xref: /petsc/src/mat/impls/aij/mpi/kokkos/mpiaijkok.kokkos.cxx (revision 76be6f4ff3bd4e251c19fc00ebbebfd58b6e7589)
1 #include <petscvec_kokkos.hpp>
2 #include <petscsf.h>
3 #include <petsc/private/sfimpl.h>
4 #include <../src/mat/impls/aij/mpi/mpiaij.h>
5 #include <../src/mat/impls/aij/mpi/kokkos/mpiaijkok.hpp>
6 #include <KokkosSparse_spadd.hpp>
7 
8 PetscErrorCode MatAssemblyEnd_MPIAIJKokkos(Mat A,MatAssemblyType mode)
9 {
10   Mat_SeqAIJKokkos *aijkok;
11 
12   PetscFunctionBegin;
13   PetscCall(MatAssemblyEnd_MPIAIJ(A,mode));
14   aijkok = static_cast<Mat_SeqAIJKokkos*>(((Mat_MPIAIJ*)A->data)->A->spptr); /* Access spptr after MatAssemblyEnd_MPIAIJ(), which might have deleted old spptr */
15   if (aijkok && aijkok->device_mat_d.data()) {
16     A->offloadmask = PETSC_OFFLOAD_GPU; // in GPU mode, no going back. MatSetValues checks this
17   }
18 
19   PetscFunctionReturn(0);
20 }
21 
22 PetscErrorCode MatMPIAIJSetPreallocation_MPIAIJKokkos(Mat mat,PetscInt d_nz,const PetscInt d_nnz[],PetscInt o_nz,const PetscInt o_nnz[])
23 {
24   Mat_MPIAIJ     *mpiaij = (Mat_MPIAIJ*)mat->data;
25 
26   PetscFunctionBegin;
27   PetscCall(PetscLayoutSetUp(mat->rmap));
28   PetscCall(PetscLayoutSetUp(mat->cmap));
29 #if defined(PETSC_USE_DEBUG)
30   if (d_nnz) {
31     PetscInt i;
32     for (i=0; i<mat->rmap->n; i++) {
33       PetscCheck(d_nnz[i] >= 0,PETSC_COMM_SELF,PETSC_ERR_ARG_OUTOFRANGE,"d_nnz cannot be less than 0: local row %" PetscInt_FMT " value %" PetscInt_FMT,i,d_nnz[i]);
34     }
35   }
36   if (o_nnz) {
37     PetscInt i;
38     for (i=0; i<mat->rmap->n; i++) {
39       PetscCheck(o_nnz[i] >= 0,PETSC_COMM_SELF,PETSC_ERR_ARG_OUTOFRANGE,"o_nnz cannot be less than 0: local row %" PetscInt_FMT " value %" PetscInt_FMT,i,o_nnz[i]);
40     }
41   }
42 #endif
43 #if defined(PETSC_USE_CTABLE)
44   PetscCall(PetscTableDestroy(&mpiaij->colmap));
45 #else
46   PetscCall(PetscFree(mpiaij->colmap));
47 #endif
48   PetscCall(PetscFree(mpiaij->garray));
49   PetscCall(VecDestroy(&mpiaij->lvec));
50   PetscCall(VecScatterDestroy(&mpiaij->Mvctx));
51   /* Because the B will have been resized we simply destroy it and create a new one each time */
52   PetscCall(MatDestroy(&mpiaij->B));
53 
54   if (!mpiaij->A) {
55     PetscCall(MatCreate(PETSC_COMM_SELF,&mpiaij->A));
56     PetscCall(MatSetSizes(mpiaij->A,mat->rmap->n,mat->cmap->n,mat->rmap->n,mat->cmap->n));
57     PetscCall(PetscLogObjectParent((PetscObject)mat,(PetscObject)mpiaij->A));
58   }
59   if (!mpiaij->B) {
60     PetscMPIInt size;
61     PetscCallMPI(MPI_Comm_size(PetscObjectComm((PetscObject)mat),&size));
62     PetscCall(MatCreate(PETSC_COMM_SELF,&mpiaij->B));
63     PetscCall(MatSetSizes(mpiaij->B,mat->rmap->n,size > 1 ? mat->cmap->N : 0,mat->rmap->n,size > 1 ? mat->cmap->N : 0));
64     PetscCall(PetscLogObjectParent((PetscObject)mat,(PetscObject)mpiaij->B));
65   }
66   PetscCall(MatSetType(mpiaij->A,MATSEQAIJKOKKOS));
67   PetscCall(MatSetType(mpiaij->B,MATSEQAIJKOKKOS));
68   PetscCall(MatSeqAIJSetPreallocation(mpiaij->A,d_nz,d_nnz));
69   PetscCall(MatSeqAIJSetPreallocation(mpiaij->B,o_nz,o_nnz));
70   mat->preallocated = PETSC_TRUE;
71   PetscFunctionReturn(0);
72 }
73 
74 PetscErrorCode MatMult_MPIAIJKokkos(Mat mat,Vec xx,Vec yy)
75 {
76   Mat_MPIAIJ     *mpiaij = (Mat_MPIAIJ*)mat->data;
77   PetscInt       nt;
78 
79   PetscFunctionBegin;
80   PetscCall(VecGetLocalSize(xx,&nt));
81   PetscCheck(nt == mat->cmap->n,PETSC_COMM_SELF,PETSC_ERR_ARG_SIZ,"Incompatible partition of mat (%" PetscInt_FMT ") and xx (%" PetscInt_FMT ")",mat->cmap->n,nt);
82   PetscCall(VecScatterBegin(mpiaij->Mvctx,xx,mpiaij->lvec,INSERT_VALUES,SCATTER_FORWARD));
83   PetscCall((*mpiaij->A->ops->mult)(mpiaij->A,xx,yy));
84   PetscCall(VecScatterEnd(mpiaij->Mvctx,xx,mpiaij->lvec,INSERT_VALUES,SCATTER_FORWARD));
85   PetscCall((*mpiaij->B->ops->multadd)(mpiaij->B,mpiaij->lvec,yy,yy));
86   PetscFunctionReturn(0);
87 }
88 
89 PetscErrorCode MatMultAdd_MPIAIJKokkos(Mat mat,Vec xx,Vec yy,Vec zz)
90 {
91   Mat_MPIAIJ     *mpiaij = (Mat_MPIAIJ*)mat->data;
92   PetscInt       nt;
93 
94   PetscFunctionBegin;
95   PetscCall(VecGetLocalSize(xx,&nt));
96   PetscCheck(nt == mat->cmap->n,PETSC_COMM_SELF,PETSC_ERR_ARG_SIZ,"Incompatible partition of mat (%" PetscInt_FMT ") and xx (%" PetscInt_FMT ")",mat->cmap->n,nt);
97   PetscCall(VecScatterBegin(mpiaij->Mvctx,xx,mpiaij->lvec,INSERT_VALUES,SCATTER_FORWARD));
98   PetscCall((*mpiaij->A->ops->multadd)(mpiaij->A,xx,yy,zz));
99   PetscCall(VecScatterEnd(mpiaij->Mvctx,xx,mpiaij->lvec,INSERT_VALUES,SCATTER_FORWARD));
100   PetscCall((*mpiaij->B->ops->multadd)(mpiaij->B,mpiaij->lvec,zz,zz));
101   PetscFunctionReturn(0);
102 }
103 
104 PetscErrorCode MatMultTranspose_MPIAIJKokkos(Mat mat,Vec xx,Vec yy)
105 {
106   Mat_MPIAIJ     *mpiaij = (Mat_MPIAIJ*)mat->data;
107   PetscInt       nt;
108 
109   PetscFunctionBegin;
110   PetscCall(VecGetLocalSize(xx,&nt));
111   PetscCheck(nt == mat->rmap->n,PETSC_COMM_SELF,PETSC_ERR_ARG_SIZ,"Incompatible partition of mat (%" PetscInt_FMT ") and xx (%" PetscInt_FMT ")",mat->rmap->n,nt);
112   PetscCall((*mpiaij->B->ops->multtranspose)(mpiaij->B,xx,mpiaij->lvec));
113   PetscCall((*mpiaij->A->ops->multtranspose)(mpiaij->A,xx,yy));
114   PetscCall(VecScatterBegin(mpiaij->Mvctx,mpiaij->lvec,yy,ADD_VALUES,SCATTER_REVERSE));
115   PetscCall(VecScatterEnd(mpiaij->Mvctx,mpiaij->lvec,yy,ADD_VALUES,SCATTER_REVERSE));
116   PetscFunctionReturn(0);
117 }
118 
119 /* Merge the "A, B" matrices of mat into a matrix C.  mat's type is MPIAIJKOKKOS. C's type is MATSEQAIJKOKKOS.
120    A is put before B. C's size would be A->rmap->n by (A->cmap->n + B->cmap->n).
121    C still uses local column ids. Their corresponding global column ids are returned in glob.
122 */
123 PetscErrorCode MatMPIAIJGetLocalMatMerge_MPIAIJKokkos(Mat mat,MatReuse reuse,IS *glob,Mat *C)
124 {
125   Mat            Ad,Ao;
126   const PetscInt *cmap;
127 
128   PetscFunctionBegin;
129   PetscCall(MatMPIAIJGetSeqAIJ(mat,&Ad,&Ao,&cmap));
130   PetscCall(MatSeqAIJKokkosMergeMats(Ad,Ao,reuse,C));
131   if (glob) {
132     PetscInt cst, i, dn, on, *gidx;
133     PetscCall(MatGetLocalSize(Ad,NULL,&dn));
134     PetscCall(MatGetLocalSize(Ao,NULL,&on));
135     PetscCall(MatGetOwnershipRangeColumn(mat,&cst,NULL));
136     PetscCall(PetscMalloc1(dn+on,&gidx));
137     for (i=0; i<dn; i++) gidx[i]    = cst + i;
138     for (i=0; i<on; i++) gidx[i+dn] = cmap[i];
139     PetscCall(ISCreateGeneral(PetscObjectComm((PetscObject)Ad),dn+on,gidx,PETSC_OWN_POINTER,glob));
140   }
141   PetscFunctionReturn(0);
142 }
143 
144 /* Structs used in matrix product C=AB, C=A^tB and C=B^tAB */
145 struct MatMatStruct {
146   MatRowMapKokkosView   Cdstart; /* Used to split sequential matrix into petsc's A, B format */
147   PetscSF               sf; /* SF to send/recv matrix entries */
148   MatScalarKokkosView   abuf; /* buf of mat values in send/recv */
149   Mat                   C1,C2,B_local;
150   KokkosCsrMatrix       C1_global,C2_global,C_global;
151   KernelHandle          kh;
152   MatMatStruct() {
153     C1 = C2 = B_local = NULL;
154     sf = NULL;
155   }
156 
157   ~MatMatStruct() {
158     MatDestroy(&C1);
159     MatDestroy(&C2);
160     MatDestroy(&B_local);
161     PetscSFDestroy(&sf);
162     kh.destroy_spadd_handle();
163   }
164 };
165 
166 struct MatMatStruct_AB : public MatMatStruct {
167   MatColIdxKokkosView   rows;
168   MatRowMapKokkosView   rowoffset;
169   Mat                   B_other,C_petsc; /* SEQAIJKOKKOS matrices. TODO: have a better var name than C_petsc */
170 
171   MatMatStruct_AB() : B_other(NULL),C_petsc(NULL){}
172   ~MatMatStruct_AB() {
173     MatDestroy(&B_other);
174     MatDestroy(&C_petsc);
175   }
176 };
177 
178 struct MatMatStruct_AtB : public MatMatStruct {
179   MatRowMapKokkosView   srcrowoffset,dstrowoffset;
180 };
181 
182 struct MatProductData_MPIAIJKokkos
183 {
184   MatMatStruct_AB   *mmAB;
185   MatMatStruct_AtB  *mmAtB;
186   PetscBool         reusesym;
187 
188   MatProductData_MPIAIJKokkos(): mmAB(NULL),mmAtB(NULL),reusesym(PETSC_FALSE){}
189   ~MatProductData_MPIAIJKokkos() {
190     delete mmAB;
191     delete mmAtB;
192   }
193 };
194 
195 static PetscErrorCode MatProductDataDestroy_MPIAIJKokkos(void *data)
196 {
197   PetscFunctionBegin;
198   PetscCallCXX(delete static_cast<MatProductData_MPIAIJKokkos*>(data));
199   PetscFunctionReturn(0);
200 }
201 
202 /* MatSeqAIJKokkosGetCSRMatrixWithGlobalColumnIds - Get a KokkosCsrMatrix from a MATSEQAIJKOKKOS matrix
203 
204    Input Parameters:
205 +  A       - the MATSEQAIJKOKKOS matrix
206 .  N       - new column size for the returned Kokkos matrix
207 -  l2g     - a map that maps old col ids to new col ids
208 
209    Output Parameters:
210 .  csrmat  - the Kokkos matrix, which has the same row size as A, shares a, i but not j with A.
211  */
212 static PetscErrorCode MatSeqAIJKokkosGetCSRMatrixWithGlobalColumnIds(Mat A,PetscInt N,const ConstMatColIdxKokkosView& l2g,KokkosCsrMatrix& csrmat)
213 {
214   KokkosCsrMatrix&         orig = static_cast<Mat_SeqAIJKokkos*>(A->spptr)->csrmat;
215   MatColIdxKokkosView      jg("jg",orig.nnz()); /* New j array for csrmat */
216 
217   PetscFunctionBegin;
218   PetscCallCXX(Kokkos::parallel_for(orig.nnz(), KOKKOS_LAMBDA(const PetscInt i) {jg(i) = l2g(orig.graph.entries(i));}));
219   PetscCallCXX(csrmat = KokkosCsrMatrix("csrmat",orig.numRows(),N,orig.nnz(),orig.values,orig.graph.row_map,jg));
220   PetscFunctionReturn(0);
221 }
222 
223 /* MatSetMPIAIJKokkosWithSplitSeqAIJKokkosMatrices - Set the diag and offdiag matrices of a MATMPIAIJKOKKOS matrix.
224    It is similar to MatCreateMPIAIJWithSplitArrays.
225 
226   Input Parameters:
227 +  mat   - the MATMPIAIJKOKKOS matrix, which should have its type and layout set, but should not have its diag, offdiag matrices set
228 .  A     - the diag matrix using local col ids
229 -  B     - the offdiag matrix using global col ids
230 
231   Output Parameters:
232 .  mat   - the updated MATMPIAIJKOKKOS matrix
233 */
234 static PetscErrorCode MatSetMPIAIJKokkosWithSplitSeqAIJKokkosMatrices(Mat mat,Mat A,Mat B)
235 {
236   Mat_MPIAIJ          *mpiaij = static_cast<Mat_MPIAIJ*>(mat->data);
237   PetscInt            m,n,M,N,Am,An,Bm,Bn;
238   Mat_SeqAIJKokkos    *bkok = static_cast<Mat_SeqAIJKokkos*>(B->spptr);
239 
240   PetscFunctionBegin;
241   PetscCall(MatGetSize(mat,&M,&N));
242   PetscCall(MatGetLocalSize(mat,&m,&n));
243   PetscCall(MatGetLocalSize(A,&Am,&An));
244   PetscCall(MatGetLocalSize(B,&Bm,&Bn));
245 
246   PetscCheck(m == Am && m == Bm,PETSC_COMM_SELF,PETSC_ERR_PLIB,"local number of rows do not match");
247   PetscCheck(n == An,PETSC_COMM_SELF,PETSC_ERR_PLIB,"local number of columns do not match");
248   PetscCheck(N == Bn,PETSC_COMM_SELF,PETSC_ERR_PLIB,"global number of columns do not match");
249   PetscCheck(!mpiaij->A && !mpiaij->B,PETSC_COMM_SELF,PETSC_ERR_PLIB,"A, B of the MPIAIJ matrix are not empty");
250   mpiaij->A = A;
251   mpiaij->B = B;
252 
253   mat->preallocated     = PETSC_TRUE;
254   mat->nooffprocentries = PETSC_TRUE; /* See MatAssemblyBegin_MPIAIJ. In effect, making MatAssemblyBegin a nop */
255 
256   PetscCall(MatSetOption(mat,MAT_NO_OFF_PROC_ENTRIES,PETSC_TRUE));
257   PetscCall(MatAssemblyBegin(mat,MAT_FINAL_ASSEMBLY));
258   /* MatAssemblyEnd is critical here. It sets mat->offloadmask according to A and B's, and
259     also gets mpiaij->B compacted, with its col ids and size reduced
260   */
261   PetscCall(MatAssemblyEnd(mat,MAT_FINAL_ASSEMBLY));
262   PetscCall(MatSetOption(mat,MAT_NO_OFF_PROC_ENTRIES,PETSC_FALSE));
263   PetscCall(MatSetOption(mat,MAT_NEW_NONZERO_LOCATION_ERR,PETSC_TRUE));
264 
265   /* Update bkok with new local col ids (stored on host) and size */
266   bkok->j_dual.modify_host();
267   bkok->j_dual.sync_device();
268   bkok->SetColSize(mpiaij->B->cmap->n);
269   PetscFunctionReturn(0);
270 }
271 
272 /* MatSeqAIJKokkosBcast - Bcast rows of a SEQAIJKOKKOS matrice (B) to form a SEQAIJKOKKOS matrix (C).
273 
274    It is essentially the MPIAIJKOKKOS counterpart of MatGetBrowsOfAoCols_MPIAIJ, but supports device and uses PetscSF.
275    In the given ownerSF, leaves correspond to rows in C, and roots correspond to rows in B. Roots may connect to multiple leaves.
276    Suppose C's j-th row is connected to a root identified by PetscSFNode (k,i), it means we will bcast the i-th row of B on rank k
277    to j-th row of C. ownerSF's leaves must be contiguous (in other words, as if ilocal=NULL was used to set its graph).
278 
279    Collective on comm of ownerSF
280 
281    Input Parameters:
282 +   B       - the SEQAIJKOKKOS matrix, using local col ids
283 .   reuse   - either MAT_INITIAL_MATRIX or MAT_REUSE_MATRIX
284 .   N       - global col ids are in range of [0,N). N Must be the same across ranks (nonsignificant in MAT_REUSE_MATRIX)
285 .   l2g     - a map mapping B's local col ids to global ones (nonsignificant in MAT_REUSE_MATRIX)
286 .   ownerSF - the ownership SF (nonsignificant in MAT_REUSE_MATRIX)
287 
288    Input/Output Parameters (out when resue = MAT_INITIAL_MATRIX, inout when reuse = MAT_REUSE_MATRIX)
289 +   bcastSF   - the SF used to bcast rows of B. This plain SF does buffer (abuf) to buffer (Ca) send/recv. In this SF, vertices are nonzeros.
290 .   abuf      - buffer for sending matrix values
291 .   rows      - array containing indices of (local) rows that this rank needs to bcast to others. Each receiver rank has a chunk in rows[].
292                 Values in rows[] might have repeats, which simply indicates a row will be bcast'ed to multiple neighbors.
293 .   rowoffset - For each row in rows[], it will be copied to rowoffset[] at abuf[]
294 -   C         -  the SEQAIJKOKKOS matrix made of the bcast'ed rows, using local col ids.
295 */
296 static PetscErrorCode MatSeqAIJKokkosBcast(Mat B,MatReuse reuse,PetscInt N,const ConstMatColIdxKokkosView& l2g,PetscSF ownerSF,
297                                            PetscSF& bcastSF,MatScalarKokkosView& abuf,MatColIdxKokkosView& rows,
298                                            MatRowMapKokkosView& rowoffset,Mat& C)
299 {
300   Mat_SeqAIJKokkos             *bkok,*ckok;
301 
302   PetscFunctionBegin;
303   PetscCall(MatSeqAIJKokkosSyncDevice(B)); /* Make sure B->spptr is accessible */
304   bkok = static_cast<Mat_SeqAIJKokkos*>(B->spptr);
305 
306   if (reuse == MAT_REUSE_MATRIX) {
307     ckok = static_cast<Mat_SeqAIJKokkos*>(C->spptr);
308 
309     const auto& Ba = bkok->a_dual.view_device();
310     const auto& Bi = bkok->i_dual.view_device();
311     const auto& Ca = ckok->a_dual.view_device();
312 
313     /* Copy Ba to abuf */
314     Kokkos::parallel_for(Kokkos::TeamPolicy<>(rows.extent(0), Kokkos::AUTO()),KOKKOS_LAMBDA(const KokkosTeamMemberType& t) {
315       PetscInt i    = t.league_rank(); /* rows[i] is r-th row of B */
316       PetscInt r    = rows(i);
317       PetscInt base = rowoffset(i); /* Copy r-th row of B to this offset in abuf[] */
318       Kokkos::parallel_for(Kokkos::TeamThreadRange(t,Bi(r+1)-Bi(r)),[&](PetscInt k) {
319         abuf(base+k) = Ba(Bi(r)+k);
320       });
321     });
322 
323     /* Send abuf to Ca through bcastSF and then mark C is updated on device */
324     PetscCall(PetscSFBcastBegin(bcastSF,MPIU_SCALAR,abuf.data(),Ca.data(),MPI_REPLACE)); /* TODO: get memtype for abuf */
325     PetscCall(PetscSFBcastEnd  (bcastSF,MPIU_SCALAR,abuf.data(),Ca.data(),MPI_REPLACE));
326     ckok->a_dual.modify_device();
327   } else if (reuse == MAT_INITIAL_MATRIX) {
328     MPI_Comm       comm;
329     PetscMPIInt    tag;
330     PetscInt       k,Cm,Cn,Cnnz,*Ci_h,nroots,nleaves;
331 
332     PetscCallMPI(PetscObjectGetComm((PetscObject)ownerSF,&comm));
333     PetscCall(PetscSFGetGraph(ownerSF,&nroots,&nleaves,NULL,NULL));
334     Cm   = nleaves; /* row size of C */
335     Cn   = N;  /* col size of C, which initially uses global ids, so we can safely set its col size as N */
336 
337     /* Get row lens (nz) of B's rows for later fast query */
338     PetscInt       *Browlens;
339     const PetscInt *tmp = bkok->i_host_data();
340     PetscCall(PetscMalloc1(nroots,&Browlens));
341     for (k=0; k<nroots; k++) Browlens[k] = tmp[k+1]-tmp[k];
342 
343     /* By ownerSF, each proc gets lens of rows of C */
344     MatRowMapKokkosDualView Ci("i",Cm+1); /* C's rowmap */
345     Ci_h    = Ci.view_host().data();
346     Ci_h[0] = 0;
347     PetscCall(PetscSFBcastWithMemTypeBegin(ownerSF,MPIU_INT,PETSC_MEMTYPE_HOST,Browlens,PETSC_MEMTYPE_HOST,&Ci_h[1],MPI_REPLACE));
348     PetscCall(PetscSFBcastEnd(ownerSF,MPIU_INT,Browlens,&Ci_h[1],MPI_REPLACE));
349     for (k=1; k<Cm+1; k++) Ci_h[k] += Ci_h[k-1]; /* Convert lens to CSR */
350     Cnnz    = Ci_h[Cm];
351     Ci.modify_host();
352     Ci.sync_device();
353 
354     /* With the newly known Cnnz, we are able to allocate (j, a) for C on host & device */
355     MatColIdxKokkosDualView  Cj("j",Cnnz);
356     MatScalarKokkosDualView  Ca("a",Cnnz);
357 
358     /* Now build the bcastSF to fill Ca, Cj. This plain SF only does (contiguous) buffer to buffer send/recv */
359     const PetscMPIInt *iranks,*ranks;
360     const PetscInt    *ioffset,*irootloc,*roffset;
361     PetscInt          i,j,niranks,nranks,*sdisp,*rdisp,*rowptr;
362     MPI_Request       *reqs;
363 
364     PetscCall(PetscSFGetLeafRanks(ownerSF,&niranks,&iranks,&ioffset,&irootloc)); /* irootloc[] contains indices of rows I need to send to each receiver */
365     PetscCall(PetscSFGetRootRanks(ownerSF,&nranks,&ranks,&roffset,NULL/*rmine*/,NULL/*rremote*/)); /* recv info */
366 
367     /* figure out offsets at the send buffer, to build the SF
368       sdisp[]  - stores offsets of nonzeros (in abuf or jbuf, see later) I need to send, per receiver.
369       rowptr[] - stores offsets for data of each row in abuf
370 
371       rdisp[]  - to receive sdisp[]
372     */
373     PetscCall(PetscMalloc3(niranks+1,&sdisp,nranks,&rdisp,niranks+nranks,&reqs));
374     MatRowMapKokkosViewHost rowptr_h("rowptr_h",ioffset[niranks]+1); /* Let Kokkos do the allocation, so that we can do an easy mirror later */
375     rowptr = rowptr_h.data();
376 
377     sdisp[0] = 0;
378     rowptr[0]  = 0;
379     for (i=0; i<niranks; i++) { /* for each receiver */
380       PetscInt len, nz = 0;
381       for (j=ioffset[i]; j<ioffset[i+1]; j++) { /* for each row to this receiver */
382         len         = Browlens[irootloc[j]];
383         rowptr[j+1] = rowptr[j] + len;
384         nz         += len;
385       }
386       sdisp[i+1] = sdisp[i] + nz;
387     }
388     PetscCallMPI(PetscCommGetNewTag(comm,&tag));
389     for (i=0; i<nranks; i++)  PetscCallMPI(MPI_Irecv(&rdisp[i],1,MPIU_INT,ranks[i],tag,comm,&reqs[i]));
390     for (i=0; i<niranks; i++) PetscCallMPI(MPI_Isend(&sdisp[i],1,MPIU_INT,iranks[i],tag,comm,&reqs[nranks+i]));
391     PetscCallMPI(MPI_Waitall(niranks+nranks,reqs,MPI_STATUSES_IGNORE));
392 
393     PetscInt    nleaves2 = Cnnz; /* leaves are the nonzeros I will receive */
394     PetscInt    nroots2  = sdisp[niranks]; /* roots are the nonzeros (in abuf) I will send */
395     PetscSFNode *iremote;
396     PetscCall(PetscMalloc1(nleaves2,&iremote));
397     for (i=0; i<nranks; i++) { /* for each sender */
398       k = 0;
399       for (j=Ci_h[roffset[i]]; j<Ci_h[roffset[i+1]]; j++) {
400         iremote[j].rank  = ranks[i];
401         iremote[j].index = rdisp[i] + k;
402         k++;
403       }
404     }
405     /* TODO: we should extend PetscSF APIs for this buffer-to-buffer send/recv */
406     PetscCall(PetscSFCreate(comm,&bcastSF));
407     PetscCall(PetscSFSetGraph(bcastSF,nroots2,nleaves2,NULL/*ilocal*/,PETSC_OWN_POINTER,iremote,PETSC_OWN_POINTER));
408 
409     /* Extract selected rows of B, and copy their (a, j) into abuf[] and jbuf[], with j converted
410       from local to global. Then use bcastSF to fill Ca, Cj.
411     */
412     ConstMatColIdxKokkosViewHost rows_h(irootloc,ioffset[niranks]); /* irootloc[] stores indices of rows I need to send */
413     MatColIdxKokkosView          rows("rows",ioffset[niranks]);
414     Kokkos::deep_copy(rows,rows_h); /* Use deep copy since irootoc is managed by PetscSF and we want 'rows' to be standalone */
415 
416     rowoffset = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(),rowptr_h); /* If no device, rowoffset will be an alias to rowptr_h */
417 
418     MatColIdxKokkosView jbuf("jbuf",sdisp[niranks]); /* send buf for (global) col ids */
419     abuf = MatScalarKokkosView("abuf",sdisp[niranks]); /* send buf for mat values */
420 
421     const auto& Ba = bkok->a_dual.view_device();
422     const auto& Bi = bkok->i_dual.view_device();
423     const auto& Bj = bkok->j_dual.view_device();
424 
425     /* Copy Ba, Bj to abuf, jbuf with change col ids from local to global */
426     Kokkos::parallel_for(Kokkos::TeamPolicy<>(rows.extent(0), Kokkos::AUTO()),KOKKOS_LAMBDA(const KokkosTeamMemberType& t) {
427       PetscInt i    = t.league_rank(); /* rows[i] is r-th row of B */
428       PetscInt r    = rows(i);
429       PetscInt base = rowoffset(i); /* Copy r-th row of B to this offset in abuf[] */
430       Kokkos::parallel_for(Kokkos::TeamThreadRange(t,Bi(r+1)-Bi(r)),[&](PetscInt k) {
431         abuf(base+k) = Ba(Bi(r)+k);
432         jbuf(base+k) = l2g(Bj(Bi(r)+k));
433       });
434     });
435 
436     /* Send abuf & jbuf to fill Ca, Cj */
437     PetscCall(PetscSFBcastBegin(bcastSF,MPIU_INT,   jbuf.data(),Cj.view_device().data(),MPI_REPLACE));
438     PetscCall(PetscSFBcastBegin(bcastSF,MPIU_SCALAR,abuf.data(),Ca.view_device().data(),MPI_REPLACE));
439     PetscCall(PetscSFBcastEnd  (bcastSF,MPIU_INT,   jbuf.data(),Cj.view_device().data(),MPI_REPLACE));
440     PetscCall(PetscSFBcastEnd  (bcastSF,MPIU_SCALAR,abuf.data(),Ca.view_device().data(),MPI_REPLACE));
441     Cj.modify_device(); /* Mark Cj, Ca modified on device, but only sync Cj since we might not need Ca on host at all */
442     Cj.sync_host();
443     Ca.modify_device();
444 
445     /* Construct C with Ca, Ci, Cj */
446     auto ckok = new Mat_SeqAIJKokkos(Cm,Cn,Cnnz,Ci,Cj,Ca);
447     PetscCall(MatCreateSeqAIJKokkosWithCSRMatrix(PETSC_COMM_SELF,ckok,&C));
448     PetscCall(PetscFree3(sdisp,rdisp,reqs));
449     PetscCall(PetscFree(Browlens));
450   } else SETERRQ(PETSC_COMM_SELF,PETSC_ERR_PLIB,"Unsupported MatReuse enum %d",reuse);
451   PetscFunctionReturn(0);
452 }
453 
454 /* MatSeqAIJKokkosReduce - Reduce rows of a SEQAIJKOKKOS matrix (A) to form a Kokkos Csr matrix (C)
455 
456   It is the reverse of MatSeqAIJKokkosBcast in some sense.
457 
458   Think each row of A as a leaf, then the given ownerSF specifies roots of the leaves. Roots may connect to multiple leaves.
459   In this routine, we reduce (i.e., concatenate) leaves (rows) at their roots to form potentially longer rows in C. Such rows might
460   contain repeats, which does not matter since they will be summed up by other routines. C's row size will be nroots of ownerSF.
461 
462   Input Parameters:
463 +  A        - the SEQAIJKOKKOS matrix to be reduced
464 .  reuse    - either MAT_INITIAL_MATRIX or MAT_REUSE_MATRIX
465 .  local    - true if A uses local col ids; false if A is already in global col ids.
466 .  N        - if local, N is A's global col size
467 .  l2g      - if local, a map mapping A's local col ids to global ones, which are in range of [0,N).
468 -  ownerSF  - the SF specifies ownership (root) of rows in A
469 
470   Output Parameters:
471 +  reduceSF    - the SF to reduce A's rows to contiguous buffers at the receiver side
472 .  abuf         - a contiguous buffer to receive A's rows sent to this proc. Suppose there are 'nrows' such rows.
473 .  srcrowoffset - offset array of size nrows+1. Each entry is the corresponding row's offset in abuf[]. srcrowoffset[i+1]-srcrowoffset[i] is row i's len.
474 .  dstrowoffset - offset array of size nrows. Each entry is the corresponding row's offset in Ca[], i.e., C's 'a' array. Row i, i+1 in abuf[] may go to
475                   unrelated places in Ca, so dstrowoffset is not in CSR-like format as srcrowoffset.
476 -  C            - the matrix made up by rows sent to me from other ranks, using global col ids
477 
478    TODO: we can even have MatSeqAIJKokkosReduceBegin/End to provide oppertunity for callers to overlap comp./comm. when reuse = MAT_REUSE_MATRIX.
479  */
480 static PetscErrorCode MatSeqAIJKokkosReduce(Mat A,MatReuse reuse,PetscBool local,PetscInt N,const ConstMatColIdxKokkosView& l2g,PetscSF ownerSF,
481                                             PetscSF& reduceSF,MatScalarKokkosView& abuf,
482                                             MatRowMapKokkosView& srcrowoffset,MatRowMapKokkosView& dstrowoffset,
483                                             KokkosCsrMatrix& C)
484 {
485   PetscInt               i,r,Am,An,Annz,Cnnz,nrows;
486   const PetscInt         *Ai;
487   Mat_SeqAIJKokkos       *akok;
488 
489   PetscFunctionBegin;
490   PetscCall(MatSeqAIJKokkosSyncDevice(A)); /* So that A's latest data is on device */
491   PetscCall(MatGetSize(A,&Am,&An));
492   Ai   = static_cast<Mat_SeqAIJ*>(A->data)->i;
493   akok = static_cast<Mat_SeqAIJKokkos*>(A->spptr);
494   Annz = Ai[Am];
495 
496   if (reuse == MAT_REUSE_MATRIX) {
497     /* Send Aa to abuf */
498     PetscCallMPI(PetscSFReduceBegin(reduceSF,MPIU_SCALAR,akok->a_device_data(),abuf.data(),MPI_REPLACE));
499     PetscCallMPI(PetscSFReduceEnd  (reduceSF,MPIU_SCALAR,akok->a_device_data(),abuf.data(),MPI_REPLACE));
500 
501     /* Copy abuf to Ca */
502     const MatScalarKokkosView& Ca = C.values;
503     nrows = dstrowoffset.extent(0); /* Not srcrowoffset[] since it has an extra entry for CSR */
504     Kokkos::parallel_for(Kokkos::TeamPolicy<>(nrows, Kokkos::AUTO()),KOKKOS_LAMBDA(const KokkosTeamMemberType& t) {
505       PetscInt i   = t.league_rank();
506       PetscInt src = srcrowoffset(i), dst = dstrowoffset(i);
507       PetscInt len = srcrowoffset(i+1) - srcrowoffset(i);
508       Kokkos::parallel_for(Kokkos::TeamThreadRange(t,len), [&](PetscInt k) {Ca(dst+k) = abuf(src+k);});
509     });
510   } else if (reuse == MAT_INITIAL_MATRIX) {
511     MPI_Comm               comm;
512     MPI_Request            *reqs;
513     PetscMPIInt            tag;
514     PetscInt               Cm;
515 
516     PetscCall(PetscObjectGetComm((PetscObject)ownerSF,&comm));
517     PetscCall(PetscCommGetNewTag(comm,&tag));
518 
519     PetscInt niranks,nranks,nroots,nleaves;
520     const PetscMPIInt *iranks,*ranks;
521     const PetscInt *ioffset,*rows,*roffset;  /* rows[] contains local indices of rows scattered to me from others. ioffset[] is a CSR on rows[] */
522     PetscCall(PetscSFSetUp(ownerSF));
523     PetscCall(PetscSFGetLeafRanks(ownerSF,&niranks,&iranks,&ioffset,&rows)); /* recv info: iranks[] will send rows to me */
524     PetscCall(PetscSFGetRootRanks(ownerSF,&nranks,&ranks,&roffset,NULL/*rmine*/,NULL/*rremote*/)); /* send info */
525     PetscCall(PetscSFGetGraph(ownerSF,&nroots,&nleaves,NULL,NULL));
526     PetscCheck(nleaves == Am,PETSC_COMM_SELF,PETSC_ERR_PLIB,"ownerSF's nleaves(%" PetscInt_FMT ") != row size of A(%" PetscInt_FMT ")",nleaves,Am);
527     Cm    = nroots;
528     nrows = ioffset[niranks]; /* # of rows to be received. Might receive same row (each is partial) from different senders */
529 
530     /* Tell owners how long each row I will send */
531     PetscInt                *srowlens; /* send buf of row lens */
532     MatRowMapKokkosViewHost rrowlens_h("rrowoffset_h",nrows+1); /* recv buf of row lens. +1 to make CSR later. Memory might be passed to other views */
533     PetscInt                *rrowlens = rrowlens_h.data();
534 
535     PetscCall(PetscMalloc2(Am,&srowlens,niranks+nranks,&reqs));
536     for (i=0; i<Am; i++) srowlens[i] = Ai[i+1] - Ai[i];
537     rrowlens[0] = 0;
538     rrowlens++; /* shift the pointer to make the following expression more readable */
539     for (i=0; i<niranks; i++)PetscCallMPI(MPI_Irecv(&rrowlens[ioffset[i]],ioffset[i+1]-ioffset[i],MPIU_INT,iranks[i],tag,comm,&reqs[i]));
540     for (i=0; i<nranks; i++) PetscCallMPI(MPI_Isend(&srowlens[roffset[i]],roffset[i+1]-roffset[i],MPIU_INT,ranks[i],tag,comm,&reqs[niranks+i]));
541     PetscCallMPI(MPI_Waitall(niranks+nranks,reqs,MPI_STATUSES_IGNORE));
542 
543     /* Owner builds Ci on host by histogramming rrowlens[] */
544     MatRowMapKokkosViewHost Ci_h("i",Cm+1);
545     Kokkos::deep_copy(Ci_h,0); /* Zero Ci */
546     MatRowMapType *Ci_ptr = Ci_h.data();
547 
548     for (i=0; i<nrows; i++) {
549       r = rows[i]; /* local row id of i-th received row */
550      #if defined(PETSC_USE_DEBUG)
551       PetscCheck(r >= 0 && r < Cm,PETSC_COMM_SELF,PETSC_ERR_PLIB,"local row id (%" PetscInt_FMT ") is out of range [0,%" PetscInt_FMT ")",r,Cm);
552      #endif
553       Ci_ptr[r+1] += rrowlens[i]; /* add to length of row r in C */
554     }
555     for (i=0; i<Cm; i++) Ci_ptr[i+1] += Ci_ptr[i]; /* to CSR format */
556     Cnnz = Ci_ptr[Cm];
557 
558     /* For each received row, compute src & dst offsets in memory copying (from recv bufs abuf, jbuf to Ca, Cj) */
559     MatRowMapKokkosViewHost dstrowoffset_h("dstrowoffset_h",nrows);
560     PetscInt                *dstrowoffset_hptr = dstrowoffset_h.data();
561     PetscInt                *currowlens; /* Current row lens. They are temp accumulators for row lens in C, to help build dstrowoffset */
562 
563     PetscCall(PetscCalloc1(Cm,&currowlens)); /* Init with zero, to be added to */
564     for (i=0; i<nrows; i++) { /* for each row I receive */
565       r                    = rows[i]; /* row id in C */
566       dstrowoffset_hptr[i] = Ci_ptr[r] + currowlens[r]; /* dst offset of the new place for each recv'ed row in Ca/Cj */
567       currowlens[r]       += rrowlens[i]; /* accumulate to length of row r in C */
568     }
569     PetscCall(PetscFree(currowlens));
570 
571     rrowlens--;
572     for (i=0; i<nrows; i++) rrowlens[i+1] += rrowlens[i]; /* Change rrowlens[] to CSR format */
573     dstrowoffset = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(),dstrowoffset_h);
574     srcrowoffset = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(),rrowlens_h); /* src offset of each recv'ed row in abuf/jbuf */
575 
576     /* Build the reduceSF, which performs buffer to buffer send/recv */
577     PetscInt *sdisp,*rdisp; /* buffer to send offsets of roots, and buffer to recv them */
578     PetscCall(PetscMalloc2(niranks,&sdisp,nranks,&rdisp));
579     for (i=0; i<niranks; i++) sdisp[i] = rrowlens[ioffset[i]];
580     for (i=0; i<nranks; i++)  PetscCallMPI(MPI_Irecv(&rdisp[i],1,MPIU_INT,ranks[i],tag,comm,&reqs[i]));
581     for (i=0; i<niranks; i++) PetscCallMPI(MPI_Isend(&sdisp[i],1,MPIU_INT,iranks[i],tag,comm,&reqs[nranks+i]));
582     PetscCallMPI(MPI_Waitall(niranks+nranks,reqs,MPI_STATUSES_IGNORE));
583 
584     /* Nonzeros in abuf/jbuf are roots and those in A are leaves */
585     PetscInt    nroots2 = Cnnz,nleaves2 = Annz;
586     PetscSFNode *iremote;
587     PetscCall(PetscMalloc1(nleaves2,&iremote)); /* no free, since memory will be given to reduceSF */
588     for (i=0; i<nranks; i++) {
589       PetscInt rootbase = rdisp[i]; /* root offset at this root rank */
590       PetscInt leafbase = Ai[roffset[i]]; /* leaf base */
591       PetscInt nz       = Ai[roffset[i+1]] - leafbase; /* I will send nz nonzeros to this root rank */
592       for (PetscInt k=0; k<nz; k++) {
593         iremote[leafbase+k].rank  = ranks[i];
594         iremote[leafbase+k].index = rootbase + k;
595       }
596     }
597     PetscCall(PetscSFCreate(comm,&reduceSF));
598     PetscCall(PetscSFSetGraph(reduceSF,nroots2,nleaves2,NULL,PETSC_OWN_POINTER,iremote,PETSC_OWN_POINTER));
599     PetscCall(PetscFree2(sdisp,rdisp));
600 
601     /* Reduce Aa, Ajg to abuf and jbuf */
602 
603     /* If A uses local col ids, convert them to global ones before sending */
604     MatColIdxKokkosView Ajg;
605     if (local) {
606       Ajg = MatColIdxKokkosView("j",Annz);
607       const MatColIdxKokkosView& Aj = akok->j_dual.view_device();
608       Kokkos::parallel_for(Annz,KOKKOS_LAMBDA(const PetscInt i) {Ajg(i) = l2g(Aj(i));});
609     } else {
610       Ajg = akok->j_dual.view_device(); /* no data copy, just take a reference */
611     }
612 
613     MatColIdxKokkosView   jbuf("jbuf",Cnnz);
614     abuf = MatScalarKokkosView("abuf",Cnnz);
615     PetscCallMPI(PetscSFReduceBegin(reduceSF,MPIU_INT,   Ajg.data(),           jbuf.data(),MPI_REPLACE));
616     PetscCallMPI(PetscSFReduceEnd  (reduceSF,MPIU_INT,   Ajg.data(),           jbuf.data(),MPI_REPLACE));
617     PetscCallMPI(PetscSFReduceBegin(reduceSF,MPIU_SCALAR,akok->a_device_data(),abuf.data(),MPI_REPLACE));
618     PetscCallMPI(PetscSFReduceEnd  (reduceSF,MPIU_SCALAR,akok->a_device_data(),abuf.data(),MPI_REPLACE));
619 
620     /* Copy data from abuf, jbuf to Ca, Cj */
621     MatRowMapKokkosView    Ci = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(),Ci_h); /* Ci is an alias of Ci_h if no device */
622     MatColIdxKokkosView    Cj("j",Cnnz);
623     MatScalarKokkosView    Ca("a",Cnnz);
624 
625     Kokkos::parallel_for(Kokkos::TeamPolicy<>(nrows, Kokkos::AUTO()),KOKKOS_LAMBDA(const KokkosTeamMemberType& t) {
626       PetscInt i   = t.league_rank();
627       PetscInt src = srcrowoffset(i), dst = dstrowoffset(i);
628       PetscInt len = srcrowoffset(i+1) - srcrowoffset(i);
629       Kokkos::parallel_for(Kokkos::TeamThreadRange(t,len), [&](PetscInt k) {
630         Ca(dst+k) = abuf(src+k);
631         Cj(dst+k) = jbuf(src+k);
632       });
633     });
634 
635     /* Build C with Ca, Ci, Cj */
636     C    = KokkosCsrMatrix("csrmat",Cm,N,Cnnz,Ca,Ci,Cj);
637     PetscCall(PetscFree2(srowlens,reqs));
638   } else SETERRQ(PETSC_COMM_SELF,PETSC_ERR_PLIB,"Unsupported MatReuse enum %d",reuse);
639   PetscFunctionReturn(0);
640 }
641 
642 /* MatSetMPIAIJKokkosWithGlobalCSRMatrix - Set the diag and offdiag parts of a MATMPIAIJKOKKOS matrix by splitting a KokkosCsrMatrix
643 
644   Input Parameters:
645 +  C        - the MATMPIAIJKOKKOS matrix, of size m,n,M,N
646 .  reuse    - indicate whether the matrix has called this function before
647 .  csrmat   - the KokkosCsrMatrix, of size m,N
648 -  Cdstart  - when reuse == MAT_REUSE_MATRIX, it is an input parameter. For each row in csrmat, it stores the start of the first
649               entry of the diag block of C in csrmat's j array. E.g, if row i has col ids = {0, 3, 4, 5, 7, 9} and the first diag
650               entry is 5, then Cdstart[i] = 3.
651 
652   Output Parameters:
653 +  C        - the updated MATMPIAIJKOKKOS matrix
654 -  Cdstart - when reuse == MAT_INITIAL_MATRIX, it is an output parameter
655 
656   Notes:
657    Between calls with MAT_INITIAL_MATRIX or MAT_REUSE_MATRIX, csrmat must have the same nonzero pattern
658  */
659 static PetscErrorCode MatSetMPIAIJKokkosWithGlobalCSRMatrix(Mat C,MatReuse reuse,const KokkosCsrMatrix& csrmat,MatRowMapKokkosView& Cdstart)
660 {
661   const MatScalarKokkosView&      Ca = csrmat.values;
662   const ConstMatRowMapKokkosView& Ci = csrmat.graph.row_map;
663   PetscInt                        m,n,N;
664 
665   PetscFunctionBegin;
666   PetscCall(MatGetLocalSize(C,&m,&n));
667   PetscCall(MatGetSize(C,NULL,&N));
668 
669   if (reuse == MAT_REUSE_MATRIX) {
670     Mat_MPIAIJ                  *mpiaij = static_cast<Mat_MPIAIJ*>(C->data);
671     Mat_SeqAIJKokkos            *akok = static_cast<Mat_SeqAIJKokkos*>(mpiaij->A->spptr);
672     Mat_SeqAIJKokkos            *bkok = static_cast<Mat_SeqAIJKokkos*>(mpiaij->B->spptr);
673     const MatScalarKokkosView&  Cda = akok->a_dual.view_device(),Coa = bkok->a_dual.view_device();
674     const MatRowMapKokkosView&  Cdi = akok->i_dual.view_device(),Coi = bkok->i_dual.view_device();
675 
676     /* Fill 'a' of Cd and Co on device */
677     Kokkos::parallel_for(Kokkos::TeamPolicy<>(m, Kokkos::AUTO()),KOKKOS_LAMBDA(const KokkosTeamMemberType& t) {
678       PetscInt i       = t.league_rank(); /* row i */
679       PetscInt clen    = Ci(i+1) - Ci(i); /* len of row i of C */
680       PetscInt cdlen   = Cdi(i+1) - Cdi(i); /* len of row i of Cd */
681       PetscInt cdstart = Cdstart(i); /* [start, end) of row i of Cd in C */
682       PetscInt cdend   = cdstart + cdlen;
683       /* [0, clen) is cut into three blocks: [0, cdstart), [cdstart, cdend), [cdend, clen) */
684       Kokkos::parallel_for(Kokkos::TeamThreadRange(t, clen), [&](PetscInt k) {
685         if (k < cdstart) {  /* k in [0, cdstart) */
686           Coa(Coi(i)+k) = Ca(Ci(i)+k);
687         } else if (k < cdend) { /* k in [cdstart, cdend) */
688           Cda(Cdi(i)+(k-cdstart)) = Ca(Ci(i)+k);
689         } else { /* k in [cdend, clen) */
690           Coa(Coi(i)+k-cdlen) = Ca(Ci(i)+k);
691         }
692       });
693     });
694 
695     akok->a_dual.modify_device();
696     bkok->a_dual.modify_device();
697   } else if (reuse == MAT_INITIAL_MATRIX) {
698     Mat                         Cd,Co;
699     const MatColIdxKokkosView&  Cj = csrmat.graph.entries;
700     MatRowMapKokkosDualView     Cdi_dual("i",m+1),Coi_dual("i",m+1);
701     MatRowMapKokkosView         Cdi = Cdi_dual.view_device(),Coi = Coi_dual.view_device();
702     PetscInt                    cstart,cend;
703 
704     /* Note that each row of C is sorted by col ids. We want to find out how to cut each row into three blocks:
705        left to the diag block, diag block, right to the diag block. The diag block have col ids in [cstart,cend).
706        Suppose a row of C has len nonzeros, indexed by [0, len). We want to know two indices: cdstart and cdend,
707        such that the three blocks are [0,cdstart), [cdstart,cdend), [cdend,len). The following code equivalentaly
708        stores values of cdstart and cdend-cstart (aka Cdi[]) instead.
709      */
710     Cdstart = MatRowMapKokkosView("Cdstart",m);
711     PetscCall(PetscLayoutGetRange(C->cmap,&cstart,&cend)); /* Not MatGetOwnershipRangeColumn() since C has not been preallocated yet */
712 
713     /* I could use RangePolicy and one thread per row. But since each thread essentially does binary search, threads in a
714       CUDA warp would completely diverge. So I use TeamPolicy with a team size 1.
715      */
716     Kokkos::parallel_for(Kokkos::TeamPolicy<>(m, 1),KOKKOS_LAMBDA(const KokkosTeamMemberType& t) {
717       Kokkos::single(Kokkos::PerTeam(t), [=] () { /* Only one thread works in a team */
718         PetscInt i = t.league_rank(); /* row i */
719         PetscInt j,first,count,step;
720 
721         if (i == 0) { /* Set the first entry of the i arrays to zero on device, to be used in CSR */
722           Cdi(0) = 0;
723           Coi(0) = 0;
724         }
725 
726         /* Do std::lower_bound(Ci(i),Ci(i+1),cstart) on Cj[]. We use j as the iterator. lower_bound() returns
727           in 'first' the first iterator with a value >= cstart, or last iterator if no such element is found.
728         */
729         count = Ci(i+1)-Ci(i);
730         first = Ci(i);
731         while (count > 0) {
732           j    = first;
733           step = count / 2;
734           j   += step;
735           if (Cj(j) < cstart) {
736             first  = ++j;
737             count -= step + 1;
738           } else count = step;
739         }
740         Cdstart(i) = first - Ci(i); /* 'first' is the while-loop's output */
741 
742         /* Do std::lower_bound(first,Ci(i+1),cend) on Cj[] */
743         count = Ci(i+1) - first;
744         while (count > 0) {
745           j    = first;
746           step = count / 2;
747           j   += step;
748           if (Cj(j) < cend) {
749             first  = ++j;
750             count -= step + 1;
751           } else count = step;
752         }
753         Cdi(i+1) = first - (Ci(i)+Cdstart(i)); /* 'first' is the while-loop's output */
754         Coi(i+1) = (Ci(i+1)-Ci(i)) - Cdi(i+1); /* Co's row len = C's row len - Cd's row len */
755       });
756     });
757 
758     /* Convert row lens in Cdi[], Coi[] to CSR format using inclusive scan, e.g., changing [0,1,2,3] into [0,1,3,6] */
759     Kokkos::parallel_scan(m+1,KOKKOS_LAMBDA(const PetscInt i,PetscInt& update,const bool final) {
760       update += Cdi(i);
761       if (final) Cdi(i) = update;
762     });
763     Kokkos::parallel_scan(m+1,KOKKOS_LAMBDA(const PetscInt i,PetscInt& update,const bool final) {
764       update += Coi(i);
765       if (final) Coi(i) = update;
766     });
767 
768     /* Get Cdi, Coi on host (it is not a waste, since we do need them on host in
769        MatCreateSeqAIJKokkosWithCSRMatrix() below), then get nnz of Cd and Co.
770     */
771     Cdi_dual.modify_device();
772     Coi_dual.modify_device();
773     Cdi_dual.sync_host();
774     Coi_dual.sync_host();
775     PetscInt Cd_nnz = Cdi_dual.view_host().data()[m];
776     PetscInt Co_nnz = Coi_dual.view_host().data()[m];
777 
778     /* With nnz, allocate a, j for Cd and Co */
779     MatColIdxKokkosDualView Cdj_dual("j",Cd_nnz),Coj_dual("j",Co_nnz);
780     MatScalarKokkosDualView Cda_dual("a",Cd_nnz),Coa_dual("a",Co_nnz);
781 
782     /* Fill a, j of Cd and Co on device */
783     MatColIdxKokkosView     Cdj = Cdj_dual.view_device(),Coj = Coj_dual.view_device();
784     MatScalarKokkosView     Cda = Cda_dual.view_device(),Coa = Coa_dual.view_device();
785 
786     Kokkos::parallel_for(Kokkos::TeamPolicy<>(m, Kokkos::AUTO()),KOKKOS_LAMBDA(const KokkosTeamMemberType& t) {
787       PetscInt i       = t.league_rank(); /* row i */
788       PetscInt clen    = Ci(i+1) - Ci(i); /* len of row i of C */
789       PetscInt cdlen   = Cdi(i+1) - Cdi(i); /* len of row i of Cd */
790       PetscInt cdstart = Cdstart(i); /* [start, end) of row i of Cd in C */
791       PetscInt cdend   = cdstart + cdlen;
792       /* [0, clen) is cut into three blocks: [0, cdstart), [cdstart, cdend), [cdend, clen) */
793       Kokkos::parallel_for(Kokkos::TeamThreadRange(t, clen), [&](PetscInt k) {
794         if (k < cdstart) { /* k in [0, cdstart) */
795           Coa(Coi(i)+k) = Ca(Ci(i)+k);
796           Coj(Coi(i)+k) = Cj(Ci(i)+k);
797         } else if (k < cdend) { /* k in [cdstart, cdend) */
798           Cda(Cdi(i)+(k-cdstart)) = Ca(Ci(i)+k);
799           Cdj(Cdi(i)+(k-cdstart)) = Cj(Ci(i)+k) - cstart; /* Use local col ids in Cdj */
800         } else { /* k in [cdend, clen) */
801           Coa(Coi(i)+k-cdlen) = Ca(Ci(i)+k);
802           Coj(Coi(i)+k-cdlen) = Cj(Ci(i)+k);
803         }
804       });
805     });
806 
807     Cdj_dual.modify_device();
808     Cda_dual.modify_device();
809     Coj_dual.modify_device();
810     Coa_dual.modify_device();
811     /* With a, i, j for Cd and Co, finally build Cd, Co and then C. Their offloadmask will be set in each's MatAssemblyEnd */
812     auto cdkok = new Mat_SeqAIJKokkos(m,n,Cd_nnz,Cdi_dual,Cdj_dual,Cda_dual);
813     auto cokok = new Mat_SeqAIJKokkos(m,N,Co_nnz,Coi_dual,Coj_dual,Coa_dual);
814     PetscCall(MatCreateSeqAIJKokkosWithCSRMatrix(PETSC_COMM_SELF,cdkok,&Cd));
815     PetscCall(MatCreateSeqAIJKokkosWithCSRMatrix(PETSC_COMM_SELF,cokok,&Co));
816     PetscCall(MatSetMPIAIJKokkosWithSplitSeqAIJKokkosMatrices(C,Cd,Co)); /* Coj will be converted to local ids within */
817   }
818   PetscFunctionReturn(0);
819 }
820 
821 /* MatSeqAIJCompactOutExtraColumns_SeqAIJKokkos - Compact a SEQAIJKOKKS matrix's global col ids.
822 
823   It is similar to MatSeqAIJCompactOutExtraColumns_SeqAIJ, but it applies to SEQAIJKOKKOS and returns the l2g map in Kokkos view.
824 
825   Input Parameters:
826 +  C        - the MATMPIAIJKOKKOS matrix, of size m,n,M,N
827 .  reuse    - indicate whether the matrix has called this function before
828 .  csrmat   - the KokkosCsrMatrix, of size m,N
829 -  Cdoffset - when reuse == MAT_REUSE_MATRIX, it is an input parameter. For each row in csrmat, it stores the offset of the first
830               entry of the diag block of C in csrmat's j array.
831 
832   Output Parameters:
833 +  C        - the updated MATMPIAIJKOKKOS matrix
834 -  Cdoffset - when reuse == MAT_INITIAL_MATRIX, it is an output parameter
835 
836   Notes: the input matrix's col ids and col size will be changed.
837 */
838 static PetscErrorCode MatSeqAIJCompactOutExtraColumns_SeqAIJKokkos(Mat C,MatColIdxKokkosView& l2g)
839 {
840   Mat_SeqAIJKokkos       *ckok;
841   ISLocalToGlobalMapping l2gmap;
842   const PetscInt         *garray;
843   PetscInt               sz;
844 
845   PetscFunctionBegin;
846   /* Compact P_other's global col ids and col size. We do it since we guess with local ids KK might be more memory scalable */
847   PetscCall(MatSeqAIJCompactOutExtraColumns_SeqAIJ(C,&l2gmap));
848   ckok = static_cast<Mat_SeqAIJKokkos*>(C->spptr);
849   ckok->j_dual.modify_host(); /* P_other's j is modified on host; we need to sync it on device */
850   ckok->j_dual.sync_device();
851   ckok->SetColSize(C->cmap->n); /* Update col size of the csrmat in spptr */
852 
853   /* Build l2g -- the local to global mapping of C's cols */
854   PetscCall(ISLocalToGlobalMappingGetIndices(l2gmap,&garray));
855   PetscCall(ISLocalToGlobalMappingGetSize(l2gmap,&sz));
856   PetscCheck(C->cmap->n == sz,PETSC_COMM_SELF,PETSC_ERR_PLIB,"matrix column size(%" PetscInt_FMT ") != l2g mapping size(%" PetscInt_FMT ")", C->cmap->n,sz);
857 
858   ConstMatColIdxKokkosViewHost tmp(garray,sz);
859   l2g = MatColIdxKokkosView("l2g",sz);
860   Kokkos::deep_copy(l2g,tmp);
861 
862   PetscCall(ISLocalToGlobalMappingRestoreIndices(l2gmap,&garray));
863   PetscCall(ISLocalToGlobalMappingDestroy(&l2gmap));
864   PetscFunctionReturn(0);
865 }
866 
867 /* MatProductSymbolic_MPIAIJKokkos_AB - AB flavor of MatProductSymbolic_MPIAIJKokkos
868 
869   Input Parameters:
870 +  product  - Mat_Product which carried out the computation. Passed in to access info about this mat product.
871 .  A        - an MPIAIJKOKKOS matrix
872 .  B        - an MPIAIJKOKKOS matrix
873 -  mm       - a struct used to stash intermediate data when computing AB. Persist from symbolic to numeric operations.
874 
875   Notes: The local part of the result C is stored as mm->C_global, which is of type KokkosCsrMatrix and uses global col ids.
876 */
877 static PetscErrorCode MatProductSymbolic_MPIAIJKokkos_AB(Mat_Product *product,Mat A,Mat B,MatMatStruct_AB *mm)
878 {
879   Mat_MPIAIJ                  *a = static_cast<Mat_MPIAIJ*>(A->data);
880   Mat                         Ad = a->A,Ao = a->B; /* diag and offdiag of A */
881   IS                          glob = NULL;
882   const PetscInt              *garray;
883   PetscInt                    N = B->cmap->N,sz;
884   ConstMatColIdxKokkosView    l2g1; /* two temp maps mapping local col ids to global ones */
885   MatColIdxKokkosView         l2g2;
886   Mat                         C1,C2; /* intermediate matrices */
887 
888   PetscFunctionBegin;
889   /* C1 = Ad * B_local. B_local is a matrix got by merging Bd and Bo, and uses local col ids */
890   PetscCall(MatMPIAIJGetLocalMatMerge(B,MAT_INITIAL_MATRIX,&glob,&mm->B_local));
891   PetscCall(MatProductCreate(Ad,mm->B_local,NULL,&C1));
892   PetscCall(MatProductSetType(C1,MATPRODUCT_AB));
893   PetscCall(MatProductSetFill(C1,product->fill));
894   C1->product->api_user = product->api_user;
895   PetscCall(MatProductSetFromOptions(C1));
896   PetscCheck(C1->ops->productsymbolic,PETSC_COMM_SELF,PETSC_ERR_PLIB,"Missing symbolic op for %s",MatProductTypes[C1->product->type]);
897   PetscCall((*C1->ops->productsymbolic)(C1));
898 
899   PetscCall(ISGetIndices(glob,&garray));
900   PetscCall(ISGetSize(glob,&sz));
901   const auto& tmp  = ConstMatColIdxKokkosViewHost(garray,sz); /* wrap garray as a view */
902   l2g1 = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(),tmp); /* maybe just an alias to tmp, so we restore garray at the very end */
903   PetscCall(MatSeqAIJKokkosGetCSRMatrixWithGlobalColumnIds(C1,N,l2g1,mm->C1_global));
904 
905   /* C2 = Ao * B_other. B_other is a matrix consisting of needed rows of B gathered from other procs */
906   PetscCall(MatSeqAIJKokkosBcast(mm->B_local,MAT_INITIAL_MATRIX,N,l2g1,a->Mvctx,mm->sf,mm->abuf,mm->rows,mm->rowoffset,mm->B_other));
907 
908   /* Compact B_other to use local ids as we guess KK spgemm is more memroy scalable with that; We could skip the compaction to simplify code */
909   PetscCall(MatSeqAIJCompactOutExtraColumns_SeqAIJKokkos(mm->B_other,l2g2));
910   PetscCall(MatProductCreate(Ao,mm->B_other,NULL,&C2));
911   PetscCall(MatProductSetType(C2,MATPRODUCT_AB));
912   PetscCall(MatProductSetFill(C2,product->fill));
913   C2->product->api_user = product->api_user;
914   PetscCall(MatProductSetFromOptions(C2));
915   PetscCheck(C2->ops->productsymbolic,PETSC_COMM_SELF,PETSC_ERR_PLIB,"Missing symbolic op for %s",MatProductTypes[C2->product->type]);
916   PetscCall((*C2->ops->productsymbolic)(C2));
917   PetscCall(MatSeqAIJKokkosGetCSRMatrixWithGlobalColumnIds(C2,N,l2g2,mm->C2_global));
918 
919   /* C = C1 + C2.  We actually use their global col ids versions in adding */
920   mm->kh.create_spadd_handle(false); /* Input C1, C2 are NOT sorted, since B_local, B_other are not */
921   KokkosSparse::spadd_symbolic(&mm->kh,mm->C1_global,mm->C2_global,mm->C_global);
922   /* Have to do numeric since spadd_symbolic does not really populate column indices of the result matrix */
923   KokkosSparse::spadd_numeric(&mm->kh,(MatScalarType)1.0,mm->C1_global,(MatScalarType)1.0,mm->C2_global,mm->C_global);
924 
925   mm->C1 = C1;
926   mm->C2 = C2;
927   PetscCall(ISRestoreIndices(glob,&garray));
928   PetscCall(ISDestroy(&glob));
929   PetscFunctionReturn(0);
930 }
931 
932 /* MatProductSymbolic_MPIAIJKokkos_AtB - A^tB flavor of MatProductSymbolic_MPIAIJKokkos
933 
934   Input Parameters:
935 +  product  - Mat_Product which carried out the computation. Passed in to access info about this mat product.
936 .  A        - an MPIAIJKOKKOS matrix
937 .  B        - a SEQAIJKOKKOS matrix. It works as if A^t is multiplied by a parallel matrix made up of Bs on each rank.
938 .  localB   - Does B use local col ids? If false, then B is already in global col ids.
939 .  N        - col size of the "parallel B matrix". It implies B's global col ids are in range of [0,N) and N is the same across the communicator.
940 .  l2g      - If localB, then l2g maps B's local col ids to global ones.
941 -  mm       - a struct used to stash intermediate data in AtB
942 
943   Notes: The local part of the result C is stored as mm->C_global, which is of type KokkosCsrMatrix and uses global col ids.
944 */
945 static PetscErrorCode MatProductSymbolic_MPIAIJKokkos_AtB(Mat_Product *product,Mat A,Mat B,PetscBool localB,PetscInt N,const ConstMatColIdxKokkosView& l2g,MatMatStruct_AtB *mm)
946 {
947   Mat_MPIAIJ             *a = static_cast<Mat_MPIAIJ*>(A->data);
948   Mat                    Ad = a->A,Ao = a->B; /* diag and offdiag of A */
949   Mat                    C1,C2; /* intermediate matrices */
950 
951   PetscFunctionBegin;
952   /* C1 = Ad^t * B */
953   PetscCall(MatProductCreate(Ad,B,NULL,&C1));
954   PetscCall(MatProductSetType(C1,MATPRODUCT_AtB));
955   PetscCall(MatProductSetFill(C1,product->fill));
956   C1->product->api_user = product->api_user;
957   PetscCall(MatProductSetFromOptions(C1));
958   PetscCheck(C1->ops->productsymbolic,PETSC_COMM_SELF,PETSC_ERR_PLIB,"Missing symbolic op for %s",MatProductTypes[C1->product->type]);
959   PetscCall((*C1->ops->productsymbolic)(C1));
960 
961   if (localB) PetscCall(MatSeqAIJKokkosGetCSRMatrixWithGlobalColumnIds(C1,N,l2g,mm->C1_global));
962   else mm->C1_global = static_cast<Mat_SeqAIJKokkos*>(C1->spptr)->csrmat; /* the csrmat already uses global col ids */
963 
964   /* C2 = Ao^t * B */
965   PetscCall(MatProductCreate(Ao,B,NULL,&C2));
966   PetscCall(MatProductSetType(C2,MATPRODUCT_AtB));
967   PetscCall(MatProductSetFill(C2,product->fill));
968   C2->product->api_user = product->api_user;
969   PetscCall(MatProductSetFromOptions(C2));
970   PetscCheck(C2->ops->productsymbolic,PETSC_COMM_SELF,PETSC_ERR_PLIB,"Missing symbolic op for %s",MatProductTypes[C2->product->type]);
971   PetscCall((*C2->ops->productsymbolic)(C2));
972 
973   PetscCall(MatSeqAIJKokkosReduce(C2,MAT_INITIAL_MATRIX,localB,N,l2g,a->Mvctx,mm->sf,mm->abuf,mm->srcrowoffset,mm->dstrowoffset,mm->C2_global));
974 
975   mm->kh.create_spadd_handle(false); /* Input C1, C2 are NOT sorted, since B may be not */
976   KokkosSparse::spadd_symbolic(&mm->kh,mm->C1_global,mm->C2_global,mm->C_global);
977   /* Have to do numeric since spadd_symbolic does not really populate column indices of the result matrix */
978   KokkosSparse::spadd_numeric(&mm->kh,(MatScalarType)1.0,mm->C1_global,(MatScalarType)1.0,mm->C2_global,mm->C_global);
979   mm->C1 = C1;
980   mm->C2 = C2;
981   PetscFunctionReturn(0);
982 }
983 
984 PetscErrorCode MatProductNumeric_MPIAIJKokkos(Mat C)
985 {
986   Mat_Product                   *product = C->product;
987   MatProductType                ptype;
988   MatProductData_MPIAIJKokkos   *mmdata;
989   MatMatStruct                  *mm = NULL;
990   MatMatStruct_AB               *ab;
991   MatMatStruct_AtB              *atb;
992   Mat                           A,B,Ad,Ao,Bd,Bo;
993   const MatScalarType           one = 1.0; /* Not use literal 1.0 directly, to avoid wrong template instantiation in KokkosSparse::spadd_numeric */
994 
995   PetscFunctionBegin;
996   MatCheckProduct(C,1);
997   mmdata = static_cast<MatProductData_MPIAIJKokkos*>(product->data);
998   ptype  = product->type;
999   A      = product->A;
1000   B      = product->B;
1001   Ad     = static_cast<Mat_MPIAIJ*>(A->data)->A;
1002   Ao     = static_cast<Mat_MPIAIJ*>(A->data)->B;
1003   Bd     = static_cast<Mat_MPIAIJ*>(B->data)->A;
1004   Bo     = static_cast<Mat_MPIAIJ*>(B->data)->B;
1005 
1006   if (mmdata->reusesym) { /* We reached here through e.g., MatMatMult(A,B,MAT_INITIAL_MATRIX,..,C), where symbolic/numeric are combined */
1007     mmdata->reusesym = PETSC_FALSE; /* So that next time when user calls MatMatMult(E,F,MAT_REUSE_MATRIX,..,C), we still do numeric  */
1008     ab  = mmdata->mmAB;
1009     atb = mmdata->mmAtB;
1010     if (ab) {
1011       static_cast<MatProductData_SeqAIJKokkos*>(ab->C1->product->data)->reusesym = PETSC_FALSE;
1012       static_cast<MatProductData_SeqAIJKokkos*>(ab->C2->product->data)->reusesym = PETSC_FALSE;
1013     }
1014     if (atb) {
1015       static_cast<MatProductData_SeqAIJKokkos*>(atb->C1->product->data)->reusesym = PETSC_FALSE;
1016       static_cast<MatProductData_SeqAIJKokkos*>(atb->C2->product->data)->reusesym = PETSC_FALSE;
1017     }
1018     PetscFunctionReturn(0);
1019   }
1020 
1021   if (ptype == MATPRODUCT_AB) {
1022     ab   = mmdata->mmAB;
1023     /* C1 = Ad * B_local */
1024     PetscCheck(ab->C1->ops->productnumeric && ab->C2->ops->productnumeric,PETSC_COMM_SELF,PETSC_ERR_PLIB,"Missing numeric op for MATPRODUCT_AB");
1025     PetscCall(MatMPIAIJGetLocalMatMerge(B,MAT_REUSE_MATRIX,NULL/*glob*/,&ab->B_local));
1026     PetscCheck(ab->C1->product->B == ab->B_local,PETSC_COMM_SELF,PETSC_ERR_PLIB,"In MATPRODUCT_AB, internal mat product matrix C1->B has unexpectedly changed");
1027     if (ab->C1->product->A != Ad) PetscCall(MatProductReplaceMats(Ad,NULL,NULL,ab->C1));
1028     PetscCall((*ab->C1->ops->productnumeric)(ab->C1));
1029     PetscCall(MatSeqAIJKokkosBcast(ab->B_local,MAT_REUSE_MATRIX,0/*N*/,MatColIdxKokkosView()/*l2g*/,NULL/*ownerSF*/,ab->sf,
1030                                  ab->abuf,ab->rows,ab->rowoffset,ab->B_other));
1031     /* C2 = Ao * B_other */
1032     PetscCheck(ab->C2->product->B == ab->B_other,PETSC_COMM_SELF,PETSC_ERR_PLIB,"In MATPRODUCT_AB, internal mat product matrix C2->B has unexpectedly changed");
1033     if (ab->C1->product->A != Ao) PetscCall(MatProductReplaceMats(Ao,NULL,NULL,ab->C2));
1034     PetscCall((*ab->C2->ops->productnumeric)(ab->C2));
1035     /* C = C1_global + C2_global */
1036     KokkosSparse::spadd_numeric(&ab->kh,one,ab->C1_global,one,ab->C2_global,ab->C_global);
1037     mm = static_cast<MatMatStruct*>(ab);
1038   } else if (ptype == MATPRODUCT_AtB) {
1039     atb  = mmdata->mmAtB;
1040     PetscCheck(atb->C1->ops->productnumeric && atb->C2->ops->productnumeric,PETSC_COMM_SELF,PETSC_ERR_PLIB,"Missing numeric op for MATPRODUCT_AtB");
1041     /* C1 = Ad^t * B_local */
1042     PetscCall(MatMPIAIJGetLocalMatMerge(B,MAT_REUSE_MATRIX,NULL/*glob*/,&atb->B_local));
1043     PetscCheck(atb->C1->product->B == atb->B_local,PETSC_COMM_SELF,PETSC_ERR_PLIB,"In MATPRODUCT_AtB, internal mat product matrix C1->B has unexpectedly changed");
1044     if (atb->C1->product->A != Ad) PetscCall(MatProductReplaceMats(Ad,NULL,NULL,atb->C1));
1045     PetscCall((*atb->C1->ops->productnumeric)(atb->C1));
1046 
1047     /* C2 = Ao^t * B_local */
1048     PetscCheck(atb->C2->product->B == atb->B_local,PETSC_COMM_SELF,PETSC_ERR_PLIB,"In MATPRODUCT_AtB, internal mat product matrix C2->B has unexpectedly changed");
1049     if (atb->C2->product->A != Ao) PetscCall(MatProductReplaceMats(Ao,NULL,NULL,atb->C2));
1050     PetscCall((*atb->C2->ops->productnumeric)(atb->C2));
1051     /* Form C2_global */
1052     PetscCall(MatSeqAIJKokkosReduce(atb->C2,MAT_REUSE_MATRIX,PETSC_TRUE,0/*N*/,MatColIdxKokkosView()/*l2g*/,NULL/*ownerSF*/,atb->sf,
1053                                   atb->abuf,atb->srcrowoffset,atb->dstrowoffset,atb->C2_global));
1054     /* C = C1_global + C2_global */
1055     KokkosSparse::spadd_numeric(&atb->kh,one,atb->C1_global,one,atb->C2_global,atb->C_global);
1056     mm = static_cast<MatMatStruct*>(atb);
1057   } else if (ptype == MATPRODUCT_PtAP) { /* BtAB */
1058     ab   = mmdata->mmAB;
1059     PetscCall(MatMPIAIJGetLocalMatMerge(B,MAT_REUSE_MATRIX,NULL/*glob*/,&ab->B_local));
1060 
1061     /* ab->C1 = Ad * B_local */
1062     PetscCheck(ab->C1->product->B == ab->B_local,PETSC_COMM_SELF,PETSC_ERR_PLIB,"In MATPRODUCT_PtAP, internal mat product matrix ab->C1->B has unexpectedly changed");
1063     if (ab->C1->product->A != Ad) PetscCall(MatProductReplaceMats(Ad,NULL,NULL,ab->C1));
1064     PetscCall((*ab->C1->ops->productnumeric)(ab->C1));
1065     PetscCall(MatSeqAIJKokkosBcast(ab->B_local,MAT_REUSE_MATRIX,0/*N*/,MatColIdxKokkosView()/*l2g*/,NULL/*ownerSF*/,ab->sf,
1066                                  ab->abuf,ab->rows,ab->rowoffset,ab->B_other));
1067     /* ab->C2 = Ao * B_other */
1068     if (ab->C2->product->A != Ao) PetscCall(MatProductReplaceMats(Ao,NULL,NULL,ab->C2));
1069     PetscCall((*ab->C2->ops->productnumeric)(ab->C2)); /* C2 = Ao * B_other */
1070     KokkosSparse::spadd_numeric(&ab->kh,one,ab->C1_global,one,ab->C2_global,ab->C_global);
1071 
1072     /* atb->C1 = Bd^t * ab->C_petsc */
1073     atb  = mmdata->mmAtB;
1074     PetscCheck(atb->C1->product->B == ab->C_petsc,PETSC_COMM_SELF,PETSC_ERR_PLIB,"In MATPRODUCT_PtAP, internal mat product matrix atb->C1->B has unexpectedly changed");
1075     if (atb->C1->product->A != Bd) PetscCall(MatProductReplaceMats(Bd,NULL,NULL,atb->C1));
1076     PetscCall((*atb->C1->ops->productnumeric)(atb->C1));
1077     /* atb->C2 = Bo^t * ab->C_petsc */
1078     if (atb->C2->product->A != Bo) PetscCall(MatProductReplaceMats(Bo,NULL,NULL,atb->C2));
1079     PetscCall((*atb->C2->ops->productnumeric)(atb->C2));
1080     PetscCall(MatSeqAIJKokkosReduce(atb->C2,MAT_REUSE_MATRIX,PETSC_FALSE,0/*N*/,MatColIdxKokkosView()/*l2g*/,NULL/*ownerSF*/,atb->sf,
1081                                   atb->abuf,atb->srcrowoffset,atb->dstrowoffset,atb->C2_global));
1082     KokkosSparse::spadd_numeric(&atb->kh,one,atb->C1_global,one,atb->C2_global,atb->C_global);
1083     mm = static_cast<MatMatStruct*>(atb);
1084   }
1085   /* Split C_global to form C */
1086   PetscCall(MatSetMPIAIJKokkosWithGlobalCSRMatrix(C,MAT_REUSE_MATRIX,mm->C_global,mm->Cdstart));
1087   PetscFunctionReturn(0);
1088 }
1089 
1090 PetscErrorCode MatProductSymbolic_MPIAIJKokkos(Mat C)
1091 {
1092   Mat                         A,B;
1093   Mat_Product                 *product = C->product;
1094   MatProductType              ptype;
1095   MatProductData_MPIAIJKokkos *mmdata;
1096   MatMatStruct                *mm = NULL;
1097   IS                          glob = NULL;
1098   const PetscInt              *garray;
1099   PetscInt                    m,n,M,N,sz;
1100   ConstMatColIdxKokkosView    l2g; /* map local col ids to global ones */
1101 
1102   PetscFunctionBegin;
1103   MatCheckProduct(C,1);
1104   PetscCheck(!product->data,PetscObjectComm((PetscObject)C),PETSC_ERR_PLIB,"Product data not empty");
1105   ptype = product->type;
1106   A     = product->A;
1107   B     = product->B;
1108 
1109   switch (ptype) {
1110     case MATPRODUCT_AB:   m = A->rmap->n; n = B->cmap->n; M = A->rmap->N; N = B->cmap->N; break;
1111     case MATPRODUCT_AtB:  m = A->cmap->n; n = B->cmap->n; M = A->cmap->N; N = B->cmap->N; break;
1112     case MATPRODUCT_PtAP: m = B->cmap->n; n = B->cmap->n; M = B->cmap->N; N = B->cmap->N; break; /* BtAB */
1113     default: SETERRQ(PetscObjectComm((PetscObject)C),PETSC_ERR_PLIB,"Not for product type %s",MatProductTypes[ptype]);
1114   }
1115 
1116   PetscCall(MatSetSizes(C,m,n,M,N));
1117   PetscCall(PetscLayoutSetUp(C->rmap));
1118   PetscCall(PetscLayoutSetUp(C->cmap));
1119   PetscCall(MatSetType(C,((PetscObject)A)->type_name));
1120 
1121   mmdata           = new MatProductData_MPIAIJKokkos();
1122   mmdata->reusesym = product->api_user;
1123 
1124   if (ptype == MATPRODUCT_AB) {
1125     mmdata->mmAB = new MatMatStruct_AB();
1126     PetscCall(MatProductSymbolic_MPIAIJKokkos_AB(product,A,B,mmdata->mmAB));
1127     mm   = static_cast<MatMatStruct*>(mmdata->mmAB);
1128   } else if (ptype == MATPRODUCT_AtB) {
1129     mmdata->mmAtB = new MatMatStruct_AtB();
1130     auto atb      = mmdata->mmAtB;
1131     PetscCall(MatMPIAIJGetLocalMatMerge(B,MAT_INITIAL_MATRIX,&glob,&atb->B_local));
1132     PetscCall(ISGetIndices(glob,&garray));
1133     PetscCall(ISGetSize(glob,&sz));
1134     l2g  = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(),ConstMatColIdxKokkosViewHost(garray,sz));
1135     PetscCall(MatProductSymbolic_MPIAIJKokkos_AtB(product,A,atb->B_local,PETSC_TRUE,N,l2g,atb));
1136     PetscCall(ISRestoreIndices(glob,&garray));
1137     PetscCall(ISDestroy(&glob));
1138     mm   = static_cast<MatMatStruct*>(atb);
1139   } else if (ptype == MATPRODUCT_PtAP) { /* BtAB */
1140     mmdata->mmAB  = new MatMatStruct_AB(); /* tmp=A*B */
1141     mmdata->mmAtB = new MatMatStruct_AtB(); /* C=B^t*tmp */
1142     auto ab       = mmdata->mmAB;
1143     auto atb      = mmdata->mmAtB;
1144     PetscCall(MatProductSymbolic_MPIAIJKokkos_AB(product,A,B,ab));
1145     auto tmp = new Mat_SeqAIJKokkos(ab->C_global); /* Memory will be owned by ab->C_petsc */
1146     PetscCall(MatCreateSeqAIJKokkosWithCSRMatrix(PETSC_COMM_SELF,tmp,&ab->C_petsc));
1147     PetscCall(MatProductSymbolic_MPIAIJKokkos_AtB(product,B,ab->C_petsc,PETSC_FALSE,N,l2g/*not used*/,atb));
1148     mm   = static_cast<MatMatStruct*>(atb);
1149   }
1150   /* Split the C_global into petsc A, B format */
1151   PetscCall(MatSetMPIAIJKokkosWithGlobalCSRMatrix(C,MAT_INITIAL_MATRIX,mm->C_global,mm->Cdstart));
1152   C->product->data        = mmdata;
1153   C->product->destroy     = MatProductDataDestroy_MPIAIJKokkos;
1154   C->ops->productnumeric  = MatProductNumeric_MPIAIJKokkos;
1155   PetscFunctionReturn(0);
1156 }
1157 
1158 PETSC_INTERN PetscErrorCode MatProductSetFromOptions_MPIAIJKokkos(Mat mat)
1159 {
1160   Mat_Product    *product = mat->product;
1161   PetscBool      match = PETSC_FALSE;
1162   PetscBool      usecpu = PETSC_FALSE;
1163 
1164   PetscFunctionBegin;
1165   MatCheckProduct(mat,1);
1166   if (!product->A->boundtocpu && !product->B->boundtocpu) {
1167     PetscCall(PetscObjectTypeCompare((PetscObject)product->B,((PetscObject)product->A)->type_name,&match));
1168   }
1169   if (match) { /* we can always fallback to the CPU if requested */
1170     switch (product->type) {
1171     case MATPRODUCT_AB:
1172       if (product->api_user) {
1173         PetscOptionsBegin(PetscObjectComm((PetscObject)mat),((PetscObject)mat)->prefix,"MatMatMult","Mat");
1174         PetscCall(PetscOptionsBool("-matmatmult_backend_cpu","Use CPU code","MatMatMult",usecpu,&usecpu,NULL));
1175         PetscOptionsEnd();
1176       } else {
1177         PetscOptionsBegin(PetscObjectComm((PetscObject)mat),((PetscObject)mat)->prefix,"MatProduct_AB","Mat");
1178         PetscCall(PetscOptionsBool("-mat_product_algorithm_backend_cpu","Use CPU code","MatMatMult",usecpu,&usecpu,NULL));
1179         PetscOptionsEnd();
1180       }
1181       break;
1182     case MATPRODUCT_AtB:
1183       if (product->api_user) {
1184         PetscOptionsBegin(PetscObjectComm((PetscObject)mat),((PetscObject)mat)->prefix,"MatTransposeMatMult","Mat");
1185         PetscCall(PetscOptionsBool("-mattransposematmult_backend_cpu","Use CPU code","MatTransposeMatMult",usecpu,&usecpu,NULL));
1186         PetscOptionsEnd();
1187       } else {
1188         PetscOptionsBegin(PetscObjectComm((PetscObject)mat),((PetscObject)mat)->prefix,"MatProduct_AtB","Mat");
1189         PetscCall(PetscOptionsBool("-mat_product_algorithm_backend_cpu","Use CPU code","MatTransposeMatMult",usecpu,&usecpu,NULL));
1190         PetscOptionsEnd();
1191       }
1192       break;
1193     case MATPRODUCT_PtAP:
1194       if (product->api_user) {
1195         PetscOptionsBegin(PetscObjectComm((PetscObject)mat),((PetscObject)mat)->prefix,"MatPtAP","Mat");
1196         PetscCall(PetscOptionsBool("-matptap_backend_cpu","Use CPU code","MatPtAP",usecpu,&usecpu,NULL));
1197         PetscOptionsEnd();
1198       } else {
1199         PetscOptionsBegin(PetscObjectComm((PetscObject)mat),((PetscObject)mat)->prefix,"MatProduct_PtAP","Mat");
1200         PetscCall(PetscOptionsBool("-mat_product_algorithm_backend_cpu","Use CPU code","MatPtAP",usecpu,&usecpu,NULL));
1201         PetscOptionsEnd();
1202       }
1203       break;
1204     default:
1205       break;
1206     }
1207     match = (PetscBool)!usecpu;
1208   }
1209   if (match) {
1210     switch (product->type) {
1211     case MATPRODUCT_AB:
1212     case MATPRODUCT_AtB:
1213     case MATPRODUCT_PtAP:
1214       mat->ops->productsymbolic = MatProductSymbolic_MPIAIJKokkos;
1215       break;
1216     default:
1217       break;
1218     }
1219   }
1220   /* fallback to MPIAIJ ops */
1221   if (!mat->ops->productsymbolic) {
1222     PetscCall(MatProductSetFromOptions_MPIAIJ(mat));
1223   }
1224   PetscFunctionReturn(0);
1225 }
1226 
1227 static PetscErrorCode MatSetPreallocationCOO_MPIAIJKokkos(Mat mat, PetscCount coo_n, const PetscInt coo_i[], const PetscInt coo_j[])
1228 {
1229   Mat_MPIAIJ       *mpiaij = (Mat_MPIAIJ*)mat->data;
1230   Mat_MPIAIJKokkos *mpikok;
1231 
1232   PetscFunctionBegin;
1233   PetscCall(MatSetPreallocationCOO_MPIAIJ(mat,coo_n,coo_i,coo_j));
1234   mat->preallocated = PETSC_TRUE;
1235   PetscCall(MatAssemblyBegin(mat,MAT_FINAL_ASSEMBLY));
1236   PetscCall(MatAssemblyEnd(mat,MAT_FINAL_ASSEMBLY));
1237   PetscCall(MatZeroEntries(mat));
1238   mpikok = static_cast<Mat_MPIAIJKokkos*>(mpiaij->spptr);
1239   delete mpikok;
1240   mpiaij->spptr = new Mat_MPIAIJKokkos(mpiaij);
1241   PetscFunctionReturn(0);
1242 }
1243 
1244 static PetscErrorCode MatSetValuesCOO_MPIAIJKokkos(Mat mat,const PetscScalar v[],InsertMode imode)
1245 {
1246   Mat_MPIAIJ                     *mpiaij = static_cast<Mat_MPIAIJ*>(mat->data);
1247   Mat_MPIAIJKokkos               *mpikok = static_cast<Mat_MPIAIJKokkos*>(mpiaij->spptr);
1248   Mat                            A = mpiaij->A,B = mpiaij->B;
1249   PetscCount                     Annz = mpiaij->Annz,Annz2 = mpiaij->Annz2,Bnnz = mpiaij->Bnnz,Bnnz2 = mpiaij->Bnnz2;
1250   MatScalarKokkosView            Aa,Ba;
1251   MatScalarKokkosView            v1;
1252   MatScalarKokkosView&           vsend = mpikok->sendbuf_d;
1253   const MatScalarKokkosView&     v2 = mpikok->recvbuf_d;
1254   const PetscCountKokkosView&    Ajmap1 = mpikok->Ajmap1_d,Ajmap2 = mpikok->Ajmap2_d,Aimap2 = mpikok->Aimap2_d;
1255   const PetscCountKokkosView&    Bjmap1 = mpikok->Bjmap1_d,Bjmap2 = mpikok->Bjmap2_d,Bimap2 = mpikok->Bimap2_d;
1256   const PetscCountKokkosView&    Aperm1 = mpikok->Aperm1_d,Aperm2 = mpikok->Aperm2_d,Bperm1 = mpikok->Bperm1_d,Bperm2 = mpikok->Bperm2_d;
1257   const PetscCountKokkosView&    Cperm1 = mpikok->Cperm1_d;
1258   PetscMemType                   memtype;
1259 
1260   PetscFunctionBegin;
1261   PetscCall(PetscGetMemType(v,&memtype)); /* Return PETSC_MEMTYPE_HOST when v is NULL */
1262   if (PetscMemTypeHost(memtype)) { /* If user gave v[] in host, we need to copy it to device if any */
1263     v1 = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(),MatScalarKokkosViewHost((PetscScalar*)v,mpiaij->coo_n));
1264   } else {
1265     v1 = MatScalarKokkosView((PetscScalar*)v,mpiaij->coo_n); /* Directly use v[]'s memory */
1266   }
1267 
1268   if (imode == INSERT_VALUES) {
1269     PetscCall(MatSeqAIJGetKokkosViewWrite(A,&Aa)); /* write matrix values */
1270     PetscCall(MatSeqAIJGetKokkosViewWrite(B,&Ba));
1271   } else {
1272     PetscCall(MatSeqAIJGetKokkosView(A,&Aa)); /* read & write matrix values */
1273     PetscCall(MatSeqAIJGetKokkosView(B,&Ba));
1274   }
1275 
1276   /* Pack entries to be sent to remote */
1277   Kokkos::parallel_for(vsend.extent(0),KOKKOS_LAMBDA(const PetscCount i) {vsend(i) = v1(Cperm1(i));});
1278 
1279   /* Send remote entries to their owner and overlap the communication with local computation */
1280   PetscCall(PetscSFReduceWithMemTypeBegin(mpiaij->coo_sf,MPIU_SCALAR,PETSC_MEMTYPE_KOKKOS,vsend.data(),PETSC_MEMTYPE_KOKKOS,v2.data(),MPI_REPLACE));
1281   /* Add local entries to A and B in one kernel */
1282   Kokkos::parallel_for(Annz+Bnnz,KOKKOS_LAMBDA(PetscCount i) {
1283     PetscScalar sum = 0.0;
1284     if (i<Annz) {
1285       for (PetscCount k=Ajmap1(i); k<Ajmap1(i+1); k++) sum += v1(Aperm1(k));
1286       Aa(i) = (imode == INSERT_VALUES? 0.0 : Aa(i)) + sum;
1287     } else {
1288       i -= Annz;
1289       for (PetscCount k=Bjmap1(i); k<Bjmap1(i+1); k++) sum += v1(Bperm1(k));
1290       Ba(i) = (imode == INSERT_VALUES? 0.0 : Ba(i)) + sum;
1291     }
1292   });
1293   PetscCall(PetscSFReduceEnd(mpiaij->coo_sf,MPIU_SCALAR,vsend.data(),v2.data(),MPI_REPLACE));
1294 
1295   /* Add received remote entries to A and B in one kernel */
1296   Kokkos::parallel_for(Annz2+Bnnz2,KOKKOS_LAMBDA(PetscCount i) {
1297     if (i < Annz2) {
1298       for (PetscCount k=Ajmap2(i); k<Ajmap2(i+1); k++) Aa(Aimap2(i)) += v2(Aperm2(k));
1299     } else {
1300       i -= Annz2;
1301       for (PetscCount k=Bjmap2(i); k<Bjmap2(i+1); k++) Ba(Bimap2(i)) += v2(Bperm2(k));
1302     }
1303   });
1304 
1305   if (imode == INSERT_VALUES) {
1306     PetscCall(MatSeqAIJRestoreKokkosViewWrite(A,&Aa)); /* Increase A & B's state etc. */
1307     PetscCall(MatSeqAIJRestoreKokkosViewWrite(B,&Ba));
1308   } else {
1309     PetscCall(MatSeqAIJRestoreKokkosView(A,&Aa));
1310     PetscCall(MatSeqAIJRestoreKokkosView(B,&Ba));
1311   }
1312   PetscFunctionReturn(0);
1313 }
1314 
1315 PetscErrorCode MatDestroy_MPIAIJKokkos(Mat A)
1316 {
1317   Mat_MPIAIJ         *mpiaij = (Mat_MPIAIJ*)A->data;
1318 
1319   PetscFunctionBegin;
1320   PetscCall(PetscObjectComposeFunction((PetscObject)A,"MatMPIAIJSetPreallocation_C",NULL));
1321   PetscCall(PetscObjectComposeFunction((PetscObject)A,"MatMPIAIJGetLocalMatMerge_C",NULL));
1322   PetscCall(PetscObjectComposeFunction((PetscObject)A,"MatSetPreallocationCOO_C",   NULL));
1323   PetscCall(PetscObjectComposeFunction((PetscObject)A,"MatSetValuesCOO_C",          NULL));
1324   delete (Mat_MPIAIJKokkos*)mpiaij->spptr;
1325   PetscCall(MatDestroy_MPIAIJ(A));
1326   PetscFunctionReturn(0);
1327 }
1328 
1329 PETSC_INTERN PetscErrorCode MatConvert_MPIAIJ_MPIAIJKokkos(Mat A, MatType mtype, MatReuse reuse, Mat* newmat)
1330 {
1331   Mat                B;
1332   Mat_MPIAIJ         *a;
1333 
1334   PetscFunctionBegin;
1335   if (reuse == MAT_INITIAL_MATRIX) {
1336     PetscCall(MatDuplicate(A,MAT_COPY_VALUES,newmat));
1337   } else if (reuse == MAT_REUSE_MATRIX) {
1338     PetscCall(MatCopy(A,*newmat,SAME_NONZERO_PATTERN));
1339   }
1340   B = *newmat;
1341 
1342   B->boundtocpu = PETSC_FALSE;
1343   PetscCall(PetscFree(B->defaultvectype));
1344   PetscCall(PetscStrallocpy(VECKOKKOS,&B->defaultvectype));
1345   PetscCall(PetscObjectChangeTypeName((PetscObject)B,MATMPIAIJKOKKOS));
1346 
1347   a = static_cast<Mat_MPIAIJ*>(A->data);
1348   if (a->A) PetscCall(MatSetType(a->A,MATSEQAIJKOKKOS));
1349   if (a->B) PetscCall(MatSetType(a->B,MATSEQAIJKOKKOS));
1350   if (a->lvec) PetscCall(VecSetType(a->lvec,VECSEQKOKKOS));
1351 
1352   B->ops->assemblyend           = MatAssemblyEnd_MPIAIJKokkos;
1353   B->ops->mult                  = MatMult_MPIAIJKokkos;
1354   B->ops->multadd               = MatMultAdd_MPIAIJKokkos;
1355   B->ops->multtranspose         = MatMultTranspose_MPIAIJKokkos;
1356   B->ops->productsetfromoptions = MatProductSetFromOptions_MPIAIJKokkos;
1357   B->ops->destroy               = MatDestroy_MPIAIJKokkos;
1358 
1359   PetscCall(PetscObjectComposeFunction((PetscObject)B,"MatMPIAIJSetPreallocation_C",MatMPIAIJSetPreallocation_MPIAIJKokkos));
1360   PetscCall(PetscObjectComposeFunction((PetscObject)B,"MatMPIAIJGetLocalMatMerge_C",MatMPIAIJGetLocalMatMerge_MPIAIJKokkos));
1361   PetscCall(PetscObjectComposeFunction((PetscObject)B,"MatSetPreallocationCOO_C",   MatSetPreallocationCOO_MPIAIJKokkos));
1362   PetscCall(PetscObjectComposeFunction((PetscObject)B,"MatSetValuesCOO_C",          MatSetValuesCOO_MPIAIJKokkos));
1363   PetscFunctionReturn(0);
1364 }
1365 /*MC
1366    MATSMPIAIJKOKKOS - MATAIJKOKKOS = "(mpi)aijkokkos" - A matrix type to be used for sparse matrices with Kokkos
1367 
1368    A matrix type type using Kokkos-Kernels CrsMatrix type for portability across different device types
1369 
1370    Options Database Keys:
1371 .  -mat_type aijkokkos - sets the matrix type to "aijkokkos" during a call to MatSetFromOptions()
1372 
1373   Level: beginner
1374 
1375 .seealso: `MatCreateAIJKokkos()`, `MATSEQAIJKOKKOS`
1376 M*/
1377 PETSC_EXTERN PetscErrorCode MatCreate_MPIAIJKokkos(Mat A)
1378 {
1379   PetscFunctionBegin;
1380   PetscCall(PetscKokkosInitializeCheck());
1381   PetscCall(MatCreate_MPIAIJ(A));
1382   PetscCall(MatConvert_MPIAIJ_MPIAIJKokkos(A,MATMPIAIJKOKKOS,MAT_INPLACE_MATRIX,&A));
1383   PetscFunctionReturn(0);
1384 }
1385 
1386 /*@C
1387    MatCreateAIJKokkos - Creates a sparse matrix in AIJ (compressed row) format
1388    (the default parallel PETSc format).  This matrix will ultimately pushed down
1389    to Kokkos for calculations. For good matrix
1390    assembly performance the user should preallocate the matrix storage by setting
1391    the parameter nz (or the array nnz).  By setting these parameters accurately,
1392    performance during matrix assembly can be increased by more than a factor of 50.
1393 
1394    Collective
1395 
1396    Input Parameters:
1397 +  comm - MPI communicator, set to PETSC_COMM_SELF
1398 .  m - number of rows
1399 .  n - number of columns
1400 .  nz - number of nonzeros per row (same for all rows)
1401 -  nnz - array containing the number of nonzeros in the various rows
1402          (possibly different for each row) or NULL
1403 
1404    Output Parameter:
1405 .  A - the matrix
1406 
1407    It is recommended that one use the MatCreate(), MatSetType() and/or MatSetFromOptions(),
1408    MatXXXXSetPreallocation() paradigm instead of this routine directly.
1409    [MatXXXXSetPreallocation() is, for example, MatSeqAIJSetPreallocation]
1410 
1411    Notes:
1412    If nnz is given then nz is ignored
1413 
1414    The AIJ format (also called the Yale sparse matrix format or
1415    compressed row storage), is fully compatible with standard Fortran 77
1416    storage.  That is, the stored row and column indices can begin at
1417    either one (as in Fortran) or zero.  See the users' manual for details.
1418 
1419    Specify the preallocated storage with either nz or nnz (not both).
1420    Set nz=PETSC_DEFAULT and nnz=NULL for PETSc to control dynamic memory
1421    allocation.  For large problems you MUST preallocate memory or you
1422    will get TERRIBLE performance, see the users' manual chapter on matrices.
1423 
1424    By default, this format uses inodes (identical nodes) when possible, to
1425    improve numerical efficiency of matrix-vector products and solves. We
1426    search for consecutive rows with the same nonzero structure, thereby
1427    reusing matrix information to achieve increased efficiency.
1428 
1429    Level: intermediate
1430 
1431 .seealso: `MatCreate()`, `MatCreateAIJ()`, `MatSetValues()`, `MatSeqAIJSetColumnIndices()`, `MatCreateSeqAIJWithArrays()`, `MatCreateAIJ()`, `MATMPIAIJKOKKOS`, `MATAIJKokkos`
1432 @*/
1433 PetscErrorCode  MatCreateAIJKokkos(MPI_Comm comm,PetscInt m,PetscInt n,PetscInt M,PetscInt N,PetscInt d_nz,const PetscInt d_nnz[],PetscInt o_nz,const PetscInt o_nnz[],Mat *A)
1434 {
1435   PetscMPIInt    size;
1436 
1437   PetscFunctionBegin;
1438   PetscCall(MatCreate(comm,A));
1439   PetscCall(MatSetSizes(*A,m,n,M,N));
1440   PetscCallMPI(MPI_Comm_size(comm,&size));
1441   if (size > 1) {
1442     PetscCall(MatSetType(*A,MATMPIAIJKOKKOS));
1443     PetscCall(MatMPIAIJSetPreallocation(*A,d_nz,d_nnz,o_nz,o_nnz));
1444   } else {
1445     PetscCall(MatSetType(*A,MATSEQAIJKOKKOS));
1446     PetscCall(MatSeqAIJSetPreallocation(*A,d_nz,d_nnz));
1447   }
1448   PetscFunctionReturn(0);
1449 }
1450 
1451 // get GPU pointer to stripped down Mat. For both Seq and MPI Mat.
1452 PetscErrorCode MatKokkosGetDeviceMatWrite(Mat A, PetscSplitCSRDataStructure *B)
1453 {
1454   PetscMPIInt                size,rank;
1455   MPI_Comm                   comm;
1456   PetscSplitCSRDataStructure d_mat=NULL;
1457 
1458   PetscFunctionBegin;
1459   PetscCall(PetscObjectGetComm((PetscObject)A,&comm));
1460   PetscCallMPI(MPI_Comm_size(comm,&size));
1461   PetscCallMPI(MPI_Comm_rank(comm,&rank));
1462   if (size == 1) {
1463     PetscCall(MatSeqAIJKokkosGetDeviceMat(A,&d_mat));
1464     PetscCall(MatSeqAIJKokkosModifyDevice(A)); /* Since we are going to modify matrix values on device */
1465   } else {
1466     Mat_MPIAIJ  *aij = (Mat_MPIAIJ*)A->data;
1467     PetscCall(MatSeqAIJKokkosGetDeviceMat(aij->A,&d_mat));
1468     PetscCall(MatSeqAIJKokkosModifyDevice(aij->A));
1469     PetscCall(MatSeqAIJKokkosModifyDevice(aij->B));
1470     PetscCheck(A->nooffprocentries || aij->donotstash,PetscObjectComm((PetscObject)A),PETSC_ERR_SUP,"Device assembly does not currently support offproc values insertion. Use MatSetOption(A,MAT_NO_OFF_PROC_ENTRIES,PETSC_TRUE) or MatSetOption(A,MAT_IGNORE_OFF_PROC_ENTRIES,PETSC_TRUE)");
1471   }
1472   // act like MatSetValues because not called on host
1473   if (A->assembled) {
1474     if (A->was_assembled) {
1475       PetscCall(PetscInfo(A,"Assemble more than once already\n"));
1476     }
1477     A->was_assembled = PETSC_TRUE; // this is done (lazy) in MatAssemble but we are not calling it anymore - done in AIJ AssemblyEnd, need here?
1478   } else {
1479     PetscCall(PetscInfo(A,"Warning !assemble ??? assembled=%" PetscInt_FMT "\n",A->assembled));
1480   }
1481   if (!d_mat) {
1482     struct _n_SplitCSRMat h_mat; /* host container */
1483     Mat_SeqAIJKokkos      *aijkokA;
1484     Mat_SeqAIJ            *jaca;
1485     PetscInt              n = A->rmap->n, nnz;
1486     Mat                   Amat;
1487     PetscInt              *colmap;
1488 
1489     /* create and copy h_mat */
1490     h_mat.M = A->cmap->N; // use for debug build
1491     PetscCall(PetscInfo(A,"Create device matrix in Kokkos\n"));
1492     if (size == 1) {
1493       Amat = A;
1494       jaca = (Mat_SeqAIJ*)A->data;
1495       h_mat.rstart = 0; h_mat.rend = A->rmap->n;
1496       h_mat.cstart = 0; h_mat.cend = A->cmap->n;
1497       h_mat.offdiag.i = h_mat.offdiag.j = NULL;
1498       h_mat.offdiag.a = NULL;
1499       aijkokA = static_cast<Mat_SeqAIJKokkos*>(A->spptr);
1500     } else {
1501       Mat_MPIAIJ       *aij = (Mat_MPIAIJ*)A->data;
1502       Mat_SeqAIJ       *jacb = (Mat_SeqAIJ*)aij->B->data;
1503       PetscInt         ii;
1504       Mat_SeqAIJKokkos *aijkokB;
1505 
1506       Amat = aij->A;
1507       aijkokA = static_cast<Mat_SeqAIJKokkos*>(aij->A->spptr);
1508       aijkokB = static_cast<Mat_SeqAIJKokkos*>(aij->B->spptr);
1509       jaca = (Mat_SeqAIJ*)aij->A->data;
1510       PetscCheck(!aij->B->cmap->n || aij->garray,comm,PETSC_ERR_PLIB,"MPIAIJ Matrix was assembled but is missing garray");
1511       PetscCheck(aij->B->rmap->n == aij->A->rmap->n,comm,PETSC_ERR_SUP,"Only support aij->B->rmap->n == aij->A->rmap->n");
1512       aij->donotstash = PETSC_TRUE;
1513       aij->A->nooffprocentries = aij->B->nooffprocentries = A->nooffprocentries = PETSC_TRUE;
1514       jaca->nonew = jacb->nonew = PETSC_TRUE; // no more disassembly
1515       PetscCall(PetscCalloc1(A->cmap->N,&colmap));
1516       PetscCall(PetscLogObjectMemory((PetscObject)A,(A->cmap->N)*sizeof(PetscInt)));
1517       for (ii=0; ii<aij->B->cmap->n; ii++) colmap[aij->garray[ii]] = ii+1;
1518       // allocate B copy data
1519       h_mat.rstart = A->rmap->rstart; h_mat.rend = A->rmap->rend;
1520       h_mat.cstart = A->cmap->rstart; h_mat.cend = A->cmap->rend;
1521       nnz = jacb->i[n];
1522       if (jacb->compressedrow.use) {
1523         const Kokkos::View<PetscInt*, Kokkos::HostSpace, Kokkos::MemoryTraits<Kokkos::Unmanaged> > h_i_k (jacb->i,n+1);
1524         aijkokB->i_uncompressed_d = Kokkos::View<PetscInt*>(Kokkos::create_mirror(DefaultMemorySpace(),h_i_k));
1525         Kokkos::deep_copy (aijkokB->i_uncompressed_d, h_i_k);
1526         h_mat.offdiag.i = aijkokB->i_uncompressed_d.data();
1527       } else {
1528          h_mat.offdiag.i = aijkokB->i_device_data();
1529       }
1530       h_mat.offdiag.j = aijkokB->j_device_data();
1531       h_mat.offdiag.a = aijkokB->a_device_data();
1532       {
1533         Kokkos::View<PetscInt*, Kokkos::HostSpace, Kokkos::MemoryTraits<Kokkos::Unmanaged> > h_colmap_k (colmap,A->cmap->N);
1534         aijkokB->colmap_d = Kokkos::View<PetscInt*>(Kokkos::create_mirror(DefaultMemorySpace(),h_colmap_k));
1535         Kokkos::deep_copy (aijkokB->colmap_d, h_colmap_k);
1536         h_mat.colmap = aijkokB->colmap_d.data();
1537         PetscCall(PetscFree(colmap));
1538       }
1539       h_mat.offdiag.ignorezeroentries = jacb->ignorezeroentries;
1540       h_mat.offdiag.n = n;
1541     }
1542     // allocate A copy data
1543     nnz = jaca->i[n];
1544     h_mat.diag.n = n;
1545     h_mat.diag.ignorezeroentries = jaca->ignorezeroentries;
1546     PetscCallMPI(MPI_Comm_rank(comm,&h_mat.rank));
1547     PetscCheck(!jaca->compressedrow.use,PETSC_COMM_SELF,PETSC_ERR_PLIB,"A does not suppport compressed row (todo)");
1548     else {
1549       h_mat.diag.i = aijkokA->i_device_data();
1550     }
1551     h_mat.diag.j = aijkokA->j_device_data();
1552     h_mat.diag.a = aijkokA->a_device_data();
1553     // copy pointers and metdata to device
1554     PetscCall(MatSeqAIJKokkosSetDeviceMat(Amat,&h_mat));
1555     PetscCall(MatSeqAIJKokkosGetDeviceMat(Amat,&d_mat));
1556     PetscCall(PetscInfo(A,"Create device Mat n=%" PetscInt_FMT " nnz=%" PetscInt_FMT "\n",h_mat.diag.n, nnz));
1557   }
1558   *B = d_mat; // return it, set it in Mat, and set it up
1559   A->assembled = PETSC_FALSE; // ready to write with matsetvalues - this done (lazy) in normal MatSetValues
1560   PetscFunctionReturn(0);
1561 }
1562 
1563 PETSC_INTERN PetscErrorCode MatSeqAIJKokkosGetOffloadMask(Mat A, const char **mask)
1564 {
1565   Mat_SeqAIJKokkos  *aijkok = static_cast<Mat_SeqAIJKokkos*>(A->spptr);
1566 
1567   PetscFunctionBegin;
1568   if (!aijkok) *mask = "AIJKOK_UNALLOCATED";
1569   else if (aijkok->a_dual.need_sync_host()) *mask = "PETSC_OFFLOAD_GPU";
1570   else if (aijkok->a_dual.need_sync_device()) *mask = "PETSC_OFFLOAD_CPU";
1571   else *mask = "PETSC_OFFLOAD_BOTH";
1572   PetscFunctionReturn(0);
1573 }
1574 
1575 PETSC_INTERN PetscErrorCode MatAIJKokkosPrintOffloadMask(Mat A)
1576 {
1577   PetscMPIInt  size;
1578   Mat          Ad,Ao;
1579   const char  *amask,*bmask;
1580 
1581   PetscFunctionBegin;
1582   PetscCallMPI(MPI_Comm_size(PetscObjectComm((PetscObject)A),&size));
1583 
1584   if (size == 1) {
1585     PetscCall(MatSeqAIJKokkosGetOffloadMask(A,&amask));
1586     PetscCall(PetscPrintf(PETSC_COMM_SELF,"%s\n",amask));
1587   } else {
1588     Ad  = ((Mat_MPIAIJ*)A->data)->A;
1589     Ao  = ((Mat_MPIAIJ*)A->data)->B;
1590     PetscCall(MatSeqAIJKokkosGetOffloadMask(Ad,&amask));
1591     PetscCall(MatSeqAIJKokkosGetOffloadMask(Ao,&bmask));
1592     PetscCall(PetscPrintf(PETSC_COMM_SELF,"Diag : Off-diag = %s : %s\n",amask,bmask));
1593   }
1594   PetscFunctionReturn(0);
1595 }
1596