xref: /petsc/src/mat/impls/aij/mpi/kokkos/mpiaijkok.kokkos.cxx (revision b45e3bf4ff73d80a20c3202b6cd9f79d2f2d3efe)
1 #include <petscvec_kokkos.hpp>
2 #include <petscsf.h>
3 #include <petsc/private/sfimpl.h>
4 #include <../src/mat/impls/aij/mpi/mpiaij.h>
5 #include <../src/mat/impls/aij/mpi/kokkos/mpiaijkok.hpp>
6 #include <KokkosSparse_spadd.hpp>
7 
8 PetscErrorCode MatAssemblyEnd_MPIAIJKokkos(Mat A,MatAssemblyType mode)
9 {
10   PetscErrorCode   ierr;
11   Mat_MPIAIJ       *mpiaij = (Mat_MPIAIJ*)A->data;
12   Mat_SeqAIJKokkos *aijkok = mpiaij->A->spptr ? static_cast<Mat_SeqAIJKokkos*>(mpiaij->A->spptr) : NULL;
13 
14   PetscFunctionBegin;
15   ierr = MatAssemblyEnd_MPIAIJ(A,mode);CHKERRQ(ierr);
16   if (aijkok && aijkok->device_mat_d.data()) {
17     A->offloadmask = PETSC_OFFLOAD_GPU; // in GPU mode, no going back. MatSetValues checks this
18   }
19 
20   PetscFunctionReturn(0);
21 }
22 
23 PetscErrorCode MatMPIAIJSetPreallocation_MPIAIJKokkos(Mat mat,PetscInt d_nz,const PetscInt d_nnz[],PetscInt o_nz,const PetscInt o_nnz[])
24 {
25   PetscErrorCode ierr;
26   Mat_MPIAIJ     *mpiaij = (Mat_MPIAIJ*)mat->data;
27 
28   PetscFunctionBegin;
29   ierr = PetscLayoutSetUp(mat->rmap);CHKERRQ(ierr);
30   ierr = PetscLayoutSetUp(mat->cmap);CHKERRQ(ierr);
31 #if defined(PETSC_USE_DEBUG)
32   if (d_nnz) {
33     PetscInt i;
34     for (i=0; i<mat->rmap->n; i++) {
35       PetscCheckFalse(d_nnz[i] < 0,PETSC_COMM_SELF,PETSC_ERR_ARG_OUTOFRANGE,"d_nnz cannot be less than 0: local row %" PetscInt_FMT " value %" PetscInt_FMT,i,d_nnz[i]);
36     }
37   }
38   if (o_nnz) {
39     PetscInt i;
40     for (i=0; i<mat->rmap->n; i++) {
41       PetscCheckFalse(o_nnz[i] < 0,PETSC_COMM_SELF,PETSC_ERR_ARG_OUTOFRANGE,"o_nnz cannot be less than 0: local row %" PetscInt_FMT " value %" PetscInt_FMT,i,o_nnz[i]);
42     }
43   }
44 #endif
45 #if defined(PETSC_USE_CTABLE)
46   ierr = PetscTableDestroy(&mpiaij->colmap);CHKERRQ(ierr);
47 #else
48   ierr = PetscFree(mpiaij->colmap);CHKERRQ(ierr);
49 #endif
50   ierr = PetscFree(mpiaij->garray);CHKERRQ(ierr);
51   ierr = VecDestroy(&mpiaij->lvec);CHKERRQ(ierr);
52   ierr = VecScatterDestroy(&mpiaij->Mvctx);CHKERRQ(ierr);
53   /* Because the B will have been resized we simply destroy it and create a new one each time */
54   ierr = MatDestroy(&mpiaij->B);CHKERRQ(ierr);
55 
56   if (!mpiaij->A) {
57     ierr = MatCreate(PETSC_COMM_SELF,&mpiaij->A);CHKERRQ(ierr);
58     ierr = MatSetSizes(mpiaij->A,mat->rmap->n,mat->cmap->n,mat->rmap->n,mat->cmap->n);CHKERRQ(ierr);
59     ierr = PetscLogObjectParent((PetscObject)mat,(PetscObject)mpiaij->A);CHKERRQ(ierr);
60   }
61   if (!mpiaij->B) {
62     PetscMPIInt size;
63     ierr = MPI_Comm_size(PetscObjectComm((PetscObject)mat),&size);CHKERRMPI(ierr);
64     ierr = MatCreate(PETSC_COMM_SELF,&mpiaij->B);CHKERRQ(ierr);
65     ierr = MatSetSizes(mpiaij->B,mat->rmap->n,size > 1 ? mat->cmap->N : 0,mat->rmap->n,size > 1 ? mat->cmap->N : 0);CHKERRQ(ierr);
66     ierr = PetscLogObjectParent((PetscObject)mat,(PetscObject)mpiaij->B);CHKERRQ(ierr);
67   }
68   ierr = MatSetType(mpiaij->A,MATSEQAIJKOKKOS);CHKERRQ(ierr);
69   ierr = MatSetType(mpiaij->B,MATSEQAIJKOKKOS);CHKERRQ(ierr);
70   ierr = MatSeqAIJSetPreallocation(mpiaij->A,d_nz,d_nnz);CHKERRQ(ierr);
71   ierr = MatSeqAIJSetPreallocation(mpiaij->B,o_nz,o_nnz);CHKERRQ(ierr);
72   mat->preallocated = PETSC_TRUE;
73   PetscFunctionReturn(0);
74 }
75 
76 PetscErrorCode MatMult_MPIAIJKokkos(Mat mat,Vec xx,Vec yy)
77 {
78   Mat_MPIAIJ     *mpiaij = (Mat_MPIAIJ*)mat->data;
79   PetscErrorCode ierr;
80   PetscInt       nt;
81 
82   PetscFunctionBegin;
83   ierr = VecGetLocalSize(xx,&nt);CHKERRQ(ierr);
84   PetscCheckFalse(nt != mat->cmap->n,PETSC_COMM_SELF,PETSC_ERR_ARG_SIZ,"Incompatible partition of mat (%" PetscInt_FMT ") and xx (%" PetscInt_FMT ")",mat->cmap->n,nt);
85   ierr = VecScatterBegin(mpiaij->Mvctx,xx,mpiaij->lvec,INSERT_VALUES,SCATTER_FORWARD);CHKERRQ(ierr);
86   ierr = (*mpiaij->A->ops->mult)(mpiaij->A,xx,yy);CHKERRQ(ierr);
87   ierr = VecScatterEnd(mpiaij->Mvctx,xx,mpiaij->lvec,INSERT_VALUES,SCATTER_FORWARD);CHKERRQ(ierr);
88   ierr = (*mpiaij->B->ops->multadd)(mpiaij->B,mpiaij->lvec,yy,yy);CHKERRQ(ierr);
89   PetscFunctionReturn(0);
90 }
91 
92 PetscErrorCode MatMultAdd_MPIAIJKokkos(Mat mat,Vec xx,Vec yy,Vec zz)
93 {
94   Mat_MPIAIJ     *mpiaij = (Mat_MPIAIJ*)mat->data;
95   PetscErrorCode ierr;
96   PetscInt       nt;
97 
98   PetscFunctionBegin;
99   ierr = VecGetLocalSize(xx,&nt);CHKERRQ(ierr);
100   PetscCheckFalse(nt != mat->cmap->n,PETSC_COMM_SELF,PETSC_ERR_ARG_SIZ,"Incompatible partition of mat (%" PetscInt_FMT ") and xx (%" PetscInt_FMT ")",mat->cmap->n,nt);
101   ierr = VecScatterBegin(mpiaij->Mvctx,xx,mpiaij->lvec,INSERT_VALUES,SCATTER_FORWARD);CHKERRQ(ierr);
102   ierr = (*mpiaij->A->ops->multadd)(mpiaij->A,xx,yy,zz);CHKERRQ(ierr);
103   ierr = VecScatterEnd(mpiaij->Mvctx,xx,mpiaij->lvec,INSERT_VALUES,SCATTER_FORWARD);CHKERRQ(ierr);
104   ierr = (*mpiaij->B->ops->multadd)(mpiaij->B,mpiaij->lvec,zz,zz);CHKERRQ(ierr);
105   PetscFunctionReturn(0);
106 }
107 
108 PetscErrorCode MatMultTranspose_MPIAIJKokkos(Mat mat,Vec xx,Vec yy)
109 {
110   Mat_MPIAIJ     *mpiaij = (Mat_MPIAIJ*)mat->data;
111   PetscErrorCode ierr;
112   PetscInt       nt;
113 
114   PetscFunctionBegin;
115   ierr = VecGetLocalSize(xx,&nt);CHKERRQ(ierr);
116   PetscCheckFalse(nt != mat->rmap->n,PETSC_COMM_SELF,PETSC_ERR_ARG_SIZ,"Incompatible partition of mat (%" PetscInt_FMT ") and xx (%" PetscInt_FMT ")",mat->rmap->n,nt);
117   ierr = (*mpiaij->B->ops->multtranspose)(mpiaij->B,xx,mpiaij->lvec);CHKERRQ(ierr);
118   ierr = (*mpiaij->A->ops->multtranspose)(mpiaij->A,xx,yy);CHKERRQ(ierr);
119   ierr = VecScatterBegin(mpiaij->Mvctx,mpiaij->lvec,yy,ADD_VALUES,SCATTER_REVERSE);CHKERRQ(ierr);
120   ierr = VecScatterEnd(mpiaij->Mvctx,mpiaij->lvec,yy,ADD_VALUES,SCATTER_REVERSE);CHKERRQ(ierr);
121   PetscFunctionReturn(0);
122 }
123 
124 /* Merge the "A, B" matrices of mat into a matrix C.  mat's type is MPIAIJKOKKOS. C's type is MATSEQAIJKOKKOS.
125    A is put before B. C's size would be A->rmap->n by (A->cmap->n + B->cmap->n).
126    C still uses local column ids. Their corresponding global column ids are returned in glob.
127 */
128 PetscErrorCode MatMPIAIJGetLocalMatMerge_MPIAIJKokkos(Mat mat,MatReuse reuse,IS *glob,Mat *C)
129 {
130   Mat            Ad,Ao;
131   const PetscInt *cmap;
132   PetscErrorCode ierr;
133 
134   PetscFunctionBegin;
135   ierr = MatMPIAIJGetSeqAIJ(mat,&Ad,&Ao,&cmap);CHKERRQ(ierr);
136   ierr = MatSeqAIJKokkosMergeMats(Ad,Ao,reuse,C);CHKERRQ(ierr);
137   if (glob) {
138     PetscInt cst, i, dn, on, *gidx;
139     ierr = MatGetLocalSize(Ad,NULL,&dn);CHKERRQ(ierr);
140     ierr = MatGetLocalSize(Ao,NULL,&on);CHKERRQ(ierr);
141     ierr = MatGetOwnershipRangeColumn(mat,&cst,NULL);CHKERRQ(ierr);
142     ierr = PetscMalloc1(dn+on,&gidx);CHKERRQ(ierr);
143     for (i=0; i<dn; i++) gidx[i]    = cst + i;
144     for (i=0; i<on; i++) gidx[i+dn] = cmap[i];
145     ierr = ISCreateGeneral(PetscObjectComm((PetscObject)Ad),dn+on,gidx,PETSC_OWN_POINTER,glob);CHKERRQ(ierr);
146   }
147   PetscFunctionReturn(0);
148 }
149 
150 /* Structs used in matrix product C=AB, C=A^tB and C=B^tAB */
151 struct MatMatStruct {
152   MatRowMapKokkosView   Cdstart; /* Used to split sequential matrix into petsc's A, B format */
153   PetscSF               sf; /* SF to send/recv matrix entries */
154   MatScalarKokkosView   abuf; /* buf of mat values in send/recv */
155   Mat                   C1,C2,B_local;
156   KokkosCsrMatrix       C1_global,C2_global,C_global;
157   KernelHandle          kh;
158   MatMatStruct() {
159     C1 = C2 = B_local = NULL;
160     sf = NULL;
161   }
162 
163   ~MatMatStruct() {
164     MatDestroy(&C1);
165     MatDestroy(&C2);
166     MatDestroy(&B_local);
167     PetscSFDestroy(&sf);
168     kh.destroy_spadd_handle();
169   }
170 };
171 
172 struct MatMatStruct_AB : public MatMatStruct {
173   MatColIdxKokkosView   rows;
174   MatRowMapKokkosView   rowoffset;
175   Mat                   B_other,C_petsc; /* SEQAIJKOKKOS matrices. TODO: have a better var name than C_petsc */
176 
177   MatMatStruct_AB() : B_other(NULL),C_petsc(NULL){}
178   ~MatMatStruct_AB() {
179     MatDestroy(&B_other);
180     MatDestroy(&C_petsc);
181   }
182 };
183 
184 struct MatMatStruct_AtB : public MatMatStruct {
185   MatRowMapKokkosView   srcrowoffset,dstrowoffset;
186 };
187 
188 struct MatProductData_MPIAIJKokkos
189 {
190   MatMatStruct_AB   *mmAB;
191   MatMatStruct_AtB  *mmAtB;
192   PetscBool         reusesym;
193 
194   MatProductData_MPIAIJKokkos(): mmAB(NULL),mmAtB(NULL),reusesym(PETSC_FALSE){}
195   ~MatProductData_MPIAIJKokkos() {
196     delete mmAB;
197     delete mmAtB;
198   }
199 };
200 
201 static PetscErrorCode MatProductDataDestroy_MPIAIJKokkos(void *data)
202 {
203   PetscFunctionBegin;
204   CHKERRCXX(delete static_cast<MatProductData_MPIAIJKokkos*>(data));
205   PetscFunctionReturn(0);
206 }
207 
208 /* MatSeqAIJKokkosGetCSRMatrixWithGlobalColumnIds - Get a KokkosCsrMatrix from a MATSEQAIJKOKKOS matrix
209 
210    Input Parameters:
211 +  A       - the MATSEQAIJKOKKOS matrix
212 .  N       - new column size for the returned Kokkos matrix
213 -  l2g     - a map that maps old col ids to new col ids
214 
215    Output Parameters:
216 .  csrmat  - the Kokkos matrix, which has the same row size as A, shares a, i but not j with A.
217  */
218 static PetscErrorCode MatSeqAIJKokkosGetCSRMatrixWithGlobalColumnIds(Mat A,PetscInt N,const ConstMatColIdxKokkosView& l2g,KokkosCsrMatrix& csrmat)
219 {
220   KokkosCsrMatrix&         orig = static_cast<Mat_SeqAIJKokkos*>(A->spptr)->csrmat;
221   MatColIdxKokkosView      jg("jg",orig.nnz()); /* New j array for csrmat */
222 
223   PetscFunctionBegin;
224   CHKERRCXX(Kokkos::parallel_for(orig.nnz(), KOKKOS_LAMBDA(const PetscInt i) {jg(i) = l2g(orig.graph.entries(i));}));
225   CHKERRCXX(csrmat = KokkosCsrMatrix("csrmat",orig.numRows(),N,orig.nnz(),orig.values,orig.graph.row_map,jg));
226   PetscFunctionReturn(0);
227 }
228 
229 /* MatSetMPIAIJKokkosWithSplitSeqAIJKokkosMatrices - Set the diag and offdiag matrices of a MATMPIAIJKOKKOS matrix.
230    It is similar to MatCreateMPIAIJWithSplitArrays.
231 
232   Input Parameters:
233 +  mat   - the MATMPIAIJKOKKOS matrix, which should have its type and layout set, but should not have its diag, offdiag matrices set
234 .  A     - the diag matrix using local col ids
235 -  B     - the offdiag matrix using global col ids
236 
237   Output Parameters:
238 .  mat   - the updated MATMPIAIJKOKKOS matrix
239 */
240 static PetscErrorCode MatSetMPIAIJKokkosWithSplitSeqAIJKokkosMatrices(Mat mat,Mat A,Mat B)
241 {
242   PetscErrorCode      ierr;
243   Mat_MPIAIJ          *mpiaij = static_cast<Mat_MPIAIJ*>(mat->data);
244   PetscInt            m,n,M,N,Am,An,Bm,Bn;
245   Mat_SeqAIJKokkos    *bkok = static_cast<Mat_SeqAIJKokkos*>(B->spptr);
246 
247   PetscFunctionBegin;
248   ierr = MatGetSize(mat,&M,&N);CHKERRQ(ierr);
249   ierr = MatGetLocalSize(mat,&m,&n);CHKERRQ(ierr);
250   ierr = MatGetLocalSize(A,&Am,&An);CHKERRQ(ierr);
251   ierr = MatGetLocalSize(B,&Bm,&Bn);CHKERRQ(ierr);
252 
253   PetscCheckFalse(m != Am || m != Bm,PETSC_COMM_SELF,PETSC_ERR_PLIB,"local number of rows do not match");
254   PetscCheckFalse(n != An,PETSC_COMM_SELF,PETSC_ERR_PLIB,"local number of columns do not match");
255   PetscCheckFalse(N != Bn,PETSC_COMM_SELF,PETSC_ERR_PLIB,"global number of columns do not match");
256   PetscCheckFalse(mpiaij->A || mpiaij->B,PETSC_COMM_SELF,PETSC_ERR_PLIB,"A, B of the MPIAIJ matrix are not empty");
257   mpiaij->A = A;
258   mpiaij->B = B;
259 
260   mat->preallocated     = PETSC_TRUE;
261   mat->nooffprocentries = PETSC_TRUE; /* See MatAssemblyBegin_MPIAIJ. In effect, making MatAssemblyBegin a nop */
262 
263   ierr = MatSetOption(mat,MAT_NO_OFF_PROC_ENTRIES,PETSC_TRUE);CHKERRQ(ierr);
264   ierr = MatAssemblyBegin(mat,MAT_FINAL_ASSEMBLY);CHKERRQ(ierr);
265   /* MatAssemblyEnd is critical here. It sets mat->offloadmask according to A and B's, and
266     also gets mpiaij->B compacted, with its col ids and size reduced
267   */
268   ierr = MatAssemblyEnd(mat,MAT_FINAL_ASSEMBLY);CHKERRQ(ierr);
269   ierr = MatSetOption(mat,MAT_NO_OFF_PROC_ENTRIES,PETSC_FALSE);CHKERRQ(ierr);
270   ierr = MatSetOption(mat,MAT_NEW_NONZERO_LOCATION_ERR,PETSC_TRUE);CHKERRQ(ierr);
271 
272   /* Update bkok with new local col ids (stored on host) and size */
273   bkok->j_dual.modify_host();
274   bkok->j_dual.sync_device();
275   bkok->SetColSize(mpiaij->B->cmap->n);
276   PetscFunctionReturn(0);
277 }
278 
279 /* MatSeqAIJKokkosBcast - Bcast rows of a SEQAIJKOKKOS matrice (B) to form a SEQAIJKOKKOS matrix (C).
280 
281    It is essentially the MPIAIJKOKKOS counterpart of MatGetBrowsOfAoCols_MPIAIJ, but supports device and uses PetscSF.
282    In the given ownerSF, leaves correspond to rows in C, and roots correspond to rows in B. Roots may connect to multiple leaves.
283    Suppose C's j-th row is connected to a root identified by PetscSFNode (k,i), it means we will bcast the i-th row of B on rank k
284    to j-th row of C. ownerSF's leaves must be contiguous (in other words, as if ilocal=NULL was used to set its graph).
285 
286    Collective on comm of ownerSF
287 
288    Input Parameters:
289 +   B       - the SEQAIJKOKKOS matrix, using local col ids
290 .   reuse   - either MAT_INITIAL_MATRIX or MAT_REUSE_MATRIX
291 .   N       - global col ids are in range of [0,N). N Must be the same across ranks (nonsignificant in MAT_REUSE_MATRIX)
292 .   l2g     - a map mapping B's local col ids to global ones (nonsignificant in MAT_REUSE_MATRIX)
293 .   ownerSF - the ownership SF (nonsignificant in MAT_REUSE_MATRIX)
294 
295    Input/Output Parameters (out when resue = MAT_INITIAL_MATRIX, inout when reuse = MAT_REUSE_MATRIX)
296 +   bcastSF   - the SF used to bcast rows of B. This plain SF does buffer (abuf) to buffer (Ca) send/recv. In this SF, vertices are nonzeros.
297 .   abuf      - buffer for sending matrix values
298 .   rows      - array containing indices of (local) rows that this rank needs to bcast to others. Each receiver rank has a chunk in rows[].
299                 Values in rows[] might have repeats, which simply indicates a row will be bcast'ed to multiple neighbors.
300 .   rowoffset - For each row in rows[], it will be copied to rowoffset[] at abuf[]
301 -   C         -  the SEQAIJKOKKOS matrix made of the bcast'ed rows, using local col ids.
302 */
303 static PetscErrorCode MatSeqAIJKokkosBcast(Mat B,MatReuse reuse,PetscInt N,const ConstMatColIdxKokkosView& l2g,PetscSF ownerSF,
304                                            PetscSF& bcastSF,MatScalarKokkosView& abuf,MatColIdxKokkosView& rows,
305                                            MatRowMapKokkosView& rowoffset,Mat& C)
306 {
307   PetscErrorCode               ierr;
308   Mat_SeqAIJKokkos             *bkok,*ckok;
309 
310   PetscFunctionBegin;
311   ierr = MatSeqAIJKokkosSyncDevice(B);CHKERRQ(ierr); /* Make sure B->spptr is accessible */
312   bkok = static_cast<Mat_SeqAIJKokkos*>(B->spptr);
313 
314   if (reuse == MAT_REUSE_MATRIX) {
315     ckok = static_cast<Mat_SeqAIJKokkos*>(C->spptr);
316 
317     const auto& Ba = bkok->a_dual.view_device();
318     const auto& Bi = bkok->i_dual.view_device();
319     const auto& Ca = ckok->a_dual.view_device();
320 
321     /* Copy Ba to abuf */
322     Kokkos::parallel_for(Kokkos::TeamPolicy<>(rows.extent(0), Kokkos::AUTO()),KOKKOS_LAMBDA(const KokkosTeamMemberType& t) {
323       PetscInt i    = t.league_rank(); /* rows[i] is r-th row of B */
324       PetscInt r    = rows(i);
325       PetscInt base = rowoffset(i); /* Copy r-th row of B to this offset in abuf[] */
326       Kokkos::parallel_for(Kokkos::TeamThreadRange(t,Bi(r+1)-Bi(r)),[&](PetscInt k) {
327         abuf(base+k) = Ba(Bi(r)+k);
328       });
329     });
330 
331     /* Send abuf to Ca through bcastSF and then mark C is updated on device */
332     ierr = PetscSFBcastBegin(bcastSF,MPIU_SCALAR,abuf.data(),Ca.data(),MPI_REPLACE);CHKERRQ(ierr); /* TODO: get memtype for abuf */
333     ierr = PetscSFBcastEnd  (bcastSF,MPIU_SCALAR,abuf.data(),Ca.data(),MPI_REPLACE);CHKERRQ(ierr);
334     ckok->a_dual.modify_device();
335   } else if (reuse == MAT_INITIAL_MATRIX) {
336     MPI_Comm       comm;
337     PetscMPIInt    tag;
338     PetscInt       k,Cm,Cn,Cnnz,*Ci_h,nroots,nleaves;
339 
340     ierr = PetscObjectGetComm((PetscObject)ownerSF,&comm);CHKERRMPI(ierr);
341     ierr = PetscSFGetGraph(ownerSF,&nroots,&nleaves,NULL,NULL);CHKERRQ(ierr);
342     Cm   = nleaves; /* row size of C */
343     Cn   = N;  /* col size of C, which initially uses global ids, so we can safely set its col size as N */
344 
345     /* Get row lens (nz) of B's rows for later fast query */
346     PetscInt       *Browlens;
347     const PetscInt *tmp = bkok->i_host_data();
348     ierr = PetscMalloc1(nroots,&Browlens);CHKERRQ(ierr);
349     for (k=0; k<nroots; k++) Browlens[k] = tmp[k+1]-tmp[k];
350 
351     /* By ownerSF, each proc gets lens of rows of C */
352     MatRowMapKokkosDualView Ci("i",Cm+1); /* C's rowmap */
353     Ci_h    = Ci.view_host().data();
354     Ci_h[0] = 0;
355     ierr    = PetscSFBcastWithMemTypeBegin(ownerSF,MPIU_INT,PETSC_MEMTYPE_HOST,Browlens,PETSC_MEMTYPE_HOST,&Ci_h[1],MPI_REPLACE);CHKERRQ(ierr);
356     ierr    = PetscSFBcastEnd(ownerSF,MPIU_INT,Browlens,&Ci_h[1],MPI_REPLACE);CHKERRQ(ierr);
357     for (k=1; k<Cm+1; k++) Ci_h[k] += Ci_h[k-1]; /* Convert lens to CSR */
358     Cnnz    = Ci_h[Cm];
359     Ci.modify_host();
360     Ci.sync_device();
361 
362     /* With the newly known Cnnz, we are able to allocate (j, a) for C on host & device */
363     MatColIdxKokkosDualView  Cj("j",Cnnz);
364     MatScalarKokkosDualView  Ca("a",Cnnz);
365 
366     /* Now build the bcastSF to fill Ca, Cj. This plain SF only does (contiguous) buffer to buffer send/recv */
367     const PetscMPIInt *iranks,*ranks;
368     const PetscInt    *ioffset,*irootloc,*roffset;
369     PetscInt          i,j,niranks,nranks,*sdisp,*rdisp,*rowptr;
370     MPI_Request       *reqs;
371 
372     ierr = PetscSFGetLeafRanks(ownerSF,&niranks,&iranks,&ioffset,&irootloc);CHKERRQ(ierr); /* irootloc[] contains indices of rows I need to send to each receiver */
373     ierr = PetscSFGetRootRanks(ownerSF,&nranks,&ranks,&roffset,NULL/*rmine*/,NULL/*rremote*/);CHKERRQ(ierr); /* recv info */
374 
375     /* figure out offsets at the send buffer, to build the SF
376       sdisp[]  - stores offsets of nonzeros (in abuf or jbuf, see later) I need to send, per receiver.
377       rowptr[] - stores offsets for data of each row in abuf
378 
379       rdisp[]  - to receive sdisp[]
380     */
381     ierr = PetscMalloc3(niranks+1,&sdisp,nranks,&rdisp,niranks+nranks,&reqs);CHKERRQ(ierr);
382     MatRowMapKokkosViewHost rowptr_h("rowptr_h",ioffset[niranks]+1); /* Let Kokkos do the allocation, so that we can do an easy mirror later */
383     rowptr = rowptr_h.data();
384 
385     sdisp[0] = 0;
386     rowptr[0]  = 0;
387     for (i=0; i<niranks; i++) { /* for each receiver */
388       PetscInt len, nz = 0;
389       for (j=ioffset[i]; j<ioffset[i+1]; j++) { /* for each row to this receiver */
390         len         = Browlens[irootloc[j]];
391         rowptr[j+1] = rowptr[j] + len;
392         nz         += len;
393       }
394       sdisp[i+1] = sdisp[i] + nz;
395     }
396     ierr = PetscCommGetNewTag(comm,&tag);CHKERRMPI(ierr);
397     for (i=0; i<nranks; i++)  {ierr = MPI_Irecv(&rdisp[i],1,MPIU_INT,ranks[i],tag,comm,&reqs[i]);CHKERRMPI(ierr);}
398     for (i=0; i<niranks; i++) {ierr = MPI_Isend(&sdisp[i],1,MPIU_INT,iranks[i],tag,comm,&reqs[nranks+i]);CHKERRMPI(ierr);}
399     ierr = MPI_Waitall(niranks+nranks,reqs,MPI_STATUSES_IGNORE);CHKERRMPI(ierr);
400 
401     PetscInt    nleaves2 = Cnnz; /* leaves are the nonzeros I will receive */
402     PetscInt    nroots2  = sdisp[niranks]; /* roots are the nonzeros (in abuf) I will send */
403     PetscSFNode *iremote;
404     ierr = PetscMalloc1(nleaves2,&iremote);CHKERRQ(ierr);
405     for (i=0; i<nranks; i++) { /* for each sender */
406       k = 0;
407       for (j=Ci_h[roffset[i]]; j<Ci_h[roffset[i+1]]; j++) {
408         iremote[j].rank  = ranks[i];
409         iremote[j].index = rdisp[i] + k;
410         k++;
411       }
412     }
413     /* TODO: we should extend PetscSF APIs for this buffer-to-buffer send/recv */
414     ierr = PetscSFCreate(comm,&bcastSF);CHKERRQ(ierr);
415     ierr = PetscSFSetGraph(bcastSF,nroots2,nleaves2,NULL/*ilocal*/,PETSC_OWN_POINTER,iremote,PETSC_OWN_POINTER);CHKERRQ(ierr);
416 
417     /* Extract selected rows of B, and copy their (a, j) into abuf[] and jbuf[], with j converted
418       from local to global. Then use bcastSF to fill Ca, Cj.
419     */
420     ConstMatColIdxKokkosViewHost rows_h(irootloc,ioffset[niranks]); /* irootloc[] stores indices of rows I need to send */
421     MatColIdxKokkosView          rows("rows",ioffset[niranks]);
422     Kokkos::deep_copy(rows,rows_h); /* Use deep copy since irootoc is managed by PetscSF and we want 'rows' to be standalone */
423 
424     rowoffset = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(),rowptr_h); /* If no device, rowoffset will be an alias to rowptr_h */
425 
426     MatColIdxKokkosView jbuf("jbuf",sdisp[niranks]); /* send buf for (global) col ids */
427     abuf = MatScalarKokkosView("abuf",sdisp[niranks]); /* send buf for mat values */
428 
429     const auto& Ba = bkok->a_dual.view_device();
430     const auto& Bi = bkok->i_dual.view_device();
431     const auto& Bj = bkok->j_dual.view_device();
432 
433     /* Copy Ba, Bj to abuf, jbuf with change col ids from local to global */
434     Kokkos::parallel_for(Kokkos::TeamPolicy<>(rows.extent(0), Kokkos::AUTO()),KOKKOS_LAMBDA(const KokkosTeamMemberType& t) {
435       PetscInt i    = t.league_rank(); /* rows[i] is r-th row of B */
436       PetscInt r    = rows(i);
437       PetscInt base = rowoffset(i); /* Copy r-th row of B to this offset in abuf[] */
438       Kokkos::parallel_for(Kokkos::TeamThreadRange(t,Bi(r+1)-Bi(r)),[&](PetscInt k) {
439         abuf(base+k) = Ba(Bi(r)+k);
440         jbuf(base+k) = l2g(Bj(Bi(r)+k));
441       });
442     });
443 
444     /* Send abuf & jbuf to fill Ca, Cj */
445     ierr = PetscSFBcastBegin(bcastSF,MPIU_INT,   jbuf.data(),Cj.view_device().data(),MPI_REPLACE);CHKERRQ(ierr);
446     ierr = PetscSFBcastBegin(bcastSF,MPIU_SCALAR,abuf.data(),Ca.view_device().data(),MPI_REPLACE);CHKERRQ(ierr);
447     ierr = PetscSFBcastEnd  (bcastSF,MPIU_INT,   jbuf.data(),Cj.view_device().data(),MPI_REPLACE);CHKERRQ(ierr);
448     ierr = PetscSFBcastEnd  (bcastSF,MPIU_SCALAR,abuf.data(),Ca.view_device().data(),MPI_REPLACE);CHKERRQ(ierr);
449     Cj.modify_device(); /* Mark Cj, Ca modified on device, but only sync Cj since we might not need Ca on host at all */
450     Cj.sync_host();
451     Ca.modify_device();
452 
453     /* Construct C with Ca, Ci, Cj */
454     auto ckok = new Mat_SeqAIJKokkos(Cm,Cn,Cnnz,Ci,Cj,Ca);
455     ierr = MatCreateSeqAIJKokkosWithCSRMatrix(PETSC_COMM_SELF,ckok,&C);CHKERRQ(ierr);
456     ierr = PetscFree3(sdisp,rdisp,reqs);CHKERRQ(ierr);
457     ierr = PetscFree(Browlens);CHKERRQ(ierr);
458   } else SETERRQ(PETSC_COMM_SELF,PETSC_ERR_PLIB,"Unsupported MatReuse enum %d",reuse);
459   PetscFunctionReturn(0);
460 }
461 
462 /* MatSeqAIJKokkosReduce - Reduce rows of a SEQAIJKOKKOS matrix (A) to form a Kokkos Csr matrix (C)
463 
464   It is the reverse of MatSeqAIJKokkosBcast in some sense.
465 
466   Think each row of A as a leaf, then the given ownerSF specifies roots of the leaves. Roots may connect to multiple leaves.
467   In this routine, we reduce (i.e., concatenate) leaves (rows) at their roots to form potentially longer rows in C. Such rows might
468   contain repeats, which does not matter since they will be summed up by other routines. C's row size will be nroots of ownerSF.
469 
470   Input Parameters:
471 +  A        - the SEQAIJKOKKOS matrix to be reduced
472 .  reuse    - either MAT_INITIAL_MATRIX or MAT_REUSE_MATRIX
473 .  local    - true if A uses local col ids; false if A is already in global col ids.
474 .  N        - if local, N is A's global col size
475 .  l2g      - if local, a map mapping A's local col ids to global ones, which are in range of [0,N).
476 -  ownerSF  - the SF specifies ownership (root) of rows in A
477 
478   Output Parameters:
479 +  reduceSF    - the SF to reduce A's rows to contiguous buffers at the receiver side
480 .  abuf         - a contiguous buffer to receive A's rows sent to this proc. Suppose there are 'nrows' such rows.
481 .  srcrowoffset - offset array of size nrows+1. Each entry is the corresponding row's offset in abuf[]. srcrowoffset[i+1]-srcrowoffset[i] is row i's len.
482 .  dstrowoffset - offset array of size nrows. Each entry is the corresponding row's offset in Ca[], i.e., C's 'a' array. Row i, i+1 in abuf[] may go to
483                   unrelated places in Ca, so dstrowoffset is not in CSR-like format as srcrowoffset.
484 -  C            - the matrix made up by rows sent to me from other ranks, using global col ids
485 
486    TODO: we can even have MatSeqAIJKokkosReduceBegin/End to provide oppertunity for callers to overlap comp./comm. when reuse = MAT_REUSE_MATRIX.
487  */
488 static PetscErrorCode MatSeqAIJKokkosReduce(Mat A,MatReuse reuse,PetscBool local,PetscInt N,const ConstMatColIdxKokkosView& l2g,PetscSF ownerSF,
489                                             PetscSF& reduceSF,MatScalarKokkosView& abuf,
490                                             MatRowMapKokkosView& srcrowoffset,MatRowMapKokkosView& dstrowoffset,
491                                             KokkosCsrMatrix& C)
492 {
493   PetscErrorCode         ierr;
494   PetscInt               i,r,Am,An,Annz,Cnnz,nrows;
495   const PetscInt         *Ai;
496   Mat_SeqAIJKokkos       *akok;
497 
498   PetscFunctionBegin;
499   ierr = MatSeqAIJKokkosSyncDevice(A);CHKERRQ(ierr); /* So that A's latest data is on device */
500   ierr = MatGetSize(A,&Am,&An);
501   Ai   = static_cast<Mat_SeqAIJ*>(A->data)->i;
502   akok = static_cast<Mat_SeqAIJKokkos*>(A->spptr);
503   Annz = Ai[Am];
504 
505   if (reuse == MAT_REUSE_MATRIX) {
506     /* Send Aa to abuf */
507     ierr = PetscSFReduceBegin(reduceSF,MPIU_SCALAR,akok->a_device_data(),abuf.data(),MPI_REPLACE);CHKERRMPI(ierr);
508     ierr = PetscSFReduceEnd  (reduceSF,MPIU_SCALAR,akok->a_device_data(),abuf.data(),MPI_REPLACE);CHKERRMPI(ierr);
509 
510     /* Copy abuf to Ca */
511     const MatScalarKokkosView& Ca = C.values;
512     nrows = dstrowoffset.extent(0); /* Not srcrowoffset[] since it has an extra entry for CSR */
513     Kokkos::parallel_for(Kokkos::TeamPolicy<>(nrows, Kokkos::AUTO()),KOKKOS_LAMBDA(const KokkosTeamMemberType& t) {
514       PetscInt i   = t.league_rank();
515       PetscInt src = srcrowoffset(i), dst = dstrowoffset(i);
516       PetscInt len = srcrowoffset(i+1) - srcrowoffset(i);
517       Kokkos::parallel_for(Kokkos::TeamThreadRange(t,len), [&](PetscInt k) {Ca(dst+k) = abuf(src+k);});
518     });
519   } else if (reuse == MAT_INITIAL_MATRIX) {
520     MPI_Comm               comm;
521     MPI_Request            *reqs;
522     PetscMPIInt            tag;
523     PetscInt               Cm;
524 
525     ierr = PetscObjectGetComm((PetscObject)ownerSF,&comm);CHKERRQ(ierr);
526     ierr = PetscCommGetNewTag(comm,&tag);CHKERRQ(ierr);
527 
528     PetscInt niranks,nranks,nroots,nleaves;
529     const PetscMPIInt *iranks,*ranks;
530     const PetscInt *ioffset,*rows,*roffset;  /* rows[] contains local indices of rows scattered to me from others. ioffset[] is a CSR on rows[] */
531     ierr = PetscSFSetUp(ownerSF);CHKERRQ(ierr);
532     ierr = PetscSFGetLeafRanks(ownerSF,&niranks,&iranks,&ioffset,&rows);CHKERRQ(ierr); /* recv info: iranks[] will send rows to me */
533     ierr = PetscSFGetRootRanks(ownerSF,&nranks,&ranks,&roffset,NULL/*rmine*/,NULL/*rremote*/);CHKERRQ(ierr); /* send info */
534     ierr = PetscSFGetGraph(ownerSF,&nroots,&nleaves,NULL,NULL);CHKERRQ(ierr);
535     PetscCheckFalse(nleaves != Am,PETSC_COMM_SELF,PETSC_ERR_PLIB,"ownerSF's nleaves(%" PetscInt_FMT ") != row size of A(%" PetscInt_FMT ")",nleaves,Am);
536     Cm    = nroots;
537     nrows = ioffset[niranks]; /* # of rows to be received. Might receive same row (each is partial) from different senders */
538 
539     /* Tell owners how long each row I will send */
540     PetscInt                *srowlens; /* send buf of row lens */
541     MatRowMapKokkosViewHost rrowlens_h("rrowoffset_h",nrows+1); /* recv buf of row lens. +1 to make CSR later. Memory might be passed to other views */
542     PetscInt                *rrowlens = rrowlens_h.data();
543 
544     ierr = PetscMalloc2(Am,&srowlens,niranks+nranks,&reqs);CHKERRQ(ierr);
545     for (i=0; i<Am; i++) srowlens[i] = Ai[i+1] - Ai[i];
546     rrowlens[0] = 0;
547     rrowlens++; /* shift the pointer to make the following expression more readable */
548     for (i=0; i<niranks; i++){ierr = MPI_Irecv(&rrowlens[ioffset[i]],ioffset[i+1]-ioffset[i],MPIU_INT,iranks[i],tag,comm,&reqs[i]);CHKERRMPI(ierr);}
549     for (i=0; i<nranks; i++) {ierr = MPI_Isend(&srowlens[roffset[i]],roffset[i+1]-roffset[i],MPIU_INT,ranks[i],tag,comm,&reqs[niranks+i]);CHKERRMPI(ierr);}
550     ierr = MPI_Waitall(niranks+nranks,reqs,MPI_STATUSES_IGNORE);CHKERRMPI(ierr);
551 
552     /* Owner builds Ci on host by histogramming rrowlens[] */
553     MatRowMapKokkosViewHost Ci_h("i",Cm+1);
554     Kokkos::deep_copy(Ci_h,0); /* Zero Ci */
555     MatRowMapType *Ci_ptr = Ci_h.data();
556 
557     for (i=0; i<nrows; i++) {
558       r = rows[i]; /* local row id of i-th received row */
559      #if defined(PETSC_USE_DEBUG)
560       PetscCheckFalse(r<0 || r>=Cm,PETSC_COMM_SELF,PETSC_ERR_PLIB,"local row id (%" PetscInt_FMT ") is out of range [0,%" PetscInt_FMT ")",r,Cm);
561      #endif
562       Ci_ptr[r+1] += rrowlens[i]; /* add to length of row r in C */
563     }
564     for (i=0; i<Cm; i++) Ci_ptr[i+1] += Ci_ptr[i]; /* to CSR format */
565     Cnnz = Ci_ptr[Cm];
566 
567     /* For each received row, compute src & dst offsets in memory copying (from recv bufs abuf, jbuf to Ca, Cj) */
568     MatRowMapKokkosViewHost dstrowoffset_h("dstrowoffset_h",nrows);
569     PetscInt                *dstrowoffset_hptr = dstrowoffset_h.data();
570     PetscInt                *currowlens; /* Current row lens. They are temp accumulators for row lens in C, to help build dstrowoffset */
571 
572     ierr = PetscCalloc1(Cm,&currowlens);CHKERRQ(ierr); /* Init with zero, to be added to */
573     for (i=0; i<nrows; i++) { /* for each row I receive */
574       r                    = rows[i]; /* row id in C */
575       dstrowoffset_hptr[i] = Ci_ptr[r] + currowlens[r]; /* dst offset of the new place for each recv'ed row in Ca/Cj */
576       currowlens[r]       += rrowlens[i]; /* accumulate to length of row r in C */
577     }
578     ierr = PetscFree(currowlens);CHKERRQ(ierr);
579 
580     rrowlens--;
581     for (i=0; i<nrows; i++) rrowlens[i+1] += rrowlens[i]; /* Change rrowlens[] to CSR format */
582     dstrowoffset = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(),dstrowoffset_h);
583     srcrowoffset = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(),rrowlens_h); /* src offset of each recv'ed row in abuf/jbuf */
584 
585     /* Build the reduceSF, which performs buffer to buffer send/recv */
586     PetscInt *sdisp,*rdisp; /* buffer to send offsets of roots, and buffer to recv them */
587     ierr = PetscMalloc2(niranks,&sdisp,nranks,&rdisp);CHKERRQ(ierr);
588     for (i=0; i<niranks; i++) sdisp[i] = rrowlens[ioffset[i]];
589     for (i=0; i<nranks; i++)  {ierr = MPI_Irecv(&rdisp[i],1,MPIU_INT,ranks[i],tag,comm,&reqs[i]);CHKERRMPI(ierr);}
590     for (i=0; i<niranks; i++) {ierr = MPI_Isend(&sdisp[i],1,MPIU_INT,iranks[i],tag,comm,&reqs[nranks+i]);CHKERRMPI(ierr);}
591     ierr = MPI_Waitall(niranks+nranks,reqs,MPI_STATUSES_IGNORE);CHKERRMPI(ierr);
592 
593     /* Nonzeros in abuf/jbuf are roots and those in A are leaves */
594     PetscInt    nroots2 = Cnnz,nleaves2 = Annz;
595     PetscSFNode *iremote;
596     ierr = PetscMalloc1(nleaves2,&iremote);CHKERRQ(ierr); /* no free, since memory will be given to reduceSF */
597     for (i=0; i<nranks; i++) {
598       PetscInt rootbase = rdisp[i]; /* root offset at this root rank */
599       PetscInt leafbase = Ai[roffset[i]]; /* leaf base */
600       PetscInt nz       = Ai[roffset[i+1]] - leafbase; /* I will send nz nonzeros to this root rank */
601       for (PetscInt k=0; k<nz; k++) {
602         iremote[leafbase+k].rank  = ranks[i];
603         iremote[leafbase+k].index = rootbase + k;
604       }
605     }
606     ierr = PetscSFCreate(comm,&reduceSF);CHKERRQ(ierr);
607     ierr = PetscSFSetGraph(reduceSF,nroots2,nleaves2,NULL,PETSC_OWN_POINTER,iremote,PETSC_OWN_POINTER);CHKERRQ(ierr);
608     ierr = PetscFree2(sdisp,rdisp);CHKERRQ(ierr);
609 
610     /* Reduce Aa, Ajg to abuf and jbuf */
611 
612     /* If A uses local col ids, convert them to global ones before sending */
613     MatColIdxKokkosView Ajg;
614     if (local) {
615       Ajg = MatColIdxKokkosView("j",Annz);
616       const MatColIdxKokkosView& Aj = akok->j_dual.view_device();
617       Kokkos::parallel_for(Annz,KOKKOS_LAMBDA(const PetscInt i) {Ajg(i) = l2g(Aj(i));});
618     } else {
619       Ajg = akok->j_dual.view_device(); /* no data copy, just take a reference */
620     }
621 
622     MatColIdxKokkosView   jbuf("jbuf",Cnnz);
623     abuf = MatScalarKokkosView("abuf",Cnnz);
624     ierr = PetscSFReduceBegin(reduceSF,MPIU_INT,   Ajg.data(),           jbuf.data(),MPI_REPLACE);CHKERRMPI(ierr);
625     ierr = PetscSFReduceEnd  (reduceSF,MPIU_INT,   Ajg.data(),           jbuf.data(),MPI_REPLACE);CHKERRMPI(ierr);
626     ierr = PetscSFReduceBegin(reduceSF,MPIU_SCALAR,akok->a_device_data(),abuf.data(),MPI_REPLACE);CHKERRMPI(ierr);
627     ierr = PetscSFReduceEnd  (reduceSF,MPIU_SCALAR,akok->a_device_data(),abuf.data(),MPI_REPLACE);CHKERRMPI(ierr);
628 
629     /* Copy data from abuf, jbuf to Ca, Cj */
630     MatRowMapKokkosView    Ci = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(),Ci_h); /* Ci is an alias of Ci_h if no device */
631     MatColIdxKokkosView    Cj("j",Cnnz);
632     MatScalarKokkosView    Ca("a",Cnnz);
633 
634     Kokkos::parallel_for(Kokkos::TeamPolicy<>(nrows, Kokkos::AUTO()),KOKKOS_LAMBDA(const KokkosTeamMemberType& t) {
635       PetscInt i   = t.league_rank();
636       PetscInt src = srcrowoffset(i), dst = dstrowoffset(i);
637       PetscInt len = srcrowoffset(i+1) - srcrowoffset(i);
638       Kokkos::parallel_for(Kokkos::TeamThreadRange(t,len), [&](PetscInt k) {
639         Ca(dst+k) = abuf(src+k);
640         Cj(dst+k) = jbuf(src+k);
641       });
642     });
643 
644     /* Build C with Ca, Ci, Cj */
645     C    = KokkosCsrMatrix("csrmat",Cm,N,Cnnz,Ca,Ci,Cj);
646     ierr = PetscFree2(srowlens,reqs);CHKERRQ(ierr);
647   } else SETERRQ(PETSC_COMM_SELF,PETSC_ERR_PLIB,"Unsupported MatReuse enum %d",reuse);
648   PetscFunctionReturn(0);
649 }
650 
651 /* MatSetMPIAIJKokkosWithGlobalCSRMatrix - Set the diag and offdiag parts of a MATMPIAIJKOKKOS matrix by splitting a KokkosCsrMatrix
652 
653   Input Parameters:
654 +  C        - the MATMPIAIJKOKKOS matrix, of size m,n,M,N
655 .  reuse    - indicate whether the matrix has called this function before
656 .  csrmat   - the KokkosCsrMatrix, of size m,N
657 -  Cdstart  - when reuse == MAT_REUSE_MATRIX, it is an input parameter. For each row in csrmat, it stores the start of the first
658               entry of the diag block of C in csrmat's j array. E.g, if row i has col ids = {0, 3, 4, 5, 7, 9} and the first diag
659               entry is 5, then Cdstart[i] = 3.
660 
661   Output Parameters:
662 +  C        - the updated MATMPIAIJKOKKOS matrix
663 -  Cdstart - when reuse == MAT_INITIAL_MATRIX, it is an output parameter
664 
665   Notes:
666    Between calls with MAT_INITIAL_MATRIX or MAT_REUSE_MATRIX, csrmat must have the same nonzero pattern
667  */
668 static PetscErrorCode MatSetMPIAIJKokkosWithGlobalCSRMatrix(Mat C,MatReuse reuse,const KokkosCsrMatrix& csrmat,MatRowMapKokkosView& Cdstart)
669 {
670   PetscErrorCode                  ierr;
671   const MatScalarKokkosView&      Ca = csrmat.values;
672   const ConstMatRowMapKokkosView& Ci = csrmat.graph.row_map;
673   PetscInt                        m,n,N;
674 
675   PetscFunctionBegin;
676   ierr = MatGetLocalSize(C,&m,&n);CHKERRQ(ierr);
677   ierr = MatGetSize(C,NULL,&N);CHKERRQ(ierr);
678 
679   if (reuse == MAT_REUSE_MATRIX) {
680     Mat_MPIAIJ                  *mpiaij = static_cast<Mat_MPIAIJ*>(C->data);
681     Mat_SeqAIJKokkos            *akok = static_cast<Mat_SeqAIJKokkos*>(mpiaij->A->spptr);
682     Mat_SeqAIJKokkos            *bkok = static_cast<Mat_SeqAIJKokkos*>(mpiaij->B->spptr);
683     const MatScalarKokkosView&  Cda = akok->a_dual.view_device(),Coa = bkok->a_dual.view_device();
684     const MatRowMapKokkosView&  Cdi = akok->i_dual.view_device(),Coi = bkok->i_dual.view_device();
685 
686     /* Fill 'a' of Cd and Co on device */
687     Kokkos::parallel_for(Kokkos::TeamPolicy<>(m, Kokkos::AUTO()),KOKKOS_LAMBDA(const KokkosTeamMemberType& t) {
688       PetscInt i       = t.league_rank(); /* row i */
689       PetscInt clen    = Ci(i+1) - Ci(i); /* len of row i of C */
690       PetscInt cdlen   = Cdi(i+1) - Cdi(i); /* len of row i of Cd */
691       PetscInt cdstart = Cdstart(i); /* [start, end) of row i of Cd in C */
692       PetscInt cdend   = cdstart + cdlen;
693       /* [0, clen) is cut into three blocks: [0, cdstart), [cdstart, cdend), [cdend, clen) */
694       Kokkos::parallel_for(Kokkos::TeamThreadRange(t, clen), [&](PetscInt k) {
695         if (k < cdstart) {  /* k in [0, cdstart) */
696           Coa(Coi(i)+k) = Ca(Ci(i)+k);
697         } else if (k < cdend) { /* k in [cdstart, cdend) */
698           Cda(Cdi(i)+(k-cdstart)) = Ca(Ci(i)+k);
699         } else { /* k in [cdend, clen) */
700           Coa(Coi(i)+k-cdlen) = Ca(Ci(i)+k);
701         }
702       });
703     });
704 
705     akok->a_dual.modify_device();
706     bkok->a_dual.modify_device();
707   } else if (reuse == MAT_INITIAL_MATRIX) {
708     Mat                         Cd,Co;
709     const MatColIdxKokkosView&  Cj = csrmat.graph.entries;
710     MatRowMapKokkosDualView     Cdi_dual("i",m+1),Coi_dual("i",m+1);
711     MatRowMapKokkosView         Cdi = Cdi_dual.view_device(),Coi = Coi_dual.view_device();
712     PetscInt                    cstart,cend;
713 
714     /* Note that each row of C is sorted by col ids. We want to find out how to cut each row into three blocks:
715        left to the diag block, diag block, right to the diag block. The diag block have col ids in [cstart,cend).
716        Suppose a row of C has len nonzeros, indexed by [0, len). We want to know two indices: cdstart and cdend,
717        such that the three blocks are [0,cdstart), [cdstart,cdend), [cdend,len). The following code equivalentaly
718        stores values of cdstart and cdend-cstart (aka Cdi[]) instead.
719      */
720     Cdstart = MatRowMapKokkosView("Cdstart",m);
721     ierr    = PetscLayoutGetRange(C->cmap,&cstart,&cend);CHKERRQ(ierr); /* Not MatGetOwnershipRangeColumn() since C has not been preallocated yet */
722 
723     /* I could use RangePolicy and one thread per row. But since each thread essentially does binary search, threads in a
724       CUDA warp would completely diverge. So I use TeamPolicy with a team size 1.
725      */
726     Kokkos::parallel_for(Kokkos::TeamPolicy<>(m, 1),KOKKOS_LAMBDA(const KokkosTeamMemberType& t) {
727       Kokkos::single(Kokkos::PerTeam(t), [=] () { /* Only one thread works in a team */
728         PetscInt i = t.league_rank(); /* row i */
729         PetscInt j,first,count,step;
730 
731         if (i == 0) { /* Set the first entry of the i arrays to zero on device, to be used in CSR */
732           Cdi(0) = 0;
733           Coi(0) = 0;
734         }
735 
736         /* Do std::lower_bound(Ci(i),Ci(i+1),cstart) on Cj[]. We use j as the iterator. lower_bound() returns
737           in 'first' the first iterator with a value >= cstart, or last iterator if no such element is found.
738         */
739         count = Ci(i+1)-Ci(i);
740         first = Ci(i);
741         while (count > 0) {
742           j    = first;
743           step = count / 2;
744           j   += step;
745           if (Cj(j) < cstart) {
746             first  = ++j;
747             count -= step + 1;
748           } else count = step;
749         }
750         Cdstart(i) = first - Ci(i); /* 'first' is the while-loop's output */
751 
752         /* Do std::lower_bound(first,Ci(i+1),cend) on Cj[] */
753         count = Ci(i+1) - first;
754         while (count > 0) {
755           j    = first;
756           step = count / 2;
757           j   += step;
758           if (Cj(j) < cend) {
759             first  = ++j;
760             count -= step + 1;
761           } else count = step;
762         }
763         Cdi(i+1) = first - (Ci(i)+Cdstart(i)); /* 'first' is the while-loop's output */
764         Coi(i+1) = (Ci(i+1)-Ci(i)) - Cdi(i+1); /* Co's row len = C's row len - Cd's row len */
765       });
766     });
767 
768     /* Convert row lens in Cdi[], Coi[] to CSR format using inclusive scan, e.g., changing [0,1,2,3] into [0,1,3,6] */
769     Kokkos::parallel_scan(m+1,KOKKOS_LAMBDA(const PetscInt i,PetscInt& update,const bool final) {
770       update += Cdi(i);
771       if (final) Cdi(i) = update;
772     });
773     Kokkos::parallel_scan(m+1,KOKKOS_LAMBDA(const PetscInt i,PetscInt& update,const bool final) {
774       update += Coi(i);
775       if (final) Coi(i) = update;
776     });
777 
778     /* Get Cdi, Coi on host (it is not a waste, since we do need them on host in
779        MatCreateSeqAIJKokkosWithCSRMatrix() below), then get nnz of Cd and Co.
780     */
781     Cdi_dual.modify_device();
782     Coi_dual.modify_device();
783     Cdi_dual.sync_host();
784     Coi_dual.sync_host();
785     PetscInt Cd_nnz = Cdi_dual.view_host().data()[m];
786     PetscInt Co_nnz = Coi_dual.view_host().data()[m];
787 
788     /* With nnz, allocate a, j for Cd and Co */
789     MatColIdxKokkosDualView Cdj_dual("j",Cd_nnz),Coj_dual("j",Co_nnz);
790     MatScalarKokkosDualView Cda_dual("a",Cd_nnz),Coa_dual("a",Co_nnz);
791 
792     /* Fill a, j of Cd and Co on device */
793     MatColIdxKokkosView     Cdj = Cdj_dual.view_device(),Coj = Coj_dual.view_device();
794     MatScalarKokkosView     Cda = Cda_dual.view_device(),Coa = Coa_dual.view_device();
795 
796     Kokkos::parallel_for(Kokkos::TeamPolicy<>(m, Kokkos::AUTO()),KOKKOS_LAMBDA(const KokkosTeamMemberType& t) {
797       PetscInt i       = t.league_rank(); /* row i */
798       PetscInt clen    = Ci(i+1) - Ci(i); /* len of row i of C */
799       PetscInt cdlen   = Cdi(i+1) - Cdi(i); /* len of row i of Cd */
800       PetscInt cdstart = Cdstart(i); /* [start, end) of row i of Cd in C */
801       PetscInt cdend   = cdstart + cdlen;
802       /* [0, clen) is cut into three blocks: [0, cdstart), [cdstart, cdend), [cdend, clen) */
803       Kokkos::parallel_for(Kokkos::TeamThreadRange(t, clen), [&](PetscInt k) {
804         if (k < cdstart) { /* k in [0, cdstart) */
805           Coa(Coi(i)+k) = Ca(Ci(i)+k);
806           Coj(Coi(i)+k) = Cj(Ci(i)+k);
807         } else if (k < cdend) { /* k in [cdstart, cdend) */
808           Cda(Cdi(i)+(k-cdstart)) = Ca(Ci(i)+k);
809           Cdj(Cdi(i)+(k-cdstart)) = Cj(Ci(i)+k) - cstart; /* Use local col ids in Cdj */
810         } else { /* k in [cdend, clen) */
811           Coa(Coi(i)+k-cdlen) = Ca(Ci(i)+k);
812           Coj(Coi(i)+k-cdlen) = Cj(Ci(i)+k);
813         }
814       });
815     });
816 
817     Cdj_dual.modify_device();
818     Cda_dual.modify_device();
819     Coj_dual.modify_device();
820     Coa_dual.modify_device();
821     /* With a, i, j for Cd and Co, finally build Cd, Co and then C. Their offloadmask will be set in each's MatAssemblyEnd */
822     auto cdkok = new Mat_SeqAIJKokkos(m,n,Cd_nnz,Cdi_dual,Cdj_dual,Cda_dual);
823     auto cokok = new Mat_SeqAIJKokkos(m,N,Co_nnz,Coi_dual,Coj_dual,Coa_dual);
824     ierr = MatCreateSeqAIJKokkosWithCSRMatrix(PETSC_COMM_SELF,cdkok,&Cd);CHKERRQ(ierr);
825     ierr = MatCreateSeqAIJKokkosWithCSRMatrix(PETSC_COMM_SELF,cokok,&Co);CHKERRQ(ierr);
826     ierr = MatSetMPIAIJKokkosWithSplitSeqAIJKokkosMatrices(C,Cd,Co);CHKERRQ(ierr); /* Coj will be converted to local ids within */
827   }
828   PetscFunctionReturn(0);
829 }
830 
831 /* MatSeqAIJCompactOutExtraColumns_SeqAIJKokkos - Compact a SEQAIJKOKKS matrix's global col ids.
832 
833   It is similar to MatSeqAIJCompactOutExtraColumns_SeqAIJ, but it applies to SEQAIJKOKKOS and returns the l2g map in Kokkos view.
834 
835   Input Parameters:
836 +  C        - the MATMPIAIJKOKKOS matrix, of size m,n,M,N
837 .  reuse    - indicate whether the matrix has called this function before
838 .  csrmat   - the KokkosCsrMatrix, of size m,N
839 -  Cdoffset - when reuse == MAT_REUSE_MATRIX, it is an input parameter. For each row in csrmat, it stores the offset of the first
840               entry of the diag block of C in csrmat's j array.
841 
842   Output Parameters:
843 +  C        - the updated MATMPIAIJKOKKOS matrix
844 -  Cdoffset - when reuse == MAT_INITIAL_MATRIX, it is an output parameter
845 
846   Notes: the input matrix's col ids and col size will be changed.
847 */
848 static PetscErrorCode MatSeqAIJCompactOutExtraColumns_SeqAIJKokkos(Mat C,MatColIdxKokkosView& l2g)
849 {
850   PetscErrorCode         ierr;
851   Mat_SeqAIJKokkos       *ckok;
852   ISLocalToGlobalMapping l2gmap;
853   const PetscInt         *garray;
854   PetscInt               sz;
855 
856   PetscFunctionBegin;
857   /* Compact P_other's global col ids and col size. We do it since we guess with local ids KK might be more memory scalable */
858   ierr = MatSeqAIJCompactOutExtraColumns_SeqAIJ(C,&l2gmap);CHKERRQ(ierr);
859   ckok = static_cast<Mat_SeqAIJKokkos*>(C->spptr);
860   ckok->j_dual.modify_host(); /* P_other's j is modified on host; we need to sync it on device */
861   ckok->j_dual.sync_device();
862   ckok->SetColSize(C->cmap->n); /* Update col size of the csrmat in spptr */
863 
864   /* Build l2g -- the local to global mapping of C's cols */
865   ierr = ISLocalToGlobalMappingGetIndices(l2gmap,&garray);CHKERRQ(ierr);
866   ierr = ISLocalToGlobalMappingGetSize(l2gmap,&sz);CHKERRQ(ierr);
867   PetscCheckFalse(C->cmap->n != sz,PETSC_COMM_SELF,PETSC_ERR_PLIB,"matrix column size(%" PetscInt_FMT ") != l2g mapping size(%" PetscInt_FMT ")", C->cmap->n,sz);
868 
869   ConstMatColIdxKokkosViewHost tmp(garray,sz);
870   l2g = MatColIdxKokkosView("l2g",sz);
871   Kokkos::deep_copy(l2g,tmp);
872 
873   ierr = ISLocalToGlobalMappingRestoreIndices(l2gmap,&garray);CHKERRQ(ierr);
874   ierr = ISLocalToGlobalMappingDestroy(&l2gmap);CHKERRQ(ierr);
875   PetscFunctionReturn(0);
876 }
877 
878 /* MatProductSymbolic_MPIAIJKokkos_AB - AB flavor of MatProductSymbolic_MPIAIJKokkos
879 
880   Input Parameters:
881 +  product  - Mat_Product which carried out the computation. Passed in to access info about this mat product.
882 .  A        - an MPIAIJKOKKOS matrix
883 .  B        - an MPIAIJKOKKOS matrix
884 -  mm       - a struct used to stash intermediate data when computing AB. Persist from symbolic to numeric operations.
885 
886   Notes: The local part of the result C is stored as mm->C_global, which is of type KokkosCsrMatrix and uses global col ids.
887 */
888 static PetscErrorCode MatProductSymbolic_MPIAIJKokkos_AB(Mat_Product *product,Mat A,Mat B,MatMatStruct_AB *mm)
889 {
890   PetscErrorCode              ierr;
891   Mat_MPIAIJ                  *a = static_cast<Mat_MPIAIJ*>(A->data);
892   Mat                         Ad = a->A,Ao = a->B; /* diag and offdiag of A */
893   IS                          glob = NULL;
894   const PetscInt              *garray;
895   PetscInt                    N = B->cmap->N,sz;
896   ConstMatColIdxKokkosView    l2g1; /* two temp maps mapping local col ids to global ones */
897   MatColIdxKokkosView         l2g2;
898   Mat                         C1,C2; /* intermediate matrices */
899 
900   PetscFunctionBegin;
901   /* C1 = Ad * B_local. B_local is a matrix got by merging Bd and Bo, and uses local col ids */
902   ierr = MatMPIAIJGetLocalMatMerge(B,MAT_INITIAL_MATRIX,&glob,&mm->B_local);CHKERRQ(ierr);
903   ierr = MatProductCreate(Ad,mm->B_local,NULL,&C1);CHKERRQ(ierr);
904   ierr = MatProductSetType(C1,MATPRODUCT_AB);CHKERRQ(ierr);
905   ierr = MatProductSetFill(C1,product->fill);CHKERRQ(ierr);
906   C1->product->api_user = product->api_user;
907   ierr = MatProductSetFromOptions(C1);CHKERRQ(ierr);
908   PetscCheckFalse(!C1->ops->productsymbolic,PETSC_COMM_SELF,PETSC_ERR_PLIB,"Missing symbolic op for %s",MatProductTypes[C1->product->type]);
909   ierr = (*C1->ops->productsymbolic)(C1);CHKERRQ(ierr);
910 
911   ierr = ISGetIndices(glob,&garray);CHKERRQ(ierr);
912   ierr = ISGetSize(glob,&sz);CHKERRQ(ierr);
913   const auto& tmp  = ConstMatColIdxKokkosViewHost(garray,sz); /* wrap garray as a view */
914   l2g1 = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(),tmp); /* maybe just an alias to tmp, so we restore garray at the very end */
915   ierr = MatSeqAIJKokkosGetCSRMatrixWithGlobalColumnIds(C1,N,l2g1,mm->C1_global);
916 
917   /* C2 = Ao * B_other. B_other is a matrix consisting of needed rows of B gathered from other procs */
918   ierr = MatSeqAIJKokkosBcast(mm->B_local,MAT_INITIAL_MATRIX,N,l2g1,a->Mvctx,mm->sf,
919                               mm->abuf,mm->rows,mm->rowoffset,mm->B_other);CHKERRQ(ierr);
920 
921   /* Compact B_other to use local ids as we guess KK spgemm is more memroy scalable with that; We could skip the compaction to simplify code */
922   ierr = MatSeqAIJCompactOutExtraColumns_SeqAIJKokkos(mm->B_other,l2g2);CHKERRQ(ierr);
923   ierr = MatProductCreate(Ao,mm->B_other,NULL,&C2);CHKERRQ(ierr);
924   ierr = MatProductSetType(C2,MATPRODUCT_AB);CHKERRQ(ierr);
925   ierr = MatProductSetFill(C2,product->fill);CHKERRQ(ierr);
926   C2->product->api_user = product->api_user;
927   ierr = MatProductSetFromOptions(C2);CHKERRQ(ierr);
928   PetscCheckFalse(!C2->ops->productsymbolic,PETSC_COMM_SELF,PETSC_ERR_PLIB,"Missing symbolic op for %s",MatProductTypes[C2->product->type]);
929   ierr = (*C2->ops->productsymbolic)(C2);CHKERRQ(ierr);
930   ierr = MatSeqAIJKokkosGetCSRMatrixWithGlobalColumnIds(C2,N,l2g2,mm->C2_global);
931 
932   /* C = C1 + C2.  We actually use their global col ids versions in adding */
933   mm->kh.create_spadd_handle(false); /* Input C1, C2 are NOT sorted, since B_local, B_other are not */
934   KokkosSparse::spadd_symbolic(&mm->kh,mm->C1_global,mm->C2_global,mm->C_global);
935   /* Have to do numeric since spadd_symbolic does not really populate column indices of the result matrix */
936   KokkosSparse::spadd_numeric(&mm->kh,(MatScalarType)1.0,mm->C1_global,(MatScalarType)1.0,mm->C2_global,mm->C_global);
937 
938   mm->C1 = C1;
939   mm->C2 = C2;
940   ierr = ISRestoreIndices(glob,&garray);CHKERRQ(ierr);
941   ierr = ISDestroy(&glob);CHKERRQ(ierr);
942   PetscFunctionReturn(0);
943 }
944 
945 /* MatProductSymbolic_MPIAIJKokkos_AtB - A^tB flavor of MatProductSymbolic_MPIAIJKokkos
946 
947   Input Parameters:
948 +  product  - Mat_Product which carried out the computation. Passed in to access info about this mat product.
949 .  A        - an MPIAIJKOKKOS matrix
950 .  B        - a SEQAIJKOKKOS matrix. It works as if A^t is multiplied by a parallel matrix made up of Bs on each rank.
951 .  localB   - Does B use local col ids? If false, then B is already in global col ids.
952 .  N        - col size of the "parallel B matrix". It implies B's global col ids are in range of [0,N) and N is the same across the communicator.
953 .  l2g      - If localB, then l2g maps B's local col ids to global ones.
954 -  mm       - a struct used to stash intermediate data in AtB
955 
956   Notes: The local part of the result C is stored as mm->C_global, which is of type KokkosCsrMatrix and uses global col ids.
957 */
958 static PetscErrorCode MatProductSymbolic_MPIAIJKokkos_AtB(Mat_Product *product,Mat A,Mat B,PetscBool localB,PetscInt N,const ConstMatColIdxKokkosView& l2g,MatMatStruct_AtB *mm)
959 {
960   PetscErrorCode         ierr;
961   Mat_MPIAIJ             *a = static_cast<Mat_MPIAIJ*>(A->data);
962   Mat                    Ad = a->A,Ao = a->B; /* diag and offdiag of A */
963   Mat                    C1,C2; /* intermediate matrices */
964 
965   PetscFunctionBegin;
966   /* C1 = Ad^t * B */
967   ierr = MatProductCreate(Ad,B,NULL,&C1);CHKERRQ(ierr);
968   ierr = MatProductSetType(C1,MATPRODUCT_AtB);CHKERRQ(ierr);
969   ierr = MatProductSetFill(C1,product->fill);CHKERRQ(ierr);
970   C1->product->api_user = product->api_user;
971   ierr = MatProductSetFromOptions(C1);CHKERRQ(ierr);
972   PetscCheckFalse(!C1->ops->productsymbolic,PETSC_COMM_SELF,PETSC_ERR_PLIB,"Missing symbolic op for %s",MatProductTypes[C1->product->type]);
973   ierr = (*C1->ops->productsymbolic)(C1);CHKERRQ(ierr);
974 
975   if (localB) {ierr = MatSeqAIJKokkosGetCSRMatrixWithGlobalColumnIds(C1,N,l2g,mm->C1_global);}
976   else mm->C1_global = static_cast<Mat_SeqAIJKokkos*>(C1->spptr)->csrmat; /* the csrmat already uses global col ids */
977 
978   /* C2 = Ao^t * B */
979   ierr = MatProductCreate(Ao,B,NULL,&C2);CHKERRQ(ierr);
980   ierr = MatProductSetType(C2,MATPRODUCT_AtB);CHKERRQ(ierr);
981   ierr = MatProductSetFill(C2,product->fill);CHKERRQ(ierr);
982   C2->product->api_user = product->api_user;
983   ierr = MatProductSetFromOptions(C2);CHKERRQ(ierr);
984   PetscCheckFalse(!C2->ops->productsymbolic,PETSC_COMM_SELF,PETSC_ERR_PLIB,"Missing symbolic op for %s",MatProductTypes[C2->product->type]);
985   ierr = (*C2->ops->productsymbolic)(C2);CHKERRQ(ierr);
986 
987   ierr = MatSeqAIJKokkosReduce(C2,MAT_INITIAL_MATRIX,localB,N,l2g,a->Mvctx,mm->sf,mm->abuf,
988                                mm->srcrowoffset,mm->dstrowoffset,mm->C2_global);CHKERRQ(ierr);
989 
990   mm->kh.create_spadd_handle(false); /* Input C1, C2 are NOT sorted, since B may be not */
991   KokkosSparse::spadd_symbolic(&mm->kh,mm->C1_global,mm->C2_global,mm->C_global);
992   /* Have to do numeric since spadd_symbolic does not really populate column indices of the result matrix */
993   KokkosSparse::spadd_numeric(&mm->kh,(MatScalarType)1.0,mm->C1_global,(MatScalarType)1.0,mm->C2_global,mm->C_global);
994   mm->C1 = C1;
995   mm->C2 = C2;
996   PetscFunctionReturn(0);
997 }
998 
999 PetscErrorCode MatProductNumeric_MPIAIJKokkos(Mat C)
1000 {
1001   PetscErrorCode                ierr;
1002   Mat_Product                   *product = C->product;
1003   MatProductType                ptype;
1004   MatProductData_MPIAIJKokkos   *mmdata;
1005   MatMatStruct                  *mm = NULL;
1006   MatMatStruct_AB               *ab;
1007   MatMatStruct_AtB              *atb;
1008   Mat                           A,B,Ad,Ao,Bd,Bo;
1009   const MatScalarType           one = 1.0; /* Not use literal 1.0 directly, to avoid wrong template instantiation in KokkosSparse::spadd_numeric */
1010 
1011   PetscFunctionBegin;
1012   MatCheckProduct(C,1);
1013   mmdata = static_cast<MatProductData_MPIAIJKokkos*>(product->data);
1014   ptype  = product->type;
1015   A      = product->A;
1016   B      = product->B;
1017   Ad     = static_cast<Mat_MPIAIJ*>(A->data)->A;
1018   Ao     = static_cast<Mat_MPIAIJ*>(A->data)->B;
1019   Bd     = static_cast<Mat_MPIAIJ*>(B->data)->A;
1020   Bo     = static_cast<Mat_MPIAIJ*>(B->data)->B;
1021 
1022   if (mmdata->reusesym) { /* We reached here through e.g., MatMatMult(A,B,MAT_INITIAL_MATRIX,..,C), where symbolic/numeric are combined */
1023     mmdata->reusesym = PETSC_FALSE; /* So that next time when user calls MatMatMult(E,F,MAT_REUSE_MATRIX,..,C), we still do numeric  */
1024     ab  = mmdata->mmAB;
1025     atb = mmdata->mmAtB;
1026     if (ab) {
1027       static_cast<MatProductData_SeqAIJKokkos*>(ab->C1->product->data)->reusesym = PETSC_FALSE;
1028       static_cast<MatProductData_SeqAIJKokkos*>(ab->C2->product->data)->reusesym = PETSC_FALSE;
1029     }
1030     if (atb) {
1031       static_cast<MatProductData_SeqAIJKokkos*>(atb->C1->product->data)->reusesym = PETSC_FALSE;
1032       static_cast<MatProductData_SeqAIJKokkos*>(atb->C2->product->data)->reusesym = PETSC_FALSE;
1033     }
1034     PetscFunctionReturn(0);
1035   }
1036 
1037   if (ptype == MATPRODUCT_AB) {
1038     ab   = mmdata->mmAB;
1039     /* C1 = Ad * B_local */
1040     PetscCheckFalse(!ab->C1->ops->productnumeric || !ab->C2->ops->productnumeric,PETSC_COMM_SELF,PETSC_ERR_PLIB,"Missing numeric op for MATPRODUCT_AB");
1041     ierr = MatMPIAIJGetLocalMatMerge(B,MAT_REUSE_MATRIX,NULL/*glob*/,&ab->B_local);CHKERRQ(ierr);
1042     PetscCheckFalse(ab->C1->product->B != ab->B_local,PETSC_COMM_SELF,PETSC_ERR_PLIB,"In MATPRODUCT_AB, internal mat product matrix C1->B has unexpectedly changed");
1043     if (ab->C1->product->A != Ad) {ierr = MatProductReplaceMats(Ad,NULL,NULL,ab->C1);CHKERRQ(ierr);}
1044     ierr = (*ab->C1->ops->productnumeric)(ab->C1);CHKERRQ(ierr);
1045     ierr = MatSeqAIJKokkosBcast(ab->B_local,MAT_REUSE_MATRIX,0/*N*/,MatColIdxKokkosView()/*l2g*/,NULL/*ownerSF*/,ab->sf,
1046                                 ab->abuf,ab->rows,ab->rowoffset,ab->B_other);CHKERRQ(ierr);
1047     /* C2 = Ao * B_other */
1048     PetscCheckFalse(ab->C2->product->B != ab->B_other,PETSC_COMM_SELF,PETSC_ERR_PLIB,"In MATPRODUCT_AB, internal mat product matrix C2->B has unexpectedly changed");
1049     if (ab->C1->product->A != Ao) {ierr = MatProductReplaceMats(Ao,NULL,NULL,ab->C2);CHKERRQ(ierr);}
1050     ierr = (*ab->C2->ops->productnumeric)(ab->C2);CHKERRQ(ierr);
1051     /* C = C1_global + C2_global */
1052     KokkosSparse::spadd_numeric(&ab->kh,one,ab->C1_global,one,ab->C2_global,ab->C_global);
1053     mm = static_cast<MatMatStruct*>(ab);
1054   } else if (ptype == MATPRODUCT_AtB) {
1055     atb  = mmdata->mmAtB;
1056     PetscCheckFalse(!atb->C1->ops->productnumeric || !atb->C2->ops->productnumeric,PETSC_COMM_SELF,PETSC_ERR_PLIB,"Missing numeric op for MATPRODUCT_AtB");
1057     /* C1 = Ad^t * B_local */
1058     ierr = MatMPIAIJGetLocalMatMerge(B,MAT_REUSE_MATRIX,NULL/*glob*/,&atb->B_local);CHKERRQ(ierr);
1059     PetscCheckFalse(atb->C1->product->B != atb->B_local,PETSC_COMM_SELF,PETSC_ERR_PLIB,"In MATPRODUCT_AtB, internal mat product matrix C1->B has unexpectedly changed");
1060     if (atb->C1->product->A != Ad) {ierr = MatProductReplaceMats(Ad,NULL,NULL,atb->C1);CHKERRQ(ierr);}
1061     ierr = (*atb->C1->ops->productnumeric)(atb->C1);CHKERRQ(ierr);
1062 
1063     /* C2 = Ao^t * B_local */
1064     PetscCheckFalse(atb->C2->product->B != atb->B_local,PETSC_COMM_SELF,PETSC_ERR_PLIB,"In MATPRODUCT_AtB, internal mat product matrix C2->B has unexpectedly changed");
1065     if (atb->C2->product->A != Ao) {ierr = MatProductReplaceMats(Ao,NULL,NULL,atb->C2);CHKERRQ(ierr);}
1066     ierr = (*atb->C2->ops->productnumeric)(atb->C2);CHKERRQ(ierr);
1067     /* Form C2_global */
1068     ierr = MatSeqAIJKokkosReduce(atb->C2,MAT_REUSE_MATRIX,PETSC_TRUE,0/*N*/,MatColIdxKokkosView()/*l2g*/,NULL/*ownerSF*/,atb->sf,
1069                                  atb->abuf,atb->srcrowoffset,atb->dstrowoffset,atb->C2_global);CHKERRQ(ierr);
1070     /* C = C1_global + C2_global */
1071     KokkosSparse::spadd_numeric(&atb->kh,one,atb->C1_global,one,atb->C2_global,atb->C_global);
1072     mm = static_cast<MatMatStruct*>(atb);
1073   } else if (ptype == MATPRODUCT_PtAP) { /* BtAB */
1074     ab   = mmdata->mmAB;
1075     ierr = MatMPIAIJGetLocalMatMerge(B,MAT_REUSE_MATRIX,NULL/*glob*/,&ab->B_local);CHKERRQ(ierr);
1076 
1077     /* ab->C1 = Ad * B_local */
1078     PetscCheckFalse(ab->C1->product->B != ab->B_local,PETSC_COMM_SELF,PETSC_ERR_PLIB,"In MATPRODUCT_PtAP, internal mat product matrix ab->C1->B has unexpectedly changed");
1079     if (ab->C1->product->A != Ad) {ierr = MatProductReplaceMats(Ad,NULL,NULL,ab->C1);CHKERRQ(ierr);}
1080     ierr = (*ab->C1->ops->productnumeric)(ab->C1);CHKERRQ(ierr);
1081     ierr = MatSeqAIJKokkosBcast(ab->B_local,MAT_REUSE_MATRIX,0/*N*/,MatColIdxKokkosView()/*l2g*/,NULL/*ownerSF*/,ab->sf,
1082                                 ab->abuf,ab->rows,ab->rowoffset,ab->B_other);CHKERRQ(ierr);
1083     /* ab->C2 = Ao * B_other */
1084     if (ab->C2->product->A != Ao) {ierr = MatProductReplaceMats(Ao,NULL,NULL,ab->C2);CHKERRQ(ierr);}
1085     ierr = (*ab->C2->ops->productnumeric)(ab->C2);CHKERRQ(ierr); /* C2 = Ao * B_other */
1086     KokkosSparse::spadd_numeric(&ab->kh,one,ab->C1_global,one,ab->C2_global,ab->C_global);
1087 
1088     /* atb->C1 = Bd^t * ab->C_petsc */
1089     atb  = mmdata->mmAtB;
1090     PetscCheckFalse(atb->C1->product->B != ab->C_petsc,PETSC_COMM_SELF,PETSC_ERR_PLIB,"In MATPRODUCT_PtAP, internal mat product matrix atb->C1->B has unexpectedly changed");
1091     if (atb->C1->product->A != Bd) {ierr = MatProductReplaceMats(Bd,NULL,NULL,atb->C1);CHKERRQ(ierr);}
1092     ierr = (*atb->C1->ops->productnumeric)(atb->C1);CHKERRQ(ierr);
1093     /* atb->C2 = Bo^t * ab->C_petsc */
1094     if (atb->C2->product->A != Bo) {ierr = MatProductReplaceMats(Bo,NULL,NULL,atb->C2);CHKERRQ(ierr);}
1095     ierr = (*atb->C2->ops->productnumeric)(atb->C2);CHKERRQ(ierr);
1096     ierr = MatSeqAIJKokkosReduce(atb->C2,MAT_REUSE_MATRIX,PETSC_FALSE,0/*N*/,MatColIdxKokkosView()/*l2g*/,NULL/*ownerSF*/,atb->sf,
1097                                  atb->abuf,atb->srcrowoffset,atb->dstrowoffset,atb->C2_global);CHKERRQ(ierr);
1098     KokkosSparse::spadd_numeric(&atb->kh,one,atb->C1_global,one,atb->C2_global,atb->C_global);
1099     mm = static_cast<MatMatStruct*>(atb);
1100   }
1101   /* Split C_global to form C */
1102   ierr = MatSetMPIAIJKokkosWithGlobalCSRMatrix(C,MAT_REUSE_MATRIX,mm->C_global,mm->Cdstart);CHKERRQ(ierr);
1103   PetscFunctionReturn(0);
1104 }
1105 
1106 PetscErrorCode MatProductSymbolic_MPIAIJKokkos(Mat C)
1107 {
1108   PetscErrorCode              ierr;
1109   Mat                         A,B;
1110   Mat_Product                 *product = C->product;
1111   MatProductType              ptype;
1112   MatProductData_MPIAIJKokkos *mmdata;
1113   MatMatStruct                *mm = NULL;
1114   IS                          glob = NULL;
1115   const PetscInt              *garray;
1116   PetscInt                    m,n,M,N,sz;
1117   ConstMatColIdxKokkosView    l2g; /* map local col ids to global ones */
1118 
1119   PetscFunctionBegin;
1120   MatCheckProduct(C,1);
1121   PetscCheckFalse(product->data,PetscObjectComm((PetscObject)C),PETSC_ERR_PLIB,"Product data not empty");
1122   ptype = product->type;
1123   A     = product->A;
1124   B     = product->B;
1125 
1126   switch (ptype) {
1127     case MATPRODUCT_AB:   m = A->rmap->n; n = B->cmap->n; M = A->rmap->N; N = B->cmap->N; break;
1128     case MATPRODUCT_AtB:  m = A->cmap->n; n = B->cmap->n; M = A->cmap->N; N = B->cmap->N; break;
1129     case MATPRODUCT_PtAP: m = B->cmap->n; n = B->cmap->n; M = B->cmap->N; N = B->cmap->N; break; /* BtAB */
1130     default: SETERRQ(PetscObjectComm((PetscObject)C),PETSC_ERR_PLIB,"Not for product type %s",MatProductTypes[ptype]);
1131   }
1132 
1133   ierr = MatSetSizes(C,m,n,M,N);CHKERRQ(ierr);
1134   ierr = PetscLayoutSetUp(C->rmap);CHKERRQ(ierr);
1135   ierr = PetscLayoutSetUp(C->cmap);CHKERRQ(ierr);
1136   ierr = MatSetType(C,((PetscObject)A)->type_name);CHKERRQ(ierr);
1137 
1138   mmdata           = new MatProductData_MPIAIJKokkos();
1139   mmdata->reusesym = product->api_user;
1140 
1141   if (ptype == MATPRODUCT_AB) {
1142     mmdata->mmAB = new MatMatStruct_AB();
1143     ierr = MatProductSymbolic_MPIAIJKokkos_AB(product,A,B,mmdata->mmAB);CHKERRQ(ierr);
1144     mm   = static_cast<MatMatStruct*>(mmdata->mmAB);
1145   } else if (ptype == MATPRODUCT_AtB) {
1146     mmdata->mmAtB = new MatMatStruct_AtB();
1147     auto atb      = mmdata->mmAtB;
1148     ierr = MatMPIAIJGetLocalMatMerge(B,MAT_INITIAL_MATRIX,&glob,&atb->B_local);CHKERRQ(ierr);
1149     ierr = ISGetIndices(glob,&garray);CHKERRQ(ierr);
1150     ierr = ISGetSize(glob,&sz);CHKERRQ(ierr);
1151     l2g  = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(),ConstMatColIdxKokkosViewHost(garray,sz));
1152     ierr = MatProductSymbolic_MPIAIJKokkos_AtB(product,A,atb->B_local,PETSC_TRUE,N,l2g,atb);CHKERRQ(ierr);
1153     ierr = ISRestoreIndices(glob,&garray);CHKERRQ(ierr);
1154     ierr = ISDestroy(&glob);CHKERRQ(ierr);
1155     mm   = static_cast<MatMatStruct*>(atb);
1156   } else if (ptype == MATPRODUCT_PtAP) { /* BtAB */
1157     mmdata->mmAB  = new MatMatStruct_AB(); /* tmp=A*B */
1158     mmdata->mmAtB = new MatMatStruct_AtB(); /* C=B^t*tmp */
1159     auto ab       = mmdata->mmAB;
1160     auto atb      = mmdata->mmAtB;
1161     ierr = MatProductSymbolic_MPIAIJKokkos_AB(product,A,B,ab);CHKERRQ(ierr);
1162     auto tmp = new Mat_SeqAIJKokkos(ab->C_global); /* Memory will be owned by ab->C_petsc */
1163     ierr = MatCreateSeqAIJKokkosWithCSRMatrix(PETSC_COMM_SELF,tmp,&ab->C_petsc);CHKERRQ(ierr);
1164     ierr = MatProductSymbolic_MPIAIJKokkos_AtB(product,B,ab->C_petsc,PETSC_FALSE,N,l2g/*not used*/,atb);CHKERRQ(ierr);
1165     mm   = static_cast<MatMatStruct*>(atb);
1166   }
1167   /* Split the C_global into petsc A, B format */
1168   ierr = MatSetMPIAIJKokkosWithGlobalCSRMatrix(C,MAT_INITIAL_MATRIX,mm->C_global,mm->Cdstart);CHKERRQ(ierr);
1169   C->product->data        = mmdata;
1170   C->product->destroy     = MatProductDataDestroy_MPIAIJKokkos;
1171   C->ops->productnumeric  = MatProductNumeric_MPIAIJKokkos;
1172   PetscFunctionReturn(0);
1173 }
1174 
1175 PETSC_INTERN PetscErrorCode MatProductSetFromOptions_MPIAIJKokkos(Mat mat)
1176 {
1177   PetscErrorCode ierr;
1178   Mat_Product    *product = mat->product;
1179   PetscBool      match = PETSC_FALSE;
1180   PetscBool      usecpu = PETSC_FALSE;
1181 
1182   PetscFunctionBegin;
1183   MatCheckProduct(mat,1);
1184   if (!product->A->boundtocpu && !product->B->boundtocpu) {
1185     ierr = PetscObjectTypeCompare((PetscObject)product->B,((PetscObject)product->A)->type_name,&match);CHKERRQ(ierr);
1186   }
1187   if (match) { /* we can always fallback to the CPU if requested */
1188     switch (product->type) {
1189     case MATPRODUCT_AB:
1190       if (product->api_user) {
1191         ierr = PetscOptionsBegin(PetscObjectComm((PetscObject)mat),((PetscObject)mat)->prefix,"MatMatMult","Mat");CHKERRQ(ierr);
1192         ierr = PetscOptionsBool("-matmatmult_backend_cpu","Use CPU code","MatMatMult",usecpu,&usecpu,NULL);CHKERRQ(ierr);
1193         ierr = PetscOptionsEnd();CHKERRQ(ierr);
1194       } else {
1195         ierr = PetscOptionsBegin(PetscObjectComm((PetscObject)mat),((PetscObject)mat)->prefix,"MatProduct_AB","Mat");CHKERRQ(ierr);
1196         ierr = PetscOptionsBool("-mat_product_algorithm_backend_cpu","Use CPU code","MatMatMult",usecpu,&usecpu,NULL);CHKERRQ(ierr);
1197         ierr = PetscOptionsEnd();CHKERRQ(ierr);
1198       }
1199       break;
1200     case MATPRODUCT_AtB:
1201       if (product->api_user) {
1202         ierr = PetscOptionsBegin(PetscObjectComm((PetscObject)mat),((PetscObject)mat)->prefix,"MatTransposeMatMult","Mat");CHKERRQ(ierr);
1203         ierr = PetscOptionsBool("-mattransposematmult_backend_cpu","Use CPU code","MatTransposeMatMult",usecpu,&usecpu,NULL);CHKERRQ(ierr);
1204         ierr = PetscOptionsEnd();CHKERRQ(ierr);
1205       } else {
1206         ierr = PetscOptionsBegin(PetscObjectComm((PetscObject)mat),((PetscObject)mat)->prefix,"MatProduct_AtB","Mat");CHKERRQ(ierr);
1207         ierr = PetscOptionsBool("-mat_product_algorithm_backend_cpu","Use CPU code","MatTransposeMatMult",usecpu,&usecpu,NULL);CHKERRQ(ierr);
1208         ierr = PetscOptionsEnd();CHKERRQ(ierr);
1209       }
1210       break;
1211     case MATPRODUCT_PtAP:
1212       if (product->api_user) {
1213         ierr = PetscOptionsBegin(PetscObjectComm((PetscObject)mat),((PetscObject)mat)->prefix,"MatPtAP","Mat");CHKERRQ(ierr);
1214         ierr = PetscOptionsBool("-matptap_backend_cpu","Use CPU code","MatPtAP",usecpu,&usecpu,NULL);CHKERRQ(ierr);
1215         ierr = PetscOptionsEnd();CHKERRQ(ierr);
1216       } else {
1217         ierr = PetscOptionsBegin(PetscObjectComm((PetscObject)mat),((PetscObject)mat)->prefix,"MatProduct_PtAP","Mat");CHKERRQ(ierr);
1218         ierr = PetscOptionsBool("-mat_product_algorithm_backend_cpu","Use CPU code","MatPtAP",usecpu,&usecpu,NULL);CHKERRQ(ierr);
1219         ierr = PetscOptionsEnd();CHKERRQ(ierr);
1220       }
1221       break;
1222     default:
1223       break;
1224     }
1225     match = (PetscBool)!usecpu;
1226   }
1227   if (match) {
1228     switch (product->type) {
1229     case MATPRODUCT_AB:
1230     case MATPRODUCT_AtB:
1231     case MATPRODUCT_PtAP:
1232       mat->ops->productsymbolic = MatProductSymbolic_MPIAIJKokkos;
1233       break;
1234     default:
1235       break;
1236     }
1237   }
1238   /* fallback to MPIAIJ ops */
1239   if (!mat->ops->productsymbolic) {
1240     ierr = MatProductSetFromOptions_MPIAIJ(mat);CHKERRQ(ierr);
1241   }
1242   PetscFunctionReturn(0);
1243 }
1244 
1245 static PetscErrorCode MatSetPreallocationCOO_MPIAIJKokkos(Mat mat, PetscCount coo_n, const PetscInt coo_i[], const PetscInt coo_j[])
1246 {
1247   Mat_MPIAIJ                *mpiaij = (Mat_MPIAIJ*)mat->data;
1248   Mat_MPIAIJKokkos          *mpikok;
1249   PetscErrorCode            ierr;
1250 
1251   PetscFunctionBegin;
1252   ierr = MatSetPreallocationCOO_MPIAIJ(mat,coo_n,coo_i,coo_j);CHKERRQ(ierr);
1253   mat->preallocated = PETSC_TRUE;
1254   ierr = MatAssemblyBegin(mat,MAT_FINAL_ASSEMBLY);CHKERRQ(ierr);
1255   ierr = MatAssemblyEnd(mat,MAT_FINAL_ASSEMBLY);CHKERRQ(ierr);
1256   ierr = MatZeroEntries(mat);CHKERRQ(ierr);
1257   mpikok = static_cast<Mat_MPIAIJKokkos*>(mpiaij->spptr);
1258   delete mpikok;
1259   mpiaij->spptr = new Mat_MPIAIJKokkos(mpiaij);
1260   PetscFunctionReturn(0);
1261 }
1262 
1263 static PetscErrorCode MatSetValuesCOO_MPIAIJKokkos(Mat mat,const PetscScalar v[],InsertMode imode)
1264 {
1265   PetscErrorCode                 ierr;
1266   Mat_MPIAIJ                     *mpiaij = static_cast<Mat_MPIAIJ*>(mat->data);
1267   Mat_MPIAIJKokkos               *mpikok = static_cast<Mat_MPIAIJKokkos*>(mpiaij->spptr);
1268   Mat                            A = mpiaij->A,B = mpiaij->B;
1269   PetscCount                     Annz1 = mpiaij->Annz1,Annz2 = mpiaij->Annz2,Bnnz1 = mpiaij->Bnnz1,Bnnz2 = mpiaij->Bnnz2;
1270   MatScalarKokkosView            Aa,Ba;
1271   MatScalarKokkosView            v1;
1272   MatScalarKokkosView&           vsend = mpikok->sendbuf_d;
1273   const MatScalarKokkosView&     v2 = mpikok->recvbuf_d;
1274   const PetscCountKokkosView&    Ajmap1 = mpikok->Ajmap1_d,Ajmap2 = mpikok->Ajmap2_d,Aimap1 = mpikok->Aimap1_d,Aimap2 = mpikok->Aimap2_d;
1275   const PetscCountKokkosView&    Bjmap1 = mpikok->Bjmap1_d,Bjmap2 = mpikok->Bjmap2_d,Bimap1 = mpikok->Bimap1_d,Bimap2 = mpikok->Bimap2_d;
1276   const PetscCountKokkosView&    Aperm1 = mpikok->Aperm1_d,Aperm2 = mpikok->Aperm2_d,Bperm1 = mpikok->Bperm1_d,Bperm2 = mpikok->Bperm2_d;
1277   const PetscCountKokkosView&    Cperm1 = mpikok->Cperm1_d;
1278   PetscMemType                   memtype;
1279 
1280   PetscFunctionBegin;
1281   ierr = PetscGetMemType(v,&memtype);CHKERRQ(ierr); /* Return PETSC_MEMTYPE_HOST when v is NULL */
1282   if (PetscMemTypeHost(memtype)) { /* If user gave v[] in host, we need to copy it to device if any */
1283     v1 = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(),MatScalarKokkosViewHost((PetscScalar*)v,mpiaij->coo_n));
1284   } else {
1285     v1 = MatScalarKokkosView((PetscScalar*)v,mpiaij->coo_n); /* Directly use v[]'s memory */
1286   }
1287 
1288   if (imode == INSERT_VALUES) {
1289     ierr = MatSeqAIJGetKokkosViewWrite(A,&Aa);CHKERRQ(ierr); /* write matrix values */
1290     ierr = MatSeqAIJGetKokkosViewWrite(B,&Ba);CHKERRQ(ierr);
1291     Kokkos::deep_copy(Aa,0.0); /* Zero matrix values since INSERT_VALUES still requires summing replicated values in v[] */
1292     Kokkos::deep_copy(Ba,0.0);
1293   } else {
1294     ierr = MatSeqAIJGetKokkosView(A,&Aa);CHKERRQ(ierr); /* read & write matrix values */
1295     ierr = MatSeqAIJGetKokkosView(B,&Ba);CHKERRQ(ierr);
1296   }
1297 
1298   /* Pack entries to be sent to remote */
1299   Kokkos::parallel_for(vsend.extent(0),KOKKOS_LAMBDA(const PetscCount i) {vsend(i) = v1(Cperm1(i));});
1300 
1301   /* Send remote entries to their owner and overlap the communication with local computation */
1302   ierr = PetscSFReduceWithMemTypeBegin(mpiaij->coo_sf,MPIU_SCALAR,PETSC_MEMTYPE_KOKKOS,vsend.data(),PETSC_MEMTYPE_KOKKOS,v2.data(),MPI_REPLACE);CHKERRQ(ierr);
1303   /* Add local entries to A and B */
1304   Kokkos::parallel_for(Annz1,KOKKOS_LAMBDA(const PetscCount i) {for (PetscCount k=Ajmap1(i); k<Ajmap1(i+1); k++) Aa(Aimap1(i)) += v1(Aperm1(k));});
1305   Kokkos::parallel_for(Bnnz1,KOKKOS_LAMBDA(const PetscCount i) {for (PetscCount k=Bjmap1(i); k<Bjmap1(i+1); k++) Ba(Bimap1(i)) += v1(Bperm1(k));});
1306   ierr = PetscSFReduceEnd(mpiaij->coo_sf,MPIU_SCALAR,vsend.data(),v2.data(),MPI_REPLACE);CHKERRQ(ierr);
1307 
1308   /* Add received remote entries to A and B */
1309   Kokkos::parallel_for(Annz2,KOKKOS_LAMBDA(const PetscCount i) {for (PetscCount k=Ajmap2(i); k<Ajmap2(i+1); k++) Aa(Aimap2(i)) += v2(Aperm2(k));});
1310   Kokkos::parallel_for(Bnnz2,KOKKOS_LAMBDA(const PetscCount i) {for (PetscCount k=Bjmap2(i); k<Bjmap2(i+1); k++) Ba(Bimap2(i)) += v2(Bperm2(k));});
1311 
1312   if (imode == INSERT_VALUES) {
1313     ierr = MatSeqAIJRestoreKokkosViewWrite(A,&Aa);CHKERRQ(ierr); /* Increase A & B's state etc. */
1314     ierr = MatSeqAIJRestoreKokkosViewWrite(B,&Ba);CHKERRQ(ierr);
1315   } else {
1316     ierr = MatSeqAIJRestoreKokkosView(A,&Aa);CHKERRQ(ierr);
1317     ierr = MatSeqAIJRestoreKokkosView(B,&Ba);CHKERRQ(ierr);
1318   }
1319   PetscFunctionReturn(0);
1320 }
1321 
1322 PetscErrorCode MatDestroy_MPIAIJKokkos(Mat A)
1323 {
1324   PetscErrorCode     ierr;
1325   Mat_MPIAIJ         *mpiaij = (Mat_MPIAIJ*)A->data;
1326 
1327   PetscFunctionBegin;
1328   ierr = PetscObjectComposeFunction((PetscObject)A,"MatMPIAIJSetPreallocation_C",NULL);CHKERRQ(ierr);
1329   ierr = PetscObjectComposeFunction((PetscObject)A,"MatMPIAIJGetLocalMatMerge_C",NULL);CHKERRQ(ierr);
1330   ierr = PetscObjectComposeFunction((PetscObject)A,"MatSetPreallocationCOO_C",   NULL);CHKERRQ(ierr);
1331   ierr = PetscObjectComposeFunction((PetscObject)A,"MatSetValuesCOO_C",          NULL);CHKERRQ(ierr);
1332   delete (Mat_MPIAIJKokkos*)mpiaij->spptr;
1333   ierr = MatDestroy_MPIAIJ(A);CHKERRQ(ierr);
1334   PetscFunctionReturn(0);
1335 }
1336 
1337 PETSC_INTERN PetscErrorCode MatConvert_MPIAIJ_MPIAIJKokkos(Mat A, MatType mtype, MatReuse reuse, Mat* newmat)
1338 {
1339   PetscErrorCode     ierr;
1340   Mat                B;
1341   Mat_MPIAIJ         *a;
1342 
1343   PetscFunctionBegin;
1344   if (reuse == MAT_INITIAL_MATRIX) {
1345     ierr = MatDuplicate(A,MAT_COPY_VALUES,newmat);CHKERRQ(ierr);
1346   } else if (reuse == MAT_REUSE_MATRIX) {
1347     ierr = MatCopy(A,*newmat,SAME_NONZERO_PATTERN);CHKERRQ(ierr);
1348   }
1349   B = *newmat;
1350 
1351   B->boundtocpu = PETSC_FALSE;
1352   ierr = PetscFree(B->defaultvectype);CHKERRQ(ierr);
1353   ierr = PetscStrallocpy(VECKOKKOS,&B->defaultvectype);CHKERRQ(ierr);
1354   ierr = PetscObjectChangeTypeName((PetscObject)B,MATMPIAIJKOKKOS);CHKERRQ(ierr);
1355 
1356   a = static_cast<Mat_MPIAIJ*>(A->data);
1357   if (a->A) {ierr = MatSetType(a->A,MATSEQAIJKOKKOS);CHKERRQ(ierr);}
1358   if (a->B) {ierr = MatSetType(a->B,MATSEQAIJKOKKOS);CHKERRQ(ierr);}
1359   if (a->lvec) {ierr = VecSetType(a->lvec,VECSEQKOKKOS);CHKERRQ(ierr);}
1360 
1361   B->ops->assemblyend           = MatAssemblyEnd_MPIAIJKokkos;
1362   B->ops->mult                  = MatMult_MPIAIJKokkos;
1363   B->ops->multadd               = MatMultAdd_MPIAIJKokkos;
1364   B->ops->multtranspose         = MatMultTranspose_MPIAIJKokkos;
1365   B->ops->productsetfromoptions = MatProductSetFromOptions_MPIAIJKokkos;
1366   B->ops->destroy               = MatDestroy_MPIAIJKokkos;
1367 
1368   ierr = PetscObjectComposeFunction((PetscObject)B,"MatMPIAIJSetPreallocation_C",MatMPIAIJSetPreallocation_MPIAIJKokkos);CHKERRQ(ierr);
1369   ierr = PetscObjectComposeFunction((PetscObject)B,"MatMPIAIJGetLocalMatMerge_C",MatMPIAIJGetLocalMatMerge_MPIAIJKokkos);CHKERRQ(ierr);
1370   ierr = PetscObjectComposeFunction((PetscObject)B,"MatSetPreallocationCOO_C",   MatSetPreallocationCOO_MPIAIJKokkos);CHKERRQ(ierr);
1371   ierr = PetscObjectComposeFunction((PetscObject)B,"MatSetValuesCOO_C",          MatSetValuesCOO_MPIAIJKokkos);CHKERRQ(ierr);
1372   PetscFunctionReturn(0);
1373 }
1374 /*MC
1375    MATSMPIAIJKOKKOS - MATAIJKOKKOS = "(mpi)aijkokkos" - A matrix type to be used for sparse matrices with Kokkos
1376 
1377    A matrix type type using Kokkos-Kernels CrsMatrix type for portability across different device types
1378 
1379    Options Database Keys:
1380 .  -mat_type aijkokkos - sets the matrix type to "aijkokkos" during a call to MatSetFromOptions()
1381 
1382   Level: beginner
1383 
1384 .seealso: MatCreateAIJKokkos(), MATSEQAIJKOKKOS
1385 M*/
1386 PETSC_EXTERN PetscErrorCode MatCreate_MPIAIJKokkos(Mat A)
1387 {
1388   PetscErrorCode ierr;
1389 
1390   PetscFunctionBegin;
1391   ierr = PetscKokkosInitializeCheck();CHKERRQ(ierr);
1392   ierr = MatCreate_MPIAIJ(A);CHKERRQ(ierr);
1393   ierr = MatConvert_MPIAIJ_MPIAIJKokkos(A,MATMPIAIJKOKKOS,MAT_INPLACE_MATRIX,&A);CHKERRQ(ierr);
1394   PetscFunctionReturn(0);
1395 }
1396 
1397 /*@C
1398    MatCreateAIJKokkos - Creates a sparse matrix in AIJ (compressed row) format
1399    (the default parallel PETSc format).  This matrix will ultimately pushed down
1400    to Kokkos for calculations. For good matrix
1401    assembly performance the user should preallocate the matrix storage by setting
1402    the parameter nz (or the array nnz).  By setting these parameters accurately,
1403    performance during matrix assembly can be increased by more than a factor of 50.
1404 
1405    Collective
1406 
1407    Input Parameters:
1408 +  comm - MPI communicator, set to PETSC_COMM_SELF
1409 .  m - number of rows
1410 .  n - number of columns
1411 .  nz - number of nonzeros per row (same for all rows)
1412 -  nnz - array containing the number of nonzeros in the various rows
1413          (possibly different for each row) or NULL
1414 
1415    Output Parameter:
1416 .  A - the matrix
1417 
1418    It is recommended that one use the MatCreate(), MatSetType() and/or MatSetFromOptions(),
1419    MatXXXXSetPreallocation() paradigm instead of this routine directly.
1420    [MatXXXXSetPreallocation() is, for example, MatSeqAIJSetPreallocation]
1421 
1422    Notes:
1423    If nnz is given then nz is ignored
1424 
1425    The AIJ format (also called the Yale sparse matrix format or
1426    compressed row storage), is fully compatible with standard Fortran 77
1427    storage.  That is, the stored row and column indices can begin at
1428    either one (as in Fortran) or zero.  See the users' manual for details.
1429 
1430    Specify the preallocated storage with either nz or nnz (not both).
1431    Set nz=PETSC_DEFAULT and nnz=NULL for PETSc to control dynamic memory
1432    allocation.  For large problems you MUST preallocate memory or you
1433    will get TERRIBLE performance, see the users' manual chapter on matrices.
1434 
1435    By default, this format uses inodes (identical nodes) when possible, to
1436    improve numerical efficiency of matrix-vector products and solves. We
1437    search for consecutive rows with the same nonzero structure, thereby
1438    reusing matrix information to achieve increased efficiency.
1439 
1440    Level: intermediate
1441 
1442 .seealso: MatCreate(), MatCreateAIJ(), MatSetValues(), MatSeqAIJSetColumnIndices(), MatCreateSeqAIJWithArrays(), MatCreateAIJ(), MATMPIAIJKOKKOS, MATAIJKokkos
1443 @*/
1444 PetscErrorCode  MatCreateAIJKokkos(MPI_Comm comm,PetscInt m,PetscInt n,PetscInt M,PetscInt N,PetscInt d_nz,const PetscInt d_nnz[],PetscInt o_nz,const PetscInt o_nnz[],Mat *A)
1445 {
1446   PetscErrorCode ierr;
1447   PetscMPIInt    size;
1448 
1449   PetscFunctionBegin;
1450   ierr = MatCreate(comm,A);CHKERRQ(ierr);
1451   ierr = MatSetSizes(*A,m,n,M,N);CHKERRQ(ierr);
1452   ierr = MPI_Comm_size(comm,&size);CHKERRMPI(ierr);
1453   if (size > 1) {
1454     ierr = MatSetType(*A,MATMPIAIJKOKKOS);CHKERRQ(ierr);
1455     ierr = MatMPIAIJSetPreallocation(*A,d_nz,d_nnz,o_nz,o_nnz);CHKERRQ(ierr);
1456   } else {
1457     ierr = MatSetType(*A,MATSEQAIJKOKKOS);CHKERRQ(ierr);
1458     ierr = MatSeqAIJSetPreallocation(*A,d_nz,d_nnz);CHKERRQ(ierr);
1459   }
1460   PetscFunctionReturn(0);
1461 }
1462 
1463 // get GPU pointer to stripped down Mat. For both Seq and MPI Mat.
1464 PetscErrorCode MatKokkosGetDeviceMatWrite(Mat A, PetscSplitCSRDataStructure *B)
1465 {
1466   PetscMPIInt                size,rank;
1467   MPI_Comm                   comm;
1468   PetscErrorCode             ierr;
1469   PetscSplitCSRDataStructure d_mat=NULL;
1470 
1471   PetscFunctionBegin;
1472   ierr = PetscObjectGetComm((PetscObject)A,&comm);CHKERRQ(ierr);
1473   ierr = MPI_Comm_size(comm,&size);CHKERRMPI(ierr);
1474   ierr = MPI_Comm_rank(comm,&rank);CHKERRMPI(ierr);
1475   if (size == 1) {
1476     ierr   = MatSeqAIJKokkosGetDeviceMat(A,&d_mat);CHKERRQ(ierr);
1477     ierr   = MatSeqAIJKokkosModifyDevice(A);CHKERRQ(ierr); /* Since we are going to modify matrix values on device */
1478   } else {
1479     Mat_MPIAIJ  *aij = (Mat_MPIAIJ*)A->data;
1480     ierr   = MatSeqAIJKokkosGetDeviceMat(aij->A,&d_mat);CHKERRQ(ierr);
1481     ierr   = MatSeqAIJKokkosModifyDevice(aij->A);CHKERRQ(ierr);
1482     ierr   = MatSeqAIJKokkosModifyDevice(aij->B);CHKERRQ(ierr);
1483     PetscCheck(A->nooffprocentries || aij->donotstash,PetscObjectComm((PetscObject)A),PETSC_ERR_SUP,"Device assembly does not currently support offproc values insertion. Use MatSetOption(A,MAT_NO_OFF_PROC_ENTRIES,PETSC_TRUE) or MatSetOption(A,MAT_IGNORE_OFF_PROC_ENTRIES,PETSC_TRUE)");
1484   }
1485   // act like MatSetValues because not called on host
1486   if (A->assembled) {
1487     if (A->was_assembled) {
1488       ierr = PetscInfo(A,"Assemble more than once already\n");CHKERRQ(ierr);
1489     }
1490     A->was_assembled = PETSC_TRUE; // this is done (lazy) in MatAssemble but we are not calling it anymore - done in AIJ AssemblyEnd, need here?
1491   } else {
1492     ierr = PetscInfo(A,"Warning !assemble ??? assembled=%" PetscInt_FMT "\n",A->assembled);CHKERRQ(ierr);
1493   }
1494   if (!d_mat) {
1495     struct _n_SplitCSRMat h_mat; /* host container */
1496     Mat_SeqAIJKokkos      *aijkokA;
1497     Mat_SeqAIJ            *jaca;
1498     PetscInt              n = A->rmap->n, nnz;
1499     Mat                   Amat;
1500     PetscInt              *colmap;
1501 
1502     /* create and copy h_mat */
1503     h_mat.M = A->cmap->N; // use for debug build
1504     ierr = PetscInfo(A,"Create device matrix in Kokkos\n");CHKERRQ(ierr);
1505     if (size == 1) {
1506       Amat = A;
1507       jaca = (Mat_SeqAIJ*)A->data;
1508       h_mat.rstart = 0; h_mat.rend = A->rmap->n;
1509       h_mat.cstart = 0; h_mat.cend = A->cmap->n;
1510       h_mat.offdiag.i = h_mat.offdiag.j = NULL;
1511       h_mat.offdiag.a = NULL;
1512       aijkokA = static_cast<Mat_SeqAIJKokkos*>(A->spptr);
1513     } else {
1514       Mat_MPIAIJ       *aij = (Mat_MPIAIJ*)A->data;
1515       Mat_SeqAIJ       *jacb = (Mat_SeqAIJ*)aij->B->data;
1516       PetscInt         ii;
1517       Mat_SeqAIJKokkos *aijkokB;
1518 
1519       Amat = aij->A;
1520       aijkokA = static_cast<Mat_SeqAIJKokkos*>(aij->A->spptr);
1521       aijkokB = static_cast<Mat_SeqAIJKokkos*>(aij->B->spptr);
1522       jaca = (Mat_SeqAIJ*)aij->A->data;
1523       PetscCheckFalse(aij->B->cmap->n && !aij->garray,comm,PETSC_ERR_PLIB,"MPIAIJ Matrix was assembled but is missing garray");
1524       PetscCheckFalse(aij->B->rmap->n != aij->A->rmap->n,comm,PETSC_ERR_SUP,"Only support aij->B->rmap->n == aij->A->rmap->n");
1525       aij->donotstash = PETSC_TRUE;
1526       aij->A->nooffprocentries = aij->B->nooffprocentries = A->nooffprocentries = PETSC_TRUE;
1527       jaca->nonew = jacb->nonew = PETSC_TRUE; // no more disassembly
1528       ierr = PetscCalloc1(A->cmap->N,&colmap);CHKERRQ(ierr);
1529       ierr = PetscLogObjectMemory((PetscObject)A,(A->cmap->N)*sizeof(PetscInt));CHKERRQ(ierr);
1530       for (ii=0; ii<aij->B->cmap->n; ii++) colmap[aij->garray[ii]] = ii+1;
1531       // allocate B copy data
1532       h_mat.rstart = A->rmap->rstart; h_mat.rend = A->rmap->rend;
1533       h_mat.cstart = A->cmap->rstart; h_mat.cend = A->cmap->rend;
1534       nnz = jacb->i[n];
1535       if (jacb->compressedrow.use) {
1536         const Kokkos::View<PetscInt*, Kokkos::HostSpace, Kokkos::MemoryTraits<Kokkos::Unmanaged> > h_i_k (jacb->i,n+1);
1537         aijkokB->i_uncompressed_d = Kokkos::View<PetscInt*>(Kokkos::create_mirror(DefaultMemorySpace(),h_i_k));
1538         Kokkos::deep_copy (aijkokB->i_uncompressed_d, h_i_k);
1539         h_mat.offdiag.i = aijkokB->i_uncompressed_d.data();
1540       } else {
1541          h_mat.offdiag.i = aijkokB->i_device_data();
1542       }
1543       h_mat.offdiag.j = aijkokB->j_device_data();
1544       h_mat.offdiag.a = aijkokB->a_device_data();
1545       {
1546         Kokkos::View<PetscInt*, Kokkos::HostSpace, Kokkos::MemoryTraits<Kokkos::Unmanaged> > h_colmap_k (colmap,A->cmap->N);
1547         aijkokB->colmap_d = Kokkos::View<PetscInt*>(Kokkos::create_mirror(DefaultMemorySpace(),h_colmap_k));
1548         Kokkos::deep_copy (aijkokB->colmap_d, h_colmap_k);
1549         h_mat.colmap = aijkokB->colmap_d.data();
1550         ierr = PetscFree(colmap);CHKERRQ(ierr);
1551       }
1552       h_mat.offdiag.ignorezeroentries = jacb->ignorezeroentries;
1553       h_mat.offdiag.n = n;
1554     }
1555     // allocate A copy data
1556     nnz = jaca->i[n];
1557     h_mat.diag.n = n;
1558     h_mat.diag.ignorezeroentries = jaca->ignorezeroentries;
1559     ierr = MPI_Comm_rank(comm,&h_mat.rank);CHKERRMPI(ierr);
1560     PetscCheckFalse(jaca->compressedrow.use,PETSC_COMM_SELF,PETSC_ERR_PLIB,"A does not suppport compressed row (todo)");
1561     else {
1562       h_mat.diag.i = aijkokA->i_device_data();
1563     }
1564     h_mat.diag.j = aijkokA->j_device_data();
1565     h_mat.diag.a = aijkokA->a_device_data();
1566     // copy pointers and metdata to device
1567     ierr = MatSeqAIJKokkosSetDeviceMat(Amat,&h_mat);CHKERRQ(ierr);
1568     ierr = MatSeqAIJKokkosGetDeviceMat(Amat,&d_mat);CHKERRQ(ierr);
1569     ierr = PetscInfo(A,"Create device Mat n=%" PetscInt_FMT " nnz=%" PetscInt_FMT "\n",h_mat.diag.n, nnz);CHKERRQ(ierr);
1570   }
1571   *B = d_mat; // return it, set it in Mat, and set it up
1572   A->assembled = PETSC_FALSE; // ready to write with matsetvalues - this done (lazy) in normal MatSetValues
1573   PetscFunctionReturn(0);
1574 }
1575 
1576 PETSC_INTERN PetscErrorCode MatSeqAIJKokkosGetOffloadMask(Mat A, const char **mask)
1577 {
1578   Mat_SeqAIJKokkos  *aijkok = static_cast<Mat_SeqAIJKokkos*>(A->spptr);
1579 
1580   PetscFunctionBegin;
1581   if (!aijkok) *mask = "AIJKOK_UNALLOCATED";
1582   else if (aijkok->a_dual.need_sync_host()) *mask = "PETSC_OFFLOAD_GPU";
1583   else if (aijkok->a_dual.need_sync_device()) *mask = "PETSC_OFFLOAD_CPU";
1584   else *mask = "PETSC_OFFLOAD_BOTH";
1585   PetscFunctionReturn(0);
1586 }
1587 
1588 PETSC_INTERN PetscErrorCode MatAIJKokkosPrintOffloadMask(Mat A)
1589 {
1590   PetscErrorCode    ierr;
1591   PetscMPIInt       size;
1592   Mat               Ad,Ao;
1593   const char        *amask,*bmask;
1594 
1595   PetscFunctionBegin;
1596   ierr = MPI_Comm_size(PetscObjectComm((PetscObject)A),&size);CHKERRMPI(ierr);
1597 
1598   if (size == 1) {
1599     ierr = MatSeqAIJKokkosGetOffloadMask(A,&amask);CHKERRQ(ierr);
1600     ierr = PetscPrintf(PETSC_COMM_SELF,"%s\n",amask);CHKERRQ(ierr);
1601   } else {
1602     Ad  = ((Mat_MPIAIJ*)A->data)->A;
1603     Ao  = ((Mat_MPIAIJ*)A->data)->B;
1604     ierr = MatSeqAIJKokkosGetOffloadMask(Ad,&amask);CHKERRQ(ierr);
1605     ierr = MatSeqAIJKokkosGetOffloadMask(Ao,&bmask);CHKERRQ(ierr);
1606     ierr = PetscPrintf(PETSC_COMM_SELF,"Diag : Off-diag = %s : %s\n",amask,bmask);CHKERRQ(ierr);
1607   }
1608   PetscFunctionReturn(0);
1609 }
1610