xref: /petsc/src/mat/impls/aij/mpi/kokkos/mpiaijkok.kokkos.cxx (revision 2d30e087755efd99e28fdfe792ffbeb2ee1ea928)
1 #include <petscvec_kokkos.hpp>
2 #include <petscsf.h>
3 #include <petsc/private/sfimpl.h>
4 #include <../src/mat/impls/aij/mpi/mpiaij.h>
5 #include <../src/mat/impls/aij/mpi/kokkos/mpiaijkok.hpp>
6 #include <KokkosSparse_spadd.hpp>
7 
8 PetscErrorCode MatAssemblyEnd_MPIAIJKokkos(Mat A, MatAssemblyType mode) {
9   Mat_SeqAIJKokkos *aijkok;
10 
11   PetscFunctionBegin;
12   PetscCall(MatAssemblyEnd_MPIAIJ(A, mode));
13   aijkok = static_cast<Mat_SeqAIJKokkos *>(((Mat_MPIAIJ *)A->data)->A->spptr); /* Access spptr after MatAssemblyEnd_MPIAIJ(), which might have deleted old spptr */
14   if (aijkok && aijkok->device_mat_d.data()) {
15     A->offloadmask = PETSC_OFFLOAD_GPU; // in GPU mode, no going back. MatSetValues checks this
16   }
17 
18   PetscFunctionReturn(0);
19 }
20 
21 PetscErrorCode MatMPIAIJSetPreallocation_MPIAIJKokkos(Mat mat, PetscInt d_nz, const PetscInt d_nnz[], PetscInt o_nz, const PetscInt o_nnz[]) {
22   Mat_MPIAIJ *mpiaij = (Mat_MPIAIJ *)mat->data;
23 
24   PetscFunctionBegin;
25   PetscCall(PetscLayoutSetUp(mat->rmap));
26   PetscCall(PetscLayoutSetUp(mat->cmap));
27 #if defined(PETSC_USE_DEBUG)
28   if (d_nnz) {
29     PetscInt i;
30     for (i = 0; i < mat->rmap->n; i++) PetscCheck(d_nnz[i] >= 0, PETSC_COMM_SELF, PETSC_ERR_ARG_OUTOFRANGE, "d_nnz cannot be less than 0: local row %" PetscInt_FMT " value %" PetscInt_FMT, i, d_nnz[i]);
31   }
32   if (o_nnz) {
33     PetscInt i;
34     for (i = 0; i < mat->rmap->n; i++) PetscCheck(o_nnz[i] >= 0, PETSC_COMM_SELF, PETSC_ERR_ARG_OUTOFRANGE, "o_nnz cannot be less than 0: local row %" PetscInt_FMT " value %" PetscInt_FMT, i, o_nnz[i]);
35   }
36 #endif
37 #if defined(PETSC_USE_CTABLE)
38   PetscCall(PetscTableDestroy(&mpiaij->colmap));
39 #else
40   PetscCall(PetscFree(mpiaij->colmap));
41 #endif
42   PetscCall(PetscFree(mpiaij->garray));
43   PetscCall(VecDestroy(&mpiaij->lvec));
44   PetscCall(VecScatterDestroy(&mpiaij->Mvctx));
45   /* Because the B will have been resized we simply destroy it and create a new one each time */
46   PetscCall(MatDestroy(&mpiaij->B));
47 
48   if (!mpiaij->A) {
49     PetscCall(MatCreate(PETSC_COMM_SELF, &mpiaij->A));
50     PetscCall(MatSetSizes(mpiaij->A, mat->rmap->n, mat->cmap->n, mat->rmap->n, mat->cmap->n));
51     PetscCall(PetscLogObjectParent((PetscObject)mat, (PetscObject)mpiaij->A));
52   }
53   if (!mpiaij->B) {
54     PetscMPIInt size;
55     PetscCallMPI(MPI_Comm_size(PetscObjectComm((PetscObject)mat), &size));
56     PetscCall(MatCreate(PETSC_COMM_SELF, &mpiaij->B));
57     PetscCall(MatSetSizes(mpiaij->B, mat->rmap->n, size > 1 ? mat->cmap->N : 0, mat->rmap->n, size > 1 ? mat->cmap->N : 0));
58     PetscCall(PetscLogObjectParent((PetscObject)mat, (PetscObject)mpiaij->B));
59   }
60   PetscCall(MatSetType(mpiaij->A, MATSEQAIJKOKKOS));
61   PetscCall(MatSetType(mpiaij->B, MATSEQAIJKOKKOS));
62   PetscCall(MatSeqAIJSetPreallocation(mpiaij->A, d_nz, d_nnz));
63   PetscCall(MatSeqAIJSetPreallocation(mpiaij->B, o_nz, o_nnz));
64   mat->preallocated = PETSC_TRUE;
65   PetscFunctionReturn(0);
66 }
67 
68 PetscErrorCode MatMult_MPIAIJKokkos(Mat mat, Vec xx, Vec yy) {
69   Mat_MPIAIJ *mpiaij = (Mat_MPIAIJ *)mat->data;
70   PetscInt    nt;
71 
72   PetscFunctionBegin;
73   PetscCall(VecGetLocalSize(xx, &nt));
74   PetscCheck(nt == mat->cmap->n, PETSC_COMM_SELF, PETSC_ERR_ARG_SIZ, "Incompatible partition of mat (%" PetscInt_FMT ") and xx (%" PetscInt_FMT ")", mat->cmap->n, nt);
75   PetscCall(VecScatterBegin(mpiaij->Mvctx, xx, mpiaij->lvec, INSERT_VALUES, SCATTER_FORWARD));
76   PetscCall((*mpiaij->A->ops->mult)(mpiaij->A, xx, yy));
77   PetscCall(VecScatterEnd(mpiaij->Mvctx, xx, mpiaij->lvec, INSERT_VALUES, SCATTER_FORWARD));
78   PetscCall((*mpiaij->B->ops->multadd)(mpiaij->B, mpiaij->lvec, yy, yy));
79   PetscFunctionReturn(0);
80 }
81 
82 PetscErrorCode MatMultAdd_MPIAIJKokkos(Mat mat, Vec xx, Vec yy, Vec zz) {
83   Mat_MPIAIJ *mpiaij = (Mat_MPIAIJ *)mat->data;
84   PetscInt    nt;
85 
86   PetscFunctionBegin;
87   PetscCall(VecGetLocalSize(xx, &nt));
88   PetscCheck(nt == mat->cmap->n, PETSC_COMM_SELF, PETSC_ERR_ARG_SIZ, "Incompatible partition of mat (%" PetscInt_FMT ") and xx (%" PetscInt_FMT ")", mat->cmap->n, nt);
89   PetscCall(VecScatterBegin(mpiaij->Mvctx, xx, mpiaij->lvec, INSERT_VALUES, SCATTER_FORWARD));
90   PetscCall((*mpiaij->A->ops->multadd)(mpiaij->A, xx, yy, zz));
91   PetscCall(VecScatterEnd(mpiaij->Mvctx, xx, mpiaij->lvec, INSERT_VALUES, SCATTER_FORWARD));
92   PetscCall((*mpiaij->B->ops->multadd)(mpiaij->B, mpiaij->lvec, zz, zz));
93   PetscFunctionReturn(0);
94 }
95 
96 PetscErrorCode MatMultTranspose_MPIAIJKokkos(Mat mat, Vec xx, Vec yy) {
97   Mat_MPIAIJ *mpiaij = (Mat_MPIAIJ *)mat->data;
98   PetscInt    nt;
99 
100   PetscFunctionBegin;
101   PetscCall(VecGetLocalSize(xx, &nt));
102   PetscCheck(nt == mat->rmap->n, PETSC_COMM_SELF, PETSC_ERR_ARG_SIZ, "Incompatible partition of mat (%" PetscInt_FMT ") and xx (%" PetscInt_FMT ")", mat->rmap->n, nt);
103   PetscCall((*mpiaij->B->ops->multtranspose)(mpiaij->B, xx, mpiaij->lvec));
104   PetscCall((*mpiaij->A->ops->multtranspose)(mpiaij->A, xx, yy));
105   PetscCall(VecScatterBegin(mpiaij->Mvctx, mpiaij->lvec, yy, ADD_VALUES, SCATTER_REVERSE));
106   PetscCall(VecScatterEnd(mpiaij->Mvctx, mpiaij->lvec, yy, ADD_VALUES, SCATTER_REVERSE));
107   PetscFunctionReturn(0);
108 }
109 
110 /* Merge the "A, B" matrices of mat into a matrix C.  mat's type is MPIAIJKOKKOS. C's type is MATSEQAIJKOKKOS.
111    A is put before B. C's size would be A->rmap->n by (A->cmap->n + B->cmap->n).
112    C still uses local column ids. Their corresponding global column ids are returned in glob.
113 */
114 PetscErrorCode MatMPIAIJGetLocalMatMerge_MPIAIJKokkos(Mat mat, MatReuse reuse, IS *glob, Mat *C) {
115   Mat             Ad, Ao;
116   const PetscInt *cmap;
117 
118   PetscFunctionBegin;
119   PetscCall(MatMPIAIJGetSeqAIJ(mat, &Ad, &Ao, &cmap));
120   PetscCall(MatSeqAIJKokkosMergeMats(Ad, Ao, reuse, C));
121   if (glob) {
122     PetscInt cst, i, dn, on, *gidx;
123     PetscCall(MatGetLocalSize(Ad, NULL, &dn));
124     PetscCall(MatGetLocalSize(Ao, NULL, &on));
125     PetscCall(MatGetOwnershipRangeColumn(mat, &cst, NULL));
126     PetscCall(PetscMalloc1(dn + on, &gidx));
127     for (i = 0; i < dn; i++) gidx[i] = cst + i;
128     for (i = 0; i < on; i++) gidx[i + dn] = cmap[i];
129     PetscCall(ISCreateGeneral(PetscObjectComm((PetscObject)Ad), dn + on, gidx, PETSC_OWN_POINTER, glob));
130   }
131   PetscFunctionReturn(0);
132 }
133 
134 /* Structs used in matrix product C=AB, C=A^tB and C=B^tAB */
135 struct MatMatStruct {
136   MatRowMapKokkosView Cdstart; /* Used to split sequential matrix into petsc's A, B format */
137   PetscSF             sf;      /* SF to send/recv matrix entries */
138   MatScalarKokkosView abuf;    /* buf of mat values in send/recv */
139   Mat                 C1, C2, B_local;
140   KokkosCsrMatrix     C1_global, C2_global, C_global;
141   KernelHandle        kh;
142   MatMatStruct() {
143     C1 = C2 = B_local = NULL;
144     sf                = NULL;
145   }
146 
147   ~MatMatStruct() {
148     MatDestroy(&C1);
149     MatDestroy(&C2);
150     MatDestroy(&B_local);
151     PetscSFDestroy(&sf);
152     kh.destroy_spadd_handle();
153   }
154 };
155 
156 struct MatMatStruct_AB : public MatMatStruct {
157   MatColIdxKokkosView rows;
158   MatRowMapKokkosView rowoffset;
159   Mat                 B_other, C_petsc; /* SEQAIJKOKKOS matrices. TODO: have a better var name than C_petsc */
160 
161   MatMatStruct_AB() : B_other(NULL), C_petsc(NULL) { }
162   ~MatMatStruct_AB() {
163     MatDestroy(&B_other);
164     MatDestroy(&C_petsc);
165   }
166 };
167 
168 struct MatMatStruct_AtB : public MatMatStruct {
169   MatRowMapKokkosView srcrowoffset, dstrowoffset;
170 };
171 
172 struct MatProductData_MPIAIJKokkos {
173   MatMatStruct_AB  *mmAB;
174   MatMatStruct_AtB *mmAtB;
175   PetscBool         reusesym;
176 
177   MatProductData_MPIAIJKokkos() : mmAB(NULL), mmAtB(NULL), reusesym(PETSC_FALSE) { }
178   ~MatProductData_MPIAIJKokkos() {
179     delete mmAB;
180     delete mmAtB;
181   }
182 };
183 
184 static PetscErrorCode MatProductDataDestroy_MPIAIJKokkos(void *data) {
185   PetscFunctionBegin;
186   PetscCallCXX(delete static_cast<MatProductData_MPIAIJKokkos *>(data));
187   PetscFunctionReturn(0);
188 }
189 
190 /* MatSeqAIJKokkosGetCSRMatrixWithGlobalColumnIds - Get a KokkosCsrMatrix from a MATSEQAIJKOKKOS matrix
191 
192    Input Parameters:
193 +  A       - the MATSEQAIJKOKKOS matrix
194 .  N       - new column size for the returned Kokkos matrix
195 -  l2g     - a map that maps old col ids to new col ids
196 
197    Output Parameters:
198 .  csrmat  - the Kokkos matrix, which has the same row size as A, shares a, i but not j with A.
199  */
200 static PetscErrorCode MatSeqAIJKokkosGetCSRMatrixWithGlobalColumnIds(Mat A, PetscInt N, const ConstMatColIdxKokkosView &l2g, KokkosCsrMatrix &csrmat) {
201   KokkosCsrMatrix    &orig = static_cast<Mat_SeqAIJKokkos *>(A->spptr)->csrmat;
202   MatColIdxKokkosView jg("jg", orig.nnz()); /* New j array for csrmat */
203 
204   PetscFunctionBegin;
205   PetscCallCXX(Kokkos::parallel_for(
206     orig.nnz(), KOKKOS_LAMBDA(const PetscInt i) { jg(i) = l2g(orig.graph.entries(i)); }));
207   PetscCallCXX(csrmat = KokkosCsrMatrix("csrmat", orig.numRows(), N, orig.nnz(), orig.values, orig.graph.row_map, jg));
208   PetscFunctionReturn(0);
209 }
210 
211 /* MatSetMPIAIJKokkosWithSplitSeqAIJKokkosMatrices - Set the diag and offdiag matrices of a MATMPIAIJKOKKOS matrix.
212    It is similar to MatCreateMPIAIJWithSplitArrays.
213 
214   Input Parameters:
215 +  mat   - the MATMPIAIJKOKKOS matrix, which should have its type and layout set, but should not have its diag, offdiag matrices set
216 .  A     - the diag matrix using local col ids
217 -  B     - the offdiag matrix using global col ids
218 
219   Output Parameters:
220 .  mat   - the updated MATMPIAIJKOKKOS matrix
221 */
222 static PetscErrorCode MatSetMPIAIJKokkosWithSplitSeqAIJKokkosMatrices(Mat mat, Mat A, Mat B) {
223   Mat_MPIAIJ       *mpiaij = static_cast<Mat_MPIAIJ *>(mat->data);
224   PetscInt          m, n, M, N, Am, An, Bm, Bn;
225   Mat_SeqAIJKokkos *bkok = static_cast<Mat_SeqAIJKokkos *>(B->spptr);
226 
227   PetscFunctionBegin;
228   PetscCall(MatGetSize(mat, &M, &N));
229   PetscCall(MatGetLocalSize(mat, &m, &n));
230   PetscCall(MatGetLocalSize(A, &Am, &An));
231   PetscCall(MatGetLocalSize(B, &Bm, &Bn));
232 
233   PetscCheck(m == Am && m == Bm, PETSC_COMM_SELF, PETSC_ERR_PLIB, "local number of rows do not match");
234   PetscCheck(n == An, PETSC_COMM_SELF, PETSC_ERR_PLIB, "local number of columns do not match");
235   PetscCheck(N == Bn, PETSC_COMM_SELF, PETSC_ERR_PLIB, "global number of columns do not match");
236   PetscCheck(!mpiaij->A && !mpiaij->B, PETSC_COMM_SELF, PETSC_ERR_PLIB, "A, B of the MPIAIJ matrix are not empty");
237   mpiaij->A = A;
238   mpiaij->B = B;
239 
240   mat->preallocated     = PETSC_TRUE;
241   mat->nooffprocentries = PETSC_TRUE; /* See MatAssemblyBegin_MPIAIJ. In effect, making MatAssemblyBegin a nop */
242 
243   PetscCall(MatSetOption(mat, MAT_NO_OFF_PROC_ENTRIES, PETSC_TRUE));
244   PetscCall(MatAssemblyBegin(mat, MAT_FINAL_ASSEMBLY));
245   /* MatAssemblyEnd is critical here. It sets mat->offloadmask according to A and B's, and
246     also gets mpiaij->B compacted, with its col ids and size reduced
247   */
248   PetscCall(MatAssemblyEnd(mat, MAT_FINAL_ASSEMBLY));
249   PetscCall(MatSetOption(mat, MAT_NO_OFF_PROC_ENTRIES, PETSC_FALSE));
250   PetscCall(MatSetOption(mat, MAT_NEW_NONZERO_LOCATION_ERR, PETSC_TRUE));
251 
252   /* Update bkok with new local col ids (stored on host) and size */
253   bkok->j_dual.modify_host();
254   bkok->j_dual.sync_device();
255   bkok->SetColSize(mpiaij->B->cmap->n);
256   PetscFunctionReturn(0);
257 }
258 
259 /* MatSeqAIJKokkosBcast - Bcast rows of a SEQAIJKOKKOS matrice (B) to form a SEQAIJKOKKOS matrix (C).
260 
261    It is essentially the MPIAIJKOKKOS counterpart of MatGetBrowsOfAoCols_MPIAIJ, but supports device and uses PetscSF.
262    In the given ownerSF, leaves correspond to rows in C, and roots correspond to rows in B. Roots may connect to multiple leaves.
263    Suppose C's j-th row is connected to a root identified by PetscSFNode (k,i), it means we will bcast the i-th row of B on rank k
264    to j-th row of C. ownerSF's leaves must be contiguous (in other words, as if ilocal=NULL was used to set its graph).
265 
266    Collective on comm of ownerSF
267 
268    Input Parameters:
269 +   B       - the SEQAIJKOKKOS matrix, using local col ids
270 .   reuse   - either MAT_INITIAL_MATRIX or MAT_REUSE_MATRIX
271 .   N       - global col ids are in range of [0,N). N Must be the same across ranks (nonsignificant in MAT_REUSE_MATRIX)
272 .   l2g     - a map mapping B's local col ids to global ones (nonsignificant in MAT_REUSE_MATRIX)
273 .   ownerSF - the ownership SF (nonsignificant in MAT_REUSE_MATRIX)
274 
275    Input/Output Parameters (out when resue = MAT_INITIAL_MATRIX, inout when reuse = MAT_REUSE_MATRIX)
276 +   bcastSF   - the SF used to bcast rows of B. This plain SF does buffer (abuf) to buffer (Ca) send/recv. In this SF, vertices are nonzeros.
277 .   abuf      - buffer for sending matrix values
278 .   rows      - array containing indices of (local) rows that this rank needs to bcast to others. Each receiver rank has a chunk in rows[].
279                 Values in rows[] might have repeats, which simply indicates a row will be bcast'ed to multiple neighbors.
280 .   rowoffset - For each row in rows[], it will be copied to rowoffset[] at abuf[]
281 -   C         -  the SEQAIJKOKKOS matrix made of the bcast'ed rows, using local col ids.
282 */
283 static PetscErrorCode MatSeqAIJKokkosBcast(Mat B, MatReuse reuse, PetscInt N, const ConstMatColIdxKokkosView &l2g, PetscSF ownerSF, PetscSF &bcastSF, MatScalarKokkosView &abuf, MatColIdxKokkosView &rows, MatRowMapKokkosView &rowoffset, Mat &C) {
284   Mat_SeqAIJKokkos *bkok, *ckok;
285 
286   PetscFunctionBegin;
287   PetscCall(MatSeqAIJKokkosSyncDevice(B)); /* Make sure B->spptr is accessible */
288   bkok = static_cast<Mat_SeqAIJKokkos *>(B->spptr);
289 
290   if (reuse == MAT_REUSE_MATRIX) {
291     ckok = static_cast<Mat_SeqAIJKokkos *>(C->spptr);
292 
293     const auto &Ba = bkok->a_dual.view_device();
294     const auto &Bi = bkok->i_dual.view_device();
295     const auto &Ca = ckok->a_dual.view_device();
296 
297     /* Copy Ba to abuf */
298     Kokkos::parallel_for(
299       Kokkos::TeamPolicy<>(rows.extent(0), Kokkos::AUTO()), KOKKOS_LAMBDA(const KokkosTeamMemberType &t) {
300         PetscInt i    = t.league_rank(); /* rows[i] is r-th row of B */
301         PetscInt r    = rows(i);
302         PetscInt base = rowoffset(i); /* Copy r-th row of B to this offset in abuf[] */
303         Kokkos::parallel_for(Kokkos::TeamThreadRange(t, Bi(r + 1) - Bi(r)), [&](PetscInt k) { abuf(base + k) = Ba(Bi(r) + k); });
304       });
305 
306     /* Send abuf to Ca through bcastSF and then mark C is updated on device */
307     PetscCall(PetscSFBcastBegin(bcastSF, MPIU_SCALAR, abuf.data(), Ca.data(), MPI_REPLACE)); /* TODO: get memtype for abuf */
308     PetscCall(PetscSFBcastEnd(bcastSF, MPIU_SCALAR, abuf.data(), Ca.data(), MPI_REPLACE));
309     ckok->a_dual.modify_device();
310   } else if (reuse == MAT_INITIAL_MATRIX) {
311     MPI_Comm    comm;
312     PetscMPIInt tag;
313     PetscInt    k, Cm, Cn, Cnnz, *Ci_h, nroots, nleaves;
314 
315     PetscCallMPI(PetscObjectGetComm((PetscObject)ownerSF, &comm));
316     PetscCall(PetscSFGetGraph(ownerSF, &nroots, &nleaves, NULL, NULL));
317     Cm = nleaves; /* row size of C */
318     Cn = N;       /* col size of C, which initially uses global ids, so we can safely set its col size as N */
319 
320     /* Get row lens (nz) of B's rows for later fast query */
321     PetscInt       *Browlens;
322     const PetscInt *tmp = bkok->i_host_data();
323     PetscCall(PetscMalloc1(nroots, &Browlens));
324     for (k = 0; k < nroots; k++) Browlens[k] = tmp[k + 1] - tmp[k];
325 
326     /* By ownerSF, each proc gets lens of rows of C */
327     MatRowMapKokkosDualView Ci("i", Cm + 1); /* C's rowmap */
328     Ci_h    = Ci.view_host().data();
329     Ci_h[0] = 0;
330     PetscCall(PetscSFBcastWithMemTypeBegin(ownerSF, MPIU_INT, PETSC_MEMTYPE_HOST, Browlens, PETSC_MEMTYPE_HOST, &Ci_h[1], MPI_REPLACE));
331     PetscCall(PetscSFBcastEnd(ownerSF, MPIU_INT, Browlens, &Ci_h[1], MPI_REPLACE));
332     for (k = 1; k < Cm + 1; k++) Ci_h[k] += Ci_h[k - 1]; /* Convert lens to CSR */
333     Cnnz = Ci_h[Cm];
334     Ci.modify_host();
335     Ci.sync_device();
336 
337     /* With the newly known Cnnz, we are able to allocate (j, a) for C on host & device */
338     MatColIdxKokkosDualView Cj("j", Cnnz);
339     MatScalarKokkosDualView Ca("a", Cnnz);
340 
341     /* Now build the bcastSF to fill Ca, Cj. This plain SF only does (contiguous) buffer to buffer send/recv */
342     const PetscMPIInt *iranks, *ranks;
343     const PetscInt    *ioffset, *irootloc, *roffset;
344     PetscInt           i, j, niranks, nranks, *sdisp, *rdisp, *rowptr;
345     MPI_Request       *reqs;
346 
347     PetscCall(PetscSFGetLeafRanks(ownerSF, &niranks, &iranks, &ioffset, &irootloc));                      /* irootloc[] contains indices of rows I need to send to each receiver */
348     PetscCall(PetscSFGetRootRanks(ownerSF, &nranks, &ranks, &roffset, NULL /*rmine*/, NULL /*rremote*/)); /* recv info */
349 
350     /* figure out offsets at the send buffer, to build the SF
351       sdisp[]  - stores offsets of nonzeros (in abuf or jbuf, see later) I need to send, per receiver.
352       rowptr[] - stores offsets for data of each row in abuf
353 
354       rdisp[]  - to receive sdisp[]
355     */
356     PetscCall(PetscMalloc3(niranks + 1, &sdisp, nranks, &rdisp, niranks + nranks, &reqs));
357     MatRowMapKokkosViewHost rowptr_h("rowptr_h", ioffset[niranks] + 1); /* Let Kokkos do the allocation, so that we can do an easy mirror later */
358     rowptr = rowptr_h.data();
359 
360     sdisp[0]  = 0;
361     rowptr[0] = 0;
362     for (i = 0; i < niranks; i++) { /* for each receiver */
363       PetscInt len, nz = 0;
364       for (j = ioffset[i]; j < ioffset[i + 1]; j++) { /* for each row to this receiver */
365         len           = Browlens[irootloc[j]];
366         rowptr[j + 1] = rowptr[j] + len;
367         nz += len;
368       }
369       sdisp[i + 1] = sdisp[i] + nz;
370     }
371     PetscCallMPI(PetscCommGetNewTag(comm, &tag));
372     for (i = 0; i < nranks; i++) PetscCallMPI(MPI_Irecv(&rdisp[i], 1, MPIU_INT, ranks[i], tag, comm, &reqs[i]));
373     for (i = 0; i < niranks; i++) PetscCallMPI(MPI_Isend(&sdisp[i], 1, MPIU_INT, iranks[i], tag, comm, &reqs[nranks + i]));
374     PetscCallMPI(MPI_Waitall(niranks + nranks, reqs, MPI_STATUSES_IGNORE));
375 
376     PetscInt     nleaves2 = Cnnz;           /* leaves are the nonzeros I will receive */
377     PetscInt     nroots2  = sdisp[niranks]; /* roots are the nonzeros (in abuf) I will send */
378     PetscSFNode *iremote;
379     PetscCall(PetscMalloc1(nleaves2, &iremote));
380     for (i = 0; i < nranks; i++) { /* for each sender */
381       k = 0;
382       for (j = Ci_h[roffset[i]]; j < Ci_h[roffset[i + 1]]; j++) {
383         iremote[j].rank  = ranks[i];
384         iremote[j].index = rdisp[i] + k;
385         k++;
386       }
387     }
388     /* TODO: we should extend PetscSF APIs for this buffer-to-buffer send/recv */
389     PetscCall(PetscSFCreate(comm, &bcastSF));
390     PetscCall(PetscSFSetGraph(bcastSF, nroots2, nleaves2, NULL /*ilocal*/, PETSC_OWN_POINTER, iremote, PETSC_OWN_POINTER));
391 
392     /* Extract selected rows of B, and copy their (a, j) into abuf[] and jbuf[], with j converted
393       from local to global. Then use bcastSF to fill Ca, Cj.
394     */
395     ConstMatColIdxKokkosViewHost rows_h(irootloc, ioffset[niranks]); /* irootloc[] stores indices of rows I need to send */
396     MatColIdxKokkosView          rows("rows", ioffset[niranks]);
397     Kokkos::deep_copy(rows, rows_h); /* Use deep copy since irootoc is managed by PetscSF and we want 'rows' to be standalone */
398 
399     rowoffset = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), rowptr_h); /* If no device, rowoffset will be an alias to rowptr_h */
400 
401     MatColIdxKokkosView jbuf("jbuf", sdisp[niranks]);   /* send buf for (global) col ids */
402     abuf = MatScalarKokkosView("abuf", sdisp[niranks]); /* send buf for mat values */
403 
404     const auto &Ba = bkok->a_dual.view_device();
405     const auto &Bi = bkok->i_dual.view_device();
406     const auto &Bj = bkok->j_dual.view_device();
407 
408     /* Copy Ba, Bj to abuf, jbuf with change col ids from local to global */
409     Kokkos::parallel_for(
410       Kokkos::TeamPolicy<>(rows.extent(0), Kokkos::AUTO()), KOKKOS_LAMBDA(const KokkosTeamMemberType &t) {
411         PetscInt i    = t.league_rank(); /* rows[i] is r-th row of B */
412         PetscInt r    = rows(i);
413         PetscInt base = rowoffset(i); /* Copy r-th row of B to this offset in abuf[] */
414         Kokkos::parallel_for(Kokkos::TeamThreadRange(t, Bi(r + 1) - Bi(r)), [&](PetscInt k) {
415           abuf(base + k) = Ba(Bi(r) + k);
416           jbuf(base + k) = l2g(Bj(Bi(r) + k));
417         });
418       });
419 
420     /* Send abuf & jbuf to fill Ca, Cj */
421     PetscCall(PetscSFBcastBegin(bcastSF, MPIU_INT, jbuf.data(), Cj.view_device().data(), MPI_REPLACE));
422     PetscCall(PetscSFBcastBegin(bcastSF, MPIU_SCALAR, abuf.data(), Ca.view_device().data(), MPI_REPLACE));
423     PetscCall(PetscSFBcastEnd(bcastSF, MPIU_INT, jbuf.data(), Cj.view_device().data(), MPI_REPLACE));
424     PetscCall(PetscSFBcastEnd(bcastSF, MPIU_SCALAR, abuf.data(), Ca.view_device().data(), MPI_REPLACE));
425     Cj.modify_device(); /* Mark Cj, Ca modified on device, but only sync Cj since we might not need Ca on host at all */
426     Cj.sync_host();
427     Ca.modify_device();
428 
429     /* Construct C with Ca, Ci, Cj */
430     auto ckok = new Mat_SeqAIJKokkos(Cm, Cn, Cnnz, Ci, Cj, Ca);
431     PetscCall(MatCreateSeqAIJKokkosWithCSRMatrix(PETSC_COMM_SELF, ckok, &C));
432     PetscCall(PetscFree3(sdisp, rdisp, reqs));
433     PetscCall(PetscFree(Browlens));
434   } else SETERRQ(PETSC_COMM_SELF, PETSC_ERR_PLIB, "Unsupported MatReuse enum %d", reuse);
435   PetscFunctionReturn(0);
436 }
437 
438 /* MatSeqAIJKokkosReduce - Reduce rows of a SEQAIJKOKKOS matrix (A) to form a Kokkos Csr matrix (C)
439 
440   It is the reverse of MatSeqAIJKokkosBcast in some sense.
441 
442   Think each row of A as a leaf, then the given ownerSF specifies roots of the leaves. Roots may connect to multiple leaves.
443   In this routine, we reduce (i.e., concatenate) leaves (rows) at their roots to form potentially longer rows in C. Such rows might
444   contain repeats, which does not matter since they will be summed up by other routines. C's row size will be nroots of ownerSF.
445 
446   Input Parameters:
447 +  A        - the SEQAIJKOKKOS matrix to be reduced
448 .  reuse    - either MAT_INITIAL_MATRIX or MAT_REUSE_MATRIX
449 .  local    - true if A uses local col ids; false if A is already in global col ids.
450 .  N        - if local, N is A's global col size
451 .  l2g      - if local, a map mapping A's local col ids to global ones, which are in range of [0,N).
452 -  ownerSF  - the SF specifies ownership (root) of rows in A
453 
454   Output Parameters:
455 +  reduceSF    - the SF to reduce A's rows to contiguous buffers at the receiver side
456 .  abuf         - a contiguous buffer to receive A's rows sent to this proc. Suppose there are 'nrows' such rows.
457 .  srcrowoffset - offset array of size nrows+1. Each entry is the corresponding row's offset in abuf[]. srcrowoffset[i+1]-srcrowoffset[i] is row i's len.
458 .  dstrowoffset - offset array of size nrows. Each entry is the corresponding row's offset in Ca[], i.e., C's 'a' array. Row i, i+1 in abuf[] may go to
459                   unrelated places in Ca, so dstrowoffset is not in CSR-like format as srcrowoffset.
460 -  C            - the matrix made up by rows sent to me from other ranks, using global col ids
461 
462    TODO: we can even have MatSeqAIJKokkosReduceBegin/End to provide oppertunity for callers to overlap comp./comm. when reuse = MAT_REUSE_MATRIX.
463  */
464 static PetscErrorCode MatSeqAIJKokkosReduce(Mat A, MatReuse reuse, PetscBool local, PetscInt N, const ConstMatColIdxKokkosView &l2g, PetscSF ownerSF, PetscSF &reduceSF, MatScalarKokkosView &abuf, MatRowMapKokkosView &srcrowoffset, MatRowMapKokkosView &dstrowoffset, KokkosCsrMatrix &C) {
465   PetscInt          i, r, Am, An, Annz, Cnnz, nrows;
466   const PetscInt   *Ai;
467   Mat_SeqAIJKokkos *akok;
468 
469   PetscFunctionBegin;
470   PetscCall(MatSeqAIJKokkosSyncDevice(A)); /* So that A's latest data is on device */
471   PetscCall(MatGetSize(A, &Am, &An));
472   Ai   = static_cast<Mat_SeqAIJ *>(A->data)->i;
473   akok = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
474   Annz = Ai[Am];
475 
476   if (reuse == MAT_REUSE_MATRIX) {
477     /* Send Aa to abuf */
478     PetscCallMPI(PetscSFReduceBegin(reduceSF, MPIU_SCALAR, akok->a_device_data(), abuf.data(), MPI_REPLACE));
479     PetscCallMPI(PetscSFReduceEnd(reduceSF, MPIU_SCALAR, akok->a_device_data(), abuf.data(), MPI_REPLACE));
480 
481     /* Copy abuf to Ca */
482     const MatScalarKokkosView &Ca = C.values;
483     nrows                         = dstrowoffset.extent(0); /* Not srcrowoffset[] since it has an extra entry for CSR */
484     Kokkos::parallel_for(
485       Kokkos::TeamPolicy<>(nrows, Kokkos::AUTO()), KOKKOS_LAMBDA(const KokkosTeamMemberType &t) {
486         PetscInt i   = t.league_rank();
487         PetscInt src = srcrowoffset(i), dst = dstrowoffset(i);
488         PetscInt len = srcrowoffset(i + 1) - srcrowoffset(i);
489         Kokkos::parallel_for(Kokkos::TeamThreadRange(t, len), [&](PetscInt k) { Ca(dst + k) = abuf(src + k); });
490       });
491   } else if (reuse == MAT_INITIAL_MATRIX) {
492     MPI_Comm     comm;
493     MPI_Request *reqs;
494     PetscMPIInt  tag;
495     PetscInt     Cm;
496 
497     PetscCall(PetscObjectGetComm((PetscObject)ownerSF, &comm));
498     PetscCall(PetscCommGetNewTag(comm, &tag));
499 
500     PetscInt           niranks, nranks, nroots, nleaves;
501     const PetscMPIInt *iranks, *ranks;
502     const PetscInt    *ioffset, *rows, *roffset; /* rows[] contains local indices of rows scattered to me from others. ioffset[] is a CSR on rows[] */
503     PetscCall(PetscSFSetUp(ownerSF));
504     PetscCall(PetscSFGetLeafRanks(ownerSF, &niranks, &iranks, &ioffset, &rows));                          /* recv info: iranks[] will send rows to me */
505     PetscCall(PetscSFGetRootRanks(ownerSF, &nranks, &ranks, &roffset, NULL /*rmine*/, NULL /*rremote*/)); /* send info */
506     PetscCall(PetscSFGetGraph(ownerSF, &nroots, &nleaves, NULL, NULL));
507     PetscCheck(nleaves == Am, PETSC_COMM_SELF, PETSC_ERR_PLIB, "ownerSF's nleaves(%" PetscInt_FMT ") != row size of A(%" PetscInt_FMT ")", nleaves, Am);
508     Cm    = nroots;
509     nrows = ioffset[niranks]; /* # of rows to be received. Might receive same row (each is partial) from different senders */
510 
511     /* Tell owners how long each row I will send */
512     PetscInt               *srowlens;                              /* send buf of row lens */
513     MatRowMapKokkosViewHost rrowlens_h("rrowoffset_h", nrows + 1); /* recv buf of row lens. +1 to make CSR later. Memory might be passed to other views */
514     PetscInt               *rrowlens = rrowlens_h.data();
515 
516     PetscCall(PetscMalloc2(Am, &srowlens, niranks + nranks, &reqs));
517     for (i = 0; i < Am; i++) srowlens[i] = Ai[i + 1] - Ai[i];
518     rrowlens[0] = 0;
519     rrowlens++; /* shift the pointer to make the following expression more readable */
520     for (i = 0; i < niranks; i++) PetscCallMPI(MPI_Irecv(&rrowlens[ioffset[i]], ioffset[i + 1] - ioffset[i], MPIU_INT, iranks[i], tag, comm, &reqs[i]));
521     for (i = 0; i < nranks; i++) PetscCallMPI(MPI_Isend(&srowlens[roffset[i]], roffset[i + 1] - roffset[i], MPIU_INT, ranks[i], tag, comm, &reqs[niranks + i]));
522     PetscCallMPI(MPI_Waitall(niranks + nranks, reqs, MPI_STATUSES_IGNORE));
523 
524     /* Owner builds Ci on host by histogramming rrowlens[] */
525     MatRowMapKokkosViewHost Ci_h("i", Cm + 1);
526     Kokkos::deep_copy(Ci_h, 0); /* Zero Ci */
527     MatRowMapType *Ci_ptr = Ci_h.data();
528 
529     for (i = 0; i < nrows; i++) {
530       r = rows[i]; /* local row id of i-th received row */
531 #if defined(PETSC_USE_DEBUG)
532       PetscCheck(r >= 0 && r < Cm, PETSC_COMM_SELF, PETSC_ERR_PLIB, "local row id (%" PetscInt_FMT ") is out of range [0,%" PetscInt_FMT ")", r, Cm);
533 #endif
534       Ci_ptr[r + 1] += rrowlens[i]; /* add to length of row r in C */
535     }
536     for (i = 0; i < Cm; i++) Ci_ptr[i + 1] += Ci_ptr[i]; /* to CSR format */
537     Cnnz = Ci_ptr[Cm];
538 
539     /* For each received row, compute src & dst offsets in memory copying (from recv bufs abuf, jbuf to Ca, Cj) */
540     MatRowMapKokkosViewHost dstrowoffset_h("dstrowoffset_h", nrows);
541     PetscInt               *dstrowoffset_hptr = dstrowoffset_h.data();
542     PetscInt               *currowlens; /* Current row lens. They are temp accumulators for row lens in C, to help build dstrowoffset */
543 
544     PetscCall(PetscCalloc1(Cm, &currowlens));           /* Init with zero, to be added to */
545     for (i = 0; i < nrows; i++) {                       /* for each row I receive */
546       r                    = rows[i];                   /* row id in C */
547       dstrowoffset_hptr[i] = Ci_ptr[r] + currowlens[r]; /* dst offset of the new place for each recv'ed row in Ca/Cj */
548       currowlens[r] += rrowlens[i];                     /* accumulate to length of row r in C */
549     }
550     PetscCall(PetscFree(currowlens));
551 
552     rrowlens--;
553     for (i = 0; i < nrows; i++) rrowlens[i + 1] += rrowlens[i]; /* Change rrowlens[] to CSR format */
554     dstrowoffset = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), dstrowoffset_h);
555     srcrowoffset = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), rrowlens_h); /* src offset of each recv'ed row in abuf/jbuf */
556 
557     /* Build the reduceSF, which performs buffer to buffer send/recv */
558     PetscInt *sdisp, *rdisp; /* buffer to send offsets of roots, and buffer to recv them */
559     PetscCall(PetscMalloc2(niranks, &sdisp, nranks, &rdisp));
560     for (i = 0; i < niranks; i++) sdisp[i] = rrowlens[ioffset[i]];
561     for (i = 0; i < nranks; i++) PetscCallMPI(MPI_Irecv(&rdisp[i], 1, MPIU_INT, ranks[i], tag, comm, &reqs[i]));
562     for (i = 0; i < niranks; i++) PetscCallMPI(MPI_Isend(&sdisp[i], 1, MPIU_INT, iranks[i], tag, comm, &reqs[nranks + i]));
563     PetscCallMPI(MPI_Waitall(niranks + nranks, reqs, MPI_STATUSES_IGNORE));
564 
565     /* Nonzeros in abuf/jbuf are roots and those in A are leaves */
566     PetscInt     nroots2 = Cnnz, nleaves2 = Annz;
567     PetscSFNode *iremote;
568     PetscCall(PetscMalloc1(nleaves2, &iremote)); /* no free, since memory will be given to reduceSF */
569     for (i = 0; i < nranks; i++) {
570       PetscInt rootbase = rdisp[i];                      /* root offset at this root rank */
571       PetscInt leafbase = Ai[roffset[i]];                /* leaf base */
572       PetscInt nz       = Ai[roffset[i + 1]] - leafbase; /* I will send nz nonzeros to this root rank */
573       for (PetscInt k = 0; k < nz; k++) {
574         iremote[leafbase + k].rank  = ranks[i];
575         iremote[leafbase + k].index = rootbase + k;
576       }
577     }
578     PetscCall(PetscSFCreate(comm, &reduceSF));
579     PetscCall(PetscSFSetGraph(reduceSF, nroots2, nleaves2, NULL, PETSC_OWN_POINTER, iremote, PETSC_OWN_POINTER));
580     PetscCall(PetscFree2(sdisp, rdisp));
581 
582     /* Reduce Aa, Ajg to abuf and jbuf */
583 
584     /* If A uses local col ids, convert them to global ones before sending */
585     MatColIdxKokkosView Ajg;
586     if (local) {
587       Ajg                           = MatColIdxKokkosView("j", Annz);
588       const MatColIdxKokkosView &Aj = akok->j_dual.view_device();
589       Kokkos::parallel_for(
590         Annz, KOKKOS_LAMBDA(const PetscInt i) { Ajg(i) = l2g(Aj(i)); });
591     } else {
592       Ajg = akok->j_dual.view_device(); /* no data copy, just take a reference */
593     }
594 
595     MatColIdxKokkosView jbuf("jbuf", Cnnz);
596     abuf = MatScalarKokkosView("abuf", Cnnz);
597     PetscCallMPI(PetscSFReduceBegin(reduceSF, MPIU_INT, Ajg.data(), jbuf.data(), MPI_REPLACE));
598     PetscCallMPI(PetscSFReduceEnd(reduceSF, MPIU_INT, Ajg.data(), jbuf.data(), MPI_REPLACE));
599     PetscCallMPI(PetscSFReduceBegin(reduceSF, MPIU_SCALAR, akok->a_device_data(), abuf.data(), MPI_REPLACE));
600     PetscCallMPI(PetscSFReduceEnd(reduceSF, MPIU_SCALAR, akok->a_device_data(), abuf.data(), MPI_REPLACE));
601 
602     /* Copy data from abuf, jbuf to Ca, Cj */
603     MatRowMapKokkosView Ci = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Ci_h); /* Ci is an alias of Ci_h if no device */
604     MatColIdxKokkosView Cj("j", Cnnz);
605     MatScalarKokkosView Ca("a", Cnnz);
606 
607     Kokkos::parallel_for(
608       Kokkos::TeamPolicy<>(nrows, Kokkos::AUTO()), KOKKOS_LAMBDA(const KokkosTeamMemberType &t) {
609         PetscInt i   = t.league_rank();
610         PetscInt src = srcrowoffset(i), dst = dstrowoffset(i);
611         PetscInt len = srcrowoffset(i + 1) - srcrowoffset(i);
612         Kokkos::parallel_for(Kokkos::TeamThreadRange(t, len), [&](PetscInt k) {
613           Ca(dst + k) = abuf(src + k);
614           Cj(dst + k) = jbuf(src + k);
615         });
616       });
617 
618     /* Build C with Ca, Ci, Cj */
619     C = KokkosCsrMatrix("csrmat", Cm, N, Cnnz, Ca, Ci, Cj);
620     PetscCall(PetscFree2(srowlens, reqs));
621   } else SETERRQ(PETSC_COMM_SELF, PETSC_ERR_PLIB, "Unsupported MatReuse enum %d", reuse);
622   PetscFunctionReturn(0);
623 }
624 
625 /* MatSetMPIAIJKokkosWithGlobalCSRMatrix - Set the diag and offdiag parts of a `MATMPIAIJKOKKOS` matrix by splitting a KokkosCsrMatrix
626 
627   Input Parameters:
628 +  C        - the `MATMPIAIJKOKKOS` matrix, of size m,n,M,N
629 .  reuse    - indicate whether the matrix has called this function before
630 .  csrmat   - the KokkosCsrMatrix, of size m,N
631 -  Cdstart  - when reuse == `MAT_REUSE_MATRIX`, it is an input parameter. For each row in csrmat, it stores the start of the first
632               entry of the diag block of C in csrmat's j array. E.g, if row i has col ids = {0, 3, 4, 5, 7, 9} and the first diag
633               entry is 5, then Cdstart[i] = 3.
634 
635   Output Parameters:
636 +  C        - the updated `MATMPIAIJKOKKOS` matrix
637 -  Cdstart - when reuse == `MAT_INITIAL_MATRIX`, it is an output parameter
638 
639   Note:
640    Between calls with `MAT_INITIAL_MATRIX` or `MAT_REUSE_MATRIX`, csrmat must have the same nonzero pattern
641 
642 .seealso: `MATMPIAIJKOKKOS`
643  */
644 static PetscErrorCode MatSetMPIAIJKokkosWithGlobalCSRMatrix(Mat C, MatReuse reuse, const KokkosCsrMatrix &csrmat, MatRowMapKokkosView &Cdstart) {
645   const MatScalarKokkosView      &Ca = csrmat.values;
646   const ConstMatRowMapKokkosView &Ci = csrmat.graph.row_map;
647   PetscInt                        m, n, N;
648 
649   PetscFunctionBegin;
650   PetscCall(MatGetLocalSize(C, &m, &n));
651   PetscCall(MatGetSize(C, NULL, &N));
652 
653   if (reuse == MAT_REUSE_MATRIX) {
654     Mat_MPIAIJ                *mpiaij = static_cast<Mat_MPIAIJ *>(C->data);
655     Mat_SeqAIJKokkos          *akok   = static_cast<Mat_SeqAIJKokkos *>(mpiaij->A->spptr);
656     Mat_SeqAIJKokkos          *bkok   = static_cast<Mat_SeqAIJKokkos *>(mpiaij->B->spptr);
657     const MatScalarKokkosView &Cda = akok->a_dual.view_device(), Coa = bkok->a_dual.view_device();
658     const MatRowMapKokkosView &Cdi = akok->i_dual.view_device(), Coi = bkok->i_dual.view_device();
659 
660     /* Fill 'a' of Cd and Co on device */
661     Kokkos::parallel_for(
662       Kokkos::TeamPolicy<>(m, Kokkos::AUTO()), KOKKOS_LAMBDA(const KokkosTeamMemberType &t) {
663         PetscInt i       = t.league_rank();     /* row i */
664         PetscInt clen    = Ci(i + 1) - Ci(i);   /* len of row i of C */
665         PetscInt cdlen   = Cdi(i + 1) - Cdi(i); /* len of row i of Cd */
666         PetscInt cdstart = Cdstart(i);          /* [start, end) of row i of Cd in C */
667         PetscInt cdend   = cdstart + cdlen;
668         /* [0, clen) is cut into three blocks: [0, cdstart), [cdstart, cdend), [cdend, clen) */
669         Kokkos::parallel_for(Kokkos::TeamThreadRange(t, clen), [&](PetscInt k) {
670           if (k < cdstart) { /* k in [0, cdstart) */
671             Coa(Coi(i) + k) = Ca(Ci(i) + k);
672           } else if (k < cdend) { /* k in [cdstart, cdend) */
673             Cda(Cdi(i) + (k - cdstart)) = Ca(Ci(i) + k);
674           } else { /* k in [cdend, clen) */
675             Coa(Coi(i) + k - cdlen) = Ca(Ci(i) + k);
676           }
677         });
678       });
679 
680     akok->a_dual.modify_device();
681     bkok->a_dual.modify_device();
682   } else if (reuse == MAT_INITIAL_MATRIX) {
683     Mat                        Cd, Co;
684     const MatColIdxKokkosView &Cj = csrmat.graph.entries;
685     MatRowMapKokkosDualView    Cdi_dual("i", m + 1), Coi_dual("i", m + 1);
686     MatRowMapKokkosView        Cdi = Cdi_dual.view_device(), Coi = Coi_dual.view_device();
687     PetscInt                   cstart, cend;
688 
689     /* Note that each row of C is sorted by col ids. We want to find out how to cut each row into three blocks:
690        left to the diag block, diag block, right to the diag block. The diag block have col ids in [cstart,cend).
691        Suppose a row of C has len nonzeros, indexed by [0, len). We want to know two indices: cdstart and cdend,
692        such that the three blocks are [0,cdstart), [cdstart,cdend), [cdend,len). The following code equivalentaly
693        stores values of cdstart and cdend-cstart (aka Cdi[]) instead.
694      */
695     Cdstart = MatRowMapKokkosView("Cdstart", m);
696     PetscCall(PetscLayoutGetRange(C->cmap, &cstart, &cend)); /* Not MatGetOwnershipRangeColumn() since C has not been preallocated yet */
697 
698     /* I could use RangePolicy and one thread per row. But since each thread essentially does binary search, threads in a
699       CUDA warp would completely diverge. So I use TeamPolicy with a team size 1.
700      */
701     Kokkos::parallel_for(
702       Kokkos::TeamPolicy<>(m, 1), KOKKOS_LAMBDA(const KokkosTeamMemberType &t) {
703         Kokkos::single(Kokkos::PerTeam(t), [=]() {                               /* Only one thread works in a team */
704                                                    PetscInt i = t.league_rank(); /* row i */
705                                                    PetscInt j, first, count, step;
706 
707                                                    if (i == 0) { /* Set the first entry of the i arrays to zero on device, to be used in CSR */
708                                                      Cdi(0) = 0;
709                                                      Coi(0) = 0;
710                                                    }
711 
712                                                    /* Do std::lower_bound(Ci(i),Ci(i+1),cstart) on Cj[]. We use j as the iterator. lower_bound() returns
713           in 'first' the first iterator with a value >= cstart, or last iterator if no such element is found.
714         */
715                                                    count = Ci(i + 1) - Ci(i);
716                                                    first = Ci(i);
717                                                    while (count > 0) {
718                                                      j    = first;
719                                                      step = count / 2;
720                                                      j += step;
721                                                      if (Cj(j) < cstart) {
722                                                        first = ++j;
723                                                        count -= step + 1;
724                                                      } else count = step;
725                                                    }
726                                                    Cdstart(i) = first - Ci(i); /* 'first' is the while-loop's output */
727 
728                                                    /* Do std::lower_bound(first,Ci(i+1),cend) on Cj[] */
729                                                    count = Ci(i + 1) - first;
730                                                    while (count > 0) {
731                                                      j    = first;
732                                                      step = count / 2;
733                                                      j += step;
734                                                      if (Cj(j) < cend) {
735                                                        first = ++j;
736                                                        count -= step + 1;
737                                                      } else count = step;
738                                                    }
739                                                    Cdi(i + 1) = first - (Ci(i) + Cdstart(i));     /* 'first' is the while-loop's output */
740                                                    Coi(i + 1) = (Ci(i + 1) - Ci(i)) - Cdi(i + 1); /* Co's row len = C's row len - Cd's row len */
741         });
742       });
743 
744     /* Convert row lens in Cdi[], Coi[] to CSR format using inclusive scan, e.g., changing [0,1,2,3] into [0,1,3,6] */
745     Kokkos::parallel_scan(
746       m + 1, KOKKOS_LAMBDA(const PetscInt i, PetscInt &update, const bool final) {
747         update += Cdi(i);
748         if (final) Cdi(i) = update;
749       });
750     Kokkos::parallel_scan(
751       m + 1, KOKKOS_LAMBDA(const PetscInt i, PetscInt &update, const bool final) {
752         update += Coi(i);
753         if (final) Coi(i) = update;
754       });
755 
756     /* Get Cdi, Coi on host (it is not a waste, since we do need them on host in
757        MatCreateSeqAIJKokkosWithCSRMatrix() below), then get nnz of Cd and Co.
758     */
759     Cdi_dual.modify_device();
760     Coi_dual.modify_device();
761     Cdi_dual.sync_host();
762     Coi_dual.sync_host();
763     PetscInt Cd_nnz = Cdi_dual.view_host().data()[m];
764     PetscInt Co_nnz = Coi_dual.view_host().data()[m];
765 
766     /* With nnz, allocate a, j for Cd and Co */
767     MatColIdxKokkosDualView Cdj_dual("j", Cd_nnz), Coj_dual("j", Co_nnz);
768     MatScalarKokkosDualView Cda_dual("a", Cd_nnz), Coa_dual("a", Co_nnz);
769 
770     /* Fill a, j of Cd and Co on device */
771     MatColIdxKokkosView Cdj = Cdj_dual.view_device(), Coj = Coj_dual.view_device();
772     MatScalarKokkosView Cda = Cda_dual.view_device(), Coa = Coa_dual.view_device();
773 
774     Kokkos::parallel_for(
775       Kokkos::TeamPolicy<>(m, Kokkos::AUTO()), KOKKOS_LAMBDA(const KokkosTeamMemberType &t) {
776         PetscInt i       = t.league_rank();     /* row i */
777         PetscInt clen    = Ci(i + 1) - Ci(i);   /* len of row i of C */
778         PetscInt cdlen   = Cdi(i + 1) - Cdi(i); /* len of row i of Cd */
779         PetscInt cdstart = Cdstart(i);          /* [start, end) of row i of Cd in C */
780         PetscInt cdend   = cdstart + cdlen;
781         /* [0, clen) is cut into three blocks: [0, cdstart), [cdstart, cdend), [cdend, clen) */
782         Kokkos::parallel_for(Kokkos::TeamThreadRange(t, clen), [&](PetscInt k) {
783           if (k < cdstart) { /* k in [0, cdstart) */
784             Coa(Coi(i) + k) = Ca(Ci(i) + k);
785             Coj(Coi(i) + k) = Cj(Ci(i) + k);
786           } else if (k < cdend) { /* k in [cdstart, cdend) */
787             Cda(Cdi(i) + (k - cdstart)) = Ca(Ci(i) + k);
788             Cdj(Cdi(i) + (k - cdstart)) = Cj(Ci(i) + k) - cstart; /* Use local col ids in Cdj */
789           } else {                                                /* k in [cdend, clen) */
790             Coa(Coi(i) + k - cdlen) = Ca(Ci(i) + k);
791             Coj(Coi(i) + k - cdlen) = Cj(Ci(i) + k);
792           }
793         });
794       });
795 
796     Cdj_dual.modify_device();
797     Cda_dual.modify_device();
798     Coj_dual.modify_device();
799     Coa_dual.modify_device();
800     /* With a, i, j for Cd and Co, finally build Cd, Co and then C. Their offloadmask will be set in each's MatAssemblyEnd */
801     auto cdkok = new Mat_SeqAIJKokkos(m, n, Cd_nnz, Cdi_dual, Cdj_dual, Cda_dual);
802     auto cokok = new Mat_SeqAIJKokkos(m, N, Co_nnz, Coi_dual, Coj_dual, Coa_dual);
803     PetscCall(MatCreateSeqAIJKokkosWithCSRMatrix(PETSC_COMM_SELF, cdkok, &Cd));
804     PetscCall(MatCreateSeqAIJKokkosWithCSRMatrix(PETSC_COMM_SELF, cokok, &Co));
805     PetscCall(MatSetMPIAIJKokkosWithSplitSeqAIJKokkosMatrices(C, Cd, Co)); /* Coj will be converted to local ids within */
806   }
807   PetscFunctionReturn(0);
808 }
809 
810 /* MatSeqAIJCompactOutExtraColumns_SeqAIJKokkos - Compact a SEQAIJKOKKS matrix's global col ids.
811 
812   It is similar to MatSeqAIJCompactOutExtraColumns_SeqAIJ, but it applies to SEQAIJKOKKOS and returns the l2g map in Kokkos view.
813 
814   Input Parameters:
815 +  C        - the MATMPIAIJKOKKOS matrix, of size m,n,M,N
816 .  reuse    - indicate whether the matrix has called this function before
817 .  csrmat   - the KokkosCsrMatrix, of size m,N
818 -  Cdoffset - when reuse == MAT_REUSE_MATRIX, it is an input parameter. For each row in csrmat, it stores the offset of the first
819               entry of the diag block of C in csrmat's j array.
820 
821   Output Parameters:
822 +  C        - the updated MATMPIAIJKOKKOS matrix
823 -  Cdoffset - when reuse == MAT_INITIAL_MATRIX, it is an output parameter
824 
825   Note:
826   the input matrix's col ids and col size will be changed.
827 */
828 static PetscErrorCode MatSeqAIJCompactOutExtraColumns_SeqAIJKokkos(Mat C, MatColIdxKokkosView &l2g) {
829   Mat_SeqAIJKokkos      *ckok;
830   ISLocalToGlobalMapping l2gmap;
831   const PetscInt        *garray;
832   PetscInt               sz;
833 
834   PetscFunctionBegin;
835   /* Compact P_other's global col ids and col size. We do it since we guess with local ids KK might be more memory scalable */
836   PetscCall(MatSeqAIJCompactOutExtraColumns_SeqAIJ(C, &l2gmap));
837   ckok = static_cast<Mat_SeqAIJKokkos *>(C->spptr);
838   ckok->j_dual.modify_host(); /* P_other's j is modified on host; we need to sync it on device */
839   ckok->j_dual.sync_device();
840   ckok->SetColSize(C->cmap->n); /* Update col size of the csrmat in spptr */
841 
842   /* Build l2g -- the local to global mapping of C's cols */
843   PetscCall(ISLocalToGlobalMappingGetIndices(l2gmap, &garray));
844   PetscCall(ISLocalToGlobalMappingGetSize(l2gmap, &sz));
845   PetscCheck(C->cmap->n == sz, PETSC_COMM_SELF, PETSC_ERR_PLIB, "matrix column size(%" PetscInt_FMT ") != l2g mapping size(%" PetscInt_FMT ")", C->cmap->n, sz);
846 
847   ConstMatColIdxKokkosViewHost tmp(garray, sz);
848   l2g = MatColIdxKokkosView("l2g", sz);
849   Kokkos::deep_copy(l2g, tmp);
850 
851   PetscCall(ISLocalToGlobalMappingRestoreIndices(l2gmap, &garray));
852   PetscCall(ISLocalToGlobalMappingDestroy(&l2gmap));
853   PetscFunctionReturn(0);
854 }
855 
856 /* MatProductSymbolic_MPIAIJKokkos_AB - AB flavor of MatProductSymbolic_MPIAIJKokkos
857 
858   Input Parameters:
859 +  product  - Mat_Product which carried out the computation. Passed in to access info about this mat product.
860 .  A        - an MPIAIJKOKKOS matrix
861 .  B        - an MPIAIJKOKKOS matrix
862 -  mm       - a struct used to stash intermediate data when computing AB. Persist from symbolic to numeric operations.
863 
864   Note: The local part of the result C is stored as mm->C_global, which is of type KokkosCsrMatrix and uses global col ids.
865 */
866 static PetscErrorCode MatProductSymbolic_MPIAIJKokkos_AB(Mat_Product *product, Mat A, Mat B, MatMatStruct_AB *mm) {
867   Mat_MPIAIJ              *a  = static_cast<Mat_MPIAIJ *>(A->data);
868   Mat                      Ad = a->A, Ao = a->B; /* diag and offdiag of A */
869   IS                       glob = NULL;
870   const PetscInt          *garray;
871   PetscInt                 N = B->cmap->N, sz;
872   ConstMatColIdxKokkosView l2g1; /* two temp maps mapping local col ids to global ones */
873   MatColIdxKokkosView      l2g2;
874   Mat                      C1, C2; /* intermediate matrices */
875 
876   PetscFunctionBegin;
877   /* C1 = Ad * B_local. B_local is a matrix got by merging Bd and Bo, and uses local col ids */
878   PetscCall(MatMPIAIJGetLocalMatMerge(B, MAT_INITIAL_MATRIX, &glob, &mm->B_local));
879   PetscCall(MatProductCreate(Ad, mm->B_local, NULL, &C1));
880   PetscCall(MatProductSetType(C1, MATPRODUCT_AB));
881   PetscCall(MatProductSetFill(C1, product->fill));
882   C1->product->api_user = product->api_user;
883   PetscCall(MatProductSetFromOptions(C1));
884   PetscUseTypeMethod(C1, productsymbolic);
885 
886   PetscCall(ISGetIndices(glob, &garray));
887   PetscCall(ISGetSize(glob, &sz));
888   const auto &tmp = ConstMatColIdxKokkosViewHost(garray, sz);                       /* wrap garray as a view */
889   l2g1            = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), tmp); /* maybe just an alias to tmp, so we restore garray at the very end */
890   PetscCall(MatSeqAIJKokkosGetCSRMatrixWithGlobalColumnIds(C1, N, l2g1, mm->C1_global));
891 
892   /* C2 = Ao * B_other. B_other is a matrix consisting of needed rows of B gathered from other procs */
893   PetscCall(MatSeqAIJKokkosBcast(mm->B_local, MAT_INITIAL_MATRIX, N, l2g1, a->Mvctx, mm->sf, mm->abuf, mm->rows, mm->rowoffset, mm->B_other));
894 
895   /* Compact B_other to use local ids as we guess KK spgemm is more memroy scalable with that; We could skip the compaction to simplify code */
896   PetscCall(MatSeqAIJCompactOutExtraColumns_SeqAIJKokkos(mm->B_other, l2g2));
897   PetscCall(MatProductCreate(Ao, mm->B_other, NULL, &C2));
898   PetscCall(MatProductSetType(C2, MATPRODUCT_AB));
899   PetscCall(MatProductSetFill(C2, product->fill));
900   C2->product->api_user = product->api_user;
901   PetscCall(MatProductSetFromOptions(C2));
902   PetscUseTypeMethod(C2, productsymbolic);
903   PetscCall(MatSeqAIJKokkosGetCSRMatrixWithGlobalColumnIds(C2, N, l2g2, mm->C2_global));
904 
905   /* C = C1 + C2.  We actually use their global col ids versions in adding */
906   mm->kh.create_spadd_handle(false); /* Input C1, C2 are NOT sorted, since B_local, B_other are not */
907   KokkosSparse::spadd_symbolic(&mm->kh, mm->C1_global, mm->C2_global, mm->C_global);
908   /* Have to do numeric since spadd_symbolic does not really populate column indices of the result matrix */
909   KokkosSparse::spadd_numeric(&mm->kh, (MatScalarType)1.0, mm->C1_global, (MatScalarType)1.0, mm->C2_global, mm->C_global);
910 
911   mm->C1 = C1;
912   mm->C2 = C2;
913   PetscCall(ISRestoreIndices(glob, &garray));
914   PetscCall(ISDestroy(&glob));
915   PetscFunctionReturn(0);
916 }
917 
918 /* MatProductSymbolic_MPIAIJKokkos_AtB - A^tB flavor of MatProductSymbolic_MPIAIJKokkos
919 
920   Input Parameters:
921 +  product  - Mat_Product which carried out the computation. Passed in to access info about this mat product.
922 .  A        - an MPIAIJKOKKOS matrix
923 .  B        - a SEQAIJKOKKOS matrix. It works as if A^t is multiplied by a parallel matrix made up of Bs on each rank.
924 .  localB   - Does B use local col ids? If false, then B is already in global col ids.
925 .  N        - col size of the "parallel B matrix". It implies B's global col ids are in range of [0,N) and N is the same across the communicator.
926 .  l2g      - If localB, then l2g maps B's local col ids to global ones.
927 -  mm       - a struct used to stash intermediate data in AtB
928 
929   Note: The local part of the result C is stored as mm->C_global, which is of type KokkosCsrMatrix and uses global col ids.
930 */
931 static PetscErrorCode MatProductSymbolic_MPIAIJKokkos_AtB(Mat_Product *product, Mat A, Mat B, PetscBool localB, PetscInt N, const ConstMatColIdxKokkosView &l2g, MatMatStruct_AtB *mm) {
932   Mat_MPIAIJ *a  = static_cast<Mat_MPIAIJ *>(A->data);
933   Mat         Ad = a->A, Ao = a->B; /* diag and offdiag of A */
934   Mat         C1, C2;               /* intermediate matrices */
935 
936   PetscFunctionBegin;
937   /* C1 = Ad^t * B */
938   PetscCall(MatProductCreate(Ad, B, NULL, &C1));
939   PetscCall(MatProductSetType(C1, MATPRODUCT_AtB));
940   PetscCall(MatProductSetFill(C1, product->fill));
941   C1->product->api_user = product->api_user;
942   PetscCall(MatProductSetFromOptions(C1));
943   PetscUseTypeMethod(C1, productsymbolic);
944 
945   if (localB) PetscCall(MatSeqAIJKokkosGetCSRMatrixWithGlobalColumnIds(C1, N, l2g, mm->C1_global));
946   else mm->C1_global = static_cast<Mat_SeqAIJKokkos *>(C1->spptr)->csrmat; /* the csrmat already uses global col ids */
947 
948   /* C2 = Ao^t * B */
949   PetscCall(MatProductCreate(Ao, B, NULL, &C2));
950   PetscCall(MatProductSetType(C2, MATPRODUCT_AtB));
951   PetscCall(MatProductSetFill(C2, product->fill));
952   C2->product->api_user = product->api_user;
953   PetscCall(MatProductSetFromOptions(C2));
954   PetscUseTypeMethod(C2, productsymbolic);
955 
956   PetscCall(MatSeqAIJKokkosReduce(C2, MAT_INITIAL_MATRIX, localB, N, l2g, a->Mvctx, mm->sf, mm->abuf, mm->srcrowoffset, mm->dstrowoffset, mm->C2_global));
957 
958   mm->kh.create_spadd_handle(false); /* Input C1, C2 are NOT sorted, since B may be not */
959   KokkosSparse::spadd_symbolic(&mm->kh, mm->C1_global, mm->C2_global, mm->C_global);
960   /* Have to do numeric since spadd_symbolic does not really populate column indices of the result matrix */
961   KokkosSparse::spadd_numeric(&mm->kh, (MatScalarType)1.0, mm->C1_global, (MatScalarType)1.0, mm->C2_global, mm->C_global);
962   mm->C1 = C1;
963   mm->C2 = C2;
964   PetscFunctionReturn(0);
965 }
966 
967 PetscErrorCode MatProductNumeric_MPIAIJKokkos(Mat C) {
968   Mat_Product                 *product = C->product;
969   MatProductType               ptype;
970   MatProductData_MPIAIJKokkos *mmdata;
971   MatMatStruct                *mm = NULL;
972   MatMatStruct_AB             *ab;
973   MatMatStruct_AtB            *atb;
974   Mat                          A, B, Ad, Ao, Bd, Bo;
975   const MatScalarType          one = 1.0; /* Not use literal 1.0 directly, to avoid wrong template instantiation in KokkosSparse::spadd_numeric */
976 
977   PetscFunctionBegin;
978   MatCheckProduct(C, 1);
979   mmdata = static_cast<MatProductData_MPIAIJKokkos *>(product->data);
980   ptype  = product->type;
981   A      = product->A;
982   B      = product->B;
983   Ad     = static_cast<Mat_MPIAIJ *>(A->data)->A;
984   Ao     = static_cast<Mat_MPIAIJ *>(A->data)->B;
985   Bd     = static_cast<Mat_MPIAIJ *>(B->data)->A;
986   Bo     = static_cast<Mat_MPIAIJ *>(B->data)->B;
987 
988   if (mmdata->reusesym) {           /* We reached here through e.g., MatMatMult(A,B,MAT_INITIAL_MATRIX,..,C), where symbolic/numeric are combined */
989     mmdata->reusesym = PETSC_FALSE; /* So that next time when user calls MatMatMult(E,F,MAT_REUSE_MATRIX,..,C), we still do numeric  */
990     ab               = mmdata->mmAB;
991     atb              = mmdata->mmAtB;
992     if (ab) {
993       static_cast<MatProductData_SeqAIJKokkos *>(ab->C1->product->data)->reusesym = PETSC_FALSE;
994       static_cast<MatProductData_SeqAIJKokkos *>(ab->C2->product->data)->reusesym = PETSC_FALSE;
995     }
996     if (atb) {
997       static_cast<MatProductData_SeqAIJKokkos *>(atb->C1->product->data)->reusesym = PETSC_FALSE;
998       static_cast<MatProductData_SeqAIJKokkos *>(atb->C2->product->data)->reusesym = PETSC_FALSE;
999     }
1000     PetscFunctionReturn(0);
1001   }
1002 
1003   if (ptype == MATPRODUCT_AB) {
1004     ab = mmdata->mmAB;
1005     /* C1 = Ad * B_local */
1006     PetscCheck(ab->C1->ops->productnumeric && ab->C2->ops->productnumeric, PETSC_COMM_SELF, PETSC_ERR_PLIB, "Missing numeric op for MATPRODUCT_AB");
1007     PetscCall(MatMPIAIJGetLocalMatMerge(B, MAT_REUSE_MATRIX, NULL /* glob */, &ab->B_local));
1008     PetscCheck(ab->C1->product->B == ab->B_local, PETSC_COMM_SELF, PETSC_ERR_PLIB, "In MATPRODUCT_AB, internal mat product matrix C1->B has unexpectedly changed");
1009     if (ab->C1->product->A != Ad) PetscCall(MatProductReplaceMats(Ad, NULL, NULL, ab->C1));
1010     PetscCall((*ab->C1->ops->productnumeric)(ab->C1));
1011     PetscCall(MatSeqAIJKokkosBcast(ab->B_local, MAT_REUSE_MATRIX, 0 /* N */, MatColIdxKokkosView() /*l2g*/, NULL /*ownerSF*/, ab->sf, ab->abuf, ab->rows, ab->rowoffset, ab->B_other));
1012     /* C2 = Ao * B_other */
1013     PetscCheck(ab->C2->product->B == ab->B_other, PETSC_COMM_SELF, PETSC_ERR_PLIB, "In MATPRODUCT_AB, internal mat product matrix C2->B has unexpectedly changed");
1014     if (ab->C1->product->A != Ao) PetscCall(MatProductReplaceMats(Ao, NULL, NULL, ab->C2));
1015     PetscCall((*ab->C2->ops->productnumeric)(ab->C2));
1016     /* C = C1_global + C2_global */
1017     KokkosSparse::spadd_numeric(&ab->kh, one, ab->C1_global, one, ab->C2_global, ab->C_global);
1018     mm = static_cast<MatMatStruct *>(ab);
1019   } else if (ptype == MATPRODUCT_AtB) {
1020     atb = mmdata->mmAtB;
1021     PetscCheck(atb->C1->ops->productnumeric && atb->C2->ops->productnumeric, PETSC_COMM_SELF, PETSC_ERR_PLIB, "Missing numeric op for MATPRODUCT_AtB");
1022     /* C1 = Ad^t * B_local */
1023     PetscCall(MatMPIAIJGetLocalMatMerge(B, MAT_REUSE_MATRIX, NULL /* glob */, &atb->B_local));
1024     PetscCheck(atb->C1->product->B == atb->B_local, PETSC_COMM_SELF, PETSC_ERR_PLIB, "In MATPRODUCT_AtB, internal mat product matrix C1->B has unexpectedly changed");
1025     if (atb->C1->product->A != Ad) PetscCall(MatProductReplaceMats(Ad, NULL, NULL, atb->C1));
1026     PetscCall((*atb->C1->ops->productnumeric)(atb->C1));
1027 
1028     /* C2 = Ao^t * B_local */
1029     PetscCheck(atb->C2->product->B == atb->B_local, PETSC_COMM_SELF, PETSC_ERR_PLIB, "In MATPRODUCT_AtB, internal mat product matrix C2->B has unexpectedly changed");
1030     if (atb->C2->product->A != Ao) PetscCall(MatProductReplaceMats(Ao, NULL, NULL, atb->C2));
1031     PetscCall((*atb->C2->ops->productnumeric)(atb->C2));
1032     /* Form C2_global */
1033     PetscCall(MatSeqAIJKokkosReduce(atb->C2, MAT_REUSE_MATRIX, PETSC_TRUE, 0 /* N */, MatColIdxKokkosView() /*l2g*/, NULL /*ownerSF*/, atb->sf, atb->abuf, atb->srcrowoffset, atb->dstrowoffset, atb->C2_global));
1034     /* C = C1_global + C2_global */
1035     KokkosSparse::spadd_numeric(&atb->kh, one, atb->C1_global, one, atb->C2_global, atb->C_global);
1036     mm = static_cast<MatMatStruct *>(atb);
1037   } else if (ptype == MATPRODUCT_PtAP) { /* BtAB */
1038     ab = mmdata->mmAB;
1039     PetscCall(MatMPIAIJGetLocalMatMerge(B, MAT_REUSE_MATRIX, NULL /* glob */, &ab->B_local));
1040 
1041     /* ab->C1 = Ad * B_local */
1042     PetscCheck(ab->C1->product->B == ab->B_local, PETSC_COMM_SELF, PETSC_ERR_PLIB, "In MATPRODUCT_PtAP, internal mat product matrix ab->C1->B has unexpectedly changed");
1043     if (ab->C1->product->A != Ad) PetscCall(MatProductReplaceMats(Ad, NULL, NULL, ab->C1));
1044     PetscCall((*ab->C1->ops->productnumeric)(ab->C1));
1045     PetscCall(MatSeqAIJKokkosBcast(ab->B_local, MAT_REUSE_MATRIX, 0 /* N */, MatColIdxKokkosView() /*l2g*/, NULL /*ownerSF*/, ab->sf, ab->abuf, ab->rows, ab->rowoffset, ab->B_other));
1046     /* ab->C2 = Ao * B_other */
1047     if (ab->C2->product->A != Ao) PetscCall(MatProductReplaceMats(Ao, NULL, NULL, ab->C2));
1048     PetscCall((*ab->C2->ops->productnumeric)(ab->C2)); /* C2 = Ao * B_other */
1049     KokkosSparse::spadd_numeric(&ab->kh, one, ab->C1_global, one, ab->C2_global, ab->C_global);
1050 
1051     /* atb->C1 = Bd^t * ab->C_petsc */
1052     atb = mmdata->mmAtB;
1053     PetscCheck(atb->C1->product->B == ab->C_petsc, PETSC_COMM_SELF, PETSC_ERR_PLIB, "In MATPRODUCT_PtAP, internal mat product matrix atb->C1->B has unexpectedly changed");
1054     if (atb->C1->product->A != Bd) PetscCall(MatProductReplaceMats(Bd, NULL, NULL, atb->C1));
1055     PetscCall((*atb->C1->ops->productnumeric)(atb->C1));
1056     /* atb->C2 = Bo^t * ab->C_petsc */
1057     if (atb->C2->product->A != Bo) PetscCall(MatProductReplaceMats(Bo, NULL, NULL, atb->C2));
1058     PetscCall((*atb->C2->ops->productnumeric)(atb->C2));
1059     PetscCall(MatSeqAIJKokkosReduce(atb->C2, MAT_REUSE_MATRIX, PETSC_FALSE, 0 /* N */, MatColIdxKokkosView() /*l2g*/, NULL /* ownerSF */, atb->sf, atb->abuf, atb->srcrowoffset, atb->dstrowoffset, atb->C2_global));
1060     KokkosSparse::spadd_numeric(&atb->kh, one, atb->C1_global, one, atb->C2_global, atb->C_global);
1061     mm = static_cast<MatMatStruct *>(atb);
1062   }
1063   /* Split C_global to form C */
1064   PetscCall(MatSetMPIAIJKokkosWithGlobalCSRMatrix(C, MAT_REUSE_MATRIX, mm->C_global, mm->Cdstart));
1065   PetscFunctionReturn(0);
1066 }
1067 
1068 PetscErrorCode MatProductSymbolic_MPIAIJKokkos(Mat C) {
1069   Mat                          A, B;
1070   Mat_Product                 *product = C->product;
1071   MatProductType               ptype;
1072   MatProductData_MPIAIJKokkos *mmdata;
1073   MatMatStruct                *mm   = NULL;
1074   IS                           glob = NULL;
1075   const PetscInt              *garray;
1076   PetscInt                     m, n, M, N, sz;
1077   ConstMatColIdxKokkosView     l2g; /* map local col ids to global ones */
1078 
1079   PetscFunctionBegin;
1080   MatCheckProduct(C, 1);
1081   PetscCheck(!product->data, PetscObjectComm((PetscObject)C), PETSC_ERR_PLIB, "Product data not empty");
1082   ptype = product->type;
1083   A     = product->A;
1084   B     = product->B;
1085 
1086   switch (ptype) {
1087   case MATPRODUCT_AB:
1088     m = A->rmap->n;
1089     n = B->cmap->n;
1090     M = A->rmap->N;
1091     N = B->cmap->N;
1092     break;
1093   case MATPRODUCT_AtB:
1094     m = A->cmap->n;
1095     n = B->cmap->n;
1096     M = A->cmap->N;
1097     N = B->cmap->N;
1098     break;
1099   case MATPRODUCT_PtAP:
1100     m = B->cmap->n;
1101     n = B->cmap->n;
1102     M = B->cmap->N;
1103     N = B->cmap->N;
1104     break; /* BtAB */
1105   default: SETERRQ(PetscObjectComm((PetscObject)C), PETSC_ERR_PLIB, "Not for product type %s", MatProductTypes[ptype]);
1106   }
1107 
1108   PetscCall(MatSetSizes(C, m, n, M, N));
1109   PetscCall(PetscLayoutSetUp(C->rmap));
1110   PetscCall(PetscLayoutSetUp(C->cmap));
1111   PetscCall(MatSetType(C, ((PetscObject)A)->type_name));
1112 
1113   mmdata           = new MatProductData_MPIAIJKokkos();
1114   mmdata->reusesym = product->api_user;
1115 
1116   if (ptype == MATPRODUCT_AB) {
1117     mmdata->mmAB = new MatMatStruct_AB();
1118     PetscCall(MatProductSymbolic_MPIAIJKokkos_AB(product, A, B, mmdata->mmAB));
1119     mm = static_cast<MatMatStruct *>(mmdata->mmAB);
1120   } else if (ptype == MATPRODUCT_AtB) {
1121     mmdata->mmAtB = new MatMatStruct_AtB();
1122     auto atb      = mmdata->mmAtB;
1123     PetscCall(MatMPIAIJGetLocalMatMerge(B, MAT_INITIAL_MATRIX, &glob, &atb->B_local));
1124     PetscCall(ISGetIndices(glob, &garray));
1125     PetscCall(ISGetSize(glob, &sz));
1126     l2g = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), ConstMatColIdxKokkosViewHost(garray, sz));
1127     PetscCall(MatProductSymbolic_MPIAIJKokkos_AtB(product, A, atb->B_local, PETSC_TRUE, N, l2g, atb));
1128     PetscCall(ISRestoreIndices(glob, &garray));
1129     PetscCall(ISDestroy(&glob));
1130     mm = static_cast<MatMatStruct *>(atb);
1131   } else if (ptype == MATPRODUCT_PtAP) {    /* BtAB */
1132     mmdata->mmAB  = new MatMatStruct_AB();  /* tmp=A*B */
1133     mmdata->mmAtB = new MatMatStruct_AtB(); /* C=B^t*tmp */
1134     auto ab       = mmdata->mmAB;
1135     auto atb      = mmdata->mmAtB;
1136     PetscCall(MatProductSymbolic_MPIAIJKokkos_AB(product, A, B, ab));
1137     auto tmp = new Mat_SeqAIJKokkos(ab->C_global); /* Memory will be owned by ab->C_petsc */
1138     PetscCall(MatCreateSeqAIJKokkosWithCSRMatrix(PETSC_COMM_SELF, tmp, &ab->C_petsc));
1139     PetscCall(MatProductSymbolic_MPIAIJKokkos_AtB(product, B, ab->C_petsc, PETSC_FALSE, N, l2g /*not used*/, atb));
1140     mm = static_cast<MatMatStruct *>(atb);
1141   }
1142   /* Split the C_global into petsc A, B format */
1143   PetscCall(MatSetMPIAIJKokkosWithGlobalCSRMatrix(C, MAT_INITIAL_MATRIX, mm->C_global, mm->Cdstart));
1144   C->product->data       = mmdata;
1145   C->product->destroy    = MatProductDataDestroy_MPIAIJKokkos;
1146   C->ops->productnumeric = MatProductNumeric_MPIAIJKokkos;
1147   PetscFunctionReturn(0);
1148 }
1149 
1150 PETSC_INTERN PetscErrorCode MatProductSetFromOptions_MPIAIJKokkos(Mat mat) {
1151   Mat_Product *product = mat->product;
1152   PetscBool    match   = PETSC_FALSE;
1153   PetscBool    usecpu  = PETSC_FALSE;
1154 
1155   PetscFunctionBegin;
1156   MatCheckProduct(mat, 1);
1157   if (!product->A->boundtocpu && !product->B->boundtocpu) PetscCall(PetscObjectTypeCompare((PetscObject)product->B, ((PetscObject)product->A)->type_name, &match));
1158   if (match) { /* we can always fallback to the CPU if requested */
1159     switch (product->type) {
1160     case MATPRODUCT_AB:
1161       if (product->api_user) {
1162         PetscOptionsBegin(PetscObjectComm((PetscObject)mat), ((PetscObject)mat)->prefix, "MatMatMult", "Mat");
1163         PetscCall(PetscOptionsBool("-matmatmult_backend_cpu", "Use CPU code", "MatMatMult", usecpu, &usecpu, NULL));
1164         PetscOptionsEnd();
1165       } else {
1166         PetscOptionsBegin(PetscObjectComm((PetscObject)mat), ((PetscObject)mat)->prefix, "MatProduct_AB", "Mat");
1167         PetscCall(PetscOptionsBool("-mat_product_algorithm_backend_cpu", "Use CPU code", "MatMatMult", usecpu, &usecpu, NULL));
1168         PetscOptionsEnd();
1169       }
1170       break;
1171     case MATPRODUCT_AtB:
1172       if (product->api_user) {
1173         PetscOptionsBegin(PetscObjectComm((PetscObject)mat), ((PetscObject)mat)->prefix, "MatTransposeMatMult", "Mat");
1174         PetscCall(PetscOptionsBool("-mattransposematmult_backend_cpu", "Use CPU code", "MatTransposeMatMult", usecpu, &usecpu, NULL));
1175         PetscOptionsEnd();
1176       } else {
1177         PetscOptionsBegin(PetscObjectComm((PetscObject)mat), ((PetscObject)mat)->prefix, "MatProduct_AtB", "Mat");
1178         PetscCall(PetscOptionsBool("-mat_product_algorithm_backend_cpu", "Use CPU code", "MatTransposeMatMult", usecpu, &usecpu, NULL));
1179         PetscOptionsEnd();
1180       }
1181       break;
1182     case MATPRODUCT_PtAP:
1183       if (product->api_user) {
1184         PetscOptionsBegin(PetscObjectComm((PetscObject)mat), ((PetscObject)mat)->prefix, "MatPtAP", "Mat");
1185         PetscCall(PetscOptionsBool("-matptap_backend_cpu", "Use CPU code", "MatPtAP", usecpu, &usecpu, NULL));
1186         PetscOptionsEnd();
1187       } else {
1188         PetscOptionsBegin(PetscObjectComm((PetscObject)mat), ((PetscObject)mat)->prefix, "MatProduct_PtAP", "Mat");
1189         PetscCall(PetscOptionsBool("-mat_product_algorithm_backend_cpu", "Use CPU code", "MatPtAP", usecpu, &usecpu, NULL));
1190         PetscOptionsEnd();
1191       }
1192       break;
1193     default: break;
1194     }
1195     match = (PetscBool)!usecpu;
1196   }
1197   if (match) {
1198     switch (product->type) {
1199     case MATPRODUCT_AB:
1200     case MATPRODUCT_AtB:
1201     case MATPRODUCT_PtAP: mat->ops->productsymbolic = MatProductSymbolic_MPIAIJKokkos; break;
1202     default: break;
1203     }
1204   }
1205   /* fallback to MPIAIJ ops */
1206   if (!mat->ops->productsymbolic) PetscCall(MatProductSetFromOptions_MPIAIJ(mat));
1207   PetscFunctionReturn(0);
1208 }
1209 
1210 static PetscErrorCode MatSetPreallocationCOO_MPIAIJKokkos(Mat mat, PetscCount coo_n, PetscInt coo_i[], PetscInt coo_j[]) {
1211   Mat_MPIAIJ       *mpiaij = (Mat_MPIAIJ *)mat->data;
1212   Mat_MPIAIJKokkos *mpikok;
1213 
1214   PetscFunctionBegin;
1215   PetscCall(MatSetPreallocationCOO_MPIAIJ(mat, coo_n, coo_i, coo_j));
1216   mat->preallocated = PETSC_TRUE;
1217   PetscCall(MatAssemblyBegin(mat, MAT_FINAL_ASSEMBLY));
1218   PetscCall(MatAssemblyEnd(mat, MAT_FINAL_ASSEMBLY));
1219   PetscCall(MatZeroEntries(mat));
1220   mpikok = static_cast<Mat_MPIAIJKokkos *>(mpiaij->spptr);
1221   delete mpikok;
1222   mpiaij->spptr = new Mat_MPIAIJKokkos(mpiaij);
1223   PetscFunctionReturn(0);
1224 }
1225 
1226 static PetscErrorCode MatSetValuesCOO_MPIAIJKokkos(Mat mat, const PetscScalar v[], InsertMode imode) {
1227   Mat_MPIAIJ                 *mpiaij = static_cast<Mat_MPIAIJ *>(mat->data);
1228   Mat_MPIAIJKokkos           *mpikok = static_cast<Mat_MPIAIJKokkos *>(mpiaij->spptr);
1229   Mat                         A = mpiaij->A, B = mpiaij->B;
1230   PetscCount                  Annz = mpiaij->Annz, Annz2 = mpiaij->Annz2, Bnnz = mpiaij->Bnnz, Bnnz2 = mpiaij->Bnnz2;
1231   MatScalarKokkosView         Aa, Ba;
1232   MatScalarKokkosView         v1;
1233   MatScalarKokkosView        &vsend  = mpikok->sendbuf_d;
1234   const MatScalarKokkosView  &v2     = mpikok->recvbuf_d;
1235   const PetscCountKokkosView &Ajmap1 = mpikok->Ajmap1_d, Ajmap2 = mpikok->Ajmap2_d, Aimap2 = mpikok->Aimap2_d;
1236   const PetscCountKokkosView &Bjmap1 = mpikok->Bjmap1_d, Bjmap2 = mpikok->Bjmap2_d, Bimap2 = mpikok->Bimap2_d;
1237   const PetscCountKokkosView &Aperm1 = mpikok->Aperm1_d, Aperm2 = mpikok->Aperm2_d, Bperm1 = mpikok->Bperm1_d, Bperm2 = mpikok->Bperm2_d;
1238   const PetscCountKokkosView &Cperm1 = mpikok->Cperm1_d;
1239   PetscMemType                memtype;
1240 
1241   PetscFunctionBegin;
1242   PetscCall(PetscGetMemType(v, &memtype)); /* Return PETSC_MEMTYPE_HOST when v is NULL */
1243   if (PetscMemTypeHost(memtype)) {         /* If user gave v[] in host, we need to copy it to device if any */
1244     v1 = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), MatScalarKokkosViewHost((PetscScalar *)v, mpiaij->coo_n));
1245   } else {
1246     v1 = MatScalarKokkosView((PetscScalar *)v, mpiaij->coo_n); /* Directly use v[]'s memory */
1247   }
1248 
1249   if (imode == INSERT_VALUES) {
1250     PetscCall(MatSeqAIJGetKokkosViewWrite(A, &Aa)); /* write matrix values */
1251     PetscCall(MatSeqAIJGetKokkosViewWrite(B, &Ba));
1252   } else {
1253     PetscCall(MatSeqAIJGetKokkosView(A, &Aa)); /* read & write matrix values */
1254     PetscCall(MatSeqAIJGetKokkosView(B, &Ba));
1255   }
1256 
1257   /* Pack entries to be sent to remote */
1258   Kokkos::parallel_for(
1259     vsend.extent(0), KOKKOS_LAMBDA(const PetscCount i) { vsend(i) = v1(Cperm1(i)); });
1260 
1261   /* Send remote entries to their owner and overlap the communication with local computation */
1262   PetscCall(PetscSFReduceWithMemTypeBegin(mpiaij->coo_sf, MPIU_SCALAR, PETSC_MEMTYPE_KOKKOS, vsend.data(), PETSC_MEMTYPE_KOKKOS, v2.data(), MPI_REPLACE));
1263   /* Add local entries to A and B in one kernel */
1264   Kokkos::parallel_for(
1265     Annz + Bnnz, KOKKOS_LAMBDA(PetscCount i) {
1266       PetscScalar sum = 0.0;
1267       if (i < Annz) {
1268         for (PetscCount k = Ajmap1(i); k < Ajmap1(i + 1); k++) sum += v1(Aperm1(k));
1269         Aa(i) = (imode == INSERT_VALUES ? 0.0 : Aa(i)) + sum;
1270       } else {
1271         i -= Annz;
1272         for (PetscCount k = Bjmap1(i); k < Bjmap1(i + 1); k++) sum += v1(Bperm1(k));
1273         Ba(i) = (imode == INSERT_VALUES ? 0.0 : Ba(i)) + sum;
1274       }
1275     });
1276   PetscCall(PetscSFReduceEnd(mpiaij->coo_sf, MPIU_SCALAR, vsend.data(), v2.data(), MPI_REPLACE));
1277 
1278   /* Add received remote entries to A and B in one kernel */
1279   Kokkos::parallel_for(
1280     Annz2 + Bnnz2, KOKKOS_LAMBDA(PetscCount i) {
1281       if (i < Annz2) {
1282         for (PetscCount k = Ajmap2(i); k < Ajmap2(i + 1); k++) Aa(Aimap2(i)) += v2(Aperm2(k));
1283       } else {
1284         i -= Annz2;
1285         for (PetscCount k = Bjmap2(i); k < Bjmap2(i + 1); k++) Ba(Bimap2(i)) += v2(Bperm2(k));
1286       }
1287     });
1288 
1289   if (imode == INSERT_VALUES) {
1290     PetscCall(MatSeqAIJRestoreKokkosViewWrite(A, &Aa)); /* Increase A & B's state etc. */
1291     PetscCall(MatSeqAIJRestoreKokkosViewWrite(B, &Ba));
1292   } else {
1293     PetscCall(MatSeqAIJRestoreKokkosView(A, &Aa));
1294     PetscCall(MatSeqAIJRestoreKokkosView(B, &Ba));
1295   }
1296   PetscFunctionReturn(0);
1297 }
1298 
1299 PetscErrorCode MatDestroy_MPIAIJKokkos(Mat A) {
1300   Mat_MPIAIJ *mpiaij = (Mat_MPIAIJ *)A->data;
1301 
1302   PetscFunctionBegin;
1303   PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatMPIAIJSetPreallocation_C", NULL));
1304   PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatMPIAIJGetLocalMatMerge_C", NULL));
1305   PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatSetPreallocationCOO_C", NULL));
1306   PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatSetValuesCOO_C", NULL));
1307   delete (Mat_MPIAIJKokkos *)mpiaij->spptr;
1308   PetscCall(MatDestroy_MPIAIJ(A));
1309   PetscFunctionReturn(0);
1310 }
1311 
1312 PETSC_INTERN PetscErrorCode MatConvert_MPIAIJ_MPIAIJKokkos(Mat A, MatType mtype, MatReuse reuse, Mat *newmat) {
1313   Mat         B;
1314   Mat_MPIAIJ *a;
1315 
1316   PetscFunctionBegin;
1317   if (reuse == MAT_INITIAL_MATRIX) {
1318     PetscCall(MatDuplicate(A, MAT_COPY_VALUES, newmat));
1319   } else if (reuse == MAT_REUSE_MATRIX) {
1320     PetscCall(MatCopy(A, *newmat, SAME_NONZERO_PATTERN));
1321   }
1322   B = *newmat;
1323 
1324   B->boundtocpu = PETSC_FALSE;
1325   PetscCall(PetscFree(B->defaultvectype));
1326   PetscCall(PetscStrallocpy(VECKOKKOS, &B->defaultvectype));
1327   PetscCall(PetscObjectChangeTypeName((PetscObject)B, MATMPIAIJKOKKOS));
1328 
1329   a = static_cast<Mat_MPIAIJ *>(A->data);
1330   if (a->A) PetscCall(MatSetType(a->A, MATSEQAIJKOKKOS));
1331   if (a->B) PetscCall(MatSetType(a->B, MATSEQAIJKOKKOS));
1332   if (a->lvec) PetscCall(VecSetType(a->lvec, VECSEQKOKKOS));
1333 
1334   B->ops->assemblyend           = MatAssemblyEnd_MPIAIJKokkos;
1335   B->ops->mult                  = MatMult_MPIAIJKokkos;
1336   B->ops->multadd               = MatMultAdd_MPIAIJKokkos;
1337   B->ops->multtranspose         = MatMultTranspose_MPIAIJKokkos;
1338   B->ops->productsetfromoptions = MatProductSetFromOptions_MPIAIJKokkos;
1339   B->ops->destroy               = MatDestroy_MPIAIJKokkos;
1340 
1341   PetscCall(PetscObjectComposeFunction((PetscObject)B, "MatMPIAIJSetPreallocation_C", MatMPIAIJSetPreallocation_MPIAIJKokkos));
1342   PetscCall(PetscObjectComposeFunction((PetscObject)B, "MatMPIAIJGetLocalMatMerge_C", MatMPIAIJGetLocalMatMerge_MPIAIJKokkos));
1343   PetscCall(PetscObjectComposeFunction((PetscObject)B, "MatSetPreallocationCOO_C", MatSetPreallocationCOO_MPIAIJKokkos));
1344   PetscCall(PetscObjectComposeFunction((PetscObject)B, "MatSetValuesCOO_C", MatSetValuesCOO_MPIAIJKokkos));
1345   PetscFunctionReturn(0);
1346 }
1347 /*MC
1348    MATAIJKOKKOS - "mpiaijkokkos", a matrix type to be used for CSR sparse matrices with Kokkos
1349 
1350    A matrix type type using Kokkos-Kernels CrsMatrix type for portability across different device types
1351 
1352    Options Database Keys:
1353 .  -mat_type aijkokkos - sets the matrix type to "aijkokkos" during a call to MatSetFromOptions()
1354 
1355   Level: beginner
1356 
1357 .seealso: `MatCreateAIJKokkos()`, `MATSEQAIJKOKKOS`, `MATSEQAIJ`, `MATMPIAIJ`
1358 M*/
1359 PETSC_EXTERN PetscErrorCode MatCreate_MPIAIJKokkos(Mat A) {
1360   PetscFunctionBegin;
1361   PetscCall(PetscKokkosInitializeCheck());
1362   PetscCall(MatCreate_MPIAIJ(A));
1363   PetscCall(MatConvert_MPIAIJ_MPIAIJKokkos(A, MATMPIAIJKOKKOS, MAT_INPLACE_MATRIX, &A));
1364   PetscFunctionReturn(0);
1365 }
1366 
1367 /*@C
1368    MatCreateAIJKokkos - Creates a sparse matrix in `MATAIJKOKOS` (compressed row) format
1369    (the default parallel PETSc format).  This matrix will ultimately pushed down
1370    to Kokkos for calculations. For good matrix
1371    assembly performance the user should preallocate the matrix storage by setting
1372    the parameter nz (or the array nnz).  By setting these parameters accurately,
1373    performance during matrix assembly can be increased by more than a factor of 50.
1374 
1375    Collective
1376 
1377    Input Parameters:
1378 +  comm - MPI communicator, set to `PETSC_COMM_SELF`
1379 .  m - number of rows
1380 .  n - number of columns
1381 .  nz - number of nonzeros per row (same for all rows)
1382 -  nnz - array containing the number of nonzeros in the various rows
1383          (possibly different for each row) or NULL
1384 
1385    Output Parameter:
1386 .  A - the matrix
1387 
1388    It is recommended that one use the `MatCreate()`, `MatSetType()` and/or `MatSetFromOptions()`,
1389    MatXXXXSetPreallocation() paradigm instead of this routine directly.
1390    [MatXXXXSetPreallocation() is, for example, `MatSeqAIJSetPreallocation()`]
1391 
1392    Notes:
1393    If nnz is given then nz is ignored
1394 
1395    The AIJ format, also called compressed row storage), is fully compatible with standard Fortran 77
1396    storage.  That is, the stored row and column indices can begin at
1397    either one (as in Fortran) or zero.  See the users' manual for details.
1398 
1399    Specify the preallocated storage with either nz or nnz (not both).
1400    Set nz = `PETSC_DEFAULT` and nnz = NULL for PETSc to control dynamic memory
1401    allocation.  For large problems you MUST preallocate memory or you
1402    will get TERRIBLE performance, see the users' manual chapter on matrices.
1403 
1404    By default, this format uses inodes (identical nodes) when possible, to
1405    improve numerical efficiency of matrix-vector products and solves. We
1406    search for consecutive rows with the same nonzero structure, thereby
1407    reusing matrix information to achieve increased efficiency.
1408 
1409    Level: intermediate
1410 
1411 .seealso: `MATAIJKOKOS`, `MATSEQAIJKOKOS`, `MATMPIAIJKOKOS`, `MatCreate()`, `MatCreateAIJ()`, `MatSetValues()`, `MatSeqAIJSetColumnIndices()`, `MatCreateSeqAIJWithArrays()`, `MatCreateAIJ()`, `MATMPIAIJKOKKOS`, `MATAIJKOKKOS`
1412 @*/
1413 PetscErrorCode MatCreateAIJKokkos(MPI_Comm comm, PetscInt m, PetscInt n, PetscInt M, PetscInt N, PetscInt d_nz, const PetscInt d_nnz[], PetscInt o_nz, const PetscInt o_nnz[], Mat *A) {
1414   PetscMPIInt size;
1415 
1416   PetscFunctionBegin;
1417   PetscCall(MatCreate(comm, A));
1418   PetscCall(MatSetSizes(*A, m, n, M, N));
1419   PetscCallMPI(MPI_Comm_size(comm, &size));
1420   if (size > 1) {
1421     PetscCall(MatSetType(*A, MATMPIAIJKOKKOS));
1422     PetscCall(MatMPIAIJSetPreallocation(*A, d_nz, d_nnz, o_nz, o_nnz));
1423   } else {
1424     PetscCall(MatSetType(*A, MATSEQAIJKOKKOS));
1425     PetscCall(MatSeqAIJSetPreallocation(*A, d_nz, d_nnz));
1426   }
1427   PetscFunctionReturn(0);
1428 }
1429 
1430 // get GPU pointer to stripped down Mat. For both Seq and MPI Mat.
1431 PetscErrorCode MatKokkosGetDeviceMatWrite(Mat A, PetscSplitCSRDataStructure *B) {
1432   PetscMPIInt                size, rank;
1433   MPI_Comm                   comm;
1434   PetscSplitCSRDataStructure d_mat = NULL;
1435 
1436   PetscFunctionBegin;
1437   PetscCall(PetscObjectGetComm((PetscObject)A, &comm));
1438   PetscCallMPI(MPI_Comm_size(comm, &size));
1439   PetscCallMPI(MPI_Comm_rank(comm, &rank));
1440   if (size == 1) {
1441     PetscCall(MatSeqAIJKokkosGetDeviceMat(A, &d_mat));
1442     PetscCall(MatSeqAIJKokkosModifyDevice(A)); /* Since we are going to modify matrix values on device */
1443   } else {
1444     Mat_MPIAIJ *aij = (Mat_MPIAIJ *)A->data;
1445     PetscCall(MatSeqAIJKokkosGetDeviceMat(aij->A, &d_mat));
1446     PetscCall(MatSeqAIJKokkosModifyDevice(aij->A));
1447     PetscCall(MatSeqAIJKokkosModifyDevice(aij->B));
1448     PetscCheck(A->nooffprocentries || aij->donotstash, PetscObjectComm((PetscObject)A), PETSC_ERR_SUP, "Device assembly does not currently support offproc values insertion. Use MatSetOption(A,MAT_NO_OFF_PROC_ENTRIES,PETSC_TRUE) or MatSetOption(A,MAT_IGNORE_OFF_PROC_ENTRIES,PETSC_TRUE)");
1449   }
1450   // act like MatSetValues because not called on host
1451   if (A->assembled) {
1452     if (A->was_assembled) PetscCall(PetscInfo(A, "Assemble more than once already\n"));
1453     A->was_assembled = PETSC_TRUE; // this is done (lazy) in MatAssemble but we are not calling it anymore - done in AIJ AssemblyEnd, need here?
1454   } else {
1455     PetscCall(PetscInfo(A, "Warning !assemble ??? assembled=%" PetscInt_FMT "\n", A->assembled));
1456   }
1457   if (!d_mat) {
1458     struct _n_SplitCSRMat h_mat; /* host container */
1459     Mat_SeqAIJKokkos     *aijkokA;
1460     Mat_SeqAIJ           *jaca;
1461     PetscInt              n = A->rmap->n, nnz;
1462     Mat                   Amat;
1463     PetscInt             *colmap;
1464 
1465     /* create and copy h_mat */
1466     h_mat.M = A->cmap->N; // use for debug build
1467     PetscCall(PetscInfo(A, "Create device matrix in Kokkos\n"));
1468     if (size == 1) {
1469       Amat            = A;
1470       jaca            = (Mat_SeqAIJ *)A->data;
1471       h_mat.rstart    = 0;
1472       h_mat.rend      = A->rmap->n;
1473       h_mat.cstart    = 0;
1474       h_mat.cend      = A->cmap->n;
1475       h_mat.offdiag.i = h_mat.offdiag.j = NULL;
1476       h_mat.offdiag.a                   = NULL;
1477       aijkokA                           = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
1478     } else {
1479       Mat_MPIAIJ       *aij  = (Mat_MPIAIJ *)A->data;
1480       Mat_SeqAIJ       *jacb = (Mat_SeqAIJ *)aij->B->data;
1481       PetscInt          ii;
1482       Mat_SeqAIJKokkos *aijkokB;
1483 
1484       Amat    = aij->A;
1485       aijkokA = static_cast<Mat_SeqAIJKokkos *>(aij->A->spptr);
1486       aijkokB = static_cast<Mat_SeqAIJKokkos *>(aij->B->spptr);
1487       jaca    = (Mat_SeqAIJ *)aij->A->data;
1488       PetscCheck(!aij->B->cmap->n || aij->garray, comm, PETSC_ERR_PLIB, "MPIAIJ Matrix was assembled but is missing garray");
1489       PetscCheck(aij->B->rmap->n == aij->A->rmap->n, comm, PETSC_ERR_SUP, "Only support aij->B->rmap->n == aij->A->rmap->n");
1490       aij->donotstash          = PETSC_TRUE;
1491       aij->A->nooffprocentries = aij->B->nooffprocentries = A->nooffprocentries = PETSC_TRUE;
1492       jaca->nonew = jacb->nonew = PETSC_TRUE; // no more disassembly
1493       PetscCall(PetscCalloc1(A->cmap->N, &colmap));
1494       PetscCall(PetscLogObjectMemory((PetscObject)A, (A->cmap->N) * sizeof(PetscInt)));
1495       for (ii = 0; ii < aij->B->cmap->n; ii++) colmap[aij->garray[ii]] = ii + 1;
1496       // allocate B copy data
1497       h_mat.rstart = A->rmap->rstart;
1498       h_mat.rend   = A->rmap->rend;
1499       h_mat.cstart = A->cmap->rstart;
1500       h_mat.cend   = A->cmap->rend;
1501       nnz          = jacb->i[n];
1502       if (jacb->compressedrow.use) {
1503         const Kokkos::View<PetscInt *, Kokkos::HostSpace, Kokkos::MemoryTraits<Kokkos::Unmanaged>> h_i_k(jacb->i, n + 1);
1504         aijkokB->i_uncompressed_d = Kokkos::View<PetscInt *>(Kokkos::create_mirror(DefaultMemorySpace(), h_i_k));
1505         Kokkos::deep_copy(aijkokB->i_uncompressed_d, h_i_k);
1506         h_mat.offdiag.i = aijkokB->i_uncompressed_d.data();
1507       } else {
1508         h_mat.offdiag.i = aijkokB->i_device_data();
1509       }
1510       h_mat.offdiag.j = aijkokB->j_device_data();
1511       h_mat.offdiag.a = aijkokB->a_device_data();
1512       {
1513         Kokkos::View<PetscInt *, Kokkos::HostSpace, Kokkos::MemoryTraits<Kokkos::Unmanaged>> h_colmap_k(colmap, A->cmap->N);
1514         aijkokB->colmap_d = Kokkos::View<PetscInt *>(Kokkos::create_mirror(DefaultMemorySpace(), h_colmap_k));
1515         Kokkos::deep_copy(aijkokB->colmap_d, h_colmap_k);
1516         h_mat.colmap = aijkokB->colmap_d.data();
1517         PetscCall(PetscFree(colmap));
1518       }
1519       h_mat.offdiag.ignorezeroentries = jacb->ignorezeroentries;
1520       h_mat.offdiag.n                 = n;
1521     }
1522     // allocate A copy data
1523     nnz                          = jaca->i[n];
1524     h_mat.diag.n                 = n;
1525     h_mat.diag.ignorezeroentries = jaca->ignorezeroentries;
1526     PetscCallMPI(MPI_Comm_rank(comm, &h_mat.rank));
1527     PetscCheck(!jaca->compressedrow.use, PETSC_COMM_SELF, PETSC_ERR_PLIB, "A does not suppport compressed row (todo)");
1528     h_mat.diag.i = aijkokA->i_device_data();
1529     h_mat.diag.j = aijkokA->j_device_data();
1530     h_mat.diag.a = aijkokA->a_device_data();
1531     // copy pointers and metdata to device
1532     PetscCall(MatSeqAIJKokkosSetDeviceMat(Amat, &h_mat));
1533     PetscCall(MatSeqAIJKokkosGetDeviceMat(Amat, &d_mat));
1534     PetscCall(PetscInfo(A, "Create device Mat n=%" PetscInt_FMT " nnz=%" PetscInt_FMT "\n", h_mat.diag.n, nnz));
1535   }
1536   *B           = d_mat;       // return it, set it in Mat, and set it up
1537   A->assembled = PETSC_FALSE; // ready to write with matsetvalues - this done (lazy) in normal MatSetValues
1538   PetscFunctionReturn(0);
1539 }
1540 
1541 PETSC_INTERN PetscErrorCode MatSeqAIJKokkosGetOffloadMask(Mat A, const char **mask) {
1542   Mat_SeqAIJKokkos *aijkok = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
1543 
1544   PetscFunctionBegin;
1545   if (!aijkok) *mask = "AIJKOK_UNALLOCATED";
1546   else if (aijkok->a_dual.need_sync_host()) *mask = "PETSC_OFFLOAD_GPU";
1547   else if (aijkok->a_dual.need_sync_device()) *mask = "PETSC_OFFLOAD_CPU";
1548   else *mask = "PETSC_OFFLOAD_BOTH";
1549   PetscFunctionReturn(0);
1550 }
1551 
1552 PETSC_INTERN PetscErrorCode MatAIJKokkosPrintOffloadMask(Mat A) {
1553   PetscMPIInt size;
1554   Mat         Ad, Ao;
1555   const char *amask, *bmask;
1556 
1557   PetscFunctionBegin;
1558   PetscCallMPI(MPI_Comm_size(PetscObjectComm((PetscObject)A), &size));
1559 
1560   if (size == 1) {
1561     PetscCall(MatSeqAIJKokkosGetOffloadMask(A, &amask));
1562     PetscCall(PetscPrintf(PETSC_COMM_SELF, "%s\n", amask));
1563   } else {
1564     Ad = ((Mat_MPIAIJ *)A->data)->A;
1565     Ao = ((Mat_MPIAIJ *)A->data)->B;
1566     PetscCall(MatSeqAIJKokkosGetOffloadMask(Ad, &amask));
1567     PetscCall(MatSeqAIJKokkosGetOffloadMask(Ao, &bmask));
1568     PetscCall(PetscPrintf(PETSC_COMM_SELF, "Diag : Off-diag = %s : %s\n", amask, bmask));
1569   }
1570   PetscFunctionReturn(0);
1571 }
1572