153022affSStefano Zampini #include <petscmat.h> 253022affSStefano Zampini #include <h2opus.h> 353022affSStefano Zampini 453022affSStefano Zampini #ifndef __MATH2OPUS_HPP 553022affSStefano Zampini #define __MATH2OPUS_HPP 653022affSStefano Zampini 79371c9d4SSatish Balay class PetscMatrixSampler : public HMatrixSampler { 853022affSStefano Zampini protected: 953022affSStefano Zampini Mat A; 1053022affSStefano Zampini typedef typename VectorContainer<H2OPUS_HWTYPE_CPU, H2Opus_Real>::type HRealVector; 1153022affSStefano Zampini typedef typename VectorContainer<H2OPUS_HWTYPE_CPU, int>::type HIntVector; 1253022affSStefano Zampini HIntVector hindexmap; 1353022affSStefano Zampini HRealVector hbuffer_in, hbuffer_out; 1453022affSStefano Zampini #if defined(PETSC_HAVE_CUDA) && defined(H2OPUS_USE_GPU) 1553022affSStefano Zampini H2OpusDeviceVector<int> dindexmap; 1653022affSStefano Zampini H2OpusDeviceVector<H2Opus_Real> dbuffer_in, dbuffer_out; 1753022affSStefano Zampini #endif 1853022affSStefano Zampini bool gpusampling; 1953022affSStefano Zampini h2opusComputeStream_t stream; 2053022affSStefano Zampini 2153022affSStefano Zampini private: 2253022affSStefano Zampini void Init(); 2353022affSStefano Zampini void VerifyBuffers(int); 2453022affSStefano Zampini void PermuteBuffersIn(int, H2Opus_Real *, H2Opus_Real **, H2Opus_Real *, H2Opus_Real **); 2553022affSStefano Zampini void PermuteBuffersOut(int, H2Opus_Real *); 2653022affSStefano Zampini 2753022affSStefano Zampini public: 2853022affSStefano Zampini PetscMatrixSampler(); 2953022affSStefano Zampini PetscMatrixSampler(Mat); 3053022affSStefano Zampini ~PetscMatrixSampler(); 3153022affSStefano Zampini void SetSamplingMat(Mat); 3253022affSStefano Zampini void SetIndexMap(int, int *); 3353022affSStefano Zampini void SetGPUSampling(bool); 3453022affSStefano Zampini void SetStream(h2opusComputeStream_t); 3553022affSStefano Zampini virtual void sample(H2Opus_Real *, H2Opus_Real *, int); 36d71ae5a4SJacob Faibussowitsch Mat GetSamplingMat() { return A; } 3753022affSStefano Zampini }; 3853022affSStefano Zampini 39d71ae5a4SJacob Faibussowitsch void PetscMatrixSampler::Init() 40d71ae5a4SJacob Faibussowitsch { 4153022affSStefano Zampini this->A = NULL; 4253022affSStefano Zampini this->gpusampling = false; 4353022affSStefano Zampini this->stream = NULL; 4453022affSStefano Zampini } 4553022affSStefano Zampini 46d71ae5a4SJacob Faibussowitsch PetscMatrixSampler::PetscMatrixSampler() 47d71ae5a4SJacob Faibussowitsch { 4853022affSStefano Zampini Init(); 4953022affSStefano Zampini } 5053022affSStefano Zampini 51d71ae5a4SJacob Faibussowitsch PetscMatrixSampler::PetscMatrixSampler(Mat A) 52d71ae5a4SJacob Faibussowitsch { 5353022affSStefano Zampini Init(); 5453022affSStefano Zampini SetSamplingMat(A); 5553022affSStefano Zampini } 5653022affSStefano Zampini 57d71ae5a4SJacob Faibussowitsch void PetscMatrixSampler::SetSamplingMat(Mat A) 58d71ae5a4SJacob Faibussowitsch { 59300d917bSStefano Zampini PetscMPIInt size = 1; 6053022affSStefano Zampini 613ba16761SJacob Faibussowitsch if (A) PetscCallVoid(static_cast<PetscErrorCode>(MPI_Comm_size(PetscObjectComm((PetscObject)A), &size))); 629566063dSJacob Faibussowitsch if (size > 1) PetscCallVoid(PETSC_ERR_SUP); 639566063dSJacob Faibussowitsch PetscCallVoid(PetscObjectReference((PetscObject)A)); 649566063dSJacob Faibussowitsch PetscCallVoid(MatDestroy(&this->A)); 6553022affSStefano Zampini this->A = A; 6653022affSStefano Zampini } 6753022affSStefano Zampini 68d71ae5a4SJacob Faibussowitsch void PetscMatrixSampler::SetStream(h2opusComputeStream_t stream) 69d71ae5a4SJacob Faibussowitsch { 7053022affSStefano Zampini this->stream = stream; 7153022affSStefano Zampini } 7253022affSStefano Zampini 73d71ae5a4SJacob Faibussowitsch void PetscMatrixSampler::SetIndexMap(int n, int *indexmap) 74d71ae5a4SJacob Faibussowitsch { 7553022affSStefano Zampini copyVector(this->hindexmap, indexmap, n, H2OPUS_HWTYPE_CPU); 7653022affSStefano Zampini #if defined(PETSC_HAVE_CUDA) && defined(H2OPUS_USE_GPU) 7753022affSStefano Zampini copyVector(this->dindexmap, indexmap, n, H2OPUS_HWTYPE_CPU); 7853022affSStefano Zampini #endif 7953022affSStefano Zampini } 8053022affSStefano Zampini 81d71ae5a4SJacob Faibussowitsch void PetscMatrixSampler::VerifyBuffers(int nv) 82d71ae5a4SJacob Faibussowitsch { 8353022affSStefano Zampini if (this->hindexmap.size()) { 8453022affSStefano Zampini size_t n = this->hindexmap.size(); 8553022affSStefano Zampini if (!this->gpusampling) { 869371c9d4SSatish Balay if (hbuffer_in.size() < (size_t)n * nv) hbuffer_in.resize(n * nv); 879371c9d4SSatish Balay if (hbuffer_out.size() < (size_t)n * nv) hbuffer_out.resize(n * nv); 8853022affSStefano Zampini } else { 8953022affSStefano Zampini #if defined(PETSC_HAVE_CUDA) && defined(H2OPUS_USE_GPU) 909371c9d4SSatish Balay if (dbuffer_in.size() < (size_t)n * nv) dbuffer_in.resize(n * nv); 919371c9d4SSatish Balay if (dbuffer_out.size() < (size_t)n * nv) dbuffer_out.resize(n * nv); 9253022affSStefano Zampini #endif 9353022affSStefano Zampini } 9453022affSStefano Zampini } 9553022affSStefano Zampini } 9653022affSStefano Zampini 97d71ae5a4SJacob Faibussowitsch void PetscMatrixSampler::PermuteBuffersIn(int nv, H2Opus_Real *v, H2Opus_Real **w, H2Opus_Real *ov, H2Opus_Real **ow) 98d71ae5a4SJacob Faibussowitsch { 9953022affSStefano Zampini *w = v; 10053022affSStefano Zampini *ow = ov; 10153022affSStefano Zampini VerifyBuffers(nv); 10253022affSStefano Zampini if (this->hindexmap.size()) { 10353022affSStefano Zampini size_t n = this->hindexmap.size(); 10453022affSStefano Zampini if (!this->gpusampling) { 1059371c9d4SSatish Balay permute_vectors(v, this->hbuffer_in.data(), n, nv, this->hindexmap.data(), 1, H2OPUS_HWTYPE_CPU, this->stream); 10653022affSStefano Zampini *w = this->hbuffer_in.data(); 10753022affSStefano Zampini *ow = this->hbuffer_out.data(); 10853022affSStefano Zampini } else { 10953022affSStefano Zampini #if defined(PETSC_HAVE_CUDA) && defined(H2OPUS_USE_GPU) 1109371c9d4SSatish Balay permute_vectors(v, this->dbuffer_in.data(), n, nv, this->dindexmap.data(), 1, H2OPUS_HWTYPE_GPU, this->stream); 11153022affSStefano Zampini *w = this->dbuffer_in.data(); 11253022affSStefano Zampini *ow = this->dbuffer_out.data(); 11353022affSStefano Zampini #endif 11453022affSStefano Zampini } 11553022affSStefano Zampini } 11653022affSStefano Zampini } 11753022affSStefano Zampini 118d71ae5a4SJacob Faibussowitsch void PetscMatrixSampler::PermuteBuffersOut(int nv, H2Opus_Real *v) 119d71ae5a4SJacob Faibussowitsch { 12053022affSStefano Zampini VerifyBuffers(nv); 12153022affSStefano Zampini if (this->hindexmap.size()) { 12253022affSStefano Zampini size_t n = this->hindexmap.size(); 12353022affSStefano Zampini if (!this->gpusampling) { 1249371c9d4SSatish Balay permute_vectors(this->hbuffer_out.data(), v, n, nv, this->hindexmap.data(), 0, H2OPUS_HWTYPE_CPU, this->stream); 12553022affSStefano Zampini } else { 12653022affSStefano Zampini #if defined(PETSC_HAVE_CUDA) && defined(H2OPUS_USE_GPU) 1279371c9d4SSatish Balay permute_vectors(this->dbuffer_out.data(), v, n, nv, this->dindexmap.data(), 0, H2OPUS_HWTYPE_GPU, this->stream); 12853022affSStefano Zampini #endif 12953022affSStefano Zampini } 13053022affSStefano Zampini } 13153022affSStefano Zampini } 13253022affSStefano Zampini 133d71ae5a4SJacob Faibussowitsch void PetscMatrixSampler::SetGPUSampling(bool gpusampling) 134d71ae5a4SJacob Faibussowitsch { 13553022affSStefano Zampini this->gpusampling = gpusampling; 13653022affSStefano Zampini } 13753022affSStefano Zampini 138d71ae5a4SJacob Faibussowitsch PetscMatrixSampler::~PetscMatrixSampler() 139d71ae5a4SJacob Faibussowitsch { 1409566063dSJacob Faibussowitsch PetscCallVoid(MatDestroy(&A)); 14153022affSStefano Zampini } 14253022affSStefano Zampini 143d71ae5a4SJacob Faibussowitsch void PetscMatrixSampler::sample(H2Opus_Real *x, H2Opus_Real *y, int samples) 144d71ae5a4SJacob Faibussowitsch { 14553022affSStefano Zampini MPI_Comm comm = PetscObjectComm((PetscObject)this->A); 14653022affSStefano Zampini Mat X = NULL, Y = NULL; 14753022affSStefano Zampini PetscInt M, N, m, n; 14853022affSStefano Zampini H2Opus_Real *px, *py; 149*2592a002SStefano Zampini VecType vtype; 15053022affSStefano Zampini 1519566063dSJacob Faibussowitsch if (!this->A) PetscCallVoid(PETSC_ERR_PLIB); 1529566063dSJacob Faibussowitsch PetscCallVoid(MatGetSize(this->A, &M, &N)); 153*2592a002SStefano Zampini PetscCallVoid(MatGetVecType(this->A, &vtype)); 1549566063dSJacob Faibussowitsch PetscCallVoid(MatGetLocalSize(this->A, &m, &n)); 1559566063dSJacob Faibussowitsch PetscCallVoid(PetscObjectGetComm((PetscObject)A, &comm)); 15653022affSStefano Zampini PermuteBuffersIn(samples, x, &px, y, &py); 15753022affSStefano Zampini if (!this->gpusampling) { 1589566063dSJacob Faibussowitsch PetscCallVoid(MatCreateDense(comm, n, PETSC_DECIDE, N, samples, px, &X)); 1599566063dSJacob Faibussowitsch PetscCallVoid(MatCreateDense(comm, m, PETSC_DECIDE, M, samples, py, &Y)); 160*2592a002SStefano Zampini PetscCallVoid(MatSetVecType(X, vtype)); 161*2592a002SStefano Zampini PetscCallVoid(MatSetVecType(Y, vtype)); 162*2592a002SStefano Zampini 16353022affSStefano Zampini } else { 16453022affSStefano Zampini #if defined(PETSC_HAVE_CUDA) 1659566063dSJacob Faibussowitsch PetscCallVoid(MatCreateDenseCUDA(comm, n, PETSC_DECIDE, N, samples, px, &X)); 1669566063dSJacob Faibussowitsch PetscCallVoid(MatCreateDenseCUDA(comm, m, PETSC_DECIDE, M, samples, py, &Y)); 167*2592a002SStefano Zampini PetscCallVoid(MatSetVecType(X, vtype)); 168*2592a002SStefano Zampini PetscCallVoid(MatSetVecType(Y, vtype)); 16953022affSStefano Zampini #endif 17053022affSStefano Zampini } 1719566063dSJacob Faibussowitsch PetscCallVoid(MatMatMult(this->A, X, MAT_REUSE_MATRIX, PETSC_DEFAULT, &Y)); 17253022affSStefano Zampini #if defined(PETSC_HAVE_CUDA) 17353022affSStefano Zampini if (this->gpusampling) { 17453022affSStefano Zampini const PetscScalar *dummy; 1759566063dSJacob Faibussowitsch PetscCallVoid(MatDenseCUDAGetArrayRead(Y, &dummy)); 1769566063dSJacob Faibussowitsch PetscCallVoid(MatDenseCUDARestoreArrayRead(Y, &dummy)); 17753022affSStefano Zampini } 17853022affSStefano Zampini #endif 17953022affSStefano Zampini PermuteBuffersOut(samples, y); 1809566063dSJacob Faibussowitsch PetscCallVoid(MatDestroy(&X)); 1819566063dSJacob Faibussowitsch PetscCallVoid(MatDestroy(&Y)); 18253022affSStefano Zampini } 18353022affSStefano Zampini 18453022affSStefano Zampini #endif 185