xref: /petsc/src/ksp/pc/impls/vpbjacobi/kokkos/vpbjacobi_kok.kokkos.cxx (revision 4977266fffaaedcd41e5a67923646569c4e329aa)
1 #include <petscvec_kokkos.hpp>
2 #include <../src/vec/vec/impls/seq/kokkos/veckokkosimpl.hpp>
3 #include <petscdevice.h>
4 #include <../src/ksp/pc/impls/vpbjacobi/vpbjacobi.h>
5 
6 /* A class that manages helper arrays assisting parallel PCApply() with Kokkos */
7 struct PC_VPBJacobi_Kokkos {
8   /* Cache the old sizes to check if we need realloc */
9   PetscInt n;       /* number of rows of the local matrix */
10   PetscInt nblocks; /* number of point blocks */
11   PetscInt nsize;   /* sum of sizes of the point blocks */
12 
13   /* Helper arrays that are pre-computed on host and then copied to device.
14     bs:     [nblocks+1], "csr" version of bsizes[]
15     bs2:    [nblocks+1], "csr" version of squares of bsizes[]
16     matIdx: [n], row i of the local matrix belongs to the matIdx_d[i] block
17   */
18   PetscIntKokkosDualView    bs_dual, bs2_dual, matIdx_dual;
19   PetscScalarKokkosDualView diag_dual;
20 
21   PC_VPBJacobi_Kokkos(PetscInt n, PetscInt nblocks, PetscInt nsize, const PetscInt *bsizes, MatScalar *diag_ptr_h) :
22     n(n), nblocks(nblocks), nsize(nsize), bs_dual("bs_dual", nblocks + 1), bs2_dual("bs2_dual", nblocks + 1), matIdx_dual("matIdx_dual", n) {
23     PetscScalarKokkosViewHost diag_h(diag_ptr_h, nsize);
24 
25     auto diag_d = Kokkos::create_mirror_view(DefaultMemorySpace(), diag_h);
26     diag_dual   = PetscScalarKokkosDualView(diag_d, diag_h);
27     PetscCallVoid(UpdateOffsetsOnDevice(bsizes, diag_ptr_h));
28   }
29 
30   PetscErrorCode UpdateOffsetsOnDevice(const PetscInt *bsizes, MatScalar *diag_ptr_h) {
31     PetscFunctionBegin;
32     PetscCheck(diag_dual.view_host().data() == diag_ptr_h, PETSC_COMM_SELF, PETSC_ERR_PLIB, "Host pointer has changed since last call");
33     PetscCall(ComputeOffsetsOnHost(bsizes));
34 
35     PetscCallCXX(bs_dual.modify_host());
36     PetscCallCXX(bs2_dual.modify_host());
37     PetscCallCXX(matIdx_dual.modify_host());
38     PetscCallCXX(diag_dual.modify_host());
39 
40     PetscCallCXX(bs_dual.sync_device());
41     PetscCallCXX(bs2_dual.sync_device());
42     PetscCallCXX(matIdx_dual.sync_device());
43     PetscCallCXX(diag_dual.sync_device());
44     PetscCall(PetscLogCpuToGpu(sizeof(PetscInt) * (2 * nblocks + 2 + n) + sizeof(MatScalar) * nsize));
45     PetscFunctionReturn(0);
46   }
47 
48 private:
49   PetscErrorCode ComputeOffsetsOnHost(const PetscInt *bsizes) {
50     PetscInt *bs_h     = bs_dual.view_host().data();
51     PetscInt *bs2_h    = bs2_dual.view_host().data();
52     PetscInt *matIdx_h = matIdx_dual.view_host().data();
53 
54     PetscFunctionBegin;
55     bs_h[0] = bs2_h[0] = 0;
56     for (PetscInt i = 0; i < nblocks; i++) {
57       bs_h[i + 1]  = bs_h[i] + bsizes[i];
58       bs2_h[i + 1] = bs2_h[i] + bsizes[i] * bsizes[i];
59       for (PetscInt j = 0; j < bsizes[i]; j++) matIdx_h[bs_h[i] + j] = i;
60     }
61     PetscFunctionReturn(0);
62   }
63 };
64 
65 static PetscErrorCode PCApply_VPBJacobi_Kokkos(PC pc, Vec x, Vec y) {
66   PC_VPBJacobi              *jac   = (PC_VPBJacobi *)pc->data;
67   PC_VPBJacobi_Kokkos       *pckok = static_cast<PC_VPBJacobi_Kokkos *>(jac->spptr);
68   ConstPetscScalarKokkosView xv;
69   PetscScalarKokkosView      yv;
70   PetscScalarKokkosView      diag   = pckok->diag_dual.view_device();
71   PetscIntKokkosView         bs     = pckok->bs_dual.view_device();
72   PetscIntKokkosView         bs2    = pckok->bs2_dual.view_device();
73   PetscIntKokkosView         matIdx = pckok->matIdx_dual.view_device();
74 
75   PetscFunctionBegin;
76   PetscCall(PetscLogGpuTimeBegin());
77   VecErrorIfNotKokkos(x);
78   VecErrorIfNotKokkos(y);
79   PetscCall(VecGetKokkosView(x, &xv));
80   PetscCall(VecGetKokkosViewWrite(y, &yv));
81   PetscCallCXX(Kokkos::parallel_for(
82     "PCApply_VPBJacobi_Kokkos", pckok->n, KOKKOS_LAMBDA(PetscInt row) {
83       const PetscScalar *Ap, *xp;
84       PetscScalar       *yp;
85       PetscInt           i, j, k, m;
86 
87       k  = matIdx(row);       /* k-th block/matrix */
88       m  = bs(k + 1) - bs(k); /* block size of the k-th block */
89       i  = row - bs(k);       /* i-th row of the block */
90       Ap = &diag(bs2(k) + i); /* Ap points to the first entry of i-th row */
91       xp = &xv(bs(k));
92       yp = &yv(bs(k));
93 
94       yp[i] = 0.0;
95       for (j = 0; j < m; j++) {
96         yp[i] += Ap[0] * xp[j];
97         Ap += m;
98       }
99     }));
100   PetscCall(VecRestoreKokkosView(x, &xv));
101   PetscCall(VecRestoreKokkosViewWrite(y, &yv));
102   PetscCall(PetscLogGpuFlops(pckok->nsize * 2)); /* FMA on entries in all blocks */
103   PetscCall(PetscLogGpuTimeEnd());
104   PetscFunctionReturn(0);
105 }
106 
107 static PetscErrorCode PCDestroy_VPBJacobi_Kokkos(PC pc) {
108   PC_VPBJacobi *jac = (PC_VPBJacobi *)pc->data;
109 
110   PetscFunctionBegin;
111   PetscCallCXX(delete static_cast<PC_VPBJacobi_Kokkos *>(jac->spptr));
112   PetscCall(PCDestroy_VPBJacobi(pc));
113   PetscFunctionReturn(0);
114 }
115 
116 PETSC_INTERN PetscErrorCode PCSetUp_VPBJacobi_Kokkos(PC pc) {
117   PC_VPBJacobi        *jac   = (PC_VPBJacobi *)pc->data;
118   PC_VPBJacobi_Kokkos *pckok = static_cast<PC_VPBJacobi_Kokkos *>(jac->spptr);
119   PetscInt             i, n, nblocks, nsize = 0;
120   const PetscInt      *bsizes;
121 
122   PetscFunctionBegin;
123   PetscCall(PCSetUp_VPBJacobi_Host(pc)); /* Compute the inverse on host now. Might worth doing it on device directly */
124   PetscCall(MatGetVariableBlockSizes(pc->pmat, &nblocks, &bsizes));
125   for (i = 0; i < nblocks; i++) nsize += bsizes[i] * bsizes[i];
126   PetscCall(MatGetLocalSize(pc->pmat, &n, NULL));
127 
128   /* If one calls MatSetVariableBlockSizes() multiple times and sizes have been changed (is it allowed?), we delete the old and rebuild anyway */
129   if (pckok && (pckok->n != n || pckok->nblocks != nblocks || pckok->nsize != nsize)) {
130     PetscCallCXX(delete pckok);
131     pckok = nullptr;
132   }
133 
134   if (!pckok) { /* allocate the struct along with the helper arrays from the scatch */
135     PetscCallCXX(jac->spptr = new PC_VPBJacobi_Kokkos(n, nblocks, nsize, bsizes, jac->diag));
136   } else { /* update the value only */
137     PetscCall(pckok->UpdateOffsetsOnDevice(bsizes, jac->diag));
138   }
139 
140   pc->ops->apply   = PCApply_VPBJacobi_Kokkos;
141   pc->ops->destroy = PCDestroy_VPBJacobi_Kokkos;
142   PetscFunctionReturn(0);
143 }
144