xref: /petsc/src/mat/impls/aij/mpi/kokkos/mpiaijkok.kokkos.cxx (revision 7f296bb328fcd4c99f2da7bfe8ba7ed8a4ebceee)
1 #include <petsc_kokkos.hpp>
2 #include <petscvec_kokkos.hpp>
3 #include <petscmat_kokkos.hpp>
4 #include <petscpkg_version.h>
5 #include <petsc/private/sfimpl.h>
6 #include <petsc/private/kokkosimpl.hpp>
7 #include <../src/mat/impls/aij/seq/kokkos/aijkok.hpp>
8 #include <../src/mat/impls/aij/mpi/mpiaij.h>
9 #include <KokkosSparse_spadd.hpp>
10 #include <KokkosSparse_spgemm.hpp>
11 
12 static PetscErrorCode MatAssemblyEnd_MPIAIJKokkos(Mat A, MatAssemblyType mode)
13 {
14   Mat_MPIAIJ *mpiaij = (Mat_MPIAIJ *)A->data;
15 
16   PetscFunctionBegin;
17   PetscCall(MatAssemblyEnd_MPIAIJ(A, mode));
18   /* E.g., MatCreateSubMatrix() calls MatCreateMPIAIJWithSeqAIJ(comm,A,B,..), which creates Bnew of SEQAIJ and destroys B of SEQAIJKOKKOS.
19      Thus we finalize A/B/lvec's type in MatAssemblyEnd() to handle various cases.
20    */
21   if (mode == MAT_FINAL_ASSEMBLY) {
22     PetscScalarKokkosView v;
23 
24     PetscCall(MatSetType(mpiaij->A, MATSEQAIJKOKKOS));
25     PetscCall(MatSetType(mpiaij->B, MATSEQAIJKOKKOS));
26     PetscCall(VecSetType(mpiaij->lvec, VECSEQKOKKOS));  // lvec is init'ed on host, without copying to device
27     PetscCall(VecGetKokkosViewWrite(mpiaij->lvec, &v)); // mark lvec updated on device, as we never need to init lvec on device
28     PetscCall(VecRestoreKokkosViewWrite(mpiaij->lvec, &v));
29   }
30   PetscFunctionReturn(PETSC_SUCCESS);
31 }
32 
33 static PetscErrorCode MatMPIAIJSetPreallocation_MPIAIJKokkos(Mat mat, PetscInt d_nz, const PetscInt d_nnz[], PetscInt o_nz, const PetscInt o_nnz[])
34 {
35   Mat_MPIAIJ *mpiaij;
36 
37   PetscFunctionBegin;
38   // reuse MPIAIJ's preallocation, which sets A/B's blocksize along other things
39   PetscCall(MatMPIAIJSetPreallocation_MPIAIJ(mat, d_nz, d_nnz, o_nz, o_nnz));
40   mpiaij = static_cast<Mat_MPIAIJ *>(mat->data);
41   PetscCall(MatConvert_SeqAIJ_SeqAIJKokkos(mpiaij->A, MATSEQAIJKOKKOS, MAT_INPLACE_MATRIX, &mpiaij->A));
42   PetscCall(MatConvert_SeqAIJ_SeqAIJKokkos(mpiaij->B, MATSEQAIJKOKKOS, MAT_INPLACE_MATRIX, &mpiaij->B));
43   PetscFunctionReturn(PETSC_SUCCESS);
44 }
45 
46 static PetscErrorCode MatMult_MPIAIJKokkos(Mat mat, Vec xx, Vec yy)
47 {
48   Mat_MPIAIJ *mpiaij = (Mat_MPIAIJ *)mat->data;
49   PetscInt    nt;
50 
51   PetscFunctionBegin;
52   PetscCall(VecGetLocalSize(xx, &nt));
53   PetscCheck(nt == mat->cmap->n, PETSC_COMM_SELF, PETSC_ERR_ARG_SIZ, "Incompatible partition of mat (%" PetscInt_FMT ") and xx (%" PetscInt_FMT ")", mat->cmap->n, nt);
54   PetscCall(VecScatterBegin(mpiaij->Mvctx, xx, mpiaij->lvec, INSERT_VALUES, SCATTER_FORWARD));
55   PetscCall((*mpiaij->A->ops->mult)(mpiaij->A, xx, yy));
56   PetscCall(VecScatterEnd(mpiaij->Mvctx, xx, mpiaij->lvec, INSERT_VALUES, SCATTER_FORWARD));
57   PetscCall((*mpiaij->B->ops->multadd)(mpiaij->B, mpiaij->lvec, yy, yy));
58   PetscFunctionReturn(PETSC_SUCCESS);
59 }
60 
61 static PetscErrorCode MatMultAdd_MPIAIJKokkos(Mat mat, Vec xx, Vec yy, Vec zz)
62 {
63   Mat_MPIAIJ *mpiaij = (Mat_MPIAIJ *)mat->data;
64   PetscInt    nt;
65 
66   PetscFunctionBegin;
67   PetscCall(VecGetLocalSize(xx, &nt));
68   PetscCheck(nt == mat->cmap->n, PETSC_COMM_SELF, PETSC_ERR_ARG_SIZ, "Incompatible partition of mat (%" PetscInt_FMT ") and xx (%" PetscInt_FMT ")", mat->cmap->n, nt);
69   PetscCall(VecScatterBegin(mpiaij->Mvctx, xx, mpiaij->lvec, INSERT_VALUES, SCATTER_FORWARD));
70   PetscCall((*mpiaij->A->ops->multadd)(mpiaij->A, xx, yy, zz));
71   PetscCall(VecScatterEnd(mpiaij->Mvctx, xx, mpiaij->lvec, INSERT_VALUES, SCATTER_FORWARD));
72   PetscCall((*mpiaij->B->ops->multadd)(mpiaij->B, mpiaij->lvec, zz, zz));
73   PetscFunctionReturn(PETSC_SUCCESS);
74 }
75 
76 static PetscErrorCode MatMultTranspose_MPIAIJKokkos(Mat mat, Vec xx, Vec yy)
77 {
78   Mat_MPIAIJ *mpiaij = (Mat_MPIAIJ *)mat->data;
79   PetscInt    nt;
80 
81   PetscFunctionBegin;
82   PetscCall(VecGetLocalSize(xx, &nt));
83   PetscCheck(nt == mat->rmap->n, PETSC_COMM_SELF, PETSC_ERR_ARG_SIZ, "Incompatible partition of mat (%" PetscInt_FMT ") and xx (%" PetscInt_FMT ")", mat->rmap->n, nt);
84   PetscCall((*mpiaij->B->ops->multtranspose)(mpiaij->B, xx, mpiaij->lvec));
85   PetscCall((*mpiaij->A->ops->multtranspose)(mpiaij->A, xx, yy));
86   PetscCall(VecScatterBegin(mpiaij->Mvctx, mpiaij->lvec, yy, ADD_VALUES, SCATTER_REVERSE));
87   PetscCall(VecScatterEnd(mpiaij->Mvctx, mpiaij->lvec, yy, ADD_VALUES, SCATTER_REVERSE));
88   PetscFunctionReturn(PETSC_SUCCESS);
89 }
90 
91 /* Merge the "A, B" matrices of mat into a matrix C.  mat's type is MPIAIJKOKKOS. C's type is MATSEQAIJKOKKOS.
92    A is put before B. C's size would be A->rmap->n by (A->cmap->n + B->cmap->n).
93    C still uses local column ids. Their corresponding global column ids are returned in glob.
94 */
95 static PetscErrorCode MatMPIAIJGetLocalMatMerge_MPIAIJKokkos(Mat mat, MatReuse reuse, IS *glob, Mat *C)
96 {
97   Mat             Ad, Ao;
98   const PetscInt *cmap;
99 
100   PetscFunctionBegin;
101   PetscCall(MatMPIAIJGetSeqAIJ(mat, &Ad, &Ao, &cmap));
102   PetscCall(MatSeqAIJKokkosMergeMats(Ad, Ao, reuse, C));
103   if (glob) {
104     PetscInt cst, i, dn, on, *gidx;
105     PetscCall(MatGetLocalSize(Ad, NULL, &dn));
106     PetscCall(MatGetLocalSize(Ao, NULL, &on));
107     PetscCall(MatGetOwnershipRangeColumn(mat, &cst, NULL));
108     PetscCall(PetscMalloc1(dn + on, &gidx));
109     for (i = 0; i < dn; i++) gidx[i] = cst + i;
110     for (i = 0; i < on; i++) gidx[i + dn] = cmap[i];
111     PetscCall(ISCreateGeneral(PetscObjectComm((PetscObject)Ad), dn + on, gidx, PETSC_OWN_POINTER, glob));
112   }
113   PetscFunctionReturn(PETSC_SUCCESS);
114 }
115 
116 /* Structs used in matrix products of type C=AB, C=A^tB and C=B^tAB */
117 struct MatMatStruct {
118   PetscInt            n, *garray;     // C's garray and its size.
119   KokkosCsrMatrix     Cd, Co;         // C is in split form matrices (all in local column indcies)
120   KokkosCsrMatrix     C1, C2, C3, C4; // intermediate mat products
121   KokkosCsrMatrix     C2_mid, C4_mid; // alias of C2, C4; share their a[], i[], but with different j[] (hence column size)
122   PetscIntKokkosView  E_NzLeft;
123   PetscSF             sf = nullptr; // SF to bcast or reduce matrices E to F
124   MatScalarKokkosView rootBuf, leafBuf;
125   KokkosCsrMatrix     Fd, Fo; // F in split form
126 
127   KernelHandle kh1; // compute C1, add C1+C3 or C1+Fd
128   KernelHandle kh2; // compute C2, add C2+C4 or C2+Fo
129   KernelHandle kh3; // compute C3
130   KernelHandle kh4; // compute C4
131 
132   PetscInt E_TeamSize; // kernel launching parameters in merging E or splitting F
133   PetscInt E_VectorLength;
134   PetscInt E_RowsPerTeam;
135   PetscInt F_TeamSize;
136   PetscInt F_VectorLength;
137   PetscInt F_RowsPerTeam;
138 
139   ~MatMatStruct()
140   {
141     PetscFunctionBegin;
142     PetscCallAbort(PETSC_COMM_SELF, PetscSFDestroy(&sf));
143     PetscFunctionReturnVoid();
144   }
145 };
146 
147 struct MatMatStruct_AB : public MatMatStruct {
148   PetscIntKokkosView F_NzLeft; // plans to split F (in leafbuf) into Fd, Fo
149   PetscIntKokkosView irootloc; // plans to put E (i.e., Bd, Bo) into rootBuf
150   PetscIntKokkosView rowoffset;
151 };
152 
153 struct MatMatStruct_AtB : public MatMatStruct {
154   MatColIdxKokkosView Fdjmap; // plans to reduce data in rootBuf to Fd, Fo
155   MatColIdxKokkosView Fdjperm;
156   MatColIdxKokkosView Fojmap;
157   MatColIdxKokkosView Fojperm;
158 };
159 
160 struct MatProductData_MPIAIJKokkos {
161   MatMatStruct_AB  *mmAB     = nullptr;
162   MatMatStruct_AtB *mmAtB    = nullptr;
163   PetscBool         reusesym = PETSC_FALSE;
164   Mat               Z        = nullptr; // store Z=AB in computing BtAB
165 
166   ~MatProductData_MPIAIJKokkos()
167   {
168     delete mmAB;
169     delete mmAtB;
170     PetscCallAbort(PETSC_COMM_SELF, MatDestroy(&Z));
171   }
172 };
173 
174 static PetscErrorCode MatProductDataDestroy_MPIAIJKokkos(void *data)
175 {
176   PetscFunctionBegin;
177   PetscCallCXX(delete static_cast<MatProductData_MPIAIJKokkos *>(data));
178   PetscFunctionReturn(PETSC_SUCCESS);
179 }
180 
181 // Adapted from Kokkos-Kernels spmv_launch_parameters(), to get parameters in Kokkos nested loops which we used to merge or
182 // split csr matrices. The rule is to have "vector_length * team_size" be around 256 on GPUs (e.g., for a CUDA thread block)
183 template <class ExecutionSpace>
184 static PetscErrorCode MatMergeGetLaunchParameters(PetscInt numRows, PetscInt nnz, PetscInt rows_per_thread, PetscInt &team_size, PetscInt &vector_length, PetscInt &rows_per_team)
185 {
186 #if PETSC_PKG_KOKKOS_KERNELS_VERSION_LE(4, 4, 1)
187   constexpr bool is_gpu_exec_space = KokkosKernels::Impl::kk_is_gpu_exec_space<ExecutionSpace>();
188 #else
189   constexpr bool is_gpu_exec_space = KokkosKernels::Impl::is_gpu_exec_space_v<ExecutionSpace>;
190 #endif
191   Kokkos::TeamPolicy<ExecutionSpace> teamPolicy(128, Kokkos::AUTO);
192 
193   PetscFunctionBegin;
194   PetscInt nnz_per_row = numRows ? (nnz / numRows) : 0; // we might meet empty matrices
195 
196   if (nnz_per_row < 1) nnz_per_row = 1;
197 
198   int max_vector_length = teamPolicy.vector_length_max();
199 
200   if (vector_length < 1) {
201     vector_length = 1;
202     while (vector_length < max_vector_length && vector_length * 6 < nnz_per_row) vector_length *= 2;
203   }
204 
205   // Determine rows per thread
206   if (rows_per_thread < 1) {
207     if (is_gpu_exec_space) rows_per_thread = 1;
208     else {
209       if (nnz_per_row < 20 && nnz > 5000000) {
210         rows_per_thread = 256;
211       } else rows_per_thread = 64;
212     }
213   }
214 
215   if (team_size < 1) {
216     if (is_gpu_exec_space) {
217       team_size = 256 / vector_length;
218     } else {
219       team_size = 1;
220     }
221   }
222 
223   rows_per_team = rows_per_thread * team_size;
224 
225   if (rows_per_team < 0) {
226     PetscInt nnz_per_team = 4096;
227     PetscInt conc         = ExecutionSpace().concurrency();
228     while ((conc * nnz_per_team * 4 > nnz) && (nnz_per_team > 256)) nnz_per_team /= 2;
229     rows_per_team = (nnz_per_team + nnz_per_row - 1) / nnz_per_row;
230   }
231   PetscFunctionReturn(PETSC_SUCCESS);
232 }
233 
234 /*
235   Reduce two sets of global indices into local ones
236 
237   Input Parameters:
238 +  n1          - size of garray1[], the first set
239 .  garray1[n1] - a sorted global index array (without duplicates)
240 .  m           - size of indices[], the second set
241 -  indices[m]  - a unsorted global index array (might have duplicates), which will be updated on output into local ones
242 
243   Output Parameters:
244 +  n2          - size of garray2[], the merged set, which combines garray1[] and indices[]
245 .  garray2[n2] - allocated by callee using PetscMalloc1(). Contains sorted unique global indices (without duplicates). Caller needs to free it.
246 .  map[n1]     - allocated by caller. It gives garray1[i] = garray2[map[i]]
247 -  indices[m]  - on output, global indices in this array are rewritten with local ones, i.e, indices_input[i] = garray2[indices_output[i]]
248 
249    Example, say
250     n1         = 5
251     garray1[5] = {1, 4, 7, 8, 10}
252     m          = 4
253     indices[4] = {2, 4, 8, 9}
254 
255    Combining them together, we have 7 global indices in garray2[]
256     n2         = 7
257     garray2[7] = {1, 2, 4, 7, 8, 9, 10}
258 
259    And we have map[] to connect "garray1[i] = garray2[map[i]], i=[0,n1)"
260     map[5] = {0, 2, 3, 4, 6}
261 
262    On output, indices[] is updated with local indices
263     indices[4] = {1, 2, 4, 5}
264 */
265 static PetscErrorCode ReduceTwoSetsOfGlobalIndices(PetscInt n1, const PetscInt *garray1, PetscInt m, PetscInt *indices, PetscInt *n2_, PetscInt **garray2_, PetscInt *map)
266 {
267   PetscHMapI    g2l = nullptr;
268   PetscHashIter iter;
269   PetscInt      tot, key, val; // total unique global indices. key is global id; val is local id
270   PetscInt      n2, *garray2;
271 
272   PetscFunctionBegin;
273   tot = 0;
274   PetscCall(PetscHMapICreateWithSize(n1, &g2l));
275   for (PetscInt i = 0; i < m; i++) {                                // insert those in indices[]
276     PetscCall(PetscHMapIGetWithDefault(g2l, indices[i], -1, &val)); // if not exist, val is set with -1
277     if (val < 0) PetscCall(PetscHMapISet(g2l, indices[i], tot++));  // val < 0 means gid is not in the hash table yet
278   }
279 
280   for (PetscInt i = 0; i < n1; i++) { // insert those in garray1[]
281     PetscCall(PetscHMapIGetWithDefault(g2l, garray1[i], -1, &val));
282     if (val < 0) PetscCall(PetscHMapISet(g2l, garray1[i], tot++));
283   }
284 
285   // Pull out (unique) globals in the hash table and put them in garray2[]
286   n2 = tot;
287   PetscCall(PetscMalloc1(n2, &garray2));
288   tot = 0;
289   PetscHashIterBegin(g2l, iter);
290   while (!PetscHashIterAtEnd(g2l, iter)) {
291     PetscHashIterGetKey(g2l, iter, key);
292     PetscHashIterNext(g2l, iter);
293     garray2[tot++] = key;
294   }
295 
296   // Sort garray2[] and then map them to local indices starting from 0
297   PetscCall(PetscSortInt(n2, garray2));
298   PetscCall(PetscHMapIClear(g2l));
299   for (PetscInt i = 0; i < tot; i++) PetscCall(PetscHMapISet(g2l, garray2[i], i)); // i is the local id
300 
301   // Rewrite indices[] with local indices
302   for (PetscInt i = 0; i < m; i++) {
303     PetscCall(PetscHMapIGetWithDefault(g2l, indices[i], -1, &val));
304     PetscAssert(val >= 0, PETSC_COMM_SELF, PETSC_ERR_PLIB, "Met a negative local column index");
305     indices[i] = val;
306   }
307   // Record the map that maps garray1[i] to garray2[map[i]]
308   for (PetscInt i = 0; i < n1; i++) PetscCall(PetscHMapIGetWithDefault(g2l, garray1[i], -1, &map[i]));
309   PetscCall(PetscHMapIDestroy(&g2l));
310   *n2_      = n2;
311   *garray2_ = garray2;
312   PetscFunctionReturn(PETSC_SUCCESS);
313 }
314 
315 /*
316   MatMPIAIJKokkosReduce - Reduce rows of a MPIAIJKOKKOS matrix (E, in split form) to produce another matrix (F, also in split form, stored in mm)
317 
318   It is the reverse of MatMPIAIJKokkosBcast() in some sense, but with a different signature since we do not really need a fully populated MPIAIJKOKKOS E.
319 
320   Think each row of E as a leaf, then the given ownerSF specifies roots for the leaves. Roots may connect to multiple leaves.
321   In this routine, we sparse-merge leaves (rows) at their roots to form potentially longer rows in F. F's number of rows will be nroots of ownerSF.
322 
323   Input Parameters:
324 +  comm       - MPI communicator of E
325 .  A          - diag block of E, using local column indices
326 .  B          - off-diag block of E, using local column indices
327 .  cstart      - (global) start column of Ed
328 .  cend        - (global) end column + 1 of Ed.  In other words, E's column ownership is in range of [cstart, cend)
329 .  garray1[n1] - global column indices of Eo. Here n1 is Eo's column size.
330 .  ownerSF     - the SF specifies ownership (root) of rows in E
331 .  reuse       - either MAT_INITIAL_MATRIX or MAT_REUSE_MATRIX
332 -  mm          - to stash intermediate data structures for reuse
333 
334   Output Parameters:
335 +  map[n1]  - allocated by caller. It maps garray1[] to garray2[]. See more at ReduceTwoSetsOfGlobalIndices().
336 -  mm       - contains various info, such as garray2[], F (Fd, Fo) etc.
337 
338   Notes:
339   When reuse = MAT_REUSE_MATRIX, cstart, cend, garray1, ownerSF, map are not significant.
340 
341  */
342 static PetscErrorCode MatMPIAIJKokkosReduceBegin(MPI_Comm comm, KokkosCsrMatrix A, KokkosCsrMatrix B, PetscInt cstart, PetscInt cend, const PetscInt *garray1, PetscSF ownerSF, MatReuse reuse, PetscInt *map, MatMatStruct_AtB *mm)
343 {
344   PetscFunctionBegin;
345   if (reuse == MAT_INITIAL_MATRIX) {
346     PetscInt Em = A.numRows(), Fm;
347     PetscInt n1 = B.numCols();
348 
349     PetscCall(PetscSFGetGraph(ownerSF, &Fm, NULL, NULL, NULL)); // Fm = #rows of F = nroots of ownerSF
350 
351     // Do the analysis on host
352     auto                 Ai_h = Kokkos::create_mirror_view_and_copy(HostMirrorMemorySpace(), A.graph.row_map);
353     auto                 Aj_h = Kokkos::create_mirror_view_and_copy(HostMirrorMemorySpace(), A.graph.entries);
354     auto                 Bi_h = Kokkos::create_mirror_view_and_copy(HostMirrorMemorySpace(), B.graph.row_map);
355     auto                 Bj_h = Kokkos::create_mirror_view_and_copy(HostMirrorMemorySpace(), B.graph.entries);
356     const MatRowMapType *Ai = Ai_h.data(), *Bi = Bi_h.data();
357     const MatColIdxType *Aj = Aj_h.data(), *Bj = Bj_h.data();
358 
359     // Count how many nonzeros of each row in E are in the left of the diag block [cstart,cend)
360     PetscIntKokkosViewHost E_NzLeft_h(NoInit("E_NzLeft_h"), Em), E_RowLen_h(NoInit("E_RowLen_h"), Em);
361     PetscInt              *E_NzLeft = E_NzLeft_h.data(), *E_RowLen = E_RowLen_h.data();
362     for (PetscInt i = 0; i < Em; i++) {
363       const PetscInt *first, *last, *it;
364       PetscInt        count, step;
365       // std::lower_bound(first,last,cstart), but need to use global column indices
366       first = Bj + Bi[i];
367       last  = Bj + Bi[i + 1];
368       count = last - first;
369       while (count > 0) {
370         it   = first;
371         step = count / 2;
372         it += step;
373         if (garray1[*it] < cstart) { // map local to global
374           first = ++it;
375           count -= step + 1;
376         } else count = step;
377       }
378       E_NzLeft[i] = first - (Bj + Bi[i]);
379       E_RowLen[i] = (Ai[i + 1] - Ai[i]) + (Bi[i + 1] - Bi[i]);
380     }
381 
382     // Get length of rows (i.e., sizes of leaves) that contribute to my roots
383     const PetscMPIInt *iranks, *ranks;
384     const PetscInt    *ioffset, *irootloc, *roffset, *rmine;
385     PetscMPIInt        niranks, nranks;
386     MPI_Request       *reqs;
387     PetscMPIInt        tag;
388     PetscSF            reduceSF;
389     PetscInt          *sdisp, *rdisp;
390 
391     PetscCall(PetscCommGetNewTag(comm, &tag));
392     PetscCall(PetscSFGetLeafRanks(ownerSF, &niranks, &iranks, &ioffset, &irootloc));  // get leaf ranks connecting to roots on this process (I'll recv from them)
393     PetscCall(PetscSFGetRootRanks(ownerSF, &nranks, &ranks, &roffset, &rmine, NULL)); // get root ranks this process connects (I'll send to them)
394 
395     // Find out length of each row I will receive. Even for the same row index, when they are from
396     // different senders, they might have different lengths (and sparsity patterns)
397     PetscInt  sendRowCnt = roffset[nranks], recvRowCnt = ioffset[niranks];
398     PetscInt *sendRowLen, *recvRowLen; // lengths of rows of I need to send/recv per process
399 
400     PetscCall(PetscMalloc5(sendRowCnt, &sendRowLen, recvRowCnt + 1, &recvRowLen, nranks, &sdisp, niranks + 1, &rdisp, nranks + niranks, &reqs));
401 
402     for (PetscInt i = 0; i < sendRowCnt; i++) sendRowLen[i] = E_RowLen[rmine[i]];
403     recvRowLen[0] = 0; // since we will make it in CSR format later
404     recvRowLen++;      // advance the pointer now
405     for (PetscInt i = 0; i < niranks; i++) MPI_Irecv(&recvRowLen[ioffset[i]], ioffset[i + 1] - ioffset[i], MPIU_INT, iranks[i], tag, comm, &reqs[nranks + i]);
406     for (PetscInt i = 0; i < nranks; i++) MPIU_Isend(&sendRowLen[roffset[i]], roffset[i + 1] - roffset[i], MPIU_INT, ranks[i], tag, comm, &reqs[i]);
407     PetscCallMPI(MPI_Waitall(nranks + niranks, reqs, MPI_STATUSES_IGNORE));
408 
409     // Build the real PetscSF for reducing E rows (buffer to buffer)
410     rdisp[0] = 0;
411     for (PetscInt i = 0; i < niranks; i++) {
412       rdisp[i + 1] = rdisp[i];
413       for (PetscInt j = ioffset[i]; j < ioffset[i + 1]; j++) { rdisp[i + 1] += recvRowLen[j]; }
414     }
415     recvRowLen--; // put it back into csr format
416     for (PetscInt i = 0; i < recvRowCnt; i++) recvRowLen[i + 1] += recvRowLen[i];
417 
418     for (PetscInt i = 0; i < nranks; i++) MPIU_Irecv(&sdisp[i], 1, MPIU_INT, ranks[i], tag, comm, &reqs[i]);
419     for (PetscInt i = 0; i < niranks; i++) MPIU_Isend(&rdisp[i], 1, MPIU_INT, iranks[i], tag, comm, &reqs[nranks + i]);
420     PetscCallMPI(MPI_Waitall(nranks + niranks, reqs, MPI_STATUSES_IGNORE));
421 
422     PetscInt     nleaves = 0, Enz = 0;    // leaves are nonzeros I will send
423     PetscInt     nroots = rdisp[niranks]; // roots are nonzeros I will recv
424     PetscSFNode *iremote;
425 
426     for (PetscInt i = 0; i < Em; i++) Enz += E_RowLen[i];
427     PetscAssert(A.nnz() + B.nnz() == Enz, comm, PETSC_ERR_PLIB, "Enz should be equal to sum of nnz of A and B");
428     PetscCallMPI(PetscMalloc1(Enz, &iremote)); // no free, since we give ownership to reduceSF
429 
430     for (PetscInt i = 0; i < nranks; i++) {
431       PetscInt count = 0;
432       for (PetscInt j = roffset[i]; j < roffset[i + 1]; j++) count += E_RowLen[rmine[j]];
433       for (PetscInt j = 0; j < count; j++) {
434         iremote[nleaves + j].rank  = ranks[i];
435         iremote[nleaves + j].index = sdisp[i] + j;
436       }
437       nleaves += count;
438     }
439     PetscCheck(nleaves == Enz, comm, PETSC_ERR_PLIB, "nleaves should be equal to Enz");
440 
441     PetscCall(PetscSFCreate(comm, &reduceSF));
442     PetscCall(PetscSFSetGraph(reduceSF, nroots, nleaves, NULL, PETSC_OWN_POINTER, iremote, PETSC_OWN_POINTER));
443 
444     // Copy (global) column indices of the needed rows in E to sendCol[], and then PetscSFReduce to recvCol[]
445     PetscInt *sendCol, *recvCol;
446     PetscCall(PetscMalloc2(nleaves, &sendCol, nroots, &recvCol));
447     for (PetscInt k = 0; k < roffset[nranks]; k++) {
448       PetscInt  i      = rmine[k]; // row to be copied
449       PetscInt *buf    = &sendCol[Ai[i] + Bi[i]];
450       PetscInt  nzLeft = E_NzLeft[i];
451       PetscInt  alen = Ai[i + 1] - Ai[i], blen = Bi[i + 1] - Bi[i];
452       for (PetscInt j = 0; j < alen + blen; j++) {
453         if (j < nzLeft) {
454           buf[j] = garray1[Bj[Bi[i] + j]]; // left B, in global
455         } else if (j < nzLeft + alen) {
456           buf[j] = Aj[Ai[i] + j - nzLeft] + cstart; // diag A, also in global
457         } else {
458           buf[j] = garray1[Bj[Bi[i] + j - alen]]; // right B, in global
459         }
460       }
461     }
462     PetscCall(PetscSFReduceWithMemTypeBegin(reduceSF, MPIU_INT, PETSC_MEMTYPE_HOST, sendCol, PETSC_MEMTYPE_HOST, recvCol, MPI_REPLACE));
463     PetscCall(PetscSFReduceEnd(reduceSF, MPIU_INT, sendCol, recvCol, MPI_REPLACE));
464 
465     // With recvCol[], we do a series of analysis to get i, j of Fd, Fo, and build plans to reduce nonzeros in recv buffers to Fd and Fo
466     PetscInt *recvRowPerm, *recvColSorted;
467     PetscInt *recvNzPerm, *recvNzPermSorted;
468     PetscCall(PetscMalloc4(recvRowCnt, &recvRowPerm, nroots, &recvColSorted, nroots, &recvNzPerm, nroots, &recvNzPermSorted));
469 
470     for (PetscInt i = 0; i < nroots; i++) recvNzPerm[i] = i;                   // numbering all received nonzeros
471     for (PetscInt i = 0; i < recvRowCnt; i++) recvRowPerm[i] = i;              // put up a permutation array, so that after sorting we know where to get a row in recvCol[]
472     PetscCall(PetscSortIntWithPermutation(recvRowCnt, irootloc, recvRowPerm)); // irootloc[] (owned by ownerSF) won't be changed
473 
474     // i[] array, nz are always easiest to compute
475     MatRowMapKokkosViewHost Fdi_h(NoInit("Fdi_h"), Fm + 1), Foi_h(NoInit("Foi_h"), Fm + 1);
476     MatRowMapType          *Fdi, *Foi;
477     PetscInt                FnzDups = 0, Fdnz = 0, FdnzDups = 0, Fonz = 0, FonzDups = 0; // nz (with or without dups) in F, Fd, Fo
478     PetscInt                iter;
479 
480     Kokkos::deep_copy(Fdi_h, 0); // zero, as we will do 'val++' on them
481     Kokkos::deep_copy(Foi_h, 0);
482     Fdi  = Fdi_h.data() + 1; // +1 for easy indexing in code below
483     Foi  = Foi_h.data() + 1;
484     iter = 0;
485     while (iter < recvRowCnt) { // iter over received rows
486       PetscInt curRowIdx = irootloc[recvRowPerm[iter]];
487       PetscInt dupRows   = 1; // current row has this many contributing rows (of various sparsity patterns)
488 
489       while (iter + dupRows < recvRowCnt && irootloc[recvRowPerm[iter + dupRows]] == curRowIdx) dupRows++;
490 
491       // Copy column indices (and their permutation) of these rows into recvColSorted & recvNzPermSorted
492       PetscInt  nz    = 0; // nz (with dups) in the current row
493       PetscInt *jbuf  = recvColSorted + FnzDups;
494       PetscInt *pbuf  = recvNzPermSorted + FnzDups;
495       PetscInt *jbuf2 = jbuf; // temp pointers
496       PetscInt *pbuf2 = pbuf;
497       for (PetscInt d = 0; d < dupRows; d++) {
498         PetscInt i   = recvRowPerm[iter + d];
499         PetscInt len = recvRowLen[i + 1] - recvRowLen[i];
500         PetscCall(PetscArraycpy(jbuf2, &recvCol[recvRowLen[i]], len));
501         PetscCall(PetscArraycpy(pbuf2, &recvNzPerm[recvRowLen[i]], len));
502         jbuf2 += len;
503         pbuf2 += len;
504         nz += len;
505       }
506       PetscCall(PetscIntSortSemiOrderedWithArray(nz, jbuf, pbuf)); // It could be improved with k-way merge sort, since the rows are already sorted
507 
508       // Scan column indices (in jbuf[0,nz), might have dups) of this row, and see how many go to Fd and how many go to Fo
509       PetscInt cur = 0;
510       while (cur < nz) {
511         PetscInt curColIdx = jbuf[cur];
512         PetscInt dups      = 1;
513 
514         while (cur + dups < nz && jbuf[cur + dups] == curColIdx) dups++;
515         if (curColIdx >= cstart && curColIdx < cend) {
516           Fdi[curRowIdx]++;
517           FdnzDups += dups;
518         } else {
519           Foi[curRowIdx]++;
520           FonzDups += dups;
521         }
522         cur += dups;
523       }
524 
525       FnzDups += nz;
526       iter += dupRows; // Move to next unique row
527     }
528 
529     Fdi = Fdi_h.data(); // restore Fdi, Foi and make them CSR
530     Foi = Foi_h.data();
531     for (PetscInt i = 0; i < Fm; i++) {
532       Fdi[i + 1] += Fdi[i];
533       Foi[i + 1] += Foi[i];
534     }
535     Fdnz = Fdi[Fm];
536     Fonz = Foi[Fm];
537     PetscCall(PetscFree2(sendCol, recvCol));
538 
539     // Allocate j, jmap, jperm for Fd and Fo
540     MatColIdxKokkosViewHost Fdj_h(NoInit("Fdj_h"), Fdnz), Foj_h(NoInit("Foj_h"), Fonz);
541     MatRowMapKokkosViewHost Fdjmap_h(NoInit("Fdjmap_h"), Fdnz + 1), Fojmap_h(NoInit("Fojmap_h"), Fonz + 1); // +1 to make csr
542     MatRowMapKokkosViewHost Fdjperm_h(NoInit("Fdjperm_h"), FdnzDups), Fojperm_h(NoInit("Fojperm_h"), FonzDups);
543     MatColIdxType          *Fdj = Fdj_h.data(), *Foj = Foj_h.data();
544     MatRowMapType          *Fdjmap = Fdjmap_h.data(), *Fojmap = Fojmap_h.data();
545     MatRowMapType          *Fdjperm = Fdjperm_h.data(), *Fojperm = Fojperm_h.data();
546 
547     // Scan recvColSorted[] again, and fill j, jmap, jperm for Fd and Fo
548     Fdjmap[0] = 0;
549     Fojmap[0] = 0;
550     FnzDups   = 0;
551     Fdnz      = 0;
552     Fonz      = 0;
553     iter      = 0; // iter over received rows
554     while (iter < recvRowCnt) {
555       PetscInt curRowIdx = irootloc[recvRowPerm[iter]]; // current row idx
556       PetscInt dupRows   = 1;                           // It has this many contributing rows (of various lengths)
557       PetscInt nz        = 0;                           // nz (with dups) in the current row
558 
559       while (iter + dupRows < recvRowCnt && irootloc[recvRowPerm[iter + dupRows]] == curRowIdx) dupRows++;
560       for (PetscInt d = 0; d < dupRows; d++) {
561         PetscInt i = recvRowPerm[iter + d];
562         nz += recvRowLen[i + 1] - recvRowLen[i];
563       }
564 
565       PetscInt *jbuf = recvColSorted + FnzDups;
566       // Scan columns (in jbuf[0,nz) of this row, copy them and their permutation to j[] and jperm[] of Fd and Fo
567       PetscInt cur = 0;
568       while (cur < nz) {
569         PetscInt curColIdx = jbuf[cur];
570         PetscInt dups      = 1;
571 
572         while (cur + dups < nz && jbuf[cur + dups] == curColIdx) dups++;
573         if (curColIdx >= cstart && curColIdx < cend) {
574           Fdj[Fdnz]        = curColIdx - cstart; // easily convert to local
575           Fdjmap[Fdnz + 1] = Fdjmap[Fdnz] + dups;
576           for (PetscInt j = 0; j < dups; j++) Fdjperm[Fdjmap[Fdnz] + j] = recvNzPermSorted[FnzDups + j];
577           FdnzDups += dups;
578           Fdnz++;
579         } else {
580           Foj[Fonz]        = curColIdx; // in global
581           Fojmap[Fonz + 1] = Fojmap[Fonz] + dups;
582           for (PetscInt j = 0; j < dups; j++) Fojperm[Fojmap[Fonz] + j] = recvNzPermSorted[FnzDups + j];
583           FonzDups += dups;
584           Fonz++;
585         }
586         cur += dups;
587         FnzDups += dups;
588       }
589       iter += dupRows; // Move to next unique row
590     }
591     PetscCall(PetscFree4(recvRowPerm, recvColSorted, recvNzPerm, recvNzPermSorted));
592     PetscCall(PetscFree5(sendRowLen, recvRowLen, sdisp, rdisp, reqs));
593 
594     // Combine global column indices in garray1[] and Foj[]
595     PetscInt n2, *garray2;
596 
597     PetscCall(ReduceTwoSetsOfGlobalIndices(n1, garray1, Fonz, Foj, &n2, &garray2, map));
598     mm->sf       = reduceSF;
599     mm->leafBuf  = MatScalarKokkosView(NoInit("leafBuf"), nleaves);
600     mm->rootBuf  = MatScalarKokkosView(NoInit("rootBuf"), nroots);
601     mm->garray   = garray2; // give ownership, so no free
602     mm->n        = n2;
603     mm->E_NzLeft = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), E_NzLeft_h);
604     mm->Fdjmap   = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Fdjmap_h);
605     mm->Fdjperm  = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Fdjperm_h);
606     mm->Fojmap   = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Fojmap_h);
607     mm->Fojperm  = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Fojperm_h);
608 
609     // Output Fd and Fo in KokkosCsrMatrix format
610     MatScalarKokkosView Fda_d(NoInit("Fda_d"), Fdnz);
611     MatRowMapKokkosView Fdi_d = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Fdi_h);
612     MatColIdxKokkosView Fdj_d = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Fdj_h);
613     MatScalarKokkosView Foa_d(NoInit("Foa_d"), Fonz);
614     MatRowMapKokkosView Foi_d = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Foi_h);
615     MatColIdxKokkosView Foj_d = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Foj_h);
616 
617     PetscCallCXX(mm->Fd = KokkosCsrMatrix("Fd", Fm, cend - cstart, Fdnz, Fda_d, Fdi_d, Fdj_d));
618     PetscCallCXX(mm->Fo = KokkosCsrMatrix("Fo", Fm, n2, Fonz, Foa_d, Foi_d, Foj_d)); // Fo's column size is n2, length of garray2[]
619 
620     // Compute kernel launch parameters in merging E
621     PetscInt teamSize, vectorLength, rowsPerTeam;
622 
623     teamSize = vectorLength = rowsPerTeam = -1;
624     PetscCall(MatMergeGetLaunchParameters<DefaultExecutionSpace>(Em, Enz, -1, teamSize, vectorLength, rowsPerTeam));
625     mm->E_TeamSize     = teamSize;
626     mm->E_VectorLength = vectorLength;
627     mm->E_RowsPerTeam  = rowsPerTeam;
628   } else PetscCheck(reuse == MAT_REUSE_MATRIX, comm, PETSC_ERR_PLIB, "Unsupported MatReuse enum %d", reuse);
629 
630   // Handy aliases
631   auto       &Aa           = A.values;
632   auto       &Ba           = B.values;
633   const auto &Ai           = A.graph.row_map;
634   const auto &Bi           = B.graph.row_map;
635   const auto &E_NzLeft     = mm->E_NzLeft;
636   auto       &leafBuf      = mm->leafBuf;
637   auto       &rootBuf      = mm->rootBuf;
638   PetscSF     reduceSF     = mm->sf;
639   PetscInt    Em           = A.numRows();
640   PetscInt    teamSize     = mm->E_TeamSize;
641   PetscInt    vectorLength = mm->E_VectorLength;
642   PetscInt    rowsPerTeam  = mm->E_RowsPerTeam;
643   PetscInt    workSets     = (Em + rowsPerTeam - 1) / rowsPerTeam;
644 
645   // Copy rows in A/B of E to leafBuf, then pass it to rootBuf
646   PetscCallCXX(Kokkos::parallel_for(
647     Kokkos::TeamPolicy<>(PetscGetKokkosExecutionSpace(), workSets, teamSize, vectorLength), KOKKOS_LAMBDA(const KokkosTeamMemberType &t) {
648       Kokkos::parallel_for(Kokkos::TeamThreadRange(t, 0, rowsPerTeam), [&](PetscInt k) {
649         PetscInt i = t.league_rank() * rowsPerTeam + k; // i-th row in F
650         if (i < Em) {
651           PetscInt disp   = Ai(i) + Bi(i);
652           PetscInt alen   = Ai(i + 1) - Ai(i);
653           PetscInt blen   = Bi(i + 1) - Bi(i);
654           PetscInt nzleft = E_NzLeft(i);
655 
656           Kokkos::parallel_for(Kokkos::ThreadVectorRange(t, alen + blen), [&](PetscInt j) {
657             MatScalar &val = leafBuf(disp + j);
658             if (j < nzleft) { // B left
659               val = Ba(Bi(i) + j);
660             } else if (j < nzleft + alen) { // diag A
661               val = Aa(Ai(i) + j - nzleft);
662             } else { // B right
663               val = Ba(Bi(i) + j - alen);
664             }
665           });
666         }
667       });
668     }));
669   PetscCall(PetscSFReduceWithMemTypeBegin(reduceSF, MPIU_SCALAR, PETSC_MEMTYPE_KOKKOS, leafBuf.data(), PETSC_MEMTYPE_KOKKOS, rootBuf.data(), MPI_REPLACE));
670   PetscFunctionReturn(PETSC_SUCCESS);
671 }
672 
673 // To finish MatMPIAIJKokkosReduce.
674 static PetscErrorCode MatMPIAIJKokkosReduceEnd(MPI_Comm comm, KokkosCsrMatrix A, KokkosCsrMatrix B, PetscInt cstart, PetscInt cend, const PetscInt *garray1, PetscSF ownerSF, MatReuse reuse, PetscInt *map, MatMatStruct_AtB *mm)
675 {
676   auto       &leafBuf  = mm->leafBuf;
677   auto       &rootBuf  = mm->rootBuf;
678   auto       &Fda      = mm->Fd.values;
679   const auto &Fdjmap   = mm->Fdjmap;
680   const auto &Fdjperm  = mm->Fdjperm;
681   auto        Fdnz     = mm->Fd.nnz();
682   auto       &Foa      = mm->Fo.values;
683   const auto &Fojmap   = mm->Fojmap;
684   const auto &Fojperm  = mm->Fojperm;
685   auto        Fonz     = mm->Fo.nnz();
686   PetscSF     reduceSF = mm->sf;
687 
688   PetscFunctionBegin;
689   PetscCall(PetscSFReduceEnd(reduceSF, MPIU_SCALAR, leafBuf.data(), rootBuf.data(), MPI_REPLACE));
690 
691   // Reduce data in rootBuf to Fd and Fo
692   PetscCallCXX(Kokkos::parallel_for(
693     Kokkos::RangePolicy<>(PetscGetKokkosExecutionSpace(), 0, Fdnz), KOKKOS_LAMBDA(const MatRowMapType i) {
694       PetscScalar sum = 0.0;
695       for (MatRowMapType k = Fdjmap(i); k < Fdjmap(i + 1); k++) sum += rootBuf(Fdjperm(k));
696       Fda(i) = sum;
697     }));
698 
699   PetscCallCXX(Kokkos::parallel_for(
700     Kokkos::RangePolicy<>(PetscGetKokkosExecutionSpace(), 0, Fonz), KOKKOS_LAMBDA(const MatRowMapType i) {
701       PetscScalar sum = 0.0;
702       for (MatRowMapType k = Fojmap(i); k < Fojmap(i + 1); k++) sum += rootBuf(Fojperm(k));
703       Foa(i) = sum;
704     }));
705   PetscFunctionReturn(PETSC_SUCCESS);
706 }
707 
708 /*
709   MatMPIAIJKokkosBcast - Bcast local rows of a MPIAIJKOKKOS matrix (E) to produce a local matrix (F, stored in mm) in split form
710 
711   This is a complex routine. It is essentially the MPIAIJKOKKOS counterpart of MatGetBrowsOfAoCols_MPIAIJ, but supports
712   device and involves various index mapping.
713 
714   In the given ownerSF, leaves correspond to rows in F, and roots correspond to rows in E. Roots may connect to multiple leaves.
715   Suppose F's j-th row is connected to a root identified by PetscSFNode (k,i), it means we need to bcast the i-th row of E on rank k
716   to j-th row of F. ownerSF is not an arbitrary SF, instead it is the Mvctx of another MPIAIJ matrix A that is able to perform A*E.
717   F has the same column layout as E.
718 
719   Conceptually F has global column indices. In this routine, we spit F into diagonal Fd and off-diagonal Fo.
720   Fd uses local column indices, which are easy to compute. We just need to subtract the "local column range start" from the global indices.
721   Fo had global column indices at first. We will reduce them into local ones. In doing that, we also take into account the global
722   column indices that E's off-diag block has. Let's say there are n1 such indices stored in garray1[]. We will reduce them along with
723   column indices in Fo and update Fo with local indices.
724 
725    Input Parameters:
726 +   E       - the MPIAIJKOKKOS matrix
727 .   ownerSF - the ownership SF (insignificant in MAT_REUSE_MATRIX)
728 .   reuse   - either MAT_INITIAL_MATRIX or MAT_REUSE_MATRIX
729 -   mm      - to stash matproduct intermediate data structures
730 
731     Output Parameters:
732 +   map[n1] - allocated by caller. It maps garray1[] to garray2[]. See more at ReduceTwoSetsOfGlobalIndices.
733 -   mm      - contains various info, such as garray2[], Fd, Fo, etc.
734 
735     Notes:
736     When reuse = MAT_REUSE_MATRIX, ownerSF, map are not significant.
737     The routine is provide in split-phase form MatMPIAIJKokkosBcastBegin/End() to provide computation/communication opportunities.
738 */
739 static PetscErrorCode MatMPIAIJKokkosBcastBegin(Mat E, PetscSF ownerSF, MatReuse reuse, PetscInt *map, MatMatStruct_AB *mm)
740 {
741   Mat_MPIAIJ       *empi = static_cast<Mat_MPIAIJ *>(E->data);
742   Mat               A = empi->A, B = empi->B; // diag and off-diag
743   Mat_SeqAIJKokkos *akok = static_cast<Mat_SeqAIJKokkos *>(A->spptr), *bkok = static_cast<Mat_SeqAIJKokkos *>(B->spptr);
744   PetscInt          Em = E->rmap->n; // #local rows
745   MPI_Comm          comm;
746 
747   PetscFunctionBegin;
748   PetscCallMPI(PetscObjectGetComm((PetscObject)E, &comm));
749   if (reuse == MAT_INITIAL_MATRIX) {
750     Mat_SeqAIJ     *aseq = static_cast<Mat_SeqAIJ *>(A->data), *bseq = static_cast<Mat_SeqAIJ *>(B->data);
751     PetscInt        n1 = B->cmap->n, *Ai = aseq->i, *Aj = aseq->j, *Bi = bseq->i, *Bj = bseq->j;
752     const PetscInt *garray1 = empi->garray; // its size is n1
753     PetscInt        cstart, cend;
754     PetscSF         bcastSF;
755 
756     PetscCall(MatGetOwnershipRangeColumn(E, &cstart, &cend));
757 
758     // Count how many nonzeros of each row in E are in the left of the diag block [cstart,cend)
759     PetscIntKokkosViewHost E_NzLeft_h(NoInit("E_NzLeft_h"), Em), E_RowLen_h(NoInit("E_RowLen_h"), Em);
760     PetscInt              *E_NzLeft = E_NzLeft_h.data(), *E_RowLen = E_RowLen_h.data();
761     for (PetscInt i = 0; i < Em; i++) {
762       const PetscInt *first, *last, *it;
763       PetscInt        count, step;
764       // std::lower_bound(first,last,cstart), but need to use global column indices
765       first = Bj + Bi[i];
766       last  = Bj + Bi[i + 1];
767       count = last - first;
768       while (count > 0) {
769         it   = first;
770         step = count / 2;
771         it += step;
772         if (empi->garray[*it] < cstart) { // map local to global
773           first = ++it;
774           count -= step + 1;
775         } else count = step;
776       }
777       E_NzLeft[i] = first - (Bj + Bi[i]);
778       E_RowLen[i] = (Ai[i + 1] - Ai[i]) + (Bi[i + 1] - Bi[i]);
779     }
780 
781     // Compute row pointer Fi of F
782     PetscInt *Fi, Fm, Fnz;
783     PetscCall(PetscSFGetGraph(ownerSF, NULL, &Fm, NULL, NULL)); // Fm = #rows of F = nleaves of ownerSF
784     PetscCall(PetscMalloc1(Fm + 1, &Fi));
785     Fi[0] = 0;
786     PetscCall(PetscSFBcastWithMemTypeBegin(ownerSF, MPIU_INT, PETSC_MEMTYPE_HOST, E_RowLen, PETSC_MEMTYPE_HOST, &Fi[1], MPI_REPLACE));
787     PetscCall(PetscSFBcastEnd(ownerSF, MPIU_INT, E_RowLen, &Fi[1], MPI_REPLACE));
788     for (PetscInt i = 0; i < Fm; i++) Fi[i + 1] += Fi[i];
789     Fnz = Fi[Fm];
790 
791     // Build the real PetscSF for bcasting E rows (buffer to buffer)
792     const PetscMPIInt *iranks, *ranks;
793     const PetscInt    *ioffset, *irootloc, *roffset;
794     PetscMPIInt        niranks, nranks;
795     PetscInt          *sdisp, *rdisp;
796     MPI_Request       *reqs;
797     PetscMPIInt        tag;
798 
799     PetscCall(PetscSFGetLeafRanks(ownerSF, &niranks, &iranks, &ioffset, &irootloc)); // get leaf ranks referencing roots on this process
800     PetscCall(PetscSFGetRootRanks(ownerSF, &nranks, &ranks, &roffset, NULL, NULL));  // recv info
801     PetscCall(PetscMalloc3(niranks + 1, &sdisp, nranks, &rdisp, niranks + nranks, &reqs));
802 
803     sdisp[0] = 0; // send displacement
804     for (PetscInt i = 0; i < niranks; i++) {
805       sdisp[i + 1] = sdisp[i];
806       for (PetscInt j = ioffset[i]; j < ioffset[i + 1]; j++) {
807         PetscInt r = irootloc[j]; // row to be sent
808         sdisp[i + 1] += E_RowLen[r];
809       }
810     }
811 
812     PetscCallMPI(PetscCommGetNewTag(comm, &tag));
813     for (PetscInt i = 0; i < nranks; i++) PetscCallMPI(MPIU_Irecv(&rdisp[i], 1, MPIU_INT, ranks[i], tag, comm, &reqs[i]));
814     for (PetscInt i = 0; i < niranks; i++) PetscCallMPI(MPIU_Isend(&sdisp[i], 1, MPIU_INT, iranks[i], tag, comm, &reqs[nranks + i]));
815     PetscCallMPI(MPI_Waitall(niranks + nranks, reqs, MPI_STATUSES_IGNORE));
816 
817     PetscInt     nleaves = Fnz;            // leaves are nonzeros I will receive
818     PetscInt     nroots  = sdisp[niranks]; // roots are nonzeros I will send
819     PetscSFNode *iremote;                  // give ownership to bcastSF
820     PetscCall(PetscMalloc1(nleaves, &iremote));
821     for (PetscInt i = 0; i < nranks; i++) { // for each sender rank
822       PetscInt k = 0;
823       for (PetscInt j = Fi[roffset[i]]; j < Fi[roffset[i + 1]]; j++) { // I will receive rows [roffset[i], roffset[i+1]) of F from ranks[i]
824         iremote[j].rank  = ranks[i];
825         iremote[j].index = rdisp[i] + k; // their root location
826         k++;
827       }
828     }
829     PetscCall(PetscSFCreate(comm, &bcastSF));
830     PetscCall(PetscSFSetGraph(bcastSF, nroots, nleaves, NULL, PETSC_OWN_POINTER, iremote, PETSC_OWN_POINTER));
831     PetscCall(PetscFree3(sdisp, rdisp, reqs));
832 
833     // Build a plan (rowoffset, irootloc, E_NzLeft) to copy rows in E to rootdata of bcastSF in parallel
834     PetscIntKokkosViewHost rowoffset_h(NoInit("rowoffset_h"), ioffset[niranks] + 1);
835     PetscInt              *rowoffset = rowoffset_h.data(); // for each entry (row) indicated in irootloc[], we calculate its destinate offset in copying
836     rowoffset[0]                     = 0;
837     for (PetscInt i = 0; i < ioffset[niranks]; i++) { rowoffset[i + 1] = rowoffset[i] + E_RowLen[irootloc[i]]; }
838 
839     // Copy (global) column indices of the needed rows in E to a buffer, and then bcast to Fj[]
840     PetscInt *jbuf, *Fj;
841     PetscCall(PetscMalloc2(nroots, &jbuf, Fnz, &Fj));
842     for (PetscInt k = 0; k < ioffset[niranks]; k++) {
843       PetscInt  i      = irootloc[k]; // row to be copied
844       PetscInt *buf    = &jbuf[rowoffset[k]];
845       PetscInt  nzLeft = E_NzLeft[i];
846       PetscInt  alen = Ai[i + 1] - Ai[i], blen = Bi[i + 1] - Bi[i];
847       for (PetscInt j = 0; j < alen + blen; j++) {
848         if (j < nzLeft) {
849           buf[j] = empi->garray[Bj[Bi[i] + j]]; // left B, in global
850         } else if (j < nzLeft + alen) {
851           buf[j] = Aj[Ai[i] + j - nzLeft] + cstart; // diag A, also in global
852         } else {
853           buf[j] = empi->garray[Bj[Bi[i] + j - alen]]; // right B, in global
854         }
855       }
856     }
857     PetscCall(PetscSFBcastWithMemTypeBegin(bcastSF, MPIU_INT, PETSC_MEMTYPE_HOST, jbuf, PETSC_MEMTYPE_HOST, Fj, MPI_REPLACE));
858     PetscCall(PetscSFBcastEnd(bcastSF, MPIU_INT, jbuf, Fj, MPI_REPLACE));
859 
860     // Build a plan (i.e., F_NzLeft) to split F into Fd and Fo
861     MatRowMapKokkosViewHost Fdi_h(NoInit("Fdi_h"), Fm + 1), Foi_h(NoInit("Foi_h"), Fm + 1); // row pointer of Fd, Fo
862     MatColIdxKokkosViewHost F_NzLeft_h(NoInit("F_NzLeft_h"), Fm);                           // split each row of F into Left, Diag, Right. We only need to record #nz in Left and Diag.
863     MatRowMapType          *Fdi = Fdi_h.data(), *Foi = Foi_h.data();
864     MatColIdxType          *F_NzLeft = F_NzLeft_h.data();
865 
866     Fdi[0] = Foi[0] = 0;
867     for (PetscInt i = 0; i < Fm; i++) {
868       PetscInt *first, *last, *lb1, *lb2;
869       // cut the row into: Left, [cstart, cend), Right
870       first       = Fj + Fi[i];
871       last        = Fj + Fi[i + 1];
872       lb1         = std::lower_bound(first, last, cstart);
873       F_NzLeft[i] = lb1 - first;
874       lb2         = std::lower_bound(first, last, cend);
875       Fdi[i + 1]  = lb2 - lb1;                        // row i length in Fdi
876       Foi[i + 1]  = (Fi[i + 1] - Fi[i]) - Fdi[i + 1]; // row i length in Foi
877     }
878     for (PetscInt i = 0; i < Fm; i++) {
879       Fdi[i + 1] += Fdi[i];
880       Foi[i + 1] += Foi[i];
881     }
882 
883     // Fill Fdj[] and Foj[], i.e., columns of Fd and Fo. Fdj[] are local, but Foj[] are not yet.
884     PetscInt                Fdnz = Fdi[Fm], Fonz = Foi[Fm];
885     MatColIdxKokkosViewHost Fdj_h(NoInit("Fdj_h"), Fdnz), Foj_h(NoInit("Foj_h"), Fonz);
886     MatColIdxType          *Fdj = Fdj_h.data(), *Foj = Foj_h.data(), gid;
887 
888     for (PetscInt i = 0; i < Fm; i++) {
889       PetscInt nzLeft = F_NzLeft[i];
890       PetscInt len    = Fdi[i + 1] - Fdi[i]; // diag row len
891       for (PetscInt j = 0; j < Fi[i + 1] - Fi[i]; j++) {
892         gid = Fj[Fi[i] + j];
893         if (j < nzLeft) { // left, in global
894           Foj[Foi[i] + j] = gid;
895         } else if (j < nzLeft + len) { // diag, in local
896           Fdj[Fdi[i] + j - nzLeft] = gid - cstart;
897         } else { // right, in global
898           Foj[Foi[i] + j - len] = gid;
899         }
900       }
901     }
902     PetscCall(PetscFree2(jbuf, Fj));
903     PetscCall(PetscFree(Fi));
904 
905     // Reduce global indices in Foj[] and garray1[] into local ones
906     PetscInt n2, *garray2;
907     PetscCall(ReduceTwoSetsOfGlobalIndices(n1, garray1, Fonz, Foj, &n2, &garray2, map));
908 
909     // Record the plans built above, for reuse
910     PetscIntKokkosViewHost tmp(const_cast<PetscInt *>(irootloc), ioffset[niranks]); // irootloc[] is owned by ownerSF. We create a copy for safety
911     PetscIntKokkosViewHost irootloc_h(NoInit("irootloc_h"), ioffset[niranks]);
912     Kokkos::deep_copy(irootloc_h, tmp);
913     mm->sf        = bcastSF;
914     mm->E_NzLeft  = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), E_NzLeft_h);
915     mm->F_NzLeft  = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), F_NzLeft_h);
916     mm->irootloc  = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), irootloc_h);
917     mm->rowoffset = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), rowoffset_h);
918     mm->rootBuf   = MatScalarKokkosView(NoInit("rootBuf"), nroots);
919     mm->leafBuf   = MatScalarKokkosView(NoInit("leafBuf"), nleaves);
920     mm->garray    = garray2;
921     mm->n         = n2;
922 
923     // Output Fd and Fo in KokkosCsrMatrix format
924     MatScalarKokkosView Fda_d(NoInit("Fda_d"), Fdnz), Foa_d(NoInit("Foa_d"), Fonz);
925     MatRowMapKokkosView Fdi_d = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Fdi_h);
926     MatColIdxKokkosView Fdj_d = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Fdj_h);
927     MatRowMapKokkosView Foi_d = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Foi_h);
928     MatColIdxKokkosView Foj_d = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Foj_h);
929 
930     PetscCallCXX(mm->Fd = KokkosCsrMatrix("Fd", Fm, cend - cstart, Fdnz, Fda_d, Fdi_d, Fdj_d));
931     PetscCallCXX(mm->Fo = KokkosCsrMatrix("Fo", Fm, n2, Fonz, Foa_d, Foi_d, Foj_d));
932 
933     // Compute kernel launch parameters in merging E or splitting F
934     PetscInt teamSize, vectorLength, rowsPerTeam;
935 
936     teamSize = vectorLength = rowsPerTeam = -1;
937     PetscCall(MatMergeGetLaunchParameters<DefaultExecutionSpace>(mm->irootloc.extent(0), mm->rootBuf.extent(0), -1, teamSize, vectorLength, rowsPerTeam));
938     mm->E_TeamSize     = teamSize;
939     mm->E_VectorLength = vectorLength;
940     mm->E_RowsPerTeam  = rowsPerTeam;
941 
942     teamSize = vectorLength = rowsPerTeam = -1;
943     PetscCall(MatMergeGetLaunchParameters<DefaultExecutionSpace>(Fm, Fnz, -1, teamSize, vectorLength, rowsPerTeam));
944     mm->F_TeamSize     = teamSize;
945     mm->F_VectorLength = vectorLength;
946     mm->F_RowsPerTeam  = rowsPerTeam;
947   } else PetscCheck(reuse == MAT_REUSE_MATRIX, comm, PETSC_ERR_PLIB, "Unsupported MatReuse enum %d", reuse);
948 
949   // Sync E's value to device
950   akok->a_dual.sync_device();
951   bkok->a_dual.sync_device();
952 
953   // Handy aliases
954   const auto &Aa = akok->a_dual.view_device();
955   const auto &Ba = bkok->a_dual.view_device();
956   const auto &Ai = akok->i_dual.view_device();
957   const auto &Bi = bkok->i_dual.view_device();
958 
959   // Fetch the plans
960   PetscIntKokkosView  &E_NzLeft  = mm->E_NzLeft;
961   PetscSF             &bcastSF   = mm->sf;
962   MatScalarKokkosView &rootBuf   = mm->rootBuf;
963   MatScalarKokkosView &leafBuf   = mm->leafBuf;
964   PetscIntKokkosView  &irootloc  = mm->irootloc;
965   PetscIntKokkosView  &rowoffset = mm->rowoffset;
966 
967   PetscInt teamSize     = mm->E_TeamSize;
968   PetscInt vectorLength = mm->E_VectorLength;
969   PetscInt rowsPerTeam  = mm->E_RowsPerTeam;
970   PetscInt workSets     = (irootloc.extent(0) + rowsPerTeam - 1) / rowsPerTeam;
971 
972   // Copy rows in A/B of E to rootBuf, then bcast it to leafBuf
973   PetscCallCXX(Kokkos::parallel_for(
974     Kokkos::TeamPolicy<>(PetscGetKokkosExecutionSpace(), workSets, teamSize, vectorLength), KOKKOS_LAMBDA(const KokkosTeamMemberType &t) {
975       Kokkos::parallel_for(Kokkos::TeamThreadRange(t, 0, rowsPerTeam), [&](PetscInt k) {
976         size_t r = t.league_rank() * rowsPerTeam + k; // r-th entry in irootloc[]
977         if (r < irootloc.extent(0)) {
978           PetscInt i      = irootloc(r); // row i of E
979           PetscInt disp   = rowoffset(r);
980           PetscInt alen   = Ai(i + 1) - Ai(i);
981           PetscInt blen   = Bi(i + 1) - Bi(i);
982           PetscInt nzleft = E_NzLeft(i);
983 
984           Kokkos::parallel_for(Kokkos::ThreadVectorRange(t, alen + blen), [&](PetscInt j) {
985             if (j < nzleft) { // B left
986               rootBuf(disp + j) = Ba(Bi(i) + j);
987             } else if (j < nzleft + alen) { // diag A
988               rootBuf(disp + j) = Aa(Ai(i) + j - nzleft);
989             } else { // B right
990               rootBuf(disp + j) = Ba(Bi(i) + j - alen);
991             }
992           });
993         }
994       });
995     }));
996   PetscCall(PetscSFBcastWithMemTypeBegin(bcastSF, MPIU_SCALAR, PETSC_MEMTYPE_KOKKOS, rootBuf.data(), PETSC_MEMTYPE_KOKKOS, leafBuf.data(), MPI_REPLACE));
997   PetscFunctionReturn(PETSC_SUCCESS);
998 }
999 
1000 // To finish MatMPIAIJKokkosBcast.
1001 static PetscErrorCode MatMPIAIJKokkosBcastEnd(Mat E, PetscSF ownerSF, MatReuse reuse, PetscInt *map, MatMatStruct_AB *mm)
1002 {
1003   PetscFunctionBegin;
1004   const auto &Fd  = mm->Fd;
1005   const auto &Fo  = mm->Fo;
1006   const auto &Fdi = Fd.graph.row_map;
1007   const auto &Foi = Fo.graph.row_map;
1008   auto       &Fda = Fd.values;
1009   auto       &Foa = Fo.values;
1010   auto        Fm  = Fd.numRows();
1011 
1012   PetscIntKokkosView  &F_NzLeft     = mm->F_NzLeft;
1013   PetscSF             &bcastSF      = mm->sf;
1014   MatScalarKokkosView &rootBuf      = mm->rootBuf;
1015   MatScalarKokkosView &leafBuf      = mm->leafBuf;
1016   PetscInt             teamSize     = mm->F_TeamSize;
1017   PetscInt             vectorLength = mm->F_VectorLength;
1018   PetscInt             rowsPerTeam  = mm->F_RowsPerTeam;
1019   PetscInt             workSets     = (Fm + rowsPerTeam - 1) / rowsPerTeam;
1020 
1021   PetscCall(PetscSFBcastEnd(bcastSF, MPIU_SCALAR, rootBuf.data(), leafBuf.data(), MPI_REPLACE));
1022 
1023   // Update Fda and Foa with new data in leafBuf (as if it is Fa)
1024   PetscCallCXX(Kokkos::parallel_for(
1025     Kokkos::TeamPolicy<>(PetscGetKokkosExecutionSpace(), workSets, teamSize, vectorLength), KOKKOS_LAMBDA(const KokkosTeamMemberType &t) {
1026       Kokkos::parallel_for(Kokkos::TeamThreadRange(t, 0, rowsPerTeam), [&](PetscInt k) {
1027         PetscInt i = t.league_rank() * rowsPerTeam + k; // i-th row in F
1028         if (i < Fm) {
1029           PetscInt nzLeft = F_NzLeft(i);
1030           PetscInt alen   = Fdi(i + 1) - Fdi(i);
1031           PetscInt blen   = Foi(i + 1) - Foi(i);
1032           PetscInt Fii    = Fdi(i) + Foi(i);
1033 
1034           Kokkos::parallel_for(Kokkos::ThreadVectorRange(t, alen + blen), [&](PetscInt j) {
1035             PetscScalar val = leafBuf(Fii + j);
1036             if (j < nzLeft) { // left
1037               Foa(Foi(i) + j) = val;
1038             } else if (j < nzLeft + alen) { // diag
1039               Fda(Fdi(i) + j - nzLeft) = val;
1040             } else { // right
1041               Foa(Foi(i) + j - alen) = val;
1042             }
1043           });
1044         }
1045       });
1046     }));
1047   PetscFunctionReturn(PETSC_SUCCESS);
1048 }
1049 
1050 static PetscErrorCode MatProductSymbolic_MPIAIJKokkos_AtB(Mat_Product *product, Mat A, Mat B, MatMatStruct_AtB *mm)
1051 {
1052   Mat_MPIAIJ     *ampi = static_cast<Mat_MPIAIJ *>(A->data);
1053   Mat_MPIAIJ     *bmpi = static_cast<Mat_MPIAIJ *>(B->data);
1054   KokkosCsrMatrix Adt, Aot, Ad, Ao, Bd, Bo;
1055   PetscInt        cstart, cend;
1056   MPI_Comm        comm;
1057 
1058   PetscFunctionBegin;
1059   PetscCall(PetscObjectGetComm((PetscObject)B, &comm));
1060   PetscCall(MatSeqAIJKokkosGenerateTranspose_Private(ampi->A, &Adt));
1061   PetscCall(MatSeqAIJKokkosGenerateTranspose_Private(ampi->B, &Aot));
1062   PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(ampi->A, &Ad));
1063   PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(ampi->B, &Ao));
1064   PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(bmpi->A, &Bd));
1065   PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(bmpi->B, &Bo));
1066 
1067   // TODO: add command line options to select spgemm algorithms
1068   auto spgemm_alg = KokkosSparse::SPGEMMAlgorithm::SPGEMM_DEFAULT; // default is TPL if enabled, otherwise KK
1069 
1070   // CUDA-10.2's spgemm has bugs. We prefer the SpGEMMreuse APIs introduced in cuda-11.4
1071 #if defined(KOKKOSKERNELS_ENABLE_TPL_CUSPARSE)
1072   #if PETSC_PKG_CUDA_VERSION_LT(11, 4, 0)
1073   spgemm_alg = KokkosSparse::SPGEMMAlgorithm::SPGEMM_KK;
1074   #endif
1075 #endif
1076 
1077   PetscCallCXX(mm->kh1.create_spgemm_handle(spgemm_alg));
1078   PetscCallCXX(mm->kh2.create_spgemm_handle(spgemm_alg));
1079   PetscCallCXX(mm->kh3.create_spgemm_handle(spgemm_alg));
1080   PetscCallCXX(mm->kh4.create_spgemm_handle(spgemm_alg));
1081 
1082   // Aot * (B's diag + B's off-diag)
1083   PetscCallCXX(KokkosSparse::spgemm_symbolic(mm->kh3, Aot, false, Bd, false, mm->C3));
1084   PetscCallCXX(KokkosSparse::spgemm_symbolic(mm->kh4, Aot, false, Bo, false, mm->C4));
1085   // KK spgemm_symbolic() only populates the result's row map, but not its columns.
1086   // TODO: Remove the fake spgemm_numeric() after KK fixed this problem.
1087   PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh3, Aot, false, Bd, false, mm->C3));
1088   PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh4, Aot, false, Bo, false, mm->C4));
1089 #if PETSC_PKG_KOKKOS_KERNELS_VERSION_LT(4, 0, 0)
1090 
1091   PetscCallCXX(sort_crs_matrix(mm->C3));
1092   PetscCallCXX(sort_crs_matrix(mm->C4));
1093 #endif
1094 
1095   // Reduce E (i.e., C3 and C4)'s rows to form F, and overlap the communication
1096   PetscIntKokkosViewHost map_h(NoInit("map_h"), bmpi->B->cmap->n);
1097   PetscCall(MatGetOwnershipRangeColumn(B, &cstart, &cend));
1098   PetscCall(MatMPIAIJKokkosReduceBegin(comm, mm->C3, mm->C4, cstart, cend, bmpi->garray, ampi->Mvctx, MAT_INITIAL_MATRIX, map_h.data(), mm));
1099 
1100   // Adt * (B's diag + B's off-diag)
1101   PetscCallCXX(KokkosSparse::spgemm_symbolic(mm->kh1, Adt, false, Bd, false, mm->C1));
1102   PetscCallCXX(KokkosSparse::spgemm_symbolic(mm->kh2, Adt, false, Bo, false, mm->C2_mid));
1103   PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh1, Adt, false, Bd, false, mm->C1));
1104   PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh2, Adt, false, Bo, false, mm->C2_mid));
1105 #if PETSC_PKG_KOKKOS_KERNELS_VERSION_LT(4, 0, 0)
1106   PetscCallCXX(sort_crs_matrix(mm->C1));
1107   PetscCallCXX(sort_crs_matrix(mm->C2_mid));
1108 #endif
1109 
1110   PetscCall(MatMPIAIJKokkosReduceEnd(comm, mm->C3, mm->C4, cstart, cend, bmpi->garray, ampi->Mvctx, MAT_INITIAL_MATRIX, map_h.data(), mm));
1111 
1112   // Create C2, which shares a, i arrays with C2_mid, but with new column indices and potentially larger column size
1113   MatColIdxKokkosView oldj = mm->C2_mid.graph.entries, newj(NoInit("j"), oldj.extent(0));
1114   PetscIntKokkosView  map  = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), map_h);
1115   PetscCallCXX(Kokkos::parallel_for(Kokkos::RangePolicy<>(PetscGetKokkosExecutionSpace(), 0, oldj.extent(0)), KOKKOS_LAMBDA(const PetscInt i) { newj(i) = map(oldj(i)); }));
1116   PetscCallCXX(mm->C2 = KokkosCsrMatrix("C2", mm->C2_mid.numRows(), mm->n /*new column size*/, mm->C2_mid.nnz(), mm->C2_mid.values, mm->C2_mid.graph.row_map, newj));
1117 
1118   // C = (C1+Fd, C2+Fo)
1119   PetscCallCXX(mm->kh1.create_spadd_handle(true)); // C1, Fd are sorted
1120   PetscCallCXX(mm->kh2.create_spadd_handle(true)); // C2, Fo are sorted
1121   PetscCallCXX(KokkosSparse::spadd_symbolic(&mm->kh1, mm->C1, mm->Fd, mm->Cd));
1122   PetscCallCXX(KokkosSparse::spadd_symbolic(&mm->kh2, mm->C2, mm->Fo, mm->Co));
1123   PetscCallCXX(KokkosSparse::spadd_numeric(&mm->kh1, 1.0, mm->C1, 1.0, mm->Fd, mm->Cd));
1124   PetscCallCXX(KokkosSparse::spadd_numeric(&mm->kh2, 1.0, mm->C2, 1.0, mm->Fo, mm->Co));
1125   PetscFunctionReturn(PETSC_SUCCESS);
1126 }
1127 
1128 static PetscErrorCode MatProductNumeric_MPIAIJKokkos_AtB(Mat_Product *product, Mat A, Mat B, MatMatStruct_AtB *mm)
1129 {
1130   Mat_MPIAIJ     *ampi = static_cast<Mat_MPIAIJ *>(A->data);
1131   Mat_MPIAIJ     *bmpi = static_cast<Mat_MPIAIJ *>(B->data);
1132   KokkosCsrMatrix Adt, Aot, Bd, Bo;
1133   MPI_Comm        comm;
1134 
1135   PetscFunctionBegin;
1136   PetscCall(PetscObjectGetComm((PetscObject)B, &comm));
1137   PetscCall(MatSeqAIJKokkosGenerateTranspose_Private(ampi->A, &Adt));
1138   PetscCall(MatSeqAIJKokkosGenerateTranspose_Private(ampi->B, &Aot));
1139   PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(bmpi->A, &Bd));
1140   PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(bmpi->B, &Bo));
1141 
1142   // Aot * (B's diag + B's off-diag)
1143   PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh3, Aot, false, Bd, false, mm->C3));
1144   PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh4, Aot, false, Bo, false, mm->C4));
1145 
1146   // Reduce E (i.e., C3 and C4)'s rows to form F, and overlap the communication
1147   PetscCall(MatMPIAIJKokkosReduceBegin(comm, mm->C3, mm->C4, 0, 0, NULL, NULL, MAT_REUSE_MATRIX, NULL, mm));
1148 
1149   // Adt * (B's diag + B's off-diag)
1150   PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh1, Adt, false, Bd, false, mm->C1));
1151   PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh2, Adt, false, Bo, false, mm->C2_mid));
1152 
1153   PetscCall(MatMPIAIJKokkosReduceEnd(comm, mm->C3, mm->C4, 0, 0, NULL, NULL, MAT_REUSE_MATRIX, NULL, mm));
1154 
1155   // C = (C1+Fd, C2+Fo)
1156   PetscCallCXX(KokkosSparse::spadd_numeric(&mm->kh1, 1.0, mm->C1, 1.0, mm->Fd, mm->Cd));
1157   PetscCallCXX(KokkosSparse::spadd_numeric(&mm->kh2, 1.0, mm->C2, 1.0, mm->Fo, mm->Co));
1158   PetscFunctionReturn(PETSC_SUCCESS);
1159 }
1160 
1161 /* MatProductSymbolic_MPIAIJKokkos_AB - AB flavor of MatProductSymbolic_MPIAIJKokkos
1162 
1163   Input Parameters:
1164 +  product  - Mat_Product which carried out the computation. Passed in to access info about this mat product.
1165 .  A        - an MPIAIJKOKKOS matrix
1166 .  B        - an MPIAIJKOKKOS matrix
1167 -  mm       - a struct used to stash intermediate data when computing AB. Persist from symbolic to numeric operations.
1168 */
1169 static PetscErrorCode MatProductSymbolic_MPIAIJKokkos_AB(Mat_Product *product, Mat A, Mat B, MatMatStruct_AB *mm)
1170 {
1171   Mat_MPIAIJ     *ampi = static_cast<Mat_MPIAIJ *>(A->data);
1172   Mat_MPIAIJ     *bmpi = static_cast<Mat_MPIAIJ *>(B->data);
1173   KokkosCsrMatrix Ad, Ao, Bd, Bo;
1174 
1175   PetscFunctionBegin;
1176   PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(ampi->A, &Ad));
1177   PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(ampi->B, &Ao));
1178   PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(bmpi->A, &Bd));
1179   PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(bmpi->B, &Bo));
1180 
1181   // TODO: add command line options to select spgemm algorithms
1182   auto spgemm_alg = KokkosSparse::SPGEMMAlgorithm::SPGEMM_DEFAULT; // default is TPL if enabled, otherwise KK
1183 
1184   // CUDA-10.2's spgemm has bugs. We prefer the SpGEMMreuse APIs introduced in cuda-11.4
1185 #if defined(KOKKOSKERNELS_ENABLE_TPL_CUSPARSE)
1186   #if PETSC_PKG_CUDA_VERSION_LT(11, 4, 0)
1187   spgemm_alg = KokkosSparse::SPGEMMAlgorithm::SPGEMM_KK;
1188   #endif
1189 #endif
1190 
1191   mm->kh1.create_spgemm_handle(spgemm_alg);
1192   mm->kh2.create_spgemm_handle(spgemm_alg);
1193   mm->kh3.create_spgemm_handle(spgemm_alg);
1194   mm->kh4.create_spgemm_handle(spgemm_alg);
1195 
1196   // Bcast B's rows to form F, and overlap the communication
1197   PetscIntKokkosViewHost map_h(NoInit("map_h"), bmpi->B->cmap->n);
1198   PetscCall(MatMPIAIJKokkosBcastBegin(B, ampi->Mvctx, MAT_INITIAL_MATRIX, map_h.data(), mm));
1199 
1200   // A's diag * (B's diag + B's off-diag)
1201   PetscCallCXX(KokkosSparse::spgemm_symbolic(mm->kh1, Ad, false, Bd, false, mm->C1));
1202   PetscCallCXX(KokkosSparse::spgemm_symbolic(mm->kh2, Ad, false, Bo, false, mm->C2_mid)); // C2 aliases with C2_mid, except with new column indices
1203   // KK spgemm_symbolic() only populates the result's row map, but not its columns.
1204   // TODO: Remove the fake spgemm_numeric() after KK fixed this problem.
1205   PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh1, Ad, false, Bd, false, mm->C1));
1206   PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh2, Ad, false, Bo, false, mm->C2_mid));
1207 #if PETSC_PKG_KOKKOS_KERNELS_VERSION_LT(4, 0, 0)
1208   PetscCallCXX(sort_crs_matrix(mm->C1));
1209   PetscCallCXX(sort_crs_matrix(mm->C2_mid));
1210 #endif
1211 
1212   PetscCall(MatMPIAIJKokkosBcastEnd(B, ampi->Mvctx, MAT_INITIAL_MATRIX, map_h.data(), mm));
1213 
1214   // A's off-diag * (F's diag + F's off-diag)
1215   PetscCallCXX(KokkosSparse::spgemm_symbolic(mm->kh3, Ao, false, mm->Fd, false, mm->C3));
1216   PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh3, Ao, false, mm->Fd, false, mm->C3));
1217   PetscCallCXX(KokkosSparse::spgemm_symbolic(mm->kh4, Ao, false, mm->Fo, false, mm->C4));
1218   PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh4, Ao, false, mm->Fo, false, mm->C4));
1219 #if PETSC_PKG_KOKKOS_KERNELS_VERSION_LT(4, 0, 0)
1220   PetscCallCXX(sort_crs_matrix(mm->C3));
1221   PetscCallCXX(sort_crs_matrix(mm->C4));
1222 #endif
1223 
1224   // Create C2, which shares a, i arrays with C2_mid, but with new column indices and potentially larger column size
1225   MatColIdxKokkosView oldj = mm->C2_mid.graph.entries, newj(NoInit("j"), oldj.extent(0));
1226   PetscIntKokkosView  map  = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), map_h);
1227   PetscCallCXX(Kokkos::parallel_for(Kokkos::RangePolicy<>(PetscGetKokkosExecutionSpace(), 0, oldj.extent(0)), KOKKOS_LAMBDA(const PetscInt i) { newj(i) = map(oldj(i)); }));
1228   mm->C2 = KokkosCsrMatrix("C2", mm->C2_mid.numRows(), mm->n /*new column size*/, mm->C2_mid.nnz(), mm->C2_mid.values, mm->C2_mid.graph.row_map, newj);
1229 
1230   // C = (Cd, Co) = (C1+C3, C2+C4)
1231   mm->kh1.create_spadd_handle(true); // C1, C3 are sorted
1232   mm->kh2.create_spadd_handle(true); // C2, C4 are sorted
1233   PetscCallCXX(KokkosSparse::spadd_symbolic(&mm->kh1, mm->C1, mm->C3, mm->Cd));
1234   PetscCallCXX(KokkosSparse::spadd_symbolic(&mm->kh2, mm->C2, mm->C4, mm->Co));
1235   PetscCallCXX(KokkosSparse::spadd_numeric(&mm->kh1, 1.0, mm->C1, 1.0, mm->C3, mm->Cd));
1236   PetscCallCXX(KokkosSparse::spadd_numeric(&mm->kh2, 1.0, mm->C2, 1.0, mm->C4, mm->Co));
1237   PetscFunctionReturn(PETSC_SUCCESS);
1238 }
1239 
1240 static PetscErrorCode MatProductNumeric_MPIAIJKokkos_AB(Mat_Product *product, Mat A, Mat B, MatMatStruct_AB *mm)
1241 {
1242   Mat_MPIAIJ     *ampi = static_cast<Mat_MPIAIJ *>(A->data);
1243   Mat_MPIAIJ     *bmpi = static_cast<Mat_MPIAIJ *>(B->data);
1244   KokkosCsrMatrix Ad, Ao, Bd, Bo;
1245 
1246   PetscFunctionBegin;
1247   PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(ampi->A, &Ad));
1248   PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(ampi->B, &Ao));
1249   PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(bmpi->A, &Bd));
1250   PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(bmpi->B, &Bo));
1251 
1252   // Bcast B's rows to form F, and overlap the communication
1253   PetscCall(MatMPIAIJKokkosBcastBegin(B, NULL, MAT_REUSE_MATRIX, NULL, mm));
1254 
1255   // A's diag * (B's diag + B's off-diag)
1256   PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh1, Ad, false, Bd, false, mm->C1));
1257   PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh2, Ad, false, Bo, false, mm->C2_mid));
1258 
1259   PetscCall(MatMPIAIJKokkosBcastEnd(B, NULL, MAT_REUSE_MATRIX, NULL, mm));
1260 
1261   // A's off-diag * (F's diag + F's off-diag)
1262   PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh3, Ao, false, mm->Fd, false, mm->C3));
1263   PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh4, Ao, false, mm->Fo, false, mm->C4));
1264 
1265   // C = (Cd, Co) = (C1+C3, C2+C4)
1266   PetscCallCXX(KokkosSparse::spadd_numeric(&mm->kh1, 1.0, mm->C1, 1.0, mm->C3, mm->Cd));
1267   PetscCallCXX(KokkosSparse::spadd_numeric(&mm->kh2, 1.0, mm->C2, 1.0, mm->C4, mm->Co));
1268   PetscFunctionReturn(PETSC_SUCCESS);
1269 }
1270 
1271 static PetscErrorCode MatProductNumeric_MPIAIJKokkos(Mat C)
1272 {
1273   Mat_MPIAIJ                  *cmpi = static_cast<Mat_MPIAIJ *>(C->data);
1274   Mat_Product                 *product;
1275   MatProductData_MPIAIJKokkos *pdata;
1276   MatProductType               ptype;
1277   Mat                          A, B;
1278 
1279   PetscFunctionBegin;
1280   MatCheckProduct(C, 1); // make sure C is a product
1281   product = C->product;
1282   pdata   = static_cast<MatProductData_MPIAIJKokkos *>(product->data);
1283   ptype   = product->type;
1284   A       = product->A;
1285   B       = product->B;
1286 
1287   // See if numeric has already been done in symbolic (e.g., user calls MatMatMult(A,B,MAT_INITIAL_MATRIX,..,C)).
1288   // If yes, skip the numeric, but reset the flag so that next time when user calls MatMatMult(E,F,MAT_REUSE_MATRIX,..,C),
1289   // we still do numeric.
1290   if (pdata->reusesym) { // numeric reuses results from symbolic
1291     pdata->reusesym = PETSC_FALSE;
1292     PetscFunctionReturn(PETSC_SUCCESS);
1293   }
1294 
1295   if (ptype == MATPRODUCT_AB) {
1296     PetscCall(MatProductNumeric_MPIAIJKokkos_AB(product, A, B, pdata->mmAB));
1297   } else if (ptype == MATPRODUCT_AtB) {
1298     PetscCall(MatProductNumeric_MPIAIJKokkos_AtB(product, A, B, pdata->mmAtB));
1299   } else if (ptype == MATPRODUCT_PtAP) { // BtAB, computed by Z = AB; C= BtZ
1300     PetscCall(MatProductNumeric_MPIAIJKokkos_AB(product, A, B, pdata->mmAB));
1301     PetscCall(MatProductNumeric_MPIAIJKokkos_AtB(product, B, pdata->Z, pdata->mmAtB));
1302   }
1303   PetscCall(MatSeqAIJKokkosModifyDevice(cmpi->A)); // mark that A, B on device are modified
1304   PetscCall(MatSeqAIJKokkosModifyDevice(cmpi->B));
1305   PetscFunctionReturn(PETSC_SUCCESS);
1306 }
1307 
1308 static PetscErrorCode MatProductSymbolic_MPIAIJKokkos(Mat C)
1309 {
1310   Mat                          A, B;
1311   Mat_Product                 *product;
1312   MatProductType               ptype;
1313   MatProductData_MPIAIJKokkos *pdata;
1314   MatMatStruct                *mm = NULL;
1315   PetscInt                     m, n, M, N;
1316   Mat                          Cd, Co;
1317   MPI_Comm                     comm;
1318 
1319   PetscFunctionBegin;
1320   PetscCall(PetscObjectGetComm((PetscObject)C, &comm));
1321   MatCheckProduct(C, 1);
1322   product = C->product;
1323   PetscCheck(!product->data, comm, PETSC_ERR_PLIB, "Product data not empty");
1324   ptype = product->type;
1325   A     = product->A;
1326   B     = product->B;
1327 
1328   switch (ptype) {
1329   case MATPRODUCT_AB:
1330     m = A->rmap->n;
1331     n = B->cmap->n;
1332     M = A->rmap->N;
1333     N = B->cmap->N;
1334     break;
1335   case MATPRODUCT_AtB:
1336     m = A->cmap->n;
1337     n = B->cmap->n;
1338     M = A->cmap->N;
1339     N = B->cmap->N;
1340     break;
1341   case MATPRODUCT_PtAP:
1342     m = B->cmap->n;
1343     n = B->cmap->n;
1344     M = B->cmap->N;
1345     N = B->cmap->N;
1346     break; /* BtAB */
1347   default:
1348     SETERRQ(comm, PETSC_ERR_PLIB, "Not for product type %s", MatProductTypes[ptype]);
1349   }
1350 
1351   PetscCall(MatSetSizes(C, m, n, M, N));
1352   PetscCall(PetscLayoutSetUp(C->rmap));
1353   PetscCall(PetscLayoutSetUp(C->cmap));
1354   PetscCall(MatSetType(C, ((PetscObject)A)->type_name));
1355 
1356   pdata           = new MatProductData_MPIAIJKokkos();
1357   pdata->reusesym = product->api_user;
1358 
1359   if (ptype == MATPRODUCT_AB) {
1360     auto mmAB = new MatMatStruct_AB();
1361     PetscCall(MatProductSymbolic_MPIAIJKokkos_AB(product, A, B, mmAB));
1362     mm = pdata->mmAB = mmAB;
1363   } else if (ptype == MATPRODUCT_AtB) {
1364     auto mmAtB = new MatMatStruct_AtB();
1365     PetscCall(MatProductSymbolic_MPIAIJKokkos_AtB(product, A, B, mmAtB));
1366     mm = pdata->mmAtB = mmAtB;
1367   } else if (ptype == MATPRODUCT_PtAP) { // C = BtAB, computed as Z = AB; C= BtZ
1368     Mat Zd, Zo, Z;                       // Zd, Zo are owned by pdata->Z
1369 
1370     auto mmAB = new MatMatStruct_AB();
1371     PetscCall(MatProductSymbolic_MPIAIJKokkos_AB(product, A, B, mmAB)); // Z stored as mmAB->{Cd, Co}
1372     PetscCall(MatCreateSeqAIJKokkosWithKokkosCsrMatrix(PETSC_COMM_SELF, mmAB->Cd, &Zd));
1373     PetscCall(MatCreateSeqAIJKokkosWithKokkosCsrMatrix(PETSC_COMM_SELF, mmAB->Co, &Zo));
1374     pdata->mmAB = mmAB;
1375 
1376     m = A->rmap->n; // Z's layout
1377     n = B->cmap->n;
1378     M = A->rmap->N;
1379     N = B->cmap->N;
1380     PetscCall(MatCreateMPIAIJWithSeqAIJ(comm, M, N, Zd, Zo, mmAB->garray, &Z));
1381 
1382     auto mmAtB = new MatMatStruct_AtB();
1383     PetscCall(MatProductSymbolic_MPIAIJKokkos_AtB(product, B, Z, mmAtB)); // final result C stored as mmAtB->{Cd, Co}
1384 
1385     pdata->Z = Z; // give ownership to pdata
1386     mm = pdata->mmAtB = mmAtB;
1387   }
1388 
1389   PetscCall(MatCreateSeqAIJKokkosWithKokkosCsrMatrix(PETSC_COMM_SELF, mm->Cd, &Cd));
1390   PetscCall(MatCreateSeqAIJKokkosWithKokkosCsrMatrix(PETSC_COMM_SELF, mm->Co, &Co));
1391   PetscCall(MatSetMPIAIJWithSplitSeqAIJ(C, Cd, Co, mm->garray));
1392   /* set block sizes */
1393   switch (ptype) {
1394   case MATPRODUCT_PtAP:
1395     if (B->cmap->bs > 1) PetscCall(MatSetBlockSizes(C, B->cmap->bs, B->cmap->bs));
1396     break;
1397   case MATPRODUCT_RARt:
1398     if (B->rmap->bs > 1) PetscCall(MatSetBlockSizes(C, B->rmap->bs, B->rmap->bs));
1399     break;
1400   case MATPRODUCT_ABC:
1401     PetscCall(MatSetBlockSizesFromMats(C, A, product->C));
1402     break;
1403   case MATPRODUCT_AB:
1404     PetscCall(MatSetBlockSizesFromMats(C, A, B));
1405     break;
1406   case MATPRODUCT_AtB:
1407     if (A->cmap->bs > 1 || B->cmap->bs > 1) PetscCall(MatSetBlockSizes(C, A->cmap->bs, B->cmap->bs));
1408     break;
1409   case MATPRODUCT_ABt:
1410     if (A->rmap->bs > 1 || B->rmap->bs > 1) PetscCall(MatSetBlockSizes(C, A->rmap->bs, B->rmap->bs));
1411     break;
1412   default:
1413     SETERRQ(PetscObjectComm((PetscObject)C), PETSC_ERR_PLIB, "Not for ProductType %s", MatProductTypes[ptype]);
1414   }
1415   C->product->data       = pdata;
1416   C->product->destroy    = MatProductDataDestroy_MPIAIJKokkos;
1417   C->ops->productnumeric = MatProductNumeric_MPIAIJKokkos;
1418   PetscFunctionReturn(PETSC_SUCCESS);
1419 }
1420 
1421 PETSC_INTERN PetscErrorCode MatProductSetFromOptions_MPIAIJKokkos(Mat mat)
1422 {
1423   Mat_Product *product = mat->product;
1424   PetscBool    match   = PETSC_FALSE;
1425   PetscBool    usecpu  = PETSC_FALSE;
1426 
1427   PetscFunctionBegin;
1428   MatCheckProduct(mat, 1);
1429   if (!product->A->boundtocpu && !product->B->boundtocpu) PetscCall(PetscObjectTypeCompare((PetscObject)product->B, ((PetscObject)product->A)->type_name, &match));
1430   if (match) { /* we can always fallback to the CPU if requested */
1431     switch (product->type) {
1432     case MATPRODUCT_AB:
1433       if (product->api_user) {
1434         PetscOptionsBegin(PetscObjectComm((PetscObject)mat), ((PetscObject)mat)->prefix, "MatMatMult", "Mat");
1435         PetscCall(PetscOptionsBool("-matmatmult_backend_cpu", "Use CPU code", "MatMatMult", usecpu, &usecpu, NULL));
1436         PetscOptionsEnd();
1437       } else {
1438         PetscOptionsBegin(PetscObjectComm((PetscObject)mat), ((PetscObject)mat)->prefix, "MatProduct_AB", "Mat");
1439         PetscCall(PetscOptionsBool("-mat_product_algorithm_backend_cpu", "Use CPU code", "MatMatMult", usecpu, &usecpu, NULL));
1440         PetscOptionsEnd();
1441       }
1442       break;
1443     case MATPRODUCT_AtB:
1444       if (product->api_user) {
1445         PetscOptionsBegin(PetscObjectComm((PetscObject)mat), ((PetscObject)mat)->prefix, "MatTransposeMatMult", "Mat");
1446         PetscCall(PetscOptionsBool("-mattransposematmult_backend_cpu", "Use CPU code", "MatTransposeMatMult", usecpu, &usecpu, NULL));
1447         PetscOptionsEnd();
1448       } else {
1449         PetscOptionsBegin(PetscObjectComm((PetscObject)mat), ((PetscObject)mat)->prefix, "MatProduct_AtB", "Mat");
1450         PetscCall(PetscOptionsBool("-mat_product_algorithm_backend_cpu", "Use CPU code", "MatTransposeMatMult", usecpu, &usecpu, NULL));
1451         PetscOptionsEnd();
1452       }
1453       break;
1454     case MATPRODUCT_PtAP:
1455       if (product->api_user) {
1456         PetscOptionsBegin(PetscObjectComm((PetscObject)mat), ((PetscObject)mat)->prefix, "MatPtAP", "Mat");
1457         PetscCall(PetscOptionsBool("-matptap_backend_cpu", "Use CPU code", "MatPtAP", usecpu, &usecpu, NULL));
1458         PetscOptionsEnd();
1459       } else {
1460         PetscOptionsBegin(PetscObjectComm((PetscObject)mat), ((PetscObject)mat)->prefix, "MatProduct_PtAP", "Mat");
1461         PetscCall(PetscOptionsBool("-mat_product_algorithm_backend_cpu", "Use CPU code", "MatPtAP", usecpu, &usecpu, NULL));
1462         PetscOptionsEnd();
1463       }
1464       break;
1465     default:
1466       break;
1467     }
1468     match = (PetscBool)!usecpu;
1469   }
1470   if (match) {
1471     switch (product->type) {
1472     case MATPRODUCT_AB:
1473     case MATPRODUCT_AtB:
1474     case MATPRODUCT_PtAP:
1475       mat->ops->productsymbolic = MatProductSymbolic_MPIAIJKokkos;
1476       break;
1477     default:
1478       break;
1479     }
1480   }
1481   /* fallback to MPIAIJ ops */
1482   if (!mat->ops->productsymbolic) PetscCall(MatProductSetFromOptions_MPIAIJ(mat));
1483   PetscFunctionReturn(PETSC_SUCCESS);
1484 }
1485 
1486 // Mirror of MatCOOStruct_MPIAIJ on device
1487 struct MatCOOStruct_MPIAIJKokkos {
1488   PetscCount           n;
1489   PetscSF              sf;
1490   PetscCount           Annz, Bnnz;
1491   PetscCount           Annz2, Bnnz2;
1492   PetscCountKokkosView Ajmap1, Aperm1;
1493   PetscCountKokkosView Bjmap1, Bperm1;
1494   PetscCountKokkosView Aimap2, Ajmap2, Aperm2;
1495   PetscCountKokkosView Bimap2, Bjmap2, Bperm2;
1496   PetscCountKokkosView Cperm1;
1497   MatScalarKokkosView  sendbuf, recvbuf;
1498 
1499   MatCOOStruct_MPIAIJKokkos(const MatCOOStruct_MPIAIJ *coo_h)
1500   {
1501     auto exec = PetscGetKokkosExecutionSpace();
1502 
1503     n       = coo_h->n;
1504     sf      = coo_h->sf;
1505     Annz    = coo_h->Annz;
1506     Bnnz    = coo_h->Bnnz;
1507     Annz2   = coo_h->Annz2;
1508     Bnnz2   = coo_h->Bnnz2;
1509     Ajmap1  = Kokkos::create_mirror_view_and_copy(exec, PetscCountKokkosViewHost(coo_h->Ajmap1, coo_h->Annz + 1));
1510     Aperm1  = Kokkos::create_mirror_view_and_copy(exec, PetscCountKokkosViewHost(coo_h->Aperm1, coo_h->Atot1));
1511     Bjmap1  = Kokkos::create_mirror_view_and_copy(exec, PetscCountKokkosViewHost(coo_h->Bjmap1, coo_h->Bnnz + 1));
1512     Bperm1  = Kokkos::create_mirror_view_and_copy(exec, PetscCountKokkosViewHost(coo_h->Bperm1, coo_h->Btot1));
1513     Aimap2  = Kokkos::create_mirror_view_and_copy(exec, PetscCountKokkosViewHost(coo_h->Aimap2, coo_h->Annz2));
1514     Ajmap2  = Kokkos::create_mirror_view_and_copy(exec, PetscCountKokkosViewHost(coo_h->Ajmap2, coo_h->Annz2 + 1));
1515     Aperm2  = Kokkos::create_mirror_view_and_copy(exec, PetscCountKokkosViewHost(coo_h->Aperm2, coo_h->Atot2));
1516     Bimap2  = Kokkos::create_mirror_view_and_copy(exec, PetscCountKokkosViewHost(coo_h->Bimap2, coo_h->Bnnz2));
1517     Bjmap2  = Kokkos::create_mirror_view_and_copy(exec, PetscCountKokkosViewHost(coo_h->Bjmap2, coo_h->Bnnz2 + 1));
1518     Bperm2  = Kokkos::create_mirror_view_and_copy(exec, PetscCountKokkosViewHost(coo_h->Bperm2, coo_h->Btot2));
1519     Cperm1  = Kokkos::create_mirror_view_and_copy(exec, PetscCountKokkosViewHost(coo_h->Cperm1, coo_h->sendlen));
1520     sendbuf = Kokkos::create_mirror_view(Kokkos::WithoutInitializing, exec, MatScalarKokkosViewHost(coo_h->sendbuf, coo_h->sendlen));
1521     recvbuf = Kokkos::create_mirror_view(Kokkos::WithoutInitializing, exec, MatScalarKokkosViewHost(coo_h->recvbuf, coo_h->recvlen));
1522     PetscCallVoid(PetscObjectReference((PetscObject)sf));
1523   }
1524 
1525   ~MatCOOStruct_MPIAIJKokkos() { PetscCallVoid(PetscSFDestroy(&sf)); }
1526 };
1527 
1528 static PetscErrorCode MatCOOStructDestroy_MPIAIJKokkos(void **data)
1529 {
1530   PetscFunctionBegin;
1531   PetscCallCXX(delete static_cast<MatCOOStruct_MPIAIJKokkos *>(*data));
1532   PetscFunctionReturn(PETSC_SUCCESS);
1533 }
1534 
1535 static PetscErrorCode MatSetPreallocationCOO_MPIAIJKokkos(Mat mat, PetscCount coo_n, PetscInt coo_i[], PetscInt coo_j[])
1536 {
1537   PetscContainer             container_h, container_d;
1538   MatCOOStruct_MPIAIJ       *coo_h;
1539   MatCOOStruct_MPIAIJKokkos *coo_d;
1540 
1541   PetscFunctionBegin;
1542   PetscCall(MatSetPreallocationCOO_MPIAIJ(mat, coo_n, coo_i, coo_j)); /* mpiaij->A,B's type is set to seqaijkokkos */
1543   mat->preallocated = PETSC_TRUE;
1544   PetscCall(MatAssemblyBegin(mat, MAT_FINAL_ASSEMBLY));
1545   PetscCall(MatAssemblyEnd(mat, MAT_FINAL_ASSEMBLY));
1546   PetscCall(MatZeroEntries(mat));
1547 
1548   // Copy the COO struct to device
1549   PetscCall(PetscObjectQuery((PetscObject)mat, "__PETSc_MatCOOStruct_Host", (PetscObject *)&container_h));
1550   PetscCall(PetscContainerGetPointer(container_h, (void **)&coo_h));
1551   PetscCallCXX(coo_d = new MatCOOStruct_MPIAIJKokkos(coo_h));
1552 
1553   // Put the COO struct in a container and then attach that to the matrix
1554   PetscCall(PetscContainerCreate(PETSC_COMM_SELF, &container_d));
1555   PetscCall(PetscContainerSetPointer(container_d, coo_d));
1556   PetscCall(PetscContainerSetCtxDestroy(container_d, MatCOOStructDestroy_MPIAIJKokkos));
1557   PetscCall(PetscObjectCompose((PetscObject)mat, "__PETSc_MatCOOStruct_Device", (PetscObject)container_d));
1558   PetscCall(PetscContainerDestroy(&container_d));
1559   PetscFunctionReturn(PETSC_SUCCESS);
1560 }
1561 
1562 static PetscErrorCode MatSetValuesCOO_MPIAIJKokkos(Mat mat, const PetscScalar v[], InsertMode imode)
1563 {
1564   Mat_MPIAIJ                   *mpiaij = static_cast<Mat_MPIAIJ *>(mat->data);
1565   Mat                           A = mpiaij->A, B = mpiaij->B;
1566   MatScalarKokkosView           Aa, Ba;
1567   MatScalarKokkosView           v1;
1568   PetscMemType                  memtype;
1569   PetscContainer                container;
1570   MatCOOStruct_MPIAIJKokkos    *coo;
1571   Kokkos::DefaultExecutionSpace exec = PetscGetKokkosExecutionSpace();
1572 
1573   PetscFunctionBegin;
1574   PetscCall(PetscObjectQuery((PetscObject)mat, "__PETSc_MatCOOStruct_Device", (PetscObject *)&container));
1575   PetscCall(PetscContainerGetPointer(container, (void **)&coo));
1576 
1577   const auto &n      = coo->n;
1578   const auto &Annz   = coo->Annz;
1579   const auto &Annz2  = coo->Annz2;
1580   const auto &Bnnz   = coo->Bnnz;
1581   const auto &Bnnz2  = coo->Bnnz2;
1582   const auto &vsend  = coo->sendbuf;
1583   const auto &v2     = coo->recvbuf;
1584   const auto &Ajmap1 = coo->Ajmap1;
1585   const auto &Ajmap2 = coo->Ajmap2;
1586   const auto &Aimap2 = coo->Aimap2;
1587   const auto &Bjmap1 = coo->Bjmap1;
1588   const auto &Bjmap2 = coo->Bjmap2;
1589   const auto &Bimap2 = coo->Bimap2;
1590   const auto &Aperm1 = coo->Aperm1;
1591   const auto &Aperm2 = coo->Aperm2;
1592   const auto &Bperm1 = coo->Bperm1;
1593   const auto &Bperm2 = coo->Bperm2;
1594   const auto &Cperm1 = coo->Cperm1;
1595 
1596   PetscCall(PetscGetMemType(v, &memtype)); /* Return PETSC_MEMTYPE_HOST when v is NULL */
1597   if (PetscMemTypeHost(memtype)) {         /* If user gave v[] in host, we need to copy it to device if any */
1598     v1 = Kokkos::create_mirror_view_and_copy(exec, MatScalarKokkosViewHost((PetscScalar *)v, n));
1599   } else {
1600     v1 = MatScalarKokkosView((PetscScalar *)v, n); /* Directly use v[]'s memory */
1601   }
1602 
1603   if (imode == INSERT_VALUES) {
1604     PetscCall(MatSeqAIJGetKokkosViewWrite(A, &Aa)); /* write matrix values */
1605     PetscCall(MatSeqAIJGetKokkosViewWrite(B, &Ba));
1606   } else {
1607     PetscCall(MatSeqAIJGetKokkosView(A, &Aa)); /* read & write matrix values */
1608     PetscCall(MatSeqAIJGetKokkosView(B, &Ba));
1609   }
1610 
1611   PetscCall(PetscLogGpuTimeBegin());
1612   /* Pack entries to be sent to remote */
1613   Kokkos::parallel_for(Kokkos::RangePolicy<>(exec, 0, vsend.extent(0)), KOKKOS_LAMBDA(const PetscCount i) { vsend(i) = v1(Cperm1(i)); });
1614 
1615   /* Send remote entries to their owner and overlap the communication with local computation */
1616   PetscCall(PetscSFReduceWithMemTypeBegin(coo->sf, MPIU_SCALAR, PETSC_MEMTYPE_KOKKOS, vsend.data(), PETSC_MEMTYPE_KOKKOS, v2.data(), MPI_REPLACE));
1617   /* Add local entries to A and B in one kernel */
1618   Kokkos::parallel_for(
1619     Kokkos::RangePolicy<>(exec, 0, Annz + Bnnz), KOKKOS_LAMBDA(PetscCount i) {
1620       PetscScalar sum = 0.0;
1621       if (i < Annz) {
1622         for (PetscCount k = Ajmap1(i); k < Ajmap1(i + 1); k++) sum += v1(Aperm1(k));
1623         Aa(i) = (imode == INSERT_VALUES ? 0.0 : Aa(i)) + sum;
1624       } else {
1625         i -= Annz;
1626         for (PetscCount k = Bjmap1(i); k < Bjmap1(i + 1); k++) sum += v1(Bperm1(k));
1627         Ba(i) = (imode == INSERT_VALUES ? 0.0 : Ba(i)) + sum;
1628       }
1629     });
1630   PetscCall(PetscSFReduceEnd(coo->sf, MPIU_SCALAR, vsend.data(), v2.data(), MPI_REPLACE));
1631 
1632   /* Add received remote entries to A and B in one kernel */
1633   Kokkos::parallel_for(
1634     Kokkos::RangePolicy<>(exec, 0, Annz2 + Bnnz2), KOKKOS_LAMBDA(PetscCount i) {
1635       if (i < Annz2) {
1636         for (PetscCount k = Ajmap2(i); k < Ajmap2(i + 1); k++) Aa(Aimap2(i)) += v2(Aperm2(k));
1637       } else {
1638         i -= Annz2;
1639         for (PetscCount k = Bjmap2(i); k < Bjmap2(i + 1); k++) Ba(Bimap2(i)) += v2(Bperm2(k));
1640       }
1641     });
1642   PetscCall(PetscLogGpuTimeEnd());
1643 
1644   if (imode == INSERT_VALUES) {
1645     PetscCall(MatSeqAIJRestoreKokkosViewWrite(A, &Aa)); /* Increase A & B's state etc. */
1646     PetscCall(MatSeqAIJRestoreKokkosViewWrite(B, &Ba));
1647   } else {
1648     PetscCall(MatSeqAIJRestoreKokkosView(A, &Aa));
1649     PetscCall(MatSeqAIJRestoreKokkosView(B, &Ba));
1650   }
1651   PetscFunctionReturn(PETSC_SUCCESS);
1652 }
1653 
1654 static PetscErrorCode MatDestroy_MPIAIJKokkos(Mat A)
1655 {
1656   PetscFunctionBegin;
1657   PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatMPIAIJSetPreallocation_C", NULL));
1658   PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatMPIAIJGetLocalMatMerge_C", NULL));
1659   PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatSetPreallocationCOO_C", NULL));
1660   PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatSetValuesCOO_C", NULL));
1661 #if defined(PETSC_HAVE_HYPRE)
1662   PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatConvert_mpiaijkokkos_hypre_C", NULL));
1663 #endif
1664   PetscCall(MatDestroy_MPIAIJ(A));
1665   PetscFunctionReturn(PETSC_SUCCESS);
1666 }
1667 
1668 static PetscErrorCode MatShift_MPIAIJKokkos(Mat A, PetscScalar a)
1669 {
1670   Mat_MPIAIJ *mpiaij = static_cast<Mat_MPIAIJ *>(A->data);
1671   PetscBool   congruent;
1672 
1673   PetscFunctionBegin;
1674   PetscCall(MatHasCongruentLayouts(A, &congruent));
1675   if (congruent) { // square matrix and the diagonals are solely in the diag block
1676     PetscCall(MatShift(mpiaij->A, a));
1677   } else { // too hard, use the general version
1678     PetscCall(MatShift_Basic(A, a));
1679   }
1680   PetscFunctionReturn(PETSC_SUCCESS);
1681 }
1682 
1683 static PetscErrorCode MatSetOps_MPIAIJKokkos(Mat B)
1684 {
1685   PetscFunctionBegin;
1686   B->ops->assemblyend           = MatAssemblyEnd_MPIAIJKokkos;
1687   B->ops->mult                  = MatMult_MPIAIJKokkos;
1688   B->ops->multadd               = MatMultAdd_MPIAIJKokkos;
1689   B->ops->multtranspose         = MatMultTranspose_MPIAIJKokkos;
1690   B->ops->productsetfromoptions = MatProductSetFromOptions_MPIAIJKokkos;
1691   B->ops->destroy               = MatDestroy_MPIAIJKokkos;
1692   B->ops->shift                 = MatShift_MPIAIJKokkos;
1693 
1694   PetscCall(PetscObjectComposeFunction((PetscObject)B, "MatMPIAIJSetPreallocation_C", MatMPIAIJSetPreallocation_MPIAIJKokkos));
1695   PetscCall(PetscObjectComposeFunction((PetscObject)B, "MatMPIAIJGetLocalMatMerge_C", MatMPIAIJGetLocalMatMerge_MPIAIJKokkos));
1696   PetscCall(PetscObjectComposeFunction((PetscObject)B, "MatSetPreallocationCOO_C", MatSetPreallocationCOO_MPIAIJKokkos));
1697   PetscCall(PetscObjectComposeFunction((PetscObject)B, "MatSetValuesCOO_C", MatSetValuesCOO_MPIAIJKokkos));
1698 #if defined(PETSC_HAVE_HYPRE)
1699   PetscCall(PetscObjectComposeFunction((PetscObject)B, "MatConvert_mpiaijkokkos_hypre_C", MatConvert_AIJ_HYPRE));
1700 #endif
1701   PetscFunctionReturn(PETSC_SUCCESS);
1702 }
1703 
1704 PETSC_INTERN PetscErrorCode MatConvert_MPIAIJ_MPIAIJKokkos(Mat A, MatType mtype, MatReuse reuse, Mat *newmat)
1705 {
1706   Mat         B;
1707   Mat_MPIAIJ *a;
1708 
1709   PetscFunctionBegin;
1710   if (reuse == MAT_INITIAL_MATRIX) {
1711     PetscCall(MatDuplicate(A, MAT_COPY_VALUES, newmat));
1712   } else if (reuse == MAT_REUSE_MATRIX) {
1713     PetscCall(MatCopy(A, *newmat, SAME_NONZERO_PATTERN));
1714   }
1715   B = *newmat;
1716 
1717   B->boundtocpu = PETSC_FALSE;
1718   PetscCall(PetscFree(B->defaultvectype));
1719   PetscCall(PetscStrallocpy(VECKOKKOS, &B->defaultvectype));
1720   PetscCall(PetscObjectChangeTypeName((PetscObject)B, MATMPIAIJKOKKOS));
1721 
1722   a = static_cast<Mat_MPIAIJ *>(A->data);
1723   if (a->A) PetscCall(MatSetType(a->A, MATSEQAIJKOKKOS));
1724   if (a->B) PetscCall(MatSetType(a->B, MATSEQAIJKOKKOS));
1725   if (a->lvec) PetscCall(VecSetType(a->lvec, VECSEQKOKKOS));
1726   PetscCall(MatSetOps_MPIAIJKokkos(B));
1727   PetscFunctionReturn(PETSC_SUCCESS);
1728 }
1729 
1730 /*MC
1731    MATAIJKOKKOS - "mpiaijkokkos", a matrix type to be used for CSR sparse matrices with Kokkos
1732 
1733    A matrix type using Kokkos-Kernels CrsMatrix type for portability across different device types
1734 
1735    Options Database Key:
1736 .  -mat_type aijkokkos - sets the matrix type to `MATAIJKOKKOS`
1737 
1738   Level: beginner
1739 
1740 .seealso: [](ch_matrices), `Mat`, `MatCreateAIJKokkos()`, `MATSEQAIJKOKKOS`, `MATSEQAIJ`, `MATMPIAIJ`
1741 M*/
1742 PETSC_EXTERN PetscErrorCode MatCreate_MPIAIJKokkos(Mat A)
1743 {
1744   PetscFunctionBegin;
1745   PetscCall(PetscKokkosInitializeCheck());
1746   PetscCall(MatCreate_MPIAIJ(A));
1747   PetscCall(MatConvert_MPIAIJ_MPIAIJKokkos(A, MATMPIAIJKOKKOS, MAT_INPLACE_MATRIX, &A));
1748   PetscFunctionReturn(PETSC_SUCCESS);
1749 }
1750 
1751 /*@C
1752   MatCreateAIJKokkos - Creates a sparse matrix in `MATAIJKOKKOS` (compressed row) format
1753   (the default parallel PETSc format).  This matrix will ultimately pushed down
1754   to Kokkos for calculations.
1755 
1756   Collective
1757 
1758   Input Parameters:
1759 + comm  - MPI communicator, set to `PETSC_COMM_SELF`
1760 . m     - number of local rows (or `PETSC_DECIDE` to have calculated if `M` is given)
1761            This value should be the same as the local size used in creating the
1762            y vector for the matrix-vector product y = Ax.
1763 . n     - This value should be the same as the local size used in creating the
1764        x vector for the matrix-vector product y = Ax. (or `PETSC_DECIDE` to have
1765        calculated if N is given) For square matrices n is almost always `m`.
1766 . M     - number of global rows (or `PETSC_DETERMINE` to have calculated if `m` is given)
1767 . N     - number of global columns (or `PETSC_DETERMINE` to have calculated if `n` is given)
1768 . d_nz  - number of nonzeros per row in DIAGONAL portion of local submatrix
1769            (same value is used for all local rows)
1770 . d_nnz - array containing the number of nonzeros in the various rows of the
1771            DIAGONAL portion of the local submatrix (possibly different for each row)
1772            or `NULL`, if `d_nz` is used to specify the nonzero structure.
1773            The size of this array is equal to the number of local rows, i.e `m`.
1774            For matrices you plan to factor you must leave room for the diagonal entry and
1775            put in the entry even if it is zero.
1776 . o_nz  - number of nonzeros per row in the OFF-DIAGONAL portion of local
1777            submatrix (same value is used for all local rows).
1778 - o_nnz - array containing the number of nonzeros in the various rows of the
1779            OFF-DIAGONAL portion of the local submatrix (possibly different for
1780            each row) or `NULL`, if `o_nz` is used to specify the nonzero
1781            structure. The size of this array is equal to the number
1782            of local rows, i.e `m`.
1783 
1784   Output Parameter:
1785 . A - the matrix
1786 
1787   Level: intermediate
1788 
1789   Notes:
1790   It is recommended that one use the `MatCreate()`, `MatSetType()` and/or `MatSetFromOptions()`,
1791   MatXXXXSetPreallocation() paradigm instead of this routine directly.
1792   [MatXXXXSetPreallocation() is, for example, `MatSeqAIJSetPreallocation()`]
1793 
1794   The AIJ format, also called compressed row storage), is fully compatible with standard Fortran
1795   storage.  That is, the stored row and column indices can begin at
1796   either one (as in Fortran) or zero.
1797 
1798 .seealso: [](ch_matrices), `Mat`, `MATAIJKOKKOS`, `MATSEQAIJKOKKOS`, `MATMPIAIJKOKKOS`, `MatCreate()`, `MatCreateAIJ()`, `MatSetValues()`,
1799           `MatSeqAIJSetColumnIndices()`, `MatCreateSeqAIJWithArrays()`
1800 @*/
1801 PetscErrorCode MatCreateAIJKokkos(MPI_Comm comm, PetscInt m, PetscInt n, PetscInt M, PetscInt N, PetscInt d_nz, const PetscInt d_nnz[], PetscInt o_nz, const PetscInt o_nnz[], Mat *A)
1802 {
1803   PetscMPIInt size;
1804 
1805   PetscFunctionBegin;
1806   PetscCall(MatCreate(comm, A));
1807   PetscCall(MatSetSizes(*A, m, n, M, N));
1808   PetscCallMPI(MPI_Comm_size(comm, &size));
1809   if (size > 1) {
1810     PetscCall(MatSetType(*A, MATMPIAIJKOKKOS));
1811     PetscCall(MatMPIAIJSetPreallocation(*A, d_nz, d_nnz, o_nz, o_nnz));
1812   } else {
1813     PetscCall(MatSetType(*A, MATSEQAIJKOKKOS));
1814     PetscCall(MatSeqAIJSetPreallocation(*A, d_nz, d_nnz));
1815   }
1816   PetscFunctionReturn(PETSC_SUCCESS);
1817 }
1818