xref: /petsc/src/mat/impls/aij/mpi/kokkos/mpiaijkok.kokkos.cxx (revision 394ed5eb51aeb2fd4d5be41f3bc40ab0c0d7d8f0)
111d22bbfSJunchao Zhang #include <petscvec_kokkos.hpp>
2076ba34aSJunchao Zhang #include <petscsf.h>
342550becSJunchao Zhang #include <petsc/private/sfimpl.h>
48c3ff71bSJunchao Zhang #include <../src/mat/impls/aij/mpi/mpiaij.h>
542550becSJunchao Zhang #include <../src/mat/impls/aij/mpi/kokkos/mpiaijkok.hpp>
6076ba34aSJunchao Zhang #include <KokkosSparse_spadd.hpp>
711d22bbfSJunchao Zhang 
88c3ff71bSJunchao Zhang PetscErrorCode MatAssemblyEnd_MPIAIJKokkos(Mat A,MatAssemblyType mode)
98c3ff71bSJunchao Zhang {
108c3ff71bSJunchao Zhang   PetscErrorCode   ierr;
118c3ff71bSJunchao Zhang   Mat_MPIAIJ       *mpiaij = (Mat_MPIAIJ*)A->data;
12a587d139SMark   Mat_SeqAIJKokkos *aijkok = mpiaij->A->spptr ? static_cast<Mat_SeqAIJKokkos*>(mpiaij->A->spptr) : NULL;
138c3ff71bSJunchao Zhang 
148c3ff71bSJunchao Zhang   PetscFunctionBegin;
158c3ff71bSJunchao Zhang   ierr = MatAssemblyEnd_MPIAIJ(A,mode);CHKERRQ(ierr);
16a587d139SMark   if (aijkok && aijkok->device_mat_d.data()) {
17a587d139SMark     A->offloadmask = PETSC_OFFLOAD_GPU; // in GPU mode, no going back. MatSetValues checks this
18a587d139SMark   }
19a587d139SMark 
208c3ff71bSJunchao Zhang   PetscFunctionReturn(0);
218c3ff71bSJunchao Zhang }
228c3ff71bSJunchao Zhang 
238c3ff71bSJunchao Zhang PetscErrorCode MatMPIAIJSetPreallocation_MPIAIJKokkos(Mat mat,PetscInt d_nz,const PetscInt d_nnz[],PetscInt o_nz,const PetscInt o_nnz[])
248c3ff71bSJunchao Zhang {
258c3ff71bSJunchao Zhang   PetscErrorCode ierr;
268c3ff71bSJunchao Zhang   Mat_MPIAIJ     *mpiaij = (Mat_MPIAIJ*)mat->data;
278c3ff71bSJunchao Zhang 
288c3ff71bSJunchao Zhang   PetscFunctionBegin;
298c3ff71bSJunchao Zhang   ierr = PetscLayoutSetUp(mat->rmap);CHKERRQ(ierr);
308c3ff71bSJunchao Zhang   ierr = PetscLayoutSetUp(mat->cmap);CHKERRQ(ierr);
316a29ce69SStefano Zampini #if defined(PETSC_USE_DEBUG)
328c3ff71bSJunchao Zhang   if (d_nnz) {
336a29ce69SStefano Zampini     PetscInt i;
348c3ff71bSJunchao Zhang     for (i=0; i<mat->rmap->n; i++) {
352c71b3e2SJacob Faibussowitsch       PetscCheckFalse(d_nnz[i] < 0,PETSC_COMM_SELF,PETSC_ERR_ARG_OUTOFRANGE,"d_nnz cannot be less than 0: local row %" PetscInt_FMT " value %" PetscInt_FMT,i,d_nnz[i]);
368c3ff71bSJunchao Zhang     }
378c3ff71bSJunchao Zhang   }
388c3ff71bSJunchao Zhang   if (o_nnz) {
396a29ce69SStefano Zampini     PetscInt i;
408c3ff71bSJunchao Zhang     for (i=0; i<mat->rmap->n; i++) {
412c71b3e2SJacob Faibussowitsch       PetscCheckFalse(o_nnz[i] < 0,PETSC_COMM_SELF,PETSC_ERR_ARG_OUTOFRANGE,"o_nnz cannot be less than 0: local row %" PetscInt_FMT " value %" PetscInt_FMT,i,o_nnz[i]);
428c3ff71bSJunchao Zhang     }
438c3ff71bSJunchao Zhang   }
446a29ce69SStefano Zampini #endif
456a29ce69SStefano Zampini #if defined(PETSC_USE_CTABLE)
466a29ce69SStefano Zampini   ierr = PetscTableDestroy(&mpiaij->colmap);CHKERRQ(ierr);
476a29ce69SStefano Zampini #else
486a29ce69SStefano Zampini   ierr = PetscFree(mpiaij->colmap);CHKERRQ(ierr);
496a29ce69SStefano Zampini #endif
506a29ce69SStefano Zampini   ierr = PetscFree(mpiaij->garray);CHKERRQ(ierr);
516a29ce69SStefano Zampini   ierr = VecDestroy(&mpiaij->lvec);CHKERRQ(ierr);
526a29ce69SStefano Zampini   ierr = VecScatterDestroy(&mpiaij->Mvctx);CHKERRQ(ierr);
536a29ce69SStefano Zampini   /* Because the B will have been resized we simply destroy it and create a new one each time */
546a29ce69SStefano Zampini   ierr = MatDestroy(&mpiaij->B);CHKERRQ(ierr);
556a29ce69SStefano Zampini 
566a29ce69SStefano Zampini   if (!mpiaij->A) {
578c3ff71bSJunchao Zhang     ierr = MatCreate(PETSC_COMM_SELF,&mpiaij->A);CHKERRQ(ierr);
588c3ff71bSJunchao Zhang     ierr = MatSetSizes(mpiaij->A,mat->rmap->n,mat->cmap->n,mat->rmap->n,mat->cmap->n);CHKERRQ(ierr);
598c3ff71bSJunchao Zhang     ierr = PetscLogObjectParent((PetscObject)mat,(PetscObject)mpiaij->A);CHKERRQ(ierr);
606a29ce69SStefano Zampini   }
616a29ce69SStefano Zampini   if (!mpiaij->B) {
626a29ce69SStefano Zampini     PetscMPIInt size;
6355b25c41SPierre Jolivet     ierr = MPI_Comm_size(PetscObjectComm((PetscObject)mat),&size);CHKERRMPI(ierr);
648c3ff71bSJunchao Zhang     ierr = MatCreate(PETSC_COMM_SELF,&mpiaij->B);CHKERRQ(ierr);
656a29ce69SStefano Zampini     ierr = MatSetSizes(mpiaij->B,mat->rmap->n,size > 1 ? mat->cmap->N : 0,mat->rmap->n,size > 1 ? mat->cmap->N : 0);CHKERRQ(ierr);
668c3ff71bSJunchao Zhang     ierr = PetscLogObjectParent((PetscObject)mat,(PetscObject)mpiaij->B);CHKERRQ(ierr);
678c3ff71bSJunchao Zhang   }
686a29ce69SStefano Zampini   ierr = MatSetType(mpiaij->A,MATSEQAIJKOKKOS);CHKERRQ(ierr);
696a29ce69SStefano Zampini   ierr = MatSetType(mpiaij->B,MATSEQAIJKOKKOS);CHKERRQ(ierr);
708c3ff71bSJunchao Zhang   ierr = MatSeqAIJSetPreallocation(mpiaij->A,d_nz,d_nnz);CHKERRQ(ierr);
718c3ff71bSJunchao Zhang   ierr = MatSeqAIJSetPreallocation(mpiaij->B,o_nz,o_nnz);CHKERRQ(ierr);
728c3ff71bSJunchao Zhang   mat->preallocated = PETSC_TRUE;
738c3ff71bSJunchao Zhang   PetscFunctionReturn(0);
748c3ff71bSJunchao Zhang }
758c3ff71bSJunchao Zhang 
768c3ff71bSJunchao Zhang PetscErrorCode MatMult_MPIAIJKokkos(Mat mat,Vec xx,Vec yy)
778c3ff71bSJunchao Zhang {
788c3ff71bSJunchao Zhang   Mat_MPIAIJ     *mpiaij = (Mat_MPIAIJ*)mat->data;
798c3ff71bSJunchao Zhang   PetscErrorCode ierr;
808c3ff71bSJunchao Zhang   PetscInt       nt;
818c3ff71bSJunchao Zhang 
828c3ff71bSJunchao Zhang   PetscFunctionBegin;
838c3ff71bSJunchao Zhang   ierr = VecGetLocalSize(xx,&nt);CHKERRQ(ierr);
842c71b3e2SJacob Faibussowitsch   PetscCheckFalse(nt != mat->cmap->n,PETSC_COMM_SELF,PETSC_ERR_ARG_SIZ,"Incompatible partition of mat (%" PetscInt_FMT ") and xx (%" PetscInt_FMT ")",mat->cmap->n,nt);
858c3ff71bSJunchao Zhang   ierr = VecScatterBegin(mpiaij->Mvctx,xx,mpiaij->lvec,INSERT_VALUES,SCATTER_FORWARD);CHKERRQ(ierr);
868c3ff71bSJunchao Zhang   ierr = (*mpiaij->A->ops->mult)(mpiaij->A,xx,yy);CHKERRQ(ierr);
878c3ff71bSJunchao Zhang   ierr = VecScatterEnd(mpiaij->Mvctx,xx,mpiaij->lvec,INSERT_VALUES,SCATTER_FORWARD);CHKERRQ(ierr);
888c3ff71bSJunchao Zhang   ierr = (*mpiaij->B->ops->multadd)(mpiaij->B,mpiaij->lvec,yy,yy);CHKERRQ(ierr);
898c3ff71bSJunchao Zhang   PetscFunctionReturn(0);
908c3ff71bSJunchao Zhang }
918c3ff71bSJunchao Zhang 
928c3ff71bSJunchao Zhang PetscErrorCode MatMultAdd_MPIAIJKokkos(Mat mat,Vec xx,Vec yy,Vec zz)
938c3ff71bSJunchao Zhang {
948c3ff71bSJunchao Zhang   Mat_MPIAIJ     *mpiaij = (Mat_MPIAIJ*)mat->data;
958c3ff71bSJunchao Zhang   PetscErrorCode ierr;
968c3ff71bSJunchao Zhang   PetscInt       nt;
978c3ff71bSJunchao Zhang 
988c3ff71bSJunchao Zhang   PetscFunctionBegin;
998c3ff71bSJunchao Zhang   ierr = VecGetLocalSize(xx,&nt);CHKERRQ(ierr);
1002c71b3e2SJacob Faibussowitsch   PetscCheckFalse(nt != mat->cmap->n,PETSC_COMM_SELF,PETSC_ERR_ARG_SIZ,"Incompatible partition of mat (%" PetscInt_FMT ") and xx (%" PetscInt_FMT ")",mat->cmap->n,nt);
1018c3ff71bSJunchao Zhang   ierr = VecScatterBegin(mpiaij->Mvctx,xx,mpiaij->lvec,INSERT_VALUES,SCATTER_FORWARD);CHKERRQ(ierr);
1028c3ff71bSJunchao Zhang   ierr = (*mpiaij->A->ops->multadd)(mpiaij->A,xx,yy,zz);CHKERRQ(ierr);
1038c3ff71bSJunchao Zhang   ierr = VecScatterEnd(mpiaij->Mvctx,xx,mpiaij->lvec,INSERT_VALUES,SCATTER_FORWARD);CHKERRQ(ierr);
1048c3ff71bSJunchao Zhang   ierr = (*mpiaij->B->ops->multadd)(mpiaij->B,mpiaij->lvec,zz,zz);CHKERRQ(ierr);
1058c3ff71bSJunchao Zhang   PetscFunctionReturn(0);
1068c3ff71bSJunchao Zhang }
1078c3ff71bSJunchao Zhang 
1088c3ff71bSJunchao Zhang PetscErrorCode MatMultTranspose_MPIAIJKokkos(Mat mat,Vec xx,Vec yy)
1098c3ff71bSJunchao Zhang {
1108c3ff71bSJunchao Zhang   Mat_MPIAIJ     *mpiaij = (Mat_MPIAIJ*)mat->data;
1118c3ff71bSJunchao Zhang   PetscErrorCode ierr;
1128c3ff71bSJunchao Zhang   PetscInt       nt;
1138c3ff71bSJunchao Zhang 
1148c3ff71bSJunchao Zhang   PetscFunctionBegin;
1158c3ff71bSJunchao Zhang   ierr = VecGetLocalSize(xx,&nt);CHKERRQ(ierr);
1162c71b3e2SJacob Faibussowitsch   PetscCheckFalse(nt != mat->rmap->n,PETSC_COMM_SELF,PETSC_ERR_ARG_SIZ,"Incompatible partition of mat (%" PetscInt_FMT ") and xx (%" PetscInt_FMT ")",mat->rmap->n,nt);
1178c3ff71bSJunchao Zhang   ierr = (*mpiaij->B->ops->multtranspose)(mpiaij->B,xx,mpiaij->lvec);CHKERRQ(ierr);
1188c3ff71bSJunchao Zhang   ierr = (*mpiaij->A->ops->multtranspose)(mpiaij->A,xx,yy);CHKERRQ(ierr);
1198c3ff71bSJunchao Zhang   ierr = VecScatterBegin(mpiaij->Mvctx,mpiaij->lvec,yy,ADD_VALUES,SCATTER_REVERSE);CHKERRQ(ierr);
1208c3ff71bSJunchao Zhang   ierr = VecScatterEnd(mpiaij->Mvctx,mpiaij->lvec,yy,ADD_VALUES,SCATTER_REVERSE);CHKERRQ(ierr);
1218c3ff71bSJunchao Zhang   PetscFunctionReturn(0);
1228c3ff71bSJunchao Zhang }
1238c3ff71bSJunchao Zhang 
124076ba34aSJunchao Zhang /* Merge the "A, B" matrices of mat into a matrix C.  mat's type is MPIAIJKOKKOS. C's type is MATSEQAIJKOKKOS.
125076ba34aSJunchao Zhang    A is put before B. C's size would be A->rmap->n by (A->cmap->n + B->cmap->n).
126076ba34aSJunchao Zhang    C still uses local column ids. Their corresponding global column ids are returned in glob.
127076ba34aSJunchao Zhang */
128076ba34aSJunchao Zhang PetscErrorCode MatMPIAIJGetLocalMatMerge_MPIAIJKokkos(Mat mat,MatReuse reuse,IS *glob,Mat *C)
129076ba34aSJunchao Zhang {
130076ba34aSJunchao Zhang   Mat            Ad,Ao;
131076ba34aSJunchao Zhang   const PetscInt *cmap;
132076ba34aSJunchao Zhang   PetscErrorCode ierr;
133076ba34aSJunchao Zhang 
134076ba34aSJunchao Zhang   PetscFunctionBegin;
135076ba34aSJunchao Zhang   ierr = MatMPIAIJGetSeqAIJ(mat,&Ad,&Ao,&cmap);CHKERRQ(ierr);
136076ba34aSJunchao Zhang   ierr = MatSeqAIJKokkosMergeMats(Ad,Ao,reuse,C);CHKERRQ(ierr);
137076ba34aSJunchao Zhang   if (glob) {
138076ba34aSJunchao Zhang     PetscInt cst, i, dn, on, *gidx;
139076ba34aSJunchao Zhang     ierr = MatGetLocalSize(Ad,NULL,&dn);CHKERRQ(ierr);
140076ba34aSJunchao Zhang     ierr = MatGetLocalSize(Ao,NULL,&on);CHKERRQ(ierr);
141076ba34aSJunchao Zhang     ierr = MatGetOwnershipRangeColumn(mat,&cst,NULL);CHKERRQ(ierr);
142076ba34aSJunchao Zhang     ierr = PetscMalloc1(dn+on,&gidx);CHKERRQ(ierr);
143076ba34aSJunchao Zhang     for (i=0; i<dn; i++) gidx[i]    = cst + i;
144076ba34aSJunchao Zhang     for (i=0; i<on; i++) gidx[i+dn] = cmap[i];
145076ba34aSJunchao Zhang     ierr = ISCreateGeneral(PetscObjectComm((PetscObject)Ad),dn+on,gidx,PETSC_OWN_POINTER,glob);CHKERRQ(ierr);
146076ba34aSJunchao Zhang   }
147076ba34aSJunchao Zhang   PetscFunctionReturn(0);
148076ba34aSJunchao Zhang }
149076ba34aSJunchao Zhang 
150076ba34aSJunchao Zhang /* Structs used in matrix product C=AB, C=A^tB and C=B^tAB */
151076ba34aSJunchao Zhang struct MatMatStruct {
152076ba34aSJunchao Zhang   MatRowMapKokkosView   Cdstart; /* Used to split sequential matrix into petsc's A, B format */
153076ba34aSJunchao Zhang   PetscSF               sf; /* SF to send/recv matrix entries */
154076ba34aSJunchao Zhang   MatScalarKokkosView   abuf; /* buf of mat values in send/recv */
155076ba34aSJunchao Zhang   Mat                   C1,C2,B_local;
156076ba34aSJunchao Zhang   KokkosCsrMatrix       C1_global,C2_global,C_global;
157076ba34aSJunchao Zhang   KernelHandle          kh;
158076ba34aSJunchao Zhang   MatMatStruct() {
159076ba34aSJunchao Zhang     C1 = C2 = B_local = NULL;
160076ba34aSJunchao Zhang     sf = NULL;
161076ba34aSJunchao Zhang   }
162076ba34aSJunchao Zhang 
163076ba34aSJunchao Zhang   ~MatMatStruct() {
164076ba34aSJunchao Zhang     MatDestroy(&C1);
165076ba34aSJunchao Zhang     MatDestroy(&C2);
166076ba34aSJunchao Zhang     MatDestroy(&B_local);
167076ba34aSJunchao Zhang     PetscSFDestroy(&sf);
168076ba34aSJunchao Zhang     kh.destroy_spadd_handle();
169076ba34aSJunchao Zhang   }
170076ba34aSJunchao Zhang };
171076ba34aSJunchao Zhang 
172076ba34aSJunchao Zhang struct MatMatStruct_AB : public MatMatStruct {
173076ba34aSJunchao Zhang   MatColIdxKokkosView   rows;
174076ba34aSJunchao Zhang   MatRowMapKokkosView   rowoffset;
175076ba34aSJunchao Zhang   Mat                   B_other,C_petsc; /* SEQAIJKOKKOS matrices. TODO: have a better var name than C_petsc */
176076ba34aSJunchao Zhang 
177076ba34aSJunchao Zhang   MatMatStruct_AB() : B_other(NULL),C_petsc(NULL){}
178076ba34aSJunchao Zhang   ~MatMatStruct_AB() {
179076ba34aSJunchao Zhang     MatDestroy(&B_other);
180076ba34aSJunchao Zhang     MatDestroy(&C_petsc);
181076ba34aSJunchao Zhang   }
182076ba34aSJunchao Zhang };
183076ba34aSJunchao Zhang 
184076ba34aSJunchao Zhang struct MatMatStruct_AtB : public MatMatStruct {
185076ba34aSJunchao Zhang   MatRowMapKokkosView   srcrowoffset,dstrowoffset;
186076ba34aSJunchao Zhang };
187076ba34aSJunchao Zhang 
188076ba34aSJunchao Zhang struct MatProductData_MPIAIJKokkos
189076ba34aSJunchao Zhang {
190076ba34aSJunchao Zhang   MatMatStruct_AB   *mmAB;
191076ba34aSJunchao Zhang   MatMatStruct_AtB  *mmAtB;
192076ba34aSJunchao Zhang   PetscBool         reusesym;
193076ba34aSJunchao Zhang 
194076ba34aSJunchao Zhang   MatProductData_MPIAIJKokkos(): mmAB(NULL),mmAtB(NULL),reusesym(PETSC_FALSE){}
195076ba34aSJunchao Zhang   ~MatProductData_MPIAIJKokkos() {
196076ba34aSJunchao Zhang     delete mmAB;
197076ba34aSJunchao Zhang     delete mmAtB;
198076ba34aSJunchao Zhang   }
199076ba34aSJunchao Zhang };
200076ba34aSJunchao Zhang 
201076ba34aSJunchao Zhang static PetscErrorCode MatProductDataDestroy_MPIAIJKokkos(void *data)
202076ba34aSJunchao Zhang {
203076ba34aSJunchao Zhang   PetscFunctionBegin;
204076ba34aSJunchao Zhang   CHKERRCXX(delete static_cast<MatProductData_MPIAIJKokkos*>(data));
205076ba34aSJunchao Zhang   PetscFunctionReturn(0);
206076ba34aSJunchao Zhang }
207076ba34aSJunchao Zhang 
208076ba34aSJunchao Zhang /* MatSeqAIJKokkosGetCSRMatrixWithGlobalColumnIds - Get a KokkosCsrMatrix from a MATSEQAIJKOKKOS matrix
209076ba34aSJunchao Zhang 
210076ba34aSJunchao Zhang    Input Parameters:
211076ba34aSJunchao Zhang +  A       - the MATSEQAIJKOKKOS matrix
212076ba34aSJunchao Zhang .  N       - new column size for the returned Kokkos matrix
213076ba34aSJunchao Zhang -  l2g     - a map that maps old col ids to new col ids
214076ba34aSJunchao Zhang 
215076ba34aSJunchao Zhang    Output Parameters:
216076ba34aSJunchao Zhang .  csrmat  - the Kokkos matrix, which has the same row size as A, shares a, i but not j with A.
217076ba34aSJunchao Zhang  */
218076ba34aSJunchao Zhang static PetscErrorCode MatSeqAIJKokkosGetCSRMatrixWithGlobalColumnIds(Mat A,PetscInt N,const ConstMatColIdxKokkosView& l2g,KokkosCsrMatrix& csrmat)
219076ba34aSJunchao Zhang {
220076ba34aSJunchao Zhang   KokkosCsrMatrix&         orig = static_cast<Mat_SeqAIJKokkos*>(A->spptr)->csrmat;
221076ba34aSJunchao Zhang   MatColIdxKokkosView      jg("jg",orig.nnz()); /* New j array for csrmat */
222076ba34aSJunchao Zhang 
223076ba34aSJunchao Zhang   PetscFunctionBegin;
224076ba34aSJunchao Zhang   CHKERRCXX(Kokkos::parallel_for(orig.nnz(), KOKKOS_LAMBDA(const PetscInt i) {jg(i) = l2g(orig.graph.entries(i));}));
225076ba34aSJunchao Zhang   CHKERRCXX(csrmat = KokkosCsrMatrix("csrmat",orig.numRows(),N,orig.nnz(),orig.values,orig.graph.row_map,jg));
226076ba34aSJunchao Zhang   PetscFunctionReturn(0);
227076ba34aSJunchao Zhang }
228076ba34aSJunchao Zhang 
229076ba34aSJunchao Zhang /* MatSetMPIAIJKokkosWithSplitSeqAIJKokkosMatrices - Set the diag and offdiag matrices of a MATMPIAIJKOKKOS matrix.
230076ba34aSJunchao Zhang    It is similar to MatCreateMPIAIJWithSplitArrays.
231076ba34aSJunchao Zhang 
232076ba34aSJunchao Zhang   Input Parameters:
233076ba34aSJunchao Zhang +  mat   - the MATMPIAIJKOKKOS matrix, which should have its type and layout set, but should not have its diag, offdiag matrices set
234076ba34aSJunchao Zhang .  A     - the diag matrix using local col ids
235076ba34aSJunchao Zhang -  B     - the offdiag matrix using global col ids
236076ba34aSJunchao Zhang 
237076ba34aSJunchao Zhang   Output Parameters:
238076ba34aSJunchao Zhang .  mat   - the updated MATMPIAIJKOKKOS matrix
239076ba34aSJunchao Zhang */
240076ba34aSJunchao Zhang static PetscErrorCode MatSetMPIAIJKokkosWithSplitSeqAIJKokkosMatrices(Mat mat,Mat A,Mat B)
241076ba34aSJunchao Zhang {
242076ba34aSJunchao Zhang   PetscErrorCode      ierr;
243076ba34aSJunchao Zhang   Mat_MPIAIJ          *mpiaij = static_cast<Mat_MPIAIJ*>(mat->data);
244076ba34aSJunchao Zhang   PetscInt            m,n,M,N,Am,An,Bm,Bn;
245076ba34aSJunchao Zhang   Mat_SeqAIJKokkos    *bkok = static_cast<Mat_SeqAIJKokkos*>(B->spptr);
246076ba34aSJunchao Zhang 
247076ba34aSJunchao Zhang   PetscFunctionBegin;
248076ba34aSJunchao Zhang   ierr = MatGetSize(mat,&M,&N);CHKERRQ(ierr);
249076ba34aSJunchao Zhang   ierr = MatGetLocalSize(mat,&m,&n);CHKERRQ(ierr);
250076ba34aSJunchao Zhang   ierr = MatGetLocalSize(A,&Am,&An);CHKERRQ(ierr);
251076ba34aSJunchao Zhang   ierr = MatGetLocalSize(B,&Bm,&Bn);CHKERRQ(ierr);
252076ba34aSJunchao Zhang 
2532c71b3e2SJacob Faibussowitsch   PetscCheckFalse(m != Am || m != Bm,PETSC_COMM_SELF,PETSC_ERR_PLIB,"local number of rows do not match");
2542c71b3e2SJacob Faibussowitsch   PetscCheckFalse(n != An,PETSC_COMM_SELF,PETSC_ERR_PLIB,"local number of columns do not match");
2552c71b3e2SJacob Faibussowitsch   PetscCheckFalse(N != Bn,PETSC_COMM_SELF,PETSC_ERR_PLIB,"global number of columns do not match");
2562c71b3e2SJacob Faibussowitsch   PetscCheckFalse(mpiaij->A || mpiaij->B,PETSC_COMM_SELF,PETSC_ERR_PLIB,"A, B of the MPIAIJ matrix are not empty");
257076ba34aSJunchao Zhang   mpiaij->A = A;
258076ba34aSJunchao Zhang   mpiaij->B = B;
259076ba34aSJunchao Zhang 
260076ba34aSJunchao Zhang   mat->preallocated     = PETSC_TRUE;
261076ba34aSJunchao Zhang   mat->nooffprocentries = PETSC_TRUE; /* See MatAssemblyBegin_MPIAIJ. In effect, making MatAssemblyBegin a nop */
262076ba34aSJunchao Zhang 
263076ba34aSJunchao Zhang   ierr = MatSetOption(mat,MAT_NO_OFF_PROC_ENTRIES,PETSC_TRUE);CHKERRQ(ierr);
264076ba34aSJunchao Zhang   ierr = MatAssemblyBegin(mat,MAT_FINAL_ASSEMBLY);CHKERRQ(ierr);
265076ba34aSJunchao Zhang   /* MatAssemblyEnd is critical here. It sets mat->offloadmask according to A and B's, and
266076ba34aSJunchao Zhang     also gets mpiaij->B compacted, with its col ids and size reduced
267076ba34aSJunchao Zhang   */
268076ba34aSJunchao Zhang   ierr = MatAssemblyEnd(mat,MAT_FINAL_ASSEMBLY);CHKERRQ(ierr);
269076ba34aSJunchao Zhang   ierr = MatSetOption(mat,MAT_NO_OFF_PROC_ENTRIES,PETSC_FALSE);CHKERRQ(ierr);
270076ba34aSJunchao Zhang   ierr = MatSetOption(mat,MAT_NEW_NONZERO_LOCATION_ERR,PETSC_TRUE);CHKERRQ(ierr);
271076ba34aSJunchao Zhang 
272076ba34aSJunchao Zhang   /* Update bkok with new local col ids (stored on host) and size */
273076ba34aSJunchao Zhang   bkok->j_dual.modify_host();
274076ba34aSJunchao Zhang   bkok->j_dual.sync_device();
275076ba34aSJunchao Zhang   bkok->SetColSize(mpiaij->B->cmap->n);
276076ba34aSJunchao Zhang   PetscFunctionReturn(0);
277076ba34aSJunchao Zhang }
278076ba34aSJunchao Zhang 
279076ba34aSJunchao Zhang /* MatSeqAIJKokkosBcast - Bcast rows of a SEQAIJKOKKOS matrice (B) to form a SEQAIJKOKKOS matrix (C).
280076ba34aSJunchao Zhang 
281076ba34aSJunchao Zhang    It is essentially the MPIAIJKOKKOS counterpart of MatGetBrowsOfAoCols_MPIAIJ, but supports device and uses PetscSF.
282076ba34aSJunchao Zhang    In the given ownerSF, leaves correspond to rows in C, and roots correspond to rows in B. Roots may connect to multiple leaves.
283076ba34aSJunchao Zhang    Suppose C's j-th row is connected to a root identified by PetscSFNode (k,i), it means we will bcast the i-th row of B on rank k
284076ba34aSJunchao Zhang    to j-th row of C. ownerSF's leaves must be contiguous (in other words, as if ilocal=NULL was used to set its graph).
285076ba34aSJunchao Zhang 
286076ba34aSJunchao Zhang    Collective on comm of ownerSF
287076ba34aSJunchao Zhang 
288076ba34aSJunchao Zhang    Input Parameters:
289076ba34aSJunchao Zhang +   B       - the SEQAIJKOKKOS matrix, using local col ids
290076ba34aSJunchao Zhang .   reuse   - either MAT_INITIAL_MATRIX or MAT_REUSE_MATRIX
291076ba34aSJunchao Zhang .   N       - global col ids are in range of [0,N). N Must be the same across ranks (nonsignificant in MAT_REUSE_MATRIX)
292076ba34aSJunchao Zhang .   l2g     - a map mapping B's local col ids to global ones (nonsignificant in MAT_REUSE_MATRIX)
293076ba34aSJunchao Zhang .   ownerSF - the ownership SF (nonsignificant in MAT_REUSE_MATRIX)
294076ba34aSJunchao Zhang 
295076ba34aSJunchao Zhang    Input/Output Parameters (out when resue = MAT_INITIAL_MATRIX, inout when reuse = MAT_REUSE_MATRIX)
296076ba34aSJunchao Zhang +   bcastSF   - the SF used to bcast rows of B. This plain SF does buffer (abuf) to buffer (Ca) send/recv. In this SF, vertices are nonzeros.
297076ba34aSJunchao Zhang .   abuf      - buffer for sending matrix values
298076ba34aSJunchao Zhang .   rows      - array containing indices of (local) rows that this rank needs to bcast to others. Each receiver rank has a chunk in rows[].
299076ba34aSJunchao Zhang                 Values in rows[] might have repeats, which simply indicates a row will be bcast'ed to multiple neighbors.
300076ba34aSJunchao Zhang .   rowoffset - For each row in rows[], it will be copied to rowoffset[] at abuf[]
301076ba34aSJunchao Zhang -   C         -  the SEQAIJKOKKOS matrix made of the bcast'ed rows, using local col ids.
302076ba34aSJunchao Zhang */
303076ba34aSJunchao Zhang static PetscErrorCode MatSeqAIJKokkosBcast(Mat B,MatReuse reuse,PetscInt N,const ConstMatColIdxKokkosView& l2g,PetscSF ownerSF,
304076ba34aSJunchao Zhang                                            PetscSF& bcastSF,MatScalarKokkosView& abuf,MatColIdxKokkosView& rows,
305076ba34aSJunchao Zhang                                            MatRowMapKokkosView& rowoffset,Mat& C)
306076ba34aSJunchao Zhang {
307076ba34aSJunchao Zhang   PetscErrorCode               ierr;
308076ba34aSJunchao Zhang   Mat_SeqAIJKokkos             *bkok,*ckok;
309076ba34aSJunchao Zhang 
310076ba34aSJunchao Zhang   PetscFunctionBegin;
311076ba34aSJunchao Zhang   ierr = MatSeqAIJKokkosSyncDevice(B);CHKERRQ(ierr); /* Make sure B->spptr is accessible */
312076ba34aSJunchao Zhang   bkok = static_cast<Mat_SeqAIJKokkos*>(B->spptr);
313076ba34aSJunchao Zhang 
314076ba34aSJunchao Zhang   if (reuse == MAT_REUSE_MATRIX) {
315076ba34aSJunchao Zhang     ckok = static_cast<Mat_SeqAIJKokkos*>(C->spptr);
316076ba34aSJunchao Zhang 
317076ba34aSJunchao Zhang     const auto& Ba = bkok->a_dual.view_device();
318076ba34aSJunchao Zhang     const auto& Bi = bkok->i_dual.view_device();
319076ba34aSJunchao Zhang     const auto& Ca = ckok->a_dual.view_device();
320076ba34aSJunchao Zhang 
321076ba34aSJunchao Zhang     /* Copy Ba to abuf */
322076ba34aSJunchao Zhang     Kokkos::parallel_for(Kokkos::TeamPolicy<>(rows.extent(0), Kokkos::AUTO()),KOKKOS_LAMBDA(const KokkosTeamMemberType& t) {
323076ba34aSJunchao Zhang       PetscInt i    = t.league_rank(); /* rows[i] is r-th row of B */
324076ba34aSJunchao Zhang       PetscInt r    = rows(i);
325076ba34aSJunchao Zhang       PetscInt base = rowoffset(i); /* Copy r-th row of B to this offset in abuf[] */
326076ba34aSJunchao Zhang       Kokkos::parallel_for(Kokkos::TeamThreadRange(t,Bi(r+1)-Bi(r)),[&](PetscInt k) {
327076ba34aSJunchao Zhang         abuf(base+k) = Ba(Bi(r)+k);
328076ba34aSJunchao Zhang       });
329076ba34aSJunchao Zhang     });
330076ba34aSJunchao Zhang 
331076ba34aSJunchao Zhang     /* Send abuf to Ca through bcastSF and then mark C is updated on device */
332076ba34aSJunchao Zhang     ierr = PetscSFBcastBegin(bcastSF,MPIU_SCALAR,abuf.data(),Ca.data(),MPI_REPLACE);CHKERRQ(ierr); /* TODO: get memtype for abuf */
333076ba34aSJunchao Zhang     ierr = PetscSFBcastEnd  (bcastSF,MPIU_SCALAR,abuf.data(),Ca.data(),MPI_REPLACE);CHKERRQ(ierr);
334076ba34aSJunchao Zhang     ckok->a_dual.modify_device();
335076ba34aSJunchao Zhang   } else if (reuse == MAT_INITIAL_MATRIX) {
336076ba34aSJunchao Zhang     MPI_Comm       comm;
337076ba34aSJunchao Zhang     PetscMPIInt    tag;
338076ba34aSJunchao Zhang     PetscInt       k,Cm,Cn,Cnnz,*Ci_h,nroots,nleaves;
339076ba34aSJunchao Zhang 
340076ba34aSJunchao Zhang     ierr = PetscObjectGetComm((PetscObject)ownerSF,&comm);CHKERRMPI(ierr);
341076ba34aSJunchao Zhang     ierr = PetscSFGetGraph(ownerSF,&nroots,&nleaves,NULL,NULL);CHKERRQ(ierr);
342076ba34aSJunchao Zhang     Cm   = nleaves; /* row size of C */
343076ba34aSJunchao Zhang     Cn   = N;  /* col size of C, which initially uses global ids, so we can safely set its col size as N */
344076ba34aSJunchao Zhang 
345076ba34aSJunchao Zhang     /* Get row lens (nz) of B's rows for later fast query */
346076ba34aSJunchao Zhang     PetscInt       *Browlens;
347076ba34aSJunchao Zhang     const PetscInt *tmp = bkok->i_host_data();
348076ba34aSJunchao Zhang     ierr = PetscMalloc1(nroots,&Browlens);CHKERRQ(ierr);
349076ba34aSJunchao Zhang     for (k=0; k<nroots; k++) Browlens[k] = tmp[k+1]-tmp[k];
350076ba34aSJunchao Zhang 
351076ba34aSJunchao Zhang     /* By ownerSF, each proc gets lens of rows of C */
352076ba34aSJunchao Zhang     MatRowMapKokkosDualView Ci("i",Cm+1); /* C's rowmap */
353076ba34aSJunchao Zhang     Ci_h    = Ci.view_host().data();
354076ba34aSJunchao Zhang     Ci_h[0] = 0;
355076ba34aSJunchao Zhang     ierr    = PetscSFBcastWithMemTypeBegin(ownerSF,MPIU_INT,PETSC_MEMTYPE_HOST,Browlens,PETSC_MEMTYPE_HOST,&Ci_h[1],MPI_REPLACE);CHKERRQ(ierr);
356076ba34aSJunchao Zhang     ierr    = PetscSFBcastEnd(ownerSF,MPIU_INT,Browlens,&Ci_h[1],MPI_REPLACE);CHKERRQ(ierr);
357076ba34aSJunchao Zhang     for (k=1; k<Cm+1; k++) Ci_h[k] += Ci_h[k-1]; /* Convert lens to CSR */
358076ba34aSJunchao Zhang     Cnnz    = Ci_h[Cm];
359076ba34aSJunchao Zhang     Ci.modify_host();
360076ba34aSJunchao Zhang     Ci.sync_device();
361076ba34aSJunchao Zhang 
362076ba34aSJunchao Zhang     /* With the newly known Cnnz, we are able to allocate (j, a) for C on host & device */
363076ba34aSJunchao Zhang     MatColIdxKokkosDualView  Cj("j",Cnnz);
364076ba34aSJunchao Zhang     MatScalarKokkosDualView  Ca("a",Cnnz);
365076ba34aSJunchao Zhang 
366076ba34aSJunchao Zhang     /* Now build the bcastSF to fill Ca, Cj. This plain SF only does (contiguous) buffer to buffer send/recv */
367076ba34aSJunchao Zhang     const PetscMPIInt *iranks,*ranks;
368076ba34aSJunchao Zhang     const PetscInt    *ioffset,*irootloc,*roffset;
369076ba34aSJunchao Zhang     PetscInt          i,j,niranks,nranks,*sdisp,*rdisp,*rowptr;
370076ba34aSJunchao Zhang     MPI_Request       *reqs;
371076ba34aSJunchao Zhang 
372076ba34aSJunchao Zhang     ierr = PetscSFGetLeafRanks(ownerSF,&niranks,&iranks,&ioffset,&irootloc);CHKERRQ(ierr); /* irootloc[] contains indices of rows I need to send to each receiver */
373076ba34aSJunchao Zhang     ierr = PetscSFGetRootRanks(ownerSF,&nranks,&ranks,&roffset,NULL/*rmine*/,NULL/*rremote*/);CHKERRQ(ierr); /* recv info */
374076ba34aSJunchao Zhang 
375076ba34aSJunchao Zhang     /* figure out offsets at the send buffer, to build the SF
376076ba34aSJunchao Zhang       sdisp[]  - stores offsets of nonzeros (in abuf or jbuf, see later) I need to send, per receiver.
377076ba34aSJunchao Zhang       rowptr[] - stores offsets for data of each row in abuf
378076ba34aSJunchao Zhang 
379076ba34aSJunchao Zhang       rdisp[]  - to receive sdisp[]
380076ba34aSJunchao Zhang     */
381076ba34aSJunchao Zhang     ierr = PetscMalloc3(niranks+1,&sdisp,nranks,&rdisp,niranks+nranks,&reqs);CHKERRQ(ierr);
382076ba34aSJunchao Zhang     MatRowMapKokkosViewHost rowptr_h("rowptr_h",ioffset[niranks]+1); /* Let Kokkos do the allocation, so that we can do an easy mirror later */
383076ba34aSJunchao Zhang     rowptr = rowptr_h.data();
384076ba34aSJunchao Zhang 
385076ba34aSJunchao Zhang     sdisp[0] = 0;
386076ba34aSJunchao Zhang     rowptr[0]  = 0;
387076ba34aSJunchao Zhang     for (i=0; i<niranks; i++) { /* for each receiver */
388076ba34aSJunchao Zhang       PetscInt len, nz = 0;
389076ba34aSJunchao Zhang       for (j=ioffset[i]; j<ioffset[i+1]; j++) { /* for each row to this receiver */
390076ba34aSJunchao Zhang         len         = Browlens[irootloc[j]];
391076ba34aSJunchao Zhang         rowptr[j+1] = rowptr[j] + len;
392076ba34aSJunchao Zhang         nz         += len;
393076ba34aSJunchao Zhang       }
394076ba34aSJunchao Zhang       sdisp[i+1] = sdisp[i] + nz;
395076ba34aSJunchao Zhang     }
396076ba34aSJunchao Zhang     ierr = PetscCommGetNewTag(comm,&tag);CHKERRMPI(ierr);
397076ba34aSJunchao Zhang     for (i=0; i<nranks; i++)  {ierr = MPI_Irecv(&rdisp[i],1,MPIU_INT,ranks[i],tag,comm,&reqs[i]);CHKERRMPI(ierr);}
398076ba34aSJunchao Zhang     for (i=0; i<niranks; i++) {ierr = MPI_Isend(&sdisp[i],1,MPIU_INT,iranks[i],tag,comm,&reqs[nranks+i]);CHKERRMPI(ierr);}
399076ba34aSJunchao Zhang     ierr = MPI_Waitall(niranks+nranks,reqs,MPI_STATUSES_IGNORE);CHKERRMPI(ierr);
400076ba34aSJunchao Zhang 
401076ba34aSJunchao Zhang     PetscInt    nleaves2 = Cnnz; /* leaves are the nonzeros I will receive */
402076ba34aSJunchao Zhang     PetscInt    nroots2  = sdisp[niranks]; /* roots are the nonzeros (in abuf) I will send */
403076ba34aSJunchao Zhang     PetscSFNode *iremote;
404076ba34aSJunchao Zhang     ierr = PetscMalloc1(nleaves2,&iremote);CHKERRQ(ierr);
405076ba34aSJunchao Zhang     for (i=0; i<nranks; i++) { /* for each sender */
406076ba34aSJunchao Zhang       k = 0;
407076ba34aSJunchao Zhang       for (j=Ci_h[roffset[i]]; j<Ci_h[roffset[i+1]]; j++) {
408076ba34aSJunchao Zhang         iremote[j].rank  = ranks[i];
409076ba34aSJunchao Zhang         iremote[j].index = rdisp[i] + k;
410076ba34aSJunchao Zhang         k++;
411076ba34aSJunchao Zhang       }
412076ba34aSJunchao Zhang     }
413076ba34aSJunchao Zhang     /* TODO: we should extend PetscSF APIs for this buffer-to-buffer send/recv */
414076ba34aSJunchao Zhang     ierr = PetscSFCreate(comm,&bcastSF);CHKERRQ(ierr);
415076ba34aSJunchao Zhang     ierr = PetscSFSetGraph(bcastSF,nroots2,nleaves2,NULL/*ilocal*/,PETSC_OWN_POINTER,iremote,PETSC_OWN_POINTER);CHKERRQ(ierr);
416076ba34aSJunchao Zhang 
417076ba34aSJunchao Zhang     /* Extract selected rows of B, and copy their (a, j) into abuf[] and jbuf[], with j converted
418076ba34aSJunchao Zhang       from local to global. Then use bcastSF to fill Ca, Cj.
419076ba34aSJunchao Zhang     */
420076ba34aSJunchao Zhang     ConstMatColIdxKokkosViewHost rows_h(irootloc,ioffset[niranks]); /* irootloc[] stores indices of rows I need to send */
421076ba34aSJunchao Zhang     MatColIdxKokkosView          rows("rows",ioffset[niranks]);
422076ba34aSJunchao Zhang     Kokkos::deep_copy(rows,rows_h); /* Use deep copy since irootoc is managed by PetscSF and we want 'rows' to be standalone */
423076ba34aSJunchao Zhang 
424076ba34aSJunchao Zhang     rowoffset = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(),rowptr_h); /* If no device, rowoffset will be an alias to rowptr_h */
425076ba34aSJunchao Zhang 
426076ba34aSJunchao Zhang     MatColIdxKokkosView jbuf("jbuf",sdisp[niranks]); /* send buf for (global) col ids */
427076ba34aSJunchao Zhang     abuf = MatScalarKokkosView("abuf",sdisp[niranks]); /* send buf for mat values */
428076ba34aSJunchao Zhang 
429076ba34aSJunchao Zhang     const auto& Ba = bkok->a_dual.view_device();
430076ba34aSJunchao Zhang     const auto& Bi = bkok->i_dual.view_device();
431076ba34aSJunchao Zhang     const auto& Bj = bkok->j_dual.view_device();
432076ba34aSJunchao Zhang 
433076ba34aSJunchao Zhang     /* Copy Ba, Bj to abuf, jbuf with change col ids from local to global */
434076ba34aSJunchao Zhang     Kokkos::parallel_for(Kokkos::TeamPolicy<>(rows.extent(0), Kokkos::AUTO()),KOKKOS_LAMBDA(const KokkosTeamMemberType& t) {
435076ba34aSJunchao Zhang       PetscInt i    = t.league_rank(); /* rows[i] is r-th row of B */
436076ba34aSJunchao Zhang       PetscInt r    = rows(i);
437076ba34aSJunchao Zhang       PetscInt base = rowoffset(i); /* Copy r-th row of B to this offset in abuf[] */
438076ba34aSJunchao Zhang       Kokkos::parallel_for(Kokkos::TeamThreadRange(t,Bi(r+1)-Bi(r)),[&](PetscInt k) {
439076ba34aSJunchao Zhang         abuf(base+k) = Ba(Bi(r)+k);
440076ba34aSJunchao Zhang         jbuf(base+k) = l2g(Bj(Bi(r)+k));
441076ba34aSJunchao Zhang       });
442076ba34aSJunchao Zhang     });
443076ba34aSJunchao Zhang 
444076ba34aSJunchao Zhang     /* Send abuf & jbuf to fill Ca, Cj */
445076ba34aSJunchao Zhang     ierr = PetscSFBcastBegin(bcastSF,MPIU_INT,   jbuf.data(),Cj.view_device().data(),MPI_REPLACE);CHKERRQ(ierr);
446076ba34aSJunchao Zhang     ierr = PetscSFBcastBegin(bcastSF,MPIU_SCALAR,abuf.data(),Ca.view_device().data(),MPI_REPLACE);CHKERRQ(ierr);
447076ba34aSJunchao Zhang     ierr = PetscSFBcastEnd  (bcastSF,MPIU_INT,   jbuf.data(),Cj.view_device().data(),MPI_REPLACE);CHKERRQ(ierr);
448076ba34aSJunchao Zhang     ierr = PetscSFBcastEnd  (bcastSF,MPIU_SCALAR,abuf.data(),Ca.view_device().data(),MPI_REPLACE);CHKERRQ(ierr);
449076ba34aSJunchao Zhang     Cj.modify_device(); /* Mark Cj, Ca modified on device, but only sync Cj since we might not need Ca on host at all */
450076ba34aSJunchao Zhang     Cj.sync_host();
451076ba34aSJunchao Zhang     Ca.modify_device();
452076ba34aSJunchao Zhang 
453076ba34aSJunchao Zhang     /* Construct C with Ca, Ci, Cj */
454076ba34aSJunchao Zhang     auto ckok = new Mat_SeqAIJKokkos(Cm,Cn,Cnnz,Ci,Cj,Ca);
455076ba34aSJunchao Zhang     ierr = MatCreateSeqAIJKokkosWithCSRMatrix(PETSC_COMM_SELF,ckok,&C);CHKERRQ(ierr);
456076ba34aSJunchao Zhang     ierr = PetscFree3(sdisp,rdisp,reqs);CHKERRQ(ierr);
457076ba34aSJunchao Zhang     ierr = PetscFree(Browlens);CHKERRQ(ierr);
45898921bdaSJacob Faibussowitsch   } else SETERRQ(PETSC_COMM_SELF,PETSC_ERR_PLIB,"Unsupported MatReuse enum %d",reuse);
459076ba34aSJunchao Zhang   PetscFunctionReturn(0);
460076ba34aSJunchao Zhang }
461076ba34aSJunchao Zhang 
462076ba34aSJunchao Zhang /* MatSeqAIJKokkosReduce - Reduce rows of a SEQAIJKOKKOS matrix (A) to form a Kokkos Csr matrix (C)
463076ba34aSJunchao Zhang 
464076ba34aSJunchao Zhang   It is the reverse of MatSeqAIJKokkosBcast in some sense.
465076ba34aSJunchao Zhang 
466076ba34aSJunchao Zhang   Think each row of A as a leaf, then the given ownerSF specifies roots of the leaves. Roots may connect to multiple leaves.
467076ba34aSJunchao Zhang   In this routine, we reduce (i.e., concatenate) leaves (rows) at their roots to form potentially longer rows in C. Such rows might
468076ba34aSJunchao Zhang   contain repeats, which does not matter since they will be summed up by other routines. C's row size will be nroots of ownerSF.
469076ba34aSJunchao Zhang 
470076ba34aSJunchao Zhang   Input Parameters:
471076ba34aSJunchao Zhang +  A        - the SEQAIJKOKKOS matrix to be reduced
472076ba34aSJunchao Zhang .  reuse    - either MAT_INITIAL_MATRIX or MAT_REUSE_MATRIX
473076ba34aSJunchao Zhang .  local    - true if A uses local col ids; false if A is already in global col ids.
474076ba34aSJunchao Zhang .  N        - if local, N is A's global col size
475076ba34aSJunchao Zhang .  l2g      - if local, a map mapping A's local col ids to global ones, which are in range of [0,N).
476076ba34aSJunchao Zhang -  ownerSF  - the SF specifies ownership (root) of rows in A
477076ba34aSJunchao Zhang 
478076ba34aSJunchao Zhang   Output Parameters:
479076ba34aSJunchao Zhang +  reduceSF    - the SF to reduce A's rows to contiguous buffers at the receiver side
480076ba34aSJunchao Zhang .  abuf         - a contiguous buffer to receive A's rows sent to this proc. Suppose there are 'nrows' such rows.
481076ba34aSJunchao Zhang .  srcrowoffset - offset array of size nrows+1. Each entry is the corresponding row's offset in abuf[]. srcrowoffset[i+1]-srcrowoffset[i] is row i's len.
482076ba34aSJunchao Zhang .  dstrowoffset - offset array of size nrows. Each entry is the corresponding row's offset in Ca[], i.e., C's 'a' array. Row i, i+1 in abuf[] may go to
483076ba34aSJunchao Zhang                   unrelated places in Ca, so dstrowoffset is not in CSR-like format as srcrowoffset.
484076ba34aSJunchao Zhang -  C            - the matrix made up by rows sent to me from other ranks, using global col ids
485076ba34aSJunchao Zhang 
486076ba34aSJunchao Zhang    TODO: we can even have MatSeqAIJKokkosReduceBegin/End to provide oppertunity for callers to overlap comp./comm. when reuse = MAT_REUSE_MATRIX.
487076ba34aSJunchao Zhang  */
488076ba34aSJunchao Zhang static PetscErrorCode MatSeqAIJKokkosReduce(Mat A,MatReuse reuse,PetscBool local,PetscInt N,const ConstMatColIdxKokkosView& l2g,PetscSF ownerSF,
489076ba34aSJunchao Zhang                                             PetscSF& reduceSF,MatScalarKokkosView& abuf,
490076ba34aSJunchao Zhang                                             MatRowMapKokkosView& srcrowoffset,MatRowMapKokkosView& dstrowoffset,
491076ba34aSJunchao Zhang                                             KokkosCsrMatrix& C)
492076ba34aSJunchao Zhang {
493076ba34aSJunchao Zhang   PetscErrorCode         ierr;
494076ba34aSJunchao Zhang   PetscInt               i,r,Am,An,Annz,Cnnz,nrows;
495076ba34aSJunchao Zhang   const PetscInt         *Ai;
496076ba34aSJunchao Zhang   Mat_SeqAIJKokkos       *akok;
497076ba34aSJunchao Zhang 
498076ba34aSJunchao Zhang   PetscFunctionBegin;
499076ba34aSJunchao Zhang   ierr = MatSeqAIJKokkosSyncDevice(A);CHKERRQ(ierr); /* So that A's latest data is on device */
500076ba34aSJunchao Zhang   ierr = MatGetSize(A,&Am,&An);
501076ba34aSJunchao Zhang   Ai   = static_cast<Mat_SeqAIJ*>(A->data)->i;
502076ba34aSJunchao Zhang   akok = static_cast<Mat_SeqAIJKokkos*>(A->spptr);
503076ba34aSJunchao Zhang   Annz = Ai[Am];
504076ba34aSJunchao Zhang 
505076ba34aSJunchao Zhang   if (reuse == MAT_REUSE_MATRIX) {
506076ba34aSJunchao Zhang     /* Send Aa to abuf */
507076ba34aSJunchao Zhang     ierr = PetscSFReduceBegin(reduceSF,MPIU_SCALAR,akok->a_device_data(),abuf.data(),MPI_REPLACE);CHKERRMPI(ierr);
508076ba34aSJunchao Zhang     ierr = PetscSFReduceEnd  (reduceSF,MPIU_SCALAR,akok->a_device_data(),abuf.data(),MPI_REPLACE);CHKERRMPI(ierr);
509076ba34aSJunchao Zhang 
510076ba34aSJunchao Zhang     /* Copy abuf to Ca */
511076ba34aSJunchao Zhang     const MatScalarKokkosView& Ca = C.values;
512076ba34aSJunchao Zhang     nrows = dstrowoffset.extent(0); /* Not srcrowoffset[] since it has an extra entry for CSR */
513076ba34aSJunchao Zhang     Kokkos::parallel_for(Kokkos::TeamPolicy<>(nrows, Kokkos::AUTO()),KOKKOS_LAMBDA(const KokkosTeamMemberType& t) {
514076ba34aSJunchao Zhang       PetscInt i   = t.league_rank();
515076ba34aSJunchao Zhang       PetscInt src = srcrowoffset(i), dst = dstrowoffset(i);
516076ba34aSJunchao Zhang       PetscInt len = srcrowoffset(i+1) - srcrowoffset(i);
517076ba34aSJunchao Zhang       Kokkos::parallel_for(Kokkos::TeamThreadRange(t,len), [&](PetscInt k) {Ca(dst+k) = abuf(src+k);});
518076ba34aSJunchao Zhang     });
519076ba34aSJunchao Zhang   } else if (reuse == MAT_INITIAL_MATRIX) {
520076ba34aSJunchao Zhang     MPI_Comm               comm;
521076ba34aSJunchao Zhang     MPI_Request            *reqs;
522076ba34aSJunchao Zhang     PetscMPIInt            tag;
523076ba34aSJunchao Zhang     PetscInt               Cm;
524076ba34aSJunchao Zhang 
525076ba34aSJunchao Zhang     ierr = PetscObjectGetComm((PetscObject)ownerSF,&comm);CHKERRQ(ierr);
526076ba34aSJunchao Zhang     ierr = PetscCommGetNewTag(comm,&tag);CHKERRQ(ierr);
527076ba34aSJunchao Zhang 
528076ba34aSJunchao Zhang     PetscInt niranks,nranks,nroots,nleaves;
529076ba34aSJunchao Zhang     const PetscMPIInt *iranks,*ranks;
530076ba34aSJunchao Zhang     const PetscInt *ioffset,*rows,*roffset;  /* rows[] contains local indices of rows scattered to me from others. ioffset[] is a CSR on rows[] */
531076ba34aSJunchao Zhang     ierr = PetscSFSetUp(ownerSF);CHKERRQ(ierr);
532076ba34aSJunchao Zhang     ierr = PetscSFGetLeafRanks(ownerSF,&niranks,&iranks,&ioffset,&rows);CHKERRQ(ierr); /* recv info: iranks[] will send rows to me */
533076ba34aSJunchao Zhang     ierr = PetscSFGetRootRanks(ownerSF,&nranks,&ranks,&roffset,NULL/*rmine*/,NULL/*rremote*/);CHKERRQ(ierr); /* send info */
534076ba34aSJunchao Zhang     ierr = PetscSFGetGraph(ownerSF,&nroots,&nleaves,NULL,NULL);CHKERRQ(ierr);
5352c71b3e2SJacob Faibussowitsch     PetscCheckFalse(nleaves != Am,PETSC_COMM_SELF,PETSC_ERR_PLIB,"ownerSF's nleaves(%" PetscInt_FMT ") != row size of A(%" PetscInt_FMT ")",nleaves,Am);
536076ba34aSJunchao Zhang     Cm    = nroots;
537076ba34aSJunchao Zhang     nrows = ioffset[niranks]; /* # of rows to be received. Might receive same row (each is partial) from different senders */
538076ba34aSJunchao Zhang 
539076ba34aSJunchao Zhang     /* Tell owners how long each row I will send */
540076ba34aSJunchao Zhang     PetscInt                *srowlens; /* send buf of row lens */
541076ba34aSJunchao Zhang     MatRowMapKokkosViewHost rrowlens_h("rrowoffset_h",nrows+1); /* recv buf of row lens. +1 to make CSR later. Memory might be passed to other views */
542076ba34aSJunchao Zhang     PetscInt                *rrowlens = rrowlens_h.data();
543076ba34aSJunchao Zhang 
544076ba34aSJunchao Zhang     ierr = PetscMalloc2(Am,&srowlens,niranks+nranks,&reqs);CHKERRQ(ierr);
545076ba34aSJunchao Zhang     for (i=0; i<Am; i++) srowlens[i] = Ai[i+1] - Ai[i];
546076ba34aSJunchao Zhang     rrowlens[0] = 0;
547076ba34aSJunchao Zhang     rrowlens++; /* shift the pointer to make the following expression more readable */
548076ba34aSJunchao Zhang     for (i=0; i<niranks; i++){ierr = MPI_Irecv(&rrowlens[ioffset[i]],ioffset[i+1]-ioffset[i],MPIU_INT,iranks[i],tag,comm,&reqs[i]);CHKERRMPI(ierr);}
549076ba34aSJunchao Zhang     for (i=0; i<nranks; i++) {ierr = MPI_Isend(&srowlens[roffset[i]],roffset[i+1]-roffset[i],MPIU_INT,ranks[i],tag,comm,&reqs[niranks+i]);CHKERRMPI(ierr);}
550076ba34aSJunchao Zhang     ierr = MPI_Waitall(niranks+nranks,reqs,MPI_STATUSES_IGNORE);CHKERRMPI(ierr);
551076ba34aSJunchao Zhang 
552076ba34aSJunchao Zhang     /* Owner builds Ci on host by histogramming rrowlens[] */
553076ba34aSJunchao Zhang     MatRowMapKokkosViewHost Ci_h("i",Cm+1);
554076ba34aSJunchao Zhang     Kokkos::deep_copy(Ci_h,0); /* Zero Ci */
555076ba34aSJunchao Zhang     MatRowMapType *Ci_ptr = Ci_h.data();
556076ba34aSJunchao Zhang 
557076ba34aSJunchao Zhang     for (i=0; i<nrows; i++) {
558076ba34aSJunchao Zhang       r = rows[i]; /* local row id of i-th received row */
559076ba34aSJunchao Zhang      #if defined(PETSC_USE_DEBUG)
5602c71b3e2SJacob Faibussowitsch       PetscCheckFalse(r<0 || r>=Cm,PETSC_COMM_SELF,PETSC_ERR_PLIB,"local row id (%" PetscInt_FMT ") is out of range [0,%" PetscInt_FMT ")",r,Cm);
561076ba34aSJunchao Zhang      #endif
562076ba34aSJunchao Zhang       Ci_ptr[r+1] += rrowlens[i]; /* add to length of row r in C */
563076ba34aSJunchao Zhang     }
564076ba34aSJunchao Zhang     for (i=0; i<Cm; i++) Ci_ptr[i+1] += Ci_ptr[i]; /* to CSR format */
565076ba34aSJunchao Zhang     Cnnz = Ci_ptr[Cm];
566076ba34aSJunchao Zhang 
567076ba34aSJunchao Zhang     /* For each received row, compute src & dst offsets in memory copying (from recv bufs abuf, jbuf to Ca, Cj) */
568076ba34aSJunchao Zhang     MatRowMapKokkosViewHost dstrowoffset_h("dstrowoffset_h",nrows);
569076ba34aSJunchao Zhang     PetscInt                *dstrowoffset_hptr = dstrowoffset_h.data();
570076ba34aSJunchao Zhang     PetscInt                *currowlens; /* Current row lens. They are temp accumulators for row lens in C, to help build dstrowoffset */
571076ba34aSJunchao Zhang 
572076ba34aSJunchao Zhang     ierr = PetscCalloc1(Cm,&currowlens);CHKERRQ(ierr); /* Init with zero, to be added to */
573076ba34aSJunchao Zhang     for (i=0; i<nrows; i++) { /* for each row I receive */
574076ba34aSJunchao Zhang       r                    = rows[i]; /* row id in C */
575076ba34aSJunchao Zhang       dstrowoffset_hptr[i] = Ci_ptr[r] + currowlens[r]; /* dst offset of the new place for each recv'ed row in Ca/Cj */
576076ba34aSJunchao Zhang       currowlens[r]       += rrowlens[i]; /* accumulate to length of row r in C */
577076ba34aSJunchao Zhang     }
578076ba34aSJunchao Zhang     ierr = PetscFree(currowlens);CHKERRQ(ierr);
579076ba34aSJunchao Zhang 
580076ba34aSJunchao Zhang     rrowlens--;
581076ba34aSJunchao Zhang     for (i=0; i<nrows; i++) rrowlens[i+1] += rrowlens[i]; /* Change rrowlens[] to CSR format */
582076ba34aSJunchao Zhang     dstrowoffset = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(),dstrowoffset_h);
583076ba34aSJunchao Zhang     srcrowoffset = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(),rrowlens_h); /* src offset of each recv'ed row in abuf/jbuf */
584076ba34aSJunchao Zhang 
585076ba34aSJunchao Zhang     /* Build the reduceSF, which performs buffer to buffer send/recv */
586076ba34aSJunchao Zhang     PetscInt *sdisp,*rdisp; /* buffer to send offsets of roots, and buffer to recv them */
587076ba34aSJunchao Zhang     ierr = PetscMalloc2(niranks,&sdisp,nranks,&rdisp);CHKERRQ(ierr);
588076ba34aSJunchao Zhang     for (i=0; i<niranks; i++) sdisp[i] = rrowlens[ioffset[i]];
589076ba34aSJunchao Zhang     for (i=0; i<nranks; i++)  {ierr = MPI_Irecv(&rdisp[i],1,MPIU_INT,ranks[i],tag,comm,&reqs[i]);CHKERRMPI(ierr);}
590076ba34aSJunchao Zhang     for (i=0; i<niranks; i++) {ierr = MPI_Isend(&sdisp[i],1,MPIU_INT,iranks[i],tag,comm,&reqs[nranks+i]);CHKERRMPI(ierr);}
591076ba34aSJunchao Zhang     ierr = MPI_Waitall(niranks+nranks,reqs,MPI_STATUSES_IGNORE);CHKERRMPI(ierr);
592076ba34aSJunchao Zhang 
593076ba34aSJunchao Zhang     /* Nonzeros in abuf/jbuf are roots and those in A are leaves */
594076ba34aSJunchao Zhang     PetscInt    nroots2 = Cnnz,nleaves2 = Annz;
595076ba34aSJunchao Zhang     PetscSFNode *iremote;
596076ba34aSJunchao Zhang     ierr = PetscMalloc1(nleaves2,&iremote);CHKERRQ(ierr); /* no free, since memory will be given to reduceSF */
597076ba34aSJunchao Zhang     for (i=0; i<nranks; i++) {
598076ba34aSJunchao Zhang       PetscInt rootbase = rdisp[i]; /* root offset at this root rank */
599076ba34aSJunchao Zhang       PetscInt leafbase = Ai[roffset[i]]; /* leaf base */
600076ba34aSJunchao Zhang       PetscInt nz       = Ai[roffset[i+1]] - leafbase; /* I will send nz nonzeros to this root rank */
601076ba34aSJunchao Zhang       for (PetscInt k=0; k<nz; k++) {
602076ba34aSJunchao Zhang         iremote[leafbase+k].rank  = ranks[i];
603076ba34aSJunchao Zhang         iremote[leafbase+k].index = rootbase + k;
604076ba34aSJunchao Zhang       }
605076ba34aSJunchao Zhang     }
606076ba34aSJunchao Zhang     ierr = PetscSFCreate(comm,&reduceSF);CHKERRQ(ierr);
607076ba34aSJunchao Zhang     ierr = PetscSFSetGraph(reduceSF,nroots2,nleaves2,NULL,PETSC_OWN_POINTER,iremote,PETSC_OWN_POINTER);CHKERRQ(ierr);
608076ba34aSJunchao Zhang     ierr = PetscFree2(sdisp,rdisp);CHKERRQ(ierr);
609076ba34aSJunchao Zhang 
610076ba34aSJunchao Zhang     /* Reduce Aa, Ajg to abuf and jbuf */
611076ba34aSJunchao Zhang 
612076ba34aSJunchao Zhang     /* If A uses local col ids, convert them to global ones before sending */
613076ba34aSJunchao Zhang     MatColIdxKokkosView Ajg;
614076ba34aSJunchao Zhang     if (local) {
615076ba34aSJunchao Zhang       Ajg = MatColIdxKokkosView("j",Annz);
616076ba34aSJunchao Zhang       const MatColIdxKokkosView& Aj = akok->j_dual.view_device();
617076ba34aSJunchao Zhang       Kokkos::parallel_for(Annz,KOKKOS_LAMBDA(const PetscInt i) {Ajg(i) = l2g(Aj(i));});
618076ba34aSJunchao Zhang     } else {
619076ba34aSJunchao Zhang       Ajg = akok->j_dual.view_device(); /* no data copy, just take a reference */
620076ba34aSJunchao Zhang     }
621076ba34aSJunchao Zhang 
622076ba34aSJunchao Zhang     MatColIdxKokkosView   jbuf("jbuf",Cnnz);
623076ba34aSJunchao Zhang     abuf = MatScalarKokkosView("abuf",Cnnz);
624076ba34aSJunchao Zhang     ierr = PetscSFReduceBegin(reduceSF,MPIU_INT,   Ajg.data(),           jbuf.data(),MPI_REPLACE);CHKERRMPI(ierr);
625076ba34aSJunchao Zhang     ierr = PetscSFReduceEnd  (reduceSF,MPIU_INT,   Ajg.data(),           jbuf.data(),MPI_REPLACE);CHKERRMPI(ierr);
626076ba34aSJunchao Zhang     ierr = PetscSFReduceBegin(reduceSF,MPIU_SCALAR,akok->a_device_data(),abuf.data(),MPI_REPLACE);CHKERRMPI(ierr);
627076ba34aSJunchao Zhang     ierr = PetscSFReduceEnd  (reduceSF,MPIU_SCALAR,akok->a_device_data(),abuf.data(),MPI_REPLACE);CHKERRMPI(ierr);
628076ba34aSJunchao Zhang 
629076ba34aSJunchao Zhang     /* Copy data from abuf, jbuf to Ca, Cj */
630076ba34aSJunchao Zhang     MatRowMapKokkosView    Ci = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(),Ci_h); /* Ci is an alias of Ci_h if no device */
631076ba34aSJunchao Zhang     MatColIdxKokkosView    Cj("j",Cnnz);
632076ba34aSJunchao Zhang     MatScalarKokkosView    Ca("a",Cnnz);
633076ba34aSJunchao Zhang 
634076ba34aSJunchao Zhang     Kokkos::parallel_for(Kokkos::TeamPolicy<>(nrows, Kokkos::AUTO()),KOKKOS_LAMBDA(const KokkosTeamMemberType& t) {
635076ba34aSJunchao Zhang       PetscInt i   = t.league_rank();
636076ba34aSJunchao Zhang       PetscInt src = srcrowoffset(i), dst = dstrowoffset(i);
637076ba34aSJunchao Zhang       PetscInt len = srcrowoffset(i+1) - srcrowoffset(i);
638076ba34aSJunchao Zhang       Kokkos::parallel_for(Kokkos::TeamThreadRange(t,len), [&](PetscInt k) {
639076ba34aSJunchao Zhang         Ca(dst+k) = abuf(src+k);
640076ba34aSJunchao Zhang         Cj(dst+k) = jbuf(src+k);
641076ba34aSJunchao Zhang       });
642076ba34aSJunchao Zhang     });
643076ba34aSJunchao Zhang 
644076ba34aSJunchao Zhang     /* Build C with Ca, Ci, Cj */
645076ba34aSJunchao Zhang     C    = KokkosCsrMatrix("csrmat",Cm,N,Cnnz,Ca,Ci,Cj);
646076ba34aSJunchao Zhang     ierr = PetscFree2(srowlens,reqs);CHKERRQ(ierr);
64798921bdaSJacob Faibussowitsch   } else SETERRQ(PETSC_COMM_SELF,PETSC_ERR_PLIB,"Unsupported MatReuse enum %d",reuse);
648076ba34aSJunchao Zhang   PetscFunctionReturn(0);
649076ba34aSJunchao Zhang }
650076ba34aSJunchao Zhang 
651076ba34aSJunchao Zhang /* MatSetMPIAIJKokkosWithGlobalCSRMatrix - Set the diag and offdiag parts of a MATMPIAIJKOKKOS matrix by splitting a KokkosCsrMatrix
652076ba34aSJunchao Zhang 
653076ba34aSJunchao Zhang   Input Parameters:
654076ba34aSJunchao Zhang +  C        - the MATMPIAIJKOKKOS matrix, of size m,n,M,N
655076ba34aSJunchao Zhang .  reuse    - indicate whether the matrix has called this function before
656076ba34aSJunchao Zhang .  csrmat   - the KokkosCsrMatrix, of size m,N
657076ba34aSJunchao Zhang -  Cdstart  - when reuse == MAT_REUSE_MATRIX, it is an input parameter. For each row in csrmat, it stores the start of the first
658076ba34aSJunchao Zhang               entry of the diag block of C in csrmat's j array. E.g, if row i has col ids = {0, 3, 4, 5, 7, 9} and the first diag
659076ba34aSJunchao Zhang               entry is 5, then Cdstart[i] = 3.
660076ba34aSJunchao Zhang 
661076ba34aSJunchao Zhang   Output Parameters:
662076ba34aSJunchao Zhang +  C        - the updated MATMPIAIJKOKKOS matrix
663076ba34aSJunchao Zhang -  Cdstart - when reuse == MAT_INITIAL_MATRIX, it is an output parameter
664076ba34aSJunchao Zhang 
665076ba34aSJunchao Zhang   Notes:
666076ba34aSJunchao Zhang    Between calls with MAT_INITIAL_MATRIX or MAT_REUSE_MATRIX, csrmat must have the same nonzero pattern
667076ba34aSJunchao Zhang  */
668076ba34aSJunchao Zhang static PetscErrorCode MatSetMPIAIJKokkosWithGlobalCSRMatrix(Mat C,MatReuse reuse,const KokkosCsrMatrix& csrmat,MatRowMapKokkosView& Cdstart)
669076ba34aSJunchao Zhang {
670076ba34aSJunchao Zhang   PetscErrorCode                  ierr;
671076ba34aSJunchao Zhang   const MatScalarKokkosView&      Ca = csrmat.values;
672076ba34aSJunchao Zhang   const ConstMatRowMapKokkosView& Ci = csrmat.graph.row_map;
673076ba34aSJunchao Zhang   PetscInt                        m,n,N;
674076ba34aSJunchao Zhang 
675076ba34aSJunchao Zhang   PetscFunctionBegin;
676076ba34aSJunchao Zhang   ierr = MatGetLocalSize(C,&m,&n);CHKERRQ(ierr);
677076ba34aSJunchao Zhang   ierr = MatGetSize(C,NULL,&N);CHKERRQ(ierr);
678076ba34aSJunchao Zhang 
679076ba34aSJunchao Zhang   if (reuse == MAT_REUSE_MATRIX) {
680076ba34aSJunchao Zhang     Mat_MPIAIJ                  *mpiaij = static_cast<Mat_MPIAIJ*>(C->data);
681076ba34aSJunchao Zhang     Mat_SeqAIJKokkos            *akok = static_cast<Mat_SeqAIJKokkos*>(mpiaij->A->spptr);
682076ba34aSJunchao Zhang     Mat_SeqAIJKokkos            *bkok = static_cast<Mat_SeqAIJKokkos*>(mpiaij->B->spptr);
683076ba34aSJunchao Zhang     const MatScalarKokkosView&  Cda = akok->a_dual.view_device(),Coa = bkok->a_dual.view_device();
684076ba34aSJunchao Zhang     const MatRowMapKokkosView&  Cdi = akok->i_dual.view_device(),Coi = bkok->i_dual.view_device();
685076ba34aSJunchao Zhang 
686076ba34aSJunchao Zhang     /* Fill 'a' of Cd and Co on device */
687076ba34aSJunchao Zhang     Kokkos::parallel_for(Kokkos::TeamPolicy<>(m, Kokkos::AUTO()),KOKKOS_LAMBDA(const KokkosTeamMemberType& t) {
688076ba34aSJunchao Zhang       PetscInt i       = t.league_rank(); /* row i */
689076ba34aSJunchao Zhang       PetscInt clen    = Ci(i+1) - Ci(i); /* len of row i of C */
690076ba34aSJunchao Zhang       PetscInt cdlen   = Cdi(i+1) - Cdi(i); /* len of row i of Cd */
691076ba34aSJunchao Zhang       PetscInt cdstart = Cdstart(i); /* [start, end) of row i of Cd in C */
692076ba34aSJunchao Zhang       PetscInt cdend   = cdstart + cdlen;
693076ba34aSJunchao Zhang       /* [0, clen) is cut into three blocks: [0, cdstart), [cdstart, cdend), [cdend, clen) */
694076ba34aSJunchao Zhang       Kokkos::parallel_for(Kokkos::TeamThreadRange(t, clen), [&](PetscInt k) {
695076ba34aSJunchao Zhang         if (k < cdstart) {  /* k in [0, cdstart) */
696076ba34aSJunchao Zhang           Coa(Coi(i)+k) = Ca(Ci(i)+k);
697076ba34aSJunchao Zhang         } else if (k < cdend) { /* k in [cdstart, cdend) */
698076ba34aSJunchao Zhang           Cda(Cdi(i)+(k-cdstart)) = Ca(Ci(i)+k);
699076ba34aSJunchao Zhang         } else { /* k in [cdend, clen) */
700076ba34aSJunchao Zhang           Coa(Coi(i)+k-cdlen) = Ca(Ci(i)+k);
701076ba34aSJunchao Zhang         }
702076ba34aSJunchao Zhang       });
703076ba34aSJunchao Zhang     });
704076ba34aSJunchao Zhang 
705076ba34aSJunchao Zhang     akok->a_dual.modify_device();
706076ba34aSJunchao Zhang     bkok->a_dual.modify_device();
707076ba34aSJunchao Zhang   } else if (reuse == MAT_INITIAL_MATRIX) {
708076ba34aSJunchao Zhang     Mat                         Cd,Co;
709076ba34aSJunchao Zhang     const MatColIdxKokkosView&  Cj = csrmat.graph.entries;
710076ba34aSJunchao Zhang     MatRowMapKokkosDualView     Cdi_dual("i",m+1),Coi_dual("i",m+1);
711076ba34aSJunchao Zhang     MatRowMapKokkosView         Cdi = Cdi_dual.view_device(),Coi = Coi_dual.view_device();
712076ba34aSJunchao Zhang     PetscInt                    cstart,cend;
713076ba34aSJunchao Zhang 
714076ba34aSJunchao Zhang     /* Note that each row of C is sorted by col ids. We want to find out how to cut each row into three blocks:
715076ba34aSJunchao Zhang        left to the diag block, diag block, right to the diag block. The diag block have col ids in [cstart,cend).
716076ba34aSJunchao Zhang        Suppose a row of C has len nonzeros, indexed by [0, len). We want to know two indices: cdstart and cdend,
717076ba34aSJunchao Zhang        such that the three blocks are [0,cdstart), [cdstart,cdend), [cdend,len). The following code equivalentaly
718076ba34aSJunchao Zhang        stores values of cdstart and cdend-cstart (aka Cdi[]) instead.
719076ba34aSJunchao Zhang      */
720076ba34aSJunchao Zhang     Cdstart = MatRowMapKokkosView("Cdstart",m);
721076ba34aSJunchao Zhang     ierr    = PetscLayoutGetRange(C->cmap,&cstart,&cend);CHKERRQ(ierr); /* Not MatGetOwnershipRangeColumn() since C has not been preallocated yet */
722076ba34aSJunchao Zhang 
723076ba34aSJunchao Zhang     /* I could use RangePolicy and one thread per row. But since each thread essentially does binary search, threads in a
724076ba34aSJunchao Zhang       CUDA warp would completely diverge. So I use TeamPolicy with a team size 1.
725076ba34aSJunchao Zhang      */
726076ba34aSJunchao Zhang     Kokkos::parallel_for(Kokkos::TeamPolicy<>(m, 1),KOKKOS_LAMBDA(const KokkosTeamMemberType& t) {
727076ba34aSJunchao Zhang       Kokkos::single(Kokkos::PerTeam(t), [=] () { /* Only one thread works in a team */
728076ba34aSJunchao Zhang         PetscInt i = t.league_rank(); /* row i */
729076ba34aSJunchao Zhang         PetscInt j,first,count,step;
730076ba34aSJunchao Zhang 
731076ba34aSJunchao Zhang         if (i == 0) { /* Set the first entry of the i arrays to zero on device, to be used in CSR */
732076ba34aSJunchao Zhang           Cdi(0) = 0;
733076ba34aSJunchao Zhang           Coi(0) = 0;
734076ba34aSJunchao Zhang         }
735076ba34aSJunchao Zhang 
736076ba34aSJunchao Zhang         /* Do std::lower_bound(Ci(i),Ci(i+1),cstart) on Cj[]. We use j as the iterator. lower_bound() returns
737076ba34aSJunchao Zhang           in 'first' the first iterator with a value >= cstart, or last iterator if no such element is found.
738076ba34aSJunchao Zhang         */
739076ba34aSJunchao Zhang         count = Ci(i+1)-Ci(i);
740076ba34aSJunchao Zhang         first = Ci(i);
741076ba34aSJunchao Zhang         while (count > 0) {
742076ba34aSJunchao Zhang           j    = first;
743076ba34aSJunchao Zhang           step = count / 2;
744076ba34aSJunchao Zhang           j   += step;
745076ba34aSJunchao Zhang           if (Cj(j) < cstart) {
746076ba34aSJunchao Zhang             first  = ++j;
747076ba34aSJunchao Zhang             count -= step + 1;
748076ba34aSJunchao Zhang           } else count = step;
749076ba34aSJunchao Zhang         }
750076ba34aSJunchao Zhang         Cdstart(i) = first - Ci(i); /* 'first' is the while-loop's output */
751076ba34aSJunchao Zhang 
752076ba34aSJunchao Zhang         /* Do std::lower_bound(first,Ci(i+1),cend) on Cj[] */
753076ba34aSJunchao Zhang         count = Ci(i+1) - first;
754076ba34aSJunchao Zhang         while (count > 0) {
755076ba34aSJunchao Zhang           j    = first;
756076ba34aSJunchao Zhang           step = count / 2;
757076ba34aSJunchao Zhang           j   += step;
758076ba34aSJunchao Zhang           if (Cj(j) < cend) {
759076ba34aSJunchao Zhang             first  = ++j;
760076ba34aSJunchao Zhang             count -= step + 1;
761076ba34aSJunchao Zhang           } else count = step;
762076ba34aSJunchao Zhang         }
763076ba34aSJunchao Zhang         Cdi(i+1) = first - (Ci(i)+Cdstart(i)); /* 'first' is the while-loop's output */
764076ba34aSJunchao Zhang         Coi(i+1) = (Ci(i+1)-Ci(i)) - Cdi(i+1); /* Co's row len = C's row len - Cd's row len */
765076ba34aSJunchao Zhang       });
766076ba34aSJunchao Zhang     });
767076ba34aSJunchao Zhang 
768076ba34aSJunchao Zhang     /* Convert row lens in Cdi[], Coi[] to CSR format using inclusive scan, e.g., changing [0,1,2,3] into [0,1,3,6] */
769076ba34aSJunchao Zhang     Kokkos::parallel_scan(m+1,KOKKOS_LAMBDA(const PetscInt i,PetscInt& update,const bool final) {
770076ba34aSJunchao Zhang       update += Cdi(i);
771076ba34aSJunchao Zhang       if (final) Cdi(i) = update;
772076ba34aSJunchao Zhang     });
773076ba34aSJunchao Zhang     Kokkos::parallel_scan(m+1,KOKKOS_LAMBDA(const PetscInt i,PetscInt& update,const bool final) {
774076ba34aSJunchao Zhang       update += Coi(i);
775076ba34aSJunchao Zhang       if (final) Coi(i) = update;
776076ba34aSJunchao Zhang     });
777076ba34aSJunchao Zhang 
778076ba34aSJunchao Zhang     /* Get Cdi, Coi on host (it is not a waste, since we do need them on host in
779076ba34aSJunchao Zhang        MatCreateSeqAIJKokkosWithCSRMatrix() below), then get nnz of Cd and Co.
780076ba34aSJunchao Zhang     */
781076ba34aSJunchao Zhang     Cdi_dual.modify_device();
782076ba34aSJunchao Zhang     Coi_dual.modify_device();
783076ba34aSJunchao Zhang     Cdi_dual.sync_host();
784076ba34aSJunchao Zhang     Coi_dual.sync_host();
785076ba34aSJunchao Zhang     PetscInt Cd_nnz = Cdi_dual.view_host().data()[m];
786076ba34aSJunchao Zhang     PetscInt Co_nnz = Coi_dual.view_host().data()[m];
787076ba34aSJunchao Zhang 
788076ba34aSJunchao Zhang     /* With nnz, allocate a, j for Cd and Co */
789076ba34aSJunchao Zhang     MatColIdxKokkosDualView Cdj_dual("j",Cd_nnz),Coj_dual("j",Co_nnz);
790076ba34aSJunchao Zhang     MatScalarKokkosDualView Cda_dual("a",Cd_nnz),Coa_dual("a",Co_nnz);
791076ba34aSJunchao Zhang 
792076ba34aSJunchao Zhang     /* Fill a, j of Cd and Co on device */
793076ba34aSJunchao Zhang     MatColIdxKokkosView     Cdj = Cdj_dual.view_device(),Coj = Coj_dual.view_device();
794076ba34aSJunchao Zhang     MatScalarKokkosView     Cda = Cda_dual.view_device(),Coa = Coa_dual.view_device();
795076ba34aSJunchao Zhang 
796076ba34aSJunchao Zhang     Kokkos::parallel_for(Kokkos::TeamPolicy<>(m, Kokkos::AUTO()),KOKKOS_LAMBDA(const KokkosTeamMemberType& t) {
797076ba34aSJunchao Zhang       PetscInt i       = t.league_rank(); /* row i */
798076ba34aSJunchao Zhang       PetscInt clen    = Ci(i+1) - Ci(i); /* len of row i of C */
799076ba34aSJunchao Zhang       PetscInt cdlen   = Cdi(i+1) - Cdi(i); /* len of row i of Cd */
800076ba34aSJunchao Zhang       PetscInt cdstart = Cdstart(i); /* [start, end) of row i of Cd in C */
801076ba34aSJunchao Zhang       PetscInt cdend   = cdstart + cdlen;
802076ba34aSJunchao Zhang       /* [0, clen) is cut into three blocks: [0, cdstart), [cdstart, cdend), [cdend, clen) */
803076ba34aSJunchao Zhang       Kokkos::parallel_for(Kokkos::TeamThreadRange(t, clen), [&](PetscInt k) {
804076ba34aSJunchao Zhang         if (k < cdstart) { /* k in [0, cdstart) */
805076ba34aSJunchao Zhang           Coa(Coi(i)+k) = Ca(Ci(i)+k);
806076ba34aSJunchao Zhang           Coj(Coi(i)+k) = Cj(Ci(i)+k);
807076ba34aSJunchao Zhang         } else if (k < cdend) { /* k in [cdstart, cdend) */
808076ba34aSJunchao Zhang           Cda(Cdi(i)+(k-cdstart)) = Ca(Ci(i)+k);
809076ba34aSJunchao Zhang           Cdj(Cdi(i)+(k-cdstart)) = Cj(Ci(i)+k) - cstart; /* Use local col ids in Cdj */
810076ba34aSJunchao Zhang         } else { /* k in [cdend, clen) */
811076ba34aSJunchao Zhang           Coa(Coi(i)+k-cdlen) = Ca(Ci(i)+k);
812076ba34aSJunchao Zhang           Coj(Coi(i)+k-cdlen) = Cj(Ci(i)+k);
813076ba34aSJunchao Zhang         }
814076ba34aSJunchao Zhang       });
815076ba34aSJunchao Zhang     });
816076ba34aSJunchao Zhang 
817076ba34aSJunchao Zhang     Cdj_dual.modify_device();
818076ba34aSJunchao Zhang     Cda_dual.modify_device();
819076ba34aSJunchao Zhang     Coj_dual.modify_device();
820076ba34aSJunchao Zhang     Coa_dual.modify_device();
821076ba34aSJunchao Zhang     /* With a, i, j for Cd and Co, finally build Cd, Co and then C. Their offloadmask will be set in each's MatAssemblyEnd */
822076ba34aSJunchao Zhang     auto cdkok = new Mat_SeqAIJKokkos(m,n,Cd_nnz,Cdi_dual,Cdj_dual,Cda_dual);
823076ba34aSJunchao Zhang     auto cokok = new Mat_SeqAIJKokkos(m,N,Co_nnz,Coi_dual,Coj_dual,Coa_dual);
824076ba34aSJunchao Zhang     ierr = MatCreateSeqAIJKokkosWithCSRMatrix(PETSC_COMM_SELF,cdkok,&Cd);CHKERRQ(ierr);
825076ba34aSJunchao Zhang     ierr = MatCreateSeqAIJKokkosWithCSRMatrix(PETSC_COMM_SELF,cokok,&Co);CHKERRQ(ierr);
826076ba34aSJunchao Zhang     ierr = MatSetMPIAIJKokkosWithSplitSeqAIJKokkosMatrices(C,Cd,Co);CHKERRQ(ierr); /* Coj will be converted to local ids within */
827076ba34aSJunchao Zhang   }
828076ba34aSJunchao Zhang   PetscFunctionReturn(0);
829076ba34aSJunchao Zhang }
830076ba34aSJunchao Zhang 
831076ba34aSJunchao Zhang /* MatSeqAIJCompactOutExtraColumns_SeqAIJKokkos - Compact a SEQAIJKOKKS matrix's global col ids.
832076ba34aSJunchao Zhang 
833076ba34aSJunchao Zhang   It is similar to MatSeqAIJCompactOutExtraColumns_SeqAIJ, but it applies to SEQAIJKOKKOS and returns the l2g map in Kokkos view.
834076ba34aSJunchao Zhang 
835076ba34aSJunchao Zhang   Input Parameters:
836076ba34aSJunchao Zhang +  C        - the MATMPIAIJKOKKOS matrix, of size m,n,M,N
837076ba34aSJunchao Zhang .  reuse    - indicate whether the matrix has called this function before
838076ba34aSJunchao Zhang .  csrmat   - the KokkosCsrMatrix, of size m,N
839076ba34aSJunchao Zhang -  Cdoffset - when reuse == MAT_REUSE_MATRIX, it is an input parameter. For each row in csrmat, it stores the offset of the first
840076ba34aSJunchao Zhang               entry of the diag block of C in csrmat's j array.
841076ba34aSJunchao Zhang 
842076ba34aSJunchao Zhang   Output Parameters:
843076ba34aSJunchao Zhang +  C        - the updated MATMPIAIJKOKKOS matrix
844076ba34aSJunchao Zhang -  Cdoffset - when reuse == MAT_INITIAL_MATRIX, it is an output parameter
845076ba34aSJunchao Zhang 
846076ba34aSJunchao Zhang   Notes: the input matrix's col ids and col size will be changed.
847076ba34aSJunchao Zhang */
848076ba34aSJunchao Zhang static PetscErrorCode MatSeqAIJCompactOutExtraColumns_SeqAIJKokkos(Mat C,MatColIdxKokkosView& l2g)
849076ba34aSJunchao Zhang {
850076ba34aSJunchao Zhang   PetscErrorCode         ierr;
851076ba34aSJunchao Zhang   Mat_SeqAIJKokkos       *ckok;
852076ba34aSJunchao Zhang   ISLocalToGlobalMapping l2gmap;
853076ba34aSJunchao Zhang   const PetscInt         *garray;
854076ba34aSJunchao Zhang   PetscInt               sz;
855076ba34aSJunchao Zhang 
856076ba34aSJunchao Zhang   PetscFunctionBegin;
857076ba34aSJunchao Zhang   /* Compact P_other's global col ids and col size. We do it since we guess with local ids KK might be more memory scalable */
858076ba34aSJunchao Zhang   ierr = MatSeqAIJCompactOutExtraColumns_SeqAIJ(C,&l2gmap);CHKERRQ(ierr);
859076ba34aSJunchao Zhang   ckok = static_cast<Mat_SeqAIJKokkos*>(C->spptr);
860076ba34aSJunchao Zhang   ckok->j_dual.modify_host(); /* P_other's j is modified on host; we need to sync it on device */
861076ba34aSJunchao Zhang   ckok->j_dual.sync_device();
862076ba34aSJunchao Zhang   ckok->SetColSize(C->cmap->n); /* Update col size of the csrmat in spptr */
863076ba34aSJunchao Zhang 
864076ba34aSJunchao Zhang   /* Build l2g -- the local to global mapping of C's cols */
865076ba34aSJunchao Zhang   ierr = ISLocalToGlobalMappingGetIndices(l2gmap,&garray);CHKERRQ(ierr);
866076ba34aSJunchao Zhang   ierr = ISLocalToGlobalMappingGetSize(l2gmap,&sz);CHKERRQ(ierr);
8672c71b3e2SJacob Faibussowitsch   PetscCheckFalse(C->cmap->n != sz,PETSC_COMM_SELF,PETSC_ERR_PLIB,"matrix column size(%" PetscInt_FMT ") != l2g mapping size(%" PetscInt_FMT ")", C->cmap->n,sz);
868076ba34aSJunchao Zhang 
869076ba34aSJunchao Zhang   ConstMatColIdxKokkosViewHost tmp(garray,sz);
870076ba34aSJunchao Zhang   l2g = MatColIdxKokkosView("l2g",sz);
871076ba34aSJunchao Zhang   Kokkos::deep_copy(l2g,tmp);
872076ba34aSJunchao Zhang 
873076ba34aSJunchao Zhang   ierr = ISLocalToGlobalMappingRestoreIndices(l2gmap,&garray);CHKERRQ(ierr);
874076ba34aSJunchao Zhang   ierr = ISLocalToGlobalMappingDestroy(&l2gmap);CHKERRQ(ierr);
875076ba34aSJunchao Zhang   PetscFunctionReturn(0);
876076ba34aSJunchao Zhang }
877076ba34aSJunchao Zhang 
878076ba34aSJunchao Zhang /* MatProductSymbolic_MPIAIJKokkos_AB - AB flavor of MatProductSymbolic_MPIAIJKokkos
879076ba34aSJunchao Zhang 
880076ba34aSJunchao Zhang   Input Parameters:
881076ba34aSJunchao Zhang +  product  - Mat_Product which carried out the computation. Passed in to access info about this mat product.
882076ba34aSJunchao Zhang .  A        - an MPIAIJKOKKOS matrix
883076ba34aSJunchao Zhang .  B        - an MPIAIJKOKKOS matrix
884076ba34aSJunchao Zhang -  mm       - a struct used to stash intermediate data when computing AB. Persist from symbolic to numeric operations.
885076ba34aSJunchao Zhang 
886076ba34aSJunchao Zhang   Notes: The local part of the result C is stored as mm->C_global, which is of type KokkosCsrMatrix and uses global col ids.
887076ba34aSJunchao Zhang */
888076ba34aSJunchao Zhang static PetscErrorCode MatProductSymbolic_MPIAIJKokkos_AB(Mat_Product *product,Mat A,Mat B,MatMatStruct_AB *mm)
889076ba34aSJunchao Zhang {
890076ba34aSJunchao Zhang   PetscErrorCode              ierr;
891076ba34aSJunchao Zhang   Mat_MPIAIJ                  *a = static_cast<Mat_MPIAIJ*>(A->data);
892076ba34aSJunchao Zhang   Mat                         Ad = a->A,Ao = a->B; /* diag and offdiag of A */
893076ba34aSJunchao Zhang   IS                          glob = NULL;
894076ba34aSJunchao Zhang   const PetscInt              *garray;
895076ba34aSJunchao Zhang   PetscInt                    N = B->cmap->N,sz;
896076ba34aSJunchao Zhang   ConstMatColIdxKokkosView    l2g1; /* two temp maps mapping local col ids to global ones */
897076ba34aSJunchao Zhang   MatColIdxKokkosView         l2g2;
898076ba34aSJunchao Zhang   Mat                         C1,C2; /* intermediate matrices */
899076ba34aSJunchao Zhang 
900076ba34aSJunchao Zhang   PetscFunctionBegin;
901076ba34aSJunchao Zhang   /* C1 = Ad * B_local. B_local is a matrix got by merging Bd and Bo, and uses local col ids */
902076ba34aSJunchao Zhang   ierr = MatMPIAIJGetLocalMatMerge(B,MAT_INITIAL_MATRIX,&glob,&mm->B_local);CHKERRQ(ierr);
903076ba34aSJunchao Zhang   ierr = MatProductCreate(Ad,mm->B_local,NULL,&C1);CHKERRQ(ierr);
904076ba34aSJunchao Zhang   ierr = MatProductSetType(C1,MATPRODUCT_AB);CHKERRQ(ierr);
905076ba34aSJunchao Zhang   ierr = MatProductSetFill(C1,product->fill);CHKERRQ(ierr);
906076ba34aSJunchao Zhang   C1->product->api_user = product->api_user;
907076ba34aSJunchao Zhang   ierr = MatProductSetFromOptions(C1);CHKERRQ(ierr);
9082c71b3e2SJacob Faibussowitsch   PetscCheckFalse(!C1->ops->productsymbolic,PETSC_COMM_SELF,PETSC_ERR_PLIB,"Missing symbolic op for %s",MatProductTypes[C1->product->type]);
909076ba34aSJunchao Zhang   ierr = (*C1->ops->productsymbolic)(C1);CHKERRQ(ierr);
910076ba34aSJunchao Zhang 
911076ba34aSJunchao Zhang   ierr = ISGetIndices(glob,&garray);CHKERRQ(ierr);
912076ba34aSJunchao Zhang   ierr = ISGetSize(glob,&sz);CHKERRQ(ierr);
913076ba34aSJunchao Zhang   const auto& tmp  = ConstMatColIdxKokkosViewHost(garray,sz); /* wrap garray as a view */
914076ba34aSJunchao Zhang   l2g1 = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(),tmp); /* maybe just an alias to tmp, so we restore garray at the very end */
915076ba34aSJunchao Zhang   ierr = MatSeqAIJKokkosGetCSRMatrixWithGlobalColumnIds(C1,N,l2g1,mm->C1_global);
916076ba34aSJunchao Zhang 
917076ba34aSJunchao Zhang   /* C2 = Ao * B_other. B_other is a matrix consisting of needed rows of B gathered from other procs */
918076ba34aSJunchao Zhang   ierr = MatSeqAIJKokkosBcast(mm->B_local,MAT_INITIAL_MATRIX,N,l2g1,a->Mvctx,mm->sf,
919076ba34aSJunchao Zhang                               mm->abuf,mm->rows,mm->rowoffset,mm->B_other);CHKERRQ(ierr);
920076ba34aSJunchao Zhang 
921076ba34aSJunchao Zhang   /* Compact B_other to use local ids as we guess KK spgemm is more memroy scalable with that; We could skip the compaction to simplify code */
922076ba34aSJunchao Zhang   ierr = MatSeqAIJCompactOutExtraColumns_SeqAIJKokkos(mm->B_other,l2g2);CHKERRQ(ierr);
923076ba34aSJunchao Zhang   ierr = MatProductCreate(Ao,mm->B_other,NULL,&C2);CHKERRQ(ierr);
924076ba34aSJunchao Zhang   ierr = MatProductSetType(C2,MATPRODUCT_AB);CHKERRQ(ierr);
925076ba34aSJunchao Zhang   ierr = MatProductSetFill(C2,product->fill);CHKERRQ(ierr);
926076ba34aSJunchao Zhang   C2->product->api_user = product->api_user;
927076ba34aSJunchao Zhang   ierr = MatProductSetFromOptions(C2);CHKERRQ(ierr);
9282c71b3e2SJacob Faibussowitsch   PetscCheckFalse(!C2->ops->productsymbolic,PETSC_COMM_SELF,PETSC_ERR_PLIB,"Missing symbolic op for %s",MatProductTypes[C2->product->type]);
929076ba34aSJunchao Zhang   ierr = (*C2->ops->productsymbolic)(C2);CHKERRQ(ierr);
930076ba34aSJunchao Zhang   ierr = MatSeqAIJKokkosGetCSRMatrixWithGlobalColumnIds(C2,N,l2g2,mm->C2_global);
931076ba34aSJunchao Zhang 
932076ba34aSJunchao Zhang   /* C = C1 + C2.  We actually use their global col ids versions in adding */
933076ba34aSJunchao Zhang   mm->kh.create_spadd_handle(false); /* Input C1, C2 are NOT sorted, since B_local, B_other are not */
934076ba34aSJunchao Zhang   KokkosSparse::spadd_symbolic(&mm->kh,mm->C1_global,mm->C2_global,mm->C_global);
935076ba34aSJunchao Zhang   /* Have to do numeric since spadd_symbolic does not really populate column indices of the result matrix */
936076ba34aSJunchao Zhang   KokkosSparse::spadd_numeric(&mm->kh,(MatScalarType)1.0,mm->C1_global,(MatScalarType)1.0,mm->C2_global,mm->C_global);
937076ba34aSJunchao Zhang 
938076ba34aSJunchao Zhang   mm->C1 = C1;
939076ba34aSJunchao Zhang   mm->C2 = C2;
940076ba34aSJunchao Zhang   ierr = ISRestoreIndices(glob,&garray);CHKERRQ(ierr);
941076ba34aSJunchao Zhang   ierr = ISDestroy(&glob);CHKERRQ(ierr);
942076ba34aSJunchao Zhang   PetscFunctionReturn(0);
943076ba34aSJunchao Zhang }
944076ba34aSJunchao Zhang 
945076ba34aSJunchao Zhang /* MatProductSymbolic_MPIAIJKokkos_AtB - A^tB flavor of MatProductSymbolic_MPIAIJKokkos
946076ba34aSJunchao Zhang 
947076ba34aSJunchao Zhang   Input Parameters:
948076ba34aSJunchao Zhang +  product  - Mat_Product which carried out the computation. Passed in to access info about this mat product.
949076ba34aSJunchao Zhang .  A        - an MPIAIJKOKKOS matrix
950076ba34aSJunchao Zhang .  B        - a SEQAIJKOKKOS matrix. It works as if A^t is multiplied by a parallel matrix made up of Bs on each rank.
951076ba34aSJunchao Zhang .  localB   - Does B use local col ids? If false, then B is already in global col ids.
952076ba34aSJunchao Zhang .  N        - col size of the "parallel B matrix". It implies B's global col ids are in range of [0,N) and N is the same across the communicator.
953076ba34aSJunchao Zhang .  l2g      - If localB, then l2g maps B's local col ids to global ones.
954076ba34aSJunchao Zhang -  mm       - a struct used to stash intermediate data in AtB
955076ba34aSJunchao Zhang 
956076ba34aSJunchao Zhang   Notes: The local part of the result C is stored as mm->C_global, which is of type KokkosCsrMatrix and uses global col ids.
957076ba34aSJunchao Zhang */
958076ba34aSJunchao Zhang static PetscErrorCode MatProductSymbolic_MPIAIJKokkos_AtB(Mat_Product *product,Mat A,Mat B,PetscBool localB,PetscInt N,const ConstMatColIdxKokkosView& l2g,MatMatStruct_AtB *mm)
959076ba34aSJunchao Zhang {
960076ba34aSJunchao Zhang   PetscErrorCode         ierr;
961076ba34aSJunchao Zhang   Mat_MPIAIJ             *a = static_cast<Mat_MPIAIJ*>(A->data);
962076ba34aSJunchao Zhang   Mat                    Ad = a->A,Ao = a->B; /* diag and offdiag of A */
963076ba34aSJunchao Zhang   Mat                    C1,C2; /* intermediate matrices */
964076ba34aSJunchao Zhang 
965076ba34aSJunchao Zhang   PetscFunctionBegin;
966076ba34aSJunchao Zhang   /* C1 = Ad^t * B */
967076ba34aSJunchao Zhang   ierr = MatProductCreate(Ad,B,NULL,&C1);CHKERRQ(ierr);
968076ba34aSJunchao Zhang   ierr = MatProductSetType(C1,MATPRODUCT_AtB);CHKERRQ(ierr);
969076ba34aSJunchao Zhang   ierr = MatProductSetFill(C1,product->fill);CHKERRQ(ierr);
970076ba34aSJunchao Zhang   C1->product->api_user = product->api_user;
971076ba34aSJunchao Zhang   ierr = MatProductSetFromOptions(C1);CHKERRQ(ierr);
9722c71b3e2SJacob Faibussowitsch   PetscCheckFalse(!C1->ops->productsymbolic,PETSC_COMM_SELF,PETSC_ERR_PLIB,"Missing symbolic op for %s",MatProductTypes[C1->product->type]);
973076ba34aSJunchao Zhang   ierr = (*C1->ops->productsymbolic)(C1);CHKERRQ(ierr);
974076ba34aSJunchao Zhang 
975076ba34aSJunchao Zhang   if (localB) {ierr = MatSeqAIJKokkosGetCSRMatrixWithGlobalColumnIds(C1,N,l2g,mm->C1_global);}
976076ba34aSJunchao Zhang   else mm->C1_global = static_cast<Mat_SeqAIJKokkos*>(C1->spptr)->csrmat; /* the csrmat already uses global col ids */
977076ba34aSJunchao Zhang 
978076ba34aSJunchao Zhang   /* C2 = Ao^t * B */
979076ba34aSJunchao Zhang   ierr = MatProductCreate(Ao,B,NULL,&C2);CHKERRQ(ierr);
980076ba34aSJunchao Zhang   ierr = MatProductSetType(C2,MATPRODUCT_AtB);CHKERRQ(ierr);
981076ba34aSJunchao Zhang   ierr = MatProductSetFill(C2,product->fill);CHKERRQ(ierr);
982076ba34aSJunchao Zhang   C2->product->api_user = product->api_user;
983076ba34aSJunchao Zhang   ierr = MatProductSetFromOptions(C2);CHKERRQ(ierr);
9842c71b3e2SJacob Faibussowitsch   PetscCheckFalse(!C2->ops->productsymbolic,PETSC_COMM_SELF,PETSC_ERR_PLIB,"Missing symbolic op for %s",MatProductTypes[C2->product->type]);
985076ba34aSJunchao Zhang   ierr = (*C2->ops->productsymbolic)(C2);CHKERRQ(ierr);
986076ba34aSJunchao Zhang 
987076ba34aSJunchao Zhang   ierr = MatSeqAIJKokkosReduce(C2,MAT_INITIAL_MATRIX,localB,N,l2g,a->Mvctx,mm->sf,mm->abuf,
988076ba34aSJunchao Zhang                                mm->srcrowoffset,mm->dstrowoffset,mm->C2_global);CHKERRQ(ierr);
989076ba34aSJunchao Zhang 
990076ba34aSJunchao Zhang   mm->kh.create_spadd_handle(false); /* Input C1, C2 are NOT sorted, since B may be not */
991076ba34aSJunchao Zhang   KokkosSparse::spadd_symbolic(&mm->kh,mm->C1_global,mm->C2_global,mm->C_global);
992076ba34aSJunchao Zhang   /* Have to do numeric since spadd_symbolic does not really populate column indices of the result matrix */
993076ba34aSJunchao Zhang   KokkosSparse::spadd_numeric(&mm->kh,(MatScalarType)1.0,mm->C1_global,(MatScalarType)1.0,mm->C2_global,mm->C_global);
994076ba34aSJunchao Zhang   mm->C1 = C1;
995076ba34aSJunchao Zhang   mm->C2 = C2;
996076ba34aSJunchao Zhang   PetscFunctionReturn(0);
997076ba34aSJunchao Zhang }
998076ba34aSJunchao Zhang 
999076ba34aSJunchao Zhang PetscErrorCode MatProductNumeric_MPIAIJKokkos(Mat C)
1000076ba34aSJunchao Zhang {
1001076ba34aSJunchao Zhang   PetscErrorCode                ierr;
1002076ba34aSJunchao Zhang   Mat_Product                   *product = C->product;
1003076ba34aSJunchao Zhang   MatProductType                ptype;
1004076ba34aSJunchao Zhang   MatProductData_MPIAIJKokkos   *mmdata;
1005076ba34aSJunchao Zhang   MatMatStruct                  *mm = NULL;
1006076ba34aSJunchao Zhang   MatMatStruct_AB               *ab;
1007076ba34aSJunchao Zhang   MatMatStruct_AtB              *atb;
1008076ba34aSJunchao Zhang   Mat                           A,B,Ad,Ao,Bd,Bo;
1009076ba34aSJunchao Zhang   const MatScalarType           one = 1.0; /* Not use literal 1.0 directly, to avoid wrong template instantiation in KokkosSparse::spadd_numeric */
1010076ba34aSJunchao Zhang 
1011076ba34aSJunchao Zhang   PetscFunctionBegin;
1012076ba34aSJunchao Zhang   MatCheckProduct(C,1);
1013076ba34aSJunchao Zhang   mmdata = static_cast<MatProductData_MPIAIJKokkos*>(product->data);
1014076ba34aSJunchao Zhang   ptype  = product->type;
1015076ba34aSJunchao Zhang   A      = product->A;
1016076ba34aSJunchao Zhang   B      = product->B;
1017076ba34aSJunchao Zhang   Ad     = static_cast<Mat_MPIAIJ*>(A->data)->A;
1018076ba34aSJunchao Zhang   Ao     = static_cast<Mat_MPIAIJ*>(A->data)->B;
1019076ba34aSJunchao Zhang   Bd     = static_cast<Mat_MPIAIJ*>(B->data)->A;
1020076ba34aSJunchao Zhang   Bo     = static_cast<Mat_MPIAIJ*>(B->data)->B;
1021076ba34aSJunchao Zhang 
1022076ba34aSJunchao Zhang   if (mmdata->reusesym) { /* We reached here through e.g., MatMatMult(A,B,MAT_INITIAL_MATRIX,..,C), where symbolic/numeric are combined */
1023076ba34aSJunchao Zhang     mmdata->reusesym = PETSC_FALSE; /* So that next time when user calls MatMatMult(E,F,MAT_REUSE_MATRIX,..,C), we still do numeric  */
1024076ba34aSJunchao Zhang     ab  = mmdata->mmAB;
1025076ba34aSJunchao Zhang     atb = mmdata->mmAtB;
1026076ba34aSJunchao Zhang     if (ab) {
1027076ba34aSJunchao Zhang       static_cast<MatProductData_SeqAIJKokkos*>(ab->C1->product->data)->reusesym = PETSC_FALSE;
1028076ba34aSJunchao Zhang       static_cast<MatProductData_SeqAIJKokkos*>(ab->C2->product->data)->reusesym = PETSC_FALSE;
1029076ba34aSJunchao Zhang     }
1030076ba34aSJunchao Zhang     if (atb) {
1031076ba34aSJunchao Zhang       static_cast<MatProductData_SeqAIJKokkos*>(atb->C1->product->data)->reusesym = PETSC_FALSE;
1032076ba34aSJunchao Zhang       static_cast<MatProductData_SeqAIJKokkos*>(atb->C2->product->data)->reusesym = PETSC_FALSE;
1033076ba34aSJunchao Zhang     }
1034076ba34aSJunchao Zhang     PetscFunctionReturn(0);
1035076ba34aSJunchao Zhang   }
1036076ba34aSJunchao Zhang 
1037076ba34aSJunchao Zhang   if (ptype == MATPRODUCT_AB) {
1038076ba34aSJunchao Zhang     ab   = mmdata->mmAB;
1039076ba34aSJunchao Zhang     /* C1 = Ad * B_local */
10402c71b3e2SJacob Faibussowitsch     PetscCheckFalse(!ab->C1->ops->productnumeric || !ab->C2->ops->productnumeric,PETSC_COMM_SELF,PETSC_ERR_PLIB,"Missing numeric op for MATPRODUCT_AB");
1041076ba34aSJunchao Zhang     ierr = MatMPIAIJGetLocalMatMerge(B,MAT_REUSE_MATRIX,NULL/*glob*/,&ab->B_local);CHKERRQ(ierr);
10422c71b3e2SJacob Faibussowitsch     PetscCheckFalse(ab->C1->product->B != ab->B_local,PETSC_COMM_SELF,PETSC_ERR_PLIB,"In MATPRODUCT_AB, internal mat product matrix C1->B has unexpectedly changed");
1043076ba34aSJunchao Zhang     if (ab->C1->product->A != Ad) {ierr = MatProductReplaceMats(Ad,NULL,NULL,ab->C1);CHKERRQ(ierr);}
1044076ba34aSJunchao Zhang     ierr = (*ab->C1->ops->productnumeric)(ab->C1);CHKERRQ(ierr);
1045076ba34aSJunchao Zhang     ierr = MatSeqAIJKokkosBcast(ab->B_local,MAT_REUSE_MATRIX,0/*N*/,MatColIdxKokkosView()/*l2g*/,NULL/*ownerSF*/,ab->sf,
1046076ba34aSJunchao Zhang                                 ab->abuf,ab->rows,ab->rowoffset,ab->B_other);CHKERRQ(ierr);
1047076ba34aSJunchao Zhang     /* C2 = Ao * B_other */
10482c71b3e2SJacob Faibussowitsch     PetscCheckFalse(ab->C2->product->B != ab->B_other,PETSC_COMM_SELF,PETSC_ERR_PLIB,"In MATPRODUCT_AB, internal mat product matrix C2->B has unexpectedly changed");
1049076ba34aSJunchao Zhang     if (ab->C1->product->A != Ao) {ierr = MatProductReplaceMats(Ao,NULL,NULL,ab->C2);CHKERRQ(ierr);}
1050076ba34aSJunchao Zhang     ierr = (*ab->C2->ops->productnumeric)(ab->C2);CHKERRQ(ierr);
1051076ba34aSJunchao Zhang     /* C = C1_global + C2_global */
1052076ba34aSJunchao Zhang     KokkosSparse::spadd_numeric(&ab->kh,one,ab->C1_global,one,ab->C2_global,ab->C_global);
1053076ba34aSJunchao Zhang     mm = static_cast<MatMatStruct*>(ab);
1054076ba34aSJunchao Zhang   } else if (ptype == MATPRODUCT_AtB) {
1055076ba34aSJunchao Zhang     atb  = mmdata->mmAtB;
10562c71b3e2SJacob Faibussowitsch     PetscCheckFalse(!atb->C1->ops->productnumeric || !atb->C2->ops->productnumeric,PETSC_COMM_SELF,PETSC_ERR_PLIB,"Missing numeric op for MATPRODUCT_AtB");
1057076ba34aSJunchao Zhang     /* C1 = Ad^t * B_local */
1058076ba34aSJunchao Zhang     ierr = MatMPIAIJGetLocalMatMerge(B,MAT_REUSE_MATRIX,NULL/*glob*/,&atb->B_local);CHKERRQ(ierr);
10592c71b3e2SJacob Faibussowitsch     PetscCheckFalse(atb->C1->product->B != atb->B_local,PETSC_COMM_SELF,PETSC_ERR_PLIB,"In MATPRODUCT_AtB, internal mat product matrix C1->B has unexpectedly changed");
1060076ba34aSJunchao Zhang     if (atb->C1->product->A != Ad) {ierr = MatProductReplaceMats(Ad,NULL,NULL,atb->C1);CHKERRQ(ierr);}
1061076ba34aSJunchao Zhang     ierr = (*atb->C1->ops->productnumeric)(atb->C1);CHKERRQ(ierr);
1062076ba34aSJunchao Zhang 
1063076ba34aSJunchao Zhang     /* C2 = Ao^t * B_local */
10642c71b3e2SJacob Faibussowitsch     PetscCheckFalse(atb->C2->product->B != atb->B_local,PETSC_COMM_SELF,PETSC_ERR_PLIB,"In MATPRODUCT_AtB, internal mat product matrix C2->B has unexpectedly changed");
1065076ba34aSJunchao Zhang     if (atb->C2->product->A != Ao) {ierr = MatProductReplaceMats(Ao,NULL,NULL,atb->C2);CHKERRQ(ierr);}
1066076ba34aSJunchao Zhang     ierr = (*atb->C2->ops->productnumeric)(atb->C2);CHKERRQ(ierr);
1067076ba34aSJunchao Zhang     /* Form C2_global */
1068076ba34aSJunchao Zhang     ierr = MatSeqAIJKokkosReduce(atb->C2,MAT_REUSE_MATRIX,PETSC_TRUE,0/*N*/,MatColIdxKokkosView()/*l2g*/,NULL/*ownerSF*/,atb->sf,
1069076ba34aSJunchao Zhang                                  atb->abuf,atb->srcrowoffset,atb->dstrowoffset,atb->C2_global);CHKERRQ(ierr);
1070076ba34aSJunchao Zhang     /* C = C1_global + C2_global */
1071076ba34aSJunchao Zhang     KokkosSparse::spadd_numeric(&atb->kh,one,atb->C1_global,one,atb->C2_global,atb->C_global);
1072076ba34aSJunchao Zhang     mm = static_cast<MatMatStruct*>(atb);
1073076ba34aSJunchao Zhang   } else if (ptype == MATPRODUCT_PtAP) { /* BtAB */
1074076ba34aSJunchao Zhang     ab   = mmdata->mmAB;
1075076ba34aSJunchao Zhang     ierr = MatMPIAIJGetLocalMatMerge(B,MAT_REUSE_MATRIX,NULL/*glob*/,&ab->B_local);CHKERRQ(ierr);
1076076ba34aSJunchao Zhang 
1077076ba34aSJunchao Zhang     /* ab->C1 = Ad * B_local */
10782c71b3e2SJacob Faibussowitsch     PetscCheckFalse(ab->C1->product->B != ab->B_local,PETSC_COMM_SELF,PETSC_ERR_PLIB,"In MATPRODUCT_PtAP, internal mat product matrix ab->C1->B has unexpectedly changed");
1079076ba34aSJunchao Zhang     if (ab->C1->product->A != Ad) {ierr = MatProductReplaceMats(Ad,NULL,NULL,ab->C1);CHKERRQ(ierr);}
1080076ba34aSJunchao Zhang     ierr = (*ab->C1->ops->productnumeric)(ab->C1);CHKERRQ(ierr);
1081076ba34aSJunchao Zhang     ierr = MatSeqAIJKokkosBcast(ab->B_local,MAT_REUSE_MATRIX,0/*N*/,MatColIdxKokkosView()/*l2g*/,NULL/*ownerSF*/,ab->sf,
1082076ba34aSJunchao Zhang                                 ab->abuf,ab->rows,ab->rowoffset,ab->B_other);CHKERRQ(ierr);
1083076ba34aSJunchao Zhang     /* ab->C2 = Ao * B_other */
1084076ba34aSJunchao Zhang     if (ab->C2->product->A != Ao) {ierr = MatProductReplaceMats(Ao,NULL,NULL,ab->C2);CHKERRQ(ierr);}
1085076ba34aSJunchao Zhang     ierr = (*ab->C2->ops->productnumeric)(ab->C2);CHKERRQ(ierr); /* C2 = Ao * B_other */
1086076ba34aSJunchao Zhang     KokkosSparse::spadd_numeric(&ab->kh,one,ab->C1_global,one,ab->C2_global,ab->C_global);
1087076ba34aSJunchao Zhang 
1088076ba34aSJunchao Zhang     /* atb->C1 = Bd^t * ab->C_petsc */
1089076ba34aSJunchao Zhang     atb  = mmdata->mmAtB;
10902c71b3e2SJacob Faibussowitsch     PetscCheckFalse(atb->C1->product->B != ab->C_petsc,PETSC_COMM_SELF,PETSC_ERR_PLIB,"In MATPRODUCT_PtAP, internal mat product matrix atb->C1->B has unexpectedly changed");
1091076ba34aSJunchao Zhang     if (atb->C1->product->A != Bd) {ierr = MatProductReplaceMats(Bd,NULL,NULL,atb->C1);CHKERRQ(ierr);}
1092076ba34aSJunchao Zhang     ierr = (*atb->C1->ops->productnumeric)(atb->C1);CHKERRQ(ierr);
1093076ba34aSJunchao Zhang     /* atb->C2 = Bo^t * ab->C_petsc */
1094076ba34aSJunchao Zhang     if (atb->C2->product->A != Bo) {ierr = MatProductReplaceMats(Bo,NULL,NULL,atb->C2);CHKERRQ(ierr);}
1095076ba34aSJunchao Zhang     ierr = (*atb->C2->ops->productnumeric)(atb->C2);CHKERRQ(ierr);
1096076ba34aSJunchao Zhang     ierr = MatSeqAIJKokkosReduce(atb->C2,MAT_REUSE_MATRIX,PETSC_FALSE,0/*N*/,MatColIdxKokkosView()/*l2g*/,NULL/*ownerSF*/,atb->sf,
1097076ba34aSJunchao Zhang                                  atb->abuf,atb->srcrowoffset,atb->dstrowoffset,atb->C2_global);CHKERRQ(ierr);
1098076ba34aSJunchao Zhang     KokkosSparse::spadd_numeric(&atb->kh,one,atb->C1_global,one,atb->C2_global,atb->C_global);
1099076ba34aSJunchao Zhang     mm = static_cast<MatMatStruct*>(atb);
1100076ba34aSJunchao Zhang   }
1101076ba34aSJunchao Zhang   /* Split C_global to form C */
1102076ba34aSJunchao Zhang   ierr = MatSetMPIAIJKokkosWithGlobalCSRMatrix(C,MAT_REUSE_MATRIX,mm->C_global,mm->Cdstart);CHKERRQ(ierr);
1103076ba34aSJunchao Zhang   PetscFunctionReturn(0);
1104076ba34aSJunchao Zhang }
1105076ba34aSJunchao Zhang 
1106076ba34aSJunchao Zhang PetscErrorCode MatProductSymbolic_MPIAIJKokkos(Mat C)
1107076ba34aSJunchao Zhang {
1108076ba34aSJunchao Zhang   PetscErrorCode              ierr;
1109076ba34aSJunchao Zhang   Mat                         A,B;
1110076ba34aSJunchao Zhang   Mat_Product                 *product = C->product;
1111076ba34aSJunchao Zhang   MatProductType              ptype;
1112076ba34aSJunchao Zhang   MatProductData_MPIAIJKokkos *mmdata;
1113076ba34aSJunchao Zhang   MatMatStruct                *mm = NULL;
1114076ba34aSJunchao Zhang   IS                          glob = NULL;
1115076ba34aSJunchao Zhang   const PetscInt              *garray;
1116076ba34aSJunchao Zhang   PetscInt                    m,n,M,N,sz;
1117076ba34aSJunchao Zhang   ConstMatColIdxKokkosView    l2g; /* map local col ids to global ones */
1118076ba34aSJunchao Zhang 
1119076ba34aSJunchao Zhang   PetscFunctionBegin;
1120076ba34aSJunchao Zhang   MatCheckProduct(C,1);
11212c71b3e2SJacob Faibussowitsch   PetscCheckFalse(product->data,PetscObjectComm((PetscObject)C),PETSC_ERR_PLIB,"Product data not empty");
1122076ba34aSJunchao Zhang   ptype = product->type;
1123076ba34aSJunchao Zhang   A     = product->A;
1124076ba34aSJunchao Zhang   B     = product->B;
1125076ba34aSJunchao Zhang 
1126076ba34aSJunchao Zhang   switch (ptype) {
1127076ba34aSJunchao Zhang     case MATPRODUCT_AB:   m = A->rmap->n; n = B->cmap->n; M = A->rmap->N; N = B->cmap->N; break;
1128076ba34aSJunchao Zhang     case MATPRODUCT_AtB:  m = A->cmap->n; n = B->cmap->n; M = A->cmap->N; N = B->cmap->N; break;
1129076ba34aSJunchao Zhang     case MATPRODUCT_PtAP: m = B->cmap->n; n = B->cmap->n; M = B->cmap->N; N = B->cmap->N; break; /* BtAB */
113098921bdaSJacob Faibussowitsch     default: SETERRQ(PetscObjectComm((PetscObject)C),PETSC_ERR_PLIB,"Not for product type %s",MatProductTypes[ptype]);
1131076ba34aSJunchao Zhang   }
1132076ba34aSJunchao Zhang 
1133076ba34aSJunchao Zhang   ierr = MatSetSizes(C,m,n,M,N);CHKERRQ(ierr);
1134076ba34aSJunchao Zhang   ierr = PetscLayoutSetUp(C->rmap);CHKERRQ(ierr);
1135076ba34aSJunchao Zhang   ierr = PetscLayoutSetUp(C->cmap);CHKERRQ(ierr);
1136076ba34aSJunchao Zhang   ierr = MatSetType(C,((PetscObject)A)->type_name);CHKERRQ(ierr);
1137076ba34aSJunchao Zhang 
1138076ba34aSJunchao Zhang   mmdata           = new MatProductData_MPIAIJKokkos();
1139076ba34aSJunchao Zhang   mmdata->reusesym = product->api_user;
1140076ba34aSJunchao Zhang 
1141076ba34aSJunchao Zhang   if (ptype == MATPRODUCT_AB) {
1142076ba34aSJunchao Zhang     mmdata->mmAB = new MatMatStruct_AB();
1143076ba34aSJunchao Zhang     ierr = MatProductSymbolic_MPIAIJKokkos_AB(product,A,B,mmdata->mmAB);CHKERRQ(ierr);
1144076ba34aSJunchao Zhang     mm   = static_cast<MatMatStruct*>(mmdata->mmAB);
1145076ba34aSJunchao Zhang   } else if (ptype == MATPRODUCT_AtB) {
1146076ba34aSJunchao Zhang     mmdata->mmAtB = new MatMatStruct_AtB();
1147076ba34aSJunchao Zhang     auto atb      = mmdata->mmAtB;
1148076ba34aSJunchao Zhang     ierr = MatMPIAIJGetLocalMatMerge(B,MAT_INITIAL_MATRIX,&glob,&atb->B_local);CHKERRQ(ierr);
1149076ba34aSJunchao Zhang     ierr = ISGetIndices(glob,&garray);CHKERRQ(ierr);
1150076ba34aSJunchao Zhang     ierr = ISGetSize(glob,&sz);CHKERRQ(ierr);
1151076ba34aSJunchao Zhang     l2g  = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(),ConstMatColIdxKokkosViewHost(garray,sz));
1152076ba34aSJunchao Zhang     ierr = MatProductSymbolic_MPIAIJKokkos_AtB(product,A,atb->B_local,PETSC_TRUE,N,l2g,atb);CHKERRQ(ierr);
1153076ba34aSJunchao Zhang     ierr = ISRestoreIndices(glob,&garray);CHKERRQ(ierr);
1154076ba34aSJunchao Zhang     ierr = ISDestroy(&glob);CHKERRQ(ierr);
1155076ba34aSJunchao Zhang     mm   = static_cast<MatMatStruct*>(atb);
1156076ba34aSJunchao Zhang   } else if (ptype == MATPRODUCT_PtAP) { /* BtAB */
1157076ba34aSJunchao Zhang     mmdata->mmAB  = new MatMatStruct_AB(); /* tmp=A*B */
1158076ba34aSJunchao Zhang     mmdata->mmAtB = new MatMatStruct_AtB(); /* C=B^t*tmp */
1159076ba34aSJunchao Zhang     auto ab       = mmdata->mmAB;
1160076ba34aSJunchao Zhang     auto atb      = mmdata->mmAtB;
1161076ba34aSJunchao Zhang     ierr = MatProductSymbolic_MPIAIJKokkos_AB(product,A,B,ab);CHKERRQ(ierr);
1162076ba34aSJunchao Zhang     auto tmp = new Mat_SeqAIJKokkos(ab->C_global); /* Memory will be owned by ab->C_petsc */
1163076ba34aSJunchao Zhang     ierr = MatCreateSeqAIJKokkosWithCSRMatrix(PETSC_COMM_SELF,tmp,&ab->C_petsc);CHKERRQ(ierr);
1164076ba34aSJunchao Zhang     ierr = MatProductSymbolic_MPIAIJKokkos_AtB(product,B,ab->C_petsc,PETSC_FALSE,N,l2g/*not used*/,atb);CHKERRQ(ierr);
1165076ba34aSJunchao Zhang     mm   = static_cast<MatMatStruct*>(atb);
1166076ba34aSJunchao Zhang   }
1167076ba34aSJunchao Zhang   /* Split the C_global into petsc A, B format */
1168076ba34aSJunchao Zhang   ierr = MatSetMPIAIJKokkosWithGlobalCSRMatrix(C,MAT_INITIAL_MATRIX,mm->C_global,mm->Cdstart);CHKERRQ(ierr);
1169076ba34aSJunchao Zhang   C->product->data        = mmdata;
1170076ba34aSJunchao Zhang   C->product->destroy     = MatProductDataDestroy_MPIAIJKokkos;
1171076ba34aSJunchao Zhang   C->ops->productnumeric  = MatProductNumeric_MPIAIJKokkos;
1172076ba34aSJunchao Zhang   PetscFunctionReturn(0);
1173076ba34aSJunchao Zhang }
1174076ba34aSJunchao Zhang 
1175076ba34aSJunchao Zhang PETSC_INTERN PetscErrorCode MatProductSetFromOptions_MPIAIJKokkos(Mat mat)
1176076ba34aSJunchao Zhang {
1177076ba34aSJunchao Zhang   PetscErrorCode ierr;
1178076ba34aSJunchao Zhang   Mat_Product    *product = mat->product;
1179076ba34aSJunchao Zhang   PetscBool      match = PETSC_FALSE;
1180076ba34aSJunchao Zhang   PetscBool      usecpu = PETSC_FALSE;
1181076ba34aSJunchao Zhang 
1182076ba34aSJunchao Zhang   PetscFunctionBegin;
1183076ba34aSJunchao Zhang   MatCheckProduct(mat,1);
1184076ba34aSJunchao Zhang   if (!product->A->boundtocpu && !product->B->boundtocpu) {
1185076ba34aSJunchao Zhang     ierr = PetscObjectTypeCompare((PetscObject)product->B,((PetscObject)product->A)->type_name,&match);CHKERRQ(ierr);
1186076ba34aSJunchao Zhang   }
1187076ba34aSJunchao Zhang   if (match) { /* we can always fallback to the CPU if requested */
1188076ba34aSJunchao Zhang     switch (product->type) {
1189076ba34aSJunchao Zhang     case MATPRODUCT_AB:
1190076ba34aSJunchao Zhang       if (product->api_user) {
1191076ba34aSJunchao Zhang         ierr = PetscOptionsBegin(PetscObjectComm((PetscObject)mat),((PetscObject)mat)->prefix,"MatMatMult","Mat");CHKERRQ(ierr);
1192076ba34aSJunchao Zhang         ierr = PetscOptionsBool("-matmatmult_backend_cpu","Use CPU code","MatMatMult",usecpu,&usecpu,NULL);CHKERRQ(ierr);
1193076ba34aSJunchao Zhang         ierr = PetscOptionsEnd();CHKERRQ(ierr);
1194076ba34aSJunchao Zhang       } else {
1195076ba34aSJunchao Zhang         ierr = PetscOptionsBegin(PetscObjectComm((PetscObject)mat),((PetscObject)mat)->prefix,"MatProduct_AB","Mat");CHKERRQ(ierr);
11963e662e0bSHong Zhang         ierr = PetscOptionsBool("-mat_product_algorithm_backend_cpu","Use CPU code","MatMatMult",usecpu,&usecpu,NULL);CHKERRQ(ierr);
1197076ba34aSJunchao Zhang         ierr = PetscOptionsEnd();CHKERRQ(ierr);
1198076ba34aSJunchao Zhang       }
1199076ba34aSJunchao Zhang       break;
1200076ba34aSJunchao Zhang     case MATPRODUCT_AtB:
1201076ba34aSJunchao Zhang       if (product->api_user) {
1202076ba34aSJunchao Zhang         ierr = PetscOptionsBegin(PetscObjectComm((PetscObject)mat),((PetscObject)mat)->prefix,"MatTransposeMatMult","Mat");CHKERRQ(ierr);
1203076ba34aSJunchao Zhang         ierr = PetscOptionsBool("-mattransposematmult_backend_cpu","Use CPU code","MatTransposeMatMult",usecpu,&usecpu,NULL);CHKERRQ(ierr);
1204076ba34aSJunchao Zhang         ierr = PetscOptionsEnd();CHKERRQ(ierr);
1205076ba34aSJunchao Zhang       } else {
1206076ba34aSJunchao Zhang         ierr = PetscOptionsBegin(PetscObjectComm((PetscObject)mat),((PetscObject)mat)->prefix,"MatProduct_AtB","Mat");CHKERRQ(ierr);
12073e662e0bSHong Zhang         ierr = PetscOptionsBool("-mat_product_algorithm_backend_cpu","Use CPU code","MatTransposeMatMult",usecpu,&usecpu,NULL);CHKERRQ(ierr);
1208076ba34aSJunchao Zhang         ierr = PetscOptionsEnd();CHKERRQ(ierr);
1209076ba34aSJunchao Zhang       }
1210076ba34aSJunchao Zhang       break;
1211076ba34aSJunchao Zhang     case MATPRODUCT_PtAP:
1212076ba34aSJunchao Zhang       if (product->api_user) {
1213076ba34aSJunchao Zhang         ierr = PetscOptionsBegin(PetscObjectComm((PetscObject)mat),((PetscObject)mat)->prefix,"MatPtAP","Mat");CHKERRQ(ierr);
1214076ba34aSJunchao Zhang         ierr = PetscOptionsBool("-matptap_backend_cpu","Use CPU code","MatPtAP",usecpu,&usecpu,NULL);CHKERRQ(ierr);
1215076ba34aSJunchao Zhang         ierr = PetscOptionsEnd();CHKERRQ(ierr);
1216076ba34aSJunchao Zhang       } else {
1217076ba34aSJunchao Zhang         ierr = PetscOptionsBegin(PetscObjectComm((PetscObject)mat),((PetscObject)mat)->prefix,"MatProduct_PtAP","Mat");CHKERRQ(ierr);
12183e662e0bSHong Zhang         ierr = PetscOptionsBool("-mat_product_algorithm_backend_cpu","Use CPU code","MatPtAP",usecpu,&usecpu,NULL);CHKERRQ(ierr);
1219076ba34aSJunchao Zhang         ierr = PetscOptionsEnd();CHKERRQ(ierr);
1220076ba34aSJunchao Zhang       }
1221076ba34aSJunchao Zhang       break;
1222076ba34aSJunchao Zhang     default:
1223076ba34aSJunchao Zhang       break;
1224076ba34aSJunchao Zhang     }
1225076ba34aSJunchao Zhang     match = (PetscBool)!usecpu;
1226076ba34aSJunchao Zhang   }
1227076ba34aSJunchao Zhang   if (match) {
1228076ba34aSJunchao Zhang     switch (product->type) {
1229076ba34aSJunchao Zhang     case MATPRODUCT_AB:
1230076ba34aSJunchao Zhang     case MATPRODUCT_AtB:
1231076ba34aSJunchao Zhang     case MATPRODUCT_PtAP:
1232076ba34aSJunchao Zhang       mat->ops->productsymbolic = MatProductSymbolic_MPIAIJKokkos;
1233076ba34aSJunchao Zhang       break;
1234076ba34aSJunchao Zhang     default:
1235076ba34aSJunchao Zhang       break;
1236076ba34aSJunchao Zhang     }
1237076ba34aSJunchao Zhang   }
1238076ba34aSJunchao Zhang   /* fallback to MPIAIJ ops */
1239076ba34aSJunchao Zhang   if (!mat->ops->productsymbolic) {
1240076ba34aSJunchao Zhang     ierr = MatProductSetFromOptions_MPIAIJ(mat);CHKERRQ(ierr);
1241076ba34aSJunchao Zhang   }
1242076ba34aSJunchao Zhang   PetscFunctionReturn(0);
1243076ba34aSJunchao Zhang }
1244076ba34aSJunchao Zhang 
124582a78a4eSJed Brown static PetscErrorCode MatSetPreallocationCOO_MPIAIJKokkos(Mat mat, PetscCount coo_n, const PetscInt coo_i[], const PetscInt coo_j[])
124642550becSJunchao Zhang {
124742550becSJunchao Zhang   PetscErrorCode            ierr;
1248*394ed5ebSJunchao Zhang   Mat                       newmat;
1249*394ed5ebSJunchao Zhang   Mat_MPIAIJ                *mpiaij = (Mat_MPIAIJ*)mat->data;
125042550becSJunchao Zhang 
125142550becSJunchao Zhang   PetscFunctionBegin;
1252*394ed5ebSJunchao Zhang   ierr = MatCreate(PetscObjectComm((PetscObject)mat),&newmat);CHKERRQ(ierr);
1253*394ed5ebSJunchao Zhang   ierr = MatSetSizes(newmat,mat->rmap->n,mat->cmap->n,mat->rmap->N,mat->cmap->N);CHKERRQ(ierr);
1254*394ed5ebSJunchao Zhang   ierr = MatSetType(newmat,MATMPIAIJ);CHKERRQ(ierr);
1255*394ed5ebSJunchao Zhang   ierr = MatSetOption(newmat,MAT_IGNORE_OFF_PROC_ENTRIES,mpiaij->donotstash);CHKERRQ(ierr);  /* Inherit the two options that we respect from mat */
1256*394ed5ebSJunchao Zhang   ierr = MatSetOption(newmat,MAT_NO_OFF_PROC_ENTRIES,mat->nooffprocentries);CHKERRQ(ierr);
1257*394ed5ebSJunchao Zhang   ierr = MatSetPreallocationCOO_MPIAIJ(newmat,coo_n,coo_i,coo_j);CHKERRQ(ierr);
125842550becSJunchao Zhang   ierr = MatConvert(newmat,MATMPIAIJKOKKOS,MAT_INPLACE_MATRIX,&newmat);CHKERRQ(ierr);
1259*394ed5ebSJunchao Zhang   ierr = MatHeaderMerge(mat,&newmat);CHKERRQ(ierr); /* Not MatHeaderReplace() since we want to keep some mat's info */
126042550becSJunchao Zhang   ierr = MatZeroEntries(mat);CHKERRQ(ierr); /* Zero matrix on device */
1261*394ed5ebSJunchao Zhang   mpiaij = static_cast<Mat_MPIAIJ*>(mat->data); /* mat->data was changed in MatHeaderReplace() */
1262*394ed5ebSJunchao Zhang   mpiaij->spptr = new Mat_MPIAIJKokkos(mpiaij);
126342550becSJunchao Zhang   PetscFunctionReturn(0);
126442550becSJunchao Zhang }
126542550becSJunchao Zhang 
126642550becSJunchao Zhang static PetscErrorCode MatSetValuesCOO_MPIAIJKokkos(Mat mat,const PetscScalar v[],InsertMode imode)
126742550becSJunchao Zhang {
126842550becSJunchao Zhang   PetscErrorCode                 ierr;
1269*394ed5ebSJunchao Zhang   Mat_MPIAIJ                     *mpiaij = static_cast<Mat_MPIAIJ*>(mat->data);
127042550becSJunchao Zhang   Mat_MPIAIJKokkos               *mpikok = static_cast<Mat_MPIAIJKokkos*>(mpiaij->spptr);
127142550becSJunchao Zhang   Mat                            A = mpiaij->A,B = mpiaij->B;
1272*394ed5ebSJunchao Zhang   PetscCount                     Annz1 = mpiaij->Annz1,Annz2 = mpiaij->Annz2,Bnnz1 = mpiaij->Bnnz1,Bnnz2 = mpiaij->Bnnz2;
127342550becSJunchao Zhang   MatScalarKokkosView            Aa,Ba;
1274*394ed5ebSJunchao Zhang   MatScalarKokkosView            v1;
127542550becSJunchao Zhang   MatScalarKokkosView&           vsend = mpikok->sendbuf_d;
127642550becSJunchao Zhang   const MatScalarKokkosView&     v2 = mpikok->recvbuf_d;
1277*394ed5ebSJunchao Zhang   const PetscCountKokkosView&    Ajmap1 = mpikok->Ajmap1_d,Ajmap2 = mpikok->Ajmap2_d,Aimap1 = mpikok->Aimap1_d,Aimap2 = mpikok->Aimap2_d;
1278*394ed5ebSJunchao Zhang   const PetscCountKokkosView&    Bjmap1 = mpikok->Bjmap1_d,Bjmap2 = mpikok->Bjmap2_d,Bimap1 = mpikok->Bimap1_d,Bimap2 = mpikok->Bimap2_d;
1279*394ed5ebSJunchao Zhang   const PetscCountKokkosView&    Aperm1 = mpikok->Aperm1_d,Aperm2 = mpikok->Aperm2_d,Bperm1 = mpikok->Bperm1_d,Bperm2 = mpikok->Bperm2_d;
1280*394ed5ebSJunchao Zhang   const PetscCountKokkosView&    Cperm1 = mpikok->Cperm1_d;
128142550becSJunchao Zhang   PetscMemType                   memtype;
128242550becSJunchao Zhang 
128342550becSJunchao Zhang   PetscFunctionBegin;
1284*394ed5ebSJunchao Zhang   PetscAssert(mat->assembled,PETSC_COMM_SELF,PETSC_ERR_PLIB,"Expected matrix to be already assembled in MatSetPreallocationCOO()");
1285*394ed5ebSJunchao Zhang   ierr = PetscGetMemType(v,&memtype);CHKERRQ(ierr); /* Return PETSC_MEMTYPE_HOST when v is NULL */
128642550becSJunchao Zhang   if (PetscMemTypeHost(memtype)) { /* If user gave v[] in host, we need to copy it to device if any */
1287*394ed5ebSJunchao Zhang     if (v) {
1288*394ed5ebSJunchao Zhang       v1 = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(),MatScalarKokkosViewHost((PetscScalar*)v,mpiaij->coo_n));
128942550becSJunchao Zhang     } else {
1290*394ed5ebSJunchao Zhang       v1 = MatScalarKokkosView("v1",mpiaij->coo_n);
1291*394ed5ebSJunchao Zhang       Kokkos::deep_copy(v1,0.0);
1292*394ed5ebSJunchao Zhang     }
1293*394ed5ebSJunchao Zhang   } else {
1294*394ed5ebSJunchao Zhang     v1 = MatScalarKokkosView((PetscScalar*)v,mpiaij->coo_n); /* Directly use v[]'s memory */
129542550becSJunchao Zhang   }
129642550becSJunchao Zhang 
129742550becSJunchao Zhang   if (imode == INSERT_VALUES) {
1298*394ed5ebSJunchao Zhang     ierr = MatSeqAIJGetKokkosViewWrite(A,&Aa);CHKERRQ(ierr); /* write matrix values */
1299*394ed5ebSJunchao Zhang     ierr = MatSeqAIJGetKokkosViewWrite(B,&Ba);CHKERRQ(ierr);
130042550becSJunchao Zhang     Kokkos::deep_copy(Aa,0.0); /* Zero matrix values since INSERT_VALUES still requires summing replicated values in v[] */
130142550becSJunchao Zhang     Kokkos::deep_copy(Ba,0.0);
1302*394ed5ebSJunchao Zhang   } else {
1303*394ed5ebSJunchao Zhang     ierr = MatSeqAIJGetKokkosView(A,&Aa);CHKERRQ(ierr); /* read & write matrix values */
1304*394ed5ebSJunchao Zhang     ierr = MatSeqAIJGetKokkosView(B,&Ba);CHKERRQ(ierr);
130542550becSJunchao Zhang   }
130642550becSJunchao Zhang 
130742550becSJunchao Zhang   /* Pack entries to be sent to remote */
1308*394ed5ebSJunchao Zhang   Kokkos::parallel_for(vsend.extent(0),KOKKOS_LAMBDA(const PetscCount i) {vsend(i) = v1(Cperm1(i));});
130942550becSJunchao Zhang 
131042550becSJunchao Zhang   /* Send remote entries to their owner and overlap the communication with local computation */
1311*394ed5ebSJunchao Zhang   ierr = PetscSFReduceWithMemTypeBegin(mpiaij->coo_sf,MPIU_SCALAR,PETSC_MEMTYPE_KOKKOS,vsend.data(),PETSC_MEMTYPE_KOKKOS,v2.data(),MPI_REPLACE);CHKERRQ(ierr);
131242550becSJunchao Zhang   /* Add local entries to A and B */
1313*394ed5ebSJunchao Zhang   Kokkos::parallel_for(Annz1,KOKKOS_LAMBDA(const PetscCount i) {for (PetscCount k=Ajmap1(i); k<Ajmap1(i+1); k++) Aa(Aimap1(i)) += v1(Aperm1(k));});
1314*394ed5ebSJunchao Zhang   Kokkos::parallel_for(Bnnz1,KOKKOS_LAMBDA(const PetscCount i) {for (PetscCount k=Bjmap1(i); k<Bjmap1(i+1); k++) Ba(Bimap1(i)) += v1(Bperm1(k));});
1315*394ed5ebSJunchao Zhang   ierr = PetscSFReduceEnd(mpiaij->coo_sf,MPIU_SCALAR,vsend.data(),v2.data(),MPI_REPLACE);CHKERRQ(ierr);
131642550becSJunchao Zhang 
131742550becSJunchao Zhang   /* Add received remote entries to A and B */
1318*394ed5ebSJunchao Zhang   Kokkos::parallel_for(Annz2,KOKKOS_LAMBDA(const PetscCount i) {for (PetscCount k=Ajmap2(i); k<Ajmap2(i+1); k++) Aa(Aimap2(i)) += v2(Aperm2(k));});
1319*394ed5ebSJunchao Zhang   Kokkos::parallel_for(Bnnz2,KOKKOS_LAMBDA(const PetscCount i) {for (PetscCount k=Bjmap2(i); k<Bjmap2(i+1); k++) Ba(Bimap2(i)) += v2(Bperm2(k));});
132042550becSJunchao Zhang 
1321*394ed5ebSJunchao Zhang   if (imode == INSERT_VALUES) {
1322*394ed5ebSJunchao Zhang     ierr = MatSeqAIJRestoreKokkosViewWrite(A,&Aa);CHKERRQ(ierr); /* Increase A & B's state etc. */
1323*394ed5ebSJunchao Zhang     ierr = MatSeqAIJRestoreKokkosViewWrite(B,&Ba);CHKERRQ(ierr);
1324*394ed5ebSJunchao Zhang   } else {
132542550becSJunchao Zhang     ierr = MatSeqAIJRestoreKokkosView(A,&Aa);CHKERRQ(ierr);
132642550becSJunchao Zhang     ierr = MatSeqAIJRestoreKokkosView(B,&Ba);CHKERRQ(ierr);
1327*394ed5ebSJunchao Zhang   }
132842550becSJunchao Zhang   PetscFunctionReturn(0);
132942550becSJunchao Zhang }
133042550becSJunchao Zhang 
1331076ba34aSJunchao Zhang PetscErrorCode MatDestroy_MPIAIJKokkos(Mat A)
1332076ba34aSJunchao Zhang {
1333076ba34aSJunchao Zhang   PetscErrorCode     ierr;
133442550becSJunchao Zhang   Mat_MPIAIJ         *mpiaij = (Mat_MPIAIJ*)A->data;
1335076ba34aSJunchao Zhang 
1336076ba34aSJunchao Zhang   PetscFunctionBegin;
1337076ba34aSJunchao Zhang   ierr = PetscObjectComposeFunction((PetscObject)A,"MatMPIAIJSetPreallocation_C",NULL);CHKERRQ(ierr);
1338076ba34aSJunchao Zhang   ierr = PetscObjectComposeFunction((PetscObject)A,"MatMPIAIJGetLocalMatMerge_C",NULL);CHKERRQ(ierr);
133942550becSJunchao Zhang   ierr = PetscObjectComposeFunction((PetscObject)A,"MatSetPreallocationCOO_C",   NULL);CHKERRQ(ierr);
134042550becSJunchao Zhang   ierr = PetscObjectComposeFunction((PetscObject)A,"MatSetValuesCOO_C",          NULL);CHKERRQ(ierr);
134142550becSJunchao Zhang   delete (Mat_MPIAIJKokkos*)mpiaij->spptr;
1342076ba34aSJunchao Zhang   ierr = MatDestroy_MPIAIJ(A);CHKERRQ(ierr);
1343076ba34aSJunchao Zhang   PetscFunctionReturn(0);
1344076ba34aSJunchao Zhang }
1345076ba34aSJunchao Zhang 
13468c3ff71bSJunchao Zhang PETSC_INTERN PetscErrorCode MatConvert_MPIAIJ_MPIAIJKokkos(Mat A, MatType mtype, MatReuse reuse, Mat* newmat)
13478c3ff71bSJunchao Zhang {
13488c3ff71bSJunchao Zhang   PetscErrorCode     ierr;
13498c3ff71bSJunchao Zhang   Mat                B;
1350076ba34aSJunchao Zhang   Mat_MPIAIJ         *a;
13518c3ff71bSJunchao Zhang 
13528c3ff71bSJunchao Zhang   PetscFunctionBegin;
13538c3ff71bSJunchao Zhang   if (reuse == MAT_INITIAL_MATRIX) {
13548c3ff71bSJunchao Zhang     ierr = MatDuplicate(A,MAT_COPY_VALUES,newmat);CHKERRQ(ierr);
13558c3ff71bSJunchao Zhang   } else if (reuse == MAT_REUSE_MATRIX) {
13568c3ff71bSJunchao Zhang     ierr = MatCopy(A,*newmat,SAME_NONZERO_PATTERN);CHKERRQ(ierr);
13578c3ff71bSJunchao Zhang   }
13588c3ff71bSJunchao Zhang   B = *newmat;
13598c3ff71bSJunchao Zhang 
13606f3d89d0SStefano Zampini   B->boundtocpu = PETSC_FALSE;
13618c3ff71bSJunchao Zhang   ierr = PetscFree(B->defaultvectype);CHKERRQ(ierr);
13628c3ff71bSJunchao Zhang   ierr = PetscStrallocpy(VECKOKKOS,&B->defaultvectype);CHKERRQ(ierr);
13633d0639e7SStefano Zampini   ierr = PetscObjectChangeTypeName((PetscObject)B,MATMPIAIJKOKKOS);CHKERRQ(ierr);
13648c3ff71bSJunchao Zhang 
1365076ba34aSJunchao Zhang   a = static_cast<Mat_MPIAIJ*>(A->data);
1366076ba34aSJunchao Zhang   if (a->A) {ierr = MatSetType(a->A,MATSEQAIJKOKKOS);CHKERRQ(ierr);}
1367076ba34aSJunchao Zhang   if (a->B) {ierr = MatSetType(a->B,MATSEQAIJKOKKOS);CHKERRQ(ierr);}
1368076ba34aSJunchao Zhang   if (a->lvec) {ierr = VecSetType(a->lvec,VECSEQKOKKOS);CHKERRQ(ierr);}
1369076ba34aSJunchao Zhang 
13708c3ff71bSJunchao Zhang   B->ops->assemblyend           = MatAssemblyEnd_MPIAIJKokkos;
13718c3ff71bSJunchao Zhang   B->ops->mult                  = MatMult_MPIAIJKokkos;
13728c3ff71bSJunchao Zhang   B->ops->multadd               = MatMultAdd_MPIAIJKokkos;
13738c3ff71bSJunchao Zhang   B->ops->multtranspose         = MatMultTranspose_MPIAIJKokkos;
1374076ba34aSJunchao Zhang   B->ops->productsetfromoptions = MatProductSetFromOptions_MPIAIJKokkos;
1375076ba34aSJunchao Zhang   B->ops->destroy               = MatDestroy_MPIAIJKokkos;
13768c3ff71bSJunchao Zhang 
13773d0639e7SStefano Zampini   ierr = PetscObjectComposeFunction((PetscObject)B,"MatMPIAIJSetPreallocation_C",MatMPIAIJSetPreallocation_MPIAIJKokkos);CHKERRQ(ierr);
137842550becSJunchao Zhang   ierr = PetscObjectComposeFunction((PetscObject)B,"MatMPIAIJGetLocalMatMerge_C",MatMPIAIJGetLocalMatMerge_MPIAIJKokkos);CHKERRQ(ierr);
137942550becSJunchao Zhang   ierr = PetscObjectComposeFunction((PetscObject)B,"MatSetPreallocationCOO_C",   MatSetPreallocationCOO_MPIAIJKokkos);CHKERRQ(ierr);
138042550becSJunchao Zhang   ierr = PetscObjectComposeFunction((PetscObject)B,"MatSetValuesCOO_C",          MatSetValuesCOO_MPIAIJKokkos);CHKERRQ(ierr);
13818c3ff71bSJunchao Zhang   PetscFunctionReturn(0);
13828c3ff71bSJunchao Zhang }
13838c3ff71bSJunchao Zhang 
13848c3ff71bSJunchao Zhang PETSC_EXTERN PetscErrorCode MatCreate_MPIAIJKokkos(Mat A)
13858c3ff71bSJunchao Zhang {
13868c3ff71bSJunchao Zhang   PetscErrorCode ierr;
13878c3ff71bSJunchao Zhang 
13888c3ff71bSJunchao Zhang   PetscFunctionBegin;
13898c3ff71bSJunchao Zhang   ierr = PetscKokkosInitializeCheck();CHKERRQ(ierr);
13908c3ff71bSJunchao Zhang   ierr = MatCreate_MPIAIJ(A);CHKERRQ(ierr);
13918c3ff71bSJunchao Zhang   ierr = MatConvert_MPIAIJ_MPIAIJKokkos(A,MATMPIAIJKOKKOS,MAT_INPLACE_MATRIX,&A);CHKERRQ(ierr);
13928c3ff71bSJunchao Zhang   PetscFunctionReturn(0);
13938c3ff71bSJunchao Zhang }
13948c3ff71bSJunchao Zhang 
13958c3ff71bSJunchao Zhang /*@C
13968c3ff71bSJunchao Zhang    MatCreateAIJKokkos - Creates a sparse matrix in AIJ (compressed row) format
13978c3ff71bSJunchao Zhang    (the default parallel PETSc format).  This matrix will ultimately pushed down
13988c3ff71bSJunchao Zhang    to Kokkos for calculations. For good matrix
13998c3ff71bSJunchao Zhang    assembly performance the user should preallocate the matrix storage by setting
14008c3ff71bSJunchao Zhang    the parameter nz (or the array nnz).  By setting these parameters accurately,
14018c3ff71bSJunchao Zhang    performance during matrix assembly can be increased by more than a factor of 50.
14028c3ff71bSJunchao Zhang 
14038c3ff71bSJunchao Zhang    Collective
14048c3ff71bSJunchao Zhang 
14058c3ff71bSJunchao Zhang    Input Parameters:
14068c3ff71bSJunchao Zhang +  comm - MPI communicator, set to PETSC_COMM_SELF
14078c3ff71bSJunchao Zhang .  m - number of rows
14088c3ff71bSJunchao Zhang .  n - number of columns
14098c3ff71bSJunchao Zhang .  nz - number of nonzeros per row (same for all rows)
14108c3ff71bSJunchao Zhang -  nnz - array containing the number of nonzeros in the various rows
14118c3ff71bSJunchao Zhang          (possibly different for each row) or NULL
14128c3ff71bSJunchao Zhang 
14138c3ff71bSJunchao Zhang    Output Parameter:
14148c3ff71bSJunchao Zhang .  A - the matrix
14158c3ff71bSJunchao Zhang 
14168c3ff71bSJunchao Zhang    It is recommended that one use the MatCreate(), MatSetType() and/or MatSetFromOptions(),
14178c3ff71bSJunchao Zhang    MatXXXXSetPreallocation() paradigm instead of this routine directly.
14188c3ff71bSJunchao Zhang    [MatXXXXSetPreallocation() is, for example, MatSeqAIJSetPreallocation]
14198c3ff71bSJunchao Zhang 
14208c3ff71bSJunchao Zhang    Notes:
14218c3ff71bSJunchao Zhang    If nnz is given then nz is ignored
14228c3ff71bSJunchao Zhang 
14238c3ff71bSJunchao Zhang    The AIJ format (also called the Yale sparse matrix format or
14248c3ff71bSJunchao Zhang    compressed row storage), is fully compatible with standard Fortran 77
14258c3ff71bSJunchao Zhang    storage.  That is, the stored row and column indices can begin at
14268c3ff71bSJunchao Zhang    either one (as in Fortran) or zero.  See the users' manual for details.
14278c3ff71bSJunchao Zhang 
14288c3ff71bSJunchao Zhang    Specify the preallocated storage with either nz or nnz (not both).
14298c3ff71bSJunchao Zhang    Set nz=PETSC_DEFAULT and nnz=NULL for PETSc to control dynamic memory
14308c3ff71bSJunchao Zhang    allocation.  For large problems you MUST preallocate memory or you
14318c3ff71bSJunchao Zhang    will get TERRIBLE performance, see the users' manual chapter on matrices.
14328c3ff71bSJunchao Zhang 
14338c3ff71bSJunchao Zhang    By default, this format uses inodes (identical nodes) when possible, to
14348c3ff71bSJunchao Zhang    improve numerical efficiency of matrix-vector products and solves. We
14358c3ff71bSJunchao Zhang    search for consecutive rows with the same nonzero structure, thereby
14368c3ff71bSJunchao Zhang    reusing matrix information to achieve increased efficiency.
14378c3ff71bSJunchao Zhang 
14388c3ff71bSJunchao Zhang    Level: intermediate
14398c3ff71bSJunchao Zhang 
14408c3ff71bSJunchao Zhang .seealso: MatCreate(), MatCreateAIJ(), MatSetValues(), MatSeqAIJSetColumnIndices(), MatCreateSeqAIJWithArrays(), MatCreateAIJ(), MATMPIAIJKOKKOS, MATAIJKokkos
14418c3ff71bSJunchao Zhang @*/
14428c3ff71bSJunchao Zhang PetscErrorCode  MatCreateAIJKokkos(MPI_Comm comm,PetscInt m,PetscInt n,PetscInt M,PetscInt N,PetscInt d_nz,const PetscInt d_nnz[],PetscInt o_nz,const PetscInt o_nnz[],Mat *A)
14438c3ff71bSJunchao Zhang {
14448c3ff71bSJunchao Zhang   PetscErrorCode ierr;
14458c3ff71bSJunchao Zhang   PetscMPIInt    size;
14468c3ff71bSJunchao Zhang 
14478c3ff71bSJunchao Zhang   PetscFunctionBegin;
14488c3ff71bSJunchao Zhang   ierr = MatCreate(comm,A);CHKERRQ(ierr);
14498c3ff71bSJunchao Zhang   ierr = MatSetSizes(*A,m,n,M,N);CHKERRQ(ierr);
1450ffc4695bSBarry Smith   ierr = MPI_Comm_size(comm,&size);CHKERRMPI(ierr);
14518c3ff71bSJunchao Zhang   if (size > 1) {
14528c3ff71bSJunchao Zhang     ierr = MatSetType(*A,MATMPIAIJKOKKOS);CHKERRQ(ierr);
14538c3ff71bSJunchao Zhang     ierr = MatMPIAIJSetPreallocation(*A,d_nz,d_nnz,o_nz,o_nnz);CHKERRQ(ierr);
14548c3ff71bSJunchao Zhang   } else {
14558c3ff71bSJunchao Zhang     ierr = MatSetType(*A,MATSEQAIJKOKKOS);CHKERRQ(ierr);
14568c3ff71bSJunchao Zhang     ierr = MatSeqAIJSetPreallocation(*A,d_nz,d_nnz);CHKERRQ(ierr);
14578c3ff71bSJunchao Zhang   }
14588c3ff71bSJunchao Zhang   PetscFunctionReturn(0);
14598c3ff71bSJunchao Zhang }
14608c3ff71bSJunchao Zhang 
1461a587d139SMark // get GPU pointer to stripped down Mat. For both Seq and MPI Mat.
1462042217e8SBarry Smith PetscErrorCode MatKokkosGetDeviceMatWrite(Mat A, PetscSplitCSRDataStructure *B)
1463a587d139SMark {
1464a587d139SMark   PetscMPIInt                size,rank;
1465a587d139SMark   MPI_Comm                   comm;
1466a587d139SMark   PetscErrorCode             ierr;
1467042217e8SBarry Smith   PetscSplitCSRDataStructure d_mat=NULL;
1468a587d139SMark 
1469a587d139SMark   PetscFunctionBegin;
1470a587d139SMark   ierr = PetscObjectGetComm((PetscObject)A,&comm);CHKERRQ(ierr);
147155b25c41SPierre Jolivet   ierr = MPI_Comm_size(comm,&size);CHKERRMPI(ierr);
147255b25c41SPierre Jolivet   ierr = MPI_Comm_rank(comm,&rank);CHKERRMPI(ierr);
1473a587d139SMark   if (size == 1) {
1474a587d139SMark     ierr   = MatSeqAIJKokkosGetDeviceMat(A,&d_mat);CHKERRQ(ierr);
1475fc76dfabSMark Adams     ierr   = MatSeqAIJKokkosModifyDevice(A);CHKERRQ(ierr); /* Since we are going to modify matrix values on device */
1476a587d139SMark   } else {
1477a587d139SMark     Mat_MPIAIJ  *aij = (Mat_MPIAIJ*)A->data;
1478a587d139SMark     ierr   = MatSeqAIJKokkosGetDeviceMat(aij->A,&d_mat);CHKERRQ(ierr);
1479fc76dfabSMark Adams     ierr   = MatSeqAIJKokkosModifyDevice(aij->A);CHKERRQ(ierr);
1480fc76dfabSMark Adams     ierr   = MatSeqAIJKokkosModifyDevice(aij->B);CHKERRQ(ierr);
14812c71b3e2SJacob Faibussowitsch     PetscCheck(A->nooffprocentries || aij->donotstash,PetscObjectComm((PetscObject)A),PETSC_ERR_SUP,"Device assembly does not currently support offproc values insertion. Use MatSetOption(A,MAT_NO_OFF_PROC_ENTRIES,PETSC_TRUE) or MatSetOption(A,MAT_IGNORE_OFF_PROC_ENTRIES,PETSC_TRUE)");
1482a587d139SMark   }
1483a587d139SMark   // act like MatSetValues because not called on host
1484a587d139SMark   if (A->assembled) {
1485a587d139SMark     if (A->was_assembled) {
1486a587d139SMark       ierr = PetscInfo(A,"Assemble more than once already\n");CHKERRQ(ierr);
1487a587d139SMark     }
1488a587d139SMark     A->was_assembled = PETSC_TRUE; // this is done (lazy) in MatAssemble but we are not calling it anymore - done in AIJ AssemblyEnd, need here?
1489a587d139SMark   } else {
14907d3de750SJacob Faibussowitsch     ierr = PetscInfo(A,"Warning !assemble ??? assembled=%" PetscInt_FMT "\n",A->assembled);CHKERRQ(ierr);
1491a587d139SMark   }
1492a587d139SMark   if (!d_mat) {
1493042217e8SBarry Smith     struct _n_SplitCSRMat h_mat; /* host container */
1494a587d139SMark     Mat_SeqAIJKokkos      *aijkokA;
1495a587d139SMark     Mat_SeqAIJ            *jaca;
1496a587d139SMark     PetscInt              n = A->rmap->n, nnz;
1497a587d139SMark     Mat                   Amat;
1498042217e8SBarry Smith     PetscInt              *colmap;
1499042217e8SBarry Smith 
1500042217e8SBarry Smith     /* create and copy h_mat */
150149b994a9SMark Adams     h_mat.M = A->cmap->N; // use for debug build
1502a587d139SMark     ierr = PetscInfo(A,"Create device matrix in Kokkos\n");CHKERRQ(ierr);
1503a587d139SMark     if (size == 1) {
1504a587d139SMark       Amat = A;
1505a587d139SMark       jaca = (Mat_SeqAIJ*)A->data;
1506a587d139SMark       h_mat.rstart = 0; h_mat.rend = A->rmap->n;
1507a587d139SMark       h_mat.cstart = 0; h_mat.cend = A->cmap->n;
1508a587d139SMark       h_mat.offdiag.i = h_mat.offdiag.j = NULL;
1509a587d139SMark       h_mat.offdiag.a = NULL;
1510a587d139SMark       aijkokA = static_cast<Mat_SeqAIJKokkos*>(A->spptr);
1511a587d139SMark     } else {
1512a587d139SMark       Mat_MPIAIJ       *aij = (Mat_MPIAIJ*)A->data;
1513a587d139SMark       Mat_SeqAIJ       *jacb = (Mat_SeqAIJ*)aij->B->data;
1514a587d139SMark       PetscInt         ii;
1515a587d139SMark       Mat_SeqAIJKokkos *aijkokB;
1516042217e8SBarry Smith 
1517a587d139SMark       Amat = aij->A;
1518a587d139SMark       aijkokA = static_cast<Mat_SeqAIJKokkos*>(aij->A->spptr);
1519a587d139SMark       aijkokB = static_cast<Mat_SeqAIJKokkos*>(aij->B->spptr);
1520a587d139SMark       jaca = (Mat_SeqAIJ*)aij->A->data;
15212c71b3e2SJacob Faibussowitsch       PetscCheckFalse(aij->B->cmap->n && !aij->garray,comm,PETSC_ERR_PLIB,"MPIAIJ Matrix was assembled but is missing garray");
15222c71b3e2SJacob Faibussowitsch       PetscCheckFalse(aij->B->rmap->n != aij->A->rmap->n,comm,PETSC_ERR_SUP,"Only support aij->B->rmap->n == aij->A->rmap->n");
1523a587d139SMark       aij->donotstash = PETSC_TRUE;
1524a587d139SMark       aij->A->nooffprocentries = aij->B->nooffprocentries = A->nooffprocentries = PETSC_TRUE;
1525a5b23f4aSJose E. Roman       jaca->nonew = jacb->nonew = PETSC_TRUE; // no more disassembly
1526042217e8SBarry Smith       ierr = PetscCalloc1(A->cmap->N,&colmap);CHKERRQ(ierr);
1527042217e8SBarry Smith       ierr = PetscLogObjectMemory((PetscObject)A,(A->cmap->N)*sizeof(PetscInt));CHKERRQ(ierr);
1528042217e8SBarry Smith       for (ii=0; ii<aij->B->cmap->n; ii++) colmap[aij->garray[ii]] = ii+1;
1529a587d139SMark       // allocate B copy data
1530a587d139SMark       h_mat.rstart = A->rmap->rstart; h_mat.rend = A->rmap->rend;
1531a587d139SMark       h_mat.cstart = A->cmap->rstart; h_mat.cend = A->cmap->rend;
1532a587d139SMark       nnz = jacb->i[n];
1533a587d139SMark       if (jacb->compressedrow.use) {
1534a587d139SMark         const Kokkos::View<PetscInt*, Kokkos::HostSpace, Kokkos::MemoryTraits<Kokkos::Unmanaged> > h_i_k (jacb->i,n+1);
1535300d22a6SJunchao Zhang         aijkokB->i_uncompressed_d = Kokkos::View<PetscInt*>(Kokkos::create_mirror(DefaultMemorySpace(),h_i_k));
1536300d22a6SJunchao Zhang         Kokkos::deep_copy (aijkokB->i_uncompressed_d, h_i_k);
1537300d22a6SJunchao Zhang         h_mat.offdiag.i = aijkokB->i_uncompressed_d.data();
1538a587d139SMark       } else {
153999551766SMark Adams          h_mat.offdiag.i = aijkokB->i_device_data();
1540a587d139SMark       }
154199551766SMark Adams       h_mat.offdiag.j = aijkokB->j_device_data();
1542076ba34aSJunchao Zhang       h_mat.offdiag.a = aijkokB->a_device_data();
1543a587d139SMark       {
1544042217e8SBarry Smith         Kokkos::View<PetscInt*, Kokkos::HostSpace, Kokkos::MemoryTraits<Kokkos::Unmanaged> > h_colmap_k (colmap,A->cmap->N);
1545300d22a6SJunchao Zhang         aijkokB->colmap_d = Kokkos::View<PetscInt*>(Kokkos::create_mirror(DefaultMemorySpace(),h_colmap_k));
1546300d22a6SJunchao Zhang         Kokkos::deep_copy (aijkokB->colmap_d, h_colmap_k);
1547300d22a6SJunchao Zhang         h_mat.colmap = aijkokB->colmap_d.data();
1548042217e8SBarry Smith         ierr = PetscFree(colmap);CHKERRQ(ierr);
1549a587d139SMark       }
1550a587d139SMark       h_mat.offdiag.ignorezeroentries = jacb->ignorezeroentries;
1551a587d139SMark       h_mat.offdiag.n = n;
1552a587d139SMark     }
1553a587d139SMark     // allocate A copy data
1554a587d139SMark     nnz = jaca->i[n];
1555a587d139SMark     h_mat.diag.n = n;
1556a587d139SMark     h_mat.diag.ignorezeroentries = jaca->ignorezeroentries;
155755b25c41SPierre Jolivet     ierr = MPI_Comm_rank(comm,&h_mat.rank);CHKERRMPI(ierr);
15582c71b3e2SJacob Faibussowitsch     PetscCheckFalse(jaca->compressedrow.use,PETSC_COMM_SELF,PETSC_ERR_PLIB,"A does not suppport compressed row (todo)");
1559042217e8SBarry Smith     else {
156099551766SMark Adams       h_mat.diag.i = aijkokA->i_device_data();
1561a587d139SMark     }
156299551766SMark Adams     h_mat.diag.j = aijkokA->j_device_data();
1563076ba34aSJunchao Zhang     h_mat.diag.a = aijkokA->a_device_data();
1564a587d139SMark     // copy pointers and metdata to device
1565a587d139SMark     ierr = MatSeqAIJKokkosSetDeviceMat(Amat,&h_mat);CHKERRQ(ierr);
1566a587d139SMark     ierr = MatSeqAIJKokkosGetDeviceMat(Amat,&d_mat);CHKERRQ(ierr);
15677d3de750SJacob Faibussowitsch     ierr = PetscInfo(A,"Create device Mat n=%" PetscInt_FMT " nnz=%" PetscInt_FMT "\n",h_mat.diag.n, nnz);CHKERRQ(ierr);
1568a587d139SMark   }
1569a587d139SMark   *B = d_mat; // return it, set it in Mat, and set it up
1570a587d139SMark   A->assembled = PETSC_FALSE; // ready to write with matsetvalues - this done (lazy) in normal MatSetValues
1571a587d139SMark   PetscFunctionReturn(0);
1572a587d139SMark }
1573076ba34aSJunchao Zhang 
1574076ba34aSJunchao Zhang PETSC_INTERN PetscErrorCode MatSeqAIJKokkosGetOffloadMask(Mat A, const char **mask)
1575076ba34aSJunchao Zhang {
1576076ba34aSJunchao Zhang   Mat_SeqAIJKokkos  *aijkok = static_cast<Mat_SeqAIJKokkos*>(A->spptr);
1577076ba34aSJunchao Zhang 
1578076ba34aSJunchao Zhang   PetscFunctionBegin;
1579076ba34aSJunchao Zhang   if (!aijkok) *mask = "AIJKOK_UNALLOCATED";
1580076ba34aSJunchao Zhang   else if (aijkok->a_dual.need_sync_host()) *mask = "PETSC_OFFLOAD_GPU";
1581076ba34aSJunchao Zhang   else if (aijkok->a_dual.need_sync_device()) *mask = "PETSC_OFFLOAD_CPU";
1582076ba34aSJunchao Zhang   else *mask = "PETSC_OFFLOAD_BOTH";
1583076ba34aSJunchao Zhang   PetscFunctionReturn(0);
1584076ba34aSJunchao Zhang }
1585076ba34aSJunchao Zhang 
1586076ba34aSJunchao Zhang PETSC_INTERN PetscErrorCode MatAIJKokkosPrintOffloadMask(Mat A)
1587076ba34aSJunchao Zhang {
1588076ba34aSJunchao Zhang   PetscErrorCode    ierr;
1589076ba34aSJunchao Zhang   PetscMPIInt       size;
1590076ba34aSJunchao Zhang   Mat               Ad,Ao;
1591076ba34aSJunchao Zhang   const char        *amask,*bmask;
1592076ba34aSJunchao Zhang 
1593076ba34aSJunchao Zhang   PetscFunctionBegin;
1594076ba34aSJunchao Zhang   ierr = MPI_Comm_size(PetscObjectComm((PetscObject)A),&size);CHKERRMPI(ierr);
1595076ba34aSJunchao Zhang 
1596076ba34aSJunchao Zhang   if (size == 1) {
1597076ba34aSJunchao Zhang     ierr = MatSeqAIJKokkosGetOffloadMask(A,&amask);CHKERRQ(ierr);
1598076ba34aSJunchao Zhang     ierr = PetscPrintf(PETSC_COMM_SELF,"%s\n",amask);CHKERRQ(ierr);
1599076ba34aSJunchao Zhang   } else {
1600076ba34aSJunchao Zhang     Ad  = ((Mat_MPIAIJ*)A->data)->A;
1601076ba34aSJunchao Zhang     Ao  = ((Mat_MPIAIJ*)A->data)->B;
1602076ba34aSJunchao Zhang     ierr = MatSeqAIJKokkosGetOffloadMask(Ad,&amask);CHKERRQ(ierr);
1603076ba34aSJunchao Zhang     ierr = MatSeqAIJKokkosGetOffloadMask(Ao,&bmask);CHKERRQ(ierr);
1604076ba34aSJunchao Zhang     ierr = PetscPrintf(PETSC_COMM_SELF,"Diag : Off-diag = %s : %s\n",amask,bmask);CHKERRQ(ierr);
1605076ba34aSJunchao Zhang   }
1606076ba34aSJunchao Zhang   PetscFunctionReturn(0);
1607076ba34aSJunchao Zhang }
1608