xref: /petsc/src/ksp/pc/impls/pbjacobi/kokkos/pbjacobi_kok.kokkos.cxx (revision 42ce410b93811b6a1b27796b46869d3e6c55267a)
112facf1bSJunchao Zhang #include <petscvec_kokkos.hpp>
2a74aa489SJunchao Zhang #include <petsc_kokkos.hpp>
312facf1bSJunchao Zhang #include <../src/vec/vec/impls/seq/kokkos/veckokkosimpl.hpp>
412facf1bSJunchao Zhang #include <petscdevice.h>
512facf1bSJunchao Zhang #include <../src/ksp/pc/impls/pbjacobi/pbjacobi.h>
612facf1bSJunchao Zhang 
712facf1bSJunchao Zhang struct PC_PBJacobi_Kokkos {
812facf1bSJunchao Zhang   PetscScalarKokkosDualView diag_dual;
912facf1bSJunchao Zhang 
1012facf1bSJunchao Zhang   PC_PBJacobi_Kokkos(PetscInt len, PetscScalar *diag_ptr_h)
1112facf1bSJunchao Zhang   {
1212facf1bSJunchao Zhang     PetscScalarKokkosViewHost diag_h(diag_ptr_h, len);
13a74aa489SJunchao Zhang     auto                      diag_d = Kokkos::create_mirror_view_and_copy(PetscGetKokkosExecutionSpace(), diag_h);
1412facf1bSJunchao Zhang     diag_dual                        = PetscScalarKokkosDualView(diag_d, diag_h);
1512facf1bSJunchao Zhang   }
1612facf1bSJunchao Zhang 
1712facf1bSJunchao Zhang   PetscErrorCode Update(const PetscScalar *diag_ptr_h)
1812facf1bSJunchao Zhang   {
19a74aa489SJunchao Zhang     auto &exec = PetscGetKokkosExecutionSpace();
20a74aa489SJunchao Zhang 
2112facf1bSJunchao Zhang     PetscFunctionBegin;
2212facf1bSJunchao Zhang     PetscCheck(diag_dual.view_host().data() == diag_ptr_h, PETSC_COMM_SELF, PETSC_ERR_PLIB, "Host pointer has changed since last call");
2312facf1bSJunchao Zhang     PetscCallCXX(diag_dual.modify_host()); /* mark the host has newer data */
24a74aa489SJunchao Zhang     PetscCallCXX(diag_dual.sync_device(exec));
253ba16761SJacob Faibussowitsch     PetscFunctionReturn(PETSC_SUCCESS);
2612facf1bSJunchao Zhang   }
2712facf1bSJunchao Zhang };
2812facf1bSJunchao Zhang 
2912facf1bSJunchao Zhang /* Make 'transpose' a template parameter instead of a function input parameter, so that
3012facf1bSJunchao Zhang  it will be a const in template instantiation and gets optimized out.
3112facf1bSJunchao Zhang */
3212facf1bSJunchao Zhang template <PetscBool transpose>
3312facf1bSJunchao Zhang static PetscErrorCode PCApplyOrTranspose_PBJacobi_Kokkos(PC pc, Vec x, Vec y)
3412facf1bSJunchao Zhang {
3512facf1bSJunchao Zhang   PC_PBJacobi               *jac   = (PC_PBJacobi *)pc->data;
3612facf1bSJunchao Zhang   PC_PBJacobi_Kokkos        *pckok = static_cast<PC_PBJacobi_Kokkos *>(jac->spptr);
3712facf1bSJunchao Zhang   ConstPetscScalarKokkosView xv;
3812facf1bSJunchao Zhang   PetscScalarKokkosView      yv;
3912facf1bSJunchao Zhang   PetscScalarKokkosView      Av = pckok->diag_dual.view_device();
4012facf1bSJunchao Zhang   const PetscInt             bs = jac->bs, mbs = jac->mbs, bs2 = bs * bs;
4112facf1bSJunchao Zhang   const char                *label = transpose ? "PCApplyTranspose_PBJacobi_Kokkos" : "PCApply_PBJacobi_Kokkos";
4212facf1bSJunchao Zhang 
4312facf1bSJunchao Zhang   PetscFunctionBegin;
4412facf1bSJunchao Zhang   PetscCall(PetscLogGpuTimeBegin());
4512facf1bSJunchao Zhang   VecErrorIfNotKokkos(x);
4612facf1bSJunchao Zhang   VecErrorIfNotKokkos(y);
4712facf1bSJunchao Zhang   PetscCall(VecGetKokkosView(x, &xv));
4812facf1bSJunchao Zhang   PetscCall(VecGetKokkosViewWrite(y, &yv));
4912facf1bSJunchao Zhang   PetscCallCXX(Kokkos::parallel_for(
50a74aa489SJunchao Zhang     label, Kokkos::RangePolicy<>(PetscGetKokkosExecutionSpace(), 0, bs * mbs), KOKKOS_LAMBDA(PetscInt row) {
5112facf1bSJunchao Zhang       const PetscScalar *Ap, *xp;
5212facf1bSJunchao Zhang       PetscScalar       *yp;
5312facf1bSJunchao Zhang       PetscInt           i, j, k;
5412facf1bSJunchao Zhang 
5512facf1bSJunchao Zhang       k  = row / bs;                                /* k-th block */
5612facf1bSJunchao Zhang       i  = row % bs;                                /* this thread deals with i-th row of the block */
5712facf1bSJunchao Zhang       Ap = &Av(bs2 * k + i * (transpose ? bs : 1)); /* Ap points to the first entry of i-th row */
5812facf1bSJunchao Zhang       xp = &xv(bs * k);
5912facf1bSJunchao Zhang       yp = &yv(bs * k);
6012facf1bSJunchao Zhang       /* multiply i-th row (column) with x */
6112facf1bSJunchao Zhang       yp[i] = 0.0;
6212facf1bSJunchao Zhang       for (j = 0; j < bs; j++) {
6312facf1bSJunchao Zhang         yp[i] += Ap[0] * xp[j];
6412facf1bSJunchao Zhang         Ap += (transpose ? 1 : bs); /* block is in column major order */
6512facf1bSJunchao Zhang       }
6612facf1bSJunchao Zhang     }));
6712facf1bSJunchao Zhang   PetscCall(VecRestoreKokkosView(x, &xv));
6812facf1bSJunchao Zhang   PetscCall(VecRestoreKokkosViewWrite(y, &yv));
6912facf1bSJunchao Zhang   PetscCall(PetscLogGpuFlops(bs * bs * mbs * 2)); /* FMA on entries in all blocks */
7012facf1bSJunchao Zhang   PetscCall(PetscLogGpuTimeEnd());
713ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
7212facf1bSJunchao Zhang }
7312facf1bSJunchao Zhang 
7412facf1bSJunchao Zhang static PetscErrorCode PCDestroy_PBJacobi_Kokkos(PC pc)
7512facf1bSJunchao Zhang {
7612facf1bSJunchao Zhang   PC_PBJacobi *jac = (PC_PBJacobi *)pc->data;
7712facf1bSJunchao Zhang 
7812facf1bSJunchao Zhang   PetscFunctionBegin;
7912facf1bSJunchao Zhang   PetscCallCXX(delete static_cast<PC_PBJacobi_Kokkos *>(jac->spptr));
8012facf1bSJunchao Zhang   PetscCall(PCDestroy_PBJacobi(pc));
813ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
8212facf1bSJunchao Zhang }
8312facf1bSJunchao Zhang 
84*42ce410bSJunchao Zhang PETSC_INTERN PetscErrorCode PCSetUp_PBJacobi_Kokkos(PC pc, Mat diagPB)
8512facf1bSJunchao Zhang {
8612facf1bSJunchao Zhang   PC_PBJacobi *jac = (PC_PBJacobi *)pc->data;
8712facf1bSJunchao Zhang   PetscInt     len;
8812facf1bSJunchao Zhang 
8912facf1bSJunchao Zhang   PetscFunctionBegin;
90*42ce410bSJunchao Zhang   PetscCall(PCSetUp_PBJacobi_Host(pc, diagPB)); /* Compute the inverse on host now. Might worth doing it on device directly */
9112facf1bSJunchao Zhang   len = jac->bs * jac->bs * jac->mbs;
9212facf1bSJunchao Zhang   if (!jac->spptr) {
9312facf1bSJunchao Zhang     PetscCallCXX(jac->spptr = new PC_PBJacobi_Kokkos(len, const_cast<PetscScalar *>(jac->diag)));
9412facf1bSJunchao Zhang   } else {
9512facf1bSJunchao Zhang     PC_PBJacobi_Kokkos *pckok = static_cast<PC_PBJacobi_Kokkos *>(jac->spptr);
9612facf1bSJunchao Zhang     PetscCall(pckok->Update(jac->diag));
9712facf1bSJunchao Zhang   }
9812facf1bSJunchao Zhang   PetscCall(PetscLogCpuToGpu(sizeof(PetscScalar) * len));
9912facf1bSJunchao Zhang 
10012facf1bSJunchao Zhang   pc->ops->apply          = PCApplyOrTranspose_PBJacobi_Kokkos<PETSC_FALSE>;
10112facf1bSJunchao Zhang   pc->ops->applytranspose = PCApplyOrTranspose_PBJacobi_Kokkos<PETSC_TRUE>;
10212facf1bSJunchao Zhang   pc->ops->destroy        = PCDestroy_PBJacobi_Kokkos;
1033ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
10412facf1bSJunchao Zhang }
105