xref: /petsc/src/mat/impls/aij/mpi/kokkos/mpiaijkok.kokkos.cxx (revision 5d6b2bfc43f03a5db9a11c81ae96edf7e0c5e260)
1 #include <petsc_kokkos.hpp>
2 #include <petscvec_kokkos.hpp>
3 #include <petscpkg_version.h>
4 #include <petsc/private/sfimpl.h>
5 #include <../src/mat/impls/aij/seq/kokkos/aijkok.hpp>
6 #include <../src/mat/impls/aij/mpi/mpiaij.h>
7 #include <KokkosSparse_spadd.hpp>
8 #include <KokkosSparse_spgemm.hpp>
9 
10 static PetscErrorCode MatAssemblyEnd_MPIAIJKokkos(Mat A, MatAssemblyType mode)
11 {
12   Mat_MPIAIJ *mpiaij = (Mat_MPIAIJ *)A->data;
13 
14   PetscFunctionBegin;
15   PetscCall(MatAssemblyEnd_MPIAIJ(A, mode));
16   /* E.g., MatCreateSubMatrix() calls MatCreateMPIAIJWithSeqAIJ(comm,A,B,..), which creates Bnew of SEQAIJ and destroys B of SEQAIJKOKKOS.
17      Thus we finalize A/B/lvec's type in MatAssemblyEnd() to handle various cases.
18    */
19   if (mode == MAT_FINAL_ASSEMBLY) {
20     PetscCall(MatSetType(mpiaij->A, MATSEQAIJKOKKOS));
21     PetscCall(MatSetType(mpiaij->B, MATSEQAIJKOKKOS));
22     PetscCall(VecSetType(mpiaij->lvec, VECSEQKOKKOS));
23   }
24   PetscFunctionReturn(PETSC_SUCCESS);
25 }
26 
27 static PetscErrorCode MatMPIAIJSetPreallocation_MPIAIJKokkos(Mat mat, PetscInt d_nz, const PetscInt d_nnz[], PetscInt o_nz, const PetscInt o_nnz[])
28 {
29   Mat_MPIAIJ *mpiaij;
30 
31   PetscFunctionBegin;
32   // reuse MPIAIJ's preallocation, which sets A/B's blocksize along other things
33   PetscCall(MatMPIAIJSetPreallocation_MPIAIJ(mat, d_nz, d_nnz, o_nz, o_nnz));
34   mpiaij = static_cast<Mat_MPIAIJ *>(mat->data);
35   PetscCall(MatConvert_SeqAIJ_SeqAIJKokkos(mpiaij->A, MATSEQAIJKOKKOS, MAT_INPLACE_MATRIX, &mpiaij->A));
36   PetscCall(MatConvert_SeqAIJ_SeqAIJKokkos(mpiaij->B, MATSEQAIJKOKKOS, MAT_INPLACE_MATRIX, &mpiaij->B));
37   PetscFunctionReturn(PETSC_SUCCESS);
38 }
39 
40 static PetscErrorCode MatMult_MPIAIJKokkos(Mat mat, Vec xx, Vec yy)
41 {
42   Mat_MPIAIJ *mpiaij = (Mat_MPIAIJ *)mat->data;
43   PetscInt    nt;
44 
45   PetscFunctionBegin;
46   PetscCall(VecGetLocalSize(xx, &nt));
47   PetscCheck(nt == mat->cmap->n, PETSC_COMM_SELF, PETSC_ERR_ARG_SIZ, "Incompatible partition of mat (%" PetscInt_FMT ") and xx (%" PetscInt_FMT ")", mat->cmap->n, nt);
48   PetscCall(VecScatterBegin(mpiaij->Mvctx, xx, mpiaij->lvec, INSERT_VALUES, SCATTER_FORWARD));
49   PetscCall((*mpiaij->A->ops->mult)(mpiaij->A, xx, yy));
50   PetscCall(VecScatterEnd(mpiaij->Mvctx, xx, mpiaij->lvec, INSERT_VALUES, SCATTER_FORWARD));
51   PetscCall((*mpiaij->B->ops->multadd)(mpiaij->B, mpiaij->lvec, yy, yy));
52   PetscFunctionReturn(PETSC_SUCCESS);
53 }
54 
55 static PetscErrorCode MatMultAdd_MPIAIJKokkos(Mat mat, Vec xx, Vec yy, Vec zz)
56 {
57   Mat_MPIAIJ *mpiaij = (Mat_MPIAIJ *)mat->data;
58   PetscInt    nt;
59 
60   PetscFunctionBegin;
61   PetscCall(VecGetLocalSize(xx, &nt));
62   PetscCheck(nt == mat->cmap->n, PETSC_COMM_SELF, PETSC_ERR_ARG_SIZ, "Incompatible partition of mat (%" PetscInt_FMT ") and xx (%" PetscInt_FMT ")", mat->cmap->n, nt);
63   PetscCall(VecScatterBegin(mpiaij->Mvctx, xx, mpiaij->lvec, INSERT_VALUES, SCATTER_FORWARD));
64   PetscCall((*mpiaij->A->ops->multadd)(mpiaij->A, xx, yy, zz));
65   PetscCall(VecScatterEnd(mpiaij->Mvctx, xx, mpiaij->lvec, INSERT_VALUES, SCATTER_FORWARD));
66   PetscCall((*mpiaij->B->ops->multadd)(mpiaij->B, mpiaij->lvec, zz, zz));
67   PetscFunctionReturn(PETSC_SUCCESS);
68 }
69 
70 static PetscErrorCode MatMultTranspose_MPIAIJKokkos(Mat mat, Vec xx, Vec yy)
71 {
72   Mat_MPIAIJ *mpiaij = (Mat_MPIAIJ *)mat->data;
73   PetscInt    nt;
74 
75   PetscFunctionBegin;
76   PetscCall(VecGetLocalSize(xx, &nt));
77   PetscCheck(nt == mat->rmap->n, PETSC_COMM_SELF, PETSC_ERR_ARG_SIZ, "Incompatible partition of mat (%" PetscInt_FMT ") and xx (%" PetscInt_FMT ")", mat->rmap->n, nt);
78   PetscCall((*mpiaij->B->ops->multtranspose)(mpiaij->B, xx, mpiaij->lvec));
79   PetscCall((*mpiaij->A->ops->multtranspose)(mpiaij->A, xx, yy));
80   PetscCall(VecScatterBegin(mpiaij->Mvctx, mpiaij->lvec, yy, ADD_VALUES, SCATTER_REVERSE));
81   PetscCall(VecScatterEnd(mpiaij->Mvctx, mpiaij->lvec, yy, ADD_VALUES, SCATTER_REVERSE));
82   PetscFunctionReturn(PETSC_SUCCESS);
83 }
84 
85 /* Merge the "A, B" matrices of mat into a matrix C.  mat's type is MPIAIJKOKKOS. C's type is MATSEQAIJKOKKOS.
86    A is put before B. C's size would be A->rmap->n by (A->cmap->n + B->cmap->n).
87    C still uses local column ids. Their corresponding global column ids are returned in glob.
88 */
89 static PetscErrorCode MatMPIAIJGetLocalMatMerge_MPIAIJKokkos(Mat mat, MatReuse reuse, IS *glob, Mat *C)
90 {
91   Mat             Ad, Ao;
92   const PetscInt *cmap;
93 
94   PetscFunctionBegin;
95   PetscCall(MatMPIAIJGetSeqAIJ(mat, &Ad, &Ao, &cmap));
96   PetscCall(MatSeqAIJKokkosMergeMats(Ad, Ao, reuse, C));
97   if (glob) {
98     PetscInt cst, i, dn, on, *gidx;
99     PetscCall(MatGetLocalSize(Ad, NULL, &dn));
100     PetscCall(MatGetLocalSize(Ao, NULL, &on));
101     PetscCall(MatGetOwnershipRangeColumn(mat, &cst, NULL));
102     PetscCall(PetscMalloc1(dn + on, &gidx));
103     for (i = 0; i < dn; i++) gidx[i] = cst + i;
104     for (i = 0; i < on; i++) gidx[i + dn] = cmap[i];
105     PetscCall(ISCreateGeneral(PetscObjectComm((PetscObject)Ad), dn + on, gidx, PETSC_OWN_POINTER, glob));
106   }
107   PetscFunctionReturn(PETSC_SUCCESS);
108 }
109 
110 /* Structs used in matrix products of type C=AB, C=A^tB and C=B^tAB */
111 struct MatMatStruct {
112   PetscInt            n, *garray;     // C's garray and its size.
113   KokkosCsrMatrix     Cd, Co;         // C is in split form matrices (all in local column indcies)
114   KokkosCsrMatrix     C1, C2, C3, C4; // intermediate mat products
115   KokkosCsrMatrix     C2_mid, C4_mid; // alias of C2, C4; share their a[], i[], but with different j[] (hence column size)
116   PetscIntKokkosView  E_NzLeft;
117   PetscSF             sf = nullptr; // SF to bcast or reduce matrices E to F
118   MatScalarKokkosView rootBuf, leafBuf;
119   KokkosCsrMatrix     Fd, Fo; // F in split form
120 
121   KernelHandle kh1; // compute C1, add C1+C3 or C1+Fd
122   KernelHandle kh2; // compute C2, add C2+C4 or C2+Fo
123   KernelHandle kh3; // compute C3
124   KernelHandle kh4; // compute C4
125 
126   PetscInt E_TeamSize; // kernel launching parameters in merging E or splitting F
127   PetscInt E_VectorLength;
128   PetscInt E_RowsPerTeam;
129   PetscInt F_TeamSize;
130   PetscInt F_VectorLength;
131   PetscInt F_RowsPerTeam;
132 
133   ~MatMatStruct()
134   {
135     PetscFunctionBegin;
136     PetscCallAbort(PETSC_COMM_SELF, PetscSFDestroy(&sf));
137     PetscFunctionReturnVoid();
138   }
139 };
140 
141 struct MatMatStruct_AB : public MatMatStruct {
142   PetscIntKokkosView F_NzLeft; // plans to split F (in leafbuf) into Fd, Fo
143   PetscIntKokkosView irootloc; // plans to put E (i.e., Bd, Bo) into rootBuf
144   PetscIntKokkosView rowoffset;
145 };
146 
147 struct MatMatStruct_AtB : public MatMatStruct {
148   MatColIdxKokkosView Fdjmap; // plans to reduce data in rootBuf to Fd, Fo
149   MatColIdxKokkosView Fdjperm;
150   MatColIdxKokkosView Fojmap;
151   MatColIdxKokkosView Fojperm;
152 };
153 
154 struct MatProductData_MPIAIJKokkos {
155   MatMatStruct_AB  *mmAB     = nullptr;
156   MatMatStruct_AtB *mmAtB    = nullptr;
157   PetscBool         reusesym = PETSC_FALSE;
158   Mat               Z        = nullptr; // store Z=AB in computing BtAB
159 
160   ~MatProductData_MPIAIJKokkos()
161   {
162     delete mmAB;
163     delete mmAtB;
164     PetscCallAbort(PETSC_COMM_SELF, MatDestroy(&Z));
165   }
166 };
167 
168 static PetscErrorCode MatProductDataDestroy_MPIAIJKokkos(void *data)
169 {
170   PetscFunctionBegin;
171   PetscCallCXX(delete static_cast<MatProductData_MPIAIJKokkos *>(data));
172   PetscFunctionReturn(PETSC_SUCCESS);
173 }
174 
175 /* MatSetMPIAIJKokkosWithSplitSeqAIJKokkosMatrices - Set the diag and offdiag matrices of a MATMPIAIJKOKKOS matrix.
176    It is similar to MatCreateMPIAIJWithSplitArrays.
177 
178   Input Parameters:
179 +  mat   - the MATMPIAIJKOKKOS matrix, which should have its type and layout set, but should not have its diag, offdiag matrices set
180 .  A     - the diag matrix using local col ids
181 -  B     - the offdiag matrix using global col ids
182 
183   Output Parameter:
184 .  mat   - the updated MATMPIAIJKOKKOS matrix
185 */
186 static PetscErrorCode MatSetMPIAIJKokkosWithSplitSeqAIJKokkosMatrices(Mat mat, Mat A, Mat B, PetscInt *garray)
187 {
188   Mat_MPIAIJ *mpiaij = static_cast<Mat_MPIAIJ *>(mat->data);
189   PetscInt    m, n, M, N, Am, An, Bm, Bn;
190 
191   PetscFunctionBegin;
192   PetscCall(MatGetSize(mat, &M, &N));
193   PetscCall(MatGetLocalSize(mat, &m, &n));
194   PetscCall(MatGetLocalSize(A, &Am, &An));
195   PetscCall(MatGetLocalSize(B, &Bm, &Bn));
196 
197   PetscCheck(m == Am && m == Bm, PETSC_COMM_SELF, PETSC_ERR_PLIB, "local number of rows do not match");
198   PetscCheck(n == An, PETSC_COMM_SELF, PETSC_ERR_PLIB, "local number of columns do not match");
199   // PetscCheck(N == Bn, PETSC_COMM_SELF, PETSC_ERR_PLIB, "global number of columns do not match");
200   PetscCheck(!mpiaij->A && !mpiaij->B, PETSC_COMM_SELF, PETSC_ERR_PLIB, "A, B of the MPIAIJ matrix are not empty");
201   mpiaij->A      = A;
202   mpiaij->B      = B;
203   mpiaij->garray = garray;
204 
205   mat->preallocated     = PETSC_TRUE;
206   mat->nooffprocentries = PETSC_TRUE; /* See MatAssemblyBegin_MPIAIJ. In effect, making MatAssemblyBegin a nop */
207 
208   PetscCall(MatSetOption(mat, MAT_NO_OFF_PROC_ENTRIES, PETSC_TRUE));
209   PetscCall(MatAssemblyBegin(mat, MAT_FINAL_ASSEMBLY));
210   /* MatAssemblyEnd is critical here. It sets mat->offloadmask according to A and B's, and
211     also gets mpiaij->B compacted, with its col ids and size reduced
212   */
213   PetscCall(MatAssemblyEnd(mat, MAT_FINAL_ASSEMBLY));
214   PetscCall(MatSetOption(mat, MAT_NO_OFF_PROC_ENTRIES, PETSC_FALSE));
215   PetscCall(MatSetOption(mat, MAT_NEW_NONZERO_LOCATION_ERR, PETSC_TRUE));
216   PetscFunctionReturn(PETSC_SUCCESS);
217 }
218 
219 // Adapted from Kokkos-Kernels spmv_launch_parameters(), to get parameters in Kokkos nested loops which we used to merge or
220 // split csr matrices. The rule is to have "vector_length * team_size" be around 256 on GPUs (e.g., for a CUDA thread block)
221 template <class ExecutionSpace>
222 static PetscErrorCode MatMergeGetLaunchParameters(PetscInt numRows, PetscInt nnz, PetscInt rows_per_thread, PetscInt &team_size, PetscInt &vector_length, PetscInt &rows_per_team)
223 {
224   Kokkos::TeamPolicy<ExecutionSpace> teamPolicy(128, Kokkos::AUTO);
225 
226   PetscFunctionBegin;
227   PetscInt nnz_per_row = numRows ? (nnz / numRows) : 0; // we might meet empty matrices
228 
229   if (nnz_per_row < 1) nnz_per_row = 1;
230 
231   int max_vector_length = teamPolicy.vector_length_max();
232 
233   if (vector_length < 1) {
234     vector_length = 1;
235     while (vector_length < max_vector_length && vector_length * 6 < nnz_per_row) vector_length *= 2;
236   }
237 
238   // Determine rows per thread
239   if (rows_per_thread < 1) {
240     if (KokkosKernels::Impl::kk_is_gpu_exec_space<ExecutionSpace>()) rows_per_thread = 1;
241     else {
242       if (nnz_per_row < 20 && nnz > 5000000) {
243         rows_per_thread = 256;
244       } else rows_per_thread = 64;
245     }
246   }
247 
248   if (team_size < 1) {
249     if (KokkosKernels::Impl::kk_is_gpu_exec_space<ExecutionSpace>()) {
250       team_size = 256 / vector_length;
251     } else {
252       team_size = 1;
253     }
254   }
255 
256   rows_per_team = rows_per_thread * team_size;
257 
258   if (rows_per_team < 0) {
259     PetscInt nnz_per_team = 4096;
260     PetscInt conc         = ExecutionSpace().concurrency();
261     while ((conc * nnz_per_team * 4 > nnz) && (nnz_per_team > 256)) nnz_per_team /= 2;
262     rows_per_team = (nnz_per_team + nnz_per_row - 1) / nnz_per_row;
263   }
264   PetscFunctionReturn(PETSC_SUCCESS);
265 }
266 
267 /*
268   Reduce two sets of global indices into local ones
269 
270   Input Parameters:
271 +  n1          - size of garray1[], the first set
272 .  garray1[n1] - a sorted global index array (without duplicates)
273 .  m           - size of indices[], the second set
274 -  indices[m]  - a unsorted global index array (might have duplicates), which will be updated on output into local ones
275 
276   Output Parameters:
277 +  n2          - size of garray2[], the merged set, which combines garray1[] and indices[]
278 .  garray2[n2] - allocated by callee using PetscMalloc1(). Contains sorted unique global indices (without duplicates). Caller needs to free it.
279 .  map[n1]     - allocated by caller. It gives garray1[i] = garray2[map[i]]
280 -  indices[m]  - on output, global indices in this array are rewritten with local ones, i.e, indices_input[i] = garray2[indices_output[i]]
281 
282    Example, say
283     n1         = 5
284     garray1[5] = {1, 4, 7, 8, 10}
285     m          = 4
286     indices[4] = {2, 4, 8, 9}
287 
288    Combining them together, we have 7 global indices in garray2[]
289     n2         = 7
290     garray2[7] = {1, 2, 4, 7, 8, 9, 10}
291 
292    And we have map[] to connect "garray1[i] = garray2[map[i]], i=[0,n1)"
293     map[5] = {0, 2, 3, 4, 6}
294 
295    On output, indices[] is updated with local indices
296     indices[4] = {1, 2, 4, 5}
297 */
298 static PetscErrorCode ReduceTwoSetsOfGlobalIndices(PetscInt n1, const PetscInt *garray1, PetscInt m, PetscInt *indices, PetscInt *n2_, PetscInt **garray2_, PetscInt *map)
299 {
300   PetscHMapI    g2l = nullptr;
301   PetscHashIter iter;
302   PetscInt      tot, key, val; // total unique global indices. key is global id; val is local id
303   PetscInt      n2, *garray2;
304 
305   PetscFunctionBegin;
306   tot = 0;
307   PetscCall(PetscHMapICreateWithSize(n1, &g2l));
308   for (PetscInt i = 0; i < m; i++) {                                // insert those in indices[]
309     PetscCall(PetscHMapIGetWithDefault(g2l, indices[i], -1, &val)); // if not exist, val is set with -1
310     if (val < 0) PetscCall(PetscHMapISet(g2l, indices[i], tot++));  // val < 0 means gid is not in the hash table yet
311   }
312 
313   for (PetscInt i = 0; i < n1; i++) { // insert those in garray1[]
314     PetscCall(PetscHMapIGetWithDefault(g2l, garray1[i], -1, &val));
315     if (val < 0) PetscCall(PetscHMapISet(g2l, garray1[i], tot++));
316   }
317 
318   // Pull out (unique) globals in the hash table and put them in garray2[]
319   n2 = tot;
320   PetscCall(PetscMalloc1(n2, &garray2));
321   tot = 0;
322   PetscHashIterBegin(g2l, iter);
323   while (!PetscHashIterAtEnd(g2l, iter)) {
324     PetscHashIterGetKey(g2l, iter, key);
325     PetscHashIterNext(g2l, iter);
326     garray2[tot++] = key;
327   }
328 
329   // Sort garray2[] and then map them to local indices starting from 0
330   PetscCall(PetscSortInt(n2, garray2));
331   PetscCall(PetscHMapIClear(g2l));
332   for (PetscInt i = 0; i < tot; i++) PetscCall(PetscHMapISet(g2l, garray2[i], i)); // i is the local id
333 
334   // Rewrite indices[] with local indices
335   for (PetscInt i = 0; i < m; i++) {
336     PetscCall(PetscHMapIGetWithDefault(g2l, indices[i], -1, &val));
337     PetscAssert(val >= 0, PETSC_COMM_SELF, PETSC_ERR_PLIB, "Met a negative local column index");
338     indices[i] = val;
339   }
340   // Record the map that maps garray1[i] to garray2[map[i]]
341   for (PetscInt i = 0; i < n1; i++) PetscCall(PetscHMapIGetWithDefault(g2l, garray1[i], -1, &map[i]));
342   PetscCall(PetscHMapIDestroy(&g2l));
343   *n2_      = n2;
344   *garray2_ = garray2;
345   PetscFunctionReturn(PETSC_SUCCESS);
346 }
347 
348 /*
349   MatMPIAIJKokkosReduce - Reduce rows of a MPIAIJKOKKOS matrix (E, in split form) to produce another matrix (F, also in split form, stored in mm)
350 
351   It is the reverse of MatMPIAIJKokkosBcast() in some sense, but with a different signature since we do not really need a fully populated MPIAIJKOKKOS E.
352 
353   Think each row of E as a leaf, then the given ownerSF specifies roots for the leaves. Roots may connect to multiple leaves.
354   In this routine, we sparse-merge leaves (rows) at their roots to form potentially longer rows in F. F's number of rows will be nroots of ownerSF.
355 
356   Input Parameters:
357 +  comm       - MPI communicator of E
358 .  A          - diag block of E, using local column indices
359 .  B          - off-diag block of E, using local column indices
360 .  cstart      - (global) start column of Ed
361 .  cend        - (global) end column + 1 of Ed.  In other words, E's column ownership is in range of [cstart, cend)
362 .  garray1[n1] - global column indices of Eo. Here n1 is Eo's column size.
363 .  ownerSF     - the SF specifies ownership (root) of rows in E
364 .  reuse       - either MAT_INITIAL_MATRIX or MAT_REUSE_MATRIX
365 -  mm          - to stash intermediate data structures for reuse
366 
367   Output Parameters:
368 +  map[n1]  - allocated by caller. It maps garray1[] to garray2[]. See more at ReduceTwoSetsOfGlobalIndices().
369 -  mm       - contains various info, such as garray2[], F (Fd, Fo) etc.
370 
371   Notes:
372   When reuse = MAT_REUSE_MATRIX, cstart, cend, garray1, ownerSF, map are not significant.
373 
374  */
375 static PetscErrorCode MatMPIAIJKokkosReduceBegin(MPI_Comm comm, KokkosCsrMatrix A, KokkosCsrMatrix B, PetscInt cstart, PetscInt cend, const PetscInt *garray1, PetscSF ownerSF, MatReuse reuse, PetscInt *map, MatMatStruct_AtB *mm)
376 {
377   PetscFunctionBegin;
378   if (reuse == MAT_INITIAL_MATRIX) {
379     PetscInt Em = A.numRows(), Fm;
380     PetscInt n1 = B.numCols();
381 
382     PetscCall(PetscSFGetGraph(ownerSF, &Fm, NULL, NULL, NULL)); // Fm = #rows of F = nroots of ownerSF
383 
384     // Do the analysis on host
385     auto                 Ai_h = Kokkos::create_mirror_view_and_copy(Kokkos::HostSpace(), A.graph.row_map);
386     auto                 Aj_h = Kokkos::create_mirror_view_and_copy(Kokkos::HostSpace(), A.graph.entries);
387     auto                 Bi_h = Kokkos::create_mirror_view_and_copy(Kokkos::HostSpace(), B.graph.row_map);
388     auto                 Bj_h = Kokkos::create_mirror_view_and_copy(Kokkos::HostSpace(), B.graph.entries);
389     const MatRowMapType *Ai = Ai_h.data(), *Bi = Bi_h.data();
390     const MatColIdxType *Aj = Aj_h.data(), *Bj = Bj_h.data();
391 
392     // Count how many nonzeros of each row in E are in the left of the diag block [cstart,cend)
393     PetscIntKokkosViewHost E_NzLeft_h(NoInit("E_NzLeft_h"), Em), E_RowLen_h(NoInit("E_RowLen_h"), Em);
394     PetscInt              *E_NzLeft = E_NzLeft_h.data(), *E_RowLen = E_RowLen_h.data();
395     for (PetscInt i = 0; i < Em; i++) {
396       const PetscInt *first, *last, *it;
397       PetscInt        count, step;
398       // std::lower_bound(first,last,cstart), but need to use global column indices
399       first = Bj + Bi[i];
400       last  = Bj + Bi[i + 1];
401       count = last - first;
402       while (count > 0) {
403         it   = first;
404         step = count / 2;
405         it += step;
406         if (garray1[*it] < cstart) { // map local to global
407           first = ++it;
408           count -= step + 1;
409         } else count = step;
410       }
411       E_NzLeft[i] = first - (Bj + Bi[i]);
412       E_RowLen[i] = (Ai[i + 1] - Ai[i]) + (Bi[i + 1] - Bi[i]);
413     }
414 
415     // Get length of rows (i.e., sizes of leaves) that contribute to my roots
416     const PetscMPIInt *iranks, *ranks;
417     const PetscInt    *ioffset, *irootloc, *roffset, *rmine;
418     PetscInt           niranks, nranks;
419     MPI_Request       *reqs;
420     PetscMPIInt        tag;
421     PetscSF            reduceSF;
422     PetscInt          *sdisp, *rdisp;
423 
424     PetscCall(PetscCommGetNewTag(comm, &tag));
425     PetscCall(PetscSFGetLeafRanks(ownerSF, &niranks, &iranks, &ioffset, &irootloc));  // get leaf ranks connecting to roots on this process (I'll recv from them)
426     PetscCall(PetscSFGetRootRanks(ownerSF, &nranks, &ranks, &roffset, &rmine, NULL)); // get root ranks this process connects (I'll send to them)
427 
428     // Find out length of each row I will receive. Even for the same row index, when they are from
429     // different senders, they might have different lengths (and sparsity patterns)
430     PetscInt  sendRowCnt = roffset[nranks], recvRowCnt = ioffset[niranks];
431     PetscInt *sendRowLen, *recvRowLen; // lengths of rows of I need to send/recv per process
432 
433     PetscCall(PetscMalloc5(sendRowCnt, &sendRowLen, recvRowCnt + 1, &recvRowLen, nranks, &sdisp, niranks + 1, &rdisp, nranks + niranks, &reqs));
434 
435     for (PetscInt i = 0; i < sendRowCnt; i++) sendRowLen[i] = E_RowLen[rmine[i]];
436     recvRowLen[0] = 0; // since we will make it in CSR format later
437     recvRowLen++;      // advance the pointer now
438     for (PetscInt i = 0; i < niranks; i++) { MPI_Irecv(&recvRowLen[ioffset[i]], ioffset[i + 1] - ioffset[i], MPIU_INT, iranks[i], tag, comm, &reqs[nranks + i]); }
439     for (PetscInt i = 0; i < nranks; i++) { MPI_Isend(&sendRowLen[roffset[i]], roffset[i + 1] - roffset[i], MPIU_INT, ranks[i], tag, comm, &reqs[i]); }
440     PetscCallMPI(MPI_Waitall(nranks + niranks, reqs, MPI_STATUSES_IGNORE));
441 
442     // Build the real PetscSF for reducing E rows (buffer to buffer)
443     rdisp[0] = 0;
444     for (PetscInt i = 0; i < niranks; i++) {
445       rdisp[i + 1] = rdisp[i];
446       for (PetscInt j = ioffset[i]; j < ioffset[i + 1]; j++) { rdisp[i + 1] += recvRowLen[j]; }
447     }
448     recvRowLen--; // put it back into csr format
449     for (PetscInt i = 0; i < recvRowCnt; i++) recvRowLen[i + 1] += recvRowLen[i];
450 
451     for (PetscInt i = 0; i < nranks; i++) { MPI_Irecv(&sdisp[i], 1, MPIU_INT, ranks[i], tag, comm, &reqs[i]); }
452     for (PetscInt i = 0; i < niranks; i++) { MPI_Isend(&rdisp[i], 1, MPIU_INT, iranks[i], tag, comm, &reqs[nranks + i]); }
453     PetscCallMPI(MPI_Waitall(nranks + niranks, reqs, MPI_STATUSES_IGNORE));
454 
455     PetscInt     nleaves = 0, Enz = 0;    // leaves are nonzeros I will send
456     PetscInt     nroots = rdisp[niranks]; // roots are nonzeros I will recv
457     PetscSFNode *iremote;
458 
459     for (PetscInt i = 0; i < Em; i++) Enz += E_RowLen[i];
460     PetscAssert(A.nnz() + B.nnz() == Enz, comm, PETSC_ERR_PLIB, "Enz should be equal to sum of nnz of A and B");
461     PetscCallMPI(PetscMalloc1(Enz, &iremote)); // no free, since we give ownership to reduceSF
462 
463     for (PetscInt i = 0; i < nranks; i++) {
464       PetscInt count = 0;
465       for (PetscInt j = roffset[i]; j < roffset[i + 1]; j++) count += E_RowLen[rmine[j]];
466       for (PetscInt j = 0; j < count; j++) {
467         iremote[nleaves + j].rank  = ranks[i];
468         iremote[nleaves + j].index = sdisp[i] + j;
469       }
470       nleaves += count;
471     }
472     PetscCheck(nleaves == Enz, comm, PETSC_ERR_PLIB, "nleaves should be equal to Enz");
473 
474     PetscCall(PetscSFCreate(comm, &reduceSF));
475     PetscCall(PetscSFSetGraph(reduceSF, nroots, nleaves, NULL, PETSC_OWN_POINTER, iremote, PETSC_OWN_POINTER));
476 
477     // Copy (global) column indices of the needed rows in E to sendCol[], and then PetscSFReduce to recvCol[]
478     PetscInt *sendCol, *recvCol;
479     PetscCall(PetscMalloc2(nleaves, &sendCol, nroots, &recvCol));
480     for (PetscInt k = 0; k < roffset[nranks]; k++) {
481       PetscInt  i      = rmine[k]; // row to be copied
482       PetscInt *buf    = &sendCol[Ai[i] + Bi[i]];
483       PetscInt  nzLeft = E_NzLeft[i];
484       PetscInt  alen = Ai[i + 1] - Ai[i], blen = Bi[i + 1] - Bi[i];
485       for (PetscInt j = 0; j < alen + blen; j++) {
486         if (j < nzLeft) {
487           buf[j] = garray1[Bj[Bi[i] + j]]; // left B, in global
488         } else if (j < nzLeft + alen) {
489           buf[j] = Aj[Ai[i] + j - nzLeft] + cstart; // diag A, also in global
490         } else {
491           buf[j] = garray1[Bj[Bi[i] + j - alen]]; // right B, in global
492         }
493       }
494     }
495     PetscCall(PetscSFReduceWithMemTypeBegin(reduceSF, MPIU_INT, PETSC_MEMTYPE_HOST, sendCol, PETSC_MEMTYPE_HOST, recvCol, MPI_REPLACE));
496     PetscCall(PetscSFReduceEnd(reduceSF, MPIU_INT, sendCol, recvCol, MPI_REPLACE));
497 
498     // With recvCol[], we do a series of analysis to get i, j of Fd, Fo, and build plans to reduce nonzeros in recv buffers to Fd and Fo
499     PetscInt *recvRowPerm, *recvColSorted;
500     PetscInt *recvNzPerm, *recvNzPermSorted;
501     PetscCall(PetscMalloc4(recvRowCnt, &recvRowPerm, nroots, &recvColSorted, nroots, &recvNzPerm, nroots, &recvNzPermSorted));
502 
503     for (PetscInt i = 0; i < nroots; i++) recvNzPerm[i] = i;                   // numbering all received nonzeros
504     for (PetscInt i = 0; i < recvRowCnt; i++) recvRowPerm[i] = i;              // put up a permutation array, so that after sorting we know where to get a row in recvCol[]
505     PetscCall(PetscSortIntWithPermutation(recvRowCnt, irootloc, recvRowPerm)); // irootloc[] (owned by ownerSF) won't be changed
506 
507     // i[] array, nz are always easiest to compute
508     MatRowMapKokkosViewHost Fdi_h(NoInit("Fdi_h"), Fm + 1), Foi_h(NoInit("Foi_h"), Fm + 1);
509     MatRowMapType          *Fdi, *Foi;
510     PetscInt                FnzDups = 0, Fdnz = 0, FdnzDups = 0, Fonz = 0, FonzDups = 0; // nz (with or without dups) in F, Fd, Fo
511     PetscInt                iter;
512 
513     Kokkos::deep_copy(Fdi_h, 0); // zero, as we will do 'val++' on them
514     Kokkos::deep_copy(Foi_h, 0);
515     Fdi  = Fdi_h.data() + 1; // +1 for easy indexing in code below
516     Foi  = Foi_h.data() + 1;
517     iter = 0;
518     while (iter < recvRowCnt) { // iter over received rows
519       PetscInt curRowIdx = irootloc[recvRowPerm[iter]];
520       PetscInt dupRows   = 1; // current row has this many contributing rows (of various sparsity patterns)
521 
522       while (iter + dupRows < recvRowCnt && irootloc[recvRowPerm[iter + dupRows]] == curRowIdx) dupRows++;
523 
524       // Copy column indices (and their permutation) of these rows into recvColSorted & recvNzPermSorted
525       PetscInt  nz    = 0; // nz (with dups) in the current row
526       PetscInt *jbuf  = recvColSorted + FnzDups;
527       PetscInt *pbuf  = recvNzPermSorted + FnzDups;
528       PetscInt *jbuf2 = jbuf; // temp pointers
529       PetscInt *pbuf2 = pbuf;
530       for (PetscInt d = 0; d < dupRows; d++) {
531         PetscInt i   = recvRowPerm[iter + d];
532         PetscInt len = recvRowLen[i + 1] - recvRowLen[i];
533         PetscCall(PetscArraycpy(jbuf2, &recvCol[recvRowLen[i]], len));
534         PetscCall(PetscArraycpy(pbuf2, &recvNzPerm[recvRowLen[i]], len));
535         jbuf2 += len;
536         pbuf2 += len;
537         nz += len;
538       }
539       PetscCall(PetscIntSortSemiOrderedWithArray(nz, jbuf, pbuf)); // It could be improved with k-way merge sort, since the rows are already sorted
540 
541       // Scan column indices (in jbuf[0,nz), might have dups) of this row, and see how many go to Fd and how many go to Fo
542       PetscInt cur = 0;
543       while (cur < nz) {
544         PetscInt curColIdx = jbuf[cur];
545         PetscInt dups      = 1;
546 
547         while (cur + dups < nz && jbuf[cur + dups] == curColIdx) dups++;
548         if (curColIdx >= cstart && curColIdx < cend) {
549           Fdi[curRowIdx]++;
550           FdnzDups += dups;
551         } else {
552           Foi[curRowIdx]++;
553           FonzDups += dups;
554         }
555         cur += dups;
556       }
557 
558       FnzDups += nz;
559       iter += dupRows; // Move to next unique row
560     }
561 
562     Fdi = Fdi_h.data(); // restore Fdi, Foi and make them CSR
563     Foi = Foi_h.data();
564     for (PetscInt i = 0; i < Fm; i++) {
565       Fdi[i + 1] += Fdi[i];
566       Foi[i + 1] += Foi[i];
567     }
568     Fdnz = Fdi[Fm];
569     Fonz = Foi[Fm];
570     PetscCall(PetscFree2(sendCol, recvCol));
571 
572     // Allocate j, jmap, jperm for Fd and Fo
573     MatColIdxKokkosViewHost Fdj_h(NoInit("Fdj_h"), Fdnz), Foj_h(NoInit("Foj_h"), Fonz);
574     MatRowMapKokkosViewHost Fdjmap_h(NoInit("Fdjmap_h"), Fdnz + 1), Fojmap_h(NoInit("Fojmap_h"), Fonz + 1); // +1 to make csr
575     MatRowMapKokkosViewHost Fdjperm_h(NoInit("Fdjperm_h"), FdnzDups), Fojperm_h(NoInit("Fojperm_h"), FonzDups);
576     MatColIdxType          *Fdj = Fdj_h.data(), *Foj = Foj_h.data();
577     MatRowMapType          *Fdjmap = Fdjmap_h.data(), *Fojmap = Fojmap_h.data();
578     MatRowMapType          *Fdjperm = Fdjperm_h.data(), *Fojperm = Fojperm_h.data();
579 
580     // Scan recvColSorted[] again, and fill j, jmap, jperm for Fd and Fo
581     Fdjmap[0] = 0;
582     Fojmap[0] = 0;
583     FnzDups   = 0;
584     Fdnz      = 0;
585     Fonz      = 0;
586     iter      = 0; // iter over received rows
587     while (iter < recvRowCnt) {
588       PetscInt curRowIdx = irootloc[recvRowPerm[iter]]; // current row idx
589       PetscInt dupRows   = 1;                           // It has this many contributing rows (of various lengths)
590       PetscInt nz        = 0;                           // nz (with dups) in the current row
591 
592       while (iter + dupRows < recvRowCnt && irootloc[recvRowPerm[iter + dupRows]] == curRowIdx) dupRows++;
593       for (PetscInt d = 0; d < dupRows; d++) {
594         PetscInt i = recvRowPerm[iter + d];
595         nz += recvRowLen[i + 1] - recvRowLen[i];
596       }
597 
598       PetscInt *jbuf = recvColSorted + FnzDups;
599       // Scan columns (in jbuf[0,nz) of this row, copy them and their permutation to j[] and jperm[] of Fd and Fo
600       PetscInt cur = 0;
601       while (cur < nz) {
602         PetscInt curColIdx = jbuf[cur];
603         PetscInt dups      = 1;
604 
605         while (cur + dups < nz && jbuf[cur + dups] == curColIdx) dups++;
606         if (curColIdx >= cstart && curColIdx < cend) {
607           Fdj[Fdnz]        = curColIdx - cstart; // easily convert to local
608           Fdjmap[Fdnz + 1] = Fdjmap[Fdnz] + dups;
609           for (PetscInt j = 0; j < dups; j++) Fdjperm[Fdjmap[Fdnz] + j] = recvNzPermSorted[FnzDups + j];
610           FdnzDups += dups;
611           Fdnz++;
612         } else {
613           Foj[Fonz]        = curColIdx; // in global
614           Fojmap[Fonz + 1] = Fojmap[Fonz] + dups;
615           for (PetscInt j = 0; j < dups; j++) Fojperm[Fojmap[Fonz] + j] = recvNzPermSorted[FnzDups + j];
616           FonzDups += dups;
617           Fonz++;
618         }
619         cur += dups;
620         FnzDups += dups;
621       }
622       iter += dupRows; // Move to next unique row
623     }
624     PetscCall(PetscFree4(recvRowPerm, recvColSorted, recvNzPerm, recvNzPermSorted));
625     PetscCall(PetscFree5(sendRowLen, recvRowLen, sdisp, rdisp, reqs));
626 
627     // Combine global column indices in garray1[] and Foj[]
628     PetscInt n2, *garray2;
629 
630     PetscCall(ReduceTwoSetsOfGlobalIndices(n1, garray1, Fonz, Foj, &n2, &garray2, map));
631     mm->sf       = reduceSF;
632     mm->leafBuf  = MatScalarKokkosView(NoInit("leafBuf"), nleaves);
633     mm->rootBuf  = MatScalarKokkosView(NoInit("rootBuf"), nroots);
634     mm->garray   = garray2; // give ownership, so no free
635     mm->n        = n2;
636     mm->E_NzLeft = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), E_NzLeft_h);
637     mm->Fdjmap   = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Fdjmap_h);
638     mm->Fdjperm  = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Fdjperm_h);
639     mm->Fojmap   = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Fojmap_h);
640     mm->Fojperm  = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Fojperm_h);
641 
642     // Output Fd and Fo in KokkosCsrMatrix format
643     MatScalarKokkosView Fda_d(NoInit("Fda_d"), Fdnz);
644     MatRowMapKokkosView Fdi_d = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Fdi_h);
645     MatColIdxKokkosView Fdj_d = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Fdj_h);
646     MatScalarKokkosView Foa_d(NoInit("Foa_d"), Fonz);
647     MatRowMapKokkosView Foi_d = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Foi_h);
648     MatColIdxKokkosView Foj_d = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Foj_h);
649 
650     PetscCallCXX(mm->Fd = KokkosCsrMatrix("Fd", Fm, cend - cstart, Fdnz, Fda_d, Fdi_d, Fdj_d));
651     PetscCallCXX(mm->Fo = KokkosCsrMatrix("Fo", Fm, n2, Fonz, Foa_d, Foi_d, Foj_d)); // Fo's column size is n2, length of garray2[]
652 
653     // Compute kernel launch parameters in merging E
654     PetscInt teamSize, vectorLength, rowsPerTeam;
655 
656     teamSize = vectorLength = rowsPerTeam = -1;
657     PetscCall(MatMergeGetLaunchParameters<DefaultExecutionSpace>(Em, Enz, -1, teamSize, vectorLength, rowsPerTeam));
658     mm->E_TeamSize     = teamSize;
659     mm->E_VectorLength = vectorLength;
660     mm->E_RowsPerTeam  = rowsPerTeam;
661   } else PetscCheck(reuse == MAT_REUSE_MATRIX, comm, PETSC_ERR_PLIB, "Unsupported MatReuse enum %d", reuse);
662 
663   // Handy aliases
664   auto       &Aa           = A.values;
665   auto       &Ba           = B.values;
666   const auto &Ai           = A.graph.row_map;
667   const auto &Bi           = B.graph.row_map;
668   const auto &E_NzLeft     = mm->E_NzLeft;
669   auto       &leafBuf      = mm->leafBuf;
670   auto       &rootBuf      = mm->rootBuf;
671   PetscSF     reduceSF     = mm->sf;
672   PetscInt    Em           = A.numRows();
673   PetscInt    teamSize     = mm->E_TeamSize;
674   PetscInt    vectorLength = mm->E_VectorLength;
675   PetscInt    rowsPerTeam  = mm->E_RowsPerTeam;
676   PetscInt    workSets     = (Em + rowsPerTeam - 1) / rowsPerTeam;
677 
678   // Copy rows in A/B of E to leafBuf, then pass it to rootBuf
679   PetscCallCXX(Kokkos::parallel_for(
680     Kokkos::TeamPolicy<>(PetscGetKokkosExecutionSpace(), workSets, teamSize, vectorLength), KOKKOS_LAMBDA(const KokkosTeamMemberType &t) {
681       Kokkos::parallel_for(Kokkos::TeamThreadRange(t, 0, rowsPerTeam), [&](PetscInt k) {
682         PetscInt i = t.league_rank() * rowsPerTeam + k; // i-th row in F
683         if (i < Em) {
684           PetscInt disp   = Ai(i) + Bi(i);
685           PetscInt alen   = Ai(i + 1) - Ai(i);
686           PetscInt blen   = Bi(i + 1) - Bi(i);
687           PetscInt nzleft = E_NzLeft(i);
688 
689           Kokkos::parallel_for(Kokkos::ThreadVectorRange(t, alen + blen), [&](PetscInt j) {
690             MatScalar &val = leafBuf(disp + j);
691             if (j < nzleft) { // B left
692               val = Ba(Bi(i) + j);
693             } else if (j < nzleft + alen) { // diag A
694               val = Aa(Ai(i) + j - nzleft);
695             } else { // B right
696               val = Ba(Bi(i) + j - alen);
697             }
698           });
699         }
700       });
701     }));
702   PetscCall(PetscSFReduceWithMemTypeBegin(reduceSF, MPIU_SCALAR, PETSC_MEMTYPE_KOKKOS, leafBuf.data(), PETSC_MEMTYPE_KOKKOS, rootBuf.data(), MPI_REPLACE));
703   PetscFunctionReturn(PETSC_SUCCESS);
704 }
705 
706 // To finish MatMPIAIJKokkosReduce.
707 static PetscErrorCode MatMPIAIJKokkosReduceEnd(MPI_Comm comm, KokkosCsrMatrix A, KokkosCsrMatrix B, PetscInt cstart, PetscInt cend, const PetscInt *garray1, PetscSF ownerSF, MatReuse reuse, PetscInt *map, MatMatStruct_AtB *mm)
708 {
709   auto       &leafBuf  = mm->leafBuf;
710   auto       &rootBuf  = mm->rootBuf;
711   auto       &Fda      = mm->Fd.values;
712   const auto &Fdjmap   = mm->Fdjmap;
713   const auto &Fdjperm  = mm->Fdjperm;
714   auto        Fdnz     = mm->Fd.nnz();
715   auto       &Foa      = mm->Fo.values;
716   const auto &Fojmap   = mm->Fojmap;
717   const auto &Fojperm  = mm->Fojperm;
718   auto        Fonz     = mm->Fo.nnz();
719   PetscSF     reduceSF = mm->sf;
720 
721   PetscFunctionBegin;
722   PetscCall(PetscSFReduceEnd(reduceSF, MPIU_SCALAR, leafBuf.data(), rootBuf.data(), MPI_REPLACE));
723 
724   // Reduce data in rootBuf to Fd and Fo
725   PetscCallCXX(Kokkos::parallel_for(
726     Kokkos::RangePolicy<>(PetscGetKokkosExecutionSpace(), 0, Fdnz), KOKKOS_LAMBDA(const MatRowMapType i) {
727       PetscScalar sum = 0.0;
728       for (MatRowMapType k = Fdjmap(i); k < Fdjmap(i + 1); k++) sum += rootBuf(Fdjperm(k));
729       Fda(i) = sum;
730     }));
731 
732   PetscCallCXX(Kokkos::parallel_for(
733     Kokkos::RangePolicy<>(PetscGetKokkosExecutionSpace(), 0, Fonz), KOKKOS_LAMBDA(const MatRowMapType i) {
734       PetscScalar sum = 0.0;
735       for (MatRowMapType k = Fojmap(i); k < Fojmap(i + 1); k++) sum += rootBuf(Fojperm(k));
736       Foa(i) = sum;
737     }));
738   PetscFunctionReturn(PETSC_SUCCESS);
739 }
740 
741 /*
742   MatMPIAIJKokkosBcast - Bcast local rows of a MPIAIJKOKKOS matrix (E) to produce a local matrix (F, stored in mm) in split form
743 
744   This is a complex routine. It is essentially the MPIAIJKOKKOS counterpart of MatGetBrowsOfAoCols_MPIAIJ, but supports
745   device and involves various index mapping.
746 
747   In the given ownerSF, leaves correspond to rows in F, and roots correspond to rows in E. Roots may connect to multiple leaves.
748   Suppose F's j-th row is connected to a root identified by PetscSFNode (k,i), it means we need to bcast the i-th row of E on rank k
749   to j-th row of F. ownerSF is not an arbitrary SF, instead it is the Mvctx of another MPIAIJ matrix A that is able to perform A*E.
750   F has the same column layout as E.
751 
752   Conceptually F has global column indices. In this routine, we spit F into diagonal Fd and off-diagonal Fo.
753   Fd uses local column indices, which are easy to compute. We just need to subtract the "local column range start" from the global indices.
754   Fo had global column indices at first. We will reduce them into local ones. In doing that, we also take into account the global
755   column indices that E's off-diag block has. Let's say there are n1 such indices stored in garray1[]. We will reduce them along with
756   column indices in Fo and update Fo with local indices.
757 
758    Input Parameters:
759 +   E       - the MPIAIJKOKKOS matrix
760 .   ownerSF - the ownership SF (insignificant in MAT_REUSE_MATRIX)
761 .   reuse   - either MAT_INITIAL_MATRIX or MAT_REUSE_MATRIX
762 -   mm      - to stash matproduct intermediate data structures
763 
764     Output Parameters:
765 +   map[n1] - allocated by caller. It maps garray1[] to garray2[]. See more at ReduceTwoSetsOfGlobalIndices.
766 -   mm      - contains various info, such as garray2[], Fd, Fo, etc.
767 
768     Notes:
769     When reuse = MAT_REUSE_MATRIX, ownerSF, map are not significant.
770     The routine is provide in split-phase form MatMPIAIJKokkosBcastBegin/End() to provide computation/communication opportunities.
771 */
772 static PetscErrorCode MatMPIAIJKokkosBcastBegin(Mat E, PetscSF ownerSF, MatReuse reuse, PetscInt *map, MatMatStruct_AB *mm)
773 {
774   Mat_MPIAIJ       *empi = static_cast<Mat_MPIAIJ *>(E->data);
775   Mat               A = empi->A, B = empi->B; // diag and off-diag
776   Mat_SeqAIJKokkos *akok = static_cast<Mat_SeqAIJKokkos *>(A->spptr), *bkok = static_cast<Mat_SeqAIJKokkos *>(B->spptr);
777   PetscInt          Em = E->rmap->n; // #local rows
778   MPI_Comm          comm;
779 
780   PetscFunctionBegin;
781   PetscCallMPI(PetscObjectGetComm((PetscObject)E, &comm));
782   if (reuse == MAT_INITIAL_MATRIX) {
783     Mat_SeqAIJ     *aseq = static_cast<Mat_SeqAIJ *>(A->data), *bseq = static_cast<Mat_SeqAIJ *>(B->data);
784     PetscInt        n1 = B->cmap->n, *Ai = aseq->i, *Aj = aseq->j, *Bi = bseq->i, *Bj = bseq->j;
785     const PetscInt *garray1 = empi->garray; // its size is n1
786     PetscInt        cstart, cend;
787     PetscSF         bcastSF;
788 
789     PetscCall(MatGetOwnershipRangeColumn(E, &cstart, &cend));
790 
791     // Count how many nonzeros of each row in E are in the left of the diag block [cstart,cend)
792     PetscIntKokkosViewHost E_NzLeft_h(NoInit("E_NzLeft_h"), Em), E_RowLen_h(NoInit("E_RowLen_h"), Em);
793     PetscInt              *E_NzLeft = E_NzLeft_h.data(), *E_RowLen = E_RowLen_h.data();
794     for (PetscInt i = 0; i < Em; i++) {
795       const PetscInt *first, *last, *it;
796       PetscInt        count, step;
797       // std::lower_bound(first,last,cstart), but need to use global column indices
798       first = Bj + Bi[i];
799       last  = Bj + Bi[i + 1];
800       count = last - first;
801       while (count > 0) {
802         it   = first;
803         step = count / 2;
804         it += step;
805         if (empi->garray[*it] < cstart) { // map local to global
806           first = ++it;
807           count -= step + 1;
808         } else count = step;
809       }
810       E_NzLeft[i] = first - (Bj + Bi[i]);
811       E_RowLen[i] = (Ai[i + 1] - Ai[i]) + (Bi[i + 1] - Bi[i]);
812     }
813 
814     // Compute row pointer Fi of F
815     PetscInt *Fi, Fm, Fnz;
816     PetscCall(PetscSFGetGraph(ownerSF, NULL, &Fm, NULL, NULL)); // Fm = #rows of F = nleaves of ownerSF
817     PetscCall(PetscMalloc1(Fm + 1, &Fi));
818     Fi[0] = 0;
819     PetscCall(PetscSFBcastWithMemTypeBegin(ownerSF, MPIU_INT, PETSC_MEMTYPE_HOST, E_RowLen, PETSC_MEMTYPE_HOST, &Fi[1], MPI_REPLACE));
820     PetscCall(PetscSFBcastEnd(ownerSF, MPIU_INT, E_RowLen, &Fi[1], MPI_REPLACE));
821     for (PetscInt i = 0; i < Fm; i++) Fi[i + 1] += Fi[i];
822     Fnz = Fi[Fm];
823 
824     // Build the real PetscSF for bcasting E rows (buffer to buffer)
825     const PetscMPIInt *iranks, *ranks;
826     const PetscInt    *ioffset, *irootloc, *roffset;
827     PetscInt           niranks, nranks, *sdisp, *rdisp;
828     MPI_Request       *reqs;
829     PetscMPIInt        tag;
830 
831     PetscCall(PetscSFGetLeafRanks(ownerSF, &niranks, &iranks, &ioffset, &irootloc)); // get leaf ranks referencing roots on this process
832     PetscCall(PetscSFGetRootRanks(ownerSF, &nranks, &ranks, &roffset, NULL, NULL));  // recv info
833     PetscCall(PetscMalloc3(niranks + 1, &sdisp, nranks, &rdisp, niranks + nranks, &reqs));
834 
835     sdisp[0] = 0; // send displacement
836     for (PetscInt i = 0; i < niranks; i++) {
837       sdisp[i + 1] = sdisp[i];
838       for (PetscInt j = ioffset[i]; j < ioffset[i + 1]; j++) {
839         PetscInt r = irootloc[j]; // row to be sent
840         sdisp[i + 1] += E_RowLen[r];
841       }
842     }
843 
844     PetscCallMPI(PetscCommGetNewTag(comm, &tag));
845     for (PetscInt i = 0; i < nranks; i++) PetscCallMPI(MPI_Irecv(&rdisp[i], 1, MPIU_INT, ranks[i], tag, comm, &reqs[i]));
846     for (PetscInt i = 0; i < niranks; i++) PetscCallMPI(MPI_Isend(&sdisp[i], 1, MPIU_INT, iranks[i], tag, comm, &reqs[nranks + i]));
847     PetscCallMPI(MPI_Waitall(niranks + nranks, reqs, MPI_STATUSES_IGNORE));
848 
849     PetscInt     nleaves = Fnz;            // leaves are nonzeros I will receive
850     PetscInt     nroots  = sdisp[niranks]; // roots are nonzeros I will send
851     PetscSFNode *iremote;                  // give ownership to bcastSF
852     PetscCall(PetscMalloc1(nleaves, &iremote));
853     for (PetscInt i = 0; i < nranks; i++) { // for each sender rank
854       PetscInt k = 0;
855       for (PetscInt j = Fi[roffset[i]]; j < Fi[roffset[i + 1]]; j++) { // I will receive rows [roffset[i], roffset[i+1]) of F from ranks[i]
856         iremote[j].rank  = ranks[i];
857         iremote[j].index = rdisp[i] + k; // their root location
858         k++;
859       }
860     }
861     PetscCall(PetscSFCreate(comm, &bcastSF));
862     PetscCall(PetscSFSetGraph(bcastSF, nroots, nleaves, NULL, PETSC_OWN_POINTER, iremote, PETSC_OWN_POINTER));
863     PetscCall(PetscFree3(sdisp, rdisp, reqs));
864 
865     // Build a plan (rowoffset, irootloc, E_NzLeft) to copy rows in E to rootdata of bcastSF in parallel
866     PetscIntKokkosViewHost rowoffset_h(NoInit("rowoffset_h"), ioffset[niranks] + 1);
867     PetscInt              *rowoffset = rowoffset_h.data(); // for each entry (row) indicated in irootloc[], we calculate its destinate offset in copying
868     rowoffset[0]                     = 0;
869     for (PetscInt i = 0; i < ioffset[niranks]; i++) { rowoffset[i + 1] = rowoffset[i] + E_RowLen[irootloc[i]]; }
870 
871     // Copy (global) column indices of the needed rows in E to a buffer, and then bcast to Fj[]
872     PetscInt *jbuf, *Fj;
873     PetscCall(PetscMalloc2(nroots, &jbuf, Fnz, &Fj));
874     for (PetscInt k = 0; k < ioffset[niranks]; k++) {
875       PetscInt  i      = irootloc[k]; // row to be copied
876       PetscInt *buf    = &jbuf[rowoffset[k]];
877       PetscInt  nzLeft = E_NzLeft[i];
878       PetscInt  alen = Ai[i + 1] - Ai[i], blen = Bi[i + 1] - Bi[i];
879       for (PetscInt j = 0; j < alen + blen; j++) {
880         if (j < nzLeft) {
881           buf[j] = empi->garray[Bj[Bi[i] + j]]; // left B, in global
882         } else if (j < nzLeft + alen) {
883           buf[j] = Aj[Ai[i] + j - nzLeft] + cstart; // diag A, also in global
884         } else {
885           buf[j] = empi->garray[Bj[Bi[i] + j - alen]]; // right B, in global
886         }
887       }
888     }
889     PetscCall(PetscSFBcastWithMemTypeBegin(bcastSF, MPIU_INT, PETSC_MEMTYPE_HOST, jbuf, PETSC_MEMTYPE_HOST, Fj, MPI_REPLACE));
890     PetscCall(PetscSFBcastEnd(bcastSF, MPIU_INT, jbuf, Fj, MPI_REPLACE));
891 
892     // Build a plan (i.e., F_NzLeft) to split F into Fd and Fo
893     MatRowMapKokkosViewHost Fdi_h(NoInit("Fdi_h"), Fm + 1), Foi_h(NoInit("Foi_h"), Fm + 1); // row pointer of Fd, Fo
894     MatColIdxKokkosViewHost F_NzLeft_h(NoInit("F_NzLeft_h"), Fm);                           // split each row of F into Left, Diag, Right. We only need to record #nz in Left and Diag.
895     MatRowMapType          *Fdi = Fdi_h.data(), *Foi = Foi_h.data();
896     MatColIdxType          *F_NzLeft = F_NzLeft_h.data();
897 
898     Fdi[0] = Foi[0] = 0;
899     for (PetscInt i = 0; i < Fm; i++) {
900       PetscInt *first, *last, *lb1, *lb2;
901       // cut the row into: Left, [cstart, cend), Right
902       first       = Fj + Fi[i];
903       last        = Fj + Fi[i + 1];
904       lb1         = std::lower_bound(first, last, cstart);
905       F_NzLeft[i] = lb1 - first;
906       lb2         = std::lower_bound(first, last, cend);
907       Fdi[i + 1]  = lb2 - lb1;                        // row i length in Fdi
908       Foi[i + 1]  = (Fi[i + 1] - Fi[i]) - Fdi[i + 1]; // row i length in Foi
909     }
910     for (PetscInt i = 0; i < Fm; i++) {
911       Fdi[i + 1] += Fdi[i];
912       Foi[i + 1] += Foi[i];
913     }
914 
915     // Fill Fdj[] and Foj[], i.e., columns of Fd and Fo. Fdj[] are local, but Foj[] are not yet.
916     PetscInt                Fdnz = Fdi[Fm], Fonz = Foi[Fm];
917     MatColIdxKokkosViewHost Fdj_h(NoInit("Fdj_h"), Fdnz), Foj_h(NoInit("Foj_h"), Fonz);
918     MatColIdxType          *Fdj = Fdj_h.data(), *Foj = Foj_h.data(), gid;
919 
920     for (PetscInt i = 0; i < Fm; i++) {
921       PetscInt nzLeft = F_NzLeft[i];
922       PetscInt len    = Fdi[i + 1] - Fdi[i]; // diag row len
923       for (PetscInt j = 0; j < Fi[i + 1] - Fi[i]; j++) {
924         gid = Fj[Fi[i] + j];
925         if (j < nzLeft) { // left, in global
926           Foj[Foi[i] + j] = gid;
927         } else if (j < nzLeft + len) { // diag, in local
928           Fdj[Fdi[i] + j - nzLeft] = gid - cstart;
929         } else { // right, in global
930           Foj[Foi[i] + j - len] = gid;
931         }
932       }
933     }
934     PetscCall(PetscFree2(jbuf, Fj));
935     PetscCall(PetscFree(Fi));
936 
937     // Reduce global indices in Foj[] and garray1[] into local ones
938     PetscInt n2, *garray2;
939     PetscCall(ReduceTwoSetsOfGlobalIndices(n1, garray1, Fonz, Foj, &n2, &garray2, map));
940 
941     // Record the plans built above, for reuse
942     PetscIntKokkosViewHost tmp(const_cast<PetscInt *>(irootloc), ioffset[niranks]); // irootloc[] is owned by ownerSF. We create a copy for safety
943     PetscIntKokkosViewHost irootloc_h(NoInit("irootloc_h"), ioffset[niranks]);
944     Kokkos::deep_copy(irootloc_h, tmp);
945     mm->sf        = bcastSF;
946     mm->E_NzLeft  = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), E_NzLeft_h);
947     mm->F_NzLeft  = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), F_NzLeft_h);
948     mm->irootloc  = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), irootloc_h);
949     mm->rowoffset = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), rowoffset_h);
950     mm->rootBuf   = MatScalarKokkosView(NoInit("rootBuf"), nroots);
951     mm->leafBuf   = MatScalarKokkosView(NoInit("leafBuf"), nleaves);
952     mm->garray    = garray2;
953     mm->n         = n2;
954 
955     // Output Fd and Fo in KokkosCsrMatrix format
956     MatScalarKokkosView Fda_d(NoInit("Fda_d"), Fdnz), Foa_d(NoInit("Foa_d"), Fonz);
957     MatRowMapKokkosView Fdi_d = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Fdi_h);
958     MatColIdxKokkosView Fdj_d = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Fdj_h);
959     MatRowMapKokkosView Foi_d = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Foi_h);
960     MatColIdxKokkosView Foj_d = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Foj_h);
961 
962     PetscCallCXX(mm->Fd = KokkosCsrMatrix("Fd", Fm, cend - cstart, Fdnz, Fda_d, Fdi_d, Fdj_d));
963     PetscCallCXX(mm->Fo = KokkosCsrMatrix("Fo", Fm, n2, Fonz, Foa_d, Foi_d, Foj_d));
964 
965     // Compute kernel launch parameters in merging E or splitting F
966     PetscInt teamSize, vectorLength, rowsPerTeam;
967 
968     teamSize = vectorLength = rowsPerTeam = -1;
969     PetscCall(MatMergeGetLaunchParameters<DefaultExecutionSpace>(mm->irootloc.extent(0), mm->rootBuf.extent(0), -1, teamSize, vectorLength, rowsPerTeam));
970     mm->E_TeamSize     = teamSize;
971     mm->E_VectorLength = vectorLength;
972     mm->E_RowsPerTeam  = rowsPerTeam;
973 
974     teamSize = vectorLength = rowsPerTeam = -1;
975     PetscCall(MatMergeGetLaunchParameters<DefaultExecutionSpace>(Fm, Fnz, -1, teamSize, vectorLength, rowsPerTeam));
976     mm->F_TeamSize     = teamSize;
977     mm->F_VectorLength = vectorLength;
978     mm->F_RowsPerTeam  = rowsPerTeam;
979   } else PetscCheck(reuse == MAT_REUSE_MATRIX, comm, PETSC_ERR_PLIB, "Unsupported MatReuse enum %d", reuse);
980 
981   // Sync E's value to device
982   akok->a_dual.sync_device();
983   bkok->a_dual.sync_device();
984 
985   // Handy aliases
986   const auto &Aa = akok->a_dual.view_device();
987   const auto &Ba = bkok->a_dual.view_device();
988   const auto &Ai = akok->i_dual.view_device();
989   const auto &Bi = bkok->i_dual.view_device();
990 
991   // Fetch the plans
992   PetscIntKokkosView  &E_NzLeft  = mm->E_NzLeft;
993   PetscSF             &bcastSF   = mm->sf;
994   MatScalarKokkosView &rootBuf   = mm->rootBuf;
995   MatScalarKokkosView &leafBuf   = mm->leafBuf;
996   PetscIntKokkosView  &irootloc  = mm->irootloc;
997   PetscIntKokkosView  &rowoffset = mm->rowoffset;
998 
999   PetscInt teamSize     = mm->E_TeamSize;
1000   PetscInt vectorLength = mm->E_VectorLength;
1001   PetscInt rowsPerTeam  = mm->E_RowsPerTeam;
1002   PetscInt workSets     = (irootloc.extent(0) + rowsPerTeam - 1) / rowsPerTeam;
1003 
1004   // Copy rows in A/B of E to rootBuf, then bcast it to leafBuf
1005   PetscCallCXX(Kokkos::parallel_for(
1006     Kokkos::TeamPolicy<>(PetscGetKokkosExecutionSpace(), workSets, teamSize, vectorLength), KOKKOS_LAMBDA(const KokkosTeamMemberType &t) {
1007       Kokkos::parallel_for(Kokkos::TeamThreadRange(t, 0, rowsPerTeam), [&](PetscInt k) {
1008         size_t r = t.league_rank() * rowsPerTeam + k; // r-th entry in irootloc[]
1009         if (r < irootloc.extent(0)) {
1010           PetscInt i      = irootloc(r); // row i of E
1011           PetscInt disp   = rowoffset(r);
1012           PetscInt alen   = Ai(i + 1) - Ai(i);
1013           PetscInt blen   = Bi(i + 1) - Bi(i);
1014           PetscInt nzleft = E_NzLeft(i);
1015 
1016           Kokkos::parallel_for(Kokkos::ThreadVectorRange(t, alen + blen), [&](PetscInt j) {
1017             if (j < nzleft) { // B left
1018               rootBuf(disp + j) = Ba(Bi(i) + j);
1019             } else if (j < nzleft + alen) { // diag A
1020               rootBuf(disp + j) = Aa(Ai(i) + j - nzleft);
1021             } else { // B right
1022               rootBuf(disp + j) = Ba(Bi(i) + j - alen);
1023             }
1024           });
1025         }
1026       });
1027     }));
1028   PetscCall(PetscSFBcastWithMemTypeBegin(bcastSF, MPIU_SCALAR, PETSC_MEMTYPE_KOKKOS, rootBuf.data(), PETSC_MEMTYPE_KOKKOS, leafBuf.data(), MPI_REPLACE));
1029   PetscFunctionReturn(PETSC_SUCCESS);
1030 }
1031 
1032 // To finish MatMPIAIJKokkosBcast.
1033 static PetscErrorCode MatMPIAIJKokkosBcastEnd(Mat E, PetscSF ownerSF, MatReuse reuse, PetscInt *map, MatMatStruct_AB *mm)
1034 {
1035   PetscFunctionBegin;
1036   const auto &Fd  = mm->Fd;
1037   const auto &Fo  = mm->Fo;
1038   const auto &Fdi = Fd.graph.row_map;
1039   const auto &Foi = Fo.graph.row_map;
1040   auto       &Fda = Fd.values;
1041   auto       &Foa = Fo.values;
1042   auto        Fm  = Fd.numRows();
1043 
1044   PetscIntKokkosView  &F_NzLeft     = mm->F_NzLeft;
1045   PetscSF             &bcastSF      = mm->sf;
1046   MatScalarKokkosView &rootBuf      = mm->rootBuf;
1047   MatScalarKokkosView &leafBuf      = mm->leafBuf;
1048   PetscInt             teamSize     = mm->F_TeamSize;
1049   PetscInt             vectorLength = mm->F_VectorLength;
1050   PetscInt             rowsPerTeam  = mm->F_RowsPerTeam;
1051   PetscInt             workSets     = (Fm + rowsPerTeam - 1) / rowsPerTeam;
1052 
1053   PetscCall(PetscSFBcastEnd(bcastSF, MPIU_SCALAR, rootBuf.data(), leafBuf.data(), MPI_REPLACE));
1054 
1055   // Update Fda and Foa with new data in leafBuf (as if it is Fa)
1056   PetscCallCXX(Kokkos::parallel_for(
1057     Kokkos::TeamPolicy<>(PetscGetKokkosExecutionSpace(), workSets, teamSize, vectorLength), KOKKOS_LAMBDA(const KokkosTeamMemberType &t) {
1058       Kokkos::parallel_for(Kokkos::TeamThreadRange(t, 0, rowsPerTeam), [&](PetscInt k) {
1059         PetscInt i = t.league_rank() * rowsPerTeam + k; // i-th row in F
1060         if (i < Fm) {
1061           PetscInt nzLeft = F_NzLeft(i);
1062           PetscInt alen   = Fdi(i + 1) - Fdi(i);
1063           PetscInt blen   = Foi(i + 1) - Foi(i);
1064           PetscInt Fii    = Fdi(i) + Foi(i);
1065 
1066           Kokkos::parallel_for(Kokkos::ThreadVectorRange(t, alen + blen), [&](PetscInt j) {
1067             PetscScalar val = leafBuf(Fii + j);
1068             if (j < nzLeft) { // left
1069               Foa(Foi(i) + j) = val;
1070             } else if (j < nzLeft + alen) { // diag
1071               Fda(Fdi(i) + j - nzLeft) = val;
1072             } else { // right
1073               Foa(Foi(i) + j - alen) = val;
1074             }
1075           });
1076         }
1077       });
1078     }));
1079   PetscFunctionReturn(PETSC_SUCCESS);
1080 }
1081 
1082 static PetscErrorCode MatProductSymbolic_MPIAIJKokkos_AtB(Mat_Product *product, Mat A, Mat B, MatMatStruct_AtB *mm)
1083 {
1084   Mat_MPIAIJ     *ampi = static_cast<Mat_MPIAIJ *>(A->data);
1085   Mat_MPIAIJ     *bmpi = static_cast<Mat_MPIAIJ *>(B->data);
1086   KokkosCsrMatrix Adt, Aot, Ad, Ao, Bd, Bo;
1087   PetscInt        cstart, cend;
1088   MPI_Comm        comm;
1089 
1090   PetscFunctionBegin;
1091   PetscCall(PetscObjectGetComm((PetscObject)B, &comm));
1092   PetscCall(MatSeqAIJKokkosGenerateTranspose_Private(ampi->A, &Adt));
1093   PetscCall(MatSeqAIJKokkosGenerateTranspose_Private(ampi->B, &Aot));
1094   PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(ampi->A, &Ad));
1095   PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(ampi->B, &Ao));
1096   PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(bmpi->A, &Bd));
1097   PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(bmpi->B, &Bo));
1098 
1099   // TODO: add command line options to select spgemm algorithms
1100   auto spgemm_alg = KokkosSparse::SPGEMMAlgorithm::SPGEMM_DEFAULT; // default is TPL if enabled, otherwise KK
1101 
1102   // CUDA-10.2's spgemm has bugs. We prefer the SpGEMMreuse APIs introduced in cuda-11.4
1103 #if defined(KOKKOSKERNELS_ENABLE_TPL_CUSPARSE)
1104   #if PETSC_PKG_CUDA_VERSION_LT(11, 4, 0)
1105   spgemm_alg = KokkosSparse::SPGEMMAlgorithm::SPGEMM_KK;
1106   #endif
1107 #endif
1108 
1109   PetscCallCXX(mm->kh1.create_spgemm_handle(spgemm_alg));
1110   PetscCallCXX(mm->kh2.create_spgemm_handle(spgemm_alg));
1111   PetscCallCXX(mm->kh3.create_spgemm_handle(spgemm_alg));
1112   PetscCallCXX(mm->kh4.create_spgemm_handle(spgemm_alg));
1113 
1114   // Aot * (B's diag + B's off-diag)
1115   PetscCallCXX(KokkosSparse::spgemm_symbolic(mm->kh3, Aot, false, Bd, false, mm->C3));
1116   PetscCallCXX(KokkosSparse::spgemm_symbolic(mm->kh4, Aot, false, Bo, false, mm->C4));
1117   // KK spgemm_symbolic() only populates the result's row map, but not its columns.
1118   // TODO: Remove the fake spgemm_numeric() after KK fixed this problem.
1119   PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh3, Aot, false, Bd, false, mm->C3));
1120   PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh4, Aot, false, Bo, false, mm->C4));
1121 #if PETSC_PKG_KOKKOS_KERNELS_VERSION_LT(4, 0, 0)
1122 
1123   PetscCallCXX(sort_crs_matrix(mm->C3));
1124   PetscCallCXX(sort_crs_matrix(mm->C4));
1125 #endif
1126 
1127   // Reduce E (i.e., C3 and C4)'s rows to form F, and overlap the communication
1128   PetscIntKokkosViewHost map_h(NoInit("map_h"), bmpi->B->cmap->n);
1129   PetscCall(MatGetOwnershipRangeColumn(B, &cstart, &cend));
1130   PetscCall(MatMPIAIJKokkosReduceBegin(comm, mm->C3, mm->C4, cstart, cend, bmpi->garray, ampi->Mvctx, MAT_INITIAL_MATRIX, map_h.data(), mm));
1131 
1132   // Adt * (B's diag + B's off-diag)
1133   PetscCallCXX(KokkosSparse::spgemm_symbolic(mm->kh1, Adt, false, Bd, false, mm->C1));
1134   PetscCallCXX(KokkosSparse::spgemm_symbolic(mm->kh2, Adt, false, Bo, false, mm->C2_mid));
1135   PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh1, Adt, false, Bd, false, mm->C1));
1136   PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh2, Adt, false, Bo, false, mm->C2_mid));
1137 #if PETSC_PKG_KOKKOS_KERNELS_VERSION_LT(4, 0, 0)
1138   PetscCallCXX(sort_crs_matrix(mm->C1));
1139   PetscCallCXX(sort_crs_matrix(mm->C2_mid));
1140 #endif
1141 
1142   PetscCall(MatMPIAIJKokkosReduceEnd(comm, mm->C3, mm->C4, cstart, cend, bmpi->garray, ampi->Mvctx, MAT_INITIAL_MATRIX, map_h.data(), mm));
1143 
1144   // Create C2, which shares a, i arrays with C2_mid, but with new column indices and potentially larger column size
1145   MatColIdxKokkosView oldj = mm->C2_mid.graph.entries, newj(NoInit("j"), oldj.extent(0));
1146   PetscIntKokkosView  map  = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), map_h);
1147   PetscCallCXX(Kokkos::parallel_for(Kokkos::RangePolicy<>(PetscGetKokkosExecutionSpace(), 0, oldj.extent(0)), KOKKOS_LAMBDA(const PetscInt i) { newj(i) = map(oldj(i)); }));
1148   PetscCallCXX(mm->C2 = KokkosCsrMatrix("C2", mm->C2_mid.numRows(), mm->n /*new column size*/, mm->C2_mid.nnz(), mm->C2_mid.values, mm->C2_mid.graph.row_map, newj));
1149 
1150   // C = (C1+Fd, C2+Fo)
1151   PetscCallCXX(mm->kh1.create_spadd_handle(true)); // C1, Fd are sorted
1152   PetscCallCXX(mm->kh2.create_spadd_handle(true)); // C2, Fo are sorted
1153   PetscCallCXX(KokkosSparse::spadd_symbolic(&mm->kh1, mm->C1, mm->Fd, mm->Cd));
1154   PetscCallCXX(KokkosSparse::spadd_symbolic(&mm->kh2, mm->C2, mm->Fo, mm->Co));
1155   PetscCallCXX(KokkosSparse::spadd_numeric(&mm->kh1, 1.0, mm->C1, 1.0, mm->Fd, mm->Cd));
1156   PetscCallCXX(KokkosSparse::spadd_numeric(&mm->kh2, 1.0, mm->C2, 1.0, mm->Fo, mm->Co));
1157   PetscFunctionReturn(PETSC_SUCCESS);
1158 }
1159 
1160 static PetscErrorCode MatProductNumeric_MPIAIJKokkos_AtB(Mat_Product *product, Mat A, Mat B, MatMatStruct_AtB *mm)
1161 {
1162   Mat_MPIAIJ     *ampi = static_cast<Mat_MPIAIJ *>(A->data);
1163   Mat_MPIAIJ     *bmpi = static_cast<Mat_MPIAIJ *>(B->data);
1164   KokkosCsrMatrix Adt, Aot, Bd, Bo;
1165   MPI_Comm        comm;
1166 
1167   PetscFunctionBegin;
1168   PetscCall(PetscObjectGetComm((PetscObject)B, &comm));
1169   PetscCall(MatSeqAIJKokkosGenerateTranspose_Private(ampi->A, &Adt));
1170   PetscCall(MatSeqAIJKokkosGenerateTranspose_Private(ampi->B, &Aot));
1171   PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(bmpi->A, &Bd));
1172   PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(bmpi->B, &Bo));
1173 
1174   // Aot * (B's diag + B's off-diag)
1175   PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh3, Aot, false, Bd, false, mm->C3));
1176   PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh4, Aot, false, Bo, false, mm->C4));
1177 
1178   // Reduce E (i.e., C3 and C4)'s rows to form F, and overlap the communication
1179   PetscCall(MatMPIAIJKokkosReduceBegin(comm, mm->C3, mm->C4, 0, 0, NULL, NULL, MAT_REUSE_MATRIX, NULL, mm));
1180 
1181   // Adt * (B's diag + B's off-diag)
1182   PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh1, Adt, false, Bd, false, mm->C1));
1183   PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh2, Adt, false, Bo, false, mm->C2_mid));
1184 
1185   PetscCall(MatMPIAIJKokkosReduceEnd(comm, mm->C3, mm->C4, 0, 0, NULL, NULL, MAT_REUSE_MATRIX, NULL, mm));
1186 
1187   // C = (C1+Fd, C2+Fo)
1188   PetscCallCXX(KokkosSparse::spadd_numeric(&mm->kh1, 1.0, mm->C1, 1.0, mm->Fd, mm->Cd));
1189   PetscCallCXX(KokkosSparse::spadd_numeric(&mm->kh2, 1.0, mm->C2, 1.0, mm->Fo, mm->Co));
1190   PetscFunctionReturn(PETSC_SUCCESS);
1191 }
1192 
1193 /* MatProductSymbolic_MPIAIJKokkos_AB - AB flavor of MatProductSymbolic_MPIAIJKokkos
1194 
1195   Input Parameters:
1196 +  product  - Mat_Product which carried out the computation. Passed in to access info about this mat product.
1197 .  A        - an MPIAIJKOKKOS matrix
1198 .  B        - an MPIAIJKOKKOS matrix
1199 -  mm       - a struct used to stash intermediate data when computing AB. Persist from symbolic to numeric operations.
1200 */
1201 static PetscErrorCode MatProductSymbolic_MPIAIJKokkos_AB(Mat_Product *product, Mat A, Mat B, MatMatStruct_AB *mm)
1202 {
1203   Mat_MPIAIJ     *ampi = static_cast<Mat_MPIAIJ *>(A->data);
1204   Mat_MPIAIJ     *bmpi = static_cast<Mat_MPIAIJ *>(B->data);
1205   KokkosCsrMatrix Ad, Ao, Bd, Bo;
1206 
1207   PetscFunctionBegin;
1208   PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(ampi->A, &Ad));
1209   PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(ampi->B, &Ao));
1210   PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(bmpi->A, &Bd));
1211   PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(bmpi->B, &Bo));
1212 
1213   // TODO: add command line options to select spgemm algorithms
1214   auto spgemm_alg = KokkosSparse::SPGEMMAlgorithm::SPGEMM_DEFAULT; // default is TPL if enabled, otherwise KK
1215 
1216   // CUDA-10.2's spgemm has bugs. We prefer the SpGEMMreuse APIs introduced in cuda-11.4
1217 #if defined(KOKKOSKERNELS_ENABLE_TPL_CUSPARSE)
1218   #if PETSC_PKG_CUDA_VERSION_LT(11, 4, 0)
1219   spgemm_alg = KokkosSparse::SPGEMMAlgorithm::SPGEMM_KK;
1220   #endif
1221 #endif
1222 
1223   mm->kh1.create_spgemm_handle(spgemm_alg);
1224   mm->kh2.create_spgemm_handle(spgemm_alg);
1225   mm->kh3.create_spgemm_handle(spgemm_alg);
1226   mm->kh4.create_spgemm_handle(spgemm_alg);
1227 
1228   // Bcast B's rows to form F, and overlap the communication
1229   PetscIntKokkosViewHost map_h(NoInit("map_h"), bmpi->B->cmap->n);
1230   PetscCall(MatMPIAIJKokkosBcastBegin(B, ampi->Mvctx, MAT_INITIAL_MATRIX, map_h.data(), mm));
1231 
1232   // A's diag * (B's diag + B's off-diag)
1233   PetscCallCXX(KokkosSparse::spgemm_symbolic(mm->kh1, Ad, false, Bd, false, mm->C1));
1234   PetscCallCXX(KokkosSparse::spgemm_symbolic(mm->kh2, Ad, false, Bo, false, mm->C2_mid)); // C2 aliases with C2_mid, except with new column indices
1235   // KK spgemm_symbolic() only populates the result's row map, but not its columns.
1236   // TODO: Remove the fake spgemm_numeric() after KK fixed this problem.
1237   PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh1, Ad, false, Bd, false, mm->C1));
1238   PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh2, Ad, false, Bo, false, mm->C2_mid));
1239 #if PETSC_PKG_KOKKOS_KERNELS_VERSION_LT(4, 0, 0)
1240   PetscCallCXX(sort_crs_matrix(mm->C1));
1241   PetscCallCXX(sort_crs_matrix(mm->C2_mid));
1242 #endif
1243 
1244   PetscCall(MatMPIAIJKokkosBcastEnd(B, ampi->Mvctx, MAT_INITIAL_MATRIX, map_h.data(), mm));
1245 
1246   // A's off-diag * (F's diag + F's off-diag)
1247   PetscCallCXX(KokkosSparse::spgemm_symbolic(mm->kh3, Ao, false, mm->Fd, false, mm->C3));
1248   PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh3, Ao, false, mm->Fd, false, mm->C3));
1249   PetscCallCXX(KokkosSparse::spgemm_symbolic(mm->kh4, Ao, false, mm->Fo, false, mm->C4));
1250   PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh4, Ao, false, mm->Fo, false, mm->C4));
1251 #if PETSC_PKG_KOKKOS_KERNELS_VERSION_LT(4, 0, 0)
1252   PetscCallCXX(sort_crs_matrix(mm->C3));
1253   PetscCallCXX(sort_crs_matrix(mm->C4));
1254 #endif
1255 
1256   // Create C2, which shares a, i arrays with C2_mid, but with new column indices and potentially larger column size
1257   MatColIdxKokkosView oldj = mm->C2_mid.graph.entries, newj(NoInit("j"), oldj.extent(0));
1258   PetscIntKokkosView  map  = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), map_h);
1259   PetscCallCXX(Kokkos::parallel_for(Kokkos::RangePolicy<>(PetscGetKokkosExecutionSpace(), 0, oldj.extent(0)), KOKKOS_LAMBDA(const PetscInt i) { newj(i) = map(oldj(i)); }));
1260   mm->C2 = KokkosCsrMatrix("C2", mm->C2_mid.numRows(), mm->n /*new column size*/, mm->C2_mid.nnz(), mm->C2_mid.values, mm->C2_mid.graph.row_map, newj);
1261 
1262   // C = (Cd, Co) = (C1+C3, C2+C4)
1263   mm->kh1.create_spadd_handle(true); // C1, C3 are sorted
1264   mm->kh2.create_spadd_handle(true); // C2, C4 are sorted
1265   PetscCallCXX(KokkosSparse::spadd_symbolic(&mm->kh1, mm->C1, mm->C3, mm->Cd));
1266   PetscCallCXX(KokkosSparse::spadd_symbolic(&mm->kh2, mm->C2, mm->C4, mm->Co));
1267   PetscCallCXX(KokkosSparse::spadd_numeric(&mm->kh1, 1.0, mm->C1, 1.0, mm->C3, mm->Cd));
1268   PetscCallCXX(KokkosSparse::spadd_numeric(&mm->kh2, 1.0, mm->C2, 1.0, mm->C4, mm->Co));
1269   PetscFunctionReturn(PETSC_SUCCESS);
1270 }
1271 
1272 static PetscErrorCode MatProductNumeric_MPIAIJKokkos_AB(Mat_Product *product, Mat A, Mat B, MatMatStruct_AB *mm)
1273 {
1274   Mat_MPIAIJ     *ampi = static_cast<Mat_MPIAIJ *>(A->data);
1275   Mat_MPIAIJ     *bmpi = static_cast<Mat_MPIAIJ *>(B->data);
1276   KokkosCsrMatrix Ad, Ao, Bd, Bo;
1277 
1278   PetscFunctionBegin;
1279   PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(ampi->A, &Ad));
1280   PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(ampi->B, &Ao));
1281   PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(bmpi->A, &Bd));
1282   PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(bmpi->B, &Bo));
1283 
1284   // Bcast B's rows to form F, and overlap the communication
1285   PetscCall(MatMPIAIJKokkosBcastBegin(B, NULL, MAT_REUSE_MATRIX, NULL, mm));
1286 
1287   // A's diag * (B's diag + B's off-diag)
1288   PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh1, Ad, false, Bd, false, mm->C1));
1289   PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh2, Ad, false, Bo, false, mm->C2_mid));
1290 
1291   PetscCall(MatMPIAIJKokkosBcastEnd(B, NULL, MAT_REUSE_MATRIX, NULL, mm));
1292 
1293   // A's off-diag * (F's diag + F's off-diag)
1294   PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh3, Ao, false, mm->Fd, false, mm->C3));
1295   PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh4, Ao, false, mm->Fo, false, mm->C4));
1296 
1297   // C = (Cd, Co) = (C1+C3, C2+C4)
1298   PetscCallCXX(KokkosSparse::spadd_numeric(&mm->kh1, 1.0, mm->C1, 1.0, mm->C3, mm->Cd));
1299   PetscCallCXX(KokkosSparse::spadd_numeric(&mm->kh2, 1.0, mm->C2, 1.0, mm->C4, mm->Co));
1300   PetscFunctionReturn(PETSC_SUCCESS);
1301 }
1302 
1303 static PetscErrorCode MatProductNumeric_MPIAIJKokkos(Mat C)
1304 {
1305   Mat_MPIAIJ                  *cmpi = static_cast<Mat_MPIAIJ *>(C->data);
1306   Mat_Product                 *product;
1307   MatProductData_MPIAIJKokkos *pdata;
1308   MatProductType               ptype;
1309   Mat                          A, B;
1310 
1311   PetscFunctionBegin;
1312   MatCheckProduct(C, 1); // make sure C is a product
1313   product = C->product;
1314   pdata   = static_cast<MatProductData_MPIAIJKokkos *>(product->data);
1315   ptype   = product->type;
1316   A       = product->A;
1317   B       = product->B;
1318 
1319   // See if numeric has already been done in symbolic (e.g., user calls MatMatMult(A,B,MAT_INITIAL_MATRIX,..,C)).
1320   // If yes, skip the numeric, but reset the flag so that next time when user calls MatMatMult(E,F,MAT_REUSE_MATRIX,..,C),
1321   // we still do numeric.
1322   if (pdata->reusesym) { // numeric reuses results from symbolic
1323     pdata->reusesym = PETSC_FALSE;
1324     PetscFunctionReturn(PETSC_SUCCESS);
1325   }
1326 
1327   if (ptype == MATPRODUCT_AB) {
1328     PetscCall(MatProductNumeric_MPIAIJKokkos_AB(product, A, B, pdata->mmAB));
1329   } else if (ptype == MATPRODUCT_AtB) {
1330     PetscCall(MatProductNumeric_MPIAIJKokkos_AtB(product, A, B, pdata->mmAtB));
1331   } else if (ptype == MATPRODUCT_PtAP) { // BtAB, computed by Z = AB; C= BtZ
1332     PetscCall(MatProductNumeric_MPIAIJKokkos_AB(product, A, B, pdata->mmAB));
1333     PetscCall(MatProductNumeric_MPIAIJKokkos_AtB(product, B, pdata->Z, pdata->mmAtB));
1334   }
1335 
1336   PetscCall(MatSeqAIJKokkosModifyDevice(cmpi->A)); // mark that A, B on device are modified
1337   PetscCall(MatSeqAIJKokkosModifyDevice(cmpi->B));
1338   PetscFunctionReturn(PETSC_SUCCESS);
1339 }
1340 
1341 static PetscErrorCode MatProductSymbolic_MPIAIJKokkos(Mat C)
1342 {
1343   Mat                          A, B;
1344   Mat_Product                 *product;
1345   MatProductType               ptype;
1346   MatProductData_MPIAIJKokkos *pdata;
1347   MatMatStruct                *mm = NULL;
1348   PetscInt                     m, n, M, N;
1349   Mat                          Cd, Co;
1350   MPI_Comm                     comm;
1351 
1352   PetscFunctionBegin;
1353   PetscCall(PetscObjectGetComm((PetscObject)C, &comm));
1354   MatCheckProduct(C, 1);
1355   product = C->product;
1356   PetscCheck(!product->data, comm, PETSC_ERR_PLIB, "Product data not empty");
1357   ptype = product->type;
1358   A     = product->A;
1359   B     = product->B;
1360 
1361   switch (ptype) {
1362   case MATPRODUCT_AB:
1363     m = A->rmap->n;
1364     n = B->cmap->n;
1365     M = A->rmap->N;
1366     N = B->cmap->N;
1367     break;
1368   case MATPRODUCT_AtB:
1369     m = A->cmap->n;
1370     n = B->cmap->n;
1371     M = A->cmap->N;
1372     N = B->cmap->N;
1373     break;
1374   case MATPRODUCT_PtAP:
1375     m = B->cmap->n;
1376     n = B->cmap->n;
1377     M = B->cmap->N;
1378     N = B->cmap->N;
1379     break; /* BtAB */
1380   default:
1381     SETERRQ(comm, PETSC_ERR_PLIB, "Not for product type %s", MatProductTypes[ptype]);
1382   }
1383 
1384   PetscCall(MatSetSizes(C, m, n, M, N));
1385   PetscCall(PetscLayoutSetUp(C->rmap));
1386   PetscCall(PetscLayoutSetUp(C->cmap));
1387   PetscCall(MatSetType(C, ((PetscObject)A)->type_name));
1388 
1389   pdata           = new MatProductData_MPIAIJKokkos();
1390   pdata->reusesym = product->api_user;
1391 
1392   if (ptype == MATPRODUCT_AB) {
1393     auto mmAB = new MatMatStruct_AB();
1394     PetscCall(MatProductSymbolic_MPIAIJKokkos_AB(product, A, B, mmAB));
1395     mm = pdata->mmAB = mmAB;
1396   } else if (ptype == MATPRODUCT_AtB) {
1397     auto mmAtB = new MatMatStruct_AtB();
1398     PetscCall(MatProductSymbolic_MPIAIJKokkos_AtB(product, A, B, mmAtB));
1399     mm = pdata->mmAtB = mmAtB;
1400   } else if (ptype == MATPRODUCT_PtAP) { // C = BtAB, computed as Z = AB; C= BtZ
1401     Mat Zd, Zo, Z;                       // Zd, Zo are owned by pdata->Z
1402 
1403     auto mmAB = new MatMatStruct_AB();
1404     PetscCall(MatProductSymbolic_MPIAIJKokkos_AB(product, A, B, mmAB)); // Z stored as mmAB->{Cd, Co}
1405     PetscCall(MatCreateSeqAIJKokkosWithKokkosCsrMatrix(PETSC_COMM_SELF, mmAB->Cd, &Zd));
1406     PetscCall(MatCreateSeqAIJKokkosWithKokkosCsrMatrix(PETSC_COMM_SELF, mmAB->Co, &Zo));
1407     pdata->mmAB = mmAB;
1408 
1409     m = A->rmap->n; // Z's layout
1410     n = B->cmap->n;
1411     M = A->rmap->N;
1412     N = B->cmap->N;
1413     PetscCall(MatCreate(comm, &Z));
1414     PetscCall(MatSetSizes(Z, m, n, M, N));
1415     PetscCall(PetscLayoutSetUp(Z->rmap));
1416     PetscCall(PetscLayoutSetUp(Z->cmap));
1417     PetscCall(MatSetType(Z, MATMPIAIJKOKKOS));
1418     PetscCall(MatSetMPIAIJKokkosWithSplitSeqAIJKokkosMatrices(Z, Zd, Zo, mmAB->garray));
1419 
1420     auto mmAtB = new MatMatStruct_AtB();
1421     PetscCall(MatProductSymbolic_MPIAIJKokkos_AtB(product, B, Z, mmAtB)); // final result C stored as mmAtB->{Cd, Co}
1422 
1423     pdata->Z = Z; // give ownership to pdata
1424     mm = pdata->mmAtB = mmAtB;
1425   }
1426 
1427   PetscCall(MatCreateSeqAIJKokkosWithKokkosCsrMatrix(PETSC_COMM_SELF, mm->Cd, &Cd));
1428   PetscCall(MatCreateSeqAIJKokkosWithKokkosCsrMatrix(PETSC_COMM_SELF, mm->Co, &Co));
1429   PetscCall(MatSetMPIAIJKokkosWithSplitSeqAIJKokkosMatrices(C, Cd, Co, mm->garray));
1430 
1431   C->product->data       = pdata;
1432   C->product->destroy    = MatProductDataDestroy_MPIAIJKokkos;
1433   C->ops->productnumeric = MatProductNumeric_MPIAIJKokkos;
1434   PetscFunctionReturn(PETSC_SUCCESS);
1435 }
1436 
1437 PETSC_INTERN PetscErrorCode MatProductSetFromOptions_MPIAIJKokkos(Mat mat)
1438 {
1439   Mat_Product *product = mat->product;
1440   PetscBool    match   = PETSC_FALSE;
1441   PetscBool    usecpu  = PETSC_FALSE;
1442 
1443   PetscFunctionBegin;
1444   MatCheckProduct(mat, 1);
1445   if (!product->A->boundtocpu && !product->B->boundtocpu) PetscCall(PetscObjectTypeCompare((PetscObject)product->B, ((PetscObject)product->A)->type_name, &match));
1446   if (match) { /* we can always fallback to the CPU if requested */
1447     switch (product->type) {
1448     case MATPRODUCT_AB:
1449       if (product->api_user) {
1450         PetscOptionsBegin(PetscObjectComm((PetscObject)mat), ((PetscObject)mat)->prefix, "MatMatMult", "Mat");
1451         PetscCall(PetscOptionsBool("-matmatmult_backend_cpu", "Use CPU code", "MatMatMult", usecpu, &usecpu, NULL));
1452         PetscOptionsEnd();
1453       } else {
1454         PetscOptionsBegin(PetscObjectComm((PetscObject)mat), ((PetscObject)mat)->prefix, "MatProduct_AB", "Mat");
1455         PetscCall(PetscOptionsBool("-mat_product_algorithm_backend_cpu", "Use CPU code", "MatMatMult", usecpu, &usecpu, NULL));
1456         PetscOptionsEnd();
1457       }
1458       break;
1459     case MATPRODUCT_AtB:
1460       if (product->api_user) {
1461         PetscOptionsBegin(PetscObjectComm((PetscObject)mat), ((PetscObject)mat)->prefix, "MatTransposeMatMult", "Mat");
1462         PetscCall(PetscOptionsBool("-mattransposematmult_backend_cpu", "Use CPU code", "MatTransposeMatMult", usecpu, &usecpu, NULL));
1463         PetscOptionsEnd();
1464       } else {
1465         PetscOptionsBegin(PetscObjectComm((PetscObject)mat), ((PetscObject)mat)->prefix, "MatProduct_AtB", "Mat");
1466         PetscCall(PetscOptionsBool("-mat_product_algorithm_backend_cpu", "Use CPU code", "MatTransposeMatMult", usecpu, &usecpu, NULL));
1467         PetscOptionsEnd();
1468       }
1469       break;
1470     case MATPRODUCT_PtAP:
1471       if (product->api_user) {
1472         PetscOptionsBegin(PetscObjectComm((PetscObject)mat), ((PetscObject)mat)->prefix, "MatPtAP", "Mat");
1473         PetscCall(PetscOptionsBool("-matptap_backend_cpu", "Use CPU code", "MatPtAP", usecpu, &usecpu, NULL));
1474         PetscOptionsEnd();
1475       } else {
1476         PetscOptionsBegin(PetscObjectComm((PetscObject)mat), ((PetscObject)mat)->prefix, "MatProduct_PtAP", "Mat");
1477         PetscCall(PetscOptionsBool("-mat_product_algorithm_backend_cpu", "Use CPU code", "MatPtAP", usecpu, &usecpu, NULL));
1478         PetscOptionsEnd();
1479       }
1480       break;
1481     default:
1482       break;
1483     }
1484     match = (PetscBool)!usecpu;
1485   }
1486   if (match) {
1487     switch (product->type) {
1488     case MATPRODUCT_AB:
1489     case MATPRODUCT_AtB:
1490     case MATPRODUCT_PtAP:
1491       mat->ops->productsymbolic = MatProductSymbolic_MPIAIJKokkos;
1492       break;
1493     default:
1494       break;
1495     }
1496   }
1497   /* fallback to MPIAIJ ops */
1498   if (!mat->ops->productsymbolic) PetscCall(MatProductSetFromOptions_MPIAIJ(mat));
1499   PetscFunctionReturn(PETSC_SUCCESS);
1500 }
1501 
1502 // Mirror of MatCOOStruct_MPIAIJ on device
1503 struct MatCOOStruct_MPIAIJKokkos {
1504   PetscCount           n;
1505   PetscSF              sf;
1506   PetscCount           Annz, Bnnz;
1507   PetscCount           Annz2, Bnnz2;
1508   PetscCountKokkosView Ajmap1, Aperm1;
1509   PetscCountKokkosView Bjmap1, Bperm1;
1510   PetscCountKokkosView Aimap2, Ajmap2, Aperm2;
1511   PetscCountKokkosView Bimap2, Bjmap2, Bperm2;
1512   PetscCountKokkosView Cperm1;
1513   MatScalarKokkosView  sendbuf, recvbuf;
1514 
1515   MatCOOStruct_MPIAIJKokkos(const MatCOOStruct_MPIAIJ *coo_h) :
1516     n(coo_h->n),
1517     sf(coo_h->sf),
1518     Annz(coo_h->Annz),
1519     Bnnz(coo_h->Bnnz),
1520     Annz2(coo_h->Annz2),
1521     Bnnz2(coo_h->Bnnz2),
1522     Ajmap1(Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), PetscCountKokkosViewHost(coo_h->Ajmap1, coo_h->Annz + 1))),
1523     Aperm1(Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), PetscCountKokkosViewHost(coo_h->Aperm1, coo_h->Atot1))),
1524     Bjmap1(Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), PetscCountKokkosViewHost(coo_h->Bjmap1, coo_h->Bnnz + 1))),
1525     Bperm1(Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), PetscCountKokkosViewHost(coo_h->Bperm1, coo_h->Btot1))),
1526     Aimap2(Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), PetscCountKokkosViewHost(coo_h->Aimap2, coo_h->Annz2))),
1527     Ajmap2(Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), PetscCountKokkosViewHost(coo_h->Ajmap2, coo_h->Annz2 + 1))),
1528     Aperm2(Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), PetscCountKokkosViewHost(coo_h->Aperm2, coo_h->Atot2))),
1529     Bimap2(Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), PetscCountKokkosViewHost(coo_h->Bimap2, coo_h->Bnnz2))),
1530     Bjmap2(Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), PetscCountKokkosViewHost(coo_h->Bjmap2, coo_h->Bnnz2 + 1))),
1531     Bperm2(Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), PetscCountKokkosViewHost(coo_h->Bperm2, coo_h->Btot2))),
1532     Cperm1(Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), PetscCountKokkosViewHost(coo_h->Cperm1, coo_h->sendlen))),
1533     sendbuf(Kokkos::create_mirror_view(Kokkos::WithoutInitializing, DefaultMemorySpace(), MatScalarKokkosViewHost(coo_h->sendbuf, coo_h->sendlen))),
1534     recvbuf(Kokkos::create_mirror_view(Kokkos::WithoutInitializing, DefaultMemorySpace(), MatScalarKokkosViewHost(coo_h->recvbuf, coo_h->recvlen)))
1535   {
1536     PetscCallVoid(PetscObjectReference((PetscObject)sf));
1537   }
1538 
1539   ~MatCOOStruct_MPIAIJKokkos() { PetscCallVoid(PetscSFDestroy(&sf)); }
1540 };
1541 
1542 static PetscErrorCode MatCOOStructDestroy_MPIAIJKokkos(void *data)
1543 {
1544   PetscFunctionBegin;
1545   PetscCallCXX(delete static_cast<MatCOOStruct_MPIAIJKokkos *>(data));
1546   PetscFunctionReturn(PETSC_SUCCESS);
1547 }
1548 
1549 static PetscErrorCode MatSetPreallocationCOO_MPIAIJKokkos(Mat mat, PetscCount coo_n, PetscInt coo_i[], PetscInt coo_j[])
1550 {
1551   PetscContainer             container_h, container_d;
1552   MatCOOStruct_MPIAIJ       *coo_h;
1553   MatCOOStruct_MPIAIJKokkos *coo_d;
1554 
1555   PetscFunctionBegin;
1556   PetscCall(MatSetPreallocationCOO_MPIAIJ(mat, coo_n, coo_i, coo_j)); /* mpiaij->A,B's type is set to seqaijkokkos */
1557   mat->preallocated = PETSC_TRUE;
1558   PetscCall(MatAssemblyBegin(mat, MAT_FINAL_ASSEMBLY));
1559   PetscCall(MatAssemblyEnd(mat, MAT_FINAL_ASSEMBLY));
1560   PetscCall(MatZeroEntries(mat));
1561 
1562   // Copy the COO struct to device
1563   PetscCall(PetscObjectQuery((PetscObject)mat, "__PETSc_MatCOOStruct_Host", (PetscObject *)&container_h));
1564   PetscCall(PetscContainerGetPointer(container_h, (void **)&coo_h));
1565   PetscCallCXX(coo_d = new MatCOOStruct_MPIAIJKokkos(coo_h));
1566 
1567   // Put the COO struct in a container and then attach that to the matrix
1568   PetscCall(PetscContainerCreate(PETSC_COMM_SELF, &container_d));
1569   PetscCall(PetscContainerSetPointer(container_d, coo_d));
1570   PetscCall(PetscContainerSetUserDestroy(container_d, MatCOOStructDestroy_MPIAIJKokkos));
1571   PetscCall(PetscObjectCompose((PetscObject)mat, "__PETSc_MatCOOStruct_Device", (PetscObject)container_d));
1572   PetscCall(PetscContainerDestroy(&container_d));
1573   PetscFunctionReturn(PETSC_SUCCESS);
1574 }
1575 
1576 static PetscErrorCode MatSetValuesCOO_MPIAIJKokkos(Mat mat, const PetscScalar v[], InsertMode imode)
1577 {
1578   Mat_MPIAIJ                *mpiaij = static_cast<Mat_MPIAIJ *>(mat->data);
1579   Mat                        A = mpiaij->A, B = mpiaij->B;
1580   MatScalarKokkosView        Aa, Ba;
1581   MatScalarKokkosView        v1;
1582   PetscMemType               memtype;
1583   PetscContainer             container;
1584   MatCOOStruct_MPIAIJKokkos *coo;
1585 
1586   PetscFunctionBegin;
1587   PetscCall(PetscObjectQuery((PetscObject)mat, "__PETSc_MatCOOStruct_Device", (PetscObject *)&container));
1588   PetscCall(PetscContainerGetPointer(container, (void **)&coo));
1589 
1590   const auto &n      = coo->n;
1591   const auto &Annz   = coo->Annz;
1592   const auto &Annz2  = coo->Annz2;
1593   const auto &Bnnz   = coo->Bnnz;
1594   const auto &Bnnz2  = coo->Bnnz2;
1595   const auto &vsend  = coo->sendbuf;
1596   const auto &v2     = coo->recvbuf;
1597   const auto &Ajmap1 = coo->Ajmap1;
1598   const auto &Ajmap2 = coo->Ajmap2;
1599   const auto &Aimap2 = coo->Aimap2;
1600   const auto &Bjmap1 = coo->Bjmap1;
1601   const auto &Bjmap2 = coo->Bjmap2;
1602   const auto &Bimap2 = coo->Bimap2;
1603   const auto &Aperm1 = coo->Aperm1;
1604   const auto &Aperm2 = coo->Aperm2;
1605   const auto &Bperm1 = coo->Bperm1;
1606   const auto &Bperm2 = coo->Bperm2;
1607   const auto &Cperm1 = coo->Cperm1;
1608 
1609   PetscCall(PetscGetMemType(v, &memtype)); /* Return PETSC_MEMTYPE_HOST when v is NULL */
1610   if (PetscMemTypeHost(memtype)) {         /* If user gave v[] in host, we need to copy it to device if any */
1611     v1 = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), MatScalarKokkosViewHost((PetscScalar *)v, n));
1612   } else {
1613     v1 = MatScalarKokkosView((PetscScalar *)v, n); /* Directly use v[]'s memory */
1614   }
1615 
1616   if (imode == INSERT_VALUES) {
1617     PetscCall(MatSeqAIJGetKokkosViewWrite(A, &Aa)); /* write matrix values */
1618     PetscCall(MatSeqAIJGetKokkosViewWrite(B, &Ba));
1619   } else {
1620     PetscCall(MatSeqAIJGetKokkosView(A, &Aa)); /* read & write matrix values */
1621     PetscCall(MatSeqAIJGetKokkosView(B, &Ba));
1622   }
1623 
1624   PetscCall(PetscLogGpuTimeBegin());
1625   /* Pack entries to be sent to remote */
1626   Kokkos::parallel_for(Kokkos::RangePolicy<>(PetscGetKokkosExecutionSpace(), 0, vsend.extent(0)), KOKKOS_LAMBDA(const PetscCount i) { vsend(i) = v1(Cperm1(i)); });
1627 
1628   /* Send remote entries to their owner and overlap the communication with local computation */
1629   PetscCall(PetscSFReduceWithMemTypeBegin(coo->sf, MPIU_SCALAR, PETSC_MEMTYPE_KOKKOS, vsend.data(), PETSC_MEMTYPE_KOKKOS, v2.data(), MPI_REPLACE));
1630   /* Add local entries to A and B in one kernel */
1631   Kokkos::parallel_for(
1632     Kokkos::RangePolicy<>(PetscGetKokkosExecutionSpace(), 0, Annz + Bnnz), KOKKOS_LAMBDA(PetscCount i) {
1633       PetscScalar sum = 0.0;
1634       if (i < Annz) {
1635         for (PetscCount k = Ajmap1(i); k < Ajmap1(i + 1); k++) sum += v1(Aperm1(k));
1636         Aa(i) = (imode == INSERT_VALUES ? 0.0 : Aa(i)) + sum;
1637       } else {
1638         i -= Annz;
1639         for (PetscCount k = Bjmap1(i); k < Bjmap1(i + 1); k++) sum += v1(Bperm1(k));
1640         Ba(i) = (imode == INSERT_VALUES ? 0.0 : Ba(i)) + sum;
1641       }
1642     });
1643   PetscCall(PetscSFReduceEnd(coo->sf, MPIU_SCALAR, vsend.data(), v2.data(), MPI_REPLACE));
1644 
1645   /* Add received remote entries to A and B in one kernel */
1646   Kokkos::parallel_for(
1647     Kokkos::RangePolicy<>(PetscGetKokkosExecutionSpace(), 0, Annz2 + Bnnz2), KOKKOS_LAMBDA(PetscCount i) {
1648       if (i < Annz2) {
1649         for (PetscCount k = Ajmap2(i); k < Ajmap2(i + 1); k++) Aa(Aimap2(i)) += v2(Aperm2(k));
1650       } else {
1651         i -= Annz2;
1652         for (PetscCount k = Bjmap2(i); k < Bjmap2(i + 1); k++) Ba(Bimap2(i)) += v2(Bperm2(k));
1653       }
1654     });
1655   PetscCall(PetscLogGpuTimeEnd());
1656 
1657   if (imode == INSERT_VALUES) {
1658     PetscCall(MatSeqAIJRestoreKokkosViewWrite(A, &Aa)); /* Increase A & B's state etc. */
1659     PetscCall(MatSeqAIJRestoreKokkosViewWrite(B, &Ba));
1660   } else {
1661     PetscCall(MatSeqAIJRestoreKokkosView(A, &Aa));
1662     PetscCall(MatSeqAIJRestoreKokkosView(B, &Ba));
1663   }
1664   PetscFunctionReturn(PETSC_SUCCESS);
1665 }
1666 
1667 static PetscErrorCode MatDestroy_MPIAIJKokkos(Mat A)
1668 {
1669   PetscFunctionBegin;
1670   PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatMPIAIJSetPreallocation_C", NULL));
1671   PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatMPIAIJGetLocalMatMerge_C", NULL));
1672   PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatSetPreallocationCOO_C", NULL));
1673   PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatSetValuesCOO_C", NULL));
1674   PetscCall(MatDestroy_MPIAIJ(A));
1675   PetscFunctionReturn(PETSC_SUCCESS);
1676 }
1677 
1678 static PetscErrorCode MatShift_MPIAIJKokkos(Mat A, PetscScalar a)
1679 {
1680   Mat_MPIAIJ *mpiaij = static_cast<Mat_MPIAIJ *>(A->data);
1681   PetscBool   congruent;
1682 
1683   PetscFunctionBegin;
1684   PetscCall(MatHasCongruentLayouts(A, &congruent));
1685   if (congruent) { // square matrix and the diagonals are solely in the diag block
1686     PetscCall(MatShift(mpiaij->A, a));
1687   } else { // too hard, use the general version
1688     PetscCall(MatShift_Basic(A, a));
1689   }
1690   PetscFunctionReturn(PETSC_SUCCESS);
1691 }
1692 
1693 static PetscErrorCode MatSetOps_MPIAIJKokkos(Mat B)
1694 {
1695   PetscFunctionBegin;
1696   B->ops->assemblyend           = MatAssemblyEnd_MPIAIJKokkos;
1697   B->ops->mult                  = MatMult_MPIAIJKokkos;
1698   B->ops->multadd               = MatMultAdd_MPIAIJKokkos;
1699   B->ops->multtranspose         = MatMultTranspose_MPIAIJKokkos;
1700   B->ops->productsetfromoptions = MatProductSetFromOptions_MPIAIJKokkos;
1701   B->ops->destroy               = MatDestroy_MPIAIJKokkos;
1702   B->ops->shift                 = MatShift_MPIAIJKokkos;
1703 
1704   PetscCall(PetscObjectComposeFunction((PetscObject)B, "MatMPIAIJSetPreallocation_C", MatMPIAIJSetPreallocation_MPIAIJKokkos));
1705   PetscCall(PetscObjectComposeFunction((PetscObject)B, "MatMPIAIJGetLocalMatMerge_C", MatMPIAIJGetLocalMatMerge_MPIAIJKokkos));
1706   PetscCall(PetscObjectComposeFunction((PetscObject)B, "MatSetPreallocationCOO_C", MatSetPreallocationCOO_MPIAIJKokkos));
1707   PetscCall(PetscObjectComposeFunction((PetscObject)B, "MatSetValuesCOO_C", MatSetValuesCOO_MPIAIJKokkos));
1708   PetscFunctionReturn(PETSC_SUCCESS);
1709 }
1710 
1711 PETSC_INTERN PetscErrorCode MatConvert_MPIAIJ_MPIAIJKokkos(Mat A, MatType mtype, MatReuse reuse, Mat *newmat)
1712 {
1713   Mat         B;
1714   Mat_MPIAIJ *a;
1715 
1716   PetscFunctionBegin;
1717   if (reuse == MAT_INITIAL_MATRIX) {
1718     PetscCall(MatDuplicate(A, MAT_COPY_VALUES, newmat));
1719   } else if (reuse == MAT_REUSE_MATRIX) {
1720     PetscCall(MatCopy(A, *newmat, SAME_NONZERO_PATTERN));
1721   }
1722   B = *newmat;
1723 
1724   B->boundtocpu = PETSC_FALSE;
1725   PetscCall(PetscFree(B->defaultvectype));
1726   PetscCall(PetscStrallocpy(VECKOKKOS, &B->defaultvectype));
1727   PetscCall(PetscObjectChangeTypeName((PetscObject)B, MATMPIAIJKOKKOS));
1728 
1729   a = static_cast<Mat_MPIAIJ *>(A->data);
1730   if (a->A) PetscCall(MatSetType(a->A, MATSEQAIJKOKKOS));
1731   if (a->B) PetscCall(MatSetType(a->B, MATSEQAIJKOKKOS));
1732   if (a->lvec) PetscCall(VecSetType(a->lvec, VECSEQKOKKOS));
1733   PetscCall(MatSetOps_MPIAIJKokkos(B));
1734   PetscFunctionReturn(PETSC_SUCCESS);
1735 }
1736 
1737 /*MC
1738    MATAIJKOKKOS - "mpiaijkokkos", a matrix type to be used for CSR sparse matrices with Kokkos
1739 
1740    A matrix type using Kokkos-Kernels CrsMatrix type for portability across different device types
1741 
1742    Options Database Key:
1743 .  -mat_type aijkokkos - sets the matrix type to `MATAIJKOKKOS`
1744 
1745   Level: beginner
1746 
1747 .seealso: [](ch_matrices), `Mat`, `MatCreateAIJKokkos()`, `MATSEQAIJKOKKOS`, `MATSEQAIJ`, `MATMPIAIJ`
1748 M*/
1749 PETSC_EXTERN PetscErrorCode MatCreate_MPIAIJKokkos(Mat A)
1750 {
1751   PetscFunctionBegin;
1752   PetscCall(PetscKokkosInitializeCheck());
1753   PetscCall(MatCreate_MPIAIJ(A));
1754   PetscCall(MatConvert_MPIAIJ_MPIAIJKokkos(A, MATMPIAIJKOKKOS, MAT_INPLACE_MATRIX, &A));
1755   PetscFunctionReturn(PETSC_SUCCESS);
1756 }
1757 
1758 /*@C
1759   MatCreateAIJKokkos - Creates a sparse matrix in `MATAIJKOKOS` (compressed row) format
1760   (the default parallel PETSc format).  This matrix will ultimately pushed down
1761   to Kokkos for calculations.
1762 
1763   Collective
1764 
1765   Input Parameters:
1766 + comm  - MPI communicator, set to `PETSC_COMM_SELF`
1767 . m     - number of local rows (or `PETSC_DECIDE` to have calculated if `M` is given)
1768            This value should be the same as the local size used in creating the
1769            y vector for the matrix-vector product y = Ax.
1770 . n     - This value should be the same as the local size used in creating the
1771        x vector for the matrix-vector product y = Ax. (or `PETSC_DECIDE` to have
1772        calculated if N is given) For square matrices n is almost always `m`.
1773 . M     - number of global rows (or `PETSC_DETERMINE` to have calculated if `m` is given)
1774 . N     - number of global columns (or `PETSC_DETERMINE` to have calculated if `n` is given)
1775 . d_nz  - number of nonzeros per row in DIAGONAL portion of local submatrix
1776            (same value is used for all local rows)
1777 . d_nnz - array containing the number of nonzeros in the various rows of the
1778            DIAGONAL portion of the local submatrix (possibly different for each row)
1779            or `NULL`, if `d_nz` is used to specify the nonzero structure.
1780            The size of this array is equal to the number of local rows, i.e `m`.
1781            For matrices you plan to factor you must leave room for the diagonal entry and
1782            put in the entry even if it is zero.
1783 . o_nz  - number of nonzeros per row in the OFF-DIAGONAL portion of local
1784            submatrix (same value is used for all local rows).
1785 - o_nnz - array containing the number of nonzeros in the various rows of the
1786            OFF-DIAGONAL portion of the local submatrix (possibly different for
1787            each row) or `NULL`, if `o_nz` is used to specify the nonzero
1788            structure. The size of this array is equal to the number
1789            of local rows, i.e `m`.
1790 
1791   Output Parameter:
1792 . A - the matrix
1793 
1794   Level: intermediate
1795 
1796   Notes:
1797   It is recommended that one use the `MatCreate()`, `MatSetType()` and/or `MatSetFromOptions()`,
1798   MatXXXXSetPreallocation() paradigm instead of this routine directly.
1799   [MatXXXXSetPreallocation() is, for example, `MatSeqAIJSetPreallocation()`]
1800 
1801   The AIJ format, also called compressed row storage), is fully compatible with standard Fortran
1802   storage.  That is, the stored row and column indices can begin at
1803   either one (as in Fortran) or zero.
1804 
1805 .seealso: [](ch_matrices), `Mat`, `MATAIJKOKOS`, `MATSEQAIJKOKOS`, `MATMPIAIJKOKOS`, `MatCreate()`, `MatCreateAIJ()`, `MatSetValues()`,
1806           `MatSeqAIJSetColumnIndices()`, `MatCreateSeqAIJWithArrays()`, `MATMPIAIJKOKKOS`, `MATAIJKOKKOS`
1807 @*/
1808 PetscErrorCode MatCreateAIJKokkos(MPI_Comm comm, PetscInt m, PetscInt n, PetscInt M, PetscInt N, PetscInt d_nz, const PetscInt d_nnz[], PetscInt o_nz, const PetscInt o_nnz[], Mat *A)
1809 {
1810   PetscMPIInt size;
1811 
1812   PetscFunctionBegin;
1813   PetscCall(MatCreate(comm, A));
1814   PetscCall(MatSetSizes(*A, m, n, M, N));
1815   PetscCallMPI(MPI_Comm_size(comm, &size));
1816   if (size > 1) {
1817     PetscCall(MatSetType(*A, MATMPIAIJKOKKOS));
1818     PetscCall(MatMPIAIJSetPreallocation(*A, d_nz, d_nnz, o_nz, o_nnz));
1819   } else {
1820     PetscCall(MatSetType(*A, MATSEQAIJKOKKOS));
1821     PetscCall(MatSeqAIJSetPreallocation(*A, d_nz, d_nnz));
1822   }
1823   PetscFunctionReturn(PETSC_SUCCESS);
1824 }
1825