1*12facf1bSJunchao Zhang #include <petscvec_kokkos.hpp> 2*12facf1bSJunchao Zhang #include <../src/vec/vec/impls/seq/kokkos/veckokkosimpl.hpp> 3*12facf1bSJunchao Zhang #include <petscdevice.h> 4*12facf1bSJunchao Zhang #include <../src/ksp/pc/impls/pbjacobi/pbjacobi.h> 5*12facf1bSJunchao Zhang 6*12facf1bSJunchao Zhang struct PC_PBJacobi_Kokkos { 7*12facf1bSJunchao Zhang PetscScalarKokkosDualView diag_dual; 8*12facf1bSJunchao Zhang 9*12facf1bSJunchao Zhang PC_PBJacobi_Kokkos(PetscInt len, PetscScalar *diag_ptr_h) 10*12facf1bSJunchao Zhang { 11*12facf1bSJunchao Zhang PetscScalarKokkosViewHost diag_h(diag_ptr_h, len); 12*12facf1bSJunchao Zhang auto diag_d = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), diag_h); 13*12facf1bSJunchao Zhang diag_dual = PetscScalarKokkosDualView(diag_d, diag_h); 14*12facf1bSJunchao Zhang } 15*12facf1bSJunchao Zhang 16*12facf1bSJunchao Zhang PetscErrorCode Update(const PetscScalar *diag_ptr_h) 17*12facf1bSJunchao Zhang { 18*12facf1bSJunchao Zhang PetscFunctionBegin; 19*12facf1bSJunchao Zhang PetscCheck(diag_dual.view_host().data() == diag_ptr_h, PETSC_COMM_SELF, PETSC_ERR_PLIB, "Host pointer has changed since last call"); 20*12facf1bSJunchao Zhang PetscCallCXX(diag_dual.modify_host()); /* mark the host has newer data */ 21*12facf1bSJunchao Zhang PetscCallCXX(diag_dual.sync_device()); 22*12facf1bSJunchao Zhang PetscFunctionReturn(0); 23*12facf1bSJunchao Zhang } 24*12facf1bSJunchao Zhang }; 25*12facf1bSJunchao Zhang 26*12facf1bSJunchao Zhang /* Make 'transpose' a template parameter instead of a function input parameter, so that 27*12facf1bSJunchao Zhang it will be a const in template instantiation and gets optimized out. 28*12facf1bSJunchao Zhang */ 29*12facf1bSJunchao Zhang template <PetscBool transpose> 30*12facf1bSJunchao Zhang static PetscErrorCode PCApplyOrTranspose_PBJacobi_Kokkos(PC pc, Vec x, Vec y) 31*12facf1bSJunchao Zhang { 32*12facf1bSJunchao Zhang PC_PBJacobi *jac = (PC_PBJacobi *)pc->data; 33*12facf1bSJunchao Zhang PC_PBJacobi_Kokkos *pckok = static_cast<PC_PBJacobi_Kokkos *>(jac->spptr); 34*12facf1bSJunchao Zhang ConstPetscScalarKokkosView xv; 35*12facf1bSJunchao Zhang PetscScalarKokkosView yv; 36*12facf1bSJunchao Zhang PetscScalarKokkosView Av = pckok->diag_dual.view_device(); 37*12facf1bSJunchao Zhang const PetscInt bs = jac->bs, mbs = jac->mbs, bs2 = bs * bs; 38*12facf1bSJunchao Zhang const char *label = transpose ? "PCApplyTranspose_PBJacobi_Kokkos" : "PCApply_PBJacobi_Kokkos"; 39*12facf1bSJunchao Zhang 40*12facf1bSJunchao Zhang PetscFunctionBegin; 41*12facf1bSJunchao Zhang PetscCall(PetscLogGpuTimeBegin()); 42*12facf1bSJunchao Zhang VecErrorIfNotKokkos(x); 43*12facf1bSJunchao Zhang VecErrorIfNotKokkos(y); 44*12facf1bSJunchao Zhang PetscCall(VecGetKokkosView(x, &xv)); 45*12facf1bSJunchao Zhang PetscCall(VecGetKokkosViewWrite(y, &yv)); 46*12facf1bSJunchao Zhang PetscCallCXX(Kokkos::parallel_for( 47*12facf1bSJunchao Zhang label, bs * mbs, KOKKOS_LAMBDA(PetscInt row) { 48*12facf1bSJunchao Zhang const PetscScalar *Ap, *xp; 49*12facf1bSJunchao Zhang PetscScalar *yp; 50*12facf1bSJunchao Zhang PetscInt i, j, k; 51*12facf1bSJunchao Zhang 52*12facf1bSJunchao Zhang k = row / bs; /* k-th block */ 53*12facf1bSJunchao Zhang i = row % bs; /* this thread deals with i-th row of the block */ 54*12facf1bSJunchao Zhang Ap = &Av(bs2 * k + i * (transpose ? bs : 1)); /* Ap points to the first entry of i-th row */ 55*12facf1bSJunchao Zhang xp = &xv(bs * k); 56*12facf1bSJunchao Zhang yp = &yv(bs * k); 57*12facf1bSJunchao Zhang /* multiply i-th row (column) with x */ 58*12facf1bSJunchao Zhang yp[i] = 0.0; 59*12facf1bSJunchao Zhang for (j = 0; j < bs; j++) { 60*12facf1bSJunchao Zhang yp[i] += Ap[0] * xp[j]; 61*12facf1bSJunchao Zhang Ap += (transpose ? 1 : bs); /* block is in column major order */ 62*12facf1bSJunchao Zhang } 63*12facf1bSJunchao Zhang })); 64*12facf1bSJunchao Zhang PetscCall(VecRestoreKokkosView(x, &xv)); 65*12facf1bSJunchao Zhang PetscCall(VecRestoreKokkosViewWrite(y, &yv)); 66*12facf1bSJunchao Zhang PetscCall(PetscLogGpuFlops(bs * bs * mbs * 2)); /* FMA on entries in all blocks */ 67*12facf1bSJunchao Zhang PetscCall(PetscLogGpuTimeEnd()); 68*12facf1bSJunchao Zhang PetscFunctionReturn(0); 69*12facf1bSJunchao Zhang } 70*12facf1bSJunchao Zhang 71*12facf1bSJunchao Zhang static PetscErrorCode PCDestroy_PBJacobi_Kokkos(PC pc) 72*12facf1bSJunchao Zhang { 73*12facf1bSJunchao Zhang PC_PBJacobi *jac = (PC_PBJacobi *)pc->data; 74*12facf1bSJunchao Zhang 75*12facf1bSJunchao Zhang PetscFunctionBegin; 76*12facf1bSJunchao Zhang PetscCallCXX(delete static_cast<PC_PBJacobi_Kokkos *>(jac->spptr)); 77*12facf1bSJunchao Zhang PetscCall(PCDestroy_PBJacobi(pc)); 78*12facf1bSJunchao Zhang PetscFunctionReturn(0); 79*12facf1bSJunchao Zhang } 80*12facf1bSJunchao Zhang 81*12facf1bSJunchao Zhang PETSC_INTERN PetscErrorCode PCSetUp_PBJacobi_Kokkos(PC pc) 82*12facf1bSJunchao Zhang { 83*12facf1bSJunchao Zhang PC_PBJacobi *jac = (PC_PBJacobi *)pc->data; 84*12facf1bSJunchao Zhang PetscInt len; 85*12facf1bSJunchao Zhang 86*12facf1bSJunchao Zhang PetscFunctionBegin; 87*12facf1bSJunchao Zhang PetscCall(PCSetUp_PBJacobi_Host(pc)); /* Compute the inverse on host now. Might worth doing it on device directly */ 88*12facf1bSJunchao Zhang len = jac->bs * jac->bs * jac->mbs; 89*12facf1bSJunchao Zhang if (!jac->spptr) { 90*12facf1bSJunchao Zhang PetscCallCXX(jac->spptr = new PC_PBJacobi_Kokkos(len, const_cast<PetscScalar *>(jac->diag))); 91*12facf1bSJunchao Zhang } else { 92*12facf1bSJunchao Zhang PC_PBJacobi_Kokkos *pckok = static_cast<PC_PBJacobi_Kokkos *>(jac->spptr); 93*12facf1bSJunchao Zhang PetscCall(pckok->Update(jac->diag)); 94*12facf1bSJunchao Zhang } 95*12facf1bSJunchao Zhang PetscCall(PetscLogCpuToGpu(sizeof(PetscScalar) * len)); 96*12facf1bSJunchao Zhang 97*12facf1bSJunchao Zhang pc->ops->apply = PCApplyOrTranspose_PBJacobi_Kokkos<PETSC_FALSE>; 98*12facf1bSJunchao Zhang pc->ops->applytranspose = PCApplyOrTranspose_PBJacobi_Kokkos<PETSC_TRUE>; 99*12facf1bSJunchao Zhang pc->ops->destroy = PCDestroy_PBJacobi_Kokkos; 100*12facf1bSJunchao Zhang PetscFunctionReturn(0); 101*12facf1bSJunchao Zhang } 102