xref: /petsc/src/mat/impls/aij/mpi/kokkos/mpiaijkok.kokkos.cxx (revision bcee047adeeb73090d7e36cc71e39fc287cdbb97)
1 #include <petscvec_kokkos.hpp>
2 #include <petscpkg_version.h>
3 #include <petsc/private/sfimpl.h>
4 #include <../src/mat/impls/aij/seq/kokkos/aijkok.hpp>
5 #include <../src/mat/impls/aij/mpi/mpiaij.h>
6 #include <KokkosSparse_spadd.hpp>
7 #include <KokkosSparse_spgemm.hpp>
8 
9 PetscErrorCode MatAssemblyEnd_MPIAIJKokkos(Mat A, MatAssemblyType mode)
10 {
11   Mat_MPIAIJ *mpiaij = (Mat_MPIAIJ *)A->data;
12 
13   PetscFunctionBegin;
14   PetscCall(MatAssemblyEnd_MPIAIJ(A, mode));
15   /* E.g., MatCreateSubMatrix() calls MatCreateMPIAIJWithSeqAIJ(comm,A,B,..), which creates Bnew of SEQAIJ and destroys B of SEQAIJKOKKOS.
16      Thus we finalize A/B/lvec's type in MatAssemblyEnd() to handle various cases.
17    */
18   if (mode == MAT_FINAL_ASSEMBLY) {
19     PetscCall(MatSetType(mpiaij->A, MATSEQAIJKOKKOS));
20     PetscCall(MatSetType(mpiaij->B, MATSEQAIJKOKKOS));
21     PetscCall(VecSetType(mpiaij->lvec, VECSEQKOKKOS));
22   }
23 
24   PetscFunctionReturn(PETSC_SUCCESS);
25 }
26 
27 PetscErrorCode MatMPIAIJSetPreallocation_MPIAIJKokkos(Mat mat, PetscInt d_nz, const PetscInt d_nnz[], PetscInt o_nz, const PetscInt o_nnz[])
28 {
29   Mat_MPIAIJ *mpiaij = (Mat_MPIAIJ *)mat->data;
30 
31   PetscFunctionBegin;
32   PetscCall(PetscLayoutSetUp(mat->rmap));
33   PetscCall(PetscLayoutSetUp(mat->cmap));
34 #if defined(PETSC_USE_DEBUG)
35   if (d_nnz) {
36     PetscInt i;
37     for (i = 0; i < mat->rmap->n; i++) PetscCheck(d_nnz[i] >= 0, PETSC_COMM_SELF, PETSC_ERR_ARG_OUTOFRANGE, "d_nnz cannot be less than 0: local row %" PetscInt_FMT " value %" PetscInt_FMT, i, d_nnz[i]);
38   }
39   if (o_nnz) {
40     PetscInt i;
41     for (i = 0; i < mat->rmap->n; i++) PetscCheck(o_nnz[i] >= 0, PETSC_COMM_SELF, PETSC_ERR_ARG_OUTOFRANGE, "o_nnz cannot be less than 0: local row %" PetscInt_FMT " value %" PetscInt_FMT, i, o_nnz[i]);
42   }
43 #endif
44 #if defined(PETSC_USE_CTABLE)
45   PetscCall(PetscHMapIDestroy(&mpiaij->colmap));
46 #else
47   PetscCall(PetscFree(mpiaij->colmap));
48 #endif
49   PetscCall(PetscFree(mpiaij->garray));
50   PetscCall(VecDestroy(&mpiaij->lvec));
51   PetscCall(VecScatterDestroy(&mpiaij->Mvctx));
52   /* Because the B will have been resized we simply destroy it and create a new one each time */
53   PetscCall(MatDestroy(&mpiaij->B));
54 
55   if (!mpiaij->A) {
56     PetscCall(MatCreate(PETSC_COMM_SELF, &mpiaij->A));
57     PetscCall(MatSetSizes(mpiaij->A, mat->rmap->n, mat->cmap->n, mat->rmap->n, mat->cmap->n));
58   }
59   if (!mpiaij->B) {
60     PetscMPIInt size;
61     PetscCallMPI(MPI_Comm_size(PetscObjectComm((PetscObject)mat), &size));
62     PetscCall(MatCreate(PETSC_COMM_SELF, &mpiaij->B));
63     PetscCall(MatSetSizes(mpiaij->B, mat->rmap->n, size > 1 ? mat->cmap->N : 0, mat->rmap->n, size > 1 ? mat->cmap->N : 0));
64   }
65   PetscCall(MatSetType(mpiaij->A, MATSEQAIJKOKKOS));
66   PetscCall(MatSetType(mpiaij->B, MATSEQAIJKOKKOS));
67   PetscCall(MatSeqAIJSetPreallocation(mpiaij->A, d_nz, d_nnz));
68   PetscCall(MatSeqAIJSetPreallocation(mpiaij->B, o_nz, o_nnz));
69   mat->preallocated = PETSC_TRUE;
70   PetscFunctionReturn(PETSC_SUCCESS);
71 }
72 
73 PetscErrorCode MatMult_MPIAIJKokkos(Mat mat, Vec xx, Vec yy)
74 {
75   Mat_MPIAIJ *mpiaij = (Mat_MPIAIJ *)mat->data;
76   PetscInt    nt;
77 
78   PetscFunctionBegin;
79   PetscCall(VecGetLocalSize(xx, &nt));
80   PetscCheck(nt == mat->cmap->n, PETSC_COMM_SELF, PETSC_ERR_ARG_SIZ, "Incompatible partition of mat (%" PetscInt_FMT ") and xx (%" PetscInt_FMT ")", mat->cmap->n, nt);
81   PetscCall(VecScatterBegin(mpiaij->Mvctx, xx, mpiaij->lvec, INSERT_VALUES, SCATTER_FORWARD));
82   PetscCall((*mpiaij->A->ops->mult)(mpiaij->A, xx, yy));
83   PetscCall(VecScatterEnd(mpiaij->Mvctx, xx, mpiaij->lvec, INSERT_VALUES, SCATTER_FORWARD));
84   PetscCall((*mpiaij->B->ops->multadd)(mpiaij->B, mpiaij->lvec, yy, yy));
85   PetscFunctionReturn(PETSC_SUCCESS);
86 }
87 
88 PetscErrorCode MatMultAdd_MPIAIJKokkos(Mat mat, Vec xx, Vec yy, Vec zz)
89 {
90   Mat_MPIAIJ *mpiaij = (Mat_MPIAIJ *)mat->data;
91   PetscInt    nt;
92 
93   PetscFunctionBegin;
94   PetscCall(VecGetLocalSize(xx, &nt));
95   PetscCheck(nt == mat->cmap->n, PETSC_COMM_SELF, PETSC_ERR_ARG_SIZ, "Incompatible partition of mat (%" PetscInt_FMT ") and xx (%" PetscInt_FMT ")", mat->cmap->n, nt);
96   PetscCall(VecScatterBegin(mpiaij->Mvctx, xx, mpiaij->lvec, INSERT_VALUES, SCATTER_FORWARD));
97   PetscCall((*mpiaij->A->ops->multadd)(mpiaij->A, xx, yy, zz));
98   PetscCall(VecScatterEnd(mpiaij->Mvctx, xx, mpiaij->lvec, INSERT_VALUES, SCATTER_FORWARD));
99   PetscCall((*mpiaij->B->ops->multadd)(mpiaij->B, mpiaij->lvec, zz, zz));
100   PetscFunctionReturn(PETSC_SUCCESS);
101 }
102 
103 PetscErrorCode MatMultTranspose_MPIAIJKokkos(Mat mat, Vec xx, Vec yy)
104 {
105   Mat_MPIAIJ *mpiaij = (Mat_MPIAIJ *)mat->data;
106   PetscInt    nt;
107 
108   PetscFunctionBegin;
109   PetscCall(VecGetLocalSize(xx, &nt));
110   PetscCheck(nt == mat->rmap->n, PETSC_COMM_SELF, PETSC_ERR_ARG_SIZ, "Incompatible partition of mat (%" PetscInt_FMT ") and xx (%" PetscInt_FMT ")", mat->rmap->n, nt);
111   PetscCall((*mpiaij->B->ops->multtranspose)(mpiaij->B, xx, mpiaij->lvec));
112   PetscCall((*mpiaij->A->ops->multtranspose)(mpiaij->A, xx, yy));
113   PetscCall(VecScatterBegin(mpiaij->Mvctx, mpiaij->lvec, yy, ADD_VALUES, SCATTER_REVERSE));
114   PetscCall(VecScatterEnd(mpiaij->Mvctx, mpiaij->lvec, yy, ADD_VALUES, SCATTER_REVERSE));
115   PetscFunctionReturn(PETSC_SUCCESS);
116 }
117 
118 /* Merge the "A, B" matrices of mat into a matrix C.  mat's type is MPIAIJKOKKOS. C's type is MATSEQAIJKOKKOS.
119    A is put before B. C's size would be A->rmap->n by (A->cmap->n + B->cmap->n).
120    C still uses local column ids. Their corresponding global column ids are returned in glob.
121 */
122 PetscErrorCode MatMPIAIJGetLocalMatMerge_MPIAIJKokkos(Mat mat, MatReuse reuse, IS *glob, Mat *C)
123 {
124   Mat             Ad, Ao;
125   const PetscInt *cmap;
126 
127   PetscFunctionBegin;
128   PetscCall(MatMPIAIJGetSeqAIJ(mat, &Ad, &Ao, &cmap));
129   PetscCall(MatSeqAIJKokkosMergeMats(Ad, Ao, reuse, C));
130   if (glob) {
131     PetscInt cst, i, dn, on, *gidx;
132     PetscCall(MatGetLocalSize(Ad, NULL, &dn));
133     PetscCall(MatGetLocalSize(Ao, NULL, &on));
134     PetscCall(MatGetOwnershipRangeColumn(mat, &cst, NULL));
135     PetscCall(PetscMalloc1(dn + on, &gidx));
136     for (i = 0; i < dn; i++) gidx[i] = cst + i;
137     for (i = 0; i < on; i++) gidx[i + dn] = cmap[i];
138     PetscCall(ISCreateGeneral(PetscObjectComm((PetscObject)Ad), dn + on, gidx, PETSC_OWN_POINTER, glob));
139   }
140   PetscFunctionReturn(PETSC_SUCCESS);
141 }
142 
143 /* Structs used in matrix products of type C=AB, C=A^tB and C=B^tAB */
144 struct MatMatStruct {
145   PetscInt            n, *garray;     // C's garray and its size.
146   KokkosCsrMatrix     Cd, Co;         // C is in split form matrices (all in local column indcies)
147   KokkosCsrMatrix     C1, C2, C3, C4; // intermediate mat products
148   KokkosCsrMatrix     C2_mid, C4_mid; // alias of C2, C4; share their a[], i[], but with different j[] (hence column size)
149   PetscIntKokkosView  E_NzLeft;
150   PetscSF             sf = nullptr; // SF to bcast or reduce matrices E to F
151   MatScalarKokkosView rootBuf, leafBuf;
152   KokkosCsrMatrix     Fd, Fo; // F in split form
153 
154   KernelHandle kh1; // compute C1, add C1+C3 or C1+Fd
155   KernelHandle kh2; // compute C2, add C2+C4 or C2+Fo
156   KernelHandle kh3; // compute C3
157   KernelHandle kh4; // compute C4
158 
159   PetscInt E_TeamSize; // kernel launching parameters in merging E or splitting F
160   PetscInt E_VectorLength;
161   PetscInt E_RowsPerTeam;
162   PetscInt F_TeamSize;
163   PetscInt F_VectorLength;
164   PetscInt F_RowsPerTeam;
165 
166   ~MatMatStruct()
167   {
168     PetscFunctionBegin;
169     PetscCallAbort(PETSC_COMM_SELF, PetscSFDestroy(&sf));
170     PetscFunctionReturnVoid();
171   }
172 };
173 
174 struct MatMatStruct_AB : public MatMatStruct {
175   PetscIntKokkosView F_NzLeft; // plans to split F (in leafbuf) into Fd, Fo
176   PetscIntKokkosView irootloc; // plans to put E (i.e., Bd, Bo) into rootBuf
177   PetscIntKokkosView rowoffset;
178 };
179 
180 struct MatMatStruct_AtB : public MatMatStruct {
181   MatColIdxKokkosView Fdjmap; // plans to reduce data in rootBuf to Fd, Fo
182   MatColIdxKokkosView Fdjperm;
183   MatColIdxKokkosView Fojmap;
184   MatColIdxKokkosView Fojperm;
185 };
186 
187 struct MatProductData_MPIAIJKokkos {
188   MatMatStruct_AB  *mmAB     = nullptr;
189   MatMatStruct_AtB *mmAtB    = nullptr;
190   PetscBool         reusesym = PETSC_FALSE;
191   Mat               Z        = nullptr; // store Z=AB in computing BtAB
192 
193   ~MatProductData_MPIAIJKokkos()
194   {
195     delete mmAB;
196     delete mmAtB;
197     PetscCallAbort(PETSC_COMM_SELF, MatDestroy(&Z));
198   }
199 };
200 
201 static PetscErrorCode MatProductDataDestroy_MPIAIJKokkos(void *data)
202 {
203   PetscFunctionBegin;
204   PetscCallCXX(delete static_cast<MatProductData_MPIAIJKokkos *>(data));
205   PetscFunctionReturn(PETSC_SUCCESS);
206 }
207 
208 /* MatSetMPIAIJKokkosWithSplitSeqAIJKokkosMatrices - Set the diag and offdiag matrices of a MATMPIAIJKOKKOS matrix.
209    It is similar to MatCreateMPIAIJWithSplitArrays.
210 
211   Input Parameters:
212 +  mat   - the MATMPIAIJKOKKOS matrix, which should have its type and layout set, but should not have its diag, offdiag matrices set
213 .  A     - the diag matrix using local col ids
214 -  B     - the offdiag matrix using global col ids
215 
216   Output Parameter:
217 .  mat   - the updated MATMPIAIJKOKKOS matrix
218 */
219 static PetscErrorCode MatSetMPIAIJKokkosWithSplitSeqAIJKokkosMatrices(Mat mat, Mat A, Mat B, PetscInt *garray)
220 {
221   Mat_MPIAIJ *mpiaij = static_cast<Mat_MPIAIJ *>(mat->data);
222   PetscInt    m, n, M, N, Am, An, Bm, Bn;
223 
224   PetscFunctionBegin;
225   PetscCall(MatGetSize(mat, &M, &N));
226   PetscCall(MatGetLocalSize(mat, &m, &n));
227   PetscCall(MatGetLocalSize(A, &Am, &An));
228   PetscCall(MatGetLocalSize(B, &Bm, &Bn));
229 
230   PetscCheck(m == Am && m == Bm, PETSC_COMM_SELF, PETSC_ERR_PLIB, "local number of rows do not match");
231   PetscCheck(n == An, PETSC_COMM_SELF, PETSC_ERR_PLIB, "local number of columns do not match");
232   // PetscCheck(N == Bn, PETSC_COMM_SELF, PETSC_ERR_PLIB, "global number of columns do not match");
233   PetscCheck(!mpiaij->A && !mpiaij->B, PETSC_COMM_SELF, PETSC_ERR_PLIB, "A, B of the MPIAIJ matrix are not empty");
234   mpiaij->A      = A;
235   mpiaij->B      = B;
236   mpiaij->garray = garray;
237 
238   mat->preallocated     = PETSC_TRUE;
239   mat->nooffprocentries = PETSC_TRUE; /* See MatAssemblyBegin_MPIAIJ. In effect, making MatAssemblyBegin a nop */
240 
241   PetscCall(MatSetOption(mat, MAT_NO_OFF_PROC_ENTRIES, PETSC_TRUE));
242   PetscCall(MatAssemblyBegin(mat, MAT_FINAL_ASSEMBLY));
243   /* MatAssemblyEnd is critical here. It sets mat->offloadmask according to A and B's, and
244     also gets mpiaij->B compacted, with its col ids and size reduced
245   */
246   PetscCall(MatAssemblyEnd(mat, MAT_FINAL_ASSEMBLY));
247   PetscCall(MatSetOption(mat, MAT_NO_OFF_PROC_ENTRIES, PETSC_FALSE));
248   PetscCall(MatSetOption(mat, MAT_NEW_NONZERO_LOCATION_ERR, PETSC_TRUE));
249   PetscFunctionReturn(PETSC_SUCCESS);
250 }
251 
252 // Adapted from Kokkos-Kernels spmv_launch_parameters(), to get parameters in Kokkos nested loops which we used to merge or
253 // split csr matrices. The rule is to have "vector_length * team_size" be around 256 on GPUs (e.g., for a CUDA thread block)
254 template <class ExecutionSpace>
255 static PetscErrorCode MatMergeGetLaunchParameters(PetscInt numRows, PetscInt nnz, PetscInt rows_per_thread, PetscInt &team_size, PetscInt &vector_length, PetscInt &rows_per_team)
256 {
257   Kokkos::TeamPolicy<ExecutionSpace> teamPolicy(128, Kokkos::AUTO);
258 
259   PetscFunctionBegin;
260   PetscInt nnz_per_row = numRows ? (nnz / numRows) : 0; // we might meet empty matrices
261 
262   if (nnz_per_row < 1) nnz_per_row = 1;
263 
264   int max_vector_length = teamPolicy.vector_length_max();
265 
266   if (vector_length < 1) {
267     vector_length = 1;
268     while (vector_length < max_vector_length && vector_length * 6 < nnz_per_row) vector_length *= 2;
269   }
270 
271   // Determine rows per thread
272   if (rows_per_thread < 1) {
273     if (KokkosKernels::Impl::kk_is_gpu_exec_space<ExecutionSpace>()) rows_per_thread = 1;
274     else {
275       if (nnz_per_row < 20 && nnz > 5000000) {
276         rows_per_thread = 256;
277       } else rows_per_thread = 64;
278     }
279   }
280 
281   if (team_size < 1) {
282     if (KokkosKernels::Impl::kk_is_gpu_exec_space<ExecutionSpace>()) {
283       team_size = 256 / vector_length;
284     } else {
285       team_size = 1;
286     }
287   }
288 
289   rows_per_team = rows_per_thread * team_size;
290 
291   if (rows_per_team < 0) {
292     PetscInt nnz_per_team = 4096;
293     PetscInt conc         = ExecutionSpace().concurrency();
294     while ((conc * nnz_per_team * 4 > nnz) && (nnz_per_team > 256)) nnz_per_team /= 2;
295     rows_per_team = (nnz_per_team + nnz_per_row - 1) / nnz_per_row;
296   }
297   PetscFunctionReturn(PETSC_SUCCESS);
298 }
299 
300 /*
301   Reduce two sets of global indices into local ones
302 
303   Input Parameters:
304 +  n1          - size of garray1[], the first set
305 .  garray1[n1] - a sorted global index array (without duplicates)
306 .  m           - size of indices[], the second set
307 -  indices[m]  - a unsorted global index array (might have duplicates), which will be updated on output into local ones
308 
309   Output Parameters:
310 +  n2          - size of garray2[], the merged set, which combines garray1[] and indices[]
311 .  garray2[n2] - allocated by callee using PetscMalloc1(). Contains sorted unique global indices (without duplicates). Caller needs to free it.
312 .  map[n1]     - allocated by caller. It gives garray1[i] = garray2[map[i]]
313 -  indices[m]  - on output, global indices in this array are rewritten with local ones, i.e, indices_input[i] = garray2[indices_output[i]]
314 
315    Example, say
316     n1         = 5
317     garray1[5] = {1, 4, 7, 8, 10}
318     m          = 4
319     indices[4] = {2, 4, 8, 9}
320 
321    Combining them together, we have 7 global indices in garray2[]
322     n2         = 7
323     garray2[7] = {1, 2, 4, 7, 8, 9, 10}
324 
325    And we have map[] to connect "garray1[i] = garray2[map[i]], i=[0,n1)"
326     map[5] = {0, 2, 3, 4, 6}
327 
328    On output, indices[] is updated with local indices
329     indices[4] = {1, 2, 4, 5}
330 */
331 static PetscErrorCode ReduceTwoSetsOfGlobalIndices(PetscInt n1, const PetscInt *garray1, PetscInt m, PetscInt *indices, PetscInt *n2_, PetscInt **garray2_, PetscInt *map)
332 {
333   PetscHMapI    g2l = nullptr;
334   PetscHashIter iter;
335   PetscInt      tot, key, val; // total unique global indices. key is global id; val is local id
336   PetscInt      n2, *garray2;
337 
338   PetscFunctionBegin;
339   tot = 0;
340   PetscCall(PetscHMapICreateWithSize(n1, &g2l));
341   for (PetscInt i = 0; i < m; i++) {                                // insert those in indices[]
342     PetscCall(PetscHMapIGetWithDefault(g2l, indices[i], -1, &val)); // if not exist, val is set with -1
343     if (val < 0) PetscCall(PetscHMapISet(g2l, indices[i], tot++));  // val < 0 means gid is not in the hash table yet
344   }
345 
346   for (PetscInt i = 0; i < n1; i++) { // insert those in garray1[]
347     PetscCall(PetscHMapIGetWithDefault(g2l, garray1[i], -1, &val));
348     if (val < 0) PetscCall(PetscHMapISet(g2l, garray1[i], tot++));
349   }
350 
351   // Pull out (unique) globals in the hash table and put them in garray2[]
352   n2 = tot;
353   PetscCall(PetscMalloc1(n2, &garray2));
354   tot = 0;
355   PetscHashIterBegin(g2l, iter);
356   while (!PetscHashIterAtEnd(g2l, iter)) {
357     PetscHashIterGetKey(g2l, iter, key);
358     PetscHashIterNext(g2l, iter);
359     garray2[tot++] = key;
360   }
361 
362   // Sort garray2[] and then map them to local indices starting from 0
363   PetscCall(PetscSortInt(n2, garray2));
364   PetscCall(PetscHMapIClear(g2l));
365   for (PetscInt i = 0; i < tot; i++) PetscCall(PetscHMapISet(g2l, garray2[i], i)); // i is the local id
366 
367   // Rewrite indices[] with local indices
368   for (PetscInt i = 0; i < m; i++) {
369     PetscCall(PetscHMapIGetWithDefault(g2l, indices[i], -1, &val));
370     PetscAssert(val >= 0, PETSC_COMM_SELF, PETSC_ERR_PLIB, "Met a negative local column index");
371     indices[i] = val;
372   }
373   // Record the map that maps garray1[i] to garray2[map[i]]
374   for (PetscInt i = 0; i < n1; i++) PetscCall(PetscHMapIGetWithDefault(g2l, garray1[i], -1, &map[i]));
375   PetscCall(PetscHMapIDestroy(&g2l));
376   *n2_      = n2;
377   *garray2_ = garray2;
378   PetscFunctionReturn(PETSC_SUCCESS);
379 }
380 
381 /*
382   MatMPIAIJKokkosReduce - Reduce rows of a MPIAIJKOKKOS matrix (E, in split form) to produce another matrix (F, also in split form, stored in mm)
383 
384   It is the reverse of MatMPIAIJKokkosBcast() in some sense, but with a different signature since we do not really need a fully populated MPIAIJKOKKOS E.
385 
386   Think each row of E as a leaf, then the given ownerSF specifies roots for the leaves. Roots may connect to multiple leaves.
387   In this routine, we sparse-merge leaves (rows) at their roots to form potentially longer rows in F. F's number of rows will be nroots of ownerSF.
388 
389   Input Parameters:
390 +  comm       - MPI communicator of E
391 .  A          - diag block of E, using local column indices
392 .  B          - off-diag block of E, using local column indices
393 .  cstart      - (global) start column of Ed
394 .  cend        - (global) end column + 1 of Ed.  In other words, E's column ownership is in range of [cstart, cend)
395 .  garray1[n1] - global column indices of Eo. Here n1 is Eo's column size.
396 .  ownerSF     - the SF specifies ownership (root) of rows in E
397 .  reuse       - either MAT_INITIAL_MATRIX or MAT_REUSE_MATRIX
398 -  mm          - to stash intermediate data structures for reuse
399 
400   Output Parameters:
401 +  map[n1]  - allocated by caller. It maps garray1[] to garray2[]. See more at ReduceTwoSetsOfGlobalIndices().
402 -  mm       - contains various info, such as garray2[], F (Fd, Fo) etc.
403 
404   Notes:
405   When reuse = MAT_REUSE_MATRIX, cstart, cend, garray1, ownerSF, map are not significant.
406 
407  */
408 static PetscErrorCode MatMPIAIJKokkosReduceBegin(MPI_Comm comm, KokkosCsrMatrix A, KokkosCsrMatrix B, PetscInt cstart, PetscInt cend, const PetscInt *garray1, PetscSF ownerSF, MatReuse reuse, PetscInt *map, MatMatStruct_AtB *mm)
409 {
410   PetscFunctionBegin;
411   if (reuse == MAT_INITIAL_MATRIX) {
412     PetscInt Em = A.numRows(), Fm;
413     PetscInt n1 = B.numCols();
414 
415     PetscCall(PetscSFGetGraph(ownerSF, &Fm, NULL, NULL, NULL)); // Fm = #rows of F = nroots of ownerSF
416 
417     // Do the analysis on host
418     auto                 Ai_h = Kokkos::create_mirror_view_and_copy(Kokkos::HostSpace(), A.graph.row_map);
419     auto                 Aj_h = Kokkos::create_mirror_view_and_copy(Kokkos::HostSpace(), A.graph.entries);
420     auto                 Bi_h = Kokkos::create_mirror_view_and_copy(Kokkos::HostSpace(), B.graph.row_map);
421     auto                 Bj_h = Kokkos::create_mirror_view_and_copy(Kokkos::HostSpace(), B.graph.entries);
422     const MatRowMapType *Ai = Ai_h.data(), *Bi = Bi_h.data();
423     const MatColIdxType *Aj = Aj_h.data(), *Bj = Bj_h.data();
424 
425     // Count how many nonzeros of each row in E are in the left of the diag block [cstart,cend)
426     PetscIntKokkosViewHost E_NzLeft_h(NoInit("E_NzLeft_h"), Em), E_RowLen_h(NoInit("E_RowLen_h"), Em);
427     PetscInt              *E_NzLeft = E_NzLeft_h.data(), *E_RowLen = E_RowLen_h.data();
428     for (PetscInt i = 0; i < Em; i++) {
429       const PetscInt *first, *last, *it;
430       PetscInt        count, step;
431       // std::lower_bound(first,last,cstart), but need to use global column indices
432       first = Bj + Bi[i];
433       last  = Bj + Bi[i + 1];
434       count = last - first;
435       while (count > 0) {
436         it   = first;
437         step = count / 2;
438         it += step;
439         if (garray1[*it] < cstart) { // map local to global
440           first = ++it;
441           count -= step + 1;
442         } else count = step;
443       }
444       E_NzLeft[i] = first - (Bj + Bi[i]);
445       E_RowLen[i] = (Ai[i + 1] - Ai[i]) + (Bi[i + 1] - Bi[i]);
446     }
447 
448     // Get length of rows (i.e., sizes of leaves) that contribute to my roots
449     const PetscMPIInt *iranks, *ranks;
450     const PetscInt    *ioffset, *irootloc, *roffset, *rmine;
451     PetscInt           niranks, nranks;
452     MPI_Request       *reqs;
453     PetscMPIInt        tag;
454     PetscSF            reduceSF;
455     PetscInt          *sdisp, *rdisp;
456 
457     PetscCall(PetscCommGetNewTag(comm, &tag));
458     PetscCall(PetscSFGetLeafRanks(ownerSF, &niranks, &iranks, &ioffset, &irootloc));  // get leaf ranks connecting to roots on this process (I'll recv from them)
459     PetscCall(PetscSFGetRootRanks(ownerSF, &nranks, &ranks, &roffset, &rmine, NULL)); // get root ranks this process connects (I'll send to them)
460 
461     // Find out length of each row I will receive. Even for the same row index, when they are from
462     // different senders, they might have different lengths (and sparsity patterns)
463     PetscInt  sendRowCnt = roffset[nranks], recvRowCnt = ioffset[niranks];
464     PetscInt *sendRowLen, *recvRowLen; // lengths of rows of I need to send/recv per process
465 
466     PetscCall(PetscMalloc5(sendRowCnt, &sendRowLen, recvRowCnt + 1, &recvRowLen, nranks, &sdisp, niranks + 1, &rdisp, nranks + niranks, &reqs));
467 
468     for (PetscInt i = 0; i < sendRowCnt; i++) sendRowLen[i] = E_RowLen[rmine[i]];
469     recvRowLen[0] = 0; // since we will make it in CSR format later
470     recvRowLen++;      // advance the pointer now
471     for (PetscInt i = 0; i < niranks; i++) { MPI_Irecv(&recvRowLen[ioffset[i]], ioffset[i + 1] - ioffset[i], MPIU_INT, iranks[i], tag, comm, &reqs[nranks + i]); }
472     for (PetscInt i = 0; i < nranks; i++) { MPI_Isend(&sendRowLen[roffset[i]], roffset[i + 1] - roffset[i], MPIU_INT, ranks[i], tag, comm, &reqs[i]); }
473     PetscCallMPI(MPI_Waitall(nranks + niranks, reqs, MPI_STATUSES_IGNORE));
474 
475     // Build the real PetscSF for reducing E rows (buffer to buffer)
476     rdisp[0] = 0;
477     for (PetscInt i = 0; i < niranks; i++) {
478       rdisp[i + 1] = rdisp[i];
479       for (PetscInt j = ioffset[i]; j < ioffset[i + 1]; j++) { rdisp[i + 1] += recvRowLen[j]; }
480     }
481     recvRowLen--; // put it back into csr format
482     for (PetscInt i = 0; i < recvRowCnt; i++) recvRowLen[i + 1] += recvRowLen[i];
483 
484     for (PetscInt i = 0; i < nranks; i++) { MPI_Irecv(&sdisp[i], 1, MPIU_INT, ranks[i], tag, comm, &reqs[i]); }
485     for (PetscInt i = 0; i < niranks; i++) { MPI_Isend(&rdisp[i], 1, MPIU_INT, iranks[i], tag, comm, &reqs[nranks + i]); }
486     PetscCallMPI(MPI_Waitall(nranks + niranks, reqs, MPI_STATUSES_IGNORE));
487 
488     PetscInt     nleaves = 0, Enz = 0;    // leaves are nonzeros I will send
489     PetscInt     nroots = rdisp[niranks]; // roots are nonzeros I will recv
490     PetscSFNode *iremote;
491 
492     for (PetscInt i = 0; i < Em; i++) Enz += E_RowLen[i];
493     PetscAssert(A.nnz() + B.nnz() == Enz, comm, PETSC_ERR_PLIB, "Enz should be equal to sum of nnz of A and B");
494     PetscCallMPI(PetscMalloc1(Enz, &iremote)); // no free, since we give ownership to reduceSF
495 
496     for (PetscInt i = 0; i < nranks; i++) {
497       PetscInt count = 0;
498       for (PetscInt j = roffset[i]; j < roffset[i + 1]; j++) count += E_RowLen[rmine[j]];
499       for (PetscInt j = 0; j < count; j++) {
500         iremote[nleaves + j].rank  = ranks[i];
501         iremote[nleaves + j].index = sdisp[i] + j;
502       }
503       nleaves += count;
504     }
505     PetscCheck(nleaves == Enz, comm, PETSC_ERR_PLIB, "nleaves should be equal to Enz");
506 
507     PetscCall(PetscSFCreate(comm, &reduceSF));
508     PetscCall(PetscSFSetGraph(reduceSF, nroots, nleaves, NULL, PETSC_OWN_POINTER, iremote, PETSC_OWN_POINTER));
509 
510     // Copy (global) column indices of the needed rows in E to sendCol[], and then PetscSFReduce to recvCol[]
511     PetscInt *sendCol, *recvCol;
512     PetscCall(PetscMalloc2(nleaves, &sendCol, nroots, &recvCol));
513     for (PetscInt k = 0; k < roffset[nranks]; k++) {
514       PetscInt  i      = rmine[k]; // row to be copied
515       PetscInt *buf    = &sendCol[Ai[i] + Bi[i]];
516       PetscInt  nzLeft = E_NzLeft[i];
517       PetscInt  alen = Ai[i + 1] - Ai[i], blen = Bi[i + 1] - Bi[i];
518       for (PetscInt j = 0; j < alen + blen; j++) {
519         if (j < nzLeft) {
520           buf[j] = garray1[Bj[Bi[i] + j]]; // left B, in global
521         } else if (j < nzLeft + alen) {
522           buf[j] = Aj[Ai[i] + j - nzLeft] + cstart; // diag A, also in global
523         } else {
524           buf[j] = garray1[Bj[Bi[i] + j - alen]]; // right B, in global
525         }
526       }
527     }
528     PetscCall(PetscSFReduceWithMemTypeBegin(reduceSF, MPIU_INT, PETSC_MEMTYPE_HOST, sendCol, PETSC_MEMTYPE_HOST, recvCol, MPI_REPLACE));
529     PetscCall(PetscSFReduceEnd(reduceSF, MPIU_INT, sendCol, recvCol, MPI_REPLACE));
530 
531     // With recvCol[], we do a series of analysis to get i, j of Fd, Fo, and build plans to reduce nonzeros in recv buffers to Fd and Fo
532     PetscInt *recvRowPerm, *recvColSorted;
533     PetscInt *recvNzPerm, *recvNzPermSorted;
534     PetscCall(PetscMalloc4(recvRowCnt, &recvRowPerm, nroots, &recvColSorted, nroots, &recvNzPerm, nroots, &recvNzPermSorted));
535 
536     for (PetscInt i = 0; i < nroots; i++) recvNzPerm[i] = i;                   // numbering all received nonzeros
537     for (PetscInt i = 0; i < recvRowCnt; i++) recvRowPerm[i] = i;              // put up a permutation array, so that after sorting we know where to get a row in recvCol[]
538     PetscCall(PetscSortIntWithPermutation(recvRowCnt, irootloc, recvRowPerm)); // irootloc[] (owned by ownerSF) won't be changed
539 
540     // i[] array, nz are always easiest to compute
541     MatRowMapKokkosViewHost Fdi_h(NoInit("Fdi_h"), Fm + 1), Foi_h(NoInit("Foi_h"), Fm + 1);
542     MatRowMapType          *Fdi, *Foi;
543     PetscInt                FnzDups = 0, Fdnz = 0, FdnzDups = 0, Fonz = 0, FonzDups = 0; // nz (with or without dups) in F, Fd, Fo
544     PetscInt                iter;
545 
546     Kokkos::deep_copy(Fdi_h, 0); // zero, as we will do 'val++' on them
547     Kokkos::deep_copy(Foi_h, 0);
548     Fdi  = Fdi_h.data() + 1; // +1 for easy indexing in code below
549     Foi  = Foi_h.data() + 1;
550     iter = 0;
551     while (iter < recvRowCnt) { // iter over received rows
552       PetscInt curRowIdx = irootloc[recvRowPerm[iter]];
553       PetscInt dupRows   = 1; // current row has this many contributing rows (of various sparsity patterns)
554 
555       while (iter + dupRows < recvRowCnt && irootloc[recvRowPerm[iter + dupRows]] == curRowIdx) dupRows++;
556 
557       // Copy column indices (and their permutation) of these rows into recvColSorted & recvNzPermSorted
558       PetscInt  nz    = 0; // nz (with dups) in the current row
559       PetscInt *jbuf  = recvColSorted + FnzDups;
560       PetscInt *pbuf  = recvNzPermSorted + FnzDups;
561       PetscInt *jbuf2 = jbuf; // temp pointers
562       PetscInt *pbuf2 = pbuf;
563       for (PetscInt d = 0; d < dupRows; d++) {
564         PetscInt i   = recvRowPerm[iter + d];
565         PetscInt len = recvRowLen[i + 1] - recvRowLen[i];
566         PetscCall(PetscArraycpy(jbuf2, &recvCol[recvRowLen[i]], len));
567         PetscCall(PetscArraycpy(pbuf2, &recvNzPerm[recvRowLen[i]], len));
568         jbuf2 += len;
569         pbuf2 += len;
570         nz += len;
571       }
572       PetscCall(PetscIntSortSemiOrderedWithArray(nz, jbuf, pbuf)); // It could be improved with k-way merge sort, since the rows are already sorted
573 
574       // Scan column indices (in jbuf[0,nz), might have dups) of this row, and see how many go to Fd and how many go to Fo
575       PetscInt cur = 0;
576       while (cur < nz) {
577         PetscInt curColIdx = jbuf[cur];
578         PetscInt dups      = 1;
579 
580         while (cur + dups < nz && jbuf[cur + dups] == curColIdx) dups++;
581         if (curColIdx >= cstart && curColIdx < cend) {
582           Fdi[curRowIdx]++;
583           FdnzDups += dups;
584         } else {
585           Foi[curRowIdx]++;
586           FonzDups += dups;
587         }
588         cur += dups;
589       }
590 
591       FnzDups += nz;
592       iter += dupRows; // Move to next unique row
593     }
594 
595     Fdi = Fdi_h.data(); // restore Fdi, Foi and make them CSR
596     Foi = Foi_h.data();
597     for (PetscInt i = 0; i < Fm; i++) {
598       Fdi[i + 1] += Fdi[i];
599       Foi[i + 1] += Foi[i];
600     }
601     Fdnz = Fdi[Fm];
602     Fonz = Foi[Fm];
603     PetscCall(PetscFree2(sendCol, recvCol));
604 
605     // Allocate j, jmap, jperm for Fd and Fo
606     MatColIdxKokkosViewHost Fdj_h(NoInit("Fdj_h"), Fdnz), Foj_h(NoInit("Foj_h"), Fonz);
607     MatRowMapKokkosViewHost Fdjmap_h(NoInit("Fdjmap_h"), Fdnz + 1), Fojmap_h(NoInit("Fojmap_h"), Fonz + 1); // +1 to make csr
608     MatRowMapKokkosViewHost Fdjperm_h(NoInit("Fdjperm_h"), FdnzDups), Fojperm_h(NoInit("Fojperm_h"), FonzDups);
609     MatColIdxType          *Fdj = Fdj_h.data(), *Foj = Foj_h.data();
610     MatRowMapType          *Fdjmap = Fdjmap_h.data(), *Fojmap = Fojmap_h.data();
611     MatRowMapType          *Fdjperm = Fdjperm_h.data(), *Fojperm = Fojperm_h.data();
612 
613     // Scan recvColSorted[] again, and fill j, jmap, jperm for Fd and Fo
614     Fdjmap[0] = 0;
615     Fojmap[0] = 0;
616     FnzDups   = 0;
617     Fdnz      = 0;
618     Fonz      = 0;
619     iter      = 0; // iter over received rows
620     while (iter < recvRowCnt) {
621       PetscInt curRowIdx = irootloc[recvRowPerm[iter]]; // current row idx
622       PetscInt dupRows   = 1;                           // It has this many contributing rows (of various lengths)
623       PetscInt nz        = 0;                           // nz (with dups) in the current row
624 
625       while (iter + dupRows < recvRowCnt && irootloc[recvRowPerm[iter + dupRows]] == curRowIdx) dupRows++;
626       for (PetscInt d = 0; d < dupRows; d++) {
627         PetscInt i = recvRowPerm[iter + d];
628         nz += recvRowLen[i + 1] - recvRowLen[i];
629       }
630 
631       PetscInt *jbuf = recvColSorted + FnzDups;
632       // Scan columns (in jbuf[0,nz) of this row, copy them and their permutation to j[] and jperm[] of Fd and Fo
633       PetscInt cur = 0;
634       while (cur < nz) {
635         PetscInt curColIdx = jbuf[cur];
636         PetscInt dups      = 1;
637 
638         while (cur + dups < nz && jbuf[cur + dups] == curColIdx) dups++;
639         if (curColIdx >= cstart && curColIdx < cend) {
640           Fdj[Fdnz]        = curColIdx - cstart; // easily convert to local
641           Fdjmap[Fdnz + 1] = Fdjmap[Fdnz] + dups;
642           for (PetscInt j = 0; j < dups; j++) Fdjperm[Fdjmap[Fdnz] + j] = recvNzPermSorted[FnzDups + j];
643           FdnzDups += dups;
644           Fdnz++;
645         } else {
646           Foj[Fonz]        = curColIdx; // in global
647           Fojmap[Fonz + 1] = Fojmap[Fonz] + dups;
648           for (PetscInt j = 0; j < dups; j++) Fojperm[Fojmap[Fonz] + j] = recvNzPermSorted[FnzDups + j];
649           FonzDups += dups;
650           Fonz++;
651         }
652         cur += dups;
653         FnzDups += dups;
654       }
655       iter += dupRows; // Move to next unique row
656     }
657     PetscCall(PetscFree4(recvRowPerm, recvColSorted, recvNzPerm, recvNzPermSorted));
658     PetscCall(PetscFree5(sendRowLen, recvRowLen, sdisp, rdisp, reqs));
659 
660     // Combine global column indices in garray1[] and Foj[]
661     PetscInt n2, *garray2;
662 
663     PetscCall(ReduceTwoSetsOfGlobalIndices(n1, garray1, Fonz, Foj, &n2, &garray2, map));
664     mm->sf       = reduceSF;
665     mm->leafBuf  = MatScalarKokkosView(NoInit("leafBuf"), nleaves);
666     mm->rootBuf  = MatScalarKokkosView(NoInit("rootBuf"), nroots);
667     mm->garray   = garray2; // give ownership, so no free
668     mm->n        = n2;
669     mm->E_NzLeft = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), E_NzLeft_h);
670     mm->Fdjmap   = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Fdjmap_h);
671     mm->Fdjperm  = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Fdjperm_h);
672     mm->Fojmap   = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Fojmap_h);
673     mm->Fojperm  = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Fojperm_h);
674 
675     // Output Fd and Fo in KokkosCsrMatrix format
676     MatScalarKokkosView Fda_d(NoInit("Fda_d"), Fdnz);
677     MatRowMapKokkosView Fdi_d = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Fdi_h);
678     MatColIdxKokkosView Fdj_d = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Fdj_h);
679     MatScalarKokkosView Foa_d(NoInit("Foa_d"), Fonz);
680     MatRowMapKokkosView Foi_d = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Foi_h);
681     MatColIdxKokkosView Foj_d = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Foj_h);
682 
683     PetscCallCXX(mm->Fd = KokkosCsrMatrix("Fd", Fm, cend - cstart, Fdnz, Fda_d, Fdi_d, Fdj_d));
684     PetscCallCXX(mm->Fo = KokkosCsrMatrix("Fo", Fm, n2, Fonz, Foa_d, Foi_d, Foj_d)); // Fo's column size is n2, length of garray2[]
685 
686     // Compute kernel launch parameters in merging E
687     PetscInt teamSize, vectorLength, rowsPerTeam;
688 
689     teamSize = vectorLength = rowsPerTeam = -1;
690     PetscCall(MatMergeGetLaunchParameters<DefaultExecutionSpace>(Em, Enz, -1, teamSize, vectorLength, rowsPerTeam));
691     mm->E_TeamSize     = teamSize;
692     mm->E_VectorLength = vectorLength;
693     mm->E_RowsPerTeam  = rowsPerTeam;
694   } else PetscCheck(reuse == MAT_REUSE_MATRIX, comm, PETSC_ERR_PLIB, "Unsupported MatReuse enum %d", reuse);
695 
696   // Handy aliases
697   auto       &Aa           = A.values;
698   auto       &Ba           = B.values;
699   const auto &Ai           = A.graph.row_map;
700   const auto &Bi           = B.graph.row_map;
701   const auto &E_NzLeft     = mm->E_NzLeft;
702   auto       &leafBuf      = mm->leafBuf;
703   auto       &rootBuf      = mm->rootBuf;
704   PetscSF     reduceSF     = mm->sf;
705   PetscInt    Em           = A.numRows();
706   PetscInt    teamSize     = mm->E_TeamSize;
707   PetscInt    vectorLength = mm->E_VectorLength;
708   PetscInt    rowsPerTeam  = mm->E_RowsPerTeam;
709   PetscInt    workSets     = (Em + rowsPerTeam - 1) / rowsPerTeam;
710 
711   // Copy rows in A/B of E to leafBuf, then pass it to rootBuf
712   PetscCallCXX(Kokkos::parallel_for(
713     Kokkos::TeamPolicy<>(workSets, teamSize, vectorLength), KOKKOS_LAMBDA(const KokkosTeamMemberType &t) {
714       Kokkos::parallel_for(Kokkos::TeamThreadRange(t, 0, rowsPerTeam), [&](PetscInt k) {
715         PetscInt i = t.league_rank() * rowsPerTeam + k; // i-th row in F
716         if (i < Em) {
717           PetscInt disp   = Ai(i) + Bi(i);
718           PetscInt alen   = Ai(i + 1) - Ai(i);
719           PetscInt blen   = Bi(i + 1) - Bi(i);
720           PetscInt nzleft = E_NzLeft(i);
721 
722           Kokkos::parallel_for(Kokkos::ThreadVectorRange(t, alen + blen), [&](PetscInt j) {
723             MatScalar &val = leafBuf(disp + j);
724             if (j < nzleft) { // B left
725               val = Ba(Bi(i) + j);
726             } else if (j < nzleft + alen) { // diag A
727               val = Aa(Ai(i) + j - nzleft);
728             } else { // B right
729               val = Ba(Bi(i) + j - alen);
730             }
731           });
732         }
733       });
734     }));
735   PetscCall(PetscSFReduceWithMemTypeBegin(reduceSF, MPIU_SCALAR, PETSC_MEMTYPE_KOKKOS, leafBuf.data(), PETSC_MEMTYPE_KOKKOS, rootBuf.data(), MPI_REPLACE));
736   PetscFunctionReturn(PETSC_SUCCESS);
737 }
738 
739 // To finish MatMPIAIJKokkosReduce.
740 static PetscErrorCode MatMPIAIJKokkosReduceEnd(MPI_Comm comm, KokkosCsrMatrix A, KokkosCsrMatrix B, PetscInt cstart, PetscInt cend, const PetscInt *garray1, PetscSF ownerSF, MatReuse reuse, PetscInt *map, MatMatStruct_AtB *mm)
741 {
742   PetscFunctionBegin;
743   auto       &leafBuf  = mm->leafBuf;
744   auto       &rootBuf  = mm->rootBuf;
745   auto       &Fda      = mm->Fd.values;
746   const auto &Fdjmap   = mm->Fdjmap;
747   const auto &Fdjperm  = mm->Fdjperm;
748   auto        Fdnz     = mm->Fd.nnz();
749   auto       &Foa      = mm->Fo.values;
750   const auto &Fojmap   = mm->Fojmap;
751   const auto &Fojperm  = mm->Fojperm;
752   auto        Fonz     = mm->Fo.nnz();
753   PetscSF     reduceSF = mm->sf;
754 
755   PetscCall(PetscSFReduceEnd(reduceSF, MPIU_SCALAR, leafBuf.data(), rootBuf.data(), MPI_REPLACE));
756 
757   // Reduce data in rootBuf to Fd and Fo
758   PetscCallCXX(Kokkos::parallel_for(
759     Fdnz, KOKKOS_LAMBDA(const MatRowMapType i) {
760       PetscScalar sum = 0.0;
761       for (MatRowMapType k = Fdjmap(i); k < Fdjmap(i + 1); k++) sum += rootBuf(Fdjperm(k));
762       Fda(i) = sum;
763     }));
764 
765   PetscCallCXX(Kokkos::parallel_for(
766     Fonz, KOKKOS_LAMBDA(const MatRowMapType i) {
767       PetscScalar sum = 0.0;
768       for (MatRowMapType k = Fojmap(i); k < Fojmap(i + 1); k++) sum += rootBuf(Fojperm(k));
769       Foa(i) = sum;
770     }));
771   PetscFunctionReturn(PETSC_SUCCESS);
772 }
773 
774 /*
775   MatMPIAIJKokkosBcast - Bcast local rows of a MPIAIJKOKKOS matrix (E) to produce a local matrix (F, stored in mm) in split form
776 
777   This is a complex routine. It is essentially the MPIAIJKOKKOS counterpart of MatGetBrowsOfAoCols_MPIAIJ, but supports
778   device and involves various index mapping.
779 
780   In the given ownerSF, leaves correspond to rows in F, and roots correspond to rows in E. Roots may connect to multiple leaves.
781   Suppose F's j-th row is connected to a root identified by PetscSFNode (k,i), it means we need to bcast the i-th row of E on rank k
782   to j-th row of F. ownerSF is not an arbitrary SF, instead it is the Mvctx of another MPIAIJ matrix A that is able to perform A*E.
783   F has the same column layout as E.
784 
785   Conceptually F has global column indices. In this routine, we spit F into diagonal Fd and off-diagonal Fo.
786   Fd uses local column indices, which are easy to compute. We just need to subtract the "local column range start" from the global indices.
787   Fo had global column indices at first. We will reduce them into local ones. In doing that, we also take into account the global
788   column indices that E's off-diag block has. Let's say there are n1 such indices stored in garray1[]. We will reduce them along with
789   column indices in Fo and update Fo with local indices.
790 
791    Input Parameters:
792 +   E       - the MPIAIJKOKKOS matrix
793 .   ownerSF - the ownership SF (insignificant in MAT_REUSE_MATRIX)
794 .   reuse   - either MAT_INITIAL_MATRIX or MAT_REUSE_MATRIX
795 -   mm      - to stash matproduct intermediate data structures
796 
797     Output Parameters:
798 +   map[n1] - allocated by caller. It maps garray1[] to garray2[]. See more at ReduceTwoSetsOfGlobalIndices.
799 -   mm      - contains various info, such as garray2[], Fd, Fo, etc.
800 
801     Notes:
802     When reuse = MAT_REUSE_MATRIX, ownerSF, map are not significant.
803     The routine is provide in split-phase form MatMPIAIJKokkosBcastBegin/End() to provide computation/communication opportunities.
804 */
805 static PetscErrorCode MatMPIAIJKokkosBcastBegin(Mat E, PetscSF ownerSF, MatReuse reuse, PetscInt *map, MatMatStruct_AB *mm)
806 {
807   Mat_MPIAIJ       *empi = static_cast<Mat_MPIAIJ *>(E->data);
808   Mat               A = empi->A, B = empi->B; // diag and off-diag
809   Mat_SeqAIJKokkos *akok = static_cast<Mat_SeqAIJKokkos *>(A->spptr), *bkok = static_cast<Mat_SeqAIJKokkos *>(B->spptr);
810   PetscInt          Em = E->rmap->n; // #local rows
811   MPI_Comm          comm;
812 
813   PetscFunctionBegin;
814   PetscCallMPI(PetscObjectGetComm((PetscObject)E, &comm));
815   if (reuse == MAT_INITIAL_MATRIX) {
816     Mat_SeqAIJ     *aseq = static_cast<Mat_SeqAIJ *>(A->data), *bseq = static_cast<Mat_SeqAIJ *>(B->data);
817     PetscInt        n1 = B->cmap->n, *Ai = aseq->i, *Aj = aseq->j, *Bi = bseq->i, *Bj = bseq->j;
818     const PetscInt *garray1 = empi->garray; // its size is n1
819     PetscInt        cstart, cend;
820     PetscSF         bcastSF;
821 
822     PetscCall(MatGetOwnershipRangeColumn(E, &cstart, &cend));
823 
824     // Count how many nonzeros of each row in E are in the left of the diag block [cstart,cend)
825     PetscIntKokkosViewHost E_NzLeft_h(NoInit("E_NzLeft_h"), Em), E_RowLen_h(NoInit("E_RowLen_h"), Em);
826     PetscInt              *E_NzLeft = E_NzLeft_h.data(), *E_RowLen = E_RowLen_h.data();
827     for (PetscInt i = 0; i < Em; i++) {
828       const PetscInt *first, *last, *it;
829       PetscInt        count, step;
830       // std::lower_bound(first,last,cstart), but need to use global column indices
831       first = Bj + Bi[i];
832       last  = Bj + Bi[i + 1];
833       count = last - first;
834       while (count > 0) {
835         it   = first;
836         step = count / 2;
837         it += step;
838         if (empi->garray[*it] < cstart) { // map local to global
839           first = ++it;
840           count -= step + 1;
841         } else count = step;
842       }
843       E_NzLeft[i] = first - (Bj + Bi[i]);
844       E_RowLen[i] = (Ai[i + 1] - Ai[i]) + (Bi[i + 1] - Bi[i]);
845     }
846 
847     // Compute row pointer Fi of F
848     PetscInt *Fi, Fm, Fnz;
849     PetscCall(PetscSFGetGraph(ownerSF, NULL, &Fm, NULL, NULL)); // Fm = #rows of F = nleaves of ownerSF
850     PetscCall(PetscMalloc1(Fm + 1, &Fi));
851     Fi[0] = 0;
852     PetscCall(PetscSFBcastWithMemTypeBegin(ownerSF, MPIU_INT, PETSC_MEMTYPE_HOST, E_RowLen, PETSC_MEMTYPE_HOST, &Fi[1], MPI_REPLACE));
853     PetscCall(PetscSFBcastEnd(ownerSF, MPIU_INT, E_RowLen, &Fi[1], MPI_REPLACE));
854     for (PetscInt i = 0; i < Fm; i++) Fi[i + 1] += Fi[i];
855     Fnz = Fi[Fm];
856 
857     // Build the real PetscSF for bcasting E rows (buffer to buffer)
858     const PetscMPIInt *iranks, *ranks;
859     const PetscInt    *ioffset, *irootloc, *roffset;
860     PetscInt           niranks, nranks, *sdisp, *rdisp;
861     MPI_Request       *reqs;
862     PetscMPIInt        tag;
863 
864     PetscCall(PetscSFGetLeafRanks(ownerSF, &niranks, &iranks, &ioffset, &irootloc)); // get leaf ranks referencing roots on this process
865     PetscCall(PetscSFGetRootRanks(ownerSF, &nranks, &ranks, &roffset, NULL, NULL));  // recv info
866     PetscCall(PetscMalloc3(niranks + 1, &sdisp, nranks, &rdisp, niranks + nranks, &reqs));
867 
868     sdisp[0] = 0; // send displacement
869     for (PetscInt i = 0; i < niranks; i++) {
870       sdisp[i + 1] = sdisp[i];
871       for (PetscInt j = ioffset[i]; j < ioffset[i + 1]; j++) {
872         PetscInt r = irootloc[j]; // row to be sent
873         sdisp[i + 1] += E_RowLen[r];
874       }
875     }
876 
877     PetscCallMPI(PetscCommGetNewTag(comm, &tag));
878     for (PetscInt i = 0; i < nranks; i++) PetscCallMPI(MPI_Irecv(&rdisp[i], 1, MPIU_INT, ranks[i], tag, comm, &reqs[i]));
879     for (PetscInt i = 0; i < niranks; i++) PetscCallMPI(MPI_Isend(&sdisp[i], 1, MPIU_INT, iranks[i], tag, comm, &reqs[nranks + i]));
880     PetscCallMPI(MPI_Waitall(niranks + nranks, reqs, MPI_STATUSES_IGNORE));
881 
882     PetscInt     nleaves = Fnz;            // leaves are nonzeros I will receive
883     PetscInt     nroots  = sdisp[niranks]; // roots are nonzeros I will send
884     PetscSFNode *iremote;                  // give ownership to bcastSF
885     PetscCall(PetscMalloc1(nleaves, &iremote));
886     for (PetscInt i = 0; i < nranks; i++) { // for each sender rank
887       PetscInt k = 0;
888       for (PetscInt j = Fi[roffset[i]]; j < Fi[roffset[i + 1]]; j++) { // I will receive rows [roffset[i], roffset[i+1]) of F from ranks[i]
889         iremote[j].rank  = ranks[i];
890         iremote[j].index = rdisp[i] + k; // their root location
891         k++;
892       }
893     }
894     PetscCall(PetscSFCreate(comm, &bcastSF));
895     PetscCall(PetscSFSetGraph(bcastSF, nroots, nleaves, NULL, PETSC_OWN_POINTER, iremote, PETSC_OWN_POINTER));
896     PetscCall(PetscFree3(sdisp, rdisp, reqs));
897 
898     // Build a plan (rowoffset, irootloc, E_NzLeft) to copy rows in E to rootdata of bcastSF in parallel
899     PetscIntKokkosViewHost rowoffset_h(NoInit("rowoffset_h"), ioffset[niranks] + 1);
900     PetscInt              *rowoffset = rowoffset_h.data(); // for each entry (row) indicated in irootloc[], we calculate its destinate offset in copying
901     rowoffset[0]                     = 0;
902     for (PetscInt i = 0; i < ioffset[niranks]; i++) { rowoffset[i + 1] = rowoffset[i] + E_RowLen[irootloc[i]]; }
903 
904     // Copy (global) column indices of the needed rows in E to a buffer, and then bcast to Fj[]
905     PetscInt *jbuf, *Fj;
906     PetscCall(PetscMalloc2(nroots, &jbuf, Fnz, &Fj));
907     for (PetscInt k = 0; k < ioffset[niranks]; k++) {
908       PetscInt  i      = irootloc[k]; // row to be copied
909       PetscInt *buf    = &jbuf[rowoffset[k]];
910       PetscInt  nzLeft = E_NzLeft[i];
911       PetscInt  alen = Ai[i + 1] - Ai[i], blen = Bi[i + 1] - Bi[i];
912       for (PetscInt j = 0; j < alen + blen; j++) {
913         if (j < nzLeft) {
914           buf[j] = empi->garray[Bj[Bi[i] + j]]; // left B, in global
915         } else if (j < nzLeft + alen) {
916           buf[j] = Aj[Ai[i] + j - nzLeft] + cstart; // diag A, also in global
917         } else {
918           buf[j] = empi->garray[Bj[Bi[i] + j - alen]]; // right B, in global
919         }
920       }
921     }
922     PetscCall(PetscSFBcastWithMemTypeBegin(bcastSF, MPIU_INT, PETSC_MEMTYPE_HOST, jbuf, PETSC_MEMTYPE_HOST, Fj, MPI_REPLACE));
923     PetscCall(PetscSFBcastEnd(bcastSF, MPIU_INT, jbuf, Fj, MPI_REPLACE));
924 
925     // Build a plan (i.e., F_NzLeft) to split F into Fd and Fo
926     MatRowMapKokkosViewHost Fdi_h(NoInit("Fdi_h"), Fm + 1), Foi_h(NoInit("Foi_h"), Fm + 1); // row pointer of Fd, Fo
927     MatColIdxKokkosViewHost F_NzLeft_h(NoInit("F_NzLeft_h"), Fm);                           // split each row of F into Left, Diag, Right. We only need to record #nz in Left and Diag.
928     MatRowMapType          *Fdi = Fdi_h.data(), *Foi = Foi_h.data();
929     MatColIdxType          *F_NzLeft = F_NzLeft_h.data();
930 
931     Fdi[0] = Foi[0] = 0;
932     for (PetscInt i = 0; i < Fm; i++) {
933       PetscInt *first, *last, *lb1, *lb2;
934       // cut the row into: Left, [cstart, cend), Right
935       first       = Fj + Fi[i];
936       last        = Fj + Fi[i + 1];
937       lb1         = std::lower_bound(first, last, cstart);
938       F_NzLeft[i] = lb1 - first;
939       lb2         = std::lower_bound(first, last, cend);
940       Fdi[i + 1]  = lb2 - lb1;                        // row i length in Fdi
941       Foi[i + 1]  = (Fi[i + 1] - Fi[i]) - Fdi[i + 1]; // row i length in Foi
942     }
943     for (PetscInt i = 0; i < Fm; i++) {
944       Fdi[i + 1] += Fdi[i];
945       Foi[i + 1] += Foi[i];
946     }
947 
948     // Fill Fdj[] and Foj[], i.e., columns of Fd and Fo. Fdj[] are local, but Foj[] are not yet.
949     PetscInt                Fdnz = Fdi[Fm], Fonz = Foi[Fm];
950     MatColIdxKokkosViewHost Fdj_h(NoInit("Fdj_h"), Fdnz), Foj_h(NoInit("Foj_h"), Fonz);
951     MatColIdxType          *Fdj = Fdj_h.data(), *Foj = Foj_h.data(), gid;
952 
953     for (PetscInt i = 0; i < Fm; i++) {
954       PetscInt nzLeft = F_NzLeft[i];
955       PetscInt len    = Fdi[i + 1] - Fdi[i]; // diag row len
956       for (PetscInt j = 0; j < Fi[i + 1] - Fi[i]; j++) {
957         gid = Fj[Fi[i] + j];
958         if (j < nzLeft) { // left, in global
959           Foj[Foi[i] + j] = gid;
960         } else if (j < nzLeft + len) { // diag, in local
961           Fdj[Fdi[i] + j - nzLeft] = gid - cstart;
962         } else { // right, in global
963           Foj[Foi[i] + j - len] = gid;
964         }
965       }
966     }
967     PetscCall(PetscFree2(jbuf, Fj));
968     PetscCall(PetscFree(Fi));
969 
970     // Reduce global indices in Foj[] and garray1[] into local ones
971     PetscInt n2, *garray2;
972     PetscCall(ReduceTwoSetsOfGlobalIndices(n1, garray1, Fonz, Foj, &n2, &garray2, map));
973 
974     // Record the plans built above, for reuse
975     PetscIntKokkosViewHost tmp(const_cast<PetscInt *>(irootloc), ioffset[niranks]); // irootloc[] is owned by ownerSF. We create a copy for safety
976     PetscIntKokkosViewHost irootloc_h(NoInit("irootloc_h"), ioffset[niranks]);
977     Kokkos::deep_copy(irootloc_h, tmp);
978     mm->sf        = bcastSF;
979     mm->E_NzLeft  = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), E_NzLeft_h);
980     mm->F_NzLeft  = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), F_NzLeft_h);
981     mm->irootloc  = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), irootloc_h);
982     mm->rowoffset = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), rowoffset_h);
983     mm->rootBuf   = MatScalarKokkosView(NoInit("rootBuf"), nroots);
984     mm->leafBuf   = MatScalarKokkosView(NoInit("leafBuf"), nleaves);
985     mm->garray    = garray2;
986     mm->n         = n2;
987 
988     // Output Fd and Fo in KokkosCsrMatrix format
989     MatScalarKokkosView Fda_d(NoInit("Fda_d"), Fdnz), Foa_d(NoInit("Foa_d"), Fonz);
990     MatRowMapKokkosView Fdi_d = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Fdi_h);
991     MatColIdxKokkosView Fdj_d = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Fdj_h);
992     MatRowMapKokkosView Foi_d = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Foi_h);
993     MatColIdxKokkosView Foj_d = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Foj_h);
994 
995     PetscCallCXX(mm->Fd = KokkosCsrMatrix("Fd", Fm, cend - cstart, Fdnz, Fda_d, Fdi_d, Fdj_d));
996     PetscCallCXX(mm->Fo = KokkosCsrMatrix("Fo", Fm, n2, Fonz, Foa_d, Foi_d, Foj_d));
997 
998     // Compute kernel launch parameters in merging E or splitting F
999     PetscInt teamSize, vectorLength, rowsPerTeam;
1000 
1001     teamSize = vectorLength = rowsPerTeam = -1;
1002     PetscCall(MatMergeGetLaunchParameters<DefaultExecutionSpace>(mm->irootloc.extent(0), mm->rootBuf.extent(0), -1, teamSize, vectorLength, rowsPerTeam));
1003     mm->E_TeamSize     = teamSize;
1004     mm->E_VectorLength = vectorLength;
1005     mm->E_RowsPerTeam  = rowsPerTeam;
1006 
1007     teamSize = vectorLength = rowsPerTeam = -1;
1008     PetscCall(MatMergeGetLaunchParameters<DefaultExecutionSpace>(Fm, Fnz, -1, teamSize, vectorLength, rowsPerTeam));
1009     mm->F_TeamSize     = teamSize;
1010     mm->F_VectorLength = vectorLength;
1011     mm->F_RowsPerTeam  = rowsPerTeam;
1012   } else PetscCheck(reuse == MAT_REUSE_MATRIX, comm, PETSC_ERR_PLIB, "Unsupported MatReuse enum %d", reuse);
1013 
1014   // Sync E's value to device
1015   akok->a_dual.sync_device();
1016   bkok->a_dual.sync_device();
1017 
1018   // Handy aliases
1019   const auto &Aa = akok->a_dual.view_device();
1020   const auto &Ba = bkok->a_dual.view_device();
1021   const auto &Ai = akok->i_dual.view_device();
1022   const auto &Bi = bkok->i_dual.view_device();
1023 
1024   // Fetch the plans
1025   PetscIntKokkosView  &E_NzLeft  = mm->E_NzLeft;
1026   PetscSF             &bcastSF   = mm->sf;
1027   MatScalarKokkosView &rootBuf   = mm->rootBuf;
1028   MatScalarKokkosView &leafBuf   = mm->leafBuf;
1029   PetscIntKokkosView  &irootloc  = mm->irootloc;
1030   PetscIntKokkosView  &rowoffset = mm->rowoffset;
1031 
1032   PetscInt teamSize     = mm->E_TeamSize;
1033   PetscInt vectorLength = mm->E_VectorLength;
1034   PetscInt rowsPerTeam  = mm->E_RowsPerTeam;
1035   PetscInt workSets     = (irootloc.extent(0) + rowsPerTeam - 1) / rowsPerTeam;
1036 
1037   // Copy rows in A/B of E to rootBuf, then bcast it to leafBuf
1038   PetscCallCXX(Kokkos::parallel_for(
1039     Kokkos::TeamPolicy<>(workSets, teamSize, vectorLength), KOKKOS_LAMBDA(const KokkosTeamMemberType &t) {
1040       Kokkos::parallel_for(Kokkos::TeamThreadRange(t, 0, rowsPerTeam), [&](PetscInt k) {
1041         size_t r = t.league_rank() * rowsPerTeam + k; // r-th entry in irootloc[]
1042         if (r < irootloc.extent(0)) {
1043           PetscInt i      = irootloc(r); // row i of E
1044           PetscInt disp   = rowoffset(r);
1045           PetscInt alen   = Ai(i + 1) - Ai(i);
1046           PetscInt blen   = Bi(i + 1) - Bi(i);
1047           PetscInt nzleft = E_NzLeft(i);
1048 
1049           Kokkos::parallel_for(Kokkos::ThreadVectorRange(t, alen + blen), [&](PetscInt j) {
1050             if (j < nzleft) { // B left
1051               rootBuf(disp + j) = Ba(Bi(i) + j);
1052             } else if (j < nzleft + alen) { // diag A
1053               rootBuf(disp + j) = Aa(Ai(i) + j - nzleft);
1054             } else { // B right
1055               rootBuf(disp + j) = Ba(Bi(i) + j - alen);
1056             }
1057           });
1058         }
1059       });
1060     }));
1061   PetscCall(PetscSFBcastWithMemTypeBegin(bcastSF, MPIU_SCALAR, PETSC_MEMTYPE_KOKKOS, rootBuf.data(), PETSC_MEMTYPE_KOKKOS, leafBuf.data(), MPI_REPLACE));
1062   PetscFunctionReturn(PETSC_SUCCESS);
1063 }
1064 
1065 // To finish MatMPIAIJKokkosBcast.
1066 static PetscErrorCode MatMPIAIJKokkosBcastEnd(Mat E, PetscSF ownerSF, MatReuse reuse, PetscInt *map, MatMatStruct_AB *mm)
1067 {
1068   PetscFunctionBegin;
1069   const auto &Fd  = mm->Fd;
1070   const auto &Fo  = mm->Fo;
1071   const auto &Fdi = Fd.graph.row_map;
1072   const auto &Foi = Fo.graph.row_map;
1073   auto       &Fda = Fd.values;
1074   auto       &Foa = Fo.values;
1075   auto        Fm  = Fd.numRows();
1076 
1077   PetscIntKokkosView  &F_NzLeft     = mm->F_NzLeft;
1078   PetscSF             &bcastSF      = mm->sf;
1079   MatScalarKokkosView &rootBuf      = mm->rootBuf;
1080   MatScalarKokkosView &leafBuf      = mm->leafBuf;
1081   PetscInt             teamSize     = mm->F_TeamSize;
1082   PetscInt             vectorLength = mm->F_VectorLength;
1083   PetscInt             rowsPerTeam  = mm->F_RowsPerTeam;
1084   PetscInt             workSets     = (Fm + rowsPerTeam - 1) / rowsPerTeam;
1085 
1086   PetscCall(PetscSFBcastEnd(bcastSF, MPIU_SCALAR, rootBuf.data(), leafBuf.data(), MPI_REPLACE));
1087 
1088   // Update Fda and Foa with new data in leafBuf (as if it is Fa)
1089   PetscCallCXX(Kokkos::parallel_for(
1090     Kokkos::TeamPolicy<>(workSets, teamSize, vectorLength), KOKKOS_LAMBDA(const KokkosTeamMemberType &t) {
1091       Kokkos::parallel_for(Kokkos::TeamThreadRange(t, 0, rowsPerTeam), [&](PetscInt k) {
1092         PetscInt i = t.league_rank() * rowsPerTeam + k; // i-th row in F
1093         if (i < Fm) {
1094           PetscInt nzLeft = F_NzLeft(i);
1095           PetscInt alen   = Fdi(i + 1) - Fdi(i);
1096           PetscInt blen   = Foi(i + 1) - Foi(i);
1097           PetscInt Fii    = Fdi(i) + Foi(i);
1098 
1099           Kokkos::parallel_for(Kokkos::ThreadVectorRange(t, alen + blen), [&](PetscInt j) {
1100             PetscScalar val = leafBuf(Fii + j);
1101             if (j < nzLeft) { // left
1102               Foa(Foi(i) + j) = val;
1103             } else if (j < nzLeft + alen) { // diag
1104               Fda(Fdi(i) + j - nzLeft) = val;
1105             } else { // right
1106               Foa(Foi(i) + j - alen) = val;
1107             }
1108           });
1109         }
1110       });
1111     }));
1112   PetscFunctionReturn(PETSC_SUCCESS);
1113 }
1114 
1115 static PetscErrorCode MatProductSymbolic_MPIAIJKokkos_AtB(Mat_Product *product, Mat A, Mat B, MatMatStruct_AtB *mm)
1116 {
1117   Mat_MPIAIJ     *ampi = static_cast<Mat_MPIAIJ *>(A->data);
1118   Mat_MPIAIJ     *bmpi = static_cast<Mat_MPIAIJ *>(B->data);
1119   KokkosCsrMatrix Adt, Aot, Ad, Ao, Bd, Bo;
1120   PetscInt        cstart, cend;
1121   MPI_Comm        comm;
1122 
1123   PetscFunctionBegin;
1124   PetscCall(PetscObjectGetComm((PetscObject)B, &comm));
1125   PetscCall(MatSeqAIJKokkosGenerateTranspose_Private(ampi->A, &Adt));
1126   PetscCall(MatSeqAIJKokkosGenerateTranspose_Private(ampi->B, &Aot));
1127   PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(ampi->A, &Ad));
1128   PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(ampi->B, &Ao));
1129   PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(bmpi->A, &Bd));
1130   PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(bmpi->B, &Bo));
1131 
1132   // TODO: add command line options to select spgemm algorithms
1133   auto spgemm_alg = KokkosSparse::SPGEMMAlgorithm::SPGEMM_DEFAULT; // default is TPL if enabled, otherwise KK
1134 
1135   // CUDA-10.2's spgemm has bugs. We prefer the SpGEMMreuse APIs introduced in cuda-11.4
1136 #if defined(KOKKOSKERNELS_ENABLE_TPL_CUSPARSE)
1137   #if PETSC_PKG_CUDA_VERSION_LT(11, 4, 0)
1138   spgemm_alg = KokkosSparse::SPGEMMAlgorithm::SPGEMM_KK;
1139   #endif
1140 #endif
1141 
1142   PetscCallCXX(mm->kh1.create_spgemm_handle(spgemm_alg));
1143   PetscCallCXX(mm->kh2.create_spgemm_handle(spgemm_alg));
1144   PetscCallCXX(mm->kh3.create_spgemm_handle(spgemm_alg));
1145   PetscCallCXX(mm->kh4.create_spgemm_handle(spgemm_alg));
1146 
1147   // Aot * (B's diag + B's off-diag)
1148   PetscCallCXX(KokkosSparse::spgemm_symbolic(mm->kh3, Aot, false, Bd, false, mm->C3));
1149   PetscCallCXX(KokkosSparse::spgemm_symbolic(mm->kh4, Aot, false, Bo, false, mm->C4));
1150   // KK spgemm_symbolic() only populates the result's row map, but not its columns.
1151   // TODO: Remove the fake spgemm_numeric() after KK fixed this problem.
1152   PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh3, Aot, false, Bd, false, mm->C3));
1153   PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh4, Aot, false, Bo, false, mm->C4));
1154 #if PETSC_PKG_KOKKOS_KERNELS_VERSION_LT(4, 0, 0)
1155   PetscCallCXX(sort_crs_matrix(mm->C3));
1156   PetscCallCXX(sort_crs_matrix(mm->C4));
1157 #endif
1158 
1159   // Reduce E (i.e., C3 and C4)'s rows to form F, and overlap the communication
1160   PetscIntKokkosViewHost map_h(NoInit("map_h"), bmpi->B->cmap->n);
1161   PetscCall(MatGetOwnershipRangeColumn(B, &cstart, &cend));
1162   PetscCall(MatMPIAIJKokkosReduceBegin(comm, mm->C3, mm->C4, cstart, cend, bmpi->garray, ampi->Mvctx, MAT_INITIAL_MATRIX, map_h.data(), mm));
1163 
1164   // Adt * (B's diag + B's off-diag)
1165   PetscCallCXX(KokkosSparse::spgemm_symbolic(mm->kh1, Adt, false, Bd, false, mm->C1));
1166   PetscCallCXX(KokkosSparse::spgemm_symbolic(mm->kh2, Adt, false, Bo, false, mm->C2_mid));
1167   PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh1, Adt, false, Bd, false, mm->C1));
1168   PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh2, Adt, false, Bo, false, mm->C2_mid));
1169 #if PETSC_PKG_KOKKOS_KERNELS_VERSION_LT(4, 0, 0)
1170   PetscCallCXX(sort_crs_matrix(mm->C1));
1171   PetscCallCXX(sort_crs_matrix(mm->C2_mid));
1172 #endif
1173 
1174   PetscCall(MatMPIAIJKokkosReduceEnd(comm, mm->C3, mm->C4, cstart, cend, bmpi->garray, ampi->Mvctx, MAT_INITIAL_MATRIX, map_h.data(), mm));
1175 
1176   // Create C2, which shares a, i arrays with C2_mid, but with new column indices and potentially larger column size
1177   MatColIdxKokkosView oldj = mm->C2_mid.graph.entries, newj(NoInit("j"), oldj.extent(0));
1178   PetscIntKokkosView  map  = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), map_h);
1179   PetscCallCXX(Kokkos::parallel_for(
1180     oldj.extent(0), KOKKOS_LAMBDA(const PetscInt i) { newj(i) = map(oldj(i)); }));
1181   PetscCallCXX(mm->C2 = KokkosCsrMatrix("C2", mm->C2_mid.numRows(), mm->n /*new column size*/, mm->C2_mid.nnz(), mm->C2_mid.values, mm->C2_mid.graph.row_map, newj));
1182 
1183   // C = (C1+Fd, C2+Fo)
1184   PetscCallCXX(mm->kh1.create_spadd_handle(true)); // C1, Fd are sorted
1185   PetscCallCXX(mm->kh2.create_spadd_handle(true)); // C2, Fo are sorted
1186   PetscCallCXX(KokkosSparse::spadd_symbolic(&mm->kh1, mm->C1, mm->Fd, mm->Cd));
1187   PetscCallCXX(KokkosSparse::spadd_symbolic(&mm->kh2, mm->C2, mm->Fo, mm->Co));
1188   PetscCallCXX(KokkosSparse::spadd_numeric(&mm->kh1, 1.0, mm->C1, 1.0, mm->Fd, mm->Cd));
1189   PetscCallCXX(KokkosSparse::spadd_numeric(&mm->kh2, 1.0, mm->C2, 1.0, mm->Fo, mm->Co));
1190   PetscFunctionReturn(PETSC_SUCCESS);
1191 }
1192 
1193 static PetscErrorCode MatProductNumeric_MPIAIJKokkos_AtB(Mat_Product *product, Mat A, Mat B, MatMatStruct_AtB *mm)
1194 {
1195   Mat_MPIAIJ     *ampi = static_cast<Mat_MPIAIJ *>(A->data);
1196   Mat_MPIAIJ     *bmpi = static_cast<Mat_MPIAIJ *>(B->data);
1197   KokkosCsrMatrix Adt, Aot, Bd, Bo;
1198   MPI_Comm        comm;
1199 
1200   PetscFunctionBegin;
1201   PetscCall(PetscObjectGetComm((PetscObject)B, &comm));
1202   PetscCall(MatSeqAIJKokkosGenerateTranspose_Private(ampi->A, &Adt));
1203   PetscCall(MatSeqAIJKokkosGenerateTranspose_Private(ampi->B, &Aot));
1204   PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(bmpi->A, &Bd));
1205   PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(bmpi->B, &Bo));
1206 
1207   // Aot * (B's diag + B's off-diag)
1208   PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh3, Aot, false, Bd, false, mm->C3));
1209   PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh4, Aot, false, Bo, false, mm->C4));
1210 
1211   // Reduce E (i.e., C3 and C4)'s rows to form F, and overlap the communication
1212   PetscCall(MatMPIAIJKokkosReduceBegin(comm, mm->C3, mm->C4, 0, 0, NULL, NULL, MAT_REUSE_MATRIX, NULL, mm));
1213 
1214   // Adt * (B's diag + B's off-diag)
1215   PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh1, Adt, false, Bd, false, mm->C1));
1216   PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh2, Adt, false, Bo, false, mm->C2_mid));
1217 
1218   PetscCall(MatMPIAIJKokkosReduceEnd(comm, mm->C3, mm->C4, 0, 0, NULL, NULL, MAT_REUSE_MATRIX, NULL, mm));
1219 
1220   // C = (C1+Fd, C2+Fo)
1221   PetscCallCXX(KokkosSparse::spadd_numeric(&mm->kh1, 1.0, mm->C1, 1.0, mm->Fd, mm->Cd));
1222   PetscCallCXX(KokkosSparse::spadd_numeric(&mm->kh2, 1.0, mm->C2, 1.0, mm->Fo, mm->Co));
1223   PetscFunctionReturn(PETSC_SUCCESS);
1224 }
1225 
1226 /* MatProductSymbolic_MPIAIJKokkos_AB - AB flavor of MatProductSymbolic_MPIAIJKokkos
1227 
1228   Input Parameters:
1229 +  product  - Mat_Product which carried out the computation. Passed in to access info about this mat product.
1230 .  A        - an MPIAIJKOKKOS matrix
1231 .  B        - an MPIAIJKOKKOS matrix
1232 -  mm       - a struct used to stash intermediate data when computing AB. Persist from symbolic to numeric operations.
1233 */
1234 static PetscErrorCode MatProductSymbolic_MPIAIJKokkos_AB(Mat_Product *product, Mat A, Mat B, MatMatStruct_AB *mm)
1235 {
1236   Mat_MPIAIJ     *ampi = static_cast<Mat_MPIAIJ *>(A->data);
1237   Mat_MPIAIJ     *bmpi = static_cast<Mat_MPIAIJ *>(B->data);
1238   KokkosCsrMatrix Ad, Ao, Bd, Bo;
1239 
1240   PetscFunctionBegin;
1241   PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(ampi->A, &Ad));
1242   PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(ampi->B, &Ao));
1243   PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(bmpi->A, &Bd));
1244   PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(bmpi->B, &Bo));
1245 
1246   // TODO: add command line options to select spgemm algorithms
1247   auto spgemm_alg = KokkosSparse::SPGEMMAlgorithm::SPGEMM_DEFAULT; // default is TPL if enabled, otherwise KK
1248 
1249   // CUDA-10.2's spgemm has bugs. We prefer the SpGEMMreuse APIs introduced in cuda-11.4
1250 #if defined(KOKKOSKERNELS_ENABLE_TPL_CUSPARSE)
1251   #if PETSC_PKG_CUDA_VERSION_LT(11, 4, 0)
1252   spgemm_alg = KokkosSparse::SPGEMMAlgorithm::SPGEMM_KK;
1253   #endif
1254 #endif
1255 
1256   mm->kh1.create_spgemm_handle(spgemm_alg);
1257   mm->kh2.create_spgemm_handle(spgemm_alg);
1258   mm->kh3.create_spgemm_handle(spgemm_alg);
1259   mm->kh4.create_spgemm_handle(spgemm_alg);
1260 
1261   // Bcast B's rows to form F, and overlap the communication
1262   PetscIntKokkosViewHost map_h(NoInit("map_h"), bmpi->B->cmap->n);
1263   PetscCall(MatMPIAIJKokkosBcastBegin(B, ampi->Mvctx, MAT_INITIAL_MATRIX, map_h.data(), mm));
1264 
1265   // A's diag * (B's diag + B's off-diag)
1266   PetscCallCXX(KokkosSparse::spgemm_symbolic(mm->kh1, Ad, false, Bd, false, mm->C1));
1267   PetscCallCXX(KokkosSparse::spgemm_symbolic(mm->kh2, Ad, false, Bo, false, mm->C2_mid)); // C2 aliases with C2_mid, except with new column indices
1268   // KK spgemm_symbolic() only populates the result's row map, but not its columns.
1269   // TODO: Remove the fake spgemm_numeric() after KK fixed this problem.
1270   PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh1, Ad, false, Bd, false, mm->C1));
1271   PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh2, Ad, false, Bo, false, mm->C2_mid));
1272 #if PETSC_PKG_KOKKOS_KERNELS_VERSION_LT(4, 0, 0)
1273   PetscCallCXX(sort_crs_matrix(mm->C1));
1274   PetscCallCXX(sort_crs_matrix(mm->C2_mid));
1275 #endif
1276 
1277   PetscCall(MatMPIAIJKokkosBcastEnd(B, ampi->Mvctx, MAT_INITIAL_MATRIX, map_h.data(), mm));
1278 
1279   // A's off-diag * (F's diag + F's off-diag)
1280   PetscCallCXX(KokkosSparse::spgemm_symbolic(mm->kh3, Ao, false, mm->Fd, false, mm->C3));
1281   PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh3, Ao, false, mm->Fd, false, mm->C3));
1282   PetscCallCXX(KokkosSparse::spgemm_symbolic(mm->kh4, Ao, false, mm->Fo, false, mm->C4));
1283   PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh4, Ao, false, mm->Fo, false, mm->C4));
1284 #if PETSC_PKG_KOKKOS_KERNELS_VERSION_LT(4, 0, 0)
1285   PetscCallCXX(sort_crs_matrix(mm->C3));
1286   PetscCallCXX(sort_crs_matrix(mm->C4));
1287 #endif
1288 
1289   // Create C2, which shares a, i arrays with C2_mid, but with new column indices and potentially larger column size
1290   MatColIdxKokkosView oldj = mm->C2_mid.graph.entries, newj(NoInit("j"), oldj.extent(0));
1291   PetscIntKokkosView  map  = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), map_h);
1292   PetscCallCXX(Kokkos::parallel_for(
1293     oldj.extent(0), KOKKOS_LAMBDA(const PetscInt i) { newj(i) = map(oldj(i)); }));
1294   mm->C2 = KokkosCsrMatrix("C2", mm->C2_mid.numRows(), mm->n /*new column size*/, mm->C2_mid.nnz(), mm->C2_mid.values, mm->C2_mid.graph.row_map, newj);
1295 
1296   // C = (Cd, Co) = (C1+C3, C2+C4)
1297   mm->kh1.create_spadd_handle(true); // C1, C3 are sorted
1298   mm->kh2.create_spadd_handle(true); // C2, C4 are sorted
1299   PetscCallCXX(KokkosSparse::spadd_symbolic(&mm->kh1, mm->C1, mm->C3, mm->Cd));
1300   PetscCallCXX(KokkosSparse::spadd_symbolic(&mm->kh2, mm->C2, mm->C4, mm->Co));
1301   PetscCallCXX(KokkosSparse::spadd_numeric(&mm->kh1, 1.0, mm->C1, 1.0, mm->C3, mm->Cd));
1302   PetscCallCXX(KokkosSparse::spadd_numeric(&mm->kh2, 1.0, mm->C2, 1.0, mm->C4, mm->Co));
1303   PetscFunctionReturn(PETSC_SUCCESS);
1304 }
1305 
1306 static PetscErrorCode MatProductNumeric_MPIAIJKokkos_AB(Mat_Product *product, Mat A, Mat B, MatMatStruct_AB *mm)
1307 {
1308   Mat_MPIAIJ     *ampi = static_cast<Mat_MPIAIJ *>(A->data);
1309   Mat_MPIAIJ     *bmpi = static_cast<Mat_MPIAIJ *>(B->data);
1310   KokkosCsrMatrix Ad, Ao, Bd, Bo;
1311 
1312   PetscFunctionBegin;
1313   PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(ampi->A, &Ad));
1314   PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(ampi->B, &Ao));
1315   PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(bmpi->A, &Bd));
1316   PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(bmpi->B, &Bo));
1317 
1318   // Bcast B's rows to form F, and overlap the communication
1319   PetscCall(MatMPIAIJKokkosBcastBegin(B, NULL, MAT_REUSE_MATRIX, NULL, mm));
1320 
1321   // A's diag * (B's diag + B's off-diag)
1322   PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh1, Ad, false, Bd, false, mm->C1));
1323   PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh2, Ad, false, Bo, false, mm->C2_mid));
1324 
1325   PetscCall(MatMPIAIJKokkosBcastEnd(B, NULL, MAT_REUSE_MATRIX, NULL, mm));
1326 
1327   // A's off-diag * (F's diag + F's off-diag)
1328   PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh3, Ao, false, mm->Fd, false, mm->C3));
1329   PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh4, Ao, false, mm->Fo, false, mm->C4));
1330 
1331   // C = (Cd, Co) = (C1+C3, C2+C4)
1332   PetscCallCXX(KokkosSparse::spadd_numeric(&mm->kh1, 1.0, mm->C1, 1.0, mm->C3, mm->Cd));
1333   PetscCallCXX(KokkosSparse::spadd_numeric(&mm->kh2, 1.0, mm->C2, 1.0, mm->C4, mm->Co));
1334   PetscFunctionReturn(PETSC_SUCCESS);
1335 }
1336 
1337 PetscErrorCode MatProductNumeric_MPIAIJKokkos(Mat C)
1338 {
1339   Mat_MPIAIJ                  *cmpi = static_cast<Mat_MPIAIJ *>(C->data);
1340   Mat_Product                 *product;
1341   MatProductData_MPIAIJKokkos *pdata;
1342   MatProductType               ptype;
1343   Mat                          A, B;
1344 
1345   PetscFunctionBegin;
1346   MatCheckProduct(C, 1); // make sure C is a product
1347   product = C->product;
1348   pdata   = static_cast<MatProductData_MPIAIJKokkos *>(product->data);
1349   ptype   = product->type;
1350   A       = product->A;
1351   B       = product->B;
1352 
1353   // See if numeric has already been done in symbolic (e.g., user calls MatMatMult(A,B,MAT_INITIAL_MATRIX,..,C)).
1354   // If yes, skip the numeric, but reset the flag so that next time when user calls MatMatMult(E,F,MAT_REUSE_MATRIX,..,C),
1355   // we still do numeric.
1356   if (pdata->reusesym) { // numeric reuses results from symbolic
1357     pdata->reusesym = PETSC_FALSE;
1358     PetscFunctionReturn(PETSC_SUCCESS);
1359   }
1360 
1361   if (ptype == MATPRODUCT_AB) {
1362     PetscCall(MatProductNumeric_MPIAIJKokkos_AB(product, A, B, pdata->mmAB));
1363   } else if (ptype == MATPRODUCT_AtB) {
1364     PetscCall(MatProductNumeric_MPIAIJKokkos_AtB(product, A, B, pdata->mmAtB));
1365   } else if (ptype == MATPRODUCT_PtAP) { // BtAB, computed by Z = AB; C= BtZ
1366     PetscCall(MatProductNumeric_MPIAIJKokkos_AB(product, A, B, pdata->mmAB));
1367     PetscCall(MatProductNumeric_MPIAIJKokkos_AtB(product, B, pdata->Z, pdata->mmAtB));
1368   }
1369 
1370   PetscCall(MatSeqAIJKokkosModifyDevice(cmpi->A)); // mark that A, B on device are modified
1371   PetscCall(MatSeqAIJKokkosModifyDevice(cmpi->B));
1372   PetscFunctionReturn(PETSC_SUCCESS);
1373 }
1374 
1375 PetscErrorCode MatProductSymbolic_MPIAIJKokkos(Mat C)
1376 {
1377   Mat                          A, B;
1378   Mat_Product                 *product;
1379   MatProductType               ptype;
1380   MatProductData_MPIAIJKokkos *pdata;
1381   MatMatStruct                *mm = NULL;
1382   PetscInt                     m, n, M, N;
1383   Mat                          Cd, Co;
1384   MPI_Comm                     comm;
1385 
1386   PetscFunctionBegin;
1387   PetscCall(PetscObjectGetComm((PetscObject)C, &comm));
1388   MatCheckProduct(C, 1);
1389   product = C->product;
1390   PetscCheck(!product->data, comm, PETSC_ERR_PLIB, "Product data not empty");
1391   ptype = product->type;
1392   A     = product->A;
1393   B     = product->B;
1394 
1395   switch (ptype) {
1396   case MATPRODUCT_AB:
1397     m = A->rmap->n;
1398     n = B->cmap->n;
1399     M = A->rmap->N;
1400     N = B->cmap->N;
1401     break;
1402   case MATPRODUCT_AtB:
1403     m = A->cmap->n;
1404     n = B->cmap->n;
1405     M = A->cmap->N;
1406     N = B->cmap->N;
1407     break;
1408   case MATPRODUCT_PtAP:
1409     m = B->cmap->n;
1410     n = B->cmap->n;
1411     M = B->cmap->N;
1412     N = B->cmap->N;
1413     break; /* BtAB */
1414   default:
1415     SETERRQ(comm, PETSC_ERR_PLIB, "Not for product type %s", MatProductTypes[ptype]);
1416   }
1417 
1418   PetscCall(MatSetSizes(C, m, n, M, N));
1419   PetscCall(PetscLayoutSetUp(C->rmap));
1420   PetscCall(PetscLayoutSetUp(C->cmap));
1421   PetscCall(MatSetType(C, ((PetscObject)A)->type_name));
1422 
1423   pdata           = new MatProductData_MPIAIJKokkos();
1424   pdata->reusesym = product->api_user;
1425 
1426   if (ptype == MATPRODUCT_AB) {
1427     auto mmAB = new MatMatStruct_AB();
1428     PetscCall(MatProductSymbolic_MPIAIJKokkos_AB(product, A, B, mmAB));
1429     mm = pdata->mmAB = mmAB;
1430   } else if (ptype == MATPRODUCT_AtB) {
1431     auto mmAtB = new MatMatStruct_AtB();
1432     PetscCall(MatProductSymbolic_MPIAIJKokkos_AtB(product, A, B, mmAtB));
1433     mm = pdata->mmAtB = mmAtB;
1434   } else if (ptype == MATPRODUCT_PtAP) { // C = BtAB, computed as Z = AB; C= BtZ
1435     Mat Zd, Zo, Z;                       // Zd, Zo are owned by pdata->Z
1436 
1437     auto mmAB = new MatMatStruct_AB();
1438     PetscCall(MatProductSymbolic_MPIAIJKokkos_AB(product, A, B, mmAB)); // Z stored as mmAB->{Cd, Co}
1439     PetscCall(MatCreateSeqAIJKokkosWithKokkosCsrMatrix(PETSC_COMM_SELF, mmAB->Cd, &Zd));
1440     PetscCall(MatCreateSeqAIJKokkosWithKokkosCsrMatrix(PETSC_COMM_SELF, mmAB->Co, &Zo));
1441     pdata->mmAB = mmAB;
1442 
1443     m = A->rmap->n; // Z's layout
1444     n = B->cmap->n;
1445     M = A->rmap->N;
1446     N = B->cmap->N;
1447     PetscCall(MatCreate(comm, &Z));
1448     PetscCall(MatSetSizes(Z, m, n, M, N));
1449     PetscCall(PetscLayoutSetUp(Z->rmap));
1450     PetscCall(PetscLayoutSetUp(Z->cmap));
1451     PetscCall(MatSetType(Z, MATMPIAIJKOKKOS));
1452     PetscCall(MatSetMPIAIJKokkosWithSplitSeqAIJKokkosMatrices(Z, Zd, Zo, mmAB->garray));
1453 
1454     auto mmAtB = new MatMatStruct_AtB();
1455     PetscCall(MatProductSymbolic_MPIAIJKokkos_AtB(product, B, Z, mmAtB)); // final result C stored as mmAtB->{Cd, Co}
1456 
1457     pdata->Z = Z; // give ownership to pdata
1458     mm = pdata->mmAtB = mmAtB;
1459   }
1460 
1461   PetscCall(MatCreateSeqAIJKokkosWithKokkosCsrMatrix(PETSC_COMM_SELF, mm->Cd, &Cd));
1462   PetscCall(MatCreateSeqAIJKokkosWithKokkosCsrMatrix(PETSC_COMM_SELF, mm->Co, &Co));
1463   PetscCall(MatSetMPIAIJKokkosWithSplitSeqAIJKokkosMatrices(C, Cd, Co, mm->garray));
1464 
1465   C->product->data       = pdata;
1466   C->product->destroy    = MatProductDataDestroy_MPIAIJKokkos;
1467   C->ops->productnumeric = MatProductNumeric_MPIAIJKokkos;
1468   PetscFunctionReturn(PETSC_SUCCESS);
1469 }
1470 
1471 PETSC_INTERN PetscErrorCode MatProductSetFromOptions_MPIAIJKokkos(Mat mat)
1472 {
1473   Mat_Product *product = mat->product;
1474   PetscBool    match   = PETSC_FALSE;
1475   PetscBool    usecpu  = PETSC_FALSE;
1476 
1477   PetscFunctionBegin;
1478   MatCheckProduct(mat, 1);
1479   if (!product->A->boundtocpu && !product->B->boundtocpu) PetscCall(PetscObjectTypeCompare((PetscObject)product->B, ((PetscObject)product->A)->type_name, &match));
1480   if (match) { /* we can always fallback to the CPU if requested */
1481     switch (product->type) {
1482     case MATPRODUCT_AB:
1483       if (product->api_user) {
1484         PetscOptionsBegin(PetscObjectComm((PetscObject)mat), ((PetscObject)mat)->prefix, "MatMatMult", "Mat");
1485         PetscCall(PetscOptionsBool("-matmatmult_backend_cpu", "Use CPU code", "MatMatMult", usecpu, &usecpu, NULL));
1486         PetscOptionsEnd();
1487       } else {
1488         PetscOptionsBegin(PetscObjectComm((PetscObject)mat), ((PetscObject)mat)->prefix, "MatProduct_AB", "Mat");
1489         PetscCall(PetscOptionsBool("-mat_product_algorithm_backend_cpu", "Use CPU code", "MatMatMult", usecpu, &usecpu, NULL));
1490         PetscOptionsEnd();
1491       }
1492       break;
1493     case MATPRODUCT_AtB:
1494       if (product->api_user) {
1495         PetscOptionsBegin(PetscObjectComm((PetscObject)mat), ((PetscObject)mat)->prefix, "MatTransposeMatMult", "Mat");
1496         PetscCall(PetscOptionsBool("-mattransposematmult_backend_cpu", "Use CPU code", "MatTransposeMatMult", usecpu, &usecpu, NULL));
1497         PetscOptionsEnd();
1498       } else {
1499         PetscOptionsBegin(PetscObjectComm((PetscObject)mat), ((PetscObject)mat)->prefix, "MatProduct_AtB", "Mat");
1500         PetscCall(PetscOptionsBool("-mat_product_algorithm_backend_cpu", "Use CPU code", "MatTransposeMatMult", usecpu, &usecpu, NULL));
1501         PetscOptionsEnd();
1502       }
1503       break;
1504     case MATPRODUCT_PtAP:
1505       if (product->api_user) {
1506         PetscOptionsBegin(PetscObjectComm((PetscObject)mat), ((PetscObject)mat)->prefix, "MatPtAP", "Mat");
1507         PetscCall(PetscOptionsBool("-matptap_backend_cpu", "Use CPU code", "MatPtAP", usecpu, &usecpu, NULL));
1508         PetscOptionsEnd();
1509       } else {
1510         PetscOptionsBegin(PetscObjectComm((PetscObject)mat), ((PetscObject)mat)->prefix, "MatProduct_PtAP", "Mat");
1511         PetscCall(PetscOptionsBool("-mat_product_algorithm_backend_cpu", "Use CPU code", "MatPtAP", usecpu, &usecpu, NULL));
1512         PetscOptionsEnd();
1513       }
1514       break;
1515     default:
1516       break;
1517     }
1518     match = (PetscBool)!usecpu;
1519   }
1520   if (match) {
1521     switch (product->type) {
1522     case MATPRODUCT_AB:
1523     case MATPRODUCT_AtB:
1524     case MATPRODUCT_PtAP:
1525       mat->ops->productsymbolic = MatProductSymbolic_MPIAIJKokkos;
1526       break;
1527     default:
1528       break;
1529     }
1530   }
1531   /* fallback to MPIAIJ ops */
1532   if (!mat->ops->productsymbolic) PetscCall(MatProductSetFromOptions_MPIAIJ(mat));
1533   PetscFunctionReturn(PETSC_SUCCESS);
1534 }
1535 
1536 // Mirror of MatCOOStruct_MPIAIJ on device
1537 struct MatCOOStruct_MPIAIJKokkos {
1538   PetscCount           n;
1539   PetscSF              sf;
1540   PetscCount           Annz, Bnnz;
1541   PetscCount           Annz2, Bnnz2;
1542   PetscCountKokkosView Ajmap1, Aperm1;
1543   PetscCountKokkosView Bjmap1, Bperm1;
1544   PetscCountKokkosView Aimap2, Ajmap2, Aperm2;
1545   PetscCountKokkosView Bimap2, Bjmap2, Bperm2;
1546   PetscCountKokkosView Cperm1;
1547   MatScalarKokkosView  sendbuf, recvbuf;
1548 
1549   MatCOOStruct_MPIAIJKokkos(const MatCOOStruct_MPIAIJ *coo_h) :
1550     n(coo_h->n),
1551     sf(coo_h->sf),
1552     Annz(coo_h->Annz),
1553     Bnnz(coo_h->Bnnz),
1554     Annz2(coo_h->Annz2),
1555     Bnnz2(coo_h->Bnnz2),
1556     Ajmap1(Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), PetscCountKokkosViewHost(coo_h->Ajmap1, coo_h->Annz + 1))),
1557     Aperm1(Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), PetscCountKokkosViewHost(coo_h->Aperm1, coo_h->Atot1))),
1558     Bjmap1(Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), PetscCountKokkosViewHost(coo_h->Bjmap1, coo_h->Bnnz + 1))),
1559     Bperm1(Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), PetscCountKokkosViewHost(coo_h->Bperm1, coo_h->Btot1))),
1560     Aimap2(Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), PetscCountKokkosViewHost(coo_h->Aimap2, coo_h->Annz2))),
1561     Ajmap2(Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), PetscCountKokkosViewHost(coo_h->Ajmap2, coo_h->Annz2 + 1))),
1562     Aperm2(Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), PetscCountKokkosViewHost(coo_h->Aperm2, coo_h->Atot2))),
1563     Bimap2(Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), PetscCountKokkosViewHost(coo_h->Bimap2, coo_h->Bnnz2))),
1564     Bjmap2(Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), PetscCountKokkosViewHost(coo_h->Bjmap2, coo_h->Bnnz2 + 1))),
1565     Bperm2(Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), PetscCountKokkosViewHost(coo_h->Bperm2, coo_h->Btot2))),
1566     Cperm1(Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), PetscCountKokkosViewHost(coo_h->Cperm1, coo_h->sendlen))),
1567     sendbuf(Kokkos::create_mirror_view(DefaultMemorySpace(), MatScalarKokkosViewHost(coo_h->sendbuf, coo_h->sendlen))),
1568     recvbuf(Kokkos::create_mirror_view(DefaultMemorySpace(), MatScalarKokkosViewHost(coo_h->recvbuf, coo_h->recvlen)))
1569   {
1570     PetscCallVoid(PetscObjectReference((PetscObject)sf));
1571   }
1572 
1573   ~MatCOOStruct_MPIAIJKokkos() { PetscCallVoid(PetscSFDestroy(&sf)); }
1574 };
1575 
1576 static PetscErrorCode MatCOOStructDestroy_MPIAIJKokkos(void *data)
1577 {
1578   PetscFunctionBegin;
1579   PetscCallCXX(delete static_cast<MatCOOStruct_MPIAIJKokkos *>(data));
1580   PetscFunctionReturn(PETSC_SUCCESS);
1581 }
1582 
1583 static PetscErrorCode MatSetPreallocationCOO_MPIAIJKokkos(Mat mat, PetscCount coo_n, PetscInt coo_i[], PetscInt coo_j[])
1584 {
1585   PetscContainer             container_h, container_d;
1586   MatCOOStruct_MPIAIJ       *coo_h;
1587   MatCOOStruct_MPIAIJKokkos *coo_d;
1588 
1589   PetscFunctionBegin;
1590   PetscCall(MatSetPreallocationCOO_MPIAIJ(mat, coo_n, coo_i, coo_j)); /* mpiaij->A,B's type is set to seqaijkokkos */
1591   mat->preallocated = PETSC_TRUE;
1592   PetscCall(MatAssemblyBegin(mat, MAT_FINAL_ASSEMBLY));
1593   PetscCall(MatAssemblyEnd(mat, MAT_FINAL_ASSEMBLY));
1594   PetscCall(MatZeroEntries(mat));
1595 
1596   // Copy the COO struct to device
1597   PetscCall(PetscObjectQuery((PetscObject)mat, "__PETSc_MatCOOStruct_Host", (PetscObject *)&container_h));
1598   PetscCall(PetscContainerGetPointer(container_h, (void **)&coo_h));
1599   PetscCallCXX(coo_d = new MatCOOStruct_MPIAIJKokkos(coo_h));
1600 
1601   // Put the COO struct in a container and then attach that to the matrix
1602   PetscCall(PetscContainerCreate(PETSC_COMM_SELF, &container_d));
1603   PetscCall(PetscContainerSetPointer(container_d, coo_d));
1604   PetscCall(PetscContainerSetUserDestroy(container_d, MatCOOStructDestroy_MPIAIJKokkos));
1605   PetscCall(PetscObjectCompose((PetscObject)mat, "__PETSc_MatCOOStruct_Device", (PetscObject)container_d));
1606   PetscCall(PetscContainerDestroy(&container_d));
1607   PetscFunctionReturn(PETSC_SUCCESS);
1608 }
1609 
1610 static PetscErrorCode MatSetValuesCOO_MPIAIJKokkos(Mat mat, const PetscScalar v[], InsertMode imode)
1611 {
1612   Mat_MPIAIJ                *mpiaij = static_cast<Mat_MPIAIJ *>(mat->data);
1613   Mat                        A = mpiaij->A, B = mpiaij->B;
1614   MatScalarKokkosView        Aa, Ba;
1615   MatScalarKokkosView        v1;
1616   PetscMemType               memtype;
1617   PetscContainer             container;
1618   MatCOOStruct_MPIAIJKokkos *coo;
1619 
1620   PetscFunctionBegin;
1621   PetscCall(PetscObjectQuery((PetscObject)mat, "__PETSc_MatCOOStruct_Device", (PetscObject *)&container));
1622   PetscCall(PetscContainerGetPointer(container, (void **)&coo));
1623 
1624   const auto &n      = coo->n;
1625   const auto &Annz   = coo->Annz;
1626   const auto &Annz2  = coo->Annz2;
1627   const auto &Bnnz   = coo->Bnnz;
1628   const auto &Bnnz2  = coo->Bnnz2;
1629   const auto &vsend  = coo->sendbuf;
1630   const auto &v2     = coo->recvbuf;
1631   const auto &Ajmap1 = coo->Ajmap1;
1632   const auto &Ajmap2 = coo->Ajmap2;
1633   const auto &Aimap2 = coo->Aimap2;
1634   const auto &Bjmap1 = coo->Bjmap1;
1635   const auto &Bjmap2 = coo->Bjmap2;
1636   const auto &Bimap2 = coo->Bimap2;
1637   const auto &Aperm1 = coo->Aperm1;
1638   const auto &Aperm2 = coo->Aperm2;
1639   const auto &Bperm1 = coo->Bperm1;
1640   const auto &Bperm2 = coo->Bperm2;
1641   const auto &Cperm1 = coo->Cperm1;
1642 
1643   PetscCall(PetscGetMemType(v, &memtype)); /* Return PETSC_MEMTYPE_HOST when v is NULL */
1644   if (PetscMemTypeHost(memtype)) {         /* If user gave v[] in host, we need to copy it to device if any */
1645     v1 = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), MatScalarKokkosViewHost((PetscScalar *)v, n));
1646   } else {
1647     v1 = MatScalarKokkosView((PetscScalar *)v, n); /* Directly use v[]'s memory */
1648   }
1649 
1650   if (imode == INSERT_VALUES) {
1651     PetscCall(MatSeqAIJGetKokkosViewWrite(A, &Aa)); /* write matrix values */
1652     PetscCall(MatSeqAIJGetKokkosViewWrite(B, &Ba));
1653   } else {
1654     PetscCall(MatSeqAIJGetKokkosView(A, &Aa)); /* read & write matrix values */
1655     PetscCall(MatSeqAIJGetKokkosView(B, &Ba));
1656   }
1657 
1658   /* Pack entries to be sent to remote */
1659   Kokkos::parallel_for(
1660     vsend.extent(0), KOKKOS_LAMBDA(const PetscCount i) { vsend(i) = v1(Cperm1(i)); });
1661 
1662   /* Send remote entries to their owner and overlap the communication with local computation */
1663   PetscCall(PetscSFReduceWithMemTypeBegin(coo->sf, MPIU_SCALAR, PETSC_MEMTYPE_KOKKOS, vsend.data(), PETSC_MEMTYPE_KOKKOS, v2.data(), MPI_REPLACE));
1664   /* Add local entries to A and B in one kernel */
1665   Kokkos::parallel_for(
1666     Annz + Bnnz, KOKKOS_LAMBDA(PetscCount i) {
1667       PetscScalar sum = 0.0;
1668       if (i < Annz) {
1669         for (PetscCount k = Ajmap1(i); k < Ajmap1(i + 1); k++) sum += v1(Aperm1(k));
1670         Aa(i) = (imode == INSERT_VALUES ? 0.0 : Aa(i)) + sum;
1671       } else {
1672         i -= Annz;
1673         for (PetscCount k = Bjmap1(i); k < Bjmap1(i + 1); k++) sum += v1(Bperm1(k));
1674         Ba(i) = (imode == INSERT_VALUES ? 0.0 : Ba(i)) + sum;
1675       }
1676     });
1677   PetscCall(PetscSFReduceEnd(coo->sf, MPIU_SCALAR, vsend.data(), v2.data(), MPI_REPLACE));
1678 
1679   /* Add received remote entries to A and B in one kernel */
1680   Kokkos::parallel_for(
1681     Annz2 + Bnnz2, KOKKOS_LAMBDA(PetscCount i) {
1682       if (i < Annz2) {
1683         for (PetscCount k = Ajmap2(i); k < Ajmap2(i + 1); k++) Aa(Aimap2(i)) += v2(Aperm2(k));
1684       } else {
1685         i -= Annz2;
1686         for (PetscCount k = Bjmap2(i); k < Bjmap2(i + 1); k++) Ba(Bimap2(i)) += v2(Bperm2(k));
1687       }
1688     });
1689 
1690   if (imode == INSERT_VALUES) {
1691     PetscCall(MatSeqAIJRestoreKokkosViewWrite(A, &Aa)); /* Increase A & B's state etc. */
1692     PetscCall(MatSeqAIJRestoreKokkosViewWrite(B, &Ba));
1693   } else {
1694     PetscCall(MatSeqAIJRestoreKokkosView(A, &Aa));
1695     PetscCall(MatSeqAIJRestoreKokkosView(B, &Ba));
1696   }
1697   PetscFunctionReturn(PETSC_SUCCESS);
1698 }
1699 
1700 static PetscErrorCode MatDestroy_MPIAIJKokkos(Mat A)
1701 {
1702   PetscFunctionBegin;
1703   PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatMPIAIJSetPreallocation_C", NULL));
1704   PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatMPIAIJGetLocalMatMerge_C", NULL));
1705   PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatSetPreallocationCOO_C", NULL));
1706   PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatSetValuesCOO_C", NULL));
1707   PetscCall(MatDestroy_MPIAIJ(A));
1708   PetscFunctionReturn(PETSC_SUCCESS);
1709 }
1710 
1711 static PetscErrorCode MatSetOps_MPIAIJKokkos(Mat B)
1712 {
1713   PetscFunctionBegin;
1714   B->ops->assemblyend           = MatAssemblyEnd_MPIAIJKokkos;
1715   B->ops->mult                  = MatMult_MPIAIJKokkos;
1716   B->ops->multadd               = MatMultAdd_MPIAIJKokkos;
1717   B->ops->multtranspose         = MatMultTranspose_MPIAIJKokkos;
1718   B->ops->productsetfromoptions = MatProductSetFromOptions_MPIAIJKokkos;
1719   B->ops->destroy               = MatDestroy_MPIAIJKokkos;
1720 
1721   PetscCall(PetscObjectComposeFunction((PetscObject)B, "MatMPIAIJSetPreallocation_C", MatMPIAIJSetPreallocation_MPIAIJKokkos));
1722   PetscCall(PetscObjectComposeFunction((PetscObject)B, "MatMPIAIJGetLocalMatMerge_C", MatMPIAIJGetLocalMatMerge_MPIAIJKokkos));
1723   PetscCall(PetscObjectComposeFunction((PetscObject)B, "MatSetPreallocationCOO_C", MatSetPreallocationCOO_MPIAIJKokkos));
1724   PetscCall(PetscObjectComposeFunction((PetscObject)B, "MatSetValuesCOO_C", MatSetValuesCOO_MPIAIJKokkos));
1725   PetscFunctionReturn(PETSC_SUCCESS);
1726 }
1727 
1728 PETSC_INTERN PetscErrorCode MatConvert_MPIAIJ_MPIAIJKokkos(Mat A, MatType mtype, MatReuse reuse, Mat *newmat)
1729 {
1730   Mat         B;
1731   Mat_MPIAIJ *a;
1732 
1733   PetscFunctionBegin;
1734   if (reuse == MAT_INITIAL_MATRIX) {
1735     PetscCall(MatDuplicate(A, MAT_COPY_VALUES, newmat));
1736   } else if (reuse == MAT_REUSE_MATRIX) {
1737     PetscCall(MatCopy(A, *newmat, SAME_NONZERO_PATTERN));
1738   }
1739   B = *newmat;
1740 
1741   B->boundtocpu = PETSC_FALSE;
1742   PetscCall(PetscFree(B->defaultvectype));
1743   PetscCall(PetscStrallocpy(VECKOKKOS, &B->defaultvectype));
1744   PetscCall(PetscObjectChangeTypeName((PetscObject)B, MATMPIAIJKOKKOS));
1745 
1746   a = static_cast<Mat_MPIAIJ *>(A->data);
1747   if (a->A) PetscCall(MatSetType(a->A, MATSEQAIJKOKKOS));
1748   if (a->B) PetscCall(MatSetType(a->B, MATSEQAIJKOKKOS));
1749   if (a->lvec) PetscCall(VecSetType(a->lvec, VECSEQKOKKOS));
1750   PetscCall(MatSetOps_MPIAIJKokkos(B));
1751   PetscFunctionReturn(PETSC_SUCCESS);
1752 }
1753 
1754 /*MC
1755    MATAIJKOKKOS - "mpiaijkokkos", a matrix type to be used for CSR sparse matrices with Kokkos
1756 
1757    A matrix type type using Kokkos-Kernels CrsMatrix type for portability across different device types
1758 
1759    Options Database Key:
1760 .  -mat_type aijkokkos - sets the matrix type to `MATAIJKOKKOS`
1761 
1762   Level: beginner
1763 
1764 .seealso: [](ch_matrices), `Mat`, `MatCreateAIJKokkos()`, `MATSEQAIJKOKKOS`, `MATSEQAIJ`, `MATMPIAIJ`
1765 M*/
1766 PETSC_EXTERN PetscErrorCode MatCreate_MPIAIJKokkos(Mat A)
1767 {
1768   PetscFunctionBegin;
1769   PetscCall(PetscKokkosInitializeCheck());
1770   PetscCall(MatCreate_MPIAIJ(A));
1771   PetscCall(MatConvert_MPIAIJ_MPIAIJKokkos(A, MATMPIAIJKOKKOS, MAT_INPLACE_MATRIX, &A));
1772   PetscFunctionReturn(PETSC_SUCCESS);
1773 }
1774 
1775 /*@C
1776    MatCreateAIJKokkos - Creates a sparse matrix in `MATAIJKOKOS` (compressed row) format
1777    (the default parallel PETSc format).  This matrix will ultimately pushed down
1778    to Kokkos for calculations.
1779 
1780    Collective
1781 
1782    Input Parameters:
1783 +  comm - MPI communicator, set to `PETSC_COMM_SELF`
1784 .  m - number of local rows (or `PETSC_DECIDE` to have calculated if `M` is given)
1785            This value should be the same as the local size used in creating the
1786            y vector for the matrix-vector product y = Ax.
1787 .  n - This value should be the same as the local size used in creating the
1788        x vector for the matrix-vector product y = Ax. (or `PETSC_DECIDE` to have
1789        calculated if N is given) For square matrices n is almost always `m`.
1790 .  M - number of global rows (or `PETSC_DETERMINE` to have calculated if `m` is given)
1791 .  N - number of global columns (or `PETSC_DETERMINE` to have calculated if `n` is given)
1792 .  d_nz  - number of nonzeros per row in DIAGONAL portion of local submatrix
1793            (same value is used for all local rows)
1794 .  d_nnz - array containing the number of nonzeros in the various rows of the
1795            DIAGONAL portion of the local submatrix (possibly different for each row)
1796            or `NULL`, if `d_nz` is used to specify the nonzero structure.
1797            The size of this array is equal to the number of local rows, i.e `m`.
1798            For matrices you plan to factor you must leave room for the diagonal entry and
1799            put in the entry even if it is zero.
1800 .  o_nz  - number of nonzeros per row in the OFF-DIAGONAL portion of local
1801            submatrix (same value is used for all local rows).
1802 -  o_nnz - array containing the number of nonzeros in the various rows of the
1803            OFF-DIAGONAL portion of the local submatrix (possibly different for
1804            each row) or `NULL`, if `o_nz` is used to specify the nonzero
1805            structure. The size of this array is equal to the number
1806            of local rows, i.e `m`.
1807 
1808    Output Parameter:
1809 .  A - the matrix
1810 
1811    Level: intermediate
1812 
1813    Notes:
1814    It is recommended that one use the `MatCreate()`, `MatSetType()` and/or `MatSetFromOptions()`,
1815    MatXXXXSetPreallocation() paradigm instead of this routine directly.
1816    [MatXXXXSetPreallocation() is, for example, `MatSeqAIJSetPreallocation()`]
1817 
1818    The AIJ format, also called compressed row storage), is fully compatible with standard Fortran
1819    storage.  That is, the stored row and column indices can begin at
1820    either one (as in Fortran) or zero.
1821 
1822 .seealso: [](ch_matrices), `Mat`, `MATAIJKOKOS`, `MATSEQAIJKOKOS`, `MATMPIAIJKOKOS`, `MatCreate()`, `MatCreateAIJ()`, `MatSetValues()`,
1823           `MatSeqAIJSetColumnIndices()`, `MatCreateSeqAIJWithArrays()`, `MatCreateAIJ()`, `MATMPIAIJKOKKOS`, `MATAIJKOKKOS`
1824 @*/
1825 PetscErrorCode MatCreateAIJKokkos(MPI_Comm comm, PetscInt m, PetscInt n, PetscInt M, PetscInt N, PetscInt d_nz, const PetscInt d_nnz[], PetscInt o_nz, const PetscInt o_nnz[], Mat *A)
1826 {
1827   PetscMPIInt size;
1828 
1829   PetscFunctionBegin;
1830   PetscCall(MatCreate(comm, A));
1831   PetscCall(MatSetSizes(*A, m, n, M, N));
1832   PetscCallMPI(MPI_Comm_size(comm, &size));
1833   if (size > 1) {
1834     PetscCall(MatSetType(*A, MATMPIAIJKOKKOS));
1835     PetscCall(MatMPIAIJSetPreallocation(*A, d_nz, d_nnz, o_nz, o_nnz));
1836   } else {
1837     PetscCall(MatSetType(*A, MATSEQAIJKOKKOS));
1838     PetscCall(MatSeqAIJSetPreallocation(*A, d_nz, d_nnz));
1839   }
1840   PetscFunctionReturn(PETSC_SUCCESS);
1841 }
1842