1 #include <petscvec_kokkos.hpp> 2 #include <petsc_kokkos.hpp> 3 #include <../src/vec/vec/impls/seq/kokkos/veckokkosimpl.hpp> 4 #include <petscdevice.h> 5 #include <../src/ksp/pc/impls/pbjacobi/pbjacobi.h> 6 7 struct PC_PBJacobi_Kokkos { 8 PetscScalarKokkosDualView diag_dual; 9 10 PC_PBJacobi_Kokkos(PetscInt len, PetscScalar *diag_ptr_h) 11 { 12 PetscScalarKokkosViewHost diag_h(diag_ptr_h, len); 13 auto diag_d = Kokkos::create_mirror_view_and_copy(PetscGetKokkosExecutionSpace(), diag_h); 14 diag_dual = PetscScalarKokkosDualView(diag_d, diag_h); 15 } 16 17 PetscErrorCode Update(const PetscScalar *diag_ptr_h) 18 { 19 auto &exec = PetscGetKokkosExecutionSpace(); 20 21 PetscFunctionBegin; 22 PetscCheck(diag_dual.view_host().data() == diag_ptr_h, PETSC_COMM_SELF, PETSC_ERR_PLIB, "Host pointer has changed since last call"); 23 PetscCallCXX(diag_dual.modify_host()); /* mark the host has newer data */ 24 PetscCallCXX(diag_dual.sync_device(exec)); 25 PetscFunctionReturn(PETSC_SUCCESS); 26 } 27 }; 28 29 /* Make 'transpose' a template parameter instead of a function input parameter, so that 30 it will be a const in template instantiation and gets optimized out. 31 */ 32 template <PetscBool transpose> 33 static PetscErrorCode PCApplyOrTranspose_PBJacobi_Kokkos(PC pc, Vec x, Vec y) 34 { 35 PC_PBJacobi *jac = (PC_PBJacobi *)pc->data; 36 PC_PBJacobi_Kokkos *pckok = static_cast<PC_PBJacobi_Kokkos *>(jac->spptr); 37 ConstPetscScalarKokkosView xv; 38 PetscScalarKokkosView yv; 39 PetscScalarKokkosView Av = pckok->diag_dual.view_device(); 40 const PetscInt bs = jac->bs, mbs = jac->mbs, bs2 = bs * bs; 41 const char *label = transpose ? "PCApplyTranspose_PBJacobi_Kokkos" : "PCApply_PBJacobi_Kokkos"; 42 43 PetscFunctionBegin; 44 PetscCall(PetscLogGpuTimeBegin()); 45 VecErrorIfNotKokkos(x); 46 VecErrorIfNotKokkos(y); 47 PetscCall(VecGetKokkosView(x, &xv)); 48 PetscCall(VecGetKokkosViewWrite(y, &yv)); 49 PetscCallCXX(Kokkos::parallel_for( 50 label, Kokkos::RangePolicy<>(PetscGetKokkosExecutionSpace(), 0, bs * mbs), KOKKOS_LAMBDA(PetscInt row) { 51 const PetscScalar *Ap, *xp; 52 PetscScalar *yp; 53 PetscInt i, j, k; 54 55 k = row / bs; /* k-th block */ 56 i = row % bs; /* this thread deals with i-th row of the block */ 57 Ap = &Av(bs2 * k + i * (transpose ? bs : 1)); /* Ap points to the first entry of i-th row */ 58 xp = &xv(bs * k); 59 yp = &yv(bs * k); 60 /* multiply i-th row (column) with x */ 61 yp[i] = 0.0; 62 for (j = 0; j < bs; j++) { 63 yp[i] += Ap[0] * xp[j]; 64 Ap += (transpose ? 1 : bs); /* block is in column major order */ 65 } 66 })); 67 PetscCall(VecRestoreKokkosView(x, &xv)); 68 PetscCall(VecRestoreKokkosViewWrite(y, &yv)); 69 PetscCall(PetscLogGpuFlops(bs * bs * mbs * 2)); /* FMA on entries in all blocks */ 70 PetscCall(PetscLogGpuTimeEnd()); 71 PetscFunctionReturn(PETSC_SUCCESS); 72 } 73 74 static PetscErrorCode PCDestroy_PBJacobi_Kokkos(PC pc) 75 { 76 PC_PBJacobi *jac = (PC_PBJacobi *)pc->data; 77 78 PetscFunctionBegin; 79 PetscCallCXX(delete static_cast<PC_PBJacobi_Kokkos *>(jac->spptr)); 80 PetscCall(PCDestroy_PBJacobi(pc)); 81 PetscFunctionReturn(PETSC_SUCCESS); 82 } 83 84 PETSC_INTERN PetscErrorCode PCSetUp_PBJacobi_Kokkos(PC pc, Mat diagPB) 85 { 86 PC_PBJacobi *jac = (PC_PBJacobi *)pc->data; 87 PetscInt len; 88 89 PetscFunctionBegin; 90 PetscCall(PCSetUp_PBJacobi_Host(pc, diagPB)); /* Compute the inverse on host now. Might worth doing it on device directly */ 91 len = jac->bs * jac->bs * jac->mbs; 92 if (!jac->spptr) { 93 PetscCallCXX(jac->spptr = new PC_PBJacobi_Kokkos(len, const_cast<PetscScalar *>(jac->diag))); 94 } else { 95 PC_PBJacobi_Kokkos *pckok = static_cast<PC_PBJacobi_Kokkos *>(jac->spptr); 96 PetscCall(pckok->Update(jac->diag)); 97 } 98 PetscCall(PetscLogCpuToGpu(sizeof(PetscScalar) * len)); 99 100 pc->ops->apply = PCApplyOrTranspose_PBJacobi_Kokkos<PETSC_FALSE>; 101 pc->ops->applytranspose = PCApplyOrTranspose_PBJacobi_Kokkos<PETSC_TRUE>; 102 pc->ops->destroy = PCDestroy_PBJacobi_Kokkos; 103 PetscFunctionReturn(PETSC_SUCCESS); 104 } 105