xref: /petsc/src/mat/impls/aij/mpi/kokkos/mpiaijkok.kokkos.cxx (revision 6a210b70a61472cf8733db59e8a3fa68f6578c01)
1 #include <petsc_kokkos.hpp>
2 #include <petscvec_kokkos.hpp>
3 #include <petscpkg_version.h>
4 #include <petsc/private/sfimpl.h>
5 #include <petsc/private/kokkosimpl.hpp>
6 #include <../src/mat/impls/aij/seq/kokkos/aijkok.hpp>
7 #include <../src/mat/impls/aij/mpi/mpiaij.h>
8 #include <KokkosSparse_spadd.hpp>
9 #include <KokkosSparse_spgemm.hpp>
10 
11 static PetscErrorCode MatAssemblyEnd_MPIAIJKokkos(Mat A, MatAssemblyType mode)
12 {
13   Mat_MPIAIJ *mpiaij = (Mat_MPIAIJ *)A->data;
14 
15   PetscFunctionBegin;
16   PetscCall(MatAssemblyEnd_MPIAIJ(A, mode));
17   /* E.g., MatCreateSubMatrix() calls MatCreateMPIAIJWithSeqAIJ(comm,A,B,..), which creates Bnew of SEQAIJ and destroys B of SEQAIJKOKKOS.
18      Thus we finalize A/B/lvec's type in MatAssemblyEnd() to handle various cases.
19    */
20   if (mode == MAT_FINAL_ASSEMBLY) {
21     PetscScalarKokkosView v;
22 
23     PetscCall(MatSetType(mpiaij->A, MATSEQAIJKOKKOS));
24     PetscCall(MatSetType(mpiaij->B, MATSEQAIJKOKKOS));
25     PetscCall(VecSetType(mpiaij->lvec, VECSEQKOKKOS));  // lvec is init'ed on host, without copying to device
26     PetscCall(VecGetKokkosViewWrite(mpiaij->lvec, &v)); // mark lvec updated on device, as we never need to init lvec on device
27     PetscCall(VecRestoreKokkosViewWrite(mpiaij->lvec, &v));
28   }
29   PetscFunctionReturn(PETSC_SUCCESS);
30 }
31 
32 static PetscErrorCode MatMPIAIJSetPreallocation_MPIAIJKokkos(Mat mat, PetscInt d_nz, const PetscInt d_nnz[], PetscInt o_nz, const PetscInt o_nnz[])
33 {
34   Mat_MPIAIJ *mpiaij;
35 
36   PetscFunctionBegin;
37   // reuse MPIAIJ's preallocation, which sets A/B's blocksize along other things
38   PetscCall(MatMPIAIJSetPreallocation_MPIAIJ(mat, d_nz, d_nnz, o_nz, o_nnz));
39   mpiaij = static_cast<Mat_MPIAIJ *>(mat->data);
40   PetscCall(MatConvert_SeqAIJ_SeqAIJKokkos(mpiaij->A, MATSEQAIJKOKKOS, MAT_INPLACE_MATRIX, &mpiaij->A));
41   PetscCall(MatConvert_SeqAIJ_SeqAIJKokkos(mpiaij->B, MATSEQAIJKOKKOS, MAT_INPLACE_MATRIX, &mpiaij->B));
42   PetscFunctionReturn(PETSC_SUCCESS);
43 }
44 
45 static PetscErrorCode MatMult_MPIAIJKokkos(Mat mat, Vec xx, Vec yy)
46 {
47   Mat_MPIAIJ *mpiaij = (Mat_MPIAIJ *)mat->data;
48   PetscInt    nt;
49 
50   PetscFunctionBegin;
51   PetscCall(VecGetLocalSize(xx, &nt));
52   PetscCheck(nt == mat->cmap->n, PETSC_COMM_SELF, PETSC_ERR_ARG_SIZ, "Incompatible partition of mat (%" PetscInt_FMT ") and xx (%" PetscInt_FMT ")", mat->cmap->n, nt);
53   PetscCall(VecScatterBegin(mpiaij->Mvctx, xx, mpiaij->lvec, INSERT_VALUES, SCATTER_FORWARD));
54   PetscCall((*mpiaij->A->ops->mult)(mpiaij->A, xx, yy));
55   PetscCall(VecScatterEnd(mpiaij->Mvctx, xx, mpiaij->lvec, INSERT_VALUES, SCATTER_FORWARD));
56   PetscCall((*mpiaij->B->ops->multadd)(mpiaij->B, mpiaij->lvec, yy, yy));
57   PetscFunctionReturn(PETSC_SUCCESS);
58 }
59 
60 static PetscErrorCode MatMultAdd_MPIAIJKokkos(Mat mat, Vec xx, Vec yy, Vec zz)
61 {
62   Mat_MPIAIJ *mpiaij = (Mat_MPIAIJ *)mat->data;
63   PetscInt    nt;
64 
65   PetscFunctionBegin;
66   PetscCall(VecGetLocalSize(xx, &nt));
67   PetscCheck(nt == mat->cmap->n, PETSC_COMM_SELF, PETSC_ERR_ARG_SIZ, "Incompatible partition of mat (%" PetscInt_FMT ") and xx (%" PetscInt_FMT ")", mat->cmap->n, nt);
68   PetscCall(VecScatterBegin(mpiaij->Mvctx, xx, mpiaij->lvec, INSERT_VALUES, SCATTER_FORWARD));
69   PetscCall((*mpiaij->A->ops->multadd)(mpiaij->A, xx, yy, zz));
70   PetscCall(VecScatterEnd(mpiaij->Mvctx, xx, mpiaij->lvec, INSERT_VALUES, SCATTER_FORWARD));
71   PetscCall((*mpiaij->B->ops->multadd)(mpiaij->B, mpiaij->lvec, zz, zz));
72   PetscFunctionReturn(PETSC_SUCCESS);
73 }
74 
75 static PetscErrorCode MatMultTranspose_MPIAIJKokkos(Mat mat, Vec xx, Vec yy)
76 {
77   Mat_MPIAIJ *mpiaij = (Mat_MPIAIJ *)mat->data;
78   PetscInt    nt;
79 
80   PetscFunctionBegin;
81   PetscCall(VecGetLocalSize(xx, &nt));
82   PetscCheck(nt == mat->rmap->n, PETSC_COMM_SELF, PETSC_ERR_ARG_SIZ, "Incompatible partition of mat (%" PetscInt_FMT ") and xx (%" PetscInt_FMT ")", mat->rmap->n, nt);
83   PetscCall((*mpiaij->B->ops->multtranspose)(mpiaij->B, xx, mpiaij->lvec));
84   PetscCall((*mpiaij->A->ops->multtranspose)(mpiaij->A, xx, yy));
85   PetscCall(VecScatterBegin(mpiaij->Mvctx, mpiaij->lvec, yy, ADD_VALUES, SCATTER_REVERSE));
86   PetscCall(VecScatterEnd(mpiaij->Mvctx, mpiaij->lvec, yy, ADD_VALUES, SCATTER_REVERSE));
87   PetscFunctionReturn(PETSC_SUCCESS);
88 }
89 
90 /* Merge the "A, B" matrices of mat into a matrix C.  mat's type is MPIAIJKOKKOS. C's type is MATSEQAIJKOKKOS.
91    A is put before B. C's size would be A->rmap->n by (A->cmap->n + B->cmap->n).
92    C still uses local column ids. Their corresponding global column ids are returned in glob.
93 */
94 static PetscErrorCode MatMPIAIJGetLocalMatMerge_MPIAIJKokkos(Mat mat, MatReuse reuse, IS *glob, Mat *C)
95 {
96   Mat             Ad, Ao;
97   const PetscInt *cmap;
98 
99   PetscFunctionBegin;
100   PetscCall(MatMPIAIJGetSeqAIJ(mat, &Ad, &Ao, &cmap));
101   PetscCall(MatSeqAIJKokkosMergeMats(Ad, Ao, reuse, C));
102   if (glob) {
103     PetscInt cst, i, dn, on, *gidx;
104     PetscCall(MatGetLocalSize(Ad, NULL, &dn));
105     PetscCall(MatGetLocalSize(Ao, NULL, &on));
106     PetscCall(MatGetOwnershipRangeColumn(mat, &cst, NULL));
107     PetscCall(PetscMalloc1(dn + on, &gidx));
108     for (i = 0; i < dn; i++) gidx[i] = cst + i;
109     for (i = 0; i < on; i++) gidx[i + dn] = cmap[i];
110     PetscCall(ISCreateGeneral(PetscObjectComm((PetscObject)Ad), dn + on, gidx, PETSC_OWN_POINTER, glob));
111   }
112   PetscFunctionReturn(PETSC_SUCCESS);
113 }
114 
115 /* Structs used in matrix products of type C=AB, C=A^tB and C=B^tAB */
116 struct MatMatStruct {
117   PetscInt            n, *garray;     // C's garray and its size.
118   KokkosCsrMatrix     Cd, Co;         // C is in split form matrices (all in local column indcies)
119   KokkosCsrMatrix     C1, C2, C3, C4; // intermediate mat products
120   KokkosCsrMatrix     C2_mid, C4_mid; // alias of C2, C4; share their a[], i[], but with different j[] (hence column size)
121   PetscIntKokkosView  E_NzLeft;
122   PetscSF             sf = nullptr; // SF to bcast or reduce matrices E to F
123   MatScalarKokkosView rootBuf, leafBuf;
124   KokkosCsrMatrix     Fd, Fo; // F in split form
125 
126   KernelHandle kh1; // compute C1, add C1+C3 or C1+Fd
127   KernelHandle kh2; // compute C2, add C2+C4 or C2+Fo
128   KernelHandle kh3; // compute C3
129   KernelHandle kh4; // compute C4
130 
131   PetscInt E_TeamSize; // kernel launching parameters in merging E or splitting F
132   PetscInt E_VectorLength;
133   PetscInt E_RowsPerTeam;
134   PetscInt F_TeamSize;
135   PetscInt F_VectorLength;
136   PetscInt F_RowsPerTeam;
137 
138   ~MatMatStruct()
139   {
140     PetscFunctionBegin;
141     PetscCallAbort(PETSC_COMM_SELF, PetscSFDestroy(&sf));
142     PetscFunctionReturnVoid();
143   }
144 };
145 
146 struct MatMatStruct_AB : public MatMatStruct {
147   PetscIntKokkosView F_NzLeft; // plans to split F (in leafbuf) into Fd, Fo
148   PetscIntKokkosView irootloc; // plans to put E (i.e., Bd, Bo) into rootBuf
149   PetscIntKokkosView rowoffset;
150 };
151 
152 struct MatMatStruct_AtB : public MatMatStruct {
153   MatColIdxKokkosView Fdjmap; // plans to reduce data in rootBuf to Fd, Fo
154   MatColIdxKokkosView Fdjperm;
155   MatColIdxKokkosView Fojmap;
156   MatColIdxKokkosView Fojperm;
157 };
158 
159 struct MatProductData_MPIAIJKokkos {
160   MatMatStruct_AB  *mmAB     = nullptr;
161   MatMatStruct_AtB *mmAtB    = nullptr;
162   PetscBool         reusesym = PETSC_FALSE;
163   Mat               Z        = nullptr; // store Z=AB in computing BtAB
164 
165   ~MatProductData_MPIAIJKokkos()
166   {
167     delete mmAB;
168     delete mmAtB;
169     PetscCallAbort(PETSC_COMM_SELF, MatDestroy(&Z));
170   }
171 };
172 
173 static PetscErrorCode MatProductDataDestroy_MPIAIJKokkos(void *data)
174 {
175   PetscFunctionBegin;
176   PetscCallCXX(delete static_cast<MatProductData_MPIAIJKokkos *>(data));
177   PetscFunctionReturn(PETSC_SUCCESS);
178 }
179 
180 /* MatSetMPIAIJKokkosWithSplitSeqAIJKokkosMatrices - Set the diag and offdiag matrices of a MATMPIAIJKOKKOS matrix.
181    It is similar to MatCreateMPIAIJWithSplitArrays.
182 
183   Input Parameters:
184 +  mat   - the MATMPIAIJKOKKOS matrix, which should have its type and layout set, but should not have its diag, offdiag matrices set
185 .  A     - the diag matrix using local col ids
186 -  B     - the offdiag matrix using global col ids
187 
188   Output Parameter:
189 .  mat   - the updated MATMPIAIJKOKKOS matrix
190 */
191 static PetscErrorCode MatSetMPIAIJKokkosWithSplitSeqAIJKokkosMatrices(Mat mat, Mat A, Mat B, PetscInt *garray)
192 {
193   Mat_MPIAIJ *mpiaij = static_cast<Mat_MPIAIJ *>(mat->data);
194   PetscInt    m, n, M, N, Am, An, Bm, Bn;
195 
196   PetscFunctionBegin;
197   PetscCall(MatGetSize(mat, &M, &N));
198   PetscCall(MatGetLocalSize(mat, &m, &n));
199   PetscCall(MatGetLocalSize(A, &Am, &An));
200   PetscCall(MatGetLocalSize(B, &Bm, &Bn));
201 
202   PetscCheck(m == Am && m == Bm, PETSC_COMM_SELF, PETSC_ERR_PLIB, "local number of rows do not match");
203   PetscCheck(n == An, PETSC_COMM_SELF, PETSC_ERR_PLIB, "local number of columns do not match");
204   // PetscCheck(N == Bn, PETSC_COMM_SELF, PETSC_ERR_PLIB, "global number of columns do not match");
205   PetscCheck(!mpiaij->A && !mpiaij->B, PETSC_COMM_SELF, PETSC_ERR_PLIB, "A, B of the MPIAIJ matrix are not empty");
206   mpiaij->A      = A;
207   mpiaij->B      = B;
208   mpiaij->garray = garray;
209 
210   mat->preallocated     = PETSC_TRUE;
211   mat->nooffprocentries = PETSC_TRUE; /* See MatAssemblyBegin_MPIAIJ. In effect, making MatAssemblyBegin a nop */
212 
213   PetscCall(MatSetOption(mat, MAT_NO_OFF_PROC_ENTRIES, PETSC_TRUE));
214   PetscCall(MatAssemblyBegin(mat, MAT_FINAL_ASSEMBLY));
215   /* MatAssemblyEnd is critical here. It sets mat->offloadmask according to A and B's, and
216     also gets mpiaij->B compacted, with its col ids and size reduced
217   */
218   PetscCall(MatAssemblyEnd(mat, MAT_FINAL_ASSEMBLY));
219   PetscCall(MatSetOption(mat, MAT_NO_OFF_PROC_ENTRIES, PETSC_FALSE));
220   PetscCall(MatSetOption(mat, MAT_NEW_NONZERO_LOCATION_ERR, PETSC_TRUE));
221   PetscFunctionReturn(PETSC_SUCCESS);
222 }
223 
224 // Adapted from Kokkos-Kernels spmv_launch_parameters(), to get parameters in Kokkos nested loops which we used to merge or
225 // split csr matrices. The rule is to have "vector_length * team_size" be around 256 on GPUs (e.g., for a CUDA thread block)
226 template <class ExecutionSpace>
227 static PetscErrorCode MatMergeGetLaunchParameters(PetscInt numRows, PetscInt nnz, PetscInt rows_per_thread, PetscInt &team_size, PetscInt &vector_length, PetscInt &rows_per_team)
228 {
229   Kokkos::TeamPolicy<ExecutionSpace> teamPolicy(128, Kokkos::AUTO);
230 
231   PetscFunctionBegin;
232   PetscInt nnz_per_row = numRows ? (nnz / numRows) : 0; // we might meet empty matrices
233 
234   if (nnz_per_row < 1) nnz_per_row = 1;
235 
236   int max_vector_length = teamPolicy.vector_length_max();
237 
238   if (vector_length < 1) {
239     vector_length = 1;
240     while (vector_length < max_vector_length && vector_length * 6 < nnz_per_row) vector_length *= 2;
241   }
242 
243   // Determine rows per thread
244   if (rows_per_thread < 1) {
245     if (KokkosKernels::Impl::kk_is_gpu_exec_space<ExecutionSpace>()) rows_per_thread = 1;
246     else {
247       if (nnz_per_row < 20 && nnz > 5000000) {
248         rows_per_thread = 256;
249       } else rows_per_thread = 64;
250     }
251   }
252 
253   if (team_size < 1) {
254     if (KokkosKernels::Impl::kk_is_gpu_exec_space<ExecutionSpace>()) {
255       team_size = 256 / vector_length;
256     } else {
257       team_size = 1;
258     }
259   }
260 
261   rows_per_team = rows_per_thread * team_size;
262 
263   if (rows_per_team < 0) {
264     PetscInt nnz_per_team = 4096;
265     PetscInt conc         = ExecutionSpace().concurrency();
266     while ((conc * nnz_per_team * 4 > nnz) && (nnz_per_team > 256)) nnz_per_team /= 2;
267     rows_per_team = (nnz_per_team + nnz_per_row - 1) / nnz_per_row;
268   }
269   PetscFunctionReturn(PETSC_SUCCESS);
270 }
271 
272 /*
273   Reduce two sets of global indices into local ones
274 
275   Input Parameters:
276 +  n1          - size of garray1[], the first set
277 .  garray1[n1] - a sorted global index array (without duplicates)
278 .  m           - size of indices[], the second set
279 -  indices[m]  - a unsorted global index array (might have duplicates), which will be updated on output into local ones
280 
281   Output Parameters:
282 +  n2          - size of garray2[], the merged set, which combines garray1[] and indices[]
283 .  garray2[n2] - allocated by callee using PetscMalloc1(). Contains sorted unique global indices (without duplicates). Caller needs to free it.
284 .  map[n1]     - allocated by caller. It gives garray1[i] = garray2[map[i]]
285 -  indices[m]  - on output, global indices in this array are rewritten with local ones, i.e, indices_input[i] = garray2[indices_output[i]]
286 
287    Example, say
288     n1         = 5
289     garray1[5] = {1, 4, 7, 8, 10}
290     m          = 4
291     indices[4] = {2, 4, 8, 9}
292 
293    Combining them together, we have 7 global indices in garray2[]
294     n2         = 7
295     garray2[7] = {1, 2, 4, 7, 8, 9, 10}
296 
297    And we have map[] to connect "garray1[i] = garray2[map[i]], i=[0,n1)"
298     map[5] = {0, 2, 3, 4, 6}
299 
300    On output, indices[] is updated with local indices
301     indices[4] = {1, 2, 4, 5}
302 */
303 static PetscErrorCode ReduceTwoSetsOfGlobalIndices(PetscInt n1, const PetscInt *garray1, PetscInt m, PetscInt *indices, PetscInt *n2_, PetscInt **garray2_, PetscInt *map)
304 {
305   PetscHMapI    g2l = nullptr;
306   PetscHashIter iter;
307   PetscInt      tot, key, val; // total unique global indices. key is global id; val is local id
308   PetscInt      n2, *garray2;
309 
310   PetscFunctionBegin;
311   tot = 0;
312   PetscCall(PetscHMapICreateWithSize(n1, &g2l));
313   for (PetscInt i = 0; i < m; i++) {                                // insert those in indices[]
314     PetscCall(PetscHMapIGetWithDefault(g2l, indices[i], -1, &val)); // if not exist, val is set with -1
315     if (val < 0) PetscCall(PetscHMapISet(g2l, indices[i], tot++));  // val < 0 means gid is not in the hash table yet
316   }
317 
318   for (PetscInt i = 0; i < n1; i++) { // insert those in garray1[]
319     PetscCall(PetscHMapIGetWithDefault(g2l, garray1[i], -1, &val));
320     if (val < 0) PetscCall(PetscHMapISet(g2l, garray1[i], tot++));
321   }
322 
323   // Pull out (unique) globals in the hash table and put them in garray2[]
324   n2 = tot;
325   PetscCall(PetscMalloc1(n2, &garray2));
326   tot = 0;
327   PetscHashIterBegin(g2l, iter);
328   while (!PetscHashIterAtEnd(g2l, iter)) {
329     PetscHashIterGetKey(g2l, iter, key);
330     PetscHashIterNext(g2l, iter);
331     garray2[tot++] = key;
332   }
333 
334   // Sort garray2[] and then map them to local indices starting from 0
335   PetscCall(PetscSortInt(n2, garray2));
336   PetscCall(PetscHMapIClear(g2l));
337   for (PetscInt i = 0; i < tot; i++) PetscCall(PetscHMapISet(g2l, garray2[i], i)); // i is the local id
338 
339   // Rewrite indices[] with local indices
340   for (PetscInt i = 0; i < m; i++) {
341     PetscCall(PetscHMapIGetWithDefault(g2l, indices[i], -1, &val));
342     PetscAssert(val >= 0, PETSC_COMM_SELF, PETSC_ERR_PLIB, "Met a negative local column index");
343     indices[i] = val;
344   }
345   // Record the map that maps garray1[i] to garray2[map[i]]
346   for (PetscInt i = 0; i < n1; i++) PetscCall(PetscHMapIGetWithDefault(g2l, garray1[i], -1, &map[i]));
347   PetscCall(PetscHMapIDestroy(&g2l));
348   *n2_      = n2;
349   *garray2_ = garray2;
350   PetscFunctionReturn(PETSC_SUCCESS);
351 }
352 
353 /*
354   MatMPIAIJKokkosReduce - Reduce rows of a MPIAIJKOKKOS matrix (E, in split form) to produce another matrix (F, also in split form, stored in mm)
355 
356   It is the reverse of MatMPIAIJKokkosBcast() in some sense, but with a different signature since we do not really need a fully populated MPIAIJKOKKOS E.
357 
358   Think each row of E as a leaf, then the given ownerSF specifies roots for the leaves. Roots may connect to multiple leaves.
359   In this routine, we sparse-merge leaves (rows) at their roots to form potentially longer rows in F. F's number of rows will be nroots of ownerSF.
360 
361   Input Parameters:
362 +  comm       - MPI communicator of E
363 .  A          - diag block of E, using local column indices
364 .  B          - off-diag block of E, using local column indices
365 .  cstart      - (global) start column of Ed
366 .  cend        - (global) end column + 1 of Ed.  In other words, E's column ownership is in range of [cstart, cend)
367 .  garray1[n1] - global column indices of Eo. Here n1 is Eo's column size.
368 .  ownerSF     - the SF specifies ownership (root) of rows in E
369 .  reuse       - either MAT_INITIAL_MATRIX or MAT_REUSE_MATRIX
370 -  mm          - to stash intermediate data structures for reuse
371 
372   Output Parameters:
373 +  map[n1]  - allocated by caller. It maps garray1[] to garray2[]. See more at ReduceTwoSetsOfGlobalIndices().
374 -  mm       - contains various info, such as garray2[], F (Fd, Fo) etc.
375 
376   Notes:
377   When reuse = MAT_REUSE_MATRIX, cstart, cend, garray1, ownerSF, map are not significant.
378 
379  */
380 static PetscErrorCode MatMPIAIJKokkosReduceBegin(MPI_Comm comm, KokkosCsrMatrix A, KokkosCsrMatrix B, PetscInt cstart, PetscInt cend, const PetscInt *garray1, PetscSF ownerSF, MatReuse reuse, PetscInt *map, MatMatStruct_AtB *mm)
381 {
382   PetscFunctionBegin;
383   if (reuse == MAT_INITIAL_MATRIX) {
384     PetscInt Em = A.numRows(), Fm;
385     PetscInt n1 = B.numCols();
386 
387     PetscCall(PetscSFGetGraph(ownerSF, &Fm, NULL, NULL, NULL)); // Fm = #rows of F = nroots of ownerSF
388 
389     // Do the analysis on host
390     auto                 Ai_h = Kokkos::create_mirror_view_and_copy(Kokkos::HostSpace(), A.graph.row_map);
391     auto                 Aj_h = Kokkos::create_mirror_view_and_copy(Kokkos::HostSpace(), A.graph.entries);
392     auto                 Bi_h = Kokkos::create_mirror_view_and_copy(Kokkos::HostSpace(), B.graph.row_map);
393     auto                 Bj_h = Kokkos::create_mirror_view_and_copy(Kokkos::HostSpace(), B.graph.entries);
394     const MatRowMapType *Ai = Ai_h.data(), *Bi = Bi_h.data();
395     const MatColIdxType *Aj = Aj_h.data(), *Bj = Bj_h.data();
396 
397     // Count how many nonzeros of each row in E are in the left of the diag block [cstart,cend)
398     PetscIntKokkosViewHost E_NzLeft_h(NoInit("E_NzLeft_h"), Em), E_RowLen_h(NoInit("E_RowLen_h"), Em);
399     PetscInt              *E_NzLeft = E_NzLeft_h.data(), *E_RowLen = E_RowLen_h.data();
400     for (PetscInt i = 0; i < Em; i++) {
401       const PetscInt *first, *last, *it;
402       PetscInt        count, step;
403       // std::lower_bound(first,last,cstart), but need to use global column indices
404       first = Bj + Bi[i];
405       last  = Bj + Bi[i + 1];
406       count = last - first;
407       while (count > 0) {
408         it   = first;
409         step = count / 2;
410         it += step;
411         if (garray1[*it] < cstart) { // map local to global
412           first = ++it;
413           count -= step + 1;
414         } else count = step;
415       }
416       E_NzLeft[i] = first - (Bj + Bi[i]);
417       E_RowLen[i] = (Ai[i + 1] - Ai[i]) + (Bi[i + 1] - Bi[i]);
418     }
419 
420     // Get length of rows (i.e., sizes of leaves) that contribute to my roots
421     const PetscMPIInt *iranks, *ranks;
422     const PetscInt    *ioffset, *irootloc, *roffset, *rmine;
423     PetscInt           niranks, nranks;
424     MPI_Request       *reqs;
425     PetscMPIInt        tag;
426     PetscSF            reduceSF;
427     PetscInt          *sdisp, *rdisp;
428 
429     PetscCall(PetscCommGetNewTag(comm, &tag));
430     PetscCall(PetscSFGetLeafRanks(ownerSF, &niranks, &iranks, &ioffset, &irootloc));  // get leaf ranks connecting to roots on this process (I'll recv from them)
431     PetscCall(PetscSFGetRootRanks(ownerSF, &nranks, &ranks, &roffset, &rmine, NULL)); // get root ranks this process connects (I'll send to them)
432 
433     // Find out length of each row I will receive. Even for the same row index, when they are from
434     // different senders, they might have different lengths (and sparsity patterns)
435     PetscInt  sendRowCnt = roffset[nranks], recvRowCnt = ioffset[niranks];
436     PetscInt *sendRowLen, *recvRowLen; // lengths of rows of I need to send/recv per process
437 
438     PetscCall(PetscMalloc5(sendRowCnt, &sendRowLen, recvRowCnt + 1, &recvRowLen, nranks, &sdisp, niranks + 1, &rdisp, nranks + niranks, &reqs));
439 
440     for (PetscInt i = 0; i < sendRowCnt; i++) sendRowLen[i] = E_RowLen[rmine[i]];
441     recvRowLen[0] = 0; // since we will make it in CSR format later
442     recvRowLen++;      // advance the pointer now
443     for (PetscInt i = 0; i < niranks; i++) MPI_Irecv(&recvRowLen[ioffset[i]], ioffset[i + 1] - ioffset[i], MPIU_INT, iranks[i], tag, comm, &reqs[nranks + i]);
444     for (PetscInt i = 0; i < nranks; i++) MPIU_Isend(&sendRowLen[roffset[i]], roffset[i + 1] - roffset[i], MPIU_INT, ranks[i], tag, comm, &reqs[i]);
445     PetscCallMPI(MPI_Waitall(nranks + niranks, reqs, MPI_STATUSES_IGNORE));
446 
447     // Build the real PetscSF for reducing E rows (buffer to buffer)
448     rdisp[0] = 0;
449     for (PetscInt i = 0; i < niranks; i++) {
450       rdisp[i + 1] = rdisp[i];
451       for (PetscInt j = ioffset[i]; j < ioffset[i + 1]; j++) { rdisp[i + 1] += recvRowLen[j]; }
452     }
453     recvRowLen--; // put it back into csr format
454     for (PetscInt i = 0; i < recvRowCnt; i++) recvRowLen[i + 1] += recvRowLen[i];
455 
456     for (PetscInt i = 0; i < nranks; i++) MPIU_Irecv(&sdisp[i], 1, MPIU_INT, ranks[i], tag, comm, &reqs[i]);
457     for (PetscInt i = 0; i < niranks; i++) MPIU_Isend(&rdisp[i], 1, MPIU_INT, iranks[i], tag, comm, &reqs[nranks + i]);
458     PetscCallMPI(MPI_Waitall(nranks + niranks, reqs, MPI_STATUSES_IGNORE));
459 
460     PetscInt     nleaves = 0, Enz = 0;    // leaves are nonzeros I will send
461     PetscInt     nroots = rdisp[niranks]; // roots are nonzeros I will recv
462     PetscSFNode *iremote;
463 
464     for (PetscInt i = 0; i < Em; i++) Enz += E_RowLen[i];
465     PetscAssert(A.nnz() + B.nnz() == Enz, comm, PETSC_ERR_PLIB, "Enz should be equal to sum of nnz of A and B");
466     PetscCallMPI(PetscMalloc1(Enz, &iremote)); // no free, since we give ownership to reduceSF
467 
468     for (PetscInt i = 0; i < nranks; i++) {
469       PetscInt count = 0;
470       for (PetscInt j = roffset[i]; j < roffset[i + 1]; j++) count += E_RowLen[rmine[j]];
471       for (PetscInt j = 0; j < count; j++) {
472         iremote[nleaves + j].rank  = ranks[i];
473         iremote[nleaves + j].index = sdisp[i] + j;
474       }
475       nleaves += count;
476     }
477     PetscCheck(nleaves == Enz, comm, PETSC_ERR_PLIB, "nleaves should be equal to Enz");
478 
479     PetscCall(PetscSFCreate(comm, &reduceSF));
480     PetscCall(PetscSFSetGraph(reduceSF, nroots, nleaves, NULL, PETSC_OWN_POINTER, iremote, PETSC_OWN_POINTER));
481 
482     // Copy (global) column indices of the needed rows in E to sendCol[], and then PetscSFReduce to recvCol[]
483     PetscInt *sendCol, *recvCol;
484     PetscCall(PetscMalloc2(nleaves, &sendCol, nroots, &recvCol));
485     for (PetscInt k = 0; k < roffset[nranks]; k++) {
486       PetscInt  i      = rmine[k]; // row to be copied
487       PetscInt *buf    = &sendCol[Ai[i] + Bi[i]];
488       PetscInt  nzLeft = E_NzLeft[i];
489       PetscInt  alen = Ai[i + 1] - Ai[i], blen = Bi[i + 1] - Bi[i];
490       for (PetscInt j = 0; j < alen + blen; j++) {
491         if (j < nzLeft) {
492           buf[j] = garray1[Bj[Bi[i] + j]]; // left B, in global
493         } else if (j < nzLeft + alen) {
494           buf[j] = Aj[Ai[i] + j - nzLeft] + cstart; // diag A, also in global
495         } else {
496           buf[j] = garray1[Bj[Bi[i] + j - alen]]; // right B, in global
497         }
498       }
499     }
500     PetscCall(PetscSFReduceWithMemTypeBegin(reduceSF, MPIU_INT, PETSC_MEMTYPE_HOST, sendCol, PETSC_MEMTYPE_HOST, recvCol, MPI_REPLACE));
501     PetscCall(PetscSFReduceEnd(reduceSF, MPIU_INT, sendCol, recvCol, MPI_REPLACE));
502 
503     // With recvCol[], we do a series of analysis to get i, j of Fd, Fo, and build plans to reduce nonzeros in recv buffers to Fd and Fo
504     PetscInt *recvRowPerm, *recvColSorted;
505     PetscInt *recvNzPerm, *recvNzPermSorted;
506     PetscCall(PetscMalloc4(recvRowCnt, &recvRowPerm, nroots, &recvColSorted, nroots, &recvNzPerm, nroots, &recvNzPermSorted));
507 
508     for (PetscInt i = 0; i < nroots; i++) recvNzPerm[i] = i;                   // numbering all received nonzeros
509     for (PetscInt i = 0; i < recvRowCnt; i++) recvRowPerm[i] = i;              // put up a permutation array, so that after sorting we know where to get a row in recvCol[]
510     PetscCall(PetscSortIntWithPermutation(recvRowCnt, irootloc, recvRowPerm)); // irootloc[] (owned by ownerSF) won't be changed
511 
512     // i[] array, nz are always easiest to compute
513     MatRowMapKokkosViewHost Fdi_h(NoInit("Fdi_h"), Fm + 1), Foi_h(NoInit("Foi_h"), Fm + 1);
514     MatRowMapType          *Fdi, *Foi;
515     PetscInt                FnzDups = 0, Fdnz = 0, FdnzDups = 0, Fonz = 0, FonzDups = 0; // nz (with or without dups) in F, Fd, Fo
516     PetscInt                iter;
517 
518     Kokkos::deep_copy(Fdi_h, 0); // zero, as we will do 'val++' on them
519     Kokkos::deep_copy(Foi_h, 0);
520     Fdi  = Fdi_h.data() + 1; // +1 for easy indexing in code below
521     Foi  = Foi_h.data() + 1;
522     iter = 0;
523     while (iter < recvRowCnt) { // iter over received rows
524       PetscInt curRowIdx = irootloc[recvRowPerm[iter]];
525       PetscInt dupRows   = 1; // current row has this many contributing rows (of various sparsity patterns)
526 
527       while (iter + dupRows < recvRowCnt && irootloc[recvRowPerm[iter + dupRows]] == curRowIdx) dupRows++;
528 
529       // Copy column indices (and their permutation) of these rows into recvColSorted & recvNzPermSorted
530       PetscInt  nz    = 0; // nz (with dups) in the current row
531       PetscInt *jbuf  = recvColSorted + FnzDups;
532       PetscInt *pbuf  = recvNzPermSorted + FnzDups;
533       PetscInt *jbuf2 = jbuf; // temp pointers
534       PetscInt *pbuf2 = pbuf;
535       for (PetscInt d = 0; d < dupRows; d++) {
536         PetscInt i   = recvRowPerm[iter + d];
537         PetscInt len = recvRowLen[i + 1] - recvRowLen[i];
538         PetscCall(PetscArraycpy(jbuf2, &recvCol[recvRowLen[i]], len));
539         PetscCall(PetscArraycpy(pbuf2, &recvNzPerm[recvRowLen[i]], len));
540         jbuf2 += len;
541         pbuf2 += len;
542         nz += len;
543       }
544       PetscCall(PetscIntSortSemiOrderedWithArray(nz, jbuf, pbuf)); // It could be improved with k-way merge sort, since the rows are already sorted
545 
546       // Scan column indices (in jbuf[0,nz), might have dups) of this row, and see how many go to Fd and how many go to Fo
547       PetscInt cur = 0;
548       while (cur < nz) {
549         PetscInt curColIdx = jbuf[cur];
550         PetscInt dups      = 1;
551 
552         while (cur + dups < nz && jbuf[cur + dups] == curColIdx) dups++;
553         if (curColIdx >= cstart && curColIdx < cend) {
554           Fdi[curRowIdx]++;
555           FdnzDups += dups;
556         } else {
557           Foi[curRowIdx]++;
558           FonzDups += dups;
559         }
560         cur += dups;
561       }
562 
563       FnzDups += nz;
564       iter += dupRows; // Move to next unique row
565     }
566 
567     Fdi = Fdi_h.data(); // restore Fdi, Foi and make them CSR
568     Foi = Foi_h.data();
569     for (PetscInt i = 0; i < Fm; i++) {
570       Fdi[i + 1] += Fdi[i];
571       Foi[i + 1] += Foi[i];
572     }
573     Fdnz = Fdi[Fm];
574     Fonz = Foi[Fm];
575     PetscCall(PetscFree2(sendCol, recvCol));
576 
577     // Allocate j, jmap, jperm for Fd and Fo
578     MatColIdxKokkosViewHost Fdj_h(NoInit("Fdj_h"), Fdnz), Foj_h(NoInit("Foj_h"), Fonz);
579     MatRowMapKokkosViewHost Fdjmap_h(NoInit("Fdjmap_h"), Fdnz + 1), Fojmap_h(NoInit("Fojmap_h"), Fonz + 1); // +1 to make csr
580     MatRowMapKokkosViewHost Fdjperm_h(NoInit("Fdjperm_h"), FdnzDups), Fojperm_h(NoInit("Fojperm_h"), FonzDups);
581     MatColIdxType          *Fdj = Fdj_h.data(), *Foj = Foj_h.data();
582     MatRowMapType          *Fdjmap = Fdjmap_h.data(), *Fojmap = Fojmap_h.data();
583     MatRowMapType          *Fdjperm = Fdjperm_h.data(), *Fojperm = Fojperm_h.data();
584 
585     // Scan recvColSorted[] again, and fill j, jmap, jperm for Fd and Fo
586     Fdjmap[0] = 0;
587     Fojmap[0] = 0;
588     FnzDups   = 0;
589     Fdnz      = 0;
590     Fonz      = 0;
591     iter      = 0; // iter over received rows
592     while (iter < recvRowCnt) {
593       PetscInt curRowIdx = irootloc[recvRowPerm[iter]]; // current row idx
594       PetscInt dupRows   = 1;                           // It has this many contributing rows (of various lengths)
595       PetscInt nz        = 0;                           // nz (with dups) in the current row
596 
597       while (iter + dupRows < recvRowCnt && irootloc[recvRowPerm[iter + dupRows]] == curRowIdx) dupRows++;
598       for (PetscInt d = 0; d < dupRows; d++) {
599         PetscInt i = recvRowPerm[iter + d];
600         nz += recvRowLen[i + 1] - recvRowLen[i];
601       }
602 
603       PetscInt *jbuf = recvColSorted + FnzDups;
604       // Scan columns (in jbuf[0,nz) of this row, copy them and their permutation to j[] and jperm[] of Fd and Fo
605       PetscInt cur = 0;
606       while (cur < nz) {
607         PetscInt curColIdx = jbuf[cur];
608         PetscInt dups      = 1;
609 
610         while (cur + dups < nz && jbuf[cur + dups] == curColIdx) dups++;
611         if (curColIdx >= cstart && curColIdx < cend) {
612           Fdj[Fdnz]        = curColIdx - cstart; // easily convert to local
613           Fdjmap[Fdnz + 1] = Fdjmap[Fdnz] + dups;
614           for (PetscInt j = 0; j < dups; j++) Fdjperm[Fdjmap[Fdnz] + j] = recvNzPermSorted[FnzDups + j];
615           FdnzDups += dups;
616           Fdnz++;
617         } else {
618           Foj[Fonz]        = curColIdx; // in global
619           Fojmap[Fonz + 1] = Fojmap[Fonz] + dups;
620           for (PetscInt j = 0; j < dups; j++) Fojperm[Fojmap[Fonz] + j] = recvNzPermSorted[FnzDups + j];
621           FonzDups += dups;
622           Fonz++;
623         }
624         cur += dups;
625         FnzDups += dups;
626       }
627       iter += dupRows; // Move to next unique row
628     }
629     PetscCall(PetscFree4(recvRowPerm, recvColSorted, recvNzPerm, recvNzPermSorted));
630     PetscCall(PetscFree5(sendRowLen, recvRowLen, sdisp, rdisp, reqs));
631 
632     // Combine global column indices in garray1[] and Foj[]
633     PetscInt n2, *garray2;
634 
635     PetscCall(ReduceTwoSetsOfGlobalIndices(n1, garray1, Fonz, Foj, &n2, &garray2, map));
636     mm->sf       = reduceSF;
637     mm->leafBuf  = MatScalarKokkosView(NoInit("leafBuf"), nleaves);
638     mm->rootBuf  = MatScalarKokkosView(NoInit("rootBuf"), nroots);
639     mm->garray   = garray2; // give ownership, so no free
640     mm->n        = n2;
641     mm->E_NzLeft = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), E_NzLeft_h);
642     mm->Fdjmap   = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Fdjmap_h);
643     mm->Fdjperm  = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Fdjperm_h);
644     mm->Fojmap   = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Fojmap_h);
645     mm->Fojperm  = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Fojperm_h);
646 
647     // Output Fd and Fo in KokkosCsrMatrix format
648     MatScalarKokkosView Fda_d(NoInit("Fda_d"), Fdnz);
649     MatRowMapKokkosView Fdi_d = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Fdi_h);
650     MatColIdxKokkosView Fdj_d = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Fdj_h);
651     MatScalarKokkosView Foa_d(NoInit("Foa_d"), Fonz);
652     MatRowMapKokkosView Foi_d = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Foi_h);
653     MatColIdxKokkosView Foj_d = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Foj_h);
654 
655     PetscCallCXX(mm->Fd = KokkosCsrMatrix("Fd", Fm, cend - cstart, Fdnz, Fda_d, Fdi_d, Fdj_d));
656     PetscCallCXX(mm->Fo = KokkosCsrMatrix("Fo", Fm, n2, Fonz, Foa_d, Foi_d, Foj_d)); // Fo's column size is n2, length of garray2[]
657 
658     // Compute kernel launch parameters in merging E
659     PetscInt teamSize, vectorLength, rowsPerTeam;
660 
661     teamSize = vectorLength = rowsPerTeam = -1;
662     PetscCall(MatMergeGetLaunchParameters<DefaultExecutionSpace>(Em, Enz, -1, teamSize, vectorLength, rowsPerTeam));
663     mm->E_TeamSize     = teamSize;
664     mm->E_VectorLength = vectorLength;
665     mm->E_RowsPerTeam  = rowsPerTeam;
666   } else PetscCheck(reuse == MAT_REUSE_MATRIX, comm, PETSC_ERR_PLIB, "Unsupported MatReuse enum %d", reuse);
667 
668   // Handy aliases
669   auto       &Aa           = A.values;
670   auto       &Ba           = B.values;
671   const auto &Ai           = A.graph.row_map;
672   const auto &Bi           = B.graph.row_map;
673   const auto &E_NzLeft     = mm->E_NzLeft;
674   auto       &leafBuf      = mm->leafBuf;
675   auto       &rootBuf      = mm->rootBuf;
676   PetscSF     reduceSF     = mm->sf;
677   PetscInt    Em           = A.numRows();
678   PetscInt    teamSize     = mm->E_TeamSize;
679   PetscInt    vectorLength = mm->E_VectorLength;
680   PetscInt    rowsPerTeam  = mm->E_RowsPerTeam;
681   PetscInt    workSets     = (Em + rowsPerTeam - 1) / rowsPerTeam;
682 
683   // Copy rows in A/B of E to leafBuf, then pass it to rootBuf
684   PetscCallCXX(Kokkos::parallel_for(
685     Kokkos::TeamPolicy<>(PetscGetKokkosExecutionSpace(), workSets, teamSize, vectorLength), KOKKOS_LAMBDA(const KokkosTeamMemberType &t) {
686       Kokkos::parallel_for(Kokkos::TeamThreadRange(t, 0, rowsPerTeam), [&](PetscInt k) {
687         PetscInt i = t.league_rank() * rowsPerTeam + k; // i-th row in F
688         if (i < Em) {
689           PetscInt disp   = Ai(i) + Bi(i);
690           PetscInt alen   = Ai(i + 1) - Ai(i);
691           PetscInt blen   = Bi(i + 1) - Bi(i);
692           PetscInt nzleft = E_NzLeft(i);
693 
694           Kokkos::parallel_for(Kokkos::ThreadVectorRange(t, alen + blen), [&](PetscInt j) {
695             MatScalar &val = leafBuf(disp + j);
696             if (j < nzleft) { // B left
697               val = Ba(Bi(i) + j);
698             } else if (j < nzleft + alen) { // diag A
699               val = Aa(Ai(i) + j - nzleft);
700             } else { // B right
701               val = Ba(Bi(i) + j - alen);
702             }
703           });
704         }
705       });
706     }));
707   PetscCall(PetscSFReduceWithMemTypeBegin(reduceSF, MPIU_SCALAR, PETSC_MEMTYPE_KOKKOS, leafBuf.data(), PETSC_MEMTYPE_KOKKOS, rootBuf.data(), MPI_REPLACE));
708   PetscFunctionReturn(PETSC_SUCCESS);
709 }
710 
711 // To finish MatMPIAIJKokkosReduce.
712 static PetscErrorCode MatMPIAIJKokkosReduceEnd(MPI_Comm comm, KokkosCsrMatrix A, KokkosCsrMatrix B, PetscInt cstart, PetscInt cend, const PetscInt *garray1, PetscSF ownerSF, MatReuse reuse, PetscInt *map, MatMatStruct_AtB *mm)
713 {
714   auto       &leafBuf  = mm->leafBuf;
715   auto       &rootBuf  = mm->rootBuf;
716   auto       &Fda      = mm->Fd.values;
717   const auto &Fdjmap   = mm->Fdjmap;
718   const auto &Fdjperm  = mm->Fdjperm;
719   auto        Fdnz     = mm->Fd.nnz();
720   auto       &Foa      = mm->Fo.values;
721   const auto &Fojmap   = mm->Fojmap;
722   const auto &Fojperm  = mm->Fojperm;
723   auto        Fonz     = mm->Fo.nnz();
724   PetscSF     reduceSF = mm->sf;
725 
726   PetscFunctionBegin;
727   PetscCall(PetscSFReduceEnd(reduceSF, MPIU_SCALAR, leafBuf.data(), rootBuf.data(), MPI_REPLACE));
728 
729   // Reduce data in rootBuf to Fd and Fo
730   PetscCallCXX(Kokkos::parallel_for(
731     Kokkos::RangePolicy<>(PetscGetKokkosExecutionSpace(), 0, Fdnz), KOKKOS_LAMBDA(const MatRowMapType i) {
732       PetscScalar sum = 0.0;
733       for (MatRowMapType k = Fdjmap(i); k < Fdjmap(i + 1); k++) sum += rootBuf(Fdjperm(k));
734       Fda(i) = sum;
735     }));
736 
737   PetscCallCXX(Kokkos::parallel_for(
738     Kokkos::RangePolicy<>(PetscGetKokkosExecutionSpace(), 0, Fonz), KOKKOS_LAMBDA(const MatRowMapType i) {
739       PetscScalar sum = 0.0;
740       for (MatRowMapType k = Fojmap(i); k < Fojmap(i + 1); k++) sum += rootBuf(Fojperm(k));
741       Foa(i) = sum;
742     }));
743   PetscFunctionReturn(PETSC_SUCCESS);
744 }
745 
746 /*
747   MatMPIAIJKokkosBcast - Bcast local rows of a MPIAIJKOKKOS matrix (E) to produce a local matrix (F, stored in mm) in split form
748 
749   This is a complex routine. It is essentially the MPIAIJKOKKOS counterpart of MatGetBrowsOfAoCols_MPIAIJ, but supports
750   device and involves various index mapping.
751 
752   In the given ownerSF, leaves correspond to rows in F, and roots correspond to rows in E. Roots may connect to multiple leaves.
753   Suppose F's j-th row is connected to a root identified by PetscSFNode (k,i), it means we need to bcast the i-th row of E on rank k
754   to j-th row of F. ownerSF is not an arbitrary SF, instead it is the Mvctx of another MPIAIJ matrix A that is able to perform A*E.
755   F has the same column layout as E.
756 
757   Conceptually F has global column indices. In this routine, we spit F into diagonal Fd and off-diagonal Fo.
758   Fd uses local column indices, which are easy to compute. We just need to subtract the "local column range start" from the global indices.
759   Fo had global column indices at first. We will reduce them into local ones. In doing that, we also take into account the global
760   column indices that E's off-diag block has. Let's say there are n1 such indices stored in garray1[]. We will reduce them along with
761   column indices in Fo and update Fo with local indices.
762 
763    Input Parameters:
764 +   E       - the MPIAIJKOKKOS matrix
765 .   ownerSF - the ownership SF (insignificant in MAT_REUSE_MATRIX)
766 .   reuse   - either MAT_INITIAL_MATRIX or MAT_REUSE_MATRIX
767 -   mm      - to stash matproduct intermediate data structures
768 
769     Output Parameters:
770 +   map[n1] - allocated by caller. It maps garray1[] to garray2[]. See more at ReduceTwoSetsOfGlobalIndices.
771 -   mm      - contains various info, such as garray2[], Fd, Fo, etc.
772 
773     Notes:
774     When reuse = MAT_REUSE_MATRIX, ownerSF, map are not significant.
775     The routine is provide in split-phase form MatMPIAIJKokkosBcastBegin/End() to provide computation/communication opportunities.
776 */
777 static PetscErrorCode MatMPIAIJKokkosBcastBegin(Mat E, PetscSF ownerSF, MatReuse reuse, PetscInt *map, MatMatStruct_AB *mm)
778 {
779   Mat_MPIAIJ       *empi = static_cast<Mat_MPIAIJ *>(E->data);
780   Mat               A = empi->A, B = empi->B; // diag and off-diag
781   Mat_SeqAIJKokkos *akok = static_cast<Mat_SeqAIJKokkos *>(A->spptr), *bkok = static_cast<Mat_SeqAIJKokkos *>(B->spptr);
782   PetscInt          Em = E->rmap->n; // #local rows
783   MPI_Comm          comm;
784 
785   PetscFunctionBegin;
786   PetscCallMPI(PetscObjectGetComm((PetscObject)E, &comm));
787   if (reuse == MAT_INITIAL_MATRIX) {
788     Mat_SeqAIJ     *aseq = static_cast<Mat_SeqAIJ *>(A->data), *bseq = static_cast<Mat_SeqAIJ *>(B->data);
789     PetscInt        n1 = B->cmap->n, *Ai = aseq->i, *Aj = aseq->j, *Bi = bseq->i, *Bj = bseq->j;
790     const PetscInt *garray1 = empi->garray; // its size is n1
791     PetscInt        cstart, cend;
792     PetscSF         bcastSF;
793 
794     PetscCall(MatGetOwnershipRangeColumn(E, &cstart, &cend));
795 
796     // Count how many nonzeros of each row in E are in the left of the diag block [cstart,cend)
797     PetscIntKokkosViewHost E_NzLeft_h(NoInit("E_NzLeft_h"), Em), E_RowLen_h(NoInit("E_RowLen_h"), Em);
798     PetscInt              *E_NzLeft = E_NzLeft_h.data(), *E_RowLen = E_RowLen_h.data();
799     for (PetscInt i = 0; i < Em; i++) {
800       const PetscInt *first, *last, *it;
801       PetscInt        count, step;
802       // std::lower_bound(first,last,cstart), but need to use global column indices
803       first = Bj + Bi[i];
804       last  = Bj + Bi[i + 1];
805       count = last - first;
806       while (count > 0) {
807         it   = first;
808         step = count / 2;
809         it += step;
810         if (empi->garray[*it] < cstart) { // map local to global
811           first = ++it;
812           count -= step + 1;
813         } else count = step;
814       }
815       E_NzLeft[i] = first - (Bj + Bi[i]);
816       E_RowLen[i] = (Ai[i + 1] - Ai[i]) + (Bi[i + 1] - Bi[i]);
817     }
818 
819     // Compute row pointer Fi of F
820     PetscInt *Fi, Fm, Fnz;
821     PetscCall(PetscSFGetGraph(ownerSF, NULL, &Fm, NULL, NULL)); // Fm = #rows of F = nleaves of ownerSF
822     PetscCall(PetscMalloc1(Fm + 1, &Fi));
823     Fi[0] = 0;
824     PetscCall(PetscSFBcastWithMemTypeBegin(ownerSF, MPIU_INT, PETSC_MEMTYPE_HOST, E_RowLen, PETSC_MEMTYPE_HOST, &Fi[1], MPI_REPLACE));
825     PetscCall(PetscSFBcastEnd(ownerSF, MPIU_INT, E_RowLen, &Fi[1], MPI_REPLACE));
826     for (PetscInt i = 0; i < Fm; i++) Fi[i + 1] += Fi[i];
827     Fnz = Fi[Fm];
828 
829     // Build the real PetscSF for bcasting E rows (buffer to buffer)
830     const PetscMPIInt *iranks, *ranks;
831     const PetscInt    *ioffset, *irootloc, *roffset;
832     PetscInt           niranks, nranks, *sdisp, *rdisp;
833     MPI_Request       *reqs;
834     PetscMPIInt        tag;
835 
836     PetscCall(PetscSFGetLeafRanks(ownerSF, &niranks, &iranks, &ioffset, &irootloc)); // get leaf ranks referencing roots on this process
837     PetscCall(PetscSFGetRootRanks(ownerSF, &nranks, &ranks, &roffset, NULL, NULL));  // recv info
838     PetscCall(PetscMalloc3(niranks + 1, &sdisp, nranks, &rdisp, niranks + nranks, &reqs));
839 
840     sdisp[0] = 0; // send displacement
841     for (PetscInt i = 0; i < niranks; i++) {
842       sdisp[i + 1] = sdisp[i];
843       for (PetscInt j = ioffset[i]; j < ioffset[i + 1]; j++) {
844         PetscInt r = irootloc[j]; // row to be sent
845         sdisp[i + 1] += E_RowLen[r];
846       }
847     }
848 
849     PetscCallMPI(PetscCommGetNewTag(comm, &tag));
850     for (PetscInt i = 0; i < nranks; i++) PetscCallMPI(MPIU_Irecv(&rdisp[i], 1, MPIU_INT, ranks[i], tag, comm, &reqs[i]));
851     for (PetscInt i = 0; i < niranks; i++) PetscCallMPI(MPIU_Isend(&sdisp[i], 1, MPIU_INT, iranks[i], tag, comm, &reqs[nranks + i]));
852     PetscCallMPI(MPI_Waitall(niranks + nranks, reqs, MPI_STATUSES_IGNORE));
853 
854     PetscInt     nleaves = Fnz;            // leaves are nonzeros I will receive
855     PetscInt     nroots  = sdisp[niranks]; // roots are nonzeros I will send
856     PetscSFNode *iremote;                  // give ownership to bcastSF
857     PetscCall(PetscMalloc1(nleaves, &iremote));
858     for (PetscInt i = 0; i < nranks; i++) { // for each sender rank
859       PetscInt k = 0;
860       for (PetscInt j = Fi[roffset[i]]; j < Fi[roffset[i + 1]]; j++) { // I will receive rows [roffset[i], roffset[i+1]) of F from ranks[i]
861         iremote[j].rank  = ranks[i];
862         iremote[j].index = rdisp[i] + k; // their root location
863         k++;
864       }
865     }
866     PetscCall(PetscSFCreate(comm, &bcastSF));
867     PetscCall(PetscSFSetGraph(bcastSF, nroots, nleaves, NULL, PETSC_OWN_POINTER, iremote, PETSC_OWN_POINTER));
868     PetscCall(PetscFree3(sdisp, rdisp, reqs));
869 
870     // Build a plan (rowoffset, irootloc, E_NzLeft) to copy rows in E to rootdata of bcastSF in parallel
871     PetscIntKokkosViewHost rowoffset_h(NoInit("rowoffset_h"), ioffset[niranks] + 1);
872     PetscInt              *rowoffset = rowoffset_h.data(); // for each entry (row) indicated in irootloc[], we calculate its destinate offset in copying
873     rowoffset[0]                     = 0;
874     for (PetscInt i = 0; i < ioffset[niranks]; i++) { rowoffset[i + 1] = rowoffset[i] + E_RowLen[irootloc[i]]; }
875 
876     // Copy (global) column indices of the needed rows in E to a buffer, and then bcast to Fj[]
877     PetscInt *jbuf, *Fj;
878     PetscCall(PetscMalloc2(nroots, &jbuf, Fnz, &Fj));
879     for (PetscInt k = 0; k < ioffset[niranks]; k++) {
880       PetscInt  i      = irootloc[k]; // row to be copied
881       PetscInt *buf    = &jbuf[rowoffset[k]];
882       PetscInt  nzLeft = E_NzLeft[i];
883       PetscInt  alen = Ai[i + 1] - Ai[i], blen = Bi[i + 1] - Bi[i];
884       for (PetscInt j = 0; j < alen + blen; j++) {
885         if (j < nzLeft) {
886           buf[j] = empi->garray[Bj[Bi[i] + j]]; // left B, in global
887         } else if (j < nzLeft + alen) {
888           buf[j] = Aj[Ai[i] + j - nzLeft] + cstart; // diag A, also in global
889         } else {
890           buf[j] = empi->garray[Bj[Bi[i] + j - alen]]; // right B, in global
891         }
892       }
893     }
894     PetscCall(PetscSFBcastWithMemTypeBegin(bcastSF, MPIU_INT, PETSC_MEMTYPE_HOST, jbuf, PETSC_MEMTYPE_HOST, Fj, MPI_REPLACE));
895     PetscCall(PetscSFBcastEnd(bcastSF, MPIU_INT, jbuf, Fj, MPI_REPLACE));
896 
897     // Build a plan (i.e., F_NzLeft) to split F into Fd and Fo
898     MatRowMapKokkosViewHost Fdi_h(NoInit("Fdi_h"), Fm + 1), Foi_h(NoInit("Foi_h"), Fm + 1); // row pointer of Fd, Fo
899     MatColIdxKokkosViewHost F_NzLeft_h(NoInit("F_NzLeft_h"), Fm);                           // split each row of F into Left, Diag, Right. We only need to record #nz in Left and Diag.
900     MatRowMapType          *Fdi = Fdi_h.data(), *Foi = Foi_h.data();
901     MatColIdxType          *F_NzLeft = F_NzLeft_h.data();
902 
903     Fdi[0] = Foi[0] = 0;
904     for (PetscInt i = 0; i < Fm; i++) {
905       PetscInt *first, *last, *lb1, *lb2;
906       // cut the row into: Left, [cstart, cend), Right
907       first       = Fj + Fi[i];
908       last        = Fj + Fi[i + 1];
909       lb1         = std::lower_bound(first, last, cstart);
910       F_NzLeft[i] = lb1 - first;
911       lb2         = std::lower_bound(first, last, cend);
912       Fdi[i + 1]  = lb2 - lb1;                        // row i length in Fdi
913       Foi[i + 1]  = (Fi[i + 1] - Fi[i]) - Fdi[i + 1]; // row i length in Foi
914     }
915     for (PetscInt i = 0; i < Fm; i++) {
916       Fdi[i + 1] += Fdi[i];
917       Foi[i + 1] += Foi[i];
918     }
919 
920     // Fill Fdj[] and Foj[], i.e., columns of Fd and Fo. Fdj[] are local, but Foj[] are not yet.
921     PetscInt                Fdnz = Fdi[Fm], Fonz = Foi[Fm];
922     MatColIdxKokkosViewHost Fdj_h(NoInit("Fdj_h"), Fdnz), Foj_h(NoInit("Foj_h"), Fonz);
923     MatColIdxType          *Fdj = Fdj_h.data(), *Foj = Foj_h.data(), gid;
924 
925     for (PetscInt i = 0; i < Fm; i++) {
926       PetscInt nzLeft = F_NzLeft[i];
927       PetscInt len    = Fdi[i + 1] - Fdi[i]; // diag row len
928       for (PetscInt j = 0; j < Fi[i + 1] - Fi[i]; j++) {
929         gid = Fj[Fi[i] + j];
930         if (j < nzLeft) { // left, in global
931           Foj[Foi[i] + j] = gid;
932         } else if (j < nzLeft + len) { // diag, in local
933           Fdj[Fdi[i] + j - nzLeft] = gid - cstart;
934         } else { // right, in global
935           Foj[Foi[i] + j - len] = gid;
936         }
937       }
938     }
939     PetscCall(PetscFree2(jbuf, Fj));
940     PetscCall(PetscFree(Fi));
941 
942     // Reduce global indices in Foj[] and garray1[] into local ones
943     PetscInt n2, *garray2;
944     PetscCall(ReduceTwoSetsOfGlobalIndices(n1, garray1, Fonz, Foj, &n2, &garray2, map));
945 
946     // Record the plans built above, for reuse
947     PetscIntKokkosViewHost tmp(const_cast<PetscInt *>(irootloc), ioffset[niranks]); // irootloc[] is owned by ownerSF. We create a copy for safety
948     PetscIntKokkosViewHost irootloc_h(NoInit("irootloc_h"), ioffset[niranks]);
949     Kokkos::deep_copy(irootloc_h, tmp);
950     mm->sf        = bcastSF;
951     mm->E_NzLeft  = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), E_NzLeft_h);
952     mm->F_NzLeft  = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), F_NzLeft_h);
953     mm->irootloc  = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), irootloc_h);
954     mm->rowoffset = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), rowoffset_h);
955     mm->rootBuf   = MatScalarKokkosView(NoInit("rootBuf"), nroots);
956     mm->leafBuf   = MatScalarKokkosView(NoInit("leafBuf"), nleaves);
957     mm->garray    = garray2;
958     mm->n         = n2;
959 
960     // Output Fd and Fo in KokkosCsrMatrix format
961     MatScalarKokkosView Fda_d(NoInit("Fda_d"), Fdnz), Foa_d(NoInit("Foa_d"), Fonz);
962     MatRowMapKokkosView Fdi_d = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Fdi_h);
963     MatColIdxKokkosView Fdj_d = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Fdj_h);
964     MatRowMapKokkosView Foi_d = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Foi_h);
965     MatColIdxKokkosView Foj_d = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Foj_h);
966 
967     PetscCallCXX(mm->Fd = KokkosCsrMatrix("Fd", Fm, cend - cstart, Fdnz, Fda_d, Fdi_d, Fdj_d));
968     PetscCallCXX(mm->Fo = KokkosCsrMatrix("Fo", Fm, n2, Fonz, Foa_d, Foi_d, Foj_d));
969 
970     // Compute kernel launch parameters in merging E or splitting F
971     PetscInt teamSize, vectorLength, rowsPerTeam;
972 
973     teamSize = vectorLength = rowsPerTeam = -1;
974     PetscCall(MatMergeGetLaunchParameters<DefaultExecutionSpace>(mm->irootloc.extent(0), mm->rootBuf.extent(0), -1, teamSize, vectorLength, rowsPerTeam));
975     mm->E_TeamSize     = teamSize;
976     mm->E_VectorLength = vectorLength;
977     mm->E_RowsPerTeam  = rowsPerTeam;
978 
979     teamSize = vectorLength = rowsPerTeam = -1;
980     PetscCall(MatMergeGetLaunchParameters<DefaultExecutionSpace>(Fm, Fnz, -1, teamSize, vectorLength, rowsPerTeam));
981     mm->F_TeamSize     = teamSize;
982     mm->F_VectorLength = vectorLength;
983     mm->F_RowsPerTeam  = rowsPerTeam;
984   } else PetscCheck(reuse == MAT_REUSE_MATRIX, comm, PETSC_ERR_PLIB, "Unsupported MatReuse enum %d", reuse);
985 
986   // Sync E's value to device
987   akok->a_dual.sync_device();
988   bkok->a_dual.sync_device();
989 
990   // Handy aliases
991   const auto &Aa = akok->a_dual.view_device();
992   const auto &Ba = bkok->a_dual.view_device();
993   const auto &Ai = akok->i_dual.view_device();
994   const auto &Bi = bkok->i_dual.view_device();
995 
996   // Fetch the plans
997   PetscIntKokkosView  &E_NzLeft  = mm->E_NzLeft;
998   PetscSF             &bcastSF   = mm->sf;
999   MatScalarKokkosView &rootBuf   = mm->rootBuf;
1000   MatScalarKokkosView &leafBuf   = mm->leafBuf;
1001   PetscIntKokkosView  &irootloc  = mm->irootloc;
1002   PetscIntKokkosView  &rowoffset = mm->rowoffset;
1003 
1004   PetscInt teamSize     = mm->E_TeamSize;
1005   PetscInt vectorLength = mm->E_VectorLength;
1006   PetscInt rowsPerTeam  = mm->E_RowsPerTeam;
1007   PetscInt workSets     = (irootloc.extent(0) + rowsPerTeam - 1) / rowsPerTeam;
1008 
1009   // Copy rows in A/B of E to rootBuf, then bcast it to leafBuf
1010   PetscCallCXX(Kokkos::parallel_for(
1011     Kokkos::TeamPolicy<>(PetscGetKokkosExecutionSpace(), workSets, teamSize, vectorLength), KOKKOS_LAMBDA(const KokkosTeamMemberType &t) {
1012       Kokkos::parallel_for(Kokkos::TeamThreadRange(t, 0, rowsPerTeam), [&](PetscInt k) {
1013         size_t r = t.league_rank() * rowsPerTeam + k; // r-th entry in irootloc[]
1014         if (r < irootloc.extent(0)) {
1015           PetscInt i      = irootloc(r); // row i of E
1016           PetscInt disp   = rowoffset(r);
1017           PetscInt alen   = Ai(i + 1) - Ai(i);
1018           PetscInt blen   = Bi(i + 1) - Bi(i);
1019           PetscInt nzleft = E_NzLeft(i);
1020 
1021           Kokkos::parallel_for(Kokkos::ThreadVectorRange(t, alen + blen), [&](PetscInt j) {
1022             if (j < nzleft) { // B left
1023               rootBuf(disp + j) = Ba(Bi(i) + j);
1024             } else if (j < nzleft + alen) { // diag A
1025               rootBuf(disp + j) = Aa(Ai(i) + j - nzleft);
1026             } else { // B right
1027               rootBuf(disp + j) = Ba(Bi(i) + j - alen);
1028             }
1029           });
1030         }
1031       });
1032     }));
1033   PetscCall(PetscSFBcastWithMemTypeBegin(bcastSF, MPIU_SCALAR, PETSC_MEMTYPE_KOKKOS, rootBuf.data(), PETSC_MEMTYPE_KOKKOS, leafBuf.data(), MPI_REPLACE));
1034   PetscFunctionReturn(PETSC_SUCCESS);
1035 }
1036 
1037 // To finish MatMPIAIJKokkosBcast.
1038 static PetscErrorCode MatMPIAIJKokkosBcastEnd(Mat E, PetscSF ownerSF, MatReuse reuse, PetscInt *map, MatMatStruct_AB *mm)
1039 {
1040   PetscFunctionBegin;
1041   const auto &Fd  = mm->Fd;
1042   const auto &Fo  = mm->Fo;
1043   const auto &Fdi = Fd.graph.row_map;
1044   const auto &Foi = Fo.graph.row_map;
1045   auto       &Fda = Fd.values;
1046   auto       &Foa = Fo.values;
1047   auto        Fm  = Fd.numRows();
1048 
1049   PetscIntKokkosView  &F_NzLeft     = mm->F_NzLeft;
1050   PetscSF             &bcastSF      = mm->sf;
1051   MatScalarKokkosView &rootBuf      = mm->rootBuf;
1052   MatScalarKokkosView &leafBuf      = mm->leafBuf;
1053   PetscInt             teamSize     = mm->F_TeamSize;
1054   PetscInt             vectorLength = mm->F_VectorLength;
1055   PetscInt             rowsPerTeam  = mm->F_RowsPerTeam;
1056   PetscInt             workSets     = (Fm + rowsPerTeam - 1) / rowsPerTeam;
1057 
1058   PetscCall(PetscSFBcastEnd(bcastSF, MPIU_SCALAR, rootBuf.data(), leafBuf.data(), MPI_REPLACE));
1059 
1060   // Update Fda and Foa with new data in leafBuf (as if it is Fa)
1061   PetscCallCXX(Kokkos::parallel_for(
1062     Kokkos::TeamPolicy<>(PetscGetKokkosExecutionSpace(), workSets, teamSize, vectorLength), KOKKOS_LAMBDA(const KokkosTeamMemberType &t) {
1063       Kokkos::parallel_for(Kokkos::TeamThreadRange(t, 0, rowsPerTeam), [&](PetscInt k) {
1064         PetscInt i = t.league_rank() * rowsPerTeam + k; // i-th row in F
1065         if (i < Fm) {
1066           PetscInt nzLeft = F_NzLeft(i);
1067           PetscInt alen   = Fdi(i + 1) - Fdi(i);
1068           PetscInt blen   = Foi(i + 1) - Foi(i);
1069           PetscInt Fii    = Fdi(i) + Foi(i);
1070 
1071           Kokkos::parallel_for(Kokkos::ThreadVectorRange(t, alen + blen), [&](PetscInt j) {
1072             PetscScalar val = leafBuf(Fii + j);
1073             if (j < nzLeft) { // left
1074               Foa(Foi(i) + j) = val;
1075             } else if (j < nzLeft + alen) { // diag
1076               Fda(Fdi(i) + j - nzLeft) = val;
1077             } else { // right
1078               Foa(Foi(i) + j - alen) = val;
1079             }
1080           });
1081         }
1082       });
1083     }));
1084   PetscFunctionReturn(PETSC_SUCCESS);
1085 }
1086 
1087 static PetscErrorCode MatProductSymbolic_MPIAIJKokkos_AtB(Mat_Product *product, Mat A, Mat B, MatMatStruct_AtB *mm)
1088 {
1089   Mat_MPIAIJ     *ampi = static_cast<Mat_MPIAIJ *>(A->data);
1090   Mat_MPIAIJ     *bmpi = static_cast<Mat_MPIAIJ *>(B->data);
1091   KokkosCsrMatrix Adt, Aot, Ad, Ao, Bd, Bo;
1092   PetscInt        cstart, cend;
1093   MPI_Comm        comm;
1094 
1095   PetscFunctionBegin;
1096   PetscCall(PetscObjectGetComm((PetscObject)B, &comm));
1097   PetscCall(MatSeqAIJKokkosGenerateTranspose_Private(ampi->A, &Adt));
1098   PetscCall(MatSeqAIJKokkosGenerateTranspose_Private(ampi->B, &Aot));
1099   PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(ampi->A, &Ad));
1100   PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(ampi->B, &Ao));
1101   PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(bmpi->A, &Bd));
1102   PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(bmpi->B, &Bo));
1103 
1104   // TODO: add command line options to select spgemm algorithms
1105   auto spgemm_alg = KokkosSparse::SPGEMMAlgorithm::SPGEMM_DEFAULT; // default is TPL if enabled, otherwise KK
1106 
1107   // CUDA-10.2's spgemm has bugs. We prefer the SpGEMMreuse APIs introduced in cuda-11.4
1108 #if defined(KOKKOSKERNELS_ENABLE_TPL_CUSPARSE)
1109   #if PETSC_PKG_CUDA_VERSION_LT(11, 4, 0)
1110   spgemm_alg = KokkosSparse::SPGEMMAlgorithm::SPGEMM_KK;
1111   #endif
1112 #endif
1113 
1114   PetscCallCXX(mm->kh1.create_spgemm_handle(spgemm_alg));
1115   PetscCallCXX(mm->kh2.create_spgemm_handle(spgemm_alg));
1116   PetscCallCXX(mm->kh3.create_spgemm_handle(spgemm_alg));
1117   PetscCallCXX(mm->kh4.create_spgemm_handle(spgemm_alg));
1118 
1119   // Aot * (B's diag + B's off-diag)
1120   PetscCallCXX(KokkosSparse::spgemm_symbolic(mm->kh3, Aot, false, Bd, false, mm->C3));
1121   PetscCallCXX(KokkosSparse::spgemm_symbolic(mm->kh4, Aot, false, Bo, false, mm->C4));
1122   // KK spgemm_symbolic() only populates the result's row map, but not its columns.
1123   // TODO: Remove the fake spgemm_numeric() after KK fixed this problem.
1124   PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh3, Aot, false, Bd, false, mm->C3));
1125   PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh4, Aot, false, Bo, false, mm->C4));
1126 #if PETSC_PKG_KOKKOS_KERNELS_VERSION_LT(4, 0, 0)
1127 
1128   PetscCallCXX(sort_crs_matrix(mm->C3));
1129   PetscCallCXX(sort_crs_matrix(mm->C4));
1130 #endif
1131 
1132   // Reduce E (i.e., C3 and C4)'s rows to form F, and overlap the communication
1133   PetscIntKokkosViewHost map_h(NoInit("map_h"), bmpi->B->cmap->n);
1134   PetscCall(MatGetOwnershipRangeColumn(B, &cstart, &cend));
1135   PetscCall(MatMPIAIJKokkosReduceBegin(comm, mm->C3, mm->C4, cstart, cend, bmpi->garray, ampi->Mvctx, MAT_INITIAL_MATRIX, map_h.data(), mm));
1136 
1137   // Adt * (B's diag + B's off-diag)
1138   PetscCallCXX(KokkosSparse::spgemm_symbolic(mm->kh1, Adt, false, Bd, false, mm->C1));
1139   PetscCallCXX(KokkosSparse::spgemm_symbolic(mm->kh2, Adt, false, Bo, false, mm->C2_mid));
1140   PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh1, Adt, false, Bd, false, mm->C1));
1141   PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh2, Adt, false, Bo, false, mm->C2_mid));
1142 #if PETSC_PKG_KOKKOS_KERNELS_VERSION_LT(4, 0, 0)
1143   PetscCallCXX(sort_crs_matrix(mm->C1));
1144   PetscCallCXX(sort_crs_matrix(mm->C2_mid));
1145 #endif
1146 
1147   PetscCall(MatMPIAIJKokkosReduceEnd(comm, mm->C3, mm->C4, cstart, cend, bmpi->garray, ampi->Mvctx, MAT_INITIAL_MATRIX, map_h.data(), mm));
1148 
1149   // Create C2, which shares a, i arrays with C2_mid, but with new column indices and potentially larger column size
1150   MatColIdxKokkosView oldj = mm->C2_mid.graph.entries, newj(NoInit("j"), oldj.extent(0));
1151   PetscIntKokkosView  map  = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), map_h);
1152   PetscCallCXX(Kokkos::parallel_for(Kokkos::RangePolicy<>(PetscGetKokkosExecutionSpace(), 0, oldj.extent(0)), KOKKOS_LAMBDA(const PetscInt i) { newj(i) = map(oldj(i)); }));
1153   PetscCallCXX(mm->C2 = KokkosCsrMatrix("C2", mm->C2_mid.numRows(), mm->n /*new column size*/, mm->C2_mid.nnz(), mm->C2_mid.values, mm->C2_mid.graph.row_map, newj));
1154 
1155   // C = (C1+Fd, C2+Fo)
1156   PetscCallCXX(mm->kh1.create_spadd_handle(true)); // C1, Fd are sorted
1157   PetscCallCXX(mm->kh2.create_spadd_handle(true)); // C2, Fo are sorted
1158   PetscCallCXX(KokkosSparse::spadd_symbolic(&mm->kh1, mm->C1, mm->Fd, mm->Cd));
1159   PetscCallCXX(KokkosSparse::spadd_symbolic(&mm->kh2, mm->C2, mm->Fo, mm->Co));
1160   PetscCallCXX(KokkosSparse::spadd_numeric(&mm->kh1, 1.0, mm->C1, 1.0, mm->Fd, mm->Cd));
1161   PetscCallCXX(KokkosSparse::spadd_numeric(&mm->kh2, 1.0, mm->C2, 1.0, mm->Fo, mm->Co));
1162   PetscFunctionReturn(PETSC_SUCCESS);
1163 }
1164 
1165 static PetscErrorCode MatProductNumeric_MPIAIJKokkos_AtB(Mat_Product *product, Mat A, Mat B, MatMatStruct_AtB *mm)
1166 {
1167   Mat_MPIAIJ     *ampi = static_cast<Mat_MPIAIJ *>(A->data);
1168   Mat_MPIAIJ     *bmpi = static_cast<Mat_MPIAIJ *>(B->data);
1169   KokkosCsrMatrix Adt, Aot, Bd, Bo;
1170   MPI_Comm        comm;
1171 
1172   PetscFunctionBegin;
1173   PetscCall(PetscObjectGetComm((PetscObject)B, &comm));
1174   PetscCall(MatSeqAIJKokkosGenerateTranspose_Private(ampi->A, &Adt));
1175   PetscCall(MatSeqAIJKokkosGenerateTranspose_Private(ampi->B, &Aot));
1176   PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(bmpi->A, &Bd));
1177   PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(bmpi->B, &Bo));
1178 
1179   // Aot * (B's diag + B's off-diag)
1180   PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh3, Aot, false, Bd, false, mm->C3));
1181   PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh4, Aot, false, Bo, false, mm->C4));
1182 
1183   // Reduce E (i.e., C3 and C4)'s rows to form F, and overlap the communication
1184   PetscCall(MatMPIAIJKokkosReduceBegin(comm, mm->C3, mm->C4, 0, 0, NULL, NULL, MAT_REUSE_MATRIX, NULL, mm));
1185 
1186   // Adt * (B's diag + B's off-diag)
1187   PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh1, Adt, false, Bd, false, mm->C1));
1188   PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh2, Adt, false, Bo, false, mm->C2_mid));
1189 
1190   PetscCall(MatMPIAIJKokkosReduceEnd(comm, mm->C3, mm->C4, 0, 0, NULL, NULL, MAT_REUSE_MATRIX, NULL, mm));
1191 
1192   // C = (C1+Fd, C2+Fo)
1193   PetscCallCXX(KokkosSparse::spadd_numeric(&mm->kh1, 1.0, mm->C1, 1.0, mm->Fd, mm->Cd));
1194   PetscCallCXX(KokkosSparse::spadd_numeric(&mm->kh2, 1.0, mm->C2, 1.0, mm->Fo, mm->Co));
1195   PetscFunctionReturn(PETSC_SUCCESS);
1196 }
1197 
1198 /* MatProductSymbolic_MPIAIJKokkos_AB - AB flavor of MatProductSymbolic_MPIAIJKokkos
1199 
1200   Input Parameters:
1201 +  product  - Mat_Product which carried out the computation. Passed in to access info about this mat product.
1202 .  A        - an MPIAIJKOKKOS matrix
1203 .  B        - an MPIAIJKOKKOS matrix
1204 -  mm       - a struct used to stash intermediate data when computing AB. Persist from symbolic to numeric operations.
1205 */
1206 static PetscErrorCode MatProductSymbolic_MPIAIJKokkos_AB(Mat_Product *product, Mat A, Mat B, MatMatStruct_AB *mm)
1207 {
1208   Mat_MPIAIJ     *ampi = static_cast<Mat_MPIAIJ *>(A->data);
1209   Mat_MPIAIJ     *bmpi = static_cast<Mat_MPIAIJ *>(B->data);
1210   KokkosCsrMatrix Ad, Ao, Bd, Bo;
1211 
1212   PetscFunctionBegin;
1213   PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(ampi->A, &Ad));
1214   PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(ampi->B, &Ao));
1215   PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(bmpi->A, &Bd));
1216   PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(bmpi->B, &Bo));
1217 
1218   // TODO: add command line options to select spgemm algorithms
1219   auto spgemm_alg = KokkosSparse::SPGEMMAlgorithm::SPGEMM_DEFAULT; // default is TPL if enabled, otherwise KK
1220 
1221   // CUDA-10.2's spgemm has bugs. We prefer the SpGEMMreuse APIs introduced in cuda-11.4
1222 #if defined(KOKKOSKERNELS_ENABLE_TPL_CUSPARSE)
1223   #if PETSC_PKG_CUDA_VERSION_LT(11, 4, 0)
1224   spgemm_alg = KokkosSparse::SPGEMMAlgorithm::SPGEMM_KK;
1225   #endif
1226 #endif
1227 
1228   mm->kh1.create_spgemm_handle(spgemm_alg);
1229   mm->kh2.create_spgemm_handle(spgemm_alg);
1230   mm->kh3.create_spgemm_handle(spgemm_alg);
1231   mm->kh4.create_spgemm_handle(spgemm_alg);
1232 
1233   // Bcast B's rows to form F, and overlap the communication
1234   PetscIntKokkosViewHost map_h(NoInit("map_h"), bmpi->B->cmap->n);
1235   PetscCall(MatMPIAIJKokkosBcastBegin(B, ampi->Mvctx, MAT_INITIAL_MATRIX, map_h.data(), mm));
1236 
1237   // A's diag * (B's diag + B's off-diag)
1238   PetscCallCXX(KokkosSparse::spgemm_symbolic(mm->kh1, Ad, false, Bd, false, mm->C1));
1239   PetscCallCXX(KokkosSparse::spgemm_symbolic(mm->kh2, Ad, false, Bo, false, mm->C2_mid)); // C2 aliases with C2_mid, except with new column indices
1240   // KK spgemm_symbolic() only populates the result's row map, but not its columns.
1241   // TODO: Remove the fake spgemm_numeric() after KK fixed this problem.
1242   PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh1, Ad, false, Bd, false, mm->C1));
1243   PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh2, Ad, false, Bo, false, mm->C2_mid));
1244 #if PETSC_PKG_KOKKOS_KERNELS_VERSION_LT(4, 0, 0)
1245   PetscCallCXX(sort_crs_matrix(mm->C1));
1246   PetscCallCXX(sort_crs_matrix(mm->C2_mid));
1247 #endif
1248 
1249   PetscCall(MatMPIAIJKokkosBcastEnd(B, ampi->Mvctx, MAT_INITIAL_MATRIX, map_h.data(), mm));
1250 
1251   // A's off-diag * (F's diag + F's off-diag)
1252   PetscCallCXX(KokkosSparse::spgemm_symbolic(mm->kh3, Ao, false, mm->Fd, false, mm->C3));
1253   PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh3, Ao, false, mm->Fd, false, mm->C3));
1254   PetscCallCXX(KokkosSparse::spgemm_symbolic(mm->kh4, Ao, false, mm->Fo, false, mm->C4));
1255   PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh4, Ao, false, mm->Fo, false, mm->C4));
1256 #if PETSC_PKG_KOKKOS_KERNELS_VERSION_LT(4, 0, 0)
1257   PetscCallCXX(sort_crs_matrix(mm->C3));
1258   PetscCallCXX(sort_crs_matrix(mm->C4));
1259 #endif
1260 
1261   // Create C2, which shares a, i arrays with C2_mid, but with new column indices and potentially larger column size
1262   MatColIdxKokkosView oldj = mm->C2_mid.graph.entries, newj(NoInit("j"), oldj.extent(0));
1263   PetscIntKokkosView  map  = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), map_h);
1264   PetscCallCXX(Kokkos::parallel_for(Kokkos::RangePolicy<>(PetscGetKokkosExecutionSpace(), 0, oldj.extent(0)), KOKKOS_LAMBDA(const PetscInt i) { newj(i) = map(oldj(i)); }));
1265   mm->C2 = KokkosCsrMatrix("C2", mm->C2_mid.numRows(), mm->n /*new column size*/, mm->C2_mid.nnz(), mm->C2_mid.values, mm->C2_mid.graph.row_map, newj);
1266 
1267   // C = (Cd, Co) = (C1+C3, C2+C4)
1268   mm->kh1.create_spadd_handle(true); // C1, C3 are sorted
1269   mm->kh2.create_spadd_handle(true); // C2, C4 are sorted
1270   PetscCallCXX(KokkosSparse::spadd_symbolic(&mm->kh1, mm->C1, mm->C3, mm->Cd));
1271   PetscCallCXX(KokkosSparse::spadd_symbolic(&mm->kh2, mm->C2, mm->C4, mm->Co));
1272   PetscCallCXX(KokkosSparse::spadd_numeric(&mm->kh1, 1.0, mm->C1, 1.0, mm->C3, mm->Cd));
1273   PetscCallCXX(KokkosSparse::spadd_numeric(&mm->kh2, 1.0, mm->C2, 1.0, mm->C4, mm->Co));
1274   PetscFunctionReturn(PETSC_SUCCESS);
1275 }
1276 
1277 static PetscErrorCode MatProductNumeric_MPIAIJKokkos_AB(Mat_Product *product, Mat A, Mat B, MatMatStruct_AB *mm)
1278 {
1279   Mat_MPIAIJ     *ampi = static_cast<Mat_MPIAIJ *>(A->data);
1280   Mat_MPIAIJ     *bmpi = static_cast<Mat_MPIAIJ *>(B->data);
1281   KokkosCsrMatrix Ad, Ao, Bd, Bo;
1282 
1283   PetscFunctionBegin;
1284   PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(ampi->A, &Ad));
1285   PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(ampi->B, &Ao));
1286   PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(bmpi->A, &Bd));
1287   PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(bmpi->B, &Bo));
1288 
1289   // Bcast B's rows to form F, and overlap the communication
1290   PetscCall(MatMPIAIJKokkosBcastBegin(B, NULL, MAT_REUSE_MATRIX, NULL, mm));
1291 
1292   // A's diag * (B's diag + B's off-diag)
1293   PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh1, Ad, false, Bd, false, mm->C1));
1294   PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh2, Ad, false, Bo, false, mm->C2_mid));
1295 
1296   PetscCall(MatMPIAIJKokkosBcastEnd(B, NULL, MAT_REUSE_MATRIX, NULL, mm));
1297 
1298   // A's off-diag * (F's diag + F's off-diag)
1299   PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh3, Ao, false, mm->Fd, false, mm->C3));
1300   PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh4, Ao, false, mm->Fo, false, mm->C4));
1301 
1302   // C = (Cd, Co) = (C1+C3, C2+C4)
1303   PetscCallCXX(KokkosSparse::spadd_numeric(&mm->kh1, 1.0, mm->C1, 1.0, mm->C3, mm->Cd));
1304   PetscCallCXX(KokkosSparse::spadd_numeric(&mm->kh2, 1.0, mm->C2, 1.0, mm->C4, mm->Co));
1305   PetscFunctionReturn(PETSC_SUCCESS);
1306 }
1307 
1308 static PetscErrorCode MatProductNumeric_MPIAIJKokkos(Mat C)
1309 {
1310   Mat_MPIAIJ                  *cmpi = static_cast<Mat_MPIAIJ *>(C->data);
1311   Mat_Product                 *product;
1312   MatProductData_MPIAIJKokkos *pdata;
1313   MatProductType               ptype;
1314   Mat                          A, B;
1315 
1316   PetscFunctionBegin;
1317   MatCheckProduct(C, 1); // make sure C is a product
1318   product = C->product;
1319   pdata   = static_cast<MatProductData_MPIAIJKokkos *>(product->data);
1320   ptype   = product->type;
1321   A       = product->A;
1322   B       = product->B;
1323 
1324   // See if numeric has already been done in symbolic (e.g., user calls MatMatMult(A,B,MAT_INITIAL_MATRIX,..,C)).
1325   // If yes, skip the numeric, but reset the flag so that next time when user calls MatMatMult(E,F,MAT_REUSE_MATRIX,..,C),
1326   // we still do numeric.
1327   if (pdata->reusesym) { // numeric reuses results from symbolic
1328     pdata->reusesym = PETSC_FALSE;
1329     PetscFunctionReturn(PETSC_SUCCESS);
1330   }
1331 
1332   if (ptype == MATPRODUCT_AB) {
1333     PetscCall(MatProductNumeric_MPIAIJKokkos_AB(product, A, B, pdata->mmAB));
1334   } else if (ptype == MATPRODUCT_AtB) {
1335     PetscCall(MatProductNumeric_MPIAIJKokkos_AtB(product, A, B, pdata->mmAtB));
1336   } else if (ptype == MATPRODUCT_PtAP) { // BtAB, computed by Z = AB; C= BtZ
1337     PetscCall(MatProductNumeric_MPIAIJKokkos_AB(product, A, B, pdata->mmAB));
1338     PetscCall(MatProductNumeric_MPIAIJKokkos_AtB(product, B, pdata->Z, pdata->mmAtB));
1339   }
1340 
1341   PetscCall(MatSeqAIJKokkosModifyDevice(cmpi->A)); // mark that A, B on device are modified
1342   PetscCall(MatSeqAIJKokkosModifyDevice(cmpi->B));
1343   PetscFunctionReturn(PETSC_SUCCESS);
1344 }
1345 
1346 static PetscErrorCode MatProductSymbolic_MPIAIJKokkos(Mat C)
1347 {
1348   Mat                          A, B;
1349   Mat_Product                 *product;
1350   MatProductType               ptype;
1351   MatProductData_MPIAIJKokkos *pdata;
1352   MatMatStruct                *mm = NULL;
1353   PetscInt                     m, n, M, N;
1354   Mat                          Cd, Co;
1355   MPI_Comm                     comm;
1356 
1357   PetscFunctionBegin;
1358   PetscCall(PetscObjectGetComm((PetscObject)C, &comm));
1359   MatCheckProduct(C, 1);
1360   product = C->product;
1361   PetscCheck(!product->data, comm, PETSC_ERR_PLIB, "Product data not empty");
1362   ptype = product->type;
1363   A     = product->A;
1364   B     = product->B;
1365 
1366   switch (ptype) {
1367   case MATPRODUCT_AB:
1368     m = A->rmap->n;
1369     n = B->cmap->n;
1370     M = A->rmap->N;
1371     N = B->cmap->N;
1372     break;
1373   case MATPRODUCT_AtB:
1374     m = A->cmap->n;
1375     n = B->cmap->n;
1376     M = A->cmap->N;
1377     N = B->cmap->N;
1378     break;
1379   case MATPRODUCT_PtAP:
1380     m = B->cmap->n;
1381     n = B->cmap->n;
1382     M = B->cmap->N;
1383     N = B->cmap->N;
1384     break; /* BtAB */
1385   default:
1386     SETERRQ(comm, PETSC_ERR_PLIB, "Not for product type %s", MatProductTypes[ptype]);
1387   }
1388 
1389   PetscCall(MatSetSizes(C, m, n, M, N));
1390   PetscCall(PetscLayoutSetUp(C->rmap));
1391   PetscCall(PetscLayoutSetUp(C->cmap));
1392   PetscCall(MatSetType(C, ((PetscObject)A)->type_name));
1393 
1394   pdata           = new MatProductData_MPIAIJKokkos();
1395   pdata->reusesym = product->api_user;
1396 
1397   if (ptype == MATPRODUCT_AB) {
1398     auto mmAB = new MatMatStruct_AB();
1399     PetscCall(MatProductSymbolic_MPIAIJKokkos_AB(product, A, B, mmAB));
1400     mm = pdata->mmAB = mmAB;
1401   } else if (ptype == MATPRODUCT_AtB) {
1402     auto mmAtB = new MatMatStruct_AtB();
1403     PetscCall(MatProductSymbolic_MPIAIJKokkos_AtB(product, A, B, mmAtB));
1404     mm = pdata->mmAtB = mmAtB;
1405   } else if (ptype == MATPRODUCT_PtAP) { // C = BtAB, computed as Z = AB; C= BtZ
1406     Mat Zd, Zo, Z;                       // Zd, Zo are owned by pdata->Z
1407 
1408     auto mmAB = new MatMatStruct_AB();
1409     PetscCall(MatProductSymbolic_MPIAIJKokkos_AB(product, A, B, mmAB)); // Z stored as mmAB->{Cd, Co}
1410     PetscCall(MatCreateSeqAIJKokkosWithKokkosCsrMatrix(PETSC_COMM_SELF, mmAB->Cd, &Zd));
1411     PetscCall(MatCreateSeqAIJKokkosWithKokkosCsrMatrix(PETSC_COMM_SELF, mmAB->Co, &Zo));
1412     pdata->mmAB = mmAB;
1413 
1414     m = A->rmap->n; // Z's layout
1415     n = B->cmap->n;
1416     M = A->rmap->N;
1417     N = B->cmap->N;
1418     PetscCall(MatCreate(comm, &Z));
1419     PetscCall(MatSetSizes(Z, m, n, M, N));
1420     PetscCall(PetscLayoutSetUp(Z->rmap));
1421     PetscCall(PetscLayoutSetUp(Z->cmap));
1422     PetscCall(MatSetType(Z, MATMPIAIJKOKKOS));
1423     PetscCall(MatSetMPIAIJKokkosWithSplitSeqAIJKokkosMatrices(Z, Zd, Zo, mmAB->garray));
1424 
1425     auto mmAtB = new MatMatStruct_AtB();
1426     PetscCall(MatProductSymbolic_MPIAIJKokkos_AtB(product, B, Z, mmAtB)); // final result C stored as mmAtB->{Cd, Co}
1427 
1428     pdata->Z = Z; // give ownership to pdata
1429     mm = pdata->mmAtB = mmAtB;
1430   }
1431 
1432   PetscCall(MatCreateSeqAIJKokkosWithKokkosCsrMatrix(PETSC_COMM_SELF, mm->Cd, &Cd));
1433   PetscCall(MatCreateSeqAIJKokkosWithKokkosCsrMatrix(PETSC_COMM_SELF, mm->Co, &Co));
1434   PetscCall(MatSetMPIAIJKokkosWithSplitSeqAIJKokkosMatrices(C, Cd, Co, mm->garray));
1435 
1436   C->product->data       = pdata;
1437   C->product->destroy    = MatProductDataDestroy_MPIAIJKokkos;
1438   C->ops->productnumeric = MatProductNumeric_MPIAIJKokkos;
1439   PetscFunctionReturn(PETSC_SUCCESS);
1440 }
1441 
1442 PETSC_INTERN PetscErrorCode MatProductSetFromOptions_MPIAIJKokkos(Mat mat)
1443 {
1444   Mat_Product *product = mat->product;
1445   PetscBool    match   = PETSC_FALSE;
1446   PetscBool    usecpu  = PETSC_FALSE;
1447 
1448   PetscFunctionBegin;
1449   MatCheckProduct(mat, 1);
1450   if (!product->A->boundtocpu && !product->B->boundtocpu) PetscCall(PetscObjectTypeCompare((PetscObject)product->B, ((PetscObject)product->A)->type_name, &match));
1451   if (match) { /* we can always fallback to the CPU if requested */
1452     switch (product->type) {
1453     case MATPRODUCT_AB:
1454       if (product->api_user) {
1455         PetscOptionsBegin(PetscObjectComm((PetscObject)mat), ((PetscObject)mat)->prefix, "MatMatMult", "Mat");
1456         PetscCall(PetscOptionsBool("-matmatmult_backend_cpu", "Use CPU code", "MatMatMult", usecpu, &usecpu, NULL));
1457         PetscOptionsEnd();
1458       } else {
1459         PetscOptionsBegin(PetscObjectComm((PetscObject)mat), ((PetscObject)mat)->prefix, "MatProduct_AB", "Mat");
1460         PetscCall(PetscOptionsBool("-mat_product_algorithm_backend_cpu", "Use CPU code", "MatMatMult", usecpu, &usecpu, NULL));
1461         PetscOptionsEnd();
1462       }
1463       break;
1464     case MATPRODUCT_AtB:
1465       if (product->api_user) {
1466         PetscOptionsBegin(PetscObjectComm((PetscObject)mat), ((PetscObject)mat)->prefix, "MatTransposeMatMult", "Mat");
1467         PetscCall(PetscOptionsBool("-mattransposematmult_backend_cpu", "Use CPU code", "MatTransposeMatMult", usecpu, &usecpu, NULL));
1468         PetscOptionsEnd();
1469       } else {
1470         PetscOptionsBegin(PetscObjectComm((PetscObject)mat), ((PetscObject)mat)->prefix, "MatProduct_AtB", "Mat");
1471         PetscCall(PetscOptionsBool("-mat_product_algorithm_backend_cpu", "Use CPU code", "MatTransposeMatMult", usecpu, &usecpu, NULL));
1472         PetscOptionsEnd();
1473       }
1474       break;
1475     case MATPRODUCT_PtAP:
1476       if (product->api_user) {
1477         PetscOptionsBegin(PetscObjectComm((PetscObject)mat), ((PetscObject)mat)->prefix, "MatPtAP", "Mat");
1478         PetscCall(PetscOptionsBool("-matptap_backend_cpu", "Use CPU code", "MatPtAP", usecpu, &usecpu, NULL));
1479         PetscOptionsEnd();
1480       } else {
1481         PetscOptionsBegin(PetscObjectComm((PetscObject)mat), ((PetscObject)mat)->prefix, "MatProduct_PtAP", "Mat");
1482         PetscCall(PetscOptionsBool("-mat_product_algorithm_backend_cpu", "Use CPU code", "MatPtAP", usecpu, &usecpu, NULL));
1483         PetscOptionsEnd();
1484       }
1485       break;
1486     default:
1487       break;
1488     }
1489     match = (PetscBool)!usecpu;
1490   }
1491   if (match) {
1492     switch (product->type) {
1493     case MATPRODUCT_AB:
1494     case MATPRODUCT_AtB:
1495     case MATPRODUCT_PtAP:
1496       mat->ops->productsymbolic = MatProductSymbolic_MPIAIJKokkos;
1497       break;
1498     default:
1499       break;
1500     }
1501   }
1502   /* fallback to MPIAIJ ops */
1503   if (!mat->ops->productsymbolic) PetscCall(MatProductSetFromOptions_MPIAIJ(mat));
1504   PetscFunctionReturn(PETSC_SUCCESS);
1505 }
1506 
1507 // Mirror of MatCOOStruct_MPIAIJ on device
1508 struct MatCOOStruct_MPIAIJKokkos {
1509   PetscCount           n;
1510   PetscSF              sf;
1511   PetscCount           Annz, Bnnz;
1512   PetscCount           Annz2, Bnnz2;
1513   PetscCountKokkosView Ajmap1, Aperm1;
1514   PetscCountKokkosView Bjmap1, Bperm1;
1515   PetscCountKokkosView Aimap2, Ajmap2, Aperm2;
1516   PetscCountKokkosView Bimap2, Bjmap2, Bperm2;
1517   PetscCountKokkosView Cperm1;
1518   MatScalarKokkosView  sendbuf, recvbuf;
1519 
1520   MatCOOStruct_MPIAIJKokkos(const MatCOOStruct_MPIAIJ *coo_h)
1521   {
1522     auto &exec = PetscGetKokkosExecutionSpace();
1523 
1524     n       = coo_h->n;
1525     sf      = coo_h->sf;
1526     Annz    = coo_h->Annz;
1527     Bnnz    = coo_h->Bnnz;
1528     Annz2   = coo_h->Annz2;
1529     Bnnz2   = coo_h->Bnnz2;
1530     Ajmap1  = Kokkos::create_mirror_view_and_copy(exec, PetscCountKokkosViewHost(coo_h->Ajmap1, coo_h->Annz + 1));
1531     Aperm1  = Kokkos::create_mirror_view_and_copy(exec, PetscCountKokkosViewHost(coo_h->Aperm1, coo_h->Atot1));
1532     Bjmap1  = Kokkos::create_mirror_view_and_copy(exec, PetscCountKokkosViewHost(coo_h->Bjmap1, coo_h->Bnnz + 1));
1533     Bperm1  = Kokkos::create_mirror_view_and_copy(exec, PetscCountKokkosViewHost(coo_h->Bperm1, coo_h->Btot1));
1534     Aimap2  = Kokkos::create_mirror_view_and_copy(exec, PetscCountKokkosViewHost(coo_h->Aimap2, coo_h->Annz2));
1535     Ajmap2  = Kokkos::create_mirror_view_and_copy(exec, PetscCountKokkosViewHost(coo_h->Ajmap2, coo_h->Annz2 + 1));
1536     Aperm2  = Kokkos::create_mirror_view_and_copy(exec, PetscCountKokkosViewHost(coo_h->Aperm2, coo_h->Atot2));
1537     Bimap2  = Kokkos::create_mirror_view_and_copy(exec, PetscCountKokkosViewHost(coo_h->Bimap2, coo_h->Bnnz2));
1538     Bjmap2  = Kokkos::create_mirror_view_and_copy(exec, PetscCountKokkosViewHost(coo_h->Bjmap2, coo_h->Bnnz2 + 1));
1539     Bperm2  = Kokkos::create_mirror_view_and_copy(exec, PetscCountKokkosViewHost(coo_h->Bperm2, coo_h->Btot2));
1540     Cperm1  = Kokkos::create_mirror_view_and_copy(exec, PetscCountKokkosViewHost(coo_h->Cperm1, coo_h->sendlen));
1541     sendbuf = Kokkos::create_mirror_view(Kokkos::WithoutInitializing, exec, MatScalarKokkosViewHost(coo_h->sendbuf, coo_h->sendlen));
1542     recvbuf = Kokkos::create_mirror_view(Kokkos::WithoutInitializing, exec, MatScalarKokkosViewHost(coo_h->recvbuf, coo_h->recvlen));
1543     PetscCallVoid(PetscObjectReference((PetscObject)sf));
1544   }
1545 
1546   ~MatCOOStruct_MPIAIJKokkos() { PetscCallVoid(PetscSFDestroy(&sf)); }
1547 };
1548 
1549 static PetscErrorCode MatCOOStructDestroy_MPIAIJKokkos(void *data)
1550 {
1551   PetscFunctionBegin;
1552   PetscCallCXX(delete static_cast<MatCOOStruct_MPIAIJKokkos *>(data));
1553   PetscFunctionReturn(PETSC_SUCCESS);
1554 }
1555 
1556 static PetscErrorCode MatSetPreallocationCOO_MPIAIJKokkos(Mat mat, PetscCount coo_n, PetscInt coo_i[], PetscInt coo_j[])
1557 {
1558   PetscContainer             container_h, container_d;
1559   MatCOOStruct_MPIAIJ       *coo_h;
1560   MatCOOStruct_MPIAIJKokkos *coo_d;
1561 
1562   PetscFunctionBegin;
1563   PetscCall(MatSetPreallocationCOO_MPIAIJ(mat, coo_n, coo_i, coo_j)); /* mpiaij->A,B's type is set to seqaijkokkos */
1564   mat->preallocated = PETSC_TRUE;
1565   PetscCall(MatAssemblyBegin(mat, MAT_FINAL_ASSEMBLY));
1566   PetscCall(MatAssemblyEnd(mat, MAT_FINAL_ASSEMBLY));
1567   PetscCall(MatZeroEntries(mat));
1568 
1569   // Copy the COO struct to device
1570   PetscCall(PetscObjectQuery((PetscObject)mat, "__PETSc_MatCOOStruct_Host", (PetscObject *)&container_h));
1571   PetscCall(PetscContainerGetPointer(container_h, (void **)&coo_h));
1572   PetscCallCXX(coo_d = new MatCOOStruct_MPIAIJKokkos(coo_h));
1573 
1574   // Put the COO struct in a container and then attach that to the matrix
1575   PetscCall(PetscContainerCreate(PETSC_COMM_SELF, &container_d));
1576   PetscCall(PetscContainerSetPointer(container_d, coo_d));
1577   PetscCall(PetscContainerSetUserDestroy(container_d, MatCOOStructDestroy_MPIAIJKokkos));
1578   PetscCall(PetscObjectCompose((PetscObject)mat, "__PETSc_MatCOOStruct_Device", (PetscObject)container_d));
1579   PetscCall(PetscContainerDestroy(&container_d));
1580   PetscFunctionReturn(PETSC_SUCCESS);
1581 }
1582 
1583 static PetscErrorCode MatSetValuesCOO_MPIAIJKokkos(Mat mat, const PetscScalar v[], InsertMode imode)
1584 {
1585   Mat_MPIAIJ                    *mpiaij = static_cast<Mat_MPIAIJ *>(mat->data);
1586   Mat                            A = mpiaij->A, B = mpiaij->B;
1587   MatScalarKokkosView            Aa, Ba;
1588   MatScalarKokkosView            v1;
1589   PetscMemType                   memtype;
1590   PetscContainer                 container;
1591   MatCOOStruct_MPIAIJKokkos     *coo;
1592   Kokkos::DefaultExecutionSpace &exec = PetscGetKokkosExecutionSpace();
1593 
1594   PetscFunctionBegin;
1595   PetscCall(PetscObjectQuery((PetscObject)mat, "__PETSc_MatCOOStruct_Device", (PetscObject *)&container));
1596   PetscCall(PetscContainerGetPointer(container, (void **)&coo));
1597 
1598   const auto &n      = coo->n;
1599   const auto &Annz   = coo->Annz;
1600   const auto &Annz2  = coo->Annz2;
1601   const auto &Bnnz   = coo->Bnnz;
1602   const auto &Bnnz2  = coo->Bnnz2;
1603   const auto &vsend  = coo->sendbuf;
1604   const auto &v2     = coo->recvbuf;
1605   const auto &Ajmap1 = coo->Ajmap1;
1606   const auto &Ajmap2 = coo->Ajmap2;
1607   const auto &Aimap2 = coo->Aimap2;
1608   const auto &Bjmap1 = coo->Bjmap1;
1609   const auto &Bjmap2 = coo->Bjmap2;
1610   const auto &Bimap2 = coo->Bimap2;
1611   const auto &Aperm1 = coo->Aperm1;
1612   const auto &Aperm2 = coo->Aperm2;
1613   const auto &Bperm1 = coo->Bperm1;
1614   const auto &Bperm2 = coo->Bperm2;
1615   const auto &Cperm1 = coo->Cperm1;
1616 
1617   PetscCall(PetscGetMemType(v, &memtype)); /* Return PETSC_MEMTYPE_HOST when v is NULL */
1618   if (PetscMemTypeHost(memtype)) {         /* If user gave v[] in host, we need to copy it to device if any */
1619     v1 = Kokkos::create_mirror_view_and_copy(exec, MatScalarKokkosViewHost((PetscScalar *)v, n));
1620   } else {
1621     v1 = MatScalarKokkosView((PetscScalar *)v, n); /* Directly use v[]'s memory */
1622   }
1623 
1624   if (imode == INSERT_VALUES) {
1625     PetscCall(MatSeqAIJGetKokkosViewWrite(A, &Aa)); /* write matrix values */
1626     PetscCall(MatSeqAIJGetKokkosViewWrite(B, &Ba));
1627   } else {
1628     PetscCall(MatSeqAIJGetKokkosView(A, &Aa)); /* read & write matrix values */
1629     PetscCall(MatSeqAIJGetKokkosView(B, &Ba));
1630   }
1631 
1632   PetscCall(PetscLogGpuTimeBegin());
1633   /* Pack entries to be sent to remote */
1634   Kokkos::parallel_for(Kokkos::RangePolicy<>(exec, 0, vsend.extent(0)), KOKKOS_LAMBDA(const PetscCount i) { vsend(i) = v1(Cperm1(i)); });
1635 
1636   /* Send remote entries to their owner and overlap the communication with local computation */
1637   PetscCall(PetscSFReduceWithMemTypeBegin(coo->sf, MPIU_SCALAR, PETSC_MEMTYPE_KOKKOS, vsend.data(), PETSC_MEMTYPE_KOKKOS, v2.data(), MPI_REPLACE));
1638   /* Add local entries to A and B in one kernel */
1639   Kokkos::parallel_for(
1640     Kokkos::RangePolicy<>(exec, 0, Annz + Bnnz), KOKKOS_LAMBDA(PetscCount i) {
1641       PetscScalar sum = 0.0;
1642       if (i < Annz) {
1643         for (PetscCount k = Ajmap1(i); k < Ajmap1(i + 1); k++) sum += v1(Aperm1(k));
1644         Aa(i) = (imode == INSERT_VALUES ? 0.0 : Aa(i)) + sum;
1645       } else {
1646         i -= Annz;
1647         for (PetscCount k = Bjmap1(i); k < Bjmap1(i + 1); k++) sum += v1(Bperm1(k));
1648         Ba(i) = (imode == INSERT_VALUES ? 0.0 : Ba(i)) + sum;
1649       }
1650     });
1651   PetscCall(PetscSFReduceEnd(coo->sf, MPIU_SCALAR, vsend.data(), v2.data(), MPI_REPLACE));
1652 
1653   /* Add received remote entries to A and B in one kernel */
1654   Kokkos::parallel_for(
1655     Kokkos::RangePolicy<>(exec, 0, Annz2 + Bnnz2), KOKKOS_LAMBDA(PetscCount i) {
1656       if (i < Annz2) {
1657         for (PetscCount k = Ajmap2(i); k < Ajmap2(i + 1); k++) Aa(Aimap2(i)) += v2(Aperm2(k));
1658       } else {
1659         i -= Annz2;
1660         for (PetscCount k = Bjmap2(i); k < Bjmap2(i + 1); k++) Ba(Bimap2(i)) += v2(Bperm2(k));
1661       }
1662     });
1663   PetscCall(PetscLogGpuTimeEnd());
1664 
1665   if (imode == INSERT_VALUES) {
1666     PetscCall(MatSeqAIJRestoreKokkosViewWrite(A, &Aa)); /* Increase A & B's state etc. */
1667     PetscCall(MatSeqAIJRestoreKokkosViewWrite(B, &Ba));
1668   } else {
1669     PetscCall(MatSeqAIJRestoreKokkosView(A, &Aa));
1670     PetscCall(MatSeqAIJRestoreKokkosView(B, &Ba));
1671   }
1672   PetscFunctionReturn(PETSC_SUCCESS);
1673 }
1674 
1675 static PetscErrorCode MatDestroy_MPIAIJKokkos(Mat A)
1676 {
1677   PetscFunctionBegin;
1678   PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatMPIAIJSetPreallocation_C", NULL));
1679   PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatMPIAIJGetLocalMatMerge_C", NULL));
1680   PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatSetPreallocationCOO_C", NULL));
1681   PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatSetValuesCOO_C", NULL));
1682   PetscCall(MatDestroy_MPIAIJ(A));
1683   PetscFunctionReturn(PETSC_SUCCESS);
1684 }
1685 
1686 static PetscErrorCode MatShift_MPIAIJKokkos(Mat A, PetscScalar a)
1687 {
1688   Mat_MPIAIJ *mpiaij = static_cast<Mat_MPIAIJ *>(A->data);
1689   PetscBool   congruent;
1690 
1691   PetscFunctionBegin;
1692   PetscCall(MatHasCongruentLayouts(A, &congruent));
1693   if (congruent) { // square matrix and the diagonals are solely in the diag block
1694     PetscCall(MatShift(mpiaij->A, a));
1695   } else { // too hard, use the general version
1696     PetscCall(MatShift_Basic(A, a));
1697   }
1698   PetscFunctionReturn(PETSC_SUCCESS);
1699 }
1700 
1701 static PetscErrorCode MatSetOps_MPIAIJKokkos(Mat B)
1702 {
1703   PetscFunctionBegin;
1704   B->ops->assemblyend           = MatAssemblyEnd_MPIAIJKokkos;
1705   B->ops->mult                  = MatMult_MPIAIJKokkos;
1706   B->ops->multadd               = MatMultAdd_MPIAIJKokkos;
1707   B->ops->multtranspose         = MatMultTranspose_MPIAIJKokkos;
1708   B->ops->productsetfromoptions = MatProductSetFromOptions_MPIAIJKokkos;
1709   B->ops->destroy               = MatDestroy_MPIAIJKokkos;
1710   B->ops->shift                 = MatShift_MPIAIJKokkos;
1711 
1712   PetscCall(PetscObjectComposeFunction((PetscObject)B, "MatMPIAIJSetPreallocation_C", MatMPIAIJSetPreallocation_MPIAIJKokkos));
1713   PetscCall(PetscObjectComposeFunction((PetscObject)B, "MatMPIAIJGetLocalMatMerge_C", MatMPIAIJGetLocalMatMerge_MPIAIJKokkos));
1714   PetscCall(PetscObjectComposeFunction((PetscObject)B, "MatSetPreallocationCOO_C", MatSetPreallocationCOO_MPIAIJKokkos));
1715   PetscCall(PetscObjectComposeFunction((PetscObject)B, "MatSetValuesCOO_C", MatSetValuesCOO_MPIAIJKokkos));
1716   PetscFunctionReturn(PETSC_SUCCESS);
1717 }
1718 
1719 PETSC_INTERN PetscErrorCode MatConvert_MPIAIJ_MPIAIJKokkos(Mat A, MatType mtype, MatReuse reuse, Mat *newmat)
1720 {
1721   Mat         B;
1722   Mat_MPIAIJ *a;
1723 
1724   PetscFunctionBegin;
1725   if (reuse == MAT_INITIAL_MATRIX) {
1726     PetscCall(MatDuplicate(A, MAT_COPY_VALUES, newmat));
1727   } else if (reuse == MAT_REUSE_MATRIX) {
1728     PetscCall(MatCopy(A, *newmat, SAME_NONZERO_PATTERN));
1729   }
1730   B = *newmat;
1731 
1732   B->boundtocpu = PETSC_FALSE;
1733   PetscCall(PetscFree(B->defaultvectype));
1734   PetscCall(PetscStrallocpy(VECKOKKOS, &B->defaultvectype));
1735   PetscCall(PetscObjectChangeTypeName((PetscObject)B, MATMPIAIJKOKKOS));
1736 
1737   a = static_cast<Mat_MPIAIJ *>(A->data);
1738   if (a->A) PetscCall(MatSetType(a->A, MATSEQAIJKOKKOS));
1739   if (a->B) PetscCall(MatSetType(a->B, MATSEQAIJKOKKOS));
1740   if (a->lvec) PetscCall(VecSetType(a->lvec, VECSEQKOKKOS));
1741   PetscCall(MatSetOps_MPIAIJKokkos(B));
1742   PetscFunctionReturn(PETSC_SUCCESS);
1743 }
1744 
1745 /*MC
1746    MATAIJKOKKOS - "mpiaijkokkos", a matrix type to be used for CSR sparse matrices with Kokkos
1747 
1748    A matrix type using Kokkos-Kernels CrsMatrix type for portability across different device types
1749 
1750    Options Database Key:
1751 .  -mat_type aijkokkos - sets the matrix type to `MATAIJKOKKOS`
1752 
1753   Level: beginner
1754 
1755 .seealso: [](ch_matrices), `Mat`, `MatCreateAIJKokkos()`, `MATSEQAIJKOKKOS`, `MATSEQAIJ`, `MATMPIAIJ`
1756 M*/
1757 PETSC_EXTERN PetscErrorCode MatCreate_MPIAIJKokkos(Mat A)
1758 {
1759   PetscFunctionBegin;
1760   PetscCall(PetscKokkosInitializeCheck());
1761   PetscCall(MatCreate_MPIAIJ(A));
1762   PetscCall(MatConvert_MPIAIJ_MPIAIJKokkos(A, MATMPIAIJKOKKOS, MAT_INPLACE_MATRIX, &A));
1763   PetscFunctionReturn(PETSC_SUCCESS);
1764 }
1765 
1766 /*@C
1767   MatCreateAIJKokkos - Creates a sparse matrix in `MATAIJKOKOS` (compressed row) format
1768   (the default parallel PETSc format).  This matrix will ultimately pushed down
1769   to Kokkos for calculations.
1770 
1771   Collective
1772 
1773   Input Parameters:
1774 + comm  - MPI communicator, set to `PETSC_COMM_SELF`
1775 . m     - number of local rows (or `PETSC_DECIDE` to have calculated if `M` is given)
1776            This value should be the same as the local size used in creating the
1777            y vector for the matrix-vector product y = Ax.
1778 . n     - This value should be the same as the local size used in creating the
1779        x vector for the matrix-vector product y = Ax. (or `PETSC_DECIDE` to have
1780        calculated if N is given) For square matrices n is almost always `m`.
1781 . M     - number of global rows (or `PETSC_DETERMINE` to have calculated if `m` is given)
1782 . N     - number of global columns (or `PETSC_DETERMINE` to have calculated if `n` is given)
1783 . d_nz  - number of nonzeros per row in DIAGONAL portion of local submatrix
1784            (same value is used for all local rows)
1785 . d_nnz - array containing the number of nonzeros in the various rows of the
1786            DIAGONAL portion of the local submatrix (possibly different for each row)
1787            or `NULL`, if `d_nz` is used to specify the nonzero structure.
1788            The size of this array is equal to the number of local rows, i.e `m`.
1789            For matrices you plan to factor you must leave room for the diagonal entry and
1790            put in the entry even if it is zero.
1791 . o_nz  - number of nonzeros per row in the OFF-DIAGONAL portion of local
1792            submatrix (same value is used for all local rows).
1793 - o_nnz - array containing the number of nonzeros in the various rows of the
1794            OFF-DIAGONAL portion of the local submatrix (possibly different for
1795            each row) or `NULL`, if `o_nz` is used to specify the nonzero
1796            structure. The size of this array is equal to the number
1797            of local rows, i.e `m`.
1798 
1799   Output Parameter:
1800 . A - the matrix
1801 
1802   Level: intermediate
1803 
1804   Notes:
1805   It is recommended that one use the `MatCreate()`, `MatSetType()` and/or `MatSetFromOptions()`,
1806   MatXXXXSetPreallocation() paradigm instead of this routine directly.
1807   [MatXXXXSetPreallocation() is, for example, `MatSeqAIJSetPreallocation()`]
1808 
1809   The AIJ format, also called compressed row storage), is fully compatible with standard Fortran
1810   storage.  That is, the stored row and column indices can begin at
1811   either one (as in Fortran) or zero.
1812 
1813 .seealso: [](ch_matrices), `Mat`, `MATAIJKOKOS`, `MATSEQAIJKOKOS`, `MATMPIAIJKOKOS`, `MatCreate()`, `MatCreateAIJ()`, `MatSetValues()`,
1814           `MatSeqAIJSetColumnIndices()`, `MatCreateSeqAIJWithArrays()`, `MATMPIAIJKOKKOS`, `MATAIJKOKKOS`
1815 @*/
1816 PetscErrorCode MatCreateAIJKokkos(MPI_Comm comm, PetscInt m, PetscInt n, PetscInt M, PetscInt N, PetscInt d_nz, const PetscInt d_nnz[], PetscInt o_nz, const PetscInt o_nnz[], Mat *A)
1817 {
1818   PetscMPIInt size;
1819 
1820   PetscFunctionBegin;
1821   PetscCall(MatCreate(comm, A));
1822   PetscCall(MatSetSizes(*A, m, n, M, N));
1823   PetscCallMPI(MPI_Comm_size(comm, &size));
1824   if (size > 1) {
1825     PetscCall(MatSetType(*A, MATMPIAIJKOKKOS));
1826     PetscCall(MatMPIAIJSetPreallocation(*A, d_nz, d_nnz, o_nz, o_nnz));
1827   } else {
1828     PetscCall(MatSetType(*A, MATSEQAIJKOKKOS));
1829     PetscCall(MatSeqAIJSetPreallocation(*A, d_nz, d_nnz));
1830   }
1831   PetscFunctionReturn(PETSC_SUCCESS);
1832 }
1833