1d326c3f1SJunchao Zhang #include <petsc_kokkos.hpp> 2f1be3500SJunchao Zhang #include <petscvec_kokkos.hpp> 3f1be3500SJunchao Zhang #include <../src/vec/vec/impls/seq/kokkos/veckokkosimpl.hpp> 4f1be3500SJunchao Zhang #include <petscdevice.h> 5f1be3500SJunchao Zhang #include <../src/ksp/pc/impls/vpbjacobi/vpbjacobi.h> 69d13fa56SJunchao Zhang #include <../src/mat/impls/aij/seq/kokkos/aijkok.hpp> // for MatInvertVariableBlockDiagonal_SeqAIJKokkos 75994c5adSJunchao Zhang #include <../src/mat/impls/aij/mpi/mpiaij.h> // for Mat_MPIAIJ 86dc07963SJunchao Zhang #include <KokkosBlas2_gemv.hpp> 9f1be3500SJunchao Zhang 10f1be3500SJunchao Zhang /* A class that manages helper arrays assisting parallel PCApply() with Kokkos */ 11f1be3500SJunchao Zhang struct PC_VPBJacobi_Kokkos { 12f1be3500SJunchao Zhang /* Cache the old sizes to check if we need realloc */ 13f1be3500SJunchao Zhang PetscInt n; /* number of rows of the local matrix */ 14f1be3500SJunchao Zhang PetscInt nblocks; /* number of point blocks */ 159d13fa56SJunchao Zhang PetscInt nsize; /* sum of sizes (elements) of the point blocks */ 16f1be3500SJunchao Zhang 17f1be3500SJunchao Zhang /* Helper arrays that are pre-computed on host and then copied to device. 18f1be3500SJunchao Zhang bs: [nblocks+1], "csr" version of bsizes[] 19f1be3500SJunchao Zhang bs2: [nblocks+1], "csr" version of squares of bsizes[] 209d13fa56SJunchao Zhang blkMap: [n], row i of the local matrix belongs to the blkMap[i]-th block 21f1be3500SJunchao Zhang */ 229d13fa56SJunchao Zhang PetscIntKokkosDualView bs_dual, bs2_dual, blkMap_dual; 239d13fa56SJunchao Zhang PetscScalarKokkosView diag; // buffer to store diagonal blocks 249d13fa56SJunchao Zhang PetscScalarKokkosView work; // work buffer, with the same size as diag[] 259d13fa56SJunchao Zhang PetscLogDouble setupFlops; 26f1be3500SJunchao Zhang 279d13fa56SJunchao Zhang // clang-format off 289d13fa56SJunchao Zhang // n: size of the matrix 299d13fa56SJunchao Zhang // nblocks: number of blocks 309d13fa56SJunchao Zhang // nsize: sum bsizes[i]^2 for i=0..nblocks 319d13fa56SJunchao Zhang // bsizes[nblocks]: sizes of blocks 329d13fa56SJunchao Zhang PC_VPBJacobi_Kokkos(PetscInt n, PetscInt nblocks, PetscInt nsize, const PetscInt *bsizes) : 339d13fa56SJunchao Zhang n(n), nblocks(nblocks), nsize(nsize), bs_dual(NoInit("bs_dual"), nblocks + 1), 349d13fa56SJunchao Zhang bs2_dual(NoInit("bs2_dual"), nblocks + 1), blkMap_dual(NoInit("blkMap_dual"), n), 359d13fa56SJunchao Zhang diag(NoInit("diag"), nsize), work(NoInit("work"), nsize) 36d71ae5a4SJacob Faibussowitsch { 379d13fa56SJunchao Zhang PetscCallVoid(BuildHelperArrays(bsizes)); 38f1be3500SJunchao Zhang } 399d13fa56SJunchao Zhang // clang-format on 40f1be3500SJunchao Zhang 41f1be3500SJunchao Zhang private: 429d13fa56SJunchao Zhang PetscErrorCode BuildHelperArrays(const PetscInt *bsizes) 43d71ae5a4SJacob Faibussowitsch { 44f1be3500SJunchao Zhang PetscInt *bs_h = bs_dual.view_host().data(); 45f1be3500SJunchao Zhang PetscInt *bs2_h = bs2_dual.view_host().data(); 469d13fa56SJunchao Zhang PetscInt *blkMap_h = blkMap_dual.view_host().data(); 47f1be3500SJunchao Zhang 48f1be3500SJunchao Zhang PetscFunctionBegin; 499d13fa56SJunchao Zhang setupFlops = 0.0; 50f1be3500SJunchao Zhang bs_h[0] = bs2_h[0] = 0; 51f1be3500SJunchao Zhang for (PetscInt i = 0; i < nblocks; i++) { 529d13fa56SJunchao Zhang PetscInt m = bsizes[i]; 539d13fa56SJunchao Zhang bs_h[i + 1] = bs_h[i] + m; 549d13fa56SJunchao Zhang bs2_h[i + 1] = bs2_h[i] + m * m; 559d13fa56SJunchao Zhang for (PetscInt j = 0; j < m; j++) blkMap_h[bs_h[i] + j] = i; 569d13fa56SJunchao Zhang // m^3/3 FMA for A=LU factorization; m^3 FMA for solving (LU)X=I to get the inverse 579d13fa56SJunchao Zhang setupFlops += 8.0 * m * m * m / 3; 58f1be3500SJunchao Zhang } 599d13fa56SJunchao Zhang 609d13fa56SJunchao Zhang PetscCallCXX(bs_dual.modify_host()); 619d13fa56SJunchao Zhang PetscCallCXX(bs2_dual.modify_host()); 629d13fa56SJunchao Zhang PetscCallCXX(blkMap_dual.modify_host()); 63*f3d3cd90SZach Atkins PetscCall(KokkosDualViewSyncDevice(bs_dual, PetscGetKokkosExecutionSpace())); 64*f3d3cd90SZach Atkins PetscCall(KokkosDualViewSyncDevice(bs2_dual, PetscGetKokkosExecutionSpace())); 65*f3d3cd90SZach Atkins PetscCall(KokkosDualViewSyncDevice(blkMap_dual, PetscGetKokkosExecutionSpace())); 663ba16761SJacob Faibussowitsch PetscFunctionReturn(PETSC_SUCCESS); 67f1be3500SJunchao Zhang } 68f1be3500SJunchao Zhang }; 69f1be3500SJunchao Zhang 7069eda9daSJed Brown template <PetscBool transpose> 7169eda9daSJed Brown static PetscErrorCode PCApplyOrTranspose_VPBJacobi_Kokkos(PC pc, Vec x, Vec y) 72d71ae5a4SJacob Faibussowitsch { 73f1be3500SJunchao Zhang PC_VPBJacobi *jac = (PC_VPBJacobi *)pc->data; 74f1be3500SJunchao Zhang PC_VPBJacobi_Kokkos *pckok = static_cast<PC_VPBJacobi_Kokkos *>(jac->spptr); 75f1be3500SJunchao Zhang ConstPetscScalarKokkosView xv; 76f1be3500SJunchao Zhang PetscScalarKokkosView yv; 779d13fa56SJunchao Zhang PetscScalarKokkosView diag = pckok->diag; 78f1be3500SJunchao Zhang PetscIntKokkosView bs = pckok->bs_dual.view_device(); 79f1be3500SJunchao Zhang PetscIntKokkosView bs2 = pckok->bs2_dual.view_device(); 809d13fa56SJunchao Zhang PetscIntKokkosView blkMap = pckok->blkMap_dual.view_device(); 816dc07963SJunchao Zhang const char *label = transpose ? "PCApplyTranspose_VPBJacobi" : "PCApply_VPBJacobi"; 82f1be3500SJunchao Zhang 83f1be3500SJunchao Zhang PetscFunctionBegin; 849a56b474SJunchao Zhang PetscCall(PetscLogGpuTimeBegin()); 85f1be3500SJunchao Zhang VecErrorIfNotKokkos(x); 86f1be3500SJunchao Zhang VecErrorIfNotKokkos(y); 87f1be3500SJunchao Zhang PetscCall(VecGetKokkosView(x, &xv)); 88f1be3500SJunchao Zhang PetscCall(VecGetKokkosViewWrite(y, &yv)); 896dc07963SJunchao Zhang #if 0 // TODO: Why the TeamGemv version is 2x worse than the naive one? 906dc07963SJunchao Zhang PetscCallCXX(Kokkos::parallel_for( 91d326c3f1SJunchao Zhang label, Kokkos::TeamPolicy<>(PetscGetKokkosExecutionSpace(), jac->nblocks, Kokkos::AUTO()), KOKKOS_LAMBDA(const KokkosTeamMemberType &team) { 926dc07963SJunchao Zhang PetscInt bid = team.league_rank(); // block id 936dc07963SJunchao Zhang PetscInt n = bs(bid + 1) - bs(bid); // size of this block 946dc07963SJunchao Zhang const PetscScalar *bbuf = &diag(bs2(bid)); 956dc07963SJunchao Zhang const PetscScalar *xbuf = &xv(bs(bid)); 966dc07963SJunchao Zhang PetscScalar *ybuf = &yv(bs(bid)); 976dc07963SJunchao Zhang const auto &B = Kokkos::View<const PetscScalar **, Kokkos::LayoutLeft>(bbuf, n, n); // wrap it in a 2D view in column-major order 986dc07963SJunchao Zhang const auto &x1 = ConstPetscScalarKokkosView(xbuf, n); 996dc07963SJunchao Zhang const auto &y1 = PetscScalarKokkosView(ybuf, n); 1006dc07963SJunchao Zhang if (transpose) { 1016dc07963SJunchao Zhang KokkosBlas::TeamGemv<KokkosTeamMemberType, KokkosBlas::Trans::Transpose>::invoke(team, 1., B, x1, 0., y1); // y1 = 0.0 * y1 + 1.0 * B^T * x1 1026dc07963SJunchao Zhang } else { 1036dc07963SJunchao Zhang KokkosBlas::TeamGemv<KokkosTeamMemberType, KokkosBlas::Trans::NoTranspose>::invoke(team, 1., B, x1, 0., y1); // y1 = 0.0 * y1 + 1.0 * B * x1 1046dc07963SJunchao Zhang } 1056dc07963SJunchao Zhang })); 1066dc07963SJunchao Zhang #else 1079371c9d4SSatish Balay PetscCallCXX(Kokkos::parallel_for( 108d326c3f1SJunchao Zhang label, Kokkos::RangePolicy<>(PetscGetKokkosExecutionSpace(), 0, pckok->n), KOKKOS_LAMBDA(PetscInt row) { 1096dc07963SJunchao Zhang const PetscScalar *Bp, *xp; 110f1be3500SJunchao Zhang PetscScalar *yp; 111f1be3500SJunchao Zhang PetscInt i, j, k, m; 112f1be3500SJunchao Zhang 1139d13fa56SJunchao Zhang k = blkMap(row); /* k-th block/matrix */ 114f1be3500SJunchao Zhang m = bs(k + 1) - bs(k); /* block size of the k-th block */ 115f1be3500SJunchao Zhang i = row - bs(k); /* i-th row of the block */ 1166dc07963SJunchao Zhang Bp = &diag(bs2(k) + i * (transpose ? m : 1)); /* Bp points to the first entry of i-th row/column */ 117f1be3500SJunchao Zhang xp = &xv(bs(k)); 118f1be3500SJunchao Zhang yp = &yv(bs(k)); 119f1be3500SJunchao Zhang 120f1be3500SJunchao Zhang yp[i] = 0.0; 1219371c9d4SSatish Balay for (j = 0; j < m; j++) { 1226dc07963SJunchao Zhang yp[i] += Bp[0] * xp[j]; 1236dc07963SJunchao Zhang Bp += transpose ? 1 : m; 1249371c9d4SSatish Balay } 125f1be3500SJunchao Zhang })); 1266dc07963SJunchao Zhang #endif 127f1be3500SJunchao Zhang PetscCall(VecRestoreKokkosView(x, &xv)); 128f1be3500SJunchao Zhang PetscCall(VecRestoreKokkosViewWrite(y, &yv)); 1299a56b474SJunchao Zhang PetscCall(PetscLogGpuFlops(pckok->nsize * 2)); /* FMA on entries in all blocks */ 1309a56b474SJunchao Zhang PetscCall(PetscLogGpuTimeEnd()); 1313ba16761SJacob Faibussowitsch PetscFunctionReturn(PETSC_SUCCESS); 132f1be3500SJunchao Zhang } 133f1be3500SJunchao Zhang 134d71ae5a4SJacob Faibussowitsch static PetscErrorCode PCDestroy_VPBJacobi_Kokkos(PC pc) 135d71ae5a4SJacob Faibussowitsch { 136f1be3500SJunchao Zhang PC_VPBJacobi *jac = (PC_VPBJacobi *)pc->data; 137f1be3500SJunchao Zhang 138f1be3500SJunchao Zhang PetscFunctionBegin; 139f1be3500SJunchao Zhang PetscCallCXX(delete static_cast<PC_VPBJacobi_Kokkos *>(jac->spptr)); 140f1be3500SJunchao Zhang PetscCall(PCDestroy_VPBJacobi(pc)); 1413ba16761SJacob Faibussowitsch PetscFunctionReturn(PETSC_SUCCESS); 142f1be3500SJunchao Zhang } 143f1be3500SJunchao Zhang 14442ce410bSJunchao Zhang PETSC_INTERN PetscErrorCode PCSetUp_VPBJacobi_Kokkos(PC pc, Mat diagVPB) 145d71ae5a4SJacob Faibussowitsch { 146f1be3500SJunchao Zhang PC_VPBJacobi *jac = (PC_VPBJacobi *)pc->data; 147f1be3500SJunchao Zhang PC_VPBJacobi_Kokkos *pckok = static_cast<PC_VPBJacobi_Kokkos *>(jac->spptr); 1489d13fa56SJunchao Zhang PetscInt i, nlocal, nblocks, nsize = 0; 149f1be3500SJunchao Zhang const PetscInt *bsizes; 1505994c5adSJunchao Zhang PetscBool ismpi; 1515994c5adSJunchao Zhang Mat A; 152f1be3500SJunchao Zhang 153f1be3500SJunchao Zhang PetscFunctionBegin; 154f1be3500SJunchao Zhang PetscCall(MatGetVariableBlockSizes(pc->pmat, &nblocks, &bsizes)); 1559d13fa56SJunchao Zhang PetscCall(MatGetLocalSize(pc->pmat, &nlocal, NULL)); 1569d13fa56SJunchao Zhang PetscCheck(!nlocal || nblocks, PETSC_COMM_SELF, PETSC_ERR_ARG_WRONGSTATE, "Must call MatSetVariableBlockSizes() before using PCVPBJACOBI"); 157f1be3500SJunchao Zhang 1589d13fa56SJunchao Zhang if (!jac->diag) { 1591690c2aeSBarry Smith PetscInt max_bs = -1, min_bs = PETSC_INT_MAX; 1609d13fa56SJunchao Zhang for (i = 0; i < nblocks; i++) { 1619d13fa56SJunchao Zhang min_bs = PetscMin(min_bs, bsizes[i]); 1629d13fa56SJunchao Zhang max_bs = PetscMax(max_bs, bsizes[i]); 1639d13fa56SJunchao Zhang nsize += bsizes[i] * bsizes[i]; 1649d13fa56SJunchao Zhang } 1659d13fa56SJunchao Zhang jac->nblocks = nblocks; 1669d13fa56SJunchao Zhang jac->min_bs = min_bs; 1679d13fa56SJunchao Zhang jac->max_bs = max_bs; 1689d13fa56SJunchao Zhang } 1699d13fa56SJunchao Zhang 1709d13fa56SJunchao Zhang // If one calls MatSetVariableBlockSizes() multiple times and sizes have been changed (is it allowed?), we delete the old and rebuild anyway 1719d13fa56SJunchao Zhang if (pckok && (pckok->n != nlocal || pckok->nblocks != nblocks || pckok->nsize != nsize)) { 172f1be3500SJunchao Zhang PetscCallCXX(delete pckok); 173f1be3500SJunchao Zhang pckok = nullptr; 174f1be3500SJunchao Zhang } 175f1be3500SJunchao Zhang 1769d13fa56SJunchao Zhang PetscCall(PetscLogGpuTimeBegin()); 1779d13fa56SJunchao Zhang if (!pckok) { 1789d13fa56SJunchao Zhang PetscCallCXX(pckok = new PC_VPBJacobi_Kokkos(nlocal, nblocks, nsize, bsizes)); 1799d13fa56SJunchao Zhang jac->spptr = pckok; 180f1be3500SJunchao Zhang } 181f1be3500SJunchao Zhang 1829d13fa56SJunchao Zhang // Extract diagonal blocks from the matrix and compute their inverse 1839d13fa56SJunchao Zhang const auto &bs = pckok->bs_dual.view_device(); 1849d13fa56SJunchao Zhang const auto &bs2 = pckok->bs2_dual.view_device(); 1859d13fa56SJunchao Zhang const auto &blkMap = pckok->blkMap_dual.view_device(); 18642ce410bSJunchao Zhang if (diagVPB) { // If caller provided a matrix made of the diagonal blocks, use it 187077598ddSJames Wright PetscCall(PetscObjectBaseTypeCompare((PetscObject)diagVPB, MATMPIAIJ, &ismpi)); 188077598ddSJames Wright A = ismpi ? static_cast<Mat_MPIAIJ *>(diagVPB->data)->A : diagVPB; 18942ce410bSJunchao Zhang } else { 190077598ddSJames Wright PetscCall(PetscObjectBaseTypeCompare((PetscObject)pc->pmat, MATMPIAIJ, &ismpi)); 191f4f49eeaSPierre Jolivet A = ismpi ? static_cast<Mat_MPIAIJ *>(pc->pmat->data)->A : pc->pmat; 19242ce410bSJunchao Zhang } 1935994c5adSJunchao Zhang PetscCall(MatInvertVariableBlockDiagonal_SeqAIJKokkos(A, bs, bs2, blkMap, pckok->work, pckok->diag)); 19469eda9daSJed Brown pc->ops->apply = PCApplyOrTranspose_VPBJacobi_Kokkos<PETSC_FALSE>; 19569eda9daSJed Brown pc->ops->applytranspose = PCApplyOrTranspose_VPBJacobi_Kokkos<PETSC_TRUE>; 196f1be3500SJunchao Zhang pc->ops->destroy = PCDestroy_VPBJacobi_Kokkos; 1979d13fa56SJunchao Zhang PetscCall(PetscLogGpuTimeEnd()); 1989d13fa56SJunchao Zhang PetscCall(PetscLogGpuFlops(pckok->setupFlops)); 1993ba16761SJacob Faibussowitsch PetscFunctionReturn(PETSC_SUCCESS); 200f1be3500SJunchao Zhang } 201