153022affSStefano Zampini #include <petscmat.h> 253022affSStefano Zampini #include <h2opus.h> 353022affSStefano Zampini 453022affSStefano Zampini #ifndef __MATH2OPUS_HPP 553022affSStefano Zampini #define __MATH2OPUS_HPP 653022affSStefano Zampini 753022affSStefano Zampini class PetscMatrixSampler : public HMatrixSampler 853022affSStefano Zampini { 953022affSStefano Zampini protected: 1053022affSStefano Zampini Mat A; 1153022affSStefano Zampini typedef typename VectorContainer<H2OPUS_HWTYPE_CPU, H2Opus_Real>::type HRealVector; 1253022affSStefano Zampini typedef typename VectorContainer<H2OPUS_HWTYPE_CPU, int>::type HIntVector; 1353022affSStefano Zampini HIntVector hindexmap; 1453022affSStefano Zampini HRealVector hbuffer_in,hbuffer_out; 1553022affSStefano Zampini #if defined(PETSC_HAVE_CUDA) && defined(H2OPUS_USE_GPU) 1653022affSStefano Zampini H2OpusDeviceVector<int> dindexmap; 1753022affSStefano Zampini H2OpusDeviceVector<H2Opus_Real> dbuffer_in,dbuffer_out; 1853022affSStefano Zampini #endif 1953022affSStefano Zampini bool gpusampling; 2053022affSStefano Zampini h2opusComputeStream_t stream; 2153022affSStefano Zampini 2253022affSStefano Zampini private: 2353022affSStefano Zampini void Init(); 2453022affSStefano Zampini void VerifyBuffers(int); 2553022affSStefano Zampini void PermuteBuffersIn(int,H2Opus_Real*,H2Opus_Real**,H2Opus_Real*,H2Opus_Real**); 2653022affSStefano Zampini void PermuteBuffersOut(int,H2Opus_Real*); 2753022affSStefano Zampini 2853022affSStefano Zampini public: 2953022affSStefano Zampini PetscMatrixSampler(); 3053022affSStefano Zampini PetscMatrixSampler(Mat); 3153022affSStefano Zampini ~PetscMatrixSampler(); 3253022affSStefano Zampini void SetSamplingMat(Mat); 3353022affSStefano Zampini void SetIndexMap(int,int*); 3453022affSStefano Zampini void SetGPUSampling(bool); 3553022affSStefano Zampini void SetStream(h2opusComputeStream_t); 3653022affSStefano Zampini virtual void sample(H2Opus_Real*,H2Opus_Real*,int); 3753022affSStefano Zampini Mat GetSamplingMat() { return A; } 3853022affSStefano Zampini }; 3953022affSStefano Zampini 4053022affSStefano Zampini void PetscMatrixSampler::Init() 4153022affSStefano Zampini { 4253022affSStefano Zampini this->A = NULL; 4353022affSStefano Zampini this->gpusampling = false; 4453022affSStefano Zampini this->stream = NULL; 4553022affSStefano Zampini } 4653022affSStefano Zampini 4753022affSStefano Zampini PetscMatrixSampler::PetscMatrixSampler() 4853022affSStefano Zampini { 4953022affSStefano Zampini Init(); 5053022affSStefano Zampini } 5153022affSStefano Zampini 5253022affSStefano Zampini PetscMatrixSampler::PetscMatrixSampler(Mat A) 5353022affSStefano Zampini { 5453022affSStefano Zampini Init(); 5553022affSStefano Zampini SetSamplingMat(A); 5653022affSStefano Zampini } 5753022affSStefano Zampini 5853022affSStefano Zampini void PetscMatrixSampler::SetSamplingMat(Mat A) 5953022affSStefano Zampini { 60300d917bSStefano Zampini PetscMPIInt size = 1; 6153022affSStefano Zampini 62*9566063dSJacob Faibussowitsch if (A) PetscCallVoid(MPI_Comm_size(PetscObjectComm((PetscObject)A),&size)); 63*9566063dSJacob Faibussowitsch if (size > 1) PetscCallVoid(PETSC_ERR_SUP); 64*9566063dSJacob Faibussowitsch PetscCallVoid(PetscObjectReference((PetscObject)A)); 65*9566063dSJacob Faibussowitsch PetscCallVoid(MatDestroy(&this->A)); 6653022affSStefano Zampini this->A = A; 6753022affSStefano Zampini } 6853022affSStefano Zampini 6953022affSStefano Zampini void PetscMatrixSampler::SetStream(h2opusComputeStream_t stream) 7053022affSStefano Zampini { 7153022affSStefano Zampini this->stream = stream; 7253022affSStefano Zampini } 7353022affSStefano Zampini 7453022affSStefano Zampini void PetscMatrixSampler::SetIndexMap(int n,int *indexmap) 7553022affSStefano Zampini { 7653022affSStefano Zampini copyVector(this->hindexmap, indexmap, n, H2OPUS_HWTYPE_CPU); 7753022affSStefano Zampini #if defined(PETSC_HAVE_CUDA) && defined(H2OPUS_USE_GPU) 7853022affSStefano Zampini copyVector(this->dindexmap, indexmap, n, H2OPUS_HWTYPE_CPU); 7953022affSStefano Zampini #endif 8053022affSStefano Zampini } 8153022affSStefano Zampini 8253022affSStefano Zampini void PetscMatrixSampler::VerifyBuffers(int nv) 8353022affSStefano Zampini { 8453022affSStefano Zampini if (this->hindexmap.size()) { 8553022affSStefano Zampini size_t n = this->hindexmap.size(); 8653022affSStefano Zampini if (!this->gpusampling) { 8753022affSStefano Zampini if (hbuffer_in.size() < (size_t)n * nv) 8853022affSStefano Zampini hbuffer_in.resize(n * nv); 8953022affSStefano Zampini if (hbuffer_out.size() < (size_t)n * nv) 9053022affSStefano Zampini hbuffer_out.resize(n * nv); 9153022affSStefano Zampini } else { 9253022affSStefano Zampini #if defined(PETSC_HAVE_CUDA) && defined(H2OPUS_USE_GPU) 9353022affSStefano Zampini if (dbuffer_in.size() < (size_t)n * nv) 9453022affSStefano Zampini dbuffer_in.resize(n * nv); 9553022affSStefano Zampini if (dbuffer_out.size() < (size_t)n * nv) 9653022affSStefano Zampini dbuffer_out.resize(n * nv); 9753022affSStefano Zampini #endif 9853022affSStefano Zampini } 9953022affSStefano Zampini } 10053022affSStefano Zampini } 10153022affSStefano Zampini 10253022affSStefano Zampini void PetscMatrixSampler::PermuteBuffersIn(int nv, H2Opus_Real *v, H2Opus_Real **w, H2Opus_Real *ov, H2Opus_Real **ow) 10353022affSStefano Zampini { 10453022affSStefano Zampini *w = v; 10553022affSStefano Zampini *ow = ov; 10653022affSStefano Zampini VerifyBuffers(nv); 10753022affSStefano Zampini if (this->hindexmap.size()) { 10853022affSStefano Zampini size_t n = this->hindexmap.size(); 10953022affSStefano Zampini if (!this->gpusampling) { 11053022affSStefano Zampini permute_vectors(v, this->hbuffer_in.data(), n, nv, this->hindexmap.data(), 1, H2OPUS_HWTYPE_CPU, 11153022affSStefano Zampini this->stream); 11253022affSStefano Zampini *w = this->hbuffer_in.data(); 11353022affSStefano Zampini *ow = this->hbuffer_out.data(); 11453022affSStefano Zampini } else { 11553022affSStefano Zampini #if defined(PETSC_HAVE_CUDA) && defined(H2OPUS_USE_GPU) 11653022affSStefano Zampini permute_vectors(v, this->dbuffer_in.data(), n, nv, this->dindexmap.data(), 1, H2OPUS_HWTYPE_GPU, 11753022affSStefano Zampini this->stream); 11853022affSStefano Zampini *w = this->dbuffer_in.data(); 11953022affSStefano Zampini *ow = this->dbuffer_out.data(); 12053022affSStefano Zampini #endif 12153022affSStefano Zampini } 12253022affSStefano Zampini } 12353022affSStefano Zampini } 12453022affSStefano Zampini 12553022affSStefano Zampini void PetscMatrixSampler::PermuteBuffersOut(int nv, H2Opus_Real *v) 12653022affSStefano Zampini { 12753022affSStefano Zampini VerifyBuffers(nv); 12853022affSStefano Zampini if (this->hindexmap.size()) { 12953022affSStefano Zampini size_t n = this->hindexmap.size(); 13053022affSStefano Zampini if (!this->gpusampling) { 13153022affSStefano Zampini permute_vectors(this->hbuffer_out.data(), v, n, nv, this->hindexmap.data(), 0, H2OPUS_HWTYPE_CPU, 13253022affSStefano Zampini this->stream); 13353022affSStefano Zampini } else { 13453022affSStefano Zampini #if defined(PETSC_HAVE_CUDA) && defined(H2OPUS_USE_GPU) 13553022affSStefano Zampini permute_vectors(this->dbuffer_out.data(), v, n, nv, this->dindexmap.data(), 0, H2OPUS_HWTYPE_GPU, 13653022affSStefano Zampini this->stream); 13753022affSStefano Zampini #endif 13853022affSStefano Zampini } 13953022affSStefano Zampini } 14053022affSStefano Zampini } 14153022affSStefano Zampini 14253022affSStefano Zampini void PetscMatrixSampler::SetGPUSampling(bool gpusampling) 14353022affSStefano Zampini { 14453022affSStefano Zampini this->gpusampling = gpusampling; 14553022affSStefano Zampini } 14653022affSStefano Zampini 14753022affSStefano Zampini PetscMatrixSampler::~PetscMatrixSampler() 14853022affSStefano Zampini { 149*9566063dSJacob Faibussowitsch PetscCallVoid(MatDestroy(&A)); 15053022affSStefano Zampini } 15153022affSStefano Zampini 15253022affSStefano Zampini void PetscMatrixSampler::sample(H2Opus_Real *x, H2Opus_Real *y, int samples) 15353022affSStefano Zampini { 15453022affSStefano Zampini MPI_Comm comm = PetscObjectComm((PetscObject)this->A); 15553022affSStefano Zampini Mat X = NULL,Y = NULL; 15653022affSStefano Zampini PetscInt M,N,m,n; 15753022affSStefano Zampini H2Opus_Real *px,*py; 15853022affSStefano Zampini 159*9566063dSJacob Faibussowitsch if (!this->A) PetscCallVoid(PETSC_ERR_PLIB); 160*9566063dSJacob Faibussowitsch PetscCallVoid(MatGetSize(this->A,&M,&N)); 161*9566063dSJacob Faibussowitsch PetscCallVoid(MatGetLocalSize(this->A,&m,&n)); 162*9566063dSJacob Faibussowitsch PetscCallVoid(PetscObjectGetComm((PetscObject)A,&comm)); 16353022affSStefano Zampini PermuteBuffersIn(samples,x,&px,y,&py); 16453022affSStefano Zampini if (!this->gpusampling) { 165*9566063dSJacob Faibussowitsch PetscCallVoid(MatCreateDense(comm,n,PETSC_DECIDE,N,samples,px,&X)); 166*9566063dSJacob Faibussowitsch PetscCallVoid(MatCreateDense(comm,m,PETSC_DECIDE,M,samples,py,&Y)); 16753022affSStefano Zampini } else { 16853022affSStefano Zampini #if defined(PETSC_HAVE_CUDA) 169*9566063dSJacob Faibussowitsch PetscCallVoid(MatCreateDenseCUDA(comm,n,PETSC_DECIDE,N,samples,px,&X)); 170*9566063dSJacob Faibussowitsch PetscCallVoid(MatCreateDenseCUDA(comm,m,PETSC_DECIDE,M,samples,py,&Y)); 17153022affSStefano Zampini #endif 17253022affSStefano Zampini } 173*9566063dSJacob Faibussowitsch PetscCallVoid(PetscLogObjectParent((PetscObject)this->A,(PetscObject)X)); 174*9566063dSJacob Faibussowitsch PetscCallVoid(PetscLogObjectParent((PetscObject)this->A,(PetscObject)Y)); 175*9566063dSJacob Faibussowitsch PetscCallVoid(MatMatMult(this->A,X,MAT_REUSE_MATRIX,PETSC_DEFAULT,&Y)); 17653022affSStefano Zampini #if defined(PETSC_HAVE_CUDA) 17753022affSStefano Zampini if (this->gpusampling) { 17853022affSStefano Zampini const PetscScalar *dummy; 179*9566063dSJacob Faibussowitsch PetscCallVoid(MatDenseCUDAGetArrayRead(Y,&dummy)); 180*9566063dSJacob Faibussowitsch PetscCallVoid(MatDenseCUDARestoreArrayRead(Y,&dummy)); 18153022affSStefano Zampini } 18253022affSStefano Zampini #endif 18353022affSStefano Zampini PermuteBuffersOut(samples,y); 184*9566063dSJacob Faibussowitsch PetscCallVoid(MatDestroy(&X)); 185*9566063dSJacob Faibussowitsch PetscCallVoid(MatDestroy(&Y)); 18653022affSStefano Zampini } 18753022affSStefano Zampini 18853022affSStefano Zampini #endif 189