xref: /petsc/src/mat/impls/aij/mpi/kokkos/mpiaijkok.kokkos.cxx (revision f13dfd9ea68e0ddeee984e65c377a1819eab8a8a)
1 #include <petsc_kokkos.hpp>
2 #include <petscvec_kokkos.hpp>
3 #include <petscpkg_version.h>
4 #include <petsc/private/sfimpl.h>
5 #include <../src/mat/impls/aij/seq/kokkos/aijkok.hpp>
6 #include <../src/mat/impls/aij/mpi/mpiaij.h>
7 #include <KokkosSparse_spadd.hpp>
8 #include <KokkosSparse_spgemm.hpp>
9 
10 static PetscErrorCode MatAssemblyEnd_MPIAIJKokkos(Mat A, MatAssemblyType mode)
11 {
12   Mat_MPIAIJ *mpiaij = (Mat_MPIAIJ *)A->data;
13 
14   PetscFunctionBegin;
15   PetscCall(MatAssemblyEnd_MPIAIJ(A, mode));
16   /* E.g., MatCreateSubMatrix() calls MatCreateMPIAIJWithSeqAIJ(comm,A,B,..), which creates Bnew of SEQAIJ and destroys B of SEQAIJKOKKOS.
17      Thus we finalize A/B/lvec's type in MatAssemblyEnd() to handle various cases.
18    */
19   if (mode == MAT_FINAL_ASSEMBLY) {
20     PetscCall(MatSetType(mpiaij->A, MATSEQAIJKOKKOS));
21     PetscCall(MatSetType(mpiaij->B, MATSEQAIJKOKKOS));
22     PetscCall(VecSetType(mpiaij->lvec, VECSEQKOKKOS));
23   }
24   PetscFunctionReturn(PETSC_SUCCESS);
25 }
26 
27 static PetscErrorCode MatMPIAIJSetPreallocation_MPIAIJKokkos(Mat mat, PetscInt d_nz, const PetscInt d_nnz[], PetscInt o_nz, const PetscInt o_nnz[])
28 {
29   Mat_MPIAIJ *mpiaij = (Mat_MPIAIJ *)mat->data;
30 
31   PetscFunctionBegin;
32   // If mat was set to use the "set values with a hash table" mechanism, discard it and restore the cached ops
33   if (mat->hash_active) {
34     mat->ops[0]      = mpiaij->cops;
35     mat->hash_active = PETSC_FALSE;
36   }
37 
38   PetscCall(PetscLayoutSetUp(mat->rmap));
39   PetscCall(PetscLayoutSetUp(mat->cmap));
40 #if defined(PETSC_USE_DEBUG)
41   if (d_nnz) {
42     PetscInt i;
43     for (i = 0; i < mat->rmap->n; i++) PetscCheck(d_nnz[i] >= 0, PETSC_COMM_SELF, PETSC_ERR_ARG_OUTOFRANGE, "d_nnz cannot be less than 0: local row %" PetscInt_FMT " value %" PetscInt_FMT, i, d_nnz[i]);
44   }
45   if (o_nnz) {
46     PetscInt i;
47     for (i = 0; i < mat->rmap->n; i++) PetscCheck(o_nnz[i] >= 0, PETSC_COMM_SELF, PETSC_ERR_ARG_OUTOFRANGE, "o_nnz cannot be less than 0: local row %" PetscInt_FMT " value %" PetscInt_FMT, i, o_nnz[i]);
48   }
49 #endif
50 #if defined(PETSC_USE_CTABLE)
51   PetscCall(PetscHMapIDestroy(&mpiaij->colmap));
52 #else
53   PetscCall(PetscFree(mpiaij->colmap));
54 #endif
55   PetscCall(PetscFree(mpiaij->garray));
56   PetscCall(VecDestroy(&mpiaij->lvec));
57   PetscCall(VecScatterDestroy(&mpiaij->Mvctx));
58   /* Because the B will have been resized we simply destroy it and create a new one each time */
59   PetscCall(MatDestroy(&mpiaij->B));
60 
61   if (!mpiaij->A) {
62     PetscCall(MatCreate(PETSC_COMM_SELF, &mpiaij->A));
63     PetscCall(MatSetSizes(mpiaij->A, mat->rmap->n, mat->cmap->n, mat->rmap->n, mat->cmap->n));
64   }
65   if (!mpiaij->B) {
66     PetscMPIInt size;
67     PetscCallMPI(MPI_Comm_size(PetscObjectComm((PetscObject)mat), &size));
68     PetscCall(MatCreate(PETSC_COMM_SELF, &mpiaij->B));
69     PetscCall(MatSetSizes(mpiaij->B, mat->rmap->n, size > 1 ? mat->cmap->N : 0, mat->rmap->n, size > 1 ? mat->cmap->N : 0));
70   }
71   PetscCall(MatSetType(mpiaij->A, MATSEQAIJKOKKOS));
72   PetscCall(MatSetType(mpiaij->B, MATSEQAIJKOKKOS));
73   PetscCall(MatSeqAIJSetPreallocation(mpiaij->A, d_nz, d_nnz));
74   PetscCall(MatSeqAIJSetPreallocation(mpiaij->B, o_nz, o_nnz));
75   mat->preallocated = PETSC_TRUE;
76   PetscFunctionReturn(PETSC_SUCCESS);
77 }
78 
79 static PetscErrorCode MatMult_MPIAIJKokkos(Mat mat, Vec xx, Vec yy)
80 {
81   Mat_MPIAIJ *mpiaij = (Mat_MPIAIJ *)mat->data;
82   PetscInt    nt;
83 
84   PetscFunctionBegin;
85   PetscCall(VecGetLocalSize(xx, &nt));
86   PetscCheck(nt == mat->cmap->n, PETSC_COMM_SELF, PETSC_ERR_ARG_SIZ, "Incompatible partition of mat (%" PetscInt_FMT ") and xx (%" PetscInt_FMT ")", mat->cmap->n, nt);
87   PetscCall(VecScatterBegin(mpiaij->Mvctx, xx, mpiaij->lvec, INSERT_VALUES, SCATTER_FORWARD));
88   PetscCall((*mpiaij->A->ops->mult)(mpiaij->A, xx, yy));
89   PetscCall(VecScatterEnd(mpiaij->Mvctx, xx, mpiaij->lvec, INSERT_VALUES, SCATTER_FORWARD));
90   PetscCall((*mpiaij->B->ops->multadd)(mpiaij->B, mpiaij->lvec, yy, yy));
91   PetscFunctionReturn(PETSC_SUCCESS);
92 }
93 
94 static PetscErrorCode MatMultAdd_MPIAIJKokkos(Mat mat, Vec xx, Vec yy, Vec zz)
95 {
96   Mat_MPIAIJ *mpiaij = (Mat_MPIAIJ *)mat->data;
97   PetscInt    nt;
98 
99   PetscFunctionBegin;
100   PetscCall(VecGetLocalSize(xx, &nt));
101   PetscCheck(nt == mat->cmap->n, PETSC_COMM_SELF, PETSC_ERR_ARG_SIZ, "Incompatible partition of mat (%" PetscInt_FMT ") and xx (%" PetscInt_FMT ")", mat->cmap->n, nt);
102   PetscCall(VecScatterBegin(mpiaij->Mvctx, xx, mpiaij->lvec, INSERT_VALUES, SCATTER_FORWARD));
103   PetscCall((*mpiaij->A->ops->multadd)(mpiaij->A, xx, yy, zz));
104   PetscCall(VecScatterEnd(mpiaij->Mvctx, xx, mpiaij->lvec, INSERT_VALUES, SCATTER_FORWARD));
105   PetscCall((*mpiaij->B->ops->multadd)(mpiaij->B, mpiaij->lvec, zz, zz));
106   PetscFunctionReturn(PETSC_SUCCESS);
107 }
108 
109 static PetscErrorCode MatMultTranspose_MPIAIJKokkos(Mat mat, Vec xx, Vec yy)
110 {
111   Mat_MPIAIJ *mpiaij = (Mat_MPIAIJ *)mat->data;
112   PetscInt    nt;
113 
114   PetscFunctionBegin;
115   PetscCall(VecGetLocalSize(xx, &nt));
116   PetscCheck(nt == mat->rmap->n, PETSC_COMM_SELF, PETSC_ERR_ARG_SIZ, "Incompatible partition of mat (%" PetscInt_FMT ") and xx (%" PetscInt_FMT ")", mat->rmap->n, nt);
117   PetscCall((*mpiaij->B->ops->multtranspose)(mpiaij->B, xx, mpiaij->lvec));
118   PetscCall((*mpiaij->A->ops->multtranspose)(mpiaij->A, xx, yy));
119   PetscCall(VecScatterBegin(mpiaij->Mvctx, mpiaij->lvec, yy, ADD_VALUES, SCATTER_REVERSE));
120   PetscCall(VecScatterEnd(mpiaij->Mvctx, mpiaij->lvec, yy, ADD_VALUES, SCATTER_REVERSE));
121   PetscFunctionReturn(PETSC_SUCCESS);
122 }
123 
124 /* Merge the "A, B" matrices of mat into a matrix C.  mat's type is MPIAIJKOKKOS. C's type is MATSEQAIJKOKKOS.
125    A is put before B. C's size would be A->rmap->n by (A->cmap->n + B->cmap->n).
126    C still uses local column ids. Their corresponding global column ids are returned in glob.
127 */
128 static PetscErrorCode MatMPIAIJGetLocalMatMerge_MPIAIJKokkos(Mat mat, MatReuse reuse, IS *glob, Mat *C)
129 {
130   Mat             Ad, Ao;
131   const PetscInt *cmap;
132 
133   PetscFunctionBegin;
134   PetscCall(MatMPIAIJGetSeqAIJ(mat, &Ad, &Ao, &cmap));
135   PetscCall(MatSeqAIJKokkosMergeMats(Ad, Ao, reuse, C));
136   if (glob) {
137     PetscInt cst, i, dn, on, *gidx;
138     PetscCall(MatGetLocalSize(Ad, NULL, &dn));
139     PetscCall(MatGetLocalSize(Ao, NULL, &on));
140     PetscCall(MatGetOwnershipRangeColumn(mat, &cst, NULL));
141     PetscCall(PetscMalloc1(dn + on, &gidx));
142     for (i = 0; i < dn; i++) gidx[i] = cst + i;
143     for (i = 0; i < on; i++) gidx[i + dn] = cmap[i];
144     PetscCall(ISCreateGeneral(PetscObjectComm((PetscObject)Ad), dn + on, gidx, PETSC_OWN_POINTER, glob));
145   }
146   PetscFunctionReturn(PETSC_SUCCESS);
147 }
148 
149 /* Structs used in matrix products of type C=AB, C=A^tB and C=B^tAB */
150 struct MatMatStruct {
151   PetscInt            n, *garray;     // C's garray and its size.
152   KokkosCsrMatrix     Cd, Co;         // C is in split form matrices (all in local column indcies)
153   KokkosCsrMatrix     C1, C2, C3, C4; // intermediate mat products
154   KokkosCsrMatrix     C2_mid, C4_mid; // alias of C2, C4; share their a[], i[], but with different j[] (hence column size)
155   PetscIntKokkosView  E_NzLeft;
156   PetscSF             sf = nullptr; // SF to bcast or reduce matrices E to F
157   MatScalarKokkosView rootBuf, leafBuf;
158   KokkosCsrMatrix     Fd, Fo; // F in split form
159 
160   KernelHandle kh1; // compute C1, add C1+C3 or C1+Fd
161   KernelHandle kh2; // compute C2, add C2+C4 or C2+Fo
162   KernelHandle kh3; // compute C3
163   KernelHandle kh4; // compute C4
164 
165   PetscInt E_TeamSize; // kernel launching parameters in merging E or splitting F
166   PetscInt E_VectorLength;
167   PetscInt E_RowsPerTeam;
168   PetscInt F_TeamSize;
169   PetscInt F_VectorLength;
170   PetscInt F_RowsPerTeam;
171 
172   ~MatMatStruct()
173   {
174     PetscFunctionBegin;
175     PetscCallAbort(PETSC_COMM_SELF, PetscSFDestroy(&sf));
176     PetscFunctionReturnVoid();
177   }
178 };
179 
180 struct MatMatStruct_AB : public MatMatStruct {
181   PetscIntKokkosView F_NzLeft; // plans to split F (in leafbuf) into Fd, Fo
182   PetscIntKokkosView irootloc; // plans to put E (i.e., Bd, Bo) into rootBuf
183   PetscIntKokkosView rowoffset;
184 };
185 
186 struct MatMatStruct_AtB : public MatMatStruct {
187   MatColIdxKokkosView Fdjmap; // plans to reduce data in rootBuf to Fd, Fo
188   MatColIdxKokkosView Fdjperm;
189   MatColIdxKokkosView Fojmap;
190   MatColIdxKokkosView Fojperm;
191 };
192 
193 struct MatProductData_MPIAIJKokkos {
194   MatMatStruct_AB  *mmAB     = nullptr;
195   MatMatStruct_AtB *mmAtB    = nullptr;
196   PetscBool         reusesym = PETSC_FALSE;
197   Mat               Z        = nullptr; // store Z=AB in computing BtAB
198 
199   ~MatProductData_MPIAIJKokkos()
200   {
201     delete mmAB;
202     delete mmAtB;
203     PetscCallAbort(PETSC_COMM_SELF, MatDestroy(&Z));
204   }
205 };
206 
207 static PetscErrorCode MatProductDataDestroy_MPIAIJKokkos(void *data)
208 {
209   PetscFunctionBegin;
210   PetscCallCXX(delete static_cast<MatProductData_MPIAIJKokkos *>(data));
211   PetscFunctionReturn(PETSC_SUCCESS);
212 }
213 
214 /* MatSetMPIAIJKokkosWithSplitSeqAIJKokkosMatrices - Set the diag and offdiag matrices of a MATMPIAIJKOKKOS matrix.
215    It is similar to MatCreateMPIAIJWithSplitArrays.
216 
217   Input Parameters:
218 +  mat   - the MATMPIAIJKOKKOS matrix, which should have its type and layout set, but should not have its diag, offdiag matrices set
219 .  A     - the diag matrix using local col ids
220 -  B     - the offdiag matrix using global col ids
221 
222   Output Parameter:
223 .  mat   - the updated MATMPIAIJKOKKOS matrix
224 */
225 static PetscErrorCode MatSetMPIAIJKokkosWithSplitSeqAIJKokkosMatrices(Mat mat, Mat A, Mat B, PetscInt *garray)
226 {
227   Mat_MPIAIJ *mpiaij = static_cast<Mat_MPIAIJ *>(mat->data);
228   PetscInt    m, n, M, N, Am, An, Bm, Bn;
229 
230   PetscFunctionBegin;
231   PetscCall(MatGetSize(mat, &M, &N));
232   PetscCall(MatGetLocalSize(mat, &m, &n));
233   PetscCall(MatGetLocalSize(A, &Am, &An));
234   PetscCall(MatGetLocalSize(B, &Bm, &Bn));
235 
236   PetscCheck(m == Am && m == Bm, PETSC_COMM_SELF, PETSC_ERR_PLIB, "local number of rows do not match");
237   PetscCheck(n == An, PETSC_COMM_SELF, PETSC_ERR_PLIB, "local number of columns do not match");
238   // PetscCheck(N == Bn, PETSC_COMM_SELF, PETSC_ERR_PLIB, "global number of columns do not match");
239   PetscCheck(!mpiaij->A && !mpiaij->B, PETSC_COMM_SELF, PETSC_ERR_PLIB, "A, B of the MPIAIJ matrix are not empty");
240   mpiaij->A      = A;
241   mpiaij->B      = B;
242   mpiaij->garray = garray;
243 
244   mat->preallocated     = PETSC_TRUE;
245   mat->nooffprocentries = PETSC_TRUE; /* See MatAssemblyBegin_MPIAIJ. In effect, making MatAssemblyBegin a nop */
246 
247   PetscCall(MatSetOption(mat, MAT_NO_OFF_PROC_ENTRIES, PETSC_TRUE));
248   PetscCall(MatAssemblyBegin(mat, MAT_FINAL_ASSEMBLY));
249   /* MatAssemblyEnd is critical here. It sets mat->offloadmask according to A and B's, and
250     also gets mpiaij->B compacted, with its col ids and size reduced
251   */
252   PetscCall(MatAssemblyEnd(mat, MAT_FINAL_ASSEMBLY));
253   PetscCall(MatSetOption(mat, MAT_NO_OFF_PROC_ENTRIES, PETSC_FALSE));
254   PetscCall(MatSetOption(mat, MAT_NEW_NONZERO_LOCATION_ERR, PETSC_TRUE));
255   PetscFunctionReturn(PETSC_SUCCESS);
256 }
257 
258 // Adapted from Kokkos-Kernels spmv_launch_parameters(), to get parameters in Kokkos nested loops which we used to merge or
259 // split csr matrices. The rule is to have "vector_length * team_size" be around 256 on GPUs (e.g., for a CUDA thread block)
260 template <class ExecutionSpace>
261 static PetscErrorCode MatMergeGetLaunchParameters(PetscInt numRows, PetscInt nnz, PetscInt rows_per_thread, PetscInt &team_size, PetscInt &vector_length, PetscInt &rows_per_team)
262 {
263   Kokkos::TeamPolicy<ExecutionSpace> teamPolicy(128, Kokkos::AUTO);
264 
265   PetscFunctionBegin;
266   PetscInt nnz_per_row = numRows ? (nnz / numRows) : 0; // we might meet empty matrices
267 
268   if (nnz_per_row < 1) nnz_per_row = 1;
269 
270   int max_vector_length = teamPolicy.vector_length_max();
271 
272   if (vector_length < 1) {
273     vector_length = 1;
274     while (vector_length < max_vector_length && vector_length * 6 < nnz_per_row) vector_length *= 2;
275   }
276 
277   // Determine rows per thread
278   if (rows_per_thread < 1) {
279     if (KokkosKernels::Impl::kk_is_gpu_exec_space<ExecutionSpace>()) rows_per_thread = 1;
280     else {
281       if (nnz_per_row < 20 && nnz > 5000000) {
282         rows_per_thread = 256;
283       } else rows_per_thread = 64;
284     }
285   }
286 
287   if (team_size < 1) {
288     if (KokkosKernels::Impl::kk_is_gpu_exec_space<ExecutionSpace>()) {
289       team_size = 256 / vector_length;
290     } else {
291       team_size = 1;
292     }
293   }
294 
295   rows_per_team = rows_per_thread * team_size;
296 
297   if (rows_per_team < 0) {
298     PetscInt nnz_per_team = 4096;
299     PetscInt conc         = ExecutionSpace().concurrency();
300     while ((conc * nnz_per_team * 4 > nnz) && (nnz_per_team > 256)) nnz_per_team /= 2;
301     rows_per_team = (nnz_per_team + nnz_per_row - 1) / nnz_per_row;
302   }
303   PetscFunctionReturn(PETSC_SUCCESS);
304 }
305 
306 /*
307   Reduce two sets of global indices into local ones
308 
309   Input Parameters:
310 +  n1          - size of garray1[], the first set
311 .  garray1[n1] - a sorted global index array (without duplicates)
312 .  m           - size of indices[], the second set
313 -  indices[m]  - a unsorted global index array (might have duplicates), which will be updated on output into local ones
314 
315   Output Parameters:
316 +  n2          - size of garray2[], the merged set, which combines garray1[] and indices[]
317 .  garray2[n2] - allocated by callee using PetscMalloc1(). Contains sorted unique global indices (without duplicates). Caller needs to free it.
318 .  map[n1]     - allocated by caller. It gives garray1[i] = garray2[map[i]]
319 -  indices[m]  - on output, global indices in this array are rewritten with local ones, i.e, indices_input[i] = garray2[indices_output[i]]
320 
321    Example, say
322     n1         = 5
323     garray1[5] = {1, 4, 7, 8, 10}
324     m          = 4
325     indices[4] = {2, 4, 8, 9}
326 
327    Combining them together, we have 7 global indices in garray2[]
328     n2         = 7
329     garray2[7] = {1, 2, 4, 7, 8, 9, 10}
330 
331    And we have map[] to connect "garray1[i] = garray2[map[i]], i=[0,n1)"
332     map[5] = {0, 2, 3, 4, 6}
333 
334    On output, indices[] is updated with local indices
335     indices[4] = {1, 2, 4, 5}
336 */
337 static PetscErrorCode ReduceTwoSetsOfGlobalIndices(PetscInt n1, const PetscInt *garray1, PetscInt m, PetscInt *indices, PetscInt *n2_, PetscInt **garray2_, PetscInt *map)
338 {
339   PetscHMapI    g2l = nullptr;
340   PetscHashIter iter;
341   PetscInt      tot, key, val; // total unique global indices. key is global id; val is local id
342   PetscInt      n2, *garray2;
343 
344   PetscFunctionBegin;
345   tot = 0;
346   PetscCall(PetscHMapICreateWithSize(n1, &g2l));
347   for (PetscInt i = 0; i < m; i++) {                                // insert those in indices[]
348     PetscCall(PetscHMapIGetWithDefault(g2l, indices[i], -1, &val)); // if not exist, val is set with -1
349     if (val < 0) PetscCall(PetscHMapISet(g2l, indices[i], tot++));  // val < 0 means gid is not in the hash table yet
350   }
351 
352   for (PetscInt i = 0; i < n1; i++) { // insert those in garray1[]
353     PetscCall(PetscHMapIGetWithDefault(g2l, garray1[i], -1, &val));
354     if (val < 0) PetscCall(PetscHMapISet(g2l, garray1[i], tot++));
355   }
356 
357   // Pull out (unique) globals in the hash table and put them in garray2[]
358   n2 = tot;
359   PetscCall(PetscMalloc1(n2, &garray2));
360   tot = 0;
361   PetscHashIterBegin(g2l, iter);
362   while (!PetscHashIterAtEnd(g2l, iter)) {
363     PetscHashIterGetKey(g2l, iter, key);
364     PetscHashIterNext(g2l, iter);
365     garray2[tot++] = key;
366   }
367 
368   // Sort garray2[] and then map them to local indices starting from 0
369   PetscCall(PetscSortInt(n2, garray2));
370   PetscCall(PetscHMapIClear(g2l));
371   for (PetscInt i = 0; i < tot; i++) PetscCall(PetscHMapISet(g2l, garray2[i], i)); // i is the local id
372 
373   // Rewrite indices[] with local indices
374   for (PetscInt i = 0; i < m; i++) {
375     PetscCall(PetscHMapIGetWithDefault(g2l, indices[i], -1, &val));
376     PetscAssert(val >= 0, PETSC_COMM_SELF, PETSC_ERR_PLIB, "Met a negative local column index");
377     indices[i] = val;
378   }
379   // Record the map that maps garray1[i] to garray2[map[i]]
380   for (PetscInt i = 0; i < n1; i++) PetscCall(PetscHMapIGetWithDefault(g2l, garray1[i], -1, &map[i]));
381   PetscCall(PetscHMapIDestroy(&g2l));
382   *n2_      = n2;
383   *garray2_ = garray2;
384   PetscFunctionReturn(PETSC_SUCCESS);
385 }
386 
387 /*
388   MatMPIAIJKokkosReduce - Reduce rows of a MPIAIJKOKKOS matrix (E, in split form) to produce another matrix (F, also in split form, stored in mm)
389 
390   It is the reverse of MatMPIAIJKokkosBcast() in some sense, but with a different signature since we do not really need a fully populated MPIAIJKOKKOS E.
391 
392   Think each row of E as a leaf, then the given ownerSF specifies roots for the leaves. Roots may connect to multiple leaves.
393   In this routine, we sparse-merge leaves (rows) at their roots to form potentially longer rows in F. F's number of rows will be nroots of ownerSF.
394 
395   Input Parameters:
396 +  comm       - MPI communicator of E
397 .  A          - diag block of E, using local column indices
398 .  B          - off-diag block of E, using local column indices
399 .  cstart      - (global) start column of Ed
400 .  cend        - (global) end column + 1 of Ed.  In other words, E's column ownership is in range of [cstart, cend)
401 .  garray1[n1] - global column indices of Eo. Here n1 is Eo's column size.
402 .  ownerSF     - the SF specifies ownership (root) of rows in E
403 .  reuse       - either MAT_INITIAL_MATRIX or MAT_REUSE_MATRIX
404 -  mm          - to stash intermediate data structures for reuse
405 
406   Output Parameters:
407 +  map[n1]  - allocated by caller. It maps garray1[] to garray2[]. See more at ReduceTwoSetsOfGlobalIndices().
408 -  mm       - contains various info, such as garray2[], F (Fd, Fo) etc.
409 
410   Notes:
411   When reuse = MAT_REUSE_MATRIX, cstart, cend, garray1, ownerSF, map are not significant.
412 
413  */
414 static PetscErrorCode MatMPIAIJKokkosReduceBegin(MPI_Comm comm, KokkosCsrMatrix A, KokkosCsrMatrix B, PetscInt cstart, PetscInt cend, const PetscInt *garray1, PetscSF ownerSF, MatReuse reuse, PetscInt *map, MatMatStruct_AtB *mm)
415 {
416   PetscFunctionBegin;
417   if (reuse == MAT_INITIAL_MATRIX) {
418     PetscInt Em = A.numRows(), Fm;
419     PetscInt n1 = B.numCols();
420 
421     PetscCall(PetscSFGetGraph(ownerSF, &Fm, NULL, NULL, NULL)); // Fm = #rows of F = nroots of ownerSF
422 
423     // Do the analysis on host
424     auto                 Ai_h = Kokkos::create_mirror_view_and_copy(Kokkos::HostSpace(), A.graph.row_map);
425     auto                 Aj_h = Kokkos::create_mirror_view_and_copy(Kokkos::HostSpace(), A.graph.entries);
426     auto                 Bi_h = Kokkos::create_mirror_view_and_copy(Kokkos::HostSpace(), B.graph.row_map);
427     auto                 Bj_h = Kokkos::create_mirror_view_and_copy(Kokkos::HostSpace(), B.graph.entries);
428     const MatRowMapType *Ai = Ai_h.data(), *Bi = Bi_h.data();
429     const MatColIdxType *Aj = Aj_h.data(), *Bj = Bj_h.data();
430 
431     // Count how many nonzeros of each row in E are in the left of the diag block [cstart,cend)
432     PetscIntKokkosViewHost E_NzLeft_h(NoInit("E_NzLeft_h"), Em), E_RowLen_h(NoInit("E_RowLen_h"), Em);
433     PetscInt              *E_NzLeft = E_NzLeft_h.data(), *E_RowLen = E_RowLen_h.data();
434     for (PetscInt i = 0; i < Em; i++) {
435       const PetscInt *first, *last, *it;
436       PetscInt        count, step;
437       // std::lower_bound(first,last,cstart), but need to use global column indices
438       first = Bj + Bi[i];
439       last  = Bj + Bi[i + 1];
440       count = last - first;
441       while (count > 0) {
442         it   = first;
443         step = count / 2;
444         it += step;
445         if (garray1[*it] < cstart) { // map local to global
446           first = ++it;
447           count -= step + 1;
448         } else count = step;
449       }
450       E_NzLeft[i] = first - (Bj + Bi[i]);
451       E_RowLen[i] = (Ai[i + 1] - Ai[i]) + (Bi[i + 1] - Bi[i]);
452     }
453 
454     // Get length of rows (i.e., sizes of leaves) that contribute to my roots
455     const PetscMPIInt *iranks, *ranks;
456     const PetscInt    *ioffset, *irootloc, *roffset, *rmine;
457     PetscInt           niranks, nranks;
458     MPI_Request       *reqs;
459     PetscMPIInt        tag;
460     PetscSF            reduceSF;
461     PetscInt          *sdisp, *rdisp;
462 
463     PetscCall(PetscCommGetNewTag(comm, &tag));
464     PetscCall(PetscSFGetLeafRanks(ownerSF, &niranks, &iranks, &ioffset, &irootloc));  // get leaf ranks connecting to roots on this process (I'll recv from them)
465     PetscCall(PetscSFGetRootRanks(ownerSF, &nranks, &ranks, &roffset, &rmine, NULL)); // get root ranks this process connects (I'll send to them)
466 
467     // Find out length of each row I will receive. Even for the same row index, when they are from
468     // different senders, they might have different lengths (and sparsity patterns)
469     PetscInt  sendRowCnt = roffset[nranks], recvRowCnt = ioffset[niranks];
470     PetscInt *sendRowLen, *recvRowLen; // lengths of rows of I need to send/recv per process
471 
472     PetscCall(PetscMalloc5(sendRowCnt, &sendRowLen, recvRowCnt + 1, &recvRowLen, nranks, &sdisp, niranks + 1, &rdisp, nranks + niranks, &reqs));
473 
474     for (PetscInt i = 0; i < sendRowCnt; i++) sendRowLen[i] = E_RowLen[rmine[i]];
475     recvRowLen[0] = 0; // since we will make it in CSR format later
476     recvRowLen++;      // advance the pointer now
477     for (PetscInt i = 0; i < niranks; i++) { MPI_Irecv(&recvRowLen[ioffset[i]], ioffset[i + 1] - ioffset[i], MPIU_INT, iranks[i], tag, comm, &reqs[nranks + i]); }
478     for (PetscInt i = 0; i < nranks; i++) { MPI_Isend(&sendRowLen[roffset[i]], roffset[i + 1] - roffset[i], MPIU_INT, ranks[i], tag, comm, &reqs[i]); }
479     PetscCallMPI(MPI_Waitall(nranks + niranks, reqs, MPI_STATUSES_IGNORE));
480 
481     // Build the real PetscSF for reducing E rows (buffer to buffer)
482     rdisp[0] = 0;
483     for (PetscInt i = 0; i < niranks; i++) {
484       rdisp[i + 1] = rdisp[i];
485       for (PetscInt j = ioffset[i]; j < ioffset[i + 1]; j++) { rdisp[i + 1] += recvRowLen[j]; }
486     }
487     recvRowLen--; // put it back into csr format
488     for (PetscInt i = 0; i < recvRowCnt; i++) recvRowLen[i + 1] += recvRowLen[i];
489 
490     for (PetscInt i = 0; i < nranks; i++) { MPI_Irecv(&sdisp[i], 1, MPIU_INT, ranks[i], tag, comm, &reqs[i]); }
491     for (PetscInt i = 0; i < niranks; i++) { MPI_Isend(&rdisp[i], 1, MPIU_INT, iranks[i], tag, comm, &reqs[nranks + i]); }
492     PetscCallMPI(MPI_Waitall(nranks + niranks, reqs, MPI_STATUSES_IGNORE));
493 
494     PetscInt     nleaves = 0, Enz = 0;    // leaves are nonzeros I will send
495     PetscInt     nroots = rdisp[niranks]; // roots are nonzeros I will recv
496     PetscSFNode *iremote;
497 
498     for (PetscInt i = 0; i < Em; i++) Enz += E_RowLen[i];
499     PetscAssert(A.nnz() + B.nnz() == Enz, comm, PETSC_ERR_PLIB, "Enz should be equal to sum of nnz of A and B");
500     PetscCallMPI(PetscMalloc1(Enz, &iremote)); // no free, since we give ownership to reduceSF
501 
502     for (PetscInt i = 0; i < nranks; i++) {
503       PetscInt count = 0;
504       for (PetscInt j = roffset[i]; j < roffset[i + 1]; j++) count += E_RowLen[rmine[j]];
505       for (PetscInt j = 0; j < count; j++) {
506         iremote[nleaves + j].rank  = ranks[i];
507         iremote[nleaves + j].index = sdisp[i] + j;
508       }
509       nleaves += count;
510     }
511     PetscCheck(nleaves == Enz, comm, PETSC_ERR_PLIB, "nleaves should be equal to Enz");
512 
513     PetscCall(PetscSFCreate(comm, &reduceSF));
514     PetscCall(PetscSFSetGraph(reduceSF, nroots, nleaves, NULL, PETSC_OWN_POINTER, iremote, PETSC_OWN_POINTER));
515 
516     // Copy (global) column indices of the needed rows in E to sendCol[], and then PetscSFReduce to recvCol[]
517     PetscInt *sendCol, *recvCol;
518     PetscCall(PetscMalloc2(nleaves, &sendCol, nroots, &recvCol));
519     for (PetscInt k = 0; k < roffset[nranks]; k++) {
520       PetscInt  i      = rmine[k]; // row to be copied
521       PetscInt *buf    = &sendCol[Ai[i] + Bi[i]];
522       PetscInt  nzLeft = E_NzLeft[i];
523       PetscInt  alen = Ai[i + 1] - Ai[i], blen = Bi[i + 1] - Bi[i];
524       for (PetscInt j = 0; j < alen + blen; j++) {
525         if (j < nzLeft) {
526           buf[j] = garray1[Bj[Bi[i] + j]]; // left B, in global
527         } else if (j < nzLeft + alen) {
528           buf[j] = Aj[Ai[i] + j - nzLeft] + cstart; // diag A, also in global
529         } else {
530           buf[j] = garray1[Bj[Bi[i] + j - alen]]; // right B, in global
531         }
532       }
533     }
534     PetscCall(PetscSFReduceWithMemTypeBegin(reduceSF, MPIU_INT, PETSC_MEMTYPE_HOST, sendCol, PETSC_MEMTYPE_HOST, recvCol, MPI_REPLACE));
535     PetscCall(PetscSFReduceEnd(reduceSF, MPIU_INT, sendCol, recvCol, MPI_REPLACE));
536 
537     // With recvCol[], we do a series of analysis to get i, j of Fd, Fo, and build plans to reduce nonzeros in recv buffers to Fd and Fo
538     PetscInt *recvRowPerm, *recvColSorted;
539     PetscInt *recvNzPerm, *recvNzPermSorted;
540     PetscCall(PetscMalloc4(recvRowCnt, &recvRowPerm, nroots, &recvColSorted, nroots, &recvNzPerm, nroots, &recvNzPermSorted));
541 
542     for (PetscInt i = 0; i < nroots; i++) recvNzPerm[i] = i;                   // numbering all received nonzeros
543     for (PetscInt i = 0; i < recvRowCnt; i++) recvRowPerm[i] = i;              // put up a permutation array, so that after sorting we know where to get a row in recvCol[]
544     PetscCall(PetscSortIntWithPermutation(recvRowCnt, irootloc, recvRowPerm)); // irootloc[] (owned by ownerSF) won't be changed
545 
546     // i[] array, nz are always easiest to compute
547     MatRowMapKokkosViewHost Fdi_h(NoInit("Fdi_h"), Fm + 1), Foi_h(NoInit("Foi_h"), Fm + 1);
548     MatRowMapType          *Fdi, *Foi;
549     PetscInt                FnzDups = 0, Fdnz = 0, FdnzDups = 0, Fonz = 0, FonzDups = 0; // nz (with or without dups) in F, Fd, Fo
550     PetscInt                iter;
551 
552     Kokkos::deep_copy(Fdi_h, 0); // zero, as we will do 'val++' on them
553     Kokkos::deep_copy(Foi_h, 0);
554     Fdi  = Fdi_h.data() + 1; // +1 for easy indexing in code below
555     Foi  = Foi_h.data() + 1;
556     iter = 0;
557     while (iter < recvRowCnt) { // iter over received rows
558       PetscInt curRowIdx = irootloc[recvRowPerm[iter]];
559       PetscInt dupRows   = 1; // current row has this many contributing rows (of various sparsity patterns)
560 
561       while (iter + dupRows < recvRowCnt && irootloc[recvRowPerm[iter + dupRows]] == curRowIdx) dupRows++;
562 
563       // Copy column indices (and their permutation) of these rows into recvColSorted & recvNzPermSorted
564       PetscInt  nz    = 0; // nz (with dups) in the current row
565       PetscInt *jbuf  = recvColSorted + FnzDups;
566       PetscInt *pbuf  = recvNzPermSorted + FnzDups;
567       PetscInt *jbuf2 = jbuf; // temp pointers
568       PetscInt *pbuf2 = pbuf;
569       for (PetscInt d = 0; d < dupRows; d++) {
570         PetscInt i   = recvRowPerm[iter + d];
571         PetscInt len = recvRowLen[i + 1] - recvRowLen[i];
572         PetscCall(PetscArraycpy(jbuf2, &recvCol[recvRowLen[i]], len));
573         PetscCall(PetscArraycpy(pbuf2, &recvNzPerm[recvRowLen[i]], len));
574         jbuf2 += len;
575         pbuf2 += len;
576         nz += len;
577       }
578       PetscCall(PetscIntSortSemiOrderedWithArray(nz, jbuf, pbuf)); // It could be improved with k-way merge sort, since the rows are already sorted
579 
580       // Scan column indices (in jbuf[0,nz), might have dups) of this row, and see how many go to Fd and how many go to Fo
581       PetscInt cur = 0;
582       while (cur < nz) {
583         PetscInt curColIdx = jbuf[cur];
584         PetscInt dups      = 1;
585 
586         while (cur + dups < nz && jbuf[cur + dups] == curColIdx) dups++;
587         if (curColIdx >= cstart && curColIdx < cend) {
588           Fdi[curRowIdx]++;
589           FdnzDups += dups;
590         } else {
591           Foi[curRowIdx]++;
592           FonzDups += dups;
593         }
594         cur += dups;
595       }
596 
597       FnzDups += nz;
598       iter += dupRows; // Move to next unique row
599     }
600 
601     Fdi = Fdi_h.data(); // restore Fdi, Foi and make them CSR
602     Foi = Foi_h.data();
603     for (PetscInt i = 0; i < Fm; i++) {
604       Fdi[i + 1] += Fdi[i];
605       Foi[i + 1] += Foi[i];
606     }
607     Fdnz = Fdi[Fm];
608     Fonz = Foi[Fm];
609     PetscCall(PetscFree2(sendCol, recvCol));
610 
611     // Allocate j, jmap, jperm for Fd and Fo
612     MatColIdxKokkosViewHost Fdj_h(NoInit("Fdj_h"), Fdnz), Foj_h(NoInit("Foj_h"), Fonz);
613     MatRowMapKokkosViewHost Fdjmap_h(NoInit("Fdjmap_h"), Fdnz + 1), Fojmap_h(NoInit("Fojmap_h"), Fonz + 1); // +1 to make csr
614     MatRowMapKokkosViewHost Fdjperm_h(NoInit("Fdjperm_h"), FdnzDups), Fojperm_h(NoInit("Fojperm_h"), FonzDups);
615     MatColIdxType          *Fdj = Fdj_h.data(), *Foj = Foj_h.data();
616     MatRowMapType          *Fdjmap = Fdjmap_h.data(), *Fojmap = Fojmap_h.data();
617     MatRowMapType          *Fdjperm = Fdjperm_h.data(), *Fojperm = Fojperm_h.data();
618 
619     // Scan recvColSorted[] again, and fill j, jmap, jperm for Fd and Fo
620     Fdjmap[0] = 0;
621     Fojmap[0] = 0;
622     FnzDups   = 0;
623     Fdnz      = 0;
624     Fonz      = 0;
625     iter      = 0; // iter over received rows
626     while (iter < recvRowCnt) {
627       PetscInt curRowIdx = irootloc[recvRowPerm[iter]]; // current row idx
628       PetscInt dupRows   = 1;                           // It has this many contributing rows (of various lengths)
629       PetscInt nz        = 0;                           // nz (with dups) in the current row
630 
631       while (iter + dupRows < recvRowCnt && irootloc[recvRowPerm[iter + dupRows]] == curRowIdx) dupRows++;
632       for (PetscInt d = 0; d < dupRows; d++) {
633         PetscInt i = recvRowPerm[iter + d];
634         nz += recvRowLen[i + 1] - recvRowLen[i];
635       }
636 
637       PetscInt *jbuf = recvColSorted + FnzDups;
638       // Scan columns (in jbuf[0,nz) of this row, copy them and their permutation to j[] and jperm[] of Fd and Fo
639       PetscInt cur = 0;
640       while (cur < nz) {
641         PetscInt curColIdx = jbuf[cur];
642         PetscInt dups      = 1;
643 
644         while (cur + dups < nz && jbuf[cur + dups] == curColIdx) dups++;
645         if (curColIdx >= cstart && curColIdx < cend) {
646           Fdj[Fdnz]        = curColIdx - cstart; // easily convert to local
647           Fdjmap[Fdnz + 1] = Fdjmap[Fdnz] + dups;
648           for (PetscInt j = 0; j < dups; j++) Fdjperm[Fdjmap[Fdnz] + j] = recvNzPermSorted[FnzDups + j];
649           FdnzDups += dups;
650           Fdnz++;
651         } else {
652           Foj[Fonz]        = curColIdx; // in global
653           Fojmap[Fonz + 1] = Fojmap[Fonz] + dups;
654           for (PetscInt j = 0; j < dups; j++) Fojperm[Fojmap[Fonz] + j] = recvNzPermSorted[FnzDups + j];
655           FonzDups += dups;
656           Fonz++;
657         }
658         cur += dups;
659         FnzDups += dups;
660       }
661       iter += dupRows; // Move to next unique row
662     }
663     PetscCall(PetscFree4(recvRowPerm, recvColSorted, recvNzPerm, recvNzPermSorted));
664     PetscCall(PetscFree5(sendRowLen, recvRowLen, sdisp, rdisp, reqs));
665 
666     // Combine global column indices in garray1[] and Foj[]
667     PetscInt n2, *garray2;
668 
669     PetscCall(ReduceTwoSetsOfGlobalIndices(n1, garray1, Fonz, Foj, &n2, &garray2, map));
670     mm->sf       = reduceSF;
671     mm->leafBuf  = MatScalarKokkosView(NoInit("leafBuf"), nleaves);
672     mm->rootBuf  = MatScalarKokkosView(NoInit("rootBuf"), nroots);
673     mm->garray   = garray2; // give ownership, so no free
674     mm->n        = n2;
675     mm->E_NzLeft = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), E_NzLeft_h);
676     mm->Fdjmap   = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Fdjmap_h);
677     mm->Fdjperm  = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Fdjperm_h);
678     mm->Fojmap   = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Fojmap_h);
679     mm->Fojperm  = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Fojperm_h);
680 
681     // Output Fd and Fo in KokkosCsrMatrix format
682     MatScalarKokkosView Fda_d(NoInit("Fda_d"), Fdnz);
683     MatRowMapKokkosView Fdi_d = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Fdi_h);
684     MatColIdxKokkosView Fdj_d = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Fdj_h);
685     MatScalarKokkosView Foa_d(NoInit("Foa_d"), Fonz);
686     MatRowMapKokkosView Foi_d = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Foi_h);
687     MatColIdxKokkosView Foj_d = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Foj_h);
688 
689     PetscCallCXX(mm->Fd = KokkosCsrMatrix("Fd", Fm, cend - cstart, Fdnz, Fda_d, Fdi_d, Fdj_d));
690     PetscCallCXX(mm->Fo = KokkosCsrMatrix("Fo", Fm, n2, Fonz, Foa_d, Foi_d, Foj_d)); // Fo's column size is n2, length of garray2[]
691 
692     // Compute kernel launch parameters in merging E
693     PetscInt teamSize, vectorLength, rowsPerTeam;
694 
695     teamSize = vectorLength = rowsPerTeam = -1;
696     PetscCall(MatMergeGetLaunchParameters<DefaultExecutionSpace>(Em, Enz, -1, teamSize, vectorLength, rowsPerTeam));
697     mm->E_TeamSize     = teamSize;
698     mm->E_VectorLength = vectorLength;
699     mm->E_RowsPerTeam  = rowsPerTeam;
700   } else PetscCheck(reuse == MAT_REUSE_MATRIX, comm, PETSC_ERR_PLIB, "Unsupported MatReuse enum %d", reuse);
701 
702   // Handy aliases
703   auto       &Aa           = A.values;
704   auto       &Ba           = B.values;
705   const auto &Ai           = A.graph.row_map;
706   const auto &Bi           = B.graph.row_map;
707   const auto &E_NzLeft     = mm->E_NzLeft;
708   auto       &leafBuf      = mm->leafBuf;
709   auto       &rootBuf      = mm->rootBuf;
710   PetscSF     reduceSF     = mm->sf;
711   PetscInt    Em           = A.numRows();
712   PetscInt    teamSize     = mm->E_TeamSize;
713   PetscInt    vectorLength = mm->E_VectorLength;
714   PetscInt    rowsPerTeam  = mm->E_RowsPerTeam;
715   PetscInt    workSets     = (Em + rowsPerTeam - 1) / rowsPerTeam;
716 
717   // Copy rows in A/B of E to leafBuf, then pass it to rootBuf
718   PetscCallCXX(Kokkos::parallel_for(
719     Kokkos::TeamPolicy<>(PetscGetKokkosExecutionSpace(), workSets, teamSize, vectorLength), KOKKOS_LAMBDA(const KokkosTeamMemberType &t) {
720       Kokkos::parallel_for(Kokkos::TeamThreadRange(t, 0, rowsPerTeam), [&](PetscInt k) {
721         PetscInt i = t.league_rank() * rowsPerTeam + k; // i-th row in F
722         if (i < Em) {
723           PetscInt disp   = Ai(i) + Bi(i);
724           PetscInt alen   = Ai(i + 1) - Ai(i);
725           PetscInt blen   = Bi(i + 1) - Bi(i);
726           PetscInt nzleft = E_NzLeft(i);
727 
728           Kokkos::parallel_for(Kokkos::ThreadVectorRange(t, alen + blen), [&](PetscInt j) {
729             MatScalar &val = leafBuf(disp + j);
730             if (j < nzleft) { // B left
731               val = Ba(Bi(i) + j);
732             } else if (j < nzleft + alen) { // diag A
733               val = Aa(Ai(i) + j - nzleft);
734             } else { // B right
735               val = Ba(Bi(i) + j - alen);
736             }
737           });
738         }
739       });
740     }));
741   PetscCall(PetscSFReduceWithMemTypeBegin(reduceSF, MPIU_SCALAR, PETSC_MEMTYPE_KOKKOS, leafBuf.data(), PETSC_MEMTYPE_KOKKOS, rootBuf.data(), MPI_REPLACE));
742   PetscFunctionReturn(PETSC_SUCCESS);
743 }
744 
745 // To finish MatMPIAIJKokkosReduce.
746 static PetscErrorCode MatMPIAIJKokkosReduceEnd(MPI_Comm comm, KokkosCsrMatrix A, KokkosCsrMatrix B, PetscInt cstart, PetscInt cend, const PetscInt *garray1, PetscSF ownerSF, MatReuse reuse, PetscInt *map, MatMatStruct_AtB *mm)
747 {
748   auto       &leafBuf  = mm->leafBuf;
749   auto       &rootBuf  = mm->rootBuf;
750   auto       &Fda      = mm->Fd.values;
751   const auto &Fdjmap   = mm->Fdjmap;
752   const auto &Fdjperm  = mm->Fdjperm;
753   auto        Fdnz     = mm->Fd.nnz();
754   auto       &Foa      = mm->Fo.values;
755   const auto &Fojmap   = mm->Fojmap;
756   const auto &Fojperm  = mm->Fojperm;
757   auto        Fonz     = mm->Fo.nnz();
758   PetscSF     reduceSF = mm->sf;
759 
760   PetscFunctionBegin;
761   PetscCall(PetscSFReduceEnd(reduceSF, MPIU_SCALAR, leafBuf.data(), rootBuf.data(), MPI_REPLACE));
762 
763   // Reduce data in rootBuf to Fd and Fo
764   PetscCallCXX(Kokkos::parallel_for(
765     Kokkos::RangePolicy<>(PetscGetKokkosExecutionSpace(), 0, Fdnz), KOKKOS_LAMBDA(const MatRowMapType i) {
766       PetscScalar sum = 0.0;
767       for (MatRowMapType k = Fdjmap(i); k < Fdjmap(i + 1); k++) sum += rootBuf(Fdjperm(k));
768       Fda(i) = sum;
769     }));
770 
771   PetscCallCXX(Kokkos::parallel_for(
772     Kokkos::RangePolicy<>(PetscGetKokkosExecutionSpace(), 0, Fonz), KOKKOS_LAMBDA(const MatRowMapType i) {
773       PetscScalar sum = 0.0;
774       for (MatRowMapType k = Fojmap(i); k < Fojmap(i + 1); k++) sum += rootBuf(Fojperm(k));
775       Foa(i) = sum;
776     }));
777   PetscFunctionReturn(PETSC_SUCCESS);
778 }
779 
780 /*
781   MatMPIAIJKokkosBcast - Bcast local rows of a MPIAIJKOKKOS matrix (E) to produce a local matrix (F, stored in mm) in split form
782 
783   This is a complex routine. It is essentially the MPIAIJKOKKOS counterpart of MatGetBrowsOfAoCols_MPIAIJ, but supports
784   device and involves various index mapping.
785 
786   In the given ownerSF, leaves correspond to rows in F, and roots correspond to rows in E. Roots may connect to multiple leaves.
787   Suppose F's j-th row is connected to a root identified by PetscSFNode (k,i), it means we need to bcast the i-th row of E on rank k
788   to j-th row of F. ownerSF is not an arbitrary SF, instead it is the Mvctx of another MPIAIJ matrix A that is able to perform A*E.
789   F has the same column layout as E.
790 
791   Conceptually F has global column indices. In this routine, we spit F into diagonal Fd and off-diagonal Fo.
792   Fd uses local column indices, which are easy to compute. We just need to subtract the "local column range start" from the global indices.
793   Fo had global column indices at first. We will reduce them into local ones. In doing that, we also take into account the global
794   column indices that E's off-diag block has. Let's say there are n1 such indices stored in garray1[]. We will reduce them along with
795   column indices in Fo and update Fo with local indices.
796 
797    Input Parameters:
798 +   E       - the MPIAIJKOKKOS matrix
799 .   ownerSF - the ownership SF (insignificant in MAT_REUSE_MATRIX)
800 .   reuse   - either MAT_INITIAL_MATRIX or MAT_REUSE_MATRIX
801 -   mm      - to stash matproduct intermediate data structures
802 
803     Output Parameters:
804 +   map[n1] - allocated by caller. It maps garray1[] to garray2[]. See more at ReduceTwoSetsOfGlobalIndices.
805 -   mm      - contains various info, such as garray2[], Fd, Fo, etc.
806 
807     Notes:
808     When reuse = MAT_REUSE_MATRIX, ownerSF, map are not significant.
809     The routine is provide in split-phase form MatMPIAIJKokkosBcastBegin/End() to provide computation/communication opportunities.
810 */
811 static PetscErrorCode MatMPIAIJKokkosBcastBegin(Mat E, PetscSF ownerSF, MatReuse reuse, PetscInt *map, MatMatStruct_AB *mm)
812 {
813   Mat_MPIAIJ       *empi = static_cast<Mat_MPIAIJ *>(E->data);
814   Mat               A = empi->A, B = empi->B; // diag and off-diag
815   Mat_SeqAIJKokkos *akok = static_cast<Mat_SeqAIJKokkos *>(A->spptr), *bkok = static_cast<Mat_SeqAIJKokkos *>(B->spptr);
816   PetscInt          Em = E->rmap->n; // #local rows
817   MPI_Comm          comm;
818 
819   PetscFunctionBegin;
820   PetscCallMPI(PetscObjectGetComm((PetscObject)E, &comm));
821   if (reuse == MAT_INITIAL_MATRIX) {
822     Mat_SeqAIJ     *aseq = static_cast<Mat_SeqAIJ *>(A->data), *bseq = static_cast<Mat_SeqAIJ *>(B->data);
823     PetscInt        n1 = B->cmap->n, *Ai = aseq->i, *Aj = aseq->j, *Bi = bseq->i, *Bj = bseq->j;
824     const PetscInt *garray1 = empi->garray; // its size is n1
825     PetscInt        cstart, cend;
826     PetscSF         bcastSF;
827 
828     PetscCall(MatGetOwnershipRangeColumn(E, &cstart, &cend));
829 
830     // Count how many nonzeros of each row in E are in the left of the diag block [cstart,cend)
831     PetscIntKokkosViewHost E_NzLeft_h(NoInit("E_NzLeft_h"), Em), E_RowLen_h(NoInit("E_RowLen_h"), Em);
832     PetscInt              *E_NzLeft = E_NzLeft_h.data(), *E_RowLen = E_RowLen_h.data();
833     for (PetscInt i = 0; i < Em; i++) {
834       const PetscInt *first, *last, *it;
835       PetscInt        count, step;
836       // std::lower_bound(first,last,cstart), but need to use global column indices
837       first = Bj + Bi[i];
838       last  = Bj + Bi[i + 1];
839       count = last - first;
840       while (count > 0) {
841         it   = first;
842         step = count / 2;
843         it += step;
844         if (empi->garray[*it] < cstart) { // map local to global
845           first = ++it;
846           count -= step + 1;
847         } else count = step;
848       }
849       E_NzLeft[i] = first - (Bj + Bi[i]);
850       E_RowLen[i] = (Ai[i + 1] - Ai[i]) + (Bi[i + 1] - Bi[i]);
851     }
852 
853     // Compute row pointer Fi of F
854     PetscInt *Fi, Fm, Fnz;
855     PetscCall(PetscSFGetGraph(ownerSF, NULL, &Fm, NULL, NULL)); // Fm = #rows of F = nleaves of ownerSF
856     PetscCall(PetscMalloc1(Fm + 1, &Fi));
857     Fi[0] = 0;
858     PetscCall(PetscSFBcastWithMemTypeBegin(ownerSF, MPIU_INT, PETSC_MEMTYPE_HOST, E_RowLen, PETSC_MEMTYPE_HOST, &Fi[1], MPI_REPLACE));
859     PetscCall(PetscSFBcastEnd(ownerSF, MPIU_INT, E_RowLen, &Fi[1], MPI_REPLACE));
860     for (PetscInt i = 0; i < Fm; i++) Fi[i + 1] += Fi[i];
861     Fnz = Fi[Fm];
862 
863     // Build the real PetscSF for bcasting E rows (buffer to buffer)
864     const PetscMPIInt *iranks, *ranks;
865     const PetscInt    *ioffset, *irootloc, *roffset;
866     PetscInt           niranks, nranks, *sdisp, *rdisp;
867     MPI_Request       *reqs;
868     PetscMPIInt        tag;
869 
870     PetscCall(PetscSFGetLeafRanks(ownerSF, &niranks, &iranks, &ioffset, &irootloc)); // get leaf ranks referencing roots on this process
871     PetscCall(PetscSFGetRootRanks(ownerSF, &nranks, &ranks, &roffset, NULL, NULL));  // recv info
872     PetscCall(PetscMalloc3(niranks + 1, &sdisp, nranks, &rdisp, niranks + nranks, &reqs));
873 
874     sdisp[0] = 0; // send displacement
875     for (PetscInt i = 0; i < niranks; i++) {
876       sdisp[i + 1] = sdisp[i];
877       for (PetscInt j = ioffset[i]; j < ioffset[i + 1]; j++) {
878         PetscInt r = irootloc[j]; // row to be sent
879         sdisp[i + 1] += E_RowLen[r];
880       }
881     }
882 
883     PetscCallMPI(PetscCommGetNewTag(comm, &tag));
884     for (PetscInt i = 0; i < nranks; i++) PetscCallMPI(MPI_Irecv(&rdisp[i], 1, MPIU_INT, ranks[i], tag, comm, &reqs[i]));
885     for (PetscInt i = 0; i < niranks; i++) PetscCallMPI(MPI_Isend(&sdisp[i], 1, MPIU_INT, iranks[i], tag, comm, &reqs[nranks + i]));
886     PetscCallMPI(MPI_Waitall(niranks + nranks, reqs, MPI_STATUSES_IGNORE));
887 
888     PetscInt     nleaves = Fnz;            // leaves are nonzeros I will receive
889     PetscInt     nroots  = sdisp[niranks]; // roots are nonzeros I will send
890     PetscSFNode *iremote;                  // give ownership to bcastSF
891     PetscCall(PetscMalloc1(nleaves, &iremote));
892     for (PetscInt i = 0; i < nranks; i++) { // for each sender rank
893       PetscInt k = 0;
894       for (PetscInt j = Fi[roffset[i]]; j < Fi[roffset[i + 1]]; j++) { // I will receive rows [roffset[i], roffset[i+1]) of F from ranks[i]
895         iremote[j].rank  = ranks[i];
896         iremote[j].index = rdisp[i] + k; // their root location
897         k++;
898       }
899     }
900     PetscCall(PetscSFCreate(comm, &bcastSF));
901     PetscCall(PetscSFSetGraph(bcastSF, nroots, nleaves, NULL, PETSC_OWN_POINTER, iremote, PETSC_OWN_POINTER));
902     PetscCall(PetscFree3(sdisp, rdisp, reqs));
903 
904     // Build a plan (rowoffset, irootloc, E_NzLeft) to copy rows in E to rootdata of bcastSF in parallel
905     PetscIntKokkosViewHost rowoffset_h(NoInit("rowoffset_h"), ioffset[niranks] + 1);
906     PetscInt              *rowoffset = rowoffset_h.data(); // for each entry (row) indicated in irootloc[], we calculate its destinate offset in copying
907     rowoffset[0]                     = 0;
908     for (PetscInt i = 0; i < ioffset[niranks]; i++) { rowoffset[i + 1] = rowoffset[i] + E_RowLen[irootloc[i]]; }
909 
910     // Copy (global) column indices of the needed rows in E to a buffer, and then bcast to Fj[]
911     PetscInt *jbuf, *Fj;
912     PetscCall(PetscMalloc2(nroots, &jbuf, Fnz, &Fj));
913     for (PetscInt k = 0; k < ioffset[niranks]; k++) {
914       PetscInt  i      = irootloc[k]; // row to be copied
915       PetscInt *buf    = &jbuf[rowoffset[k]];
916       PetscInt  nzLeft = E_NzLeft[i];
917       PetscInt  alen = Ai[i + 1] - Ai[i], blen = Bi[i + 1] - Bi[i];
918       for (PetscInt j = 0; j < alen + blen; j++) {
919         if (j < nzLeft) {
920           buf[j] = empi->garray[Bj[Bi[i] + j]]; // left B, in global
921         } else if (j < nzLeft + alen) {
922           buf[j] = Aj[Ai[i] + j - nzLeft] + cstart; // diag A, also in global
923         } else {
924           buf[j] = empi->garray[Bj[Bi[i] + j - alen]]; // right B, in global
925         }
926       }
927     }
928     PetscCall(PetscSFBcastWithMemTypeBegin(bcastSF, MPIU_INT, PETSC_MEMTYPE_HOST, jbuf, PETSC_MEMTYPE_HOST, Fj, MPI_REPLACE));
929     PetscCall(PetscSFBcastEnd(bcastSF, MPIU_INT, jbuf, Fj, MPI_REPLACE));
930 
931     // Build a plan (i.e., F_NzLeft) to split F into Fd and Fo
932     MatRowMapKokkosViewHost Fdi_h(NoInit("Fdi_h"), Fm + 1), Foi_h(NoInit("Foi_h"), Fm + 1); // row pointer of Fd, Fo
933     MatColIdxKokkosViewHost F_NzLeft_h(NoInit("F_NzLeft_h"), Fm);                           // split each row of F into Left, Diag, Right. We only need to record #nz in Left and Diag.
934     MatRowMapType          *Fdi = Fdi_h.data(), *Foi = Foi_h.data();
935     MatColIdxType          *F_NzLeft = F_NzLeft_h.data();
936 
937     Fdi[0] = Foi[0] = 0;
938     for (PetscInt i = 0; i < Fm; i++) {
939       PetscInt *first, *last, *lb1, *lb2;
940       // cut the row into: Left, [cstart, cend), Right
941       first       = Fj + Fi[i];
942       last        = Fj + Fi[i + 1];
943       lb1         = std::lower_bound(first, last, cstart);
944       F_NzLeft[i] = lb1 - first;
945       lb2         = std::lower_bound(first, last, cend);
946       Fdi[i + 1]  = lb2 - lb1;                        // row i length in Fdi
947       Foi[i + 1]  = (Fi[i + 1] - Fi[i]) - Fdi[i + 1]; // row i length in Foi
948     }
949     for (PetscInt i = 0; i < Fm; i++) {
950       Fdi[i + 1] += Fdi[i];
951       Foi[i + 1] += Foi[i];
952     }
953 
954     // Fill Fdj[] and Foj[], i.e., columns of Fd and Fo. Fdj[] are local, but Foj[] are not yet.
955     PetscInt                Fdnz = Fdi[Fm], Fonz = Foi[Fm];
956     MatColIdxKokkosViewHost Fdj_h(NoInit("Fdj_h"), Fdnz), Foj_h(NoInit("Foj_h"), Fonz);
957     MatColIdxType          *Fdj = Fdj_h.data(), *Foj = Foj_h.data(), gid;
958 
959     for (PetscInt i = 0; i < Fm; i++) {
960       PetscInt nzLeft = F_NzLeft[i];
961       PetscInt len    = Fdi[i + 1] - Fdi[i]; // diag row len
962       for (PetscInt j = 0; j < Fi[i + 1] - Fi[i]; j++) {
963         gid = Fj[Fi[i] + j];
964         if (j < nzLeft) { // left, in global
965           Foj[Foi[i] + j] = gid;
966         } else if (j < nzLeft + len) { // diag, in local
967           Fdj[Fdi[i] + j - nzLeft] = gid - cstart;
968         } else { // right, in global
969           Foj[Foi[i] + j - len] = gid;
970         }
971       }
972     }
973     PetscCall(PetscFree2(jbuf, Fj));
974     PetscCall(PetscFree(Fi));
975 
976     // Reduce global indices in Foj[] and garray1[] into local ones
977     PetscInt n2, *garray2;
978     PetscCall(ReduceTwoSetsOfGlobalIndices(n1, garray1, Fonz, Foj, &n2, &garray2, map));
979 
980     // Record the plans built above, for reuse
981     PetscIntKokkosViewHost tmp(const_cast<PetscInt *>(irootloc), ioffset[niranks]); // irootloc[] is owned by ownerSF. We create a copy for safety
982     PetscIntKokkosViewHost irootloc_h(NoInit("irootloc_h"), ioffset[niranks]);
983     Kokkos::deep_copy(irootloc_h, tmp);
984     mm->sf        = bcastSF;
985     mm->E_NzLeft  = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), E_NzLeft_h);
986     mm->F_NzLeft  = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), F_NzLeft_h);
987     mm->irootloc  = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), irootloc_h);
988     mm->rowoffset = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), rowoffset_h);
989     mm->rootBuf   = MatScalarKokkosView(NoInit("rootBuf"), nroots);
990     mm->leafBuf   = MatScalarKokkosView(NoInit("leafBuf"), nleaves);
991     mm->garray    = garray2;
992     mm->n         = n2;
993 
994     // Output Fd and Fo in KokkosCsrMatrix format
995     MatScalarKokkosView Fda_d(NoInit("Fda_d"), Fdnz), Foa_d(NoInit("Foa_d"), Fonz);
996     MatRowMapKokkosView Fdi_d = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Fdi_h);
997     MatColIdxKokkosView Fdj_d = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Fdj_h);
998     MatRowMapKokkosView Foi_d = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Foi_h);
999     MatColIdxKokkosView Foj_d = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Foj_h);
1000 
1001     PetscCallCXX(mm->Fd = KokkosCsrMatrix("Fd", Fm, cend - cstart, Fdnz, Fda_d, Fdi_d, Fdj_d));
1002     PetscCallCXX(mm->Fo = KokkosCsrMatrix("Fo", Fm, n2, Fonz, Foa_d, Foi_d, Foj_d));
1003 
1004     // Compute kernel launch parameters in merging E or splitting F
1005     PetscInt teamSize, vectorLength, rowsPerTeam;
1006 
1007     teamSize = vectorLength = rowsPerTeam = -1;
1008     PetscCall(MatMergeGetLaunchParameters<DefaultExecutionSpace>(mm->irootloc.extent(0), mm->rootBuf.extent(0), -1, teamSize, vectorLength, rowsPerTeam));
1009     mm->E_TeamSize     = teamSize;
1010     mm->E_VectorLength = vectorLength;
1011     mm->E_RowsPerTeam  = rowsPerTeam;
1012 
1013     teamSize = vectorLength = rowsPerTeam = -1;
1014     PetscCall(MatMergeGetLaunchParameters<DefaultExecutionSpace>(Fm, Fnz, -1, teamSize, vectorLength, rowsPerTeam));
1015     mm->F_TeamSize     = teamSize;
1016     mm->F_VectorLength = vectorLength;
1017     mm->F_RowsPerTeam  = rowsPerTeam;
1018   } else PetscCheck(reuse == MAT_REUSE_MATRIX, comm, PETSC_ERR_PLIB, "Unsupported MatReuse enum %d", reuse);
1019 
1020   // Sync E's value to device
1021   akok->a_dual.sync_device();
1022   bkok->a_dual.sync_device();
1023 
1024   // Handy aliases
1025   const auto &Aa = akok->a_dual.view_device();
1026   const auto &Ba = bkok->a_dual.view_device();
1027   const auto &Ai = akok->i_dual.view_device();
1028   const auto &Bi = bkok->i_dual.view_device();
1029 
1030   // Fetch the plans
1031   PetscIntKokkosView  &E_NzLeft  = mm->E_NzLeft;
1032   PetscSF             &bcastSF   = mm->sf;
1033   MatScalarKokkosView &rootBuf   = mm->rootBuf;
1034   MatScalarKokkosView &leafBuf   = mm->leafBuf;
1035   PetscIntKokkosView  &irootloc  = mm->irootloc;
1036   PetscIntKokkosView  &rowoffset = mm->rowoffset;
1037 
1038   PetscInt teamSize     = mm->E_TeamSize;
1039   PetscInt vectorLength = mm->E_VectorLength;
1040   PetscInt rowsPerTeam  = mm->E_RowsPerTeam;
1041   PetscInt workSets     = (irootloc.extent(0) + rowsPerTeam - 1) / rowsPerTeam;
1042 
1043   // Copy rows in A/B of E to rootBuf, then bcast it to leafBuf
1044   PetscCallCXX(Kokkos::parallel_for(
1045     Kokkos::TeamPolicy<>(PetscGetKokkosExecutionSpace(), workSets, teamSize, vectorLength), KOKKOS_LAMBDA(const KokkosTeamMemberType &t) {
1046       Kokkos::parallel_for(Kokkos::TeamThreadRange(t, 0, rowsPerTeam), [&](PetscInt k) {
1047         size_t r = t.league_rank() * rowsPerTeam + k; // r-th entry in irootloc[]
1048         if (r < irootloc.extent(0)) {
1049           PetscInt i      = irootloc(r); // row i of E
1050           PetscInt disp   = rowoffset(r);
1051           PetscInt alen   = Ai(i + 1) - Ai(i);
1052           PetscInt blen   = Bi(i + 1) - Bi(i);
1053           PetscInt nzleft = E_NzLeft(i);
1054 
1055           Kokkos::parallel_for(Kokkos::ThreadVectorRange(t, alen + blen), [&](PetscInt j) {
1056             if (j < nzleft) { // B left
1057               rootBuf(disp + j) = Ba(Bi(i) + j);
1058             } else if (j < nzleft + alen) { // diag A
1059               rootBuf(disp + j) = Aa(Ai(i) + j - nzleft);
1060             } else { // B right
1061               rootBuf(disp + j) = Ba(Bi(i) + j - alen);
1062             }
1063           });
1064         }
1065       });
1066     }));
1067   PetscCall(PetscSFBcastWithMemTypeBegin(bcastSF, MPIU_SCALAR, PETSC_MEMTYPE_KOKKOS, rootBuf.data(), PETSC_MEMTYPE_KOKKOS, leafBuf.data(), MPI_REPLACE));
1068   PetscFunctionReturn(PETSC_SUCCESS);
1069 }
1070 
1071 // To finish MatMPIAIJKokkosBcast.
1072 static PetscErrorCode MatMPIAIJKokkosBcastEnd(Mat E, PetscSF ownerSF, MatReuse reuse, PetscInt *map, MatMatStruct_AB *mm)
1073 {
1074   PetscFunctionBegin;
1075   const auto &Fd  = mm->Fd;
1076   const auto &Fo  = mm->Fo;
1077   const auto &Fdi = Fd.graph.row_map;
1078   const auto &Foi = Fo.graph.row_map;
1079   auto       &Fda = Fd.values;
1080   auto       &Foa = Fo.values;
1081   auto        Fm  = Fd.numRows();
1082 
1083   PetscIntKokkosView  &F_NzLeft     = mm->F_NzLeft;
1084   PetscSF             &bcastSF      = mm->sf;
1085   MatScalarKokkosView &rootBuf      = mm->rootBuf;
1086   MatScalarKokkosView &leafBuf      = mm->leafBuf;
1087   PetscInt             teamSize     = mm->F_TeamSize;
1088   PetscInt             vectorLength = mm->F_VectorLength;
1089   PetscInt             rowsPerTeam  = mm->F_RowsPerTeam;
1090   PetscInt             workSets     = (Fm + rowsPerTeam - 1) / rowsPerTeam;
1091 
1092   PetscCall(PetscSFBcastEnd(bcastSF, MPIU_SCALAR, rootBuf.data(), leafBuf.data(), MPI_REPLACE));
1093 
1094   // Update Fda and Foa with new data in leafBuf (as if it is Fa)
1095   PetscCallCXX(Kokkos::parallel_for(
1096     Kokkos::TeamPolicy<>(PetscGetKokkosExecutionSpace(), workSets, teamSize, vectorLength), KOKKOS_LAMBDA(const KokkosTeamMemberType &t) {
1097       Kokkos::parallel_for(Kokkos::TeamThreadRange(t, 0, rowsPerTeam), [&](PetscInt k) {
1098         PetscInt i = t.league_rank() * rowsPerTeam + k; // i-th row in F
1099         if (i < Fm) {
1100           PetscInt nzLeft = F_NzLeft(i);
1101           PetscInt alen   = Fdi(i + 1) - Fdi(i);
1102           PetscInt blen   = Foi(i + 1) - Foi(i);
1103           PetscInt Fii    = Fdi(i) + Foi(i);
1104 
1105           Kokkos::parallel_for(Kokkos::ThreadVectorRange(t, alen + blen), [&](PetscInt j) {
1106             PetscScalar val = leafBuf(Fii + j);
1107             if (j < nzLeft) { // left
1108               Foa(Foi(i) + j) = val;
1109             } else if (j < nzLeft + alen) { // diag
1110               Fda(Fdi(i) + j - nzLeft) = val;
1111             } else { // right
1112               Foa(Foi(i) + j - alen) = val;
1113             }
1114           });
1115         }
1116       });
1117     }));
1118   PetscFunctionReturn(PETSC_SUCCESS);
1119 }
1120 
1121 static PetscErrorCode MatProductSymbolic_MPIAIJKokkos_AtB(Mat_Product *product, Mat A, Mat B, MatMatStruct_AtB *mm)
1122 {
1123   Mat_MPIAIJ     *ampi = static_cast<Mat_MPIAIJ *>(A->data);
1124   Mat_MPIAIJ     *bmpi = static_cast<Mat_MPIAIJ *>(B->data);
1125   KokkosCsrMatrix Adt, Aot, Ad, Ao, Bd, Bo;
1126   PetscInt        cstart, cend;
1127   MPI_Comm        comm;
1128 
1129   PetscFunctionBegin;
1130   PetscCall(PetscObjectGetComm((PetscObject)B, &comm));
1131   PetscCall(MatSeqAIJKokkosGenerateTranspose_Private(ampi->A, &Adt));
1132   PetscCall(MatSeqAIJKokkosGenerateTranspose_Private(ampi->B, &Aot));
1133   PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(ampi->A, &Ad));
1134   PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(ampi->B, &Ao));
1135   PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(bmpi->A, &Bd));
1136   PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(bmpi->B, &Bo));
1137 
1138   // TODO: add command line options to select spgemm algorithms
1139   auto spgemm_alg = KokkosSparse::SPGEMMAlgorithm::SPGEMM_DEFAULT; // default is TPL if enabled, otherwise KK
1140 
1141   // CUDA-10.2's spgemm has bugs. We prefer the SpGEMMreuse APIs introduced in cuda-11.4
1142 #if defined(KOKKOSKERNELS_ENABLE_TPL_CUSPARSE)
1143   #if PETSC_PKG_CUDA_VERSION_LT(11, 4, 0)
1144   spgemm_alg = KokkosSparse::SPGEMMAlgorithm::SPGEMM_KK;
1145   #endif
1146 #endif
1147 
1148   PetscCallCXX(mm->kh1.create_spgemm_handle(spgemm_alg));
1149   PetscCallCXX(mm->kh2.create_spgemm_handle(spgemm_alg));
1150   PetscCallCXX(mm->kh3.create_spgemm_handle(spgemm_alg));
1151   PetscCallCXX(mm->kh4.create_spgemm_handle(spgemm_alg));
1152 
1153   // Aot * (B's diag + B's off-diag)
1154   PetscCallCXX(KokkosSparse::spgemm_symbolic(mm->kh3, Aot, false, Bd, false, mm->C3));
1155   PetscCallCXX(KokkosSparse::spgemm_symbolic(mm->kh4, Aot, false, Bo, false, mm->C4));
1156   // KK spgemm_symbolic() only populates the result's row map, but not its columns.
1157   // TODO: Remove the fake spgemm_numeric() after KK fixed this problem.
1158   PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh3, Aot, false, Bd, false, mm->C3));
1159   PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh4, Aot, false, Bo, false, mm->C4));
1160 #if PETSC_PKG_KOKKOS_KERNELS_VERSION_LT(4, 0, 0)
1161 
1162   PetscCallCXX(sort_crs_matrix(mm->C3));
1163   PetscCallCXX(sort_crs_matrix(mm->C4));
1164 #endif
1165 
1166   // Reduce E (i.e., C3 and C4)'s rows to form F, and overlap the communication
1167   PetscIntKokkosViewHost map_h(NoInit("map_h"), bmpi->B->cmap->n);
1168   PetscCall(MatGetOwnershipRangeColumn(B, &cstart, &cend));
1169   PetscCall(MatMPIAIJKokkosReduceBegin(comm, mm->C3, mm->C4, cstart, cend, bmpi->garray, ampi->Mvctx, MAT_INITIAL_MATRIX, map_h.data(), mm));
1170 
1171   // Adt * (B's diag + B's off-diag)
1172   PetscCallCXX(KokkosSparse::spgemm_symbolic(mm->kh1, Adt, false, Bd, false, mm->C1));
1173   PetscCallCXX(KokkosSparse::spgemm_symbolic(mm->kh2, Adt, false, Bo, false, mm->C2_mid));
1174   PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh1, Adt, false, Bd, false, mm->C1));
1175   PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh2, Adt, false, Bo, false, mm->C2_mid));
1176 #if PETSC_PKG_KOKKOS_KERNELS_VERSION_LT(4, 0, 0)
1177   PetscCallCXX(sort_crs_matrix(mm->C1));
1178   PetscCallCXX(sort_crs_matrix(mm->C2_mid));
1179 #endif
1180 
1181   PetscCall(MatMPIAIJKokkosReduceEnd(comm, mm->C3, mm->C4, cstart, cend, bmpi->garray, ampi->Mvctx, MAT_INITIAL_MATRIX, map_h.data(), mm));
1182 
1183   // Create C2, which shares a, i arrays with C2_mid, but with new column indices and potentially larger column size
1184   MatColIdxKokkosView oldj = mm->C2_mid.graph.entries, newj(NoInit("j"), oldj.extent(0));
1185   PetscIntKokkosView  map  = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), map_h);
1186   PetscCallCXX(Kokkos::parallel_for(Kokkos::RangePolicy<>(PetscGetKokkosExecutionSpace(), 0, oldj.extent(0)), KOKKOS_LAMBDA(const PetscInt i) { newj(i) = map(oldj(i)); }));
1187   PetscCallCXX(mm->C2 = KokkosCsrMatrix("C2", mm->C2_mid.numRows(), mm->n /*new column size*/, mm->C2_mid.nnz(), mm->C2_mid.values, mm->C2_mid.graph.row_map, newj));
1188 
1189   // C = (C1+Fd, C2+Fo)
1190   PetscCallCXX(mm->kh1.create_spadd_handle(true)); // C1, Fd are sorted
1191   PetscCallCXX(mm->kh2.create_spadd_handle(true)); // C2, Fo are sorted
1192   PetscCallCXX(KokkosSparse::spadd_symbolic(&mm->kh1, mm->C1, mm->Fd, mm->Cd));
1193   PetscCallCXX(KokkosSparse::spadd_symbolic(&mm->kh2, mm->C2, mm->Fo, mm->Co));
1194   PetscCallCXX(KokkosSparse::spadd_numeric(&mm->kh1, 1.0, mm->C1, 1.0, mm->Fd, mm->Cd));
1195   PetscCallCXX(KokkosSparse::spadd_numeric(&mm->kh2, 1.0, mm->C2, 1.0, mm->Fo, mm->Co));
1196   PetscFunctionReturn(PETSC_SUCCESS);
1197 }
1198 
1199 static PetscErrorCode MatProductNumeric_MPIAIJKokkos_AtB(Mat_Product *product, Mat A, Mat B, MatMatStruct_AtB *mm)
1200 {
1201   Mat_MPIAIJ     *ampi = static_cast<Mat_MPIAIJ *>(A->data);
1202   Mat_MPIAIJ     *bmpi = static_cast<Mat_MPIAIJ *>(B->data);
1203   KokkosCsrMatrix Adt, Aot, Bd, Bo;
1204   MPI_Comm        comm;
1205 
1206   PetscFunctionBegin;
1207   PetscCall(PetscObjectGetComm((PetscObject)B, &comm));
1208   PetscCall(MatSeqAIJKokkosGenerateTranspose_Private(ampi->A, &Adt));
1209   PetscCall(MatSeqAIJKokkosGenerateTranspose_Private(ampi->B, &Aot));
1210   PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(bmpi->A, &Bd));
1211   PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(bmpi->B, &Bo));
1212 
1213   // Aot * (B's diag + B's off-diag)
1214   PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh3, Aot, false, Bd, false, mm->C3));
1215   PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh4, Aot, false, Bo, false, mm->C4));
1216 
1217   // Reduce E (i.e., C3 and C4)'s rows to form F, and overlap the communication
1218   PetscCall(MatMPIAIJKokkosReduceBegin(comm, mm->C3, mm->C4, 0, 0, NULL, NULL, MAT_REUSE_MATRIX, NULL, mm));
1219 
1220   // Adt * (B's diag + B's off-diag)
1221   PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh1, Adt, false, Bd, false, mm->C1));
1222   PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh2, Adt, false, Bo, false, mm->C2_mid));
1223 
1224   PetscCall(MatMPIAIJKokkosReduceEnd(comm, mm->C3, mm->C4, 0, 0, NULL, NULL, MAT_REUSE_MATRIX, NULL, mm));
1225 
1226   // C = (C1+Fd, C2+Fo)
1227   PetscCallCXX(KokkosSparse::spadd_numeric(&mm->kh1, 1.0, mm->C1, 1.0, mm->Fd, mm->Cd));
1228   PetscCallCXX(KokkosSparse::spadd_numeric(&mm->kh2, 1.0, mm->C2, 1.0, mm->Fo, mm->Co));
1229   PetscFunctionReturn(PETSC_SUCCESS);
1230 }
1231 
1232 /* MatProductSymbolic_MPIAIJKokkos_AB - AB flavor of MatProductSymbolic_MPIAIJKokkos
1233 
1234   Input Parameters:
1235 +  product  - Mat_Product which carried out the computation. Passed in to access info about this mat product.
1236 .  A        - an MPIAIJKOKKOS matrix
1237 .  B        - an MPIAIJKOKKOS matrix
1238 -  mm       - a struct used to stash intermediate data when computing AB. Persist from symbolic to numeric operations.
1239 */
1240 static PetscErrorCode MatProductSymbolic_MPIAIJKokkos_AB(Mat_Product *product, Mat A, Mat B, MatMatStruct_AB *mm)
1241 {
1242   Mat_MPIAIJ     *ampi = static_cast<Mat_MPIAIJ *>(A->data);
1243   Mat_MPIAIJ     *bmpi = static_cast<Mat_MPIAIJ *>(B->data);
1244   KokkosCsrMatrix Ad, Ao, Bd, Bo;
1245 
1246   PetscFunctionBegin;
1247   PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(ampi->A, &Ad));
1248   PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(ampi->B, &Ao));
1249   PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(bmpi->A, &Bd));
1250   PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(bmpi->B, &Bo));
1251 
1252   // TODO: add command line options to select spgemm algorithms
1253   auto spgemm_alg = KokkosSparse::SPGEMMAlgorithm::SPGEMM_DEFAULT; // default is TPL if enabled, otherwise KK
1254 
1255   // CUDA-10.2's spgemm has bugs. We prefer the SpGEMMreuse APIs introduced in cuda-11.4
1256 #if defined(KOKKOSKERNELS_ENABLE_TPL_CUSPARSE)
1257   #if PETSC_PKG_CUDA_VERSION_LT(11, 4, 0)
1258   spgemm_alg = KokkosSparse::SPGEMMAlgorithm::SPGEMM_KK;
1259   #endif
1260 #endif
1261 
1262   mm->kh1.create_spgemm_handle(spgemm_alg);
1263   mm->kh2.create_spgemm_handle(spgemm_alg);
1264   mm->kh3.create_spgemm_handle(spgemm_alg);
1265   mm->kh4.create_spgemm_handle(spgemm_alg);
1266 
1267   // Bcast B's rows to form F, and overlap the communication
1268   PetscIntKokkosViewHost map_h(NoInit("map_h"), bmpi->B->cmap->n);
1269   PetscCall(MatMPIAIJKokkosBcastBegin(B, ampi->Mvctx, MAT_INITIAL_MATRIX, map_h.data(), mm));
1270 
1271   // A's diag * (B's diag + B's off-diag)
1272   PetscCallCXX(KokkosSparse::spgemm_symbolic(mm->kh1, Ad, false, Bd, false, mm->C1));
1273   PetscCallCXX(KokkosSparse::spgemm_symbolic(mm->kh2, Ad, false, Bo, false, mm->C2_mid)); // C2 aliases with C2_mid, except with new column indices
1274   // KK spgemm_symbolic() only populates the result's row map, but not its columns.
1275   // TODO: Remove the fake spgemm_numeric() after KK fixed this problem.
1276   PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh1, Ad, false, Bd, false, mm->C1));
1277   PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh2, Ad, false, Bo, false, mm->C2_mid));
1278 #if PETSC_PKG_KOKKOS_KERNELS_VERSION_LT(4, 0, 0)
1279   PetscCallCXX(sort_crs_matrix(mm->C1));
1280   PetscCallCXX(sort_crs_matrix(mm->C2_mid));
1281 #endif
1282 
1283   PetscCall(MatMPIAIJKokkosBcastEnd(B, ampi->Mvctx, MAT_INITIAL_MATRIX, map_h.data(), mm));
1284 
1285   // A's off-diag * (F's diag + F's off-diag)
1286   PetscCallCXX(KokkosSparse::spgemm_symbolic(mm->kh3, Ao, false, mm->Fd, false, mm->C3));
1287   PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh3, Ao, false, mm->Fd, false, mm->C3));
1288   PetscCallCXX(KokkosSparse::spgemm_symbolic(mm->kh4, Ao, false, mm->Fo, false, mm->C4));
1289   PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh4, Ao, false, mm->Fo, false, mm->C4));
1290 #if PETSC_PKG_KOKKOS_KERNELS_VERSION_LT(4, 0, 0)
1291   PetscCallCXX(sort_crs_matrix(mm->C3));
1292   PetscCallCXX(sort_crs_matrix(mm->C4));
1293 #endif
1294 
1295   // Create C2, which shares a, i arrays with C2_mid, but with new column indices and potentially larger column size
1296   MatColIdxKokkosView oldj = mm->C2_mid.graph.entries, newj(NoInit("j"), oldj.extent(0));
1297   PetscIntKokkosView  map  = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), map_h);
1298   PetscCallCXX(Kokkos::parallel_for(Kokkos::RangePolicy<>(PetscGetKokkosExecutionSpace(), 0, oldj.extent(0)), KOKKOS_LAMBDA(const PetscInt i) { newj(i) = map(oldj(i)); }));
1299   mm->C2 = KokkosCsrMatrix("C2", mm->C2_mid.numRows(), mm->n /*new column size*/, mm->C2_mid.nnz(), mm->C2_mid.values, mm->C2_mid.graph.row_map, newj);
1300 
1301   // C = (Cd, Co) = (C1+C3, C2+C4)
1302   mm->kh1.create_spadd_handle(true); // C1, C3 are sorted
1303   mm->kh2.create_spadd_handle(true); // C2, C4 are sorted
1304   PetscCallCXX(KokkosSparse::spadd_symbolic(&mm->kh1, mm->C1, mm->C3, mm->Cd));
1305   PetscCallCXX(KokkosSparse::spadd_symbolic(&mm->kh2, mm->C2, mm->C4, mm->Co));
1306   PetscCallCXX(KokkosSparse::spadd_numeric(&mm->kh1, 1.0, mm->C1, 1.0, mm->C3, mm->Cd));
1307   PetscCallCXX(KokkosSparse::spadd_numeric(&mm->kh2, 1.0, mm->C2, 1.0, mm->C4, mm->Co));
1308   PetscFunctionReturn(PETSC_SUCCESS);
1309 }
1310 
1311 static PetscErrorCode MatProductNumeric_MPIAIJKokkos_AB(Mat_Product *product, Mat A, Mat B, MatMatStruct_AB *mm)
1312 {
1313   Mat_MPIAIJ     *ampi = static_cast<Mat_MPIAIJ *>(A->data);
1314   Mat_MPIAIJ     *bmpi = static_cast<Mat_MPIAIJ *>(B->data);
1315   KokkosCsrMatrix Ad, Ao, Bd, Bo;
1316 
1317   PetscFunctionBegin;
1318   PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(ampi->A, &Ad));
1319   PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(ampi->B, &Ao));
1320   PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(bmpi->A, &Bd));
1321   PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(bmpi->B, &Bo));
1322 
1323   // Bcast B's rows to form F, and overlap the communication
1324   PetscCall(MatMPIAIJKokkosBcastBegin(B, NULL, MAT_REUSE_MATRIX, NULL, mm));
1325 
1326   // A's diag * (B's diag + B's off-diag)
1327   PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh1, Ad, false, Bd, false, mm->C1));
1328   PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh2, Ad, false, Bo, false, mm->C2_mid));
1329 
1330   PetscCall(MatMPIAIJKokkosBcastEnd(B, NULL, MAT_REUSE_MATRIX, NULL, mm));
1331 
1332   // A's off-diag * (F's diag + F's off-diag)
1333   PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh3, Ao, false, mm->Fd, false, mm->C3));
1334   PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh4, Ao, false, mm->Fo, false, mm->C4));
1335 
1336   // C = (Cd, Co) = (C1+C3, C2+C4)
1337   PetscCallCXX(KokkosSparse::spadd_numeric(&mm->kh1, 1.0, mm->C1, 1.0, mm->C3, mm->Cd));
1338   PetscCallCXX(KokkosSparse::spadd_numeric(&mm->kh2, 1.0, mm->C2, 1.0, mm->C4, mm->Co));
1339   PetscFunctionReturn(PETSC_SUCCESS);
1340 }
1341 
1342 static PetscErrorCode MatProductNumeric_MPIAIJKokkos(Mat C)
1343 {
1344   Mat_MPIAIJ                  *cmpi = static_cast<Mat_MPIAIJ *>(C->data);
1345   Mat_Product                 *product;
1346   MatProductData_MPIAIJKokkos *pdata;
1347   MatProductType               ptype;
1348   Mat                          A, B;
1349 
1350   PetscFunctionBegin;
1351   MatCheckProduct(C, 1); // make sure C is a product
1352   product = C->product;
1353   pdata   = static_cast<MatProductData_MPIAIJKokkos *>(product->data);
1354   ptype   = product->type;
1355   A       = product->A;
1356   B       = product->B;
1357 
1358   // See if numeric has already been done in symbolic (e.g., user calls MatMatMult(A,B,MAT_INITIAL_MATRIX,..,C)).
1359   // If yes, skip the numeric, but reset the flag so that next time when user calls MatMatMult(E,F,MAT_REUSE_MATRIX,..,C),
1360   // we still do numeric.
1361   if (pdata->reusesym) { // numeric reuses results from symbolic
1362     pdata->reusesym = PETSC_FALSE;
1363     PetscFunctionReturn(PETSC_SUCCESS);
1364   }
1365 
1366   if (ptype == MATPRODUCT_AB) {
1367     PetscCall(MatProductNumeric_MPIAIJKokkos_AB(product, A, B, pdata->mmAB));
1368   } else if (ptype == MATPRODUCT_AtB) {
1369     PetscCall(MatProductNumeric_MPIAIJKokkos_AtB(product, A, B, pdata->mmAtB));
1370   } else if (ptype == MATPRODUCT_PtAP) { // BtAB, computed by Z = AB; C= BtZ
1371     PetscCall(MatProductNumeric_MPIAIJKokkos_AB(product, A, B, pdata->mmAB));
1372     PetscCall(MatProductNumeric_MPIAIJKokkos_AtB(product, B, pdata->Z, pdata->mmAtB));
1373   }
1374 
1375   PetscCall(MatSeqAIJKokkosModifyDevice(cmpi->A)); // mark that A, B on device are modified
1376   PetscCall(MatSeqAIJKokkosModifyDevice(cmpi->B));
1377   PetscFunctionReturn(PETSC_SUCCESS);
1378 }
1379 
1380 static PetscErrorCode MatProductSymbolic_MPIAIJKokkos(Mat C)
1381 {
1382   Mat                          A, B;
1383   Mat_Product                 *product;
1384   MatProductType               ptype;
1385   MatProductData_MPIAIJKokkos *pdata;
1386   MatMatStruct                *mm = NULL;
1387   PetscInt                     m, n, M, N;
1388   Mat                          Cd, Co;
1389   MPI_Comm                     comm;
1390 
1391   PetscFunctionBegin;
1392   PetscCall(PetscObjectGetComm((PetscObject)C, &comm));
1393   MatCheckProduct(C, 1);
1394   product = C->product;
1395   PetscCheck(!product->data, comm, PETSC_ERR_PLIB, "Product data not empty");
1396   ptype = product->type;
1397   A     = product->A;
1398   B     = product->B;
1399 
1400   switch (ptype) {
1401   case MATPRODUCT_AB:
1402     m = A->rmap->n;
1403     n = B->cmap->n;
1404     M = A->rmap->N;
1405     N = B->cmap->N;
1406     break;
1407   case MATPRODUCT_AtB:
1408     m = A->cmap->n;
1409     n = B->cmap->n;
1410     M = A->cmap->N;
1411     N = B->cmap->N;
1412     break;
1413   case MATPRODUCT_PtAP:
1414     m = B->cmap->n;
1415     n = B->cmap->n;
1416     M = B->cmap->N;
1417     N = B->cmap->N;
1418     break; /* BtAB */
1419   default:
1420     SETERRQ(comm, PETSC_ERR_PLIB, "Not for product type %s", MatProductTypes[ptype]);
1421   }
1422 
1423   PetscCall(MatSetSizes(C, m, n, M, N));
1424   PetscCall(PetscLayoutSetUp(C->rmap));
1425   PetscCall(PetscLayoutSetUp(C->cmap));
1426   PetscCall(MatSetType(C, ((PetscObject)A)->type_name));
1427 
1428   pdata           = new MatProductData_MPIAIJKokkos();
1429   pdata->reusesym = product->api_user;
1430 
1431   if (ptype == MATPRODUCT_AB) {
1432     auto mmAB = new MatMatStruct_AB();
1433     PetscCall(MatProductSymbolic_MPIAIJKokkos_AB(product, A, B, mmAB));
1434     mm = pdata->mmAB = mmAB;
1435   } else if (ptype == MATPRODUCT_AtB) {
1436     auto mmAtB = new MatMatStruct_AtB();
1437     PetscCall(MatProductSymbolic_MPIAIJKokkos_AtB(product, A, B, mmAtB));
1438     mm = pdata->mmAtB = mmAtB;
1439   } else if (ptype == MATPRODUCT_PtAP) { // C = BtAB, computed as Z = AB; C= BtZ
1440     Mat Zd, Zo, Z;                       // Zd, Zo are owned by pdata->Z
1441 
1442     auto mmAB = new MatMatStruct_AB();
1443     PetscCall(MatProductSymbolic_MPIAIJKokkos_AB(product, A, B, mmAB)); // Z stored as mmAB->{Cd, Co}
1444     PetscCall(MatCreateSeqAIJKokkosWithKokkosCsrMatrix(PETSC_COMM_SELF, mmAB->Cd, &Zd));
1445     PetscCall(MatCreateSeqAIJKokkosWithKokkosCsrMatrix(PETSC_COMM_SELF, mmAB->Co, &Zo));
1446     pdata->mmAB = mmAB;
1447 
1448     m = A->rmap->n; // Z's layout
1449     n = B->cmap->n;
1450     M = A->rmap->N;
1451     N = B->cmap->N;
1452     PetscCall(MatCreate(comm, &Z));
1453     PetscCall(MatSetSizes(Z, m, n, M, N));
1454     PetscCall(PetscLayoutSetUp(Z->rmap));
1455     PetscCall(PetscLayoutSetUp(Z->cmap));
1456     PetscCall(MatSetType(Z, MATMPIAIJKOKKOS));
1457     PetscCall(MatSetMPIAIJKokkosWithSplitSeqAIJKokkosMatrices(Z, Zd, Zo, mmAB->garray));
1458 
1459     auto mmAtB = new MatMatStruct_AtB();
1460     PetscCall(MatProductSymbolic_MPIAIJKokkos_AtB(product, B, Z, mmAtB)); // final result C stored as mmAtB->{Cd, Co}
1461 
1462     pdata->Z = Z; // give ownership to pdata
1463     mm = pdata->mmAtB = mmAtB;
1464   }
1465 
1466   PetscCall(MatCreateSeqAIJKokkosWithKokkosCsrMatrix(PETSC_COMM_SELF, mm->Cd, &Cd));
1467   PetscCall(MatCreateSeqAIJKokkosWithKokkosCsrMatrix(PETSC_COMM_SELF, mm->Co, &Co));
1468   PetscCall(MatSetMPIAIJKokkosWithSplitSeqAIJKokkosMatrices(C, Cd, Co, mm->garray));
1469 
1470   C->product->data       = pdata;
1471   C->product->destroy    = MatProductDataDestroy_MPIAIJKokkos;
1472   C->ops->productnumeric = MatProductNumeric_MPIAIJKokkos;
1473   PetscFunctionReturn(PETSC_SUCCESS);
1474 }
1475 
1476 PETSC_INTERN PetscErrorCode MatProductSetFromOptions_MPIAIJKokkos(Mat mat)
1477 {
1478   Mat_Product *product = mat->product;
1479   PetscBool    match   = PETSC_FALSE;
1480   PetscBool    usecpu  = PETSC_FALSE;
1481 
1482   PetscFunctionBegin;
1483   MatCheckProduct(mat, 1);
1484   if (!product->A->boundtocpu && !product->B->boundtocpu) PetscCall(PetscObjectTypeCompare((PetscObject)product->B, ((PetscObject)product->A)->type_name, &match));
1485   if (match) { /* we can always fallback to the CPU if requested */
1486     switch (product->type) {
1487     case MATPRODUCT_AB:
1488       if (product->api_user) {
1489         PetscOptionsBegin(PetscObjectComm((PetscObject)mat), ((PetscObject)mat)->prefix, "MatMatMult", "Mat");
1490         PetscCall(PetscOptionsBool("-matmatmult_backend_cpu", "Use CPU code", "MatMatMult", usecpu, &usecpu, NULL));
1491         PetscOptionsEnd();
1492       } else {
1493         PetscOptionsBegin(PetscObjectComm((PetscObject)mat), ((PetscObject)mat)->prefix, "MatProduct_AB", "Mat");
1494         PetscCall(PetscOptionsBool("-mat_product_algorithm_backend_cpu", "Use CPU code", "MatMatMult", usecpu, &usecpu, NULL));
1495         PetscOptionsEnd();
1496       }
1497       break;
1498     case MATPRODUCT_AtB:
1499       if (product->api_user) {
1500         PetscOptionsBegin(PetscObjectComm((PetscObject)mat), ((PetscObject)mat)->prefix, "MatTransposeMatMult", "Mat");
1501         PetscCall(PetscOptionsBool("-mattransposematmult_backend_cpu", "Use CPU code", "MatTransposeMatMult", usecpu, &usecpu, NULL));
1502         PetscOptionsEnd();
1503       } else {
1504         PetscOptionsBegin(PetscObjectComm((PetscObject)mat), ((PetscObject)mat)->prefix, "MatProduct_AtB", "Mat");
1505         PetscCall(PetscOptionsBool("-mat_product_algorithm_backend_cpu", "Use CPU code", "MatTransposeMatMult", usecpu, &usecpu, NULL));
1506         PetscOptionsEnd();
1507       }
1508       break;
1509     case MATPRODUCT_PtAP:
1510       if (product->api_user) {
1511         PetscOptionsBegin(PetscObjectComm((PetscObject)mat), ((PetscObject)mat)->prefix, "MatPtAP", "Mat");
1512         PetscCall(PetscOptionsBool("-matptap_backend_cpu", "Use CPU code", "MatPtAP", usecpu, &usecpu, NULL));
1513         PetscOptionsEnd();
1514       } else {
1515         PetscOptionsBegin(PetscObjectComm((PetscObject)mat), ((PetscObject)mat)->prefix, "MatProduct_PtAP", "Mat");
1516         PetscCall(PetscOptionsBool("-mat_product_algorithm_backend_cpu", "Use CPU code", "MatPtAP", usecpu, &usecpu, NULL));
1517         PetscOptionsEnd();
1518       }
1519       break;
1520     default:
1521       break;
1522     }
1523     match = (PetscBool)!usecpu;
1524   }
1525   if (match) {
1526     switch (product->type) {
1527     case MATPRODUCT_AB:
1528     case MATPRODUCT_AtB:
1529     case MATPRODUCT_PtAP:
1530       mat->ops->productsymbolic = MatProductSymbolic_MPIAIJKokkos;
1531       break;
1532     default:
1533       break;
1534     }
1535   }
1536   /* fallback to MPIAIJ ops */
1537   if (!mat->ops->productsymbolic) PetscCall(MatProductSetFromOptions_MPIAIJ(mat));
1538   PetscFunctionReturn(PETSC_SUCCESS);
1539 }
1540 
1541 // Mirror of MatCOOStruct_MPIAIJ on device
1542 struct MatCOOStruct_MPIAIJKokkos {
1543   PetscCount           n;
1544   PetscSF              sf;
1545   PetscCount           Annz, Bnnz;
1546   PetscCount           Annz2, Bnnz2;
1547   PetscCountKokkosView Ajmap1, Aperm1;
1548   PetscCountKokkosView Bjmap1, Bperm1;
1549   PetscCountKokkosView Aimap2, Ajmap2, Aperm2;
1550   PetscCountKokkosView Bimap2, Bjmap2, Bperm2;
1551   PetscCountKokkosView Cperm1;
1552   MatScalarKokkosView  sendbuf, recvbuf;
1553 
1554   MatCOOStruct_MPIAIJKokkos(const MatCOOStruct_MPIAIJ *coo_h) :
1555     n(coo_h->n),
1556     sf(coo_h->sf),
1557     Annz(coo_h->Annz),
1558     Bnnz(coo_h->Bnnz),
1559     Annz2(coo_h->Annz2),
1560     Bnnz2(coo_h->Bnnz2),
1561     Ajmap1(Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), PetscCountKokkosViewHost(coo_h->Ajmap1, coo_h->Annz + 1))),
1562     Aperm1(Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), PetscCountKokkosViewHost(coo_h->Aperm1, coo_h->Atot1))),
1563     Bjmap1(Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), PetscCountKokkosViewHost(coo_h->Bjmap1, coo_h->Bnnz + 1))),
1564     Bperm1(Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), PetscCountKokkosViewHost(coo_h->Bperm1, coo_h->Btot1))),
1565     Aimap2(Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), PetscCountKokkosViewHost(coo_h->Aimap2, coo_h->Annz2))),
1566     Ajmap2(Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), PetscCountKokkosViewHost(coo_h->Ajmap2, coo_h->Annz2 + 1))),
1567     Aperm2(Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), PetscCountKokkosViewHost(coo_h->Aperm2, coo_h->Atot2))),
1568     Bimap2(Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), PetscCountKokkosViewHost(coo_h->Bimap2, coo_h->Bnnz2))),
1569     Bjmap2(Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), PetscCountKokkosViewHost(coo_h->Bjmap2, coo_h->Bnnz2 + 1))),
1570     Bperm2(Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), PetscCountKokkosViewHost(coo_h->Bperm2, coo_h->Btot2))),
1571     Cperm1(Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), PetscCountKokkosViewHost(coo_h->Cperm1, coo_h->sendlen))),
1572     sendbuf(Kokkos::create_mirror_view(Kokkos::WithoutInitializing, DefaultMemorySpace(), MatScalarKokkosViewHost(coo_h->sendbuf, coo_h->sendlen))),
1573     recvbuf(Kokkos::create_mirror_view(Kokkos::WithoutInitializing, DefaultMemorySpace(), MatScalarKokkosViewHost(coo_h->recvbuf, coo_h->recvlen)))
1574   {
1575     PetscCallVoid(PetscObjectReference((PetscObject)sf));
1576   }
1577 
1578   ~MatCOOStruct_MPIAIJKokkos() { PetscCallVoid(PetscSFDestroy(&sf)); }
1579 };
1580 
1581 static PetscErrorCode MatCOOStructDestroy_MPIAIJKokkos(void *data)
1582 {
1583   PetscFunctionBegin;
1584   PetscCallCXX(delete static_cast<MatCOOStruct_MPIAIJKokkos *>(data));
1585   PetscFunctionReturn(PETSC_SUCCESS);
1586 }
1587 
1588 static PetscErrorCode MatSetPreallocationCOO_MPIAIJKokkos(Mat mat, PetscCount coo_n, PetscInt coo_i[], PetscInt coo_j[])
1589 {
1590   PetscContainer             container_h, container_d;
1591   MatCOOStruct_MPIAIJ       *coo_h;
1592   MatCOOStruct_MPIAIJKokkos *coo_d;
1593 
1594   PetscFunctionBegin;
1595   PetscCall(MatSetPreallocationCOO_MPIAIJ(mat, coo_n, coo_i, coo_j)); /* mpiaij->A,B's type is set to seqaijkokkos */
1596   mat->preallocated = PETSC_TRUE;
1597   PetscCall(MatAssemblyBegin(mat, MAT_FINAL_ASSEMBLY));
1598   PetscCall(MatAssemblyEnd(mat, MAT_FINAL_ASSEMBLY));
1599   PetscCall(MatZeroEntries(mat));
1600 
1601   // Copy the COO struct to device
1602   PetscCall(PetscObjectQuery((PetscObject)mat, "__PETSc_MatCOOStruct_Host", (PetscObject *)&container_h));
1603   PetscCall(PetscContainerGetPointer(container_h, (void **)&coo_h));
1604   PetscCallCXX(coo_d = new MatCOOStruct_MPIAIJKokkos(coo_h));
1605 
1606   // Put the COO struct in a container and then attach that to the matrix
1607   PetscCall(PetscContainerCreate(PETSC_COMM_SELF, &container_d));
1608   PetscCall(PetscContainerSetPointer(container_d, coo_d));
1609   PetscCall(PetscContainerSetUserDestroy(container_d, MatCOOStructDestroy_MPIAIJKokkos));
1610   PetscCall(PetscObjectCompose((PetscObject)mat, "__PETSc_MatCOOStruct_Device", (PetscObject)container_d));
1611   PetscCall(PetscContainerDestroy(&container_d));
1612   PetscFunctionReturn(PETSC_SUCCESS);
1613 }
1614 
1615 static PetscErrorCode MatSetValuesCOO_MPIAIJKokkos(Mat mat, const PetscScalar v[], InsertMode imode)
1616 {
1617   Mat_MPIAIJ                *mpiaij = static_cast<Mat_MPIAIJ *>(mat->data);
1618   Mat                        A = mpiaij->A, B = mpiaij->B;
1619   MatScalarKokkosView        Aa, Ba;
1620   MatScalarKokkosView        v1;
1621   PetscMemType               memtype;
1622   PetscContainer             container;
1623   MatCOOStruct_MPIAIJKokkos *coo;
1624 
1625   PetscFunctionBegin;
1626   PetscCall(PetscObjectQuery((PetscObject)mat, "__PETSc_MatCOOStruct_Device", (PetscObject *)&container));
1627   PetscCall(PetscContainerGetPointer(container, (void **)&coo));
1628 
1629   const auto &n      = coo->n;
1630   const auto &Annz   = coo->Annz;
1631   const auto &Annz2  = coo->Annz2;
1632   const auto &Bnnz   = coo->Bnnz;
1633   const auto &Bnnz2  = coo->Bnnz2;
1634   const auto &vsend  = coo->sendbuf;
1635   const auto &v2     = coo->recvbuf;
1636   const auto &Ajmap1 = coo->Ajmap1;
1637   const auto &Ajmap2 = coo->Ajmap2;
1638   const auto &Aimap2 = coo->Aimap2;
1639   const auto &Bjmap1 = coo->Bjmap1;
1640   const auto &Bjmap2 = coo->Bjmap2;
1641   const auto &Bimap2 = coo->Bimap2;
1642   const auto &Aperm1 = coo->Aperm1;
1643   const auto &Aperm2 = coo->Aperm2;
1644   const auto &Bperm1 = coo->Bperm1;
1645   const auto &Bperm2 = coo->Bperm2;
1646   const auto &Cperm1 = coo->Cperm1;
1647 
1648   PetscCall(PetscGetMemType(v, &memtype)); /* Return PETSC_MEMTYPE_HOST when v is NULL */
1649   if (PetscMemTypeHost(memtype)) {         /* If user gave v[] in host, we need to copy it to device if any */
1650     v1 = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), MatScalarKokkosViewHost((PetscScalar *)v, n));
1651   } else {
1652     v1 = MatScalarKokkosView((PetscScalar *)v, n); /* Directly use v[]'s memory */
1653   }
1654 
1655   if (imode == INSERT_VALUES) {
1656     PetscCall(MatSeqAIJGetKokkosViewWrite(A, &Aa)); /* write matrix values */
1657     PetscCall(MatSeqAIJGetKokkosViewWrite(B, &Ba));
1658   } else {
1659     PetscCall(MatSeqAIJGetKokkosView(A, &Aa)); /* read & write matrix values */
1660     PetscCall(MatSeqAIJGetKokkosView(B, &Ba));
1661   }
1662 
1663   PetscCall(PetscLogGpuTimeBegin());
1664   /* Pack entries to be sent to remote */
1665   Kokkos::parallel_for(Kokkos::RangePolicy<>(PetscGetKokkosExecutionSpace(), 0, vsend.extent(0)), KOKKOS_LAMBDA(const PetscCount i) { vsend(i) = v1(Cperm1(i)); });
1666 
1667   /* Send remote entries to their owner and overlap the communication with local computation */
1668   PetscCall(PetscSFReduceWithMemTypeBegin(coo->sf, MPIU_SCALAR, PETSC_MEMTYPE_KOKKOS, vsend.data(), PETSC_MEMTYPE_KOKKOS, v2.data(), MPI_REPLACE));
1669   /* Add local entries to A and B in one kernel */
1670   Kokkos::parallel_for(
1671     Kokkos::RangePolicy<>(PetscGetKokkosExecutionSpace(), 0, Annz + Bnnz), KOKKOS_LAMBDA(PetscCount i) {
1672       PetscScalar sum = 0.0;
1673       if (i < Annz) {
1674         for (PetscCount k = Ajmap1(i); k < Ajmap1(i + 1); k++) sum += v1(Aperm1(k));
1675         Aa(i) = (imode == INSERT_VALUES ? 0.0 : Aa(i)) + sum;
1676       } else {
1677         i -= Annz;
1678         for (PetscCount k = Bjmap1(i); k < Bjmap1(i + 1); k++) sum += v1(Bperm1(k));
1679         Ba(i) = (imode == INSERT_VALUES ? 0.0 : Ba(i)) + sum;
1680       }
1681     });
1682   PetscCall(PetscSFReduceEnd(coo->sf, MPIU_SCALAR, vsend.data(), v2.data(), MPI_REPLACE));
1683 
1684   /* Add received remote entries to A and B in one kernel */
1685   Kokkos::parallel_for(
1686     Kokkos::RangePolicy<>(PetscGetKokkosExecutionSpace(), 0, Annz2 + Bnnz2), KOKKOS_LAMBDA(PetscCount i) {
1687       if (i < Annz2) {
1688         for (PetscCount k = Ajmap2(i); k < Ajmap2(i + 1); k++) Aa(Aimap2(i)) += v2(Aperm2(k));
1689       } else {
1690         i -= Annz2;
1691         for (PetscCount k = Bjmap2(i); k < Bjmap2(i + 1); k++) Ba(Bimap2(i)) += v2(Bperm2(k));
1692       }
1693     });
1694   PetscCall(PetscLogGpuTimeEnd());
1695 
1696   if (imode == INSERT_VALUES) {
1697     PetscCall(MatSeqAIJRestoreKokkosViewWrite(A, &Aa)); /* Increase A & B's state etc. */
1698     PetscCall(MatSeqAIJRestoreKokkosViewWrite(B, &Ba));
1699   } else {
1700     PetscCall(MatSeqAIJRestoreKokkosView(A, &Aa));
1701     PetscCall(MatSeqAIJRestoreKokkosView(B, &Ba));
1702   }
1703   PetscFunctionReturn(PETSC_SUCCESS);
1704 }
1705 
1706 static PetscErrorCode MatDestroy_MPIAIJKokkos(Mat A)
1707 {
1708   PetscFunctionBegin;
1709   PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatMPIAIJSetPreallocation_C", NULL));
1710   PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatMPIAIJGetLocalMatMerge_C", NULL));
1711   PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatSetPreallocationCOO_C", NULL));
1712   PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatSetValuesCOO_C", NULL));
1713   PetscCall(MatDestroy_MPIAIJ(A));
1714   PetscFunctionReturn(PETSC_SUCCESS);
1715 }
1716 
1717 static PetscErrorCode MatShift_MPIAIJKokkos(Mat A, PetscScalar a)
1718 {
1719   Mat_MPIAIJ *mpiaij = static_cast<Mat_MPIAIJ *>(A->data);
1720   PetscBool   congruent;
1721 
1722   PetscFunctionBegin;
1723   PetscCall(MatHasCongruentLayouts(A, &congruent));
1724   if (congruent) { // square matrix and the diagonals are solely in the diag block
1725     PetscCall(MatShift(mpiaij->A, a));
1726   } else { // too hard, use the general version
1727     PetscCall(MatShift_Basic(A, a));
1728   }
1729   PetscFunctionReturn(PETSC_SUCCESS);
1730 }
1731 
1732 static PetscErrorCode MatSetOps_MPIAIJKokkos(Mat B)
1733 {
1734   PetscFunctionBegin;
1735   B->ops->assemblyend           = MatAssemblyEnd_MPIAIJKokkos;
1736   B->ops->mult                  = MatMult_MPIAIJKokkos;
1737   B->ops->multadd               = MatMultAdd_MPIAIJKokkos;
1738   B->ops->multtranspose         = MatMultTranspose_MPIAIJKokkos;
1739   B->ops->productsetfromoptions = MatProductSetFromOptions_MPIAIJKokkos;
1740   B->ops->destroy               = MatDestroy_MPIAIJKokkos;
1741   B->ops->shift                 = MatShift_MPIAIJKokkos;
1742 
1743   PetscCall(PetscObjectComposeFunction((PetscObject)B, "MatMPIAIJSetPreallocation_C", MatMPIAIJSetPreallocation_MPIAIJKokkos));
1744   PetscCall(PetscObjectComposeFunction((PetscObject)B, "MatMPIAIJGetLocalMatMerge_C", MatMPIAIJGetLocalMatMerge_MPIAIJKokkos));
1745   PetscCall(PetscObjectComposeFunction((PetscObject)B, "MatSetPreallocationCOO_C", MatSetPreallocationCOO_MPIAIJKokkos));
1746   PetscCall(PetscObjectComposeFunction((PetscObject)B, "MatSetValuesCOO_C", MatSetValuesCOO_MPIAIJKokkos));
1747   PetscFunctionReturn(PETSC_SUCCESS);
1748 }
1749 
1750 PETSC_INTERN PetscErrorCode MatConvert_MPIAIJ_MPIAIJKokkos(Mat A, MatType mtype, MatReuse reuse, Mat *newmat)
1751 {
1752   Mat         B;
1753   Mat_MPIAIJ *a;
1754 
1755   PetscFunctionBegin;
1756   if (reuse == MAT_INITIAL_MATRIX) {
1757     PetscCall(MatDuplicate(A, MAT_COPY_VALUES, newmat));
1758   } else if (reuse == MAT_REUSE_MATRIX) {
1759     PetscCall(MatCopy(A, *newmat, SAME_NONZERO_PATTERN));
1760   }
1761   B = *newmat;
1762 
1763   B->boundtocpu = PETSC_FALSE;
1764   PetscCall(PetscFree(B->defaultvectype));
1765   PetscCall(PetscStrallocpy(VECKOKKOS, &B->defaultvectype));
1766   PetscCall(PetscObjectChangeTypeName((PetscObject)B, MATMPIAIJKOKKOS));
1767 
1768   a = static_cast<Mat_MPIAIJ *>(A->data);
1769   if (a->A) PetscCall(MatSetType(a->A, MATSEQAIJKOKKOS));
1770   if (a->B) PetscCall(MatSetType(a->B, MATSEQAIJKOKKOS));
1771   if (a->lvec) PetscCall(VecSetType(a->lvec, VECSEQKOKKOS));
1772   PetscCall(MatSetOps_MPIAIJKokkos(B));
1773   PetscFunctionReturn(PETSC_SUCCESS);
1774 }
1775 
1776 /*MC
1777    MATAIJKOKKOS - "mpiaijkokkos", a matrix type to be used for CSR sparse matrices with Kokkos
1778 
1779    A matrix type using Kokkos-Kernels CrsMatrix type for portability across different device types
1780 
1781    Options Database Key:
1782 .  -mat_type aijkokkos - sets the matrix type to `MATAIJKOKKOS`
1783 
1784   Level: beginner
1785 
1786 .seealso: [](ch_matrices), `Mat`, `MatCreateAIJKokkos()`, `MATSEQAIJKOKKOS`, `MATSEQAIJ`, `MATMPIAIJ`
1787 M*/
1788 PETSC_EXTERN PetscErrorCode MatCreate_MPIAIJKokkos(Mat A)
1789 {
1790   PetscFunctionBegin;
1791   PetscCall(PetscKokkosInitializeCheck());
1792   PetscCall(MatCreate_MPIAIJ(A));
1793   PetscCall(MatConvert_MPIAIJ_MPIAIJKokkos(A, MATMPIAIJKOKKOS, MAT_INPLACE_MATRIX, &A));
1794   PetscFunctionReturn(PETSC_SUCCESS);
1795 }
1796 
1797 /*@C
1798   MatCreateAIJKokkos - Creates a sparse matrix in `MATAIJKOKOS` (compressed row) format
1799   (the default parallel PETSc format).  This matrix will ultimately pushed down
1800   to Kokkos for calculations.
1801 
1802   Collective
1803 
1804   Input Parameters:
1805 + comm  - MPI communicator, set to `PETSC_COMM_SELF`
1806 . m     - number of local rows (or `PETSC_DECIDE` to have calculated if `M` is given)
1807            This value should be the same as the local size used in creating the
1808            y vector for the matrix-vector product y = Ax.
1809 . n     - This value should be the same as the local size used in creating the
1810        x vector for the matrix-vector product y = Ax. (or `PETSC_DECIDE` to have
1811        calculated if N is given) For square matrices n is almost always `m`.
1812 . M     - number of global rows (or `PETSC_DETERMINE` to have calculated if `m` is given)
1813 . N     - number of global columns (or `PETSC_DETERMINE` to have calculated if `n` is given)
1814 . d_nz  - number of nonzeros per row in DIAGONAL portion of local submatrix
1815            (same value is used for all local rows)
1816 . d_nnz - array containing the number of nonzeros in the various rows of the
1817            DIAGONAL portion of the local submatrix (possibly different for each row)
1818            or `NULL`, if `d_nz` is used to specify the nonzero structure.
1819            The size of this array is equal to the number of local rows, i.e `m`.
1820            For matrices you plan to factor you must leave room for the diagonal entry and
1821            put in the entry even if it is zero.
1822 . o_nz  - number of nonzeros per row in the OFF-DIAGONAL portion of local
1823            submatrix (same value is used for all local rows).
1824 - o_nnz - array containing the number of nonzeros in the various rows of the
1825            OFF-DIAGONAL portion of the local submatrix (possibly different for
1826            each row) or `NULL`, if `o_nz` is used to specify the nonzero
1827            structure. The size of this array is equal to the number
1828            of local rows, i.e `m`.
1829 
1830   Output Parameter:
1831 . A - the matrix
1832 
1833   Level: intermediate
1834 
1835   Notes:
1836   It is recommended that one use the `MatCreate()`, `MatSetType()` and/or `MatSetFromOptions()`,
1837   MatXXXXSetPreallocation() paradigm instead of this routine directly.
1838   [MatXXXXSetPreallocation() is, for example, `MatSeqAIJSetPreallocation()`]
1839 
1840   The AIJ format, also called compressed row storage), is fully compatible with standard Fortran
1841   storage.  That is, the stored row and column indices can begin at
1842   either one (as in Fortran) or zero.
1843 
1844 .seealso: [](ch_matrices), `Mat`, `MATAIJKOKOS`, `MATSEQAIJKOKOS`, `MATMPIAIJKOKOS`, `MatCreate()`, `MatCreateAIJ()`, `MatSetValues()`,
1845           `MatSeqAIJSetColumnIndices()`, `MatCreateSeqAIJWithArrays()`, `MATMPIAIJKOKKOS`, `MATAIJKOKKOS`
1846 @*/
1847 PetscErrorCode MatCreateAIJKokkos(MPI_Comm comm, PetscInt m, PetscInt n, PetscInt M, PetscInt N, PetscInt d_nz, const PetscInt d_nnz[], PetscInt o_nz, const PetscInt o_nnz[], Mat *A)
1848 {
1849   PetscMPIInt size;
1850 
1851   PetscFunctionBegin;
1852   PetscCall(MatCreate(comm, A));
1853   PetscCall(MatSetSizes(*A, m, n, M, N));
1854   PetscCallMPI(MPI_Comm_size(comm, &size));
1855   if (size > 1) {
1856     PetscCall(MatSetType(*A, MATMPIAIJKOKKOS));
1857     PetscCall(MatMPIAIJSetPreallocation(*A, d_nz, d_nnz, o_nz, o_nnz));
1858   } else {
1859     PetscCall(MatSetType(*A, MATSEQAIJKOKKOS));
1860     PetscCall(MatSeqAIJSetPreallocation(*A, d_nz, d_nnz));
1861   }
1862   PetscFunctionReturn(PETSC_SUCCESS);
1863 }
1864