xref: /petsc/src/mat/impls/aij/mpi/kokkos/mpiaijkok.kokkos.cxx (revision e0f5bfbec699682fa3e8b8532b1176849ea4e12a)
1 #include <petscvec_kokkos.hpp>
2 #include <petscsf.h>
3 #include <petsc/private/sfimpl.h>
4 #include <../src/mat/impls/aij/mpi/mpiaij.h>
5 #include <../src/mat/impls/aij/mpi/kokkos/mpiaijkok.hpp>
6 #include <KokkosSparse_spadd.hpp>
7 
8 PetscErrorCode MatAssemblyEnd_MPIAIJKokkos(Mat A, MatAssemblyType mode) {
9   Mat_SeqAIJKokkos *aijkok;
10 
11   PetscFunctionBegin;
12   PetscCall(MatAssemblyEnd_MPIAIJ(A, mode));
13   aijkok = static_cast<Mat_SeqAIJKokkos *>(((Mat_MPIAIJ *)A->data)->A->spptr); /* Access spptr after MatAssemblyEnd_MPIAIJ(), which might have deleted old spptr */
14   if (aijkok && aijkok->device_mat_d.data()) {
15     A->offloadmask = PETSC_OFFLOAD_GPU; // in GPU mode, no going back. MatSetValues checks this
16   }
17 
18   PetscFunctionReturn(0);
19 }
20 
21 PetscErrorCode MatMPIAIJSetPreallocation_MPIAIJKokkos(Mat mat, PetscInt d_nz, const PetscInt d_nnz[], PetscInt o_nz, const PetscInt o_nnz[]) {
22   Mat_MPIAIJ *mpiaij = (Mat_MPIAIJ *)mat->data;
23 
24   PetscFunctionBegin;
25   PetscCall(PetscLayoutSetUp(mat->rmap));
26   PetscCall(PetscLayoutSetUp(mat->cmap));
27 #if defined(PETSC_USE_DEBUG)
28   if (d_nnz) {
29     PetscInt i;
30     for (i = 0; i < mat->rmap->n; i++) PetscCheck(d_nnz[i] >= 0, PETSC_COMM_SELF, PETSC_ERR_ARG_OUTOFRANGE, "d_nnz cannot be less than 0: local row %" PetscInt_FMT " value %" PetscInt_FMT, i, d_nnz[i]);
31   }
32   if (o_nnz) {
33     PetscInt i;
34     for (i = 0; i < mat->rmap->n; i++) PetscCheck(o_nnz[i] >= 0, PETSC_COMM_SELF, PETSC_ERR_ARG_OUTOFRANGE, "o_nnz cannot be less than 0: local row %" PetscInt_FMT " value %" PetscInt_FMT, i, o_nnz[i]);
35   }
36 #endif
37 #if defined(PETSC_USE_CTABLE)
38   PetscCall(PetscTableDestroy(&mpiaij->colmap));
39 #else
40   PetscCall(PetscFree(mpiaij->colmap));
41 #endif
42   PetscCall(PetscFree(mpiaij->garray));
43   PetscCall(VecDestroy(&mpiaij->lvec));
44   PetscCall(VecScatterDestroy(&mpiaij->Mvctx));
45   /* Because the B will have been resized we simply destroy it and create a new one each time */
46   PetscCall(MatDestroy(&mpiaij->B));
47 
48   if (!mpiaij->A) {
49     PetscCall(MatCreate(PETSC_COMM_SELF, &mpiaij->A));
50     PetscCall(MatSetSizes(mpiaij->A, mat->rmap->n, mat->cmap->n, mat->rmap->n, mat->cmap->n));
51   }
52   if (!mpiaij->B) {
53     PetscMPIInt size;
54     PetscCallMPI(MPI_Comm_size(PetscObjectComm((PetscObject)mat), &size));
55     PetscCall(MatCreate(PETSC_COMM_SELF, &mpiaij->B));
56     PetscCall(MatSetSizes(mpiaij->B, mat->rmap->n, size > 1 ? mat->cmap->N : 0, mat->rmap->n, size > 1 ? mat->cmap->N : 0));
57   }
58   PetscCall(MatSetType(mpiaij->A, MATSEQAIJKOKKOS));
59   PetscCall(MatSetType(mpiaij->B, MATSEQAIJKOKKOS));
60   PetscCall(MatSeqAIJSetPreallocation(mpiaij->A, d_nz, d_nnz));
61   PetscCall(MatSeqAIJSetPreallocation(mpiaij->B, o_nz, o_nnz));
62   mat->preallocated = PETSC_TRUE;
63   PetscFunctionReturn(0);
64 }
65 
66 PetscErrorCode MatMult_MPIAIJKokkos(Mat mat, Vec xx, Vec yy) {
67   Mat_MPIAIJ *mpiaij = (Mat_MPIAIJ *)mat->data;
68   PetscInt    nt;
69 
70   PetscFunctionBegin;
71   PetscCall(VecGetLocalSize(xx, &nt));
72   PetscCheck(nt == mat->cmap->n, PETSC_COMM_SELF, PETSC_ERR_ARG_SIZ, "Incompatible partition of mat (%" PetscInt_FMT ") and xx (%" PetscInt_FMT ")", mat->cmap->n, nt);
73   PetscCall(VecScatterBegin(mpiaij->Mvctx, xx, mpiaij->lvec, INSERT_VALUES, SCATTER_FORWARD));
74   PetscCall((*mpiaij->A->ops->mult)(mpiaij->A, xx, yy));
75   PetscCall(VecScatterEnd(mpiaij->Mvctx, xx, mpiaij->lvec, INSERT_VALUES, SCATTER_FORWARD));
76   PetscCall((*mpiaij->B->ops->multadd)(mpiaij->B, mpiaij->lvec, yy, yy));
77   PetscFunctionReturn(0);
78 }
79 
80 PetscErrorCode MatMultAdd_MPIAIJKokkos(Mat mat, Vec xx, Vec yy, Vec zz) {
81   Mat_MPIAIJ *mpiaij = (Mat_MPIAIJ *)mat->data;
82   PetscInt    nt;
83 
84   PetscFunctionBegin;
85   PetscCall(VecGetLocalSize(xx, &nt));
86   PetscCheck(nt == mat->cmap->n, PETSC_COMM_SELF, PETSC_ERR_ARG_SIZ, "Incompatible partition of mat (%" PetscInt_FMT ") and xx (%" PetscInt_FMT ")", mat->cmap->n, nt);
87   PetscCall(VecScatterBegin(mpiaij->Mvctx, xx, mpiaij->lvec, INSERT_VALUES, SCATTER_FORWARD));
88   PetscCall((*mpiaij->A->ops->multadd)(mpiaij->A, xx, yy, zz));
89   PetscCall(VecScatterEnd(mpiaij->Mvctx, xx, mpiaij->lvec, INSERT_VALUES, SCATTER_FORWARD));
90   PetscCall((*mpiaij->B->ops->multadd)(mpiaij->B, mpiaij->lvec, zz, zz));
91   PetscFunctionReturn(0);
92 }
93 
94 PetscErrorCode MatMultTranspose_MPIAIJKokkos(Mat mat, Vec xx, Vec yy) {
95   Mat_MPIAIJ *mpiaij = (Mat_MPIAIJ *)mat->data;
96   PetscInt    nt;
97 
98   PetscFunctionBegin;
99   PetscCall(VecGetLocalSize(xx, &nt));
100   PetscCheck(nt == mat->rmap->n, PETSC_COMM_SELF, PETSC_ERR_ARG_SIZ, "Incompatible partition of mat (%" PetscInt_FMT ") and xx (%" PetscInt_FMT ")", mat->rmap->n, nt);
101   PetscCall((*mpiaij->B->ops->multtranspose)(mpiaij->B, xx, mpiaij->lvec));
102   PetscCall((*mpiaij->A->ops->multtranspose)(mpiaij->A, xx, yy));
103   PetscCall(VecScatterBegin(mpiaij->Mvctx, mpiaij->lvec, yy, ADD_VALUES, SCATTER_REVERSE));
104   PetscCall(VecScatterEnd(mpiaij->Mvctx, mpiaij->lvec, yy, ADD_VALUES, SCATTER_REVERSE));
105   PetscFunctionReturn(0);
106 }
107 
108 /* Merge the "A, B" matrices of mat into a matrix C.  mat's type is MPIAIJKOKKOS. C's type is MATSEQAIJKOKKOS.
109    A is put before B. C's size would be A->rmap->n by (A->cmap->n + B->cmap->n).
110    C still uses local column ids. Their corresponding global column ids are returned in glob.
111 */
112 PetscErrorCode MatMPIAIJGetLocalMatMerge_MPIAIJKokkos(Mat mat, MatReuse reuse, IS *glob, Mat *C) {
113   Mat             Ad, Ao;
114   const PetscInt *cmap;
115 
116   PetscFunctionBegin;
117   PetscCall(MatMPIAIJGetSeqAIJ(mat, &Ad, &Ao, &cmap));
118   PetscCall(MatSeqAIJKokkosMergeMats(Ad, Ao, reuse, C));
119   if (glob) {
120     PetscInt cst, i, dn, on, *gidx;
121     PetscCall(MatGetLocalSize(Ad, NULL, &dn));
122     PetscCall(MatGetLocalSize(Ao, NULL, &on));
123     PetscCall(MatGetOwnershipRangeColumn(mat, &cst, NULL));
124     PetscCall(PetscMalloc1(dn + on, &gidx));
125     for (i = 0; i < dn; i++) gidx[i] = cst + i;
126     for (i = 0; i < on; i++) gidx[i + dn] = cmap[i];
127     PetscCall(ISCreateGeneral(PetscObjectComm((PetscObject)Ad), dn + on, gidx, PETSC_OWN_POINTER, glob));
128   }
129   PetscFunctionReturn(0);
130 }
131 
132 /* Structs used in matrix product C=AB, C=A^tB and C=B^tAB */
133 struct MatMatStruct {
134   MatRowMapKokkosView Cdstart; /* Used to split sequential matrix into petsc's A, B format */
135   PetscSF             sf;      /* SF to send/recv matrix entries */
136   MatScalarKokkosView abuf;    /* buf of mat values in send/recv */
137   Mat                 C1, C2, B_local;
138   KokkosCsrMatrix     C1_global, C2_global, C_global;
139   KernelHandle        kh;
140   MatMatStruct() {
141     C1 = C2 = B_local = NULL;
142     sf                = NULL;
143   }
144 
145   ~MatMatStruct() {
146     MatDestroy(&C1);
147     MatDestroy(&C2);
148     MatDestroy(&B_local);
149     PetscSFDestroy(&sf);
150     kh.destroy_spadd_handle();
151   }
152 };
153 
154 struct MatMatStruct_AB : public MatMatStruct {
155   MatColIdxKokkosView rows;
156   MatRowMapKokkosView rowoffset;
157   Mat                 B_other, C_petsc; /* SEQAIJKOKKOS matrices. TODO: have a better var name than C_petsc */
158 
159   MatMatStruct_AB() : B_other(NULL), C_petsc(NULL) { }
160   ~MatMatStruct_AB() {
161     MatDestroy(&B_other);
162     MatDestroy(&C_petsc);
163   }
164 };
165 
166 struct MatMatStruct_AtB : public MatMatStruct {
167   MatRowMapKokkosView srcrowoffset, dstrowoffset;
168 };
169 
170 struct MatProductData_MPIAIJKokkos {
171   MatMatStruct_AB  *mmAB;
172   MatMatStruct_AtB *mmAtB;
173   PetscBool         reusesym;
174 
175   MatProductData_MPIAIJKokkos() : mmAB(NULL), mmAtB(NULL), reusesym(PETSC_FALSE) { }
176   ~MatProductData_MPIAIJKokkos() {
177     delete mmAB;
178     delete mmAtB;
179   }
180 };
181 
182 static PetscErrorCode MatProductDataDestroy_MPIAIJKokkos(void *data) {
183   PetscFunctionBegin;
184   PetscCallCXX(delete static_cast<MatProductData_MPIAIJKokkos *>(data));
185   PetscFunctionReturn(0);
186 }
187 
188 /* MatSeqAIJKokkosGetCSRMatrixWithGlobalColumnIds - Get a KokkosCsrMatrix from a MATSEQAIJKOKKOS matrix
189 
190    Input Parameters:
191 +  A       - the MATSEQAIJKOKKOS matrix
192 .  N       - new column size for the returned Kokkos matrix
193 -  l2g     - a map that maps old col ids to new col ids
194 
195    Output Parameters:
196 .  csrmat  - the Kokkos matrix, which has the same row size as A, shares a, i but not j with A.
197  */
198 static PetscErrorCode MatSeqAIJKokkosGetCSRMatrixWithGlobalColumnIds(Mat A, PetscInt N, const ConstMatColIdxKokkosView &l2g, KokkosCsrMatrix &csrmat) {
199   KokkosCsrMatrix    &orig = static_cast<Mat_SeqAIJKokkos *>(A->spptr)->csrmat;
200   MatColIdxKokkosView jg("jg", orig.nnz()); /* New j array for csrmat */
201 
202   PetscFunctionBegin;
203   PetscCallCXX(Kokkos::parallel_for(
204     orig.nnz(), KOKKOS_LAMBDA(const PetscInt i) { jg(i) = l2g(orig.graph.entries(i)); }));
205   PetscCallCXX(csrmat = KokkosCsrMatrix("csrmat", orig.numRows(), N, orig.nnz(), orig.values, orig.graph.row_map, jg));
206   PetscFunctionReturn(0);
207 }
208 
209 /* MatSetMPIAIJKokkosWithSplitSeqAIJKokkosMatrices - Set the diag and offdiag matrices of a MATMPIAIJKOKKOS matrix.
210    It is similar to MatCreateMPIAIJWithSplitArrays.
211 
212   Input Parameters:
213 +  mat   - the MATMPIAIJKOKKOS matrix, which should have its type and layout set, but should not have its diag, offdiag matrices set
214 .  A     - the diag matrix using local col ids
215 -  B     - the offdiag matrix using global col ids
216 
217   Output Parameters:
218 .  mat   - the updated MATMPIAIJKOKKOS matrix
219 */
220 static PetscErrorCode MatSetMPIAIJKokkosWithSplitSeqAIJKokkosMatrices(Mat mat, Mat A, Mat B) {
221   Mat_MPIAIJ       *mpiaij = static_cast<Mat_MPIAIJ *>(mat->data);
222   PetscInt          m, n, M, N, Am, An, Bm, Bn;
223   Mat_SeqAIJKokkos *bkok = static_cast<Mat_SeqAIJKokkos *>(B->spptr);
224 
225   PetscFunctionBegin;
226   PetscCall(MatGetSize(mat, &M, &N));
227   PetscCall(MatGetLocalSize(mat, &m, &n));
228   PetscCall(MatGetLocalSize(A, &Am, &An));
229   PetscCall(MatGetLocalSize(B, &Bm, &Bn));
230 
231   PetscCheck(m == Am && m == Bm, PETSC_COMM_SELF, PETSC_ERR_PLIB, "local number of rows do not match");
232   PetscCheck(n == An, PETSC_COMM_SELF, PETSC_ERR_PLIB, "local number of columns do not match");
233   PetscCheck(N == Bn, PETSC_COMM_SELF, PETSC_ERR_PLIB, "global number of columns do not match");
234   PetscCheck(!mpiaij->A && !mpiaij->B, PETSC_COMM_SELF, PETSC_ERR_PLIB, "A, B of the MPIAIJ matrix are not empty");
235   mpiaij->A = A;
236   mpiaij->B = B;
237 
238   mat->preallocated     = PETSC_TRUE;
239   mat->nooffprocentries = PETSC_TRUE; /* See MatAssemblyBegin_MPIAIJ. In effect, making MatAssemblyBegin a nop */
240 
241   PetscCall(MatSetOption(mat, MAT_NO_OFF_PROC_ENTRIES, PETSC_TRUE));
242   PetscCall(MatAssemblyBegin(mat, MAT_FINAL_ASSEMBLY));
243   /* MatAssemblyEnd is critical here. It sets mat->offloadmask according to A and B's, and
244     also gets mpiaij->B compacted, with its col ids and size reduced
245   */
246   PetscCall(MatAssemblyEnd(mat, MAT_FINAL_ASSEMBLY));
247   PetscCall(MatSetOption(mat, MAT_NO_OFF_PROC_ENTRIES, PETSC_FALSE));
248   PetscCall(MatSetOption(mat, MAT_NEW_NONZERO_LOCATION_ERR, PETSC_TRUE));
249 
250   /* Update bkok with new local col ids (stored on host) and size */
251   bkok->j_dual.modify_host();
252   bkok->j_dual.sync_device();
253   bkok->SetColSize(mpiaij->B->cmap->n);
254   PetscFunctionReturn(0);
255 }
256 
257 /* MatSeqAIJKokkosBcast - Bcast rows of a SEQAIJKOKKOS matrice (B) to form a SEQAIJKOKKOS matrix (C).
258 
259    It is essentially the MPIAIJKOKKOS counterpart of MatGetBrowsOfAoCols_MPIAIJ, but supports device and uses PetscSF.
260    In the given ownerSF, leaves correspond to rows in C, and roots correspond to rows in B. Roots may connect to multiple leaves.
261    Suppose C's j-th row is connected to a root identified by PetscSFNode (k,i), it means we will bcast the i-th row of B on rank k
262    to j-th row of C. ownerSF's leaves must be contiguous (in other words, as if ilocal=NULL was used to set its graph).
263 
264    Collective on comm of ownerSF
265 
266    Input Parameters:
267 +   B       - the SEQAIJKOKKOS matrix, using local col ids
268 .   reuse   - either MAT_INITIAL_MATRIX or MAT_REUSE_MATRIX
269 .   N       - global col ids are in range of [0,N). N Must be the same across ranks (nonsignificant in MAT_REUSE_MATRIX)
270 .   l2g     - a map mapping B's local col ids to global ones (nonsignificant in MAT_REUSE_MATRIX)
271 .   ownerSF - the ownership SF (nonsignificant in MAT_REUSE_MATRIX)
272 
273    Input/Output Parameters (out when resue = MAT_INITIAL_MATRIX, inout when reuse = MAT_REUSE_MATRIX)
274 +   bcastSF   - the SF used to bcast rows of B. This plain SF does buffer (abuf) to buffer (Ca) send/recv. In this SF, vertices are nonzeros.
275 .   abuf      - buffer for sending matrix values
276 .   rows      - array containing indices of (local) rows that this rank needs to bcast to others. Each receiver rank has a chunk in rows[].
277                 Values in rows[] might have repeats, which simply indicates a row will be bcast'ed to multiple neighbors.
278 .   rowoffset - For each row in rows[], it will be copied to rowoffset[] at abuf[]
279 -   C         -  the SEQAIJKOKKOS matrix made of the bcast'ed rows, using local col ids.
280 */
281 static PetscErrorCode MatSeqAIJKokkosBcast(Mat B, MatReuse reuse, PetscInt N, const ConstMatColIdxKokkosView &l2g, PetscSF ownerSF, PetscSF &bcastSF, MatScalarKokkosView &abuf, MatColIdxKokkosView &rows, MatRowMapKokkosView &rowoffset, Mat &C) {
282   Mat_SeqAIJKokkos *bkok, *ckok;
283 
284   PetscFunctionBegin;
285   PetscCall(MatSeqAIJKokkosSyncDevice(B)); /* Make sure B->spptr is accessible */
286   bkok = static_cast<Mat_SeqAIJKokkos *>(B->spptr);
287 
288   if (reuse == MAT_REUSE_MATRIX) {
289     ckok = static_cast<Mat_SeqAIJKokkos *>(C->spptr);
290 
291     const auto &Ba = bkok->a_dual.view_device();
292     const auto &Bi = bkok->i_dual.view_device();
293     const auto &Ca = ckok->a_dual.view_device();
294 
295     /* Copy Ba to abuf */
296     Kokkos::parallel_for(
297       Kokkos::TeamPolicy<>(rows.extent(0), Kokkos::AUTO()), KOKKOS_LAMBDA(const KokkosTeamMemberType &t) {
298         PetscInt i    = t.league_rank(); /* rows[i] is r-th row of B */
299         PetscInt r    = rows(i);
300         PetscInt base = rowoffset(i); /* Copy r-th row of B to this offset in abuf[] */
301         Kokkos::parallel_for(Kokkos::TeamThreadRange(t, Bi(r + 1) - Bi(r)), [&](PetscInt k) { abuf(base + k) = Ba(Bi(r) + k); });
302       });
303 
304     /* Send abuf to Ca through bcastSF and then mark C is updated on device */
305     PetscCall(PetscSFBcastBegin(bcastSF, MPIU_SCALAR, abuf.data(), Ca.data(), MPI_REPLACE)); /* TODO: get memtype for abuf */
306     PetscCall(PetscSFBcastEnd(bcastSF, MPIU_SCALAR, abuf.data(), Ca.data(), MPI_REPLACE));
307     ckok->a_dual.modify_device();
308   } else if (reuse == MAT_INITIAL_MATRIX) {
309     MPI_Comm    comm;
310     PetscMPIInt tag;
311     PetscInt    k, Cm, Cn, Cnnz, *Ci_h, nroots, nleaves;
312 
313     PetscCallMPI(PetscObjectGetComm((PetscObject)ownerSF, &comm));
314     PetscCall(PetscSFGetGraph(ownerSF, &nroots, &nleaves, NULL, NULL));
315     Cm = nleaves; /* row size of C */
316     Cn = N;       /* col size of C, which initially uses global ids, so we can safely set its col size as N */
317 
318     /* Get row lens (nz) of B's rows for later fast query */
319     PetscInt       *Browlens;
320     const PetscInt *tmp = bkok->i_host_data();
321     PetscCall(PetscMalloc1(nroots, &Browlens));
322     for (k = 0; k < nroots; k++) Browlens[k] = tmp[k + 1] - tmp[k];
323 
324     /* By ownerSF, each proc gets lens of rows of C */
325     MatRowMapKokkosDualView Ci("i", Cm + 1); /* C's rowmap */
326     Ci_h    = Ci.view_host().data();
327     Ci_h[0] = 0;
328     PetscCall(PetscSFBcastWithMemTypeBegin(ownerSF, MPIU_INT, PETSC_MEMTYPE_HOST, Browlens, PETSC_MEMTYPE_HOST, &Ci_h[1], MPI_REPLACE));
329     PetscCall(PetscSFBcastEnd(ownerSF, MPIU_INT, Browlens, &Ci_h[1], MPI_REPLACE));
330     for (k = 1; k < Cm + 1; k++) Ci_h[k] += Ci_h[k - 1]; /* Convert lens to CSR */
331     Cnnz = Ci_h[Cm];
332     Ci.modify_host();
333     Ci.sync_device();
334 
335     /* With the newly known Cnnz, we are able to allocate (j, a) for C on host & device */
336     MatColIdxKokkosDualView Cj("j", Cnnz);
337     MatScalarKokkosDualView Ca("a", Cnnz);
338 
339     /* Now build the bcastSF to fill Ca, Cj. This plain SF only does (contiguous) buffer to buffer send/recv */
340     const PetscMPIInt *iranks, *ranks;
341     const PetscInt    *ioffset, *irootloc, *roffset;
342     PetscInt           i, j, niranks, nranks, *sdisp, *rdisp, *rowptr;
343     MPI_Request       *reqs;
344 
345     PetscCall(PetscSFGetLeafRanks(ownerSF, &niranks, &iranks, &ioffset, &irootloc));                      /* irootloc[] contains indices of rows I need to send to each receiver */
346     PetscCall(PetscSFGetRootRanks(ownerSF, &nranks, &ranks, &roffset, NULL /*rmine*/, NULL /*rremote*/)); /* recv info */
347 
348     /* figure out offsets at the send buffer, to build the SF
349       sdisp[]  - stores offsets of nonzeros (in abuf or jbuf, see later) I need to send, per receiver.
350       rowptr[] - stores offsets for data of each row in abuf
351 
352       rdisp[]  - to receive sdisp[]
353     */
354     PetscCall(PetscMalloc3(niranks + 1, &sdisp, nranks, &rdisp, niranks + nranks, &reqs));
355     MatRowMapKokkosViewHost rowptr_h("rowptr_h", ioffset[niranks] + 1); /* Let Kokkos do the allocation, so that we can do an easy mirror later */
356     rowptr = rowptr_h.data();
357 
358     sdisp[0]  = 0;
359     rowptr[0] = 0;
360     for (i = 0; i < niranks; i++) { /* for each receiver */
361       PetscInt len, nz = 0;
362       for (j = ioffset[i]; j < ioffset[i + 1]; j++) { /* for each row to this receiver */
363         len           = Browlens[irootloc[j]];
364         rowptr[j + 1] = rowptr[j] + len;
365         nz += len;
366       }
367       sdisp[i + 1] = sdisp[i] + nz;
368     }
369     PetscCallMPI(PetscCommGetNewTag(comm, &tag));
370     for (i = 0; i < nranks; i++) PetscCallMPI(MPI_Irecv(&rdisp[i], 1, MPIU_INT, ranks[i], tag, comm, &reqs[i]));
371     for (i = 0; i < niranks; i++) PetscCallMPI(MPI_Isend(&sdisp[i], 1, MPIU_INT, iranks[i], tag, comm, &reqs[nranks + i]));
372     PetscCallMPI(MPI_Waitall(niranks + nranks, reqs, MPI_STATUSES_IGNORE));
373 
374     PetscInt     nleaves2 = Cnnz;           /* leaves are the nonzeros I will receive */
375     PetscInt     nroots2  = sdisp[niranks]; /* roots are the nonzeros (in abuf) I will send */
376     PetscSFNode *iremote;
377     PetscCall(PetscMalloc1(nleaves2, &iremote));
378     for (i = 0; i < nranks; i++) { /* for each sender */
379       k = 0;
380       for (j = Ci_h[roffset[i]]; j < Ci_h[roffset[i + 1]]; j++) {
381         iremote[j].rank  = ranks[i];
382         iremote[j].index = rdisp[i] + k;
383         k++;
384       }
385     }
386     /* TODO: we should extend PetscSF APIs for this buffer-to-buffer send/recv */
387     PetscCall(PetscSFCreate(comm, &bcastSF));
388     PetscCall(PetscSFSetGraph(bcastSF, nroots2, nleaves2, NULL /*ilocal*/, PETSC_OWN_POINTER, iremote, PETSC_OWN_POINTER));
389 
390     /* Extract selected rows of B, and copy their (a, j) into abuf[] and jbuf[], with j converted
391       from local to global. Then use bcastSF to fill Ca, Cj.
392     */
393     ConstMatColIdxKokkosViewHost rows_h(irootloc, ioffset[niranks]); /* irootloc[] stores indices of rows I need to send */
394     MatColIdxKokkosView          rows("rows", ioffset[niranks]);
395     Kokkos::deep_copy(rows, rows_h); /* Use deep copy since irootoc is managed by PetscSF and we want 'rows' to be standalone */
396 
397     rowoffset = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), rowptr_h); /* If no device, rowoffset will be an alias to rowptr_h */
398 
399     MatColIdxKokkosView jbuf("jbuf", sdisp[niranks]);   /* send buf for (global) col ids */
400     abuf = MatScalarKokkosView("abuf", sdisp[niranks]); /* send buf for mat values */
401 
402     const auto &Ba = bkok->a_dual.view_device();
403     const auto &Bi = bkok->i_dual.view_device();
404     const auto &Bj = bkok->j_dual.view_device();
405 
406     /* Copy Ba, Bj to abuf, jbuf with change col ids from local to global */
407     Kokkos::parallel_for(
408       Kokkos::TeamPolicy<>(rows.extent(0), Kokkos::AUTO()), KOKKOS_LAMBDA(const KokkosTeamMemberType &t) {
409         PetscInt i    = t.league_rank(); /* rows[i] is r-th row of B */
410         PetscInt r    = rows(i);
411         PetscInt base = rowoffset(i); /* Copy r-th row of B to this offset in abuf[] */
412         Kokkos::parallel_for(Kokkos::TeamThreadRange(t, Bi(r + 1) - Bi(r)), [&](PetscInt k) {
413           abuf(base + k) = Ba(Bi(r) + k);
414           jbuf(base + k) = l2g(Bj(Bi(r) + k));
415         });
416       });
417 
418     /* Send abuf & jbuf to fill Ca, Cj */
419     PetscCall(PetscSFBcastBegin(bcastSF, MPIU_INT, jbuf.data(), Cj.view_device().data(), MPI_REPLACE));
420     PetscCall(PetscSFBcastBegin(bcastSF, MPIU_SCALAR, abuf.data(), Ca.view_device().data(), MPI_REPLACE));
421     PetscCall(PetscSFBcastEnd(bcastSF, MPIU_INT, jbuf.data(), Cj.view_device().data(), MPI_REPLACE));
422     PetscCall(PetscSFBcastEnd(bcastSF, MPIU_SCALAR, abuf.data(), Ca.view_device().data(), MPI_REPLACE));
423     Cj.modify_device(); /* Mark Cj, Ca modified on device, but only sync Cj since we might not need Ca on host at all */
424     Cj.sync_host();
425     Ca.modify_device();
426 
427     /* Construct C with Ca, Ci, Cj */
428     auto ckok = new Mat_SeqAIJKokkos(Cm, Cn, Cnnz, Ci, Cj, Ca);
429     PetscCall(MatCreateSeqAIJKokkosWithCSRMatrix(PETSC_COMM_SELF, ckok, &C));
430     PetscCall(PetscFree3(sdisp, rdisp, reqs));
431     PetscCall(PetscFree(Browlens));
432   } else SETERRQ(PETSC_COMM_SELF, PETSC_ERR_PLIB, "Unsupported MatReuse enum %d", reuse);
433   PetscFunctionReturn(0);
434 }
435 
436 /* MatSeqAIJKokkosReduce - Reduce rows of a SEQAIJKOKKOS matrix (A) to form a Kokkos Csr matrix (C)
437 
438   It is the reverse of MatSeqAIJKokkosBcast in some sense.
439 
440   Think each row of A as a leaf, then the given ownerSF specifies roots of the leaves. Roots may connect to multiple leaves.
441   In this routine, we reduce (i.e., concatenate) leaves (rows) at their roots to form potentially longer rows in C. Such rows might
442   contain repeats, which does not matter since they will be summed up by other routines. C's row size will be nroots of ownerSF.
443 
444   Input Parameters:
445 +  A        - the SEQAIJKOKKOS matrix to be reduced
446 .  reuse    - either MAT_INITIAL_MATRIX or MAT_REUSE_MATRIX
447 .  local    - true if A uses local col ids; false if A is already in global col ids.
448 .  N        - if local, N is A's global col size
449 .  l2g      - if local, a map mapping A's local col ids to global ones, which are in range of [0,N).
450 -  ownerSF  - the SF specifies ownership (root) of rows in A
451 
452   Output Parameters:
453 +  reduceSF    - the SF to reduce A's rows to contiguous buffers at the receiver side
454 .  abuf         - a contiguous buffer to receive A's rows sent to this proc. Suppose there are 'nrows' such rows.
455 .  srcrowoffset - offset array of size nrows+1. Each entry is the corresponding row's offset in abuf[]. srcrowoffset[i+1]-srcrowoffset[i] is row i's len.
456 .  dstrowoffset - offset array of size nrows. Each entry is the corresponding row's offset in Ca[], i.e., C's 'a' array. Row i, i+1 in abuf[] may go to
457                   unrelated places in Ca, so dstrowoffset is not in CSR-like format as srcrowoffset.
458 -  C            - the matrix made up by rows sent to me from other ranks, using global col ids
459 
460    TODO: we can even have MatSeqAIJKokkosReduceBegin/End to provide oppertunity for callers to overlap comp./comm. when reuse = MAT_REUSE_MATRIX.
461  */
462 static PetscErrorCode MatSeqAIJKokkosReduce(Mat A, MatReuse reuse, PetscBool local, PetscInt N, const ConstMatColIdxKokkosView &l2g, PetscSF ownerSF, PetscSF &reduceSF, MatScalarKokkosView &abuf, MatRowMapKokkosView &srcrowoffset, MatRowMapKokkosView &dstrowoffset, KokkosCsrMatrix &C) {
463   PetscInt          i, r, Am, An, Annz, Cnnz, nrows;
464   const PetscInt   *Ai;
465   Mat_SeqAIJKokkos *akok;
466 
467   PetscFunctionBegin;
468   PetscCall(MatSeqAIJKokkosSyncDevice(A)); /* So that A's latest data is on device */
469   PetscCall(MatGetSize(A, &Am, &An));
470   Ai   = static_cast<Mat_SeqAIJ *>(A->data)->i;
471   akok = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
472   Annz = Ai[Am];
473 
474   if (reuse == MAT_REUSE_MATRIX) {
475     /* Send Aa to abuf */
476     PetscCallMPI(PetscSFReduceBegin(reduceSF, MPIU_SCALAR, akok->a_device_data(), abuf.data(), MPI_REPLACE));
477     PetscCallMPI(PetscSFReduceEnd(reduceSF, MPIU_SCALAR, akok->a_device_data(), abuf.data(), MPI_REPLACE));
478 
479     /* Copy abuf to Ca */
480     const MatScalarKokkosView &Ca = C.values;
481     nrows                         = dstrowoffset.extent(0); /* Not srcrowoffset[] since it has an extra entry for CSR */
482     Kokkos::parallel_for(
483       Kokkos::TeamPolicy<>(nrows, Kokkos::AUTO()), KOKKOS_LAMBDA(const KokkosTeamMemberType &t) {
484         PetscInt i   = t.league_rank();
485         PetscInt src = srcrowoffset(i), dst = dstrowoffset(i);
486         PetscInt len = srcrowoffset(i + 1) - srcrowoffset(i);
487         Kokkos::parallel_for(Kokkos::TeamThreadRange(t, len), [&](PetscInt k) { Ca(dst + k) = abuf(src + k); });
488       });
489   } else if (reuse == MAT_INITIAL_MATRIX) {
490     MPI_Comm     comm;
491     MPI_Request *reqs;
492     PetscMPIInt  tag;
493     PetscInt     Cm;
494 
495     PetscCall(PetscObjectGetComm((PetscObject)ownerSF, &comm));
496     PetscCall(PetscCommGetNewTag(comm, &tag));
497 
498     PetscInt           niranks, nranks, nroots, nleaves;
499     const PetscMPIInt *iranks, *ranks;
500     const PetscInt    *ioffset, *rows, *roffset; /* rows[] contains local indices of rows scattered to me from others. ioffset[] is a CSR on rows[] */
501     PetscCall(PetscSFSetUp(ownerSF));
502     PetscCall(PetscSFGetLeafRanks(ownerSF, &niranks, &iranks, &ioffset, &rows));                          /* recv info: iranks[] will send rows to me */
503     PetscCall(PetscSFGetRootRanks(ownerSF, &nranks, &ranks, &roffset, NULL /*rmine*/, NULL /*rremote*/)); /* send info */
504     PetscCall(PetscSFGetGraph(ownerSF, &nroots, &nleaves, NULL, NULL));
505     PetscCheck(nleaves == Am, PETSC_COMM_SELF, PETSC_ERR_PLIB, "ownerSF's nleaves(%" PetscInt_FMT ") != row size of A(%" PetscInt_FMT ")", nleaves, Am);
506     Cm    = nroots;
507     nrows = ioffset[niranks]; /* # of rows to be received. Might receive same row (each is partial) from different senders */
508 
509     /* Tell owners how long each row I will send */
510     PetscInt               *srowlens;                              /* send buf of row lens */
511     MatRowMapKokkosViewHost rrowlens_h("rrowoffset_h", nrows + 1); /* recv buf of row lens. +1 to make CSR later. Memory might be passed to other views */
512     PetscInt               *rrowlens = rrowlens_h.data();
513 
514     PetscCall(PetscMalloc2(Am, &srowlens, niranks + nranks, &reqs));
515     for (i = 0; i < Am; i++) srowlens[i] = Ai[i + 1] - Ai[i];
516     rrowlens[0] = 0;
517     rrowlens++; /* shift the pointer to make the following expression more readable */
518     for (i = 0; i < niranks; i++) PetscCallMPI(MPI_Irecv(&rrowlens[ioffset[i]], ioffset[i + 1] - ioffset[i], MPIU_INT, iranks[i], tag, comm, &reqs[i]));
519     for (i = 0; i < nranks; i++) PetscCallMPI(MPI_Isend(&srowlens[roffset[i]], roffset[i + 1] - roffset[i], MPIU_INT, ranks[i], tag, comm, &reqs[niranks + i]));
520     PetscCallMPI(MPI_Waitall(niranks + nranks, reqs, MPI_STATUSES_IGNORE));
521 
522     /* Owner builds Ci on host by histogramming rrowlens[] */
523     MatRowMapKokkosViewHost Ci_h("i", Cm + 1);
524     Kokkos::deep_copy(Ci_h, 0); /* Zero Ci */
525     MatRowMapType *Ci_ptr = Ci_h.data();
526 
527     for (i = 0; i < nrows; i++) {
528       r = rows[i]; /* local row id of i-th received row */
529 #if defined(PETSC_USE_DEBUG)
530       PetscCheck(r >= 0 && r < Cm, PETSC_COMM_SELF, PETSC_ERR_PLIB, "local row id (%" PetscInt_FMT ") is out of range [0,%" PetscInt_FMT ")", r, Cm);
531 #endif
532       Ci_ptr[r + 1] += rrowlens[i]; /* add to length of row r in C */
533     }
534     for (i = 0; i < Cm; i++) Ci_ptr[i + 1] += Ci_ptr[i]; /* to CSR format */
535     Cnnz = Ci_ptr[Cm];
536 
537     /* For each received row, compute src & dst offsets in memory copying (from recv bufs abuf, jbuf to Ca, Cj) */
538     MatRowMapKokkosViewHost dstrowoffset_h("dstrowoffset_h", nrows);
539     PetscInt               *dstrowoffset_hptr = dstrowoffset_h.data();
540     PetscInt               *currowlens; /* Current row lens. They are temp accumulators for row lens in C, to help build dstrowoffset */
541 
542     PetscCall(PetscCalloc1(Cm, &currowlens));           /* Init with zero, to be added to */
543     for (i = 0; i < nrows; i++) {                       /* for each row I receive */
544       r                    = rows[i];                   /* row id in C */
545       dstrowoffset_hptr[i] = Ci_ptr[r] + currowlens[r]; /* dst offset of the new place for each recv'ed row in Ca/Cj */
546       currowlens[r] += rrowlens[i];                     /* accumulate to length of row r in C */
547     }
548     PetscCall(PetscFree(currowlens));
549 
550     rrowlens--;
551     for (i = 0; i < nrows; i++) rrowlens[i + 1] += rrowlens[i]; /* Change rrowlens[] to CSR format */
552     dstrowoffset = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), dstrowoffset_h);
553     srcrowoffset = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), rrowlens_h); /* src offset of each recv'ed row in abuf/jbuf */
554 
555     /* Build the reduceSF, which performs buffer to buffer send/recv */
556     PetscInt *sdisp, *rdisp; /* buffer to send offsets of roots, and buffer to recv them */
557     PetscCall(PetscMalloc2(niranks, &sdisp, nranks, &rdisp));
558     for (i = 0; i < niranks; i++) sdisp[i] = rrowlens[ioffset[i]];
559     for (i = 0; i < nranks; i++) PetscCallMPI(MPI_Irecv(&rdisp[i], 1, MPIU_INT, ranks[i], tag, comm, &reqs[i]));
560     for (i = 0; i < niranks; i++) PetscCallMPI(MPI_Isend(&sdisp[i], 1, MPIU_INT, iranks[i], tag, comm, &reqs[nranks + i]));
561     PetscCallMPI(MPI_Waitall(niranks + nranks, reqs, MPI_STATUSES_IGNORE));
562 
563     /* Nonzeros in abuf/jbuf are roots and those in A are leaves */
564     PetscInt     nroots2 = Cnnz, nleaves2 = Annz;
565     PetscSFNode *iremote;
566     PetscCall(PetscMalloc1(nleaves2, &iremote)); /* no free, since memory will be given to reduceSF */
567     for (i = 0; i < nranks; i++) {
568       PetscInt rootbase = rdisp[i];                      /* root offset at this root rank */
569       PetscInt leafbase = Ai[roffset[i]];                /* leaf base */
570       PetscInt nz       = Ai[roffset[i + 1]] - leafbase; /* I will send nz nonzeros to this root rank */
571       for (PetscInt k = 0; k < nz; k++) {
572         iremote[leafbase + k].rank  = ranks[i];
573         iremote[leafbase + k].index = rootbase + k;
574       }
575     }
576     PetscCall(PetscSFCreate(comm, &reduceSF));
577     PetscCall(PetscSFSetGraph(reduceSF, nroots2, nleaves2, NULL, PETSC_OWN_POINTER, iremote, PETSC_OWN_POINTER));
578     PetscCall(PetscFree2(sdisp, rdisp));
579 
580     /* Reduce Aa, Ajg to abuf and jbuf */
581 
582     /* If A uses local col ids, convert them to global ones before sending */
583     MatColIdxKokkosView Ajg;
584     if (local) {
585       Ajg                           = MatColIdxKokkosView("j", Annz);
586       const MatColIdxKokkosView &Aj = akok->j_dual.view_device();
587       Kokkos::parallel_for(
588         Annz, KOKKOS_LAMBDA(const PetscInt i) { Ajg(i) = l2g(Aj(i)); });
589     } else {
590       Ajg = akok->j_dual.view_device(); /* no data copy, just take a reference */
591     }
592 
593     MatColIdxKokkosView jbuf("jbuf", Cnnz);
594     abuf = MatScalarKokkosView("abuf", Cnnz);
595     PetscCallMPI(PetscSFReduceBegin(reduceSF, MPIU_INT, Ajg.data(), jbuf.data(), MPI_REPLACE));
596     PetscCallMPI(PetscSFReduceEnd(reduceSF, MPIU_INT, Ajg.data(), jbuf.data(), MPI_REPLACE));
597     PetscCallMPI(PetscSFReduceBegin(reduceSF, MPIU_SCALAR, akok->a_device_data(), abuf.data(), MPI_REPLACE));
598     PetscCallMPI(PetscSFReduceEnd(reduceSF, MPIU_SCALAR, akok->a_device_data(), abuf.data(), MPI_REPLACE));
599 
600     /* Copy data from abuf, jbuf to Ca, Cj */
601     MatRowMapKokkosView Ci = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Ci_h); /* Ci is an alias of Ci_h if no device */
602     MatColIdxKokkosView Cj("j", Cnnz);
603     MatScalarKokkosView Ca("a", Cnnz);
604 
605     Kokkos::parallel_for(
606       Kokkos::TeamPolicy<>(nrows, Kokkos::AUTO()), KOKKOS_LAMBDA(const KokkosTeamMemberType &t) {
607         PetscInt i   = t.league_rank();
608         PetscInt src = srcrowoffset(i), dst = dstrowoffset(i);
609         PetscInt len = srcrowoffset(i + 1) - srcrowoffset(i);
610         Kokkos::parallel_for(Kokkos::TeamThreadRange(t, len), [&](PetscInt k) {
611           Ca(dst + k) = abuf(src + k);
612           Cj(dst + k) = jbuf(src + k);
613         });
614       });
615 
616     /* Build C with Ca, Ci, Cj */
617     C = KokkosCsrMatrix("csrmat", Cm, N, Cnnz, Ca, Ci, Cj);
618     PetscCall(PetscFree2(srowlens, reqs));
619   } else SETERRQ(PETSC_COMM_SELF, PETSC_ERR_PLIB, "Unsupported MatReuse enum %d", reuse);
620   PetscFunctionReturn(0);
621 }
622 
623 /* MatSetMPIAIJKokkosWithGlobalCSRMatrix - Set the diag and offdiag parts of a `MATMPIAIJKOKKOS` matrix by splitting a KokkosCsrMatrix
624 
625   Input Parameters:
626 +  C        - the `MATMPIAIJKOKKOS` matrix, of size m,n,M,N
627 .  reuse    - indicate whether the matrix has called this function before
628 .  csrmat   - the KokkosCsrMatrix, of size m,N
629 -  Cdstart  - when reuse == `MAT_REUSE_MATRIX`, it is an input parameter. For each row in csrmat, it stores the start of the first
630               entry of the diag block of C in csrmat's j array. E.g, if row i has col ids = {0, 3, 4, 5, 7, 9} and the first diag
631               entry is 5, then Cdstart[i] = 3.
632 
633   Output Parameters:
634 +  C        - the updated `MATMPIAIJKOKKOS` matrix
635 -  Cdstart - when reuse == `MAT_INITIAL_MATRIX`, it is an output parameter
636 
637   Note:
638    Between calls with `MAT_INITIAL_MATRIX` or `MAT_REUSE_MATRIX`, csrmat must have the same nonzero pattern
639 
640 .seealso: `MATMPIAIJKOKKOS`
641  */
642 static PetscErrorCode MatSetMPIAIJKokkosWithGlobalCSRMatrix(Mat C, MatReuse reuse, const KokkosCsrMatrix &csrmat, MatRowMapKokkosView &Cdstart) {
643   const MatScalarKokkosView      &Ca = csrmat.values;
644   const ConstMatRowMapKokkosView &Ci = csrmat.graph.row_map;
645   PetscInt                        m, n, N;
646 
647   PetscFunctionBegin;
648   PetscCall(MatGetLocalSize(C, &m, &n));
649   PetscCall(MatGetSize(C, NULL, &N));
650 
651   if (reuse == MAT_REUSE_MATRIX) {
652     Mat_MPIAIJ                *mpiaij = static_cast<Mat_MPIAIJ *>(C->data);
653     Mat_SeqAIJKokkos          *akok   = static_cast<Mat_SeqAIJKokkos *>(mpiaij->A->spptr);
654     Mat_SeqAIJKokkos          *bkok   = static_cast<Mat_SeqAIJKokkos *>(mpiaij->B->spptr);
655     const MatScalarKokkosView &Cda = akok->a_dual.view_device(), Coa = bkok->a_dual.view_device();
656     const MatRowMapKokkosView &Cdi = akok->i_dual.view_device(), Coi = bkok->i_dual.view_device();
657 
658     /* Fill 'a' of Cd and Co on device */
659     Kokkos::parallel_for(
660       Kokkos::TeamPolicy<>(m, Kokkos::AUTO()), KOKKOS_LAMBDA(const KokkosTeamMemberType &t) {
661         PetscInt i       = t.league_rank();     /* row i */
662         PetscInt clen    = Ci(i + 1) - Ci(i);   /* len of row i of C */
663         PetscInt cdlen   = Cdi(i + 1) - Cdi(i); /* len of row i of Cd */
664         PetscInt cdstart = Cdstart(i);          /* [start, end) of row i of Cd in C */
665         PetscInt cdend   = cdstart + cdlen;
666         /* [0, clen) is cut into three blocks: [0, cdstart), [cdstart, cdend), [cdend, clen) */
667         Kokkos::parallel_for(Kokkos::TeamThreadRange(t, clen), [&](PetscInt k) {
668           if (k < cdstart) { /* k in [0, cdstart) */
669             Coa(Coi(i) + k) = Ca(Ci(i) + k);
670           } else if (k < cdend) { /* k in [cdstart, cdend) */
671             Cda(Cdi(i) + (k - cdstart)) = Ca(Ci(i) + k);
672           } else { /* k in [cdend, clen) */
673             Coa(Coi(i) + k - cdlen) = Ca(Ci(i) + k);
674           }
675         });
676       });
677 
678     akok->a_dual.modify_device();
679     bkok->a_dual.modify_device();
680   } else if (reuse == MAT_INITIAL_MATRIX) {
681     Mat                        Cd, Co;
682     const MatColIdxKokkosView &Cj = csrmat.graph.entries;
683     MatRowMapKokkosDualView    Cdi_dual("i", m + 1), Coi_dual("i", m + 1);
684     MatRowMapKokkosView        Cdi = Cdi_dual.view_device(), Coi = Coi_dual.view_device();
685     PetscInt                   cstart, cend;
686 
687     /* Note that each row of C is sorted by col ids. We want to find out how to cut each row into three blocks:
688        left to the diag block, diag block, right to the diag block. The diag block have col ids in [cstart,cend).
689        Suppose a row of C has len nonzeros, indexed by [0, len). We want to know two indices: cdstart and cdend,
690        such that the three blocks are [0,cdstart), [cdstart,cdend), [cdend,len). The following code equivalentaly
691        stores values of cdstart and cdend-cstart (aka Cdi[]) instead.
692      */
693     Cdstart = MatRowMapKokkosView("Cdstart", m);
694     PetscCall(PetscLayoutGetRange(C->cmap, &cstart, &cend)); /* Not MatGetOwnershipRangeColumn() since C has not been preallocated yet */
695 
696     /* I could use RangePolicy and one thread per row. But since each thread essentially does binary search, threads in a
697       CUDA warp would completely diverge. So I use TeamPolicy with a team size 1.
698      */
699     Kokkos::parallel_for(
700       Kokkos::TeamPolicy<>(m, 1), KOKKOS_LAMBDA(const KokkosTeamMemberType &t) {
701         Kokkos::single(Kokkos::PerTeam(t), [=]() {                               /* Only one thread works in a team */
702                                                    PetscInt i = t.league_rank(); /* row i */
703                                                    PetscInt j, first, count, step;
704 
705                                                    if (i == 0) { /* Set the first entry of the i arrays to zero on device, to be used in CSR */
706                                                      Cdi(0) = 0;
707                                                      Coi(0) = 0;
708                                                    }
709 
710                                                    /* Do std::lower_bound(Ci(i),Ci(i+1),cstart) on Cj[]. We use j as the iterator. lower_bound() returns
711           in 'first' the first iterator with a value >= cstart, or last iterator if no such element is found.
712         */
713                                                    count = Ci(i + 1) - Ci(i);
714                                                    first = Ci(i);
715                                                    while (count > 0) {
716                                                      j    = first;
717                                                      step = count / 2;
718                                                      j += step;
719                                                      if (Cj(j) < cstart) {
720                                                        first = ++j;
721                                                        count -= step + 1;
722                                                      } else count = step;
723                                                    }
724                                                    Cdstart(i) = first - Ci(i); /* 'first' is the while-loop's output */
725 
726                                                    /* Do std::lower_bound(first,Ci(i+1),cend) on Cj[] */
727                                                    count = Ci(i + 1) - first;
728                                                    while (count > 0) {
729                                                      j    = first;
730                                                      step = count / 2;
731                                                      j += step;
732                                                      if (Cj(j) < cend) {
733                                                        first = ++j;
734                                                        count -= step + 1;
735                                                      } else count = step;
736                                                    }
737                                                    Cdi(i + 1) = first - (Ci(i) + Cdstart(i));     /* 'first' is the while-loop's output */
738                                                    Coi(i + 1) = (Ci(i + 1) - Ci(i)) - Cdi(i + 1); /* Co's row len = C's row len - Cd's row len */
739         });
740       });
741 
742     /* Convert row lens in Cdi[], Coi[] to CSR format using inclusive scan, e.g., changing [0,1,2,3] into [0,1,3,6] */
743     Kokkos::parallel_scan(
744       m + 1, KOKKOS_LAMBDA(const PetscInt i, PetscInt &update, const bool final) {
745         update += Cdi(i);
746         if (final) Cdi(i) = update;
747       });
748     Kokkos::parallel_scan(
749       m + 1, KOKKOS_LAMBDA(const PetscInt i, PetscInt &update, const bool final) {
750         update += Coi(i);
751         if (final) Coi(i) = update;
752       });
753 
754     /* Get Cdi, Coi on host (it is not a waste, since we do need them on host in
755        MatCreateSeqAIJKokkosWithCSRMatrix() below), then get nnz of Cd and Co.
756     */
757     Cdi_dual.modify_device();
758     Coi_dual.modify_device();
759     Cdi_dual.sync_host();
760     Coi_dual.sync_host();
761     PetscInt Cd_nnz = Cdi_dual.view_host().data()[m];
762     PetscInt Co_nnz = Coi_dual.view_host().data()[m];
763 
764     /* With nnz, allocate a, j for Cd and Co */
765     MatColIdxKokkosDualView Cdj_dual("j", Cd_nnz), Coj_dual("j", Co_nnz);
766     MatScalarKokkosDualView Cda_dual("a", Cd_nnz), Coa_dual("a", Co_nnz);
767 
768     /* Fill a, j of Cd and Co on device */
769     MatColIdxKokkosView Cdj = Cdj_dual.view_device(), Coj = Coj_dual.view_device();
770     MatScalarKokkosView Cda = Cda_dual.view_device(), Coa = Coa_dual.view_device();
771 
772     Kokkos::parallel_for(
773       Kokkos::TeamPolicy<>(m, Kokkos::AUTO()), KOKKOS_LAMBDA(const KokkosTeamMemberType &t) {
774         PetscInt i       = t.league_rank();     /* row i */
775         PetscInt clen    = Ci(i + 1) - Ci(i);   /* len of row i of C */
776         PetscInt cdlen   = Cdi(i + 1) - Cdi(i); /* len of row i of Cd */
777         PetscInt cdstart = Cdstart(i);          /* [start, end) of row i of Cd in C */
778         PetscInt cdend   = cdstart + cdlen;
779         /* [0, clen) is cut into three blocks: [0, cdstart), [cdstart, cdend), [cdend, clen) */
780         Kokkos::parallel_for(Kokkos::TeamThreadRange(t, clen), [&](PetscInt k) {
781           if (k < cdstart) { /* k in [0, cdstart) */
782             Coa(Coi(i) + k) = Ca(Ci(i) + k);
783             Coj(Coi(i) + k) = Cj(Ci(i) + k);
784           } else if (k < cdend) { /* k in [cdstart, cdend) */
785             Cda(Cdi(i) + (k - cdstart)) = Ca(Ci(i) + k);
786             Cdj(Cdi(i) + (k - cdstart)) = Cj(Ci(i) + k) - cstart; /* Use local col ids in Cdj */
787           } else {                                                /* k in [cdend, clen) */
788             Coa(Coi(i) + k - cdlen) = Ca(Ci(i) + k);
789             Coj(Coi(i) + k - cdlen) = Cj(Ci(i) + k);
790           }
791         });
792       });
793 
794     Cdj_dual.modify_device();
795     Cda_dual.modify_device();
796     Coj_dual.modify_device();
797     Coa_dual.modify_device();
798     /* With a, i, j for Cd and Co, finally build Cd, Co and then C. Their offloadmask will be set in each's MatAssemblyEnd */
799     auto cdkok = new Mat_SeqAIJKokkos(m, n, Cd_nnz, Cdi_dual, Cdj_dual, Cda_dual);
800     auto cokok = new Mat_SeqAIJKokkos(m, N, Co_nnz, Coi_dual, Coj_dual, Coa_dual);
801     PetscCall(MatCreateSeqAIJKokkosWithCSRMatrix(PETSC_COMM_SELF, cdkok, &Cd));
802     PetscCall(MatCreateSeqAIJKokkosWithCSRMatrix(PETSC_COMM_SELF, cokok, &Co));
803     PetscCall(MatSetMPIAIJKokkosWithSplitSeqAIJKokkosMatrices(C, Cd, Co)); /* Coj will be converted to local ids within */
804   }
805   PetscFunctionReturn(0);
806 }
807 
808 /* MatSeqAIJCompactOutExtraColumns_SeqAIJKokkos - Compact a SEQAIJKOKKS matrix's global col ids.
809 
810   It is similar to MatSeqAIJCompactOutExtraColumns_SeqAIJ, but it applies to SEQAIJKOKKOS and returns the l2g map in Kokkos view.
811 
812   Input Parameters:
813 +  C        - the MATMPIAIJKOKKOS matrix, of size m,n,M,N
814 .  reuse    - indicate whether the matrix has called this function before
815 .  csrmat   - the KokkosCsrMatrix, of size m,N
816 -  Cdoffset - when reuse == MAT_REUSE_MATRIX, it is an input parameter. For each row in csrmat, it stores the offset of the first
817               entry of the diag block of C in csrmat's j array.
818 
819   Output Parameters:
820 +  C        - the updated MATMPIAIJKOKKOS matrix
821 -  Cdoffset - when reuse == MAT_INITIAL_MATRIX, it is an output parameter
822 
823   Note:
824   the input matrix's col ids and col size will be changed.
825 */
826 static PetscErrorCode MatSeqAIJCompactOutExtraColumns_SeqAIJKokkos(Mat C, MatColIdxKokkosView &l2g) {
827   Mat_SeqAIJKokkos      *ckok;
828   ISLocalToGlobalMapping l2gmap;
829   const PetscInt        *garray;
830   PetscInt               sz;
831 
832   PetscFunctionBegin;
833   /* Compact P_other's global col ids and col size. We do it since we guess with local ids KK might be more memory scalable */
834   PetscCall(MatSeqAIJCompactOutExtraColumns_SeqAIJ(C, &l2gmap));
835   ckok = static_cast<Mat_SeqAIJKokkos *>(C->spptr);
836   ckok->j_dual.modify_host(); /* P_other's j is modified on host; we need to sync it on device */
837   ckok->j_dual.sync_device();
838   ckok->SetColSize(C->cmap->n); /* Update col size of the csrmat in spptr */
839 
840   /* Build l2g -- the local to global mapping of C's cols */
841   PetscCall(ISLocalToGlobalMappingGetIndices(l2gmap, &garray));
842   PetscCall(ISLocalToGlobalMappingGetSize(l2gmap, &sz));
843   PetscCheck(C->cmap->n == sz, PETSC_COMM_SELF, PETSC_ERR_PLIB, "matrix column size(%" PetscInt_FMT ") != l2g mapping size(%" PetscInt_FMT ")", C->cmap->n, sz);
844 
845   ConstMatColIdxKokkosViewHost tmp(garray, sz);
846   l2g = MatColIdxKokkosView("l2g", sz);
847   Kokkos::deep_copy(l2g, tmp);
848 
849   PetscCall(ISLocalToGlobalMappingRestoreIndices(l2gmap, &garray));
850   PetscCall(ISLocalToGlobalMappingDestroy(&l2gmap));
851   PetscFunctionReturn(0);
852 }
853 
854 /* MatProductSymbolic_MPIAIJKokkos_AB - AB flavor of MatProductSymbolic_MPIAIJKokkos
855 
856   Input Parameters:
857 +  product  - Mat_Product which carried out the computation. Passed in to access info about this mat product.
858 .  A        - an MPIAIJKOKKOS matrix
859 .  B        - an MPIAIJKOKKOS matrix
860 -  mm       - a struct used to stash intermediate data when computing AB. Persist from symbolic to numeric operations.
861 
862   Note: The local part of the result C is stored as mm->C_global, which is of type KokkosCsrMatrix and uses global col ids.
863 */
864 static PetscErrorCode MatProductSymbolic_MPIAIJKokkos_AB(Mat_Product *product, Mat A, Mat B, MatMatStruct_AB *mm) {
865   Mat_MPIAIJ              *a  = static_cast<Mat_MPIAIJ *>(A->data);
866   Mat                      Ad = a->A, Ao = a->B; /* diag and offdiag of A */
867   IS                       glob = NULL;
868   const PetscInt          *garray;
869   PetscInt                 N = B->cmap->N, sz;
870   ConstMatColIdxKokkosView l2g1; /* two temp maps mapping local col ids to global ones */
871   MatColIdxKokkosView      l2g2;
872   Mat                      C1, C2; /* intermediate matrices */
873 
874   PetscFunctionBegin;
875   /* C1 = Ad * B_local. B_local is a matrix got by merging Bd and Bo, and uses local col ids */
876   PetscCall(MatMPIAIJGetLocalMatMerge(B, MAT_INITIAL_MATRIX, &glob, &mm->B_local));
877   PetscCall(MatProductCreate(Ad, mm->B_local, NULL, &C1));
878   PetscCall(MatProductSetType(C1, MATPRODUCT_AB));
879   PetscCall(MatProductSetFill(C1, product->fill));
880   C1->product->api_user = product->api_user;
881   PetscCall(MatProductSetFromOptions(C1));
882   PetscUseTypeMethod(C1, productsymbolic);
883 
884   PetscCall(ISGetIndices(glob, &garray));
885   PetscCall(ISGetSize(glob, &sz));
886   const auto &tmp = ConstMatColIdxKokkosViewHost(garray, sz);                       /* wrap garray as a view */
887   l2g1            = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), tmp); /* maybe just an alias to tmp, so we restore garray at the very end */
888   PetscCall(MatSeqAIJKokkosGetCSRMatrixWithGlobalColumnIds(C1, N, l2g1, mm->C1_global));
889 
890   /* C2 = Ao * B_other. B_other is a matrix consisting of needed rows of B gathered from other procs */
891   PetscCall(MatSeqAIJKokkosBcast(mm->B_local, MAT_INITIAL_MATRIX, N, l2g1, a->Mvctx, mm->sf, mm->abuf, mm->rows, mm->rowoffset, mm->B_other));
892 
893   /* Compact B_other to use local ids as we guess KK spgemm is more memroy scalable with that; We could skip the compaction to simplify code */
894   PetscCall(MatSeqAIJCompactOutExtraColumns_SeqAIJKokkos(mm->B_other, l2g2));
895   PetscCall(MatProductCreate(Ao, mm->B_other, NULL, &C2));
896   PetscCall(MatProductSetType(C2, MATPRODUCT_AB));
897   PetscCall(MatProductSetFill(C2, product->fill));
898   C2->product->api_user = product->api_user;
899   PetscCall(MatProductSetFromOptions(C2));
900   PetscUseTypeMethod(C2, productsymbolic);
901   PetscCall(MatSeqAIJKokkosGetCSRMatrixWithGlobalColumnIds(C2, N, l2g2, mm->C2_global));
902 
903   /* C = C1 + C2.  We actually use their global col ids versions in adding */
904   mm->kh.create_spadd_handle(false); /* Input C1, C2 are NOT sorted, since B_local, B_other are not */
905   KokkosSparse::spadd_symbolic(&mm->kh, mm->C1_global, mm->C2_global, mm->C_global);
906   /* Have to do numeric since spadd_symbolic does not really populate column indices of the result matrix */
907   KokkosSparse::spadd_numeric(&mm->kh, (MatScalarType)1.0, mm->C1_global, (MatScalarType)1.0, mm->C2_global, mm->C_global);
908 
909   mm->C1 = C1;
910   mm->C2 = C2;
911   PetscCall(ISRestoreIndices(glob, &garray));
912   PetscCall(ISDestroy(&glob));
913   PetscFunctionReturn(0);
914 }
915 
916 /* MatProductSymbolic_MPIAIJKokkos_AtB - A^tB flavor of MatProductSymbolic_MPIAIJKokkos
917 
918   Input Parameters:
919 +  product  - Mat_Product which carried out the computation. Passed in to access info about this mat product.
920 .  A        - an MPIAIJKOKKOS matrix
921 .  B        - a SEQAIJKOKKOS matrix. It works as if A^t is multiplied by a parallel matrix made up of Bs on each rank.
922 .  localB   - Does B use local col ids? If false, then B is already in global col ids.
923 .  N        - col size of the "parallel B matrix". It implies B's global col ids are in range of [0,N) and N is the same across the communicator.
924 .  l2g      - If localB, then l2g maps B's local col ids to global ones.
925 -  mm       - a struct used to stash intermediate data in AtB
926 
927   Note: The local part of the result C is stored as mm->C_global, which is of type KokkosCsrMatrix and uses global col ids.
928 */
929 static PetscErrorCode MatProductSymbolic_MPIAIJKokkos_AtB(Mat_Product *product, Mat A, Mat B, PetscBool localB, PetscInt N, const ConstMatColIdxKokkosView &l2g, MatMatStruct_AtB *mm) {
930   Mat_MPIAIJ *a  = static_cast<Mat_MPIAIJ *>(A->data);
931   Mat         Ad = a->A, Ao = a->B; /* diag and offdiag of A */
932   Mat         C1, C2;               /* intermediate matrices */
933 
934   PetscFunctionBegin;
935   /* C1 = Ad^t * B */
936   PetscCall(MatProductCreate(Ad, B, NULL, &C1));
937   PetscCall(MatProductSetType(C1, MATPRODUCT_AtB));
938   PetscCall(MatProductSetFill(C1, product->fill));
939   C1->product->api_user = product->api_user;
940   PetscCall(MatProductSetFromOptions(C1));
941   PetscUseTypeMethod(C1, productsymbolic);
942 
943   if (localB) PetscCall(MatSeqAIJKokkosGetCSRMatrixWithGlobalColumnIds(C1, N, l2g, mm->C1_global));
944   else mm->C1_global = static_cast<Mat_SeqAIJKokkos *>(C1->spptr)->csrmat; /* the csrmat already uses global col ids */
945 
946   /* C2 = Ao^t * B */
947   PetscCall(MatProductCreate(Ao, B, NULL, &C2));
948   PetscCall(MatProductSetType(C2, MATPRODUCT_AtB));
949   PetscCall(MatProductSetFill(C2, product->fill));
950   C2->product->api_user = product->api_user;
951   PetscCall(MatProductSetFromOptions(C2));
952   PetscUseTypeMethod(C2, productsymbolic);
953 
954   PetscCall(MatSeqAIJKokkosReduce(C2, MAT_INITIAL_MATRIX, localB, N, l2g, a->Mvctx, mm->sf, mm->abuf, mm->srcrowoffset, mm->dstrowoffset, mm->C2_global));
955 
956   mm->kh.create_spadd_handle(false); /* Input C1, C2 are NOT sorted, since B may be not */
957   KokkosSparse::spadd_symbolic(&mm->kh, mm->C1_global, mm->C2_global, mm->C_global);
958   /* Have to do numeric since spadd_symbolic does not really populate column indices of the result matrix */
959   KokkosSparse::spadd_numeric(&mm->kh, (MatScalarType)1.0, mm->C1_global, (MatScalarType)1.0, mm->C2_global, mm->C_global);
960   mm->C1 = C1;
961   mm->C2 = C2;
962   PetscFunctionReturn(0);
963 }
964 
965 PetscErrorCode MatProductNumeric_MPIAIJKokkos(Mat C) {
966   Mat_Product                 *product = C->product;
967   MatProductType               ptype;
968   MatProductData_MPIAIJKokkos *mmdata;
969   MatMatStruct                *mm = NULL;
970   MatMatStruct_AB             *ab;
971   MatMatStruct_AtB            *atb;
972   Mat                          A, B, Ad, Ao, Bd, Bo;
973   const MatScalarType          one = 1.0; /* Not use literal 1.0 directly, to avoid wrong template instantiation in KokkosSparse::spadd_numeric */
974 
975   PetscFunctionBegin;
976   MatCheckProduct(C, 1);
977   mmdata = static_cast<MatProductData_MPIAIJKokkos *>(product->data);
978   ptype  = product->type;
979   A      = product->A;
980   B      = product->B;
981   Ad     = static_cast<Mat_MPIAIJ *>(A->data)->A;
982   Ao     = static_cast<Mat_MPIAIJ *>(A->data)->B;
983   Bd     = static_cast<Mat_MPIAIJ *>(B->data)->A;
984   Bo     = static_cast<Mat_MPIAIJ *>(B->data)->B;
985 
986   if (mmdata->reusesym) {           /* We reached here through e.g., MatMatMult(A,B,MAT_INITIAL_MATRIX,..,C), where symbolic/numeric are combined */
987     mmdata->reusesym = PETSC_FALSE; /* So that next time when user calls MatMatMult(E,F,MAT_REUSE_MATRIX,..,C), we still do numeric  */
988     ab               = mmdata->mmAB;
989     atb              = mmdata->mmAtB;
990     if (ab) {
991       static_cast<MatProductData_SeqAIJKokkos *>(ab->C1->product->data)->reusesym = PETSC_FALSE;
992       static_cast<MatProductData_SeqAIJKokkos *>(ab->C2->product->data)->reusesym = PETSC_FALSE;
993     }
994     if (atb) {
995       static_cast<MatProductData_SeqAIJKokkos *>(atb->C1->product->data)->reusesym = PETSC_FALSE;
996       static_cast<MatProductData_SeqAIJKokkos *>(atb->C2->product->data)->reusesym = PETSC_FALSE;
997     }
998     PetscFunctionReturn(0);
999   }
1000 
1001   if (ptype == MATPRODUCT_AB) {
1002     ab = mmdata->mmAB;
1003     /* C1 = Ad * B_local */
1004     PetscCheck(ab->C1->ops->productnumeric && ab->C2->ops->productnumeric, PETSC_COMM_SELF, PETSC_ERR_PLIB, "Missing numeric op for MATPRODUCT_AB");
1005     PetscCall(MatMPIAIJGetLocalMatMerge(B, MAT_REUSE_MATRIX, NULL /* glob */, &ab->B_local));
1006     PetscCheck(ab->C1->product->B == ab->B_local, PETSC_COMM_SELF, PETSC_ERR_PLIB, "In MATPRODUCT_AB, internal mat product matrix C1->B has unexpectedly changed");
1007     if (ab->C1->product->A != Ad) PetscCall(MatProductReplaceMats(Ad, NULL, NULL, ab->C1));
1008     PetscCall((*ab->C1->ops->productnumeric)(ab->C1));
1009     PetscCall(MatSeqAIJKokkosBcast(ab->B_local, MAT_REUSE_MATRIX, 0 /* N */, MatColIdxKokkosView() /*l2g*/, NULL /*ownerSF*/, ab->sf, ab->abuf, ab->rows, ab->rowoffset, ab->B_other));
1010     /* C2 = Ao * B_other */
1011     PetscCheck(ab->C2->product->B == ab->B_other, PETSC_COMM_SELF, PETSC_ERR_PLIB, "In MATPRODUCT_AB, internal mat product matrix C2->B has unexpectedly changed");
1012     if (ab->C1->product->A != Ao) PetscCall(MatProductReplaceMats(Ao, NULL, NULL, ab->C2));
1013     PetscCall((*ab->C2->ops->productnumeric)(ab->C2));
1014     /* C = C1_global + C2_global */
1015     KokkosSparse::spadd_numeric(&ab->kh, one, ab->C1_global, one, ab->C2_global, ab->C_global);
1016     mm = static_cast<MatMatStruct *>(ab);
1017   } else if (ptype == MATPRODUCT_AtB) {
1018     atb = mmdata->mmAtB;
1019     PetscCheck(atb->C1->ops->productnumeric && atb->C2->ops->productnumeric, PETSC_COMM_SELF, PETSC_ERR_PLIB, "Missing numeric op for MATPRODUCT_AtB");
1020     /* C1 = Ad^t * B_local */
1021     PetscCall(MatMPIAIJGetLocalMatMerge(B, MAT_REUSE_MATRIX, NULL /* glob */, &atb->B_local));
1022     PetscCheck(atb->C1->product->B == atb->B_local, PETSC_COMM_SELF, PETSC_ERR_PLIB, "In MATPRODUCT_AtB, internal mat product matrix C1->B has unexpectedly changed");
1023     if (atb->C1->product->A != Ad) PetscCall(MatProductReplaceMats(Ad, NULL, NULL, atb->C1));
1024     PetscCall((*atb->C1->ops->productnumeric)(atb->C1));
1025 
1026     /* C2 = Ao^t * B_local */
1027     PetscCheck(atb->C2->product->B == atb->B_local, PETSC_COMM_SELF, PETSC_ERR_PLIB, "In MATPRODUCT_AtB, internal mat product matrix C2->B has unexpectedly changed");
1028     if (atb->C2->product->A != Ao) PetscCall(MatProductReplaceMats(Ao, NULL, NULL, atb->C2));
1029     PetscCall((*atb->C2->ops->productnumeric)(atb->C2));
1030     /* Form C2_global */
1031     PetscCall(MatSeqAIJKokkosReduce(atb->C2, MAT_REUSE_MATRIX, PETSC_TRUE, 0 /* N */, MatColIdxKokkosView() /*l2g*/, NULL /*ownerSF*/, atb->sf, atb->abuf, atb->srcrowoffset, atb->dstrowoffset, atb->C2_global));
1032     /* C = C1_global + C2_global */
1033     KokkosSparse::spadd_numeric(&atb->kh, one, atb->C1_global, one, atb->C2_global, atb->C_global);
1034     mm = static_cast<MatMatStruct *>(atb);
1035   } else if (ptype == MATPRODUCT_PtAP) { /* BtAB */
1036     ab = mmdata->mmAB;
1037     PetscCall(MatMPIAIJGetLocalMatMerge(B, MAT_REUSE_MATRIX, NULL /* glob */, &ab->B_local));
1038 
1039     /* ab->C1 = Ad * B_local */
1040     PetscCheck(ab->C1->product->B == ab->B_local, PETSC_COMM_SELF, PETSC_ERR_PLIB, "In MATPRODUCT_PtAP, internal mat product matrix ab->C1->B has unexpectedly changed");
1041     if (ab->C1->product->A != Ad) PetscCall(MatProductReplaceMats(Ad, NULL, NULL, ab->C1));
1042     PetscCall((*ab->C1->ops->productnumeric)(ab->C1));
1043     PetscCall(MatSeqAIJKokkosBcast(ab->B_local, MAT_REUSE_MATRIX, 0 /* N */, MatColIdxKokkosView() /*l2g*/, NULL /*ownerSF*/, ab->sf, ab->abuf, ab->rows, ab->rowoffset, ab->B_other));
1044     /* ab->C2 = Ao * B_other */
1045     if (ab->C2->product->A != Ao) PetscCall(MatProductReplaceMats(Ao, NULL, NULL, ab->C2));
1046     PetscCall((*ab->C2->ops->productnumeric)(ab->C2)); /* C2 = Ao * B_other */
1047     KokkosSparse::spadd_numeric(&ab->kh, one, ab->C1_global, one, ab->C2_global, ab->C_global);
1048 
1049     /* atb->C1 = Bd^t * ab->C_petsc */
1050     atb = mmdata->mmAtB;
1051     PetscCheck(atb->C1->product->B == ab->C_petsc, PETSC_COMM_SELF, PETSC_ERR_PLIB, "In MATPRODUCT_PtAP, internal mat product matrix atb->C1->B has unexpectedly changed");
1052     if (atb->C1->product->A != Bd) PetscCall(MatProductReplaceMats(Bd, NULL, NULL, atb->C1));
1053     PetscCall((*atb->C1->ops->productnumeric)(atb->C1));
1054     /* atb->C2 = Bo^t * ab->C_petsc */
1055     if (atb->C2->product->A != Bo) PetscCall(MatProductReplaceMats(Bo, NULL, NULL, atb->C2));
1056     PetscCall((*atb->C2->ops->productnumeric)(atb->C2));
1057     PetscCall(MatSeqAIJKokkosReduce(atb->C2, MAT_REUSE_MATRIX, PETSC_FALSE, 0 /* N */, MatColIdxKokkosView() /*l2g*/, NULL /* ownerSF */, atb->sf, atb->abuf, atb->srcrowoffset, atb->dstrowoffset, atb->C2_global));
1058     KokkosSparse::spadd_numeric(&atb->kh, one, atb->C1_global, one, atb->C2_global, atb->C_global);
1059     mm = static_cast<MatMatStruct *>(atb);
1060   }
1061   /* Split C_global to form C */
1062   PetscCall(MatSetMPIAIJKokkosWithGlobalCSRMatrix(C, MAT_REUSE_MATRIX, mm->C_global, mm->Cdstart));
1063   PetscFunctionReturn(0);
1064 }
1065 
1066 PetscErrorCode MatProductSymbolic_MPIAIJKokkos(Mat C) {
1067   Mat                          A, B;
1068   Mat_Product                 *product = C->product;
1069   MatProductType               ptype;
1070   MatProductData_MPIAIJKokkos *mmdata;
1071   MatMatStruct                *mm   = NULL;
1072   IS                           glob = NULL;
1073   const PetscInt              *garray;
1074   PetscInt                     m, n, M, N, sz;
1075   ConstMatColIdxKokkosView     l2g; /* map local col ids to global ones */
1076 
1077   PetscFunctionBegin;
1078   MatCheckProduct(C, 1);
1079   PetscCheck(!product->data, PetscObjectComm((PetscObject)C), PETSC_ERR_PLIB, "Product data not empty");
1080   ptype = product->type;
1081   A     = product->A;
1082   B     = product->B;
1083 
1084   switch (ptype) {
1085   case MATPRODUCT_AB:
1086     m = A->rmap->n;
1087     n = B->cmap->n;
1088     M = A->rmap->N;
1089     N = B->cmap->N;
1090     break;
1091   case MATPRODUCT_AtB:
1092     m = A->cmap->n;
1093     n = B->cmap->n;
1094     M = A->cmap->N;
1095     N = B->cmap->N;
1096     break;
1097   case MATPRODUCT_PtAP:
1098     m = B->cmap->n;
1099     n = B->cmap->n;
1100     M = B->cmap->N;
1101     N = B->cmap->N;
1102     break; /* BtAB */
1103   default: SETERRQ(PetscObjectComm((PetscObject)C), PETSC_ERR_PLIB, "Not for product type %s", MatProductTypes[ptype]);
1104   }
1105 
1106   PetscCall(MatSetSizes(C, m, n, M, N));
1107   PetscCall(PetscLayoutSetUp(C->rmap));
1108   PetscCall(PetscLayoutSetUp(C->cmap));
1109   PetscCall(MatSetType(C, ((PetscObject)A)->type_name));
1110 
1111   mmdata           = new MatProductData_MPIAIJKokkos();
1112   mmdata->reusesym = product->api_user;
1113 
1114   if (ptype == MATPRODUCT_AB) {
1115     mmdata->mmAB = new MatMatStruct_AB();
1116     PetscCall(MatProductSymbolic_MPIAIJKokkos_AB(product, A, B, mmdata->mmAB));
1117     mm = static_cast<MatMatStruct *>(mmdata->mmAB);
1118   } else if (ptype == MATPRODUCT_AtB) {
1119     mmdata->mmAtB = new MatMatStruct_AtB();
1120     auto atb      = mmdata->mmAtB;
1121     PetscCall(MatMPIAIJGetLocalMatMerge(B, MAT_INITIAL_MATRIX, &glob, &atb->B_local));
1122     PetscCall(ISGetIndices(glob, &garray));
1123     PetscCall(ISGetSize(glob, &sz));
1124     l2g = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), ConstMatColIdxKokkosViewHost(garray, sz));
1125     PetscCall(MatProductSymbolic_MPIAIJKokkos_AtB(product, A, atb->B_local, PETSC_TRUE, N, l2g, atb));
1126     PetscCall(ISRestoreIndices(glob, &garray));
1127     PetscCall(ISDestroy(&glob));
1128     mm = static_cast<MatMatStruct *>(atb);
1129   } else if (ptype == MATPRODUCT_PtAP) {    /* BtAB */
1130     mmdata->mmAB  = new MatMatStruct_AB();  /* tmp=A*B */
1131     mmdata->mmAtB = new MatMatStruct_AtB(); /* C=B^t*tmp */
1132     auto ab       = mmdata->mmAB;
1133     auto atb      = mmdata->mmAtB;
1134     PetscCall(MatProductSymbolic_MPIAIJKokkos_AB(product, A, B, ab));
1135     auto tmp = new Mat_SeqAIJKokkos(ab->C_global); /* Memory will be owned by ab->C_petsc */
1136     PetscCall(MatCreateSeqAIJKokkosWithCSRMatrix(PETSC_COMM_SELF, tmp, &ab->C_petsc));
1137     PetscCall(MatProductSymbolic_MPIAIJKokkos_AtB(product, B, ab->C_petsc, PETSC_FALSE, N, l2g /*not used*/, atb));
1138     mm = static_cast<MatMatStruct *>(atb);
1139   }
1140   /* Split the C_global into petsc A, B format */
1141   PetscCall(MatSetMPIAIJKokkosWithGlobalCSRMatrix(C, MAT_INITIAL_MATRIX, mm->C_global, mm->Cdstart));
1142   C->product->data       = mmdata;
1143   C->product->destroy    = MatProductDataDestroy_MPIAIJKokkos;
1144   C->ops->productnumeric = MatProductNumeric_MPIAIJKokkos;
1145   PetscFunctionReturn(0);
1146 }
1147 
1148 PETSC_INTERN PetscErrorCode MatProductSetFromOptions_MPIAIJKokkos(Mat mat) {
1149   Mat_Product *product = mat->product;
1150   PetscBool    match   = PETSC_FALSE;
1151   PetscBool    usecpu  = PETSC_FALSE;
1152 
1153   PetscFunctionBegin;
1154   MatCheckProduct(mat, 1);
1155   if (!product->A->boundtocpu && !product->B->boundtocpu) PetscCall(PetscObjectTypeCompare((PetscObject)product->B, ((PetscObject)product->A)->type_name, &match));
1156   if (match) { /* we can always fallback to the CPU if requested */
1157     switch (product->type) {
1158     case MATPRODUCT_AB:
1159       if (product->api_user) {
1160         PetscOptionsBegin(PetscObjectComm((PetscObject)mat), ((PetscObject)mat)->prefix, "MatMatMult", "Mat");
1161         PetscCall(PetscOptionsBool("-matmatmult_backend_cpu", "Use CPU code", "MatMatMult", usecpu, &usecpu, NULL));
1162         PetscOptionsEnd();
1163       } else {
1164         PetscOptionsBegin(PetscObjectComm((PetscObject)mat), ((PetscObject)mat)->prefix, "MatProduct_AB", "Mat");
1165         PetscCall(PetscOptionsBool("-mat_product_algorithm_backend_cpu", "Use CPU code", "MatMatMult", usecpu, &usecpu, NULL));
1166         PetscOptionsEnd();
1167       }
1168       break;
1169     case MATPRODUCT_AtB:
1170       if (product->api_user) {
1171         PetscOptionsBegin(PetscObjectComm((PetscObject)mat), ((PetscObject)mat)->prefix, "MatTransposeMatMult", "Mat");
1172         PetscCall(PetscOptionsBool("-mattransposematmult_backend_cpu", "Use CPU code", "MatTransposeMatMult", usecpu, &usecpu, NULL));
1173         PetscOptionsEnd();
1174       } else {
1175         PetscOptionsBegin(PetscObjectComm((PetscObject)mat), ((PetscObject)mat)->prefix, "MatProduct_AtB", "Mat");
1176         PetscCall(PetscOptionsBool("-mat_product_algorithm_backend_cpu", "Use CPU code", "MatTransposeMatMult", usecpu, &usecpu, NULL));
1177         PetscOptionsEnd();
1178       }
1179       break;
1180     case MATPRODUCT_PtAP:
1181       if (product->api_user) {
1182         PetscOptionsBegin(PetscObjectComm((PetscObject)mat), ((PetscObject)mat)->prefix, "MatPtAP", "Mat");
1183         PetscCall(PetscOptionsBool("-matptap_backend_cpu", "Use CPU code", "MatPtAP", usecpu, &usecpu, NULL));
1184         PetscOptionsEnd();
1185       } else {
1186         PetscOptionsBegin(PetscObjectComm((PetscObject)mat), ((PetscObject)mat)->prefix, "MatProduct_PtAP", "Mat");
1187         PetscCall(PetscOptionsBool("-mat_product_algorithm_backend_cpu", "Use CPU code", "MatPtAP", usecpu, &usecpu, NULL));
1188         PetscOptionsEnd();
1189       }
1190       break;
1191     default: break;
1192     }
1193     match = (PetscBool)!usecpu;
1194   }
1195   if (match) {
1196     switch (product->type) {
1197     case MATPRODUCT_AB:
1198     case MATPRODUCT_AtB:
1199     case MATPRODUCT_PtAP: mat->ops->productsymbolic = MatProductSymbolic_MPIAIJKokkos; break;
1200     default: break;
1201     }
1202   }
1203   /* fallback to MPIAIJ ops */
1204   if (!mat->ops->productsymbolic) PetscCall(MatProductSetFromOptions_MPIAIJ(mat));
1205   PetscFunctionReturn(0);
1206 }
1207 
1208 static PetscErrorCode MatSetPreallocationCOO_MPIAIJKokkos(Mat mat, PetscCount coo_n, PetscInt coo_i[], PetscInt coo_j[]) {
1209   Mat_MPIAIJ       *mpiaij = (Mat_MPIAIJ *)mat->data;
1210   Mat_MPIAIJKokkos *mpikok;
1211 
1212   PetscFunctionBegin;
1213   PetscCall(MatSetPreallocationCOO_MPIAIJ(mat, coo_n, coo_i, coo_j));
1214   mat->preallocated = PETSC_TRUE;
1215   PetscCall(MatAssemblyBegin(mat, MAT_FINAL_ASSEMBLY));
1216   PetscCall(MatAssemblyEnd(mat, MAT_FINAL_ASSEMBLY));
1217   PetscCall(MatZeroEntries(mat));
1218   mpikok = static_cast<Mat_MPIAIJKokkos *>(mpiaij->spptr);
1219   delete mpikok;
1220   mpiaij->spptr = new Mat_MPIAIJKokkos(mpiaij);
1221   PetscFunctionReturn(0);
1222 }
1223 
1224 static PetscErrorCode MatSetValuesCOO_MPIAIJKokkos(Mat mat, const PetscScalar v[], InsertMode imode) {
1225   Mat_MPIAIJ                 *mpiaij = static_cast<Mat_MPIAIJ *>(mat->data);
1226   Mat_MPIAIJKokkos           *mpikok = static_cast<Mat_MPIAIJKokkos *>(mpiaij->spptr);
1227   Mat                         A = mpiaij->A, B = mpiaij->B;
1228   PetscCount                  Annz = mpiaij->Annz, Annz2 = mpiaij->Annz2, Bnnz = mpiaij->Bnnz, Bnnz2 = mpiaij->Bnnz2;
1229   MatScalarKokkosView         Aa, Ba;
1230   MatScalarKokkosView         v1;
1231   MatScalarKokkosView        &vsend  = mpikok->sendbuf_d;
1232   const MatScalarKokkosView  &v2     = mpikok->recvbuf_d;
1233   const PetscCountKokkosView &Ajmap1 = mpikok->Ajmap1_d, Ajmap2 = mpikok->Ajmap2_d, Aimap2 = mpikok->Aimap2_d;
1234   const PetscCountKokkosView &Bjmap1 = mpikok->Bjmap1_d, Bjmap2 = mpikok->Bjmap2_d, Bimap2 = mpikok->Bimap2_d;
1235   const PetscCountKokkosView &Aperm1 = mpikok->Aperm1_d, Aperm2 = mpikok->Aperm2_d, Bperm1 = mpikok->Bperm1_d, Bperm2 = mpikok->Bperm2_d;
1236   const PetscCountKokkosView &Cperm1 = mpikok->Cperm1_d;
1237   PetscMemType                memtype;
1238 
1239   PetscFunctionBegin;
1240   PetscCall(PetscGetMemType(v, &memtype)); /* Return PETSC_MEMTYPE_HOST when v is NULL */
1241   if (PetscMemTypeHost(memtype)) {         /* If user gave v[] in host, we need to copy it to device if any */
1242     v1 = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), MatScalarKokkosViewHost((PetscScalar *)v, mpiaij->coo_n));
1243   } else {
1244     v1 = MatScalarKokkosView((PetscScalar *)v, mpiaij->coo_n); /* Directly use v[]'s memory */
1245   }
1246 
1247   if (imode == INSERT_VALUES) {
1248     PetscCall(MatSeqAIJGetKokkosViewWrite(A, &Aa)); /* write matrix values */
1249     PetscCall(MatSeqAIJGetKokkosViewWrite(B, &Ba));
1250   } else {
1251     PetscCall(MatSeqAIJGetKokkosView(A, &Aa)); /* read & write matrix values */
1252     PetscCall(MatSeqAIJGetKokkosView(B, &Ba));
1253   }
1254 
1255   /* Pack entries to be sent to remote */
1256   Kokkos::parallel_for(
1257     vsend.extent(0), KOKKOS_LAMBDA(const PetscCount i) { vsend(i) = v1(Cperm1(i)); });
1258 
1259   /* Send remote entries to their owner and overlap the communication with local computation */
1260   PetscCall(PetscSFReduceWithMemTypeBegin(mpiaij->coo_sf, MPIU_SCALAR, PETSC_MEMTYPE_KOKKOS, vsend.data(), PETSC_MEMTYPE_KOKKOS, v2.data(), MPI_REPLACE));
1261   /* Add local entries to A and B in one kernel */
1262   Kokkos::parallel_for(
1263     Annz + Bnnz, KOKKOS_LAMBDA(PetscCount i) {
1264       PetscScalar sum = 0.0;
1265       if (i < Annz) {
1266         for (PetscCount k = Ajmap1(i); k < Ajmap1(i + 1); k++) sum += v1(Aperm1(k));
1267         Aa(i) = (imode == INSERT_VALUES ? 0.0 : Aa(i)) + sum;
1268       } else {
1269         i -= Annz;
1270         for (PetscCount k = Bjmap1(i); k < Bjmap1(i + 1); k++) sum += v1(Bperm1(k));
1271         Ba(i) = (imode == INSERT_VALUES ? 0.0 : Ba(i)) + sum;
1272       }
1273     });
1274   PetscCall(PetscSFReduceEnd(mpiaij->coo_sf, MPIU_SCALAR, vsend.data(), v2.data(), MPI_REPLACE));
1275 
1276   /* Add received remote entries to A and B in one kernel */
1277   Kokkos::parallel_for(
1278     Annz2 + Bnnz2, KOKKOS_LAMBDA(PetscCount i) {
1279       if (i < Annz2) {
1280         for (PetscCount k = Ajmap2(i); k < Ajmap2(i + 1); k++) Aa(Aimap2(i)) += v2(Aperm2(k));
1281       } else {
1282         i -= Annz2;
1283         for (PetscCount k = Bjmap2(i); k < Bjmap2(i + 1); k++) Ba(Bimap2(i)) += v2(Bperm2(k));
1284       }
1285     });
1286 
1287   if (imode == INSERT_VALUES) {
1288     PetscCall(MatSeqAIJRestoreKokkosViewWrite(A, &Aa)); /* Increase A & B's state etc. */
1289     PetscCall(MatSeqAIJRestoreKokkosViewWrite(B, &Ba));
1290   } else {
1291     PetscCall(MatSeqAIJRestoreKokkosView(A, &Aa));
1292     PetscCall(MatSeqAIJRestoreKokkosView(B, &Ba));
1293   }
1294   PetscFunctionReturn(0);
1295 }
1296 
1297 PetscErrorCode MatDestroy_MPIAIJKokkos(Mat A) {
1298   Mat_MPIAIJ *mpiaij = (Mat_MPIAIJ *)A->data;
1299 
1300   PetscFunctionBegin;
1301   PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatMPIAIJSetPreallocation_C", NULL));
1302   PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatMPIAIJGetLocalMatMerge_C", NULL));
1303   PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatSetPreallocationCOO_C", NULL));
1304   PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatSetValuesCOO_C", NULL));
1305   delete (Mat_MPIAIJKokkos *)mpiaij->spptr;
1306   PetscCall(MatDestroy_MPIAIJ(A));
1307   PetscFunctionReturn(0);
1308 }
1309 
1310 PETSC_INTERN PetscErrorCode MatConvert_MPIAIJ_MPIAIJKokkos(Mat A, MatType mtype, MatReuse reuse, Mat *newmat) {
1311   Mat         B;
1312   Mat_MPIAIJ *a;
1313 
1314   PetscFunctionBegin;
1315   if (reuse == MAT_INITIAL_MATRIX) {
1316     PetscCall(MatDuplicate(A, MAT_COPY_VALUES, newmat));
1317   } else if (reuse == MAT_REUSE_MATRIX) {
1318     PetscCall(MatCopy(A, *newmat, SAME_NONZERO_PATTERN));
1319   }
1320   B = *newmat;
1321 
1322   B->boundtocpu = PETSC_FALSE;
1323   PetscCall(PetscFree(B->defaultvectype));
1324   PetscCall(PetscStrallocpy(VECKOKKOS, &B->defaultvectype));
1325   PetscCall(PetscObjectChangeTypeName((PetscObject)B, MATMPIAIJKOKKOS));
1326 
1327   a = static_cast<Mat_MPIAIJ *>(A->data);
1328   if (a->A) PetscCall(MatSetType(a->A, MATSEQAIJKOKKOS));
1329   if (a->B) PetscCall(MatSetType(a->B, MATSEQAIJKOKKOS));
1330   if (a->lvec) PetscCall(VecSetType(a->lvec, VECSEQKOKKOS));
1331 
1332   B->ops->assemblyend           = MatAssemblyEnd_MPIAIJKokkos;
1333   B->ops->mult                  = MatMult_MPIAIJKokkos;
1334   B->ops->multadd               = MatMultAdd_MPIAIJKokkos;
1335   B->ops->multtranspose         = MatMultTranspose_MPIAIJKokkos;
1336   B->ops->productsetfromoptions = MatProductSetFromOptions_MPIAIJKokkos;
1337   B->ops->destroy               = MatDestroy_MPIAIJKokkos;
1338 
1339   PetscCall(PetscObjectComposeFunction((PetscObject)B, "MatMPIAIJSetPreallocation_C", MatMPIAIJSetPreallocation_MPIAIJKokkos));
1340   PetscCall(PetscObjectComposeFunction((PetscObject)B, "MatMPIAIJGetLocalMatMerge_C", MatMPIAIJGetLocalMatMerge_MPIAIJKokkos));
1341   PetscCall(PetscObjectComposeFunction((PetscObject)B, "MatSetPreallocationCOO_C", MatSetPreallocationCOO_MPIAIJKokkos));
1342   PetscCall(PetscObjectComposeFunction((PetscObject)B, "MatSetValuesCOO_C", MatSetValuesCOO_MPIAIJKokkos));
1343   PetscFunctionReturn(0);
1344 }
1345 /*MC
1346    MATAIJKOKKOS - "mpiaijkokkos", a matrix type to be used for CSR sparse matrices with Kokkos
1347 
1348    A matrix type type using Kokkos-Kernels CrsMatrix type for portability across different device types
1349 
1350    Options Database Keys:
1351 .  -mat_type aijkokkos - sets the matrix type to "aijkokkos" during a call to MatSetFromOptions()
1352 
1353   Level: beginner
1354 
1355 .seealso: `MatCreateAIJKokkos()`, `MATSEQAIJKOKKOS`, `MATSEQAIJ`, `MATMPIAIJ`
1356 M*/
1357 PETSC_EXTERN PetscErrorCode MatCreate_MPIAIJKokkos(Mat A) {
1358   PetscFunctionBegin;
1359   PetscCall(PetscKokkosInitializeCheck());
1360   PetscCall(MatCreate_MPIAIJ(A));
1361   PetscCall(MatConvert_MPIAIJ_MPIAIJKokkos(A, MATMPIAIJKOKKOS, MAT_INPLACE_MATRIX, &A));
1362   PetscFunctionReturn(0);
1363 }
1364 
1365 /*@C
1366    MatCreateAIJKokkos - Creates a sparse matrix in `MATAIJKOKOS` (compressed row) format
1367    (the default parallel PETSc format).  This matrix will ultimately pushed down
1368    to Kokkos for calculations. For good matrix
1369    assembly performance the user should preallocate the matrix storage by setting
1370    the parameter nz (or the array nnz).  By setting these parameters accurately,
1371    performance during matrix assembly can be increased by more than a factor of 50.
1372 
1373    Collective
1374 
1375    Input Parameters:
1376 +  comm - MPI communicator, set to `PETSC_COMM_SELF`
1377 .  m - number of rows
1378 .  n - number of columns
1379 .  nz - number of nonzeros per row (same for all rows)
1380 -  nnz - array containing the number of nonzeros in the various rows
1381          (possibly different for each row) or NULL
1382 
1383    Output Parameter:
1384 .  A - the matrix
1385 
1386    It is recommended that one use the `MatCreate()`, `MatSetType()` and/or `MatSetFromOptions()`,
1387    MatXXXXSetPreallocation() paradigm instead of this routine directly.
1388    [MatXXXXSetPreallocation() is, for example, `MatSeqAIJSetPreallocation()`]
1389 
1390    Notes:
1391    If nnz is given then nz is ignored
1392 
1393    The AIJ format, also called compressed row storage), is fully compatible with standard Fortran 77
1394    storage.  That is, the stored row and column indices can begin at
1395    either one (as in Fortran) or zero.  See the users' manual for details.
1396 
1397    Specify the preallocated storage with either nz or nnz (not both).
1398    Set nz = `PETSC_DEFAULT` and nnz = NULL for PETSc to control dynamic memory
1399    allocation.  For large problems you MUST preallocate memory or you
1400    will get TERRIBLE performance, see the users' manual chapter on matrices.
1401 
1402    By default, this format uses inodes (identical nodes) when possible, to
1403    improve numerical efficiency of matrix-vector products and solves. We
1404    search for consecutive rows with the same nonzero structure, thereby
1405    reusing matrix information to achieve increased efficiency.
1406 
1407    Level: intermediate
1408 
1409 .seealso: `MATAIJKOKOS`, `MATSEQAIJKOKOS`, `MATMPIAIJKOKOS`, `MatCreate()`, `MatCreateAIJ()`, `MatSetValues()`, `MatSeqAIJSetColumnIndices()`, `MatCreateSeqAIJWithArrays()`, `MatCreateAIJ()`, `MATMPIAIJKOKKOS`, `MATAIJKOKKOS`
1410 @*/
1411 PetscErrorCode MatCreateAIJKokkos(MPI_Comm comm, PetscInt m, PetscInt n, PetscInt M, PetscInt N, PetscInt d_nz, const PetscInt d_nnz[], PetscInt o_nz, const PetscInt o_nnz[], Mat *A) {
1412   PetscMPIInt size;
1413 
1414   PetscFunctionBegin;
1415   PetscCall(MatCreate(comm, A));
1416   PetscCall(MatSetSizes(*A, m, n, M, N));
1417   PetscCallMPI(MPI_Comm_size(comm, &size));
1418   if (size > 1) {
1419     PetscCall(MatSetType(*A, MATMPIAIJKOKKOS));
1420     PetscCall(MatMPIAIJSetPreallocation(*A, d_nz, d_nnz, o_nz, o_nnz));
1421   } else {
1422     PetscCall(MatSetType(*A, MATSEQAIJKOKKOS));
1423     PetscCall(MatSeqAIJSetPreallocation(*A, d_nz, d_nnz));
1424   }
1425   PetscFunctionReturn(0);
1426 }
1427 
1428 // get GPU pointer to stripped down Mat. For both Seq and MPI Mat.
1429 PetscErrorCode MatKokkosGetDeviceMatWrite(Mat A, PetscSplitCSRDataStructure *B) {
1430   PetscMPIInt                size, rank;
1431   MPI_Comm                   comm;
1432   PetscSplitCSRDataStructure d_mat = NULL;
1433 
1434   PetscFunctionBegin;
1435   PetscCall(PetscObjectGetComm((PetscObject)A, &comm));
1436   PetscCallMPI(MPI_Comm_size(comm, &size));
1437   PetscCallMPI(MPI_Comm_rank(comm, &rank));
1438   if (size == 1) {
1439     PetscCall(MatSeqAIJKokkosGetDeviceMat(A, &d_mat));
1440     PetscCall(MatSeqAIJKokkosModifyDevice(A)); /* Since we are going to modify matrix values on device */
1441   } else {
1442     Mat_MPIAIJ *aij = (Mat_MPIAIJ *)A->data;
1443     PetscCall(MatSeqAIJKokkosGetDeviceMat(aij->A, &d_mat));
1444     PetscCall(MatSeqAIJKokkosModifyDevice(aij->A));
1445     PetscCall(MatSeqAIJKokkosModifyDevice(aij->B));
1446     PetscCheck(A->nooffprocentries || aij->donotstash, PetscObjectComm((PetscObject)A), PETSC_ERR_SUP, "Device assembly does not currently support offproc values insertion. Use MatSetOption(A,MAT_NO_OFF_PROC_ENTRIES,PETSC_TRUE) or MatSetOption(A,MAT_IGNORE_OFF_PROC_ENTRIES,PETSC_TRUE)");
1447   }
1448   // act like MatSetValues because not called on host
1449   if (A->assembled) {
1450     if (A->was_assembled) PetscCall(PetscInfo(A, "Assemble more than once already\n"));
1451     A->was_assembled = PETSC_TRUE; // this is done (lazy) in MatAssemble but we are not calling it anymore - done in AIJ AssemblyEnd, need here?
1452   } else {
1453     PetscCall(PetscInfo(A, "Warning !assemble ??? assembled=%" PetscInt_FMT "\n", A->assembled));
1454   }
1455   if (!d_mat) {
1456     struct _n_SplitCSRMat h_mat; /* host container */
1457     Mat_SeqAIJKokkos     *aijkokA;
1458     Mat_SeqAIJ           *jaca;
1459     PetscInt              n = A->rmap->n, nnz;
1460     Mat                   Amat;
1461     PetscInt             *colmap;
1462 
1463     /* create and copy h_mat */
1464     h_mat.M = A->cmap->N; // use for debug build
1465     PetscCall(PetscInfo(A, "Create device matrix in Kokkos\n"));
1466     if (size == 1) {
1467       Amat            = A;
1468       jaca            = (Mat_SeqAIJ *)A->data;
1469       h_mat.rstart    = 0;
1470       h_mat.rend      = A->rmap->n;
1471       h_mat.cstart    = 0;
1472       h_mat.cend      = A->cmap->n;
1473       h_mat.offdiag.i = h_mat.offdiag.j = NULL;
1474       h_mat.offdiag.a                   = NULL;
1475       aijkokA                           = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
1476     } else {
1477       Mat_MPIAIJ       *aij  = (Mat_MPIAIJ *)A->data;
1478       Mat_SeqAIJ       *jacb = (Mat_SeqAIJ *)aij->B->data;
1479       PetscInt          ii;
1480       Mat_SeqAIJKokkos *aijkokB;
1481 
1482       Amat    = aij->A;
1483       aijkokA = static_cast<Mat_SeqAIJKokkos *>(aij->A->spptr);
1484       aijkokB = static_cast<Mat_SeqAIJKokkos *>(aij->B->spptr);
1485       jaca    = (Mat_SeqAIJ *)aij->A->data;
1486       PetscCheck(!aij->B->cmap->n || aij->garray, comm, PETSC_ERR_PLIB, "MPIAIJ Matrix was assembled but is missing garray");
1487       PetscCheck(aij->B->rmap->n == aij->A->rmap->n, comm, PETSC_ERR_SUP, "Only support aij->B->rmap->n == aij->A->rmap->n");
1488       aij->donotstash          = PETSC_TRUE;
1489       aij->A->nooffprocentries = aij->B->nooffprocentries = A->nooffprocentries = PETSC_TRUE;
1490       jaca->nonew = jacb->nonew = PETSC_TRUE; // no more disassembly
1491       PetscCall(PetscCalloc1(A->cmap->N, &colmap));
1492       for (ii = 0; ii < aij->B->cmap->n; ii++) colmap[aij->garray[ii]] = ii + 1;
1493       // allocate B copy data
1494       h_mat.rstart = A->rmap->rstart;
1495       h_mat.rend   = A->rmap->rend;
1496       h_mat.cstart = A->cmap->rstart;
1497       h_mat.cend   = A->cmap->rend;
1498       nnz          = jacb->i[n];
1499       if (jacb->compressedrow.use) {
1500         const Kokkos::View<PetscInt *, Kokkos::HostSpace, Kokkos::MemoryTraits<Kokkos::Unmanaged>> h_i_k(jacb->i, n + 1);
1501         aijkokB->i_uncompressed_d = Kokkos::View<PetscInt *>(Kokkos::create_mirror(DefaultMemorySpace(), h_i_k));
1502         Kokkos::deep_copy(aijkokB->i_uncompressed_d, h_i_k);
1503         h_mat.offdiag.i = aijkokB->i_uncompressed_d.data();
1504       } else {
1505         h_mat.offdiag.i = aijkokB->i_device_data();
1506       }
1507       h_mat.offdiag.j = aijkokB->j_device_data();
1508       h_mat.offdiag.a = aijkokB->a_device_data();
1509       {
1510         Kokkos::View<PetscInt *, Kokkos::HostSpace, Kokkos::MemoryTraits<Kokkos::Unmanaged>> h_colmap_k(colmap, A->cmap->N);
1511         aijkokB->colmap_d = Kokkos::View<PetscInt *>(Kokkos::create_mirror(DefaultMemorySpace(), h_colmap_k));
1512         Kokkos::deep_copy(aijkokB->colmap_d, h_colmap_k);
1513         h_mat.colmap = aijkokB->colmap_d.data();
1514         PetscCall(PetscFree(colmap));
1515       }
1516       h_mat.offdiag.ignorezeroentries = jacb->ignorezeroentries;
1517       h_mat.offdiag.n                 = n;
1518     }
1519     // allocate A copy data
1520     nnz                          = jaca->i[n];
1521     h_mat.diag.n                 = n;
1522     h_mat.diag.ignorezeroentries = jaca->ignorezeroentries;
1523     PetscCallMPI(MPI_Comm_rank(comm, &h_mat.rank));
1524     PetscCheck(!jaca->compressedrow.use, PETSC_COMM_SELF, PETSC_ERR_PLIB, "A does not suppport compressed row (todo)");
1525     h_mat.diag.i = aijkokA->i_device_data();
1526     h_mat.diag.j = aijkokA->j_device_data();
1527     h_mat.diag.a = aijkokA->a_device_data();
1528     // copy pointers and metdata to device
1529     PetscCall(MatSeqAIJKokkosSetDeviceMat(Amat, &h_mat));
1530     PetscCall(MatSeqAIJKokkosGetDeviceMat(Amat, &d_mat));
1531     PetscCall(PetscInfo(A, "Create device Mat n=%" PetscInt_FMT " nnz=%" PetscInt_FMT "\n", h_mat.diag.n, nnz));
1532   }
1533   *B           = d_mat;       // return it, set it in Mat, and set it up
1534   A->assembled = PETSC_FALSE; // ready to write with matsetvalues - this done (lazy) in normal MatSetValues
1535   PetscFunctionReturn(0);
1536 }
1537 
1538 PETSC_INTERN PetscErrorCode MatSeqAIJKokkosGetOffloadMask(Mat A, const char **mask) {
1539   Mat_SeqAIJKokkos *aijkok = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
1540 
1541   PetscFunctionBegin;
1542   if (!aijkok) *mask = "AIJKOK_UNALLOCATED";
1543   else if (aijkok->a_dual.need_sync_host()) *mask = "PETSC_OFFLOAD_GPU";
1544   else if (aijkok->a_dual.need_sync_device()) *mask = "PETSC_OFFLOAD_CPU";
1545   else *mask = "PETSC_OFFLOAD_BOTH";
1546   PetscFunctionReturn(0);
1547 }
1548 
1549 PETSC_INTERN PetscErrorCode MatAIJKokkosPrintOffloadMask(Mat A) {
1550   PetscMPIInt size;
1551   Mat         Ad, Ao;
1552   const char *amask, *bmask;
1553 
1554   PetscFunctionBegin;
1555   PetscCallMPI(MPI_Comm_size(PetscObjectComm((PetscObject)A), &size));
1556 
1557   if (size == 1) {
1558     PetscCall(MatSeqAIJKokkosGetOffloadMask(A, &amask));
1559     PetscCall(PetscPrintf(PETSC_COMM_SELF, "%s\n", amask));
1560   } else {
1561     Ad = ((Mat_MPIAIJ *)A->data)->A;
1562     Ao = ((Mat_MPIAIJ *)A->data)->B;
1563     PetscCall(MatSeqAIJKokkosGetOffloadMask(Ad, &amask));
1564     PetscCall(MatSeqAIJKokkosGetOffloadMask(Ao, &bmask));
1565     PetscCall(PetscPrintf(PETSC_COMM_SELF, "Diag : Off-diag = %s : %s\n", amask, bmask));
1566   }
1567   PetscFunctionReturn(0);
1568 }
1569