1f1be3500SJunchao Zhang #include <petscvec_kokkos.hpp> 2f1be3500SJunchao Zhang #include <../src/vec/vec/impls/seq/kokkos/veckokkosimpl.hpp> 3f1be3500SJunchao Zhang #include <petscdevice.h> 4f1be3500SJunchao Zhang #include <../src/ksp/pc/impls/vpbjacobi/vpbjacobi.h> 5f1be3500SJunchao Zhang 6f1be3500SJunchao Zhang /* A class that manages helper arrays assisting parallel PCApply() with Kokkos */ 7f1be3500SJunchao Zhang struct PC_VPBJacobi_Kokkos { 8f1be3500SJunchao Zhang /* Cache the old sizes to check if we need realloc */ 9f1be3500SJunchao Zhang PetscInt n; /* number of rows of the local matrix */ 10f1be3500SJunchao Zhang PetscInt nblocks; /* number of point blocks */ 11f1be3500SJunchao Zhang PetscInt nsize; /* sum of sizes of the point blocks */ 12f1be3500SJunchao Zhang 13f1be3500SJunchao Zhang /* Helper arrays that are pre-computed on host and then copied to device. 14f1be3500SJunchao Zhang bs: [nblocks+1], "csr" version of bsizes[] 15f1be3500SJunchao Zhang bs2: [nblocks+1], "csr" version of squares of bsizes[] 16f1be3500SJunchao Zhang matIdx: [n], row i of the local matrix belongs to the matIdx_d[i] block 17f1be3500SJunchao Zhang */ 18f1be3500SJunchao Zhang PetscIntKokkosDualView bs_dual, bs2_dual, matIdx_dual; 19f1be3500SJunchao Zhang PetscScalarKokkosDualView diag_dual; 20f1be3500SJunchao Zhang 219371c9d4SSatish Balay PC_VPBJacobi_Kokkos(PetscInt n, PetscInt nblocks, PetscInt nsize, const PetscInt *bsizes, MatScalar *diag_ptr_h) : 22d71ae5a4SJacob Faibussowitsch n(n), nblocks(nblocks), nsize(nsize), bs_dual("bs_dual", nblocks + 1), bs2_dual("bs2_dual", nblocks + 1), matIdx_dual("matIdx_dual", n) 23d71ae5a4SJacob Faibussowitsch { 24f1be3500SJunchao Zhang PetscScalarKokkosViewHost diag_h(diag_ptr_h, nsize); 25f1be3500SJunchao Zhang 26f1be3500SJunchao Zhang auto diag_d = Kokkos::create_mirror_view(DefaultMemorySpace(), diag_h); 27f1be3500SJunchao Zhang diag_dual = PetscScalarKokkosDualView(diag_d, diag_h); 28f1be3500SJunchao Zhang PetscCallVoid(UpdateOffsetsOnDevice(bsizes, diag_ptr_h)); 29f1be3500SJunchao Zhang } 30f1be3500SJunchao Zhang 31d71ae5a4SJacob Faibussowitsch PetscErrorCode UpdateOffsetsOnDevice(const PetscInt *bsizes, MatScalar *diag_ptr_h) 32d71ae5a4SJacob Faibussowitsch { 33f1be3500SJunchao Zhang PetscFunctionBegin; 34f1be3500SJunchao Zhang PetscCheck(diag_dual.view_host().data() == diag_ptr_h, PETSC_COMM_SELF, PETSC_ERR_PLIB, "Host pointer has changed since last call"); 35f1be3500SJunchao Zhang PetscCall(ComputeOffsetsOnHost(bsizes)); 36f1be3500SJunchao Zhang 37f1be3500SJunchao Zhang PetscCallCXX(bs_dual.modify_host()); 38f1be3500SJunchao Zhang PetscCallCXX(bs2_dual.modify_host()); 39f1be3500SJunchao Zhang PetscCallCXX(matIdx_dual.modify_host()); 40f1be3500SJunchao Zhang PetscCallCXX(diag_dual.modify_host()); 41f1be3500SJunchao Zhang 42f1be3500SJunchao Zhang PetscCallCXX(bs_dual.sync_device()); 43f1be3500SJunchao Zhang PetscCallCXX(bs2_dual.sync_device()); 44f1be3500SJunchao Zhang PetscCallCXX(matIdx_dual.sync_device()); 45f1be3500SJunchao Zhang PetscCallCXX(diag_dual.sync_device()); 469a56b474SJunchao Zhang PetscCall(PetscLogCpuToGpu(sizeof(PetscInt) * (2 * nblocks + 2 + n) + sizeof(MatScalar) * nsize)); 47*3ba16761SJacob Faibussowitsch PetscFunctionReturn(PETSC_SUCCESS); 48f1be3500SJunchao Zhang } 49f1be3500SJunchao Zhang 50f1be3500SJunchao Zhang private: 51d71ae5a4SJacob Faibussowitsch PetscErrorCode ComputeOffsetsOnHost(const PetscInt *bsizes) 52d71ae5a4SJacob Faibussowitsch { 53f1be3500SJunchao Zhang PetscInt *bs_h = bs_dual.view_host().data(); 54f1be3500SJunchao Zhang PetscInt *bs2_h = bs2_dual.view_host().data(); 55f1be3500SJunchao Zhang PetscInt *matIdx_h = matIdx_dual.view_host().data(); 56f1be3500SJunchao Zhang 57f1be3500SJunchao Zhang PetscFunctionBegin; 58f1be3500SJunchao Zhang bs_h[0] = bs2_h[0] = 0; 59f1be3500SJunchao Zhang for (PetscInt i = 0; i < nblocks; i++) { 60f1be3500SJunchao Zhang bs_h[i + 1] = bs_h[i] + bsizes[i]; 61f1be3500SJunchao Zhang bs2_h[i + 1] = bs2_h[i] + bsizes[i] * bsizes[i]; 62f1be3500SJunchao Zhang for (PetscInt j = 0; j < bsizes[i]; j++) matIdx_h[bs_h[i] + j] = i; 63f1be3500SJunchao Zhang } 64*3ba16761SJacob Faibussowitsch PetscFunctionReturn(PETSC_SUCCESS); 65f1be3500SJunchao Zhang } 66f1be3500SJunchao Zhang }; 67f1be3500SJunchao Zhang 6869eda9daSJed Brown template <PetscBool transpose> 6969eda9daSJed Brown static PetscErrorCode PCApplyOrTranspose_VPBJacobi_Kokkos(PC pc, Vec x, Vec y) 70d71ae5a4SJacob Faibussowitsch { 71f1be3500SJunchao Zhang PC_VPBJacobi *jac = (PC_VPBJacobi *)pc->data; 72f1be3500SJunchao Zhang PC_VPBJacobi_Kokkos *pckok = static_cast<PC_VPBJacobi_Kokkos *>(jac->spptr); 73f1be3500SJunchao Zhang ConstPetscScalarKokkosView xv; 74f1be3500SJunchao Zhang PetscScalarKokkosView yv; 75f1be3500SJunchao Zhang PetscScalarKokkosView diag = pckok->diag_dual.view_device(); 76f1be3500SJunchao Zhang PetscIntKokkosView bs = pckok->bs_dual.view_device(); 77f1be3500SJunchao Zhang PetscIntKokkosView bs2 = pckok->bs2_dual.view_device(); 78f1be3500SJunchao Zhang PetscIntKokkosView matIdx = pckok->matIdx_dual.view_device(); 7969eda9daSJed Brown const char *label = transpose ? "PCApplyTranspose_VPBJacobi_Kokkos" : "PCApply_VPBJacobi_Kokkos"; 80f1be3500SJunchao Zhang 81f1be3500SJunchao Zhang PetscFunctionBegin; 829a56b474SJunchao Zhang PetscCall(PetscLogGpuTimeBegin()); 83f1be3500SJunchao Zhang VecErrorIfNotKokkos(x); 84f1be3500SJunchao Zhang VecErrorIfNotKokkos(y); 85f1be3500SJunchao Zhang PetscCall(VecGetKokkosView(x, &xv)); 86f1be3500SJunchao Zhang PetscCall(VecGetKokkosViewWrite(y, &yv)); 879371c9d4SSatish Balay PetscCallCXX(Kokkos::parallel_for( 8869eda9daSJed Brown label, pckok->n, KOKKOS_LAMBDA(PetscInt row) { 89f1be3500SJunchao Zhang const PetscScalar *Ap, *xp; 90f1be3500SJunchao Zhang PetscScalar *yp; 91f1be3500SJunchao Zhang PetscInt i, j, k, m; 92f1be3500SJunchao Zhang 93f1be3500SJunchao Zhang k = matIdx(row); /* k-th block/matrix */ 94f1be3500SJunchao Zhang m = bs(k + 1) - bs(k); /* block size of the k-th block */ 95f1be3500SJunchao Zhang i = row - bs(k); /* i-th row of the block */ 9669eda9daSJed Brown Ap = &diag(bs2(k) + i * (transpose ? m : 1)); /* Ap points to the first entry of i-th row/column */ 97f1be3500SJunchao Zhang xp = &xv(bs(k)); 98f1be3500SJunchao Zhang yp = &yv(bs(k)); 99f1be3500SJunchao Zhang 100f1be3500SJunchao Zhang yp[i] = 0.0; 1019371c9d4SSatish Balay for (j = 0; j < m; j++) { 1029371c9d4SSatish Balay yp[i] += Ap[0] * xp[j]; 10369eda9daSJed Brown Ap += transpose ? 1 : m; 1049371c9d4SSatish Balay } 105f1be3500SJunchao Zhang })); 106f1be3500SJunchao Zhang PetscCall(VecRestoreKokkosView(x, &xv)); 107f1be3500SJunchao Zhang PetscCall(VecRestoreKokkosViewWrite(y, &yv)); 1089a56b474SJunchao Zhang PetscCall(PetscLogGpuFlops(pckok->nsize * 2)); /* FMA on entries in all blocks */ 1099a56b474SJunchao Zhang PetscCall(PetscLogGpuTimeEnd()); 110*3ba16761SJacob Faibussowitsch PetscFunctionReturn(PETSC_SUCCESS); 111f1be3500SJunchao Zhang } 112f1be3500SJunchao Zhang 113d71ae5a4SJacob Faibussowitsch static PetscErrorCode PCDestroy_VPBJacobi_Kokkos(PC pc) 114d71ae5a4SJacob Faibussowitsch { 115f1be3500SJunchao Zhang PC_VPBJacobi *jac = (PC_VPBJacobi *)pc->data; 116f1be3500SJunchao Zhang 117f1be3500SJunchao Zhang PetscFunctionBegin; 118f1be3500SJunchao Zhang PetscCallCXX(delete static_cast<PC_VPBJacobi_Kokkos *>(jac->spptr)); 119f1be3500SJunchao Zhang PetscCall(PCDestroy_VPBJacobi(pc)); 120*3ba16761SJacob Faibussowitsch PetscFunctionReturn(PETSC_SUCCESS); 121f1be3500SJunchao Zhang } 122f1be3500SJunchao Zhang 123d71ae5a4SJacob Faibussowitsch PETSC_INTERN PetscErrorCode PCSetUp_VPBJacobi_Kokkos(PC pc) 124d71ae5a4SJacob Faibussowitsch { 125f1be3500SJunchao Zhang PC_VPBJacobi *jac = (PC_VPBJacobi *)pc->data; 126f1be3500SJunchao Zhang PC_VPBJacobi_Kokkos *pckok = static_cast<PC_VPBJacobi_Kokkos *>(jac->spptr); 127f1be3500SJunchao Zhang PetscInt i, n, nblocks, nsize = 0; 128f1be3500SJunchao Zhang const PetscInt *bsizes; 129f1be3500SJunchao Zhang 130f1be3500SJunchao Zhang PetscFunctionBegin; 131f1be3500SJunchao Zhang PetscCall(PCSetUp_VPBJacobi_Host(pc)); /* Compute the inverse on host now. Might worth doing it on device directly */ 132f1be3500SJunchao Zhang PetscCall(MatGetVariableBlockSizes(pc->pmat, &nblocks, &bsizes)); 133f1be3500SJunchao Zhang for (i = 0; i < nblocks; i++) nsize += bsizes[i] * bsizes[i]; 134f1be3500SJunchao Zhang PetscCall(MatGetLocalSize(pc->pmat, &n, NULL)); 135f1be3500SJunchao Zhang 136f1be3500SJunchao Zhang /* If one calls MatSetVariableBlockSizes() multiple times and sizes have been changed (is it allowed?), we delete the old and rebuild anyway */ 137f1be3500SJunchao Zhang if (pckok && (pckok->n != n || pckok->nblocks != nblocks || pckok->nsize != nsize)) { 138f1be3500SJunchao Zhang PetscCallCXX(delete pckok); 139f1be3500SJunchao Zhang pckok = nullptr; 140f1be3500SJunchao Zhang } 141f1be3500SJunchao Zhang 14235cb6cd3SPierre Jolivet if (!pckok) { /* allocate the struct along with the helper arrays from the scratch */ 143f1be3500SJunchao Zhang PetscCallCXX(jac->spptr = new PC_VPBJacobi_Kokkos(n, nblocks, nsize, bsizes, jac->diag)); 144f1be3500SJunchao Zhang } else { /* update the value only */ 145f1be3500SJunchao Zhang PetscCall(pckok->UpdateOffsetsOnDevice(bsizes, jac->diag)); 146f1be3500SJunchao Zhang } 147f1be3500SJunchao Zhang 14869eda9daSJed Brown pc->ops->apply = PCApplyOrTranspose_VPBJacobi_Kokkos<PETSC_FALSE>; 14969eda9daSJed Brown pc->ops->applytranspose = PCApplyOrTranspose_VPBJacobi_Kokkos<PETSC_TRUE>; 150f1be3500SJunchao Zhang pc->ops->destroy = PCDestroy_VPBJacobi_Kokkos; 151*3ba16761SJacob Faibussowitsch PetscFunctionReturn(PETSC_SUCCESS); 152f1be3500SJunchao Zhang } 153