xref: /petsc/src/mat/impls/aij/mpi/kokkos/mpiaijkok.kokkos.cxx (revision e91c04dfc8a52dee1965211bb1cc8e5bf775178f)
1 #include <petsc_kokkos.hpp>
2 #include <petscvec_kokkos.hpp>
3 #include <petscpkg_version.h>
4 #include <petsc/private/sfimpl.h>
5 #include <petsc/private/kokkosimpl.hpp>
6 #include <../src/mat/impls/aij/seq/kokkos/aijkok.hpp>
7 #include <../src/mat/impls/aij/mpi/mpiaij.h>
8 #include <KokkosSparse_spadd.hpp>
9 #include <KokkosSparse_spgemm.hpp>
10 
11 static PetscErrorCode MatAssemblyEnd_MPIAIJKokkos(Mat A, MatAssemblyType mode)
12 {
13   Mat_MPIAIJ *mpiaij = (Mat_MPIAIJ *)A->data;
14 
15   PetscFunctionBegin;
16   PetscCall(MatAssemblyEnd_MPIAIJ(A, mode));
17   /* E.g., MatCreateSubMatrix() calls MatCreateMPIAIJWithSeqAIJ(comm,A,B,..), which creates Bnew of SEQAIJ and destroys B of SEQAIJKOKKOS.
18      Thus we finalize A/B/lvec's type in MatAssemblyEnd() to handle various cases.
19    */
20   if (mode == MAT_FINAL_ASSEMBLY) {
21     PetscScalarKokkosView v;
22 
23     PetscCall(MatSetType(mpiaij->A, MATSEQAIJKOKKOS));
24     PetscCall(MatSetType(mpiaij->B, MATSEQAIJKOKKOS));
25     PetscCall(VecSetType(mpiaij->lvec, VECSEQKOKKOS));  // lvec is init'ed on host, without copying to device
26     PetscCall(VecGetKokkosViewWrite(mpiaij->lvec, &v)); // mark lvec updated on device, as we never need to init lvec on device
27     PetscCall(VecRestoreKokkosViewWrite(mpiaij->lvec, &v));
28   }
29   PetscFunctionReturn(PETSC_SUCCESS);
30 }
31 
32 static PetscErrorCode MatMPIAIJSetPreallocation_MPIAIJKokkos(Mat mat, PetscInt d_nz, const PetscInt d_nnz[], PetscInt o_nz, const PetscInt o_nnz[])
33 {
34   Mat_MPIAIJ *mpiaij;
35 
36   PetscFunctionBegin;
37   // reuse MPIAIJ's preallocation, which sets A/B's blocksize along other things
38   PetscCall(MatMPIAIJSetPreallocation_MPIAIJ(mat, d_nz, d_nnz, o_nz, o_nnz));
39   mpiaij = static_cast<Mat_MPIAIJ *>(mat->data);
40   PetscCall(MatConvert_SeqAIJ_SeqAIJKokkos(mpiaij->A, MATSEQAIJKOKKOS, MAT_INPLACE_MATRIX, &mpiaij->A));
41   PetscCall(MatConvert_SeqAIJ_SeqAIJKokkos(mpiaij->B, MATSEQAIJKOKKOS, MAT_INPLACE_MATRIX, &mpiaij->B));
42   PetscFunctionReturn(PETSC_SUCCESS);
43 }
44 
45 static PetscErrorCode MatMult_MPIAIJKokkos(Mat mat, Vec xx, Vec yy)
46 {
47   Mat_MPIAIJ *mpiaij = (Mat_MPIAIJ *)mat->data;
48   PetscInt    nt;
49 
50   PetscFunctionBegin;
51   PetscCall(VecGetLocalSize(xx, &nt));
52   PetscCheck(nt == mat->cmap->n, PETSC_COMM_SELF, PETSC_ERR_ARG_SIZ, "Incompatible partition of mat (%" PetscInt_FMT ") and xx (%" PetscInt_FMT ")", mat->cmap->n, nt);
53   PetscCall(VecScatterBegin(mpiaij->Mvctx, xx, mpiaij->lvec, INSERT_VALUES, SCATTER_FORWARD));
54   PetscCall((*mpiaij->A->ops->mult)(mpiaij->A, xx, yy));
55   PetscCall(VecScatterEnd(mpiaij->Mvctx, xx, mpiaij->lvec, INSERT_VALUES, SCATTER_FORWARD));
56   PetscCall((*mpiaij->B->ops->multadd)(mpiaij->B, mpiaij->lvec, yy, yy));
57   PetscFunctionReturn(PETSC_SUCCESS);
58 }
59 
60 static PetscErrorCode MatMultAdd_MPIAIJKokkos(Mat mat, Vec xx, Vec yy, Vec zz)
61 {
62   Mat_MPIAIJ *mpiaij = (Mat_MPIAIJ *)mat->data;
63   PetscInt    nt;
64 
65   PetscFunctionBegin;
66   PetscCall(VecGetLocalSize(xx, &nt));
67   PetscCheck(nt == mat->cmap->n, PETSC_COMM_SELF, PETSC_ERR_ARG_SIZ, "Incompatible partition of mat (%" PetscInt_FMT ") and xx (%" PetscInt_FMT ")", mat->cmap->n, nt);
68   PetscCall(VecScatterBegin(mpiaij->Mvctx, xx, mpiaij->lvec, INSERT_VALUES, SCATTER_FORWARD));
69   PetscCall((*mpiaij->A->ops->multadd)(mpiaij->A, xx, yy, zz));
70   PetscCall(VecScatterEnd(mpiaij->Mvctx, xx, mpiaij->lvec, INSERT_VALUES, SCATTER_FORWARD));
71   PetscCall((*mpiaij->B->ops->multadd)(mpiaij->B, mpiaij->lvec, zz, zz));
72   PetscFunctionReturn(PETSC_SUCCESS);
73 }
74 
75 static PetscErrorCode MatMultTranspose_MPIAIJKokkos(Mat mat, Vec xx, Vec yy)
76 {
77   Mat_MPIAIJ *mpiaij = (Mat_MPIAIJ *)mat->data;
78   PetscInt    nt;
79 
80   PetscFunctionBegin;
81   PetscCall(VecGetLocalSize(xx, &nt));
82   PetscCheck(nt == mat->rmap->n, PETSC_COMM_SELF, PETSC_ERR_ARG_SIZ, "Incompatible partition of mat (%" PetscInt_FMT ") and xx (%" PetscInt_FMT ")", mat->rmap->n, nt);
83   PetscCall((*mpiaij->B->ops->multtranspose)(mpiaij->B, xx, mpiaij->lvec));
84   PetscCall((*mpiaij->A->ops->multtranspose)(mpiaij->A, xx, yy));
85   PetscCall(VecScatterBegin(mpiaij->Mvctx, mpiaij->lvec, yy, ADD_VALUES, SCATTER_REVERSE));
86   PetscCall(VecScatterEnd(mpiaij->Mvctx, mpiaij->lvec, yy, ADD_VALUES, SCATTER_REVERSE));
87   PetscFunctionReturn(PETSC_SUCCESS);
88 }
89 
90 /* Merge the "A, B" matrices of mat into a matrix C.  mat's type is MPIAIJKOKKOS. C's type is MATSEQAIJKOKKOS.
91    A is put before B. C's size would be A->rmap->n by (A->cmap->n + B->cmap->n).
92    C still uses local column ids. Their corresponding global column ids are returned in glob.
93 */
94 static PetscErrorCode MatMPIAIJGetLocalMatMerge_MPIAIJKokkos(Mat mat, MatReuse reuse, IS *glob, Mat *C)
95 {
96   Mat             Ad, Ao;
97   const PetscInt *cmap;
98 
99   PetscFunctionBegin;
100   PetscCall(MatMPIAIJGetSeqAIJ(mat, &Ad, &Ao, &cmap));
101   PetscCall(MatSeqAIJKokkosMergeMats(Ad, Ao, reuse, C));
102   if (glob) {
103     PetscInt cst, i, dn, on, *gidx;
104     PetscCall(MatGetLocalSize(Ad, NULL, &dn));
105     PetscCall(MatGetLocalSize(Ao, NULL, &on));
106     PetscCall(MatGetOwnershipRangeColumn(mat, &cst, NULL));
107     PetscCall(PetscMalloc1(dn + on, &gidx));
108     for (i = 0; i < dn; i++) gidx[i] = cst + i;
109     for (i = 0; i < on; i++) gidx[i + dn] = cmap[i];
110     PetscCall(ISCreateGeneral(PetscObjectComm((PetscObject)Ad), dn + on, gidx, PETSC_OWN_POINTER, glob));
111   }
112   PetscFunctionReturn(PETSC_SUCCESS);
113 }
114 
115 /* Structs used in matrix products of type C=AB, C=A^tB and C=B^tAB */
116 struct MatMatStruct {
117   PetscInt            n, *garray;     // C's garray and its size.
118   KokkosCsrMatrix     Cd, Co;         // C is in split form matrices (all in local column indcies)
119   KokkosCsrMatrix     C1, C2, C3, C4; // intermediate mat products
120   KokkosCsrMatrix     C2_mid, C4_mid; // alias of C2, C4; share their a[], i[], but with different j[] (hence column size)
121   PetscIntKokkosView  E_NzLeft;
122   PetscSF             sf = nullptr; // SF to bcast or reduce matrices E to F
123   MatScalarKokkosView rootBuf, leafBuf;
124   KokkosCsrMatrix     Fd, Fo; // F in split form
125 
126   KernelHandle kh1; // compute C1, add C1+C3 or C1+Fd
127   KernelHandle kh2; // compute C2, add C2+C4 or C2+Fo
128   KernelHandle kh3; // compute C3
129   KernelHandle kh4; // compute C4
130 
131   PetscInt E_TeamSize; // kernel launching parameters in merging E or splitting F
132   PetscInt E_VectorLength;
133   PetscInt E_RowsPerTeam;
134   PetscInt F_TeamSize;
135   PetscInt F_VectorLength;
136   PetscInt F_RowsPerTeam;
137 
138   ~MatMatStruct()
139   {
140     PetscFunctionBegin;
141     PetscCallAbort(PETSC_COMM_SELF, PetscSFDestroy(&sf));
142     PetscFunctionReturnVoid();
143   }
144 };
145 
146 struct MatMatStruct_AB : public MatMatStruct {
147   PetscIntKokkosView F_NzLeft; // plans to split F (in leafbuf) into Fd, Fo
148   PetscIntKokkosView irootloc; // plans to put E (i.e., Bd, Bo) into rootBuf
149   PetscIntKokkosView rowoffset;
150 };
151 
152 struct MatMatStruct_AtB : public MatMatStruct {
153   MatColIdxKokkosView Fdjmap; // plans to reduce data in rootBuf to Fd, Fo
154   MatColIdxKokkosView Fdjperm;
155   MatColIdxKokkosView Fojmap;
156   MatColIdxKokkosView Fojperm;
157 };
158 
159 struct MatProductData_MPIAIJKokkos {
160   MatMatStruct_AB  *mmAB     = nullptr;
161   MatMatStruct_AtB *mmAtB    = nullptr;
162   PetscBool         reusesym = PETSC_FALSE;
163   Mat               Z        = nullptr; // store Z=AB in computing BtAB
164 
165   ~MatProductData_MPIAIJKokkos()
166   {
167     delete mmAB;
168     delete mmAtB;
169     PetscCallAbort(PETSC_COMM_SELF, MatDestroy(&Z));
170   }
171 };
172 
173 static PetscErrorCode MatProductDataDestroy_MPIAIJKokkos(void *data)
174 {
175   PetscFunctionBegin;
176   PetscCallCXX(delete static_cast<MatProductData_MPIAIJKokkos *>(data));
177   PetscFunctionReturn(PETSC_SUCCESS);
178 }
179 
180 /* MatSetMPIAIJKokkosWithSplitSeqAIJKokkosMatrices - Set the diag and offdiag matrices of a MATMPIAIJKOKKOS matrix.
181    It is similar to MatCreateMPIAIJWithSplitArrays.
182 
183   Input Parameters:
184 +  mat   - the MATMPIAIJKOKKOS matrix, which should have its type and layout set, but should not have its diag, offdiag matrices set
185 .  A     - the diag matrix using local col ids
186 -  B     - the offdiag matrix using global col ids
187 
188   Output Parameter:
189 .  mat   - the updated MATMPIAIJKOKKOS matrix
190 */
191 static PetscErrorCode MatSetMPIAIJKokkosWithSplitSeqAIJKokkosMatrices(Mat mat, Mat A, Mat B, PetscInt *garray)
192 {
193   Mat_MPIAIJ *mpiaij = static_cast<Mat_MPIAIJ *>(mat->data);
194   PetscInt    m, n, M, N, Am, An, Bm, Bn;
195 
196   PetscFunctionBegin;
197   PetscCall(MatGetSize(mat, &M, &N));
198   PetscCall(MatGetLocalSize(mat, &m, &n));
199   PetscCall(MatGetLocalSize(A, &Am, &An));
200   PetscCall(MatGetLocalSize(B, &Bm, &Bn));
201 
202   PetscCheck(m == Am && m == Bm, PETSC_COMM_SELF, PETSC_ERR_PLIB, "local number of rows do not match");
203   PetscCheck(n == An, PETSC_COMM_SELF, PETSC_ERR_PLIB, "local number of columns do not match");
204   // PetscCheck(N == Bn, PETSC_COMM_SELF, PETSC_ERR_PLIB, "global number of columns do not match");
205   PetscCheck(!mpiaij->A && !mpiaij->B, PETSC_COMM_SELF, PETSC_ERR_PLIB, "A, B of the MPIAIJ matrix are not empty");
206   mpiaij->A      = A;
207   mpiaij->B      = B;
208   mpiaij->garray = garray;
209 
210   mat->preallocated     = PETSC_TRUE;
211   mat->nooffprocentries = PETSC_TRUE; /* See MatAssemblyBegin_MPIAIJ. In effect, making MatAssemblyBegin a nop */
212 
213   PetscCall(MatSetOption(mat, MAT_NO_OFF_PROC_ENTRIES, PETSC_TRUE));
214   PetscCall(MatAssemblyBegin(mat, MAT_FINAL_ASSEMBLY));
215   /* MatAssemblyEnd is critical here. It sets mat->offloadmask according to A and B's, and
216     also gets mpiaij->B compacted, with its col ids and size reduced
217   */
218   PetscCall(MatAssemblyEnd(mat, MAT_FINAL_ASSEMBLY));
219   PetscCall(MatSetOption(mat, MAT_NO_OFF_PROC_ENTRIES, PETSC_FALSE));
220   PetscCall(MatSetOption(mat, MAT_NEW_NONZERO_LOCATION_ERR, PETSC_TRUE));
221   PetscFunctionReturn(PETSC_SUCCESS);
222 }
223 
224 // Adapted from Kokkos-Kernels spmv_launch_parameters(), to get parameters in Kokkos nested loops which we used to merge or
225 // split csr matrices. The rule is to have "vector_length * team_size" be around 256 on GPUs (e.g., for a CUDA thread block)
226 template <class ExecutionSpace>
227 static PetscErrorCode MatMergeGetLaunchParameters(PetscInt numRows, PetscInt nnz, PetscInt rows_per_thread, PetscInt &team_size, PetscInt &vector_length, PetscInt &rows_per_team)
228 {
229 #if PETSC_PKG_KOKKOS_KERNELS_VERSION_LE(4, 4, 1)
230   constexpr bool is_gpu_exec_space = KokkosKernels::Impl::kk_is_gpu_exec_space<ExecutionSpace>();
231 #else
232   constexpr bool is_gpu_exec_space = KokkosKernels::Impl::is_gpu_exec_space_v<ExecutionSpace>;
233 #endif
234   Kokkos::TeamPolicy<ExecutionSpace> teamPolicy(128, Kokkos::AUTO);
235 
236   PetscFunctionBegin;
237   PetscInt nnz_per_row = numRows ? (nnz / numRows) : 0; // we might meet empty matrices
238 
239   if (nnz_per_row < 1) nnz_per_row = 1;
240 
241   int max_vector_length = teamPolicy.vector_length_max();
242 
243   if (vector_length < 1) {
244     vector_length = 1;
245     while (vector_length < max_vector_length && vector_length * 6 < nnz_per_row) vector_length *= 2;
246   }
247 
248   // Determine rows per thread
249   if (rows_per_thread < 1) {
250     if (is_gpu_exec_space) rows_per_thread = 1;
251     else {
252       if (nnz_per_row < 20 && nnz > 5000000) {
253         rows_per_thread = 256;
254       } else rows_per_thread = 64;
255     }
256   }
257 
258   if (team_size < 1) {
259     if (is_gpu_exec_space) {
260       team_size = 256 / vector_length;
261     } else {
262       team_size = 1;
263     }
264   }
265 
266   rows_per_team = rows_per_thread * team_size;
267 
268   if (rows_per_team < 0) {
269     PetscInt nnz_per_team = 4096;
270     PetscInt conc         = ExecutionSpace().concurrency();
271     while ((conc * nnz_per_team * 4 > nnz) && (nnz_per_team > 256)) nnz_per_team /= 2;
272     rows_per_team = (nnz_per_team + nnz_per_row - 1) / nnz_per_row;
273   }
274   PetscFunctionReturn(PETSC_SUCCESS);
275 }
276 
277 /*
278   Reduce two sets of global indices into local ones
279 
280   Input Parameters:
281 +  n1          - size of garray1[], the first set
282 .  garray1[n1] - a sorted global index array (without duplicates)
283 .  m           - size of indices[], the second set
284 -  indices[m]  - a unsorted global index array (might have duplicates), which will be updated on output into local ones
285 
286   Output Parameters:
287 +  n2          - size of garray2[], the merged set, which combines garray1[] and indices[]
288 .  garray2[n2] - allocated by callee using PetscMalloc1(). Contains sorted unique global indices (without duplicates). Caller needs to free it.
289 .  map[n1]     - allocated by caller. It gives garray1[i] = garray2[map[i]]
290 -  indices[m]  - on output, global indices in this array are rewritten with local ones, i.e, indices_input[i] = garray2[indices_output[i]]
291 
292    Example, say
293     n1         = 5
294     garray1[5] = {1, 4, 7, 8, 10}
295     m          = 4
296     indices[4] = {2, 4, 8, 9}
297 
298    Combining them together, we have 7 global indices in garray2[]
299     n2         = 7
300     garray2[7] = {1, 2, 4, 7, 8, 9, 10}
301 
302    And we have map[] to connect "garray1[i] = garray2[map[i]], i=[0,n1)"
303     map[5] = {0, 2, 3, 4, 6}
304 
305    On output, indices[] is updated with local indices
306     indices[4] = {1, 2, 4, 5}
307 */
308 static PetscErrorCode ReduceTwoSetsOfGlobalIndices(PetscInt n1, const PetscInt *garray1, PetscInt m, PetscInt *indices, PetscInt *n2_, PetscInt **garray2_, PetscInt *map)
309 {
310   PetscHMapI    g2l = nullptr;
311   PetscHashIter iter;
312   PetscInt      tot, key, val; // total unique global indices. key is global id; val is local id
313   PetscInt      n2, *garray2;
314 
315   PetscFunctionBegin;
316   tot = 0;
317   PetscCall(PetscHMapICreateWithSize(n1, &g2l));
318   for (PetscInt i = 0; i < m; i++) {                                // insert those in indices[]
319     PetscCall(PetscHMapIGetWithDefault(g2l, indices[i], -1, &val)); // if not exist, val is set with -1
320     if (val < 0) PetscCall(PetscHMapISet(g2l, indices[i], tot++));  // val < 0 means gid is not in the hash table yet
321   }
322 
323   for (PetscInt i = 0; i < n1; i++) { // insert those in garray1[]
324     PetscCall(PetscHMapIGetWithDefault(g2l, garray1[i], -1, &val));
325     if (val < 0) PetscCall(PetscHMapISet(g2l, garray1[i], tot++));
326   }
327 
328   // Pull out (unique) globals in the hash table and put them in garray2[]
329   n2 = tot;
330   PetscCall(PetscMalloc1(n2, &garray2));
331   tot = 0;
332   PetscHashIterBegin(g2l, iter);
333   while (!PetscHashIterAtEnd(g2l, iter)) {
334     PetscHashIterGetKey(g2l, iter, key);
335     PetscHashIterNext(g2l, iter);
336     garray2[tot++] = key;
337   }
338 
339   // Sort garray2[] and then map them to local indices starting from 0
340   PetscCall(PetscSortInt(n2, garray2));
341   PetscCall(PetscHMapIClear(g2l));
342   for (PetscInt i = 0; i < tot; i++) PetscCall(PetscHMapISet(g2l, garray2[i], i)); // i is the local id
343 
344   // Rewrite indices[] with local indices
345   for (PetscInt i = 0; i < m; i++) {
346     PetscCall(PetscHMapIGetWithDefault(g2l, indices[i], -1, &val));
347     PetscAssert(val >= 0, PETSC_COMM_SELF, PETSC_ERR_PLIB, "Met a negative local column index");
348     indices[i] = val;
349   }
350   // Record the map that maps garray1[i] to garray2[map[i]]
351   for (PetscInt i = 0; i < n1; i++) PetscCall(PetscHMapIGetWithDefault(g2l, garray1[i], -1, &map[i]));
352   PetscCall(PetscHMapIDestroy(&g2l));
353   *n2_      = n2;
354   *garray2_ = garray2;
355   PetscFunctionReturn(PETSC_SUCCESS);
356 }
357 
358 /*
359   MatMPIAIJKokkosReduce - Reduce rows of a MPIAIJKOKKOS matrix (E, in split form) to produce another matrix (F, also in split form, stored in mm)
360 
361   It is the reverse of MatMPIAIJKokkosBcast() in some sense, but with a different signature since we do not really need a fully populated MPIAIJKOKKOS E.
362 
363   Think each row of E as a leaf, then the given ownerSF specifies roots for the leaves. Roots may connect to multiple leaves.
364   In this routine, we sparse-merge leaves (rows) at their roots to form potentially longer rows in F. F's number of rows will be nroots of ownerSF.
365 
366   Input Parameters:
367 +  comm       - MPI communicator of E
368 .  A          - diag block of E, using local column indices
369 .  B          - off-diag block of E, using local column indices
370 .  cstart      - (global) start column of Ed
371 .  cend        - (global) end column + 1 of Ed.  In other words, E's column ownership is in range of [cstart, cend)
372 .  garray1[n1] - global column indices of Eo. Here n1 is Eo's column size.
373 .  ownerSF     - the SF specifies ownership (root) of rows in E
374 .  reuse       - either MAT_INITIAL_MATRIX or MAT_REUSE_MATRIX
375 -  mm          - to stash intermediate data structures for reuse
376 
377   Output Parameters:
378 +  map[n1]  - allocated by caller. It maps garray1[] to garray2[]. See more at ReduceTwoSetsOfGlobalIndices().
379 -  mm       - contains various info, such as garray2[], F (Fd, Fo) etc.
380 
381   Notes:
382   When reuse = MAT_REUSE_MATRIX, cstart, cend, garray1, ownerSF, map are not significant.
383 
384  */
385 static PetscErrorCode MatMPIAIJKokkosReduceBegin(MPI_Comm comm, KokkosCsrMatrix A, KokkosCsrMatrix B, PetscInt cstart, PetscInt cend, const PetscInt *garray1, PetscSF ownerSF, MatReuse reuse, PetscInt *map, MatMatStruct_AtB *mm)
386 {
387   PetscFunctionBegin;
388   if (reuse == MAT_INITIAL_MATRIX) {
389     PetscInt Em = A.numRows(), Fm;
390     PetscInt n1 = B.numCols();
391 
392     PetscCall(PetscSFGetGraph(ownerSF, &Fm, NULL, NULL, NULL)); // Fm = #rows of F = nroots of ownerSF
393 
394     // Do the analysis on host
395     auto                 Ai_h = Kokkos::create_mirror_view_and_copy(Kokkos::HostSpace(), A.graph.row_map);
396     auto                 Aj_h = Kokkos::create_mirror_view_and_copy(Kokkos::HostSpace(), A.graph.entries);
397     auto                 Bi_h = Kokkos::create_mirror_view_and_copy(Kokkos::HostSpace(), B.graph.row_map);
398     auto                 Bj_h = Kokkos::create_mirror_view_and_copy(Kokkos::HostSpace(), B.graph.entries);
399     const MatRowMapType *Ai = Ai_h.data(), *Bi = Bi_h.data();
400     const MatColIdxType *Aj = Aj_h.data(), *Bj = Bj_h.data();
401 
402     // Count how many nonzeros of each row in E are in the left of the diag block [cstart,cend)
403     PetscIntKokkosViewHost E_NzLeft_h(NoInit("E_NzLeft_h"), Em), E_RowLen_h(NoInit("E_RowLen_h"), Em);
404     PetscInt              *E_NzLeft = E_NzLeft_h.data(), *E_RowLen = E_RowLen_h.data();
405     for (PetscInt i = 0; i < Em; i++) {
406       const PetscInt *first, *last, *it;
407       PetscInt        count, step;
408       // std::lower_bound(first,last,cstart), but need to use global column indices
409       first = Bj + Bi[i];
410       last  = Bj + Bi[i + 1];
411       count = last - first;
412       while (count > 0) {
413         it   = first;
414         step = count / 2;
415         it += step;
416         if (garray1[*it] < cstart) { // map local to global
417           first = ++it;
418           count -= step + 1;
419         } else count = step;
420       }
421       E_NzLeft[i] = first - (Bj + Bi[i]);
422       E_RowLen[i] = (Ai[i + 1] - Ai[i]) + (Bi[i + 1] - Bi[i]);
423     }
424 
425     // Get length of rows (i.e., sizes of leaves) that contribute to my roots
426     const PetscMPIInt *iranks, *ranks;
427     const PetscInt    *ioffset, *irootloc, *roffset, *rmine;
428     PetscMPIInt        niranks, nranks;
429     MPI_Request       *reqs;
430     PetscMPIInt        tag;
431     PetscSF            reduceSF;
432     PetscInt          *sdisp, *rdisp;
433 
434     PetscCall(PetscCommGetNewTag(comm, &tag));
435     PetscCall(PetscSFGetLeafRanks(ownerSF, &niranks, &iranks, &ioffset, &irootloc));  // get leaf ranks connecting to roots on this process (I'll recv from them)
436     PetscCall(PetscSFGetRootRanks(ownerSF, &nranks, &ranks, &roffset, &rmine, NULL)); // get root ranks this process connects (I'll send to them)
437 
438     // Find out length of each row I will receive. Even for the same row index, when they are from
439     // different senders, they might have different lengths (and sparsity patterns)
440     PetscInt  sendRowCnt = roffset[nranks], recvRowCnt = ioffset[niranks];
441     PetscInt *sendRowLen, *recvRowLen; // lengths of rows of I need to send/recv per process
442 
443     PetscCall(PetscMalloc5(sendRowCnt, &sendRowLen, recvRowCnt + 1, &recvRowLen, nranks, &sdisp, niranks + 1, &rdisp, nranks + niranks, &reqs));
444 
445     for (PetscInt i = 0; i < sendRowCnt; i++) sendRowLen[i] = E_RowLen[rmine[i]];
446     recvRowLen[0] = 0; // since we will make it in CSR format later
447     recvRowLen++;      // advance the pointer now
448     for (PetscInt i = 0; i < niranks; i++) MPI_Irecv(&recvRowLen[ioffset[i]], ioffset[i + 1] - ioffset[i], MPIU_INT, iranks[i], tag, comm, &reqs[nranks + i]);
449     for (PetscInt i = 0; i < nranks; i++) MPIU_Isend(&sendRowLen[roffset[i]], roffset[i + 1] - roffset[i], MPIU_INT, ranks[i], tag, comm, &reqs[i]);
450     PetscCallMPI(MPI_Waitall(nranks + niranks, reqs, MPI_STATUSES_IGNORE));
451 
452     // Build the real PetscSF for reducing E rows (buffer to buffer)
453     rdisp[0] = 0;
454     for (PetscInt i = 0; i < niranks; i++) {
455       rdisp[i + 1] = rdisp[i];
456       for (PetscInt j = ioffset[i]; j < ioffset[i + 1]; j++) { rdisp[i + 1] += recvRowLen[j]; }
457     }
458     recvRowLen--; // put it back into csr format
459     for (PetscInt i = 0; i < recvRowCnt; i++) recvRowLen[i + 1] += recvRowLen[i];
460 
461     for (PetscInt i = 0; i < nranks; i++) MPIU_Irecv(&sdisp[i], 1, MPIU_INT, ranks[i], tag, comm, &reqs[i]);
462     for (PetscInt i = 0; i < niranks; i++) MPIU_Isend(&rdisp[i], 1, MPIU_INT, iranks[i], tag, comm, &reqs[nranks + i]);
463     PetscCallMPI(MPI_Waitall(nranks + niranks, reqs, MPI_STATUSES_IGNORE));
464 
465     PetscInt     nleaves = 0, Enz = 0;    // leaves are nonzeros I will send
466     PetscInt     nroots = rdisp[niranks]; // roots are nonzeros I will recv
467     PetscSFNode *iremote;
468 
469     for (PetscInt i = 0; i < Em; i++) Enz += E_RowLen[i];
470     PetscAssert(A.nnz() + B.nnz() == Enz, comm, PETSC_ERR_PLIB, "Enz should be equal to sum of nnz of A and B");
471     PetscCallMPI(PetscMalloc1(Enz, &iremote)); // no free, since we give ownership to reduceSF
472 
473     for (PetscInt i = 0; i < nranks; i++) {
474       PetscInt count = 0;
475       for (PetscInt j = roffset[i]; j < roffset[i + 1]; j++) count += E_RowLen[rmine[j]];
476       for (PetscInt j = 0; j < count; j++) {
477         iremote[nleaves + j].rank  = ranks[i];
478         iremote[nleaves + j].index = sdisp[i] + j;
479       }
480       nleaves += count;
481     }
482     PetscCheck(nleaves == Enz, comm, PETSC_ERR_PLIB, "nleaves should be equal to Enz");
483 
484     PetscCall(PetscSFCreate(comm, &reduceSF));
485     PetscCall(PetscSFSetGraph(reduceSF, nroots, nleaves, NULL, PETSC_OWN_POINTER, iremote, PETSC_OWN_POINTER));
486 
487     // Copy (global) column indices of the needed rows in E to sendCol[], and then PetscSFReduce to recvCol[]
488     PetscInt *sendCol, *recvCol;
489     PetscCall(PetscMalloc2(nleaves, &sendCol, nroots, &recvCol));
490     for (PetscInt k = 0; k < roffset[nranks]; k++) {
491       PetscInt  i      = rmine[k]; // row to be copied
492       PetscInt *buf    = &sendCol[Ai[i] + Bi[i]];
493       PetscInt  nzLeft = E_NzLeft[i];
494       PetscInt  alen = Ai[i + 1] - Ai[i], blen = Bi[i + 1] - Bi[i];
495       for (PetscInt j = 0; j < alen + blen; j++) {
496         if (j < nzLeft) {
497           buf[j] = garray1[Bj[Bi[i] + j]]; // left B, in global
498         } else if (j < nzLeft + alen) {
499           buf[j] = Aj[Ai[i] + j - nzLeft] + cstart; // diag A, also in global
500         } else {
501           buf[j] = garray1[Bj[Bi[i] + j - alen]]; // right B, in global
502         }
503       }
504     }
505     PetscCall(PetscSFReduceWithMemTypeBegin(reduceSF, MPIU_INT, PETSC_MEMTYPE_HOST, sendCol, PETSC_MEMTYPE_HOST, recvCol, MPI_REPLACE));
506     PetscCall(PetscSFReduceEnd(reduceSF, MPIU_INT, sendCol, recvCol, MPI_REPLACE));
507 
508     // With recvCol[], we do a series of analysis to get i, j of Fd, Fo, and build plans to reduce nonzeros in recv buffers to Fd and Fo
509     PetscInt *recvRowPerm, *recvColSorted;
510     PetscInt *recvNzPerm, *recvNzPermSorted;
511     PetscCall(PetscMalloc4(recvRowCnt, &recvRowPerm, nroots, &recvColSorted, nroots, &recvNzPerm, nroots, &recvNzPermSorted));
512 
513     for (PetscInt i = 0; i < nroots; i++) recvNzPerm[i] = i;                   // numbering all received nonzeros
514     for (PetscInt i = 0; i < recvRowCnt; i++) recvRowPerm[i] = i;              // put up a permutation array, so that after sorting we know where to get a row in recvCol[]
515     PetscCall(PetscSortIntWithPermutation(recvRowCnt, irootloc, recvRowPerm)); // irootloc[] (owned by ownerSF) won't be changed
516 
517     // i[] array, nz are always easiest to compute
518     MatRowMapKokkosViewHost Fdi_h(NoInit("Fdi_h"), Fm + 1), Foi_h(NoInit("Foi_h"), Fm + 1);
519     MatRowMapType          *Fdi, *Foi;
520     PetscInt                FnzDups = 0, Fdnz = 0, FdnzDups = 0, Fonz = 0, FonzDups = 0; // nz (with or without dups) in F, Fd, Fo
521     PetscInt                iter;
522 
523     Kokkos::deep_copy(Fdi_h, 0); // zero, as we will do 'val++' on them
524     Kokkos::deep_copy(Foi_h, 0);
525     Fdi  = Fdi_h.data() + 1; // +1 for easy indexing in code below
526     Foi  = Foi_h.data() + 1;
527     iter = 0;
528     while (iter < recvRowCnt) { // iter over received rows
529       PetscInt curRowIdx = irootloc[recvRowPerm[iter]];
530       PetscInt dupRows   = 1; // current row has this many contributing rows (of various sparsity patterns)
531 
532       while (iter + dupRows < recvRowCnt && irootloc[recvRowPerm[iter + dupRows]] == curRowIdx) dupRows++;
533 
534       // Copy column indices (and their permutation) of these rows into recvColSorted & recvNzPermSorted
535       PetscInt  nz    = 0; // nz (with dups) in the current row
536       PetscInt *jbuf  = recvColSorted + FnzDups;
537       PetscInt *pbuf  = recvNzPermSorted + FnzDups;
538       PetscInt *jbuf2 = jbuf; // temp pointers
539       PetscInt *pbuf2 = pbuf;
540       for (PetscInt d = 0; d < dupRows; d++) {
541         PetscInt i   = recvRowPerm[iter + d];
542         PetscInt len = recvRowLen[i + 1] - recvRowLen[i];
543         PetscCall(PetscArraycpy(jbuf2, &recvCol[recvRowLen[i]], len));
544         PetscCall(PetscArraycpy(pbuf2, &recvNzPerm[recvRowLen[i]], len));
545         jbuf2 += len;
546         pbuf2 += len;
547         nz += len;
548       }
549       PetscCall(PetscIntSortSemiOrderedWithArray(nz, jbuf, pbuf)); // It could be improved with k-way merge sort, since the rows are already sorted
550 
551       // Scan column indices (in jbuf[0,nz), might have dups) of this row, and see how many go to Fd and how many go to Fo
552       PetscInt cur = 0;
553       while (cur < nz) {
554         PetscInt curColIdx = jbuf[cur];
555         PetscInt dups      = 1;
556 
557         while (cur + dups < nz && jbuf[cur + dups] == curColIdx) dups++;
558         if (curColIdx >= cstart && curColIdx < cend) {
559           Fdi[curRowIdx]++;
560           FdnzDups += dups;
561         } else {
562           Foi[curRowIdx]++;
563           FonzDups += dups;
564         }
565         cur += dups;
566       }
567 
568       FnzDups += nz;
569       iter += dupRows; // Move to next unique row
570     }
571 
572     Fdi = Fdi_h.data(); // restore Fdi, Foi and make them CSR
573     Foi = Foi_h.data();
574     for (PetscInt i = 0; i < Fm; i++) {
575       Fdi[i + 1] += Fdi[i];
576       Foi[i + 1] += Foi[i];
577     }
578     Fdnz = Fdi[Fm];
579     Fonz = Foi[Fm];
580     PetscCall(PetscFree2(sendCol, recvCol));
581 
582     // Allocate j, jmap, jperm for Fd and Fo
583     MatColIdxKokkosViewHost Fdj_h(NoInit("Fdj_h"), Fdnz), Foj_h(NoInit("Foj_h"), Fonz);
584     MatRowMapKokkosViewHost Fdjmap_h(NoInit("Fdjmap_h"), Fdnz + 1), Fojmap_h(NoInit("Fojmap_h"), Fonz + 1); // +1 to make csr
585     MatRowMapKokkosViewHost Fdjperm_h(NoInit("Fdjperm_h"), FdnzDups), Fojperm_h(NoInit("Fojperm_h"), FonzDups);
586     MatColIdxType          *Fdj = Fdj_h.data(), *Foj = Foj_h.data();
587     MatRowMapType          *Fdjmap = Fdjmap_h.data(), *Fojmap = Fojmap_h.data();
588     MatRowMapType          *Fdjperm = Fdjperm_h.data(), *Fojperm = Fojperm_h.data();
589 
590     // Scan recvColSorted[] again, and fill j, jmap, jperm for Fd and Fo
591     Fdjmap[0] = 0;
592     Fojmap[0] = 0;
593     FnzDups   = 0;
594     Fdnz      = 0;
595     Fonz      = 0;
596     iter      = 0; // iter over received rows
597     while (iter < recvRowCnt) {
598       PetscInt curRowIdx = irootloc[recvRowPerm[iter]]; // current row idx
599       PetscInt dupRows   = 1;                           // It has this many contributing rows (of various lengths)
600       PetscInt nz        = 0;                           // nz (with dups) in the current row
601 
602       while (iter + dupRows < recvRowCnt && irootloc[recvRowPerm[iter + dupRows]] == curRowIdx) dupRows++;
603       for (PetscInt d = 0; d < dupRows; d++) {
604         PetscInt i = recvRowPerm[iter + d];
605         nz += recvRowLen[i + 1] - recvRowLen[i];
606       }
607 
608       PetscInt *jbuf = recvColSorted + FnzDups;
609       // Scan columns (in jbuf[0,nz) of this row, copy them and their permutation to j[] and jperm[] of Fd and Fo
610       PetscInt cur = 0;
611       while (cur < nz) {
612         PetscInt curColIdx = jbuf[cur];
613         PetscInt dups      = 1;
614 
615         while (cur + dups < nz && jbuf[cur + dups] == curColIdx) dups++;
616         if (curColIdx >= cstart && curColIdx < cend) {
617           Fdj[Fdnz]        = curColIdx - cstart; // easily convert to local
618           Fdjmap[Fdnz + 1] = Fdjmap[Fdnz] + dups;
619           for (PetscInt j = 0; j < dups; j++) Fdjperm[Fdjmap[Fdnz] + j] = recvNzPermSorted[FnzDups + j];
620           FdnzDups += dups;
621           Fdnz++;
622         } else {
623           Foj[Fonz]        = curColIdx; // in global
624           Fojmap[Fonz + 1] = Fojmap[Fonz] + dups;
625           for (PetscInt j = 0; j < dups; j++) Fojperm[Fojmap[Fonz] + j] = recvNzPermSorted[FnzDups + j];
626           FonzDups += dups;
627           Fonz++;
628         }
629         cur += dups;
630         FnzDups += dups;
631       }
632       iter += dupRows; // Move to next unique row
633     }
634     PetscCall(PetscFree4(recvRowPerm, recvColSorted, recvNzPerm, recvNzPermSorted));
635     PetscCall(PetscFree5(sendRowLen, recvRowLen, sdisp, rdisp, reqs));
636 
637     // Combine global column indices in garray1[] and Foj[]
638     PetscInt n2, *garray2;
639 
640     PetscCall(ReduceTwoSetsOfGlobalIndices(n1, garray1, Fonz, Foj, &n2, &garray2, map));
641     mm->sf       = reduceSF;
642     mm->leafBuf  = MatScalarKokkosView(NoInit("leafBuf"), nleaves);
643     mm->rootBuf  = MatScalarKokkosView(NoInit("rootBuf"), nroots);
644     mm->garray   = garray2; // give ownership, so no free
645     mm->n        = n2;
646     mm->E_NzLeft = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), E_NzLeft_h);
647     mm->Fdjmap   = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Fdjmap_h);
648     mm->Fdjperm  = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Fdjperm_h);
649     mm->Fojmap   = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Fojmap_h);
650     mm->Fojperm  = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Fojperm_h);
651 
652     // Output Fd and Fo in KokkosCsrMatrix format
653     MatScalarKokkosView Fda_d(NoInit("Fda_d"), Fdnz);
654     MatRowMapKokkosView Fdi_d = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Fdi_h);
655     MatColIdxKokkosView Fdj_d = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Fdj_h);
656     MatScalarKokkosView Foa_d(NoInit("Foa_d"), Fonz);
657     MatRowMapKokkosView Foi_d = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Foi_h);
658     MatColIdxKokkosView Foj_d = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Foj_h);
659 
660     PetscCallCXX(mm->Fd = KokkosCsrMatrix("Fd", Fm, cend - cstart, Fdnz, Fda_d, Fdi_d, Fdj_d));
661     PetscCallCXX(mm->Fo = KokkosCsrMatrix("Fo", Fm, n2, Fonz, Foa_d, Foi_d, Foj_d)); // Fo's column size is n2, length of garray2[]
662 
663     // Compute kernel launch parameters in merging E
664     PetscInt teamSize, vectorLength, rowsPerTeam;
665 
666     teamSize = vectorLength = rowsPerTeam = -1;
667     PetscCall(MatMergeGetLaunchParameters<DefaultExecutionSpace>(Em, Enz, -1, teamSize, vectorLength, rowsPerTeam));
668     mm->E_TeamSize     = teamSize;
669     mm->E_VectorLength = vectorLength;
670     mm->E_RowsPerTeam  = rowsPerTeam;
671   } else PetscCheck(reuse == MAT_REUSE_MATRIX, comm, PETSC_ERR_PLIB, "Unsupported MatReuse enum %d", reuse);
672 
673   // Handy aliases
674   auto       &Aa           = A.values;
675   auto       &Ba           = B.values;
676   const auto &Ai           = A.graph.row_map;
677   const auto &Bi           = B.graph.row_map;
678   const auto &E_NzLeft     = mm->E_NzLeft;
679   auto       &leafBuf      = mm->leafBuf;
680   auto       &rootBuf      = mm->rootBuf;
681   PetscSF     reduceSF     = mm->sf;
682   PetscInt    Em           = A.numRows();
683   PetscInt    teamSize     = mm->E_TeamSize;
684   PetscInt    vectorLength = mm->E_VectorLength;
685   PetscInt    rowsPerTeam  = mm->E_RowsPerTeam;
686   PetscInt    workSets     = (Em + rowsPerTeam - 1) / rowsPerTeam;
687 
688   // Copy rows in A/B of E to leafBuf, then pass it to rootBuf
689   PetscCallCXX(Kokkos::parallel_for(
690     Kokkos::TeamPolicy<>(PetscGetKokkosExecutionSpace(), workSets, teamSize, vectorLength), KOKKOS_LAMBDA(const KokkosTeamMemberType &t) {
691       Kokkos::parallel_for(Kokkos::TeamThreadRange(t, 0, rowsPerTeam), [&](PetscInt k) {
692         PetscInt i = t.league_rank() * rowsPerTeam + k; // i-th row in F
693         if (i < Em) {
694           PetscInt disp   = Ai(i) + Bi(i);
695           PetscInt alen   = Ai(i + 1) - Ai(i);
696           PetscInt blen   = Bi(i + 1) - Bi(i);
697           PetscInt nzleft = E_NzLeft(i);
698 
699           Kokkos::parallel_for(Kokkos::ThreadVectorRange(t, alen + blen), [&](PetscInt j) {
700             MatScalar &val = leafBuf(disp + j);
701             if (j < nzleft) { // B left
702               val = Ba(Bi(i) + j);
703             } else if (j < nzleft + alen) { // diag A
704               val = Aa(Ai(i) + j - nzleft);
705             } else { // B right
706               val = Ba(Bi(i) + j - alen);
707             }
708           });
709         }
710       });
711     }));
712   PetscCall(PetscSFReduceWithMemTypeBegin(reduceSF, MPIU_SCALAR, PETSC_MEMTYPE_KOKKOS, leafBuf.data(), PETSC_MEMTYPE_KOKKOS, rootBuf.data(), MPI_REPLACE));
713   PetscFunctionReturn(PETSC_SUCCESS);
714 }
715 
716 // To finish MatMPIAIJKokkosReduce.
717 static PetscErrorCode MatMPIAIJKokkosReduceEnd(MPI_Comm comm, KokkosCsrMatrix A, KokkosCsrMatrix B, PetscInt cstart, PetscInt cend, const PetscInt *garray1, PetscSF ownerSF, MatReuse reuse, PetscInt *map, MatMatStruct_AtB *mm)
718 {
719   auto       &leafBuf  = mm->leafBuf;
720   auto       &rootBuf  = mm->rootBuf;
721   auto       &Fda      = mm->Fd.values;
722   const auto &Fdjmap   = mm->Fdjmap;
723   const auto &Fdjperm  = mm->Fdjperm;
724   auto        Fdnz     = mm->Fd.nnz();
725   auto       &Foa      = mm->Fo.values;
726   const auto &Fojmap   = mm->Fojmap;
727   const auto &Fojperm  = mm->Fojperm;
728   auto        Fonz     = mm->Fo.nnz();
729   PetscSF     reduceSF = mm->sf;
730 
731   PetscFunctionBegin;
732   PetscCall(PetscSFReduceEnd(reduceSF, MPIU_SCALAR, leafBuf.data(), rootBuf.data(), MPI_REPLACE));
733 
734   // Reduce data in rootBuf to Fd and Fo
735   PetscCallCXX(Kokkos::parallel_for(
736     Kokkos::RangePolicy<>(PetscGetKokkosExecutionSpace(), 0, Fdnz), KOKKOS_LAMBDA(const MatRowMapType i) {
737       PetscScalar sum = 0.0;
738       for (MatRowMapType k = Fdjmap(i); k < Fdjmap(i + 1); k++) sum += rootBuf(Fdjperm(k));
739       Fda(i) = sum;
740     }));
741 
742   PetscCallCXX(Kokkos::parallel_for(
743     Kokkos::RangePolicy<>(PetscGetKokkosExecutionSpace(), 0, Fonz), KOKKOS_LAMBDA(const MatRowMapType i) {
744       PetscScalar sum = 0.0;
745       for (MatRowMapType k = Fojmap(i); k < Fojmap(i + 1); k++) sum += rootBuf(Fojperm(k));
746       Foa(i) = sum;
747     }));
748   PetscFunctionReturn(PETSC_SUCCESS);
749 }
750 
751 /*
752   MatMPIAIJKokkosBcast - Bcast local rows of a MPIAIJKOKKOS matrix (E) to produce a local matrix (F, stored in mm) in split form
753 
754   This is a complex routine. It is essentially the MPIAIJKOKKOS counterpart of MatGetBrowsOfAoCols_MPIAIJ, but supports
755   device and involves various index mapping.
756 
757   In the given ownerSF, leaves correspond to rows in F, and roots correspond to rows in E. Roots may connect to multiple leaves.
758   Suppose F's j-th row is connected to a root identified by PetscSFNode (k,i), it means we need to bcast the i-th row of E on rank k
759   to j-th row of F. ownerSF is not an arbitrary SF, instead it is the Mvctx of another MPIAIJ matrix A that is able to perform A*E.
760   F has the same column layout as E.
761 
762   Conceptually F has global column indices. In this routine, we spit F into diagonal Fd and off-diagonal Fo.
763   Fd uses local column indices, which are easy to compute. We just need to subtract the "local column range start" from the global indices.
764   Fo had global column indices at first. We will reduce them into local ones. In doing that, we also take into account the global
765   column indices that E's off-diag block has. Let's say there are n1 such indices stored in garray1[]. We will reduce them along with
766   column indices in Fo and update Fo with local indices.
767 
768    Input Parameters:
769 +   E       - the MPIAIJKOKKOS matrix
770 .   ownerSF - the ownership SF (insignificant in MAT_REUSE_MATRIX)
771 .   reuse   - either MAT_INITIAL_MATRIX or MAT_REUSE_MATRIX
772 -   mm      - to stash matproduct intermediate data structures
773 
774     Output Parameters:
775 +   map[n1] - allocated by caller. It maps garray1[] to garray2[]. See more at ReduceTwoSetsOfGlobalIndices.
776 -   mm      - contains various info, such as garray2[], Fd, Fo, etc.
777 
778     Notes:
779     When reuse = MAT_REUSE_MATRIX, ownerSF, map are not significant.
780     The routine is provide in split-phase form MatMPIAIJKokkosBcastBegin/End() to provide computation/communication opportunities.
781 */
782 static PetscErrorCode MatMPIAIJKokkosBcastBegin(Mat E, PetscSF ownerSF, MatReuse reuse, PetscInt *map, MatMatStruct_AB *mm)
783 {
784   Mat_MPIAIJ       *empi = static_cast<Mat_MPIAIJ *>(E->data);
785   Mat               A = empi->A, B = empi->B; // diag and off-diag
786   Mat_SeqAIJKokkos *akok = static_cast<Mat_SeqAIJKokkos *>(A->spptr), *bkok = static_cast<Mat_SeqAIJKokkos *>(B->spptr);
787   PetscInt          Em = E->rmap->n; // #local rows
788   MPI_Comm          comm;
789 
790   PetscFunctionBegin;
791   PetscCallMPI(PetscObjectGetComm((PetscObject)E, &comm));
792   if (reuse == MAT_INITIAL_MATRIX) {
793     Mat_SeqAIJ     *aseq = static_cast<Mat_SeqAIJ *>(A->data), *bseq = static_cast<Mat_SeqAIJ *>(B->data);
794     PetscInt        n1 = B->cmap->n, *Ai = aseq->i, *Aj = aseq->j, *Bi = bseq->i, *Bj = bseq->j;
795     const PetscInt *garray1 = empi->garray; // its size is n1
796     PetscInt        cstart, cend;
797     PetscSF         bcastSF;
798 
799     PetscCall(MatGetOwnershipRangeColumn(E, &cstart, &cend));
800 
801     // Count how many nonzeros of each row in E are in the left of the diag block [cstart,cend)
802     PetscIntKokkosViewHost E_NzLeft_h(NoInit("E_NzLeft_h"), Em), E_RowLen_h(NoInit("E_RowLen_h"), Em);
803     PetscInt              *E_NzLeft = E_NzLeft_h.data(), *E_RowLen = E_RowLen_h.data();
804     for (PetscInt i = 0; i < Em; i++) {
805       const PetscInt *first, *last, *it;
806       PetscInt        count, step;
807       // std::lower_bound(first,last,cstart), but need to use global column indices
808       first = Bj + Bi[i];
809       last  = Bj + Bi[i + 1];
810       count = last - first;
811       while (count > 0) {
812         it   = first;
813         step = count / 2;
814         it += step;
815         if (empi->garray[*it] < cstart) { // map local to global
816           first = ++it;
817           count -= step + 1;
818         } else count = step;
819       }
820       E_NzLeft[i] = first - (Bj + Bi[i]);
821       E_RowLen[i] = (Ai[i + 1] - Ai[i]) + (Bi[i + 1] - Bi[i]);
822     }
823 
824     // Compute row pointer Fi of F
825     PetscInt *Fi, Fm, Fnz;
826     PetscCall(PetscSFGetGraph(ownerSF, NULL, &Fm, NULL, NULL)); // Fm = #rows of F = nleaves of ownerSF
827     PetscCall(PetscMalloc1(Fm + 1, &Fi));
828     Fi[0] = 0;
829     PetscCall(PetscSFBcastWithMemTypeBegin(ownerSF, MPIU_INT, PETSC_MEMTYPE_HOST, E_RowLen, PETSC_MEMTYPE_HOST, &Fi[1], MPI_REPLACE));
830     PetscCall(PetscSFBcastEnd(ownerSF, MPIU_INT, E_RowLen, &Fi[1], MPI_REPLACE));
831     for (PetscInt i = 0; i < Fm; i++) Fi[i + 1] += Fi[i];
832     Fnz = Fi[Fm];
833 
834     // Build the real PetscSF for bcasting E rows (buffer to buffer)
835     const PetscMPIInt *iranks, *ranks;
836     const PetscInt    *ioffset, *irootloc, *roffset;
837     PetscMPIInt        niranks, nranks;
838     PetscInt          *sdisp, *rdisp;
839     MPI_Request       *reqs;
840     PetscMPIInt        tag;
841 
842     PetscCall(PetscSFGetLeafRanks(ownerSF, &niranks, &iranks, &ioffset, &irootloc)); // get leaf ranks referencing roots on this process
843     PetscCall(PetscSFGetRootRanks(ownerSF, &nranks, &ranks, &roffset, NULL, NULL));  // recv info
844     PetscCall(PetscMalloc3(niranks + 1, &sdisp, nranks, &rdisp, niranks + nranks, &reqs));
845 
846     sdisp[0] = 0; // send displacement
847     for (PetscInt i = 0; i < niranks; i++) {
848       sdisp[i + 1] = sdisp[i];
849       for (PetscInt j = ioffset[i]; j < ioffset[i + 1]; j++) {
850         PetscInt r = irootloc[j]; // row to be sent
851         sdisp[i + 1] += E_RowLen[r];
852       }
853     }
854 
855     PetscCallMPI(PetscCommGetNewTag(comm, &tag));
856     for (PetscInt i = 0; i < nranks; i++) PetscCallMPI(MPIU_Irecv(&rdisp[i], 1, MPIU_INT, ranks[i], tag, comm, &reqs[i]));
857     for (PetscInt i = 0; i < niranks; i++) PetscCallMPI(MPIU_Isend(&sdisp[i], 1, MPIU_INT, iranks[i], tag, comm, &reqs[nranks + i]));
858     PetscCallMPI(MPI_Waitall(niranks + nranks, reqs, MPI_STATUSES_IGNORE));
859 
860     PetscInt     nleaves = Fnz;            // leaves are nonzeros I will receive
861     PetscInt     nroots  = sdisp[niranks]; // roots are nonzeros I will send
862     PetscSFNode *iremote;                  // give ownership to bcastSF
863     PetscCall(PetscMalloc1(nleaves, &iremote));
864     for (PetscInt i = 0; i < nranks; i++) { // for each sender rank
865       PetscInt k = 0;
866       for (PetscInt j = Fi[roffset[i]]; j < Fi[roffset[i + 1]]; j++) { // I will receive rows [roffset[i], roffset[i+1]) of F from ranks[i]
867         iremote[j].rank  = ranks[i];
868         iremote[j].index = rdisp[i] + k; // their root location
869         k++;
870       }
871     }
872     PetscCall(PetscSFCreate(comm, &bcastSF));
873     PetscCall(PetscSFSetGraph(bcastSF, nroots, nleaves, NULL, PETSC_OWN_POINTER, iremote, PETSC_OWN_POINTER));
874     PetscCall(PetscFree3(sdisp, rdisp, reqs));
875 
876     // Build a plan (rowoffset, irootloc, E_NzLeft) to copy rows in E to rootdata of bcastSF in parallel
877     PetscIntKokkosViewHost rowoffset_h(NoInit("rowoffset_h"), ioffset[niranks] + 1);
878     PetscInt              *rowoffset = rowoffset_h.data(); // for each entry (row) indicated in irootloc[], we calculate its destinate offset in copying
879     rowoffset[0]                     = 0;
880     for (PetscInt i = 0; i < ioffset[niranks]; i++) { rowoffset[i + 1] = rowoffset[i] + E_RowLen[irootloc[i]]; }
881 
882     // Copy (global) column indices of the needed rows in E to a buffer, and then bcast to Fj[]
883     PetscInt *jbuf, *Fj;
884     PetscCall(PetscMalloc2(nroots, &jbuf, Fnz, &Fj));
885     for (PetscInt k = 0; k < ioffset[niranks]; k++) {
886       PetscInt  i      = irootloc[k]; // row to be copied
887       PetscInt *buf    = &jbuf[rowoffset[k]];
888       PetscInt  nzLeft = E_NzLeft[i];
889       PetscInt  alen = Ai[i + 1] - Ai[i], blen = Bi[i + 1] - Bi[i];
890       for (PetscInt j = 0; j < alen + blen; j++) {
891         if (j < nzLeft) {
892           buf[j] = empi->garray[Bj[Bi[i] + j]]; // left B, in global
893         } else if (j < nzLeft + alen) {
894           buf[j] = Aj[Ai[i] + j - nzLeft] + cstart; // diag A, also in global
895         } else {
896           buf[j] = empi->garray[Bj[Bi[i] + j - alen]]; // right B, in global
897         }
898       }
899     }
900     PetscCall(PetscSFBcastWithMemTypeBegin(bcastSF, MPIU_INT, PETSC_MEMTYPE_HOST, jbuf, PETSC_MEMTYPE_HOST, Fj, MPI_REPLACE));
901     PetscCall(PetscSFBcastEnd(bcastSF, MPIU_INT, jbuf, Fj, MPI_REPLACE));
902 
903     // Build a plan (i.e., F_NzLeft) to split F into Fd and Fo
904     MatRowMapKokkosViewHost Fdi_h(NoInit("Fdi_h"), Fm + 1), Foi_h(NoInit("Foi_h"), Fm + 1); // row pointer of Fd, Fo
905     MatColIdxKokkosViewHost F_NzLeft_h(NoInit("F_NzLeft_h"), Fm);                           // split each row of F into Left, Diag, Right. We only need to record #nz in Left and Diag.
906     MatRowMapType          *Fdi = Fdi_h.data(), *Foi = Foi_h.data();
907     MatColIdxType          *F_NzLeft = F_NzLeft_h.data();
908 
909     Fdi[0] = Foi[0] = 0;
910     for (PetscInt i = 0; i < Fm; i++) {
911       PetscInt *first, *last, *lb1, *lb2;
912       // cut the row into: Left, [cstart, cend), Right
913       first       = Fj + Fi[i];
914       last        = Fj + Fi[i + 1];
915       lb1         = std::lower_bound(first, last, cstart);
916       F_NzLeft[i] = lb1 - first;
917       lb2         = std::lower_bound(first, last, cend);
918       Fdi[i + 1]  = lb2 - lb1;                        // row i length in Fdi
919       Foi[i + 1]  = (Fi[i + 1] - Fi[i]) - Fdi[i + 1]; // row i length in Foi
920     }
921     for (PetscInt i = 0; i < Fm; i++) {
922       Fdi[i + 1] += Fdi[i];
923       Foi[i + 1] += Foi[i];
924     }
925 
926     // Fill Fdj[] and Foj[], i.e., columns of Fd and Fo. Fdj[] are local, but Foj[] are not yet.
927     PetscInt                Fdnz = Fdi[Fm], Fonz = Foi[Fm];
928     MatColIdxKokkosViewHost Fdj_h(NoInit("Fdj_h"), Fdnz), Foj_h(NoInit("Foj_h"), Fonz);
929     MatColIdxType          *Fdj = Fdj_h.data(), *Foj = Foj_h.data(), gid;
930 
931     for (PetscInt i = 0; i < Fm; i++) {
932       PetscInt nzLeft = F_NzLeft[i];
933       PetscInt len    = Fdi[i + 1] - Fdi[i]; // diag row len
934       for (PetscInt j = 0; j < Fi[i + 1] - Fi[i]; j++) {
935         gid = Fj[Fi[i] + j];
936         if (j < nzLeft) { // left, in global
937           Foj[Foi[i] + j] = gid;
938         } else if (j < nzLeft + len) { // diag, in local
939           Fdj[Fdi[i] + j - nzLeft] = gid - cstart;
940         } else { // right, in global
941           Foj[Foi[i] + j - len] = gid;
942         }
943       }
944     }
945     PetscCall(PetscFree2(jbuf, Fj));
946     PetscCall(PetscFree(Fi));
947 
948     // Reduce global indices in Foj[] and garray1[] into local ones
949     PetscInt n2, *garray2;
950     PetscCall(ReduceTwoSetsOfGlobalIndices(n1, garray1, Fonz, Foj, &n2, &garray2, map));
951 
952     // Record the plans built above, for reuse
953     PetscIntKokkosViewHost tmp(const_cast<PetscInt *>(irootloc), ioffset[niranks]); // irootloc[] is owned by ownerSF. We create a copy for safety
954     PetscIntKokkosViewHost irootloc_h(NoInit("irootloc_h"), ioffset[niranks]);
955     Kokkos::deep_copy(irootloc_h, tmp);
956     mm->sf        = bcastSF;
957     mm->E_NzLeft  = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), E_NzLeft_h);
958     mm->F_NzLeft  = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), F_NzLeft_h);
959     mm->irootloc  = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), irootloc_h);
960     mm->rowoffset = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), rowoffset_h);
961     mm->rootBuf   = MatScalarKokkosView(NoInit("rootBuf"), nroots);
962     mm->leafBuf   = MatScalarKokkosView(NoInit("leafBuf"), nleaves);
963     mm->garray    = garray2;
964     mm->n         = n2;
965 
966     // Output Fd and Fo in KokkosCsrMatrix format
967     MatScalarKokkosView Fda_d(NoInit("Fda_d"), Fdnz), Foa_d(NoInit("Foa_d"), Fonz);
968     MatRowMapKokkosView Fdi_d = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Fdi_h);
969     MatColIdxKokkosView Fdj_d = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Fdj_h);
970     MatRowMapKokkosView Foi_d = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Foi_h);
971     MatColIdxKokkosView Foj_d = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Foj_h);
972 
973     PetscCallCXX(mm->Fd = KokkosCsrMatrix("Fd", Fm, cend - cstart, Fdnz, Fda_d, Fdi_d, Fdj_d));
974     PetscCallCXX(mm->Fo = KokkosCsrMatrix("Fo", Fm, n2, Fonz, Foa_d, Foi_d, Foj_d));
975 
976     // Compute kernel launch parameters in merging E or splitting F
977     PetscInt teamSize, vectorLength, rowsPerTeam;
978 
979     teamSize = vectorLength = rowsPerTeam = -1;
980     PetscCall(MatMergeGetLaunchParameters<DefaultExecutionSpace>(mm->irootloc.extent(0), mm->rootBuf.extent(0), -1, teamSize, vectorLength, rowsPerTeam));
981     mm->E_TeamSize     = teamSize;
982     mm->E_VectorLength = vectorLength;
983     mm->E_RowsPerTeam  = rowsPerTeam;
984 
985     teamSize = vectorLength = rowsPerTeam = -1;
986     PetscCall(MatMergeGetLaunchParameters<DefaultExecutionSpace>(Fm, Fnz, -1, teamSize, vectorLength, rowsPerTeam));
987     mm->F_TeamSize     = teamSize;
988     mm->F_VectorLength = vectorLength;
989     mm->F_RowsPerTeam  = rowsPerTeam;
990   } else PetscCheck(reuse == MAT_REUSE_MATRIX, comm, PETSC_ERR_PLIB, "Unsupported MatReuse enum %d", reuse);
991 
992   // Sync E's value to device
993   akok->a_dual.sync_device();
994   bkok->a_dual.sync_device();
995 
996   // Handy aliases
997   const auto &Aa = akok->a_dual.view_device();
998   const auto &Ba = bkok->a_dual.view_device();
999   const auto &Ai = akok->i_dual.view_device();
1000   const auto &Bi = bkok->i_dual.view_device();
1001 
1002   // Fetch the plans
1003   PetscIntKokkosView  &E_NzLeft  = mm->E_NzLeft;
1004   PetscSF             &bcastSF   = mm->sf;
1005   MatScalarKokkosView &rootBuf   = mm->rootBuf;
1006   MatScalarKokkosView &leafBuf   = mm->leafBuf;
1007   PetscIntKokkosView  &irootloc  = mm->irootloc;
1008   PetscIntKokkosView  &rowoffset = mm->rowoffset;
1009 
1010   PetscInt teamSize     = mm->E_TeamSize;
1011   PetscInt vectorLength = mm->E_VectorLength;
1012   PetscInt rowsPerTeam  = mm->E_RowsPerTeam;
1013   PetscInt workSets     = (irootloc.extent(0) + rowsPerTeam - 1) / rowsPerTeam;
1014 
1015   // Copy rows in A/B of E to rootBuf, then bcast it to leafBuf
1016   PetscCallCXX(Kokkos::parallel_for(
1017     Kokkos::TeamPolicy<>(PetscGetKokkosExecutionSpace(), workSets, teamSize, vectorLength), KOKKOS_LAMBDA(const KokkosTeamMemberType &t) {
1018       Kokkos::parallel_for(Kokkos::TeamThreadRange(t, 0, rowsPerTeam), [&](PetscInt k) {
1019         size_t r = t.league_rank() * rowsPerTeam + k; // r-th entry in irootloc[]
1020         if (r < irootloc.extent(0)) {
1021           PetscInt i      = irootloc(r); // row i of E
1022           PetscInt disp   = rowoffset(r);
1023           PetscInt alen   = Ai(i + 1) - Ai(i);
1024           PetscInt blen   = Bi(i + 1) - Bi(i);
1025           PetscInt nzleft = E_NzLeft(i);
1026 
1027           Kokkos::parallel_for(Kokkos::ThreadVectorRange(t, alen + blen), [&](PetscInt j) {
1028             if (j < nzleft) { // B left
1029               rootBuf(disp + j) = Ba(Bi(i) + j);
1030             } else if (j < nzleft + alen) { // diag A
1031               rootBuf(disp + j) = Aa(Ai(i) + j - nzleft);
1032             } else { // B right
1033               rootBuf(disp + j) = Ba(Bi(i) + j - alen);
1034             }
1035           });
1036         }
1037       });
1038     }));
1039   PetscCall(PetscSFBcastWithMemTypeBegin(bcastSF, MPIU_SCALAR, PETSC_MEMTYPE_KOKKOS, rootBuf.data(), PETSC_MEMTYPE_KOKKOS, leafBuf.data(), MPI_REPLACE));
1040   PetscFunctionReturn(PETSC_SUCCESS);
1041 }
1042 
1043 // To finish MatMPIAIJKokkosBcast.
1044 static PetscErrorCode MatMPIAIJKokkosBcastEnd(Mat E, PetscSF ownerSF, MatReuse reuse, PetscInt *map, MatMatStruct_AB *mm)
1045 {
1046   PetscFunctionBegin;
1047   const auto &Fd  = mm->Fd;
1048   const auto &Fo  = mm->Fo;
1049   const auto &Fdi = Fd.graph.row_map;
1050   const auto &Foi = Fo.graph.row_map;
1051   auto       &Fda = Fd.values;
1052   auto       &Foa = Fo.values;
1053   auto        Fm  = Fd.numRows();
1054 
1055   PetscIntKokkosView  &F_NzLeft     = mm->F_NzLeft;
1056   PetscSF             &bcastSF      = mm->sf;
1057   MatScalarKokkosView &rootBuf      = mm->rootBuf;
1058   MatScalarKokkosView &leafBuf      = mm->leafBuf;
1059   PetscInt             teamSize     = mm->F_TeamSize;
1060   PetscInt             vectorLength = mm->F_VectorLength;
1061   PetscInt             rowsPerTeam  = mm->F_RowsPerTeam;
1062   PetscInt             workSets     = (Fm + rowsPerTeam - 1) / rowsPerTeam;
1063 
1064   PetscCall(PetscSFBcastEnd(bcastSF, MPIU_SCALAR, rootBuf.data(), leafBuf.data(), MPI_REPLACE));
1065 
1066   // Update Fda and Foa with new data in leafBuf (as if it is Fa)
1067   PetscCallCXX(Kokkos::parallel_for(
1068     Kokkos::TeamPolicy<>(PetscGetKokkosExecutionSpace(), workSets, teamSize, vectorLength), KOKKOS_LAMBDA(const KokkosTeamMemberType &t) {
1069       Kokkos::parallel_for(Kokkos::TeamThreadRange(t, 0, rowsPerTeam), [&](PetscInt k) {
1070         PetscInt i = t.league_rank() * rowsPerTeam + k; // i-th row in F
1071         if (i < Fm) {
1072           PetscInt nzLeft = F_NzLeft(i);
1073           PetscInt alen   = Fdi(i + 1) - Fdi(i);
1074           PetscInt blen   = Foi(i + 1) - Foi(i);
1075           PetscInt Fii    = Fdi(i) + Foi(i);
1076 
1077           Kokkos::parallel_for(Kokkos::ThreadVectorRange(t, alen + blen), [&](PetscInt j) {
1078             PetscScalar val = leafBuf(Fii + j);
1079             if (j < nzLeft) { // left
1080               Foa(Foi(i) + j) = val;
1081             } else if (j < nzLeft + alen) { // diag
1082               Fda(Fdi(i) + j - nzLeft) = val;
1083             } else { // right
1084               Foa(Foi(i) + j - alen) = val;
1085             }
1086           });
1087         }
1088       });
1089     }));
1090   PetscFunctionReturn(PETSC_SUCCESS);
1091 }
1092 
1093 static PetscErrorCode MatProductSymbolic_MPIAIJKokkos_AtB(Mat_Product *product, Mat A, Mat B, MatMatStruct_AtB *mm)
1094 {
1095   Mat_MPIAIJ     *ampi = static_cast<Mat_MPIAIJ *>(A->data);
1096   Mat_MPIAIJ     *bmpi = static_cast<Mat_MPIAIJ *>(B->data);
1097   KokkosCsrMatrix Adt, Aot, Ad, Ao, Bd, Bo;
1098   PetscInt        cstart, cend;
1099   MPI_Comm        comm;
1100 
1101   PetscFunctionBegin;
1102   PetscCall(PetscObjectGetComm((PetscObject)B, &comm));
1103   PetscCall(MatSeqAIJKokkosGenerateTranspose_Private(ampi->A, &Adt));
1104   PetscCall(MatSeqAIJKokkosGenerateTranspose_Private(ampi->B, &Aot));
1105   PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(ampi->A, &Ad));
1106   PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(ampi->B, &Ao));
1107   PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(bmpi->A, &Bd));
1108   PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(bmpi->B, &Bo));
1109 
1110   // TODO: add command line options to select spgemm algorithms
1111   auto spgemm_alg = KokkosSparse::SPGEMMAlgorithm::SPGEMM_DEFAULT; // default is TPL if enabled, otherwise KK
1112 
1113   // CUDA-10.2's spgemm has bugs. We prefer the SpGEMMreuse APIs introduced in cuda-11.4
1114 #if defined(KOKKOSKERNELS_ENABLE_TPL_CUSPARSE)
1115   #if PETSC_PKG_CUDA_VERSION_LT(11, 4, 0)
1116   spgemm_alg = KokkosSparse::SPGEMMAlgorithm::SPGEMM_KK;
1117   #endif
1118 #endif
1119 
1120   PetscCallCXX(mm->kh1.create_spgemm_handle(spgemm_alg));
1121   PetscCallCXX(mm->kh2.create_spgemm_handle(spgemm_alg));
1122   PetscCallCXX(mm->kh3.create_spgemm_handle(spgemm_alg));
1123   PetscCallCXX(mm->kh4.create_spgemm_handle(spgemm_alg));
1124 
1125   // Aot * (B's diag + B's off-diag)
1126   PetscCallCXX(KokkosSparse::spgemm_symbolic(mm->kh3, Aot, false, Bd, false, mm->C3));
1127   PetscCallCXX(KokkosSparse::spgemm_symbolic(mm->kh4, Aot, false, Bo, false, mm->C4));
1128   // KK spgemm_symbolic() only populates the result's row map, but not its columns.
1129   // TODO: Remove the fake spgemm_numeric() after KK fixed this problem.
1130   PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh3, Aot, false, Bd, false, mm->C3));
1131   PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh4, Aot, false, Bo, false, mm->C4));
1132 #if PETSC_PKG_KOKKOS_KERNELS_VERSION_LT(4, 0, 0)
1133 
1134   PetscCallCXX(sort_crs_matrix(mm->C3));
1135   PetscCallCXX(sort_crs_matrix(mm->C4));
1136 #endif
1137 
1138   // Reduce E (i.e., C3 and C4)'s rows to form F, and overlap the communication
1139   PetscIntKokkosViewHost map_h(NoInit("map_h"), bmpi->B->cmap->n);
1140   PetscCall(MatGetOwnershipRangeColumn(B, &cstart, &cend));
1141   PetscCall(MatMPIAIJKokkosReduceBegin(comm, mm->C3, mm->C4, cstart, cend, bmpi->garray, ampi->Mvctx, MAT_INITIAL_MATRIX, map_h.data(), mm));
1142 
1143   // Adt * (B's diag + B's off-diag)
1144   PetscCallCXX(KokkosSparse::spgemm_symbolic(mm->kh1, Adt, false, Bd, false, mm->C1));
1145   PetscCallCXX(KokkosSparse::spgemm_symbolic(mm->kh2, Adt, false, Bo, false, mm->C2_mid));
1146   PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh1, Adt, false, Bd, false, mm->C1));
1147   PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh2, Adt, false, Bo, false, mm->C2_mid));
1148 #if PETSC_PKG_KOKKOS_KERNELS_VERSION_LT(4, 0, 0)
1149   PetscCallCXX(sort_crs_matrix(mm->C1));
1150   PetscCallCXX(sort_crs_matrix(mm->C2_mid));
1151 #endif
1152 
1153   PetscCall(MatMPIAIJKokkosReduceEnd(comm, mm->C3, mm->C4, cstart, cend, bmpi->garray, ampi->Mvctx, MAT_INITIAL_MATRIX, map_h.data(), mm));
1154 
1155   // Create C2, which shares a, i arrays with C2_mid, but with new column indices and potentially larger column size
1156   MatColIdxKokkosView oldj = mm->C2_mid.graph.entries, newj(NoInit("j"), oldj.extent(0));
1157   PetscIntKokkosView  map  = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), map_h);
1158   PetscCallCXX(Kokkos::parallel_for(Kokkos::RangePolicy<>(PetscGetKokkosExecutionSpace(), 0, oldj.extent(0)), KOKKOS_LAMBDA(const PetscInt i) { newj(i) = map(oldj(i)); }));
1159   PetscCallCXX(mm->C2 = KokkosCsrMatrix("C2", mm->C2_mid.numRows(), mm->n /*new column size*/, mm->C2_mid.nnz(), mm->C2_mid.values, mm->C2_mid.graph.row_map, newj));
1160 
1161   // C = (C1+Fd, C2+Fo)
1162   PetscCallCXX(mm->kh1.create_spadd_handle(true)); // C1, Fd are sorted
1163   PetscCallCXX(mm->kh2.create_spadd_handle(true)); // C2, Fo are sorted
1164   PetscCallCXX(KokkosSparse::spadd_symbolic(&mm->kh1, mm->C1, mm->Fd, mm->Cd));
1165   PetscCallCXX(KokkosSparse::spadd_symbolic(&mm->kh2, mm->C2, mm->Fo, mm->Co));
1166   PetscCallCXX(KokkosSparse::spadd_numeric(&mm->kh1, 1.0, mm->C1, 1.0, mm->Fd, mm->Cd));
1167   PetscCallCXX(KokkosSparse::spadd_numeric(&mm->kh2, 1.0, mm->C2, 1.0, mm->Fo, mm->Co));
1168   PetscFunctionReturn(PETSC_SUCCESS);
1169 }
1170 
1171 static PetscErrorCode MatProductNumeric_MPIAIJKokkos_AtB(Mat_Product *product, Mat A, Mat B, MatMatStruct_AtB *mm)
1172 {
1173   Mat_MPIAIJ     *ampi = static_cast<Mat_MPIAIJ *>(A->data);
1174   Mat_MPIAIJ     *bmpi = static_cast<Mat_MPIAIJ *>(B->data);
1175   KokkosCsrMatrix Adt, Aot, Bd, Bo;
1176   MPI_Comm        comm;
1177 
1178   PetscFunctionBegin;
1179   PetscCall(PetscObjectGetComm((PetscObject)B, &comm));
1180   PetscCall(MatSeqAIJKokkosGenerateTranspose_Private(ampi->A, &Adt));
1181   PetscCall(MatSeqAIJKokkosGenerateTranspose_Private(ampi->B, &Aot));
1182   PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(bmpi->A, &Bd));
1183   PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(bmpi->B, &Bo));
1184 
1185   // Aot * (B's diag + B's off-diag)
1186   PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh3, Aot, false, Bd, false, mm->C3));
1187   PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh4, Aot, false, Bo, false, mm->C4));
1188 
1189   // Reduce E (i.e., C3 and C4)'s rows to form F, and overlap the communication
1190   PetscCall(MatMPIAIJKokkosReduceBegin(comm, mm->C3, mm->C4, 0, 0, NULL, NULL, MAT_REUSE_MATRIX, NULL, mm));
1191 
1192   // Adt * (B's diag + B's off-diag)
1193   PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh1, Adt, false, Bd, false, mm->C1));
1194   PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh2, Adt, false, Bo, false, mm->C2_mid));
1195 
1196   PetscCall(MatMPIAIJKokkosReduceEnd(comm, mm->C3, mm->C4, 0, 0, NULL, NULL, MAT_REUSE_MATRIX, NULL, mm));
1197 
1198   // C = (C1+Fd, C2+Fo)
1199   PetscCallCXX(KokkosSparse::spadd_numeric(&mm->kh1, 1.0, mm->C1, 1.0, mm->Fd, mm->Cd));
1200   PetscCallCXX(KokkosSparse::spadd_numeric(&mm->kh2, 1.0, mm->C2, 1.0, mm->Fo, mm->Co));
1201   PetscFunctionReturn(PETSC_SUCCESS);
1202 }
1203 
1204 /* MatProductSymbolic_MPIAIJKokkos_AB - AB flavor of MatProductSymbolic_MPIAIJKokkos
1205 
1206   Input Parameters:
1207 +  product  - Mat_Product which carried out the computation. Passed in to access info about this mat product.
1208 .  A        - an MPIAIJKOKKOS matrix
1209 .  B        - an MPIAIJKOKKOS matrix
1210 -  mm       - a struct used to stash intermediate data when computing AB. Persist from symbolic to numeric operations.
1211 */
1212 static PetscErrorCode MatProductSymbolic_MPIAIJKokkos_AB(Mat_Product *product, Mat A, Mat B, MatMatStruct_AB *mm)
1213 {
1214   Mat_MPIAIJ     *ampi = static_cast<Mat_MPIAIJ *>(A->data);
1215   Mat_MPIAIJ     *bmpi = static_cast<Mat_MPIAIJ *>(B->data);
1216   KokkosCsrMatrix Ad, Ao, Bd, Bo;
1217 
1218   PetscFunctionBegin;
1219   PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(ampi->A, &Ad));
1220   PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(ampi->B, &Ao));
1221   PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(bmpi->A, &Bd));
1222   PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(bmpi->B, &Bo));
1223 
1224   // TODO: add command line options to select spgemm algorithms
1225   auto spgemm_alg = KokkosSparse::SPGEMMAlgorithm::SPGEMM_DEFAULT; // default is TPL if enabled, otherwise KK
1226 
1227   // CUDA-10.2's spgemm has bugs. We prefer the SpGEMMreuse APIs introduced in cuda-11.4
1228 #if defined(KOKKOSKERNELS_ENABLE_TPL_CUSPARSE)
1229   #if PETSC_PKG_CUDA_VERSION_LT(11, 4, 0)
1230   spgemm_alg = KokkosSparse::SPGEMMAlgorithm::SPGEMM_KK;
1231   #endif
1232 #endif
1233 
1234   mm->kh1.create_spgemm_handle(spgemm_alg);
1235   mm->kh2.create_spgemm_handle(spgemm_alg);
1236   mm->kh3.create_spgemm_handle(spgemm_alg);
1237   mm->kh4.create_spgemm_handle(spgemm_alg);
1238 
1239   // Bcast B's rows to form F, and overlap the communication
1240   PetscIntKokkosViewHost map_h(NoInit("map_h"), bmpi->B->cmap->n);
1241   PetscCall(MatMPIAIJKokkosBcastBegin(B, ampi->Mvctx, MAT_INITIAL_MATRIX, map_h.data(), mm));
1242 
1243   // A's diag * (B's diag + B's off-diag)
1244   PetscCallCXX(KokkosSparse::spgemm_symbolic(mm->kh1, Ad, false, Bd, false, mm->C1));
1245   PetscCallCXX(KokkosSparse::spgemm_symbolic(mm->kh2, Ad, false, Bo, false, mm->C2_mid)); // C2 aliases with C2_mid, except with new column indices
1246   // KK spgemm_symbolic() only populates the result's row map, but not its columns.
1247   // TODO: Remove the fake spgemm_numeric() after KK fixed this problem.
1248   PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh1, Ad, false, Bd, false, mm->C1));
1249   PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh2, Ad, false, Bo, false, mm->C2_mid));
1250 #if PETSC_PKG_KOKKOS_KERNELS_VERSION_LT(4, 0, 0)
1251   PetscCallCXX(sort_crs_matrix(mm->C1));
1252   PetscCallCXX(sort_crs_matrix(mm->C2_mid));
1253 #endif
1254 
1255   PetscCall(MatMPIAIJKokkosBcastEnd(B, ampi->Mvctx, MAT_INITIAL_MATRIX, map_h.data(), mm));
1256 
1257   // A's off-diag * (F's diag + F's off-diag)
1258   PetscCallCXX(KokkosSparse::spgemm_symbolic(mm->kh3, Ao, false, mm->Fd, false, mm->C3));
1259   PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh3, Ao, false, mm->Fd, false, mm->C3));
1260   PetscCallCXX(KokkosSparse::spgemm_symbolic(mm->kh4, Ao, false, mm->Fo, false, mm->C4));
1261   PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh4, Ao, false, mm->Fo, false, mm->C4));
1262 #if PETSC_PKG_KOKKOS_KERNELS_VERSION_LT(4, 0, 0)
1263   PetscCallCXX(sort_crs_matrix(mm->C3));
1264   PetscCallCXX(sort_crs_matrix(mm->C4));
1265 #endif
1266 
1267   // Create C2, which shares a, i arrays with C2_mid, but with new column indices and potentially larger column size
1268   MatColIdxKokkosView oldj = mm->C2_mid.graph.entries, newj(NoInit("j"), oldj.extent(0));
1269   PetscIntKokkosView  map  = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), map_h);
1270   PetscCallCXX(Kokkos::parallel_for(Kokkos::RangePolicy<>(PetscGetKokkosExecutionSpace(), 0, oldj.extent(0)), KOKKOS_LAMBDA(const PetscInt i) { newj(i) = map(oldj(i)); }));
1271   mm->C2 = KokkosCsrMatrix("C2", mm->C2_mid.numRows(), mm->n /*new column size*/, mm->C2_mid.nnz(), mm->C2_mid.values, mm->C2_mid.graph.row_map, newj);
1272 
1273   // C = (Cd, Co) = (C1+C3, C2+C4)
1274   mm->kh1.create_spadd_handle(true); // C1, C3 are sorted
1275   mm->kh2.create_spadd_handle(true); // C2, C4 are sorted
1276   PetscCallCXX(KokkosSparse::spadd_symbolic(&mm->kh1, mm->C1, mm->C3, mm->Cd));
1277   PetscCallCXX(KokkosSparse::spadd_symbolic(&mm->kh2, mm->C2, mm->C4, mm->Co));
1278   PetscCallCXX(KokkosSparse::spadd_numeric(&mm->kh1, 1.0, mm->C1, 1.0, mm->C3, mm->Cd));
1279   PetscCallCXX(KokkosSparse::spadd_numeric(&mm->kh2, 1.0, mm->C2, 1.0, mm->C4, mm->Co));
1280   PetscFunctionReturn(PETSC_SUCCESS);
1281 }
1282 
1283 static PetscErrorCode MatProductNumeric_MPIAIJKokkos_AB(Mat_Product *product, Mat A, Mat B, MatMatStruct_AB *mm)
1284 {
1285   Mat_MPIAIJ     *ampi = static_cast<Mat_MPIAIJ *>(A->data);
1286   Mat_MPIAIJ     *bmpi = static_cast<Mat_MPIAIJ *>(B->data);
1287   KokkosCsrMatrix Ad, Ao, Bd, Bo;
1288 
1289   PetscFunctionBegin;
1290   PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(ampi->A, &Ad));
1291   PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(ampi->B, &Ao));
1292   PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(bmpi->A, &Bd));
1293   PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(bmpi->B, &Bo));
1294 
1295   // Bcast B's rows to form F, and overlap the communication
1296   PetscCall(MatMPIAIJKokkosBcastBegin(B, NULL, MAT_REUSE_MATRIX, NULL, mm));
1297 
1298   // A's diag * (B's diag + B's off-diag)
1299   PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh1, Ad, false, Bd, false, mm->C1));
1300   PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh2, Ad, false, Bo, false, mm->C2_mid));
1301 
1302   PetscCall(MatMPIAIJKokkosBcastEnd(B, NULL, MAT_REUSE_MATRIX, NULL, mm));
1303 
1304   // A's off-diag * (F's diag + F's off-diag)
1305   PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh3, Ao, false, mm->Fd, false, mm->C3));
1306   PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh4, Ao, false, mm->Fo, false, mm->C4));
1307 
1308   // C = (Cd, Co) = (C1+C3, C2+C4)
1309   PetscCallCXX(KokkosSparse::spadd_numeric(&mm->kh1, 1.0, mm->C1, 1.0, mm->C3, mm->Cd));
1310   PetscCallCXX(KokkosSparse::spadd_numeric(&mm->kh2, 1.0, mm->C2, 1.0, mm->C4, mm->Co));
1311   PetscFunctionReturn(PETSC_SUCCESS);
1312 }
1313 
1314 static PetscErrorCode MatProductNumeric_MPIAIJKokkos(Mat C)
1315 {
1316   Mat_MPIAIJ                  *cmpi = static_cast<Mat_MPIAIJ *>(C->data);
1317   Mat_Product                 *product;
1318   MatProductData_MPIAIJKokkos *pdata;
1319   MatProductType               ptype;
1320   Mat                          A, B;
1321 
1322   PetscFunctionBegin;
1323   MatCheckProduct(C, 1); // make sure C is a product
1324   product = C->product;
1325   pdata   = static_cast<MatProductData_MPIAIJKokkos *>(product->data);
1326   ptype   = product->type;
1327   A       = product->A;
1328   B       = product->B;
1329 
1330   // See if numeric has already been done in symbolic (e.g., user calls MatMatMult(A,B,MAT_INITIAL_MATRIX,..,C)).
1331   // If yes, skip the numeric, but reset the flag so that next time when user calls MatMatMult(E,F,MAT_REUSE_MATRIX,..,C),
1332   // we still do numeric.
1333   if (pdata->reusesym) { // numeric reuses results from symbolic
1334     pdata->reusesym = PETSC_FALSE;
1335     PetscFunctionReturn(PETSC_SUCCESS);
1336   }
1337 
1338   if (ptype == MATPRODUCT_AB) {
1339     PetscCall(MatProductNumeric_MPIAIJKokkos_AB(product, A, B, pdata->mmAB));
1340   } else if (ptype == MATPRODUCT_AtB) {
1341     PetscCall(MatProductNumeric_MPIAIJKokkos_AtB(product, A, B, pdata->mmAtB));
1342   } else if (ptype == MATPRODUCT_PtAP) { // BtAB, computed by Z = AB; C= BtZ
1343     PetscCall(MatProductNumeric_MPIAIJKokkos_AB(product, A, B, pdata->mmAB));
1344     PetscCall(MatProductNumeric_MPIAIJKokkos_AtB(product, B, pdata->Z, pdata->mmAtB));
1345   }
1346 
1347   PetscCall(MatSeqAIJKokkosModifyDevice(cmpi->A)); // mark that A, B on device are modified
1348   PetscCall(MatSeqAIJKokkosModifyDevice(cmpi->B));
1349   PetscFunctionReturn(PETSC_SUCCESS);
1350 }
1351 
1352 static PetscErrorCode MatProductSymbolic_MPIAIJKokkos(Mat C)
1353 {
1354   Mat                          A, B;
1355   Mat_Product                 *product;
1356   MatProductType               ptype;
1357   MatProductData_MPIAIJKokkos *pdata;
1358   MatMatStruct                *mm = NULL;
1359   PetscInt                     m, n, M, N;
1360   Mat                          Cd, Co;
1361   MPI_Comm                     comm;
1362 
1363   PetscFunctionBegin;
1364   PetscCall(PetscObjectGetComm((PetscObject)C, &comm));
1365   MatCheckProduct(C, 1);
1366   product = C->product;
1367   PetscCheck(!product->data, comm, PETSC_ERR_PLIB, "Product data not empty");
1368   ptype = product->type;
1369   A     = product->A;
1370   B     = product->B;
1371 
1372   switch (ptype) {
1373   case MATPRODUCT_AB:
1374     m = A->rmap->n;
1375     n = B->cmap->n;
1376     M = A->rmap->N;
1377     N = B->cmap->N;
1378     break;
1379   case MATPRODUCT_AtB:
1380     m = A->cmap->n;
1381     n = B->cmap->n;
1382     M = A->cmap->N;
1383     N = B->cmap->N;
1384     break;
1385   case MATPRODUCT_PtAP:
1386     m = B->cmap->n;
1387     n = B->cmap->n;
1388     M = B->cmap->N;
1389     N = B->cmap->N;
1390     break; /* BtAB */
1391   default:
1392     SETERRQ(comm, PETSC_ERR_PLIB, "Not for product type %s", MatProductTypes[ptype]);
1393   }
1394 
1395   PetscCall(MatSetSizes(C, m, n, M, N));
1396   PetscCall(PetscLayoutSetUp(C->rmap));
1397   PetscCall(PetscLayoutSetUp(C->cmap));
1398   PetscCall(MatSetType(C, ((PetscObject)A)->type_name));
1399 
1400   pdata           = new MatProductData_MPIAIJKokkos();
1401   pdata->reusesym = product->api_user;
1402 
1403   if (ptype == MATPRODUCT_AB) {
1404     auto mmAB = new MatMatStruct_AB();
1405     PetscCall(MatProductSymbolic_MPIAIJKokkos_AB(product, A, B, mmAB));
1406     mm = pdata->mmAB = mmAB;
1407   } else if (ptype == MATPRODUCT_AtB) {
1408     auto mmAtB = new MatMatStruct_AtB();
1409     PetscCall(MatProductSymbolic_MPIAIJKokkos_AtB(product, A, B, mmAtB));
1410     mm = pdata->mmAtB = mmAtB;
1411   } else if (ptype == MATPRODUCT_PtAP) { // C = BtAB, computed as Z = AB; C= BtZ
1412     Mat Zd, Zo, Z;                       // Zd, Zo are owned by pdata->Z
1413 
1414     auto mmAB = new MatMatStruct_AB();
1415     PetscCall(MatProductSymbolic_MPIAIJKokkos_AB(product, A, B, mmAB)); // Z stored as mmAB->{Cd, Co}
1416     PetscCall(MatCreateSeqAIJKokkosWithKokkosCsrMatrix(PETSC_COMM_SELF, mmAB->Cd, &Zd));
1417     PetscCall(MatCreateSeqAIJKokkosWithKokkosCsrMatrix(PETSC_COMM_SELF, mmAB->Co, &Zo));
1418     pdata->mmAB = mmAB;
1419 
1420     m = A->rmap->n; // Z's layout
1421     n = B->cmap->n;
1422     M = A->rmap->N;
1423     N = B->cmap->N;
1424     PetscCall(MatCreate(comm, &Z));
1425     PetscCall(MatSetSizes(Z, m, n, M, N));
1426     PetscCall(PetscLayoutSetUp(Z->rmap));
1427     PetscCall(PetscLayoutSetUp(Z->cmap));
1428     PetscCall(MatSetType(Z, MATMPIAIJKOKKOS));
1429     PetscCall(MatSetMPIAIJKokkosWithSplitSeqAIJKokkosMatrices(Z, Zd, Zo, mmAB->garray));
1430 
1431     auto mmAtB = new MatMatStruct_AtB();
1432     PetscCall(MatProductSymbolic_MPIAIJKokkos_AtB(product, B, Z, mmAtB)); // final result C stored as mmAtB->{Cd, Co}
1433 
1434     pdata->Z = Z; // give ownership to pdata
1435     mm = pdata->mmAtB = mmAtB;
1436   }
1437 
1438   PetscCall(MatCreateSeqAIJKokkosWithKokkosCsrMatrix(PETSC_COMM_SELF, mm->Cd, &Cd));
1439   PetscCall(MatCreateSeqAIJKokkosWithKokkosCsrMatrix(PETSC_COMM_SELF, mm->Co, &Co));
1440   PetscCall(MatSetMPIAIJKokkosWithSplitSeqAIJKokkosMatrices(C, Cd, Co, mm->garray));
1441 
1442   C->product->data       = pdata;
1443   C->product->destroy    = MatProductDataDestroy_MPIAIJKokkos;
1444   C->ops->productnumeric = MatProductNumeric_MPIAIJKokkos;
1445   PetscFunctionReturn(PETSC_SUCCESS);
1446 }
1447 
1448 PETSC_INTERN PetscErrorCode MatProductSetFromOptions_MPIAIJKokkos(Mat mat)
1449 {
1450   Mat_Product *product = mat->product;
1451   PetscBool    match   = PETSC_FALSE;
1452   PetscBool    usecpu  = PETSC_FALSE;
1453 
1454   PetscFunctionBegin;
1455   MatCheckProduct(mat, 1);
1456   if (!product->A->boundtocpu && !product->B->boundtocpu) PetscCall(PetscObjectTypeCompare((PetscObject)product->B, ((PetscObject)product->A)->type_name, &match));
1457   if (match) { /* we can always fallback to the CPU if requested */
1458     switch (product->type) {
1459     case MATPRODUCT_AB:
1460       if (product->api_user) {
1461         PetscOptionsBegin(PetscObjectComm((PetscObject)mat), ((PetscObject)mat)->prefix, "MatMatMult", "Mat");
1462         PetscCall(PetscOptionsBool("-matmatmult_backend_cpu", "Use CPU code", "MatMatMult", usecpu, &usecpu, NULL));
1463         PetscOptionsEnd();
1464       } else {
1465         PetscOptionsBegin(PetscObjectComm((PetscObject)mat), ((PetscObject)mat)->prefix, "MatProduct_AB", "Mat");
1466         PetscCall(PetscOptionsBool("-mat_product_algorithm_backend_cpu", "Use CPU code", "MatMatMult", usecpu, &usecpu, NULL));
1467         PetscOptionsEnd();
1468       }
1469       break;
1470     case MATPRODUCT_AtB:
1471       if (product->api_user) {
1472         PetscOptionsBegin(PetscObjectComm((PetscObject)mat), ((PetscObject)mat)->prefix, "MatTransposeMatMult", "Mat");
1473         PetscCall(PetscOptionsBool("-mattransposematmult_backend_cpu", "Use CPU code", "MatTransposeMatMult", usecpu, &usecpu, NULL));
1474         PetscOptionsEnd();
1475       } else {
1476         PetscOptionsBegin(PetscObjectComm((PetscObject)mat), ((PetscObject)mat)->prefix, "MatProduct_AtB", "Mat");
1477         PetscCall(PetscOptionsBool("-mat_product_algorithm_backend_cpu", "Use CPU code", "MatTransposeMatMult", usecpu, &usecpu, NULL));
1478         PetscOptionsEnd();
1479       }
1480       break;
1481     case MATPRODUCT_PtAP:
1482       if (product->api_user) {
1483         PetscOptionsBegin(PetscObjectComm((PetscObject)mat), ((PetscObject)mat)->prefix, "MatPtAP", "Mat");
1484         PetscCall(PetscOptionsBool("-matptap_backend_cpu", "Use CPU code", "MatPtAP", usecpu, &usecpu, NULL));
1485         PetscOptionsEnd();
1486       } else {
1487         PetscOptionsBegin(PetscObjectComm((PetscObject)mat), ((PetscObject)mat)->prefix, "MatProduct_PtAP", "Mat");
1488         PetscCall(PetscOptionsBool("-mat_product_algorithm_backend_cpu", "Use CPU code", "MatPtAP", usecpu, &usecpu, NULL));
1489         PetscOptionsEnd();
1490       }
1491       break;
1492     default:
1493       break;
1494     }
1495     match = (PetscBool)!usecpu;
1496   }
1497   if (match) {
1498     switch (product->type) {
1499     case MATPRODUCT_AB:
1500     case MATPRODUCT_AtB:
1501     case MATPRODUCT_PtAP:
1502       mat->ops->productsymbolic = MatProductSymbolic_MPIAIJKokkos;
1503       break;
1504     default:
1505       break;
1506     }
1507   }
1508   /* fallback to MPIAIJ ops */
1509   if (!mat->ops->productsymbolic) PetscCall(MatProductSetFromOptions_MPIAIJ(mat));
1510   PetscFunctionReturn(PETSC_SUCCESS);
1511 }
1512 
1513 // Mirror of MatCOOStruct_MPIAIJ on device
1514 struct MatCOOStruct_MPIAIJKokkos {
1515   PetscCount           n;
1516   PetscSF              sf;
1517   PetscCount           Annz, Bnnz;
1518   PetscCount           Annz2, Bnnz2;
1519   PetscCountKokkosView Ajmap1, Aperm1;
1520   PetscCountKokkosView Bjmap1, Bperm1;
1521   PetscCountKokkosView Aimap2, Ajmap2, Aperm2;
1522   PetscCountKokkosView Bimap2, Bjmap2, Bperm2;
1523   PetscCountKokkosView Cperm1;
1524   MatScalarKokkosView  sendbuf, recvbuf;
1525 
1526   MatCOOStruct_MPIAIJKokkos(const MatCOOStruct_MPIAIJ *coo_h)
1527   {
1528     auto &exec = PetscGetKokkosExecutionSpace();
1529 
1530     n       = coo_h->n;
1531     sf      = coo_h->sf;
1532     Annz    = coo_h->Annz;
1533     Bnnz    = coo_h->Bnnz;
1534     Annz2   = coo_h->Annz2;
1535     Bnnz2   = coo_h->Bnnz2;
1536     Ajmap1  = Kokkos::create_mirror_view_and_copy(exec, PetscCountKokkosViewHost(coo_h->Ajmap1, coo_h->Annz + 1));
1537     Aperm1  = Kokkos::create_mirror_view_and_copy(exec, PetscCountKokkosViewHost(coo_h->Aperm1, coo_h->Atot1));
1538     Bjmap1  = Kokkos::create_mirror_view_and_copy(exec, PetscCountKokkosViewHost(coo_h->Bjmap1, coo_h->Bnnz + 1));
1539     Bperm1  = Kokkos::create_mirror_view_and_copy(exec, PetscCountKokkosViewHost(coo_h->Bperm1, coo_h->Btot1));
1540     Aimap2  = Kokkos::create_mirror_view_and_copy(exec, PetscCountKokkosViewHost(coo_h->Aimap2, coo_h->Annz2));
1541     Ajmap2  = Kokkos::create_mirror_view_and_copy(exec, PetscCountKokkosViewHost(coo_h->Ajmap2, coo_h->Annz2 + 1));
1542     Aperm2  = Kokkos::create_mirror_view_and_copy(exec, PetscCountKokkosViewHost(coo_h->Aperm2, coo_h->Atot2));
1543     Bimap2  = Kokkos::create_mirror_view_and_copy(exec, PetscCountKokkosViewHost(coo_h->Bimap2, coo_h->Bnnz2));
1544     Bjmap2  = Kokkos::create_mirror_view_and_copy(exec, PetscCountKokkosViewHost(coo_h->Bjmap2, coo_h->Bnnz2 + 1));
1545     Bperm2  = Kokkos::create_mirror_view_and_copy(exec, PetscCountKokkosViewHost(coo_h->Bperm2, coo_h->Btot2));
1546     Cperm1  = Kokkos::create_mirror_view_and_copy(exec, PetscCountKokkosViewHost(coo_h->Cperm1, coo_h->sendlen));
1547     sendbuf = Kokkos::create_mirror_view(Kokkos::WithoutInitializing, exec, MatScalarKokkosViewHost(coo_h->sendbuf, coo_h->sendlen));
1548     recvbuf = Kokkos::create_mirror_view(Kokkos::WithoutInitializing, exec, MatScalarKokkosViewHost(coo_h->recvbuf, coo_h->recvlen));
1549     PetscCallVoid(PetscObjectReference((PetscObject)sf));
1550   }
1551 
1552   ~MatCOOStruct_MPIAIJKokkos() { PetscCallVoid(PetscSFDestroy(&sf)); }
1553 };
1554 
1555 static PetscErrorCode MatCOOStructDestroy_MPIAIJKokkos(void **data)
1556 {
1557   PetscFunctionBegin;
1558   PetscCallCXX(delete static_cast<MatCOOStruct_MPIAIJKokkos *>(*data));
1559   PetscFunctionReturn(PETSC_SUCCESS);
1560 }
1561 
1562 static PetscErrorCode MatSetPreallocationCOO_MPIAIJKokkos(Mat mat, PetscCount coo_n, PetscInt coo_i[], PetscInt coo_j[])
1563 {
1564   PetscContainer             container_h, container_d;
1565   MatCOOStruct_MPIAIJ       *coo_h;
1566   MatCOOStruct_MPIAIJKokkos *coo_d;
1567 
1568   PetscFunctionBegin;
1569   PetscCall(MatSetPreallocationCOO_MPIAIJ(mat, coo_n, coo_i, coo_j)); /* mpiaij->A,B's type is set to seqaijkokkos */
1570   mat->preallocated = PETSC_TRUE;
1571   PetscCall(MatAssemblyBegin(mat, MAT_FINAL_ASSEMBLY));
1572   PetscCall(MatAssemblyEnd(mat, MAT_FINAL_ASSEMBLY));
1573   PetscCall(MatZeroEntries(mat));
1574 
1575   // Copy the COO struct to device
1576   PetscCall(PetscObjectQuery((PetscObject)mat, "__PETSc_MatCOOStruct_Host", (PetscObject *)&container_h));
1577   PetscCall(PetscContainerGetPointer(container_h, (void **)&coo_h));
1578   PetscCallCXX(coo_d = new MatCOOStruct_MPIAIJKokkos(coo_h));
1579 
1580   // Put the COO struct in a container and then attach that to the matrix
1581   PetscCall(PetscContainerCreate(PETSC_COMM_SELF, &container_d));
1582   PetscCall(PetscContainerSetPointer(container_d, coo_d));
1583   PetscCall(PetscContainerSetCtxDestroy(container_d, MatCOOStructDestroy_MPIAIJKokkos));
1584   PetscCall(PetscObjectCompose((PetscObject)mat, "__PETSc_MatCOOStruct_Device", (PetscObject)container_d));
1585   PetscCall(PetscContainerDestroy(&container_d));
1586   PetscFunctionReturn(PETSC_SUCCESS);
1587 }
1588 
1589 static PetscErrorCode MatSetValuesCOO_MPIAIJKokkos(Mat mat, const PetscScalar v[], InsertMode imode)
1590 {
1591   Mat_MPIAIJ                    *mpiaij = static_cast<Mat_MPIAIJ *>(mat->data);
1592   Mat                            A = mpiaij->A, B = mpiaij->B;
1593   MatScalarKokkosView            Aa, Ba;
1594   MatScalarKokkosView            v1;
1595   PetscMemType                   memtype;
1596   PetscContainer                 container;
1597   MatCOOStruct_MPIAIJKokkos     *coo;
1598   Kokkos::DefaultExecutionSpace &exec = PetscGetKokkosExecutionSpace();
1599 
1600   PetscFunctionBegin;
1601   PetscCall(PetscObjectQuery((PetscObject)mat, "__PETSc_MatCOOStruct_Device", (PetscObject *)&container));
1602   PetscCall(PetscContainerGetPointer(container, (void **)&coo));
1603 
1604   const auto &n      = coo->n;
1605   const auto &Annz   = coo->Annz;
1606   const auto &Annz2  = coo->Annz2;
1607   const auto &Bnnz   = coo->Bnnz;
1608   const auto &Bnnz2  = coo->Bnnz2;
1609   const auto &vsend  = coo->sendbuf;
1610   const auto &v2     = coo->recvbuf;
1611   const auto &Ajmap1 = coo->Ajmap1;
1612   const auto &Ajmap2 = coo->Ajmap2;
1613   const auto &Aimap2 = coo->Aimap2;
1614   const auto &Bjmap1 = coo->Bjmap1;
1615   const auto &Bjmap2 = coo->Bjmap2;
1616   const auto &Bimap2 = coo->Bimap2;
1617   const auto &Aperm1 = coo->Aperm1;
1618   const auto &Aperm2 = coo->Aperm2;
1619   const auto &Bperm1 = coo->Bperm1;
1620   const auto &Bperm2 = coo->Bperm2;
1621   const auto &Cperm1 = coo->Cperm1;
1622 
1623   PetscCall(PetscGetMemType(v, &memtype)); /* Return PETSC_MEMTYPE_HOST when v is NULL */
1624   if (PetscMemTypeHost(memtype)) {         /* If user gave v[] in host, we need to copy it to device if any */
1625     v1 = Kokkos::create_mirror_view_and_copy(exec, MatScalarKokkosViewHost((PetscScalar *)v, n));
1626   } else {
1627     v1 = MatScalarKokkosView((PetscScalar *)v, n); /* Directly use v[]'s memory */
1628   }
1629 
1630   if (imode == INSERT_VALUES) {
1631     PetscCall(MatSeqAIJGetKokkosViewWrite(A, &Aa)); /* write matrix values */
1632     PetscCall(MatSeqAIJGetKokkosViewWrite(B, &Ba));
1633   } else {
1634     PetscCall(MatSeqAIJGetKokkosView(A, &Aa)); /* read & write matrix values */
1635     PetscCall(MatSeqAIJGetKokkosView(B, &Ba));
1636   }
1637 
1638   PetscCall(PetscLogGpuTimeBegin());
1639   /* Pack entries to be sent to remote */
1640   Kokkos::parallel_for(Kokkos::RangePolicy<>(exec, 0, vsend.extent(0)), KOKKOS_LAMBDA(const PetscCount i) { vsend(i) = v1(Cperm1(i)); });
1641 
1642   /* Send remote entries to their owner and overlap the communication with local computation */
1643   PetscCall(PetscSFReduceWithMemTypeBegin(coo->sf, MPIU_SCALAR, PETSC_MEMTYPE_KOKKOS, vsend.data(), PETSC_MEMTYPE_KOKKOS, v2.data(), MPI_REPLACE));
1644   /* Add local entries to A and B in one kernel */
1645   Kokkos::parallel_for(
1646     Kokkos::RangePolicy<>(exec, 0, Annz + Bnnz), KOKKOS_LAMBDA(PetscCount i) {
1647       PetscScalar sum = 0.0;
1648       if (i < Annz) {
1649         for (PetscCount k = Ajmap1(i); k < Ajmap1(i + 1); k++) sum += v1(Aperm1(k));
1650         Aa(i) = (imode == INSERT_VALUES ? 0.0 : Aa(i)) + sum;
1651       } else {
1652         i -= Annz;
1653         for (PetscCount k = Bjmap1(i); k < Bjmap1(i + 1); k++) sum += v1(Bperm1(k));
1654         Ba(i) = (imode == INSERT_VALUES ? 0.0 : Ba(i)) + sum;
1655       }
1656     });
1657   PetscCall(PetscSFReduceEnd(coo->sf, MPIU_SCALAR, vsend.data(), v2.data(), MPI_REPLACE));
1658 
1659   /* Add received remote entries to A and B in one kernel */
1660   Kokkos::parallel_for(
1661     Kokkos::RangePolicy<>(exec, 0, Annz2 + Bnnz2), KOKKOS_LAMBDA(PetscCount i) {
1662       if (i < Annz2) {
1663         for (PetscCount k = Ajmap2(i); k < Ajmap2(i + 1); k++) Aa(Aimap2(i)) += v2(Aperm2(k));
1664       } else {
1665         i -= Annz2;
1666         for (PetscCount k = Bjmap2(i); k < Bjmap2(i + 1); k++) Ba(Bimap2(i)) += v2(Bperm2(k));
1667       }
1668     });
1669   PetscCall(PetscLogGpuTimeEnd());
1670 
1671   if (imode == INSERT_VALUES) {
1672     PetscCall(MatSeqAIJRestoreKokkosViewWrite(A, &Aa)); /* Increase A & B's state etc. */
1673     PetscCall(MatSeqAIJRestoreKokkosViewWrite(B, &Ba));
1674   } else {
1675     PetscCall(MatSeqAIJRestoreKokkosView(A, &Aa));
1676     PetscCall(MatSeqAIJRestoreKokkosView(B, &Ba));
1677   }
1678   PetscFunctionReturn(PETSC_SUCCESS);
1679 }
1680 
1681 static PetscErrorCode MatDestroy_MPIAIJKokkos(Mat A)
1682 {
1683   PetscFunctionBegin;
1684   PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatMPIAIJSetPreallocation_C", NULL));
1685   PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatMPIAIJGetLocalMatMerge_C", NULL));
1686   PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatSetPreallocationCOO_C", NULL));
1687   PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatSetValuesCOO_C", NULL));
1688 #if defined(PETSC_HAVE_HYPRE)
1689   PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatConvert_mpiaijkokkos_hypre_C", NULL));
1690 #endif
1691   PetscCall(MatDestroy_MPIAIJ(A));
1692   PetscFunctionReturn(PETSC_SUCCESS);
1693 }
1694 
1695 static PetscErrorCode MatShift_MPIAIJKokkos(Mat A, PetscScalar a)
1696 {
1697   Mat_MPIAIJ *mpiaij = static_cast<Mat_MPIAIJ *>(A->data);
1698   PetscBool   congruent;
1699 
1700   PetscFunctionBegin;
1701   PetscCall(MatHasCongruentLayouts(A, &congruent));
1702   if (congruent) { // square matrix and the diagonals are solely in the diag block
1703     PetscCall(MatShift(mpiaij->A, a));
1704   } else { // too hard, use the general version
1705     PetscCall(MatShift_Basic(A, a));
1706   }
1707   PetscFunctionReturn(PETSC_SUCCESS);
1708 }
1709 
1710 static PetscErrorCode MatSetOps_MPIAIJKokkos(Mat B)
1711 {
1712   PetscFunctionBegin;
1713   B->ops->assemblyend           = MatAssemblyEnd_MPIAIJKokkos;
1714   B->ops->mult                  = MatMult_MPIAIJKokkos;
1715   B->ops->multadd               = MatMultAdd_MPIAIJKokkos;
1716   B->ops->multtranspose         = MatMultTranspose_MPIAIJKokkos;
1717   B->ops->productsetfromoptions = MatProductSetFromOptions_MPIAIJKokkos;
1718   B->ops->destroy               = MatDestroy_MPIAIJKokkos;
1719   B->ops->shift                 = MatShift_MPIAIJKokkos;
1720 
1721   PetscCall(PetscObjectComposeFunction((PetscObject)B, "MatMPIAIJSetPreallocation_C", MatMPIAIJSetPreallocation_MPIAIJKokkos));
1722   PetscCall(PetscObjectComposeFunction((PetscObject)B, "MatMPIAIJGetLocalMatMerge_C", MatMPIAIJGetLocalMatMerge_MPIAIJKokkos));
1723   PetscCall(PetscObjectComposeFunction((PetscObject)B, "MatSetPreallocationCOO_C", MatSetPreallocationCOO_MPIAIJKokkos));
1724   PetscCall(PetscObjectComposeFunction((PetscObject)B, "MatSetValuesCOO_C", MatSetValuesCOO_MPIAIJKokkos));
1725 #if defined(PETSC_HAVE_HYPRE)
1726   PetscCall(PetscObjectComposeFunction((PetscObject)B, "MatConvert_mpiaijkokkos_hypre_C", MatConvert_AIJ_HYPRE));
1727 #endif
1728   PetscFunctionReturn(PETSC_SUCCESS);
1729 }
1730 
1731 PETSC_INTERN PetscErrorCode MatConvert_MPIAIJ_MPIAIJKokkos(Mat A, MatType mtype, MatReuse reuse, Mat *newmat)
1732 {
1733   Mat         B;
1734   Mat_MPIAIJ *a;
1735 
1736   PetscFunctionBegin;
1737   if (reuse == MAT_INITIAL_MATRIX) {
1738     PetscCall(MatDuplicate(A, MAT_COPY_VALUES, newmat));
1739   } else if (reuse == MAT_REUSE_MATRIX) {
1740     PetscCall(MatCopy(A, *newmat, SAME_NONZERO_PATTERN));
1741   }
1742   B = *newmat;
1743 
1744   B->boundtocpu = PETSC_FALSE;
1745   PetscCall(PetscFree(B->defaultvectype));
1746   PetscCall(PetscStrallocpy(VECKOKKOS, &B->defaultvectype));
1747   PetscCall(PetscObjectChangeTypeName((PetscObject)B, MATMPIAIJKOKKOS));
1748 
1749   a = static_cast<Mat_MPIAIJ *>(A->data);
1750   if (a->A) PetscCall(MatSetType(a->A, MATSEQAIJKOKKOS));
1751   if (a->B) PetscCall(MatSetType(a->B, MATSEQAIJKOKKOS));
1752   if (a->lvec) PetscCall(VecSetType(a->lvec, VECSEQKOKKOS));
1753   PetscCall(MatSetOps_MPIAIJKokkos(B));
1754   PetscFunctionReturn(PETSC_SUCCESS);
1755 }
1756 
1757 /*MC
1758    MATAIJKOKKOS - "mpiaijkokkos", a matrix type to be used for CSR sparse matrices with Kokkos
1759 
1760    A matrix type using Kokkos-Kernels CrsMatrix type for portability across different device types
1761 
1762    Options Database Key:
1763 .  -mat_type aijkokkos - sets the matrix type to `MATAIJKOKKOS`
1764 
1765   Level: beginner
1766 
1767 .seealso: [](ch_matrices), `Mat`, `MatCreateAIJKokkos()`, `MATSEQAIJKOKKOS`, `MATSEQAIJ`, `MATMPIAIJ`
1768 M*/
1769 PETSC_EXTERN PetscErrorCode MatCreate_MPIAIJKokkos(Mat A)
1770 {
1771   PetscFunctionBegin;
1772   PetscCall(PetscKokkosInitializeCheck());
1773   PetscCall(MatCreate_MPIAIJ(A));
1774   PetscCall(MatConvert_MPIAIJ_MPIAIJKokkos(A, MATMPIAIJKOKKOS, MAT_INPLACE_MATRIX, &A));
1775   PetscFunctionReturn(PETSC_SUCCESS);
1776 }
1777 
1778 /*@C
1779   MatCreateAIJKokkos - Creates a sparse matrix in `MATAIJKOKKOS` (compressed row) format
1780   (the default parallel PETSc format).  This matrix will ultimately pushed down
1781   to Kokkos for calculations.
1782 
1783   Collective
1784 
1785   Input Parameters:
1786 + comm  - MPI communicator, set to `PETSC_COMM_SELF`
1787 . m     - number of local rows (or `PETSC_DECIDE` to have calculated if `M` is given)
1788            This value should be the same as the local size used in creating the
1789            y vector for the matrix-vector product y = Ax.
1790 . n     - This value should be the same as the local size used in creating the
1791        x vector for the matrix-vector product y = Ax. (or `PETSC_DECIDE` to have
1792        calculated if N is given) For square matrices n is almost always `m`.
1793 . M     - number of global rows (or `PETSC_DETERMINE` to have calculated if `m` is given)
1794 . N     - number of global columns (or `PETSC_DETERMINE` to have calculated if `n` is given)
1795 . d_nz  - number of nonzeros per row in DIAGONAL portion of local submatrix
1796            (same value is used for all local rows)
1797 . d_nnz - array containing the number of nonzeros in the various rows of the
1798            DIAGONAL portion of the local submatrix (possibly different for each row)
1799            or `NULL`, if `d_nz` is used to specify the nonzero structure.
1800            The size of this array is equal to the number of local rows, i.e `m`.
1801            For matrices you plan to factor you must leave room for the diagonal entry and
1802            put in the entry even if it is zero.
1803 . o_nz  - number of nonzeros per row in the OFF-DIAGONAL portion of local
1804            submatrix (same value is used for all local rows).
1805 - o_nnz - array containing the number of nonzeros in the various rows of the
1806            OFF-DIAGONAL portion of the local submatrix (possibly different for
1807            each row) or `NULL`, if `o_nz` is used to specify the nonzero
1808            structure. The size of this array is equal to the number
1809            of local rows, i.e `m`.
1810 
1811   Output Parameter:
1812 . A - the matrix
1813 
1814   Level: intermediate
1815 
1816   Notes:
1817   It is recommended that one use the `MatCreate()`, `MatSetType()` and/or `MatSetFromOptions()`,
1818   MatXXXXSetPreallocation() paradigm instead of this routine directly.
1819   [MatXXXXSetPreallocation() is, for example, `MatSeqAIJSetPreallocation()`]
1820 
1821   The AIJ format, also called compressed row storage), is fully compatible with standard Fortran
1822   storage.  That is, the stored row and column indices can begin at
1823   either one (as in Fortran) or zero.
1824 
1825 .seealso: [](ch_matrices), `Mat`, `MATAIJKOKKOS`, `MATSEQAIJKOKKOS`, `MATMPIAIJKOKKOS`, `MatCreate()`, `MatCreateAIJ()`, `MatSetValues()`,
1826           `MatSeqAIJSetColumnIndices()`, `MatCreateSeqAIJWithArrays()`
1827 @*/
1828 PetscErrorCode MatCreateAIJKokkos(MPI_Comm comm, PetscInt m, PetscInt n, PetscInt M, PetscInt N, PetscInt d_nz, const PetscInt d_nnz[], PetscInt o_nz, const PetscInt o_nnz[], Mat *A)
1829 {
1830   PetscMPIInt size;
1831 
1832   PetscFunctionBegin;
1833   PetscCall(MatCreate(comm, A));
1834   PetscCall(MatSetSizes(*A, m, n, M, N));
1835   PetscCallMPI(MPI_Comm_size(comm, &size));
1836   if (size > 1) {
1837     PetscCall(MatSetType(*A, MATMPIAIJKOKKOS));
1838     PetscCall(MatMPIAIJSetPreallocation(*A, d_nz, d_nnz, o_nz, o_nnz));
1839   } else {
1840     PetscCall(MatSetType(*A, MATSEQAIJKOKKOS));
1841     PetscCall(MatSeqAIJSetPreallocation(*A, d_nz, d_nnz));
1842   }
1843   PetscFunctionReturn(PETSC_SUCCESS);
1844 }
1845