xref: /petsc/src/mat/impls/aij/mpi/kokkos/mpiaijkok.kokkos.cxx (revision f4f49eeac7efa77fffa46b7ff95a3ed169f659ed)
1 #include <petscvec_kokkos.hpp>
2 #include <petscpkg_version.h>
3 #include <petsc/private/sfimpl.h>
4 #include <../src/mat/impls/aij/seq/kokkos/aijkok.hpp>
5 #include <../src/mat/impls/aij/mpi/mpiaij.h>
6 #include <KokkosSparse_spadd.hpp>
7 #include <KokkosSparse_spgemm.hpp>
8 
9 static PetscErrorCode MatAssemblyEnd_MPIAIJKokkos(Mat A, MatAssemblyType mode)
10 {
11   Mat_MPIAIJ *mpiaij = (Mat_MPIAIJ *)A->data;
12 
13   PetscFunctionBegin;
14   PetscCall(MatAssemblyEnd_MPIAIJ(A, mode));
15   /* E.g., MatCreateSubMatrix() calls MatCreateMPIAIJWithSeqAIJ(comm,A,B,..), which creates Bnew of SEQAIJ and destroys B of SEQAIJKOKKOS.
16      Thus we finalize A/B/lvec's type in MatAssemblyEnd() to handle various cases.
17    */
18   if (mode == MAT_FINAL_ASSEMBLY) {
19     PetscCall(MatSetType(mpiaij->A, MATSEQAIJKOKKOS));
20     PetscCall(MatSetType(mpiaij->B, MATSEQAIJKOKKOS));
21     PetscCall(VecSetType(mpiaij->lvec, VECSEQKOKKOS));
22   }
23   PetscFunctionReturn(PETSC_SUCCESS);
24 }
25 
26 static PetscErrorCode MatMPIAIJSetPreallocation_MPIAIJKokkos(Mat mat, PetscInt d_nz, const PetscInt d_nnz[], PetscInt o_nz, const PetscInt o_nnz[])
27 {
28   Mat_MPIAIJ *mpiaij = (Mat_MPIAIJ *)mat->data;
29 
30   PetscFunctionBegin;
31   // If mat was set to use the "set values with a hash table" mechanism, discard it and restore the cached ops
32   if (mat->hash_active) {
33     mat->ops[0]      = mpiaij->cops;
34     mat->hash_active = PETSC_FALSE;
35   }
36 
37   PetscCall(PetscLayoutSetUp(mat->rmap));
38   PetscCall(PetscLayoutSetUp(mat->cmap));
39 #if defined(PETSC_USE_DEBUG)
40   if (d_nnz) {
41     PetscInt i;
42     for (i = 0; i < mat->rmap->n; i++) PetscCheck(d_nnz[i] >= 0, PETSC_COMM_SELF, PETSC_ERR_ARG_OUTOFRANGE, "d_nnz cannot be less than 0: local row %" PetscInt_FMT " value %" PetscInt_FMT, i, d_nnz[i]);
43   }
44   if (o_nnz) {
45     PetscInt i;
46     for (i = 0; i < mat->rmap->n; i++) PetscCheck(o_nnz[i] >= 0, PETSC_COMM_SELF, PETSC_ERR_ARG_OUTOFRANGE, "o_nnz cannot be less than 0: local row %" PetscInt_FMT " value %" PetscInt_FMT, i, o_nnz[i]);
47   }
48 #endif
49 #if defined(PETSC_USE_CTABLE)
50   PetscCall(PetscHMapIDestroy(&mpiaij->colmap));
51 #else
52   PetscCall(PetscFree(mpiaij->colmap));
53 #endif
54   PetscCall(PetscFree(mpiaij->garray));
55   PetscCall(VecDestroy(&mpiaij->lvec));
56   PetscCall(VecScatterDestroy(&mpiaij->Mvctx));
57   /* Because the B will have been resized we simply destroy it and create a new one each time */
58   PetscCall(MatDestroy(&mpiaij->B));
59 
60   if (!mpiaij->A) {
61     PetscCall(MatCreate(PETSC_COMM_SELF, &mpiaij->A));
62     PetscCall(MatSetSizes(mpiaij->A, mat->rmap->n, mat->cmap->n, mat->rmap->n, mat->cmap->n));
63   }
64   if (!mpiaij->B) {
65     PetscMPIInt size;
66     PetscCallMPI(MPI_Comm_size(PetscObjectComm((PetscObject)mat), &size));
67     PetscCall(MatCreate(PETSC_COMM_SELF, &mpiaij->B));
68     PetscCall(MatSetSizes(mpiaij->B, mat->rmap->n, size > 1 ? mat->cmap->N : 0, mat->rmap->n, size > 1 ? mat->cmap->N : 0));
69   }
70   PetscCall(MatSetType(mpiaij->A, MATSEQAIJKOKKOS));
71   PetscCall(MatSetType(mpiaij->B, MATSEQAIJKOKKOS));
72   PetscCall(MatSeqAIJSetPreallocation(mpiaij->A, d_nz, d_nnz));
73   PetscCall(MatSeqAIJSetPreallocation(mpiaij->B, o_nz, o_nnz));
74   mat->preallocated = PETSC_TRUE;
75   PetscFunctionReturn(PETSC_SUCCESS);
76 }
77 
78 static PetscErrorCode MatMult_MPIAIJKokkos(Mat mat, Vec xx, Vec yy)
79 {
80   Mat_MPIAIJ *mpiaij = (Mat_MPIAIJ *)mat->data;
81   PetscInt    nt;
82 
83   PetscFunctionBegin;
84   PetscCall(VecGetLocalSize(xx, &nt));
85   PetscCheck(nt == mat->cmap->n, PETSC_COMM_SELF, PETSC_ERR_ARG_SIZ, "Incompatible partition of mat (%" PetscInt_FMT ") and xx (%" PetscInt_FMT ")", mat->cmap->n, nt);
86   PetscCall(VecScatterBegin(mpiaij->Mvctx, xx, mpiaij->lvec, INSERT_VALUES, SCATTER_FORWARD));
87   PetscCall((*mpiaij->A->ops->mult)(mpiaij->A, xx, yy));
88   PetscCall(VecScatterEnd(mpiaij->Mvctx, xx, mpiaij->lvec, INSERT_VALUES, SCATTER_FORWARD));
89   PetscCall((*mpiaij->B->ops->multadd)(mpiaij->B, mpiaij->lvec, yy, yy));
90   PetscFunctionReturn(PETSC_SUCCESS);
91 }
92 
93 static PetscErrorCode MatMultAdd_MPIAIJKokkos(Mat mat, Vec xx, Vec yy, Vec zz)
94 {
95   Mat_MPIAIJ *mpiaij = (Mat_MPIAIJ *)mat->data;
96   PetscInt    nt;
97 
98   PetscFunctionBegin;
99   PetscCall(VecGetLocalSize(xx, &nt));
100   PetscCheck(nt == mat->cmap->n, PETSC_COMM_SELF, PETSC_ERR_ARG_SIZ, "Incompatible partition of mat (%" PetscInt_FMT ") and xx (%" PetscInt_FMT ")", mat->cmap->n, nt);
101   PetscCall(VecScatterBegin(mpiaij->Mvctx, xx, mpiaij->lvec, INSERT_VALUES, SCATTER_FORWARD));
102   PetscCall((*mpiaij->A->ops->multadd)(mpiaij->A, xx, yy, zz));
103   PetscCall(VecScatterEnd(mpiaij->Mvctx, xx, mpiaij->lvec, INSERT_VALUES, SCATTER_FORWARD));
104   PetscCall((*mpiaij->B->ops->multadd)(mpiaij->B, mpiaij->lvec, zz, zz));
105   PetscFunctionReturn(PETSC_SUCCESS);
106 }
107 
108 static PetscErrorCode MatMultTranspose_MPIAIJKokkos(Mat mat, Vec xx, Vec yy)
109 {
110   Mat_MPIAIJ *mpiaij = (Mat_MPIAIJ *)mat->data;
111   PetscInt    nt;
112 
113   PetscFunctionBegin;
114   PetscCall(VecGetLocalSize(xx, &nt));
115   PetscCheck(nt == mat->rmap->n, PETSC_COMM_SELF, PETSC_ERR_ARG_SIZ, "Incompatible partition of mat (%" PetscInt_FMT ") and xx (%" PetscInt_FMT ")", mat->rmap->n, nt);
116   PetscCall((*mpiaij->B->ops->multtranspose)(mpiaij->B, xx, mpiaij->lvec));
117   PetscCall((*mpiaij->A->ops->multtranspose)(mpiaij->A, xx, yy));
118   PetscCall(VecScatterBegin(mpiaij->Mvctx, mpiaij->lvec, yy, ADD_VALUES, SCATTER_REVERSE));
119   PetscCall(VecScatterEnd(mpiaij->Mvctx, mpiaij->lvec, yy, ADD_VALUES, SCATTER_REVERSE));
120   PetscFunctionReturn(PETSC_SUCCESS);
121 }
122 
123 /* Merge the "A, B" matrices of mat into a matrix C.  mat's type is MPIAIJKOKKOS. C's type is MATSEQAIJKOKKOS.
124    A is put before B. C's size would be A->rmap->n by (A->cmap->n + B->cmap->n).
125    C still uses local column ids. Their corresponding global column ids are returned in glob.
126 */
127 static PetscErrorCode MatMPIAIJGetLocalMatMerge_MPIAIJKokkos(Mat mat, MatReuse reuse, IS *glob, Mat *C)
128 {
129   Mat             Ad, Ao;
130   const PetscInt *cmap;
131 
132   PetscFunctionBegin;
133   PetscCall(MatMPIAIJGetSeqAIJ(mat, &Ad, &Ao, &cmap));
134   PetscCall(MatSeqAIJKokkosMergeMats(Ad, Ao, reuse, C));
135   if (glob) {
136     PetscInt cst, i, dn, on, *gidx;
137     PetscCall(MatGetLocalSize(Ad, NULL, &dn));
138     PetscCall(MatGetLocalSize(Ao, NULL, &on));
139     PetscCall(MatGetOwnershipRangeColumn(mat, &cst, NULL));
140     PetscCall(PetscMalloc1(dn + on, &gidx));
141     for (i = 0; i < dn; i++) gidx[i] = cst + i;
142     for (i = 0; i < on; i++) gidx[i + dn] = cmap[i];
143     PetscCall(ISCreateGeneral(PetscObjectComm((PetscObject)Ad), dn + on, gidx, PETSC_OWN_POINTER, glob));
144   }
145   PetscFunctionReturn(PETSC_SUCCESS);
146 }
147 
148 /* Structs used in matrix products of type C=AB, C=A^tB and C=B^tAB */
149 struct MatMatStruct {
150   PetscInt            n, *garray;     // C's garray and its size.
151   KokkosCsrMatrix     Cd, Co;         // C is in split form matrices (all in local column indcies)
152   KokkosCsrMatrix     C1, C2, C3, C4; // intermediate mat products
153   KokkosCsrMatrix     C2_mid, C4_mid; // alias of C2, C4; share their a[], i[], but with different j[] (hence column size)
154   PetscIntKokkosView  E_NzLeft;
155   PetscSF             sf = nullptr; // SF to bcast or reduce matrices E to F
156   MatScalarKokkosView rootBuf, leafBuf;
157   KokkosCsrMatrix     Fd, Fo; // F in split form
158 
159   KernelHandle kh1; // compute C1, add C1+C3 or C1+Fd
160   KernelHandle kh2; // compute C2, add C2+C4 or C2+Fo
161   KernelHandle kh3; // compute C3
162   KernelHandle kh4; // compute C4
163 
164   PetscInt E_TeamSize; // kernel launching parameters in merging E or splitting F
165   PetscInt E_VectorLength;
166   PetscInt E_RowsPerTeam;
167   PetscInt F_TeamSize;
168   PetscInt F_VectorLength;
169   PetscInt F_RowsPerTeam;
170 
171   ~MatMatStruct()
172   {
173     PetscFunctionBegin;
174     PetscCallAbort(PETSC_COMM_SELF, PetscSFDestroy(&sf));
175     PetscFunctionReturnVoid();
176   }
177 };
178 
179 struct MatMatStruct_AB : public MatMatStruct {
180   PetscIntKokkosView F_NzLeft; // plans to split F (in leafbuf) into Fd, Fo
181   PetscIntKokkosView irootloc; // plans to put E (i.e., Bd, Bo) into rootBuf
182   PetscIntKokkosView rowoffset;
183 };
184 
185 struct MatMatStruct_AtB : public MatMatStruct {
186   MatColIdxKokkosView Fdjmap; // plans to reduce data in rootBuf to Fd, Fo
187   MatColIdxKokkosView Fdjperm;
188   MatColIdxKokkosView Fojmap;
189   MatColIdxKokkosView Fojperm;
190 };
191 
192 struct MatProductData_MPIAIJKokkos {
193   MatMatStruct_AB  *mmAB     = nullptr;
194   MatMatStruct_AtB *mmAtB    = nullptr;
195   PetscBool         reusesym = PETSC_FALSE;
196   Mat               Z        = nullptr; // store Z=AB in computing BtAB
197 
198   ~MatProductData_MPIAIJKokkos()
199   {
200     delete mmAB;
201     delete mmAtB;
202     PetscCallAbort(PETSC_COMM_SELF, MatDestroy(&Z));
203   }
204 };
205 
206 static PetscErrorCode MatProductDataDestroy_MPIAIJKokkos(void *data)
207 {
208   PetscFunctionBegin;
209   PetscCallCXX(delete static_cast<MatProductData_MPIAIJKokkos *>(data));
210   PetscFunctionReturn(PETSC_SUCCESS);
211 }
212 
213 /* MatSetMPIAIJKokkosWithSplitSeqAIJKokkosMatrices - Set the diag and offdiag matrices of a MATMPIAIJKOKKOS matrix.
214    It is similar to MatCreateMPIAIJWithSplitArrays.
215 
216   Input Parameters:
217 +  mat   - the MATMPIAIJKOKKOS matrix, which should have its type and layout set, but should not have its diag, offdiag matrices set
218 .  A     - the diag matrix using local col ids
219 -  B     - the offdiag matrix using global col ids
220 
221   Output Parameter:
222 .  mat   - the updated MATMPIAIJKOKKOS matrix
223 */
224 static PetscErrorCode MatSetMPIAIJKokkosWithSplitSeqAIJKokkosMatrices(Mat mat, Mat A, Mat B, PetscInt *garray)
225 {
226   Mat_MPIAIJ *mpiaij = static_cast<Mat_MPIAIJ *>(mat->data);
227   PetscInt    m, n, M, N, Am, An, Bm, Bn;
228 
229   PetscFunctionBegin;
230   PetscCall(MatGetSize(mat, &M, &N));
231   PetscCall(MatGetLocalSize(mat, &m, &n));
232   PetscCall(MatGetLocalSize(A, &Am, &An));
233   PetscCall(MatGetLocalSize(B, &Bm, &Bn));
234 
235   PetscCheck(m == Am && m == Bm, PETSC_COMM_SELF, PETSC_ERR_PLIB, "local number of rows do not match");
236   PetscCheck(n == An, PETSC_COMM_SELF, PETSC_ERR_PLIB, "local number of columns do not match");
237   // PetscCheck(N == Bn, PETSC_COMM_SELF, PETSC_ERR_PLIB, "global number of columns do not match");
238   PetscCheck(!mpiaij->A && !mpiaij->B, PETSC_COMM_SELF, PETSC_ERR_PLIB, "A, B of the MPIAIJ matrix are not empty");
239   mpiaij->A      = A;
240   mpiaij->B      = B;
241   mpiaij->garray = garray;
242 
243   mat->preallocated     = PETSC_TRUE;
244   mat->nooffprocentries = PETSC_TRUE; /* See MatAssemblyBegin_MPIAIJ. In effect, making MatAssemblyBegin a nop */
245 
246   PetscCall(MatSetOption(mat, MAT_NO_OFF_PROC_ENTRIES, PETSC_TRUE));
247   PetscCall(MatAssemblyBegin(mat, MAT_FINAL_ASSEMBLY));
248   /* MatAssemblyEnd is critical here. It sets mat->offloadmask according to A and B's, and
249     also gets mpiaij->B compacted, with its col ids and size reduced
250   */
251   PetscCall(MatAssemblyEnd(mat, MAT_FINAL_ASSEMBLY));
252   PetscCall(MatSetOption(mat, MAT_NO_OFF_PROC_ENTRIES, PETSC_FALSE));
253   PetscCall(MatSetOption(mat, MAT_NEW_NONZERO_LOCATION_ERR, PETSC_TRUE));
254   PetscFunctionReturn(PETSC_SUCCESS);
255 }
256 
257 // Adapted from Kokkos-Kernels spmv_launch_parameters(), to get parameters in Kokkos nested loops which we used to merge or
258 // split csr matrices. The rule is to have "vector_length * team_size" be around 256 on GPUs (e.g., for a CUDA thread block)
259 template <class ExecutionSpace>
260 static PetscErrorCode MatMergeGetLaunchParameters(PetscInt numRows, PetscInt nnz, PetscInt rows_per_thread, PetscInt &team_size, PetscInt &vector_length, PetscInt &rows_per_team)
261 {
262   Kokkos::TeamPolicy<ExecutionSpace> teamPolicy(128, Kokkos::AUTO);
263 
264   PetscFunctionBegin;
265   PetscInt nnz_per_row = numRows ? (nnz / numRows) : 0; // we might meet empty matrices
266 
267   if (nnz_per_row < 1) nnz_per_row = 1;
268 
269   int max_vector_length = teamPolicy.vector_length_max();
270 
271   if (vector_length < 1) {
272     vector_length = 1;
273     while (vector_length < max_vector_length && vector_length * 6 < nnz_per_row) vector_length *= 2;
274   }
275 
276   // Determine rows per thread
277   if (rows_per_thread < 1) {
278     if (KokkosKernels::Impl::kk_is_gpu_exec_space<ExecutionSpace>()) rows_per_thread = 1;
279     else {
280       if (nnz_per_row < 20 && nnz > 5000000) {
281         rows_per_thread = 256;
282       } else rows_per_thread = 64;
283     }
284   }
285 
286   if (team_size < 1) {
287     if (KokkosKernels::Impl::kk_is_gpu_exec_space<ExecutionSpace>()) {
288       team_size = 256 / vector_length;
289     } else {
290       team_size = 1;
291     }
292   }
293 
294   rows_per_team = rows_per_thread * team_size;
295 
296   if (rows_per_team < 0) {
297     PetscInt nnz_per_team = 4096;
298     PetscInt conc         = ExecutionSpace().concurrency();
299     while ((conc * nnz_per_team * 4 > nnz) && (nnz_per_team > 256)) nnz_per_team /= 2;
300     rows_per_team = (nnz_per_team + nnz_per_row - 1) / nnz_per_row;
301   }
302   PetscFunctionReturn(PETSC_SUCCESS);
303 }
304 
305 /*
306   Reduce two sets of global indices into local ones
307 
308   Input Parameters:
309 +  n1          - size of garray1[], the first set
310 .  garray1[n1] - a sorted global index array (without duplicates)
311 .  m           - size of indices[], the second set
312 -  indices[m]  - a unsorted global index array (might have duplicates), which will be updated on output into local ones
313 
314   Output Parameters:
315 +  n2          - size of garray2[], the merged set, which combines garray1[] and indices[]
316 .  garray2[n2] - allocated by callee using PetscMalloc1(). Contains sorted unique global indices (without duplicates). Caller needs to free it.
317 .  map[n1]     - allocated by caller. It gives garray1[i] = garray2[map[i]]
318 -  indices[m]  - on output, global indices in this array are rewritten with local ones, i.e, indices_input[i] = garray2[indices_output[i]]
319 
320    Example, say
321     n1         = 5
322     garray1[5] = {1, 4, 7, 8, 10}
323     m          = 4
324     indices[4] = {2, 4, 8, 9}
325 
326    Combining them together, we have 7 global indices in garray2[]
327     n2         = 7
328     garray2[7] = {1, 2, 4, 7, 8, 9, 10}
329 
330    And we have map[] to connect "garray1[i] = garray2[map[i]], i=[0,n1)"
331     map[5] = {0, 2, 3, 4, 6}
332 
333    On output, indices[] is updated with local indices
334     indices[4] = {1, 2, 4, 5}
335 */
336 static PetscErrorCode ReduceTwoSetsOfGlobalIndices(PetscInt n1, const PetscInt *garray1, PetscInt m, PetscInt *indices, PetscInt *n2_, PetscInt **garray2_, PetscInt *map)
337 {
338   PetscHMapI    g2l = nullptr;
339   PetscHashIter iter;
340   PetscInt      tot, key, val; // total unique global indices. key is global id; val is local id
341   PetscInt      n2, *garray2;
342 
343   PetscFunctionBegin;
344   tot = 0;
345   PetscCall(PetscHMapICreateWithSize(n1, &g2l));
346   for (PetscInt i = 0; i < m; i++) {                                // insert those in indices[]
347     PetscCall(PetscHMapIGetWithDefault(g2l, indices[i], -1, &val)); // if not exist, val is set with -1
348     if (val < 0) PetscCall(PetscHMapISet(g2l, indices[i], tot++));  // val < 0 means gid is not in the hash table yet
349   }
350 
351   for (PetscInt i = 0; i < n1; i++) { // insert those in garray1[]
352     PetscCall(PetscHMapIGetWithDefault(g2l, garray1[i], -1, &val));
353     if (val < 0) PetscCall(PetscHMapISet(g2l, garray1[i], tot++));
354   }
355 
356   // Pull out (unique) globals in the hash table and put them in garray2[]
357   n2 = tot;
358   PetscCall(PetscMalloc1(n2, &garray2));
359   tot = 0;
360   PetscHashIterBegin(g2l, iter);
361   while (!PetscHashIterAtEnd(g2l, iter)) {
362     PetscHashIterGetKey(g2l, iter, key);
363     PetscHashIterNext(g2l, iter);
364     garray2[tot++] = key;
365   }
366 
367   // Sort garray2[] and then map them to local indices starting from 0
368   PetscCall(PetscSortInt(n2, garray2));
369   PetscCall(PetscHMapIClear(g2l));
370   for (PetscInt i = 0; i < tot; i++) PetscCall(PetscHMapISet(g2l, garray2[i], i)); // i is the local id
371 
372   // Rewrite indices[] with local indices
373   for (PetscInt i = 0; i < m; i++) {
374     PetscCall(PetscHMapIGetWithDefault(g2l, indices[i], -1, &val));
375     PetscAssert(val >= 0, PETSC_COMM_SELF, PETSC_ERR_PLIB, "Met a negative local column index");
376     indices[i] = val;
377   }
378   // Record the map that maps garray1[i] to garray2[map[i]]
379   for (PetscInt i = 0; i < n1; i++) PetscCall(PetscHMapIGetWithDefault(g2l, garray1[i], -1, &map[i]));
380   PetscCall(PetscHMapIDestroy(&g2l));
381   *n2_      = n2;
382   *garray2_ = garray2;
383   PetscFunctionReturn(PETSC_SUCCESS);
384 }
385 
386 /*
387   MatMPIAIJKokkosReduce - Reduce rows of a MPIAIJKOKKOS matrix (E, in split form) to produce another matrix (F, also in split form, stored in mm)
388 
389   It is the reverse of MatMPIAIJKokkosBcast() in some sense, but with a different signature since we do not really need a fully populated MPIAIJKOKKOS E.
390 
391   Think each row of E as a leaf, then the given ownerSF specifies roots for the leaves. Roots may connect to multiple leaves.
392   In this routine, we sparse-merge leaves (rows) at their roots to form potentially longer rows in F. F's number of rows will be nroots of ownerSF.
393 
394   Input Parameters:
395 +  comm       - MPI communicator of E
396 .  A          - diag block of E, using local column indices
397 .  B          - off-diag block of E, using local column indices
398 .  cstart      - (global) start column of Ed
399 .  cend        - (global) end column + 1 of Ed.  In other words, E's column ownership is in range of [cstart, cend)
400 .  garray1[n1] - global column indices of Eo. Here n1 is Eo's column size.
401 .  ownerSF     - the SF specifies ownership (root) of rows in E
402 .  reuse       - either MAT_INITIAL_MATRIX or MAT_REUSE_MATRIX
403 -  mm          - to stash intermediate data structures for reuse
404 
405   Output Parameters:
406 +  map[n1]  - allocated by caller. It maps garray1[] to garray2[]. See more at ReduceTwoSetsOfGlobalIndices().
407 -  mm       - contains various info, such as garray2[], F (Fd, Fo) etc.
408 
409   Notes:
410   When reuse = MAT_REUSE_MATRIX, cstart, cend, garray1, ownerSF, map are not significant.
411 
412  */
413 static PetscErrorCode MatMPIAIJKokkosReduceBegin(MPI_Comm comm, KokkosCsrMatrix A, KokkosCsrMatrix B, PetscInt cstart, PetscInt cend, const PetscInt *garray1, PetscSF ownerSF, MatReuse reuse, PetscInt *map, MatMatStruct_AtB *mm)
414 {
415   PetscFunctionBegin;
416   if (reuse == MAT_INITIAL_MATRIX) {
417     PetscInt Em = A.numRows(), Fm;
418     PetscInt n1 = B.numCols();
419 
420     PetscCall(PetscSFGetGraph(ownerSF, &Fm, NULL, NULL, NULL)); // Fm = #rows of F = nroots of ownerSF
421 
422     // Do the analysis on host
423     auto                 Ai_h = Kokkos::create_mirror_view_and_copy(Kokkos::HostSpace(), A.graph.row_map);
424     auto                 Aj_h = Kokkos::create_mirror_view_and_copy(Kokkos::HostSpace(), A.graph.entries);
425     auto                 Bi_h = Kokkos::create_mirror_view_and_copy(Kokkos::HostSpace(), B.graph.row_map);
426     auto                 Bj_h = Kokkos::create_mirror_view_and_copy(Kokkos::HostSpace(), B.graph.entries);
427     const MatRowMapType *Ai = Ai_h.data(), *Bi = Bi_h.data();
428     const MatColIdxType *Aj = Aj_h.data(), *Bj = Bj_h.data();
429 
430     // Count how many nonzeros of each row in E are in the left of the diag block [cstart,cend)
431     PetscIntKokkosViewHost E_NzLeft_h(NoInit("E_NzLeft_h"), Em), E_RowLen_h(NoInit("E_RowLen_h"), Em);
432     PetscInt              *E_NzLeft = E_NzLeft_h.data(), *E_RowLen = E_RowLen_h.data();
433     for (PetscInt i = 0; i < Em; i++) {
434       const PetscInt *first, *last, *it;
435       PetscInt        count, step;
436       // std::lower_bound(first,last,cstart), but need to use global column indices
437       first = Bj + Bi[i];
438       last  = Bj + Bi[i + 1];
439       count = last - first;
440       while (count > 0) {
441         it   = first;
442         step = count / 2;
443         it += step;
444         if (garray1[*it] < cstart) { // map local to global
445           first = ++it;
446           count -= step + 1;
447         } else count = step;
448       }
449       E_NzLeft[i] = first - (Bj + Bi[i]);
450       E_RowLen[i] = (Ai[i + 1] - Ai[i]) + (Bi[i + 1] - Bi[i]);
451     }
452 
453     // Get length of rows (i.e., sizes of leaves) that contribute to my roots
454     const PetscMPIInt *iranks, *ranks;
455     const PetscInt    *ioffset, *irootloc, *roffset, *rmine;
456     PetscInt           niranks, nranks;
457     MPI_Request       *reqs;
458     PetscMPIInt        tag;
459     PetscSF            reduceSF;
460     PetscInt          *sdisp, *rdisp;
461 
462     PetscCall(PetscCommGetNewTag(comm, &tag));
463     PetscCall(PetscSFGetLeafRanks(ownerSF, &niranks, &iranks, &ioffset, &irootloc));  // get leaf ranks connecting to roots on this process (I'll recv from them)
464     PetscCall(PetscSFGetRootRanks(ownerSF, &nranks, &ranks, &roffset, &rmine, NULL)); // get root ranks this process connects (I'll send to them)
465 
466     // Find out length of each row I will receive. Even for the same row index, when they are from
467     // different senders, they might have different lengths (and sparsity patterns)
468     PetscInt  sendRowCnt = roffset[nranks], recvRowCnt = ioffset[niranks];
469     PetscInt *sendRowLen, *recvRowLen; // lengths of rows of I need to send/recv per process
470 
471     PetscCall(PetscMalloc5(sendRowCnt, &sendRowLen, recvRowCnt + 1, &recvRowLen, nranks, &sdisp, niranks + 1, &rdisp, nranks + niranks, &reqs));
472 
473     for (PetscInt i = 0; i < sendRowCnt; i++) sendRowLen[i] = E_RowLen[rmine[i]];
474     recvRowLen[0] = 0; // since we will make it in CSR format later
475     recvRowLen++;      // advance the pointer now
476     for (PetscInt i = 0; i < niranks; i++) { MPI_Irecv(&recvRowLen[ioffset[i]], ioffset[i + 1] - ioffset[i], MPIU_INT, iranks[i], tag, comm, &reqs[nranks + i]); }
477     for (PetscInt i = 0; i < nranks; i++) { MPI_Isend(&sendRowLen[roffset[i]], roffset[i + 1] - roffset[i], MPIU_INT, ranks[i], tag, comm, &reqs[i]); }
478     PetscCallMPI(MPI_Waitall(nranks + niranks, reqs, MPI_STATUSES_IGNORE));
479 
480     // Build the real PetscSF for reducing E rows (buffer to buffer)
481     rdisp[0] = 0;
482     for (PetscInt i = 0; i < niranks; i++) {
483       rdisp[i + 1] = rdisp[i];
484       for (PetscInt j = ioffset[i]; j < ioffset[i + 1]; j++) { rdisp[i + 1] += recvRowLen[j]; }
485     }
486     recvRowLen--; // put it back into csr format
487     for (PetscInt i = 0; i < recvRowCnt; i++) recvRowLen[i + 1] += recvRowLen[i];
488 
489     for (PetscInt i = 0; i < nranks; i++) { MPI_Irecv(&sdisp[i], 1, MPIU_INT, ranks[i], tag, comm, &reqs[i]); }
490     for (PetscInt i = 0; i < niranks; i++) { MPI_Isend(&rdisp[i], 1, MPIU_INT, iranks[i], tag, comm, &reqs[nranks + i]); }
491     PetscCallMPI(MPI_Waitall(nranks + niranks, reqs, MPI_STATUSES_IGNORE));
492 
493     PetscInt     nleaves = 0, Enz = 0;    // leaves are nonzeros I will send
494     PetscInt     nroots = rdisp[niranks]; // roots are nonzeros I will recv
495     PetscSFNode *iremote;
496 
497     for (PetscInt i = 0; i < Em; i++) Enz += E_RowLen[i];
498     PetscAssert(A.nnz() + B.nnz() == Enz, comm, PETSC_ERR_PLIB, "Enz should be equal to sum of nnz of A and B");
499     PetscCallMPI(PetscMalloc1(Enz, &iremote)); // no free, since we give ownership to reduceSF
500 
501     for (PetscInt i = 0; i < nranks; i++) {
502       PetscInt count = 0;
503       for (PetscInt j = roffset[i]; j < roffset[i + 1]; j++) count += E_RowLen[rmine[j]];
504       for (PetscInt j = 0; j < count; j++) {
505         iremote[nleaves + j].rank  = ranks[i];
506         iremote[nleaves + j].index = sdisp[i] + j;
507       }
508       nleaves += count;
509     }
510     PetscCheck(nleaves == Enz, comm, PETSC_ERR_PLIB, "nleaves should be equal to Enz");
511 
512     PetscCall(PetscSFCreate(comm, &reduceSF));
513     PetscCall(PetscSFSetGraph(reduceSF, nroots, nleaves, NULL, PETSC_OWN_POINTER, iremote, PETSC_OWN_POINTER));
514 
515     // Copy (global) column indices of the needed rows in E to sendCol[], and then PetscSFReduce to recvCol[]
516     PetscInt *sendCol, *recvCol;
517     PetscCall(PetscMalloc2(nleaves, &sendCol, nroots, &recvCol));
518     for (PetscInt k = 0; k < roffset[nranks]; k++) {
519       PetscInt  i      = rmine[k]; // row to be copied
520       PetscInt *buf    = &sendCol[Ai[i] + Bi[i]];
521       PetscInt  nzLeft = E_NzLeft[i];
522       PetscInt  alen = Ai[i + 1] - Ai[i], blen = Bi[i + 1] - Bi[i];
523       for (PetscInt j = 0; j < alen + blen; j++) {
524         if (j < nzLeft) {
525           buf[j] = garray1[Bj[Bi[i] + j]]; // left B, in global
526         } else if (j < nzLeft + alen) {
527           buf[j] = Aj[Ai[i] + j - nzLeft] + cstart; // diag A, also in global
528         } else {
529           buf[j] = garray1[Bj[Bi[i] + j - alen]]; // right B, in global
530         }
531       }
532     }
533     PetscCall(PetscSFReduceWithMemTypeBegin(reduceSF, MPIU_INT, PETSC_MEMTYPE_HOST, sendCol, PETSC_MEMTYPE_HOST, recvCol, MPI_REPLACE));
534     PetscCall(PetscSFReduceEnd(reduceSF, MPIU_INT, sendCol, recvCol, MPI_REPLACE));
535 
536     // With recvCol[], we do a series of analysis to get i, j of Fd, Fo, and build plans to reduce nonzeros in recv buffers to Fd and Fo
537     PetscInt *recvRowPerm, *recvColSorted;
538     PetscInt *recvNzPerm, *recvNzPermSorted;
539     PetscCall(PetscMalloc4(recvRowCnt, &recvRowPerm, nroots, &recvColSorted, nroots, &recvNzPerm, nroots, &recvNzPermSorted));
540 
541     for (PetscInt i = 0; i < nroots; i++) recvNzPerm[i] = i;                   // numbering all received nonzeros
542     for (PetscInt i = 0; i < recvRowCnt; i++) recvRowPerm[i] = i;              // put up a permutation array, so that after sorting we know where to get a row in recvCol[]
543     PetscCall(PetscSortIntWithPermutation(recvRowCnt, irootloc, recvRowPerm)); // irootloc[] (owned by ownerSF) won't be changed
544 
545     // i[] array, nz are always easiest to compute
546     MatRowMapKokkosViewHost Fdi_h(NoInit("Fdi_h"), Fm + 1), Foi_h(NoInit("Foi_h"), Fm + 1);
547     MatRowMapType          *Fdi, *Foi;
548     PetscInt                FnzDups = 0, Fdnz = 0, FdnzDups = 0, Fonz = 0, FonzDups = 0; // nz (with or without dups) in F, Fd, Fo
549     PetscInt                iter;
550 
551     Kokkos::deep_copy(Fdi_h, 0); // zero, as we will do 'val++' on them
552     Kokkos::deep_copy(Foi_h, 0);
553     Fdi  = Fdi_h.data() + 1; // +1 for easy indexing in code below
554     Foi  = Foi_h.data() + 1;
555     iter = 0;
556     while (iter < recvRowCnt) { // iter over received rows
557       PetscInt curRowIdx = irootloc[recvRowPerm[iter]];
558       PetscInt dupRows   = 1; // current row has this many contributing rows (of various sparsity patterns)
559 
560       while (iter + dupRows < recvRowCnt && irootloc[recvRowPerm[iter + dupRows]] == curRowIdx) dupRows++;
561 
562       // Copy column indices (and their permutation) of these rows into recvColSorted & recvNzPermSorted
563       PetscInt  nz    = 0; // nz (with dups) in the current row
564       PetscInt *jbuf  = recvColSorted + FnzDups;
565       PetscInt *pbuf  = recvNzPermSorted + FnzDups;
566       PetscInt *jbuf2 = jbuf; // temp pointers
567       PetscInt *pbuf2 = pbuf;
568       for (PetscInt d = 0; d < dupRows; d++) {
569         PetscInt i   = recvRowPerm[iter + d];
570         PetscInt len = recvRowLen[i + 1] - recvRowLen[i];
571         PetscCall(PetscArraycpy(jbuf2, &recvCol[recvRowLen[i]], len));
572         PetscCall(PetscArraycpy(pbuf2, &recvNzPerm[recvRowLen[i]], len));
573         jbuf2 += len;
574         pbuf2 += len;
575         nz += len;
576       }
577       PetscCall(PetscIntSortSemiOrderedWithArray(nz, jbuf, pbuf)); // It could be improved with k-way merge sort, since the rows are already sorted
578 
579       // Scan column indices (in jbuf[0,nz), might have dups) of this row, and see how many go to Fd and how many go to Fo
580       PetscInt cur = 0;
581       while (cur < nz) {
582         PetscInt curColIdx = jbuf[cur];
583         PetscInt dups      = 1;
584 
585         while (cur + dups < nz && jbuf[cur + dups] == curColIdx) dups++;
586         if (curColIdx >= cstart && curColIdx < cend) {
587           Fdi[curRowIdx]++;
588           FdnzDups += dups;
589         } else {
590           Foi[curRowIdx]++;
591           FonzDups += dups;
592         }
593         cur += dups;
594       }
595 
596       FnzDups += nz;
597       iter += dupRows; // Move to next unique row
598     }
599 
600     Fdi = Fdi_h.data(); // restore Fdi, Foi and make them CSR
601     Foi = Foi_h.data();
602     for (PetscInt i = 0; i < Fm; i++) {
603       Fdi[i + 1] += Fdi[i];
604       Foi[i + 1] += Foi[i];
605     }
606     Fdnz = Fdi[Fm];
607     Fonz = Foi[Fm];
608     PetscCall(PetscFree2(sendCol, recvCol));
609 
610     // Allocate j, jmap, jperm for Fd and Fo
611     MatColIdxKokkosViewHost Fdj_h(NoInit("Fdj_h"), Fdnz), Foj_h(NoInit("Foj_h"), Fonz);
612     MatRowMapKokkosViewHost Fdjmap_h(NoInit("Fdjmap_h"), Fdnz + 1), Fojmap_h(NoInit("Fojmap_h"), Fonz + 1); // +1 to make csr
613     MatRowMapKokkosViewHost Fdjperm_h(NoInit("Fdjperm_h"), FdnzDups), Fojperm_h(NoInit("Fojperm_h"), FonzDups);
614     MatColIdxType          *Fdj = Fdj_h.data(), *Foj = Foj_h.data();
615     MatRowMapType          *Fdjmap = Fdjmap_h.data(), *Fojmap = Fojmap_h.data();
616     MatRowMapType          *Fdjperm = Fdjperm_h.data(), *Fojperm = Fojperm_h.data();
617 
618     // Scan recvColSorted[] again, and fill j, jmap, jperm for Fd and Fo
619     Fdjmap[0] = 0;
620     Fojmap[0] = 0;
621     FnzDups   = 0;
622     Fdnz      = 0;
623     Fonz      = 0;
624     iter      = 0; // iter over received rows
625     while (iter < recvRowCnt) {
626       PetscInt curRowIdx = irootloc[recvRowPerm[iter]]; // current row idx
627       PetscInt dupRows   = 1;                           // It has this many contributing rows (of various lengths)
628       PetscInt nz        = 0;                           // nz (with dups) in the current row
629 
630       while (iter + dupRows < recvRowCnt && irootloc[recvRowPerm[iter + dupRows]] == curRowIdx) dupRows++;
631       for (PetscInt d = 0; d < dupRows; d++) {
632         PetscInt i = recvRowPerm[iter + d];
633         nz += recvRowLen[i + 1] - recvRowLen[i];
634       }
635 
636       PetscInt *jbuf = recvColSorted + FnzDups;
637       // Scan columns (in jbuf[0,nz) of this row, copy them and their permutation to j[] and jperm[] of Fd and Fo
638       PetscInt cur = 0;
639       while (cur < nz) {
640         PetscInt curColIdx = jbuf[cur];
641         PetscInt dups      = 1;
642 
643         while (cur + dups < nz && jbuf[cur + dups] == curColIdx) dups++;
644         if (curColIdx >= cstart && curColIdx < cend) {
645           Fdj[Fdnz]        = curColIdx - cstart; // easily convert to local
646           Fdjmap[Fdnz + 1] = Fdjmap[Fdnz] + dups;
647           for (PetscInt j = 0; j < dups; j++) Fdjperm[Fdjmap[Fdnz] + j] = recvNzPermSorted[FnzDups + j];
648           FdnzDups += dups;
649           Fdnz++;
650         } else {
651           Foj[Fonz]        = curColIdx; // in global
652           Fojmap[Fonz + 1] = Fojmap[Fonz] + dups;
653           for (PetscInt j = 0; j < dups; j++) Fojperm[Fojmap[Fonz] + j] = recvNzPermSorted[FnzDups + j];
654           FonzDups += dups;
655           Fonz++;
656         }
657         cur += dups;
658         FnzDups += dups;
659       }
660       iter += dupRows; // Move to next unique row
661     }
662     PetscCall(PetscFree4(recvRowPerm, recvColSorted, recvNzPerm, recvNzPermSorted));
663     PetscCall(PetscFree5(sendRowLen, recvRowLen, sdisp, rdisp, reqs));
664 
665     // Combine global column indices in garray1[] and Foj[]
666     PetscInt n2, *garray2;
667 
668     PetscCall(ReduceTwoSetsOfGlobalIndices(n1, garray1, Fonz, Foj, &n2, &garray2, map));
669     mm->sf       = reduceSF;
670     mm->leafBuf  = MatScalarKokkosView(NoInit("leafBuf"), nleaves);
671     mm->rootBuf  = MatScalarKokkosView(NoInit("rootBuf"), nroots);
672     mm->garray   = garray2; // give ownership, so no free
673     mm->n        = n2;
674     mm->E_NzLeft = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), E_NzLeft_h);
675     mm->Fdjmap   = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Fdjmap_h);
676     mm->Fdjperm  = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Fdjperm_h);
677     mm->Fojmap   = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Fojmap_h);
678     mm->Fojperm  = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Fojperm_h);
679 
680     // Output Fd and Fo in KokkosCsrMatrix format
681     MatScalarKokkosView Fda_d(NoInit("Fda_d"), Fdnz);
682     MatRowMapKokkosView Fdi_d = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Fdi_h);
683     MatColIdxKokkosView Fdj_d = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Fdj_h);
684     MatScalarKokkosView Foa_d(NoInit("Foa_d"), Fonz);
685     MatRowMapKokkosView Foi_d = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Foi_h);
686     MatColIdxKokkosView Foj_d = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Foj_h);
687 
688     PetscCallCXX(mm->Fd = KokkosCsrMatrix("Fd", Fm, cend - cstart, Fdnz, Fda_d, Fdi_d, Fdj_d));
689     PetscCallCXX(mm->Fo = KokkosCsrMatrix("Fo", Fm, n2, Fonz, Foa_d, Foi_d, Foj_d)); // Fo's column size is n2, length of garray2[]
690 
691     // Compute kernel launch parameters in merging E
692     PetscInt teamSize, vectorLength, rowsPerTeam;
693 
694     teamSize = vectorLength = rowsPerTeam = -1;
695     PetscCall(MatMergeGetLaunchParameters<DefaultExecutionSpace>(Em, Enz, -1, teamSize, vectorLength, rowsPerTeam));
696     mm->E_TeamSize     = teamSize;
697     mm->E_VectorLength = vectorLength;
698     mm->E_RowsPerTeam  = rowsPerTeam;
699   } else PetscCheck(reuse == MAT_REUSE_MATRIX, comm, PETSC_ERR_PLIB, "Unsupported MatReuse enum %d", reuse);
700 
701   // Handy aliases
702   auto       &Aa           = A.values;
703   auto       &Ba           = B.values;
704   const auto &Ai           = A.graph.row_map;
705   const auto &Bi           = B.graph.row_map;
706   const auto &E_NzLeft     = mm->E_NzLeft;
707   auto       &leafBuf      = mm->leafBuf;
708   auto       &rootBuf      = mm->rootBuf;
709   PetscSF     reduceSF     = mm->sf;
710   PetscInt    Em           = A.numRows();
711   PetscInt    teamSize     = mm->E_TeamSize;
712   PetscInt    vectorLength = mm->E_VectorLength;
713   PetscInt    rowsPerTeam  = mm->E_RowsPerTeam;
714   PetscInt    workSets     = (Em + rowsPerTeam - 1) / rowsPerTeam;
715 
716   // Copy rows in A/B of E to leafBuf, then pass it to rootBuf
717   PetscCallCXX(Kokkos::parallel_for(
718     Kokkos::TeamPolicy<>(workSets, teamSize, vectorLength), KOKKOS_LAMBDA(const KokkosTeamMemberType &t) {
719       Kokkos::parallel_for(Kokkos::TeamThreadRange(t, 0, rowsPerTeam), [&](PetscInt k) {
720         PetscInt i = t.league_rank() * rowsPerTeam + k; // i-th row in F
721         if (i < Em) {
722           PetscInt disp   = Ai(i) + Bi(i);
723           PetscInt alen   = Ai(i + 1) - Ai(i);
724           PetscInt blen   = Bi(i + 1) - Bi(i);
725           PetscInt nzleft = E_NzLeft(i);
726 
727           Kokkos::parallel_for(Kokkos::ThreadVectorRange(t, alen + blen), [&](PetscInt j) {
728             MatScalar &val = leafBuf(disp + j);
729             if (j < nzleft) { // B left
730               val = Ba(Bi(i) + j);
731             } else if (j < nzleft + alen) { // diag A
732               val = Aa(Ai(i) + j - nzleft);
733             } else { // B right
734               val = Ba(Bi(i) + j - alen);
735             }
736           });
737         }
738       });
739     }));
740   PetscCall(PetscSFReduceWithMemTypeBegin(reduceSF, MPIU_SCALAR, PETSC_MEMTYPE_KOKKOS, leafBuf.data(), PETSC_MEMTYPE_KOKKOS, rootBuf.data(), MPI_REPLACE));
741   PetscFunctionReturn(PETSC_SUCCESS);
742 }
743 
744 // To finish MatMPIAIJKokkosReduce.
745 static PetscErrorCode MatMPIAIJKokkosReduceEnd(MPI_Comm comm, KokkosCsrMatrix A, KokkosCsrMatrix B, PetscInt cstart, PetscInt cend, const PetscInt *garray1, PetscSF ownerSF, MatReuse reuse, PetscInt *map, MatMatStruct_AtB *mm)
746 {
747   PetscFunctionBegin;
748   auto       &leafBuf  = mm->leafBuf;
749   auto       &rootBuf  = mm->rootBuf;
750   auto       &Fda      = mm->Fd.values;
751   const auto &Fdjmap   = mm->Fdjmap;
752   const auto &Fdjperm  = mm->Fdjperm;
753   auto        Fdnz     = mm->Fd.nnz();
754   auto       &Foa      = mm->Fo.values;
755   const auto &Fojmap   = mm->Fojmap;
756   const auto &Fojperm  = mm->Fojperm;
757   auto        Fonz     = mm->Fo.nnz();
758   PetscSF     reduceSF = mm->sf;
759 
760   PetscCall(PetscSFReduceEnd(reduceSF, MPIU_SCALAR, leafBuf.data(), rootBuf.data(), MPI_REPLACE));
761 
762   // Reduce data in rootBuf to Fd and Fo
763   PetscCallCXX(Kokkos::parallel_for(
764     Fdnz, KOKKOS_LAMBDA(const MatRowMapType i) {
765       PetscScalar sum = 0.0;
766       for (MatRowMapType k = Fdjmap(i); k < Fdjmap(i + 1); k++) sum += rootBuf(Fdjperm(k));
767       Fda(i) = sum;
768     }));
769 
770   PetscCallCXX(Kokkos::parallel_for(
771     Fonz, KOKKOS_LAMBDA(const MatRowMapType i) {
772       PetscScalar sum = 0.0;
773       for (MatRowMapType k = Fojmap(i); k < Fojmap(i + 1); k++) sum += rootBuf(Fojperm(k));
774       Foa(i) = sum;
775     }));
776   PetscFunctionReturn(PETSC_SUCCESS);
777 }
778 
779 /*
780   MatMPIAIJKokkosBcast - Bcast local rows of a MPIAIJKOKKOS matrix (E) to produce a local matrix (F, stored in mm) in split form
781 
782   This is a complex routine. It is essentially the MPIAIJKOKKOS counterpart of MatGetBrowsOfAoCols_MPIAIJ, but supports
783   device and involves various index mapping.
784 
785   In the given ownerSF, leaves correspond to rows in F, and roots correspond to rows in E. Roots may connect to multiple leaves.
786   Suppose F's j-th row is connected to a root identified by PetscSFNode (k,i), it means we need to bcast the i-th row of E on rank k
787   to j-th row of F. ownerSF is not an arbitrary SF, instead it is the Mvctx of another MPIAIJ matrix A that is able to perform A*E.
788   F has the same column layout as E.
789 
790   Conceptually F has global column indices. In this routine, we spit F into diagonal Fd and off-diagonal Fo.
791   Fd uses local column indices, which are easy to compute. We just need to subtract the "local column range start" from the global indices.
792   Fo had global column indices at first. We will reduce them into local ones. In doing that, we also take into account the global
793   column indices that E's off-diag block has. Let's say there are n1 such indices stored in garray1[]. We will reduce them along with
794   column indices in Fo and update Fo with local indices.
795 
796    Input Parameters:
797 +   E       - the MPIAIJKOKKOS matrix
798 .   ownerSF - the ownership SF (insignificant in MAT_REUSE_MATRIX)
799 .   reuse   - either MAT_INITIAL_MATRIX or MAT_REUSE_MATRIX
800 -   mm      - to stash matproduct intermediate data structures
801 
802     Output Parameters:
803 +   map[n1] - allocated by caller. It maps garray1[] to garray2[]. See more at ReduceTwoSetsOfGlobalIndices.
804 -   mm      - contains various info, such as garray2[], Fd, Fo, etc.
805 
806     Notes:
807     When reuse = MAT_REUSE_MATRIX, ownerSF, map are not significant.
808     The routine is provide in split-phase form MatMPIAIJKokkosBcastBegin/End() to provide computation/communication opportunities.
809 */
810 static PetscErrorCode MatMPIAIJKokkosBcastBegin(Mat E, PetscSF ownerSF, MatReuse reuse, PetscInt *map, MatMatStruct_AB *mm)
811 {
812   Mat_MPIAIJ       *empi = static_cast<Mat_MPIAIJ *>(E->data);
813   Mat               A = empi->A, B = empi->B; // diag and off-diag
814   Mat_SeqAIJKokkos *akok = static_cast<Mat_SeqAIJKokkos *>(A->spptr), *bkok = static_cast<Mat_SeqAIJKokkos *>(B->spptr);
815   PetscInt          Em = E->rmap->n; // #local rows
816   MPI_Comm          comm;
817 
818   PetscFunctionBegin;
819   PetscCallMPI(PetscObjectGetComm((PetscObject)E, &comm));
820   if (reuse == MAT_INITIAL_MATRIX) {
821     Mat_SeqAIJ     *aseq = static_cast<Mat_SeqAIJ *>(A->data), *bseq = static_cast<Mat_SeqAIJ *>(B->data);
822     PetscInt        n1 = B->cmap->n, *Ai = aseq->i, *Aj = aseq->j, *Bi = bseq->i, *Bj = bseq->j;
823     const PetscInt *garray1 = empi->garray; // its size is n1
824     PetscInt        cstart, cend;
825     PetscSF         bcastSF;
826 
827     PetscCall(MatGetOwnershipRangeColumn(E, &cstart, &cend));
828 
829     // Count how many nonzeros of each row in E are in the left of the diag block [cstart,cend)
830     PetscIntKokkosViewHost E_NzLeft_h(NoInit("E_NzLeft_h"), Em), E_RowLen_h(NoInit("E_RowLen_h"), Em);
831     PetscInt              *E_NzLeft = E_NzLeft_h.data(), *E_RowLen = E_RowLen_h.data();
832     for (PetscInt i = 0; i < Em; i++) {
833       const PetscInt *first, *last, *it;
834       PetscInt        count, step;
835       // std::lower_bound(first,last,cstart), but need to use global column indices
836       first = Bj + Bi[i];
837       last  = Bj + Bi[i + 1];
838       count = last - first;
839       while (count > 0) {
840         it   = first;
841         step = count / 2;
842         it += step;
843         if (empi->garray[*it] < cstart) { // map local to global
844           first = ++it;
845           count -= step + 1;
846         } else count = step;
847       }
848       E_NzLeft[i] = first - (Bj + Bi[i]);
849       E_RowLen[i] = (Ai[i + 1] - Ai[i]) + (Bi[i + 1] - Bi[i]);
850     }
851 
852     // Compute row pointer Fi of F
853     PetscInt *Fi, Fm, Fnz;
854     PetscCall(PetscSFGetGraph(ownerSF, NULL, &Fm, NULL, NULL)); // Fm = #rows of F = nleaves of ownerSF
855     PetscCall(PetscMalloc1(Fm + 1, &Fi));
856     Fi[0] = 0;
857     PetscCall(PetscSFBcastWithMemTypeBegin(ownerSF, MPIU_INT, PETSC_MEMTYPE_HOST, E_RowLen, PETSC_MEMTYPE_HOST, &Fi[1], MPI_REPLACE));
858     PetscCall(PetscSFBcastEnd(ownerSF, MPIU_INT, E_RowLen, &Fi[1], MPI_REPLACE));
859     for (PetscInt i = 0; i < Fm; i++) Fi[i + 1] += Fi[i];
860     Fnz = Fi[Fm];
861 
862     // Build the real PetscSF for bcasting E rows (buffer to buffer)
863     const PetscMPIInt *iranks, *ranks;
864     const PetscInt    *ioffset, *irootloc, *roffset;
865     PetscInt           niranks, nranks, *sdisp, *rdisp;
866     MPI_Request       *reqs;
867     PetscMPIInt        tag;
868 
869     PetscCall(PetscSFGetLeafRanks(ownerSF, &niranks, &iranks, &ioffset, &irootloc)); // get leaf ranks referencing roots on this process
870     PetscCall(PetscSFGetRootRanks(ownerSF, &nranks, &ranks, &roffset, NULL, NULL));  // recv info
871     PetscCall(PetscMalloc3(niranks + 1, &sdisp, nranks, &rdisp, niranks + nranks, &reqs));
872 
873     sdisp[0] = 0; // send displacement
874     for (PetscInt i = 0; i < niranks; i++) {
875       sdisp[i + 1] = sdisp[i];
876       for (PetscInt j = ioffset[i]; j < ioffset[i + 1]; j++) {
877         PetscInt r = irootloc[j]; // row to be sent
878         sdisp[i + 1] += E_RowLen[r];
879       }
880     }
881 
882     PetscCallMPI(PetscCommGetNewTag(comm, &tag));
883     for (PetscInt i = 0; i < nranks; i++) PetscCallMPI(MPI_Irecv(&rdisp[i], 1, MPIU_INT, ranks[i], tag, comm, &reqs[i]));
884     for (PetscInt i = 0; i < niranks; i++) PetscCallMPI(MPI_Isend(&sdisp[i], 1, MPIU_INT, iranks[i], tag, comm, &reqs[nranks + i]));
885     PetscCallMPI(MPI_Waitall(niranks + nranks, reqs, MPI_STATUSES_IGNORE));
886 
887     PetscInt     nleaves = Fnz;            // leaves are nonzeros I will receive
888     PetscInt     nroots  = sdisp[niranks]; // roots are nonzeros I will send
889     PetscSFNode *iremote;                  // give ownership to bcastSF
890     PetscCall(PetscMalloc1(nleaves, &iremote));
891     for (PetscInt i = 0; i < nranks; i++) { // for each sender rank
892       PetscInt k = 0;
893       for (PetscInt j = Fi[roffset[i]]; j < Fi[roffset[i + 1]]; j++) { // I will receive rows [roffset[i], roffset[i+1]) of F from ranks[i]
894         iremote[j].rank  = ranks[i];
895         iremote[j].index = rdisp[i] + k; // their root location
896         k++;
897       }
898     }
899     PetscCall(PetscSFCreate(comm, &bcastSF));
900     PetscCall(PetscSFSetGraph(bcastSF, nroots, nleaves, NULL, PETSC_OWN_POINTER, iremote, PETSC_OWN_POINTER));
901     PetscCall(PetscFree3(sdisp, rdisp, reqs));
902 
903     // Build a plan (rowoffset, irootloc, E_NzLeft) to copy rows in E to rootdata of bcastSF in parallel
904     PetscIntKokkosViewHost rowoffset_h(NoInit("rowoffset_h"), ioffset[niranks] + 1);
905     PetscInt              *rowoffset = rowoffset_h.data(); // for each entry (row) indicated in irootloc[], we calculate its destinate offset in copying
906     rowoffset[0]                     = 0;
907     for (PetscInt i = 0; i < ioffset[niranks]; i++) { rowoffset[i + 1] = rowoffset[i] + E_RowLen[irootloc[i]]; }
908 
909     // Copy (global) column indices of the needed rows in E to a buffer, and then bcast to Fj[]
910     PetscInt *jbuf, *Fj;
911     PetscCall(PetscMalloc2(nroots, &jbuf, Fnz, &Fj));
912     for (PetscInt k = 0; k < ioffset[niranks]; k++) {
913       PetscInt  i      = irootloc[k]; // row to be copied
914       PetscInt *buf    = &jbuf[rowoffset[k]];
915       PetscInt  nzLeft = E_NzLeft[i];
916       PetscInt  alen = Ai[i + 1] - Ai[i], blen = Bi[i + 1] - Bi[i];
917       for (PetscInt j = 0; j < alen + blen; j++) {
918         if (j < nzLeft) {
919           buf[j] = empi->garray[Bj[Bi[i] + j]]; // left B, in global
920         } else if (j < nzLeft + alen) {
921           buf[j] = Aj[Ai[i] + j - nzLeft] + cstart; // diag A, also in global
922         } else {
923           buf[j] = empi->garray[Bj[Bi[i] + j - alen]]; // right B, in global
924         }
925       }
926     }
927     PetscCall(PetscSFBcastWithMemTypeBegin(bcastSF, MPIU_INT, PETSC_MEMTYPE_HOST, jbuf, PETSC_MEMTYPE_HOST, Fj, MPI_REPLACE));
928     PetscCall(PetscSFBcastEnd(bcastSF, MPIU_INT, jbuf, Fj, MPI_REPLACE));
929 
930     // Build a plan (i.e., F_NzLeft) to split F into Fd and Fo
931     MatRowMapKokkosViewHost Fdi_h(NoInit("Fdi_h"), Fm + 1), Foi_h(NoInit("Foi_h"), Fm + 1); // row pointer of Fd, Fo
932     MatColIdxKokkosViewHost F_NzLeft_h(NoInit("F_NzLeft_h"), Fm);                           // split each row of F into Left, Diag, Right. We only need to record #nz in Left and Diag.
933     MatRowMapType          *Fdi = Fdi_h.data(), *Foi = Foi_h.data();
934     MatColIdxType          *F_NzLeft = F_NzLeft_h.data();
935 
936     Fdi[0] = Foi[0] = 0;
937     for (PetscInt i = 0; i < Fm; i++) {
938       PetscInt *first, *last, *lb1, *lb2;
939       // cut the row into: Left, [cstart, cend), Right
940       first       = Fj + Fi[i];
941       last        = Fj + Fi[i + 1];
942       lb1         = std::lower_bound(first, last, cstart);
943       F_NzLeft[i] = lb1 - first;
944       lb2         = std::lower_bound(first, last, cend);
945       Fdi[i + 1]  = lb2 - lb1;                        // row i length in Fdi
946       Foi[i + 1]  = (Fi[i + 1] - Fi[i]) - Fdi[i + 1]; // row i length in Foi
947     }
948     for (PetscInt i = 0; i < Fm; i++) {
949       Fdi[i + 1] += Fdi[i];
950       Foi[i + 1] += Foi[i];
951     }
952 
953     // Fill Fdj[] and Foj[], i.e., columns of Fd and Fo. Fdj[] are local, but Foj[] are not yet.
954     PetscInt                Fdnz = Fdi[Fm], Fonz = Foi[Fm];
955     MatColIdxKokkosViewHost Fdj_h(NoInit("Fdj_h"), Fdnz), Foj_h(NoInit("Foj_h"), Fonz);
956     MatColIdxType          *Fdj = Fdj_h.data(), *Foj = Foj_h.data(), gid;
957 
958     for (PetscInt i = 0; i < Fm; i++) {
959       PetscInt nzLeft = F_NzLeft[i];
960       PetscInt len    = Fdi[i + 1] - Fdi[i]; // diag row len
961       for (PetscInt j = 0; j < Fi[i + 1] - Fi[i]; j++) {
962         gid = Fj[Fi[i] + j];
963         if (j < nzLeft) { // left, in global
964           Foj[Foi[i] + j] = gid;
965         } else if (j < nzLeft + len) { // diag, in local
966           Fdj[Fdi[i] + j - nzLeft] = gid - cstart;
967         } else { // right, in global
968           Foj[Foi[i] + j - len] = gid;
969         }
970       }
971     }
972     PetscCall(PetscFree2(jbuf, Fj));
973     PetscCall(PetscFree(Fi));
974 
975     // Reduce global indices in Foj[] and garray1[] into local ones
976     PetscInt n2, *garray2;
977     PetscCall(ReduceTwoSetsOfGlobalIndices(n1, garray1, Fonz, Foj, &n2, &garray2, map));
978 
979     // Record the plans built above, for reuse
980     PetscIntKokkosViewHost tmp(const_cast<PetscInt *>(irootloc), ioffset[niranks]); // irootloc[] is owned by ownerSF. We create a copy for safety
981     PetscIntKokkosViewHost irootloc_h(NoInit("irootloc_h"), ioffset[niranks]);
982     Kokkos::deep_copy(irootloc_h, tmp);
983     mm->sf        = bcastSF;
984     mm->E_NzLeft  = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), E_NzLeft_h);
985     mm->F_NzLeft  = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), F_NzLeft_h);
986     mm->irootloc  = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), irootloc_h);
987     mm->rowoffset = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), rowoffset_h);
988     mm->rootBuf   = MatScalarKokkosView(NoInit("rootBuf"), nroots);
989     mm->leafBuf   = MatScalarKokkosView(NoInit("leafBuf"), nleaves);
990     mm->garray    = garray2;
991     mm->n         = n2;
992 
993     // Output Fd and Fo in KokkosCsrMatrix format
994     MatScalarKokkosView Fda_d(NoInit("Fda_d"), Fdnz), Foa_d(NoInit("Foa_d"), Fonz);
995     MatRowMapKokkosView Fdi_d = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Fdi_h);
996     MatColIdxKokkosView Fdj_d = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Fdj_h);
997     MatRowMapKokkosView Foi_d = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Foi_h);
998     MatColIdxKokkosView Foj_d = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Foj_h);
999 
1000     PetscCallCXX(mm->Fd = KokkosCsrMatrix("Fd", Fm, cend - cstart, Fdnz, Fda_d, Fdi_d, Fdj_d));
1001     PetscCallCXX(mm->Fo = KokkosCsrMatrix("Fo", Fm, n2, Fonz, Foa_d, Foi_d, Foj_d));
1002 
1003     // Compute kernel launch parameters in merging E or splitting F
1004     PetscInt teamSize, vectorLength, rowsPerTeam;
1005 
1006     teamSize = vectorLength = rowsPerTeam = -1;
1007     PetscCall(MatMergeGetLaunchParameters<DefaultExecutionSpace>(mm->irootloc.extent(0), mm->rootBuf.extent(0), -1, teamSize, vectorLength, rowsPerTeam));
1008     mm->E_TeamSize     = teamSize;
1009     mm->E_VectorLength = vectorLength;
1010     mm->E_RowsPerTeam  = rowsPerTeam;
1011 
1012     teamSize = vectorLength = rowsPerTeam = -1;
1013     PetscCall(MatMergeGetLaunchParameters<DefaultExecutionSpace>(Fm, Fnz, -1, teamSize, vectorLength, rowsPerTeam));
1014     mm->F_TeamSize     = teamSize;
1015     mm->F_VectorLength = vectorLength;
1016     mm->F_RowsPerTeam  = rowsPerTeam;
1017   } else PetscCheck(reuse == MAT_REUSE_MATRIX, comm, PETSC_ERR_PLIB, "Unsupported MatReuse enum %d", reuse);
1018 
1019   // Sync E's value to device
1020   akok->a_dual.sync_device();
1021   bkok->a_dual.sync_device();
1022 
1023   // Handy aliases
1024   const auto &Aa = akok->a_dual.view_device();
1025   const auto &Ba = bkok->a_dual.view_device();
1026   const auto &Ai = akok->i_dual.view_device();
1027   const auto &Bi = bkok->i_dual.view_device();
1028 
1029   // Fetch the plans
1030   PetscIntKokkosView  &E_NzLeft  = mm->E_NzLeft;
1031   PetscSF             &bcastSF   = mm->sf;
1032   MatScalarKokkosView &rootBuf   = mm->rootBuf;
1033   MatScalarKokkosView &leafBuf   = mm->leafBuf;
1034   PetscIntKokkosView  &irootloc  = mm->irootloc;
1035   PetscIntKokkosView  &rowoffset = mm->rowoffset;
1036 
1037   PetscInt teamSize     = mm->E_TeamSize;
1038   PetscInt vectorLength = mm->E_VectorLength;
1039   PetscInt rowsPerTeam  = mm->E_RowsPerTeam;
1040   PetscInt workSets     = (irootloc.extent(0) + rowsPerTeam - 1) / rowsPerTeam;
1041 
1042   // Copy rows in A/B of E to rootBuf, then bcast it to leafBuf
1043   PetscCallCXX(Kokkos::parallel_for(
1044     Kokkos::TeamPolicy<>(workSets, teamSize, vectorLength), KOKKOS_LAMBDA(const KokkosTeamMemberType &t) {
1045       Kokkos::parallel_for(Kokkos::TeamThreadRange(t, 0, rowsPerTeam), [&](PetscInt k) {
1046         size_t r = t.league_rank() * rowsPerTeam + k; // r-th entry in irootloc[]
1047         if (r < irootloc.extent(0)) {
1048           PetscInt i      = irootloc(r); // row i of E
1049           PetscInt disp   = rowoffset(r);
1050           PetscInt alen   = Ai(i + 1) - Ai(i);
1051           PetscInt blen   = Bi(i + 1) - Bi(i);
1052           PetscInt nzleft = E_NzLeft(i);
1053 
1054           Kokkos::parallel_for(Kokkos::ThreadVectorRange(t, alen + blen), [&](PetscInt j) {
1055             if (j < nzleft) { // B left
1056               rootBuf(disp + j) = Ba(Bi(i) + j);
1057             } else if (j < nzleft + alen) { // diag A
1058               rootBuf(disp + j) = Aa(Ai(i) + j - nzleft);
1059             } else { // B right
1060               rootBuf(disp + j) = Ba(Bi(i) + j - alen);
1061             }
1062           });
1063         }
1064       });
1065     }));
1066   PetscCall(PetscSFBcastWithMemTypeBegin(bcastSF, MPIU_SCALAR, PETSC_MEMTYPE_KOKKOS, rootBuf.data(), PETSC_MEMTYPE_KOKKOS, leafBuf.data(), MPI_REPLACE));
1067   PetscFunctionReturn(PETSC_SUCCESS);
1068 }
1069 
1070 // To finish MatMPIAIJKokkosBcast.
1071 static PetscErrorCode MatMPIAIJKokkosBcastEnd(Mat E, PetscSF ownerSF, MatReuse reuse, PetscInt *map, MatMatStruct_AB *mm)
1072 {
1073   PetscFunctionBegin;
1074   const auto &Fd  = mm->Fd;
1075   const auto &Fo  = mm->Fo;
1076   const auto &Fdi = Fd.graph.row_map;
1077   const auto &Foi = Fo.graph.row_map;
1078   auto       &Fda = Fd.values;
1079   auto       &Foa = Fo.values;
1080   auto        Fm  = Fd.numRows();
1081 
1082   PetscIntKokkosView  &F_NzLeft     = mm->F_NzLeft;
1083   PetscSF             &bcastSF      = mm->sf;
1084   MatScalarKokkosView &rootBuf      = mm->rootBuf;
1085   MatScalarKokkosView &leafBuf      = mm->leafBuf;
1086   PetscInt             teamSize     = mm->F_TeamSize;
1087   PetscInt             vectorLength = mm->F_VectorLength;
1088   PetscInt             rowsPerTeam  = mm->F_RowsPerTeam;
1089   PetscInt             workSets     = (Fm + rowsPerTeam - 1) / rowsPerTeam;
1090 
1091   PetscCall(PetscSFBcastEnd(bcastSF, MPIU_SCALAR, rootBuf.data(), leafBuf.data(), MPI_REPLACE));
1092 
1093   // Update Fda and Foa with new data in leafBuf (as if it is Fa)
1094   PetscCallCXX(Kokkos::parallel_for(
1095     Kokkos::TeamPolicy<>(workSets, teamSize, vectorLength), KOKKOS_LAMBDA(const KokkosTeamMemberType &t) {
1096       Kokkos::parallel_for(Kokkos::TeamThreadRange(t, 0, rowsPerTeam), [&](PetscInt k) {
1097         PetscInt i = t.league_rank() * rowsPerTeam + k; // i-th row in F
1098         if (i < Fm) {
1099           PetscInt nzLeft = F_NzLeft(i);
1100           PetscInt alen   = Fdi(i + 1) - Fdi(i);
1101           PetscInt blen   = Foi(i + 1) - Foi(i);
1102           PetscInt Fii    = Fdi(i) + Foi(i);
1103 
1104           Kokkos::parallel_for(Kokkos::ThreadVectorRange(t, alen + blen), [&](PetscInt j) {
1105             PetscScalar val = leafBuf(Fii + j);
1106             if (j < nzLeft) { // left
1107               Foa(Foi(i) + j) = val;
1108             } else if (j < nzLeft + alen) { // diag
1109               Fda(Fdi(i) + j - nzLeft) = val;
1110             } else { // right
1111               Foa(Foi(i) + j - alen) = val;
1112             }
1113           });
1114         }
1115       });
1116     }));
1117   PetscFunctionReturn(PETSC_SUCCESS);
1118 }
1119 
1120 static PetscErrorCode MatProductSymbolic_MPIAIJKokkos_AtB(Mat_Product *product, Mat A, Mat B, MatMatStruct_AtB *mm)
1121 {
1122   Mat_MPIAIJ     *ampi = static_cast<Mat_MPIAIJ *>(A->data);
1123   Mat_MPIAIJ     *bmpi = static_cast<Mat_MPIAIJ *>(B->data);
1124   KokkosCsrMatrix Adt, Aot, Ad, Ao, Bd, Bo;
1125   PetscInt        cstart, cend;
1126   MPI_Comm        comm;
1127 
1128   PetscFunctionBegin;
1129   PetscCall(PetscObjectGetComm((PetscObject)B, &comm));
1130   PetscCall(MatSeqAIJKokkosGenerateTranspose_Private(ampi->A, &Adt));
1131   PetscCall(MatSeqAIJKokkosGenerateTranspose_Private(ampi->B, &Aot));
1132   PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(ampi->A, &Ad));
1133   PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(ampi->B, &Ao));
1134   PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(bmpi->A, &Bd));
1135   PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(bmpi->B, &Bo));
1136 
1137   // TODO: add command line options to select spgemm algorithms
1138   auto spgemm_alg = KokkosSparse::SPGEMMAlgorithm::SPGEMM_DEFAULT; // default is TPL if enabled, otherwise KK
1139 
1140   // CUDA-10.2's spgemm has bugs. We prefer the SpGEMMreuse APIs introduced in cuda-11.4
1141 #if defined(KOKKOSKERNELS_ENABLE_TPL_CUSPARSE)
1142   #if PETSC_PKG_CUDA_VERSION_LT(11, 4, 0)
1143   spgemm_alg = KokkosSparse::SPGEMMAlgorithm::SPGEMM_KK;
1144   #endif
1145 #endif
1146 
1147   PetscCallCXX(mm->kh1.create_spgemm_handle(spgemm_alg));
1148   PetscCallCXX(mm->kh2.create_spgemm_handle(spgemm_alg));
1149   PetscCallCXX(mm->kh3.create_spgemm_handle(spgemm_alg));
1150   PetscCallCXX(mm->kh4.create_spgemm_handle(spgemm_alg));
1151 
1152   // Aot * (B's diag + B's off-diag)
1153   PetscCallCXX(KokkosSparse::spgemm_symbolic(mm->kh3, Aot, false, Bd, false, mm->C3));
1154   PetscCallCXX(KokkosSparse::spgemm_symbolic(mm->kh4, Aot, false, Bo, false, mm->C4));
1155   // KK spgemm_symbolic() only populates the result's row map, but not its columns.
1156   // TODO: Remove the fake spgemm_numeric() after KK fixed this problem.
1157   PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh3, Aot, false, Bd, false, mm->C3));
1158   PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh4, Aot, false, Bo, false, mm->C4));
1159 #if PETSC_PKG_KOKKOS_KERNELS_VERSION_LT(4, 0, 0)
1160   PetscCallCXX(sort_crs_matrix(mm->C3));
1161   PetscCallCXX(sort_crs_matrix(mm->C4));
1162 #endif
1163 
1164   // Reduce E (i.e., C3 and C4)'s rows to form F, and overlap the communication
1165   PetscIntKokkosViewHost map_h(NoInit("map_h"), bmpi->B->cmap->n);
1166   PetscCall(MatGetOwnershipRangeColumn(B, &cstart, &cend));
1167   PetscCall(MatMPIAIJKokkosReduceBegin(comm, mm->C3, mm->C4, cstart, cend, bmpi->garray, ampi->Mvctx, MAT_INITIAL_MATRIX, map_h.data(), mm));
1168 
1169   // Adt * (B's diag + B's off-diag)
1170   PetscCallCXX(KokkosSparse::spgemm_symbolic(mm->kh1, Adt, false, Bd, false, mm->C1));
1171   PetscCallCXX(KokkosSparse::spgemm_symbolic(mm->kh2, Adt, false, Bo, false, mm->C2_mid));
1172   PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh1, Adt, false, Bd, false, mm->C1));
1173   PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh2, Adt, false, Bo, false, mm->C2_mid));
1174 #if PETSC_PKG_KOKKOS_KERNELS_VERSION_LT(4, 0, 0)
1175   PetscCallCXX(sort_crs_matrix(mm->C1));
1176   PetscCallCXX(sort_crs_matrix(mm->C2_mid));
1177 #endif
1178 
1179   PetscCall(MatMPIAIJKokkosReduceEnd(comm, mm->C3, mm->C4, cstart, cend, bmpi->garray, ampi->Mvctx, MAT_INITIAL_MATRIX, map_h.data(), mm));
1180 
1181   // Create C2, which shares a, i arrays with C2_mid, but with new column indices and potentially larger column size
1182   MatColIdxKokkosView oldj = mm->C2_mid.graph.entries, newj(NoInit("j"), oldj.extent(0));
1183   PetscIntKokkosView  map  = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), map_h);
1184   PetscCallCXX(Kokkos::parallel_for(
1185     oldj.extent(0), KOKKOS_LAMBDA(const PetscInt i) { newj(i) = map(oldj(i)); }));
1186   PetscCallCXX(mm->C2 = KokkosCsrMatrix("C2", mm->C2_mid.numRows(), mm->n /*new column size*/, mm->C2_mid.nnz(), mm->C2_mid.values, mm->C2_mid.graph.row_map, newj));
1187 
1188   // C = (C1+Fd, C2+Fo)
1189   PetscCallCXX(mm->kh1.create_spadd_handle(true)); // C1, Fd are sorted
1190   PetscCallCXX(mm->kh2.create_spadd_handle(true)); // C2, Fo are sorted
1191   PetscCallCXX(KokkosSparse::spadd_symbolic(&mm->kh1, mm->C1, mm->Fd, mm->Cd));
1192   PetscCallCXX(KokkosSparse::spadd_symbolic(&mm->kh2, mm->C2, mm->Fo, mm->Co));
1193   PetscCallCXX(KokkosSparse::spadd_numeric(&mm->kh1, 1.0, mm->C1, 1.0, mm->Fd, mm->Cd));
1194   PetscCallCXX(KokkosSparse::spadd_numeric(&mm->kh2, 1.0, mm->C2, 1.0, mm->Fo, mm->Co));
1195   PetscFunctionReturn(PETSC_SUCCESS);
1196 }
1197 
1198 static PetscErrorCode MatProductNumeric_MPIAIJKokkos_AtB(Mat_Product *product, Mat A, Mat B, MatMatStruct_AtB *mm)
1199 {
1200   Mat_MPIAIJ     *ampi = static_cast<Mat_MPIAIJ *>(A->data);
1201   Mat_MPIAIJ     *bmpi = static_cast<Mat_MPIAIJ *>(B->data);
1202   KokkosCsrMatrix Adt, Aot, Bd, Bo;
1203   MPI_Comm        comm;
1204 
1205   PetscFunctionBegin;
1206   PetscCall(PetscObjectGetComm((PetscObject)B, &comm));
1207   PetscCall(MatSeqAIJKokkosGenerateTranspose_Private(ampi->A, &Adt));
1208   PetscCall(MatSeqAIJKokkosGenerateTranspose_Private(ampi->B, &Aot));
1209   PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(bmpi->A, &Bd));
1210   PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(bmpi->B, &Bo));
1211 
1212   // Aot * (B's diag + B's off-diag)
1213   PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh3, Aot, false, Bd, false, mm->C3));
1214   PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh4, Aot, false, Bo, false, mm->C4));
1215 
1216   // Reduce E (i.e., C3 and C4)'s rows to form F, and overlap the communication
1217   PetscCall(MatMPIAIJKokkosReduceBegin(comm, mm->C3, mm->C4, 0, 0, NULL, NULL, MAT_REUSE_MATRIX, NULL, mm));
1218 
1219   // Adt * (B's diag + B's off-diag)
1220   PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh1, Adt, false, Bd, false, mm->C1));
1221   PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh2, Adt, false, Bo, false, mm->C2_mid));
1222 
1223   PetscCall(MatMPIAIJKokkosReduceEnd(comm, mm->C3, mm->C4, 0, 0, NULL, NULL, MAT_REUSE_MATRIX, NULL, mm));
1224 
1225   // C = (C1+Fd, C2+Fo)
1226   PetscCallCXX(KokkosSparse::spadd_numeric(&mm->kh1, 1.0, mm->C1, 1.0, mm->Fd, mm->Cd));
1227   PetscCallCXX(KokkosSparse::spadd_numeric(&mm->kh2, 1.0, mm->C2, 1.0, mm->Fo, mm->Co));
1228   PetscFunctionReturn(PETSC_SUCCESS);
1229 }
1230 
1231 /* MatProductSymbolic_MPIAIJKokkos_AB - AB flavor of MatProductSymbolic_MPIAIJKokkos
1232 
1233   Input Parameters:
1234 +  product  - Mat_Product which carried out the computation. Passed in to access info about this mat product.
1235 .  A        - an MPIAIJKOKKOS matrix
1236 .  B        - an MPIAIJKOKKOS matrix
1237 -  mm       - a struct used to stash intermediate data when computing AB. Persist from symbolic to numeric operations.
1238 */
1239 static PetscErrorCode MatProductSymbolic_MPIAIJKokkos_AB(Mat_Product *product, Mat A, Mat B, MatMatStruct_AB *mm)
1240 {
1241   Mat_MPIAIJ     *ampi = static_cast<Mat_MPIAIJ *>(A->data);
1242   Mat_MPIAIJ     *bmpi = static_cast<Mat_MPIAIJ *>(B->data);
1243   KokkosCsrMatrix Ad, Ao, Bd, Bo;
1244 
1245   PetscFunctionBegin;
1246   PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(ampi->A, &Ad));
1247   PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(ampi->B, &Ao));
1248   PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(bmpi->A, &Bd));
1249   PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(bmpi->B, &Bo));
1250 
1251   // TODO: add command line options to select spgemm algorithms
1252   auto spgemm_alg = KokkosSparse::SPGEMMAlgorithm::SPGEMM_DEFAULT; // default is TPL if enabled, otherwise KK
1253 
1254   // CUDA-10.2's spgemm has bugs. We prefer the SpGEMMreuse APIs introduced in cuda-11.4
1255 #if defined(KOKKOSKERNELS_ENABLE_TPL_CUSPARSE)
1256   #if PETSC_PKG_CUDA_VERSION_LT(11, 4, 0)
1257   spgemm_alg = KokkosSparse::SPGEMMAlgorithm::SPGEMM_KK;
1258   #endif
1259 #endif
1260 
1261   mm->kh1.create_spgemm_handle(spgemm_alg);
1262   mm->kh2.create_spgemm_handle(spgemm_alg);
1263   mm->kh3.create_spgemm_handle(spgemm_alg);
1264   mm->kh4.create_spgemm_handle(spgemm_alg);
1265 
1266   // Bcast B's rows to form F, and overlap the communication
1267   PetscIntKokkosViewHost map_h(NoInit("map_h"), bmpi->B->cmap->n);
1268   PetscCall(MatMPIAIJKokkosBcastBegin(B, ampi->Mvctx, MAT_INITIAL_MATRIX, map_h.data(), mm));
1269 
1270   // A's diag * (B's diag + B's off-diag)
1271   PetscCallCXX(KokkosSparse::spgemm_symbolic(mm->kh1, Ad, false, Bd, false, mm->C1));
1272   PetscCallCXX(KokkosSparse::spgemm_symbolic(mm->kh2, Ad, false, Bo, false, mm->C2_mid)); // C2 aliases with C2_mid, except with new column indices
1273   // KK spgemm_symbolic() only populates the result's row map, but not its columns.
1274   // TODO: Remove the fake spgemm_numeric() after KK fixed this problem.
1275   PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh1, Ad, false, Bd, false, mm->C1));
1276   PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh2, Ad, false, Bo, false, mm->C2_mid));
1277 #if PETSC_PKG_KOKKOS_KERNELS_VERSION_LT(4, 0, 0)
1278   PetscCallCXX(sort_crs_matrix(mm->C1));
1279   PetscCallCXX(sort_crs_matrix(mm->C2_mid));
1280 #endif
1281 
1282   PetscCall(MatMPIAIJKokkosBcastEnd(B, ampi->Mvctx, MAT_INITIAL_MATRIX, map_h.data(), mm));
1283 
1284   // A's off-diag * (F's diag + F's off-diag)
1285   PetscCallCXX(KokkosSparse::spgemm_symbolic(mm->kh3, Ao, false, mm->Fd, false, mm->C3));
1286   PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh3, Ao, false, mm->Fd, false, mm->C3));
1287   PetscCallCXX(KokkosSparse::spgemm_symbolic(mm->kh4, Ao, false, mm->Fo, false, mm->C4));
1288   PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh4, Ao, false, mm->Fo, false, mm->C4));
1289 #if PETSC_PKG_KOKKOS_KERNELS_VERSION_LT(4, 0, 0)
1290   PetscCallCXX(sort_crs_matrix(mm->C3));
1291   PetscCallCXX(sort_crs_matrix(mm->C4));
1292 #endif
1293 
1294   // Create C2, which shares a, i arrays with C2_mid, but with new column indices and potentially larger column size
1295   MatColIdxKokkosView oldj = mm->C2_mid.graph.entries, newj(NoInit("j"), oldj.extent(0));
1296   PetscIntKokkosView  map  = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), map_h);
1297   PetscCallCXX(Kokkos::parallel_for(
1298     oldj.extent(0), KOKKOS_LAMBDA(const PetscInt i) { newj(i) = map(oldj(i)); }));
1299   mm->C2 = KokkosCsrMatrix("C2", mm->C2_mid.numRows(), mm->n /*new column size*/, mm->C2_mid.nnz(), mm->C2_mid.values, mm->C2_mid.graph.row_map, newj);
1300 
1301   // C = (Cd, Co) = (C1+C3, C2+C4)
1302   mm->kh1.create_spadd_handle(true); // C1, C3 are sorted
1303   mm->kh2.create_spadd_handle(true); // C2, C4 are sorted
1304   PetscCallCXX(KokkosSparse::spadd_symbolic(&mm->kh1, mm->C1, mm->C3, mm->Cd));
1305   PetscCallCXX(KokkosSparse::spadd_symbolic(&mm->kh2, mm->C2, mm->C4, mm->Co));
1306   PetscCallCXX(KokkosSparse::spadd_numeric(&mm->kh1, 1.0, mm->C1, 1.0, mm->C3, mm->Cd));
1307   PetscCallCXX(KokkosSparse::spadd_numeric(&mm->kh2, 1.0, mm->C2, 1.0, mm->C4, mm->Co));
1308   PetscFunctionReturn(PETSC_SUCCESS);
1309 }
1310 
1311 static PetscErrorCode MatProductNumeric_MPIAIJKokkos_AB(Mat_Product *product, Mat A, Mat B, MatMatStruct_AB *mm)
1312 {
1313   Mat_MPIAIJ     *ampi = static_cast<Mat_MPIAIJ *>(A->data);
1314   Mat_MPIAIJ     *bmpi = static_cast<Mat_MPIAIJ *>(B->data);
1315   KokkosCsrMatrix Ad, Ao, Bd, Bo;
1316 
1317   PetscFunctionBegin;
1318   PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(ampi->A, &Ad));
1319   PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(ampi->B, &Ao));
1320   PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(bmpi->A, &Bd));
1321   PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(bmpi->B, &Bo));
1322 
1323   // Bcast B's rows to form F, and overlap the communication
1324   PetscCall(MatMPIAIJKokkosBcastBegin(B, NULL, MAT_REUSE_MATRIX, NULL, mm));
1325 
1326   // A's diag * (B's diag + B's off-diag)
1327   PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh1, Ad, false, Bd, false, mm->C1));
1328   PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh2, Ad, false, Bo, false, mm->C2_mid));
1329 
1330   PetscCall(MatMPIAIJKokkosBcastEnd(B, NULL, MAT_REUSE_MATRIX, NULL, mm));
1331 
1332   // A's off-diag * (F's diag + F's off-diag)
1333   PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh3, Ao, false, mm->Fd, false, mm->C3));
1334   PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh4, Ao, false, mm->Fo, false, mm->C4));
1335 
1336   // C = (Cd, Co) = (C1+C3, C2+C4)
1337   PetscCallCXX(KokkosSparse::spadd_numeric(&mm->kh1, 1.0, mm->C1, 1.0, mm->C3, mm->Cd));
1338   PetscCallCXX(KokkosSparse::spadd_numeric(&mm->kh2, 1.0, mm->C2, 1.0, mm->C4, mm->Co));
1339   PetscFunctionReturn(PETSC_SUCCESS);
1340 }
1341 
1342 static PetscErrorCode MatProductNumeric_MPIAIJKokkos(Mat C)
1343 {
1344   Mat_MPIAIJ                  *cmpi = static_cast<Mat_MPIAIJ *>(C->data);
1345   Mat_Product                 *product;
1346   MatProductData_MPIAIJKokkos *pdata;
1347   MatProductType               ptype;
1348   Mat                          A, B;
1349 
1350   PetscFunctionBegin;
1351   MatCheckProduct(C, 1); // make sure C is a product
1352   product = C->product;
1353   pdata   = static_cast<MatProductData_MPIAIJKokkos *>(product->data);
1354   ptype   = product->type;
1355   A       = product->A;
1356   B       = product->B;
1357 
1358   // See if numeric has already been done in symbolic (e.g., user calls MatMatMult(A,B,MAT_INITIAL_MATRIX,..,C)).
1359   // If yes, skip the numeric, but reset the flag so that next time when user calls MatMatMult(E,F,MAT_REUSE_MATRIX,..,C),
1360   // we still do numeric.
1361   if (pdata->reusesym) { // numeric reuses results from symbolic
1362     pdata->reusesym = PETSC_FALSE;
1363     PetscFunctionReturn(PETSC_SUCCESS);
1364   }
1365 
1366   if (ptype == MATPRODUCT_AB) {
1367     PetscCall(MatProductNumeric_MPIAIJKokkos_AB(product, A, B, pdata->mmAB));
1368   } else if (ptype == MATPRODUCT_AtB) {
1369     PetscCall(MatProductNumeric_MPIAIJKokkos_AtB(product, A, B, pdata->mmAtB));
1370   } else if (ptype == MATPRODUCT_PtAP) { // BtAB, computed by Z = AB; C= BtZ
1371     PetscCall(MatProductNumeric_MPIAIJKokkos_AB(product, A, B, pdata->mmAB));
1372     PetscCall(MatProductNumeric_MPIAIJKokkos_AtB(product, B, pdata->Z, pdata->mmAtB));
1373   }
1374 
1375   PetscCall(MatSeqAIJKokkosModifyDevice(cmpi->A)); // mark that A, B on device are modified
1376   PetscCall(MatSeqAIJKokkosModifyDevice(cmpi->B));
1377   PetscFunctionReturn(PETSC_SUCCESS);
1378 }
1379 
1380 static PetscErrorCode MatProductSymbolic_MPIAIJKokkos(Mat C)
1381 {
1382   Mat                          A, B;
1383   Mat_Product                 *product;
1384   MatProductType               ptype;
1385   MatProductData_MPIAIJKokkos *pdata;
1386   MatMatStruct                *mm = NULL;
1387   PetscInt                     m, n, M, N;
1388   Mat                          Cd, Co;
1389   MPI_Comm                     comm;
1390 
1391   PetscFunctionBegin;
1392   PetscCall(PetscObjectGetComm((PetscObject)C, &comm));
1393   MatCheckProduct(C, 1);
1394   product = C->product;
1395   PetscCheck(!product->data, comm, PETSC_ERR_PLIB, "Product data not empty");
1396   ptype = product->type;
1397   A     = product->A;
1398   B     = product->B;
1399 
1400   switch (ptype) {
1401   case MATPRODUCT_AB:
1402     m = A->rmap->n;
1403     n = B->cmap->n;
1404     M = A->rmap->N;
1405     N = B->cmap->N;
1406     break;
1407   case MATPRODUCT_AtB:
1408     m = A->cmap->n;
1409     n = B->cmap->n;
1410     M = A->cmap->N;
1411     N = B->cmap->N;
1412     break;
1413   case MATPRODUCT_PtAP:
1414     m = B->cmap->n;
1415     n = B->cmap->n;
1416     M = B->cmap->N;
1417     N = B->cmap->N;
1418     break; /* BtAB */
1419   default:
1420     SETERRQ(comm, PETSC_ERR_PLIB, "Not for product type %s", MatProductTypes[ptype]);
1421   }
1422 
1423   PetscCall(MatSetSizes(C, m, n, M, N));
1424   PetscCall(PetscLayoutSetUp(C->rmap));
1425   PetscCall(PetscLayoutSetUp(C->cmap));
1426   PetscCall(MatSetType(C, ((PetscObject)A)->type_name));
1427 
1428   pdata           = new MatProductData_MPIAIJKokkos();
1429   pdata->reusesym = product->api_user;
1430 
1431   if (ptype == MATPRODUCT_AB) {
1432     auto mmAB = new MatMatStruct_AB();
1433     PetscCall(MatProductSymbolic_MPIAIJKokkos_AB(product, A, B, mmAB));
1434     mm = pdata->mmAB = mmAB;
1435   } else if (ptype == MATPRODUCT_AtB) {
1436     auto mmAtB = new MatMatStruct_AtB();
1437     PetscCall(MatProductSymbolic_MPIAIJKokkos_AtB(product, A, B, mmAtB));
1438     mm = pdata->mmAtB = mmAtB;
1439   } else if (ptype == MATPRODUCT_PtAP) { // C = BtAB, computed as Z = AB; C= BtZ
1440     Mat Zd, Zo, Z;                       // Zd, Zo are owned by pdata->Z
1441 
1442     auto mmAB = new MatMatStruct_AB();
1443     PetscCall(MatProductSymbolic_MPIAIJKokkos_AB(product, A, B, mmAB)); // Z stored as mmAB->{Cd, Co}
1444     PetscCall(MatCreateSeqAIJKokkosWithKokkosCsrMatrix(PETSC_COMM_SELF, mmAB->Cd, &Zd));
1445     PetscCall(MatCreateSeqAIJKokkosWithKokkosCsrMatrix(PETSC_COMM_SELF, mmAB->Co, &Zo));
1446     pdata->mmAB = mmAB;
1447 
1448     m = A->rmap->n; // Z's layout
1449     n = B->cmap->n;
1450     M = A->rmap->N;
1451     N = B->cmap->N;
1452     PetscCall(MatCreate(comm, &Z));
1453     PetscCall(MatSetSizes(Z, m, n, M, N));
1454     PetscCall(PetscLayoutSetUp(Z->rmap));
1455     PetscCall(PetscLayoutSetUp(Z->cmap));
1456     PetscCall(MatSetType(Z, MATMPIAIJKOKKOS));
1457     PetscCall(MatSetMPIAIJKokkosWithSplitSeqAIJKokkosMatrices(Z, Zd, Zo, mmAB->garray));
1458 
1459     auto mmAtB = new MatMatStruct_AtB();
1460     PetscCall(MatProductSymbolic_MPIAIJKokkos_AtB(product, B, Z, mmAtB)); // final result C stored as mmAtB->{Cd, Co}
1461 
1462     pdata->Z = Z; // give ownership to pdata
1463     mm = pdata->mmAtB = mmAtB;
1464   }
1465 
1466   PetscCall(MatCreateSeqAIJKokkosWithKokkosCsrMatrix(PETSC_COMM_SELF, mm->Cd, &Cd));
1467   PetscCall(MatCreateSeqAIJKokkosWithKokkosCsrMatrix(PETSC_COMM_SELF, mm->Co, &Co));
1468   PetscCall(MatSetMPIAIJKokkosWithSplitSeqAIJKokkosMatrices(C, Cd, Co, mm->garray));
1469 
1470   C->product->data       = pdata;
1471   C->product->destroy    = MatProductDataDestroy_MPIAIJKokkos;
1472   C->ops->productnumeric = MatProductNumeric_MPIAIJKokkos;
1473   PetscFunctionReturn(PETSC_SUCCESS);
1474 }
1475 
1476 PETSC_INTERN PetscErrorCode MatProductSetFromOptions_MPIAIJKokkos(Mat mat)
1477 {
1478   Mat_Product *product = mat->product;
1479   PetscBool    match   = PETSC_FALSE;
1480   PetscBool    usecpu  = PETSC_FALSE;
1481 
1482   PetscFunctionBegin;
1483   MatCheckProduct(mat, 1);
1484   if (!product->A->boundtocpu && !product->B->boundtocpu) PetscCall(PetscObjectTypeCompare((PetscObject)product->B, ((PetscObject)product->A)->type_name, &match));
1485   if (match) { /* we can always fallback to the CPU if requested */
1486     switch (product->type) {
1487     case MATPRODUCT_AB:
1488       if (product->api_user) {
1489         PetscOptionsBegin(PetscObjectComm((PetscObject)mat), ((PetscObject)mat)->prefix, "MatMatMult", "Mat");
1490         PetscCall(PetscOptionsBool("-matmatmult_backend_cpu", "Use CPU code", "MatMatMult", usecpu, &usecpu, NULL));
1491         PetscOptionsEnd();
1492       } else {
1493         PetscOptionsBegin(PetscObjectComm((PetscObject)mat), ((PetscObject)mat)->prefix, "MatProduct_AB", "Mat");
1494         PetscCall(PetscOptionsBool("-mat_product_algorithm_backend_cpu", "Use CPU code", "MatMatMult", usecpu, &usecpu, NULL));
1495         PetscOptionsEnd();
1496       }
1497       break;
1498     case MATPRODUCT_AtB:
1499       if (product->api_user) {
1500         PetscOptionsBegin(PetscObjectComm((PetscObject)mat), ((PetscObject)mat)->prefix, "MatTransposeMatMult", "Mat");
1501         PetscCall(PetscOptionsBool("-mattransposematmult_backend_cpu", "Use CPU code", "MatTransposeMatMult", usecpu, &usecpu, NULL));
1502         PetscOptionsEnd();
1503       } else {
1504         PetscOptionsBegin(PetscObjectComm((PetscObject)mat), ((PetscObject)mat)->prefix, "MatProduct_AtB", "Mat");
1505         PetscCall(PetscOptionsBool("-mat_product_algorithm_backend_cpu", "Use CPU code", "MatTransposeMatMult", usecpu, &usecpu, NULL));
1506         PetscOptionsEnd();
1507       }
1508       break;
1509     case MATPRODUCT_PtAP:
1510       if (product->api_user) {
1511         PetscOptionsBegin(PetscObjectComm((PetscObject)mat), ((PetscObject)mat)->prefix, "MatPtAP", "Mat");
1512         PetscCall(PetscOptionsBool("-matptap_backend_cpu", "Use CPU code", "MatPtAP", usecpu, &usecpu, NULL));
1513         PetscOptionsEnd();
1514       } else {
1515         PetscOptionsBegin(PetscObjectComm((PetscObject)mat), ((PetscObject)mat)->prefix, "MatProduct_PtAP", "Mat");
1516         PetscCall(PetscOptionsBool("-mat_product_algorithm_backend_cpu", "Use CPU code", "MatPtAP", usecpu, &usecpu, NULL));
1517         PetscOptionsEnd();
1518       }
1519       break;
1520     default:
1521       break;
1522     }
1523     match = (PetscBool)!usecpu;
1524   }
1525   if (match) {
1526     switch (product->type) {
1527     case MATPRODUCT_AB:
1528     case MATPRODUCT_AtB:
1529     case MATPRODUCT_PtAP:
1530       mat->ops->productsymbolic = MatProductSymbolic_MPIAIJKokkos;
1531       break;
1532     default:
1533       break;
1534     }
1535   }
1536   /* fallback to MPIAIJ ops */
1537   if (!mat->ops->productsymbolic) PetscCall(MatProductSetFromOptions_MPIAIJ(mat));
1538   PetscFunctionReturn(PETSC_SUCCESS);
1539 }
1540 
1541 // Mirror of MatCOOStruct_MPIAIJ on device
1542 struct MatCOOStruct_MPIAIJKokkos {
1543   PetscCount           n;
1544   PetscSF              sf;
1545   PetscCount           Annz, Bnnz;
1546   PetscCount           Annz2, Bnnz2;
1547   PetscCountKokkosView Ajmap1, Aperm1;
1548   PetscCountKokkosView Bjmap1, Bperm1;
1549   PetscCountKokkosView Aimap2, Ajmap2, Aperm2;
1550   PetscCountKokkosView Bimap2, Bjmap2, Bperm2;
1551   PetscCountKokkosView Cperm1;
1552   MatScalarKokkosView  sendbuf, recvbuf;
1553 
1554   MatCOOStruct_MPIAIJKokkos(const MatCOOStruct_MPIAIJ *coo_h) :
1555     n(coo_h->n),
1556     sf(coo_h->sf),
1557     Annz(coo_h->Annz),
1558     Bnnz(coo_h->Bnnz),
1559     Annz2(coo_h->Annz2),
1560     Bnnz2(coo_h->Bnnz2),
1561     Ajmap1(Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), PetscCountKokkosViewHost(coo_h->Ajmap1, coo_h->Annz + 1))),
1562     Aperm1(Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), PetscCountKokkosViewHost(coo_h->Aperm1, coo_h->Atot1))),
1563     Bjmap1(Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), PetscCountKokkosViewHost(coo_h->Bjmap1, coo_h->Bnnz + 1))),
1564     Bperm1(Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), PetscCountKokkosViewHost(coo_h->Bperm1, coo_h->Btot1))),
1565     Aimap2(Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), PetscCountKokkosViewHost(coo_h->Aimap2, coo_h->Annz2))),
1566     Ajmap2(Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), PetscCountKokkosViewHost(coo_h->Ajmap2, coo_h->Annz2 + 1))),
1567     Aperm2(Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), PetscCountKokkosViewHost(coo_h->Aperm2, coo_h->Atot2))),
1568     Bimap2(Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), PetscCountKokkosViewHost(coo_h->Bimap2, coo_h->Bnnz2))),
1569     Bjmap2(Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), PetscCountKokkosViewHost(coo_h->Bjmap2, coo_h->Bnnz2 + 1))),
1570     Bperm2(Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), PetscCountKokkosViewHost(coo_h->Bperm2, coo_h->Btot2))),
1571     Cperm1(Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), PetscCountKokkosViewHost(coo_h->Cperm1, coo_h->sendlen))),
1572     sendbuf(Kokkos::create_mirror_view(Kokkos::WithoutInitializing, DefaultMemorySpace(), MatScalarKokkosViewHost(coo_h->sendbuf, coo_h->sendlen))),
1573     recvbuf(Kokkos::create_mirror_view(Kokkos::WithoutInitializing, DefaultMemorySpace(), MatScalarKokkosViewHost(coo_h->recvbuf, coo_h->recvlen)))
1574   {
1575     PetscCallVoid(PetscObjectReference((PetscObject)sf));
1576   }
1577 
1578   ~MatCOOStruct_MPIAIJKokkos() { PetscCallVoid(PetscSFDestroy(&sf)); }
1579 };
1580 
1581 static PetscErrorCode MatCOOStructDestroy_MPIAIJKokkos(void *data)
1582 {
1583   PetscFunctionBegin;
1584   PetscCallCXX(delete static_cast<MatCOOStruct_MPIAIJKokkos *>(data));
1585   PetscFunctionReturn(PETSC_SUCCESS);
1586 }
1587 
1588 static PetscErrorCode MatSetPreallocationCOO_MPIAIJKokkos(Mat mat, PetscCount coo_n, PetscInt coo_i[], PetscInt coo_j[])
1589 {
1590   PetscContainer             container_h, container_d;
1591   MatCOOStruct_MPIAIJ       *coo_h;
1592   MatCOOStruct_MPIAIJKokkos *coo_d;
1593 
1594   PetscFunctionBegin;
1595   PetscCall(MatSetPreallocationCOO_MPIAIJ(mat, coo_n, coo_i, coo_j)); /* mpiaij->A,B's type is set to seqaijkokkos */
1596   mat->preallocated = PETSC_TRUE;
1597   PetscCall(MatAssemblyBegin(mat, MAT_FINAL_ASSEMBLY));
1598   PetscCall(MatAssemblyEnd(mat, MAT_FINAL_ASSEMBLY));
1599   PetscCall(MatZeroEntries(mat));
1600 
1601   // Copy the COO struct to device
1602   PetscCall(PetscObjectQuery((PetscObject)mat, "__PETSc_MatCOOStruct_Host", (PetscObject *)&container_h));
1603   PetscCall(PetscContainerGetPointer(container_h, (void **)&coo_h));
1604   PetscCallCXX(coo_d = new MatCOOStruct_MPIAIJKokkos(coo_h));
1605 
1606   // Put the COO struct in a container and then attach that to the matrix
1607   PetscCall(PetscContainerCreate(PETSC_COMM_SELF, &container_d));
1608   PetscCall(PetscContainerSetPointer(container_d, coo_d));
1609   PetscCall(PetscContainerSetUserDestroy(container_d, MatCOOStructDestroy_MPIAIJKokkos));
1610   PetscCall(PetscObjectCompose((PetscObject)mat, "__PETSc_MatCOOStruct_Device", (PetscObject)container_d));
1611   PetscCall(PetscContainerDestroy(&container_d));
1612   PetscFunctionReturn(PETSC_SUCCESS);
1613 }
1614 
1615 static PetscErrorCode MatSetValuesCOO_MPIAIJKokkos(Mat mat, const PetscScalar v[], InsertMode imode)
1616 {
1617   Mat_MPIAIJ                *mpiaij = static_cast<Mat_MPIAIJ *>(mat->data);
1618   Mat                        A = mpiaij->A, B = mpiaij->B;
1619   MatScalarKokkosView        Aa, Ba;
1620   MatScalarKokkosView        v1;
1621   PetscMemType               memtype;
1622   PetscContainer             container;
1623   MatCOOStruct_MPIAIJKokkos *coo;
1624 
1625   PetscFunctionBegin;
1626   PetscCall(PetscObjectQuery((PetscObject)mat, "__PETSc_MatCOOStruct_Device", (PetscObject *)&container));
1627   PetscCall(PetscContainerGetPointer(container, (void **)&coo));
1628 
1629   const auto &n      = coo->n;
1630   const auto &Annz   = coo->Annz;
1631   const auto &Annz2  = coo->Annz2;
1632   const auto &Bnnz   = coo->Bnnz;
1633   const auto &Bnnz2  = coo->Bnnz2;
1634   const auto &vsend  = coo->sendbuf;
1635   const auto &v2     = coo->recvbuf;
1636   const auto &Ajmap1 = coo->Ajmap1;
1637   const auto &Ajmap2 = coo->Ajmap2;
1638   const auto &Aimap2 = coo->Aimap2;
1639   const auto &Bjmap1 = coo->Bjmap1;
1640   const auto &Bjmap2 = coo->Bjmap2;
1641   const auto &Bimap2 = coo->Bimap2;
1642   const auto &Aperm1 = coo->Aperm1;
1643   const auto &Aperm2 = coo->Aperm2;
1644   const auto &Bperm1 = coo->Bperm1;
1645   const auto &Bperm2 = coo->Bperm2;
1646   const auto &Cperm1 = coo->Cperm1;
1647 
1648   PetscCall(PetscGetMemType(v, &memtype)); /* Return PETSC_MEMTYPE_HOST when v is NULL */
1649   if (PetscMemTypeHost(memtype)) {         /* If user gave v[] in host, we need to copy it to device if any */
1650     v1 = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), MatScalarKokkosViewHost((PetscScalar *)v, n));
1651   } else {
1652     v1 = MatScalarKokkosView((PetscScalar *)v, n); /* Directly use v[]'s memory */
1653   }
1654 
1655   if (imode == INSERT_VALUES) {
1656     PetscCall(MatSeqAIJGetKokkosViewWrite(A, &Aa)); /* write matrix values */
1657     PetscCall(MatSeqAIJGetKokkosViewWrite(B, &Ba));
1658   } else {
1659     PetscCall(MatSeqAIJGetKokkosView(A, &Aa)); /* read & write matrix values */
1660     PetscCall(MatSeqAIJGetKokkosView(B, &Ba));
1661   }
1662 
1663   PetscCall(PetscLogGpuTimeBegin());
1664   /* Pack entries to be sent to remote */
1665   Kokkos::parallel_for(
1666     vsend.extent(0), KOKKOS_LAMBDA(const PetscCount i) { vsend(i) = v1(Cperm1(i)); });
1667 
1668   /* Send remote entries to their owner and overlap the communication with local computation */
1669   PetscCall(PetscSFReduceWithMemTypeBegin(coo->sf, MPIU_SCALAR, PETSC_MEMTYPE_KOKKOS, vsend.data(), PETSC_MEMTYPE_KOKKOS, v2.data(), MPI_REPLACE));
1670   /* Add local entries to A and B in one kernel */
1671   Kokkos::parallel_for(
1672     Annz + Bnnz, KOKKOS_LAMBDA(PetscCount i) {
1673       PetscScalar sum = 0.0;
1674       if (i < Annz) {
1675         for (PetscCount k = Ajmap1(i); k < Ajmap1(i + 1); k++) sum += v1(Aperm1(k));
1676         Aa(i) = (imode == INSERT_VALUES ? 0.0 : Aa(i)) + sum;
1677       } else {
1678         i -= Annz;
1679         for (PetscCount k = Bjmap1(i); k < Bjmap1(i + 1); k++) sum += v1(Bperm1(k));
1680         Ba(i) = (imode == INSERT_VALUES ? 0.0 : Ba(i)) + sum;
1681       }
1682     });
1683   PetscCall(PetscSFReduceEnd(coo->sf, MPIU_SCALAR, vsend.data(), v2.data(), MPI_REPLACE));
1684 
1685   /* Add received remote entries to A and B in one kernel */
1686   Kokkos::parallel_for(
1687     Annz2 + Bnnz2, KOKKOS_LAMBDA(PetscCount i) {
1688       if (i < Annz2) {
1689         for (PetscCount k = Ajmap2(i); k < Ajmap2(i + 1); k++) Aa(Aimap2(i)) += v2(Aperm2(k));
1690       } else {
1691         i -= Annz2;
1692         for (PetscCount k = Bjmap2(i); k < Bjmap2(i + 1); k++) Ba(Bimap2(i)) += v2(Bperm2(k));
1693       }
1694     });
1695   PetscCall(PetscLogGpuTimeEnd());
1696 
1697   if (imode == INSERT_VALUES) {
1698     PetscCall(MatSeqAIJRestoreKokkosViewWrite(A, &Aa)); /* Increase A & B's state etc. */
1699     PetscCall(MatSeqAIJRestoreKokkosViewWrite(B, &Ba));
1700   } else {
1701     PetscCall(MatSeqAIJRestoreKokkosView(A, &Aa));
1702     PetscCall(MatSeqAIJRestoreKokkosView(B, &Ba));
1703   }
1704   PetscFunctionReturn(PETSC_SUCCESS);
1705 }
1706 
1707 static PetscErrorCode MatDestroy_MPIAIJKokkos(Mat A)
1708 {
1709   PetscFunctionBegin;
1710   PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatMPIAIJSetPreallocation_C", NULL));
1711   PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatMPIAIJGetLocalMatMerge_C", NULL));
1712   PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatSetPreallocationCOO_C", NULL));
1713   PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatSetValuesCOO_C", NULL));
1714   PetscCall(MatDestroy_MPIAIJ(A));
1715   PetscFunctionReturn(PETSC_SUCCESS);
1716 }
1717 
1718 static PetscErrorCode MatShift_MPIAIJKokkos(Mat A, PetscScalar a)
1719 {
1720   Mat_MPIAIJ *mpiaij = static_cast<Mat_MPIAIJ *>(A->data);
1721   PetscBool   congruent;
1722 
1723   PetscFunctionBegin;
1724   PetscCall(MatHasCongruentLayouts(A, &congruent));
1725   if (congruent) { // square matrix and the diagonals are solely in the diag block
1726     PetscCall(MatShift(mpiaij->A, a));
1727   } else { // too hard, use the general version
1728     PetscCall(MatShift_Basic(A, a));
1729   }
1730   PetscFunctionReturn(PETSC_SUCCESS);
1731 }
1732 
1733 static PetscErrorCode MatSetOps_MPIAIJKokkos(Mat B)
1734 {
1735   PetscFunctionBegin;
1736   B->ops->assemblyend           = MatAssemblyEnd_MPIAIJKokkos;
1737   B->ops->mult                  = MatMult_MPIAIJKokkos;
1738   B->ops->multadd               = MatMultAdd_MPIAIJKokkos;
1739   B->ops->multtranspose         = MatMultTranspose_MPIAIJKokkos;
1740   B->ops->productsetfromoptions = MatProductSetFromOptions_MPIAIJKokkos;
1741   B->ops->destroy               = MatDestroy_MPIAIJKokkos;
1742   B->ops->shift                 = MatShift_MPIAIJKokkos;
1743 
1744   PetscCall(PetscObjectComposeFunction((PetscObject)B, "MatMPIAIJSetPreallocation_C", MatMPIAIJSetPreallocation_MPIAIJKokkos));
1745   PetscCall(PetscObjectComposeFunction((PetscObject)B, "MatMPIAIJGetLocalMatMerge_C", MatMPIAIJGetLocalMatMerge_MPIAIJKokkos));
1746   PetscCall(PetscObjectComposeFunction((PetscObject)B, "MatSetPreallocationCOO_C", MatSetPreallocationCOO_MPIAIJKokkos));
1747   PetscCall(PetscObjectComposeFunction((PetscObject)B, "MatSetValuesCOO_C", MatSetValuesCOO_MPIAIJKokkos));
1748   PetscFunctionReturn(PETSC_SUCCESS);
1749 }
1750 
1751 PETSC_INTERN PetscErrorCode MatConvert_MPIAIJ_MPIAIJKokkos(Mat A, MatType mtype, MatReuse reuse, Mat *newmat)
1752 {
1753   Mat         B;
1754   Mat_MPIAIJ *a;
1755 
1756   PetscFunctionBegin;
1757   if (reuse == MAT_INITIAL_MATRIX) {
1758     PetscCall(MatDuplicate(A, MAT_COPY_VALUES, newmat));
1759   } else if (reuse == MAT_REUSE_MATRIX) {
1760     PetscCall(MatCopy(A, *newmat, SAME_NONZERO_PATTERN));
1761   }
1762   B = *newmat;
1763 
1764   B->boundtocpu = PETSC_FALSE;
1765   PetscCall(PetscFree(B->defaultvectype));
1766   PetscCall(PetscStrallocpy(VECKOKKOS, &B->defaultvectype));
1767   PetscCall(PetscObjectChangeTypeName((PetscObject)B, MATMPIAIJKOKKOS));
1768 
1769   a = static_cast<Mat_MPIAIJ *>(A->data);
1770   if (a->A) PetscCall(MatSetType(a->A, MATSEQAIJKOKKOS));
1771   if (a->B) PetscCall(MatSetType(a->B, MATSEQAIJKOKKOS));
1772   if (a->lvec) PetscCall(VecSetType(a->lvec, VECSEQKOKKOS));
1773   PetscCall(MatSetOps_MPIAIJKokkos(B));
1774   PetscFunctionReturn(PETSC_SUCCESS);
1775 }
1776 
1777 /*MC
1778    MATAIJKOKKOS - "mpiaijkokkos", a matrix type to be used for CSR sparse matrices with Kokkos
1779 
1780    A matrix type using Kokkos-Kernels CrsMatrix type for portability across different device types
1781 
1782    Options Database Key:
1783 .  -mat_type aijkokkos - sets the matrix type to `MATAIJKOKKOS`
1784 
1785   Level: beginner
1786 
1787 .seealso: [](ch_matrices), `Mat`, `MatCreateAIJKokkos()`, `MATSEQAIJKOKKOS`, `MATSEQAIJ`, `MATMPIAIJ`
1788 M*/
1789 PETSC_EXTERN PetscErrorCode MatCreate_MPIAIJKokkos(Mat A)
1790 {
1791   PetscFunctionBegin;
1792   PetscCall(PetscKokkosInitializeCheck());
1793   PetscCall(MatCreate_MPIAIJ(A));
1794   PetscCall(MatConvert_MPIAIJ_MPIAIJKokkos(A, MATMPIAIJKOKKOS, MAT_INPLACE_MATRIX, &A));
1795   PetscFunctionReturn(PETSC_SUCCESS);
1796 }
1797 
1798 /*@C
1799   MatCreateAIJKokkos - Creates a sparse matrix in `MATAIJKOKOS` (compressed row) format
1800   (the default parallel PETSc format).  This matrix will ultimately pushed down
1801   to Kokkos for calculations.
1802 
1803   Collective
1804 
1805   Input Parameters:
1806 + comm  - MPI communicator, set to `PETSC_COMM_SELF`
1807 . m     - number of local rows (or `PETSC_DECIDE` to have calculated if `M` is given)
1808            This value should be the same as the local size used in creating the
1809            y vector for the matrix-vector product y = Ax.
1810 . n     - This value should be the same as the local size used in creating the
1811        x vector for the matrix-vector product y = Ax. (or `PETSC_DECIDE` to have
1812        calculated if N is given) For square matrices n is almost always `m`.
1813 . M     - number of global rows (or `PETSC_DETERMINE` to have calculated if `m` is given)
1814 . N     - number of global columns (or `PETSC_DETERMINE` to have calculated if `n` is given)
1815 . d_nz  - number of nonzeros per row in DIAGONAL portion of local submatrix
1816            (same value is used for all local rows)
1817 . d_nnz - array containing the number of nonzeros in the various rows of the
1818            DIAGONAL portion of the local submatrix (possibly different for each row)
1819            or `NULL`, if `d_nz` is used to specify the nonzero structure.
1820            The size of this array is equal to the number of local rows, i.e `m`.
1821            For matrices you plan to factor you must leave room for the diagonal entry and
1822            put in the entry even if it is zero.
1823 . o_nz  - number of nonzeros per row in the OFF-DIAGONAL portion of local
1824            submatrix (same value is used for all local rows).
1825 - o_nnz - array containing the number of nonzeros in the various rows of the
1826            OFF-DIAGONAL portion of the local submatrix (possibly different for
1827            each row) or `NULL`, if `o_nz` is used to specify the nonzero
1828            structure. The size of this array is equal to the number
1829            of local rows, i.e `m`.
1830 
1831   Output Parameter:
1832 . A - the matrix
1833 
1834   Level: intermediate
1835 
1836   Notes:
1837   It is recommended that one use the `MatCreate()`, `MatSetType()` and/or `MatSetFromOptions()`,
1838   MatXXXXSetPreallocation() paradigm instead of this routine directly.
1839   [MatXXXXSetPreallocation() is, for example, `MatSeqAIJSetPreallocation()`]
1840 
1841   The AIJ format, also called compressed row storage), is fully compatible with standard Fortran
1842   storage.  That is, the stored row and column indices can begin at
1843   either one (as in Fortran) or zero.
1844 
1845 .seealso: [](ch_matrices), `Mat`, `MATAIJKOKOS`, `MATSEQAIJKOKOS`, `MATMPIAIJKOKOS`, `MatCreate()`, `MatCreateAIJ()`, `MatSetValues()`,
1846           `MatSeqAIJSetColumnIndices()`, `MatCreateSeqAIJWithArrays()`, `MATMPIAIJKOKKOS`, `MATAIJKOKKOS`
1847 @*/
1848 PetscErrorCode MatCreateAIJKokkos(MPI_Comm comm, PetscInt m, PetscInt n, PetscInt M, PetscInt N, PetscInt d_nz, const PetscInt d_nnz[], PetscInt o_nz, const PetscInt o_nnz[], Mat *A)
1849 {
1850   PetscMPIInt size;
1851 
1852   PetscFunctionBegin;
1853   PetscCall(MatCreate(comm, A));
1854   PetscCall(MatSetSizes(*A, m, n, M, N));
1855   PetscCallMPI(MPI_Comm_size(comm, &size));
1856   if (size > 1) {
1857     PetscCall(MatSetType(*A, MATMPIAIJKOKKOS));
1858     PetscCall(MatMPIAIJSetPreallocation(*A, d_nz, d_nnz, o_nz, o_nnz));
1859   } else {
1860     PetscCall(MatSetType(*A, MATSEQAIJKOKKOS));
1861     PetscCall(MatSeqAIJSetPreallocation(*A, d_nz, d_nnz));
1862   }
1863   PetscFunctionReturn(PETSC_SUCCESS);
1864 }
1865