1 #include <petsc_kokkos.hpp>
2 #include <petscvec_kokkos.hpp>
3 #include <../src/vec/vec/impls/seq/kokkos/veckokkosimpl.hpp>
4 #include <petscdevice.h>
5 #include <../src/ksp/pc/impls/vpbjacobi/vpbjacobi.h>
6 #include <../src/mat/impls/aij/seq/kokkos/aijkok.hpp> // for MatInvertVariableBlockDiagonal_SeqAIJKokkos
7 #include <../src/mat/impls/aij/mpi/mpiaij.h> // for Mat_MPIAIJ
8 #include <KokkosBlas2_gemv.hpp>
9
10 /* A class that manages helper arrays assisting parallel PCApply() with Kokkos */
11 struct PC_VPBJacobi_Kokkos {
12 /* Cache the old sizes to check if we need realloc */
13 PetscInt n; /* number of rows of the local matrix */
14 PetscInt nblocks; /* number of point blocks */
15 PetscInt nsize; /* sum of sizes (elements) of the point blocks */
16
17 /* Helper arrays that are pre-computed on host and then copied to device.
18 bs: [nblocks+1], "csr" version of bsizes[]
19 bs2: [nblocks+1], "csr" version of squares of bsizes[]
20 blkMap: [n], row i of the local matrix belongs to the blkMap[i]-th block
21 */
22 PetscIntKokkosDualView bs_dual, bs2_dual, blkMap_dual;
23 PetscScalarKokkosView diag; // buffer to store diagonal blocks
24 PetscScalarKokkosView work; // work buffer, with the same size as diag[]
25 PetscLogDouble setupFlops;
26
27 // clang-format off
28 // n: size of the matrix
29 // nblocks: number of blocks
30 // nsize: sum bsizes[i]^2 for i=0..nblocks
31 // bsizes[nblocks]: sizes of blocks
PC_VPBJacobi_KokkosPC_VPBJacobi_Kokkos32 PC_VPBJacobi_Kokkos(PetscInt n, PetscInt nblocks, PetscInt nsize, const PetscInt *bsizes) :
33 n(n), nblocks(nblocks), nsize(nsize), bs_dual(NoInit("bs_dual"), nblocks + 1),
34 bs2_dual(NoInit("bs2_dual"), nblocks + 1), blkMap_dual(NoInit("blkMap_dual"), n),
35 diag(NoInit("diag"), nsize), work(NoInit("work"), nsize)
36 {
37 PetscCallVoid(BuildHelperArrays(bsizes));
38 }
39 // clang-format on
40
41 private:
BuildHelperArraysPC_VPBJacobi_Kokkos42 PetscErrorCode BuildHelperArrays(const PetscInt *bsizes)
43 {
44 PetscInt *bs_h = bs_dual.view_host().data();
45 PetscInt *bs2_h = bs2_dual.view_host().data();
46 PetscInt *blkMap_h = blkMap_dual.view_host().data();
47
48 PetscFunctionBegin;
49 setupFlops = 0.0;
50 bs_h[0] = bs2_h[0] = 0;
51 for (PetscInt i = 0; i < nblocks; i++) {
52 PetscInt m = bsizes[i];
53 bs_h[i + 1] = bs_h[i] + m;
54 bs2_h[i + 1] = bs2_h[i] + m * m;
55 for (PetscInt j = 0; j < m; j++) blkMap_h[bs_h[i] + j] = i;
56 // m^3/3 FMA for A=LU factorization; m^3 FMA for solving (LU)X=I to get the inverse
57 setupFlops += 8.0 * m * m * m / 3;
58 }
59
60 PetscCallCXX(bs_dual.modify_host());
61 PetscCallCXX(bs2_dual.modify_host());
62 PetscCallCXX(blkMap_dual.modify_host());
63 PetscCall(KokkosDualViewSyncDevice(bs_dual, PetscGetKokkosExecutionSpace()));
64 PetscCall(KokkosDualViewSyncDevice(bs2_dual, PetscGetKokkosExecutionSpace()));
65 PetscCall(KokkosDualViewSyncDevice(blkMap_dual, PetscGetKokkosExecutionSpace()));
66 PetscFunctionReturn(PETSC_SUCCESS);
67 }
68 };
69
70 template <PetscBool transpose>
PCApplyOrTranspose_VPBJacobi_Kokkos(PC pc,Vec x,Vec y)71 static PetscErrorCode PCApplyOrTranspose_VPBJacobi_Kokkos(PC pc, Vec x, Vec y)
72 {
73 PC_VPBJacobi *jac = (PC_VPBJacobi *)pc->data;
74 PC_VPBJacobi_Kokkos *pckok = static_cast<PC_VPBJacobi_Kokkos *>(jac->spptr);
75 ConstPetscScalarKokkosView xv;
76 PetscScalarKokkosView yv;
77 PetscScalarKokkosView diag = pckok->diag;
78 PetscIntKokkosView bs = pckok->bs_dual.view_device();
79 PetscIntKokkosView bs2 = pckok->bs2_dual.view_device();
80 PetscIntKokkosView blkMap = pckok->blkMap_dual.view_device();
81 const char *label = transpose ? "PCApplyTranspose_VPBJacobi" : "PCApply_VPBJacobi";
82
83 PetscFunctionBegin;
84 PetscCall(PetscLogGpuTimeBegin());
85 VecErrorIfNotKokkos(x);
86 VecErrorIfNotKokkos(y);
87 PetscCall(VecGetKokkosView(x, &xv));
88 PetscCall(VecGetKokkosViewWrite(y, &yv));
89 #if 0 // TODO: Why the TeamGemv version is 2x worse than the naive one?
90 PetscCallCXX(Kokkos::parallel_for(
91 label, Kokkos::TeamPolicy<>(PetscGetKokkosExecutionSpace(), jac->nblocks, Kokkos::AUTO()), KOKKOS_LAMBDA(const KokkosTeamMemberType &team) {
92 PetscInt bid = team.league_rank(); // block id
93 PetscInt n = bs(bid + 1) - bs(bid); // size of this block
94 const PetscScalar *bbuf = &diag(bs2(bid));
95 const PetscScalar *xbuf = &xv(bs(bid));
96 PetscScalar *ybuf = &yv(bs(bid));
97 const auto &B = Kokkos::View<const PetscScalar **, Kokkos::LayoutLeft>(bbuf, n, n); // wrap it in a 2D view in column-major order
98 const auto &x1 = ConstPetscScalarKokkosView(xbuf, n);
99 const auto &y1 = PetscScalarKokkosView(ybuf, n);
100 if (transpose) {
101 KokkosBlas::TeamGemv<KokkosTeamMemberType, KokkosBlas::Trans::Transpose>::invoke(team, 1., B, x1, 0., y1); // y1 = 0.0 * y1 + 1.0 * B^T * x1
102 } else {
103 KokkosBlas::TeamGemv<KokkosTeamMemberType, KokkosBlas::Trans::NoTranspose>::invoke(team, 1., B, x1, 0., y1); // y1 = 0.0 * y1 + 1.0 * B * x1
104 }
105 }));
106 #else
107 PetscCallCXX(Kokkos::parallel_for(
108 label, Kokkos::RangePolicy<>(PetscGetKokkosExecutionSpace(), 0, pckok->n), KOKKOS_LAMBDA(PetscInt row) {
109 const PetscScalar *Bp, *xp;
110 PetscScalar *yp;
111 PetscInt i, j, k, m;
112
113 k = blkMap(row); /* k-th block/matrix */
114 m = bs(k + 1) - bs(k); /* block size of the k-th block */
115 i = row - bs(k); /* i-th row of the block */
116 Bp = &diag(bs2(k) + i * (transpose ? m : 1)); /* Bp points to the first entry of i-th row/column */
117 xp = &xv(bs(k));
118 yp = &yv(bs(k));
119
120 yp[i] = 0.0;
121 for (j = 0; j < m; j++) {
122 yp[i] += Bp[0] * xp[j];
123 Bp += transpose ? 1 : m;
124 }
125 }));
126 #endif
127 PetscCall(VecRestoreKokkosView(x, &xv));
128 PetscCall(VecRestoreKokkosViewWrite(y, &yv));
129 PetscCall(PetscLogGpuFlops(pckok->nsize * 2)); /* FMA on entries in all blocks */
130 PetscCall(PetscLogGpuTimeEnd());
131 PetscFunctionReturn(PETSC_SUCCESS);
132 }
133
PCDestroy_VPBJacobi_Kokkos(PC pc)134 static PetscErrorCode PCDestroy_VPBJacobi_Kokkos(PC pc)
135 {
136 PC_VPBJacobi *jac = (PC_VPBJacobi *)pc->data;
137
138 PetscFunctionBegin;
139 PetscCallCXX(delete static_cast<PC_VPBJacobi_Kokkos *>(jac->spptr));
140 PetscCall(PCDestroy_VPBJacobi(pc));
141 PetscFunctionReturn(PETSC_SUCCESS);
142 }
143
PCSetUp_VPBJacobi_Kokkos(PC pc,Mat diagVPB)144 PETSC_INTERN PetscErrorCode PCSetUp_VPBJacobi_Kokkos(PC pc, Mat diagVPB)
145 {
146 PC_VPBJacobi *jac = (PC_VPBJacobi *)pc->data;
147 PC_VPBJacobi_Kokkos *pckok = static_cast<PC_VPBJacobi_Kokkos *>(jac->spptr);
148 PetscInt i, nlocal, nblocks, nsize = 0;
149 const PetscInt *bsizes;
150 PetscBool ismpi;
151 Mat A;
152
153 PetscFunctionBegin;
154 PetscCall(MatGetVariableBlockSizes(pc->pmat, &nblocks, &bsizes));
155 PetscCall(MatGetLocalSize(pc->pmat, &nlocal, NULL));
156 PetscCheck(!nlocal || nblocks, PETSC_COMM_SELF, PETSC_ERR_ARG_WRONGSTATE, "Must call MatSetVariableBlockSizes() before using PCVPBJACOBI");
157
158 if (!jac->diag) {
159 PetscInt max_bs = -1, min_bs = PETSC_INT_MAX;
160 for (i = 0; i < nblocks; i++) {
161 min_bs = PetscMin(min_bs, bsizes[i]);
162 max_bs = PetscMax(max_bs, bsizes[i]);
163 nsize += bsizes[i] * bsizes[i];
164 }
165 jac->nblocks = nblocks;
166 jac->min_bs = min_bs;
167 jac->max_bs = max_bs;
168 }
169
170 // If one calls MatSetVariableBlockSizes() multiple times and sizes have been changed (is it allowed?), we delete the old and rebuild anyway
171 if (pckok && (pckok->n != nlocal || pckok->nblocks != nblocks || pckok->nsize != nsize)) {
172 PetscCallCXX(delete pckok);
173 pckok = nullptr;
174 }
175
176 PetscCall(PetscLogGpuTimeBegin());
177 if (!pckok) {
178 PetscCallCXX(pckok = new PC_VPBJacobi_Kokkos(nlocal, nblocks, nsize, bsizes));
179 jac->spptr = pckok;
180 }
181
182 // Extract diagonal blocks from the matrix and compute their inverse
183 const auto &bs = pckok->bs_dual.view_device();
184 const auto &bs2 = pckok->bs2_dual.view_device();
185 const auto &blkMap = pckok->blkMap_dual.view_device();
186 if (diagVPB) { // If caller provided a matrix made of the diagonal blocks, use it
187 PetscCall(PetscObjectBaseTypeCompare((PetscObject)diagVPB, MATMPIAIJ, &ismpi));
188 A = ismpi ? static_cast<Mat_MPIAIJ *>(diagVPB->data)->A : diagVPB;
189 } else {
190 PetscCall(PetscObjectBaseTypeCompare((PetscObject)pc->pmat, MATMPIAIJ, &ismpi));
191 A = ismpi ? static_cast<Mat_MPIAIJ *>(pc->pmat->data)->A : pc->pmat;
192 }
193 PetscCall(MatInvertVariableBlockDiagonal_SeqAIJKokkos(A, bs, bs2, blkMap, pckok->work, pckok->diag));
194 pc->ops->apply = PCApplyOrTranspose_VPBJacobi_Kokkos<PETSC_FALSE>;
195 pc->ops->applytranspose = PCApplyOrTranspose_VPBJacobi_Kokkos<PETSC_TRUE>;
196 pc->ops->destroy = PCDestroy_VPBJacobi_Kokkos;
197 PetscCall(PetscLogGpuTimeEnd());
198 PetscCall(PetscLogGpuFlops(pckok->setupFlops));
199 PetscFunctionReturn(PETSC_SUCCESS);
200 }
201