xref: /petsc/src/ksp/pc/impls/vpbjacobi/kokkos/vpbjacobi_kok.kokkos.cxx (revision db9fc945a3c8a0a4b4e9c7cf33c611e450ed8efd)
1 #include <petscvec_kokkos.hpp>
2 #include <../src/vec/vec/impls/seq/kokkos/veckokkosimpl.hpp>
3 #include <petscdevice.h>
4 #include <../src/ksp/pc/impls/vpbjacobi/vpbjacobi.h>
5 #include <../src/mat/impls/aij/seq/kokkos/aijkok.hpp> // for MatInvertVariableBlockDiagonal_SeqAIJKokkos
6 #include <KokkosBlas2_gemv.hpp>
7 
8 /* A class that manages helper arrays assisting parallel PCApply() with Kokkos */
9 struct PC_VPBJacobi_Kokkos {
10   /* Cache the old sizes to check if we need realloc */
11   PetscInt n;       /* number of rows of the local matrix */
12   PetscInt nblocks; /* number of point blocks */
13   PetscInt nsize;   /* sum of sizes (elements) of the point blocks */
14 
15   /* Helper arrays that are pre-computed on host and then copied to device.
16     bs:     [nblocks+1], "csr" version of bsizes[]
17     bs2:    [nblocks+1], "csr" version of squares of bsizes[]
18     blkMap: [n], row i of the local matrix belongs to the blkMap[i]-th block
19   */
20   PetscIntKokkosDualView bs_dual, bs2_dual, blkMap_dual;
21   PetscScalarKokkosView  diag; // buffer to store diagonal blocks
22   PetscScalarKokkosView  work; // work buffer, with the same size as diag[]
23   PetscLogDouble         setupFlops;
24 
25   // clang-format off
26   // n:               size of the matrix
27   // nblocks:         number of blocks
28   // nsize:           sum bsizes[i]^2 for i=0..nblocks
29   // bsizes[nblocks]: sizes of blocks
30   PC_VPBJacobi_Kokkos(PetscInt n, PetscInt nblocks, PetscInt nsize, const PetscInt *bsizes) :
31     n(n), nblocks(nblocks), nsize(nsize), bs_dual(NoInit("bs_dual"), nblocks + 1),
32     bs2_dual(NoInit("bs2_dual"), nblocks + 1), blkMap_dual(NoInit("blkMap_dual"), n),
33     diag(NoInit("diag"), nsize), work(NoInit("work"), nsize)
34   {
35     PetscCallVoid(BuildHelperArrays(bsizes));
36   }
37   // clang-format on
38 
39 private:
40   PetscErrorCode BuildHelperArrays(const PetscInt *bsizes)
41   {
42     PetscInt *bs_h     = bs_dual.view_host().data();
43     PetscInt *bs2_h    = bs2_dual.view_host().data();
44     PetscInt *blkMap_h = blkMap_dual.view_host().data();
45 
46     PetscFunctionBegin;
47     setupFlops = 0.0;
48     bs_h[0] = bs2_h[0] = 0;
49     for (PetscInt i = 0; i < nblocks; i++) {
50       PetscInt m   = bsizes[i];
51       bs_h[i + 1]  = bs_h[i] + m;
52       bs2_h[i + 1] = bs2_h[i] + m * m;
53       for (PetscInt j = 0; j < m; j++) blkMap_h[bs_h[i] + j] = i;
54       // m^3/3 FMA for A=LU factorization; m^3 FMA for solving (LU)X=I to get the inverse
55       setupFlops += 8.0 * m * m * m / 3;
56     }
57 
58     PetscCallCXX(bs_dual.modify_host());
59     PetscCallCXX(bs2_dual.modify_host());
60     PetscCallCXX(blkMap_dual.modify_host());
61     PetscCallCXX(bs_dual.sync_device());
62     PetscCallCXX(bs2_dual.sync_device());
63     PetscCallCXX(blkMap_dual.sync_device());
64     PetscCall(PetscLogCpuToGpu(sizeof(PetscInt) * (2 * (nblocks + 1) + n)));
65     PetscFunctionReturn(PETSC_SUCCESS);
66   }
67 };
68 
69 template <PetscBool transpose>
70 static PetscErrorCode PCApplyOrTranspose_VPBJacobi_Kokkos(PC pc, Vec x, Vec y)
71 {
72   PC_VPBJacobi              *jac   = (PC_VPBJacobi *)pc->data;
73   PC_VPBJacobi_Kokkos       *pckok = static_cast<PC_VPBJacobi_Kokkos *>(jac->spptr);
74   ConstPetscScalarKokkosView xv;
75   PetscScalarKokkosView      yv;
76   PetscScalarKokkosView      diag   = pckok->diag;
77   PetscIntKokkosView         bs     = pckok->bs_dual.view_device();
78   PetscIntKokkosView         bs2    = pckok->bs2_dual.view_device();
79   PetscIntKokkosView         blkMap = pckok->blkMap_dual.view_device();
80   const char                *label  = transpose ? "PCApplyTranspose_VPBJacobi" : "PCApply_VPBJacobi";
81 
82   PetscFunctionBegin;
83   PetscCall(PetscLogGpuTimeBegin());
84   VecErrorIfNotKokkos(x);
85   VecErrorIfNotKokkos(y);
86   PetscCall(VecGetKokkosView(x, &xv));
87   PetscCall(VecGetKokkosViewWrite(y, &yv));
88 #if 0 // TODO: Why the TeamGemv version is 2x worse than the naive one?
89   PetscCallCXX(Kokkos::parallel_for(
90     label, Kokkos::TeamPolicy<>(jac->nblocks, Kokkos::AUTO()), KOKKOS_LAMBDA(const KokkosTeamMemberType &team) {
91       PetscInt           bid  = team.league_rank();    // block id
92       PetscInt           n    = bs(bid + 1) - bs(bid); // size of this block
93       const PetscScalar *bbuf = &diag(bs2(bid));
94       const PetscScalar *xbuf = &xv(bs(bid));
95       PetscScalar       *ybuf = &yv(bs(bid));
96       const auto        &B    = Kokkos::View<const PetscScalar **, Kokkos::LayoutLeft>(bbuf, n, n); // wrap it in a 2D view in column-major order
97       const auto        &x1   = ConstPetscScalarKokkosView(xbuf, n);
98       const auto        &y1   = PetscScalarKokkosView(ybuf, n);
99       if (transpose) {
100         KokkosBlas::TeamGemv<KokkosTeamMemberType, KokkosBlas::Trans::Transpose>::invoke(team, 1., B, x1, 0., y1); // y1 = 0.0 * y1 + 1.0 * B^T * x1
101       } else {
102         KokkosBlas::TeamGemv<KokkosTeamMemberType, KokkosBlas::Trans::NoTranspose>::invoke(team, 1., B, x1, 0., y1); // y1 = 0.0 * y1 + 1.0 * B * x1
103       }
104     }));
105 #else
106   PetscCallCXX(Kokkos::parallel_for(
107     label, pckok->n, KOKKOS_LAMBDA(PetscInt row) {
108       const PetscScalar *Bp, *xp;
109       PetscScalar       *yp;
110       PetscInt           i, j, k, m;
111 
112       k  = blkMap(row);                             /* k-th block/matrix */
113       m  = bs(k + 1) - bs(k);                       /* block size of the k-th block */
114       i  = row - bs(k);                             /* i-th row of the block */
115       Bp = &diag(bs2(k) + i * (transpose ? m : 1)); /* Bp points to the first entry of i-th row/column */
116       xp = &xv(bs(k));
117       yp = &yv(bs(k));
118 
119       yp[i] = 0.0;
120       for (j = 0; j < m; j++) {
121         yp[i] += Bp[0] * xp[j];
122         Bp += transpose ? 1 : m;
123       }
124     }));
125 #endif
126   PetscCall(VecRestoreKokkosView(x, &xv));
127   PetscCall(VecRestoreKokkosViewWrite(y, &yv));
128   PetscCall(PetscLogGpuFlops(pckok->nsize * 2)); /* FMA on entries in all blocks */
129   PetscCall(PetscLogGpuTimeEnd());
130   PetscFunctionReturn(PETSC_SUCCESS);
131 }
132 
133 static PetscErrorCode PCDestroy_VPBJacobi_Kokkos(PC pc)
134 {
135   PC_VPBJacobi *jac = (PC_VPBJacobi *)pc->data;
136 
137   PetscFunctionBegin;
138   PetscCallCXX(delete static_cast<PC_VPBJacobi_Kokkos *>(jac->spptr));
139   PetscCall(PCDestroy_VPBJacobi(pc));
140   PetscFunctionReturn(PETSC_SUCCESS);
141 }
142 
143 PETSC_INTERN PetscErrorCode PCSetUp_VPBJacobi_Kokkos(PC pc)
144 {
145   PC_VPBJacobi        *jac   = (PC_VPBJacobi *)pc->data;
146   PC_VPBJacobi_Kokkos *pckok = static_cast<PC_VPBJacobi_Kokkos *>(jac->spptr);
147   PetscInt             i, nlocal, nblocks, nsize = 0;
148   const PetscInt      *bsizes;
149 
150   PetscFunctionBegin;
151   PetscCall(MatGetVariableBlockSizes(pc->pmat, &nblocks, &bsizes));
152   PetscCall(MatGetLocalSize(pc->pmat, &nlocal, NULL));
153   PetscCheck(!nlocal || nblocks, PETSC_COMM_SELF, PETSC_ERR_ARG_WRONGSTATE, "Must call MatSetVariableBlockSizes() before using PCVPBJACOBI");
154 
155   if (!jac->diag) {
156     PetscInt max_bs = -1, min_bs = PETSC_MAX_INT;
157     for (i = 0; i < nblocks; i++) {
158       min_bs = PetscMin(min_bs, bsizes[i]);
159       max_bs = PetscMax(max_bs, bsizes[i]);
160       nsize += bsizes[i] * bsizes[i];
161     }
162     jac->nblocks = nblocks;
163     jac->min_bs  = min_bs;
164     jac->max_bs  = max_bs;
165   }
166 
167   // If one calls MatSetVariableBlockSizes() multiple times and sizes have been changed (is it allowed?), we delete the old and rebuild anyway
168   if (pckok && (pckok->n != nlocal || pckok->nblocks != nblocks || pckok->nsize != nsize)) {
169     PetscCallCXX(delete pckok);
170     pckok = nullptr;
171   }
172 
173   PetscCall(PetscLogGpuTimeBegin());
174   if (!pckok) {
175     PetscCallCXX(pckok = new PC_VPBJacobi_Kokkos(nlocal, nblocks, nsize, bsizes));
176     jac->spptr = pckok;
177   }
178 
179   // Extract diagonal blocks from the matrix and compute their inverse
180   const auto &bs     = pckok->bs_dual.view_device();
181   const auto &bs2    = pckok->bs2_dual.view_device();
182   const auto &blkMap = pckok->blkMap_dual.view_device();
183   PetscCall(MatInvertVariableBlockDiagonal_SeqAIJKokkos(pc->pmat, bs, bs2, blkMap, pckok->work, pckok->diag));
184   pc->ops->apply          = PCApplyOrTranspose_VPBJacobi_Kokkos<PETSC_FALSE>;
185   pc->ops->applytranspose = PCApplyOrTranspose_VPBJacobi_Kokkos<PETSC_TRUE>;
186   pc->ops->destroy        = PCDestroy_VPBJacobi_Kokkos;
187   PetscCall(PetscLogGpuTimeEnd());
188   PetscCall(PetscLogGpuFlops(pckok->setupFlops));
189   PetscFunctionReturn(PETSC_SUCCESS);
190 }
191