1 #include <petsc_kokkos.hpp> 2 #include <petscvec_kokkos.hpp> 3 #include <../src/vec/vec/impls/seq/kokkos/veckokkosimpl.hpp> 4 #include <petscdevice.h> 5 #include <../src/ksp/pc/impls/vpbjacobi/vpbjacobi.h> 6 #include <../src/mat/impls/aij/seq/kokkos/aijkok.hpp> // for MatInvertVariableBlockDiagonal_SeqAIJKokkos 7 #include <../src/mat/impls/aij/mpi/mpiaij.h> // for Mat_MPIAIJ 8 #include <KokkosBlas2_gemv.hpp> 9 10 /* A class that manages helper arrays assisting parallel PCApply() with Kokkos */ 11 struct PC_VPBJacobi_Kokkos { 12 /* Cache the old sizes to check if we need realloc */ 13 PetscInt n; /* number of rows of the local matrix */ 14 PetscInt nblocks; /* number of point blocks */ 15 PetscInt nsize; /* sum of sizes (elements) of the point blocks */ 16 17 /* Helper arrays that are pre-computed on host and then copied to device. 18 bs: [nblocks+1], "csr" version of bsizes[] 19 bs2: [nblocks+1], "csr" version of squares of bsizes[] 20 blkMap: [n], row i of the local matrix belongs to the blkMap[i]-th block 21 */ 22 PetscIntKokkosDualView bs_dual, bs2_dual, blkMap_dual; 23 PetscScalarKokkosView diag; // buffer to store diagonal blocks 24 PetscScalarKokkosView work; // work buffer, with the same size as diag[] 25 PetscLogDouble setupFlops; 26 27 // clang-format off 28 // n: size of the matrix 29 // nblocks: number of blocks 30 // nsize: sum bsizes[i]^2 for i=0..nblocks 31 // bsizes[nblocks]: sizes of blocks 32 PC_VPBJacobi_Kokkos(PetscInt n, PetscInt nblocks, PetscInt nsize, const PetscInt *bsizes) : 33 n(n), nblocks(nblocks), nsize(nsize), bs_dual(NoInit("bs_dual"), nblocks + 1), 34 bs2_dual(NoInit("bs2_dual"), nblocks + 1), blkMap_dual(NoInit("blkMap_dual"), n), 35 diag(NoInit("diag"), nsize), work(NoInit("work"), nsize) 36 { 37 PetscCallVoid(BuildHelperArrays(bsizes)); 38 } 39 // clang-format on 40 41 private: 42 PetscErrorCode BuildHelperArrays(const PetscInt *bsizes) 43 { 44 PetscInt *bs_h = bs_dual.view_host().data(); 45 PetscInt *bs2_h = bs2_dual.view_host().data(); 46 PetscInt *blkMap_h = blkMap_dual.view_host().data(); 47 48 PetscFunctionBegin; 49 setupFlops = 0.0; 50 bs_h[0] = bs2_h[0] = 0; 51 for (PetscInt i = 0; i < nblocks; i++) { 52 PetscInt m = bsizes[i]; 53 bs_h[i + 1] = bs_h[i] + m; 54 bs2_h[i + 1] = bs2_h[i] + m * m; 55 for (PetscInt j = 0; j < m; j++) blkMap_h[bs_h[i] + j] = i; 56 // m^3/3 FMA for A=LU factorization; m^3 FMA for solving (LU)X=I to get the inverse 57 setupFlops += 8.0 * m * m * m / 3; 58 } 59 60 PetscCallCXX(bs_dual.modify_host()); 61 PetscCallCXX(bs2_dual.modify_host()); 62 PetscCallCXX(blkMap_dual.modify_host()); 63 PetscCallCXX(bs_dual.sync_device()); 64 PetscCallCXX(bs2_dual.sync_device()); 65 PetscCallCXX(blkMap_dual.sync_device()); 66 PetscCall(PetscLogCpuToGpu(sizeof(PetscInt) * (2 * (nblocks + 1) + n))); 67 PetscFunctionReturn(PETSC_SUCCESS); 68 } 69 }; 70 71 template <PetscBool transpose> 72 static PetscErrorCode PCApplyOrTranspose_VPBJacobi_Kokkos(PC pc, Vec x, Vec y) 73 { 74 PC_VPBJacobi *jac = (PC_VPBJacobi *)pc->data; 75 PC_VPBJacobi_Kokkos *pckok = static_cast<PC_VPBJacobi_Kokkos *>(jac->spptr); 76 ConstPetscScalarKokkosView xv; 77 PetscScalarKokkosView yv; 78 PetscScalarKokkosView diag = pckok->diag; 79 PetscIntKokkosView bs = pckok->bs_dual.view_device(); 80 PetscIntKokkosView bs2 = pckok->bs2_dual.view_device(); 81 PetscIntKokkosView blkMap = pckok->blkMap_dual.view_device(); 82 const char *label = transpose ? "PCApplyTranspose_VPBJacobi" : "PCApply_VPBJacobi"; 83 84 PetscFunctionBegin; 85 PetscCall(PetscLogGpuTimeBegin()); 86 VecErrorIfNotKokkos(x); 87 VecErrorIfNotKokkos(y); 88 PetscCall(VecGetKokkosView(x, &xv)); 89 PetscCall(VecGetKokkosViewWrite(y, &yv)); 90 #if 0 // TODO: Why the TeamGemv version is 2x worse than the naive one? 91 PetscCallCXX(Kokkos::parallel_for( 92 label, Kokkos::TeamPolicy<>(PetscGetKokkosExecutionSpace(), jac->nblocks, Kokkos::AUTO()), KOKKOS_LAMBDA(const KokkosTeamMemberType &team) { 93 PetscInt bid = team.league_rank(); // block id 94 PetscInt n = bs(bid + 1) - bs(bid); // size of this block 95 const PetscScalar *bbuf = &diag(bs2(bid)); 96 const PetscScalar *xbuf = &xv(bs(bid)); 97 PetscScalar *ybuf = &yv(bs(bid)); 98 const auto &B = Kokkos::View<const PetscScalar **, Kokkos::LayoutLeft>(bbuf, n, n); // wrap it in a 2D view in column-major order 99 const auto &x1 = ConstPetscScalarKokkosView(xbuf, n); 100 const auto &y1 = PetscScalarKokkosView(ybuf, n); 101 if (transpose) { 102 KokkosBlas::TeamGemv<KokkosTeamMemberType, KokkosBlas::Trans::Transpose>::invoke(team, 1., B, x1, 0., y1); // y1 = 0.0 * y1 + 1.0 * B^T * x1 103 } else { 104 KokkosBlas::TeamGemv<KokkosTeamMemberType, KokkosBlas::Trans::NoTranspose>::invoke(team, 1., B, x1, 0., y1); // y1 = 0.0 * y1 + 1.0 * B * x1 105 } 106 })); 107 #else 108 PetscCallCXX(Kokkos::parallel_for( 109 label, Kokkos::RangePolicy<>(PetscGetKokkosExecutionSpace(), 0, pckok->n), KOKKOS_LAMBDA(PetscInt row) { 110 const PetscScalar *Bp, *xp; 111 PetscScalar *yp; 112 PetscInt i, j, k, m; 113 114 k = blkMap(row); /* k-th block/matrix */ 115 m = bs(k + 1) - bs(k); /* block size of the k-th block */ 116 i = row - bs(k); /* i-th row of the block */ 117 Bp = &diag(bs2(k) + i * (transpose ? m : 1)); /* Bp points to the first entry of i-th row/column */ 118 xp = &xv(bs(k)); 119 yp = &yv(bs(k)); 120 121 yp[i] = 0.0; 122 for (j = 0; j < m; j++) { 123 yp[i] += Bp[0] * xp[j]; 124 Bp += transpose ? 1 : m; 125 } 126 })); 127 #endif 128 PetscCall(VecRestoreKokkosView(x, &xv)); 129 PetscCall(VecRestoreKokkosViewWrite(y, &yv)); 130 PetscCall(PetscLogGpuFlops(pckok->nsize * 2)); /* FMA on entries in all blocks */ 131 PetscCall(PetscLogGpuTimeEnd()); 132 PetscFunctionReturn(PETSC_SUCCESS); 133 } 134 135 static PetscErrorCode PCDestroy_VPBJacobi_Kokkos(PC pc) 136 { 137 PC_VPBJacobi *jac = (PC_VPBJacobi *)pc->data; 138 139 PetscFunctionBegin; 140 PetscCallCXX(delete static_cast<PC_VPBJacobi_Kokkos *>(jac->spptr)); 141 PetscCall(PCDestroy_VPBJacobi(pc)); 142 PetscFunctionReturn(PETSC_SUCCESS); 143 } 144 145 PETSC_INTERN PetscErrorCode PCSetUp_VPBJacobi_Kokkos(PC pc) 146 { 147 PC_VPBJacobi *jac = (PC_VPBJacobi *)pc->data; 148 PC_VPBJacobi_Kokkos *pckok = static_cast<PC_VPBJacobi_Kokkos *>(jac->spptr); 149 PetscInt i, nlocal, nblocks, nsize = 0; 150 const PetscInt *bsizes; 151 PetscBool ismpi; 152 Mat A; 153 154 PetscFunctionBegin; 155 PetscCall(MatGetVariableBlockSizes(pc->pmat, &nblocks, &bsizes)); 156 PetscCall(MatGetLocalSize(pc->pmat, &nlocal, NULL)); 157 PetscCheck(!nlocal || nblocks, PETSC_COMM_SELF, PETSC_ERR_ARG_WRONGSTATE, "Must call MatSetVariableBlockSizes() before using PCVPBJACOBI"); 158 159 if (!jac->diag) { 160 PetscInt max_bs = -1, min_bs = PETSC_MAX_INT; 161 for (i = 0; i < nblocks; i++) { 162 min_bs = PetscMin(min_bs, bsizes[i]); 163 max_bs = PetscMax(max_bs, bsizes[i]); 164 nsize += bsizes[i] * bsizes[i]; 165 } 166 jac->nblocks = nblocks; 167 jac->min_bs = min_bs; 168 jac->max_bs = max_bs; 169 } 170 171 // If one calls MatSetVariableBlockSizes() multiple times and sizes have been changed (is it allowed?), we delete the old and rebuild anyway 172 if (pckok && (pckok->n != nlocal || pckok->nblocks != nblocks || pckok->nsize != nsize)) { 173 PetscCallCXX(delete pckok); 174 pckok = nullptr; 175 } 176 177 PetscCall(PetscLogGpuTimeBegin()); 178 if (!pckok) { 179 PetscCallCXX(pckok = new PC_VPBJacobi_Kokkos(nlocal, nblocks, nsize, bsizes)); 180 jac->spptr = pckok; 181 } 182 183 // Extract diagonal blocks from the matrix and compute their inverse 184 const auto &bs = pckok->bs_dual.view_device(); 185 const auto &bs2 = pckok->bs2_dual.view_device(); 186 const auto &blkMap = pckok->blkMap_dual.view_device(); 187 PetscCall(PetscObjectBaseTypeCompare((PetscObject)pc->pmat, MATMPIAIJ, &ismpi)); 188 A = ismpi ? static_cast<Mat_MPIAIJ *>(pc->pmat->data)->A : pc->pmat; 189 PetscCall(MatInvertVariableBlockDiagonal_SeqAIJKokkos(A, bs, bs2, blkMap, pckok->work, pckok->diag)); 190 pc->ops->apply = PCApplyOrTranspose_VPBJacobi_Kokkos<PETSC_FALSE>; 191 pc->ops->applytranspose = PCApplyOrTranspose_VPBJacobi_Kokkos<PETSC_TRUE>; 192 pc->ops->destroy = PCDestroy_VPBJacobi_Kokkos; 193 PetscCall(PetscLogGpuTimeEnd()); 194 PetscCall(PetscLogGpuFlops(pckok->setupFlops)); 195 PetscFunctionReturn(PETSC_SUCCESS); 196 } 197