xref: /petsc/src/mat/impls/aij/mpi/kokkos/mpiaijkok.kokkos.cxx (revision d7547e516efde0cd36ffdeebcfafd4768debadcc)
1 #include <petscvec_kokkos.hpp>
2 #include <petscsf.h>
3 #include <petsc/private/sfimpl.h>
4 #include <../src/mat/impls/aij/mpi/mpiaij.h>
5 #include <../src/mat/impls/aij/mpi/kokkos/mpiaijkok.hpp>
6 #include <KokkosSparse_spadd.hpp>
7 
8 PetscErrorCode MatAssemblyEnd_MPIAIJKokkos(Mat A, MatAssemblyType mode)
9 {
10   Mat_SeqAIJKokkos *aijkok;
11 
12   PetscFunctionBegin;
13   PetscCall(MatAssemblyEnd_MPIAIJ(A, mode));
14   aijkok = static_cast<Mat_SeqAIJKokkos *>(((Mat_MPIAIJ *)A->data)->A->spptr); /* Access spptr after MatAssemblyEnd_MPIAIJ(), which might have deleted old spptr */
15   if (aijkok && aijkok->device_mat_d.data()) {
16     A->offloadmask = PETSC_OFFLOAD_GPU; // in GPU mode, no going back. MatSetValues checks this
17   }
18 
19   PetscFunctionReturn(0);
20 }
21 
22 PetscErrorCode MatMPIAIJSetPreallocation_MPIAIJKokkos(Mat mat, PetscInt d_nz, const PetscInt d_nnz[], PetscInt o_nz, const PetscInt o_nnz[])
23 {
24   Mat_MPIAIJ *mpiaij = (Mat_MPIAIJ *)mat->data;
25 
26   PetscFunctionBegin;
27   PetscCall(PetscLayoutSetUp(mat->rmap));
28   PetscCall(PetscLayoutSetUp(mat->cmap));
29 #if defined(PETSC_USE_DEBUG)
30   if (d_nnz) {
31     PetscInt i;
32     for (i = 0; i < mat->rmap->n; i++) PetscCheck(d_nnz[i] >= 0, PETSC_COMM_SELF, PETSC_ERR_ARG_OUTOFRANGE, "d_nnz cannot be less than 0: local row %" PetscInt_FMT " value %" PetscInt_FMT, i, d_nnz[i]);
33   }
34   if (o_nnz) {
35     PetscInt i;
36     for (i = 0; i < mat->rmap->n; i++) PetscCheck(o_nnz[i] >= 0, PETSC_COMM_SELF, PETSC_ERR_ARG_OUTOFRANGE, "o_nnz cannot be less than 0: local row %" PetscInt_FMT " value %" PetscInt_FMT, i, o_nnz[i]);
37   }
38 #endif
39 #if defined(PETSC_USE_CTABLE)
40   PetscCall(PetscTableDestroy(&mpiaij->colmap));
41 #else
42   PetscCall(PetscFree(mpiaij->colmap));
43 #endif
44   PetscCall(PetscFree(mpiaij->garray));
45   PetscCall(VecDestroy(&mpiaij->lvec));
46   PetscCall(VecScatterDestroy(&mpiaij->Mvctx));
47   /* Because the B will have been resized we simply destroy it and create a new one each time */
48   PetscCall(MatDestroy(&mpiaij->B));
49 
50   if (!mpiaij->A) {
51     PetscCall(MatCreate(PETSC_COMM_SELF, &mpiaij->A));
52     PetscCall(MatSetSizes(mpiaij->A, mat->rmap->n, mat->cmap->n, mat->rmap->n, mat->cmap->n));
53   }
54   if (!mpiaij->B) {
55     PetscMPIInt size;
56     PetscCallMPI(MPI_Comm_size(PetscObjectComm((PetscObject)mat), &size));
57     PetscCall(MatCreate(PETSC_COMM_SELF, &mpiaij->B));
58     PetscCall(MatSetSizes(mpiaij->B, mat->rmap->n, size > 1 ? mat->cmap->N : 0, mat->rmap->n, size > 1 ? mat->cmap->N : 0));
59   }
60   PetscCall(MatSetType(mpiaij->A, MATSEQAIJKOKKOS));
61   PetscCall(MatSetType(mpiaij->B, MATSEQAIJKOKKOS));
62   PetscCall(MatSeqAIJSetPreallocation(mpiaij->A, d_nz, d_nnz));
63   PetscCall(MatSeqAIJSetPreallocation(mpiaij->B, o_nz, o_nnz));
64   mat->preallocated = PETSC_TRUE;
65   PetscFunctionReturn(0);
66 }
67 
68 PetscErrorCode MatMult_MPIAIJKokkos(Mat mat, Vec xx, Vec yy)
69 {
70   Mat_MPIAIJ *mpiaij = (Mat_MPIAIJ *)mat->data;
71   PetscInt    nt;
72 
73   PetscFunctionBegin;
74   PetscCall(VecGetLocalSize(xx, &nt));
75   PetscCheck(nt == mat->cmap->n, PETSC_COMM_SELF, PETSC_ERR_ARG_SIZ, "Incompatible partition of mat (%" PetscInt_FMT ") and xx (%" PetscInt_FMT ")", mat->cmap->n, nt);
76   PetscCall(VecScatterBegin(mpiaij->Mvctx, xx, mpiaij->lvec, INSERT_VALUES, SCATTER_FORWARD));
77   PetscCall((*mpiaij->A->ops->mult)(mpiaij->A, xx, yy));
78   PetscCall(VecScatterEnd(mpiaij->Mvctx, xx, mpiaij->lvec, INSERT_VALUES, SCATTER_FORWARD));
79   PetscCall((*mpiaij->B->ops->multadd)(mpiaij->B, mpiaij->lvec, yy, yy));
80   PetscFunctionReturn(0);
81 }
82 
83 PetscErrorCode MatMultAdd_MPIAIJKokkos(Mat mat, Vec xx, Vec yy, Vec zz)
84 {
85   Mat_MPIAIJ *mpiaij = (Mat_MPIAIJ *)mat->data;
86   PetscInt    nt;
87 
88   PetscFunctionBegin;
89   PetscCall(VecGetLocalSize(xx, &nt));
90   PetscCheck(nt == mat->cmap->n, PETSC_COMM_SELF, PETSC_ERR_ARG_SIZ, "Incompatible partition of mat (%" PetscInt_FMT ") and xx (%" PetscInt_FMT ")", mat->cmap->n, nt);
91   PetscCall(VecScatterBegin(mpiaij->Mvctx, xx, mpiaij->lvec, INSERT_VALUES, SCATTER_FORWARD));
92   PetscCall((*mpiaij->A->ops->multadd)(mpiaij->A, xx, yy, zz));
93   PetscCall(VecScatterEnd(mpiaij->Mvctx, xx, mpiaij->lvec, INSERT_VALUES, SCATTER_FORWARD));
94   PetscCall((*mpiaij->B->ops->multadd)(mpiaij->B, mpiaij->lvec, zz, zz));
95   PetscFunctionReturn(0);
96 }
97 
98 PetscErrorCode MatMultTranspose_MPIAIJKokkos(Mat mat, Vec xx, Vec yy)
99 {
100   Mat_MPIAIJ *mpiaij = (Mat_MPIAIJ *)mat->data;
101   PetscInt    nt;
102 
103   PetscFunctionBegin;
104   PetscCall(VecGetLocalSize(xx, &nt));
105   PetscCheck(nt == mat->rmap->n, PETSC_COMM_SELF, PETSC_ERR_ARG_SIZ, "Incompatible partition of mat (%" PetscInt_FMT ") and xx (%" PetscInt_FMT ")", mat->rmap->n, nt);
106   PetscCall((*mpiaij->B->ops->multtranspose)(mpiaij->B, xx, mpiaij->lvec));
107   PetscCall((*mpiaij->A->ops->multtranspose)(mpiaij->A, xx, yy));
108   PetscCall(VecScatterBegin(mpiaij->Mvctx, mpiaij->lvec, yy, ADD_VALUES, SCATTER_REVERSE));
109   PetscCall(VecScatterEnd(mpiaij->Mvctx, mpiaij->lvec, yy, ADD_VALUES, SCATTER_REVERSE));
110   PetscFunctionReturn(0);
111 }
112 
113 /* Merge the "A, B" matrices of mat into a matrix C.  mat's type is MPIAIJKOKKOS. C's type is MATSEQAIJKOKKOS.
114    A is put before B. C's size would be A->rmap->n by (A->cmap->n + B->cmap->n).
115    C still uses local column ids. Their corresponding global column ids are returned in glob.
116 */
117 PetscErrorCode MatMPIAIJGetLocalMatMerge_MPIAIJKokkos(Mat mat, MatReuse reuse, IS *glob, Mat *C)
118 {
119   Mat             Ad, Ao;
120   const PetscInt *cmap;
121 
122   PetscFunctionBegin;
123   PetscCall(MatMPIAIJGetSeqAIJ(mat, &Ad, &Ao, &cmap));
124   PetscCall(MatSeqAIJKokkosMergeMats(Ad, Ao, reuse, C));
125   if (glob) {
126     PetscInt cst, i, dn, on, *gidx;
127     PetscCall(MatGetLocalSize(Ad, NULL, &dn));
128     PetscCall(MatGetLocalSize(Ao, NULL, &on));
129     PetscCall(MatGetOwnershipRangeColumn(mat, &cst, NULL));
130     PetscCall(PetscMalloc1(dn + on, &gidx));
131     for (i = 0; i < dn; i++) gidx[i] = cst + i;
132     for (i = 0; i < on; i++) gidx[i + dn] = cmap[i];
133     PetscCall(ISCreateGeneral(PetscObjectComm((PetscObject)Ad), dn + on, gidx, PETSC_OWN_POINTER, glob));
134   }
135   PetscFunctionReturn(0);
136 }
137 
138 /* Structs used in matrix product C=AB, C=A^tB and C=B^tAB */
139 struct MatMatStruct {
140   MatRowMapKokkosView Cdstart; /* Used to split sequential matrix into petsc's A, B format */
141   PetscSF             sf;      /* SF to send/recv matrix entries */
142   MatScalarKokkosView abuf;    /* buf of mat values in send/recv */
143   Mat                 C1, C2, B_local;
144   KokkosCsrMatrix     C1_global, C2_global, C_global;
145   KernelHandle        kh;
146   MatMatStruct()
147   {
148     C1 = C2 = B_local = NULL;
149     sf                = NULL;
150   }
151 
152   ~MatMatStruct()
153   {
154     MatDestroy(&C1);
155     MatDestroy(&C2);
156     MatDestroy(&B_local);
157     PetscSFDestroy(&sf);
158     kh.destroy_spadd_handle();
159   }
160 };
161 
162 struct MatMatStruct_AB : public MatMatStruct {
163   MatColIdxKokkosView rows;
164   MatRowMapKokkosView rowoffset;
165   Mat                 B_other, C_petsc; /* SEQAIJKOKKOS matrices. TODO: have a better var name than C_petsc */
166 
167   MatMatStruct_AB() : B_other(NULL), C_petsc(NULL) { }
168   ~MatMatStruct_AB()
169   {
170     MatDestroy(&B_other);
171     MatDestroy(&C_petsc);
172   }
173 };
174 
175 struct MatMatStruct_AtB : public MatMatStruct {
176   MatRowMapKokkosView srcrowoffset, dstrowoffset;
177 };
178 
179 struct MatProductData_MPIAIJKokkos {
180   MatMatStruct_AB  *mmAB;
181   MatMatStruct_AtB *mmAtB;
182   PetscBool         reusesym;
183 
184   MatProductData_MPIAIJKokkos() : mmAB(NULL), mmAtB(NULL), reusesym(PETSC_FALSE) { }
185   ~MatProductData_MPIAIJKokkos()
186   {
187     delete mmAB;
188     delete mmAtB;
189   }
190 };
191 
192 static PetscErrorCode MatProductDataDestroy_MPIAIJKokkos(void *data)
193 {
194   PetscFunctionBegin;
195   PetscCallCXX(delete static_cast<MatProductData_MPIAIJKokkos *>(data));
196   PetscFunctionReturn(0);
197 }
198 
199 /* MatSeqAIJKokkosGetCSRMatrixWithGlobalColumnIds - Get a KokkosCsrMatrix from a MATSEQAIJKOKKOS matrix
200 
201    Input Parameters:
202 +  A       - the MATSEQAIJKOKKOS matrix
203 .  N       - new column size for the returned Kokkos matrix
204 -  l2g     - a map that maps old col ids to new col ids
205 
206    Output Parameters:
207 .  csrmat  - the Kokkos matrix, which has the same row size as A, shares a, i but not j with A.
208  */
209 static PetscErrorCode MatSeqAIJKokkosGetCSRMatrixWithGlobalColumnIds(Mat A, PetscInt N, const ConstMatColIdxKokkosView &l2g, KokkosCsrMatrix &csrmat)
210 {
211   KokkosCsrMatrix    &orig = static_cast<Mat_SeqAIJKokkos *>(A->spptr)->csrmat;
212   MatColIdxKokkosView jg("jg", orig.nnz()); /* New j array for csrmat */
213 
214   PetscFunctionBegin;
215   PetscCallCXX(Kokkos::parallel_for(
216     orig.nnz(), KOKKOS_LAMBDA(const PetscInt i) { jg(i) = l2g(orig.graph.entries(i)); }));
217   PetscCallCXX(csrmat = KokkosCsrMatrix("csrmat", orig.numRows(), N, orig.nnz(), orig.values, orig.graph.row_map, jg));
218   PetscFunctionReturn(0);
219 }
220 
221 /* MatSetMPIAIJKokkosWithSplitSeqAIJKokkosMatrices - Set the diag and offdiag matrices of a MATMPIAIJKOKKOS matrix.
222    It is similar to MatCreateMPIAIJWithSplitArrays.
223 
224   Input Parameters:
225 +  mat   - the MATMPIAIJKOKKOS matrix, which should have its type and layout set, but should not have its diag, offdiag matrices set
226 .  A     - the diag matrix using local col ids
227 -  B     - the offdiag matrix using global col ids
228 
229   Output Parameters:
230 .  mat   - the updated MATMPIAIJKOKKOS matrix
231 */
232 static PetscErrorCode MatSetMPIAIJKokkosWithSplitSeqAIJKokkosMatrices(Mat mat, Mat A, Mat B)
233 {
234   Mat_MPIAIJ       *mpiaij = static_cast<Mat_MPIAIJ *>(mat->data);
235   PetscInt          m, n, M, N, Am, An, Bm, Bn;
236   Mat_SeqAIJKokkos *bkok = static_cast<Mat_SeqAIJKokkos *>(B->spptr);
237 
238   PetscFunctionBegin;
239   PetscCall(MatGetSize(mat, &M, &N));
240   PetscCall(MatGetLocalSize(mat, &m, &n));
241   PetscCall(MatGetLocalSize(A, &Am, &An));
242   PetscCall(MatGetLocalSize(B, &Bm, &Bn));
243 
244   PetscCheck(m == Am && m == Bm, PETSC_COMM_SELF, PETSC_ERR_PLIB, "local number of rows do not match");
245   PetscCheck(n == An, PETSC_COMM_SELF, PETSC_ERR_PLIB, "local number of columns do not match");
246   PetscCheck(N == Bn, PETSC_COMM_SELF, PETSC_ERR_PLIB, "global number of columns do not match");
247   PetscCheck(!mpiaij->A && !mpiaij->B, PETSC_COMM_SELF, PETSC_ERR_PLIB, "A, B of the MPIAIJ matrix are not empty");
248   mpiaij->A = A;
249   mpiaij->B = B;
250 
251   mat->preallocated     = PETSC_TRUE;
252   mat->nooffprocentries = PETSC_TRUE; /* See MatAssemblyBegin_MPIAIJ. In effect, making MatAssemblyBegin a nop */
253 
254   PetscCall(MatSetOption(mat, MAT_NO_OFF_PROC_ENTRIES, PETSC_TRUE));
255   PetscCall(MatAssemblyBegin(mat, MAT_FINAL_ASSEMBLY));
256   /* MatAssemblyEnd is critical here. It sets mat->offloadmask according to A and B's, and
257     also gets mpiaij->B compacted, with its col ids and size reduced
258   */
259   PetscCall(MatAssemblyEnd(mat, MAT_FINAL_ASSEMBLY));
260   PetscCall(MatSetOption(mat, MAT_NO_OFF_PROC_ENTRIES, PETSC_FALSE));
261   PetscCall(MatSetOption(mat, MAT_NEW_NONZERO_LOCATION_ERR, PETSC_TRUE));
262 
263   /* Update bkok with new local col ids (stored on host) and size */
264   bkok->j_dual.modify_host();
265   bkok->j_dual.sync_device();
266   bkok->SetColSize(mpiaij->B->cmap->n);
267   PetscFunctionReturn(0);
268 }
269 
270 /* MatSeqAIJKokkosBcast - Bcast rows of a SEQAIJKOKKOS matrice (B) to form a SEQAIJKOKKOS matrix (C).
271 
272    It is essentially the MPIAIJKOKKOS counterpart of MatGetBrowsOfAoCols_MPIAIJ, but supports device and uses PetscSF.
273    In the given ownerSF, leaves correspond to rows in C, and roots correspond to rows in B. Roots may connect to multiple leaves.
274    Suppose C's j-th row is connected to a root identified by PetscSFNode (k,i), it means we will bcast the i-th row of B on rank k
275    to j-th row of C. ownerSF's leaves must be contiguous (in other words, as if ilocal=NULL was used to set its graph).
276 
277    Collective on comm of ownerSF
278 
279    Input Parameters:
280 +   B       - the SEQAIJKOKKOS matrix, using local col ids
281 .   reuse   - either MAT_INITIAL_MATRIX or MAT_REUSE_MATRIX
282 .   N       - global col ids are in range of [0,N). N Must be the same across ranks (nonsignificant in MAT_REUSE_MATRIX)
283 .   l2g     - a map mapping B's local col ids to global ones (nonsignificant in MAT_REUSE_MATRIX)
284 .   ownerSF - the ownership SF (nonsignificant in MAT_REUSE_MATRIX)
285 
286    Input/Output Parameters (out when resue = MAT_INITIAL_MATRIX, inout when reuse = MAT_REUSE_MATRIX)
287 +   bcastSF   - the SF used to bcast rows of B. This plain SF does buffer (abuf) to buffer (Ca) send/recv. In this SF, vertices are nonzeros.
288 .   abuf      - buffer for sending matrix values
289 .   rows      - array containing indices of (local) rows that this rank needs to bcast to others. Each receiver rank has a chunk in rows[].
290                 Values in rows[] might have repeats, which simply indicates a row will be bcast'ed to multiple neighbors.
291 .   rowoffset - For each row in rows[], it will be copied to rowoffset[] at abuf[]
292 -   C         -  the SEQAIJKOKKOS matrix made of the bcast'ed rows, using local col ids.
293 */
294 static PetscErrorCode MatSeqAIJKokkosBcast(Mat B, MatReuse reuse, PetscInt N, const ConstMatColIdxKokkosView &l2g, PetscSF ownerSF, PetscSF &bcastSF, MatScalarKokkosView &abuf, MatColIdxKokkosView &rows, MatRowMapKokkosView &rowoffset, Mat &C)
295 {
296   Mat_SeqAIJKokkos *bkok, *ckok;
297 
298   PetscFunctionBegin;
299   PetscCall(MatSeqAIJKokkosSyncDevice(B)); /* Make sure B->spptr is accessible */
300   bkok = static_cast<Mat_SeqAIJKokkos *>(B->spptr);
301 
302   if (reuse == MAT_REUSE_MATRIX) {
303     ckok = static_cast<Mat_SeqAIJKokkos *>(C->spptr);
304 
305     const auto &Ba = bkok->a_dual.view_device();
306     const auto &Bi = bkok->i_dual.view_device();
307     const auto &Ca = ckok->a_dual.view_device();
308 
309     /* Copy Ba to abuf */
310     Kokkos::parallel_for(
311       Kokkos::TeamPolicy<>(rows.extent(0), Kokkos::AUTO()), KOKKOS_LAMBDA(const KokkosTeamMemberType &t) {
312         PetscInt i    = t.league_rank(); /* rows[i] is r-th row of B */
313         PetscInt r    = rows(i);
314         PetscInt base = rowoffset(i); /* Copy r-th row of B to this offset in abuf[] */
315         Kokkos::parallel_for(Kokkos::TeamThreadRange(t, Bi(r + 1) - Bi(r)), [&](PetscInt k) { abuf(base + k) = Ba(Bi(r) + k); });
316       });
317 
318     /* Send abuf to Ca through bcastSF and then mark C is updated on device */
319     PetscCall(PetscSFBcastBegin(bcastSF, MPIU_SCALAR, abuf.data(), Ca.data(), MPI_REPLACE)); /* TODO: get memtype for abuf */
320     PetscCall(PetscSFBcastEnd(bcastSF, MPIU_SCALAR, abuf.data(), Ca.data(), MPI_REPLACE));
321     ckok->a_dual.modify_device();
322   } else if (reuse == MAT_INITIAL_MATRIX) {
323     MPI_Comm    comm;
324     PetscMPIInt tag;
325     PetscInt    k, Cm, Cn, Cnnz, *Ci_h, nroots, nleaves;
326 
327     PetscCallMPI(PetscObjectGetComm((PetscObject)ownerSF, &comm));
328     PetscCall(PetscSFGetGraph(ownerSF, &nroots, &nleaves, NULL, NULL));
329     Cm = nleaves; /* row size of C */
330     Cn = N;       /* col size of C, which initially uses global ids, so we can safely set its col size as N */
331 
332     /* Get row lens (nz) of B's rows for later fast query */
333     PetscInt       *Browlens;
334     const PetscInt *tmp = bkok->i_host_data();
335     PetscCall(PetscMalloc1(nroots, &Browlens));
336     for (k = 0; k < nroots; k++) Browlens[k] = tmp[k + 1] - tmp[k];
337 
338     /* By ownerSF, each proc gets lens of rows of C */
339     MatRowMapKokkosDualView Ci("i", Cm + 1); /* C's rowmap */
340     Ci_h    = Ci.view_host().data();
341     Ci_h[0] = 0;
342     PetscCall(PetscSFBcastWithMemTypeBegin(ownerSF, MPIU_INT, PETSC_MEMTYPE_HOST, Browlens, PETSC_MEMTYPE_HOST, &Ci_h[1], MPI_REPLACE));
343     PetscCall(PetscSFBcastEnd(ownerSF, MPIU_INT, Browlens, &Ci_h[1], MPI_REPLACE));
344     for (k = 1; k < Cm + 1; k++) Ci_h[k] += Ci_h[k - 1]; /* Convert lens to CSR */
345     Cnnz = Ci_h[Cm];
346     Ci.modify_host();
347     Ci.sync_device();
348 
349     /* With the newly known Cnnz, we are able to allocate (j, a) for C on host & device */
350     MatColIdxKokkosDualView Cj("j", Cnnz);
351     MatScalarKokkosDualView Ca("a", Cnnz);
352 
353     /* Now build the bcastSF to fill Ca, Cj. This plain SF only does (contiguous) buffer to buffer send/recv */
354     const PetscMPIInt *iranks, *ranks;
355     const PetscInt    *ioffset, *irootloc, *roffset;
356     PetscInt           i, j, niranks, nranks, *sdisp, *rdisp, *rowptr;
357     MPI_Request       *reqs;
358 
359     PetscCall(PetscSFGetLeafRanks(ownerSF, &niranks, &iranks, &ioffset, &irootloc));                      /* irootloc[] contains indices of rows I need to send to each receiver */
360     PetscCall(PetscSFGetRootRanks(ownerSF, &nranks, &ranks, &roffset, NULL /*rmine*/, NULL /*rremote*/)); /* recv info */
361 
362     /* figure out offsets at the send buffer, to build the SF
363       sdisp[]  - stores offsets of nonzeros (in abuf or jbuf, see later) I need to send, per receiver.
364       rowptr[] - stores offsets for data of each row in abuf
365 
366       rdisp[]  - to receive sdisp[]
367     */
368     PetscCall(PetscMalloc3(niranks + 1, &sdisp, nranks, &rdisp, niranks + nranks, &reqs));
369     MatRowMapKokkosViewHost rowptr_h("rowptr_h", ioffset[niranks] + 1); /* Let Kokkos do the allocation, so that we can do an easy mirror later */
370     rowptr = rowptr_h.data();
371 
372     sdisp[0]  = 0;
373     rowptr[0] = 0;
374     for (i = 0; i < niranks; i++) { /* for each receiver */
375       PetscInt len, nz = 0;
376       for (j = ioffset[i]; j < ioffset[i + 1]; j++) { /* for each row to this receiver */
377         len           = Browlens[irootloc[j]];
378         rowptr[j + 1] = rowptr[j] + len;
379         nz += len;
380       }
381       sdisp[i + 1] = sdisp[i] + nz;
382     }
383     PetscCallMPI(PetscCommGetNewTag(comm, &tag));
384     for (i = 0; i < nranks; i++) PetscCallMPI(MPI_Irecv(&rdisp[i], 1, MPIU_INT, ranks[i], tag, comm, &reqs[i]));
385     for (i = 0; i < niranks; i++) PetscCallMPI(MPI_Isend(&sdisp[i], 1, MPIU_INT, iranks[i], tag, comm, &reqs[nranks + i]));
386     PetscCallMPI(MPI_Waitall(niranks + nranks, reqs, MPI_STATUSES_IGNORE));
387 
388     PetscInt     nleaves2 = Cnnz;           /* leaves are the nonzeros I will receive */
389     PetscInt     nroots2  = sdisp[niranks]; /* roots are the nonzeros (in abuf) I will send */
390     PetscSFNode *iremote;
391     PetscCall(PetscMalloc1(nleaves2, &iremote));
392     for (i = 0; i < nranks; i++) { /* for each sender */
393       k = 0;
394       for (j = Ci_h[roffset[i]]; j < Ci_h[roffset[i + 1]]; j++) {
395         iremote[j].rank  = ranks[i];
396         iremote[j].index = rdisp[i] + k;
397         k++;
398       }
399     }
400     /* TODO: we should extend PetscSF APIs for this buffer-to-buffer send/recv */
401     PetscCall(PetscSFCreate(comm, &bcastSF));
402     PetscCall(PetscSFSetGraph(bcastSF, nroots2, nleaves2, NULL /*ilocal*/, PETSC_OWN_POINTER, iremote, PETSC_OWN_POINTER));
403 
404     /* Extract selected rows of B, and copy their (a, j) into abuf[] and jbuf[], with j converted
405       from local to global. Then use bcastSF to fill Ca, Cj.
406     */
407     ConstMatColIdxKokkosViewHost rows_h(irootloc, ioffset[niranks]); /* irootloc[] stores indices of rows I need to send */
408     MatColIdxKokkosView          rows("rows", ioffset[niranks]);
409     Kokkos::deep_copy(rows, rows_h); /* Use deep copy since irootoc is managed by PetscSF and we want 'rows' to be standalone */
410 
411     rowoffset = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), rowptr_h); /* If no device, rowoffset will be an alias to rowptr_h */
412 
413     MatColIdxKokkosView jbuf("jbuf", sdisp[niranks]);   /* send buf for (global) col ids */
414     abuf = MatScalarKokkosView("abuf", sdisp[niranks]); /* send buf for mat values */
415 
416     const auto &Ba = bkok->a_dual.view_device();
417     const auto &Bi = bkok->i_dual.view_device();
418     const auto &Bj = bkok->j_dual.view_device();
419 
420     /* Copy Ba, Bj to abuf, jbuf with change col ids from local to global */
421     Kokkos::parallel_for(
422       Kokkos::TeamPolicy<>(rows.extent(0), Kokkos::AUTO()), KOKKOS_LAMBDA(const KokkosTeamMemberType &t) {
423         PetscInt i    = t.league_rank(); /* rows[i] is r-th row of B */
424         PetscInt r    = rows(i);
425         PetscInt base = rowoffset(i); /* Copy r-th row of B to this offset in abuf[] */
426         Kokkos::parallel_for(Kokkos::TeamThreadRange(t, Bi(r + 1) - Bi(r)), [&](PetscInt k) {
427           abuf(base + k) = Ba(Bi(r) + k);
428           jbuf(base + k) = l2g(Bj(Bi(r) + k));
429         });
430       });
431 
432     /* Send abuf & jbuf to fill Ca, Cj */
433     PetscCall(PetscSFBcastBegin(bcastSF, MPIU_INT, jbuf.data(), Cj.view_device().data(), MPI_REPLACE));
434     PetscCall(PetscSFBcastBegin(bcastSF, MPIU_SCALAR, abuf.data(), Ca.view_device().data(), MPI_REPLACE));
435     PetscCall(PetscSFBcastEnd(bcastSF, MPIU_INT, jbuf.data(), Cj.view_device().data(), MPI_REPLACE));
436     PetscCall(PetscSFBcastEnd(bcastSF, MPIU_SCALAR, abuf.data(), Ca.view_device().data(), MPI_REPLACE));
437     Cj.modify_device(); /* Mark Cj, Ca modified on device, but only sync Cj since we might not need Ca on host at all */
438     Cj.sync_host();
439     Ca.modify_device();
440 
441     /* Construct C with Ca, Ci, Cj */
442     auto ckok = new Mat_SeqAIJKokkos(Cm, Cn, Cnnz, Ci, Cj, Ca);
443     PetscCall(MatCreateSeqAIJKokkosWithCSRMatrix(PETSC_COMM_SELF, ckok, &C));
444     PetscCall(PetscFree3(sdisp, rdisp, reqs));
445     PetscCall(PetscFree(Browlens));
446   } else SETERRQ(PETSC_COMM_SELF, PETSC_ERR_PLIB, "Unsupported MatReuse enum %d", reuse);
447   PetscFunctionReturn(0);
448 }
449 
450 /* MatSeqAIJKokkosReduce - Reduce rows of a SEQAIJKOKKOS matrix (A) to form a Kokkos Csr matrix (C)
451 
452   It is the reverse of MatSeqAIJKokkosBcast in some sense.
453 
454   Think each row of A as a leaf, then the given ownerSF specifies roots of the leaves. Roots may connect to multiple leaves.
455   In this routine, we reduce (i.e., concatenate) leaves (rows) at their roots to form potentially longer rows in C. Such rows might
456   contain repeats, which does not matter since they will be summed up by other routines. C's row size will be nroots of ownerSF.
457 
458   Input Parameters:
459 +  A        - the SEQAIJKOKKOS matrix to be reduced
460 .  reuse    - either MAT_INITIAL_MATRIX or MAT_REUSE_MATRIX
461 .  local    - true if A uses local col ids; false if A is already in global col ids.
462 .  N        - if local, N is A's global col size
463 .  l2g      - if local, a map mapping A's local col ids to global ones, which are in range of [0,N).
464 -  ownerSF  - the SF specifies ownership (root) of rows in A
465 
466   Output Parameters:
467 +  reduceSF    - the SF to reduce A's rows to contiguous buffers at the receiver side
468 .  abuf         - a contiguous buffer to receive A's rows sent to this proc. Suppose there are 'nrows' such rows.
469 .  srcrowoffset - offset array of size nrows+1. Each entry is the corresponding row's offset in abuf[]. srcrowoffset[i+1]-srcrowoffset[i] is row i's len.
470 .  dstrowoffset - offset array of size nrows. Each entry is the corresponding row's offset in Ca[], i.e., C's 'a' array. Row i, i+1 in abuf[] may go to
471                   unrelated places in Ca, so dstrowoffset is not in CSR-like format as srcrowoffset.
472 -  C            - the matrix made up by rows sent to me from other ranks, using global col ids
473 
474    TODO: we can even have MatSeqAIJKokkosReduceBegin/End to provide oppertunity for callers to overlap comp./comm. when reuse = MAT_REUSE_MATRIX.
475  */
476 static PetscErrorCode MatSeqAIJKokkosReduce(Mat A, MatReuse reuse, PetscBool local, PetscInt N, const ConstMatColIdxKokkosView &l2g, PetscSF ownerSF, PetscSF &reduceSF, MatScalarKokkosView &abuf, MatRowMapKokkosView &srcrowoffset, MatRowMapKokkosView &dstrowoffset, KokkosCsrMatrix &C)
477 {
478   PetscInt          i, r, Am, An, Annz, Cnnz, nrows;
479   const PetscInt   *Ai;
480   Mat_SeqAIJKokkos *akok;
481 
482   PetscFunctionBegin;
483   PetscCall(MatSeqAIJKokkosSyncDevice(A)); /* So that A's latest data is on device */
484   PetscCall(MatGetSize(A, &Am, &An));
485   Ai   = static_cast<Mat_SeqAIJ *>(A->data)->i;
486   akok = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
487   Annz = Ai[Am];
488 
489   if (reuse == MAT_REUSE_MATRIX) {
490     /* Send Aa to abuf */
491     PetscCallMPI(PetscSFReduceBegin(reduceSF, MPIU_SCALAR, akok->a_device_data(), abuf.data(), MPI_REPLACE));
492     PetscCallMPI(PetscSFReduceEnd(reduceSF, MPIU_SCALAR, akok->a_device_data(), abuf.data(), MPI_REPLACE));
493 
494     /* Copy abuf to Ca */
495     const MatScalarKokkosView &Ca = C.values;
496     nrows                         = dstrowoffset.extent(0); /* Not srcrowoffset[] since it has an extra entry for CSR */
497     Kokkos::parallel_for(
498       Kokkos::TeamPolicy<>(nrows, Kokkos::AUTO()), KOKKOS_LAMBDA(const KokkosTeamMemberType &t) {
499         PetscInt i   = t.league_rank();
500         PetscInt src = srcrowoffset(i), dst = dstrowoffset(i);
501         PetscInt len = srcrowoffset(i + 1) - srcrowoffset(i);
502         Kokkos::parallel_for(Kokkos::TeamThreadRange(t, len), [&](PetscInt k) { Ca(dst + k) = abuf(src + k); });
503       });
504   } else if (reuse == MAT_INITIAL_MATRIX) {
505     MPI_Comm     comm;
506     MPI_Request *reqs;
507     PetscMPIInt  tag;
508     PetscInt     Cm;
509 
510     PetscCall(PetscObjectGetComm((PetscObject)ownerSF, &comm));
511     PetscCall(PetscCommGetNewTag(comm, &tag));
512 
513     PetscInt           niranks, nranks, nroots, nleaves;
514     const PetscMPIInt *iranks, *ranks;
515     const PetscInt    *ioffset, *rows, *roffset; /* rows[] contains local indices of rows scattered to me from others. ioffset[] is a CSR on rows[] */
516     PetscCall(PetscSFSetUp(ownerSF));
517     PetscCall(PetscSFGetLeafRanks(ownerSF, &niranks, &iranks, &ioffset, &rows));                          /* recv info: iranks[] will send rows to me */
518     PetscCall(PetscSFGetRootRanks(ownerSF, &nranks, &ranks, &roffset, NULL /*rmine*/, NULL /*rremote*/)); /* send info */
519     PetscCall(PetscSFGetGraph(ownerSF, &nroots, &nleaves, NULL, NULL));
520     PetscCheck(nleaves == Am, PETSC_COMM_SELF, PETSC_ERR_PLIB, "ownerSF's nleaves(%" PetscInt_FMT ") != row size of A(%" PetscInt_FMT ")", nleaves, Am);
521     Cm    = nroots;
522     nrows = ioffset[niranks]; /* # of rows to be received. Might receive same row (each is partial) from different senders */
523 
524     /* Tell owners how long each row I will send */
525     PetscInt               *srowlens;                              /* send buf of row lens */
526     MatRowMapKokkosViewHost rrowlens_h("rrowoffset_h", nrows + 1); /* recv buf of row lens. +1 to make CSR later. Memory might be passed to other views */
527     PetscInt               *rrowlens = rrowlens_h.data();
528 
529     PetscCall(PetscMalloc2(Am, &srowlens, niranks + nranks, &reqs));
530     for (i = 0; i < Am; i++) srowlens[i] = Ai[i + 1] - Ai[i];
531     rrowlens[0] = 0;
532     rrowlens++; /* shift the pointer to make the following expression more readable */
533     for (i = 0; i < niranks; i++) PetscCallMPI(MPI_Irecv(&rrowlens[ioffset[i]], ioffset[i + 1] - ioffset[i], MPIU_INT, iranks[i], tag, comm, &reqs[i]));
534     for (i = 0; i < nranks; i++) PetscCallMPI(MPI_Isend(&srowlens[roffset[i]], roffset[i + 1] - roffset[i], MPIU_INT, ranks[i], tag, comm, &reqs[niranks + i]));
535     PetscCallMPI(MPI_Waitall(niranks + nranks, reqs, MPI_STATUSES_IGNORE));
536 
537     /* Owner builds Ci on host by histogramming rrowlens[] */
538     MatRowMapKokkosViewHost Ci_h("i", Cm + 1);
539     Kokkos::deep_copy(Ci_h, 0); /* Zero Ci */
540     MatRowMapType *Ci_ptr = Ci_h.data();
541 
542     for (i = 0; i < nrows; i++) {
543       r = rows[i]; /* local row id of i-th received row */
544 #if defined(PETSC_USE_DEBUG)
545       PetscCheck(r >= 0 && r < Cm, PETSC_COMM_SELF, PETSC_ERR_PLIB, "local row id (%" PetscInt_FMT ") is out of range [0,%" PetscInt_FMT ")", r, Cm);
546 #endif
547       Ci_ptr[r + 1] += rrowlens[i]; /* add to length of row r in C */
548     }
549     for (i = 0; i < Cm; i++) Ci_ptr[i + 1] += Ci_ptr[i]; /* to CSR format */
550     Cnnz = Ci_ptr[Cm];
551 
552     /* For each received row, compute src & dst offsets in memory copying (from recv bufs abuf, jbuf to Ca, Cj) */
553     MatRowMapKokkosViewHost dstrowoffset_h("dstrowoffset_h", nrows);
554     PetscInt               *dstrowoffset_hptr = dstrowoffset_h.data();
555     PetscInt               *currowlens; /* Current row lens. They are temp accumulators for row lens in C, to help build dstrowoffset */
556 
557     PetscCall(PetscCalloc1(Cm, &currowlens));           /* Init with zero, to be added to */
558     for (i = 0; i < nrows; i++) {                       /* for each row I receive */
559       r                    = rows[i];                   /* row id in C */
560       dstrowoffset_hptr[i] = Ci_ptr[r] + currowlens[r]; /* dst offset of the new place for each recv'ed row in Ca/Cj */
561       currowlens[r] += rrowlens[i];                     /* accumulate to length of row r in C */
562     }
563     PetscCall(PetscFree(currowlens));
564 
565     rrowlens--;
566     for (i = 0; i < nrows; i++) rrowlens[i + 1] += rrowlens[i]; /* Change rrowlens[] to CSR format */
567     dstrowoffset = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), dstrowoffset_h);
568     srcrowoffset = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), rrowlens_h); /* src offset of each recv'ed row in abuf/jbuf */
569 
570     /* Build the reduceSF, which performs buffer to buffer send/recv */
571     PetscInt *sdisp, *rdisp; /* buffer to send offsets of roots, and buffer to recv them */
572     PetscCall(PetscMalloc2(niranks, &sdisp, nranks, &rdisp));
573     for (i = 0; i < niranks; i++) sdisp[i] = rrowlens[ioffset[i]];
574     for (i = 0; i < nranks; i++) PetscCallMPI(MPI_Irecv(&rdisp[i], 1, MPIU_INT, ranks[i], tag, comm, &reqs[i]));
575     for (i = 0; i < niranks; i++) PetscCallMPI(MPI_Isend(&sdisp[i], 1, MPIU_INT, iranks[i], tag, comm, &reqs[nranks + i]));
576     PetscCallMPI(MPI_Waitall(niranks + nranks, reqs, MPI_STATUSES_IGNORE));
577 
578     /* Nonzeros in abuf/jbuf are roots and those in A are leaves */
579     PetscInt     nroots2 = Cnnz, nleaves2 = Annz;
580     PetscSFNode *iremote;
581     PetscCall(PetscMalloc1(nleaves2, &iremote)); /* no free, since memory will be given to reduceSF */
582     for (i = 0; i < nranks; i++) {
583       PetscInt rootbase = rdisp[i];                      /* root offset at this root rank */
584       PetscInt leafbase = Ai[roffset[i]];                /* leaf base */
585       PetscInt nz       = Ai[roffset[i + 1]] - leafbase; /* I will send nz nonzeros to this root rank */
586       for (PetscInt k = 0; k < nz; k++) {
587         iremote[leafbase + k].rank  = ranks[i];
588         iremote[leafbase + k].index = rootbase + k;
589       }
590     }
591     PetscCall(PetscSFCreate(comm, &reduceSF));
592     PetscCall(PetscSFSetGraph(reduceSF, nroots2, nleaves2, NULL, PETSC_OWN_POINTER, iremote, PETSC_OWN_POINTER));
593     PetscCall(PetscFree2(sdisp, rdisp));
594 
595     /* Reduce Aa, Ajg to abuf and jbuf */
596 
597     /* If A uses local col ids, convert them to global ones before sending */
598     MatColIdxKokkosView Ajg;
599     if (local) {
600       Ajg                           = MatColIdxKokkosView("j", Annz);
601       const MatColIdxKokkosView &Aj = akok->j_dual.view_device();
602       Kokkos::parallel_for(
603         Annz, KOKKOS_LAMBDA(const PetscInt i) { Ajg(i) = l2g(Aj(i)); });
604     } else {
605       Ajg = akok->j_dual.view_device(); /* no data copy, just take a reference */
606     }
607 
608     MatColIdxKokkosView jbuf("jbuf", Cnnz);
609     abuf = MatScalarKokkosView("abuf", Cnnz);
610     PetscCallMPI(PetscSFReduceBegin(reduceSF, MPIU_INT, Ajg.data(), jbuf.data(), MPI_REPLACE));
611     PetscCallMPI(PetscSFReduceEnd(reduceSF, MPIU_INT, Ajg.data(), jbuf.data(), MPI_REPLACE));
612     PetscCallMPI(PetscSFReduceBegin(reduceSF, MPIU_SCALAR, akok->a_device_data(), abuf.data(), MPI_REPLACE));
613     PetscCallMPI(PetscSFReduceEnd(reduceSF, MPIU_SCALAR, akok->a_device_data(), abuf.data(), MPI_REPLACE));
614 
615     /* Copy data from abuf, jbuf to Ca, Cj */
616     MatRowMapKokkosView Ci = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Ci_h); /* Ci is an alias of Ci_h if no device */
617     MatColIdxKokkosView Cj("j", Cnnz);
618     MatScalarKokkosView Ca("a", Cnnz);
619 
620     Kokkos::parallel_for(
621       Kokkos::TeamPolicy<>(nrows, Kokkos::AUTO()), KOKKOS_LAMBDA(const KokkosTeamMemberType &t) {
622         PetscInt i   = t.league_rank();
623         PetscInt src = srcrowoffset(i), dst = dstrowoffset(i);
624         PetscInt len = srcrowoffset(i + 1) - srcrowoffset(i);
625         Kokkos::parallel_for(Kokkos::TeamThreadRange(t, len), [&](PetscInt k) {
626           Ca(dst + k) = abuf(src + k);
627           Cj(dst + k) = jbuf(src + k);
628         });
629       });
630 
631     /* Build C with Ca, Ci, Cj */
632     C = KokkosCsrMatrix("csrmat", Cm, N, Cnnz, Ca, Ci, Cj);
633     PetscCall(PetscFree2(srowlens, reqs));
634   } else SETERRQ(PETSC_COMM_SELF, PETSC_ERR_PLIB, "Unsupported MatReuse enum %d", reuse);
635   PetscFunctionReturn(0);
636 }
637 
638 /* MatSetMPIAIJKokkosWithGlobalCSRMatrix - Set the diag and offdiag parts of a `MATMPIAIJKOKKOS` matrix by splitting a KokkosCsrMatrix
639 
640   Input Parameters:
641 +  C        - the `MATMPIAIJKOKKOS` matrix, of size m,n,M,N
642 .  reuse    - indicate whether the matrix has called this function before
643 .  csrmat   - the KokkosCsrMatrix, of size m,N
644 -  Cdstart  - when reuse == `MAT_REUSE_MATRIX`, it is an input parameter. For each row in csrmat, it stores the start of the first
645               entry of the diag block of C in csrmat's j array. E.g, if row i has col ids = {0, 3, 4, 5, 7, 9} and the first diag
646               entry is 5, then Cdstart[i] = 3.
647 
648   Output Parameters:
649 +  C        - the updated `MATMPIAIJKOKKOS` matrix
650 -  Cdstart - when reuse == `MAT_INITIAL_MATRIX`, it is an output parameter
651 
652   Note:
653    Between calls with `MAT_INITIAL_MATRIX` or `MAT_REUSE_MATRIX`, csrmat must have the same nonzero pattern
654 
655 .seealso: `MATMPIAIJKOKKOS`
656  */
657 static PetscErrorCode MatSetMPIAIJKokkosWithGlobalCSRMatrix(Mat C, MatReuse reuse, const KokkosCsrMatrix &csrmat, MatRowMapKokkosView &Cdstart)
658 {
659   const MatScalarKokkosView      &Ca = csrmat.values;
660   const ConstMatRowMapKokkosView &Ci = csrmat.graph.row_map;
661   PetscInt                        m, n, N;
662 
663   PetscFunctionBegin;
664   PetscCall(MatGetLocalSize(C, &m, &n));
665   PetscCall(MatGetSize(C, NULL, &N));
666 
667   if (reuse == MAT_REUSE_MATRIX) {
668     Mat_MPIAIJ                *mpiaij = static_cast<Mat_MPIAIJ *>(C->data);
669     Mat_SeqAIJKokkos          *akok   = static_cast<Mat_SeqAIJKokkos *>(mpiaij->A->spptr);
670     Mat_SeqAIJKokkos          *bkok   = static_cast<Mat_SeqAIJKokkos *>(mpiaij->B->spptr);
671     const MatScalarKokkosView &Cda = akok->a_dual.view_device(), Coa = bkok->a_dual.view_device();
672     const MatRowMapKokkosView &Cdi = akok->i_dual.view_device(), Coi = bkok->i_dual.view_device();
673 
674     /* Fill 'a' of Cd and Co on device */
675     Kokkos::parallel_for(
676       Kokkos::TeamPolicy<>(m, Kokkos::AUTO()), KOKKOS_LAMBDA(const KokkosTeamMemberType &t) {
677         PetscInt i       = t.league_rank();     /* row i */
678         PetscInt clen    = Ci(i + 1) - Ci(i);   /* len of row i of C */
679         PetscInt cdlen   = Cdi(i + 1) - Cdi(i); /* len of row i of Cd */
680         PetscInt cdstart = Cdstart(i);          /* [start, end) of row i of Cd in C */
681         PetscInt cdend   = cdstart + cdlen;
682         /* [0, clen) is cut into three blocks: [0, cdstart), [cdstart, cdend), [cdend, clen) */
683         Kokkos::parallel_for(Kokkos::TeamThreadRange(t, clen), [&](PetscInt k) {
684           if (k < cdstart) { /* k in [0, cdstart) */
685             Coa(Coi(i) + k) = Ca(Ci(i) + k);
686           } else if (k < cdend) { /* k in [cdstart, cdend) */
687             Cda(Cdi(i) + (k - cdstart)) = Ca(Ci(i) + k);
688           } else { /* k in [cdend, clen) */
689             Coa(Coi(i) + k - cdlen) = Ca(Ci(i) + k);
690           }
691         });
692       });
693 
694     akok->a_dual.modify_device();
695     bkok->a_dual.modify_device();
696   } else if (reuse == MAT_INITIAL_MATRIX) {
697     Mat                        Cd, Co;
698     const MatColIdxKokkosView &Cj = csrmat.graph.entries;
699     MatRowMapKokkosDualView    Cdi_dual("i", m + 1), Coi_dual("i", m + 1);
700     MatRowMapKokkosView        Cdi = Cdi_dual.view_device(), Coi = Coi_dual.view_device();
701     PetscInt                   cstart, cend;
702 
703     /* Note that each row of C is sorted by col ids. We want to find out how to cut each row into three blocks:
704        left to the diag block, diag block, right to the diag block. The diag block have col ids in [cstart,cend).
705        Suppose a row of C has len nonzeros, indexed by [0, len). We want to know two indices: cdstart and cdend,
706        such that the three blocks are [0,cdstart), [cdstart,cdend), [cdend,len). The following code equivalentaly
707        stores values of cdstart and cdend-cstart (aka Cdi[]) instead.
708      */
709     Cdstart = MatRowMapKokkosView("Cdstart", m);
710     PetscCall(PetscLayoutGetRange(C->cmap, &cstart, &cend)); /* Not MatGetOwnershipRangeColumn() since C has not been preallocated yet */
711 
712     /* I could use RangePolicy and one thread per row. But since each thread essentially does binary search, threads in a
713       CUDA warp would completely diverge. So I use TeamPolicy with a team size 1.
714      */
715     Kokkos::parallel_for(
716       Kokkos::TeamPolicy<>(m, 1), KOKKOS_LAMBDA(const KokkosTeamMemberType &t) {
717         Kokkos::single(Kokkos::PerTeam(t), [=]() {                               /* Only one thread works in a team */
718                                                    PetscInt i = t.league_rank(); /* row i */
719                                                    PetscInt j, first, count, step;
720 
721                                                    if (i == 0) { /* Set the first entry of the i arrays to zero on device, to be used in CSR */
722                                                      Cdi(0) = 0;
723                                                      Coi(0) = 0;
724                                                    }
725 
726                                                    /* Do std::lower_bound(Ci(i),Ci(i+1),cstart) on Cj[]. We use j as the iterator. lower_bound() returns
727           in 'first' the first iterator with a value >= cstart, or last iterator if no such element is found.
728         */
729                                                    count = Ci(i + 1) - Ci(i);
730                                                    first = Ci(i);
731                                                    while (count > 0) {
732                                                      j    = first;
733                                                      step = count / 2;
734                                                      j += step;
735                                                      if (Cj(j) < cstart) {
736                                                        first = ++j;
737                                                        count -= step + 1;
738                                                      } else count = step;
739                                                    }
740                                                    Cdstart(i) = first - Ci(i); /* 'first' is the while-loop's output */
741 
742                                                    /* Do std::lower_bound(first,Ci(i+1),cend) on Cj[] */
743                                                    count = Ci(i + 1) - first;
744                                                    while (count > 0) {
745                                                      j    = first;
746                                                      step = count / 2;
747                                                      j += step;
748                                                      if (Cj(j) < cend) {
749                                                        first = ++j;
750                                                        count -= step + 1;
751                                                      } else count = step;
752                                                    }
753                                                    Cdi(i + 1) = first - (Ci(i) + Cdstart(i));     /* 'first' is the while-loop's output */
754                                                    Coi(i + 1) = (Ci(i + 1) - Ci(i)) - Cdi(i + 1); /* Co's row len = C's row len - Cd's row len */
755         });
756       });
757 
758     /* Convert row lens in Cdi[], Coi[] to CSR format using inclusive scan, e.g., changing [0,1,2,3] into [0,1,3,6] */
759     Kokkos::parallel_scan(
760       m + 1, KOKKOS_LAMBDA(const PetscInt i, PetscInt &update, const bool final) {
761         update += Cdi(i);
762         if (final) Cdi(i) = update;
763       });
764     Kokkos::parallel_scan(
765       m + 1, KOKKOS_LAMBDA(const PetscInt i, PetscInt &update, const bool final) {
766         update += Coi(i);
767         if (final) Coi(i) = update;
768       });
769 
770     /* Get Cdi, Coi on host (it is not a waste, since we do need them on host in
771        MatCreateSeqAIJKokkosWithCSRMatrix() below), then get nnz of Cd and Co.
772     */
773     Cdi_dual.modify_device();
774     Coi_dual.modify_device();
775     Cdi_dual.sync_host();
776     Coi_dual.sync_host();
777     PetscInt Cd_nnz = Cdi_dual.view_host().data()[m];
778     PetscInt Co_nnz = Coi_dual.view_host().data()[m];
779 
780     /* With nnz, allocate a, j for Cd and Co */
781     MatColIdxKokkosDualView Cdj_dual("j", Cd_nnz), Coj_dual("j", Co_nnz);
782     MatScalarKokkosDualView Cda_dual("a", Cd_nnz), Coa_dual("a", Co_nnz);
783 
784     /* Fill a, j of Cd and Co on device */
785     MatColIdxKokkosView Cdj = Cdj_dual.view_device(), Coj = Coj_dual.view_device();
786     MatScalarKokkosView Cda = Cda_dual.view_device(), Coa = Coa_dual.view_device();
787 
788     Kokkos::parallel_for(
789       Kokkos::TeamPolicy<>(m, Kokkos::AUTO()), KOKKOS_LAMBDA(const KokkosTeamMemberType &t) {
790         PetscInt i       = t.league_rank();     /* row i */
791         PetscInt clen    = Ci(i + 1) - Ci(i);   /* len of row i of C */
792         PetscInt cdlen   = Cdi(i + 1) - Cdi(i); /* len of row i of Cd */
793         PetscInt cdstart = Cdstart(i);          /* [start, end) of row i of Cd in C */
794         PetscInt cdend   = cdstart + cdlen;
795         /* [0, clen) is cut into three blocks: [0, cdstart), [cdstart, cdend), [cdend, clen) */
796         Kokkos::parallel_for(Kokkos::TeamThreadRange(t, clen), [&](PetscInt k) {
797           if (k < cdstart) { /* k in [0, cdstart) */
798             Coa(Coi(i) + k) = Ca(Ci(i) + k);
799             Coj(Coi(i) + k) = Cj(Ci(i) + k);
800           } else if (k < cdend) { /* k in [cdstart, cdend) */
801             Cda(Cdi(i) + (k - cdstart)) = Ca(Ci(i) + k);
802             Cdj(Cdi(i) + (k - cdstart)) = Cj(Ci(i) + k) - cstart; /* Use local col ids in Cdj */
803           } else {                                                /* k in [cdend, clen) */
804             Coa(Coi(i) + k - cdlen) = Ca(Ci(i) + k);
805             Coj(Coi(i) + k - cdlen) = Cj(Ci(i) + k);
806           }
807         });
808       });
809 
810     Cdj_dual.modify_device();
811     Cda_dual.modify_device();
812     Coj_dual.modify_device();
813     Coa_dual.modify_device();
814     /* With a, i, j for Cd and Co, finally build Cd, Co and then C. Their offloadmask will be set in each's MatAssemblyEnd */
815     auto cdkok = new Mat_SeqAIJKokkos(m, n, Cd_nnz, Cdi_dual, Cdj_dual, Cda_dual);
816     auto cokok = new Mat_SeqAIJKokkos(m, N, Co_nnz, Coi_dual, Coj_dual, Coa_dual);
817     PetscCall(MatCreateSeqAIJKokkosWithCSRMatrix(PETSC_COMM_SELF, cdkok, &Cd));
818     PetscCall(MatCreateSeqAIJKokkosWithCSRMatrix(PETSC_COMM_SELF, cokok, &Co));
819     PetscCall(MatSetMPIAIJKokkosWithSplitSeqAIJKokkosMatrices(C, Cd, Co)); /* Coj will be converted to local ids within */
820   }
821   PetscFunctionReturn(0);
822 }
823 
824 /* MatSeqAIJCompactOutExtraColumns_SeqAIJKokkos - Compact a SEQAIJKOKKS matrix's global col ids.
825 
826   It is similar to MatSeqAIJCompactOutExtraColumns_SeqAIJ, but it applies to SEQAIJKOKKOS and returns the l2g map in Kokkos view.
827 
828   Input Parameters:
829 +  C        - the MATMPIAIJKOKKOS matrix, of size m,n,M,N
830 .  reuse    - indicate whether the matrix has called this function before
831 .  csrmat   - the KokkosCsrMatrix, of size m,N
832 -  Cdoffset - when reuse == MAT_REUSE_MATRIX, it is an input parameter. For each row in csrmat, it stores the offset of the first
833               entry of the diag block of C in csrmat's j array.
834 
835   Output Parameters:
836 +  C        - the updated MATMPIAIJKOKKOS matrix
837 -  Cdoffset - when reuse == MAT_INITIAL_MATRIX, it is an output parameter
838 
839   Note:
840   the input matrix's col ids and col size will be changed.
841 */
842 static PetscErrorCode MatSeqAIJCompactOutExtraColumns_SeqAIJKokkos(Mat C, MatColIdxKokkosView &l2g)
843 {
844   Mat_SeqAIJKokkos      *ckok;
845   ISLocalToGlobalMapping l2gmap;
846   const PetscInt        *garray;
847   PetscInt               sz;
848 
849   PetscFunctionBegin;
850   /* Compact P_other's global col ids and col size. We do it since we guess with local ids KK might be more memory scalable */
851   PetscCall(MatSeqAIJCompactOutExtraColumns_SeqAIJ(C, &l2gmap));
852   ckok = static_cast<Mat_SeqAIJKokkos *>(C->spptr);
853   ckok->j_dual.modify_host(); /* P_other's j is modified on host; we need to sync it on device */
854   ckok->j_dual.sync_device();
855   ckok->SetColSize(C->cmap->n); /* Update col size of the csrmat in spptr */
856 
857   /* Build l2g -- the local to global mapping of C's cols */
858   PetscCall(ISLocalToGlobalMappingGetIndices(l2gmap, &garray));
859   PetscCall(ISLocalToGlobalMappingGetSize(l2gmap, &sz));
860   PetscCheck(C->cmap->n == sz, PETSC_COMM_SELF, PETSC_ERR_PLIB, "matrix column size(%" PetscInt_FMT ") != l2g mapping size(%" PetscInt_FMT ")", C->cmap->n, sz);
861 
862   ConstMatColIdxKokkosViewHost tmp(garray, sz);
863   l2g = MatColIdxKokkosView("l2g", sz);
864   Kokkos::deep_copy(l2g, tmp);
865 
866   PetscCall(ISLocalToGlobalMappingRestoreIndices(l2gmap, &garray));
867   PetscCall(ISLocalToGlobalMappingDestroy(&l2gmap));
868   PetscFunctionReturn(0);
869 }
870 
871 /* MatProductSymbolic_MPIAIJKokkos_AB - AB flavor of MatProductSymbolic_MPIAIJKokkos
872 
873   Input Parameters:
874 +  product  - Mat_Product which carried out the computation. Passed in to access info about this mat product.
875 .  A        - an MPIAIJKOKKOS matrix
876 .  B        - an MPIAIJKOKKOS matrix
877 -  mm       - a struct used to stash intermediate data when computing AB. Persist from symbolic to numeric operations.
878 
879   Note: The local part of the result C is stored as mm->C_global, which is of type KokkosCsrMatrix and uses global col ids.
880 */
881 static PetscErrorCode MatProductSymbolic_MPIAIJKokkos_AB(Mat_Product *product, Mat A, Mat B, MatMatStruct_AB *mm)
882 {
883   Mat_MPIAIJ              *a  = static_cast<Mat_MPIAIJ *>(A->data);
884   Mat                      Ad = a->A, Ao = a->B; /* diag and offdiag of A */
885   IS                       glob = NULL;
886   const PetscInt          *garray;
887   PetscInt                 N = B->cmap->N, sz;
888   ConstMatColIdxKokkosView l2g1; /* two temp maps mapping local col ids to global ones */
889   MatColIdxKokkosView      l2g2;
890   Mat                      C1, C2; /* intermediate matrices */
891 
892   PetscFunctionBegin;
893   /* C1 = Ad * B_local. B_local is a matrix got by merging Bd and Bo, and uses local col ids */
894   PetscCall(MatMPIAIJGetLocalMatMerge(B, MAT_INITIAL_MATRIX, &glob, &mm->B_local));
895   PetscCall(MatProductCreate(Ad, mm->B_local, NULL, &C1));
896   PetscCall(MatProductSetType(C1, MATPRODUCT_AB));
897   PetscCall(MatProductSetFill(C1, product->fill));
898   C1->product->api_user = product->api_user;
899   PetscCall(MatProductSetFromOptions(C1));
900   PetscUseTypeMethod(C1, productsymbolic);
901 
902   PetscCall(ISGetIndices(glob, &garray));
903   PetscCall(ISGetSize(glob, &sz));
904   const auto &tmp = ConstMatColIdxKokkosViewHost(garray, sz);                       /* wrap garray as a view */
905   l2g1            = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), tmp); /* maybe just an alias to tmp, so we restore garray at the very end */
906   PetscCall(MatSeqAIJKokkosGetCSRMatrixWithGlobalColumnIds(C1, N, l2g1, mm->C1_global));
907 
908   /* C2 = Ao * B_other. B_other is a matrix consisting of needed rows of B gathered from other procs */
909   PetscCall(MatSeqAIJKokkosBcast(mm->B_local, MAT_INITIAL_MATRIX, N, l2g1, a->Mvctx, mm->sf, mm->abuf, mm->rows, mm->rowoffset, mm->B_other));
910 
911   /* Compact B_other to use local ids as we guess KK spgemm is more memroy scalable with that; We could skip the compaction to simplify code */
912   PetscCall(MatSeqAIJCompactOutExtraColumns_SeqAIJKokkos(mm->B_other, l2g2));
913   PetscCall(MatProductCreate(Ao, mm->B_other, NULL, &C2));
914   PetscCall(MatProductSetType(C2, MATPRODUCT_AB));
915   PetscCall(MatProductSetFill(C2, product->fill));
916   C2->product->api_user = product->api_user;
917   PetscCall(MatProductSetFromOptions(C2));
918   PetscUseTypeMethod(C2, productsymbolic);
919   PetscCall(MatSeqAIJKokkosGetCSRMatrixWithGlobalColumnIds(C2, N, l2g2, mm->C2_global));
920 
921   /* C = C1 + C2.  We actually use their global col ids versions in adding */
922   mm->kh.create_spadd_handle(false); /* Input C1, C2 are NOT sorted, since B_local, B_other are not */
923   KokkosSparse::spadd_symbolic(&mm->kh, mm->C1_global, mm->C2_global, mm->C_global);
924   /* Have to do numeric since spadd_symbolic does not really populate column indices of the result matrix */
925   KokkosSparse::spadd_numeric(&mm->kh, (MatScalarType)1.0, mm->C1_global, (MatScalarType)1.0, mm->C2_global, mm->C_global);
926 
927   mm->C1 = C1;
928   mm->C2 = C2;
929   PetscCall(ISRestoreIndices(glob, &garray));
930   PetscCall(ISDestroy(&glob));
931   PetscFunctionReturn(0);
932 }
933 
934 /* MatProductSymbolic_MPIAIJKokkos_AtB - A^tB flavor of MatProductSymbolic_MPIAIJKokkos
935 
936   Input Parameters:
937 +  product  - Mat_Product which carried out the computation. Passed in to access info about this mat product.
938 .  A        - an MPIAIJKOKKOS matrix
939 .  B        - a SEQAIJKOKKOS matrix. It works as if A^t is multiplied by a parallel matrix made up of Bs on each rank.
940 .  localB   - Does B use local col ids? If false, then B is already in global col ids.
941 .  N        - col size of the "parallel B matrix". It implies B's global col ids are in range of [0,N) and N is the same across the communicator.
942 .  l2g      - If localB, then l2g maps B's local col ids to global ones.
943 -  mm       - a struct used to stash intermediate data in AtB
944 
945   Note: The local part of the result C is stored as mm->C_global, which is of type KokkosCsrMatrix and uses global col ids.
946 */
947 static PetscErrorCode MatProductSymbolic_MPIAIJKokkos_AtB(Mat_Product *product, Mat A, Mat B, PetscBool localB, PetscInt N, const ConstMatColIdxKokkosView &l2g, MatMatStruct_AtB *mm)
948 {
949   Mat_MPIAIJ *a  = static_cast<Mat_MPIAIJ *>(A->data);
950   Mat         Ad = a->A, Ao = a->B; /* diag and offdiag of A */
951   Mat         C1, C2;               /* intermediate matrices */
952 
953   PetscFunctionBegin;
954   /* C1 = Ad^t * B */
955   PetscCall(MatProductCreate(Ad, B, NULL, &C1));
956   PetscCall(MatProductSetType(C1, MATPRODUCT_AtB));
957   PetscCall(MatProductSetFill(C1, product->fill));
958   C1->product->api_user = product->api_user;
959   PetscCall(MatProductSetFromOptions(C1));
960   PetscUseTypeMethod(C1, productsymbolic);
961 
962   if (localB) PetscCall(MatSeqAIJKokkosGetCSRMatrixWithGlobalColumnIds(C1, N, l2g, mm->C1_global));
963   else mm->C1_global = static_cast<Mat_SeqAIJKokkos *>(C1->spptr)->csrmat; /* the csrmat already uses global col ids */
964 
965   /* C2 = Ao^t * B */
966   PetscCall(MatProductCreate(Ao, B, NULL, &C2));
967   PetscCall(MatProductSetType(C2, MATPRODUCT_AtB));
968   PetscCall(MatProductSetFill(C2, product->fill));
969   C2->product->api_user = product->api_user;
970   PetscCall(MatProductSetFromOptions(C2));
971   PetscUseTypeMethod(C2, productsymbolic);
972 
973   PetscCall(MatSeqAIJKokkosReduce(C2, MAT_INITIAL_MATRIX, localB, N, l2g, a->Mvctx, mm->sf, mm->abuf, mm->srcrowoffset, mm->dstrowoffset, mm->C2_global));
974 
975   mm->kh.create_spadd_handle(false); /* Input C1, C2 are NOT sorted, since B may be not */
976   KokkosSparse::spadd_symbolic(&mm->kh, mm->C1_global, mm->C2_global, mm->C_global);
977   /* Have to do numeric since spadd_symbolic does not really populate column indices of the result matrix */
978   KokkosSparse::spadd_numeric(&mm->kh, (MatScalarType)1.0, mm->C1_global, (MatScalarType)1.0, mm->C2_global, mm->C_global);
979   mm->C1 = C1;
980   mm->C2 = C2;
981   PetscFunctionReturn(0);
982 }
983 
984 PetscErrorCode MatProductNumeric_MPIAIJKokkos(Mat C)
985 {
986   Mat_Product                 *product = C->product;
987   MatProductType               ptype;
988   MatProductData_MPIAIJKokkos *mmdata;
989   MatMatStruct                *mm = NULL;
990   MatMatStruct_AB             *ab;
991   MatMatStruct_AtB            *atb;
992   Mat                          A, B, Ad, Ao, Bd, Bo;
993   const MatScalarType          one = 1.0; /* Not use literal 1.0 directly, to avoid wrong template instantiation in KokkosSparse::spadd_numeric */
994 
995   PetscFunctionBegin;
996   MatCheckProduct(C, 1);
997   mmdata = static_cast<MatProductData_MPIAIJKokkos *>(product->data);
998   ptype  = product->type;
999   A      = product->A;
1000   B      = product->B;
1001   Ad     = static_cast<Mat_MPIAIJ *>(A->data)->A;
1002   Ao     = static_cast<Mat_MPIAIJ *>(A->data)->B;
1003   Bd     = static_cast<Mat_MPIAIJ *>(B->data)->A;
1004   Bo     = static_cast<Mat_MPIAIJ *>(B->data)->B;
1005 
1006   if (mmdata->reusesym) {           /* We reached here through e.g., MatMatMult(A,B,MAT_INITIAL_MATRIX,..,C), where symbolic/numeric are combined */
1007     mmdata->reusesym = PETSC_FALSE; /* So that next time when user calls MatMatMult(E,F,MAT_REUSE_MATRIX,..,C), we still do numeric  */
1008     ab               = mmdata->mmAB;
1009     atb              = mmdata->mmAtB;
1010     if (ab) {
1011       static_cast<MatProductData_SeqAIJKokkos *>(ab->C1->product->data)->reusesym = PETSC_FALSE;
1012       static_cast<MatProductData_SeqAIJKokkos *>(ab->C2->product->data)->reusesym = PETSC_FALSE;
1013     }
1014     if (atb) {
1015       static_cast<MatProductData_SeqAIJKokkos *>(atb->C1->product->data)->reusesym = PETSC_FALSE;
1016       static_cast<MatProductData_SeqAIJKokkos *>(atb->C2->product->data)->reusesym = PETSC_FALSE;
1017     }
1018     PetscFunctionReturn(0);
1019   }
1020 
1021   if (ptype == MATPRODUCT_AB) {
1022     ab = mmdata->mmAB;
1023     /* C1 = Ad * B_local */
1024     PetscCheck(ab->C1->ops->productnumeric && ab->C2->ops->productnumeric, PETSC_COMM_SELF, PETSC_ERR_PLIB, "Missing numeric op for MATPRODUCT_AB");
1025     PetscCall(MatMPIAIJGetLocalMatMerge(B, MAT_REUSE_MATRIX, NULL /* glob */, &ab->B_local));
1026     PetscCheck(ab->C1->product->B == ab->B_local, PETSC_COMM_SELF, PETSC_ERR_PLIB, "In MATPRODUCT_AB, internal mat product matrix C1->B has unexpectedly changed");
1027     if (ab->C1->product->A != Ad) PetscCall(MatProductReplaceMats(Ad, NULL, NULL, ab->C1));
1028     PetscCall((*ab->C1->ops->productnumeric)(ab->C1));
1029     PetscCall(MatSeqAIJKokkosBcast(ab->B_local, MAT_REUSE_MATRIX, 0 /* N */, MatColIdxKokkosView() /*l2g*/, NULL /*ownerSF*/, ab->sf, ab->abuf, ab->rows, ab->rowoffset, ab->B_other));
1030     /* C2 = Ao * B_other */
1031     PetscCheck(ab->C2->product->B == ab->B_other, PETSC_COMM_SELF, PETSC_ERR_PLIB, "In MATPRODUCT_AB, internal mat product matrix C2->B has unexpectedly changed");
1032     if (ab->C1->product->A != Ao) PetscCall(MatProductReplaceMats(Ao, NULL, NULL, ab->C2));
1033     PetscCall((*ab->C2->ops->productnumeric)(ab->C2));
1034     /* C = C1_global + C2_global */
1035     KokkosSparse::spadd_numeric(&ab->kh, one, ab->C1_global, one, ab->C2_global, ab->C_global);
1036     mm = static_cast<MatMatStruct *>(ab);
1037   } else if (ptype == MATPRODUCT_AtB) {
1038     atb = mmdata->mmAtB;
1039     PetscCheck(atb->C1->ops->productnumeric && atb->C2->ops->productnumeric, PETSC_COMM_SELF, PETSC_ERR_PLIB, "Missing numeric op for MATPRODUCT_AtB");
1040     /* C1 = Ad^t * B_local */
1041     PetscCall(MatMPIAIJGetLocalMatMerge(B, MAT_REUSE_MATRIX, NULL /* glob */, &atb->B_local));
1042     PetscCheck(atb->C1->product->B == atb->B_local, PETSC_COMM_SELF, PETSC_ERR_PLIB, "In MATPRODUCT_AtB, internal mat product matrix C1->B has unexpectedly changed");
1043     if (atb->C1->product->A != Ad) PetscCall(MatProductReplaceMats(Ad, NULL, NULL, atb->C1));
1044     PetscCall((*atb->C1->ops->productnumeric)(atb->C1));
1045 
1046     /* C2 = Ao^t * B_local */
1047     PetscCheck(atb->C2->product->B == atb->B_local, PETSC_COMM_SELF, PETSC_ERR_PLIB, "In MATPRODUCT_AtB, internal mat product matrix C2->B has unexpectedly changed");
1048     if (atb->C2->product->A != Ao) PetscCall(MatProductReplaceMats(Ao, NULL, NULL, atb->C2));
1049     PetscCall((*atb->C2->ops->productnumeric)(atb->C2));
1050     /* Form C2_global */
1051     PetscCall(MatSeqAIJKokkosReduce(atb->C2, MAT_REUSE_MATRIX, PETSC_TRUE, 0 /* N */, MatColIdxKokkosView() /*l2g*/, NULL /*ownerSF*/, atb->sf, atb->abuf, atb->srcrowoffset, atb->dstrowoffset, atb->C2_global));
1052     /* C = C1_global + C2_global */
1053     KokkosSparse::spadd_numeric(&atb->kh, one, atb->C1_global, one, atb->C2_global, atb->C_global);
1054     mm = static_cast<MatMatStruct *>(atb);
1055   } else if (ptype == MATPRODUCT_PtAP) { /* BtAB */
1056     ab = mmdata->mmAB;
1057     PetscCall(MatMPIAIJGetLocalMatMerge(B, MAT_REUSE_MATRIX, NULL /* glob */, &ab->B_local));
1058 
1059     /* ab->C1 = Ad * B_local */
1060     PetscCheck(ab->C1->product->B == ab->B_local, PETSC_COMM_SELF, PETSC_ERR_PLIB, "In MATPRODUCT_PtAP, internal mat product matrix ab->C1->B has unexpectedly changed");
1061     if (ab->C1->product->A != Ad) PetscCall(MatProductReplaceMats(Ad, NULL, NULL, ab->C1));
1062     PetscCall((*ab->C1->ops->productnumeric)(ab->C1));
1063     PetscCall(MatSeqAIJKokkosBcast(ab->B_local, MAT_REUSE_MATRIX, 0 /* N */, MatColIdxKokkosView() /*l2g*/, NULL /*ownerSF*/, ab->sf, ab->abuf, ab->rows, ab->rowoffset, ab->B_other));
1064     /* ab->C2 = Ao * B_other */
1065     if (ab->C2->product->A != Ao) PetscCall(MatProductReplaceMats(Ao, NULL, NULL, ab->C2));
1066     PetscCall((*ab->C2->ops->productnumeric)(ab->C2)); /* C2 = Ao * B_other */
1067     KokkosSparse::spadd_numeric(&ab->kh, one, ab->C1_global, one, ab->C2_global, ab->C_global);
1068 
1069     /* atb->C1 = Bd^t * ab->C_petsc */
1070     atb = mmdata->mmAtB;
1071     PetscCheck(atb->C1->product->B == ab->C_petsc, PETSC_COMM_SELF, PETSC_ERR_PLIB, "In MATPRODUCT_PtAP, internal mat product matrix atb->C1->B has unexpectedly changed");
1072     if (atb->C1->product->A != Bd) PetscCall(MatProductReplaceMats(Bd, NULL, NULL, atb->C1));
1073     PetscCall((*atb->C1->ops->productnumeric)(atb->C1));
1074     /* atb->C2 = Bo^t * ab->C_petsc */
1075     if (atb->C2->product->A != Bo) PetscCall(MatProductReplaceMats(Bo, NULL, NULL, atb->C2));
1076     PetscCall((*atb->C2->ops->productnumeric)(atb->C2));
1077     PetscCall(MatSeqAIJKokkosReduce(atb->C2, MAT_REUSE_MATRIX, PETSC_FALSE, 0 /* N */, MatColIdxKokkosView() /*l2g*/, NULL /* ownerSF */, atb->sf, atb->abuf, atb->srcrowoffset, atb->dstrowoffset, atb->C2_global));
1078     KokkosSparse::spadd_numeric(&atb->kh, one, atb->C1_global, one, atb->C2_global, atb->C_global);
1079     mm = static_cast<MatMatStruct *>(atb);
1080   }
1081   /* Split C_global to form C */
1082   PetscCall(MatSetMPIAIJKokkosWithGlobalCSRMatrix(C, MAT_REUSE_MATRIX, mm->C_global, mm->Cdstart));
1083   PetscFunctionReturn(0);
1084 }
1085 
1086 PetscErrorCode MatProductSymbolic_MPIAIJKokkos(Mat C)
1087 {
1088   Mat                          A, B;
1089   Mat_Product                 *product = C->product;
1090   MatProductType               ptype;
1091   MatProductData_MPIAIJKokkos *mmdata;
1092   MatMatStruct                *mm   = NULL;
1093   IS                           glob = NULL;
1094   const PetscInt              *garray;
1095   PetscInt                     m, n, M, N, sz;
1096   ConstMatColIdxKokkosView     l2g; /* map local col ids to global ones */
1097 
1098   PetscFunctionBegin;
1099   MatCheckProduct(C, 1);
1100   PetscCheck(!product->data, PetscObjectComm((PetscObject)C), PETSC_ERR_PLIB, "Product data not empty");
1101   ptype = product->type;
1102   A     = product->A;
1103   B     = product->B;
1104 
1105   switch (ptype) {
1106   case MATPRODUCT_AB:
1107     m = A->rmap->n;
1108     n = B->cmap->n;
1109     M = A->rmap->N;
1110     N = B->cmap->N;
1111     break;
1112   case MATPRODUCT_AtB:
1113     m = A->cmap->n;
1114     n = B->cmap->n;
1115     M = A->cmap->N;
1116     N = B->cmap->N;
1117     break;
1118   case MATPRODUCT_PtAP:
1119     m = B->cmap->n;
1120     n = B->cmap->n;
1121     M = B->cmap->N;
1122     N = B->cmap->N;
1123     break; /* BtAB */
1124   default:
1125     SETERRQ(PetscObjectComm((PetscObject)C), PETSC_ERR_PLIB, "Not for product type %s", MatProductTypes[ptype]);
1126   }
1127 
1128   PetscCall(MatSetSizes(C, m, n, M, N));
1129   PetscCall(PetscLayoutSetUp(C->rmap));
1130   PetscCall(PetscLayoutSetUp(C->cmap));
1131   PetscCall(MatSetType(C, ((PetscObject)A)->type_name));
1132 
1133   mmdata           = new MatProductData_MPIAIJKokkos();
1134   mmdata->reusesym = product->api_user;
1135 
1136   if (ptype == MATPRODUCT_AB) {
1137     mmdata->mmAB = new MatMatStruct_AB();
1138     PetscCall(MatProductSymbolic_MPIAIJKokkos_AB(product, A, B, mmdata->mmAB));
1139     mm = static_cast<MatMatStruct *>(mmdata->mmAB);
1140   } else if (ptype == MATPRODUCT_AtB) {
1141     mmdata->mmAtB = new MatMatStruct_AtB();
1142     auto atb      = mmdata->mmAtB;
1143     PetscCall(MatMPIAIJGetLocalMatMerge(B, MAT_INITIAL_MATRIX, &glob, &atb->B_local));
1144     PetscCall(ISGetIndices(glob, &garray));
1145     PetscCall(ISGetSize(glob, &sz));
1146     l2g = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), ConstMatColIdxKokkosViewHost(garray, sz));
1147     PetscCall(MatProductSymbolic_MPIAIJKokkos_AtB(product, A, atb->B_local, PETSC_TRUE, N, l2g, atb));
1148     PetscCall(ISRestoreIndices(glob, &garray));
1149     PetscCall(ISDestroy(&glob));
1150     mm = static_cast<MatMatStruct *>(atb);
1151   } else if (ptype == MATPRODUCT_PtAP) {    /* BtAB */
1152     mmdata->mmAB  = new MatMatStruct_AB();  /* tmp=A*B */
1153     mmdata->mmAtB = new MatMatStruct_AtB(); /* C=B^t*tmp */
1154     auto ab       = mmdata->mmAB;
1155     auto atb      = mmdata->mmAtB;
1156     PetscCall(MatProductSymbolic_MPIAIJKokkos_AB(product, A, B, ab));
1157     auto tmp = new Mat_SeqAIJKokkos(ab->C_global); /* Memory will be owned by ab->C_petsc */
1158     PetscCall(MatCreateSeqAIJKokkosWithCSRMatrix(PETSC_COMM_SELF, tmp, &ab->C_petsc));
1159     PetscCall(MatProductSymbolic_MPIAIJKokkos_AtB(product, B, ab->C_petsc, PETSC_FALSE, N, l2g /*not used*/, atb));
1160     mm = static_cast<MatMatStruct *>(atb);
1161   }
1162   /* Split the C_global into petsc A, B format */
1163   PetscCall(MatSetMPIAIJKokkosWithGlobalCSRMatrix(C, MAT_INITIAL_MATRIX, mm->C_global, mm->Cdstart));
1164   C->product->data       = mmdata;
1165   C->product->destroy    = MatProductDataDestroy_MPIAIJKokkos;
1166   C->ops->productnumeric = MatProductNumeric_MPIAIJKokkos;
1167   PetscFunctionReturn(0);
1168 }
1169 
1170 PETSC_INTERN PetscErrorCode MatProductSetFromOptions_MPIAIJKokkos(Mat mat)
1171 {
1172   Mat_Product *product = mat->product;
1173   PetscBool    match   = PETSC_FALSE;
1174   PetscBool    usecpu  = PETSC_FALSE;
1175 
1176   PetscFunctionBegin;
1177   MatCheckProduct(mat, 1);
1178   if (!product->A->boundtocpu && !product->B->boundtocpu) PetscCall(PetscObjectTypeCompare((PetscObject)product->B, ((PetscObject)product->A)->type_name, &match));
1179   if (match) { /* we can always fallback to the CPU if requested */
1180     switch (product->type) {
1181     case MATPRODUCT_AB:
1182       if (product->api_user) {
1183         PetscOptionsBegin(PetscObjectComm((PetscObject)mat), ((PetscObject)mat)->prefix, "MatMatMult", "Mat");
1184         PetscCall(PetscOptionsBool("-matmatmult_backend_cpu", "Use CPU code", "MatMatMult", usecpu, &usecpu, NULL));
1185         PetscOptionsEnd();
1186       } else {
1187         PetscOptionsBegin(PetscObjectComm((PetscObject)mat), ((PetscObject)mat)->prefix, "MatProduct_AB", "Mat");
1188         PetscCall(PetscOptionsBool("-mat_product_algorithm_backend_cpu", "Use CPU code", "MatMatMult", usecpu, &usecpu, NULL));
1189         PetscOptionsEnd();
1190       }
1191       break;
1192     case MATPRODUCT_AtB:
1193       if (product->api_user) {
1194         PetscOptionsBegin(PetscObjectComm((PetscObject)mat), ((PetscObject)mat)->prefix, "MatTransposeMatMult", "Mat");
1195         PetscCall(PetscOptionsBool("-mattransposematmult_backend_cpu", "Use CPU code", "MatTransposeMatMult", usecpu, &usecpu, NULL));
1196         PetscOptionsEnd();
1197       } else {
1198         PetscOptionsBegin(PetscObjectComm((PetscObject)mat), ((PetscObject)mat)->prefix, "MatProduct_AtB", "Mat");
1199         PetscCall(PetscOptionsBool("-mat_product_algorithm_backend_cpu", "Use CPU code", "MatTransposeMatMult", usecpu, &usecpu, NULL));
1200         PetscOptionsEnd();
1201       }
1202       break;
1203     case MATPRODUCT_PtAP:
1204       if (product->api_user) {
1205         PetscOptionsBegin(PetscObjectComm((PetscObject)mat), ((PetscObject)mat)->prefix, "MatPtAP", "Mat");
1206         PetscCall(PetscOptionsBool("-matptap_backend_cpu", "Use CPU code", "MatPtAP", usecpu, &usecpu, NULL));
1207         PetscOptionsEnd();
1208       } else {
1209         PetscOptionsBegin(PetscObjectComm((PetscObject)mat), ((PetscObject)mat)->prefix, "MatProduct_PtAP", "Mat");
1210         PetscCall(PetscOptionsBool("-mat_product_algorithm_backend_cpu", "Use CPU code", "MatPtAP", usecpu, &usecpu, NULL));
1211         PetscOptionsEnd();
1212       }
1213       break;
1214     default:
1215       break;
1216     }
1217     match = (PetscBool)!usecpu;
1218   }
1219   if (match) {
1220     switch (product->type) {
1221     case MATPRODUCT_AB:
1222     case MATPRODUCT_AtB:
1223     case MATPRODUCT_PtAP:
1224       mat->ops->productsymbolic = MatProductSymbolic_MPIAIJKokkos;
1225       break;
1226     default:
1227       break;
1228     }
1229   }
1230   /* fallback to MPIAIJ ops */
1231   if (!mat->ops->productsymbolic) PetscCall(MatProductSetFromOptions_MPIAIJ(mat));
1232   PetscFunctionReturn(0);
1233 }
1234 
1235 static PetscErrorCode MatSetPreallocationCOO_MPIAIJKokkos(Mat mat, PetscCount coo_n, PetscInt coo_i[], PetscInt coo_j[])
1236 {
1237   Mat_MPIAIJ       *mpiaij = (Mat_MPIAIJ *)mat->data;
1238   Mat_MPIAIJKokkos *mpikok;
1239 
1240   PetscFunctionBegin;
1241   PetscCall(MatSetPreallocationCOO_MPIAIJ(mat, coo_n, coo_i, coo_j));
1242   mat->preallocated = PETSC_TRUE;
1243   PetscCall(MatAssemblyBegin(mat, MAT_FINAL_ASSEMBLY));
1244   PetscCall(MatAssemblyEnd(mat, MAT_FINAL_ASSEMBLY));
1245   PetscCall(MatZeroEntries(mat));
1246   mpikok = static_cast<Mat_MPIAIJKokkos *>(mpiaij->spptr);
1247   delete mpikok;
1248   mpiaij->spptr = new Mat_MPIAIJKokkos(mpiaij);
1249   PetscFunctionReturn(0);
1250 }
1251 
1252 static PetscErrorCode MatSetValuesCOO_MPIAIJKokkos(Mat mat, const PetscScalar v[], InsertMode imode)
1253 {
1254   Mat_MPIAIJ                 *mpiaij = static_cast<Mat_MPIAIJ *>(mat->data);
1255   Mat_MPIAIJKokkos           *mpikok = static_cast<Mat_MPIAIJKokkos *>(mpiaij->spptr);
1256   Mat                         A = mpiaij->A, B = mpiaij->B;
1257   PetscCount                  Annz = mpiaij->Annz, Annz2 = mpiaij->Annz2, Bnnz = mpiaij->Bnnz, Bnnz2 = mpiaij->Bnnz2;
1258   MatScalarKokkosView         Aa, Ba;
1259   MatScalarKokkosView         v1;
1260   MatScalarKokkosView        &vsend  = mpikok->sendbuf_d;
1261   const MatScalarKokkosView  &v2     = mpikok->recvbuf_d;
1262   const PetscCountKokkosView &Ajmap1 = mpikok->Ajmap1_d, Ajmap2 = mpikok->Ajmap2_d, Aimap2 = mpikok->Aimap2_d;
1263   const PetscCountKokkosView &Bjmap1 = mpikok->Bjmap1_d, Bjmap2 = mpikok->Bjmap2_d, Bimap2 = mpikok->Bimap2_d;
1264   const PetscCountKokkosView &Aperm1 = mpikok->Aperm1_d, Aperm2 = mpikok->Aperm2_d, Bperm1 = mpikok->Bperm1_d, Bperm2 = mpikok->Bperm2_d;
1265   const PetscCountKokkosView &Cperm1 = mpikok->Cperm1_d;
1266   PetscMemType                memtype;
1267 
1268   PetscFunctionBegin;
1269   PetscCall(PetscGetMemType(v, &memtype)); /* Return PETSC_MEMTYPE_HOST when v is NULL */
1270   if (PetscMemTypeHost(memtype)) {         /* If user gave v[] in host, we need to copy it to device if any */
1271     v1 = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), MatScalarKokkosViewHost((PetscScalar *)v, mpiaij->coo_n));
1272   } else {
1273     v1 = MatScalarKokkosView((PetscScalar *)v, mpiaij->coo_n); /* Directly use v[]'s memory */
1274   }
1275 
1276   if (imode == INSERT_VALUES) {
1277     PetscCall(MatSeqAIJGetKokkosViewWrite(A, &Aa)); /* write matrix values */
1278     PetscCall(MatSeqAIJGetKokkosViewWrite(B, &Ba));
1279   } else {
1280     PetscCall(MatSeqAIJGetKokkosView(A, &Aa)); /* read & write matrix values */
1281     PetscCall(MatSeqAIJGetKokkosView(B, &Ba));
1282   }
1283 
1284   /* Pack entries to be sent to remote */
1285   Kokkos::parallel_for(
1286     vsend.extent(0), KOKKOS_LAMBDA(const PetscCount i) { vsend(i) = v1(Cperm1(i)); });
1287 
1288   /* Send remote entries to their owner and overlap the communication with local computation */
1289   PetscCall(PetscSFReduceWithMemTypeBegin(mpiaij->coo_sf, MPIU_SCALAR, PETSC_MEMTYPE_KOKKOS, vsend.data(), PETSC_MEMTYPE_KOKKOS, v2.data(), MPI_REPLACE));
1290   /* Add local entries to A and B in one kernel */
1291   Kokkos::parallel_for(
1292     Annz + Bnnz, KOKKOS_LAMBDA(PetscCount i) {
1293       PetscScalar sum = 0.0;
1294       if (i < Annz) {
1295         for (PetscCount k = Ajmap1(i); k < Ajmap1(i + 1); k++) sum += v1(Aperm1(k));
1296         Aa(i) = (imode == INSERT_VALUES ? 0.0 : Aa(i)) + sum;
1297       } else {
1298         i -= Annz;
1299         for (PetscCount k = Bjmap1(i); k < Bjmap1(i + 1); k++) sum += v1(Bperm1(k));
1300         Ba(i) = (imode == INSERT_VALUES ? 0.0 : Ba(i)) + sum;
1301       }
1302     });
1303   PetscCall(PetscSFReduceEnd(mpiaij->coo_sf, MPIU_SCALAR, vsend.data(), v2.data(), MPI_REPLACE));
1304 
1305   /* Add received remote entries to A and B in one kernel */
1306   Kokkos::parallel_for(
1307     Annz2 + Bnnz2, KOKKOS_LAMBDA(PetscCount i) {
1308       if (i < Annz2) {
1309         for (PetscCount k = Ajmap2(i); k < Ajmap2(i + 1); k++) Aa(Aimap2(i)) += v2(Aperm2(k));
1310       } else {
1311         i -= Annz2;
1312         for (PetscCount k = Bjmap2(i); k < Bjmap2(i + 1); k++) Ba(Bimap2(i)) += v2(Bperm2(k));
1313       }
1314     });
1315 
1316   if (imode == INSERT_VALUES) {
1317     PetscCall(MatSeqAIJRestoreKokkosViewWrite(A, &Aa)); /* Increase A & B's state etc. */
1318     PetscCall(MatSeqAIJRestoreKokkosViewWrite(B, &Ba));
1319   } else {
1320     PetscCall(MatSeqAIJRestoreKokkosView(A, &Aa));
1321     PetscCall(MatSeqAIJRestoreKokkosView(B, &Ba));
1322   }
1323   PetscFunctionReturn(0);
1324 }
1325 
1326 PetscErrorCode MatDestroy_MPIAIJKokkos(Mat A)
1327 {
1328   Mat_MPIAIJ *mpiaij = (Mat_MPIAIJ *)A->data;
1329 
1330   PetscFunctionBegin;
1331   PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatMPIAIJSetPreallocation_C", NULL));
1332   PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatMPIAIJGetLocalMatMerge_C", NULL));
1333   PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatSetPreallocationCOO_C", NULL));
1334   PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatSetValuesCOO_C", NULL));
1335   delete (Mat_MPIAIJKokkos *)mpiaij->spptr;
1336   PetscCall(MatDestroy_MPIAIJ(A));
1337   PetscFunctionReturn(0);
1338 }
1339 
1340 PETSC_INTERN PetscErrorCode MatConvert_MPIAIJ_MPIAIJKokkos(Mat A, MatType mtype, MatReuse reuse, Mat *newmat)
1341 {
1342   Mat         B;
1343   Mat_MPIAIJ *a;
1344 
1345   PetscFunctionBegin;
1346   if (reuse == MAT_INITIAL_MATRIX) {
1347     PetscCall(MatDuplicate(A, MAT_COPY_VALUES, newmat));
1348   } else if (reuse == MAT_REUSE_MATRIX) {
1349     PetscCall(MatCopy(A, *newmat, SAME_NONZERO_PATTERN));
1350   }
1351   B = *newmat;
1352 
1353   B->boundtocpu = PETSC_FALSE;
1354   PetscCall(PetscFree(B->defaultvectype));
1355   PetscCall(PetscStrallocpy(VECKOKKOS, &B->defaultvectype));
1356   PetscCall(PetscObjectChangeTypeName((PetscObject)B, MATMPIAIJKOKKOS));
1357 
1358   a = static_cast<Mat_MPIAIJ *>(A->data);
1359   if (a->A) PetscCall(MatSetType(a->A, MATSEQAIJKOKKOS));
1360   if (a->B) PetscCall(MatSetType(a->B, MATSEQAIJKOKKOS));
1361   if (a->lvec) PetscCall(VecSetType(a->lvec, VECSEQKOKKOS));
1362 
1363   B->ops->assemblyend           = MatAssemblyEnd_MPIAIJKokkos;
1364   B->ops->mult                  = MatMult_MPIAIJKokkos;
1365   B->ops->multadd               = MatMultAdd_MPIAIJKokkos;
1366   B->ops->multtranspose         = MatMultTranspose_MPIAIJKokkos;
1367   B->ops->productsetfromoptions = MatProductSetFromOptions_MPIAIJKokkos;
1368   B->ops->destroy               = MatDestroy_MPIAIJKokkos;
1369 
1370   PetscCall(PetscObjectComposeFunction((PetscObject)B, "MatMPIAIJSetPreallocation_C", MatMPIAIJSetPreallocation_MPIAIJKokkos));
1371   PetscCall(PetscObjectComposeFunction((PetscObject)B, "MatMPIAIJGetLocalMatMerge_C", MatMPIAIJGetLocalMatMerge_MPIAIJKokkos));
1372   PetscCall(PetscObjectComposeFunction((PetscObject)B, "MatSetPreallocationCOO_C", MatSetPreallocationCOO_MPIAIJKokkos));
1373   PetscCall(PetscObjectComposeFunction((PetscObject)B, "MatSetValuesCOO_C", MatSetValuesCOO_MPIAIJKokkos));
1374   PetscFunctionReturn(0);
1375 }
1376 /*MC
1377    MATAIJKOKKOS - "mpiaijkokkos", a matrix type to be used for CSR sparse matrices with Kokkos
1378 
1379    A matrix type type using Kokkos-Kernels CrsMatrix type for portability across different device types
1380 
1381    Options Database Keys:
1382 .  -mat_type aijkokkos - sets the matrix type to "aijkokkos" during a call to MatSetFromOptions()
1383 
1384   Level: beginner
1385 
1386 .seealso: `MatCreateAIJKokkos()`, `MATSEQAIJKOKKOS`, `MATSEQAIJ`, `MATMPIAIJ`
1387 M*/
1388 PETSC_EXTERN PetscErrorCode MatCreate_MPIAIJKokkos(Mat A)
1389 {
1390   PetscFunctionBegin;
1391   PetscCall(PetscKokkosInitializeCheck());
1392   PetscCall(MatCreate_MPIAIJ(A));
1393   PetscCall(MatConvert_MPIAIJ_MPIAIJKokkos(A, MATMPIAIJKOKKOS, MAT_INPLACE_MATRIX, &A));
1394   PetscFunctionReturn(0);
1395 }
1396 
1397 /*@C
1398    MatCreateAIJKokkos - Creates a sparse matrix in `MATAIJKOKOS` (compressed row) format
1399    (the default parallel PETSc format).  This matrix will ultimately pushed down
1400    to Kokkos for calculations. For good matrix
1401    assembly performance the user should preallocate the matrix storage by setting
1402    the parameter nz (or the array nnz).  By setting these parameters accurately,
1403    performance during matrix assembly can be increased by more than a factor of 50.
1404 
1405    Collective
1406 
1407    Input Parameters:
1408 +  comm - MPI communicator, set to `PETSC_COMM_SELF`
1409 .  m - number of rows
1410 .  n - number of columns
1411 .  nz - number of nonzeros per row (same for all rows)
1412 -  nnz - array containing the number of nonzeros in the various rows
1413          (possibly different for each row) or NULL
1414 
1415    Output Parameter:
1416 .  A - the matrix
1417 
1418    It is recommended that one use the `MatCreate()`, `MatSetType()` and/or `MatSetFromOptions()`,
1419    MatXXXXSetPreallocation() paradigm instead of this routine directly.
1420    [MatXXXXSetPreallocation() is, for example, `MatSeqAIJSetPreallocation()`]
1421 
1422    Notes:
1423    If nnz is given then nz is ignored
1424 
1425    The AIJ format, also called compressed row storage), is fully compatible with standard Fortran 77
1426    storage.  That is, the stored row and column indices can begin at
1427    either one (as in Fortran) or zero.  See the users' manual for details.
1428 
1429    Specify the preallocated storage with either nz or nnz (not both).
1430    Set nz = `PETSC_DEFAULT` and nnz = NULL for PETSc to control dynamic memory
1431    allocation.  For large problems you MUST preallocate memory or you
1432    will get TERRIBLE performance, see the users' manual chapter on matrices.
1433 
1434    By default, this format uses inodes (identical nodes) when possible, to
1435    improve numerical efficiency of matrix-vector products and solves. We
1436    search for consecutive rows with the same nonzero structure, thereby
1437    reusing matrix information to achieve increased efficiency.
1438 
1439    Level: intermediate
1440 
1441 .seealso: `MATAIJKOKOS`, `MATSEQAIJKOKOS`, `MATMPIAIJKOKOS`, `MatCreate()`, `MatCreateAIJ()`, `MatSetValues()`, `MatSeqAIJSetColumnIndices()`, `MatCreateSeqAIJWithArrays()`, `MatCreateAIJ()`, `MATMPIAIJKOKKOS`, `MATAIJKOKKOS`
1442 @*/
1443 PetscErrorCode MatCreateAIJKokkos(MPI_Comm comm, PetscInt m, PetscInt n, PetscInt M, PetscInt N, PetscInt d_nz, const PetscInt d_nnz[], PetscInt o_nz, const PetscInt o_nnz[], Mat *A)
1444 {
1445   PetscMPIInt size;
1446 
1447   PetscFunctionBegin;
1448   PetscCall(MatCreate(comm, A));
1449   PetscCall(MatSetSizes(*A, m, n, M, N));
1450   PetscCallMPI(MPI_Comm_size(comm, &size));
1451   if (size > 1) {
1452     PetscCall(MatSetType(*A, MATMPIAIJKOKKOS));
1453     PetscCall(MatMPIAIJSetPreallocation(*A, d_nz, d_nnz, o_nz, o_nnz));
1454   } else {
1455     PetscCall(MatSetType(*A, MATSEQAIJKOKKOS));
1456     PetscCall(MatSeqAIJSetPreallocation(*A, d_nz, d_nnz));
1457   }
1458   PetscFunctionReturn(0);
1459 }
1460 
1461 // get GPU pointer to stripped down Mat. For both Seq and MPI Mat.
1462 PetscErrorCode MatKokkosGetDeviceMatWrite(Mat A, PetscSplitCSRDataStructure *B)
1463 {
1464   PetscMPIInt                size, rank;
1465   MPI_Comm                   comm;
1466   PetscSplitCSRDataStructure d_mat = NULL;
1467 
1468   PetscFunctionBegin;
1469   PetscCall(PetscObjectGetComm((PetscObject)A, &comm));
1470   PetscCallMPI(MPI_Comm_size(comm, &size));
1471   PetscCallMPI(MPI_Comm_rank(comm, &rank));
1472   if (size == 1) {
1473     PetscCall(MatSeqAIJKokkosGetDeviceMat(A, &d_mat));
1474     PetscCall(MatSeqAIJKokkosModifyDevice(A)); /* Since we are going to modify matrix values on device */
1475   } else {
1476     Mat_MPIAIJ *aij = (Mat_MPIAIJ *)A->data;
1477     PetscCall(MatSeqAIJKokkosGetDeviceMat(aij->A, &d_mat));
1478     PetscCall(MatSeqAIJKokkosModifyDevice(aij->A));
1479     PetscCall(MatSeqAIJKokkosModifyDevice(aij->B));
1480     PetscCheck(A->nooffprocentries || aij->donotstash, PetscObjectComm((PetscObject)A), PETSC_ERR_SUP, "Device assembly does not currently support offproc values insertion. Use MatSetOption(A,MAT_NO_OFF_PROC_ENTRIES,PETSC_TRUE) or MatSetOption(A,MAT_IGNORE_OFF_PROC_ENTRIES,PETSC_TRUE)");
1481   }
1482   // act like MatSetValues because not called on host
1483   if (A->assembled) {
1484     if (A->was_assembled) PetscCall(PetscInfo(A, "Assemble more than once already\n"));
1485     A->was_assembled = PETSC_TRUE; // this is done (lazy) in MatAssemble but we are not calling it anymore - done in AIJ AssemblyEnd, need here?
1486   } else {
1487     PetscCall(PetscInfo(A, "Warning !assemble ??? assembled=%" PetscInt_FMT "\n", A->assembled));
1488   }
1489   if (!d_mat) {
1490     struct _n_SplitCSRMat h_mat; /* host container */
1491     Mat_SeqAIJKokkos     *aijkokA;
1492     Mat_SeqAIJ           *jaca;
1493     PetscInt              n = A->rmap->n, nnz;
1494     Mat                   Amat;
1495     PetscInt             *colmap;
1496 
1497     /* create and copy h_mat */
1498     h_mat.M = A->cmap->N; // use for debug build
1499     PetscCall(PetscInfo(A, "Create device matrix in Kokkos\n"));
1500     if (size == 1) {
1501       Amat            = A;
1502       jaca            = (Mat_SeqAIJ *)A->data;
1503       h_mat.rstart    = 0;
1504       h_mat.rend      = A->rmap->n;
1505       h_mat.cstart    = 0;
1506       h_mat.cend      = A->cmap->n;
1507       h_mat.offdiag.i = h_mat.offdiag.j = NULL;
1508       h_mat.offdiag.a                   = NULL;
1509       aijkokA                           = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
1510     } else {
1511       Mat_MPIAIJ       *aij  = (Mat_MPIAIJ *)A->data;
1512       Mat_SeqAIJ       *jacb = (Mat_SeqAIJ *)aij->B->data;
1513       PetscInt          ii;
1514       Mat_SeqAIJKokkos *aijkokB;
1515 
1516       Amat    = aij->A;
1517       aijkokA = static_cast<Mat_SeqAIJKokkos *>(aij->A->spptr);
1518       aijkokB = static_cast<Mat_SeqAIJKokkos *>(aij->B->spptr);
1519       jaca    = (Mat_SeqAIJ *)aij->A->data;
1520       PetscCheck(!aij->B->cmap->n || aij->garray, comm, PETSC_ERR_PLIB, "MPIAIJ Matrix was assembled but is missing garray");
1521       PetscCheck(aij->B->rmap->n == aij->A->rmap->n, comm, PETSC_ERR_SUP, "Only support aij->B->rmap->n == aij->A->rmap->n");
1522       aij->donotstash          = PETSC_TRUE;
1523       aij->A->nooffprocentries = aij->B->nooffprocentries = A->nooffprocentries = PETSC_TRUE;
1524       jaca->nonew = jacb->nonew = PETSC_TRUE; // no more disassembly
1525       PetscCall(PetscCalloc1(A->cmap->N, &colmap));
1526       for (ii = 0; ii < aij->B->cmap->n; ii++) colmap[aij->garray[ii]] = ii + 1;
1527       // allocate B copy data
1528       h_mat.rstart = A->rmap->rstart;
1529       h_mat.rend   = A->rmap->rend;
1530       h_mat.cstart = A->cmap->rstart;
1531       h_mat.cend   = A->cmap->rend;
1532       nnz          = jacb->i[n];
1533       if (jacb->compressedrow.use) {
1534         const Kokkos::View<PetscInt *, Kokkos::HostSpace, Kokkos::MemoryTraits<Kokkos::Unmanaged>> h_i_k(jacb->i, n + 1);
1535         aijkokB->i_uncompressed_d = Kokkos::View<PetscInt *>(Kokkos::create_mirror(DefaultMemorySpace(), h_i_k));
1536         Kokkos::deep_copy(aijkokB->i_uncompressed_d, h_i_k);
1537         h_mat.offdiag.i = aijkokB->i_uncompressed_d.data();
1538       } else {
1539         h_mat.offdiag.i = aijkokB->i_device_data();
1540       }
1541       h_mat.offdiag.j = aijkokB->j_device_data();
1542       h_mat.offdiag.a = aijkokB->a_device_data();
1543       {
1544         Kokkos::View<PetscInt *, Kokkos::HostSpace, Kokkos::MemoryTraits<Kokkos::Unmanaged>> h_colmap_k(colmap, A->cmap->N);
1545         aijkokB->colmap_d = Kokkos::View<PetscInt *>(Kokkos::create_mirror(DefaultMemorySpace(), h_colmap_k));
1546         Kokkos::deep_copy(aijkokB->colmap_d, h_colmap_k);
1547         h_mat.colmap = aijkokB->colmap_d.data();
1548         PetscCall(PetscFree(colmap));
1549       }
1550       h_mat.offdiag.ignorezeroentries = jacb->ignorezeroentries;
1551       h_mat.offdiag.n                 = n;
1552     }
1553     // allocate A copy data
1554     nnz                          = jaca->i[n];
1555     h_mat.diag.n                 = n;
1556     h_mat.diag.ignorezeroentries = jaca->ignorezeroentries;
1557     PetscCallMPI(MPI_Comm_rank(comm, &h_mat.rank));
1558     PetscCheck(!jaca->compressedrow.use, PETSC_COMM_SELF, PETSC_ERR_PLIB, "A does not support compressed row (todo)");
1559     h_mat.diag.i = aijkokA->i_device_data();
1560     h_mat.diag.j = aijkokA->j_device_data();
1561     h_mat.diag.a = aijkokA->a_device_data();
1562     // copy pointers and metdata to device
1563     PetscCall(MatSeqAIJKokkosSetDeviceMat(Amat, &h_mat));
1564     PetscCall(MatSeqAIJKokkosGetDeviceMat(Amat, &d_mat));
1565     PetscCall(PetscInfo(A, "Create device Mat n=%" PetscInt_FMT " nnz=%" PetscInt_FMT "\n", h_mat.diag.n, nnz));
1566   }
1567   *B           = d_mat;       // return it, set it in Mat, and set it up
1568   A->assembled = PETSC_FALSE; // ready to write with matsetvalues - this done (lazy) in normal MatSetValues
1569   PetscFunctionReturn(0);
1570 }
1571 
1572 PETSC_INTERN PetscErrorCode MatSeqAIJKokkosGetOffloadMask(Mat A, const char **mask)
1573 {
1574   Mat_SeqAIJKokkos *aijkok = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
1575 
1576   PetscFunctionBegin;
1577   if (!aijkok) *mask = "AIJKOK_UNALLOCATED";
1578   else if (aijkok->a_dual.need_sync_host()) *mask = "PETSC_OFFLOAD_GPU";
1579   else if (aijkok->a_dual.need_sync_device()) *mask = "PETSC_OFFLOAD_CPU";
1580   else *mask = "PETSC_OFFLOAD_BOTH";
1581   PetscFunctionReturn(0);
1582 }
1583 
1584 PETSC_INTERN PetscErrorCode MatAIJKokkosPrintOffloadMask(Mat A)
1585 {
1586   PetscMPIInt size;
1587   Mat         Ad, Ao;
1588   const char *amask, *bmask;
1589 
1590   PetscFunctionBegin;
1591   PetscCallMPI(MPI_Comm_size(PetscObjectComm((PetscObject)A), &size));
1592 
1593   if (size == 1) {
1594     PetscCall(MatSeqAIJKokkosGetOffloadMask(A, &amask));
1595     PetscCall(PetscPrintf(PETSC_COMM_SELF, "%s\n", amask));
1596   } else {
1597     Ad = ((Mat_MPIAIJ *)A->data)->A;
1598     Ao = ((Mat_MPIAIJ *)A->data)->B;
1599     PetscCall(MatSeqAIJKokkosGetOffloadMask(Ad, &amask));
1600     PetscCall(MatSeqAIJKokkosGetOffloadMask(Ao, &bmask));
1601     PetscCall(PetscPrintf(PETSC_COMM_SELF, "Diag : Off-diag = %s : %s\n", amask, bmask));
1602   }
1603   PetscFunctionReturn(0);
1604 }
1605