xref: /petsc/src/mat/impls/aij/mpi/kokkos/mpiaijkok.kokkos.cxx (revision b6e6beb40ac0a73ccc19b70153df96d5058dce43)
1 #include <petscvec_kokkos.hpp>
2 #include <petscsf.h>
3 #include <petsc/private/sfimpl.h>
4 #include <../src/mat/impls/aij/mpi/mpiaij.h>
5 #include <../src/mat/impls/aij/mpi/kokkos/mpiaijkok.hpp>
6 #include <KokkosSparse_spadd.hpp>
7 
8 PetscErrorCode MatAssemblyEnd_MPIAIJKokkos(Mat A, MatAssemblyType mode)
9 {
10   Mat_SeqAIJKokkos *aijkok;
11   Mat_MPIAIJ       *mpiaij = (Mat_MPIAIJ *)A->data;
12 
13   PetscFunctionBegin;
14   PetscCall(MatAssemblyEnd_MPIAIJ(A, mode));
15   /* E.g., MatCreateSubMatrix() calls MatCreateMPIAIJWithSeqAIJ(comm,A,B,..), which creates Bnew of SEQAIJ and destroys B of SEQAIJKOKKOS.
16      Thus we finalize A/B/lvec's type in MatAssemblyEnd() to handle various cases.
17    */
18   if (mode == MAT_FINAL_ASSEMBLY) {
19     PetscCall(MatSetType(mpiaij->A, MATSEQAIJKOKKOS));
20     PetscCall(MatSetType(mpiaij->B, MATSEQAIJKOKKOS));
21     PetscCall(VecSetType(mpiaij->lvec, VECSEQKOKKOS));
22   }
23   aijkok = static_cast<Mat_SeqAIJKokkos *>(((Mat_MPIAIJ *)A->data)->A->spptr); /* Access spptr after MatAssemblyEnd_MPIAIJ(), which might have deleted old spptr */
24   if (aijkok && aijkok->device_mat_d.data()) {
25     A->offloadmask = PETSC_OFFLOAD_GPU; // in GPU mode, no going back. MatSetValues checks this
26   }
27 
28   PetscFunctionReturn(PETSC_SUCCESS);
29 }
30 
31 PetscErrorCode MatMPIAIJSetPreallocation_MPIAIJKokkos(Mat mat, PetscInt d_nz, const PetscInt d_nnz[], PetscInt o_nz, const PetscInt o_nnz[])
32 {
33   Mat_MPIAIJ *mpiaij = (Mat_MPIAIJ *)mat->data;
34 
35   PetscFunctionBegin;
36   PetscCall(PetscLayoutSetUp(mat->rmap));
37   PetscCall(PetscLayoutSetUp(mat->cmap));
38 #if defined(PETSC_USE_DEBUG)
39   if (d_nnz) {
40     PetscInt i;
41     for (i = 0; i < mat->rmap->n; i++) PetscCheck(d_nnz[i] >= 0, PETSC_COMM_SELF, PETSC_ERR_ARG_OUTOFRANGE, "d_nnz cannot be less than 0: local row %" PetscInt_FMT " value %" PetscInt_FMT, i, d_nnz[i]);
42   }
43   if (o_nnz) {
44     PetscInt i;
45     for (i = 0; i < mat->rmap->n; i++) PetscCheck(o_nnz[i] >= 0, PETSC_COMM_SELF, PETSC_ERR_ARG_OUTOFRANGE, "o_nnz cannot be less than 0: local row %" PetscInt_FMT " value %" PetscInt_FMT, i, o_nnz[i]);
46   }
47 #endif
48 #if defined(PETSC_USE_CTABLE)
49   PetscCall(PetscHMapIDestroy(&mpiaij->colmap));
50 #else
51   PetscCall(PetscFree(mpiaij->colmap));
52 #endif
53   PetscCall(PetscFree(mpiaij->garray));
54   PetscCall(VecDestroy(&mpiaij->lvec));
55   PetscCall(VecScatterDestroy(&mpiaij->Mvctx));
56   /* Because the B will have been resized we simply destroy it and create a new one each time */
57   PetscCall(MatDestroy(&mpiaij->B));
58 
59   if (!mpiaij->A) {
60     PetscCall(MatCreate(PETSC_COMM_SELF, &mpiaij->A));
61     PetscCall(MatSetSizes(mpiaij->A, mat->rmap->n, mat->cmap->n, mat->rmap->n, mat->cmap->n));
62   }
63   if (!mpiaij->B) {
64     PetscMPIInt size;
65     PetscCallMPI(MPI_Comm_size(PetscObjectComm((PetscObject)mat), &size));
66     PetscCall(MatCreate(PETSC_COMM_SELF, &mpiaij->B));
67     PetscCall(MatSetSizes(mpiaij->B, mat->rmap->n, size > 1 ? mat->cmap->N : 0, mat->rmap->n, size > 1 ? mat->cmap->N : 0));
68   }
69   PetscCall(MatSetType(mpiaij->A, MATSEQAIJKOKKOS));
70   PetscCall(MatSetType(mpiaij->B, MATSEQAIJKOKKOS));
71   PetscCall(MatSeqAIJSetPreallocation(mpiaij->A, d_nz, d_nnz));
72   PetscCall(MatSeqAIJSetPreallocation(mpiaij->B, o_nz, o_nnz));
73   mat->preallocated = PETSC_TRUE;
74   PetscFunctionReturn(PETSC_SUCCESS);
75 }
76 
77 PetscErrorCode MatMult_MPIAIJKokkos(Mat mat, Vec xx, Vec yy)
78 {
79   Mat_MPIAIJ *mpiaij = (Mat_MPIAIJ *)mat->data;
80   PetscInt    nt;
81 
82   PetscFunctionBegin;
83   PetscCall(VecGetLocalSize(xx, &nt));
84   PetscCheck(nt == mat->cmap->n, PETSC_COMM_SELF, PETSC_ERR_ARG_SIZ, "Incompatible partition of mat (%" PetscInt_FMT ") and xx (%" PetscInt_FMT ")", mat->cmap->n, nt);
85   PetscCall(VecScatterBegin(mpiaij->Mvctx, xx, mpiaij->lvec, INSERT_VALUES, SCATTER_FORWARD));
86   PetscCall((*mpiaij->A->ops->mult)(mpiaij->A, xx, yy));
87   PetscCall(VecScatterEnd(mpiaij->Mvctx, xx, mpiaij->lvec, INSERT_VALUES, SCATTER_FORWARD));
88   PetscCall((*mpiaij->B->ops->multadd)(mpiaij->B, mpiaij->lvec, yy, yy));
89   PetscFunctionReturn(PETSC_SUCCESS);
90 }
91 
92 PetscErrorCode MatMultAdd_MPIAIJKokkos(Mat mat, Vec xx, Vec yy, Vec zz)
93 {
94   Mat_MPIAIJ *mpiaij = (Mat_MPIAIJ *)mat->data;
95   PetscInt    nt;
96 
97   PetscFunctionBegin;
98   PetscCall(VecGetLocalSize(xx, &nt));
99   PetscCheck(nt == mat->cmap->n, PETSC_COMM_SELF, PETSC_ERR_ARG_SIZ, "Incompatible partition of mat (%" PetscInt_FMT ") and xx (%" PetscInt_FMT ")", mat->cmap->n, nt);
100   PetscCall(VecScatterBegin(mpiaij->Mvctx, xx, mpiaij->lvec, INSERT_VALUES, SCATTER_FORWARD));
101   PetscCall((*mpiaij->A->ops->multadd)(mpiaij->A, xx, yy, zz));
102   PetscCall(VecScatterEnd(mpiaij->Mvctx, xx, mpiaij->lvec, INSERT_VALUES, SCATTER_FORWARD));
103   PetscCall((*mpiaij->B->ops->multadd)(mpiaij->B, mpiaij->lvec, zz, zz));
104   PetscFunctionReturn(PETSC_SUCCESS);
105 }
106 
107 PetscErrorCode MatMultTranspose_MPIAIJKokkos(Mat mat, Vec xx, Vec yy)
108 {
109   Mat_MPIAIJ *mpiaij = (Mat_MPIAIJ *)mat->data;
110   PetscInt    nt;
111 
112   PetscFunctionBegin;
113   PetscCall(VecGetLocalSize(xx, &nt));
114   PetscCheck(nt == mat->rmap->n, PETSC_COMM_SELF, PETSC_ERR_ARG_SIZ, "Incompatible partition of mat (%" PetscInt_FMT ") and xx (%" PetscInt_FMT ")", mat->rmap->n, nt);
115   PetscCall((*mpiaij->B->ops->multtranspose)(mpiaij->B, xx, mpiaij->lvec));
116   PetscCall((*mpiaij->A->ops->multtranspose)(mpiaij->A, xx, yy));
117   PetscCall(VecScatterBegin(mpiaij->Mvctx, mpiaij->lvec, yy, ADD_VALUES, SCATTER_REVERSE));
118   PetscCall(VecScatterEnd(mpiaij->Mvctx, mpiaij->lvec, yy, ADD_VALUES, SCATTER_REVERSE));
119   PetscFunctionReturn(PETSC_SUCCESS);
120 }
121 
122 /* Merge the "A, B" matrices of mat into a matrix C.  mat's type is MPIAIJKOKKOS. C's type is MATSEQAIJKOKKOS.
123    A is put before B. C's size would be A->rmap->n by (A->cmap->n + B->cmap->n).
124    C still uses local column ids. Their corresponding global column ids are returned in glob.
125 */
126 PetscErrorCode MatMPIAIJGetLocalMatMerge_MPIAIJKokkos(Mat mat, MatReuse reuse, IS *glob, Mat *C)
127 {
128   Mat             Ad, Ao;
129   const PetscInt *cmap;
130 
131   PetscFunctionBegin;
132   PetscCall(MatMPIAIJGetSeqAIJ(mat, &Ad, &Ao, &cmap));
133   PetscCall(MatSeqAIJKokkosMergeMats(Ad, Ao, reuse, C));
134   if (glob) {
135     PetscInt cst, i, dn, on, *gidx;
136     PetscCall(MatGetLocalSize(Ad, NULL, &dn));
137     PetscCall(MatGetLocalSize(Ao, NULL, &on));
138     PetscCall(MatGetOwnershipRangeColumn(mat, &cst, NULL));
139     PetscCall(PetscMalloc1(dn + on, &gidx));
140     for (i = 0; i < dn; i++) gidx[i] = cst + i;
141     for (i = 0; i < on; i++) gidx[i + dn] = cmap[i];
142     PetscCall(ISCreateGeneral(PetscObjectComm((PetscObject)Ad), dn + on, gidx, PETSC_OWN_POINTER, glob));
143   }
144   PetscFunctionReturn(PETSC_SUCCESS);
145 }
146 
147 /* Structs used in matrix product C=AB, C=A^tB and C=B^tAB */
148 struct MatMatStruct {
149   MatRowMapKokkosView Cdstart; /* Used to split sequential matrix into petsc's A, B format */
150   PetscSF             sf;      /* SF to send/recv matrix entries */
151   MatScalarKokkosView abuf;    /* buf of mat values in send/recv */
152   Mat                 C1, C2, B_local;
153   KokkosCsrMatrix     C1_global, C2_global, C_global;
154   KernelHandle        kh;
155   MatMatStruct() noexcept : sf(nullptr), C1(nullptr), C2(nullptr), B_local(nullptr) { }
156 
157   ~MatMatStruct()
158   {
159     PetscFunctionBegin;
160     PetscCallAbort(PETSC_COMM_SELF, MatDestroy(&C1));
161     PetscCallAbort(PETSC_COMM_SELF, MatDestroy(&C2));
162     PetscCallAbort(PETSC_COMM_SELF, MatDestroy(&B_local));
163     PetscCallAbort(PETSC_COMM_SELF, PetscSFDestroy(&sf));
164     kh.destroy_spadd_handle();
165     PetscFunctionReturnVoid();
166   }
167 };
168 
169 struct MatMatStruct_AB : public MatMatStruct {
170   MatColIdxKokkosView rows{};
171   MatRowMapKokkosView rowoffset{};
172   Mat                 B_other{}, C_petsc{}; /* SEQAIJKOKKOS matrices. TODO: have a better var name than C_petsc */
173 
174   ~MatMatStruct_AB() noexcept
175   {
176     PetscFunctionBegin;
177     PetscCallAbort(PETSC_COMM_SELF, MatDestroy(&B_other));
178     PetscCallAbort(PETSC_COMM_SELF, MatDestroy(&C_petsc));
179     PetscFunctionReturnVoid();
180   }
181 };
182 
183 struct MatMatStruct_AtB : public MatMatStruct {
184   MatRowMapKokkosView srcrowoffset, dstrowoffset;
185 };
186 
187 struct MatProductData_MPIAIJKokkos {
188   MatMatStruct_AB  *mmAB     = nullptr;
189   MatMatStruct_AtB *mmAtB    = nullptr;
190   PetscBool         reusesym = PETSC_FALSE;
191 
192   ~MatProductData_MPIAIJKokkos()
193   {
194     delete mmAB;
195     delete mmAtB;
196   }
197 };
198 
199 static PetscErrorCode MatProductDataDestroy_MPIAIJKokkos(void *data)
200 {
201   PetscFunctionBegin;
202   PetscCallCXX(delete static_cast<MatProductData_MPIAIJKokkos *>(data));
203   PetscFunctionReturn(PETSC_SUCCESS);
204 }
205 
206 /* MatSeqAIJKokkosGetCSRMatrixWithGlobalColumnIds - Get a KokkosCsrMatrix from a MATSEQAIJKOKKOS matrix
207 
208    Input Parameters:
209 +  A       - the MATSEQAIJKOKKOS matrix
210 .  N       - new column size for the returned Kokkos matrix
211 -  l2g     - a map that maps old col ids to new col ids
212 
213    Output Parameters:
214 .  csrmat  - the Kokkos matrix, which has the same row size as A, shares a, i but not j with A.
215  */
216 static PetscErrorCode MatSeqAIJKokkosGetCSRMatrixWithGlobalColumnIds(Mat A, PetscInt N, const ConstMatColIdxKokkosView &l2g, KokkosCsrMatrix &csrmat)
217 {
218   KokkosCsrMatrix    &orig = static_cast<Mat_SeqAIJKokkos *>(A->spptr)->csrmat;
219   MatColIdxKokkosView jg("jg", orig.nnz()); /* New j array for csrmat */
220 
221   PetscFunctionBegin;
222   PetscCallCXX(Kokkos::parallel_for(
223     orig.nnz(), KOKKOS_LAMBDA(const PetscInt i) { jg(i) = l2g(orig.graph.entries(i)); }));
224   PetscCallCXX(csrmat = KokkosCsrMatrix("csrmat", orig.numRows(), N, orig.nnz(), orig.values, orig.graph.row_map, jg));
225   PetscFunctionReturn(PETSC_SUCCESS);
226 }
227 
228 /* MatSetMPIAIJKokkosWithSplitSeqAIJKokkosMatrices - Set the diag and offdiag matrices of a MATMPIAIJKOKKOS matrix.
229    It is similar to MatCreateMPIAIJWithSplitArrays.
230 
231   Input Parameters:
232 +  mat   - the MATMPIAIJKOKKOS matrix, which should have its type and layout set, but should not have its diag, offdiag matrices set
233 .  A     - the diag matrix using local col ids
234 -  B     - the offdiag matrix using global col ids
235 
236   Output Parameters:
237 .  mat   - the updated MATMPIAIJKOKKOS matrix
238 */
239 static PetscErrorCode MatSetMPIAIJKokkosWithSplitSeqAIJKokkosMatrices(Mat mat, Mat A, Mat B)
240 {
241   Mat_MPIAIJ       *mpiaij = static_cast<Mat_MPIAIJ *>(mat->data);
242   PetscInt          m, n, M, N, Am, An, Bm, Bn;
243   Mat_SeqAIJKokkos *bkok = static_cast<Mat_SeqAIJKokkos *>(B->spptr);
244 
245   PetscFunctionBegin;
246   PetscCall(MatGetSize(mat, &M, &N));
247   PetscCall(MatGetLocalSize(mat, &m, &n));
248   PetscCall(MatGetLocalSize(A, &Am, &An));
249   PetscCall(MatGetLocalSize(B, &Bm, &Bn));
250 
251   PetscCheck(m == Am && m == Bm, PETSC_COMM_SELF, PETSC_ERR_PLIB, "local number of rows do not match");
252   PetscCheck(n == An, PETSC_COMM_SELF, PETSC_ERR_PLIB, "local number of columns do not match");
253   PetscCheck(N == Bn, PETSC_COMM_SELF, PETSC_ERR_PLIB, "global number of columns do not match");
254   PetscCheck(!mpiaij->A && !mpiaij->B, PETSC_COMM_SELF, PETSC_ERR_PLIB, "A, B of the MPIAIJ matrix are not empty");
255   mpiaij->A = A;
256   mpiaij->B = B;
257 
258   mat->preallocated     = PETSC_TRUE;
259   mat->nooffprocentries = PETSC_TRUE; /* See MatAssemblyBegin_MPIAIJ. In effect, making MatAssemblyBegin a nop */
260 
261   PetscCall(MatSetOption(mat, MAT_NO_OFF_PROC_ENTRIES, PETSC_TRUE));
262   PetscCall(MatAssemblyBegin(mat, MAT_FINAL_ASSEMBLY));
263   /* MatAssemblyEnd is critical here. It sets mat->offloadmask according to A and B's, and
264     also gets mpiaij->B compacted, with its col ids and size reduced
265   */
266   PetscCall(MatAssemblyEnd(mat, MAT_FINAL_ASSEMBLY));
267   PetscCall(MatSetOption(mat, MAT_NO_OFF_PROC_ENTRIES, PETSC_FALSE));
268   PetscCall(MatSetOption(mat, MAT_NEW_NONZERO_LOCATION_ERR, PETSC_TRUE));
269 
270   /* Update bkok with new local col ids (stored on host) and size */
271   bkok->j_dual.modify_host();
272   bkok->j_dual.sync_device();
273   bkok->SetColSize(mpiaij->B->cmap->n);
274   PetscFunctionReturn(PETSC_SUCCESS);
275 }
276 
277 /* MatSeqAIJKokkosBcast - Bcast rows of a SEQAIJKOKKOS matrice (B) to form a SEQAIJKOKKOS matrix (C).
278 
279    It is essentially the MPIAIJKOKKOS counterpart of MatGetBrowsOfAoCols_MPIAIJ, but supports device and uses PetscSF.
280    In the given ownerSF, leaves correspond to rows in C, and roots correspond to rows in B. Roots may connect to multiple leaves.
281    Suppose C's j-th row is connected to a root identified by PetscSFNode (k,i), it means we will bcast the i-th row of B on rank k
282    to j-th row of C. ownerSF's leaves must be contiguous (in other words, as if ilocal=NULL was used to set its graph).
283 
284    Collective on comm of ownerSF
285 
286    Input Parameters:
287 +   B       - the SEQAIJKOKKOS matrix, using local col ids
288 .   reuse   - either MAT_INITIAL_MATRIX or MAT_REUSE_MATRIX
289 .   N       - global col ids are in range of [0,N). N Must be the same across ranks (nonsignificant in MAT_REUSE_MATRIX)
290 .   l2g     - a map mapping B's local col ids to global ones (nonsignificant in MAT_REUSE_MATRIX)
291 .   ownerSF - the ownership SF (nonsignificant in MAT_REUSE_MATRIX)
292 
293    Input/Output Parameters (out when resue = MAT_INITIAL_MATRIX, inout when reuse = MAT_REUSE_MATRIX)
294 +   bcastSF   - the SF used to bcast rows of B. This plain SF does buffer (abuf) to buffer (Ca) send/recv. In this SF, vertices are nonzeros.
295 .   abuf      - buffer for sending matrix values
296 .   rows      - array containing indices of (local) rows that this rank needs to bcast to others. Each receiver rank has a chunk in rows[].
297                 Values in rows[] might have repeats, which simply indicates a row will be bcast'ed to multiple neighbors.
298 .   rowoffset - For each row in rows[], it will be copied to rowoffset[] at abuf[]
299 -   C         -  the SEQAIJKOKKOS matrix made of the bcast'ed rows, using local col ids.
300 */
301 static PetscErrorCode MatSeqAIJKokkosBcast(Mat B, MatReuse reuse, PetscInt N, const ConstMatColIdxKokkosView &l2g, PetscSF ownerSF, PetscSF &bcastSF, MatScalarKokkosView &abuf, MatColIdxKokkosView &rows, MatRowMapKokkosView &rowoffset, Mat &C)
302 {
303   Mat_SeqAIJKokkos *bkok, *ckok;
304 
305   PetscFunctionBegin;
306   PetscCall(MatSeqAIJKokkosSyncDevice(B)); /* Make sure B->spptr is accessible */
307   bkok = static_cast<Mat_SeqAIJKokkos *>(B->spptr);
308 
309   if (reuse == MAT_REUSE_MATRIX) {
310     ckok = static_cast<Mat_SeqAIJKokkos *>(C->spptr);
311 
312     const auto &Ba = bkok->a_dual.view_device();
313     const auto &Bi = bkok->i_dual.view_device();
314     const auto &Ca = ckok->a_dual.view_device();
315 
316     /* Copy Ba to abuf */
317     Kokkos::parallel_for(
318       Kokkos::TeamPolicy<>(rows.extent(0), Kokkos::AUTO()), KOKKOS_LAMBDA(const KokkosTeamMemberType &t) {
319         PetscInt i    = t.league_rank(); /* rows[i] is r-th row of B */
320         PetscInt r    = rows(i);
321         PetscInt base = rowoffset(i); /* Copy r-th row of B to this offset in abuf[] */
322         Kokkos::parallel_for(Kokkos::TeamThreadRange(t, Bi(r + 1) - Bi(r)), [&](PetscInt k) { abuf(base + k) = Ba(Bi(r) + k); });
323       });
324 
325     /* Send abuf to Ca through bcastSF and then mark C is updated on device */
326     PetscCall(PetscSFBcastBegin(bcastSF, MPIU_SCALAR, abuf.data(), Ca.data(), MPI_REPLACE)); /* TODO: get memtype for abuf */
327     PetscCall(PetscSFBcastEnd(bcastSF, MPIU_SCALAR, abuf.data(), Ca.data(), MPI_REPLACE));
328     ckok->a_dual.modify_device();
329   } else if (reuse == MAT_INITIAL_MATRIX) {
330     MPI_Comm    comm;
331     PetscMPIInt tag;
332     PetscInt    k, Cm, Cn, Cnnz, *Ci_h, nroots, nleaves;
333 
334     PetscCallMPI(PetscObjectGetComm((PetscObject)ownerSF, &comm));
335     PetscCall(PetscSFGetGraph(ownerSF, &nroots, &nleaves, NULL, NULL));
336     Cm = nleaves; /* row size of C */
337     Cn = N;       /* col size of C, which initially uses global ids, so we can safely set its col size as N */
338 
339     /* Get row lens (nz) of B's rows for later fast query */
340     PetscInt       *Browlens;
341     const PetscInt *tmp = bkok->i_host_data();
342     PetscCall(PetscMalloc1(nroots, &Browlens));
343     for (k = 0; k < nroots; k++) Browlens[k] = tmp[k + 1] - tmp[k];
344 
345     /* By ownerSF, each proc gets lens of rows of C */
346     MatRowMapKokkosDualView Ci("i", Cm + 1); /* C's rowmap */
347     Ci_h    = Ci.view_host().data();
348     Ci_h[0] = 0;
349     PetscCall(PetscSFBcastWithMemTypeBegin(ownerSF, MPIU_INT, PETSC_MEMTYPE_HOST, Browlens, PETSC_MEMTYPE_HOST, &Ci_h[1], MPI_REPLACE));
350     PetscCall(PetscSFBcastEnd(ownerSF, MPIU_INT, Browlens, &Ci_h[1], MPI_REPLACE));
351     for (k = 1; k < Cm + 1; k++) Ci_h[k] += Ci_h[k - 1]; /* Convert lens to CSR */
352     Cnnz = Ci_h[Cm];
353     Ci.modify_host();
354     Ci.sync_device();
355 
356     /* With the newly known Cnnz, we are able to allocate (j, a) for C on host & device */
357     MatColIdxKokkosDualView Cj("j", Cnnz);
358     MatScalarKokkosDualView Ca("a", Cnnz);
359 
360     /* Now build the bcastSF to fill Ca, Cj. This plain SF only does (contiguous) buffer to buffer send/recv */
361     const PetscMPIInt *iranks, *ranks;
362     const PetscInt    *ioffset, *irootloc, *roffset;
363     PetscInt           i, j, niranks, nranks, *sdisp, *rdisp, *rowptr;
364     MPI_Request       *reqs;
365 
366     PetscCall(PetscSFGetLeafRanks(ownerSF, &niranks, &iranks, &ioffset, &irootloc));                      /* irootloc[] contains indices of rows I need to send to each receiver */
367     PetscCall(PetscSFGetRootRanks(ownerSF, &nranks, &ranks, &roffset, NULL /*rmine*/, NULL /*rremote*/)); /* recv info */
368 
369     /* figure out offsets at the send buffer, to build the SF
370       sdisp[]  - stores offsets of nonzeros (in abuf or jbuf, see later) I need to send, per receiver.
371       rowptr[] - stores offsets for data of each row in abuf
372 
373       rdisp[]  - to receive sdisp[]
374     */
375     PetscCall(PetscMalloc3(niranks + 1, &sdisp, nranks, &rdisp, niranks + nranks, &reqs));
376     MatRowMapKokkosViewHost rowptr_h("rowptr_h", ioffset[niranks] + 1); /* Let Kokkos do the allocation, so that we can do an easy mirror later */
377     rowptr = rowptr_h.data();
378 
379     sdisp[0]  = 0;
380     rowptr[0] = 0;
381     for (i = 0; i < niranks; i++) { /* for each receiver */
382       PetscInt len, nz = 0;
383       for (j = ioffset[i]; j < ioffset[i + 1]; j++) { /* for each row to this receiver */
384         len           = Browlens[irootloc[j]];
385         rowptr[j + 1] = rowptr[j] + len;
386         nz += len;
387       }
388       sdisp[i + 1] = sdisp[i] + nz;
389     }
390     PetscCallMPI(PetscCommGetNewTag(comm, &tag));
391     for (i = 0; i < nranks; i++) PetscCallMPI(MPI_Irecv(&rdisp[i], 1, MPIU_INT, ranks[i], tag, comm, &reqs[i]));
392     for (i = 0; i < niranks; i++) PetscCallMPI(MPI_Isend(&sdisp[i], 1, MPIU_INT, iranks[i], tag, comm, &reqs[nranks + i]));
393     PetscCallMPI(MPI_Waitall(niranks + nranks, reqs, MPI_STATUSES_IGNORE));
394 
395     PetscInt     nleaves2 = Cnnz;           /* leaves are the nonzeros I will receive */
396     PetscInt     nroots2  = sdisp[niranks]; /* roots are the nonzeros (in abuf) I will send */
397     PetscSFNode *iremote;
398     PetscCall(PetscMalloc1(nleaves2, &iremote));
399     for (i = 0; i < nranks; i++) { /* for each sender */
400       k = 0;
401       for (j = Ci_h[roffset[i]]; j < Ci_h[roffset[i + 1]]; j++) {
402         iremote[j].rank  = ranks[i];
403         iremote[j].index = rdisp[i] + k;
404         k++;
405       }
406     }
407     /* TODO: we should extend PetscSF APIs for this buffer-to-buffer send/recv */
408     PetscCall(PetscSFCreate(comm, &bcastSF));
409     PetscCall(PetscSFSetGraph(bcastSF, nroots2, nleaves2, NULL /*ilocal*/, PETSC_OWN_POINTER, iremote, PETSC_OWN_POINTER));
410 
411     /* Extract selected rows of B, and copy their (a, j) into abuf[] and jbuf[], with j converted
412       from local to global. Then use bcastSF to fill Ca, Cj.
413     */
414     ConstMatColIdxKokkosViewHost rows_h(irootloc, ioffset[niranks]); /* irootloc[] stores indices of rows I need to send */
415     MatColIdxKokkosView          rows("rows", ioffset[niranks]);
416     Kokkos::deep_copy(rows, rows_h); /* Use deep copy since irootoc is managed by PetscSF and we want 'rows' to be standalone */
417 
418     rowoffset = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), rowptr_h); /* If no device, rowoffset will be an alias to rowptr_h */
419 
420     MatColIdxKokkosView jbuf("jbuf", sdisp[niranks]);   /* send buf for (global) col ids */
421     abuf = MatScalarKokkosView("abuf", sdisp[niranks]); /* send buf for mat values */
422 
423     const auto &Ba = bkok->a_dual.view_device();
424     const auto &Bi = bkok->i_dual.view_device();
425     const auto &Bj = bkok->j_dual.view_device();
426 
427     /* Copy Ba, Bj to abuf, jbuf with change col ids from local to global */
428     Kokkos::parallel_for(
429       Kokkos::TeamPolicy<>(rows.extent(0), Kokkos::AUTO()), KOKKOS_LAMBDA(const KokkosTeamMemberType &t) {
430         PetscInt i    = t.league_rank(); /* rows[i] is r-th row of B */
431         PetscInt r    = rows(i);
432         PetscInt base = rowoffset(i); /* Copy r-th row of B to this offset in abuf[] */
433         Kokkos::parallel_for(Kokkos::TeamThreadRange(t, Bi(r + 1) - Bi(r)), [&](PetscInt k) {
434           abuf(base + k) = Ba(Bi(r) + k);
435           jbuf(base + k) = l2g(Bj(Bi(r) + k));
436         });
437       });
438 
439     /* Send abuf & jbuf to fill Ca, Cj */
440     PetscCall(PetscSFBcastBegin(bcastSF, MPIU_INT, jbuf.data(), Cj.view_device().data(), MPI_REPLACE));
441     PetscCall(PetscSFBcastBegin(bcastSF, MPIU_SCALAR, abuf.data(), Ca.view_device().data(), MPI_REPLACE));
442     PetscCall(PetscSFBcastEnd(bcastSF, MPIU_INT, jbuf.data(), Cj.view_device().data(), MPI_REPLACE));
443     PetscCall(PetscSFBcastEnd(bcastSF, MPIU_SCALAR, abuf.data(), Ca.view_device().data(), MPI_REPLACE));
444     Cj.modify_device(); /* Mark Cj, Ca modified on device, but only sync Cj since we might not need Ca on host at all */
445     Cj.sync_host();
446     Ca.modify_device();
447 
448     /* Construct C with Ca, Ci, Cj */
449     auto ckok = new Mat_SeqAIJKokkos(Cm, Cn, Cnnz, Ci, Cj, Ca);
450     PetscCall(MatCreateSeqAIJKokkosWithCSRMatrix(PETSC_COMM_SELF, ckok, &C));
451     PetscCall(PetscFree3(sdisp, rdisp, reqs));
452     PetscCall(PetscFree(Browlens));
453   } else SETERRQ(PETSC_COMM_SELF, PETSC_ERR_PLIB, "Unsupported MatReuse enum %d", reuse);
454   PetscFunctionReturn(PETSC_SUCCESS);
455 }
456 
457 /* MatSeqAIJKokkosReduce - Reduce rows of a SEQAIJKOKKOS matrix (A) to form a Kokkos Csr matrix (C)
458 
459   It is the reverse of MatSeqAIJKokkosBcast in some sense.
460 
461   Think each row of A as a leaf, then the given ownerSF specifies roots of the leaves. Roots may connect to multiple leaves.
462   In this routine, we reduce (i.e., concatenate) leaves (rows) at their roots to form potentially longer rows in C. Such rows might
463   contain repeats, which does not matter since they will be summed up by other routines. C's row size will be nroots of ownerSF.
464 
465   Input Parameters:
466 +  A        - the SEQAIJKOKKOS matrix to be reduced
467 .  reuse    - either MAT_INITIAL_MATRIX or MAT_REUSE_MATRIX
468 .  local    - true if A uses local col ids; false if A is already in global col ids.
469 .  N        - if local, N is A's global col size
470 .  l2g      - if local, a map mapping A's local col ids to global ones, which are in range of [0,N).
471 -  ownerSF  - the SF specifies ownership (root) of rows in A
472 
473   Output Parameters:
474 +  reduceSF    - the SF to reduce A's rows to contiguous buffers at the receiver side
475 .  abuf         - a contiguous buffer to receive A's rows sent to this proc. Suppose there are 'nrows' such rows.
476 .  srcrowoffset - offset array of size nrows+1. Each entry is the corresponding row's offset in abuf[]. srcrowoffset[i+1]-srcrowoffset[i] is row i's len.
477 .  dstrowoffset - offset array of size nrows. Each entry is the corresponding row's offset in Ca[], i.e., C's 'a' array. Row i, i+1 in abuf[] may go to
478                   unrelated places in Ca, so dstrowoffset is not in CSR-like format as srcrowoffset.
479 -  C            - the matrix made up by rows sent to me from other ranks, using global col ids
480 
481    TODO: we can even have MatSeqAIJKokkosReduceBegin/End to provide opportunity for callers to overlap comp./comm. when reuse = MAT_REUSE_MATRIX.
482  */
483 static PetscErrorCode MatSeqAIJKokkosReduce(Mat A, MatReuse reuse, PetscBool local, PetscInt N, const ConstMatColIdxKokkosView &l2g, PetscSF ownerSF, PetscSF &reduceSF, MatScalarKokkosView &abuf, MatRowMapKokkosView &srcrowoffset, MatRowMapKokkosView &dstrowoffset, KokkosCsrMatrix &C)
484 {
485   PetscInt          i, r, Am, An, Annz, Cnnz, nrows;
486   const PetscInt   *Ai;
487   Mat_SeqAIJKokkos *akok;
488 
489   PetscFunctionBegin;
490   PetscCall(MatSeqAIJKokkosSyncDevice(A)); /* So that A's latest data is on device */
491   PetscCall(MatGetSize(A, &Am, &An));
492   Ai   = static_cast<Mat_SeqAIJ *>(A->data)->i;
493   akok = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
494   Annz = Ai[Am];
495 
496   if (reuse == MAT_REUSE_MATRIX) {
497     /* Send Aa to abuf */
498     PetscCallMPI(PetscSFReduceBegin(reduceSF, MPIU_SCALAR, akok->a_device_data(), abuf.data(), MPI_REPLACE));
499     PetscCallMPI(PetscSFReduceEnd(reduceSF, MPIU_SCALAR, akok->a_device_data(), abuf.data(), MPI_REPLACE));
500 
501     /* Copy abuf to Ca */
502     const MatScalarKokkosView &Ca = C.values;
503     nrows                         = dstrowoffset.extent(0); /* Not srcrowoffset[] since it has an extra entry for CSR */
504     Kokkos::parallel_for(
505       Kokkos::TeamPolicy<>(nrows, Kokkos::AUTO()), KOKKOS_LAMBDA(const KokkosTeamMemberType &t) {
506         PetscInt i   = t.league_rank();
507         PetscInt src = srcrowoffset(i), dst = dstrowoffset(i);
508         PetscInt len = srcrowoffset(i + 1) - srcrowoffset(i);
509         Kokkos::parallel_for(Kokkos::TeamThreadRange(t, len), [&](PetscInt k) { Ca(dst + k) = abuf(src + k); });
510       });
511   } else if (reuse == MAT_INITIAL_MATRIX) {
512     MPI_Comm     comm;
513     MPI_Request *reqs;
514     PetscMPIInt  tag;
515     PetscInt     Cm;
516 
517     PetscCall(PetscObjectGetComm((PetscObject)ownerSF, &comm));
518     PetscCall(PetscCommGetNewTag(comm, &tag));
519 
520     PetscInt           niranks, nranks, nroots, nleaves;
521     const PetscMPIInt *iranks, *ranks;
522     const PetscInt    *ioffset, *rows, *roffset; /* rows[] contains local indices of rows scattered to me from others. ioffset[] is a CSR on rows[] */
523     PetscCall(PetscSFSetUp(ownerSF));
524     PetscCall(PetscSFGetLeafRanks(ownerSF, &niranks, &iranks, &ioffset, &rows));                          /* recv info: iranks[] will send rows to me */
525     PetscCall(PetscSFGetRootRanks(ownerSF, &nranks, &ranks, &roffset, NULL /*rmine*/, NULL /*rremote*/)); /* send info */
526     PetscCall(PetscSFGetGraph(ownerSF, &nroots, &nleaves, NULL, NULL));
527     PetscCheck(nleaves == Am, PETSC_COMM_SELF, PETSC_ERR_PLIB, "ownerSF's nleaves(%" PetscInt_FMT ") != row size of A(%" PetscInt_FMT ")", nleaves, Am);
528     Cm    = nroots;
529     nrows = ioffset[niranks]; /* # of rows to be received. Might receive same row (each is partial) from different senders */
530 
531     /* Tell owners how long each row I will send */
532     PetscInt               *srowlens;                              /* send buf of row lens */
533     MatRowMapKokkosViewHost rrowlens_h("rrowoffset_h", nrows + 1); /* recv buf of row lens. +1 to make CSR later. Memory might be passed to other views */
534     PetscInt               *rrowlens = rrowlens_h.data();
535 
536     PetscCall(PetscMalloc2(Am, &srowlens, niranks + nranks, &reqs));
537     for (i = 0; i < Am; i++) srowlens[i] = Ai[i + 1] - Ai[i];
538     rrowlens[0] = 0;
539     rrowlens++; /* shift the pointer to make the following expression more readable */
540     for (i = 0; i < niranks; i++) PetscCallMPI(MPI_Irecv(&rrowlens[ioffset[i]], ioffset[i + 1] - ioffset[i], MPIU_INT, iranks[i], tag, comm, &reqs[i]));
541     for (i = 0; i < nranks; i++) PetscCallMPI(MPI_Isend(&srowlens[roffset[i]], roffset[i + 1] - roffset[i], MPIU_INT, ranks[i], tag, comm, &reqs[niranks + i]));
542     PetscCallMPI(MPI_Waitall(niranks + nranks, reqs, MPI_STATUSES_IGNORE));
543 
544     /* Owner builds Ci on host by histogramming rrowlens[] */
545     MatRowMapKokkosViewHost Ci_h("i", Cm + 1);
546     Kokkos::deep_copy(Ci_h, 0); /* Zero Ci */
547     MatRowMapType *Ci_ptr = Ci_h.data();
548 
549     for (i = 0; i < nrows; i++) {
550       r = rows[i]; /* local row id of i-th received row */
551 #if defined(PETSC_USE_DEBUG)
552       PetscCheck(r >= 0 && r < Cm, PETSC_COMM_SELF, PETSC_ERR_PLIB, "local row id (%" PetscInt_FMT ") is out of range [0,%" PetscInt_FMT ")", r, Cm);
553 #endif
554       Ci_ptr[r + 1] += rrowlens[i]; /* add to length of row r in C */
555     }
556     for (i = 0; i < Cm; i++) Ci_ptr[i + 1] += Ci_ptr[i]; /* to CSR format */
557     Cnnz = Ci_ptr[Cm];
558 
559     /* For each received row, compute src & dst offsets in memory copying (from recv bufs abuf, jbuf to Ca, Cj) */
560     MatRowMapKokkosViewHost dstrowoffset_h("dstrowoffset_h", nrows);
561     PetscInt               *dstrowoffset_hptr = dstrowoffset_h.data();
562     PetscInt               *currowlens; /* Current row lens. They are temp accumulators for row lens in C, to help build dstrowoffset */
563 
564     PetscCall(PetscCalloc1(Cm, &currowlens));           /* Init with zero, to be added to */
565     for (i = 0; i < nrows; i++) {                       /* for each row I receive */
566       r                    = rows[i];                   /* row id in C */
567       dstrowoffset_hptr[i] = Ci_ptr[r] + currowlens[r]; /* dst offset of the new place for each recv'ed row in Ca/Cj */
568       currowlens[r] += rrowlens[i];                     /* accumulate to length of row r in C */
569     }
570     PetscCall(PetscFree(currowlens));
571 
572     rrowlens--;
573     for (i = 0; i < nrows; i++) rrowlens[i + 1] += rrowlens[i]; /* Change rrowlens[] to CSR format */
574     dstrowoffset = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), dstrowoffset_h);
575     srcrowoffset = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), rrowlens_h); /* src offset of each recv'ed row in abuf/jbuf */
576 
577     /* Build the reduceSF, which performs buffer to buffer send/recv */
578     PetscInt *sdisp, *rdisp; /* buffer to send offsets of roots, and buffer to recv them */
579     PetscCall(PetscMalloc2(niranks, &sdisp, nranks, &rdisp));
580     for (i = 0; i < niranks; i++) sdisp[i] = rrowlens[ioffset[i]];
581     for (i = 0; i < nranks; i++) PetscCallMPI(MPI_Irecv(&rdisp[i], 1, MPIU_INT, ranks[i], tag, comm, &reqs[i]));
582     for (i = 0; i < niranks; i++) PetscCallMPI(MPI_Isend(&sdisp[i], 1, MPIU_INT, iranks[i], tag, comm, &reqs[nranks + i]));
583     PetscCallMPI(MPI_Waitall(niranks + nranks, reqs, MPI_STATUSES_IGNORE));
584 
585     /* Nonzeros in abuf/jbuf are roots and those in A are leaves */
586     PetscInt     nroots2 = Cnnz, nleaves2 = Annz;
587     PetscSFNode *iremote;
588     PetscCall(PetscMalloc1(nleaves2, &iremote)); /* no free, since memory will be given to reduceSF */
589     for (i = 0; i < nranks; i++) {
590       PetscInt rootbase = rdisp[i];                      /* root offset at this root rank */
591       PetscInt leafbase = Ai[roffset[i]];                /* leaf base */
592       PetscInt nz       = Ai[roffset[i + 1]] - leafbase; /* I will send nz nonzeros to this root rank */
593       for (PetscInt k = 0; k < nz; k++) {
594         iremote[leafbase + k].rank  = ranks[i];
595         iremote[leafbase + k].index = rootbase + k;
596       }
597     }
598     PetscCall(PetscSFCreate(comm, &reduceSF));
599     PetscCall(PetscSFSetGraph(reduceSF, nroots2, nleaves2, NULL, PETSC_OWN_POINTER, iremote, PETSC_OWN_POINTER));
600     PetscCall(PetscFree2(sdisp, rdisp));
601 
602     /* Reduce Aa, Ajg to abuf and jbuf */
603 
604     /* If A uses local col ids, convert them to global ones before sending */
605     MatColIdxKokkosView Ajg;
606     if (local) {
607       Ajg                           = MatColIdxKokkosView("j", Annz);
608       const MatColIdxKokkosView &Aj = akok->j_dual.view_device();
609       Kokkos::parallel_for(
610         Annz, KOKKOS_LAMBDA(const PetscInt i) { Ajg(i) = l2g(Aj(i)); });
611     } else {
612       Ajg = akok->j_dual.view_device(); /* no data copy, just take a reference */
613     }
614 
615     MatColIdxKokkosView jbuf("jbuf", Cnnz);
616     abuf = MatScalarKokkosView("abuf", Cnnz);
617     PetscCallMPI(PetscSFReduceBegin(reduceSF, MPIU_INT, Ajg.data(), jbuf.data(), MPI_REPLACE));
618     PetscCallMPI(PetscSFReduceEnd(reduceSF, MPIU_INT, Ajg.data(), jbuf.data(), MPI_REPLACE));
619     PetscCallMPI(PetscSFReduceBegin(reduceSF, MPIU_SCALAR, akok->a_device_data(), abuf.data(), MPI_REPLACE));
620     PetscCallMPI(PetscSFReduceEnd(reduceSF, MPIU_SCALAR, akok->a_device_data(), abuf.data(), MPI_REPLACE));
621 
622     /* Copy data from abuf, jbuf to Ca, Cj */
623     MatRowMapKokkosView Ci = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Ci_h); /* Ci is an alias of Ci_h if no device */
624     MatColIdxKokkosView Cj("j", Cnnz);
625     MatScalarKokkosView Ca("a", Cnnz);
626 
627     Kokkos::parallel_for(
628       Kokkos::TeamPolicy<>(nrows, Kokkos::AUTO()), KOKKOS_LAMBDA(const KokkosTeamMemberType &t) {
629         PetscInt i   = t.league_rank();
630         PetscInt src = srcrowoffset(i), dst = dstrowoffset(i);
631         PetscInt len = srcrowoffset(i + 1) - srcrowoffset(i);
632         Kokkos::parallel_for(Kokkos::TeamThreadRange(t, len), [&](PetscInt k) {
633           Ca(dst + k) = abuf(src + k);
634           Cj(dst + k) = jbuf(src + k);
635         });
636       });
637 
638     /* Build C with Ca, Ci, Cj */
639     C = KokkosCsrMatrix("csrmat", Cm, N, Cnnz, Ca, Ci, Cj);
640     PetscCall(PetscFree2(srowlens, reqs));
641   } else SETERRQ(PETSC_COMM_SELF, PETSC_ERR_PLIB, "Unsupported MatReuse enum %d", reuse);
642   PetscFunctionReturn(PETSC_SUCCESS);
643 }
644 
645 /* MatSetMPIAIJKokkosWithGlobalCSRMatrix - Set the diag and offdiag parts of a `MATMPIAIJKOKKOS` matrix by splitting a KokkosCsrMatrix
646 
647   Input Parameters:
648 +  C        - the `MATMPIAIJKOKKOS` matrix, of size m,n,M,N
649 .  reuse    - indicate whether the matrix has called this function before
650 .  csrmat   - the KokkosCsrMatrix, of size m,N
651 -  Cdstart  - when reuse == `MAT_REUSE_MATRIX`, it is an input parameter. For each row in csrmat, it stores the start of the first
652               entry of the diag block of C in csrmat's j array. E.g, if row i has col ids = {0, 3, 4, 5, 7, 9} and the first diag
653               entry is 5, then Cdstart[i] = 3.
654 
655   Output Parameters:
656 +  C        - the updated `MATMPIAIJKOKKOS` matrix
657 -  Cdstart - when reuse == `MAT_INITIAL_MATRIX`, it is an output parameter
658 
659   Note:
660    Between calls with `MAT_INITIAL_MATRIX` or `MAT_REUSE_MATRIX`, csrmat must have the same nonzero pattern
661 
662 .seealso: `MATMPIAIJKOKKOS`
663  */
664 static PetscErrorCode MatSetMPIAIJKokkosWithGlobalCSRMatrix(Mat C, MatReuse reuse, const KokkosCsrMatrix &csrmat, MatRowMapKokkosView &Cdstart)
665 {
666   const MatScalarKokkosView      &Ca = csrmat.values;
667   const ConstMatRowMapKokkosView &Ci = csrmat.graph.row_map;
668   PetscInt                        m, n, N;
669 
670   PetscFunctionBegin;
671   PetscCall(MatGetLocalSize(C, &m, &n));
672   PetscCall(MatGetSize(C, NULL, &N));
673 
674   if (reuse == MAT_REUSE_MATRIX) {
675     Mat_MPIAIJ                *mpiaij = static_cast<Mat_MPIAIJ *>(C->data);
676     Mat_SeqAIJKokkos          *akok   = static_cast<Mat_SeqAIJKokkos *>(mpiaij->A->spptr);
677     Mat_SeqAIJKokkos          *bkok   = static_cast<Mat_SeqAIJKokkos *>(mpiaij->B->spptr);
678     const MatScalarKokkosView &Cda = akok->a_dual.view_device(), Coa = bkok->a_dual.view_device();
679     const MatRowMapKokkosView &Cdi = akok->i_dual.view_device(), Coi = bkok->i_dual.view_device();
680 
681     /* Fill 'a' of Cd and Co on device */
682     Kokkos::parallel_for(
683       Kokkos::TeamPolicy<>(m, Kokkos::AUTO()), KOKKOS_LAMBDA(const KokkosTeamMemberType &t) {
684         PetscInt i       = t.league_rank();     /* row i */
685         PetscInt clen    = Ci(i + 1) - Ci(i);   /* len of row i of C */
686         PetscInt cdlen   = Cdi(i + 1) - Cdi(i); /* len of row i of Cd */
687         PetscInt cdstart = Cdstart(i);          /* [start, end) of row i of Cd in C */
688         PetscInt cdend   = cdstart + cdlen;
689         /* [0, clen) is cut into three blocks: [0, cdstart), [cdstart, cdend), [cdend, clen) */
690         Kokkos::parallel_for(Kokkos::TeamThreadRange(t, clen), [&](PetscInt k) {
691           if (k < cdstart) { /* k in [0, cdstart) */
692             Coa(Coi(i) + k) = Ca(Ci(i) + k);
693           } else if (k < cdend) { /* k in [cdstart, cdend) */
694             Cda(Cdi(i) + (k - cdstart)) = Ca(Ci(i) + k);
695           } else { /* k in [cdend, clen) */
696             Coa(Coi(i) + k - cdlen) = Ca(Ci(i) + k);
697           }
698         });
699       });
700 
701     akok->a_dual.modify_device();
702     bkok->a_dual.modify_device();
703   } else if (reuse == MAT_INITIAL_MATRIX) {
704     Mat                        Cd, Co;
705     const MatColIdxKokkosView &Cj = csrmat.graph.entries;
706     MatRowMapKokkosDualView    Cdi_dual("i", m + 1), Coi_dual("i", m + 1);
707     MatRowMapKokkosView        Cdi = Cdi_dual.view_device(), Coi = Coi_dual.view_device();
708     PetscInt                   cstart, cend;
709 
710     /* Note that each row of C is sorted by col ids. We want to find out how to cut each row into three blocks:
711        left to the diag block, diag block, right to the diag block. The diag block have col ids in [cstart,cend).
712        Suppose a row of C has len nonzeros, indexed by [0, len). We want to know two indices: cdstart and cdend,
713        such that the three blocks are [0,cdstart), [cdstart,cdend), [cdend,len). The following code equivalentaly
714        stores values of cdstart and cdend-cstart (aka Cdi[]) instead.
715      */
716     Cdstart = MatRowMapKokkosView("Cdstart", m);
717     PetscCall(PetscLayoutGetRange(C->cmap, &cstart, &cend)); /* Not MatGetOwnershipRangeColumn() since C has not been preallocated yet */
718 
719     /* I could use RangePolicy and one thread per row. But since each thread essentially does binary search, threads in a
720       CUDA warp would completely diverge. So I use TeamPolicy with a team size 1.
721      */
722     Kokkos::parallel_for(
723       Kokkos::TeamPolicy<>(m, 1), KOKKOS_LAMBDA(const KokkosTeamMemberType &t) {
724         Kokkos::single(Kokkos::PerTeam(t), [=]() {                               /* Only one thread works in a team */
725                                                    PetscInt i = t.league_rank(); /* row i */
726                                                    PetscInt j, first, count, step;
727 
728                                                    if (i == 0) { /* Set the first entry of the i arrays to zero on device, to be used in CSR */
729                                                      Cdi(0) = 0;
730                                                      Coi(0) = 0;
731                                                    }
732 
733                                                    /* Do std::lower_bound(Ci(i),Ci(i+1),cstart) on Cj[]. We use j as the iterator. lower_bound() returns
734           in 'first' the first iterator with a value >= cstart, or last iterator if no such element is found.
735         */
736                                                    count = Ci(i + 1) - Ci(i);
737                                                    first = Ci(i);
738                                                    while (count > 0) {
739                                                      j    = first;
740                                                      step = count / 2;
741                                                      j += step;
742                                                      if (Cj(j) < cstart) {
743                                                        first = ++j;
744                                                        count -= step + 1;
745                                                      } else count = step;
746                                                    }
747                                                    Cdstart(i) = first - Ci(i); /* 'first' is the while-loop's output */
748 
749                                                    /* Do std::lower_bound(first,Ci(i+1),cend) on Cj[] */
750                                                    count = Ci(i + 1) - first;
751                                                    while (count > 0) {
752                                                      j    = first;
753                                                      step = count / 2;
754                                                      j += step;
755                                                      if (Cj(j) < cend) {
756                                                        first = ++j;
757                                                        count -= step + 1;
758                                                      } else count = step;
759                                                    }
760                                                    Cdi(i + 1) = first - (Ci(i) + Cdstart(i));     /* 'first' is the while-loop's output */
761                                                    Coi(i + 1) = (Ci(i + 1) - Ci(i)) - Cdi(i + 1); /* Co's row len = C's row len - Cd's row len */
762         });
763       });
764 
765     /* Convert row lens in Cdi[], Coi[] to CSR format using inclusive scan, e.g., changing [0,1,2,3] into [0,1,3,6] */
766     Kokkos::parallel_scan(
767       m + 1, KOKKOS_LAMBDA(const PetscInt i, PetscInt &update, const bool final) {
768         update += Cdi(i);
769         if (final) Cdi(i) = update;
770       });
771     Kokkos::parallel_scan(
772       m + 1, KOKKOS_LAMBDA(const PetscInt i, PetscInt &update, const bool final) {
773         update += Coi(i);
774         if (final) Coi(i) = update;
775       });
776 
777     /* Get Cdi, Coi on host (it is not a waste, since we do need them on host in
778        MatCreateSeqAIJKokkosWithCSRMatrix() below), then get nnz of Cd and Co.
779     */
780     Cdi_dual.modify_device();
781     Coi_dual.modify_device();
782     Cdi_dual.sync_host();
783     Coi_dual.sync_host();
784     PetscInt Cd_nnz = Cdi_dual.view_host().data()[m];
785     PetscInt Co_nnz = Coi_dual.view_host().data()[m];
786 
787     /* With nnz, allocate a, j for Cd and Co */
788     MatColIdxKokkosDualView Cdj_dual("j", Cd_nnz), Coj_dual("j", Co_nnz);
789     MatScalarKokkosDualView Cda_dual("a", Cd_nnz), Coa_dual("a", Co_nnz);
790 
791     /* Fill a, j of Cd and Co on device */
792     MatColIdxKokkosView Cdj = Cdj_dual.view_device(), Coj = Coj_dual.view_device();
793     MatScalarKokkosView Cda = Cda_dual.view_device(), Coa = Coa_dual.view_device();
794 
795     Kokkos::parallel_for(
796       Kokkos::TeamPolicy<>(m, Kokkos::AUTO()), KOKKOS_LAMBDA(const KokkosTeamMemberType &t) {
797         PetscInt i       = t.league_rank();     /* row i */
798         PetscInt clen    = Ci(i + 1) - Ci(i);   /* len of row i of C */
799         PetscInt cdlen   = Cdi(i + 1) - Cdi(i); /* len of row i of Cd */
800         PetscInt cdstart = Cdstart(i);          /* [start, end) of row i of Cd in C */
801         PetscInt cdend   = cdstart + cdlen;
802         /* [0, clen) is cut into three blocks: [0, cdstart), [cdstart, cdend), [cdend, clen) */
803         Kokkos::parallel_for(Kokkos::TeamThreadRange(t, clen), [&](PetscInt k) {
804           if (k < cdstart) { /* k in [0, cdstart) */
805             Coa(Coi(i) + k) = Ca(Ci(i) + k);
806             Coj(Coi(i) + k) = Cj(Ci(i) + k);
807           } else if (k < cdend) { /* k in [cdstart, cdend) */
808             Cda(Cdi(i) + (k - cdstart)) = Ca(Ci(i) + k);
809             Cdj(Cdi(i) + (k - cdstart)) = Cj(Ci(i) + k) - cstart; /* Use local col ids in Cdj */
810           } else {                                                /* k in [cdend, clen) */
811             Coa(Coi(i) + k - cdlen) = Ca(Ci(i) + k);
812             Coj(Coi(i) + k - cdlen) = Cj(Ci(i) + k);
813           }
814         });
815       });
816 
817     Cdj_dual.modify_device();
818     Cda_dual.modify_device();
819     Coj_dual.modify_device();
820     Coa_dual.modify_device();
821     /* With a, i, j for Cd and Co, finally build Cd, Co and then C. Their offloadmask will be set in each's MatAssemblyEnd */
822     auto cdkok = new Mat_SeqAIJKokkos(m, n, Cd_nnz, Cdi_dual, Cdj_dual, Cda_dual);
823     auto cokok = new Mat_SeqAIJKokkos(m, N, Co_nnz, Coi_dual, Coj_dual, Coa_dual);
824     PetscCall(MatCreateSeqAIJKokkosWithCSRMatrix(PETSC_COMM_SELF, cdkok, &Cd));
825     PetscCall(MatCreateSeqAIJKokkosWithCSRMatrix(PETSC_COMM_SELF, cokok, &Co));
826     PetscCall(MatSetMPIAIJKokkosWithSplitSeqAIJKokkosMatrices(C, Cd, Co)); /* Coj will be converted to local ids within */
827   }
828   PetscFunctionReturn(PETSC_SUCCESS);
829 }
830 
831 /* MatSeqAIJCompactOutExtraColumns_SeqAIJKokkos - Compact a SEQAIJKOKKS matrix's global col ids.
832 
833   It is similar to MatSeqAIJCompactOutExtraColumns_SeqAIJ, but it applies to SEQAIJKOKKOS and returns the l2g map in Kokkos view.
834 
835   Input Parameters:
836 +  C        - the MATMPIAIJKOKKOS matrix, of size m,n,M,N
837 .  reuse    - indicate whether the matrix has called this function before
838 .  csrmat   - the KokkosCsrMatrix, of size m,N
839 -  Cdoffset - when reuse == MAT_REUSE_MATRIX, it is an input parameter. For each row in csrmat, it stores the offset of the first
840               entry of the diag block of C in csrmat's j array.
841 
842   Output Parameters:
843 +  C        - the updated MATMPIAIJKOKKOS matrix
844 -  Cdoffset - when reuse == MAT_INITIAL_MATRIX, it is an output parameter
845 
846   Note:
847   the input matrix's col ids and col size will be changed.
848 */
849 static PetscErrorCode MatSeqAIJCompactOutExtraColumns_SeqAIJKokkos(Mat C, MatColIdxKokkosView &l2g)
850 {
851   Mat_SeqAIJKokkos      *ckok;
852   ISLocalToGlobalMapping l2gmap;
853   const PetscInt        *garray;
854   PetscInt               sz;
855 
856   PetscFunctionBegin;
857   /* Compact P_other's global col ids and col size. We do it since we guess with local ids KK might be more memory scalable */
858   PetscCall(MatSeqAIJCompactOutExtraColumns_SeqAIJ(C, &l2gmap));
859   ckok = static_cast<Mat_SeqAIJKokkos *>(C->spptr);
860   ckok->j_dual.modify_host(); /* P_other's j is modified on host; we need to sync it on device */
861   ckok->j_dual.sync_device();
862   ckok->SetColSize(C->cmap->n); /* Update col size of the csrmat in spptr */
863 
864   /* Build l2g -- the local to global mapping of C's cols */
865   PetscCall(ISLocalToGlobalMappingGetIndices(l2gmap, &garray));
866   PetscCall(ISLocalToGlobalMappingGetSize(l2gmap, &sz));
867   PetscCheck(C->cmap->n == sz, PETSC_COMM_SELF, PETSC_ERR_PLIB, "matrix column size(%" PetscInt_FMT ") != l2g mapping size(%" PetscInt_FMT ")", C->cmap->n, sz);
868 
869   ConstMatColIdxKokkosViewHost tmp(garray, sz);
870   l2g = MatColIdxKokkosView("l2g", sz);
871   Kokkos::deep_copy(l2g, tmp);
872 
873   PetscCall(ISLocalToGlobalMappingRestoreIndices(l2gmap, &garray));
874   PetscCall(ISLocalToGlobalMappingDestroy(&l2gmap));
875   PetscFunctionReturn(PETSC_SUCCESS);
876 }
877 
878 /* MatProductSymbolic_MPIAIJKokkos_AB - AB flavor of MatProductSymbolic_MPIAIJKokkos
879 
880   Input Parameters:
881 +  product  - Mat_Product which carried out the computation. Passed in to access info about this mat product.
882 .  A        - an MPIAIJKOKKOS matrix
883 .  B        - an MPIAIJKOKKOS matrix
884 -  mm       - a struct used to stash intermediate data when computing AB. Persist from symbolic to numeric operations.
885 
886   Note: The local part of the result C is stored as mm->C_global, which is of type KokkosCsrMatrix and uses global col ids.
887 */
888 static PetscErrorCode MatProductSymbolic_MPIAIJKokkos_AB(Mat_Product *product, Mat A, Mat B, MatMatStruct_AB *mm)
889 {
890   Mat_MPIAIJ              *a  = static_cast<Mat_MPIAIJ *>(A->data);
891   Mat                      Ad = a->A, Ao = a->B; /* diag and offdiag of A */
892   IS                       glob = NULL;
893   const PetscInt          *garray;
894   PetscInt                 N = B->cmap->N, sz;
895   ConstMatColIdxKokkosView l2g1; /* two temp maps mapping local col ids to global ones */
896   MatColIdxKokkosView      l2g2;
897   Mat                      C1, C2; /* intermediate matrices */
898 
899   PetscFunctionBegin;
900   /* C1 = Ad * B_local. B_local is a matrix got by merging Bd and Bo, and uses local col ids */
901   PetscCall(MatMPIAIJGetLocalMatMerge(B, MAT_INITIAL_MATRIX, &glob, &mm->B_local));
902   PetscCall(MatProductCreate(Ad, mm->B_local, NULL, &C1));
903   PetscCall(MatProductSetType(C1, MATPRODUCT_AB));
904   PetscCall(MatProductSetFill(C1, product->fill));
905   C1->product->api_user = product->api_user;
906   PetscCall(MatProductSetFromOptions(C1));
907   PetscUseTypeMethod(C1, productsymbolic);
908 
909   PetscCall(ISGetIndices(glob, &garray));
910   PetscCall(ISGetSize(glob, &sz));
911   const auto &tmp = ConstMatColIdxKokkosViewHost(garray, sz);                       /* wrap garray as a view */
912   l2g1            = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), tmp); /* maybe just an alias to tmp, so we restore garray at the very end */
913   PetscCall(MatSeqAIJKokkosGetCSRMatrixWithGlobalColumnIds(C1, N, l2g1, mm->C1_global));
914 
915   /* C2 = Ao * B_other. B_other is a matrix consisting of needed rows of B gathered from other procs */
916   PetscCall(MatSeqAIJKokkosBcast(mm->B_local, MAT_INITIAL_MATRIX, N, l2g1, a->Mvctx, mm->sf, mm->abuf, mm->rows, mm->rowoffset, mm->B_other));
917 
918   /* Compact B_other to use local ids as we guess KK spgemm is more memory scalable with that; We could skip the compaction to simplify code */
919   PetscCall(MatSeqAIJCompactOutExtraColumns_SeqAIJKokkos(mm->B_other, l2g2));
920   PetscCall(MatProductCreate(Ao, mm->B_other, NULL, &C2));
921   PetscCall(MatProductSetType(C2, MATPRODUCT_AB));
922   PetscCall(MatProductSetFill(C2, product->fill));
923   C2->product->api_user = product->api_user;
924   PetscCall(MatProductSetFromOptions(C2));
925   PetscUseTypeMethod(C2, productsymbolic);
926   PetscCall(MatSeqAIJKokkosGetCSRMatrixWithGlobalColumnIds(C2, N, l2g2, mm->C2_global));
927 
928   /* C = C1 + C2.  We actually use their global col ids versions in adding */
929   mm->kh.create_spadd_handle(false); /* Input C1, C2 are NOT sorted, since B_local, B_other are not */
930   KokkosSparse::spadd_symbolic(&mm->kh, mm->C1_global, mm->C2_global, mm->C_global);
931   /* Have to do numeric since spadd_symbolic does not really populate column indices of the result matrix */
932   KokkosSparse::spadd_numeric(&mm->kh, (MatScalarType)1.0, mm->C1_global, (MatScalarType)1.0, mm->C2_global, mm->C_global);
933 
934   mm->C1 = C1;
935   mm->C2 = C2;
936   PetscCall(ISRestoreIndices(glob, &garray));
937   PetscCall(ISDestroy(&glob));
938   PetscFunctionReturn(PETSC_SUCCESS);
939 }
940 
941 /* MatProductSymbolic_MPIAIJKokkos_AtB - A^tB flavor of MatProductSymbolic_MPIAIJKokkos
942 
943   Input Parameters:
944 +  product  - Mat_Product which carried out the computation. Passed in to access info about this mat product.
945 .  A        - an MPIAIJKOKKOS matrix
946 .  B        - a SEQAIJKOKKOS matrix. It works as if A^t is multiplied by a parallel matrix made up of Bs on each rank.
947 .  localB   - Does B use local col ids? If false, then B is already in global col ids.
948 .  N        - col size of the "parallel B matrix". It implies B's global col ids are in range of [0,N) and N is the same across the communicator.
949 .  l2g      - If localB, then l2g maps B's local col ids to global ones.
950 -  mm       - a struct used to stash intermediate data in AtB
951 
952   Note: The local part of the result C is stored as mm->C_global, which is of type KokkosCsrMatrix and uses global col ids.
953 */
954 static PetscErrorCode MatProductSymbolic_MPIAIJKokkos_AtB(Mat_Product *product, Mat A, Mat B, PetscBool localB, PetscInt N, const ConstMatColIdxKokkosView &l2g, MatMatStruct_AtB *mm)
955 {
956   Mat_MPIAIJ *a  = static_cast<Mat_MPIAIJ *>(A->data);
957   Mat         Ad = a->A, Ao = a->B; /* diag and offdiag of A */
958   Mat         C1, C2;               /* intermediate matrices */
959 
960   PetscFunctionBegin;
961   /* C1 = Ad^t * B */
962   PetscCall(MatProductCreate(Ad, B, NULL, &C1));
963   PetscCall(MatProductSetType(C1, MATPRODUCT_AtB));
964   PetscCall(MatProductSetFill(C1, product->fill));
965   C1->product->api_user = product->api_user;
966   PetscCall(MatProductSetFromOptions(C1));
967   PetscUseTypeMethod(C1, productsymbolic);
968 
969   if (localB) PetscCall(MatSeqAIJKokkosGetCSRMatrixWithGlobalColumnIds(C1, N, l2g, mm->C1_global));
970   else mm->C1_global = static_cast<Mat_SeqAIJKokkos *>(C1->spptr)->csrmat; /* the csrmat already uses global col ids */
971 
972   /* C2 = Ao^t * B */
973   PetscCall(MatProductCreate(Ao, B, NULL, &C2));
974   PetscCall(MatProductSetType(C2, MATPRODUCT_AtB));
975   PetscCall(MatProductSetFill(C2, product->fill));
976   C2->product->api_user = product->api_user;
977   PetscCall(MatProductSetFromOptions(C2));
978   PetscUseTypeMethod(C2, productsymbolic);
979 
980   PetscCall(MatSeqAIJKokkosReduce(C2, MAT_INITIAL_MATRIX, localB, N, l2g, a->Mvctx, mm->sf, mm->abuf, mm->srcrowoffset, mm->dstrowoffset, mm->C2_global));
981 
982   mm->kh.create_spadd_handle(false); /* Input C1, C2 are NOT sorted, since B may be not */
983   KokkosSparse::spadd_symbolic(&mm->kh, mm->C1_global, mm->C2_global, mm->C_global);
984   /* Have to do numeric since spadd_symbolic does not really populate column indices of the result matrix */
985   KokkosSparse::spadd_numeric(&mm->kh, (MatScalarType)1.0, mm->C1_global, (MatScalarType)1.0, mm->C2_global, mm->C_global);
986   mm->C1 = C1;
987   mm->C2 = C2;
988   PetscFunctionReturn(PETSC_SUCCESS);
989 }
990 
991 PetscErrorCode MatProductNumeric_MPIAIJKokkos(Mat C)
992 {
993   Mat_Product                 *product = C->product;
994   MatProductType               ptype;
995   MatProductData_MPIAIJKokkos *mmdata;
996   MatMatStruct                *mm = NULL;
997   MatMatStruct_AB             *ab;
998   MatMatStruct_AtB            *atb;
999   Mat                          A, B, Ad, Ao, Bd, Bo;
1000   const MatScalarType          one = 1.0; /* Not use literal 1.0 directly, to avoid wrong template instantiation in KokkosSparse::spadd_numeric */
1001 
1002   PetscFunctionBegin;
1003   MatCheckProduct(C, 1);
1004   mmdata = static_cast<MatProductData_MPIAIJKokkos *>(product->data);
1005   ptype  = product->type;
1006   A      = product->A;
1007   B      = product->B;
1008   Ad     = static_cast<Mat_MPIAIJ *>(A->data)->A;
1009   Ao     = static_cast<Mat_MPIAIJ *>(A->data)->B;
1010   Bd     = static_cast<Mat_MPIAIJ *>(B->data)->A;
1011   Bo     = static_cast<Mat_MPIAIJ *>(B->data)->B;
1012 
1013   if (mmdata->reusesym) {           /* We reached here through e.g., MatMatMult(A,B,MAT_INITIAL_MATRIX,..,C), where symbolic/numeric are combined */
1014     mmdata->reusesym = PETSC_FALSE; /* So that next time when user calls MatMatMult(E,F,MAT_REUSE_MATRIX,..,C), we still do numeric  */
1015     ab               = mmdata->mmAB;
1016     atb              = mmdata->mmAtB;
1017     if (ab) {
1018       static_cast<MatProductData_SeqAIJKokkos *>(ab->C1->product->data)->reusesym = PETSC_FALSE;
1019       static_cast<MatProductData_SeqAIJKokkos *>(ab->C2->product->data)->reusesym = PETSC_FALSE;
1020     }
1021     if (atb) {
1022       static_cast<MatProductData_SeqAIJKokkos *>(atb->C1->product->data)->reusesym = PETSC_FALSE;
1023       static_cast<MatProductData_SeqAIJKokkos *>(atb->C2->product->data)->reusesym = PETSC_FALSE;
1024     }
1025     PetscFunctionReturn(PETSC_SUCCESS);
1026   }
1027 
1028   if (ptype == MATPRODUCT_AB) {
1029     ab = mmdata->mmAB;
1030     /* C1 = Ad * B_local */
1031     PetscCheck(ab->C1->ops->productnumeric && ab->C2->ops->productnumeric, PETSC_COMM_SELF, PETSC_ERR_PLIB, "Missing numeric op for MATPRODUCT_AB");
1032     PetscCall(MatMPIAIJGetLocalMatMerge(B, MAT_REUSE_MATRIX, NULL /* glob */, &ab->B_local));
1033     PetscCheck(ab->C1->product->B == ab->B_local, PETSC_COMM_SELF, PETSC_ERR_PLIB, "In MATPRODUCT_AB, internal mat product matrix C1->B has unexpectedly changed");
1034     if (ab->C1->product->A != Ad) PetscCall(MatProductReplaceMats(Ad, NULL, NULL, ab->C1));
1035     PetscCall((*ab->C1->ops->productnumeric)(ab->C1));
1036     PetscCall(MatSeqAIJKokkosBcast(ab->B_local, MAT_REUSE_MATRIX, 0 /* N */, MatColIdxKokkosView() /*l2g*/, NULL /*ownerSF*/, ab->sf, ab->abuf, ab->rows, ab->rowoffset, ab->B_other));
1037     /* C2 = Ao * B_other */
1038     PetscCheck(ab->C2->product->B == ab->B_other, PETSC_COMM_SELF, PETSC_ERR_PLIB, "In MATPRODUCT_AB, internal mat product matrix C2->B has unexpectedly changed");
1039     if (ab->C1->product->A != Ao) PetscCall(MatProductReplaceMats(Ao, NULL, NULL, ab->C2));
1040     PetscCall((*ab->C2->ops->productnumeric)(ab->C2));
1041     /* C = C1_global + C2_global */
1042     KokkosSparse::spadd_numeric(&ab->kh, one, ab->C1_global, one, ab->C2_global, ab->C_global);
1043     mm = static_cast<MatMatStruct *>(ab);
1044   } else if (ptype == MATPRODUCT_AtB) {
1045     atb = mmdata->mmAtB;
1046     PetscCheck(atb->C1->ops->productnumeric && atb->C2->ops->productnumeric, PETSC_COMM_SELF, PETSC_ERR_PLIB, "Missing numeric op for MATPRODUCT_AtB");
1047     /* C1 = Ad^t * B_local */
1048     PetscCall(MatMPIAIJGetLocalMatMerge(B, MAT_REUSE_MATRIX, NULL /* glob */, &atb->B_local));
1049     PetscCheck(atb->C1->product->B == atb->B_local, PETSC_COMM_SELF, PETSC_ERR_PLIB, "In MATPRODUCT_AtB, internal mat product matrix C1->B has unexpectedly changed");
1050     if (atb->C1->product->A != Ad) PetscCall(MatProductReplaceMats(Ad, NULL, NULL, atb->C1));
1051     PetscCall((*atb->C1->ops->productnumeric)(atb->C1));
1052 
1053     /* C2 = Ao^t * B_local */
1054     PetscCheck(atb->C2->product->B == atb->B_local, PETSC_COMM_SELF, PETSC_ERR_PLIB, "In MATPRODUCT_AtB, internal mat product matrix C2->B has unexpectedly changed");
1055     if (atb->C2->product->A != Ao) PetscCall(MatProductReplaceMats(Ao, NULL, NULL, atb->C2));
1056     PetscCall((*atb->C2->ops->productnumeric)(atb->C2));
1057     /* Form C2_global */
1058     PetscCall(MatSeqAIJKokkosReduce(atb->C2, MAT_REUSE_MATRIX, PETSC_TRUE, 0 /* N */, MatColIdxKokkosView() /*l2g*/, NULL /*ownerSF*/, atb->sf, atb->abuf, atb->srcrowoffset, atb->dstrowoffset, atb->C2_global));
1059     /* C = C1_global + C2_global */
1060     KokkosSparse::spadd_numeric(&atb->kh, one, atb->C1_global, one, atb->C2_global, atb->C_global);
1061     mm = static_cast<MatMatStruct *>(atb);
1062   } else if (ptype == MATPRODUCT_PtAP) { /* BtAB */
1063     ab = mmdata->mmAB;
1064     PetscCall(MatMPIAIJGetLocalMatMerge(B, MAT_REUSE_MATRIX, NULL /* glob */, &ab->B_local));
1065 
1066     /* ab->C1 = Ad * B_local */
1067     PetscCheck(ab->C1->product->B == ab->B_local, PETSC_COMM_SELF, PETSC_ERR_PLIB, "In MATPRODUCT_PtAP, internal mat product matrix ab->C1->B has unexpectedly changed");
1068     if (ab->C1->product->A != Ad) PetscCall(MatProductReplaceMats(Ad, NULL, NULL, ab->C1));
1069     PetscCall((*ab->C1->ops->productnumeric)(ab->C1));
1070     PetscCall(MatSeqAIJKokkosBcast(ab->B_local, MAT_REUSE_MATRIX, 0 /* N */, MatColIdxKokkosView() /*l2g*/, NULL /*ownerSF*/, ab->sf, ab->abuf, ab->rows, ab->rowoffset, ab->B_other));
1071     /* ab->C2 = Ao * B_other */
1072     if (ab->C2->product->A != Ao) PetscCall(MatProductReplaceMats(Ao, NULL, NULL, ab->C2));
1073     PetscCall((*ab->C2->ops->productnumeric)(ab->C2)); /* C2 = Ao * B_other */
1074     KokkosSparse::spadd_numeric(&ab->kh, one, ab->C1_global, one, ab->C2_global, ab->C_global);
1075 
1076     /* atb->C1 = Bd^t * ab->C_petsc */
1077     atb = mmdata->mmAtB;
1078     PetscCheck(atb->C1->product->B == ab->C_petsc, PETSC_COMM_SELF, PETSC_ERR_PLIB, "In MATPRODUCT_PtAP, internal mat product matrix atb->C1->B has unexpectedly changed");
1079     if (atb->C1->product->A != Bd) PetscCall(MatProductReplaceMats(Bd, NULL, NULL, atb->C1));
1080     PetscCall((*atb->C1->ops->productnumeric)(atb->C1));
1081     /* atb->C2 = Bo^t * ab->C_petsc */
1082     if (atb->C2->product->A != Bo) PetscCall(MatProductReplaceMats(Bo, NULL, NULL, atb->C2));
1083     PetscCall((*atb->C2->ops->productnumeric)(atb->C2));
1084     PetscCall(MatSeqAIJKokkosReduce(atb->C2, MAT_REUSE_MATRIX, PETSC_FALSE, 0 /* N */, MatColIdxKokkosView() /*l2g*/, NULL /* ownerSF */, atb->sf, atb->abuf, atb->srcrowoffset, atb->dstrowoffset, atb->C2_global));
1085     KokkosSparse::spadd_numeric(&atb->kh, one, atb->C1_global, one, atb->C2_global, atb->C_global);
1086     mm = static_cast<MatMatStruct *>(atb);
1087   }
1088   /* Split C_global to form C */
1089   PetscCall(MatSetMPIAIJKokkosWithGlobalCSRMatrix(C, MAT_REUSE_MATRIX, mm->C_global, mm->Cdstart));
1090   PetscFunctionReturn(PETSC_SUCCESS);
1091 }
1092 
1093 PetscErrorCode MatProductSymbolic_MPIAIJKokkos(Mat C)
1094 {
1095   Mat                          A, B;
1096   Mat_Product                 *product = C->product;
1097   MatProductType               ptype;
1098   MatProductData_MPIAIJKokkos *mmdata;
1099   MatMatStruct                *mm   = NULL;
1100   IS                           glob = NULL;
1101   const PetscInt              *garray;
1102   PetscInt                     m, n, M, N, sz;
1103   ConstMatColIdxKokkosView     l2g; /* map local col ids to global ones */
1104 
1105   PetscFunctionBegin;
1106   MatCheckProduct(C, 1);
1107   PetscCheck(!product->data, PetscObjectComm((PetscObject)C), PETSC_ERR_PLIB, "Product data not empty");
1108   ptype = product->type;
1109   A     = product->A;
1110   B     = product->B;
1111 
1112   switch (ptype) {
1113   case MATPRODUCT_AB:
1114     m = A->rmap->n;
1115     n = B->cmap->n;
1116     M = A->rmap->N;
1117     N = B->cmap->N;
1118     break;
1119   case MATPRODUCT_AtB:
1120     m = A->cmap->n;
1121     n = B->cmap->n;
1122     M = A->cmap->N;
1123     N = B->cmap->N;
1124     break;
1125   case MATPRODUCT_PtAP:
1126     m = B->cmap->n;
1127     n = B->cmap->n;
1128     M = B->cmap->N;
1129     N = B->cmap->N;
1130     break; /* BtAB */
1131   default:
1132     SETERRQ(PetscObjectComm((PetscObject)C), PETSC_ERR_PLIB, "Not for product type %s", MatProductTypes[ptype]);
1133   }
1134 
1135   PetscCall(MatSetSizes(C, m, n, M, N));
1136   PetscCall(PetscLayoutSetUp(C->rmap));
1137   PetscCall(PetscLayoutSetUp(C->cmap));
1138   PetscCall(MatSetType(C, ((PetscObject)A)->type_name));
1139 
1140   mmdata           = new MatProductData_MPIAIJKokkos();
1141   mmdata->reusesym = product->api_user;
1142 
1143   if (ptype == MATPRODUCT_AB) {
1144     mmdata->mmAB = new MatMatStruct_AB();
1145     PetscCall(MatProductSymbolic_MPIAIJKokkos_AB(product, A, B, mmdata->mmAB));
1146     mm = static_cast<MatMatStruct *>(mmdata->mmAB);
1147   } else if (ptype == MATPRODUCT_AtB) {
1148     mmdata->mmAtB = new MatMatStruct_AtB();
1149     auto atb      = mmdata->mmAtB;
1150     PetscCall(MatMPIAIJGetLocalMatMerge(B, MAT_INITIAL_MATRIX, &glob, &atb->B_local));
1151     PetscCall(ISGetIndices(glob, &garray));
1152     PetscCall(ISGetSize(glob, &sz));
1153     l2g = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), ConstMatColIdxKokkosViewHost(garray, sz));
1154     PetscCall(MatProductSymbolic_MPIAIJKokkos_AtB(product, A, atb->B_local, PETSC_TRUE, N, l2g, atb));
1155     PetscCall(ISRestoreIndices(glob, &garray));
1156     PetscCall(ISDestroy(&glob));
1157     mm = static_cast<MatMatStruct *>(atb);
1158   } else if (ptype == MATPRODUCT_PtAP) {    /* BtAB */
1159     mmdata->mmAB  = new MatMatStruct_AB();  /* tmp=A*B */
1160     mmdata->mmAtB = new MatMatStruct_AtB(); /* C=B^t*tmp */
1161     auto ab       = mmdata->mmAB;
1162     auto atb      = mmdata->mmAtB;
1163     PetscCall(MatProductSymbolic_MPIAIJKokkos_AB(product, A, B, ab));
1164     auto tmp = new Mat_SeqAIJKokkos(ab->C_global); /* Memory will be owned by ab->C_petsc */
1165     PetscCall(MatCreateSeqAIJKokkosWithCSRMatrix(PETSC_COMM_SELF, tmp, &ab->C_petsc));
1166     PetscCall(MatProductSymbolic_MPIAIJKokkos_AtB(product, B, ab->C_petsc, PETSC_FALSE, N, l2g /*not used*/, atb));
1167     mm = static_cast<MatMatStruct *>(atb);
1168   }
1169   /* Split the C_global into petsc A, B format */
1170   PetscCall(MatSetMPIAIJKokkosWithGlobalCSRMatrix(C, MAT_INITIAL_MATRIX, mm->C_global, mm->Cdstart));
1171   C->product->data       = mmdata;
1172   C->product->destroy    = MatProductDataDestroy_MPIAIJKokkos;
1173   C->ops->productnumeric = MatProductNumeric_MPIAIJKokkos;
1174   PetscFunctionReturn(PETSC_SUCCESS);
1175 }
1176 
1177 PETSC_INTERN PetscErrorCode MatProductSetFromOptions_MPIAIJKokkos(Mat mat)
1178 {
1179   Mat_Product *product = mat->product;
1180   PetscBool    match   = PETSC_FALSE;
1181   PetscBool    usecpu  = PETSC_FALSE;
1182 
1183   PetscFunctionBegin;
1184   MatCheckProduct(mat, 1);
1185   if (!product->A->boundtocpu && !product->B->boundtocpu) PetscCall(PetscObjectTypeCompare((PetscObject)product->B, ((PetscObject)product->A)->type_name, &match));
1186   if (match) { /* we can always fallback to the CPU if requested */
1187     switch (product->type) {
1188     case MATPRODUCT_AB:
1189       if (product->api_user) {
1190         PetscOptionsBegin(PetscObjectComm((PetscObject)mat), ((PetscObject)mat)->prefix, "MatMatMult", "Mat");
1191         PetscCall(PetscOptionsBool("-matmatmult_backend_cpu", "Use CPU code", "MatMatMult", usecpu, &usecpu, NULL));
1192         PetscOptionsEnd();
1193       } else {
1194         PetscOptionsBegin(PetscObjectComm((PetscObject)mat), ((PetscObject)mat)->prefix, "MatProduct_AB", "Mat");
1195         PetscCall(PetscOptionsBool("-mat_product_algorithm_backend_cpu", "Use CPU code", "MatMatMult", usecpu, &usecpu, NULL));
1196         PetscOptionsEnd();
1197       }
1198       break;
1199     case MATPRODUCT_AtB:
1200       if (product->api_user) {
1201         PetscOptionsBegin(PetscObjectComm((PetscObject)mat), ((PetscObject)mat)->prefix, "MatTransposeMatMult", "Mat");
1202         PetscCall(PetscOptionsBool("-mattransposematmult_backend_cpu", "Use CPU code", "MatTransposeMatMult", usecpu, &usecpu, NULL));
1203         PetscOptionsEnd();
1204       } else {
1205         PetscOptionsBegin(PetscObjectComm((PetscObject)mat), ((PetscObject)mat)->prefix, "MatProduct_AtB", "Mat");
1206         PetscCall(PetscOptionsBool("-mat_product_algorithm_backend_cpu", "Use CPU code", "MatTransposeMatMult", usecpu, &usecpu, NULL));
1207         PetscOptionsEnd();
1208       }
1209       break;
1210     case MATPRODUCT_PtAP:
1211       if (product->api_user) {
1212         PetscOptionsBegin(PetscObjectComm((PetscObject)mat), ((PetscObject)mat)->prefix, "MatPtAP", "Mat");
1213         PetscCall(PetscOptionsBool("-matptap_backend_cpu", "Use CPU code", "MatPtAP", usecpu, &usecpu, NULL));
1214         PetscOptionsEnd();
1215       } else {
1216         PetscOptionsBegin(PetscObjectComm((PetscObject)mat), ((PetscObject)mat)->prefix, "MatProduct_PtAP", "Mat");
1217         PetscCall(PetscOptionsBool("-mat_product_algorithm_backend_cpu", "Use CPU code", "MatPtAP", usecpu, &usecpu, NULL));
1218         PetscOptionsEnd();
1219       }
1220       break;
1221     default:
1222       break;
1223     }
1224     match = (PetscBool)!usecpu;
1225   }
1226   if (match) {
1227     switch (product->type) {
1228     case MATPRODUCT_AB:
1229     case MATPRODUCT_AtB:
1230     case MATPRODUCT_PtAP:
1231       mat->ops->productsymbolic = MatProductSymbolic_MPIAIJKokkos;
1232       break;
1233     default:
1234       break;
1235     }
1236   }
1237   /* fallback to MPIAIJ ops */
1238   if (!mat->ops->productsymbolic) PetscCall(MatProductSetFromOptions_MPIAIJ(mat));
1239   PetscFunctionReturn(PETSC_SUCCESS);
1240 }
1241 
1242 static PetscErrorCode MatSetPreallocationCOO_MPIAIJKokkos(Mat mat, PetscCount coo_n, PetscInt coo_i[], PetscInt coo_j[])
1243 {
1244   Mat_MPIAIJ       *mpiaij = (Mat_MPIAIJ *)mat->data;
1245   Mat_MPIAIJKokkos *mpikok;
1246 
1247   PetscFunctionBegin;
1248   PetscCall(MatSetPreallocationCOO_MPIAIJ(mat, coo_n, coo_i, coo_j)); /* mpiaij->A,B's type is set to seqaijkokkos */
1249   mat->preallocated = PETSC_TRUE;
1250   PetscCall(MatAssemblyBegin(mat, MAT_FINAL_ASSEMBLY));
1251   PetscCall(MatAssemblyEnd(mat, MAT_FINAL_ASSEMBLY));
1252   PetscCall(MatZeroEntries(mat));
1253   mpikok = static_cast<Mat_MPIAIJKokkos *>(mpiaij->spptr);
1254   delete mpikok;
1255   mpiaij->spptr = new Mat_MPIAIJKokkos(mpiaij);
1256   PetscFunctionReturn(PETSC_SUCCESS);
1257 }
1258 
1259 static PetscErrorCode MatSetValuesCOO_MPIAIJKokkos(Mat mat, const PetscScalar v[], InsertMode imode)
1260 {
1261   Mat_MPIAIJ                 *mpiaij = static_cast<Mat_MPIAIJ *>(mat->data);
1262   Mat_MPIAIJKokkos           *mpikok = static_cast<Mat_MPIAIJKokkos *>(mpiaij->spptr);
1263   Mat                         A = mpiaij->A, B = mpiaij->B;
1264   PetscCount                  Annz = mpiaij->Annz, Annz2 = mpiaij->Annz2, Bnnz = mpiaij->Bnnz, Bnnz2 = mpiaij->Bnnz2;
1265   MatScalarKokkosView         Aa, Ba;
1266   MatScalarKokkosView         v1;
1267   MatScalarKokkosView        &vsend  = mpikok->sendbuf_d;
1268   const MatScalarKokkosView  &v2     = mpikok->recvbuf_d;
1269   const PetscCountKokkosView &Ajmap1 = mpikok->Ajmap1_d, Ajmap2 = mpikok->Ajmap2_d, Aimap2 = mpikok->Aimap2_d;
1270   const PetscCountKokkosView &Bjmap1 = mpikok->Bjmap1_d, Bjmap2 = mpikok->Bjmap2_d, Bimap2 = mpikok->Bimap2_d;
1271   const PetscCountKokkosView &Aperm1 = mpikok->Aperm1_d, Aperm2 = mpikok->Aperm2_d, Bperm1 = mpikok->Bperm1_d, Bperm2 = mpikok->Bperm2_d;
1272   const PetscCountKokkosView &Cperm1 = mpikok->Cperm1_d;
1273   PetscMemType                memtype;
1274 
1275   PetscFunctionBegin;
1276   PetscCall(PetscGetMemType(v, &memtype)); /* Return PETSC_MEMTYPE_HOST when v is NULL */
1277   if (PetscMemTypeHost(memtype)) {         /* If user gave v[] in host, we need to copy it to device if any */
1278     v1 = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), MatScalarKokkosViewHost((PetscScalar *)v, mpiaij->coo_n));
1279   } else {
1280     v1 = MatScalarKokkosView((PetscScalar *)v, mpiaij->coo_n); /* Directly use v[]'s memory */
1281   }
1282 
1283   if (imode == INSERT_VALUES) {
1284     PetscCall(MatSeqAIJGetKokkosViewWrite(A, &Aa)); /* write matrix values */
1285     PetscCall(MatSeqAIJGetKokkosViewWrite(B, &Ba));
1286   } else {
1287     PetscCall(MatSeqAIJGetKokkosView(A, &Aa)); /* read & write matrix values */
1288     PetscCall(MatSeqAIJGetKokkosView(B, &Ba));
1289   }
1290 
1291   /* Pack entries to be sent to remote */
1292   Kokkos::parallel_for(
1293     vsend.extent(0), KOKKOS_LAMBDA(const PetscCount i) { vsend(i) = v1(Cperm1(i)); });
1294 
1295   /* Send remote entries to their owner and overlap the communication with local computation */
1296   PetscCall(PetscSFReduceWithMemTypeBegin(mpiaij->coo_sf, MPIU_SCALAR, PETSC_MEMTYPE_KOKKOS, vsend.data(), PETSC_MEMTYPE_KOKKOS, v2.data(), MPI_REPLACE));
1297   /* Add local entries to A and B in one kernel */
1298   Kokkos::parallel_for(
1299     Annz + Bnnz, KOKKOS_LAMBDA(PetscCount i) {
1300       PetscScalar sum = 0.0;
1301       if (i < Annz) {
1302         for (PetscCount k = Ajmap1(i); k < Ajmap1(i + 1); k++) sum += v1(Aperm1(k));
1303         Aa(i) = (imode == INSERT_VALUES ? 0.0 : Aa(i)) + sum;
1304       } else {
1305         i -= Annz;
1306         for (PetscCount k = Bjmap1(i); k < Bjmap1(i + 1); k++) sum += v1(Bperm1(k));
1307         Ba(i) = (imode == INSERT_VALUES ? 0.0 : Ba(i)) + sum;
1308       }
1309     });
1310   PetscCall(PetscSFReduceEnd(mpiaij->coo_sf, MPIU_SCALAR, vsend.data(), v2.data(), MPI_REPLACE));
1311 
1312   /* Add received remote entries to A and B in one kernel */
1313   Kokkos::parallel_for(
1314     Annz2 + Bnnz2, KOKKOS_LAMBDA(PetscCount i) {
1315       if (i < Annz2) {
1316         for (PetscCount k = Ajmap2(i); k < Ajmap2(i + 1); k++) Aa(Aimap2(i)) += v2(Aperm2(k));
1317       } else {
1318         i -= Annz2;
1319         for (PetscCount k = Bjmap2(i); k < Bjmap2(i + 1); k++) Ba(Bimap2(i)) += v2(Bperm2(k));
1320       }
1321     });
1322 
1323   if (imode == INSERT_VALUES) {
1324     PetscCall(MatSeqAIJRestoreKokkosViewWrite(A, &Aa)); /* Increase A & B's state etc. */
1325     PetscCall(MatSeqAIJRestoreKokkosViewWrite(B, &Ba));
1326   } else {
1327     PetscCall(MatSeqAIJRestoreKokkosView(A, &Aa));
1328     PetscCall(MatSeqAIJRestoreKokkosView(B, &Ba));
1329   }
1330   PetscFunctionReturn(PETSC_SUCCESS);
1331 }
1332 
1333 PetscErrorCode MatDestroy_MPIAIJKokkos(Mat A)
1334 {
1335   Mat_MPIAIJ *mpiaij = (Mat_MPIAIJ *)A->data;
1336 
1337   PetscFunctionBegin;
1338   PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatMPIAIJSetPreallocation_C", NULL));
1339   PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatMPIAIJGetLocalMatMerge_C", NULL));
1340   PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatSetPreallocationCOO_C", NULL));
1341   PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatSetValuesCOO_C", NULL));
1342   delete (Mat_MPIAIJKokkos *)mpiaij->spptr;
1343   PetscCall(MatDestroy_MPIAIJ(A));
1344   PetscFunctionReturn(PETSC_SUCCESS);
1345 }
1346 
1347 PETSC_INTERN PetscErrorCode MatConvert_MPIAIJ_MPIAIJKokkos(Mat A, MatType mtype, MatReuse reuse, Mat *newmat)
1348 {
1349   Mat         B;
1350   Mat_MPIAIJ *a;
1351 
1352   PetscFunctionBegin;
1353   if (reuse == MAT_INITIAL_MATRIX) {
1354     PetscCall(MatDuplicate(A, MAT_COPY_VALUES, newmat));
1355   } else if (reuse == MAT_REUSE_MATRIX) {
1356     PetscCall(MatCopy(A, *newmat, SAME_NONZERO_PATTERN));
1357   }
1358   B = *newmat;
1359 
1360   B->boundtocpu = PETSC_FALSE;
1361   PetscCall(PetscFree(B->defaultvectype));
1362   PetscCall(PetscStrallocpy(VECKOKKOS, &B->defaultvectype));
1363   PetscCall(PetscObjectChangeTypeName((PetscObject)B, MATMPIAIJKOKKOS));
1364 
1365   a = static_cast<Mat_MPIAIJ *>(A->data);
1366   if (a->A) PetscCall(MatSetType(a->A, MATSEQAIJKOKKOS));
1367   if (a->B) PetscCall(MatSetType(a->B, MATSEQAIJKOKKOS));
1368   if (a->lvec) PetscCall(VecSetType(a->lvec, VECSEQKOKKOS));
1369 
1370   B->ops->assemblyend           = MatAssemblyEnd_MPIAIJKokkos;
1371   B->ops->mult                  = MatMult_MPIAIJKokkos;
1372   B->ops->multadd               = MatMultAdd_MPIAIJKokkos;
1373   B->ops->multtranspose         = MatMultTranspose_MPIAIJKokkos;
1374   B->ops->productsetfromoptions = MatProductSetFromOptions_MPIAIJKokkos;
1375   B->ops->destroy               = MatDestroy_MPIAIJKokkos;
1376 
1377   PetscCall(PetscObjectComposeFunction((PetscObject)B, "MatMPIAIJSetPreallocation_C", MatMPIAIJSetPreallocation_MPIAIJKokkos));
1378   PetscCall(PetscObjectComposeFunction((PetscObject)B, "MatMPIAIJGetLocalMatMerge_C", MatMPIAIJGetLocalMatMerge_MPIAIJKokkos));
1379   PetscCall(PetscObjectComposeFunction((PetscObject)B, "MatSetPreallocationCOO_C", MatSetPreallocationCOO_MPIAIJKokkos));
1380   PetscCall(PetscObjectComposeFunction((PetscObject)B, "MatSetValuesCOO_C", MatSetValuesCOO_MPIAIJKokkos));
1381   PetscFunctionReturn(PETSC_SUCCESS);
1382 }
1383 /*MC
1384    MATAIJKOKKOS - "mpiaijkokkos", a matrix type to be used for CSR sparse matrices with Kokkos
1385 
1386    A matrix type type using Kokkos-Kernels CrsMatrix type for portability across different device types
1387 
1388    Options Database Keys:
1389 .  -mat_type aijkokkos - sets the matrix type to "aijkokkos" during a call to MatSetFromOptions()
1390 
1391   Level: beginner
1392 
1393 .seealso: `MatCreateAIJKokkos()`, `MATSEQAIJKOKKOS`, `MATSEQAIJ`, `MATMPIAIJ`
1394 M*/
1395 PETSC_EXTERN PetscErrorCode MatCreate_MPIAIJKokkos(Mat A)
1396 {
1397   PetscFunctionBegin;
1398   PetscCall(PetscKokkosInitializeCheck());
1399   PetscCall(MatCreate_MPIAIJ(A));
1400   PetscCall(MatConvert_MPIAIJ_MPIAIJKokkos(A, MATMPIAIJKOKKOS, MAT_INPLACE_MATRIX, &A));
1401   PetscFunctionReturn(PETSC_SUCCESS);
1402 }
1403 
1404 /*@C
1405    MatCreateAIJKokkos - Creates a sparse matrix in `MATAIJKOKOS` (compressed row) format
1406    (the default parallel PETSc format).  This matrix will ultimately pushed down
1407    to Kokkos for calculations. For good matrix
1408    assembly performance the user should preallocate the matrix storage by setting
1409    the parameter nz (or the array nnz).  By setting these parameters accurately,
1410    performance during matrix assembly can be increased by more than a factor of 50.
1411 
1412    Collective
1413 
1414    Input Parameters:
1415 +  comm - MPI communicator, set to `PETSC_COMM_SELF`
1416 .  m - number of rows
1417 .  n - number of columns
1418 .  nz - number of nonzeros per row (same for all rows)
1419 -  nnz - array containing the number of nonzeros in the various rows
1420          (possibly different for each row) or NULL
1421 
1422    Output Parameter:
1423 .  A - the matrix
1424 
1425    It is recommended that one use the `MatCreate()`, `MatSetType()` and/or `MatSetFromOptions()`,
1426    MatXXXXSetPreallocation() paradigm instead of this routine directly.
1427    [MatXXXXSetPreallocation() is, for example, `MatSeqAIJSetPreallocation()`]
1428 
1429    Notes:
1430    If nnz is given then nz is ignored
1431 
1432    The AIJ format, also called compressed row storage), is fully compatible with standard Fortran 77
1433    storage.  That is, the stored row and column indices can begin at
1434    either one (as in Fortran) or zero.  See the users' manual for details.
1435 
1436    Specify the preallocated storage with either nz or nnz (not both).
1437    Set nz = `PETSC_DEFAULT` and nnz = NULL for PETSc to control dynamic memory
1438    allocation.  For large problems you MUST preallocate memory or you
1439    will get TERRIBLE performance, see the users' manual chapter on matrices.
1440 
1441    By default, this format uses inodes (identical nodes) when possible, to
1442    improve numerical efficiency of matrix-vector products and solves. We
1443    search for consecutive rows with the same nonzero structure, thereby
1444    reusing matrix information to achieve increased efficiency.
1445 
1446    Level: intermediate
1447 
1448 .seealso: `MATAIJKOKOS`, `MATSEQAIJKOKOS`, `MATMPIAIJKOKOS`, `MatCreate()`, `MatCreateAIJ()`, `MatSetValues()`, `MatSeqAIJSetColumnIndices()`, `MatCreateSeqAIJWithArrays()`, `MatCreateAIJ()`, `MATMPIAIJKOKKOS`, `MATAIJKOKKOS`
1449 @*/
1450 PetscErrorCode MatCreateAIJKokkos(MPI_Comm comm, PetscInt m, PetscInt n, PetscInt M, PetscInt N, PetscInt d_nz, const PetscInt d_nnz[], PetscInt o_nz, const PetscInt o_nnz[], Mat *A)
1451 {
1452   PetscMPIInt size;
1453 
1454   PetscFunctionBegin;
1455   PetscCall(MatCreate(comm, A));
1456   PetscCall(MatSetSizes(*A, m, n, M, N));
1457   PetscCallMPI(MPI_Comm_size(comm, &size));
1458   if (size > 1) {
1459     PetscCall(MatSetType(*A, MATMPIAIJKOKKOS));
1460     PetscCall(MatMPIAIJSetPreallocation(*A, d_nz, d_nnz, o_nz, o_nnz));
1461   } else {
1462     PetscCall(MatSetType(*A, MATSEQAIJKOKKOS));
1463     PetscCall(MatSeqAIJSetPreallocation(*A, d_nz, d_nnz));
1464   }
1465   PetscFunctionReturn(PETSC_SUCCESS);
1466 }
1467 
1468 // get GPU pointer to stripped down Mat. For both Seq and MPI Mat.
1469 PetscErrorCode MatKokkosGetDeviceMatWrite(Mat A, PetscSplitCSRDataStructure *B)
1470 {
1471   PetscMPIInt                size, rank;
1472   MPI_Comm                   comm;
1473   PetscSplitCSRDataStructure d_mat = NULL;
1474 
1475   PetscFunctionBegin;
1476   PetscCall(PetscObjectGetComm((PetscObject)A, &comm));
1477   PetscCallMPI(MPI_Comm_size(comm, &size));
1478   PetscCallMPI(MPI_Comm_rank(comm, &rank));
1479   if (size == 1) {
1480     PetscCall(MatSeqAIJKokkosGetDeviceMat(A, &d_mat));
1481     PetscCall(MatSeqAIJKokkosModifyDevice(A)); /* Since we are going to modify matrix values on device */
1482   } else {
1483     Mat_MPIAIJ *aij = (Mat_MPIAIJ *)A->data;
1484     PetscCall(MatSeqAIJKokkosGetDeviceMat(aij->A, &d_mat));
1485     PetscCall(MatSeqAIJKokkosModifyDevice(aij->A));
1486     PetscCall(MatSeqAIJKokkosModifyDevice(aij->B));
1487     PetscCheck(A->nooffprocentries || aij->donotstash, PetscObjectComm((PetscObject)A), PETSC_ERR_SUP, "Device assembly does not currently support offproc values insertion. Use MatSetOption(A,MAT_NO_OFF_PROC_ENTRIES,PETSC_TRUE) or MatSetOption(A,MAT_IGNORE_OFF_PROC_ENTRIES,PETSC_TRUE)");
1488   }
1489   // act like MatSetValues because not called on host
1490   if (A->assembled) {
1491     if (A->was_assembled) PetscCall(PetscInfo(A, "Assemble more than once already\n"));
1492     A->was_assembled = PETSC_TRUE; // this is done (lazy) in MatAssemble but we are not calling it anymore - done in AIJ AssemblyEnd, need here?
1493   } else {
1494     PetscCall(PetscInfo(A, "Warning !assemble ??? assembled=%" PetscInt_FMT "\n", A->assembled));
1495   }
1496   if (!d_mat) {
1497     struct _n_SplitCSRMat h_mat; /* host container */
1498     Mat_SeqAIJKokkos     *aijkokA;
1499     Mat_SeqAIJ           *jaca;
1500     PetscInt              n = A->rmap->n, nnz;
1501     Mat                   Amat;
1502     PetscInt             *colmap;
1503 
1504     /* create and copy h_mat */
1505     h_mat.M = A->cmap->N; // use for debug build
1506     PetscCall(PetscInfo(A, "Create device matrix in Kokkos\n"));
1507     if (size == 1) {
1508       Amat            = A;
1509       jaca            = (Mat_SeqAIJ *)A->data;
1510       h_mat.rstart    = 0;
1511       h_mat.rend      = A->rmap->n;
1512       h_mat.cstart    = 0;
1513       h_mat.cend      = A->cmap->n;
1514       h_mat.offdiag.i = h_mat.offdiag.j = NULL;
1515       h_mat.offdiag.a                   = NULL;
1516       aijkokA                           = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
1517     } else {
1518       Mat_MPIAIJ       *aij  = (Mat_MPIAIJ *)A->data;
1519       Mat_SeqAIJ       *jacb = (Mat_SeqAIJ *)aij->B->data;
1520       PetscInt          ii;
1521       Mat_SeqAIJKokkos *aijkokB;
1522 
1523       Amat    = aij->A;
1524       aijkokA = static_cast<Mat_SeqAIJKokkos *>(aij->A->spptr);
1525       aijkokB = static_cast<Mat_SeqAIJKokkos *>(aij->B->spptr);
1526       jaca    = (Mat_SeqAIJ *)aij->A->data;
1527       PetscCheck(!aij->B->cmap->n || aij->garray, comm, PETSC_ERR_PLIB, "MPIAIJ Matrix was assembled but is missing garray");
1528       PetscCheck(aij->B->rmap->n == aij->A->rmap->n, comm, PETSC_ERR_SUP, "Only support aij->B->rmap->n == aij->A->rmap->n");
1529       aij->donotstash          = PETSC_TRUE;
1530       aij->A->nooffprocentries = aij->B->nooffprocentries = A->nooffprocentries = PETSC_TRUE;
1531       jaca->nonew = jacb->nonew = PETSC_TRUE; // no more disassembly
1532       PetscCall(PetscCalloc1(A->cmap->N, &colmap));
1533       for (ii = 0; ii < aij->B->cmap->n; ii++) colmap[aij->garray[ii]] = ii + 1;
1534       // allocate B copy data
1535       h_mat.rstart = A->rmap->rstart;
1536       h_mat.rend   = A->rmap->rend;
1537       h_mat.cstart = A->cmap->rstart;
1538       h_mat.cend   = A->cmap->rend;
1539       nnz          = jacb->i[n];
1540       if (jacb->compressedrow.use) {
1541         const Kokkos::View<PetscInt *, Kokkos::HostSpace, Kokkos::MemoryTraits<Kokkos::Unmanaged>> h_i_k(jacb->i, n + 1);
1542         aijkokB->i_uncompressed_d = Kokkos::View<PetscInt *>(Kokkos::create_mirror(DefaultMemorySpace(), h_i_k));
1543         Kokkos::deep_copy(aijkokB->i_uncompressed_d, h_i_k);
1544         h_mat.offdiag.i = aijkokB->i_uncompressed_d.data();
1545       } else {
1546         h_mat.offdiag.i = aijkokB->i_device_data();
1547       }
1548       h_mat.offdiag.j = aijkokB->j_device_data();
1549       h_mat.offdiag.a = aijkokB->a_device_data();
1550       {
1551         Kokkos::View<PetscInt *, Kokkos::HostSpace, Kokkos::MemoryTraits<Kokkos::Unmanaged>> h_colmap_k(colmap, A->cmap->N);
1552         aijkokB->colmap_d = Kokkos::View<PetscInt *>(Kokkos::create_mirror(DefaultMemorySpace(), h_colmap_k));
1553         Kokkos::deep_copy(aijkokB->colmap_d, h_colmap_k);
1554         h_mat.colmap = aijkokB->colmap_d.data();
1555         PetscCall(PetscFree(colmap));
1556       }
1557       h_mat.offdiag.ignorezeroentries = jacb->ignorezeroentries;
1558       h_mat.offdiag.n                 = n;
1559     }
1560     // allocate A copy data
1561     nnz                          = jaca->i[n];
1562     h_mat.diag.n                 = n;
1563     h_mat.diag.ignorezeroentries = jaca->ignorezeroentries;
1564     PetscCallMPI(MPI_Comm_rank(comm, &h_mat.rank));
1565     PetscCheck(!jaca->compressedrow.use, PETSC_COMM_SELF, PETSC_ERR_PLIB, "A does not support compressed row (todo)");
1566     h_mat.diag.i = aijkokA->i_device_data();
1567     h_mat.diag.j = aijkokA->j_device_data();
1568     h_mat.diag.a = aijkokA->a_device_data();
1569     // copy pointers and metadata to device
1570     PetscCall(MatSeqAIJKokkosSetDeviceMat(Amat, &h_mat));
1571     PetscCall(MatSeqAIJKokkosGetDeviceMat(Amat, &d_mat));
1572     PetscCall(PetscInfo(A, "Create device Mat n=%" PetscInt_FMT " nnz=%" PetscInt_FMT "\n", h_mat.diag.n, nnz));
1573   }
1574   *B           = d_mat;       // return it, set it in Mat, and set it up
1575   A->assembled = PETSC_FALSE; // ready to write with matsetvalues - this done (lazy) in normal MatSetValues
1576   PetscFunctionReturn(PETSC_SUCCESS);
1577 }
1578 
1579 PETSC_INTERN PetscErrorCode MatSeqAIJKokkosGetOffloadMask(Mat A, const char **mask)
1580 {
1581   Mat_SeqAIJKokkos *aijkok = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
1582 
1583   PetscFunctionBegin;
1584   if (!aijkok) *mask = "AIJKOK_UNALLOCATED";
1585   else if (aijkok->a_dual.need_sync_host()) *mask = "PETSC_OFFLOAD_GPU";
1586   else if (aijkok->a_dual.need_sync_device()) *mask = "PETSC_OFFLOAD_CPU";
1587   else *mask = "PETSC_OFFLOAD_BOTH";
1588   PetscFunctionReturn(PETSC_SUCCESS);
1589 }
1590 
1591 PETSC_INTERN PetscErrorCode MatAIJKokkosPrintOffloadMask(Mat A)
1592 {
1593   PetscMPIInt size;
1594   Mat         Ad, Ao;
1595   const char *amask, *bmask;
1596 
1597   PetscFunctionBegin;
1598   PetscCallMPI(MPI_Comm_size(PetscObjectComm((PetscObject)A), &size));
1599 
1600   if (size == 1) {
1601     PetscCall(MatSeqAIJKokkosGetOffloadMask(A, &amask));
1602     PetscCall(PetscPrintf(PETSC_COMM_SELF, "%s\n", amask));
1603   } else {
1604     Ad = ((Mat_MPIAIJ *)A->data)->A;
1605     Ao = ((Mat_MPIAIJ *)A->data)->B;
1606     PetscCall(MatSeqAIJKokkosGetOffloadMask(Ad, &amask));
1607     PetscCall(MatSeqAIJKokkosGetOffloadMask(Ao, &bmask));
1608     PetscCall(PetscPrintf(PETSC_COMM_SELF, "Diag : Off-diag = %s : %s\n", amask, bmask));
1609   }
1610   PetscFunctionReturn(PETSC_SUCCESS);
1611 }
1612