1*53022affSStefano Zampini #include <petscmat.h> 2*53022affSStefano Zampini #include <h2opus.h> 3*53022affSStefano Zampini 4*53022affSStefano Zampini #ifndef __MATH2OPUS_HPP 5*53022affSStefano Zampini #define __MATH2OPUS_HPP 6*53022affSStefano Zampini 7*53022affSStefano Zampini class PetscMatrixSampler : public HMatrixSampler 8*53022affSStefano Zampini { 9*53022affSStefano Zampini protected: 10*53022affSStefano Zampini Mat A; 11*53022affSStefano Zampini typedef typename VectorContainer<H2OPUS_HWTYPE_CPU, H2Opus_Real>::type HRealVector; 12*53022affSStefano Zampini typedef typename VectorContainer<H2OPUS_HWTYPE_CPU, int>::type HIntVector; 13*53022affSStefano Zampini HIntVector hindexmap; 14*53022affSStefano Zampini HRealVector hbuffer_in,hbuffer_out; 15*53022affSStefano Zampini #if defined(PETSC_HAVE_CUDA) && defined(H2OPUS_USE_GPU) 16*53022affSStefano Zampini H2OpusDeviceVector<int> dindexmap; 17*53022affSStefano Zampini H2OpusDeviceVector<H2Opus_Real> dbuffer_in,dbuffer_out; 18*53022affSStefano Zampini #endif 19*53022affSStefano Zampini bool gpusampling; 20*53022affSStefano Zampini h2opusComputeStream_t stream; 21*53022affSStefano Zampini 22*53022affSStefano Zampini private: 23*53022affSStefano Zampini void Init(); 24*53022affSStefano Zampini void VerifyBuffers(int); 25*53022affSStefano Zampini void PermuteBuffersIn(int,H2Opus_Real*,H2Opus_Real**,H2Opus_Real*,H2Opus_Real**); 26*53022affSStefano Zampini void PermuteBuffersOut(int,H2Opus_Real*); 27*53022affSStefano Zampini 28*53022affSStefano Zampini public: 29*53022affSStefano Zampini PetscMatrixSampler(); 30*53022affSStefano Zampini PetscMatrixSampler(Mat); 31*53022affSStefano Zampini ~PetscMatrixSampler(); 32*53022affSStefano Zampini void SetSamplingMat(Mat); 33*53022affSStefano Zampini void SetIndexMap(int,int*); 34*53022affSStefano Zampini void SetGPUSampling(bool); 35*53022affSStefano Zampini void SetStream(h2opusComputeStream_t); 36*53022affSStefano Zampini virtual void sample(H2Opus_Real*,H2Opus_Real*,int); 37*53022affSStefano Zampini Mat GetSamplingMat() { return A; } 38*53022affSStefano Zampini }; 39*53022affSStefano Zampini 40*53022affSStefano Zampini void PetscMatrixSampler::Init() 41*53022affSStefano Zampini { 42*53022affSStefano Zampini this->A = NULL; 43*53022affSStefano Zampini this->gpusampling = false; 44*53022affSStefano Zampini this->stream = NULL; 45*53022affSStefano Zampini } 46*53022affSStefano Zampini 47*53022affSStefano Zampini PetscMatrixSampler::PetscMatrixSampler() 48*53022affSStefano Zampini { 49*53022affSStefano Zampini Init(); 50*53022affSStefano Zampini } 51*53022affSStefano Zampini 52*53022affSStefano Zampini PetscMatrixSampler::PetscMatrixSampler(Mat A) 53*53022affSStefano Zampini { 54*53022affSStefano Zampini Init(); 55*53022affSStefano Zampini SetSamplingMat(A); 56*53022affSStefano Zampini } 57*53022affSStefano Zampini 58*53022affSStefano Zampini void PetscMatrixSampler::SetSamplingMat(Mat A) 59*53022affSStefano Zampini { 60*53022affSStefano Zampini PetscErrorCode ierr; 61*53022affSStefano Zampini PetscMPIInt size; 62*53022affSStefano Zampini 63*53022affSStefano Zampini ierr = MPI_Comm_size(PetscObjectComm((PetscObject)A),&size);CHKERRV(ierr); 64*53022affSStefano Zampini if (size > 1) CHKERRV(PETSC_ERR_SUP); 65*53022affSStefano Zampini ierr = PetscObjectReference((PetscObject)A);CHKERRV(ierr); 66*53022affSStefano Zampini ierr = MatDestroy(&this->A);CHKERRV(ierr); 67*53022affSStefano Zampini this->A = A; 68*53022affSStefano Zampini } 69*53022affSStefano Zampini 70*53022affSStefano Zampini void PetscMatrixSampler::SetStream(h2opusComputeStream_t stream) 71*53022affSStefano Zampini { 72*53022affSStefano Zampini this->stream = stream; 73*53022affSStefano Zampini } 74*53022affSStefano Zampini 75*53022affSStefano Zampini void PetscMatrixSampler::SetIndexMap(int n,int *indexmap) 76*53022affSStefano Zampini { 77*53022affSStefano Zampini copyVector(this->hindexmap, indexmap, n, H2OPUS_HWTYPE_CPU); 78*53022affSStefano Zampini #if defined(PETSC_HAVE_CUDA) && defined(H2OPUS_USE_GPU) 79*53022affSStefano Zampini copyVector(this->dindexmap, indexmap, n, H2OPUS_HWTYPE_CPU); 80*53022affSStefano Zampini #endif 81*53022affSStefano Zampini } 82*53022affSStefano Zampini 83*53022affSStefano Zampini void PetscMatrixSampler::VerifyBuffers(int nv) 84*53022affSStefano Zampini { 85*53022affSStefano Zampini if (this->hindexmap.size()) { 86*53022affSStefano Zampini size_t n = this->hindexmap.size(); 87*53022affSStefano Zampini if (!this->gpusampling) { 88*53022affSStefano Zampini if (hbuffer_in.size() < (size_t)n * nv) 89*53022affSStefano Zampini hbuffer_in.resize(n * nv); 90*53022affSStefano Zampini if (hbuffer_out.size() < (size_t)n * nv) 91*53022affSStefano Zampini hbuffer_out.resize(n * nv); 92*53022affSStefano Zampini } else { 93*53022affSStefano Zampini #if defined(PETSC_HAVE_CUDA) && defined(H2OPUS_USE_GPU) 94*53022affSStefano Zampini if (dbuffer_in.size() < (size_t)n * nv) 95*53022affSStefano Zampini dbuffer_in.resize(n * nv); 96*53022affSStefano Zampini if (dbuffer_out.size() < (size_t)n * nv) 97*53022affSStefano Zampini dbuffer_out.resize(n * nv); 98*53022affSStefano Zampini #endif 99*53022affSStefano Zampini } 100*53022affSStefano Zampini } 101*53022affSStefano Zampini } 102*53022affSStefano Zampini 103*53022affSStefano Zampini void PetscMatrixSampler::PermuteBuffersIn(int nv, H2Opus_Real *v, H2Opus_Real **w, H2Opus_Real *ov, H2Opus_Real **ow) 104*53022affSStefano Zampini { 105*53022affSStefano Zampini *w = v; 106*53022affSStefano Zampini *ow = ov; 107*53022affSStefano Zampini VerifyBuffers(nv); 108*53022affSStefano Zampini if (this->hindexmap.size()) { 109*53022affSStefano Zampini size_t n = this->hindexmap.size(); 110*53022affSStefano Zampini if (!this->gpusampling) { 111*53022affSStefano Zampini permute_vectors(v, this->hbuffer_in.data(), n, nv, this->hindexmap.data(), 1, H2OPUS_HWTYPE_CPU, 112*53022affSStefano Zampini this->stream); 113*53022affSStefano Zampini *w = this->hbuffer_in.data(); 114*53022affSStefano Zampini *ow = this->hbuffer_out.data(); 115*53022affSStefano Zampini } else { 116*53022affSStefano Zampini #if defined(PETSC_HAVE_CUDA) && defined(H2OPUS_USE_GPU) 117*53022affSStefano Zampini permute_vectors(v, this->dbuffer_in.data(), n, nv, this->dindexmap.data(), 1, H2OPUS_HWTYPE_GPU, 118*53022affSStefano Zampini this->stream); 119*53022affSStefano Zampini *w = this->dbuffer_in.data(); 120*53022affSStefano Zampini *ow = this->dbuffer_out.data(); 121*53022affSStefano Zampini #endif 122*53022affSStefano Zampini } 123*53022affSStefano Zampini } 124*53022affSStefano Zampini } 125*53022affSStefano Zampini 126*53022affSStefano Zampini void PetscMatrixSampler::PermuteBuffersOut(int nv, H2Opus_Real *v) 127*53022affSStefano Zampini { 128*53022affSStefano Zampini VerifyBuffers(nv); 129*53022affSStefano Zampini if (this->hindexmap.size()) { 130*53022affSStefano Zampini size_t n = this->hindexmap.size(); 131*53022affSStefano Zampini if (!this->gpusampling) { 132*53022affSStefano Zampini permute_vectors(this->hbuffer_out.data(), v, n, nv, this->hindexmap.data(), 0, H2OPUS_HWTYPE_CPU, 133*53022affSStefano Zampini this->stream); 134*53022affSStefano Zampini } else { 135*53022affSStefano Zampini #if defined(PETSC_HAVE_CUDA) && defined(H2OPUS_USE_GPU) 136*53022affSStefano Zampini permute_vectors(this->dbuffer_out.data(), v, n, nv, this->dindexmap.data(), 0, H2OPUS_HWTYPE_GPU, 137*53022affSStefano Zampini this->stream); 138*53022affSStefano Zampini #endif 139*53022affSStefano Zampini } 140*53022affSStefano Zampini } 141*53022affSStefano Zampini } 142*53022affSStefano Zampini 143*53022affSStefano Zampini void PetscMatrixSampler::SetGPUSampling(bool gpusampling) 144*53022affSStefano Zampini { 145*53022affSStefano Zampini this->gpusampling = gpusampling; 146*53022affSStefano Zampini } 147*53022affSStefano Zampini 148*53022affSStefano Zampini PetscMatrixSampler::~PetscMatrixSampler() 149*53022affSStefano Zampini { 150*53022affSStefano Zampini PetscErrorCode ierr; 151*53022affSStefano Zampini 152*53022affSStefano Zampini ierr = MatDestroy(&A);CHKERRV(ierr); 153*53022affSStefano Zampini } 154*53022affSStefano Zampini 155*53022affSStefano Zampini void PetscMatrixSampler::sample(H2Opus_Real *x, H2Opus_Real *y, int samples) 156*53022affSStefano Zampini { 157*53022affSStefano Zampini PetscErrorCode ierr; 158*53022affSStefano Zampini MPI_Comm comm = PetscObjectComm((PetscObject)this->A); 159*53022affSStefano Zampini Mat X = NULL,Y = NULL; 160*53022affSStefano Zampini PetscInt M,N,m,n; 161*53022affSStefano Zampini H2Opus_Real *px,*py; 162*53022affSStefano Zampini 163*53022affSStefano Zampini if (!this->A) CHKERRV(PETSC_ERR_PLIB); 164*53022affSStefano Zampini ierr = MatGetSize(this->A,&M,&N);CHKERRV(ierr); 165*53022affSStefano Zampini ierr = MatGetLocalSize(this->A,&m,&n);CHKERRV(ierr); 166*53022affSStefano Zampini ierr = PetscObjectGetComm((PetscObject)A,&comm);CHKERRV(ierr); 167*53022affSStefano Zampini PermuteBuffersIn(samples,x,&px,y,&py); 168*53022affSStefano Zampini if (!this->gpusampling) { 169*53022affSStefano Zampini ierr = MatCreateDense(comm,n,PETSC_DECIDE,N,samples,px,&X);CHKERRV(ierr); 170*53022affSStefano Zampini ierr = MatCreateDense(comm,m,PETSC_DECIDE,M,samples,py,&Y);CHKERRV(ierr); 171*53022affSStefano Zampini } else { 172*53022affSStefano Zampini #if defined(PETSC_HAVE_CUDA) 173*53022affSStefano Zampini ierr = MatCreateDenseCUDA(comm,n,PETSC_DECIDE,N,samples,px,&X);CHKERRV(ierr); 174*53022affSStefano Zampini ierr = MatCreateDenseCUDA(comm,m,PETSC_DECIDE,M,samples,py,&Y);CHKERRV(ierr); 175*53022affSStefano Zampini #endif 176*53022affSStefano Zampini } 177*53022affSStefano Zampini ierr = PetscLogObjectParent((PetscObject)this->A,(PetscObject)X);CHKERRV(ierr); 178*53022affSStefano Zampini ierr = PetscLogObjectParent((PetscObject)this->A,(PetscObject)Y);CHKERRV(ierr); 179*53022affSStefano Zampini ierr = MatMatMult(this->A,X,MAT_REUSE_MATRIX,PETSC_DEFAULT,&Y);CHKERRV(ierr); 180*53022affSStefano Zampini #if defined(PETSC_HAVE_CUDA) 181*53022affSStefano Zampini if (this->gpusampling) { 182*53022affSStefano Zampini const PetscScalar *dummy; 183*53022affSStefano Zampini ierr = MatDenseCUDAGetArrayRead(Y,&dummy);CHKERRV(ierr); 184*53022affSStefano Zampini ierr = MatDenseCUDARestoreArrayRead(Y,&dummy);CHKERRV(ierr); 185*53022affSStefano Zampini } 186*53022affSStefano Zampini #endif 187*53022affSStefano Zampini PermuteBuffersOut(samples,y); 188*53022affSStefano Zampini ierr = MatDestroy(&X);CHKERRV(ierr); 189*53022affSStefano Zampini ierr = MatDestroy(&Y);CHKERRV(ierr); 190*53022affSStefano Zampini } 191*53022affSStefano Zampini 192*53022affSStefano Zampini #endif 193