xref: /petsc/src/mat/impls/aij/mpi/kokkos/mpiaijkok.kokkos.cxx (revision d2522c19e8fa9bca20aaca277941d9a63e71db6a)
1 #include <petscvec_kokkos.hpp>
2 #include <petscsf.h>
3 #include <petsc/private/sfimpl.h>
4 #include <../src/mat/impls/aij/mpi/mpiaij.h>
5 #include <../src/mat/impls/aij/mpi/kokkos/mpiaijkok.hpp>
6 #include <KokkosSparse_spadd.hpp>
7 
8 PetscErrorCode MatAssemblyEnd_MPIAIJKokkos(Mat A, MatAssemblyType mode) {
9   Mat_SeqAIJKokkos *aijkok;
10 
11   PetscFunctionBegin;
12   PetscCall(MatAssemblyEnd_MPIAIJ(A, mode));
13   aijkok = static_cast<Mat_SeqAIJKokkos *>(((Mat_MPIAIJ *)A->data)->A->spptr); /* Access spptr after MatAssemblyEnd_MPIAIJ(), which might have deleted old spptr */
14   if (aijkok && aijkok->device_mat_d.data()) {
15     A->offloadmask = PETSC_OFFLOAD_GPU; // in GPU mode, no going back. MatSetValues checks this
16   }
17 
18   PetscFunctionReturn(0);
19 }
20 
21 PetscErrorCode MatMPIAIJSetPreallocation_MPIAIJKokkos(Mat mat, PetscInt d_nz, const PetscInt d_nnz[], PetscInt o_nz, const PetscInt o_nnz[]) {
22   Mat_MPIAIJ *mpiaij = (Mat_MPIAIJ *)mat->data;
23 
24   PetscFunctionBegin;
25   PetscCall(PetscLayoutSetUp(mat->rmap));
26   PetscCall(PetscLayoutSetUp(mat->cmap));
27 #if defined(PETSC_USE_DEBUG)
28   if (d_nnz) {
29     PetscInt i;
30     for (i = 0; i < mat->rmap->n; i++) { PetscCheck(d_nnz[i] >= 0, PETSC_COMM_SELF, PETSC_ERR_ARG_OUTOFRANGE, "d_nnz cannot be less than 0: local row %" PetscInt_FMT " value %" PetscInt_FMT, i, d_nnz[i]); }
31   }
32   if (o_nnz) {
33     PetscInt i;
34     for (i = 0; i < mat->rmap->n; i++) { PetscCheck(o_nnz[i] >= 0, PETSC_COMM_SELF, PETSC_ERR_ARG_OUTOFRANGE, "o_nnz cannot be less than 0: local row %" PetscInt_FMT " value %" PetscInt_FMT, i, o_nnz[i]); }
35   }
36 #endif
37 #if defined(PETSC_USE_CTABLE)
38   PetscCall(PetscTableDestroy(&mpiaij->colmap));
39 #else
40   PetscCall(PetscFree(mpiaij->colmap));
41 #endif
42   PetscCall(PetscFree(mpiaij->garray));
43   PetscCall(VecDestroy(&mpiaij->lvec));
44   PetscCall(VecScatterDestroy(&mpiaij->Mvctx));
45   /* Because the B will have been resized we simply destroy it and create a new one each time */
46   PetscCall(MatDestroy(&mpiaij->B));
47 
48   if (!mpiaij->A) {
49     PetscCall(MatCreate(PETSC_COMM_SELF, &mpiaij->A));
50     PetscCall(MatSetSizes(mpiaij->A, mat->rmap->n, mat->cmap->n, mat->rmap->n, mat->cmap->n));
51     PetscCall(PetscLogObjectParent((PetscObject)mat, (PetscObject)mpiaij->A));
52   }
53   if (!mpiaij->B) {
54     PetscMPIInt size;
55     PetscCallMPI(MPI_Comm_size(PetscObjectComm((PetscObject)mat), &size));
56     PetscCall(MatCreate(PETSC_COMM_SELF, &mpiaij->B));
57     PetscCall(MatSetSizes(mpiaij->B, mat->rmap->n, size > 1 ? mat->cmap->N : 0, mat->rmap->n, size > 1 ? mat->cmap->N : 0));
58     PetscCall(PetscLogObjectParent((PetscObject)mat, (PetscObject)mpiaij->B));
59   }
60   PetscCall(MatSetType(mpiaij->A, MATSEQAIJKOKKOS));
61   PetscCall(MatSetType(mpiaij->B, MATSEQAIJKOKKOS));
62   PetscCall(MatSeqAIJSetPreallocation(mpiaij->A, d_nz, d_nnz));
63   PetscCall(MatSeqAIJSetPreallocation(mpiaij->B, o_nz, o_nnz));
64   mat->preallocated = PETSC_TRUE;
65   PetscFunctionReturn(0);
66 }
67 
68 PetscErrorCode MatMult_MPIAIJKokkos(Mat mat, Vec xx, Vec yy) {
69   Mat_MPIAIJ *mpiaij = (Mat_MPIAIJ *)mat->data;
70   PetscInt    nt;
71 
72   PetscFunctionBegin;
73   PetscCall(VecGetLocalSize(xx, &nt));
74   PetscCheck(nt == mat->cmap->n, PETSC_COMM_SELF, PETSC_ERR_ARG_SIZ, "Incompatible partition of mat (%" PetscInt_FMT ") and xx (%" PetscInt_FMT ")", mat->cmap->n, nt);
75   PetscCall(VecScatterBegin(mpiaij->Mvctx, xx, mpiaij->lvec, INSERT_VALUES, SCATTER_FORWARD));
76   PetscCall((*mpiaij->A->ops->mult)(mpiaij->A, xx, yy));
77   PetscCall(VecScatterEnd(mpiaij->Mvctx, xx, mpiaij->lvec, INSERT_VALUES, SCATTER_FORWARD));
78   PetscCall((*mpiaij->B->ops->multadd)(mpiaij->B, mpiaij->lvec, yy, yy));
79   PetscFunctionReturn(0);
80 }
81 
82 PetscErrorCode MatMultAdd_MPIAIJKokkos(Mat mat, Vec xx, Vec yy, Vec zz) {
83   Mat_MPIAIJ *mpiaij = (Mat_MPIAIJ *)mat->data;
84   PetscInt    nt;
85 
86   PetscFunctionBegin;
87   PetscCall(VecGetLocalSize(xx, &nt));
88   PetscCheck(nt == mat->cmap->n, PETSC_COMM_SELF, PETSC_ERR_ARG_SIZ, "Incompatible partition of mat (%" PetscInt_FMT ") and xx (%" PetscInt_FMT ")", mat->cmap->n, nt);
89   PetscCall(VecScatterBegin(mpiaij->Mvctx, xx, mpiaij->lvec, INSERT_VALUES, SCATTER_FORWARD));
90   PetscCall((*mpiaij->A->ops->multadd)(mpiaij->A, xx, yy, zz));
91   PetscCall(VecScatterEnd(mpiaij->Mvctx, xx, mpiaij->lvec, INSERT_VALUES, SCATTER_FORWARD));
92   PetscCall((*mpiaij->B->ops->multadd)(mpiaij->B, mpiaij->lvec, zz, zz));
93   PetscFunctionReturn(0);
94 }
95 
96 PetscErrorCode MatMultTranspose_MPIAIJKokkos(Mat mat, Vec xx, Vec yy) {
97   Mat_MPIAIJ *mpiaij = (Mat_MPIAIJ *)mat->data;
98   PetscInt    nt;
99 
100   PetscFunctionBegin;
101   PetscCall(VecGetLocalSize(xx, &nt));
102   PetscCheck(nt == mat->rmap->n, PETSC_COMM_SELF, PETSC_ERR_ARG_SIZ, "Incompatible partition of mat (%" PetscInt_FMT ") and xx (%" PetscInt_FMT ")", mat->rmap->n, nt);
103   PetscCall((*mpiaij->B->ops->multtranspose)(mpiaij->B, xx, mpiaij->lvec));
104   PetscCall((*mpiaij->A->ops->multtranspose)(mpiaij->A, xx, yy));
105   PetscCall(VecScatterBegin(mpiaij->Mvctx, mpiaij->lvec, yy, ADD_VALUES, SCATTER_REVERSE));
106   PetscCall(VecScatterEnd(mpiaij->Mvctx, mpiaij->lvec, yy, ADD_VALUES, SCATTER_REVERSE));
107   PetscFunctionReturn(0);
108 }
109 
110 /* Merge the "A, B" matrices of mat into a matrix C.  mat's type is MPIAIJKOKKOS. C's type is MATSEQAIJKOKKOS.
111    A is put before B. C's size would be A->rmap->n by (A->cmap->n + B->cmap->n).
112    C still uses local column ids. Their corresponding global column ids are returned in glob.
113 */
114 PetscErrorCode MatMPIAIJGetLocalMatMerge_MPIAIJKokkos(Mat mat, MatReuse reuse, IS *glob, Mat *C) {
115   Mat             Ad, Ao;
116   const PetscInt *cmap;
117 
118   PetscFunctionBegin;
119   PetscCall(MatMPIAIJGetSeqAIJ(mat, &Ad, &Ao, &cmap));
120   PetscCall(MatSeqAIJKokkosMergeMats(Ad, Ao, reuse, C));
121   if (glob) {
122     PetscInt cst, i, dn, on, *gidx;
123     PetscCall(MatGetLocalSize(Ad, NULL, &dn));
124     PetscCall(MatGetLocalSize(Ao, NULL, &on));
125     PetscCall(MatGetOwnershipRangeColumn(mat, &cst, NULL));
126     PetscCall(PetscMalloc1(dn + on, &gidx));
127     for (i = 0; i < dn; i++) gidx[i] = cst + i;
128     for (i = 0; i < on; i++) gidx[i + dn] = cmap[i];
129     PetscCall(ISCreateGeneral(PetscObjectComm((PetscObject)Ad), dn + on, gidx, PETSC_OWN_POINTER, glob));
130   }
131   PetscFunctionReturn(0);
132 }
133 
134 /* Structs used in matrix product C=AB, C=A^tB and C=B^tAB */
135 struct MatMatStruct {
136   MatRowMapKokkosView Cdstart; /* Used to split sequential matrix into petsc's A, B format */
137   PetscSF             sf;      /* SF to send/recv matrix entries */
138   MatScalarKokkosView abuf;    /* buf of mat values in send/recv */
139   Mat                 C1, C2, B_local;
140   KokkosCsrMatrix     C1_global, C2_global, C_global;
141   KernelHandle        kh;
142   MatMatStruct() {
143     C1 = C2 = B_local = NULL;
144     sf                = NULL;
145   }
146 
147   ~MatMatStruct() {
148     MatDestroy(&C1);
149     MatDestroy(&C2);
150     MatDestroy(&B_local);
151     PetscSFDestroy(&sf);
152     kh.destroy_spadd_handle();
153   }
154 };
155 
156 struct MatMatStruct_AB : public MatMatStruct {
157   MatColIdxKokkosView rows;
158   MatRowMapKokkosView rowoffset;
159   Mat                 B_other, C_petsc; /* SEQAIJKOKKOS matrices. TODO: have a better var name than C_petsc */
160 
161   MatMatStruct_AB() : B_other(NULL), C_petsc(NULL) { }
162   ~MatMatStruct_AB() {
163     MatDestroy(&B_other);
164     MatDestroy(&C_petsc);
165   }
166 };
167 
168 struct MatMatStruct_AtB : public MatMatStruct {
169   MatRowMapKokkosView srcrowoffset, dstrowoffset;
170 };
171 
172 struct MatProductData_MPIAIJKokkos {
173   MatMatStruct_AB  *mmAB;
174   MatMatStruct_AtB *mmAtB;
175   PetscBool         reusesym;
176 
177   MatProductData_MPIAIJKokkos() : mmAB(NULL), mmAtB(NULL), reusesym(PETSC_FALSE) { }
178   ~MatProductData_MPIAIJKokkos() {
179     delete mmAB;
180     delete mmAtB;
181   }
182 };
183 
184 static PetscErrorCode MatProductDataDestroy_MPIAIJKokkos(void *data) {
185   PetscFunctionBegin;
186   PetscCallCXX(delete static_cast<MatProductData_MPIAIJKokkos *>(data));
187   PetscFunctionReturn(0);
188 }
189 
190 /* MatSeqAIJKokkosGetCSRMatrixWithGlobalColumnIds - Get a KokkosCsrMatrix from a MATSEQAIJKOKKOS matrix
191 
192    Input Parameters:
193 +  A       - the MATSEQAIJKOKKOS matrix
194 .  N       - new column size for the returned Kokkos matrix
195 -  l2g     - a map that maps old col ids to new col ids
196 
197    Output Parameters:
198 .  csrmat  - the Kokkos matrix, which has the same row size as A, shares a, i but not j with A.
199  */
200 static PetscErrorCode MatSeqAIJKokkosGetCSRMatrixWithGlobalColumnIds(Mat A, PetscInt N, const ConstMatColIdxKokkosView &l2g, KokkosCsrMatrix &csrmat) {
201   KokkosCsrMatrix    &orig = static_cast<Mat_SeqAIJKokkos *>(A->spptr)->csrmat;
202   MatColIdxKokkosView jg("jg", orig.nnz()); /* New j array for csrmat */
203 
204   PetscFunctionBegin;
205   PetscCallCXX(Kokkos::parallel_for(
206     orig.nnz(), KOKKOS_LAMBDA(const PetscInt i) { jg(i) = l2g(orig.graph.entries(i)); }));
207   PetscCallCXX(csrmat = KokkosCsrMatrix("csrmat", orig.numRows(), N, orig.nnz(), orig.values, orig.graph.row_map, jg));
208   PetscFunctionReturn(0);
209 }
210 
211 /* MatSetMPIAIJKokkosWithSplitSeqAIJKokkosMatrices - Set the diag and offdiag matrices of a MATMPIAIJKOKKOS matrix.
212    It is similar to MatCreateMPIAIJWithSplitArrays.
213 
214   Input Parameters:
215 +  mat   - the MATMPIAIJKOKKOS matrix, which should have its type and layout set, but should not have its diag, offdiag matrices set
216 .  A     - the diag matrix using local col ids
217 -  B     - the offdiag matrix using global col ids
218 
219   Output Parameters:
220 .  mat   - the updated MATMPIAIJKOKKOS matrix
221 */
222 static PetscErrorCode MatSetMPIAIJKokkosWithSplitSeqAIJKokkosMatrices(Mat mat, Mat A, Mat B) {
223   Mat_MPIAIJ       *mpiaij = static_cast<Mat_MPIAIJ *>(mat->data);
224   PetscInt          m, n, M, N, Am, An, Bm, Bn;
225   Mat_SeqAIJKokkos *bkok = static_cast<Mat_SeqAIJKokkos *>(B->spptr);
226 
227   PetscFunctionBegin;
228   PetscCall(MatGetSize(mat, &M, &N));
229   PetscCall(MatGetLocalSize(mat, &m, &n));
230   PetscCall(MatGetLocalSize(A, &Am, &An));
231   PetscCall(MatGetLocalSize(B, &Bm, &Bn));
232 
233   PetscCheck(m == Am && m == Bm, PETSC_COMM_SELF, PETSC_ERR_PLIB, "local number of rows do not match");
234   PetscCheck(n == An, PETSC_COMM_SELF, PETSC_ERR_PLIB, "local number of columns do not match");
235   PetscCheck(N == Bn, PETSC_COMM_SELF, PETSC_ERR_PLIB, "global number of columns do not match");
236   PetscCheck(!mpiaij->A && !mpiaij->B, PETSC_COMM_SELF, PETSC_ERR_PLIB, "A, B of the MPIAIJ matrix are not empty");
237   mpiaij->A = A;
238   mpiaij->B = B;
239 
240   mat->preallocated     = PETSC_TRUE;
241   mat->nooffprocentries = PETSC_TRUE; /* See MatAssemblyBegin_MPIAIJ. In effect, making MatAssemblyBegin a nop */
242 
243   PetscCall(MatSetOption(mat, MAT_NO_OFF_PROC_ENTRIES, PETSC_TRUE));
244   PetscCall(MatAssemblyBegin(mat, MAT_FINAL_ASSEMBLY));
245   /* MatAssemblyEnd is critical here. It sets mat->offloadmask according to A and B's, and
246     also gets mpiaij->B compacted, with its col ids and size reduced
247   */
248   PetscCall(MatAssemblyEnd(mat, MAT_FINAL_ASSEMBLY));
249   PetscCall(MatSetOption(mat, MAT_NO_OFF_PROC_ENTRIES, PETSC_FALSE));
250   PetscCall(MatSetOption(mat, MAT_NEW_NONZERO_LOCATION_ERR, PETSC_TRUE));
251 
252   /* Update bkok with new local col ids (stored on host) and size */
253   bkok->j_dual.modify_host();
254   bkok->j_dual.sync_device();
255   bkok->SetColSize(mpiaij->B->cmap->n);
256   PetscFunctionReturn(0);
257 }
258 
259 /* MatSeqAIJKokkosBcast - Bcast rows of a SEQAIJKOKKOS matrice (B) to form a SEQAIJKOKKOS matrix (C).
260 
261    It is essentially the MPIAIJKOKKOS counterpart of MatGetBrowsOfAoCols_MPIAIJ, but supports device and uses PetscSF.
262    In the given ownerSF, leaves correspond to rows in C, and roots correspond to rows in B. Roots may connect to multiple leaves.
263    Suppose C's j-th row is connected to a root identified by PetscSFNode (k,i), it means we will bcast the i-th row of B on rank k
264    to j-th row of C. ownerSF's leaves must be contiguous (in other words, as if ilocal=NULL was used to set its graph).
265 
266    Collective on comm of ownerSF
267 
268    Input Parameters:
269 +   B       - the SEQAIJKOKKOS matrix, using local col ids
270 .   reuse   - either MAT_INITIAL_MATRIX or MAT_REUSE_MATRIX
271 .   N       - global col ids are in range of [0,N). N Must be the same across ranks (nonsignificant in MAT_REUSE_MATRIX)
272 .   l2g     - a map mapping B's local col ids to global ones (nonsignificant in MAT_REUSE_MATRIX)
273 .   ownerSF - the ownership SF (nonsignificant in MAT_REUSE_MATRIX)
274 
275    Input/Output Parameters (out when resue = MAT_INITIAL_MATRIX, inout when reuse = MAT_REUSE_MATRIX)
276 +   bcastSF   - the SF used to bcast rows of B. This plain SF does buffer (abuf) to buffer (Ca) send/recv. In this SF, vertices are nonzeros.
277 .   abuf      - buffer for sending matrix values
278 .   rows      - array containing indices of (local) rows that this rank needs to bcast to others. Each receiver rank has a chunk in rows[].
279                 Values in rows[] might have repeats, which simply indicates a row will be bcast'ed to multiple neighbors.
280 .   rowoffset - For each row in rows[], it will be copied to rowoffset[] at abuf[]
281 -   C         -  the SEQAIJKOKKOS matrix made of the bcast'ed rows, using local col ids.
282 */
283 static PetscErrorCode MatSeqAIJKokkosBcast(Mat B, MatReuse reuse, PetscInt N, const ConstMatColIdxKokkosView &l2g, PetscSF ownerSF, PetscSF &bcastSF, MatScalarKokkosView &abuf, MatColIdxKokkosView &rows, MatRowMapKokkosView &rowoffset, Mat &C) {
284   Mat_SeqAIJKokkos *bkok, *ckok;
285 
286   PetscFunctionBegin;
287   PetscCall(MatSeqAIJKokkosSyncDevice(B)); /* Make sure B->spptr is accessible */
288   bkok = static_cast<Mat_SeqAIJKokkos *>(B->spptr);
289 
290   if (reuse == MAT_REUSE_MATRIX) {
291     ckok = static_cast<Mat_SeqAIJKokkos *>(C->spptr);
292 
293     const auto &Ba = bkok->a_dual.view_device();
294     const auto &Bi = bkok->i_dual.view_device();
295     const auto &Ca = ckok->a_dual.view_device();
296 
297     /* Copy Ba to abuf */
298     Kokkos::parallel_for(
299       Kokkos::TeamPolicy<>(rows.extent(0), Kokkos::AUTO()), KOKKOS_LAMBDA(const KokkosTeamMemberType &t) {
300         PetscInt i    = t.league_rank(); /* rows[i] is r-th row of B */
301         PetscInt r    = rows(i);
302         PetscInt base = rowoffset(i); /* Copy r-th row of B to this offset in abuf[] */
303         Kokkos::parallel_for(Kokkos::TeamThreadRange(t, Bi(r + 1) - Bi(r)), [&](PetscInt k) { abuf(base + k) = Ba(Bi(r) + k); });
304       });
305 
306     /* Send abuf to Ca through bcastSF and then mark C is updated on device */
307     PetscCall(PetscSFBcastBegin(bcastSF, MPIU_SCALAR, abuf.data(), Ca.data(), MPI_REPLACE)); /* TODO: get memtype for abuf */
308     PetscCall(PetscSFBcastEnd(bcastSF, MPIU_SCALAR, abuf.data(), Ca.data(), MPI_REPLACE));
309     ckok->a_dual.modify_device();
310   } else if (reuse == MAT_INITIAL_MATRIX) {
311     MPI_Comm    comm;
312     PetscMPIInt tag;
313     PetscInt    k, Cm, Cn, Cnnz, *Ci_h, nroots, nleaves;
314 
315     PetscCallMPI(PetscObjectGetComm((PetscObject)ownerSF, &comm));
316     PetscCall(PetscSFGetGraph(ownerSF, &nroots, &nleaves, NULL, NULL));
317     Cm = nleaves; /* row size of C */
318     Cn = N;       /* col size of C, which initially uses global ids, so we can safely set its col size as N */
319 
320     /* Get row lens (nz) of B's rows for later fast query */
321     PetscInt       *Browlens;
322     const PetscInt *tmp = bkok->i_host_data();
323     PetscCall(PetscMalloc1(nroots, &Browlens));
324     for (k = 0; k < nroots; k++) Browlens[k] = tmp[k + 1] - tmp[k];
325 
326     /* By ownerSF, each proc gets lens of rows of C */
327     MatRowMapKokkosDualView Ci("i", Cm + 1); /* C's rowmap */
328     Ci_h    = Ci.view_host().data();
329     Ci_h[0] = 0;
330     PetscCall(PetscSFBcastWithMemTypeBegin(ownerSF, MPIU_INT, PETSC_MEMTYPE_HOST, Browlens, PETSC_MEMTYPE_HOST, &Ci_h[1], MPI_REPLACE));
331     PetscCall(PetscSFBcastEnd(ownerSF, MPIU_INT, Browlens, &Ci_h[1], MPI_REPLACE));
332     for (k = 1; k < Cm + 1; k++) Ci_h[k] += Ci_h[k - 1]; /* Convert lens to CSR */
333     Cnnz = Ci_h[Cm];
334     Ci.modify_host();
335     Ci.sync_device();
336 
337     /* With the newly known Cnnz, we are able to allocate (j, a) for C on host & device */
338     MatColIdxKokkosDualView Cj("j", Cnnz);
339     MatScalarKokkosDualView Ca("a", Cnnz);
340 
341     /* Now build the bcastSF to fill Ca, Cj. This plain SF only does (contiguous) buffer to buffer send/recv */
342     const PetscMPIInt *iranks, *ranks;
343     const PetscInt    *ioffset, *irootloc, *roffset;
344     PetscInt           i, j, niranks, nranks, *sdisp, *rdisp, *rowptr;
345     MPI_Request       *reqs;
346 
347     PetscCall(PetscSFGetLeafRanks(ownerSF, &niranks, &iranks, &ioffset, &irootloc));                      /* irootloc[] contains indices of rows I need to send to each receiver */
348     PetscCall(PetscSFGetRootRanks(ownerSF, &nranks, &ranks, &roffset, NULL /*rmine*/, NULL /*rremote*/)); /* recv info */
349 
350     /* figure out offsets at the send buffer, to build the SF
351       sdisp[]  - stores offsets of nonzeros (in abuf or jbuf, see later) I need to send, per receiver.
352       rowptr[] - stores offsets for data of each row in abuf
353 
354       rdisp[]  - to receive sdisp[]
355     */
356     PetscCall(PetscMalloc3(niranks + 1, &sdisp, nranks, &rdisp, niranks + nranks, &reqs));
357     MatRowMapKokkosViewHost rowptr_h("rowptr_h", ioffset[niranks] + 1); /* Let Kokkos do the allocation, so that we can do an easy mirror later */
358     rowptr = rowptr_h.data();
359 
360     sdisp[0]  = 0;
361     rowptr[0] = 0;
362     for (i = 0; i < niranks; i++) { /* for each receiver */
363       PetscInt len, nz = 0;
364       for (j = ioffset[i]; j < ioffset[i + 1]; j++) { /* for each row to this receiver */
365         len           = Browlens[irootloc[j]];
366         rowptr[j + 1] = rowptr[j] + len;
367         nz += len;
368       }
369       sdisp[i + 1] = sdisp[i] + nz;
370     }
371     PetscCallMPI(PetscCommGetNewTag(comm, &tag));
372     for (i = 0; i < nranks; i++) PetscCallMPI(MPI_Irecv(&rdisp[i], 1, MPIU_INT, ranks[i], tag, comm, &reqs[i]));
373     for (i = 0; i < niranks; i++) PetscCallMPI(MPI_Isend(&sdisp[i], 1, MPIU_INT, iranks[i], tag, comm, &reqs[nranks + i]));
374     PetscCallMPI(MPI_Waitall(niranks + nranks, reqs, MPI_STATUSES_IGNORE));
375 
376     PetscInt     nleaves2 = Cnnz;           /* leaves are the nonzeros I will receive */
377     PetscInt     nroots2  = sdisp[niranks]; /* roots are the nonzeros (in abuf) I will send */
378     PetscSFNode *iremote;
379     PetscCall(PetscMalloc1(nleaves2, &iremote));
380     for (i = 0; i < nranks; i++) { /* for each sender */
381       k = 0;
382       for (j = Ci_h[roffset[i]]; j < Ci_h[roffset[i + 1]]; j++) {
383         iremote[j].rank  = ranks[i];
384         iremote[j].index = rdisp[i] + k;
385         k++;
386       }
387     }
388     /* TODO: we should extend PetscSF APIs for this buffer-to-buffer send/recv */
389     PetscCall(PetscSFCreate(comm, &bcastSF));
390     PetscCall(PetscSFSetGraph(bcastSF, nroots2, nleaves2, NULL /*ilocal*/, PETSC_OWN_POINTER, iremote, PETSC_OWN_POINTER));
391 
392     /* Extract selected rows of B, and copy their (a, j) into abuf[] and jbuf[], with j converted
393       from local to global. Then use bcastSF to fill Ca, Cj.
394     */
395     ConstMatColIdxKokkosViewHost rows_h(irootloc, ioffset[niranks]); /* irootloc[] stores indices of rows I need to send */
396     MatColIdxKokkosView          rows("rows", ioffset[niranks]);
397     Kokkos::deep_copy(rows, rows_h); /* Use deep copy since irootoc is managed by PetscSF and we want 'rows' to be standalone */
398 
399     rowoffset = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), rowptr_h); /* If no device, rowoffset will be an alias to rowptr_h */
400 
401     MatColIdxKokkosView jbuf("jbuf", sdisp[niranks]);   /* send buf for (global) col ids */
402     abuf = MatScalarKokkosView("abuf", sdisp[niranks]); /* send buf for mat values */
403 
404     const auto &Ba = bkok->a_dual.view_device();
405     const auto &Bi = bkok->i_dual.view_device();
406     const auto &Bj = bkok->j_dual.view_device();
407 
408     /* Copy Ba, Bj to abuf, jbuf with change col ids from local to global */
409     Kokkos::parallel_for(
410       Kokkos::TeamPolicy<>(rows.extent(0), Kokkos::AUTO()), KOKKOS_LAMBDA(const KokkosTeamMemberType &t) {
411         PetscInt i    = t.league_rank(); /* rows[i] is r-th row of B */
412         PetscInt r    = rows(i);
413         PetscInt base = rowoffset(i); /* Copy r-th row of B to this offset in abuf[] */
414         Kokkos::parallel_for(Kokkos::TeamThreadRange(t, Bi(r + 1) - Bi(r)), [&](PetscInt k) {
415           abuf(base + k) = Ba(Bi(r) + k);
416           jbuf(base + k) = l2g(Bj(Bi(r) + k));
417         });
418       });
419 
420     /* Send abuf & jbuf to fill Ca, Cj */
421     PetscCall(PetscSFBcastBegin(bcastSF, MPIU_INT, jbuf.data(), Cj.view_device().data(), MPI_REPLACE));
422     PetscCall(PetscSFBcastBegin(bcastSF, MPIU_SCALAR, abuf.data(), Ca.view_device().data(), MPI_REPLACE));
423     PetscCall(PetscSFBcastEnd(bcastSF, MPIU_INT, jbuf.data(), Cj.view_device().data(), MPI_REPLACE));
424     PetscCall(PetscSFBcastEnd(bcastSF, MPIU_SCALAR, abuf.data(), Ca.view_device().data(), MPI_REPLACE));
425     Cj.modify_device(); /* Mark Cj, Ca modified on device, but only sync Cj since we might not need Ca on host at all */
426     Cj.sync_host();
427     Ca.modify_device();
428 
429     /* Construct C with Ca, Ci, Cj */
430     auto ckok = new Mat_SeqAIJKokkos(Cm, Cn, Cnnz, Ci, Cj, Ca);
431     PetscCall(MatCreateSeqAIJKokkosWithCSRMatrix(PETSC_COMM_SELF, ckok, &C));
432     PetscCall(PetscFree3(sdisp, rdisp, reqs));
433     PetscCall(PetscFree(Browlens));
434   } else SETERRQ(PETSC_COMM_SELF, PETSC_ERR_PLIB, "Unsupported MatReuse enum %d", reuse);
435   PetscFunctionReturn(0);
436 }
437 
438 /* MatSeqAIJKokkosReduce - Reduce rows of a SEQAIJKOKKOS matrix (A) to form a Kokkos Csr matrix (C)
439 
440   It is the reverse of MatSeqAIJKokkosBcast in some sense.
441 
442   Think each row of A as a leaf, then the given ownerSF specifies roots of the leaves. Roots may connect to multiple leaves.
443   In this routine, we reduce (i.e., concatenate) leaves (rows) at their roots to form potentially longer rows in C. Such rows might
444   contain repeats, which does not matter since they will be summed up by other routines. C's row size will be nroots of ownerSF.
445 
446   Input Parameters:
447 +  A        - the SEQAIJKOKKOS matrix to be reduced
448 .  reuse    - either MAT_INITIAL_MATRIX or MAT_REUSE_MATRIX
449 .  local    - true if A uses local col ids; false if A is already in global col ids.
450 .  N        - if local, N is A's global col size
451 .  l2g      - if local, a map mapping A's local col ids to global ones, which are in range of [0,N).
452 -  ownerSF  - the SF specifies ownership (root) of rows in A
453 
454   Output Parameters:
455 +  reduceSF    - the SF to reduce A's rows to contiguous buffers at the receiver side
456 .  abuf         - a contiguous buffer to receive A's rows sent to this proc. Suppose there are 'nrows' such rows.
457 .  srcrowoffset - offset array of size nrows+1. Each entry is the corresponding row's offset in abuf[]. srcrowoffset[i+1]-srcrowoffset[i] is row i's len.
458 .  dstrowoffset - offset array of size nrows. Each entry is the corresponding row's offset in Ca[], i.e., C's 'a' array. Row i, i+1 in abuf[] may go to
459                   unrelated places in Ca, so dstrowoffset is not in CSR-like format as srcrowoffset.
460 -  C            - the matrix made up by rows sent to me from other ranks, using global col ids
461 
462    TODO: we can even have MatSeqAIJKokkosReduceBegin/End to provide oppertunity for callers to overlap comp./comm. when reuse = MAT_REUSE_MATRIX.
463  */
464 static PetscErrorCode MatSeqAIJKokkosReduce(Mat A, MatReuse reuse, PetscBool local, PetscInt N, const ConstMatColIdxKokkosView &l2g, PetscSF ownerSF, PetscSF &reduceSF, MatScalarKokkosView &abuf, MatRowMapKokkosView &srcrowoffset, MatRowMapKokkosView &dstrowoffset, KokkosCsrMatrix &C) {
465   PetscInt          i, r, Am, An, Annz, Cnnz, nrows;
466   const PetscInt   *Ai;
467   Mat_SeqAIJKokkos *akok;
468 
469   PetscFunctionBegin;
470   PetscCall(MatSeqAIJKokkosSyncDevice(A)); /* So that A's latest data is on device */
471   PetscCall(MatGetSize(A, &Am, &An));
472   Ai   = static_cast<Mat_SeqAIJ *>(A->data)->i;
473   akok = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
474   Annz = Ai[Am];
475 
476   if (reuse == MAT_REUSE_MATRIX) {
477     /* Send Aa to abuf */
478     PetscCallMPI(PetscSFReduceBegin(reduceSF, MPIU_SCALAR, akok->a_device_data(), abuf.data(), MPI_REPLACE));
479     PetscCallMPI(PetscSFReduceEnd(reduceSF, MPIU_SCALAR, akok->a_device_data(), abuf.data(), MPI_REPLACE));
480 
481     /* Copy abuf to Ca */
482     const MatScalarKokkosView &Ca = C.values;
483     nrows                         = dstrowoffset.extent(0); /* Not srcrowoffset[] since it has an extra entry for CSR */
484     Kokkos::parallel_for(
485       Kokkos::TeamPolicy<>(nrows, Kokkos::AUTO()), KOKKOS_LAMBDA(const KokkosTeamMemberType &t) {
486         PetscInt i   = t.league_rank();
487         PetscInt src = srcrowoffset(i), dst = dstrowoffset(i);
488         PetscInt len = srcrowoffset(i + 1) - srcrowoffset(i);
489         Kokkos::parallel_for(Kokkos::TeamThreadRange(t, len), [&](PetscInt k) { Ca(dst + k) = abuf(src + k); });
490       });
491   } else if (reuse == MAT_INITIAL_MATRIX) {
492     MPI_Comm     comm;
493     MPI_Request *reqs;
494     PetscMPIInt  tag;
495     PetscInt     Cm;
496 
497     PetscCall(PetscObjectGetComm((PetscObject)ownerSF, &comm));
498     PetscCall(PetscCommGetNewTag(comm, &tag));
499 
500     PetscInt           niranks, nranks, nroots, nleaves;
501     const PetscMPIInt *iranks, *ranks;
502     const PetscInt    *ioffset, *rows, *roffset; /* rows[] contains local indices of rows scattered to me from others. ioffset[] is a CSR on rows[] */
503     PetscCall(PetscSFSetUp(ownerSF));
504     PetscCall(PetscSFGetLeafRanks(ownerSF, &niranks, &iranks, &ioffset, &rows));                          /* recv info: iranks[] will send rows to me */
505     PetscCall(PetscSFGetRootRanks(ownerSF, &nranks, &ranks, &roffset, NULL /*rmine*/, NULL /*rremote*/)); /* send info */
506     PetscCall(PetscSFGetGraph(ownerSF, &nroots, &nleaves, NULL, NULL));
507     PetscCheck(nleaves == Am, PETSC_COMM_SELF, PETSC_ERR_PLIB, "ownerSF's nleaves(%" PetscInt_FMT ") != row size of A(%" PetscInt_FMT ")", nleaves, Am);
508     Cm    = nroots;
509     nrows = ioffset[niranks]; /* # of rows to be received. Might receive same row (each is partial) from different senders */
510 
511     /* Tell owners how long each row I will send */
512     PetscInt               *srowlens;                              /* send buf of row lens */
513     MatRowMapKokkosViewHost rrowlens_h("rrowoffset_h", nrows + 1); /* recv buf of row lens. +1 to make CSR later. Memory might be passed to other views */
514     PetscInt               *rrowlens = rrowlens_h.data();
515 
516     PetscCall(PetscMalloc2(Am, &srowlens, niranks + nranks, &reqs));
517     for (i = 0; i < Am; i++) srowlens[i] = Ai[i + 1] - Ai[i];
518     rrowlens[0] = 0;
519     rrowlens++; /* shift the pointer to make the following expression more readable */
520     for (i = 0; i < niranks; i++) PetscCallMPI(MPI_Irecv(&rrowlens[ioffset[i]], ioffset[i + 1] - ioffset[i], MPIU_INT, iranks[i], tag, comm, &reqs[i]));
521     for (i = 0; i < nranks; i++) PetscCallMPI(MPI_Isend(&srowlens[roffset[i]], roffset[i + 1] - roffset[i], MPIU_INT, ranks[i], tag, comm, &reqs[niranks + i]));
522     PetscCallMPI(MPI_Waitall(niranks + nranks, reqs, MPI_STATUSES_IGNORE));
523 
524     /* Owner builds Ci on host by histogramming rrowlens[] */
525     MatRowMapKokkosViewHost Ci_h("i", Cm + 1);
526     Kokkos::deep_copy(Ci_h, 0); /* Zero Ci */
527     MatRowMapType *Ci_ptr = Ci_h.data();
528 
529     for (i = 0; i < nrows; i++) {
530       r = rows[i]; /* local row id of i-th received row */
531 #if defined(PETSC_USE_DEBUG)
532       PetscCheck(r >= 0 && r < Cm, PETSC_COMM_SELF, PETSC_ERR_PLIB, "local row id (%" PetscInt_FMT ") is out of range [0,%" PetscInt_FMT ")", r, Cm);
533 #endif
534       Ci_ptr[r + 1] += rrowlens[i]; /* add to length of row r in C */
535     }
536     for (i = 0; i < Cm; i++) Ci_ptr[i + 1] += Ci_ptr[i]; /* to CSR format */
537     Cnnz = Ci_ptr[Cm];
538 
539     /* For each received row, compute src & dst offsets in memory copying (from recv bufs abuf, jbuf to Ca, Cj) */
540     MatRowMapKokkosViewHost dstrowoffset_h("dstrowoffset_h", nrows);
541     PetscInt               *dstrowoffset_hptr = dstrowoffset_h.data();
542     PetscInt               *currowlens; /* Current row lens. They are temp accumulators for row lens in C, to help build dstrowoffset */
543 
544     PetscCall(PetscCalloc1(Cm, &currowlens));           /* Init with zero, to be added to */
545     for (i = 0; i < nrows; i++) {                       /* for each row I receive */
546       r                    = rows[i];                   /* row id in C */
547       dstrowoffset_hptr[i] = Ci_ptr[r] + currowlens[r]; /* dst offset of the new place for each recv'ed row in Ca/Cj */
548       currowlens[r] += rrowlens[i];                     /* accumulate to length of row r in C */
549     }
550     PetscCall(PetscFree(currowlens));
551 
552     rrowlens--;
553     for (i = 0; i < nrows; i++) rrowlens[i + 1] += rrowlens[i]; /* Change rrowlens[] to CSR format */
554     dstrowoffset = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), dstrowoffset_h);
555     srcrowoffset = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), rrowlens_h); /* src offset of each recv'ed row in abuf/jbuf */
556 
557     /* Build the reduceSF, which performs buffer to buffer send/recv */
558     PetscInt *sdisp, *rdisp; /* buffer to send offsets of roots, and buffer to recv them */
559     PetscCall(PetscMalloc2(niranks, &sdisp, nranks, &rdisp));
560     for (i = 0; i < niranks; i++) sdisp[i] = rrowlens[ioffset[i]];
561     for (i = 0; i < nranks; i++) PetscCallMPI(MPI_Irecv(&rdisp[i], 1, MPIU_INT, ranks[i], tag, comm, &reqs[i]));
562     for (i = 0; i < niranks; i++) PetscCallMPI(MPI_Isend(&sdisp[i], 1, MPIU_INT, iranks[i], tag, comm, &reqs[nranks + i]));
563     PetscCallMPI(MPI_Waitall(niranks + nranks, reqs, MPI_STATUSES_IGNORE));
564 
565     /* Nonzeros in abuf/jbuf are roots and those in A are leaves */
566     PetscInt     nroots2 = Cnnz, nleaves2 = Annz;
567     PetscSFNode *iremote;
568     PetscCall(PetscMalloc1(nleaves2, &iremote)); /* no free, since memory will be given to reduceSF */
569     for (i = 0; i < nranks; i++) {
570       PetscInt rootbase = rdisp[i];                      /* root offset at this root rank */
571       PetscInt leafbase = Ai[roffset[i]];                /* leaf base */
572       PetscInt nz       = Ai[roffset[i + 1]] - leafbase; /* I will send nz nonzeros to this root rank */
573       for (PetscInt k = 0; k < nz; k++) {
574         iremote[leafbase + k].rank  = ranks[i];
575         iremote[leafbase + k].index = rootbase + k;
576       }
577     }
578     PetscCall(PetscSFCreate(comm, &reduceSF));
579     PetscCall(PetscSFSetGraph(reduceSF, nroots2, nleaves2, NULL, PETSC_OWN_POINTER, iremote, PETSC_OWN_POINTER));
580     PetscCall(PetscFree2(sdisp, rdisp));
581 
582     /* Reduce Aa, Ajg to abuf and jbuf */
583 
584     /* If A uses local col ids, convert them to global ones before sending */
585     MatColIdxKokkosView Ajg;
586     if (local) {
587       Ajg                           = MatColIdxKokkosView("j", Annz);
588       const MatColIdxKokkosView &Aj = akok->j_dual.view_device();
589       Kokkos::parallel_for(
590         Annz, KOKKOS_LAMBDA(const PetscInt i) { Ajg(i) = l2g(Aj(i)); });
591     } else {
592       Ajg = akok->j_dual.view_device(); /* no data copy, just take a reference */
593     }
594 
595     MatColIdxKokkosView jbuf("jbuf", Cnnz);
596     abuf = MatScalarKokkosView("abuf", Cnnz);
597     PetscCallMPI(PetscSFReduceBegin(reduceSF, MPIU_INT, Ajg.data(), jbuf.data(), MPI_REPLACE));
598     PetscCallMPI(PetscSFReduceEnd(reduceSF, MPIU_INT, Ajg.data(), jbuf.data(), MPI_REPLACE));
599     PetscCallMPI(PetscSFReduceBegin(reduceSF, MPIU_SCALAR, akok->a_device_data(), abuf.data(), MPI_REPLACE));
600     PetscCallMPI(PetscSFReduceEnd(reduceSF, MPIU_SCALAR, akok->a_device_data(), abuf.data(), MPI_REPLACE));
601 
602     /* Copy data from abuf, jbuf to Ca, Cj */
603     MatRowMapKokkosView Ci = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Ci_h); /* Ci is an alias of Ci_h if no device */
604     MatColIdxKokkosView Cj("j", Cnnz);
605     MatScalarKokkosView Ca("a", Cnnz);
606 
607     Kokkos::parallel_for(
608       Kokkos::TeamPolicy<>(nrows, Kokkos::AUTO()), KOKKOS_LAMBDA(const KokkosTeamMemberType &t) {
609         PetscInt i   = t.league_rank();
610         PetscInt src = srcrowoffset(i), dst = dstrowoffset(i);
611         PetscInt len = srcrowoffset(i + 1) - srcrowoffset(i);
612         Kokkos::parallel_for(Kokkos::TeamThreadRange(t, len), [&](PetscInt k) {
613           Ca(dst + k) = abuf(src + k);
614           Cj(dst + k) = jbuf(src + k);
615         });
616       });
617 
618     /* Build C with Ca, Ci, Cj */
619     C = KokkosCsrMatrix("csrmat", Cm, N, Cnnz, Ca, Ci, Cj);
620     PetscCall(PetscFree2(srowlens, reqs));
621   } else SETERRQ(PETSC_COMM_SELF, PETSC_ERR_PLIB, "Unsupported MatReuse enum %d", reuse);
622   PetscFunctionReturn(0);
623 }
624 
625 /* MatSetMPIAIJKokkosWithGlobalCSRMatrix - Set the diag and offdiag parts of a MATMPIAIJKOKKOS matrix by splitting a KokkosCsrMatrix
626 
627   Input Parameters:
628 +  C        - the MATMPIAIJKOKKOS matrix, of size m,n,M,N
629 .  reuse    - indicate whether the matrix has called this function before
630 .  csrmat   - the KokkosCsrMatrix, of size m,N
631 -  Cdstart  - when reuse == MAT_REUSE_MATRIX, it is an input parameter. For each row in csrmat, it stores the start of the first
632               entry of the diag block of C in csrmat's j array. E.g, if row i has col ids = {0, 3, 4, 5, 7, 9} and the first diag
633               entry is 5, then Cdstart[i] = 3.
634 
635   Output Parameters:
636 +  C        - the updated MATMPIAIJKOKKOS matrix
637 -  Cdstart - when reuse == MAT_INITIAL_MATRIX, it is an output parameter
638 
639   Notes:
640    Between calls with MAT_INITIAL_MATRIX or MAT_REUSE_MATRIX, csrmat must have the same nonzero pattern
641  */
642 static PetscErrorCode MatSetMPIAIJKokkosWithGlobalCSRMatrix(Mat C, MatReuse reuse, const KokkosCsrMatrix &csrmat, MatRowMapKokkosView &Cdstart) {
643   const MatScalarKokkosView      &Ca = csrmat.values;
644   const ConstMatRowMapKokkosView &Ci = csrmat.graph.row_map;
645   PetscInt                        m, n, N;
646 
647   PetscFunctionBegin;
648   PetscCall(MatGetLocalSize(C, &m, &n));
649   PetscCall(MatGetSize(C, NULL, &N));
650 
651   if (reuse == MAT_REUSE_MATRIX) {
652     Mat_MPIAIJ                *mpiaij = static_cast<Mat_MPIAIJ *>(C->data);
653     Mat_SeqAIJKokkos          *akok   = static_cast<Mat_SeqAIJKokkos *>(mpiaij->A->spptr);
654     Mat_SeqAIJKokkos          *bkok   = static_cast<Mat_SeqAIJKokkos *>(mpiaij->B->spptr);
655     const MatScalarKokkosView &Cda = akok->a_dual.view_device(), Coa = bkok->a_dual.view_device();
656     const MatRowMapKokkosView &Cdi = akok->i_dual.view_device(), Coi = bkok->i_dual.view_device();
657 
658     /* Fill 'a' of Cd and Co on device */
659     Kokkos::parallel_for(
660       Kokkos::TeamPolicy<>(m, Kokkos::AUTO()), KOKKOS_LAMBDA(const KokkosTeamMemberType &t) {
661         PetscInt i       = t.league_rank();     /* row i */
662         PetscInt clen    = Ci(i + 1) - Ci(i);   /* len of row i of C */
663         PetscInt cdlen   = Cdi(i + 1) - Cdi(i); /* len of row i of Cd */
664         PetscInt cdstart = Cdstart(i);          /* [start, end) of row i of Cd in C */
665         PetscInt cdend   = cdstart + cdlen;
666         /* [0, clen) is cut into three blocks: [0, cdstart), [cdstart, cdend), [cdend, clen) */
667         Kokkos::parallel_for(Kokkos::TeamThreadRange(t, clen), [&](PetscInt k) {
668           if (k < cdstart) { /* k in [0, cdstart) */
669             Coa(Coi(i) + k) = Ca(Ci(i) + k);
670           } else if (k < cdend) { /* k in [cdstart, cdend) */
671             Cda(Cdi(i) + (k - cdstart)) = Ca(Ci(i) + k);
672           } else { /* k in [cdend, clen) */
673             Coa(Coi(i) + k - cdlen) = Ca(Ci(i) + k);
674           }
675         });
676       });
677 
678     akok->a_dual.modify_device();
679     bkok->a_dual.modify_device();
680   } else if (reuse == MAT_INITIAL_MATRIX) {
681     Mat                        Cd, Co;
682     const MatColIdxKokkosView &Cj = csrmat.graph.entries;
683     MatRowMapKokkosDualView    Cdi_dual("i", m + 1), Coi_dual("i", m + 1);
684     MatRowMapKokkosView        Cdi = Cdi_dual.view_device(), Coi = Coi_dual.view_device();
685     PetscInt                   cstart, cend;
686 
687     /* Note that each row of C is sorted by col ids. We want to find out how to cut each row into three blocks:
688        left to the diag block, diag block, right to the diag block. The diag block have col ids in [cstart,cend).
689        Suppose a row of C has len nonzeros, indexed by [0, len). We want to know two indices: cdstart and cdend,
690        such that the three blocks are [0,cdstart), [cdstart,cdend), [cdend,len). The following code equivalentaly
691        stores values of cdstart and cdend-cstart (aka Cdi[]) instead.
692      */
693     Cdstart = MatRowMapKokkosView("Cdstart", m);
694     PetscCall(PetscLayoutGetRange(C->cmap, &cstart, &cend)); /* Not MatGetOwnershipRangeColumn() since C has not been preallocated yet */
695 
696     /* I could use RangePolicy and one thread per row. But since each thread essentially does binary search, threads in a
697       CUDA warp would completely diverge. So I use TeamPolicy with a team size 1.
698      */
699     Kokkos::parallel_for(
700       Kokkos::TeamPolicy<>(m, 1), KOKKOS_LAMBDA(const KokkosTeamMemberType &t) {
701         Kokkos::single(Kokkos::PerTeam(t), [=]() {                               /* Only one thread works in a team */
702                                                    PetscInt i = t.league_rank(); /* row i */
703                                                    PetscInt j, first, count, step;
704 
705                                                    if (i == 0) { /* Set the first entry of the i arrays to zero on device, to be used in CSR */
706                                                      Cdi(0) = 0;
707                                                      Coi(0) = 0;
708                                                    }
709 
710                                                    /* Do std::lower_bound(Ci(i),Ci(i+1),cstart) on Cj[]. We use j as the iterator. lower_bound() returns
711           in 'first' the first iterator with a value >= cstart, or last iterator if no such element is found.
712         */
713                                                    count = Ci(i + 1) - Ci(i);
714                                                    first = Ci(i);
715                                                    while (count > 0) {
716                                                      j    = first;
717                                                      step = count / 2;
718                                                      j += step;
719                                                      if (Cj(j) < cstart) {
720                                                        first = ++j;
721                                                        count -= step + 1;
722                                                      } else count = step;
723                                                    }
724                                                    Cdstart(i) = first - Ci(i); /* 'first' is the while-loop's output */
725 
726                                                    /* Do std::lower_bound(first,Ci(i+1),cend) on Cj[] */
727                                                    count = Ci(i + 1) - first;
728                                                    while (count > 0) {
729                                                      j    = first;
730                                                      step = count / 2;
731                                                      j += step;
732                                                      if (Cj(j) < cend) {
733                                                        first = ++j;
734                                                        count -= step + 1;
735                                                      } else count = step;
736                                                    }
737                                                    Cdi(i + 1) = first - (Ci(i) + Cdstart(i));     /* 'first' is the while-loop's output */
738                                                    Coi(i + 1) = (Ci(i + 1) - Ci(i)) - Cdi(i + 1); /* Co's row len = C's row len - Cd's row len */
739         });
740       });
741 
742     /* Convert row lens in Cdi[], Coi[] to CSR format using inclusive scan, e.g., changing [0,1,2,3] into [0,1,3,6] */
743     Kokkos::parallel_scan(
744       m + 1, KOKKOS_LAMBDA(const PetscInt i, PetscInt &update, const bool final) {
745         update += Cdi(i);
746         if (final) Cdi(i) = update;
747       });
748     Kokkos::parallel_scan(
749       m + 1, KOKKOS_LAMBDA(const PetscInt i, PetscInt &update, const bool final) {
750         update += Coi(i);
751         if (final) Coi(i) = update;
752       });
753 
754     /* Get Cdi, Coi on host (it is not a waste, since we do need them on host in
755        MatCreateSeqAIJKokkosWithCSRMatrix() below), then get nnz of Cd and Co.
756     */
757     Cdi_dual.modify_device();
758     Coi_dual.modify_device();
759     Cdi_dual.sync_host();
760     Coi_dual.sync_host();
761     PetscInt Cd_nnz = Cdi_dual.view_host().data()[m];
762     PetscInt Co_nnz = Coi_dual.view_host().data()[m];
763 
764     /* With nnz, allocate a, j for Cd and Co */
765     MatColIdxKokkosDualView Cdj_dual("j", Cd_nnz), Coj_dual("j", Co_nnz);
766     MatScalarKokkosDualView Cda_dual("a", Cd_nnz), Coa_dual("a", Co_nnz);
767 
768     /* Fill a, j of Cd and Co on device */
769     MatColIdxKokkosView Cdj = Cdj_dual.view_device(), Coj = Coj_dual.view_device();
770     MatScalarKokkosView Cda = Cda_dual.view_device(), Coa = Coa_dual.view_device();
771 
772     Kokkos::parallel_for(
773       Kokkos::TeamPolicy<>(m, Kokkos::AUTO()), KOKKOS_LAMBDA(const KokkosTeamMemberType &t) {
774         PetscInt i       = t.league_rank();     /* row i */
775         PetscInt clen    = Ci(i + 1) - Ci(i);   /* len of row i of C */
776         PetscInt cdlen   = Cdi(i + 1) - Cdi(i); /* len of row i of Cd */
777         PetscInt cdstart = Cdstart(i);          /* [start, end) of row i of Cd in C */
778         PetscInt cdend   = cdstart + cdlen;
779         /* [0, clen) is cut into three blocks: [0, cdstart), [cdstart, cdend), [cdend, clen) */
780         Kokkos::parallel_for(Kokkos::TeamThreadRange(t, clen), [&](PetscInt k) {
781           if (k < cdstart) { /* k in [0, cdstart) */
782             Coa(Coi(i) + k) = Ca(Ci(i) + k);
783             Coj(Coi(i) + k) = Cj(Ci(i) + k);
784           } else if (k < cdend) { /* k in [cdstart, cdend) */
785             Cda(Cdi(i) + (k - cdstart)) = Ca(Ci(i) + k);
786             Cdj(Cdi(i) + (k - cdstart)) = Cj(Ci(i) + k) - cstart; /* Use local col ids in Cdj */
787           } else {                                                /* k in [cdend, clen) */
788             Coa(Coi(i) + k - cdlen) = Ca(Ci(i) + k);
789             Coj(Coi(i) + k - cdlen) = Cj(Ci(i) + k);
790           }
791         });
792       });
793 
794     Cdj_dual.modify_device();
795     Cda_dual.modify_device();
796     Coj_dual.modify_device();
797     Coa_dual.modify_device();
798     /* With a, i, j for Cd and Co, finally build Cd, Co and then C. Their offloadmask will be set in each's MatAssemblyEnd */
799     auto cdkok = new Mat_SeqAIJKokkos(m, n, Cd_nnz, Cdi_dual, Cdj_dual, Cda_dual);
800     auto cokok = new Mat_SeqAIJKokkos(m, N, Co_nnz, Coi_dual, Coj_dual, Coa_dual);
801     PetscCall(MatCreateSeqAIJKokkosWithCSRMatrix(PETSC_COMM_SELF, cdkok, &Cd));
802     PetscCall(MatCreateSeqAIJKokkosWithCSRMatrix(PETSC_COMM_SELF, cokok, &Co));
803     PetscCall(MatSetMPIAIJKokkosWithSplitSeqAIJKokkosMatrices(C, Cd, Co)); /* Coj will be converted to local ids within */
804   }
805   PetscFunctionReturn(0);
806 }
807 
808 /* MatSeqAIJCompactOutExtraColumns_SeqAIJKokkos - Compact a SEQAIJKOKKS matrix's global col ids.
809 
810   It is similar to MatSeqAIJCompactOutExtraColumns_SeqAIJ, but it applies to SEQAIJKOKKOS and returns the l2g map in Kokkos view.
811 
812   Input Parameters:
813 +  C        - the MATMPIAIJKOKKOS matrix, of size m,n,M,N
814 .  reuse    - indicate whether the matrix has called this function before
815 .  csrmat   - the KokkosCsrMatrix, of size m,N
816 -  Cdoffset - when reuse == MAT_REUSE_MATRIX, it is an input parameter. For each row in csrmat, it stores the offset of the first
817               entry of the diag block of C in csrmat's j array.
818 
819   Output Parameters:
820 +  C        - the updated MATMPIAIJKOKKOS matrix
821 -  Cdoffset - when reuse == MAT_INITIAL_MATRIX, it is an output parameter
822 
823   Notes: the input matrix's col ids and col size will be changed.
824 */
825 static PetscErrorCode MatSeqAIJCompactOutExtraColumns_SeqAIJKokkos(Mat C, MatColIdxKokkosView &l2g) {
826   Mat_SeqAIJKokkos      *ckok;
827   ISLocalToGlobalMapping l2gmap;
828   const PetscInt        *garray;
829   PetscInt               sz;
830 
831   PetscFunctionBegin;
832   /* Compact P_other's global col ids and col size. We do it since we guess with local ids KK might be more memory scalable */
833   PetscCall(MatSeqAIJCompactOutExtraColumns_SeqAIJ(C, &l2gmap));
834   ckok = static_cast<Mat_SeqAIJKokkos *>(C->spptr);
835   ckok->j_dual.modify_host(); /* P_other's j is modified on host; we need to sync it on device */
836   ckok->j_dual.sync_device();
837   ckok->SetColSize(C->cmap->n); /* Update col size of the csrmat in spptr */
838 
839   /* Build l2g -- the local to global mapping of C's cols */
840   PetscCall(ISLocalToGlobalMappingGetIndices(l2gmap, &garray));
841   PetscCall(ISLocalToGlobalMappingGetSize(l2gmap, &sz));
842   PetscCheck(C->cmap->n == sz, PETSC_COMM_SELF, PETSC_ERR_PLIB, "matrix column size(%" PetscInt_FMT ") != l2g mapping size(%" PetscInt_FMT ")", C->cmap->n, sz);
843 
844   ConstMatColIdxKokkosViewHost tmp(garray, sz);
845   l2g = MatColIdxKokkosView("l2g", sz);
846   Kokkos::deep_copy(l2g, tmp);
847 
848   PetscCall(ISLocalToGlobalMappingRestoreIndices(l2gmap, &garray));
849   PetscCall(ISLocalToGlobalMappingDestroy(&l2gmap));
850   PetscFunctionReturn(0);
851 }
852 
853 /* MatProductSymbolic_MPIAIJKokkos_AB - AB flavor of MatProductSymbolic_MPIAIJKokkos
854 
855   Input Parameters:
856 +  product  - Mat_Product which carried out the computation. Passed in to access info about this mat product.
857 .  A        - an MPIAIJKOKKOS matrix
858 .  B        - an MPIAIJKOKKOS matrix
859 -  mm       - a struct used to stash intermediate data when computing AB. Persist from symbolic to numeric operations.
860 
861   Notes: The local part of the result C is stored as mm->C_global, which is of type KokkosCsrMatrix and uses global col ids.
862 */
863 static PetscErrorCode MatProductSymbolic_MPIAIJKokkos_AB(Mat_Product *product, Mat A, Mat B, MatMatStruct_AB *mm) {
864   Mat_MPIAIJ              *a  = static_cast<Mat_MPIAIJ *>(A->data);
865   Mat                      Ad = a->A, Ao = a->B; /* diag and offdiag of A */
866   IS                       glob = NULL;
867   const PetscInt          *garray;
868   PetscInt                 N = B->cmap->N, sz;
869   ConstMatColIdxKokkosView l2g1; /* two temp maps mapping local col ids to global ones */
870   MatColIdxKokkosView      l2g2;
871   Mat                      C1, C2; /* intermediate matrices */
872 
873   PetscFunctionBegin;
874   /* C1 = Ad * B_local. B_local is a matrix got by merging Bd and Bo, and uses local col ids */
875   PetscCall(MatMPIAIJGetLocalMatMerge(B, MAT_INITIAL_MATRIX, &glob, &mm->B_local));
876   PetscCall(MatProductCreate(Ad, mm->B_local, NULL, &C1));
877   PetscCall(MatProductSetType(C1, MATPRODUCT_AB));
878   PetscCall(MatProductSetFill(C1, product->fill));
879   C1->product->api_user = product->api_user;
880   PetscCall(MatProductSetFromOptions(C1));
881   PetscUseTypeMethod(C1, productsymbolic);
882 
883   PetscCall(ISGetIndices(glob, &garray));
884   PetscCall(ISGetSize(glob, &sz));
885   const auto &tmp = ConstMatColIdxKokkosViewHost(garray, sz);                       /* wrap garray as a view */
886   l2g1            = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), tmp); /* maybe just an alias to tmp, so we restore garray at the very end */
887   PetscCall(MatSeqAIJKokkosGetCSRMatrixWithGlobalColumnIds(C1, N, l2g1, mm->C1_global));
888 
889   /* C2 = Ao * B_other. B_other is a matrix consisting of needed rows of B gathered from other procs */
890   PetscCall(MatSeqAIJKokkosBcast(mm->B_local, MAT_INITIAL_MATRIX, N, l2g1, a->Mvctx, mm->sf, mm->abuf, mm->rows, mm->rowoffset, mm->B_other));
891 
892   /* Compact B_other to use local ids as we guess KK spgemm is more memroy scalable with that; We could skip the compaction to simplify code */
893   PetscCall(MatSeqAIJCompactOutExtraColumns_SeqAIJKokkos(mm->B_other, l2g2));
894   PetscCall(MatProductCreate(Ao, mm->B_other, NULL, &C2));
895   PetscCall(MatProductSetType(C2, MATPRODUCT_AB));
896   PetscCall(MatProductSetFill(C2, product->fill));
897   C2->product->api_user = product->api_user;
898   PetscCall(MatProductSetFromOptions(C2));
899   PetscUseTypeMethod(C2, productsymbolic);
900   PetscCall(MatSeqAIJKokkosGetCSRMatrixWithGlobalColumnIds(C2, N, l2g2, mm->C2_global));
901 
902   /* C = C1 + C2.  We actually use their global col ids versions in adding */
903   mm->kh.create_spadd_handle(false); /* Input C1, C2 are NOT sorted, since B_local, B_other are not */
904   KokkosSparse::spadd_symbolic(&mm->kh, mm->C1_global, mm->C2_global, mm->C_global);
905   /* Have to do numeric since spadd_symbolic does not really populate column indices of the result matrix */
906   KokkosSparse::spadd_numeric(&mm->kh, (MatScalarType)1.0, mm->C1_global, (MatScalarType)1.0, mm->C2_global, mm->C_global);
907 
908   mm->C1 = C1;
909   mm->C2 = C2;
910   PetscCall(ISRestoreIndices(glob, &garray));
911   PetscCall(ISDestroy(&glob));
912   PetscFunctionReturn(0);
913 }
914 
915 /* MatProductSymbolic_MPIAIJKokkos_AtB - A^tB flavor of MatProductSymbolic_MPIAIJKokkos
916 
917   Input Parameters:
918 +  product  - Mat_Product which carried out the computation. Passed in to access info about this mat product.
919 .  A        - an MPIAIJKOKKOS matrix
920 .  B        - a SEQAIJKOKKOS matrix. It works as if A^t is multiplied by a parallel matrix made up of Bs on each rank.
921 .  localB   - Does B use local col ids? If false, then B is already in global col ids.
922 .  N        - col size of the "parallel B matrix". It implies B's global col ids are in range of [0,N) and N is the same across the communicator.
923 .  l2g      - If localB, then l2g maps B's local col ids to global ones.
924 -  mm       - a struct used to stash intermediate data in AtB
925 
926   Notes: The local part of the result C is stored as mm->C_global, which is of type KokkosCsrMatrix and uses global col ids.
927 */
928 static PetscErrorCode MatProductSymbolic_MPIAIJKokkos_AtB(Mat_Product *product, Mat A, Mat B, PetscBool localB, PetscInt N, const ConstMatColIdxKokkosView &l2g, MatMatStruct_AtB *mm) {
929   Mat_MPIAIJ *a  = static_cast<Mat_MPIAIJ *>(A->data);
930   Mat         Ad = a->A, Ao = a->B; /* diag and offdiag of A */
931   Mat         C1, C2;               /* intermediate matrices */
932 
933   PetscFunctionBegin;
934   /* C1 = Ad^t * B */
935   PetscCall(MatProductCreate(Ad, B, NULL, &C1));
936   PetscCall(MatProductSetType(C1, MATPRODUCT_AtB));
937   PetscCall(MatProductSetFill(C1, product->fill));
938   C1->product->api_user = product->api_user;
939   PetscCall(MatProductSetFromOptions(C1));
940   PetscUseTypeMethod(C1, productsymbolic);
941 
942   if (localB) PetscCall(MatSeqAIJKokkosGetCSRMatrixWithGlobalColumnIds(C1, N, l2g, mm->C1_global));
943   else mm->C1_global = static_cast<Mat_SeqAIJKokkos *>(C1->spptr)->csrmat; /* the csrmat already uses global col ids */
944 
945   /* C2 = Ao^t * B */
946   PetscCall(MatProductCreate(Ao, B, NULL, &C2));
947   PetscCall(MatProductSetType(C2, MATPRODUCT_AtB));
948   PetscCall(MatProductSetFill(C2, product->fill));
949   C2->product->api_user = product->api_user;
950   PetscCall(MatProductSetFromOptions(C2));
951   PetscUseTypeMethod(C2, productsymbolic);
952 
953   PetscCall(MatSeqAIJKokkosReduce(C2, MAT_INITIAL_MATRIX, localB, N, l2g, a->Mvctx, mm->sf, mm->abuf, mm->srcrowoffset, mm->dstrowoffset, mm->C2_global));
954 
955   mm->kh.create_spadd_handle(false); /* Input C1, C2 are NOT sorted, since B may be not */
956   KokkosSparse::spadd_symbolic(&mm->kh, mm->C1_global, mm->C2_global, mm->C_global);
957   /* Have to do numeric since spadd_symbolic does not really populate column indices of the result matrix */
958   KokkosSparse::spadd_numeric(&mm->kh, (MatScalarType)1.0, mm->C1_global, (MatScalarType)1.0, mm->C2_global, mm->C_global);
959   mm->C1 = C1;
960   mm->C2 = C2;
961   PetscFunctionReturn(0);
962 }
963 
964 PetscErrorCode MatProductNumeric_MPIAIJKokkos(Mat C) {
965   Mat_Product                 *product = C->product;
966   MatProductType               ptype;
967   MatProductData_MPIAIJKokkos *mmdata;
968   MatMatStruct                *mm = NULL;
969   MatMatStruct_AB             *ab;
970   MatMatStruct_AtB            *atb;
971   Mat                          A, B, Ad, Ao, Bd, Bo;
972   const MatScalarType          one = 1.0; /* Not use literal 1.0 directly, to avoid wrong template instantiation in KokkosSparse::spadd_numeric */
973 
974   PetscFunctionBegin;
975   MatCheckProduct(C, 1);
976   mmdata = static_cast<MatProductData_MPIAIJKokkos *>(product->data);
977   ptype  = product->type;
978   A      = product->A;
979   B      = product->B;
980   Ad     = static_cast<Mat_MPIAIJ *>(A->data)->A;
981   Ao     = static_cast<Mat_MPIAIJ *>(A->data)->B;
982   Bd     = static_cast<Mat_MPIAIJ *>(B->data)->A;
983   Bo     = static_cast<Mat_MPIAIJ *>(B->data)->B;
984 
985   if (mmdata->reusesym) {           /* We reached here through e.g., MatMatMult(A,B,MAT_INITIAL_MATRIX,..,C), where symbolic/numeric are combined */
986     mmdata->reusesym = PETSC_FALSE; /* So that next time when user calls MatMatMult(E,F,MAT_REUSE_MATRIX,..,C), we still do numeric  */
987     ab               = mmdata->mmAB;
988     atb              = mmdata->mmAtB;
989     if (ab) {
990       static_cast<MatProductData_SeqAIJKokkos *>(ab->C1->product->data)->reusesym = PETSC_FALSE;
991       static_cast<MatProductData_SeqAIJKokkos *>(ab->C2->product->data)->reusesym = PETSC_FALSE;
992     }
993     if (atb) {
994       static_cast<MatProductData_SeqAIJKokkos *>(atb->C1->product->data)->reusesym = PETSC_FALSE;
995       static_cast<MatProductData_SeqAIJKokkos *>(atb->C2->product->data)->reusesym = PETSC_FALSE;
996     }
997     PetscFunctionReturn(0);
998   }
999 
1000   if (ptype == MATPRODUCT_AB) {
1001     ab = mmdata->mmAB;
1002     /* C1 = Ad * B_local */
1003     PetscCheck(ab->C1->ops->productnumeric && ab->C2->ops->productnumeric, PETSC_COMM_SELF, PETSC_ERR_PLIB, "Missing numeric op for MATPRODUCT_AB");
1004     PetscCall(MatMPIAIJGetLocalMatMerge(B, MAT_REUSE_MATRIX, NULL /*glob*/, &ab->B_local));
1005     PetscCheck(ab->C1->product->B == ab->B_local, PETSC_COMM_SELF, PETSC_ERR_PLIB, "In MATPRODUCT_AB, internal mat product matrix C1->B has unexpectedly changed");
1006     if (ab->C1->product->A != Ad) PetscCall(MatProductReplaceMats(Ad, NULL, NULL, ab->C1));
1007     PetscCall((*ab->C1->ops->productnumeric)(ab->C1));
1008     PetscCall(MatSeqAIJKokkosBcast(ab->B_local, MAT_REUSE_MATRIX, 0 /*N*/, MatColIdxKokkosView() /*l2g*/, NULL /*ownerSF*/, ab->sf, ab->abuf, ab->rows, ab->rowoffset, ab->B_other));
1009     /* C2 = Ao * B_other */
1010     PetscCheck(ab->C2->product->B == ab->B_other, PETSC_COMM_SELF, PETSC_ERR_PLIB, "In MATPRODUCT_AB, internal mat product matrix C2->B has unexpectedly changed");
1011     if (ab->C1->product->A != Ao) PetscCall(MatProductReplaceMats(Ao, NULL, NULL, ab->C2));
1012     PetscCall((*ab->C2->ops->productnumeric)(ab->C2));
1013     /* C = C1_global + C2_global */
1014     KokkosSparse::spadd_numeric(&ab->kh, one, ab->C1_global, one, ab->C2_global, ab->C_global);
1015     mm = static_cast<MatMatStruct *>(ab);
1016   } else if (ptype == MATPRODUCT_AtB) {
1017     atb = mmdata->mmAtB;
1018     PetscCheck(atb->C1->ops->productnumeric && atb->C2->ops->productnumeric, PETSC_COMM_SELF, PETSC_ERR_PLIB, "Missing numeric op for MATPRODUCT_AtB");
1019     /* C1 = Ad^t * B_local */
1020     PetscCall(MatMPIAIJGetLocalMatMerge(B, MAT_REUSE_MATRIX, NULL /*glob*/, &atb->B_local));
1021     PetscCheck(atb->C1->product->B == atb->B_local, PETSC_COMM_SELF, PETSC_ERR_PLIB, "In MATPRODUCT_AtB, internal mat product matrix C1->B has unexpectedly changed");
1022     if (atb->C1->product->A != Ad) PetscCall(MatProductReplaceMats(Ad, NULL, NULL, atb->C1));
1023     PetscCall((*atb->C1->ops->productnumeric)(atb->C1));
1024 
1025     /* C2 = Ao^t * B_local */
1026     PetscCheck(atb->C2->product->B == atb->B_local, PETSC_COMM_SELF, PETSC_ERR_PLIB, "In MATPRODUCT_AtB, internal mat product matrix C2->B has unexpectedly changed");
1027     if (atb->C2->product->A != Ao) PetscCall(MatProductReplaceMats(Ao, NULL, NULL, atb->C2));
1028     PetscCall((*atb->C2->ops->productnumeric)(atb->C2));
1029     /* Form C2_global */
1030     PetscCall(MatSeqAIJKokkosReduce(atb->C2, MAT_REUSE_MATRIX, PETSC_TRUE, 0 /*N*/, MatColIdxKokkosView() /*l2g*/, NULL /*ownerSF*/, atb->sf, atb->abuf, atb->srcrowoffset, atb->dstrowoffset, atb->C2_global));
1031     /* C = C1_global + C2_global */
1032     KokkosSparse::spadd_numeric(&atb->kh, one, atb->C1_global, one, atb->C2_global, atb->C_global);
1033     mm = static_cast<MatMatStruct *>(atb);
1034   } else if (ptype == MATPRODUCT_PtAP) { /* BtAB */
1035     ab = mmdata->mmAB;
1036     PetscCall(MatMPIAIJGetLocalMatMerge(B, MAT_REUSE_MATRIX, NULL /*glob*/, &ab->B_local));
1037 
1038     /* ab->C1 = Ad * B_local */
1039     PetscCheck(ab->C1->product->B == ab->B_local, PETSC_COMM_SELF, PETSC_ERR_PLIB, "In MATPRODUCT_PtAP, internal mat product matrix ab->C1->B has unexpectedly changed");
1040     if (ab->C1->product->A != Ad) PetscCall(MatProductReplaceMats(Ad, NULL, NULL, ab->C1));
1041     PetscCall((*ab->C1->ops->productnumeric)(ab->C1));
1042     PetscCall(MatSeqAIJKokkosBcast(ab->B_local, MAT_REUSE_MATRIX, 0 /*N*/, MatColIdxKokkosView() /*l2g*/, NULL /*ownerSF*/, ab->sf, ab->abuf, ab->rows, ab->rowoffset, ab->B_other));
1043     /* ab->C2 = Ao * B_other */
1044     if (ab->C2->product->A != Ao) PetscCall(MatProductReplaceMats(Ao, NULL, NULL, ab->C2));
1045     PetscCall((*ab->C2->ops->productnumeric)(ab->C2)); /* C2 = Ao * B_other */
1046     KokkosSparse::spadd_numeric(&ab->kh, one, ab->C1_global, one, ab->C2_global, ab->C_global);
1047 
1048     /* atb->C1 = Bd^t * ab->C_petsc */
1049     atb = mmdata->mmAtB;
1050     PetscCheck(atb->C1->product->B == ab->C_petsc, PETSC_COMM_SELF, PETSC_ERR_PLIB, "In MATPRODUCT_PtAP, internal mat product matrix atb->C1->B has unexpectedly changed");
1051     if (atb->C1->product->A != Bd) PetscCall(MatProductReplaceMats(Bd, NULL, NULL, atb->C1));
1052     PetscCall((*atb->C1->ops->productnumeric)(atb->C1));
1053     /* atb->C2 = Bo^t * ab->C_petsc */
1054     if (atb->C2->product->A != Bo) PetscCall(MatProductReplaceMats(Bo, NULL, NULL, atb->C2));
1055     PetscCall((*atb->C2->ops->productnumeric)(atb->C2));
1056     PetscCall(MatSeqAIJKokkosReduce(atb->C2, MAT_REUSE_MATRIX, PETSC_FALSE, 0 /*N*/, MatColIdxKokkosView() /*l2g*/, NULL /*ownerSF*/, atb->sf, atb->abuf, atb->srcrowoffset, atb->dstrowoffset, atb->C2_global));
1057     KokkosSparse::spadd_numeric(&atb->kh, one, atb->C1_global, one, atb->C2_global, atb->C_global);
1058     mm = static_cast<MatMatStruct *>(atb);
1059   }
1060   /* Split C_global to form C */
1061   PetscCall(MatSetMPIAIJKokkosWithGlobalCSRMatrix(C, MAT_REUSE_MATRIX, mm->C_global, mm->Cdstart));
1062   PetscFunctionReturn(0);
1063 }
1064 
1065 PetscErrorCode MatProductSymbolic_MPIAIJKokkos(Mat C) {
1066   Mat                          A, B;
1067   Mat_Product                 *product = C->product;
1068   MatProductType               ptype;
1069   MatProductData_MPIAIJKokkos *mmdata;
1070   MatMatStruct                *mm   = NULL;
1071   IS                           glob = NULL;
1072   const PetscInt              *garray;
1073   PetscInt                     m, n, M, N, sz;
1074   ConstMatColIdxKokkosView     l2g; /* map local col ids to global ones */
1075 
1076   PetscFunctionBegin;
1077   MatCheckProduct(C, 1);
1078   PetscCheck(!product->data, PetscObjectComm((PetscObject)C), PETSC_ERR_PLIB, "Product data not empty");
1079   ptype = product->type;
1080   A     = product->A;
1081   B     = product->B;
1082 
1083   switch (ptype) {
1084   case MATPRODUCT_AB:
1085     m = A->rmap->n;
1086     n = B->cmap->n;
1087     M = A->rmap->N;
1088     N = B->cmap->N;
1089     break;
1090   case MATPRODUCT_AtB:
1091     m = A->cmap->n;
1092     n = B->cmap->n;
1093     M = A->cmap->N;
1094     N = B->cmap->N;
1095     break;
1096   case MATPRODUCT_PtAP:
1097     m = B->cmap->n;
1098     n = B->cmap->n;
1099     M = B->cmap->N;
1100     N = B->cmap->N;
1101     break; /* BtAB */
1102   default: SETERRQ(PetscObjectComm((PetscObject)C), PETSC_ERR_PLIB, "Not for product type %s", MatProductTypes[ptype]);
1103   }
1104 
1105   PetscCall(MatSetSizes(C, m, n, M, N));
1106   PetscCall(PetscLayoutSetUp(C->rmap));
1107   PetscCall(PetscLayoutSetUp(C->cmap));
1108   PetscCall(MatSetType(C, ((PetscObject)A)->type_name));
1109 
1110   mmdata           = new MatProductData_MPIAIJKokkos();
1111   mmdata->reusesym = product->api_user;
1112 
1113   if (ptype == MATPRODUCT_AB) {
1114     mmdata->mmAB = new MatMatStruct_AB();
1115     PetscCall(MatProductSymbolic_MPIAIJKokkos_AB(product, A, B, mmdata->mmAB));
1116     mm = static_cast<MatMatStruct *>(mmdata->mmAB);
1117   } else if (ptype == MATPRODUCT_AtB) {
1118     mmdata->mmAtB = new MatMatStruct_AtB();
1119     auto atb      = mmdata->mmAtB;
1120     PetscCall(MatMPIAIJGetLocalMatMerge(B, MAT_INITIAL_MATRIX, &glob, &atb->B_local));
1121     PetscCall(ISGetIndices(glob, &garray));
1122     PetscCall(ISGetSize(glob, &sz));
1123     l2g = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), ConstMatColIdxKokkosViewHost(garray, sz));
1124     PetscCall(MatProductSymbolic_MPIAIJKokkos_AtB(product, A, atb->B_local, PETSC_TRUE, N, l2g, atb));
1125     PetscCall(ISRestoreIndices(glob, &garray));
1126     PetscCall(ISDestroy(&glob));
1127     mm = static_cast<MatMatStruct *>(atb);
1128   } else if (ptype == MATPRODUCT_PtAP) {    /* BtAB */
1129     mmdata->mmAB  = new MatMatStruct_AB();  /* tmp=A*B */
1130     mmdata->mmAtB = new MatMatStruct_AtB(); /* C=B^t*tmp */
1131     auto ab       = mmdata->mmAB;
1132     auto atb      = mmdata->mmAtB;
1133     PetscCall(MatProductSymbolic_MPIAIJKokkos_AB(product, A, B, ab));
1134     auto tmp = new Mat_SeqAIJKokkos(ab->C_global); /* Memory will be owned by ab->C_petsc */
1135     PetscCall(MatCreateSeqAIJKokkosWithCSRMatrix(PETSC_COMM_SELF, tmp, &ab->C_petsc));
1136     PetscCall(MatProductSymbolic_MPIAIJKokkos_AtB(product, B, ab->C_petsc, PETSC_FALSE, N, l2g /*not used*/, atb));
1137     mm = static_cast<MatMatStruct *>(atb);
1138   }
1139   /* Split the C_global into petsc A, B format */
1140   PetscCall(MatSetMPIAIJKokkosWithGlobalCSRMatrix(C, MAT_INITIAL_MATRIX, mm->C_global, mm->Cdstart));
1141   C->product->data       = mmdata;
1142   C->product->destroy    = MatProductDataDestroy_MPIAIJKokkos;
1143   C->ops->productnumeric = MatProductNumeric_MPIAIJKokkos;
1144   PetscFunctionReturn(0);
1145 }
1146 
1147 PETSC_INTERN PetscErrorCode MatProductSetFromOptions_MPIAIJKokkos(Mat mat) {
1148   Mat_Product *product = mat->product;
1149   PetscBool    match   = PETSC_FALSE;
1150   PetscBool    usecpu  = PETSC_FALSE;
1151 
1152   PetscFunctionBegin;
1153   MatCheckProduct(mat, 1);
1154   if (!product->A->boundtocpu && !product->B->boundtocpu) { PetscCall(PetscObjectTypeCompare((PetscObject)product->B, ((PetscObject)product->A)->type_name, &match)); }
1155   if (match) { /* we can always fallback to the CPU if requested */
1156     switch (product->type) {
1157     case MATPRODUCT_AB:
1158       if (product->api_user) {
1159         PetscOptionsBegin(PetscObjectComm((PetscObject)mat), ((PetscObject)mat)->prefix, "MatMatMult", "Mat");
1160         PetscCall(PetscOptionsBool("-matmatmult_backend_cpu", "Use CPU code", "MatMatMult", usecpu, &usecpu, NULL));
1161         PetscOptionsEnd();
1162       } else {
1163         PetscOptionsBegin(PetscObjectComm((PetscObject)mat), ((PetscObject)mat)->prefix, "MatProduct_AB", "Mat");
1164         PetscCall(PetscOptionsBool("-mat_product_algorithm_backend_cpu", "Use CPU code", "MatMatMult", usecpu, &usecpu, NULL));
1165         PetscOptionsEnd();
1166       }
1167       break;
1168     case MATPRODUCT_AtB:
1169       if (product->api_user) {
1170         PetscOptionsBegin(PetscObjectComm((PetscObject)mat), ((PetscObject)mat)->prefix, "MatTransposeMatMult", "Mat");
1171         PetscCall(PetscOptionsBool("-mattransposematmult_backend_cpu", "Use CPU code", "MatTransposeMatMult", usecpu, &usecpu, NULL));
1172         PetscOptionsEnd();
1173       } else {
1174         PetscOptionsBegin(PetscObjectComm((PetscObject)mat), ((PetscObject)mat)->prefix, "MatProduct_AtB", "Mat");
1175         PetscCall(PetscOptionsBool("-mat_product_algorithm_backend_cpu", "Use CPU code", "MatTransposeMatMult", usecpu, &usecpu, NULL));
1176         PetscOptionsEnd();
1177       }
1178       break;
1179     case MATPRODUCT_PtAP:
1180       if (product->api_user) {
1181         PetscOptionsBegin(PetscObjectComm((PetscObject)mat), ((PetscObject)mat)->prefix, "MatPtAP", "Mat");
1182         PetscCall(PetscOptionsBool("-matptap_backend_cpu", "Use CPU code", "MatPtAP", usecpu, &usecpu, NULL));
1183         PetscOptionsEnd();
1184       } else {
1185         PetscOptionsBegin(PetscObjectComm((PetscObject)mat), ((PetscObject)mat)->prefix, "MatProduct_PtAP", "Mat");
1186         PetscCall(PetscOptionsBool("-mat_product_algorithm_backend_cpu", "Use CPU code", "MatPtAP", usecpu, &usecpu, NULL));
1187         PetscOptionsEnd();
1188       }
1189       break;
1190     default: break;
1191     }
1192     match = (PetscBool)!usecpu;
1193   }
1194   if (match) {
1195     switch (product->type) {
1196     case MATPRODUCT_AB:
1197     case MATPRODUCT_AtB:
1198     case MATPRODUCT_PtAP: mat->ops->productsymbolic = MatProductSymbolic_MPIAIJKokkos; break;
1199     default: break;
1200     }
1201   }
1202   /* fallback to MPIAIJ ops */
1203   if (!mat->ops->productsymbolic) { PetscCall(MatProductSetFromOptions_MPIAIJ(mat)); }
1204   PetscFunctionReturn(0);
1205 }
1206 
1207 static PetscErrorCode MatSetPreallocationCOO_MPIAIJKokkos(Mat mat, PetscCount coo_n, PetscInt coo_i[], PetscInt coo_j[]) {
1208   Mat_MPIAIJ       *mpiaij = (Mat_MPIAIJ *)mat->data;
1209   Mat_MPIAIJKokkos *mpikok;
1210 
1211   PetscFunctionBegin;
1212   PetscCall(MatSetPreallocationCOO_MPIAIJ(mat, coo_n, coo_i, coo_j));
1213   mat->preallocated = PETSC_TRUE;
1214   PetscCall(MatAssemblyBegin(mat, MAT_FINAL_ASSEMBLY));
1215   PetscCall(MatAssemblyEnd(mat, MAT_FINAL_ASSEMBLY));
1216   PetscCall(MatZeroEntries(mat));
1217   mpikok = static_cast<Mat_MPIAIJKokkos *>(mpiaij->spptr);
1218   delete mpikok;
1219   mpiaij->spptr = new Mat_MPIAIJKokkos(mpiaij);
1220   PetscFunctionReturn(0);
1221 }
1222 
1223 static PetscErrorCode MatSetValuesCOO_MPIAIJKokkos(Mat mat, const PetscScalar v[], InsertMode imode) {
1224   Mat_MPIAIJ                 *mpiaij = static_cast<Mat_MPIAIJ *>(mat->data);
1225   Mat_MPIAIJKokkos           *mpikok = static_cast<Mat_MPIAIJKokkos *>(mpiaij->spptr);
1226   Mat                         A = mpiaij->A, B = mpiaij->B;
1227   PetscCount                  Annz = mpiaij->Annz, Annz2 = mpiaij->Annz2, Bnnz = mpiaij->Bnnz, Bnnz2 = mpiaij->Bnnz2;
1228   MatScalarKokkosView         Aa, Ba;
1229   MatScalarKokkosView         v1;
1230   MatScalarKokkosView        &vsend  = mpikok->sendbuf_d;
1231   const MatScalarKokkosView  &v2     = mpikok->recvbuf_d;
1232   const PetscCountKokkosView &Ajmap1 = mpikok->Ajmap1_d, Ajmap2 = mpikok->Ajmap2_d, Aimap2 = mpikok->Aimap2_d;
1233   const PetscCountKokkosView &Bjmap1 = mpikok->Bjmap1_d, Bjmap2 = mpikok->Bjmap2_d, Bimap2 = mpikok->Bimap2_d;
1234   const PetscCountKokkosView &Aperm1 = mpikok->Aperm1_d, Aperm2 = mpikok->Aperm2_d, Bperm1 = mpikok->Bperm1_d, Bperm2 = mpikok->Bperm2_d;
1235   const PetscCountKokkosView &Cperm1 = mpikok->Cperm1_d;
1236   PetscMemType                memtype;
1237 
1238   PetscFunctionBegin;
1239   PetscCall(PetscGetMemType(v, &memtype)); /* Return PETSC_MEMTYPE_HOST when v is NULL */
1240   if (PetscMemTypeHost(memtype)) {         /* If user gave v[] in host, we need to copy it to device if any */
1241     v1 = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), MatScalarKokkosViewHost((PetscScalar *)v, mpiaij->coo_n));
1242   } else {
1243     v1 = MatScalarKokkosView((PetscScalar *)v, mpiaij->coo_n); /* Directly use v[]'s memory */
1244   }
1245 
1246   if (imode == INSERT_VALUES) {
1247     PetscCall(MatSeqAIJGetKokkosViewWrite(A, &Aa)); /* write matrix values */
1248     PetscCall(MatSeqAIJGetKokkosViewWrite(B, &Ba));
1249   } else {
1250     PetscCall(MatSeqAIJGetKokkosView(A, &Aa)); /* read & write matrix values */
1251     PetscCall(MatSeqAIJGetKokkosView(B, &Ba));
1252   }
1253 
1254   /* Pack entries to be sent to remote */
1255   Kokkos::parallel_for(
1256     vsend.extent(0), KOKKOS_LAMBDA(const PetscCount i) { vsend(i) = v1(Cperm1(i)); });
1257 
1258   /* Send remote entries to their owner and overlap the communication with local computation */
1259   PetscCall(PetscSFReduceWithMemTypeBegin(mpiaij->coo_sf, MPIU_SCALAR, PETSC_MEMTYPE_KOKKOS, vsend.data(), PETSC_MEMTYPE_KOKKOS, v2.data(), MPI_REPLACE));
1260   /* Add local entries to A and B in one kernel */
1261   Kokkos::parallel_for(
1262     Annz + Bnnz, KOKKOS_LAMBDA(PetscCount i) {
1263       PetscScalar sum = 0.0;
1264       if (i < Annz) {
1265         for (PetscCount k = Ajmap1(i); k < Ajmap1(i + 1); k++) sum += v1(Aperm1(k));
1266         Aa(i) = (imode == INSERT_VALUES ? 0.0 : Aa(i)) + sum;
1267       } else {
1268         i -= Annz;
1269         for (PetscCount k = Bjmap1(i); k < Bjmap1(i + 1); k++) sum += v1(Bperm1(k));
1270         Ba(i) = (imode == INSERT_VALUES ? 0.0 : Ba(i)) + sum;
1271       }
1272     });
1273   PetscCall(PetscSFReduceEnd(mpiaij->coo_sf, MPIU_SCALAR, vsend.data(), v2.data(), MPI_REPLACE));
1274 
1275   /* Add received remote entries to A and B in one kernel */
1276   Kokkos::parallel_for(
1277     Annz2 + Bnnz2, KOKKOS_LAMBDA(PetscCount i) {
1278       if (i < Annz2) {
1279         for (PetscCount k = Ajmap2(i); k < Ajmap2(i + 1); k++) Aa(Aimap2(i)) += v2(Aperm2(k));
1280       } else {
1281         i -= Annz2;
1282         for (PetscCount k = Bjmap2(i); k < Bjmap2(i + 1); k++) Ba(Bimap2(i)) += v2(Bperm2(k));
1283       }
1284     });
1285 
1286   if (imode == INSERT_VALUES) {
1287     PetscCall(MatSeqAIJRestoreKokkosViewWrite(A, &Aa)); /* Increase A & B's state etc. */
1288     PetscCall(MatSeqAIJRestoreKokkosViewWrite(B, &Ba));
1289   } else {
1290     PetscCall(MatSeqAIJRestoreKokkosView(A, &Aa));
1291     PetscCall(MatSeqAIJRestoreKokkosView(B, &Ba));
1292   }
1293   PetscFunctionReturn(0);
1294 }
1295 
1296 PetscErrorCode MatDestroy_MPIAIJKokkos(Mat A) {
1297   Mat_MPIAIJ *mpiaij = (Mat_MPIAIJ *)A->data;
1298 
1299   PetscFunctionBegin;
1300   PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatMPIAIJSetPreallocation_C", NULL));
1301   PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatMPIAIJGetLocalMatMerge_C", NULL));
1302   PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatSetPreallocationCOO_C", NULL));
1303   PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatSetValuesCOO_C", NULL));
1304   delete (Mat_MPIAIJKokkos *)mpiaij->spptr;
1305   PetscCall(MatDestroy_MPIAIJ(A));
1306   PetscFunctionReturn(0);
1307 }
1308 
1309 PETSC_INTERN PetscErrorCode MatConvert_MPIAIJ_MPIAIJKokkos(Mat A, MatType mtype, MatReuse reuse, Mat *newmat) {
1310   Mat         B;
1311   Mat_MPIAIJ *a;
1312 
1313   PetscFunctionBegin;
1314   if (reuse == MAT_INITIAL_MATRIX) {
1315     PetscCall(MatDuplicate(A, MAT_COPY_VALUES, newmat));
1316   } else if (reuse == MAT_REUSE_MATRIX) {
1317     PetscCall(MatCopy(A, *newmat, SAME_NONZERO_PATTERN));
1318   }
1319   B = *newmat;
1320 
1321   B->boundtocpu = PETSC_FALSE;
1322   PetscCall(PetscFree(B->defaultvectype));
1323   PetscCall(PetscStrallocpy(VECKOKKOS, &B->defaultvectype));
1324   PetscCall(PetscObjectChangeTypeName((PetscObject)B, MATMPIAIJKOKKOS));
1325 
1326   a = static_cast<Mat_MPIAIJ *>(A->data);
1327   if (a->A) PetscCall(MatSetType(a->A, MATSEQAIJKOKKOS));
1328   if (a->B) PetscCall(MatSetType(a->B, MATSEQAIJKOKKOS));
1329   if (a->lvec) PetscCall(VecSetType(a->lvec, VECSEQKOKKOS));
1330 
1331   B->ops->assemblyend           = MatAssemblyEnd_MPIAIJKokkos;
1332   B->ops->mult                  = MatMult_MPIAIJKokkos;
1333   B->ops->multadd               = MatMultAdd_MPIAIJKokkos;
1334   B->ops->multtranspose         = MatMultTranspose_MPIAIJKokkos;
1335   B->ops->productsetfromoptions = MatProductSetFromOptions_MPIAIJKokkos;
1336   B->ops->destroy               = MatDestroy_MPIAIJKokkos;
1337 
1338   PetscCall(PetscObjectComposeFunction((PetscObject)B, "MatMPIAIJSetPreallocation_C", MatMPIAIJSetPreallocation_MPIAIJKokkos));
1339   PetscCall(PetscObjectComposeFunction((PetscObject)B, "MatMPIAIJGetLocalMatMerge_C", MatMPIAIJGetLocalMatMerge_MPIAIJKokkos));
1340   PetscCall(PetscObjectComposeFunction((PetscObject)B, "MatSetPreallocationCOO_C", MatSetPreallocationCOO_MPIAIJKokkos));
1341   PetscCall(PetscObjectComposeFunction((PetscObject)B, "MatSetValuesCOO_C", MatSetValuesCOO_MPIAIJKokkos));
1342   PetscFunctionReturn(0);
1343 }
1344 /*MC
1345    MATSMPIAIJKOKKOS - MATAIJKOKKOS = "(mpi)aijkokkos" - A matrix type to be used for sparse matrices with Kokkos
1346 
1347    A matrix type type using Kokkos-Kernels CrsMatrix type for portability across different device types
1348 
1349    Options Database Keys:
1350 .  -mat_type aijkokkos - sets the matrix type to "aijkokkos" during a call to MatSetFromOptions()
1351 
1352   Level: beginner
1353 
1354 .seealso: `MatCreateAIJKokkos()`, `MATSEQAIJKOKKOS`
1355 M*/
1356 PETSC_EXTERN PetscErrorCode MatCreate_MPIAIJKokkos(Mat A) {
1357   PetscFunctionBegin;
1358   PetscCall(PetscKokkosInitializeCheck());
1359   PetscCall(MatCreate_MPIAIJ(A));
1360   PetscCall(MatConvert_MPIAIJ_MPIAIJKokkos(A, MATMPIAIJKOKKOS, MAT_INPLACE_MATRIX, &A));
1361   PetscFunctionReturn(0);
1362 }
1363 
1364 /*@C
1365    MatCreateAIJKokkos - Creates a sparse matrix in AIJ (compressed row) format
1366    (the default parallel PETSc format).  This matrix will ultimately pushed down
1367    to Kokkos for calculations. For good matrix
1368    assembly performance the user should preallocate the matrix storage by setting
1369    the parameter nz (or the array nnz).  By setting these parameters accurately,
1370    performance during matrix assembly can be increased by more than a factor of 50.
1371 
1372    Collective
1373 
1374    Input Parameters:
1375 +  comm - MPI communicator, set to PETSC_COMM_SELF
1376 .  m - number of rows
1377 .  n - number of columns
1378 .  nz - number of nonzeros per row (same for all rows)
1379 -  nnz - array containing the number of nonzeros in the various rows
1380          (possibly different for each row) or NULL
1381 
1382    Output Parameter:
1383 .  A - the matrix
1384 
1385    It is recommended that one use the MatCreate(), MatSetType() and/or MatSetFromOptions(),
1386    MatXXXXSetPreallocation() paradigm instead of this routine directly.
1387    [MatXXXXSetPreallocation() is, for example, MatSeqAIJSetPreallocation]
1388 
1389    Notes:
1390    If nnz is given then nz is ignored
1391 
1392    The AIJ format (also called the Yale sparse matrix format or
1393    compressed row storage), is fully compatible with standard Fortran 77
1394    storage.  That is, the stored row and column indices can begin at
1395    either one (as in Fortran) or zero.  See the users' manual for details.
1396 
1397    Specify the preallocated storage with either nz or nnz (not both).
1398    Set nz=PETSC_DEFAULT and nnz=NULL for PETSc to control dynamic memory
1399    allocation.  For large problems you MUST preallocate memory or you
1400    will get TERRIBLE performance, see the users' manual chapter on matrices.
1401 
1402    By default, this format uses inodes (identical nodes) when possible, to
1403    improve numerical efficiency of matrix-vector products and solves. We
1404    search for consecutive rows with the same nonzero structure, thereby
1405    reusing matrix information to achieve increased efficiency.
1406 
1407    Level: intermediate
1408 
1409 .seealso: `MatCreate()`, `MatCreateAIJ()`, `MatSetValues()`, `MatSeqAIJSetColumnIndices()`, `MatCreateSeqAIJWithArrays()`, `MatCreateAIJ()`, `MATMPIAIJKOKKOS`, `MATAIJKokkos`
1410 @*/
1411 PetscErrorCode MatCreateAIJKokkos(MPI_Comm comm, PetscInt m, PetscInt n, PetscInt M, PetscInt N, PetscInt d_nz, const PetscInt d_nnz[], PetscInt o_nz, const PetscInt o_nnz[], Mat *A) {
1412   PetscMPIInt size;
1413 
1414   PetscFunctionBegin;
1415   PetscCall(MatCreate(comm, A));
1416   PetscCall(MatSetSizes(*A, m, n, M, N));
1417   PetscCallMPI(MPI_Comm_size(comm, &size));
1418   if (size > 1) {
1419     PetscCall(MatSetType(*A, MATMPIAIJKOKKOS));
1420     PetscCall(MatMPIAIJSetPreallocation(*A, d_nz, d_nnz, o_nz, o_nnz));
1421   } else {
1422     PetscCall(MatSetType(*A, MATSEQAIJKOKKOS));
1423     PetscCall(MatSeqAIJSetPreallocation(*A, d_nz, d_nnz));
1424   }
1425   PetscFunctionReturn(0);
1426 }
1427 
1428 // get GPU pointer to stripped down Mat. For both Seq and MPI Mat.
1429 PetscErrorCode MatKokkosGetDeviceMatWrite(Mat A, PetscSplitCSRDataStructure *B) {
1430   PetscMPIInt                size, rank;
1431   MPI_Comm                   comm;
1432   PetscSplitCSRDataStructure d_mat = NULL;
1433 
1434   PetscFunctionBegin;
1435   PetscCall(PetscObjectGetComm((PetscObject)A, &comm));
1436   PetscCallMPI(MPI_Comm_size(comm, &size));
1437   PetscCallMPI(MPI_Comm_rank(comm, &rank));
1438   if (size == 1) {
1439     PetscCall(MatSeqAIJKokkosGetDeviceMat(A, &d_mat));
1440     PetscCall(MatSeqAIJKokkosModifyDevice(A)); /* Since we are going to modify matrix values on device */
1441   } else {
1442     Mat_MPIAIJ *aij = (Mat_MPIAIJ *)A->data;
1443     PetscCall(MatSeqAIJKokkosGetDeviceMat(aij->A, &d_mat));
1444     PetscCall(MatSeqAIJKokkosModifyDevice(aij->A));
1445     PetscCall(MatSeqAIJKokkosModifyDevice(aij->B));
1446     PetscCheck(A->nooffprocentries || aij->donotstash, PetscObjectComm((PetscObject)A), PETSC_ERR_SUP, "Device assembly does not currently support offproc values insertion. Use MatSetOption(A,MAT_NO_OFF_PROC_ENTRIES,PETSC_TRUE) or MatSetOption(A,MAT_IGNORE_OFF_PROC_ENTRIES,PETSC_TRUE)");
1447   }
1448   // act like MatSetValues because not called on host
1449   if (A->assembled) {
1450     if (A->was_assembled) { PetscCall(PetscInfo(A, "Assemble more than once already\n")); }
1451     A->was_assembled = PETSC_TRUE; // this is done (lazy) in MatAssemble but we are not calling it anymore - done in AIJ AssemblyEnd, need here?
1452   } else {
1453     PetscCall(PetscInfo(A, "Warning !assemble ??? assembled=%" PetscInt_FMT "\n", A->assembled));
1454   }
1455   if (!d_mat) {
1456     struct _n_SplitCSRMat h_mat; /* host container */
1457     Mat_SeqAIJKokkos     *aijkokA;
1458     Mat_SeqAIJ           *jaca;
1459     PetscInt              n = A->rmap->n, nnz;
1460     Mat                   Amat;
1461     PetscInt             *colmap;
1462 
1463     /* create and copy h_mat */
1464     h_mat.M = A->cmap->N; // use for debug build
1465     PetscCall(PetscInfo(A, "Create device matrix in Kokkos\n"));
1466     if (size == 1) {
1467       Amat            = A;
1468       jaca            = (Mat_SeqAIJ *)A->data;
1469       h_mat.rstart    = 0;
1470       h_mat.rend      = A->rmap->n;
1471       h_mat.cstart    = 0;
1472       h_mat.cend      = A->cmap->n;
1473       h_mat.offdiag.i = h_mat.offdiag.j = NULL;
1474       h_mat.offdiag.a                   = NULL;
1475       aijkokA                           = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
1476     } else {
1477       Mat_MPIAIJ       *aij  = (Mat_MPIAIJ *)A->data;
1478       Mat_SeqAIJ       *jacb = (Mat_SeqAIJ *)aij->B->data;
1479       PetscInt          ii;
1480       Mat_SeqAIJKokkos *aijkokB;
1481 
1482       Amat    = aij->A;
1483       aijkokA = static_cast<Mat_SeqAIJKokkos *>(aij->A->spptr);
1484       aijkokB = static_cast<Mat_SeqAIJKokkos *>(aij->B->spptr);
1485       jaca    = (Mat_SeqAIJ *)aij->A->data;
1486       PetscCheck(!aij->B->cmap->n || aij->garray, comm, PETSC_ERR_PLIB, "MPIAIJ Matrix was assembled but is missing garray");
1487       PetscCheck(aij->B->rmap->n == aij->A->rmap->n, comm, PETSC_ERR_SUP, "Only support aij->B->rmap->n == aij->A->rmap->n");
1488       aij->donotstash          = PETSC_TRUE;
1489       aij->A->nooffprocentries = aij->B->nooffprocentries = A->nooffprocentries = PETSC_TRUE;
1490       jaca->nonew = jacb->nonew = PETSC_TRUE; // no more disassembly
1491       PetscCall(PetscCalloc1(A->cmap->N, &colmap));
1492       PetscCall(PetscLogObjectMemory((PetscObject)A, (A->cmap->N) * sizeof(PetscInt)));
1493       for (ii = 0; ii < aij->B->cmap->n; ii++) colmap[aij->garray[ii]] = ii + 1;
1494       // allocate B copy data
1495       h_mat.rstart = A->rmap->rstart;
1496       h_mat.rend   = A->rmap->rend;
1497       h_mat.cstart = A->cmap->rstart;
1498       h_mat.cend   = A->cmap->rend;
1499       nnz          = jacb->i[n];
1500       if (jacb->compressedrow.use) {
1501         const Kokkos::View<PetscInt *, Kokkos::HostSpace, Kokkos::MemoryTraits<Kokkos::Unmanaged>> h_i_k(jacb->i, n + 1);
1502         aijkokB->i_uncompressed_d = Kokkos::View<PetscInt *>(Kokkos::create_mirror(DefaultMemorySpace(), h_i_k));
1503         Kokkos::deep_copy(aijkokB->i_uncompressed_d, h_i_k);
1504         h_mat.offdiag.i = aijkokB->i_uncompressed_d.data();
1505       } else {
1506         h_mat.offdiag.i = aijkokB->i_device_data();
1507       }
1508       h_mat.offdiag.j = aijkokB->j_device_data();
1509       h_mat.offdiag.a = aijkokB->a_device_data();
1510       {
1511         Kokkos::View<PetscInt *, Kokkos::HostSpace, Kokkos::MemoryTraits<Kokkos::Unmanaged>> h_colmap_k(colmap, A->cmap->N);
1512         aijkokB->colmap_d = Kokkos::View<PetscInt *>(Kokkos::create_mirror(DefaultMemorySpace(), h_colmap_k));
1513         Kokkos::deep_copy(aijkokB->colmap_d, h_colmap_k);
1514         h_mat.colmap = aijkokB->colmap_d.data();
1515         PetscCall(PetscFree(colmap));
1516       }
1517       h_mat.offdiag.ignorezeroentries = jacb->ignorezeroentries;
1518       h_mat.offdiag.n                 = n;
1519     }
1520     // allocate A copy data
1521     nnz                          = jaca->i[n];
1522     h_mat.diag.n                 = n;
1523     h_mat.diag.ignorezeroentries = jaca->ignorezeroentries;
1524     PetscCallMPI(MPI_Comm_rank(comm, &h_mat.rank));
1525     PetscCheck(!jaca->compressedrow.use, PETSC_COMM_SELF, PETSC_ERR_PLIB, "A does not suppport compressed row (todo)");
1526     h_mat.diag.i = aijkokA->i_device_data();
1527     h_mat.diag.j = aijkokA->j_device_data();
1528     h_mat.diag.a = aijkokA->a_device_data();
1529     // copy pointers and metdata to device
1530     PetscCall(MatSeqAIJKokkosSetDeviceMat(Amat, &h_mat));
1531     PetscCall(MatSeqAIJKokkosGetDeviceMat(Amat, &d_mat));
1532     PetscCall(PetscInfo(A, "Create device Mat n=%" PetscInt_FMT " nnz=%" PetscInt_FMT "\n", h_mat.diag.n, nnz));
1533   }
1534   *B           = d_mat;       // return it, set it in Mat, and set it up
1535   A->assembled = PETSC_FALSE; // ready to write with matsetvalues - this done (lazy) in normal MatSetValues
1536   PetscFunctionReturn(0);
1537 }
1538 
1539 PETSC_INTERN PetscErrorCode MatSeqAIJKokkosGetOffloadMask(Mat A, const char **mask) {
1540   Mat_SeqAIJKokkos *aijkok = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
1541 
1542   PetscFunctionBegin;
1543   if (!aijkok) *mask = "AIJKOK_UNALLOCATED";
1544   else if (aijkok->a_dual.need_sync_host()) *mask = "PETSC_OFFLOAD_GPU";
1545   else if (aijkok->a_dual.need_sync_device()) *mask = "PETSC_OFFLOAD_CPU";
1546   else *mask = "PETSC_OFFLOAD_BOTH";
1547   PetscFunctionReturn(0);
1548 }
1549 
1550 PETSC_INTERN PetscErrorCode MatAIJKokkosPrintOffloadMask(Mat A) {
1551   PetscMPIInt size;
1552   Mat         Ad, Ao;
1553   const char *amask, *bmask;
1554 
1555   PetscFunctionBegin;
1556   PetscCallMPI(MPI_Comm_size(PetscObjectComm((PetscObject)A), &size));
1557 
1558   if (size == 1) {
1559     PetscCall(MatSeqAIJKokkosGetOffloadMask(A, &amask));
1560     PetscCall(PetscPrintf(PETSC_COMM_SELF, "%s\n", amask));
1561   } else {
1562     Ad = ((Mat_MPIAIJ *)A->data)->A;
1563     Ao = ((Mat_MPIAIJ *)A->data)->B;
1564     PetscCall(MatSeqAIJKokkosGetOffloadMask(Ad, &amask));
1565     PetscCall(MatSeqAIJKokkosGetOffloadMask(Ao, &bmask));
1566     PetscCall(PetscPrintf(PETSC_COMM_SELF, "Diag : Off-diag = %s : %s\n", amask, bmask));
1567   }
1568   PetscFunctionReturn(0);
1569 }
1570