xref: /petsc/src/mat/impls/aij/mpi/kokkos/mpiaijkok.kokkos.cxx (revision e2fbb1ba6e3eace0c4c0036fdbc8a86f123a4fe5)
1 #include <petscvec_kokkos.hpp>
2 #include <petscpkg_version.h>
3 #include <petsc/private/sfimpl.h>
4 #include <../src/mat/impls/aij/seq/kokkos/aijkok.hpp>
5 #include <../src/mat/impls/aij/mpi/mpiaij.h>
6 #include <KokkosSparse_spadd.hpp>
7 #include <KokkosSparse_spgemm.hpp>
8 
9 static PetscErrorCode MatAssemblyEnd_MPIAIJKokkos(Mat A, MatAssemblyType mode)
10 {
11   Mat_MPIAIJ *mpiaij = (Mat_MPIAIJ *)A->data;
12 
13   PetscFunctionBegin;
14   PetscCall(MatAssemblyEnd_MPIAIJ(A, mode));
15   /* E.g., MatCreateSubMatrix() calls MatCreateMPIAIJWithSeqAIJ(comm,A,B,..), which creates Bnew of SEQAIJ and destroys B of SEQAIJKOKKOS.
16      Thus we finalize A/B/lvec's type in MatAssemblyEnd() to handle various cases.
17    */
18   if (mode == MAT_FINAL_ASSEMBLY) {
19     PetscCall(MatSetType(mpiaij->A, MATSEQAIJKOKKOS));
20     PetscCall(MatSetType(mpiaij->B, MATSEQAIJKOKKOS));
21     PetscCall(VecSetType(mpiaij->lvec, VECSEQKOKKOS));
22   }
23   PetscFunctionReturn(PETSC_SUCCESS);
24 }
25 
26 static PetscErrorCode MatMPIAIJSetPreallocation_MPIAIJKokkos(Mat mat, PetscInt d_nz, const PetscInt d_nnz[], PetscInt o_nz, const PetscInt o_nnz[])
27 {
28   Mat_MPIAIJ *mpiaij = (Mat_MPIAIJ *)mat->data;
29 
30   PetscFunctionBegin;
31   // If mat was set to use the "set values with a hash table" mechanism, discard it and restore the cached ops
32   if (mat->hash_active) {
33     mat->ops[0]      = mpiaij->cops;
34     mat->hash_active = PETSC_FALSE;
35   }
36 
37   PetscCall(PetscLayoutSetUp(mat->rmap));
38   PetscCall(PetscLayoutSetUp(mat->cmap));
39 #if defined(PETSC_USE_DEBUG)
40   if (d_nnz) {
41     PetscInt i;
42     for (i = 0; i < mat->rmap->n; i++) PetscCheck(d_nnz[i] >= 0, PETSC_COMM_SELF, PETSC_ERR_ARG_OUTOFRANGE, "d_nnz cannot be less than 0: local row %" PetscInt_FMT " value %" PetscInt_FMT, i, d_nnz[i]);
43   }
44   if (o_nnz) {
45     PetscInt i;
46     for (i = 0; i < mat->rmap->n; i++) PetscCheck(o_nnz[i] >= 0, PETSC_COMM_SELF, PETSC_ERR_ARG_OUTOFRANGE, "o_nnz cannot be less than 0: local row %" PetscInt_FMT " value %" PetscInt_FMT, i, o_nnz[i]);
47   }
48 #endif
49 #if defined(PETSC_USE_CTABLE)
50   PetscCall(PetscHMapIDestroy(&mpiaij->colmap));
51 #else
52   PetscCall(PetscFree(mpiaij->colmap));
53 #endif
54   PetscCall(PetscFree(mpiaij->garray));
55   PetscCall(VecDestroy(&mpiaij->lvec));
56   PetscCall(VecScatterDestroy(&mpiaij->Mvctx));
57   /* Because the B will have been resized we simply destroy it and create a new one each time */
58   PetscCall(MatDestroy(&mpiaij->B));
59 
60   if (!mpiaij->A) {
61     PetscCall(MatCreate(PETSC_COMM_SELF, &mpiaij->A));
62     PetscCall(MatSetSizes(mpiaij->A, mat->rmap->n, mat->cmap->n, mat->rmap->n, mat->cmap->n));
63   }
64   if (!mpiaij->B) {
65     PetscMPIInt size;
66     PetscCallMPI(MPI_Comm_size(PetscObjectComm((PetscObject)mat), &size));
67     PetscCall(MatCreate(PETSC_COMM_SELF, &mpiaij->B));
68     PetscCall(MatSetSizes(mpiaij->B, mat->rmap->n, size > 1 ? mat->cmap->N : 0, mat->rmap->n, size > 1 ? mat->cmap->N : 0));
69   }
70   PetscCall(MatSetType(mpiaij->A, MATSEQAIJKOKKOS));
71   PetscCall(MatSetType(mpiaij->B, MATSEQAIJKOKKOS));
72   PetscCall(MatSeqAIJSetPreallocation(mpiaij->A, d_nz, d_nnz));
73   PetscCall(MatSeqAIJSetPreallocation(mpiaij->B, o_nz, o_nnz));
74   mat->preallocated = PETSC_TRUE;
75   PetscFunctionReturn(PETSC_SUCCESS);
76 }
77 
78 static PetscErrorCode MatMult_MPIAIJKokkos(Mat mat, Vec xx, Vec yy)
79 {
80   Mat_MPIAIJ *mpiaij = (Mat_MPIAIJ *)mat->data;
81   PetscInt    nt;
82 
83   PetscFunctionBegin;
84   PetscCall(VecGetLocalSize(xx, &nt));
85   PetscCheck(nt == mat->cmap->n, PETSC_COMM_SELF, PETSC_ERR_ARG_SIZ, "Incompatible partition of mat (%" PetscInt_FMT ") and xx (%" PetscInt_FMT ")", mat->cmap->n, nt);
86   PetscCall(VecScatterBegin(mpiaij->Mvctx, xx, mpiaij->lvec, INSERT_VALUES, SCATTER_FORWARD));
87   PetscCall((*mpiaij->A->ops->mult)(mpiaij->A, xx, yy));
88   PetscCall(VecScatterEnd(mpiaij->Mvctx, xx, mpiaij->lvec, INSERT_VALUES, SCATTER_FORWARD));
89   PetscCall((*mpiaij->B->ops->multadd)(mpiaij->B, mpiaij->lvec, yy, yy));
90   PetscFunctionReturn(PETSC_SUCCESS);
91 }
92 
93 static PetscErrorCode MatMultAdd_MPIAIJKokkos(Mat mat, Vec xx, Vec yy, Vec zz)
94 {
95   Mat_MPIAIJ *mpiaij = (Mat_MPIAIJ *)mat->data;
96   PetscInt    nt;
97 
98   PetscFunctionBegin;
99   PetscCall(VecGetLocalSize(xx, &nt));
100   PetscCheck(nt == mat->cmap->n, PETSC_COMM_SELF, PETSC_ERR_ARG_SIZ, "Incompatible partition of mat (%" PetscInt_FMT ") and xx (%" PetscInt_FMT ")", mat->cmap->n, nt);
101   PetscCall(VecScatterBegin(mpiaij->Mvctx, xx, mpiaij->lvec, INSERT_VALUES, SCATTER_FORWARD));
102   PetscCall((*mpiaij->A->ops->multadd)(mpiaij->A, xx, yy, zz));
103   PetscCall(VecScatterEnd(mpiaij->Mvctx, xx, mpiaij->lvec, INSERT_VALUES, SCATTER_FORWARD));
104   PetscCall((*mpiaij->B->ops->multadd)(mpiaij->B, mpiaij->lvec, zz, zz));
105   PetscFunctionReturn(PETSC_SUCCESS);
106 }
107 
108 static PetscErrorCode MatMultTranspose_MPIAIJKokkos(Mat mat, Vec xx, Vec yy)
109 {
110   Mat_MPIAIJ *mpiaij = (Mat_MPIAIJ *)mat->data;
111   PetscInt    nt;
112 
113   PetscFunctionBegin;
114   PetscCall(VecGetLocalSize(xx, &nt));
115   PetscCheck(nt == mat->rmap->n, PETSC_COMM_SELF, PETSC_ERR_ARG_SIZ, "Incompatible partition of mat (%" PetscInt_FMT ") and xx (%" PetscInt_FMT ")", mat->rmap->n, nt);
116   PetscCall((*mpiaij->B->ops->multtranspose)(mpiaij->B, xx, mpiaij->lvec));
117   PetscCall((*mpiaij->A->ops->multtranspose)(mpiaij->A, xx, yy));
118   PetscCall(VecScatterBegin(mpiaij->Mvctx, mpiaij->lvec, yy, ADD_VALUES, SCATTER_REVERSE));
119   PetscCall(VecScatterEnd(mpiaij->Mvctx, mpiaij->lvec, yy, ADD_VALUES, SCATTER_REVERSE));
120   PetscFunctionReturn(PETSC_SUCCESS);
121 }
122 
123 /* Merge the "A, B" matrices of mat into a matrix C.  mat's type is MPIAIJKOKKOS. C's type is MATSEQAIJKOKKOS.
124    A is put before B. C's size would be A->rmap->n by (A->cmap->n + B->cmap->n).
125    C still uses local column ids. Their corresponding global column ids are returned in glob.
126 */
127 static PetscErrorCode MatMPIAIJGetLocalMatMerge_MPIAIJKokkos(Mat mat, MatReuse reuse, IS *glob, Mat *C)
128 {
129   Mat             Ad, Ao;
130   const PetscInt *cmap;
131 
132   PetscFunctionBegin;
133   PetscCall(MatMPIAIJGetSeqAIJ(mat, &Ad, &Ao, &cmap));
134   PetscCall(MatSeqAIJKokkosMergeMats(Ad, Ao, reuse, C));
135   if (glob) {
136     PetscInt cst, i, dn, on, *gidx;
137     PetscCall(MatGetLocalSize(Ad, NULL, &dn));
138     PetscCall(MatGetLocalSize(Ao, NULL, &on));
139     PetscCall(MatGetOwnershipRangeColumn(mat, &cst, NULL));
140     PetscCall(PetscMalloc1(dn + on, &gidx));
141     for (i = 0; i < dn; i++) gidx[i] = cst + i;
142     for (i = 0; i < on; i++) gidx[i + dn] = cmap[i];
143     PetscCall(ISCreateGeneral(PetscObjectComm((PetscObject)Ad), dn + on, gidx, PETSC_OWN_POINTER, glob));
144   }
145   PetscFunctionReturn(PETSC_SUCCESS);
146 }
147 
148 /* Structs used in matrix products of type C=AB, C=A^tB and C=B^tAB */
149 struct MatMatStruct {
150   PetscInt            n, *garray;     // C's garray and its size.
151   KokkosCsrMatrix     Cd, Co;         // C is in split form matrices (all in local column indcies)
152   KokkosCsrMatrix     C1, C2, C3, C4; // intermediate mat products
153   KokkosCsrMatrix     C2_mid, C4_mid; // alias of C2, C4; share their a[], i[], but with different j[] (hence column size)
154   PetscIntKokkosView  E_NzLeft;
155   PetscSF             sf = nullptr; // SF to bcast or reduce matrices E to F
156   MatScalarKokkosView rootBuf, leafBuf;
157   KokkosCsrMatrix     Fd, Fo; // F in split form
158 
159   KernelHandle kh1; // compute C1, add C1+C3 or C1+Fd
160   KernelHandle kh2; // compute C2, add C2+C4 or C2+Fo
161   KernelHandle kh3; // compute C3
162   KernelHandle kh4; // compute C4
163 
164   PetscInt E_TeamSize; // kernel launching parameters in merging E or splitting F
165   PetscInt E_VectorLength;
166   PetscInt E_RowsPerTeam;
167   PetscInt F_TeamSize;
168   PetscInt F_VectorLength;
169   PetscInt F_RowsPerTeam;
170 
171   ~MatMatStruct()
172   {
173     PetscFunctionBegin;
174     PetscCallAbort(PETSC_COMM_SELF, PetscSFDestroy(&sf));
175     PetscFunctionReturnVoid();
176   }
177 };
178 
179 struct MatMatStruct_AB : public MatMatStruct {
180   PetscIntKokkosView F_NzLeft; // plans to split F (in leafbuf) into Fd, Fo
181   PetscIntKokkosView irootloc; // plans to put E (i.e., Bd, Bo) into rootBuf
182   PetscIntKokkosView rowoffset;
183 };
184 
185 struct MatMatStruct_AtB : public MatMatStruct {
186   MatColIdxKokkosView Fdjmap; // plans to reduce data in rootBuf to Fd, Fo
187   MatColIdxKokkosView Fdjperm;
188   MatColIdxKokkosView Fojmap;
189   MatColIdxKokkosView Fojperm;
190 };
191 
192 struct MatProductData_MPIAIJKokkos {
193   MatMatStruct_AB  *mmAB     = nullptr;
194   MatMatStruct_AtB *mmAtB    = nullptr;
195   PetscBool         reusesym = PETSC_FALSE;
196   Mat               Z        = nullptr; // store Z=AB in computing BtAB
197 
198   ~MatProductData_MPIAIJKokkos()
199   {
200     delete mmAB;
201     delete mmAtB;
202     PetscCallAbort(PETSC_COMM_SELF, MatDestroy(&Z));
203   }
204 };
205 
206 static PetscErrorCode MatProductDataDestroy_MPIAIJKokkos(void *data)
207 {
208   PetscFunctionBegin;
209   PetscCallCXX(delete static_cast<MatProductData_MPIAIJKokkos *>(data));
210   PetscFunctionReturn(PETSC_SUCCESS);
211 }
212 
213 /* MatSetMPIAIJKokkosWithSplitSeqAIJKokkosMatrices - Set the diag and offdiag matrices of a MATMPIAIJKOKKOS matrix.
214    It is similar to MatCreateMPIAIJWithSplitArrays.
215 
216   Input Parameters:
217 +  mat   - the MATMPIAIJKOKKOS matrix, which should have its type and layout set, but should not have its diag, offdiag matrices set
218 .  A     - the diag matrix using local col ids
219 -  B     - the offdiag matrix using global col ids
220 
221   Output Parameter:
222 .  mat   - the updated MATMPIAIJKOKKOS matrix
223 */
224 static PetscErrorCode MatSetMPIAIJKokkosWithSplitSeqAIJKokkosMatrices(Mat mat, Mat A, Mat B, PetscInt *garray)
225 {
226   Mat_MPIAIJ *mpiaij = static_cast<Mat_MPIAIJ *>(mat->data);
227   PetscInt    m, n, M, N, Am, An, Bm, Bn;
228 
229   PetscFunctionBegin;
230   PetscCall(MatGetSize(mat, &M, &N));
231   PetscCall(MatGetLocalSize(mat, &m, &n));
232   PetscCall(MatGetLocalSize(A, &Am, &An));
233   PetscCall(MatGetLocalSize(B, &Bm, &Bn));
234 
235   PetscCheck(m == Am && m == Bm, PETSC_COMM_SELF, PETSC_ERR_PLIB, "local number of rows do not match");
236   PetscCheck(n == An, PETSC_COMM_SELF, PETSC_ERR_PLIB, "local number of columns do not match");
237   // PetscCheck(N == Bn, PETSC_COMM_SELF, PETSC_ERR_PLIB, "global number of columns do not match");
238   PetscCheck(!mpiaij->A && !mpiaij->B, PETSC_COMM_SELF, PETSC_ERR_PLIB, "A, B of the MPIAIJ matrix are not empty");
239   mpiaij->A      = A;
240   mpiaij->B      = B;
241   mpiaij->garray = garray;
242 
243   mat->preallocated     = PETSC_TRUE;
244   mat->nooffprocentries = PETSC_TRUE; /* See MatAssemblyBegin_MPIAIJ. In effect, making MatAssemblyBegin a nop */
245 
246   PetscCall(MatSetOption(mat, MAT_NO_OFF_PROC_ENTRIES, PETSC_TRUE));
247   PetscCall(MatAssemblyBegin(mat, MAT_FINAL_ASSEMBLY));
248   /* MatAssemblyEnd is critical here. It sets mat->offloadmask according to A and B's, and
249     also gets mpiaij->B compacted, with its col ids and size reduced
250   */
251   PetscCall(MatAssemblyEnd(mat, MAT_FINAL_ASSEMBLY));
252   PetscCall(MatSetOption(mat, MAT_NO_OFF_PROC_ENTRIES, PETSC_FALSE));
253   PetscCall(MatSetOption(mat, MAT_NEW_NONZERO_LOCATION_ERR, PETSC_TRUE));
254   PetscFunctionReturn(PETSC_SUCCESS);
255 }
256 
257 // Adapted from Kokkos-Kernels spmv_launch_parameters(), to get parameters in Kokkos nested loops which we used to merge or
258 // split csr matrices. The rule is to have "vector_length * team_size" be around 256 on GPUs (e.g., for a CUDA thread block)
259 template <class ExecutionSpace>
260 static PetscErrorCode MatMergeGetLaunchParameters(PetscInt numRows, PetscInt nnz, PetscInt rows_per_thread, PetscInt &team_size, PetscInt &vector_length, PetscInt &rows_per_team)
261 {
262   Kokkos::TeamPolicy<ExecutionSpace> teamPolicy(128, Kokkos::AUTO);
263 
264   PetscFunctionBegin;
265   PetscInt nnz_per_row = numRows ? (nnz / numRows) : 0; // we might meet empty matrices
266 
267   if (nnz_per_row < 1) nnz_per_row = 1;
268 
269   int max_vector_length = teamPolicy.vector_length_max();
270 
271   if (vector_length < 1) {
272     vector_length = 1;
273     while (vector_length < max_vector_length && vector_length * 6 < nnz_per_row) vector_length *= 2;
274   }
275 
276   // Determine rows per thread
277   if (rows_per_thread < 1) {
278     if (KokkosKernels::Impl::kk_is_gpu_exec_space<ExecutionSpace>()) rows_per_thread = 1;
279     else {
280       if (nnz_per_row < 20 && nnz > 5000000) {
281         rows_per_thread = 256;
282       } else rows_per_thread = 64;
283     }
284   }
285 
286   if (team_size < 1) {
287     if (KokkosKernels::Impl::kk_is_gpu_exec_space<ExecutionSpace>()) {
288       team_size = 256 / vector_length;
289     } else {
290       team_size = 1;
291     }
292   }
293 
294   rows_per_team = rows_per_thread * team_size;
295 
296   if (rows_per_team < 0) {
297     PetscInt nnz_per_team = 4096;
298     PetscInt conc         = ExecutionSpace().concurrency();
299     while ((conc * nnz_per_team * 4 > nnz) && (nnz_per_team > 256)) nnz_per_team /= 2;
300     rows_per_team = (nnz_per_team + nnz_per_row - 1) / nnz_per_row;
301   }
302   PetscFunctionReturn(PETSC_SUCCESS);
303 }
304 
305 /*
306   Reduce two sets of global indices into local ones
307 
308   Input Parameters:
309 +  n1          - size of garray1[], the first set
310 .  garray1[n1] - a sorted global index array (without duplicates)
311 .  m           - size of indices[], the second set
312 -  indices[m]  - a unsorted global index array (might have duplicates), which will be updated on output into local ones
313 
314   Output Parameters:
315 +  n2          - size of garray2[], the merged set, which combines garray1[] and indices[]
316 .  garray2[n2] - allocated by callee using PetscMalloc1(). Contains sorted unique global indices (without duplicates). Caller needs to free it.
317 .  map[n1]     - allocated by caller. It gives garray1[i] = garray2[map[i]]
318 -  indices[m]  - on output, global indices in this array are rewritten with local ones, i.e, indices_input[i] = garray2[indices_output[i]]
319 
320    Example, say
321     n1         = 5
322     garray1[5] = {1, 4, 7, 8, 10}
323     m          = 4
324     indices[4] = {2, 4, 8, 9}
325 
326    Combining them together, we have 7 global indices in garray2[]
327     n2         = 7
328     garray2[7] = {1, 2, 4, 7, 8, 9, 10}
329 
330    And we have map[] to connect "garray1[i] = garray2[map[i]], i=[0,n1)"
331     map[5] = {0, 2, 3, 4, 6}
332 
333    On output, indices[] is updated with local indices
334     indices[4] = {1, 2, 4, 5}
335 */
336 static PetscErrorCode ReduceTwoSetsOfGlobalIndices(PetscInt n1, const PetscInt *garray1, PetscInt m, PetscInt *indices, PetscInt *n2_, PetscInt **garray2_, PetscInt *map)
337 {
338   PetscHMapI    g2l = nullptr;
339   PetscHashIter iter;
340   PetscInt      tot, key, val; // total unique global indices. key is global id; val is local id
341   PetscInt      n2, *garray2;
342 
343   PetscFunctionBegin;
344   tot = 0;
345   PetscCall(PetscHMapICreateWithSize(n1, &g2l));
346   for (PetscInt i = 0; i < m; i++) {                                // insert those in indices[]
347     PetscCall(PetscHMapIGetWithDefault(g2l, indices[i], -1, &val)); // if not exist, val is set with -1
348     if (val < 0) PetscCall(PetscHMapISet(g2l, indices[i], tot++));  // val < 0 means gid is not in the hash table yet
349   }
350 
351   for (PetscInt i = 0; i < n1; i++) { // insert those in garray1[]
352     PetscCall(PetscHMapIGetWithDefault(g2l, garray1[i], -1, &val));
353     if (val < 0) PetscCall(PetscHMapISet(g2l, garray1[i], tot++));
354   }
355 
356   // Pull out (unique) globals in the hash table and put them in garray2[]
357   n2 = tot;
358   PetscCall(PetscMalloc1(n2, &garray2));
359   tot = 0;
360   PetscHashIterBegin(g2l, iter);
361   while (!PetscHashIterAtEnd(g2l, iter)) {
362     PetscHashIterGetKey(g2l, iter, key);
363     PetscHashIterNext(g2l, iter);
364     garray2[tot++] = key;
365   }
366 
367   // Sort garray2[] and then map them to local indices starting from 0
368   PetscCall(PetscSortInt(n2, garray2));
369   PetscCall(PetscHMapIClear(g2l));
370   for (PetscInt i = 0; i < tot; i++) PetscCall(PetscHMapISet(g2l, garray2[i], i)); // i is the local id
371 
372   // Rewrite indices[] with local indices
373   for (PetscInt i = 0; i < m; i++) {
374     PetscCall(PetscHMapIGetWithDefault(g2l, indices[i], -1, &val));
375     PetscAssert(val >= 0, PETSC_COMM_SELF, PETSC_ERR_PLIB, "Met a negative local column index");
376     indices[i] = val;
377   }
378   // Record the map that maps garray1[i] to garray2[map[i]]
379   for (PetscInt i = 0; i < n1; i++) PetscCall(PetscHMapIGetWithDefault(g2l, garray1[i], -1, &map[i]));
380   PetscCall(PetscHMapIDestroy(&g2l));
381   *n2_      = n2;
382   *garray2_ = garray2;
383   PetscFunctionReturn(PETSC_SUCCESS);
384 }
385 
386 /*
387   MatMPIAIJKokkosReduce - Reduce rows of a MPIAIJKOKKOS matrix (E, in split form) to produce another matrix (F, also in split form, stored in mm)
388 
389   It is the reverse of MatMPIAIJKokkosBcast() in some sense, but with a different signature since we do not really need a fully populated MPIAIJKOKKOS E.
390 
391   Think each row of E as a leaf, then the given ownerSF specifies roots for the leaves. Roots may connect to multiple leaves.
392   In this routine, we sparse-merge leaves (rows) at their roots to form potentially longer rows in F. F's number of rows will be nroots of ownerSF.
393 
394   Input Parameters:
395 +  comm       - MPI communicator of E
396 .  A          - diag block of E, using local column indices
397 .  B          - off-diag block of E, using local column indices
398 .  cstart      - (global) start column of Ed
399 .  cend        - (global) end column + 1 of Ed.  In other words, E's column ownership is in range of [cstart, cend)
400 .  garray1[n1] - global column indices of Eo. Here n1 is Eo's column size.
401 .  ownerSF     - the SF specifies ownership (root) of rows in E
402 .  reuse       - either MAT_INITIAL_MATRIX or MAT_REUSE_MATRIX
403 -  mm          - to stash intermediate data structures for reuse
404 
405   Output Parameters:
406 +  map[n1]  - allocated by caller. It maps garray1[] to garray2[]. See more at ReduceTwoSetsOfGlobalIndices().
407 -  mm       - contains various info, such as garray2[], F (Fd, Fo) etc.
408 
409   Notes:
410   When reuse = MAT_REUSE_MATRIX, cstart, cend, garray1, ownerSF, map are not significant.
411 
412  */
413 static PetscErrorCode MatMPIAIJKokkosReduceBegin(MPI_Comm comm, KokkosCsrMatrix A, KokkosCsrMatrix B, PetscInt cstart, PetscInt cend, const PetscInt *garray1, PetscSF ownerSF, MatReuse reuse, PetscInt *map, MatMatStruct_AtB *mm)
414 {
415   PetscFunctionBegin;
416   if (reuse == MAT_INITIAL_MATRIX) {
417     PetscInt Em = A.numRows(), Fm;
418     PetscInt n1 = B.numCols();
419 
420     PetscCall(PetscSFGetGraph(ownerSF, &Fm, NULL, NULL, NULL)); // Fm = #rows of F = nroots of ownerSF
421 
422     // Do the analysis on host
423     auto                 Ai_h = Kokkos::create_mirror_view_and_copy(Kokkos::HostSpace(), A.graph.row_map);
424     auto                 Aj_h = Kokkos::create_mirror_view_and_copy(Kokkos::HostSpace(), A.graph.entries);
425     auto                 Bi_h = Kokkos::create_mirror_view_and_copy(Kokkos::HostSpace(), B.graph.row_map);
426     auto                 Bj_h = Kokkos::create_mirror_view_and_copy(Kokkos::HostSpace(), B.graph.entries);
427     const MatRowMapType *Ai = Ai_h.data(), *Bi = Bi_h.data();
428     const MatColIdxType *Aj = Aj_h.data(), *Bj = Bj_h.data();
429 
430     // Count how many nonzeros of each row in E are in the left of the diag block [cstart,cend)
431     PetscIntKokkosViewHost E_NzLeft_h(NoInit("E_NzLeft_h"), Em), E_RowLen_h(NoInit("E_RowLen_h"), Em);
432     PetscInt              *E_NzLeft = E_NzLeft_h.data(), *E_RowLen = E_RowLen_h.data();
433     for (PetscInt i = 0; i < Em; i++) {
434       const PetscInt *first, *last, *it;
435       PetscInt        count, step;
436       // std::lower_bound(first,last,cstart), but need to use global column indices
437       first = Bj + Bi[i];
438       last  = Bj + Bi[i + 1];
439       count = last - first;
440       while (count > 0) {
441         it   = first;
442         step = count / 2;
443         it += step;
444         if (garray1[*it] < cstart) { // map local to global
445           first = ++it;
446           count -= step + 1;
447         } else count = step;
448       }
449       E_NzLeft[i] = first - (Bj + Bi[i]);
450       E_RowLen[i] = (Ai[i + 1] - Ai[i]) + (Bi[i + 1] - Bi[i]);
451     }
452 
453     // Get length of rows (i.e., sizes of leaves) that contribute to my roots
454     const PetscMPIInt *iranks, *ranks;
455     const PetscInt    *ioffset, *irootloc, *roffset, *rmine;
456     PetscInt           niranks, nranks;
457     MPI_Request       *reqs;
458     PetscMPIInt        tag;
459     PetscSF            reduceSF;
460     PetscInt          *sdisp, *rdisp;
461 
462     PetscCall(PetscCommGetNewTag(comm, &tag));
463     PetscCall(PetscSFGetLeafRanks(ownerSF, &niranks, &iranks, &ioffset, &irootloc));  // get leaf ranks connecting to roots on this process (I'll recv from them)
464     PetscCall(PetscSFGetRootRanks(ownerSF, &nranks, &ranks, &roffset, &rmine, NULL)); // get root ranks this process connects (I'll send to them)
465 
466     // Find out length of each row I will receive. Even for the same row index, when they are from
467     // different senders, they might have different lengths (and sparsity patterns)
468     PetscInt  sendRowCnt = roffset[nranks], recvRowCnt = ioffset[niranks];
469     PetscInt *sendRowLen, *recvRowLen; // lengths of rows of I need to send/recv per process
470 
471     PetscCall(PetscMalloc5(sendRowCnt, &sendRowLen, recvRowCnt + 1, &recvRowLen, nranks, &sdisp, niranks + 1, &rdisp, nranks + niranks, &reqs));
472 
473     for (PetscInt i = 0; i < sendRowCnt; i++) sendRowLen[i] = E_RowLen[rmine[i]];
474     recvRowLen[0] = 0; // since we will make it in CSR format later
475     recvRowLen++;      // advance the pointer now
476     for (PetscInt i = 0; i < niranks; i++) { MPI_Irecv(&recvRowLen[ioffset[i]], ioffset[i + 1] - ioffset[i], MPIU_INT, iranks[i], tag, comm, &reqs[nranks + i]); }
477     for (PetscInt i = 0; i < nranks; i++) { MPI_Isend(&sendRowLen[roffset[i]], roffset[i + 1] - roffset[i], MPIU_INT, ranks[i], tag, comm, &reqs[i]); }
478     PetscCallMPI(MPI_Waitall(nranks + niranks, reqs, MPI_STATUSES_IGNORE));
479 
480     // Build the real PetscSF for reducing E rows (buffer to buffer)
481     rdisp[0] = 0;
482     for (PetscInt i = 0; i < niranks; i++) {
483       rdisp[i + 1] = rdisp[i];
484       for (PetscInt j = ioffset[i]; j < ioffset[i + 1]; j++) { rdisp[i + 1] += recvRowLen[j]; }
485     }
486     recvRowLen--; // put it back into csr format
487     for (PetscInt i = 0; i < recvRowCnt; i++) recvRowLen[i + 1] += recvRowLen[i];
488 
489     for (PetscInt i = 0; i < nranks; i++) { MPI_Irecv(&sdisp[i], 1, MPIU_INT, ranks[i], tag, comm, &reqs[i]); }
490     for (PetscInt i = 0; i < niranks; i++) { MPI_Isend(&rdisp[i], 1, MPIU_INT, iranks[i], tag, comm, &reqs[nranks + i]); }
491     PetscCallMPI(MPI_Waitall(nranks + niranks, reqs, MPI_STATUSES_IGNORE));
492 
493     PetscInt     nleaves = 0, Enz = 0;    // leaves are nonzeros I will send
494     PetscInt     nroots = rdisp[niranks]; // roots are nonzeros I will recv
495     PetscSFNode *iremote;
496 
497     for (PetscInt i = 0; i < Em; i++) Enz += E_RowLen[i];
498     PetscAssert(A.nnz() + B.nnz() == Enz, comm, PETSC_ERR_PLIB, "Enz should be equal to sum of nnz of A and B");
499     PetscCallMPI(PetscMalloc1(Enz, &iremote)); // no free, since we give ownership to reduceSF
500 
501     for (PetscInt i = 0; i < nranks; i++) {
502       PetscInt count = 0;
503       for (PetscInt j = roffset[i]; j < roffset[i + 1]; j++) count += E_RowLen[rmine[j]];
504       for (PetscInt j = 0; j < count; j++) {
505         iremote[nleaves + j].rank  = ranks[i];
506         iremote[nleaves + j].index = sdisp[i] + j;
507       }
508       nleaves += count;
509     }
510     PetscCheck(nleaves == Enz, comm, PETSC_ERR_PLIB, "nleaves should be equal to Enz");
511 
512     PetscCall(PetscSFCreate(comm, &reduceSF));
513     PetscCall(PetscSFSetGraph(reduceSF, nroots, nleaves, NULL, PETSC_OWN_POINTER, iremote, PETSC_OWN_POINTER));
514 
515     // Copy (global) column indices of the needed rows in E to sendCol[], and then PetscSFReduce to recvCol[]
516     PetscInt *sendCol, *recvCol;
517     PetscCall(PetscMalloc2(nleaves, &sendCol, nroots, &recvCol));
518     for (PetscInt k = 0; k < roffset[nranks]; k++) {
519       PetscInt  i      = rmine[k]; // row to be copied
520       PetscInt *buf    = &sendCol[Ai[i] + Bi[i]];
521       PetscInt  nzLeft = E_NzLeft[i];
522       PetscInt  alen = Ai[i + 1] - Ai[i], blen = Bi[i + 1] - Bi[i];
523       for (PetscInt j = 0; j < alen + blen; j++) {
524         if (j < nzLeft) {
525           buf[j] = garray1[Bj[Bi[i] + j]]; // left B, in global
526         } else if (j < nzLeft + alen) {
527           buf[j] = Aj[Ai[i] + j - nzLeft] + cstart; // diag A, also in global
528         } else {
529           buf[j] = garray1[Bj[Bi[i] + j - alen]]; // right B, in global
530         }
531       }
532     }
533     PetscCall(PetscSFReduceWithMemTypeBegin(reduceSF, MPIU_INT, PETSC_MEMTYPE_HOST, sendCol, PETSC_MEMTYPE_HOST, recvCol, MPI_REPLACE));
534     PetscCall(PetscSFReduceEnd(reduceSF, MPIU_INT, sendCol, recvCol, MPI_REPLACE));
535 
536     // With recvCol[], we do a series of analysis to get i, j of Fd, Fo, and build plans to reduce nonzeros in recv buffers to Fd and Fo
537     PetscInt *recvRowPerm, *recvColSorted;
538     PetscInt *recvNzPerm, *recvNzPermSorted;
539     PetscCall(PetscMalloc4(recvRowCnt, &recvRowPerm, nroots, &recvColSorted, nroots, &recvNzPerm, nroots, &recvNzPermSorted));
540 
541     for (PetscInt i = 0; i < nroots; i++) recvNzPerm[i] = i;                   // numbering all received nonzeros
542     for (PetscInt i = 0; i < recvRowCnt; i++) recvRowPerm[i] = i;              // put up a permutation array, so that after sorting we know where to get a row in recvCol[]
543     PetscCall(PetscSortIntWithPermutation(recvRowCnt, irootloc, recvRowPerm)); // irootloc[] (owned by ownerSF) won't be changed
544 
545     // i[] array, nz are always easiest to compute
546     MatRowMapKokkosViewHost Fdi_h(NoInit("Fdi_h"), Fm + 1), Foi_h(NoInit("Foi_h"), Fm + 1);
547     MatRowMapType          *Fdi, *Foi;
548     PetscInt                FnzDups = 0, Fdnz = 0, FdnzDups = 0, Fonz = 0, FonzDups = 0; // nz (with or without dups) in F, Fd, Fo
549     PetscInt                iter;
550 
551     Kokkos::deep_copy(Fdi_h, 0); // zero, as we will do 'val++' on them
552     Kokkos::deep_copy(Foi_h, 0);
553     Fdi  = Fdi_h.data() + 1; // +1 for easy indexing in code below
554     Foi  = Foi_h.data() + 1;
555     iter = 0;
556     while (iter < recvRowCnt) { // iter over received rows
557       PetscInt curRowIdx = irootloc[recvRowPerm[iter]];
558       PetscInt dupRows   = 1; // current row has this many contributing rows (of various sparsity patterns)
559 
560       while (iter + dupRows < recvRowCnt && irootloc[recvRowPerm[iter + dupRows]] == curRowIdx) dupRows++;
561 
562       // Copy column indices (and their permutation) of these rows into recvColSorted & recvNzPermSorted
563       PetscInt  nz    = 0; // nz (with dups) in the current row
564       PetscInt *jbuf  = recvColSorted + FnzDups;
565       PetscInt *pbuf  = recvNzPermSorted + FnzDups;
566       PetscInt *jbuf2 = jbuf; // temp pointers
567       PetscInt *pbuf2 = pbuf;
568       for (PetscInt d = 0; d < dupRows; d++) {
569         PetscInt i   = recvRowPerm[iter + d];
570         PetscInt len = recvRowLen[i + 1] - recvRowLen[i];
571         PetscCall(PetscArraycpy(jbuf2, &recvCol[recvRowLen[i]], len));
572         PetscCall(PetscArraycpy(pbuf2, &recvNzPerm[recvRowLen[i]], len));
573         jbuf2 += len;
574         pbuf2 += len;
575         nz += len;
576       }
577       PetscCall(PetscIntSortSemiOrderedWithArray(nz, jbuf, pbuf)); // It could be improved with k-way merge sort, since the rows are already sorted
578 
579       // Scan column indices (in jbuf[0,nz), might have dups) of this row, and see how many go to Fd and how many go to Fo
580       PetscInt cur = 0;
581       while (cur < nz) {
582         PetscInt curColIdx = jbuf[cur];
583         PetscInt dups      = 1;
584 
585         while (cur + dups < nz && jbuf[cur + dups] == curColIdx) dups++;
586         if (curColIdx >= cstart && curColIdx < cend) {
587           Fdi[curRowIdx]++;
588           FdnzDups += dups;
589         } else {
590           Foi[curRowIdx]++;
591           FonzDups += dups;
592         }
593         cur += dups;
594       }
595 
596       FnzDups += nz;
597       iter += dupRows; // Move to next unique row
598     }
599 
600     Fdi = Fdi_h.data(); // restore Fdi, Foi and make them CSR
601     Foi = Foi_h.data();
602     for (PetscInt i = 0; i < Fm; i++) {
603       Fdi[i + 1] += Fdi[i];
604       Foi[i + 1] += Foi[i];
605     }
606     Fdnz = Fdi[Fm];
607     Fonz = Foi[Fm];
608     PetscCall(PetscFree2(sendCol, recvCol));
609 
610     // Allocate j, jmap, jperm for Fd and Fo
611     MatColIdxKokkosViewHost Fdj_h(NoInit("Fdj_h"), Fdnz), Foj_h(NoInit("Foj_h"), Fonz);
612     MatRowMapKokkosViewHost Fdjmap_h(NoInit("Fdjmap_h"), Fdnz + 1), Fojmap_h(NoInit("Fojmap_h"), Fonz + 1); // +1 to make csr
613     MatRowMapKokkosViewHost Fdjperm_h(NoInit("Fdjperm_h"), FdnzDups), Fojperm_h(NoInit("Fojperm_h"), FonzDups);
614     MatColIdxType          *Fdj = Fdj_h.data(), *Foj = Foj_h.data();
615     MatRowMapType          *Fdjmap = Fdjmap_h.data(), *Fojmap = Fojmap_h.data();
616     MatRowMapType          *Fdjperm = Fdjperm_h.data(), *Fojperm = Fojperm_h.data();
617 
618     // Scan recvColSorted[] again, and fill j, jmap, jperm for Fd and Fo
619     Fdjmap[0] = 0;
620     Fojmap[0] = 0;
621     FnzDups   = 0;
622     Fdnz      = 0;
623     Fonz      = 0;
624     iter      = 0; // iter over received rows
625     while (iter < recvRowCnt) {
626       PetscInt curRowIdx = irootloc[recvRowPerm[iter]]; // current row idx
627       PetscInt dupRows   = 1;                           // It has this many contributing rows (of various lengths)
628       PetscInt nz        = 0;                           // nz (with dups) in the current row
629 
630       while (iter + dupRows < recvRowCnt && irootloc[recvRowPerm[iter + dupRows]] == curRowIdx) dupRows++;
631       for (PetscInt d = 0; d < dupRows; d++) {
632         PetscInt i = recvRowPerm[iter + d];
633         nz += recvRowLen[i + 1] - recvRowLen[i];
634       }
635 
636       PetscInt *jbuf = recvColSorted + FnzDups;
637       // Scan columns (in jbuf[0,nz) of this row, copy them and their permutation to j[] and jperm[] of Fd and Fo
638       PetscInt cur = 0;
639       while (cur < nz) {
640         PetscInt curColIdx = jbuf[cur];
641         PetscInt dups      = 1;
642 
643         while (cur + dups < nz && jbuf[cur + dups] == curColIdx) dups++;
644         if (curColIdx >= cstart && curColIdx < cend) {
645           Fdj[Fdnz]        = curColIdx - cstart; // easily convert to local
646           Fdjmap[Fdnz + 1] = Fdjmap[Fdnz] + dups;
647           for (PetscInt j = 0; j < dups; j++) Fdjperm[Fdjmap[Fdnz] + j] = recvNzPermSorted[FnzDups + j];
648           FdnzDups += dups;
649           Fdnz++;
650         } else {
651           Foj[Fonz]        = curColIdx; // in global
652           Fojmap[Fonz + 1] = Fojmap[Fonz] + dups;
653           for (PetscInt j = 0; j < dups; j++) Fojperm[Fojmap[Fonz] + j] = recvNzPermSorted[FnzDups + j];
654           FonzDups += dups;
655           Fonz++;
656         }
657         cur += dups;
658         FnzDups += dups;
659       }
660       iter += dupRows; // Move to next unique row
661     }
662     PetscCall(PetscFree4(recvRowPerm, recvColSorted, recvNzPerm, recvNzPermSorted));
663     PetscCall(PetscFree5(sendRowLen, recvRowLen, sdisp, rdisp, reqs));
664 
665     // Combine global column indices in garray1[] and Foj[]
666     PetscInt n2, *garray2;
667 
668     PetscCall(ReduceTwoSetsOfGlobalIndices(n1, garray1, Fonz, Foj, &n2, &garray2, map));
669     mm->sf       = reduceSF;
670     mm->leafBuf  = MatScalarKokkosView(NoInit("leafBuf"), nleaves);
671     mm->rootBuf  = MatScalarKokkosView(NoInit("rootBuf"), nroots);
672     mm->garray   = garray2; // give ownership, so no free
673     mm->n        = n2;
674     mm->E_NzLeft = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), E_NzLeft_h);
675     mm->Fdjmap   = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Fdjmap_h);
676     mm->Fdjperm  = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Fdjperm_h);
677     mm->Fojmap   = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Fojmap_h);
678     mm->Fojperm  = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Fojperm_h);
679 
680     // Output Fd and Fo in KokkosCsrMatrix format
681     MatScalarKokkosView Fda_d(NoInit("Fda_d"), Fdnz);
682     MatRowMapKokkosView Fdi_d = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Fdi_h);
683     MatColIdxKokkosView Fdj_d = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Fdj_h);
684     MatScalarKokkosView Foa_d(NoInit("Foa_d"), Fonz);
685     MatRowMapKokkosView Foi_d = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Foi_h);
686     MatColIdxKokkosView Foj_d = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Foj_h);
687 
688     PetscCallCXX(mm->Fd = KokkosCsrMatrix("Fd", Fm, cend - cstart, Fdnz, Fda_d, Fdi_d, Fdj_d));
689     PetscCallCXX(mm->Fo = KokkosCsrMatrix("Fo", Fm, n2, Fonz, Foa_d, Foi_d, Foj_d)); // Fo's column size is n2, length of garray2[]
690 
691     // Compute kernel launch parameters in merging E
692     PetscInt teamSize, vectorLength, rowsPerTeam;
693 
694     teamSize = vectorLength = rowsPerTeam = -1;
695     PetscCall(MatMergeGetLaunchParameters<DefaultExecutionSpace>(Em, Enz, -1, teamSize, vectorLength, rowsPerTeam));
696     mm->E_TeamSize     = teamSize;
697     mm->E_VectorLength = vectorLength;
698     mm->E_RowsPerTeam  = rowsPerTeam;
699   } else PetscCheck(reuse == MAT_REUSE_MATRIX, comm, PETSC_ERR_PLIB, "Unsupported MatReuse enum %d", reuse);
700 
701   // Handy aliases
702   auto       &Aa           = A.values;
703   auto       &Ba           = B.values;
704   const auto &Ai           = A.graph.row_map;
705   const auto &Bi           = B.graph.row_map;
706   const auto &E_NzLeft     = mm->E_NzLeft;
707   auto       &leafBuf      = mm->leafBuf;
708   auto       &rootBuf      = mm->rootBuf;
709   PetscSF     reduceSF     = mm->sf;
710   PetscInt    Em           = A.numRows();
711   PetscInt    teamSize     = mm->E_TeamSize;
712   PetscInt    vectorLength = mm->E_VectorLength;
713   PetscInt    rowsPerTeam  = mm->E_RowsPerTeam;
714   PetscInt    workSets     = (Em + rowsPerTeam - 1) / rowsPerTeam;
715 
716   // Copy rows in A/B of E to leafBuf, then pass it to rootBuf
717   PetscCallCXX(Kokkos::parallel_for(
718     Kokkos::TeamPolicy<>(workSets, teamSize, vectorLength), KOKKOS_LAMBDA(const KokkosTeamMemberType &t) {
719       Kokkos::parallel_for(Kokkos::TeamThreadRange(t, 0, rowsPerTeam), [&](PetscInt k) {
720         PetscInt i = t.league_rank() * rowsPerTeam + k; // i-th row in F
721         if (i < Em) {
722           PetscInt disp   = Ai(i) + Bi(i);
723           PetscInt alen   = Ai(i + 1) - Ai(i);
724           PetscInt blen   = Bi(i + 1) - Bi(i);
725           PetscInt nzleft = E_NzLeft(i);
726 
727           Kokkos::parallel_for(Kokkos::ThreadVectorRange(t, alen + blen), [&](PetscInt j) {
728             MatScalar &val = leafBuf(disp + j);
729             if (j < nzleft) { // B left
730               val = Ba(Bi(i) + j);
731             } else if (j < nzleft + alen) { // diag A
732               val = Aa(Ai(i) + j - nzleft);
733             } else { // B right
734               val = Ba(Bi(i) + j - alen);
735             }
736           });
737         }
738       });
739     }));
740   PetscCall(PetscSFReduceWithMemTypeBegin(reduceSF, MPIU_SCALAR, PETSC_MEMTYPE_KOKKOS, leafBuf.data(), PETSC_MEMTYPE_KOKKOS, rootBuf.data(), MPI_REPLACE));
741   PetscFunctionReturn(PETSC_SUCCESS);
742 }
743 
744 // To finish MatMPIAIJKokkosReduce.
745 static PetscErrorCode MatMPIAIJKokkosReduceEnd(MPI_Comm comm, KokkosCsrMatrix A, KokkosCsrMatrix B, PetscInt cstart, PetscInt cend, const PetscInt *garray1, PetscSF ownerSF, MatReuse reuse, PetscInt *map, MatMatStruct_AtB *mm)
746 {
747   PetscFunctionBegin;
748   auto       &leafBuf  = mm->leafBuf;
749   auto       &rootBuf  = mm->rootBuf;
750   auto       &Fda      = mm->Fd.values;
751   const auto &Fdjmap   = mm->Fdjmap;
752   const auto &Fdjperm  = mm->Fdjperm;
753   auto        Fdnz     = mm->Fd.nnz();
754   auto       &Foa      = mm->Fo.values;
755   const auto &Fojmap   = mm->Fojmap;
756   const auto &Fojperm  = mm->Fojperm;
757   auto        Fonz     = mm->Fo.nnz();
758   PetscSF     reduceSF = mm->sf;
759 
760   PetscCall(PetscSFReduceEnd(reduceSF, MPIU_SCALAR, leafBuf.data(), rootBuf.data(), MPI_REPLACE));
761 
762   // Reduce data in rootBuf to Fd and Fo
763   PetscCallCXX(Kokkos::parallel_for(
764     Fdnz, KOKKOS_LAMBDA(const MatRowMapType i) {
765       PetscScalar sum = 0.0;
766       for (MatRowMapType k = Fdjmap(i); k < Fdjmap(i + 1); k++) sum += rootBuf(Fdjperm(k));
767       Fda(i) = sum;
768     }));
769 
770   PetscCallCXX(Kokkos::parallel_for(
771     Fonz, KOKKOS_LAMBDA(const MatRowMapType i) {
772       PetscScalar sum = 0.0;
773       for (MatRowMapType k = Fojmap(i); k < Fojmap(i + 1); k++) sum += rootBuf(Fojperm(k));
774       Foa(i) = sum;
775     }));
776   PetscFunctionReturn(PETSC_SUCCESS);
777 }
778 
779 /*
780   MatMPIAIJKokkosBcast - Bcast local rows of a MPIAIJKOKKOS matrix (E) to produce a local matrix (F, stored in mm) in split form
781 
782   This is a complex routine. It is essentially the MPIAIJKOKKOS counterpart of MatGetBrowsOfAoCols_MPIAIJ, but supports
783   device and involves various index mapping.
784 
785   In the given ownerSF, leaves correspond to rows in F, and roots correspond to rows in E. Roots may connect to multiple leaves.
786   Suppose F's j-th row is connected to a root identified by PetscSFNode (k,i), it means we need to bcast the i-th row of E on rank k
787   to j-th row of F. ownerSF is not an arbitrary SF, instead it is the Mvctx of another MPIAIJ matrix A that is able to perform A*E.
788   F has the same column layout as E.
789 
790   Conceptually F has global column indices. In this routine, we spit F into diagonal Fd and off-diagonal Fo.
791   Fd uses local column indices, which are easy to compute. We just need to subtract the "local column range start" from the global indices.
792   Fo had global column indices at first. We will reduce them into local ones. In doing that, we also take into account the global
793   column indices that E's off-diag block has. Let's say there are n1 such indices stored in garray1[]. We will reduce them along with
794   column indices in Fo and update Fo with local indices.
795 
796    Input Parameters:
797 +   E       - the MPIAIJKOKKOS matrix
798 .   ownerSF - the ownership SF (insignificant in MAT_REUSE_MATRIX)
799 .   reuse   - either MAT_INITIAL_MATRIX or MAT_REUSE_MATRIX
800 -   mm      - to stash matproduct intermediate data structures
801 
802     Output Parameters:
803 +   map[n1] - allocated by caller. It maps garray1[] to garray2[]. See more at ReduceTwoSetsOfGlobalIndices.
804 -   mm      - contains various info, such as garray2[], Fd, Fo, etc.
805 
806     Notes:
807     When reuse = MAT_REUSE_MATRIX, ownerSF, map are not significant.
808     The routine is provide in split-phase form MatMPIAIJKokkosBcastBegin/End() to provide computation/communication opportunities.
809 */
810 static PetscErrorCode MatMPIAIJKokkosBcastBegin(Mat E, PetscSF ownerSF, MatReuse reuse, PetscInt *map, MatMatStruct_AB *mm)
811 {
812   Mat_MPIAIJ       *empi = static_cast<Mat_MPIAIJ *>(E->data);
813   Mat               A = empi->A, B = empi->B; // diag and off-diag
814   Mat_SeqAIJKokkos *akok = static_cast<Mat_SeqAIJKokkos *>(A->spptr), *bkok = static_cast<Mat_SeqAIJKokkos *>(B->spptr);
815   PetscInt          Em = E->rmap->n; // #local rows
816   MPI_Comm          comm;
817 
818   PetscFunctionBegin;
819   PetscCallMPI(PetscObjectGetComm((PetscObject)E, &comm));
820   if (reuse == MAT_INITIAL_MATRIX) {
821     Mat_SeqAIJ     *aseq = static_cast<Mat_SeqAIJ *>(A->data), *bseq = static_cast<Mat_SeqAIJ *>(B->data);
822     PetscInt        n1 = B->cmap->n, *Ai = aseq->i, *Aj = aseq->j, *Bi = bseq->i, *Bj = bseq->j;
823     const PetscInt *garray1 = empi->garray; // its size is n1
824     PetscInt        cstart, cend;
825     PetscSF         bcastSF;
826 
827     PetscCall(MatGetOwnershipRangeColumn(E, &cstart, &cend));
828 
829     // Count how many nonzeros of each row in E are in the left of the diag block [cstart,cend)
830     PetscIntKokkosViewHost E_NzLeft_h(NoInit("E_NzLeft_h"), Em), E_RowLen_h(NoInit("E_RowLen_h"), Em);
831     PetscInt              *E_NzLeft = E_NzLeft_h.data(), *E_RowLen = E_RowLen_h.data();
832     for (PetscInt i = 0; i < Em; i++) {
833       const PetscInt *first, *last, *it;
834       PetscInt        count, step;
835       // std::lower_bound(first,last,cstart), but need to use global column indices
836       first = Bj + Bi[i];
837       last  = Bj + Bi[i + 1];
838       count = last - first;
839       while (count > 0) {
840         it   = first;
841         step = count / 2;
842         it += step;
843         if (empi->garray[*it] < cstart) { // map local to global
844           first = ++it;
845           count -= step + 1;
846         } else count = step;
847       }
848       E_NzLeft[i] = first - (Bj + Bi[i]);
849       E_RowLen[i] = (Ai[i + 1] - Ai[i]) + (Bi[i + 1] - Bi[i]);
850     }
851 
852     // Compute row pointer Fi of F
853     PetscInt *Fi, Fm, Fnz;
854     PetscCall(PetscSFGetGraph(ownerSF, NULL, &Fm, NULL, NULL)); // Fm = #rows of F = nleaves of ownerSF
855     PetscCall(PetscMalloc1(Fm + 1, &Fi));
856     Fi[0] = 0;
857     PetscCall(PetscSFBcastWithMemTypeBegin(ownerSF, MPIU_INT, PETSC_MEMTYPE_HOST, E_RowLen, PETSC_MEMTYPE_HOST, &Fi[1], MPI_REPLACE));
858     PetscCall(PetscSFBcastEnd(ownerSF, MPIU_INT, E_RowLen, &Fi[1], MPI_REPLACE));
859     for (PetscInt i = 0; i < Fm; i++) Fi[i + 1] += Fi[i];
860     Fnz = Fi[Fm];
861 
862     // Build the real PetscSF for bcasting E rows (buffer to buffer)
863     const PetscMPIInt *iranks, *ranks;
864     const PetscInt    *ioffset, *irootloc, *roffset;
865     PetscInt           niranks, nranks, *sdisp, *rdisp;
866     MPI_Request       *reqs;
867     PetscMPIInt        tag;
868 
869     PetscCall(PetscSFGetLeafRanks(ownerSF, &niranks, &iranks, &ioffset, &irootloc)); // get leaf ranks referencing roots on this process
870     PetscCall(PetscSFGetRootRanks(ownerSF, &nranks, &ranks, &roffset, NULL, NULL));  // recv info
871     PetscCall(PetscMalloc3(niranks + 1, &sdisp, nranks, &rdisp, niranks + nranks, &reqs));
872 
873     sdisp[0] = 0; // send displacement
874     for (PetscInt i = 0; i < niranks; i++) {
875       sdisp[i + 1] = sdisp[i];
876       for (PetscInt j = ioffset[i]; j < ioffset[i + 1]; j++) {
877         PetscInt r = irootloc[j]; // row to be sent
878         sdisp[i + 1] += E_RowLen[r];
879       }
880     }
881 
882     PetscCallMPI(PetscCommGetNewTag(comm, &tag));
883     for (PetscInt i = 0; i < nranks; i++) PetscCallMPI(MPI_Irecv(&rdisp[i], 1, MPIU_INT, ranks[i], tag, comm, &reqs[i]));
884     for (PetscInt i = 0; i < niranks; i++) PetscCallMPI(MPI_Isend(&sdisp[i], 1, MPIU_INT, iranks[i], tag, comm, &reqs[nranks + i]));
885     PetscCallMPI(MPI_Waitall(niranks + nranks, reqs, MPI_STATUSES_IGNORE));
886 
887     PetscInt     nleaves = Fnz;            // leaves are nonzeros I will receive
888     PetscInt     nroots  = sdisp[niranks]; // roots are nonzeros I will send
889     PetscSFNode *iremote;                  // give ownership to bcastSF
890     PetscCall(PetscMalloc1(nleaves, &iremote));
891     for (PetscInt i = 0; i < nranks; i++) { // for each sender rank
892       PetscInt k = 0;
893       for (PetscInt j = Fi[roffset[i]]; j < Fi[roffset[i + 1]]; j++) { // I will receive rows [roffset[i], roffset[i+1]) of F from ranks[i]
894         iremote[j].rank  = ranks[i];
895         iremote[j].index = rdisp[i] + k; // their root location
896         k++;
897       }
898     }
899     PetscCall(PetscSFCreate(comm, &bcastSF));
900     PetscCall(PetscSFSetGraph(bcastSF, nroots, nleaves, NULL, PETSC_OWN_POINTER, iremote, PETSC_OWN_POINTER));
901     PetscCall(PetscFree3(sdisp, rdisp, reqs));
902 
903     // Build a plan (rowoffset, irootloc, E_NzLeft) to copy rows in E to rootdata of bcastSF in parallel
904     PetscIntKokkosViewHost rowoffset_h(NoInit("rowoffset_h"), ioffset[niranks] + 1);
905     PetscInt              *rowoffset = rowoffset_h.data(); // for each entry (row) indicated in irootloc[], we calculate its destinate offset in copying
906     rowoffset[0]                     = 0;
907     for (PetscInt i = 0; i < ioffset[niranks]; i++) { rowoffset[i + 1] = rowoffset[i] + E_RowLen[irootloc[i]]; }
908 
909     // Copy (global) column indices of the needed rows in E to a buffer, and then bcast to Fj[]
910     PetscInt *jbuf, *Fj;
911     PetscCall(PetscMalloc2(nroots, &jbuf, Fnz, &Fj));
912     for (PetscInt k = 0; k < ioffset[niranks]; k++) {
913       PetscInt  i      = irootloc[k]; // row to be copied
914       PetscInt *buf    = &jbuf[rowoffset[k]];
915       PetscInt  nzLeft = E_NzLeft[i];
916       PetscInt  alen = Ai[i + 1] - Ai[i], blen = Bi[i + 1] - Bi[i];
917       for (PetscInt j = 0; j < alen + blen; j++) {
918         if (j < nzLeft) {
919           buf[j] = empi->garray[Bj[Bi[i] + j]]; // left B, in global
920         } else if (j < nzLeft + alen) {
921           buf[j] = Aj[Ai[i] + j - nzLeft] + cstart; // diag A, also in global
922         } else {
923           buf[j] = empi->garray[Bj[Bi[i] + j - alen]]; // right B, in global
924         }
925       }
926     }
927     PetscCall(PetscSFBcastWithMemTypeBegin(bcastSF, MPIU_INT, PETSC_MEMTYPE_HOST, jbuf, PETSC_MEMTYPE_HOST, Fj, MPI_REPLACE));
928     PetscCall(PetscSFBcastEnd(bcastSF, MPIU_INT, jbuf, Fj, MPI_REPLACE));
929 
930     // Build a plan (i.e., F_NzLeft) to split F into Fd and Fo
931     MatRowMapKokkosViewHost Fdi_h(NoInit("Fdi_h"), Fm + 1), Foi_h(NoInit("Foi_h"), Fm + 1); // row pointer of Fd, Fo
932     MatColIdxKokkosViewHost F_NzLeft_h(NoInit("F_NzLeft_h"), Fm);                           // split each row of F into Left, Diag, Right. We only need to record #nz in Left and Diag.
933     MatRowMapType          *Fdi = Fdi_h.data(), *Foi = Foi_h.data();
934     MatColIdxType          *F_NzLeft = F_NzLeft_h.data();
935 
936     Fdi[0] = Foi[0] = 0;
937     for (PetscInt i = 0; i < Fm; i++) {
938       PetscInt *first, *last, *lb1, *lb2;
939       // cut the row into: Left, [cstart, cend), Right
940       first       = Fj + Fi[i];
941       last        = Fj + Fi[i + 1];
942       lb1         = std::lower_bound(first, last, cstart);
943       F_NzLeft[i] = lb1 - first;
944       lb2         = std::lower_bound(first, last, cend);
945       Fdi[i + 1]  = lb2 - lb1;                        // row i length in Fdi
946       Foi[i + 1]  = (Fi[i + 1] - Fi[i]) - Fdi[i + 1]; // row i length in Foi
947     }
948     for (PetscInt i = 0; i < Fm; i++) {
949       Fdi[i + 1] += Fdi[i];
950       Foi[i + 1] += Foi[i];
951     }
952 
953     // Fill Fdj[] and Foj[], i.e., columns of Fd and Fo. Fdj[] are local, but Foj[] are not yet.
954     PetscInt                Fdnz = Fdi[Fm], Fonz = Foi[Fm];
955     MatColIdxKokkosViewHost Fdj_h(NoInit("Fdj_h"), Fdnz), Foj_h(NoInit("Foj_h"), Fonz);
956     MatColIdxType          *Fdj = Fdj_h.data(), *Foj = Foj_h.data(), gid;
957 
958     for (PetscInt i = 0; i < Fm; i++) {
959       PetscInt nzLeft = F_NzLeft[i];
960       PetscInt len    = Fdi[i + 1] - Fdi[i]; // diag row len
961       for (PetscInt j = 0; j < Fi[i + 1] - Fi[i]; j++) {
962         gid = Fj[Fi[i] + j];
963         if (j < nzLeft) { // left, in global
964           Foj[Foi[i] + j] = gid;
965         } else if (j < nzLeft + len) { // diag, in local
966           Fdj[Fdi[i] + j - nzLeft] = gid - cstart;
967         } else { // right, in global
968           Foj[Foi[i] + j - len] = gid;
969         }
970       }
971     }
972     PetscCall(PetscFree2(jbuf, Fj));
973     PetscCall(PetscFree(Fi));
974 
975     // Reduce global indices in Foj[] and garray1[] into local ones
976     PetscInt n2, *garray2;
977     PetscCall(ReduceTwoSetsOfGlobalIndices(n1, garray1, Fonz, Foj, &n2, &garray2, map));
978 
979     // Record the plans built above, for reuse
980     PetscIntKokkosViewHost tmp(const_cast<PetscInt *>(irootloc), ioffset[niranks]); // irootloc[] is owned by ownerSF. We create a copy for safety
981     PetscIntKokkosViewHost irootloc_h(NoInit("irootloc_h"), ioffset[niranks]);
982     Kokkos::deep_copy(irootloc_h, tmp);
983     mm->sf        = bcastSF;
984     mm->E_NzLeft  = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), E_NzLeft_h);
985     mm->F_NzLeft  = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), F_NzLeft_h);
986     mm->irootloc  = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), irootloc_h);
987     mm->rowoffset = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), rowoffset_h);
988     mm->rootBuf   = MatScalarKokkosView(NoInit("rootBuf"), nroots);
989     mm->leafBuf   = MatScalarKokkosView(NoInit("leafBuf"), nleaves);
990     mm->garray    = garray2;
991     mm->n         = n2;
992 
993     // Output Fd and Fo in KokkosCsrMatrix format
994     MatScalarKokkosView Fda_d(NoInit("Fda_d"), Fdnz), Foa_d(NoInit("Foa_d"), Fonz);
995     MatRowMapKokkosView Fdi_d = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Fdi_h);
996     MatColIdxKokkosView Fdj_d = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Fdj_h);
997     MatRowMapKokkosView Foi_d = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Foi_h);
998     MatColIdxKokkosView Foj_d = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Foj_h);
999 
1000     PetscCallCXX(mm->Fd = KokkosCsrMatrix("Fd", Fm, cend - cstart, Fdnz, Fda_d, Fdi_d, Fdj_d));
1001     PetscCallCXX(mm->Fo = KokkosCsrMatrix("Fo", Fm, n2, Fonz, Foa_d, Foi_d, Foj_d));
1002 
1003     // Compute kernel launch parameters in merging E or splitting F
1004     PetscInt teamSize, vectorLength, rowsPerTeam;
1005 
1006     teamSize = vectorLength = rowsPerTeam = -1;
1007     PetscCall(MatMergeGetLaunchParameters<DefaultExecutionSpace>(mm->irootloc.extent(0), mm->rootBuf.extent(0), -1, teamSize, vectorLength, rowsPerTeam));
1008     mm->E_TeamSize     = teamSize;
1009     mm->E_VectorLength = vectorLength;
1010     mm->E_RowsPerTeam  = rowsPerTeam;
1011 
1012     teamSize = vectorLength = rowsPerTeam = -1;
1013     PetscCall(MatMergeGetLaunchParameters<DefaultExecutionSpace>(Fm, Fnz, -1, teamSize, vectorLength, rowsPerTeam));
1014     mm->F_TeamSize     = teamSize;
1015     mm->F_VectorLength = vectorLength;
1016     mm->F_RowsPerTeam  = rowsPerTeam;
1017   } else PetscCheck(reuse == MAT_REUSE_MATRIX, comm, PETSC_ERR_PLIB, "Unsupported MatReuse enum %d", reuse);
1018 
1019   // Sync E's value to device
1020   akok->a_dual.sync_device();
1021   bkok->a_dual.sync_device();
1022 
1023   // Handy aliases
1024   const auto &Aa = akok->a_dual.view_device();
1025   const auto &Ba = bkok->a_dual.view_device();
1026   const auto &Ai = akok->i_dual.view_device();
1027   const auto &Bi = bkok->i_dual.view_device();
1028 
1029   // Fetch the plans
1030   PetscIntKokkosView  &E_NzLeft  = mm->E_NzLeft;
1031   PetscSF             &bcastSF   = mm->sf;
1032   MatScalarKokkosView &rootBuf   = mm->rootBuf;
1033   MatScalarKokkosView &leafBuf   = mm->leafBuf;
1034   PetscIntKokkosView  &irootloc  = mm->irootloc;
1035   PetscIntKokkosView  &rowoffset = mm->rowoffset;
1036 
1037   PetscInt teamSize     = mm->E_TeamSize;
1038   PetscInt vectorLength = mm->E_VectorLength;
1039   PetscInt rowsPerTeam  = mm->E_RowsPerTeam;
1040   PetscInt workSets     = (irootloc.extent(0) + rowsPerTeam - 1) / rowsPerTeam;
1041 
1042   // Copy rows in A/B of E to rootBuf, then bcast it to leafBuf
1043   PetscCallCXX(Kokkos::parallel_for(
1044     Kokkos::TeamPolicy<>(workSets, teamSize, vectorLength), KOKKOS_LAMBDA(const KokkosTeamMemberType &t) {
1045       Kokkos::parallel_for(Kokkos::TeamThreadRange(t, 0, rowsPerTeam), [&](PetscInt k) {
1046         size_t r = t.league_rank() * rowsPerTeam + k; // r-th entry in irootloc[]
1047         if (r < irootloc.extent(0)) {
1048           PetscInt i      = irootloc(r); // row i of E
1049           PetscInt disp   = rowoffset(r);
1050           PetscInt alen   = Ai(i + 1) - Ai(i);
1051           PetscInt blen   = Bi(i + 1) - Bi(i);
1052           PetscInt nzleft = E_NzLeft(i);
1053 
1054           Kokkos::parallel_for(Kokkos::ThreadVectorRange(t, alen + blen), [&](PetscInt j) {
1055             if (j < nzleft) { // B left
1056               rootBuf(disp + j) = Ba(Bi(i) + j);
1057             } else if (j < nzleft + alen) { // diag A
1058               rootBuf(disp + j) = Aa(Ai(i) + j - nzleft);
1059             } else { // B right
1060               rootBuf(disp + j) = Ba(Bi(i) + j - alen);
1061             }
1062           });
1063         }
1064       });
1065     }));
1066   PetscCall(PetscSFBcastWithMemTypeBegin(bcastSF, MPIU_SCALAR, PETSC_MEMTYPE_KOKKOS, rootBuf.data(), PETSC_MEMTYPE_KOKKOS, leafBuf.data(), MPI_REPLACE));
1067   PetscFunctionReturn(PETSC_SUCCESS);
1068 }
1069 
1070 // To finish MatMPIAIJKokkosBcast.
1071 static PetscErrorCode MatMPIAIJKokkosBcastEnd(Mat E, PetscSF ownerSF, MatReuse reuse, PetscInt *map, MatMatStruct_AB *mm)
1072 {
1073   PetscFunctionBegin;
1074   const auto &Fd  = mm->Fd;
1075   const auto &Fo  = mm->Fo;
1076   const auto &Fdi = Fd.graph.row_map;
1077   const auto &Foi = Fo.graph.row_map;
1078   auto       &Fda = Fd.values;
1079   auto       &Foa = Fo.values;
1080   auto        Fm  = Fd.numRows();
1081 
1082   PetscIntKokkosView  &F_NzLeft     = mm->F_NzLeft;
1083   PetscSF             &bcastSF      = mm->sf;
1084   MatScalarKokkosView &rootBuf      = mm->rootBuf;
1085   MatScalarKokkosView &leafBuf      = mm->leafBuf;
1086   PetscInt             teamSize     = mm->F_TeamSize;
1087   PetscInt             vectorLength = mm->F_VectorLength;
1088   PetscInt             rowsPerTeam  = mm->F_RowsPerTeam;
1089   PetscInt             workSets     = (Fm + rowsPerTeam - 1) / rowsPerTeam;
1090 
1091   PetscCall(PetscSFBcastEnd(bcastSF, MPIU_SCALAR, rootBuf.data(), leafBuf.data(), MPI_REPLACE));
1092 
1093   // Update Fda and Foa with new data in leafBuf (as if it is Fa)
1094   PetscCallCXX(Kokkos::parallel_for(
1095     Kokkos::TeamPolicy<>(workSets, teamSize, vectorLength), KOKKOS_LAMBDA(const KokkosTeamMemberType &t) {
1096       Kokkos::parallel_for(Kokkos::TeamThreadRange(t, 0, rowsPerTeam), [&](PetscInt k) {
1097         PetscInt i = t.league_rank() * rowsPerTeam + k; // i-th row in F
1098         if (i < Fm) {
1099           PetscInt nzLeft = F_NzLeft(i);
1100           PetscInt alen   = Fdi(i + 1) - Fdi(i);
1101           PetscInt blen   = Foi(i + 1) - Foi(i);
1102           PetscInt Fii    = Fdi(i) + Foi(i);
1103 
1104           Kokkos::parallel_for(Kokkos::ThreadVectorRange(t, alen + blen), [&](PetscInt j) {
1105             PetscScalar val = leafBuf(Fii + j);
1106             if (j < nzLeft) { // left
1107               Foa(Foi(i) + j) = val;
1108             } else if (j < nzLeft + alen) { // diag
1109               Fda(Fdi(i) + j - nzLeft) = val;
1110             } else { // right
1111               Foa(Foi(i) + j - alen) = val;
1112             }
1113           });
1114         }
1115       });
1116     }));
1117   PetscFunctionReturn(PETSC_SUCCESS);
1118 }
1119 
1120 static PetscErrorCode MatProductSymbolic_MPIAIJKokkos_AtB(Mat_Product *product, Mat A, Mat B, MatMatStruct_AtB *mm)
1121 {
1122   Mat_MPIAIJ     *ampi = static_cast<Mat_MPIAIJ *>(A->data);
1123   Mat_MPIAIJ     *bmpi = static_cast<Mat_MPIAIJ *>(B->data);
1124   KokkosCsrMatrix Adt, Aot, Ad, Ao, Bd, Bo;
1125   PetscInt        cstart, cend;
1126   MPI_Comm        comm;
1127 
1128   PetscFunctionBegin;
1129   PetscCall(PetscObjectGetComm((PetscObject)B, &comm));
1130   PetscCall(MatSeqAIJKokkosGenerateTranspose_Private(ampi->A, &Adt));
1131   PetscCall(MatSeqAIJKokkosGenerateTranspose_Private(ampi->B, &Aot));
1132   PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(ampi->A, &Ad));
1133   PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(ampi->B, &Ao));
1134   PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(bmpi->A, &Bd));
1135   PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(bmpi->B, &Bo));
1136 
1137   // TODO: add command line options to select spgemm algorithms
1138   auto spgemm_alg = KokkosSparse::SPGEMMAlgorithm::SPGEMM_DEFAULT; // default is TPL if enabled, otherwise KK
1139 
1140   // CUDA-10.2's spgemm has bugs. We prefer the SpGEMMreuse APIs introduced in cuda-11.4
1141 #if defined(KOKKOSKERNELS_ENABLE_TPL_CUSPARSE)
1142   #if PETSC_PKG_CUDA_VERSION_LT(11, 4, 0)
1143   spgemm_alg = KokkosSparse::SPGEMMAlgorithm::SPGEMM_KK;
1144   #endif
1145 #endif
1146 
1147   PetscCallCXX(mm->kh1.create_spgemm_handle(spgemm_alg));
1148   PetscCallCXX(mm->kh2.create_spgemm_handle(spgemm_alg));
1149   PetscCallCXX(mm->kh3.create_spgemm_handle(spgemm_alg));
1150   PetscCallCXX(mm->kh4.create_spgemm_handle(spgemm_alg));
1151 
1152   // Aot * (B's diag + B's off-diag)
1153   PetscCallCXX(KokkosSparse::spgemm_symbolic(mm->kh3, Aot, false, Bd, false, mm->C3));
1154   PetscCallCXX(KokkosSparse::spgemm_symbolic(mm->kh4, Aot, false, Bo, false, mm->C4));
1155   // KK spgemm_symbolic() only populates the result's row map, but not its columns.
1156   // TODO: Remove the fake spgemm_numeric() after KK fixed this problem.
1157   PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh3, Aot, false, Bd, false, mm->C3));
1158   PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh4, Aot, false, Bo, false, mm->C4));
1159 #if PETSC_PKG_KOKKOS_KERNELS_VERSION_LT(4, 0, 0)
1160   PetscCallCXX(sort_crs_matrix(mm->C3));
1161   PetscCallCXX(sort_crs_matrix(mm->C4));
1162 #endif
1163 
1164   // Reduce E (i.e., C3 and C4)'s rows to form F, and overlap the communication
1165   PetscIntKokkosViewHost map_h(NoInit("map_h"), bmpi->B->cmap->n);
1166   PetscCall(MatGetOwnershipRangeColumn(B, &cstart, &cend));
1167   PetscCall(MatMPIAIJKokkosReduceBegin(comm, mm->C3, mm->C4, cstart, cend, bmpi->garray, ampi->Mvctx, MAT_INITIAL_MATRIX, map_h.data(), mm));
1168 
1169   // Adt * (B's diag + B's off-diag)
1170   PetscCallCXX(KokkosSparse::spgemm_symbolic(mm->kh1, Adt, false, Bd, false, mm->C1));
1171   PetscCallCXX(KokkosSparse::spgemm_symbolic(mm->kh2, Adt, false, Bo, false, mm->C2_mid));
1172   PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh1, Adt, false, Bd, false, mm->C1));
1173   PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh2, Adt, false, Bo, false, mm->C2_mid));
1174 #if PETSC_PKG_KOKKOS_KERNELS_VERSION_LT(4, 0, 0)
1175   PetscCallCXX(sort_crs_matrix(mm->C1));
1176   PetscCallCXX(sort_crs_matrix(mm->C2_mid));
1177 #endif
1178 
1179   PetscCall(MatMPIAIJKokkosReduceEnd(comm, mm->C3, mm->C4, cstart, cend, bmpi->garray, ampi->Mvctx, MAT_INITIAL_MATRIX, map_h.data(), mm));
1180 
1181   // Create C2, which shares a, i arrays with C2_mid, but with new column indices and potentially larger column size
1182   MatColIdxKokkosView oldj = mm->C2_mid.graph.entries, newj(NoInit("j"), oldj.extent(0));
1183   PetscIntKokkosView  map  = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), map_h);
1184   PetscCallCXX(Kokkos::parallel_for(oldj.extent(0), KOKKOS_LAMBDA(const PetscInt i) { newj(i) = map(oldj(i)); }));
1185   PetscCallCXX(mm->C2 = KokkosCsrMatrix("C2", mm->C2_mid.numRows(), mm->n /*new column size*/, mm->C2_mid.nnz(), mm->C2_mid.values, mm->C2_mid.graph.row_map, newj));
1186 
1187   // C = (C1+Fd, C2+Fo)
1188   PetscCallCXX(mm->kh1.create_spadd_handle(true)); // C1, Fd are sorted
1189   PetscCallCXX(mm->kh2.create_spadd_handle(true)); // C2, Fo are sorted
1190   PetscCallCXX(KokkosSparse::spadd_symbolic(&mm->kh1, mm->C1, mm->Fd, mm->Cd));
1191   PetscCallCXX(KokkosSparse::spadd_symbolic(&mm->kh2, mm->C2, mm->Fo, mm->Co));
1192   PetscCallCXX(KokkosSparse::spadd_numeric(&mm->kh1, 1.0, mm->C1, 1.0, mm->Fd, mm->Cd));
1193   PetscCallCXX(KokkosSparse::spadd_numeric(&mm->kh2, 1.0, mm->C2, 1.0, mm->Fo, mm->Co));
1194   PetscFunctionReturn(PETSC_SUCCESS);
1195 }
1196 
1197 static PetscErrorCode MatProductNumeric_MPIAIJKokkos_AtB(Mat_Product *product, Mat A, Mat B, MatMatStruct_AtB *mm)
1198 {
1199   Mat_MPIAIJ     *ampi = static_cast<Mat_MPIAIJ *>(A->data);
1200   Mat_MPIAIJ     *bmpi = static_cast<Mat_MPIAIJ *>(B->data);
1201   KokkosCsrMatrix Adt, Aot, Bd, Bo;
1202   MPI_Comm        comm;
1203 
1204   PetscFunctionBegin;
1205   PetscCall(PetscObjectGetComm((PetscObject)B, &comm));
1206   PetscCall(MatSeqAIJKokkosGenerateTranspose_Private(ampi->A, &Adt));
1207   PetscCall(MatSeqAIJKokkosGenerateTranspose_Private(ampi->B, &Aot));
1208   PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(bmpi->A, &Bd));
1209   PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(bmpi->B, &Bo));
1210 
1211   // Aot * (B's diag + B's off-diag)
1212   PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh3, Aot, false, Bd, false, mm->C3));
1213   PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh4, Aot, false, Bo, false, mm->C4));
1214 
1215   // Reduce E (i.e., C3 and C4)'s rows to form F, and overlap the communication
1216   PetscCall(MatMPIAIJKokkosReduceBegin(comm, mm->C3, mm->C4, 0, 0, NULL, NULL, MAT_REUSE_MATRIX, NULL, mm));
1217 
1218   // Adt * (B's diag + B's off-diag)
1219   PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh1, Adt, false, Bd, false, mm->C1));
1220   PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh2, Adt, false, Bo, false, mm->C2_mid));
1221 
1222   PetscCall(MatMPIAIJKokkosReduceEnd(comm, mm->C3, mm->C4, 0, 0, NULL, NULL, MAT_REUSE_MATRIX, NULL, mm));
1223 
1224   // C = (C1+Fd, C2+Fo)
1225   PetscCallCXX(KokkosSparse::spadd_numeric(&mm->kh1, 1.0, mm->C1, 1.0, mm->Fd, mm->Cd));
1226   PetscCallCXX(KokkosSparse::spadd_numeric(&mm->kh2, 1.0, mm->C2, 1.0, mm->Fo, mm->Co));
1227   PetscFunctionReturn(PETSC_SUCCESS);
1228 }
1229 
1230 /* MatProductSymbolic_MPIAIJKokkos_AB - AB flavor of MatProductSymbolic_MPIAIJKokkos
1231 
1232   Input Parameters:
1233 +  product  - Mat_Product which carried out the computation. Passed in to access info about this mat product.
1234 .  A        - an MPIAIJKOKKOS matrix
1235 .  B        - an MPIAIJKOKKOS matrix
1236 -  mm       - a struct used to stash intermediate data when computing AB. Persist from symbolic to numeric operations.
1237 */
1238 static PetscErrorCode MatProductSymbolic_MPIAIJKokkos_AB(Mat_Product *product, Mat A, Mat B, MatMatStruct_AB *mm)
1239 {
1240   Mat_MPIAIJ     *ampi = static_cast<Mat_MPIAIJ *>(A->data);
1241   Mat_MPIAIJ     *bmpi = static_cast<Mat_MPIAIJ *>(B->data);
1242   KokkosCsrMatrix Ad, Ao, Bd, Bo;
1243 
1244   PetscFunctionBegin;
1245   PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(ampi->A, &Ad));
1246   PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(ampi->B, &Ao));
1247   PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(bmpi->A, &Bd));
1248   PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(bmpi->B, &Bo));
1249 
1250   // TODO: add command line options to select spgemm algorithms
1251   auto spgemm_alg = KokkosSparse::SPGEMMAlgorithm::SPGEMM_DEFAULT; // default is TPL if enabled, otherwise KK
1252 
1253   // CUDA-10.2's spgemm has bugs. We prefer the SpGEMMreuse APIs introduced in cuda-11.4
1254 #if defined(KOKKOSKERNELS_ENABLE_TPL_CUSPARSE)
1255   #if PETSC_PKG_CUDA_VERSION_LT(11, 4, 0)
1256   spgemm_alg = KokkosSparse::SPGEMMAlgorithm::SPGEMM_KK;
1257   #endif
1258 #endif
1259 
1260   mm->kh1.create_spgemm_handle(spgemm_alg);
1261   mm->kh2.create_spgemm_handle(spgemm_alg);
1262   mm->kh3.create_spgemm_handle(spgemm_alg);
1263   mm->kh4.create_spgemm_handle(spgemm_alg);
1264 
1265   // Bcast B's rows to form F, and overlap the communication
1266   PetscIntKokkosViewHost map_h(NoInit("map_h"), bmpi->B->cmap->n);
1267   PetscCall(MatMPIAIJKokkosBcastBegin(B, ampi->Mvctx, MAT_INITIAL_MATRIX, map_h.data(), mm));
1268 
1269   // A's diag * (B's diag + B's off-diag)
1270   PetscCallCXX(KokkosSparse::spgemm_symbolic(mm->kh1, Ad, false, Bd, false, mm->C1));
1271   PetscCallCXX(KokkosSparse::spgemm_symbolic(mm->kh2, Ad, false, Bo, false, mm->C2_mid)); // C2 aliases with C2_mid, except with new column indices
1272   // KK spgemm_symbolic() only populates the result's row map, but not its columns.
1273   // TODO: Remove the fake spgemm_numeric() after KK fixed this problem.
1274   PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh1, Ad, false, Bd, false, mm->C1));
1275   PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh2, Ad, false, Bo, false, mm->C2_mid));
1276 #if PETSC_PKG_KOKKOS_KERNELS_VERSION_LT(4, 0, 0)
1277   PetscCallCXX(sort_crs_matrix(mm->C1));
1278   PetscCallCXX(sort_crs_matrix(mm->C2_mid));
1279 #endif
1280 
1281   PetscCall(MatMPIAIJKokkosBcastEnd(B, ampi->Mvctx, MAT_INITIAL_MATRIX, map_h.data(), mm));
1282 
1283   // A's off-diag * (F's diag + F's off-diag)
1284   PetscCallCXX(KokkosSparse::spgemm_symbolic(mm->kh3, Ao, false, mm->Fd, false, mm->C3));
1285   PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh3, Ao, false, mm->Fd, false, mm->C3));
1286   PetscCallCXX(KokkosSparse::spgemm_symbolic(mm->kh4, Ao, false, mm->Fo, false, mm->C4));
1287   PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh4, Ao, false, mm->Fo, false, mm->C4));
1288 #if PETSC_PKG_KOKKOS_KERNELS_VERSION_LT(4, 0, 0)
1289   PetscCallCXX(sort_crs_matrix(mm->C3));
1290   PetscCallCXX(sort_crs_matrix(mm->C4));
1291 #endif
1292 
1293   // Create C2, which shares a, i arrays with C2_mid, but with new column indices and potentially larger column size
1294   MatColIdxKokkosView oldj = mm->C2_mid.graph.entries, newj(NoInit("j"), oldj.extent(0));
1295   PetscIntKokkosView  map  = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), map_h);
1296   PetscCallCXX(Kokkos::parallel_for(oldj.extent(0), KOKKOS_LAMBDA(const PetscInt i) { newj(i) = map(oldj(i)); }));
1297   mm->C2 = KokkosCsrMatrix("C2", mm->C2_mid.numRows(), mm->n /*new column size*/, mm->C2_mid.nnz(), mm->C2_mid.values, mm->C2_mid.graph.row_map, newj);
1298 
1299   // C = (Cd, Co) = (C1+C3, C2+C4)
1300   mm->kh1.create_spadd_handle(true); // C1, C3 are sorted
1301   mm->kh2.create_spadd_handle(true); // C2, C4 are sorted
1302   PetscCallCXX(KokkosSparse::spadd_symbolic(&mm->kh1, mm->C1, mm->C3, mm->Cd));
1303   PetscCallCXX(KokkosSparse::spadd_symbolic(&mm->kh2, mm->C2, mm->C4, mm->Co));
1304   PetscCallCXX(KokkosSparse::spadd_numeric(&mm->kh1, 1.0, mm->C1, 1.0, mm->C3, mm->Cd));
1305   PetscCallCXX(KokkosSparse::spadd_numeric(&mm->kh2, 1.0, mm->C2, 1.0, mm->C4, mm->Co));
1306   PetscFunctionReturn(PETSC_SUCCESS);
1307 }
1308 
1309 static PetscErrorCode MatProductNumeric_MPIAIJKokkos_AB(Mat_Product *product, Mat A, Mat B, MatMatStruct_AB *mm)
1310 {
1311   Mat_MPIAIJ     *ampi = static_cast<Mat_MPIAIJ *>(A->data);
1312   Mat_MPIAIJ     *bmpi = static_cast<Mat_MPIAIJ *>(B->data);
1313   KokkosCsrMatrix Ad, Ao, Bd, Bo;
1314 
1315   PetscFunctionBegin;
1316   PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(ampi->A, &Ad));
1317   PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(ampi->B, &Ao));
1318   PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(bmpi->A, &Bd));
1319   PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(bmpi->B, &Bo));
1320 
1321   // Bcast B's rows to form F, and overlap the communication
1322   PetscCall(MatMPIAIJKokkosBcastBegin(B, NULL, MAT_REUSE_MATRIX, NULL, mm));
1323 
1324   // A's diag * (B's diag + B's off-diag)
1325   PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh1, Ad, false, Bd, false, mm->C1));
1326   PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh2, Ad, false, Bo, false, mm->C2_mid));
1327 
1328   PetscCall(MatMPIAIJKokkosBcastEnd(B, NULL, MAT_REUSE_MATRIX, NULL, mm));
1329 
1330   // A's off-diag * (F's diag + F's off-diag)
1331   PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh3, Ao, false, mm->Fd, false, mm->C3));
1332   PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh4, Ao, false, mm->Fo, false, mm->C4));
1333 
1334   // C = (Cd, Co) = (C1+C3, C2+C4)
1335   PetscCallCXX(KokkosSparse::spadd_numeric(&mm->kh1, 1.0, mm->C1, 1.0, mm->C3, mm->Cd));
1336   PetscCallCXX(KokkosSparse::spadd_numeric(&mm->kh2, 1.0, mm->C2, 1.0, mm->C4, mm->Co));
1337   PetscFunctionReturn(PETSC_SUCCESS);
1338 }
1339 
1340 static PetscErrorCode MatProductNumeric_MPIAIJKokkos(Mat C)
1341 {
1342   Mat_MPIAIJ                  *cmpi = static_cast<Mat_MPIAIJ *>(C->data);
1343   Mat_Product                 *product;
1344   MatProductData_MPIAIJKokkos *pdata;
1345   MatProductType               ptype;
1346   Mat                          A, B;
1347 
1348   PetscFunctionBegin;
1349   MatCheckProduct(C, 1); // make sure C is a product
1350   product = C->product;
1351   pdata   = static_cast<MatProductData_MPIAIJKokkos *>(product->data);
1352   ptype   = product->type;
1353   A       = product->A;
1354   B       = product->B;
1355 
1356   // See if numeric has already been done in symbolic (e.g., user calls MatMatMult(A,B,MAT_INITIAL_MATRIX,..,C)).
1357   // If yes, skip the numeric, but reset the flag so that next time when user calls MatMatMult(E,F,MAT_REUSE_MATRIX,..,C),
1358   // we still do numeric.
1359   if (pdata->reusesym) { // numeric reuses results from symbolic
1360     pdata->reusesym = PETSC_FALSE;
1361     PetscFunctionReturn(PETSC_SUCCESS);
1362   }
1363 
1364   if (ptype == MATPRODUCT_AB) {
1365     PetscCall(MatProductNumeric_MPIAIJKokkos_AB(product, A, B, pdata->mmAB));
1366   } else if (ptype == MATPRODUCT_AtB) {
1367     PetscCall(MatProductNumeric_MPIAIJKokkos_AtB(product, A, B, pdata->mmAtB));
1368   } else if (ptype == MATPRODUCT_PtAP) { // BtAB, computed by Z = AB; C= BtZ
1369     PetscCall(MatProductNumeric_MPIAIJKokkos_AB(product, A, B, pdata->mmAB));
1370     PetscCall(MatProductNumeric_MPIAIJKokkos_AtB(product, B, pdata->Z, pdata->mmAtB));
1371   }
1372 
1373   PetscCall(MatSeqAIJKokkosModifyDevice(cmpi->A)); // mark that A, B on device are modified
1374   PetscCall(MatSeqAIJKokkosModifyDevice(cmpi->B));
1375   PetscFunctionReturn(PETSC_SUCCESS);
1376 }
1377 
1378 static PetscErrorCode MatProductSymbolic_MPIAIJKokkos(Mat C)
1379 {
1380   Mat                          A, B;
1381   Mat_Product                 *product;
1382   MatProductType               ptype;
1383   MatProductData_MPIAIJKokkos *pdata;
1384   MatMatStruct                *mm = NULL;
1385   PetscInt                     m, n, M, N;
1386   Mat                          Cd, Co;
1387   MPI_Comm                     comm;
1388 
1389   PetscFunctionBegin;
1390   PetscCall(PetscObjectGetComm((PetscObject)C, &comm));
1391   MatCheckProduct(C, 1);
1392   product = C->product;
1393   PetscCheck(!product->data, comm, PETSC_ERR_PLIB, "Product data not empty");
1394   ptype = product->type;
1395   A     = product->A;
1396   B     = product->B;
1397 
1398   switch (ptype) {
1399   case MATPRODUCT_AB:
1400     m = A->rmap->n;
1401     n = B->cmap->n;
1402     M = A->rmap->N;
1403     N = B->cmap->N;
1404     break;
1405   case MATPRODUCT_AtB:
1406     m = A->cmap->n;
1407     n = B->cmap->n;
1408     M = A->cmap->N;
1409     N = B->cmap->N;
1410     break;
1411   case MATPRODUCT_PtAP:
1412     m = B->cmap->n;
1413     n = B->cmap->n;
1414     M = B->cmap->N;
1415     N = B->cmap->N;
1416     break; /* BtAB */
1417   default:
1418     SETERRQ(comm, PETSC_ERR_PLIB, "Not for product type %s", MatProductTypes[ptype]);
1419   }
1420 
1421   PetscCall(MatSetSizes(C, m, n, M, N));
1422   PetscCall(PetscLayoutSetUp(C->rmap));
1423   PetscCall(PetscLayoutSetUp(C->cmap));
1424   PetscCall(MatSetType(C, ((PetscObject)A)->type_name));
1425 
1426   pdata           = new MatProductData_MPIAIJKokkos();
1427   pdata->reusesym = product->api_user;
1428 
1429   if (ptype == MATPRODUCT_AB) {
1430     auto mmAB = new MatMatStruct_AB();
1431     PetscCall(MatProductSymbolic_MPIAIJKokkos_AB(product, A, B, mmAB));
1432     mm = pdata->mmAB = mmAB;
1433   } else if (ptype == MATPRODUCT_AtB) {
1434     auto mmAtB = new MatMatStruct_AtB();
1435     PetscCall(MatProductSymbolic_MPIAIJKokkos_AtB(product, A, B, mmAtB));
1436     mm = pdata->mmAtB = mmAtB;
1437   } else if (ptype == MATPRODUCT_PtAP) { // C = BtAB, computed as Z = AB; C= BtZ
1438     Mat Zd, Zo, Z;                       // Zd, Zo are owned by pdata->Z
1439 
1440     auto mmAB = new MatMatStruct_AB();
1441     PetscCall(MatProductSymbolic_MPIAIJKokkos_AB(product, A, B, mmAB)); // Z stored as mmAB->{Cd, Co}
1442     PetscCall(MatCreateSeqAIJKokkosWithKokkosCsrMatrix(PETSC_COMM_SELF, mmAB->Cd, &Zd));
1443     PetscCall(MatCreateSeqAIJKokkosWithKokkosCsrMatrix(PETSC_COMM_SELF, mmAB->Co, &Zo));
1444     pdata->mmAB = mmAB;
1445 
1446     m = A->rmap->n; // Z's layout
1447     n = B->cmap->n;
1448     M = A->rmap->N;
1449     N = B->cmap->N;
1450     PetscCall(MatCreate(comm, &Z));
1451     PetscCall(MatSetSizes(Z, m, n, M, N));
1452     PetscCall(PetscLayoutSetUp(Z->rmap));
1453     PetscCall(PetscLayoutSetUp(Z->cmap));
1454     PetscCall(MatSetType(Z, MATMPIAIJKOKKOS));
1455     PetscCall(MatSetMPIAIJKokkosWithSplitSeqAIJKokkosMatrices(Z, Zd, Zo, mmAB->garray));
1456 
1457     auto mmAtB = new MatMatStruct_AtB();
1458     PetscCall(MatProductSymbolic_MPIAIJKokkos_AtB(product, B, Z, mmAtB)); // final result C stored as mmAtB->{Cd, Co}
1459 
1460     pdata->Z = Z; // give ownership to pdata
1461     mm = pdata->mmAtB = mmAtB;
1462   }
1463 
1464   PetscCall(MatCreateSeqAIJKokkosWithKokkosCsrMatrix(PETSC_COMM_SELF, mm->Cd, &Cd));
1465   PetscCall(MatCreateSeqAIJKokkosWithKokkosCsrMatrix(PETSC_COMM_SELF, mm->Co, &Co));
1466   PetscCall(MatSetMPIAIJKokkosWithSplitSeqAIJKokkosMatrices(C, Cd, Co, mm->garray));
1467 
1468   C->product->data       = pdata;
1469   C->product->destroy    = MatProductDataDestroy_MPIAIJKokkos;
1470   C->ops->productnumeric = MatProductNumeric_MPIAIJKokkos;
1471   PetscFunctionReturn(PETSC_SUCCESS);
1472 }
1473 
1474 PETSC_INTERN PetscErrorCode MatProductSetFromOptions_MPIAIJKokkos(Mat mat)
1475 {
1476   Mat_Product *product = mat->product;
1477   PetscBool    match   = PETSC_FALSE;
1478   PetscBool    usecpu  = PETSC_FALSE;
1479 
1480   PetscFunctionBegin;
1481   MatCheckProduct(mat, 1);
1482   if (!product->A->boundtocpu && !product->B->boundtocpu) PetscCall(PetscObjectTypeCompare((PetscObject)product->B, ((PetscObject)product->A)->type_name, &match));
1483   if (match) { /* we can always fallback to the CPU if requested */
1484     switch (product->type) {
1485     case MATPRODUCT_AB:
1486       if (product->api_user) {
1487         PetscOptionsBegin(PetscObjectComm((PetscObject)mat), ((PetscObject)mat)->prefix, "MatMatMult", "Mat");
1488         PetscCall(PetscOptionsBool("-matmatmult_backend_cpu", "Use CPU code", "MatMatMult", usecpu, &usecpu, NULL));
1489         PetscOptionsEnd();
1490       } else {
1491         PetscOptionsBegin(PetscObjectComm((PetscObject)mat), ((PetscObject)mat)->prefix, "MatProduct_AB", "Mat");
1492         PetscCall(PetscOptionsBool("-mat_product_algorithm_backend_cpu", "Use CPU code", "MatMatMult", usecpu, &usecpu, NULL));
1493         PetscOptionsEnd();
1494       }
1495       break;
1496     case MATPRODUCT_AtB:
1497       if (product->api_user) {
1498         PetscOptionsBegin(PetscObjectComm((PetscObject)mat), ((PetscObject)mat)->prefix, "MatTransposeMatMult", "Mat");
1499         PetscCall(PetscOptionsBool("-mattransposematmult_backend_cpu", "Use CPU code", "MatTransposeMatMult", usecpu, &usecpu, NULL));
1500         PetscOptionsEnd();
1501       } else {
1502         PetscOptionsBegin(PetscObjectComm((PetscObject)mat), ((PetscObject)mat)->prefix, "MatProduct_AtB", "Mat");
1503         PetscCall(PetscOptionsBool("-mat_product_algorithm_backend_cpu", "Use CPU code", "MatTransposeMatMult", usecpu, &usecpu, NULL));
1504         PetscOptionsEnd();
1505       }
1506       break;
1507     case MATPRODUCT_PtAP:
1508       if (product->api_user) {
1509         PetscOptionsBegin(PetscObjectComm((PetscObject)mat), ((PetscObject)mat)->prefix, "MatPtAP", "Mat");
1510         PetscCall(PetscOptionsBool("-matptap_backend_cpu", "Use CPU code", "MatPtAP", usecpu, &usecpu, NULL));
1511         PetscOptionsEnd();
1512       } else {
1513         PetscOptionsBegin(PetscObjectComm((PetscObject)mat), ((PetscObject)mat)->prefix, "MatProduct_PtAP", "Mat");
1514         PetscCall(PetscOptionsBool("-mat_product_algorithm_backend_cpu", "Use CPU code", "MatPtAP", usecpu, &usecpu, NULL));
1515         PetscOptionsEnd();
1516       }
1517       break;
1518     default:
1519       break;
1520     }
1521     match = (PetscBool)!usecpu;
1522   }
1523   if (match) {
1524     switch (product->type) {
1525     case MATPRODUCT_AB:
1526     case MATPRODUCT_AtB:
1527     case MATPRODUCT_PtAP:
1528       mat->ops->productsymbolic = MatProductSymbolic_MPIAIJKokkos;
1529       break;
1530     default:
1531       break;
1532     }
1533   }
1534   /* fallback to MPIAIJ ops */
1535   if (!mat->ops->productsymbolic) PetscCall(MatProductSetFromOptions_MPIAIJ(mat));
1536   PetscFunctionReturn(PETSC_SUCCESS);
1537 }
1538 
1539 // Mirror of MatCOOStruct_MPIAIJ on device
1540 struct MatCOOStruct_MPIAIJKokkos {
1541   PetscCount           n;
1542   PetscSF              sf;
1543   PetscCount           Annz, Bnnz;
1544   PetscCount           Annz2, Bnnz2;
1545   PetscCountKokkosView Ajmap1, Aperm1;
1546   PetscCountKokkosView Bjmap1, Bperm1;
1547   PetscCountKokkosView Aimap2, Ajmap2, Aperm2;
1548   PetscCountKokkosView Bimap2, Bjmap2, Bperm2;
1549   PetscCountKokkosView Cperm1;
1550   MatScalarKokkosView  sendbuf, recvbuf;
1551 
1552   MatCOOStruct_MPIAIJKokkos(const MatCOOStruct_MPIAIJ *coo_h) :
1553     n(coo_h->n),
1554     sf(coo_h->sf),
1555     Annz(coo_h->Annz),
1556     Bnnz(coo_h->Bnnz),
1557     Annz2(coo_h->Annz2),
1558     Bnnz2(coo_h->Bnnz2),
1559     Ajmap1(Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), PetscCountKokkosViewHost(coo_h->Ajmap1, coo_h->Annz + 1))),
1560     Aperm1(Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), PetscCountKokkosViewHost(coo_h->Aperm1, coo_h->Atot1))),
1561     Bjmap1(Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), PetscCountKokkosViewHost(coo_h->Bjmap1, coo_h->Bnnz + 1))),
1562     Bperm1(Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), PetscCountKokkosViewHost(coo_h->Bperm1, coo_h->Btot1))),
1563     Aimap2(Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), PetscCountKokkosViewHost(coo_h->Aimap2, coo_h->Annz2))),
1564     Ajmap2(Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), PetscCountKokkosViewHost(coo_h->Ajmap2, coo_h->Annz2 + 1))),
1565     Aperm2(Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), PetscCountKokkosViewHost(coo_h->Aperm2, coo_h->Atot2))),
1566     Bimap2(Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), PetscCountKokkosViewHost(coo_h->Bimap2, coo_h->Bnnz2))),
1567     Bjmap2(Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), PetscCountKokkosViewHost(coo_h->Bjmap2, coo_h->Bnnz2 + 1))),
1568     Bperm2(Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), PetscCountKokkosViewHost(coo_h->Bperm2, coo_h->Btot2))),
1569     Cperm1(Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), PetscCountKokkosViewHost(coo_h->Cperm1, coo_h->sendlen))),
1570     sendbuf(Kokkos::create_mirror_view(Kokkos::WithoutInitializing, DefaultMemorySpace(), MatScalarKokkosViewHost(coo_h->sendbuf, coo_h->sendlen))),
1571     recvbuf(Kokkos::create_mirror_view(Kokkos::WithoutInitializing, DefaultMemorySpace(), MatScalarKokkosViewHost(coo_h->recvbuf, coo_h->recvlen)))
1572   {
1573     PetscCallVoid(PetscObjectReference((PetscObject)sf));
1574   }
1575 
1576   ~MatCOOStruct_MPIAIJKokkos() { PetscCallVoid(PetscSFDestroy(&sf)); }
1577 };
1578 
1579 static PetscErrorCode MatCOOStructDestroy_MPIAIJKokkos(void *data)
1580 {
1581   PetscFunctionBegin;
1582   PetscCallCXX(delete static_cast<MatCOOStruct_MPIAIJKokkos *>(data));
1583   PetscFunctionReturn(PETSC_SUCCESS);
1584 }
1585 
1586 static PetscErrorCode MatSetPreallocationCOO_MPIAIJKokkos(Mat mat, PetscCount coo_n, PetscInt coo_i[], PetscInt coo_j[])
1587 {
1588   PetscContainer             container_h, container_d;
1589   MatCOOStruct_MPIAIJ       *coo_h;
1590   MatCOOStruct_MPIAIJKokkos *coo_d;
1591 
1592   PetscFunctionBegin;
1593   PetscCall(MatSetPreallocationCOO_MPIAIJ(mat, coo_n, coo_i, coo_j)); /* mpiaij->A,B's type is set to seqaijkokkos */
1594   mat->preallocated = PETSC_TRUE;
1595   PetscCall(MatAssemblyBegin(mat, MAT_FINAL_ASSEMBLY));
1596   PetscCall(MatAssemblyEnd(mat, MAT_FINAL_ASSEMBLY));
1597   PetscCall(MatZeroEntries(mat));
1598 
1599   // Copy the COO struct to device
1600   PetscCall(PetscObjectQuery((PetscObject)mat, "__PETSc_MatCOOStruct_Host", (PetscObject *)&container_h));
1601   PetscCall(PetscContainerGetPointer(container_h, (void **)&coo_h));
1602   PetscCallCXX(coo_d = new MatCOOStruct_MPIAIJKokkos(coo_h));
1603 
1604   // Put the COO struct in a container and then attach that to the matrix
1605   PetscCall(PetscContainerCreate(PETSC_COMM_SELF, &container_d));
1606   PetscCall(PetscContainerSetPointer(container_d, coo_d));
1607   PetscCall(PetscContainerSetUserDestroy(container_d, MatCOOStructDestroy_MPIAIJKokkos));
1608   PetscCall(PetscObjectCompose((PetscObject)mat, "__PETSc_MatCOOStruct_Device", (PetscObject)container_d));
1609   PetscCall(PetscContainerDestroy(&container_d));
1610   PetscFunctionReturn(PETSC_SUCCESS);
1611 }
1612 
1613 static PetscErrorCode MatSetValuesCOO_MPIAIJKokkos(Mat mat, const PetscScalar v[], InsertMode imode)
1614 {
1615   Mat_MPIAIJ                *mpiaij = static_cast<Mat_MPIAIJ *>(mat->data);
1616   Mat                        A = mpiaij->A, B = mpiaij->B;
1617   MatScalarKokkosView        Aa, Ba;
1618   MatScalarKokkosView        v1;
1619   PetscMemType               memtype;
1620   PetscContainer             container;
1621   MatCOOStruct_MPIAIJKokkos *coo;
1622 
1623   PetscFunctionBegin;
1624   PetscCall(PetscObjectQuery((PetscObject)mat, "__PETSc_MatCOOStruct_Device", (PetscObject *)&container));
1625   PetscCall(PetscContainerGetPointer(container, (void **)&coo));
1626 
1627   const auto &n      = coo->n;
1628   const auto &Annz   = coo->Annz;
1629   const auto &Annz2  = coo->Annz2;
1630   const auto &Bnnz   = coo->Bnnz;
1631   const auto &Bnnz2  = coo->Bnnz2;
1632   const auto &vsend  = coo->sendbuf;
1633   const auto &v2     = coo->recvbuf;
1634   const auto &Ajmap1 = coo->Ajmap1;
1635   const auto &Ajmap2 = coo->Ajmap2;
1636   const auto &Aimap2 = coo->Aimap2;
1637   const auto &Bjmap1 = coo->Bjmap1;
1638   const auto &Bjmap2 = coo->Bjmap2;
1639   const auto &Bimap2 = coo->Bimap2;
1640   const auto &Aperm1 = coo->Aperm1;
1641   const auto &Aperm2 = coo->Aperm2;
1642   const auto &Bperm1 = coo->Bperm1;
1643   const auto &Bperm2 = coo->Bperm2;
1644   const auto &Cperm1 = coo->Cperm1;
1645 
1646   PetscCall(PetscGetMemType(v, &memtype)); /* Return PETSC_MEMTYPE_HOST when v is NULL */
1647   if (PetscMemTypeHost(memtype)) {         /* If user gave v[] in host, we need to copy it to device if any */
1648     v1 = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), MatScalarKokkosViewHost((PetscScalar *)v, n));
1649   } else {
1650     v1 = MatScalarKokkosView((PetscScalar *)v, n); /* Directly use v[]'s memory */
1651   }
1652 
1653   if (imode == INSERT_VALUES) {
1654     PetscCall(MatSeqAIJGetKokkosViewWrite(A, &Aa)); /* write matrix values */
1655     PetscCall(MatSeqAIJGetKokkosViewWrite(B, &Ba));
1656   } else {
1657     PetscCall(MatSeqAIJGetKokkosView(A, &Aa)); /* read & write matrix values */
1658     PetscCall(MatSeqAIJGetKokkosView(B, &Ba));
1659   }
1660 
1661   PetscCall(PetscLogGpuTimeBegin());
1662   /* Pack entries to be sent to remote */
1663   Kokkos::parallel_for(vsend.extent(0), KOKKOS_LAMBDA(const PetscCount i) { vsend(i) = v1(Cperm1(i)); });
1664 
1665   /* Send remote entries to their owner and overlap the communication with local computation */
1666   PetscCall(PetscSFReduceWithMemTypeBegin(coo->sf, MPIU_SCALAR, PETSC_MEMTYPE_KOKKOS, vsend.data(), PETSC_MEMTYPE_KOKKOS, v2.data(), MPI_REPLACE));
1667   /* Add local entries to A and B in one kernel */
1668   Kokkos::parallel_for(
1669     Annz + Bnnz, KOKKOS_LAMBDA(PetscCount i) {
1670       PetscScalar sum = 0.0;
1671       if (i < Annz) {
1672         for (PetscCount k = Ajmap1(i); k < Ajmap1(i + 1); k++) sum += v1(Aperm1(k));
1673         Aa(i) = (imode == INSERT_VALUES ? 0.0 : Aa(i)) + sum;
1674       } else {
1675         i -= Annz;
1676         for (PetscCount k = Bjmap1(i); k < Bjmap1(i + 1); k++) sum += v1(Bperm1(k));
1677         Ba(i) = (imode == INSERT_VALUES ? 0.0 : Ba(i)) + sum;
1678       }
1679     });
1680   PetscCall(PetscSFReduceEnd(coo->sf, MPIU_SCALAR, vsend.data(), v2.data(), MPI_REPLACE));
1681 
1682   /* Add received remote entries to A and B in one kernel */
1683   Kokkos::parallel_for(
1684     Annz2 + Bnnz2, KOKKOS_LAMBDA(PetscCount i) {
1685       if (i < Annz2) {
1686         for (PetscCount k = Ajmap2(i); k < Ajmap2(i + 1); k++) Aa(Aimap2(i)) += v2(Aperm2(k));
1687       } else {
1688         i -= Annz2;
1689         for (PetscCount k = Bjmap2(i); k < Bjmap2(i + 1); k++) Ba(Bimap2(i)) += v2(Bperm2(k));
1690       }
1691     });
1692   PetscCall(PetscLogGpuTimeEnd());
1693 
1694   if (imode == INSERT_VALUES) {
1695     PetscCall(MatSeqAIJRestoreKokkosViewWrite(A, &Aa)); /* Increase A & B's state etc. */
1696     PetscCall(MatSeqAIJRestoreKokkosViewWrite(B, &Ba));
1697   } else {
1698     PetscCall(MatSeqAIJRestoreKokkosView(A, &Aa));
1699     PetscCall(MatSeqAIJRestoreKokkosView(B, &Ba));
1700   }
1701   PetscFunctionReturn(PETSC_SUCCESS);
1702 }
1703 
1704 static PetscErrorCode MatDestroy_MPIAIJKokkos(Mat A)
1705 {
1706   PetscFunctionBegin;
1707   PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatMPIAIJSetPreallocation_C", NULL));
1708   PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatMPIAIJGetLocalMatMerge_C", NULL));
1709   PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatSetPreallocationCOO_C", NULL));
1710   PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatSetValuesCOO_C", NULL));
1711   PetscCall(MatDestroy_MPIAIJ(A));
1712   PetscFunctionReturn(PETSC_SUCCESS);
1713 }
1714 
1715 static PetscErrorCode MatShift_MPIAIJKokkos(Mat A, PetscScalar a)
1716 {
1717   Mat_MPIAIJ *mpiaij = static_cast<Mat_MPIAIJ *>(A->data);
1718   PetscBool   congruent;
1719 
1720   PetscFunctionBegin;
1721   PetscCall(MatHasCongruentLayouts(A, &congruent));
1722   if (congruent) { // square matrix and the diagonals are solely in the diag block
1723     PetscCall(MatShift(mpiaij->A, a));
1724   } else { // too hard, use the general version
1725     PetscCall(MatShift_Basic(A, a));
1726   }
1727   PetscFunctionReturn(PETSC_SUCCESS);
1728 }
1729 
1730 static PetscErrorCode MatSetOps_MPIAIJKokkos(Mat B)
1731 {
1732   PetscFunctionBegin;
1733   B->ops->assemblyend           = MatAssemblyEnd_MPIAIJKokkos;
1734   B->ops->mult                  = MatMult_MPIAIJKokkos;
1735   B->ops->multadd               = MatMultAdd_MPIAIJKokkos;
1736   B->ops->multtranspose         = MatMultTranspose_MPIAIJKokkos;
1737   B->ops->productsetfromoptions = MatProductSetFromOptions_MPIAIJKokkos;
1738   B->ops->destroy               = MatDestroy_MPIAIJKokkos;
1739   B->ops->shift                 = MatShift_MPIAIJKokkos;
1740 
1741   PetscCall(PetscObjectComposeFunction((PetscObject)B, "MatMPIAIJSetPreallocation_C", MatMPIAIJSetPreallocation_MPIAIJKokkos));
1742   PetscCall(PetscObjectComposeFunction((PetscObject)B, "MatMPIAIJGetLocalMatMerge_C", MatMPIAIJGetLocalMatMerge_MPIAIJKokkos));
1743   PetscCall(PetscObjectComposeFunction((PetscObject)B, "MatSetPreallocationCOO_C", MatSetPreallocationCOO_MPIAIJKokkos));
1744   PetscCall(PetscObjectComposeFunction((PetscObject)B, "MatSetValuesCOO_C", MatSetValuesCOO_MPIAIJKokkos));
1745   PetscFunctionReturn(PETSC_SUCCESS);
1746 }
1747 
1748 PETSC_INTERN PetscErrorCode MatConvert_MPIAIJ_MPIAIJKokkos(Mat A, MatType mtype, MatReuse reuse, Mat *newmat)
1749 {
1750   Mat         B;
1751   Mat_MPIAIJ *a;
1752 
1753   PetscFunctionBegin;
1754   if (reuse == MAT_INITIAL_MATRIX) {
1755     PetscCall(MatDuplicate(A, MAT_COPY_VALUES, newmat));
1756   } else if (reuse == MAT_REUSE_MATRIX) {
1757     PetscCall(MatCopy(A, *newmat, SAME_NONZERO_PATTERN));
1758   }
1759   B = *newmat;
1760 
1761   B->boundtocpu = PETSC_FALSE;
1762   PetscCall(PetscFree(B->defaultvectype));
1763   PetscCall(PetscStrallocpy(VECKOKKOS, &B->defaultvectype));
1764   PetscCall(PetscObjectChangeTypeName((PetscObject)B, MATMPIAIJKOKKOS));
1765 
1766   a = static_cast<Mat_MPIAIJ *>(A->data);
1767   if (a->A) PetscCall(MatSetType(a->A, MATSEQAIJKOKKOS));
1768   if (a->B) PetscCall(MatSetType(a->B, MATSEQAIJKOKKOS));
1769   if (a->lvec) PetscCall(VecSetType(a->lvec, VECSEQKOKKOS));
1770   PetscCall(MatSetOps_MPIAIJKokkos(B));
1771   PetscFunctionReturn(PETSC_SUCCESS);
1772 }
1773 
1774 /*MC
1775    MATAIJKOKKOS - "mpiaijkokkos", a matrix type to be used for CSR sparse matrices with Kokkos
1776 
1777    A matrix type using Kokkos-Kernels CrsMatrix type for portability across different device types
1778 
1779    Options Database Key:
1780 .  -mat_type aijkokkos - sets the matrix type to `MATAIJKOKKOS`
1781 
1782   Level: beginner
1783 
1784 .seealso: [](ch_matrices), `Mat`, `MatCreateAIJKokkos()`, `MATSEQAIJKOKKOS`, `MATSEQAIJ`, `MATMPIAIJ`
1785 M*/
1786 PETSC_EXTERN PetscErrorCode MatCreate_MPIAIJKokkos(Mat A)
1787 {
1788   PetscFunctionBegin;
1789   PetscCall(PetscKokkosInitializeCheck());
1790   PetscCall(MatCreate_MPIAIJ(A));
1791   PetscCall(MatConvert_MPIAIJ_MPIAIJKokkos(A, MATMPIAIJKOKKOS, MAT_INPLACE_MATRIX, &A));
1792   PetscFunctionReturn(PETSC_SUCCESS);
1793 }
1794 
1795 /*@C
1796   MatCreateAIJKokkos - Creates a sparse matrix in `MATAIJKOKOS` (compressed row) format
1797   (the default parallel PETSc format).  This matrix will ultimately pushed down
1798   to Kokkos for calculations.
1799 
1800   Collective
1801 
1802   Input Parameters:
1803 + comm  - MPI communicator, set to `PETSC_COMM_SELF`
1804 . m     - number of local rows (or `PETSC_DECIDE` to have calculated if `M` is given)
1805            This value should be the same as the local size used in creating the
1806            y vector for the matrix-vector product y = Ax.
1807 . n     - This value should be the same as the local size used in creating the
1808        x vector for the matrix-vector product y = Ax. (or `PETSC_DECIDE` to have
1809        calculated if N is given) For square matrices n is almost always `m`.
1810 . M     - number of global rows (or `PETSC_DETERMINE` to have calculated if `m` is given)
1811 . N     - number of global columns (or `PETSC_DETERMINE` to have calculated if `n` is given)
1812 . d_nz  - number of nonzeros per row in DIAGONAL portion of local submatrix
1813            (same value is used for all local rows)
1814 . d_nnz - array containing the number of nonzeros in the various rows of the
1815            DIAGONAL portion of the local submatrix (possibly different for each row)
1816            or `NULL`, if `d_nz` is used to specify the nonzero structure.
1817            The size of this array is equal to the number of local rows, i.e `m`.
1818            For matrices you plan to factor you must leave room for the diagonal entry and
1819            put in the entry even if it is zero.
1820 . o_nz  - number of nonzeros per row in the OFF-DIAGONAL portion of local
1821            submatrix (same value is used for all local rows).
1822 - o_nnz - array containing the number of nonzeros in the various rows of the
1823            OFF-DIAGONAL portion of the local submatrix (possibly different for
1824            each row) or `NULL`, if `o_nz` is used to specify the nonzero
1825            structure. The size of this array is equal to the number
1826            of local rows, i.e `m`.
1827 
1828   Output Parameter:
1829 . A - the matrix
1830 
1831   Level: intermediate
1832 
1833   Notes:
1834   It is recommended that one use the `MatCreate()`, `MatSetType()` and/or `MatSetFromOptions()`,
1835   MatXXXXSetPreallocation() paradigm instead of this routine directly.
1836   [MatXXXXSetPreallocation() is, for example, `MatSeqAIJSetPreallocation()`]
1837 
1838   The AIJ format, also called compressed row storage), is fully compatible with standard Fortran
1839   storage.  That is, the stored row and column indices can begin at
1840   either one (as in Fortran) or zero.
1841 
1842 .seealso: [](ch_matrices), `Mat`, `MATAIJKOKOS`, `MATSEQAIJKOKOS`, `MATMPIAIJKOKOS`, `MatCreate()`, `MatCreateAIJ()`, `MatSetValues()`,
1843           `MatSeqAIJSetColumnIndices()`, `MatCreateSeqAIJWithArrays()`, `MATMPIAIJKOKKOS`, `MATAIJKOKKOS`
1844 @*/
1845 PetscErrorCode MatCreateAIJKokkos(MPI_Comm comm, PetscInt m, PetscInt n, PetscInt M, PetscInt N, PetscInt d_nz, const PetscInt d_nnz[], PetscInt o_nz, const PetscInt o_nnz[], Mat *A)
1846 {
1847   PetscMPIInt size;
1848 
1849   PetscFunctionBegin;
1850   PetscCall(MatCreate(comm, A));
1851   PetscCall(MatSetSizes(*A, m, n, M, N));
1852   PetscCallMPI(MPI_Comm_size(comm, &size));
1853   if (size > 1) {
1854     PetscCall(MatSetType(*A, MATMPIAIJKOKKOS));
1855     PetscCall(MatMPIAIJSetPreallocation(*A, d_nz, d_nnz, o_nz, o_nnz));
1856   } else {
1857     PetscCall(MatSetType(*A, MATSEQAIJKOKKOS));
1858     PetscCall(MatSeqAIJSetPreallocation(*A, d_nz, d_nnz));
1859   }
1860   PetscFunctionReturn(PETSC_SUCCESS);
1861 }
1862