xref: /petsc/src/mat/impls/aij/mpi/kokkos/mpiaijkok.kokkos.cxx (revision 8fd105b637c659b5723a6c3ba83a32bc84aa12fb)
1 #include <petscvec_kokkos.hpp>
2 #include <petscsf.h>
3 #include <petsc/private/sfimpl.h>
4 #include <../src/mat/impls/aij/mpi/mpiaij.h>
5 #include <../src/mat/impls/aij/mpi/kokkos/mpiaijkok.hpp>
6 #include <KokkosSparse_spadd.hpp>
7 
8 PetscErrorCode MatAssemblyEnd_MPIAIJKokkos(Mat A,MatAssemblyType mode)
9 {
10   Mat_SeqAIJKokkos *aijkok;
11 
12   PetscFunctionBegin;
13   PetscCall(MatAssemblyEnd_MPIAIJ(A,mode));
14   aijkok = static_cast<Mat_SeqAIJKokkos*>(((Mat_MPIAIJ*)A->data)->A->spptr); /* Access spptr after MatAssemblyEnd_MPIAIJ(), which might have deleted old spptr */
15   if (aijkok && aijkok->device_mat_d.data()) {
16     A->offloadmask = PETSC_OFFLOAD_GPU; // in GPU mode, no going back. MatSetValues checks this
17   }
18 
19   PetscFunctionReturn(0);
20 }
21 
22 PetscErrorCode MatMPIAIJSetPreallocation_MPIAIJKokkos(Mat mat,PetscInt d_nz,const PetscInt d_nnz[],PetscInt o_nz,const PetscInt o_nnz[])
23 {
24   Mat_MPIAIJ     *mpiaij = (Mat_MPIAIJ*)mat->data;
25 
26   PetscFunctionBegin;
27   PetscCall(PetscLayoutSetUp(mat->rmap));
28   PetscCall(PetscLayoutSetUp(mat->cmap));
29 #if defined(PETSC_USE_DEBUG)
30   if (d_nnz) {
31     PetscInt i;
32     for (i=0; i<mat->rmap->n; i++) {
33       PetscCheck(d_nnz[i] >= 0,PETSC_COMM_SELF,PETSC_ERR_ARG_OUTOFRANGE,"d_nnz cannot be less than 0: local row %" PetscInt_FMT " value %" PetscInt_FMT,i,d_nnz[i]);
34     }
35   }
36   if (o_nnz) {
37     PetscInt i;
38     for (i=0; i<mat->rmap->n; i++) {
39       PetscCheck(o_nnz[i] >= 0,PETSC_COMM_SELF,PETSC_ERR_ARG_OUTOFRANGE,"o_nnz cannot be less than 0: local row %" PetscInt_FMT " value %" PetscInt_FMT,i,o_nnz[i]);
40     }
41   }
42 #endif
43 #if defined(PETSC_USE_CTABLE)
44   PetscCall(PetscTableDestroy(&mpiaij->colmap));
45 #else
46   PetscCall(PetscFree(mpiaij->colmap));
47 #endif
48   PetscCall(PetscFree(mpiaij->garray));
49   PetscCall(VecDestroy(&mpiaij->lvec));
50   PetscCall(VecScatterDestroy(&mpiaij->Mvctx));
51   /* Because the B will have been resized we simply destroy it and create a new one each time */
52   PetscCall(MatDestroy(&mpiaij->B));
53 
54   if (!mpiaij->A) {
55     PetscCall(MatCreate(PETSC_COMM_SELF,&mpiaij->A));
56     PetscCall(MatSetSizes(mpiaij->A,mat->rmap->n,mat->cmap->n,mat->rmap->n,mat->cmap->n));
57     PetscCall(PetscLogObjectParent((PetscObject)mat,(PetscObject)mpiaij->A));
58   }
59   if (!mpiaij->B) {
60     PetscMPIInt size;
61     PetscCallMPI(MPI_Comm_size(PetscObjectComm((PetscObject)mat),&size));
62     PetscCall(MatCreate(PETSC_COMM_SELF,&mpiaij->B));
63     PetscCall(MatSetSizes(mpiaij->B,mat->rmap->n,size > 1 ? mat->cmap->N : 0,mat->rmap->n,size > 1 ? mat->cmap->N : 0));
64     PetscCall(PetscLogObjectParent((PetscObject)mat,(PetscObject)mpiaij->B));
65   }
66   PetscCall(MatSetType(mpiaij->A,MATSEQAIJKOKKOS));
67   PetscCall(MatSetType(mpiaij->B,MATSEQAIJKOKKOS));
68   PetscCall(MatSeqAIJSetPreallocation(mpiaij->A,d_nz,d_nnz));
69   PetscCall(MatSeqAIJSetPreallocation(mpiaij->B,o_nz,o_nnz));
70   mat->preallocated = PETSC_TRUE;
71   PetscFunctionReturn(0);
72 }
73 
74 PetscErrorCode MatMult_MPIAIJKokkos(Mat mat,Vec xx,Vec yy)
75 {
76   Mat_MPIAIJ     *mpiaij = (Mat_MPIAIJ*)mat->data;
77   PetscInt       nt;
78 
79   PetscFunctionBegin;
80   PetscCall(VecGetLocalSize(xx,&nt));
81   PetscCheck(nt == mat->cmap->n,PETSC_COMM_SELF,PETSC_ERR_ARG_SIZ,"Incompatible partition of mat (%" PetscInt_FMT ") and xx (%" PetscInt_FMT ")",mat->cmap->n,nt);
82   PetscCall(VecScatterBegin(mpiaij->Mvctx,xx,mpiaij->lvec,INSERT_VALUES,SCATTER_FORWARD));
83   PetscCall((*mpiaij->A->ops->mult)(mpiaij->A,xx,yy));
84   PetscCall(VecScatterEnd(mpiaij->Mvctx,xx,mpiaij->lvec,INSERT_VALUES,SCATTER_FORWARD));
85   PetscCall((*mpiaij->B->ops->multadd)(mpiaij->B,mpiaij->lvec,yy,yy));
86   PetscFunctionReturn(0);
87 }
88 
89 PetscErrorCode MatMultAdd_MPIAIJKokkos(Mat mat,Vec xx,Vec yy,Vec zz)
90 {
91   Mat_MPIAIJ     *mpiaij = (Mat_MPIAIJ*)mat->data;
92   PetscInt       nt;
93 
94   PetscFunctionBegin;
95   PetscCall(VecGetLocalSize(xx,&nt));
96   PetscCheck(nt == mat->cmap->n,PETSC_COMM_SELF,PETSC_ERR_ARG_SIZ,"Incompatible partition of mat (%" PetscInt_FMT ") and xx (%" PetscInt_FMT ")",mat->cmap->n,nt);
97   PetscCall(VecScatterBegin(mpiaij->Mvctx,xx,mpiaij->lvec,INSERT_VALUES,SCATTER_FORWARD));
98   PetscCall((*mpiaij->A->ops->multadd)(mpiaij->A,xx,yy,zz));
99   PetscCall(VecScatterEnd(mpiaij->Mvctx,xx,mpiaij->lvec,INSERT_VALUES,SCATTER_FORWARD));
100   PetscCall((*mpiaij->B->ops->multadd)(mpiaij->B,mpiaij->lvec,zz,zz));
101   PetscFunctionReturn(0);
102 }
103 
104 PetscErrorCode MatMultTranspose_MPIAIJKokkos(Mat mat,Vec xx,Vec yy)
105 {
106   Mat_MPIAIJ     *mpiaij = (Mat_MPIAIJ*)mat->data;
107   PetscInt       nt;
108 
109   PetscFunctionBegin;
110   PetscCall(VecGetLocalSize(xx,&nt));
111   PetscCheck(nt == mat->rmap->n,PETSC_COMM_SELF,PETSC_ERR_ARG_SIZ,"Incompatible partition of mat (%" PetscInt_FMT ") and xx (%" PetscInt_FMT ")",mat->rmap->n,nt);
112   PetscCall((*mpiaij->B->ops->multtranspose)(mpiaij->B,xx,mpiaij->lvec));
113   PetscCall((*mpiaij->A->ops->multtranspose)(mpiaij->A,xx,yy));
114   PetscCall(VecScatterBegin(mpiaij->Mvctx,mpiaij->lvec,yy,ADD_VALUES,SCATTER_REVERSE));
115   PetscCall(VecScatterEnd(mpiaij->Mvctx,mpiaij->lvec,yy,ADD_VALUES,SCATTER_REVERSE));
116   PetscFunctionReturn(0);
117 }
118 
119 /* Merge the "A, B" matrices of mat into a matrix C.  mat's type is MPIAIJKOKKOS. C's type is MATSEQAIJKOKKOS.
120    A is put before B. C's size would be A->rmap->n by (A->cmap->n + B->cmap->n).
121    C still uses local column ids. Their corresponding global column ids are returned in glob.
122 */
123 PetscErrorCode MatMPIAIJGetLocalMatMerge_MPIAIJKokkos(Mat mat,MatReuse reuse,IS *glob,Mat *C)
124 {
125   Mat            Ad,Ao;
126   const PetscInt *cmap;
127 
128   PetscFunctionBegin;
129   PetscCall(MatMPIAIJGetSeqAIJ(mat,&Ad,&Ao,&cmap));
130   PetscCall(MatSeqAIJKokkosMergeMats(Ad,Ao,reuse,C));
131   if (glob) {
132     PetscInt cst, i, dn, on, *gidx;
133     PetscCall(MatGetLocalSize(Ad,NULL,&dn));
134     PetscCall(MatGetLocalSize(Ao,NULL,&on));
135     PetscCall(MatGetOwnershipRangeColumn(mat,&cst,NULL));
136     PetscCall(PetscMalloc1(dn+on,&gidx));
137     for (i=0; i<dn; i++) gidx[i]    = cst + i;
138     for (i=0; i<on; i++) gidx[i+dn] = cmap[i];
139     PetscCall(ISCreateGeneral(PetscObjectComm((PetscObject)Ad),dn+on,gidx,PETSC_OWN_POINTER,glob));
140   }
141   PetscFunctionReturn(0);
142 }
143 
144 /* Structs used in matrix product C=AB, C=A^tB and C=B^tAB */
145 struct MatMatStruct {
146   MatRowMapKokkosView   Cdstart; /* Used to split sequential matrix into petsc's A, B format */
147   PetscSF               sf; /* SF to send/recv matrix entries */
148   MatScalarKokkosView   abuf; /* buf of mat values in send/recv */
149   Mat                   C1,C2,B_local;
150   KokkosCsrMatrix       C1_global,C2_global,C_global;
151   KernelHandle          kh;
152   MatMatStruct() {
153     C1 = C2 = B_local = NULL;
154     sf = NULL;
155   }
156 
157   ~MatMatStruct() {
158     MatDestroy(&C1);
159     MatDestroy(&C2);
160     MatDestroy(&B_local);
161     PetscSFDestroy(&sf);
162     kh.destroy_spadd_handle();
163   }
164 };
165 
166 struct MatMatStruct_AB : public MatMatStruct {
167   MatColIdxKokkosView   rows;
168   MatRowMapKokkosView   rowoffset;
169   Mat                   B_other,C_petsc; /* SEQAIJKOKKOS matrices. TODO: have a better var name than C_petsc */
170 
171   MatMatStruct_AB() : B_other(NULL),C_petsc(NULL){}
172   ~MatMatStruct_AB() {
173     MatDestroy(&B_other);
174     MatDestroy(&C_petsc);
175   }
176 };
177 
178 struct MatMatStruct_AtB : public MatMatStruct {
179   MatRowMapKokkosView   srcrowoffset,dstrowoffset;
180 };
181 
182 struct MatProductData_MPIAIJKokkos
183 {
184   MatMatStruct_AB   *mmAB;
185   MatMatStruct_AtB  *mmAtB;
186   PetscBool         reusesym;
187 
188   MatProductData_MPIAIJKokkos(): mmAB(NULL),mmAtB(NULL),reusesym(PETSC_FALSE){}
189   ~MatProductData_MPIAIJKokkos() {
190     delete mmAB;
191     delete mmAtB;
192   }
193 };
194 
195 static PetscErrorCode MatProductDataDestroy_MPIAIJKokkos(void *data)
196 {
197   PetscFunctionBegin;
198   PetscCallCXX(delete static_cast<MatProductData_MPIAIJKokkos*>(data));
199   PetscFunctionReturn(0);
200 }
201 
202 /* MatSeqAIJKokkosGetCSRMatrixWithGlobalColumnIds - Get a KokkosCsrMatrix from a MATSEQAIJKOKKOS matrix
203 
204    Input Parameters:
205 +  A       - the MATSEQAIJKOKKOS matrix
206 .  N       - new column size for the returned Kokkos matrix
207 -  l2g     - a map that maps old col ids to new col ids
208 
209    Output Parameters:
210 .  csrmat  - the Kokkos matrix, which has the same row size as A, shares a, i but not j with A.
211  */
212 static PetscErrorCode MatSeqAIJKokkosGetCSRMatrixWithGlobalColumnIds(Mat A,PetscInt N,const ConstMatColIdxKokkosView& l2g,KokkosCsrMatrix& csrmat)
213 {
214   KokkosCsrMatrix&         orig = static_cast<Mat_SeqAIJKokkos*>(A->spptr)->csrmat;
215   MatColIdxKokkosView      jg("jg",orig.nnz()); /* New j array for csrmat */
216 
217   PetscFunctionBegin;
218   PetscCallCXX(Kokkos::parallel_for(orig.nnz(), KOKKOS_LAMBDA(const PetscInt i) {jg(i) = l2g(orig.graph.entries(i));}));
219   PetscCallCXX(csrmat = KokkosCsrMatrix("csrmat",orig.numRows(),N,orig.nnz(),orig.values,orig.graph.row_map,jg));
220   PetscFunctionReturn(0);
221 }
222 
223 /* MatSetMPIAIJKokkosWithSplitSeqAIJKokkosMatrices - Set the diag and offdiag matrices of a MATMPIAIJKOKKOS matrix.
224    It is similar to MatCreateMPIAIJWithSplitArrays.
225 
226   Input Parameters:
227 +  mat   - the MATMPIAIJKOKKOS matrix, which should have its type and layout set, but should not have its diag, offdiag matrices set
228 .  A     - the diag matrix using local col ids
229 -  B     - the offdiag matrix using global col ids
230 
231   Output Parameters:
232 .  mat   - the updated MATMPIAIJKOKKOS matrix
233 */
234 static PetscErrorCode MatSetMPIAIJKokkosWithSplitSeqAIJKokkosMatrices(Mat mat,Mat A,Mat B)
235 {
236   Mat_MPIAIJ          *mpiaij = static_cast<Mat_MPIAIJ*>(mat->data);
237   PetscInt            m,n,M,N,Am,An,Bm,Bn;
238   Mat_SeqAIJKokkos    *bkok = static_cast<Mat_SeqAIJKokkos*>(B->spptr);
239 
240   PetscFunctionBegin;
241   PetscCall(MatGetSize(mat,&M,&N));
242   PetscCall(MatGetLocalSize(mat,&m,&n));
243   PetscCall(MatGetLocalSize(A,&Am,&An));
244   PetscCall(MatGetLocalSize(B,&Bm,&Bn));
245 
246   PetscCheck(m == Am && m == Bm,PETSC_COMM_SELF,PETSC_ERR_PLIB,"local number of rows do not match");
247   PetscCheck(n == An,PETSC_COMM_SELF,PETSC_ERR_PLIB,"local number of columns do not match");
248   PetscCheck(N == Bn,PETSC_COMM_SELF,PETSC_ERR_PLIB,"global number of columns do not match");
249   PetscCheck(!mpiaij->A && !mpiaij->B,PETSC_COMM_SELF,PETSC_ERR_PLIB,"A, B of the MPIAIJ matrix are not empty");
250   mpiaij->A = A;
251   mpiaij->B = B;
252 
253   mat->preallocated     = PETSC_TRUE;
254   mat->nooffprocentries = PETSC_TRUE; /* See MatAssemblyBegin_MPIAIJ. In effect, making MatAssemblyBegin a nop */
255 
256   PetscCall(MatSetOption(mat,MAT_NO_OFF_PROC_ENTRIES,PETSC_TRUE));
257   PetscCall(MatAssemblyBegin(mat,MAT_FINAL_ASSEMBLY));
258   /* MatAssemblyEnd is critical here. It sets mat->offloadmask according to A and B's, and
259     also gets mpiaij->B compacted, with its col ids and size reduced
260   */
261   PetscCall(MatAssemblyEnd(mat,MAT_FINAL_ASSEMBLY));
262   PetscCall(MatSetOption(mat,MAT_NO_OFF_PROC_ENTRIES,PETSC_FALSE));
263   PetscCall(MatSetOption(mat,MAT_NEW_NONZERO_LOCATION_ERR,PETSC_TRUE));
264 
265   /* Update bkok with new local col ids (stored on host) and size */
266   bkok->j_dual.modify_host();
267   bkok->j_dual.sync_device();
268   bkok->SetColSize(mpiaij->B->cmap->n);
269   PetscFunctionReturn(0);
270 }
271 
272 /* MatSeqAIJKokkosBcast - Bcast rows of a SEQAIJKOKKOS matrice (B) to form a SEQAIJKOKKOS matrix (C).
273 
274    It is essentially the MPIAIJKOKKOS counterpart of MatGetBrowsOfAoCols_MPIAIJ, but supports device and uses PetscSF.
275    In the given ownerSF, leaves correspond to rows in C, and roots correspond to rows in B. Roots may connect to multiple leaves.
276    Suppose C's j-th row is connected to a root identified by PetscSFNode (k,i), it means we will bcast the i-th row of B on rank k
277    to j-th row of C. ownerSF's leaves must be contiguous (in other words, as if ilocal=NULL was used to set its graph).
278 
279    Collective on comm of ownerSF
280 
281    Input Parameters:
282 +   B       - the SEQAIJKOKKOS matrix, using local col ids
283 .   reuse   - either MAT_INITIAL_MATRIX or MAT_REUSE_MATRIX
284 .   N       - global col ids are in range of [0,N). N Must be the same across ranks (nonsignificant in MAT_REUSE_MATRIX)
285 .   l2g     - a map mapping B's local col ids to global ones (nonsignificant in MAT_REUSE_MATRIX)
286 .   ownerSF - the ownership SF (nonsignificant in MAT_REUSE_MATRIX)
287 
288    Input/Output Parameters (out when resue = MAT_INITIAL_MATRIX, inout when reuse = MAT_REUSE_MATRIX)
289 +   bcastSF   - the SF used to bcast rows of B. This plain SF does buffer (abuf) to buffer (Ca) send/recv. In this SF, vertices are nonzeros.
290 .   abuf      - buffer for sending matrix values
291 .   rows      - array containing indices of (local) rows that this rank needs to bcast to others. Each receiver rank has a chunk in rows[].
292                 Values in rows[] might have repeats, which simply indicates a row will be bcast'ed to multiple neighbors.
293 .   rowoffset - For each row in rows[], it will be copied to rowoffset[] at abuf[]
294 -   C         -  the SEQAIJKOKKOS matrix made of the bcast'ed rows, using local col ids.
295 */
296 static PetscErrorCode MatSeqAIJKokkosBcast(Mat B,MatReuse reuse,PetscInt N,const ConstMatColIdxKokkosView& l2g,PetscSF ownerSF,
297                                            PetscSF& bcastSF,MatScalarKokkosView& abuf,MatColIdxKokkosView& rows,
298                                            MatRowMapKokkosView& rowoffset,Mat& C)
299 {
300   Mat_SeqAIJKokkos             *bkok,*ckok;
301 
302   PetscFunctionBegin;
303   PetscCall(MatSeqAIJKokkosSyncDevice(B)); /* Make sure B->spptr is accessible */
304   bkok = static_cast<Mat_SeqAIJKokkos*>(B->spptr);
305 
306   if (reuse == MAT_REUSE_MATRIX) {
307     ckok = static_cast<Mat_SeqAIJKokkos*>(C->spptr);
308 
309     const auto& Ba = bkok->a_dual.view_device();
310     const auto& Bi = bkok->i_dual.view_device();
311     const auto& Ca = ckok->a_dual.view_device();
312 
313     /* Copy Ba to abuf */
314     Kokkos::parallel_for(Kokkos::TeamPolicy<>(rows.extent(0), Kokkos::AUTO()),KOKKOS_LAMBDA(const KokkosTeamMemberType& t) {
315       PetscInt i    = t.league_rank(); /* rows[i] is r-th row of B */
316       PetscInt r    = rows(i);
317       PetscInt base = rowoffset(i); /* Copy r-th row of B to this offset in abuf[] */
318       Kokkos::parallel_for(Kokkos::TeamThreadRange(t,Bi(r+1)-Bi(r)),[&](PetscInt k) {
319         abuf(base+k) = Ba(Bi(r)+k);
320       });
321     });
322 
323     /* Send abuf to Ca through bcastSF and then mark C is updated on device */
324     PetscCall(PetscSFBcastBegin(bcastSF,MPIU_SCALAR,abuf.data(),Ca.data(),MPI_REPLACE)); /* TODO: get memtype for abuf */
325     PetscCall(PetscSFBcastEnd  (bcastSF,MPIU_SCALAR,abuf.data(),Ca.data(),MPI_REPLACE));
326     ckok->a_dual.modify_device();
327   } else if (reuse == MAT_INITIAL_MATRIX) {
328     MPI_Comm       comm;
329     PetscMPIInt    tag;
330     PetscInt       k,Cm,Cn,Cnnz,*Ci_h,nroots,nleaves;
331 
332     PetscCallMPI(PetscObjectGetComm((PetscObject)ownerSF,&comm));
333     PetscCall(PetscSFGetGraph(ownerSF,&nroots,&nleaves,NULL,NULL));
334     Cm   = nleaves; /* row size of C */
335     Cn   = N;  /* col size of C, which initially uses global ids, so we can safely set its col size as N */
336 
337     /* Get row lens (nz) of B's rows for later fast query */
338     PetscInt       *Browlens;
339     const PetscInt *tmp = bkok->i_host_data();
340     PetscCall(PetscMalloc1(nroots,&Browlens));
341     for (k=0; k<nroots; k++) Browlens[k] = tmp[k+1]-tmp[k];
342 
343     /* By ownerSF, each proc gets lens of rows of C */
344     MatRowMapKokkosDualView Ci("i",Cm+1); /* C's rowmap */
345     Ci_h    = Ci.view_host().data();
346     Ci_h[0] = 0;
347     PetscCall(PetscSFBcastWithMemTypeBegin(ownerSF,MPIU_INT,PETSC_MEMTYPE_HOST,Browlens,PETSC_MEMTYPE_HOST,&Ci_h[1],MPI_REPLACE));
348     PetscCall(PetscSFBcastEnd(ownerSF,MPIU_INT,Browlens,&Ci_h[1],MPI_REPLACE));
349     for (k=1; k<Cm+1; k++) Ci_h[k] += Ci_h[k-1]; /* Convert lens to CSR */
350     Cnnz    = Ci_h[Cm];
351     Ci.modify_host();
352     Ci.sync_device();
353 
354     /* With the newly known Cnnz, we are able to allocate (j, a) for C on host & device */
355     MatColIdxKokkosDualView  Cj("j",Cnnz);
356     MatScalarKokkosDualView  Ca("a",Cnnz);
357 
358     /* Now build the bcastSF to fill Ca, Cj. This plain SF only does (contiguous) buffer to buffer send/recv */
359     const PetscMPIInt *iranks,*ranks;
360     const PetscInt    *ioffset,*irootloc,*roffset;
361     PetscInt          i,j,niranks,nranks,*sdisp,*rdisp,*rowptr;
362     MPI_Request       *reqs;
363 
364     PetscCall(PetscSFGetLeafRanks(ownerSF,&niranks,&iranks,&ioffset,&irootloc)); /* irootloc[] contains indices of rows I need to send to each receiver */
365     PetscCall(PetscSFGetRootRanks(ownerSF,&nranks,&ranks,&roffset,NULL/*rmine*/,NULL/*rremote*/)); /* recv info */
366 
367     /* figure out offsets at the send buffer, to build the SF
368       sdisp[]  - stores offsets of nonzeros (in abuf or jbuf, see later) I need to send, per receiver.
369       rowptr[] - stores offsets for data of each row in abuf
370 
371       rdisp[]  - to receive sdisp[]
372     */
373     PetscCall(PetscMalloc3(niranks+1,&sdisp,nranks,&rdisp,niranks+nranks,&reqs));
374     MatRowMapKokkosViewHost rowptr_h("rowptr_h",ioffset[niranks]+1); /* Let Kokkos do the allocation, so that we can do an easy mirror later */
375     rowptr = rowptr_h.data();
376 
377     sdisp[0] = 0;
378     rowptr[0]  = 0;
379     for (i=0; i<niranks; i++) { /* for each receiver */
380       PetscInt len, nz = 0;
381       for (j=ioffset[i]; j<ioffset[i+1]; j++) { /* for each row to this receiver */
382         len         = Browlens[irootloc[j]];
383         rowptr[j+1] = rowptr[j] + len;
384         nz         += len;
385       }
386       sdisp[i+1] = sdisp[i] + nz;
387     }
388     PetscCallMPI(PetscCommGetNewTag(comm,&tag));
389     for (i=0; i<nranks; i++)  PetscCallMPI(MPI_Irecv(&rdisp[i],1,MPIU_INT,ranks[i],tag,comm,&reqs[i]));
390     for (i=0; i<niranks; i++) PetscCallMPI(MPI_Isend(&sdisp[i],1,MPIU_INT,iranks[i],tag,comm,&reqs[nranks+i]));
391     PetscCallMPI(MPI_Waitall(niranks+nranks,reqs,MPI_STATUSES_IGNORE));
392 
393     PetscInt    nleaves2 = Cnnz; /* leaves are the nonzeros I will receive */
394     PetscInt    nroots2  = sdisp[niranks]; /* roots are the nonzeros (in abuf) I will send */
395     PetscSFNode *iremote;
396     PetscCall(PetscMalloc1(nleaves2,&iremote));
397     for (i=0; i<nranks; i++) { /* for each sender */
398       k = 0;
399       for (j=Ci_h[roffset[i]]; j<Ci_h[roffset[i+1]]; j++) {
400         iremote[j].rank  = ranks[i];
401         iremote[j].index = rdisp[i] + k;
402         k++;
403       }
404     }
405     /* TODO: we should extend PetscSF APIs for this buffer-to-buffer send/recv */
406     PetscCall(PetscSFCreate(comm,&bcastSF));
407     PetscCall(PetscSFSetGraph(bcastSF,nroots2,nleaves2,NULL/*ilocal*/,PETSC_OWN_POINTER,iremote,PETSC_OWN_POINTER));
408 
409     /* Extract selected rows of B, and copy their (a, j) into abuf[] and jbuf[], with j converted
410       from local to global. Then use bcastSF to fill Ca, Cj.
411     */
412     ConstMatColIdxKokkosViewHost rows_h(irootloc,ioffset[niranks]); /* irootloc[] stores indices of rows I need to send */
413     MatColIdxKokkosView          rows("rows",ioffset[niranks]);
414     Kokkos::deep_copy(rows,rows_h); /* Use deep copy since irootoc is managed by PetscSF and we want 'rows' to be standalone */
415 
416     rowoffset = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(),rowptr_h); /* If no device, rowoffset will be an alias to rowptr_h */
417 
418     MatColIdxKokkosView jbuf("jbuf",sdisp[niranks]); /* send buf for (global) col ids */
419     abuf = MatScalarKokkosView("abuf",sdisp[niranks]); /* send buf for mat values */
420 
421     const auto& Ba = bkok->a_dual.view_device();
422     const auto& Bi = bkok->i_dual.view_device();
423     const auto& Bj = bkok->j_dual.view_device();
424 
425     /* Copy Ba, Bj to abuf, jbuf with change col ids from local to global */
426     Kokkos::parallel_for(Kokkos::TeamPolicy<>(rows.extent(0), Kokkos::AUTO()),KOKKOS_LAMBDA(const KokkosTeamMemberType& t) {
427       PetscInt i    = t.league_rank(); /* rows[i] is r-th row of B */
428       PetscInt r    = rows(i);
429       PetscInt base = rowoffset(i); /* Copy r-th row of B to this offset in abuf[] */
430       Kokkos::parallel_for(Kokkos::TeamThreadRange(t,Bi(r+1)-Bi(r)),[&](PetscInt k) {
431         abuf(base+k) = Ba(Bi(r)+k);
432         jbuf(base+k) = l2g(Bj(Bi(r)+k));
433       });
434     });
435 
436     /* Send abuf & jbuf to fill Ca, Cj */
437     PetscCall(PetscSFBcastBegin(bcastSF,MPIU_INT,   jbuf.data(),Cj.view_device().data(),MPI_REPLACE));
438     PetscCall(PetscSFBcastBegin(bcastSF,MPIU_SCALAR,abuf.data(),Ca.view_device().data(),MPI_REPLACE));
439     PetscCall(PetscSFBcastEnd  (bcastSF,MPIU_INT,   jbuf.data(),Cj.view_device().data(),MPI_REPLACE));
440     PetscCall(PetscSFBcastEnd  (bcastSF,MPIU_SCALAR,abuf.data(),Ca.view_device().data(),MPI_REPLACE));
441     Cj.modify_device(); /* Mark Cj, Ca modified on device, but only sync Cj since we might not need Ca on host at all */
442     Cj.sync_host();
443     Ca.modify_device();
444 
445     /* Construct C with Ca, Ci, Cj */
446     auto ckok = new Mat_SeqAIJKokkos(Cm,Cn,Cnnz,Ci,Cj,Ca);
447     PetscCall(MatCreateSeqAIJKokkosWithCSRMatrix(PETSC_COMM_SELF,ckok,&C));
448     PetscCall(PetscFree3(sdisp,rdisp,reqs));
449     PetscCall(PetscFree(Browlens));
450   } else SETERRQ(PETSC_COMM_SELF,PETSC_ERR_PLIB,"Unsupported MatReuse enum %d",reuse);
451   PetscFunctionReturn(0);
452 }
453 
454 /* MatSeqAIJKokkosReduce - Reduce rows of a SEQAIJKOKKOS matrix (A) to form a Kokkos Csr matrix (C)
455 
456   It is the reverse of MatSeqAIJKokkosBcast in some sense.
457 
458   Think each row of A as a leaf, then the given ownerSF specifies roots of the leaves. Roots may connect to multiple leaves.
459   In this routine, we reduce (i.e., concatenate) leaves (rows) at their roots to form potentially longer rows in C. Such rows might
460   contain repeats, which does not matter since they will be summed up by other routines. C's row size will be nroots of ownerSF.
461 
462   Input Parameters:
463 +  A        - the SEQAIJKOKKOS matrix to be reduced
464 .  reuse    - either MAT_INITIAL_MATRIX or MAT_REUSE_MATRIX
465 .  local    - true if A uses local col ids; false if A is already in global col ids.
466 .  N        - if local, N is A's global col size
467 .  l2g      - if local, a map mapping A's local col ids to global ones, which are in range of [0,N).
468 -  ownerSF  - the SF specifies ownership (root) of rows in A
469 
470   Output Parameters:
471 +  reduceSF    - the SF to reduce A's rows to contiguous buffers at the receiver side
472 .  abuf         - a contiguous buffer to receive A's rows sent to this proc. Suppose there are 'nrows' such rows.
473 .  srcrowoffset - offset array of size nrows+1. Each entry is the corresponding row's offset in abuf[]. srcrowoffset[i+1]-srcrowoffset[i] is row i's len.
474 .  dstrowoffset - offset array of size nrows. Each entry is the corresponding row's offset in Ca[], i.e., C's 'a' array. Row i, i+1 in abuf[] may go to
475                   unrelated places in Ca, so dstrowoffset is not in CSR-like format as srcrowoffset.
476 -  C            - the matrix made up by rows sent to me from other ranks, using global col ids
477 
478    TODO: we can even have MatSeqAIJKokkosReduceBegin/End to provide oppertunity for callers to overlap comp./comm. when reuse = MAT_REUSE_MATRIX.
479  */
480 static PetscErrorCode MatSeqAIJKokkosReduce(Mat A,MatReuse reuse,PetscBool local,PetscInt N,const ConstMatColIdxKokkosView& l2g,PetscSF ownerSF,
481                                             PetscSF& reduceSF,MatScalarKokkosView& abuf,
482                                             MatRowMapKokkosView& srcrowoffset,MatRowMapKokkosView& dstrowoffset,
483                                             KokkosCsrMatrix& C)
484 {
485   PetscInt               i,r,Am,An,Annz,Cnnz,nrows;
486   const PetscInt         *Ai;
487   Mat_SeqAIJKokkos       *akok;
488 
489   PetscFunctionBegin;
490   PetscCall(MatSeqAIJKokkosSyncDevice(A)); /* So that A's latest data is on device */
491   PetscCall(MatGetSize(A,&Am,&An));
492   Ai   = static_cast<Mat_SeqAIJ*>(A->data)->i;
493   akok = static_cast<Mat_SeqAIJKokkos*>(A->spptr);
494   Annz = Ai[Am];
495 
496   if (reuse == MAT_REUSE_MATRIX) {
497     /* Send Aa to abuf */
498     PetscCallMPI(PetscSFReduceBegin(reduceSF,MPIU_SCALAR,akok->a_device_data(),abuf.data(),MPI_REPLACE));
499     PetscCallMPI(PetscSFReduceEnd  (reduceSF,MPIU_SCALAR,akok->a_device_data(),abuf.data(),MPI_REPLACE));
500 
501     /* Copy abuf to Ca */
502     const MatScalarKokkosView& Ca = C.values;
503     nrows = dstrowoffset.extent(0); /* Not srcrowoffset[] since it has an extra entry for CSR */
504     Kokkos::parallel_for(Kokkos::TeamPolicy<>(nrows, Kokkos::AUTO()),KOKKOS_LAMBDA(const KokkosTeamMemberType& t) {
505       PetscInt i   = t.league_rank();
506       PetscInt src = srcrowoffset(i), dst = dstrowoffset(i);
507       PetscInt len = srcrowoffset(i+1) - srcrowoffset(i);
508       Kokkos::parallel_for(Kokkos::TeamThreadRange(t,len), [&](PetscInt k) {Ca(dst+k) = abuf(src+k);});
509     });
510   } else if (reuse == MAT_INITIAL_MATRIX) {
511     MPI_Comm               comm;
512     MPI_Request            *reqs;
513     PetscMPIInt            tag;
514     PetscInt               Cm;
515 
516     PetscCall(PetscObjectGetComm((PetscObject)ownerSF,&comm));
517     PetscCall(PetscCommGetNewTag(comm,&tag));
518 
519     PetscInt niranks,nranks,nroots,nleaves;
520     const PetscMPIInt *iranks,*ranks;
521     const PetscInt *ioffset,*rows,*roffset;  /* rows[] contains local indices of rows scattered to me from others. ioffset[] is a CSR on rows[] */
522     PetscCall(PetscSFSetUp(ownerSF));
523     PetscCall(PetscSFGetLeafRanks(ownerSF,&niranks,&iranks,&ioffset,&rows)); /* recv info: iranks[] will send rows to me */
524     PetscCall(PetscSFGetRootRanks(ownerSF,&nranks,&ranks,&roffset,NULL/*rmine*/,NULL/*rremote*/)); /* send info */
525     PetscCall(PetscSFGetGraph(ownerSF,&nroots,&nleaves,NULL,NULL));
526     PetscCheck(nleaves == Am,PETSC_COMM_SELF,PETSC_ERR_PLIB,"ownerSF's nleaves(%" PetscInt_FMT ") != row size of A(%" PetscInt_FMT ")",nleaves,Am);
527     Cm    = nroots;
528     nrows = ioffset[niranks]; /* # of rows to be received. Might receive same row (each is partial) from different senders */
529 
530     /* Tell owners how long each row I will send */
531     PetscInt                *srowlens; /* send buf of row lens */
532     MatRowMapKokkosViewHost rrowlens_h("rrowoffset_h",nrows+1); /* recv buf of row lens. +1 to make CSR later. Memory might be passed to other views */
533     PetscInt                *rrowlens = rrowlens_h.data();
534 
535     PetscCall(PetscMalloc2(Am,&srowlens,niranks+nranks,&reqs));
536     for (i=0; i<Am; i++) srowlens[i] = Ai[i+1] - Ai[i];
537     rrowlens[0] = 0;
538     rrowlens++; /* shift the pointer to make the following expression more readable */
539     for (i=0; i<niranks; i++)PetscCallMPI(MPI_Irecv(&rrowlens[ioffset[i]],ioffset[i+1]-ioffset[i],MPIU_INT,iranks[i],tag,comm,&reqs[i]));
540     for (i=0; i<nranks; i++) PetscCallMPI(MPI_Isend(&srowlens[roffset[i]],roffset[i+1]-roffset[i],MPIU_INT,ranks[i],tag,comm,&reqs[niranks+i]));
541     PetscCallMPI(MPI_Waitall(niranks+nranks,reqs,MPI_STATUSES_IGNORE));
542 
543     /* Owner builds Ci on host by histogramming rrowlens[] */
544     MatRowMapKokkosViewHost Ci_h("i",Cm+1);
545     Kokkos::deep_copy(Ci_h,0); /* Zero Ci */
546     MatRowMapType *Ci_ptr = Ci_h.data();
547 
548     for (i=0; i<nrows; i++) {
549       r = rows[i]; /* local row id of i-th received row */
550      #if defined(PETSC_USE_DEBUG)
551       PetscCheck(r >= 0 && r < Cm,PETSC_COMM_SELF,PETSC_ERR_PLIB,"local row id (%" PetscInt_FMT ") is out of range [0,%" PetscInt_FMT ")",r,Cm);
552      #endif
553       Ci_ptr[r+1] += rrowlens[i]; /* add to length of row r in C */
554     }
555     for (i=0; i<Cm; i++) Ci_ptr[i+1] += Ci_ptr[i]; /* to CSR format */
556     Cnnz = Ci_ptr[Cm];
557 
558     /* For each received row, compute src & dst offsets in memory copying (from recv bufs abuf, jbuf to Ca, Cj) */
559     MatRowMapKokkosViewHost dstrowoffset_h("dstrowoffset_h",nrows);
560     PetscInt                *dstrowoffset_hptr = dstrowoffset_h.data();
561     PetscInt                *currowlens; /* Current row lens. They are temp accumulators for row lens in C, to help build dstrowoffset */
562 
563     PetscCall(PetscCalloc1(Cm,&currowlens)); /* Init with zero, to be added to */
564     for (i=0; i<nrows; i++) { /* for each row I receive */
565       r                    = rows[i]; /* row id in C */
566       dstrowoffset_hptr[i] = Ci_ptr[r] + currowlens[r]; /* dst offset of the new place for each recv'ed row in Ca/Cj */
567       currowlens[r]       += rrowlens[i]; /* accumulate to length of row r in C */
568     }
569     PetscCall(PetscFree(currowlens));
570 
571     rrowlens--;
572     for (i=0; i<nrows; i++) rrowlens[i+1] += rrowlens[i]; /* Change rrowlens[] to CSR format */
573     dstrowoffset = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(),dstrowoffset_h);
574     srcrowoffset = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(),rrowlens_h); /* src offset of each recv'ed row in abuf/jbuf */
575 
576     /* Build the reduceSF, which performs buffer to buffer send/recv */
577     PetscInt *sdisp,*rdisp; /* buffer to send offsets of roots, and buffer to recv them */
578     PetscCall(PetscMalloc2(niranks,&sdisp,nranks,&rdisp));
579     for (i=0; i<niranks; i++) sdisp[i] = rrowlens[ioffset[i]];
580     for (i=0; i<nranks; i++)  PetscCallMPI(MPI_Irecv(&rdisp[i],1,MPIU_INT,ranks[i],tag,comm,&reqs[i]));
581     for (i=0; i<niranks; i++) PetscCallMPI(MPI_Isend(&sdisp[i],1,MPIU_INT,iranks[i],tag,comm,&reqs[nranks+i]));
582     PetscCallMPI(MPI_Waitall(niranks+nranks,reqs,MPI_STATUSES_IGNORE));
583 
584     /* Nonzeros in abuf/jbuf are roots and those in A are leaves */
585     PetscInt    nroots2 = Cnnz,nleaves2 = Annz;
586     PetscSFNode *iremote;
587     PetscCall(PetscMalloc1(nleaves2,&iremote)); /* no free, since memory will be given to reduceSF */
588     for (i=0; i<nranks; i++) {
589       PetscInt rootbase = rdisp[i]; /* root offset at this root rank */
590       PetscInt leafbase = Ai[roffset[i]]; /* leaf base */
591       PetscInt nz       = Ai[roffset[i+1]] - leafbase; /* I will send nz nonzeros to this root rank */
592       for (PetscInt k=0; k<nz; k++) {
593         iremote[leafbase+k].rank  = ranks[i];
594         iremote[leafbase+k].index = rootbase + k;
595       }
596     }
597     PetscCall(PetscSFCreate(comm,&reduceSF));
598     PetscCall(PetscSFSetGraph(reduceSF,nroots2,nleaves2,NULL,PETSC_OWN_POINTER,iremote,PETSC_OWN_POINTER));
599     PetscCall(PetscFree2(sdisp,rdisp));
600 
601     /* Reduce Aa, Ajg to abuf and jbuf */
602 
603     /* If A uses local col ids, convert them to global ones before sending */
604     MatColIdxKokkosView Ajg;
605     if (local) {
606       Ajg = MatColIdxKokkosView("j",Annz);
607       const MatColIdxKokkosView& Aj = akok->j_dual.view_device();
608       Kokkos::parallel_for(Annz,KOKKOS_LAMBDA(const PetscInt i) {Ajg(i) = l2g(Aj(i));});
609     } else {
610       Ajg = akok->j_dual.view_device(); /* no data copy, just take a reference */
611     }
612 
613     MatColIdxKokkosView   jbuf("jbuf",Cnnz);
614     abuf = MatScalarKokkosView("abuf",Cnnz);
615     PetscCallMPI(PetscSFReduceBegin(reduceSF,MPIU_INT,   Ajg.data(),           jbuf.data(),MPI_REPLACE));
616     PetscCallMPI(PetscSFReduceEnd  (reduceSF,MPIU_INT,   Ajg.data(),           jbuf.data(),MPI_REPLACE));
617     PetscCallMPI(PetscSFReduceBegin(reduceSF,MPIU_SCALAR,akok->a_device_data(),abuf.data(),MPI_REPLACE));
618     PetscCallMPI(PetscSFReduceEnd  (reduceSF,MPIU_SCALAR,akok->a_device_data(),abuf.data(),MPI_REPLACE));
619 
620     /* Copy data from abuf, jbuf to Ca, Cj */
621     MatRowMapKokkosView    Ci = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(),Ci_h); /* Ci is an alias of Ci_h if no device */
622     MatColIdxKokkosView    Cj("j",Cnnz);
623     MatScalarKokkosView    Ca("a",Cnnz);
624 
625     Kokkos::parallel_for(Kokkos::TeamPolicy<>(nrows, Kokkos::AUTO()),KOKKOS_LAMBDA(const KokkosTeamMemberType& t) {
626       PetscInt i   = t.league_rank();
627       PetscInt src = srcrowoffset(i), dst = dstrowoffset(i);
628       PetscInt len = srcrowoffset(i+1) - srcrowoffset(i);
629       Kokkos::parallel_for(Kokkos::TeamThreadRange(t,len), [&](PetscInt k) {
630         Ca(dst+k) = abuf(src+k);
631         Cj(dst+k) = jbuf(src+k);
632       });
633     });
634 
635     /* Build C with Ca, Ci, Cj */
636     C    = KokkosCsrMatrix("csrmat",Cm,N,Cnnz,Ca,Ci,Cj);
637     PetscCall(PetscFree2(srowlens,reqs));
638   } else SETERRQ(PETSC_COMM_SELF,PETSC_ERR_PLIB,"Unsupported MatReuse enum %d",reuse);
639   PetscFunctionReturn(0);
640 }
641 
642 /* MatSetMPIAIJKokkosWithGlobalCSRMatrix - Set the diag and offdiag parts of a MATMPIAIJKOKKOS matrix by splitting a KokkosCsrMatrix
643 
644   Input Parameters:
645 +  C        - the MATMPIAIJKOKKOS matrix, of size m,n,M,N
646 .  reuse    - indicate whether the matrix has called this function before
647 .  csrmat   - the KokkosCsrMatrix, of size m,N
648 -  Cdstart  - when reuse == MAT_REUSE_MATRIX, it is an input parameter. For each row in csrmat, it stores the start of the first
649               entry of the diag block of C in csrmat's j array. E.g, if row i has col ids = {0, 3, 4, 5, 7, 9} and the first diag
650               entry is 5, then Cdstart[i] = 3.
651 
652   Output Parameters:
653 +  C        - the updated MATMPIAIJKOKKOS matrix
654 -  Cdstart - when reuse == MAT_INITIAL_MATRIX, it is an output parameter
655 
656   Notes:
657    Between calls with MAT_INITIAL_MATRIX or MAT_REUSE_MATRIX, csrmat must have the same nonzero pattern
658  */
659 static PetscErrorCode MatSetMPIAIJKokkosWithGlobalCSRMatrix(Mat C,MatReuse reuse,const KokkosCsrMatrix& csrmat,MatRowMapKokkosView& Cdstart)
660 {
661   const MatScalarKokkosView&      Ca = csrmat.values;
662   const ConstMatRowMapKokkosView& Ci = csrmat.graph.row_map;
663   PetscInt                        m,n,N;
664 
665   PetscFunctionBegin;
666   PetscCall(MatGetLocalSize(C,&m,&n));
667   PetscCall(MatGetSize(C,NULL,&N));
668 
669   if (reuse == MAT_REUSE_MATRIX) {
670     Mat_MPIAIJ                  *mpiaij = static_cast<Mat_MPIAIJ*>(C->data);
671     Mat_SeqAIJKokkos            *akok = static_cast<Mat_SeqAIJKokkos*>(mpiaij->A->spptr);
672     Mat_SeqAIJKokkos            *bkok = static_cast<Mat_SeqAIJKokkos*>(mpiaij->B->spptr);
673     const MatScalarKokkosView&  Cda = akok->a_dual.view_device(),Coa = bkok->a_dual.view_device();
674     const MatRowMapKokkosView&  Cdi = akok->i_dual.view_device(),Coi = bkok->i_dual.view_device();
675 
676     /* Fill 'a' of Cd and Co on device */
677     Kokkos::parallel_for(Kokkos::TeamPolicy<>(m, Kokkos::AUTO()),KOKKOS_LAMBDA(const KokkosTeamMemberType& t) {
678       PetscInt i       = t.league_rank(); /* row i */
679       PetscInt clen    = Ci(i+1) - Ci(i); /* len of row i of C */
680       PetscInt cdlen   = Cdi(i+1) - Cdi(i); /* len of row i of Cd */
681       PetscInt cdstart = Cdstart(i); /* [start, end) of row i of Cd in C */
682       PetscInt cdend   = cdstart + cdlen;
683       /* [0, clen) is cut into three blocks: [0, cdstart), [cdstart, cdend), [cdend, clen) */
684       Kokkos::parallel_for(Kokkos::TeamThreadRange(t, clen), [&](PetscInt k) {
685         if (k < cdstart) {  /* k in [0, cdstart) */
686           Coa(Coi(i)+k) = Ca(Ci(i)+k);
687         } else if (k < cdend) { /* k in [cdstart, cdend) */
688           Cda(Cdi(i)+(k-cdstart)) = Ca(Ci(i)+k);
689         } else { /* k in [cdend, clen) */
690           Coa(Coi(i)+k-cdlen) = Ca(Ci(i)+k);
691         }
692       });
693     });
694 
695     akok->a_dual.modify_device();
696     bkok->a_dual.modify_device();
697   } else if (reuse == MAT_INITIAL_MATRIX) {
698     Mat                         Cd,Co;
699     const MatColIdxKokkosView&  Cj = csrmat.graph.entries;
700     MatRowMapKokkosDualView     Cdi_dual("i",m+1),Coi_dual("i",m+1);
701     MatRowMapKokkosView         Cdi = Cdi_dual.view_device(),Coi = Coi_dual.view_device();
702     PetscInt                    cstart,cend;
703 
704     /* Note that each row of C is sorted by col ids. We want to find out how to cut each row into three blocks:
705        left to the diag block, diag block, right to the diag block. The diag block have col ids in [cstart,cend).
706        Suppose a row of C has len nonzeros, indexed by [0, len). We want to know two indices: cdstart and cdend,
707        such that the three blocks are [0,cdstart), [cdstart,cdend), [cdend,len). The following code equivalentaly
708        stores values of cdstart and cdend-cstart (aka Cdi[]) instead.
709      */
710     Cdstart = MatRowMapKokkosView("Cdstart",m);
711     PetscCall(PetscLayoutGetRange(C->cmap,&cstart,&cend)); /* Not MatGetOwnershipRangeColumn() since C has not been preallocated yet */
712 
713     /* I could use RangePolicy and one thread per row. But since each thread essentially does binary search, threads in a
714       CUDA warp would completely diverge. So I use TeamPolicy with a team size 1.
715      */
716     Kokkos::parallel_for(Kokkos::TeamPolicy<>(m, 1),KOKKOS_LAMBDA(const KokkosTeamMemberType& t) {
717       Kokkos::single(Kokkos::PerTeam(t), [=] () { /* Only one thread works in a team */
718         PetscInt i = t.league_rank(); /* row i */
719         PetscInt j,first,count,step;
720 
721         if (i == 0) { /* Set the first entry of the i arrays to zero on device, to be used in CSR */
722           Cdi(0) = 0;
723           Coi(0) = 0;
724         }
725 
726         /* Do std::lower_bound(Ci(i),Ci(i+1),cstart) on Cj[]. We use j as the iterator. lower_bound() returns
727           in 'first' the first iterator with a value >= cstart, or last iterator if no such element is found.
728         */
729         count = Ci(i+1)-Ci(i);
730         first = Ci(i);
731         while (count > 0) {
732           j    = first;
733           step = count / 2;
734           j   += step;
735           if (Cj(j) < cstart) {
736             first  = ++j;
737             count -= step + 1;
738           } else count = step;
739         }
740         Cdstart(i) = first - Ci(i); /* 'first' is the while-loop's output */
741 
742         /* Do std::lower_bound(first,Ci(i+1),cend) on Cj[] */
743         count = Ci(i+1) - first;
744         while (count > 0) {
745           j    = first;
746           step = count / 2;
747           j   += step;
748           if (Cj(j) < cend) {
749             first  = ++j;
750             count -= step + 1;
751           } else count = step;
752         }
753         Cdi(i+1) = first - (Ci(i)+Cdstart(i)); /* 'first' is the while-loop's output */
754         Coi(i+1) = (Ci(i+1)-Ci(i)) - Cdi(i+1); /* Co's row len = C's row len - Cd's row len */
755       });
756     });
757 
758     /* Convert row lens in Cdi[], Coi[] to CSR format using inclusive scan, e.g., changing [0,1,2,3] into [0,1,3,6] */
759     Kokkos::parallel_scan(m+1,KOKKOS_LAMBDA(const PetscInt i,PetscInt& update,const bool final) {
760       update += Cdi(i);
761       if (final) Cdi(i) = update;
762     });
763     Kokkos::parallel_scan(m+1,KOKKOS_LAMBDA(const PetscInt i,PetscInt& update,const bool final) {
764       update += Coi(i);
765       if (final) Coi(i) = update;
766     });
767 
768     /* Get Cdi, Coi on host (it is not a waste, since we do need them on host in
769        MatCreateSeqAIJKokkosWithCSRMatrix() below), then get nnz of Cd and Co.
770     */
771     Cdi_dual.modify_device();
772     Coi_dual.modify_device();
773     Cdi_dual.sync_host();
774     Coi_dual.sync_host();
775     PetscInt Cd_nnz = Cdi_dual.view_host().data()[m];
776     PetscInt Co_nnz = Coi_dual.view_host().data()[m];
777 
778     /* With nnz, allocate a, j for Cd and Co */
779     MatColIdxKokkosDualView Cdj_dual("j",Cd_nnz),Coj_dual("j",Co_nnz);
780     MatScalarKokkosDualView Cda_dual("a",Cd_nnz),Coa_dual("a",Co_nnz);
781 
782     /* Fill a, j of Cd and Co on device */
783     MatColIdxKokkosView     Cdj = Cdj_dual.view_device(),Coj = Coj_dual.view_device();
784     MatScalarKokkosView     Cda = Cda_dual.view_device(),Coa = Coa_dual.view_device();
785 
786     Kokkos::parallel_for(Kokkos::TeamPolicy<>(m, Kokkos::AUTO()),KOKKOS_LAMBDA(const KokkosTeamMemberType& t) {
787       PetscInt i       = t.league_rank(); /* row i */
788       PetscInt clen    = Ci(i+1) - Ci(i); /* len of row i of C */
789       PetscInt cdlen   = Cdi(i+1) - Cdi(i); /* len of row i of Cd */
790       PetscInt cdstart = Cdstart(i); /* [start, end) of row i of Cd in C */
791       PetscInt cdend   = cdstart + cdlen;
792       /* [0, clen) is cut into three blocks: [0, cdstart), [cdstart, cdend), [cdend, clen) */
793       Kokkos::parallel_for(Kokkos::TeamThreadRange(t, clen), [&](PetscInt k) {
794         if (k < cdstart) { /* k in [0, cdstart) */
795           Coa(Coi(i)+k) = Ca(Ci(i)+k);
796           Coj(Coi(i)+k) = Cj(Ci(i)+k);
797         } else if (k < cdend) { /* k in [cdstart, cdend) */
798           Cda(Cdi(i)+(k-cdstart)) = Ca(Ci(i)+k);
799           Cdj(Cdi(i)+(k-cdstart)) = Cj(Ci(i)+k) - cstart; /* Use local col ids in Cdj */
800         } else { /* k in [cdend, clen) */
801           Coa(Coi(i)+k-cdlen) = Ca(Ci(i)+k);
802           Coj(Coi(i)+k-cdlen) = Cj(Ci(i)+k);
803         }
804       });
805     });
806 
807     Cdj_dual.modify_device();
808     Cda_dual.modify_device();
809     Coj_dual.modify_device();
810     Coa_dual.modify_device();
811     /* With a, i, j for Cd and Co, finally build Cd, Co and then C. Their offloadmask will be set in each's MatAssemblyEnd */
812     auto cdkok = new Mat_SeqAIJKokkos(m,n,Cd_nnz,Cdi_dual,Cdj_dual,Cda_dual);
813     auto cokok = new Mat_SeqAIJKokkos(m,N,Co_nnz,Coi_dual,Coj_dual,Coa_dual);
814     PetscCall(MatCreateSeqAIJKokkosWithCSRMatrix(PETSC_COMM_SELF,cdkok,&Cd));
815     PetscCall(MatCreateSeqAIJKokkosWithCSRMatrix(PETSC_COMM_SELF,cokok,&Co));
816     PetscCall(MatSetMPIAIJKokkosWithSplitSeqAIJKokkosMatrices(C,Cd,Co)); /* Coj will be converted to local ids within */
817   }
818   PetscFunctionReturn(0);
819 }
820 
821 /* MatSeqAIJCompactOutExtraColumns_SeqAIJKokkos - Compact a SEQAIJKOKKS matrix's global col ids.
822 
823   It is similar to MatSeqAIJCompactOutExtraColumns_SeqAIJ, but it applies to SEQAIJKOKKOS and returns the l2g map in Kokkos view.
824 
825   Input Parameters:
826 +  C        - the MATMPIAIJKOKKOS matrix, of size m,n,M,N
827 .  reuse    - indicate whether the matrix has called this function before
828 .  csrmat   - the KokkosCsrMatrix, of size m,N
829 -  Cdoffset - when reuse == MAT_REUSE_MATRIX, it is an input parameter. For each row in csrmat, it stores the offset of the first
830               entry of the diag block of C in csrmat's j array.
831 
832   Output Parameters:
833 +  C        - the updated MATMPIAIJKOKKOS matrix
834 -  Cdoffset - when reuse == MAT_INITIAL_MATRIX, it is an output parameter
835 
836   Notes: the input matrix's col ids and col size will be changed.
837 */
838 static PetscErrorCode MatSeqAIJCompactOutExtraColumns_SeqAIJKokkos(Mat C,MatColIdxKokkosView& l2g)
839 {
840   Mat_SeqAIJKokkos       *ckok;
841   ISLocalToGlobalMapping l2gmap;
842   const PetscInt         *garray;
843   PetscInt               sz;
844 
845   PetscFunctionBegin;
846   /* Compact P_other's global col ids and col size. We do it since we guess with local ids KK might be more memory scalable */
847   PetscCall(MatSeqAIJCompactOutExtraColumns_SeqAIJ(C,&l2gmap));
848   ckok = static_cast<Mat_SeqAIJKokkos*>(C->spptr);
849   ckok->j_dual.modify_host(); /* P_other's j is modified on host; we need to sync it on device */
850   ckok->j_dual.sync_device();
851   ckok->SetColSize(C->cmap->n); /* Update col size of the csrmat in spptr */
852 
853   /* Build l2g -- the local to global mapping of C's cols */
854   PetscCall(ISLocalToGlobalMappingGetIndices(l2gmap,&garray));
855   PetscCall(ISLocalToGlobalMappingGetSize(l2gmap,&sz));
856   PetscCheck(C->cmap->n == sz,PETSC_COMM_SELF,PETSC_ERR_PLIB,"matrix column size(%" PetscInt_FMT ") != l2g mapping size(%" PetscInt_FMT ")", C->cmap->n,sz);
857 
858   ConstMatColIdxKokkosViewHost tmp(garray,sz);
859   l2g = MatColIdxKokkosView("l2g",sz);
860   Kokkos::deep_copy(l2g,tmp);
861 
862   PetscCall(ISLocalToGlobalMappingRestoreIndices(l2gmap,&garray));
863   PetscCall(ISLocalToGlobalMappingDestroy(&l2gmap));
864   PetscFunctionReturn(0);
865 }
866 
867 /* MatProductSymbolic_MPIAIJKokkos_AB - AB flavor of MatProductSymbolic_MPIAIJKokkos
868 
869   Input Parameters:
870 +  product  - Mat_Product which carried out the computation. Passed in to access info about this mat product.
871 .  A        - an MPIAIJKOKKOS matrix
872 .  B        - an MPIAIJKOKKOS matrix
873 -  mm       - a struct used to stash intermediate data when computing AB. Persist from symbolic to numeric operations.
874 
875   Notes: The local part of the result C is stored as mm->C_global, which is of type KokkosCsrMatrix and uses global col ids.
876 */
877 static PetscErrorCode MatProductSymbolic_MPIAIJKokkos_AB(Mat_Product *product,Mat A,Mat B,MatMatStruct_AB *mm)
878 {
879   Mat_MPIAIJ                  *a = static_cast<Mat_MPIAIJ*>(A->data);
880   Mat                         Ad = a->A,Ao = a->B; /* diag and offdiag of A */
881   IS                          glob = NULL;
882   const PetscInt              *garray;
883   PetscInt                    N = B->cmap->N,sz;
884   ConstMatColIdxKokkosView    l2g1; /* two temp maps mapping local col ids to global ones */
885   MatColIdxKokkosView         l2g2;
886   Mat                         C1,C2; /* intermediate matrices */
887 
888   PetscFunctionBegin;
889   /* C1 = Ad * B_local. B_local is a matrix got by merging Bd and Bo, and uses local col ids */
890   PetscCall(MatMPIAIJGetLocalMatMerge(B,MAT_INITIAL_MATRIX,&glob,&mm->B_local));
891   PetscCall(MatProductCreate(Ad,mm->B_local,NULL,&C1));
892   PetscCall(MatProductSetType(C1,MATPRODUCT_AB));
893   PetscCall(MatProductSetFill(C1,product->fill));
894   C1->product->api_user = product->api_user;
895   PetscCall(MatProductSetFromOptions(C1));
896   PetscUseTypeMethod(C1,productsymbolic);
897 
898   PetscCall(ISGetIndices(glob,&garray));
899   PetscCall(ISGetSize(glob,&sz));
900   const auto& tmp  = ConstMatColIdxKokkosViewHost(garray,sz); /* wrap garray as a view */
901   l2g1 = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(),tmp); /* maybe just an alias to tmp, so we restore garray at the very end */
902   PetscCall(MatSeqAIJKokkosGetCSRMatrixWithGlobalColumnIds(C1,N,l2g1,mm->C1_global));
903 
904   /* C2 = Ao * B_other. B_other is a matrix consisting of needed rows of B gathered from other procs */
905   PetscCall(MatSeqAIJKokkosBcast(mm->B_local,MAT_INITIAL_MATRIX,N,l2g1,a->Mvctx,mm->sf,mm->abuf,mm->rows,mm->rowoffset,mm->B_other));
906 
907   /* Compact B_other to use local ids as we guess KK spgemm is more memroy scalable with that; We could skip the compaction to simplify code */
908   PetscCall(MatSeqAIJCompactOutExtraColumns_SeqAIJKokkos(mm->B_other,l2g2));
909   PetscCall(MatProductCreate(Ao,mm->B_other,NULL,&C2));
910   PetscCall(MatProductSetType(C2,MATPRODUCT_AB));
911   PetscCall(MatProductSetFill(C2,product->fill));
912   C2->product->api_user = product->api_user;
913   PetscCall(MatProductSetFromOptions(C2));
914   PetscUseTypeMethod(C2,productsymbolic);
915   PetscCall(MatSeqAIJKokkosGetCSRMatrixWithGlobalColumnIds(C2,N,l2g2,mm->C2_global));
916 
917   /* C = C1 + C2.  We actually use their global col ids versions in adding */
918   mm->kh.create_spadd_handle(false); /* Input C1, C2 are NOT sorted, since B_local, B_other are not */
919   KokkosSparse::spadd_symbolic(&mm->kh,mm->C1_global,mm->C2_global,mm->C_global);
920   /* Have to do numeric since spadd_symbolic does not really populate column indices of the result matrix */
921   KokkosSparse::spadd_numeric(&mm->kh,(MatScalarType)1.0,mm->C1_global,(MatScalarType)1.0,mm->C2_global,mm->C_global);
922 
923   mm->C1 = C1;
924   mm->C2 = C2;
925   PetscCall(ISRestoreIndices(glob,&garray));
926   PetscCall(ISDestroy(&glob));
927   PetscFunctionReturn(0);
928 }
929 
930 /* MatProductSymbolic_MPIAIJKokkos_AtB - A^tB flavor of MatProductSymbolic_MPIAIJKokkos
931 
932   Input Parameters:
933 +  product  - Mat_Product which carried out the computation. Passed in to access info about this mat product.
934 .  A        - an MPIAIJKOKKOS matrix
935 .  B        - a SEQAIJKOKKOS matrix. It works as if A^t is multiplied by a parallel matrix made up of Bs on each rank.
936 .  localB   - Does B use local col ids? If false, then B is already in global col ids.
937 .  N        - col size of the "parallel B matrix". It implies B's global col ids are in range of [0,N) and N is the same across the communicator.
938 .  l2g      - If localB, then l2g maps B's local col ids to global ones.
939 -  mm       - a struct used to stash intermediate data in AtB
940 
941   Notes: The local part of the result C is stored as mm->C_global, which is of type KokkosCsrMatrix and uses global col ids.
942 */
943 static PetscErrorCode MatProductSymbolic_MPIAIJKokkos_AtB(Mat_Product *product,Mat A,Mat B,PetscBool localB,PetscInt N,const ConstMatColIdxKokkosView& l2g,MatMatStruct_AtB *mm)
944 {
945   Mat_MPIAIJ             *a = static_cast<Mat_MPIAIJ*>(A->data);
946   Mat                    Ad = a->A,Ao = a->B; /* diag and offdiag of A */
947   Mat                    C1,C2; /* intermediate matrices */
948 
949   PetscFunctionBegin;
950   /* C1 = Ad^t * B */
951   PetscCall(MatProductCreate(Ad,B,NULL,&C1));
952   PetscCall(MatProductSetType(C1,MATPRODUCT_AtB));
953   PetscCall(MatProductSetFill(C1,product->fill));
954   C1->product->api_user = product->api_user;
955   PetscCall(MatProductSetFromOptions(C1));
956   PetscUseTypeMethod(C1,productsymbolic);
957 
958   if (localB) PetscCall(MatSeqAIJKokkosGetCSRMatrixWithGlobalColumnIds(C1,N,l2g,mm->C1_global));
959   else mm->C1_global = static_cast<Mat_SeqAIJKokkos*>(C1->spptr)->csrmat; /* the csrmat already uses global col ids */
960 
961   /* C2 = Ao^t * B */
962   PetscCall(MatProductCreate(Ao,B,NULL,&C2));
963   PetscCall(MatProductSetType(C2,MATPRODUCT_AtB));
964   PetscCall(MatProductSetFill(C2,product->fill));
965   C2->product->api_user = product->api_user;
966   PetscCall(MatProductSetFromOptions(C2));
967   PetscUseTypeMethod(C2,productsymbolic);
968 
969   PetscCall(MatSeqAIJKokkosReduce(C2,MAT_INITIAL_MATRIX,localB,N,l2g,a->Mvctx,mm->sf,mm->abuf,mm->srcrowoffset,mm->dstrowoffset,mm->C2_global));
970 
971   mm->kh.create_spadd_handle(false); /* Input C1, C2 are NOT sorted, since B may be not */
972   KokkosSparse::spadd_symbolic(&mm->kh,mm->C1_global,mm->C2_global,mm->C_global);
973   /* Have to do numeric since spadd_symbolic does not really populate column indices of the result matrix */
974   KokkosSparse::spadd_numeric(&mm->kh,(MatScalarType)1.0,mm->C1_global,(MatScalarType)1.0,mm->C2_global,mm->C_global);
975   mm->C1 = C1;
976   mm->C2 = C2;
977   PetscFunctionReturn(0);
978 }
979 
980 PetscErrorCode MatProductNumeric_MPIAIJKokkos(Mat C)
981 {
982   Mat_Product                   *product = C->product;
983   MatProductType                ptype;
984   MatProductData_MPIAIJKokkos   *mmdata;
985   MatMatStruct                  *mm = NULL;
986   MatMatStruct_AB               *ab;
987   MatMatStruct_AtB              *atb;
988   Mat                           A,B,Ad,Ao,Bd,Bo;
989   const MatScalarType           one = 1.0; /* Not use literal 1.0 directly, to avoid wrong template instantiation in KokkosSparse::spadd_numeric */
990 
991   PetscFunctionBegin;
992   MatCheckProduct(C,1);
993   mmdata = static_cast<MatProductData_MPIAIJKokkos*>(product->data);
994   ptype  = product->type;
995   A      = product->A;
996   B      = product->B;
997   Ad     = static_cast<Mat_MPIAIJ*>(A->data)->A;
998   Ao     = static_cast<Mat_MPIAIJ*>(A->data)->B;
999   Bd     = static_cast<Mat_MPIAIJ*>(B->data)->A;
1000   Bo     = static_cast<Mat_MPIAIJ*>(B->data)->B;
1001 
1002   if (mmdata->reusesym) { /* We reached here through e.g., MatMatMult(A,B,MAT_INITIAL_MATRIX,..,C), where symbolic/numeric are combined */
1003     mmdata->reusesym = PETSC_FALSE; /* So that next time when user calls MatMatMult(E,F,MAT_REUSE_MATRIX,..,C), we still do numeric  */
1004     ab  = mmdata->mmAB;
1005     atb = mmdata->mmAtB;
1006     if (ab) {
1007       static_cast<MatProductData_SeqAIJKokkos*>(ab->C1->product->data)->reusesym = PETSC_FALSE;
1008       static_cast<MatProductData_SeqAIJKokkos*>(ab->C2->product->data)->reusesym = PETSC_FALSE;
1009     }
1010     if (atb) {
1011       static_cast<MatProductData_SeqAIJKokkos*>(atb->C1->product->data)->reusesym = PETSC_FALSE;
1012       static_cast<MatProductData_SeqAIJKokkos*>(atb->C2->product->data)->reusesym = PETSC_FALSE;
1013     }
1014     PetscFunctionReturn(0);
1015   }
1016 
1017   if (ptype == MATPRODUCT_AB) {
1018     ab   = mmdata->mmAB;
1019     /* C1 = Ad * B_local */
1020     PetscCheck(ab->C1->ops->productnumeric && ab->C2->ops->productnumeric,PETSC_COMM_SELF,PETSC_ERR_PLIB,"Missing numeric op for MATPRODUCT_AB");
1021     PetscCall(MatMPIAIJGetLocalMatMerge(B,MAT_REUSE_MATRIX,NULL/*glob*/,&ab->B_local));
1022     PetscCheck(ab->C1->product->B == ab->B_local,PETSC_COMM_SELF,PETSC_ERR_PLIB,"In MATPRODUCT_AB, internal mat product matrix C1->B has unexpectedly changed");
1023     if (ab->C1->product->A != Ad) PetscCall(MatProductReplaceMats(Ad,NULL,NULL,ab->C1));
1024     PetscCall((*ab->C1->ops->productnumeric)(ab->C1));
1025     PetscCall(MatSeqAIJKokkosBcast(ab->B_local,MAT_REUSE_MATRIX,0/*N*/,MatColIdxKokkosView()/*l2g*/,NULL/*ownerSF*/,ab->sf,
1026                                  ab->abuf,ab->rows,ab->rowoffset,ab->B_other));
1027     /* C2 = Ao * B_other */
1028     PetscCheck(ab->C2->product->B == ab->B_other,PETSC_COMM_SELF,PETSC_ERR_PLIB,"In MATPRODUCT_AB, internal mat product matrix C2->B has unexpectedly changed");
1029     if (ab->C1->product->A != Ao) PetscCall(MatProductReplaceMats(Ao,NULL,NULL,ab->C2));
1030     PetscCall((*ab->C2->ops->productnumeric)(ab->C2));
1031     /* C = C1_global + C2_global */
1032     KokkosSparse::spadd_numeric(&ab->kh,one,ab->C1_global,one,ab->C2_global,ab->C_global);
1033     mm = static_cast<MatMatStruct*>(ab);
1034   } else if (ptype == MATPRODUCT_AtB) {
1035     atb  = mmdata->mmAtB;
1036     PetscCheck(atb->C1->ops->productnumeric && atb->C2->ops->productnumeric,PETSC_COMM_SELF,PETSC_ERR_PLIB,"Missing numeric op for MATPRODUCT_AtB");
1037     /* C1 = Ad^t * B_local */
1038     PetscCall(MatMPIAIJGetLocalMatMerge(B,MAT_REUSE_MATRIX,NULL/*glob*/,&atb->B_local));
1039     PetscCheck(atb->C1->product->B == atb->B_local,PETSC_COMM_SELF,PETSC_ERR_PLIB,"In MATPRODUCT_AtB, internal mat product matrix C1->B has unexpectedly changed");
1040     if (atb->C1->product->A != Ad) PetscCall(MatProductReplaceMats(Ad,NULL,NULL,atb->C1));
1041     PetscCall((*atb->C1->ops->productnumeric)(atb->C1));
1042 
1043     /* C2 = Ao^t * B_local */
1044     PetscCheck(atb->C2->product->B == atb->B_local,PETSC_COMM_SELF,PETSC_ERR_PLIB,"In MATPRODUCT_AtB, internal mat product matrix C2->B has unexpectedly changed");
1045     if (atb->C2->product->A != Ao) PetscCall(MatProductReplaceMats(Ao,NULL,NULL,atb->C2));
1046     PetscCall((*atb->C2->ops->productnumeric)(atb->C2));
1047     /* Form C2_global */
1048     PetscCall(MatSeqAIJKokkosReduce(atb->C2,MAT_REUSE_MATRIX,PETSC_TRUE,0/*N*/,MatColIdxKokkosView()/*l2g*/,NULL/*ownerSF*/,atb->sf,
1049                                   atb->abuf,atb->srcrowoffset,atb->dstrowoffset,atb->C2_global));
1050     /* C = C1_global + C2_global */
1051     KokkosSparse::spadd_numeric(&atb->kh,one,atb->C1_global,one,atb->C2_global,atb->C_global);
1052     mm = static_cast<MatMatStruct*>(atb);
1053   } else if (ptype == MATPRODUCT_PtAP) { /* BtAB */
1054     ab   = mmdata->mmAB;
1055     PetscCall(MatMPIAIJGetLocalMatMerge(B,MAT_REUSE_MATRIX,NULL/*glob*/,&ab->B_local));
1056 
1057     /* ab->C1 = Ad * B_local */
1058     PetscCheck(ab->C1->product->B == ab->B_local,PETSC_COMM_SELF,PETSC_ERR_PLIB,"In MATPRODUCT_PtAP, internal mat product matrix ab->C1->B has unexpectedly changed");
1059     if (ab->C1->product->A != Ad) PetscCall(MatProductReplaceMats(Ad,NULL,NULL,ab->C1));
1060     PetscCall((*ab->C1->ops->productnumeric)(ab->C1));
1061     PetscCall(MatSeqAIJKokkosBcast(ab->B_local,MAT_REUSE_MATRIX,0/*N*/,MatColIdxKokkosView()/*l2g*/,NULL/*ownerSF*/,ab->sf,
1062                                  ab->abuf,ab->rows,ab->rowoffset,ab->B_other));
1063     /* ab->C2 = Ao * B_other */
1064     if (ab->C2->product->A != Ao) PetscCall(MatProductReplaceMats(Ao,NULL,NULL,ab->C2));
1065     PetscCall((*ab->C2->ops->productnumeric)(ab->C2)); /* C2 = Ao * B_other */
1066     KokkosSparse::spadd_numeric(&ab->kh,one,ab->C1_global,one,ab->C2_global,ab->C_global);
1067 
1068     /* atb->C1 = Bd^t * ab->C_petsc */
1069     atb  = mmdata->mmAtB;
1070     PetscCheck(atb->C1->product->B == ab->C_petsc,PETSC_COMM_SELF,PETSC_ERR_PLIB,"In MATPRODUCT_PtAP, internal mat product matrix atb->C1->B has unexpectedly changed");
1071     if (atb->C1->product->A != Bd) PetscCall(MatProductReplaceMats(Bd,NULL,NULL,atb->C1));
1072     PetscCall((*atb->C1->ops->productnumeric)(atb->C1));
1073     /* atb->C2 = Bo^t * ab->C_petsc */
1074     if (atb->C2->product->A != Bo) PetscCall(MatProductReplaceMats(Bo,NULL,NULL,atb->C2));
1075     PetscCall((*atb->C2->ops->productnumeric)(atb->C2));
1076     PetscCall(MatSeqAIJKokkosReduce(atb->C2,MAT_REUSE_MATRIX,PETSC_FALSE,0/*N*/,MatColIdxKokkosView()/*l2g*/,NULL/*ownerSF*/,atb->sf,
1077                                   atb->abuf,atb->srcrowoffset,atb->dstrowoffset,atb->C2_global));
1078     KokkosSparse::spadd_numeric(&atb->kh,one,atb->C1_global,one,atb->C2_global,atb->C_global);
1079     mm = static_cast<MatMatStruct*>(atb);
1080   }
1081   /* Split C_global to form C */
1082   PetscCall(MatSetMPIAIJKokkosWithGlobalCSRMatrix(C,MAT_REUSE_MATRIX,mm->C_global,mm->Cdstart));
1083   PetscFunctionReturn(0);
1084 }
1085 
1086 PetscErrorCode MatProductSymbolic_MPIAIJKokkos(Mat C)
1087 {
1088   Mat                         A,B;
1089   Mat_Product                 *product = C->product;
1090   MatProductType              ptype;
1091   MatProductData_MPIAIJKokkos *mmdata;
1092   MatMatStruct                *mm = NULL;
1093   IS                          glob = NULL;
1094   const PetscInt              *garray;
1095   PetscInt                    m,n,M,N,sz;
1096   ConstMatColIdxKokkosView    l2g; /* map local col ids to global ones */
1097 
1098   PetscFunctionBegin;
1099   MatCheckProduct(C,1);
1100   PetscCheck(!product->data,PetscObjectComm((PetscObject)C),PETSC_ERR_PLIB,"Product data not empty");
1101   ptype = product->type;
1102   A     = product->A;
1103   B     = product->B;
1104 
1105   switch (ptype) {
1106     case MATPRODUCT_AB:   m = A->rmap->n; n = B->cmap->n; M = A->rmap->N; N = B->cmap->N; break;
1107     case MATPRODUCT_AtB:  m = A->cmap->n; n = B->cmap->n; M = A->cmap->N; N = B->cmap->N; break;
1108     case MATPRODUCT_PtAP: m = B->cmap->n; n = B->cmap->n; M = B->cmap->N; N = B->cmap->N; break; /* BtAB */
1109     default: SETERRQ(PetscObjectComm((PetscObject)C),PETSC_ERR_PLIB,"Not for product type %s",MatProductTypes[ptype]);
1110   }
1111 
1112   PetscCall(MatSetSizes(C,m,n,M,N));
1113   PetscCall(PetscLayoutSetUp(C->rmap));
1114   PetscCall(PetscLayoutSetUp(C->cmap));
1115   PetscCall(MatSetType(C,((PetscObject)A)->type_name));
1116 
1117   mmdata           = new MatProductData_MPIAIJKokkos();
1118   mmdata->reusesym = product->api_user;
1119 
1120   if (ptype == MATPRODUCT_AB) {
1121     mmdata->mmAB = new MatMatStruct_AB();
1122     PetscCall(MatProductSymbolic_MPIAIJKokkos_AB(product,A,B,mmdata->mmAB));
1123     mm   = static_cast<MatMatStruct*>(mmdata->mmAB);
1124   } else if (ptype == MATPRODUCT_AtB) {
1125     mmdata->mmAtB = new MatMatStruct_AtB();
1126     auto atb      = mmdata->mmAtB;
1127     PetscCall(MatMPIAIJGetLocalMatMerge(B,MAT_INITIAL_MATRIX,&glob,&atb->B_local));
1128     PetscCall(ISGetIndices(glob,&garray));
1129     PetscCall(ISGetSize(glob,&sz));
1130     l2g  = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(),ConstMatColIdxKokkosViewHost(garray,sz));
1131     PetscCall(MatProductSymbolic_MPIAIJKokkos_AtB(product,A,atb->B_local,PETSC_TRUE,N,l2g,atb));
1132     PetscCall(ISRestoreIndices(glob,&garray));
1133     PetscCall(ISDestroy(&glob));
1134     mm   = static_cast<MatMatStruct*>(atb);
1135   } else if (ptype == MATPRODUCT_PtAP) { /* BtAB */
1136     mmdata->mmAB  = new MatMatStruct_AB(); /* tmp=A*B */
1137     mmdata->mmAtB = new MatMatStruct_AtB(); /* C=B^t*tmp */
1138     auto ab       = mmdata->mmAB;
1139     auto atb      = mmdata->mmAtB;
1140     PetscCall(MatProductSymbolic_MPIAIJKokkos_AB(product,A,B,ab));
1141     auto tmp = new Mat_SeqAIJKokkos(ab->C_global); /* Memory will be owned by ab->C_petsc */
1142     PetscCall(MatCreateSeqAIJKokkosWithCSRMatrix(PETSC_COMM_SELF,tmp,&ab->C_petsc));
1143     PetscCall(MatProductSymbolic_MPIAIJKokkos_AtB(product,B,ab->C_petsc,PETSC_FALSE,N,l2g/*not used*/,atb));
1144     mm   = static_cast<MatMatStruct*>(atb);
1145   }
1146   /* Split the C_global into petsc A, B format */
1147   PetscCall(MatSetMPIAIJKokkosWithGlobalCSRMatrix(C,MAT_INITIAL_MATRIX,mm->C_global,mm->Cdstart));
1148   C->product->data        = mmdata;
1149   C->product->destroy     = MatProductDataDestroy_MPIAIJKokkos;
1150   C->ops->productnumeric  = MatProductNumeric_MPIAIJKokkos;
1151   PetscFunctionReturn(0);
1152 }
1153 
1154 PETSC_INTERN PetscErrorCode MatProductSetFromOptions_MPIAIJKokkos(Mat mat)
1155 {
1156   Mat_Product    *product = mat->product;
1157   PetscBool      match = PETSC_FALSE;
1158   PetscBool      usecpu = PETSC_FALSE;
1159 
1160   PetscFunctionBegin;
1161   MatCheckProduct(mat,1);
1162   if (!product->A->boundtocpu && !product->B->boundtocpu) {
1163     PetscCall(PetscObjectTypeCompare((PetscObject)product->B,((PetscObject)product->A)->type_name,&match));
1164   }
1165   if (match) { /* we can always fallback to the CPU if requested */
1166     switch (product->type) {
1167     case MATPRODUCT_AB:
1168       if (product->api_user) {
1169         PetscOptionsBegin(PetscObjectComm((PetscObject)mat),((PetscObject)mat)->prefix,"MatMatMult","Mat");
1170         PetscCall(PetscOptionsBool("-matmatmult_backend_cpu","Use CPU code","MatMatMult",usecpu,&usecpu,NULL));
1171         PetscOptionsEnd();
1172       } else {
1173         PetscOptionsBegin(PetscObjectComm((PetscObject)mat),((PetscObject)mat)->prefix,"MatProduct_AB","Mat");
1174         PetscCall(PetscOptionsBool("-mat_product_algorithm_backend_cpu","Use CPU code","MatMatMult",usecpu,&usecpu,NULL));
1175         PetscOptionsEnd();
1176       }
1177       break;
1178     case MATPRODUCT_AtB:
1179       if (product->api_user) {
1180         PetscOptionsBegin(PetscObjectComm((PetscObject)mat),((PetscObject)mat)->prefix,"MatTransposeMatMult","Mat");
1181         PetscCall(PetscOptionsBool("-mattransposematmult_backend_cpu","Use CPU code","MatTransposeMatMult",usecpu,&usecpu,NULL));
1182         PetscOptionsEnd();
1183       } else {
1184         PetscOptionsBegin(PetscObjectComm((PetscObject)mat),((PetscObject)mat)->prefix,"MatProduct_AtB","Mat");
1185         PetscCall(PetscOptionsBool("-mat_product_algorithm_backend_cpu","Use CPU code","MatTransposeMatMult",usecpu,&usecpu,NULL));
1186         PetscOptionsEnd();
1187       }
1188       break;
1189     case MATPRODUCT_PtAP:
1190       if (product->api_user) {
1191         PetscOptionsBegin(PetscObjectComm((PetscObject)mat),((PetscObject)mat)->prefix,"MatPtAP","Mat");
1192         PetscCall(PetscOptionsBool("-matptap_backend_cpu","Use CPU code","MatPtAP",usecpu,&usecpu,NULL));
1193         PetscOptionsEnd();
1194       } else {
1195         PetscOptionsBegin(PetscObjectComm((PetscObject)mat),((PetscObject)mat)->prefix,"MatProduct_PtAP","Mat");
1196         PetscCall(PetscOptionsBool("-mat_product_algorithm_backend_cpu","Use CPU code","MatPtAP",usecpu,&usecpu,NULL));
1197         PetscOptionsEnd();
1198       }
1199       break;
1200     default:
1201       break;
1202     }
1203     match = (PetscBool)!usecpu;
1204   }
1205   if (match) {
1206     switch (product->type) {
1207     case MATPRODUCT_AB:
1208     case MATPRODUCT_AtB:
1209     case MATPRODUCT_PtAP:
1210       mat->ops->productsymbolic = MatProductSymbolic_MPIAIJKokkos;
1211       break;
1212     default:
1213       break;
1214     }
1215   }
1216   /* fallback to MPIAIJ ops */
1217   if (!mat->ops->productsymbolic) {
1218     PetscCall(MatProductSetFromOptions_MPIAIJ(mat));
1219   }
1220   PetscFunctionReturn(0);
1221 }
1222 
1223 static PetscErrorCode MatSetPreallocationCOO_MPIAIJKokkos(Mat mat, PetscCount coo_n, PetscInt coo_i[], PetscInt coo_j[])
1224 {
1225   Mat_MPIAIJ       *mpiaij = (Mat_MPIAIJ*)mat->data;
1226   Mat_MPIAIJKokkos *mpikok;
1227 
1228   PetscFunctionBegin;
1229   PetscCall(MatSetPreallocationCOO_MPIAIJ(mat,coo_n,coo_i,coo_j));
1230   mat->preallocated = PETSC_TRUE;
1231   PetscCall(MatAssemblyBegin(mat,MAT_FINAL_ASSEMBLY));
1232   PetscCall(MatAssemblyEnd(mat,MAT_FINAL_ASSEMBLY));
1233   PetscCall(MatZeroEntries(mat));
1234   mpikok = static_cast<Mat_MPIAIJKokkos*>(mpiaij->spptr);
1235   delete mpikok;
1236   mpiaij->spptr = new Mat_MPIAIJKokkos(mpiaij);
1237   PetscFunctionReturn(0);
1238 }
1239 
1240 static PetscErrorCode MatSetValuesCOO_MPIAIJKokkos(Mat mat,const PetscScalar v[],InsertMode imode)
1241 {
1242   Mat_MPIAIJ                     *mpiaij = static_cast<Mat_MPIAIJ*>(mat->data);
1243   Mat_MPIAIJKokkos               *mpikok = static_cast<Mat_MPIAIJKokkos*>(mpiaij->spptr);
1244   Mat                            A = mpiaij->A,B = mpiaij->B;
1245   PetscCount                     Annz = mpiaij->Annz,Annz2 = mpiaij->Annz2,Bnnz = mpiaij->Bnnz,Bnnz2 = mpiaij->Bnnz2;
1246   MatScalarKokkosView            Aa,Ba;
1247   MatScalarKokkosView            v1;
1248   MatScalarKokkosView&           vsend = mpikok->sendbuf_d;
1249   const MatScalarKokkosView&     v2 = mpikok->recvbuf_d;
1250   const PetscCountKokkosView&    Ajmap1 = mpikok->Ajmap1_d,Ajmap2 = mpikok->Ajmap2_d,Aimap2 = mpikok->Aimap2_d;
1251   const PetscCountKokkosView&    Bjmap1 = mpikok->Bjmap1_d,Bjmap2 = mpikok->Bjmap2_d,Bimap2 = mpikok->Bimap2_d;
1252   const PetscCountKokkosView&    Aperm1 = mpikok->Aperm1_d,Aperm2 = mpikok->Aperm2_d,Bperm1 = mpikok->Bperm1_d,Bperm2 = mpikok->Bperm2_d;
1253   const PetscCountKokkosView&    Cperm1 = mpikok->Cperm1_d;
1254   PetscMemType                   memtype;
1255 
1256   PetscFunctionBegin;
1257   PetscCall(PetscGetMemType(v,&memtype)); /* Return PETSC_MEMTYPE_HOST when v is NULL */
1258   if (PetscMemTypeHost(memtype)) { /* If user gave v[] in host, we need to copy it to device if any */
1259     v1 = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(),MatScalarKokkosViewHost((PetscScalar*)v,mpiaij->coo_n));
1260   } else {
1261     v1 = MatScalarKokkosView((PetscScalar*)v,mpiaij->coo_n); /* Directly use v[]'s memory */
1262   }
1263 
1264   if (imode == INSERT_VALUES) {
1265     PetscCall(MatSeqAIJGetKokkosViewWrite(A,&Aa)); /* write matrix values */
1266     PetscCall(MatSeqAIJGetKokkosViewWrite(B,&Ba));
1267   } else {
1268     PetscCall(MatSeqAIJGetKokkosView(A,&Aa)); /* read & write matrix values */
1269     PetscCall(MatSeqAIJGetKokkosView(B,&Ba));
1270   }
1271 
1272   /* Pack entries to be sent to remote */
1273   Kokkos::parallel_for(vsend.extent(0),KOKKOS_LAMBDA(const PetscCount i) {vsend(i) = v1(Cperm1(i));});
1274 
1275   /* Send remote entries to their owner and overlap the communication with local computation */
1276   PetscCall(PetscSFReduceWithMemTypeBegin(mpiaij->coo_sf,MPIU_SCALAR,PETSC_MEMTYPE_KOKKOS,vsend.data(),PETSC_MEMTYPE_KOKKOS,v2.data(),MPI_REPLACE));
1277   /* Add local entries to A and B in one kernel */
1278   Kokkos::parallel_for(Annz+Bnnz,KOKKOS_LAMBDA(PetscCount i) {
1279     PetscScalar sum = 0.0;
1280     if (i<Annz) {
1281       for (PetscCount k=Ajmap1(i); k<Ajmap1(i+1); k++) sum += v1(Aperm1(k));
1282       Aa(i) = (imode == INSERT_VALUES? 0.0 : Aa(i)) + sum;
1283     } else {
1284       i -= Annz;
1285       for (PetscCount k=Bjmap1(i); k<Bjmap1(i+1); k++) sum += v1(Bperm1(k));
1286       Ba(i) = (imode == INSERT_VALUES? 0.0 : Ba(i)) + sum;
1287     }
1288   });
1289   PetscCall(PetscSFReduceEnd(mpiaij->coo_sf,MPIU_SCALAR,vsend.data(),v2.data(),MPI_REPLACE));
1290 
1291   /* Add received remote entries to A and B in one kernel */
1292   Kokkos::parallel_for(Annz2+Bnnz2,KOKKOS_LAMBDA(PetscCount i) {
1293     if (i < Annz2) {
1294       for (PetscCount k=Ajmap2(i); k<Ajmap2(i+1); k++) Aa(Aimap2(i)) += v2(Aperm2(k));
1295     } else {
1296       i -= Annz2;
1297       for (PetscCount k=Bjmap2(i); k<Bjmap2(i+1); k++) Ba(Bimap2(i)) += v2(Bperm2(k));
1298     }
1299   });
1300 
1301   if (imode == INSERT_VALUES) {
1302     PetscCall(MatSeqAIJRestoreKokkosViewWrite(A,&Aa)); /* Increase A & B's state etc. */
1303     PetscCall(MatSeqAIJRestoreKokkosViewWrite(B,&Ba));
1304   } else {
1305     PetscCall(MatSeqAIJRestoreKokkosView(A,&Aa));
1306     PetscCall(MatSeqAIJRestoreKokkosView(B,&Ba));
1307   }
1308   PetscFunctionReturn(0);
1309 }
1310 
1311 PetscErrorCode MatDestroy_MPIAIJKokkos(Mat A)
1312 {
1313   Mat_MPIAIJ         *mpiaij = (Mat_MPIAIJ*)A->data;
1314 
1315   PetscFunctionBegin;
1316   PetscCall(PetscObjectComposeFunction((PetscObject)A,"MatMPIAIJSetPreallocation_C",NULL));
1317   PetscCall(PetscObjectComposeFunction((PetscObject)A,"MatMPIAIJGetLocalMatMerge_C",NULL));
1318   PetscCall(PetscObjectComposeFunction((PetscObject)A,"MatSetPreallocationCOO_C",   NULL));
1319   PetscCall(PetscObjectComposeFunction((PetscObject)A,"MatSetValuesCOO_C",          NULL));
1320   delete (Mat_MPIAIJKokkos*)mpiaij->spptr;
1321   PetscCall(MatDestroy_MPIAIJ(A));
1322   PetscFunctionReturn(0);
1323 }
1324 
1325 PETSC_INTERN PetscErrorCode MatConvert_MPIAIJ_MPIAIJKokkos(Mat A, MatType mtype, MatReuse reuse, Mat* newmat)
1326 {
1327   Mat                B;
1328   Mat_MPIAIJ         *a;
1329 
1330   PetscFunctionBegin;
1331   if (reuse == MAT_INITIAL_MATRIX) {
1332     PetscCall(MatDuplicate(A,MAT_COPY_VALUES,newmat));
1333   } else if (reuse == MAT_REUSE_MATRIX) {
1334     PetscCall(MatCopy(A,*newmat,SAME_NONZERO_PATTERN));
1335   }
1336   B = *newmat;
1337 
1338   B->boundtocpu = PETSC_FALSE;
1339   PetscCall(PetscFree(B->defaultvectype));
1340   PetscCall(PetscStrallocpy(VECKOKKOS,&B->defaultvectype));
1341   PetscCall(PetscObjectChangeTypeName((PetscObject)B,MATMPIAIJKOKKOS));
1342 
1343   a = static_cast<Mat_MPIAIJ*>(A->data);
1344   if (a->A) PetscCall(MatSetType(a->A,MATSEQAIJKOKKOS));
1345   if (a->B) PetscCall(MatSetType(a->B,MATSEQAIJKOKKOS));
1346   if (a->lvec) PetscCall(VecSetType(a->lvec,VECSEQKOKKOS));
1347 
1348   B->ops->assemblyend           = MatAssemblyEnd_MPIAIJKokkos;
1349   B->ops->mult                  = MatMult_MPIAIJKokkos;
1350   B->ops->multadd               = MatMultAdd_MPIAIJKokkos;
1351   B->ops->multtranspose         = MatMultTranspose_MPIAIJKokkos;
1352   B->ops->productsetfromoptions = MatProductSetFromOptions_MPIAIJKokkos;
1353   B->ops->destroy               = MatDestroy_MPIAIJKokkos;
1354 
1355   PetscCall(PetscObjectComposeFunction((PetscObject)B,"MatMPIAIJSetPreallocation_C",MatMPIAIJSetPreallocation_MPIAIJKokkos));
1356   PetscCall(PetscObjectComposeFunction((PetscObject)B,"MatMPIAIJGetLocalMatMerge_C",MatMPIAIJGetLocalMatMerge_MPIAIJKokkos));
1357   PetscCall(PetscObjectComposeFunction((PetscObject)B,"MatSetPreallocationCOO_C",   MatSetPreallocationCOO_MPIAIJKokkos));
1358   PetscCall(PetscObjectComposeFunction((PetscObject)B,"MatSetValuesCOO_C",          MatSetValuesCOO_MPIAIJKokkos));
1359   PetscFunctionReturn(0);
1360 }
1361 /*MC
1362    MATSMPIAIJKOKKOS - MATAIJKOKKOS = "(mpi)aijkokkos" - A matrix type to be used for sparse matrices with Kokkos
1363 
1364    A matrix type type using Kokkos-Kernels CrsMatrix type for portability across different device types
1365 
1366    Options Database Keys:
1367 .  -mat_type aijkokkos - sets the matrix type to "aijkokkos" during a call to MatSetFromOptions()
1368 
1369   Level: beginner
1370 
1371 .seealso: `MatCreateAIJKokkos()`, `MATSEQAIJKOKKOS`
1372 M*/
1373 PETSC_EXTERN PetscErrorCode MatCreate_MPIAIJKokkos(Mat A)
1374 {
1375   PetscFunctionBegin;
1376   PetscCall(PetscKokkosInitializeCheck());
1377   PetscCall(MatCreate_MPIAIJ(A));
1378   PetscCall(MatConvert_MPIAIJ_MPIAIJKokkos(A,MATMPIAIJKOKKOS,MAT_INPLACE_MATRIX,&A));
1379   PetscFunctionReturn(0);
1380 }
1381 
1382 /*@C
1383    MatCreateAIJKokkos - Creates a sparse matrix in AIJ (compressed row) format
1384    (the default parallel PETSc format).  This matrix will ultimately pushed down
1385    to Kokkos for calculations. For good matrix
1386    assembly performance the user should preallocate the matrix storage by setting
1387    the parameter nz (or the array nnz).  By setting these parameters accurately,
1388    performance during matrix assembly can be increased by more than a factor of 50.
1389 
1390    Collective
1391 
1392    Input Parameters:
1393 +  comm - MPI communicator, set to PETSC_COMM_SELF
1394 .  m - number of rows
1395 .  n - number of columns
1396 .  nz - number of nonzeros per row (same for all rows)
1397 -  nnz - array containing the number of nonzeros in the various rows
1398          (possibly different for each row) or NULL
1399 
1400    Output Parameter:
1401 .  A - the matrix
1402 
1403    It is recommended that one use the MatCreate(), MatSetType() and/or MatSetFromOptions(),
1404    MatXXXXSetPreallocation() paradigm instead of this routine directly.
1405    [MatXXXXSetPreallocation() is, for example, MatSeqAIJSetPreallocation]
1406 
1407    Notes:
1408    If nnz is given then nz is ignored
1409 
1410    The AIJ format (also called the Yale sparse matrix format or
1411    compressed row storage), is fully compatible with standard Fortran 77
1412    storage.  That is, the stored row and column indices can begin at
1413    either one (as in Fortran) or zero.  See the users' manual for details.
1414 
1415    Specify the preallocated storage with either nz or nnz (not both).
1416    Set nz=PETSC_DEFAULT and nnz=NULL for PETSc to control dynamic memory
1417    allocation.  For large problems you MUST preallocate memory or you
1418    will get TERRIBLE performance, see the users' manual chapter on matrices.
1419 
1420    By default, this format uses inodes (identical nodes) when possible, to
1421    improve numerical efficiency of matrix-vector products and solves. We
1422    search for consecutive rows with the same nonzero structure, thereby
1423    reusing matrix information to achieve increased efficiency.
1424 
1425    Level: intermediate
1426 
1427 .seealso: `MatCreate()`, `MatCreateAIJ()`, `MatSetValues()`, `MatSeqAIJSetColumnIndices()`, `MatCreateSeqAIJWithArrays()`, `MatCreateAIJ()`, `MATMPIAIJKOKKOS`, `MATAIJKokkos`
1428 @*/
1429 PetscErrorCode  MatCreateAIJKokkos(MPI_Comm comm,PetscInt m,PetscInt n,PetscInt M,PetscInt N,PetscInt d_nz,const PetscInt d_nnz[],PetscInt o_nz,const PetscInt o_nnz[],Mat *A)
1430 {
1431   PetscMPIInt    size;
1432 
1433   PetscFunctionBegin;
1434   PetscCall(MatCreate(comm,A));
1435   PetscCall(MatSetSizes(*A,m,n,M,N));
1436   PetscCallMPI(MPI_Comm_size(comm,&size));
1437   if (size > 1) {
1438     PetscCall(MatSetType(*A,MATMPIAIJKOKKOS));
1439     PetscCall(MatMPIAIJSetPreallocation(*A,d_nz,d_nnz,o_nz,o_nnz));
1440   } else {
1441     PetscCall(MatSetType(*A,MATSEQAIJKOKKOS));
1442     PetscCall(MatSeqAIJSetPreallocation(*A,d_nz,d_nnz));
1443   }
1444   PetscFunctionReturn(0);
1445 }
1446 
1447 // get GPU pointer to stripped down Mat. For both Seq and MPI Mat.
1448 PetscErrorCode MatKokkosGetDeviceMatWrite(Mat A, PetscSplitCSRDataStructure *B)
1449 {
1450   PetscMPIInt                size,rank;
1451   MPI_Comm                   comm;
1452   PetscSplitCSRDataStructure d_mat=NULL;
1453 
1454   PetscFunctionBegin;
1455   PetscCall(PetscObjectGetComm((PetscObject)A,&comm));
1456   PetscCallMPI(MPI_Comm_size(comm,&size));
1457   PetscCallMPI(MPI_Comm_rank(comm,&rank));
1458   if (size == 1) {
1459     PetscCall(MatSeqAIJKokkosGetDeviceMat(A,&d_mat));
1460     PetscCall(MatSeqAIJKokkosModifyDevice(A)); /* Since we are going to modify matrix values on device */
1461   } else {
1462     Mat_MPIAIJ  *aij = (Mat_MPIAIJ*)A->data;
1463     PetscCall(MatSeqAIJKokkosGetDeviceMat(aij->A,&d_mat));
1464     PetscCall(MatSeqAIJKokkosModifyDevice(aij->A));
1465     PetscCall(MatSeqAIJKokkosModifyDevice(aij->B));
1466     PetscCheck(A->nooffprocentries || aij->donotstash,PetscObjectComm((PetscObject)A),PETSC_ERR_SUP,"Device assembly does not currently support offproc values insertion. Use MatSetOption(A,MAT_NO_OFF_PROC_ENTRIES,PETSC_TRUE) or MatSetOption(A,MAT_IGNORE_OFF_PROC_ENTRIES,PETSC_TRUE)");
1467   }
1468   // act like MatSetValues because not called on host
1469   if (A->assembled) {
1470     if (A->was_assembled) {
1471       PetscCall(PetscInfo(A,"Assemble more than once already\n"));
1472     }
1473     A->was_assembled = PETSC_TRUE; // this is done (lazy) in MatAssemble but we are not calling it anymore - done in AIJ AssemblyEnd, need here?
1474   } else {
1475     PetscCall(PetscInfo(A,"Warning !assemble ??? assembled=%" PetscInt_FMT "\n",A->assembled));
1476   }
1477   if (!d_mat) {
1478     struct _n_SplitCSRMat h_mat; /* host container */
1479     Mat_SeqAIJKokkos      *aijkokA;
1480     Mat_SeqAIJ            *jaca;
1481     PetscInt              n = A->rmap->n, nnz;
1482     Mat                   Amat;
1483     PetscInt              *colmap;
1484 
1485     /* create and copy h_mat */
1486     h_mat.M = A->cmap->N; // use for debug build
1487     PetscCall(PetscInfo(A,"Create device matrix in Kokkos\n"));
1488     if (size == 1) {
1489       Amat = A;
1490       jaca = (Mat_SeqAIJ*)A->data;
1491       h_mat.rstart = 0; h_mat.rend = A->rmap->n;
1492       h_mat.cstart = 0; h_mat.cend = A->cmap->n;
1493       h_mat.offdiag.i = h_mat.offdiag.j = NULL;
1494       h_mat.offdiag.a = NULL;
1495       aijkokA = static_cast<Mat_SeqAIJKokkos*>(A->spptr);
1496     } else {
1497       Mat_MPIAIJ       *aij = (Mat_MPIAIJ*)A->data;
1498       Mat_SeqAIJ       *jacb = (Mat_SeqAIJ*)aij->B->data;
1499       PetscInt         ii;
1500       Mat_SeqAIJKokkos *aijkokB;
1501 
1502       Amat = aij->A;
1503       aijkokA = static_cast<Mat_SeqAIJKokkos*>(aij->A->spptr);
1504       aijkokB = static_cast<Mat_SeqAIJKokkos*>(aij->B->spptr);
1505       jaca = (Mat_SeqAIJ*)aij->A->data;
1506       PetscCheck(!aij->B->cmap->n || aij->garray,comm,PETSC_ERR_PLIB,"MPIAIJ Matrix was assembled but is missing garray");
1507       PetscCheck(aij->B->rmap->n == aij->A->rmap->n,comm,PETSC_ERR_SUP,"Only support aij->B->rmap->n == aij->A->rmap->n");
1508       aij->donotstash = PETSC_TRUE;
1509       aij->A->nooffprocentries = aij->B->nooffprocentries = A->nooffprocentries = PETSC_TRUE;
1510       jaca->nonew = jacb->nonew = PETSC_TRUE; // no more disassembly
1511       PetscCall(PetscCalloc1(A->cmap->N,&colmap));
1512       PetscCall(PetscLogObjectMemory((PetscObject)A,(A->cmap->N)*sizeof(PetscInt)));
1513       for (ii=0; ii<aij->B->cmap->n; ii++) colmap[aij->garray[ii]] = ii+1;
1514       // allocate B copy data
1515       h_mat.rstart = A->rmap->rstart; h_mat.rend = A->rmap->rend;
1516       h_mat.cstart = A->cmap->rstart; h_mat.cend = A->cmap->rend;
1517       nnz = jacb->i[n];
1518       if (jacb->compressedrow.use) {
1519         const Kokkos::View<PetscInt*, Kokkos::HostSpace, Kokkos::MemoryTraits<Kokkos::Unmanaged> > h_i_k (jacb->i,n+1);
1520         aijkokB->i_uncompressed_d = Kokkos::View<PetscInt*>(Kokkos::create_mirror(DefaultMemorySpace(),h_i_k));
1521         Kokkos::deep_copy (aijkokB->i_uncompressed_d, h_i_k);
1522         h_mat.offdiag.i = aijkokB->i_uncompressed_d.data();
1523       } else {
1524          h_mat.offdiag.i = aijkokB->i_device_data();
1525       }
1526       h_mat.offdiag.j = aijkokB->j_device_data();
1527       h_mat.offdiag.a = aijkokB->a_device_data();
1528       {
1529         Kokkos::View<PetscInt*, Kokkos::HostSpace, Kokkos::MemoryTraits<Kokkos::Unmanaged> > h_colmap_k (colmap,A->cmap->N);
1530         aijkokB->colmap_d = Kokkos::View<PetscInt*>(Kokkos::create_mirror(DefaultMemorySpace(),h_colmap_k));
1531         Kokkos::deep_copy (aijkokB->colmap_d, h_colmap_k);
1532         h_mat.colmap = aijkokB->colmap_d.data();
1533         PetscCall(PetscFree(colmap));
1534       }
1535       h_mat.offdiag.ignorezeroentries = jacb->ignorezeroentries;
1536       h_mat.offdiag.n = n;
1537     }
1538     // allocate A copy data
1539     nnz = jaca->i[n];
1540     h_mat.diag.n = n;
1541     h_mat.diag.ignorezeroentries = jaca->ignorezeroentries;
1542     PetscCallMPI(MPI_Comm_rank(comm,&h_mat.rank));
1543     PetscCheck(!jaca->compressedrow.use,PETSC_COMM_SELF,PETSC_ERR_PLIB,"A does not suppport compressed row (todo)");
1544     h_mat.diag.i = aijkokA->i_device_data();
1545     h_mat.diag.j = aijkokA->j_device_data();
1546     h_mat.diag.a = aijkokA->a_device_data();
1547     // copy pointers and metdata to device
1548     PetscCall(MatSeqAIJKokkosSetDeviceMat(Amat,&h_mat));
1549     PetscCall(MatSeqAIJKokkosGetDeviceMat(Amat,&d_mat));
1550     PetscCall(PetscInfo(A,"Create device Mat n=%" PetscInt_FMT " nnz=%" PetscInt_FMT "\n",h_mat.diag.n, nnz));
1551   }
1552   *B = d_mat; // return it, set it in Mat, and set it up
1553   A->assembled = PETSC_FALSE; // ready to write with matsetvalues - this done (lazy) in normal MatSetValues
1554   PetscFunctionReturn(0);
1555 }
1556 
1557 PETSC_INTERN PetscErrorCode MatSeqAIJKokkosGetOffloadMask(Mat A, const char **mask)
1558 {
1559   Mat_SeqAIJKokkos  *aijkok = static_cast<Mat_SeqAIJKokkos*>(A->spptr);
1560 
1561   PetscFunctionBegin;
1562   if (!aijkok) *mask = "AIJKOK_UNALLOCATED";
1563   else if (aijkok->a_dual.need_sync_host()) *mask = "PETSC_OFFLOAD_GPU";
1564   else if (aijkok->a_dual.need_sync_device()) *mask = "PETSC_OFFLOAD_CPU";
1565   else *mask = "PETSC_OFFLOAD_BOTH";
1566   PetscFunctionReturn(0);
1567 }
1568 
1569 PETSC_INTERN PetscErrorCode MatAIJKokkosPrintOffloadMask(Mat A)
1570 {
1571   PetscMPIInt  size;
1572   Mat          Ad,Ao;
1573   const char  *amask,*bmask;
1574 
1575   PetscFunctionBegin;
1576   PetscCallMPI(MPI_Comm_size(PetscObjectComm((PetscObject)A),&size));
1577 
1578   if (size == 1) {
1579     PetscCall(MatSeqAIJKokkosGetOffloadMask(A,&amask));
1580     PetscCall(PetscPrintf(PETSC_COMM_SELF,"%s\n",amask));
1581   } else {
1582     Ad  = ((Mat_MPIAIJ*)A->data)->A;
1583     Ao  = ((Mat_MPIAIJ*)A->data)->B;
1584     PetscCall(MatSeqAIJKokkosGetOffloadMask(Ad,&amask));
1585     PetscCall(MatSeqAIJKokkosGetOffloadMask(Ao,&bmask));
1586     PetscCall(PetscPrintf(PETSC_COMM_SELF,"Diag : Off-diag = %s : %s\n",amask,bmask));
1587   }
1588   PetscFunctionReturn(0);
1589 }
1590