1*a4963045SJacob Faibussowitsch #pragma once 2*a4963045SJacob Faibussowitsch 353022affSStefano Zampini #include <petscmat.h> 453022affSStefano Zampini #include <h2opus.h> 553022affSStefano Zampini 69371c9d4SSatish Balay class PetscMatrixSampler : public HMatrixSampler { 753022affSStefano Zampini protected: 853022affSStefano Zampini Mat A; 953022affSStefano Zampini typedef typename VectorContainer<H2OPUS_HWTYPE_CPU, H2Opus_Real>::type HRealVector; 1053022affSStefano Zampini typedef typename VectorContainer<H2OPUS_HWTYPE_CPU, int>::type HIntVector; 1153022affSStefano Zampini HIntVector hindexmap; 1253022affSStefano Zampini HRealVector hbuffer_in, hbuffer_out; 1353022affSStefano Zampini #if defined(PETSC_HAVE_CUDA) && defined(H2OPUS_USE_GPU) 1453022affSStefano Zampini H2OpusDeviceVector<int> dindexmap; 1553022affSStefano Zampini H2OpusDeviceVector<H2Opus_Real> dbuffer_in, dbuffer_out; 1653022affSStefano Zampini #endif 1753022affSStefano Zampini bool gpusampling; 1853022affSStefano Zampini h2opusComputeStream_t stream; 1953022affSStefano Zampini 2053022affSStefano Zampini private: 2153022affSStefano Zampini void Init(); 2253022affSStefano Zampini void VerifyBuffers(int); 2353022affSStefano Zampini void PermuteBuffersIn(int, H2Opus_Real *, H2Opus_Real **, H2Opus_Real *, H2Opus_Real **); 2453022affSStefano Zampini void PermuteBuffersOut(int, H2Opus_Real *); 2553022affSStefano Zampini 2653022affSStefano Zampini public: 2753022affSStefano Zampini PetscMatrixSampler(); 2853022affSStefano Zampini PetscMatrixSampler(Mat); 2953022affSStefano Zampini ~PetscMatrixSampler(); 3053022affSStefano Zampini void SetSamplingMat(Mat); 3153022affSStefano Zampini void SetIndexMap(int, int *); 3253022affSStefano Zampini void SetGPUSampling(bool); 3353022affSStefano Zampini void SetStream(h2opusComputeStream_t); 3453022affSStefano Zampini virtual void sample(H2Opus_Real *, H2Opus_Real *, int); 35d71ae5a4SJacob Faibussowitsch Mat GetSamplingMat() { return A; } 3653022affSStefano Zampini }; 3753022affSStefano Zampini 38d71ae5a4SJacob Faibussowitsch void PetscMatrixSampler::Init() 39d71ae5a4SJacob Faibussowitsch { 4053022affSStefano Zampini this->A = NULL; 4153022affSStefano Zampini this->gpusampling = false; 4253022affSStefano Zampini this->stream = NULL; 4353022affSStefano Zampini } 4453022affSStefano Zampini 45d71ae5a4SJacob Faibussowitsch PetscMatrixSampler::PetscMatrixSampler() 46d71ae5a4SJacob Faibussowitsch { 4753022affSStefano Zampini Init(); 4853022affSStefano Zampini } 4953022affSStefano Zampini 50d71ae5a4SJacob Faibussowitsch PetscMatrixSampler::PetscMatrixSampler(Mat A) 51d71ae5a4SJacob Faibussowitsch { 5253022affSStefano Zampini Init(); 5353022affSStefano Zampini SetSamplingMat(A); 5453022affSStefano Zampini } 5553022affSStefano Zampini 56d71ae5a4SJacob Faibussowitsch void PetscMatrixSampler::SetSamplingMat(Mat A) 57d71ae5a4SJacob Faibussowitsch { 58300d917bSStefano Zampini PetscMPIInt size = 1; 5953022affSStefano Zampini 603ba16761SJacob Faibussowitsch if (A) PetscCallVoid(static_cast<PetscErrorCode>(MPI_Comm_size(PetscObjectComm((PetscObject)A), &size))); 619566063dSJacob Faibussowitsch if (size > 1) PetscCallVoid(PETSC_ERR_SUP); 629566063dSJacob Faibussowitsch PetscCallVoid(PetscObjectReference((PetscObject)A)); 639566063dSJacob Faibussowitsch PetscCallVoid(MatDestroy(&this->A)); 6453022affSStefano Zampini this->A = A; 6553022affSStefano Zampini } 6653022affSStefano Zampini 67d71ae5a4SJacob Faibussowitsch void PetscMatrixSampler::SetStream(h2opusComputeStream_t stream) 68d71ae5a4SJacob Faibussowitsch { 6953022affSStefano Zampini this->stream = stream; 7053022affSStefano Zampini } 7153022affSStefano Zampini 72d71ae5a4SJacob Faibussowitsch void PetscMatrixSampler::SetIndexMap(int n, int *indexmap) 73d71ae5a4SJacob Faibussowitsch { 7453022affSStefano Zampini copyVector(this->hindexmap, indexmap, n, H2OPUS_HWTYPE_CPU); 7553022affSStefano Zampini #if defined(PETSC_HAVE_CUDA) && defined(H2OPUS_USE_GPU) 7653022affSStefano Zampini copyVector(this->dindexmap, indexmap, n, H2OPUS_HWTYPE_CPU); 7753022affSStefano Zampini #endif 7853022affSStefano Zampini } 7953022affSStefano Zampini 80d71ae5a4SJacob Faibussowitsch void PetscMatrixSampler::VerifyBuffers(int nv) 81d71ae5a4SJacob Faibussowitsch { 8253022affSStefano Zampini if (this->hindexmap.size()) { 8353022affSStefano Zampini size_t n = this->hindexmap.size(); 8453022affSStefano Zampini if (!this->gpusampling) { 859371c9d4SSatish Balay if (hbuffer_in.size() < (size_t)n * nv) hbuffer_in.resize(n * nv); 869371c9d4SSatish Balay if (hbuffer_out.size() < (size_t)n * nv) hbuffer_out.resize(n * nv); 8753022affSStefano Zampini } else { 8853022affSStefano Zampini #if defined(PETSC_HAVE_CUDA) && defined(H2OPUS_USE_GPU) 899371c9d4SSatish Balay if (dbuffer_in.size() < (size_t)n * nv) dbuffer_in.resize(n * nv); 909371c9d4SSatish Balay if (dbuffer_out.size() < (size_t)n * nv) dbuffer_out.resize(n * nv); 9153022affSStefano Zampini #endif 9253022affSStefano Zampini } 9353022affSStefano Zampini } 9453022affSStefano Zampini } 9553022affSStefano Zampini 96d71ae5a4SJacob Faibussowitsch void PetscMatrixSampler::PermuteBuffersIn(int nv, H2Opus_Real *v, H2Opus_Real **w, H2Opus_Real *ov, H2Opus_Real **ow) 97d71ae5a4SJacob Faibussowitsch { 9853022affSStefano Zampini *w = v; 9953022affSStefano Zampini *ow = ov; 10053022affSStefano Zampini VerifyBuffers(nv); 10153022affSStefano Zampini if (this->hindexmap.size()) { 10253022affSStefano Zampini size_t n = this->hindexmap.size(); 10353022affSStefano Zampini if (!this->gpusampling) { 1049371c9d4SSatish Balay permute_vectors(v, this->hbuffer_in.data(), n, nv, this->hindexmap.data(), 1, H2OPUS_HWTYPE_CPU, this->stream); 10553022affSStefano Zampini *w = this->hbuffer_in.data(); 10653022affSStefano Zampini *ow = this->hbuffer_out.data(); 10753022affSStefano Zampini } else { 10853022affSStefano Zampini #if defined(PETSC_HAVE_CUDA) && defined(H2OPUS_USE_GPU) 1099371c9d4SSatish Balay permute_vectors(v, this->dbuffer_in.data(), n, nv, this->dindexmap.data(), 1, H2OPUS_HWTYPE_GPU, this->stream); 11053022affSStefano Zampini *w = this->dbuffer_in.data(); 11153022affSStefano Zampini *ow = this->dbuffer_out.data(); 11253022affSStefano Zampini #endif 11353022affSStefano Zampini } 11453022affSStefano Zampini } 11553022affSStefano Zampini } 11653022affSStefano Zampini 117d71ae5a4SJacob Faibussowitsch void PetscMatrixSampler::PermuteBuffersOut(int nv, H2Opus_Real *v) 118d71ae5a4SJacob Faibussowitsch { 11953022affSStefano Zampini VerifyBuffers(nv); 12053022affSStefano Zampini if (this->hindexmap.size()) { 12153022affSStefano Zampini size_t n = this->hindexmap.size(); 12253022affSStefano Zampini if (!this->gpusampling) { 1239371c9d4SSatish Balay permute_vectors(this->hbuffer_out.data(), v, n, nv, this->hindexmap.data(), 0, H2OPUS_HWTYPE_CPU, this->stream); 12453022affSStefano Zampini } else { 12553022affSStefano Zampini #if defined(PETSC_HAVE_CUDA) && defined(H2OPUS_USE_GPU) 1269371c9d4SSatish Balay permute_vectors(this->dbuffer_out.data(), v, n, nv, this->dindexmap.data(), 0, H2OPUS_HWTYPE_GPU, this->stream); 12753022affSStefano Zampini #endif 12853022affSStefano Zampini } 12953022affSStefano Zampini } 13053022affSStefano Zampini } 13153022affSStefano Zampini 132d71ae5a4SJacob Faibussowitsch void PetscMatrixSampler::SetGPUSampling(bool gpusampling) 133d71ae5a4SJacob Faibussowitsch { 13453022affSStefano Zampini this->gpusampling = gpusampling; 13553022affSStefano Zampini } 13653022affSStefano Zampini 137d71ae5a4SJacob Faibussowitsch PetscMatrixSampler::~PetscMatrixSampler() 138d71ae5a4SJacob Faibussowitsch { 1399566063dSJacob Faibussowitsch PetscCallVoid(MatDestroy(&A)); 14053022affSStefano Zampini } 14153022affSStefano Zampini 142d71ae5a4SJacob Faibussowitsch void PetscMatrixSampler::sample(H2Opus_Real *x, H2Opus_Real *y, int samples) 143d71ae5a4SJacob Faibussowitsch { 14453022affSStefano Zampini MPI_Comm comm = PetscObjectComm((PetscObject)this->A); 14553022affSStefano Zampini Mat X = NULL, Y = NULL; 14653022affSStefano Zampini PetscInt M, N, m, n; 14753022affSStefano Zampini H2Opus_Real *px, *py; 1482592a002SStefano Zampini VecType vtype; 14953022affSStefano Zampini 1509566063dSJacob Faibussowitsch if (!this->A) PetscCallVoid(PETSC_ERR_PLIB); 1519566063dSJacob Faibussowitsch PetscCallVoid(MatGetSize(this->A, &M, &N)); 1522592a002SStefano Zampini PetscCallVoid(MatGetVecType(this->A, &vtype)); 1539566063dSJacob Faibussowitsch PetscCallVoid(MatGetLocalSize(this->A, &m, &n)); 1549566063dSJacob Faibussowitsch PetscCallVoid(PetscObjectGetComm((PetscObject)A, &comm)); 15553022affSStefano Zampini PermuteBuffersIn(samples, x, &px, y, &py); 15653022affSStefano Zampini if (!this->gpusampling) { 1579566063dSJacob Faibussowitsch PetscCallVoid(MatCreateDense(comm, n, PETSC_DECIDE, N, samples, px, &X)); 1589566063dSJacob Faibussowitsch PetscCallVoid(MatCreateDense(comm, m, PETSC_DECIDE, M, samples, py, &Y)); 1592592a002SStefano Zampini PetscCallVoid(MatSetVecType(X, vtype)); 1602592a002SStefano Zampini PetscCallVoid(MatSetVecType(Y, vtype)); 1612592a002SStefano Zampini 16253022affSStefano Zampini } else { 16353022affSStefano Zampini #if defined(PETSC_HAVE_CUDA) 1649566063dSJacob Faibussowitsch PetscCallVoid(MatCreateDenseCUDA(comm, n, PETSC_DECIDE, N, samples, px, &X)); 1659566063dSJacob Faibussowitsch PetscCallVoid(MatCreateDenseCUDA(comm, m, PETSC_DECIDE, M, samples, py, &Y)); 1662592a002SStefano Zampini PetscCallVoid(MatSetVecType(X, vtype)); 1672592a002SStefano Zampini PetscCallVoid(MatSetVecType(Y, vtype)); 16853022affSStefano Zampini #endif 16953022affSStefano Zampini } 1709566063dSJacob Faibussowitsch PetscCallVoid(MatMatMult(this->A, X, MAT_REUSE_MATRIX, PETSC_DEFAULT, &Y)); 17153022affSStefano Zampini #if defined(PETSC_HAVE_CUDA) 17253022affSStefano Zampini if (this->gpusampling) { 17353022affSStefano Zampini const PetscScalar *dummy; 1749566063dSJacob Faibussowitsch PetscCallVoid(MatDenseCUDAGetArrayRead(Y, &dummy)); 1759566063dSJacob Faibussowitsch PetscCallVoid(MatDenseCUDARestoreArrayRead(Y, &dummy)); 17653022affSStefano Zampini } 17753022affSStefano Zampini #endif 17853022affSStefano Zampini PermuteBuffersOut(samples, y); 1799566063dSJacob Faibussowitsch PetscCallVoid(MatDestroy(&X)); 1809566063dSJacob Faibussowitsch PetscCallVoid(MatDestroy(&Y)); 18153022affSStefano Zampini } 182