xref: /petsc/src/mat/impls/aij/mpi/kokkos/mpiaijkok.kokkos.cxx (revision 83d0d507e8eaf8d844b49c5dbd0d8d7c8cefa37b)
1 #include <petsc_kokkos.hpp>
2 #include <petscvec_kokkos.hpp>
3 #include <petscpkg_version.h>
4 #include <petsc/private/sfimpl.h>
5 #include <petsc/private/kokkosimpl.hpp>
6 #include <../src/mat/impls/aij/seq/kokkos/aijkok.hpp>
7 #include <../src/mat/impls/aij/mpi/mpiaij.h>
8 #include <KokkosSparse_spadd.hpp>
9 #include <KokkosSparse_spgemm.hpp>
10 
11 static PetscErrorCode MatAssemblyEnd_MPIAIJKokkos(Mat A, MatAssemblyType mode)
12 {
13   Mat_MPIAIJ *mpiaij = (Mat_MPIAIJ *)A->data;
14 
15   PetscFunctionBegin;
16   PetscCall(MatAssemblyEnd_MPIAIJ(A, mode));
17   /* E.g., MatCreateSubMatrix() calls MatCreateMPIAIJWithSeqAIJ(comm,A,B,..), which creates Bnew of SEQAIJ and destroys B of SEQAIJKOKKOS.
18      Thus we finalize A/B/lvec's type in MatAssemblyEnd() to handle various cases.
19    */
20   if (mode == MAT_FINAL_ASSEMBLY) {
21     PetscScalarKokkosView v;
22 
23     PetscCall(MatSetType(mpiaij->A, MATSEQAIJKOKKOS));
24     PetscCall(MatSetType(mpiaij->B, MATSEQAIJKOKKOS));
25     PetscCall(VecSetType(mpiaij->lvec, VECSEQKOKKOS));  // lvec is init'ed on host, without copying to device
26     PetscCall(VecGetKokkosViewWrite(mpiaij->lvec, &v)); // mark lvec updated on device, as we never need to init lvec on device
27     PetscCall(VecRestoreKokkosViewWrite(mpiaij->lvec, &v));
28   }
29   PetscFunctionReturn(PETSC_SUCCESS);
30 }
31 
32 static PetscErrorCode MatMPIAIJSetPreallocation_MPIAIJKokkos(Mat mat, PetscInt d_nz, const PetscInt d_nnz[], PetscInt o_nz, const PetscInt o_nnz[])
33 {
34   Mat_MPIAIJ *mpiaij;
35 
36   PetscFunctionBegin;
37   // reuse MPIAIJ's preallocation, which sets A/B's blocksize along other things
38   PetscCall(MatMPIAIJSetPreallocation_MPIAIJ(mat, d_nz, d_nnz, o_nz, o_nnz));
39   mpiaij = static_cast<Mat_MPIAIJ *>(mat->data);
40   PetscCall(MatConvert_SeqAIJ_SeqAIJKokkos(mpiaij->A, MATSEQAIJKOKKOS, MAT_INPLACE_MATRIX, &mpiaij->A));
41   PetscCall(MatConvert_SeqAIJ_SeqAIJKokkos(mpiaij->B, MATSEQAIJKOKKOS, MAT_INPLACE_MATRIX, &mpiaij->B));
42   PetscFunctionReturn(PETSC_SUCCESS);
43 }
44 
45 static PetscErrorCode MatMult_MPIAIJKokkos(Mat mat, Vec xx, Vec yy)
46 {
47   Mat_MPIAIJ *mpiaij = (Mat_MPIAIJ *)mat->data;
48   PetscInt    nt;
49 
50   PetscFunctionBegin;
51   PetscCall(VecGetLocalSize(xx, &nt));
52   PetscCheck(nt == mat->cmap->n, PETSC_COMM_SELF, PETSC_ERR_ARG_SIZ, "Incompatible partition of mat (%" PetscInt_FMT ") and xx (%" PetscInt_FMT ")", mat->cmap->n, nt);
53   PetscCall(VecScatterBegin(mpiaij->Mvctx, xx, mpiaij->lvec, INSERT_VALUES, SCATTER_FORWARD));
54   PetscCall((*mpiaij->A->ops->mult)(mpiaij->A, xx, yy));
55   PetscCall(VecScatterEnd(mpiaij->Mvctx, xx, mpiaij->lvec, INSERT_VALUES, SCATTER_FORWARD));
56   PetscCall((*mpiaij->B->ops->multadd)(mpiaij->B, mpiaij->lvec, yy, yy));
57   PetscFunctionReturn(PETSC_SUCCESS);
58 }
59 
60 static PetscErrorCode MatMultAdd_MPIAIJKokkos(Mat mat, Vec xx, Vec yy, Vec zz)
61 {
62   Mat_MPIAIJ *mpiaij = (Mat_MPIAIJ *)mat->data;
63   PetscInt    nt;
64 
65   PetscFunctionBegin;
66   PetscCall(VecGetLocalSize(xx, &nt));
67   PetscCheck(nt == mat->cmap->n, PETSC_COMM_SELF, PETSC_ERR_ARG_SIZ, "Incompatible partition of mat (%" PetscInt_FMT ") and xx (%" PetscInt_FMT ")", mat->cmap->n, nt);
68   PetscCall(VecScatterBegin(mpiaij->Mvctx, xx, mpiaij->lvec, INSERT_VALUES, SCATTER_FORWARD));
69   PetscCall((*mpiaij->A->ops->multadd)(mpiaij->A, xx, yy, zz));
70   PetscCall(VecScatterEnd(mpiaij->Mvctx, xx, mpiaij->lvec, INSERT_VALUES, SCATTER_FORWARD));
71   PetscCall((*mpiaij->B->ops->multadd)(mpiaij->B, mpiaij->lvec, zz, zz));
72   PetscFunctionReturn(PETSC_SUCCESS);
73 }
74 
75 static PetscErrorCode MatMultTranspose_MPIAIJKokkos(Mat mat, Vec xx, Vec yy)
76 {
77   Mat_MPIAIJ *mpiaij = (Mat_MPIAIJ *)mat->data;
78   PetscInt    nt;
79 
80   PetscFunctionBegin;
81   PetscCall(VecGetLocalSize(xx, &nt));
82   PetscCheck(nt == mat->rmap->n, PETSC_COMM_SELF, PETSC_ERR_ARG_SIZ, "Incompatible partition of mat (%" PetscInt_FMT ") and xx (%" PetscInt_FMT ")", mat->rmap->n, nt);
83   PetscCall((*mpiaij->B->ops->multtranspose)(mpiaij->B, xx, mpiaij->lvec));
84   PetscCall((*mpiaij->A->ops->multtranspose)(mpiaij->A, xx, yy));
85   PetscCall(VecScatterBegin(mpiaij->Mvctx, mpiaij->lvec, yy, ADD_VALUES, SCATTER_REVERSE));
86   PetscCall(VecScatterEnd(mpiaij->Mvctx, mpiaij->lvec, yy, ADD_VALUES, SCATTER_REVERSE));
87   PetscFunctionReturn(PETSC_SUCCESS);
88 }
89 
90 /* Merge the "A, B" matrices of mat into a matrix C.  mat's type is MPIAIJKOKKOS. C's type is MATSEQAIJKOKKOS.
91    A is put before B. C's size would be A->rmap->n by (A->cmap->n + B->cmap->n).
92    C still uses local column ids. Their corresponding global column ids are returned in glob.
93 */
94 static PetscErrorCode MatMPIAIJGetLocalMatMerge_MPIAIJKokkos(Mat mat, MatReuse reuse, IS *glob, Mat *C)
95 {
96   Mat             Ad, Ao;
97   const PetscInt *cmap;
98 
99   PetscFunctionBegin;
100   PetscCall(MatMPIAIJGetSeqAIJ(mat, &Ad, &Ao, &cmap));
101   PetscCall(MatSeqAIJKokkosMergeMats(Ad, Ao, reuse, C));
102   if (glob) {
103     PetscInt cst, i, dn, on, *gidx;
104     PetscCall(MatGetLocalSize(Ad, NULL, &dn));
105     PetscCall(MatGetLocalSize(Ao, NULL, &on));
106     PetscCall(MatGetOwnershipRangeColumn(mat, &cst, NULL));
107     PetscCall(PetscMalloc1(dn + on, &gidx));
108     for (i = 0; i < dn; i++) gidx[i] = cst + i;
109     for (i = 0; i < on; i++) gidx[i + dn] = cmap[i];
110     PetscCall(ISCreateGeneral(PetscObjectComm((PetscObject)Ad), dn + on, gidx, PETSC_OWN_POINTER, glob));
111   }
112   PetscFunctionReturn(PETSC_SUCCESS);
113 }
114 
115 /* Structs used in matrix products of type C=AB, C=A^tB and C=B^tAB */
116 struct MatMatStruct {
117   PetscInt            n, *garray;     // C's garray and its size.
118   KokkosCsrMatrix     Cd, Co;         // C is in split form matrices (all in local column indcies)
119   KokkosCsrMatrix     C1, C2, C3, C4; // intermediate mat products
120   KokkosCsrMatrix     C2_mid, C4_mid; // alias of C2, C4; share their a[], i[], but with different j[] (hence column size)
121   PetscIntKokkosView  E_NzLeft;
122   PetscSF             sf = nullptr; // SF to bcast or reduce matrices E to F
123   MatScalarKokkosView rootBuf, leafBuf;
124   KokkosCsrMatrix     Fd, Fo; // F in split form
125 
126   KernelHandle kh1; // compute C1, add C1+C3 or C1+Fd
127   KernelHandle kh2; // compute C2, add C2+C4 or C2+Fo
128   KernelHandle kh3; // compute C3
129   KernelHandle kh4; // compute C4
130 
131   PetscInt E_TeamSize; // kernel launching parameters in merging E or splitting F
132   PetscInt E_VectorLength;
133   PetscInt E_RowsPerTeam;
134   PetscInt F_TeamSize;
135   PetscInt F_VectorLength;
136   PetscInt F_RowsPerTeam;
137 
138   ~MatMatStruct()
139   {
140     PetscFunctionBegin;
141     PetscCallAbort(PETSC_COMM_SELF, PetscSFDestroy(&sf));
142     PetscFunctionReturnVoid();
143   }
144 };
145 
146 struct MatMatStruct_AB : public MatMatStruct {
147   PetscIntKokkosView F_NzLeft; // plans to split F (in leafbuf) into Fd, Fo
148   PetscIntKokkosView irootloc; // plans to put E (i.e., Bd, Bo) into rootBuf
149   PetscIntKokkosView rowoffset;
150 };
151 
152 struct MatMatStruct_AtB : public MatMatStruct {
153   MatColIdxKokkosView Fdjmap; // plans to reduce data in rootBuf to Fd, Fo
154   MatColIdxKokkosView Fdjperm;
155   MatColIdxKokkosView Fojmap;
156   MatColIdxKokkosView Fojperm;
157 };
158 
159 struct MatProductData_MPIAIJKokkos {
160   MatMatStruct_AB  *mmAB     = nullptr;
161   MatMatStruct_AtB *mmAtB    = nullptr;
162   PetscBool         reusesym = PETSC_FALSE;
163   Mat               Z        = nullptr; // store Z=AB in computing BtAB
164 
165   ~MatProductData_MPIAIJKokkos()
166   {
167     delete mmAB;
168     delete mmAtB;
169     PetscCallAbort(PETSC_COMM_SELF, MatDestroy(&Z));
170   }
171 };
172 
173 static PetscErrorCode MatProductDataDestroy_MPIAIJKokkos(void *data)
174 {
175   PetscFunctionBegin;
176   PetscCallCXX(delete static_cast<MatProductData_MPIAIJKokkos *>(data));
177   PetscFunctionReturn(PETSC_SUCCESS);
178 }
179 
180 /* MatSetMPIAIJKokkosWithSplitSeqAIJKokkosMatrices - Set the diag and offdiag matrices of a MATMPIAIJKOKKOS matrix.
181    It is similar to MatCreateMPIAIJWithSplitArrays.
182 
183   Input Parameters:
184 +  mat   - the MATMPIAIJKOKKOS matrix, which should have its type and layout set, but should not have its diag, offdiag matrices set
185 .  A     - the diag matrix using local col ids
186 -  B     - the offdiag matrix using global col ids
187 
188   Output Parameter:
189 .  mat   - the updated MATMPIAIJKOKKOS matrix
190 */
191 static PetscErrorCode MatSetMPIAIJKokkosWithSplitSeqAIJKokkosMatrices(Mat mat, Mat A, Mat B, PetscInt *garray)
192 {
193   Mat_MPIAIJ *mpiaij = static_cast<Mat_MPIAIJ *>(mat->data);
194   PetscInt    m, n, M, N, Am, An, Bm, Bn;
195 
196   PetscFunctionBegin;
197   PetscCall(MatGetSize(mat, &M, &N));
198   PetscCall(MatGetLocalSize(mat, &m, &n));
199   PetscCall(MatGetLocalSize(A, &Am, &An));
200   PetscCall(MatGetLocalSize(B, &Bm, &Bn));
201 
202   PetscCheck(m == Am && m == Bm, PETSC_COMM_SELF, PETSC_ERR_PLIB, "local number of rows do not match");
203   PetscCheck(n == An, PETSC_COMM_SELF, PETSC_ERR_PLIB, "local number of columns do not match");
204   // PetscCheck(N == Bn, PETSC_COMM_SELF, PETSC_ERR_PLIB, "global number of columns do not match");
205   PetscCheck(!mpiaij->A && !mpiaij->B, PETSC_COMM_SELF, PETSC_ERR_PLIB, "A, B of the MPIAIJ matrix are not empty");
206   mpiaij->A      = A;
207   mpiaij->B      = B;
208   mpiaij->garray = garray;
209 
210   mat->preallocated     = PETSC_TRUE;
211   mat->nooffprocentries = PETSC_TRUE; /* See MatAssemblyBegin_MPIAIJ. In effect, making MatAssemblyBegin a nop */
212 
213   PetscCall(MatSetOption(mat, MAT_NO_OFF_PROC_ENTRIES, PETSC_TRUE));
214   PetscCall(MatAssemblyBegin(mat, MAT_FINAL_ASSEMBLY));
215   /* MatAssemblyEnd is critical here. It sets mat->offloadmask according to A and B's, and
216     also gets mpiaij->B compacted, with its col ids and size reduced
217   */
218   PetscCall(MatAssemblyEnd(mat, MAT_FINAL_ASSEMBLY));
219   PetscCall(MatSetOption(mat, MAT_NO_OFF_PROC_ENTRIES, PETSC_FALSE));
220   PetscCall(MatSetOption(mat, MAT_NEW_NONZERO_LOCATION_ERR, PETSC_TRUE));
221   PetscFunctionReturn(PETSC_SUCCESS);
222 }
223 
224 // Adapted from Kokkos-Kernels spmv_launch_parameters(), to get parameters in Kokkos nested loops which we used to merge or
225 // split csr matrices. The rule is to have "vector_length * team_size" be around 256 on GPUs (e.g., for a CUDA thread block)
226 template <class ExecutionSpace>
227 static PetscErrorCode MatMergeGetLaunchParameters(PetscInt numRows, PetscInt nnz, PetscInt rows_per_thread, PetscInt &team_size, PetscInt &vector_length, PetscInt &rows_per_team)
228 {
229   Kokkos::TeamPolicy<ExecutionSpace> teamPolicy(128, Kokkos::AUTO);
230 
231   PetscFunctionBegin;
232   PetscInt nnz_per_row = numRows ? (nnz / numRows) : 0; // we might meet empty matrices
233 
234   if (nnz_per_row < 1) nnz_per_row = 1;
235 
236   int max_vector_length = teamPolicy.vector_length_max();
237 
238   if (vector_length < 1) {
239     vector_length = 1;
240     while (vector_length < max_vector_length && vector_length * 6 < nnz_per_row) vector_length *= 2;
241   }
242 
243   // Determine rows per thread
244   if (rows_per_thread < 1) {
245     if (KokkosKernels::Impl::kk_is_gpu_exec_space<ExecutionSpace>()) rows_per_thread = 1;
246     else {
247       if (nnz_per_row < 20 && nnz > 5000000) {
248         rows_per_thread = 256;
249       } else rows_per_thread = 64;
250     }
251   }
252 
253   if (team_size < 1) {
254     if (KokkosKernels::Impl::kk_is_gpu_exec_space<ExecutionSpace>()) {
255       team_size = 256 / vector_length;
256     } else {
257       team_size = 1;
258     }
259   }
260 
261   rows_per_team = rows_per_thread * team_size;
262 
263   if (rows_per_team < 0) {
264     PetscInt nnz_per_team = 4096;
265     PetscInt conc         = ExecutionSpace().concurrency();
266     while ((conc * nnz_per_team * 4 > nnz) && (nnz_per_team > 256)) nnz_per_team /= 2;
267     rows_per_team = (nnz_per_team + nnz_per_row - 1) / nnz_per_row;
268   }
269   PetscFunctionReturn(PETSC_SUCCESS);
270 }
271 
272 /*
273   Reduce two sets of global indices into local ones
274 
275   Input Parameters:
276 +  n1          - size of garray1[], the first set
277 .  garray1[n1] - a sorted global index array (without duplicates)
278 .  m           - size of indices[], the second set
279 -  indices[m]  - a unsorted global index array (might have duplicates), which will be updated on output into local ones
280 
281   Output Parameters:
282 +  n2          - size of garray2[], the merged set, which combines garray1[] and indices[]
283 .  garray2[n2] - allocated by callee using PetscMalloc1(). Contains sorted unique global indices (without duplicates). Caller needs to free it.
284 .  map[n1]     - allocated by caller. It gives garray1[i] = garray2[map[i]]
285 -  indices[m]  - on output, global indices in this array are rewritten with local ones, i.e, indices_input[i] = garray2[indices_output[i]]
286 
287    Example, say
288     n1         = 5
289     garray1[5] = {1, 4, 7, 8, 10}
290     m          = 4
291     indices[4] = {2, 4, 8, 9}
292 
293    Combining them together, we have 7 global indices in garray2[]
294     n2         = 7
295     garray2[7] = {1, 2, 4, 7, 8, 9, 10}
296 
297    And we have map[] to connect "garray1[i] = garray2[map[i]], i=[0,n1)"
298     map[5] = {0, 2, 3, 4, 6}
299 
300    On output, indices[] is updated with local indices
301     indices[4] = {1, 2, 4, 5}
302 */
303 static PetscErrorCode ReduceTwoSetsOfGlobalIndices(PetscInt n1, const PetscInt *garray1, PetscInt m, PetscInt *indices, PetscInt *n2_, PetscInt **garray2_, PetscInt *map)
304 {
305   PetscHMapI    g2l = nullptr;
306   PetscHashIter iter;
307   PetscInt      tot, key, val; // total unique global indices. key is global id; val is local id
308   PetscInt      n2, *garray2;
309 
310   PetscFunctionBegin;
311   tot = 0;
312   PetscCall(PetscHMapICreateWithSize(n1, &g2l));
313   for (PetscInt i = 0; i < m; i++) {                                // insert those in indices[]
314     PetscCall(PetscHMapIGetWithDefault(g2l, indices[i], -1, &val)); // if not exist, val is set with -1
315     if (val < 0) PetscCall(PetscHMapISet(g2l, indices[i], tot++));  // val < 0 means gid is not in the hash table yet
316   }
317 
318   for (PetscInt i = 0; i < n1; i++) { // insert those in garray1[]
319     PetscCall(PetscHMapIGetWithDefault(g2l, garray1[i], -1, &val));
320     if (val < 0) PetscCall(PetscHMapISet(g2l, garray1[i], tot++));
321   }
322 
323   // Pull out (unique) globals in the hash table and put them in garray2[]
324   n2 = tot;
325   PetscCall(PetscMalloc1(n2, &garray2));
326   tot = 0;
327   PetscHashIterBegin(g2l, iter);
328   while (!PetscHashIterAtEnd(g2l, iter)) {
329     PetscHashIterGetKey(g2l, iter, key);
330     PetscHashIterNext(g2l, iter);
331     garray2[tot++] = key;
332   }
333 
334   // Sort garray2[] and then map them to local indices starting from 0
335   PetscCall(PetscSortInt(n2, garray2));
336   PetscCall(PetscHMapIClear(g2l));
337   for (PetscInt i = 0; i < tot; i++) PetscCall(PetscHMapISet(g2l, garray2[i], i)); // i is the local id
338 
339   // Rewrite indices[] with local indices
340   for (PetscInt i = 0; i < m; i++) {
341     PetscCall(PetscHMapIGetWithDefault(g2l, indices[i], -1, &val));
342     PetscAssert(val >= 0, PETSC_COMM_SELF, PETSC_ERR_PLIB, "Met a negative local column index");
343     indices[i] = val;
344   }
345   // Record the map that maps garray1[i] to garray2[map[i]]
346   for (PetscInt i = 0; i < n1; i++) PetscCall(PetscHMapIGetWithDefault(g2l, garray1[i], -1, &map[i]));
347   PetscCall(PetscHMapIDestroy(&g2l));
348   *n2_      = n2;
349   *garray2_ = garray2;
350   PetscFunctionReturn(PETSC_SUCCESS);
351 }
352 
353 /*
354   MatMPIAIJKokkosReduce - Reduce rows of a MPIAIJKOKKOS matrix (E, in split form) to produce another matrix (F, also in split form, stored in mm)
355 
356   It is the reverse of MatMPIAIJKokkosBcast() in some sense, but with a different signature since we do not really need a fully populated MPIAIJKOKKOS E.
357 
358   Think each row of E as a leaf, then the given ownerSF specifies roots for the leaves. Roots may connect to multiple leaves.
359   In this routine, we sparse-merge leaves (rows) at their roots to form potentially longer rows in F. F's number of rows will be nroots of ownerSF.
360 
361   Input Parameters:
362 +  comm       - MPI communicator of E
363 .  A          - diag block of E, using local column indices
364 .  B          - off-diag block of E, using local column indices
365 .  cstart      - (global) start column of Ed
366 .  cend        - (global) end column + 1 of Ed.  In other words, E's column ownership is in range of [cstart, cend)
367 .  garray1[n1] - global column indices of Eo. Here n1 is Eo's column size.
368 .  ownerSF     - the SF specifies ownership (root) of rows in E
369 .  reuse       - either MAT_INITIAL_MATRIX or MAT_REUSE_MATRIX
370 -  mm          - to stash intermediate data structures for reuse
371 
372   Output Parameters:
373 +  map[n1]  - allocated by caller. It maps garray1[] to garray2[]. See more at ReduceTwoSetsOfGlobalIndices().
374 -  mm       - contains various info, such as garray2[], F (Fd, Fo) etc.
375 
376   Notes:
377   When reuse = MAT_REUSE_MATRIX, cstart, cend, garray1, ownerSF, map are not significant.
378 
379  */
380 static PetscErrorCode MatMPIAIJKokkosReduceBegin(MPI_Comm comm, KokkosCsrMatrix A, KokkosCsrMatrix B, PetscInt cstart, PetscInt cend, const PetscInt *garray1, PetscSF ownerSF, MatReuse reuse, PetscInt *map, MatMatStruct_AtB *mm)
381 {
382   PetscFunctionBegin;
383   if (reuse == MAT_INITIAL_MATRIX) {
384     PetscInt Em = A.numRows(), Fm;
385     PetscInt n1 = B.numCols();
386 
387     PetscCall(PetscSFGetGraph(ownerSF, &Fm, NULL, NULL, NULL)); // Fm = #rows of F = nroots of ownerSF
388 
389     // Do the analysis on host
390     auto                 Ai_h = Kokkos::create_mirror_view_and_copy(Kokkos::HostSpace(), A.graph.row_map);
391     auto                 Aj_h = Kokkos::create_mirror_view_and_copy(Kokkos::HostSpace(), A.graph.entries);
392     auto                 Bi_h = Kokkos::create_mirror_view_and_copy(Kokkos::HostSpace(), B.graph.row_map);
393     auto                 Bj_h = Kokkos::create_mirror_view_and_copy(Kokkos::HostSpace(), B.graph.entries);
394     const MatRowMapType *Ai = Ai_h.data(), *Bi = Bi_h.data();
395     const MatColIdxType *Aj = Aj_h.data(), *Bj = Bj_h.data();
396 
397     // Count how many nonzeros of each row in E are in the left of the diag block [cstart,cend)
398     PetscIntKokkosViewHost E_NzLeft_h(NoInit("E_NzLeft_h"), Em), E_RowLen_h(NoInit("E_RowLen_h"), Em);
399     PetscInt              *E_NzLeft = E_NzLeft_h.data(), *E_RowLen = E_RowLen_h.data();
400     for (PetscInt i = 0; i < Em; i++) {
401       const PetscInt *first, *last, *it;
402       PetscInt        count, step;
403       // std::lower_bound(first,last,cstart), but need to use global column indices
404       first = Bj + Bi[i];
405       last  = Bj + Bi[i + 1];
406       count = last - first;
407       while (count > 0) {
408         it   = first;
409         step = count / 2;
410         it += step;
411         if (garray1[*it] < cstart) { // map local to global
412           first = ++it;
413           count -= step + 1;
414         } else count = step;
415       }
416       E_NzLeft[i] = first - (Bj + Bi[i]);
417       E_RowLen[i] = (Ai[i + 1] - Ai[i]) + (Bi[i + 1] - Bi[i]);
418     }
419 
420     // Get length of rows (i.e., sizes of leaves) that contribute to my roots
421     const PetscMPIInt *iranks, *ranks;
422     const PetscInt    *ioffset, *irootloc, *roffset, *rmine;
423     PetscMPIInt        niranks, nranks;
424     MPI_Request       *reqs;
425     PetscMPIInt        tag;
426     PetscSF            reduceSF;
427     PetscInt          *sdisp, *rdisp;
428 
429     PetscCall(PetscCommGetNewTag(comm, &tag));
430     PetscCall(PetscSFGetLeafRanks(ownerSF, &niranks, &iranks, &ioffset, &irootloc));  // get leaf ranks connecting to roots on this process (I'll recv from them)
431     PetscCall(PetscSFGetRootRanks(ownerSF, &nranks, &ranks, &roffset, &rmine, NULL)); // get root ranks this process connects (I'll send to them)
432 
433     // Find out length of each row I will receive. Even for the same row index, when they are from
434     // different senders, they might have different lengths (and sparsity patterns)
435     PetscInt  sendRowCnt = roffset[nranks], recvRowCnt = ioffset[niranks];
436     PetscInt *sendRowLen, *recvRowLen; // lengths of rows of I need to send/recv per process
437 
438     PetscCall(PetscMalloc5(sendRowCnt, &sendRowLen, recvRowCnt + 1, &recvRowLen, nranks, &sdisp, niranks + 1, &rdisp, nranks + niranks, &reqs));
439 
440     for (PetscInt i = 0; i < sendRowCnt; i++) sendRowLen[i] = E_RowLen[rmine[i]];
441     recvRowLen[0] = 0; // since we will make it in CSR format later
442     recvRowLen++;      // advance the pointer now
443     for (PetscInt i = 0; i < niranks; i++) MPI_Irecv(&recvRowLen[ioffset[i]], ioffset[i + 1] - ioffset[i], MPIU_INT, iranks[i], tag, comm, &reqs[nranks + i]);
444     for (PetscInt i = 0; i < nranks; i++) MPIU_Isend(&sendRowLen[roffset[i]], roffset[i + 1] - roffset[i], MPIU_INT, ranks[i], tag, comm, &reqs[i]);
445     PetscCallMPI(MPI_Waitall(nranks + niranks, reqs, MPI_STATUSES_IGNORE));
446 
447     // Build the real PetscSF for reducing E rows (buffer to buffer)
448     rdisp[0] = 0;
449     for (PetscInt i = 0; i < niranks; i++) {
450       rdisp[i + 1] = rdisp[i];
451       for (PetscInt j = ioffset[i]; j < ioffset[i + 1]; j++) { rdisp[i + 1] += recvRowLen[j]; }
452     }
453     recvRowLen--; // put it back into csr format
454     for (PetscInt i = 0; i < recvRowCnt; i++) recvRowLen[i + 1] += recvRowLen[i];
455 
456     for (PetscInt i = 0; i < nranks; i++) MPIU_Irecv(&sdisp[i], 1, MPIU_INT, ranks[i], tag, comm, &reqs[i]);
457     for (PetscInt i = 0; i < niranks; i++) MPIU_Isend(&rdisp[i], 1, MPIU_INT, iranks[i], tag, comm, &reqs[nranks + i]);
458     PetscCallMPI(MPI_Waitall(nranks + niranks, reqs, MPI_STATUSES_IGNORE));
459 
460     PetscInt     nleaves = 0, Enz = 0;    // leaves are nonzeros I will send
461     PetscInt     nroots = rdisp[niranks]; // roots are nonzeros I will recv
462     PetscSFNode *iremote;
463 
464     for (PetscInt i = 0; i < Em; i++) Enz += E_RowLen[i];
465     PetscAssert(A.nnz() + B.nnz() == Enz, comm, PETSC_ERR_PLIB, "Enz should be equal to sum of nnz of A and B");
466     PetscCallMPI(PetscMalloc1(Enz, &iremote)); // no free, since we give ownership to reduceSF
467 
468     for (PetscInt i = 0; i < nranks; i++) {
469       PetscInt count = 0;
470       for (PetscInt j = roffset[i]; j < roffset[i + 1]; j++) count += E_RowLen[rmine[j]];
471       for (PetscInt j = 0; j < count; j++) {
472         iremote[nleaves + j].rank  = ranks[i];
473         iremote[nleaves + j].index = sdisp[i] + j;
474       }
475       nleaves += count;
476     }
477     PetscCheck(nleaves == Enz, comm, PETSC_ERR_PLIB, "nleaves should be equal to Enz");
478 
479     PetscCall(PetscSFCreate(comm, &reduceSF));
480     PetscCall(PetscSFSetGraph(reduceSF, nroots, nleaves, NULL, PETSC_OWN_POINTER, iremote, PETSC_OWN_POINTER));
481 
482     // Copy (global) column indices of the needed rows in E to sendCol[], and then PetscSFReduce to recvCol[]
483     PetscInt *sendCol, *recvCol;
484     PetscCall(PetscMalloc2(nleaves, &sendCol, nroots, &recvCol));
485     for (PetscInt k = 0; k < roffset[nranks]; k++) {
486       PetscInt  i      = rmine[k]; // row to be copied
487       PetscInt *buf    = &sendCol[Ai[i] + Bi[i]];
488       PetscInt  nzLeft = E_NzLeft[i];
489       PetscInt  alen = Ai[i + 1] - Ai[i], blen = Bi[i + 1] - Bi[i];
490       for (PetscInt j = 0; j < alen + blen; j++) {
491         if (j < nzLeft) {
492           buf[j] = garray1[Bj[Bi[i] + j]]; // left B, in global
493         } else if (j < nzLeft + alen) {
494           buf[j] = Aj[Ai[i] + j - nzLeft] + cstart; // diag A, also in global
495         } else {
496           buf[j] = garray1[Bj[Bi[i] + j - alen]]; // right B, in global
497         }
498       }
499     }
500     PetscCall(PetscSFReduceWithMemTypeBegin(reduceSF, MPIU_INT, PETSC_MEMTYPE_HOST, sendCol, PETSC_MEMTYPE_HOST, recvCol, MPI_REPLACE));
501     PetscCall(PetscSFReduceEnd(reduceSF, MPIU_INT, sendCol, recvCol, MPI_REPLACE));
502 
503     // With recvCol[], we do a series of analysis to get i, j of Fd, Fo, and build plans to reduce nonzeros in recv buffers to Fd and Fo
504     PetscInt *recvRowPerm, *recvColSorted;
505     PetscInt *recvNzPerm, *recvNzPermSorted;
506     PetscCall(PetscMalloc4(recvRowCnt, &recvRowPerm, nroots, &recvColSorted, nroots, &recvNzPerm, nroots, &recvNzPermSorted));
507 
508     for (PetscInt i = 0; i < nroots; i++) recvNzPerm[i] = i;                   // numbering all received nonzeros
509     for (PetscInt i = 0; i < recvRowCnt; i++) recvRowPerm[i] = i;              // put up a permutation array, so that after sorting we know where to get a row in recvCol[]
510     PetscCall(PetscSortIntWithPermutation(recvRowCnt, irootloc, recvRowPerm)); // irootloc[] (owned by ownerSF) won't be changed
511 
512     // i[] array, nz are always easiest to compute
513     MatRowMapKokkosViewHost Fdi_h(NoInit("Fdi_h"), Fm + 1), Foi_h(NoInit("Foi_h"), Fm + 1);
514     MatRowMapType          *Fdi, *Foi;
515     PetscInt                FnzDups = 0, Fdnz = 0, FdnzDups = 0, Fonz = 0, FonzDups = 0; // nz (with or without dups) in F, Fd, Fo
516     PetscInt                iter;
517 
518     Kokkos::deep_copy(Fdi_h, 0); // zero, as we will do 'val++' on them
519     Kokkos::deep_copy(Foi_h, 0);
520     Fdi  = Fdi_h.data() + 1; // +1 for easy indexing in code below
521     Foi  = Foi_h.data() + 1;
522     iter = 0;
523     while (iter < recvRowCnt) { // iter over received rows
524       PetscInt curRowIdx = irootloc[recvRowPerm[iter]];
525       PetscInt dupRows   = 1; // current row has this many contributing rows (of various sparsity patterns)
526 
527       while (iter + dupRows < recvRowCnt && irootloc[recvRowPerm[iter + dupRows]] == curRowIdx) dupRows++;
528 
529       // Copy column indices (and their permutation) of these rows into recvColSorted & recvNzPermSorted
530       PetscInt  nz    = 0; // nz (with dups) in the current row
531       PetscInt *jbuf  = recvColSorted + FnzDups;
532       PetscInt *pbuf  = recvNzPermSorted + FnzDups;
533       PetscInt *jbuf2 = jbuf; // temp pointers
534       PetscInt *pbuf2 = pbuf;
535       for (PetscInt d = 0; d < dupRows; d++) {
536         PetscInt i   = recvRowPerm[iter + d];
537         PetscInt len = recvRowLen[i + 1] - recvRowLen[i];
538         PetscCall(PetscArraycpy(jbuf2, &recvCol[recvRowLen[i]], len));
539         PetscCall(PetscArraycpy(pbuf2, &recvNzPerm[recvRowLen[i]], len));
540         jbuf2 += len;
541         pbuf2 += len;
542         nz += len;
543       }
544       PetscCall(PetscIntSortSemiOrderedWithArray(nz, jbuf, pbuf)); // It could be improved with k-way merge sort, since the rows are already sorted
545 
546       // Scan column indices (in jbuf[0,nz), might have dups) of this row, and see how many go to Fd and how many go to Fo
547       PetscInt cur = 0;
548       while (cur < nz) {
549         PetscInt curColIdx = jbuf[cur];
550         PetscInt dups      = 1;
551 
552         while (cur + dups < nz && jbuf[cur + dups] == curColIdx) dups++;
553         if (curColIdx >= cstart && curColIdx < cend) {
554           Fdi[curRowIdx]++;
555           FdnzDups += dups;
556         } else {
557           Foi[curRowIdx]++;
558           FonzDups += dups;
559         }
560         cur += dups;
561       }
562 
563       FnzDups += nz;
564       iter += dupRows; // Move to next unique row
565     }
566 
567     Fdi = Fdi_h.data(); // restore Fdi, Foi and make them CSR
568     Foi = Foi_h.data();
569     for (PetscInt i = 0; i < Fm; i++) {
570       Fdi[i + 1] += Fdi[i];
571       Foi[i + 1] += Foi[i];
572     }
573     Fdnz = Fdi[Fm];
574     Fonz = Foi[Fm];
575     PetscCall(PetscFree2(sendCol, recvCol));
576 
577     // Allocate j, jmap, jperm for Fd and Fo
578     MatColIdxKokkosViewHost Fdj_h(NoInit("Fdj_h"), Fdnz), Foj_h(NoInit("Foj_h"), Fonz);
579     MatRowMapKokkosViewHost Fdjmap_h(NoInit("Fdjmap_h"), Fdnz + 1), Fojmap_h(NoInit("Fojmap_h"), Fonz + 1); // +1 to make csr
580     MatRowMapKokkosViewHost Fdjperm_h(NoInit("Fdjperm_h"), FdnzDups), Fojperm_h(NoInit("Fojperm_h"), FonzDups);
581     MatColIdxType          *Fdj = Fdj_h.data(), *Foj = Foj_h.data();
582     MatRowMapType          *Fdjmap = Fdjmap_h.data(), *Fojmap = Fojmap_h.data();
583     MatRowMapType          *Fdjperm = Fdjperm_h.data(), *Fojperm = Fojperm_h.data();
584 
585     // Scan recvColSorted[] again, and fill j, jmap, jperm for Fd and Fo
586     Fdjmap[0] = 0;
587     Fojmap[0] = 0;
588     FnzDups   = 0;
589     Fdnz      = 0;
590     Fonz      = 0;
591     iter      = 0; // iter over received rows
592     while (iter < recvRowCnt) {
593       PetscInt curRowIdx = irootloc[recvRowPerm[iter]]; // current row idx
594       PetscInt dupRows   = 1;                           // It has this many contributing rows (of various lengths)
595       PetscInt nz        = 0;                           // nz (with dups) in the current row
596 
597       while (iter + dupRows < recvRowCnt && irootloc[recvRowPerm[iter + dupRows]] == curRowIdx) dupRows++;
598       for (PetscInt d = 0; d < dupRows; d++) {
599         PetscInt i = recvRowPerm[iter + d];
600         nz += recvRowLen[i + 1] - recvRowLen[i];
601       }
602 
603       PetscInt *jbuf = recvColSorted + FnzDups;
604       // Scan columns (in jbuf[0,nz) of this row, copy them and their permutation to j[] and jperm[] of Fd and Fo
605       PetscInt cur = 0;
606       while (cur < nz) {
607         PetscInt curColIdx = jbuf[cur];
608         PetscInt dups      = 1;
609 
610         while (cur + dups < nz && jbuf[cur + dups] == curColIdx) dups++;
611         if (curColIdx >= cstart && curColIdx < cend) {
612           Fdj[Fdnz]        = curColIdx - cstart; // easily convert to local
613           Fdjmap[Fdnz + 1] = Fdjmap[Fdnz] + dups;
614           for (PetscInt j = 0; j < dups; j++) Fdjperm[Fdjmap[Fdnz] + j] = recvNzPermSorted[FnzDups + j];
615           FdnzDups += dups;
616           Fdnz++;
617         } else {
618           Foj[Fonz]        = curColIdx; // in global
619           Fojmap[Fonz + 1] = Fojmap[Fonz] + dups;
620           for (PetscInt j = 0; j < dups; j++) Fojperm[Fojmap[Fonz] + j] = recvNzPermSorted[FnzDups + j];
621           FonzDups += dups;
622           Fonz++;
623         }
624         cur += dups;
625         FnzDups += dups;
626       }
627       iter += dupRows; // Move to next unique row
628     }
629     PetscCall(PetscFree4(recvRowPerm, recvColSorted, recvNzPerm, recvNzPermSorted));
630     PetscCall(PetscFree5(sendRowLen, recvRowLen, sdisp, rdisp, reqs));
631 
632     // Combine global column indices in garray1[] and Foj[]
633     PetscInt n2, *garray2;
634 
635     PetscCall(ReduceTwoSetsOfGlobalIndices(n1, garray1, Fonz, Foj, &n2, &garray2, map));
636     mm->sf       = reduceSF;
637     mm->leafBuf  = MatScalarKokkosView(NoInit("leafBuf"), nleaves);
638     mm->rootBuf  = MatScalarKokkosView(NoInit("rootBuf"), nroots);
639     mm->garray   = garray2; // give ownership, so no free
640     mm->n        = n2;
641     mm->E_NzLeft = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), E_NzLeft_h);
642     mm->Fdjmap   = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Fdjmap_h);
643     mm->Fdjperm  = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Fdjperm_h);
644     mm->Fojmap   = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Fojmap_h);
645     mm->Fojperm  = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Fojperm_h);
646 
647     // Output Fd and Fo in KokkosCsrMatrix format
648     MatScalarKokkosView Fda_d(NoInit("Fda_d"), Fdnz);
649     MatRowMapKokkosView Fdi_d = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Fdi_h);
650     MatColIdxKokkosView Fdj_d = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Fdj_h);
651     MatScalarKokkosView Foa_d(NoInit("Foa_d"), Fonz);
652     MatRowMapKokkosView Foi_d = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Foi_h);
653     MatColIdxKokkosView Foj_d = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Foj_h);
654 
655     PetscCallCXX(mm->Fd = KokkosCsrMatrix("Fd", Fm, cend - cstart, Fdnz, Fda_d, Fdi_d, Fdj_d));
656     PetscCallCXX(mm->Fo = KokkosCsrMatrix("Fo", Fm, n2, Fonz, Foa_d, Foi_d, Foj_d)); // Fo's column size is n2, length of garray2[]
657 
658     // Compute kernel launch parameters in merging E
659     PetscInt teamSize, vectorLength, rowsPerTeam;
660 
661     teamSize = vectorLength = rowsPerTeam = -1;
662     PetscCall(MatMergeGetLaunchParameters<DefaultExecutionSpace>(Em, Enz, -1, teamSize, vectorLength, rowsPerTeam));
663     mm->E_TeamSize     = teamSize;
664     mm->E_VectorLength = vectorLength;
665     mm->E_RowsPerTeam  = rowsPerTeam;
666   } else PetscCheck(reuse == MAT_REUSE_MATRIX, comm, PETSC_ERR_PLIB, "Unsupported MatReuse enum %d", reuse);
667 
668   // Handy aliases
669   auto       &Aa           = A.values;
670   auto       &Ba           = B.values;
671   const auto &Ai           = A.graph.row_map;
672   const auto &Bi           = B.graph.row_map;
673   const auto &E_NzLeft     = mm->E_NzLeft;
674   auto       &leafBuf      = mm->leafBuf;
675   auto       &rootBuf      = mm->rootBuf;
676   PetscSF     reduceSF     = mm->sf;
677   PetscInt    Em           = A.numRows();
678   PetscInt    teamSize     = mm->E_TeamSize;
679   PetscInt    vectorLength = mm->E_VectorLength;
680   PetscInt    rowsPerTeam  = mm->E_RowsPerTeam;
681   PetscInt    workSets     = (Em + rowsPerTeam - 1) / rowsPerTeam;
682 
683   // Copy rows in A/B of E to leafBuf, then pass it to rootBuf
684   PetscCallCXX(Kokkos::parallel_for(
685     Kokkos::TeamPolicy<>(PetscGetKokkosExecutionSpace(), workSets, teamSize, vectorLength), KOKKOS_LAMBDA(const KokkosTeamMemberType &t) {
686       Kokkos::parallel_for(Kokkos::TeamThreadRange(t, 0, rowsPerTeam), [&](PetscInt k) {
687         PetscInt i = t.league_rank() * rowsPerTeam + k; // i-th row in F
688         if (i < Em) {
689           PetscInt disp   = Ai(i) + Bi(i);
690           PetscInt alen   = Ai(i + 1) - Ai(i);
691           PetscInt blen   = Bi(i + 1) - Bi(i);
692           PetscInt nzleft = E_NzLeft(i);
693 
694           Kokkos::parallel_for(Kokkos::ThreadVectorRange(t, alen + blen), [&](PetscInt j) {
695             MatScalar &val = leafBuf(disp + j);
696             if (j < nzleft) { // B left
697               val = Ba(Bi(i) + j);
698             } else if (j < nzleft + alen) { // diag A
699               val = Aa(Ai(i) + j - nzleft);
700             } else { // B right
701               val = Ba(Bi(i) + j - alen);
702             }
703           });
704         }
705       });
706     }));
707   PetscCall(PetscSFReduceWithMemTypeBegin(reduceSF, MPIU_SCALAR, PETSC_MEMTYPE_KOKKOS, leafBuf.data(), PETSC_MEMTYPE_KOKKOS, rootBuf.data(), MPI_REPLACE));
708   PetscFunctionReturn(PETSC_SUCCESS);
709 }
710 
711 // To finish MatMPIAIJKokkosReduce.
712 static PetscErrorCode MatMPIAIJKokkosReduceEnd(MPI_Comm comm, KokkosCsrMatrix A, KokkosCsrMatrix B, PetscInt cstart, PetscInt cend, const PetscInt *garray1, PetscSF ownerSF, MatReuse reuse, PetscInt *map, MatMatStruct_AtB *mm)
713 {
714   auto       &leafBuf  = mm->leafBuf;
715   auto       &rootBuf  = mm->rootBuf;
716   auto       &Fda      = mm->Fd.values;
717   const auto &Fdjmap   = mm->Fdjmap;
718   const auto &Fdjperm  = mm->Fdjperm;
719   auto        Fdnz     = mm->Fd.nnz();
720   auto       &Foa      = mm->Fo.values;
721   const auto &Fojmap   = mm->Fojmap;
722   const auto &Fojperm  = mm->Fojperm;
723   auto        Fonz     = mm->Fo.nnz();
724   PetscSF     reduceSF = mm->sf;
725 
726   PetscFunctionBegin;
727   PetscCall(PetscSFReduceEnd(reduceSF, MPIU_SCALAR, leafBuf.data(), rootBuf.data(), MPI_REPLACE));
728 
729   // Reduce data in rootBuf to Fd and Fo
730   PetscCallCXX(Kokkos::parallel_for(
731     Kokkos::RangePolicy<>(PetscGetKokkosExecutionSpace(), 0, Fdnz), KOKKOS_LAMBDA(const MatRowMapType i) {
732       PetscScalar sum = 0.0;
733       for (MatRowMapType k = Fdjmap(i); k < Fdjmap(i + 1); k++) sum += rootBuf(Fdjperm(k));
734       Fda(i) = sum;
735     }));
736 
737   PetscCallCXX(Kokkos::parallel_for(
738     Kokkos::RangePolicy<>(PetscGetKokkosExecutionSpace(), 0, Fonz), KOKKOS_LAMBDA(const MatRowMapType i) {
739       PetscScalar sum = 0.0;
740       for (MatRowMapType k = Fojmap(i); k < Fojmap(i + 1); k++) sum += rootBuf(Fojperm(k));
741       Foa(i) = sum;
742     }));
743   PetscFunctionReturn(PETSC_SUCCESS);
744 }
745 
746 /*
747   MatMPIAIJKokkosBcast - Bcast local rows of a MPIAIJKOKKOS matrix (E) to produce a local matrix (F, stored in mm) in split form
748 
749   This is a complex routine. It is essentially the MPIAIJKOKKOS counterpart of MatGetBrowsOfAoCols_MPIAIJ, but supports
750   device and involves various index mapping.
751 
752   In the given ownerSF, leaves correspond to rows in F, and roots correspond to rows in E. Roots may connect to multiple leaves.
753   Suppose F's j-th row is connected to a root identified by PetscSFNode (k,i), it means we need to bcast the i-th row of E on rank k
754   to j-th row of F. ownerSF is not an arbitrary SF, instead it is the Mvctx of another MPIAIJ matrix A that is able to perform A*E.
755   F has the same column layout as E.
756 
757   Conceptually F has global column indices. In this routine, we spit F into diagonal Fd and off-diagonal Fo.
758   Fd uses local column indices, which are easy to compute. We just need to subtract the "local column range start" from the global indices.
759   Fo had global column indices at first. We will reduce them into local ones. In doing that, we also take into account the global
760   column indices that E's off-diag block has. Let's say there are n1 such indices stored in garray1[]. We will reduce them along with
761   column indices in Fo and update Fo with local indices.
762 
763    Input Parameters:
764 +   E       - the MPIAIJKOKKOS matrix
765 .   ownerSF - the ownership SF (insignificant in MAT_REUSE_MATRIX)
766 .   reuse   - either MAT_INITIAL_MATRIX or MAT_REUSE_MATRIX
767 -   mm      - to stash matproduct intermediate data structures
768 
769     Output Parameters:
770 +   map[n1] - allocated by caller. It maps garray1[] to garray2[]. See more at ReduceTwoSetsOfGlobalIndices.
771 -   mm      - contains various info, such as garray2[], Fd, Fo, etc.
772 
773     Notes:
774     When reuse = MAT_REUSE_MATRIX, ownerSF, map are not significant.
775     The routine is provide in split-phase form MatMPIAIJKokkosBcastBegin/End() to provide computation/communication opportunities.
776 */
777 static PetscErrorCode MatMPIAIJKokkosBcastBegin(Mat E, PetscSF ownerSF, MatReuse reuse, PetscInt *map, MatMatStruct_AB *mm)
778 {
779   Mat_MPIAIJ       *empi = static_cast<Mat_MPIAIJ *>(E->data);
780   Mat               A = empi->A, B = empi->B; // diag and off-diag
781   Mat_SeqAIJKokkos *akok = static_cast<Mat_SeqAIJKokkos *>(A->spptr), *bkok = static_cast<Mat_SeqAIJKokkos *>(B->spptr);
782   PetscInt          Em = E->rmap->n; // #local rows
783   MPI_Comm          comm;
784 
785   PetscFunctionBegin;
786   PetscCallMPI(PetscObjectGetComm((PetscObject)E, &comm));
787   if (reuse == MAT_INITIAL_MATRIX) {
788     Mat_SeqAIJ     *aseq = static_cast<Mat_SeqAIJ *>(A->data), *bseq = static_cast<Mat_SeqAIJ *>(B->data);
789     PetscInt        n1 = B->cmap->n, *Ai = aseq->i, *Aj = aseq->j, *Bi = bseq->i, *Bj = bseq->j;
790     const PetscInt *garray1 = empi->garray; // its size is n1
791     PetscInt        cstart, cend;
792     PetscSF         bcastSF;
793 
794     PetscCall(MatGetOwnershipRangeColumn(E, &cstart, &cend));
795 
796     // Count how many nonzeros of each row in E are in the left of the diag block [cstart,cend)
797     PetscIntKokkosViewHost E_NzLeft_h(NoInit("E_NzLeft_h"), Em), E_RowLen_h(NoInit("E_RowLen_h"), Em);
798     PetscInt              *E_NzLeft = E_NzLeft_h.data(), *E_RowLen = E_RowLen_h.data();
799     for (PetscInt i = 0; i < Em; i++) {
800       const PetscInt *first, *last, *it;
801       PetscInt        count, step;
802       // std::lower_bound(first,last,cstart), but need to use global column indices
803       first = Bj + Bi[i];
804       last  = Bj + Bi[i + 1];
805       count = last - first;
806       while (count > 0) {
807         it   = first;
808         step = count / 2;
809         it += step;
810         if (empi->garray[*it] < cstart) { // map local to global
811           first = ++it;
812           count -= step + 1;
813         } else count = step;
814       }
815       E_NzLeft[i] = first - (Bj + Bi[i]);
816       E_RowLen[i] = (Ai[i + 1] - Ai[i]) + (Bi[i + 1] - Bi[i]);
817     }
818 
819     // Compute row pointer Fi of F
820     PetscInt *Fi, Fm, Fnz;
821     PetscCall(PetscSFGetGraph(ownerSF, NULL, &Fm, NULL, NULL)); // Fm = #rows of F = nleaves of ownerSF
822     PetscCall(PetscMalloc1(Fm + 1, &Fi));
823     Fi[0] = 0;
824     PetscCall(PetscSFBcastWithMemTypeBegin(ownerSF, MPIU_INT, PETSC_MEMTYPE_HOST, E_RowLen, PETSC_MEMTYPE_HOST, &Fi[1], MPI_REPLACE));
825     PetscCall(PetscSFBcastEnd(ownerSF, MPIU_INT, E_RowLen, &Fi[1], MPI_REPLACE));
826     for (PetscInt i = 0; i < Fm; i++) Fi[i + 1] += Fi[i];
827     Fnz = Fi[Fm];
828 
829     // Build the real PetscSF for bcasting E rows (buffer to buffer)
830     const PetscMPIInt *iranks, *ranks;
831     const PetscInt    *ioffset, *irootloc, *roffset;
832     PetscMPIInt        niranks, nranks;
833     PetscInt          *sdisp, *rdisp;
834     MPI_Request       *reqs;
835     PetscMPIInt        tag;
836 
837     PetscCall(PetscSFGetLeafRanks(ownerSF, &niranks, &iranks, &ioffset, &irootloc)); // get leaf ranks referencing roots on this process
838     PetscCall(PetscSFGetRootRanks(ownerSF, &nranks, &ranks, &roffset, NULL, NULL));  // recv info
839     PetscCall(PetscMalloc3(niranks + 1, &sdisp, nranks, &rdisp, niranks + nranks, &reqs));
840 
841     sdisp[0] = 0; // send displacement
842     for (PetscInt i = 0; i < niranks; i++) {
843       sdisp[i + 1] = sdisp[i];
844       for (PetscInt j = ioffset[i]; j < ioffset[i + 1]; j++) {
845         PetscInt r = irootloc[j]; // row to be sent
846         sdisp[i + 1] += E_RowLen[r];
847       }
848     }
849 
850     PetscCallMPI(PetscCommGetNewTag(comm, &tag));
851     for (PetscInt i = 0; i < nranks; i++) PetscCallMPI(MPIU_Irecv(&rdisp[i], 1, MPIU_INT, ranks[i], tag, comm, &reqs[i]));
852     for (PetscInt i = 0; i < niranks; i++) PetscCallMPI(MPIU_Isend(&sdisp[i], 1, MPIU_INT, iranks[i], tag, comm, &reqs[nranks + i]));
853     PetscCallMPI(MPI_Waitall(niranks + nranks, reqs, MPI_STATUSES_IGNORE));
854 
855     PetscInt     nleaves = Fnz;            // leaves are nonzeros I will receive
856     PetscInt     nroots  = sdisp[niranks]; // roots are nonzeros I will send
857     PetscSFNode *iremote;                  // give ownership to bcastSF
858     PetscCall(PetscMalloc1(nleaves, &iremote));
859     for (PetscInt i = 0; i < nranks; i++) { // for each sender rank
860       PetscInt k = 0;
861       for (PetscInt j = Fi[roffset[i]]; j < Fi[roffset[i + 1]]; j++) { // I will receive rows [roffset[i], roffset[i+1]) of F from ranks[i]
862         iremote[j].rank  = ranks[i];
863         iremote[j].index = rdisp[i] + k; // their root location
864         k++;
865       }
866     }
867     PetscCall(PetscSFCreate(comm, &bcastSF));
868     PetscCall(PetscSFSetGraph(bcastSF, nroots, nleaves, NULL, PETSC_OWN_POINTER, iremote, PETSC_OWN_POINTER));
869     PetscCall(PetscFree3(sdisp, rdisp, reqs));
870 
871     // Build a plan (rowoffset, irootloc, E_NzLeft) to copy rows in E to rootdata of bcastSF in parallel
872     PetscIntKokkosViewHost rowoffset_h(NoInit("rowoffset_h"), ioffset[niranks] + 1);
873     PetscInt              *rowoffset = rowoffset_h.data(); // for each entry (row) indicated in irootloc[], we calculate its destinate offset in copying
874     rowoffset[0]                     = 0;
875     for (PetscInt i = 0; i < ioffset[niranks]; i++) { rowoffset[i + 1] = rowoffset[i] + E_RowLen[irootloc[i]]; }
876 
877     // Copy (global) column indices of the needed rows in E to a buffer, and then bcast to Fj[]
878     PetscInt *jbuf, *Fj;
879     PetscCall(PetscMalloc2(nroots, &jbuf, Fnz, &Fj));
880     for (PetscInt k = 0; k < ioffset[niranks]; k++) {
881       PetscInt  i      = irootloc[k]; // row to be copied
882       PetscInt *buf    = &jbuf[rowoffset[k]];
883       PetscInt  nzLeft = E_NzLeft[i];
884       PetscInt  alen = Ai[i + 1] - Ai[i], blen = Bi[i + 1] - Bi[i];
885       for (PetscInt j = 0; j < alen + blen; j++) {
886         if (j < nzLeft) {
887           buf[j] = empi->garray[Bj[Bi[i] + j]]; // left B, in global
888         } else if (j < nzLeft + alen) {
889           buf[j] = Aj[Ai[i] + j - nzLeft] + cstart; // diag A, also in global
890         } else {
891           buf[j] = empi->garray[Bj[Bi[i] + j - alen]]; // right B, in global
892         }
893       }
894     }
895     PetscCall(PetscSFBcastWithMemTypeBegin(bcastSF, MPIU_INT, PETSC_MEMTYPE_HOST, jbuf, PETSC_MEMTYPE_HOST, Fj, MPI_REPLACE));
896     PetscCall(PetscSFBcastEnd(bcastSF, MPIU_INT, jbuf, Fj, MPI_REPLACE));
897 
898     // Build a plan (i.e., F_NzLeft) to split F into Fd and Fo
899     MatRowMapKokkosViewHost Fdi_h(NoInit("Fdi_h"), Fm + 1), Foi_h(NoInit("Foi_h"), Fm + 1); // row pointer of Fd, Fo
900     MatColIdxKokkosViewHost F_NzLeft_h(NoInit("F_NzLeft_h"), Fm);                           // split each row of F into Left, Diag, Right. We only need to record #nz in Left and Diag.
901     MatRowMapType          *Fdi = Fdi_h.data(), *Foi = Foi_h.data();
902     MatColIdxType          *F_NzLeft = F_NzLeft_h.data();
903 
904     Fdi[0] = Foi[0] = 0;
905     for (PetscInt i = 0; i < Fm; i++) {
906       PetscInt *first, *last, *lb1, *lb2;
907       // cut the row into: Left, [cstart, cend), Right
908       first       = Fj + Fi[i];
909       last        = Fj + Fi[i + 1];
910       lb1         = std::lower_bound(first, last, cstart);
911       F_NzLeft[i] = lb1 - first;
912       lb2         = std::lower_bound(first, last, cend);
913       Fdi[i + 1]  = lb2 - lb1;                        // row i length in Fdi
914       Foi[i + 1]  = (Fi[i + 1] - Fi[i]) - Fdi[i + 1]; // row i length in Foi
915     }
916     for (PetscInt i = 0; i < Fm; i++) {
917       Fdi[i + 1] += Fdi[i];
918       Foi[i + 1] += Foi[i];
919     }
920 
921     // Fill Fdj[] and Foj[], i.e., columns of Fd and Fo. Fdj[] are local, but Foj[] are not yet.
922     PetscInt                Fdnz = Fdi[Fm], Fonz = Foi[Fm];
923     MatColIdxKokkosViewHost Fdj_h(NoInit("Fdj_h"), Fdnz), Foj_h(NoInit("Foj_h"), Fonz);
924     MatColIdxType          *Fdj = Fdj_h.data(), *Foj = Foj_h.data(), gid;
925 
926     for (PetscInt i = 0; i < Fm; i++) {
927       PetscInt nzLeft = F_NzLeft[i];
928       PetscInt len    = Fdi[i + 1] - Fdi[i]; // diag row len
929       for (PetscInt j = 0; j < Fi[i + 1] - Fi[i]; j++) {
930         gid = Fj[Fi[i] + j];
931         if (j < nzLeft) { // left, in global
932           Foj[Foi[i] + j] = gid;
933         } else if (j < nzLeft + len) { // diag, in local
934           Fdj[Fdi[i] + j - nzLeft] = gid - cstart;
935         } else { // right, in global
936           Foj[Foi[i] + j - len] = gid;
937         }
938       }
939     }
940     PetscCall(PetscFree2(jbuf, Fj));
941     PetscCall(PetscFree(Fi));
942 
943     // Reduce global indices in Foj[] and garray1[] into local ones
944     PetscInt n2, *garray2;
945     PetscCall(ReduceTwoSetsOfGlobalIndices(n1, garray1, Fonz, Foj, &n2, &garray2, map));
946 
947     // Record the plans built above, for reuse
948     PetscIntKokkosViewHost tmp(const_cast<PetscInt *>(irootloc), ioffset[niranks]); // irootloc[] is owned by ownerSF. We create a copy for safety
949     PetscIntKokkosViewHost irootloc_h(NoInit("irootloc_h"), ioffset[niranks]);
950     Kokkos::deep_copy(irootloc_h, tmp);
951     mm->sf        = bcastSF;
952     mm->E_NzLeft  = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), E_NzLeft_h);
953     mm->F_NzLeft  = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), F_NzLeft_h);
954     mm->irootloc  = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), irootloc_h);
955     mm->rowoffset = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), rowoffset_h);
956     mm->rootBuf   = MatScalarKokkosView(NoInit("rootBuf"), nroots);
957     mm->leafBuf   = MatScalarKokkosView(NoInit("leafBuf"), nleaves);
958     mm->garray    = garray2;
959     mm->n         = n2;
960 
961     // Output Fd and Fo in KokkosCsrMatrix format
962     MatScalarKokkosView Fda_d(NoInit("Fda_d"), Fdnz), Foa_d(NoInit("Foa_d"), Fonz);
963     MatRowMapKokkosView Fdi_d = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Fdi_h);
964     MatColIdxKokkosView Fdj_d = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Fdj_h);
965     MatRowMapKokkosView Foi_d = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Foi_h);
966     MatColIdxKokkosView Foj_d = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Foj_h);
967 
968     PetscCallCXX(mm->Fd = KokkosCsrMatrix("Fd", Fm, cend - cstart, Fdnz, Fda_d, Fdi_d, Fdj_d));
969     PetscCallCXX(mm->Fo = KokkosCsrMatrix("Fo", Fm, n2, Fonz, Foa_d, Foi_d, Foj_d));
970 
971     // Compute kernel launch parameters in merging E or splitting F
972     PetscInt teamSize, vectorLength, rowsPerTeam;
973 
974     teamSize = vectorLength = rowsPerTeam = -1;
975     PetscCall(MatMergeGetLaunchParameters<DefaultExecutionSpace>(mm->irootloc.extent(0), mm->rootBuf.extent(0), -1, teamSize, vectorLength, rowsPerTeam));
976     mm->E_TeamSize     = teamSize;
977     mm->E_VectorLength = vectorLength;
978     mm->E_RowsPerTeam  = rowsPerTeam;
979 
980     teamSize = vectorLength = rowsPerTeam = -1;
981     PetscCall(MatMergeGetLaunchParameters<DefaultExecutionSpace>(Fm, Fnz, -1, teamSize, vectorLength, rowsPerTeam));
982     mm->F_TeamSize     = teamSize;
983     mm->F_VectorLength = vectorLength;
984     mm->F_RowsPerTeam  = rowsPerTeam;
985   } else PetscCheck(reuse == MAT_REUSE_MATRIX, comm, PETSC_ERR_PLIB, "Unsupported MatReuse enum %d", reuse);
986 
987   // Sync E's value to device
988   akok->a_dual.sync_device();
989   bkok->a_dual.sync_device();
990 
991   // Handy aliases
992   const auto &Aa = akok->a_dual.view_device();
993   const auto &Ba = bkok->a_dual.view_device();
994   const auto &Ai = akok->i_dual.view_device();
995   const auto &Bi = bkok->i_dual.view_device();
996 
997   // Fetch the plans
998   PetscIntKokkosView  &E_NzLeft  = mm->E_NzLeft;
999   PetscSF             &bcastSF   = mm->sf;
1000   MatScalarKokkosView &rootBuf   = mm->rootBuf;
1001   MatScalarKokkosView &leafBuf   = mm->leafBuf;
1002   PetscIntKokkosView  &irootloc  = mm->irootloc;
1003   PetscIntKokkosView  &rowoffset = mm->rowoffset;
1004 
1005   PetscInt teamSize     = mm->E_TeamSize;
1006   PetscInt vectorLength = mm->E_VectorLength;
1007   PetscInt rowsPerTeam  = mm->E_RowsPerTeam;
1008   PetscInt workSets     = (irootloc.extent(0) + rowsPerTeam - 1) / rowsPerTeam;
1009 
1010   // Copy rows in A/B of E to rootBuf, then bcast it to leafBuf
1011   PetscCallCXX(Kokkos::parallel_for(
1012     Kokkos::TeamPolicy<>(PetscGetKokkosExecutionSpace(), workSets, teamSize, vectorLength), KOKKOS_LAMBDA(const KokkosTeamMemberType &t) {
1013       Kokkos::parallel_for(Kokkos::TeamThreadRange(t, 0, rowsPerTeam), [&](PetscInt k) {
1014         size_t r = t.league_rank() * rowsPerTeam + k; // r-th entry in irootloc[]
1015         if (r < irootloc.extent(0)) {
1016           PetscInt i      = irootloc(r); // row i of E
1017           PetscInt disp   = rowoffset(r);
1018           PetscInt alen   = Ai(i + 1) - Ai(i);
1019           PetscInt blen   = Bi(i + 1) - Bi(i);
1020           PetscInt nzleft = E_NzLeft(i);
1021 
1022           Kokkos::parallel_for(Kokkos::ThreadVectorRange(t, alen + blen), [&](PetscInt j) {
1023             if (j < nzleft) { // B left
1024               rootBuf(disp + j) = Ba(Bi(i) + j);
1025             } else if (j < nzleft + alen) { // diag A
1026               rootBuf(disp + j) = Aa(Ai(i) + j - nzleft);
1027             } else { // B right
1028               rootBuf(disp + j) = Ba(Bi(i) + j - alen);
1029             }
1030           });
1031         }
1032       });
1033     }));
1034   PetscCall(PetscSFBcastWithMemTypeBegin(bcastSF, MPIU_SCALAR, PETSC_MEMTYPE_KOKKOS, rootBuf.data(), PETSC_MEMTYPE_KOKKOS, leafBuf.data(), MPI_REPLACE));
1035   PetscFunctionReturn(PETSC_SUCCESS);
1036 }
1037 
1038 // To finish MatMPIAIJKokkosBcast.
1039 static PetscErrorCode MatMPIAIJKokkosBcastEnd(Mat E, PetscSF ownerSF, MatReuse reuse, PetscInt *map, MatMatStruct_AB *mm)
1040 {
1041   PetscFunctionBegin;
1042   const auto &Fd  = mm->Fd;
1043   const auto &Fo  = mm->Fo;
1044   const auto &Fdi = Fd.graph.row_map;
1045   const auto &Foi = Fo.graph.row_map;
1046   auto       &Fda = Fd.values;
1047   auto       &Foa = Fo.values;
1048   auto        Fm  = Fd.numRows();
1049 
1050   PetscIntKokkosView  &F_NzLeft     = mm->F_NzLeft;
1051   PetscSF             &bcastSF      = mm->sf;
1052   MatScalarKokkosView &rootBuf      = mm->rootBuf;
1053   MatScalarKokkosView &leafBuf      = mm->leafBuf;
1054   PetscInt             teamSize     = mm->F_TeamSize;
1055   PetscInt             vectorLength = mm->F_VectorLength;
1056   PetscInt             rowsPerTeam  = mm->F_RowsPerTeam;
1057   PetscInt             workSets     = (Fm + rowsPerTeam - 1) / rowsPerTeam;
1058 
1059   PetscCall(PetscSFBcastEnd(bcastSF, MPIU_SCALAR, rootBuf.data(), leafBuf.data(), MPI_REPLACE));
1060 
1061   // Update Fda and Foa with new data in leafBuf (as if it is Fa)
1062   PetscCallCXX(Kokkos::parallel_for(
1063     Kokkos::TeamPolicy<>(PetscGetKokkosExecutionSpace(), workSets, teamSize, vectorLength), KOKKOS_LAMBDA(const KokkosTeamMemberType &t) {
1064       Kokkos::parallel_for(Kokkos::TeamThreadRange(t, 0, rowsPerTeam), [&](PetscInt k) {
1065         PetscInt i = t.league_rank() * rowsPerTeam + k; // i-th row in F
1066         if (i < Fm) {
1067           PetscInt nzLeft = F_NzLeft(i);
1068           PetscInt alen   = Fdi(i + 1) - Fdi(i);
1069           PetscInt blen   = Foi(i + 1) - Foi(i);
1070           PetscInt Fii    = Fdi(i) + Foi(i);
1071 
1072           Kokkos::parallel_for(Kokkos::ThreadVectorRange(t, alen + blen), [&](PetscInt j) {
1073             PetscScalar val = leafBuf(Fii + j);
1074             if (j < nzLeft) { // left
1075               Foa(Foi(i) + j) = val;
1076             } else if (j < nzLeft + alen) { // diag
1077               Fda(Fdi(i) + j - nzLeft) = val;
1078             } else { // right
1079               Foa(Foi(i) + j - alen) = val;
1080             }
1081           });
1082         }
1083       });
1084     }));
1085   PetscFunctionReturn(PETSC_SUCCESS);
1086 }
1087 
1088 static PetscErrorCode MatProductSymbolic_MPIAIJKokkos_AtB(Mat_Product *product, Mat A, Mat B, MatMatStruct_AtB *mm)
1089 {
1090   Mat_MPIAIJ     *ampi = static_cast<Mat_MPIAIJ *>(A->data);
1091   Mat_MPIAIJ     *bmpi = static_cast<Mat_MPIAIJ *>(B->data);
1092   KokkosCsrMatrix Adt, Aot, Ad, Ao, Bd, Bo;
1093   PetscInt        cstart, cend;
1094   MPI_Comm        comm;
1095 
1096   PetscFunctionBegin;
1097   PetscCall(PetscObjectGetComm((PetscObject)B, &comm));
1098   PetscCall(MatSeqAIJKokkosGenerateTranspose_Private(ampi->A, &Adt));
1099   PetscCall(MatSeqAIJKokkosGenerateTranspose_Private(ampi->B, &Aot));
1100   PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(ampi->A, &Ad));
1101   PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(ampi->B, &Ao));
1102   PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(bmpi->A, &Bd));
1103   PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(bmpi->B, &Bo));
1104 
1105   // TODO: add command line options to select spgemm algorithms
1106   auto spgemm_alg = KokkosSparse::SPGEMMAlgorithm::SPGEMM_DEFAULT; // default is TPL if enabled, otherwise KK
1107 
1108   // CUDA-10.2's spgemm has bugs. We prefer the SpGEMMreuse APIs introduced in cuda-11.4
1109 #if defined(KOKKOSKERNELS_ENABLE_TPL_CUSPARSE)
1110   #if PETSC_PKG_CUDA_VERSION_LT(11, 4, 0)
1111   spgemm_alg = KokkosSparse::SPGEMMAlgorithm::SPGEMM_KK;
1112   #endif
1113 #endif
1114 
1115   PetscCallCXX(mm->kh1.create_spgemm_handle(spgemm_alg));
1116   PetscCallCXX(mm->kh2.create_spgemm_handle(spgemm_alg));
1117   PetscCallCXX(mm->kh3.create_spgemm_handle(spgemm_alg));
1118   PetscCallCXX(mm->kh4.create_spgemm_handle(spgemm_alg));
1119 
1120   // Aot * (B's diag + B's off-diag)
1121   PetscCallCXX(KokkosSparse::spgemm_symbolic(mm->kh3, Aot, false, Bd, false, mm->C3));
1122   PetscCallCXX(KokkosSparse::spgemm_symbolic(mm->kh4, Aot, false, Bo, false, mm->C4));
1123   // KK spgemm_symbolic() only populates the result's row map, but not its columns.
1124   // TODO: Remove the fake spgemm_numeric() after KK fixed this problem.
1125   PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh3, Aot, false, Bd, false, mm->C3));
1126   PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh4, Aot, false, Bo, false, mm->C4));
1127 #if PETSC_PKG_KOKKOS_KERNELS_VERSION_LT(4, 0, 0)
1128 
1129   PetscCallCXX(sort_crs_matrix(mm->C3));
1130   PetscCallCXX(sort_crs_matrix(mm->C4));
1131 #endif
1132 
1133   // Reduce E (i.e., C3 and C4)'s rows to form F, and overlap the communication
1134   PetscIntKokkosViewHost map_h(NoInit("map_h"), bmpi->B->cmap->n);
1135   PetscCall(MatGetOwnershipRangeColumn(B, &cstart, &cend));
1136   PetscCall(MatMPIAIJKokkosReduceBegin(comm, mm->C3, mm->C4, cstart, cend, bmpi->garray, ampi->Mvctx, MAT_INITIAL_MATRIX, map_h.data(), mm));
1137 
1138   // Adt * (B's diag + B's off-diag)
1139   PetscCallCXX(KokkosSparse::spgemm_symbolic(mm->kh1, Adt, false, Bd, false, mm->C1));
1140   PetscCallCXX(KokkosSparse::spgemm_symbolic(mm->kh2, Adt, false, Bo, false, mm->C2_mid));
1141   PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh1, Adt, false, Bd, false, mm->C1));
1142   PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh2, Adt, false, Bo, false, mm->C2_mid));
1143 #if PETSC_PKG_KOKKOS_KERNELS_VERSION_LT(4, 0, 0)
1144   PetscCallCXX(sort_crs_matrix(mm->C1));
1145   PetscCallCXX(sort_crs_matrix(mm->C2_mid));
1146 #endif
1147 
1148   PetscCall(MatMPIAIJKokkosReduceEnd(comm, mm->C3, mm->C4, cstart, cend, bmpi->garray, ampi->Mvctx, MAT_INITIAL_MATRIX, map_h.data(), mm));
1149 
1150   // Create C2, which shares a, i arrays with C2_mid, but with new column indices and potentially larger column size
1151   MatColIdxKokkosView oldj = mm->C2_mid.graph.entries, newj(NoInit("j"), oldj.extent(0));
1152   PetscIntKokkosView  map  = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), map_h);
1153   PetscCallCXX(Kokkos::parallel_for(Kokkos::RangePolicy<>(PetscGetKokkosExecutionSpace(), 0, oldj.extent(0)), KOKKOS_LAMBDA(const PetscInt i) { newj(i) = map(oldj(i)); }));
1154   PetscCallCXX(mm->C2 = KokkosCsrMatrix("C2", mm->C2_mid.numRows(), mm->n /*new column size*/, mm->C2_mid.nnz(), mm->C2_mid.values, mm->C2_mid.graph.row_map, newj));
1155 
1156   // C = (C1+Fd, C2+Fo)
1157   PetscCallCXX(mm->kh1.create_spadd_handle(true)); // C1, Fd are sorted
1158   PetscCallCXX(mm->kh2.create_spadd_handle(true)); // C2, Fo are sorted
1159   PetscCallCXX(KokkosSparse::spadd_symbolic(&mm->kh1, mm->C1, mm->Fd, mm->Cd));
1160   PetscCallCXX(KokkosSparse::spadd_symbolic(&mm->kh2, mm->C2, mm->Fo, mm->Co));
1161   PetscCallCXX(KokkosSparse::spadd_numeric(&mm->kh1, 1.0, mm->C1, 1.0, mm->Fd, mm->Cd));
1162   PetscCallCXX(KokkosSparse::spadd_numeric(&mm->kh2, 1.0, mm->C2, 1.0, mm->Fo, mm->Co));
1163   PetscFunctionReturn(PETSC_SUCCESS);
1164 }
1165 
1166 static PetscErrorCode MatProductNumeric_MPIAIJKokkos_AtB(Mat_Product *product, Mat A, Mat B, MatMatStruct_AtB *mm)
1167 {
1168   Mat_MPIAIJ     *ampi = static_cast<Mat_MPIAIJ *>(A->data);
1169   Mat_MPIAIJ     *bmpi = static_cast<Mat_MPIAIJ *>(B->data);
1170   KokkosCsrMatrix Adt, Aot, Bd, Bo;
1171   MPI_Comm        comm;
1172 
1173   PetscFunctionBegin;
1174   PetscCall(PetscObjectGetComm((PetscObject)B, &comm));
1175   PetscCall(MatSeqAIJKokkosGenerateTranspose_Private(ampi->A, &Adt));
1176   PetscCall(MatSeqAIJKokkosGenerateTranspose_Private(ampi->B, &Aot));
1177   PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(bmpi->A, &Bd));
1178   PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(bmpi->B, &Bo));
1179 
1180   // Aot * (B's diag + B's off-diag)
1181   PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh3, Aot, false, Bd, false, mm->C3));
1182   PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh4, Aot, false, Bo, false, mm->C4));
1183 
1184   // Reduce E (i.e., C3 and C4)'s rows to form F, and overlap the communication
1185   PetscCall(MatMPIAIJKokkosReduceBegin(comm, mm->C3, mm->C4, 0, 0, NULL, NULL, MAT_REUSE_MATRIX, NULL, mm));
1186 
1187   // Adt * (B's diag + B's off-diag)
1188   PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh1, Adt, false, Bd, false, mm->C1));
1189   PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh2, Adt, false, Bo, false, mm->C2_mid));
1190 
1191   PetscCall(MatMPIAIJKokkosReduceEnd(comm, mm->C3, mm->C4, 0, 0, NULL, NULL, MAT_REUSE_MATRIX, NULL, mm));
1192 
1193   // C = (C1+Fd, C2+Fo)
1194   PetscCallCXX(KokkosSparse::spadd_numeric(&mm->kh1, 1.0, mm->C1, 1.0, mm->Fd, mm->Cd));
1195   PetscCallCXX(KokkosSparse::spadd_numeric(&mm->kh2, 1.0, mm->C2, 1.0, mm->Fo, mm->Co));
1196   PetscFunctionReturn(PETSC_SUCCESS);
1197 }
1198 
1199 /* MatProductSymbolic_MPIAIJKokkos_AB - AB flavor of MatProductSymbolic_MPIAIJKokkos
1200 
1201   Input Parameters:
1202 +  product  - Mat_Product which carried out the computation. Passed in to access info about this mat product.
1203 .  A        - an MPIAIJKOKKOS matrix
1204 .  B        - an MPIAIJKOKKOS matrix
1205 -  mm       - a struct used to stash intermediate data when computing AB. Persist from symbolic to numeric operations.
1206 */
1207 static PetscErrorCode MatProductSymbolic_MPIAIJKokkos_AB(Mat_Product *product, Mat A, Mat B, MatMatStruct_AB *mm)
1208 {
1209   Mat_MPIAIJ     *ampi = static_cast<Mat_MPIAIJ *>(A->data);
1210   Mat_MPIAIJ     *bmpi = static_cast<Mat_MPIAIJ *>(B->data);
1211   KokkosCsrMatrix Ad, Ao, Bd, Bo;
1212 
1213   PetscFunctionBegin;
1214   PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(ampi->A, &Ad));
1215   PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(ampi->B, &Ao));
1216   PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(bmpi->A, &Bd));
1217   PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(bmpi->B, &Bo));
1218 
1219   // TODO: add command line options to select spgemm algorithms
1220   auto spgemm_alg = KokkosSparse::SPGEMMAlgorithm::SPGEMM_DEFAULT; // default is TPL if enabled, otherwise KK
1221 
1222   // CUDA-10.2's spgemm has bugs. We prefer the SpGEMMreuse APIs introduced in cuda-11.4
1223 #if defined(KOKKOSKERNELS_ENABLE_TPL_CUSPARSE)
1224   #if PETSC_PKG_CUDA_VERSION_LT(11, 4, 0)
1225   spgemm_alg = KokkosSparse::SPGEMMAlgorithm::SPGEMM_KK;
1226   #endif
1227 #endif
1228 
1229   mm->kh1.create_spgemm_handle(spgemm_alg);
1230   mm->kh2.create_spgemm_handle(spgemm_alg);
1231   mm->kh3.create_spgemm_handle(spgemm_alg);
1232   mm->kh4.create_spgemm_handle(spgemm_alg);
1233 
1234   // Bcast B's rows to form F, and overlap the communication
1235   PetscIntKokkosViewHost map_h(NoInit("map_h"), bmpi->B->cmap->n);
1236   PetscCall(MatMPIAIJKokkosBcastBegin(B, ampi->Mvctx, MAT_INITIAL_MATRIX, map_h.data(), mm));
1237 
1238   // A's diag * (B's diag + B's off-diag)
1239   PetscCallCXX(KokkosSparse::spgemm_symbolic(mm->kh1, Ad, false, Bd, false, mm->C1));
1240   PetscCallCXX(KokkosSparse::spgemm_symbolic(mm->kh2, Ad, false, Bo, false, mm->C2_mid)); // C2 aliases with C2_mid, except with new column indices
1241   // KK spgemm_symbolic() only populates the result's row map, but not its columns.
1242   // TODO: Remove the fake spgemm_numeric() after KK fixed this problem.
1243   PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh1, Ad, false, Bd, false, mm->C1));
1244   PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh2, Ad, false, Bo, false, mm->C2_mid));
1245 #if PETSC_PKG_KOKKOS_KERNELS_VERSION_LT(4, 0, 0)
1246   PetscCallCXX(sort_crs_matrix(mm->C1));
1247   PetscCallCXX(sort_crs_matrix(mm->C2_mid));
1248 #endif
1249 
1250   PetscCall(MatMPIAIJKokkosBcastEnd(B, ampi->Mvctx, MAT_INITIAL_MATRIX, map_h.data(), mm));
1251 
1252   // A's off-diag * (F's diag + F's off-diag)
1253   PetscCallCXX(KokkosSparse::spgemm_symbolic(mm->kh3, Ao, false, mm->Fd, false, mm->C3));
1254   PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh3, Ao, false, mm->Fd, false, mm->C3));
1255   PetscCallCXX(KokkosSparse::spgemm_symbolic(mm->kh4, Ao, false, mm->Fo, false, mm->C4));
1256   PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh4, Ao, false, mm->Fo, false, mm->C4));
1257 #if PETSC_PKG_KOKKOS_KERNELS_VERSION_LT(4, 0, 0)
1258   PetscCallCXX(sort_crs_matrix(mm->C3));
1259   PetscCallCXX(sort_crs_matrix(mm->C4));
1260 #endif
1261 
1262   // Create C2, which shares a, i arrays with C2_mid, but with new column indices and potentially larger column size
1263   MatColIdxKokkosView oldj = mm->C2_mid.graph.entries, newj(NoInit("j"), oldj.extent(0));
1264   PetscIntKokkosView  map  = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), map_h);
1265   PetscCallCXX(Kokkos::parallel_for(Kokkos::RangePolicy<>(PetscGetKokkosExecutionSpace(), 0, oldj.extent(0)), KOKKOS_LAMBDA(const PetscInt i) { newj(i) = map(oldj(i)); }));
1266   mm->C2 = KokkosCsrMatrix("C2", mm->C2_mid.numRows(), mm->n /*new column size*/, mm->C2_mid.nnz(), mm->C2_mid.values, mm->C2_mid.graph.row_map, newj);
1267 
1268   // C = (Cd, Co) = (C1+C3, C2+C4)
1269   mm->kh1.create_spadd_handle(true); // C1, C3 are sorted
1270   mm->kh2.create_spadd_handle(true); // C2, C4 are sorted
1271   PetscCallCXX(KokkosSparse::spadd_symbolic(&mm->kh1, mm->C1, mm->C3, mm->Cd));
1272   PetscCallCXX(KokkosSparse::spadd_symbolic(&mm->kh2, mm->C2, mm->C4, mm->Co));
1273   PetscCallCXX(KokkosSparse::spadd_numeric(&mm->kh1, 1.0, mm->C1, 1.0, mm->C3, mm->Cd));
1274   PetscCallCXX(KokkosSparse::spadd_numeric(&mm->kh2, 1.0, mm->C2, 1.0, mm->C4, mm->Co));
1275   PetscFunctionReturn(PETSC_SUCCESS);
1276 }
1277 
1278 static PetscErrorCode MatProductNumeric_MPIAIJKokkos_AB(Mat_Product *product, Mat A, Mat B, MatMatStruct_AB *mm)
1279 {
1280   Mat_MPIAIJ     *ampi = static_cast<Mat_MPIAIJ *>(A->data);
1281   Mat_MPIAIJ     *bmpi = static_cast<Mat_MPIAIJ *>(B->data);
1282   KokkosCsrMatrix Ad, Ao, Bd, Bo;
1283 
1284   PetscFunctionBegin;
1285   PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(ampi->A, &Ad));
1286   PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(ampi->B, &Ao));
1287   PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(bmpi->A, &Bd));
1288   PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(bmpi->B, &Bo));
1289 
1290   // Bcast B's rows to form F, and overlap the communication
1291   PetscCall(MatMPIAIJKokkosBcastBegin(B, NULL, MAT_REUSE_MATRIX, NULL, mm));
1292 
1293   // A's diag * (B's diag + B's off-diag)
1294   PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh1, Ad, false, Bd, false, mm->C1));
1295   PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh2, Ad, false, Bo, false, mm->C2_mid));
1296 
1297   PetscCall(MatMPIAIJKokkosBcastEnd(B, NULL, MAT_REUSE_MATRIX, NULL, mm));
1298 
1299   // A's off-diag * (F's diag + F's off-diag)
1300   PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh3, Ao, false, mm->Fd, false, mm->C3));
1301   PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh4, Ao, false, mm->Fo, false, mm->C4));
1302 
1303   // C = (Cd, Co) = (C1+C3, C2+C4)
1304   PetscCallCXX(KokkosSparse::spadd_numeric(&mm->kh1, 1.0, mm->C1, 1.0, mm->C3, mm->Cd));
1305   PetscCallCXX(KokkosSparse::spadd_numeric(&mm->kh2, 1.0, mm->C2, 1.0, mm->C4, mm->Co));
1306   PetscFunctionReturn(PETSC_SUCCESS);
1307 }
1308 
1309 static PetscErrorCode MatProductNumeric_MPIAIJKokkos(Mat C)
1310 {
1311   Mat_MPIAIJ                  *cmpi = static_cast<Mat_MPIAIJ *>(C->data);
1312   Mat_Product                 *product;
1313   MatProductData_MPIAIJKokkos *pdata;
1314   MatProductType               ptype;
1315   Mat                          A, B;
1316 
1317   PetscFunctionBegin;
1318   MatCheckProduct(C, 1); // make sure C is a product
1319   product = C->product;
1320   pdata   = static_cast<MatProductData_MPIAIJKokkos *>(product->data);
1321   ptype   = product->type;
1322   A       = product->A;
1323   B       = product->B;
1324 
1325   // See if numeric has already been done in symbolic (e.g., user calls MatMatMult(A,B,MAT_INITIAL_MATRIX,..,C)).
1326   // If yes, skip the numeric, but reset the flag so that next time when user calls MatMatMult(E,F,MAT_REUSE_MATRIX,..,C),
1327   // we still do numeric.
1328   if (pdata->reusesym) { // numeric reuses results from symbolic
1329     pdata->reusesym = PETSC_FALSE;
1330     PetscFunctionReturn(PETSC_SUCCESS);
1331   }
1332 
1333   if (ptype == MATPRODUCT_AB) {
1334     PetscCall(MatProductNumeric_MPIAIJKokkos_AB(product, A, B, pdata->mmAB));
1335   } else if (ptype == MATPRODUCT_AtB) {
1336     PetscCall(MatProductNumeric_MPIAIJKokkos_AtB(product, A, B, pdata->mmAtB));
1337   } else if (ptype == MATPRODUCT_PtAP) { // BtAB, computed by Z = AB; C= BtZ
1338     PetscCall(MatProductNumeric_MPIAIJKokkos_AB(product, A, B, pdata->mmAB));
1339     PetscCall(MatProductNumeric_MPIAIJKokkos_AtB(product, B, pdata->Z, pdata->mmAtB));
1340   }
1341 
1342   PetscCall(MatSeqAIJKokkosModifyDevice(cmpi->A)); // mark that A, B on device are modified
1343   PetscCall(MatSeqAIJKokkosModifyDevice(cmpi->B));
1344   PetscFunctionReturn(PETSC_SUCCESS);
1345 }
1346 
1347 static PetscErrorCode MatProductSymbolic_MPIAIJKokkos(Mat C)
1348 {
1349   Mat                          A, B;
1350   Mat_Product                 *product;
1351   MatProductType               ptype;
1352   MatProductData_MPIAIJKokkos *pdata;
1353   MatMatStruct                *mm = NULL;
1354   PetscInt                     m, n, M, N;
1355   Mat                          Cd, Co;
1356   MPI_Comm                     comm;
1357 
1358   PetscFunctionBegin;
1359   PetscCall(PetscObjectGetComm((PetscObject)C, &comm));
1360   MatCheckProduct(C, 1);
1361   product = C->product;
1362   PetscCheck(!product->data, comm, PETSC_ERR_PLIB, "Product data not empty");
1363   ptype = product->type;
1364   A     = product->A;
1365   B     = product->B;
1366 
1367   switch (ptype) {
1368   case MATPRODUCT_AB:
1369     m = A->rmap->n;
1370     n = B->cmap->n;
1371     M = A->rmap->N;
1372     N = B->cmap->N;
1373     break;
1374   case MATPRODUCT_AtB:
1375     m = A->cmap->n;
1376     n = B->cmap->n;
1377     M = A->cmap->N;
1378     N = B->cmap->N;
1379     break;
1380   case MATPRODUCT_PtAP:
1381     m = B->cmap->n;
1382     n = B->cmap->n;
1383     M = B->cmap->N;
1384     N = B->cmap->N;
1385     break; /* BtAB */
1386   default:
1387     SETERRQ(comm, PETSC_ERR_PLIB, "Not for product type %s", MatProductTypes[ptype]);
1388   }
1389 
1390   PetscCall(MatSetSizes(C, m, n, M, N));
1391   PetscCall(PetscLayoutSetUp(C->rmap));
1392   PetscCall(PetscLayoutSetUp(C->cmap));
1393   PetscCall(MatSetType(C, ((PetscObject)A)->type_name));
1394 
1395   pdata           = new MatProductData_MPIAIJKokkos();
1396   pdata->reusesym = product->api_user;
1397 
1398   if (ptype == MATPRODUCT_AB) {
1399     auto mmAB = new MatMatStruct_AB();
1400     PetscCall(MatProductSymbolic_MPIAIJKokkos_AB(product, A, B, mmAB));
1401     mm = pdata->mmAB = mmAB;
1402   } else if (ptype == MATPRODUCT_AtB) {
1403     auto mmAtB = new MatMatStruct_AtB();
1404     PetscCall(MatProductSymbolic_MPIAIJKokkos_AtB(product, A, B, mmAtB));
1405     mm = pdata->mmAtB = mmAtB;
1406   } else if (ptype == MATPRODUCT_PtAP) { // C = BtAB, computed as Z = AB; C= BtZ
1407     Mat Zd, Zo, Z;                       // Zd, Zo are owned by pdata->Z
1408 
1409     auto mmAB = new MatMatStruct_AB();
1410     PetscCall(MatProductSymbolic_MPIAIJKokkos_AB(product, A, B, mmAB)); // Z stored as mmAB->{Cd, Co}
1411     PetscCall(MatCreateSeqAIJKokkosWithKokkosCsrMatrix(PETSC_COMM_SELF, mmAB->Cd, &Zd));
1412     PetscCall(MatCreateSeqAIJKokkosWithKokkosCsrMatrix(PETSC_COMM_SELF, mmAB->Co, &Zo));
1413     pdata->mmAB = mmAB;
1414 
1415     m = A->rmap->n; // Z's layout
1416     n = B->cmap->n;
1417     M = A->rmap->N;
1418     N = B->cmap->N;
1419     PetscCall(MatCreate(comm, &Z));
1420     PetscCall(MatSetSizes(Z, m, n, M, N));
1421     PetscCall(PetscLayoutSetUp(Z->rmap));
1422     PetscCall(PetscLayoutSetUp(Z->cmap));
1423     PetscCall(MatSetType(Z, MATMPIAIJKOKKOS));
1424     PetscCall(MatSetMPIAIJKokkosWithSplitSeqAIJKokkosMatrices(Z, Zd, Zo, mmAB->garray));
1425 
1426     auto mmAtB = new MatMatStruct_AtB();
1427     PetscCall(MatProductSymbolic_MPIAIJKokkos_AtB(product, B, Z, mmAtB)); // final result C stored as mmAtB->{Cd, Co}
1428 
1429     pdata->Z = Z; // give ownership to pdata
1430     mm = pdata->mmAtB = mmAtB;
1431   }
1432 
1433   PetscCall(MatCreateSeqAIJKokkosWithKokkosCsrMatrix(PETSC_COMM_SELF, mm->Cd, &Cd));
1434   PetscCall(MatCreateSeqAIJKokkosWithKokkosCsrMatrix(PETSC_COMM_SELF, mm->Co, &Co));
1435   PetscCall(MatSetMPIAIJKokkosWithSplitSeqAIJKokkosMatrices(C, Cd, Co, mm->garray));
1436 
1437   C->product->data       = pdata;
1438   C->product->destroy    = MatProductDataDestroy_MPIAIJKokkos;
1439   C->ops->productnumeric = MatProductNumeric_MPIAIJKokkos;
1440   PetscFunctionReturn(PETSC_SUCCESS);
1441 }
1442 
1443 PETSC_INTERN PetscErrorCode MatProductSetFromOptions_MPIAIJKokkos(Mat mat)
1444 {
1445   Mat_Product *product = mat->product;
1446   PetscBool    match   = PETSC_FALSE;
1447   PetscBool    usecpu  = PETSC_FALSE;
1448 
1449   PetscFunctionBegin;
1450   MatCheckProduct(mat, 1);
1451   if (!product->A->boundtocpu && !product->B->boundtocpu) PetscCall(PetscObjectTypeCompare((PetscObject)product->B, ((PetscObject)product->A)->type_name, &match));
1452   if (match) { /* we can always fallback to the CPU if requested */
1453     switch (product->type) {
1454     case MATPRODUCT_AB:
1455       if (product->api_user) {
1456         PetscOptionsBegin(PetscObjectComm((PetscObject)mat), ((PetscObject)mat)->prefix, "MatMatMult", "Mat");
1457         PetscCall(PetscOptionsBool("-matmatmult_backend_cpu", "Use CPU code", "MatMatMult", usecpu, &usecpu, NULL));
1458         PetscOptionsEnd();
1459       } else {
1460         PetscOptionsBegin(PetscObjectComm((PetscObject)mat), ((PetscObject)mat)->prefix, "MatProduct_AB", "Mat");
1461         PetscCall(PetscOptionsBool("-mat_product_algorithm_backend_cpu", "Use CPU code", "MatMatMult", usecpu, &usecpu, NULL));
1462         PetscOptionsEnd();
1463       }
1464       break;
1465     case MATPRODUCT_AtB:
1466       if (product->api_user) {
1467         PetscOptionsBegin(PetscObjectComm((PetscObject)mat), ((PetscObject)mat)->prefix, "MatTransposeMatMult", "Mat");
1468         PetscCall(PetscOptionsBool("-mattransposematmult_backend_cpu", "Use CPU code", "MatTransposeMatMult", usecpu, &usecpu, NULL));
1469         PetscOptionsEnd();
1470       } else {
1471         PetscOptionsBegin(PetscObjectComm((PetscObject)mat), ((PetscObject)mat)->prefix, "MatProduct_AtB", "Mat");
1472         PetscCall(PetscOptionsBool("-mat_product_algorithm_backend_cpu", "Use CPU code", "MatTransposeMatMult", usecpu, &usecpu, NULL));
1473         PetscOptionsEnd();
1474       }
1475       break;
1476     case MATPRODUCT_PtAP:
1477       if (product->api_user) {
1478         PetscOptionsBegin(PetscObjectComm((PetscObject)mat), ((PetscObject)mat)->prefix, "MatPtAP", "Mat");
1479         PetscCall(PetscOptionsBool("-matptap_backend_cpu", "Use CPU code", "MatPtAP", usecpu, &usecpu, NULL));
1480         PetscOptionsEnd();
1481       } else {
1482         PetscOptionsBegin(PetscObjectComm((PetscObject)mat), ((PetscObject)mat)->prefix, "MatProduct_PtAP", "Mat");
1483         PetscCall(PetscOptionsBool("-mat_product_algorithm_backend_cpu", "Use CPU code", "MatPtAP", usecpu, &usecpu, NULL));
1484         PetscOptionsEnd();
1485       }
1486       break;
1487     default:
1488       break;
1489     }
1490     match = (PetscBool)!usecpu;
1491   }
1492   if (match) {
1493     switch (product->type) {
1494     case MATPRODUCT_AB:
1495     case MATPRODUCT_AtB:
1496     case MATPRODUCT_PtAP:
1497       mat->ops->productsymbolic = MatProductSymbolic_MPIAIJKokkos;
1498       break;
1499     default:
1500       break;
1501     }
1502   }
1503   /* fallback to MPIAIJ ops */
1504   if (!mat->ops->productsymbolic) PetscCall(MatProductSetFromOptions_MPIAIJ(mat));
1505   PetscFunctionReturn(PETSC_SUCCESS);
1506 }
1507 
1508 // Mirror of MatCOOStruct_MPIAIJ on device
1509 struct MatCOOStruct_MPIAIJKokkos {
1510   PetscCount           n;
1511   PetscSF              sf;
1512   PetscCount           Annz, Bnnz;
1513   PetscCount           Annz2, Bnnz2;
1514   PetscCountKokkosView Ajmap1, Aperm1;
1515   PetscCountKokkosView Bjmap1, Bperm1;
1516   PetscCountKokkosView Aimap2, Ajmap2, Aperm2;
1517   PetscCountKokkosView Bimap2, Bjmap2, Bperm2;
1518   PetscCountKokkosView Cperm1;
1519   MatScalarKokkosView  sendbuf, recvbuf;
1520 
1521   MatCOOStruct_MPIAIJKokkos(const MatCOOStruct_MPIAIJ *coo_h)
1522   {
1523     auto &exec = PetscGetKokkosExecutionSpace();
1524 
1525     n       = coo_h->n;
1526     sf      = coo_h->sf;
1527     Annz    = coo_h->Annz;
1528     Bnnz    = coo_h->Bnnz;
1529     Annz2   = coo_h->Annz2;
1530     Bnnz2   = coo_h->Bnnz2;
1531     Ajmap1  = Kokkos::create_mirror_view_and_copy(exec, PetscCountKokkosViewHost(coo_h->Ajmap1, coo_h->Annz + 1));
1532     Aperm1  = Kokkos::create_mirror_view_and_copy(exec, PetscCountKokkosViewHost(coo_h->Aperm1, coo_h->Atot1));
1533     Bjmap1  = Kokkos::create_mirror_view_and_copy(exec, PetscCountKokkosViewHost(coo_h->Bjmap1, coo_h->Bnnz + 1));
1534     Bperm1  = Kokkos::create_mirror_view_and_copy(exec, PetscCountKokkosViewHost(coo_h->Bperm1, coo_h->Btot1));
1535     Aimap2  = Kokkos::create_mirror_view_and_copy(exec, PetscCountKokkosViewHost(coo_h->Aimap2, coo_h->Annz2));
1536     Ajmap2  = Kokkos::create_mirror_view_and_copy(exec, PetscCountKokkosViewHost(coo_h->Ajmap2, coo_h->Annz2 + 1));
1537     Aperm2  = Kokkos::create_mirror_view_and_copy(exec, PetscCountKokkosViewHost(coo_h->Aperm2, coo_h->Atot2));
1538     Bimap2  = Kokkos::create_mirror_view_and_copy(exec, PetscCountKokkosViewHost(coo_h->Bimap2, coo_h->Bnnz2));
1539     Bjmap2  = Kokkos::create_mirror_view_and_copy(exec, PetscCountKokkosViewHost(coo_h->Bjmap2, coo_h->Bnnz2 + 1));
1540     Bperm2  = Kokkos::create_mirror_view_and_copy(exec, PetscCountKokkosViewHost(coo_h->Bperm2, coo_h->Btot2));
1541     Cperm1  = Kokkos::create_mirror_view_and_copy(exec, PetscCountKokkosViewHost(coo_h->Cperm1, coo_h->sendlen));
1542     sendbuf = Kokkos::create_mirror_view(Kokkos::WithoutInitializing, exec, MatScalarKokkosViewHost(coo_h->sendbuf, coo_h->sendlen));
1543     recvbuf = Kokkos::create_mirror_view(Kokkos::WithoutInitializing, exec, MatScalarKokkosViewHost(coo_h->recvbuf, coo_h->recvlen));
1544     PetscCallVoid(PetscObjectReference((PetscObject)sf));
1545   }
1546 
1547   ~MatCOOStruct_MPIAIJKokkos() { PetscCallVoid(PetscSFDestroy(&sf)); }
1548 };
1549 
1550 static PetscErrorCode MatCOOStructDestroy_MPIAIJKokkos(void *data)
1551 {
1552   PetscFunctionBegin;
1553   PetscCallCXX(delete static_cast<MatCOOStruct_MPIAIJKokkos *>(data));
1554   PetscFunctionReturn(PETSC_SUCCESS);
1555 }
1556 
1557 static PetscErrorCode MatSetPreallocationCOO_MPIAIJKokkos(Mat mat, PetscCount coo_n, PetscInt coo_i[], PetscInt coo_j[])
1558 {
1559   PetscContainer             container_h, container_d;
1560   MatCOOStruct_MPIAIJ       *coo_h;
1561   MatCOOStruct_MPIAIJKokkos *coo_d;
1562 
1563   PetscFunctionBegin;
1564   PetscCall(MatSetPreallocationCOO_MPIAIJ(mat, coo_n, coo_i, coo_j)); /* mpiaij->A,B's type is set to seqaijkokkos */
1565   mat->preallocated = PETSC_TRUE;
1566   PetscCall(MatAssemblyBegin(mat, MAT_FINAL_ASSEMBLY));
1567   PetscCall(MatAssemblyEnd(mat, MAT_FINAL_ASSEMBLY));
1568   PetscCall(MatZeroEntries(mat));
1569 
1570   // Copy the COO struct to device
1571   PetscCall(PetscObjectQuery((PetscObject)mat, "__PETSc_MatCOOStruct_Host", (PetscObject *)&container_h));
1572   PetscCall(PetscContainerGetPointer(container_h, (void **)&coo_h));
1573   PetscCallCXX(coo_d = new MatCOOStruct_MPIAIJKokkos(coo_h));
1574 
1575   // Put the COO struct in a container and then attach that to the matrix
1576   PetscCall(PetscContainerCreate(PETSC_COMM_SELF, &container_d));
1577   PetscCall(PetscContainerSetPointer(container_d, coo_d));
1578   PetscCall(PetscContainerSetUserDestroy(container_d, MatCOOStructDestroy_MPIAIJKokkos));
1579   PetscCall(PetscObjectCompose((PetscObject)mat, "__PETSc_MatCOOStruct_Device", (PetscObject)container_d));
1580   PetscCall(PetscContainerDestroy(&container_d));
1581   PetscFunctionReturn(PETSC_SUCCESS);
1582 }
1583 
1584 static PetscErrorCode MatSetValuesCOO_MPIAIJKokkos(Mat mat, const PetscScalar v[], InsertMode imode)
1585 {
1586   Mat_MPIAIJ                    *mpiaij = static_cast<Mat_MPIAIJ *>(mat->data);
1587   Mat                            A = mpiaij->A, B = mpiaij->B;
1588   MatScalarKokkosView            Aa, Ba;
1589   MatScalarKokkosView            v1;
1590   PetscMemType                   memtype;
1591   PetscContainer                 container;
1592   MatCOOStruct_MPIAIJKokkos     *coo;
1593   Kokkos::DefaultExecutionSpace &exec = PetscGetKokkosExecutionSpace();
1594 
1595   PetscFunctionBegin;
1596   PetscCall(PetscObjectQuery((PetscObject)mat, "__PETSc_MatCOOStruct_Device", (PetscObject *)&container));
1597   PetscCall(PetscContainerGetPointer(container, (void **)&coo));
1598 
1599   const auto &n      = coo->n;
1600   const auto &Annz   = coo->Annz;
1601   const auto &Annz2  = coo->Annz2;
1602   const auto &Bnnz   = coo->Bnnz;
1603   const auto &Bnnz2  = coo->Bnnz2;
1604   const auto &vsend  = coo->sendbuf;
1605   const auto &v2     = coo->recvbuf;
1606   const auto &Ajmap1 = coo->Ajmap1;
1607   const auto &Ajmap2 = coo->Ajmap2;
1608   const auto &Aimap2 = coo->Aimap2;
1609   const auto &Bjmap1 = coo->Bjmap1;
1610   const auto &Bjmap2 = coo->Bjmap2;
1611   const auto &Bimap2 = coo->Bimap2;
1612   const auto &Aperm1 = coo->Aperm1;
1613   const auto &Aperm2 = coo->Aperm2;
1614   const auto &Bperm1 = coo->Bperm1;
1615   const auto &Bperm2 = coo->Bperm2;
1616   const auto &Cperm1 = coo->Cperm1;
1617 
1618   PetscCall(PetscGetMemType(v, &memtype)); /* Return PETSC_MEMTYPE_HOST when v is NULL */
1619   if (PetscMemTypeHost(memtype)) {         /* If user gave v[] in host, we need to copy it to device if any */
1620     v1 = Kokkos::create_mirror_view_and_copy(exec, MatScalarKokkosViewHost((PetscScalar *)v, n));
1621   } else {
1622     v1 = MatScalarKokkosView((PetscScalar *)v, n); /* Directly use v[]'s memory */
1623   }
1624 
1625   if (imode == INSERT_VALUES) {
1626     PetscCall(MatSeqAIJGetKokkosViewWrite(A, &Aa)); /* write matrix values */
1627     PetscCall(MatSeqAIJGetKokkosViewWrite(B, &Ba));
1628   } else {
1629     PetscCall(MatSeqAIJGetKokkosView(A, &Aa)); /* read & write matrix values */
1630     PetscCall(MatSeqAIJGetKokkosView(B, &Ba));
1631   }
1632 
1633   PetscCall(PetscLogGpuTimeBegin());
1634   /* Pack entries to be sent to remote */
1635   Kokkos::parallel_for(Kokkos::RangePolicy<>(exec, 0, vsend.extent(0)), KOKKOS_LAMBDA(const PetscCount i) { vsend(i) = v1(Cperm1(i)); });
1636 
1637   /* Send remote entries to their owner and overlap the communication with local computation */
1638   PetscCall(PetscSFReduceWithMemTypeBegin(coo->sf, MPIU_SCALAR, PETSC_MEMTYPE_KOKKOS, vsend.data(), PETSC_MEMTYPE_KOKKOS, v2.data(), MPI_REPLACE));
1639   /* Add local entries to A and B in one kernel */
1640   Kokkos::parallel_for(
1641     Kokkos::RangePolicy<>(exec, 0, Annz + Bnnz), KOKKOS_LAMBDA(PetscCount i) {
1642       PetscScalar sum = 0.0;
1643       if (i < Annz) {
1644         for (PetscCount k = Ajmap1(i); k < Ajmap1(i + 1); k++) sum += v1(Aperm1(k));
1645         Aa(i) = (imode == INSERT_VALUES ? 0.0 : Aa(i)) + sum;
1646       } else {
1647         i -= Annz;
1648         for (PetscCount k = Bjmap1(i); k < Bjmap1(i + 1); k++) sum += v1(Bperm1(k));
1649         Ba(i) = (imode == INSERT_VALUES ? 0.0 : Ba(i)) + sum;
1650       }
1651     });
1652   PetscCall(PetscSFReduceEnd(coo->sf, MPIU_SCALAR, vsend.data(), v2.data(), MPI_REPLACE));
1653 
1654   /* Add received remote entries to A and B in one kernel */
1655   Kokkos::parallel_for(
1656     Kokkos::RangePolicy<>(exec, 0, Annz2 + Bnnz2), KOKKOS_LAMBDA(PetscCount i) {
1657       if (i < Annz2) {
1658         for (PetscCount k = Ajmap2(i); k < Ajmap2(i + 1); k++) Aa(Aimap2(i)) += v2(Aperm2(k));
1659       } else {
1660         i -= Annz2;
1661         for (PetscCount k = Bjmap2(i); k < Bjmap2(i + 1); k++) Ba(Bimap2(i)) += v2(Bperm2(k));
1662       }
1663     });
1664   PetscCall(PetscLogGpuTimeEnd());
1665 
1666   if (imode == INSERT_VALUES) {
1667     PetscCall(MatSeqAIJRestoreKokkosViewWrite(A, &Aa)); /* Increase A & B's state etc. */
1668     PetscCall(MatSeqAIJRestoreKokkosViewWrite(B, &Ba));
1669   } else {
1670     PetscCall(MatSeqAIJRestoreKokkosView(A, &Aa));
1671     PetscCall(MatSeqAIJRestoreKokkosView(B, &Ba));
1672   }
1673   PetscFunctionReturn(PETSC_SUCCESS);
1674 }
1675 
1676 static PetscErrorCode MatDestroy_MPIAIJKokkos(Mat A)
1677 {
1678   PetscFunctionBegin;
1679   PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatMPIAIJSetPreallocation_C", NULL));
1680   PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatMPIAIJGetLocalMatMerge_C", NULL));
1681   PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatSetPreallocationCOO_C", NULL));
1682   PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatSetValuesCOO_C", NULL));
1683   PetscCall(MatDestroy_MPIAIJ(A));
1684   PetscFunctionReturn(PETSC_SUCCESS);
1685 }
1686 
1687 static PetscErrorCode MatShift_MPIAIJKokkos(Mat A, PetscScalar a)
1688 {
1689   Mat_MPIAIJ *mpiaij = static_cast<Mat_MPIAIJ *>(A->data);
1690   PetscBool   congruent;
1691 
1692   PetscFunctionBegin;
1693   PetscCall(MatHasCongruentLayouts(A, &congruent));
1694   if (congruent) { // square matrix and the diagonals are solely in the diag block
1695     PetscCall(MatShift(mpiaij->A, a));
1696   } else { // too hard, use the general version
1697     PetscCall(MatShift_Basic(A, a));
1698   }
1699   PetscFunctionReturn(PETSC_SUCCESS);
1700 }
1701 
1702 static PetscErrorCode MatSetOps_MPIAIJKokkos(Mat B)
1703 {
1704   PetscFunctionBegin;
1705   B->ops->assemblyend           = MatAssemblyEnd_MPIAIJKokkos;
1706   B->ops->mult                  = MatMult_MPIAIJKokkos;
1707   B->ops->multadd               = MatMultAdd_MPIAIJKokkos;
1708   B->ops->multtranspose         = MatMultTranspose_MPIAIJKokkos;
1709   B->ops->productsetfromoptions = MatProductSetFromOptions_MPIAIJKokkos;
1710   B->ops->destroy               = MatDestroy_MPIAIJKokkos;
1711   B->ops->shift                 = MatShift_MPIAIJKokkos;
1712 
1713   PetscCall(PetscObjectComposeFunction((PetscObject)B, "MatMPIAIJSetPreallocation_C", MatMPIAIJSetPreallocation_MPIAIJKokkos));
1714   PetscCall(PetscObjectComposeFunction((PetscObject)B, "MatMPIAIJGetLocalMatMerge_C", MatMPIAIJGetLocalMatMerge_MPIAIJKokkos));
1715   PetscCall(PetscObjectComposeFunction((PetscObject)B, "MatSetPreallocationCOO_C", MatSetPreallocationCOO_MPIAIJKokkos));
1716   PetscCall(PetscObjectComposeFunction((PetscObject)B, "MatSetValuesCOO_C", MatSetValuesCOO_MPIAIJKokkos));
1717   PetscFunctionReturn(PETSC_SUCCESS);
1718 }
1719 
1720 PETSC_INTERN PetscErrorCode MatConvert_MPIAIJ_MPIAIJKokkos(Mat A, MatType mtype, MatReuse reuse, Mat *newmat)
1721 {
1722   Mat         B;
1723   Mat_MPIAIJ *a;
1724 
1725   PetscFunctionBegin;
1726   if (reuse == MAT_INITIAL_MATRIX) {
1727     PetscCall(MatDuplicate(A, MAT_COPY_VALUES, newmat));
1728   } else if (reuse == MAT_REUSE_MATRIX) {
1729     PetscCall(MatCopy(A, *newmat, SAME_NONZERO_PATTERN));
1730   }
1731   B = *newmat;
1732 
1733   B->boundtocpu = PETSC_FALSE;
1734   PetscCall(PetscFree(B->defaultvectype));
1735   PetscCall(PetscStrallocpy(VECKOKKOS, &B->defaultvectype));
1736   PetscCall(PetscObjectChangeTypeName((PetscObject)B, MATMPIAIJKOKKOS));
1737 
1738   a = static_cast<Mat_MPIAIJ *>(A->data);
1739   if (a->A) PetscCall(MatSetType(a->A, MATSEQAIJKOKKOS));
1740   if (a->B) PetscCall(MatSetType(a->B, MATSEQAIJKOKKOS));
1741   if (a->lvec) PetscCall(VecSetType(a->lvec, VECSEQKOKKOS));
1742   PetscCall(MatSetOps_MPIAIJKokkos(B));
1743   PetscFunctionReturn(PETSC_SUCCESS);
1744 }
1745 
1746 /*MC
1747    MATAIJKOKKOS - "mpiaijkokkos", a matrix type to be used for CSR sparse matrices with Kokkos
1748 
1749    A matrix type using Kokkos-Kernels CrsMatrix type for portability across different device types
1750 
1751    Options Database Key:
1752 .  -mat_type aijkokkos - sets the matrix type to `MATAIJKOKKOS`
1753 
1754   Level: beginner
1755 
1756 .seealso: [](ch_matrices), `Mat`, `MatCreateAIJKokkos()`, `MATSEQAIJKOKKOS`, `MATSEQAIJ`, `MATMPIAIJ`
1757 M*/
1758 PETSC_EXTERN PetscErrorCode MatCreate_MPIAIJKokkos(Mat A)
1759 {
1760   PetscFunctionBegin;
1761   PetscCall(PetscKokkosInitializeCheck());
1762   PetscCall(MatCreate_MPIAIJ(A));
1763   PetscCall(MatConvert_MPIAIJ_MPIAIJKokkos(A, MATMPIAIJKOKKOS, MAT_INPLACE_MATRIX, &A));
1764   PetscFunctionReturn(PETSC_SUCCESS);
1765 }
1766 
1767 /*@C
1768   MatCreateAIJKokkos - Creates a sparse matrix in `MATAIJKOKOS` (compressed row) format
1769   (the default parallel PETSc format).  This matrix will ultimately pushed down
1770   to Kokkos for calculations.
1771 
1772   Collective
1773 
1774   Input Parameters:
1775 + comm  - MPI communicator, set to `PETSC_COMM_SELF`
1776 . m     - number of local rows (or `PETSC_DECIDE` to have calculated if `M` is given)
1777            This value should be the same as the local size used in creating the
1778            y vector for the matrix-vector product y = Ax.
1779 . n     - This value should be the same as the local size used in creating the
1780        x vector for the matrix-vector product y = Ax. (or `PETSC_DECIDE` to have
1781        calculated if N is given) For square matrices n is almost always `m`.
1782 . M     - number of global rows (or `PETSC_DETERMINE` to have calculated if `m` is given)
1783 . N     - number of global columns (or `PETSC_DETERMINE` to have calculated if `n` is given)
1784 . d_nz  - number of nonzeros per row in DIAGONAL portion of local submatrix
1785            (same value is used for all local rows)
1786 . d_nnz - array containing the number of nonzeros in the various rows of the
1787            DIAGONAL portion of the local submatrix (possibly different for each row)
1788            or `NULL`, if `d_nz` is used to specify the nonzero structure.
1789            The size of this array is equal to the number of local rows, i.e `m`.
1790            For matrices you plan to factor you must leave room for the diagonal entry and
1791            put in the entry even if it is zero.
1792 . o_nz  - number of nonzeros per row in the OFF-DIAGONAL portion of local
1793            submatrix (same value is used for all local rows).
1794 - o_nnz - array containing the number of nonzeros in the various rows of the
1795            OFF-DIAGONAL portion of the local submatrix (possibly different for
1796            each row) or `NULL`, if `o_nz` is used to specify the nonzero
1797            structure. The size of this array is equal to the number
1798            of local rows, i.e `m`.
1799 
1800   Output Parameter:
1801 . A - the matrix
1802 
1803   Level: intermediate
1804 
1805   Notes:
1806   It is recommended that one use the `MatCreate()`, `MatSetType()` and/or `MatSetFromOptions()`,
1807   MatXXXXSetPreallocation() paradigm instead of this routine directly.
1808   [MatXXXXSetPreallocation() is, for example, `MatSeqAIJSetPreallocation()`]
1809 
1810   The AIJ format, also called compressed row storage), is fully compatible with standard Fortran
1811   storage.  That is, the stored row and column indices can begin at
1812   either one (as in Fortran) or zero.
1813 
1814 .seealso: [](ch_matrices), `Mat`, `MATAIJKOKOS`, `MATSEQAIJKOKOS`, `MATMPIAIJKOKOS`, `MatCreate()`, `MatCreateAIJ()`, `MatSetValues()`,
1815           `MatSeqAIJSetColumnIndices()`, `MatCreateSeqAIJWithArrays()`, `MATMPIAIJKOKKOS`, `MATAIJKOKKOS`
1816 @*/
1817 PetscErrorCode MatCreateAIJKokkos(MPI_Comm comm, PetscInt m, PetscInt n, PetscInt M, PetscInt N, PetscInt d_nz, const PetscInt d_nnz[], PetscInt o_nz, const PetscInt o_nnz[], Mat *A)
1818 {
1819   PetscMPIInt size;
1820 
1821   PetscFunctionBegin;
1822   PetscCall(MatCreate(comm, A));
1823   PetscCall(MatSetSizes(*A, m, n, M, N));
1824   PetscCallMPI(MPI_Comm_size(comm, &size));
1825   if (size > 1) {
1826     PetscCall(MatSetType(*A, MATMPIAIJKOKKOS));
1827     PetscCall(MatMPIAIJSetPreallocation(*A, d_nz, d_nnz, o_nz, o_nnz));
1828   } else {
1829     PetscCall(MatSetType(*A, MATSEQAIJKOKKOS));
1830     PetscCall(MatSeqAIJSetPreallocation(*A, d_nz, d_nnz));
1831   }
1832   PetscFunctionReturn(PETSC_SUCCESS);
1833 }
1834