112facf1bSJunchao Zhang #include <petscvec_kokkos.hpp> 2a74aa489SJunchao Zhang #include <petsc_kokkos.hpp> 312facf1bSJunchao Zhang #include <../src/vec/vec/impls/seq/kokkos/veckokkosimpl.hpp> 412facf1bSJunchao Zhang #include <petscdevice.h> 512facf1bSJunchao Zhang #include <../src/ksp/pc/impls/pbjacobi/pbjacobi.h> 612facf1bSJunchao Zhang 712facf1bSJunchao Zhang struct PC_PBJacobi_Kokkos { 812facf1bSJunchao Zhang PetscScalarKokkosDualView diag_dual; 912facf1bSJunchao Zhang 1012facf1bSJunchao Zhang PC_PBJacobi_Kokkos(PetscInt len, PetscScalar *diag_ptr_h) 1112facf1bSJunchao Zhang { 1212facf1bSJunchao Zhang PetscScalarKokkosViewHost diag_h(diag_ptr_h, len); 13a74aa489SJunchao Zhang auto diag_d = Kokkos::create_mirror_view_and_copy(PetscGetKokkosExecutionSpace(), diag_h); 1412facf1bSJunchao Zhang diag_dual = PetscScalarKokkosDualView(diag_d, diag_h); 1512facf1bSJunchao Zhang } 1612facf1bSJunchao Zhang 1712facf1bSJunchao Zhang PetscErrorCode Update(const PetscScalar *diag_ptr_h) 1812facf1bSJunchao Zhang { 1912facf1bSJunchao Zhang PetscFunctionBegin; 2012facf1bSJunchao Zhang PetscCheck(diag_dual.view_host().data() == diag_ptr_h, PETSC_COMM_SELF, PETSC_ERR_PLIB, "Host pointer has changed since last call"); 2112facf1bSJunchao Zhang PetscCallCXX(diag_dual.modify_host()); /* mark the host has newer data */ 22*f3d3cd90SZach Atkins PetscCall(KokkosDualViewSyncDevice(diag_dual, PetscGetKokkosExecutionSpace())); 233ba16761SJacob Faibussowitsch PetscFunctionReturn(PETSC_SUCCESS); 2412facf1bSJunchao Zhang } 2512facf1bSJunchao Zhang }; 2612facf1bSJunchao Zhang 2712facf1bSJunchao Zhang /* Make 'transpose' a template parameter instead of a function input parameter, so that 2812facf1bSJunchao Zhang it will be a const in template instantiation and gets optimized out. 2912facf1bSJunchao Zhang */ 3012facf1bSJunchao Zhang template <PetscBool transpose> 3112facf1bSJunchao Zhang static PetscErrorCode PCApplyOrTranspose_PBJacobi_Kokkos(PC pc, Vec x, Vec y) 3212facf1bSJunchao Zhang { 3312facf1bSJunchao Zhang PC_PBJacobi *jac = (PC_PBJacobi *)pc->data; 3412facf1bSJunchao Zhang PC_PBJacobi_Kokkos *pckok = static_cast<PC_PBJacobi_Kokkos *>(jac->spptr); 3512facf1bSJunchao Zhang ConstPetscScalarKokkosView xv; 3612facf1bSJunchao Zhang PetscScalarKokkosView yv; 3712facf1bSJunchao Zhang PetscScalarKokkosView Av = pckok->diag_dual.view_device(); 3812facf1bSJunchao Zhang const PetscInt bs = jac->bs, mbs = jac->mbs, bs2 = bs * bs; 3912facf1bSJunchao Zhang const char *label = transpose ? "PCApplyTranspose_PBJacobi_Kokkos" : "PCApply_PBJacobi_Kokkos"; 4012facf1bSJunchao Zhang 4112facf1bSJunchao Zhang PetscFunctionBegin; 4212facf1bSJunchao Zhang PetscCall(PetscLogGpuTimeBegin()); 4312facf1bSJunchao Zhang VecErrorIfNotKokkos(x); 4412facf1bSJunchao Zhang VecErrorIfNotKokkos(y); 4512facf1bSJunchao Zhang PetscCall(VecGetKokkosView(x, &xv)); 4612facf1bSJunchao Zhang PetscCall(VecGetKokkosViewWrite(y, &yv)); 4712facf1bSJunchao Zhang PetscCallCXX(Kokkos::parallel_for( 48a74aa489SJunchao Zhang label, Kokkos::RangePolicy<>(PetscGetKokkosExecutionSpace(), 0, bs * mbs), KOKKOS_LAMBDA(PetscInt row) { 4912facf1bSJunchao Zhang const PetscScalar *Ap, *xp; 5012facf1bSJunchao Zhang PetscScalar *yp; 5112facf1bSJunchao Zhang PetscInt i, j, k; 5212facf1bSJunchao Zhang 5312facf1bSJunchao Zhang k = row / bs; /* k-th block */ 5412facf1bSJunchao Zhang i = row % bs; /* this thread deals with i-th row of the block */ 5512facf1bSJunchao Zhang Ap = &Av(bs2 * k + i * (transpose ? bs : 1)); /* Ap points to the first entry of i-th row */ 5612facf1bSJunchao Zhang xp = &xv(bs * k); 5712facf1bSJunchao Zhang yp = &yv(bs * k); 5812facf1bSJunchao Zhang /* multiply i-th row (column) with x */ 5912facf1bSJunchao Zhang yp[i] = 0.0; 6012facf1bSJunchao Zhang for (j = 0; j < bs; j++) { 6112facf1bSJunchao Zhang yp[i] += Ap[0] * xp[j]; 6212facf1bSJunchao Zhang Ap += (transpose ? 1 : bs); /* block is in column major order */ 6312facf1bSJunchao Zhang } 6412facf1bSJunchao Zhang })); 6512facf1bSJunchao Zhang PetscCall(VecRestoreKokkosView(x, &xv)); 6612facf1bSJunchao Zhang PetscCall(VecRestoreKokkosViewWrite(y, &yv)); 6712facf1bSJunchao Zhang PetscCall(PetscLogGpuFlops(bs * bs * mbs * 2)); /* FMA on entries in all blocks */ 6812facf1bSJunchao Zhang PetscCall(PetscLogGpuTimeEnd()); 693ba16761SJacob Faibussowitsch PetscFunctionReturn(PETSC_SUCCESS); 7012facf1bSJunchao Zhang } 7112facf1bSJunchao Zhang 7212facf1bSJunchao Zhang static PetscErrorCode PCDestroy_PBJacobi_Kokkos(PC pc) 7312facf1bSJunchao Zhang { 7412facf1bSJunchao Zhang PC_PBJacobi *jac = (PC_PBJacobi *)pc->data; 7512facf1bSJunchao Zhang 7612facf1bSJunchao Zhang PetscFunctionBegin; 7712facf1bSJunchao Zhang PetscCallCXX(delete static_cast<PC_PBJacobi_Kokkos *>(jac->spptr)); 7812facf1bSJunchao Zhang PetscCall(PCDestroy_PBJacobi(pc)); 793ba16761SJacob Faibussowitsch PetscFunctionReturn(PETSC_SUCCESS); 8012facf1bSJunchao Zhang } 8112facf1bSJunchao Zhang 8242ce410bSJunchao Zhang PETSC_INTERN PetscErrorCode PCSetUp_PBJacobi_Kokkos(PC pc, Mat diagPB) 8312facf1bSJunchao Zhang { 8412facf1bSJunchao Zhang PC_PBJacobi *jac = (PC_PBJacobi *)pc->data; 8512facf1bSJunchao Zhang PetscInt len; 8612facf1bSJunchao Zhang 8712facf1bSJunchao Zhang PetscFunctionBegin; 8842ce410bSJunchao Zhang PetscCall(PCSetUp_PBJacobi_Host(pc, diagPB)); /* Compute the inverse on host now. Might worth doing it on device directly */ 8912facf1bSJunchao Zhang len = jac->bs * jac->bs * jac->mbs; 9012facf1bSJunchao Zhang if (!jac->spptr) { 9112facf1bSJunchao Zhang PetscCallCXX(jac->spptr = new PC_PBJacobi_Kokkos(len, const_cast<PetscScalar *>(jac->diag))); 9212facf1bSJunchao Zhang } else { 9312facf1bSJunchao Zhang PC_PBJacobi_Kokkos *pckok = static_cast<PC_PBJacobi_Kokkos *>(jac->spptr); 9412facf1bSJunchao Zhang PetscCall(pckok->Update(jac->diag)); 9512facf1bSJunchao Zhang } 9612facf1bSJunchao Zhang 9712facf1bSJunchao Zhang pc->ops->apply = PCApplyOrTranspose_PBJacobi_Kokkos<PETSC_FALSE>; 9812facf1bSJunchao Zhang pc->ops->applytranspose = PCApplyOrTranspose_PBJacobi_Kokkos<PETSC_TRUE>; 9912facf1bSJunchao Zhang pc->ops->destroy = PCDestroy_PBJacobi_Kokkos; 1003ba16761SJacob Faibussowitsch PetscFunctionReturn(PETSC_SUCCESS); 10112facf1bSJunchao Zhang } 102