153022affSStefano Zampini #include <petscmat.h> 253022affSStefano Zampini #include <h2opus.h> 353022affSStefano Zampini 453022affSStefano Zampini #ifndef __MATH2OPUS_HPP 553022affSStefano Zampini #define __MATH2OPUS_HPP 653022affSStefano Zampini 753022affSStefano Zampini class PetscMatrixSampler : public HMatrixSampler 853022affSStefano Zampini { 953022affSStefano Zampini protected: 1053022affSStefano Zampini Mat A; 1153022affSStefano Zampini typedef typename VectorContainer<H2OPUS_HWTYPE_CPU, H2Opus_Real>::type HRealVector; 1253022affSStefano Zampini typedef typename VectorContainer<H2OPUS_HWTYPE_CPU, int>::type HIntVector; 1353022affSStefano Zampini HIntVector hindexmap; 1453022affSStefano Zampini HRealVector hbuffer_in,hbuffer_out; 1553022affSStefano Zampini #if defined(PETSC_HAVE_CUDA) && defined(H2OPUS_USE_GPU) 1653022affSStefano Zampini H2OpusDeviceVector<int> dindexmap; 1753022affSStefano Zampini H2OpusDeviceVector<H2Opus_Real> dbuffer_in,dbuffer_out; 1853022affSStefano Zampini #endif 1953022affSStefano Zampini bool gpusampling; 2053022affSStefano Zampini h2opusComputeStream_t stream; 2153022affSStefano Zampini 2253022affSStefano Zampini private: 2353022affSStefano Zampini void Init(); 2453022affSStefano Zampini void VerifyBuffers(int); 2553022affSStefano Zampini void PermuteBuffersIn(int,H2Opus_Real*,H2Opus_Real**,H2Opus_Real*,H2Opus_Real**); 2653022affSStefano Zampini void PermuteBuffersOut(int,H2Opus_Real*); 2753022affSStefano Zampini 2853022affSStefano Zampini public: 2953022affSStefano Zampini PetscMatrixSampler(); 3053022affSStefano Zampini PetscMatrixSampler(Mat); 3153022affSStefano Zampini ~PetscMatrixSampler(); 3253022affSStefano Zampini void SetSamplingMat(Mat); 3353022affSStefano Zampini void SetIndexMap(int,int*); 3453022affSStefano Zampini void SetGPUSampling(bool); 3553022affSStefano Zampini void SetStream(h2opusComputeStream_t); 3653022affSStefano Zampini virtual void sample(H2Opus_Real*,H2Opus_Real*,int); 3753022affSStefano Zampini Mat GetSamplingMat() { return A; } 3853022affSStefano Zampini }; 3953022affSStefano Zampini 4053022affSStefano Zampini void PetscMatrixSampler::Init() 4153022affSStefano Zampini { 4253022affSStefano Zampini this->A = NULL; 4353022affSStefano Zampini this->gpusampling = false; 4453022affSStefano Zampini this->stream = NULL; 4553022affSStefano Zampini } 4653022affSStefano Zampini 4753022affSStefano Zampini PetscMatrixSampler::PetscMatrixSampler() 4853022affSStefano Zampini { 4953022affSStefano Zampini Init(); 5053022affSStefano Zampini } 5153022affSStefano Zampini 5253022affSStefano Zampini PetscMatrixSampler::PetscMatrixSampler(Mat A) 5353022affSStefano Zampini { 5453022affSStefano Zampini Init(); 5553022affSStefano Zampini SetSamplingMat(A); 5653022affSStefano Zampini } 5753022affSStefano Zampini 5853022affSStefano Zampini void PetscMatrixSampler::SetSamplingMat(Mat A) 5953022affSStefano Zampini { 6053022affSStefano Zampini PetscErrorCode ierr; 61*300d917bSStefano Zampini PetscMPIInt size = 1; 6253022affSStefano Zampini 63*300d917bSStefano Zampini if (A) { ierr = MPI_Comm_size(PetscObjectComm((PetscObject)A),&size);CHKERRV(ierr); } 6453022affSStefano Zampini if (size > 1) CHKERRV(PETSC_ERR_SUP); 6553022affSStefano Zampini ierr = PetscObjectReference((PetscObject)A);CHKERRV(ierr); 6653022affSStefano Zampini ierr = MatDestroy(&this->A);CHKERRV(ierr); 6753022affSStefano Zampini this->A = A; 6853022affSStefano Zampini } 6953022affSStefano Zampini 7053022affSStefano Zampini void PetscMatrixSampler::SetStream(h2opusComputeStream_t stream) 7153022affSStefano Zampini { 7253022affSStefano Zampini this->stream = stream; 7353022affSStefano Zampini } 7453022affSStefano Zampini 7553022affSStefano Zampini void PetscMatrixSampler::SetIndexMap(int n,int *indexmap) 7653022affSStefano Zampini { 7753022affSStefano Zampini copyVector(this->hindexmap, indexmap, n, H2OPUS_HWTYPE_CPU); 7853022affSStefano Zampini #if defined(PETSC_HAVE_CUDA) && defined(H2OPUS_USE_GPU) 7953022affSStefano Zampini copyVector(this->dindexmap, indexmap, n, H2OPUS_HWTYPE_CPU); 8053022affSStefano Zampini #endif 8153022affSStefano Zampini } 8253022affSStefano Zampini 8353022affSStefano Zampini void PetscMatrixSampler::VerifyBuffers(int nv) 8453022affSStefano Zampini { 8553022affSStefano Zampini if (this->hindexmap.size()) { 8653022affSStefano Zampini size_t n = this->hindexmap.size(); 8753022affSStefano Zampini if (!this->gpusampling) { 8853022affSStefano Zampini if (hbuffer_in.size() < (size_t)n * nv) 8953022affSStefano Zampini hbuffer_in.resize(n * nv); 9053022affSStefano Zampini if (hbuffer_out.size() < (size_t)n * nv) 9153022affSStefano Zampini hbuffer_out.resize(n * nv); 9253022affSStefano Zampini } else { 9353022affSStefano Zampini #if defined(PETSC_HAVE_CUDA) && defined(H2OPUS_USE_GPU) 9453022affSStefano Zampini if (dbuffer_in.size() < (size_t)n * nv) 9553022affSStefano Zampini dbuffer_in.resize(n * nv); 9653022affSStefano Zampini if (dbuffer_out.size() < (size_t)n * nv) 9753022affSStefano Zampini dbuffer_out.resize(n * nv); 9853022affSStefano Zampini #endif 9953022affSStefano Zampini } 10053022affSStefano Zampini } 10153022affSStefano Zampini } 10253022affSStefano Zampini 10353022affSStefano Zampini void PetscMatrixSampler::PermuteBuffersIn(int nv, H2Opus_Real *v, H2Opus_Real **w, H2Opus_Real *ov, H2Opus_Real **ow) 10453022affSStefano Zampini { 10553022affSStefano Zampini *w = v; 10653022affSStefano Zampini *ow = ov; 10753022affSStefano Zampini VerifyBuffers(nv); 10853022affSStefano Zampini if (this->hindexmap.size()) { 10953022affSStefano Zampini size_t n = this->hindexmap.size(); 11053022affSStefano Zampini if (!this->gpusampling) { 11153022affSStefano Zampini permute_vectors(v, this->hbuffer_in.data(), n, nv, this->hindexmap.data(), 1, H2OPUS_HWTYPE_CPU, 11253022affSStefano Zampini this->stream); 11353022affSStefano Zampini *w = this->hbuffer_in.data(); 11453022affSStefano Zampini *ow = this->hbuffer_out.data(); 11553022affSStefano Zampini } else { 11653022affSStefano Zampini #if defined(PETSC_HAVE_CUDA) && defined(H2OPUS_USE_GPU) 11753022affSStefano Zampini permute_vectors(v, this->dbuffer_in.data(), n, nv, this->dindexmap.data(), 1, H2OPUS_HWTYPE_GPU, 11853022affSStefano Zampini this->stream); 11953022affSStefano Zampini *w = this->dbuffer_in.data(); 12053022affSStefano Zampini *ow = this->dbuffer_out.data(); 12153022affSStefano Zampini #endif 12253022affSStefano Zampini } 12353022affSStefano Zampini } 12453022affSStefano Zampini } 12553022affSStefano Zampini 12653022affSStefano Zampini void PetscMatrixSampler::PermuteBuffersOut(int nv, H2Opus_Real *v) 12753022affSStefano Zampini { 12853022affSStefano Zampini VerifyBuffers(nv); 12953022affSStefano Zampini if (this->hindexmap.size()) { 13053022affSStefano Zampini size_t n = this->hindexmap.size(); 13153022affSStefano Zampini if (!this->gpusampling) { 13253022affSStefano Zampini permute_vectors(this->hbuffer_out.data(), v, n, nv, this->hindexmap.data(), 0, H2OPUS_HWTYPE_CPU, 13353022affSStefano Zampini this->stream); 13453022affSStefano Zampini } else { 13553022affSStefano Zampini #if defined(PETSC_HAVE_CUDA) && defined(H2OPUS_USE_GPU) 13653022affSStefano Zampini permute_vectors(this->dbuffer_out.data(), v, n, nv, this->dindexmap.data(), 0, H2OPUS_HWTYPE_GPU, 13753022affSStefano Zampini this->stream); 13853022affSStefano Zampini #endif 13953022affSStefano Zampini } 14053022affSStefano Zampini } 14153022affSStefano Zampini } 14253022affSStefano Zampini 14353022affSStefano Zampini void PetscMatrixSampler::SetGPUSampling(bool gpusampling) 14453022affSStefano Zampini { 14553022affSStefano Zampini this->gpusampling = gpusampling; 14653022affSStefano Zampini } 14753022affSStefano Zampini 14853022affSStefano Zampini PetscMatrixSampler::~PetscMatrixSampler() 14953022affSStefano Zampini { 15053022affSStefano Zampini PetscErrorCode ierr; 15153022affSStefano Zampini 15253022affSStefano Zampini ierr = MatDestroy(&A);CHKERRV(ierr); 15353022affSStefano Zampini } 15453022affSStefano Zampini 15553022affSStefano Zampini void PetscMatrixSampler::sample(H2Opus_Real *x, H2Opus_Real *y, int samples) 15653022affSStefano Zampini { 15753022affSStefano Zampini PetscErrorCode ierr; 15853022affSStefano Zampini MPI_Comm comm = PetscObjectComm((PetscObject)this->A); 15953022affSStefano Zampini Mat X = NULL,Y = NULL; 16053022affSStefano Zampini PetscInt M,N,m,n; 16153022affSStefano Zampini H2Opus_Real *px,*py; 16253022affSStefano Zampini 16353022affSStefano Zampini if (!this->A) CHKERRV(PETSC_ERR_PLIB); 16453022affSStefano Zampini ierr = MatGetSize(this->A,&M,&N);CHKERRV(ierr); 16553022affSStefano Zampini ierr = MatGetLocalSize(this->A,&m,&n);CHKERRV(ierr); 16653022affSStefano Zampini ierr = PetscObjectGetComm((PetscObject)A,&comm);CHKERRV(ierr); 16753022affSStefano Zampini PermuteBuffersIn(samples,x,&px,y,&py); 16853022affSStefano Zampini if (!this->gpusampling) { 16953022affSStefano Zampini ierr = MatCreateDense(comm,n,PETSC_DECIDE,N,samples,px,&X);CHKERRV(ierr); 17053022affSStefano Zampini ierr = MatCreateDense(comm,m,PETSC_DECIDE,M,samples,py,&Y);CHKERRV(ierr); 17153022affSStefano Zampini } else { 17253022affSStefano Zampini #if defined(PETSC_HAVE_CUDA) 17353022affSStefano Zampini ierr = MatCreateDenseCUDA(comm,n,PETSC_DECIDE,N,samples,px,&X);CHKERRV(ierr); 17453022affSStefano Zampini ierr = MatCreateDenseCUDA(comm,m,PETSC_DECIDE,M,samples,py,&Y);CHKERRV(ierr); 17553022affSStefano Zampini #endif 17653022affSStefano Zampini } 17753022affSStefano Zampini ierr = PetscLogObjectParent((PetscObject)this->A,(PetscObject)X);CHKERRV(ierr); 17853022affSStefano Zampini ierr = PetscLogObjectParent((PetscObject)this->A,(PetscObject)Y);CHKERRV(ierr); 17953022affSStefano Zampini ierr = MatMatMult(this->A,X,MAT_REUSE_MATRIX,PETSC_DEFAULT,&Y);CHKERRV(ierr); 18053022affSStefano Zampini #if defined(PETSC_HAVE_CUDA) 18153022affSStefano Zampini if (this->gpusampling) { 18253022affSStefano Zampini const PetscScalar *dummy; 18353022affSStefano Zampini ierr = MatDenseCUDAGetArrayRead(Y,&dummy);CHKERRV(ierr); 18453022affSStefano Zampini ierr = MatDenseCUDARestoreArrayRead(Y,&dummy);CHKERRV(ierr); 18553022affSStefano Zampini } 18653022affSStefano Zampini #endif 18753022affSStefano Zampini PermuteBuffersOut(samples,y); 18853022affSStefano Zampini ierr = MatDestroy(&X);CHKERRV(ierr); 18953022affSStefano Zampini ierr = MatDestroy(&Y);CHKERRV(ierr); 19053022affSStefano Zampini } 19153022affSStefano Zampini 19253022affSStefano Zampini #endif 193