xref: /petsc/src/mat/impls/aij/mpi/kokkos/mpiaijkok.kokkos.cxx (revision 076ba34a670039b3127ea98935d34fd539c66ade)
111d22bbfSJunchao Zhang #include <petscvec_kokkos.hpp>
2*076ba34aSJunchao Zhang #include <petscsf.h>
38c3ff71bSJunchao Zhang #include <../src/mat/impls/aij/mpi/mpiaij.h>
4a587d139SMark #include <../src/mat/impls/aij/seq/kokkos/aijkokkosimpl.hpp>
5*076ba34aSJunchao Zhang #include <KokkosSparse_spadd.hpp>
611d22bbfSJunchao Zhang 
78c3ff71bSJunchao Zhang PetscErrorCode MatAssemblyEnd_MPIAIJKokkos(Mat A,MatAssemblyType mode)
88c3ff71bSJunchao Zhang {
98c3ff71bSJunchao Zhang   PetscErrorCode   ierr;
108c3ff71bSJunchao Zhang   Mat_MPIAIJ       *mpiaij = (Mat_MPIAIJ*)A->data;
11a587d139SMark   Mat_SeqAIJKokkos *aijkok = mpiaij->A->spptr ? static_cast<Mat_SeqAIJKokkos*>(mpiaij->A->spptr) : NULL;
128c3ff71bSJunchao Zhang 
138c3ff71bSJunchao Zhang   PetscFunctionBegin;
148c3ff71bSJunchao Zhang   ierr = MatAssemblyEnd_MPIAIJ(A,mode);CHKERRQ(ierr);
15a587d139SMark   if (aijkok && aijkok->device_mat_d.data()) {
16a587d139SMark     A->offloadmask = PETSC_OFFLOAD_GPU; // in GPU mode, no going back. MatSetValues checks this
17a587d139SMark   }
18a587d139SMark 
198c3ff71bSJunchao Zhang   PetscFunctionReturn(0);
208c3ff71bSJunchao Zhang }
218c3ff71bSJunchao Zhang 
228c3ff71bSJunchao Zhang PetscErrorCode MatMPIAIJSetPreallocation_MPIAIJKokkos(Mat mat,PetscInt d_nz,const PetscInt d_nnz[],PetscInt o_nz,const PetscInt o_nnz[])
238c3ff71bSJunchao Zhang {
248c3ff71bSJunchao Zhang   PetscErrorCode ierr;
258c3ff71bSJunchao Zhang   Mat_MPIAIJ     *mpiaij = (Mat_MPIAIJ*)mat->data;
268c3ff71bSJunchao Zhang 
278c3ff71bSJunchao Zhang   PetscFunctionBegin;
288c3ff71bSJunchao Zhang   ierr = PetscLayoutSetUp(mat->rmap);CHKERRQ(ierr);
298c3ff71bSJunchao Zhang   ierr = PetscLayoutSetUp(mat->cmap);CHKERRQ(ierr);
306a29ce69SStefano Zampini #if defined(PETSC_USE_DEBUG)
318c3ff71bSJunchao Zhang   if (d_nnz) {
326a29ce69SStefano Zampini     PetscInt i;
338c3ff71bSJunchao Zhang     for (i=0; i<mat->rmap->n; i++) {
348c3ff71bSJunchao Zhang       if (d_nnz[i] < 0) SETERRQ2(PETSC_COMM_SELF,PETSC_ERR_ARG_OUTOFRANGE,"d_nnz cannot be less than 0: local row %D value %D",i,d_nnz[i]);
358c3ff71bSJunchao Zhang     }
368c3ff71bSJunchao Zhang   }
378c3ff71bSJunchao Zhang   if (o_nnz) {
386a29ce69SStefano Zampini     PetscInt i;
398c3ff71bSJunchao Zhang     for (i=0; i<mat->rmap->n; i++) {
408c3ff71bSJunchao Zhang       if (o_nnz[i] < 0) SETERRQ2(PETSC_COMM_SELF,PETSC_ERR_ARG_OUTOFRANGE,"o_nnz cannot be less than 0: local row %D value %D",i,o_nnz[i]);
418c3ff71bSJunchao Zhang     }
428c3ff71bSJunchao Zhang   }
436a29ce69SStefano Zampini #endif
446a29ce69SStefano Zampini #if defined(PETSC_USE_CTABLE)
456a29ce69SStefano Zampini   ierr = PetscTableDestroy(&mpiaij->colmap);CHKERRQ(ierr);
466a29ce69SStefano Zampini #else
476a29ce69SStefano Zampini   ierr = PetscFree(mpiaij->colmap);CHKERRQ(ierr);
486a29ce69SStefano Zampini #endif
496a29ce69SStefano Zampini   ierr = PetscFree(mpiaij->garray);CHKERRQ(ierr);
506a29ce69SStefano Zampini   ierr = VecDestroy(&mpiaij->lvec);CHKERRQ(ierr);
516a29ce69SStefano Zampini   ierr = VecScatterDestroy(&mpiaij->Mvctx);CHKERRQ(ierr);
526a29ce69SStefano Zampini   /* Because the B will have been resized we simply destroy it and create a new one each time */
536a29ce69SStefano Zampini   ierr = MatDestroy(&mpiaij->B);CHKERRQ(ierr);
546a29ce69SStefano Zampini 
556a29ce69SStefano Zampini   if (!mpiaij->A) {
568c3ff71bSJunchao Zhang     ierr = MatCreate(PETSC_COMM_SELF,&mpiaij->A);CHKERRQ(ierr);
578c3ff71bSJunchao Zhang     ierr = MatSetSizes(mpiaij->A,mat->rmap->n,mat->cmap->n,mat->rmap->n,mat->cmap->n);CHKERRQ(ierr);
588c3ff71bSJunchao Zhang     ierr = PetscLogObjectParent((PetscObject)mat,(PetscObject)mpiaij->A);CHKERRQ(ierr);
596a29ce69SStefano Zampini   }
606a29ce69SStefano Zampini   if (!mpiaij->B) {
616a29ce69SStefano Zampini     PetscMPIInt size;
6255b25c41SPierre Jolivet     ierr = MPI_Comm_size(PetscObjectComm((PetscObject)mat),&size);CHKERRMPI(ierr);
638c3ff71bSJunchao Zhang     ierr = MatCreate(PETSC_COMM_SELF,&mpiaij->B);CHKERRQ(ierr);
646a29ce69SStefano Zampini     ierr = MatSetSizes(mpiaij->B,mat->rmap->n,size > 1 ? mat->cmap->N : 0,mat->rmap->n,size > 1 ? mat->cmap->N : 0);CHKERRQ(ierr);
658c3ff71bSJunchao Zhang     ierr = PetscLogObjectParent((PetscObject)mat,(PetscObject)mpiaij->B);CHKERRQ(ierr);
668c3ff71bSJunchao Zhang   }
676a29ce69SStefano Zampini   ierr = MatSetType(mpiaij->A,MATSEQAIJKOKKOS);CHKERRQ(ierr);
686a29ce69SStefano Zampini   ierr = MatSetType(mpiaij->B,MATSEQAIJKOKKOS);CHKERRQ(ierr);
698c3ff71bSJunchao Zhang   ierr = MatSeqAIJSetPreallocation(mpiaij->A,d_nz,d_nnz);CHKERRQ(ierr);
708c3ff71bSJunchao Zhang   ierr = MatSeqAIJSetPreallocation(mpiaij->B,o_nz,o_nnz);CHKERRQ(ierr);
718c3ff71bSJunchao Zhang   mat->preallocated = PETSC_TRUE;
728c3ff71bSJunchao Zhang   PetscFunctionReturn(0);
738c3ff71bSJunchao Zhang }
748c3ff71bSJunchao Zhang 
758c3ff71bSJunchao Zhang PetscErrorCode MatMult_MPIAIJKokkos(Mat mat,Vec xx,Vec yy)
768c3ff71bSJunchao Zhang {
778c3ff71bSJunchao Zhang   Mat_MPIAIJ     *mpiaij = (Mat_MPIAIJ*)mat->data;
788c3ff71bSJunchao Zhang   PetscErrorCode ierr;
798c3ff71bSJunchao Zhang   PetscInt       nt;
808c3ff71bSJunchao Zhang 
818c3ff71bSJunchao Zhang   PetscFunctionBegin;
828c3ff71bSJunchao Zhang   ierr = VecGetLocalSize(xx,&nt);CHKERRQ(ierr);
838c3ff71bSJunchao Zhang   if (nt != mat->cmap->n) SETERRQ2(PETSC_COMM_SELF,PETSC_ERR_ARG_SIZ,"Incompatible partition of mat (%D) and xx (%D)",mat->cmap->n,nt);
848c3ff71bSJunchao Zhang   ierr = VecScatterBegin(mpiaij->Mvctx,xx,mpiaij->lvec,INSERT_VALUES,SCATTER_FORWARD);CHKERRQ(ierr);
858c3ff71bSJunchao Zhang   ierr = (*mpiaij->A->ops->mult)(mpiaij->A,xx,yy);CHKERRQ(ierr);
868c3ff71bSJunchao Zhang   ierr = VecScatterEnd(mpiaij->Mvctx,xx,mpiaij->lvec,INSERT_VALUES,SCATTER_FORWARD);CHKERRQ(ierr);
878c3ff71bSJunchao Zhang   ierr = (*mpiaij->B->ops->multadd)(mpiaij->B,mpiaij->lvec,yy,yy);CHKERRQ(ierr);
888c3ff71bSJunchao Zhang   PetscFunctionReturn(0);
898c3ff71bSJunchao Zhang }
908c3ff71bSJunchao Zhang 
918c3ff71bSJunchao Zhang PetscErrorCode MatMultAdd_MPIAIJKokkos(Mat mat,Vec xx,Vec yy,Vec zz)
928c3ff71bSJunchao Zhang {
938c3ff71bSJunchao Zhang   Mat_MPIAIJ     *mpiaij = (Mat_MPIAIJ*)mat->data;
948c3ff71bSJunchao Zhang   PetscErrorCode ierr;
958c3ff71bSJunchao Zhang   PetscInt       nt;
968c3ff71bSJunchao Zhang 
978c3ff71bSJunchao Zhang   PetscFunctionBegin;
988c3ff71bSJunchao Zhang   ierr = VecGetLocalSize(xx,&nt);CHKERRQ(ierr);
998c3ff71bSJunchao Zhang   if (nt != mat->cmap->n) SETERRQ2(PETSC_COMM_SELF,PETSC_ERR_ARG_SIZ,"Incompatible partition of mat (%D) and xx (%D)",mat->cmap->n,nt);
1008c3ff71bSJunchao Zhang   ierr = VecScatterBegin(mpiaij->Mvctx,xx,mpiaij->lvec,INSERT_VALUES,SCATTER_FORWARD);CHKERRQ(ierr);
1018c3ff71bSJunchao Zhang   ierr = (*mpiaij->A->ops->multadd)(mpiaij->A,xx,yy,zz);CHKERRQ(ierr);
1028c3ff71bSJunchao Zhang   ierr = VecScatterEnd(mpiaij->Mvctx,xx,mpiaij->lvec,INSERT_VALUES,SCATTER_FORWARD);CHKERRQ(ierr);
1038c3ff71bSJunchao Zhang   ierr = (*mpiaij->B->ops->multadd)(mpiaij->B,mpiaij->lvec,zz,zz);CHKERRQ(ierr);
1048c3ff71bSJunchao Zhang   PetscFunctionReturn(0);
1058c3ff71bSJunchao Zhang }
1068c3ff71bSJunchao Zhang 
1078c3ff71bSJunchao Zhang PetscErrorCode MatMultTranspose_MPIAIJKokkos(Mat mat,Vec xx,Vec yy)
1088c3ff71bSJunchao Zhang {
1098c3ff71bSJunchao Zhang   Mat_MPIAIJ     *mpiaij = (Mat_MPIAIJ*)mat->data;
1108c3ff71bSJunchao Zhang   PetscErrorCode ierr;
1118c3ff71bSJunchao Zhang   PetscInt       nt;
1128c3ff71bSJunchao Zhang 
1138c3ff71bSJunchao Zhang   PetscFunctionBegin;
1148c3ff71bSJunchao Zhang   ierr = VecGetLocalSize(xx,&nt);CHKERRQ(ierr);
1158c3ff71bSJunchao Zhang   if (nt != mat->rmap->n) SETERRQ2(PETSC_COMM_SELF,PETSC_ERR_ARG_SIZ,"Incompatible partition of mat (%D) and xx (%D)",mat->rmap->n,nt);
1168c3ff71bSJunchao Zhang   ierr = (*mpiaij->B->ops->multtranspose)(mpiaij->B,xx,mpiaij->lvec);CHKERRQ(ierr);
1178c3ff71bSJunchao Zhang   ierr = (*mpiaij->A->ops->multtranspose)(mpiaij->A,xx,yy);CHKERRQ(ierr);
1188c3ff71bSJunchao Zhang   ierr = VecScatterBegin(mpiaij->Mvctx,mpiaij->lvec,yy,ADD_VALUES,SCATTER_REVERSE);CHKERRQ(ierr);
1198c3ff71bSJunchao Zhang   ierr = VecScatterEnd(mpiaij->Mvctx,mpiaij->lvec,yy,ADD_VALUES,SCATTER_REVERSE);CHKERRQ(ierr);
1208c3ff71bSJunchao Zhang   PetscFunctionReturn(0);
1218c3ff71bSJunchao Zhang }
1228c3ff71bSJunchao Zhang 
123*076ba34aSJunchao Zhang /* Merge the "A, B" matrices of mat into a matrix C.  mat's type is MPIAIJKOKKOS. C's type is MATSEQAIJKOKKOS.
124*076ba34aSJunchao Zhang    A is put before B. C's size would be A->rmap->n by (A->cmap->n + B->cmap->n).
125*076ba34aSJunchao Zhang    C still uses local column ids. Their corresponding global column ids are returned in glob.
126*076ba34aSJunchao Zhang */
127*076ba34aSJunchao Zhang PetscErrorCode MatMPIAIJGetLocalMatMerge_MPIAIJKokkos(Mat mat,MatReuse reuse,IS *glob,Mat *C)
128*076ba34aSJunchao Zhang {
129*076ba34aSJunchao Zhang   Mat            Ad,Ao;
130*076ba34aSJunchao Zhang   const PetscInt *cmap;
131*076ba34aSJunchao Zhang   PetscErrorCode ierr;
132*076ba34aSJunchao Zhang 
133*076ba34aSJunchao Zhang   PetscFunctionBegin;
134*076ba34aSJunchao Zhang   ierr = MatMPIAIJGetSeqAIJ(mat,&Ad,&Ao,&cmap);CHKERRQ(ierr);
135*076ba34aSJunchao Zhang   ierr = MatSeqAIJKokkosMergeMats(Ad,Ao,reuse,C);CHKERRQ(ierr);
136*076ba34aSJunchao Zhang   if (glob) {
137*076ba34aSJunchao Zhang     PetscInt cst, i, dn, on, *gidx;
138*076ba34aSJunchao Zhang     ierr = MatGetLocalSize(Ad,NULL,&dn);CHKERRQ(ierr);
139*076ba34aSJunchao Zhang     ierr = MatGetLocalSize(Ao,NULL,&on);CHKERRQ(ierr);
140*076ba34aSJunchao Zhang     ierr = MatGetOwnershipRangeColumn(mat,&cst,NULL);CHKERRQ(ierr);
141*076ba34aSJunchao Zhang     ierr = PetscMalloc1(dn+on,&gidx);CHKERRQ(ierr);
142*076ba34aSJunchao Zhang     for (i=0; i<dn; i++) gidx[i]    = cst + i;
143*076ba34aSJunchao Zhang     for (i=0; i<on; i++) gidx[i+dn] = cmap[i];
144*076ba34aSJunchao Zhang     ierr = ISCreateGeneral(PetscObjectComm((PetscObject)Ad),dn+on,gidx,PETSC_OWN_POINTER,glob);CHKERRQ(ierr);
145*076ba34aSJunchao Zhang   }
146*076ba34aSJunchao Zhang   PetscFunctionReturn(0);
147*076ba34aSJunchao Zhang }
148*076ba34aSJunchao Zhang 
149*076ba34aSJunchao Zhang /* Structs used in matrix product C=AB, C=A^tB and C=B^tAB */
150*076ba34aSJunchao Zhang struct MatMatStruct {
151*076ba34aSJunchao Zhang   MatRowMapKokkosView   Cdstart; /* Used to split sequential matrix into petsc's A, B format */
152*076ba34aSJunchao Zhang   PetscSF               sf; /* SF to send/recv matrix entries */
153*076ba34aSJunchao Zhang   MatScalarKokkosView   abuf; /* buf of mat values in send/recv */
154*076ba34aSJunchao Zhang   Mat                   C1,C2,B_local;
155*076ba34aSJunchao Zhang   KokkosCsrMatrix       C1_global,C2_global,C_global;
156*076ba34aSJunchao Zhang   KernelHandle          kh;
157*076ba34aSJunchao Zhang   MatMatStruct() {
158*076ba34aSJunchao Zhang     C1 = C2 = B_local = NULL;
159*076ba34aSJunchao Zhang     sf = NULL;
160*076ba34aSJunchao Zhang   }
161*076ba34aSJunchao Zhang 
162*076ba34aSJunchao Zhang   ~MatMatStruct() {
163*076ba34aSJunchao Zhang     MatDestroy(&C1);
164*076ba34aSJunchao Zhang     MatDestroy(&C2);
165*076ba34aSJunchao Zhang     MatDestroy(&B_local);
166*076ba34aSJunchao Zhang     PetscSFDestroy(&sf);
167*076ba34aSJunchao Zhang     kh.destroy_spadd_handle();
168*076ba34aSJunchao Zhang   }
169*076ba34aSJunchao Zhang };
170*076ba34aSJunchao Zhang 
171*076ba34aSJunchao Zhang struct MatMatStruct_AB : public MatMatStruct {
172*076ba34aSJunchao Zhang   MatColIdxKokkosView   rows;
173*076ba34aSJunchao Zhang   MatRowMapKokkosView   rowoffset;
174*076ba34aSJunchao Zhang   Mat                   B_other,C_petsc; /* SEQAIJKOKKOS matrices. TODO: have a better var name than C_petsc */
175*076ba34aSJunchao Zhang 
176*076ba34aSJunchao Zhang   MatMatStruct_AB() : B_other(NULL),C_petsc(NULL){}
177*076ba34aSJunchao Zhang   ~MatMatStruct_AB() {
178*076ba34aSJunchao Zhang     MatDestroy(&B_other);
179*076ba34aSJunchao Zhang     MatDestroy(&C_petsc);
180*076ba34aSJunchao Zhang   }
181*076ba34aSJunchao Zhang };
182*076ba34aSJunchao Zhang 
183*076ba34aSJunchao Zhang struct MatMatStruct_AtB : public MatMatStruct {
184*076ba34aSJunchao Zhang   MatRowMapKokkosView   srcrowoffset,dstrowoffset;
185*076ba34aSJunchao Zhang };
186*076ba34aSJunchao Zhang 
187*076ba34aSJunchao Zhang struct MatProductData_MPIAIJKokkos
188*076ba34aSJunchao Zhang {
189*076ba34aSJunchao Zhang   MatMatStruct_AB   *mmAB;
190*076ba34aSJunchao Zhang   MatMatStruct_AtB  *mmAtB;
191*076ba34aSJunchao Zhang   PetscBool         reusesym;
192*076ba34aSJunchao Zhang 
193*076ba34aSJunchao Zhang   MatProductData_MPIAIJKokkos(): mmAB(NULL),mmAtB(NULL),reusesym(PETSC_FALSE){}
194*076ba34aSJunchao Zhang   ~MatProductData_MPIAIJKokkos() {
195*076ba34aSJunchao Zhang     delete mmAB;
196*076ba34aSJunchao Zhang     delete mmAtB;
197*076ba34aSJunchao Zhang   }
198*076ba34aSJunchao Zhang };
199*076ba34aSJunchao Zhang 
200*076ba34aSJunchao Zhang static PetscErrorCode MatProductDataDestroy_MPIAIJKokkos(void *data)
201*076ba34aSJunchao Zhang {
202*076ba34aSJunchao Zhang   PetscFunctionBegin;
203*076ba34aSJunchao Zhang   CHKERRCXX(delete static_cast<MatProductData_MPIAIJKokkos*>(data));
204*076ba34aSJunchao Zhang   PetscFunctionReturn(0);
205*076ba34aSJunchao Zhang }
206*076ba34aSJunchao Zhang 
207*076ba34aSJunchao Zhang /* MatSeqAIJKokkosGetCSRMatrixWithGlobalColumnIds - Get a KokkosCsrMatrix from a MATSEQAIJKOKKOS matrix
208*076ba34aSJunchao Zhang 
209*076ba34aSJunchao Zhang    Input Parameters:
210*076ba34aSJunchao Zhang +  A       - the MATSEQAIJKOKKOS matrix
211*076ba34aSJunchao Zhang .  N       - new column size for the returned Kokkos matrix
212*076ba34aSJunchao Zhang -  l2g     - a map that maps old col ids to new col ids
213*076ba34aSJunchao Zhang 
214*076ba34aSJunchao Zhang    Output Parameters:
215*076ba34aSJunchao Zhang .  csrmat  - the Kokkos matrix, which has the same row size as A, shares a, i but not j with A.
216*076ba34aSJunchao Zhang  */
217*076ba34aSJunchao Zhang static PetscErrorCode MatSeqAIJKokkosGetCSRMatrixWithGlobalColumnIds(Mat A,PetscInt N,const ConstMatColIdxKokkosView& l2g,KokkosCsrMatrix& csrmat)
218*076ba34aSJunchao Zhang {
219*076ba34aSJunchao Zhang   KokkosCsrMatrix&         orig = static_cast<Mat_SeqAIJKokkos*>(A->spptr)->csrmat;
220*076ba34aSJunchao Zhang   MatColIdxKokkosView      jg("jg",orig.nnz()); /* New j array for csrmat */
221*076ba34aSJunchao Zhang 
222*076ba34aSJunchao Zhang   PetscFunctionBegin;
223*076ba34aSJunchao Zhang   CHKERRCXX(Kokkos::parallel_for(orig.nnz(), KOKKOS_LAMBDA(const PetscInt i) {jg(i) = l2g(orig.graph.entries(i));}));
224*076ba34aSJunchao Zhang   CHKERRCXX(csrmat = KokkosCsrMatrix("csrmat",orig.numRows(),N,orig.nnz(),orig.values,orig.graph.row_map,jg));
225*076ba34aSJunchao Zhang   PetscFunctionReturn(0);
226*076ba34aSJunchao Zhang }
227*076ba34aSJunchao Zhang 
228*076ba34aSJunchao Zhang /* MatSetMPIAIJKokkosWithSplitSeqAIJKokkosMatrices - Set the diag and offdiag matrices of a MATMPIAIJKOKKOS matrix.
229*076ba34aSJunchao Zhang    It is similar to MatCreateMPIAIJWithSplitArrays.
230*076ba34aSJunchao Zhang 
231*076ba34aSJunchao Zhang   Input Parameters:
232*076ba34aSJunchao Zhang +  mat   - the MATMPIAIJKOKKOS matrix, which should have its type and layout set, but should not have its diag, offdiag matrices set
233*076ba34aSJunchao Zhang .  A     - the diag matrix using local col ids
234*076ba34aSJunchao Zhang -  B     - the offdiag matrix using global col ids
235*076ba34aSJunchao Zhang 
236*076ba34aSJunchao Zhang   Output Parameters:
237*076ba34aSJunchao Zhang .  mat   - the updated MATMPIAIJKOKKOS matrix
238*076ba34aSJunchao Zhang */
239*076ba34aSJunchao Zhang static PetscErrorCode MatSetMPIAIJKokkosWithSplitSeqAIJKokkosMatrices(Mat mat,Mat A,Mat B)
240*076ba34aSJunchao Zhang {
241*076ba34aSJunchao Zhang   PetscErrorCode      ierr;
242*076ba34aSJunchao Zhang   Mat_MPIAIJ          *mpiaij = static_cast<Mat_MPIAIJ*>(mat->data);
243*076ba34aSJunchao Zhang   PetscInt            m,n,M,N,Am,An,Bm,Bn;
244*076ba34aSJunchao Zhang   Mat_SeqAIJKokkos    *bkok = static_cast<Mat_SeqAIJKokkos*>(B->spptr);
245*076ba34aSJunchao Zhang 
246*076ba34aSJunchao Zhang   PetscFunctionBegin;
247*076ba34aSJunchao Zhang   ierr = MatGetSize(mat,&M,&N);CHKERRQ(ierr);
248*076ba34aSJunchao Zhang   ierr = MatGetLocalSize(mat,&m,&n);CHKERRQ(ierr);
249*076ba34aSJunchao Zhang   ierr = MatGetLocalSize(A,&Am,&An);CHKERRQ(ierr);
250*076ba34aSJunchao Zhang   ierr = MatGetLocalSize(B,&Bm,&Bn);CHKERRQ(ierr);
251*076ba34aSJunchao Zhang 
252*076ba34aSJunchao Zhang   if (m != Am || m != Bm) SETERRQ(PETSC_COMM_SELF,PETSC_ERR_PLIB,"local number of rows do not match");
253*076ba34aSJunchao Zhang   if (n != An) SETERRQ(PETSC_COMM_SELF,PETSC_ERR_PLIB,"local number of columns do not match");
254*076ba34aSJunchao Zhang   if (N != Bn) SETERRQ(PETSC_COMM_SELF,PETSC_ERR_PLIB,"global number of columns do not match");
255*076ba34aSJunchao Zhang   if (mpiaij->A || mpiaij->B) SETERRQ(PETSC_COMM_SELF,PETSC_ERR_PLIB,"A, B of the MPIAIJ matrix are not empty");
256*076ba34aSJunchao Zhang   mpiaij->A = A;
257*076ba34aSJunchao Zhang   mpiaij->B = B;
258*076ba34aSJunchao Zhang 
259*076ba34aSJunchao Zhang   mat->preallocated     = PETSC_TRUE;
260*076ba34aSJunchao Zhang   mat->nooffprocentries = PETSC_TRUE; /* See MatAssemblyBegin_MPIAIJ. In effect, making MatAssemblyBegin a nop */
261*076ba34aSJunchao Zhang 
262*076ba34aSJunchao Zhang   ierr = MatSetOption(mat,MAT_NO_OFF_PROC_ENTRIES,PETSC_TRUE);CHKERRQ(ierr);
263*076ba34aSJunchao Zhang   ierr = MatAssemblyBegin(mat,MAT_FINAL_ASSEMBLY);CHKERRQ(ierr);
264*076ba34aSJunchao Zhang   /* MatAssemblyEnd is critical here. It sets mat->offloadmask according to A and B's, and
265*076ba34aSJunchao Zhang     also gets mpiaij->B compacted, with its col ids and size reduced
266*076ba34aSJunchao Zhang   */
267*076ba34aSJunchao Zhang   ierr = MatAssemblyEnd(mat,MAT_FINAL_ASSEMBLY);CHKERRQ(ierr);
268*076ba34aSJunchao Zhang   ierr = MatSetOption(mat,MAT_NO_OFF_PROC_ENTRIES,PETSC_FALSE);CHKERRQ(ierr);
269*076ba34aSJunchao Zhang   ierr = MatSetOption(mat,MAT_NEW_NONZERO_LOCATION_ERR,PETSC_TRUE);CHKERRQ(ierr);
270*076ba34aSJunchao Zhang 
271*076ba34aSJunchao Zhang   /* Update bkok with new local col ids (stored on host) and size */
272*076ba34aSJunchao Zhang   bkok->j_dual.modify_host();
273*076ba34aSJunchao Zhang   bkok->j_dual.sync_device();
274*076ba34aSJunchao Zhang   bkok->SetColSize(mpiaij->B->cmap->n);
275*076ba34aSJunchao Zhang   PetscFunctionReturn(0);
276*076ba34aSJunchao Zhang }
277*076ba34aSJunchao Zhang 
278*076ba34aSJunchao Zhang /* MatSeqAIJKokkosBcast - Bcast rows of a SEQAIJKOKKOS matrice (B) to form a SEQAIJKOKKOS matrix (C).
279*076ba34aSJunchao Zhang 
280*076ba34aSJunchao Zhang    It is essentially the MPIAIJKOKKOS counterpart of MatGetBrowsOfAoCols_MPIAIJ, but supports device and uses PetscSF.
281*076ba34aSJunchao Zhang    In the given ownerSF, leaves correspond to rows in C, and roots correspond to rows in B. Roots may connect to multiple leaves.
282*076ba34aSJunchao Zhang    Suppose C's j-th row is connected to a root identified by PetscSFNode (k,i), it means we will bcast the i-th row of B on rank k
283*076ba34aSJunchao Zhang    to j-th row of C. ownerSF's leaves must be contiguous (in other words, as if ilocal=NULL was used to set its graph).
284*076ba34aSJunchao Zhang 
285*076ba34aSJunchao Zhang    Collective on comm of ownerSF
286*076ba34aSJunchao Zhang 
287*076ba34aSJunchao Zhang    Input Parameters:
288*076ba34aSJunchao Zhang +   B       - the SEQAIJKOKKOS matrix, using local col ids
289*076ba34aSJunchao Zhang .   reuse   - either MAT_INITIAL_MATRIX or MAT_REUSE_MATRIX
290*076ba34aSJunchao Zhang .   N       - global col ids are in range of [0,N). N Must be the same across ranks (nonsignificant in MAT_REUSE_MATRIX)
291*076ba34aSJunchao Zhang .   l2g     - a map mapping B's local col ids to global ones (nonsignificant in MAT_REUSE_MATRIX)
292*076ba34aSJunchao Zhang .   ownerSF - the ownership SF (nonsignificant in MAT_REUSE_MATRIX)
293*076ba34aSJunchao Zhang 
294*076ba34aSJunchao Zhang    Input/Output Parameters (out when resue = MAT_INITIAL_MATRIX, inout when reuse = MAT_REUSE_MATRIX)
295*076ba34aSJunchao Zhang +   bcastSF   - the SF used to bcast rows of B. This plain SF does buffer (abuf) to buffer (Ca) send/recv. In this SF, vertices are nonzeros.
296*076ba34aSJunchao Zhang .   abuf      - buffer for sending matrix values
297*076ba34aSJunchao Zhang .   rows      - array containing indices of (local) rows that this rank needs to bcast to others. Each receiver rank has a chunk in rows[].
298*076ba34aSJunchao Zhang                 Values in rows[] might have repeats, which simply indicates a row will be bcast'ed to multiple neighbors.
299*076ba34aSJunchao Zhang .   rowoffset - For each row in rows[], it will be copied to rowoffset[] at abuf[]
300*076ba34aSJunchao Zhang -   C         -  the SEQAIJKOKKOS matrix made of the bcast'ed rows, using local col ids.
301*076ba34aSJunchao Zhang */
302*076ba34aSJunchao Zhang static PetscErrorCode MatSeqAIJKokkosBcast(Mat B,MatReuse reuse,PetscInt N,const ConstMatColIdxKokkosView& l2g,PetscSF ownerSF,
303*076ba34aSJunchao Zhang                                            PetscSF& bcastSF,MatScalarKokkosView& abuf,MatColIdxKokkosView& rows,
304*076ba34aSJunchao Zhang                                            MatRowMapKokkosView& rowoffset,Mat& C)
305*076ba34aSJunchao Zhang {
306*076ba34aSJunchao Zhang   PetscErrorCode               ierr;
307*076ba34aSJunchao Zhang   Mat_SeqAIJKokkos             *bkok,*ckok;
308*076ba34aSJunchao Zhang 
309*076ba34aSJunchao Zhang   PetscFunctionBegin;
310*076ba34aSJunchao Zhang   ierr = MatSeqAIJKokkosSyncDevice(B);CHKERRQ(ierr); /* Make sure B->spptr is accessible */
311*076ba34aSJunchao Zhang   bkok = static_cast<Mat_SeqAIJKokkos*>(B->spptr);
312*076ba34aSJunchao Zhang 
313*076ba34aSJunchao Zhang   if (reuse == MAT_REUSE_MATRIX) {
314*076ba34aSJunchao Zhang     ckok = static_cast<Mat_SeqAIJKokkos*>(C->spptr);
315*076ba34aSJunchao Zhang 
316*076ba34aSJunchao Zhang     const auto& Ba = bkok->a_dual.view_device();
317*076ba34aSJunchao Zhang     const auto& Bi = bkok->i_dual.view_device();
318*076ba34aSJunchao Zhang     const auto& Ca = ckok->a_dual.view_device();
319*076ba34aSJunchao Zhang 
320*076ba34aSJunchao Zhang     /* Copy Ba to abuf */
321*076ba34aSJunchao Zhang     Kokkos::parallel_for(Kokkos::TeamPolicy<>(rows.extent(0), Kokkos::AUTO()),KOKKOS_LAMBDA(const KokkosTeamMemberType& t) {
322*076ba34aSJunchao Zhang       PetscInt i    = t.league_rank(); /* rows[i] is r-th row of B */
323*076ba34aSJunchao Zhang       PetscInt r    = rows(i);
324*076ba34aSJunchao Zhang       PetscInt base = rowoffset(i); /* Copy r-th row of B to this offset in abuf[] */
325*076ba34aSJunchao Zhang       Kokkos::parallel_for(Kokkos::TeamThreadRange(t,Bi(r+1)-Bi(r)),[&](PetscInt k) {
326*076ba34aSJunchao Zhang         abuf(base+k) = Ba(Bi(r)+k);
327*076ba34aSJunchao Zhang       });
328*076ba34aSJunchao Zhang     });
329*076ba34aSJunchao Zhang 
330*076ba34aSJunchao Zhang     /* Send abuf to Ca through bcastSF and then mark C is updated on device */
331*076ba34aSJunchao Zhang     ierr = PetscSFBcastBegin(bcastSF,MPIU_SCALAR,abuf.data(),Ca.data(),MPI_REPLACE);CHKERRQ(ierr); /* TODO: get memtype for abuf */
332*076ba34aSJunchao Zhang     ierr = PetscSFBcastEnd  (bcastSF,MPIU_SCALAR,abuf.data(),Ca.data(),MPI_REPLACE);CHKERRQ(ierr);
333*076ba34aSJunchao Zhang     ckok->a_dual.modify_device();
334*076ba34aSJunchao Zhang   } else if (reuse == MAT_INITIAL_MATRIX) {
335*076ba34aSJunchao Zhang     MPI_Comm       comm;
336*076ba34aSJunchao Zhang     PetscMPIInt    tag;
337*076ba34aSJunchao Zhang     PetscInt       k,Cm,Cn,Cnnz,*Ci_h,nroots,nleaves;
338*076ba34aSJunchao Zhang 
339*076ba34aSJunchao Zhang     ierr = PetscObjectGetComm((PetscObject)ownerSF,&comm);CHKERRMPI(ierr);
340*076ba34aSJunchao Zhang     ierr = PetscSFGetGraph(ownerSF,&nroots,&nleaves,NULL,NULL);CHKERRQ(ierr);
341*076ba34aSJunchao Zhang     Cm   = nleaves; /* row size of C */
342*076ba34aSJunchao Zhang     Cn   = N;  /* col size of C, which initially uses global ids, so we can safely set its col size as N */
343*076ba34aSJunchao Zhang 
344*076ba34aSJunchao Zhang     /* Get row lens (nz) of B's rows for later fast query */
345*076ba34aSJunchao Zhang     PetscInt       *Browlens;
346*076ba34aSJunchao Zhang     const PetscInt *tmp = bkok->i_host_data();
347*076ba34aSJunchao Zhang     ierr = PetscMalloc1(nroots,&Browlens);CHKERRQ(ierr);
348*076ba34aSJunchao Zhang     for (k=0; k<nroots; k++) Browlens[k] = tmp[k+1]-tmp[k];
349*076ba34aSJunchao Zhang 
350*076ba34aSJunchao Zhang     /* By ownerSF, each proc gets lens of rows of C */
351*076ba34aSJunchao Zhang     MatRowMapKokkosDualView Ci("i",Cm+1); /* C's rowmap */
352*076ba34aSJunchao Zhang     Ci_h    = Ci.view_host().data();
353*076ba34aSJunchao Zhang     Ci_h[0] = 0;
354*076ba34aSJunchao Zhang     ierr    = PetscSFBcastWithMemTypeBegin(ownerSF,MPIU_INT,PETSC_MEMTYPE_HOST,Browlens,PETSC_MEMTYPE_HOST,&Ci_h[1],MPI_REPLACE);CHKERRQ(ierr);
355*076ba34aSJunchao Zhang     ierr    = PetscSFBcastEnd(ownerSF,MPIU_INT,Browlens,&Ci_h[1],MPI_REPLACE);CHKERRQ(ierr);
356*076ba34aSJunchao Zhang     for (k=1; k<Cm+1; k++) Ci_h[k] += Ci_h[k-1]; /* Convert lens to CSR */
357*076ba34aSJunchao Zhang     Cnnz    = Ci_h[Cm];
358*076ba34aSJunchao Zhang     Ci.modify_host();
359*076ba34aSJunchao Zhang     Ci.sync_device();
360*076ba34aSJunchao Zhang 
361*076ba34aSJunchao Zhang     /* With the newly known Cnnz, we are able to allocate (j, a) for C on host & device */
362*076ba34aSJunchao Zhang     MatColIdxKokkosDualView  Cj("j",Cnnz);
363*076ba34aSJunchao Zhang     MatScalarKokkosDualView  Ca("a",Cnnz);
364*076ba34aSJunchao Zhang 
365*076ba34aSJunchao Zhang     /* Now build the bcastSF to fill Ca, Cj. This plain SF only does (contiguous) buffer to buffer send/recv */
366*076ba34aSJunchao Zhang     const PetscMPIInt *iranks,*ranks;
367*076ba34aSJunchao Zhang     const PetscInt    *ioffset,*irootloc,*roffset;
368*076ba34aSJunchao Zhang     PetscInt          i,j,niranks,nranks,*sdisp,*rdisp,*rowptr;
369*076ba34aSJunchao Zhang     MPI_Request       *reqs;
370*076ba34aSJunchao Zhang 
371*076ba34aSJunchao Zhang     ierr = PetscSFGetLeafRanks(ownerSF,&niranks,&iranks,&ioffset,&irootloc);CHKERRQ(ierr); /* irootloc[] contains indices of rows I need to send to each receiver */
372*076ba34aSJunchao Zhang     ierr = PetscSFGetRootRanks(ownerSF,&nranks,&ranks,&roffset,NULL/*rmine*/,NULL/*rremote*/);CHKERRQ(ierr); /* recv info */
373*076ba34aSJunchao Zhang 
374*076ba34aSJunchao Zhang     /* figure out offsets at the send buffer, to build the SF
375*076ba34aSJunchao Zhang       sdisp[]  - stores offsets of nonzeros (in abuf or jbuf, see later) I need to send, per receiver.
376*076ba34aSJunchao Zhang       rowptr[] - stores offsets for data of each row in abuf
377*076ba34aSJunchao Zhang 
378*076ba34aSJunchao Zhang       rdisp[]  - to receive sdisp[]
379*076ba34aSJunchao Zhang     */
380*076ba34aSJunchao Zhang     ierr = PetscMalloc3(niranks+1,&sdisp,nranks,&rdisp,niranks+nranks,&reqs);CHKERRQ(ierr);
381*076ba34aSJunchao Zhang     MatRowMapKokkosViewHost rowptr_h("rowptr_h",ioffset[niranks]+1); /* Let Kokkos do the allocation, so that we can do an easy mirror later */
382*076ba34aSJunchao Zhang     rowptr = rowptr_h.data();
383*076ba34aSJunchao Zhang 
384*076ba34aSJunchao Zhang     sdisp[0] = 0;
385*076ba34aSJunchao Zhang     rowptr[0]  = 0;
386*076ba34aSJunchao Zhang     for (i=0; i<niranks; i++) { /* for each receiver */
387*076ba34aSJunchao Zhang       PetscInt len, nz = 0;
388*076ba34aSJunchao Zhang       for (j=ioffset[i]; j<ioffset[i+1]; j++) { /* for each row to this receiver */
389*076ba34aSJunchao Zhang         len         = Browlens[irootloc[j]];
390*076ba34aSJunchao Zhang         rowptr[j+1] = rowptr[j] + len;
391*076ba34aSJunchao Zhang         nz         += len;
392*076ba34aSJunchao Zhang       }
393*076ba34aSJunchao Zhang       sdisp[i+1] = sdisp[i] + nz;
394*076ba34aSJunchao Zhang     }
395*076ba34aSJunchao Zhang     ierr = PetscCommGetNewTag(comm,&tag);CHKERRMPI(ierr);
396*076ba34aSJunchao Zhang     for (i=0; i<nranks; i++)  {ierr = MPI_Irecv(&rdisp[i],1,MPIU_INT,ranks[i],tag,comm,&reqs[i]);CHKERRMPI(ierr);}
397*076ba34aSJunchao Zhang     for (i=0; i<niranks; i++) {ierr = MPI_Isend(&sdisp[i],1,MPIU_INT,iranks[i],tag,comm,&reqs[nranks+i]);CHKERRMPI(ierr);}
398*076ba34aSJunchao Zhang     ierr = MPI_Waitall(niranks+nranks,reqs,MPI_STATUSES_IGNORE);CHKERRMPI(ierr);
399*076ba34aSJunchao Zhang 
400*076ba34aSJunchao Zhang     PetscInt    nleaves2 = Cnnz; /* leaves are the nonzeros I will receive */
401*076ba34aSJunchao Zhang     PetscInt    nroots2  = sdisp[niranks]; /* roots are the nonzeros (in abuf) I will send */
402*076ba34aSJunchao Zhang     PetscSFNode *iremote;
403*076ba34aSJunchao Zhang     ierr = PetscMalloc1(nleaves2,&iremote);CHKERRQ(ierr);
404*076ba34aSJunchao Zhang     for (i=0; i<nranks; i++) { /* for each sender */
405*076ba34aSJunchao Zhang       k = 0;
406*076ba34aSJunchao Zhang       for (j=Ci_h[roffset[i]]; j<Ci_h[roffset[i+1]]; j++) {
407*076ba34aSJunchao Zhang         iremote[j].rank  = ranks[i];
408*076ba34aSJunchao Zhang         iremote[j].index = rdisp[i] + k;
409*076ba34aSJunchao Zhang         k++;
410*076ba34aSJunchao Zhang       }
411*076ba34aSJunchao Zhang     }
412*076ba34aSJunchao Zhang     /* TODO: we should extend PetscSF APIs for this buffer-to-buffer send/recv */
413*076ba34aSJunchao Zhang     ierr = PetscSFCreate(comm,&bcastSF);CHKERRQ(ierr);
414*076ba34aSJunchao Zhang     ierr = PetscSFSetGraph(bcastSF,nroots2,nleaves2,NULL/*ilocal*/,PETSC_OWN_POINTER,iremote,PETSC_OWN_POINTER);CHKERRQ(ierr);
415*076ba34aSJunchao Zhang 
416*076ba34aSJunchao Zhang     /* Extract selected rows of B, and copy their (a, j) into abuf[] and jbuf[], with j converted
417*076ba34aSJunchao Zhang       from local to global. Then use bcastSF to fill Ca, Cj.
418*076ba34aSJunchao Zhang     */
419*076ba34aSJunchao Zhang     ConstMatColIdxKokkosViewHost rows_h(irootloc,ioffset[niranks]); /* irootloc[] stores indices of rows I need to send */
420*076ba34aSJunchao Zhang     MatColIdxKokkosView          rows("rows",ioffset[niranks]);
421*076ba34aSJunchao Zhang     Kokkos::deep_copy(rows,rows_h); /* Use deep copy since irootoc is managed by PetscSF and we want 'rows' to be standalone */
422*076ba34aSJunchao Zhang 
423*076ba34aSJunchao Zhang     rowoffset = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(),rowptr_h); /* If no device, rowoffset will be an alias to rowptr_h */
424*076ba34aSJunchao Zhang 
425*076ba34aSJunchao Zhang     MatColIdxKokkosView jbuf("jbuf",sdisp[niranks]); /* send buf for (global) col ids */
426*076ba34aSJunchao Zhang     abuf = MatScalarKokkosView("abuf",sdisp[niranks]); /* send buf for mat values */
427*076ba34aSJunchao Zhang 
428*076ba34aSJunchao Zhang     const auto& Ba = bkok->a_dual.view_device();
429*076ba34aSJunchao Zhang     const auto& Bi = bkok->i_dual.view_device();
430*076ba34aSJunchao Zhang     const auto& Bj = bkok->j_dual.view_device();
431*076ba34aSJunchao Zhang 
432*076ba34aSJunchao Zhang     /* Copy Ba, Bj to abuf, jbuf with change col ids from local to global */
433*076ba34aSJunchao Zhang     Kokkos::parallel_for(Kokkos::TeamPolicy<>(rows.extent(0), Kokkos::AUTO()),KOKKOS_LAMBDA(const KokkosTeamMemberType& t) {
434*076ba34aSJunchao Zhang       PetscInt i    = t.league_rank(); /* rows[i] is r-th row of B */
435*076ba34aSJunchao Zhang       PetscInt r    = rows(i);
436*076ba34aSJunchao Zhang       PetscInt base = rowoffset(i); /* Copy r-th row of B to this offset in abuf[] */
437*076ba34aSJunchao Zhang       Kokkos::parallel_for(Kokkos::TeamThreadRange(t,Bi(r+1)-Bi(r)),[&](PetscInt k) {
438*076ba34aSJunchao Zhang         abuf(base+k) = Ba(Bi(r)+k);
439*076ba34aSJunchao Zhang         jbuf(base+k) = l2g(Bj(Bi(r)+k));
440*076ba34aSJunchao Zhang       });
441*076ba34aSJunchao Zhang     });
442*076ba34aSJunchao Zhang 
443*076ba34aSJunchao Zhang     /* Send abuf & jbuf to fill Ca, Cj */
444*076ba34aSJunchao Zhang     ierr = PetscSFBcastBegin(bcastSF,MPIU_INT,   jbuf.data(),Cj.view_device().data(),MPI_REPLACE);CHKERRQ(ierr);
445*076ba34aSJunchao Zhang     ierr = PetscSFBcastBegin(bcastSF,MPIU_SCALAR,abuf.data(),Ca.view_device().data(),MPI_REPLACE);CHKERRQ(ierr);
446*076ba34aSJunchao Zhang     ierr = PetscSFBcastEnd  (bcastSF,MPIU_INT,   jbuf.data(),Cj.view_device().data(),MPI_REPLACE);CHKERRQ(ierr);
447*076ba34aSJunchao Zhang     ierr = PetscSFBcastEnd  (bcastSF,MPIU_SCALAR,abuf.data(),Ca.view_device().data(),MPI_REPLACE);CHKERRQ(ierr);
448*076ba34aSJunchao Zhang     Cj.modify_device(); /* Mark Cj, Ca modified on device, but only sync Cj since we might not need Ca on host at all */
449*076ba34aSJunchao Zhang     Cj.sync_host();
450*076ba34aSJunchao Zhang     Ca.modify_device();
451*076ba34aSJunchao Zhang 
452*076ba34aSJunchao Zhang     /* Construct C with Ca, Ci, Cj */
453*076ba34aSJunchao Zhang     auto ckok = new Mat_SeqAIJKokkos(Cm,Cn,Cnnz,Ci,Cj,Ca);
454*076ba34aSJunchao Zhang     ierr = MatCreateSeqAIJKokkosWithCSRMatrix(PETSC_COMM_SELF,ckok,&C);CHKERRQ(ierr);
455*076ba34aSJunchao Zhang     ierr = PetscFree3(sdisp,rdisp,reqs);CHKERRQ(ierr);
456*076ba34aSJunchao Zhang     ierr = PetscFree(Browlens);CHKERRQ(ierr);
457*076ba34aSJunchao Zhang   } else SETERRQ1(PETSC_COMM_SELF,PETSC_ERR_PLIB,"Unsupported MatReuse enum %d\n",reuse);
458*076ba34aSJunchao Zhang   PetscFunctionReturn(0);
459*076ba34aSJunchao Zhang }
460*076ba34aSJunchao Zhang 
461*076ba34aSJunchao Zhang /* MatSeqAIJKokkosReduce - Reduce rows of a SEQAIJKOKKOS matrix (A) to form a Kokkos Csr matrix (C)
462*076ba34aSJunchao Zhang 
463*076ba34aSJunchao Zhang   It is the reverse of MatSeqAIJKokkosBcast in some sense.
464*076ba34aSJunchao Zhang 
465*076ba34aSJunchao Zhang   Think each row of A as a leaf, then the given ownerSF specifies roots of the leaves. Roots may connect to multiple leaves.
466*076ba34aSJunchao Zhang   In this routine, we reduce (i.e., concatenate) leaves (rows) at their roots to form potentially longer rows in C. Such rows might
467*076ba34aSJunchao Zhang   contain repeats, which does not matter since they will be summed up by other routines. C's row size will be nroots of ownerSF.
468*076ba34aSJunchao Zhang 
469*076ba34aSJunchao Zhang   Input Parameters:
470*076ba34aSJunchao Zhang +  A        - the SEQAIJKOKKOS matrix to be reduced
471*076ba34aSJunchao Zhang .  reuse    - either MAT_INITIAL_MATRIX or MAT_REUSE_MATRIX
472*076ba34aSJunchao Zhang .  local    - true if A uses local col ids; false if A is already in global col ids.
473*076ba34aSJunchao Zhang .  N        - if local, N is A's global col size
474*076ba34aSJunchao Zhang .  l2g      - if local, a map mapping A's local col ids to global ones, which are in range of [0,N).
475*076ba34aSJunchao Zhang -  ownerSF  - the SF specifies ownership (root) of rows in A
476*076ba34aSJunchao Zhang 
477*076ba34aSJunchao Zhang   Output Parameters:
478*076ba34aSJunchao Zhang +  reduceSF    - the SF to reduce A's rows to contiguous buffers at the receiver side
479*076ba34aSJunchao Zhang .  abuf         - a contiguous buffer to receive A's rows sent to this proc. Suppose there are 'nrows' such rows.
480*076ba34aSJunchao Zhang .  srcrowoffset - offset array of size nrows+1. Each entry is the corresponding row's offset in abuf[]. srcrowoffset[i+1]-srcrowoffset[i] is row i's len.
481*076ba34aSJunchao Zhang .  dstrowoffset - offset array of size nrows. Each entry is the corresponding row's offset in Ca[], i.e., C's 'a' array. Row i, i+1 in abuf[] may go to
482*076ba34aSJunchao Zhang                   unrelated places in Ca, so dstrowoffset is not in CSR-like format as srcrowoffset.
483*076ba34aSJunchao Zhang -  C            - the matrix made up by rows sent to me from other ranks, using global col ids
484*076ba34aSJunchao Zhang 
485*076ba34aSJunchao Zhang    TODO: we can even have MatSeqAIJKokkosReduceBegin/End to provide oppertunity for callers to overlap comp./comm. when reuse = MAT_REUSE_MATRIX.
486*076ba34aSJunchao Zhang  */
487*076ba34aSJunchao Zhang static PetscErrorCode MatSeqAIJKokkosReduce(Mat A,MatReuse reuse,PetscBool local,PetscInt N,const ConstMatColIdxKokkosView& l2g,PetscSF ownerSF,
488*076ba34aSJunchao Zhang                                             PetscSF& reduceSF,MatScalarKokkosView& abuf,
489*076ba34aSJunchao Zhang                                             MatRowMapKokkosView& srcrowoffset,MatRowMapKokkosView& dstrowoffset,
490*076ba34aSJunchao Zhang                                             KokkosCsrMatrix& C)
491*076ba34aSJunchao Zhang {
492*076ba34aSJunchao Zhang   PetscErrorCode         ierr;
493*076ba34aSJunchao Zhang   PetscInt               i,r,Am,An,Annz,Cnnz,nrows;
494*076ba34aSJunchao Zhang   const PetscInt         *Ai;
495*076ba34aSJunchao Zhang   Mat_SeqAIJKokkos       *akok;
496*076ba34aSJunchao Zhang 
497*076ba34aSJunchao Zhang   PetscFunctionBegin;
498*076ba34aSJunchao Zhang   ierr = MatSeqAIJKokkosSyncDevice(A);CHKERRQ(ierr); /* So that A's latest data is on device */
499*076ba34aSJunchao Zhang   ierr = MatGetSize(A,&Am,&An);
500*076ba34aSJunchao Zhang   Ai   = static_cast<Mat_SeqAIJ*>(A->data)->i;
501*076ba34aSJunchao Zhang   akok = static_cast<Mat_SeqAIJKokkos*>(A->spptr);
502*076ba34aSJunchao Zhang   Annz = Ai[Am];
503*076ba34aSJunchao Zhang 
504*076ba34aSJunchao Zhang   if (reuse == MAT_REUSE_MATRIX) {
505*076ba34aSJunchao Zhang     /* Send Aa to abuf */
506*076ba34aSJunchao Zhang     ierr = PetscSFReduceBegin(reduceSF,MPIU_SCALAR,akok->a_device_data(),abuf.data(),MPI_REPLACE);CHKERRMPI(ierr);
507*076ba34aSJunchao Zhang     ierr = PetscSFReduceEnd  (reduceSF,MPIU_SCALAR,akok->a_device_data(),abuf.data(),MPI_REPLACE);CHKERRMPI(ierr);
508*076ba34aSJunchao Zhang 
509*076ba34aSJunchao Zhang     /* Copy abuf to Ca */
510*076ba34aSJunchao Zhang     const MatScalarKokkosView& Ca = C.values;
511*076ba34aSJunchao Zhang     nrows = dstrowoffset.extent(0); /* Not srcrowoffset[] since it has an extra entry for CSR */
512*076ba34aSJunchao Zhang     Kokkos::parallel_for(Kokkos::TeamPolicy<>(nrows, Kokkos::AUTO()),KOKKOS_LAMBDA(const KokkosTeamMemberType& t) {
513*076ba34aSJunchao Zhang       PetscInt i   = t.league_rank();
514*076ba34aSJunchao Zhang       PetscInt src = srcrowoffset(i), dst = dstrowoffset(i);
515*076ba34aSJunchao Zhang       PetscInt len = srcrowoffset(i+1) - srcrowoffset(i);
516*076ba34aSJunchao Zhang       Kokkos::parallel_for(Kokkos::TeamThreadRange(t,len), [&](PetscInt k) {Ca(dst+k) = abuf(src+k);});
517*076ba34aSJunchao Zhang     });
518*076ba34aSJunchao Zhang   } else if (reuse == MAT_INITIAL_MATRIX) {
519*076ba34aSJunchao Zhang     MPI_Comm               comm;
520*076ba34aSJunchao Zhang     MPI_Request            *reqs;
521*076ba34aSJunchao Zhang     PetscMPIInt            tag;
522*076ba34aSJunchao Zhang     PetscInt               Cm;
523*076ba34aSJunchao Zhang 
524*076ba34aSJunchao Zhang     ierr = PetscObjectGetComm((PetscObject)ownerSF,&comm);CHKERRQ(ierr);
525*076ba34aSJunchao Zhang     ierr = PetscCommGetNewTag(comm,&tag);CHKERRQ(ierr);
526*076ba34aSJunchao Zhang 
527*076ba34aSJunchao Zhang     PetscInt niranks,nranks,nroots,nleaves;
528*076ba34aSJunchao Zhang     const PetscMPIInt *iranks,*ranks;
529*076ba34aSJunchao Zhang     const PetscInt *ioffset,*rows,*roffset;  /* rows[] contains local indices of rows scattered to me from others. ioffset[] is a CSR on rows[] */
530*076ba34aSJunchao Zhang     ierr = PetscSFSetUp(ownerSF);CHKERRQ(ierr);
531*076ba34aSJunchao Zhang     ierr = PetscSFGetLeafRanks(ownerSF,&niranks,&iranks,&ioffset,&rows);CHKERRQ(ierr); /* recv info: iranks[] will send rows to me */
532*076ba34aSJunchao Zhang     ierr = PetscSFGetRootRanks(ownerSF,&nranks,&ranks,&roffset,NULL/*rmine*/,NULL/*rremote*/);CHKERRQ(ierr); /* send info */
533*076ba34aSJunchao Zhang     ierr = PetscSFGetGraph(ownerSF,&nroots,&nleaves,NULL,NULL);CHKERRQ(ierr);
534*076ba34aSJunchao Zhang     if (nleaves != Am) SETERRQ2(PETSC_COMM_SELF,PETSC_ERR_PLIB,"ownerSF's nleaves(%D) != row size of A(%D)\n",nleaves,Am);
535*076ba34aSJunchao Zhang     Cm    = nroots;
536*076ba34aSJunchao Zhang     nrows = ioffset[niranks]; /* # of rows to be received. Might receive same row (each is partial) from different senders */
537*076ba34aSJunchao Zhang 
538*076ba34aSJunchao Zhang     /* Tell owners how long each row I will send */
539*076ba34aSJunchao Zhang     PetscInt                *srowlens; /* send buf of row lens */
540*076ba34aSJunchao Zhang     MatRowMapKokkosViewHost rrowlens_h("rrowoffset_h",nrows+1); /* recv buf of row lens. +1 to make CSR later. Memory might be passed to other views */
541*076ba34aSJunchao Zhang     PetscInt                *rrowlens = rrowlens_h.data();
542*076ba34aSJunchao Zhang 
543*076ba34aSJunchao Zhang     ierr = PetscMalloc2(Am,&srowlens,niranks+nranks,&reqs);CHKERRQ(ierr);
544*076ba34aSJunchao Zhang     for (i=0; i<Am; i++) srowlens[i] = Ai[i+1] - Ai[i];
545*076ba34aSJunchao Zhang     rrowlens[0] = 0;
546*076ba34aSJunchao Zhang     rrowlens++; /* shift the pointer to make the following expression more readable */
547*076ba34aSJunchao Zhang     for (i=0; i<niranks; i++){ierr = MPI_Irecv(&rrowlens[ioffset[i]],ioffset[i+1]-ioffset[i],MPIU_INT,iranks[i],tag,comm,&reqs[i]);CHKERRMPI(ierr);}
548*076ba34aSJunchao Zhang     for (i=0; i<nranks; i++) {ierr = MPI_Isend(&srowlens[roffset[i]],roffset[i+1]-roffset[i],MPIU_INT,ranks[i],tag,comm,&reqs[niranks+i]);CHKERRMPI(ierr);}
549*076ba34aSJunchao Zhang     ierr = MPI_Waitall(niranks+nranks,reqs,MPI_STATUSES_IGNORE);CHKERRMPI(ierr);
550*076ba34aSJunchao Zhang 
551*076ba34aSJunchao Zhang     /* Owner builds Ci on host by histogramming rrowlens[] */
552*076ba34aSJunchao Zhang     MatRowMapKokkosViewHost Ci_h("i",Cm+1);
553*076ba34aSJunchao Zhang     Kokkos::deep_copy(Ci_h,0); /* Zero Ci */
554*076ba34aSJunchao Zhang     MatRowMapType *Ci_ptr = Ci_h.data();
555*076ba34aSJunchao Zhang 
556*076ba34aSJunchao Zhang     for (i=0; i<nrows; i++) {
557*076ba34aSJunchao Zhang       r = rows[i]; /* local row id of i-th received row */
558*076ba34aSJunchao Zhang      #if defined(PETSC_USE_DEBUG)
559*076ba34aSJunchao Zhang       if (r<0 || r>=Cm) SETERRQ2(PETSC_COMM_SELF,PETSC_ERR_PLIB,"local row id (%D) is out of range [0,%D)\n",r,Cm);
560*076ba34aSJunchao Zhang      #endif
561*076ba34aSJunchao Zhang       Ci_ptr[r+1] += rrowlens[i]; /* add to length of row r in C */
562*076ba34aSJunchao Zhang     }
563*076ba34aSJunchao Zhang     for (i=0; i<Cm; i++) Ci_ptr[i+1] += Ci_ptr[i]; /* to CSR format */
564*076ba34aSJunchao Zhang     Cnnz = Ci_ptr[Cm];
565*076ba34aSJunchao Zhang 
566*076ba34aSJunchao Zhang     /* For each received row, compute src & dst offsets in memory copying (from recv bufs abuf, jbuf to Ca, Cj) */
567*076ba34aSJunchao Zhang     MatRowMapKokkosViewHost dstrowoffset_h("dstrowoffset_h",nrows);
568*076ba34aSJunchao Zhang     PetscInt                *dstrowoffset_hptr = dstrowoffset_h.data();
569*076ba34aSJunchao Zhang     PetscInt                *currowlens; /* Current row lens. They are temp accumulators for row lens in C, to help build dstrowoffset */
570*076ba34aSJunchao Zhang 
571*076ba34aSJunchao Zhang     ierr = PetscCalloc1(Cm,&currowlens);CHKERRQ(ierr); /* Init with zero, to be added to */
572*076ba34aSJunchao Zhang     for (i=0; i<nrows; i++) { /* for each row I receive */
573*076ba34aSJunchao Zhang       r                    = rows[i]; /* row id in C */
574*076ba34aSJunchao Zhang       dstrowoffset_hptr[i] = Ci_ptr[r] + currowlens[r]; /* dst offset of the new place for each recv'ed row in Ca/Cj */
575*076ba34aSJunchao Zhang       currowlens[r]       += rrowlens[i]; /* accumulate to length of row r in C */
576*076ba34aSJunchao Zhang     }
577*076ba34aSJunchao Zhang     ierr = PetscFree(currowlens);CHKERRQ(ierr);
578*076ba34aSJunchao Zhang 
579*076ba34aSJunchao Zhang     rrowlens--;
580*076ba34aSJunchao Zhang     for (i=0; i<nrows; i++) rrowlens[i+1] += rrowlens[i]; /* Change rrowlens[] to CSR format */
581*076ba34aSJunchao Zhang     dstrowoffset = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(),dstrowoffset_h);
582*076ba34aSJunchao Zhang     srcrowoffset = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(),rrowlens_h); /* src offset of each recv'ed row in abuf/jbuf */
583*076ba34aSJunchao Zhang 
584*076ba34aSJunchao Zhang     /* Build the reduceSF, which performs buffer to buffer send/recv */
585*076ba34aSJunchao Zhang     PetscInt *sdisp,*rdisp; /* buffer to send offsets of roots, and buffer to recv them */
586*076ba34aSJunchao Zhang     ierr = PetscMalloc2(niranks,&sdisp,nranks,&rdisp);CHKERRQ(ierr);
587*076ba34aSJunchao Zhang     for (i=0; i<niranks; i++) sdisp[i] = rrowlens[ioffset[i]];
588*076ba34aSJunchao Zhang     for (i=0; i<nranks; i++)  {ierr = MPI_Irecv(&rdisp[i],1,MPIU_INT,ranks[i],tag,comm,&reqs[i]);CHKERRMPI(ierr);}
589*076ba34aSJunchao Zhang     for (i=0; i<niranks; i++) {ierr = MPI_Isend(&sdisp[i],1,MPIU_INT,iranks[i],tag,comm,&reqs[nranks+i]);CHKERRMPI(ierr);}
590*076ba34aSJunchao Zhang     ierr = MPI_Waitall(niranks+nranks,reqs,MPI_STATUSES_IGNORE);CHKERRMPI(ierr);
591*076ba34aSJunchao Zhang 
592*076ba34aSJunchao Zhang     /* Nonzeros in abuf/jbuf are roots and those in A are leaves */
593*076ba34aSJunchao Zhang     PetscInt    nroots2 = Cnnz,nleaves2 = Annz;
594*076ba34aSJunchao Zhang     PetscSFNode *iremote;
595*076ba34aSJunchao Zhang     ierr = PetscMalloc1(nleaves2,&iremote);CHKERRQ(ierr); /* no free, since memory will be given to reduceSF */
596*076ba34aSJunchao Zhang     for (i=0; i<nranks; i++) {
597*076ba34aSJunchao Zhang       PetscInt rootbase = rdisp[i]; /* root offset at this root rank */
598*076ba34aSJunchao Zhang       PetscInt leafbase = Ai[roffset[i]]; /* leaf base */
599*076ba34aSJunchao Zhang       PetscInt nz       = Ai[roffset[i+1]] - leafbase; /* I will send nz nonzeros to this root rank */
600*076ba34aSJunchao Zhang       for (PetscInt k=0; k<nz; k++) {
601*076ba34aSJunchao Zhang         iremote[leafbase+k].rank  = ranks[i];
602*076ba34aSJunchao Zhang         iremote[leafbase+k].index = rootbase + k;
603*076ba34aSJunchao Zhang       }
604*076ba34aSJunchao Zhang     }
605*076ba34aSJunchao Zhang     ierr = PetscSFCreate(comm,&reduceSF);CHKERRQ(ierr);
606*076ba34aSJunchao Zhang     ierr = PetscSFSetGraph(reduceSF,nroots2,nleaves2,NULL,PETSC_OWN_POINTER,iremote,PETSC_OWN_POINTER);CHKERRQ(ierr);
607*076ba34aSJunchao Zhang     ierr = PetscFree2(sdisp,rdisp);CHKERRQ(ierr);
608*076ba34aSJunchao Zhang 
609*076ba34aSJunchao Zhang     /* Reduce Aa, Ajg to abuf and jbuf */
610*076ba34aSJunchao Zhang 
611*076ba34aSJunchao Zhang     /* If A uses local col ids, convert them to global ones before sending */
612*076ba34aSJunchao Zhang     MatColIdxKokkosView Ajg;
613*076ba34aSJunchao Zhang     if (local) {
614*076ba34aSJunchao Zhang       Ajg = MatColIdxKokkosView("j",Annz);
615*076ba34aSJunchao Zhang       const MatColIdxKokkosView& Aj = akok->j_dual.view_device();
616*076ba34aSJunchao Zhang       Kokkos::parallel_for(Annz,KOKKOS_LAMBDA(const PetscInt i) {Ajg(i) = l2g(Aj(i));});
617*076ba34aSJunchao Zhang     } else {
618*076ba34aSJunchao Zhang       Ajg = akok->j_dual.view_device(); /* no data copy, just take a reference */
619*076ba34aSJunchao Zhang     }
620*076ba34aSJunchao Zhang 
621*076ba34aSJunchao Zhang     MatColIdxKokkosView   jbuf("jbuf",Cnnz);
622*076ba34aSJunchao Zhang     abuf = MatScalarKokkosView("abuf",Cnnz);
623*076ba34aSJunchao Zhang     ierr = PetscSFReduceBegin(reduceSF,MPIU_INT,   Ajg.data(),           jbuf.data(),MPI_REPLACE);CHKERRMPI(ierr);
624*076ba34aSJunchao Zhang     ierr = PetscSFReduceEnd  (reduceSF,MPIU_INT,   Ajg.data(),           jbuf.data(),MPI_REPLACE);CHKERRMPI(ierr);
625*076ba34aSJunchao Zhang     ierr = PetscSFReduceBegin(reduceSF,MPIU_SCALAR,akok->a_device_data(),abuf.data(),MPI_REPLACE);CHKERRMPI(ierr);
626*076ba34aSJunchao Zhang     ierr = PetscSFReduceEnd  (reduceSF,MPIU_SCALAR,akok->a_device_data(),abuf.data(),MPI_REPLACE);CHKERRMPI(ierr);
627*076ba34aSJunchao Zhang 
628*076ba34aSJunchao Zhang     /* Copy data from abuf, jbuf to Ca, Cj */
629*076ba34aSJunchao Zhang     MatRowMapKokkosView    Ci = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(),Ci_h); /* Ci is an alias of Ci_h if no device */
630*076ba34aSJunchao Zhang     MatColIdxKokkosView    Cj("j",Cnnz);
631*076ba34aSJunchao Zhang     MatScalarKokkosView    Ca("a",Cnnz);
632*076ba34aSJunchao Zhang 
633*076ba34aSJunchao Zhang     Kokkos::parallel_for(Kokkos::TeamPolicy<>(nrows, Kokkos::AUTO()),KOKKOS_LAMBDA(const KokkosTeamMemberType& t) {
634*076ba34aSJunchao Zhang       PetscInt i   = t.league_rank();
635*076ba34aSJunchao Zhang       PetscInt src = srcrowoffset(i), dst = dstrowoffset(i);
636*076ba34aSJunchao Zhang       PetscInt len = srcrowoffset(i+1) - srcrowoffset(i);
637*076ba34aSJunchao Zhang       Kokkos::parallel_for(Kokkos::TeamThreadRange(t,len), [&](PetscInt k) {
638*076ba34aSJunchao Zhang         Ca(dst+k) = abuf(src+k);
639*076ba34aSJunchao Zhang         Cj(dst+k) = jbuf(src+k);
640*076ba34aSJunchao Zhang       });
641*076ba34aSJunchao Zhang     });
642*076ba34aSJunchao Zhang 
643*076ba34aSJunchao Zhang     /* Build C with Ca, Ci, Cj */
644*076ba34aSJunchao Zhang     C    = KokkosCsrMatrix("csrmat",Cm,N,Cnnz,Ca,Ci,Cj);
645*076ba34aSJunchao Zhang     ierr = PetscFree2(srowlens,reqs);CHKERRQ(ierr);
646*076ba34aSJunchao Zhang   } else SETERRQ1(PETSC_COMM_SELF,PETSC_ERR_PLIB,"Unsupported MatReuse enum %d\n",reuse);
647*076ba34aSJunchao Zhang   PetscFunctionReturn(0);
648*076ba34aSJunchao Zhang }
649*076ba34aSJunchao Zhang 
650*076ba34aSJunchao Zhang /* MatSetMPIAIJKokkosWithGlobalCSRMatrix - Set the diag and offdiag parts of a MATMPIAIJKOKKOS matrix by splitting a KokkosCsrMatrix
651*076ba34aSJunchao Zhang 
652*076ba34aSJunchao Zhang   Input Parameters:
653*076ba34aSJunchao Zhang +  C        - the MATMPIAIJKOKKOS matrix, of size m,n,M,N
654*076ba34aSJunchao Zhang .  reuse    - indicate whether the matrix has called this function before
655*076ba34aSJunchao Zhang .  csrmat   - the KokkosCsrMatrix, of size m,N
656*076ba34aSJunchao Zhang -  Cdstart  - when reuse == MAT_REUSE_MATRIX, it is an input parameter. For each row in csrmat, it stores the start of the first
657*076ba34aSJunchao Zhang               entry of the diag block of C in csrmat's j array. E.g, if row i has col ids = {0, 3, 4, 5, 7, 9} and the first diag
658*076ba34aSJunchao Zhang               entry is 5, then Cdstart[i] = 3.
659*076ba34aSJunchao Zhang 
660*076ba34aSJunchao Zhang   Output Parameters:
661*076ba34aSJunchao Zhang +  C        - the updated MATMPIAIJKOKKOS matrix
662*076ba34aSJunchao Zhang -  Cdstart - when reuse == MAT_INITIAL_MATRIX, it is an output parameter
663*076ba34aSJunchao Zhang 
664*076ba34aSJunchao Zhang   Notes:
665*076ba34aSJunchao Zhang    Between calls with MAT_INITIAL_MATRIX or MAT_REUSE_MATRIX, csrmat must have the same nonzero pattern
666*076ba34aSJunchao Zhang  */
667*076ba34aSJunchao Zhang static PetscErrorCode MatSetMPIAIJKokkosWithGlobalCSRMatrix(Mat C,MatReuse reuse,const KokkosCsrMatrix& csrmat,MatRowMapKokkosView& Cdstart)
668*076ba34aSJunchao Zhang {
669*076ba34aSJunchao Zhang   PetscErrorCode                  ierr;
670*076ba34aSJunchao Zhang   const MatScalarKokkosView&      Ca = csrmat.values;
671*076ba34aSJunchao Zhang   const ConstMatRowMapKokkosView& Ci = csrmat.graph.row_map;
672*076ba34aSJunchao Zhang   PetscInt                        m,n,N;
673*076ba34aSJunchao Zhang 
674*076ba34aSJunchao Zhang   PetscFunctionBegin;
675*076ba34aSJunchao Zhang   ierr = MatGetLocalSize(C,&m,&n);CHKERRQ(ierr);
676*076ba34aSJunchao Zhang   ierr = MatGetSize(C,NULL,&N);CHKERRQ(ierr);
677*076ba34aSJunchao Zhang 
678*076ba34aSJunchao Zhang   if (reuse == MAT_REUSE_MATRIX) {
679*076ba34aSJunchao Zhang     Mat_MPIAIJ                  *mpiaij = static_cast<Mat_MPIAIJ*>(C->data);
680*076ba34aSJunchao Zhang     Mat_SeqAIJKokkos            *akok = static_cast<Mat_SeqAIJKokkos*>(mpiaij->A->spptr);
681*076ba34aSJunchao Zhang     Mat_SeqAIJKokkos            *bkok = static_cast<Mat_SeqAIJKokkos*>(mpiaij->B->spptr);
682*076ba34aSJunchao Zhang     const MatScalarKokkosView&  Cda = akok->a_dual.view_device(),Coa = bkok->a_dual.view_device();
683*076ba34aSJunchao Zhang     const MatRowMapKokkosView&  Cdi = akok->i_dual.view_device(),Coi = bkok->i_dual.view_device();
684*076ba34aSJunchao Zhang 
685*076ba34aSJunchao Zhang     /* Fill 'a' of Cd and Co on device */
686*076ba34aSJunchao Zhang     Kokkos::parallel_for(Kokkos::TeamPolicy<>(m, Kokkos::AUTO()),KOKKOS_LAMBDA(const KokkosTeamMemberType& t) {
687*076ba34aSJunchao Zhang       PetscInt i       = t.league_rank(); /* row i */
688*076ba34aSJunchao Zhang       PetscInt clen    = Ci(i+1) - Ci(i); /* len of row i of C */
689*076ba34aSJunchao Zhang       PetscInt cdlen   = Cdi(i+1) - Cdi(i); /* len of row i of Cd */
690*076ba34aSJunchao Zhang       PetscInt cdstart = Cdstart(i); /* [start, end) of row i of Cd in C */
691*076ba34aSJunchao Zhang       PetscInt cdend   = cdstart + cdlen;
692*076ba34aSJunchao Zhang       /* [0, clen) is cut into three blocks: [0, cdstart), [cdstart, cdend), [cdend, clen) */
693*076ba34aSJunchao Zhang       Kokkos::parallel_for(Kokkos::TeamThreadRange(t, clen), [&](PetscInt k) {
694*076ba34aSJunchao Zhang         if (k < cdstart) {  /* k in [0, cdstart) */
695*076ba34aSJunchao Zhang           Coa(Coi(i)+k) = Ca(Ci(i)+k);
696*076ba34aSJunchao Zhang         } else if (k < cdend) { /* k in [cdstart, cdend) */
697*076ba34aSJunchao Zhang           Cda(Cdi(i)+(k-cdstart)) = Ca(Ci(i)+k);
698*076ba34aSJunchao Zhang         } else { /* k in [cdend, clen) */
699*076ba34aSJunchao Zhang           Coa(Coi(i)+k-cdlen) = Ca(Ci(i)+k);
700*076ba34aSJunchao Zhang         }
701*076ba34aSJunchao Zhang       });
702*076ba34aSJunchao Zhang     });
703*076ba34aSJunchao Zhang 
704*076ba34aSJunchao Zhang     akok->a_dual.modify_device();
705*076ba34aSJunchao Zhang     bkok->a_dual.modify_device();
706*076ba34aSJunchao Zhang   } else if (reuse == MAT_INITIAL_MATRIX) {
707*076ba34aSJunchao Zhang     Mat                         Cd,Co;
708*076ba34aSJunchao Zhang     const MatColIdxKokkosView&  Cj = csrmat.graph.entries;
709*076ba34aSJunchao Zhang     MatRowMapKokkosDualView     Cdi_dual("i",m+1),Coi_dual("i",m+1);
710*076ba34aSJunchao Zhang     MatRowMapKokkosView         Cdi = Cdi_dual.view_device(),Coi = Coi_dual.view_device();
711*076ba34aSJunchao Zhang     PetscInt                    cstart,cend;
712*076ba34aSJunchao Zhang 
713*076ba34aSJunchao Zhang     /* Note that each row of C is sorted by col ids. We want to find out how to cut each row into three blocks:
714*076ba34aSJunchao Zhang        left to the diag block, diag block, right to the diag block. The diag block have col ids in [cstart,cend).
715*076ba34aSJunchao Zhang        Suppose a row of C has len nonzeros, indexed by [0, len). We want to know two indices: cdstart and cdend,
716*076ba34aSJunchao Zhang        such that the three blocks are [0,cdstart), [cdstart,cdend), [cdend,len). The following code equivalentaly
717*076ba34aSJunchao Zhang        stores values of cdstart and cdend-cstart (aka Cdi[]) instead.
718*076ba34aSJunchao Zhang      */
719*076ba34aSJunchao Zhang     Cdstart = MatRowMapKokkosView("Cdstart",m);
720*076ba34aSJunchao Zhang     ierr    = PetscLayoutGetRange(C->cmap,&cstart,&cend);CHKERRQ(ierr); /* Not MatGetOwnershipRangeColumn() since C has not been preallocated yet */
721*076ba34aSJunchao Zhang 
722*076ba34aSJunchao Zhang     /* I could use RangePolicy and one thread per row. But since each thread essentially does binary search, threads in a
723*076ba34aSJunchao Zhang       CUDA warp would completely diverge. So I use TeamPolicy with a team size 1.
724*076ba34aSJunchao Zhang      */
725*076ba34aSJunchao Zhang     Kokkos::parallel_for(Kokkos::TeamPolicy<>(m, 1),KOKKOS_LAMBDA(const KokkosTeamMemberType& t) {
726*076ba34aSJunchao Zhang       Kokkos::single(Kokkos::PerTeam(t), [=] () { /* Only one thread works in a team */
727*076ba34aSJunchao Zhang         PetscInt i = t.league_rank(); /* row i */
728*076ba34aSJunchao Zhang         PetscInt j,first,count,step;
729*076ba34aSJunchao Zhang 
730*076ba34aSJunchao Zhang         if (i == 0) { /* Set the first entry of the i arrays to zero on device, to be used in CSR */
731*076ba34aSJunchao Zhang           Cdi(0) = 0;
732*076ba34aSJunchao Zhang           Coi(0) = 0;
733*076ba34aSJunchao Zhang         }
734*076ba34aSJunchao Zhang 
735*076ba34aSJunchao Zhang         /* Do std::lower_bound(Ci(i),Ci(i+1),cstart) on Cj[]. We use j as the iterator. lower_bound() returns
736*076ba34aSJunchao Zhang           in 'first' the first iterator with a value >= cstart, or last iterator if no such element is found.
737*076ba34aSJunchao Zhang         */
738*076ba34aSJunchao Zhang         count = Ci(i+1)-Ci(i);
739*076ba34aSJunchao Zhang         first = Ci(i);
740*076ba34aSJunchao Zhang         while (count > 0) {
741*076ba34aSJunchao Zhang           j    = first;
742*076ba34aSJunchao Zhang           step = count / 2;
743*076ba34aSJunchao Zhang           j   += step;
744*076ba34aSJunchao Zhang           if (Cj(j) < cstart) {
745*076ba34aSJunchao Zhang             first  = ++j;
746*076ba34aSJunchao Zhang             count -= step + 1;
747*076ba34aSJunchao Zhang           } else count = step;
748*076ba34aSJunchao Zhang         }
749*076ba34aSJunchao Zhang         Cdstart(i) = first - Ci(i); /* 'first' is the while-loop's output */
750*076ba34aSJunchao Zhang 
751*076ba34aSJunchao Zhang         /* Do std::lower_bound(first,Ci(i+1),cend) on Cj[] */
752*076ba34aSJunchao Zhang         count = Ci(i+1) - first;
753*076ba34aSJunchao Zhang         while (count > 0) {
754*076ba34aSJunchao Zhang           j    = first;
755*076ba34aSJunchao Zhang           step = count / 2;
756*076ba34aSJunchao Zhang           j   += step;
757*076ba34aSJunchao Zhang           if (Cj(j) < cend) {
758*076ba34aSJunchao Zhang             first  = ++j;
759*076ba34aSJunchao Zhang             count -= step + 1;
760*076ba34aSJunchao Zhang           } else count = step;
761*076ba34aSJunchao Zhang         }
762*076ba34aSJunchao Zhang         Cdi(i+1) = first - (Ci(i)+Cdstart(i)); /* 'first' is the while-loop's output */
763*076ba34aSJunchao Zhang         Coi(i+1) = (Ci(i+1)-Ci(i)) - Cdi(i+1); /* Co's row len = C's row len - Cd's row len */
764*076ba34aSJunchao Zhang       });
765*076ba34aSJunchao Zhang     });
766*076ba34aSJunchao Zhang 
767*076ba34aSJunchao Zhang     /* Convert row lens in Cdi[], Coi[] to CSR format using inclusive scan, e.g., changing [0,1,2,3] into [0,1,3,6] */
768*076ba34aSJunchao Zhang     Kokkos::parallel_scan(m+1,KOKKOS_LAMBDA(const PetscInt i,PetscInt& update,const bool final) {
769*076ba34aSJunchao Zhang       update += Cdi(i);
770*076ba34aSJunchao Zhang       if (final) Cdi(i) = update;
771*076ba34aSJunchao Zhang     });
772*076ba34aSJunchao Zhang     Kokkos::parallel_scan(m+1,KOKKOS_LAMBDA(const PetscInt i,PetscInt& update,const bool final) {
773*076ba34aSJunchao Zhang       update += Coi(i);
774*076ba34aSJunchao Zhang       if (final) Coi(i) = update;
775*076ba34aSJunchao Zhang     });
776*076ba34aSJunchao Zhang 
777*076ba34aSJunchao Zhang     /* Get Cdi, Coi on host (it is not a waste, since we do need them on host in
778*076ba34aSJunchao Zhang        MatCreateSeqAIJKokkosWithCSRMatrix() below), then get nnz of Cd and Co.
779*076ba34aSJunchao Zhang     */
780*076ba34aSJunchao Zhang     Cdi_dual.modify_device();
781*076ba34aSJunchao Zhang     Coi_dual.modify_device();
782*076ba34aSJunchao Zhang     Cdi_dual.sync_host();
783*076ba34aSJunchao Zhang     Coi_dual.sync_host();
784*076ba34aSJunchao Zhang     PetscInt Cd_nnz = Cdi_dual.view_host().data()[m];
785*076ba34aSJunchao Zhang     PetscInt Co_nnz = Coi_dual.view_host().data()[m];
786*076ba34aSJunchao Zhang 
787*076ba34aSJunchao Zhang     /* With nnz, allocate a, j for Cd and Co */
788*076ba34aSJunchao Zhang     MatColIdxKokkosDualView Cdj_dual("j",Cd_nnz),Coj_dual("j",Co_nnz);
789*076ba34aSJunchao Zhang     MatScalarKokkosDualView Cda_dual("a",Cd_nnz),Coa_dual("a",Co_nnz);
790*076ba34aSJunchao Zhang 
791*076ba34aSJunchao Zhang     /* Fill a, j of Cd and Co on device */
792*076ba34aSJunchao Zhang     MatColIdxKokkosView     Cdj = Cdj_dual.view_device(),Coj = Coj_dual.view_device();
793*076ba34aSJunchao Zhang     MatScalarKokkosView     Cda = Cda_dual.view_device(),Coa = Coa_dual.view_device();
794*076ba34aSJunchao Zhang 
795*076ba34aSJunchao Zhang     Kokkos::parallel_for(Kokkos::TeamPolicy<>(m, Kokkos::AUTO()),KOKKOS_LAMBDA(const KokkosTeamMemberType& t) {
796*076ba34aSJunchao Zhang       PetscInt i       = t.league_rank(); /* row i */
797*076ba34aSJunchao Zhang       PetscInt clen    = Ci(i+1) - Ci(i); /* len of row i of C */
798*076ba34aSJunchao Zhang       PetscInt cdlen   = Cdi(i+1) - Cdi(i); /* len of row i of Cd */
799*076ba34aSJunchao Zhang       PetscInt cdstart = Cdstart(i); /* [start, end) of row i of Cd in C */
800*076ba34aSJunchao Zhang       PetscInt cdend   = cdstart + cdlen;
801*076ba34aSJunchao Zhang       /* [0, clen) is cut into three blocks: [0, cdstart), [cdstart, cdend), [cdend, clen) */
802*076ba34aSJunchao Zhang       Kokkos::parallel_for(Kokkos::TeamThreadRange(t, clen), [&](PetscInt k) {
803*076ba34aSJunchao Zhang         if (k < cdstart) { /* k in [0, cdstart) */
804*076ba34aSJunchao Zhang           Coa(Coi(i)+k) = Ca(Ci(i)+k);
805*076ba34aSJunchao Zhang           Coj(Coi(i)+k) = Cj(Ci(i)+k);
806*076ba34aSJunchao Zhang         } else if (k < cdend) { /* k in [cdstart, cdend) */
807*076ba34aSJunchao Zhang           Cda(Cdi(i)+(k-cdstart)) = Ca(Ci(i)+k);
808*076ba34aSJunchao Zhang           Cdj(Cdi(i)+(k-cdstart)) = Cj(Ci(i)+k) - cstart; /* Use local col ids in Cdj */
809*076ba34aSJunchao Zhang         } else { /* k in [cdend, clen) */
810*076ba34aSJunchao Zhang           Coa(Coi(i)+k-cdlen) = Ca(Ci(i)+k);
811*076ba34aSJunchao Zhang           Coj(Coi(i)+k-cdlen) = Cj(Ci(i)+k);
812*076ba34aSJunchao Zhang         }
813*076ba34aSJunchao Zhang       });
814*076ba34aSJunchao Zhang     });
815*076ba34aSJunchao Zhang 
816*076ba34aSJunchao Zhang     Cdj_dual.modify_device();
817*076ba34aSJunchao Zhang     Cda_dual.modify_device();
818*076ba34aSJunchao Zhang     Coj_dual.modify_device();
819*076ba34aSJunchao Zhang     Coa_dual.modify_device();
820*076ba34aSJunchao Zhang     /* With a, i, j for Cd and Co, finally build Cd, Co and then C. Their offloadmask will be set in each's MatAssemblyEnd */
821*076ba34aSJunchao Zhang     auto cdkok = new Mat_SeqAIJKokkos(m,n,Cd_nnz,Cdi_dual,Cdj_dual,Cda_dual);
822*076ba34aSJunchao Zhang     auto cokok = new Mat_SeqAIJKokkos(m,N,Co_nnz,Coi_dual,Coj_dual,Coa_dual);
823*076ba34aSJunchao Zhang     ierr = MatCreateSeqAIJKokkosWithCSRMatrix(PETSC_COMM_SELF,cdkok,&Cd);CHKERRQ(ierr);
824*076ba34aSJunchao Zhang     ierr = MatCreateSeqAIJKokkosWithCSRMatrix(PETSC_COMM_SELF,cokok,&Co);CHKERRQ(ierr);
825*076ba34aSJunchao Zhang     ierr = MatSetMPIAIJKokkosWithSplitSeqAIJKokkosMatrices(C,Cd,Co);CHKERRQ(ierr); /* Coj will be converted to local ids within */
826*076ba34aSJunchao Zhang   }
827*076ba34aSJunchao Zhang   PetscFunctionReturn(0);
828*076ba34aSJunchao Zhang }
829*076ba34aSJunchao Zhang 
830*076ba34aSJunchao Zhang /* MatSeqAIJCompactOutExtraColumns_SeqAIJKokkos - Compact a SEQAIJKOKKS matrix's global col ids.
831*076ba34aSJunchao Zhang 
832*076ba34aSJunchao Zhang   It is similar to MatSeqAIJCompactOutExtraColumns_SeqAIJ, but it applies to SEQAIJKOKKOS and returns the l2g map in Kokkos view.
833*076ba34aSJunchao Zhang 
834*076ba34aSJunchao Zhang   Input Parameters:
835*076ba34aSJunchao Zhang +  C        - the MATMPIAIJKOKKOS matrix, of size m,n,M,N
836*076ba34aSJunchao Zhang .  reuse    - indicate whether the matrix has called this function before
837*076ba34aSJunchao Zhang .  csrmat   - the KokkosCsrMatrix, of size m,N
838*076ba34aSJunchao Zhang -  Cdoffset - when reuse == MAT_REUSE_MATRIX, it is an input parameter. For each row in csrmat, it stores the offset of the first
839*076ba34aSJunchao Zhang               entry of the diag block of C in csrmat's j array.
840*076ba34aSJunchao Zhang 
841*076ba34aSJunchao Zhang   Output Parameters:
842*076ba34aSJunchao Zhang +  C        - the updated MATMPIAIJKOKKOS matrix
843*076ba34aSJunchao Zhang -  Cdoffset - when reuse == MAT_INITIAL_MATRIX, it is an output parameter
844*076ba34aSJunchao Zhang 
845*076ba34aSJunchao Zhang   Notes: the input matrix's col ids and col size will be changed.
846*076ba34aSJunchao Zhang */
847*076ba34aSJunchao Zhang static PetscErrorCode MatSeqAIJCompactOutExtraColumns_SeqAIJKokkos(Mat C,MatColIdxKokkosView& l2g)
848*076ba34aSJunchao Zhang {
849*076ba34aSJunchao Zhang   PetscErrorCode         ierr;
850*076ba34aSJunchao Zhang   Mat_SeqAIJKokkos       *ckok;
851*076ba34aSJunchao Zhang   ISLocalToGlobalMapping l2gmap;
852*076ba34aSJunchao Zhang   const PetscInt         *garray;
853*076ba34aSJunchao Zhang   PetscInt               sz;
854*076ba34aSJunchao Zhang 
855*076ba34aSJunchao Zhang   PetscFunctionBegin;
856*076ba34aSJunchao Zhang   /* Compact P_other's global col ids and col size. We do it since we guess with local ids KK might be more memory scalable */
857*076ba34aSJunchao Zhang   ierr = MatSeqAIJCompactOutExtraColumns_SeqAIJ(C,&l2gmap);CHKERRQ(ierr);
858*076ba34aSJunchao Zhang   ckok = static_cast<Mat_SeqAIJKokkos*>(C->spptr);
859*076ba34aSJunchao Zhang   ckok->j_dual.modify_host(); /* P_other's j is modified on host; we need to sync it on device */
860*076ba34aSJunchao Zhang   ckok->j_dual.sync_device();
861*076ba34aSJunchao Zhang   ckok->SetColSize(C->cmap->n); /* Update col size of the csrmat in spptr */
862*076ba34aSJunchao Zhang 
863*076ba34aSJunchao Zhang   /* Build l2g -- the local to global mapping of C's cols */
864*076ba34aSJunchao Zhang   ierr = ISLocalToGlobalMappingGetIndices(l2gmap,&garray);CHKERRQ(ierr);
865*076ba34aSJunchao Zhang   ierr = ISLocalToGlobalMappingGetSize(l2gmap,&sz);CHKERRQ(ierr);
866*076ba34aSJunchao Zhang   if (C->cmap->n != sz) SETERRQ2(PETSC_COMM_SELF,PETSC_ERR_PLIB,"matrix column size(%D) != l2g mapping size(%D)\n", C->cmap->n,sz);
867*076ba34aSJunchao Zhang 
868*076ba34aSJunchao Zhang   ConstMatColIdxKokkosViewHost tmp(garray,sz);
869*076ba34aSJunchao Zhang   l2g = MatColIdxKokkosView("l2g",sz);
870*076ba34aSJunchao Zhang   Kokkos::deep_copy(l2g,tmp);
871*076ba34aSJunchao Zhang 
872*076ba34aSJunchao Zhang   ierr = ISLocalToGlobalMappingRestoreIndices(l2gmap,&garray);CHKERRQ(ierr);
873*076ba34aSJunchao Zhang   ierr = ISLocalToGlobalMappingDestroy(&l2gmap);CHKERRQ(ierr);
874*076ba34aSJunchao Zhang   PetscFunctionReturn(0);
875*076ba34aSJunchao Zhang }
876*076ba34aSJunchao Zhang 
877*076ba34aSJunchao Zhang /* MatProductSymbolic_MPIAIJKokkos_AB - AB flavor of MatProductSymbolic_MPIAIJKokkos
878*076ba34aSJunchao Zhang 
879*076ba34aSJunchao Zhang   Input Parameters:
880*076ba34aSJunchao Zhang +  product  - Mat_Product which carried out the computation. Passed in to access info about this mat product.
881*076ba34aSJunchao Zhang .  A        - an MPIAIJKOKKOS matrix
882*076ba34aSJunchao Zhang .  B        - an MPIAIJKOKKOS matrix
883*076ba34aSJunchao Zhang -  mm       - a struct used to stash intermediate data when computing AB. Persist from symbolic to numeric operations.
884*076ba34aSJunchao Zhang 
885*076ba34aSJunchao Zhang   Notes: The local part of the result C is stored as mm->C_global, which is of type KokkosCsrMatrix and uses global col ids.
886*076ba34aSJunchao Zhang */
887*076ba34aSJunchao Zhang static PetscErrorCode MatProductSymbolic_MPIAIJKokkos_AB(Mat_Product *product,Mat A,Mat B,MatMatStruct_AB *mm)
888*076ba34aSJunchao Zhang {
889*076ba34aSJunchao Zhang   PetscErrorCode              ierr;
890*076ba34aSJunchao Zhang   Mat_MPIAIJ                  *a = static_cast<Mat_MPIAIJ*>(A->data);
891*076ba34aSJunchao Zhang   Mat                         Ad = a->A,Ao = a->B; /* diag and offdiag of A */
892*076ba34aSJunchao Zhang   IS                          glob = NULL;
893*076ba34aSJunchao Zhang   const PetscInt              *garray;
894*076ba34aSJunchao Zhang   PetscInt                    N = B->cmap->N,sz;
895*076ba34aSJunchao Zhang   ConstMatColIdxKokkosView    l2g1; /* two temp maps mapping local col ids to global ones */
896*076ba34aSJunchao Zhang   MatColIdxKokkosView         l2g2;
897*076ba34aSJunchao Zhang   Mat                         C1,C2; /* intermediate matrices */
898*076ba34aSJunchao Zhang 
899*076ba34aSJunchao Zhang   PetscFunctionBegin;
900*076ba34aSJunchao Zhang   /* C1 = Ad * B_local. B_local is a matrix got by merging Bd and Bo, and uses local col ids */
901*076ba34aSJunchao Zhang   ierr = MatMPIAIJGetLocalMatMerge(B,MAT_INITIAL_MATRIX,&glob,&mm->B_local);CHKERRQ(ierr);
902*076ba34aSJunchao Zhang   ierr = MatProductCreate(Ad,mm->B_local,NULL,&C1);CHKERRQ(ierr);
903*076ba34aSJunchao Zhang   ierr = MatProductSetType(C1,MATPRODUCT_AB);CHKERRQ(ierr);
904*076ba34aSJunchao Zhang   ierr = MatProductSetFill(C1,product->fill);CHKERRQ(ierr);
905*076ba34aSJunchao Zhang   C1->product->api_user = product->api_user;
906*076ba34aSJunchao Zhang   ierr = MatProductSetFromOptions(C1);CHKERRQ(ierr);
907*076ba34aSJunchao Zhang   if (!C1->ops->productsymbolic) SETERRQ1(PETSC_COMM_SELF,PETSC_ERR_PLIB,"Missing symbolic op for %s",MatProductTypes[C1->product->type]);
908*076ba34aSJunchao Zhang   ierr = (*C1->ops->productsymbolic)(C1);CHKERRQ(ierr);
909*076ba34aSJunchao Zhang 
910*076ba34aSJunchao Zhang   ierr = ISGetIndices(glob,&garray);CHKERRQ(ierr);
911*076ba34aSJunchao Zhang   ierr = ISGetSize(glob,&sz);CHKERRQ(ierr);
912*076ba34aSJunchao Zhang   const auto& tmp  = ConstMatColIdxKokkosViewHost(garray,sz); /* wrap garray as a view */
913*076ba34aSJunchao Zhang   l2g1 = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(),tmp); /* maybe just an alias to tmp, so we restore garray at the very end */
914*076ba34aSJunchao Zhang   ierr = MatSeqAIJKokkosGetCSRMatrixWithGlobalColumnIds(C1,N,l2g1,mm->C1_global);
915*076ba34aSJunchao Zhang 
916*076ba34aSJunchao Zhang   /* C2 = Ao * B_other. B_other is a matrix consisting of needed rows of B gathered from other procs */
917*076ba34aSJunchao Zhang   ierr = MatSeqAIJKokkosBcast(mm->B_local,MAT_INITIAL_MATRIX,N,l2g1,a->Mvctx,mm->sf,
918*076ba34aSJunchao Zhang                               mm->abuf,mm->rows,mm->rowoffset,mm->B_other);CHKERRQ(ierr);
919*076ba34aSJunchao Zhang 
920*076ba34aSJunchao Zhang   /* Compact B_other to use local ids as we guess KK spgemm is more memroy scalable with that; We could skip the compaction to simplify code */
921*076ba34aSJunchao Zhang   ierr = MatSeqAIJCompactOutExtraColumns_SeqAIJKokkos(mm->B_other,l2g2);CHKERRQ(ierr);
922*076ba34aSJunchao Zhang   ierr = MatProductCreate(Ao,mm->B_other,NULL,&C2);CHKERRQ(ierr);
923*076ba34aSJunchao Zhang   ierr = MatProductSetType(C2,MATPRODUCT_AB);CHKERRQ(ierr);
924*076ba34aSJunchao Zhang   ierr = MatProductSetFill(C2,product->fill);CHKERRQ(ierr);
925*076ba34aSJunchao Zhang   C2->product->api_user = product->api_user;
926*076ba34aSJunchao Zhang   ierr = MatProductSetFromOptions(C2);CHKERRQ(ierr);
927*076ba34aSJunchao Zhang   if (!C2->ops->productsymbolic) SETERRQ1(PETSC_COMM_SELF,PETSC_ERR_PLIB,"Missing symbolic op for %s",MatProductTypes[C2->product->type]);
928*076ba34aSJunchao Zhang   ierr = (*C2->ops->productsymbolic)(C2);CHKERRQ(ierr);
929*076ba34aSJunchao Zhang   ierr = MatSeqAIJKokkosGetCSRMatrixWithGlobalColumnIds(C2,N,l2g2,mm->C2_global);
930*076ba34aSJunchao Zhang 
931*076ba34aSJunchao Zhang   /* C = C1 + C2.  We actually use their global col ids versions in adding */
932*076ba34aSJunchao Zhang   mm->kh.create_spadd_handle(false); /* Input C1, C2 are NOT sorted, since B_local, B_other are not */
933*076ba34aSJunchao Zhang   KokkosSparse::spadd_symbolic(&mm->kh,mm->C1_global,mm->C2_global,mm->C_global);
934*076ba34aSJunchao Zhang   /* Have to do numeric since spadd_symbolic does not really populate column indices of the result matrix */
935*076ba34aSJunchao Zhang   KokkosSparse::spadd_numeric(&mm->kh,(MatScalarType)1.0,mm->C1_global,(MatScalarType)1.0,mm->C2_global,mm->C_global);
936*076ba34aSJunchao Zhang 
937*076ba34aSJunchao Zhang   mm->C1 = C1;
938*076ba34aSJunchao Zhang   mm->C2 = C2;
939*076ba34aSJunchao Zhang   ierr = ISRestoreIndices(glob,&garray);CHKERRQ(ierr);
940*076ba34aSJunchao Zhang   ierr = ISDestroy(&glob);CHKERRQ(ierr);
941*076ba34aSJunchao Zhang   PetscFunctionReturn(0);
942*076ba34aSJunchao Zhang }
943*076ba34aSJunchao Zhang 
944*076ba34aSJunchao Zhang /* MatProductSymbolic_MPIAIJKokkos_AtB - A^tB flavor of MatProductSymbolic_MPIAIJKokkos
945*076ba34aSJunchao Zhang 
946*076ba34aSJunchao Zhang   Input Parameters:
947*076ba34aSJunchao Zhang +  product  - Mat_Product which carried out the computation. Passed in to access info about this mat product.
948*076ba34aSJunchao Zhang .  A        - an MPIAIJKOKKOS matrix
949*076ba34aSJunchao Zhang .  B        - a SEQAIJKOKKOS matrix. It works as if A^t is multiplied by a parallel matrix made up of Bs on each rank.
950*076ba34aSJunchao Zhang .  localB   - Does B use local col ids? If false, then B is already in global col ids.
951*076ba34aSJunchao Zhang .  N        - col size of the "parallel B matrix". It implies B's global col ids are in range of [0,N) and N is the same across the communicator.
952*076ba34aSJunchao Zhang .  l2g      - If localB, then l2g maps B's local col ids to global ones.
953*076ba34aSJunchao Zhang -  mm       - a struct used to stash intermediate data in AtB
954*076ba34aSJunchao Zhang 
955*076ba34aSJunchao Zhang   Notes: The local part of the result C is stored as mm->C_global, which is of type KokkosCsrMatrix and uses global col ids.
956*076ba34aSJunchao Zhang */
957*076ba34aSJunchao Zhang static PetscErrorCode MatProductSymbolic_MPIAIJKokkos_AtB(Mat_Product *product,Mat A,Mat B,PetscBool localB,PetscInt N,const ConstMatColIdxKokkosView& l2g,MatMatStruct_AtB *mm)
958*076ba34aSJunchao Zhang {
959*076ba34aSJunchao Zhang   PetscErrorCode         ierr;
960*076ba34aSJunchao Zhang   Mat_MPIAIJ             *a = static_cast<Mat_MPIAIJ*>(A->data);
961*076ba34aSJunchao Zhang   Mat                    Ad = a->A,Ao = a->B; /* diag and offdiag of A */
962*076ba34aSJunchao Zhang   Mat                    C1,C2; /* intermediate matrices */
963*076ba34aSJunchao Zhang 
964*076ba34aSJunchao Zhang   PetscFunctionBegin;
965*076ba34aSJunchao Zhang   /* C1 = Ad^t * B */
966*076ba34aSJunchao Zhang   ierr = MatProductCreate(Ad,B,NULL,&C1);CHKERRQ(ierr);
967*076ba34aSJunchao Zhang   ierr = MatProductSetType(C1,MATPRODUCT_AtB);CHKERRQ(ierr);
968*076ba34aSJunchao Zhang   ierr = MatProductSetFill(C1,product->fill);CHKERRQ(ierr);
969*076ba34aSJunchao Zhang   C1->product->api_user = product->api_user;
970*076ba34aSJunchao Zhang   ierr = MatProductSetFromOptions(C1);CHKERRQ(ierr);
971*076ba34aSJunchao Zhang   if (!C1->ops->productsymbolic) SETERRQ1(PETSC_COMM_SELF,PETSC_ERR_PLIB,"Missing symbolic op for %s",MatProductTypes[C1->product->type]);
972*076ba34aSJunchao Zhang   ierr = (*C1->ops->productsymbolic)(C1);CHKERRQ(ierr);
973*076ba34aSJunchao Zhang 
974*076ba34aSJunchao Zhang   if (localB) {ierr = MatSeqAIJKokkosGetCSRMatrixWithGlobalColumnIds(C1,N,l2g,mm->C1_global);}
975*076ba34aSJunchao Zhang   else mm->C1_global = static_cast<Mat_SeqAIJKokkos*>(C1->spptr)->csrmat; /* the csrmat already uses global col ids */
976*076ba34aSJunchao Zhang 
977*076ba34aSJunchao Zhang   /* C2 = Ao^t * B */
978*076ba34aSJunchao Zhang   ierr = MatProductCreate(Ao,B,NULL,&C2);CHKERRQ(ierr);
979*076ba34aSJunchao Zhang   ierr = MatProductSetType(C2,MATPRODUCT_AtB);CHKERRQ(ierr);
980*076ba34aSJunchao Zhang   ierr = MatProductSetFill(C2,product->fill);CHKERRQ(ierr);
981*076ba34aSJunchao Zhang   C2->product->api_user = product->api_user;
982*076ba34aSJunchao Zhang   ierr = MatProductSetFromOptions(C2);CHKERRQ(ierr);
983*076ba34aSJunchao Zhang   if (!C2->ops->productsymbolic) SETERRQ1(PETSC_COMM_SELF,PETSC_ERR_PLIB,"Missing symbolic op for %s",MatProductTypes[C2->product->type]);
984*076ba34aSJunchao Zhang   ierr = (*C2->ops->productsymbolic)(C2);CHKERRQ(ierr);
985*076ba34aSJunchao Zhang 
986*076ba34aSJunchao Zhang   ierr = MatSeqAIJKokkosReduce(C2,MAT_INITIAL_MATRIX,localB,N,l2g,a->Mvctx,mm->sf,mm->abuf,
987*076ba34aSJunchao Zhang                                mm->srcrowoffset,mm->dstrowoffset,mm->C2_global);CHKERRQ(ierr);
988*076ba34aSJunchao Zhang 
989*076ba34aSJunchao Zhang   mm->kh.create_spadd_handle(false); /* Input C1, C2 are NOT sorted, since B may be not */
990*076ba34aSJunchao Zhang   KokkosSparse::spadd_symbolic(&mm->kh,mm->C1_global,mm->C2_global,mm->C_global);
991*076ba34aSJunchao Zhang   /* Have to do numeric since spadd_symbolic does not really populate column indices of the result matrix */
992*076ba34aSJunchao Zhang   KokkosSparse::spadd_numeric(&mm->kh,(MatScalarType)1.0,mm->C1_global,(MatScalarType)1.0,mm->C2_global,mm->C_global);
993*076ba34aSJunchao Zhang   mm->C1 = C1;
994*076ba34aSJunchao Zhang   mm->C2 = C2;
995*076ba34aSJunchao Zhang   PetscFunctionReturn(0);
996*076ba34aSJunchao Zhang }
997*076ba34aSJunchao Zhang 
998*076ba34aSJunchao Zhang PetscErrorCode MatProductNumeric_MPIAIJKokkos(Mat C)
999*076ba34aSJunchao Zhang {
1000*076ba34aSJunchao Zhang   PetscErrorCode                ierr;
1001*076ba34aSJunchao Zhang   Mat_Product                   *product = C->product;
1002*076ba34aSJunchao Zhang   MatProductType                ptype;
1003*076ba34aSJunchao Zhang   MatProductData_MPIAIJKokkos   *mmdata;
1004*076ba34aSJunchao Zhang   MatMatStruct                  *mm = NULL;
1005*076ba34aSJunchao Zhang   MatMatStruct_AB               *ab;
1006*076ba34aSJunchao Zhang   MatMatStruct_AtB              *atb;
1007*076ba34aSJunchao Zhang   Mat                           A,B,Ad,Ao,Bd,Bo;
1008*076ba34aSJunchao Zhang   const MatScalarType           one = 1.0; /* Not use literal 1.0 directly, to avoid wrong template instantiation in KokkosSparse::spadd_numeric */
1009*076ba34aSJunchao Zhang 
1010*076ba34aSJunchao Zhang   PetscFunctionBegin;
1011*076ba34aSJunchao Zhang   MatCheckProduct(C,1);
1012*076ba34aSJunchao Zhang   mmdata = static_cast<MatProductData_MPIAIJKokkos*>(product->data);
1013*076ba34aSJunchao Zhang   ptype  = product->type;
1014*076ba34aSJunchao Zhang   A      = product->A;
1015*076ba34aSJunchao Zhang   B      = product->B;
1016*076ba34aSJunchao Zhang   Ad     = static_cast<Mat_MPIAIJ*>(A->data)->A;
1017*076ba34aSJunchao Zhang   Ao     = static_cast<Mat_MPIAIJ*>(A->data)->B;
1018*076ba34aSJunchao Zhang   Bd     = static_cast<Mat_MPIAIJ*>(B->data)->A;
1019*076ba34aSJunchao Zhang   Bo     = static_cast<Mat_MPIAIJ*>(B->data)->B;
1020*076ba34aSJunchao Zhang 
1021*076ba34aSJunchao Zhang   if (mmdata->reusesym) { /* We reached here through e.g., MatMatMult(A,B,MAT_INITIAL_MATRIX,..,C), where symbolic/numeric are combined */
1022*076ba34aSJunchao Zhang     mmdata->reusesym = PETSC_FALSE; /* So that next time when user calls MatMatMult(E,F,MAT_REUSE_MATRIX,..,C), we still do numeric  */
1023*076ba34aSJunchao Zhang     ab  = mmdata->mmAB;
1024*076ba34aSJunchao Zhang     atb = mmdata->mmAtB;
1025*076ba34aSJunchao Zhang     if (ab) {
1026*076ba34aSJunchao Zhang       static_cast<MatProductData_SeqAIJKokkos*>(ab->C1->product->data)->reusesym = PETSC_FALSE;
1027*076ba34aSJunchao Zhang       static_cast<MatProductData_SeqAIJKokkos*>(ab->C2->product->data)->reusesym = PETSC_FALSE;
1028*076ba34aSJunchao Zhang     }
1029*076ba34aSJunchao Zhang     if (atb) {
1030*076ba34aSJunchao Zhang       static_cast<MatProductData_SeqAIJKokkos*>(atb->C1->product->data)->reusesym = PETSC_FALSE;
1031*076ba34aSJunchao Zhang       static_cast<MatProductData_SeqAIJKokkos*>(atb->C2->product->data)->reusesym = PETSC_FALSE;
1032*076ba34aSJunchao Zhang     }
1033*076ba34aSJunchao Zhang     PetscFunctionReturn(0);
1034*076ba34aSJunchao Zhang   }
1035*076ba34aSJunchao Zhang 
1036*076ba34aSJunchao Zhang   if (ptype == MATPRODUCT_AB) {
1037*076ba34aSJunchao Zhang     ab   = mmdata->mmAB;
1038*076ba34aSJunchao Zhang     /* C1 = Ad * B_local */
1039*076ba34aSJunchao Zhang     if (!ab->C1->ops->productnumeric || !ab->C2->ops->productnumeric) SETERRQ(PETSC_COMM_SELF,PETSC_ERR_PLIB,"Missing numeric op for MATPRODUCT_AB");
1040*076ba34aSJunchao Zhang     ierr = MatMPIAIJGetLocalMatMerge(B,MAT_REUSE_MATRIX,NULL/*glob*/,&ab->B_local);CHKERRQ(ierr);
1041*076ba34aSJunchao Zhang     if (ab->C1->product->B != ab->B_local) SETERRQ(PETSC_COMM_SELF,PETSC_ERR_PLIB,"In MATPRODUCT_AB, internal mat product matrix C1->B has unexpectedly changed");
1042*076ba34aSJunchao Zhang     if (ab->C1->product->A != Ad) {ierr = MatProductReplaceMats(Ad,NULL,NULL,ab->C1);CHKERRQ(ierr);}
1043*076ba34aSJunchao Zhang     ierr = (*ab->C1->ops->productnumeric)(ab->C1);CHKERRQ(ierr);
1044*076ba34aSJunchao Zhang     ierr = MatSeqAIJKokkosBcast(ab->B_local,MAT_REUSE_MATRIX,0/*N*/,MatColIdxKokkosView()/*l2g*/,NULL/*ownerSF*/,ab->sf,
1045*076ba34aSJunchao Zhang                                 ab->abuf,ab->rows,ab->rowoffset,ab->B_other);CHKERRQ(ierr);
1046*076ba34aSJunchao Zhang     /* C2 = Ao * B_other */
1047*076ba34aSJunchao Zhang     if (ab->C2->product->B != ab->B_other) SETERRQ(PETSC_COMM_SELF,PETSC_ERR_PLIB,"In MATPRODUCT_AB, internal mat product matrix C2->B has unexpectedly changed");
1048*076ba34aSJunchao Zhang     if (ab->C1->product->A != Ao) {ierr = MatProductReplaceMats(Ao,NULL,NULL,ab->C2);CHKERRQ(ierr);}
1049*076ba34aSJunchao Zhang     ierr = (*ab->C2->ops->productnumeric)(ab->C2);CHKERRQ(ierr);
1050*076ba34aSJunchao Zhang     /* C = C1_global + C2_global */
1051*076ba34aSJunchao Zhang     KokkosSparse::spadd_numeric(&ab->kh,one,ab->C1_global,one,ab->C2_global,ab->C_global);
1052*076ba34aSJunchao Zhang     mm = static_cast<MatMatStruct*>(ab);
1053*076ba34aSJunchao Zhang   } else if (ptype == MATPRODUCT_AtB) {
1054*076ba34aSJunchao Zhang     atb  = mmdata->mmAtB;
1055*076ba34aSJunchao Zhang     if (!atb->C1->ops->productnumeric || !atb->C2->ops->productnumeric) SETERRQ(PETSC_COMM_SELF,PETSC_ERR_PLIB,"Missing numeric op for MATPRODUCT_AtB");
1056*076ba34aSJunchao Zhang     /* C1 = Ad^t * B_local */
1057*076ba34aSJunchao Zhang     ierr = MatMPIAIJGetLocalMatMerge(B,MAT_REUSE_MATRIX,NULL/*glob*/,&atb->B_local);CHKERRQ(ierr);
1058*076ba34aSJunchao Zhang     if (atb->C1->product->B != atb->B_local) SETERRQ(PETSC_COMM_SELF,PETSC_ERR_PLIB,"In MATPRODUCT_AtB, internal mat product matrix C1->B has unexpectedly changed");
1059*076ba34aSJunchao Zhang     if (atb->C1->product->A != Ad) {ierr = MatProductReplaceMats(Ad,NULL,NULL,atb->C1);CHKERRQ(ierr);}
1060*076ba34aSJunchao Zhang     ierr = (*atb->C1->ops->productnumeric)(atb->C1);CHKERRQ(ierr);
1061*076ba34aSJunchao Zhang 
1062*076ba34aSJunchao Zhang     /* C2 = Ao^t * B_local */
1063*076ba34aSJunchao Zhang     if (atb->C2->product->B != atb->B_local) SETERRQ(PETSC_COMM_SELF,PETSC_ERR_PLIB,"In MATPRODUCT_AtB, internal mat product matrix C2->B has unexpectedly changed");
1064*076ba34aSJunchao Zhang     if (atb->C2->product->A != Ao) {ierr = MatProductReplaceMats(Ao,NULL,NULL,atb->C2);CHKERRQ(ierr);}
1065*076ba34aSJunchao Zhang     ierr = (*atb->C2->ops->productnumeric)(atb->C2);CHKERRQ(ierr);
1066*076ba34aSJunchao Zhang     /* Form C2_global */
1067*076ba34aSJunchao Zhang     ierr = MatSeqAIJKokkosReduce(atb->C2,MAT_REUSE_MATRIX,PETSC_TRUE,0/*N*/,MatColIdxKokkosView()/*l2g*/,NULL/*ownerSF*/,atb->sf,
1068*076ba34aSJunchao Zhang                                  atb->abuf,atb->srcrowoffset,atb->dstrowoffset,atb->C2_global);CHKERRQ(ierr);
1069*076ba34aSJunchao Zhang     /* C = C1_global + C2_global */
1070*076ba34aSJunchao Zhang     KokkosSparse::spadd_numeric(&atb->kh,one,atb->C1_global,one,atb->C2_global,atb->C_global);
1071*076ba34aSJunchao Zhang     mm = static_cast<MatMatStruct*>(atb);
1072*076ba34aSJunchao Zhang   } else if (ptype == MATPRODUCT_PtAP) { /* BtAB */
1073*076ba34aSJunchao Zhang     ab   = mmdata->mmAB;
1074*076ba34aSJunchao Zhang     ierr = MatMPIAIJGetLocalMatMerge(B,MAT_REUSE_MATRIX,NULL/*glob*/,&ab->B_local);CHKERRQ(ierr);
1075*076ba34aSJunchao Zhang 
1076*076ba34aSJunchao Zhang     /* ab->C1 = Ad * B_local */
1077*076ba34aSJunchao Zhang     if (ab->C1->product->B != ab->B_local) SETERRQ(PETSC_COMM_SELF,PETSC_ERR_PLIB,"In MATPRODUCT_PtAP, internal mat product matrix ab->C1->B has unexpectedly changed");
1078*076ba34aSJunchao Zhang     if (ab->C1->product->A != Ad) {ierr = MatProductReplaceMats(Ad,NULL,NULL,ab->C1);CHKERRQ(ierr);}
1079*076ba34aSJunchao Zhang     ierr = (*ab->C1->ops->productnumeric)(ab->C1);CHKERRQ(ierr);
1080*076ba34aSJunchao Zhang     ierr = MatSeqAIJKokkosBcast(ab->B_local,MAT_REUSE_MATRIX,0/*N*/,MatColIdxKokkosView()/*l2g*/,NULL/*ownerSF*/,ab->sf,
1081*076ba34aSJunchao Zhang                                 ab->abuf,ab->rows,ab->rowoffset,ab->B_other);CHKERRQ(ierr);
1082*076ba34aSJunchao Zhang     /* ab->C2 = Ao * B_other */
1083*076ba34aSJunchao Zhang     if (ab->C2->product->A != Ao) {ierr = MatProductReplaceMats(Ao,NULL,NULL,ab->C2);CHKERRQ(ierr);}
1084*076ba34aSJunchao Zhang     ierr = (*ab->C2->ops->productnumeric)(ab->C2);CHKERRQ(ierr); /* C2 = Ao * B_other */
1085*076ba34aSJunchao Zhang     KokkosSparse::spadd_numeric(&ab->kh,one,ab->C1_global,one,ab->C2_global,ab->C_global);
1086*076ba34aSJunchao Zhang 
1087*076ba34aSJunchao Zhang     /* atb->C1 = Bd^t * ab->C_petsc */
1088*076ba34aSJunchao Zhang     atb  = mmdata->mmAtB;
1089*076ba34aSJunchao Zhang     if (atb->C1->product->B != ab->C_petsc) SETERRQ(PETSC_COMM_SELF,PETSC_ERR_PLIB,"In MATPRODUCT_PtAP, internal mat product matrix atb->C1->B has unexpectedly changed");
1090*076ba34aSJunchao Zhang     if (atb->C1->product->A != Bd) {ierr = MatProductReplaceMats(Bd,NULL,NULL,atb->C1);CHKERRQ(ierr);}
1091*076ba34aSJunchao Zhang     ierr = (*atb->C1->ops->productnumeric)(atb->C1);CHKERRQ(ierr);
1092*076ba34aSJunchao Zhang     /* atb->C2 = Bo^t * ab->C_petsc */
1093*076ba34aSJunchao Zhang     if (atb->C2->product->A != Bo) {ierr = MatProductReplaceMats(Bo,NULL,NULL,atb->C2);CHKERRQ(ierr);}
1094*076ba34aSJunchao Zhang     ierr = (*atb->C2->ops->productnumeric)(atb->C2);CHKERRQ(ierr);
1095*076ba34aSJunchao Zhang     ierr = MatSeqAIJKokkosReduce(atb->C2,MAT_REUSE_MATRIX,PETSC_FALSE,0/*N*/,MatColIdxKokkosView()/*l2g*/,NULL/*ownerSF*/,atb->sf,
1096*076ba34aSJunchao Zhang                                  atb->abuf,atb->srcrowoffset,atb->dstrowoffset,atb->C2_global);CHKERRQ(ierr);
1097*076ba34aSJunchao Zhang     KokkosSparse::spadd_numeric(&atb->kh,one,atb->C1_global,one,atb->C2_global,atb->C_global);
1098*076ba34aSJunchao Zhang     mm = static_cast<MatMatStruct*>(atb);
1099*076ba34aSJunchao Zhang   }
1100*076ba34aSJunchao Zhang   /* Split C_global to form C */
1101*076ba34aSJunchao Zhang   ierr = MatSetMPIAIJKokkosWithGlobalCSRMatrix(C,MAT_REUSE_MATRIX,mm->C_global,mm->Cdstart);CHKERRQ(ierr);
1102*076ba34aSJunchao Zhang   PetscFunctionReturn(0);
1103*076ba34aSJunchao Zhang }
1104*076ba34aSJunchao Zhang 
1105*076ba34aSJunchao Zhang PetscErrorCode MatProductSymbolic_MPIAIJKokkos(Mat C)
1106*076ba34aSJunchao Zhang {
1107*076ba34aSJunchao Zhang   PetscErrorCode              ierr;
1108*076ba34aSJunchao Zhang   Mat                         A,B;
1109*076ba34aSJunchao Zhang   Mat_Product                 *product = C->product;
1110*076ba34aSJunchao Zhang   MatProductType              ptype;
1111*076ba34aSJunchao Zhang   MatProductData_MPIAIJKokkos *mmdata;
1112*076ba34aSJunchao Zhang   MatMatStruct                *mm = NULL;
1113*076ba34aSJunchao Zhang   IS                          glob = NULL;
1114*076ba34aSJunchao Zhang   const PetscInt              *garray;
1115*076ba34aSJunchao Zhang   PetscInt                    m,n,M,N,sz;
1116*076ba34aSJunchao Zhang   ConstMatColIdxKokkosView    l2g; /* map local col ids to global ones */
1117*076ba34aSJunchao Zhang 
1118*076ba34aSJunchao Zhang   PetscFunctionBegin;
1119*076ba34aSJunchao Zhang   MatCheckProduct(C,1);
1120*076ba34aSJunchao Zhang   if (product->data) SETERRQ(PetscObjectComm((PetscObject)C),PETSC_ERR_PLIB,"Product data not empty");
1121*076ba34aSJunchao Zhang   ptype = product->type;
1122*076ba34aSJunchao Zhang   A     = product->A;
1123*076ba34aSJunchao Zhang   B     = product->B;
1124*076ba34aSJunchao Zhang 
1125*076ba34aSJunchao Zhang   switch (ptype) {
1126*076ba34aSJunchao Zhang     case MATPRODUCT_AB:   m = A->rmap->n; n = B->cmap->n; M = A->rmap->N; N = B->cmap->N; break;
1127*076ba34aSJunchao Zhang     case MATPRODUCT_AtB:  m = A->cmap->n; n = B->cmap->n; M = A->cmap->N; N = B->cmap->N; break;
1128*076ba34aSJunchao Zhang     case MATPRODUCT_PtAP: m = B->cmap->n; n = B->cmap->n; M = B->cmap->N; N = B->cmap->N; break; /* BtAB */
1129*076ba34aSJunchao Zhang     default: SETERRQ1(PetscObjectComm((PetscObject)C),PETSC_ERR_PLIB,"Not for product type %s",MatProductTypes[ptype]);
1130*076ba34aSJunchao Zhang   }
1131*076ba34aSJunchao Zhang 
1132*076ba34aSJunchao Zhang   ierr = MatSetSizes(C,m,n,M,N);CHKERRQ(ierr);
1133*076ba34aSJunchao Zhang   ierr = PetscLayoutSetUp(C->rmap);CHKERRQ(ierr);
1134*076ba34aSJunchao Zhang   ierr = PetscLayoutSetUp(C->cmap);CHKERRQ(ierr);
1135*076ba34aSJunchao Zhang   ierr = MatSetType(C,((PetscObject)A)->type_name);CHKERRQ(ierr);
1136*076ba34aSJunchao Zhang 
1137*076ba34aSJunchao Zhang   mmdata           = new MatProductData_MPIAIJKokkos();
1138*076ba34aSJunchao Zhang   mmdata->reusesym = product->api_user;
1139*076ba34aSJunchao Zhang 
1140*076ba34aSJunchao Zhang   if (ptype == MATPRODUCT_AB) {
1141*076ba34aSJunchao Zhang     mmdata->mmAB = new MatMatStruct_AB();
1142*076ba34aSJunchao Zhang     ierr = MatProductSymbolic_MPIAIJKokkos_AB(product,A,B,mmdata->mmAB);CHKERRQ(ierr);
1143*076ba34aSJunchao Zhang     mm   = static_cast<MatMatStruct*>(mmdata->mmAB);
1144*076ba34aSJunchao Zhang   } else if (ptype == MATPRODUCT_AtB) {
1145*076ba34aSJunchao Zhang     mmdata->mmAtB = new MatMatStruct_AtB();
1146*076ba34aSJunchao Zhang     auto atb      = mmdata->mmAtB;
1147*076ba34aSJunchao Zhang     ierr = MatMPIAIJGetLocalMatMerge(B,MAT_INITIAL_MATRIX,&glob,&atb->B_local);CHKERRQ(ierr);
1148*076ba34aSJunchao Zhang     ierr = ISGetIndices(glob,&garray);CHKERRQ(ierr);
1149*076ba34aSJunchao Zhang     ierr = ISGetSize(glob,&sz);CHKERRQ(ierr);
1150*076ba34aSJunchao Zhang     l2g  = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(),ConstMatColIdxKokkosViewHost(garray,sz));
1151*076ba34aSJunchao Zhang     ierr = MatProductSymbolic_MPIAIJKokkos_AtB(product,A,atb->B_local,PETSC_TRUE,N,l2g,atb);CHKERRQ(ierr);
1152*076ba34aSJunchao Zhang     ierr = ISRestoreIndices(glob,&garray);CHKERRQ(ierr);
1153*076ba34aSJunchao Zhang     ierr = ISDestroy(&glob);CHKERRQ(ierr);
1154*076ba34aSJunchao Zhang     mm   = static_cast<MatMatStruct*>(atb);
1155*076ba34aSJunchao Zhang   } else if (ptype == MATPRODUCT_PtAP) { /* BtAB */
1156*076ba34aSJunchao Zhang     mmdata->mmAB  = new MatMatStruct_AB(); /* tmp=A*B */
1157*076ba34aSJunchao Zhang     mmdata->mmAtB = new MatMatStruct_AtB(); /* C=B^t*tmp */
1158*076ba34aSJunchao Zhang     auto ab       = mmdata->mmAB;
1159*076ba34aSJunchao Zhang     auto atb      = mmdata->mmAtB;
1160*076ba34aSJunchao Zhang     ierr = MatProductSymbolic_MPIAIJKokkos_AB(product,A,B,ab);CHKERRQ(ierr);
1161*076ba34aSJunchao Zhang     auto tmp = new Mat_SeqAIJKokkos(ab->C_global); /* Memory will be owned by ab->C_petsc */
1162*076ba34aSJunchao Zhang     ierr = MatCreateSeqAIJKokkosWithCSRMatrix(PETSC_COMM_SELF,tmp,&ab->C_petsc);CHKERRQ(ierr);
1163*076ba34aSJunchao Zhang     ierr = MatProductSymbolic_MPIAIJKokkos_AtB(product,B,ab->C_petsc,PETSC_FALSE,N,l2g/*not used*/,atb);CHKERRQ(ierr);
1164*076ba34aSJunchao Zhang     mm   = static_cast<MatMatStruct*>(atb);
1165*076ba34aSJunchao Zhang   }
1166*076ba34aSJunchao Zhang   /* Split the C_global into petsc A, B format */
1167*076ba34aSJunchao Zhang   ierr = MatSetMPIAIJKokkosWithGlobalCSRMatrix(C,MAT_INITIAL_MATRIX,mm->C_global,mm->Cdstart);CHKERRQ(ierr);
1168*076ba34aSJunchao Zhang   C->product->data        = mmdata;
1169*076ba34aSJunchao Zhang   C->product->destroy     = MatProductDataDestroy_MPIAIJKokkos;
1170*076ba34aSJunchao Zhang   C->ops->productnumeric  = MatProductNumeric_MPIAIJKokkos;
1171*076ba34aSJunchao Zhang   PetscFunctionReturn(0);
1172*076ba34aSJunchao Zhang }
1173*076ba34aSJunchao Zhang 
1174*076ba34aSJunchao Zhang PETSC_INTERN PetscErrorCode MatProductSetFromOptions_MPIAIJKokkos(Mat mat)
1175*076ba34aSJunchao Zhang {
1176*076ba34aSJunchao Zhang   PetscErrorCode ierr;
1177*076ba34aSJunchao Zhang   Mat_Product    *product = mat->product;
1178*076ba34aSJunchao Zhang   PetscBool      match = PETSC_FALSE;
1179*076ba34aSJunchao Zhang   PetscBool      usecpu = PETSC_FALSE;
1180*076ba34aSJunchao Zhang 
1181*076ba34aSJunchao Zhang   PetscFunctionBegin;
1182*076ba34aSJunchao Zhang   MatCheckProduct(mat,1);
1183*076ba34aSJunchao Zhang   if (!product->A->boundtocpu && !product->B->boundtocpu) {
1184*076ba34aSJunchao Zhang     ierr = PetscObjectTypeCompare((PetscObject)product->B,((PetscObject)product->A)->type_name,&match);CHKERRQ(ierr);
1185*076ba34aSJunchao Zhang   }
1186*076ba34aSJunchao Zhang   if (match) { /* we can always fallback to the CPU if requested */
1187*076ba34aSJunchao Zhang     switch (product->type) {
1188*076ba34aSJunchao Zhang     case MATPRODUCT_AB:
1189*076ba34aSJunchao Zhang       if (product->api_user) {
1190*076ba34aSJunchao Zhang         ierr = PetscOptionsBegin(PetscObjectComm((PetscObject)mat),((PetscObject)mat)->prefix,"MatMatMult","Mat");CHKERRQ(ierr);
1191*076ba34aSJunchao Zhang         ierr = PetscOptionsBool("-matmatmult_backend_cpu","Use CPU code","MatMatMult",usecpu,&usecpu,NULL);CHKERRQ(ierr);
1192*076ba34aSJunchao Zhang         ierr = PetscOptionsEnd();CHKERRQ(ierr);
1193*076ba34aSJunchao Zhang       } else {
1194*076ba34aSJunchao Zhang         ierr = PetscOptionsBegin(PetscObjectComm((PetscObject)mat),((PetscObject)mat)->prefix,"MatProduct_AB","Mat");CHKERRQ(ierr);
1195*076ba34aSJunchao Zhang         ierr = PetscOptionsBool("-matproduct_ab_backend_cpu","Use CPU code","MatMatMult",usecpu,&usecpu,NULL);CHKERRQ(ierr);
1196*076ba34aSJunchao Zhang         ierr = PetscOptionsEnd();CHKERRQ(ierr);
1197*076ba34aSJunchao Zhang       }
1198*076ba34aSJunchao Zhang       break;
1199*076ba34aSJunchao Zhang     case MATPRODUCT_AtB:
1200*076ba34aSJunchao Zhang       if (product->api_user) {
1201*076ba34aSJunchao Zhang         ierr = PetscOptionsBegin(PetscObjectComm((PetscObject)mat),((PetscObject)mat)->prefix,"MatTransposeMatMult","Mat");CHKERRQ(ierr);
1202*076ba34aSJunchao Zhang         ierr = PetscOptionsBool("-mattransposematmult_backend_cpu","Use CPU code","MatTransposeMatMult",usecpu,&usecpu,NULL);CHKERRQ(ierr);
1203*076ba34aSJunchao Zhang         ierr = PetscOptionsEnd();CHKERRQ(ierr);
1204*076ba34aSJunchao Zhang       } else {
1205*076ba34aSJunchao Zhang         ierr = PetscOptionsBegin(PetscObjectComm((PetscObject)mat),((PetscObject)mat)->prefix,"MatProduct_AtB","Mat");CHKERRQ(ierr);
1206*076ba34aSJunchao Zhang         ierr = PetscOptionsBool("-matproduct_atb_backend_cpu","Use CPU code","MatTransposeMatMult",usecpu,&usecpu,NULL);CHKERRQ(ierr);
1207*076ba34aSJunchao Zhang         ierr = PetscOptionsEnd();CHKERRQ(ierr);
1208*076ba34aSJunchao Zhang       }
1209*076ba34aSJunchao Zhang       break;
1210*076ba34aSJunchao Zhang     case MATPRODUCT_PtAP:
1211*076ba34aSJunchao Zhang       if (product->api_user) {
1212*076ba34aSJunchao Zhang         ierr = PetscOptionsBegin(PetscObjectComm((PetscObject)mat),((PetscObject)mat)->prefix,"MatPtAP","Mat");CHKERRQ(ierr);
1213*076ba34aSJunchao Zhang         ierr = PetscOptionsBool("-matptap_backend_cpu","Use CPU code","MatPtAP",usecpu,&usecpu,NULL);CHKERRQ(ierr);
1214*076ba34aSJunchao Zhang         ierr = PetscOptionsEnd();CHKERRQ(ierr);
1215*076ba34aSJunchao Zhang       } else {
1216*076ba34aSJunchao Zhang         ierr = PetscOptionsBegin(PetscObjectComm((PetscObject)mat),((PetscObject)mat)->prefix,"MatProduct_PtAP","Mat");CHKERRQ(ierr);
1217*076ba34aSJunchao Zhang         ierr = PetscOptionsBool("-matproduct_ptap_backend_cpu","Use CPU code","MatPtAP",usecpu,&usecpu,NULL);CHKERRQ(ierr);
1218*076ba34aSJunchao Zhang         ierr = PetscOptionsEnd();CHKERRQ(ierr);
1219*076ba34aSJunchao Zhang       }
1220*076ba34aSJunchao Zhang       break;
1221*076ba34aSJunchao Zhang     default:
1222*076ba34aSJunchao Zhang       break;
1223*076ba34aSJunchao Zhang     }
1224*076ba34aSJunchao Zhang     match = (PetscBool)!usecpu;
1225*076ba34aSJunchao Zhang   }
1226*076ba34aSJunchao Zhang   if (match) {
1227*076ba34aSJunchao Zhang     switch (product->type) {
1228*076ba34aSJunchao Zhang     case MATPRODUCT_AB:
1229*076ba34aSJunchao Zhang     case MATPRODUCT_AtB:
1230*076ba34aSJunchao Zhang     case MATPRODUCT_PtAP:
1231*076ba34aSJunchao Zhang       mat->ops->productsymbolic = MatProductSymbolic_MPIAIJKokkos;
1232*076ba34aSJunchao Zhang       break;
1233*076ba34aSJunchao Zhang     default:
1234*076ba34aSJunchao Zhang       break;
1235*076ba34aSJunchao Zhang     }
1236*076ba34aSJunchao Zhang   }
1237*076ba34aSJunchao Zhang   /* fallback to MPIAIJ ops */
1238*076ba34aSJunchao Zhang   if (!mat->ops->productsymbolic) {
1239*076ba34aSJunchao Zhang     ierr = MatProductSetFromOptions_MPIAIJ(mat);CHKERRQ(ierr);
1240*076ba34aSJunchao Zhang   }
1241*076ba34aSJunchao Zhang   PetscFunctionReturn(0);
1242*076ba34aSJunchao Zhang }
1243*076ba34aSJunchao Zhang 
1244*076ba34aSJunchao Zhang PetscErrorCode MatDestroy_MPIAIJKokkos(Mat A)
1245*076ba34aSJunchao Zhang {
1246*076ba34aSJunchao Zhang   PetscErrorCode     ierr;
1247*076ba34aSJunchao Zhang 
1248*076ba34aSJunchao Zhang   PetscFunctionBegin;
1249*076ba34aSJunchao Zhang   ierr = PetscObjectComposeFunction((PetscObject)A,"MatMPIAIJSetPreallocation_C",NULL);CHKERRQ(ierr);
1250*076ba34aSJunchao Zhang   ierr = PetscObjectComposeFunction((PetscObject)A,"MatMPIAIJGetLocalMatMerge_C",NULL);CHKERRQ(ierr);
1251*076ba34aSJunchao Zhang   ierr = MatDestroy_MPIAIJ(A);CHKERRQ(ierr);
1252*076ba34aSJunchao Zhang   PetscFunctionReturn(0);
1253*076ba34aSJunchao Zhang }
1254*076ba34aSJunchao Zhang 
12558c3ff71bSJunchao Zhang PETSC_INTERN PetscErrorCode MatConvert_MPIAIJ_MPIAIJKokkos(Mat A, MatType mtype, MatReuse reuse, Mat* newmat)
12568c3ff71bSJunchao Zhang {
12578c3ff71bSJunchao Zhang   PetscErrorCode     ierr;
12588c3ff71bSJunchao Zhang   Mat                B;
1259*076ba34aSJunchao Zhang   Mat_MPIAIJ         *a;
12608c3ff71bSJunchao Zhang 
12618c3ff71bSJunchao Zhang   PetscFunctionBegin;
12628c3ff71bSJunchao Zhang   if (reuse == MAT_INITIAL_MATRIX) {
12638c3ff71bSJunchao Zhang     ierr = MatDuplicate(A,MAT_COPY_VALUES,newmat);CHKERRQ(ierr);
12648c3ff71bSJunchao Zhang   } else if (reuse == MAT_REUSE_MATRIX) {
12658c3ff71bSJunchao Zhang     ierr = MatCopy(A,*newmat,SAME_NONZERO_PATTERN);CHKERRQ(ierr);
12668c3ff71bSJunchao Zhang   }
12678c3ff71bSJunchao Zhang   B = *newmat;
12688c3ff71bSJunchao Zhang 
12696f3d89d0SStefano Zampini   B->boundtocpu = PETSC_FALSE;
12708c3ff71bSJunchao Zhang   ierr = PetscFree(B->defaultvectype);CHKERRQ(ierr);
12718c3ff71bSJunchao Zhang   ierr = PetscStrallocpy(VECKOKKOS,&B->defaultvectype);CHKERRQ(ierr);
12723d0639e7SStefano Zampini   ierr = PetscObjectChangeTypeName((PetscObject)B,MATMPIAIJKOKKOS);CHKERRQ(ierr);
12738c3ff71bSJunchao Zhang 
1274*076ba34aSJunchao Zhang   a = static_cast<Mat_MPIAIJ*>(A->data);
1275*076ba34aSJunchao Zhang   if (a->A) {ierr = MatSetType(a->A,MATSEQAIJKOKKOS);CHKERRQ(ierr);}
1276*076ba34aSJunchao Zhang   if (a->B) {ierr = MatSetType(a->B,MATSEQAIJKOKKOS);CHKERRQ(ierr);}
1277*076ba34aSJunchao Zhang   if (a->lvec) {ierr = VecSetType(a->lvec,VECSEQKOKKOS);CHKERRQ(ierr);}
1278*076ba34aSJunchao Zhang 
12798c3ff71bSJunchao Zhang   B->ops->assemblyend           = MatAssemblyEnd_MPIAIJKokkos;
12808c3ff71bSJunchao Zhang   B->ops->mult                  = MatMult_MPIAIJKokkos;
12818c3ff71bSJunchao Zhang   B->ops->multadd               = MatMultAdd_MPIAIJKokkos;
12828c3ff71bSJunchao Zhang   B->ops->multtranspose         = MatMultTranspose_MPIAIJKokkos;
1283*076ba34aSJunchao Zhang   B->ops->productsetfromoptions = MatProductSetFromOptions_MPIAIJKokkos;
1284*076ba34aSJunchao Zhang   B->ops->destroy               = MatDestroy_MPIAIJKokkos;
12858c3ff71bSJunchao Zhang 
12863d0639e7SStefano Zampini   ierr = PetscObjectComposeFunction((PetscObject)B,"MatMPIAIJSetPreallocation_C",MatMPIAIJSetPreallocation_MPIAIJKokkos);CHKERRQ(ierr);
1287*076ba34aSJunchao Zhang   ierr = PetscObjectComposeFunction((PetscObject)A,"MatMPIAIJGetLocalMatMerge_C",MatMPIAIJGetLocalMatMerge_MPIAIJKokkos);CHKERRQ(ierr);
12888c3ff71bSJunchao Zhang   PetscFunctionReturn(0);
12898c3ff71bSJunchao Zhang }
12908c3ff71bSJunchao Zhang 
12918c3ff71bSJunchao Zhang PETSC_EXTERN PetscErrorCode MatCreate_MPIAIJKokkos(Mat A)
12928c3ff71bSJunchao Zhang {
12938c3ff71bSJunchao Zhang   PetscErrorCode ierr;
12948c3ff71bSJunchao Zhang 
12958c3ff71bSJunchao Zhang   PetscFunctionBegin;
12968c3ff71bSJunchao Zhang   ierr = PetscKokkosInitializeCheck();CHKERRQ(ierr);
12978c3ff71bSJunchao Zhang   ierr = MatCreate_MPIAIJ(A);CHKERRQ(ierr);
12988c3ff71bSJunchao Zhang   ierr = MatConvert_MPIAIJ_MPIAIJKokkos(A,MATMPIAIJKOKKOS,MAT_INPLACE_MATRIX,&A);CHKERRQ(ierr);
12998c3ff71bSJunchao Zhang   PetscFunctionReturn(0);
13008c3ff71bSJunchao Zhang }
13018c3ff71bSJunchao Zhang 
13028c3ff71bSJunchao Zhang /*@C
13038c3ff71bSJunchao Zhang    MatCreateAIJKokkos - Creates a sparse matrix in AIJ (compressed row) format
13048c3ff71bSJunchao Zhang    (the default parallel PETSc format).  This matrix will ultimately pushed down
13058c3ff71bSJunchao Zhang    to Kokkos for calculations. For good matrix
13068c3ff71bSJunchao Zhang    assembly performance the user should preallocate the matrix storage by setting
13078c3ff71bSJunchao Zhang    the parameter nz (or the array nnz).  By setting these parameters accurately,
13088c3ff71bSJunchao Zhang    performance during matrix assembly can be increased by more than a factor of 50.
13098c3ff71bSJunchao Zhang 
13108c3ff71bSJunchao Zhang    Collective
13118c3ff71bSJunchao Zhang 
13128c3ff71bSJunchao Zhang    Input Parameters:
13138c3ff71bSJunchao Zhang +  comm - MPI communicator, set to PETSC_COMM_SELF
13148c3ff71bSJunchao Zhang .  m - number of rows
13158c3ff71bSJunchao Zhang .  n - number of columns
13168c3ff71bSJunchao Zhang .  nz - number of nonzeros per row (same for all rows)
13178c3ff71bSJunchao Zhang -  nnz - array containing the number of nonzeros in the various rows
13188c3ff71bSJunchao Zhang          (possibly different for each row) or NULL
13198c3ff71bSJunchao Zhang 
13208c3ff71bSJunchao Zhang    Output Parameter:
13218c3ff71bSJunchao Zhang .  A - the matrix
13228c3ff71bSJunchao Zhang 
13238c3ff71bSJunchao Zhang    It is recommended that one use the MatCreate(), MatSetType() and/or MatSetFromOptions(),
13248c3ff71bSJunchao Zhang    MatXXXXSetPreallocation() paradigm instead of this routine directly.
13258c3ff71bSJunchao Zhang    [MatXXXXSetPreallocation() is, for example, MatSeqAIJSetPreallocation]
13268c3ff71bSJunchao Zhang 
13278c3ff71bSJunchao Zhang    Notes:
13288c3ff71bSJunchao Zhang    If nnz is given then nz is ignored
13298c3ff71bSJunchao Zhang 
13308c3ff71bSJunchao Zhang    The AIJ format (also called the Yale sparse matrix format or
13318c3ff71bSJunchao Zhang    compressed row storage), is fully compatible with standard Fortran 77
13328c3ff71bSJunchao Zhang    storage.  That is, the stored row and column indices can begin at
13338c3ff71bSJunchao Zhang    either one (as in Fortran) or zero.  See the users' manual for details.
13348c3ff71bSJunchao Zhang 
13358c3ff71bSJunchao Zhang    Specify the preallocated storage with either nz or nnz (not both).
13368c3ff71bSJunchao Zhang    Set nz=PETSC_DEFAULT and nnz=NULL for PETSc to control dynamic memory
13378c3ff71bSJunchao Zhang    allocation.  For large problems you MUST preallocate memory or you
13388c3ff71bSJunchao Zhang    will get TERRIBLE performance, see the users' manual chapter on matrices.
13398c3ff71bSJunchao Zhang 
13408c3ff71bSJunchao Zhang    By default, this format uses inodes (identical nodes) when possible, to
13418c3ff71bSJunchao Zhang    improve numerical efficiency of matrix-vector products and solves. We
13428c3ff71bSJunchao Zhang    search for consecutive rows with the same nonzero structure, thereby
13438c3ff71bSJunchao Zhang    reusing matrix information to achieve increased efficiency.
13448c3ff71bSJunchao Zhang 
13458c3ff71bSJunchao Zhang    Level: intermediate
13468c3ff71bSJunchao Zhang 
13478c3ff71bSJunchao Zhang .seealso: MatCreate(), MatCreateAIJ(), MatSetValues(), MatSeqAIJSetColumnIndices(), MatCreateSeqAIJWithArrays(), MatCreateAIJ(), MATMPIAIJKOKKOS, MATAIJKokkos
13488c3ff71bSJunchao Zhang @*/
13498c3ff71bSJunchao Zhang PetscErrorCode  MatCreateAIJKokkos(MPI_Comm comm,PetscInt m,PetscInt n,PetscInt M,PetscInt N,PetscInt d_nz,const PetscInt d_nnz[],PetscInt o_nz,const PetscInt o_nnz[],Mat *A)
13508c3ff71bSJunchao Zhang {
13518c3ff71bSJunchao Zhang   PetscErrorCode ierr;
13528c3ff71bSJunchao Zhang   PetscMPIInt    size;
13538c3ff71bSJunchao Zhang 
13548c3ff71bSJunchao Zhang   PetscFunctionBegin;
13558c3ff71bSJunchao Zhang   ierr = MatCreate(comm,A);CHKERRQ(ierr);
13568c3ff71bSJunchao Zhang   ierr = MatSetSizes(*A,m,n,M,N);CHKERRQ(ierr);
1357ffc4695bSBarry Smith   ierr = MPI_Comm_size(comm,&size);CHKERRMPI(ierr);
13588c3ff71bSJunchao Zhang   if (size > 1) {
13598c3ff71bSJunchao Zhang     ierr = MatSetType(*A,MATMPIAIJKOKKOS);CHKERRQ(ierr);
13608c3ff71bSJunchao Zhang     ierr = MatMPIAIJSetPreallocation(*A,d_nz,d_nnz,o_nz,o_nnz);CHKERRQ(ierr);
13618c3ff71bSJunchao Zhang   } else {
13628c3ff71bSJunchao Zhang     ierr = MatSetType(*A,MATSEQAIJKOKKOS);CHKERRQ(ierr);
13638c3ff71bSJunchao Zhang     ierr = MatSeqAIJSetPreallocation(*A,d_nz,d_nnz);CHKERRQ(ierr);
13648c3ff71bSJunchao Zhang   }
13658c3ff71bSJunchao Zhang   PetscFunctionReturn(0);
13668c3ff71bSJunchao Zhang }
13678c3ff71bSJunchao Zhang 
1368a587d139SMark // get GPU pointer to stripped down Mat. For both Seq and MPI Mat.
1369042217e8SBarry Smith PetscErrorCode MatKokkosGetDeviceMatWrite(Mat A, PetscSplitCSRDataStructure *B)
1370a587d139SMark {
1371a587d139SMark   PetscMPIInt                size,rank;
1372a587d139SMark   MPI_Comm                   comm;
1373a587d139SMark   PetscErrorCode             ierr;
1374042217e8SBarry Smith   PetscSplitCSRDataStructure d_mat=NULL;
1375a587d139SMark 
1376a587d139SMark   PetscFunctionBegin;
1377a587d139SMark   ierr = PetscObjectGetComm((PetscObject)A,&comm);CHKERRQ(ierr);
137855b25c41SPierre Jolivet   ierr = MPI_Comm_size(comm,&size);CHKERRMPI(ierr);
137955b25c41SPierre Jolivet   ierr = MPI_Comm_rank(comm,&rank);CHKERRMPI(ierr);
1380a587d139SMark   if (size == 1) {
1381a587d139SMark     ierr   = MatSeqAIJKokkosGetDeviceMat(A,&d_mat);CHKERRQ(ierr);
1382a587d139SMark   } else {
1383a587d139SMark     Mat_MPIAIJ  *aij = (Mat_MPIAIJ*)A->data;
1384a587d139SMark     ierr   = MatSeqAIJKokkosGetDeviceMat(aij->A,&d_mat);CHKERRQ(ierr);
1385a587d139SMark   }
1386a587d139SMark   // act like MatSetValues because not called on host
1387a587d139SMark   if (A->assembled) {
1388a587d139SMark     if (A->was_assembled) {
1389a587d139SMark       ierr = PetscInfo(A,"Assemble more than once already\n");CHKERRQ(ierr);
1390a587d139SMark     }
1391a587d139SMark     A->was_assembled = PETSC_TRUE; // this is done (lazy) in MatAssemble but we are not calling it anymore - done in AIJ AssemblyEnd, need here?
1392a587d139SMark   } else {
1393a587d139SMark     ierr = PetscInfo1(A,"Warning !assemble ??? assembled=%D\n",A->assembled);CHKERRQ(ierr);
1394a587d139SMark     // SETERRQ(comm,PETSC_ERR_SUP,"Need assemble matrix");
1395a587d139SMark   }
1396a587d139SMark   if (!d_mat) {
1397042217e8SBarry Smith     struct _n_SplitCSRMat h_mat; /* host container */
1398a587d139SMark     Mat_SeqAIJKokkos      *aijkokA;
1399a587d139SMark     Mat_SeqAIJ            *jaca;
1400a587d139SMark     PetscInt              n = A->rmap->n, nnz;
1401a587d139SMark     Mat                   Amat;
1402042217e8SBarry Smith     PetscInt              *colmap;
1403042217e8SBarry Smith 
1404042217e8SBarry Smith     /* create and copy h_mat */
1405a587d139SMark     ierr = PetscInfo(A,"Create device matrix in Kokkos\n");CHKERRQ(ierr);
1406a587d139SMark     if (size == 1) {
1407a587d139SMark       Amat = A;
1408a587d139SMark       jaca = (Mat_SeqAIJ*)A->data;
1409a587d139SMark       h_mat.rstart = 0; h_mat.rend = A->rmap->n;
1410a587d139SMark       h_mat.cstart = 0; h_mat.cend = A->cmap->n;
1411a587d139SMark       h_mat.offdiag.i = h_mat.offdiag.j = NULL;
1412a587d139SMark       h_mat.offdiag.a = NULL;
1413a587d139SMark       aijkokA = static_cast<Mat_SeqAIJKokkos*>(A->spptr);
1414a587d139SMark       aijkokA->i_uncompressed_d = NULL;
1415a587d139SMark       aijkokA->colmap_d = NULL;
1416a587d139SMark     } else {
1417a587d139SMark       Mat_MPIAIJ       *aij = (Mat_MPIAIJ*)A->data;
1418a587d139SMark       Mat_SeqAIJ       *jacb = (Mat_SeqAIJ*)aij->B->data;
1419a587d139SMark       PetscInt         ii;
1420a587d139SMark       Mat_SeqAIJKokkos *aijkokB;
1421042217e8SBarry Smith 
1422a587d139SMark       Amat = aij->A;
1423a587d139SMark       aijkokA = static_cast<Mat_SeqAIJKokkos*>(aij->A->spptr);
1424a587d139SMark       aijkokB = static_cast<Mat_SeqAIJKokkos*>(aij->B->spptr);
1425a587d139SMark       aijkokA->i_uncompressed_d = NULL;
1426a587d139SMark       aijkokA->colmap_d = NULL;
1427a587d139SMark       jaca = (Mat_SeqAIJ*)aij->A->data;
1428b3c64f9dSJunchao Zhang       if (aij->B->cmap->n && !aij->garray) SETERRQ(comm,PETSC_ERR_PLIB,"MPIAIJ Matrix was assembled but is missing garray");
1429a587d139SMark       if (aij->B->rmap->n != aij->A->rmap->n) SETERRQ(comm,PETSC_ERR_SUP,"Only support aij->B->rmap->n == aij->A->rmap->n");
1430a587d139SMark       aij->donotstash = PETSC_TRUE;
1431a587d139SMark       aij->A->nooffprocentries = aij->B->nooffprocentries = A->nooffprocentries = PETSC_TRUE;
1432a5b23f4aSJose E. Roman       jaca->nonew = jacb->nonew = PETSC_TRUE; // no more disassembly
1433042217e8SBarry Smith       ierr = PetscCalloc1(A->cmap->N,&colmap);CHKERRQ(ierr);
1434042217e8SBarry Smith       ierr = PetscLogObjectMemory((PetscObject)A,(A->cmap->N)*sizeof(PetscInt));CHKERRQ(ierr);
1435042217e8SBarry Smith       for (ii=0; ii<aij->B->cmap->n; ii++) colmap[aij->garray[ii]] = ii+1;
1436a587d139SMark       // allocate B copy data
1437a587d139SMark       h_mat.rstart = A->rmap->rstart; h_mat.rend = A->rmap->rend;
1438a587d139SMark       h_mat.cstart = A->cmap->rstart; h_mat.cend = A->cmap->rend;
1439a587d139SMark       nnz = jacb->i[n];
1440a587d139SMark       if (jacb->compressedrow.use) {
1441a587d139SMark         const Kokkos::View<PetscInt*, Kokkos::HostSpace, Kokkos::MemoryTraits<Kokkos::Unmanaged> > h_i_k (jacb->i,n+1);
14429d676ae4SMark Adams         aijkokB->i_uncompressed_d = new Kokkos::View<PetscInt*>(Kokkos::create_mirror(DefaultMemorySpace(),h_i_k));
1443a587d139SMark         Kokkos::deep_copy (*aijkokB->i_uncompressed_d, h_i_k);
1444a587d139SMark         h_mat.offdiag.i = aijkokB->i_uncompressed_d->data();
1445a587d139SMark       } else {
1446*076ba34aSJunchao Zhang          h_mat.offdiag.i = (PetscInt*)aijkokB->i_device_data();
1447a587d139SMark       }
1448*076ba34aSJunchao Zhang       h_mat.offdiag.j = (PetscInt*)aijkokB->j_device_data();
1449*076ba34aSJunchao Zhang       h_mat.offdiag.a = aijkokB->a_device_data();
1450a587d139SMark       {
1451042217e8SBarry Smith         Kokkos::View<PetscInt*, Kokkos::HostSpace, Kokkos::MemoryTraits<Kokkos::Unmanaged> > h_colmap_k (colmap,A->cmap->N);
14529d676ae4SMark Adams         aijkokB->colmap_d = new Kokkos::View<PetscInt*>(Kokkos::create_mirror(DefaultMemorySpace(),h_colmap_k));
1453a587d139SMark         Kokkos::deep_copy (*aijkokB->colmap_d, h_colmap_k);
1454a587d139SMark         h_mat.colmap = aijkokB->colmap_d->data();
1455042217e8SBarry Smith         ierr = PetscFree(colmap);CHKERRQ(ierr);
1456a587d139SMark       }
1457a587d139SMark       h_mat.offdiag.ignorezeroentries = jacb->ignorezeroentries;
1458a587d139SMark       h_mat.offdiag.n = n;
1459a587d139SMark     }
1460a587d139SMark     // allocate A copy data
1461a587d139SMark     nnz = jaca->i[n];
1462a587d139SMark     h_mat.diag.n = n;
1463a587d139SMark     h_mat.diag.ignorezeroentries = jaca->ignorezeroentries;
146455b25c41SPierre Jolivet     ierr = MPI_Comm_rank(comm,&h_mat.rank);CHKERRMPI(ierr);
1465042217e8SBarry Smith     if (jaca->compressedrow.use) SETERRQ(PETSC_COMM_SELF,PETSC_ERR_PLIB,"A does not suppport compressed row (todo)");
1466042217e8SBarry Smith     else {
1467*076ba34aSJunchao Zhang       h_mat.diag.i = (PetscInt*)aijkokA->i_device_data();
1468a587d139SMark     }
1469*076ba34aSJunchao Zhang     h_mat.diag.j = (PetscInt*)aijkokA->j_device_data();
1470*076ba34aSJunchao Zhang     h_mat.diag.a = aijkokA->a_device_data();
1471a587d139SMark     // copy pointers and metdata to device
1472a587d139SMark     ierr = MatSeqAIJKokkosSetDeviceMat(Amat,&h_mat);CHKERRQ(ierr);
1473a587d139SMark     ierr = MatSeqAIJKokkosGetDeviceMat(Amat,&d_mat);CHKERRQ(ierr);
1474a587d139SMark     ierr = PetscInfo2(A,"Create device Mat n=%D nnz=%D\n",h_mat.diag.n, nnz);CHKERRQ(ierr);
1475a587d139SMark   }
1476a587d139SMark   *B = d_mat; // return it, set it in Mat, and set it up
1477a587d139SMark   A->assembled = PETSC_FALSE; // ready to write with matsetvalues - this done (lazy) in normal MatSetValues
1478a587d139SMark   PetscFunctionReturn(0);
1479a587d139SMark }
1480*076ba34aSJunchao Zhang 
1481*076ba34aSJunchao Zhang PETSC_INTERN PetscErrorCode MatSeqAIJKokkosGetOffloadMask(Mat A, const char **mask)
1482*076ba34aSJunchao Zhang {
1483*076ba34aSJunchao Zhang   Mat_SeqAIJKokkos  *aijkok = static_cast<Mat_SeqAIJKokkos*>(A->spptr);
1484*076ba34aSJunchao Zhang 
1485*076ba34aSJunchao Zhang   PetscFunctionBegin;
1486*076ba34aSJunchao Zhang   if (!aijkok) *mask = "AIJKOK_UNALLOCATED";
1487*076ba34aSJunchao Zhang   else if (aijkok->a_dual.need_sync_host()) *mask = "PETSC_OFFLOAD_GPU";
1488*076ba34aSJunchao Zhang   else if (aijkok->a_dual.need_sync_device()) *mask = "PETSC_OFFLOAD_CPU";
1489*076ba34aSJunchao Zhang   else *mask = "PETSC_OFFLOAD_BOTH";
1490*076ba34aSJunchao Zhang   PetscFunctionReturn(0);
1491*076ba34aSJunchao Zhang }
1492*076ba34aSJunchao Zhang 
1493*076ba34aSJunchao Zhang PETSC_INTERN PetscErrorCode MatAIJKokkosPrintOffloadMask(Mat A)
1494*076ba34aSJunchao Zhang {
1495*076ba34aSJunchao Zhang   PetscErrorCode    ierr;
1496*076ba34aSJunchao Zhang   PetscMPIInt       size;
1497*076ba34aSJunchao Zhang   Mat               Ad,Ao;
1498*076ba34aSJunchao Zhang   const char        *amask,*bmask;
1499*076ba34aSJunchao Zhang 
1500*076ba34aSJunchao Zhang   PetscFunctionBegin;
1501*076ba34aSJunchao Zhang   ierr = MPI_Comm_size(PetscObjectComm((PetscObject)A),&size);CHKERRMPI(ierr);
1502*076ba34aSJunchao Zhang 
1503*076ba34aSJunchao Zhang   if (size == 1) {
1504*076ba34aSJunchao Zhang     ierr = MatSeqAIJKokkosGetOffloadMask(A,&amask);CHKERRQ(ierr);
1505*076ba34aSJunchao Zhang     ierr = PetscPrintf(PETSC_COMM_SELF,"%s\n",amask);CHKERRQ(ierr);
1506*076ba34aSJunchao Zhang   } else {
1507*076ba34aSJunchao Zhang     Ad  = ((Mat_MPIAIJ*)A->data)->A;
1508*076ba34aSJunchao Zhang     Ao  = ((Mat_MPIAIJ*)A->data)->B;
1509*076ba34aSJunchao Zhang     ierr = MatSeqAIJKokkosGetOffloadMask(Ad,&amask);CHKERRQ(ierr);
1510*076ba34aSJunchao Zhang     ierr = MatSeqAIJKokkosGetOffloadMask(Ao,&bmask);CHKERRQ(ierr);
1511*076ba34aSJunchao Zhang     ierr = PetscPrintf(PETSC_COMM_SELF,"Diag : Off-diag = %s : %s\n",amask,bmask);CHKERRQ(ierr);
1512*076ba34aSJunchao Zhang   }
1513*076ba34aSJunchao Zhang   PetscFunctionReturn(0);
1514*076ba34aSJunchao Zhang }
1515