xref: /petsc/src/ksp/pc/impls/pbjacobi/kokkos/pbjacobi_kok.kokkos.cxx (revision a24cdd0d2637ee09544835b4ece87d1f040d2dfd)
1 #include <petscvec_kokkos.hpp>
2 #include <petsc_kokkos.hpp>
3 #include <../src/vec/vec/impls/seq/kokkos/veckokkosimpl.hpp>
4 #include <petscdevice.h>
5 #include <../src/ksp/pc/impls/pbjacobi/pbjacobi.h>
6 
7 struct PC_PBJacobi_Kokkos {
8   PetscScalarKokkosDualView diag_dual;
9 
10   PC_PBJacobi_Kokkos(PetscInt len, PetscScalar *diag_ptr_h)
11   {
12     PetscScalarKokkosViewHost diag_h(diag_ptr_h, len);
13     auto                      diag_d = Kokkos::create_mirror_view_and_copy(PetscGetKokkosExecutionSpace(), diag_h);
14     diag_dual                        = PetscScalarKokkosDualView(diag_d, diag_h);
15   }
16 
17   PetscErrorCode Update(const PetscScalar *diag_ptr_h)
18   {
19     auto &exec = PetscGetKokkosExecutionSpace();
20 
21     PetscFunctionBegin;
22     PetscCheck(diag_dual.view_host().data() == diag_ptr_h, PETSC_COMM_SELF, PETSC_ERR_PLIB, "Host pointer has changed since last call");
23     PetscCallCXX(diag_dual.modify_host()); /* mark the host has newer data */
24     PetscCallCXX(diag_dual.sync_device(exec));
25     PetscFunctionReturn(PETSC_SUCCESS);
26   }
27 };
28 
29 /* Make 'transpose' a template parameter instead of a function input parameter, so that
30  it will be a const in template instantiation and gets optimized out.
31 */
32 template <PetscBool transpose>
33 static PetscErrorCode PCApplyOrTranspose_PBJacobi_Kokkos(PC pc, Vec x, Vec y)
34 {
35   PC_PBJacobi               *jac   = (PC_PBJacobi *)pc->data;
36   PC_PBJacobi_Kokkos        *pckok = static_cast<PC_PBJacobi_Kokkos *>(jac->spptr);
37   ConstPetscScalarKokkosView xv;
38   PetscScalarKokkosView      yv;
39   PetscScalarKokkosView      Av = pckok->diag_dual.view_device();
40   const PetscInt             bs = jac->bs, mbs = jac->mbs, bs2 = bs * bs;
41   const char                *label = transpose ? "PCApplyTranspose_PBJacobi_Kokkos" : "PCApply_PBJacobi_Kokkos";
42 
43   PetscFunctionBegin;
44   PetscCall(PetscLogGpuTimeBegin());
45   VecErrorIfNotKokkos(x);
46   VecErrorIfNotKokkos(y);
47   PetscCall(VecGetKokkosView(x, &xv));
48   PetscCall(VecGetKokkosViewWrite(y, &yv));
49   PetscCallCXX(Kokkos::parallel_for(
50     label, Kokkos::RangePolicy<>(PetscGetKokkosExecutionSpace(), 0, bs * mbs), KOKKOS_LAMBDA(PetscInt row) {
51       const PetscScalar *Ap, *xp;
52       PetscScalar       *yp;
53       PetscInt           i, j, k;
54 
55       k  = row / bs;                                /* k-th block */
56       i  = row % bs;                                /* this thread deals with i-th row of the block */
57       Ap = &Av(bs2 * k + i * (transpose ? bs : 1)); /* Ap points to the first entry of i-th row */
58       xp = &xv(bs * k);
59       yp = &yv(bs * k);
60       /* multiply i-th row (column) with x */
61       yp[i] = 0.0;
62       for (j = 0; j < bs; j++) {
63         yp[i] += Ap[0] * xp[j];
64         Ap += (transpose ? 1 : bs); /* block is in column major order */
65       }
66     }));
67   PetscCall(VecRestoreKokkosView(x, &xv));
68   PetscCall(VecRestoreKokkosViewWrite(y, &yv));
69   PetscCall(PetscLogGpuFlops(bs * bs * mbs * 2)); /* FMA on entries in all blocks */
70   PetscCall(PetscLogGpuTimeEnd());
71   PetscFunctionReturn(PETSC_SUCCESS);
72 }
73 
74 static PetscErrorCode PCDestroy_PBJacobi_Kokkos(PC pc)
75 {
76   PC_PBJacobi *jac = (PC_PBJacobi *)pc->data;
77 
78   PetscFunctionBegin;
79   PetscCallCXX(delete static_cast<PC_PBJacobi_Kokkos *>(jac->spptr));
80   PetscCall(PCDestroy_PBJacobi(pc));
81   PetscFunctionReturn(PETSC_SUCCESS);
82 }
83 
84 PETSC_INTERN PetscErrorCode PCSetUp_PBJacobi_Kokkos(PC pc, Mat diagPB)
85 {
86   PC_PBJacobi *jac = (PC_PBJacobi *)pc->data;
87   PetscInt     len;
88 
89   PetscFunctionBegin;
90   PetscCall(PCSetUp_PBJacobi_Host(pc, diagPB)); /* Compute the inverse on host now. Might worth doing it on device directly */
91   len = jac->bs * jac->bs * jac->mbs;
92   if (!jac->spptr) {
93     PetscCallCXX(jac->spptr = new PC_PBJacobi_Kokkos(len, const_cast<PetscScalar *>(jac->diag)));
94   } else {
95     PC_PBJacobi_Kokkos *pckok = static_cast<PC_PBJacobi_Kokkos *>(jac->spptr);
96     PetscCall(pckok->Update(jac->diag));
97   }
98   PetscCall(PetscLogCpuToGpu(sizeof(PetscScalar) * len));
99 
100   pc->ops->apply          = PCApplyOrTranspose_PBJacobi_Kokkos<PETSC_FALSE>;
101   pc->ops->applytranspose = PCApplyOrTranspose_PBJacobi_Kokkos<PETSC_TRUE>;
102   pc->ops->destroy        = PCDestroy_PBJacobi_Kokkos;
103   PetscFunctionReturn(PETSC_SUCCESS);
104 }
105