xref: /petsc/src/ksp/pc/impls/vpbjacobi/kokkos/vpbjacobi_kok.kokkos.cxx (revision f1be35001843913b137132cb0fff0c609e27f59e)
1*f1be3500SJunchao Zhang #include <petscvec_kokkos.hpp>
2*f1be3500SJunchao Zhang #include <../src/vec/vec/impls/seq/kokkos/veckokkosimpl.hpp>
3*f1be3500SJunchao Zhang #include <petscdevice.h>
4*f1be3500SJunchao Zhang #include <../src/ksp/pc/impls/vpbjacobi/vpbjacobi.h>
5*f1be3500SJunchao Zhang 
6*f1be3500SJunchao Zhang /* A class that manages helper arrays assisting parallel PCApply() with Kokkos */
7*f1be3500SJunchao Zhang struct PC_VPBJacobi_Kokkos {
8*f1be3500SJunchao Zhang   /* Cache the old sizes to check if we need realloc */
9*f1be3500SJunchao Zhang   PetscInt  n; /* number of rows of the local matrix */
10*f1be3500SJunchao Zhang   PetscInt  nblocks; /* number of point blocks */
11*f1be3500SJunchao Zhang   PetscInt  nsize; /* sum of sizes of the point blocks */
12*f1be3500SJunchao Zhang 
13*f1be3500SJunchao Zhang   /* Helper arrays that are pre-computed on host and then copied to device.
14*f1be3500SJunchao Zhang     bs:     [nblocks+1], "csr" version of bsizes[]
15*f1be3500SJunchao Zhang     bs2:    [nblocks+1], "csr" version of squares of bsizes[]
16*f1be3500SJunchao Zhang     matIdx: [n], row i of the local matrix belongs to the matIdx_d[i] block
17*f1be3500SJunchao Zhang   */
18*f1be3500SJunchao Zhang   PetscIntKokkosDualView     bs_dual, bs2_dual, matIdx_dual;
19*f1be3500SJunchao Zhang   PetscScalarKokkosDualView  diag_dual;
20*f1be3500SJunchao Zhang 
21*f1be3500SJunchao Zhang   PC_VPBJacobi_Kokkos(PetscInt n,PetscInt nblocks,PetscInt nsize,const PetscInt *bsizes,MatScalar *diag_ptr_h)
22*f1be3500SJunchao Zhang     : n(n),nblocks(nblocks),nsize(nsize),
23*f1be3500SJunchao Zhang       bs_dual("bs_dual",nblocks+1),bs2_dual("bs2_dual",nblocks+1),matIdx_dual("matIdx_dual",n)
24*f1be3500SJunchao Zhang   {
25*f1be3500SJunchao Zhang     PetscScalarKokkosViewHost    diag_h(diag_ptr_h,nsize);
26*f1be3500SJunchao Zhang 
27*f1be3500SJunchao Zhang     auto diag_d = Kokkos::create_mirror_view(DefaultMemorySpace(),diag_h);
28*f1be3500SJunchao Zhang     diag_dual = PetscScalarKokkosDualView(diag_d,diag_h);
29*f1be3500SJunchao Zhang     PetscCallVoid(UpdateOffsetsOnDevice(bsizes,diag_ptr_h));
30*f1be3500SJunchao Zhang   }
31*f1be3500SJunchao Zhang 
32*f1be3500SJunchao Zhang   PetscErrorCode UpdateOffsetsOnDevice(const PetscInt *bsizes,MatScalar *diag_ptr_h)
33*f1be3500SJunchao Zhang   {
34*f1be3500SJunchao Zhang     PetscFunctionBegin;
35*f1be3500SJunchao Zhang     PetscCheck(diag_dual.view_host().data() == diag_ptr_h, PETSC_COMM_SELF,PETSC_ERR_PLIB,"Host pointer has changed since last call");
36*f1be3500SJunchao Zhang     PetscCall(ComputeOffsetsOnHost(bsizes));
37*f1be3500SJunchao Zhang 
38*f1be3500SJunchao Zhang     PetscCallCXX(bs_dual.modify_host());
39*f1be3500SJunchao Zhang     PetscCallCXX(bs2_dual.modify_host());
40*f1be3500SJunchao Zhang     PetscCallCXX(matIdx_dual.modify_host());
41*f1be3500SJunchao Zhang     PetscCallCXX(diag_dual.modify_host());
42*f1be3500SJunchao Zhang 
43*f1be3500SJunchao Zhang     PetscCallCXX(bs_dual.sync_device());
44*f1be3500SJunchao Zhang     PetscCallCXX(bs2_dual.sync_device());
45*f1be3500SJunchao Zhang     PetscCallCXX(matIdx_dual.sync_device());
46*f1be3500SJunchao Zhang     PetscCallCXX(diag_dual.sync_device());
47*f1be3500SJunchao Zhang     PetscFunctionReturn(0);
48*f1be3500SJunchao Zhang   }
49*f1be3500SJunchao Zhang 
50*f1be3500SJunchao Zhang private:
51*f1be3500SJunchao Zhang   PetscErrorCode ComputeOffsetsOnHost(const PetscInt *bsizes)
52*f1be3500SJunchao Zhang   {
53*f1be3500SJunchao Zhang     PetscInt *bs_h     = bs_dual.view_host().data();
54*f1be3500SJunchao Zhang     PetscInt *bs2_h    = bs2_dual.view_host().data();
55*f1be3500SJunchao Zhang     PetscInt *matIdx_h = matIdx_dual.view_host().data();
56*f1be3500SJunchao Zhang 
57*f1be3500SJunchao Zhang     PetscFunctionBegin;
58*f1be3500SJunchao Zhang     bs_h[0] = bs2_h[0] = 0;
59*f1be3500SJunchao Zhang     for (PetscInt i=0; i<nblocks; i++) {
60*f1be3500SJunchao Zhang       bs_h[i+1]  = bs_h[i]  + bsizes[i];
61*f1be3500SJunchao Zhang       bs2_h[i+1] = bs2_h[i] + bsizes[i]*bsizes[i];
62*f1be3500SJunchao Zhang       for (PetscInt j=0; j<bsizes[i]; j++) matIdx_h[bs_h[i]+j] = i;
63*f1be3500SJunchao Zhang     }
64*f1be3500SJunchao Zhang     PetscFunctionReturn(0);
65*f1be3500SJunchao Zhang   }
66*f1be3500SJunchao Zhang };
67*f1be3500SJunchao Zhang 
68*f1be3500SJunchao Zhang static PetscErrorCode PCApply_VPBJacobi_Kokkos(PC pc,Vec x,Vec y)
69*f1be3500SJunchao Zhang {
70*f1be3500SJunchao Zhang   PC_VPBJacobi               *jac = (PC_VPBJacobi*)pc->data;
71*f1be3500SJunchao Zhang   PC_VPBJacobi_Kokkos        *pckok = static_cast<PC_VPBJacobi_Kokkos*>(jac->spptr);
72*f1be3500SJunchao Zhang   ConstPetscScalarKokkosView xv;
73*f1be3500SJunchao Zhang   PetscScalarKokkosView      yv;
74*f1be3500SJunchao Zhang   PetscScalarKokkosView      diag = pckok->diag_dual.view_device();
75*f1be3500SJunchao Zhang   PetscIntKokkosView         bs = pckok->bs_dual.view_device();
76*f1be3500SJunchao Zhang   PetscIntKokkosView         bs2 = pckok->bs2_dual.view_device();
77*f1be3500SJunchao Zhang   PetscIntKokkosView         matIdx = pckok->matIdx_dual.view_device();
78*f1be3500SJunchao Zhang 
79*f1be3500SJunchao Zhang   PetscFunctionBegin;
80*f1be3500SJunchao Zhang   VecErrorIfNotKokkos(x);
81*f1be3500SJunchao Zhang   VecErrorIfNotKokkos(y);
82*f1be3500SJunchao Zhang   PetscCall(VecGetKokkosView(x,&xv));
83*f1be3500SJunchao Zhang   PetscCall(VecGetKokkosViewWrite(y,&yv));
84*f1be3500SJunchao Zhang   PetscCallCXX(Kokkos::parallel_for("PCApply_VPBJacobi_Kokkos",pckok->n,KOKKOS_LAMBDA(PetscInt row) {
85*f1be3500SJunchao Zhang     const PetscScalar *Ap,*xp;
86*f1be3500SJunchao Zhang     PetscScalar       *yp;
87*f1be3500SJunchao Zhang     PetscInt          i,j,k,m;
88*f1be3500SJunchao Zhang 
89*f1be3500SJunchao Zhang     k  = matIdx(row);        /* k-th block/matrix */
90*f1be3500SJunchao Zhang     m  = bs(k+1) - bs(k);    /* block size of the k-th block */
91*f1be3500SJunchao Zhang     i  = row - bs(k);        /* i-th row of the block */
92*f1be3500SJunchao Zhang     Ap = &diag(bs2(k) + i);  /* Ap points to the first entry of i-th row */
93*f1be3500SJunchao Zhang     xp = &xv(bs(k));
94*f1be3500SJunchao Zhang     yp = &yv(bs(k));
95*f1be3500SJunchao Zhang 
96*f1be3500SJunchao Zhang     yp[i] = 0.0;
97*f1be3500SJunchao Zhang     for (j=0; j<m; j++) {yp[i] += Ap[0]*xp[j]; Ap += m;}
98*f1be3500SJunchao Zhang   }));
99*f1be3500SJunchao Zhang   PetscCall(VecRestoreKokkosView(x,&xv));
100*f1be3500SJunchao Zhang   PetscCall(VecRestoreKokkosViewWrite(y,&yv));
101*f1be3500SJunchao Zhang   PetscFunctionReturn(0);
102*f1be3500SJunchao Zhang }
103*f1be3500SJunchao Zhang 
104*f1be3500SJunchao Zhang static PetscErrorCode PCDestroy_VPBJacobi_Kokkos(PC pc)
105*f1be3500SJunchao Zhang {
106*f1be3500SJunchao Zhang   PC_VPBJacobi       *jac = (PC_VPBJacobi*)pc->data;
107*f1be3500SJunchao Zhang 
108*f1be3500SJunchao Zhang   PetscFunctionBegin;
109*f1be3500SJunchao Zhang   PetscCallCXX(delete static_cast<PC_VPBJacobi_Kokkos*>(jac->spptr));
110*f1be3500SJunchao Zhang   PetscCall(PCDestroy_VPBJacobi(pc));
111*f1be3500SJunchao Zhang   PetscFunctionReturn(0);
112*f1be3500SJunchao Zhang }
113*f1be3500SJunchao Zhang 
114*f1be3500SJunchao Zhang PETSC_INTERN PetscErrorCode PCSetUp_VPBJacobi_Kokkos(PC pc)
115*f1be3500SJunchao Zhang {
116*f1be3500SJunchao Zhang   PC_VPBJacobi        *jac = (PC_VPBJacobi*)pc->data;
117*f1be3500SJunchao Zhang   PC_VPBJacobi_Kokkos *pckok = static_cast<PC_VPBJacobi_Kokkos*>(jac->spptr);
118*f1be3500SJunchao Zhang   PetscInt            i,n,nblocks,nsize = 0;
119*f1be3500SJunchao Zhang   const PetscInt      *bsizes;
120*f1be3500SJunchao Zhang 
121*f1be3500SJunchao Zhang   PetscFunctionBegin;
122*f1be3500SJunchao Zhang   PetscCall(PCSetUp_VPBJacobi_Host(pc)); /* Compute the inverse on host now. Might worth doing it on device directly */
123*f1be3500SJunchao Zhang   PetscCall(MatGetVariableBlockSizes(pc->pmat,&nblocks,&bsizes));
124*f1be3500SJunchao Zhang   for (i=0; i<nblocks; i++) nsize += bsizes[i]*bsizes[i];
125*f1be3500SJunchao Zhang   PetscCall(MatGetLocalSize(pc->pmat,&n,NULL));
126*f1be3500SJunchao Zhang 
127*f1be3500SJunchao Zhang   /* If one calls MatSetVariableBlockSizes() multiple times and sizes have been changed (is it allowed?), we delete the old and rebuild anyway */
128*f1be3500SJunchao Zhang   if (pckok && (pckok->n != n || pckok->nblocks != nblocks || pckok->nsize != nsize)) {
129*f1be3500SJunchao Zhang     PetscCallCXX(delete pckok);
130*f1be3500SJunchao Zhang     pckok = nullptr;
131*f1be3500SJunchao Zhang   }
132*f1be3500SJunchao Zhang 
133*f1be3500SJunchao Zhang   if (!pckok) { /* allocate the struct along with the helper arrays from the scatch */
134*f1be3500SJunchao Zhang     PetscCallCXX(jac->spptr = new PC_VPBJacobi_Kokkos(n,nblocks,nsize,bsizes,jac->diag));
135*f1be3500SJunchao Zhang   } else { /* update the value only */
136*f1be3500SJunchao Zhang     PetscCall(pckok->UpdateOffsetsOnDevice(bsizes,jac->diag));
137*f1be3500SJunchao Zhang   }
138*f1be3500SJunchao Zhang 
139*f1be3500SJunchao Zhang   pc->ops->apply   = PCApply_VPBJacobi_Kokkos;
140*f1be3500SJunchao Zhang   pc->ops->destroy = PCDestroy_VPBJacobi_Kokkos;
141*f1be3500SJunchao Zhang   PetscFunctionReturn(0);
142*f1be3500SJunchao Zhang }
143