1*f1be3500SJunchao Zhang #include <petscvec_kokkos.hpp> 2*f1be3500SJunchao Zhang #include <../src/vec/vec/impls/seq/kokkos/veckokkosimpl.hpp> 3*f1be3500SJunchao Zhang #include <petscdevice.h> 4*f1be3500SJunchao Zhang #include <../src/ksp/pc/impls/vpbjacobi/vpbjacobi.h> 5*f1be3500SJunchao Zhang 6*f1be3500SJunchao Zhang /* A class that manages helper arrays assisting parallel PCApply() with Kokkos */ 7*f1be3500SJunchao Zhang struct PC_VPBJacobi_Kokkos { 8*f1be3500SJunchao Zhang /* Cache the old sizes to check if we need realloc */ 9*f1be3500SJunchao Zhang PetscInt n; /* number of rows of the local matrix */ 10*f1be3500SJunchao Zhang PetscInt nblocks; /* number of point blocks */ 11*f1be3500SJunchao Zhang PetscInt nsize; /* sum of sizes of the point blocks */ 12*f1be3500SJunchao Zhang 13*f1be3500SJunchao Zhang /* Helper arrays that are pre-computed on host and then copied to device. 14*f1be3500SJunchao Zhang bs: [nblocks+1], "csr" version of bsizes[] 15*f1be3500SJunchao Zhang bs2: [nblocks+1], "csr" version of squares of bsizes[] 16*f1be3500SJunchao Zhang matIdx: [n], row i of the local matrix belongs to the matIdx_d[i] block 17*f1be3500SJunchao Zhang */ 18*f1be3500SJunchao Zhang PetscIntKokkosDualView bs_dual, bs2_dual, matIdx_dual; 19*f1be3500SJunchao Zhang PetscScalarKokkosDualView diag_dual; 20*f1be3500SJunchao Zhang 21*f1be3500SJunchao Zhang PC_VPBJacobi_Kokkos(PetscInt n,PetscInt nblocks,PetscInt nsize,const PetscInt *bsizes,MatScalar *diag_ptr_h) 22*f1be3500SJunchao Zhang : n(n),nblocks(nblocks),nsize(nsize), 23*f1be3500SJunchao Zhang bs_dual("bs_dual",nblocks+1),bs2_dual("bs2_dual",nblocks+1),matIdx_dual("matIdx_dual",n) 24*f1be3500SJunchao Zhang { 25*f1be3500SJunchao Zhang PetscScalarKokkosViewHost diag_h(diag_ptr_h,nsize); 26*f1be3500SJunchao Zhang 27*f1be3500SJunchao Zhang auto diag_d = Kokkos::create_mirror_view(DefaultMemorySpace(),diag_h); 28*f1be3500SJunchao Zhang diag_dual = PetscScalarKokkosDualView(diag_d,diag_h); 29*f1be3500SJunchao Zhang PetscCallVoid(UpdateOffsetsOnDevice(bsizes,diag_ptr_h)); 30*f1be3500SJunchao Zhang } 31*f1be3500SJunchao Zhang 32*f1be3500SJunchao Zhang PetscErrorCode UpdateOffsetsOnDevice(const PetscInt *bsizes,MatScalar *diag_ptr_h) 33*f1be3500SJunchao Zhang { 34*f1be3500SJunchao Zhang PetscFunctionBegin; 35*f1be3500SJunchao Zhang PetscCheck(diag_dual.view_host().data() == diag_ptr_h, PETSC_COMM_SELF,PETSC_ERR_PLIB,"Host pointer has changed since last call"); 36*f1be3500SJunchao Zhang PetscCall(ComputeOffsetsOnHost(bsizes)); 37*f1be3500SJunchao Zhang 38*f1be3500SJunchao Zhang PetscCallCXX(bs_dual.modify_host()); 39*f1be3500SJunchao Zhang PetscCallCXX(bs2_dual.modify_host()); 40*f1be3500SJunchao Zhang PetscCallCXX(matIdx_dual.modify_host()); 41*f1be3500SJunchao Zhang PetscCallCXX(diag_dual.modify_host()); 42*f1be3500SJunchao Zhang 43*f1be3500SJunchao Zhang PetscCallCXX(bs_dual.sync_device()); 44*f1be3500SJunchao Zhang PetscCallCXX(bs2_dual.sync_device()); 45*f1be3500SJunchao Zhang PetscCallCXX(matIdx_dual.sync_device()); 46*f1be3500SJunchao Zhang PetscCallCXX(diag_dual.sync_device()); 47*f1be3500SJunchao Zhang PetscFunctionReturn(0); 48*f1be3500SJunchao Zhang } 49*f1be3500SJunchao Zhang 50*f1be3500SJunchao Zhang private: 51*f1be3500SJunchao Zhang PetscErrorCode ComputeOffsetsOnHost(const PetscInt *bsizes) 52*f1be3500SJunchao Zhang { 53*f1be3500SJunchao Zhang PetscInt *bs_h = bs_dual.view_host().data(); 54*f1be3500SJunchao Zhang PetscInt *bs2_h = bs2_dual.view_host().data(); 55*f1be3500SJunchao Zhang PetscInt *matIdx_h = matIdx_dual.view_host().data(); 56*f1be3500SJunchao Zhang 57*f1be3500SJunchao Zhang PetscFunctionBegin; 58*f1be3500SJunchao Zhang bs_h[0] = bs2_h[0] = 0; 59*f1be3500SJunchao Zhang for (PetscInt i=0; i<nblocks; i++) { 60*f1be3500SJunchao Zhang bs_h[i+1] = bs_h[i] + bsizes[i]; 61*f1be3500SJunchao Zhang bs2_h[i+1] = bs2_h[i] + bsizes[i]*bsizes[i]; 62*f1be3500SJunchao Zhang for (PetscInt j=0; j<bsizes[i]; j++) matIdx_h[bs_h[i]+j] = i; 63*f1be3500SJunchao Zhang } 64*f1be3500SJunchao Zhang PetscFunctionReturn(0); 65*f1be3500SJunchao Zhang } 66*f1be3500SJunchao Zhang }; 67*f1be3500SJunchao Zhang 68*f1be3500SJunchao Zhang static PetscErrorCode PCApply_VPBJacobi_Kokkos(PC pc,Vec x,Vec y) 69*f1be3500SJunchao Zhang { 70*f1be3500SJunchao Zhang PC_VPBJacobi *jac = (PC_VPBJacobi*)pc->data; 71*f1be3500SJunchao Zhang PC_VPBJacobi_Kokkos *pckok = static_cast<PC_VPBJacobi_Kokkos*>(jac->spptr); 72*f1be3500SJunchao Zhang ConstPetscScalarKokkosView xv; 73*f1be3500SJunchao Zhang PetscScalarKokkosView yv; 74*f1be3500SJunchao Zhang PetscScalarKokkosView diag = pckok->diag_dual.view_device(); 75*f1be3500SJunchao Zhang PetscIntKokkosView bs = pckok->bs_dual.view_device(); 76*f1be3500SJunchao Zhang PetscIntKokkosView bs2 = pckok->bs2_dual.view_device(); 77*f1be3500SJunchao Zhang PetscIntKokkosView matIdx = pckok->matIdx_dual.view_device(); 78*f1be3500SJunchao Zhang 79*f1be3500SJunchao Zhang PetscFunctionBegin; 80*f1be3500SJunchao Zhang VecErrorIfNotKokkos(x); 81*f1be3500SJunchao Zhang VecErrorIfNotKokkos(y); 82*f1be3500SJunchao Zhang PetscCall(VecGetKokkosView(x,&xv)); 83*f1be3500SJunchao Zhang PetscCall(VecGetKokkosViewWrite(y,&yv)); 84*f1be3500SJunchao Zhang PetscCallCXX(Kokkos::parallel_for("PCApply_VPBJacobi_Kokkos",pckok->n,KOKKOS_LAMBDA(PetscInt row) { 85*f1be3500SJunchao Zhang const PetscScalar *Ap,*xp; 86*f1be3500SJunchao Zhang PetscScalar *yp; 87*f1be3500SJunchao Zhang PetscInt i,j,k,m; 88*f1be3500SJunchao Zhang 89*f1be3500SJunchao Zhang k = matIdx(row); /* k-th block/matrix */ 90*f1be3500SJunchao Zhang m = bs(k+1) - bs(k); /* block size of the k-th block */ 91*f1be3500SJunchao Zhang i = row - bs(k); /* i-th row of the block */ 92*f1be3500SJunchao Zhang Ap = &diag(bs2(k) + i); /* Ap points to the first entry of i-th row */ 93*f1be3500SJunchao Zhang xp = &xv(bs(k)); 94*f1be3500SJunchao Zhang yp = &yv(bs(k)); 95*f1be3500SJunchao Zhang 96*f1be3500SJunchao Zhang yp[i] = 0.0; 97*f1be3500SJunchao Zhang for (j=0; j<m; j++) {yp[i] += Ap[0]*xp[j]; Ap += m;} 98*f1be3500SJunchao Zhang })); 99*f1be3500SJunchao Zhang PetscCall(VecRestoreKokkosView(x,&xv)); 100*f1be3500SJunchao Zhang PetscCall(VecRestoreKokkosViewWrite(y,&yv)); 101*f1be3500SJunchao Zhang PetscFunctionReturn(0); 102*f1be3500SJunchao Zhang } 103*f1be3500SJunchao Zhang 104*f1be3500SJunchao Zhang static PetscErrorCode PCDestroy_VPBJacobi_Kokkos(PC pc) 105*f1be3500SJunchao Zhang { 106*f1be3500SJunchao Zhang PC_VPBJacobi *jac = (PC_VPBJacobi*)pc->data; 107*f1be3500SJunchao Zhang 108*f1be3500SJunchao Zhang PetscFunctionBegin; 109*f1be3500SJunchao Zhang PetscCallCXX(delete static_cast<PC_VPBJacobi_Kokkos*>(jac->spptr)); 110*f1be3500SJunchao Zhang PetscCall(PCDestroy_VPBJacobi(pc)); 111*f1be3500SJunchao Zhang PetscFunctionReturn(0); 112*f1be3500SJunchao Zhang } 113*f1be3500SJunchao Zhang 114*f1be3500SJunchao Zhang PETSC_INTERN PetscErrorCode PCSetUp_VPBJacobi_Kokkos(PC pc) 115*f1be3500SJunchao Zhang { 116*f1be3500SJunchao Zhang PC_VPBJacobi *jac = (PC_VPBJacobi*)pc->data; 117*f1be3500SJunchao Zhang PC_VPBJacobi_Kokkos *pckok = static_cast<PC_VPBJacobi_Kokkos*>(jac->spptr); 118*f1be3500SJunchao Zhang PetscInt i,n,nblocks,nsize = 0; 119*f1be3500SJunchao Zhang const PetscInt *bsizes; 120*f1be3500SJunchao Zhang 121*f1be3500SJunchao Zhang PetscFunctionBegin; 122*f1be3500SJunchao Zhang PetscCall(PCSetUp_VPBJacobi_Host(pc)); /* Compute the inverse on host now. Might worth doing it on device directly */ 123*f1be3500SJunchao Zhang PetscCall(MatGetVariableBlockSizes(pc->pmat,&nblocks,&bsizes)); 124*f1be3500SJunchao Zhang for (i=0; i<nblocks; i++) nsize += bsizes[i]*bsizes[i]; 125*f1be3500SJunchao Zhang PetscCall(MatGetLocalSize(pc->pmat,&n,NULL)); 126*f1be3500SJunchao Zhang 127*f1be3500SJunchao Zhang /* If one calls MatSetVariableBlockSizes() multiple times and sizes have been changed (is it allowed?), we delete the old and rebuild anyway */ 128*f1be3500SJunchao Zhang if (pckok && (pckok->n != n || pckok->nblocks != nblocks || pckok->nsize != nsize)) { 129*f1be3500SJunchao Zhang PetscCallCXX(delete pckok); 130*f1be3500SJunchao Zhang pckok = nullptr; 131*f1be3500SJunchao Zhang } 132*f1be3500SJunchao Zhang 133*f1be3500SJunchao Zhang if (!pckok) { /* allocate the struct along with the helper arrays from the scatch */ 134*f1be3500SJunchao Zhang PetscCallCXX(jac->spptr = new PC_VPBJacobi_Kokkos(n,nblocks,nsize,bsizes,jac->diag)); 135*f1be3500SJunchao Zhang } else { /* update the value only */ 136*f1be3500SJunchao Zhang PetscCall(pckok->UpdateOffsetsOnDevice(bsizes,jac->diag)); 137*f1be3500SJunchao Zhang } 138*f1be3500SJunchao Zhang 139*f1be3500SJunchao Zhang pc->ops->apply = PCApply_VPBJacobi_Kokkos; 140*f1be3500SJunchao Zhang pc->ops->destroy = PCDestroy_VPBJacobi_Kokkos; 141*f1be3500SJunchao Zhang PetscFunctionReturn(0); 142*f1be3500SJunchao Zhang } 143