xref: /petsc/src/mat/impls/aij/mpi/kokkos/mpiaijkok.kokkos.cxx (revision 2d6c3ceebcb340e82a066ce435cc2ceddc132986)
1 #include <petscvec_kokkos.hpp>
2 #include <petscpkg_version.h>
3 #include <petscsf.h>
4 #include <petsc/private/sfimpl.h>
5 #include <../src/mat/impls/aij/mpi/mpiaij.h>
6 #include <../src/mat/impls/aij/mpi/kokkos/mpiaijkok.hpp>
7 #include <KokkosSparse_spadd.hpp>
8 #include <KokkosSparse_spgemm.hpp>
9 
10 PetscErrorCode MatAssemblyEnd_MPIAIJKokkos(Mat A, MatAssemblyType mode)
11 {
12   Mat_MPIAIJ *mpiaij = (Mat_MPIAIJ *)A->data;
13 
14   PetscFunctionBegin;
15   PetscCall(MatAssemblyEnd_MPIAIJ(A, mode));
16   /* E.g., MatCreateSubMatrix() calls MatCreateMPIAIJWithSeqAIJ(comm,A,B,..), which creates Bnew of SEQAIJ and destroys B of SEQAIJKOKKOS.
17      Thus we finalize A/B/lvec's type in MatAssemblyEnd() to handle various cases.
18    */
19   if (mode == MAT_FINAL_ASSEMBLY) {
20     PetscCall(MatSetType(mpiaij->A, MATSEQAIJKOKKOS));
21     PetscCall(MatSetType(mpiaij->B, MATSEQAIJKOKKOS));
22     PetscCall(VecSetType(mpiaij->lvec, VECSEQKOKKOS));
23   }
24 
25   PetscFunctionReturn(PETSC_SUCCESS);
26 }
27 
28 PetscErrorCode MatMPIAIJSetPreallocation_MPIAIJKokkos(Mat mat, PetscInt d_nz, const PetscInt d_nnz[], PetscInt o_nz, const PetscInt o_nnz[])
29 {
30   Mat_MPIAIJ *mpiaij = (Mat_MPIAIJ *)mat->data;
31 
32   PetscFunctionBegin;
33   PetscCall(PetscLayoutSetUp(mat->rmap));
34   PetscCall(PetscLayoutSetUp(mat->cmap));
35 #if defined(PETSC_USE_DEBUG)
36   if (d_nnz) {
37     PetscInt i;
38     for (i = 0; i < mat->rmap->n; i++) PetscCheck(d_nnz[i] >= 0, PETSC_COMM_SELF, PETSC_ERR_ARG_OUTOFRANGE, "d_nnz cannot be less than 0: local row %" PetscInt_FMT " value %" PetscInt_FMT, i, d_nnz[i]);
39   }
40   if (o_nnz) {
41     PetscInt i;
42     for (i = 0; i < mat->rmap->n; i++) PetscCheck(o_nnz[i] >= 0, PETSC_COMM_SELF, PETSC_ERR_ARG_OUTOFRANGE, "o_nnz cannot be less than 0: local row %" PetscInt_FMT " value %" PetscInt_FMT, i, o_nnz[i]);
43   }
44 #endif
45 #if defined(PETSC_USE_CTABLE)
46   PetscCall(PetscHMapIDestroy(&mpiaij->colmap));
47 #else
48   PetscCall(PetscFree(mpiaij->colmap));
49 #endif
50   PetscCall(PetscFree(mpiaij->garray));
51   PetscCall(VecDestroy(&mpiaij->lvec));
52   PetscCall(VecScatterDestroy(&mpiaij->Mvctx));
53   /* Because the B will have been resized we simply destroy it and create a new one each time */
54   PetscCall(MatDestroy(&mpiaij->B));
55 
56   if (!mpiaij->A) {
57     PetscCall(MatCreate(PETSC_COMM_SELF, &mpiaij->A));
58     PetscCall(MatSetSizes(mpiaij->A, mat->rmap->n, mat->cmap->n, mat->rmap->n, mat->cmap->n));
59   }
60   if (!mpiaij->B) {
61     PetscMPIInt size;
62     PetscCallMPI(MPI_Comm_size(PetscObjectComm((PetscObject)mat), &size));
63     PetscCall(MatCreate(PETSC_COMM_SELF, &mpiaij->B));
64     PetscCall(MatSetSizes(mpiaij->B, mat->rmap->n, size > 1 ? mat->cmap->N : 0, mat->rmap->n, size > 1 ? mat->cmap->N : 0));
65   }
66   PetscCall(MatSetType(mpiaij->A, MATSEQAIJKOKKOS));
67   PetscCall(MatSetType(mpiaij->B, MATSEQAIJKOKKOS));
68   PetscCall(MatSeqAIJSetPreallocation(mpiaij->A, d_nz, d_nnz));
69   PetscCall(MatSeqAIJSetPreallocation(mpiaij->B, o_nz, o_nnz));
70   mat->preallocated = PETSC_TRUE;
71   PetscFunctionReturn(PETSC_SUCCESS);
72 }
73 
74 PetscErrorCode MatMult_MPIAIJKokkos(Mat mat, Vec xx, Vec yy)
75 {
76   Mat_MPIAIJ *mpiaij = (Mat_MPIAIJ *)mat->data;
77   PetscInt    nt;
78 
79   PetscFunctionBegin;
80   PetscCall(VecGetLocalSize(xx, &nt));
81   PetscCheck(nt == mat->cmap->n, PETSC_COMM_SELF, PETSC_ERR_ARG_SIZ, "Incompatible partition of mat (%" PetscInt_FMT ") and xx (%" PetscInt_FMT ")", mat->cmap->n, nt);
82   PetscCall(VecScatterBegin(mpiaij->Mvctx, xx, mpiaij->lvec, INSERT_VALUES, SCATTER_FORWARD));
83   PetscCall((*mpiaij->A->ops->mult)(mpiaij->A, xx, yy));
84   PetscCall(VecScatterEnd(mpiaij->Mvctx, xx, mpiaij->lvec, INSERT_VALUES, SCATTER_FORWARD));
85   PetscCall((*mpiaij->B->ops->multadd)(mpiaij->B, mpiaij->lvec, yy, yy));
86   PetscFunctionReturn(PETSC_SUCCESS);
87 }
88 
89 PetscErrorCode MatMultAdd_MPIAIJKokkos(Mat mat, Vec xx, Vec yy, Vec zz)
90 {
91   Mat_MPIAIJ *mpiaij = (Mat_MPIAIJ *)mat->data;
92   PetscInt    nt;
93 
94   PetscFunctionBegin;
95   PetscCall(VecGetLocalSize(xx, &nt));
96   PetscCheck(nt == mat->cmap->n, PETSC_COMM_SELF, PETSC_ERR_ARG_SIZ, "Incompatible partition of mat (%" PetscInt_FMT ") and xx (%" PetscInt_FMT ")", mat->cmap->n, nt);
97   PetscCall(VecScatterBegin(mpiaij->Mvctx, xx, mpiaij->lvec, INSERT_VALUES, SCATTER_FORWARD));
98   PetscCall((*mpiaij->A->ops->multadd)(mpiaij->A, xx, yy, zz));
99   PetscCall(VecScatterEnd(mpiaij->Mvctx, xx, mpiaij->lvec, INSERT_VALUES, SCATTER_FORWARD));
100   PetscCall((*mpiaij->B->ops->multadd)(mpiaij->B, mpiaij->lvec, zz, zz));
101   PetscFunctionReturn(PETSC_SUCCESS);
102 }
103 
104 PetscErrorCode MatMultTranspose_MPIAIJKokkos(Mat mat, Vec xx, Vec yy)
105 {
106   Mat_MPIAIJ *mpiaij = (Mat_MPIAIJ *)mat->data;
107   PetscInt    nt;
108 
109   PetscFunctionBegin;
110   PetscCall(VecGetLocalSize(xx, &nt));
111   PetscCheck(nt == mat->rmap->n, PETSC_COMM_SELF, PETSC_ERR_ARG_SIZ, "Incompatible partition of mat (%" PetscInt_FMT ") and xx (%" PetscInt_FMT ")", mat->rmap->n, nt);
112   PetscCall((*mpiaij->B->ops->multtranspose)(mpiaij->B, xx, mpiaij->lvec));
113   PetscCall((*mpiaij->A->ops->multtranspose)(mpiaij->A, xx, yy));
114   PetscCall(VecScatterBegin(mpiaij->Mvctx, mpiaij->lvec, yy, ADD_VALUES, SCATTER_REVERSE));
115   PetscCall(VecScatterEnd(mpiaij->Mvctx, mpiaij->lvec, yy, ADD_VALUES, SCATTER_REVERSE));
116   PetscFunctionReturn(PETSC_SUCCESS);
117 }
118 
119 /* Merge the "A, B" matrices of mat into a matrix C.  mat's type is MPIAIJKOKKOS. C's type is MATSEQAIJKOKKOS.
120    A is put before B. C's size would be A->rmap->n by (A->cmap->n + B->cmap->n).
121    C still uses local column ids. Their corresponding global column ids are returned in glob.
122 */
123 PetscErrorCode MatMPIAIJGetLocalMatMerge_MPIAIJKokkos(Mat mat, MatReuse reuse, IS *glob, Mat *C)
124 {
125   Mat             Ad, Ao;
126   const PetscInt *cmap;
127 
128   PetscFunctionBegin;
129   PetscCall(MatMPIAIJGetSeqAIJ(mat, &Ad, &Ao, &cmap));
130   PetscCall(MatSeqAIJKokkosMergeMats(Ad, Ao, reuse, C));
131   if (glob) {
132     PetscInt cst, i, dn, on, *gidx;
133     PetscCall(MatGetLocalSize(Ad, NULL, &dn));
134     PetscCall(MatGetLocalSize(Ao, NULL, &on));
135     PetscCall(MatGetOwnershipRangeColumn(mat, &cst, NULL));
136     PetscCall(PetscMalloc1(dn + on, &gidx));
137     for (i = 0; i < dn; i++) gidx[i] = cst + i;
138     for (i = 0; i < on; i++) gidx[i + dn] = cmap[i];
139     PetscCall(ISCreateGeneral(PetscObjectComm((PetscObject)Ad), dn + on, gidx, PETSC_OWN_POINTER, glob));
140   }
141   PetscFunctionReturn(PETSC_SUCCESS);
142 }
143 
144 /* Structs used in matrix products of type C=AB, C=A^tB and C=B^tAB */
145 struct MatMatStruct {
146   PetscInt            n, *garray;     // C's garray and its size.
147   KokkosCsrMatrix     Cd, Co;         // C is in split form matrices (all in local column indcies)
148   KokkosCsrMatrix     C1, C2, C3, C4; // intermediate mat products
149   KokkosCsrMatrix     C2_mid, C4_mid; // alias of C2, C4; share their a[], i[], but with different j[] (hence column size)
150   PetscIntKokkosView  E_NzLeft;
151   PetscSF             sf = nullptr; // SF to bcast or reduce matrices E to F
152   MatScalarKokkosView rootBuf, leafBuf;
153   KokkosCsrMatrix     Fd, Fo; // F in split form
154 
155   KernelHandle kh1; // compute C1, add C1+C3 or C1+Fd
156   KernelHandle kh2; // compute C2, add C2+C4 or C2+Fo
157   KernelHandle kh3; // compute C3
158   KernelHandle kh4; // compute C4
159 
160   PetscInt E_TeamSize; // kernel launching parameters in merging E or splitting F
161   PetscInt E_VectorLength;
162   PetscInt E_RowsPerTeam;
163   PetscInt F_TeamSize;
164   PetscInt F_VectorLength;
165   PetscInt F_RowsPerTeam;
166 
167   ~MatMatStruct()
168   {
169     PetscFunctionBegin;
170     PetscCallAbort(PETSC_COMM_SELF, PetscSFDestroy(&sf));
171     PetscFunctionReturnVoid();
172   }
173 };
174 
175 struct MatMatStruct_AB : public MatMatStruct {
176   PetscIntKokkosView F_NzLeft; // plans to split F (in leafbuf) into Fd, Fo
177   PetscIntKokkosView irootloc; // plans to put E (i.e., Bd, Bo) into rootBuf
178   PetscIntKokkosView rowoffset;
179 };
180 
181 struct MatMatStruct_AtB : public MatMatStruct {
182   MatColIdxKokkosView Fdjmap; // plans to reduce data in rootBuf to Fd, Fo
183   MatColIdxKokkosView Fdjperm;
184   MatColIdxKokkosView Fojmap;
185   MatColIdxKokkosView Fojperm;
186 };
187 
188 struct MatProductData_MPIAIJKokkos {
189   MatMatStruct_AB  *mmAB     = nullptr;
190   MatMatStruct_AtB *mmAtB    = nullptr;
191   PetscBool         reusesym = PETSC_FALSE;
192   Mat               Z        = nullptr; // store Z=AB in computing BtAB
193 
194   ~MatProductData_MPIAIJKokkos()
195   {
196     delete mmAB;
197     delete mmAtB;
198     PetscCallAbort(PETSC_COMM_SELF, MatDestroy(&Z));
199   }
200 };
201 
202 static PetscErrorCode MatProductDataDestroy_MPIAIJKokkos(void *data)
203 {
204   PetscFunctionBegin;
205   PetscCallCXX(delete static_cast<MatProductData_MPIAIJKokkos *>(data));
206   PetscFunctionReturn(PETSC_SUCCESS);
207 }
208 
209 /* MatSetMPIAIJKokkosWithSplitSeqAIJKokkosMatrices - Set the diag and offdiag matrices of a MATMPIAIJKOKKOS matrix.
210    It is similar to MatCreateMPIAIJWithSplitArrays.
211 
212   Input Parameters:
213 +  mat   - the MATMPIAIJKOKKOS matrix, which should have its type and layout set, but should not have its diag, offdiag matrices set
214 .  A     - the diag matrix using local col ids
215 -  B     - the offdiag matrix using global col ids
216 
217   Output Parameter:
218 .  mat   - the updated MATMPIAIJKOKKOS matrix
219 */
220 static PetscErrorCode MatSetMPIAIJKokkosWithSplitSeqAIJKokkosMatrices(Mat mat, Mat A, Mat B, PetscInt *garray)
221 {
222   Mat_MPIAIJ *mpiaij = static_cast<Mat_MPIAIJ *>(mat->data);
223   PetscInt    m, n, M, N, Am, An, Bm, Bn;
224 
225   PetscFunctionBegin;
226   PetscCall(MatGetSize(mat, &M, &N));
227   PetscCall(MatGetLocalSize(mat, &m, &n));
228   PetscCall(MatGetLocalSize(A, &Am, &An));
229   PetscCall(MatGetLocalSize(B, &Bm, &Bn));
230 
231   PetscCheck(m == Am && m == Bm, PETSC_COMM_SELF, PETSC_ERR_PLIB, "local number of rows do not match");
232   PetscCheck(n == An, PETSC_COMM_SELF, PETSC_ERR_PLIB, "local number of columns do not match");
233   // PetscCheck(N == Bn, PETSC_COMM_SELF, PETSC_ERR_PLIB, "global number of columns do not match");
234   PetscCheck(!mpiaij->A && !mpiaij->B, PETSC_COMM_SELF, PETSC_ERR_PLIB, "A, B of the MPIAIJ matrix are not empty");
235   mpiaij->A      = A;
236   mpiaij->B      = B;
237   mpiaij->garray = garray;
238 
239   mat->preallocated     = PETSC_TRUE;
240   mat->nooffprocentries = PETSC_TRUE; /* See MatAssemblyBegin_MPIAIJ. In effect, making MatAssemblyBegin a nop */
241 
242   PetscCall(MatSetOption(mat, MAT_NO_OFF_PROC_ENTRIES, PETSC_TRUE));
243   PetscCall(MatAssemblyBegin(mat, MAT_FINAL_ASSEMBLY));
244   /* MatAssemblyEnd is critical here. It sets mat->offloadmask according to A and B's, and
245     also gets mpiaij->B compacted, with its col ids and size reduced
246   */
247   PetscCall(MatAssemblyEnd(mat, MAT_FINAL_ASSEMBLY));
248   PetscCall(MatSetOption(mat, MAT_NO_OFF_PROC_ENTRIES, PETSC_FALSE));
249   PetscCall(MatSetOption(mat, MAT_NEW_NONZERO_LOCATION_ERR, PETSC_TRUE));
250   PetscFunctionReturn(PETSC_SUCCESS);
251 }
252 
253 // Adapted from Kokkos-Kernels spmv_launch_parameters(), to get parameters in Kokkos nested loops which we used to merge or
254 // split csr matrices. The rule is to have "vector_length * team_size" be around 256 on GPUs (e.g., for a CUDA thread block)
255 template <class ExecutionSpace>
256 static PetscErrorCode MatMergeGetLaunchParameters(PetscInt numRows, PetscInt nnz, PetscInt rows_per_thread, PetscInt &team_size, PetscInt &vector_length, PetscInt &rows_per_team)
257 {
258   Kokkos::TeamPolicy<ExecutionSpace> teamPolicy(128, Kokkos::AUTO);
259 
260   PetscFunctionBegin;
261   PetscInt nnz_per_row = numRows ? (nnz / numRows) : 0; // we might meet empty matrices
262 
263   if (nnz_per_row < 1) nnz_per_row = 1;
264 
265   int max_vector_length = teamPolicy.vector_length_max();
266 
267   if (vector_length < 1) {
268     vector_length = 1;
269     while (vector_length < max_vector_length && vector_length * 6 < nnz_per_row) vector_length *= 2;
270   }
271 
272   // Determine rows per thread
273   if (rows_per_thread < 1) {
274     if (KokkosKernels::Impl::kk_is_gpu_exec_space<ExecutionSpace>()) rows_per_thread = 1;
275     else {
276       if (nnz_per_row < 20 && nnz > 5000000) {
277         rows_per_thread = 256;
278       } else rows_per_thread = 64;
279     }
280   }
281 
282   if (team_size < 1) {
283     if (KokkosKernels::Impl::kk_is_gpu_exec_space<ExecutionSpace>()) {
284       team_size = 256 / vector_length;
285     } else {
286       team_size = 1;
287     }
288   }
289 
290   rows_per_team = rows_per_thread * team_size;
291 
292   if (rows_per_team < 0) {
293     PetscInt nnz_per_team = 4096;
294     PetscInt conc         = ExecutionSpace().concurrency();
295     while ((conc * nnz_per_team * 4 > nnz) && (nnz_per_team > 256)) nnz_per_team /= 2;
296     rows_per_team = (nnz_per_team + nnz_per_row - 1) / nnz_per_row;
297   }
298   PetscFunctionReturn(PETSC_SUCCESS);
299 }
300 
301 /*
302   Reduce two sets of global indices into local ones
303 
304   Input Parameters:
305 +  n1          - size of garray1[], the first set
306 .  garray1[n1] - a sorted global index array (without duplicates)
307 .  m           - size of indices[], the second set
308 -  indices[m]  - a unsorted global index array (might have duplicates), which will be updated on output into local ones
309 
310   Output Parameters:
311 +  n2          - size of garray2[], the merged set, which combines garray1[] and indices[]
312 .  garray2[n2] - allocated by callee using PetscMalloc1(). Contains sorted unique global indices (without duplicates). Caller needs to free it.
313 .  map[n1]     - allocated by caller. It gives garray1[i] = garray2[map[i]]
314 -  indices[m]  - on output, global indices in this array are rewritten with local ones, i.e, indices_input[i] = garray2[indices_output[i]]
315 
316    Example, say
317     n1         = 5
318     garray1[5] = {1, 4, 7, 8, 10}
319     m          = 4
320     indices[4] = {2, 4, 8, 9}
321 
322    Combining them together, we have 7 global indices in garray2[]
323     n2         = 7
324     garray2[7] = {1, 2, 4, 7, 8, 9, 10}
325 
326    And we have map[] to connect "garray1[i] = garray2[map[i]], i=[0,n1)"
327     map[5] = {0, 2, 3, 4, 6}
328 
329    On output, indices[] is updated with local indices
330     indices[4] = {1, 2, 4, 5}
331 */
332 static PetscErrorCode ReduceTwoSetsOfGlobalIndices(PetscInt n1, const PetscInt *garray1, PetscInt m, PetscInt *indices, PetscInt *n2_, PetscInt **garray2_, PetscInt *map)
333 {
334   PetscHMapI    g2l = nullptr;
335   PetscHashIter iter;
336   PetscInt      tot, key, val; // total unique global indices. key is global id; val is local id
337   PetscInt      n2, *garray2;
338 
339   PetscFunctionBegin;
340   tot = 0;
341   PetscCall(PetscHMapICreateWithSize(n1, &g2l));
342   for (PetscInt i = 0; i < m; i++) {                                // insert those in indices[]
343     PetscCall(PetscHMapIGetWithDefault(g2l, indices[i], -1, &val)); // if not exist, val is set with -1
344     if (val < 0) PetscCall(PetscHMapISet(g2l, indices[i], tot++));  // val < 0 means gid is not in the hash table yet
345   }
346 
347   for (PetscInt i = 0; i < n1; i++) { // insert those in garray1[]
348     PetscCall(PetscHMapIGetWithDefault(g2l, garray1[i], -1, &val));
349     if (val < 0) PetscCall(PetscHMapISet(g2l, garray1[i], tot++));
350   }
351 
352   // Pull out (unique) globals in the hash table and put them in garray2[]
353   n2 = tot;
354   PetscCall(PetscMalloc1(n2, &garray2));
355   tot = 0;
356   PetscHashIterBegin(g2l, iter);
357   while (!PetscHashIterAtEnd(g2l, iter)) {
358     PetscHashIterGetKey(g2l, iter, key);
359     PetscHashIterNext(g2l, iter);
360     garray2[tot++] = key;
361   }
362 
363   // Sort garray2[] and then map them to local indices starting from 0
364   PetscCall(PetscSortInt(n2, garray2));
365   PetscCall(PetscHMapIClear(g2l));
366   for (PetscInt i = 0; i < tot; i++) PetscCall(PetscHMapISet(g2l, garray2[i], i)); // i is the local id
367 
368   // Rewrite indices[] with local indices
369   for (PetscInt i = 0; i < m; i++) {
370     PetscCall(PetscHMapIGetWithDefault(g2l, indices[i], -1, &val));
371     PetscAssert(val >= 0, PETSC_COMM_SELF, PETSC_ERR_PLIB, "Met a negative local column index");
372     indices[i] = val;
373   }
374   // Record the map that maps garray1[i] to garray2[map[i]]
375   for (PetscInt i = 0; i < n1; i++) PetscCall(PetscHMapIGetWithDefault(g2l, garray1[i], -1, &map[i]));
376   PetscCall(PetscHMapIDestroy(&g2l));
377   *n2_      = n2;
378   *garray2_ = garray2;
379   PetscFunctionReturn(PETSC_SUCCESS);
380 }
381 
382 /*
383   MatMPIAIJKokkosReduce - Reduce rows of a MPIAIJKOKKOS matrix (E, in split form) to produce another matrix (F, also in split form, stored in mm)
384 
385   It is the reverse of MatMPIAIJKokkosBcast() in some sense, but with a different signature since we do not really need a fully populated MPIAIJKOKKOS E.
386 
387   Think each row of E as a leaf, then the given ownerSF specifies roots for the leaves. Roots may connect to multiple leaves.
388   In this routine, we sparse-merge leaves (rows) at their roots to form potentially longer rows in F. F's number of rows will be nroots of ownerSF.
389 
390   Input Parameters:
391 +  comm       - MPI communicator of E
392 .  A          - diag block of E, using local column indices
393 .  B          - off-diag block of E, using local column indices
394 .  cstart      - (global) start column of Ed
395 .  cend        - (global) end column + 1 of Ed.  In other words, E's column ownership is in range of [cstart, cend)
396 .  garray1[n1] - global column indices of Eo. Here n1 is Eo's column size.
397 .  ownerSF     - the SF specifies ownership (root) of rows in E
398 .  reuse       - either MAT_INITIAL_MATRIX or MAT_REUSE_MATRIX
399 -  mm          - to stash intermediate data structures for reuse
400 
401   Output Parameters:
402 +  map[n1]  - allocated by caller. It maps garray1[] to garray2[]. See more at ReduceTwoSetsOfGlobalIndices().
403 -  mm       - contains various info, such as garray2[], F (Fd, Fo) etc.
404 
405   Notes:
406   When reuse = MAT_REUSE_MATRIX, cstart, cend, garray1, ownerSF, map are not significant.
407 
408  */
409 static PetscErrorCode MatMPIAIJKokkosReduceBegin(MPI_Comm comm, KokkosCsrMatrix A, KokkosCsrMatrix B, PetscInt cstart, PetscInt cend, const PetscInt *garray1, PetscSF ownerSF, MatReuse reuse, PetscInt *map, MatMatStruct_AtB *mm)
410 {
411   PetscFunctionBegin;
412   if (reuse == MAT_INITIAL_MATRIX) {
413     PetscInt Em = A.numRows(), Fm;
414     PetscInt n1 = B.numCols();
415 
416     PetscCall(PetscSFGetGraph(ownerSF, &Fm, NULL, NULL, NULL)); // Fm = #rows of F = nroots of ownerSF
417 
418     // Do the analysis on host
419     auto                 Ai_h = Kokkos::create_mirror_view_and_copy(Kokkos::HostSpace(), A.graph.row_map);
420     auto                 Aj_h = Kokkos::create_mirror_view_and_copy(Kokkos::HostSpace(), A.graph.entries);
421     auto                 Bi_h = Kokkos::create_mirror_view_and_copy(Kokkos::HostSpace(), B.graph.row_map);
422     auto                 Bj_h = Kokkos::create_mirror_view_and_copy(Kokkos::HostSpace(), B.graph.entries);
423     const MatRowMapType *Ai = Ai_h.data(), *Bi = Bi_h.data();
424     const MatColIdxType *Aj = Aj_h.data(), *Bj = Bj_h.data();
425 
426     // Count how many nonzeros of each row in E are in the left of the diag block [cstart,cend)
427     PetscIntKokkosViewHost E_NzLeft_h(NoInit("E_NzLeft_h"), Em), E_RowLen_h(NoInit("E_RowLen_h"), Em);
428     PetscInt              *E_NzLeft = E_NzLeft_h.data(), *E_RowLen = E_RowLen_h.data();
429     for (PetscInt i = 0; i < Em; i++) {
430       const PetscInt *first, *last, *it;
431       PetscInt        count, step;
432       // std::lower_bound(first,last,cstart), but need to use global column indices
433       first = Bj + Bi[i];
434       last  = Bj + Bi[i + 1];
435       count = last - first;
436       while (count > 0) {
437         it   = first;
438         step = count / 2;
439         it += step;
440         if (garray1[*it] < cstart) { // map local to global
441           first = ++it;
442           count -= step + 1;
443         } else count = step;
444       }
445       E_NzLeft[i] = first - (Bj + Bi[i]);
446       E_RowLen[i] = (Ai[i + 1] - Ai[i]) + (Bi[i + 1] - Bi[i]);
447     }
448 
449     // Get length of rows (i.e., sizes of leaves) that contribute to my roots
450     const PetscMPIInt *iranks, *ranks;
451     const PetscInt    *ioffset, *irootloc, *roffset, *rmine;
452     PetscInt           niranks, nranks;
453     MPI_Request       *reqs;
454     PetscMPIInt        tag;
455     PetscSF            reduceSF;
456     PetscInt          *sdisp, *rdisp;
457 
458     PetscCall(PetscCommGetNewTag(comm, &tag));
459     PetscCall(PetscSFGetLeafRanks(ownerSF, &niranks, &iranks, &ioffset, &irootloc));  // get leaf ranks connecting to roots on this process (I'll recv from them)
460     PetscCall(PetscSFGetRootRanks(ownerSF, &nranks, &ranks, &roffset, &rmine, NULL)); // get root ranks this process connects (I'll send to them)
461 
462     // Find out length of each row I will receive. Even for the same row index, when they are from
463     // different senders, they might have different lengths (and sparsity patterns)
464     PetscInt  sendRowCnt = roffset[nranks], recvRowCnt = ioffset[niranks];
465     PetscInt *sendRowLen, *recvRowLen; // lengths of rows of I need to send/recv per process
466 
467     PetscCall(PetscMalloc5(sendRowCnt, &sendRowLen, recvRowCnt + 1, &recvRowLen, nranks, &sdisp, niranks + 1, &rdisp, nranks + niranks, &reqs));
468 
469     for (PetscInt i = 0; i < sendRowCnt; i++) sendRowLen[i] = E_RowLen[rmine[i]];
470     recvRowLen[0] = 0; // since we will make it in CSR format later
471     recvRowLen++;      // advance the pointer now
472     for (PetscInt i = 0; i < niranks; i++) { MPI_Irecv(&recvRowLen[ioffset[i]], ioffset[i + 1] - ioffset[i], MPIU_INT, iranks[i], tag, comm, &reqs[nranks + i]); }
473     for (PetscInt i = 0; i < nranks; i++) { MPI_Isend(&sendRowLen[roffset[i]], roffset[i + 1] - roffset[i], MPIU_INT, ranks[i], tag, comm, &reqs[i]); }
474     PetscCallMPI(MPI_Waitall(nranks + niranks, reqs, MPI_STATUSES_IGNORE));
475 
476     // Build the real PetscSF for reducing E rows (buffer to buffer)
477     rdisp[0] = 0;
478     for (PetscInt i = 0; i < niranks; i++) {
479       rdisp[i + 1] = rdisp[i];
480       for (PetscInt j = ioffset[i]; j < ioffset[i + 1]; j++) { rdisp[i + 1] += recvRowLen[j]; }
481     }
482     recvRowLen--; // put it back into csr format
483     for (PetscInt i = 0; i < recvRowCnt; i++) recvRowLen[i + 1] += recvRowLen[i];
484 
485     for (PetscInt i = 0; i < nranks; i++) { MPI_Irecv(&sdisp[i], 1, MPIU_INT, ranks[i], tag, comm, &reqs[i]); }
486     for (PetscInt i = 0; i < niranks; i++) { MPI_Isend(&rdisp[i], 1, MPIU_INT, iranks[i], tag, comm, &reqs[nranks + i]); }
487     PetscCallMPI(MPI_Waitall(nranks + niranks, reqs, MPI_STATUSES_IGNORE));
488 
489     PetscInt     nleaves = 0, Enz = 0;    // leaves are nonzeros I will send
490     PetscInt     nroots = rdisp[niranks]; // roots are nonzeros I will recv
491     PetscSFNode *iremote;
492 
493     for (PetscInt i = 0; i < Em; i++) Enz += E_RowLen[i];
494     PetscAssert(A.nnz() + B.nnz() == Enz, comm, PETSC_ERR_PLIB, "Enz should be equal to sum of nnz of A and B");
495     PetscCallMPI(PetscMalloc1(Enz, &iremote)); // no free, since we give ownership to reduceSF
496 
497     for (PetscInt i = 0; i < nranks; i++) {
498       PetscInt count = 0;
499       for (PetscInt j = roffset[i]; j < roffset[i + 1]; j++) count += E_RowLen[rmine[j]];
500       for (PetscInt j = 0; j < count; j++) {
501         iremote[nleaves + j].rank  = ranks[i];
502         iremote[nleaves + j].index = sdisp[i] + j;
503       }
504       nleaves += count;
505     }
506     PetscCheck(nleaves == Enz, comm, PETSC_ERR_PLIB, "nleaves should be equal to Enz");
507 
508     PetscCall(PetscSFCreate(comm, &reduceSF));
509     PetscCall(PetscSFSetGraph(reduceSF, nroots, nleaves, NULL, PETSC_OWN_POINTER, iremote, PETSC_OWN_POINTER));
510 
511     // Copy (global) column indices of the needed rows in E to sendCol[], and then PetscSFReduce to recvCol[]
512     PetscInt *sendCol, *recvCol;
513     PetscCall(PetscMalloc2(nleaves, &sendCol, nroots, &recvCol));
514     for (PetscInt k = 0; k < roffset[nranks]; k++) {
515       PetscInt  i      = rmine[k]; // row to be copied
516       PetscInt *buf    = &sendCol[Ai[i] + Bi[i]];
517       PetscInt  nzLeft = E_NzLeft[i];
518       PetscInt  alen = Ai[i + 1] - Ai[i], blen = Bi[i + 1] - Bi[i];
519       for (PetscInt j = 0; j < alen + blen; j++) {
520         if (j < nzLeft) {
521           buf[j] = garray1[Bj[Bi[i] + j]]; // left B, in global
522         } else if (j < nzLeft + alen) {
523           buf[j] = Aj[Ai[i] + j - nzLeft] + cstart; // diag A, also in global
524         } else {
525           buf[j] = garray1[Bj[Bi[i] + j - alen]]; // right B, in global
526         }
527       }
528     }
529     PetscCall(PetscSFReduceWithMemTypeBegin(reduceSF, MPIU_INT, PETSC_MEMTYPE_HOST, sendCol, PETSC_MEMTYPE_HOST, recvCol, MPI_REPLACE));
530     PetscCall(PetscSFReduceEnd(reduceSF, MPIU_INT, sendCol, recvCol, MPI_REPLACE));
531 
532     // With recvCol[], we do a series of analysis to get i, j of Fd, Fo, and build plans to reduce nonzeros in recv buffers to Fd and Fo
533     PetscInt *recvRowPerm, *recvColSorted;
534     PetscInt *recvNzPerm, *recvNzPermSorted;
535     PetscCall(PetscMalloc4(recvRowCnt, &recvRowPerm, nroots, &recvColSorted, nroots, &recvNzPerm, nroots, &recvNzPermSorted));
536 
537     for (PetscInt i = 0; i < nroots; i++) recvNzPerm[i] = i;                   // numbering all received nonzeros
538     for (PetscInt i = 0; i < recvRowCnt; i++) recvRowPerm[i] = i;              // put up a permutation array, so that after sorting we know where to get a row in recvCol[]
539     PetscCall(PetscSortIntWithPermutation(recvRowCnt, irootloc, recvRowPerm)); // irootloc[] (owned by ownerSF) won't be changed
540 
541     // i[] array, nz are always easiest to compute
542     MatRowMapKokkosViewHost Fdi_h(NoInit("Fdi_h"), Fm + 1), Foi_h(NoInit("Foi_h"), Fm + 1);
543     MatRowMapType          *Fdi, *Foi;
544     PetscInt                FnzDups = 0, Fdnz = 0, FdnzDups = 0, Fonz = 0, FonzDups = 0; // nz (with or without dups) in F, Fd, Fo
545     PetscInt                iter;
546 
547     Kokkos::deep_copy(Fdi_h, 0); // zero, as we will do 'val++' on them
548     Kokkos::deep_copy(Foi_h, 0);
549     Fdi  = Fdi_h.data() + 1; // +1 for easy indexing in code below
550     Foi  = Foi_h.data() + 1;
551     iter = 0;
552     while (iter < recvRowCnt) { // iter over received rows
553       PetscInt curRowIdx = irootloc[recvRowPerm[iter]];
554       PetscInt dupRows   = 1; // current row has this many contributing rows (of various sparsity patterns)
555 
556       while (iter + dupRows < recvRowCnt && irootloc[recvRowPerm[iter + dupRows]] == curRowIdx) dupRows++;
557 
558       // Copy column indices (and their permutation) of these rows into recvColSorted & recvNzPermSorted
559       PetscInt  nz    = 0; // nz (with dups) in the current row
560       PetscInt *jbuf  = recvColSorted + FnzDups;
561       PetscInt *pbuf  = recvNzPermSorted + FnzDups;
562       PetscInt *jbuf2 = jbuf; // temp pointers
563       PetscInt *pbuf2 = pbuf;
564       for (PetscInt d = 0; d < dupRows; d++) {
565         PetscInt i   = recvRowPerm[iter + d];
566         PetscInt len = recvRowLen[i + 1] - recvRowLen[i];
567         PetscCall(PetscArraycpy(jbuf2, &recvCol[recvRowLen[i]], len));
568         PetscCall(PetscArraycpy(pbuf2, &recvNzPerm[recvRowLen[i]], len));
569         jbuf2 += len;
570         pbuf2 += len;
571         nz += len;
572       }
573       PetscCall(PetscIntSortSemiOrderedWithArray(nz, jbuf, pbuf)); // It could be improved with k-way merge sort, since the rows are already sorted
574 
575       // Scan column indices (in jbuf[0,nz), might have dups) of this row, and see how many go to Fd and how many go to Fo
576       PetscInt cur = 0;
577       while (cur < nz) {
578         PetscInt curColIdx = jbuf[cur];
579         PetscInt dups      = 1;
580 
581         while (cur + dups < nz && jbuf[cur + dups] == curColIdx) dups++;
582         if (curColIdx >= cstart && curColIdx < cend) {
583           Fdi[curRowIdx]++;
584           FdnzDups += dups;
585         } else {
586           Foi[curRowIdx]++;
587           FonzDups += dups;
588         }
589         cur += dups;
590       }
591 
592       FnzDups += nz;
593       iter += dupRows; // Move to next unique row
594     }
595 
596     Fdi = Fdi_h.data(); // restore Fdi, Foi and make them CSR
597     Foi = Foi_h.data();
598     for (PetscInt i = 0; i < Fm; i++) {
599       Fdi[i + 1] += Fdi[i];
600       Foi[i + 1] += Foi[i];
601     }
602     Fdnz = Fdi[Fm];
603     Fonz = Foi[Fm];
604     PetscCall(PetscFree2(sendCol, recvCol));
605 
606     // Allocate j, jmap, jperm for Fd and Fo
607     MatColIdxKokkosViewHost Fdj_h(NoInit("Fdj_h"), Fdnz), Foj_h(NoInit("Foj_h"), Fonz);
608     MatRowMapKokkosViewHost Fdjmap_h(NoInit("Fdjmap_h"), Fdnz + 1), Fojmap_h(NoInit("Fojmap_h"), Fonz + 1); // +1 to make csr
609     MatRowMapKokkosViewHost Fdjperm_h(NoInit("Fdjperm_h"), FdnzDups), Fojperm_h(NoInit("Fojperm_h"), FonzDups);
610     MatColIdxType          *Fdj = Fdj_h.data(), *Foj = Foj_h.data();
611     MatRowMapType          *Fdjmap = Fdjmap_h.data(), *Fojmap = Fojmap_h.data();
612     MatRowMapType          *Fdjperm = Fdjperm_h.data(), *Fojperm = Fojperm_h.data();
613 
614     // Scan recvColSorted[] again, and fill j, jmap, jperm for Fd and Fo
615     Fdjmap[0] = 0;
616     Fojmap[0] = 0;
617     FnzDups   = 0;
618     Fdnz      = 0;
619     Fonz      = 0;
620     iter      = 0; // iter over received rows
621     while (iter < recvRowCnt) {
622       PetscInt curRowIdx = irootloc[recvRowPerm[iter]]; // current row idx
623       PetscInt dupRows   = 1;                           // It has this many contributing rows (of various lengths)
624       PetscInt nz        = 0;                           // nz (with dups) in the current row
625 
626       while (iter + dupRows < recvRowCnt && irootloc[recvRowPerm[iter + dupRows]] == curRowIdx) dupRows++;
627       for (PetscInt d = 0; d < dupRows; d++) {
628         PetscInt i = recvRowPerm[iter + d];
629         nz += recvRowLen[i + 1] - recvRowLen[i];
630       }
631 
632       PetscInt *jbuf = recvColSorted + FnzDups;
633       // Scan columns (in jbuf[0,nz) of this row, copy them and their permutation to j[] and jperm[] of Fd and Fo
634       PetscInt cur = 0;
635       while (cur < nz) {
636         PetscInt curColIdx = jbuf[cur];
637         PetscInt dups      = 1;
638 
639         while (cur + dups < nz && jbuf[cur + dups] == curColIdx) dups++;
640         if (curColIdx >= cstart && curColIdx < cend) {
641           Fdj[Fdnz]        = curColIdx - cstart; // easily convert to local
642           Fdjmap[Fdnz + 1] = Fdjmap[Fdnz] + dups;
643           for (PetscInt j = 0; j < dups; j++) Fdjperm[Fdjmap[Fdnz] + j] = recvNzPermSorted[FnzDups + j];
644           FdnzDups += dups;
645           Fdnz++;
646         } else {
647           Foj[Fonz]        = curColIdx; // in global
648           Fojmap[Fonz + 1] = Fojmap[Fonz] + dups;
649           for (PetscInt j = 0; j < dups; j++) Fojperm[Fojmap[Fonz] + j] = recvNzPermSorted[FnzDups + j];
650           FonzDups += dups;
651           Fonz++;
652         }
653         cur += dups;
654         FnzDups += dups;
655       }
656       iter += dupRows; // Move to next unique row
657     }
658     PetscCall(PetscFree4(recvRowPerm, recvColSorted, recvNzPerm, recvNzPermSorted));
659     PetscCall(PetscFree5(sendRowLen, recvRowLen, sdisp, rdisp, reqs));
660 
661     // Combine global column indices in garray1[] and Foj[]
662     PetscInt n2, *garray2;
663 
664     PetscCall(ReduceTwoSetsOfGlobalIndices(n1, garray1, Fonz, Foj, &n2, &garray2, map));
665     mm->sf       = reduceSF;
666     mm->leafBuf  = MatScalarKokkosView(NoInit("leafBuf"), nleaves);
667     mm->rootBuf  = MatScalarKokkosView(NoInit("rootBuf"), nroots);
668     mm->garray   = garray2; // give ownership, so no free
669     mm->n        = n2;
670     mm->E_NzLeft = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), E_NzLeft_h);
671     mm->Fdjmap   = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Fdjmap_h);
672     mm->Fdjperm  = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Fdjperm_h);
673     mm->Fojmap   = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Fojmap_h);
674     mm->Fojperm  = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Fojperm_h);
675 
676     // Output Fd and Fo in KokkosCsrMatrix format
677     MatScalarKokkosView Fda_d(NoInit("Fda_d"), Fdnz);
678     MatRowMapKokkosView Fdi_d = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Fdi_h);
679     MatColIdxKokkosView Fdj_d = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Fdj_h);
680     MatScalarKokkosView Foa_d(NoInit("Foa_d"), Fonz);
681     MatRowMapKokkosView Foi_d = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Foi_h);
682     MatColIdxKokkosView Foj_d = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Foj_h);
683 
684     PetscCallCXX(mm->Fd = KokkosCsrMatrix("Fd", Fm, cend - cstart, Fdnz, Fda_d, Fdi_d, Fdj_d));
685     PetscCallCXX(mm->Fo = KokkosCsrMatrix("Fo", Fm, n2, Fonz, Foa_d, Foi_d, Foj_d)); // Fo's column size is n2, length of garray2[]
686 
687     // Compute kernel launch parameters in merging E
688     PetscInt teamSize, vectorLength, rowsPerTeam;
689 
690     teamSize = vectorLength = rowsPerTeam = -1;
691     PetscCall(MatMergeGetLaunchParameters<DefaultExecutionSpace>(Em, Enz, -1, teamSize, vectorLength, rowsPerTeam));
692     mm->E_TeamSize     = teamSize;
693     mm->E_VectorLength = vectorLength;
694     mm->E_RowsPerTeam  = rowsPerTeam;
695   } else PetscCheck(reuse == MAT_REUSE_MATRIX, comm, PETSC_ERR_PLIB, "Unsupported MatReuse enum %d", reuse);
696 
697   // Handy aliases
698   auto       &Aa           = A.values;
699   auto       &Ba           = B.values;
700   const auto &Ai           = A.graph.row_map;
701   const auto &Bi           = B.graph.row_map;
702   const auto &E_NzLeft     = mm->E_NzLeft;
703   auto       &leafBuf      = mm->leafBuf;
704   auto       &rootBuf      = mm->rootBuf;
705   PetscSF     reduceSF     = mm->sf;
706   PetscInt    Em           = A.numRows();
707   PetscInt    teamSize     = mm->E_TeamSize;
708   PetscInt    vectorLength = mm->E_VectorLength;
709   PetscInt    rowsPerTeam  = mm->E_RowsPerTeam;
710   PetscInt    workSets     = (Em + rowsPerTeam - 1) / rowsPerTeam;
711 
712   // Copy rows in A/B of E to leafBuf, then pass it to rootBuf
713   PetscCallCXX(Kokkos::parallel_for(
714     Kokkos::TeamPolicy<>(workSets, teamSize, vectorLength), KOKKOS_LAMBDA(const KokkosTeamMemberType &t) {
715       Kokkos::parallel_for(Kokkos::TeamThreadRange(t, 0, rowsPerTeam), [&](PetscInt k) {
716         PetscInt i = t.league_rank() * rowsPerTeam + k; // i-th row in F
717         if (i < Em) {
718           PetscInt disp   = Ai(i) + Bi(i);
719           PetscInt alen   = Ai(i + 1) - Ai(i);
720           PetscInt blen   = Bi(i + 1) - Bi(i);
721           PetscInt nzleft = E_NzLeft(i);
722 
723           Kokkos::parallel_for(Kokkos::ThreadVectorRange(t, alen + blen), [&](PetscInt j) {
724             MatScalar &val = leafBuf(disp + j);
725             if (j < nzleft) { // B left
726               val = Ba(Bi(i) + j);
727             } else if (j < nzleft + alen) { // diag A
728               val = Aa(Ai(i) + j - nzleft);
729             } else { // B right
730               val = Ba(Bi(i) + j - alen);
731             }
732           });
733         }
734       });
735     }));
736   PetscCall(PetscSFReduceWithMemTypeBegin(reduceSF, MPIU_SCALAR, PETSC_MEMTYPE_KOKKOS, leafBuf.data(), PETSC_MEMTYPE_KOKKOS, rootBuf.data(), MPI_REPLACE));
737   PetscFunctionReturn(PETSC_SUCCESS);
738 }
739 
740 // To finish MatMPIAIJKokkosReduce.
741 static PetscErrorCode MatMPIAIJKokkosReduceEnd(MPI_Comm comm, KokkosCsrMatrix A, KokkosCsrMatrix B, PetscInt cstart, PetscInt cend, const PetscInt *garray1, PetscSF ownerSF, MatReuse reuse, PetscInt *map, MatMatStruct_AtB *mm)
742 {
743   PetscFunctionBegin;
744   auto       &leafBuf  = mm->leafBuf;
745   auto       &rootBuf  = mm->rootBuf;
746   auto       &Fda      = mm->Fd.values;
747   const auto &Fdjmap   = mm->Fdjmap;
748   const auto &Fdjperm  = mm->Fdjperm;
749   auto        Fdnz     = mm->Fd.nnz();
750   auto       &Foa      = mm->Fo.values;
751   const auto &Fojmap   = mm->Fojmap;
752   const auto &Fojperm  = mm->Fojperm;
753   auto        Fonz     = mm->Fo.nnz();
754   PetscSF     reduceSF = mm->sf;
755 
756   PetscCall(PetscSFReduceEnd(reduceSF, MPIU_SCALAR, leafBuf.data(), rootBuf.data(), MPI_REPLACE));
757 
758   // Reduce data in rootBuf to Fd and Fo
759   PetscCallCXX(Kokkos::parallel_for(
760     Fdnz, KOKKOS_LAMBDA(const MatRowMapType i) {
761       PetscScalar sum = 0.0;
762       for (MatRowMapType k = Fdjmap(i); k < Fdjmap(i + 1); k++) sum += rootBuf(Fdjperm(k));
763       Fda(i) = sum;
764     }));
765 
766   PetscCallCXX(Kokkos::parallel_for(
767     Fonz, KOKKOS_LAMBDA(const MatRowMapType i) {
768       PetscScalar sum = 0.0;
769       for (MatRowMapType k = Fojmap(i); k < Fojmap(i + 1); k++) sum += rootBuf(Fojperm(k));
770       Foa(i) = sum;
771     }));
772   PetscFunctionReturn(PETSC_SUCCESS);
773 }
774 
775 /*
776   MatMPIAIJKokkosBcast - Bcast local rows of a MPIAIJKOKKOS matrix (E) to produce a local matrix (F, stored in mm) in split form
777 
778   This is a complex routine. It is essentially the MPIAIJKOKKOS counterpart of MatGetBrowsOfAoCols_MPIAIJ, but supports
779   device and involves various index mapping.
780 
781   In the given ownerSF, leaves correspond to rows in F, and roots correspond to rows in E. Roots may connect to multiple leaves.
782   Suppose F's j-th row is connected to a root identified by PetscSFNode (k,i), it means we need to bcast the i-th row of E on rank k
783   to j-th row of F. ownerSF is not an arbitrary SF, instead it is the Mvctx of another MPIAIJ matrix A that is able to perform A*E.
784   F has the same column layout as E.
785 
786   Conceptually F has global column indices. In this routine, we spit F into diagonal Fd and off-diagonal Fo.
787   Fd uses local column indices, which are easy to compute. We just need to subtract the "local column range start" from the global indices.
788   Fo had global column indices at first. We will reduce them into local ones. In doing that, we also take into account the global
789   column indices that E's off-diag block has. Let's say there are n1 such indices stored in garray1[]. We will reduce them along with
790   column indices in Fo and update Fo with local indices.
791 
792    Input Parameters:
793 +   E       - the MPIAIJKOKKOS matrix
794 .   ownerSF - the ownership SF (nonsignificant in MAT_REUSE_MATRIX)
795 .   reuse   - either MAT_INITIAL_MATRIX or MAT_REUSE_MATRIX
796 -   mm      - to stash matproduct intermediate data structures
797 
798     Output Parameters:
799 +   map[n1] - allocated by caller. It maps garray1[] to garray2[]. See more at ReduceTwoSetsOfGlobalIndices.
800 -   mm      - contains various info, such as garray2[], Fd, Fo, etc.
801 
802     Notes:
803     When reuse = MAT_REUSE_MATRIX, ownerSF, map are not significant.
804     The routine is provide in split-phase form MatMPIAIJKokkosBcastBegin/End() to provide computation/communication opportunities.
805 */
806 static PetscErrorCode MatMPIAIJKokkosBcastBegin(Mat E, PetscSF ownerSF, MatReuse reuse, PetscInt *map, MatMatStruct_AB *mm)
807 {
808   Mat_MPIAIJ       *empi = static_cast<Mat_MPIAIJ *>(E->data);
809   Mat               A = empi->A, B = empi->B; // diag and off-diag
810   Mat_SeqAIJKokkos *akok = static_cast<Mat_SeqAIJKokkos *>(A->spptr), *bkok = static_cast<Mat_SeqAIJKokkos *>(B->spptr);
811   PetscInt          Em = E->rmap->n; // #local rows
812   MPI_Comm          comm;
813 
814   PetscFunctionBegin;
815   PetscCallMPI(PetscObjectGetComm((PetscObject)E, &comm));
816   if (reuse == MAT_INITIAL_MATRIX) {
817     Mat_SeqAIJ     *aseq = static_cast<Mat_SeqAIJ *>(A->data), *bseq = static_cast<Mat_SeqAIJ *>(B->data);
818     PetscInt        n1 = B->cmap->n, *Ai = aseq->i, *Aj = aseq->j, *Bi = bseq->i, *Bj = bseq->j;
819     const PetscInt *garray1 = empi->garray; // its size is n1
820     PetscInt        cstart, cend;
821     PetscSF         bcastSF;
822 
823     PetscCall(MatGetOwnershipRangeColumn(E, &cstart, &cend));
824 
825     // Count how many nonzeros of each row in E are in the left of the diag block [cstart,cend)
826     PetscIntKokkosViewHost E_NzLeft_h(NoInit("E_NzLeft_h"), Em), E_RowLen_h(NoInit("E_RowLen_h"), Em);
827     PetscInt              *E_NzLeft = E_NzLeft_h.data(), *E_RowLen = E_RowLen_h.data();
828     for (PetscInt i = 0; i < Em; i++) {
829       const PetscInt *first, *last, *it;
830       PetscInt        count, step;
831       // std::lower_bound(first,last,cstart), but need to use global column indices
832       first = Bj + Bi[i];
833       last  = Bj + Bi[i + 1];
834       count = last - first;
835       while (count > 0) {
836         it   = first;
837         step = count / 2;
838         it += step;
839         if (empi->garray[*it] < cstart) { // map local to global
840           first = ++it;
841           count -= step + 1;
842         } else count = step;
843       }
844       E_NzLeft[i] = first - (Bj + Bi[i]);
845       E_RowLen[i] = (Ai[i + 1] - Ai[i]) + (Bi[i + 1] - Bi[i]);
846     }
847 
848     // Compute row pointer Fi of F
849     PetscInt *Fi, Fm, Fnz;
850     PetscCall(PetscSFGetGraph(ownerSF, NULL, &Fm, NULL, NULL)); // Fm = #rows of F = nleaves of ownerSF
851     PetscCall(PetscMalloc1(Fm + 1, &Fi));
852     Fi[0] = 0;
853     PetscCall(PetscSFBcastWithMemTypeBegin(ownerSF, MPIU_INT, PETSC_MEMTYPE_HOST, E_RowLen, PETSC_MEMTYPE_HOST, &Fi[1], MPI_REPLACE));
854     PetscCall(PetscSFBcastEnd(ownerSF, MPIU_INT, E_RowLen, &Fi[1], MPI_REPLACE));
855     for (PetscInt i = 0; i < Fm; i++) Fi[i + 1] += Fi[i];
856     Fnz = Fi[Fm];
857 
858     // Build the real PetscSF for bcasting E rows (buffer to buffer)
859     const PetscMPIInt *iranks, *ranks;
860     const PetscInt    *ioffset, *irootloc, *roffset;
861     PetscInt           niranks, nranks, *sdisp, *rdisp;
862     MPI_Request       *reqs;
863     PetscMPIInt        tag;
864 
865     PetscCall(PetscSFGetLeafRanks(ownerSF, &niranks, &iranks, &ioffset, &irootloc)); // get leaf ranks referencing roots on this process
866     PetscCall(PetscSFGetRootRanks(ownerSF, &nranks, &ranks, &roffset, NULL, NULL));  // recv info
867     PetscCall(PetscMalloc3(niranks + 1, &sdisp, nranks, &rdisp, niranks + nranks, &reqs));
868 
869     sdisp[0] = 0; // send displacement
870     for (PetscInt i = 0; i < niranks; i++) {
871       sdisp[i + 1] = sdisp[i];
872       for (PetscInt j = ioffset[i]; j < ioffset[i + 1]; j++) {
873         PetscInt r = irootloc[j]; // row to be sent
874         sdisp[i + 1] += E_RowLen[r];
875       }
876     }
877 
878     PetscCallMPI(PetscCommGetNewTag(comm, &tag));
879     for (PetscInt i = 0; i < nranks; i++) PetscCallMPI(MPI_Irecv(&rdisp[i], 1, MPIU_INT, ranks[i], tag, comm, &reqs[i]));
880     for (PetscInt i = 0; i < niranks; i++) PetscCallMPI(MPI_Isend(&sdisp[i], 1, MPIU_INT, iranks[i], tag, comm, &reqs[nranks + i]));
881     PetscCallMPI(MPI_Waitall(niranks + nranks, reqs, MPI_STATUSES_IGNORE));
882 
883     PetscInt     nleaves = Fnz;            // leaves are nonzeros I will receive
884     PetscInt     nroots  = sdisp[niranks]; // roots are nonzeros I will send
885     PetscSFNode *iremote;                  // give ownership to bcastSF
886     PetscCall(PetscMalloc1(nleaves, &iremote));
887     for (PetscInt i = 0; i < nranks; i++) { // for each sender rank
888       PetscInt k = 0;
889       for (PetscInt j = Fi[roffset[i]]; j < Fi[roffset[i + 1]]; j++) { // I will receive rows [roffset[i], roffset[i+1]) of F from ranks[i]
890         iremote[j].rank  = ranks[i];
891         iremote[j].index = rdisp[i] + k; // their root location
892         k++;
893       }
894     }
895     PetscCall(PetscSFCreate(comm, &bcastSF));
896     PetscCall(PetscSFSetGraph(bcastSF, nroots, nleaves, NULL, PETSC_OWN_POINTER, iremote, PETSC_OWN_POINTER));
897     PetscCall(PetscFree3(sdisp, rdisp, reqs));
898 
899     // Build a plan (rowoffset, irootloc, E_NzLeft) to copy rows in E to rootdata of bcastSF in parallel
900     PetscIntKokkosViewHost rowoffset_h(NoInit("rowoffset_h"), ioffset[niranks] + 1);
901     PetscInt              *rowoffset = rowoffset_h.data(); // for each entry (row) indicated in irootloc[], we calculate its destinate offset in copying
902     rowoffset[0]                     = 0;
903     for (PetscInt i = 0; i < ioffset[niranks]; i++) { rowoffset[i + 1] = rowoffset[i] + E_RowLen[irootloc[i]]; }
904 
905     // Copy (global) column indices of the needed rows in E to a buffer, and then bcast to Fj[]
906     PetscInt *jbuf, *Fj;
907     PetscCall(PetscMalloc2(nroots, &jbuf, Fnz, &Fj));
908     for (PetscInt k = 0; k < ioffset[niranks]; k++) {
909       PetscInt  i      = irootloc[k]; // row to be copied
910       PetscInt *buf    = &jbuf[rowoffset[k]];
911       PetscInt  nzLeft = E_NzLeft[i];
912       PetscInt  alen = Ai[i + 1] - Ai[i], blen = Bi[i + 1] - Bi[i];
913       for (PetscInt j = 0; j < alen + blen; j++) {
914         if (j < nzLeft) {
915           buf[j] = empi->garray[Bj[Bi[i] + j]]; // left B, in global
916         } else if (j < nzLeft + alen) {
917           buf[j] = Aj[Ai[i] + j - nzLeft] + cstart; // diag A, also in global
918         } else {
919           buf[j] = empi->garray[Bj[Bi[i] + j - alen]]; // right B, in global
920         }
921       }
922     }
923     PetscCall(PetscSFBcastWithMemTypeBegin(bcastSF, MPIU_INT, PETSC_MEMTYPE_HOST, jbuf, PETSC_MEMTYPE_HOST, Fj, MPI_REPLACE));
924     PetscCall(PetscSFBcastEnd(bcastSF, MPIU_INT, jbuf, Fj, MPI_REPLACE));
925 
926     // Build a plan (i.e., F_NzLeft) to split F into Fd and Fo
927     MatRowMapKokkosViewHost Fdi_h(NoInit("Fdi_h"), Fm + 1), Foi_h(NoInit("Foi_h"), Fm + 1); // row pointer of Fd, Fo
928     MatColIdxKokkosViewHost F_NzLeft_h(NoInit("F_NzLeft_h"), Fm);                           // split each row of F into Left, Diag, Right. We only need to record #nz in Left and Diag.
929     MatRowMapType          *Fdi = Fdi_h.data(), *Foi = Foi_h.data();
930     MatColIdxType          *F_NzLeft = F_NzLeft_h.data();
931 
932     Fdi[0] = Foi[0] = 0;
933     for (PetscInt i = 0; i < Fm; i++) {
934       PetscInt *first, *last, *lb1, *lb2;
935       // cut the row into: Left, [cstart, cend), Right
936       first       = Fj + Fi[i];
937       last        = Fj + Fi[i + 1];
938       lb1         = std::lower_bound(first, last, cstart);
939       F_NzLeft[i] = lb1 - first;
940       lb2         = std::lower_bound(first, last, cend);
941       Fdi[i + 1]  = lb2 - lb1;                        // row i length in Fdi
942       Foi[i + 1]  = (Fi[i + 1] - Fi[i]) - Fdi[i + 1]; // row i length in Foi
943     }
944     for (PetscInt i = 0; i < Fm; i++) {
945       Fdi[i + 1] += Fdi[i];
946       Foi[i + 1] += Foi[i];
947     }
948 
949     // Fill Fdj[] and Foj[], i.e., columns of Fd and Fo. Fdj[] are local, but Foj[] are not yet.
950     PetscInt                Fdnz = Fdi[Fm], Fonz = Foi[Fm];
951     MatColIdxKokkosViewHost Fdj_h(NoInit("Fdj_h"), Fdnz), Foj_h(NoInit("Foj_h"), Fonz);
952     MatColIdxType          *Fdj = Fdj_h.data(), *Foj = Foj_h.data(), gid;
953 
954     for (PetscInt i = 0; i < Fm; i++) {
955       PetscInt nzLeft = F_NzLeft[i];
956       PetscInt len    = Fdi[i + 1] - Fdi[i]; // diag row len
957       for (PetscInt j = 0; j < Fi[i + 1] - Fi[i]; j++) {
958         gid = Fj[Fi[i] + j];
959         if (j < nzLeft) { // left, in global
960           Foj[Foi[i] + j] = gid;
961         } else if (j < nzLeft + len) { // diag, in local
962           Fdj[Fdi[i] + j - nzLeft] = gid - cstart;
963         } else { // right, in global
964           Foj[Foi[i] + j - len] = gid;
965         }
966       }
967     }
968     PetscCall(PetscFree2(jbuf, Fj));
969     PetscCall(PetscFree(Fi));
970 
971     // Reduce global indices in Foj[] and garray1[] into local ones
972     PetscInt n2, *garray2;
973     PetscCall(ReduceTwoSetsOfGlobalIndices(n1, garray1, Fonz, Foj, &n2, &garray2, map));
974 
975     // Record the plans built above, for reuse
976     PetscIntKokkosViewHost tmp(const_cast<PetscInt *>(irootloc), ioffset[niranks]); // irootloc[] is owned by ownerSF. We create a copy for safety
977     PetscIntKokkosViewHost irootloc_h(NoInit("irootloc_h"), ioffset[niranks]);
978     Kokkos::deep_copy(irootloc_h, tmp);
979     mm->sf        = bcastSF;
980     mm->E_NzLeft  = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), E_NzLeft_h);
981     mm->F_NzLeft  = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), F_NzLeft_h);
982     mm->irootloc  = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), irootloc_h);
983     mm->rowoffset = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), rowoffset_h);
984     mm->rootBuf   = MatScalarKokkosView(NoInit("rootBuf"), nroots);
985     mm->leafBuf   = MatScalarKokkosView(NoInit("leafBuf"), nleaves);
986     mm->garray    = garray2;
987     mm->n         = n2;
988 
989     // Output Fd and Fo in KokkosCsrMatrix format
990     MatScalarKokkosView Fda_d(NoInit("Fda_d"), Fdnz), Foa_d(NoInit("Foa_d"), Fonz);
991     MatRowMapKokkosView Fdi_d = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Fdi_h);
992     MatColIdxKokkosView Fdj_d = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Fdj_h);
993     MatRowMapKokkosView Foi_d = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Foi_h);
994     MatColIdxKokkosView Foj_d = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Foj_h);
995 
996     PetscCallCXX(mm->Fd = KokkosCsrMatrix("Fd", Fm, cend - cstart, Fdnz, Fda_d, Fdi_d, Fdj_d));
997     PetscCallCXX(mm->Fo = KokkosCsrMatrix("Fo", Fm, n2, Fonz, Foa_d, Foi_d, Foj_d));
998 
999     // Compute kernel launch parameters in merging E or splitting F
1000     PetscInt teamSize, vectorLength, rowsPerTeam;
1001 
1002     teamSize = vectorLength = rowsPerTeam = -1;
1003     PetscCall(MatMergeGetLaunchParameters<DefaultExecutionSpace>(mm->irootloc.extent(0), mm->rootBuf.extent(0), -1, teamSize, vectorLength, rowsPerTeam));
1004     mm->E_TeamSize     = teamSize;
1005     mm->E_VectorLength = vectorLength;
1006     mm->E_RowsPerTeam  = rowsPerTeam;
1007 
1008     teamSize = vectorLength = rowsPerTeam = -1;
1009     PetscCall(MatMergeGetLaunchParameters<DefaultExecutionSpace>(Fm, Fnz, -1, teamSize, vectorLength, rowsPerTeam));
1010     mm->F_TeamSize     = teamSize;
1011     mm->F_VectorLength = vectorLength;
1012     mm->F_RowsPerTeam  = rowsPerTeam;
1013   } else PetscCheck(reuse == MAT_REUSE_MATRIX, comm, PETSC_ERR_PLIB, "Unsupported MatReuse enum %d", reuse);
1014 
1015   // Sync E's value to device
1016   akok->a_dual.sync_device();
1017   bkok->a_dual.sync_device();
1018 
1019   // Handy aliases
1020   const auto &Aa = akok->a_dual.view_device();
1021   const auto &Ba = bkok->a_dual.view_device();
1022   const auto &Ai = akok->i_dual.view_device();
1023   const auto &Bi = bkok->i_dual.view_device();
1024 
1025   // Fetch the plans
1026   PetscIntKokkosView  &E_NzLeft  = mm->E_NzLeft;
1027   PetscSF             &bcastSF   = mm->sf;
1028   MatScalarKokkosView &rootBuf   = mm->rootBuf;
1029   MatScalarKokkosView &leafBuf   = mm->leafBuf;
1030   PetscIntKokkosView  &irootloc  = mm->irootloc;
1031   PetscIntKokkosView  &rowoffset = mm->rowoffset;
1032 
1033   PetscInt teamSize     = mm->E_TeamSize;
1034   PetscInt vectorLength = mm->E_VectorLength;
1035   PetscInt rowsPerTeam  = mm->E_RowsPerTeam;
1036   PetscInt workSets     = (irootloc.extent(0) + rowsPerTeam - 1) / rowsPerTeam;
1037 
1038   // Copy rows in A/B of E to rootBuf, then bcast it to leafBuf
1039   PetscCallCXX(Kokkos::parallel_for(
1040     Kokkos::TeamPolicy<>(workSets, teamSize, vectorLength), KOKKOS_LAMBDA(const KokkosTeamMemberType &t) {
1041       Kokkos::parallel_for(Kokkos::TeamThreadRange(t, 0, rowsPerTeam), [&](PetscInt k) {
1042         size_t r = t.league_rank() * rowsPerTeam + k; // r-th entry in irootloc[]
1043         if (r < irootloc.extent(0)) {
1044           PetscInt i      = irootloc(r); // row i of E
1045           PetscInt disp   = rowoffset(r);
1046           PetscInt alen   = Ai(i + 1) - Ai(i);
1047           PetscInt blen   = Bi(i + 1) - Bi(i);
1048           PetscInt nzleft = E_NzLeft(i);
1049 
1050           Kokkos::parallel_for(Kokkos::ThreadVectorRange(t, alen + blen), [&](PetscInt j) {
1051             if (j < nzleft) { // B left
1052               rootBuf(disp + j) = Ba(Bi(i) + j);
1053             } else if (j < nzleft + alen) { // diag A
1054               rootBuf(disp + j) = Aa(Ai(i) + j - nzleft);
1055             } else { // B right
1056               rootBuf(disp + j) = Ba(Bi(i) + j - alen);
1057             }
1058           });
1059         }
1060       });
1061     }));
1062   PetscCall(PetscSFBcastWithMemTypeBegin(bcastSF, MPIU_SCALAR, PETSC_MEMTYPE_KOKKOS, rootBuf.data(), PETSC_MEMTYPE_KOKKOS, leafBuf.data(), MPI_REPLACE));
1063   PetscFunctionReturn(PETSC_SUCCESS);
1064 }
1065 
1066 // To finish MatMPIAIJKokkosBcast.
1067 static PetscErrorCode MatMPIAIJKokkosBcastEnd(Mat E, PetscSF ownerSF, MatReuse reuse, PetscInt *map, MatMatStruct_AB *mm)
1068 {
1069   PetscFunctionBegin;
1070   const auto &Fd  = mm->Fd;
1071   const auto &Fo  = mm->Fo;
1072   const auto &Fdi = Fd.graph.row_map;
1073   const auto &Foi = Fo.graph.row_map;
1074   auto       &Fda = Fd.values;
1075   auto       &Foa = Fo.values;
1076   auto        Fm  = Fd.numRows();
1077 
1078   PetscIntKokkosView  &F_NzLeft     = mm->F_NzLeft;
1079   PetscSF             &bcastSF      = mm->sf;
1080   MatScalarKokkosView &rootBuf      = mm->rootBuf;
1081   MatScalarKokkosView &leafBuf      = mm->leafBuf;
1082   PetscInt             teamSize     = mm->F_TeamSize;
1083   PetscInt             vectorLength = mm->F_VectorLength;
1084   PetscInt             rowsPerTeam  = mm->F_RowsPerTeam;
1085   PetscInt             workSets     = (Fm + rowsPerTeam - 1) / rowsPerTeam;
1086 
1087   PetscCall(PetscSFBcastEnd(bcastSF, MPIU_SCALAR, rootBuf.data(), leafBuf.data(), MPI_REPLACE));
1088 
1089   // Update Fda and Foa with new data in leafBuf (as if it is Fa)
1090   PetscCallCXX(Kokkos::parallel_for(
1091     Kokkos::TeamPolicy<>(workSets, teamSize, vectorLength), KOKKOS_LAMBDA(const KokkosTeamMemberType &t) {
1092       Kokkos::parallel_for(Kokkos::TeamThreadRange(t, 0, rowsPerTeam), [&](PetscInt k) {
1093         PetscInt i = t.league_rank() * rowsPerTeam + k; // i-th row in F
1094         if (i < Fm) {
1095           PetscInt nzLeft = F_NzLeft(i);
1096           PetscInt alen   = Fdi(i + 1) - Fdi(i);
1097           PetscInt blen   = Foi(i + 1) - Foi(i);
1098           PetscInt Fii    = Fdi(i) + Foi(i);
1099 
1100           Kokkos::parallel_for(Kokkos::ThreadVectorRange(t, alen + blen), [&](PetscInt j) {
1101             PetscScalar val = leafBuf(Fii + j);
1102             if (j < nzLeft) { // left
1103               Foa(Foi(i) + j) = val;
1104             } else if (j < nzLeft + alen) { // diag
1105               Fda(Fdi(i) + j - nzLeft) = val;
1106             } else { // right
1107               Foa(Foi(i) + j - alen) = val;
1108             }
1109           });
1110         }
1111       });
1112     }));
1113   PetscFunctionReturn(PETSC_SUCCESS);
1114 }
1115 
1116 static PetscErrorCode MatProductSymbolic_MPIAIJKokkos_AtB(Mat_Product *product, Mat A, Mat B, MatMatStruct_AtB *mm)
1117 {
1118   Mat_MPIAIJ     *ampi = static_cast<Mat_MPIAIJ *>(A->data);
1119   Mat_MPIAIJ     *bmpi = static_cast<Mat_MPIAIJ *>(B->data);
1120   KokkosCsrMatrix Adt, Aot, Ad, Ao, Bd, Bo;
1121   PetscInt        cstart, cend;
1122   MPI_Comm        comm;
1123 
1124   PetscFunctionBegin;
1125   PetscCall(PetscObjectGetComm((PetscObject)B, &comm));
1126   PetscCall(MatSeqAIJKokkosGenerateTranspose_Private(ampi->A, &Adt));
1127   PetscCall(MatSeqAIJKokkosGenerateTranspose_Private(ampi->B, &Aot));
1128   PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(ampi->A, &Ad));
1129   PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(ampi->B, &Ao));
1130   PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(bmpi->A, &Bd));
1131   PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(bmpi->B, &Bo));
1132 
1133   // TODO: add command line options to select spgemm algorithms
1134   auto spgemm_alg = KokkosSparse::SPGEMMAlgorithm::SPGEMM_DEFAULT; // default is TPL if enabled, otherwise KK
1135 
1136   // CUDA-10.2's spgemm has bugs. We prefer the SpGEMMreuse APIs introduced in cuda-11.4
1137 #if defined(KOKKOSKERNELS_ENABLE_TPL_CUSPARSE)
1138   #if PETSC_PKG_CUDA_VERSION_LT(11, 4, 0)
1139   spgemm_alg = KokkosSparse::SPGEMMAlgorithm::SPGEMM_KK;
1140   #endif
1141 #endif
1142 
1143   PetscCallCXX(mm->kh1.create_spgemm_handle(spgemm_alg));
1144   PetscCallCXX(mm->kh2.create_spgemm_handle(spgemm_alg));
1145   PetscCallCXX(mm->kh3.create_spgemm_handle(spgemm_alg));
1146   PetscCallCXX(mm->kh4.create_spgemm_handle(spgemm_alg));
1147 
1148   // Aot * (B's diag + B's off-diag)
1149   PetscCallCXX(KokkosSparse::spgemm_symbolic(mm->kh3, Aot, false, Bd, false, mm->C3));
1150   PetscCallCXX(KokkosSparse::spgemm_symbolic(mm->kh4, Aot, false, Bo, false, mm->C4));
1151   // KK spgemm_symbolic() only populates the result's row map, but not its columns.
1152   // TODO: Remove the fake spgemm_numeric() after KK fixed this problem.
1153   PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh3, Aot, false, Bd, false, mm->C3));
1154   PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh4, Aot, false, Bo, false, mm->C4));
1155 #if PETSC_PKG_KOKKOS_KERNELS_VERSION_LT(4, 0, 0)
1156   PetscCallCXX(sort_crs_matrix(mm->C3));
1157   PetscCallCXX(sort_crs_matrix(mm->C4));
1158 #endif
1159 
1160   // Reduce E (i.e., C3 and C4)'s rows to form F, and overlap the communication
1161   PetscIntKokkosViewHost map_h(NoInit("map_h"), bmpi->B->cmap->n);
1162   PetscCall(MatGetOwnershipRangeColumn(B, &cstart, &cend));
1163   PetscCall(MatMPIAIJKokkosReduceBegin(comm, mm->C3, mm->C4, cstart, cend, bmpi->garray, ampi->Mvctx, MAT_INITIAL_MATRIX, map_h.data(), mm));
1164 
1165   // Adt * (B's diag + B's off-diag)
1166   PetscCallCXX(KokkosSparse::spgemm_symbolic(mm->kh1, Adt, false, Bd, false, mm->C1));
1167   PetscCallCXX(KokkosSparse::spgemm_symbolic(mm->kh2, Adt, false, Bo, false, mm->C2_mid));
1168   PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh1, Adt, false, Bd, false, mm->C1));
1169   PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh2, Adt, false, Bo, false, mm->C2_mid));
1170 #if PETSC_PKG_KOKKOS_KERNELS_VERSION_LT(4, 0, 0)
1171   PetscCallCXX(sort_crs_matrix(mm->C1));
1172   PetscCallCXX(sort_crs_matrix(mm->C2_mid));
1173 #endif
1174 
1175   PetscCall(MatMPIAIJKokkosReduceEnd(comm, mm->C3, mm->C4, cstart, cend, bmpi->garray, ampi->Mvctx, MAT_INITIAL_MATRIX, map_h.data(), mm));
1176 
1177   // Create C2, which shares a, i arrays with C2_mid, but with new column indices and potentially larger column size
1178   MatColIdxKokkosView oldj = mm->C2_mid.graph.entries, newj(NoInit("j"), oldj.extent(0));
1179   PetscIntKokkosView  map  = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), map_h);
1180   PetscCallCXX(Kokkos::parallel_for(
1181     oldj.extent(0), KOKKOS_LAMBDA(const PetscInt i) { newj(i) = map(oldj(i)); }));
1182   PetscCallCXX(mm->C2 = KokkosCsrMatrix("C2", mm->C2_mid.numRows(), mm->n /*new column size*/, mm->C2_mid.nnz(), mm->C2_mid.values, mm->C2_mid.graph.row_map, newj));
1183 
1184   // C = (C1+Fd, C2+Fo)
1185   PetscCallCXX(mm->kh1.create_spadd_handle(true)); // C1, Fd are sorted
1186   PetscCallCXX(mm->kh2.create_spadd_handle(true)); // C2, Fo are sorted
1187   PetscCallCXX(KokkosSparse::spadd_symbolic(&mm->kh1, mm->C1, mm->Fd, mm->Cd));
1188   PetscCallCXX(KokkosSparse::spadd_symbolic(&mm->kh2, mm->C2, mm->Fo, mm->Co));
1189   PetscCallCXX(KokkosSparse::spadd_numeric(&mm->kh1, 1.0, mm->C1, 1.0, mm->Fd, mm->Cd));
1190   PetscCallCXX(KokkosSparse::spadd_numeric(&mm->kh2, 1.0, mm->C2, 1.0, mm->Fo, mm->Co));
1191   PetscFunctionReturn(PETSC_SUCCESS);
1192 }
1193 
1194 static PetscErrorCode MatProductNumeric_MPIAIJKokkos_AtB(Mat_Product *product, Mat A, Mat B, MatMatStruct_AtB *mm)
1195 {
1196   Mat_MPIAIJ     *ampi = static_cast<Mat_MPIAIJ *>(A->data);
1197   Mat_MPIAIJ     *bmpi = static_cast<Mat_MPIAIJ *>(B->data);
1198   KokkosCsrMatrix Adt, Aot, Bd, Bo;
1199   MPI_Comm        comm;
1200 
1201   PetscFunctionBegin;
1202   PetscCall(PetscObjectGetComm((PetscObject)B, &comm));
1203   PetscCall(MatSeqAIJKokkosGenerateTranspose_Private(ampi->A, &Adt));
1204   PetscCall(MatSeqAIJKokkosGenerateTranspose_Private(ampi->B, &Aot));
1205   PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(bmpi->A, &Bd));
1206   PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(bmpi->B, &Bo));
1207 
1208   // Aot * (B's diag + B's off-diag)
1209   PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh3, Aot, false, Bd, false, mm->C3));
1210   PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh4, Aot, false, Bo, false, mm->C4));
1211 
1212   // Reduce E (i.e., C3 and C4)'s rows to form F, and overlap the communication
1213   PetscCall(MatMPIAIJKokkosReduceBegin(comm, mm->C3, mm->C4, 0, 0, NULL, NULL, MAT_REUSE_MATRIX, NULL, mm));
1214 
1215   // Adt * (B's diag + B's off-diag)
1216   PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh1, Adt, false, Bd, false, mm->C1));
1217   PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh2, Adt, false, Bo, false, mm->C2_mid));
1218 
1219   PetscCall(MatMPIAIJKokkosReduceEnd(comm, mm->C3, mm->C4, 0, 0, NULL, NULL, MAT_REUSE_MATRIX, NULL, mm));
1220 
1221   // C = (C1+Fd, C2+Fo)
1222   PetscCallCXX(KokkosSparse::spadd_numeric(&mm->kh1, 1.0, mm->C1, 1.0, mm->Fd, mm->Cd));
1223   PetscCallCXX(KokkosSparse::spadd_numeric(&mm->kh2, 1.0, mm->C2, 1.0, mm->Fo, mm->Co));
1224   PetscFunctionReturn(PETSC_SUCCESS);
1225 }
1226 
1227 /* MatProductSymbolic_MPIAIJKokkos_AB - AB flavor of MatProductSymbolic_MPIAIJKokkos
1228 
1229   Input Parameters:
1230 +  product  - Mat_Product which carried out the computation. Passed in to access info about this mat product.
1231 .  A        - an MPIAIJKOKKOS matrix
1232 .  B        - an MPIAIJKOKKOS matrix
1233 -  mm       - a struct used to stash intermediate data when computing AB. Persist from symbolic to numeric operations.
1234 */
1235 static PetscErrorCode MatProductSymbolic_MPIAIJKokkos_AB(Mat_Product *product, Mat A, Mat B, MatMatStruct_AB *mm)
1236 {
1237   Mat_MPIAIJ     *ampi = static_cast<Mat_MPIAIJ *>(A->data);
1238   Mat_MPIAIJ     *bmpi = static_cast<Mat_MPIAIJ *>(B->data);
1239   KokkosCsrMatrix Ad, Ao, Bd, Bo;
1240 
1241   PetscFunctionBegin;
1242   PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(ampi->A, &Ad));
1243   PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(ampi->B, &Ao));
1244   PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(bmpi->A, &Bd));
1245   PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(bmpi->B, &Bo));
1246 
1247   // TODO: add command line options to select spgemm algorithms
1248   auto spgemm_alg = KokkosSparse::SPGEMMAlgorithm::SPGEMM_DEFAULT; // default is TPL if enabled, otherwise KK
1249 
1250   // CUDA-10.2's spgemm has bugs. We prefer the SpGEMMreuse APIs introduced in cuda-11.4
1251 #if defined(KOKKOSKERNELS_ENABLE_TPL_CUSPARSE)
1252   #if PETSC_PKG_CUDA_VERSION_LT(11, 4, 0)
1253   spgemm_alg = KokkosSparse::SPGEMMAlgorithm::SPGEMM_KK;
1254   #endif
1255 #endif
1256 
1257   mm->kh1.create_spgemm_handle(spgemm_alg);
1258   mm->kh2.create_spgemm_handle(spgemm_alg);
1259   mm->kh3.create_spgemm_handle(spgemm_alg);
1260   mm->kh4.create_spgemm_handle(spgemm_alg);
1261 
1262   // Bcast B's rows to form F, and overlap the communication
1263   PetscIntKokkosViewHost map_h(NoInit("map_h"), bmpi->B->cmap->n);
1264   PetscCall(MatMPIAIJKokkosBcastBegin(B, ampi->Mvctx, MAT_INITIAL_MATRIX, map_h.data(), mm));
1265 
1266   // A's diag * (B's diag + B's off-diag)
1267   PetscCallCXX(KokkosSparse::spgemm_symbolic(mm->kh1, Ad, false, Bd, false, mm->C1));
1268   PetscCallCXX(KokkosSparse::spgemm_symbolic(mm->kh2, Ad, false, Bo, false, mm->C2_mid)); // C2 aliases with C2_mid, except with new column indices
1269   // KK spgemm_symbolic() only populates the result's row map, but not its columns.
1270   // TODO: Remove the fake spgemm_numeric() after KK fixed this problem.
1271   PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh1, Ad, false, Bd, false, mm->C1));
1272   PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh2, Ad, false, Bo, false, mm->C2_mid));
1273 #if PETSC_PKG_KOKKOS_KERNELS_VERSION_LT(4, 0, 0)
1274   PetscCallCXX(sort_crs_matrix(mm->C1));
1275   PetscCallCXX(sort_crs_matrix(mm->C2_mid));
1276 #endif
1277 
1278   PetscCall(MatMPIAIJKokkosBcastEnd(B, ampi->Mvctx, MAT_INITIAL_MATRIX, map_h.data(), mm));
1279 
1280   // A's off-diag * (F's diag + F's off-diag)
1281   PetscCallCXX(KokkosSparse::spgemm_symbolic(mm->kh3, Ao, false, mm->Fd, false, mm->C3));
1282   PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh3, Ao, false, mm->Fd, false, mm->C3));
1283   PetscCallCXX(KokkosSparse::spgemm_symbolic(mm->kh4, Ao, false, mm->Fo, false, mm->C4));
1284   PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh4, Ao, false, mm->Fo, false, mm->C4));
1285 #if PETSC_PKG_KOKKOS_KERNELS_VERSION_LT(4, 0, 0)
1286   PetscCallCXX(sort_crs_matrix(mm->C3));
1287   PetscCallCXX(sort_crs_matrix(mm->C4));
1288 #endif
1289 
1290   // Create C2, which shares a, i arrays with C2_mid, but with new column indices and potentially larger column size
1291   MatColIdxKokkosView oldj = mm->C2_mid.graph.entries, newj(NoInit("j"), oldj.extent(0));
1292   PetscIntKokkosView  map  = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), map_h);
1293   PetscCallCXX(Kokkos::parallel_for(
1294     oldj.extent(0), KOKKOS_LAMBDA(const PetscInt i) { newj(i) = map(oldj(i)); }));
1295   mm->C2 = KokkosCsrMatrix("C2", mm->C2_mid.numRows(), mm->n /*new column size*/, mm->C2_mid.nnz(), mm->C2_mid.values, mm->C2_mid.graph.row_map, newj);
1296 
1297   // C = (Cd, Co) = (C1+C3, C2+C4)
1298   mm->kh1.create_spadd_handle(true); // C1, C3 are sorted
1299   mm->kh2.create_spadd_handle(true); // C2, C4 are sorted
1300   PetscCallCXX(KokkosSparse::spadd_symbolic(&mm->kh1, mm->C1, mm->C3, mm->Cd));
1301   PetscCallCXX(KokkosSparse::spadd_symbolic(&mm->kh2, mm->C2, mm->C4, mm->Co));
1302   PetscCallCXX(KokkosSparse::spadd_numeric(&mm->kh1, 1.0, mm->C1, 1.0, mm->C3, mm->Cd));
1303   PetscCallCXX(KokkosSparse::spadd_numeric(&mm->kh2, 1.0, mm->C2, 1.0, mm->C4, mm->Co));
1304   PetscFunctionReturn(PETSC_SUCCESS);
1305 }
1306 
1307 static PetscErrorCode MatProductNumeric_MPIAIJKokkos_AB(Mat_Product *product, Mat A, Mat B, MatMatStruct_AB *mm)
1308 {
1309   Mat_MPIAIJ     *ampi = static_cast<Mat_MPIAIJ *>(A->data);
1310   Mat_MPIAIJ     *bmpi = static_cast<Mat_MPIAIJ *>(B->data);
1311   KokkosCsrMatrix Ad, Ao, Bd, Bo;
1312 
1313   PetscFunctionBegin;
1314   PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(ampi->A, &Ad));
1315   PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(ampi->B, &Ao));
1316   PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(bmpi->A, &Bd));
1317   PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(bmpi->B, &Bo));
1318 
1319   // Bcast B's rows to form F, and overlap the communication
1320   PetscCall(MatMPIAIJKokkosBcastBegin(B, NULL, MAT_REUSE_MATRIX, NULL, mm));
1321 
1322   // A's diag * (B's diag + B's off-diag)
1323   PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh1, Ad, false, Bd, false, mm->C1));
1324   PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh2, Ad, false, Bo, false, mm->C2_mid));
1325 
1326   PetscCall(MatMPIAIJKokkosBcastEnd(B, NULL, MAT_REUSE_MATRIX, NULL, mm));
1327 
1328   // A's off-diag * (F's diag + F's off-diag)
1329   PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh3, Ao, false, mm->Fd, false, mm->C3));
1330   PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh4, Ao, false, mm->Fo, false, mm->C4));
1331 
1332   // C = (Cd, Co) = (C1+C3, C2+C4)
1333   PetscCallCXX(KokkosSparse::spadd_numeric(&mm->kh1, 1.0, mm->C1, 1.0, mm->C3, mm->Cd));
1334   PetscCallCXX(KokkosSparse::spadd_numeric(&mm->kh2, 1.0, mm->C2, 1.0, mm->C4, mm->Co));
1335   PetscFunctionReturn(PETSC_SUCCESS);
1336 }
1337 
1338 PetscErrorCode MatProductNumeric_MPIAIJKokkos(Mat C)
1339 {
1340   Mat_MPIAIJ                  *cmpi = static_cast<Mat_MPIAIJ *>(C->data);
1341   Mat_Product                 *product;
1342   MatProductData_MPIAIJKokkos *pdata;
1343   MatProductType               ptype;
1344   Mat                          A, B;
1345 
1346   PetscFunctionBegin;
1347   MatCheckProduct(C, 1); // make sure C is a product
1348   product = C->product;
1349   pdata   = static_cast<MatProductData_MPIAIJKokkos *>(product->data);
1350   ptype   = product->type;
1351   A       = product->A;
1352   B       = product->B;
1353 
1354   // See if numeric has already been done in symbolic (e.g., user calls MatMatMult(A,B,MAT_INITIAL_MATRIX,..,C)).
1355   // If yes, skip the numeric, but reset the flag so that next time when user calls MatMatMult(E,F,MAT_REUSE_MATRIX,..,C),
1356   // we still do numeric.
1357   if (pdata->reusesym) { // numeric reuses results from symbolic
1358     pdata->reusesym = PETSC_FALSE;
1359     PetscFunctionReturn(PETSC_SUCCESS);
1360   }
1361 
1362   if (ptype == MATPRODUCT_AB) {
1363     PetscCall(MatProductNumeric_MPIAIJKokkos_AB(product, A, B, pdata->mmAB));
1364   } else if (ptype == MATPRODUCT_AtB) {
1365     PetscCall(MatProductNumeric_MPIAIJKokkos_AtB(product, A, B, pdata->mmAtB));
1366   } else if (ptype == MATPRODUCT_PtAP) { // BtAB, computed by Z = AB; C= BtZ
1367     PetscCall(MatProductNumeric_MPIAIJKokkos_AB(product, A, B, pdata->mmAB));
1368     PetscCall(MatProductNumeric_MPIAIJKokkos_AtB(product, B, pdata->Z, pdata->mmAtB));
1369   }
1370 
1371   PetscCall(MatSeqAIJKokkosModifyDevice(cmpi->A)); // mark that A, B on device are modified
1372   PetscCall(MatSeqAIJKokkosModifyDevice(cmpi->B));
1373   PetscFunctionReturn(PETSC_SUCCESS);
1374 }
1375 
1376 PetscErrorCode MatProductSymbolic_MPIAIJKokkos(Mat C)
1377 {
1378   Mat                          A, B;
1379   Mat_Product                 *product;
1380   MatProductType               ptype;
1381   MatProductData_MPIAIJKokkos *pdata;
1382   MatMatStruct                *mm = NULL;
1383   PetscInt                     m, n, M, N;
1384   Mat                          Cd, Co;
1385   MPI_Comm                     comm;
1386 
1387   PetscFunctionBegin;
1388   PetscCall(PetscObjectGetComm((PetscObject)C, &comm));
1389   MatCheckProduct(C, 1);
1390   product = C->product;
1391   PetscCheck(!product->data, comm, PETSC_ERR_PLIB, "Product data not empty");
1392   ptype = product->type;
1393   A     = product->A;
1394   B     = product->B;
1395 
1396   switch (ptype) {
1397   case MATPRODUCT_AB:
1398     m = A->rmap->n;
1399     n = B->cmap->n;
1400     M = A->rmap->N;
1401     N = B->cmap->N;
1402     break;
1403   case MATPRODUCT_AtB:
1404     m = A->cmap->n;
1405     n = B->cmap->n;
1406     M = A->cmap->N;
1407     N = B->cmap->N;
1408     break;
1409   case MATPRODUCT_PtAP:
1410     m = B->cmap->n;
1411     n = B->cmap->n;
1412     M = B->cmap->N;
1413     N = B->cmap->N;
1414     break; /* BtAB */
1415   default:
1416     SETERRQ(comm, PETSC_ERR_PLIB, "Not for product type %s", MatProductTypes[ptype]);
1417   }
1418 
1419   PetscCall(MatSetSizes(C, m, n, M, N));
1420   PetscCall(PetscLayoutSetUp(C->rmap));
1421   PetscCall(PetscLayoutSetUp(C->cmap));
1422   PetscCall(MatSetType(C, ((PetscObject)A)->type_name));
1423 
1424   pdata           = new MatProductData_MPIAIJKokkos();
1425   pdata->reusesym = product->api_user;
1426 
1427   if (ptype == MATPRODUCT_AB) {
1428     auto mmAB = new MatMatStruct_AB();
1429     PetscCall(MatProductSymbolic_MPIAIJKokkos_AB(product, A, B, mmAB));
1430     mm = pdata->mmAB = mmAB;
1431   } else if (ptype == MATPRODUCT_AtB) {
1432     auto mmAtB = new MatMatStruct_AtB();
1433     PetscCall(MatProductSymbolic_MPIAIJKokkos_AtB(product, A, B, mmAtB));
1434     mm = pdata->mmAtB = mmAtB;
1435   } else if (ptype == MATPRODUCT_PtAP) { // C = BtAB, computed as Z = AB; C= BtZ
1436     Mat Zd, Zo, Z;                       // Zd, Zo are owned by pdata->Z
1437 
1438     auto mmAB = new MatMatStruct_AB();
1439     PetscCall(MatProductSymbolic_MPIAIJKokkos_AB(product, A, B, mmAB)); // Z stored as mmAB->{Cd, Co}
1440     PetscCall(MatCreateSeqAIJKokkosWithKokkosCsrMatrix(PETSC_COMM_SELF, mmAB->Cd, &Zd));
1441     PetscCall(MatCreateSeqAIJKokkosWithKokkosCsrMatrix(PETSC_COMM_SELF, mmAB->Co, &Zo));
1442     pdata->mmAB = mmAB;
1443 
1444     m = A->rmap->n; // Z's layout
1445     n = B->cmap->n;
1446     M = A->rmap->N;
1447     N = B->cmap->N;
1448     PetscCall(MatCreate(comm, &Z));
1449     PetscCall(MatSetSizes(Z, m, n, M, N));
1450     PetscCall(PetscLayoutSetUp(Z->rmap));
1451     PetscCall(PetscLayoutSetUp(Z->cmap));
1452     PetscCall(MatSetType(Z, MATMPIAIJKOKKOS));
1453     PetscCall(MatSetMPIAIJKokkosWithSplitSeqAIJKokkosMatrices(Z, Zd, Zo, mmAB->garray));
1454 
1455     auto mmAtB = new MatMatStruct_AtB();
1456     PetscCall(MatProductSymbolic_MPIAIJKokkos_AtB(product, B, Z, mmAtB)); // final result C stored as mmAtB->{Cd, Co}
1457 
1458     pdata->Z = Z; // give ownership to pdata
1459     mm = pdata->mmAtB = mmAtB;
1460   }
1461 
1462   PetscCall(MatCreateSeqAIJKokkosWithKokkosCsrMatrix(PETSC_COMM_SELF, mm->Cd, &Cd));
1463   PetscCall(MatCreateSeqAIJKokkosWithKokkosCsrMatrix(PETSC_COMM_SELF, mm->Co, &Co));
1464   PetscCall(MatSetMPIAIJKokkosWithSplitSeqAIJKokkosMatrices(C, Cd, Co, mm->garray));
1465 
1466   C->product->data       = pdata;
1467   C->product->destroy    = MatProductDataDestroy_MPIAIJKokkos;
1468   C->ops->productnumeric = MatProductNumeric_MPIAIJKokkos;
1469   PetscFunctionReturn(PETSC_SUCCESS);
1470 }
1471 
1472 PETSC_INTERN PetscErrorCode MatProductSetFromOptions_MPIAIJKokkos(Mat mat)
1473 {
1474   Mat_Product *product = mat->product;
1475   PetscBool    match   = PETSC_FALSE;
1476   PetscBool    usecpu  = PETSC_FALSE;
1477 
1478   PetscFunctionBegin;
1479   MatCheckProduct(mat, 1);
1480   if (!product->A->boundtocpu && !product->B->boundtocpu) PetscCall(PetscObjectTypeCompare((PetscObject)product->B, ((PetscObject)product->A)->type_name, &match));
1481   if (match) { /* we can always fallback to the CPU if requested */
1482     switch (product->type) {
1483     case MATPRODUCT_AB:
1484       if (product->api_user) {
1485         PetscOptionsBegin(PetscObjectComm((PetscObject)mat), ((PetscObject)mat)->prefix, "MatMatMult", "Mat");
1486         PetscCall(PetscOptionsBool("-matmatmult_backend_cpu", "Use CPU code", "MatMatMult", usecpu, &usecpu, NULL));
1487         PetscOptionsEnd();
1488       } else {
1489         PetscOptionsBegin(PetscObjectComm((PetscObject)mat), ((PetscObject)mat)->prefix, "MatProduct_AB", "Mat");
1490         PetscCall(PetscOptionsBool("-mat_product_algorithm_backend_cpu", "Use CPU code", "MatMatMult", usecpu, &usecpu, NULL));
1491         PetscOptionsEnd();
1492       }
1493       break;
1494     case MATPRODUCT_AtB:
1495       if (product->api_user) {
1496         PetscOptionsBegin(PetscObjectComm((PetscObject)mat), ((PetscObject)mat)->prefix, "MatTransposeMatMult", "Mat");
1497         PetscCall(PetscOptionsBool("-mattransposematmult_backend_cpu", "Use CPU code", "MatTransposeMatMult", usecpu, &usecpu, NULL));
1498         PetscOptionsEnd();
1499       } else {
1500         PetscOptionsBegin(PetscObjectComm((PetscObject)mat), ((PetscObject)mat)->prefix, "MatProduct_AtB", "Mat");
1501         PetscCall(PetscOptionsBool("-mat_product_algorithm_backend_cpu", "Use CPU code", "MatTransposeMatMult", usecpu, &usecpu, NULL));
1502         PetscOptionsEnd();
1503       }
1504       break;
1505     case MATPRODUCT_PtAP:
1506       if (product->api_user) {
1507         PetscOptionsBegin(PetscObjectComm((PetscObject)mat), ((PetscObject)mat)->prefix, "MatPtAP", "Mat");
1508         PetscCall(PetscOptionsBool("-matptap_backend_cpu", "Use CPU code", "MatPtAP", usecpu, &usecpu, NULL));
1509         PetscOptionsEnd();
1510       } else {
1511         PetscOptionsBegin(PetscObjectComm((PetscObject)mat), ((PetscObject)mat)->prefix, "MatProduct_PtAP", "Mat");
1512         PetscCall(PetscOptionsBool("-mat_product_algorithm_backend_cpu", "Use CPU code", "MatPtAP", usecpu, &usecpu, NULL));
1513         PetscOptionsEnd();
1514       }
1515       break;
1516     default:
1517       break;
1518     }
1519     match = (PetscBool)!usecpu;
1520   }
1521   if (match) {
1522     switch (product->type) {
1523     case MATPRODUCT_AB:
1524     case MATPRODUCT_AtB:
1525     case MATPRODUCT_PtAP:
1526       mat->ops->productsymbolic = MatProductSymbolic_MPIAIJKokkos;
1527       break;
1528     default:
1529       break;
1530     }
1531   }
1532   /* fallback to MPIAIJ ops */
1533   if (!mat->ops->productsymbolic) PetscCall(MatProductSetFromOptions_MPIAIJ(mat));
1534   PetscFunctionReturn(PETSC_SUCCESS);
1535 }
1536 
1537 static PetscErrorCode MatSetPreallocationCOO_MPIAIJKokkos(Mat mat, PetscCount coo_n, PetscInt coo_i[], PetscInt coo_j[])
1538 {
1539   Mat_MPIAIJ       *mpiaij = (Mat_MPIAIJ *)mat->data;
1540   Mat_MPIAIJKokkos *mpikok;
1541 
1542   PetscFunctionBegin;
1543   PetscCall(MatSetPreallocationCOO_MPIAIJ(mat, coo_n, coo_i, coo_j)); /* mpiaij->A,B's type is set to seqaijkokkos */
1544   mat->preallocated = PETSC_TRUE;
1545   PetscCall(MatAssemblyBegin(mat, MAT_FINAL_ASSEMBLY));
1546   PetscCall(MatAssemblyEnd(mat, MAT_FINAL_ASSEMBLY));
1547   PetscCall(MatZeroEntries(mat));
1548   mpikok = static_cast<Mat_MPIAIJKokkos *>(mpiaij->spptr);
1549   delete mpikok;
1550   mpiaij->spptr = new Mat_MPIAIJKokkos(mpiaij);
1551   PetscFunctionReturn(PETSC_SUCCESS);
1552 }
1553 
1554 static PetscErrorCode MatSetValuesCOO_MPIAIJKokkos(Mat mat, const PetscScalar v[], InsertMode imode)
1555 {
1556   Mat_MPIAIJ                 *mpiaij = static_cast<Mat_MPIAIJ *>(mat->data);
1557   Mat_MPIAIJKokkos           *mpikok = static_cast<Mat_MPIAIJKokkos *>(mpiaij->spptr);
1558   Mat                         A = mpiaij->A, B = mpiaij->B;
1559   PetscCount                  Annz = mpiaij->Annz, Annz2 = mpiaij->Annz2, Bnnz = mpiaij->Bnnz, Bnnz2 = mpiaij->Bnnz2;
1560   MatScalarKokkosView         Aa, Ba;
1561   MatScalarKokkosView         v1;
1562   MatScalarKokkosView        &vsend  = mpikok->sendbuf_d;
1563   const MatScalarKokkosView  &v2     = mpikok->recvbuf_d;
1564   const PetscCountKokkosView &Ajmap1 = mpikok->Ajmap1_d, Ajmap2 = mpikok->Ajmap2_d, Aimap2 = mpikok->Aimap2_d;
1565   const PetscCountKokkosView &Bjmap1 = mpikok->Bjmap1_d, Bjmap2 = mpikok->Bjmap2_d, Bimap2 = mpikok->Bimap2_d;
1566   const PetscCountKokkosView &Aperm1 = mpikok->Aperm1_d, Aperm2 = mpikok->Aperm2_d, Bperm1 = mpikok->Bperm1_d, Bperm2 = mpikok->Bperm2_d;
1567   const PetscCountKokkosView &Cperm1 = mpikok->Cperm1_d;
1568   PetscMemType                memtype;
1569 
1570   PetscFunctionBegin;
1571   PetscCall(PetscGetMemType(v, &memtype)); /* Return PETSC_MEMTYPE_HOST when v is NULL */
1572   if (PetscMemTypeHost(memtype)) {         /* If user gave v[] in host, we need to copy it to device if any */
1573     v1 = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), MatScalarKokkosViewHost((PetscScalar *)v, mpiaij->coo_n));
1574   } else {
1575     v1 = MatScalarKokkosView((PetscScalar *)v, mpiaij->coo_n); /* Directly use v[]'s memory */
1576   }
1577 
1578   if (imode == INSERT_VALUES) {
1579     PetscCall(MatSeqAIJGetKokkosViewWrite(A, &Aa)); /* write matrix values */
1580     PetscCall(MatSeqAIJGetKokkosViewWrite(B, &Ba));
1581   } else {
1582     PetscCall(MatSeqAIJGetKokkosView(A, &Aa)); /* read & write matrix values */
1583     PetscCall(MatSeqAIJGetKokkosView(B, &Ba));
1584   }
1585 
1586   /* Pack entries to be sent to remote */
1587   Kokkos::parallel_for(
1588     vsend.extent(0), KOKKOS_LAMBDA(const PetscCount i) { vsend(i) = v1(Cperm1(i)); });
1589 
1590   /* Send remote entries to their owner and overlap the communication with local computation */
1591   PetscCall(PetscSFReduceWithMemTypeBegin(mpiaij->coo_sf, MPIU_SCALAR, PETSC_MEMTYPE_KOKKOS, vsend.data(), PETSC_MEMTYPE_KOKKOS, v2.data(), MPI_REPLACE));
1592   /* Add local entries to A and B in one kernel */
1593   Kokkos::parallel_for(
1594     Annz + Bnnz, KOKKOS_LAMBDA(PetscCount i) {
1595       PetscScalar sum = 0.0;
1596       if (i < Annz) {
1597         for (PetscCount k = Ajmap1(i); k < Ajmap1(i + 1); k++) sum += v1(Aperm1(k));
1598         Aa(i) = (imode == INSERT_VALUES ? 0.0 : Aa(i)) + sum;
1599       } else {
1600         i -= Annz;
1601         for (PetscCount k = Bjmap1(i); k < Bjmap1(i + 1); k++) sum += v1(Bperm1(k));
1602         Ba(i) = (imode == INSERT_VALUES ? 0.0 : Ba(i)) + sum;
1603       }
1604     });
1605   PetscCall(PetscSFReduceEnd(mpiaij->coo_sf, MPIU_SCALAR, vsend.data(), v2.data(), MPI_REPLACE));
1606 
1607   /* Add received remote entries to A and B in one kernel */
1608   Kokkos::parallel_for(
1609     Annz2 + Bnnz2, KOKKOS_LAMBDA(PetscCount i) {
1610       if (i < Annz2) {
1611         for (PetscCount k = Ajmap2(i); k < Ajmap2(i + 1); k++) Aa(Aimap2(i)) += v2(Aperm2(k));
1612       } else {
1613         i -= Annz2;
1614         for (PetscCount k = Bjmap2(i); k < Bjmap2(i + 1); k++) Ba(Bimap2(i)) += v2(Bperm2(k));
1615       }
1616     });
1617 
1618   if (imode == INSERT_VALUES) {
1619     PetscCall(MatSeqAIJRestoreKokkosViewWrite(A, &Aa)); /* Increase A & B's state etc. */
1620     PetscCall(MatSeqAIJRestoreKokkosViewWrite(B, &Ba));
1621   } else {
1622     PetscCall(MatSeqAIJRestoreKokkosView(A, &Aa));
1623     PetscCall(MatSeqAIJRestoreKokkosView(B, &Ba));
1624   }
1625   PetscFunctionReturn(PETSC_SUCCESS);
1626 }
1627 
1628 PetscErrorCode MatDestroy_MPIAIJKokkos(Mat A)
1629 {
1630   Mat_MPIAIJ *mpiaij = (Mat_MPIAIJ *)A->data;
1631 
1632   PetscFunctionBegin;
1633   PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatMPIAIJSetPreallocation_C", NULL));
1634   PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatMPIAIJGetLocalMatMerge_C", NULL));
1635   PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatSetPreallocationCOO_C", NULL));
1636   PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatSetValuesCOO_C", NULL));
1637   delete (Mat_MPIAIJKokkos *)mpiaij->spptr;
1638   PetscCall(MatDestroy_MPIAIJ(A));
1639   PetscFunctionReturn(PETSC_SUCCESS);
1640 }
1641 
1642 PETSC_INTERN PetscErrorCode MatConvert_MPIAIJ_MPIAIJKokkos(Mat A, MatType mtype, MatReuse reuse, Mat *newmat)
1643 {
1644   Mat         B;
1645   Mat_MPIAIJ *a;
1646 
1647   PetscFunctionBegin;
1648   if (reuse == MAT_INITIAL_MATRIX) {
1649     PetscCall(MatDuplicate(A, MAT_COPY_VALUES, newmat));
1650   } else if (reuse == MAT_REUSE_MATRIX) {
1651     PetscCall(MatCopy(A, *newmat, SAME_NONZERO_PATTERN));
1652   }
1653   B = *newmat;
1654 
1655   B->boundtocpu = PETSC_FALSE;
1656   PetscCall(PetscFree(B->defaultvectype));
1657   PetscCall(PetscStrallocpy(VECKOKKOS, &B->defaultvectype));
1658   PetscCall(PetscObjectChangeTypeName((PetscObject)B, MATMPIAIJKOKKOS));
1659 
1660   a = static_cast<Mat_MPIAIJ *>(A->data);
1661   if (a->A) PetscCall(MatSetType(a->A, MATSEQAIJKOKKOS));
1662   if (a->B) PetscCall(MatSetType(a->B, MATSEQAIJKOKKOS));
1663   if (a->lvec) PetscCall(VecSetType(a->lvec, VECSEQKOKKOS));
1664 
1665   B->ops->assemblyend           = MatAssemblyEnd_MPIAIJKokkos;
1666   B->ops->mult                  = MatMult_MPIAIJKokkos;
1667   B->ops->multadd               = MatMultAdd_MPIAIJKokkos;
1668   B->ops->multtranspose         = MatMultTranspose_MPIAIJKokkos;
1669   B->ops->productsetfromoptions = MatProductSetFromOptions_MPIAIJKokkos;
1670   B->ops->destroy               = MatDestroy_MPIAIJKokkos;
1671 
1672   PetscCall(PetscObjectComposeFunction((PetscObject)B, "MatMPIAIJSetPreallocation_C", MatMPIAIJSetPreallocation_MPIAIJKokkos));
1673   PetscCall(PetscObjectComposeFunction((PetscObject)B, "MatMPIAIJGetLocalMatMerge_C", MatMPIAIJGetLocalMatMerge_MPIAIJKokkos));
1674   PetscCall(PetscObjectComposeFunction((PetscObject)B, "MatSetPreallocationCOO_C", MatSetPreallocationCOO_MPIAIJKokkos));
1675   PetscCall(PetscObjectComposeFunction((PetscObject)B, "MatSetValuesCOO_C", MatSetValuesCOO_MPIAIJKokkos));
1676   PetscFunctionReturn(PETSC_SUCCESS);
1677 }
1678 /*MC
1679    MATAIJKOKKOS - "mpiaijkokkos", a matrix type to be used for CSR sparse matrices with Kokkos
1680 
1681    A matrix type type using Kokkos-Kernels CrsMatrix type for portability across different device types
1682 
1683    Options Database Key:
1684 .  -mat_type aijkokkos - sets the matrix type to `MATAIJKOKKOS`
1685 
1686   Level: beginner
1687 
1688 .seealso: [](chapter_matrices), `Mat`, `MatCreateAIJKokkos()`, `MATSEQAIJKOKKOS`, `MATSEQAIJ`, `MATMPIAIJ`
1689 M*/
1690 PETSC_EXTERN PetscErrorCode MatCreate_MPIAIJKokkos(Mat A)
1691 {
1692   PetscFunctionBegin;
1693   PetscCall(PetscKokkosInitializeCheck());
1694   PetscCall(MatCreate_MPIAIJ(A));
1695   PetscCall(MatConvert_MPIAIJ_MPIAIJKokkos(A, MATMPIAIJKOKKOS, MAT_INPLACE_MATRIX, &A));
1696   PetscFunctionReturn(PETSC_SUCCESS);
1697 }
1698 
1699 /*@C
1700    MatCreateAIJKokkos - Creates a sparse matrix in `MATAIJKOKOS` (compressed row) format
1701    (the default parallel PETSc format).  This matrix will ultimately pushed down
1702    to Kokkos for calculations.
1703 
1704    Collective
1705 
1706    Input Parameters:
1707 +  comm - MPI communicator, set to `PETSC_COMM_SELF`
1708 .  m - number of local rows (or `PETSC_DECIDE` to have calculated if `M` is given)
1709            This value should be the same as the local size used in creating the
1710            y vector for the matrix-vector product y = Ax.
1711 .  n - This value should be the same as the local size used in creating the
1712        x vector for the matrix-vector product y = Ax. (or `PETSC_DECIDE` to have
1713        calculated if N is given) For square matrices n is almost always `m`.
1714 .  M - number of global rows (or `PETSC_DETERMINE` to have calculated if `m` is given)
1715 .  N - number of global columns (or `PETSC_DETERMINE` to have calculated if `n` is given)
1716 .  d_nz  - number of nonzeros per row in DIAGONAL portion of local submatrix
1717            (same value is used for all local rows)
1718 .  d_nnz - array containing the number of nonzeros in the various rows of the
1719            DIAGONAL portion of the local submatrix (possibly different for each row)
1720            or `NULL`, if `d_nz` is used to specify the nonzero structure.
1721            The size of this array is equal to the number of local rows, i.e `m`.
1722            For matrices you plan to factor you must leave room for the diagonal entry and
1723            put in the entry even if it is zero.
1724 .  o_nz  - number of nonzeros per row in the OFF-DIAGONAL portion of local
1725            submatrix (same value is used for all local rows).
1726 -  o_nnz - array containing the number of nonzeros in the various rows of the
1727            OFF-DIAGONAL portion of the local submatrix (possibly different for
1728            each row) or `NULL`, if `o_nz` is used to specify the nonzero
1729            structure. The size of this array is equal to the number
1730            of local rows, i.e `m`.
1731 
1732    Output Parameter:
1733 .  A - the matrix
1734 
1735    Level: intermediate
1736 
1737    Notes:
1738    It is recommended that one use the `MatCreate()`, `MatSetType()` and/or `MatSetFromOptions()`,
1739    MatXXXXSetPreallocation() paradigm instead of this routine directly.
1740    [MatXXXXSetPreallocation() is, for example, `MatSeqAIJSetPreallocation()`]
1741 
1742    The AIJ format, also called compressed row storage), is fully compatible with standard Fortran
1743    storage.  That is, the stored row and column indices can begin at
1744    either one (as in Fortran) or zero.
1745 
1746 .seealso: [](chapter_matrices), `Mat`, `MATAIJKOKOS`, `MATSEQAIJKOKOS`, `MATMPIAIJKOKOS`, `MatCreate()`, `MatCreateAIJ()`, `MatSetValues()`,
1747           `MatSeqAIJSetColumnIndices()`, `MatCreateSeqAIJWithArrays()`, `MatCreateAIJ()`, `MATMPIAIJKOKKOS`, `MATAIJKOKKOS`
1748 @*/
1749 PetscErrorCode MatCreateAIJKokkos(MPI_Comm comm, PetscInt m, PetscInt n, PetscInt M, PetscInt N, PetscInt d_nz, const PetscInt d_nnz[], PetscInt o_nz, const PetscInt o_nnz[], Mat *A)
1750 {
1751   PetscMPIInt size;
1752 
1753   PetscFunctionBegin;
1754   PetscCall(MatCreate(comm, A));
1755   PetscCall(MatSetSizes(*A, m, n, M, N));
1756   PetscCallMPI(MPI_Comm_size(comm, &size));
1757   if (size > 1) {
1758     PetscCall(MatSetType(*A, MATMPIAIJKOKKOS));
1759     PetscCall(MatMPIAIJSetPreallocation(*A, d_nz, d_nnz, o_nz, o_nnz));
1760   } else {
1761     PetscCall(MatSetType(*A, MATSEQAIJKOKKOS));
1762     PetscCall(MatSeqAIJSetPreallocation(*A, d_nz, d_nnz));
1763   }
1764   PetscFunctionReturn(PETSC_SUCCESS);
1765 }
1766