xref: /petsc/src/mat/impls/aij/mpi/kokkos/mpiaijkok.kokkos.cxx (revision 43cdf1ebbcbfa05bee08e48007ef1bae3f20f4e9)
1 #include <petscvec_kokkos.hpp>
2 #include <petscsf.h>
3 #include <petsc/private/sfimpl.h>
4 #include <../src/mat/impls/aij/mpi/mpiaij.h>
5 #include <../src/mat/impls/aij/mpi/kokkos/mpiaijkok.hpp>
6 #include <KokkosSparse_spadd.hpp>
7 
8 PetscErrorCode MatAssemblyEnd_MPIAIJKokkos(Mat A, MatAssemblyType mode)
9 {
10   Mat_SeqAIJKokkos *aijkok;
11   Mat_MPIAIJ       *mpiaij = (Mat_MPIAIJ *)A->data;
12 
13   PetscFunctionBegin;
14   PetscCall(MatAssemblyEnd_MPIAIJ(A, mode));
15   /* E.g., MatCreateSubMatrix() calls MatCreateMPIAIJWithSeqAIJ(comm,A,B,..), which creates Bnew of SEQAIJ and destroys B of SEQAIJKOKKOS.
16      Thus we finalize A/B/lvec's type in MatAssemblyEnd() to handle various cases.
17    */
18   if (mode == MAT_FINAL_ASSEMBLY) {
19     PetscCall(MatSetType(mpiaij->A, MATSEQAIJKOKKOS));
20     PetscCall(MatSetType(mpiaij->B, MATSEQAIJKOKKOS));
21     PetscCall(VecSetType(mpiaij->lvec, VECSEQKOKKOS));
22   }
23   aijkok = static_cast<Mat_SeqAIJKokkos *>(((Mat_MPIAIJ *)A->data)->A->spptr); /* Access spptr after MatAssemblyEnd_MPIAIJ(), which might have deleted old spptr */
24   if (aijkok && aijkok->device_mat_d.data()) {
25     A->offloadmask = PETSC_OFFLOAD_GPU; // in GPU mode, no going back. MatSetValues checks this
26   }
27 
28   PetscFunctionReturn(0);
29 }
30 
31 PetscErrorCode MatMPIAIJSetPreallocation_MPIAIJKokkos(Mat mat, PetscInt d_nz, const PetscInt d_nnz[], PetscInt o_nz, const PetscInt o_nnz[])
32 {
33   Mat_MPIAIJ *mpiaij = (Mat_MPIAIJ *)mat->data;
34 
35   PetscFunctionBegin;
36   PetscCall(PetscLayoutSetUp(mat->rmap));
37   PetscCall(PetscLayoutSetUp(mat->cmap));
38 #if defined(PETSC_USE_DEBUG)
39   if (d_nnz) {
40     PetscInt i;
41     for (i = 0; i < mat->rmap->n; i++) PetscCheck(d_nnz[i] >= 0, PETSC_COMM_SELF, PETSC_ERR_ARG_OUTOFRANGE, "d_nnz cannot be less than 0: local row %" PetscInt_FMT " value %" PetscInt_FMT, i, d_nnz[i]);
42   }
43   if (o_nnz) {
44     PetscInt i;
45     for (i = 0; i < mat->rmap->n; i++) PetscCheck(o_nnz[i] >= 0, PETSC_COMM_SELF, PETSC_ERR_ARG_OUTOFRANGE, "o_nnz cannot be less than 0: local row %" PetscInt_FMT " value %" PetscInt_FMT, i, o_nnz[i]);
46   }
47 #endif
48 #if defined(PETSC_USE_CTABLE)
49   PetscCall(PetscHMapIDestroy(&mpiaij->colmap));
50 #else
51   PetscCall(PetscFree(mpiaij->colmap));
52 #endif
53   PetscCall(PetscFree(mpiaij->garray));
54   PetscCall(VecDestroy(&mpiaij->lvec));
55   PetscCall(VecScatterDestroy(&mpiaij->Mvctx));
56   /* Because the B will have been resized we simply destroy it and create a new one each time */
57   PetscCall(MatDestroy(&mpiaij->B));
58 
59   if (!mpiaij->A) {
60     PetscCall(MatCreate(PETSC_COMM_SELF, &mpiaij->A));
61     PetscCall(MatSetSizes(mpiaij->A, mat->rmap->n, mat->cmap->n, mat->rmap->n, mat->cmap->n));
62   }
63   if (!mpiaij->B) {
64     PetscMPIInt size;
65     PetscCallMPI(MPI_Comm_size(PetscObjectComm((PetscObject)mat), &size));
66     PetscCall(MatCreate(PETSC_COMM_SELF, &mpiaij->B));
67     PetscCall(MatSetSizes(mpiaij->B, mat->rmap->n, size > 1 ? mat->cmap->N : 0, mat->rmap->n, size > 1 ? mat->cmap->N : 0));
68   }
69   PetscCall(MatSetType(mpiaij->A, MATSEQAIJKOKKOS));
70   PetscCall(MatSetType(mpiaij->B, MATSEQAIJKOKKOS));
71   PetscCall(MatSeqAIJSetPreallocation(mpiaij->A, d_nz, d_nnz));
72   PetscCall(MatSeqAIJSetPreallocation(mpiaij->B, o_nz, o_nnz));
73   mat->preallocated = PETSC_TRUE;
74   PetscFunctionReturn(0);
75 }
76 
77 PetscErrorCode MatMult_MPIAIJKokkos(Mat mat, Vec xx, Vec yy)
78 {
79   Mat_MPIAIJ *mpiaij = (Mat_MPIAIJ *)mat->data;
80   PetscInt    nt;
81 
82   PetscFunctionBegin;
83   PetscCall(VecGetLocalSize(xx, &nt));
84   PetscCheck(nt == mat->cmap->n, PETSC_COMM_SELF, PETSC_ERR_ARG_SIZ, "Incompatible partition of mat (%" PetscInt_FMT ") and xx (%" PetscInt_FMT ")", mat->cmap->n, nt);
85   PetscCall(VecScatterBegin(mpiaij->Mvctx, xx, mpiaij->lvec, INSERT_VALUES, SCATTER_FORWARD));
86   PetscCall((*mpiaij->A->ops->mult)(mpiaij->A, xx, yy));
87   PetscCall(VecScatterEnd(mpiaij->Mvctx, xx, mpiaij->lvec, INSERT_VALUES, SCATTER_FORWARD));
88   PetscCall((*mpiaij->B->ops->multadd)(mpiaij->B, mpiaij->lvec, yy, yy));
89   PetscFunctionReturn(0);
90 }
91 
92 PetscErrorCode MatMultAdd_MPIAIJKokkos(Mat mat, Vec xx, Vec yy, Vec zz)
93 {
94   Mat_MPIAIJ *mpiaij = (Mat_MPIAIJ *)mat->data;
95   PetscInt    nt;
96 
97   PetscFunctionBegin;
98   PetscCall(VecGetLocalSize(xx, &nt));
99   PetscCheck(nt == mat->cmap->n, PETSC_COMM_SELF, PETSC_ERR_ARG_SIZ, "Incompatible partition of mat (%" PetscInt_FMT ") and xx (%" PetscInt_FMT ")", mat->cmap->n, nt);
100   PetscCall(VecScatterBegin(mpiaij->Mvctx, xx, mpiaij->lvec, INSERT_VALUES, SCATTER_FORWARD));
101   PetscCall((*mpiaij->A->ops->multadd)(mpiaij->A, xx, yy, zz));
102   PetscCall(VecScatterEnd(mpiaij->Mvctx, xx, mpiaij->lvec, INSERT_VALUES, SCATTER_FORWARD));
103   PetscCall((*mpiaij->B->ops->multadd)(mpiaij->B, mpiaij->lvec, zz, zz));
104   PetscFunctionReturn(0);
105 }
106 
107 PetscErrorCode MatMultTranspose_MPIAIJKokkos(Mat mat, Vec xx, Vec yy)
108 {
109   Mat_MPIAIJ *mpiaij = (Mat_MPIAIJ *)mat->data;
110   PetscInt    nt;
111 
112   PetscFunctionBegin;
113   PetscCall(VecGetLocalSize(xx, &nt));
114   PetscCheck(nt == mat->rmap->n, PETSC_COMM_SELF, PETSC_ERR_ARG_SIZ, "Incompatible partition of mat (%" PetscInt_FMT ") and xx (%" PetscInt_FMT ")", mat->rmap->n, nt);
115   PetscCall((*mpiaij->B->ops->multtranspose)(mpiaij->B, xx, mpiaij->lvec));
116   PetscCall((*mpiaij->A->ops->multtranspose)(mpiaij->A, xx, yy));
117   PetscCall(VecScatterBegin(mpiaij->Mvctx, mpiaij->lvec, yy, ADD_VALUES, SCATTER_REVERSE));
118   PetscCall(VecScatterEnd(mpiaij->Mvctx, mpiaij->lvec, yy, ADD_VALUES, SCATTER_REVERSE));
119   PetscFunctionReturn(0);
120 }
121 
122 /* Merge the "A, B" matrices of mat into a matrix C.  mat's type is MPIAIJKOKKOS. C's type is MATSEQAIJKOKKOS.
123    A is put before B. C's size would be A->rmap->n by (A->cmap->n + B->cmap->n).
124    C still uses local column ids. Their corresponding global column ids are returned in glob.
125 */
126 PetscErrorCode MatMPIAIJGetLocalMatMerge_MPIAIJKokkos(Mat mat, MatReuse reuse, IS *glob, Mat *C)
127 {
128   Mat             Ad, Ao;
129   const PetscInt *cmap;
130 
131   PetscFunctionBegin;
132   PetscCall(MatMPIAIJGetSeqAIJ(mat, &Ad, &Ao, &cmap));
133   PetscCall(MatSeqAIJKokkosMergeMats(Ad, Ao, reuse, C));
134   if (glob) {
135     PetscInt cst, i, dn, on, *gidx;
136     PetscCall(MatGetLocalSize(Ad, NULL, &dn));
137     PetscCall(MatGetLocalSize(Ao, NULL, &on));
138     PetscCall(MatGetOwnershipRangeColumn(mat, &cst, NULL));
139     PetscCall(PetscMalloc1(dn + on, &gidx));
140     for (i = 0; i < dn; i++) gidx[i] = cst + i;
141     for (i = 0; i < on; i++) gidx[i + dn] = cmap[i];
142     PetscCall(ISCreateGeneral(PetscObjectComm((PetscObject)Ad), dn + on, gidx, PETSC_OWN_POINTER, glob));
143   }
144   PetscFunctionReturn(0);
145 }
146 
147 /* Structs used in matrix product C=AB, C=A^tB and C=B^tAB */
148 struct MatMatStruct {
149   MatRowMapKokkosView Cdstart; /* Used to split sequential matrix into petsc's A, B format */
150   PetscSF             sf;      /* SF to send/recv matrix entries */
151   MatScalarKokkosView abuf;    /* buf of mat values in send/recv */
152   Mat                 C1, C2, B_local;
153   KokkosCsrMatrix     C1_global, C2_global, C_global;
154   KernelHandle        kh;
155   MatMatStruct()
156   {
157     C1 = C2 = B_local = NULL;
158     sf                = NULL;
159   }
160 
161   ~MatMatStruct()
162   {
163     MatDestroy(&C1);
164     MatDestroy(&C2);
165     MatDestroy(&B_local);
166     PetscSFDestroy(&sf);
167     kh.destroy_spadd_handle();
168   }
169 };
170 
171 struct MatMatStruct_AB : public MatMatStruct {
172   MatColIdxKokkosView rows;
173   MatRowMapKokkosView rowoffset;
174   Mat                 B_other, C_petsc; /* SEQAIJKOKKOS matrices. TODO: have a better var name than C_petsc */
175 
176   MatMatStruct_AB() : B_other(NULL), C_petsc(NULL) { }
177   ~MatMatStruct_AB()
178   {
179     MatDestroy(&B_other);
180     MatDestroy(&C_petsc);
181   }
182 };
183 
184 struct MatMatStruct_AtB : public MatMatStruct {
185   MatRowMapKokkosView srcrowoffset, dstrowoffset;
186 };
187 
188 struct MatProductData_MPIAIJKokkos {
189   MatMatStruct_AB  *mmAB;
190   MatMatStruct_AtB *mmAtB;
191   PetscBool         reusesym;
192 
193   MatProductData_MPIAIJKokkos() : mmAB(NULL), mmAtB(NULL), reusesym(PETSC_FALSE) { }
194   ~MatProductData_MPIAIJKokkos()
195   {
196     delete mmAB;
197     delete mmAtB;
198   }
199 };
200 
201 static PetscErrorCode MatProductDataDestroy_MPIAIJKokkos(void *data)
202 {
203   PetscFunctionBegin;
204   PetscCallCXX(delete static_cast<MatProductData_MPIAIJKokkos *>(data));
205   PetscFunctionReturn(0);
206 }
207 
208 /* MatSeqAIJKokkosGetCSRMatrixWithGlobalColumnIds - Get a KokkosCsrMatrix from a MATSEQAIJKOKKOS matrix
209 
210    Input Parameters:
211 +  A       - the MATSEQAIJKOKKOS matrix
212 .  N       - new column size for the returned Kokkos matrix
213 -  l2g     - a map that maps old col ids to new col ids
214 
215    Output Parameters:
216 .  csrmat  - the Kokkos matrix, which has the same row size as A, shares a, i but not j with A.
217  */
218 static PetscErrorCode MatSeqAIJKokkosGetCSRMatrixWithGlobalColumnIds(Mat A, PetscInt N, const ConstMatColIdxKokkosView &l2g, KokkosCsrMatrix &csrmat)
219 {
220   KokkosCsrMatrix    &orig = static_cast<Mat_SeqAIJKokkos *>(A->spptr)->csrmat;
221   MatColIdxKokkosView jg("jg", orig.nnz()); /* New j array for csrmat */
222 
223   PetscFunctionBegin;
224   PetscCallCXX(Kokkos::parallel_for(
225     orig.nnz(), KOKKOS_LAMBDA(const PetscInt i) { jg(i) = l2g(orig.graph.entries(i)); }));
226   PetscCallCXX(csrmat = KokkosCsrMatrix("csrmat", orig.numRows(), N, orig.nnz(), orig.values, orig.graph.row_map, jg));
227   PetscFunctionReturn(0);
228 }
229 
230 /* MatSetMPIAIJKokkosWithSplitSeqAIJKokkosMatrices - Set the diag and offdiag matrices of a MATMPIAIJKOKKOS matrix.
231    It is similar to MatCreateMPIAIJWithSplitArrays.
232 
233   Input Parameters:
234 +  mat   - the MATMPIAIJKOKKOS matrix, which should have its type and layout set, but should not have its diag, offdiag matrices set
235 .  A     - the diag matrix using local col ids
236 -  B     - the offdiag matrix using global col ids
237 
238   Output Parameters:
239 .  mat   - the updated MATMPIAIJKOKKOS matrix
240 */
241 static PetscErrorCode MatSetMPIAIJKokkosWithSplitSeqAIJKokkosMatrices(Mat mat, Mat A, Mat B)
242 {
243   Mat_MPIAIJ       *mpiaij = static_cast<Mat_MPIAIJ *>(mat->data);
244   PetscInt          m, n, M, N, Am, An, Bm, Bn;
245   Mat_SeqAIJKokkos *bkok = static_cast<Mat_SeqAIJKokkos *>(B->spptr);
246 
247   PetscFunctionBegin;
248   PetscCall(MatGetSize(mat, &M, &N));
249   PetscCall(MatGetLocalSize(mat, &m, &n));
250   PetscCall(MatGetLocalSize(A, &Am, &An));
251   PetscCall(MatGetLocalSize(B, &Bm, &Bn));
252 
253   PetscCheck(m == Am && m == Bm, PETSC_COMM_SELF, PETSC_ERR_PLIB, "local number of rows do not match");
254   PetscCheck(n == An, PETSC_COMM_SELF, PETSC_ERR_PLIB, "local number of columns do not match");
255   PetscCheck(N == Bn, PETSC_COMM_SELF, PETSC_ERR_PLIB, "global number of columns do not match");
256   PetscCheck(!mpiaij->A && !mpiaij->B, PETSC_COMM_SELF, PETSC_ERR_PLIB, "A, B of the MPIAIJ matrix are not empty");
257   mpiaij->A = A;
258   mpiaij->B = B;
259 
260   mat->preallocated     = PETSC_TRUE;
261   mat->nooffprocentries = PETSC_TRUE; /* See MatAssemblyBegin_MPIAIJ. In effect, making MatAssemblyBegin a nop */
262 
263   PetscCall(MatSetOption(mat, MAT_NO_OFF_PROC_ENTRIES, PETSC_TRUE));
264   PetscCall(MatAssemblyBegin(mat, MAT_FINAL_ASSEMBLY));
265   /* MatAssemblyEnd is critical here. It sets mat->offloadmask according to A and B's, and
266     also gets mpiaij->B compacted, with its col ids and size reduced
267   */
268   PetscCall(MatAssemblyEnd(mat, MAT_FINAL_ASSEMBLY));
269   PetscCall(MatSetOption(mat, MAT_NO_OFF_PROC_ENTRIES, PETSC_FALSE));
270   PetscCall(MatSetOption(mat, MAT_NEW_NONZERO_LOCATION_ERR, PETSC_TRUE));
271 
272   /* Update bkok with new local col ids (stored on host) and size */
273   bkok->j_dual.modify_host();
274   bkok->j_dual.sync_device();
275   bkok->SetColSize(mpiaij->B->cmap->n);
276   PetscFunctionReturn(0);
277 }
278 
279 /* MatSeqAIJKokkosBcast - Bcast rows of a SEQAIJKOKKOS matrice (B) to form a SEQAIJKOKKOS matrix (C).
280 
281    It is essentially the MPIAIJKOKKOS counterpart of MatGetBrowsOfAoCols_MPIAIJ, but supports device and uses PetscSF.
282    In the given ownerSF, leaves correspond to rows in C, and roots correspond to rows in B. Roots may connect to multiple leaves.
283    Suppose C's j-th row is connected to a root identified by PetscSFNode (k,i), it means we will bcast the i-th row of B on rank k
284    to j-th row of C. ownerSF's leaves must be contiguous (in other words, as if ilocal=NULL was used to set its graph).
285 
286    Collective on comm of ownerSF
287 
288    Input Parameters:
289 +   B       - the SEQAIJKOKKOS matrix, using local col ids
290 .   reuse   - either MAT_INITIAL_MATRIX or MAT_REUSE_MATRIX
291 .   N       - global col ids are in range of [0,N). N Must be the same across ranks (nonsignificant in MAT_REUSE_MATRIX)
292 .   l2g     - a map mapping B's local col ids to global ones (nonsignificant in MAT_REUSE_MATRIX)
293 .   ownerSF - the ownership SF (nonsignificant in MAT_REUSE_MATRIX)
294 
295    Input/Output Parameters (out when resue = MAT_INITIAL_MATRIX, inout when reuse = MAT_REUSE_MATRIX)
296 +   bcastSF   - the SF used to bcast rows of B. This plain SF does buffer (abuf) to buffer (Ca) send/recv. In this SF, vertices are nonzeros.
297 .   abuf      - buffer for sending matrix values
298 .   rows      - array containing indices of (local) rows that this rank needs to bcast to others. Each receiver rank has a chunk in rows[].
299                 Values in rows[] might have repeats, which simply indicates a row will be bcast'ed to multiple neighbors.
300 .   rowoffset - For each row in rows[], it will be copied to rowoffset[] at abuf[]
301 -   C         -  the SEQAIJKOKKOS matrix made of the bcast'ed rows, using local col ids.
302 */
303 static PetscErrorCode MatSeqAIJKokkosBcast(Mat B, MatReuse reuse, PetscInt N, const ConstMatColIdxKokkosView &l2g, PetscSF ownerSF, PetscSF &bcastSF, MatScalarKokkosView &abuf, MatColIdxKokkosView &rows, MatRowMapKokkosView &rowoffset, Mat &C)
304 {
305   Mat_SeqAIJKokkos *bkok, *ckok;
306 
307   PetscFunctionBegin;
308   PetscCall(MatSeqAIJKokkosSyncDevice(B)); /* Make sure B->spptr is accessible */
309   bkok = static_cast<Mat_SeqAIJKokkos *>(B->spptr);
310 
311   if (reuse == MAT_REUSE_MATRIX) {
312     ckok = static_cast<Mat_SeqAIJKokkos *>(C->spptr);
313 
314     const auto &Ba = bkok->a_dual.view_device();
315     const auto &Bi = bkok->i_dual.view_device();
316     const auto &Ca = ckok->a_dual.view_device();
317 
318     /* Copy Ba to abuf */
319     Kokkos::parallel_for(
320       Kokkos::TeamPolicy<>(rows.extent(0), Kokkos::AUTO()), KOKKOS_LAMBDA(const KokkosTeamMemberType &t) {
321         PetscInt i    = t.league_rank(); /* rows[i] is r-th row of B */
322         PetscInt r    = rows(i);
323         PetscInt base = rowoffset(i); /* Copy r-th row of B to this offset in abuf[] */
324         Kokkos::parallel_for(Kokkos::TeamThreadRange(t, Bi(r + 1) - Bi(r)), [&](PetscInt k) { abuf(base + k) = Ba(Bi(r) + k); });
325       });
326 
327     /* Send abuf to Ca through bcastSF and then mark C is updated on device */
328     PetscCall(PetscSFBcastBegin(bcastSF, MPIU_SCALAR, abuf.data(), Ca.data(), MPI_REPLACE)); /* TODO: get memtype for abuf */
329     PetscCall(PetscSFBcastEnd(bcastSF, MPIU_SCALAR, abuf.data(), Ca.data(), MPI_REPLACE));
330     ckok->a_dual.modify_device();
331   } else if (reuse == MAT_INITIAL_MATRIX) {
332     MPI_Comm    comm;
333     PetscMPIInt tag;
334     PetscInt    k, Cm, Cn, Cnnz, *Ci_h, nroots, nleaves;
335 
336     PetscCallMPI(PetscObjectGetComm((PetscObject)ownerSF, &comm));
337     PetscCall(PetscSFGetGraph(ownerSF, &nroots, &nleaves, NULL, NULL));
338     Cm = nleaves; /* row size of C */
339     Cn = N;       /* col size of C, which initially uses global ids, so we can safely set its col size as N */
340 
341     /* Get row lens (nz) of B's rows for later fast query */
342     PetscInt       *Browlens;
343     const PetscInt *tmp = bkok->i_host_data();
344     PetscCall(PetscMalloc1(nroots, &Browlens));
345     for (k = 0; k < nroots; k++) Browlens[k] = tmp[k + 1] - tmp[k];
346 
347     /* By ownerSF, each proc gets lens of rows of C */
348     MatRowMapKokkosDualView Ci("i", Cm + 1); /* C's rowmap */
349     Ci_h    = Ci.view_host().data();
350     Ci_h[0] = 0;
351     PetscCall(PetscSFBcastWithMemTypeBegin(ownerSF, MPIU_INT, PETSC_MEMTYPE_HOST, Browlens, PETSC_MEMTYPE_HOST, &Ci_h[1], MPI_REPLACE));
352     PetscCall(PetscSFBcastEnd(ownerSF, MPIU_INT, Browlens, &Ci_h[1], MPI_REPLACE));
353     for (k = 1; k < Cm + 1; k++) Ci_h[k] += Ci_h[k - 1]; /* Convert lens to CSR */
354     Cnnz = Ci_h[Cm];
355     Ci.modify_host();
356     Ci.sync_device();
357 
358     /* With the newly known Cnnz, we are able to allocate (j, a) for C on host & device */
359     MatColIdxKokkosDualView Cj("j", Cnnz);
360     MatScalarKokkosDualView Ca("a", Cnnz);
361 
362     /* Now build the bcastSF to fill Ca, Cj. This plain SF only does (contiguous) buffer to buffer send/recv */
363     const PetscMPIInt *iranks, *ranks;
364     const PetscInt    *ioffset, *irootloc, *roffset;
365     PetscInt           i, j, niranks, nranks, *sdisp, *rdisp, *rowptr;
366     MPI_Request       *reqs;
367 
368     PetscCall(PetscSFGetLeafRanks(ownerSF, &niranks, &iranks, &ioffset, &irootloc));                      /* irootloc[] contains indices of rows I need to send to each receiver */
369     PetscCall(PetscSFGetRootRanks(ownerSF, &nranks, &ranks, &roffset, NULL /*rmine*/, NULL /*rremote*/)); /* recv info */
370 
371     /* figure out offsets at the send buffer, to build the SF
372       sdisp[]  - stores offsets of nonzeros (in abuf or jbuf, see later) I need to send, per receiver.
373       rowptr[] - stores offsets for data of each row in abuf
374 
375       rdisp[]  - to receive sdisp[]
376     */
377     PetscCall(PetscMalloc3(niranks + 1, &sdisp, nranks, &rdisp, niranks + nranks, &reqs));
378     MatRowMapKokkosViewHost rowptr_h("rowptr_h", ioffset[niranks] + 1); /* Let Kokkos do the allocation, so that we can do an easy mirror later */
379     rowptr = rowptr_h.data();
380 
381     sdisp[0]  = 0;
382     rowptr[0] = 0;
383     for (i = 0; i < niranks; i++) { /* for each receiver */
384       PetscInt len, nz = 0;
385       for (j = ioffset[i]; j < ioffset[i + 1]; j++) { /* for each row to this receiver */
386         len           = Browlens[irootloc[j]];
387         rowptr[j + 1] = rowptr[j] + len;
388         nz += len;
389       }
390       sdisp[i + 1] = sdisp[i] + nz;
391     }
392     PetscCallMPI(PetscCommGetNewTag(comm, &tag));
393     for (i = 0; i < nranks; i++) PetscCallMPI(MPI_Irecv(&rdisp[i], 1, MPIU_INT, ranks[i], tag, comm, &reqs[i]));
394     for (i = 0; i < niranks; i++) PetscCallMPI(MPI_Isend(&sdisp[i], 1, MPIU_INT, iranks[i], tag, comm, &reqs[nranks + i]));
395     PetscCallMPI(MPI_Waitall(niranks + nranks, reqs, MPI_STATUSES_IGNORE));
396 
397     PetscInt     nleaves2 = Cnnz;           /* leaves are the nonzeros I will receive */
398     PetscInt     nroots2  = sdisp[niranks]; /* roots are the nonzeros (in abuf) I will send */
399     PetscSFNode *iremote;
400     PetscCall(PetscMalloc1(nleaves2, &iremote));
401     for (i = 0; i < nranks; i++) { /* for each sender */
402       k = 0;
403       for (j = Ci_h[roffset[i]]; j < Ci_h[roffset[i + 1]]; j++) {
404         iremote[j].rank  = ranks[i];
405         iremote[j].index = rdisp[i] + k;
406         k++;
407       }
408     }
409     /* TODO: we should extend PetscSF APIs for this buffer-to-buffer send/recv */
410     PetscCall(PetscSFCreate(comm, &bcastSF));
411     PetscCall(PetscSFSetGraph(bcastSF, nroots2, nleaves2, NULL /*ilocal*/, PETSC_OWN_POINTER, iremote, PETSC_OWN_POINTER));
412 
413     /* Extract selected rows of B, and copy their (a, j) into abuf[] and jbuf[], with j converted
414       from local to global. Then use bcastSF to fill Ca, Cj.
415     */
416     ConstMatColIdxKokkosViewHost rows_h(irootloc, ioffset[niranks]); /* irootloc[] stores indices of rows I need to send */
417     MatColIdxKokkosView          rows("rows", ioffset[niranks]);
418     Kokkos::deep_copy(rows, rows_h); /* Use deep copy since irootoc is managed by PetscSF and we want 'rows' to be standalone */
419 
420     rowoffset = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), rowptr_h); /* If no device, rowoffset will be an alias to rowptr_h */
421 
422     MatColIdxKokkosView jbuf("jbuf", sdisp[niranks]);   /* send buf for (global) col ids */
423     abuf = MatScalarKokkosView("abuf", sdisp[niranks]); /* send buf for mat values */
424 
425     const auto &Ba = bkok->a_dual.view_device();
426     const auto &Bi = bkok->i_dual.view_device();
427     const auto &Bj = bkok->j_dual.view_device();
428 
429     /* Copy Ba, Bj to abuf, jbuf with change col ids from local to global */
430     Kokkos::parallel_for(
431       Kokkos::TeamPolicy<>(rows.extent(0), Kokkos::AUTO()), KOKKOS_LAMBDA(const KokkosTeamMemberType &t) {
432         PetscInt i    = t.league_rank(); /* rows[i] is r-th row of B */
433         PetscInt r    = rows(i);
434         PetscInt base = rowoffset(i); /* Copy r-th row of B to this offset in abuf[] */
435         Kokkos::parallel_for(Kokkos::TeamThreadRange(t, Bi(r + 1) - Bi(r)), [&](PetscInt k) {
436           abuf(base + k) = Ba(Bi(r) + k);
437           jbuf(base + k) = l2g(Bj(Bi(r) + k));
438         });
439       });
440 
441     /* Send abuf & jbuf to fill Ca, Cj */
442     PetscCall(PetscSFBcastBegin(bcastSF, MPIU_INT, jbuf.data(), Cj.view_device().data(), MPI_REPLACE));
443     PetscCall(PetscSFBcastBegin(bcastSF, MPIU_SCALAR, abuf.data(), Ca.view_device().data(), MPI_REPLACE));
444     PetscCall(PetscSFBcastEnd(bcastSF, MPIU_INT, jbuf.data(), Cj.view_device().data(), MPI_REPLACE));
445     PetscCall(PetscSFBcastEnd(bcastSF, MPIU_SCALAR, abuf.data(), Ca.view_device().data(), MPI_REPLACE));
446     Cj.modify_device(); /* Mark Cj, Ca modified on device, but only sync Cj since we might not need Ca on host at all */
447     Cj.sync_host();
448     Ca.modify_device();
449 
450     /* Construct C with Ca, Ci, Cj */
451     auto ckok = new Mat_SeqAIJKokkos(Cm, Cn, Cnnz, Ci, Cj, Ca);
452     PetscCall(MatCreateSeqAIJKokkosWithCSRMatrix(PETSC_COMM_SELF, ckok, &C));
453     PetscCall(PetscFree3(sdisp, rdisp, reqs));
454     PetscCall(PetscFree(Browlens));
455   } else SETERRQ(PETSC_COMM_SELF, PETSC_ERR_PLIB, "Unsupported MatReuse enum %d", reuse);
456   PetscFunctionReturn(0);
457 }
458 
459 /* MatSeqAIJKokkosReduce - Reduce rows of a SEQAIJKOKKOS matrix (A) to form a Kokkos Csr matrix (C)
460 
461   It is the reverse of MatSeqAIJKokkosBcast in some sense.
462 
463   Think each row of A as a leaf, then the given ownerSF specifies roots of the leaves. Roots may connect to multiple leaves.
464   In this routine, we reduce (i.e., concatenate) leaves (rows) at their roots to form potentially longer rows in C. Such rows might
465   contain repeats, which does not matter since they will be summed up by other routines. C's row size will be nroots of ownerSF.
466 
467   Input Parameters:
468 +  A        - the SEQAIJKOKKOS matrix to be reduced
469 .  reuse    - either MAT_INITIAL_MATRIX or MAT_REUSE_MATRIX
470 .  local    - true if A uses local col ids; false if A is already in global col ids.
471 .  N        - if local, N is A's global col size
472 .  l2g      - if local, a map mapping A's local col ids to global ones, which are in range of [0,N).
473 -  ownerSF  - the SF specifies ownership (root) of rows in A
474 
475   Output Parameters:
476 +  reduceSF    - the SF to reduce A's rows to contiguous buffers at the receiver side
477 .  abuf         - a contiguous buffer to receive A's rows sent to this proc. Suppose there are 'nrows' such rows.
478 .  srcrowoffset - offset array of size nrows+1. Each entry is the corresponding row's offset in abuf[]. srcrowoffset[i+1]-srcrowoffset[i] is row i's len.
479 .  dstrowoffset - offset array of size nrows. Each entry is the corresponding row's offset in Ca[], i.e., C's 'a' array. Row i, i+1 in abuf[] may go to
480                   unrelated places in Ca, so dstrowoffset is not in CSR-like format as srcrowoffset.
481 -  C            - the matrix made up by rows sent to me from other ranks, using global col ids
482 
483    TODO: we can even have MatSeqAIJKokkosReduceBegin/End to provide oppertunity for callers to overlap comp./comm. when reuse = MAT_REUSE_MATRIX.
484  */
485 static PetscErrorCode MatSeqAIJKokkosReduce(Mat A, MatReuse reuse, PetscBool local, PetscInt N, const ConstMatColIdxKokkosView &l2g, PetscSF ownerSF, PetscSF &reduceSF, MatScalarKokkosView &abuf, MatRowMapKokkosView &srcrowoffset, MatRowMapKokkosView &dstrowoffset, KokkosCsrMatrix &C)
486 {
487   PetscInt          i, r, Am, An, Annz, Cnnz, nrows;
488   const PetscInt   *Ai;
489   Mat_SeqAIJKokkos *akok;
490 
491   PetscFunctionBegin;
492   PetscCall(MatSeqAIJKokkosSyncDevice(A)); /* So that A's latest data is on device */
493   PetscCall(MatGetSize(A, &Am, &An));
494   Ai   = static_cast<Mat_SeqAIJ *>(A->data)->i;
495   akok = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
496   Annz = Ai[Am];
497 
498   if (reuse == MAT_REUSE_MATRIX) {
499     /* Send Aa to abuf */
500     PetscCallMPI(PetscSFReduceBegin(reduceSF, MPIU_SCALAR, akok->a_device_data(), abuf.data(), MPI_REPLACE));
501     PetscCallMPI(PetscSFReduceEnd(reduceSF, MPIU_SCALAR, akok->a_device_data(), abuf.data(), MPI_REPLACE));
502 
503     /* Copy abuf to Ca */
504     const MatScalarKokkosView &Ca = C.values;
505     nrows                         = dstrowoffset.extent(0); /* Not srcrowoffset[] since it has an extra entry for CSR */
506     Kokkos::parallel_for(
507       Kokkos::TeamPolicy<>(nrows, Kokkos::AUTO()), KOKKOS_LAMBDA(const KokkosTeamMemberType &t) {
508         PetscInt i   = t.league_rank();
509         PetscInt src = srcrowoffset(i), dst = dstrowoffset(i);
510         PetscInt len = srcrowoffset(i + 1) - srcrowoffset(i);
511         Kokkos::parallel_for(Kokkos::TeamThreadRange(t, len), [&](PetscInt k) { Ca(dst + k) = abuf(src + k); });
512       });
513   } else if (reuse == MAT_INITIAL_MATRIX) {
514     MPI_Comm     comm;
515     MPI_Request *reqs;
516     PetscMPIInt  tag;
517     PetscInt     Cm;
518 
519     PetscCall(PetscObjectGetComm((PetscObject)ownerSF, &comm));
520     PetscCall(PetscCommGetNewTag(comm, &tag));
521 
522     PetscInt           niranks, nranks, nroots, nleaves;
523     const PetscMPIInt *iranks, *ranks;
524     const PetscInt    *ioffset, *rows, *roffset; /* rows[] contains local indices of rows scattered to me from others. ioffset[] is a CSR on rows[] */
525     PetscCall(PetscSFSetUp(ownerSF));
526     PetscCall(PetscSFGetLeafRanks(ownerSF, &niranks, &iranks, &ioffset, &rows));                          /* recv info: iranks[] will send rows to me */
527     PetscCall(PetscSFGetRootRanks(ownerSF, &nranks, &ranks, &roffset, NULL /*rmine*/, NULL /*rremote*/)); /* send info */
528     PetscCall(PetscSFGetGraph(ownerSF, &nroots, &nleaves, NULL, NULL));
529     PetscCheck(nleaves == Am, PETSC_COMM_SELF, PETSC_ERR_PLIB, "ownerSF's nleaves(%" PetscInt_FMT ") != row size of A(%" PetscInt_FMT ")", nleaves, Am);
530     Cm    = nroots;
531     nrows = ioffset[niranks]; /* # of rows to be received. Might receive same row (each is partial) from different senders */
532 
533     /* Tell owners how long each row I will send */
534     PetscInt               *srowlens;                              /* send buf of row lens */
535     MatRowMapKokkosViewHost rrowlens_h("rrowoffset_h", nrows + 1); /* recv buf of row lens. +1 to make CSR later. Memory might be passed to other views */
536     PetscInt               *rrowlens = rrowlens_h.data();
537 
538     PetscCall(PetscMalloc2(Am, &srowlens, niranks + nranks, &reqs));
539     for (i = 0; i < Am; i++) srowlens[i] = Ai[i + 1] - Ai[i];
540     rrowlens[0] = 0;
541     rrowlens++; /* shift the pointer to make the following expression more readable */
542     for (i = 0; i < niranks; i++) PetscCallMPI(MPI_Irecv(&rrowlens[ioffset[i]], ioffset[i + 1] - ioffset[i], MPIU_INT, iranks[i], tag, comm, &reqs[i]));
543     for (i = 0; i < nranks; i++) PetscCallMPI(MPI_Isend(&srowlens[roffset[i]], roffset[i + 1] - roffset[i], MPIU_INT, ranks[i], tag, comm, &reqs[niranks + i]));
544     PetscCallMPI(MPI_Waitall(niranks + nranks, reqs, MPI_STATUSES_IGNORE));
545 
546     /* Owner builds Ci on host by histogramming rrowlens[] */
547     MatRowMapKokkosViewHost Ci_h("i", Cm + 1);
548     Kokkos::deep_copy(Ci_h, 0); /* Zero Ci */
549     MatRowMapType *Ci_ptr = Ci_h.data();
550 
551     for (i = 0; i < nrows; i++) {
552       r = rows[i]; /* local row id of i-th received row */
553 #if defined(PETSC_USE_DEBUG)
554       PetscCheck(r >= 0 && r < Cm, PETSC_COMM_SELF, PETSC_ERR_PLIB, "local row id (%" PetscInt_FMT ") is out of range [0,%" PetscInt_FMT ")", r, Cm);
555 #endif
556       Ci_ptr[r + 1] += rrowlens[i]; /* add to length of row r in C */
557     }
558     for (i = 0; i < Cm; i++) Ci_ptr[i + 1] += Ci_ptr[i]; /* to CSR format */
559     Cnnz = Ci_ptr[Cm];
560 
561     /* For each received row, compute src & dst offsets in memory copying (from recv bufs abuf, jbuf to Ca, Cj) */
562     MatRowMapKokkosViewHost dstrowoffset_h("dstrowoffset_h", nrows);
563     PetscInt               *dstrowoffset_hptr = dstrowoffset_h.data();
564     PetscInt               *currowlens; /* Current row lens. They are temp accumulators for row lens in C, to help build dstrowoffset */
565 
566     PetscCall(PetscCalloc1(Cm, &currowlens));           /* Init with zero, to be added to */
567     for (i = 0; i < nrows; i++) {                       /* for each row I receive */
568       r                    = rows[i];                   /* row id in C */
569       dstrowoffset_hptr[i] = Ci_ptr[r] + currowlens[r]; /* dst offset of the new place for each recv'ed row in Ca/Cj */
570       currowlens[r] += rrowlens[i];                     /* accumulate to length of row r in C */
571     }
572     PetscCall(PetscFree(currowlens));
573 
574     rrowlens--;
575     for (i = 0; i < nrows; i++) rrowlens[i + 1] += rrowlens[i]; /* Change rrowlens[] to CSR format */
576     dstrowoffset = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), dstrowoffset_h);
577     srcrowoffset = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), rrowlens_h); /* src offset of each recv'ed row in abuf/jbuf */
578 
579     /* Build the reduceSF, which performs buffer to buffer send/recv */
580     PetscInt *sdisp, *rdisp; /* buffer to send offsets of roots, and buffer to recv them */
581     PetscCall(PetscMalloc2(niranks, &sdisp, nranks, &rdisp));
582     for (i = 0; i < niranks; i++) sdisp[i] = rrowlens[ioffset[i]];
583     for (i = 0; i < nranks; i++) PetscCallMPI(MPI_Irecv(&rdisp[i], 1, MPIU_INT, ranks[i], tag, comm, &reqs[i]));
584     for (i = 0; i < niranks; i++) PetscCallMPI(MPI_Isend(&sdisp[i], 1, MPIU_INT, iranks[i], tag, comm, &reqs[nranks + i]));
585     PetscCallMPI(MPI_Waitall(niranks + nranks, reqs, MPI_STATUSES_IGNORE));
586 
587     /* Nonzeros in abuf/jbuf are roots and those in A are leaves */
588     PetscInt     nroots2 = Cnnz, nleaves2 = Annz;
589     PetscSFNode *iremote;
590     PetscCall(PetscMalloc1(nleaves2, &iremote)); /* no free, since memory will be given to reduceSF */
591     for (i = 0; i < nranks; i++) {
592       PetscInt rootbase = rdisp[i];                      /* root offset at this root rank */
593       PetscInt leafbase = Ai[roffset[i]];                /* leaf base */
594       PetscInt nz       = Ai[roffset[i + 1]] - leafbase; /* I will send nz nonzeros to this root rank */
595       for (PetscInt k = 0; k < nz; k++) {
596         iremote[leafbase + k].rank  = ranks[i];
597         iremote[leafbase + k].index = rootbase + k;
598       }
599     }
600     PetscCall(PetscSFCreate(comm, &reduceSF));
601     PetscCall(PetscSFSetGraph(reduceSF, nroots2, nleaves2, NULL, PETSC_OWN_POINTER, iremote, PETSC_OWN_POINTER));
602     PetscCall(PetscFree2(sdisp, rdisp));
603 
604     /* Reduce Aa, Ajg to abuf and jbuf */
605 
606     /* If A uses local col ids, convert them to global ones before sending */
607     MatColIdxKokkosView Ajg;
608     if (local) {
609       Ajg                           = MatColIdxKokkosView("j", Annz);
610       const MatColIdxKokkosView &Aj = akok->j_dual.view_device();
611       Kokkos::parallel_for(
612         Annz, KOKKOS_LAMBDA(const PetscInt i) { Ajg(i) = l2g(Aj(i)); });
613     } else {
614       Ajg = akok->j_dual.view_device(); /* no data copy, just take a reference */
615     }
616 
617     MatColIdxKokkosView jbuf("jbuf", Cnnz);
618     abuf = MatScalarKokkosView("abuf", Cnnz);
619     PetscCallMPI(PetscSFReduceBegin(reduceSF, MPIU_INT, Ajg.data(), jbuf.data(), MPI_REPLACE));
620     PetscCallMPI(PetscSFReduceEnd(reduceSF, MPIU_INT, Ajg.data(), jbuf.data(), MPI_REPLACE));
621     PetscCallMPI(PetscSFReduceBegin(reduceSF, MPIU_SCALAR, akok->a_device_data(), abuf.data(), MPI_REPLACE));
622     PetscCallMPI(PetscSFReduceEnd(reduceSF, MPIU_SCALAR, akok->a_device_data(), abuf.data(), MPI_REPLACE));
623 
624     /* Copy data from abuf, jbuf to Ca, Cj */
625     MatRowMapKokkosView Ci = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Ci_h); /* Ci is an alias of Ci_h if no device */
626     MatColIdxKokkosView Cj("j", Cnnz);
627     MatScalarKokkosView Ca("a", Cnnz);
628 
629     Kokkos::parallel_for(
630       Kokkos::TeamPolicy<>(nrows, Kokkos::AUTO()), KOKKOS_LAMBDA(const KokkosTeamMemberType &t) {
631         PetscInt i   = t.league_rank();
632         PetscInt src = srcrowoffset(i), dst = dstrowoffset(i);
633         PetscInt len = srcrowoffset(i + 1) - srcrowoffset(i);
634         Kokkos::parallel_for(Kokkos::TeamThreadRange(t, len), [&](PetscInt k) {
635           Ca(dst + k) = abuf(src + k);
636           Cj(dst + k) = jbuf(src + k);
637         });
638       });
639 
640     /* Build C with Ca, Ci, Cj */
641     C = KokkosCsrMatrix("csrmat", Cm, N, Cnnz, Ca, Ci, Cj);
642     PetscCall(PetscFree2(srowlens, reqs));
643   } else SETERRQ(PETSC_COMM_SELF, PETSC_ERR_PLIB, "Unsupported MatReuse enum %d", reuse);
644   PetscFunctionReturn(0);
645 }
646 
647 /* MatSetMPIAIJKokkosWithGlobalCSRMatrix - Set the diag and offdiag parts of a `MATMPIAIJKOKKOS` matrix by splitting a KokkosCsrMatrix
648 
649   Input Parameters:
650 +  C        - the `MATMPIAIJKOKKOS` matrix, of size m,n,M,N
651 .  reuse    - indicate whether the matrix has called this function before
652 .  csrmat   - the KokkosCsrMatrix, of size m,N
653 -  Cdstart  - when reuse == `MAT_REUSE_MATRIX`, it is an input parameter. For each row in csrmat, it stores the start of the first
654               entry of the diag block of C in csrmat's j array. E.g, if row i has col ids = {0, 3, 4, 5, 7, 9} and the first diag
655               entry is 5, then Cdstart[i] = 3.
656 
657   Output Parameters:
658 +  C        - the updated `MATMPIAIJKOKKOS` matrix
659 -  Cdstart - when reuse == `MAT_INITIAL_MATRIX`, it is an output parameter
660 
661   Note:
662    Between calls with `MAT_INITIAL_MATRIX` or `MAT_REUSE_MATRIX`, csrmat must have the same nonzero pattern
663 
664 .seealso: `MATMPIAIJKOKKOS`
665  */
666 static PetscErrorCode MatSetMPIAIJKokkosWithGlobalCSRMatrix(Mat C, MatReuse reuse, const KokkosCsrMatrix &csrmat, MatRowMapKokkosView &Cdstart)
667 {
668   const MatScalarKokkosView      &Ca = csrmat.values;
669   const ConstMatRowMapKokkosView &Ci = csrmat.graph.row_map;
670   PetscInt                        m, n, N;
671 
672   PetscFunctionBegin;
673   PetscCall(MatGetLocalSize(C, &m, &n));
674   PetscCall(MatGetSize(C, NULL, &N));
675 
676   if (reuse == MAT_REUSE_MATRIX) {
677     Mat_MPIAIJ                *mpiaij = static_cast<Mat_MPIAIJ *>(C->data);
678     Mat_SeqAIJKokkos          *akok   = static_cast<Mat_SeqAIJKokkos *>(mpiaij->A->spptr);
679     Mat_SeqAIJKokkos          *bkok   = static_cast<Mat_SeqAIJKokkos *>(mpiaij->B->spptr);
680     const MatScalarKokkosView &Cda = akok->a_dual.view_device(), Coa = bkok->a_dual.view_device();
681     const MatRowMapKokkosView &Cdi = akok->i_dual.view_device(), Coi = bkok->i_dual.view_device();
682 
683     /* Fill 'a' of Cd and Co on device */
684     Kokkos::parallel_for(
685       Kokkos::TeamPolicy<>(m, Kokkos::AUTO()), KOKKOS_LAMBDA(const KokkosTeamMemberType &t) {
686         PetscInt i       = t.league_rank();     /* row i */
687         PetscInt clen    = Ci(i + 1) - Ci(i);   /* len of row i of C */
688         PetscInt cdlen   = Cdi(i + 1) - Cdi(i); /* len of row i of Cd */
689         PetscInt cdstart = Cdstart(i);          /* [start, end) of row i of Cd in C */
690         PetscInt cdend   = cdstart + cdlen;
691         /* [0, clen) is cut into three blocks: [0, cdstart), [cdstart, cdend), [cdend, clen) */
692         Kokkos::parallel_for(Kokkos::TeamThreadRange(t, clen), [&](PetscInt k) {
693           if (k < cdstart) { /* k in [0, cdstart) */
694             Coa(Coi(i) + k) = Ca(Ci(i) + k);
695           } else if (k < cdend) { /* k in [cdstart, cdend) */
696             Cda(Cdi(i) + (k - cdstart)) = Ca(Ci(i) + k);
697           } else { /* k in [cdend, clen) */
698             Coa(Coi(i) + k - cdlen) = Ca(Ci(i) + k);
699           }
700         });
701       });
702 
703     akok->a_dual.modify_device();
704     bkok->a_dual.modify_device();
705   } else if (reuse == MAT_INITIAL_MATRIX) {
706     Mat                        Cd, Co;
707     const MatColIdxKokkosView &Cj = csrmat.graph.entries;
708     MatRowMapKokkosDualView    Cdi_dual("i", m + 1), Coi_dual("i", m + 1);
709     MatRowMapKokkosView        Cdi = Cdi_dual.view_device(), Coi = Coi_dual.view_device();
710     PetscInt                   cstart, cend;
711 
712     /* Note that each row of C is sorted by col ids. We want to find out how to cut each row into three blocks:
713        left to the diag block, diag block, right to the diag block. The diag block have col ids in [cstart,cend).
714        Suppose a row of C has len nonzeros, indexed by [0, len). We want to know two indices: cdstart and cdend,
715        such that the three blocks are [0,cdstart), [cdstart,cdend), [cdend,len). The following code equivalentaly
716        stores values of cdstart and cdend-cstart (aka Cdi[]) instead.
717      */
718     Cdstart = MatRowMapKokkosView("Cdstart", m);
719     PetscCall(PetscLayoutGetRange(C->cmap, &cstart, &cend)); /* Not MatGetOwnershipRangeColumn() since C has not been preallocated yet */
720 
721     /* I could use RangePolicy and one thread per row. But since each thread essentially does binary search, threads in a
722       CUDA warp would completely diverge. So I use TeamPolicy with a team size 1.
723      */
724     Kokkos::parallel_for(
725       Kokkos::TeamPolicy<>(m, 1), KOKKOS_LAMBDA(const KokkosTeamMemberType &t) {
726         Kokkos::single(Kokkos::PerTeam(t), [=]() {                               /* Only one thread works in a team */
727                                                    PetscInt i = t.league_rank(); /* row i */
728                                                    PetscInt j, first, count, step;
729 
730                                                    if (i == 0) { /* Set the first entry of the i arrays to zero on device, to be used in CSR */
731                                                      Cdi(0) = 0;
732                                                      Coi(0) = 0;
733                                                    }
734 
735                                                    /* Do std::lower_bound(Ci(i),Ci(i+1),cstart) on Cj[]. We use j as the iterator. lower_bound() returns
736           in 'first' the first iterator with a value >= cstart, or last iterator if no such element is found.
737         */
738                                                    count = Ci(i + 1) - Ci(i);
739                                                    first = Ci(i);
740                                                    while (count > 0) {
741                                                      j    = first;
742                                                      step = count / 2;
743                                                      j += step;
744                                                      if (Cj(j) < cstart) {
745                                                        first = ++j;
746                                                        count -= step + 1;
747                                                      } else count = step;
748                                                    }
749                                                    Cdstart(i) = first - Ci(i); /* 'first' is the while-loop's output */
750 
751                                                    /* Do std::lower_bound(first,Ci(i+1),cend) on Cj[] */
752                                                    count = Ci(i + 1) - first;
753                                                    while (count > 0) {
754                                                      j    = first;
755                                                      step = count / 2;
756                                                      j += step;
757                                                      if (Cj(j) < cend) {
758                                                        first = ++j;
759                                                        count -= step + 1;
760                                                      } else count = step;
761                                                    }
762                                                    Cdi(i + 1) = first - (Ci(i) + Cdstart(i));     /* 'first' is the while-loop's output */
763                                                    Coi(i + 1) = (Ci(i + 1) - Ci(i)) - Cdi(i + 1); /* Co's row len = C's row len - Cd's row len */
764         });
765       });
766 
767     /* Convert row lens in Cdi[], Coi[] to CSR format using inclusive scan, e.g., changing [0,1,2,3] into [0,1,3,6] */
768     Kokkos::parallel_scan(
769       m + 1, KOKKOS_LAMBDA(const PetscInt i, PetscInt &update, const bool final) {
770         update += Cdi(i);
771         if (final) Cdi(i) = update;
772       });
773     Kokkos::parallel_scan(
774       m + 1, KOKKOS_LAMBDA(const PetscInt i, PetscInt &update, const bool final) {
775         update += Coi(i);
776         if (final) Coi(i) = update;
777       });
778 
779     /* Get Cdi, Coi on host (it is not a waste, since we do need them on host in
780        MatCreateSeqAIJKokkosWithCSRMatrix() below), then get nnz of Cd and Co.
781     */
782     Cdi_dual.modify_device();
783     Coi_dual.modify_device();
784     Cdi_dual.sync_host();
785     Coi_dual.sync_host();
786     PetscInt Cd_nnz = Cdi_dual.view_host().data()[m];
787     PetscInt Co_nnz = Coi_dual.view_host().data()[m];
788 
789     /* With nnz, allocate a, j for Cd and Co */
790     MatColIdxKokkosDualView Cdj_dual("j", Cd_nnz), Coj_dual("j", Co_nnz);
791     MatScalarKokkosDualView Cda_dual("a", Cd_nnz), Coa_dual("a", Co_nnz);
792 
793     /* Fill a, j of Cd and Co on device */
794     MatColIdxKokkosView Cdj = Cdj_dual.view_device(), Coj = Coj_dual.view_device();
795     MatScalarKokkosView Cda = Cda_dual.view_device(), Coa = Coa_dual.view_device();
796 
797     Kokkos::parallel_for(
798       Kokkos::TeamPolicy<>(m, Kokkos::AUTO()), KOKKOS_LAMBDA(const KokkosTeamMemberType &t) {
799         PetscInt i       = t.league_rank();     /* row i */
800         PetscInt clen    = Ci(i + 1) - Ci(i);   /* len of row i of C */
801         PetscInt cdlen   = Cdi(i + 1) - Cdi(i); /* len of row i of Cd */
802         PetscInt cdstart = Cdstart(i);          /* [start, end) of row i of Cd in C */
803         PetscInt cdend   = cdstart + cdlen;
804         /* [0, clen) is cut into three blocks: [0, cdstart), [cdstart, cdend), [cdend, clen) */
805         Kokkos::parallel_for(Kokkos::TeamThreadRange(t, clen), [&](PetscInt k) {
806           if (k < cdstart) { /* k in [0, cdstart) */
807             Coa(Coi(i) + k) = Ca(Ci(i) + k);
808             Coj(Coi(i) + k) = Cj(Ci(i) + k);
809           } else if (k < cdend) { /* k in [cdstart, cdend) */
810             Cda(Cdi(i) + (k - cdstart)) = Ca(Ci(i) + k);
811             Cdj(Cdi(i) + (k - cdstart)) = Cj(Ci(i) + k) - cstart; /* Use local col ids in Cdj */
812           } else {                                                /* k in [cdend, clen) */
813             Coa(Coi(i) + k - cdlen) = Ca(Ci(i) + k);
814             Coj(Coi(i) + k - cdlen) = Cj(Ci(i) + k);
815           }
816         });
817       });
818 
819     Cdj_dual.modify_device();
820     Cda_dual.modify_device();
821     Coj_dual.modify_device();
822     Coa_dual.modify_device();
823     /* With a, i, j for Cd and Co, finally build Cd, Co and then C. Their offloadmask will be set in each's MatAssemblyEnd */
824     auto cdkok = new Mat_SeqAIJKokkos(m, n, Cd_nnz, Cdi_dual, Cdj_dual, Cda_dual);
825     auto cokok = new Mat_SeqAIJKokkos(m, N, Co_nnz, Coi_dual, Coj_dual, Coa_dual);
826     PetscCall(MatCreateSeqAIJKokkosWithCSRMatrix(PETSC_COMM_SELF, cdkok, &Cd));
827     PetscCall(MatCreateSeqAIJKokkosWithCSRMatrix(PETSC_COMM_SELF, cokok, &Co));
828     PetscCall(MatSetMPIAIJKokkosWithSplitSeqAIJKokkosMatrices(C, Cd, Co)); /* Coj will be converted to local ids within */
829   }
830   PetscFunctionReturn(0);
831 }
832 
833 /* MatSeqAIJCompactOutExtraColumns_SeqAIJKokkos - Compact a SEQAIJKOKKS matrix's global col ids.
834 
835   It is similar to MatSeqAIJCompactOutExtraColumns_SeqAIJ, but it applies to SEQAIJKOKKOS and returns the l2g map in Kokkos view.
836 
837   Input Parameters:
838 +  C        - the MATMPIAIJKOKKOS matrix, of size m,n,M,N
839 .  reuse    - indicate whether the matrix has called this function before
840 .  csrmat   - the KokkosCsrMatrix, of size m,N
841 -  Cdoffset - when reuse == MAT_REUSE_MATRIX, it is an input parameter. For each row in csrmat, it stores the offset of the first
842               entry of the diag block of C in csrmat's j array.
843 
844   Output Parameters:
845 +  C        - the updated MATMPIAIJKOKKOS matrix
846 -  Cdoffset - when reuse == MAT_INITIAL_MATRIX, it is an output parameter
847 
848   Note:
849   the input matrix's col ids and col size will be changed.
850 */
851 static PetscErrorCode MatSeqAIJCompactOutExtraColumns_SeqAIJKokkos(Mat C, MatColIdxKokkosView &l2g)
852 {
853   Mat_SeqAIJKokkos      *ckok;
854   ISLocalToGlobalMapping l2gmap;
855   const PetscInt        *garray;
856   PetscInt               sz;
857 
858   PetscFunctionBegin;
859   /* Compact P_other's global col ids and col size. We do it since we guess with local ids KK might be more memory scalable */
860   PetscCall(MatSeqAIJCompactOutExtraColumns_SeqAIJ(C, &l2gmap));
861   ckok = static_cast<Mat_SeqAIJKokkos *>(C->spptr);
862   ckok->j_dual.modify_host(); /* P_other's j is modified on host; we need to sync it on device */
863   ckok->j_dual.sync_device();
864   ckok->SetColSize(C->cmap->n); /* Update col size of the csrmat in spptr */
865 
866   /* Build l2g -- the local to global mapping of C's cols */
867   PetscCall(ISLocalToGlobalMappingGetIndices(l2gmap, &garray));
868   PetscCall(ISLocalToGlobalMappingGetSize(l2gmap, &sz));
869   PetscCheck(C->cmap->n == sz, PETSC_COMM_SELF, PETSC_ERR_PLIB, "matrix column size(%" PetscInt_FMT ") != l2g mapping size(%" PetscInt_FMT ")", C->cmap->n, sz);
870 
871   ConstMatColIdxKokkosViewHost tmp(garray, sz);
872   l2g = MatColIdxKokkosView("l2g", sz);
873   Kokkos::deep_copy(l2g, tmp);
874 
875   PetscCall(ISLocalToGlobalMappingRestoreIndices(l2gmap, &garray));
876   PetscCall(ISLocalToGlobalMappingDestroy(&l2gmap));
877   PetscFunctionReturn(0);
878 }
879 
880 /* MatProductSymbolic_MPIAIJKokkos_AB - AB flavor of MatProductSymbolic_MPIAIJKokkos
881 
882   Input Parameters:
883 +  product  - Mat_Product which carried out the computation. Passed in to access info about this mat product.
884 .  A        - an MPIAIJKOKKOS matrix
885 .  B        - an MPIAIJKOKKOS matrix
886 -  mm       - a struct used to stash intermediate data when computing AB. Persist from symbolic to numeric operations.
887 
888   Note: The local part of the result C is stored as mm->C_global, which is of type KokkosCsrMatrix and uses global col ids.
889 */
890 static PetscErrorCode MatProductSymbolic_MPIAIJKokkos_AB(Mat_Product *product, Mat A, Mat B, MatMatStruct_AB *mm)
891 {
892   Mat_MPIAIJ              *a  = static_cast<Mat_MPIAIJ *>(A->data);
893   Mat                      Ad = a->A, Ao = a->B; /* diag and offdiag of A */
894   IS                       glob = NULL;
895   const PetscInt          *garray;
896   PetscInt                 N = B->cmap->N, sz;
897   ConstMatColIdxKokkosView l2g1; /* two temp maps mapping local col ids to global ones */
898   MatColIdxKokkosView      l2g2;
899   Mat                      C1, C2; /* intermediate matrices */
900 
901   PetscFunctionBegin;
902   /* C1 = Ad * B_local. B_local is a matrix got by merging Bd and Bo, and uses local col ids */
903   PetscCall(MatMPIAIJGetLocalMatMerge(B, MAT_INITIAL_MATRIX, &glob, &mm->B_local));
904   PetscCall(MatProductCreate(Ad, mm->B_local, NULL, &C1));
905   PetscCall(MatProductSetType(C1, MATPRODUCT_AB));
906   PetscCall(MatProductSetFill(C1, product->fill));
907   C1->product->api_user = product->api_user;
908   PetscCall(MatProductSetFromOptions(C1));
909   PetscUseTypeMethod(C1, productsymbolic);
910 
911   PetscCall(ISGetIndices(glob, &garray));
912   PetscCall(ISGetSize(glob, &sz));
913   const auto &tmp = ConstMatColIdxKokkosViewHost(garray, sz);                       /* wrap garray as a view */
914   l2g1            = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), tmp); /* maybe just an alias to tmp, so we restore garray at the very end */
915   PetscCall(MatSeqAIJKokkosGetCSRMatrixWithGlobalColumnIds(C1, N, l2g1, mm->C1_global));
916 
917   /* C2 = Ao * B_other. B_other is a matrix consisting of needed rows of B gathered from other procs */
918   PetscCall(MatSeqAIJKokkosBcast(mm->B_local, MAT_INITIAL_MATRIX, N, l2g1, a->Mvctx, mm->sf, mm->abuf, mm->rows, mm->rowoffset, mm->B_other));
919 
920   /* Compact B_other to use local ids as we guess KK spgemm is more memroy scalable with that; We could skip the compaction to simplify code */
921   PetscCall(MatSeqAIJCompactOutExtraColumns_SeqAIJKokkos(mm->B_other, l2g2));
922   PetscCall(MatProductCreate(Ao, mm->B_other, NULL, &C2));
923   PetscCall(MatProductSetType(C2, MATPRODUCT_AB));
924   PetscCall(MatProductSetFill(C2, product->fill));
925   C2->product->api_user = product->api_user;
926   PetscCall(MatProductSetFromOptions(C2));
927   PetscUseTypeMethod(C2, productsymbolic);
928   PetscCall(MatSeqAIJKokkosGetCSRMatrixWithGlobalColumnIds(C2, N, l2g2, mm->C2_global));
929 
930   /* C = C1 + C2.  We actually use their global col ids versions in adding */
931   mm->kh.create_spadd_handle(false); /* Input C1, C2 are NOT sorted, since B_local, B_other are not */
932   KokkosSparse::spadd_symbolic(&mm->kh, mm->C1_global, mm->C2_global, mm->C_global);
933   /* Have to do numeric since spadd_symbolic does not really populate column indices of the result matrix */
934   KokkosSparse::spadd_numeric(&mm->kh, (MatScalarType)1.0, mm->C1_global, (MatScalarType)1.0, mm->C2_global, mm->C_global);
935 
936   mm->C1 = C1;
937   mm->C2 = C2;
938   PetscCall(ISRestoreIndices(glob, &garray));
939   PetscCall(ISDestroy(&glob));
940   PetscFunctionReturn(0);
941 }
942 
943 /* MatProductSymbolic_MPIAIJKokkos_AtB - A^tB flavor of MatProductSymbolic_MPIAIJKokkos
944 
945   Input Parameters:
946 +  product  - Mat_Product which carried out the computation. Passed in to access info about this mat product.
947 .  A        - an MPIAIJKOKKOS matrix
948 .  B        - a SEQAIJKOKKOS matrix. It works as if A^t is multiplied by a parallel matrix made up of Bs on each rank.
949 .  localB   - Does B use local col ids? If false, then B is already in global col ids.
950 .  N        - col size of the "parallel B matrix". It implies B's global col ids are in range of [0,N) and N is the same across the communicator.
951 .  l2g      - If localB, then l2g maps B's local col ids to global ones.
952 -  mm       - a struct used to stash intermediate data in AtB
953 
954   Note: The local part of the result C is stored as mm->C_global, which is of type KokkosCsrMatrix and uses global col ids.
955 */
956 static PetscErrorCode MatProductSymbolic_MPIAIJKokkos_AtB(Mat_Product *product, Mat A, Mat B, PetscBool localB, PetscInt N, const ConstMatColIdxKokkosView &l2g, MatMatStruct_AtB *mm)
957 {
958   Mat_MPIAIJ *a  = static_cast<Mat_MPIAIJ *>(A->data);
959   Mat         Ad = a->A, Ao = a->B; /* diag and offdiag of A */
960   Mat         C1, C2;               /* intermediate matrices */
961 
962   PetscFunctionBegin;
963   /* C1 = Ad^t * B */
964   PetscCall(MatProductCreate(Ad, B, NULL, &C1));
965   PetscCall(MatProductSetType(C1, MATPRODUCT_AtB));
966   PetscCall(MatProductSetFill(C1, product->fill));
967   C1->product->api_user = product->api_user;
968   PetscCall(MatProductSetFromOptions(C1));
969   PetscUseTypeMethod(C1, productsymbolic);
970 
971   if (localB) PetscCall(MatSeqAIJKokkosGetCSRMatrixWithGlobalColumnIds(C1, N, l2g, mm->C1_global));
972   else mm->C1_global = static_cast<Mat_SeqAIJKokkos *>(C1->spptr)->csrmat; /* the csrmat already uses global col ids */
973 
974   /* C2 = Ao^t * B */
975   PetscCall(MatProductCreate(Ao, B, NULL, &C2));
976   PetscCall(MatProductSetType(C2, MATPRODUCT_AtB));
977   PetscCall(MatProductSetFill(C2, product->fill));
978   C2->product->api_user = product->api_user;
979   PetscCall(MatProductSetFromOptions(C2));
980   PetscUseTypeMethod(C2, productsymbolic);
981 
982   PetscCall(MatSeqAIJKokkosReduce(C2, MAT_INITIAL_MATRIX, localB, N, l2g, a->Mvctx, mm->sf, mm->abuf, mm->srcrowoffset, mm->dstrowoffset, mm->C2_global));
983 
984   mm->kh.create_spadd_handle(false); /* Input C1, C2 are NOT sorted, since B may be not */
985   KokkosSparse::spadd_symbolic(&mm->kh, mm->C1_global, mm->C2_global, mm->C_global);
986   /* Have to do numeric since spadd_symbolic does not really populate column indices of the result matrix */
987   KokkosSparse::spadd_numeric(&mm->kh, (MatScalarType)1.0, mm->C1_global, (MatScalarType)1.0, mm->C2_global, mm->C_global);
988   mm->C1 = C1;
989   mm->C2 = C2;
990   PetscFunctionReturn(0);
991 }
992 
993 PetscErrorCode MatProductNumeric_MPIAIJKokkos(Mat C)
994 {
995   Mat_Product                 *product = C->product;
996   MatProductType               ptype;
997   MatProductData_MPIAIJKokkos *mmdata;
998   MatMatStruct                *mm = NULL;
999   MatMatStruct_AB             *ab;
1000   MatMatStruct_AtB            *atb;
1001   Mat                          A, B, Ad, Ao, Bd, Bo;
1002   const MatScalarType          one = 1.0; /* Not use literal 1.0 directly, to avoid wrong template instantiation in KokkosSparse::spadd_numeric */
1003 
1004   PetscFunctionBegin;
1005   MatCheckProduct(C, 1);
1006   mmdata = static_cast<MatProductData_MPIAIJKokkos *>(product->data);
1007   ptype  = product->type;
1008   A      = product->A;
1009   B      = product->B;
1010   Ad     = static_cast<Mat_MPIAIJ *>(A->data)->A;
1011   Ao     = static_cast<Mat_MPIAIJ *>(A->data)->B;
1012   Bd     = static_cast<Mat_MPIAIJ *>(B->data)->A;
1013   Bo     = static_cast<Mat_MPIAIJ *>(B->data)->B;
1014 
1015   if (mmdata->reusesym) {           /* We reached here through e.g., MatMatMult(A,B,MAT_INITIAL_MATRIX,..,C), where symbolic/numeric are combined */
1016     mmdata->reusesym = PETSC_FALSE; /* So that next time when user calls MatMatMult(E,F,MAT_REUSE_MATRIX,..,C), we still do numeric  */
1017     ab               = mmdata->mmAB;
1018     atb              = mmdata->mmAtB;
1019     if (ab) {
1020       static_cast<MatProductData_SeqAIJKokkos *>(ab->C1->product->data)->reusesym = PETSC_FALSE;
1021       static_cast<MatProductData_SeqAIJKokkos *>(ab->C2->product->data)->reusesym = PETSC_FALSE;
1022     }
1023     if (atb) {
1024       static_cast<MatProductData_SeqAIJKokkos *>(atb->C1->product->data)->reusesym = PETSC_FALSE;
1025       static_cast<MatProductData_SeqAIJKokkos *>(atb->C2->product->data)->reusesym = PETSC_FALSE;
1026     }
1027     PetscFunctionReturn(0);
1028   }
1029 
1030   if (ptype == MATPRODUCT_AB) {
1031     ab = mmdata->mmAB;
1032     /* C1 = Ad * B_local */
1033     PetscCheck(ab->C1->ops->productnumeric && ab->C2->ops->productnumeric, PETSC_COMM_SELF, PETSC_ERR_PLIB, "Missing numeric op for MATPRODUCT_AB");
1034     PetscCall(MatMPIAIJGetLocalMatMerge(B, MAT_REUSE_MATRIX, NULL /* glob */, &ab->B_local));
1035     PetscCheck(ab->C1->product->B == ab->B_local, PETSC_COMM_SELF, PETSC_ERR_PLIB, "In MATPRODUCT_AB, internal mat product matrix C1->B has unexpectedly changed");
1036     if (ab->C1->product->A != Ad) PetscCall(MatProductReplaceMats(Ad, NULL, NULL, ab->C1));
1037     PetscCall((*ab->C1->ops->productnumeric)(ab->C1));
1038     PetscCall(MatSeqAIJKokkosBcast(ab->B_local, MAT_REUSE_MATRIX, 0 /* N */, MatColIdxKokkosView() /*l2g*/, NULL /*ownerSF*/, ab->sf, ab->abuf, ab->rows, ab->rowoffset, ab->B_other));
1039     /* C2 = Ao * B_other */
1040     PetscCheck(ab->C2->product->B == ab->B_other, PETSC_COMM_SELF, PETSC_ERR_PLIB, "In MATPRODUCT_AB, internal mat product matrix C2->B has unexpectedly changed");
1041     if (ab->C1->product->A != Ao) PetscCall(MatProductReplaceMats(Ao, NULL, NULL, ab->C2));
1042     PetscCall((*ab->C2->ops->productnumeric)(ab->C2));
1043     /* C = C1_global + C2_global */
1044     KokkosSparse::spadd_numeric(&ab->kh, one, ab->C1_global, one, ab->C2_global, ab->C_global);
1045     mm = static_cast<MatMatStruct *>(ab);
1046   } else if (ptype == MATPRODUCT_AtB) {
1047     atb = mmdata->mmAtB;
1048     PetscCheck(atb->C1->ops->productnumeric && atb->C2->ops->productnumeric, PETSC_COMM_SELF, PETSC_ERR_PLIB, "Missing numeric op for MATPRODUCT_AtB");
1049     /* C1 = Ad^t * B_local */
1050     PetscCall(MatMPIAIJGetLocalMatMerge(B, MAT_REUSE_MATRIX, NULL /* glob */, &atb->B_local));
1051     PetscCheck(atb->C1->product->B == atb->B_local, PETSC_COMM_SELF, PETSC_ERR_PLIB, "In MATPRODUCT_AtB, internal mat product matrix C1->B has unexpectedly changed");
1052     if (atb->C1->product->A != Ad) PetscCall(MatProductReplaceMats(Ad, NULL, NULL, atb->C1));
1053     PetscCall((*atb->C1->ops->productnumeric)(atb->C1));
1054 
1055     /* C2 = Ao^t * B_local */
1056     PetscCheck(atb->C2->product->B == atb->B_local, PETSC_COMM_SELF, PETSC_ERR_PLIB, "In MATPRODUCT_AtB, internal mat product matrix C2->B has unexpectedly changed");
1057     if (atb->C2->product->A != Ao) PetscCall(MatProductReplaceMats(Ao, NULL, NULL, atb->C2));
1058     PetscCall((*atb->C2->ops->productnumeric)(atb->C2));
1059     /* Form C2_global */
1060     PetscCall(MatSeqAIJKokkosReduce(atb->C2, MAT_REUSE_MATRIX, PETSC_TRUE, 0 /* N */, MatColIdxKokkosView() /*l2g*/, NULL /*ownerSF*/, atb->sf, atb->abuf, atb->srcrowoffset, atb->dstrowoffset, atb->C2_global));
1061     /* C = C1_global + C2_global */
1062     KokkosSparse::spadd_numeric(&atb->kh, one, atb->C1_global, one, atb->C2_global, atb->C_global);
1063     mm = static_cast<MatMatStruct *>(atb);
1064   } else if (ptype == MATPRODUCT_PtAP) { /* BtAB */
1065     ab = mmdata->mmAB;
1066     PetscCall(MatMPIAIJGetLocalMatMerge(B, MAT_REUSE_MATRIX, NULL /* glob */, &ab->B_local));
1067 
1068     /* ab->C1 = Ad * B_local */
1069     PetscCheck(ab->C1->product->B == ab->B_local, PETSC_COMM_SELF, PETSC_ERR_PLIB, "In MATPRODUCT_PtAP, internal mat product matrix ab->C1->B has unexpectedly changed");
1070     if (ab->C1->product->A != Ad) PetscCall(MatProductReplaceMats(Ad, NULL, NULL, ab->C1));
1071     PetscCall((*ab->C1->ops->productnumeric)(ab->C1));
1072     PetscCall(MatSeqAIJKokkosBcast(ab->B_local, MAT_REUSE_MATRIX, 0 /* N */, MatColIdxKokkosView() /*l2g*/, NULL /*ownerSF*/, ab->sf, ab->abuf, ab->rows, ab->rowoffset, ab->B_other));
1073     /* ab->C2 = Ao * B_other */
1074     if (ab->C2->product->A != Ao) PetscCall(MatProductReplaceMats(Ao, NULL, NULL, ab->C2));
1075     PetscCall((*ab->C2->ops->productnumeric)(ab->C2)); /* C2 = Ao * B_other */
1076     KokkosSparse::spadd_numeric(&ab->kh, one, ab->C1_global, one, ab->C2_global, ab->C_global);
1077 
1078     /* atb->C1 = Bd^t * ab->C_petsc */
1079     atb = mmdata->mmAtB;
1080     PetscCheck(atb->C1->product->B == ab->C_petsc, PETSC_COMM_SELF, PETSC_ERR_PLIB, "In MATPRODUCT_PtAP, internal mat product matrix atb->C1->B has unexpectedly changed");
1081     if (atb->C1->product->A != Bd) PetscCall(MatProductReplaceMats(Bd, NULL, NULL, atb->C1));
1082     PetscCall((*atb->C1->ops->productnumeric)(atb->C1));
1083     /* atb->C2 = Bo^t * ab->C_petsc */
1084     if (atb->C2->product->A != Bo) PetscCall(MatProductReplaceMats(Bo, NULL, NULL, atb->C2));
1085     PetscCall((*atb->C2->ops->productnumeric)(atb->C2));
1086     PetscCall(MatSeqAIJKokkosReduce(atb->C2, MAT_REUSE_MATRIX, PETSC_FALSE, 0 /* N */, MatColIdxKokkosView() /*l2g*/, NULL /* ownerSF */, atb->sf, atb->abuf, atb->srcrowoffset, atb->dstrowoffset, atb->C2_global));
1087     KokkosSparse::spadd_numeric(&atb->kh, one, atb->C1_global, one, atb->C2_global, atb->C_global);
1088     mm = static_cast<MatMatStruct *>(atb);
1089   }
1090   /* Split C_global to form C */
1091   PetscCall(MatSetMPIAIJKokkosWithGlobalCSRMatrix(C, MAT_REUSE_MATRIX, mm->C_global, mm->Cdstart));
1092   PetscFunctionReturn(0);
1093 }
1094 
1095 PetscErrorCode MatProductSymbolic_MPIAIJKokkos(Mat C)
1096 {
1097   Mat                          A, B;
1098   Mat_Product                 *product = C->product;
1099   MatProductType               ptype;
1100   MatProductData_MPIAIJKokkos *mmdata;
1101   MatMatStruct                *mm   = NULL;
1102   IS                           glob = NULL;
1103   const PetscInt              *garray;
1104   PetscInt                     m, n, M, N, sz;
1105   ConstMatColIdxKokkosView     l2g; /* map local col ids to global ones */
1106 
1107   PetscFunctionBegin;
1108   MatCheckProduct(C, 1);
1109   PetscCheck(!product->data, PetscObjectComm((PetscObject)C), PETSC_ERR_PLIB, "Product data not empty");
1110   ptype = product->type;
1111   A     = product->A;
1112   B     = product->B;
1113 
1114   switch (ptype) {
1115   case MATPRODUCT_AB:
1116     m = A->rmap->n;
1117     n = B->cmap->n;
1118     M = A->rmap->N;
1119     N = B->cmap->N;
1120     break;
1121   case MATPRODUCT_AtB:
1122     m = A->cmap->n;
1123     n = B->cmap->n;
1124     M = A->cmap->N;
1125     N = B->cmap->N;
1126     break;
1127   case MATPRODUCT_PtAP:
1128     m = B->cmap->n;
1129     n = B->cmap->n;
1130     M = B->cmap->N;
1131     N = B->cmap->N;
1132     break; /* BtAB */
1133   default:
1134     SETERRQ(PetscObjectComm((PetscObject)C), PETSC_ERR_PLIB, "Not for product type %s", MatProductTypes[ptype]);
1135   }
1136 
1137   PetscCall(MatSetSizes(C, m, n, M, N));
1138   PetscCall(PetscLayoutSetUp(C->rmap));
1139   PetscCall(PetscLayoutSetUp(C->cmap));
1140   PetscCall(MatSetType(C, ((PetscObject)A)->type_name));
1141 
1142   mmdata           = new MatProductData_MPIAIJKokkos();
1143   mmdata->reusesym = product->api_user;
1144 
1145   if (ptype == MATPRODUCT_AB) {
1146     mmdata->mmAB = new MatMatStruct_AB();
1147     PetscCall(MatProductSymbolic_MPIAIJKokkos_AB(product, A, B, mmdata->mmAB));
1148     mm = static_cast<MatMatStruct *>(mmdata->mmAB);
1149   } else if (ptype == MATPRODUCT_AtB) {
1150     mmdata->mmAtB = new MatMatStruct_AtB();
1151     auto atb      = mmdata->mmAtB;
1152     PetscCall(MatMPIAIJGetLocalMatMerge(B, MAT_INITIAL_MATRIX, &glob, &atb->B_local));
1153     PetscCall(ISGetIndices(glob, &garray));
1154     PetscCall(ISGetSize(glob, &sz));
1155     l2g = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), ConstMatColIdxKokkosViewHost(garray, sz));
1156     PetscCall(MatProductSymbolic_MPIAIJKokkos_AtB(product, A, atb->B_local, PETSC_TRUE, N, l2g, atb));
1157     PetscCall(ISRestoreIndices(glob, &garray));
1158     PetscCall(ISDestroy(&glob));
1159     mm = static_cast<MatMatStruct *>(atb);
1160   } else if (ptype == MATPRODUCT_PtAP) {    /* BtAB */
1161     mmdata->mmAB  = new MatMatStruct_AB();  /* tmp=A*B */
1162     mmdata->mmAtB = new MatMatStruct_AtB(); /* C=B^t*tmp */
1163     auto ab       = mmdata->mmAB;
1164     auto atb      = mmdata->mmAtB;
1165     PetscCall(MatProductSymbolic_MPIAIJKokkos_AB(product, A, B, ab));
1166     auto tmp = new Mat_SeqAIJKokkos(ab->C_global); /* Memory will be owned by ab->C_petsc */
1167     PetscCall(MatCreateSeqAIJKokkosWithCSRMatrix(PETSC_COMM_SELF, tmp, &ab->C_petsc));
1168     PetscCall(MatProductSymbolic_MPIAIJKokkos_AtB(product, B, ab->C_petsc, PETSC_FALSE, N, l2g /*not used*/, atb));
1169     mm = static_cast<MatMatStruct *>(atb);
1170   }
1171   /* Split the C_global into petsc A, B format */
1172   PetscCall(MatSetMPIAIJKokkosWithGlobalCSRMatrix(C, MAT_INITIAL_MATRIX, mm->C_global, mm->Cdstart));
1173   C->product->data       = mmdata;
1174   C->product->destroy    = MatProductDataDestroy_MPIAIJKokkos;
1175   C->ops->productnumeric = MatProductNumeric_MPIAIJKokkos;
1176   PetscFunctionReturn(0);
1177 }
1178 
1179 PETSC_INTERN PetscErrorCode MatProductSetFromOptions_MPIAIJKokkos(Mat mat)
1180 {
1181   Mat_Product *product = mat->product;
1182   PetscBool    match   = PETSC_FALSE;
1183   PetscBool    usecpu  = PETSC_FALSE;
1184 
1185   PetscFunctionBegin;
1186   MatCheckProduct(mat, 1);
1187   if (!product->A->boundtocpu && !product->B->boundtocpu) PetscCall(PetscObjectTypeCompare((PetscObject)product->B, ((PetscObject)product->A)->type_name, &match));
1188   if (match) { /* we can always fallback to the CPU if requested */
1189     switch (product->type) {
1190     case MATPRODUCT_AB:
1191       if (product->api_user) {
1192         PetscOptionsBegin(PetscObjectComm((PetscObject)mat), ((PetscObject)mat)->prefix, "MatMatMult", "Mat");
1193         PetscCall(PetscOptionsBool("-matmatmult_backend_cpu", "Use CPU code", "MatMatMult", usecpu, &usecpu, NULL));
1194         PetscOptionsEnd();
1195       } else {
1196         PetscOptionsBegin(PetscObjectComm((PetscObject)mat), ((PetscObject)mat)->prefix, "MatProduct_AB", "Mat");
1197         PetscCall(PetscOptionsBool("-mat_product_algorithm_backend_cpu", "Use CPU code", "MatMatMult", usecpu, &usecpu, NULL));
1198         PetscOptionsEnd();
1199       }
1200       break;
1201     case MATPRODUCT_AtB:
1202       if (product->api_user) {
1203         PetscOptionsBegin(PetscObjectComm((PetscObject)mat), ((PetscObject)mat)->prefix, "MatTransposeMatMult", "Mat");
1204         PetscCall(PetscOptionsBool("-mattransposematmult_backend_cpu", "Use CPU code", "MatTransposeMatMult", usecpu, &usecpu, NULL));
1205         PetscOptionsEnd();
1206       } else {
1207         PetscOptionsBegin(PetscObjectComm((PetscObject)mat), ((PetscObject)mat)->prefix, "MatProduct_AtB", "Mat");
1208         PetscCall(PetscOptionsBool("-mat_product_algorithm_backend_cpu", "Use CPU code", "MatTransposeMatMult", usecpu, &usecpu, NULL));
1209         PetscOptionsEnd();
1210       }
1211       break;
1212     case MATPRODUCT_PtAP:
1213       if (product->api_user) {
1214         PetscOptionsBegin(PetscObjectComm((PetscObject)mat), ((PetscObject)mat)->prefix, "MatPtAP", "Mat");
1215         PetscCall(PetscOptionsBool("-matptap_backend_cpu", "Use CPU code", "MatPtAP", usecpu, &usecpu, NULL));
1216         PetscOptionsEnd();
1217       } else {
1218         PetscOptionsBegin(PetscObjectComm((PetscObject)mat), ((PetscObject)mat)->prefix, "MatProduct_PtAP", "Mat");
1219         PetscCall(PetscOptionsBool("-mat_product_algorithm_backend_cpu", "Use CPU code", "MatPtAP", usecpu, &usecpu, NULL));
1220         PetscOptionsEnd();
1221       }
1222       break;
1223     default:
1224       break;
1225     }
1226     match = (PetscBool)!usecpu;
1227   }
1228   if (match) {
1229     switch (product->type) {
1230     case MATPRODUCT_AB:
1231     case MATPRODUCT_AtB:
1232     case MATPRODUCT_PtAP:
1233       mat->ops->productsymbolic = MatProductSymbolic_MPIAIJKokkos;
1234       break;
1235     default:
1236       break;
1237     }
1238   }
1239   /* fallback to MPIAIJ ops */
1240   if (!mat->ops->productsymbolic) PetscCall(MatProductSetFromOptions_MPIAIJ(mat));
1241   PetscFunctionReturn(0);
1242 }
1243 
1244 static PetscErrorCode MatSetPreallocationCOO_MPIAIJKokkos(Mat mat, PetscCount coo_n, PetscInt coo_i[], PetscInt coo_j[])
1245 {
1246   Mat_MPIAIJ       *mpiaij = (Mat_MPIAIJ *)mat->data;
1247   Mat_MPIAIJKokkos *mpikok;
1248 
1249   PetscFunctionBegin;
1250   PetscCall(MatSetPreallocationCOO_MPIAIJ(mat, coo_n, coo_i, coo_j)); /* mpiaij->A,B's type is set to seqaijkokkos */
1251   mat->preallocated = PETSC_TRUE;
1252   PetscCall(MatAssemblyBegin(mat, MAT_FINAL_ASSEMBLY));
1253   PetscCall(MatAssemblyEnd(mat, MAT_FINAL_ASSEMBLY));
1254   PetscCall(MatZeroEntries(mat));
1255   mpikok = static_cast<Mat_MPIAIJKokkos *>(mpiaij->spptr);
1256   delete mpikok;
1257   mpiaij->spptr = new Mat_MPIAIJKokkos(mpiaij);
1258   PetscFunctionReturn(0);
1259 }
1260 
1261 static PetscErrorCode MatSetValuesCOO_MPIAIJKokkos(Mat mat, const PetscScalar v[], InsertMode imode)
1262 {
1263   Mat_MPIAIJ                 *mpiaij = static_cast<Mat_MPIAIJ *>(mat->data);
1264   Mat_MPIAIJKokkos           *mpikok = static_cast<Mat_MPIAIJKokkos *>(mpiaij->spptr);
1265   Mat                         A = mpiaij->A, B = mpiaij->B;
1266   PetscCount                  Annz = mpiaij->Annz, Annz2 = mpiaij->Annz2, Bnnz = mpiaij->Bnnz, Bnnz2 = mpiaij->Bnnz2;
1267   MatScalarKokkosView         Aa, Ba;
1268   MatScalarKokkosView         v1;
1269   MatScalarKokkosView        &vsend  = mpikok->sendbuf_d;
1270   const MatScalarKokkosView  &v2     = mpikok->recvbuf_d;
1271   const PetscCountKokkosView &Ajmap1 = mpikok->Ajmap1_d, Ajmap2 = mpikok->Ajmap2_d, Aimap2 = mpikok->Aimap2_d;
1272   const PetscCountKokkosView &Bjmap1 = mpikok->Bjmap1_d, Bjmap2 = mpikok->Bjmap2_d, Bimap2 = mpikok->Bimap2_d;
1273   const PetscCountKokkosView &Aperm1 = mpikok->Aperm1_d, Aperm2 = mpikok->Aperm2_d, Bperm1 = mpikok->Bperm1_d, Bperm2 = mpikok->Bperm2_d;
1274   const PetscCountKokkosView &Cperm1 = mpikok->Cperm1_d;
1275   PetscMemType                memtype;
1276 
1277   PetscFunctionBegin;
1278   PetscCall(PetscGetMemType(v, &memtype)); /* Return PETSC_MEMTYPE_HOST when v is NULL */
1279   if (PetscMemTypeHost(memtype)) {         /* If user gave v[] in host, we need to copy it to device if any */
1280     v1 = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), MatScalarKokkosViewHost((PetscScalar *)v, mpiaij->coo_n));
1281   } else {
1282     v1 = MatScalarKokkosView((PetscScalar *)v, mpiaij->coo_n); /* Directly use v[]'s memory */
1283   }
1284 
1285   if (imode == INSERT_VALUES) {
1286     PetscCall(MatSeqAIJGetKokkosViewWrite(A, &Aa)); /* write matrix values */
1287     PetscCall(MatSeqAIJGetKokkosViewWrite(B, &Ba));
1288   } else {
1289     PetscCall(MatSeqAIJGetKokkosView(A, &Aa)); /* read & write matrix values */
1290     PetscCall(MatSeqAIJGetKokkosView(B, &Ba));
1291   }
1292 
1293   /* Pack entries to be sent to remote */
1294   Kokkos::parallel_for(
1295     vsend.extent(0), KOKKOS_LAMBDA(const PetscCount i) { vsend(i) = v1(Cperm1(i)); });
1296 
1297   /* Send remote entries to their owner and overlap the communication with local computation */
1298   PetscCall(PetscSFReduceWithMemTypeBegin(mpiaij->coo_sf, MPIU_SCALAR, PETSC_MEMTYPE_KOKKOS, vsend.data(), PETSC_MEMTYPE_KOKKOS, v2.data(), MPI_REPLACE));
1299   /* Add local entries to A and B in one kernel */
1300   Kokkos::parallel_for(
1301     Annz + Bnnz, KOKKOS_LAMBDA(PetscCount i) {
1302       PetscScalar sum = 0.0;
1303       if (i < Annz) {
1304         for (PetscCount k = Ajmap1(i); k < Ajmap1(i + 1); k++) sum += v1(Aperm1(k));
1305         Aa(i) = (imode == INSERT_VALUES ? 0.0 : Aa(i)) + sum;
1306       } else {
1307         i -= Annz;
1308         for (PetscCount k = Bjmap1(i); k < Bjmap1(i + 1); k++) sum += v1(Bperm1(k));
1309         Ba(i) = (imode == INSERT_VALUES ? 0.0 : Ba(i)) + sum;
1310       }
1311     });
1312   PetscCall(PetscSFReduceEnd(mpiaij->coo_sf, MPIU_SCALAR, vsend.data(), v2.data(), MPI_REPLACE));
1313 
1314   /* Add received remote entries to A and B in one kernel */
1315   Kokkos::parallel_for(
1316     Annz2 + Bnnz2, KOKKOS_LAMBDA(PetscCount i) {
1317       if (i < Annz2) {
1318         for (PetscCount k = Ajmap2(i); k < Ajmap2(i + 1); k++) Aa(Aimap2(i)) += v2(Aperm2(k));
1319       } else {
1320         i -= Annz2;
1321         for (PetscCount k = Bjmap2(i); k < Bjmap2(i + 1); k++) Ba(Bimap2(i)) += v2(Bperm2(k));
1322       }
1323     });
1324 
1325   if (imode == INSERT_VALUES) {
1326     PetscCall(MatSeqAIJRestoreKokkosViewWrite(A, &Aa)); /* Increase A & B's state etc. */
1327     PetscCall(MatSeqAIJRestoreKokkosViewWrite(B, &Ba));
1328   } else {
1329     PetscCall(MatSeqAIJRestoreKokkosView(A, &Aa));
1330     PetscCall(MatSeqAIJRestoreKokkosView(B, &Ba));
1331   }
1332   PetscFunctionReturn(0);
1333 }
1334 
1335 PetscErrorCode MatDestroy_MPIAIJKokkos(Mat A)
1336 {
1337   Mat_MPIAIJ *mpiaij = (Mat_MPIAIJ *)A->data;
1338 
1339   PetscFunctionBegin;
1340   PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatMPIAIJSetPreallocation_C", NULL));
1341   PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatMPIAIJGetLocalMatMerge_C", NULL));
1342   PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatSetPreallocationCOO_C", NULL));
1343   PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatSetValuesCOO_C", NULL));
1344   delete (Mat_MPIAIJKokkos *)mpiaij->spptr;
1345   PetscCall(MatDestroy_MPIAIJ(A));
1346   PetscFunctionReturn(0);
1347 }
1348 
1349 PETSC_INTERN PetscErrorCode MatConvert_MPIAIJ_MPIAIJKokkos(Mat A, MatType mtype, MatReuse reuse, Mat *newmat)
1350 {
1351   Mat         B;
1352   Mat_MPIAIJ *a;
1353 
1354   PetscFunctionBegin;
1355   if (reuse == MAT_INITIAL_MATRIX) {
1356     PetscCall(MatDuplicate(A, MAT_COPY_VALUES, newmat));
1357   } else if (reuse == MAT_REUSE_MATRIX) {
1358     PetscCall(MatCopy(A, *newmat, SAME_NONZERO_PATTERN));
1359   }
1360   B = *newmat;
1361 
1362   B->boundtocpu = PETSC_FALSE;
1363   PetscCall(PetscFree(B->defaultvectype));
1364   PetscCall(PetscStrallocpy(VECKOKKOS, &B->defaultvectype));
1365   PetscCall(PetscObjectChangeTypeName((PetscObject)B, MATMPIAIJKOKKOS));
1366 
1367   a = static_cast<Mat_MPIAIJ *>(A->data);
1368   if (a->A) PetscCall(MatSetType(a->A, MATSEQAIJKOKKOS));
1369   if (a->B) PetscCall(MatSetType(a->B, MATSEQAIJKOKKOS));
1370   if (a->lvec) PetscCall(VecSetType(a->lvec, VECSEQKOKKOS));
1371 
1372   B->ops->assemblyend           = MatAssemblyEnd_MPIAIJKokkos;
1373   B->ops->mult                  = MatMult_MPIAIJKokkos;
1374   B->ops->multadd               = MatMultAdd_MPIAIJKokkos;
1375   B->ops->multtranspose         = MatMultTranspose_MPIAIJKokkos;
1376   B->ops->productsetfromoptions = MatProductSetFromOptions_MPIAIJKokkos;
1377   B->ops->destroy               = MatDestroy_MPIAIJKokkos;
1378 
1379   PetscCall(PetscObjectComposeFunction((PetscObject)B, "MatMPIAIJSetPreallocation_C", MatMPIAIJSetPreallocation_MPIAIJKokkos));
1380   PetscCall(PetscObjectComposeFunction((PetscObject)B, "MatMPIAIJGetLocalMatMerge_C", MatMPIAIJGetLocalMatMerge_MPIAIJKokkos));
1381   PetscCall(PetscObjectComposeFunction((PetscObject)B, "MatSetPreallocationCOO_C", MatSetPreallocationCOO_MPIAIJKokkos));
1382   PetscCall(PetscObjectComposeFunction((PetscObject)B, "MatSetValuesCOO_C", MatSetValuesCOO_MPIAIJKokkos));
1383   PetscFunctionReturn(0);
1384 }
1385 /*MC
1386    MATAIJKOKKOS - "mpiaijkokkos", a matrix type to be used for CSR sparse matrices with Kokkos
1387 
1388    A matrix type type using Kokkos-Kernels CrsMatrix type for portability across different device types
1389 
1390    Options Database Keys:
1391 .  -mat_type aijkokkos - sets the matrix type to "aijkokkos" during a call to MatSetFromOptions()
1392 
1393   Level: beginner
1394 
1395 .seealso: `MatCreateAIJKokkos()`, `MATSEQAIJKOKKOS`, `MATSEQAIJ`, `MATMPIAIJ`
1396 M*/
1397 PETSC_EXTERN PetscErrorCode MatCreate_MPIAIJKokkos(Mat A)
1398 {
1399   PetscFunctionBegin;
1400   PetscCall(PetscKokkosInitializeCheck());
1401   PetscCall(MatCreate_MPIAIJ(A));
1402   PetscCall(MatConvert_MPIAIJ_MPIAIJKokkos(A, MATMPIAIJKOKKOS, MAT_INPLACE_MATRIX, &A));
1403   PetscFunctionReturn(0);
1404 }
1405 
1406 /*@C
1407    MatCreateAIJKokkos - Creates a sparse matrix in `MATAIJKOKOS` (compressed row) format
1408    (the default parallel PETSc format).  This matrix will ultimately pushed down
1409    to Kokkos for calculations. For good matrix
1410    assembly performance the user should preallocate the matrix storage by setting
1411    the parameter nz (or the array nnz).  By setting these parameters accurately,
1412    performance during matrix assembly can be increased by more than a factor of 50.
1413 
1414    Collective
1415 
1416    Input Parameters:
1417 +  comm - MPI communicator, set to `PETSC_COMM_SELF`
1418 .  m - number of rows
1419 .  n - number of columns
1420 .  nz - number of nonzeros per row (same for all rows)
1421 -  nnz - array containing the number of nonzeros in the various rows
1422          (possibly different for each row) or NULL
1423 
1424    Output Parameter:
1425 .  A - the matrix
1426 
1427    It is recommended that one use the `MatCreate()`, `MatSetType()` and/or `MatSetFromOptions()`,
1428    MatXXXXSetPreallocation() paradigm instead of this routine directly.
1429    [MatXXXXSetPreallocation() is, for example, `MatSeqAIJSetPreallocation()`]
1430 
1431    Notes:
1432    If nnz is given then nz is ignored
1433 
1434    The AIJ format, also called compressed row storage), is fully compatible with standard Fortran 77
1435    storage.  That is, the stored row and column indices can begin at
1436    either one (as in Fortran) or zero.  See the users' manual for details.
1437 
1438    Specify the preallocated storage with either nz or nnz (not both).
1439    Set nz = `PETSC_DEFAULT` and nnz = NULL for PETSc to control dynamic memory
1440    allocation.  For large problems you MUST preallocate memory or you
1441    will get TERRIBLE performance, see the users' manual chapter on matrices.
1442 
1443    By default, this format uses inodes (identical nodes) when possible, to
1444    improve numerical efficiency of matrix-vector products and solves. We
1445    search for consecutive rows with the same nonzero structure, thereby
1446    reusing matrix information to achieve increased efficiency.
1447 
1448    Level: intermediate
1449 
1450 .seealso: `MATAIJKOKOS`, `MATSEQAIJKOKOS`, `MATMPIAIJKOKOS`, `MatCreate()`, `MatCreateAIJ()`, `MatSetValues()`, `MatSeqAIJSetColumnIndices()`, `MatCreateSeqAIJWithArrays()`, `MatCreateAIJ()`, `MATMPIAIJKOKKOS`, `MATAIJKOKKOS`
1451 @*/
1452 PetscErrorCode MatCreateAIJKokkos(MPI_Comm comm, PetscInt m, PetscInt n, PetscInt M, PetscInt N, PetscInt d_nz, const PetscInt d_nnz[], PetscInt o_nz, const PetscInt o_nnz[], Mat *A)
1453 {
1454   PetscMPIInt size;
1455 
1456   PetscFunctionBegin;
1457   PetscCall(MatCreate(comm, A));
1458   PetscCall(MatSetSizes(*A, m, n, M, N));
1459   PetscCallMPI(MPI_Comm_size(comm, &size));
1460   if (size > 1) {
1461     PetscCall(MatSetType(*A, MATMPIAIJKOKKOS));
1462     PetscCall(MatMPIAIJSetPreallocation(*A, d_nz, d_nnz, o_nz, o_nnz));
1463   } else {
1464     PetscCall(MatSetType(*A, MATSEQAIJKOKKOS));
1465     PetscCall(MatSeqAIJSetPreallocation(*A, d_nz, d_nnz));
1466   }
1467   PetscFunctionReturn(0);
1468 }
1469 
1470 // get GPU pointer to stripped down Mat. For both Seq and MPI Mat.
1471 PetscErrorCode MatKokkosGetDeviceMatWrite(Mat A, PetscSplitCSRDataStructure *B)
1472 {
1473   PetscMPIInt                size, rank;
1474   MPI_Comm                   comm;
1475   PetscSplitCSRDataStructure d_mat = NULL;
1476 
1477   PetscFunctionBegin;
1478   PetscCall(PetscObjectGetComm((PetscObject)A, &comm));
1479   PetscCallMPI(MPI_Comm_size(comm, &size));
1480   PetscCallMPI(MPI_Comm_rank(comm, &rank));
1481   if (size == 1) {
1482     PetscCall(MatSeqAIJKokkosGetDeviceMat(A, &d_mat));
1483     PetscCall(MatSeqAIJKokkosModifyDevice(A)); /* Since we are going to modify matrix values on device */
1484   } else {
1485     Mat_MPIAIJ *aij = (Mat_MPIAIJ *)A->data;
1486     PetscCall(MatSeqAIJKokkosGetDeviceMat(aij->A, &d_mat));
1487     PetscCall(MatSeqAIJKokkosModifyDevice(aij->A));
1488     PetscCall(MatSeqAIJKokkosModifyDevice(aij->B));
1489     PetscCheck(A->nooffprocentries || aij->donotstash, PetscObjectComm((PetscObject)A), PETSC_ERR_SUP, "Device assembly does not currently support offproc values insertion. Use MatSetOption(A,MAT_NO_OFF_PROC_ENTRIES,PETSC_TRUE) or MatSetOption(A,MAT_IGNORE_OFF_PROC_ENTRIES,PETSC_TRUE)");
1490   }
1491   // act like MatSetValues because not called on host
1492   if (A->assembled) {
1493     if (A->was_assembled) PetscCall(PetscInfo(A, "Assemble more than once already\n"));
1494     A->was_assembled = PETSC_TRUE; // this is done (lazy) in MatAssemble but we are not calling it anymore - done in AIJ AssemblyEnd, need here?
1495   } else {
1496     PetscCall(PetscInfo(A, "Warning !assemble ??? assembled=%" PetscInt_FMT "\n", A->assembled));
1497   }
1498   if (!d_mat) {
1499     struct _n_SplitCSRMat h_mat; /* host container */
1500     Mat_SeqAIJKokkos     *aijkokA;
1501     Mat_SeqAIJ           *jaca;
1502     PetscInt              n = A->rmap->n, nnz;
1503     Mat                   Amat;
1504     PetscInt             *colmap;
1505 
1506     /* create and copy h_mat */
1507     h_mat.M = A->cmap->N; // use for debug build
1508     PetscCall(PetscInfo(A, "Create device matrix in Kokkos\n"));
1509     if (size == 1) {
1510       Amat            = A;
1511       jaca            = (Mat_SeqAIJ *)A->data;
1512       h_mat.rstart    = 0;
1513       h_mat.rend      = A->rmap->n;
1514       h_mat.cstart    = 0;
1515       h_mat.cend      = A->cmap->n;
1516       h_mat.offdiag.i = h_mat.offdiag.j = NULL;
1517       h_mat.offdiag.a                   = NULL;
1518       aijkokA                           = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
1519     } else {
1520       Mat_MPIAIJ       *aij  = (Mat_MPIAIJ *)A->data;
1521       Mat_SeqAIJ       *jacb = (Mat_SeqAIJ *)aij->B->data;
1522       PetscInt          ii;
1523       Mat_SeqAIJKokkos *aijkokB;
1524 
1525       Amat    = aij->A;
1526       aijkokA = static_cast<Mat_SeqAIJKokkos *>(aij->A->spptr);
1527       aijkokB = static_cast<Mat_SeqAIJKokkos *>(aij->B->spptr);
1528       jaca    = (Mat_SeqAIJ *)aij->A->data;
1529       PetscCheck(!aij->B->cmap->n || aij->garray, comm, PETSC_ERR_PLIB, "MPIAIJ Matrix was assembled but is missing garray");
1530       PetscCheck(aij->B->rmap->n == aij->A->rmap->n, comm, PETSC_ERR_SUP, "Only support aij->B->rmap->n == aij->A->rmap->n");
1531       aij->donotstash          = PETSC_TRUE;
1532       aij->A->nooffprocentries = aij->B->nooffprocentries = A->nooffprocentries = PETSC_TRUE;
1533       jaca->nonew = jacb->nonew = PETSC_TRUE; // no more disassembly
1534       PetscCall(PetscCalloc1(A->cmap->N, &colmap));
1535       for (ii = 0; ii < aij->B->cmap->n; ii++) colmap[aij->garray[ii]] = ii + 1;
1536       // allocate B copy data
1537       h_mat.rstart = A->rmap->rstart;
1538       h_mat.rend   = A->rmap->rend;
1539       h_mat.cstart = A->cmap->rstart;
1540       h_mat.cend   = A->cmap->rend;
1541       nnz          = jacb->i[n];
1542       if (jacb->compressedrow.use) {
1543         const Kokkos::View<PetscInt *, Kokkos::HostSpace, Kokkos::MemoryTraits<Kokkos::Unmanaged>> h_i_k(jacb->i, n + 1);
1544         aijkokB->i_uncompressed_d = Kokkos::View<PetscInt *>(Kokkos::create_mirror(DefaultMemorySpace(), h_i_k));
1545         Kokkos::deep_copy(aijkokB->i_uncompressed_d, h_i_k);
1546         h_mat.offdiag.i = aijkokB->i_uncompressed_d.data();
1547       } else {
1548         h_mat.offdiag.i = aijkokB->i_device_data();
1549       }
1550       h_mat.offdiag.j = aijkokB->j_device_data();
1551       h_mat.offdiag.a = aijkokB->a_device_data();
1552       {
1553         Kokkos::View<PetscInt *, Kokkos::HostSpace, Kokkos::MemoryTraits<Kokkos::Unmanaged>> h_colmap_k(colmap, A->cmap->N);
1554         aijkokB->colmap_d = Kokkos::View<PetscInt *>(Kokkos::create_mirror(DefaultMemorySpace(), h_colmap_k));
1555         Kokkos::deep_copy(aijkokB->colmap_d, h_colmap_k);
1556         h_mat.colmap = aijkokB->colmap_d.data();
1557         PetscCall(PetscFree(colmap));
1558       }
1559       h_mat.offdiag.ignorezeroentries = jacb->ignorezeroentries;
1560       h_mat.offdiag.n                 = n;
1561     }
1562     // allocate A copy data
1563     nnz                          = jaca->i[n];
1564     h_mat.diag.n                 = n;
1565     h_mat.diag.ignorezeroentries = jaca->ignorezeroentries;
1566     PetscCallMPI(MPI_Comm_rank(comm, &h_mat.rank));
1567     PetscCheck(!jaca->compressedrow.use, PETSC_COMM_SELF, PETSC_ERR_PLIB, "A does not support compressed row (todo)");
1568     h_mat.diag.i = aijkokA->i_device_data();
1569     h_mat.diag.j = aijkokA->j_device_data();
1570     h_mat.diag.a = aijkokA->a_device_data();
1571     // copy pointers and metdata to device
1572     PetscCall(MatSeqAIJKokkosSetDeviceMat(Amat, &h_mat));
1573     PetscCall(MatSeqAIJKokkosGetDeviceMat(Amat, &d_mat));
1574     PetscCall(PetscInfo(A, "Create device Mat n=%" PetscInt_FMT " nnz=%" PetscInt_FMT "\n", h_mat.diag.n, nnz));
1575   }
1576   *B           = d_mat;       // return it, set it in Mat, and set it up
1577   A->assembled = PETSC_FALSE; // ready to write with matsetvalues - this done (lazy) in normal MatSetValues
1578   PetscFunctionReturn(0);
1579 }
1580 
1581 PETSC_INTERN PetscErrorCode MatSeqAIJKokkosGetOffloadMask(Mat A, const char **mask)
1582 {
1583   Mat_SeqAIJKokkos *aijkok = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
1584 
1585   PetscFunctionBegin;
1586   if (!aijkok) *mask = "AIJKOK_UNALLOCATED";
1587   else if (aijkok->a_dual.need_sync_host()) *mask = "PETSC_OFFLOAD_GPU";
1588   else if (aijkok->a_dual.need_sync_device()) *mask = "PETSC_OFFLOAD_CPU";
1589   else *mask = "PETSC_OFFLOAD_BOTH";
1590   PetscFunctionReturn(0);
1591 }
1592 
1593 PETSC_INTERN PetscErrorCode MatAIJKokkosPrintOffloadMask(Mat A)
1594 {
1595   PetscMPIInt size;
1596   Mat         Ad, Ao;
1597   const char *amask, *bmask;
1598 
1599   PetscFunctionBegin;
1600   PetscCallMPI(MPI_Comm_size(PetscObjectComm((PetscObject)A), &size));
1601 
1602   if (size == 1) {
1603     PetscCall(MatSeqAIJKokkosGetOffloadMask(A, &amask));
1604     PetscCall(PetscPrintf(PETSC_COMM_SELF, "%s\n", amask));
1605   } else {
1606     Ad = ((Mat_MPIAIJ *)A->data)->A;
1607     Ao = ((Mat_MPIAIJ *)A->data)->B;
1608     PetscCall(MatSeqAIJKokkosGetOffloadMask(Ad, &amask));
1609     PetscCall(MatSeqAIJKokkosGetOffloadMask(Ao, &bmask));
1610     PetscCall(PetscPrintf(PETSC_COMM_SELF, "Diag : Off-diag = %s : %s\n", amask, bmask));
1611   }
1612   PetscFunctionReturn(0);
1613 }
1614