xref: /petsc/src/mat/impls/aij/mpi/kokkos/mpiaijkok.kokkos.cxx (revision 72a35d6d63ea98b60d8ef4c5aaefda5ff933c0fe)
1 #include <petscvec_kokkos.hpp>
2 #include <petscpkg_version.h>
3 #include <petscsf.h>
4 #include <petsc/private/sfimpl.h>
5 #include <../src/mat/impls/aij/mpi/mpiaij.h>
6 #include <../src/mat/impls/aij/mpi/kokkos/mpiaijkok.hpp>
7 #include <KokkosSparse_spadd.hpp>
8 
9 PetscErrorCode MatAssemblyEnd_MPIAIJKokkos(Mat A, MatAssemblyType mode)
10 {
11   Mat_SeqAIJKokkos *aijkok;
12   Mat_MPIAIJ       *mpiaij = (Mat_MPIAIJ *)A->data;
13 
14   PetscFunctionBegin;
15   PetscCall(MatAssemblyEnd_MPIAIJ(A, mode));
16   /* E.g., MatCreateSubMatrix() calls MatCreateMPIAIJWithSeqAIJ(comm,A,B,..), which creates Bnew of SEQAIJ and destroys B of SEQAIJKOKKOS.
17      Thus we finalize A/B/lvec's type in MatAssemblyEnd() to handle various cases.
18    */
19   if (mode == MAT_FINAL_ASSEMBLY) {
20     PetscCall(MatSetType(mpiaij->A, MATSEQAIJKOKKOS));
21     PetscCall(MatSetType(mpiaij->B, MATSEQAIJKOKKOS));
22     PetscCall(VecSetType(mpiaij->lvec, VECSEQKOKKOS));
23   }
24   aijkok = static_cast<Mat_SeqAIJKokkos *>(((Mat_MPIAIJ *)A->data)->A->spptr); /* Access spptr after MatAssemblyEnd_MPIAIJ(), which might have deleted old spptr */
25   if (aijkok && aijkok->device_mat_d.data()) {
26     A->offloadmask = PETSC_OFFLOAD_GPU; // in GPU mode, no going back. MatSetValues checks this
27   }
28 
29   PetscFunctionReturn(PETSC_SUCCESS);
30 }
31 
32 PetscErrorCode MatMPIAIJSetPreallocation_MPIAIJKokkos(Mat mat, PetscInt d_nz, const PetscInt d_nnz[], PetscInt o_nz, const PetscInt o_nnz[])
33 {
34   Mat_MPIAIJ *mpiaij = (Mat_MPIAIJ *)mat->data;
35 
36   PetscFunctionBegin;
37   PetscCall(PetscLayoutSetUp(mat->rmap));
38   PetscCall(PetscLayoutSetUp(mat->cmap));
39 #if defined(PETSC_USE_DEBUG)
40   if (d_nnz) {
41     PetscInt i;
42     for (i = 0; i < mat->rmap->n; i++) PetscCheck(d_nnz[i] >= 0, PETSC_COMM_SELF, PETSC_ERR_ARG_OUTOFRANGE, "d_nnz cannot be less than 0: local row %" PetscInt_FMT " value %" PetscInt_FMT, i, d_nnz[i]);
43   }
44   if (o_nnz) {
45     PetscInt i;
46     for (i = 0; i < mat->rmap->n; i++) PetscCheck(o_nnz[i] >= 0, PETSC_COMM_SELF, PETSC_ERR_ARG_OUTOFRANGE, "o_nnz cannot be less than 0: local row %" PetscInt_FMT " value %" PetscInt_FMT, i, o_nnz[i]);
47   }
48 #endif
49 #if defined(PETSC_USE_CTABLE)
50   PetscCall(PetscHMapIDestroy(&mpiaij->colmap));
51 #else
52   PetscCall(PetscFree(mpiaij->colmap));
53 #endif
54   PetscCall(PetscFree(mpiaij->garray));
55   PetscCall(VecDestroy(&mpiaij->lvec));
56   PetscCall(VecScatterDestroy(&mpiaij->Mvctx));
57   /* Because the B will have been resized we simply destroy it and create a new one each time */
58   PetscCall(MatDestroy(&mpiaij->B));
59 
60   if (!mpiaij->A) {
61     PetscCall(MatCreate(PETSC_COMM_SELF, &mpiaij->A));
62     PetscCall(MatSetSizes(mpiaij->A, mat->rmap->n, mat->cmap->n, mat->rmap->n, mat->cmap->n));
63   }
64   if (!mpiaij->B) {
65     PetscMPIInt size;
66     PetscCallMPI(MPI_Comm_size(PetscObjectComm((PetscObject)mat), &size));
67     PetscCall(MatCreate(PETSC_COMM_SELF, &mpiaij->B));
68     PetscCall(MatSetSizes(mpiaij->B, mat->rmap->n, size > 1 ? mat->cmap->N : 0, mat->rmap->n, size > 1 ? mat->cmap->N : 0));
69   }
70   PetscCall(MatSetType(mpiaij->A, MATSEQAIJKOKKOS));
71   PetscCall(MatSetType(mpiaij->B, MATSEQAIJKOKKOS));
72   PetscCall(MatSeqAIJSetPreallocation(mpiaij->A, d_nz, d_nnz));
73   PetscCall(MatSeqAIJSetPreallocation(mpiaij->B, o_nz, o_nnz));
74   mat->preallocated = PETSC_TRUE;
75   PetscFunctionReturn(PETSC_SUCCESS);
76 }
77 
78 PetscErrorCode MatMult_MPIAIJKokkos(Mat mat, Vec xx, Vec yy)
79 {
80   Mat_MPIAIJ *mpiaij = (Mat_MPIAIJ *)mat->data;
81   PetscInt    nt;
82 
83   PetscFunctionBegin;
84   PetscCall(VecGetLocalSize(xx, &nt));
85   PetscCheck(nt == mat->cmap->n, PETSC_COMM_SELF, PETSC_ERR_ARG_SIZ, "Incompatible partition of mat (%" PetscInt_FMT ") and xx (%" PetscInt_FMT ")", mat->cmap->n, nt);
86   PetscCall(VecScatterBegin(mpiaij->Mvctx, xx, mpiaij->lvec, INSERT_VALUES, SCATTER_FORWARD));
87   PetscCall((*mpiaij->A->ops->mult)(mpiaij->A, xx, yy));
88   PetscCall(VecScatterEnd(mpiaij->Mvctx, xx, mpiaij->lvec, INSERT_VALUES, SCATTER_FORWARD));
89   PetscCall((*mpiaij->B->ops->multadd)(mpiaij->B, mpiaij->lvec, yy, yy));
90   PetscFunctionReturn(PETSC_SUCCESS);
91 }
92 
93 PetscErrorCode MatMultAdd_MPIAIJKokkos(Mat mat, Vec xx, Vec yy, Vec zz)
94 {
95   Mat_MPIAIJ *mpiaij = (Mat_MPIAIJ *)mat->data;
96   PetscInt    nt;
97 
98   PetscFunctionBegin;
99   PetscCall(VecGetLocalSize(xx, &nt));
100   PetscCheck(nt == mat->cmap->n, PETSC_COMM_SELF, PETSC_ERR_ARG_SIZ, "Incompatible partition of mat (%" PetscInt_FMT ") and xx (%" PetscInt_FMT ")", mat->cmap->n, nt);
101   PetscCall(VecScatterBegin(mpiaij->Mvctx, xx, mpiaij->lvec, INSERT_VALUES, SCATTER_FORWARD));
102   PetscCall((*mpiaij->A->ops->multadd)(mpiaij->A, xx, yy, zz));
103   PetscCall(VecScatterEnd(mpiaij->Mvctx, xx, mpiaij->lvec, INSERT_VALUES, SCATTER_FORWARD));
104   PetscCall((*mpiaij->B->ops->multadd)(mpiaij->B, mpiaij->lvec, zz, zz));
105   PetscFunctionReturn(PETSC_SUCCESS);
106 }
107 
108 PetscErrorCode MatMultTranspose_MPIAIJKokkos(Mat mat, Vec xx, Vec yy)
109 {
110   Mat_MPIAIJ *mpiaij = (Mat_MPIAIJ *)mat->data;
111   PetscInt    nt;
112 
113   PetscFunctionBegin;
114   PetscCall(VecGetLocalSize(xx, &nt));
115   PetscCheck(nt == mat->rmap->n, PETSC_COMM_SELF, PETSC_ERR_ARG_SIZ, "Incompatible partition of mat (%" PetscInt_FMT ") and xx (%" PetscInt_FMT ")", mat->rmap->n, nt);
116   PetscCall((*mpiaij->B->ops->multtranspose)(mpiaij->B, xx, mpiaij->lvec));
117   PetscCall((*mpiaij->A->ops->multtranspose)(mpiaij->A, xx, yy));
118   PetscCall(VecScatterBegin(mpiaij->Mvctx, mpiaij->lvec, yy, ADD_VALUES, SCATTER_REVERSE));
119   PetscCall(VecScatterEnd(mpiaij->Mvctx, mpiaij->lvec, yy, ADD_VALUES, SCATTER_REVERSE));
120   PetscFunctionReturn(PETSC_SUCCESS);
121 }
122 
123 /* Merge the "A, B" matrices of mat into a matrix C.  mat's type is MPIAIJKOKKOS. C's type is MATSEQAIJKOKKOS.
124    A is put before B. C's size would be A->rmap->n by (A->cmap->n + B->cmap->n).
125    C still uses local column ids. Their corresponding global column ids are returned in glob.
126 */
127 PetscErrorCode MatMPIAIJGetLocalMatMerge_MPIAIJKokkos(Mat mat, MatReuse reuse, IS *glob, Mat *C)
128 {
129   Mat             Ad, Ao;
130   const PetscInt *cmap;
131 
132   PetscFunctionBegin;
133   PetscCall(MatMPIAIJGetSeqAIJ(mat, &Ad, &Ao, &cmap));
134   PetscCall(MatSeqAIJKokkosMergeMats(Ad, Ao, reuse, C));
135   if (glob) {
136     PetscInt cst, i, dn, on, *gidx;
137     PetscCall(MatGetLocalSize(Ad, NULL, &dn));
138     PetscCall(MatGetLocalSize(Ao, NULL, &on));
139     PetscCall(MatGetOwnershipRangeColumn(mat, &cst, NULL));
140     PetscCall(PetscMalloc1(dn + on, &gidx));
141     for (i = 0; i < dn; i++) gidx[i] = cst + i;
142     for (i = 0; i < on; i++) gidx[i + dn] = cmap[i];
143     PetscCall(ISCreateGeneral(PetscObjectComm((PetscObject)Ad), dn + on, gidx, PETSC_OWN_POINTER, glob));
144   }
145   PetscFunctionReturn(PETSC_SUCCESS);
146 }
147 
148 /* Structs used in matrix product C=AB, C=A^tB and C=B^tAB */
149 struct MatMatStruct {
150   MatRowMapKokkosView Cdstart; /* Used to split sequential matrix into petsc's A, B format */
151   PetscSF             sf;      /* SF to send/recv matrix entries */
152   MatScalarKokkosView abuf;    /* buf of mat values in send/recv */
153   Mat                 C1, C2, B_local;
154   KokkosCsrMatrix     C1_global, C2_global, C_global;
155   KernelHandle        kh;
156   MatMatStruct() noexcept : sf(nullptr), C1(nullptr), C2(nullptr), B_local(nullptr) { }
157 
158   ~MatMatStruct()
159   {
160     PetscFunctionBegin;
161     PetscCallAbort(PETSC_COMM_SELF, MatDestroy(&C1));
162     PetscCallAbort(PETSC_COMM_SELF, MatDestroy(&C2));
163     PetscCallAbort(PETSC_COMM_SELF, MatDestroy(&B_local));
164     PetscCallAbort(PETSC_COMM_SELF, PetscSFDestroy(&sf));
165     kh.destroy_spadd_handle();
166     PetscFunctionReturnVoid();
167   }
168 };
169 
170 struct MatMatStruct_AB : public MatMatStruct {
171   MatColIdxKokkosView rows{};
172   MatRowMapKokkosView rowoffset{};
173   Mat                 B_other{}, C_petsc{}; /* SEQAIJKOKKOS matrices. TODO: have a better var name than C_petsc */
174   MatColIdxKokkosView B_NzDiagLeft;         // Number of nonzeros on the left of B's diagonal block; Used to recover the unsplit B (i.e., local mat)
175 
176   ~MatMatStruct_AB() noexcept
177   {
178     PetscFunctionBegin;
179     PetscCallAbort(PETSC_COMM_SELF, MatDestroy(&B_other));
180     PetscCallAbort(PETSC_COMM_SELF, MatDestroy(&C_petsc));
181     PetscFunctionReturnVoid();
182   }
183 };
184 
185 struct MatMatStruct_AtB : public MatMatStruct {
186   MatRowMapKokkosView srcrowoffset, dstrowoffset;
187 };
188 
189 struct MatProductData_MPIAIJKokkos {
190   MatMatStruct_AB  *mmAB     = nullptr;
191   MatMatStruct_AtB *mmAtB    = nullptr;
192   PetscBool         reusesym = PETSC_FALSE;
193 
194   ~MatProductData_MPIAIJKokkos()
195   {
196     delete mmAB;
197     delete mmAtB;
198   }
199 };
200 
201 static PetscErrorCode MatProductDataDestroy_MPIAIJKokkos(void *data)
202 {
203   PetscFunctionBegin;
204   PetscCallCXX(delete static_cast<MatProductData_MPIAIJKokkos *>(data));
205   PetscFunctionReturn(PETSC_SUCCESS);
206 }
207 
208 /* MatSeqAIJKokkosGetCSRMatrixWithGlobalColumnIds - Get a KokkosCsrMatrix from a MATSEQAIJKOKKOS matrix
209 
210    Input Parameters:
211 +  A       - the MATSEQAIJKOKKOS matrix
212 .  N       - new column size for the returned Kokkos matrix
213 -  l2g     - a map that maps old col ids to new col ids
214 
215    Output Parameters:
216 .  csrmat  - the Kokkos matrix, which has the same row size as A, shares a, i but not j with A.
217  */
218 static PetscErrorCode MatSeqAIJKokkosGetCSRMatrixWithGlobalColumnIds(Mat A, PetscInt N, const ConstMatColIdxKokkosView &l2g, KokkosCsrMatrix &csrmat)
219 {
220   KokkosCsrMatrix    &orig = static_cast<Mat_SeqAIJKokkos *>(A->spptr)->csrmat;
221   MatColIdxKokkosView jg("jg", orig.nnz()); /* New j array for csrmat */
222 
223   PetscFunctionBegin;
224   PetscCallCXX(Kokkos::parallel_for(
225     orig.nnz(), KOKKOS_LAMBDA(const PetscInt i) { jg(i) = l2g(orig.graph.entries(i)); }));
226   PetscCallCXX(csrmat = KokkosCsrMatrix("csrmat", orig.numRows(), N, orig.nnz(), orig.values, orig.graph.row_map, jg));
227   PetscFunctionReturn(PETSC_SUCCESS);
228 }
229 
230 /* MatSetMPIAIJKokkosWithSplitSeqAIJKokkosMatrices - Set the diag and offdiag matrices of a MATMPIAIJKOKKOS matrix.
231    It is similar to MatCreateMPIAIJWithSplitArrays.
232 
233   Input Parameters:
234 +  mat   - the MATMPIAIJKOKKOS matrix, which should have its type and layout set, but should not have its diag, offdiag matrices set
235 .  A     - the diag matrix using local col ids
236 -  B     - the offdiag matrix using global col ids
237 
238   Output Parameters:
239 .  mat   - the updated MATMPIAIJKOKKOS matrix
240 */
241 static PetscErrorCode MatSetMPIAIJKokkosWithSplitSeqAIJKokkosMatrices(Mat mat, Mat A, Mat B)
242 {
243   Mat_MPIAIJ       *mpiaij = static_cast<Mat_MPIAIJ *>(mat->data);
244   PetscInt          m, n, M, N, Am, An, Bm, Bn;
245   Mat_SeqAIJKokkos *bkok = static_cast<Mat_SeqAIJKokkos *>(B->spptr);
246 
247   PetscFunctionBegin;
248   PetscCall(MatGetSize(mat, &M, &N));
249   PetscCall(MatGetLocalSize(mat, &m, &n));
250   PetscCall(MatGetLocalSize(A, &Am, &An));
251   PetscCall(MatGetLocalSize(B, &Bm, &Bn));
252 
253   PetscCheck(m == Am && m == Bm, PETSC_COMM_SELF, PETSC_ERR_PLIB, "local number of rows do not match");
254   PetscCheck(n == An, PETSC_COMM_SELF, PETSC_ERR_PLIB, "local number of columns do not match");
255   PetscCheck(N == Bn, PETSC_COMM_SELF, PETSC_ERR_PLIB, "global number of columns do not match");
256   PetscCheck(!mpiaij->A && !mpiaij->B, PETSC_COMM_SELF, PETSC_ERR_PLIB, "A, B of the MPIAIJ matrix are not empty");
257   mpiaij->A = A;
258   mpiaij->B = B;
259 
260   mat->preallocated     = PETSC_TRUE;
261   mat->nooffprocentries = PETSC_TRUE; /* See MatAssemblyBegin_MPIAIJ. In effect, making MatAssemblyBegin a nop */
262 
263   PetscCall(MatSetOption(mat, MAT_NO_OFF_PROC_ENTRIES, PETSC_TRUE));
264   PetscCall(MatAssemblyBegin(mat, MAT_FINAL_ASSEMBLY));
265   /* MatAssemblyEnd is critical here. It sets mat->offloadmask according to A and B's, and
266     also gets mpiaij->B compacted, with its col ids and size reduced
267   */
268   PetscCall(MatAssemblyEnd(mat, MAT_FINAL_ASSEMBLY));
269   PetscCall(MatSetOption(mat, MAT_NO_OFF_PROC_ENTRIES, PETSC_FALSE));
270   PetscCall(MatSetOption(mat, MAT_NEW_NONZERO_LOCATION_ERR, PETSC_TRUE));
271 
272   /* Update bkok with new local col ids (stored on host) and size */
273   bkok->j_dual.modify_host();
274   bkok->j_dual.sync_device();
275   bkok->SetColSize(mpiaij->B->cmap->n);
276   PetscFunctionReturn(PETSC_SUCCESS);
277 }
278 
279 /* MatSeqAIJKokkosBcast - Bcast rows of a SEQAIJKOKKOS matrice (B) to form a SEQAIJKOKKOS matrix (C).
280 
281    It is essentially the MPIAIJKOKKOS counterpart of MatGetBrowsOfAoCols_MPIAIJ, but supports device and uses PetscSF.
282    In the given ownerSF, leaves correspond to rows in C, and roots correspond to rows in B. Roots may connect to multiple leaves.
283    Suppose C's j-th row is connected to a root identified by PetscSFNode (k,i), it means we will bcast the i-th row of B on rank k
284    to j-th row of C. ownerSF's leaves must be contiguous (in other words, as if ilocal=NULL was used to set its graph).
285 
286    Collective
287 
288    Input Parameters:
289 +   B       - the SEQAIJKOKKOS matrix, using local col ids
290 .   reuse   - either MAT_INITIAL_MATRIX or MAT_REUSE_MATRIX
291 .   N       - global col ids are in range of [0,N). N Must be the same across ranks (nonsignificant in MAT_REUSE_MATRIX)
292 .   l2g     - a map mapping B's local col ids to global ones (nonsignificant in MAT_REUSE_MATRIX)
293 .   ownerSF - the ownership SF (nonsignificant in MAT_REUSE_MATRIX)
294 
295    Input/Output Parameters (out when resue = MAT_INITIAL_MATRIX, inout when reuse = MAT_REUSE_MATRIX)
296 +   bcastSF   - the SF used to bcast rows of B. This plain SF does buffer (abuf) to buffer (Ca) send/recv. In this SF, vertices are nonzeros.
297 .   abuf      - buffer for sending matrix values
298 .   rows      - array containing indices of (local) rows that this rank needs to bcast to others. Each receiver rank has a chunk in rows[].
299                 Values in rows[] might have repeats, which simply indicates a row will be bcast'ed to multiple neighbors.
300 .   rowoffset - For each row in rows[], it will be copied to rowoffset[] at abuf[]
301 -   C         -  the SEQAIJKOKKOS matrix made of the bcast'ed rows, using local col ids.
302 */
303 static PetscErrorCode MatSeqAIJKokkosBcast(Mat B, MatReuse reuse, PetscInt N, const ConstMatColIdxKokkosView &l2g, PetscSF ownerSF, PetscSF &bcastSF, MatScalarKokkosView &abuf, MatColIdxKokkosView &rows, MatRowMapKokkosView &rowoffset, Mat &C)
304 {
305   Mat_SeqAIJKokkos *bkok, *ckok;
306 
307   PetscFunctionBegin;
308   PetscCall(MatSeqAIJKokkosSyncDevice(B)); /* Make sure B->spptr is accessible */
309   bkok = static_cast<Mat_SeqAIJKokkos *>(B->spptr);
310 
311   if (reuse == MAT_REUSE_MATRIX) {
312     ckok = static_cast<Mat_SeqAIJKokkos *>(C->spptr);
313 
314     const auto &Ba = bkok->a_dual.view_device();
315     const auto &Bi = bkok->i_dual.view_device();
316     const auto &Ca = ckok->a_dual.view_device();
317 
318     /* Copy Ba to abuf */
319     Kokkos::parallel_for(
320       Kokkos::TeamPolicy<>(rows.extent(0), Kokkos::AUTO()), KOKKOS_LAMBDA(const KokkosTeamMemberType &t) {
321         PetscInt i    = t.league_rank(); /* rows[i] is r-th row of B */
322         PetscInt r    = rows(i);
323         PetscInt base = rowoffset(i); /* Copy r-th row of B to this offset in abuf[] */
324         Kokkos::parallel_for(Kokkos::TeamThreadRange(t, Bi(r + 1) - Bi(r)), [&](PetscInt k) { abuf(base + k) = Ba(Bi(r) + k); });
325       });
326 
327     /* Send abuf to Ca through bcastSF and then mark C is updated on device */
328     PetscCall(PetscSFBcastBegin(bcastSF, MPIU_SCALAR, abuf.data(), Ca.data(), MPI_REPLACE)); /* TODO: get memtype for abuf */
329     PetscCall(PetscSFBcastEnd(bcastSF, MPIU_SCALAR, abuf.data(), Ca.data(), MPI_REPLACE));
330     ckok->a_dual.modify_device();
331   } else if (reuse == MAT_INITIAL_MATRIX) {
332     MPI_Comm    comm;
333     PetscMPIInt tag;
334     PetscInt    k, Cm, Cn, Cnnz, *Ci_h, nroots, nleaves;
335 
336     PetscCallMPI(PetscObjectGetComm((PetscObject)ownerSF, &comm));
337     PetscCall(PetscSFGetGraph(ownerSF, &nroots, &nleaves, NULL, NULL));
338     Cm = nleaves; /* row size of C */
339     Cn = N;       /* col size of C, which initially uses global ids, so we can safely set its col size as N */
340 
341     /* Get row lens (nz) of B's rows for later fast query */
342     PetscInt       *Browlens;
343     const PetscInt *tmp = bkok->i_host_data();
344     PetscCall(PetscMalloc1(nroots, &Browlens));
345     for (k = 0; k < nroots; k++) Browlens[k] = tmp[k + 1] - tmp[k];
346 
347     /* By ownerSF, each proc gets lens of rows of C */
348     MatRowMapKokkosDualView Ci("i", Cm + 1); /* C's rowmap */
349     Ci_h    = Ci.view_host().data();
350     Ci_h[0] = 0;
351     PetscCall(PetscSFBcastWithMemTypeBegin(ownerSF, MPIU_INT, PETSC_MEMTYPE_HOST, Browlens, PETSC_MEMTYPE_HOST, &Ci_h[1], MPI_REPLACE));
352     PetscCall(PetscSFBcastEnd(ownerSF, MPIU_INT, Browlens, &Ci_h[1], MPI_REPLACE));
353     for (k = 1; k < Cm + 1; k++) Ci_h[k] += Ci_h[k - 1]; /* Convert lens to CSR */
354     Cnnz = Ci_h[Cm];
355     Ci.modify_host();
356     Ci.sync_device();
357 
358     /* With the newly known Cnnz, we are able to allocate (j, a) for C on host & device */
359     MatColIdxKokkosDualView Cj("j", Cnnz);
360     MatScalarKokkosDualView Ca("a", Cnnz);
361 
362     /* Now build the bcastSF to fill Ca, Cj. This plain SF only does (contiguous) buffer to buffer send/recv */
363     const PetscMPIInt *iranks, *ranks;
364     const PetscInt    *ioffset, *irootloc, *roffset;
365     PetscInt           i, j, niranks, nranks, *sdisp, *rdisp, *rowptr;
366     MPI_Request       *reqs;
367 
368     PetscCall(PetscSFGetLeafRanks(ownerSF, &niranks, &iranks, &ioffset, &irootloc));                      /* irootloc[] contains indices of rows I need to send to each receiver */
369     PetscCall(PetscSFGetRootRanks(ownerSF, &nranks, &ranks, &roffset, NULL /*rmine*/, NULL /*rremote*/)); /* recv info */
370 
371     /* figure out offsets at the send buffer, to build the SF
372       sdisp[]  - stores offsets of nonzeros (in abuf or jbuf, see later) I need to send, per receiver.
373       rowptr[] - stores offsets for data of each row in abuf
374 
375       rdisp[]  - to receive sdisp[]
376     */
377     PetscCall(PetscMalloc3(niranks + 1, &sdisp, nranks, &rdisp, niranks + nranks, &reqs));
378     MatRowMapKokkosViewHost rowptr_h("rowptr_h", ioffset[niranks] + 1); /* Let Kokkos do the allocation, so that we can do an easy mirror later */
379     rowptr = rowptr_h.data();
380 
381     sdisp[0]  = 0;
382     rowptr[0] = 0;
383     for (i = 0; i < niranks; i++) { /* for each receiver */
384       PetscInt len, nz = 0;
385       for (j = ioffset[i]; j < ioffset[i + 1]; j++) { /* for each row to this receiver */
386         len           = Browlens[irootloc[j]];
387         rowptr[j + 1] = rowptr[j] + len;
388         nz += len;
389       }
390       sdisp[i + 1] = sdisp[i] + nz;
391     }
392     PetscCallMPI(PetscCommGetNewTag(comm, &tag));
393     for (i = 0; i < nranks; i++) PetscCallMPI(MPI_Irecv(&rdisp[i], 1, MPIU_INT, ranks[i], tag, comm, &reqs[i]));
394     for (i = 0; i < niranks; i++) PetscCallMPI(MPI_Isend(&sdisp[i], 1, MPIU_INT, iranks[i], tag, comm, &reqs[nranks + i]));
395     PetscCallMPI(MPI_Waitall(niranks + nranks, reqs, MPI_STATUSES_IGNORE));
396 
397     PetscInt     nleaves2 = Cnnz;           /* leaves are the nonzeros I will receive */
398     PetscInt     nroots2  = sdisp[niranks]; /* roots are the nonzeros (in abuf) I will send */
399     PetscSFNode *iremote;
400     PetscCall(PetscMalloc1(nleaves2, &iremote));
401     for (i = 0; i < nranks; i++) { /* for each sender */
402       k = 0;
403       for (j = Ci_h[roffset[i]]; j < Ci_h[roffset[i + 1]]; j++) {
404         iremote[j].rank  = ranks[i];
405         iremote[j].index = rdisp[i] + k;
406         k++;
407       }
408     }
409     /* TODO: we should extend PetscSF APIs for this buffer-to-buffer send/recv */
410     PetscCall(PetscSFCreate(comm, &bcastSF));
411     PetscCall(PetscSFSetGraph(bcastSF, nroots2, nleaves2, NULL /*ilocal*/, PETSC_OWN_POINTER, iremote, PETSC_OWN_POINTER));
412 
413     /* Extract selected rows of B, and copy their (a, j) into abuf[] and jbuf[], with j converted
414       from local to global. Then use bcastSF to fill Ca, Cj.
415     */
416     ConstMatColIdxKokkosViewHost rows_h(irootloc, ioffset[niranks]); /* irootloc[] stores indices of rows I need to send */
417     MatColIdxKokkosView          rows("rows", ioffset[niranks]);
418     Kokkos::deep_copy(rows, rows_h); /* Use deep copy since irootoc is managed by PetscSF and we want 'rows' to be standalone */
419 
420     rowoffset = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), rowptr_h); /* If no device, rowoffset will be an alias to rowptr_h */
421 
422     MatColIdxKokkosView jbuf("jbuf", sdisp[niranks]);   /* send buf for (global) col ids */
423     abuf = MatScalarKokkosView("abuf", sdisp[niranks]); /* send buf for mat values */
424 
425     const auto &Ba = bkok->a_dual.view_device();
426     const auto &Bi = bkok->i_dual.view_device();
427     const auto &Bj = bkok->j_dual.view_device();
428 
429     /* Copy Ba, Bj to abuf, jbuf with change col ids from local to global */
430     Kokkos::parallel_for(
431       Kokkos::TeamPolicy<>(rows.extent(0), Kokkos::AUTO()), KOKKOS_LAMBDA(const KokkosTeamMemberType &t) {
432         PetscInt i    = t.league_rank(); /* rows[i] is r-th row of B */
433         PetscInt r    = rows(i);
434         PetscInt base = rowoffset(i); /* Copy r-th row of B to this offset in abuf[] */
435         Kokkos::parallel_for(Kokkos::TeamThreadRange(t, Bi(r + 1) - Bi(r)), [&](PetscInt k) {
436           abuf(base + k) = Ba(Bi(r) + k);
437           jbuf(base + k) = l2g(Bj(Bi(r) + k));
438         });
439       });
440 
441     /* Send abuf & jbuf to fill Ca, Cj */
442     PetscCall(PetscSFBcastBegin(bcastSF, MPIU_INT, jbuf.data(), Cj.view_device().data(), MPI_REPLACE));
443     PetscCall(PetscSFBcastBegin(bcastSF, MPIU_SCALAR, abuf.data(), Ca.view_device().data(), MPI_REPLACE));
444     PetscCall(PetscSFBcastEnd(bcastSF, MPIU_INT, jbuf.data(), Cj.view_device().data(), MPI_REPLACE));
445     PetscCall(PetscSFBcastEnd(bcastSF, MPIU_SCALAR, abuf.data(), Ca.view_device().data(), MPI_REPLACE));
446     Cj.modify_device(); /* Mark Cj, Ca modified on device, but only sync Cj since we might not need Ca on host at all */
447     Cj.sync_host();
448     Ca.modify_device();
449 
450     /* Construct C with Ca, Ci, Cj */
451     auto ckok = new Mat_SeqAIJKokkos(Cm, Cn, Cnnz, Ci, Cj, Ca);
452     PetscCall(MatCreateSeqAIJKokkosWithCSRMatrix(PETSC_COMM_SELF, ckok, &C));
453     PetscCall(PetscFree3(sdisp, rdisp, reqs));
454     PetscCall(PetscFree(Browlens));
455   } else SETERRQ(PETSC_COMM_SELF, PETSC_ERR_PLIB, "Unsupported MatReuse enum %d", reuse);
456   PetscFunctionReturn(PETSC_SUCCESS);
457 }
458 
459 /* MatSeqAIJKokkosReduce - Reduce rows of a SEQAIJKOKKOS matrix (A) to form a Kokkos Csr matrix (C)
460 
461   It is the reverse of MatSeqAIJKokkosBcast in some sense.
462 
463   Think each row of A as a leaf, then the given ownerSF specifies roots of the leaves. Roots may connect to multiple leaves.
464   In this routine, we reduce (i.e., concatenate) leaves (rows) at their roots to form potentially longer rows in C. Such rows might
465   contain repeats, which does not matter since they will be summed up by other routines. C's row size will be nroots of ownerSF.
466 
467   Input Parameters:
468 +  A        - the SEQAIJKOKKOS matrix to be reduced
469 .  reuse    - either MAT_INITIAL_MATRIX or MAT_REUSE_MATRIX
470 .  local    - true if A uses local col ids; false if A is already in global col ids.
471 .  N        - if local, N is A's global col size
472 .  l2g      - if local, a map mapping A's local col ids to global ones, which are in range of [0,N).
473 -  ownerSF  - the SF specifies ownership (root) of rows in A
474 
475   Output Parameters:
476 +  reduceSF    - the SF to reduce A's rows to contiguous buffers at the receiver side
477 .  abuf         - a contiguous buffer to receive A's rows sent to this proc. Suppose there are 'nrows' such rows.
478 .  srcrowoffset - offset array of size nrows+1. Each entry is the corresponding row's offset in abuf[]. srcrowoffset[i+1]-srcrowoffset[i] is row i's len.
479 .  dstrowoffset - offset array of size nrows. Each entry is the corresponding row's offset in Ca[], i.e., C's 'a' array. Row i, i+1 in abuf[] may go to
480                   unrelated places in Ca, so dstrowoffset is not in CSR-like format as srcrowoffset.
481 -  C            - the matrix made up by rows sent to me from other ranks, using global col ids
482 
483    TODO: we can even have MatSeqAIJKokkosReduceBegin/End to provide opportunity for callers to overlap comp./comm. when reuse = MAT_REUSE_MATRIX.
484  */
485 static PetscErrorCode MatSeqAIJKokkosReduce(Mat A, MatReuse reuse, PetscBool local, PetscInt N, const ConstMatColIdxKokkosView &l2g, PetscSF ownerSF, PetscSF &reduceSF, MatScalarKokkosView &abuf, MatRowMapKokkosView &srcrowoffset, MatRowMapKokkosView &dstrowoffset, KokkosCsrMatrix &C)
486 {
487   PetscInt          i, r, Am, An, Annz, Cnnz, nrows;
488   const PetscInt   *Ai;
489   Mat_SeqAIJKokkos *akok;
490 
491   PetscFunctionBegin;
492   PetscCall(MatSeqAIJKokkosSyncDevice(A)); /* So that A's latest data is on device */
493   PetscCall(MatGetSize(A, &Am, &An));
494   Ai   = static_cast<Mat_SeqAIJ *>(A->data)->i;
495   akok = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
496   Annz = Ai[Am];
497 
498   if (reuse == MAT_REUSE_MATRIX) {
499     /* Send Aa to abuf */
500     PetscCallMPI(PetscSFReduceBegin(reduceSF, MPIU_SCALAR, akok->a_device_data(), abuf.data(), MPI_REPLACE));
501     PetscCallMPI(PetscSFReduceEnd(reduceSF, MPIU_SCALAR, akok->a_device_data(), abuf.data(), MPI_REPLACE));
502 
503     /* Copy abuf to Ca */
504     const MatScalarKokkosView &Ca = C.values;
505     nrows                         = dstrowoffset.extent(0); /* Not srcrowoffset[] since it has an extra entry for CSR */
506     Kokkos::parallel_for(
507       Kokkos::TeamPolicy<>(nrows, Kokkos::AUTO()), KOKKOS_LAMBDA(const KokkosTeamMemberType &t) {
508         PetscInt i   = t.league_rank();
509         PetscInt src = srcrowoffset(i), dst = dstrowoffset(i);
510         PetscInt len = srcrowoffset(i + 1) - srcrowoffset(i);
511         Kokkos::parallel_for(Kokkos::TeamThreadRange(t, len), [&](PetscInt k) { Ca(dst + k) = abuf(src + k); });
512       });
513   } else if (reuse == MAT_INITIAL_MATRIX) {
514     MPI_Comm     comm;
515     MPI_Request *reqs;
516     PetscMPIInt  tag;
517     PetscInt     Cm;
518 
519     PetscCall(PetscObjectGetComm((PetscObject)ownerSF, &comm));
520     PetscCall(PetscCommGetNewTag(comm, &tag));
521 
522     PetscInt           niranks, nranks, nroots, nleaves;
523     const PetscMPIInt *iranks, *ranks;
524     const PetscInt    *ioffset, *rows, *roffset; /* rows[] contains local indices of rows scattered to me from others. ioffset[] is a CSR on rows[] */
525     PetscCall(PetscSFSetUp(ownerSF));
526     PetscCall(PetscSFGetLeafRanks(ownerSF, &niranks, &iranks, &ioffset, &rows));                          /* recv info: iranks[] will send rows to me */
527     PetscCall(PetscSFGetRootRanks(ownerSF, &nranks, &ranks, &roffset, NULL /*rmine*/, NULL /*rremote*/)); /* send info */
528     PetscCall(PetscSFGetGraph(ownerSF, &nroots, &nleaves, NULL, NULL));
529     PetscCheck(nleaves == Am, PETSC_COMM_SELF, PETSC_ERR_PLIB, "ownerSF's nleaves(%" PetscInt_FMT ") != row size of A(%" PetscInt_FMT ")", nleaves, Am);
530     Cm    = nroots;
531     nrows = ioffset[niranks]; /* # of rows to be received. Might receive same row (each is partial) from different senders */
532 
533     /* Tell owners how long each row I will send */
534     PetscInt               *srowlens;                              /* send buf of row lens */
535     MatRowMapKokkosViewHost rrowlens_h("rrowoffset_h", nrows + 1); /* recv buf of row lens. +1 to make CSR later. Memory might be passed to other views */
536     PetscInt               *rrowlens = rrowlens_h.data();
537 
538     PetscCall(PetscMalloc2(Am, &srowlens, niranks + nranks, &reqs));
539     for (i = 0; i < Am; i++) srowlens[i] = Ai[i + 1] - Ai[i];
540     rrowlens[0] = 0;
541     rrowlens++; /* shift the pointer to make the following expression more readable */
542     for (i = 0; i < niranks; i++) PetscCallMPI(MPI_Irecv(&rrowlens[ioffset[i]], ioffset[i + 1] - ioffset[i], MPIU_INT, iranks[i], tag, comm, &reqs[i]));
543     for (i = 0; i < nranks; i++) PetscCallMPI(MPI_Isend(&srowlens[roffset[i]], roffset[i + 1] - roffset[i], MPIU_INT, ranks[i], tag, comm, &reqs[niranks + i]));
544     PetscCallMPI(MPI_Waitall(niranks + nranks, reqs, MPI_STATUSES_IGNORE));
545 
546     /* Owner builds Ci on host by histogramming rrowlens[] */
547     MatRowMapKokkosViewHost Ci_h("i", Cm + 1);
548     Kokkos::deep_copy(Ci_h, 0); /* Zero Ci */
549     MatRowMapType *Ci_ptr = Ci_h.data();
550 
551     for (i = 0; i < nrows; i++) {
552       r = rows[i]; /* local row id of i-th received row */
553 #if defined(PETSC_USE_DEBUG)
554       PetscCheck(r >= 0 && r < Cm, PETSC_COMM_SELF, PETSC_ERR_PLIB, "local row id (%" PetscInt_FMT ") is out of range [0,%" PetscInt_FMT ")", r, Cm);
555 #endif
556       Ci_ptr[r + 1] += rrowlens[i]; /* add to length of row r in C */
557     }
558     for (i = 0; i < Cm; i++) Ci_ptr[i + 1] += Ci_ptr[i]; /* to CSR format */
559     Cnnz = Ci_ptr[Cm];
560 
561     /* For each received row, compute src & dst offsets in memory copying (from recv bufs abuf, jbuf to Ca, Cj) */
562     MatRowMapKokkosViewHost dstrowoffset_h("dstrowoffset_h", nrows);
563     PetscInt               *dstrowoffset_hptr = dstrowoffset_h.data();
564     PetscInt               *currowlens; /* Current row lens. They are temp accumulators for row lens in C, to help build dstrowoffset */
565 
566     PetscCall(PetscCalloc1(Cm, &currowlens));           /* Init with zero, to be added to */
567     for (i = 0; i < nrows; i++) {                       /* for each row I receive */
568       r                    = rows[i];                   /* row id in C */
569       dstrowoffset_hptr[i] = Ci_ptr[r] + currowlens[r]; /* dst offset of the new place for each recv'ed row in Ca/Cj */
570       currowlens[r] += rrowlens[i];                     /* accumulate to length of row r in C */
571     }
572     PetscCall(PetscFree(currowlens));
573 
574     rrowlens--;
575     for (i = 0; i < nrows; i++) rrowlens[i + 1] += rrowlens[i]; /* Change rrowlens[] to CSR format */
576     dstrowoffset = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), dstrowoffset_h);
577     srcrowoffset = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), rrowlens_h); /* src offset of each recv'ed row in abuf/jbuf */
578 
579     /* Build the reduceSF, which performs buffer to buffer send/recv */
580     PetscInt *sdisp, *rdisp; /* buffer to send offsets of roots, and buffer to recv them */
581     PetscCall(PetscMalloc2(niranks, &sdisp, nranks, &rdisp));
582     for (i = 0; i < niranks; i++) sdisp[i] = rrowlens[ioffset[i]];
583     for (i = 0; i < nranks; i++) PetscCallMPI(MPI_Irecv(&rdisp[i], 1, MPIU_INT, ranks[i], tag, comm, &reqs[i]));
584     for (i = 0; i < niranks; i++) PetscCallMPI(MPI_Isend(&sdisp[i], 1, MPIU_INT, iranks[i], tag, comm, &reqs[nranks + i]));
585     PetscCallMPI(MPI_Waitall(niranks + nranks, reqs, MPI_STATUSES_IGNORE));
586 
587     /* Nonzeros in abuf/jbuf are roots and those in A are leaves */
588     PetscInt     nroots2 = Cnnz, nleaves2 = Annz;
589     PetscSFNode *iremote;
590     PetscCall(PetscMalloc1(nleaves2, &iremote)); /* no free, since memory will be given to reduceSF */
591     for (i = 0; i < nranks; i++) {
592       PetscInt rootbase = rdisp[i];                      /* root offset at this root rank */
593       PetscInt leafbase = Ai[roffset[i]];                /* leaf base */
594       PetscInt nz       = Ai[roffset[i + 1]] - leafbase; /* I will send nz nonzeros to this root rank */
595       for (PetscInt k = 0; k < nz; k++) {
596         iremote[leafbase + k].rank  = ranks[i];
597         iremote[leafbase + k].index = rootbase + k;
598       }
599     }
600     PetscCall(PetscSFCreate(comm, &reduceSF));
601     PetscCall(PetscSFSetGraph(reduceSF, nroots2, nleaves2, NULL, PETSC_OWN_POINTER, iremote, PETSC_OWN_POINTER));
602     PetscCall(PetscFree2(sdisp, rdisp));
603 
604     /* Reduce Aa, Ajg to abuf and jbuf */
605 
606     /* If A uses local col ids, convert them to global ones before sending */
607     MatColIdxKokkosView Ajg;
608     if (local) {
609       Ajg                           = MatColIdxKokkosView("j", Annz);
610       const MatColIdxKokkosView &Aj = akok->j_dual.view_device();
611       Kokkos::parallel_for(
612         Annz, KOKKOS_LAMBDA(const PetscInt i) { Ajg(i) = l2g(Aj(i)); });
613     } else {
614       Ajg = akok->j_dual.view_device(); /* no data copy, just take a reference */
615     }
616 
617     MatColIdxKokkosView jbuf("jbuf", Cnnz);
618     abuf = MatScalarKokkosView("abuf", Cnnz);
619     PetscCallMPI(PetscSFReduceBegin(reduceSF, MPIU_INT, Ajg.data(), jbuf.data(), MPI_REPLACE));
620     PetscCallMPI(PetscSFReduceEnd(reduceSF, MPIU_INT, Ajg.data(), jbuf.data(), MPI_REPLACE));
621     PetscCallMPI(PetscSFReduceBegin(reduceSF, MPIU_SCALAR, akok->a_device_data(), abuf.data(), MPI_REPLACE));
622     PetscCallMPI(PetscSFReduceEnd(reduceSF, MPIU_SCALAR, akok->a_device_data(), abuf.data(), MPI_REPLACE));
623 
624     /* Copy data from abuf, jbuf to Ca, Cj */
625     MatRowMapKokkosView Ci = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Ci_h); /* Ci is an alias of Ci_h if no device */
626     MatColIdxKokkosView Cj("j", Cnnz);
627     MatScalarKokkosView Ca("a", Cnnz);
628 
629     Kokkos::parallel_for(
630       Kokkos::TeamPolicy<>(nrows, Kokkos::AUTO()), KOKKOS_LAMBDA(const KokkosTeamMemberType &t) {
631         PetscInt i   = t.league_rank();
632         PetscInt src = srcrowoffset(i), dst = dstrowoffset(i);
633         PetscInt len = srcrowoffset(i + 1) - srcrowoffset(i);
634         Kokkos::parallel_for(Kokkos::TeamThreadRange(t, len), [&](PetscInt k) {
635           Ca(dst + k) = abuf(src + k);
636           Cj(dst + k) = jbuf(src + k);
637         });
638       });
639 
640     /* Build C with Ca, Ci, Cj */
641     C = KokkosCsrMatrix("csrmat", Cm, N, Cnnz, Ca, Ci, Cj);
642     PetscCall(PetscFree2(srowlens, reqs));
643   } else SETERRQ(PETSC_COMM_SELF, PETSC_ERR_PLIB, "Unsupported MatReuse enum %d", reuse);
644   PetscFunctionReturn(PETSC_SUCCESS);
645 }
646 
647 /* MatSetMPIAIJKokkosWithGlobalCSRMatrix - Set the diag and offdiag parts of a `MATMPIAIJKOKKOS` matrix by splitting a KokkosCsrMatrix
648 
649   Input Parameters:
650 +  C        - the `MATMPIAIJKOKKOS` matrix, of size m,n,M,N
651 .  reuse    - indicate whether the matrix has called this function before
652 .  csrmat   - the KokkosCsrMatrix, of size m,N
653 -  Cdstart  - when reuse == `MAT_REUSE_MATRIX`, it is an input parameter. For each row in csrmat, it stores the start of the first
654               entry of the diag block of C in csrmat's j array. E.g, if row i has col ids = {0, 3, 4, 5, 7, 9} and the first diag
655               entry is 5, then Cdstart[i] = 3.
656 
657   Output Parameters:
658 +  C        - the updated `MATMPIAIJKOKKOS` matrix
659 -  Cdstart - when reuse == `MAT_INITIAL_MATRIX`, it is an output parameter
660 
661   Note:
662    Between calls with `MAT_INITIAL_MATRIX` or `MAT_REUSE_MATRIX`, csrmat must have the same nonzero pattern
663 
664 .seealso: [](chapter_matrices), `Mat`, `MATMPIAIJKOKKOS`
665  */
666 static PetscErrorCode MatSetMPIAIJKokkosWithGlobalCSRMatrix(Mat C, MatReuse reuse, const KokkosCsrMatrix &csrmat, MatRowMapKokkosView &Cdstart)
667 {
668   const MatScalarKokkosView      &Ca = csrmat.values;
669   const ConstMatRowMapKokkosView &Ci = csrmat.graph.row_map;
670   PetscInt                        m, n, N;
671 
672   PetscFunctionBegin;
673   PetscCall(MatGetLocalSize(C, &m, &n));
674   PetscCall(MatGetSize(C, NULL, &N));
675 
676   if (reuse == MAT_REUSE_MATRIX) {
677     Mat_MPIAIJ                *mpiaij = static_cast<Mat_MPIAIJ *>(C->data);
678     Mat_SeqAIJKokkos          *akok   = static_cast<Mat_SeqAIJKokkos *>(mpiaij->A->spptr);
679     Mat_SeqAIJKokkos          *bkok   = static_cast<Mat_SeqAIJKokkos *>(mpiaij->B->spptr);
680     const MatScalarKokkosView &Cda = akok->a_dual.view_device(), Coa = bkok->a_dual.view_device();
681     const MatRowMapKokkosView &Cdi = akok->i_dual.view_device(), Coi = bkok->i_dual.view_device();
682 
683     /* Fill 'a' of Cd and Co on device */
684     Kokkos::parallel_for(
685       Kokkos::TeamPolicy<>(m, Kokkos::AUTO()), KOKKOS_LAMBDA(const KokkosTeamMemberType &t) {
686         PetscInt i       = t.league_rank();     /* row i */
687         PetscInt clen    = Ci(i + 1) - Ci(i);   /* len of row i of C */
688         PetscInt cdlen   = Cdi(i + 1) - Cdi(i); /* len of row i of Cd */
689         PetscInt cdstart = Cdstart(i);          /* [start, end) of row i of Cd in C */
690         PetscInt cdend   = cdstart + cdlen;
691         /* [0, clen) is cut into three blocks: [0, cdstart), [cdstart, cdend), [cdend, clen) */
692         Kokkos::parallel_for(Kokkos::TeamThreadRange(t, clen), [&](PetscInt k) {
693           if (k < cdstart) { /* k in [0, cdstart) */
694             Coa(Coi(i) + k) = Ca(Ci(i) + k);
695           } else if (k < cdend) { /* k in [cdstart, cdend) */
696             Cda(Cdi(i) + (k - cdstart)) = Ca(Ci(i) + k);
697           } else { /* k in [cdend, clen) */
698             Coa(Coi(i) + k - cdlen) = Ca(Ci(i) + k);
699           }
700         });
701       });
702 
703     akok->a_dual.modify_device();
704     bkok->a_dual.modify_device();
705   } else if (reuse == MAT_INITIAL_MATRIX) {
706     Mat                        Cd, Co;
707     const MatColIdxKokkosView &Cj = csrmat.graph.entries;
708     MatRowMapKokkosDualView    Cdi_dual("i", m + 1), Coi_dual("i", m + 1);
709     MatRowMapKokkosView        Cdi = Cdi_dual.view_device(), Coi = Coi_dual.view_device();
710     PetscInt                   cstart, cend;
711 
712     /* Note that each row of C is sorted by col ids. We want to find out how to cut each row into three blocks:
713        left to the diag block, diag block, right to the diag block. The diag block have col ids in [cstart,cend).
714        Suppose a row of C has len nonzeros, indexed by [0, len). We want to know two indices: cdstart and cdend,
715        such that the three blocks are [0,cdstart), [cdstart,cdend), [cdend,len). The following code equivalentaly
716        stores values of cdstart and cdend-cstart (aka Cdi[]) instead.
717      */
718     Cdstart = MatRowMapKokkosView("Cdstart", m);
719     PetscCall(PetscLayoutGetRange(C->cmap, &cstart, &cend)); /* Not MatGetOwnershipRangeColumn() since C has not been preallocated yet */
720 
721     /* I could use RangePolicy and one thread per row. But since each thread essentially does binary search, threads in a
722       CUDA warp would completely diverge. So I use TeamPolicy with a team size 1.
723      */
724     Kokkos::parallel_for(
725       Kokkos::TeamPolicy<>(m, 1), KOKKOS_LAMBDA(const KokkosTeamMemberType &t) {
726         Kokkos::single(Kokkos::PerTeam(t), [=]() {                               /* Only one thread works in a team */
727                                                    PetscInt i = t.league_rank(); /* row i */
728                                                    PetscInt j, first, count, step;
729 
730                                                    if (i == 0) { /* Set the first entry of the i arrays to zero on device, to be used in CSR */
731                                                      Cdi(0) = 0;
732                                                      Coi(0) = 0;
733                                                    }
734 
735                                                    /* Do std::lower_bound(Ci(i),Ci(i+1),cstart) on Cj[]. We use j as the iterator. lower_bound() returns
736           in 'first' the first iterator with a value >= cstart, or last iterator if no such element is found.
737         */
738                                                    count = Ci(i + 1) - Ci(i);
739                                                    first = Ci(i);
740                                                    while (count > 0) {
741                                                      j    = first;
742                                                      step = count / 2;
743                                                      j += step;
744                                                      if (Cj(j) < cstart) {
745                                                        first = ++j;
746                                                        count -= step + 1;
747                                                      } else count = step;
748                                                    }
749                                                    Cdstart(i) = first - Ci(i); /* 'first' is the while-loop's output */
750 
751                                                    /* Do std::lower_bound(first,Ci(i+1),cend) on Cj[] */
752                                                    count = Ci(i + 1) - first;
753                                                    while (count > 0) {
754                                                      j    = first;
755                                                      step = count / 2;
756                                                      j += step;
757                                                      if (Cj(j) < cend) {
758                                                        first = ++j;
759                                                        count -= step + 1;
760                                                      } else count = step;
761                                                    }
762                                                    Cdi(i + 1) = first - (Ci(i) + Cdstart(i));     /* 'first' is the while-loop's output */
763                                                    Coi(i + 1) = (Ci(i + 1) - Ci(i)) - Cdi(i + 1); /* Co's row len = C's row len - Cd's row len */
764         });
765       });
766 
767     /* Convert row lens in Cdi[], Coi[] to CSR format using inclusive scan, e.g., changing [0,1,2,3] into [0,1,3,6] */
768     Kokkos::parallel_scan(
769       m + 1, KOKKOS_LAMBDA(const PetscInt i, PetscInt &update, const bool final) {
770         update += Cdi(i);
771         if (final) Cdi(i) = update;
772       });
773     Kokkos::parallel_scan(
774       m + 1, KOKKOS_LAMBDA(const PetscInt i, PetscInt &update, const bool final) {
775         update += Coi(i);
776         if (final) Coi(i) = update;
777       });
778 
779     /* Get Cdi, Coi on host (it is not a waste, since we do need them on host in
780        MatCreateSeqAIJKokkosWithCSRMatrix() below), then get nnz of Cd and Co.
781     */
782     Cdi_dual.modify_device();
783     Coi_dual.modify_device();
784     Cdi_dual.sync_host();
785     Coi_dual.sync_host();
786     PetscInt Cd_nnz = Cdi_dual.view_host().data()[m];
787     PetscInt Co_nnz = Coi_dual.view_host().data()[m];
788 
789     /* With nnz, allocate a, j for Cd and Co */
790     MatColIdxKokkosDualView Cdj_dual("j", Cd_nnz), Coj_dual("j", Co_nnz);
791     MatScalarKokkosDualView Cda_dual("a", Cd_nnz), Coa_dual("a", Co_nnz);
792 
793     /* Fill a, j of Cd and Co on device */
794     MatColIdxKokkosView Cdj = Cdj_dual.view_device(), Coj = Coj_dual.view_device();
795     MatScalarKokkosView Cda = Cda_dual.view_device(), Coa = Coa_dual.view_device();
796 
797     Kokkos::parallel_for(
798       Kokkos::TeamPolicy<>(m, Kokkos::AUTO()), KOKKOS_LAMBDA(const KokkosTeamMemberType &t) {
799         PetscInt i       = t.league_rank();     /* row i */
800         PetscInt clen    = Ci(i + 1) - Ci(i);   /* len of row i of C */
801         PetscInt cdlen   = Cdi(i + 1) - Cdi(i); /* len of row i of Cd */
802         PetscInt cdstart = Cdstart(i);          /* [start, end) of row i of Cd in C */
803         PetscInt cdend   = cdstart + cdlen;
804         /* [0, clen) is cut into three blocks: [0, cdstart), [cdstart, cdend), [cdend, clen) */
805         Kokkos::parallel_for(Kokkos::TeamThreadRange(t, clen), [&](PetscInt k) {
806           if (k < cdstart) { /* k in [0, cdstart) */
807             Coa(Coi(i) + k) = Ca(Ci(i) + k);
808             Coj(Coi(i) + k) = Cj(Ci(i) + k);
809           } else if (k < cdend) { /* k in [cdstart, cdend) */
810             Cda(Cdi(i) + (k - cdstart)) = Ca(Ci(i) + k);
811             Cdj(Cdi(i) + (k - cdstart)) = Cj(Ci(i) + k) - cstart; /* Use local col ids in Cdj */
812           } else {                                                /* k in [cdend, clen) */
813             Coa(Coi(i) + k - cdlen) = Ca(Ci(i) + k);
814             Coj(Coi(i) + k - cdlen) = Cj(Ci(i) + k);
815           }
816         });
817       });
818 
819     Cdj_dual.modify_device();
820     Cda_dual.modify_device();
821     Coj_dual.modify_device();
822     Coa_dual.modify_device();
823     /* With a, i, j for Cd and Co, finally build Cd, Co and then C. Their offloadmask will be set in each's MatAssemblyEnd */
824     auto cdkok = new Mat_SeqAIJKokkos(m, n, Cd_nnz, Cdi_dual, Cdj_dual, Cda_dual);
825     auto cokok = new Mat_SeqAIJKokkos(m, N, Co_nnz, Coi_dual, Coj_dual, Coa_dual);
826     PetscCall(MatCreateSeqAIJKokkosWithCSRMatrix(PETSC_COMM_SELF, cdkok, &Cd));
827     PetscCall(MatCreateSeqAIJKokkosWithCSRMatrix(PETSC_COMM_SELF, cokok, &Co));
828     PetscCall(MatSetMPIAIJKokkosWithSplitSeqAIJKokkosMatrices(C, Cd, Co)); /* Coj will be converted to local ids within */
829   }
830   PetscFunctionReturn(PETSC_SUCCESS);
831 }
832 
833 /* MatSeqAIJCompactOutExtraColumns_SeqAIJKokkos - Compact a SEQAIJKOKKS matrix's global col ids.
834 
835   It is similar to MatSeqAIJCompactOutExtraColumns_SeqAIJ, but it applies to SEQAIJKOKKOS and returns the l2g map in Kokkos view.
836 
837   Input Parameters:
838 +  C        - the MATMPIAIJKOKKOS matrix, of size m,n,M,N
839 .  reuse    - indicate whether the matrix has called this function before
840 .  csrmat   - the KokkosCsrMatrix, of size m,N
841 -  Cdoffset - when reuse == MAT_REUSE_MATRIX, it is an input parameter. For each row in csrmat, it stores the offset of the first
842               entry of the diag block of C in csrmat's j array.
843 
844   Output Parameters:
845 +  C        - the updated MATMPIAIJKOKKOS matrix
846 -  Cdoffset - when reuse == MAT_INITIAL_MATRIX, it is an output parameter
847 
848   Note:
849   the input matrix's col ids and col size will be changed.
850 */
851 static PetscErrorCode MatSeqAIJCompactOutExtraColumns_SeqAIJKokkos(Mat C, MatColIdxKokkosView &l2g)
852 {
853   Mat_SeqAIJKokkos      *ckok;
854   ISLocalToGlobalMapping l2gmap;
855   const PetscInt        *garray;
856   PetscInt               sz;
857 
858   PetscFunctionBegin;
859   /* Compact P_other's global col ids and col size. We do it since we guess with local ids KK might be more memory scalable */
860   PetscCall(MatSeqAIJCompactOutExtraColumns_SeqAIJ(C, &l2gmap));
861   ckok = static_cast<Mat_SeqAIJKokkos *>(C->spptr);
862   ckok->j_dual.modify_host(); /* P_other's j is modified on host; we need to sync it on device */
863   ckok->j_dual.sync_device();
864   ckok->SetColSize(C->cmap->n); /* Update col size of the csrmat in spptr */
865 
866   /* Build l2g -- the local to global mapping of C's cols */
867   PetscCall(ISLocalToGlobalMappingGetIndices(l2gmap, &garray));
868   PetscCall(ISLocalToGlobalMappingGetSize(l2gmap, &sz));
869   PetscCheck(C->cmap->n == sz, PETSC_COMM_SELF, PETSC_ERR_PLIB, "matrix column size(%" PetscInt_FMT ") != l2g mapping size(%" PetscInt_FMT ")", C->cmap->n, sz);
870 
871   ConstMatColIdxKokkosViewHost tmp(garray, sz);
872   l2g = MatColIdxKokkosView("l2g", sz);
873   Kokkos::deep_copy(l2g, tmp);
874 
875   PetscCall(ISLocalToGlobalMappingRestoreIndices(l2gmap, &garray));
876   PetscCall(ISLocalToGlobalMappingDestroy(&l2gmap));
877   PetscFunctionReturn(PETSC_SUCCESS);
878 }
879 
880 #if PETSC_PKG_KOKKOS_KERNELS_VERSION_GE(3, 7, 99)
881 static PetscErrorCode MatMPIAIJGetLocalMat_MPIAIJKokkos(Mat mat, MatReuse reuse, MatMatStruct_AB *mm, Mat *C)
882 {
883   Mat                 A, B;
884   const PetscInt     *garray;
885   Mat_SeqAIJ         *aseq, *bseq;
886   Mat_SeqAIJKokkos   *akok, *bkok, *ckok;
887   MatScalarKokkosView aa, ba, ca;
888   MatRowMapKokkosView ai, bi, ci;
889   MatColIdxKokkosView aj, bj, cj;
890   PetscInt            m, nnz;
891 
892   PetscFunctionBegin;
893   PetscCall(MatMPIAIJGetSeqAIJ(mat, &A, &B, &garray));
894   PetscCheckTypeName(A, MATSEQAIJKOKKOS);
895   PetscCheckTypeName(B, MATSEQAIJKOKKOS);
896   PetscCheck(reuse != MAT_INPLACE_MATRIX, PETSC_COMM_SELF, PETSC_ERR_SUP, "MAT_INPLACE_MATRIX not supported");
897   PetscCall(MatSeqAIJKokkosSyncDevice(A));
898   PetscCall(MatSeqAIJKokkosSyncDevice(B));
899   aseq = static_cast<Mat_SeqAIJ *>(A->data);
900   bseq = static_cast<Mat_SeqAIJ *>(B->data);
901   akok = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
902   bkok = static_cast<Mat_SeqAIJKokkos *>(B->spptr);
903   aa   = akok->a_dual.view_device();
904   ai   = akok->i_dual.view_device();
905   ba   = bkok->a_dual.view_device();
906   bi   = bkok->i_dual.view_device();
907   m    = A->rmap->n; /* M and nnz of C */
908   nnz  = aseq->nz + bseq->nz;
909   if (reuse == MAT_INITIAL_MATRIX) {
910     aj           = akok->j_dual.view_device();
911     bj           = bkok->j_dual.view_device();
912     auto ca_dual = MatScalarKokkosDualView("a", nnz);
913     auto ci_dual = MatRowMapKokkosDualView("i", m + 1);
914     auto cj_dual = MatColIdxKokkosDualView("j", nnz);
915     ca           = ca_dual.view_device();
916     ci           = ci_dual.view_device();
917     cj           = cj_dual.view_device();
918 
919     // For each row of B, find number of nonzeros on the left of the diagonal block (i.e., A).
920     // The result is stored in mm->B_NzDiagLeft for reuse in the numeric phase
921     MatColIdxKokkosViewHost NzLeft("NzLeft", m);
922     const MatRowMapType    *rowptr = bkok->i_host_data();
923     const MatColIdxType    *colidx = bkok->j_host_data();
924     MatColIdxType          *nzleft = NzLeft.data();
925     const MatColIdxType     cstart = mat->cmap->rstart; // start of global column indices of A; used to split B
926     for (PetscInt i = 0; i < m; i++) {
927       const MatColIdxType *first, *last, *it;
928       PetscInt             count, step;
929 
930       // Basically, std::lower_bound(first,last,cstart), but need to map columns from local to global with garray[]
931       first = colidx + rowptr[i];
932       last  = colidx + rowptr[i + 1];
933       count = last - first;
934       while (count > 0) {
935         it   = first;
936         step = count / 2;
937         it += step;
938         if (garray[*it] < cstart) {
939           first = ++it;
940           count -= step + 1;
941         } else count = step;
942       }
943       nzleft[i] = first - (colidx + rowptr[i]);
944     }
945     auto B_NzDiagLeft = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), NzLeft); // copy to device
946 
947     auto tmp = MatColIdxKokkosViewHost(const_cast<MatColIdxType *>(garray), B->cmap->n);
948     auto l2g = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), tmp); // copy garray to device
949 
950     // Shuffle A and B in parallel using Kokkos hierarchical parallelism
951     Kokkos::parallel_for(
952       Kokkos::TeamPolicy<>(m, Kokkos::AUTO()), KOKKOS_LAMBDA(const KokkosTeamMemberType &t) {
953         PetscInt i    = t.league_rank(); /* row i */
954         PetscInt disp = ai(i) + bi(i), alen = ai(i + 1) - ai(i), blen = bi(i + 1) - bi(i);
955         PetscInt nzleft = B_NzDiagLeft(i);
956 
957         Kokkos::single(Kokkos::PerTeam(t), [=]() {
958           ci(i) = disp;
959           if (i == m - 1) ci(m) = ai(m) + bi(m);
960         });
961 
962         Kokkos::parallel_for(Kokkos::TeamThreadRange(t, alen + blen), [&](PetscInt k) {
963           if (k < nzleft) { // portion of B that is on left of A
964             ca(disp + k) = ba(bi(i) + k);
965             cj(disp + k) = l2g(bj(bi(i) + k));
966           } else if (k < nzleft + alen) { // diag A
967             ca(disp + k) = aa(ai(i) + k - nzleft);
968             cj(disp + k) = aj(ai(i) + k - nzleft) + cstart; // add the shift to convert local to global.
969           } else {                                          // portion of B that is on right of A
970             ca(disp + k) = ba(bi(i) + k - alen);
971             cj(disp + k) = l2g(bj(bi(i) + k - alen));
972           }
973         });
974       });
975     ca_dual.modify_device();
976     ci_dual.modify_device();
977     cj_dual.modify_device();
978     PetscCallCXX(ckok = new Mat_SeqAIJKokkos(m, mat->cmap->N, nnz, ci_dual, cj_dual, ca_dual));
979     PetscCall(MatCreateSeqAIJKokkosWithCSRMatrix(PETSC_COMM_SELF, ckok, C));
980     mm->B_NzDiagLeft = B_NzDiagLeft;
981   } else if (reuse == MAT_REUSE_MATRIX) {
982     PetscValidHeaderSpecific(*C, MAT_CLASSID, 4);
983     PetscCheckTypeName(*C, MATSEQAIJKOKKOS);
984     ckok               = static_cast<Mat_SeqAIJKokkos *>((*C)->spptr);
985     ca                 = ckok->a_dual.view_device();
986     auto &B_NzDiagLeft = mm->B_NzDiagLeft;
987 
988     Kokkos::parallel_for(
989       Kokkos::TeamPolicy<>(m, Kokkos::AUTO()), KOKKOS_LAMBDA(const KokkosTeamMemberType &t) {
990         PetscInt i    = t.league_rank(); // row i
991         PetscInt disp = ai(i) + bi(i), alen = ai(i + 1) - ai(i), blen = bi(i + 1) - bi(i);
992         PetscInt nzleft = B_NzDiagLeft(i);
993 
994         Kokkos::parallel_for(Kokkos::TeamThreadRange(t, alen + blen), [&](PetscInt k) {
995           if (k < nzleft) { // portion of B that is on left of A
996             ca(disp + k) = ba(bi(i) + k);
997           } else if (k < nzleft + alen) { // diag A
998             ca(disp + k) = aa(ai(i) + k - nzleft);
999           } else { // portion of B that is on right of A
1000             ca(disp + k) = ba(bi(i) + k - alen);
1001           }
1002         });
1003       });
1004 
1005     PetscCall(MatSeqAIJKokkosModifyDevice(*C));
1006   }
1007   PetscFunctionReturn(PETSC_SUCCESS);
1008 }
1009 #endif
1010 
1011 /* MatProductSymbolic_MPIAIJKokkos_AB - AB flavor of MatProductSymbolic_MPIAIJKokkos
1012 
1013   Input Parameters:
1014 +  product  - Mat_Product which carried out the computation. Passed in to access info about this mat product.
1015 .  A        - an MPIAIJKOKKOS matrix
1016 .  B        - an MPIAIJKOKKOS matrix
1017 -  mm       - a struct used to stash intermediate data when computing AB. Persist from symbolic to numeric operations.
1018 
1019   Note: The local part of the result C is stored as mm->C_global, which is of type KokkosCsrMatrix and uses global col ids.
1020 */
1021 static PetscErrorCode MatProductSymbolic_MPIAIJKokkos_AB(Mat_Product *product, Mat A, Mat B, MatMatStruct_AB *mm)
1022 {
1023   Mat_MPIAIJ              *a  = static_cast<Mat_MPIAIJ *>(A->data);
1024   Mat                      Ad = a->A, Ao = a->B; /* diag and offdiag of A */
1025   IS                       glob = NULL;
1026   const PetscInt          *garray;
1027   PetscInt                 N = B->cmap->N, sz;
1028   ConstMatColIdxKokkosView l2g1; /* two temp maps mapping local col ids to global ones */
1029   MatColIdxKokkosView      l2g2;
1030   Mat                      C1, C2; /* intermediate matrices */
1031 
1032   PetscFunctionBegin;
1033 #if PETSC_PKG_KOKKOS_KERNELS_VERSION_LT(3, 7, 99)
1034   /* C1 = Ad * B_local. B_local is a matrix got by merging Bd and Bo, and uses local col ids */
1035   PetscCall(MatMPIAIJGetLocalMatMerge(B, MAT_INITIAL_MATRIX, &glob, &mm->B_local));
1036 #else
1037   PetscCall(MatMPIAIJGetLocalMat_MPIAIJKokkos(B, MAT_INITIAL_MATRIX, mm, &mm->B_local));
1038   PetscCall(ISCreateStride(MPI_COMM_SELF, N, 0, 1, &glob));
1039 #endif
1040 
1041   PetscCall(MatProductCreate(Ad, mm->B_local, NULL, &C1));
1042   PetscCall(MatProductSetType(C1, MATPRODUCT_AB));
1043   PetscCall(MatProductSetFill(C1, product->fill));
1044   C1->product->api_user = product->api_user;
1045   PetscCall(MatProductSetFromOptions(C1));
1046   PetscUseTypeMethod(C1, productsymbolic);
1047 
1048   PetscCall(ISGetIndices(glob, &garray));
1049   PetscCall(ISGetSize(glob, &sz));
1050   const auto &tmp = ConstMatColIdxKokkosViewHost(garray, sz);                       /* wrap garray as a view */
1051   l2g1            = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), tmp); /* maybe just an alias to tmp, so we restore garray at the very end */
1052   PetscCall(MatSeqAIJKokkosGetCSRMatrixWithGlobalColumnIds(C1, N, l2g1, mm->C1_global));
1053 
1054   /* C2 = Ao * B_other. B_other is a matrix consisting of needed rows of B gathered from other procs */
1055   PetscCall(MatSeqAIJKokkosBcast(mm->B_local, MAT_INITIAL_MATRIX, N, l2g1, a->Mvctx, mm->sf, mm->abuf, mm->rows, mm->rowoffset, mm->B_other));
1056 
1057   /* Compact B_other to use local ids as we guess KK spgemm is more memory scalable with that; We could skip the compaction to simplify code */
1058   PetscCall(MatSeqAIJCompactOutExtraColumns_SeqAIJKokkos(mm->B_other, l2g2));
1059   PetscCall(MatProductCreate(Ao, mm->B_other, NULL, &C2));
1060   PetscCall(MatProductSetType(C2, MATPRODUCT_AB));
1061   PetscCall(MatProductSetFill(C2, product->fill));
1062   C2->product->api_user = product->api_user;
1063   PetscCall(MatProductSetFromOptions(C2));
1064   PetscUseTypeMethod(C2, productsymbolic);
1065   PetscCall(MatSeqAIJKokkosGetCSRMatrixWithGlobalColumnIds(C2, N, l2g2, mm->C2_global));
1066 
1067   /* C = C1 + C2.  We actually use their global col ids versions in adding */
1068   mm->kh.create_spadd_handle(false); /* Input C1, C2 are NOT sorted, since B_local, B_other are not */
1069   KokkosSparse::spadd_symbolic(&mm->kh, mm->C1_global, mm->C2_global, mm->C_global);
1070   /* Have to do numeric since spadd_symbolic does not really populate column indices of the result matrix */
1071   KokkosSparse::spadd_numeric(&mm->kh, (MatScalarType)1.0, mm->C1_global, (MatScalarType)1.0, mm->C2_global, mm->C_global);
1072 
1073   mm->C1 = C1;
1074   mm->C2 = C2;
1075   PetscCall(ISRestoreIndices(glob, &garray));
1076   PetscCall(ISDestroy(&glob));
1077   PetscFunctionReturn(PETSC_SUCCESS);
1078 }
1079 
1080 /* MatProductSymbolic_MPIAIJKokkos_AtB - A^tB flavor of MatProductSymbolic_MPIAIJKokkos
1081 
1082   Input Parameters:
1083 +  product  - Mat_Product which carried out the computation. Passed in to access info about this mat product.
1084 .  A        - an MPIAIJKOKKOS matrix
1085 .  B        - a SEQAIJKOKKOS matrix. It works as if A^t is multiplied by a parallel matrix made up of Bs on each rank.
1086 .  localB   - Does B use local col ids? If false, then B is already in global col ids.
1087 .  N        - col size of the "parallel B matrix". It implies B's global col ids are in range of [0,N) and N is the same across the communicator.
1088 .  l2g      - If localB, then l2g maps B's local col ids to global ones.
1089 -  mm       - a struct used to stash intermediate data in AtB
1090 
1091   Note: The local part of the result C is stored as mm->C_global, which is of type KokkosCsrMatrix and uses global col ids.
1092 */
1093 static PetscErrorCode MatProductSymbolic_MPIAIJKokkos_AtB(Mat_Product *product, Mat A, Mat B, PetscBool localB, PetscInt N, const ConstMatColIdxKokkosView &l2g, MatMatStruct_AtB *mm)
1094 {
1095   Mat_MPIAIJ *a  = static_cast<Mat_MPIAIJ *>(A->data);
1096   Mat         Ad = a->A, Ao = a->B; /* diag and offdiag of A */
1097   Mat         C1, C2;               /* intermediate matrices */
1098 
1099   PetscFunctionBegin;
1100   /* C1 = Ad^t * B */
1101   PetscCall(MatProductCreate(Ad, B, NULL, &C1));
1102   PetscCall(MatProductSetType(C1, MATPRODUCT_AtB));
1103   PetscCall(MatProductSetFill(C1, product->fill));
1104   C1->product->api_user = product->api_user;
1105   PetscCall(MatProductSetFromOptions(C1));
1106   PetscUseTypeMethod(C1, productsymbolic);
1107 
1108   if (localB) PetscCall(MatSeqAIJKokkosGetCSRMatrixWithGlobalColumnIds(C1, N, l2g, mm->C1_global));
1109   else mm->C1_global = static_cast<Mat_SeqAIJKokkos *>(C1->spptr)->csrmat; /* the csrmat already uses global col ids */
1110 
1111   /* C2 = Ao^t * B */
1112   PetscCall(MatProductCreate(Ao, B, NULL, &C2));
1113   PetscCall(MatProductSetType(C2, MATPRODUCT_AtB));
1114   PetscCall(MatProductSetFill(C2, product->fill));
1115   C2->product->api_user = product->api_user;
1116   PetscCall(MatProductSetFromOptions(C2));
1117   PetscUseTypeMethod(C2, productsymbolic);
1118 
1119   PetscCall(MatSeqAIJKokkosReduce(C2, MAT_INITIAL_MATRIX, localB, N, l2g, a->Mvctx, mm->sf, mm->abuf, mm->srcrowoffset, mm->dstrowoffset, mm->C2_global));
1120 
1121   mm->kh.create_spadd_handle(false); /* Input C1, C2 are NOT sorted, since B may be not */
1122   KokkosSparse::spadd_symbolic(&mm->kh, mm->C1_global, mm->C2_global, mm->C_global);
1123   /* Have to do numeric since spadd_symbolic does not really populate column indices of the result matrix */
1124   KokkosSparse::spadd_numeric(&mm->kh, (MatScalarType)1.0, mm->C1_global, (MatScalarType)1.0, mm->C2_global, mm->C_global);
1125   mm->C1 = C1;
1126   mm->C2 = C2;
1127   PetscFunctionReturn(PETSC_SUCCESS);
1128 }
1129 
1130 PetscErrorCode MatProductNumeric_MPIAIJKokkos(Mat C)
1131 {
1132   Mat_Product                 *product = C->product;
1133   MatProductType               ptype;
1134   MatProductData_MPIAIJKokkos *mmdata;
1135   MatMatStruct                *mm = NULL;
1136   MatMatStruct_AB             *ab;
1137   MatMatStruct_AtB            *atb;
1138   Mat                          A, B, Ad, Ao, Bd, Bo;
1139   const MatScalarType          one = 1.0; /* Not use literal 1.0 directly, to avoid wrong template instantiation in KokkosSparse::spadd_numeric */
1140 
1141   PetscFunctionBegin;
1142   MatCheckProduct(C, 1);
1143   mmdata = static_cast<MatProductData_MPIAIJKokkos *>(product->data);
1144   ptype  = product->type;
1145   A      = product->A;
1146   B      = product->B;
1147   Ad     = static_cast<Mat_MPIAIJ *>(A->data)->A;
1148   Ao     = static_cast<Mat_MPIAIJ *>(A->data)->B;
1149   Bd     = static_cast<Mat_MPIAIJ *>(B->data)->A;
1150   Bo     = static_cast<Mat_MPIAIJ *>(B->data)->B;
1151 
1152   if (mmdata->reusesym) {           /* We reached here through e.g., MatMatMult(A,B,MAT_INITIAL_MATRIX,..,C), where symbolic/numeric are combined */
1153     mmdata->reusesym = PETSC_FALSE; /* So that next time when user calls MatMatMult(E,F,MAT_REUSE_MATRIX,..,C), we still do numeric  */
1154     ab               = mmdata->mmAB;
1155     atb              = mmdata->mmAtB;
1156     if (ab) {
1157       static_cast<MatProductData_SeqAIJKokkos *>(ab->C1->product->data)->reusesym = PETSC_FALSE;
1158       static_cast<MatProductData_SeqAIJKokkos *>(ab->C2->product->data)->reusesym = PETSC_FALSE;
1159     }
1160     if (atb) {
1161       static_cast<MatProductData_SeqAIJKokkos *>(atb->C1->product->data)->reusesym = PETSC_FALSE;
1162       static_cast<MatProductData_SeqAIJKokkos *>(atb->C2->product->data)->reusesym = PETSC_FALSE;
1163     }
1164     PetscFunctionReturn(PETSC_SUCCESS);
1165   }
1166 
1167   if (ptype == MATPRODUCT_AB) {
1168     ab = mmdata->mmAB;
1169     /* C1 = Ad * B_local */
1170     PetscCheck(ab->C1->ops->productnumeric && ab->C2->ops->productnumeric, PETSC_COMM_SELF, PETSC_ERR_PLIB, "Missing numeric op for MATPRODUCT_AB");
1171 #if PETSC_PKG_KOKKOS_KERNELS_VERSION_LT(3, 7, 99)
1172     PetscCall(MatMPIAIJGetLocalMatMerge(B, MAT_REUSE_MATRIX, NULL /* glob */, &ab->B_local));
1173 #else
1174     PetscCall(MatMPIAIJGetLocalMat_MPIAIJKokkos(B, MAT_REUSE_MATRIX, ab, &ab->B_local));
1175 #endif
1176 
1177     PetscCheck(ab->C1->product->B == ab->B_local, PETSC_COMM_SELF, PETSC_ERR_PLIB, "In MATPRODUCT_AB, internal mat product matrix C1->B has unexpectedly changed");
1178     if (ab->C1->product->A != Ad) PetscCall(MatProductReplaceMats(Ad, NULL, NULL, ab->C1));
1179     PetscCall((*ab->C1->ops->productnumeric)(ab->C1));
1180     PetscCall(MatSeqAIJKokkosBcast(ab->B_local, MAT_REUSE_MATRIX, 0 /* N */, MatColIdxKokkosView() /*l2g*/, NULL /*ownerSF*/, ab->sf, ab->abuf, ab->rows, ab->rowoffset, ab->B_other));
1181     /* C2 = Ao * B_other */
1182     PetscCheck(ab->C2->product->B == ab->B_other, PETSC_COMM_SELF, PETSC_ERR_PLIB, "In MATPRODUCT_AB, internal mat product matrix C2->B has unexpectedly changed");
1183     if (ab->C1->product->A != Ao) PetscCall(MatProductReplaceMats(Ao, NULL, NULL, ab->C2));
1184     PetscCall((*ab->C2->ops->productnumeric)(ab->C2));
1185     /* C = C1_global + C2_global */
1186     KokkosSparse::spadd_numeric(&ab->kh, one, ab->C1_global, one, ab->C2_global, ab->C_global);
1187     mm = static_cast<MatMatStruct *>(ab);
1188   } else if (ptype == MATPRODUCT_AtB) {
1189     atb = mmdata->mmAtB;
1190     PetscCheck(atb->C1->ops->productnumeric && atb->C2->ops->productnumeric, PETSC_COMM_SELF, PETSC_ERR_PLIB, "Missing numeric op for MATPRODUCT_AtB");
1191     /* C1 = Ad^t * B_local */
1192     PetscCall(MatMPIAIJGetLocalMatMerge(B, MAT_REUSE_MATRIX, NULL /* glob */, &atb->B_local));
1193     PetscCheck(atb->C1->product->B == atb->B_local, PETSC_COMM_SELF, PETSC_ERR_PLIB, "In MATPRODUCT_AtB, internal mat product matrix C1->B has unexpectedly changed");
1194     if (atb->C1->product->A != Ad) PetscCall(MatProductReplaceMats(Ad, NULL, NULL, atb->C1));
1195     PetscCall((*atb->C1->ops->productnumeric)(atb->C1));
1196 
1197     /* C2 = Ao^t * B_local */
1198     PetscCheck(atb->C2->product->B == atb->B_local, PETSC_COMM_SELF, PETSC_ERR_PLIB, "In MATPRODUCT_AtB, internal mat product matrix C2->B has unexpectedly changed");
1199     if (atb->C2->product->A != Ao) PetscCall(MatProductReplaceMats(Ao, NULL, NULL, atb->C2));
1200     PetscCall((*atb->C2->ops->productnumeric)(atb->C2));
1201     /* Form C2_global */
1202     PetscCall(MatSeqAIJKokkosReduce(atb->C2, MAT_REUSE_MATRIX, PETSC_TRUE, 0 /* N */, MatColIdxKokkosView() /*l2g*/, NULL /*ownerSF*/, atb->sf, atb->abuf, atb->srcrowoffset, atb->dstrowoffset, atb->C2_global));
1203     /* C = C1_global + C2_global */
1204     KokkosSparse::spadd_numeric(&atb->kh, one, atb->C1_global, one, atb->C2_global, atb->C_global);
1205     mm = static_cast<MatMatStruct *>(atb);
1206   } else if (ptype == MATPRODUCT_PtAP) { /* BtAB */
1207     ab = mmdata->mmAB;
1208 #if PETSC_PKG_KOKKOS_KERNELS_VERSION_LT(3, 7, 99)
1209     PetscCall(MatMPIAIJGetLocalMatMerge(B, MAT_REUSE_MATRIX, NULL /* glob */, &ab->B_local));
1210 #else
1211     PetscCall(MatMPIAIJGetLocalMat_MPIAIJKokkos(B, MAT_REUSE_MATRIX, ab, &ab->B_local));
1212 #endif
1213     /* ab->C1 = Ad * B_local */
1214     PetscCheck(ab->C1->product->B == ab->B_local, PETSC_COMM_SELF, PETSC_ERR_PLIB, "In MATPRODUCT_PtAP, internal mat product matrix ab->C1->B has unexpectedly changed");
1215     if (ab->C1->product->A != Ad) PetscCall(MatProductReplaceMats(Ad, NULL, NULL, ab->C1));
1216     PetscCall((*ab->C1->ops->productnumeric)(ab->C1));
1217     PetscCall(MatSeqAIJKokkosBcast(ab->B_local, MAT_REUSE_MATRIX, 0 /* N */, MatColIdxKokkosView() /*l2g*/, NULL /*ownerSF*/, ab->sf, ab->abuf, ab->rows, ab->rowoffset, ab->B_other));
1218     /* ab->C2 = Ao * B_other */
1219     if (ab->C2->product->A != Ao) PetscCall(MatProductReplaceMats(Ao, NULL, NULL, ab->C2));
1220     PetscCall((*ab->C2->ops->productnumeric)(ab->C2)); /* C2 = Ao * B_other */
1221     KokkosSparse::spadd_numeric(&ab->kh, one, ab->C1_global, one, ab->C2_global, ab->C_global);
1222 
1223     /* atb->C1 = Bd^t * ab->C_petsc */
1224     atb = mmdata->mmAtB;
1225     PetscCheck(atb->C1->product->B == ab->C_petsc, PETSC_COMM_SELF, PETSC_ERR_PLIB, "In MATPRODUCT_PtAP, internal mat product matrix atb->C1->B has unexpectedly changed");
1226     if (atb->C1->product->A != Bd) PetscCall(MatProductReplaceMats(Bd, NULL, NULL, atb->C1));
1227     PetscCall((*atb->C1->ops->productnumeric)(atb->C1));
1228     /* atb->C2 = Bo^t * ab->C_petsc */
1229     if (atb->C2->product->A != Bo) PetscCall(MatProductReplaceMats(Bo, NULL, NULL, atb->C2));
1230     PetscCall((*atb->C2->ops->productnumeric)(atb->C2));
1231     PetscCall(MatSeqAIJKokkosReduce(atb->C2, MAT_REUSE_MATRIX, PETSC_FALSE, 0 /* N */, MatColIdxKokkosView() /*l2g*/, NULL /* ownerSF */, atb->sf, atb->abuf, atb->srcrowoffset, atb->dstrowoffset, atb->C2_global));
1232     KokkosSparse::spadd_numeric(&atb->kh, one, atb->C1_global, one, atb->C2_global, atb->C_global);
1233     mm = static_cast<MatMatStruct *>(atb);
1234   }
1235   /* Split C_global to form C */
1236   PetscCall(MatSetMPIAIJKokkosWithGlobalCSRMatrix(C, MAT_REUSE_MATRIX, mm->C_global, mm->Cdstart));
1237   PetscFunctionReturn(PETSC_SUCCESS);
1238 }
1239 
1240 PetscErrorCode MatProductSymbolic_MPIAIJKokkos(Mat C)
1241 {
1242   Mat                          A, B;
1243   Mat_Product                 *product = C->product;
1244   MatProductType               ptype;
1245   MatProductData_MPIAIJKokkos *mmdata;
1246   MatMatStruct                *mm   = NULL;
1247   IS                           glob = NULL;
1248   const PetscInt              *garray;
1249   PetscInt                     m, n, M, N, sz;
1250   ConstMatColIdxKokkosView     l2g; /* map local col ids to global ones */
1251 
1252   PetscFunctionBegin;
1253   MatCheckProduct(C, 1);
1254   PetscCheck(!product->data, PetscObjectComm((PetscObject)C), PETSC_ERR_PLIB, "Product data not empty");
1255   ptype = product->type;
1256   A     = product->A;
1257   B     = product->B;
1258 
1259   switch (ptype) {
1260   case MATPRODUCT_AB:
1261     m = A->rmap->n;
1262     n = B->cmap->n;
1263     M = A->rmap->N;
1264     N = B->cmap->N;
1265     break;
1266   case MATPRODUCT_AtB:
1267     m = A->cmap->n;
1268     n = B->cmap->n;
1269     M = A->cmap->N;
1270     N = B->cmap->N;
1271     break;
1272   case MATPRODUCT_PtAP:
1273     m = B->cmap->n;
1274     n = B->cmap->n;
1275     M = B->cmap->N;
1276     N = B->cmap->N;
1277     break; /* BtAB */
1278   default:
1279     SETERRQ(PetscObjectComm((PetscObject)C), PETSC_ERR_PLIB, "Not for product type %s", MatProductTypes[ptype]);
1280   }
1281 
1282   PetscCall(MatSetSizes(C, m, n, M, N));
1283   PetscCall(PetscLayoutSetUp(C->rmap));
1284   PetscCall(PetscLayoutSetUp(C->cmap));
1285   PetscCall(MatSetType(C, ((PetscObject)A)->type_name));
1286 
1287   mmdata           = new MatProductData_MPIAIJKokkos();
1288   mmdata->reusesym = product->api_user;
1289 
1290   if (ptype == MATPRODUCT_AB) {
1291     mmdata->mmAB = new MatMatStruct_AB();
1292     PetscCall(MatProductSymbolic_MPIAIJKokkos_AB(product, A, B, mmdata->mmAB));
1293     mm = static_cast<MatMatStruct *>(mmdata->mmAB);
1294   } else if (ptype == MATPRODUCT_AtB) {
1295     mmdata->mmAtB = new MatMatStruct_AtB();
1296     auto atb      = mmdata->mmAtB;
1297     PetscCall(MatMPIAIJGetLocalMatMerge(B, MAT_INITIAL_MATRIX, &glob, &atb->B_local));
1298     PetscCall(ISGetIndices(glob, &garray));
1299     PetscCall(ISGetSize(glob, &sz));
1300     l2g = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), ConstMatColIdxKokkosViewHost(garray, sz));
1301     PetscCall(MatProductSymbolic_MPIAIJKokkos_AtB(product, A, atb->B_local, PETSC_TRUE, N, l2g, atb));
1302     PetscCall(ISRestoreIndices(glob, &garray));
1303     PetscCall(ISDestroy(&glob));
1304     mm = static_cast<MatMatStruct *>(atb);
1305   } else if (ptype == MATPRODUCT_PtAP) {    /* BtAB */
1306     mmdata->mmAB  = new MatMatStruct_AB();  /* tmp=A*B */
1307     mmdata->mmAtB = new MatMatStruct_AtB(); /* C=B^t*tmp */
1308     auto ab       = mmdata->mmAB;
1309     auto atb      = mmdata->mmAtB;
1310     PetscCall(MatProductSymbolic_MPIAIJKokkos_AB(product, A, B, ab));
1311     auto tmp = new Mat_SeqAIJKokkos(ab->C_global); /* Memory will be owned by ab->C_petsc */
1312     PetscCall(MatCreateSeqAIJKokkosWithCSRMatrix(PETSC_COMM_SELF, tmp, &ab->C_petsc));
1313     PetscCall(MatProductSymbolic_MPIAIJKokkos_AtB(product, B, ab->C_petsc, PETSC_FALSE, N, l2g /*not used*/, atb));
1314     mm = static_cast<MatMatStruct *>(atb);
1315   }
1316   /* Split the C_global into petsc A, B format */
1317   PetscCall(MatSetMPIAIJKokkosWithGlobalCSRMatrix(C, MAT_INITIAL_MATRIX, mm->C_global, mm->Cdstart));
1318   C->product->data       = mmdata;
1319   C->product->destroy    = MatProductDataDestroy_MPIAIJKokkos;
1320   C->ops->productnumeric = MatProductNumeric_MPIAIJKokkos;
1321   PetscFunctionReturn(PETSC_SUCCESS);
1322 }
1323 
1324 PETSC_INTERN PetscErrorCode MatProductSetFromOptions_MPIAIJKokkos(Mat mat)
1325 {
1326   Mat_Product *product = mat->product;
1327   PetscBool    match   = PETSC_FALSE;
1328   PetscBool    usecpu  = PETSC_FALSE;
1329 
1330   PetscFunctionBegin;
1331   MatCheckProduct(mat, 1);
1332   if (!product->A->boundtocpu && !product->B->boundtocpu) PetscCall(PetscObjectTypeCompare((PetscObject)product->B, ((PetscObject)product->A)->type_name, &match));
1333   if (match) { /* we can always fallback to the CPU if requested */
1334     switch (product->type) {
1335     case MATPRODUCT_AB:
1336       if (product->api_user) {
1337         PetscOptionsBegin(PetscObjectComm((PetscObject)mat), ((PetscObject)mat)->prefix, "MatMatMult", "Mat");
1338         PetscCall(PetscOptionsBool("-matmatmult_backend_cpu", "Use CPU code", "MatMatMult", usecpu, &usecpu, NULL));
1339         PetscOptionsEnd();
1340       } else {
1341         PetscOptionsBegin(PetscObjectComm((PetscObject)mat), ((PetscObject)mat)->prefix, "MatProduct_AB", "Mat");
1342         PetscCall(PetscOptionsBool("-mat_product_algorithm_backend_cpu", "Use CPU code", "MatMatMult", usecpu, &usecpu, NULL));
1343         PetscOptionsEnd();
1344       }
1345       break;
1346     case MATPRODUCT_AtB:
1347       if (product->api_user) {
1348         PetscOptionsBegin(PetscObjectComm((PetscObject)mat), ((PetscObject)mat)->prefix, "MatTransposeMatMult", "Mat");
1349         PetscCall(PetscOptionsBool("-mattransposematmult_backend_cpu", "Use CPU code", "MatTransposeMatMult", usecpu, &usecpu, NULL));
1350         PetscOptionsEnd();
1351       } else {
1352         PetscOptionsBegin(PetscObjectComm((PetscObject)mat), ((PetscObject)mat)->prefix, "MatProduct_AtB", "Mat");
1353         PetscCall(PetscOptionsBool("-mat_product_algorithm_backend_cpu", "Use CPU code", "MatTransposeMatMult", usecpu, &usecpu, NULL));
1354         PetscOptionsEnd();
1355       }
1356       break;
1357     case MATPRODUCT_PtAP:
1358       if (product->api_user) {
1359         PetscOptionsBegin(PetscObjectComm((PetscObject)mat), ((PetscObject)mat)->prefix, "MatPtAP", "Mat");
1360         PetscCall(PetscOptionsBool("-matptap_backend_cpu", "Use CPU code", "MatPtAP", usecpu, &usecpu, NULL));
1361         PetscOptionsEnd();
1362       } else {
1363         PetscOptionsBegin(PetscObjectComm((PetscObject)mat), ((PetscObject)mat)->prefix, "MatProduct_PtAP", "Mat");
1364         PetscCall(PetscOptionsBool("-mat_product_algorithm_backend_cpu", "Use CPU code", "MatPtAP", usecpu, &usecpu, NULL));
1365         PetscOptionsEnd();
1366       }
1367       break;
1368     default:
1369       break;
1370     }
1371     match = (PetscBool)!usecpu;
1372   }
1373   if (match) {
1374     switch (product->type) {
1375     case MATPRODUCT_AB:
1376     case MATPRODUCT_AtB:
1377     case MATPRODUCT_PtAP:
1378       mat->ops->productsymbolic = MatProductSymbolic_MPIAIJKokkos;
1379       break;
1380     default:
1381       break;
1382     }
1383   }
1384   /* fallback to MPIAIJ ops */
1385   if (!mat->ops->productsymbolic) PetscCall(MatProductSetFromOptions_MPIAIJ(mat));
1386   PetscFunctionReturn(PETSC_SUCCESS);
1387 }
1388 
1389 static PetscErrorCode MatSetPreallocationCOO_MPIAIJKokkos(Mat mat, PetscCount coo_n, PetscInt coo_i[], PetscInt coo_j[])
1390 {
1391   Mat_MPIAIJ       *mpiaij = (Mat_MPIAIJ *)mat->data;
1392   Mat_MPIAIJKokkos *mpikok;
1393 
1394   PetscFunctionBegin;
1395   PetscCall(MatSetPreallocationCOO_MPIAIJ(mat, coo_n, coo_i, coo_j)); /* mpiaij->A,B's type is set to seqaijkokkos */
1396   mat->preallocated = PETSC_TRUE;
1397   PetscCall(MatAssemblyBegin(mat, MAT_FINAL_ASSEMBLY));
1398   PetscCall(MatAssemblyEnd(mat, MAT_FINAL_ASSEMBLY));
1399   PetscCall(MatZeroEntries(mat));
1400   mpikok = static_cast<Mat_MPIAIJKokkos *>(mpiaij->spptr);
1401   delete mpikok;
1402   mpiaij->spptr = new Mat_MPIAIJKokkos(mpiaij);
1403   PetscFunctionReturn(PETSC_SUCCESS);
1404 }
1405 
1406 static PetscErrorCode MatSetValuesCOO_MPIAIJKokkos(Mat mat, const PetscScalar v[], InsertMode imode)
1407 {
1408   Mat_MPIAIJ                 *mpiaij = static_cast<Mat_MPIAIJ *>(mat->data);
1409   Mat_MPIAIJKokkos           *mpikok = static_cast<Mat_MPIAIJKokkos *>(mpiaij->spptr);
1410   Mat                         A = mpiaij->A, B = mpiaij->B;
1411   PetscCount                  Annz = mpiaij->Annz, Annz2 = mpiaij->Annz2, Bnnz = mpiaij->Bnnz, Bnnz2 = mpiaij->Bnnz2;
1412   MatScalarKokkosView         Aa, Ba;
1413   MatScalarKokkosView         v1;
1414   MatScalarKokkosView        &vsend  = mpikok->sendbuf_d;
1415   const MatScalarKokkosView  &v2     = mpikok->recvbuf_d;
1416   const PetscCountKokkosView &Ajmap1 = mpikok->Ajmap1_d, Ajmap2 = mpikok->Ajmap2_d, Aimap2 = mpikok->Aimap2_d;
1417   const PetscCountKokkosView &Bjmap1 = mpikok->Bjmap1_d, Bjmap2 = mpikok->Bjmap2_d, Bimap2 = mpikok->Bimap2_d;
1418   const PetscCountKokkosView &Aperm1 = mpikok->Aperm1_d, Aperm2 = mpikok->Aperm2_d, Bperm1 = mpikok->Bperm1_d, Bperm2 = mpikok->Bperm2_d;
1419   const PetscCountKokkosView &Cperm1 = mpikok->Cperm1_d;
1420   PetscMemType                memtype;
1421 
1422   PetscFunctionBegin;
1423   PetscCall(PetscGetMemType(v, &memtype)); /* Return PETSC_MEMTYPE_HOST when v is NULL */
1424   if (PetscMemTypeHost(memtype)) {         /* If user gave v[] in host, we need to copy it to device if any */
1425     v1 = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), MatScalarKokkosViewHost((PetscScalar *)v, mpiaij->coo_n));
1426   } else {
1427     v1 = MatScalarKokkosView((PetscScalar *)v, mpiaij->coo_n); /* Directly use v[]'s memory */
1428   }
1429 
1430   if (imode == INSERT_VALUES) {
1431     PetscCall(MatSeqAIJGetKokkosViewWrite(A, &Aa)); /* write matrix values */
1432     PetscCall(MatSeqAIJGetKokkosViewWrite(B, &Ba));
1433   } else {
1434     PetscCall(MatSeqAIJGetKokkosView(A, &Aa)); /* read & write matrix values */
1435     PetscCall(MatSeqAIJGetKokkosView(B, &Ba));
1436   }
1437 
1438   /* Pack entries to be sent to remote */
1439   Kokkos::parallel_for(
1440     vsend.extent(0), KOKKOS_LAMBDA(const PetscCount i) { vsend(i) = v1(Cperm1(i)); });
1441 
1442   /* Send remote entries to their owner and overlap the communication with local computation */
1443   PetscCall(PetscSFReduceWithMemTypeBegin(mpiaij->coo_sf, MPIU_SCALAR, PETSC_MEMTYPE_KOKKOS, vsend.data(), PETSC_MEMTYPE_KOKKOS, v2.data(), MPI_REPLACE));
1444   /* Add local entries to A and B in one kernel */
1445   Kokkos::parallel_for(
1446     Annz + Bnnz, KOKKOS_LAMBDA(PetscCount i) {
1447       PetscScalar sum = 0.0;
1448       if (i < Annz) {
1449         for (PetscCount k = Ajmap1(i); k < Ajmap1(i + 1); k++) sum += v1(Aperm1(k));
1450         Aa(i) = (imode == INSERT_VALUES ? 0.0 : Aa(i)) + sum;
1451       } else {
1452         i -= Annz;
1453         for (PetscCount k = Bjmap1(i); k < Bjmap1(i + 1); k++) sum += v1(Bperm1(k));
1454         Ba(i) = (imode == INSERT_VALUES ? 0.0 : Ba(i)) + sum;
1455       }
1456     });
1457   PetscCall(PetscSFReduceEnd(mpiaij->coo_sf, MPIU_SCALAR, vsend.data(), v2.data(), MPI_REPLACE));
1458 
1459   /* Add received remote entries to A and B in one kernel */
1460   Kokkos::parallel_for(
1461     Annz2 + Bnnz2, KOKKOS_LAMBDA(PetscCount i) {
1462       if (i < Annz2) {
1463         for (PetscCount k = Ajmap2(i); k < Ajmap2(i + 1); k++) Aa(Aimap2(i)) += v2(Aperm2(k));
1464       } else {
1465         i -= Annz2;
1466         for (PetscCount k = Bjmap2(i); k < Bjmap2(i + 1); k++) Ba(Bimap2(i)) += v2(Bperm2(k));
1467       }
1468     });
1469 
1470   if (imode == INSERT_VALUES) {
1471     PetscCall(MatSeqAIJRestoreKokkosViewWrite(A, &Aa)); /* Increase A & B's state etc. */
1472     PetscCall(MatSeqAIJRestoreKokkosViewWrite(B, &Ba));
1473   } else {
1474     PetscCall(MatSeqAIJRestoreKokkosView(A, &Aa));
1475     PetscCall(MatSeqAIJRestoreKokkosView(B, &Ba));
1476   }
1477   PetscFunctionReturn(PETSC_SUCCESS);
1478 }
1479 
1480 PetscErrorCode MatDestroy_MPIAIJKokkos(Mat A)
1481 {
1482   Mat_MPIAIJ *mpiaij = (Mat_MPIAIJ *)A->data;
1483 
1484   PetscFunctionBegin;
1485   PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatMPIAIJSetPreallocation_C", NULL));
1486   PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatMPIAIJGetLocalMatMerge_C", NULL));
1487   PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatSetPreallocationCOO_C", NULL));
1488   PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatSetValuesCOO_C", NULL));
1489   delete (Mat_MPIAIJKokkos *)mpiaij->spptr;
1490   PetscCall(MatDestroy_MPIAIJ(A));
1491   PetscFunctionReturn(PETSC_SUCCESS);
1492 }
1493 
1494 PETSC_INTERN PetscErrorCode MatConvert_MPIAIJ_MPIAIJKokkos(Mat A, MatType mtype, MatReuse reuse, Mat *newmat)
1495 {
1496   Mat         B;
1497   Mat_MPIAIJ *a;
1498 
1499   PetscFunctionBegin;
1500   if (reuse == MAT_INITIAL_MATRIX) {
1501     PetscCall(MatDuplicate(A, MAT_COPY_VALUES, newmat));
1502   } else if (reuse == MAT_REUSE_MATRIX) {
1503     PetscCall(MatCopy(A, *newmat, SAME_NONZERO_PATTERN));
1504   }
1505   B = *newmat;
1506 
1507   B->boundtocpu = PETSC_FALSE;
1508   PetscCall(PetscFree(B->defaultvectype));
1509   PetscCall(PetscStrallocpy(VECKOKKOS, &B->defaultvectype));
1510   PetscCall(PetscObjectChangeTypeName((PetscObject)B, MATMPIAIJKOKKOS));
1511 
1512   a = static_cast<Mat_MPIAIJ *>(A->data);
1513   if (a->A) PetscCall(MatSetType(a->A, MATSEQAIJKOKKOS));
1514   if (a->B) PetscCall(MatSetType(a->B, MATSEQAIJKOKKOS));
1515   if (a->lvec) PetscCall(VecSetType(a->lvec, VECSEQKOKKOS));
1516 
1517   B->ops->assemblyend           = MatAssemblyEnd_MPIAIJKokkos;
1518   B->ops->mult                  = MatMult_MPIAIJKokkos;
1519   B->ops->multadd               = MatMultAdd_MPIAIJKokkos;
1520   B->ops->multtranspose         = MatMultTranspose_MPIAIJKokkos;
1521   B->ops->productsetfromoptions = MatProductSetFromOptions_MPIAIJKokkos;
1522   B->ops->destroy               = MatDestroy_MPIAIJKokkos;
1523 
1524   PetscCall(PetscObjectComposeFunction((PetscObject)B, "MatMPIAIJSetPreallocation_C", MatMPIAIJSetPreallocation_MPIAIJKokkos));
1525   PetscCall(PetscObjectComposeFunction((PetscObject)B, "MatMPIAIJGetLocalMatMerge_C", MatMPIAIJGetLocalMatMerge_MPIAIJKokkos));
1526   PetscCall(PetscObjectComposeFunction((PetscObject)B, "MatSetPreallocationCOO_C", MatSetPreallocationCOO_MPIAIJKokkos));
1527   PetscCall(PetscObjectComposeFunction((PetscObject)B, "MatSetValuesCOO_C", MatSetValuesCOO_MPIAIJKokkos));
1528   PetscFunctionReturn(PETSC_SUCCESS);
1529 }
1530 /*MC
1531    MATAIJKOKKOS - "mpiaijkokkos", a matrix type to be used for CSR sparse matrices with Kokkos
1532 
1533    A matrix type type using Kokkos-Kernels CrsMatrix type for portability across different device types
1534 
1535    Options Database Key:
1536 .  -mat_type aijkokkos - sets the matrix type to `MATAIJKOKKOS`
1537 
1538   Level: beginner
1539 
1540 .seealso: [](chapter_matrices), `Mat`, `MatCreateAIJKokkos()`, `MATSEQAIJKOKKOS`, `MATSEQAIJ`, `MATMPIAIJ`
1541 M*/
1542 PETSC_EXTERN PetscErrorCode MatCreate_MPIAIJKokkos(Mat A)
1543 {
1544   PetscFunctionBegin;
1545   PetscCall(PetscKokkosInitializeCheck());
1546   PetscCall(MatCreate_MPIAIJ(A));
1547   PetscCall(MatConvert_MPIAIJ_MPIAIJKokkos(A, MATMPIAIJKOKKOS, MAT_INPLACE_MATRIX, &A));
1548   PetscFunctionReturn(PETSC_SUCCESS);
1549 }
1550 
1551 /*@C
1552    MatCreateAIJKokkos - Creates a sparse matrix in `MATAIJKOKOS` (compressed row) format
1553    (the default parallel PETSc format).  This matrix will ultimately pushed down
1554    to Kokkos for calculations. For good matrix
1555    assembly performance the user should preallocate the matrix storage by setting
1556    the parameter `nz` (or the array `nnz`).
1557 
1558    Collective
1559 
1560    Input Parameters:
1561 +  comm - MPI communicator, set to `PETSC_COMM_SELF`
1562 .  m - number of rows
1563 .  n - number of columns
1564 .  nz - number of nonzeros per row (same for all rows)
1565 -  nnz - array containing the number of nonzeros in the various rows
1566          (possibly different for each row) or `NULL`
1567 
1568    Output Parameter:
1569 .  A - the matrix
1570 
1571    Level: intermediate
1572 
1573    Notes:
1574    It is recommended that one use the `MatCreate()`, `MatSetType()` and/or `MatSetFromOptions()`,
1575    MatXXXXSetPreallocation() paradigm instead of this routine directly.
1576    [MatXXXXSetPreallocation() is, for example, `MatSeqAIJSetPreallocation()`]
1577 
1578    If `nnz` is given then `nz` is ignored
1579 
1580    The AIJ format, also called compressed row storage), is fully compatible with standard Fortran
1581    storage.  That is, the stored row and column indices can begin at
1582    either one (as in Fortran) or zero.
1583 
1584    Specify the preallocated storage with either nz or nnz (not both).
1585    Set `nz` = `PETSC_DEFAULT` and `nnz` = `NULL` for PETSc to control dynamic memory
1586    allocation.  For large problems you MUST preallocate memory or you
1587    will get TERRIBLE performance, see the users' manual chapter on matrices.
1588 
1589    By default, this format uses inodes (identical nodes) when possible, to
1590    improve numerical efficiency of matrix-vector products and solves. We
1591    search for consecutive rows with the same nonzero structure, thereby
1592    reusing matrix information to achieve increased efficiency.
1593 
1594    Developer Note:
1595    This manual page is for the sequential constructor, not the parallel constructor
1596 
1597 .seealso: [](chapter_matrices), `Mat`, `MATAIJKOKOS`, `MATSEQAIJKOKOS`, `MATMPIAIJKOKOS`, `MatCreate()`, `MatCreateAIJ()`, `MatSetValues()`,
1598           `MatSeqAIJSetColumnIndices()`, `MatCreateSeqAIJWithArrays()`, `MatCreateAIJ()`, `MATMPIAIJKOKKOS`, `MATAIJKOKKOS`
1599 @*/
1600 PetscErrorCode MatCreateAIJKokkos(MPI_Comm comm, PetscInt m, PetscInt n, PetscInt M, PetscInt N, PetscInt d_nz, const PetscInt d_nnz[], PetscInt o_nz, const PetscInt o_nnz[], Mat *A)
1601 {
1602   PetscMPIInt size;
1603 
1604   PetscFunctionBegin;
1605   PetscCall(MatCreate(comm, A));
1606   PetscCall(MatSetSizes(*A, m, n, M, N));
1607   PetscCallMPI(MPI_Comm_size(comm, &size));
1608   if (size > 1) {
1609     PetscCall(MatSetType(*A, MATMPIAIJKOKKOS));
1610     PetscCall(MatMPIAIJSetPreallocation(*A, d_nz, d_nnz, o_nz, o_nnz));
1611   } else {
1612     PetscCall(MatSetType(*A, MATSEQAIJKOKKOS));
1613     PetscCall(MatSeqAIJSetPreallocation(*A, d_nz, d_nnz));
1614   }
1615   PetscFunctionReturn(PETSC_SUCCESS);
1616 }
1617 
1618 // get GPU pointer to stripped down Mat. For both Seq and MPI Mat.
1619 PetscErrorCode MatKokkosGetDeviceMatWrite(Mat A, PetscSplitCSRDataStructure *B)
1620 {
1621   PetscMPIInt                size, rank;
1622   MPI_Comm                   comm;
1623   PetscSplitCSRDataStructure d_mat = NULL;
1624 
1625   PetscFunctionBegin;
1626   PetscCall(PetscObjectGetComm((PetscObject)A, &comm));
1627   PetscCallMPI(MPI_Comm_size(comm, &size));
1628   PetscCallMPI(MPI_Comm_rank(comm, &rank));
1629   if (size == 1) {
1630     PetscCall(MatSeqAIJKokkosGetDeviceMat(A, &d_mat));
1631     PetscCall(MatSeqAIJKokkosModifyDevice(A)); /* Since we are going to modify matrix values on device */
1632   } else {
1633     Mat_MPIAIJ *aij = (Mat_MPIAIJ *)A->data;
1634     PetscCall(MatSeqAIJKokkosGetDeviceMat(aij->A, &d_mat));
1635     PetscCall(MatSeqAIJKokkosModifyDevice(aij->A));
1636     PetscCall(MatSeqAIJKokkosModifyDevice(aij->B));
1637     PetscCheck(A->nooffprocentries || aij->donotstash, PetscObjectComm((PetscObject)A), PETSC_ERR_SUP, "Device assembly does not currently support offproc values insertion. Use MatSetOption(A,MAT_NO_OFF_PROC_ENTRIES,PETSC_TRUE) or MatSetOption(A,MAT_IGNORE_OFF_PROC_ENTRIES,PETSC_TRUE)");
1638   }
1639   // act like MatSetValues because not called on host
1640   if (A->assembled) {
1641     if (A->was_assembled) PetscCall(PetscInfo(A, "Assemble more than once already\n"));
1642     A->was_assembled = PETSC_TRUE; // this is done (lazy) in MatAssemble but we are not calling it anymore - done in AIJ AssemblyEnd, need here?
1643   } else {
1644     PetscCall(PetscInfo(A, "Warning !assemble ??? assembled=%" PetscInt_FMT "\n", A->assembled));
1645   }
1646   if (!d_mat) {
1647     struct _n_SplitCSRMat h_mat; /* host container */
1648     Mat_SeqAIJKokkos     *aijkokA;
1649     Mat_SeqAIJ           *jaca;
1650     PetscInt              n = A->rmap->n, nnz;
1651     Mat                   Amat;
1652     PetscInt             *colmap;
1653 
1654     /* create and copy h_mat */
1655     h_mat.M = A->cmap->N; // use for debug build
1656     PetscCall(PetscInfo(A, "Create device matrix in Kokkos\n"));
1657     if (size == 1) {
1658       Amat            = A;
1659       jaca            = (Mat_SeqAIJ *)A->data;
1660       h_mat.rstart    = 0;
1661       h_mat.rend      = A->rmap->n;
1662       h_mat.cstart    = 0;
1663       h_mat.cend      = A->cmap->n;
1664       h_mat.offdiag.i = h_mat.offdiag.j = NULL;
1665       h_mat.offdiag.a                   = NULL;
1666       aijkokA                           = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
1667     } else {
1668       Mat_MPIAIJ       *aij  = (Mat_MPIAIJ *)A->data;
1669       Mat_SeqAIJ       *jacb = (Mat_SeqAIJ *)aij->B->data;
1670       PetscInt          ii;
1671       Mat_SeqAIJKokkos *aijkokB;
1672 
1673       Amat    = aij->A;
1674       aijkokA = static_cast<Mat_SeqAIJKokkos *>(aij->A->spptr);
1675       aijkokB = static_cast<Mat_SeqAIJKokkos *>(aij->B->spptr);
1676       jaca    = (Mat_SeqAIJ *)aij->A->data;
1677       PetscCheck(!aij->B->cmap->n || aij->garray, comm, PETSC_ERR_PLIB, "MPIAIJ Matrix was assembled but is missing garray");
1678       PetscCheck(aij->B->rmap->n == aij->A->rmap->n, comm, PETSC_ERR_SUP, "Only support aij->B->rmap->n == aij->A->rmap->n");
1679       aij->donotstash          = PETSC_TRUE;
1680       aij->A->nooffprocentries = aij->B->nooffprocentries = A->nooffprocentries = PETSC_TRUE;
1681       jaca->nonew = jacb->nonew = PETSC_TRUE; // no more disassembly
1682       PetscCall(PetscCalloc1(A->cmap->N, &colmap));
1683       for (ii = 0; ii < aij->B->cmap->n; ii++) colmap[aij->garray[ii]] = ii + 1;
1684       // allocate B copy data
1685       h_mat.rstart = A->rmap->rstart;
1686       h_mat.rend   = A->rmap->rend;
1687       h_mat.cstart = A->cmap->rstart;
1688       h_mat.cend   = A->cmap->rend;
1689       nnz          = jacb->i[n];
1690       if (jacb->compressedrow.use) {
1691         const Kokkos::View<PetscInt *, Kokkos::HostSpace, Kokkos::MemoryTraits<Kokkos::Unmanaged>> h_i_k(jacb->i, n + 1);
1692         aijkokB->i_uncompressed_d = Kokkos::View<PetscInt *>(Kokkos::create_mirror(DefaultMemorySpace(), h_i_k));
1693         Kokkos::deep_copy(aijkokB->i_uncompressed_d, h_i_k);
1694         h_mat.offdiag.i = aijkokB->i_uncompressed_d.data();
1695       } else {
1696         h_mat.offdiag.i = aijkokB->i_device_data();
1697       }
1698       h_mat.offdiag.j = aijkokB->j_device_data();
1699       h_mat.offdiag.a = aijkokB->a_device_data();
1700       {
1701         Kokkos::View<PetscInt *, Kokkos::HostSpace, Kokkos::MemoryTraits<Kokkos::Unmanaged>> h_colmap_k(colmap, A->cmap->N);
1702         aijkokB->colmap_d = Kokkos::View<PetscInt *>(Kokkos::create_mirror(DefaultMemorySpace(), h_colmap_k));
1703         Kokkos::deep_copy(aijkokB->colmap_d, h_colmap_k);
1704         h_mat.colmap = aijkokB->colmap_d.data();
1705         PetscCall(PetscFree(colmap));
1706       }
1707       h_mat.offdiag.ignorezeroentries = jacb->ignorezeroentries;
1708       h_mat.offdiag.n                 = n;
1709     }
1710     // allocate A copy data
1711     nnz                          = jaca->i[n];
1712     h_mat.diag.n                 = n;
1713     h_mat.diag.ignorezeroentries = jaca->ignorezeroentries;
1714     PetscCallMPI(MPI_Comm_rank(comm, &h_mat.rank));
1715     PetscCheck(!jaca->compressedrow.use, PETSC_COMM_SELF, PETSC_ERR_PLIB, "A does not support compressed row (todo)");
1716     h_mat.diag.i = aijkokA->i_device_data();
1717     h_mat.diag.j = aijkokA->j_device_data();
1718     h_mat.diag.a = aijkokA->a_device_data();
1719     // copy pointers and metadata to device
1720     PetscCall(MatSeqAIJKokkosSetDeviceMat(Amat, &h_mat));
1721     PetscCall(MatSeqAIJKokkosGetDeviceMat(Amat, &d_mat));
1722     PetscCall(PetscInfo(A, "Create device Mat n=%" PetscInt_FMT " nnz=%" PetscInt_FMT "\n", h_mat.diag.n, nnz));
1723   }
1724   *B           = d_mat;       // return it, set it in Mat, and set it up
1725   A->assembled = PETSC_FALSE; // ready to write with matsetvalues - this done (lazy) in normal MatSetValues
1726   PetscFunctionReturn(PETSC_SUCCESS);
1727 }
1728 
1729 PETSC_INTERN PetscErrorCode MatSeqAIJKokkosGetOffloadMask(Mat A, const char **mask)
1730 {
1731   Mat_SeqAIJKokkos *aijkok = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
1732 
1733   PetscFunctionBegin;
1734   if (!aijkok) *mask = "AIJKOK_UNALLOCATED";
1735   else if (aijkok->a_dual.need_sync_host()) *mask = "PETSC_OFFLOAD_GPU";
1736   else if (aijkok->a_dual.need_sync_device()) *mask = "PETSC_OFFLOAD_CPU";
1737   else *mask = "PETSC_OFFLOAD_BOTH";
1738   PetscFunctionReturn(PETSC_SUCCESS);
1739 }
1740 
1741 PETSC_INTERN PetscErrorCode MatAIJKokkosPrintOffloadMask(Mat A)
1742 {
1743   PetscMPIInt size;
1744   Mat         Ad, Ao;
1745   const char *amask, *bmask;
1746 
1747   PetscFunctionBegin;
1748   PetscCallMPI(MPI_Comm_size(PetscObjectComm((PetscObject)A), &size));
1749 
1750   if (size == 1) {
1751     PetscCall(MatSeqAIJKokkosGetOffloadMask(A, &amask));
1752     PetscCall(PetscPrintf(PETSC_COMM_SELF, "%s\n", amask));
1753   } else {
1754     Ad = ((Mat_MPIAIJ *)A->data)->A;
1755     Ao = ((Mat_MPIAIJ *)A->data)->B;
1756     PetscCall(MatSeqAIJKokkosGetOffloadMask(Ad, &amask));
1757     PetscCall(MatSeqAIJKokkosGetOffloadMask(Ao, &bmask));
1758     PetscCall(PetscPrintf(PETSC_COMM_SELF, "Diag : Off-diag = %s : %s\n", amask, bmask));
1759   }
1760   PetscFunctionReturn(PETSC_SUCCESS);
1761 }
1762