xref: /petsc/src/ksp/pc/impls/vpbjacobi/kokkos/vpbjacobi_kok.kokkos.cxx (revision 5cb80ecd61fc93edbe681011dc160a66bbda2b5d)
1 #include <petscvec_kokkos.hpp>
2 #include <../src/vec/vec/impls/seq/kokkos/veckokkosimpl.hpp>
3 #include <petscdevice.h>
4 #include <../src/ksp/pc/impls/vpbjacobi/vpbjacobi.h>
5 
6 /* A class that manages helper arrays assisting parallel PCApply() with Kokkos */
7 struct PC_VPBJacobi_Kokkos {
8   /* Cache the old sizes to check if we need realloc */
9   PetscInt  n; /* number of rows of the local matrix */
10   PetscInt  nblocks; /* number of point blocks */
11   PetscInt  nsize; /* sum of sizes of the point blocks */
12 
13   /* Helper arrays that are pre-computed on host and then copied to device.
14     bs:     [nblocks+1], "csr" version of bsizes[]
15     bs2:    [nblocks+1], "csr" version of squares of bsizes[]
16     matIdx: [n], row i of the local matrix belongs to the matIdx_d[i] block
17   */
18   PetscIntKokkosDualView     bs_dual, bs2_dual, matIdx_dual;
19   PetscScalarKokkosDualView  diag_dual;
20 
21   PC_VPBJacobi_Kokkos(PetscInt n,PetscInt nblocks,PetscInt nsize,const PetscInt *bsizes,MatScalar *diag_ptr_h)
22     : n(n),nblocks(nblocks),nsize(nsize),
23       bs_dual("bs_dual",nblocks+1),bs2_dual("bs2_dual",nblocks+1),matIdx_dual("matIdx_dual",n)
24   {
25     PetscScalarKokkosViewHost    diag_h(diag_ptr_h,nsize);
26 
27     auto diag_d = Kokkos::create_mirror_view(DefaultMemorySpace(),diag_h);
28     diag_dual = PetscScalarKokkosDualView(diag_d,diag_h);
29     PetscCallVoid(UpdateOffsetsOnDevice(bsizes,diag_ptr_h));
30   }
31 
32   PetscErrorCode UpdateOffsetsOnDevice(const PetscInt *bsizes,MatScalar *diag_ptr_h)
33   {
34     PetscFunctionBegin;
35     PetscCheck(diag_dual.view_host().data() == diag_ptr_h, PETSC_COMM_SELF,PETSC_ERR_PLIB,"Host pointer has changed since last call");
36     PetscCall(ComputeOffsetsOnHost(bsizes));
37 
38     PetscCallCXX(bs_dual.modify_host());
39     PetscCallCXX(bs2_dual.modify_host());
40     PetscCallCXX(matIdx_dual.modify_host());
41     PetscCallCXX(diag_dual.modify_host());
42 
43     PetscCallCXX(bs_dual.sync_device());
44     PetscCallCXX(bs2_dual.sync_device());
45     PetscCallCXX(matIdx_dual.sync_device());
46     PetscCallCXX(diag_dual.sync_device());
47     PetscCall(PetscLogCpuToGpu(sizeof(PetscInt)*(2*nblocks+2+n) + sizeof(MatScalar)*nsize));
48     PetscFunctionReturn(0);
49   }
50 
51 private:
52   PetscErrorCode ComputeOffsetsOnHost(const PetscInt *bsizes)
53   {
54     PetscInt *bs_h     = bs_dual.view_host().data();
55     PetscInt *bs2_h    = bs2_dual.view_host().data();
56     PetscInt *matIdx_h = matIdx_dual.view_host().data();
57 
58     PetscFunctionBegin;
59     bs_h[0] = bs2_h[0] = 0;
60     for (PetscInt i=0; i<nblocks; i++) {
61       bs_h[i+1]  = bs_h[i]  + bsizes[i];
62       bs2_h[i+1] = bs2_h[i] + bsizes[i]*bsizes[i];
63       for (PetscInt j=0; j<bsizes[i]; j++) matIdx_h[bs_h[i]+j] = i;
64     }
65     PetscFunctionReturn(0);
66   }
67 };
68 
69 static PetscErrorCode PCApply_VPBJacobi_Kokkos(PC pc,Vec x,Vec y)
70 {
71   PC_VPBJacobi               *jac = (PC_VPBJacobi*)pc->data;
72   PC_VPBJacobi_Kokkos        *pckok = static_cast<PC_VPBJacobi_Kokkos*>(jac->spptr);
73   ConstPetscScalarKokkosView xv;
74   PetscScalarKokkosView      yv;
75   PetscScalarKokkosView      diag = pckok->diag_dual.view_device();
76   PetscIntKokkosView         bs = pckok->bs_dual.view_device();
77   PetscIntKokkosView         bs2 = pckok->bs2_dual.view_device();
78   PetscIntKokkosView         matIdx = pckok->matIdx_dual.view_device();
79 
80   PetscFunctionBegin;
81   PetscCall(PetscLogGpuTimeBegin());
82   VecErrorIfNotKokkos(x);
83   VecErrorIfNotKokkos(y);
84   PetscCall(VecGetKokkosView(x,&xv));
85   PetscCall(VecGetKokkosViewWrite(y,&yv));
86   PetscCallCXX(Kokkos::parallel_for("PCApply_VPBJacobi_Kokkos",pckok->n,KOKKOS_LAMBDA(PetscInt row) {
87     const PetscScalar *Ap,*xp;
88     PetscScalar       *yp;
89     PetscInt          i,j,k,m;
90 
91     k  = matIdx(row);        /* k-th block/matrix */
92     m  = bs(k+1) - bs(k);    /* block size of the k-th block */
93     i  = row - bs(k);        /* i-th row of the block */
94     Ap = &diag(bs2(k) + i);  /* Ap points to the first entry of i-th row */
95     xp = &xv(bs(k));
96     yp = &yv(bs(k));
97 
98     yp[i] = 0.0;
99     for (j=0; j<m; j++) {yp[i] += Ap[0]*xp[j]; Ap += m;}
100   }));
101   PetscCall(VecRestoreKokkosView(x,&xv));
102   PetscCall(VecRestoreKokkosViewWrite(y,&yv));
103   PetscCall(PetscLogGpuFlops(pckok->nsize*2)); /* FMA on entries in all blocks */
104   PetscCall(PetscLogGpuTimeEnd());
105   PetscFunctionReturn(0);
106 }
107 
108 static PetscErrorCode PCDestroy_VPBJacobi_Kokkos(PC pc)
109 {
110   PC_VPBJacobi       *jac = (PC_VPBJacobi*)pc->data;
111 
112   PetscFunctionBegin;
113   PetscCallCXX(delete static_cast<PC_VPBJacobi_Kokkos*>(jac->spptr));
114   PetscCall(PCDestroy_VPBJacobi(pc));
115   PetscFunctionReturn(0);
116 }
117 
118 PETSC_INTERN PetscErrorCode PCSetUp_VPBJacobi_Kokkos(PC pc)
119 {
120   PC_VPBJacobi        *jac = (PC_VPBJacobi*)pc->data;
121   PC_VPBJacobi_Kokkos *pckok = static_cast<PC_VPBJacobi_Kokkos*>(jac->spptr);
122   PetscInt            i,n,nblocks,nsize = 0;
123   const PetscInt      *bsizes;
124 
125   PetscFunctionBegin;
126   PetscCall(PCSetUp_VPBJacobi_Host(pc)); /* Compute the inverse on host now. Might worth doing it on device directly */
127   PetscCall(MatGetVariableBlockSizes(pc->pmat,&nblocks,&bsizes));
128   for (i=0; i<nblocks; i++) nsize += bsizes[i]*bsizes[i];
129   PetscCall(MatGetLocalSize(pc->pmat,&n,NULL));
130 
131   /* If one calls MatSetVariableBlockSizes() multiple times and sizes have been changed (is it allowed?), we delete the old and rebuild anyway */
132   if (pckok && (pckok->n != n || pckok->nblocks != nblocks || pckok->nsize != nsize)) {
133     PetscCallCXX(delete pckok);
134     pckok = nullptr;
135   }
136 
137   if (!pckok) { /* allocate the struct along with the helper arrays from the scatch */
138     PetscCallCXX(jac->spptr = new PC_VPBJacobi_Kokkos(n,nblocks,nsize,bsizes,jac->diag));
139   } else { /* update the value only */
140     PetscCall(pckok->UpdateOffsetsOnDevice(bsizes,jac->diag));
141   }
142 
143   pc->ops->apply   = PCApply_VPBJacobi_Kokkos;
144   pc->ops->destroy = PCDestroy_VPBJacobi_Kokkos;
145   PetscFunctionReturn(0);
146 }
147