xref: /petsc/src/mat/impls/h2opus/math2opussampler.hpp (revision 53022affac82b2fcec7b6432d0d3b2c8aa0487f8)
1*53022affSStefano Zampini #include <petscmat.h>
2*53022affSStefano Zampini #include <h2opus.h>
3*53022affSStefano Zampini 
4*53022affSStefano Zampini #ifndef __MATH2OPUS_HPP
5*53022affSStefano Zampini #define __MATH2OPUS_HPP
6*53022affSStefano Zampini 
7*53022affSStefano Zampini class PetscMatrixSampler : public HMatrixSampler
8*53022affSStefano Zampini {
9*53022affSStefano Zampini protected:
10*53022affSStefano Zampini   Mat  A;
11*53022affSStefano Zampini   typedef typename VectorContainer<H2OPUS_HWTYPE_CPU, H2Opus_Real>::type HRealVector;
12*53022affSStefano Zampini   typedef typename VectorContainer<H2OPUS_HWTYPE_CPU, int>::type HIntVector;
13*53022affSStefano Zampini   HIntVector hindexmap;
14*53022affSStefano Zampini   HRealVector hbuffer_in,hbuffer_out;
15*53022affSStefano Zampini #if defined(PETSC_HAVE_CUDA) && defined(H2OPUS_USE_GPU)
16*53022affSStefano Zampini   H2OpusDeviceVector<int> dindexmap;
17*53022affSStefano Zampini   H2OpusDeviceVector<H2Opus_Real> dbuffer_in,dbuffer_out;
18*53022affSStefano Zampini #endif
19*53022affSStefano Zampini   bool gpusampling;
20*53022affSStefano Zampini   h2opusComputeStream_t stream;
21*53022affSStefano Zampini 
22*53022affSStefano Zampini private:
23*53022affSStefano Zampini   void Init();
24*53022affSStefano Zampini   void VerifyBuffers(int);
25*53022affSStefano Zampini   void PermuteBuffersIn(int,H2Opus_Real*,H2Opus_Real**,H2Opus_Real*,H2Opus_Real**);
26*53022affSStefano Zampini   void PermuteBuffersOut(int,H2Opus_Real*);
27*53022affSStefano Zampini 
28*53022affSStefano Zampini public:
29*53022affSStefano Zampini   PetscMatrixSampler();
30*53022affSStefano Zampini   PetscMatrixSampler(Mat);
31*53022affSStefano Zampini   ~PetscMatrixSampler();
32*53022affSStefano Zampini   void SetSamplingMat(Mat);
33*53022affSStefano Zampini   void SetIndexMap(int,int*);
34*53022affSStefano Zampini   void SetGPUSampling(bool);
35*53022affSStefano Zampini   void SetStream(h2opusComputeStream_t);
36*53022affSStefano Zampini   virtual void sample(H2Opus_Real*,H2Opus_Real*,int);
37*53022affSStefano Zampini   Mat GetSamplingMat() { return A; }
38*53022affSStefano Zampini };
39*53022affSStefano Zampini 
40*53022affSStefano Zampini void PetscMatrixSampler::Init()
41*53022affSStefano Zampini {
42*53022affSStefano Zampini   this->A = NULL;
43*53022affSStefano Zampini   this->gpusampling = false;
44*53022affSStefano Zampini   this->stream = NULL;
45*53022affSStefano Zampini }
46*53022affSStefano Zampini 
47*53022affSStefano Zampini PetscMatrixSampler::PetscMatrixSampler()
48*53022affSStefano Zampini {
49*53022affSStefano Zampini   Init();
50*53022affSStefano Zampini }
51*53022affSStefano Zampini 
52*53022affSStefano Zampini PetscMatrixSampler::PetscMatrixSampler(Mat A)
53*53022affSStefano Zampini {
54*53022affSStefano Zampini   Init();
55*53022affSStefano Zampini   SetSamplingMat(A);
56*53022affSStefano Zampini }
57*53022affSStefano Zampini 
58*53022affSStefano Zampini void PetscMatrixSampler::SetSamplingMat(Mat A)
59*53022affSStefano Zampini {
60*53022affSStefano Zampini   PetscErrorCode ierr;
61*53022affSStefano Zampini   PetscMPIInt    size;
62*53022affSStefano Zampini 
63*53022affSStefano Zampini   ierr = MPI_Comm_size(PetscObjectComm((PetscObject)A),&size);CHKERRV(ierr);
64*53022affSStefano Zampini   if (size > 1) CHKERRV(PETSC_ERR_SUP);
65*53022affSStefano Zampini   ierr = PetscObjectReference((PetscObject)A);CHKERRV(ierr);
66*53022affSStefano Zampini   ierr = MatDestroy(&this->A);CHKERRV(ierr);
67*53022affSStefano Zampini   this->A = A;
68*53022affSStefano Zampini }
69*53022affSStefano Zampini 
70*53022affSStefano Zampini void PetscMatrixSampler::SetStream(h2opusComputeStream_t stream)
71*53022affSStefano Zampini {
72*53022affSStefano Zampini   this->stream = stream;
73*53022affSStefano Zampini }
74*53022affSStefano Zampini 
75*53022affSStefano Zampini void PetscMatrixSampler::SetIndexMap(int n,int *indexmap)
76*53022affSStefano Zampini {
77*53022affSStefano Zampini   copyVector(this->hindexmap, indexmap, n, H2OPUS_HWTYPE_CPU);
78*53022affSStefano Zampini #if defined(PETSC_HAVE_CUDA) && defined(H2OPUS_USE_GPU)
79*53022affSStefano Zampini   copyVector(this->dindexmap, indexmap, n, H2OPUS_HWTYPE_CPU);
80*53022affSStefano Zampini #endif
81*53022affSStefano Zampini }
82*53022affSStefano Zampini 
83*53022affSStefano Zampini void PetscMatrixSampler::VerifyBuffers(int nv)
84*53022affSStefano Zampini {
85*53022affSStefano Zampini   if (this->hindexmap.size()) {
86*53022affSStefano Zampini     size_t n = this->hindexmap.size();
87*53022affSStefano Zampini     if (!this->gpusampling) {
88*53022affSStefano Zampini       if (hbuffer_in.size() < (size_t)n * nv)
89*53022affSStefano Zampini           hbuffer_in.resize(n * nv);
90*53022affSStefano Zampini       if (hbuffer_out.size() < (size_t)n * nv)
91*53022affSStefano Zampini           hbuffer_out.resize(n * nv);
92*53022affSStefano Zampini     } else {
93*53022affSStefano Zampini #if defined(PETSC_HAVE_CUDA) && defined(H2OPUS_USE_GPU)
94*53022affSStefano Zampini       if (dbuffer_in.size() < (size_t)n * nv)
95*53022affSStefano Zampini           dbuffer_in.resize(n * nv);
96*53022affSStefano Zampini       if (dbuffer_out.size() < (size_t)n * nv)
97*53022affSStefano Zampini           dbuffer_out.resize(n * nv);
98*53022affSStefano Zampini #endif
99*53022affSStefano Zampini     }
100*53022affSStefano Zampini   }
101*53022affSStefano Zampini }
102*53022affSStefano Zampini 
103*53022affSStefano Zampini void PetscMatrixSampler::PermuteBuffersIn(int nv, H2Opus_Real *v, H2Opus_Real **w, H2Opus_Real *ov, H2Opus_Real **ow)
104*53022affSStefano Zampini {
105*53022affSStefano Zampini   *w = v;
106*53022affSStefano Zampini   *ow = ov;
107*53022affSStefano Zampini   VerifyBuffers(nv);
108*53022affSStefano Zampini   if (this->hindexmap.size()) {
109*53022affSStefano Zampini     size_t n = this->hindexmap.size();
110*53022affSStefano Zampini     if (!this->gpusampling) {
111*53022affSStefano Zampini       permute_vectors(v, this->hbuffer_in.data(), n, nv, this->hindexmap.data(), 1, H2OPUS_HWTYPE_CPU,
112*53022affSStefano Zampini                       this->stream);
113*53022affSStefano Zampini       *w = this->hbuffer_in.data();
114*53022affSStefano Zampini       *ow = this->hbuffer_out.data();
115*53022affSStefano Zampini     } else {
116*53022affSStefano Zampini #if defined(PETSC_HAVE_CUDA) && defined(H2OPUS_USE_GPU)
117*53022affSStefano Zampini       permute_vectors(v, this->dbuffer_in.data(), n, nv, this->dindexmap.data(), 1, H2OPUS_HWTYPE_GPU,
118*53022affSStefano Zampini                       this->stream);
119*53022affSStefano Zampini       *w = this->dbuffer_in.data();
120*53022affSStefano Zampini       *ow = this->dbuffer_out.data();
121*53022affSStefano Zampini #endif
122*53022affSStefano Zampini     }
123*53022affSStefano Zampini   }
124*53022affSStefano Zampini }
125*53022affSStefano Zampini 
126*53022affSStefano Zampini void PetscMatrixSampler::PermuteBuffersOut(int nv, H2Opus_Real *v)
127*53022affSStefano Zampini {
128*53022affSStefano Zampini   VerifyBuffers(nv);
129*53022affSStefano Zampini   if (this->hindexmap.size()) {
130*53022affSStefano Zampini     size_t n = this->hindexmap.size();
131*53022affSStefano Zampini     if (!this->gpusampling) {
132*53022affSStefano Zampini       permute_vectors(this->hbuffer_out.data(), v, n, nv, this->hindexmap.data(), 0, H2OPUS_HWTYPE_CPU,
133*53022affSStefano Zampini                       this->stream);
134*53022affSStefano Zampini     } else {
135*53022affSStefano Zampini #if defined(PETSC_HAVE_CUDA) && defined(H2OPUS_USE_GPU)
136*53022affSStefano Zampini       permute_vectors(this->dbuffer_out.data(), v, n, nv, this->dindexmap.data(), 0, H2OPUS_HWTYPE_GPU,
137*53022affSStefano Zampini                       this->stream);
138*53022affSStefano Zampini #endif
139*53022affSStefano Zampini     }
140*53022affSStefano Zampini   }
141*53022affSStefano Zampini }
142*53022affSStefano Zampini 
143*53022affSStefano Zampini void PetscMatrixSampler::SetGPUSampling(bool gpusampling)
144*53022affSStefano Zampini {
145*53022affSStefano Zampini   this->gpusampling = gpusampling;
146*53022affSStefano Zampini }
147*53022affSStefano Zampini 
148*53022affSStefano Zampini PetscMatrixSampler::~PetscMatrixSampler()
149*53022affSStefano Zampini {
150*53022affSStefano Zampini   PetscErrorCode ierr;
151*53022affSStefano Zampini 
152*53022affSStefano Zampini   ierr = MatDestroy(&A);CHKERRV(ierr);
153*53022affSStefano Zampini }
154*53022affSStefano Zampini 
155*53022affSStefano Zampini void PetscMatrixSampler::sample(H2Opus_Real *x, H2Opus_Real *y, int samples)
156*53022affSStefano Zampini {
157*53022affSStefano Zampini   PetscErrorCode ierr;
158*53022affSStefano Zampini   MPI_Comm       comm = PetscObjectComm((PetscObject)this->A);
159*53022affSStefano Zampini   Mat            X = NULL,Y = NULL;
160*53022affSStefano Zampini   PetscInt       M,N,m,n;
161*53022affSStefano Zampini   H2Opus_Real    *px,*py;
162*53022affSStefano Zampini 
163*53022affSStefano Zampini   if (!this->A) CHKERRV(PETSC_ERR_PLIB);
164*53022affSStefano Zampini   ierr = MatGetSize(this->A,&M,&N);CHKERRV(ierr);
165*53022affSStefano Zampini   ierr = MatGetLocalSize(this->A,&m,&n);CHKERRV(ierr);
166*53022affSStefano Zampini   ierr = PetscObjectGetComm((PetscObject)A,&comm);CHKERRV(ierr);
167*53022affSStefano Zampini   PermuteBuffersIn(samples,x,&px,y,&py);
168*53022affSStefano Zampini   if (!this->gpusampling) {
169*53022affSStefano Zampini     ierr = MatCreateDense(comm,n,PETSC_DECIDE,N,samples,px,&X);CHKERRV(ierr);
170*53022affSStefano Zampini     ierr = MatCreateDense(comm,m,PETSC_DECIDE,M,samples,py,&Y);CHKERRV(ierr);
171*53022affSStefano Zampini   } else {
172*53022affSStefano Zampini #if defined(PETSC_HAVE_CUDA)
173*53022affSStefano Zampini     ierr = MatCreateDenseCUDA(comm,n,PETSC_DECIDE,N,samples,px,&X);CHKERRV(ierr);
174*53022affSStefano Zampini     ierr = MatCreateDenseCUDA(comm,m,PETSC_DECIDE,M,samples,py,&Y);CHKERRV(ierr);
175*53022affSStefano Zampini #endif
176*53022affSStefano Zampini   }
177*53022affSStefano Zampini   ierr = PetscLogObjectParent((PetscObject)this->A,(PetscObject)X);CHKERRV(ierr);
178*53022affSStefano Zampini   ierr = PetscLogObjectParent((PetscObject)this->A,(PetscObject)Y);CHKERRV(ierr);
179*53022affSStefano Zampini   ierr = MatMatMult(this->A,X,MAT_REUSE_MATRIX,PETSC_DEFAULT,&Y);CHKERRV(ierr);
180*53022affSStefano Zampini #if defined(PETSC_HAVE_CUDA)
181*53022affSStefano Zampini   if (this->gpusampling) {
182*53022affSStefano Zampini     const PetscScalar *dummy;
183*53022affSStefano Zampini     ierr = MatDenseCUDAGetArrayRead(Y,&dummy);CHKERRV(ierr);
184*53022affSStefano Zampini     ierr = MatDenseCUDARestoreArrayRead(Y,&dummy);CHKERRV(ierr);
185*53022affSStefano Zampini   }
186*53022affSStefano Zampini #endif
187*53022affSStefano Zampini   PermuteBuffersOut(samples,y);
188*53022affSStefano Zampini   ierr = MatDestroy(&X);CHKERRV(ierr);
189*53022affSStefano Zampini   ierr = MatDestroy(&Y);CHKERRV(ierr);
190*53022affSStefano Zampini }
191*53022affSStefano Zampini 
192*53022affSStefano Zampini #endif
193