xref: /petsc/src/mat/impls/aij/mpi/kokkos/mpiaijkok.kokkos.cxx (revision a29dfd43bb0c77e2653d3bfa2c953f902720a6d2)
1 #include <petsc_kokkos.hpp>
2 #include <petscvec_kokkos.hpp>
3 #include <petscpkg_version.h>
4 #include <petsc/private/sfimpl.h>
5 #include <petsc/private/kokkosimpl.hpp>
6 #include <../src/mat/impls/aij/seq/kokkos/aijkok.hpp>
7 #include <../src/mat/impls/aij/mpi/mpiaij.h>
8 #include <KokkosSparse_spadd.hpp>
9 #include <KokkosSparse_spgemm.hpp>
10 
11 static PetscErrorCode MatAssemblyEnd_MPIAIJKokkos(Mat A, MatAssemblyType mode)
12 {
13   Mat_MPIAIJ *mpiaij = (Mat_MPIAIJ *)A->data;
14 
15   PetscFunctionBegin;
16   PetscCall(MatAssemblyEnd_MPIAIJ(A, mode));
17   /* E.g., MatCreateSubMatrix() calls MatCreateMPIAIJWithSeqAIJ(comm,A,B,..), which creates Bnew of SEQAIJ and destroys B of SEQAIJKOKKOS.
18      Thus we finalize A/B/lvec's type in MatAssemblyEnd() to handle various cases.
19    */
20   if (mode == MAT_FINAL_ASSEMBLY) {
21     PetscScalarKokkosView v;
22 
23     PetscCall(MatSetType(mpiaij->A, MATSEQAIJKOKKOS));
24     PetscCall(MatSetType(mpiaij->B, MATSEQAIJKOKKOS));
25     PetscCall(VecSetType(mpiaij->lvec, VECSEQKOKKOS));  // lvec is init'ed on host, without copying to device
26     PetscCall(VecGetKokkosViewWrite(mpiaij->lvec, &v)); // mark lvec updated on device, as we never need to init lvec on device
27     PetscCall(VecRestoreKokkosViewWrite(mpiaij->lvec, &v));
28   }
29   PetscFunctionReturn(PETSC_SUCCESS);
30 }
31 
32 static PetscErrorCode MatMPIAIJSetPreallocation_MPIAIJKokkos(Mat mat, PetscInt d_nz, const PetscInt d_nnz[], PetscInt o_nz, const PetscInt o_nnz[])
33 {
34   Mat_MPIAIJ *mpiaij = (Mat_MPIAIJ *)mat->data;
35 
36   PetscFunctionBegin;
37   // If mat was set to use the "set values with a hash table" mechanism, discard it and restore the cached ops
38   if (mat->hash_active) {
39     mat->ops[0]      = mpiaij->cops;
40     mat->hash_active = PETSC_FALSE;
41   }
42 
43   PetscCall(PetscLayoutSetUp(mat->rmap));
44   PetscCall(PetscLayoutSetUp(mat->cmap));
45 #if defined(PETSC_USE_DEBUG)
46   if (d_nnz) {
47     PetscInt i;
48     for (i = 0; i < mat->rmap->n; i++) PetscCheck(d_nnz[i] >= 0, PETSC_COMM_SELF, PETSC_ERR_ARG_OUTOFRANGE, "d_nnz cannot be less than 0: local row %" PetscInt_FMT " value %" PetscInt_FMT, i, d_nnz[i]);
49   }
50   if (o_nnz) {
51     PetscInt i;
52     for (i = 0; i < mat->rmap->n; i++) PetscCheck(o_nnz[i] >= 0, PETSC_COMM_SELF, PETSC_ERR_ARG_OUTOFRANGE, "o_nnz cannot be less than 0: local row %" PetscInt_FMT " value %" PetscInt_FMT, i, o_nnz[i]);
53   }
54 #endif
55 #if defined(PETSC_USE_CTABLE)
56   PetscCall(PetscHMapIDestroy(&mpiaij->colmap));
57 #else
58   PetscCall(PetscFree(mpiaij->colmap));
59 #endif
60   PetscCall(PetscFree(mpiaij->garray));
61   PetscCall(VecDestroy(&mpiaij->lvec));
62   PetscCall(VecScatterDestroy(&mpiaij->Mvctx));
63   /* Because the B will have been resized we simply destroy it and create a new one each time */
64   PetscCall(MatDestroy(&mpiaij->B));
65 
66   if (!mpiaij->A) {
67     PetscCall(MatCreate(PETSC_COMM_SELF, &mpiaij->A));
68     PetscCall(MatSetSizes(mpiaij->A, mat->rmap->n, mat->cmap->n, mat->rmap->n, mat->cmap->n));
69   }
70   if (!mpiaij->B) {
71     PetscMPIInt size;
72     PetscCallMPI(MPI_Comm_size(PetscObjectComm((PetscObject)mat), &size));
73     PetscCall(MatCreate(PETSC_COMM_SELF, &mpiaij->B));
74     PetscCall(MatSetSizes(mpiaij->B, mat->rmap->n, size > 1 ? mat->cmap->N : 0, mat->rmap->n, size > 1 ? mat->cmap->N : 0));
75   }
76   PetscCall(MatSetType(mpiaij->A, MATSEQAIJKOKKOS));
77   PetscCall(MatSetType(mpiaij->B, MATSEQAIJKOKKOS));
78   PetscCall(MatSeqAIJSetPreallocation(mpiaij->A, d_nz, d_nnz));
79   PetscCall(MatSeqAIJSetPreallocation(mpiaij->B, o_nz, o_nnz));
80   mat->preallocated = PETSC_TRUE;
81   PetscFunctionReturn(PETSC_SUCCESS);
82 }
83 
84 static PetscErrorCode MatMult_MPIAIJKokkos(Mat mat, Vec xx, Vec yy)
85 {
86   Mat_MPIAIJ *mpiaij = (Mat_MPIAIJ *)mat->data;
87   PetscInt    nt;
88 
89   PetscFunctionBegin;
90   PetscCall(VecGetLocalSize(xx, &nt));
91   PetscCheck(nt == mat->cmap->n, PETSC_COMM_SELF, PETSC_ERR_ARG_SIZ, "Incompatible partition of mat (%" PetscInt_FMT ") and xx (%" PetscInt_FMT ")", mat->cmap->n, nt);
92   PetscCall(VecScatterBegin(mpiaij->Mvctx, xx, mpiaij->lvec, INSERT_VALUES, SCATTER_FORWARD));
93   PetscCall((*mpiaij->A->ops->mult)(mpiaij->A, xx, yy));
94   PetscCall(VecScatterEnd(mpiaij->Mvctx, xx, mpiaij->lvec, INSERT_VALUES, SCATTER_FORWARD));
95   PetscCall((*mpiaij->B->ops->multadd)(mpiaij->B, mpiaij->lvec, yy, yy));
96   PetscFunctionReturn(PETSC_SUCCESS);
97 }
98 
99 static PetscErrorCode MatMultAdd_MPIAIJKokkos(Mat mat, Vec xx, Vec yy, Vec zz)
100 {
101   Mat_MPIAIJ *mpiaij = (Mat_MPIAIJ *)mat->data;
102   PetscInt    nt;
103 
104   PetscFunctionBegin;
105   PetscCall(VecGetLocalSize(xx, &nt));
106   PetscCheck(nt == mat->cmap->n, PETSC_COMM_SELF, PETSC_ERR_ARG_SIZ, "Incompatible partition of mat (%" PetscInt_FMT ") and xx (%" PetscInt_FMT ")", mat->cmap->n, nt);
107   PetscCall(VecScatterBegin(mpiaij->Mvctx, xx, mpiaij->lvec, INSERT_VALUES, SCATTER_FORWARD));
108   PetscCall((*mpiaij->A->ops->multadd)(mpiaij->A, xx, yy, zz));
109   PetscCall(VecScatterEnd(mpiaij->Mvctx, xx, mpiaij->lvec, INSERT_VALUES, SCATTER_FORWARD));
110   PetscCall((*mpiaij->B->ops->multadd)(mpiaij->B, mpiaij->lvec, zz, zz));
111   PetscFunctionReturn(PETSC_SUCCESS);
112 }
113 
114 static PetscErrorCode MatMultTranspose_MPIAIJKokkos(Mat mat, Vec xx, Vec yy)
115 {
116   Mat_MPIAIJ *mpiaij = (Mat_MPIAIJ *)mat->data;
117   PetscInt    nt;
118 
119   PetscFunctionBegin;
120   PetscCall(VecGetLocalSize(xx, &nt));
121   PetscCheck(nt == mat->rmap->n, PETSC_COMM_SELF, PETSC_ERR_ARG_SIZ, "Incompatible partition of mat (%" PetscInt_FMT ") and xx (%" PetscInt_FMT ")", mat->rmap->n, nt);
122   PetscCall((*mpiaij->B->ops->multtranspose)(mpiaij->B, xx, mpiaij->lvec));
123   PetscCall((*mpiaij->A->ops->multtranspose)(mpiaij->A, xx, yy));
124   PetscCall(VecScatterBegin(mpiaij->Mvctx, mpiaij->lvec, yy, ADD_VALUES, SCATTER_REVERSE));
125   PetscCall(VecScatterEnd(mpiaij->Mvctx, mpiaij->lvec, yy, ADD_VALUES, SCATTER_REVERSE));
126   PetscFunctionReturn(PETSC_SUCCESS);
127 }
128 
129 /* Merge the "A, B" matrices of mat into a matrix C.  mat's type is MPIAIJKOKKOS. C's type is MATSEQAIJKOKKOS.
130    A is put before B. C's size would be A->rmap->n by (A->cmap->n + B->cmap->n).
131    C still uses local column ids. Their corresponding global column ids are returned in glob.
132 */
133 static PetscErrorCode MatMPIAIJGetLocalMatMerge_MPIAIJKokkos(Mat mat, MatReuse reuse, IS *glob, Mat *C)
134 {
135   Mat             Ad, Ao;
136   const PetscInt *cmap;
137 
138   PetscFunctionBegin;
139   PetscCall(MatMPIAIJGetSeqAIJ(mat, &Ad, &Ao, &cmap));
140   PetscCall(MatSeqAIJKokkosMergeMats(Ad, Ao, reuse, C));
141   if (glob) {
142     PetscInt cst, i, dn, on, *gidx;
143     PetscCall(MatGetLocalSize(Ad, NULL, &dn));
144     PetscCall(MatGetLocalSize(Ao, NULL, &on));
145     PetscCall(MatGetOwnershipRangeColumn(mat, &cst, NULL));
146     PetscCall(PetscMalloc1(dn + on, &gidx));
147     for (i = 0; i < dn; i++) gidx[i] = cst + i;
148     for (i = 0; i < on; i++) gidx[i + dn] = cmap[i];
149     PetscCall(ISCreateGeneral(PetscObjectComm((PetscObject)Ad), dn + on, gidx, PETSC_OWN_POINTER, glob));
150   }
151   PetscFunctionReturn(PETSC_SUCCESS);
152 }
153 
154 /* Structs used in matrix products of type C=AB, C=A^tB and C=B^tAB */
155 struct MatMatStruct {
156   PetscInt            n, *garray;     // C's garray and its size.
157   KokkosCsrMatrix     Cd, Co;         // C is in split form matrices (all in local column indcies)
158   KokkosCsrMatrix     C1, C2, C3, C4; // intermediate mat products
159   KokkosCsrMatrix     C2_mid, C4_mid; // alias of C2, C4; share their a[], i[], but with different j[] (hence column size)
160   PetscIntKokkosView  E_NzLeft;
161   PetscSF             sf = nullptr; // SF to bcast or reduce matrices E to F
162   MatScalarKokkosView rootBuf, leafBuf;
163   KokkosCsrMatrix     Fd, Fo; // F in split form
164 
165   KernelHandle kh1; // compute C1, add C1+C3 or C1+Fd
166   KernelHandle kh2; // compute C2, add C2+C4 or C2+Fo
167   KernelHandle kh3; // compute C3
168   KernelHandle kh4; // compute C4
169 
170   PetscInt E_TeamSize; // kernel launching parameters in merging E or splitting F
171   PetscInt E_VectorLength;
172   PetscInt E_RowsPerTeam;
173   PetscInt F_TeamSize;
174   PetscInt F_VectorLength;
175   PetscInt F_RowsPerTeam;
176 
177   ~MatMatStruct()
178   {
179     PetscFunctionBegin;
180     PetscCallAbort(PETSC_COMM_SELF, PetscSFDestroy(&sf));
181     PetscFunctionReturnVoid();
182   }
183 };
184 
185 struct MatMatStruct_AB : public MatMatStruct {
186   PetscIntKokkosView F_NzLeft; // plans to split F (in leafbuf) into Fd, Fo
187   PetscIntKokkosView irootloc; // plans to put E (i.e., Bd, Bo) into rootBuf
188   PetscIntKokkosView rowoffset;
189 };
190 
191 struct MatMatStruct_AtB : public MatMatStruct {
192   MatColIdxKokkosView Fdjmap; // plans to reduce data in rootBuf to Fd, Fo
193   MatColIdxKokkosView Fdjperm;
194   MatColIdxKokkosView Fojmap;
195   MatColIdxKokkosView Fojperm;
196 };
197 
198 struct MatProductData_MPIAIJKokkos {
199   MatMatStruct_AB  *mmAB     = nullptr;
200   MatMatStruct_AtB *mmAtB    = nullptr;
201   PetscBool         reusesym = PETSC_FALSE;
202   Mat               Z        = nullptr; // store Z=AB in computing BtAB
203 
204   ~MatProductData_MPIAIJKokkos()
205   {
206     delete mmAB;
207     delete mmAtB;
208     PetscCallAbort(PETSC_COMM_SELF, MatDestroy(&Z));
209   }
210 };
211 
212 static PetscErrorCode MatProductDataDestroy_MPIAIJKokkos(void *data)
213 {
214   PetscFunctionBegin;
215   PetscCallCXX(delete static_cast<MatProductData_MPIAIJKokkos *>(data));
216   PetscFunctionReturn(PETSC_SUCCESS);
217 }
218 
219 /* MatSetMPIAIJKokkosWithSplitSeqAIJKokkosMatrices - Set the diag and offdiag matrices of a MATMPIAIJKOKKOS matrix.
220    It is similar to MatCreateMPIAIJWithSplitArrays.
221 
222   Input Parameters:
223 +  mat   - the MATMPIAIJKOKKOS matrix, which should have its type and layout set, but should not have its diag, offdiag matrices set
224 .  A     - the diag matrix using local col ids
225 -  B     - the offdiag matrix using global col ids
226 
227   Output Parameter:
228 .  mat   - the updated MATMPIAIJKOKKOS matrix
229 */
230 static PetscErrorCode MatSetMPIAIJKokkosWithSplitSeqAIJKokkosMatrices(Mat mat, Mat A, Mat B, PetscInt *garray)
231 {
232   Mat_MPIAIJ *mpiaij = static_cast<Mat_MPIAIJ *>(mat->data);
233   PetscInt    m, n, M, N, Am, An, Bm, Bn;
234 
235   PetscFunctionBegin;
236   PetscCall(MatGetSize(mat, &M, &N));
237   PetscCall(MatGetLocalSize(mat, &m, &n));
238   PetscCall(MatGetLocalSize(A, &Am, &An));
239   PetscCall(MatGetLocalSize(B, &Bm, &Bn));
240 
241   PetscCheck(m == Am && m == Bm, PETSC_COMM_SELF, PETSC_ERR_PLIB, "local number of rows do not match");
242   PetscCheck(n == An, PETSC_COMM_SELF, PETSC_ERR_PLIB, "local number of columns do not match");
243   // PetscCheck(N == Bn, PETSC_COMM_SELF, PETSC_ERR_PLIB, "global number of columns do not match");
244   PetscCheck(!mpiaij->A && !mpiaij->B, PETSC_COMM_SELF, PETSC_ERR_PLIB, "A, B of the MPIAIJ matrix are not empty");
245   mpiaij->A      = A;
246   mpiaij->B      = B;
247   mpiaij->garray = garray;
248 
249   mat->preallocated     = PETSC_TRUE;
250   mat->nooffprocentries = PETSC_TRUE; /* See MatAssemblyBegin_MPIAIJ. In effect, making MatAssemblyBegin a nop */
251 
252   PetscCall(MatSetOption(mat, MAT_NO_OFF_PROC_ENTRIES, PETSC_TRUE));
253   PetscCall(MatAssemblyBegin(mat, MAT_FINAL_ASSEMBLY));
254   /* MatAssemblyEnd is critical here. It sets mat->offloadmask according to A and B's, and
255     also gets mpiaij->B compacted, with its col ids and size reduced
256   */
257   PetscCall(MatAssemblyEnd(mat, MAT_FINAL_ASSEMBLY));
258   PetscCall(MatSetOption(mat, MAT_NO_OFF_PROC_ENTRIES, PETSC_FALSE));
259   PetscCall(MatSetOption(mat, MAT_NEW_NONZERO_LOCATION_ERR, PETSC_TRUE));
260   PetscFunctionReturn(PETSC_SUCCESS);
261 }
262 
263 // Adapted from Kokkos-Kernels spmv_launch_parameters(), to get parameters in Kokkos nested loops which we used to merge or
264 // split csr matrices. The rule is to have "vector_length * team_size" be around 256 on GPUs (e.g., for a CUDA thread block)
265 template <class ExecutionSpace>
266 static PetscErrorCode MatMergeGetLaunchParameters(PetscInt numRows, PetscInt nnz, PetscInt rows_per_thread, PetscInt &team_size, PetscInt &vector_length, PetscInt &rows_per_team)
267 {
268   Kokkos::TeamPolicy<ExecutionSpace> teamPolicy(128, Kokkos::AUTO);
269 
270   PetscFunctionBegin;
271   PetscInt nnz_per_row = numRows ? (nnz / numRows) : 0; // we might meet empty matrices
272 
273   if (nnz_per_row < 1) nnz_per_row = 1;
274 
275   int max_vector_length = teamPolicy.vector_length_max();
276 
277   if (vector_length < 1) {
278     vector_length = 1;
279     while (vector_length < max_vector_length && vector_length * 6 < nnz_per_row) vector_length *= 2;
280   }
281 
282   // Determine rows per thread
283   if (rows_per_thread < 1) {
284     if (KokkosKernels::Impl::kk_is_gpu_exec_space<ExecutionSpace>()) rows_per_thread = 1;
285     else {
286       if (nnz_per_row < 20 && nnz > 5000000) {
287         rows_per_thread = 256;
288       } else rows_per_thread = 64;
289     }
290   }
291 
292   if (team_size < 1) {
293     if (KokkosKernels::Impl::kk_is_gpu_exec_space<ExecutionSpace>()) {
294       team_size = 256 / vector_length;
295     } else {
296       team_size = 1;
297     }
298   }
299 
300   rows_per_team = rows_per_thread * team_size;
301 
302   if (rows_per_team < 0) {
303     PetscInt nnz_per_team = 4096;
304     PetscInt conc         = ExecutionSpace().concurrency();
305     while ((conc * nnz_per_team * 4 > nnz) && (nnz_per_team > 256)) nnz_per_team /= 2;
306     rows_per_team = (nnz_per_team + nnz_per_row - 1) / nnz_per_row;
307   }
308   PetscFunctionReturn(PETSC_SUCCESS);
309 }
310 
311 /*
312   Reduce two sets of global indices into local ones
313 
314   Input Parameters:
315 +  n1          - size of garray1[], the first set
316 .  garray1[n1] - a sorted global index array (without duplicates)
317 .  m           - size of indices[], the second set
318 -  indices[m]  - a unsorted global index array (might have duplicates), which will be updated on output into local ones
319 
320   Output Parameters:
321 +  n2          - size of garray2[], the merged set, which combines garray1[] and indices[]
322 .  garray2[n2] - allocated by callee using PetscMalloc1(). Contains sorted unique global indices (without duplicates). Caller needs to free it.
323 .  map[n1]     - allocated by caller. It gives garray1[i] = garray2[map[i]]
324 -  indices[m]  - on output, global indices in this array are rewritten with local ones, i.e, indices_input[i] = garray2[indices_output[i]]
325 
326    Example, say
327     n1         = 5
328     garray1[5] = {1, 4, 7, 8, 10}
329     m          = 4
330     indices[4] = {2, 4, 8, 9}
331 
332    Combining them together, we have 7 global indices in garray2[]
333     n2         = 7
334     garray2[7] = {1, 2, 4, 7, 8, 9, 10}
335 
336    And we have map[] to connect "garray1[i] = garray2[map[i]], i=[0,n1)"
337     map[5] = {0, 2, 3, 4, 6}
338 
339    On output, indices[] is updated with local indices
340     indices[4] = {1, 2, 4, 5}
341 */
342 static PetscErrorCode ReduceTwoSetsOfGlobalIndices(PetscInt n1, const PetscInt *garray1, PetscInt m, PetscInt *indices, PetscInt *n2_, PetscInt **garray2_, PetscInt *map)
343 {
344   PetscHMapI    g2l = nullptr;
345   PetscHashIter iter;
346   PetscInt      tot, key, val; // total unique global indices. key is global id; val is local id
347   PetscInt      n2, *garray2;
348 
349   PetscFunctionBegin;
350   tot = 0;
351   PetscCall(PetscHMapICreateWithSize(n1, &g2l));
352   for (PetscInt i = 0; i < m; i++) {                                // insert those in indices[]
353     PetscCall(PetscHMapIGetWithDefault(g2l, indices[i], -1, &val)); // if not exist, val is set with -1
354     if (val < 0) PetscCall(PetscHMapISet(g2l, indices[i], tot++));  // val < 0 means gid is not in the hash table yet
355   }
356 
357   for (PetscInt i = 0; i < n1; i++) { // insert those in garray1[]
358     PetscCall(PetscHMapIGetWithDefault(g2l, garray1[i], -1, &val));
359     if (val < 0) PetscCall(PetscHMapISet(g2l, garray1[i], tot++));
360   }
361 
362   // Pull out (unique) globals in the hash table and put them in garray2[]
363   n2 = tot;
364   PetscCall(PetscMalloc1(n2, &garray2));
365   tot = 0;
366   PetscHashIterBegin(g2l, iter);
367   while (!PetscHashIterAtEnd(g2l, iter)) {
368     PetscHashIterGetKey(g2l, iter, key);
369     PetscHashIterNext(g2l, iter);
370     garray2[tot++] = key;
371   }
372 
373   // Sort garray2[] and then map them to local indices starting from 0
374   PetscCall(PetscSortInt(n2, garray2));
375   PetscCall(PetscHMapIClear(g2l));
376   for (PetscInt i = 0; i < tot; i++) PetscCall(PetscHMapISet(g2l, garray2[i], i)); // i is the local id
377 
378   // Rewrite indices[] with local indices
379   for (PetscInt i = 0; i < m; i++) {
380     PetscCall(PetscHMapIGetWithDefault(g2l, indices[i], -1, &val));
381     PetscAssert(val >= 0, PETSC_COMM_SELF, PETSC_ERR_PLIB, "Met a negative local column index");
382     indices[i] = val;
383   }
384   // Record the map that maps garray1[i] to garray2[map[i]]
385   for (PetscInt i = 0; i < n1; i++) PetscCall(PetscHMapIGetWithDefault(g2l, garray1[i], -1, &map[i]));
386   PetscCall(PetscHMapIDestroy(&g2l));
387   *n2_      = n2;
388   *garray2_ = garray2;
389   PetscFunctionReturn(PETSC_SUCCESS);
390 }
391 
392 /*
393   MatMPIAIJKokkosReduce - Reduce rows of a MPIAIJKOKKOS matrix (E, in split form) to produce another matrix (F, also in split form, stored in mm)
394 
395   It is the reverse of MatMPIAIJKokkosBcast() in some sense, but with a different signature since we do not really need a fully populated MPIAIJKOKKOS E.
396 
397   Think each row of E as a leaf, then the given ownerSF specifies roots for the leaves. Roots may connect to multiple leaves.
398   In this routine, we sparse-merge leaves (rows) at their roots to form potentially longer rows in F. F's number of rows will be nroots of ownerSF.
399 
400   Input Parameters:
401 +  comm       - MPI communicator of E
402 .  A          - diag block of E, using local column indices
403 .  B          - off-diag block of E, using local column indices
404 .  cstart      - (global) start column of Ed
405 .  cend        - (global) end column + 1 of Ed.  In other words, E's column ownership is in range of [cstart, cend)
406 .  garray1[n1] - global column indices of Eo. Here n1 is Eo's column size.
407 .  ownerSF     - the SF specifies ownership (root) of rows in E
408 .  reuse       - either MAT_INITIAL_MATRIX or MAT_REUSE_MATRIX
409 -  mm          - to stash intermediate data structures for reuse
410 
411   Output Parameters:
412 +  map[n1]  - allocated by caller. It maps garray1[] to garray2[]. See more at ReduceTwoSetsOfGlobalIndices().
413 -  mm       - contains various info, such as garray2[], F (Fd, Fo) etc.
414 
415   Notes:
416   When reuse = MAT_REUSE_MATRIX, cstart, cend, garray1, ownerSF, map are not significant.
417 
418  */
419 static PetscErrorCode MatMPIAIJKokkosReduceBegin(MPI_Comm comm, KokkosCsrMatrix A, KokkosCsrMatrix B, PetscInt cstart, PetscInt cend, const PetscInt *garray1, PetscSF ownerSF, MatReuse reuse, PetscInt *map, MatMatStruct_AtB *mm)
420 {
421   PetscFunctionBegin;
422   if (reuse == MAT_INITIAL_MATRIX) {
423     PetscInt Em = A.numRows(), Fm;
424     PetscInt n1 = B.numCols();
425 
426     PetscCall(PetscSFGetGraph(ownerSF, &Fm, NULL, NULL, NULL)); // Fm = #rows of F = nroots of ownerSF
427 
428     // Do the analysis on host
429     auto                 Ai_h = Kokkos::create_mirror_view_and_copy(Kokkos::HostSpace(), A.graph.row_map);
430     auto                 Aj_h = Kokkos::create_mirror_view_and_copy(Kokkos::HostSpace(), A.graph.entries);
431     auto                 Bi_h = Kokkos::create_mirror_view_and_copy(Kokkos::HostSpace(), B.graph.row_map);
432     auto                 Bj_h = Kokkos::create_mirror_view_and_copy(Kokkos::HostSpace(), B.graph.entries);
433     const MatRowMapType *Ai = Ai_h.data(), *Bi = Bi_h.data();
434     const MatColIdxType *Aj = Aj_h.data(), *Bj = Bj_h.data();
435 
436     // Count how many nonzeros of each row in E are in the left of the diag block [cstart,cend)
437     PetscIntKokkosViewHost E_NzLeft_h(NoInit("E_NzLeft_h"), Em), E_RowLen_h(NoInit("E_RowLen_h"), Em);
438     PetscInt              *E_NzLeft = E_NzLeft_h.data(), *E_RowLen = E_RowLen_h.data();
439     for (PetscInt i = 0; i < Em; i++) {
440       const PetscInt *first, *last, *it;
441       PetscInt        count, step;
442       // std::lower_bound(first,last,cstart), but need to use global column indices
443       first = Bj + Bi[i];
444       last  = Bj + Bi[i + 1];
445       count = last - first;
446       while (count > 0) {
447         it   = first;
448         step = count / 2;
449         it += step;
450         if (garray1[*it] < cstart) { // map local to global
451           first = ++it;
452           count -= step + 1;
453         } else count = step;
454       }
455       E_NzLeft[i] = first - (Bj + Bi[i]);
456       E_RowLen[i] = (Ai[i + 1] - Ai[i]) + (Bi[i + 1] - Bi[i]);
457     }
458 
459     // Get length of rows (i.e., sizes of leaves) that contribute to my roots
460     const PetscMPIInt *iranks, *ranks;
461     const PetscInt    *ioffset, *irootloc, *roffset, *rmine;
462     PetscInt           niranks, nranks;
463     MPI_Request       *reqs;
464     PetscMPIInt        tag;
465     PetscSF            reduceSF;
466     PetscInt          *sdisp, *rdisp;
467 
468     PetscCall(PetscCommGetNewTag(comm, &tag));
469     PetscCall(PetscSFGetLeafRanks(ownerSF, &niranks, &iranks, &ioffset, &irootloc));  // get leaf ranks connecting to roots on this process (I'll recv from them)
470     PetscCall(PetscSFGetRootRanks(ownerSF, &nranks, &ranks, &roffset, &rmine, NULL)); // get root ranks this process connects (I'll send to them)
471 
472     // Find out length of each row I will receive. Even for the same row index, when they are from
473     // different senders, they might have different lengths (and sparsity patterns)
474     PetscInt  sendRowCnt = roffset[nranks], recvRowCnt = ioffset[niranks];
475     PetscInt *sendRowLen, *recvRowLen; // lengths of rows of I need to send/recv per process
476 
477     PetscCall(PetscMalloc5(sendRowCnt, &sendRowLen, recvRowCnt + 1, &recvRowLen, nranks, &sdisp, niranks + 1, &rdisp, nranks + niranks, &reqs));
478 
479     for (PetscInt i = 0; i < sendRowCnt; i++) sendRowLen[i] = E_RowLen[rmine[i]];
480     recvRowLen[0] = 0; // since we will make it in CSR format later
481     recvRowLen++;      // advance the pointer now
482     for (PetscInt i = 0; i < niranks; i++) { MPI_Irecv(&recvRowLen[ioffset[i]], ioffset[i + 1] - ioffset[i], MPIU_INT, iranks[i], tag, comm, &reqs[nranks + i]); }
483     for (PetscInt i = 0; i < nranks; i++) { MPI_Isend(&sendRowLen[roffset[i]], roffset[i + 1] - roffset[i], MPIU_INT, ranks[i], tag, comm, &reqs[i]); }
484     PetscCallMPI(MPI_Waitall(nranks + niranks, reqs, MPI_STATUSES_IGNORE));
485 
486     // Build the real PetscSF for reducing E rows (buffer to buffer)
487     rdisp[0] = 0;
488     for (PetscInt i = 0; i < niranks; i++) {
489       rdisp[i + 1] = rdisp[i];
490       for (PetscInt j = ioffset[i]; j < ioffset[i + 1]; j++) { rdisp[i + 1] += recvRowLen[j]; }
491     }
492     recvRowLen--; // put it back into csr format
493     for (PetscInt i = 0; i < recvRowCnt; i++) recvRowLen[i + 1] += recvRowLen[i];
494 
495     for (PetscInt i = 0; i < nranks; i++) { MPI_Irecv(&sdisp[i], 1, MPIU_INT, ranks[i], tag, comm, &reqs[i]); }
496     for (PetscInt i = 0; i < niranks; i++) { MPI_Isend(&rdisp[i], 1, MPIU_INT, iranks[i], tag, comm, &reqs[nranks + i]); }
497     PetscCallMPI(MPI_Waitall(nranks + niranks, reqs, MPI_STATUSES_IGNORE));
498 
499     PetscInt     nleaves = 0, Enz = 0;    // leaves are nonzeros I will send
500     PetscInt     nroots = rdisp[niranks]; // roots are nonzeros I will recv
501     PetscSFNode *iremote;
502 
503     for (PetscInt i = 0; i < Em; i++) Enz += E_RowLen[i];
504     PetscAssert(A.nnz() + B.nnz() == Enz, comm, PETSC_ERR_PLIB, "Enz should be equal to sum of nnz of A and B");
505     PetscCallMPI(PetscMalloc1(Enz, &iremote)); // no free, since we give ownership to reduceSF
506 
507     for (PetscInt i = 0; i < nranks; i++) {
508       PetscInt count = 0;
509       for (PetscInt j = roffset[i]; j < roffset[i + 1]; j++) count += E_RowLen[rmine[j]];
510       for (PetscInt j = 0; j < count; j++) {
511         iremote[nleaves + j].rank  = ranks[i];
512         iremote[nleaves + j].index = sdisp[i] + j;
513       }
514       nleaves += count;
515     }
516     PetscCheck(nleaves == Enz, comm, PETSC_ERR_PLIB, "nleaves should be equal to Enz");
517 
518     PetscCall(PetscSFCreate(comm, &reduceSF));
519     PetscCall(PetscSFSetGraph(reduceSF, nroots, nleaves, NULL, PETSC_OWN_POINTER, iremote, PETSC_OWN_POINTER));
520 
521     // Copy (global) column indices of the needed rows in E to sendCol[], and then PetscSFReduce to recvCol[]
522     PetscInt *sendCol, *recvCol;
523     PetscCall(PetscMalloc2(nleaves, &sendCol, nroots, &recvCol));
524     for (PetscInt k = 0; k < roffset[nranks]; k++) {
525       PetscInt  i      = rmine[k]; // row to be copied
526       PetscInt *buf    = &sendCol[Ai[i] + Bi[i]];
527       PetscInt  nzLeft = E_NzLeft[i];
528       PetscInt  alen = Ai[i + 1] - Ai[i], blen = Bi[i + 1] - Bi[i];
529       for (PetscInt j = 0; j < alen + blen; j++) {
530         if (j < nzLeft) {
531           buf[j] = garray1[Bj[Bi[i] + j]]; // left B, in global
532         } else if (j < nzLeft + alen) {
533           buf[j] = Aj[Ai[i] + j - nzLeft] + cstart; // diag A, also in global
534         } else {
535           buf[j] = garray1[Bj[Bi[i] + j - alen]]; // right B, in global
536         }
537       }
538     }
539     PetscCall(PetscSFReduceWithMemTypeBegin(reduceSF, MPIU_INT, PETSC_MEMTYPE_HOST, sendCol, PETSC_MEMTYPE_HOST, recvCol, MPI_REPLACE));
540     PetscCall(PetscSFReduceEnd(reduceSF, MPIU_INT, sendCol, recvCol, MPI_REPLACE));
541 
542     // With recvCol[], we do a series of analysis to get i, j of Fd, Fo, and build plans to reduce nonzeros in recv buffers to Fd and Fo
543     PetscInt *recvRowPerm, *recvColSorted;
544     PetscInt *recvNzPerm, *recvNzPermSorted;
545     PetscCall(PetscMalloc4(recvRowCnt, &recvRowPerm, nroots, &recvColSorted, nroots, &recvNzPerm, nroots, &recvNzPermSorted));
546 
547     for (PetscInt i = 0; i < nroots; i++) recvNzPerm[i] = i;                   // numbering all received nonzeros
548     for (PetscInt i = 0; i < recvRowCnt; i++) recvRowPerm[i] = i;              // put up a permutation array, so that after sorting we know where to get a row in recvCol[]
549     PetscCall(PetscSortIntWithPermutation(recvRowCnt, irootloc, recvRowPerm)); // irootloc[] (owned by ownerSF) won't be changed
550 
551     // i[] array, nz are always easiest to compute
552     MatRowMapKokkosViewHost Fdi_h(NoInit("Fdi_h"), Fm + 1), Foi_h(NoInit("Foi_h"), Fm + 1);
553     MatRowMapType          *Fdi, *Foi;
554     PetscInt                FnzDups = 0, Fdnz = 0, FdnzDups = 0, Fonz = 0, FonzDups = 0; // nz (with or without dups) in F, Fd, Fo
555     PetscInt                iter;
556 
557     Kokkos::deep_copy(Fdi_h, 0); // zero, as we will do 'val++' on them
558     Kokkos::deep_copy(Foi_h, 0);
559     Fdi  = Fdi_h.data() + 1; // +1 for easy indexing in code below
560     Foi  = Foi_h.data() + 1;
561     iter = 0;
562     while (iter < recvRowCnt) { // iter over received rows
563       PetscInt curRowIdx = irootloc[recvRowPerm[iter]];
564       PetscInt dupRows   = 1; // current row has this many contributing rows (of various sparsity patterns)
565 
566       while (iter + dupRows < recvRowCnt && irootloc[recvRowPerm[iter + dupRows]] == curRowIdx) dupRows++;
567 
568       // Copy column indices (and their permutation) of these rows into recvColSorted & recvNzPermSorted
569       PetscInt  nz    = 0; // nz (with dups) in the current row
570       PetscInt *jbuf  = recvColSorted + FnzDups;
571       PetscInt *pbuf  = recvNzPermSorted + FnzDups;
572       PetscInt *jbuf2 = jbuf; // temp pointers
573       PetscInt *pbuf2 = pbuf;
574       for (PetscInt d = 0; d < dupRows; d++) {
575         PetscInt i   = recvRowPerm[iter + d];
576         PetscInt len = recvRowLen[i + 1] - recvRowLen[i];
577         PetscCall(PetscArraycpy(jbuf2, &recvCol[recvRowLen[i]], len));
578         PetscCall(PetscArraycpy(pbuf2, &recvNzPerm[recvRowLen[i]], len));
579         jbuf2 += len;
580         pbuf2 += len;
581         nz += len;
582       }
583       PetscCall(PetscIntSortSemiOrderedWithArray(nz, jbuf, pbuf)); // It could be improved with k-way merge sort, since the rows are already sorted
584 
585       // Scan column indices (in jbuf[0,nz), might have dups) of this row, and see how many go to Fd and how many go to Fo
586       PetscInt cur = 0;
587       while (cur < nz) {
588         PetscInt curColIdx = jbuf[cur];
589         PetscInt dups      = 1;
590 
591         while (cur + dups < nz && jbuf[cur + dups] == curColIdx) dups++;
592         if (curColIdx >= cstart && curColIdx < cend) {
593           Fdi[curRowIdx]++;
594           FdnzDups += dups;
595         } else {
596           Foi[curRowIdx]++;
597           FonzDups += dups;
598         }
599         cur += dups;
600       }
601 
602       FnzDups += nz;
603       iter += dupRows; // Move to next unique row
604     }
605 
606     Fdi = Fdi_h.data(); // restore Fdi, Foi and make them CSR
607     Foi = Foi_h.data();
608     for (PetscInt i = 0; i < Fm; i++) {
609       Fdi[i + 1] += Fdi[i];
610       Foi[i + 1] += Foi[i];
611     }
612     Fdnz = Fdi[Fm];
613     Fonz = Foi[Fm];
614     PetscCall(PetscFree2(sendCol, recvCol));
615 
616     // Allocate j, jmap, jperm for Fd and Fo
617     MatColIdxKokkosViewHost Fdj_h(NoInit("Fdj_h"), Fdnz), Foj_h(NoInit("Foj_h"), Fonz);
618     MatRowMapKokkosViewHost Fdjmap_h(NoInit("Fdjmap_h"), Fdnz + 1), Fojmap_h(NoInit("Fojmap_h"), Fonz + 1); // +1 to make csr
619     MatRowMapKokkosViewHost Fdjperm_h(NoInit("Fdjperm_h"), FdnzDups), Fojperm_h(NoInit("Fojperm_h"), FonzDups);
620     MatColIdxType          *Fdj = Fdj_h.data(), *Foj = Foj_h.data();
621     MatRowMapType          *Fdjmap = Fdjmap_h.data(), *Fojmap = Fojmap_h.data();
622     MatRowMapType          *Fdjperm = Fdjperm_h.data(), *Fojperm = Fojperm_h.data();
623 
624     // Scan recvColSorted[] again, and fill j, jmap, jperm for Fd and Fo
625     Fdjmap[0] = 0;
626     Fojmap[0] = 0;
627     FnzDups   = 0;
628     Fdnz      = 0;
629     Fonz      = 0;
630     iter      = 0; // iter over received rows
631     while (iter < recvRowCnt) {
632       PetscInt curRowIdx = irootloc[recvRowPerm[iter]]; // current row idx
633       PetscInt dupRows   = 1;                           // It has this many contributing rows (of various lengths)
634       PetscInt nz        = 0;                           // nz (with dups) in the current row
635 
636       while (iter + dupRows < recvRowCnt && irootloc[recvRowPerm[iter + dupRows]] == curRowIdx) dupRows++;
637       for (PetscInt d = 0; d < dupRows; d++) {
638         PetscInt i = recvRowPerm[iter + d];
639         nz += recvRowLen[i + 1] - recvRowLen[i];
640       }
641 
642       PetscInt *jbuf = recvColSorted + FnzDups;
643       // Scan columns (in jbuf[0,nz) of this row, copy them and their permutation to j[] and jperm[] of Fd and Fo
644       PetscInt cur = 0;
645       while (cur < nz) {
646         PetscInt curColIdx = jbuf[cur];
647         PetscInt dups      = 1;
648 
649         while (cur + dups < nz && jbuf[cur + dups] == curColIdx) dups++;
650         if (curColIdx >= cstart && curColIdx < cend) {
651           Fdj[Fdnz]        = curColIdx - cstart; // easily convert to local
652           Fdjmap[Fdnz + 1] = Fdjmap[Fdnz] + dups;
653           for (PetscInt j = 0; j < dups; j++) Fdjperm[Fdjmap[Fdnz] + j] = recvNzPermSorted[FnzDups + j];
654           FdnzDups += dups;
655           Fdnz++;
656         } else {
657           Foj[Fonz]        = curColIdx; // in global
658           Fojmap[Fonz + 1] = Fojmap[Fonz] + dups;
659           for (PetscInt j = 0; j < dups; j++) Fojperm[Fojmap[Fonz] + j] = recvNzPermSorted[FnzDups + j];
660           FonzDups += dups;
661           Fonz++;
662         }
663         cur += dups;
664         FnzDups += dups;
665       }
666       iter += dupRows; // Move to next unique row
667     }
668     PetscCall(PetscFree4(recvRowPerm, recvColSorted, recvNzPerm, recvNzPermSorted));
669     PetscCall(PetscFree5(sendRowLen, recvRowLen, sdisp, rdisp, reqs));
670 
671     // Combine global column indices in garray1[] and Foj[]
672     PetscInt n2, *garray2;
673 
674     PetscCall(ReduceTwoSetsOfGlobalIndices(n1, garray1, Fonz, Foj, &n2, &garray2, map));
675     mm->sf       = reduceSF;
676     mm->leafBuf  = MatScalarKokkosView(NoInit("leafBuf"), nleaves);
677     mm->rootBuf  = MatScalarKokkosView(NoInit("rootBuf"), nroots);
678     mm->garray   = garray2; // give ownership, so no free
679     mm->n        = n2;
680     mm->E_NzLeft = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), E_NzLeft_h);
681     mm->Fdjmap   = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Fdjmap_h);
682     mm->Fdjperm  = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Fdjperm_h);
683     mm->Fojmap   = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Fojmap_h);
684     mm->Fojperm  = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Fojperm_h);
685 
686     // Output Fd and Fo in KokkosCsrMatrix format
687     MatScalarKokkosView Fda_d(NoInit("Fda_d"), Fdnz);
688     MatRowMapKokkosView Fdi_d = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Fdi_h);
689     MatColIdxKokkosView Fdj_d = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Fdj_h);
690     MatScalarKokkosView Foa_d(NoInit("Foa_d"), Fonz);
691     MatRowMapKokkosView Foi_d = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Foi_h);
692     MatColIdxKokkosView Foj_d = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Foj_h);
693 
694     PetscCallCXX(mm->Fd = KokkosCsrMatrix("Fd", Fm, cend - cstart, Fdnz, Fda_d, Fdi_d, Fdj_d));
695     PetscCallCXX(mm->Fo = KokkosCsrMatrix("Fo", Fm, n2, Fonz, Foa_d, Foi_d, Foj_d)); // Fo's column size is n2, length of garray2[]
696 
697     // Compute kernel launch parameters in merging E
698     PetscInt teamSize, vectorLength, rowsPerTeam;
699 
700     teamSize = vectorLength = rowsPerTeam = -1;
701     PetscCall(MatMergeGetLaunchParameters<DefaultExecutionSpace>(Em, Enz, -1, teamSize, vectorLength, rowsPerTeam));
702     mm->E_TeamSize     = teamSize;
703     mm->E_VectorLength = vectorLength;
704     mm->E_RowsPerTeam  = rowsPerTeam;
705   } else PetscCheck(reuse == MAT_REUSE_MATRIX, comm, PETSC_ERR_PLIB, "Unsupported MatReuse enum %d", reuse);
706 
707   // Handy aliases
708   auto       &Aa           = A.values;
709   auto       &Ba           = B.values;
710   const auto &Ai           = A.graph.row_map;
711   const auto &Bi           = B.graph.row_map;
712   const auto &E_NzLeft     = mm->E_NzLeft;
713   auto       &leafBuf      = mm->leafBuf;
714   auto       &rootBuf      = mm->rootBuf;
715   PetscSF     reduceSF     = mm->sf;
716   PetscInt    Em           = A.numRows();
717   PetscInt    teamSize     = mm->E_TeamSize;
718   PetscInt    vectorLength = mm->E_VectorLength;
719   PetscInt    rowsPerTeam  = mm->E_RowsPerTeam;
720   PetscInt    workSets     = (Em + rowsPerTeam - 1) / rowsPerTeam;
721 
722   // Copy rows in A/B of E to leafBuf, then pass it to rootBuf
723   PetscCallCXX(Kokkos::parallel_for(
724     Kokkos::TeamPolicy<>(PetscGetKokkosExecutionSpace(), workSets, teamSize, vectorLength), KOKKOS_LAMBDA(const KokkosTeamMemberType &t) {
725       Kokkos::parallel_for(Kokkos::TeamThreadRange(t, 0, rowsPerTeam), [&](PetscInt k) {
726         PetscInt i = t.league_rank() * rowsPerTeam + k; // i-th row in F
727         if (i < Em) {
728           PetscInt disp   = Ai(i) + Bi(i);
729           PetscInt alen   = Ai(i + 1) - Ai(i);
730           PetscInt blen   = Bi(i + 1) - Bi(i);
731           PetscInt nzleft = E_NzLeft(i);
732 
733           Kokkos::parallel_for(Kokkos::ThreadVectorRange(t, alen + blen), [&](PetscInt j) {
734             MatScalar &val = leafBuf(disp + j);
735             if (j < nzleft) { // B left
736               val = Ba(Bi(i) + j);
737             } else if (j < nzleft + alen) { // diag A
738               val = Aa(Ai(i) + j - nzleft);
739             } else { // B right
740               val = Ba(Bi(i) + j - alen);
741             }
742           });
743         }
744       });
745     }));
746   PetscCall(PetscSFReduceWithMemTypeBegin(reduceSF, MPIU_SCALAR, PETSC_MEMTYPE_KOKKOS, leafBuf.data(), PETSC_MEMTYPE_KOKKOS, rootBuf.data(), MPI_REPLACE));
747   PetscFunctionReturn(PETSC_SUCCESS);
748 }
749 
750 // To finish MatMPIAIJKokkosReduce.
751 static PetscErrorCode MatMPIAIJKokkosReduceEnd(MPI_Comm comm, KokkosCsrMatrix A, KokkosCsrMatrix B, PetscInt cstart, PetscInt cend, const PetscInt *garray1, PetscSF ownerSF, MatReuse reuse, PetscInt *map, MatMatStruct_AtB *mm)
752 {
753   auto       &leafBuf  = mm->leafBuf;
754   auto       &rootBuf  = mm->rootBuf;
755   auto       &Fda      = mm->Fd.values;
756   const auto &Fdjmap   = mm->Fdjmap;
757   const auto &Fdjperm  = mm->Fdjperm;
758   auto        Fdnz     = mm->Fd.nnz();
759   auto       &Foa      = mm->Fo.values;
760   const auto &Fojmap   = mm->Fojmap;
761   const auto &Fojperm  = mm->Fojperm;
762   auto        Fonz     = mm->Fo.nnz();
763   PetscSF     reduceSF = mm->sf;
764 
765   PetscFunctionBegin;
766   PetscCall(PetscSFReduceEnd(reduceSF, MPIU_SCALAR, leafBuf.data(), rootBuf.data(), MPI_REPLACE));
767 
768   // Reduce data in rootBuf to Fd and Fo
769   PetscCallCXX(Kokkos::parallel_for(
770     Kokkos::RangePolicy<>(PetscGetKokkosExecutionSpace(), 0, Fdnz), KOKKOS_LAMBDA(const MatRowMapType i) {
771       PetscScalar sum = 0.0;
772       for (MatRowMapType k = Fdjmap(i); k < Fdjmap(i + 1); k++) sum += rootBuf(Fdjperm(k));
773       Fda(i) = sum;
774     }));
775 
776   PetscCallCXX(Kokkos::parallel_for(
777     Kokkos::RangePolicy<>(PetscGetKokkosExecutionSpace(), 0, Fonz), KOKKOS_LAMBDA(const MatRowMapType i) {
778       PetscScalar sum = 0.0;
779       for (MatRowMapType k = Fojmap(i); k < Fojmap(i + 1); k++) sum += rootBuf(Fojperm(k));
780       Foa(i) = sum;
781     }));
782   PetscFunctionReturn(PETSC_SUCCESS);
783 }
784 
785 /*
786   MatMPIAIJKokkosBcast - Bcast local rows of a MPIAIJKOKKOS matrix (E) to produce a local matrix (F, stored in mm) in split form
787 
788   This is a complex routine. It is essentially the MPIAIJKOKKOS counterpart of MatGetBrowsOfAoCols_MPIAIJ, but supports
789   device and involves various index mapping.
790 
791   In the given ownerSF, leaves correspond to rows in F, and roots correspond to rows in E. Roots may connect to multiple leaves.
792   Suppose F's j-th row is connected to a root identified by PetscSFNode (k,i), it means we need to bcast the i-th row of E on rank k
793   to j-th row of F. ownerSF is not an arbitrary SF, instead it is the Mvctx of another MPIAIJ matrix A that is able to perform A*E.
794   F has the same column layout as E.
795 
796   Conceptually F has global column indices. In this routine, we spit F into diagonal Fd and off-diagonal Fo.
797   Fd uses local column indices, which are easy to compute. We just need to subtract the "local column range start" from the global indices.
798   Fo had global column indices at first. We will reduce them into local ones. In doing that, we also take into account the global
799   column indices that E's off-diag block has. Let's say there are n1 such indices stored in garray1[]. We will reduce them along with
800   column indices in Fo and update Fo with local indices.
801 
802    Input Parameters:
803 +   E       - the MPIAIJKOKKOS matrix
804 .   ownerSF - the ownership SF (insignificant in MAT_REUSE_MATRIX)
805 .   reuse   - either MAT_INITIAL_MATRIX or MAT_REUSE_MATRIX
806 -   mm      - to stash matproduct intermediate data structures
807 
808     Output Parameters:
809 +   map[n1] - allocated by caller. It maps garray1[] to garray2[]. See more at ReduceTwoSetsOfGlobalIndices.
810 -   mm      - contains various info, such as garray2[], Fd, Fo, etc.
811 
812     Notes:
813     When reuse = MAT_REUSE_MATRIX, ownerSF, map are not significant.
814     The routine is provide in split-phase form MatMPIAIJKokkosBcastBegin/End() to provide computation/communication opportunities.
815 */
816 static PetscErrorCode MatMPIAIJKokkosBcastBegin(Mat E, PetscSF ownerSF, MatReuse reuse, PetscInt *map, MatMatStruct_AB *mm)
817 {
818   Mat_MPIAIJ       *empi = static_cast<Mat_MPIAIJ *>(E->data);
819   Mat               A = empi->A, B = empi->B; // diag and off-diag
820   Mat_SeqAIJKokkos *akok = static_cast<Mat_SeqAIJKokkos *>(A->spptr), *bkok = static_cast<Mat_SeqAIJKokkos *>(B->spptr);
821   PetscInt          Em = E->rmap->n; // #local rows
822   MPI_Comm          comm;
823 
824   PetscFunctionBegin;
825   PetscCallMPI(PetscObjectGetComm((PetscObject)E, &comm));
826   if (reuse == MAT_INITIAL_MATRIX) {
827     Mat_SeqAIJ     *aseq = static_cast<Mat_SeqAIJ *>(A->data), *bseq = static_cast<Mat_SeqAIJ *>(B->data);
828     PetscInt        n1 = B->cmap->n, *Ai = aseq->i, *Aj = aseq->j, *Bi = bseq->i, *Bj = bseq->j;
829     const PetscInt *garray1 = empi->garray; // its size is n1
830     PetscInt        cstart, cend;
831     PetscSF         bcastSF;
832 
833     PetscCall(MatGetOwnershipRangeColumn(E, &cstart, &cend));
834 
835     // Count how many nonzeros of each row in E are in the left of the diag block [cstart,cend)
836     PetscIntKokkosViewHost E_NzLeft_h(NoInit("E_NzLeft_h"), Em), E_RowLen_h(NoInit("E_RowLen_h"), Em);
837     PetscInt              *E_NzLeft = E_NzLeft_h.data(), *E_RowLen = E_RowLen_h.data();
838     for (PetscInt i = 0; i < Em; i++) {
839       const PetscInt *first, *last, *it;
840       PetscInt        count, step;
841       // std::lower_bound(first,last,cstart), but need to use global column indices
842       first = Bj + Bi[i];
843       last  = Bj + Bi[i + 1];
844       count = last - first;
845       while (count > 0) {
846         it   = first;
847         step = count / 2;
848         it += step;
849         if (empi->garray[*it] < cstart) { // map local to global
850           first = ++it;
851           count -= step + 1;
852         } else count = step;
853       }
854       E_NzLeft[i] = first - (Bj + Bi[i]);
855       E_RowLen[i] = (Ai[i + 1] - Ai[i]) + (Bi[i + 1] - Bi[i]);
856     }
857 
858     // Compute row pointer Fi of F
859     PetscInt *Fi, Fm, Fnz;
860     PetscCall(PetscSFGetGraph(ownerSF, NULL, &Fm, NULL, NULL)); // Fm = #rows of F = nleaves of ownerSF
861     PetscCall(PetscMalloc1(Fm + 1, &Fi));
862     Fi[0] = 0;
863     PetscCall(PetscSFBcastWithMemTypeBegin(ownerSF, MPIU_INT, PETSC_MEMTYPE_HOST, E_RowLen, PETSC_MEMTYPE_HOST, &Fi[1], MPI_REPLACE));
864     PetscCall(PetscSFBcastEnd(ownerSF, MPIU_INT, E_RowLen, &Fi[1], MPI_REPLACE));
865     for (PetscInt i = 0; i < Fm; i++) Fi[i + 1] += Fi[i];
866     Fnz = Fi[Fm];
867 
868     // Build the real PetscSF for bcasting E rows (buffer to buffer)
869     const PetscMPIInt *iranks, *ranks;
870     const PetscInt    *ioffset, *irootloc, *roffset;
871     PetscInt           niranks, nranks, *sdisp, *rdisp;
872     MPI_Request       *reqs;
873     PetscMPIInt        tag;
874 
875     PetscCall(PetscSFGetLeafRanks(ownerSF, &niranks, &iranks, &ioffset, &irootloc)); // get leaf ranks referencing roots on this process
876     PetscCall(PetscSFGetRootRanks(ownerSF, &nranks, &ranks, &roffset, NULL, NULL));  // recv info
877     PetscCall(PetscMalloc3(niranks + 1, &sdisp, nranks, &rdisp, niranks + nranks, &reqs));
878 
879     sdisp[0] = 0; // send displacement
880     for (PetscInt i = 0; i < niranks; i++) {
881       sdisp[i + 1] = sdisp[i];
882       for (PetscInt j = ioffset[i]; j < ioffset[i + 1]; j++) {
883         PetscInt r = irootloc[j]; // row to be sent
884         sdisp[i + 1] += E_RowLen[r];
885       }
886     }
887 
888     PetscCallMPI(PetscCommGetNewTag(comm, &tag));
889     for (PetscInt i = 0; i < nranks; i++) PetscCallMPI(MPI_Irecv(&rdisp[i], 1, MPIU_INT, ranks[i], tag, comm, &reqs[i]));
890     for (PetscInt i = 0; i < niranks; i++) PetscCallMPI(MPI_Isend(&sdisp[i], 1, MPIU_INT, iranks[i], tag, comm, &reqs[nranks + i]));
891     PetscCallMPI(MPI_Waitall(niranks + nranks, reqs, MPI_STATUSES_IGNORE));
892 
893     PetscInt     nleaves = Fnz;            // leaves are nonzeros I will receive
894     PetscInt     nroots  = sdisp[niranks]; // roots are nonzeros I will send
895     PetscSFNode *iremote;                  // give ownership to bcastSF
896     PetscCall(PetscMalloc1(nleaves, &iremote));
897     for (PetscInt i = 0; i < nranks; i++) { // for each sender rank
898       PetscInt k = 0;
899       for (PetscInt j = Fi[roffset[i]]; j < Fi[roffset[i + 1]]; j++) { // I will receive rows [roffset[i], roffset[i+1]) of F from ranks[i]
900         iremote[j].rank  = ranks[i];
901         iremote[j].index = rdisp[i] + k; // their root location
902         k++;
903       }
904     }
905     PetscCall(PetscSFCreate(comm, &bcastSF));
906     PetscCall(PetscSFSetGraph(bcastSF, nroots, nleaves, NULL, PETSC_OWN_POINTER, iremote, PETSC_OWN_POINTER));
907     PetscCall(PetscFree3(sdisp, rdisp, reqs));
908 
909     // Build a plan (rowoffset, irootloc, E_NzLeft) to copy rows in E to rootdata of bcastSF in parallel
910     PetscIntKokkosViewHost rowoffset_h(NoInit("rowoffset_h"), ioffset[niranks] + 1);
911     PetscInt              *rowoffset = rowoffset_h.data(); // for each entry (row) indicated in irootloc[], we calculate its destinate offset in copying
912     rowoffset[0]                     = 0;
913     for (PetscInt i = 0; i < ioffset[niranks]; i++) { rowoffset[i + 1] = rowoffset[i] + E_RowLen[irootloc[i]]; }
914 
915     // Copy (global) column indices of the needed rows in E to a buffer, and then bcast to Fj[]
916     PetscInt *jbuf, *Fj;
917     PetscCall(PetscMalloc2(nroots, &jbuf, Fnz, &Fj));
918     for (PetscInt k = 0; k < ioffset[niranks]; k++) {
919       PetscInt  i      = irootloc[k]; // row to be copied
920       PetscInt *buf    = &jbuf[rowoffset[k]];
921       PetscInt  nzLeft = E_NzLeft[i];
922       PetscInt  alen = Ai[i + 1] - Ai[i], blen = Bi[i + 1] - Bi[i];
923       for (PetscInt j = 0; j < alen + blen; j++) {
924         if (j < nzLeft) {
925           buf[j] = empi->garray[Bj[Bi[i] + j]]; // left B, in global
926         } else if (j < nzLeft + alen) {
927           buf[j] = Aj[Ai[i] + j - nzLeft] + cstart; // diag A, also in global
928         } else {
929           buf[j] = empi->garray[Bj[Bi[i] + j - alen]]; // right B, in global
930         }
931       }
932     }
933     PetscCall(PetscSFBcastWithMemTypeBegin(bcastSF, MPIU_INT, PETSC_MEMTYPE_HOST, jbuf, PETSC_MEMTYPE_HOST, Fj, MPI_REPLACE));
934     PetscCall(PetscSFBcastEnd(bcastSF, MPIU_INT, jbuf, Fj, MPI_REPLACE));
935 
936     // Build a plan (i.e., F_NzLeft) to split F into Fd and Fo
937     MatRowMapKokkosViewHost Fdi_h(NoInit("Fdi_h"), Fm + 1), Foi_h(NoInit("Foi_h"), Fm + 1); // row pointer of Fd, Fo
938     MatColIdxKokkosViewHost F_NzLeft_h(NoInit("F_NzLeft_h"), Fm);                           // split each row of F into Left, Diag, Right. We only need to record #nz in Left and Diag.
939     MatRowMapType          *Fdi = Fdi_h.data(), *Foi = Foi_h.data();
940     MatColIdxType          *F_NzLeft = F_NzLeft_h.data();
941 
942     Fdi[0] = Foi[0] = 0;
943     for (PetscInt i = 0; i < Fm; i++) {
944       PetscInt *first, *last, *lb1, *lb2;
945       // cut the row into: Left, [cstart, cend), Right
946       first       = Fj + Fi[i];
947       last        = Fj + Fi[i + 1];
948       lb1         = std::lower_bound(first, last, cstart);
949       F_NzLeft[i] = lb1 - first;
950       lb2         = std::lower_bound(first, last, cend);
951       Fdi[i + 1]  = lb2 - lb1;                        // row i length in Fdi
952       Foi[i + 1]  = (Fi[i + 1] - Fi[i]) - Fdi[i + 1]; // row i length in Foi
953     }
954     for (PetscInt i = 0; i < Fm; i++) {
955       Fdi[i + 1] += Fdi[i];
956       Foi[i + 1] += Foi[i];
957     }
958 
959     // Fill Fdj[] and Foj[], i.e., columns of Fd and Fo. Fdj[] are local, but Foj[] are not yet.
960     PetscInt                Fdnz = Fdi[Fm], Fonz = Foi[Fm];
961     MatColIdxKokkosViewHost Fdj_h(NoInit("Fdj_h"), Fdnz), Foj_h(NoInit("Foj_h"), Fonz);
962     MatColIdxType          *Fdj = Fdj_h.data(), *Foj = Foj_h.data(), gid;
963 
964     for (PetscInt i = 0; i < Fm; i++) {
965       PetscInt nzLeft = F_NzLeft[i];
966       PetscInt len    = Fdi[i + 1] - Fdi[i]; // diag row len
967       for (PetscInt j = 0; j < Fi[i + 1] - Fi[i]; j++) {
968         gid = Fj[Fi[i] + j];
969         if (j < nzLeft) { // left, in global
970           Foj[Foi[i] + j] = gid;
971         } else if (j < nzLeft + len) { // diag, in local
972           Fdj[Fdi[i] + j - nzLeft] = gid - cstart;
973         } else { // right, in global
974           Foj[Foi[i] + j - len] = gid;
975         }
976       }
977     }
978     PetscCall(PetscFree2(jbuf, Fj));
979     PetscCall(PetscFree(Fi));
980 
981     // Reduce global indices in Foj[] and garray1[] into local ones
982     PetscInt n2, *garray2;
983     PetscCall(ReduceTwoSetsOfGlobalIndices(n1, garray1, Fonz, Foj, &n2, &garray2, map));
984 
985     // Record the plans built above, for reuse
986     PetscIntKokkosViewHost tmp(const_cast<PetscInt *>(irootloc), ioffset[niranks]); // irootloc[] is owned by ownerSF. We create a copy for safety
987     PetscIntKokkosViewHost irootloc_h(NoInit("irootloc_h"), ioffset[niranks]);
988     Kokkos::deep_copy(irootloc_h, tmp);
989     mm->sf        = bcastSF;
990     mm->E_NzLeft  = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), E_NzLeft_h);
991     mm->F_NzLeft  = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), F_NzLeft_h);
992     mm->irootloc  = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), irootloc_h);
993     mm->rowoffset = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), rowoffset_h);
994     mm->rootBuf   = MatScalarKokkosView(NoInit("rootBuf"), nroots);
995     mm->leafBuf   = MatScalarKokkosView(NoInit("leafBuf"), nleaves);
996     mm->garray    = garray2;
997     mm->n         = n2;
998 
999     // Output Fd and Fo in KokkosCsrMatrix format
1000     MatScalarKokkosView Fda_d(NoInit("Fda_d"), Fdnz), Foa_d(NoInit("Foa_d"), Fonz);
1001     MatRowMapKokkosView Fdi_d = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Fdi_h);
1002     MatColIdxKokkosView Fdj_d = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Fdj_h);
1003     MatRowMapKokkosView Foi_d = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Foi_h);
1004     MatColIdxKokkosView Foj_d = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Foj_h);
1005 
1006     PetscCallCXX(mm->Fd = KokkosCsrMatrix("Fd", Fm, cend - cstart, Fdnz, Fda_d, Fdi_d, Fdj_d));
1007     PetscCallCXX(mm->Fo = KokkosCsrMatrix("Fo", Fm, n2, Fonz, Foa_d, Foi_d, Foj_d));
1008 
1009     // Compute kernel launch parameters in merging E or splitting F
1010     PetscInt teamSize, vectorLength, rowsPerTeam;
1011 
1012     teamSize = vectorLength = rowsPerTeam = -1;
1013     PetscCall(MatMergeGetLaunchParameters<DefaultExecutionSpace>(mm->irootloc.extent(0), mm->rootBuf.extent(0), -1, teamSize, vectorLength, rowsPerTeam));
1014     mm->E_TeamSize     = teamSize;
1015     mm->E_VectorLength = vectorLength;
1016     mm->E_RowsPerTeam  = rowsPerTeam;
1017 
1018     teamSize = vectorLength = rowsPerTeam = -1;
1019     PetscCall(MatMergeGetLaunchParameters<DefaultExecutionSpace>(Fm, Fnz, -1, teamSize, vectorLength, rowsPerTeam));
1020     mm->F_TeamSize     = teamSize;
1021     mm->F_VectorLength = vectorLength;
1022     mm->F_RowsPerTeam  = rowsPerTeam;
1023   } else PetscCheck(reuse == MAT_REUSE_MATRIX, comm, PETSC_ERR_PLIB, "Unsupported MatReuse enum %d", reuse);
1024 
1025   // Sync E's value to device
1026   akok->a_dual.sync_device();
1027   bkok->a_dual.sync_device();
1028 
1029   // Handy aliases
1030   const auto &Aa = akok->a_dual.view_device();
1031   const auto &Ba = bkok->a_dual.view_device();
1032   const auto &Ai = akok->i_dual.view_device();
1033   const auto &Bi = bkok->i_dual.view_device();
1034 
1035   // Fetch the plans
1036   PetscIntKokkosView  &E_NzLeft  = mm->E_NzLeft;
1037   PetscSF             &bcastSF   = mm->sf;
1038   MatScalarKokkosView &rootBuf   = mm->rootBuf;
1039   MatScalarKokkosView &leafBuf   = mm->leafBuf;
1040   PetscIntKokkosView  &irootloc  = mm->irootloc;
1041   PetscIntKokkosView  &rowoffset = mm->rowoffset;
1042 
1043   PetscInt teamSize     = mm->E_TeamSize;
1044   PetscInt vectorLength = mm->E_VectorLength;
1045   PetscInt rowsPerTeam  = mm->E_RowsPerTeam;
1046   PetscInt workSets     = (irootloc.extent(0) + rowsPerTeam - 1) / rowsPerTeam;
1047 
1048   // Copy rows in A/B of E to rootBuf, then bcast it to leafBuf
1049   PetscCallCXX(Kokkos::parallel_for(
1050     Kokkos::TeamPolicy<>(PetscGetKokkosExecutionSpace(), workSets, teamSize, vectorLength), KOKKOS_LAMBDA(const KokkosTeamMemberType &t) {
1051       Kokkos::parallel_for(Kokkos::TeamThreadRange(t, 0, rowsPerTeam), [&](PetscInt k) {
1052         size_t r = t.league_rank() * rowsPerTeam + k; // r-th entry in irootloc[]
1053         if (r < irootloc.extent(0)) {
1054           PetscInt i      = irootloc(r); // row i of E
1055           PetscInt disp   = rowoffset(r);
1056           PetscInt alen   = Ai(i + 1) - Ai(i);
1057           PetscInt blen   = Bi(i + 1) - Bi(i);
1058           PetscInt nzleft = E_NzLeft(i);
1059 
1060           Kokkos::parallel_for(Kokkos::ThreadVectorRange(t, alen + blen), [&](PetscInt j) {
1061             if (j < nzleft) { // B left
1062               rootBuf(disp + j) = Ba(Bi(i) + j);
1063             } else if (j < nzleft + alen) { // diag A
1064               rootBuf(disp + j) = Aa(Ai(i) + j - nzleft);
1065             } else { // B right
1066               rootBuf(disp + j) = Ba(Bi(i) + j - alen);
1067             }
1068           });
1069         }
1070       });
1071     }));
1072   PetscCall(PetscSFBcastWithMemTypeBegin(bcastSF, MPIU_SCALAR, PETSC_MEMTYPE_KOKKOS, rootBuf.data(), PETSC_MEMTYPE_KOKKOS, leafBuf.data(), MPI_REPLACE));
1073   PetscFunctionReturn(PETSC_SUCCESS);
1074 }
1075 
1076 // To finish MatMPIAIJKokkosBcast.
1077 static PetscErrorCode MatMPIAIJKokkosBcastEnd(Mat E, PetscSF ownerSF, MatReuse reuse, PetscInt *map, MatMatStruct_AB *mm)
1078 {
1079   PetscFunctionBegin;
1080   const auto &Fd  = mm->Fd;
1081   const auto &Fo  = mm->Fo;
1082   const auto &Fdi = Fd.graph.row_map;
1083   const auto &Foi = Fo.graph.row_map;
1084   auto       &Fda = Fd.values;
1085   auto       &Foa = Fo.values;
1086   auto        Fm  = Fd.numRows();
1087 
1088   PetscIntKokkosView  &F_NzLeft     = mm->F_NzLeft;
1089   PetscSF             &bcastSF      = mm->sf;
1090   MatScalarKokkosView &rootBuf      = mm->rootBuf;
1091   MatScalarKokkosView &leafBuf      = mm->leafBuf;
1092   PetscInt             teamSize     = mm->F_TeamSize;
1093   PetscInt             vectorLength = mm->F_VectorLength;
1094   PetscInt             rowsPerTeam  = mm->F_RowsPerTeam;
1095   PetscInt             workSets     = (Fm + rowsPerTeam - 1) / rowsPerTeam;
1096 
1097   PetscCall(PetscSFBcastEnd(bcastSF, MPIU_SCALAR, rootBuf.data(), leafBuf.data(), MPI_REPLACE));
1098 
1099   // Update Fda and Foa with new data in leafBuf (as if it is Fa)
1100   PetscCallCXX(Kokkos::parallel_for(
1101     Kokkos::TeamPolicy<>(PetscGetKokkosExecutionSpace(), workSets, teamSize, vectorLength), KOKKOS_LAMBDA(const KokkosTeamMemberType &t) {
1102       Kokkos::parallel_for(Kokkos::TeamThreadRange(t, 0, rowsPerTeam), [&](PetscInt k) {
1103         PetscInt i = t.league_rank() * rowsPerTeam + k; // i-th row in F
1104         if (i < Fm) {
1105           PetscInt nzLeft = F_NzLeft(i);
1106           PetscInt alen   = Fdi(i + 1) - Fdi(i);
1107           PetscInt blen   = Foi(i + 1) - Foi(i);
1108           PetscInt Fii    = Fdi(i) + Foi(i);
1109 
1110           Kokkos::parallel_for(Kokkos::ThreadVectorRange(t, alen + blen), [&](PetscInt j) {
1111             PetscScalar val = leafBuf(Fii + j);
1112             if (j < nzLeft) { // left
1113               Foa(Foi(i) + j) = val;
1114             } else if (j < nzLeft + alen) { // diag
1115               Fda(Fdi(i) + j - nzLeft) = val;
1116             } else { // right
1117               Foa(Foi(i) + j - alen) = val;
1118             }
1119           });
1120         }
1121       });
1122     }));
1123   PetscFunctionReturn(PETSC_SUCCESS);
1124 }
1125 
1126 static PetscErrorCode MatProductSymbolic_MPIAIJKokkos_AtB(Mat_Product *product, Mat A, Mat B, MatMatStruct_AtB *mm)
1127 {
1128   Mat_MPIAIJ     *ampi = static_cast<Mat_MPIAIJ *>(A->data);
1129   Mat_MPIAIJ     *bmpi = static_cast<Mat_MPIAIJ *>(B->data);
1130   KokkosCsrMatrix Adt, Aot, Ad, Ao, Bd, Bo;
1131   PetscInt        cstart, cend;
1132   MPI_Comm        comm;
1133 
1134   PetscFunctionBegin;
1135   PetscCall(PetscObjectGetComm((PetscObject)B, &comm));
1136   PetscCall(MatSeqAIJKokkosGenerateTranspose_Private(ampi->A, &Adt));
1137   PetscCall(MatSeqAIJKokkosGenerateTranspose_Private(ampi->B, &Aot));
1138   PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(ampi->A, &Ad));
1139   PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(ampi->B, &Ao));
1140   PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(bmpi->A, &Bd));
1141   PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(bmpi->B, &Bo));
1142 
1143   // TODO: add command line options to select spgemm algorithms
1144   auto spgemm_alg = KokkosSparse::SPGEMMAlgorithm::SPGEMM_DEFAULT; // default is TPL if enabled, otherwise KK
1145 
1146   // CUDA-10.2's spgemm has bugs. We prefer the SpGEMMreuse APIs introduced in cuda-11.4
1147 #if defined(KOKKOSKERNELS_ENABLE_TPL_CUSPARSE)
1148   #if PETSC_PKG_CUDA_VERSION_LT(11, 4, 0)
1149   spgemm_alg = KokkosSparse::SPGEMMAlgorithm::SPGEMM_KK;
1150   #endif
1151 #endif
1152 
1153   PetscCallCXX(mm->kh1.create_spgemm_handle(spgemm_alg));
1154   PetscCallCXX(mm->kh2.create_spgemm_handle(spgemm_alg));
1155   PetscCallCXX(mm->kh3.create_spgemm_handle(spgemm_alg));
1156   PetscCallCXX(mm->kh4.create_spgemm_handle(spgemm_alg));
1157 
1158   // Aot * (B's diag + B's off-diag)
1159   PetscCallCXX(KokkosSparse::spgemm_symbolic(mm->kh3, Aot, false, Bd, false, mm->C3));
1160   PetscCallCXX(KokkosSparse::spgemm_symbolic(mm->kh4, Aot, false, Bo, false, mm->C4));
1161   // KK spgemm_symbolic() only populates the result's row map, but not its columns.
1162   // TODO: Remove the fake spgemm_numeric() after KK fixed this problem.
1163   PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh3, Aot, false, Bd, false, mm->C3));
1164   PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh4, Aot, false, Bo, false, mm->C4));
1165 #if PETSC_PKG_KOKKOS_KERNELS_VERSION_LT(4, 0, 0)
1166 
1167   PetscCallCXX(sort_crs_matrix(mm->C3));
1168   PetscCallCXX(sort_crs_matrix(mm->C4));
1169 #endif
1170 
1171   // Reduce E (i.e., C3 and C4)'s rows to form F, and overlap the communication
1172   PetscIntKokkosViewHost map_h(NoInit("map_h"), bmpi->B->cmap->n);
1173   PetscCall(MatGetOwnershipRangeColumn(B, &cstart, &cend));
1174   PetscCall(MatMPIAIJKokkosReduceBegin(comm, mm->C3, mm->C4, cstart, cend, bmpi->garray, ampi->Mvctx, MAT_INITIAL_MATRIX, map_h.data(), mm));
1175 
1176   // Adt * (B's diag + B's off-diag)
1177   PetscCallCXX(KokkosSparse::spgemm_symbolic(mm->kh1, Adt, false, Bd, false, mm->C1));
1178   PetscCallCXX(KokkosSparse::spgemm_symbolic(mm->kh2, Adt, false, Bo, false, mm->C2_mid));
1179   PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh1, Adt, false, Bd, false, mm->C1));
1180   PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh2, Adt, false, Bo, false, mm->C2_mid));
1181 #if PETSC_PKG_KOKKOS_KERNELS_VERSION_LT(4, 0, 0)
1182   PetscCallCXX(sort_crs_matrix(mm->C1));
1183   PetscCallCXX(sort_crs_matrix(mm->C2_mid));
1184 #endif
1185 
1186   PetscCall(MatMPIAIJKokkosReduceEnd(comm, mm->C3, mm->C4, cstart, cend, bmpi->garray, ampi->Mvctx, MAT_INITIAL_MATRIX, map_h.data(), mm));
1187 
1188   // Create C2, which shares a, i arrays with C2_mid, but with new column indices and potentially larger column size
1189   MatColIdxKokkosView oldj = mm->C2_mid.graph.entries, newj(NoInit("j"), oldj.extent(0));
1190   PetscIntKokkosView  map  = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), map_h);
1191   PetscCallCXX(Kokkos::parallel_for(Kokkos::RangePolicy<>(PetscGetKokkosExecutionSpace(), 0, oldj.extent(0)), KOKKOS_LAMBDA(const PetscInt i) { newj(i) = map(oldj(i)); }));
1192   PetscCallCXX(mm->C2 = KokkosCsrMatrix("C2", mm->C2_mid.numRows(), mm->n /*new column size*/, mm->C2_mid.nnz(), mm->C2_mid.values, mm->C2_mid.graph.row_map, newj));
1193 
1194   // C = (C1+Fd, C2+Fo)
1195   PetscCallCXX(mm->kh1.create_spadd_handle(true)); // C1, Fd are sorted
1196   PetscCallCXX(mm->kh2.create_spadd_handle(true)); // C2, Fo are sorted
1197   PetscCallCXX(KokkosSparse::spadd_symbolic(&mm->kh1, mm->C1, mm->Fd, mm->Cd));
1198   PetscCallCXX(KokkosSparse::spadd_symbolic(&mm->kh2, mm->C2, mm->Fo, mm->Co));
1199   PetscCallCXX(KokkosSparse::spadd_numeric(&mm->kh1, 1.0, mm->C1, 1.0, mm->Fd, mm->Cd));
1200   PetscCallCXX(KokkosSparse::spadd_numeric(&mm->kh2, 1.0, mm->C2, 1.0, mm->Fo, mm->Co));
1201   PetscFunctionReturn(PETSC_SUCCESS);
1202 }
1203 
1204 static PetscErrorCode MatProductNumeric_MPIAIJKokkos_AtB(Mat_Product *product, Mat A, Mat B, MatMatStruct_AtB *mm)
1205 {
1206   Mat_MPIAIJ     *ampi = static_cast<Mat_MPIAIJ *>(A->data);
1207   Mat_MPIAIJ     *bmpi = static_cast<Mat_MPIAIJ *>(B->data);
1208   KokkosCsrMatrix Adt, Aot, Bd, Bo;
1209   MPI_Comm        comm;
1210 
1211   PetscFunctionBegin;
1212   PetscCall(PetscObjectGetComm((PetscObject)B, &comm));
1213   PetscCall(MatSeqAIJKokkosGenerateTranspose_Private(ampi->A, &Adt));
1214   PetscCall(MatSeqAIJKokkosGenerateTranspose_Private(ampi->B, &Aot));
1215   PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(bmpi->A, &Bd));
1216   PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(bmpi->B, &Bo));
1217 
1218   // Aot * (B's diag + B's off-diag)
1219   PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh3, Aot, false, Bd, false, mm->C3));
1220   PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh4, Aot, false, Bo, false, mm->C4));
1221 
1222   // Reduce E (i.e., C3 and C4)'s rows to form F, and overlap the communication
1223   PetscCall(MatMPIAIJKokkosReduceBegin(comm, mm->C3, mm->C4, 0, 0, NULL, NULL, MAT_REUSE_MATRIX, NULL, mm));
1224 
1225   // Adt * (B's diag + B's off-diag)
1226   PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh1, Adt, false, Bd, false, mm->C1));
1227   PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh2, Adt, false, Bo, false, mm->C2_mid));
1228 
1229   PetscCall(MatMPIAIJKokkosReduceEnd(comm, mm->C3, mm->C4, 0, 0, NULL, NULL, MAT_REUSE_MATRIX, NULL, mm));
1230 
1231   // C = (C1+Fd, C2+Fo)
1232   PetscCallCXX(KokkosSparse::spadd_numeric(&mm->kh1, 1.0, mm->C1, 1.0, mm->Fd, mm->Cd));
1233   PetscCallCXX(KokkosSparse::spadd_numeric(&mm->kh2, 1.0, mm->C2, 1.0, mm->Fo, mm->Co));
1234   PetscFunctionReturn(PETSC_SUCCESS);
1235 }
1236 
1237 /* MatProductSymbolic_MPIAIJKokkos_AB - AB flavor of MatProductSymbolic_MPIAIJKokkos
1238 
1239   Input Parameters:
1240 +  product  - Mat_Product which carried out the computation. Passed in to access info about this mat product.
1241 .  A        - an MPIAIJKOKKOS matrix
1242 .  B        - an MPIAIJKOKKOS matrix
1243 -  mm       - a struct used to stash intermediate data when computing AB. Persist from symbolic to numeric operations.
1244 */
1245 static PetscErrorCode MatProductSymbolic_MPIAIJKokkos_AB(Mat_Product *product, Mat A, Mat B, MatMatStruct_AB *mm)
1246 {
1247   Mat_MPIAIJ     *ampi = static_cast<Mat_MPIAIJ *>(A->data);
1248   Mat_MPIAIJ     *bmpi = static_cast<Mat_MPIAIJ *>(B->data);
1249   KokkosCsrMatrix Ad, Ao, Bd, Bo;
1250 
1251   PetscFunctionBegin;
1252   PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(ampi->A, &Ad));
1253   PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(ampi->B, &Ao));
1254   PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(bmpi->A, &Bd));
1255   PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(bmpi->B, &Bo));
1256 
1257   // TODO: add command line options to select spgemm algorithms
1258   auto spgemm_alg = KokkosSparse::SPGEMMAlgorithm::SPGEMM_DEFAULT; // default is TPL if enabled, otherwise KK
1259 
1260   // CUDA-10.2's spgemm has bugs. We prefer the SpGEMMreuse APIs introduced in cuda-11.4
1261 #if defined(KOKKOSKERNELS_ENABLE_TPL_CUSPARSE)
1262   #if PETSC_PKG_CUDA_VERSION_LT(11, 4, 0)
1263   spgemm_alg = KokkosSparse::SPGEMMAlgorithm::SPGEMM_KK;
1264   #endif
1265 #endif
1266 
1267   mm->kh1.create_spgemm_handle(spgemm_alg);
1268   mm->kh2.create_spgemm_handle(spgemm_alg);
1269   mm->kh3.create_spgemm_handle(spgemm_alg);
1270   mm->kh4.create_spgemm_handle(spgemm_alg);
1271 
1272   // Bcast B's rows to form F, and overlap the communication
1273   PetscIntKokkosViewHost map_h(NoInit("map_h"), bmpi->B->cmap->n);
1274   PetscCall(MatMPIAIJKokkosBcastBegin(B, ampi->Mvctx, MAT_INITIAL_MATRIX, map_h.data(), mm));
1275 
1276   // A's diag * (B's diag + B's off-diag)
1277   PetscCallCXX(KokkosSparse::spgemm_symbolic(mm->kh1, Ad, false, Bd, false, mm->C1));
1278   PetscCallCXX(KokkosSparse::spgemm_symbolic(mm->kh2, Ad, false, Bo, false, mm->C2_mid)); // C2 aliases with C2_mid, except with new column indices
1279   // KK spgemm_symbolic() only populates the result's row map, but not its columns.
1280   // TODO: Remove the fake spgemm_numeric() after KK fixed this problem.
1281   PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh1, Ad, false, Bd, false, mm->C1));
1282   PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh2, Ad, false, Bo, false, mm->C2_mid));
1283 #if PETSC_PKG_KOKKOS_KERNELS_VERSION_LT(4, 0, 0)
1284   PetscCallCXX(sort_crs_matrix(mm->C1));
1285   PetscCallCXX(sort_crs_matrix(mm->C2_mid));
1286 #endif
1287 
1288   PetscCall(MatMPIAIJKokkosBcastEnd(B, ampi->Mvctx, MAT_INITIAL_MATRIX, map_h.data(), mm));
1289 
1290   // A's off-diag * (F's diag + F's off-diag)
1291   PetscCallCXX(KokkosSparse::spgemm_symbolic(mm->kh3, Ao, false, mm->Fd, false, mm->C3));
1292   PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh3, Ao, false, mm->Fd, false, mm->C3));
1293   PetscCallCXX(KokkosSparse::spgemm_symbolic(mm->kh4, Ao, false, mm->Fo, false, mm->C4));
1294   PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh4, Ao, false, mm->Fo, false, mm->C4));
1295 #if PETSC_PKG_KOKKOS_KERNELS_VERSION_LT(4, 0, 0)
1296   PetscCallCXX(sort_crs_matrix(mm->C3));
1297   PetscCallCXX(sort_crs_matrix(mm->C4));
1298 #endif
1299 
1300   // Create C2, which shares a, i arrays with C2_mid, but with new column indices and potentially larger column size
1301   MatColIdxKokkosView oldj = mm->C2_mid.graph.entries, newj(NoInit("j"), oldj.extent(0));
1302   PetscIntKokkosView  map  = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), map_h);
1303   PetscCallCXX(Kokkos::parallel_for(Kokkos::RangePolicy<>(PetscGetKokkosExecutionSpace(), 0, oldj.extent(0)), KOKKOS_LAMBDA(const PetscInt i) { newj(i) = map(oldj(i)); }));
1304   mm->C2 = KokkosCsrMatrix("C2", mm->C2_mid.numRows(), mm->n /*new column size*/, mm->C2_mid.nnz(), mm->C2_mid.values, mm->C2_mid.graph.row_map, newj);
1305 
1306   // C = (Cd, Co) = (C1+C3, C2+C4)
1307   mm->kh1.create_spadd_handle(true); // C1, C3 are sorted
1308   mm->kh2.create_spadd_handle(true); // C2, C4 are sorted
1309   PetscCallCXX(KokkosSparse::spadd_symbolic(&mm->kh1, mm->C1, mm->C3, mm->Cd));
1310   PetscCallCXX(KokkosSparse::spadd_symbolic(&mm->kh2, mm->C2, mm->C4, mm->Co));
1311   PetscCallCXX(KokkosSparse::spadd_numeric(&mm->kh1, 1.0, mm->C1, 1.0, mm->C3, mm->Cd));
1312   PetscCallCXX(KokkosSparse::spadd_numeric(&mm->kh2, 1.0, mm->C2, 1.0, mm->C4, mm->Co));
1313   PetscFunctionReturn(PETSC_SUCCESS);
1314 }
1315 
1316 static PetscErrorCode MatProductNumeric_MPIAIJKokkos_AB(Mat_Product *product, Mat A, Mat B, MatMatStruct_AB *mm)
1317 {
1318   Mat_MPIAIJ     *ampi = static_cast<Mat_MPIAIJ *>(A->data);
1319   Mat_MPIAIJ     *bmpi = static_cast<Mat_MPIAIJ *>(B->data);
1320   KokkosCsrMatrix Ad, Ao, Bd, Bo;
1321 
1322   PetscFunctionBegin;
1323   PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(ampi->A, &Ad));
1324   PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(ampi->B, &Ao));
1325   PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(bmpi->A, &Bd));
1326   PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(bmpi->B, &Bo));
1327 
1328   // Bcast B's rows to form F, and overlap the communication
1329   PetscCall(MatMPIAIJKokkosBcastBegin(B, NULL, MAT_REUSE_MATRIX, NULL, mm));
1330 
1331   // A's diag * (B's diag + B's off-diag)
1332   PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh1, Ad, false, Bd, false, mm->C1));
1333   PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh2, Ad, false, Bo, false, mm->C2_mid));
1334 
1335   PetscCall(MatMPIAIJKokkosBcastEnd(B, NULL, MAT_REUSE_MATRIX, NULL, mm));
1336 
1337   // A's off-diag * (F's diag + F's off-diag)
1338   PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh3, Ao, false, mm->Fd, false, mm->C3));
1339   PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh4, Ao, false, mm->Fo, false, mm->C4));
1340 
1341   // C = (Cd, Co) = (C1+C3, C2+C4)
1342   PetscCallCXX(KokkosSparse::spadd_numeric(&mm->kh1, 1.0, mm->C1, 1.0, mm->C3, mm->Cd));
1343   PetscCallCXX(KokkosSparse::spadd_numeric(&mm->kh2, 1.0, mm->C2, 1.0, mm->C4, mm->Co));
1344   PetscFunctionReturn(PETSC_SUCCESS);
1345 }
1346 
1347 static PetscErrorCode MatProductNumeric_MPIAIJKokkos(Mat C)
1348 {
1349   Mat_MPIAIJ                  *cmpi = static_cast<Mat_MPIAIJ *>(C->data);
1350   Mat_Product                 *product;
1351   MatProductData_MPIAIJKokkos *pdata;
1352   MatProductType               ptype;
1353   Mat                          A, B;
1354 
1355   PetscFunctionBegin;
1356   MatCheckProduct(C, 1); // make sure C is a product
1357   product = C->product;
1358   pdata   = static_cast<MatProductData_MPIAIJKokkos *>(product->data);
1359   ptype   = product->type;
1360   A       = product->A;
1361   B       = product->B;
1362 
1363   // See if numeric has already been done in symbolic (e.g., user calls MatMatMult(A,B,MAT_INITIAL_MATRIX,..,C)).
1364   // If yes, skip the numeric, but reset the flag so that next time when user calls MatMatMult(E,F,MAT_REUSE_MATRIX,..,C),
1365   // we still do numeric.
1366   if (pdata->reusesym) { // numeric reuses results from symbolic
1367     pdata->reusesym = PETSC_FALSE;
1368     PetscFunctionReturn(PETSC_SUCCESS);
1369   }
1370 
1371   if (ptype == MATPRODUCT_AB) {
1372     PetscCall(MatProductNumeric_MPIAIJKokkos_AB(product, A, B, pdata->mmAB));
1373   } else if (ptype == MATPRODUCT_AtB) {
1374     PetscCall(MatProductNumeric_MPIAIJKokkos_AtB(product, A, B, pdata->mmAtB));
1375   } else if (ptype == MATPRODUCT_PtAP) { // BtAB, computed by Z = AB; C= BtZ
1376     PetscCall(MatProductNumeric_MPIAIJKokkos_AB(product, A, B, pdata->mmAB));
1377     PetscCall(MatProductNumeric_MPIAIJKokkos_AtB(product, B, pdata->Z, pdata->mmAtB));
1378   }
1379 
1380   PetscCall(MatSeqAIJKokkosModifyDevice(cmpi->A)); // mark that A, B on device are modified
1381   PetscCall(MatSeqAIJKokkosModifyDevice(cmpi->B));
1382   PetscFunctionReturn(PETSC_SUCCESS);
1383 }
1384 
1385 static PetscErrorCode MatProductSymbolic_MPIAIJKokkos(Mat C)
1386 {
1387   Mat                          A, B;
1388   Mat_Product                 *product;
1389   MatProductType               ptype;
1390   MatProductData_MPIAIJKokkos *pdata;
1391   MatMatStruct                *mm = NULL;
1392   PetscInt                     m, n, M, N;
1393   Mat                          Cd, Co;
1394   MPI_Comm                     comm;
1395 
1396   PetscFunctionBegin;
1397   PetscCall(PetscObjectGetComm((PetscObject)C, &comm));
1398   MatCheckProduct(C, 1);
1399   product = C->product;
1400   PetscCheck(!product->data, comm, PETSC_ERR_PLIB, "Product data not empty");
1401   ptype = product->type;
1402   A     = product->A;
1403   B     = product->B;
1404 
1405   switch (ptype) {
1406   case MATPRODUCT_AB:
1407     m = A->rmap->n;
1408     n = B->cmap->n;
1409     M = A->rmap->N;
1410     N = B->cmap->N;
1411     break;
1412   case MATPRODUCT_AtB:
1413     m = A->cmap->n;
1414     n = B->cmap->n;
1415     M = A->cmap->N;
1416     N = B->cmap->N;
1417     break;
1418   case MATPRODUCT_PtAP:
1419     m = B->cmap->n;
1420     n = B->cmap->n;
1421     M = B->cmap->N;
1422     N = B->cmap->N;
1423     break; /* BtAB */
1424   default:
1425     SETERRQ(comm, PETSC_ERR_PLIB, "Not for product type %s", MatProductTypes[ptype]);
1426   }
1427 
1428   PetscCall(MatSetSizes(C, m, n, M, N));
1429   PetscCall(PetscLayoutSetUp(C->rmap));
1430   PetscCall(PetscLayoutSetUp(C->cmap));
1431   PetscCall(MatSetType(C, ((PetscObject)A)->type_name));
1432 
1433   pdata           = new MatProductData_MPIAIJKokkos();
1434   pdata->reusesym = product->api_user;
1435 
1436   if (ptype == MATPRODUCT_AB) {
1437     auto mmAB = new MatMatStruct_AB();
1438     PetscCall(MatProductSymbolic_MPIAIJKokkos_AB(product, A, B, mmAB));
1439     mm = pdata->mmAB = mmAB;
1440   } else if (ptype == MATPRODUCT_AtB) {
1441     auto mmAtB = new MatMatStruct_AtB();
1442     PetscCall(MatProductSymbolic_MPIAIJKokkos_AtB(product, A, B, mmAtB));
1443     mm = pdata->mmAtB = mmAtB;
1444   } else if (ptype == MATPRODUCT_PtAP) { // C = BtAB, computed as Z = AB; C= BtZ
1445     Mat Zd, Zo, Z;                       // Zd, Zo are owned by pdata->Z
1446 
1447     auto mmAB = new MatMatStruct_AB();
1448     PetscCall(MatProductSymbolic_MPIAIJKokkos_AB(product, A, B, mmAB)); // Z stored as mmAB->{Cd, Co}
1449     PetscCall(MatCreateSeqAIJKokkosWithKokkosCsrMatrix(PETSC_COMM_SELF, mmAB->Cd, &Zd));
1450     PetscCall(MatCreateSeqAIJKokkosWithKokkosCsrMatrix(PETSC_COMM_SELF, mmAB->Co, &Zo));
1451     pdata->mmAB = mmAB;
1452 
1453     m = A->rmap->n; // Z's layout
1454     n = B->cmap->n;
1455     M = A->rmap->N;
1456     N = B->cmap->N;
1457     PetscCall(MatCreate(comm, &Z));
1458     PetscCall(MatSetSizes(Z, m, n, M, N));
1459     PetscCall(PetscLayoutSetUp(Z->rmap));
1460     PetscCall(PetscLayoutSetUp(Z->cmap));
1461     PetscCall(MatSetType(Z, MATMPIAIJKOKKOS));
1462     PetscCall(MatSetMPIAIJKokkosWithSplitSeqAIJKokkosMatrices(Z, Zd, Zo, mmAB->garray));
1463 
1464     auto mmAtB = new MatMatStruct_AtB();
1465     PetscCall(MatProductSymbolic_MPIAIJKokkos_AtB(product, B, Z, mmAtB)); // final result C stored as mmAtB->{Cd, Co}
1466 
1467     pdata->Z = Z; // give ownership to pdata
1468     mm = pdata->mmAtB = mmAtB;
1469   }
1470 
1471   PetscCall(MatCreateSeqAIJKokkosWithKokkosCsrMatrix(PETSC_COMM_SELF, mm->Cd, &Cd));
1472   PetscCall(MatCreateSeqAIJKokkosWithKokkosCsrMatrix(PETSC_COMM_SELF, mm->Co, &Co));
1473   PetscCall(MatSetMPIAIJKokkosWithSplitSeqAIJKokkosMatrices(C, Cd, Co, mm->garray));
1474 
1475   C->product->data       = pdata;
1476   C->product->destroy    = MatProductDataDestroy_MPIAIJKokkos;
1477   C->ops->productnumeric = MatProductNumeric_MPIAIJKokkos;
1478   PetscFunctionReturn(PETSC_SUCCESS);
1479 }
1480 
1481 PETSC_INTERN PetscErrorCode MatProductSetFromOptions_MPIAIJKokkos(Mat mat)
1482 {
1483   Mat_Product *product = mat->product;
1484   PetscBool    match   = PETSC_FALSE;
1485   PetscBool    usecpu  = PETSC_FALSE;
1486 
1487   PetscFunctionBegin;
1488   MatCheckProduct(mat, 1);
1489   if (!product->A->boundtocpu && !product->B->boundtocpu) PetscCall(PetscObjectTypeCompare((PetscObject)product->B, ((PetscObject)product->A)->type_name, &match));
1490   if (match) { /* we can always fallback to the CPU if requested */
1491     switch (product->type) {
1492     case MATPRODUCT_AB:
1493       if (product->api_user) {
1494         PetscOptionsBegin(PetscObjectComm((PetscObject)mat), ((PetscObject)mat)->prefix, "MatMatMult", "Mat");
1495         PetscCall(PetscOptionsBool("-matmatmult_backend_cpu", "Use CPU code", "MatMatMult", usecpu, &usecpu, NULL));
1496         PetscOptionsEnd();
1497       } else {
1498         PetscOptionsBegin(PetscObjectComm((PetscObject)mat), ((PetscObject)mat)->prefix, "MatProduct_AB", "Mat");
1499         PetscCall(PetscOptionsBool("-mat_product_algorithm_backend_cpu", "Use CPU code", "MatMatMult", usecpu, &usecpu, NULL));
1500         PetscOptionsEnd();
1501       }
1502       break;
1503     case MATPRODUCT_AtB:
1504       if (product->api_user) {
1505         PetscOptionsBegin(PetscObjectComm((PetscObject)mat), ((PetscObject)mat)->prefix, "MatTransposeMatMult", "Mat");
1506         PetscCall(PetscOptionsBool("-mattransposematmult_backend_cpu", "Use CPU code", "MatTransposeMatMult", usecpu, &usecpu, NULL));
1507         PetscOptionsEnd();
1508       } else {
1509         PetscOptionsBegin(PetscObjectComm((PetscObject)mat), ((PetscObject)mat)->prefix, "MatProduct_AtB", "Mat");
1510         PetscCall(PetscOptionsBool("-mat_product_algorithm_backend_cpu", "Use CPU code", "MatTransposeMatMult", usecpu, &usecpu, NULL));
1511         PetscOptionsEnd();
1512       }
1513       break;
1514     case MATPRODUCT_PtAP:
1515       if (product->api_user) {
1516         PetscOptionsBegin(PetscObjectComm((PetscObject)mat), ((PetscObject)mat)->prefix, "MatPtAP", "Mat");
1517         PetscCall(PetscOptionsBool("-matptap_backend_cpu", "Use CPU code", "MatPtAP", usecpu, &usecpu, NULL));
1518         PetscOptionsEnd();
1519       } else {
1520         PetscOptionsBegin(PetscObjectComm((PetscObject)mat), ((PetscObject)mat)->prefix, "MatProduct_PtAP", "Mat");
1521         PetscCall(PetscOptionsBool("-mat_product_algorithm_backend_cpu", "Use CPU code", "MatPtAP", usecpu, &usecpu, NULL));
1522         PetscOptionsEnd();
1523       }
1524       break;
1525     default:
1526       break;
1527     }
1528     match = (PetscBool)!usecpu;
1529   }
1530   if (match) {
1531     switch (product->type) {
1532     case MATPRODUCT_AB:
1533     case MATPRODUCT_AtB:
1534     case MATPRODUCT_PtAP:
1535       mat->ops->productsymbolic = MatProductSymbolic_MPIAIJKokkos;
1536       break;
1537     default:
1538       break;
1539     }
1540   }
1541   /* fallback to MPIAIJ ops */
1542   if (!mat->ops->productsymbolic) PetscCall(MatProductSetFromOptions_MPIAIJ(mat));
1543   PetscFunctionReturn(PETSC_SUCCESS);
1544 }
1545 
1546 // Mirror of MatCOOStruct_MPIAIJ on device
1547 struct MatCOOStruct_MPIAIJKokkos {
1548   PetscCount           n;
1549   PetscSF              sf;
1550   PetscCount           Annz, Bnnz;
1551   PetscCount           Annz2, Bnnz2;
1552   PetscCountKokkosView Ajmap1, Aperm1;
1553   PetscCountKokkosView Bjmap1, Bperm1;
1554   PetscCountKokkosView Aimap2, Ajmap2, Aperm2;
1555   PetscCountKokkosView Bimap2, Bjmap2, Bperm2;
1556   PetscCountKokkosView Cperm1;
1557   MatScalarKokkosView  sendbuf, recvbuf;
1558 
1559   MatCOOStruct_MPIAIJKokkos(const MatCOOStruct_MPIAIJ *coo_h)
1560   {
1561     auto &exec = PetscGetKokkosExecutionSpace();
1562 
1563     n       = coo_h->n;
1564     sf      = coo_h->sf;
1565     Annz    = coo_h->Annz;
1566     Bnnz    = coo_h->Bnnz;
1567     Annz2   = coo_h->Annz2;
1568     Bnnz2   = coo_h->Bnnz2;
1569     Ajmap1  = Kokkos::create_mirror_view_and_copy(exec, PetscCountKokkosViewHost(coo_h->Ajmap1, coo_h->Annz + 1));
1570     Aperm1  = Kokkos::create_mirror_view_and_copy(exec, PetscCountKokkosViewHost(coo_h->Aperm1, coo_h->Atot1));
1571     Bjmap1  = Kokkos::create_mirror_view_and_copy(exec, PetscCountKokkosViewHost(coo_h->Bjmap1, coo_h->Bnnz + 1));
1572     Bperm1  = Kokkos::create_mirror_view_and_copy(exec, PetscCountKokkosViewHost(coo_h->Bperm1, coo_h->Btot1));
1573     Aimap2  = Kokkos::create_mirror_view_and_copy(exec, PetscCountKokkosViewHost(coo_h->Aimap2, coo_h->Annz2));
1574     Ajmap2  = Kokkos::create_mirror_view_and_copy(exec, PetscCountKokkosViewHost(coo_h->Ajmap2, coo_h->Annz2 + 1));
1575     Aperm2  = Kokkos::create_mirror_view_and_copy(exec, PetscCountKokkosViewHost(coo_h->Aperm2, coo_h->Atot2));
1576     Bimap2  = Kokkos::create_mirror_view_and_copy(exec, PetscCountKokkosViewHost(coo_h->Bimap2, coo_h->Bnnz2));
1577     Bjmap2  = Kokkos::create_mirror_view_and_copy(exec, PetscCountKokkosViewHost(coo_h->Bjmap2, coo_h->Bnnz2 + 1));
1578     Bperm2  = Kokkos::create_mirror_view_and_copy(exec, PetscCountKokkosViewHost(coo_h->Bperm2, coo_h->Btot2));
1579     Cperm1  = Kokkos::create_mirror_view_and_copy(exec, PetscCountKokkosViewHost(coo_h->Cperm1, coo_h->sendlen));
1580     sendbuf = Kokkos::create_mirror_view(Kokkos::WithoutInitializing, exec, MatScalarKokkosViewHost(coo_h->sendbuf, coo_h->sendlen));
1581     recvbuf = Kokkos::create_mirror_view(Kokkos::WithoutInitializing, exec, MatScalarKokkosViewHost(coo_h->recvbuf, coo_h->recvlen));
1582     PetscCallVoid(PetscObjectReference((PetscObject)sf));
1583   }
1584 
1585   ~MatCOOStruct_MPIAIJKokkos() { PetscCallVoid(PetscSFDestroy(&sf)); }
1586 };
1587 
1588 static PetscErrorCode MatCOOStructDestroy_MPIAIJKokkos(void *data)
1589 {
1590   PetscFunctionBegin;
1591   PetscCallCXX(delete static_cast<MatCOOStruct_MPIAIJKokkos *>(data));
1592   PetscFunctionReturn(PETSC_SUCCESS);
1593 }
1594 
1595 static PetscErrorCode MatSetPreallocationCOO_MPIAIJKokkos(Mat mat, PetscCount coo_n, PetscInt coo_i[], PetscInt coo_j[])
1596 {
1597   PetscContainer             container_h, container_d;
1598   MatCOOStruct_MPIAIJ       *coo_h;
1599   MatCOOStruct_MPIAIJKokkos *coo_d;
1600 
1601   PetscFunctionBegin;
1602   PetscCall(MatSetPreallocationCOO_MPIAIJ(mat, coo_n, coo_i, coo_j)); /* mpiaij->A,B's type is set to seqaijkokkos */
1603   mat->preallocated = PETSC_TRUE;
1604   PetscCall(MatAssemblyBegin(mat, MAT_FINAL_ASSEMBLY));
1605   PetscCall(MatAssemblyEnd(mat, MAT_FINAL_ASSEMBLY));
1606   PetscCall(MatZeroEntries(mat));
1607 
1608   // Copy the COO struct to device
1609   PetscCall(PetscObjectQuery((PetscObject)mat, "__PETSc_MatCOOStruct_Host", (PetscObject *)&container_h));
1610   PetscCall(PetscContainerGetPointer(container_h, (void **)&coo_h));
1611   PetscCallCXX(coo_d = new MatCOOStruct_MPIAIJKokkos(coo_h));
1612 
1613   // Put the COO struct in a container and then attach that to the matrix
1614   PetscCall(PetscContainerCreate(PETSC_COMM_SELF, &container_d));
1615   PetscCall(PetscContainerSetPointer(container_d, coo_d));
1616   PetscCall(PetscContainerSetUserDestroy(container_d, MatCOOStructDestroy_MPIAIJKokkos));
1617   PetscCall(PetscObjectCompose((PetscObject)mat, "__PETSc_MatCOOStruct_Device", (PetscObject)container_d));
1618   PetscCall(PetscContainerDestroy(&container_d));
1619   PetscFunctionReturn(PETSC_SUCCESS);
1620 }
1621 
1622 static PetscErrorCode MatSetValuesCOO_MPIAIJKokkos(Mat mat, const PetscScalar v[], InsertMode imode)
1623 {
1624   Mat_MPIAIJ                    *mpiaij = static_cast<Mat_MPIAIJ *>(mat->data);
1625   Mat                            A = mpiaij->A, B = mpiaij->B;
1626   MatScalarKokkosView            Aa, Ba;
1627   MatScalarKokkosView            v1;
1628   PetscMemType                   memtype;
1629   PetscContainer                 container;
1630   MatCOOStruct_MPIAIJKokkos     *coo;
1631   Kokkos::DefaultExecutionSpace &exec = PetscGetKokkosExecutionSpace();
1632 
1633   PetscFunctionBegin;
1634   PetscCall(PetscObjectQuery((PetscObject)mat, "__PETSc_MatCOOStruct_Device", (PetscObject *)&container));
1635   PetscCall(PetscContainerGetPointer(container, (void **)&coo));
1636 
1637   const auto &n      = coo->n;
1638   const auto &Annz   = coo->Annz;
1639   const auto &Annz2  = coo->Annz2;
1640   const auto &Bnnz   = coo->Bnnz;
1641   const auto &Bnnz2  = coo->Bnnz2;
1642   const auto &vsend  = coo->sendbuf;
1643   const auto &v2     = coo->recvbuf;
1644   const auto &Ajmap1 = coo->Ajmap1;
1645   const auto &Ajmap2 = coo->Ajmap2;
1646   const auto &Aimap2 = coo->Aimap2;
1647   const auto &Bjmap1 = coo->Bjmap1;
1648   const auto &Bjmap2 = coo->Bjmap2;
1649   const auto &Bimap2 = coo->Bimap2;
1650   const auto &Aperm1 = coo->Aperm1;
1651   const auto &Aperm2 = coo->Aperm2;
1652   const auto &Bperm1 = coo->Bperm1;
1653   const auto &Bperm2 = coo->Bperm2;
1654   const auto &Cperm1 = coo->Cperm1;
1655 
1656   PetscCall(PetscGetMemType(v, &memtype)); /* Return PETSC_MEMTYPE_HOST when v is NULL */
1657   if (PetscMemTypeHost(memtype)) {         /* If user gave v[] in host, we need to copy it to device if any */
1658     v1 = Kokkos::create_mirror_view_and_copy(exec, MatScalarKokkosViewHost((PetscScalar *)v, n));
1659   } else {
1660     v1 = MatScalarKokkosView((PetscScalar *)v, n); /* Directly use v[]'s memory */
1661   }
1662 
1663   if (imode == INSERT_VALUES) {
1664     PetscCall(MatSeqAIJGetKokkosViewWrite(A, &Aa)); /* write matrix values */
1665     PetscCall(MatSeqAIJGetKokkosViewWrite(B, &Ba));
1666   } else {
1667     PetscCall(MatSeqAIJGetKokkosView(A, &Aa)); /* read & write matrix values */
1668     PetscCall(MatSeqAIJGetKokkosView(B, &Ba));
1669   }
1670 
1671   PetscCall(PetscLogGpuTimeBegin());
1672   /* Pack entries to be sent to remote */
1673   Kokkos::parallel_for(Kokkos::RangePolicy<>(exec, 0, vsend.extent(0)), KOKKOS_LAMBDA(const PetscCount i) { vsend(i) = v1(Cperm1(i)); });
1674 
1675   /* Send remote entries to their owner and overlap the communication with local computation */
1676   PetscCall(PetscSFReduceWithMemTypeBegin(coo->sf, MPIU_SCALAR, PETSC_MEMTYPE_KOKKOS, vsend.data(), PETSC_MEMTYPE_KOKKOS, v2.data(), MPI_REPLACE));
1677   /* Add local entries to A and B in one kernel */
1678   Kokkos::parallel_for(
1679     Kokkos::RangePolicy<>(exec, 0, Annz + Bnnz), KOKKOS_LAMBDA(PetscCount i) {
1680       PetscScalar sum = 0.0;
1681       if (i < Annz) {
1682         for (PetscCount k = Ajmap1(i); k < Ajmap1(i + 1); k++) sum += v1(Aperm1(k));
1683         Aa(i) = (imode == INSERT_VALUES ? 0.0 : Aa(i)) + sum;
1684       } else {
1685         i -= Annz;
1686         for (PetscCount k = Bjmap1(i); k < Bjmap1(i + 1); k++) sum += v1(Bperm1(k));
1687         Ba(i) = (imode == INSERT_VALUES ? 0.0 : Ba(i)) + sum;
1688       }
1689     });
1690   PetscCall(PetscSFReduceEnd(coo->sf, MPIU_SCALAR, vsend.data(), v2.data(), MPI_REPLACE));
1691 
1692   /* Add received remote entries to A and B in one kernel */
1693   Kokkos::parallel_for(
1694     Kokkos::RangePolicy<>(exec, 0, Annz2 + Bnnz2), KOKKOS_LAMBDA(PetscCount i) {
1695       if (i < Annz2) {
1696         for (PetscCount k = Ajmap2(i); k < Ajmap2(i + 1); k++) Aa(Aimap2(i)) += v2(Aperm2(k));
1697       } else {
1698         i -= Annz2;
1699         for (PetscCount k = Bjmap2(i); k < Bjmap2(i + 1); k++) Ba(Bimap2(i)) += v2(Bperm2(k));
1700       }
1701     });
1702   PetscCall(PetscLogGpuTimeEnd());
1703 
1704   if (imode == INSERT_VALUES) {
1705     PetscCall(MatSeqAIJRestoreKokkosViewWrite(A, &Aa)); /* Increase A & B's state etc. */
1706     PetscCall(MatSeqAIJRestoreKokkosViewWrite(B, &Ba));
1707   } else {
1708     PetscCall(MatSeqAIJRestoreKokkosView(A, &Aa));
1709     PetscCall(MatSeqAIJRestoreKokkosView(B, &Ba));
1710   }
1711   PetscFunctionReturn(PETSC_SUCCESS);
1712 }
1713 
1714 static PetscErrorCode MatDestroy_MPIAIJKokkos(Mat A)
1715 {
1716   PetscFunctionBegin;
1717   PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatMPIAIJSetPreallocation_C", NULL));
1718   PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatMPIAIJGetLocalMatMerge_C", NULL));
1719   PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatSetPreallocationCOO_C", NULL));
1720   PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatSetValuesCOO_C", NULL));
1721   PetscCall(MatDestroy_MPIAIJ(A));
1722   PetscFunctionReturn(PETSC_SUCCESS);
1723 }
1724 
1725 static PetscErrorCode MatShift_MPIAIJKokkos(Mat A, PetscScalar a)
1726 {
1727   Mat_MPIAIJ *mpiaij = static_cast<Mat_MPIAIJ *>(A->data);
1728   PetscBool   congruent;
1729 
1730   PetscFunctionBegin;
1731   PetscCall(MatHasCongruentLayouts(A, &congruent));
1732   if (congruent) { // square matrix and the diagonals are solely in the diag block
1733     PetscCall(MatShift(mpiaij->A, a));
1734   } else { // too hard, use the general version
1735     PetscCall(MatShift_Basic(A, a));
1736   }
1737   PetscFunctionReturn(PETSC_SUCCESS);
1738 }
1739 
1740 static PetscErrorCode MatSetOps_MPIAIJKokkos(Mat B)
1741 {
1742   PetscFunctionBegin;
1743   B->ops->assemblyend           = MatAssemblyEnd_MPIAIJKokkos;
1744   B->ops->mult                  = MatMult_MPIAIJKokkos;
1745   B->ops->multadd               = MatMultAdd_MPIAIJKokkos;
1746   B->ops->multtranspose         = MatMultTranspose_MPIAIJKokkos;
1747   B->ops->productsetfromoptions = MatProductSetFromOptions_MPIAIJKokkos;
1748   B->ops->destroy               = MatDestroy_MPIAIJKokkos;
1749   B->ops->shift                 = MatShift_MPIAIJKokkos;
1750 
1751   PetscCall(PetscObjectComposeFunction((PetscObject)B, "MatMPIAIJSetPreallocation_C", MatMPIAIJSetPreallocation_MPIAIJKokkos));
1752   PetscCall(PetscObjectComposeFunction((PetscObject)B, "MatMPIAIJGetLocalMatMerge_C", MatMPIAIJGetLocalMatMerge_MPIAIJKokkos));
1753   PetscCall(PetscObjectComposeFunction((PetscObject)B, "MatSetPreallocationCOO_C", MatSetPreallocationCOO_MPIAIJKokkos));
1754   PetscCall(PetscObjectComposeFunction((PetscObject)B, "MatSetValuesCOO_C", MatSetValuesCOO_MPIAIJKokkos));
1755   PetscFunctionReturn(PETSC_SUCCESS);
1756 }
1757 
1758 PETSC_INTERN PetscErrorCode MatConvert_MPIAIJ_MPIAIJKokkos(Mat A, MatType mtype, MatReuse reuse, Mat *newmat)
1759 {
1760   Mat         B;
1761   Mat_MPIAIJ *a;
1762 
1763   PetscFunctionBegin;
1764   if (reuse == MAT_INITIAL_MATRIX) {
1765     PetscCall(MatDuplicate(A, MAT_COPY_VALUES, newmat));
1766   } else if (reuse == MAT_REUSE_MATRIX) {
1767     PetscCall(MatCopy(A, *newmat, SAME_NONZERO_PATTERN));
1768   }
1769   B = *newmat;
1770 
1771   B->boundtocpu = PETSC_FALSE;
1772   PetscCall(PetscFree(B->defaultvectype));
1773   PetscCall(PetscStrallocpy(VECKOKKOS, &B->defaultvectype));
1774   PetscCall(PetscObjectChangeTypeName((PetscObject)B, MATMPIAIJKOKKOS));
1775 
1776   a = static_cast<Mat_MPIAIJ *>(A->data);
1777   if (a->A) PetscCall(MatSetType(a->A, MATSEQAIJKOKKOS));
1778   if (a->B) PetscCall(MatSetType(a->B, MATSEQAIJKOKKOS));
1779   if (a->lvec) PetscCall(VecSetType(a->lvec, VECSEQKOKKOS));
1780   PetscCall(MatSetOps_MPIAIJKokkos(B));
1781   PetscFunctionReturn(PETSC_SUCCESS);
1782 }
1783 
1784 /*MC
1785    MATAIJKOKKOS - "mpiaijkokkos", a matrix type to be used for CSR sparse matrices with Kokkos
1786 
1787    A matrix type using Kokkos-Kernels CrsMatrix type for portability across different device types
1788 
1789    Options Database Key:
1790 .  -mat_type aijkokkos - sets the matrix type to `MATAIJKOKKOS`
1791 
1792   Level: beginner
1793 
1794 .seealso: [](ch_matrices), `Mat`, `MatCreateAIJKokkos()`, `MATSEQAIJKOKKOS`, `MATSEQAIJ`, `MATMPIAIJ`
1795 M*/
1796 PETSC_EXTERN PetscErrorCode MatCreate_MPIAIJKokkos(Mat A)
1797 {
1798   PetscFunctionBegin;
1799   PetscCall(PetscKokkosInitializeCheck());
1800   PetscCall(MatCreate_MPIAIJ(A));
1801   PetscCall(MatConvert_MPIAIJ_MPIAIJKokkos(A, MATMPIAIJKOKKOS, MAT_INPLACE_MATRIX, &A));
1802   PetscFunctionReturn(PETSC_SUCCESS);
1803 }
1804 
1805 /*@C
1806   MatCreateAIJKokkos - Creates a sparse matrix in `MATAIJKOKOS` (compressed row) format
1807   (the default parallel PETSc format).  This matrix will ultimately pushed down
1808   to Kokkos for calculations.
1809 
1810   Collective
1811 
1812   Input Parameters:
1813 + comm  - MPI communicator, set to `PETSC_COMM_SELF`
1814 . m     - number of local rows (or `PETSC_DECIDE` to have calculated if `M` is given)
1815            This value should be the same as the local size used in creating the
1816            y vector for the matrix-vector product y = Ax.
1817 . n     - This value should be the same as the local size used in creating the
1818        x vector for the matrix-vector product y = Ax. (or `PETSC_DECIDE` to have
1819        calculated if N is given) For square matrices n is almost always `m`.
1820 . M     - number of global rows (or `PETSC_DETERMINE` to have calculated if `m` is given)
1821 . N     - number of global columns (or `PETSC_DETERMINE` to have calculated if `n` is given)
1822 . d_nz  - number of nonzeros per row in DIAGONAL portion of local submatrix
1823            (same value is used for all local rows)
1824 . d_nnz - array containing the number of nonzeros in the various rows of the
1825            DIAGONAL portion of the local submatrix (possibly different for each row)
1826            or `NULL`, if `d_nz` is used to specify the nonzero structure.
1827            The size of this array is equal to the number of local rows, i.e `m`.
1828            For matrices you plan to factor you must leave room for the diagonal entry and
1829            put in the entry even if it is zero.
1830 . o_nz  - number of nonzeros per row in the OFF-DIAGONAL portion of local
1831            submatrix (same value is used for all local rows).
1832 - o_nnz - array containing the number of nonzeros in the various rows of the
1833            OFF-DIAGONAL portion of the local submatrix (possibly different for
1834            each row) or `NULL`, if `o_nz` is used to specify the nonzero
1835            structure. The size of this array is equal to the number
1836            of local rows, i.e `m`.
1837 
1838   Output Parameter:
1839 . A - the matrix
1840 
1841   Level: intermediate
1842 
1843   Notes:
1844   It is recommended that one use the `MatCreate()`, `MatSetType()` and/or `MatSetFromOptions()`,
1845   MatXXXXSetPreallocation() paradigm instead of this routine directly.
1846   [MatXXXXSetPreallocation() is, for example, `MatSeqAIJSetPreallocation()`]
1847 
1848   The AIJ format, also called compressed row storage), is fully compatible with standard Fortran
1849   storage.  That is, the stored row and column indices can begin at
1850   either one (as in Fortran) or zero.
1851 
1852 .seealso: [](ch_matrices), `Mat`, `MATAIJKOKOS`, `MATSEQAIJKOKOS`, `MATMPIAIJKOKOS`, `MatCreate()`, `MatCreateAIJ()`, `MatSetValues()`,
1853           `MatSeqAIJSetColumnIndices()`, `MatCreateSeqAIJWithArrays()`, `MATMPIAIJKOKKOS`, `MATAIJKOKKOS`
1854 @*/
1855 PetscErrorCode MatCreateAIJKokkos(MPI_Comm comm, PetscInt m, PetscInt n, PetscInt M, PetscInt N, PetscInt d_nz, const PetscInt d_nnz[], PetscInt o_nz, const PetscInt o_nnz[], Mat *A)
1856 {
1857   PetscMPIInt size;
1858 
1859   PetscFunctionBegin;
1860   PetscCall(MatCreate(comm, A));
1861   PetscCall(MatSetSizes(*A, m, n, M, N));
1862   PetscCallMPI(MPI_Comm_size(comm, &size));
1863   if (size > 1) {
1864     PetscCall(MatSetType(*A, MATMPIAIJKOKKOS));
1865     PetscCall(MatMPIAIJSetPreallocation(*A, d_nz, d_nnz, o_nz, o_nnz));
1866   } else {
1867     PetscCall(MatSetType(*A, MATSEQAIJKOKKOS));
1868     PetscCall(MatSeqAIJSetPreallocation(*A, d_nz, d_nnz));
1869   }
1870   PetscFunctionReturn(PETSC_SUCCESS);
1871 }
1872