xref: /petsc/src/mat/impls/aij/mpi/kokkos/mpiaijkok.kokkos.cxx (revision ed00a6c129cbc4f1b4754b50ab0c4e7536acfadd)
1 #include <petscvec_kokkos.hpp>
2 #include <petscpkg_version.h>
3 #include <petscsf.h>
4 #include <petsc/private/sfimpl.h>
5 #include <../src/mat/impls/aij/mpi/mpiaij.h>
6 #include <../src/mat/impls/aij/mpi/kokkos/mpiaijkok.hpp>
7 #include <KokkosSparse_spadd.hpp>
8 #include <KokkosSparse_spgemm.hpp>
9 
10 PetscErrorCode MatAssemblyEnd_MPIAIJKokkos(Mat A, MatAssemblyType mode)
11 {
12   Mat_SeqAIJKokkos *aijkok;
13   Mat_MPIAIJ       *mpiaij = (Mat_MPIAIJ *)A->data;
14 
15   PetscFunctionBegin;
16   PetscCall(MatAssemblyEnd_MPIAIJ(A, mode));
17   /* E.g., MatCreateSubMatrix() calls MatCreateMPIAIJWithSeqAIJ(comm,A,B,..), which creates Bnew of SEQAIJ and destroys B of SEQAIJKOKKOS.
18      Thus we finalize A/B/lvec's type in MatAssemblyEnd() to handle various cases.
19    */
20   if (mode == MAT_FINAL_ASSEMBLY) {
21     PetscCall(MatSetType(mpiaij->A, MATSEQAIJKOKKOS));
22     PetscCall(MatSetType(mpiaij->B, MATSEQAIJKOKKOS));
23     PetscCall(VecSetType(mpiaij->lvec, VECSEQKOKKOS));
24   }
25   aijkok = static_cast<Mat_SeqAIJKokkos *>(((Mat_MPIAIJ *)A->data)->A->spptr); /* Access spptr after MatAssemblyEnd_MPIAIJ(), which might have deleted old spptr */
26   if (aijkok && aijkok->device_mat_d.data()) {
27     A->offloadmask = PETSC_OFFLOAD_GPU; // in GPU mode, no going back. MatSetValues checks this
28   }
29 
30   PetscFunctionReturn(PETSC_SUCCESS);
31 }
32 
33 PetscErrorCode MatMPIAIJSetPreallocation_MPIAIJKokkos(Mat mat, PetscInt d_nz, const PetscInt d_nnz[], PetscInt o_nz, const PetscInt o_nnz[])
34 {
35   Mat_MPIAIJ *mpiaij = (Mat_MPIAIJ *)mat->data;
36 
37   PetscFunctionBegin;
38   PetscCall(PetscLayoutSetUp(mat->rmap));
39   PetscCall(PetscLayoutSetUp(mat->cmap));
40 #if defined(PETSC_USE_DEBUG)
41   if (d_nnz) {
42     PetscInt i;
43     for (i = 0; i < mat->rmap->n; i++) PetscCheck(d_nnz[i] >= 0, PETSC_COMM_SELF, PETSC_ERR_ARG_OUTOFRANGE, "d_nnz cannot be less than 0: local row %" PetscInt_FMT " value %" PetscInt_FMT, i, d_nnz[i]);
44   }
45   if (o_nnz) {
46     PetscInt i;
47     for (i = 0; i < mat->rmap->n; i++) PetscCheck(o_nnz[i] >= 0, PETSC_COMM_SELF, PETSC_ERR_ARG_OUTOFRANGE, "o_nnz cannot be less than 0: local row %" PetscInt_FMT " value %" PetscInt_FMT, i, o_nnz[i]);
48   }
49 #endif
50 #if defined(PETSC_USE_CTABLE)
51   PetscCall(PetscHMapIDestroy(&mpiaij->colmap));
52 #else
53   PetscCall(PetscFree(mpiaij->colmap));
54 #endif
55   PetscCall(PetscFree(mpiaij->garray));
56   PetscCall(VecDestroy(&mpiaij->lvec));
57   PetscCall(VecScatterDestroy(&mpiaij->Mvctx));
58   /* Because the B will have been resized we simply destroy it and create a new one each time */
59   PetscCall(MatDestroy(&mpiaij->B));
60 
61   if (!mpiaij->A) {
62     PetscCall(MatCreate(PETSC_COMM_SELF, &mpiaij->A));
63     PetscCall(MatSetSizes(mpiaij->A, mat->rmap->n, mat->cmap->n, mat->rmap->n, mat->cmap->n));
64   }
65   if (!mpiaij->B) {
66     PetscMPIInt size;
67     PetscCallMPI(MPI_Comm_size(PetscObjectComm((PetscObject)mat), &size));
68     PetscCall(MatCreate(PETSC_COMM_SELF, &mpiaij->B));
69     PetscCall(MatSetSizes(mpiaij->B, mat->rmap->n, size > 1 ? mat->cmap->N : 0, mat->rmap->n, size > 1 ? mat->cmap->N : 0));
70   }
71   PetscCall(MatSetType(mpiaij->A, MATSEQAIJKOKKOS));
72   PetscCall(MatSetType(mpiaij->B, MATSEQAIJKOKKOS));
73   PetscCall(MatSeqAIJSetPreallocation(mpiaij->A, d_nz, d_nnz));
74   PetscCall(MatSeqAIJSetPreallocation(mpiaij->B, o_nz, o_nnz));
75   mat->preallocated = PETSC_TRUE;
76   PetscFunctionReturn(PETSC_SUCCESS);
77 }
78 
79 PetscErrorCode MatMult_MPIAIJKokkos(Mat mat, Vec xx, Vec yy)
80 {
81   Mat_MPIAIJ *mpiaij = (Mat_MPIAIJ *)mat->data;
82   PetscInt    nt;
83 
84   PetscFunctionBegin;
85   PetscCall(VecGetLocalSize(xx, &nt));
86   PetscCheck(nt == mat->cmap->n, PETSC_COMM_SELF, PETSC_ERR_ARG_SIZ, "Incompatible partition of mat (%" PetscInt_FMT ") and xx (%" PetscInt_FMT ")", mat->cmap->n, nt);
87   PetscCall(VecScatterBegin(mpiaij->Mvctx, xx, mpiaij->lvec, INSERT_VALUES, SCATTER_FORWARD));
88   PetscCall((*mpiaij->A->ops->mult)(mpiaij->A, xx, yy));
89   PetscCall(VecScatterEnd(mpiaij->Mvctx, xx, mpiaij->lvec, INSERT_VALUES, SCATTER_FORWARD));
90   PetscCall((*mpiaij->B->ops->multadd)(mpiaij->B, mpiaij->lvec, yy, yy));
91   PetscFunctionReturn(PETSC_SUCCESS);
92 }
93 
94 PetscErrorCode MatMultAdd_MPIAIJKokkos(Mat mat, Vec xx, Vec yy, Vec zz)
95 {
96   Mat_MPIAIJ *mpiaij = (Mat_MPIAIJ *)mat->data;
97   PetscInt    nt;
98 
99   PetscFunctionBegin;
100   PetscCall(VecGetLocalSize(xx, &nt));
101   PetscCheck(nt == mat->cmap->n, PETSC_COMM_SELF, PETSC_ERR_ARG_SIZ, "Incompatible partition of mat (%" PetscInt_FMT ") and xx (%" PetscInt_FMT ")", mat->cmap->n, nt);
102   PetscCall(VecScatterBegin(mpiaij->Mvctx, xx, mpiaij->lvec, INSERT_VALUES, SCATTER_FORWARD));
103   PetscCall((*mpiaij->A->ops->multadd)(mpiaij->A, xx, yy, zz));
104   PetscCall(VecScatterEnd(mpiaij->Mvctx, xx, mpiaij->lvec, INSERT_VALUES, SCATTER_FORWARD));
105   PetscCall((*mpiaij->B->ops->multadd)(mpiaij->B, mpiaij->lvec, zz, zz));
106   PetscFunctionReturn(PETSC_SUCCESS);
107 }
108 
109 PetscErrorCode MatMultTranspose_MPIAIJKokkos(Mat mat, Vec xx, Vec yy)
110 {
111   Mat_MPIAIJ *mpiaij = (Mat_MPIAIJ *)mat->data;
112   PetscInt    nt;
113 
114   PetscFunctionBegin;
115   PetscCall(VecGetLocalSize(xx, &nt));
116   PetscCheck(nt == mat->rmap->n, PETSC_COMM_SELF, PETSC_ERR_ARG_SIZ, "Incompatible partition of mat (%" PetscInt_FMT ") and xx (%" PetscInt_FMT ")", mat->rmap->n, nt);
117   PetscCall((*mpiaij->B->ops->multtranspose)(mpiaij->B, xx, mpiaij->lvec));
118   PetscCall((*mpiaij->A->ops->multtranspose)(mpiaij->A, xx, yy));
119   PetscCall(VecScatterBegin(mpiaij->Mvctx, mpiaij->lvec, yy, ADD_VALUES, SCATTER_REVERSE));
120   PetscCall(VecScatterEnd(mpiaij->Mvctx, mpiaij->lvec, yy, ADD_VALUES, SCATTER_REVERSE));
121   PetscFunctionReturn(PETSC_SUCCESS);
122 }
123 
124 /* Merge the "A, B" matrices of mat into a matrix C.  mat's type is MPIAIJKOKKOS. C's type is MATSEQAIJKOKKOS.
125    A is put before B. C's size would be A->rmap->n by (A->cmap->n + B->cmap->n).
126    C still uses local column ids. Their corresponding global column ids are returned in glob.
127 */
128 PetscErrorCode MatMPIAIJGetLocalMatMerge_MPIAIJKokkos(Mat mat, MatReuse reuse, IS *glob, Mat *C)
129 {
130   Mat             Ad, Ao;
131   const PetscInt *cmap;
132 
133   PetscFunctionBegin;
134   PetscCall(MatMPIAIJGetSeqAIJ(mat, &Ad, &Ao, &cmap));
135   PetscCall(MatSeqAIJKokkosMergeMats(Ad, Ao, reuse, C));
136   if (glob) {
137     PetscInt cst, i, dn, on, *gidx;
138     PetscCall(MatGetLocalSize(Ad, NULL, &dn));
139     PetscCall(MatGetLocalSize(Ao, NULL, &on));
140     PetscCall(MatGetOwnershipRangeColumn(mat, &cst, NULL));
141     PetscCall(PetscMalloc1(dn + on, &gidx));
142     for (i = 0; i < dn; i++) gidx[i] = cst + i;
143     for (i = 0; i < on; i++) gidx[i + dn] = cmap[i];
144     PetscCall(ISCreateGeneral(PetscObjectComm((PetscObject)Ad), dn + on, gidx, PETSC_OWN_POINTER, glob));
145   }
146   PetscFunctionReturn(PETSC_SUCCESS);
147 }
148 
149 /* Structs used in matrix products of type C=AB, C=A^tB and C=B^tAB */
150 struct MatMatStruct {
151   PetscInt            n, *garray;     // C's garray and its size.
152   KokkosCsrMatrix     Cd, Co;         // C is in split form matrices (all in local column indcies)
153   KokkosCsrMatrix     C1, C2, C3, C4; // intermediate mat products
154   KokkosCsrMatrix     C2_mid, C4_mid; // alias of C2, C4; share their a[], i[], but with different j[] (hence column size)
155   PetscIntKokkosView  E_NzLeft;
156   PetscSF             sf = nullptr; // SF to bcast or reduce matrices E to F
157   MatScalarKokkosView rootBuf, leafBuf;
158   KokkosCsrMatrix     Fd, Fo; // F in split form
159 
160   KernelHandle kh1; // compute C1, add C1+C3 or C1+Fd
161   KernelHandle kh2; // compute C2, add C2+C4 or C2+Fo
162   KernelHandle kh3; // compute C3
163   KernelHandle kh4; // compute C4
164 
165   PetscInt E_TeamSize; // kernel launching parameters in merging E or splitting F
166   PetscInt E_VectorLength;
167   PetscInt E_RowsPerTeam;
168   PetscInt F_TeamSize;
169   PetscInt F_VectorLength;
170   PetscInt F_RowsPerTeam;
171 
172   ~MatMatStruct()
173   {
174     PetscFunctionBegin;
175     PetscCallAbort(PETSC_COMM_SELF, PetscSFDestroy(&sf));
176     PetscFunctionReturnVoid();
177   }
178 };
179 
180 struct MatMatStruct_AB : public MatMatStruct {
181   PetscIntKokkosView F_NzLeft; // plans to split F (in leafbuf) into Fd, Fo
182   PetscIntKokkosView irootloc; // plans to put E (i.e., Bd, Bo) into rootBuf
183   PetscIntKokkosView rowoffset;
184 };
185 
186 struct MatMatStruct_AtB : public MatMatStruct {
187   MatColIdxKokkosView Fdjmap; // plans to reduce data in rootBuf to Fd, Fo
188   MatColIdxKokkosView Fdjperm;
189   MatColIdxKokkosView Fojmap;
190   MatColIdxKokkosView Fojperm;
191 };
192 
193 struct MatProductData_MPIAIJKokkos {
194   MatMatStruct_AB  *mmAB     = nullptr;
195   MatMatStruct_AtB *mmAtB    = nullptr;
196   PetscBool         reusesym = PETSC_FALSE;
197   Mat               Z        = nullptr; // store Z=AB in computing BtAB
198 
199   ~MatProductData_MPIAIJKokkos()
200   {
201     delete mmAB;
202     delete mmAtB;
203     PetscCallAbort(PETSC_COMM_SELF, MatDestroy(&Z));
204   }
205 };
206 
207 static PetscErrorCode MatProductDataDestroy_MPIAIJKokkos(void *data)
208 {
209   PetscFunctionBegin;
210   PetscCallCXX(delete static_cast<MatProductData_MPIAIJKokkos *>(data));
211   PetscFunctionReturn(PETSC_SUCCESS);
212 }
213 
214 /* MatSetMPIAIJKokkosWithSplitSeqAIJKokkosMatrices - Set the diag and offdiag matrices of a MATMPIAIJKOKKOS matrix.
215    It is similar to MatCreateMPIAIJWithSplitArrays.
216 
217   Input Parameters:
218 +  mat   - the MATMPIAIJKOKKOS matrix, which should have its type and layout set, but should not have its diag, offdiag matrices set
219 .  A     - the diag matrix using local col ids
220 -  B     - the offdiag matrix using global col ids
221 
222   Output Parameter:
223 .  mat   - the updated MATMPIAIJKOKKOS matrix
224 */
225 static PetscErrorCode MatSetMPIAIJKokkosWithSplitSeqAIJKokkosMatrices(Mat mat, Mat A, Mat B, PetscInt *garray)
226 {
227   Mat_MPIAIJ *mpiaij = static_cast<Mat_MPIAIJ *>(mat->data);
228   PetscInt    m, n, M, N, Am, An, Bm, Bn;
229 
230   PetscFunctionBegin;
231   PetscCall(MatGetSize(mat, &M, &N));
232   PetscCall(MatGetLocalSize(mat, &m, &n));
233   PetscCall(MatGetLocalSize(A, &Am, &An));
234   PetscCall(MatGetLocalSize(B, &Bm, &Bn));
235 
236   PetscCheck(m == Am && m == Bm, PETSC_COMM_SELF, PETSC_ERR_PLIB, "local number of rows do not match");
237   PetscCheck(n == An, PETSC_COMM_SELF, PETSC_ERR_PLIB, "local number of columns do not match");
238   // PetscCheck(N == Bn, PETSC_COMM_SELF, PETSC_ERR_PLIB, "global number of columns do not match");
239   PetscCheck(!mpiaij->A && !mpiaij->B, PETSC_COMM_SELF, PETSC_ERR_PLIB, "A, B of the MPIAIJ matrix are not empty");
240   mpiaij->A      = A;
241   mpiaij->B      = B;
242   mpiaij->garray = garray;
243 
244   mat->preallocated     = PETSC_TRUE;
245   mat->nooffprocentries = PETSC_TRUE; /* See MatAssemblyBegin_MPIAIJ. In effect, making MatAssemblyBegin a nop */
246 
247   PetscCall(MatSetOption(mat, MAT_NO_OFF_PROC_ENTRIES, PETSC_TRUE));
248   PetscCall(MatAssemblyBegin(mat, MAT_FINAL_ASSEMBLY));
249   /* MatAssemblyEnd is critical here. It sets mat->offloadmask according to A and B's, and
250     also gets mpiaij->B compacted, with its col ids and size reduced
251   */
252   PetscCall(MatAssemblyEnd(mat, MAT_FINAL_ASSEMBLY));
253   PetscCall(MatSetOption(mat, MAT_NO_OFF_PROC_ENTRIES, PETSC_FALSE));
254   PetscCall(MatSetOption(mat, MAT_NEW_NONZERO_LOCATION_ERR, PETSC_TRUE));
255   PetscFunctionReturn(PETSC_SUCCESS);
256 }
257 
258 // Adapted from Kokkos-Kernels spmv_launch_parameters(), to get parameters in Kokkos nested loops which we used to merge or
259 // split csr matrices. The rule is to have "vector_length * team_size" be around 256 on GPUs (e.g., for a CUDA thread block)
260 template <class ExecutionSpace>
261 static PetscErrorCode MatMergeGetLaunchParameters(PetscInt numRows, PetscInt nnz, PetscInt rows_per_thread, PetscInt &team_size, PetscInt &vector_length, PetscInt &rows_per_team)
262 {
263   Kokkos::TeamPolicy<ExecutionSpace> teamPolicy(128, Kokkos::AUTO);
264 
265   PetscFunctionBegin;
266   PetscInt nnz_per_row = numRows ? (nnz / numRows) : 0; // we might meet empty matrices
267 
268   if (nnz_per_row < 1) nnz_per_row = 1;
269 
270   int max_vector_length = teamPolicy.vector_length_max();
271 
272   if (vector_length < 1) {
273     vector_length = 1;
274     while (vector_length < max_vector_length && vector_length * 6 < nnz_per_row) vector_length *= 2;
275   }
276 
277   // Determine rows per thread
278   if (rows_per_thread < 1) {
279     if (KokkosKernels::Impl::kk_is_gpu_exec_space<ExecutionSpace>()) rows_per_thread = 1;
280     else {
281       if (nnz_per_row < 20 && nnz > 5000000) {
282         rows_per_thread = 256;
283       } else rows_per_thread = 64;
284     }
285   }
286 
287   if (team_size < 1) {
288     if (KokkosKernels::Impl::kk_is_gpu_exec_space<ExecutionSpace>()) {
289       team_size = 256 / vector_length;
290     } else {
291       team_size = 1;
292     }
293   }
294 
295   rows_per_team = rows_per_thread * team_size;
296 
297   if (rows_per_team < 0) {
298     PetscInt nnz_per_team = 4096;
299     PetscInt conc         = ExecutionSpace().concurrency();
300     while ((conc * nnz_per_team * 4 > nnz) && (nnz_per_team > 256)) nnz_per_team /= 2;
301     rows_per_team = (nnz_per_team + nnz_per_row - 1) / nnz_per_row;
302   }
303   PetscFunctionReturn(PETSC_SUCCESS);
304 }
305 
306 /*
307   Reduce two sets of global indices into local ones
308 
309   Input Parameters:
310 +  n1          - size of garray1[], the first set
311 .  garray1[n1] - a sorted global index array (without duplicates)
312 .  m           - size of indices[], the second set
313 -  indices[m]  - a unsorted global index array (might have duplicates), which will be updated on output into local ones
314 
315   Output Parameters:
316 +  n2          - size of garray2[], the merged set, which combines garray1[] and indices[]
317 .  garray2[n2] - allocated by callee using PetscMalloc1(). Contains sorted unique global indices (without duplicates). Caller needs to free it.
318 .  map[n1]     - allocated by caller. It gives garray1[i] = garray2[map[i]]
319 -  indices[m]  - on output, global indices in this array are rewritten with local ones, i.e, indices_input[i] = garray2[indices_output[i]]
320 
321    Example, say
322     n1         = 5
323     garray1[5] = {1, 4, 7, 8, 10}
324     m          = 4
325     indices[4] = {2, 4, 8, 9}
326 
327    Combining them together, we have 7 global indices in garray2[]
328     n2         = 7
329     garray2[7] = {1, 2, 4, 7, 8, 9, 10}
330 
331    And we have map[] to connect "garray1[i] = garray2[map[i]], i=[0,n1)"
332     map[5] = {0, 2, 3, 4, 6}
333 
334    On output, indices[] is updated with local indices
335     indices[4] = {1, 2, 4, 5}
336 */
337 static PetscErrorCode ReduceTwoSetsOfGlobalIndices(PetscInt n1, const PetscInt *garray1, PetscInt m, PetscInt *indices, PetscInt *n2_, PetscInt **garray2_, PetscInt *map)
338 {
339   PetscHMapI    g2l = nullptr;
340   PetscHashIter iter;
341   PetscInt      tot, key, val; // total unique global indices. key is global id; val is local id
342   PetscInt      n2, *garray2;
343 
344   PetscFunctionBegin;
345   tot = 0;
346   PetscCall(PetscHMapICreateWithSize(n1, &g2l));
347   for (PetscInt i = 0; i < m; i++) {                                // insert those in indices[]
348     PetscCall(PetscHMapIGetWithDefault(g2l, indices[i], -1, &val)); // if not exist, val is set with -1
349     if (val < 0) PetscCall(PetscHMapISet(g2l, indices[i], tot++));  // val < 0 means gid is not in the hash table yet
350   }
351 
352   for (PetscInt i = 0; i < n1; i++) { // insert those in garray1[]
353     PetscCall(PetscHMapIGetWithDefault(g2l, garray1[i], -1, &val));
354     if (val < 0) PetscCall(PetscHMapISet(g2l, garray1[i], tot++));
355   }
356 
357   // Pull out (unique) globals in the hash table and put them in garray2[]
358   n2 = tot;
359   PetscCall(PetscMalloc1(n2, &garray2));
360   tot = 0;
361   PetscHashIterBegin(g2l, iter);
362   while (!PetscHashIterAtEnd(g2l, iter)) {
363     PetscHashIterGetKey(g2l, iter, key);
364     PetscHashIterNext(g2l, iter);
365     garray2[tot++] = key;
366   }
367 
368   // Sort garray2[] and then map them to local indices starting from 0
369   PetscCall(PetscSortInt(n2, garray2));
370   PetscCall(PetscHMapIClear(g2l));
371   for (PetscInt i = 0; i < tot; i++) PetscCall(PetscHMapISet(g2l, garray2[i], i)); // i is the local id
372 
373   // Rewrite indices[] with local indices
374   for (PetscInt i = 0; i < m; i++) {
375     PetscCall(PetscHMapIGetWithDefault(g2l, indices[i], -1, &val));
376     PetscAssert(val >= 0, PETSC_COMM_SELF, PETSC_ERR_PLIB, "Met a negative local column index");
377     indices[i] = val;
378   }
379   // Record the map that maps garray1[i] to garray2[map[i]]
380   for (PetscInt i = 0; i < n1; i++) PetscCall(PetscHMapIGetWithDefault(g2l, garray1[i], -1, &map[i]));
381   PetscCall(PetscHMapIDestroy(&g2l));
382   *n2_      = n2;
383   *garray2_ = garray2;
384   PetscFunctionReturn(PETSC_SUCCESS);
385 }
386 
387 /*
388   MatMPIAIJKokkosReduce - Reduce rows of a MPIAIJKOKKOS matrix (E, in split form) to produce another matrix (F, also in split form, stored in mm)
389 
390   It is the reverse of MatMPIAIJKokkosBcast() in some sense, but with a different signature since we do not really need a fully populated MPIAIJKOKKOS E.
391 
392   Think each row of E as a leaf, then the given ownerSF specifies roots for the leaves. Roots may connect to multiple leaves.
393   In this routine, we sparse-merge leaves (rows) at their roots to form potentially longer rows in F. F's number of rows will be nroots of ownerSF.
394 
395   Input Parameters:
396 +  comm       - MPI communicator of E
397 .  A          - diag block of E, using local column indices
398 .  B          - off-diag block of E, using local column indices
399 .  cstart      - (global) start column of Ed
400 .  cend        - (global) end column + 1 of Ed.  In other words, E's column ownership is in range of [cstart, cend)
401 .  garray1[n1] - global column indices of Eo. Here n1 is Eo's column size.
402 .  ownerSF     - the SF specifies ownership (root) of rows in E
403 .  reuse       - either MAT_INITIAL_MATRIX or MAT_REUSE_MATRIX
404 -  mm          - to stash intermediate data structures for reuse
405 
406   Output Parameters:
407 +  map[n1]  - allocated by caller. It maps garray1[] to garray2[]. See more at ReduceTwoSetsOfGlobalIndices().
408 -  mm       - contains various info, such as garray2[], F (Fd, Fo) etc.
409 
410   Notes:
411   When reuse = MAT_REUSE_MATRIX, cstart, cend, garray1, ownerSF, map are not significant.
412 
413  */
414 static PetscErrorCode MatMPIAIJKokkosReduceBegin(MPI_Comm comm, KokkosCsrMatrix A, KokkosCsrMatrix B, PetscInt cstart, PetscInt cend, const PetscInt *garray1, PetscSF ownerSF, MatReuse reuse, PetscInt *map, MatMatStruct_AtB *mm)
415 {
416   PetscFunctionBegin;
417   if (reuse == MAT_INITIAL_MATRIX) {
418     PetscInt Em = A.numRows(), Fm;
419     PetscInt n1 = B.numCols();
420 
421     PetscCall(PetscSFGetGraph(ownerSF, &Fm, NULL, NULL, NULL)); // Fm = #rows of F = nroots of ownerSF
422 
423     // Do the analysis on host
424     auto                 Ai_h = Kokkos::create_mirror_view_and_copy(Kokkos::HostSpace(), A.graph.row_map);
425     auto                 Aj_h = Kokkos::create_mirror_view_and_copy(Kokkos::HostSpace(), A.graph.entries);
426     auto                 Bi_h = Kokkos::create_mirror_view_and_copy(Kokkos::HostSpace(), B.graph.row_map);
427     auto                 Bj_h = Kokkos::create_mirror_view_and_copy(Kokkos::HostSpace(), B.graph.entries);
428     const MatRowMapType *Ai = Ai_h.data(), *Bi = Bi_h.data();
429     const MatColIdxType *Aj = Aj_h.data(), *Bj = Bj_h.data();
430 
431     // Count how many nonzeros of each row in E are in the left of the diag block [cstart,cend)
432     PetscIntKokkosViewHost E_NzLeft_h("E_NzLeft_h", Em), E_RowLen_h("E_RowLen_h", Em);
433     PetscInt              *E_NzLeft = E_NzLeft_h.data(), *E_RowLen = E_RowLen_h.data();
434     for (PetscInt i = 0; i < Em; i++) {
435       const PetscInt *first, *last, *it;
436       PetscInt        count, step;
437       // std::lower_bound(first,last,cstart), but need to use global column indices
438       first = Bj + Bi[i];
439       last  = Bj + Bi[i + 1];
440       count = last - first;
441       while (count > 0) {
442         it   = first;
443         step = count / 2;
444         it += step;
445         if (garray1[*it] < cstart) { // map local to global
446           first = ++it;
447           count -= step + 1;
448         } else count = step;
449       }
450       E_NzLeft[i] = first - (Bj + Bi[i]);
451       E_RowLen[i] = (Ai[i + 1] - Ai[i]) + (Bi[i + 1] - Bi[i]);
452     }
453 
454     // Get length of rows (i.e., sizes of leaves) that contribute to my roots
455     const PetscMPIInt *iranks, *ranks;
456     const PetscInt    *ioffset, *irootloc, *roffset, *rmine;
457     PetscInt           niranks, nranks;
458     MPI_Request       *reqs;
459     PetscMPIInt        tag;
460     PetscSF            reduceSF;
461     PetscInt          *sdisp, *rdisp;
462 
463     PetscCall(PetscCommGetNewTag(comm, &tag));
464     PetscCall(PetscSFGetLeafRanks(ownerSF, &niranks, &iranks, &ioffset, &irootloc));  // get leaf ranks connecting to roots on this process (I'll recv from them)
465     PetscCall(PetscSFGetRootRanks(ownerSF, &nranks, &ranks, &roffset, &rmine, NULL)); // get root ranks this process connects (I'll send to them)
466 
467     // Find out length of each row I will receive. Even for the same row index, when they are from
468     // different senders, they might have different lengths (and sparsity patterns)
469     PetscInt  sendRowCnt = roffset[nranks], recvRowCnt = ioffset[niranks];
470     PetscInt *sendRowLen, *recvRowLen; // lengths of rows of I need to send/recv per process
471 
472     PetscCall(PetscMalloc5(sendRowCnt, &sendRowLen, recvRowCnt + 1, &recvRowLen, nranks, &sdisp, niranks + 1, &rdisp, nranks + niranks, &reqs));
473 
474     for (PetscInt i = 0; i < sendRowCnt; i++) sendRowLen[i] = E_RowLen[rmine[i]];
475     recvRowLen[0] = 0; // since we will make it in CSR format later
476     recvRowLen++;      // advance the pointer now
477     for (PetscInt i = 0; i < niranks; i++) { MPI_Irecv(&recvRowLen[ioffset[i]], ioffset[i + 1] - ioffset[i], MPIU_INT, iranks[i], tag, comm, &reqs[nranks + i]); }
478     for (PetscInt i = 0; i < nranks; i++) { MPI_Isend(&sendRowLen[roffset[i]], roffset[i + 1] - roffset[i], MPIU_INT, ranks[i], tag, comm, &reqs[i]); }
479     PetscCallMPI(MPI_Waitall(nranks + niranks, reqs, MPI_STATUSES_IGNORE));
480 
481     // Build the real PetscSF for reducing E rows (buffer to buffer)
482     rdisp[0] = 0;
483     for (PetscInt i = 0; i < niranks; i++) {
484       rdisp[i + 1] = rdisp[i];
485       for (PetscInt j = ioffset[i]; j < ioffset[i + 1]; j++) { rdisp[i + 1] += recvRowLen[j]; }
486     }
487     recvRowLen--; // put it back into csr format
488     for (PetscInt i = 0; i < recvRowCnt; i++) recvRowLen[i + 1] += recvRowLen[i];
489 
490     for (PetscInt i = 0; i < nranks; i++) { MPI_Irecv(&sdisp[i], 1, MPIU_INT, ranks[i], tag, comm, &reqs[i]); }
491     for (PetscInt i = 0; i < niranks; i++) { MPI_Isend(&rdisp[i], 1, MPIU_INT, iranks[i], tag, comm, &reqs[nranks + i]); }
492     PetscCallMPI(MPI_Waitall(nranks + niranks, reqs, MPI_STATUSES_IGNORE));
493 
494     PetscInt     nleaves = 0, Enz = 0;    // leaves are nonzeros I will send
495     PetscInt     nroots = rdisp[niranks]; // roots are nonzeros I will recv
496     PetscSFNode *iremote;
497 
498     for (PetscInt i = 0; i < Em; i++) Enz += E_RowLen[i];
499     PetscAssert(A.nnz() + B.nnz() == Enz, comm, PETSC_ERR_PLIB, "Enz should be equal to sum of nnz of A and B");
500     PetscCallMPI(PetscMalloc1(Enz, &iremote)); // no free, since we give ownership to reduceSF
501 
502     for (PetscInt i = 0; i < nranks; i++) {
503       PetscInt count = 0;
504       for (PetscInt j = roffset[i]; j < roffset[i + 1]; j++) count += E_RowLen[rmine[j]];
505       for (PetscInt j = 0; j < count; j++) {
506         iremote[nleaves + j].rank  = ranks[i];
507         iremote[nleaves + j].index = sdisp[i] + j;
508       }
509       nleaves += count;
510     }
511     PetscCheck(nleaves == Enz, comm, PETSC_ERR_PLIB, "nleaves should be equal to Enz");
512 
513     PetscCall(PetscSFCreate(comm, &reduceSF));
514     PetscCall(PetscSFSetGraph(reduceSF, nroots, nleaves, NULL, PETSC_OWN_POINTER, iremote, PETSC_OWN_POINTER));
515 
516     // Copy (global) column indices of the needed rows in E to sendCol[], and then PetscSFReduce to recvCol[]
517     PetscInt *sendCol, *recvCol;
518     PetscCall(PetscMalloc2(nleaves, &sendCol, nroots, &recvCol));
519     for (PetscInt k = 0; k < roffset[nranks]; k++) {
520       PetscInt  i      = rmine[k]; // row to be copied
521       PetscInt *buf    = &sendCol[Ai[i] + Bi[i]];
522       PetscInt  nzLeft = E_NzLeft[i];
523       PetscInt  alen = Ai[i + 1] - Ai[i], blen = Bi[i + 1] - Bi[i];
524       for (PetscInt j = 0; j < alen + blen; j++) {
525         if (j < nzLeft) {
526           buf[j] = garray1[Bj[Bi[i] + j]]; // left B, in global
527         } else if (j < nzLeft + alen) {
528           buf[j] = Aj[Ai[i] + j - nzLeft] + cstart; // diag A, also in global
529         } else {
530           buf[j] = garray1[Bj[Bi[i] + j - alen]]; // right B, in global
531         }
532       }
533     }
534     PetscCall(PetscSFReduceWithMemTypeBegin(reduceSF, MPIU_INT, PETSC_MEMTYPE_HOST, sendCol, PETSC_MEMTYPE_HOST, recvCol, MPI_REPLACE));
535     PetscCall(PetscSFReduceEnd(reduceSF, MPIU_INT, sendCol, recvCol, MPI_REPLACE));
536 
537     // With recvCol[], we do a series of analysis to get i, j of Fd, Fo, and build plans to reduce nonzeros in recv buffers to Fd and Fo
538     PetscInt *recvRowPerm, *recvColSorted;
539     PetscInt *recvNzPerm, *recvNzPermSorted;
540     PetscCall(PetscMalloc4(recvRowCnt, &recvRowPerm, nroots, &recvColSorted, nroots, &recvNzPerm, nroots, &recvNzPermSorted));
541 
542     for (PetscInt i = 0; i < nroots; i++) recvNzPerm[i] = i;                   // numbering all received nonzeros
543     for (PetscInt i = 0; i < recvRowCnt; i++) recvRowPerm[i] = i;              // put up a permutation array, so that after sorting we know where to get a row in recvCol[]
544     PetscCall(PetscSortIntWithPermutation(recvRowCnt, irootloc, recvRowPerm)); // irootloc[] (owned by ownerSF) won't be changed
545 
546     // i[] array, nz are always easiest to compute
547     MatRowMapKokkosViewHost Fdi_h("Fdi_h", Fm + 1), Foi_h("Foi_h", Fm + 1);
548     MatRowMapType          *Fdi, *Foi;
549     PetscInt                FnzDups = 0, Fdnz = 0, FdnzDups = 0, Fonz = 0, FonzDups = 0; // nz (with or without dups) in F, Fd, Fo
550     PetscInt                iter;
551 
552     Kokkos::deep_copy(Fdi_h, 0); // zero, as we will do 'val++' on them
553     Kokkos::deep_copy(Foi_h, 0);
554     Fdi  = Fdi_h.data() + 1; // +1 for easy indexing in code below
555     Foi  = Foi_h.data() + 1;
556     iter = 0;
557     while (iter < recvRowCnt) { // iter over received rows
558       PetscInt curRowIdx = irootloc[recvRowPerm[iter]];
559       PetscInt dupRows   = 1; // current row has this many contributing rows (of various sparsity patterns)
560 
561       while (iter + dupRows < recvRowCnt && irootloc[recvRowPerm[iter + dupRows]] == curRowIdx) dupRows++;
562 
563       // Copy column indices (and their permutation) of these rows into recvColSorted & recvNzPermSorted
564       PetscInt  nz    = 0; // nz (with dups) in the current row
565       PetscInt *jbuf  = recvColSorted + FnzDups;
566       PetscInt *pbuf  = recvNzPermSorted + FnzDups;
567       PetscInt *jbuf2 = jbuf; // temp pointers
568       PetscInt *pbuf2 = pbuf;
569       for (PetscInt d = 0; d < dupRows; d++) {
570         PetscInt i   = recvRowPerm[iter + d];
571         PetscInt len = recvRowLen[i + 1] - recvRowLen[i];
572         PetscCall(PetscArraycpy(jbuf2, &recvCol[recvRowLen[i]], len));
573         PetscCall(PetscArraycpy(pbuf2, &recvNzPerm[recvRowLen[i]], len));
574         jbuf2 += len;
575         pbuf2 += len;
576         nz += len;
577       }
578       PetscCall(PetscIntSortSemiOrderedWithArray(nz, jbuf, pbuf)); // It could be improved with k-way merge sort, since the rows are already sorted
579 
580       // Scan column indices (in jbuf[0,nz), might have dups) of this row, and see how many go to Fd and how many go to Fo
581       PetscInt cur = 0;
582       while (cur < nz) {
583         PetscInt curColIdx = jbuf[cur];
584         PetscInt dups      = 1;
585 
586         while (cur + dups < nz && jbuf[cur + dups] == curColIdx) dups++;
587         if (curColIdx >= cstart && curColIdx < cend) {
588           Fdi[curRowIdx]++;
589           FdnzDups += dups;
590         } else {
591           Foi[curRowIdx]++;
592           FonzDups += dups;
593         }
594         cur += dups;
595       }
596 
597       FnzDups += nz;
598       iter += dupRows; // Move to next unique row
599     }
600 
601     Fdi = Fdi_h.data(); // restore Fdi, Foi and make them CSR
602     Foi = Foi_h.data();
603     for (PetscInt i = 0; i < Fm; i++) {
604       Fdi[i + 1] += Fdi[i];
605       Foi[i + 1] += Foi[i];
606     }
607     Fdnz = Fdi[Fm];
608     Fonz = Foi[Fm];
609     PetscCall(PetscFree2(sendCol, recvCol));
610 
611     // Allocate j, jmap, jperm for Fd and Fo
612     MatColIdxKokkosViewHost Fdj_h("Fdj_h", Fdnz), Foj_h("Foj_h", Fonz);
613     MatRowMapKokkosViewHost Fdjmap_h("Fdjmap_h", Fdnz + 1), Fojmap_h("Fojmap_h", Fonz + 1); // +1 to make csr
614     MatRowMapKokkosViewHost Fdjperm_h("Fdjperm_h", FdnzDups), Fojperm_h("Fojperm_h", FonzDups);
615     MatColIdxType          *Fdj = Fdj_h.data(), *Foj = Foj_h.data();
616     MatRowMapType          *Fdjmap = Fdjmap_h.data(), *Fojmap = Fojmap_h.data();
617     MatRowMapType          *Fdjperm = Fdjperm_h.data(), *Fojperm = Fojperm_h.data();
618 
619     // Scan recvColSorted[] again, and fill j, jmap, jperm for Fd and Fo
620     Fdjmap[0] = 0;
621     Fojmap[0] = 0;
622     FnzDups   = 0;
623     Fdnz      = 0;
624     Fonz      = 0;
625     iter      = 0; // iter over received rows
626     while (iter < recvRowCnt) {
627       PetscInt curRowIdx = irootloc[recvRowPerm[iter]]; // current row idx
628       PetscInt dupRows   = 1;                           // It has this many contributing rows (of various lengths)
629       PetscInt nz        = 0;                           // nz (with dups) in the current row
630 
631       while (iter + dupRows < recvRowCnt && irootloc[recvRowPerm[iter + dupRows]] == curRowIdx) dupRows++;
632       for (PetscInt d = 0; d < dupRows; d++) {
633         PetscInt i = recvRowPerm[iter + d];
634         nz += recvRowLen[i + 1] - recvRowLen[i];
635       }
636 
637       PetscInt *jbuf = recvColSorted + FnzDups;
638       // Scan columns (in jbuf[0,nz) of this row, copy them and their permutation to j[] and jperm[] of Fd and Fo
639       PetscInt cur = 0;
640       while (cur < nz) {
641         PetscInt curColIdx = jbuf[cur];
642         PetscInt dups      = 1;
643 
644         while (cur + dups < nz && jbuf[cur + dups] == curColIdx) dups++;
645         if (curColIdx >= cstart && curColIdx < cend) {
646           Fdj[Fdnz]        = curColIdx - cstart; // easily convert to local
647           Fdjmap[Fdnz + 1] = Fdjmap[Fdnz] + dups;
648           for (PetscInt j = 0; j < dups; j++) Fdjperm[Fdjmap[Fdnz] + j] = recvNzPermSorted[FnzDups + j];
649           FdnzDups += dups;
650           Fdnz++;
651         } else {
652           Foj[Fonz]        = curColIdx; // in global
653           Fojmap[Fonz + 1] = Fojmap[Fonz] + dups;
654           for (PetscInt j = 0; j < dups; j++) Fojperm[Fojmap[Fonz] + j] = recvNzPermSorted[FnzDups + j];
655           FonzDups += dups;
656           Fonz++;
657         }
658         cur += dups;
659         FnzDups += dups;
660       }
661       iter += dupRows; // Move to next unique row
662     }
663     PetscCall(PetscFree4(recvRowPerm, recvColSorted, recvNzPerm, recvNzPermSorted));
664     PetscCall(PetscFree5(sendRowLen, recvRowLen, sdisp, rdisp, reqs));
665 
666     // Combine global column indices in garray1[] and Foj[]
667     PetscInt n2, *garray2;
668 
669     PetscCall(ReduceTwoSetsOfGlobalIndices(n1, garray1, Fonz, Foj, &n2, &garray2, map));
670     mm->sf       = reduceSF;
671     mm->leafBuf  = MatScalarKokkosView("leafBuf", nleaves);
672     mm->rootBuf  = MatScalarKokkosView("rootBuf", nroots);
673     mm->garray   = garray2; // give ownership, so no free
674     mm->n        = n2;
675     mm->E_NzLeft = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), E_NzLeft_h);
676     mm->Fdjmap   = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Fdjmap_h);
677     mm->Fdjperm  = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Fdjperm_h);
678     mm->Fojmap   = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Fojmap_h);
679     mm->Fojperm  = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Fojperm_h);
680 
681     // Output Fd and Fo in KokkosCsrMatrix format
682     MatScalarKokkosView Fda_d("Fda_d", Fdnz);
683     MatRowMapKokkosView Fdi_d = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Fdi_h);
684     MatColIdxKokkosView Fdj_d = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Fdj_h);
685     MatScalarKokkosView Foa_d("Foa_d", Fonz);
686     MatRowMapKokkosView Foi_d = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Foi_h);
687     MatColIdxKokkosView Foj_d = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Foj_h);
688 
689     PetscCallCXX(mm->Fd = KokkosCsrMatrix("Fd", Fm, cend - cstart, Fdnz, Fda_d, Fdi_d, Fdj_d));
690     PetscCallCXX(mm->Fo = KokkosCsrMatrix("Fo", Fm, n2, Fonz, Foa_d, Foi_d, Foj_d)); // Fo's column size is n2, length of garray2[]
691 
692     // Compute kernel launch parameters in merging E
693     PetscInt teamSize, vectorLength, rowsPerTeam;
694 
695     teamSize = vectorLength = rowsPerTeam = -1;
696     PetscCall(MatMergeGetLaunchParameters<DefaultExecutionSpace>(Em, Enz, -1, teamSize, vectorLength, rowsPerTeam));
697     mm->E_TeamSize     = teamSize;
698     mm->E_VectorLength = vectorLength;
699     mm->E_RowsPerTeam  = rowsPerTeam;
700   } else PetscCheck(reuse == MAT_REUSE_MATRIX, comm, PETSC_ERR_PLIB, "Unsupported MatReuse enum %d", reuse);
701 
702   // Handy aliases
703   auto       &Aa           = A.values;
704   auto       &Ba           = B.values;
705   const auto &Ai           = A.graph.row_map;
706   const auto &Bi           = B.graph.row_map;
707   const auto &E_NzLeft     = mm->E_NzLeft;
708   auto       &leafBuf      = mm->leafBuf;
709   auto       &rootBuf      = mm->rootBuf;
710   PetscSF     reduceSF     = mm->sf;
711   PetscInt    Em           = A.numRows();
712   PetscInt    teamSize     = mm->E_TeamSize;
713   PetscInt    vectorLength = mm->E_VectorLength;
714   PetscInt    rowsPerTeam  = mm->E_RowsPerTeam;
715   PetscInt    workSets     = (Em + rowsPerTeam - 1) / rowsPerTeam;
716 
717   // Copy rows in A/B of E to leafBuf, then pass it to rootBuf
718   PetscCallCXX(Kokkos::parallel_for(
719     Kokkos::TeamPolicy<>(workSets, teamSize, vectorLength), KOKKOS_LAMBDA(const KokkosTeamMemberType &t) {
720       Kokkos::parallel_for(Kokkos::TeamThreadRange(t, 0, rowsPerTeam), [&](PetscInt k) {
721         PetscInt i = t.league_rank() * rowsPerTeam + k; // i-th row in F
722         if (i < Em) {
723           PetscInt disp   = Ai(i) + Bi(i);
724           PetscInt alen   = Ai(i + 1) - Ai(i);
725           PetscInt blen   = Bi(i + 1) - Bi(i);
726           PetscInt nzleft = E_NzLeft(i);
727 
728           Kokkos::parallel_for(Kokkos::ThreadVectorRange(t, alen + blen), [&](PetscInt j) {
729             MatScalar &val = leafBuf(disp + j);
730             if (j < nzleft) { // B left
731               val = Ba(Bi(i) + j);
732             } else if (j < nzleft + alen) { // diag A
733               val = Aa(Ai(i) + j - nzleft);
734             } else { // B right
735               val = Ba(Bi(i) + j - alen);
736             }
737           });
738         }
739       });
740     }));
741   PetscCall(PetscSFReduceWithMemTypeBegin(reduceSF, MPIU_SCALAR, PETSC_MEMTYPE_KOKKOS, leafBuf.data(), PETSC_MEMTYPE_KOKKOS, rootBuf.data(), MPI_REPLACE));
742   PetscFunctionReturn(PETSC_SUCCESS);
743 }
744 
745 // To finish MatMPIAIJKokkosReduce.
746 static PetscErrorCode MatMPIAIJKokkosReduceEnd(MPI_Comm comm, KokkosCsrMatrix A, KokkosCsrMatrix B, PetscInt cstart, PetscInt cend, const PetscInt *garray1, PetscSF ownerSF, MatReuse reuse, PetscInt *map, MatMatStruct_AtB *mm)
747 {
748   PetscFunctionBegin;
749   auto       &leafBuf  = mm->leafBuf;
750   auto       &rootBuf  = mm->rootBuf;
751   auto       &Fda      = mm->Fd.values;
752   const auto &Fdjmap   = mm->Fdjmap;
753   const auto &Fdjperm  = mm->Fdjperm;
754   auto        Fdnz     = mm->Fd.nnz();
755   auto       &Foa      = mm->Fo.values;
756   const auto &Fojmap   = mm->Fojmap;
757   const auto &Fojperm  = mm->Fojperm;
758   auto        Fonz     = mm->Fo.nnz();
759   PetscSF     reduceSF = mm->sf;
760 
761   PetscCall(PetscSFReduceEnd(reduceSF, MPIU_SCALAR, leafBuf.data(), rootBuf.data(), MPI_REPLACE));
762 
763   // Reduce data in rootBuf to Fd and Fo
764   PetscCallCXX(Kokkos::parallel_for(
765     Fdnz, KOKKOS_LAMBDA(const MatRowMapType i) {
766       PetscScalar sum = 0.0;
767       for (MatRowMapType k = Fdjmap(i); k < Fdjmap(i + 1); k++) sum += rootBuf(Fdjperm(k));
768       Fda(i) = sum;
769     }));
770 
771   PetscCallCXX(Kokkos::parallel_for(
772     Fonz, KOKKOS_LAMBDA(const MatRowMapType i) {
773       PetscScalar sum = 0.0;
774       for (MatRowMapType k = Fojmap(i); k < Fojmap(i + 1); k++) sum += rootBuf(Fojperm(k));
775       Foa(i) = sum;
776     }));
777   PetscFunctionReturn(PETSC_SUCCESS);
778 }
779 
780 /*
781   MatMPIAIJKokkosBcast - Bcast local rows of a MPIAIJKOKKOS matrix (E) to produce a local matrix (F, stored in mm) in split form
782 
783   This is a complex routine. It is essentially the MPIAIJKOKKOS counterpart of MatGetBrowsOfAoCols_MPIAIJ, but supports
784   device and involves various index mapping.
785 
786   In the given ownerSF, leaves correspond to rows in F, and roots correspond to rows in E. Roots may connect to multiple leaves.
787   Suppose F's j-th row is connected to a root identified by PetscSFNode (k,i), it means we need to bcast the i-th row of E on rank k
788   to j-th row of F. ownerSF is not an arbitrary SF, instead it is the Mvctx of another MPIAIJ matrix A that is able to perform A*E.
789   F has the same column layout as E.
790 
791   Conceptually F has global column indices. In this routine, we spit F into diagonal Fd and off-diagonal Fo.
792   Fd uses local column indices, which are easy to compute. We just need to subtract the "local column range start" from the global indices.
793   Fo had global column indices at first. We will reduce them into local ones. In doing that, we also take into account the global
794   column indices that E's off-diag block has. Let's say there are n1 such indices stored in garray1[]. We will reduce them along with
795   column indices in Fo and update Fo with local indices.
796 
797    Input Parameters:
798 +   E       - the MPIAIJKOKKOS matrix
799 .   ownerSF - the ownership SF (nonsignificant in MAT_REUSE_MATRIX)
800 .   reuse   - either MAT_INITIAL_MATRIX or MAT_REUSE_MATRIX
801 -   mm      - to stash matproduct intermediate data structures
802 
803     Output Parameters:
804 +   map[n1] - allocated by caller. It maps garray1[] to garray2[]. See more at ReduceTwoSetsOfGlobalIndices.
805 -   mm      - contains various info, such as garray2[], Fd, Fo, etc.
806 
807     Notes:
808     When reuse = MAT_REUSE_MATRIX, ownerSF, map are not significant.
809     The routine is provide in split-phase form MatMPIAIJKokkosBcastBegin/End() to provide computation/communication opportunities.
810 */
811 static PetscErrorCode MatMPIAIJKokkosBcastBegin(Mat E, PetscSF ownerSF, MatReuse reuse, PetscInt *map, MatMatStruct_AB *mm)
812 {
813   Mat_MPIAIJ       *empi = static_cast<Mat_MPIAIJ *>(E->data);
814   Mat               A = empi->A, B = empi->B; // diag and off-diag
815   Mat_SeqAIJKokkos *akok = static_cast<Mat_SeqAIJKokkos *>(A->spptr), *bkok = static_cast<Mat_SeqAIJKokkos *>(B->spptr);
816   PetscInt          Em = E->rmap->n; // #local rows
817   MPI_Comm          comm;
818 
819   PetscFunctionBegin;
820   PetscCallMPI(PetscObjectGetComm((PetscObject)E, &comm));
821   if (reuse == MAT_INITIAL_MATRIX) {
822     Mat_SeqAIJ     *aseq = static_cast<Mat_SeqAIJ *>(A->data), *bseq = static_cast<Mat_SeqAIJ *>(B->data);
823     PetscInt        n1 = B->cmap->n, *Ai = aseq->i, *Aj = aseq->j, *Bi = bseq->i, *Bj = bseq->j;
824     const PetscInt *garray1 = empi->garray; // its size is n1
825     PetscInt        cstart, cend;
826     PetscSF         bcastSF;
827 
828     PetscCall(MatGetOwnershipRangeColumn(E, &cstart, &cend));
829 
830     // Count how many nonzeros of each row in E are in the left of the diag block [cstart,cend)
831     PetscIntKokkosViewHost E_NzLeft_h("E_NzLeft_h", Em), E_RowLen_h("E_RowLen_h", Em);
832     PetscInt              *E_NzLeft = E_NzLeft_h.data(), *E_RowLen = E_RowLen_h.data();
833     for (PetscInt i = 0; i < Em; i++) {
834       const PetscInt *first, *last, *it;
835       PetscInt        count, step;
836       // std::lower_bound(first,last,cstart), but need to use global column indices
837       first = Bj + Bi[i];
838       last  = Bj + Bi[i + 1];
839       count = last - first;
840       while (count > 0) {
841         it   = first;
842         step = count / 2;
843         it += step;
844         if (empi->garray[*it] < cstart) { // map local to global
845           first = ++it;
846           count -= step + 1;
847         } else count = step;
848       }
849       E_NzLeft[i] = first - (Bj + Bi[i]);
850       E_RowLen[i] = (Ai[i + 1] - Ai[i]) + (Bi[i + 1] - Bi[i]);
851     }
852 
853     // Compute row pointer Fi of F
854     PetscInt *Fi, Fm, Fnz;
855     PetscCall(PetscSFGetGraph(ownerSF, NULL, &Fm, NULL, NULL)); // Fm = #rows of F = nleaves of ownerSF
856     PetscCall(PetscMalloc1(Fm + 1, &Fi));
857     Fi[0] = 0;
858     PetscCall(PetscSFBcastWithMemTypeBegin(ownerSF, MPIU_INT, PETSC_MEMTYPE_HOST, E_RowLen, PETSC_MEMTYPE_HOST, &Fi[1], MPI_REPLACE));
859     PetscCall(PetscSFBcastEnd(ownerSF, MPIU_INT, E_RowLen, &Fi[1], MPI_REPLACE));
860     for (PetscInt i = 0; i < Fm; i++) Fi[i + 1] += Fi[i];
861     Fnz = Fi[Fm];
862 
863     // Build the real PetscSF for bcasting E rows (buffer to buffer)
864     const PetscMPIInt *iranks, *ranks;
865     const PetscInt    *ioffset, *irootloc, *roffset;
866     PetscInt           niranks, nranks, *sdisp, *rdisp;
867     MPI_Request       *reqs;
868     PetscMPIInt        tag;
869 
870     PetscCall(PetscSFGetLeafRanks(ownerSF, &niranks, &iranks, &ioffset, &irootloc)); // get leaf ranks referencing roots on this process
871     PetscCall(PetscSFGetRootRanks(ownerSF, &nranks, &ranks, &roffset, NULL, NULL));  // recv info
872     PetscCall(PetscMalloc3(niranks + 1, &sdisp, nranks, &rdisp, niranks + nranks, &reqs));
873 
874     sdisp[0] = 0; // send displacement
875     for (PetscInt i = 0; i < niranks; i++) {
876       sdisp[i + 1] = sdisp[i];
877       for (PetscInt j = ioffset[i]; j < ioffset[i + 1]; j++) {
878         PetscInt r = irootloc[j]; // row to be sent
879         sdisp[i + 1] += E_RowLen[r];
880       }
881     }
882 
883     PetscCallMPI(PetscCommGetNewTag(comm, &tag));
884     for (PetscInt i = 0; i < nranks; i++) PetscCallMPI(MPI_Irecv(&rdisp[i], 1, MPIU_INT, ranks[i], tag, comm, &reqs[i]));
885     for (PetscInt i = 0; i < niranks; i++) PetscCallMPI(MPI_Isend(&sdisp[i], 1, MPIU_INT, iranks[i], tag, comm, &reqs[nranks + i]));
886     PetscCallMPI(MPI_Waitall(niranks + nranks, reqs, MPI_STATUSES_IGNORE));
887 
888     PetscInt     nleaves = Fnz;            // leaves are nonzeros I will receive
889     PetscInt     nroots  = sdisp[niranks]; // roots are nonzeros I will send
890     PetscSFNode *iremote;                  // give ownership to bcastSF
891     PetscCall(PetscMalloc1(nleaves, &iremote));
892     for (PetscInt i = 0; i < nranks; i++) { // for each sender rank
893       PetscInt k = 0;
894       for (PetscInt j = Fi[roffset[i]]; j < Fi[roffset[i + 1]]; j++) { // I will receive rows [roffset[i], roffset[i+1]) of F from ranks[i]
895         iremote[j].rank  = ranks[i];
896         iremote[j].index = rdisp[i] + k; // their root location
897         k++;
898       }
899     }
900     PetscCall(PetscSFCreate(comm, &bcastSF));
901     PetscCall(PetscSFSetGraph(bcastSF, nroots, nleaves, NULL, PETSC_OWN_POINTER, iremote, PETSC_OWN_POINTER));
902     PetscCall(PetscFree3(sdisp, rdisp, reqs));
903 
904     // Build a plan (rowoffset, irootloc, E_NzLeft) to copy rows in E to rootdata of bcastSF in parallel
905     PetscIntKokkosViewHost rowoffset_h("rowoffset_h", ioffset[niranks] + 1);
906     PetscInt              *rowoffset = rowoffset_h.data(); // for each entry (row) indicated in irootloc[], we calculate its destinate offset in copying
907     rowoffset[0]                     = 0;
908     for (PetscInt i = 0; i < ioffset[niranks]; i++) { rowoffset[i + 1] += rowoffset[i] + E_RowLen[irootloc[i]]; }
909 
910     // Copy (global) column indices of the needed rows in E to a buffer, and then bcast to Fj[]
911     PetscInt *jbuf, *Fj;
912     PetscCall(PetscMalloc2(nroots, &jbuf, Fnz, &Fj));
913     for (PetscInt k = 0; k < ioffset[niranks]; k++) {
914       PetscInt  i      = irootloc[k]; // row to be copied
915       PetscInt *buf    = &jbuf[rowoffset[k]];
916       PetscInt  nzLeft = E_NzLeft[i];
917       PetscInt  alen = Ai[i + 1] - Ai[i], blen = Bi[i + 1] - Bi[i];
918       for (PetscInt j = 0; j < alen + blen; j++) {
919         if (j < nzLeft) {
920           buf[j] = empi->garray[Bj[Bi[i] + j]]; // left B, in global
921         } else if (j < nzLeft + alen) {
922           buf[j] = Aj[Ai[i] + j - nzLeft] + cstart; // diag A, also in global
923         } else {
924           buf[j] = empi->garray[Bj[Bi[i] + j - alen]]; // right B, in global
925         }
926       }
927     }
928     PetscCall(PetscSFBcastWithMemTypeBegin(bcastSF, MPIU_INT, PETSC_MEMTYPE_HOST, jbuf, PETSC_MEMTYPE_HOST, Fj, MPI_REPLACE));
929     PetscCall(PetscSFBcastEnd(bcastSF, MPIU_INT, jbuf, Fj, MPI_REPLACE));
930 
931     // Build a plan (i.e., F_NzLeft) to split F into Fd and Fo
932     MatRowMapKokkosViewHost Fdi_h("Fdi_h", Fm + 1), Foi_h("Foi_h", Fm + 1); // row pointer of Fd, Fo
933     MatColIdxKokkosViewHost F_NzLeft_h("F_NzLeft_h", Fm);                   // split each row of F into Left, Diag, Right. We only need to record #nz in Left and Diag.
934     MatRowMapType          *Fdi = Fdi_h.data(), *Foi = Foi_h.data();
935     MatColIdxType          *F_NzLeft = F_NzLeft_h.data();
936 
937     Fdi[0] = Foi[0] = 0;
938     for (PetscInt i = 0; i < Fm; i++) {
939       PetscInt *first, *last, *lb1, *lb2;
940       // cut the row into: Left, [cstart, cend), Right
941       first       = Fj + Fi[i];
942       last        = Fj + Fi[i + 1];
943       lb1         = std::lower_bound(first, last, cstart);
944       F_NzLeft[i] = lb1 - first;
945       lb2         = std::lower_bound(first, last, cend);
946       Fdi[i + 1]  = lb2 - lb1;                        // row i length in Fdi
947       Foi[i + 1]  = (Fi[i + 1] - Fi[i]) - Fdi[i + 1]; // row i length in Foi
948     }
949     for (PetscInt i = 0; i < Fm; i++) {
950       Fdi[i + 1] += Fdi[i];
951       Foi[i + 1] += Foi[i];
952     }
953 
954     // Fill Fdj[] and Foj[], i.e., columns of Fd and Fo. Fdj[] are local, but Foj[] are not yet.
955     PetscInt                Fdnz = Fdi[Fm], Fonz = Foi[Fm];
956     MatColIdxKokkosViewHost Fdj_h("Fdj_h", Fdnz), Foj_h("Foj_h", Fonz);
957     MatColIdxType          *Fdj = Fdj_h.data(), *Foj = Foj_h.data(), gid;
958 
959     for (PetscInt i = 0; i < Fm; i++) {
960       PetscInt nzLeft = F_NzLeft[i];
961       PetscInt len    = Fdi[i + 1] - Fdi[i]; // diag row len
962       for (PetscInt j = 0; j < Fi[i + 1] - Fi[i]; j++) {
963         gid = Fj[Fi[i] + j];
964         if (j < nzLeft) { // left, in global
965           Foj[Foi[i] + j] = gid;
966         } else if (j < nzLeft + len) { // diag, in local
967           Fdj[Fdi[i] + j - nzLeft] = gid - cstart;
968         } else { // right, in global
969           Foj[Foi[i] + j - len] = gid;
970         }
971       }
972     }
973     PetscCall(PetscFree2(jbuf, Fj));
974     PetscCall(PetscFree(Fi));
975 
976     // Reduce global indices in Foj[] and garray1[] into local ones
977     PetscInt n2, *garray2;
978     PetscCall(ReduceTwoSetsOfGlobalIndices(n1, garray1, Fonz, Foj, &n2, &garray2, map));
979 
980     // Record the plans built above, for reuse
981     PetscIntKokkosViewHost tmp(const_cast<PetscInt *>(irootloc), ioffset[niranks]); // irootloc[] is owned by ownerSF. We create a copy for safety
982     PetscIntKokkosViewHost irootloc_h("irootloc_h", ioffset[niranks]);
983     Kokkos::deep_copy(irootloc_h, tmp);
984     mm->sf        = bcastSF;
985     mm->E_NzLeft  = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), E_NzLeft_h);
986     mm->F_NzLeft  = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), F_NzLeft_h);
987     mm->irootloc  = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), irootloc_h);
988     mm->rowoffset = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), rowoffset_h);
989     mm->rootBuf   = MatScalarKokkosView("rootBuf", nroots);
990     mm->leafBuf   = MatScalarKokkosView("leafBuf", nleaves);
991     mm->garray    = garray2;
992     mm->n         = n2;
993 
994     // Output Fd and Fo in KokkosCsrMatrix format
995     MatScalarKokkosView Fda_d("Fda_d", Fdnz), Foa_d("Foa_d", Fonz);
996     MatRowMapKokkosView Fdi_d = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Fdi_h);
997     MatColIdxKokkosView Fdj_d = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Fdj_h);
998     MatRowMapKokkosView Foi_d = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Foi_h);
999     MatColIdxKokkosView Foj_d = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Foj_h);
1000 
1001     PetscCallCXX(mm->Fd = KokkosCsrMatrix("Fd", Fm, cend - cstart, Fdnz, Fda_d, Fdi_d, Fdj_d));
1002     PetscCallCXX(mm->Fo = KokkosCsrMatrix("Fo", Fm, n2, Fonz, Foa_d, Foi_d, Foj_d));
1003 
1004     // Compute kernel launch parameters in merging E or splitting F
1005     PetscInt teamSize, vectorLength, rowsPerTeam;
1006 
1007     teamSize = vectorLength = rowsPerTeam = -1;
1008     PetscCall(MatMergeGetLaunchParameters<DefaultExecutionSpace>(mm->irootloc.extent(0), mm->rootBuf.extent(0), -1, teamSize, vectorLength, rowsPerTeam));
1009     mm->E_TeamSize     = teamSize;
1010     mm->E_VectorLength = vectorLength;
1011     mm->E_RowsPerTeam  = rowsPerTeam;
1012 
1013     teamSize = vectorLength = rowsPerTeam = -1;
1014     PetscCall(MatMergeGetLaunchParameters<DefaultExecutionSpace>(Fm, Fnz, -1, teamSize, vectorLength, rowsPerTeam));
1015     mm->F_TeamSize     = teamSize;
1016     mm->F_VectorLength = vectorLength;
1017     mm->F_RowsPerTeam  = rowsPerTeam;
1018   } else PetscCheck(reuse == MAT_REUSE_MATRIX, comm, PETSC_ERR_PLIB, "Unsupported MatReuse enum %d", reuse);
1019 
1020   // Sync E's value to device
1021   akok->a_dual.sync_device();
1022   bkok->a_dual.sync_device();
1023 
1024   // Handy aliases
1025   const auto &Aa = akok->a_dual.view_device();
1026   const auto &Ba = bkok->a_dual.view_device();
1027   const auto &Ai = akok->i_dual.view_device();
1028   const auto &Bi = bkok->i_dual.view_device();
1029 
1030   // Fetch the plans
1031   PetscIntKokkosView  &E_NzLeft  = mm->E_NzLeft;
1032   PetscSF             &bcastSF   = mm->sf;
1033   MatScalarKokkosView &rootBuf   = mm->rootBuf;
1034   MatScalarKokkosView &leafBuf   = mm->leafBuf;
1035   PetscIntKokkosView  &irootloc  = mm->irootloc;
1036   PetscIntKokkosView  &rowoffset = mm->rowoffset;
1037 
1038   PetscInt teamSize     = mm->E_TeamSize;
1039   PetscInt vectorLength = mm->E_VectorLength;
1040   PetscInt rowsPerTeam  = mm->E_RowsPerTeam;
1041   PetscInt workSets     = (irootloc.extent(0) + rowsPerTeam - 1) / rowsPerTeam;
1042 
1043   // Copy rows in A/B of E to rootBuf, then bcast it to leafBuf
1044   PetscCallCXX(Kokkos::parallel_for(
1045     Kokkos::TeamPolicy<>(workSets, teamSize, vectorLength), KOKKOS_LAMBDA(const KokkosTeamMemberType &t) {
1046       Kokkos::parallel_for(Kokkos::TeamThreadRange(t, 0, rowsPerTeam), [&](PetscInt k) {
1047         size_t r = t.league_rank() * rowsPerTeam + k; // r-th entry in irootloc[]
1048         if (r < irootloc.extent(0)) {
1049           PetscInt i      = irootloc(r); // row i of E
1050           PetscInt disp   = rowoffset(r);
1051           PetscInt alen   = Ai(i + 1) - Ai(i);
1052           PetscInt blen   = Bi(i + 1) - Bi(i);
1053           PetscInt nzleft = E_NzLeft(i);
1054 
1055           Kokkos::parallel_for(Kokkos::ThreadVectorRange(t, alen + blen), [&](PetscInt j) {
1056             if (j < nzleft) { // B left
1057               rootBuf(disp + j) = Ba(Bi(i) + j);
1058             } else if (j < nzleft + alen) { // diag A
1059               rootBuf(disp + j) = Aa(Ai(i) + j - nzleft);
1060             } else { // B right
1061               rootBuf(disp + j) = Ba(Bi(i) + j - alen);
1062             }
1063           });
1064         }
1065       });
1066     }));
1067   PetscCall(PetscSFBcastWithMemTypeBegin(bcastSF, MPIU_SCALAR, PETSC_MEMTYPE_KOKKOS, rootBuf.data(), PETSC_MEMTYPE_KOKKOS, leafBuf.data(), MPI_REPLACE));
1068   PetscFunctionReturn(PETSC_SUCCESS);
1069 }
1070 
1071 // To finish MatMPIAIJKokkosBcast.
1072 static PetscErrorCode MatMPIAIJKokkosBcastEnd(Mat E, PetscSF ownerSF, MatReuse reuse, PetscInt *map, MatMatStruct_AB *mm)
1073 {
1074   PetscFunctionBegin;
1075   const auto &Fd  = mm->Fd;
1076   const auto &Fo  = mm->Fo;
1077   const auto &Fdi = Fd.graph.row_map;
1078   const auto &Foi = Fo.graph.row_map;
1079   auto       &Fda = Fd.values;
1080   auto       &Foa = Fo.values;
1081   auto        Fm  = Fd.numRows();
1082 
1083   PetscIntKokkosView  &F_NzLeft     = mm->F_NzLeft;
1084   PetscSF             &bcastSF      = mm->sf;
1085   MatScalarKokkosView &rootBuf      = mm->rootBuf;
1086   MatScalarKokkosView &leafBuf      = mm->leafBuf;
1087   PetscInt             teamSize     = mm->F_TeamSize;
1088   PetscInt             vectorLength = mm->F_VectorLength;
1089   PetscInt             rowsPerTeam  = mm->F_RowsPerTeam;
1090   PetscInt             workSets     = (Fm + rowsPerTeam - 1) / rowsPerTeam;
1091 
1092   PetscCall(PetscSFBcastEnd(bcastSF, MPIU_SCALAR, rootBuf.data(), leafBuf.data(), MPI_REPLACE));
1093 
1094   // Update Fda and Foa with new data in leafBuf (as if it is Fa)
1095   PetscCallCXX(Kokkos::parallel_for(
1096     Kokkos::TeamPolicy<>(workSets, teamSize, vectorLength), KOKKOS_LAMBDA(const KokkosTeamMemberType &t) {
1097       Kokkos::parallel_for(Kokkos::TeamThreadRange(t, 0, rowsPerTeam), [&](PetscInt k) {
1098         PetscInt i = t.league_rank() * rowsPerTeam + k; // i-th row in F
1099         if (i < Fm) {
1100           PetscInt nzLeft = F_NzLeft(i);
1101           PetscInt alen   = Fdi(i + 1) - Fdi(i);
1102           PetscInt blen   = Foi(i + 1) - Foi(i);
1103           PetscInt Fii    = Fdi(i) + Foi(i);
1104 
1105           Kokkos::parallel_for(Kokkos::ThreadVectorRange(t, alen + blen), [&](PetscInt j) {
1106             PetscScalar val = leafBuf(Fii + j);
1107             if (j < nzLeft) { // left
1108               Foa(Foi(i) + j) = val;
1109             } else if (j < nzLeft + alen) { // diag
1110               Fda(Fdi(i) + j - nzLeft) = val;
1111             } else { // right
1112               Foa(Foi(i) + j - alen) = val;
1113             }
1114           });
1115         }
1116       });
1117     }));
1118   PetscFunctionReturn(PETSC_SUCCESS);
1119 }
1120 
1121 static PetscErrorCode MatProductSymbolic_MPIAIJKokkos_AtB(Mat_Product *product, Mat A, Mat B, MatMatStruct_AtB *mm)
1122 {
1123   Mat_MPIAIJ     *ampi = static_cast<Mat_MPIAIJ *>(A->data);
1124   Mat_MPIAIJ     *bmpi = static_cast<Mat_MPIAIJ *>(B->data);
1125   KokkosCsrMatrix Adt, Aot, Ad, Ao, Bd, Bo;
1126   PetscInt        cstart, cend;
1127   MPI_Comm        comm;
1128 
1129   PetscFunctionBegin;
1130   PetscCall(PetscObjectGetComm((PetscObject)B, &comm));
1131   PetscCall(MatSeqAIJKokkosGenerateTranspose_Private(ampi->A, &Adt));
1132   PetscCall(MatSeqAIJKokkosGenerateTranspose_Private(ampi->B, &Aot));
1133   PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(ampi->A, &Ad));
1134   PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(ampi->B, &Ao));
1135   PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(bmpi->A, &Bd));
1136   PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(bmpi->B, &Bo));
1137 
1138   // TODO: add command line options to select spgemm algorithms
1139   auto spgemm_alg = KokkosSparse::SPGEMMAlgorithm::SPGEMM_DEFAULT; // default is TPL if enabled, otherwise KK
1140 
1141   // CUDA-10.2's spgemm has bugs. We prefer the SpGEMMreuse APIs introduced in cuda-11.4
1142 #if defined(KOKKOSKERNELS_ENABLE_TPL_CUSPARSE)
1143   #if PETSC_PKG_CUDA_VERSION_LT(11, 4, 0)
1144   spgemm_alg = KokkosSparse::SPGEMMAlgorithm::SPGEMM_KK;
1145   #endif
1146 #endif
1147 
1148   PetscCallCXX(mm->kh1.create_spgemm_handle(spgemm_alg));
1149   PetscCallCXX(mm->kh2.create_spgemm_handle(spgemm_alg));
1150   PetscCallCXX(mm->kh3.create_spgemm_handle(spgemm_alg));
1151   PetscCallCXX(mm->kh4.create_spgemm_handle(spgemm_alg));
1152 
1153   // Aot * (B's diag + B's off-diag)
1154   PetscCallCXX(KokkosSparse::spgemm_symbolic(mm->kh3, Aot, false, Bd, false, mm->C3));
1155   PetscCallCXX(KokkosSparse::spgemm_symbolic(mm->kh4, Aot, false, Bo, false, mm->C4));
1156   // KK spgemm_symbolic() only populates the result's row map, but not its columns.
1157   // TODO: Remove the fake spgemm_numeric() after KK fixed this problem.
1158   PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh3, Aot, false, Bd, false, mm->C3));
1159   PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh4, Aot, false, Bo, false, mm->C4));
1160 #if PETSC_PKG_KOKKOS_KERNELS_VERSION_LT(4, 0, 0)
1161   PetscCallCXX(sort_crs_matrix(mm->C3));
1162   PetscCallCXX(sort_crs_matrix(mm->C4));
1163 #endif
1164 
1165   // Reduce E (i.e., C3 and C4)'s rows to form F, and overlap the communication
1166   PetscIntKokkosViewHost map_h("map_h", bmpi->B->cmap->n);
1167   PetscCall(MatGetOwnershipRangeColumn(B, &cstart, &cend));
1168   PetscCall(MatMPIAIJKokkosReduceBegin(comm, mm->C3, mm->C4, cstart, cend, bmpi->garray, ampi->Mvctx, MAT_INITIAL_MATRIX, map_h.data(), mm));
1169 
1170   // Adt * (B's diag + B's off-diag)
1171   PetscCallCXX(KokkosSparse::spgemm_symbolic(mm->kh1, Adt, false, Bd, false, mm->C1));
1172   PetscCallCXX(KokkosSparse::spgemm_symbolic(mm->kh2, Adt, false, Bo, false, mm->C2_mid));
1173   PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh1, Adt, false, Bd, false, mm->C1));
1174   PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh2, Adt, false, Bo, false, mm->C2_mid));
1175 #if PETSC_PKG_KOKKOS_KERNELS_VERSION_LT(4, 0, 0)
1176   PetscCallCXX(sort_crs_matrix(mm->C1));
1177   PetscCallCXX(sort_crs_matrix(mm->C2_mid));
1178 #endif
1179 
1180   PetscCall(MatMPIAIJKokkosReduceEnd(comm, mm->C3, mm->C4, cstart, cend, bmpi->garray, ampi->Mvctx, MAT_INITIAL_MATRIX, map_h.data(), mm));
1181 
1182   // Create C2, which shares a, i arrays with C2_mid, but with new column indices and potentially larger column size
1183   MatColIdxKokkosView oldj = mm->C2_mid.graph.entries, newj("j", oldj.extent(0));
1184   PetscIntKokkosView  map  = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), map_h);
1185   PetscCallCXX(Kokkos::parallel_for(
1186     oldj.extent(0), KOKKOS_LAMBDA(const PetscInt i) { newj(i) = map(oldj(i)); }));
1187   PetscCallCXX(mm->C2 = KokkosCsrMatrix("C2", mm->C2_mid.numRows(), mm->n /*new column size*/, mm->C2_mid.nnz(), mm->C2_mid.values, mm->C2_mid.graph.row_map, newj));
1188 
1189   // C = (C1+Fd, C2+Fo)
1190   PetscCallCXX(mm->kh1.create_spadd_handle(true)); // C1, Fd are sorted
1191   PetscCallCXX(mm->kh2.create_spadd_handle(true)); // C2, Fo are sorted
1192   PetscCallCXX(KokkosSparse::spadd_symbolic(&mm->kh1, mm->C1, mm->Fd, mm->Cd));
1193   PetscCallCXX(KokkosSparse::spadd_symbolic(&mm->kh2, mm->C2, mm->Fo, mm->Co));
1194   PetscCallCXX(KokkosSparse::spadd_numeric(&mm->kh1, 1.0, mm->C1, 1.0, mm->Fd, mm->Cd));
1195   PetscCallCXX(KokkosSparse::spadd_numeric(&mm->kh2, 1.0, mm->C2, 1.0, mm->Fo, mm->Co));
1196   PetscFunctionReturn(PETSC_SUCCESS);
1197 }
1198 
1199 static PetscErrorCode MatProductNumeric_MPIAIJKokkos_AtB(Mat_Product *product, Mat A, Mat B, MatMatStruct_AtB *mm)
1200 {
1201   Mat_MPIAIJ     *ampi = static_cast<Mat_MPIAIJ *>(A->data);
1202   Mat_MPIAIJ     *bmpi = static_cast<Mat_MPIAIJ *>(B->data);
1203   KokkosCsrMatrix Adt, Aot, Bd, Bo;
1204   MPI_Comm        comm;
1205 
1206   PetscFunctionBegin;
1207   PetscCall(PetscObjectGetComm((PetscObject)B, &comm));
1208   PetscCall(MatSeqAIJKokkosGenerateTranspose_Private(ampi->A, &Adt));
1209   PetscCall(MatSeqAIJKokkosGenerateTranspose_Private(ampi->B, &Aot));
1210   PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(bmpi->A, &Bd));
1211   PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(bmpi->B, &Bo));
1212 
1213   // Aot * (B's diag + B's off-diag)
1214   PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh3, Aot, false, Bd, false, mm->C3));
1215   PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh4, Aot, false, Bo, false, mm->C4));
1216 
1217   // Reduce E (i.e., C3 and C4)'s rows to form F, and overlap the communication
1218   PetscCall(MatMPIAIJKokkosReduceBegin(comm, mm->C3, mm->C4, 0, 0, NULL, NULL, MAT_REUSE_MATRIX, NULL, mm));
1219 
1220   // Adt * (B's diag + B's off-diag)
1221   PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh1, Adt, false, Bd, false, mm->C1));
1222   PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh2, Adt, false, Bo, false, mm->C2_mid));
1223 
1224   PetscCall(MatMPIAIJKokkosReduceEnd(comm, mm->C3, mm->C4, 0, 0, NULL, NULL, MAT_REUSE_MATRIX, NULL, mm));
1225 
1226   // C = (C1+Fd, C2+Fo)
1227   PetscCallCXX(KokkosSparse::spadd_numeric(&mm->kh1, 1.0, mm->C1, 1.0, mm->Fd, mm->Cd));
1228   PetscCallCXX(KokkosSparse::spadd_numeric(&mm->kh2, 1.0, mm->C2, 1.0, mm->Fo, mm->Co));
1229   PetscFunctionReturn(PETSC_SUCCESS);
1230 }
1231 
1232 /* MatProductSymbolic_MPIAIJKokkos_AB - AB flavor of MatProductSymbolic_MPIAIJKokkos
1233 
1234   Input Parameters:
1235 +  product  - Mat_Product which carried out the computation. Passed in to access info about this mat product.
1236 .  A        - an MPIAIJKOKKOS matrix
1237 .  B        - an MPIAIJKOKKOS matrix
1238 -  mm       - a struct used to stash intermediate data when computing AB. Persist from symbolic to numeric operations.
1239 */
1240 static PetscErrorCode MatProductSymbolic_MPIAIJKokkos_AB(Mat_Product *product, Mat A, Mat B, MatMatStruct_AB *mm)
1241 {
1242   Mat_MPIAIJ     *ampi = static_cast<Mat_MPIAIJ *>(A->data);
1243   Mat_MPIAIJ     *bmpi = static_cast<Mat_MPIAIJ *>(B->data);
1244   KokkosCsrMatrix Ad, Ao, Bd, Bo;
1245 
1246   PetscFunctionBegin;
1247   PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(ampi->A, &Ad));
1248   PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(ampi->B, &Ao));
1249   PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(bmpi->A, &Bd));
1250   PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(bmpi->B, &Bo));
1251 
1252   // TODO: add command line options to select spgemm algorithms
1253   auto spgemm_alg = KokkosSparse::SPGEMMAlgorithm::SPGEMM_DEFAULT; // default is TPL if enabled, otherwise KK
1254 
1255   // CUDA-10.2's spgemm has bugs. We prefer the SpGEMMreuse APIs introduced in cuda-11.4
1256 #if defined(KOKKOSKERNELS_ENABLE_TPL_CUSPARSE)
1257   #if PETSC_PKG_CUDA_VERSION_LT(11, 4, 0)
1258   spgemm_alg = KokkosSparse::SPGEMMAlgorithm::SPGEMM_KK;
1259   #endif
1260 #endif
1261 
1262   mm->kh1.create_spgemm_handle(spgemm_alg);
1263   mm->kh2.create_spgemm_handle(spgemm_alg);
1264   mm->kh3.create_spgemm_handle(spgemm_alg);
1265   mm->kh4.create_spgemm_handle(spgemm_alg);
1266 
1267   // Bcast B's rows to form F, and overlap the communication
1268   PetscIntKokkosViewHost map_h("map_h", bmpi->B->cmap->n);
1269   PetscCall(MatMPIAIJKokkosBcastBegin(B, ampi->Mvctx, MAT_INITIAL_MATRIX, map_h.data(), mm));
1270 
1271   // A's diag * (B's diag + B's off-diag)
1272   PetscCallCXX(KokkosSparse::spgemm_symbolic(mm->kh1, Ad, false, Bd, false, mm->C1));
1273   PetscCallCXX(KokkosSparse::spgemm_symbolic(mm->kh2, Ad, false, Bo, false, mm->C2_mid)); // C2 aliases with C2_mid, except with new column indices
1274   // KK spgemm_symbolic() only populates the result's row map, but not its columns.
1275   // TODO: Remove the fake spgemm_numeric() after KK fixed this problem.
1276   PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh1, Ad, false, Bd, false, mm->C1));
1277   PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh2, Ad, false, Bo, false, mm->C2_mid));
1278 #if PETSC_PKG_KOKKOS_KERNELS_VERSION_LT(4, 0, 0)
1279   PetscCallCXX(sort_crs_matrix(mm->C1));
1280   PetscCallCXX(sort_crs_matrix(mm->C2_mid));
1281 #endif
1282 
1283   PetscCall(MatMPIAIJKokkosBcastEnd(B, ampi->Mvctx, MAT_INITIAL_MATRIX, map_h.data(), mm));
1284 
1285   // A's off-diag * (F's diag + F's off-diag)
1286   PetscCallCXX(KokkosSparse::spgemm_symbolic(mm->kh3, Ao, false, mm->Fd, false, mm->C3));
1287   PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh3, Ao, false, mm->Fd, false, mm->C3));
1288   PetscCallCXX(KokkosSparse::spgemm_symbolic(mm->kh4, Ao, false, mm->Fo, false, mm->C4));
1289   PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh4, Ao, false, mm->Fo, false, mm->C4));
1290 #if PETSC_PKG_KOKKOS_KERNELS_VERSION_LT(4, 0, 0)
1291   PetscCallCXX(sort_crs_matrix(mm->C3));
1292   PetscCallCXX(sort_crs_matrix(mm->C4));
1293 #endif
1294 
1295   // Create C2, which shares a, i arrays with C2_mid, but with new column indices and potentially larger column size
1296   MatColIdxKokkosView oldj = mm->C2_mid.graph.entries, newj("j", oldj.extent(0));
1297   PetscIntKokkosView  map  = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), map_h);
1298   PetscCallCXX(Kokkos::parallel_for(
1299     oldj.extent(0), KOKKOS_LAMBDA(const PetscInt i) { newj(i) = map(oldj(i)); }));
1300   mm->C2 = KokkosCsrMatrix("C2", mm->C2_mid.numRows(), mm->n /*new column size*/, mm->C2_mid.nnz(), mm->C2_mid.values, mm->C2_mid.graph.row_map, newj);
1301 
1302   // C = (Cd, Co) = (C1+C3, C2+C4)
1303   mm->kh1.create_spadd_handle(true); // C1, C3 are sorted
1304   mm->kh2.create_spadd_handle(true); // C2, C4 are sorted
1305   PetscCallCXX(KokkosSparse::spadd_symbolic(&mm->kh1, mm->C1, mm->C3, mm->Cd));
1306   PetscCallCXX(KokkosSparse::spadd_symbolic(&mm->kh2, mm->C2, mm->C4, mm->Co));
1307   PetscCallCXX(KokkosSparse::spadd_numeric(&mm->kh1, 1.0, mm->C1, 1.0, mm->C3, mm->Cd));
1308   PetscCallCXX(KokkosSparse::spadd_numeric(&mm->kh2, 1.0, mm->C2, 1.0, mm->C4, mm->Co));
1309   PetscFunctionReturn(PETSC_SUCCESS);
1310 }
1311 
1312 static PetscErrorCode MatProductNumeric_MPIAIJKokkos_AB(Mat_Product *product, Mat A, Mat B, MatMatStruct_AB *mm)
1313 {
1314   Mat_MPIAIJ     *ampi = static_cast<Mat_MPIAIJ *>(A->data);
1315   Mat_MPIAIJ     *bmpi = static_cast<Mat_MPIAIJ *>(B->data);
1316   KokkosCsrMatrix Ad, Ao, Bd, Bo;
1317 
1318   PetscFunctionBegin;
1319   PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(ampi->A, &Ad));
1320   PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(ampi->B, &Ao));
1321   PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(bmpi->A, &Bd));
1322   PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(bmpi->B, &Bo));
1323 
1324   // Bcast B's rows to form F, and overlap the communication
1325   PetscCall(MatMPIAIJKokkosBcastBegin(B, NULL, MAT_REUSE_MATRIX, NULL, mm));
1326 
1327   // A's diag * (B's diag + B's off-diag)
1328   PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh1, Ad, false, Bd, false, mm->C1));
1329   PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh2, Ad, false, Bo, false, mm->C2_mid));
1330 
1331   PetscCall(MatMPIAIJKokkosBcastEnd(B, NULL, MAT_REUSE_MATRIX, NULL, mm));
1332 
1333   // A's off-diag * (F's diag + F's off-diag)
1334   PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh3, Ao, false, mm->Fd, false, mm->C3));
1335   PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh4, Ao, false, mm->Fo, false, mm->C4));
1336 
1337   // C = (Cd, Co) = (C1+C3, C2+C4)
1338   PetscCallCXX(KokkosSparse::spadd_numeric(&mm->kh1, 1.0, mm->C1, 1.0, mm->C3, mm->Cd));
1339   PetscCallCXX(KokkosSparse::spadd_numeric(&mm->kh2, 1.0, mm->C2, 1.0, mm->C4, mm->Co));
1340   PetscFunctionReturn(PETSC_SUCCESS);
1341 }
1342 
1343 PetscErrorCode MatProductNumeric_MPIAIJKokkos(Mat C)
1344 {
1345   Mat_MPIAIJ                  *cmpi = static_cast<Mat_MPIAIJ *>(C->data);
1346   Mat_Product                 *product;
1347   MatProductData_MPIAIJKokkos *pdata;
1348   MatProductType               ptype;
1349   Mat                          A, B;
1350 
1351   PetscFunctionBegin;
1352   MatCheckProduct(C, 1); // make sure C is a product
1353   product = C->product;
1354   pdata   = static_cast<MatProductData_MPIAIJKokkos *>(product->data);
1355   ptype   = product->type;
1356   A       = product->A;
1357   B       = product->B;
1358 
1359   // See if numeric has already been done in symbolic (e.g., user calls MatMatMult(A,B,MAT_INITIAL_MATRIX,..,C)).
1360   // If yes, skip the numeric, but reset the flag so that next time when user calls MatMatMult(E,F,MAT_REUSE_MATRIX,..,C),
1361   // we still do numeric.
1362   if (pdata->reusesym) { // numeric reuses results from symbolic
1363     pdata->reusesym = PETSC_FALSE;
1364     PetscFunctionReturn(PETSC_SUCCESS);
1365   }
1366 
1367   if (ptype == MATPRODUCT_AB) {
1368     PetscCall(MatProductNumeric_MPIAIJKokkos_AB(product, A, B, pdata->mmAB));
1369   } else if (ptype == MATPRODUCT_AtB) {
1370     PetscCall(MatProductNumeric_MPIAIJKokkos_AtB(product, A, B, pdata->mmAtB));
1371   } else if (ptype == MATPRODUCT_PtAP) { // BtAB, computed by Z = AB; C= BtZ
1372     PetscCall(MatProductNumeric_MPIAIJKokkos_AB(product, A, B, pdata->mmAB));
1373     PetscCall(MatProductNumeric_MPIAIJKokkos_AtB(product, B, pdata->Z, pdata->mmAtB));
1374   }
1375 
1376   PetscCall(MatSeqAIJKokkosModifyDevice(cmpi->A)); // mark that A, B on device are modified
1377   PetscCall(MatSeqAIJKokkosModifyDevice(cmpi->B));
1378   PetscFunctionReturn(PETSC_SUCCESS);
1379 }
1380 
1381 PetscErrorCode MatProductSymbolic_MPIAIJKokkos(Mat C)
1382 {
1383   Mat                          A, B;
1384   Mat_Product                 *product;
1385   MatProductType               ptype;
1386   MatProductData_MPIAIJKokkos *pdata;
1387   MatMatStruct                *mm = NULL;
1388   PetscInt                     m, n, M, N;
1389   Mat                          Cd, Co;
1390   MPI_Comm                     comm;
1391 
1392   PetscFunctionBegin;
1393   PetscCall(PetscObjectGetComm((PetscObject)C, &comm));
1394   MatCheckProduct(C, 1);
1395   product = C->product;
1396   PetscCheck(!product->data, comm, PETSC_ERR_PLIB, "Product data not empty");
1397   ptype = product->type;
1398   A     = product->A;
1399   B     = product->B;
1400 
1401   switch (ptype) {
1402   case MATPRODUCT_AB:
1403     m = A->rmap->n;
1404     n = B->cmap->n;
1405     M = A->rmap->N;
1406     N = B->cmap->N;
1407     break;
1408   case MATPRODUCT_AtB:
1409     m = A->cmap->n;
1410     n = B->cmap->n;
1411     M = A->cmap->N;
1412     N = B->cmap->N;
1413     break;
1414   case MATPRODUCT_PtAP:
1415     m = B->cmap->n;
1416     n = B->cmap->n;
1417     M = B->cmap->N;
1418     N = B->cmap->N;
1419     break; /* BtAB */
1420   default:
1421     SETERRQ(comm, PETSC_ERR_PLIB, "Not for product type %s", MatProductTypes[ptype]);
1422   }
1423 
1424   PetscCall(MatSetSizes(C, m, n, M, N));
1425   PetscCall(PetscLayoutSetUp(C->rmap));
1426   PetscCall(PetscLayoutSetUp(C->cmap));
1427   PetscCall(MatSetType(C, ((PetscObject)A)->type_name));
1428 
1429   pdata           = new MatProductData_MPIAIJKokkos();
1430   pdata->reusesym = product->api_user;
1431 
1432   if (ptype == MATPRODUCT_AB) {
1433     auto mmAB = new MatMatStruct_AB();
1434     PetscCall(MatProductSymbolic_MPIAIJKokkos_AB(product, A, B, mmAB));
1435     mm = pdata->mmAB = mmAB;
1436   } else if (ptype == MATPRODUCT_AtB) {
1437     auto mmAtB = new MatMatStruct_AtB();
1438     PetscCall(MatProductSymbolic_MPIAIJKokkos_AtB(product, A, B, mmAtB));
1439     mm = pdata->mmAtB = mmAtB;
1440   } else if (ptype == MATPRODUCT_PtAP) { // C = BtAB, computed as Z = AB; C= BtZ
1441     Mat Zd, Zo, Z;                       // Zd, Zo are owned by pdata->Z
1442 
1443     auto mmAB = new MatMatStruct_AB();
1444     PetscCall(MatProductSymbolic_MPIAIJKokkos_AB(product, A, B, mmAB)); // Z stored as mmAB->{Cd, Co}
1445     PetscCall(MatCreateSeqAIJKokkosWithKokkosCsrMatrix(PETSC_COMM_SELF, mmAB->Cd, &Zd));
1446     PetscCall(MatCreateSeqAIJKokkosWithKokkosCsrMatrix(PETSC_COMM_SELF, mmAB->Co, &Zo));
1447     pdata->mmAB = mmAB;
1448 
1449     m = A->rmap->n; // Z's layout
1450     n = B->cmap->n;
1451     M = A->rmap->N;
1452     N = B->cmap->N;
1453     PetscCall(MatCreate(comm, &Z));
1454     PetscCall(MatSetSizes(Z, m, n, M, N));
1455     PetscCall(PetscLayoutSetUp(Z->rmap));
1456     PetscCall(PetscLayoutSetUp(Z->cmap));
1457     PetscCall(MatSetType(Z, MATMPIAIJKOKKOS));
1458     PetscCall(MatSetMPIAIJKokkosWithSplitSeqAIJKokkosMatrices(Z, Zd, Zo, mmAB->garray));
1459 
1460     auto mmAtB = new MatMatStruct_AtB();
1461     PetscCall(MatProductSymbolic_MPIAIJKokkos_AtB(product, B, Z, mmAtB)); // final result C stored as mmAtB->{Cd, Co}
1462 
1463     pdata->Z = Z; // give ownership to pdata
1464     mm = pdata->mmAtB = mmAtB;
1465   }
1466 
1467   PetscCall(MatCreateSeqAIJKokkosWithKokkosCsrMatrix(PETSC_COMM_SELF, mm->Cd, &Cd));
1468   PetscCall(MatCreateSeqAIJKokkosWithKokkosCsrMatrix(PETSC_COMM_SELF, mm->Co, &Co));
1469   PetscCall(MatSetMPIAIJKokkosWithSplitSeqAIJKokkosMatrices(C, Cd, Co, mm->garray));
1470 
1471   C->product->data       = pdata;
1472   C->product->destroy    = MatProductDataDestroy_MPIAIJKokkos;
1473   C->ops->productnumeric = MatProductNumeric_MPIAIJKokkos;
1474   PetscFunctionReturn(PETSC_SUCCESS);
1475 }
1476 
1477 PETSC_INTERN PetscErrorCode MatProductSetFromOptions_MPIAIJKokkos(Mat mat)
1478 {
1479   Mat_Product *product = mat->product;
1480   PetscBool    match   = PETSC_FALSE;
1481   PetscBool    usecpu  = PETSC_FALSE;
1482 
1483   PetscFunctionBegin;
1484   MatCheckProduct(mat, 1);
1485   if (!product->A->boundtocpu && !product->B->boundtocpu) PetscCall(PetscObjectTypeCompare((PetscObject)product->B, ((PetscObject)product->A)->type_name, &match));
1486   if (match) { /* we can always fallback to the CPU if requested */
1487     switch (product->type) {
1488     case MATPRODUCT_AB:
1489       if (product->api_user) {
1490         PetscOptionsBegin(PetscObjectComm((PetscObject)mat), ((PetscObject)mat)->prefix, "MatMatMult", "Mat");
1491         PetscCall(PetscOptionsBool("-matmatmult_backend_cpu", "Use CPU code", "MatMatMult", usecpu, &usecpu, NULL));
1492         PetscOptionsEnd();
1493       } else {
1494         PetscOptionsBegin(PetscObjectComm((PetscObject)mat), ((PetscObject)mat)->prefix, "MatProduct_AB", "Mat");
1495         PetscCall(PetscOptionsBool("-mat_product_algorithm_backend_cpu", "Use CPU code", "MatMatMult", usecpu, &usecpu, NULL));
1496         PetscOptionsEnd();
1497       }
1498       break;
1499     case MATPRODUCT_AtB:
1500       if (product->api_user) {
1501         PetscOptionsBegin(PetscObjectComm((PetscObject)mat), ((PetscObject)mat)->prefix, "MatTransposeMatMult", "Mat");
1502         PetscCall(PetscOptionsBool("-mattransposematmult_backend_cpu", "Use CPU code", "MatTransposeMatMult", usecpu, &usecpu, NULL));
1503         PetscOptionsEnd();
1504       } else {
1505         PetscOptionsBegin(PetscObjectComm((PetscObject)mat), ((PetscObject)mat)->prefix, "MatProduct_AtB", "Mat");
1506         PetscCall(PetscOptionsBool("-mat_product_algorithm_backend_cpu", "Use CPU code", "MatTransposeMatMult", usecpu, &usecpu, NULL));
1507         PetscOptionsEnd();
1508       }
1509       break;
1510     case MATPRODUCT_PtAP:
1511       if (product->api_user) {
1512         PetscOptionsBegin(PetscObjectComm((PetscObject)mat), ((PetscObject)mat)->prefix, "MatPtAP", "Mat");
1513         PetscCall(PetscOptionsBool("-matptap_backend_cpu", "Use CPU code", "MatPtAP", usecpu, &usecpu, NULL));
1514         PetscOptionsEnd();
1515       } else {
1516         PetscOptionsBegin(PetscObjectComm((PetscObject)mat), ((PetscObject)mat)->prefix, "MatProduct_PtAP", "Mat");
1517         PetscCall(PetscOptionsBool("-mat_product_algorithm_backend_cpu", "Use CPU code", "MatPtAP", usecpu, &usecpu, NULL));
1518         PetscOptionsEnd();
1519       }
1520       break;
1521     default:
1522       break;
1523     }
1524     match = (PetscBool)!usecpu;
1525   }
1526   if (match) {
1527     switch (product->type) {
1528     case MATPRODUCT_AB:
1529     case MATPRODUCT_AtB:
1530     case MATPRODUCT_PtAP:
1531       mat->ops->productsymbolic = MatProductSymbolic_MPIAIJKokkos;
1532       break;
1533     default:
1534       break;
1535     }
1536   }
1537   /* fallback to MPIAIJ ops */
1538   if (!mat->ops->productsymbolic) PetscCall(MatProductSetFromOptions_MPIAIJ(mat));
1539   PetscFunctionReturn(PETSC_SUCCESS);
1540 }
1541 
1542 static PetscErrorCode MatSetPreallocationCOO_MPIAIJKokkos(Mat mat, PetscCount coo_n, PetscInt coo_i[], PetscInt coo_j[])
1543 {
1544   Mat_MPIAIJ       *mpiaij = (Mat_MPIAIJ *)mat->data;
1545   Mat_MPIAIJKokkos *mpikok;
1546 
1547   PetscFunctionBegin;
1548   PetscCall(MatSetPreallocationCOO_MPIAIJ(mat, coo_n, coo_i, coo_j)); /* mpiaij->A,B's type is set to seqaijkokkos */
1549   mat->preallocated = PETSC_TRUE;
1550   PetscCall(MatAssemblyBegin(mat, MAT_FINAL_ASSEMBLY));
1551   PetscCall(MatAssemblyEnd(mat, MAT_FINAL_ASSEMBLY));
1552   PetscCall(MatZeroEntries(mat));
1553   mpikok = static_cast<Mat_MPIAIJKokkos *>(mpiaij->spptr);
1554   delete mpikok;
1555   mpiaij->spptr = new Mat_MPIAIJKokkos(mpiaij);
1556   PetscFunctionReturn(PETSC_SUCCESS);
1557 }
1558 
1559 static PetscErrorCode MatSetValuesCOO_MPIAIJKokkos(Mat mat, const PetscScalar v[], InsertMode imode)
1560 {
1561   Mat_MPIAIJ                 *mpiaij = static_cast<Mat_MPIAIJ *>(mat->data);
1562   Mat_MPIAIJKokkos           *mpikok = static_cast<Mat_MPIAIJKokkos *>(mpiaij->spptr);
1563   Mat                         A = mpiaij->A, B = mpiaij->B;
1564   PetscCount                  Annz = mpiaij->Annz, Annz2 = mpiaij->Annz2, Bnnz = mpiaij->Bnnz, Bnnz2 = mpiaij->Bnnz2;
1565   MatScalarKokkosView         Aa, Ba;
1566   MatScalarKokkosView         v1;
1567   MatScalarKokkosView        &vsend  = mpikok->sendbuf_d;
1568   const MatScalarKokkosView  &v2     = mpikok->recvbuf_d;
1569   const PetscCountKokkosView &Ajmap1 = mpikok->Ajmap1_d, Ajmap2 = mpikok->Ajmap2_d, Aimap2 = mpikok->Aimap2_d;
1570   const PetscCountKokkosView &Bjmap1 = mpikok->Bjmap1_d, Bjmap2 = mpikok->Bjmap2_d, Bimap2 = mpikok->Bimap2_d;
1571   const PetscCountKokkosView &Aperm1 = mpikok->Aperm1_d, Aperm2 = mpikok->Aperm2_d, Bperm1 = mpikok->Bperm1_d, Bperm2 = mpikok->Bperm2_d;
1572   const PetscCountKokkosView &Cperm1 = mpikok->Cperm1_d;
1573   PetscMemType                memtype;
1574 
1575   PetscFunctionBegin;
1576   PetscCall(PetscGetMemType(v, &memtype)); /* Return PETSC_MEMTYPE_HOST when v is NULL */
1577   if (PetscMemTypeHost(memtype)) {         /* If user gave v[] in host, we need to copy it to device if any */
1578     v1 = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), MatScalarKokkosViewHost((PetscScalar *)v, mpiaij->coo_n));
1579   } else {
1580     v1 = MatScalarKokkosView((PetscScalar *)v, mpiaij->coo_n); /* Directly use v[]'s memory */
1581   }
1582 
1583   if (imode == INSERT_VALUES) {
1584     PetscCall(MatSeqAIJGetKokkosViewWrite(A, &Aa)); /* write matrix values */
1585     PetscCall(MatSeqAIJGetKokkosViewWrite(B, &Ba));
1586   } else {
1587     PetscCall(MatSeqAIJGetKokkosView(A, &Aa)); /* read & write matrix values */
1588     PetscCall(MatSeqAIJGetKokkosView(B, &Ba));
1589   }
1590 
1591   /* Pack entries to be sent to remote */
1592   Kokkos::parallel_for(
1593     vsend.extent(0), KOKKOS_LAMBDA(const PetscCount i) { vsend(i) = v1(Cperm1(i)); });
1594 
1595   /* Send remote entries to their owner and overlap the communication with local computation */
1596   PetscCall(PetscSFReduceWithMemTypeBegin(mpiaij->coo_sf, MPIU_SCALAR, PETSC_MEMTYPE_KOKKOS, vsend.data(), PETSC_MEMTYPE_KOKKOS, v2.data(), MPI_REPLACE));
1597   /* Add local entries to A and B in one kernel */
1598   Kokkos::parallel_for(
1599     Annz + Bnnz, KOKKOS_LAMBDA(PetscCount i) {
1600       PetscScalar sum = 0.0;
1601       if (i < Annz) {
1602         for (PetscCount k = Ajmap1(i); k < Ajmap1(i + 1); k++) sum += v1(Aperm1(k));
1603         Aa(i) = (imode == INSERT_VALUES ? 0.0 : Aa(i)) + sum;
1604       } else {
1605         i -= Annz;
1606         for (PetscCount k = Bjmap1(i); k < Bjmap1(i + 1); k++) sum += v1(Bperm1(k));
1607         Ba(i) = (imode == INSERT_VALUES ? 0.0 : Ba(i)) + sum;
1608       }
1609     });
1610   PetscCall(PetscSFReduceEnd(mpiaij->coo_sf, MPIU_SCALAR, vsend.data(), v2.data(), MPI_REPLACE));
1611 
1612   /* Add received remote entries to A and B in one kernel */
1613   Kokkos::parallel_for(
1614     Annz2 + Bnnz2, KOKKOS_LAMBDA(PetscCount i) {
1615       if (i < Annz2) {
1616         for (PetscCount k = Ajmap2(i); k < Ajmap2(i + 1); k++) Aa(Aimap2(i)) += v2(Aperm2(k));
1617       } else {
1618         i -= Annz2;
1619         for (PetscCount k = Bjmap2(i); k < Bjmap2(i + 1); k++) Ba(Bimap2(i)) += v2(Bperm2(k));
1620       }
1621     });
1622 
1623   if (imode == INSERT_VALUES) {
1624     PetscCall(MatSeqAIJRestoreKokkosViewWrite(A, &Aa)); /* Increase A & B's state etc. */
1625     PetscCall(MatSeqAIJRestoreKokkosViewWrite(B, &Ba));
1626   } else {
1627     PetscCall(MatSeqAIJRestoreKokkosView(A, &Aa));
1628     PetscCall(MatSeqAIJRestoreKokkosView(B, &Ba));
1629   }
1630   PetscFunctionReturn(PETSC_SUCCESS);
1631 }
1632 
1633 PetscErrorCode MatDestroy_MPIAIJKokkos(Mat A)
1634 {
1635   Mat_MPIAIJ *mpiaij = (Mat_MPIAIJ *)A->data;
1636 
1637   PetscFunctionBegin;
1638   PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatMPIAIJSetPreallocation_C", NULL));
1639   PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatMPIAIJGetLocalMatMerge_C", NULL));
1640   PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatSetPreallocationCOO_C", NULL));
1641   PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatSetValuesCOO_C", NULL));
1642   delete (Mat_MPIAIJKokkos *)mpiaij->spptr;
1643   PetscCall(MatDestroy_MPIAIJ(A));
1644   PetscFunctionReturn(PETSC_SUCCESS);
1645 }
1646 
1647 PETSC_INTERN PetscErrorCode MatConvert_MPIAIJ_MPIAIJKokkos(Mat A, MatType mtype, MatReuse reuse, Mat *newmat)
1648 {
1649   Mat         B;
1650   Mat_MPIAIJ *a;
1651 
1652   PetscFunctionBegin;
1653   if (reuse == MAT_INITIAL_MATRIX) {
1654     PetscCall(MatDuplicate(A, MAT_COPY_VALUES, newmat));
1655   } else if (reuse == MAT_REUSE_MATRIX) {
1656     PetscCall(MatCopy(A, *newmat, SAME_NONZERO_PATTERN));
1657   }
1658   B = *newmat;
1659 
1660   B->boundtocpu = PETSC_FALSE;
1661   PetscCall(PetscFree(B->defaultvectype));
1662   PetscCall(PetscStrallocpy(VECKOKKOS, &B->defaultvectype));
1663   PetscCall(PetscObjectChangeTypeName((PetscObject)B, MATMPIAIJKOKKOS));
1664 
1665   a = static_cast<Mat_MPIAIJ *>(A->data);
1666   if (a->A) PetscCall(MatSetType(a->A, MATSEQAIJKOKKOS));
1667   if (a->B) PetscCall(MatSetType(a->B, MATSEQAIJKOKKOS));
1668   if (a->lvec) PetscCall(VecSetType(a->lvec, VECSEQKOKKOS));
1669 
1670   B->ops->assemblyend           = MatAssemblyEnd_MPIAIJKokkos;
1671   B->ops->mult                  = MatMult_MPIAIJKokkos;
1672   B->ops->multadd               = MatMultAdd_MPIAIJKokkos;
1673   B->ops->multtranspose         = MatMultTranspose_MPIAIJKokkos;
1674   B->ops->productsetfromoptions = MatProductSetFromOptions_MPIAIJKokkos;
1675   B->ops->destroy               = MatDestroy_MPIAIJKokkos;
1676 
1677   PetscCall(PetscObjectComposeFunction((PetscObject)B, "MatMPIAIJSetPreallocation_C", MatMPIAIJSetPreallocation_MPIAIJKokkos));
1678   PetscCall(PetscObjectComposeFunction((PetscObject)B, "MatMPIAIJGetLocalMatMerge_C", MatMPIAIJGetLocalMatMerge_MPIAIJKokkos));
1679   PetscCall(PetscObjectComposeFunction((PetscObject)B, "MatSetPreallocationCOO_C", MatSetPreallocationCOO_MPIAIJKokkos));
1680   PetscCall(PetscObjectComposeFunction((PetscObject)B, "MatSetValuesCOO_C", MatSetValuesCOO_MPIAIJKokkos));
1681   PetscFunctionReturn(PETSC_SUCCESS);
1682 }
1683 /*MC
1684    MATAIJKOKKOS - "mpiaijkokkos", a matrix type to be used for CSR sparse matrices with Kokkos
1685 
1686    A matrix type type using Kokkos-Kernels CrsMatrix type for portability across different device types
1687 
1688    Options Database Key:
1689 .  -mat_type aijkokkos - sets the matrix type to `MATAIJKOKKOS`
1690 
1691   Level: beginner
1692 
1693 .seealso: [](chapter_matrices), `Mat`, `MatCreateAIJKokkos()`, `MATSEQAIJKOKKOS`, `MATSEQAIJ`, `MATMPIAIJ`
1694 M*/
1695 PETSC_EXTERN PetscErrorCode MatCreate_MPIAIJKokkos(Mat A)
1696 {
1697   PetscFunctionBegin;
1698   PetscCall(PetscKokkosInitializeCheck());
1699   PetscCall(MatCreate_MPIAIJ(A));
1700   PetscCall(MatConvert_MPIAIJ_MPIAIJKokkos(A, MATMPIAIJKOKKOS, MAT_INPLACE_MATRIX, &A));
1701   PetscFunctionReturn(PETSC_SUCCESS);
1702 }
1703 
1704 /*@C
1705    MatCreateAIJKokkos - Creates a sparse matrix in `MATAIJKOKOS` (compressed row) format
1706    (the default parallel PETSc format).  This matrix will ultimately pushed down
1707    to Kokkos for calculations.
1708 
1709    Collective
1710 
1711    Input Parameters:
1712 +  comm - MPI communicator, set to `PETSC_COMM_SELF`
1713 .  m - number of local rows (or `PETSC_DECIDE` to have calculated if `M` is given)
1714            This value should be the same as the local size used in creating the
1715            y vector for the matrix-vector product y = Ax.
1716 .  n - This value should be the same as the local size used in creating the
1717        x vector for the matrix-vector product y = Ax. (or `PETSC_DECIDE` to have
1718        calculated if N is given) For square matrices n is almost always `m`.
1719 .  M - number of global rows (or `PETSC_DETERMINE` to have calculated if `m` is given)
1720 .  N - number of global columns (or `PETSC_DETERMINE` to have calculated if `n` is given)
1721 .  d_nz  - number of nonzeros per row in DIAGONAL portion of local submatrix
1722            (same value is used for all local rows)
1723 .  d_nnz - array containing the number of nonzeros in the various rows of the
1724            DIAGONAL portion of the local submatrix (possibly different for each row)
1725            or `NULL`, if `d_nz` is used to specify the nonzero structure.
1726            The size of this array is equal to the number of local rows, i.e `m`.
1727            For matrices you plan to factor you must leave room for the diagonal entry and
1728            put in the entry even if it is zero.
1729 .  o_nz  - number of nonzeros per row in the OFF-DIAGONAL portion of local
1730            submatrix (same value is used for all local rows).
1731 -  o_nnz - array containing the number of nonzeros in the various rows of the
1732            OFF-DIAGONAL portion of the local submatrix (possibly different for
1733            each row) or `NULL`, if `o_nz` is used to specify the nonzero
1734            structure. The size of this array is equal to the number
1735            of local rows, i.e `m`.
1736 
1737    Output Parameter:
1738 .  A - the matrix
1739 
1740    Level: intermediate
1741 
1742    Notes:
1743    It is recommended that one use the `MatCreate()`, `MatSetType()` and/or `MatSetFromOptions()`,
1744    MatXXXXSetPreallocation() paradigm instead of this routine directly.
1745    [MatXXXXSetPreallocation() is, for example, `MatSeqAIJSetPreallocation()`]
1746 
1747    The AIJ format, also called compressed row storage), is fully compatible with standard Fortran
1748    storage.  That is, the stored row and column indices can begin at
1749    either one (as in Fortran) or zero.
1750 
1751 .seealso: [](chapter_matrices), `Mat`, `MATAIJKOKOS`, `MATSEQAIJKOKOS`, `MATMPIAIJKOKOS`, `MatCreate()`, `MatCreateAIJ()`, `MatSetValues()`,
1752           `MatSeqAIJSetColumnIndices()`, `MatCreateSeqAIJWithArrays()`, `MatCreateAIJ()`, `MATMPIAIJKOKKOS`, `MATAIJKOKKOS`
1753 @*/
1754 PetscErrorCode MatCreateAIJKokkos(MPI_Comm comm, PetscInt m, PetscInt n, PetscInt M, PetscInt N, PetscInt d_nz, const PetscInt d_nnz[], PetscInt o_nz, const PetscInt o_nnz[], Mat *A)
1755 {
1756   PetscMPIInt size;
1757 
1758   PetscFunctionBegin;
1759   PetscCall(MatCreate(comm, A));
1760   PetscCall(MatSetSizes(*A, m, n, M, N));
1761   PetscCallMPI(MPI_Comm_size(comm, &size));
1762   if (size > 1) {
1763     PetscCall(MatSetType(*A, MATMPIAIJKOKKOS));
1764     PetscCall(MatMPIAIJSetPreallocation(*A, d_nz, d_nnz, o_nz, o_nnz));
1765   } else {
1766     PetscCall(MatSetType(*A, MATSEQAIJKOKKOS));
1767     PetscCall(MatSeqAIJSetPreallocation(*A, d_nz, d_nnz));
1768   }
1769   PetscFunctionReturn(PETSC_SUCCESS);
1770 }
1771 
1772 // get GPU pointer to stripped down Mat. For both Seq and MPI Mat.
1773 PetscErrorCode MatKokkosGetDeviceMatWrite(Mat A, PetscSplitCSRDataStructure *B)
1774 {
1775   PetscMPIInt                size, rank;
1776   MPI_Comm                   comm;
1777   PetscSplitCSRDataStructure d_mat = NULL;
1778 
1779   PetscFunctionBegin;
1780   PetscCall(PetscObjectGetComm((PetscObject)A, &comm));
1781   PetscCallMPI(MPI_Comm_size(comm, &size));
1782   PetscCallMPI(MPI_Comm_rank(comm, &rank));
1783   if (size == 1) {
1784     PetscCall(MatSeqAIJKokkosGetDeviceMat(A, &d_mat));
1785     PetscCall(MatSeqAIJKokkosModifyDevice(A)); /* Since we are going to modify matrix values on device */
1786   } else {
1787     Mat_MPIAIJ *aij = (Mat_MPIAIJ *)A->data;
1788     PetscCall(MatSeqAIJKokkosGetDeviceMat(aij->A, &d_mat));
1789     PetscCall(MatSeqAIJKokkosModifyDevice(aij->A));
1790     PetscCall(MatSeqAIJKokkosModifyDevice(aij->B));
1791     PetscCheck(A->nooffprocentries || aij->donotstash, PetscObjectComm((PetscObject)A), PETSC_ERR_SUP, "Device assembly does not currently support offproc values insertion. Use MatSetOption(A,MAT_NO_OFF_PROC_ENTRIES,PETSC_TRUE) or MatSetOption(A,MAT_IGNORE_OFF_PROC_ENTRIES,PETSC_TRUE)");
1792   }
1793   // act like MatSetValues because not called on host
1794   if (A->assembled) {
1795     if (A->was_assembled) PetscCall(PetscInfo(A, "Assemble more than once already\n"));
1796     A->was_assembled = PETSC_TRUE; // this is done (lazy) in MatAssemble but we are not calling it anymore - done in AIJ AssemblyEnd, need here?
1797   } else {
1798     PetscCall(PetscInfo(A, "Warning !assemble ??? assembled=%d\n", A->assembled));
1799   }
1800   if (!d_mat) {
1801     struct _n_SplitCSRMat h_mat; /* host container */
1802     Mat_SeqAIJKokkos     *aijkokA;
1803     Mat_SeqAIJ           *jaca;
1804     PetscInt              n = A->rmap->n, nnz;
1805     Mat                   Amat;
1806     PetscInt             *colmap;
1807 
1808     /* create and copy h_mat */
1809     h_mat.M = A->cmap->N; // use for debug build
1810     PetscCall(PetscInfo(A, "Create device matrix in Kokkos\n"));
1811     if (size == 1) {
1812       Amat            = A;
1813       jaca            = (Mat_SeqAIJ *)A->data;
1814       h_mat.rstart    = 0;
1815       h_mat.rend      = A->rmap->n;
1816       h_mat.cstart    = 0;
1817       h_mat.cend      = A->cmap->n;
1818       h_mat.offdiag.i = h_mat.offdiag.j = NULL;
1819       h_mat.offdiag.a                   = NULL;
1820       aijkokA                           = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
1821     } else {
1822       Mat_MPIAIJ       *aij  = (Mat_MPIAIJ *)A->data;
1823       Mat_SeqAIJ       *jacb = (Mat_SeqAIJ *)aij->B->data;
1824       PetscInt          ii;
1825       Mat_SeqAIJKokkos *aijkokB;
1826 
1827       Amat    = aij->A;
1828       aijkokA = static_cast<Mat_SeqAIJKokkos *>(aij->A->spptr);
1829       aijkokB = static_cast<Mat_SeqAIJKokkos *>(aij->B->spptr);
1830       jaca    = (Mat_SeqAIJ *)aij->A->data;
1831       PetscCheck(!aij->B->cmap->n || aij->garray, comm, PETSC_ERR_PLIB, "MPIAIJ Matrix was assembled but is missing garray");
1832       PetscCheck(aij->B->rmap->n == aij->A->rmap->n, comm, PETSC_ERR_SUP, "Only support aij->B->rmap->n == aij->A->rmap->n");
1833       aij->donotstash          = PETSC_TRUE;
1834       aij->A->nooffprocentries = aij->B->nooffprocentries = A->nooffprocentries = PETSC_TRUE;
1835       jaca->nonew = jacb->nonew = PETSC_TRUE; // no more disassembly
1836       PetscCall(PetscCalloc1(A->cmap->N, &colmap));
1837       for (ii = 0; ii < aij->B->cmap->n; ii++) colmap[aij->garray[ii]] = ii + 1;
1838       // allocate B copy data
1839       h_mat.rstart = A->rmap->rstart;
1840       h_mat.rend   = A->rmap->rend;
1841       h_mat.cstart = A->cmap->rstart;
1842       h_mat.cend   = A->cmap->rend;
1843       nnz          = jacb->i[n];
1844       if (jacb->compressedrow.use) {
1845         const Kokkos::View<PetscInt *, Kokkos::HostSpace, Kokkos::MemoryTraits<Kokkos::Unmanaged>> h_i_k(jacb->i, n + 1);
1846         aijkokB->i_uncompressed_d = Kokkos::View<PetscInt *>(Kokkos::create_mirror(DefaultMemorySpace(), h_i_k));
1847         Kokkos::deep_copy(aijkokB->i_uncompressed_d, h_i_k);
1848         h_mat.offdiag.i = aijkokB->i_uncompressed_d.data();
1849       } else {
1850         h_mat.offdiag.i = aijkokB->i_device_data();
1851       }
1852       h_mat.offdiag.j = aijkokB->j_device_data();
1853       h_mat.offdiag.a = aijkokB->a_device_data();
1854       {
1855         Kokkos::View<PetscInt *, Kokkos::HostSpace, Kokkos::MemoryTraits<Kokkos::Unmanaged>> h_colmap_k(colmap, A->cmap->N);
1856         aijkokB->colmap_d = Kokkos::View<PetscInt *>(Kokkos::create_mirror(DefaultMemorySpace(), h_colmap_k));
1857         Kokkos::deep_copy(aijkokB->colmap_d, h_colmap_k);
1858         h_mat.colmap = aijkokB->colmap_d.data();
1859         PetscCall(PetscFree(colmap));
1860       }
1861       h_mat.offdiag.ignorezeroentries = jacb->ignorezeroentries;
1862       h_mat.offdiag.n                 = n;
1863     }
1864     // allocate A copy data
1865     nnz                          = jaca->i[n];
1866     h_mat.diag.n                 = n;
1867     h_mat.diag.ignorezeroentries = jaca->ignorezeroentries;
1868     PetscCallMPI(MPI_Comm_rank(comm, &h_mat.rank));
1869     PetscCheck(!jaca->compressedrow.use, PETSC_COMM_SELF, PETSC_ERR_PLIB, "A does not support compressed row (todo)");
1870     h_mat.diag.i = aijkokA->i_device_data();
1871     h_mat.diag.j = aijkokA->j_device_data();
1872     h_mat.diag.a = aijkokA->a_device_data();
1873     // copy pointers and metadata to device
1874     PetscCall(MatSeqAIJKokkosSetDeviceMat(Amat, &h_mat));
1875     PetscCall(MatSeqAIJKokkosGetDeviceMat(Amat, &d_mat));
1876     PetscCall(PetscInfo(A, "Create device Mat n=%" PetscInt_FMT " nnz=%" PetscInt_FMT "\n", h_mat.diag.n, nnz));
1877   }
1878   *B           = d_mat;       // return it, set it in Mat, and set it up
1879   A->assembled = PETSC_FALSE; // ready to write with matsetvalues - this done (lazy) in normal MatSetValues
1880   PetscFunctionReturn(PETSC_SUCCESS);
1881 }
1882 
1883 PETSC_INTERN PetscErrorCode MatSeqAIJKokkosGetOffloadMask(Mat A, const char **mask)
1884 {
1885   Mat_SeqAIJKokkos *aijkok = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
1886 
1887   PetscFunctionBegin;
1888   if (!aijkok) *mask = "AIJKOK_UNALLOCATED";
1889   else if (aijkok->a_dual.need_sync_host()) *mask = "PETSC_OFFLOAD_GPU";
1890   else if (aijkok->a_dual.need_sync_device()) *mask = "PETSC_OFFLOAD_CPU";
1891   else *mask = "PETSC_OFFLOAD_BOTH";
1892   PetscFunctionReturn(PETSC_SUCCESS);
1893 }
1894 
1895 PETSC_INTERN PetscErrorCode MatAIJKokkosPrintOffloadMask(Mat A)
1896 {
1897   PetscMPIInt size;
1898   Mat         Ad, Ao;
1899   const char *amask, *bmask;
1900 
1901   PetscFunctionBegin;
1902   PetscCallMPI(MPI_Comm_size(PetscObjectComm((PetscObject)A), &size));
1903 
1904   if (size == 1) {
1905     PetscCall(MatSeqAIJKokkosGetOffloadMask(A, &amask));
1906     PetscCall(PetscPrintf(PETSC_COMM_SELF, "%s\n", amask));
1907   } else {
1908     Ad = ((Mat_MPIAIJ *)A->data)->A;
1909     Ao = ((Mat_MPIAIJ *)A->data)->B;
1910     PetscCall(MatSeqAIJKokkosGetOffloadMask(Ad, &amask));
1911     PetscCall(MatSeqAIJKokkosGetOffloadMask(Ao, &bmask));
1912     PetscCall(PetscPrintf(PETSC_COMM_SELF, "Diag : Off-diag = %s : %s\n", amask, bmask));
1913   }
1914   PetscFunctionReturn(PETSC_SUCCESS);
1915 }
1916