xref: /petsc/src/ksp/pc/impls/vpbjacobi/kokkos/vpbjacobi_kok.kokkos.cxx (revision 934c28ddc29f2ef830f40fcfadab042dd386ea01)
1d326c3f1SJunchao Zhang #include <petsc_kokkos.hpp>
2f1be3500SJunchao Zhang #include <petscvec_kokkos.hpp>
3f1be3500SJunchao Zhang #include <../src/vec/vec/impls/seq/kokkos/veckokkosimpl.hpp>
4f1be3500SJunchao Zhang #include <petscdevice.h>
5f1be3500SJunchao Zhang #include <../src/ksp/pc/impls/vpbjacobi/vpbjacobi.h>
69d13fa56SJunchao Zhang #include <../src/mat/impls/aij/seq/kokkos/aijkok.hpp> // for MatInvertVariableBlockDiagonal_SeqAIJKokkos
75994c5adSJunchao Zhang #include <../src/mat/impls/aij/mpi/mpiaij.h>          // for Mat_MPIAIJ
86dc07963SJunchao Zhang #include <KokkosBlas2_gemv.hpp>
9f1be3500SJunchao Zhang 
10f1be3500SJunchao Zhang /* A class that manages helper arrays assisting parallel PCApply() with Kokkos */
11f1be3500SJunchao Zhang struct PC_VPBJacobi_Kokkos {
12f1be3500SJunchao Zhang   /* Cache the old sizes to check if we need realloc */
13f1be3500SJunchao Zhang   PetscInt n;       /* number of rows of the local matrix */
14f1be3500SJunchao Zhang   PetscInt nblocks; /* number of point blocks */
159d13fa56SJunchao Zhang   PetscInt nsize;   /* sum of sizes (elements) of the point blocks */
16f1be3500SJunchao Zhang 
17f1be3500SJunchao Zhang   /* Helper arrays that are pre-computed on host and then copied to device.
18f1be3500SJunchao Zhang     bs:     [nblocks+1], "csr" version of bsizes[]
19f1be3500SJunchao Zhang     bs2:    [nblocks+1], "csr" version of squares of bsizes[]
209d13fa56SJunchao Zhang     blkMap: [n], row i of the local matrix belongs to the blkMap[i]-th block
21f1be3500SJunchao Zhang   */
229d13fa56SJunchao Zhang   PetscIntKokkosDualView bs_dual, bs2_dual, blkMap_dual;
239d13fa56SJunchao Zhang   PetscScalarKokkosView  diag; // buffer to store diagonal blocks
249d13fa56SJunchao Zhang   PetscScalarKokkosView  work; // work buffer, with the same size as diag[]
259d13fa56SJunchao Zhang   PetscLogDouble         setupFlops;
26f1be3500SJunchao Zhang 
279d13fa56SJunchao Zhang   // clang-format off
289d13fa56SJunchao Zhang   // n:               size of the matrix
299d13fa56SJunchao Zhang   // nblocks:         number of blocks
309d13fa56SJunchao Zhang   // nsize:           sum bsizes[i]^2 for i=0..nblocks
319d13fa56SJunchao Zhang   // bsizes[nblocks]: sizes of blocks
329d13fa56SJunchao Zhang   PC_VPBJacobi_Kokkos(PetscInt n, PetscInt nblocks, PetscInt nsize, const PetscInt *bsizes) :
339d13fa56SJunchao Zhang     n(n), nblocks(nblocks), nsize(nsize), bs_dual(NoInit("bs_dual"), nblocks + 1),
349d13fa56SJunchao Zhang     bs2_dual(NoInit("bs2_dual"), nblocks + 1), blkMap_dual(NoInit("blkMap_dual"), n),
359d13fa56SJunchao Zhang     diag(NoInit("diag"), nsize), work(NoInit("work"), nsize)
36d71ae5a4SJacob Faibussowitsch   {
379d13fa56SJunchao Zhang     PetscCallVoid(BuildHelperArrays(bsizes));
38f1be3500SJunchao Zhang   }
399d13fa56SJunchao Zhang   // clang-format on
40f1be3500SJunchao Zhang 
41f1be3500SJunchao Zhang private:
429d13fa56SJunchao Zhang   PetscErrorCode BuildHelperArrays(const PetscInt *bsizes)
43d71ae5a4SJacob Faibussowitsch   {
44f1be3500SJunchao Zhang     PetscInt *bs_h     = bs_dual.view_host().data();
45f1be3500SJunchao Zhang     PetscInt *bs2_h    = bs2_dual.view_host().data();
469d13fa56SJunchao Zhang     PetscInt *blkMap_h = blkMap_dual.view_host().data();
47f1be3500SJunchao Zhang 
48f1be3500SJunchao Zhang     PetscFunctionBegin;
499d13fa56SJunchao Zhang     setupFlops = 0.0;
50f1be3500SJunchao Zhang     bs_h[0] = bs2_h[0] = 0;
51f1be3500SJunchao Zhang     for (PetscInt i = 0; i < nblocks; i++) {
529d13fa56SJunchao Zhang       PetscInt m   = bsizes[i];
539d13fa56SJunchao Zhang       bs_h[i + 1]  = bs_h[i] + m;
549d13fa56SJunchao Zhang       bs2_h[i + 1] = bs2_h[i] + m * m;
559d13fa56SJunchao Zhang       for (PetscInt j = 0; j < m; j++) blkMap_h[bs_h[i] + j] = i;
569d13fa56SJunchao Zhang       // m^3/3 FMA for A=LU factorization; m^3 FMA for solving (LU)X=I to get the inverse
579d13fa56SJunchao Zhang       setupFlops += 8.0 * m * m * m / 3;
58f1be3500SJunchao Zhang     }
599d13fa56SJunchao Zhang 
609d13fa56SJunchao Zhang     PetscCallCXX(bs_dual.modify_host());
619d13fa56SJunchao Zhang     PetscCallCXX(bs2_dual.modify_host());
629d13fa56SJunchao Zhang     PetscCallCXX(blkMap_dual.modify_host());
63*f3d3cd90SZach Atkins     PetscCall(KokkosDualViewSyncDevice(bs_dual, PetscGetKokkosExecutionSpace()));
64*f3d3cd90SZach Atkins     PetscCall(KokkosDualViewSyncDevice(bs2_dual, PetscGetKokkosExecutionSpace()));
65*f3d3cd90SZach Atkins     PetscCall(KokkosDualViewSyncDevice(blkMap_dual, PetscGetKokkosExecutionSpace()));
663ba16761SJacob Faibussowitsch     PetscFunctionReturn(PETSC_SUCCESS);
67f1be3500SJunchao Zhang   }
68f1be3500SJunchao Zhang };
69f1be3500SJunchao Zhang 
7069eda9daSJed Brown template <PetscBool transpose>
7169eda9daSJed Brown static PetscErrorCode PCApplyOrTranspose_VPBJacobi_Kokkos(PC pc, Vec x, Vec y)
72d71ae5a4SJacob Faibussowitsch {
73f1be3500SJunchao Zhang   PC_VPBJacobi              *jac   = (PC_VPBJacobi *)pc->data;
74f1be3500SJunchao Zhang   PC_VPBJacobi_Kokkos       *pckok = static_cast<PC_VPBJacobi_Kokkos *>(jac->spptr);
75f1be3500SJunchao Zhang   ConstPetscScalarKokkosView xv;
76f1be3500SJunchao Zhang   PetscScalarKokkosView      yv;
779d13fa56SJunchao Zhang   PetscScalarKokkosView      diag   = pckok->diag;
78f1be3500SJunchao Zhang   PetscIntKokkosView         bs     = pckok->bs_dual.view_device();
79f1be3500SJunchao Zhang   PetscIntKokkosView         bs2    = pckok->bs2_dual.view_device();
809d13fa56SJunchao Zhang   PetscIntKokkosView         blkMap = pckok->blkMap_dual.view_device();
816dc07963SJunchao Zhang   const char                *label  = transpose ? "PCApplyTranspose_VPBJacobi" : "PCApply_VPBJacobi";
82f1be3500SJunchao Zhang 
83f1be3500SJunchao Zhang   PetscFunctionBegin;
849a56b474SJunchao Zhang   PetscCall(PetscLogGpuTimeBegin());
85f1be3500SJunchao Zhang   VecErrorIfNotKokkos(x);
86f1be3500SJunchao Zhang   VecErrorIfNotKokkos(y);
87f1be3500SJunchao Zhang   PetscCall(VecGetKokkosView(x, &xv));
88f1be3500SJunchao Zhang   PetscCall(VecGetKokkosViewWrite(y, &yv));
896dc07963SJunchao Zhang #if 0 // TODO: Why the TeamGemv version is 2x worse than the naive one?
906dc07963SJunchao Zhang   PetscCallCXX(Kokkos::parallel_for(
91d326c3f1SJunchao Zhang     label, Kokkos::TeamPolicy<>(PetscGetKokkosExecutionSpace(), jac->nblocks, Kokkos::AUTO()), KOKKOS_LAMBDA(const KokkosTeamMemberType &team) {
926dc07963SJunchao Zhang       PetscInt           bid  = team.league_rank();    // block id
936dc07963SJunchao Zhang       PetscInt           n    = bs(bid + 1) - bs(bid); // size of this block
946dc07963SJunchao Zhang       const PetscScalar *bbuf = &diag(bs2(bid));
956dc07963SJunchao Zhang       const PetscScalar *xbuf = &xv(bs(bid));
966dc07963SJunchao Zhang       PetscScalar       *ybuf = &yv(bs(bid));
976dc07963SJunchao Zhang       const auto        &B    = Kokkos::View<const PetscScalar **, Kokkos::LayoutLeft>(bbuf, n, n); // wrap it in a 2D view in column-major order
986dc07963SJunchao Zhang       const auto        &x1   = ConstPetscScalarKokkosView(xbuf, n);
996dc07963SJunchao Zhang       const auto        &y1   = PetscScalarKokkosView(ybuf, n);
1006dc07963SJunchao Zhang       if (transpose) {
1016dc07963SJunchao Zhang         KokkosBlas::TeamGemv<KokkosTeamMemberType, KokkosBlas::Trans::Transpose>::invoke(team, 1., B, x1, 0., y1); // y1 = 0.0 * y1 + 1.0 * B^T * x1
1026dc07963SJunchao Zhang       } else {
1036dc07963SJunchao Zhang         KokkosBlas::TeamGemv<KokkosTeamMemberType, KokkosBlas::Trans::NoTranspose>::invoke(team, 1., B, x1, 0., y1); // y1 = 0.0 * y1 + 1.0 * B * x1
1046dc07963SJunchao Zhang       }
1056dc07963SJunchao Zhang     }));
1066dc07963SJunchao Zhang #else
1079371c9d4SSatish Balay   PetscCallCXX(Kokkos::parallel_for(
108d326c3f1SJunchao Zhang     label, Kokkos::RangePolicy<>(PetscGetKokkosExecutionSpace(), 0, pckok->n), KOKKOS_LAMBDA(PetscInt row) {
1096dc07963SJunchao Zhang       const PetscScalar *Bp, *xp;
110f1be3500SJunchao Zhang       PetscScalar       *yp;
111f1be3500SJunchao Zhang       PetscInt           i, j, k, m;
112f1be3500SJunchao Zhang 
1139d13fa56SJunchao Zhang       k  = blkMap(row);                             /* k-th block/matrix */
114f1be3500SJunchao Zhang       m  = bs(k + 1) - bs(k);                       /* block size of the k-th block */
115f1be3500SJunchao Zhang       i  = row - bs(k);                             /* i-th row of the block */
1166dc07963SJunchao Zhang       Bp = &diag(bs2(k) + i * (transpose ? m : 1)); /* Bp points to the first entry of i-th row/column */
117f1be3500SJunchao Zhang       xp = &xv(bs(k));
118f1be3500SJunchao Zhang       yp = &yv(bs(k));
119f1be3500SJunchao Zhang 
120f1be3500SJunchao Zhang       yp[i] = 0.0;
1219371c9d4SSatish Balay       for (j = 0; j < m; j++) {
1226dc07963SJunchao Zhang         yp[i] += Bp[0] * xp[j];
1236dc07963SJunchao Zhang         Bp += transpose ? 1 : m;
1249371c9d4SSatish Balay       }
125f1be3500SJunchao Zhang     }));
1266dc07963SJunchao Zhang #endif
127f1be3500SJunchao Zhang   PetscCall(VecRestoreKokkosView(x, &xv));
128f1be3500SJunchao Zhang   PetscCall(VecRestoreKokkosViewWrite(y, &yv));
1299a56b474SJunchao Zhang   PetscCall(PetscLogGpuFlops(pckok->nsize * 2)); /* FMA on entries in all blocks */
1309a56b474SJunchao Zhang   PetscCall(PetscLogGpuTimeEnd());
1313ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
132f1be3500SJunchao Zhang }
133f1be3500SJunchao Zhang 
134d71ae5a4SJacob Faibussowitsch static PetscErrorCode PCDestroy_VPBJacobi_Kokkos(PC pc)
135d71ae5a4SJacob Faibussowitsch {
136f1be3500SJunchao Zhang   PC_VPBJacobi *jac = (PC_VPBJacobi *)pc->data;
137f1be3500SJunchao Zhang 
138f1be3500SJunchao Zhang   PetscFunctionBegin;
139f1be3500SJunchao Zhang   PetscCallCXX(delete static_cast<PC_VPBJacobi_Kokkos *>(jac->spptr));
140f1be3500SJunchao Zhang   PetscCall(PCDestroy_VPBJacobi(pc));
1413ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
142f1be3500SJunchao Zhang }
143f1be3500SJunchao Zhang 
14442ce410bSJunchao Zhang PETSC_INTERN PetscErrorCode PCSetUp_VPBJacobi_Kokkos(PC pc, Mat diagVPB)
145d71ae5a4SJacob Faibussowitsch {
146f1be3500SJunchao Zhang   PC_VPBJacobi        *jac   = (PC_VPBJacobi *)pc->data;
147f1be3500SJunchao Zhang   PC_VPBJacobi_Kokkos *pckok = static_cast<PC_VPBJacobi_Kokkos *>(jac->spptr);
1489d13fa56SJunchao Zhang   PetscInt             i, nlocal, nblocks, nsize = 0;
149f1be3500SJunchao Zhang   const PetscInt      *bsizes;
1505994c5adSJunchao Zhang   PetscBool            ismpi;
1515994c5adSJunchao Zhang   Mat                  A;
152f1be3500SJunchao Zhang 
153f1be3500SJunchao Zhang   PetscFunctionBegin;
154f1be3500SJunchao Zhang   PetscCall(MatGetVariableBlockSizes(pc->pmat, &nblocks, &bsizes));
1559d13fa56SJunchao Zhang   PetscCall(MatGetLocalSize(pc->pmat, &nlocal, NULL));
1569d13fa56SJunchao Zhang   PetscCheck(!nlocal || nblocks, PETSC_COMM_SELF, PETSC_ERR_ARG_WRONGSTATE, "Must call MatSetVariableBlockSizes() before using PCVPBJACOBI");
157f1be3500SJunchao Zhang 
1589d13fa56SJunchao Zhang   if (!jac->diag) {
1591690c2aeSBarry Smith     PetscInt max_bs = -1, min_bs = PETSC_INT_MAX;
1609d13fa56SJunchao Zhang     for (i = 0; i < nblocks; i++) {
1619d13fa56SJunchao Zhang       min_bs = PetscMin(min_bs, bsizes[i]);
1629d13fa56SJunchao Zhang       max_bs = PetscMax(max_bs, bsizes[i]);
1639d13fa56SJunchao Zhang       nsize += bsizes[i] * bsizes[i];
1649d13fa56SJunchao Zhang     }
1659d13fa56SJunchao Zhang     jac->nblocks = nblocks;
1669d13fa56SJunchao Zhang     jac->min_bs  = min_bs;
1679d13fa56SJunchao Zhang     jac->max_bs  = max_bs;
1689d13fa56SJunchao Zhang   }
1699d13fa56SJunchao Zhang 
1709d13fa56SJunchao Zhang   // If one calls MatSetVariableBlockSizes() multiple times and sizes have been changed (is it allowed?), we delete the old and rebuild anyway
1719d13fa56SJunchao Zhang   if (pckok && (pckok->n != nlocal || pckok->nblocks != nblocks || pckok->nsize != nsize)) {
172f1be3500SJunchao Zhang     PetscCallCXX(delete pckok);
173f1be3500SJunchao Zhang     pckok = nullptr;
174f1be3500SJunchao Zhang   }
175f1be3500SJunchao Zhang 
1769d13fa56SJunchao Zhang   PetscCall(PetscLogGpuTimeBegin());
1779d13fa56SJunchao Zhang   if (!pckok) {
1789d13fa56SJunchao Zhang     PetscCallCXX(pckok = new PC_VPBJacobi_Kokkos(nlocal, nblocks, nsize, bsizes));
1799d13fa56SJunchao Zhang     jac->spptr = pckok;
180f1be3500SJunchao Zhang   }
181f1be3500SJunchao Zhang 
1829d13fa56SJunchao Zhang   // Extract diagonal blocks from the matrix and compute their inverse
1839d13fa56SJunchao Zhang   const auto &bs     = pckok->bs_dual.view_device();
1849d13fa56SJunchao Zhang   const auto &bs2    = pckok->bs2_dual.view_device();
1859d13fa56SJunchao Zhang   const auto &blkMap = pckok->blkMap_dual.view_device();
18642ce410bSJunchao Zhang   if (diagVPB) { // If caller provided a matrix made of the diagonal blocks, use it
187077598ddSJames Wright     PetscCall(PetscObjectBaseTypeCompare((PetscObject)diagVPB, MATMPIAIJ, &ismpi));
188077598ddSJames Wright     A = ismpi ? static_cast<Mat_MPIAIJ *>(diagVPB->data)->A : diagVPB;
18942ce410bSJunchao Zhang   } else {
190077598ddSJames Wright     PetscCall(PetscObjectBaseTypeCompare((PetscObject)pc->pmat, MATMPIAIJ, &ismpi));
191f4f49eeaSPierre Jolivet     A = ismpi ? static_cast<Mat_MPIAIJ *>(pc->pmat->data)->A : pc->pmat;
19242ce410bSJunchao Zhang   }
1935994c5adSJunchao Zhang   PetscCall(MatInvertVariableBlockDiagonal_SeqAIJKokkos(A, bs, bs2, blkMap, pckok->work, pckok->diag));
19469eda9daSJed Brown   pc->ops->apply          = PCApplyOrTranspose_VPBJacobi_Kokkos<PETSC_FALSE>;
19569eda9daSJed Brown   pc->ops->applytranspose = PCApplyOrTranspose_VPBJacobi_Kokkos<PETSC_TRUE>;
196f1be3500SJunchao Zhang   pc->ops->destroy        = PCDestroy_VPBJacobi_Kokkos;
1979d13fa56SJunchao Zhang   PetscCall(PetscLogGpuTimeEnd());
1989d13fa56SJunchao Zhang   PetscCall(PetscLogGpuFlops(pckok->setupFlops));
1993ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
200f1be3500SJunchao Zhang }
201