xref: /petsc/src/ksp/pc/impls/vpbjacobi/kokkos/vpbjacobi_kok.kokkos.cxx (revision 9d13fa56c5c6523e02c36edc0e4e22bf2d0334a8)
1f1be3500SJunchao Zhang #include <petscvec_kokkos.hpp>
2f1be3500SJunchao Zhang #include <../src/vec/vec/impls/seq/kokkos/veckokkosimpl.hpp>
3f1be3500SJunchao Zhang #include <petscdevice.h>
4f1be3500SJunchao Zhang #include <../src/ksp/pc/impls/vpbjacobi/vpbjacobi.h>
5*9d13fa56SJunchao Zhang #include <../src/mat/impls/aij/seq/kokkos/aijkok.hpp> // for MatInvertVariableBlockDiagonal_SeqAIJKokkos
6f1be3500SJunchao Zhang 
7f1be3500SJunchao Zhang /* A class that manages helper arrays assisting parallel PCApply() with Kokkos */
8f1be3500SJunchao Zhang struct PC_VPBJacobi_Kokkos {
9f1be3500SJunchao Zhang   /* Cache the old sizes to check if we need realloc */
10f1be3500SJunchao Zhang   PetscInt n;       /* number of rows of the local matrix */
11f1be3500SJunchao Zhang   PetscInt nblocks; /* number of point blocks */
12*9d13fa56SJunchao Zhang   PetscInt nsize;   /* sum of sizes (elements) of the point blocks */
13f1be3500SJunchao Zhang 
14f1be3500SJunchao Zhang   /* Helper arrays that are pre-computed on host and then copied to device.
15f1be3500SJunchao Zhang     bs:     [nblocks+1], "csr" version of bsizes[]
16f1be3500SJunchao Zhang     bs2:    [nblocks+1], "csr" version of squares of bsizes[]
17*9d13fa56SJunchao Zhang     blkMap: [n], row i of the local matrix belongs to the blkMap[i]-th block
18f1be3500SJunchao Zhang   */
19*9d13fa56SJunchao Zhang   PetscIntKokkosDualView bs_dual, bs2_dual, blkMap_dual;
20*9d13fa56SJunchao Zhang   PetscScalarKokkosView  diag; // buffer to store diagonal blocks
21*9d13fa56SJunchao Zhang   PetscScalarKokkosView  work; // work buffer, with the same size as diag[]
22*9d13fa56SJunchao Zhang   PetscLogDouble         setupFlops;
23f1be3500SJunchao Zhang 
24*9d13fa56SJunchao Zhang   // clang-format off
25*9d13fa56SJunchao Zhang   // n:               size of the matrix
26*9d13fa56SJunchao Zhang   // nblocks:         number of blocks
27*9d13fa56SJunchao Zhang   // nsize:           sum bsizes[i]^2 for i=0..nblocks
28*9d13fa56SJunchao Zhang   // bsizes[nblocks]: sizes of blocks
29*9d13fa56SJunchao Zhang   PC_VPBJacobi_Kokkos(PetscInt n, PetscInt nblocks, PetscInt nsize, const PetscInt *bsizes) :
30*9d13fa56SJunchao Zhang     n(n), nblocks(nblocks), nsize(nsize), bs_dual(NoInit("bs_dual"), nblocks + 1),
31*9d13fa56SJunchao Zhang     bs2_dual(NoInit("bs2_dual"), nblocks + 1), blkMap_dual(NoInit("blkMap_dual"), n),
32*9d13fa56SJunchao Zhang     diag(NoInit("diag"), nsize), work(NoInit("work"), nsize)
33d71ae5a4SJacob Faibussowitsch   {
34*9d13fa56SJunchao Zhang     PetscCallVoid(BuildHelperArrays(bsizes));
35f1be3500SJunchao Zhang   }
36*9d13fa56SJunchao Zhang   // clang-format on
37f1be3500SJunchao Zhang 
38f1be3500SJunchao Zhang private:
39*9d13fa56SJunchao Zhang   PetscErrorCode BuildHelperArrays(const PetscInt *bsizes)
40d71ae5a4SJacob Faibussowitsch   {
41f1be3500SJunchao Zhang     PetscInt *bs_h     = bs_dual.view_host().data();
42f1be3500SJunchao Zhang     PetscInt *bs2_h    = bs2_dual.view_host().data();
43*9d13fa56SJunchao Zhang     PetscInt *blkMap_h = blkMap_dual.view_host().data();
44f1be3500SJunchao Zhang 
45f1be3500SJunchao Zhang     PetscFunctionBegin;
46*9d13fa56SJunchao Zhang     setupFlops = 0.0;
47f1be3500SJunchao Zhang     bs_h[0] = bs2_h[0] = 0;
48f1be3500SJunchao Zhang     for (PetscInt i = 0; i < nblocks; i++) {
49*9d13fa56SJunchao Zhang       PetscInt m   = bsizes[i];
50*9d13fa56SJunchao Zhang       bs_h[i + 1]  = bs_h[i] + m;
51*9d13fa56SJunchao Zhang       bs2_h[i + 1] = bs2_h[i] + m * m;
52*9d13fa56SJunchao Zhang       for (PetscInt j = 0; j < m; j++) blkMap_h[bs_h[i] + j] = i;
53*9d13fa56SJunchao Zhang       // m^3/3 FMA for A=LU factorization; m^3 FMA for solving (LU)X=I to get the inverse
54*9d13fa56SJunchao Zhang       setupFlops += 8.0 * m * m * m / 3;
55f1be3500SJunchao Zhang     }
56*9d13fa56SJunchao Zhang 
57*9d13fa56SJunchao Zhang     PetscCallCXX(bs_dual.modify_host());
58*9d13fa56SJunchao Zhang     PetscCallCXX(bs2_dual.modify_host());
59*9d13fa56SJunchao Zhang     PetscCallCXX(blkMap_dual.modify_host());
60*9d13fa56SJunchao Zhang     PetscCallCXX(bs_dual.sync_device());
61*9d13fa56SJunchao Zhang     PetscCallCXX(bs2_dual.sync_device());
62*9d13fa56SJunchao Zhang     PetscCallCXX(blkMap_dual.sync_device());
63*9d13fa56SJunchao Zhang     PetscCall(PetscLogCpuToGpu(sizeof(PetscInt) * (2 * (nblocks + 1) + n)));
643ba16761SJacob Faibussowitsch     PetscFunctionReturn(PETSC_SUCCESS);
65f1be3500SJunchao Zhang   }
66f1be3500SJunchao Zhang };
67f1be3500SJunchao Zhang 
6869eda9daSJed Brown template <PetscBool transpose>
6969eda9daSJed Brown static PetscErrorCode PCApplyOrTranspose_VPBJacobi_Kokkos(PC pc, Vec x, Vec y)
70d71ae5a4SJacob Faibussowitsch {
71f1be3500SJunchao Zhang   PC_VPBJacobi              *jac   = (PC_VPBJacobi *)pc->data;
72f1be3500SJunchao Zhang   PC_VPBJacobi_Kokkos       *pckok = static_cast<PC_VPBJacobi_Kokkos *>(jac->spptr);
73f1be3500SJunchao Zhang   ConstPetscScalarKokkosView xv;
74f1be3500SJunchao Zhang   PetscScalarKokkosView      yv;
75*9d13fa56SJunchao Zhang   PetscScalarKokkosView      diag   = pckok->diag;
76f1be3500SJunchao Zhang   PetscIntKokkosView         bs     = pckok->bs_dual.view_device();
77f1be3500SJunchao Zhang   PetscIntKokkosView         bs2    = pckok->bs2_dual.view_device();
78*9d13fa56SJunchao Zhang   PetscIntKokkosView         blkMap = pckok->blkMap_dual.view_device();
7969eda9daSJed Brown   const char                *label  = transpose ? "PCApplyTranspose_VPBJacobi_Kokkos" : "PCApply_VPBJacobi_Kokkos";
80f1be3500SJunchao Zhang 
81f1be3500SJunchao Zhang   PetscFunctionBegin;
829a56b474SJunchao Zhang   PetscCall(PetscLogGpuTimeBegin());
83f1be3500SJunchao Zhang   VecErrorIfNotKokkos(x);
84f1be3500SJunchao Zhang   VecErrorIfNotKokkos(y);
85f1be3500SJunchao Zhang   PetscCall(VecGetKokkosView(x, &xv));
86f1be3500SJunchao Zhang   PetscCall(VecGetKokkosViewWrite(y, &yv));
879371c9d4SSatish Balay   PetscCallCXX(Kokkos::parallel_for(
8869eda9daSJed Brown     label, pckok->n, KOKKOS_LAMBDA(PetscInt row) {
89f1be3500SJunchao Zhang       const PetscScalar *Ap, *xp;
90f1be3500SJunchao Zhang       PetscScalar       *yp;
91f1be3500SJunchao Zhang       PetscInt           i, j, k, m;
92f1be3500SJunchao Zhang 
93*9d13fa56SJunchao Zhang       k  = blkMap(row);                             /* k-th block/matrix */
94f1be3500SJunchao Zhang       m  = bs(k + 1) - bs(k);                       /* block size of the k-th block */
95f1be3500SJunchao Zhang       i  = row - bs(k);                             /* i-th row of the block */
9669eda9daSJed Brown       Ap = &diag(bs2(k) + i * (transpose ? m : 1)); /* Ap points to the first entry of i-th row/column */
97f1be3500SJunchao Zhang       xp = &xv(bs(k));
98f1be3500SJunchao Zhang       yp = &yv(bs(k));
99f1be3500SJunchao Zhang 
100f1be3500SJunchao Zhang       yp[i] = 0.0;
1019371c9d4SSatish Balay       for (j = 0; j < m; j++) {
1029371c9d4SSatish Balay         yp[i] += Ap[0] * xp[j];
10369eda9daSJed Brown         Ap += transpose ? 1 : m;
1049371c9d4SSatish Balay       }
105f1be3500SJunchao Zhang     }));
106f1be3500SJunchao Zhang   PetscCall(VecRestoreKokkosView(x, &xv));
107f1be3500SJunchao Zhang   PetscCall(VecRestoreKokkosViewWrite(y, &yv));
1089a56b474SJunchao Zhang   PetscCall(PetscLogGpuFlops(pckok->nsize * 2)); /* FMA on entries in all blocks */
1099a56b474SJunchao Zhang   PetscCall(PetscLogGpuTimeEnd());
1103ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
111f1be3500SJunchao Zhang }
112f1be3500SJunchao Zhang 
113d71ae5a4SJacob Faibussowitsch static PetscErrorCode PCDestroy_VPBJacobi_Kokkos(PC pc)
114d71ae5a4SJacob Faibussowitsch {
115f1be3500SJunchao Zhang   PC_VPBJacobi *jac = (PC_VPBJacobi *)pc->data;
116f1be3500SJunchao Zhang 
117f1be3500SJunchao Zhang   PetscFunctionBegin;
118f1be3500SJunchao Zhang   PetscCallCXX(delete static_cast<PC_VPBJacobi_Kokkos *>(jac->spptr));
119f1be3500SJunchao Zhang   PetscCall(PCDestroy_VPBJacobi(pc));
1203ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
121f1be3500SJunchao Zhang }
122f1be3500SJunchao Zhang 
123d71ae5a4SJacob Faibussowitsch PETSC_INTERN PetscErrorCode PCSetUp_VPBJacobi_Kokkos(PC pc)
124d71ae5a4SJacob Faibussowitsch {
125f1be3500SJunchao Zhang   PC_VPBJacobi        *jac   = (PC_VPBJacobi *)pc->data;
126f1be3500SJunchao Zhang   PC_VPBJacobi_Kokkos *pckok = static_cast<PC_VPBJacobi_Kokkos *>(jac->spptr);
127*9d13fa56SJunchao Zhang   PetscInt             i, nlocal, nblocks, nsize = 0;
128f1be3500SJunchao Zhang   const PetscInt      *bsizes;
129f1be3500SJunchao Zhang 
130f1be3500SJunchao Zhang   PetscFunctionBegin;
131f1be3500SJunchao Zhang   PetscCall(MatGetVariableBlockSizes(pc->pmat, &nblocks, &bsizes));
132*9d13fa56SJunchao Zhang   PetscCall(MatGetLocalSize(pc->pmat, &nlocal, NULL));
133*9d13fa56SJunchao Zhang   PetscCheck(!nlocal || nblocks, PETSC_COMM_SELF, PETSC_ERR_ARG_WRONGSTATE, "Must call MatSetVariableBlockSizes() before using PCVPBJACOBI");
134f1be3500SJunchao Zhang 
135*9d13fa56SJunchao Zhang   if (!jac->diag) {
136*9d13fa56SJunchao Zhang     PetscInt max_bs = -1, min_bs = PETSC_MAX_INT;
137*9d13fa56SJunchao Zhang     for (i = 0; i < nblocks; i++) {
138*9d13fa56SJunchao Zhang       min_bs = PetscMin(min_bs, bsizes[i]);
139*9d13fa56SJunchao Zhang       max_bs = PetscMax(max_bs, bsizes[i]);
140*9d13fa56SJunchao Zhang       nsize += bsizes[i] * bsizes[i];
141*9d13fa56SJunchao Zhang     }
142*9d13fa56SJunchao Zhang     jac->nblocks = nblocks;
143*9d13fa56SJunchao Zhang     jac->min_bs  = min_bs;
144*9d13fa56SJunchao Zhang     jac->max_bs  = max_bs;
145*9d13fa56SJunchao Zhang   }
146*9d13fa56SJunchao Zhang 
147*9d13fa56SJunchao Zhang   // If one calls MatSetVariableBlockSizes() multiple times and sizes have been changed (is it allowed?), we delete the old and rebuild anyway
148*9d13fa56SJunchao Zhang   if (pckok && (pckok->n != nlocal || pckok->nblocks != nblocks || pckok->nsize != nsize)) {
149f1be3500SJunchao Zhang     PetscCallCXX(delete pckok);
150f1be3500SJunchao Zhang     pckok = nullptr;
151f1be3500SJunchao Zhang   }
152f1be3500SJunchao Zhang 
153*9d13fa56SJunchao Zhang   PetscCall(PetscLogGpuTimeBegin());
154*9d13fa56SJunchao Zhang   if (!pckok) {
155*9d13fa56SJunchao Zhang     PetscCallCXX(pckok = new PC_VPBJacobi_Kokkos(nlocal, nblocks, nsize, bsizes));
156*9d13fa56SJunchao Zhang     jac->spptr = pckok;
157f1be3500SJunchao Zhang   }
158f1be3500SJunchao Zhang 
159*9d13fa56SJunchao Zhang   // Extract diagonal blocks from the matrix and compute their inverse
160*9d13fa56SJunchao Zhang   const auto &bs     = pckok->bs_dual.view_device();
161*9d13fa56SJunchao Zhang   const auto &bs2    = pckok->bs2_dual.view_device();
162*9d13fa56SJunchao Zhang   const auto &blkMap = pckok->blkMap_dual.view_device();
163*9d13fa56SJunchao Zhang   PetscCall(MatInvertVariableBlockDiagonal_SeqAIJKokkos(pc->pmat, bs, bs2, blkMap, pckok->work, pckok->diag));
16469eda9daSJed Brown   pc->ops->apply          = PCApplyOrTranspose_VPBJacobi_Kokkos<PETSC_FALSE>;
16569eda9daSJed Brown   pc->ops->applytranspose = PCApplyOrTranspose_VPBJacobi_Kokkos<PETSC_TRUE>;
166f1be3500SJunchao Zhang   pc->ops->destroy        = PCDestroy_VPBJacobi_Kokkos;
167*9d13fa56SJunchao Zhang   PetscCall(PetscLogGpuTimeEnd());
168*9d13fa56SJunchao Zhang   PetscCall(PetscLogGpuFlops(pckok->setupFlops));
1693ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
170f1be3500SJunchao Zhang }
171