xref: /petsc/src/ksp/pc/impls/vpbjacobi/kokkos/vpbjacobi_kok.kokkos.cxx (revision 4877389942e948c907dd43b3a2d019cdca14ea44)
1 #include <petscvec_kokkos.hpp>
2 #include <../src/vec/vec/impls/seq/kokkos/veckokkosimpl.hpp>
3 #include <petscdevice.h>
4 #include <../src/ksp/pc/impls/vpbjacobi/vpbjacobi.h>
5 
6 /* A class that manages helper arrays assisting parallel PCApply() with Kokkos */
7 struct PC_VPBJacobi_Kokkos {
8   /* Cache the old sizes to check if we need realloc */
9   PetscInt n;       /* number of rows of the local matrix */
10   PetscInt nblocks; /* number of point blocks */
11   PetscInt nsize;   /* sum of sizes of the point blocks */
12 
13   /* Helper arrays that are pre-computed on host and then copied to device.
14     bs:     [nblocks+1], "csr" version of bsizes[]
15     bs2:    [nblocks+1], "csr" version of squares of bsizes[]
16     matIdx: [n], row i of the local matrix belongs to the matIdx_d[i] block
17   */
18   PetscIntKokkosDualView    bs_dual, bs2_dual, matIdx_dual;
19   PetscScalarKokkosDualView diag_dual;
20 
21   PC_VPBJacobi_Kokkos(PetscInt n, PetscInt nblocks, PetscInt nsize, const PetscInt *bsizes, MatScalar *diag_ptr_h) :
22     n(n), nblocks(nblocks), nsize(nsize), bs_dual("bs_dual", nblocks + 1), bs2_dual("bs2_dual", nblocks + 1), matIdx_dual("matIdx_dual", n)
23   {
24     PetscScalarKokkosViewHost diag_h(diag_ptr_h, nsize);
25 
26     auto diag_d = Kokkos::create_mirror_view(DefaultMemorySpace(), diag_h);
27     diag_dual   = PetscScalarKokkosDualView(diag_d, diag_h);
28     PetscCallVoid(UpdateOffsetsOnDevice(bsizes, diag_ptr_h));
29   }
30 
31   PetscErrorCode UpdateOffsetsOnDevice(const PetscInt *bsizes, MatScalar *diag_ptr_h)
32   {
33     PetscFunctionBegin;
34     PetscCheck(diag_dual.view_host().data() == diag_ptr_h, PETSC_COMM_SELF, PETSC_ERR_PLIB, "Host pointer has changed since last call");
35     PetscCall(ComputeOffsetsOnHost(bsizes));
36 
37     PetscCallCXX(bs_dual.modify_host());
38     PetscCallCXX(bs2_dual.modify_host());
39     PetscCallCXX(matIdx_dual.modify_host());
40     PetscCallCXX(diag_dual.modify_host());
41 
42     PetscCallCXX(bs_dual.sync_device());
43     PetscCallCXX(bs2_dual.sync_device());
44     PetscCallCXX(matIdx_dual.sync_device());
45     PetscCallCXX(diag_dual.sync_device());
46     PetscCall(PetscLogCpuToGpu(sizeof(PetscInt) * (2 * nblocks + 2 + n) + sizeof(MatScalar) * nsize));
47     PetscFunctionReturn(0);
48   }
49 
50 private:
51   PetscErrorCode ComputeOffsetsOnHost(const PetscInt *bsizes)
52   {
53     PetscInt *bs_h     = bs_dual.view_host().data();
54     PetscInt *bs2_h    = bs2_dual.view_host().data();
55     PetscInt *matIdx_h = matIdx_dual.view_host().data();
56 
57     PetscFunctionBegin;
58     bs_h[0] = bs2_h[0] = 0;
59     for (PetscInt i = 0; i < nblocks; i++) {
60       bs_h[i + 1]  = bs_h[i] + bsizes[i];
61       bs2_h[i + 1] = bs2_h[i] + bsizes[i] * bsizes[i];
62       for (PetscInt j = 0; j < bsizes[i]; j++) matIdx_h[bs_h[i] + j] = i;
63     }
64     PetscFunctionReturn(0);
65   }
66 };
67 
68 static PetscErrorCode PCApply_VPBJacobi_Kokkos(PC pc, Vec x, Vec y)
69 {
70   PC_VPBJacobi              *jac   = (PC_VPBJacobi *)pc->data;
71   PC_VPBJacobi_Kokkos       *pckok = static_cast<PC_VPBJacobi_Kokkos *>(jac->spptr);
72   ConstPetscScalarKokkosView xv;
73   PetscScalarKokkosView      yv;
74   PetscScalarKokkosView      diag   = pckok->diag_dual.view_device();
75   PetscIntKokkosView         bs     = pckok->bs_dual.view_device();
76   PetscIntKokkosView         bs2    = pckok->bs2_dual.view_device();
77   PetscIntKokkosView         matIdx = pckok->matIdx_dual.view_device();
78 
79   PetscFunctionBegin;
80   PetscCall(PetscLogGpuTimeBegin());
81   VecErrorIfNotKokkos(x);
82   VecErrorIfNotKokkos(y);
83   PetscCall(VecGetKokkosView(x, &xv));
84   PetscCall(VecGetKokkosViewWrite(y, &yv));
85   PetscCallCXX(Kokkos::parallel_for(
86     "PCApply_VPBJacobi_Kokkos", pckok->n, KOKKOS_LAMBDA(PetscInt row) {
87       const PetscScalar *Ap, *xp;
88       PetscScalar       *yp;
89       PetscInt           i, j, k, m;
90 
91       k  = matIdx(row);       /* k-th block/matrix */
92       m  = bs(k + 1) - bs(k); /* block size of the k-th block */
93       i  = row - bs(k);       /* i-th row of the block */
94       Ap = &diag(bs2(k) + i); /* Ap points to the first entry of i-th row */
95       xp = &xv(bs(k));
96       yp = &yv(bs(k));
97 
98       yp[i] = 0.0;
99       for (j = 0; j < m; j++) {
100         yp[i] += Ap[0] * xp[j];
101         Ap += m;
102       }
103     }));
104   PetscCall(VecRestoreKokkosView(x, &xv));
105   PetscCall(VecRestoreKokkosViewWrite(y, &yv));
106   PetscCall(PetscLogGpuFlops(pckok->nsize * 2)); /* FMA on entries in all blocks */
107   PetscCall(PetscLogGpuTimeEnd());
108   PetscFunctionReturn(0);
109 }
110 
111 static PetscErrorCode PCDestroy_VPBJacobi_Kokkos(PC pc)
112 {
113   PC_VPBJacobi *jac = (PC_VPBJacobi *)pc->data;
114 
115   PetscFunctionBegin;
116   PetscCallCXX(delete static_cast<PC_VPBJacobi_Kokkos *>(jac->spptr));
117   PetscCall(PCDestroy_VPBJacobi(pc));
118   PetscFunctionReturn(0);
119 }
120 
121 PETSC_INTERN PetscErrorCode PCSetUp_VPBJacobi_Kokkos(PC pc)
122 {
123   PC_VPBJacobi        *jac   = (PC_VPBJacobi *)pc->data;
124   PC_VPBJacobi_Kokkos *pckok = static_cast<PC_VPBJacobi_Kokkos *>(jac->spptr);
125   PetscInt             i, n, nblocks, nsize = 0;
126   const PetscInt      *bsizes;
127 
128   PetscFunctionBegin;
129   PetscCall(PCSetUp_VPBJacobi_Host(pc)); /* Compute the inverse on host now. Might worth doing it on device directly */
130   PetscCall(MatGetVariableBlockSizes(pc->pmat, &nblocks, &bsizes));
131   for (i = 0; i < nblocks; i++) nsize += bsizes[i] * bsizes[i];
132   PetscCall(MatGetLocalSize(pc->pmat, &n, NULL));
133 
134   /* If one calls MatSetVariableBlockSizes() multiple times and sizes have been changed (is it allowed?), we delete the old and rebuild anyway */
135   if (pckok && (pckok->n != n || pckok->nblocks != nblocks || pckok->nsize != nsize)) {
136     PetscCallCXX(delete pckok);
137     pckok = nullptr;
138   }
139 
140   if (!pckok) { /* allocate the struct along with the helper arrays from the scratch */
141     PetscCallCXX(jac->spptr = new PC_VPBJacobi_Kokkos(n, nblocks, nsize, bsizes, jac->diag));
142   } else { /* update the value only */
143     PetscCall(pckok->UpdateOffsetsOnDevice(bsizes, jac->diag));
144   }
145 
146   pc->ops->apply   = PCApply_VPBJacobi_Kokkos;
147   pc->ops->destroy = PCDestroy_VPBJacobi_Kokkos;
148   PetscFunctionReturn(0);
149 }
150