1f1be3500SJunchao Zhang #include <petscvec_kokkos.hpp> 2f1be3500SJunchao Zhang #include <../src/vec/vec/impls/seq/kokkos/veckokkosimpl.hpp> 3f1be3500SJunchao Zhang #include <petscdevice.h> 4f1be3500SJunchao Zhang #include <../src/ksp/pc/impls/vpbjacobi/vpbjacobi.h> 5*9d13fa56SJunchao Zhang #include <../src/mat/impls/aij/seq/kokkos/aijkok.hpp> // for MatInvertVariableBlockDiagonal_SeqAIJKokkos 6f1be3500SJunchao Zhang 7f1be3500SJunchao Zhang /* A class that manages helper arrays assisting parallel PCApply() with Kokkos */ 8f1be3500SJunchao Zhang struct PC_VPBJacobi_Kokkos { 9f1be3500SJunchao Zhang /* Cache the old sizes to check if we need realloc */ 10f1be3500SJunchao Zhang PetscInt n; /* number of rows of the local matrix */ 11f1be3500SJunchao Zhang PetscInt nblocks; /* number of point blocks */ 12*9d13fa56SJunchao Zhang PetscInt nsize; /* sum of sizes (elements) of the point blocks */ 13f1be3500SJunchao Zhang 14f1be3500SJunchao Zhang /* Helper arrays that are pre-computed on host and then copied to device. 15f1be3500SJunchao Zhang bs: [nblocks+1], "csr" version of bsizes[] 16f1be3500SJunchao Zhang bs2: [nblocks+1], "csr" version of squares of bsizes[] 17*9d13fa56SJunchao Zhang blkMap: [n], row i of the local matrix belongs to the blkMap[i]-th block 18f1be3500SJunchao Zhang */ 19*9d13fa56SJunchao Zhang PetscIntKokkosDualView bs_dual, bs2_dual, blkMap_dual; 20*9d13fa56SJunchao Zhang PetscScalarKokkosView diag; // buffer to store diagonal blocks 21*9d13fa56SJunchao Zhang PetscScalarKokkosView work; // work buffer, with the same size as diag[] 22*9d13fa56SJunchao Zhang PetscLogDouble setupFlops; 23f1be3500SJunchao Zhang 24*9d13fa56SJunchao Zhang // clang-format off 25*9d13fa56SJunchao Zhang // n: size of the matrix 26*9d13fa56SJunchao Zhang // nblocks: number of blocks 27*9d13fa56SJunchao Zhang // nsize: sum bsizes[i]^2 for i=0..nblocks 28*9d13fa56SJunchao Zhang // bsizes[nblocks]: sizes of blocks 29*9d13fa56SJunchao Zhang PC_VPBJacobi_Kokkos(PetscInt n, PetscInt nblocks, PetscInt nsize, const PetscInt *bsizes) : 30*9d13fa56SJunchao Zhang n(n), nblocks(nblocks), nsize(nsize), bs_dual(NoInit("bs_dual"), nblocks + 1), 31*9d13fa56SJunchao Zhang bs2_dual(NoInit("bs2_dual"), nblocks + 1), blkMap_dual(NoInit("blkMap_dual"), n), 32*9d13fa56SJunchao Zhang diag(NoInit("diag"), nsize), work(NoInit("work"), nsize) 33d71ae5a4SJacob Faibussowitsch { 34*9d13fa56SJunchao Zhang PetscCallVoid(BuildHelperArrays(bsizes)); 35f1be3500SJunchao Zhang } 36*9d13fa56SJunchao Zhang // clang-format on 37f1be3500SJunchao Zhang 38f1be3500SJunchao Zhang private: 39*9d13fa56SJunchao Zhang PetscErrorCode BuildHelperArrays(const PetscInt *bsizes) 40d71ae5a4SJacob Faibussowitsch { 41f1be3500SJunchao Zhang PetscInt *bs_h = bs_dual.view_host().data(); 42f1be3500SJunchao Zhang PetscInt *bs2_h = bs2_dual.view_host().data(); 43*9d13fa56SJunchao Zhang PetscInt *blkMap_h = blkMap_dual.view_host().data(); 44f1be3500SJunchao Zhang 45f1be3500SJunchao Zhang PetscFunctionBegin; 46*9d13fa56SJunchao Zhang setupFlops = 0.0; 47f1be3500SJunchao Zhang bs_h[0] = bs2_h[0] = 0; 48f1be3500SJunchao Zhang for (PetscInt i = 0; i < nblocks; i++) { 49*9d13fa56SJunchao Zhang PetscInt m = bsizes[i]; 50*9d13fa56SJunchao Zhang bs_h[i + 1] = bs_h[i] + m; 51*9d13fa56SJunchao Zhang bs2_h[i + 1] = bs2_h[i] + m * m; 52*9d13fa56SJunchao Zhang for (PetscInt j = 0; j < m; j++) blkMap_h[bs_h[i] + j] = i; 53*9d13fa56SJunchao Zhang // m^3/3 FMA for A=LU factorization; m^3 FMA for solving (LU)X=I to get the inverse 54*9d13fa56SJunchao Zhang setupFlops += 8.0 * m * m * m / 3; 55f1be3500SJunchao Zhang } 56*9d13fa56SJunchao Zhang 57*9d13fa56SJunchao Zhang PetscCallCXX(bs_dual.modify_host()); 58*9d13fa56SJunchao Zhang PetscCallCXX(bs2_dual.modify_host()); 59*9d13fa56SJunchao Zhang PetscCallCXX(blkMap_dual.modify_host()); 60*9d13fa56SJunchao Zhang PetscCallCXX(bs_dual.sync_device()); 61*9d13fa56SJunchao Zhang PetscCallCXX(bs2_dual.sync_device()); 62*9d13fa56SJunchao Zhang PetscCallCXX(blkMap_dual.sync_device()); 63*9d13fa56SJunchao Zhang PetscCall(PetscLogCpuToGpu(sizeof(PetscInt) * (2 * (nblocks + 1) + n))); 643ba16761SJacob Faibussowitsch PetscFunctionReturn(PETSC_SUCCESS); 65f1be3500SJunchao Zhang } 66f1be3500SJunchao Zhang }; 67f1be3500SJunchao Zhang 6869eda9daSJed Brown template <PetscBool transpose> 6969eda9daSJed Brown static PetscErrorCode PCApplyOrTranspose_VPBJacobi_Kokkos(PC pc, Vec x, Vec y) 70d71ae5a4SJacob Faibussowitsch { 71f1be3500SJunchao Zhang PC_VPBJacobi *jac = (PC_VPBJacobi *)pc->data; 72f1be3500SJunchao Zhang PC_VPBJacobi_Kokkos *pckok = static_cast<PC_VPBJacobi_Kokkos *>(jac->spptr); 73f1be3500SJunchao Zhang ConstPetscScalarKokkosView xv; 74f1be3500SJunchao Zhang PetscScalarKokkosView yv; 75*9d13fa56SJunchao Zhang PetscScalarKokkosView diag = pckok->diag; 76f1be3500SJunchao Zhang PetscIntKokkosView bs = pckok->bs_dual.view_device(); 77f1be3500SJunchao Zhang PetscIntKokkosView bs2 = pckok->bs2_dual.view_device(); 78*9d13fa56SJunchao Zhang PetscIntKokkosView blkMap = pckok->blkMap_dual.view_device(); 7969eda9daSJed Brown const char *label = transpose ? "PCApplyTranspose_VPBJacobi_Kokkos" : "PCApply_VPBJacobi_Kokkos"; 80f1be3500SJunchao Zhang 81f1be3500SJunchao Zhang PetscFunctionBegin; 829a56b474SJunchao Zhang PetscCall(PetscLogGpuTimeBegin()); 83f1be3500SJunchao Zhang VecErrorIfNotKokkos(x); 84f1be3500SJunchao Zhang VecErrorIfNotKokkos(y); 85f1be3500SJunchao Zhang PetscCall(VecGetKokkosView(x, &xv)); 86f1be3500SJunchao Zhang PetscCall(VecGetKokkosViewWrite(y, &yv)); 879371c9d4SSatish Balay PetscCallCXX(Kokkos::parallel_for( 8869eda9daSJed Brown label, pckok->n, KOKKOS_LAMBDA(PetscInt row) { 89f1be3500SJunchao Zhang const PetscScalar *Ap, *xp; 90f1be3500SJunchao Zhang PetscScalar *yp; 91f1be3500SJunchao Zhang PetscInt i, j, k, m; 92f1be3500SJunchao Zhang 93*9d13fa56SJunchao Zhang k = blkMap(row); /* k-th block/matrix */ 94f1be3500SJunchao Zhang m = bs(k + 1) - bs(k); /* block size of the k-th block */ 95f1be3500SJunchao Zhang i = row - bs(k); /* i-th row of the block */ 9669eda9daSJed Brown Ap = &diag(bs2(k) + i * (transpose ? m : 1)); /* Ap points to the first entry of i-th row/column */ 97f1be3500SJunchao Zhang xp = &xv(bs(k)); 98f1be3500SJunchao Zhang yp = &yv(bs(k)); 99f1be3500SJunchao Zhang 100f1be3500SJunchao Zhang yp[i] = 0.0; 1019371c9d4SSatish Balay for (j = 0; j < m; j++) { 1029371c9d4SSatish Balay yp[i] += Ap[0] * xp[j]; 10369eda9daSJed Brown Ap += transpose ? 1 : m; 1049371c9d4SSatish Balay } 105f1be3500SJunchao Zhang })); 106f1be3500SJunchao Zhang PetscCall(VecRestoreKokkosView(x, &xv)); 107f1be3500SJunchao Zhang PetscCall(VecRestoreKokkosViewWrite(y, &yv)); 1089a56b474SJunchao Zhang PetscCall(PetscLogGpuFlops(pckok->nsize * 2)); /* FMA on entries in all blocks */ 1099a56b474SJunchao Zhang PetscCall(PetscLogGpuTimeEnd()); 1103ba16761SJacob Faibussowitsch PetscFunctionReturn(PETSC_SUCCESS); 111f1be3500SJunchao Zhang } 112f1be3500SJunchao Zhang 113d71ae5a4SJacob Faibussowitsch static PetscErrorCode PCDestroy_VPBJacobi_Kokkos(PC pc) 114d71ae5a4SJacob Faibussowitsch { 115f1be3500SJunchao Zhang PC_VPBJacobi *jac = (PC_VPBJacobi *)pc->data; 116f1be3500SJunchao Zhang 117f1be3500SJunchao Zhang PetscFunctionBegin; 118f1be3500SJunchao Zhang PetscCallCXX(delete static_cast<PC_VPBJacobi_Kokkos *>(jac->spptr)); 119f1be3500SJunchao Zhang PetscCall(PCDestroy_VPBJacobi(pc)); 1203ba16761SJacob Faibussowitsch PetscFunctionReturn(PETSC_SUCCESS); 121f1be3500SJunchao Zhang } 122f1be3500SJunchao Zhang 123d71ae5a4SJacob Faibussowitsch PETSC_INTERN PetscErrorCode PCSetUp_VPBJacobi_Kokkos(PC pc) 124d71ae5a4SJacob Faibussowitsch { 125f1be3500SJunchao Zhang PC_VPBJacobi *jac = (PC_VPBJacobi *)pc->data; 126f1be3500SJunchao Zhang PC_VPBJacobi_Kokkos *pckok = static_cast<PC_VPBJacobi_Kokkos *>(jac->spptr); 127*9d13fa56SJunchao Zhang PetscInt i, nlocal, nblocks, nsize = 0; 128f1be3500SJunchao Zhang const PetscInt *bsizes; 129f1be3500SJunchao Zhang 130f1be3500SJunchao Zhang PetscFunctionBegin; 131f1be3500SJunchao Zhang PetscCall(MatGetVariableBlockSizes(pc->pmat, &nblocks, &bsizes)); 132*9d13fa56SJunchao Zhang PetscCall(MatGetLocalSize(pc->pmat, &nlocal, NULL)); 133*9d13fa56SJunchao Zhang PetscCheck(!nlocal || nblocks, PETSC_COMM_SELF, PETSC_ERR_ARG_WRONGSTATE, "Must call MatSetVariableBlockSizes() before using PCVPBJACOBI"); 134f1be3500SJunchao Zhang 135*9d13fa56SJunchao Zhang if (!jac->diag) { 136*9d13fa56SJunchao Zhang PetscInt max_bs = -1, min_bs = PETSC_MAX_INT; 137*9d13fa56SJunchao Zhang for (i = 0; i < nblocks; i++) { 138*9d13fa56SJunchao Zhang min_bs = PetscMin(min_bs, bsizes[i]); 139*9d13fa56SJunchao Zhang max_bs = PetscMax(max_bs, bsizes[i]); 140*9d13fa56SJunchao Zhang nsize += bsizes[i] * bsizes[i]; 141*9d13fa56SJunchao Zhang } 142*9d13fa56SJunchao Zhang jac->nblocks = nblocks; 143*9d13fa56SJunchao Zhang jac->min_bs = min_bs; 144*9d13fa56SJunchao Zhang jac->max_bs = max_bs; 145*9d13fa56SJunchao Zhang } 146*9d13fa56SJunchao Zhang 147*9d13fa56SJunchao Zhang // If one calls MatSetVariableBlockSizes() multiple times and sizes have been changed (is it allowed?), we delete the old and rebuild anyway 148*9d13fa56SJunchao Zhang if (pckok && (pckok->n != nlocal || pckok->nblocks != nblocks || pckok->nsize != nsize)) { 149f1be3500SJunchao Zhang PetscCallCXX(delete pckok); 150f1be3500SJunchao Zhang pckok = nullptr; 151f1be3500SJunchao Zhang } 152f1be3500SJunchao Zhang 153*9d13fa56SJunchao Zhang PetscCall(PetscLogGpuTimeBegin()); 154*9d13fa56SJunchao Zhang if (!pckok) { 155*9d13fa56SJunchao Zhang PetscCallCXX(pckok = new PC_VPBJacobi_Kokkos(nlocal, nblocks, nsize, bsizes)); 156*9d13fa56SJunchao Zhang jac->spptr = pckok; 157f1be3500SJunchao Zhang } 158f1be3500SJunchao Zhang 159*9d13fa56SJunchao Zhang // Extract diagonal blocks from the matrix and compute their inverse 160*9d13fa56SJunchao Zhang const auto &bs = pckok->bs_dual.view_device(); 161*9d13fa56SJunchao Zhang const auto &bs2 = pckok->bs2_dual.view_device(); 162*9d13fa56SJunchao Zhang const auto &blkMap = pckok->blkMap_dual.view_device(); 163*9d13fa56SJunchao Zhang PetscCall(MatInvertVariableBlockDiagonal_SeqAIJKokkos(pc->pmat, bs, bs2, blkMap, pckok->work, pckok->diag)); 16469eda9daSJed Brown pc->ops->apply = PCApplyOrTranspose_VPBJacobi_Kokkos<PETSC_FALSE>; 16569eda9daSJed Brown pc->ops->applytranspose = PCApplyOrTranspose_VPBJacobi_Kokkos<PETSC_TRUE>; 166f1be3500SJunchao Zhang pc->ops->destroy = PCDestroy_VPBJacobi_Kokkos; 167*9d13fa56SJunchao Zhang PetscCall(PetscLogGpuTimeEnd()); 168*9d13fa56SJunchao Zhang PetscCall(PetscLogGpuFlops(pckok->setupFlops)); 1693ba16761SJacob Faibussowitsch PetscFunctionReturn(PETSC_SUCCESS); 170f1be3500SJunchao Zhang } 171