xref: /petsc/src/mat/impls/aij/mpi/kokkos/mpiaijkok.kokkos.cxx (revision fd4e4c6a5cea325de5cddb1efe0b692eacacde5a)
1 #include <petsc_kokkos.hpp>
2 #include <petscvec_kokkos.hpp>
3 #include <petscpkg_version.h>
4 #include <petsc/private/sfimpl.h>
5 #include <petsc/private/kokkosimpl.hpp>
6 #include <../src/mat/impls/aij/seq/kokkos/aijkok.hpp>
7 #include <../src/mat/impls/aij/mpi/mpiaij.h>
8 #include <KokkosSparse_spadd.hpp>
9 #include <KokkosSparse_spgemm.hpp>
10 
11 static PetscErrorCode MatAssemblyEnd_MPIAIJKokkos(Mat A, MatAssemblyType mode)
12 {
13   Mat_MPIAIJ *mpiaij = (Mat_MPIAIJ *)A->data;
14 
15   PetscFunctionBegin;
16   PetscCall(MatAssemblyEnd_MPIAIJ(A, mode));
17   /* E.g., MatCreateSubMatrix() calls MatCreateMPIAIJWithSeqAIJ(comm,A,B,..), which creates Bnew of SEQAIJ and destroys B of SEQAIJKOKKOS.
18      Thus we finalize A/B/lvec's type in MatAssemblyEnd() to handle various cases.
19    */
20   if (mode == MAT_FINAL_ASSEMBLY) {
21     PetscScalarKokkosView v;
22 
23     PetscCall(MatSetType(mpiaij->A, MATSEQAIJKOKKOS));
24     PetscCall(MatSetType(mpiaij->B, MATSEQAIJKOKKOS));
25     PetscCall(VecSetType(mpiaij->lvec, VECSEQKOKKOS));  // lvec is init'ed on host, without copying to device
26     PetscCall(VecGetKokkosViewWrite(mpiaij->lvec, &v)); // mark lvec updated on device, as we never need to init lvec on device
27     PetscCall(VecRestoreKokkosViewWrite(mpiaij->lvec, &v));
28   }
29   PetscFunctionReturn(PETSC_SUCCESS);
30 }
31 
32 static PetscErrorCode MatMPIAIJSetPreallocation_MPIAIJKokkos(Mat mat, PetscInt d_nz, const PetscInt d_nnz[], PetscInt o_nz, const PetscInt o_nnz[])
33 {
34   Mat_MPIAIJ *mpiaij;
35 
36   PetscFunctionBegin;
37   // reuse MPIAIJ's preallocation, which sets A/B's blocksize along other things
38   PetscCall(MatMPIAIJSetPreallocation_MPIAIJ(mat, d_nz, d_nnz, o_nz, o_nnz));
39   mpiaij = static_cast<Mat_MPIAIJ *>(mat->data);
40   PetscCall(MatConvert_SeqAIJ_SeqAIJKokkos(mpiaij->A, MATSEQAIJKOKKOS, MAT_INPLACE_MATRIX, &mpiaij->A));
41   PetscCall(MatConvert_SeqAIJ_SeqAIJKokkos(mpiaij->B, MATSEQAIJKOKKOS, MAT_INPLACE_MATRIX, &mpiaij->B));
42   PetscFunctionReturn(PETSC_SUCCESS);
43 }
44 
45 static PetscErrorCode MatMult_MPIAIJKokkos(Mat mat, Vec xx, Vec yy)
46 {
47   Mat_MPIAIJ *mpiaij = (Mat_MPIAIJ *)mat->data;
48   PetscInt    nt;
49 
50   PetscFunctionBegin;
51   PetscCall(VecGetLocalSize(xx, &nt));
52   PetscCheck(nt == mat->cmap->n, PETSC_COMM_SELF, PETSC_ERR_ARG_SIZ, "Incompatible partition of mat (%" PetscInt_FMT ") and xx (%" PetscInt_FMT ")", mat->cmap->n, nt);
53   PetscCall(VecScatterBegin(mpiaij->Mvctx, xx, mpiaij->lvec, INSERT_VALUES, SCATTER_FORWARD));
54   PetscCall((*mpiaij->A->ops->mult)(mpiaij->A, xx, yy));
55   PetscCall(VecScatterEnd(mpiaij->Mvctx, xx, mpiaij->lvec, INSERT_VALUES, SCATTER_FORWARD));
56   PetscCall((*mpiaij->B->ops->multadd)(mpiaij->B, mpiaij->lvec, yy, yy));
57   PetscFunctionReturn(PETSC_SUCCESS);
58 }
59 
60 static PetscErrorCode MatMultAdd_MPIAIJKokkos(Mat mat, Vec xx, Vec yy, Vec zz)
61 {
62   Mat_MPIAIJ *mpiaij = (Mat_MPIAIJ *)mat->data;
63   PetscInt    nt;
64 
65   PetscFunctionBegin;
66   PetscCall(VecGetLocalSize(xx, &nt));
67   PetscCheck(nt == mat->cmap->n, PETSC_COMM_SELF, PETSC_ERR_ARG_SIZ, "Incompatible partition of mat (%" PetscInt_FMT ") and xx (%" PetscInt_FMT ")", mat->cmap->n, nt);
68   PetscCall(VecScatterBegin(mpiaij->Mvctx, xx, mpiaij->lvec, INSERT_VALUES, SCATTER_FORWARD));
69   PetscCall((*mpiaij->A->ops->multadd)(mpiaij->A, xx, yy, zz));
70   PetscCall(VecScatterEnd(mpiaij->Mvctx, xx, mpiaij->lvec, INSERT_VALUES, SCATTER_FORWARD));
71   PetscCall((*mpiaij->B->ops->multadd)(mpiaij->B, mpiaij->lvec, zz, zz));
72   PetscFunctionReturn(PETSC_SUCCESS);
73 }
74 
75 static PetscErrorCode MatMultTranspose_MPIAIJKokkos(Mat mat, Vec xx, Vec yy)
76 {
77   Mat_MPIAIJ *mpiaij = (Mat_MPIAIJ *)mat->data;
78   PetscInt    nt;
79 
80   PetscFunctionBegin;
81   PetscCall(VecGetLocalSize(xx, &nt));
82   PetscCheck(nt == mat->rmap->n, PETSC_COMM_SELF, PETSC_ERR_ARG_SIZ, "Incompatible partition of mat (%" PetscInt_FMT ") and xx (%" PetscInt_FMT ")", mat->rmap->n, nt);
83   PetscCall((*mpiaij->B->ops->multtranspose)(mpiaij->B, xx, mpiaij->lvec));
84   PetscCall((*mpiaij->A->ops->multtranspose)(mpiaij->A, xx, yy));
85   PetscCall(VecScatterBegin(mpiaij->Mvctx, mpiaij->lvec, yy, ADD_VALUES, SCATTER_REVERSE));
86   PetscCall(VecScatterEnd(mpiaij->Mvctx, mpiaij->lvec, yy, ADD_VALUES, SCATTER_REVERSE));
87   PetscFunctionReturn(PETSC_SUCCESS);
88 }
89 
90 /* Merge the "A, B" matrices of mat into a matrix C.  mat's type is MPIAIJKOKKOS. C's type is MATSEQAIJKOKKOS.
91    A is put before B. C's size would be A->rmap->n by (A->cmap->n + B->cmap->n).
92    C still uses local column ids. Their corresponding global column ids are returned in glob.
93 */
94 static PetscErrorCode MatMPIAIJGetLocalMatMerge_MPIAIJKokkos(Mat mat, MatReuse reuse, IS *glob, Mat *C)
95 {
96   Mat             Ad, Ao;
97   const PetscInt *cmap;
98 
99   PetscFunctionBegin;
100   PetscCall(MatMPIAIJGetSeqAIJ(mat, &Ad, &Ao, &cmap));
101   PetscCall(MatSeqAIJKokkosMergeMats(Ad, Ao, reuse, C));
102   if (glob) {
103     PetscInt cst, i, dn, on, *gidx;
104     PetscCall(MatGetLocalSize(Ad, NULL, &dn));
105     PetscCall(MatGetLocalSize(Ao, NULL, &on));
106     PetscCall(MatGetOwnershipRangeColumn(mat, &cst, NULL));
107     PetscCall(PetscMalloc1(dn + on, &gidx));
108     for (i = 0; i < dn; i++) gidx[i] = cst + i;
109     for (i = 0; i < on; i++) gidx[i + dn] = cmap[i];
110     PetscCall(ISCreateGeneral(PetscObjectComm((PetscObject)Ad), dn + on, gidx, PETSC_OWN_POINTER, glob));
111   }
112   PetscFunctionReturn(PETSC_SUCCESS);
113 }
114 
115 /* Structs used in matrix products of type C=AB, C=A^tB and C=B^tAB */
116 struct MatMatStruct {
117   PetscInt            n, *garray;     // C's garray and its size.
118   KokkosCsrMatrix     Cd, Co;         // C is in split form matrices (all in local column indcies)
119   KokkosCsrMatrix     C1, C2, C3, C4; // intermediate mat products
120   KokkosCsrMatrix     C2_mid, C4_mid; // alias of C2, C4; share their a[], i[], but with different j[] (hence column size)
121   PetscIntKokkosView  E_NzLeft;
122   PetscSF             sf = nullptr; // SF to bcast or reduce matrices E to F
123   MatScalarKokkosView rootBuf, leafBuf;
124   KokkosCsrMatrix     Fd, Fo; // F in split form
125 
126   KernelHandle kh1; // compute C1, add C1+C3 or C1+Fd
127   KernelHandle kh2; // compute C2, add C2+C4 or C2+Fo
128   KernelHandle kh3; // compute C3
129   KernelHandle kh4; // compute C4
130 
131   PetscInt E_TeamSize; // kernel launching parameters in merging E or splitting F
132   PetscInt E_VectorLength;
133   PetscInt E_RowsPerTeam;
134   PetscInt F_TeamSize;
135   PetscInt F_VectorLength;
136   PetscInt F_RowsPerTeam;
137 
138   ~MatMatStruct()
139   {
140     PetscFunctionBegin;
141     PetscCallAbort(PETSC_COMM_SELF, PetscSFDestroy(&sf));
142     PetscFunctionReturnVoid();
143   }
144 };
145 
146 struct MatMatStruct_AB : public MatMatStruct {
147   PetscIntKokkosView F_NzLeft; // plans to split F (in leafbuf) into Fd, Fo
148   PetscIntKokkosView irootloc; // plans to put E (i.e., Bd, Bo) into rootBuf
149   PetscIntKokkosView rowoffset;
150 };
151 
152 struct MatMatStruct_AtB : public MatMatStruct {
153   MatColIdxKokkosView Fdjmap; // plans to reduce data in rootBuf to Fd, Fo
154   MatColIdxKokkosView Fdjperm;
155   MatColIdxKokkosView Fojmap;
156   MatColIdxKokkosView Fojperm;
157 };
158 
159 struct MatProductData_MPIAIJKokkos {
160   MatMatStruct_AB  *mmAB     = nullptr;
161   MatMatStruct_AtB *mmAtB    = nullptr;
162   PetscBool         reusesym = PETSC_FALSE;
163   Mat               Z        = nullptr; // store Z=AB in computing BtAB
164 
165   ~MatProductData_MPIAIJKokkos()
166   {
167     delete mmAB;
168     delete mmAtB;
169     PetscCallAbort(PETSC_COMM_SELF, MatDestroy(&Z));
170   }
171 };
172 
173 static PetscErrorCode MatProductDataDestroy_MPIAIJKokkos(void *data)
174 {
175   PetscFunctionBegin;
176   PetscCallCXX(delete static_cast<MatProductData_MPIAIJKokkos *>(data));
177   PetscFunctionReturn(PETSC_SUCCESS);
178 }
179 
180 /* MatSetMPIAIJKokkosWithSplitSeqAIJKokkosMatrices - Set the diag and offdiag matrices of a MATMPIAIJKOKKOS matrix.
181    It is similar to MatCreateMPIAIJWithSplitArrays.
182 
183   Input Parameters:
184 +  mat   - the MATMPIAIJKOKKOS matrix, which should have its type and layout set, but should not have its diag, offdiag matrices set
185 .  A     - the diag matrix using local col ids
186 -  B     - the offdiag matrix using global col ids
187 
188   Output Parameter:
189 .  mat   - the updated MATMPIAIJKOKKOS matrix
190 */
191 static PetscErrorCode MatSetMPIAIJKokkosWithSplitSeqAIJKokkosMatrices(Mat mat, Mat A, Mat B, PetscInt *garray)
192 {
193   Mat_MPIAIJ *mpiaij = static_cast<Mat_MPIAIJ *>(mat->data);
194   PetscInt    m, n, M, N, Am, An, Bm, Bn;
195 
196   PetscFunctionBegin;
197   PetscCall(MatGetSize(mat, &M, &N));
198   PetscCall(MatGetLocalSize(mat, &m, &n));
199   PetscCall(MatGetLocalSize(A, &Am, &An));
200   PetscCall(MatGetLocalSize(B, &Bm, &Bn));
201 
202   PetscCheck(m == Am && m == Bm, PETSC_COMM_SELF, PETSC_ERR_PLIB, "local number of rows do not match");
203   PetscCheck(n == An, PETSC_COMM_SELF, PETSC_ERR_PLIB, "local number of columns do not match");
204   // PetscCheck(N == Bn, PETSC_COMM_SELF, PETSC_ERR_PLIB, "global number of columns do not match");
205   PetscCheck(!mpiaij->A && !mpiaij->B, PETSC_COMM_SELF, PETSC_ERR_PLIB, "A, B of the MPIAIJ matrix are not empty");
206   mpiaij->A      = A;
207   mpiaij->B      = B;
208   mpiaij->garray = garray;
209 
210   mat->preallocated     = PETSC_TRUE;
211   mat->nooffprocentries = PETSC_TRUE; /* See MatAssemblyBegin_MPIAIJ. In effect, making MatAssemblyBegin a nop */
212 
213   PetscCall(MatSetOption(mat, MAT_NO_OFF_PROC_ENTRIES, PETSC_TRUE));
214   PetscCall(MatAssemblyBegin(mat, MAT_FINAL_ASSEMBLY));
215   /* MatAssemblyEnd is critical here. It sets mat->offloadmask according to A and B's, and
216     also gets mpiaij->B compacted, with its col ids and size reduced
217   */
218   PetscCall(MatAssemblyEnd(mat, MAT_FINAL_ASSEMBLY));
219   PetscCall(MatSetOption(mat, MAT_NO_OFF_PROC_ENTRIES, PETSC_FALSE));
220   PetscCall(MatSetOption(mat, MAT_NEW_NONZERO_LOCATION_ERR, PETSC_TRUE));
221   PetscFunctionReturn(PETSC_SUCCESS);
222 }
223 
224 // Adapted from Kokkos-Kernels spmv_launch_parameters(), to get parameters in Kokkos nested loops which we used to merge or
225 // split csr matrices. The rule is to have "vector_length * team_size" be around 256 on GPUs (e.g., for a CUDA thread block)
226 template <class ExecutionSpace>
227 static PetscErrorCode MatMergeGetLaunchParameters(PetscInt numRows, PetscInt nnz, PetscInt rows_per_thread, PetscInt &team_size, PetscInt &vector_length, PetscInt &rows_per_team)
228 {
229 #if PETSC_PKG_KOKKOS_KERNELS_VERSION_LE(4, 4, 1)
230   constexpr bool is_gpu_exec_space = KokkosKernels::Impl::kk_is_gpu_exec_space<ExecutionSpace>();
231 #else
232   constexpr bool is_gpu_exec_space = KokkosKernels::Impl::is_gpu_exec_space_v<ExecutionSpace>;
233 #endif
234   Kokkos::TeamPolicy<ExecutionSpace> teamPolicy(128, Kokkos::AUTO);
235 
236   PetscFunctionBegin;
237   PetscInt nnz_per_row = numRows ? (nnz / numRows) : 0; // we might meet empty matrices
238 
239   if (nnz_per_row < 1) nnz_per_row = 1;
240 
241   int max_vector_length = teamPolicy.vector_length_max();
242 
243   if (vector_length < 1) {
244     vector_length = 1;
245     while (vector_length < max_vector_length && vector_length * 6 < nnz_per_row) vector_length *= 2;
246   }
247 
248   // Determine rows per thread
249   if (rows_per_thread < 1) {
250     if (is_gpu_exec_space) rows_per_thread = 1;
251     else {
252       if (nnz_per_row < 20 && nnz > 5000000) {
253         rows_per_thread = 256;
254       } else rows_per_thread = 64;
255     }
256   }
257 
258   if (team_size < 1) {
259     if (is_gpu_exec_space) {
260       team_size = 256 / vector_length;
261     } else {
262       team_size = 1;
263     }
264   }
265 
266   rows_per_team = rows_per_thread * team_size;
267 
268   if (rows_per_team < 0) {
269     PetscInt nnz_per_team = 4096;
270     PetscInt conc         = ExecutionSpace().concurrency();
271     while ((conc * nnz_per_team * 4 > nnz) && (nnz_per_team > 256)) nnz_per_team /= 2;
272     rows_per_team = (nnz_per_team + nnz_per_row - 1) / nnz_per_row;
273   }
274   PetscFunctionReturn(PETSC_SUCCESS);
275 }
276 
277 /*
278   Reduce two sets of global indices into local ones
279 
280   Input Parameters:
281 +  n1          - size of garray1[], the first set
282 .  garray1[n1] - a sorted global index array (without duplicates)
283 .  m           - size of indices[], the second set
284 -  indices[m]  - a unsorted global index array (might have duplicates), which will be updated on output into local ones
285 
286   Output Parameters:
287 +  n2          - size of garray2[], the merged set, which combines garray1[] and indices[]
288 .  garray2[n2] - allocated by callee using PetscMalloc1(). Contains sorted unique global indices (without duplicates). Caller needs to free it.
289 .  map[n1]     - allocated by caller. It gives garray1[i] = garray2[map[i]]
290 -  indices[m]  - on output, global indices in this array are rewritten with local ones, i.e, indices_input[i] = garray2[indices_output[i]]
291 
292    Example, say
293     n1         = 5
294     garray1[5] = {1, 4, 7, 8, 10}
295     m          = 4
296     indices[4] = {2, 4, 8, 9}
297 
298    Combining them together, we have 7 global indices in garray2[]
299     n2         = 7
300     garray2[7] = {1, 2, 4, 7, 8, 9, 10}
301 
302    And we have map[] to connect "garray1[i] = garray2[map[i]], i=[0,n1)"
303     map[5] = {0, 2, 3, 4, 6}
304 
305    On output, indices[] is updated with local indices
306     indices[4] = {1, 2, 4, 5}
307 */
308 static PetscErrorCode ReduceTwoSetsOfGlobalIndices(PetscInt n1, const PetscInt *garray1, PetscInt m, PetscInt *indices, PetscInt *n2_, PetscInt **garray2_, PetscInt *map)
309 {
310   PetscHMapI    g2l = nullptr;
311   PetscHashIter iter;
312   PetscInt      tot, key, val; // total unique global indices. key is global id; val is local id
313   PetscInt      n2, *garray2;
314 
315   PetscFunctionBegin;
316   tot = 0;
317   PetscCall(PetscHMapICreateWithSize(n1, &g2l));
318   for (PetscInt i = 0; i < m; i++) {                                // insert those in indices[]
319     PetscCall(PetscHMapIGetWithDefault(g2l, indices[i], -1, &val)); // if not exist, val is set with -1
320     if (val < 0) PetscCall(PetscHMapISet(g2l, indices[i], tot++));  // val < 0 means gid is not in the hash table yet
321   }
322 
323   for (PetscInt i = 0; i < n1; i++) { // insert those in garray1[]
324     PetscCall(PetscHMapIGetWithDefault(g2l, garray1[i], -1, &val));
325     if (val < 0) PetscCall(PetscHMapISet(g2l, garray1[i], tot++));
326   }
327 
328   // Pull out (unique) globals in the hash table and put them in garray2[]
329   n2 = tot;
330   PetscCall(PetscMalloc1(n2, &garray2));
331   tot = 0;
332   PetscHashIterBegin(g2l, iter);
333   while (!PetscHashIterAtEnd(g2l, iter)) {
334     PetscHashIterGetKey(g2l, iter, key);
335     PetscHashIterNext(g2l, iter);
336     garray2[tot++] = key;
337   }
338 
339   // Sort garray2[] and then map them to local indices starting from 0
340   PetscCall(PetscSortInt(n2, garray2));
341   PetscCall(PetscHMapIClear(g2l));
342   for (PetscInt i = 0; i < tot; i++) PetscCall(PetscHMapISet(g2l, garray2[i], i)); // i is the local id
343 
344   // Rewrite indices[] with local indices
345   for (PetscInt i = 0; i < m; i++) {
346     PetscCall(PetscHMapIGetWithDefault(g2l, indices[i], -1, &val));
347     PetscAssert(val >= 0, PETSC_COMM_SELF, PETSC_ERR_PLIB, "Met a negative local column index");
348     indices[i] = val;
349   }
350   // Record the map that maps garray1[i] to garray2[map[i]]
351   for (PetscInt i = 0; i < n1; i++) PetscCall(PetscHMapIGetWithDefault(g2l, garray1[i], -1, &map[i]));
352   PetscCall(PetscHMapIDestroy(&g2l));
353   *n2_      = n2;
354   *garray2_ = garray2;
355   PetscFunctionReturn(PETSC_SUCCESS);
356 }
357 
358 /*
359   MatMPIAIJKokkosReduce - Reduce rows of a MPIAIJKOKKOS matrix (E, in split form) to produce another matrix (F, also in split form, stored in mm)
360 
361   It is the reverse of MatMPIAIJKokkosBcast() in some sense, but with a different signature since we do not really need a fully populated MPIAIJKOKKOS E.
362 
363   Think each row of E as a leaf, then the given ownerSF specifies roots for the leaves. Roots may connect to multiple leaves.
364   In this routine, we sparse-merge leaves (rows) at their roots to form potentially longer rows in F. F's number of rows will be nroots of ownerSF.
365 
366   Input Parameters:
367 +  comm       - MPI communicator of E
368 .  A          - diag block of E, using local column indices
369 .  B          - off-diag block of E, using local column indices
370 .  cstart      - (global) start column of Ed
371 .  cend        - (global) end column + 1 of Ed.  In other words, E's column ownership is in range of [cstart, cend)
372 .  garray1[n1] - global column indices of Eo. Here n1 is Eo's column size.
373 .  ownerSF     - the SF specifies ownership (root) of rows in E
374 .  reuse       - either MAT_INITIAL_MATRIX or MAT_REUSE_MATRIX
375 -  mm          - to stash intermediate data structures for reuse
376 
377   Output Parameters:
378 +  map[n1]  - allocated by caller. It maps garray1[] to garray2[]. See more at ReduceTwoSetsOfGlobalIndices().
379 -  mm       - contains various info, such as garray2[], F (Fd, Fo) etc.
380 
381   Notes:
382   When reuse = MAT_REUSE_MATRIX, cstart, cend, garray1, ownerSF, map are not significant.
383 
384  */
385 static PetscErrorCode MatMPIAIJKokkosReduceBegin(MPI_Comm comm, KokkosCsrMatrix A, KokkosCsrMatrix B, PetscInt cstart, PetscInt cend, const PetscInt *garray1, PetscSF ownerSF, MatReuse reuse, PetscInt *map, MatMatStruct_AtB *mm)
386 {
387   PetscFunctionBegin;
388   if (reuse == MAT_INITIAL_MATRIX) {
389     PetscInt Em = A.numRows(), Fm;
390     PetscInt n1 = B.numCols();
391 
392     PetscCall(PetscSFGetGraph(ownerSF, &Fm, NULL, NULL, NULL)); // Fm = #rows of F = nroots of ownerSF
393 
394     // Do the analysis on host
395     auto                 Ai_h = Kokkos::create_mirror_view_and_copy(HostMirrorMemorySpace(), A.graph.row_map);
396     auto                 Aj_h = Kokkos::create_mirror_view_and_copy(HostMirrorMemorySpace(), A.graph.entries);
397     auto                 Bi_h = Kokkos::create_mirror_view_and_copy(HostMirrorMemorySpace(), B.graph.row_map);
398     auto                 Bj_h = Kokkos::create_mirror_view_and_copy(HostMirrorMemorySpace(), B.graph.entries);
399     const MatRowMapType *Ai = Ai_h.data(), *Bi = Bi_h.data();
400     const MatColIdxType *Aj = Aj_h.data(), *Bj = Bj_h.data();
401 
402     // Count how many nonzeros of each row in E are in the left of the diag block [cstart,cend)
403     PetscIntKokkosViewHost E_NzLeft_h(NoInit("E_NzLeft_h"), Em), E_RowLen_h(NoInit("E_RowLen_h"), Em);
404     PetscInt              *E_NzLeft = E_NzLeft_h.data(), *E_RowLen = E_RowLen_h.data();
405     for (PetscInt i = 0; i < Em; i++) {
406       const PetscInt *first, *last, *it;
407       PetscInt        count, step;
408       // std::lower_bound(first,last,cstart), but need to use global column indices
409       first = Bj + Bi[i];
410       last  = Bj + Bi[i + 1];
411       count = last - first;
412       while (count > 0) {
413         it   = first;
414         step = count / 2;
415         it += step;
416         if (garray1[*it] < cstart) { // map local to global
417           first = ++it;
418           count -= step + 1;
419         } else count = step;
420       }
421       E_NzLeft[i] = first - (Bj + Bi[i]);
422       E_RowLen[i] = (Ai[i + 1] - Ai[i]) + (Bi[i + 1] - Bi[i]);
423     }
424 
425     // Get length of rows (i.e., sizes of leaves) that contribute to my roots
426     const PetscMPIInt *iranks, *ranks;
427     const PetscInt    *ioffset, *irootloc, *roffset, *rmine;
428     PetscMPIInt        niranks, nranks;
429     MPI_Request       *reqs;
430     PetscMPIInt        tag;
431     PetscSF            reduceSF;
432     PetscInt          *sdisp, *rdisp;
433 
434     PetscCall(PetscCommGetNewTag(comm, &tag));
435     PetscCall(PetscSFGetLeafRanks(ownerSF, &niranks, &iranks, &ioffset, &irootloc));  // get leaf ranks connecting to roots on this process (I'll recv from them)
436     PetscCall(PetscSFGetRootRanks(ownerSF, &nranks, &ranks, &roffset, &rmine, NULL)); // get root ranks this process connects (I'll send to them)
437 
438     // Find out length of each row I will receive. Even for the same row index, when they are from
439     // different senders, they might have different lengths (and sparsity patterns)
440     PetscInt  sendRowCnt = roffset[nranks], recvRowCnt = ioffset[niranks];
441     PetscInt *sendRowLen, *recvRowLen; // lengths of rows of I need to send/recv per process
442 
443     PetscCall(PetscMalloc5(sendRowCnt, &sendRowLen, recvRowCnt + 1, &recvRowLen, nranks, &sdisp, niranks + 1, &rdisp, nranks + niranks, &reqs));
444 
445     for (PetscInt i = 0; i < sendRowCnt; i++) sendRowLen[i] = E_RowLen[rmine[i]];
446     recvRowLen[0] = 0; // since we will make it in CSR format later
447     recvRowLen++;      // advance the pointer now
448     for (PetscInt i = 0; i < niranks; i++) MPI_Irecv(&recvRowLen[ioffset[i]], ioffset[i + 1] - ioffset[i], MPIU_INT, iranks[i], tag, comm, &reqs[nranks + i]);
449     for (PetscInt i = 0; i < nranks; i++) MPIU_Isend(&sendRowLen[roffset[i]], roffset[i + 1] - roffset[i], MPIU_INT, ranks[i], tag, comm, &reqs[i]);
450     PetscCallMPI(MPI_Waitall(nranks + niranks, reqs, MPI_STATUSES_IGNORE));
451 
452     // Build the real PetscSF for reducing E rows (buffer to buffer)
453     rdisp[0] = 0;
454     for (PetscInt i = 0; i < niranks; i++) {
455       rdisp[i + 1] = rdisp[i];
456       for (PetscInt j = ioffset[i]; j < ioffset[i + 1]; j++) { rdisp[i + 1] += recvRowLen[j]; }
457     }
458     recvRowLen--; // put it back into csr format
459     for (PetscInt i = 0; i < recvRowCnt; i++) recvRowLen[i + 1] += recvRowLen[i];
460 
461     for (PetscInt i = 0; i < nranks; i++) MPIU_Irecv(&sdisp[i], 1, MPIU_INT, ranks[i], tag, comm, &reqs[i]);
462     for (PetscInt i = 0; i < niranks; i++) MPIU_Isend(&rdisp[i], 1, MPIU_INT, iranks[i], tag, comm, &reqs[nranks + i]);
463     PetscCallMPI(MPI_Waitall(nranks + niranks, reqs, MPI_STATUSES_IGNORE));
464 
465     PetscInt     nleaves = 0, Enz = 0;    // leaves are nonzeros I will send
466     PetscInt     nroots = rdisp[niranks]; // roots are nonzeros I will recv
467     PetscSFNode *iremote;
468 
469     for (PetscInt i = 0; i < Em; i++) Enz += E_RowLen[i];
470     PetscAssert(A.nnz() + B.nnz() == Enz, comm, PETSC_ERR_PLIB, "Enz should be equal to sum of nnz of A and B");
471     PetscCallMPI(PetscMalloc1(Enz, &iremote)); // no free, since we give ownership to reduceSF
472 
473     for (PetscInt i = 0; i < nranks; i++) {
474       PetscInt count = 0;
475       for (PetscInt j = roffset[i]; j < roffset[i + 1]; j++) count += E_RowLen[rmine[j]];
476       for (PetscInt j = 0; j < count; j++) {
477         iremote[nleaves + j].rank  = ranks[i];
478         iremote[nleaves + j].index = sdisp[i] + j;
479       }
480       nleaves += count;
481     }
482     PetscCheck(nleaves == Enz, comm, PETSC_ERR_PLIB, "nleaves should be equal to Enz");
483 
484     PetscCall(PetscSFCreate(comm, &reduceSF));
485     PetscCall(PetscSFSetGraph(reduceSF, nroots, nleaves, NULL, PETSC_OWN_POINTER, iremote, PETSC_OWN_POINTER));
486 
487     // Copy (global) column indices of the needed rows in E to sendCol[], and then PetscSFReduce to recvCol[]
488     PetscInt *sendCol, *recvCol;
489     PetscCall(PetscMalloc2(nleaves, &sendCol, nroots, &recvCol));
490     for (PetscInt k = 0; k < roffset[nranks]; k++) {
491       PetscInt  i      = rmine[k]; // row to be copied
492       PetscInt *buf    = &sendCol[Ai[i] + Bi[i]];
493       PetscInt  nzLeft = E_NzLeft[i];
494       PetscInt  alen = Ai[i + 1] - Ai[i], blen = Bi[i + 1] - Bi[i];
495       for (PetscInt j = 0; j < alen + blen; j++) {
496         if (j < nzLeft) {
497           buf[j] = garray1[Bj[Bi[i] + j]]; // left B, in global
498         } else if (j < nzLeft + alen) {
499           buf[j] = Aj[Ai[i] + j - nzLeft] + cstart; // diag A, also in global
500         } else {
501           buf[j] = garray1[Bj[Bi[i] + j - alen]]; // right B, in global
502         }
503       }
504     }
505     PetscCall(PetscSFReduceWithMemTypeBegin(reduceSF, MPIU_INT, PETSC_MEMTYPE_HOST, sendCol, PETSC_MEMTYPE_HOST, recvCol, MPI_REPLACE));
506     PetscCall(PetscSFReduceEnd(reduceSF, MPIU_INT, sendCol, recvCol, MPI_REPLACE));
507 
508     // With recvCol[], we do a series of analysis to get i, j of Fd, Fo, and build plans to reduce nonzeros in recv buffers to Fd and Fo
509     PetscInt *recvRowPerm, *recvColSorted;
510     PetscInt *recvNzPerm, *recvNzPermSorted;
511     PetscCall(PetscMalloc4(recvRowCnt, &recvRowPerm, nroots, &recvColSorted, nroots, &recvNzPerm, nroots, &recvNzPermSorted));
512 
513     for (PetscInt i = 0; i < nroots; i++) recvNzPerm[i] = i;                   // numbering all received nonzeros
514     for (PetscInt i = 0; i < recvRowCnt; i++) recvRowPerm[i] = i;              // put up a permutation array, so that after sorting we know where to get a row in recvCol[]
515     PetscCall(PetscSortIntWithPermutation(recvRowCnt, irootloc, recvRowPerm)); // irootloc[] (owned by ownerSF) won't be changed
516 
517     // i[] array, nz are always easiest to compute
518     MatRowMapKokkosViewHost Fdi_h(NoInit("Fdi_h"), Fm + 1), Foi_h(NoInit("Foi_h"), Fm + 1);
519     MatRowMapType          *Fdi, *Foi;
520     PetscInt                FnzDups = 0, Fdnz = 0, FdnzDups = 0, Fonz = 0, FonzDups = 0; // nz (with or without dups) in F, Fd, Fo
521     PetscInt                iter;
522 
523     Kokkos::deep_copy(Fdi_h, 0); // zero, as we will do 'val++' on them
524     Kokkos::deep_copy(Foi_h, 0);
525     Fdi  = Fdi_h.data() + 1; // +1 for easy indexing in code below
526     Foi  = Foi_h.data() + 1;
527     iter = 0;
528     while (iter < recvRowCnt) { // iter over received rows
529       PetscInt curRowIdx = irootloc[recvRowPerm[iter]];
530       PetscInt dupRows   = 1; // current row has this many contributing rows (of various sparsity patterns)
531 
532       while (iter + dupRows < recvRowCnt && irootloc[recvRowPerm[iter + dupRows]] == curRowIdx) dupRows++;
533 
534       // Copy column indices (and their permutation) of these rows into recvColSorted & recvNzPermSorted
535       PetscInt  nz    = 0; // nz (with dups) in the current row
536       PetscInt *jbuf  = recvColSorted + FnzDups;
537       PetscInt *pbuf  = recvNzPermSorted + FnzDups;
538       PetscInt *jbuf2 = jbuf; // temp pointers
539       PetscInt *pbuf2 = pbuf;
540       for (PetscInt d = 0; d < dupRows; d++) {
541         PetscInt i   = recvRowPerm[iter + d];
542         PetscInt len = recvRowLen[i + 1] - recvRowLen[i];
543         PetscCall(PetscArraycpy(jbuf2, &recvCol[recvRowLen[i]], len));
544         PetscCall(PetscArraycpy(pbuf2, &recvNzPerm[recvRowLen[i]], len));
545         jbuf2 += len;
546         pbuf2 += len;
547         nz += len;
548       }
549       PetscCall(PetscIntSortSemiOrderedWithArray(nz, jbuf, pbuf)); // It could be improved with k-way merge sort, since the rows are already sorted
550 
551       // Scan column indices (in jbuf[0,nz), might have dups) of this row, and see how many go to Fd and how many go to Fo
552       PetscInt cur = 0;
553       while (cur < nz) {
554         PetscInt curColIdx = jbuf[cur];
555         PetscInt dups      = 1;
556 
557         while (cur + dups < nz && jbuf[cur + dups] == curColIdx) dups++;
558         if (curColIdx >= cstart && curColIdx < cend) {
559           Fdi[curRowIdx]++;
560           FdnzDups += dups;
561         } else {
562           Foi[curRowIdx]++;
563           FonzDups += dups;
564         }
565         cur += dups;
566       }
567 
568       FnzDups += nz;
569       iter += dupRows; // Move to next unique row
570     }
571 
572     Fdi = Fdi_h.data(); // restore Fdi, Foi and make them CSR
573     Foi = Foi_h.data();
574     for (PetscInt i = 0; i < Fm; i++) {
575       Fdi[i + 1] += Fdi[i];
576       Foi[i + 1] += Foi[i];
577     }
578     Fdnz = Fdi[Fm];
579     Fonz = Foi[Fm];
580     PetscCall(PetscFree2(sendCol, recvCol));
581 
582     // Allocate j, jmap, jperm for Fd and Fo
583     MatColIdxKokkosViewHost Fdj_h(NoInit("Fdj_h"), Fdnz), Foj_h(NoInit("Foj_h"), Fonz);
584     MatRowMapKokkosViewHost Fdjmap_h(NoInit("Fdjmap_h"), Fdnz + 1), Fojmap_h(NoInit("Fojmap_h"), Fonz + 1); // +1 to make csr
585     MatRowMapKokkosViewHost Fdjperm_h(NoInit("Fdjperm_h"), FdnzDups), Fojperm_h(NoInit("Fojperm_h"), FonzDups);
586     MatColIdxType          *Fdj = Fdj_h.data(), *Foj = Foj_h.data();
587     MatRowMapType          *Fdjmap = Fdjmap_h.data(), *Fojmap = Fojmap_h.data();
588     MatRowMapType          *Fdjperm = Fdjperm_h.data(), *Fojperm = Fojperm_h.data();
589 
590     // Scan recvColSorted[] again, and fill j, jmap, jperm for Fd and Fo
591     Fdjmap[0] = 0;
592     Fojmap[0] = 0;
593     FnzDups   = 0;
594     Fdnz      = 0;
595     Fonz      = 0;
596     iter      = 0; // iter over received rows
597     while (iter < recvRowCnt) {
598       PetscInt curRowIdx = irootloc[recvRowPerm[iter]]; // current row idx
599       PetscInt dupRows   = 1;                           // It has this many contributing rows (of various lengths)
600       PetscInt nz        = 0;                           // nz (with dups) in the current row
601 
602       while (iter + dupRows < recvRowCnt && irootloc[recvRowPerm[iter + dupRows]] == curRowIdx) dupRows++;
603       for (PetscInt d = 0; d < dupRows; d++) {
604         PetscInt i = recvRowPerm[iter + d];
605         nz += recvRowLen[i + 1] - recvRowLen[i];
606       }
607 
608       PetscInt *jbuf = recvColSorted + FnzDups;
609       // Scan columns (in jbuf[0,nz) of this row, copy them and their permutation to j[] and jperm[] of Fd and Fo
610       PetscInt cur = 0;
611       while (cur < nz) {
612         PetscInt curColIdx = jbuf[cur];
613         PetscInt dups      = 1;
614 
615         while (cur + dups < nz && jbuf[cur + dups] == curColIdx) dups++;
616         if (curColIdx >= cstart && curColIdx < cend) {
617           Fdj[Fdnz]        = curColIdx - cstart; // easily convert to local
618           Fdjmap[Fdnz + 1] = Fdjmap[Fdnz] + dups;
619           for (PetscInt j = 0; j < dups; j++) Fdjperm[Fdjmap[Fdnz] + j] = recvNzPermSorted[FnzDups + j];
620           FdnzDups += dups;
621           Fdnz++;
622         } else {
623           Foj[Fonz]        = curColIdx; // in global
624           Fojmap[Fonz + 1] = Fojmap[Fonz] + dups;
625           for (PetscInt j = 0; j < dups; j++) Fojperm[Fojmap[Fonz] + j] = recvNzPermSorted[FnzDups + j];
626           FonzDups += dups;
627           Fonz++;
628         }
629         cur += dups;
630         FnzDups += dups;
631       }
632       iter += dupRows; // Move to next unique row
633     }
634     PetscCall(PetscFree4(recvRowPerm, recvColSorted, recvNzPerm, recvNzPermSorted));
635     PetscCall(PetscFree5(sendRowLen, recvRowLen, sdisp, rdisp, reqs));
636 
637     // Combine global column indices in garray1[] and Foj[]
638     PetscInt n2, *garray2;
639 
640     PetscCall(ReduceTwoSetsOfGlobalIndices(n1, garray1, Fonz, Foj, &n2, &garray2, map));
641     mm->sf       = reduceSF;
642     mm->leafBuf  = MatScalarKokkosView(NoInit("leafBuf"), nleaves);
643     mm->rootBuf  = MatScalarKokkosView(NoInit("rootBuf"), nroots);
644     mm->garray   = garray2; // give ownership, so no free
645     mm->n        = n2;
646     mm->E_NzLeft = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), E_NzLeft_h);
647     mm->Fdjmap   = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Fdjmap_h);
648     mm->Fdjperm  = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Fdjperm_h);
649     mm->Fojmap   = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Fojmap_h);
650     mm->Fojperm  = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Fojperm_h);
651 
652     // Output Fd and Fo in KokkosCsrMatrix format
653     MatScalarKokkosView Fda_d(NoInit("Fda_d"), Fdnz);
654     MatRowMapKokkosView Fdi_d = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Fdi_h);
655     MatColIdxKokkosView Fdj_d = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Fdj_h);
656     MatScalarKokkosView Foa_d(NoInit("Foa_d"), Fonz);
657     MatRowMapKokkosView Foi_d = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Foi_h);
658     MatColIdxKokkosView Foj_d = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Foj_h);
659 
660     PetscCallCXX(mm->Fd = KokkosCsrMatrix("Fd", Fm, cend - cstart, Fdnz, Fda_d, Fdi_d, Fdj_d));
661     PetscCallCXX(mm->Fo = KokkosCsrMatrix("Fo", Fm, n2, Fonz, Foa_d, Foi_d, Foj_d)); // Fo's column size is n2, length of garray2[]
662 
663     // Compute kernel launch parameters in merging E
664     PetscInt teamSize, vectorLength, rowsPerTeam;
665 
666     teamSize = vectorLength = rowsPerTeam = -1;
667     PetscCall(MatMergeGetLaunchParameters<DefaultExecutionSpace>(Em, Enz, -1, teamSize, vectorLength, rowsPerTeam));
668     mm->E_TeamSize     = teamSize;
669     mm->E_VectorLength = vectorLength;
670     mm->E_RowsPerTeam  = rowsPerTeam;
671   } else PetscCheck(reuse == MAT_REUSE_MATRIX, comm, PETSC_ERR_PLIB, "Unsupported MatReuse enum %d", reuse);
672 
673   // Handy aliases
674   auto       &Aa           = A.values;
675   auto       &Ba           = B.values;
676   const auto &Ai           = A.graph.row_map;
677   const auto &Bi           = B.graph.row_map;
678   const auto &E_NzLeft     = mm->E_NzLeft;
679   auto       &leafBuf      = mm->leafBuf;
680   auto       &rootBuf      = mm->rootBuf;
681   PetscSF     reduceSF     = mm->sf;
682   PetscInt    Em           = A.numRows();
683   PetscInt    teamSize     = mm->E_TeamSize;
684   PetscInt    vectorLength = mm->E_VectorLength;
685   PetscInt    rowsPerTeam  = mm->E_RowsPerTeam;
686   PetscInt    workSets     = (Em + rowsPerTeam - 1) / rowsPerTeam;
687 
688   // Copy rows in A/B of E to leafBuf, then pass it to rootBuf
689   PetscCallCXX(Kokkos::parallel_for(
690     Kokkos::TeamPolicy<>(PetscGetKokkosExecutionSpace(), workSets, teamSize, vectorLength), KOKKOS_LAMBDA(const KokkosTeamMemberType &t) {
691       Kokkos::parallel_for(Kokkos::TeamThreadRange(t, 0, rowsPerTeam), [&](PetscInt k) {
692         PetscInt i = t.league_rank() * rowsPerTeam + k; // i-th row in F
693         if (i < Em) {
694           PetscInt disp   = Ai(i) + Bi(i);
695           PetscInt alen   = Ai(i + 1) - Ai(i);
696           PetscInt blen   = Bi(i + 1) - Bi(i);
697           PetscInt nzleft = E_NzLeft(i);
698 
699           Kokkos::parallel_for(Kokkos::ThreadVectorRange(t, alen + blen), [&](PetscInt j) {
700             MatScalar &val = leafBuf(disp + j);
701             if (j < nzleft) { // B left
702               val = Ba(Bi(i) + j);
703             } else if (j < nzleft + alen) { // diag A
704               val = Aa(Ai(i) + j - nzleft);
705             } else { // B right
706               val = Ba(Bi(i) + j - alen);
707             }
708           });
709         }
710       });
711     }));
712   PetscCall(PetscSFReduceWithMemTypeBegin(reduceSF, MPIU_SCALAR, PETSC_MEMTYPE_KOKKOS, leafBuf.data(), PETSC_MEMTYPE_KOKKOS, rootBuf.data(), MPI_REPLACE));
713   PetscFunctionReturn(PETSC_SUCCESS);
714 }
715 
716 // To finish MatMPIAIJKokkosReduce.
717 static PetscErrorCode MatMPIAIJKokkosReduceEnd(MPI_Comm comm, KokkosCsrMatrix A, KokkosCsrMatrix B, PetscInt cstart, PetscInt cend, const PetscInt *garray1, PetscSF ownerSF, MatReuse reuse, PetscInt *map, MatMatStruct_AtB *mm)
718 {
719   auto       &leafBuf  = mm->leafBuf;
720   auto       &rootBuf  = mm->rootBuf;
721   auto       &Fda      = mm->Fd.values;
722   const auto &Fdjmap   = mm->Fdjmap;
723   const auto &Fdjperm  = mm->Fdjperm;
724   auto        Fdnz     = mm->Fd.nnz();
725   auto       &Foa      = mm->Fo.values;
726   const auto &Fojmap   = mm->Fojmap;
727   const auto &Fojperm  = mm->Fojperm;
728   auto        Fonz     = mm->Fo.nnz();
729   PetscSF     reduceSF = mm->sf;
730 
731   PetscFunctionBegin;
732   PetscCall(PetscSFReduceEnd(reduceSF, MPIU_SCALAR, leafBuf.data(), rootBuf.data(), MPI_REPLACE));
733 
734   // Reduce data in rootBuf to Fd and Fo
735   PetscCallCXX(Kokkos::parallel_for(
736     Kokkos::RangePolicy<>(PetscGetKokkosExecutionSpace(), 0, Fdnz), KOKKOS_LAMBDA(const MatRowMapType i) {
737       PetscScalar sum = 0.0;
738       for (MatRowMapType k = Fdjmap(i); k < Fdjmap(i + 1); k++) sum += rootBuf(Fdjperm(k));
739       Fda(i) = sum;
740     }));
741 
742   PetscCallCXX(Kokkos::parallel_for(
743     Kokkos::RangePolicy<>(PetscGetKokkosExecutionSpace(), 0, Fonz), KOKKOS_LAMBDA(const MatRowMapType i) {
744       PetscScalar sum = 0.0;
745       for (MatRowMapType k = Fojmap(i); k < Fojmap(i + 1); k++) sum += rootBuf(Fojperm(k));
746       Foa(i) = sum;
747     }));
748   PetscFunctionReturn(PETSC_SUCCESS);
749 }
750 
751 /*
752   MatMPIAIJKokkosBcast - Bcast local rows of a MPIAIJKOKKOS matrix (E) to produce a local matrix (F, stored in mm) in split form
753 
754   This is a complex routine. It is essentially the MPIAIJKOKKOS counterpart of MatGetBrowsOfAoCols_MPIAIJ, but supports
755   device and involves various index mapping.
756 
757   In the given ownerSF, leaves correspond to rows in F, and roots correspond to rows in E. Roots may connect to multiple leaves.
758   Suppose F's j-th row is connected to a root identified by PetscSFNode (k,i), it means we need to bcast the i-th row of E on rank k
759   to j-th row of F. ownerSF is not an arbitrary SF, instead it is the Mvctx of another MPIAIJ matrix A that is able to perform A*E.
760   F has the same column layout as E.
761 
762   Conceptually F has global column indices. In this routine, we spit F into diagonal Fd and off-diagonal Fo.
763   Fd uses local column indices, which are easy to compute. We just need to subtract the "local column range start" from the global indices.
764   Fo had global column indices at first. We will reduce them into local ones. In doing that, we also take into account the global
765   column indices that E's off-diag block has. Let's say there are n1 such indices stored in garray1[]. We will reduce them along with
766   column indices in Fo and update Fo with local indices.
767 
768    Input Parameters:
769 +   E       - the MPIAIJKOKKOS matrix
770 .   ownerSF - the ownership SF (insignificant in MAT_REUSE_MATRIX)
771 .   reuse   - either MAT_INITIAL_MATRIX or MAT_REUSE_MATRIX
772 -   mm      - to stash matproduct intermediate data structures
773 
774     Output Parameters:
775 +   map[n1] - allocated by caller. It maps garray1[] to garray2[]. See more at ReduceTwoSetsOfGlobalIndices.
776 -   mm      - contains various info, such as garray2[], Fd, Fo, etc.
777 
778     Notes:
779     When reuse = MAT_REUSE_MATRIX, ownerSF, map are not significant.
780     The routine is provide in split-phase form MatMPIAIJKokkosBcastBegin/End() to provide computation/communication opportunities.
781 */
782 static PetscErrorCode MatMPIAIJKokkosBcastBegin(Mat E, PetscSF ownerSF, MatReuse reuse, PetscInt *map, MatMatStruct_AB *mm)
783 {
784   Mat_MPIAIJ       *empi = static_cast<Mat_MPIAIJ *>(E->data);
785   Mat               A = empi->A, B = empi->B; // diag and off-diag
786   Mat_SeqAIJKokkos *akok = static_cast<Mat_SeqAIJKokkos *>(A->spptr), *bkok = static_cast<Mat_SeqAIJKokkos *>(B->spptr);
787   PetscInt          Em = E->rmap->n; // #local rows
788   MPI_Comm          comm;
789 
790   PetscFunctionBegin;
791   PetscCallMPI(PetscObjectGetComm((PetscObject)E, &comm));
792   if (reuse == MAT_INITIAL_MATRIX) {
793     Mat_SeqAIJ     *aseq = static_cast<Mat_SeqAIJ *>(A->data), *bseq = static_cast<Mat_SeqAIJ *>(B->data);
794     PetscInt        n1 = B->cmap->n, *Ai = aseq->i, *Aj = aseq->j, *Bi = bseq->i, *Bj = bseq->j;
795     const PetscInt *garray1 = empi->garray; // its size is n1
796     PetscInt        cstart, cend;
797     PetscSF         bcastSF;
798 
799     PetscCall(MatGetOwnershipRangeColumn(E, &cstart, &cend));
800 
801     // Count how many nonzeros of each row in E are in the left of the diag block [cstart,cend)
802     PetscIntKokkosViewHost E_NzLeft_h(NoInit("E_NzLeft_h"), Em), E_RowLen_h(NoInit("E_RowLen_h"), Em);
803     PetscInt              *E_NzLeft = E_NzLeft_h.data(), *E_RowLen = E_RowLen_h.data();
804     for (PetscInt i = 0; i < Em; i++) {
805       const PetscInt *first, *last, *it;
806       PetscInt        count, step;
807       // std::lower_bound(first,last,cstart), but need to use global column indices
808       first = Bj + Bi[i];
809       last  = Bj + Bi[i + 1];
810       count = last - first;
811       while (count > 0) {
812         it   = first;
813         step = count / 2;
814         it += step;
815         if (empi->garray[*it] < cstart) { // map local to global
816           first = ++it;
817           count -= step + 1;
818         } else count = step;
819       }
820       E_NzLeft[i] = first - (Bj + Bi[i]);
821       E_RowLen[i] = (Ai[i + 1] - Ai[i]) + (Bi[i + 1] - Bi[i]);
822     }
823 
824     // Compute row pointer Fi of F
825     PetscInt *Fi, Fm, Fnz;
826     PetscCall(PetscSFGetGraph(ownerSF, NULL, &Fm, NULL, NULL)); // Fm = #rows of F = nleaves of ownerSF
827     PetscCall(PetscMalloc1(Fm + 1, &Fi));
828     Fi[0] = 0;
829     PetscCall(PetscSFBcastWithMemTypeBegin(ownerSF, MPIU_INT, PETSC_MEMTYPE_HOST, E_RowLen, PETSC_MEMTYPE_HOST, &Fi[1], MPI_REPLACE));
830     PetscCall(PetscSFBcastEnd(ownerSF, MPIU_INT, E_RowLen, &Fi[1], MPI_REPLACE));
831     for (PetscInt i = 0; i < Fm; i++) Fi[i + 1] += Fi[i];
832     Fnz = Fi[Fm];
833 
834     // Build the real PetscSF for bcasting E rows (buffer to buffer)
835     const PetscMPIInt *iranks, *ranks;
836     const PetscInt    *ioffset, *irootloc, *roffset;
837     PetscMPIInt        niranks, nranks;
838     PetscInt          *sdisp, *rdisp;
839     MPI_Request       *reqs;
840     PetscMPIInt        tag;
841 
842     PetscCall(PetscSFGetLeafRanks(ownerSF, &niranks, &iranks, &ioffset, &irootloc)); // get leaf ranks referencing roots on this process
843     PetscCall(PetscSFGetRootRanks(ownerSF, &nranks, &ranks, &roffset, NULL, NULL));  // recv info
844     PetscCall(PetscMalloc3(niranks + 1, &sdisp, nranks, &rdisp, niranks + nranks, &reqs));
845 
846     sdisp[0] = 0; // send displacement
847     for (PetscInt i = 0; i < niranks; i++) {
848       sdisp[i + 1] = sdisp[i];
849       for (PetscInt j = ioffset[i]; j < ioffset[i + 1]; j++) {
850         PetscInt r = irootloc[j]; // row to be sent
851         sdisp[i + 1] += E_RowLen[r];
852       }
853     }
854 
855     PetscCallMPI(PetscCommGetNewTag(comm, &tag));
856     for (PetscInt i = 0; i < nranks; i++) PetscCallMPI(MPIU_Irecv(&rdisp[i], 1, MPIU_INT, ranks[i], tag, comm, &reqs[i]));
857     for (PetscInt i = 0; i < niranks; i++) PetscCallMPI(MPIU_Isend(&sdisp[i], 1, MPIU_INT, iranks[i], tag, comm, &reqs[nranks + i]));
858     PetscCallMPI(MPI_Waitall(niranks + nranks, reqs, MPI_STATUSES_IGNORE));
859 
860     PetscInt     nleaves = Fnz;            // leaves are nonzeros I will receive
861     PetscInt     nroots  = sdisp[niranks]; // roots are nonzeros I will send
862     PetscSFNode *iremote;                  // give ownership to bcastSF
863     PetscCall(PetscMalloc1(nleaves, &iremote));
864     for (PetscInt i = 0; i < nranks; i++) { // for each sender rank
865       PetscInt k = 0;
866       for (PetscInt j = Fi[roffset[i]]; j < Fi[roffset[i + 1]]; j++) { // I will receive rows [roffset[i], roffset[i+1]) of F from ranks[i]
867         iremote[j].rank  = ranks[i];
868         iremote[j].index = rdisp[i] + k; // their root location
869         k++;
870       }
871     }
872     PetscCall(PetscSFCreate(comm, &bcastSF));
873     PetscCall(PetscSFSetGraph(bcastSF, nroots, nleaves, NULL, PETSC_OWN_POINTER, iremote, PETSC_OWN_POINTER));
874     PetscCall(PetscFree3(sdisp, rdisp, reqs));
875 
876     // Build a plan (rowoffset, irootloc, E_NzLeft) to copy rows in E to rootdata of bcastSF in parallel
877     PetscIntKokkosViewHost rowoffset_h(NoInit("rowoffset_h"), ioffset[niranks] + 1);
878     PetscInt              *rowoffset = rowoffset_h.data(); // for each entry (row) indicated in irootloc[], we calculate its destinate offset in copying
879     rowoffset[0]                     = 0;
880     for (PetscInt i = 0; i < ioffset[niranks]; i++) { rowoffset[i + 1] = rowoffset[i] + E_RowLen[irootloc[i]]; }
881 
882     // Copy (global) column indices of the needed rows in E to a buffer, and then bcast to Fj[]
883     PetscInt *jbuf, *Fj;
884     PetscCall(PetscMalloc2(nroots, &jbuf, Fnz, &Fj));
885     for (PetscInt k = 0; k < ioffset[niranks]; k++) {
886       PetscInt  i      = irootloc[k]; // row to be copied
887       PetscInt *buf    = &jbuf[rowoffset[k]];
888       PetscInt  nzLeft = E_NzLeft[i];
889       PetscInt  alen = Ai[i + 1] - Ai[i], blen = Bi[i + 1] - Bi[i];
890       for (PetscInt j = 0; j < alen + blen; j++) {
891         if (j < nzLeft) {
892           buf[j] = empi->garray[Bj[Bi[i] + j]]; // left B, in global
893         } else if (j < nzLeft + alen) {
894           buf[j] = Aj[Ai[i] + j - nzLeft] + cstart; // diag A, also in global
895         } else {
896           buf[j] = empi->garray[Bj[Bi[i] + j - alen]]; // right B, in global
897         }
898       }
899     }
900     PetscCall(PetscSFBcastWithMemTypeBegin(bcastSF, MPIU_INT, PETSC_MEMTYPE_HOST, jbuf, PETSC_MEMTYPE_HOST, Fj, MPI_REPLACE));
901     PetscCall(PetscSFBcastEnd(bcastSF, MPIU_INT, jbuf, Fj, MPI_REPLACE));
902 
903     // Build a plan (i.e., F_NzLeft) to split F into Fd and Fo
904     MatRowMapKokkosViewHost Fdi_h(NoInit("Fdi_h"), Fm + 1), Foi_h(NoInit("Foi_h"), Fm + 1); // row pointer of Fd, Fo
905     MatColIdxKokkosViewHost F_NzLeft_h(NoInit("F_NzLeft_h"), Fm);                           // split each row of F into Left, Diag, Right. We only need to record #nz in Left and Diag.
906     MatRowMapType          *Fdi = Fdi_h.data(), *Foi = Foi_h.data();
907     MatColIdxType          *F_NzLeft = F_NzLeft_h.data();
908 
909     Fdi[0] = Foi[0] = 0;
910     for (PetscInt i = 0; i < Fm; i++) {
911       PetscInt *first, *last, *lb1, *lb2;
912       // cut the row into: Left, [cstart, cend), Right
913       first       = Fj + Fi[i];
914       last        = Fj + Fi[i + 1];
915       lb1         = std::lower_bound(first, last, cstart);
916       F_NzLeft[i] = lb1 - first;
917       lb2         = std::lower_bound(first, last, cend);
918       Fdi[i + 1]  = lb2 - lb1;                        // row i length in Fdi
919       Foi[i + 1]  = (Fi[i + 1] - Fi[i]) - Fdi[i + 1]; // row i length in Foi
920     }
921     for (PetscInt i = 0; i < Fm; i++) {
922       Fdi[i + 1] += Fdi[i];
923       Foi[i + 1] += Foi[i];
924     }
925 
926     // Fill Fdj[] and Foj[], i.e., columns of Fd and Fo. Fdj[] are local, but Foj[] are not yet.
927     PetscInt                Fdnz = Fdi[Fm], Fonz = Foi[Fm];
928     MatColIdxKokkosViewHost Fdj_h(NoInit("Fdj_h"), Fdnz), Foj_h(NoInit("Foj_h"), Fonz);
929     MatColIdxType          *Fdj = Fdj_h.data(), *Foj = Foj_h.data(), gid;
930 
931     for (PetscInt i = 0; i < Fm; i++) {
932       PetscInt nzLeft = F_NzLeft[i];
933       PetscInt len    = Fdi[i + 1] - Fdi[i]; // diag row len
934       for (PetscInt j = 0; j < Fi[i + 1] - Fi[i]; j++) {
935         gid = Fj[Fi[i] + j];
936         if (j < nzLeft) { // left, in global
937           Foj[Foi[i] + j] = gid;
938         } else if (j < nzLeft + len) { // diag, in local
939           Fdj[Fdi[i] + j - nzLeft] = gid - cstart;
940         } else { // right, in global
941           Foj[Foi[i] + j - len] = gid;
942         }
943       }
944     }
945     PetscCall(PetscFree2(jbuf, Fj));
946     PetscCall(PetscFree(Fi));
947 
948     // Reduce global indices in Foj[] and garray1[] into local ones
949     PetscInt n2, *garray2;
950     PetscCall(ReduceTwoSetsOfGlobalIndices(n1, garray1, Fonz, Foj, &n2, &garray2, map));
951 
952     // Record the plans built above, for reuse
953     PetscIntKokkosViewHost tmp(const_cast<PetscInt *>(irootloc), ioffset[niranks]); // irootloc[] is owned by ownerSF. We create a copy for safety
954     PetscIntKokkosViewHost irootloc_h(NoInit("irootloc_h"), ioffset[niranks]);
955     Kokkos::deep_copy(irootloc_h, tmp);
956     mm->sf        = bcastSF;
957     mm->E_NzLeft  = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), E_NzLeft_h);
958     mm->F_NzLeft  = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), F_NzLeft_h);
959     mm->irootloc  = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), irootloc_h);
960     mm->rowoffset = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), rowoffset_h);
961     mm->rootBuf   = MatScalarKokkosView(NoInit("rootBuf"), nroots);
962     mm->leafBuf   = MatScalarKokkosView(NoInit("leafBuf"), nleaves);
963     mm->garray    = garray2;
964     mm->n         = n2;
965 
966     // Output Fd and Fo in KokkosCsrMatrix format
967     MatScalarKokkosView Fda_d(NoInit("Fda_d"), Fdnz), Foa_d(NoInit("Foa_d"), Fonz);
968     MatRowMapKokkosView Fdi_d = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Fdi_h);
969     MatColIdxKokkosView Fdj_d = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Fdj_h);
970     MatRowMapKokkosView Foi_d = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Foi_h);
971     MatColIdxKokkosView Foj_d = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Foj_h);
972 
973     PetscCallCXX(mm->Fd = KokkosCsrMatrix("Fd", Fm, cend - cstart, Fdnz, Fda_d, Fdi_d, Fdj_d));
974     PetscCallCXX(mm->Fo = KokkosCsrMatrix("Fo", Fm, n2, Fonz, Foa_d, Foi_d, Foj_d));
975 
976     // Compute kernel launch parameters in merging E or splitting F
977     PetscInt teamSize, vectorLength, rowsPerTeam;
978 
979     teamSize = vectorLength = rowsPerTeam = -1;
980     PetscCall(MatMergeGetLaunchParameters<DefaultExecutionSpace>(mm->irootloc.extent(0), mm->rootBuf.extent(0), -1, teamSize, vectorLength, rowsPerTeam));
981     mm->E_TeamSize     = teamSize;
982     mm->E_VectorLength = vectorLength;
983     mm->E_RowsPerTeam  = rowsPerTeam;
984 
985     teamSize = vectorLength = rowsPerTeam = -1;
986     PetscCall(MatMergeGetLaunchParameters<DefaultExecutionSpace>(Fm, Fnz, -1, teamSize, vectorLength, rowsPerTeam));
987     mm->F_TeamSize     = teamSize;
988     mm->F_VectorLength = vectorLength;
989     mm->F_RowsPerTeam  = rowsPerTeam;
990   } else PetscCheck(reuse == MAT_REUSE_MATRIX, comm, PETSC_ERR_PLIB, "Unsupported MatReuse enum %d", reuse);
991 
992   // Sync E's value to device
993   akok->a_dual.sync_device();
994   bkok->a_dual.sync_device();
995 
996   // Handy aliases
997   const auto &Aa = akok->a_dual.view_device();
998   const auto &Ba = bkok->a_dual.view_device();
999   const auto &Ai = akok->i_dual.view_device();
1000   const auto &Bi = bkok->i_dual.view_device();
1001 
1002   // Fetch the plans
1003   PetscIntKokkosView  &E_NzLeft  = mm->E_NzLeft;
1004   PetscSF             &bcastSF   = mm->sf;
1005   MatScalarKokkosView &rootBuf   = mm->rootBuf;
1006   MatScalarKokkosView &leafBuf   = mm->leafBuf;
1007   PetscIntKokkosView  &irootloc  = mm->irootloc;
1008   PetscIntKokkosView  &rowoffset = mm->rowoffset;
1009 
1010   PetscInt teamSize     = mm->E_TeamSize;
1011   PetscInt vectorLength = mm->E_VectorLength;
1012   PetscInt rowsPerTeam  = mm->E_RowsPerTeam;
1013   PetscInt workSets     = (irootloc.extent(0) + rowsPerTeam - 1) / rowsPerTeam;
1014 
1015   // Copy rows in A/B of E to rootBuf, then bcast it to leafBuf
1016   PetscCallCXX(Kokkos::parallel_for(
1017     Kokkos::TeamPolicy<>(PetscGetKokkosExecutionSpace(), workSets, teamSize, vectorLength), KOKKOS_LAMBDA(const KokkosTeamMemberType &t) {
1018       Kokkos::parallel_for(Kokkos::TeamThreadRange(t, 0, rowsPerTeam), [&](PetscInt k) {
1019         size_t r = t.league_rank() * rowsPerTeam + k; // r-th entry in irootloc[]
1020         if (r < irootloc.extent(0)) {
1021           PetscInt i      = irootloc(r); // row i of E
1022           PetscInt disp   = rowoffset(r);
1023           PetscInt alen   = Ai(i + 1) - Ai(i);
1024           PetscInt blen   = Bi(i + 1) - Bi(i);
1025           PetscInt nzleft = E_NzLeft(i);
1026 
1027           Kokkos::parallel_for(Kokkos::ThreadVectorRange(t, alen + blen), [&](PetscInt j) {
1028             if (j < nzleft) { // B left
1029               rootBuf(disp + j) = Ba(Bi(i) + j);
1030             } else if (j < nzleft + alen) { // diag A
1031               rootBuf(disp + j) = Aa(Ai(i) + j - nzleft);
1032             } else { // B right
1033               rootBuf(disp + j) = Ba(Bi(i) + j - alen);
1034             }
1035           });
1036         }
1037       });
1038     }));
1039   PetscCall(PetscSFBcastWithMemTypeBegin(bcastSF, MPIU_SCALAR, PETSC_MEMTYPE_KOKKOS, rootBuf.data(), PETSC_MEMTYPE_KOKKOS, leafBuf.data(), MPI_REPLACE));
1040   PetscFunctionReturn(PETSC_SUCCESS);
1041 }
1042 
1043 // To finish MatMPIAIJKokkosBcast.
1044 static PetscErrorCode MatMPIAIJKokkosBcastEnd(Mat E, PetscSF ownerSF, MatReuse reuse, PetscInt *map, MatMatStruct_AB *mm)
1045 {
1046   PetscFunctionBegin;
1047   const auto &Fd  = mm->Fd;
1048   const auto &Fo  = mm->Fo;
1049   const auto &Fdi = Fd.graph.row_map;
1050   const auto &Foi = Fo.graph.row_map;
1051   auto       &Fda = Fd.values;
1052   auto       &Foa = Fo.values;
1053   auto        Fm  = Fd.numRows();
1054 
1055   PetscIntKokkosView  &F_NzLeft     = mm->F_NzLeft;
1056   PetscSF             &bcastSF      = mm->sf;
1057   MatScalarKokkosView &rootBuf      = mm->rootBuf;
1058   MatScalarKokkosView &leafBuf      = mm->leafBuf;
1059   PetscInt             teamSize     = mm->F_TeamSize;
1060   PetscInt             vectorLength = mm->F_VectorLength;
1061   PetscInt             rowsPerTeam  = mm->F_RowsPerTeam;
1062   PetscInt             workSets     = (Fm + rowsPerTeam - 1) / rowsPerTeam;
1063 
1064   PetscCall(PetscSFBcastEnd(bcastSF, MPIU_SCALAR, rootBuf.data(), leafBuf.data(), MPI_REPLACE));
1065 
1066   // Update Fda and Foa with new data in leafBuf (as if it is Fa)
1067   PetscCallCXX(Kokkos::parallel_for(
1068     Kokkos::TeamPolicy<>(PetscGetKokkosExecutionSpace(), workSets, teamSize, vectorLength), KOKKOS_LAMBDA(const KokkosTeamMemberType &t) {
1069       Kokkos::parallel_for(Kokkos::TeamThreadRange(t, 0, rowsPerTeam), [&](PetscInt k) {
1070         PetscInt i = t.league_rank() * rowsPerTeam + k; // i-th row in F
1071         if (i < Fm) {
1072           PetscInt nzLeft = F_NzLeft(i);
1073           PetscInt alen   = Fdi(i + 1) - Fdi(i);
1074           PetscInt blen   = Foi(i + 1) - Foi(i);
1075           PetscInt Fii    = Fdi(i) + Foi(i);
1076 
1077           Kokkos::parallel_for(Kokkos::ThreadVectorRange(t, alen + blen), [&](PetscInt j) {
1078             PetscScalar val = leafBuf(Fii + j);
1079             if (j < nzLeft) { // left
1080               Foa(Foi(i) + j) = val;
1081             } else if (j < nzLeft + alen) { // diag
1082               Fda(Fdi(i) + j - nzLeft) = val;
1083             } else { // right
1084               Foa(Foi(i) + j - alen) = val;
1085             }
1086           });
1087         }
1088       });
1089     }));
1090   PetscFunctionReturn(PETSC_SUCCESS);
1091 }
1092 
1093 static PetscErrorCode MatProductSymbolic_MPIAIJKokkos_AtB(Mat_Product *product, Mat A, Mat B, MatMatStruct_AtB *mm)
1094 {
1095   Mat_MPIAIJ     *ampi = static_cast<Mat_MPIAIJ *>(A->data);
1096   Mat_MPIAIJ     *bmpi = static_cast<Mat_MPIAIJ *>(B->data);
1097   KokkosCsrMatrix Adt, Aot, Ad, Ao, Bd, Bo;
1098   PetscInt        cstart, cend;
1099   MPI_Comm        comm;
1100 
1101   PetscFunctionBegin;
1102   PetscCall(PetscObjectGetComm((PetscObject)B, &comm));
1103   PetscCall(MatSeqAIJKokkosGenerateTranspose_Private(ampi->A, &Adt));
1104   PetscCall(MatSeqAIJKokkosGenerateTranspose_Private(ampi->B, &Aot));
1105   PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(ampi->A, &Ad));
1106   PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(ampi->B, &Ao));
1107   PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(bmpi->A, &Bd));
1108   PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(bmpi->B, &Bo));
1109 
1110   // TODO: add command line options to select spgemm algorithms
1111   auto spgemm_alg = KokkosSparse::SPGEMMAlgorithm::SPGEMM_DEFAULT; // default is TPL if enabled, otherwise KK
1112 
1113   // CUDA-10.2's spgemm has bugs. We prefer the SpGEMMreuse APIs introduced in cuda-11.4
1114 #if defined(KOKKOSKERNELS_ENABLE_TPL_CUSPARSE)
1115   #if PETSC_PKG_CUDA_VERSION_LT(11, 4, 0)
1116   spgemm_alg = KokkosSparse::SPGEMMAlgorithm::SPGEMM_KK;
1117   #endif
1118 #endif
1119 
1120   PetscCallCXX(mm->kh1.create_spgemm_handle(spgemm_alg));
1121   PetscCallCXX(mm->kh2.create_spgemm_handle(spgemm_alg));
1122   PetscCallCXX(mm->kh3.create_spgemm_handle(spgemm_alg));
1123   PetscCallCXX(mm->kh4.create_spgemm_handle(spgemm_alg));
1124 
1125   // Aot * (B's diag + B's off-diag)
1126   PetscCallCXX(KokkosSparse::spgemm_symbolic(mm->kh3, Aot, false, Bd, false, mm->C3));
1127   PetscCallCXX(KokkosSparse::spgemm_symbolic(mm->kh4, Aot, false, Bo, false, mm->C4));
1128   // KK spgemm_symbolic() only populates the result's row map, but not its columns.
1129   // TODO: Remove the fake spgemm_numeric() after KK fixed this problem.
1130   PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh3, Aot, false, Bd, false, mm->C3));
1131   PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh4, Aot, false, Bo, false, mm->C4));
1132 #if PETSC_PKG_KOKKOS_KERNELS_VERSION_LT(4, 0, 0)
1133 
1134   PetscCallCXX(sort_crs_matrix(mm->C3));
1135   PetscCallCXX(sort_crs_matrix(mm->C4));
1136 #endif
1137 
1138   // Reduce E (i.e., C3 and C4)'s rows to form F, and overlap the communication
1139   PetscIntKokkosViewHost map_h(NoInit("map_h"), bmpi->B->cmap->n);
1140   PetscCall(MatGetOwnershipRangeColumn(B, &cstart, &cend));
1141   PetscCall(MatMPIAIJKokkosReduceBegin(comm, mm->C3, mm->C4, cstart, cend, bmpi->garray, ampi->Mvctx, MAT_INITIAL_MATRIX, map_h.data(), mm));
1142 
1143   // Adt * (B's diag + B's off-diag)
1144   PetscCallCXX(KokkosSparse::spgemm_symbolic(mm->kh1, Adt, false, Bd, false, mm->C1));
1145   PetscCallCXX(KokkosSparse::spgemm_symbolic(mm->kh2, Adt, false, Bo, false, mm->C2_mid));
1146   PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh1, Adt, false, Bd, false, mm->C1));
1147   PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh2, Adt, false, Bo, false, mm->C2_mid));
1148 #if PETSC_PKG_KOKKOS_KERNELS_VERSION_LT(4, 0, 0)
1149   PetscCallCXX(sort_crs_matrix(mm->C1));
1150   PetscCallCXX(sort_crs_matrix(mm->C2_mid));
1151 #endif
1152 
1153   PetscCall(MatMPIAIJKokkosReduceEnd(comm, mm->C3, mm->C4, cstart, cend, bmpi->garray, ampi->Mvctx, MAT_INITIAL_MATRIX, map_h.data(), mm));
1154 
1155   // Create C2, which shares a, i arrays with C2_mid, but with new column indices and potentially larger column size
1156   MatColIdxKokkosView oldj = mm->C2_mid.graph.entries, newj(NoInit("j"), oldj.extent(0));
1157   PetscIntKokkosView  map  = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), map_h);
1158   PetscCallCXX(Kokkos::parallel_for(Kokkos::RangePolicy<>(PetscGetKokkosExecutionSpace(), 0, oldj.extent(0)), KOKKOS_LAMBDA(const PetscInt i) { newj(i) = map(oldj(i)); }));
1159   PetscCallCXX(mm->C2 = KokkosCsrMatrix("C2", mm->C2_mid.numRows(), mm->n /*new column size*/, mm->C2_mid.nnz(), mm->C2_mid.values, mm->C2_mid.graph.row_map, newj));
1160 
1161   // C = (C1+Fd, C2+Fo)
1162   PetscCallCXX(mm->kh1.create_spadd_handle(true)); // C1, Fd are sorted
1163   PetscCallCXX(mm->kh2.create_spadd_handle(true)); // C2, Fo are sorted
1164   PetscCallCXX(KokkosSparse::spadd_symbolic(&mm->kh1, mm->C1, mm->Fd, mm->Cd));
1165   PetscCallCXX(KokkosSparse::spadd_symbolic(&mm->kh2, mm->C2, mm->Fo, mm->Co));
1166   PetscCallCXX(KokkosSparse::spadd_numeric(&mm->kh1, 1.0, mm->C1, 1.0, mm->Fd, mm->Cd));
1167   PetscCallCXX(KokkosSparse::spadd_numeric(&mm->kh2, 1.0, mm->C2, 1.0, mm->Fo, mm->Co));
1168   PetscFunctionReturn(PETSC_SUCCESS);
1169 }
1170 
1171 static PetscErrorCode MatProductNumeric_MPIAIJKokkos_AtB(Mat_Product *product, Mat A, Mat B, MatMatStruct_AtB *mm)
1172 {
1173   Mat_MPIAIJ     *ampi = static_cast<Mat_MPIAIJ *>(A->data);
1174   Mat_MPIAIJ     *bmpi = static_cast<Mat_MPIAIJ *>(B->data);
1175   KokkosCsrMatrix Adt, Aot, Bd, Bo;
1176   MPI_Comm        comm;
1177 
1178   PetscFunctionBegin;
1179   PetscCall(PetscObjectGetComm((PetscObject)B, &comm));
1180   PetscCall(MatSeqAIJKokkosGenerateTranspose_Private(ampi->A, &Adt));
1181   PetscCall(MatSeqAIJKokkosGenerateTranspose_Private(ampi->B, &Aot));
1182   PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(bmpi->A, &Bd));
1183   PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(bmpi->B, &Bo));
1184 
1185   // Aot * (B's diag + B's off-diag)
1186   PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh3, Aot, false, Bd, false, mm->C3));
1187   PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh4, Aot, false, Bo, false, mm->C4));
1188 
1189   // Reduce E (i.e., C3 and C4)'s rows to form F, and overlap the communication
1190   PetscCall(MatMPIAIJKokkosReduceBegin(comm, mm->C3, mm->C4, 0, 0, NULL, NULL, MAT_REUSE_MATRIX, NULL, mm));
1191 
1192   // Adt * (B's diag + B's off-diag)
1193   PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh1, Adt, false, Bd, false, mm->C1));
1194   PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh2, Adt, false, Bo, false, mm->C2_mid));
1195 
1196   PetscCall(MatMPIAIJKokkosReduceEnd(comm, mm->C3, mm->C4, 0, 0, NULL, NULL, MAT_REUSE_MATRIX, NULL, mm));
1197 
1198   // C = (C1+Fd, C2+Fo)
1199   PetscCallCXX(KokkosSparse::spadd_numeric(&mm->kh1, 1.0, mm->C1, 1.0, mm->Fd, mm->Cd));
1200   PetscCallCXX(KokkosSparse::spadd_numeric(&mm->kh2, 1.0, mm->C2, 1.0, mm->Fo, mm->Co));
1201   PetscFunctionReturn(PETSC_SUCCESS);
1202 }
1203 
1204 /* MatProductSymbolic_MPIAIJKokkos_AB - AB flavor of MatProductSymbolic_MPIAIJKokkos
1205 
1206   Input Parameters:
1207 +  product  - Mat_Product which carried out the computation. Passed in to access info about this mat product.
1208 .  A        - an MPIAIJKOKKOS matrix
1209 .  B        - an MPIAIJKOKKOS matrix
1210 -  mm       - a struct used to stash intermediate data when computing AB. Persist from symbolic to numeric operations.
1211 */
1212 static PetscErrorCode MatProductSymbolic_MPIAIJKokkos_AB(Mat_Product *product, Mat A, Mat B, MatMatStruct_AB *mm)
1213 {
1214   Mat_MPIAIJ     *ampi = static_cast<Mat_MPIAIJ *>(A->data);
1215   Mat_MPIAIJ     *bmpi = static_cast<Mat_MPIAIJ *>(B->data);
1216   KokkosCsrMatrix Ad, Ao, Bd, Bo;
1217 
1218   PetscFunctionBegin;
1219   PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(ampi->A, &Ad));
1220   PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(ampi->B, &Ao));
1221   PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(bmpi->A, &Bd));
1222   PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(bmpi->B, &Bo));
1223 
1224   // TODO: add command line options to select spgemm algorithms
1225   auto spgemm_alg = KokkosSparse::SPGEMMAlgorithm::SPGEMM_DEFAULT; // default is TPL if enabled, otherwise KK
1226 
1227   // CUDA-10.2's spgemm has bugs. We prefer the SpGEMMreuse APIs introduced in cuda-11.4
1228 #if defined(KOKKOSKERNELS_ENABLE_TPL_CUSPARSE)
1229   #if PETSC_PKG_CUDA_VERSION_LT(11, 4, 0)
1230   spgemm_alg = KokkosSparse::SPGEMMAlgorithm::SPGEMM_KK;
1231   #endif
1232 #endif
1233 
1234   mm->kh1.create_spgemm_handle(spgemm_alg);
1235   mm->kh2.create_spgemm_handle(spgemm_alg);
1236   mm->kh3.create_spgemm_handle(spgemm_alg);
1237   mm->kh4.create_spgemm_handle(spgemm_alg);
1238 
1239   // Bcast B's rows to form F, and overlap the communication
1240   PetscIntKokkosViewHost map_h(NoInit("map_h"), bmpi->B->cmap->n);
1241   PetscCall(MatMPIAIJKokkosBcastBegin(B, ampi->Mvctx, MAT_INITIAL_MATRIX, map_h.data(), mm));
1242 
1243   // A's diag * (B's diag + B's off-diag)
1244   PetscCallCXX(KokkosSparse::spgemm_symbolic(mm->kh1, Ad, false, Bd, false, mm->C1));
1245   PetscCallCXX(KokkosSparse::spgemm_symbolic(mm->kh2, Ad, false, Bo, false, mm->C2_mid)); // C2 aliases with C2_mid, except with new column indices
1246   // KK spgemm_symbolic() only populates the result's row map, but not its columns.
1247   // TODO: Remove the fake spgemm_numeric() after KK fixed this problem.
1248   PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh1, Ad, false, Bd, false, mm->C1));
1249   PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh2, Ad, false, Bo, false, mm->C2_mid));
1250 #if PETSC_PKG_KOKKOS_KERNELS_VERSION_LT(4, 0, 0)
1251   PetscCallCXX(sort_crs_matrix(mm->C1));
1252   PetscCallCXX(sort_crs_matrix(mm->C2_mid));
1253 #endif
1254 
1255   PetscCall(MatMPIAIJKokkosBcastEnd(B, ampi->Mvctx, MAT_INITIAL_MATRIX, map_h.data(), mm));
1256 
1257   // A's off-diag * (F's diag + F's off-diag)
1258   PetscCallCXX(KokkosSparse::spgemm_symbolic(mm->kh3, Ao, false, mm->Fd, false, mm->C3));
1259   PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh3, Ao, false, mm->Fd, false, mm->C3));
1260   PetscCallCXX(KokkosSparse::spgemm_symbolic(mm->kh4, Ao, false, mm->Fo, false, mm->C4));
1261   PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh4, Ao, false, mm->Fo, false, mm->C4));
1262 #if PETSC_PKG_KOKKOS_KERNELS_VERSION_LT(4, 0, 0)
1263   PetscCallCXX(sort_crs_matrix(mm->C3));
1264   PetscCallCXX(sort_crs_matrix(mm->C4));
1265 #endif
1266 
1267   // Create C2, which shares a, i arrays with C2_mid, but with new column indices and potentially larger column size
1268   MatColIdxKokkosView oldj = mm->C2_mid.graph.entries, newj(NoInit("j"), oldj.extent(0));
1269   PetscIntKokkosView  map  = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), map_h);
1270   PetscCallCXX(Kokkos::parallel_for(Kokkos::RangePolicy<>(PetscGetKokkosExecutionSpace(), 0, oldj.extent(0)), KOKKOS_LAMBDA(const PetscInt i) { newj(i) = map(oldj(i)); }));
1271   mm->C2 = KokkosCsrMatrix("C2", mm->C2_mid.numRows(), mm->n /*new column size*/, mm->C2_mid.nnz(), mm->C2_mid.values, mm->C2_mid.graph.row_map, newj);
1272 
1273   // C = (Cd, Co) = (C1+C3, C2+C4)
1274   mm->kh1.create_spadd_handle(true); // C1, C3 are sorted
1275   mm->kh2.create_spadd_handle(true); // C2, C4 are sorted
1276   PetscCallCXX(KokkosSparse::spadd_symbolic(&mm->kh1, mm->C1, mm->C3, mm->Cd));
1277   PetscCallCXX(KokkosSparse::spadd_symbolic(&mm->kh2, mm->C2, mm->C4, mm->Co));
1278   PetscCallCXX(KokkosSparse::spadd_numeric(&mm->kh1, 1.0, mm->C1, 1.0, mm->C3, mm->Cd));
1279   PetscCallCXX(KokkosSparse::spadd_numeric(&mm->kh2, 1.0, mm->C2, 1.0, mm->C4, mm->Co));
1280   PetscFunctionReturn(PETSC_SUCCESS);
1281 }
1282 
1283 static PetscErrorCode MatProductNumeric_MPIAIJKokkos_AB(Mat_Product *product, Mat A, Mat B, MatMatStruct_AB *mm)
1284 {
1285   Mat_MPIAIJ     *ampi = static_cast<Mat_MPIAIJ *>(A->data);
1286   Mat_MPIAIJ     *bmpi = static_cast<Mat_MPIAIJ *>(B->data);
1287   KokkosCsrMatrix Ad, Ao, Bd, Bo;
1288 
1289   PetscFunctionBegin;
1290   PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(ampi->A, &Ad));
1291   PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(ampi->B, &Ao));
1292   PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(bmpi->A, &Bd));
1293   PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(bmpi->B, &Bo));
1294 
1295   // Bcast B's rows to form F, and overlap the communication
1296   PetscCall(MatMPIAIJKokkosBcastBegin(B, NULL, MAT_REUSE_MATRIX, NULL, mm));
1297 
1298   // A's diag * (B's diag + B's off-diag)
1299   PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh1, Ad, false, Bd, false, mm->C1));
1300   PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh2, Ad, false, Bo, false, mm->C2_mid));
1301 
1302   PetscCall(MatMPIAIJKokkosBcastEnd(B, NULL, MAT_REUSE_MATRIX, NULL, mm));
1303 
1304   // A's off-diag * (F's diag + F's off-diag)
1305   PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh3, Ao, false, mm->Fd, false, mm->C3));
1306   PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh4, Ao, false, mm->Fo, false, mm->C4));
1307 
1308   // C = (Cd, Co) = (C1+C3, C2+C4)
1309   PetscCallCXX(KokkosSparse::spadd_numeric(&mm->kh1, 1.0, mm->C1, 1.0, mm->C3, mm->Cd));
1310   PetscCallCXX(KokkosSparse::spadd_numeric(&mm->kh2, 1.0, mm->C2, 1.0, mm->C4, mm->Co));
1311   PetscFunctionReturn(PETSC_SUCCESS);
1312 }
1313 
1314 static PetscErrorCode MatProductNumeric_MPIAIJKokkos(Mat C)
1315 {
1316   Mat_MPIAIJ                  *cmpi = static_cast<Mat_MPIAIJ *>(C->data);
1317   Mat_Product                 *product;
1318   MatProductData_MPIAIJKokkos *pdata;
1319   MatProductType               ptype;
1320   Mat                          A, B;
1321 
1322   PetscFunctionBegin;
1323   MatCheckProduct(C, 1); // make sure C is a product
1324   product = C->product;
1325   pdata   = static_cast<MatProductData_MPIAIJKokkos *>(product->data);
1326   ptype   = product->type;
1327   A       = product->A;
1328   B       = product->B;
1329 
1330   // See if numeric has already been done in symbolic (e.g., user calls MatMatMult(A,B,MAT_INITIAL_MATRIX,..,C)).
1331   // If yes, skip the numeric, but reset the flag so that next time when user calls MatMatMult(E,F,MAT_REUSE_MATRIX,..,C),
1332   // we still do numeric.
1333   if (pdata->reusesym) { // numeric reuses results from symbolic
1334     pdata->reusesym = PETSC_FALSE;
1335     PetscFunctionReturn(PETSC_SUCCESS);
1336   }
1337 
1338   if (ptype == MATPRODUCT_AB) {
1339     PetscCall(MatProductNumeric_MPIAIJKokkos_AB(product, A, B, pdata->mmAB));
1340   } else if (ptype == MATPRODUCT_AtB) {
1341     PetscCall(MatProductNumeric_MPIAIJKokkos_AtB(product, A, B, pdata->mmAtB));
1342   } else if (ptype == MATPRODUCT_PtAP) { // BtAB, computed by Z = AB; C= BtZ
1343     PetscCall(MatProductNumeric_MPIAIJKokkos_AB(product, A, B, pdata->mmAB));
1344     PetscCall(MatProductNumeric_MPIAIJKokkos_AtB(product, B, pdata->Z, pdata->mmAtB));
1345   }
1346   PetscCall(MatSeqAIJKokkosModifyDevice(cmpi->A)); // mark that A, B on device are modified
1347   PetscCall(MatSeqAIJKokkosModifyDevice(cmpi->B));
1348   PetscFunctionReturn(PETSC_SUCCESS);
1349 }
1350 
1351 static PetscErrorCode MatProductSymbolic_MPIAIJKokkos(Mat C)
1352 {
1353   Mat                          A, B;
1354   Mat_Product                 *product;
1355   MatProductType               ptype;
1356   MatProductData_MPIAIJKokkos *pdata;
1357   MatMatStruct                *mm = NULL;
1358   PetscInt                     m, n, M, N;
1359   Mat                          Cd, Co;
1360   MPI_Comm                     comm;
1361 
1362   PetscFunctionBegin;
1363   PetscCall(PetscObjectGetComm((PetscObject)C, &comm));
1364   MatCheckProduct(C, 1);
1365   product = C->product;
1366   PetscCheck(!product->data, comm, PETSC_ERR_PLIB, "Product data not empty");
1367   ptype = product->type;
1368   A     = product->A;
1369   B     = product->B;
1370 
1371   switch (ptype) {
1372   case MATPRODUCT_AB:
1373     m = A->rmap->n;
1374     n = B->cmap->n;
1375     M = A->rmap->N;
1376     N = B->cmap->N;
1377     break;
1378   case MATPRODUCT_AtB:
1379     m = A->cmap->n;
1380     n = B->cmap->n;
1381     M = A->cmap->N;
1382     N = B->cmap->N;
1383     break;
1384   case MATPRODUCT_PtAP:
1385     m = B->cmap->n;
1386     n = B->cmap->n;
1387     M = B->cmap->N;
1388     N = B->cmap->N;
1389     break; /* BtAB */
1390   default:
1391     SETERRQ(comm, PETSC_ERR_PLIB, "Not for product type %s", MatProductTypes[ptype]);
1392   }
1393 
1394   PetscCall(MatSetSizes(C, m, n, M, N));
1395   PetscCall(PetscLayoutSetUp(C->rmap));
1396   PetscCall(PetscLayoutSetUp(C->cmap));
1397   PetscCall(MatSetType(C, ((PetscObject)A)->type_name));
1398 
1399   pdata           = new MatProductData_MPIAIJKokkos();
1400   pdata->reusesym = product->api_user;
1401 
1402   if (ptype == MATPRODUCT_AB) {
1403     auto mmAB = new MatMatStruct_AB();
1404     PetscCall(MatProductSymbolic_MPIAIJKokkos_AB(product, A, B, mmAB));
1405     mm = pdata->mmAB = mmAB;
1406   } else if (ptype == MATPRODUCT_AtB) {
1407     auto mmAtB = new MatMatStruct_AtB();
1408     PetscCall(MatProductSymbolic_MPIAIJKokkos_AtB(product, A, B, mmAtB));
1409     mm = pdata->mmAtB = mmAtB;
1410   } else if (ptype == MATPRODUCT_PtAP) { // C = BtAB, computed as Z = AB; C= BtZ
1411     Mat Zd, Zo, Z;                       // Zd, Zo are owned by pdata->Z
1412 
1413     auto mmAB = new MatMatStruct_AB();
1414     PetscCall(MatProductSymbolic_MPIAIJKokkos_AB(product, A, B, mmAB)); // Z stored as mmAB->{Cd, Co}
1415     PetscCall(MatCreateSeqAIJKokkosWithKokkosCsrMatrix(PETSC_COMM_SELF, mmAB->Cd, &Zd));
1416     PetscCall(MatCreateSeqAIJKokkosWithKokkosCsrMatrix(PETSC_COMM_SELF, mmAB->Co, &Zo));
1417     pdata->mmAB = mmAB;
1418 
1419     m = A->rmap->n; // Z's layout
1420     n = B->cmap->n;
1421     M = A->rmap->N;
1422     N = B->cmap->N;
1423     PetscCall(MatCreate(comm, &Z));
1424     PetscCall(MatSetSizes(Z, m, n, M, N));
1425     PetscCall(PetscLayoutSetUp(Z->rmap));
1426     PetscCall(PetscLayoutSetUp(Z->cmap));
1427     PetscCall(MatSetType(Z, MATMPIAIJKOKKOS));
1428     PetscCall(MatSetMPIAIJKokkosWithSplitSeqAIJKokkosMatrices(Z, Zd, Zo, mmAB->garray));
1429 
1430     auto mmAtB = new MatMatStruct_AtB();
1431     PetscCall(MatProductSymbolic_MPIAIJKokkos_AtB(product, B, Z, mmAtB)); // final result C stored as mmAtB->{Cd, Co}
1432 
1433     pdata->Z = Z; // give ownership to pdata
1434     mm = pdata->mmAtB = mmAtB;
1435   }
1436 
1437   PetscCall(MatCreateSeqAIJKokkosWithKokkosCsrMatrix(PETSC_COMM_SELF, mm->Cd, &Cd));
1438   PetscCall(MatCreateSeqAIJKokkosWithKokkosCsrMatrix(PETSC_COMM_SELF, mm->Co, &Co));
1439   PetscCall(MatSetMPIAIJKokkosWithSplitSeqAIJKokkosMatrices(C, Cd, Co, mm->garray));
1440   /* set block sizes */
1441   switch (ptype) {
1442   case MATPRODUCT_PtAP:
1443     if (B->cmap->bs > 1) PetscCall(MatSetBlockSizes(C, B->cmap->bs, B->cmap->bs));
1444     break;
1445   case MATPRODUCT_RARt:
1446     if (B->rmap->bs > 1) PetscCall(MatSetBlockSizes(C, B->rmap->bs, B->rmap->bs));
1447     break;
1448   case MATPRODUCT_ABC:
1449     PetscCall(MatSetBlockSizesFromMats(C, A, product->C));
1450     break;
1451   case MATPRODUCT_AB:
1452     PetscCall(MatSetBlockSizesFromMats(C, A, B));
1453     break;
1454   case MATPRODUCT_AtB:
1455     if (A->cmap->bs > 1 || B->cmap->bs > 1) PetscCall(MatSetBlockSizes(C, A->cmap->bs, B->cmap->bs));
1456     break;
1457   case MATPRODUCT_ABt:
1458     if (A->rmap->bs > 1 || B->rmap->bs > 1) PetscCall(MatSetBlockSizes(C, A->rmap->bs, B->rmap->bs));
1459     break;
1460   default:
1461     SETERRQ(PetscObjectComm((PetscObject)C), PETSC_ERR_PLIB, "Not for ProductType %s", MatProductTypes[ptype]);
1462   }
1463   C->product->data       = pdata;
1464   C->product->destroy    = MatProductDataDestroy_MPIAIJKokkos;
1465   C->ops->productnumeric = MatProductNumeric_MPIAIJKokkos;
1466   PetscFunctionReturn(PETSC_SUCCESS);
1467 }
1468 
1469 PETSC_INTERN PetscErrorCode MatProductSetFromOptions_MPIAIJKokkos(Mat mat)
1470 {
1471   Mat_Product *product = mat->product;
1472   PetscBool    match   = PETSC_FALSE;
1473   PetscBool    usecpu  = PETSC_FALSE;
1474 
1475   PetscFunctionBegin;
1476   MatCheckProduct(mat, 1);
1477   if (!product->A->boundtocpu && !product->B->boundtocpu) PetscCall(PetscObjectTypeCompare((PetscObject)product->B, ((PetscObject)product->A)->type_name, &match));
1478   if (match) { /* we can always fallback to the CPU if requested */
1479     switch (product->type) {
1480     case MATPRODUCT_AB:
1481       if (product->api_user) {
1482         PetscOptionsBegin(PetscObjectComm((PetscObject)mat), ((PetscObject)mat)->prefix, "MatMatMult", "Mat");
1483         PetscCall(PetscOptionsBool("-matmatmult_backend_cpu", "Use CPU code", "MatMatMult", usecpu, &usecpu, NULL));
1484         PetscOptionsEnd();
1485       } else {
1486         PetscOptionsBegin(PetscObjectComm((PetscObject)mat), ((PetscObject)mat)->prefix, "MatProduct_AB", "Mat");
1487         PetscCall(PetscOptionsBool("-mat_product_algorithm_backend_cpu", "Use CPU code", "MatMatMult", usecpu, &usecpu, NULL));
1488         PetscOptionsEnd();
1489       }
1490       break;
1491     case MATPRODUCT_AtB:
1492       if (product->api_user) {
1493         PetscOptionsBegin(PetscObjectComm((PetscObject)mat), ((PetscObject)mat)->prefix, "MatTransposeMatMult", "Mat");
1494         PetscCall(PetscOptionsBool("-mattransposematmult_backend_cpu", "Use CPU code", "MatTransposeMatMult", usecpu, &usecpu, NULL));
1495         PetscOptionsEnd();
1496       } else {
1497         PetscOptionsBegin(PetscObjectComm((PetscObject)mat), ((PetscObject)mat)->prefix, "MatProduct_AtB", "Mat");
1498         PetscCall(PetscOptionsBool("-mat_product_algorithm_backend_cpu", "Use CPU code", "MatTransposeMatMult", usecpu, &usecpu, NULL));
1499         PetscOptionsEnd();
1500       }
1501       break;
1502     case MATPRODUCT_PtAP:
1503       if (product->api_user) {
1504         PetscOptionsBegin(PetscObjectComm((PetscObject)mat), ((PetscObject)mat)->prefix, "MatPtAP", "Mat");
1505         PetscCall(PetscOptionsBool("-matptap_backend_cpu", "Use CPU code", "MatPtAP", usecpu, &usecpu, NULL));
1506         PetscOptionsEnd();
1507       } else {
1508         PetscOptionsBegin(PetscObjectComm((PetscObject)mat), ((PetscObject)mat)->prefix, "MatProduct_PtAP", "Mat");
1509         PetscCall(PetscOptionsBool("-mat_product_algorithm_backend_cpu", "Use CPU code", "MatPtAP", usecpu, &usecpu, NULL));
1510         PetscOptionsEnd();
1511       }
1512       break;
1513     default:
1514       break;
1515     }
1516     match = (PetscBool)!usecpu;
1517   }
1518   if (match) {
1519     switch (product->type) {
1520     case MATPRODUCT_AB:
1521     case MATPRODUCT_AtB:
1522     case MATPRODUCT_PtAP:
1523       mat->ops->productsymbolic = MatProductSymbolic_MPIAIJKokkos;
1524       break;
1525     default:
1526       break;
1527     }
1528   }
1529   /* fallback to MPIAIJ ops */
1530   if (!mat->ops->productsymbolic) PetscCall(MatProductSetFromOptions_MPIAIJ(mat));
1531   PetscFunctionReturn(PETSC_SUCCESS);
1532 }
1533 
1534 // Mirror of MatCOOStruct_MPIAIJ on device
1535 struct MatCOOStruct_MPIAIJKokkos {
1536   PetscCount           n;
1537   PetscSF              sf;
1538   PetscCount           Annz, Bnnz;
1539   PetscCount           Annz2, Bnnz2;
1540   PetscCountKokkosView Ajmap1, Aperm1;
1541   PetscCountKokkosView Bjmap1, Bperm1;
1542   PetscCountKokkosView Aimap2, Ajmap2, Aperm2;
1543   PetscCountKokkosView Bimap2, Bjmap2, Bperm2;
1544   PetscCountKokkosView Cperm1;
1545   MatScalarKokkosView  sendbuf, recvbuf;
1546 
1547   MatCOOStruct_MPIAIJKokkos(const MatCOOStruct_MPIAIJ *coo_h)
1548   {
1549     auto &exec = PetscGetKokkosExecutionSpace();
1550 
1551     n       = coo_h->n;
1552     sf      = coo_h->sf;
1553     Annz    = coo_h->Annz;
1554     Bnnz    = coo_h->Bnnz;
1555     Annz2   = coo_h->Annz2;
1556     Bnnz2   = coo_h->Bnnz2;
1557     Ajmap1  = Kokkos::create_mirror_view_and_copy(exec, PetscCountKokkosViewHost(coo_h->Ajmap1, coo_h->Annz + 1));
1558     Aperm1  = Kokkos::create_mirror_view_and_copy(exec, PetscCountKokkosViewHost(coo_h->Aperm1, coo_h->Atot1));
1559     Bjmap1  = Kokkos::create_mirror_view_and_copy(exec, PetscCountKokkosViewHost(coo_h->Bjmap1, coo_h->Bnnz + 1));
1560     Bperm1  = Kokkos::create_mirror_view_and_copy(exec, PetscCountKokkosViewHost(coo_h->Bperm1, coo_h->Btot1));
1561     Aimap2  = Kokkos::create_mirror_view_and_copy(exec, PetscCountKokkosViewHost(coo_h->Aimap2, coo_h->Annz2));
1562     Ajmap2  = Kokkos::create_mirror_view_and_copy(exec, PetscCountKokkosViewHost(coo_h->Ajmap2, coo_h->Annz2 + 1));
1563     Aperm2  = Kokkos::create_mirror_view_and_copy(exec, PetscCountKokkosViewHost(coo_h->Aperm2, coo_h->Atot2));
1564     Bimap2  = Kokkos::create_mirror_view_and_copy(exec, PetscCountKokkosViewHost(coo_h->Bimap2, coo_h->Bnnz2));
1565     Bjmap2  = Kokkos::create_mirror_view_and_copy(exec, PetscCountKokkosViewHost(coo_h->Bjmap2, coo_h->Bnnz2 + 1));
1566     Bperm2  = Kokkos::create_mirror_view_and_copy(exec, PetscCountKokkosViewHost(coo_h->Bperm2, coo_h->Btot2));
1567     Cperm1  = Kokkos::create_mirror_view_and_copy(exec, PetscCountKokkosViewHost(coo_h->Cperm1, coo_h->sendlen));
1568     sendbuf = Kokkos::create_mirror_view(Kokkos::WithoutInitializing, exec, MatScalarKokkosViewHost(coo_h->sendbuf, coo_h->sendlen));
1569     recvbuf = Kokkos::create_mirror_view(Kokkos::WithoutInitializing, exec, MatScalarKokkosViewHost(coo_h->recvbuf, coo_h->recvlen));
1570     PetscCallVoid(PetscObjectReference((PetscObject)sf));
1571   }
1572 
1573   ~MatCOOStruct_MPIAIJKokkos() { PetscCallVoid(PetscSFDestroy(&sf)); }
1574 };
1575 
1576 static PetscErrorCode MatCOOStructDestroy_MPIAIJKokkos(void **data)
1577 {
1578   PetscFunctionBegin;
1579   PetscCallCXX(delete static_cast<MatCOOStruct_MPIAIJKokkos *>(*data));
1580   PetscFunctionReturn(PETSC_SUCCESS);
1581 }
1582 
1583 static PetscErrorCode MatSetPreallocationCOO_MPIAIJKokkos(Mat mat, PetscCount coo_n, PetscInt coo_i[], PetscInt coo_j[])
1584 {
1585   PetscContainer             container_h, container_d;
1586   MatCOOStruct_MPIAIJ       *coo_h;
1587   MatCOOStruct_MPIAIJKokkos *coo_d;
1588 
1589   PetscFunctionBegin;
1590   PetscCall(MatSetPreallocationCOO_MPIAIJ(mat, coo_n, coo_i, coo_j)); /* mpiaij->A,B's type is set to seqaijkokkos */
1591   mat->preallocated = PETSC_TRUE;
1592   PetscCall(MatAssemblyBegin(mat, MAT_FINAL_ASSEMBLY));
1593   PetscCall(MatAssemblyEnd(mat, MAT_FINAL_ASSEMBLY));
1594   PetscCall(MatZeroEntries(mat));
1595 
1596   // Copy the COO struct to device
1597   PetscCall(PetscObjectQuery((PetscObject)mat, "__PETSc_MatCOOStruct_Host", (PetscObject *)&container_h));
1598   PetscCall(PetscContainerGetPointer(container_h, (void **)&coo_h));
1599   PetscCallCXX(coo_d = new MatCOOStruct_MPIAIJKokkos(coo_h));
1600 
1601   // Put the COO struct in a container and then attach that to the matrix
1602   PetscCall(PetscContainerCreate(PETSC_COMM_SELF, &container_d));
1603   PetscCall(PetscContainerSetPointer(container_d, coo_d));
1604   PetscCall(PetscContainerSetCtxDestroy(container_d, MatCOOStructDestroy_MPIAIJKokkos));
1605   PetscCall(PetscObjectCompose((PetscObject)mat, "__PETSc_MatCOOStruct_Device", (PetscObject)container_d));
1606   PetscCall(PetscContainerDestroy(&container_d));
1607   PetscFunctionReturn(PETSC_SUCCESS);
1608 }
1609 
1610 static PetscErrorCode MatSetValuesCOO_MPIAIJKokkos(Mat mat, const PetscScalar v[], InsertMode imode)
1611 {
1612   Mat_MPIAIJ                    *mpiaij = static_cast<Mat_MPIAIJ *>(mat->data);
1613   Mat                            A = mpiaij->A, B = mpiaij->B;
1614   MatScalarKokkosView            Aa, Ba;
1615   MatScalarKokkosView            v1;
1616   PetscMemType                   memtype;
1617   PetscContainer                 container;
1618   MatCOOStruct_MPIAIJKokkos     *coo;
1619   Kokkos::DefaultExecutionSpace &exec = PetscGetKokkosExecutionSpace();
1620 
1621   PetscFunctionBegin;
1622   PetscCall(PetscObjectQuery((PetscObject)mat, "__PETSc_MatCOOStruct_Device", (PetscObject *)&container));
1623   PetscCall(PetscContainerGetPointer(container, (void **)&coo));
1624 
1625   const auto &n      = coo->n;
1626   const auto &Annz   = coo->Annz;
1627   const auto &Annz2  = coo->Annz2;
1628   const auto &Bnnz   = coo->Bnnz;
1629   const auto &Bnnz2  = coo->Bnnz2;
1630   const auto &vsend  = coo->sendbuf;
1631   const auto &v2     = coo->recvbuf;
1632   const auto &Ajmap1 = coo->Ajmap1;
1633   const auto &Ajmap2 = coo->Ajmap2;
1634   const auto &Aimap2 = coo->Aimap2;
1635   const auto &Bjmap1 = coo->Bjmap1;
1636   const auto &Bjmap2 = coo->Bjmap2;
1637   const auto &Bimap2 = coo->Bimap2;
1638   const auto &Aperm1 = coo->Aperm1;
1639   const auto &Aperm2 = coo->Aperm2;
1640   const auto &Bperm1 = coo->Bperm1;
1641   const auto &Bperm2 = coo->Bperm2;
1642   const auto &Cperm1 = coo->Cperm1;
1643 
1644   PetscCall(PetscGetMemType(v, &memtype)); /* Return PETSC_MEMTYPE_HOST when v is NULL */
1645   if (PetscMemTypeHost(memtype)) {         /* If user gave v[] in host, we need to copy it to device if any */
1646     v1 = Kokkos::create_mirror_view_and_copy(exec, MatScalarKokkosViewHost((PetscScalar *)v, n));
1647   } else {
1648     v1 = MatScalarKokkosView((PetscScalar *)v, n); /* Directly use v[]'s memory */
1649   }
1650 
1651   if (imode == INSERT_VALUES) {
1652     PetscCall(MatSeqAIJGetKokkosViewWrite(A, &Aa)); /* write matrix values */
1653     PetscCall(MatSeqAIJGetKokkosViewWrite(B, &Ba));
1654   } else {
1655     PetscCall(MatSeqAIJGetKokkosView(A, &Aa)); /* read & write matrix values */
1656     PetscCall(MatSeqAIJGetKokkosView(B, &Ba));
1657   }
1658 
1659   PetscCall(PetscLogGpuTimeBegin());
1660   /* Pack entries to be sent to remote */
1661   Kokkos::parallel_for(Kokkos::RangePolicy<>(exec, 0, vsend.extent(0)), KOKKOS_LAMBDA(const PetscCount i) { vsend(i) = v1(Cperm1(i)); });
1662 
1663   /* Send remote entries to their owner and overlap the communication with local computation */
1664   PetscCall(PetscSFReduceWithMemTypeBegin(coo->sf, MPIU_SCALAR, PETSC_MEMTYPE_KOKKOS, vsend.data(), PETSC_MEMTYPE_KOKKOS, v2.data(), MPI_REPLACE));
1665   /* Add local entries to A and B in one kernel */
1666   Kokkos::parallel_for(
1667     Kokkos::RangePolicy<>(exec, 0, Annz + Bnnz), KOKKOS_LAMBDA(PetscCount i) {
1668       PetscScalar sum = 0.0;
1669       if (i < Annz) {
1670         for (PetscCount k = Ajmap1(i); k < Ajmap1(i + 1); k++) sum += v1(Aperm1(k));
1671         Aa(i) = (imode == INSERT_VALUES ? 0.0 : Aa(i)) + sum;
1672       } else {
1673         i -= Annz;
1674         for (PetscCount k = Bjmap1(i); k < Bjmap1(i + 1); k++) sum += v1(Bperm1(k));
1675         Ba(i) = (imode == INSERT_VALUES ? 0.0 : Ba(i)) + sum;
1676       }
1677     });
1678   PetscCall(PetscSFReduceEnd(coo->sf, MPIU_SCALAR, vsend.data(), v2.data(), MPI_REPLACE));
1679 
1680   /* Add received remote entries to A and B in one kernel */
1681   Kokkos::parallel_for(
1682     Kokkos::RangePolicy<>(exec, 0, Annz2 + Bnnz2), KOKKOS_LAMBDA(PetscCount i) {
1683       if (i < Annz2) {
1684         for (PetscCount k = Ajmap2(i); k < Ajmap2(i + 1); k++) Aa(Aimap2(i)) += v2(Aperm2(k));
1685       } else {
1686         i -= Annz2;
1687         for (PetscCount k = Bjmap2(i); k < Bjmap2(i + 1); k++) Ba(Bimap2(i)) += v2(Bperm2(k));
1688       }
1689     });
1690   PetscCall(PetscLogGpuTimeEnd());
1691 
1692   if (imode == INSERT_VALUES) {
1693     PetscCall(MatSeqAIJRestoreKokkosViewWrite(A, &Aa)); /* Increase A & B's state etc. */
1694     PetscCall(MatSeqAIJRestoreKokkosViewWrite(B, &Ba));
1695   } else {
1696     PetscCall(MatSeqAIJRestoreKokkosView(A, &Aa));
1697     PetscCall(MatSeqAIJRestoreKokkosView(B, &Ba));
1698   }
1699   PetscFunctionReturn(PETSC_SUCCESS);
1700 }
1701 
1702 static PetscErrorCode MatDestroy_MPIAIJKokkos(Mat A)
1703 {
1704   PetscFunctionBegin;
1705   PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatMPIAIJSetPreallocation_C", NULL));
1706   PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatMPIAIJGetLocalMatMerge_C", NULL));
1707   PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatSetPreallocationCOO_C", NULL));
1708   PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatSetValuesCOO_C", NULL));
1709 #if defined(PETSC_HAVE_HYPRE)
1710   PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatConvert_mpiaijkokkos_hypre_C", NULL));
1711 #endif
1712   PetscCall(MatDestroy_MPIAIJ(A));
1713   PetscFunctionReturn(PETSC_SUCCESS);
1714 }
1715 
1716 static PetscErrorCode MatShift_MPIAIJKokkos(Mat A, PetscScalar a)
1717 {
1718   Mat_MPIAIJ *mpiaij = static_cast<Mat_MPIAIJ *>(A->data);
1719   PetscBool   congruent;
1720 
1721   PetscFunctionBegin;
1722   PetscCall(MatHasCongruentLayouts(A, &congruent));
1723   if (congruent) { // square matrix and the diagonals are solely in the diag block
1724     PetscCall(MatShift(mpiaij->A, a));
1725   } else { // too hard, use the general version
1726     PetscCall(MatShift_Basic(A, a));
1727   }
1728   PetscFunctionReturn(PETSC_SUCCESS);
1729 }
1730 
1731 static PetscErrorCode MatSetOps_MPIAIJKokkos(Mat B)
1732 {
1733   PetscFunctionBegin;
1734   B->ops->assemblyend           = MatAssemblyEnd_MPIAIJKokkos;
1735   B->ops->mult                  = MatMult_MPIAIJKokkos;
1736   B->ops->multadd               = MatMultAdd_MPIAIJKokkos;
1737   B->ops->multtranspose         = MatMultTranspose_MPIAIJKokkos;
1738   B->ops->productsetfromoptions = MatProductSetFromOptions_MPIAIJKokkos;
1739   B->ops->destroy               = MatDestroy_MPIAIJKokkos;
1740   B->ops->shift                 = MatShift_MPIAIJKokkos;
1741 
1742   PetscCall(PetscObjectComposeFunction((PetscObject)B, "MatMPIAIJSetPreallocation_C", MatMPIAIJSetPreallocation_MPIAIJKokkos));
1743   PetscCall(PetscObjectComposeFunction((PetscObject)B, "MatMPIAIJGetLocalMatMerge_C", MatMPIAIJGetLocalMatMerge_MPIAIJKokkos));
1744   PetscCall(PetscObjectComposeFunction((PetscObject)B, "MatSetPreallocationCOO_C", MatSetPreallocationCOO_MPIAIJKokkos));
1745   PetscCall(PetscObjectComposeFunction((PetscObject)B, "MatSetValuesCOO_C", MatSetValuesCOO_MPIAIJKokkos));
1746 #if defined(PETSC_HAVE_HYPRE)
1747   PetscCall(PetscObjectComposeFunction((PetscObject)B, "MatConvert_mpiaijkokkos_hypre_C", MatConvert_AIJ_HYPRE));
1748 #endif
1749   PetscFunctionReturn(PETSC_SUCCESS);
1750 }
1751 
1752 PETSC_INTERN PetscErrorCode MatConvert_MPIAIJ_MPIAIJKokkos(Mat A, MatType mtype, MatReuse reuse, Mat *newmat)
1753 {
1754   Mat         B;
1755   Mat_MPIAIJ *a;
1756 
1757   PetscFunctionBegin;
1758   if (reuse == MAT_INITIAL_MATRIX) {
1759     PetscCall(MatDuplicate(A, MAT_COPY_VALUES, newmat));
1760   } else if (reuse == MAT_REUSE_MATRIX) {
1761     PetscCall(MatCopy(A, *newmat, SAME_NONZERO_PATTERN));
1762   }
1763   B = *newmat;
1764 
1765   B->boundtocpu = PETSC_FALSE;
1766   PetscCall(PetscFree(B->defaultvectype));
1767   PetscCall(PetscStrallocpy(VECKOKKOS, &B->defaultvectype));
1768   PetscCall(PetscObjectChangeTypeName((PetscObject)B, MATMPIAIJKOKKOS));
1769 
1770   a = static_cast<Mat_MPIAIJ *>(A->data);
1771   if (a->A) PetscCall(MatSetType(a->A, MATSEQAIJKOKKOS));
1772   if (a->B) PetscCall(MatSetType(a->B, MATSEQAIJKOKKOS));
1773   if (a->lvec) PetscCall(VecSetType(a->lvec, VECSEQKOKKOS));
1774   PetscCall(MatSetOps_MPIAIJKokkos(B));
1775   PetscFunctionReturn(PETSC_SUCCESS);
1776 }
1777 
1778 /*MC
1779    MATAIJKOKKOS - "mpiaijkokkos", a matrix type to be used for CSR sparse matrices with Kokkos
1780 
1781    A matrix type using Kokkos-Kernels CrsMatrix type for portability across different device types
1782 
1783    Options Database Key:
1784 .  -mat_type aijkokkos - sets the matrix type to `MATAIJKOKKOS`
1785 
1786   Level: beginner
1787 
1788 .seealso: [](ch_matrices), `Mat`, `MatCreateAIJKokkos()`, `MATSEQAIJKOKKOS`, `MATSEQAIJ`, `MATMPIAIJ`
1789 M*/
1790 PETSC_EXTERN PetscErrorCode MatCreate_MPIAIJKokkos(Mat A)
1791 {
1792   PetscFunctionBegin;
1793   PetscCall(PetscKokkosInitializeCheck());
1794   PetscCall(MatCreate_MPIAIJ(A));
1795   PetscCall(MatConvert_MPIAIJ_MPIAIJKokkos(A, MATMPIAIJKOKKOS, MAT_INPLACE_MATRIX, &A));
1796   PetscFunctionReturn(PETSC_SUCCESS);
1797 }
1798 
1799 /*@C
1800   MatCreateAIJKokkos - Creates a sparse matrix in `MATAIJKOKKOS` (compressed row) format
1801   (the default parallel PETSc format).  This matrix will ultimately pushed down
1802   to Kokkos for calculations.
1803 
1804   Collective
1805 
1806   Input Parameters:
1807 + comm  - MPI communicator, set to `PETSC_COMM_SELF`
1808 . m     - number of local rows (or `PETSC_DECIDE` to have calculated if `M` is given)
1809            This value should be the same as the local size used in creating the
1810            y vector for the matrix-vector product y = Ax.
1811 . n     - This value should be the same as the local size used in creating the
1812        x vector for the matrix-vector product y = Ax. (or `PETSC_DECIDE` to have
1813        calculated if N is given) For square matrices n is almost always `m`.
1814 . M     - number of global rows (or `PETSC_DETERMINE` to have calculated if `m` is given)
1815 . N     - number of global columns (or `PETSC_DETERMINE` to have calculated if `n` is given)
1816 . d_nz  - number of nonzeros per row in DIAGONAL portion of local submatrix
1817            (same value is used for all local rows)
1818 . d_nnz - array containing the number of nonzeros in the various rows of the
1819            DIAGONAL portion of the local submatrix (possibly different for each row)
1820            or `NULL`, if `d_nz` is used to specify the nonzero structure.
1821            The size of this array is equal to the number of local rows, i.e `m`.
1822            For matrices you plan to factor you must leave room for the diagonal entry and
1823            put in the entry even if it is zero.
1824 . o_nz  - number of nonzeros per row in the OFF-DIAGONAL portion of local
1825            submatrix (same value is used for all local rows).
1826 - o_nnz - array containing the number of nonzeros in the various rows of the
1827            OFF-DIAGONAL portion of the local submatrix (possibly different for
1828            each row) or `NULL`, if `o_nz` is used to specify the nonzero
1829            structure. The size of this array is equal to the number
1830            of local rows, i.e `m`.
1831 
1832   Output Parameter:
1833 . A - the matrix
1834 
1835   Level: intermediate
1836 
1837   Notes:
1838   It is recommended that one use the `MatCreate()`, `MatSetType()` and/or `MatSetFromOptions()`,
1839   MatXXXXSetPreallocation() paradigm instead of this routine directly.
1840   [MatXXXXSetPreallocation() is, for example, `MatSeqAIJSetPreallocation()`]
1841 
1842   The AIJ format, also called compressed row storage), is fully compatible with standard Fortran
1843   storage.  That is, the stored row and column indices can begin at
1844   either one (as in Fortran) or zero.
1845 
1846 .seealso: [](ch_matrices), `Mat`, `MATAIJKOKKOS`, `MATSEQAIJKOKKOS`, `MATMPIAIJKOKKOS`, `MatCreate()`, `MatCreateAIJ()`, `MatSetValues()`,
1847           `MatSeqAIJSetColumnIndices()`, `MatCreateSeqAIJWithArrays()`
1848 @*/
1849 PetscErrorCode MatCreateAIJKokkos(MPI_Comm comm, PetscInt m, PetscInt n, PetscInt M, PetscInt N, PetscInt d_nz, const PetscInt d_nnz[], PetscInt o_nz, const PetscInt o_nnz[], Mat *A)
1850 {
1851   PetscMPIInt size;
1852 
1853   PetscFunctionBegin;
1854   PetscCall(MatCreate(comm, A));
1855   PetscCall(MatSetSizes(*A, m, n, M, N));
1856   PetscCallMPI(MPI_Comm_size(comm, &size));
1857   if (size > 1) {
1858     PetscCall(MatSetType(*A, MATMPIAIJKOKKOS));
1859     PetscCall(MatMPIAIJSetPreallocation(*A, d_nz, d_nnz, o_nz, o_nnz));
1860   } else {
1861     PetscCall(MatSetType(*A, MATSEQAIJKOKKOS));
1862     PetscCall(MatSeqAIJSetPreallocation(*A, d_nz, d_nnz));
1863   }
1864   PetscFunctionReturn(PETSC_SUCCESS);
1865 }
1866