xref: /petsc/src/mat/impls/aij/mpi/kokkos/mpiaijkok.kokkos.cxx (revision 3e662e0b84908e4f96ebbc649bff52cb1d2f6c56)
111d22bbfSJunchao Zhang #include <petscvec_kokkos.hpp>
2076ba34aSJunchao Zhang #include <petscsf.h>
342550becSJunchao Zhang #include <petsc/private/sfimpl.h>
48c3ff71bSJunchao Zhang #include <../src/mat/impls/aij/mpi/mpiaij.h>
542550becSJunchao Zhang #include <../src/mat/impls/aij/seq/kokkos/aijkok.hpp>
642550becSJunchao Zhang #include <../src/mat/impls/aij/mpi/kokkos/mpiaijkok.hpp>
7076ba34aSJunchao Zhang #include <KokkosSparse_spadd.hpp>
811d22bbfSJunchao Zhang 
98c3ff71bSJunchao Zhang PetscErrorCode MatAssemblyEnd_MPIAIJKokkos(Mat A,MatAssemblyType mode)
108c3ff71bSJunchao Zhang {
118c3ff71bSJunchao Zhang   PetscErrorCode   ierr;
128c3ff71bSJunchao Zhang   Mat_MPIAIJ       *mpiaij = (Mat_MPIAIJ*)A->data;
13a587d139SMark   Mat_SeqAIJKokkos *aijkok = mpiaij->A->spptr ? static_cast<Mat_SeqAIJKokkos*>(mpiaij->A->spptr) : NULL;
148c3ff71bSJunchao Zhang 
158c3ff71bSJunchao Zhang   PetscFunctionBegin;
168c3ff71bSJunchao Zhang   ierr = MatAssemblyEnd_MPIAIJ(A,mode);CHKERRQ(ierr);
17a587d139SMark   if (aijkok && aijkok->device_mat_d.data()) {
18a587d139SMark     A->offloadmask = PETSC_OFFLOAD_GPU; // in GPU mode, no going back. MatSetValues checks this
19a587d139SMark   }
20a587d139SMark 
218c3ff71bSJunchao Zhang   PetscFunctionReturn(0);
228c3ff71bSJunchao Zhang }
238c3ff71bSJunchao Zhang 
248c3ff71bSJunchao Zhang PetscErrorCode MatMPIAIJSetPreallocation_MPIAIJKokkos(Mat mat,PetscInt d_nz,const PetscInt d_nnz[],PetscInt o_nz,const PetscInt o_nnz[])
258c3ff71bSJunchao Zhang {
268c3ff71bSJunchao Zhang   PetscErrorCode ierr;
278c3ff71bSJunchao Zhang   Mat_MPIAIJ     *mpiaij = (Mat_MPIAIJ*)mat->data;
288c3ff71bSJunchao Zhang 
298c3ff71bSJunchao Zhang   PetscFunctionBegin;
308c3ff71bSJunchao Zhang   ierr = PetscLayoutSetUp(mat->rmap);CHKERRQ(ierr);
318c3ff71bSJunchao Zhang   ierr = PetscLayoutSetUp(mat->cmap);CHKERRQ(ierr);
326a29ce69SStefano Zampini #if defined(PETSC_USE_DEBUG)
338c3ff71bSJunchao Zhang   if (d_nnz) {
346a29ce69SStefano Zampini     PetscInt i;
358c3ff71bSJunchao Zhang     for (i=0; i<mat->rmap->n; i++) {
36c0aa6a63SJacob Faibussowitsch       if (d_nnz[i] < 0) SETERRQ2(PETSC_COMM_SELF,PETSC_ERR_ARG_OUTOFRANGE,"d_nnz cannot be less than 0: local row %" PetscInt_FMT " value %" PetscInt_FMT,i,d_nnz[i]);
378c3ff71bSJunchao Zhang     }
388c3ff71bSJunchao Zhang   }
398c3ff71bSJunchao Zhang   if (o_nnz) {
406a29ce69SStefano Zampini     PetscInt i;
418c3ff71bSJunchao Zhang     for (i=0; i<mat->rmap->n; i++) {
42c0aa6a63SJacob Faibussowitsch       if (o_nnz[i] < 0) SETERRQ2(PETSC_COMM_SELF,PETSC_ERR_ARG_OUTOFRANGE,"o_nnz cannot be less than 0: local row %" PetscInt_FMT " value %" PetscInt_FMT,i,o_nnz[i]);
438c3ff71bSJunchao Zhang     }
448c3ff71bSJunchao Zhang   }
456a29ce69SStefano Zampini #endif
466a29ce69SStefano Zampini #if defined(PETSC_USE_CTABLE)
476a29ce69SStefano Zampini   ierr = PetscTableDestroy(&mpiaij->colmap);CHKERRQ(ierr);
486a29ce69SStefano Zampini #else
496a29ce69SStefano Zampini   ierr = PetscFree(mpiaij->colmap);CHKERRQ(ierr);
506a29ce69SStefano Zampini #endif
516a29ce69SStefano Zampini   ierr = PetscFree(mpiaij->garray);CHKERRQ(ierr);
526a29ce69SStefano Zampini   ierr = VecDestroy(&mpiaij->lvec);CHKERRQ(ierr);
536a29ce69SStefano Zampini   ierr = VecScatterDestroy(&mpiaij->Mvctx);CHKERRQ(ierr);
546a29ce69SStefano Zampini   /* Because the B will have been resized we simply destroy it and create a new one each time */
556a29ce69SStefano Zampini   ierr = MatDestroy(&mpiaij->B);CHKERRQ(ierr);
566a29ce69SStefano Zampini 
576a29ce69SStefano Zampini   if (!mpiaij->A) {
588c3ff71bSJunchao Zhang     ierr = MatCreate(PETSC_COMM_SELF,&mpiaij->A);CHKERRQ(ierr);
598c3ff71bSJunchao Zhang     ierr = MatSetSizes(mpiaij->A,mat->rmap->n,mat->cmap->n,mat->rmap->n,mat->cmap->n);CHKERRQ(ierr);
608c3ff71bSJunchao Zhang     ierr = PetscLogObjectParent((PetscObject)mat,(PetscObject)mpiaij->A);CHKERRQ(ierr);
616a29ce69SStefano Zampini   }
626a29ce69SStefano Zampini   if (!mpiaij->B) {
636a29ce69SStefano Zampini     PetscMPIInt size;
6455b25c41SPierre Jolivet     ierr = MPI_Comm_size(PetscObjectComm((PetscObject)mat),&size);CHKERRMPI(ierr);
658c3ff71bSJunchao Zhang     ierr = MatCreate(PETSC_COMM_SELF,&mpiaij->B);CHKERRQ(ierr);
666a29ce69SStefano Zampini     ierr = MatSetSizes(mpiaij->B,mat->rmap->n,size > 1 ? mat->cmap->N : 0,mat->rmap->n,size > 1 ? mat->cmap->N : 0);CHKERRQ(ierr);
678c3ff71bSJunchao Zhang     ierr = PetscLogObjectParent((PetscObject)mat,(PetscObject)mpiaij->B);CHKERRQ(ierr);
688c3ff71bSJunchao Zhang   }
696a29ce69SStefano Zampini   ierr = MatSetType(mpiaij->A,MATSEQAIJKOKKOS);CHKERRQ(ierr);
706a29ce69SStefano Zampini   ierr = MatSetType(mpiaij->B,MATSEQAIJKOKKOS);CHKERRQ(ierr);
718c3ff71bSJunchao Zhang   ierr = MatSeqAIJSetPreallocation(mpiaij->A,d_nz,d_nnz);CHKERRQ(ierr);
728c3ff71bSJunchao Zhang   ierr = MatSeqAIJSetPreallocation(mpiaij->B,o_nz,o_nnz);CHKERRQ(ierr);
738c3ff71bSJunchao Zhang   mat->preallocated = PETSC_TRUE;
748c3ff71bSJunchao Zhang   PetscFunctionReturn(0);
758c3ff71bSJunchao Zhang }
768c3ff71bSJunchao Zhang 
778c3ff71bSJunchao Zhang PetscErrorCode MatMult_MPIAIJKokkos(Mat mat,Vec xx,Vec yy)
788c3ff71bSJunchao Zhang {
798c3ff71bSJunchao Zhang   Mat_MPIAIJ     *mpiaij = (Mat_MPIAIJ*)mat->data;
808c3ff71bSJunchao Zhang   PetscErrorCode ierr;
818c3ff71bSJunchao Zhang   PetscInt       nt;
828c3ff71bSJunchao Zhang 
838c3ff71bSJunchao Zhang   PetscFunctionBegin;
848c3ff71bSJunchao Zhang   ierr = VecGetLocalSize(xx,&nt);CHKERRQ(ierr);
85c0aa6a63SJacob Faibussowitsch   if (nt != mat->cmap->n) SETERRQ2(PETSC_COMM_SELF,PETSC_ERR_ARG_SIZ,"Incompatible partition of mat (%" PetscInt_FMT ") and xx (%" PetscInt_FMT ")",mat->cmap->n,nt);
868c3ff71bSJunchao Zhang   ierr = VecScatterBegin(mpiaij->Mvctx,xx,mpiaij->lvec,INSERT_VALUES,SCATTER_FORWARD);CHKERRQ(ierr);
878c3ff71bSJunchao Zhang   ierr = (*mpiaij->A->ops->mult)(mpiaij->A,xx,yy);CHKERRQ(ierr);
888c3ff71bSJunchao Zhang   ierr = VecScatterEnd(mpiaij->Mvctx,xx,mpiaij->lvec,INSERT_VALUES,SCATTER_FORWARD);CHKERRQ(ierr);
898c3ff71bSJunchao Zhang   ierr = (*mpiaij->B->ops->multadd)(mpiaij->B,mpiaij->lvec,yy,yy);CHKERRQ(ierr);
908c3ff71bSJunchao Zhang   PetscFunctionReturn(0);
918c3ff71bSJunchao Zhang }
928c3ff71bSJunchao Zhang 
938c3ff71bSJunchao Zhang PetscErrorCode MatMultAdd_MPIAIJKokkos(Mat mat,Vec xx,Vec yy,Vec zz)
948c3ff71bSJunchao Zhang {
958c3ff71bSJunchao Zhang   Mat_MPIAIJ     *mpiaij = (Mat_MPIAIJ*)mat->data;
968c3ff71bSJunchao Zhang   PetscErrorCode ierr;
978c3ff71bSJunchao Zhang   PetscInt       nt;
988c3ff71bSJunchao Zhang 
998c3ff71bSJunchao Zhang   PetscFunctionBegin;
1008c3ff71bSJunchao Zhang   ierr = VecGetLocalSize(xx,&nt);CHKERRQ(ierr);
101c0aa6a63SJacob Faibussowitsch   if (nt != mat->cmap->n) SETERRQ2(PETSC_COMM_SELF,PETSC_ERR_ARG_SIZ,"Incompatible partition of mat (%" PetscInt_FMT ") and xx (%" PetscInt_FMT ")",mat->cmap->n,nt);
1028c3ff71bSJunchao Zhang   ierr = VecScatterBegin(mpiaij->Mvctx,xx,mpiaij->lvec,INSERT_VALUES,SCATTER_FORWARD);CHKERRQ(ierr);
1038c3ff71bSJunchao Zhang   ierr = (*mpiaij->A->ops->multadd)(mpiaij->A,xx,yy,zz);CHKERRQ(ierr);
1048c3ff71bSJunchao Zhang   ierr = VecScatterEnd(mpiaij->Mvctx,xx,mpiaij->lvec,INSERT_VALUES,SCATTER_FORWARD);CHKERRQ(ierr);
1058c3ff71bSJunchao Zhang   ierr = (*mpiaij->B->ops->multadd)(mpiaij->B,mpiaij->lvec,zz,zz);CHKERRQ(ierr);
1068c3ff71bSJunchao Zhang   PetscFunctionReturn(0);
1078c3ff71bSJunchao Zhang }
1088c3ff71bSJunchao Zhang 
1098c3ff71bSJunchao Zhang PetscErrorCode MatMultTranspose_MPIAIJKokkos(Mat mat,Vec xx,Vec yy)
1108c3ff71bSJunchao Zhang {
1118c3ff71bSJunchao Zhang   Mat_MPIAIJ     *mpiaij = (Mat_MPIAIJ*)mat->data;
1128c3ff71bSJunchao Zhang   PetscErrorCode ierr;
1138c3ff71bSJunchao Zhang   PetscInt       nt;
1148c3ff71bSJunchao Zhang 
1158c3ff71bSJunchao Zhang   PetscFunctionBegin;
1168c3ff71bSJunchao Zhang   ierr = VecGetLocalSize(xx,&nt);CHKERRQ(ierr);
117c0aa6a63SJacob Faibussowitsch   if (nt != mat->rmap->n) SETERRQ2(PETSC_COMM_SELF,PETSC_ERR_ARG_SIZ,"Incompatible partition of mat (%" PetscInt_FMT ") and xx (%" PetscInt_FMT ")",mat->rmap->n,nt);
1188c3ff71bSJunchao Zhang   ierr = (*mpiaij->B->ops->multtranspose)(mpiaij->B,xx,mpiaij->lvec);CHKERRQ(ierr);
1198c3ff71bSJunchao Zhang   ierr = (*mpiaij->A->ops->multtranspose)(mpiaij->A,xx,yy);CHKERRQ(ierr);
1208c3ff71bSJunchao Zhang   ierr = VecScatterBegin(mpiaij->Mvctx,mpiaij->lvec,yy,ADD_VALUES,SCATTER_REVERSE);CHKERRQ(ierr);
1218c3ff71bSJunchao Zhang   ierr = VecScatterEnd(mpiaij->Mvctx,mpiaij->lvec,yy,ADD_VALUES,SCATTER_REVERSE);CHKERRQ(ierr);
1228c3ff71bSJunchao Zhang   PetscFunctionReturn(0);
1238c3ff71bSJunchao Zhang }
1248c3ff71bSJunchao Zhang 
125076ba34aSJunchao Zhang /* Merge the "A, B" matrices of mat into a matrix C.  mat's type is MPIAIJKOKKOS. C's type is MATSEQAIJKOKKOS.
126076ba34aSJunchao Zhang    A is put before B. C's size would be A->rmap->n by (A->cmap->n + B->cmap->n).
127076ba34aSJunchao Zhang    C still uses local column ids. Their corresponding global column ids are returned in glob.
128076ba34aSJunchao Zhang */
129076ba34aSJunchao Zhang PetscErrorCode MatMPIAIJGetLocalMatMerge_MPIAIJKokkos(Mat mat,MatReuse reuse,IS *glob,Mat *C)
130076ba34aSJunchao Zhang {
131076ba34aSJunchao Zhang   Mat            Ad,Ao;
132076ba34aSJunchao Zhang   const PetscInt *cmap;
133076ba34aSJunchao Zhang   PetscErrorCode ierr;
134076ba34aSJunchao Zhang 
135076ba34aSJunchao Zhang   PetscFunctionBegin;
136076ba34aSJunchao Zhang   ierr = MatMPIAIJGetSeqAIJ(mat,&Ad,&Ao,&cmap);CHKERRQ(ierr);
137076ba34aSJunchao Zhang   ierr = MatSeqAIJKokkosMergeMats(Ad,Ao,reuse,C);CHKERRQ(ierr);
138076ba34aSJunchao Zhang   if (glob) {
139076ba34aSJunchao Zhang     PetscInt cst, i, dn, on, *gidx;
140076ba34aSJunchao Zhang     ierr = MatGetLocalSize(Ad,NULL,&dn);CHKERRQ(ierr);
141076ba34aSJunchao Zhang     ierr = MatGetLocalSize(Ao,NULL,&on);CHKERRQ(ierr);
142076ba34aSJunchao Zhang     ierr = MatGetOwnershipRangeColumn(mat,&cst,NULL);CHKERRQ(ierr);
143076ba34aSJunchao Zhang     ierr = PetscMalloc1(dn+on,&gidx);CHKERRQ(ierr);
144076ba34aSJunchao Zhang     for (i=0; i<dn; i++) gidx[i]    = cst + i;
145076ba34aSJunchao Zhang     for (i=0; i<on; i++) gidx[i+dn] = cmap[i];
146076ba34aSJunchao Zhang     ierr = ISCreateGeneral(PetscObjectComm((PetscObject)Ad),dn+on,gidx,PETSC_OWN_POINTER,glob);CHKERRQ(ierr);
147076ba34aSJunchao Zhang   }
148076ba34aSJunchao Zhang   PetscFunctionReturn(0);
149076ba34aSJunchao Zhang }
150076ba34aSJunchao Zhang 
151076ba34aSJunchao Zhang /* Structs used in matrix product C=AB, C=A^tB and C=B^tAB */
152076ba34aSJunchao Zhang struct MatMatStruct {
153076ba34aSJunchao Zhang   MatRowMapKokkosView   Cdstart; /* Used to split sequential matrix into petsc's A, B format */
154076ba34aSJunchao Zhang   PetscSF               sf; /* SF to send/recv matrix entries */
155076ba34aSJunchao Zhang   MatScalarKokkosView   abuf; /* buf of mat values in send/recv */
156076ba34aSJunchao Zhang   Mat                   C1,C2,B_local;
157076ba34aSJunchao Zhang   KokkosCsrMatrix       C1_global,C2_global,C_global;
158076ba34aSJunchao Zhang   KernelHandle          kh;
159076ba34aSJunchao Zhang   MatMatStruct() {
160076ba34aSJunchao Zhang     C1 = C2 = B_local = NULL;
161076ba34aSJunchao Zhang     sf = NULL;
162076ba34aSJunchao Zhang   }
163076ba34aSJunchao Zhang 
164076ba34aSJunchao Zhang   ~MatMatStruct() {
165076ba34aSJunchao Zhang     MatDestroy(&C1);
166076ba34aSJunchao Zhang     MatDestroy(&C2);
167076ba34aSJunchao Zhang     MatDestroy(&B_local);
168076ba34aSJunchao Zhang     PetscSFDestroy(&sf);
169076ba34aSJunchao Zhang     kh.destroy_spadd_handle();
170076ba34aSJunchao Zhang   }
171076ba34aSJunchao Zhang };
172076ba34aSJunchao Zhang 
173076ba34aSJunchao Zhang struct MatMatStruct_AB : public MatMatStruct {
174076ba34aSJunchao Zhang   MatColIdxKokkosView   rows;
175076ba34aSJunchao Zhang   MatRowMapKokkosView   rowoffset;
176076ba34aSJunchao Zhang   Mat                   B_other,C_petsc; /* SEQAIJKOKKOS matrices. TODO: have a better var name than C_petsc */
177076ba34aSJunchao Zhang 
178076ba34aSJunchao Zhang   MatMatStruct_AB() : B_other(NULL),C_petsc(NULL){}
179076ba34aSJunchao Zhang   ~MatMatStruct_AB() {
180076ba34aSJunchao Zhang     MatDestroy(&B_other);
181076ba34aSJunchao Zhang     MatDestroy(&C_petsc);
182076ba34aSJunchao Zhang   }
183076ba34aSJunchao Zhang };
184076ba34aSJunchao Zhang 
185076ba34aSJunchao Zhang struct MatMatStruct_AtB : public MatMatStruct {
186076ba34aSJunchao Zhang   MatRowMapKokkosView   srcrowoffset,dstrowoffset;
187076ba34aSJunchao Zhang };
188076ba34aSJunchao Zhang 
189076ba34aSJunchao Zhang struct MatProductData_MPIAIJKokkos
190076ba34aSJunchao Zhang {
191076ba34aSJunchao Zhang   MatMatStruct_AB   *mmAB;
192076ba34aSJunchao Zhang   MatMatStruct_AtB  *mmAtB;
193076ba34aSJunchao Zhang   PetscBool         reusesym;
194076ba34aSJunchao Zhang 
195076ba34aSJunchao Zhang   MatProductData_MPIAIJKokkos(): mmAB(NULL),mmAtB(NULL),reusesym(PETSC_FALSE){}
196076ba34aSJunchao Zhang   ~MatProductData_MPIAIJKokkos() {
197076ba34aSJunchao Zhang     delete mmAB;
198076ba34aSJunchao Zhang     delete mmAtB;
199076ba34aSJunchao Zhang   }
200076ba34aSJunchao Zhang };
201076ba34aSJunchao Zhang 
202076ba34aSJunchao Zhang static PetscErrorCode MatProductDataDestroy_MPIAIJKokkos(void *data)
203076ba34aSJunchao Zhang {
204076ba34aSJunchao Zhang   PetscFunctionBegin;
205076ba34aSJunchao Zhang   CHKERRCXX(delete static_cast<MatProductData_MPIAIJKokkos*>(data));
206076ba34aSJunchao Zhang   PetscFunctionReturn(0);
207076ba34aSJunchao Zhang }
208076ba34aSJunchao Zhang 
209076ba34aSJunchao Zhang /* MatSeqAIJKokkosGetCSRMatrixWithGlobalColumnIds - Get a KokkosCsrMatrix from a MATSEQAIJKOKKOS matrix
210076ba34aSJunchao Zhang 
211076ba34aSJunchao Zhang    Input Parameters:
212076ba34aSJunchao Zhang +  A       - the MATSEQAIJKOKKOS matrix
213076ba34aSJunchao Zhang .  N       - new column size for the returned Kokkos matrix
214076ba34aSJunchao Zhang -  l2g     - a map that maps old col ids to new col ids
215076ba34aSJunchao Zhang 
216076ba34aSJunchao Zhang    Output Parameters:
217076ba34aSJunchao Zhang .  csrmat  - the Kokkos matrix, which has the same row size as A, shares a, i but not j with A.
218076ba34aSJunchao Zhang  */
219076ba34aSJunchao Zhang static PetscErrorCode MatSeqAIJKokkosGetCSRMatrixWithGlobalColumnIds(Mat A,PetscInt N,const ConstMatColIdxKokkosView& l2g,KokkosCsrMatrix& csrmat)
220076ba34aSJunchao Zhang {
221076ba34aSJunchao Zhang   KokkosCsrMatrix&         orig = static_cast<Mat_SeqAIJKokkos*>(A->spptr)->csrmat;
222076ba34aSJunchao Zhang   MatColIdxKokkosView      jg("jg",orig.nnz()); /* New j array for csrmat */
223076ba34aSJunchao Zhang 
224076ba34aSJunchao Zhang   PetscFunctionBegin;
225076ba34aSJunchao Zhang   CHKERRCXX(Kokkos::parallel_for(orig.nnz(), KOKKOS_LAMBDA(const PetscInt i) {jg(i) = l2g(orig.graph.entries(i));}));
226076ba34aSJunchao Zhang   CHKERRCXX(csrmat = KokkosCsrMatrix("csrmat",orig.numRows(),N,orig.nnz(),orig.values,orig.graph.row_map,jg));
227076ba34aSJunchao Zhang   PetscFunctionReturn(0);
228076ba34aSJunchao Zhang }
229076ba34aSJunchao Zhang 
230076ba34aSJunchao Zhang /* MatSetMPIAIJKokkosWithSplitSeqAIJKokkosMatrices - Set the diag and offdiag matrices of a MATMPIAIJKOKKOS matrix.
231076ba34aSJunchao Zhang    It is similar to MatCreateMPIAIJWithSplitArrays.
232076ba34aSJunchao Zhang 
233076ba34aSJunchao Zhang   Input Parameters:
234076ba34aSJunchao Zhang +  mat   - the MATMPIAIJKOKKOS matrix, which should have its type and layout set, but should not have its diag, offdiag matrices set
235076ba34aSJunchao Zhang .  A     - the diag matrix using local col ids
236076ba34aSJunchao Zhang -  B     - the offdiag matrix using global col ids
237076ba34aSJunchao Zhang 
238076ba34aSJunchao Zhang   Output Parameters:
239076ba34aSJunchao Zhang .  mat   - the updated MATMPIAIJKOKKOS matrix
240076ba34aSJunchao Zhang */
241076ba34aSJunchao Zhang static PetscErrorCode MatSetMPIAIJKokkosWithSplitSeqAIJKokkosMatrices(Mat mat,Mat A,Mat B)
242076ba34aSJunchao Zhang {
243076ba34aSJunchao Zhang   PetscErrorCode      ierr;
244076ba34aSJunchao Zhang   Mat_MPIAIJ          *mpiaij = static_cast<Mat_MPIAIJ*>(mat->data);
245076ba34aSJunchao Zhang   PetscInt            m,n,M,N,Am,An,Bm,Bn;
246076ba34aSJunchao Zhang   Mat_SeqAIJKokkos    *bkok = static_cast<Mat_SeqAIJKokkos*>(B->spptr);
247076ba34aSJunchao Zhang 
248076ba34aSJunchao Zhang   PetscFunctionBegin;
249076ba34aSJunchao Zhang   ierr = MatGetSize(mat,&M,&N);CHKERRQ(ierr);
250076ba34aSJunchao Zhang   ierr = MatGetLocalSize(mat,&m,&n);CHKERRQ(ierr);
251076ba34aSJunchao Zhang   ierr = MatGetLocalSize(A,&Am,&An);CHKERRQ(ierr);
252076ba34aSJunchao Zhang   ierr = MatGetLocalSize(B,&Bm,&Bn);CHKERRQ(ierr);
253076ba34aSJunchao Zhang 
254076ba34aSJunchao Zhang   if (m != Am || m != Bm) SETERRQ(PETSC_COMM_SELF,PETSC_ERR_PLIB,"local number of rows do not match");
255076ba34aSJunchao Zhang   if (n != An) SETERRQ(PETSC_COMM_SELF,PETSC_ERR_PLIB,"local number of columns do not match");
256076ba34aSJunchao Zhang   if (N != Bn) SETERRQ(PETSC_COMM_SELF,PETSC_ERR_PLIB,"global number of columns do not match");
257076ba34aSJunchao Zhang   if (mpiaij->A || mpiaij->B) SETERRQ(PETSC_COMM_SELF,PETSC_ERR_PLIB,"A, B of the MPIAIJ matrix are not empty");
258076ba34aSJunchao Zhang   mpiaij->A = A;
259076ba34aSJunchao Zhang   mpiaij->B = B;
260076ba34aSJunchao Zhang 
261076ba34aSJunchao Zhang   mat->preallocated     = PETSC_TRUE;
262076ba34aSJunchao Zhang   mat->nooffprocentries = PETSC_TRUE; /* See MatAssemblyBegin_MPIAIJ. In effect, making MatAssemblyBegin a nop */
263076ba34aSJunchao Zhang 
264076ba34aSJunchao Zhang   ierr = MatSetOption(mat,MAT_NO_OFF_PROC_ENTRIES,PETSC_TRUE);CHKERRQ(ierr);
265076ba34aSJunchao Zhang   ierr = MatAssemblyBegin(mat,MAT_FINAL_ASSEMBLY);CHKERRQ(ierr);
266076ba34aSJunchao Zhang   /* MatAssemblyEnd is critical here. It sets mat->offloadmask according to A and B's, and
267076ba34aSJunchao Zhang     also gets mpiaij->B compacted, with its col ids and size reduced
268076ba34aSJunchao Zhang   */
269076ba34aSJunchao Zhang   ierr = MatAssemblyEnd(mat,MAT_FINAL_ASSEMBLY);CHKERRQ(ierr);
270076ba34aSJunchao Zhang   ierr = MatSetOption(mat,MAT_NO_OFF_PROC_ENTRIES,PETSC_FALSE);CHKERRQ(ierr);
271076ba34aSJunchao Zhang   ierr = MatSetOption(mat,MAT_NEW_NONZERO_LOCATION_ERR,PETSC_TRUE);CHKERRQ(ierr);
272076ba34aSJunchao Zhang 
273076ba34aSJunchao Zhang   /* Update bkok with new local col ids (stored on host) and size */
274076ba34aSJunchao Zhang   bkok->j_dual.modify_host();
275076ba34aSJunchao Zhang   bkok->j_dual.sync_device();
276076ba34aSJunchao Zhang   bkok->SetColSize(mpiaij->B->cmap->n);
277076ba34aSJunchao Zhang   PetscFunctionReturn(0);
278076ba34aSJunchao Zhang }
279076ba34aSJunchao Zhang 
280076ba34aSJunchao Zhang /* MatSeqAIJKokkosBcast - Bcast rows of a SEQAIJKOKKOS matrice (B) to form a SEQAIJKOKKOS matrix (C).
281076ba34aSJunchao Zhang 
282076ba34aSJunchao Zhang    It is essentially the MPIAIJKOKKOS counterpart of MatGetBrowsOfAoCols_MPIAIJ, but supports device and uses PetscSF.
283076ba34aSJunchao Zhang    In the given ownerSF, leaves correspond to rows in C, and roots correspond to rows in B. Roots may connect to multiple leaves.
284076ba34aSJunchao Zhang    Suppose C's j-th row is connected to a root identified by PetscSFNode (k,i), it means we will bcast the i-th row of B on rank k
285076ba34aSJunchao Zhang    to j-th row of C. ownerSF's leaves must be contiguous (in other words, as if ilocal=NULL was used to set its graph).
286076ba34aSJunchao Zhang 
287076ba34aSJunchao Zhang    Collective on comm of ownerSF
288076ba34aSJunchao Zhang 
289076ba34aSJunchao Zhang    Input Parameters:
290076ba34aSJunchao Zhang +   B       - the SEQAIJKOKKOS matrix, using local col ids
291076ba34aSJunchao Zhang .   reuse   - either MAT_INITIAL_MATRIX or MAT_REUSE_MATRIX
292076ba34aSJunchao Zhang .   N       - global col ids are in range of [0,N). N Must be the same across ranks (nonsignificant in MAT_REUSE_MATRIX)
293076ba34aSJunchao Zhang .   l2g     - a map mapping B's local col ids to global ones (nonsignificant in MAT_REUSE_MATRIX)
294076ba34aSJunchao Zhang .   ownerSF - the ownership SF (nonsignificant in MAT_REUSE_MATRIX)
295076ba34aSJunchao Zhang 
296076ba34aSJunchao Zhang    Input/Output Parameters (out when resue = MAT_INITIAL_MATRIX, inout when reuse = MAT_REUSE_MATRIX)
297076ba34aSJunchao Zhang +   bcastSF   - the SF used to bcast rows of B. This plain SF does buffer (abuf) to buffer (Ca) send/recv. In this SF, vertices are nonzeros.
298076ba34aSJunchao Zhang .   abuf      - buffer for sending matrix values
299076ba34aSJunchao Zhang .   rows      - array containing indices of (local) rows that this rank needs to bcast to others. Each receiver rank has a chunk in rows[].
300076ba34aSJunchao Zhang                 Values in rows[] might have repeats, which simply indicates a row will be bcast'ed to multiple neighbors.
301076ba34aSJunchao Zhang .   rowoffset - For each row in rows[], it will be copied to rowoffset[] at abuf[]
302076ba34aSJunchao Zhang -   C         -  the SEQAIJKOKKOS matrix made of the bcast'ed rows, using local col ids.
303076ba34aSJunchao Zhang */
304076ba34aSJunchao Zhang static PetscErrorCode MatSeqAIJKokkosBcast(Mat B,MatReuse reuse,PetscInt N,const ConstMatColIdxKokkosView& l2g,PetscSF ownerSF,
305076ba34aSJunchao Zhang                                            PetscSF& bcastSF,MatScalarKokkosView& abuf,MatColIdxKokkosView& rows,
306076ba34aSJunchao Zhang                                            MatRowMapKokkosView& rowoffset,Mat& C)
307076ba34aSJunchao Zhang {
308076ba34aSJunchao Zhang   PetscErrorCode               ierr;
309076ba34aSJunchao Zhang   Mat_SeqAIJKokkos             *bkok,*ckok;
310076ba34aSJunchao Zhang 
311076ba34aSJunchao Zhang   PetscFunctionBegin;
312076ba34aSJunchao Zhang   ierr = MatSeqAIJKokkosSyncDevice(B);CHKERRQ(ierr); /* Make sure B->spptr is accessible */
313076ba34aSJunchao Zhang   bkok = static_cast<Mat_SeqAIJKokkos*>(B->spptr);
314076ba34aSJunchao Zhang 
315076ba34aSJunchao Zhang   if (reuse == MAT_REUSE_MATRIX) {
316076ba34aSJunchao Zhang     ckok = static_cast<Mat_SeqAIJKokkos*>(C->spptr);
317076ba34aSJunchao Zhang 
318076ba34aSJunchao Zhang     const auto& Ba = bkok->a_dual.view_device();
319076ba34aSJunchao Zhang     const auto& Bi = bkok->i_dual.view_device();
320076ba34aSJunchao Zhang     const auto& Ca = ckok->a_dual.view_device();
321076ba34aSJunchao Zhang 
322076ba34aSJunchao Zhang     /* Copy Ba to abuf */
323076ba34aSJunchao Zhang     Kokkos::parallel_for(Kokkos::TeamPolicy<>(rows.extent(0), Kokkos::AUTO()),KOKKOS_LAMBDA(const KokkosTeamMemberType& t) {
324076ba34aSJunchao Zhang       PetscInt i    = t.league_rank(); /* rows[i] is r-th row of B */
325076ba34aSJunchao Zhang       PetscInt r    = rows(i);
326076ba34aSJunchao Zhang       PetscInt base = rowoffset(i); /* Copy r-th row of B to this offset in abuf[] */
327076ba34aSJunchao Zhang       Kokkos::parallel_for(Kokkos::TeamThreadRange(t,Bi(r+1)-Bi(r)),[&](PetscInt k) {
328076ba34aSJunchao Zhang         abuf(base+k) = Ba(Bi(r)+k);
329076ba34aSJunchao Zhang       });
330076ba34aSJunchao Zhang     });
331076ba34aSJunchao Zhang 
332076ba34aSJunchao Zhang     /* Send abuf to Ca through bcastSF and then mark C is updated on device */
333076ba34aSJunchao Zhang     ierr = PetscSFBcastBegin(bcastSF,MPIU_SCALAR,abuf.data(),Ca.data(),MPI_REPLACE);CHKERRQ(ierr); /* TODO: get memtype for abuf */
334076ba34aSJunchao Zhang     ierr = PetscSFBcastEnd  (bcastSF,MPIU_SCALAR,abuf.data(),Ca.data(),MPI_REPLACE);CHKERRQ(ierr);
335076ba34aSJunchao Zhang     ckok->a_dual.modify_device();
336076ba34aSJunchao Zhang   } else if (reuse == MAT_INITIAL_MATRIX) {
337076ba34aSJunchao Zhang     MPI_Comm       comm;
338076ba34aSJunchao Zhang     PetscMPIInt    tag;
339076ba34aSJunchao Zhang     PetscInt       k,Cm,Cn,Cnnz,*Ci_h,nroots,nleaves;
340076ba34aSJunchao Zhang 
341076ba34aSJunchao Zhang     ierr = PetscObjectGetComm((PetscObject)ownerSF,&comm);CHKERRMPI(ierr);
342076ba34aSJunchao Zhang     ierr = PetscSFGetGraph(ownerSF,&nroots,&nleaves,NULL,NULL);CHKERRQ(ierr);
343076ba34aSJunchao Zhang     Cm   = nleaves; /* row size of C */
344076ba34aSJunchao Zhang     Cn   = N;  /* col size of C, which initially uses global ids, so we can safely set its col size as N */
345076ba34aSJunchao Zhang 
346076ba34aSJunchao Zhang     /* Get row lens (nz) of B's rows for later fast query */
347076ba34aSJunchao Zhang     PetscInt       *Browlens;
348076ba34aSJunchao Zhang     const PetscInt *tmp = bkok->i_host_data();
349076ba34aSJunchao Zhang     ierr = PetscMalloc1(nroots,&Browlens);CHKERRQ(ierr);
350076ba34aSJunchao Zhang     for (k=0; k<nroots; k++) Browlens[k] = tmp[k+1]-tmp[k];
351076ba34aSJunchao Zhang 
352076ba34aSJunchao Zhang     /* By ownerSF, each proc gets lens of rows of C */
353076ba34aSJunchao Zhang     MatRowMapKokkosDualView Ci("i",Cm+1); /* C's rowmap */
354076ba34aSJunchao Zhang     Ci_h    = Ci.view_host().data();
355076ba34aSJunchao Zhang     Ci_h[0] = 0;
356076ba34aSJunchao Zhang     ierr    = PetscSFBcastWithMemTypeBegin(ownerSF,MPIU_INT,PETSC_MEMTYPE_HOST,Browlens,PETSC_MEMTYPE_HOST,&Ci_h[1],MPI_REPLACE);CHKERRQ(ierr);
357076ba34aSJunchao Zhang     ierr    = PetscSFBcastEnd(ownerSF,MPIU_INT,Browlens,&Ci_h[1],MPI_REPLACE);CHKERRQ(ierr);
358076ba34aSJunchao Zhang     for (k=1; k<Cm+1; k++) Ci_h[k] += Ci_h[k-1]; /* Convert lens to CSR */
359076ba34aSJunchao Zhang     Cnnz    = Ci_h[Cm];
360076ba34aSJunchao Zhang     Ci.modify_host();
361076ba34aSJunchao Zhang     Ci.sync_device();
362076ba34aSJunchao Zhang 
363076ba34aSJunchao Zhang     /* With the newly known Cnnz, we are able to allocate (j, a) for C on host & device */
364076ba34aSJunchao Zhang     MatColIdxKokkosDualView  Cj("j",Cnnz);
365076ba34aSJunchao Zhang     MatScalarKokkosDualView  Ca("a",Cnnz);
366076ba34aSJunchao Zhang 
367076ba34aSJunchao Zhang     /* Now build the bcastSF to fill Ca, Cj. This plain SF only does (contiguous) buffer to buffer send/recv */
368076ba34aSJunchao Zhang     const PetscMPIInt *iranks,*ranks;
369076ba34aSJunchao Zhang     const PetscInt    *ioffset,*irootloc,*roffset;
370076ba34aSJunchao Zhang     PetscInt          i,j,niranks,nranks,*sdisp,*rdisp,*rowptr;
371076ba34aSJunchao Zhang     MPI_Request       *reqs;
372076ba34aSJunchao Zhang 
373076ba34aSJunchao Zhang     ierr = PetscSFGetLeafRanks(ownerSF,&niranks,&iranks,&ioffset,&irootloc);CHKERRQ(ierr); /* irootloc[] contains indices of rows I need to send to each receiver */
374076ba34aSJunchao Zhang     ierr = PetscSFGetRootRanks(ownerSF,&nranks,&ranks,&roffset,NULL/*rmine*/,NULL/*rremote*/);CHKERRQ(ierr); /* recv info */
375076ba34aSJunchao Zhang 
376076ba34aSJunchao Zhang     /* figure out offsets at the send buffer, to build the SF
377076ba34aSJunchao Zhang       sdisp[]  - stores offsets of nonzeros (in abuf or jbuf, see later) I need to send, per receiver.
378076ba34aSJunchao Zhang       rowptr[] - stores offsets for data of each row in abuf
379076ba34aSJunchao Zhang 
380076ba34aSJunchao Zhang       rdisp[]  - to receive sdisp[]
381076ba34aSJunchao Zhang     */
382076ba34aSJunchao Zhang     ierr = PetscMalloc3(niranks+1,&sdisp,nranks,&rdisp,niranks+nranks,&reqs);CHKERRQ(ierr);
383076ba34aSJunchao Zhang     MatRowMapKokkosViewHost rowptr_h("rowptr_h",ioffset[niranks]+1); /* Let Kokkos do the allocation, so that we can do an easy mirror later */
384076ba34aSJunchao Zhang     rowptr = rowptr_h.data();
385076ba34aSJunchao Zhang 
386076ba34aSJunchao Zhang     sdisp[0] = 0;
387076ba34aSJunchao Zhang     rowptr[0]  = 0;
388076ba34aSJunchao Zhang     for (i=0; i<niranks; i++) { /* for each receiver */
389076ba34aSJunchao Zhang       PetscInt len, nz = 0;
390076ba34aSJunchao Zhang       for (j=ioffset[i]; j<ioffset[i+1]; j++) { /* for each row to this receiver */
391076ba34aSJunchao Zhang         len         = Browlens[irootloc[j]];
392076ba34aSJunchao Zhang         rowptr[j+1] = rowptr[j] + len;
393076ba34aSJunchao Zhang         nz         += len;
394076ba34aSJunchao Zhang       }
395076ba34aSJunchao Zhang       sdisp[i+1] = sdisp[i] + nz;
396076ba34aSJunchao Zhang     }
397076ba34aSJunchao Zhang     ierr = PetscCommGetNewTag(comm,&tag);CHKERRMPI(ierr);
398076ba34aSJunchao Zhang     for (i=0; i<nranks; i++)  {ierr = MPI_Irecv(&rdisp[i],1,MPIU_INT,ranks[i],tag,comm,&reqs[i]);CHKERRMPI(ierr);}
399076ba34aSJunchao Zhang     for (i=0; i<niranks; i++) {ierr = MPI_Isend(&sdisp[i],1,MPIU_INT,iranks[i],tag,comm,&reqs[nranks+i]);CHKERRMPI(ierr);}
400076ba34aSJunchao Zhang     ierr = MPI_Waitall(niranks+nranks,reqs,MPI_STATUSES_IGNORE);CHKERRMPI(ierr);
401076ba34aSJunchao Zhang 
402076ba34aSJunchao Zhang     PetscInt    nleaves2 = Cnnz; /* leaves are the nonzeros I will receive */
403076ba34aSJunchao Zhang     PetscInt    nroots2  = sdisp[niranks]; /* roots are the nonzeros (in abuf) I will send */
404076ba34aSJunchao Zhang     PetscSFNode *iremote;
405076ba34aSJunchao Zhang     ierr = PetscMalloc1(nleaves2,&iremote);CHKERRQ(ierr);
406076ba34aSJunchao Zhang     for (i=0; i<nranks; i++) { /* for each sender */
407076ba34aSJunchao Zhang       k = 0;
408076ba34aSJunchao Zhang       for (j=Ci_h[roffset[i]]; j<Ci_h[roffset[i+1]]; j++) {
409076ba34aSJunchao Zhang         iremote[j].rank  = ranks[i];
410076ba34aSJunchao Zhang         iremote[j].index = rdisp[i] + k;
411076ba34aSJunchao Zhang         k++;
412076ba34aSJunchao Zhang       }
413076ba34aSJunchao Zhang     }
414076ba34aSJunchao Zhang     /* TODO: we should extend PetscSF APIs for this buffer-to-buffer send/recv */
415076ba34aSJunchao Zhang     ierr = PetscSFCreate(comm,&bcastSF);CHKERRQ(ierr);
416076ba34aSJunchao Zhang     ierr = PetscSFSetGraph(bcastSF,nroots2,nleaves2,NULL/*ilocal*/,PETSC_OWN_POINTER,iremote,PETSC_OWN_POINTER);CHKERRQ(ierr);
417076ba34aSJunchao Zhang 
418076ba34aSJunchao Zhang     /* Extract selected rows of B, and copy their (a, j) into abuf[] and jbuf[], with j converted
419076ba34aSJunchao Zhang       from local to global. Then use bcastSF to fill Ca, Cj.
420076ba34aSJunchao Zhang     */
421076ba34aSJunchao Zhang     ConstMatColIdxKokkosViewHost rows_h(irootloc,ioffset[niranks]); /* irootloc[] stores indices of rows I need to send */
422076ba34aSJunchao Zhang     MatColIdxKokkosView          rows("rows",ioffset[niranks]);
423076ba34aSJunchao Zhang     Kokkos::deep_copy(rows,rows_h); /* Use deep copy since irootoc is managed by PetscSF and we want 'rows' to be standalone */
424076ba34aSJunchao Zhang 
425076ba34aSJunchao Zhang     rowoffset = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(),rowptr_h); /* If no device, rowoffset will be an alias to rowptr_h */
426076ba34aSJunchao Zhang 
427076ba34aSJunchao Zhang     MatColIdxKokkosView jbuf("jbuf",sdisp[niranks]); /* send buf for (global) col ids */
428076ba34aSJunchao Zhang     abuf = MatScalarKokkosView("abuf",sdisp[niranks]); /* send buf for mat values */
429076ba34aSJunchao Zhang 
430076ba34aSJunchao Zhang     const auto& Ba = bkok->a_dual.view_device();
431076ba34aSJunchao Zhang     const auto& Bi = bkok->i_dual.view_device();
432076ba34aSJunchao Zhang     const auto& Bj = bkok->j_dual.view_device();
433076ba34aSJunchao Zhang 
434076ba34aSJunchao Zhang     /* Copy Ba, Bj to abuf, jbuf with change col ids from local to global */
435076ba34aSJunchao Zhang     Kokkos::parallel_for(Kokkos::TeamPolicy<>(rows.extent(0), Kokkos::AUTO()),KOKKOS_LAMBDA(const KokkosTeamMemberType& t) {
436076ba34aSJunchao Zhang       PetscInt i    = t.league_rank(); /* rows[i] is r-th row of B */
437076ba34aSJunchao Zhang       PetscInt r    = rows(i);
438076ba34aSJunchao Zhang       PetscInt base = rowoffset(i); /* Copy r-th row of B to this offset in abuf[] */
439076ba34aSJunchao Zhang       Kokkos::parallel_for(Kokkos::TeamThreadRange(t,Bi(r+1)-Bi(r)),[&](PetscInt k) {
440076ba34aSJunchao Zhang         abuf(base+k) = Ba(Bi(r)+k);
441076ba34aSJunchao Zhang         jbuf(base+k) = l2g(Bj(Bi(r)+k));
442076ba34aSJunchao Zhang       });
443076ba34aSJunchao Zhang     });
444076ba34aSJunchao Zhang 
445076ba34aSJunchao Zhang     /* Send abuf & jbuf to fill Ca, Cj */
446076ba34aSJunchao Zhang     ierr = PetscSFBcastBegin(bcastSF,MPIU_INT,   jbuf.data(),Cj.view_device().data(),MPI_REPLACE);CHKERRQ(ierr);
447076ba34aSJunchao Zhang     ierr = PetscSFBcastBegin(bcastSF,MPIU_SCALAR,abuf.data(),Ca.view_device().data(),MPI_REPLACE);CHKERRQ(ierr);
448076ba34aSJunchao Zhang     ierr = PetscSFBcastEnd  (bcastSF,MPIU_INT,   jbuf.data(),Cj.view_device().data(),MPI_REPLACE);CHKERRQ(ierr);
449076ba34aSJunchao Zhang     ierr = PetscSFBcastEnd  (bcastSF,MPIU_SCALAR,abuf.data(),Ca.view_device().data(),MPI_REPLACE);CHKERRQ(ierr);
450076ba34aSJunchao Zhang     Cj.modify_device(); /* Mark Cj, Ca modified on device, but only sync Cj since we might not need Ca on host at all */
451076ba34aSJunchao Zhang     Cj.sync_host();
452076ba34aSJunchao Zhang     Ca.modify_device();
453076ba34aSJunchao Zhang 
454076ba34aSJunchao Zhang     /* Construct C with Ca, Ci, Cj */
455076ba34aSJunchao Zhang     auto ckok = new Mat_SeqAIJKokkos(Cm,Cn,Cnnz,Ci,Cj,Ca);
456076ba34aSJunchao Zhang     ierr = MatCreateSeqAIJKokkosWithCSRMatrix(PETSC_COMM_SELF,ckok,&C);CHKERRQ(ierr);
457076ba34aSJunchao Zhang     ierr = PetscFree3(sdisp,rdisp,reqs);CHKERRQ(ierr);
458076ba34aSJunchao Zhang     ierr = PetscFree(Browlens);CHKERRQ(ierr);
459546078acSJacob Faibussowitsch   } else SETERRQ1(PETSC_COMM_SELF,PETSC_ERR_PLIB,"Unsupported MatReuse enum %d",reuse);
460076ba34aSJunchao Zhang   PetscFunctionReturn(0);
461076ba34aSJunchao Zhang }
462076ba34aSJunchao Zhang 
463076ba34aSJunchao Zhang /* MatSeqAIJKokkosReduce - Reduce rows of a SEQAIJKOKKOS matrix (A) to form a Kokkos Csr matrix (C)
464076ba34aSJunchao Zhang 
465076ba34aSJunchao Zhang   It is the reverse of MatSeqAIJKokkosBcast in some sense.
466076ba34aSJunchao Zhang 
467076ba34aSJunchao Zhang   Think each row of A as a leaf, then the given ownerSF specifies roots of the leaves. Roots may connect to multiple leaves.
468076ba34aSJunchao Zhang   In this routine, we reduce (i.e., concatenate) leaves (rows) at their roots to form potentially longer rows in C. Such rows might
469076ba34aSJunchao Zhang   contain repeats, which does not matter since they will be summed up by other routines. C's row size will be nroots of ownerSF.
470076ba34aSJunchao Zhang 
471076ba34aSJunchao Zhang   Input Parameters:
472076ba34aSJunchao Zhang +  A        - the SEQAIJKOKKOS matrix to be reduced
473076ba34aSJunchao Zhang .  reuse    - either MAT_INITIAL_MATRIX or MAT_REUSE_MATRIX
474076ba34aSJunchao Zhang .  local    - true if A uses local col ids; false if A is already in global col ids.
475076ba34aSJunchao Zhang .  N        - if local, N is A's global col size
476076ba34aSJunchao Zhang .  l2g      - if local, a map mapping A's local col ids to global ones, which are in range of [0,N).
477076ba34aSJunchao Zhang -  ownerSF  - the SF specifies ownership (root) of rows in A
478076ba34aSJunchao Zhang 
479076ba34aSJunchao Zhang   Output Parameters:
480076ba34aSJunchao Zhang +  reduceSF    - the SF to reduce A's rows to contiguous buffers at the receiver side
481076ba34aSJunchao Zhang .  abuf         - a contiguous buffer to receive A's rows sent to this proc. Suppose there are 'nrows' such rows.
482076ba34aSJunchao Zhang .  srcrowoffset - offset array of size nrows+1. Each entry is the corresponding row's offset in abuf[]. srcrowoffset[i+1]-srcrowoffset[i] is row i's len.
483076ba34aSJunchao Zhang .  dstrowoffset - offset array of size nrows. Each entry is the corresponding row's offset in Ca[], i.e., C's 'a' array. Row i, i+1 in abuf[] may go to
484076ba34aSJunchao Zhang                   unrelated places in Ca, so dstrowoffset is not in CSR-like format as srcrowoffset.
485076ba34aSJunchao Zhang -  C            - the matrix made up by rows sent to me from other ranks, using global col ids
486076ba34aSJunchao Zhang 
487076ba34aSJunchao Zhang    TODO: we can even have MatSeqAIJKokkosReduceBegin/End to provide oppertunity for callers to overlap comp./comm. when reuse = MAT_REUSE_MATRIX.
488076ba34aSJunchao Zhang  */
489076ba34aSJunchao Zhang static PetscErrorCode MatSeqAIJKokkosReduce(Mat A,MatReuse reuse,PetscBool local,PetscInt N,const ConstMatColIdxKokkosView& l2g,PetscSF ownerSF,
490076ba34aSJunchao Zhang                                             PetscSF& reduceSF,MatScalarKokkosView& abuf,
491076ba34aSJunchao Zhang                                             MatRowMapKokkosView& srcrowoffset,MatRowMapKokkosView& dstrowoffset,
492076ba34aSJunchao Zhang                                             KokkosCsrMatrix& C)
493076ba34aSJunchao Zhang {
494076ba34aSJunchao Zhang   PetscErrorCode         ierr;
495076ba34aSJunchao Zhang   PetscInt               i,r,Am,An,Annz,Cnnz,nrows;
496076ba34aSJunchao Zhang   const PetscInt         *Ai;
497076ba34aSJunchao Zhang   Mat_SeqAIJKokkos       *akok;
498076ba34aSJunchao Zhang 
499076ba34aSJunchao Zhang   PetscFunctionBegin;
500076ba34aSJunchao Zhang   ierr = MatSeqAIJKokkosSyncDevice(A);CHKERRQ(ierr); /* So that A's latest data is on device */
501076ba34aSJunchao Zhang   ierr = MatGetSize(A,&Am,&An);
502076ba34aSJunchao Zhang   Ai   = static_cast<Mat_SeqAIJ*>(A->data)->i;
503076ba34aSJunchao Zhang   akok = static_cast<Mat_SeqAIJKokkos*>(A->spptr);
504076ba34aSJunchao Zhang   Annz = Ai[Am];
505076ba34aSJunchao Zhang 
506076ba34aSJunchao Zhang   if (reuse == MAT_REUSE_MATRIX) {
507076ba34aSJunchao Zhang     /* Send Aa to abuf */
508076ba34aSJunchao Zhang     ierr = PetscSFReduceBegin(reduceSF,MPIU_SCALAR,akok->a_device_data(),abuf.data(),MPI_REPLACE);CHKERRMPI(ierr);
509076ba34aSJunchao Zhang     ierr = PetscSFReduceEnd  (reduceSF,MPIU_SCALAR,akok->a_device_data(),abuf.data(),MPI_REPLACE);CHKERRMPI(ierr);
510076ba34aSJunchao Zhang 
511076ba34aSJunchao Zhang     /* Copy abuf to Ca */
512076ba34aSJunchao Zhang     const MatScalarKokkosView& Ca = C.values;
513076ba34aSJunchao Zhang     nrows = dstrowoffset.extent(0); /* Not srcrowoffset[] since it has an extra entry for CSR */
514076ba34aSJunchao Zhang     Kokkos::parallel_for(Kokkos::TeamPolicy<>(nrows, Kokkos::AUTO()),KOKKOS_LAMBDA(const KokkosTeamMemberType& t) {
515076ba34aSJunchao Zhang       PetscInt i   = t.league_rank();
516076ba34aSJunchao Zhang       PetscInt src = srcrowoffset(i), dst = dstrowoffset(i);
517076ba34aSJunchao Zhang       PetscInt len = srcrowoffset(i+1) - srcrowoffset(i);
518076ba34aSJunchao Zhang       Kokkos::parallel_for(Kokkos::TeamThreadRange(t,len), [&](PetscInt k) {Ca(dst+k) = abuf(src+k);});
519076ba34aSJunchao Zhang     });
520076ba34aSJunchao Zhang   } else if (reuse == MAT_INITIAL_MATRIX) {
521076ba34aSJunchao Zhang     MPI_Comm               comm;
522076ba34aSJunchao Zhang     MPI_Request            *reqs;
523076ba34aSJunchao Zhang     PetscMPIInt            tag;
524076ba34aSJunchao Zhang     PetscInt               Cm;
525076ba34aSJunchao Zhang 
526076ba34aSJunchao Zhang     ierr = PetscObjectGetComm((PetscObject)ownerSF,&comm);CHKERRQ(ierr);
527076ba34aSJunchao Zhang     ierr = PetscCommGetNewTag(comm,&tag);CHKERRQ(ierr);
528076ba34aSJunchao Zhang 
529076ba34aSJunchao Zhang     PetscInt niranks,nranks,nroots,nleaves;
530076ba34aSJunchao Zhang     const PetscMPIInt *iranks,*ranks;
531076ba34aSJunchao Zhang     const PetscInt *ioffset,*rows,*roffset;  /* rows[] contains local indices of rows scattered to me from others. ioffset[] is a CSR on rows[] */
532076ba34aSJunchao Zhang     ierr = PetscSFSetUp(ownerSF);CHKERRQ(ierr);
533076ba34aSJunchao Zhang     ierr = PetscSFGetLeafRanks(ownerSF,&niranks,&iranks,&ioffset,&rows);CHKERRQ(ierr); /* recv info: iranks[] will send rows to me */
534076ba34aSJunchao Zhang     ierr = PetscSFGetRootRanks(ownerSF,&nranks,&ranks,&roffset,NULL/*rmine*/,NULL/*rremote*/);CHKERRQ(ierr); /* send info */
535076ba34aSJunchao Zhang     ierr = PetscSFGetGraph(ownerSF,&nroots,&nleaves,NULL,NULL);CHKERRQ(ierr);
536546078acSJacob Faibussowitsch     if (nleaves != Am) SETERRQ2(PETSC_COMM_SELF,PETSC_ERR_PLIB,"ownerSF's nleaves(%" PetscInt_FMT ") != row size of A(%" PetscInt_FMT ")",nleaves,Am);
537076ba34aSJunchao Zhang     Cm    = nroots;
538076ba34aSJunchao Zhang     nrows = ioffset[niranks]; /* # of rows to be received. Might receive same row (each is partial) from different senders */
539076ba34aSJunchao Zhang 
540076ba34aSJunchao Zhang     /* Tell owners how long each row I will send */
541076ba34aSJunchao Zhang     PetscInt                *srowlens; /* send buf of row lens */
542076ba34aSJunchao Zhang     MatRowMapKokkosViewHost rrowlens_h("rrowoffset_h",nrows+1); /* recv buf of row lens. +1 to make CSR later. Memory might be passed to other views */
543076ba34aSJunchao Zhang     PetscInt                *rrowlens = rrowlens_h.data();
544076ba34aSJunchao Zhang 
545076ba34aSJunchao Zhang     ierr = PetscMalloc2(Am,&srowlens,niranks+nranks,&reqs);CHKERRQ(ierr);
546076ba34aSJunchao Zhang     for (i=0; i<Am; i++) srowlens[i] = Ai[i+1] - Ai[i];
547076ba34aSJunchao Zhang     rrowlens[0] = 0;
548076ba34aSJunchao Zhang     rrowlens++; /* shift the pointer to make the following expression more readable */
549076ba34aSJunchao Zhang     for (i=0; i<niranks; i++){ierr = MPI_Irecv(&rrowlens[ioffset[i]],ioffset[i+1]-ioffset[i],MPIU_INT,iranks[i],tag,comm,&reqs[i]);CHKERRMPI(ierr);}
550076ba34aSJunchao Zhang     for (i=0; i<nranks; i++) {ierr = MPI_Isend(&srowlens[roffset[i]],roffset[i+1]-roffset[i],MPIU_INT,ranks[i],tag,comm,&reqs[niranks+i]);CHKERRMPI(ierr);}
551076ba34aSJunchao Zhang     ierr = MPI_Waitall(niranks+nranks,reqs,MPI_STATUSES_IGNORE);CHKERRMPI(ierr);
552076ba34aSJunchao Zhang 
553076ba34aSJunchao Zhang     /* Owner builds Ci on host by histogramming rrowlens[] */
554076ba34aSJunchao Zhang     MatRowMapKokkosViewHost Ci_h("i",Cm+1);
555076ba34aSJunchao Zhang     Kokkos::deep_copy(Ci_h,0); /* Zero Ci */
556076ba34aSJunchao Zhang     MatRowMapType *Ci_ptr = Ci_h.data();
557076ba34aSJunchao Zhang 
558076ba34aSJunchao Zhang     for (i=0; i<nrows; i++) {
559076ba34aSJunchao Zhang       r = rows[i]; /* local row id of i-th received row */
560076ba34aSJunchao Zhang      #if defined(PETSC_USE_DEBUG)
561546078acSJacob Faibussowitsch       if (r<0 || r>=Cm) SETERRQ2(PETSC_COMM_SELF,PETSC_ERR_PLIB,"local row id (%" PetscInt_FMT ") is out of range [0,%" PetscInt_FMT ")",r,Cm);
562076ba34aSJunchao Zhang      #endif
563076ba34aSJunchao Zhang       Ci_ptr[r+1] += rrowlens[i]; /* add to length of row r in C */
564076ba34aSJunchao Zhang     }
565076ba34aSJunchao Zhang     for (i=0; i<Cm; i++) Ci_ptr[i+1] += Ci_ptr[i]; /* to CSR format */
566076ba34aSJunchao Zhang     Cnnz = Ci_ptr[Cm];
567076ba34aSJunchao Zhang 
568076ba34aSJunchao Zhang     /* For each received row, compute src & dst offsets in memory copying (from recv bufs abuf, jbuf to Ca, Cj) */
569076ba34aSJunchao Zhang     MatRowMapKokkosViewHost dstrowoffset_h("dstrowoffset_h",nrows);
570076ba34aSJunchao Zhang     PetscInt                *dstrowoffset_hptr = dstrowoffset_h.data();
571076ba34aSJunchao Zhang     PetscInt                *currowlens; /* Current row lens. They are temp accumulators for row lens in C, to help build dstrowoffset */
572076ba34aSJunchao Zhang 
573076ba34aSJunchao Zhang     ierr = PetscCalloc1(Cm,&currowlens);CHKERRQ(ierr); /* Init with zero, to be added to */
574076ba34aSJunchao Zhang     for (i=0; i<nrows; i++) { /* for each row I receive */
575076ba34aSJunchao Zhang       r                    = rows[i]; /* row id in C */
576076ba34aSJunchao Zhang       dstrowoffset_hptr[i] = Ci_ptr[r] + currowlens[r]; /* dst offset of the new place for each recv'ed row in Ca/Cj */
577076ba34aSJunchao Zhang       currowlens[r]       += rrowlens[i]; /* accumulate to length of row r in C */
578076ba34aSJunchao Zhang     }
579076ba34aSJunchao Zhang     ierr = PetscFree(currowlens);CHKERRQ(ierr);
580076ba34aSJunchao Zhang 
581076ba34aSJunchao Zhang     rrowlens--;
582076ba34aSJunchao Zhang     for (i=0; i<nrows; i++) rrowlens[i+1] += rrowlens[i]; /* Change rrowlens[] to CSR format */
583076ba34aSJunchao Zhang     dstrowoffset = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(),dstrowoffset_h);
584076ba34aSJunchao Zhang     srcrowoffset = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(),rrowlens_h); /* src offset of each recv'ed row in abuf/jbuf */
585076ba34aSJunchao Zhang 
586076ba34aSJunchao Zhang     /* Build the reduceSF, which performs buffer to buffer send/recv */
587076ba34aSJunchao Zhang     PetscInt *sdisp,*rdisp; /* buffer to send offsets of roots, and buffer to recv them */
588076ba34aSJunchao Zhang     ierr = PetscMalloc2(niranks,&sdisp,nranks,&rdisp);CHKERRQ(ierr);
589076ba34aSJunchao Zhang     for (i=0; i<niranks; i++) sdisp[i] = rrowlens[ioffset[i]];
590076ba34aSJunchao Zhang     for (i=0; i<nranks; i++)  {ierr = MPI_Irecv(&rdisp[i],1,MPIU_INT,ranks[i],tag,comm,&reqs[i]);CHKERRMPI(ierr);}
591076ba34aSJunchao Zhang     for (i=0; i<niranks; i++) {ierr = MPI_Isend(&sdisp[i],1,MPIU_INT,iranks[i],tag,comm,&reqs[nranks+i]);CHKERRMPI(ierr);}
592076ba34aSJunchao Zhang     ierr = MPI_Waitall(niranks+nranks,reqs,MPI_STATUSES_IGNORE);CHKERRMPI(ierr);
593076ba34aSJunchao Zhang 
594076ba34aSJunchao Zhang     /* Nonzeros in abuf/jbuf are roots and those in A are leaves */
595076ba34aSJunchao Zhang     PetscInt    nroots2 = Cnnz,nleaves2 = Annz;
596076ba34aSJunchao Zhang     PetscSFNode *iremote;
597076ba34aSJunchao Zhang     ierr = PetscMalloc1(nleaves2,&iremote);CHKERRQ(ierr); /* no free, since memory will be given to reduceSF */
598076ba34aSJunchao Zhang     for (i=0; i<nranks; i++) {
599076ba34aSJunchao Zhang       PetscInt rootbase = rdisp[i]; /* root offset at this root rank */
600076ba34aSJunchao Zhang       PetscInt leafbase = Ai[roffset[i]]; /* leaf base */
601076ba34aSJunchao Zhang       PetscInt nz       = Ai[roffset[i+1]] - leafbase; /* I will send nz nonzeros to this root rank */
602076ba34aSJunchao Zhang       for (PetscInt k=0; k<nz; k++) {
603076ba34aSJunchao Zhang         iremote[leafbase+k].rank  = ranks[i];
604076ba34aSJunchao Zhang         iremote[leafbase+k].index = rootbase + k;
605076ba34aSJunchao Zhang       }
606076ba34aSJunchao Zhang     }
607076ba34aSJunchao Zhang     ierr = PetscSFCreate(comm,&reduceSF);CHKERRQ(ierr);
608076ba34aSJunchao Zhang     ierr = PetscSFSetGraph(reduceSF,nroots2,nleaves2,NULL,PETSC_OWN_POINTER,iremote,PETSC_OWN_POINTER);CHKERRQ(ierr);
609076ba34aSJunchao Zhang     ierr = PetscFree2(sdisp,rdisp);CHKERRQ(ierr);
610076ba34aSJunchao Zhang 
611076ba34aSJunchao Zhang     /* Reduce Aa, Ajg to abuf and jbuf */
612076ba34aSJunchao Zhang 
613076ba34aSJunchao Zhang     /* If A uses local col ids, convert them to global ones before sending */
614076ba34aSJunchao Zhang     MatColIdxKokkosView Ajg;
615076ba34aSJunchao Zhang     if (local) {
616076ba34aSJunchao Zhang       Ajg = MatColIdxKokkosView("j",Annz);
617076ba34aSJunchao Zhang       const MatColIdxKokkosView& Aj = akok->j_dual.view_device();
618076ba34aSJunchao Zhang       Kokkos::parallel_for(Annz,KOKKOS_LAMBDA(const PetscInt i) {Ajg(i) = l2g(Aj(i));});
619076ba34aSJunchao Zhang     } else {
620076ba34aSJunchao Zhang       Ajg = akok->j_dual.view_device(); /* no data copy, just take a reference */
621076ba34aSJunchao Zhang     }
622076ba34aSJunchao Zhang 
623076ba34aSJunchao Zhang     MatColIdxKokkosView   jbuf("jbuf",Cnnz);
624076ba34aSJunchao Zhang     abuf = MatScalarKokkosView("abuf",Cnnz);
625076ba34aSJunchao Zhang     ierr = PetscSFReduceBegin(reduceSF,MPIU_INT,   Ajg.data(),           jbuf.data(),MPI_REPLACE);CHKERRMPI(ierr);
626076ba34aSJunchao Zhang     ierr = PetscSFReduceEnd  (reduceSF,MPIU_INT,   Ajg.data(),           jbuf.data(),MPI_REPLACE);CHKERRMPI(ierr);
627076ba34aSJunchao Zhang     ierr = PetscSFReduceBegin(reduceSF,MPIU_SCALAR,akok->a_device_data(),abuf.data(),MPI_REPLACE);CHKERRMPI(ierr);
628076ba34aSJunchao Zhang     ierr = PetscSFReduceEnd  (reduceSF,MPIU_SCALAR,akok->a_device_data(),abuf.data(),MPI_REPLACE);CHKERRMPI(ierr);
629076ba34aSJunchao Zhang 
630076ba34aSJunchao Zhang     /* Copy data from abuf, jbuf to Ca, Cj */
631076ba34aSJunchao Zhang     MatRowMapKokkosView    Ci = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(),Ci_h); /* Ci is an alias of Ci_h if no device */
632076ba34aSJunchao Zhang     MatColIdxKokkosView    Cj("j",Cnnz);
633076ba34aSJunchao Zhang     MatScalarKokkosView    Ca("a",Cnnz);
634076ba34aSJunchao Zhang 
635076ba34aSJunchao Zhang     Kokkos::parallel_for(Kokkos::TeamPolicy<>(nrows, Kokkos::AUTO()),KOKKOS_LAMBDA(const KokkosTeamMemberType& t) {
636076ba34aSJunchao Zhang       PetscInt i   = t.league_rank();
637076ba34aSJunchao Zhang       PetscInt src = srcrowoffset(i), dst = dstrowoffset(i);
638076ba34aSJunchao Zhang       PetscInt len = srcrowoffset(i+1) - srcrowoffset(i);
639076ba34aSJunchao Zhang       Kokkos::parallel_for(Kokkos::TeamThreadRange(t,len), [&](PetscInt k) {
640076ba34aSJunchao Zhang         Ca(dst+k) = abuf(src+k);
641076ba34aSJunchao Zhang         Cj(dst+k) = jbuf(src+k);
642076ba34aSJunchao Zhang       });
643076ba34aSJunchao Zhang     });
644076ba34aSJunchao Zhang 
645076ba34aSJunchao Zhang     /* Build C with Ca, Ci, Cj */
646076ba34aSJunchao Zhang     C    = KokkosCsrMatrix("csrmat",Cm,N,Cnnz,Ca,Ci,Cj);
647076ba34aSJunchao Zhang     ierr = PetscFree2(srowlens,reqs);CHKERRQ(ierr);
648546078acSJacob Faibussowitsch   } else SETERRQ1(PETSC_COMM_SELF,PETSC_ERR_PLIB,"Unsupported MatReuse enum %d",reuse);
649076ba34aSJunchao Zhang   PetscFunctionReturn(0);
650076ba34aSJunchao Zhang }
651076ba34aSJunchao Zhang 
652076ba34aSJunchao Zhang /* MatSetMPIAIJKokkosWithGlobalCSRMatrix - Set the diag and offdiag parts of a MATMPIAIJKOKKOS matrix by splitting a KokkosCsrMatrix
653076ba34aSJunchao Zhang 
654076ba34aSJunchao Zhang   Input Parameters:
655076ba34aSJunchao Zhang +  C        - the MATMPIAIJKOKKOS matrix, of size m,n,M,N
656076ba34aSJunchao Zhang .  reuse    - indicate whether the matrix has called this function before
657076ba34aSJunchao Zhang .  csrmat   - the KokkosCsrMatrix, of size m,N
658076ba34aSJunchao Zhang -  Cdstart  - when reuse == MAT_REUSE_MATRIX, it is an input parameter. For each row in csrmat, it stores the start of the first
659076ba34aSJunchao Zhang               entry of the diag block of C in csrmat's j array. E.g, if row i has col ids = {0, 3, 4, 5, 7, 9} and the first diag
660076ba34aSJunchao Zhang               entry is 5, then Cdstart[i] = 3.
661076ba34aSJunchao Zhang 
662076ba34aSJunchao Zhang   Output Parameters:
663076ba34aSJunchao Zhang +  C        - the updated MATMPIAIJKOKKOS matrix
664076ba34aSJunchao Zhang -  Cdstart - when reuse == MAT_INITIAL_MATRIX, it is an output parameter
665076ba34aSJunchao Zhang 
666076ba34aSJunchao Zhang   Notes:
667076ba34aSJunchao Zhang    Between calls with MAT_INITIAL_MATRIX or MAT_REUSE_MATRIX, csrmat must have the same nonzero pattern
668076ba34aSJunchao Zhang  */
669076ba34aSJunchao Zhang static PetscErrorCode MatSetMPIAIJKokkosWithGlobalCSRMatrix(Mat C,MatReuse reuse,const KokkosCsrMatrix& csrmat,MatRowMapKokkosView& Cdstart)
670076ba34aSJunchao Zhang {
671076ba34aSJunchao Zhang   PetscErrorCode                  ierr;
672076ba34aSJunchao Zhang   const MatScalarKokkosView&      Ca = csrmat.values;
673076ba34aSJunchao Zhang   const ConstMatRowMapKokkosView& Ci = csrmat.graph.row_map;
674076ba34aSJunchao Zhang   PetscInt                        m,n,N;
675076ba34aSJunchao Zhang 
676076ba34aSJunchao Zhang   PetscFunctionBegin;
677076ba34aSJunchao Zhang   ierr = MatGetLocalSize(C,&m,&n);CHKERRQ(ierr);
678076ba34aSJunchao Zhang   ierr = MatGetSize(C,NULL,&N);CHKERRQ(ierr);
679076ba34aSJunchao Zhang 
680076ba34aSJunchao Zhang   if (reuse == MAT_REUSE_MATRIX) {
681076ba34aSJunchao Zhang     Mat_MPIAIJ                  *mpiaij = static_cast<Mat_MPIAIJ*>(C->data);
682076ba34aSJunchao Zhang     Mat_SeqAIJKokkos            *akok = static_cast<Mat_SeqAIJKokkos*>(mpiaij->A->spptr);
683076ba34aSJunchao Zhang     Mat_SeqAIJKokkos            *bkok = static_cast<Mat_SeqAIJKokkos*>(mpiaij->B->spptr);
684076ba34aSJunchao Zhang     const MatScalarKokkosView&  Cda = akok->a_dual.view_device(),Coa = bkok->a_dual.view_device();
685076ba34aSJunchao Zhang     const MatRowMapKokkosView&  Cdi = akok->i_dual.view_device(),Coi = bkok->i_dual.view_device();
686076ba34aSJunchao Zhang 
687076ba34aSJunchao Zhang     /* Fill 'a' of Cd and Co on device */
688076ba34aSJunchao Zhang     Kokkos::parallel_for(Kokkos::TeamPolicy<>(m, Kokkos::AUTO()),KOKKOS_LAMBDA(const KokkosTeamMemberType& t) {
689076ba34aSJunchao Zhang       PetscInt i       = t.league_rank(); /* row i */
690076ba34aSJunchao Zhang       PetscInt clen    = Ci(i+1) - Ci(i); /* len of row i of C */
691076ba34aSJunchao Zhang       PetscInt cdlen   = Cdi(i+1) - Cdi(i); /* len of row i of Cd */
692076ba34aSJunchao Zhang       PetscInt cdstart = Cdstart(i); /* [start, end) of row i of Cd in C */
693076ba34aSJunchao Zhang       PetscInt cdend   = cdstart + cdlen;
694076ba34aSJunchao Zhang       /* [0, clen) is cut into three blocks: [0, cdstart), [cdstart, cdend), [cdend, clen) */
695076ba34aSJunchao Zhang       Kokkos::parallel_for(Kokkos::TeamThreadRange(t, clen), [&](PetscInt k) {
696076ba34aSJunchao Zhang         if (k < cdstart) {  /* k in [0, cdstart) */
697076ba34aSJunchao Zhang           Coa(Coi(i)+k) = Ca(Ci(i)+k);
698076ba34aSJunchao Zhang         } else if (k < cdend) { /* k in [cdstart, cdend) */
699076ba34aSJunchao Zhang           Cda(Cdi(i)+(k-cdstart)) = Ca(Ci(i)+k);
700076ba34aSJunchao Zhang         } else { /* k in [cdend, clen) */
701076ba34aSJunchao Zhang           Coa(Coi(i)+k-cdlen) = Ca(Ci(i)+k);
702076ba34aSJunchao Zhang         }
703076ba34aSJunchao Zhang       });
704076ba34aSJunchao Zhang     });
705076ba34aSJunchao Zhang 
706076ba34aSJunchao Zhang     akok->a_dual.modify_device();
707076ba34aSJunchao Zhang     bkok->a_dual.modify_device();
708076ba34aSJunchao Zhang   } else if (reuse == MAT_INITIAL_MATRIX) {
709076ba34aSJunchao Zhang     Mat                         Cd,Co;
710076ba34aSJunchao Zhang     const MatColIdxKokkosView&  Cj = csrmat.graph.entries;
711076ba34aSJunchao Zhang     MatRowMapKokkosDualView     Cdi_dual("i",m+1),Coi_dual("i",m+1);
712076ba34aSJunchao Zhang     MatRowMapKokkosView         Cdi = Cdi_dual.view_device(),Coi = Coi_dual.view_device();
713076ba34aSJunchao Zhang     PetscInt                    cstart,cend;
714076ba34aSJunchao Zhang 
715076ba34aSJunchao Zhang     /* Note that each row of C is sorted by col ids. We want to find out how to cut each row into three blocks:
716076ba34aSJunchao Zhang        left to the diag block, diag block, right to the diag block. The diag block have col ids in [cstart,cend).
717076ba34aSJunchao Zhang        Suppose a row of C has len nonzeros, indexed by [0, len). We want to know two indices: cdstart and cdend,
718076ba34aSJunchao Zhang        such that the three blocks are [0,cdstart), [cdstart,cdend), [cdend,len). The following code equivalentaly
719076ba34aSJunchao Zhang        stores values of cdstart and cdend-cstart (aka Cdi[]) instead.
720076ba34aSJunchao Zhang      */
721076ba34aSJunchao Zhang     Cdstart = MatRowMapKokkosView("Cdstart",m);
722076ba34aSJunchao Zhang     ierr    = PetscLayoutGetRange(C->cmap,&cstart,&cend);CHKERRQ(ierr); /* Not MatGetOwnershipRangeColumn() since C has not been preallocated yet */
723076ba34aSJunchao Zhang 
724076ba34aSJunchao Zhang     /* I could use RangePolicy and one thread per row. But since each thread essentially does binary search, threads in a
725076ba34aSJunchao Zhang       CUDA warp would completely diverge. So I use TeamPolicy with a team size 1.
726076ba34aSJunchao Zhang      */
727076ba34aSJunchao Zhang     Kokkos::parallel_for(Kokkos::TeamPolicy<>(m, 1),KOKKOS_LAMBDA(const KokkosTeamMemberType& t) {
728076ba34aSJunchao Zhang       Kokkos::single(Kokkos::PerTeam(t), [=] () { /* Only one thread works in a team */
729076ba34aSJunchao Zhang         PetscInt i = t.league_rank(); /* row i */
730076ba34aSJunchao Zhang         PetscInt j,first,count,step;
731076ba34aSJunchao Zhang 
732076ba34aSJunchao Zhang         if (i == 0) { /* Set the first entry of the i arrays to zero on device, to be used in CSR */
733076ba34aSJunchao Zhang           Cdi(0) = 0;
734076ba34aSJunchao Zhang           Coi(0) = 0;
735076ba34aSJunchao Zhang         }
736076ba34aSJunchao Zhang 
737076ba34aSJunchao Zhang         /* Do std::lower_bound(Ci(i),Ci(i+1),cstart) on Cj[]. We use j as the iterator. lower_bound() returns
738076ba34aSJunchao Zhang           in 'first' the first iterator with a value >= cstart, or last iterator if no such element is found.
739076ba34aSJunchao Zhang         */
740076ba34aSJunchao Zhang         count = Ci(i+1)-Ci(i);
741076ba34aSJunchao Zhang         first = Ci(i);
742076ba34aSJunchao Zhang         while (count > 0) {
743076ba34aSJunchao Zhang           j    = first;
744076ba34aSJunchao Zhang           step = count / 2;
745076ba34aSJunchao Zhang           j   += step;
746076ba34aSJunchao Zhang           if (Cj(j) < cstart) {
747076ba34aSJunchao Zhang             first  = ++j;
748076ba34aSJunchao Zhang             count -= step + 1;
749076ba34aSJunchao Zhang           } else count = step;
750076ba34aSJunchao Zhang         }
751076ba34aSJunchao Zhang         Cdstart(i) = first - Ci(i); /* 'first' is the while-loop's output */
752076ba34aSJunchao Zhang 
753076ba34aSJunchao Zhang         /* Do std::lower_bound(first,Ci(i+1),cend) on Cj[] */
754076ba34aSJunchao Zhang         count = Ci(i+1) - first;
755076ba34aSJunchao Zhang         while (count > 0) {
756076ba34aSJunchao Zhang           j    = first;
757076ba34aSJunchao Zhang           step = count / 2;
758076ba34aSJunchao Zhang           j   += step;
759076ba34aSJunchao Zhang           if (Cj(j) < cend) {
760076ba34aSJunchao Zhang             first  = ++j;
761076ba34aSJunchao Zhang             count -= step + 1;
762076ba34aSJunchao Zhang           } else count = step;
763076ba34aSJunchao Zhang         }
764076ba34aSJunchao Zhang         Cdi(i+1) = first - (Ci(i)+Cdstart(i)); /* 'first' is the while-loop's output */
765076ba34aSJunchao Zhang         Coi(i+1) = (Ci(i+1)-Ci(i)) - Cdi(i+1); /* Co's row len = C's row len - Cd's row len */
766076ba34aSJunchao Zhang       });
767076ba34aSJunchao Zhang     });
768076ba34aSJunchao Zhang 
769076ba34aSJunchao Zhang     /* Convert row lens in Cdi[], Coi[] to CSR format using inclusive scan, e.g., changing [0,1,2,3] into [0,1,3,6] */
770076ba34aSJunchao Zhang     Kokkos::parallel_scan(m+1,KOKKOS_LAMBDA(const PetscInt i,PetscInt& update,const bool final) {
771076ba34aSJunchao Zhang       update += Cdi(i);
772076ba34aSJunchao Zhang       if (final) Cdi(i) = update;
773076ba34aSJunchao Zhang     });
774076ba34aSJunchao Zhang     Kokkos::parallel_scan(m+1,KOKKOS_LAMBDA(const PetscInt i,PetscInt& update,const bool final) {
775076ba34aSJunchao Zhang       update += Coi(i);
776076ba34aSJunchao Zhang       if (final) Coi(i) = update;
777076ba34aSJunchao Zhang     });
778076ba34aSJunchao Zhang 
779076ba34aSJunchao Zhang     /* Get Cdi, Coi on host (it is not a waste, since we do need them on host in
780076ba34aSJunchao Zhang        MatCreateSeqAIJKokkosWithCSRMatrix() below), then get nnz of Cd and Co.
781076ba34aSJunchao Zhang     */
782076ba34aSJunchao Zhang     Cdi_dual.modify_device();
783076ba34aSJunchao Zhang     Coi_dual.modify_device();
784076ba34aSJunchao Zhang     Cdi_dual.sync_host();
785076ba34aSJunchao Zhang     Coi_dual.sync_host();
786076ba34aSJunchao Zhang     PetscInt Cd_nnz = Cdi_dual.view_host().data()[m];
787076ba34aSJunchao Zhang     PetscInt Co_nnz = Coi_dual.view_host().data()[m];
788076ba34aSJunchao Zhang 
789076ba34aSJunchao Zhang     /* With nnz, allocate a, j for Cd and Co */
790076ba34aSJunchao Zhang     MatColIdxKokkosDualView Cdj_dual("j",Cd_nnz),Coj_dual("j",Co_nnz);
791076ba34aSJunchao Zhang     MatScalarKokkosDualView Cda_dual("a",Cd_nnz),Coa_dual("a",Co_nnz);
792076ba34aSJunchao Zhang 
793076ba34aSJunchao Zhang     /* Fill a, j of Cd and Co on device */
794076ba34aSJunchao Zhang     MatColIdxKokkosView     Cdj = Cdj_dual.view_device(),Coj = Coj_dual.view_device();
795076ba34aSJunchao Zhang     MatScalarKokkosView     Cda = Cda_dual.view_device(),Coa = Coa_dual.view_device();
796076ba34aSJunchao Zhang 
797076ba34aSJunchao Zhang     Kokkos::parallel_for(Kokkos::TeamPolicy<>(m, Kokkos::AUTO()),KOKKOS_LAMBDA(const KokkosTeamMemberType& t) {
798076ba34aSJunchao Zhang       PetscInt i       = t.league_rank(); /* row i */
799076ba34aSJunchao Zhang       PetscInt clen    = Ci(i+1) - Ci(i); /* len of row i of C */
800076ba34aSJunchao Zhang       PetscInt cdlen   = Cdi(i+1) - Cdi(i); /* len of row i of Cd */
801076ba34aSJunchao Zhang       PetscInt cdstart = Cdstart(i); /* [start, end) of row i of Cd in C */
802076ba34aSJunchao Zhang       PetscInt cdend   = cdstart + cdlen;
803076ba34aSJunchao Zhang       /* [0, clen) is cut into three blocks: [0, cdstart), [cdstart, cdend), [cdend, clen) */
804076ba34aSJunchao Zhang       Kokkos::parallel_for(Kokkos::TeamThreadRange(t, clen), [&](PetscInt k) {
805076ba34aSJunchao Zhang         if (k < cdstart) { /* k in [0, cdstart) */
806076ba34aSJunchao Zhang           Coa(Coi(i)+k) = Ca(Ci(i)+k);
807076ba34aSJunchao Zhang           Coj(Coi(i)+k) = Cj(Ci(i)+k);
808076ba34aSJunchao Zhang         } else if (k < cdend) { /* k in [cdstart, cdend) */
809076ba34aSJunchao Zhang           Cda(Cdi(i)+(k-cdstart)) = Ca(Ci(i)+k);
810076ba34aSJunchao Zhang           Cdj(Cdi(i)+(k-cdstart)) = Cj(Ci(i)+k) - cstart; /* Use local col ids in Cdj */
811076ba34aSJunchao Zhang         } else { /* k in [cdend, clen) */
812076ba34aSJunchao Zhang           Coa(Coi(i)+k-cdlen) = Ca(Ci(i)+k);
813076ba34aSJunchao Zhang           Coj(Coi(i)+k-cdlen) = Cj(Ci(i)+k);
814076ba34aSJunchao Zhang         }
815076ba34aSJunchao Zhang       });
816076ba34aSJunchao Zhang     });
817076ba34aSJunchao Zhang 
818076ba34aSJunchao Zhang     Cdj_dual.modify_device();
819076ba34aSJunchao Zhang     Cda_dual.modify_device();
820076ba34aSJunchao Zhang     Coj_dual.modify_device();
821076ba34aSJunchao Zhang     Coa_dual.modify_device();
822076ba34aSJunchao Zhang     /* With a, i, j for Cd and Co, finally build Cd, Co and then C. Their offloadmask will be set in each's MatAssemblyEnd */
823076ba34aSJunchao Zhang     auto cdkok = new Mat_SeqAIJKokkos(m,n,Cd_nnz,Cdi_dual,Cdj_dual,Cda_dual);
824076ba34aSJunchao Zhang     auto cokok = new Mat_SeqAIJKokkos(m,N,Co_nnz,Coi_dual,Coj_dual,Coa_dual);
825076ba34aSJunchao Zhang     ierr = MatCreateSeqAIJKokkosWithCSRMatrix(PETSC_COMM_SELF,cdkok,&Cd);CHKERRQ(ierr);
826076ba34aSJunchao Zhang     ierr = MatCreateSeqAIJKokkosWithCSRMatrix(PETSC_COMM_SELF,cokok,&Co);CHKERRQ(ierr);
827076ba34aSJunchao Zhang     ierr = MatSetMPIAIJKokkosWithSplitSeqAIJKokkosMatrices(C,Cd,Co);CHKERRQ(ierr); /* Coj will be converted to local ids within */
828076ba34aSJunchao Zhang   }
829076ba34aSJunchao Zhang   PetscFunctionReturn(0);
830076ba34aSJunchao Zhang }
831076ba34aSJunchao Zhang 
832076ba34aSJunchao Zhang /* MatSeqAIJCompactOutExtraColumns_SeqAIJKokkos - Compact a SEQAIJKOKKS matrix's global col ids.
833076ba34aSJunchao Zhang 
834076ba34aSJunchao Zhang   It is similar to MatSeqAIJCompactOutExtraColumns_SeqAIJ, but it applies to SEQAIJKOKKOS and returns the l2g map in Kokkos view.
835076ba34aSJunchao Zhang 
836076ba34aSJunchao Zhang   Input Parameters:
837076ba34aSJunchao Zhang +  C        - the MATMPIAIJKOKKOS matrix, of size m,n,M,N
838076ba34aSJunchao Zhang .  reuse    - indicate whether the matrix has called this function before
839076ba34aSJunchao Zhang .  csrmat   - the KokkosCsrMatrix, of size m,N
840076ba34aSJunchao Zhang -  Cdoffset - when reuse == MAT_REUSE_MATRIX, it is an input parameter. For each row in csrmat, it stores the offset of the first
841076ba34aSJunchao Zhang               entry of the diag block of C in csrmat's j array.
842076ba34aSJunchao Zhang 
843076ba34aSJunchao Zhang   Output Parameters:
844076ba34aSJunchao Zhang +  C        - the updated MATMPIAIJKOKKOS matrix
845076ba34aSJunchao Zhang -  Cdoffset - when reuse == MAT_INITIAL_MATRIX, it is an output parameter
846076ba34aSJunchao Zhang 
847076ba34aSJunchao Zhang   Notes: the input matrix's col ids and col size will be changed.
848076ba34aSJunchao Zhang */
849076ba34aSJunchao Zhang static PetscErrorCode MatSeqAIJCompactOutExtraColumns_SeqAIJKokkos(Mat C,MatColIdxKokkosView& l2g)
850076ba34aSJunchao Zhang {
851076ba34aSJunchao Zhang   PetscErrorCode         ierr;
852076ba34aSJunchao Zhang   Mat_SeqAIJKokkos       *ckok;
853076ba34aSJunchao Zhang   ISLocalToGlobalMapping l2gmap;
854076ba34aSJunchao Zhang   const PetscInt         *garray;
855076ba34aSJunchao Zhang   PetscInt               sz;
856076ba34aSJunchao Zhang 
857076ba34aSJunchao Zhang   PetscFunctionBegin;
858076ba34aSJunchao Zhang   /* Compact P_other's global col ids and col size. We do it since we guess with local ids KK might be more memory scalable */
859076ba34aSJunchao Zhang   ierr = MatSeqAIJCompactOutExtraColumns_SeqAIJ(C,&l2gmap);CHKERRQ(ierr);
860076ba34aSJunchao Zhang   ckok = static_cast<Mat_SeqAIJKokkos*>(C->spptr);
861076ba34aSJunchao Zhang   ckok->j_dual.modify_host(); /* P_other's j is modified on host; we need to sync it on device */
862076ba34aSJunchao Zhang   ckok->j_dual.sync_device();
863076ba34aSJunchao Zhang   ckok->SetColSize(C->cmap->n); /* Update col size of the csrmat in spptr */
864076ba34aSJunchao Zhang 
865076ba34aSJunchao Zhang   /* Build l2g -- the local to global mapping of C's cols */
866076ba34aSJunchao Zhang   ierr = ISLocalToGlobalMappingGetIndices(l2gmap,&garray);CHKERRQ(ierr);
867076ba34aSJunchao Zhang   ierr = ISLocalToGlobalMappingGetSize(l2gmap,&sz);CHKERRQ(ierr);
868546078acSJacob Faibussowitsch   if (C->cmap->n != sz) SETERRQ2(PETSC_COMM_SELF,PETSC_ERR_PLIB,"matrix column size(%" PetscInt_FMT ") != l2g mapping size(%" PetscInt_FMT ")", C->cmap->n,sz);
869076ba34aSJunchao Zhang 
870076ba34aSJunchao Zhang   ConstMatColIdxKokkosViewHost tmp(garray,sz);
871076ba34aSJunchao Zhang   l2g = MatColIdxKokkosView("l2g",sz);
872076ba34aSJunchao Zhang   Kokkos::deep_copy(l2g,tmp);
873076ba34aSJunchao Zhang 
874076ba34aSJunchao Zhang   ierr = ISLocalToGlobalMappingRestoreIndices(l2gmap,&garray);CHKERRQ(ierr);
875076ba34aSJunchao Zhang   ierr = ISLocalToGlobalMappingDestroy(&l2gmap);CHKERRQ(ierr);
876076ba34aSJunchao Zhang   PetscFunctionReturn(0);
877076ba34aSJunchao Zhang }
878076ba34aSJunchao Zhang 
879076ba34aSJunchao Zhang /* MatProductSymbolic_MPIAIJKokkos_AB - AB flavor of MatProductSymbolic_MPIAIJKokkos
880076ba34aSJunchao Zhang 
881076ba34aSJunchao Zhang   Input Parameters:
882076ba34aSJunchao Zhang +  product  - Mat_Product which carried out the computation. Passed in to access info about this mat product.
883076ba34aSJunchao Zhang .  A        - an MPIAIJKOKKOS matrix
884076ba34aSJunchao Zhang .  B        - an MPIAIJKOKKOS matrix
885076ba34aSJunchao Zhang -  mm       - a struct used to stash intermediate data when computing AB. Persist from symbolic to numeric operations.
886076ba34aSJunchao Zhang 
887076ba34aSJunchao Zhang   Notes: The local part of the result C is stored as mm->C_global, which is of type KokkosCsrMatrix and uses global col ids.
888076ba34aSJunchao Zhang */
889076ba34aSJunchao Zhang static PetscErrorCode MatProductSymbolic_MPIAIJKokkos_AB(Mat_Product *product,Mat A,Mat B,MatMatStruct_AB *mm)
890076ba34aSJunchao Zhang {
891076ba34aSJunchao Zhang   PetscErrorCode              ierr;
892076ba34aSJunchao Zhang   Mat_MPIAIJ                  *a = static_cast<Mat_MPIAIJ*>(A->data);
893076ba34aSJunchao Zhang   Mat                         Ad = a->A,Ao = a->B; /* diag and offdiag of A */
894076ba34aSJunchao Zhang   IS                          glob = NULL;
895076ba34aSJunchao Zhang   const PetscInt              *garray;
896076ba34aSJunchao Zhang   PetscInt                    N = B->cmap->N,sz;
897076ba34aSJunchao Zhang   ConstMatColIdxKokkosView    l2g1; /* two temp maps mapping local col ids to global ones */
898076ba34aSJunchao Zhang   MatColIdxKokkosView         l2g2;
899076ba34aSJunchao Zhang   Mat                         C1,C2; /* intermediate matrices */
900076ba34aSJunchao Zhang 
901076ba34aSJunchao Zhang   PetscFunctionBegin;
902076ba34aSJunchao Zhang   /* C1 = Ad * B_local. B_local is a matrix got by merging Bd and Bo, and uses local col ids */
903076ba34aSJunchao Zhang   ierr = MatMPIAIJGetLocalMatMerge(B,MAT_INITIAL_MATRIX,&glob,&mm->B_local);CHKERRQ(ierr);
904076ba34aSJunchao Zhang   ierr = MatProductCreate(Ad,mm->B_local,NULL,&C1);CHKERRQ(ierr);
905076ba34aSJunchao Zhang   ierr = MatProductSetType(C1,MATPRODUCT_AB);CHKERRQ(ierr);
906076ba34aSJunchao Zhang   ierr = MatProductSetFill(C1,product->fill);CHKERRQ(ierr);
907076ba34aSJunchao Zhang   C1->product->api_user = product->api_user;
908076ba34aSJunchao Zhang   ierr = MatProductSetFromOptions(C1);CHKERRQ(ierr);
909076ba34aSJunchao Zhang   if (!C1->ops->productsymbolic) SETERRQ1(PETSC_COMM_SELF,PETSC_ERR_PLIB,"Missing symbolic op for %s",MatProductTypes[C1->product->type]);
910076ba34aSJunchao Zhang   ierr = (*C1->ops->productsymbolic)(C1);CHKERRQ(ierr);
911076ba34aSJunchao Zhang 
912076ba34aSJunchao Zhang   ierr = ISGetIndices(glob,&garray);CHKERRQ(ierr);
913076ba34aSJunchao Zhang   ierr = ISGetSize(glob,&sz);CHKERRQ(ierr);
914076ba34aSJunchao Zhang   const auto& tmp  = ConstMatColIdxKokkosViewHost(garray,sz); /* wrap garray as a view */
915076ba34aSJunchao Zhang   l2g1 = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(),tmp); /* maybe just an alias to tmp, so we restore garray at the very end */
916076ba34aSJunchao Zhang   ierr = MatSeqAIJKokkosGetCSRMatrixWithGlobalColumnIds(C1,N,l2g1,mm->C1_global);
917076ba34aSJunchao Zhang 
918076ba34aSJunchao Zhang   /* C2 = Ao * B_other. B_other is a matrix consisting of needed rows of B gathered from other procs */
919076ba34aSJunchao Zhang   ierr = MatSeqAIJKokkosBcast(mm->B_local,MAT_INITIAL_MATRIX,N,l2g1,a->Mvctx,mm->sf,
920076ba34aSJunchao Zhang                               mm->abuf,mm->rows,mm->rowoffset,mm->B_other);CHKERRQ(ierr);
921076ba34aSJunchao Zhang 
922076ba34aSJunchao Zhang   /* Compact B_other to use local ids as we guess KK spgemm is more memroy scalable with that; We could skip the compaction to simplify code */
923076ba34aSJunchao Zhang   ierr = MatSeqAIJCompactOutExtraColumns_SeqAIJKokkos(mm->B_other,l2g2);CHKERRQ(ierr);
924076ba34aSJunchao Zhang   ierr = MatProductCreate(Ao,mm->B_other,NULL,&C2);CHKERRQ(ierr);
925076ba34aSJunchao Zhang   ierr = MatProductSetType(C2,MATPRODUCT_AB);CHKERRQ(ierr);
926076ba34aSJunchao Zhang   ierr = MatProductSetFill(C2,product->fill);CHKERRQ(ierr);
927076ba34aSJunchao Zhang   C2->product->api_user = product->api_user;
928076ba34aSJunchao Zhang   ierr = MatProductSetFromOptions(C2);CHKERRQ(ierr);
929076ba34aSJunchao Zhang   if (!C2->ops->productsymbolic) SETERRQ1(PETSC_COMM_SELF,PETSC_ERR_PLIB,"Missing symbolic op for %s",MatProductTypes[C2->product->type]);
930076ba34aSJunchao Zhang   ierr = (*C2->ops->productsymbolic)(C2);CHKERRQ(ierr);
931076ba34aSJunchao Zhang   ierr = MatSeqAIJKokkosGetCSRMatrixWithGlobalColumnIds(C2,N,l2g2,mm->C2_global);
932076ba34aSJunchao Zhang 
933076ba34aSJunchao Zhang   /* C = C1 + C2.  We actually use their global col ids versions in adding */
934076ba34aSJunchao Zhang   mm->kh.create_spadd_handle(false); /* Input C1, C2 are NOT sorted, since B_local, B_other are not */
935076ba34aSJunchao Zhang   KokkosSparse::spadd_symbolic(&mm->kh,mm->C1_global,mm->C2_global,mm->C_global);
936076ba34aSJunchao Zhang   /* Have to do numeric since spadd_symbolic does not really populate column indices of the result matrix */
937076ba34aSJunchao Zhang   KokkosSparse::spadd_numeric(&mm->kh,(MatScalarType)1.0,mm->C1_global,(MatScalarType)1.0,mm->C2_global,mm->C_global);
938076ba34aSJunchao Zhang 
939076ba34aSJunchao Zhang   mm->C1 = C1;
940076ba34aSJunchao Zhang   mm->C2 = C2;
941076ba34aSJunchao Zhang   ierr = ISRestoreIndices(glob,&garray);CHKERRQ(ierr);
942076ba34aSJunchao Zhang   ierr = ISDestroy(&glob);CHKERRQ(ierr);
943076ba34aSJunchao Zhang   PetscFunctionReturn(0);
944076ba34aSJunchao Zhang }
945076ba34aSJunchao Zhang 
946076ba34aSJunchao Zhang /* MatProductSymbolic_MPIAIJKokkos_AtB - A^tB flavor of MatProductSymbolic_MPIAIJKokkos
947076ba34aSJunchao Zhang 
948076ba34aSJunchao Zhang   Input Parameters:
949076ba34aSJunchao Zhang +  product  - Mat_Product which carried out the computation. Passed in to access info about this mat product.
950076ba34aSJunchao Zhang .  A        - an MPIAIJKOKKOS matrix
951076ba34aSJunchao Zhang .  B        - a SEQAIJKOKKOS matrix. It works as if A^t is multiplied by a parallel matrix made up of Bs on each rank.
952076ba34aSJunchao Zhang .  localB   - Does B use local col ids? If false, then B is already in global col ids.
953076ba34aSJunchao Zhang .  N        - col size of the "parallel B matrix". It implies B's global col ids are in range of [0,N) and N is the same across the communicator.
954076ba34aSJunchao Zhang .  l2g      - If localB, then l2g maps B's local col ids to global ones.
955076ba34aSJunchao Zhang -  mm       - a struct used to stash intermediate data in AtB
956076ba34aSJunchao Zhang 
957076ba34aSJunchao Zhang   Notes: The local part of the result C is stored as mm->C_global, which is of type KokkosCsrMatrix and uses global col ids.
958076ba34aSJunchao Zhang */
959076ba34aSJunchao Zhang static PetscErrorCode MatProductSymbolic_MPIAIJKokkos_AtB(Mat_Product *product,Mat A,Mat B,PetscBool localB,PetscInt N,const ConstMatColIdxKokkosView& l2g,MatMatStruct_AtB *mm)
960076ba34aSJunchao Zhang {
961076ba34aSJunchao Zhang   PetscErrorCode         ierr;
962076ba34aSJunchao Zhang   Mat_MPIAIJ             *a = static_cast<Mat_MPIAIJ*>(A->data);
963076ba34aSJunchao Zhang   Mat                    Ad = a->A,Ao = a->B; /* diag and offdiag of A */
964076ba34aSJunchao Zhang   Mat                    C1,C2; /* intermediate matrices */
965076ba34aSJunchao Zhang 
966076ba34aSJunchao Zhang   PetscFunctionBegin;
967076ba34aSJunchao Zhang   /* C1 = Ad^t * B */
968076ba34aSJunchao Zhang   ierr = MatProductCreate(Ad,B,NULL,&C1);CHKERRQ(ierr);
969076ba34aSJunchao Zhang   ierr = MatProductSetType(C1,MATPRODUCT_AtB);CHKERRQ(ierr);
970076ba34aSJunchao Zhang   ierr = MatProductSetFill(C1,product->fill);CHKERRQ(ierr);
971076ba34aSJunchao Zhang   C1->product->api_user = product->api_user;
972076ba34aSJunchao Zhang   ierr = MatProductSetFromOptions(C1);CHKERRQ(ierr);
973076ba34aSJunchao Zhang   if (!C1->ops->productsymbolic) SETERRQ1(PETSC_COMM_SELF,PETSC_ERR_PLIB,"Missing symbolic op for %s",MatProductTypes[C1->product->type]);
974076ba34aSJunchao Zhang   ierr = (*C1->ops->productsymbolic)(C1);CHKERRQ(ierr);
975076ba34aSJunchao Zhang 
976076ba34aSJunchao Zhang   if (localB) {ierr = MatSeqAIJKokkosGetCSRMatrixWithGlobalColumnIds(C1,N,l2g,mm->C1_global);}
977076ba34aSJunchao Zhang   else mm->C1_global = static_cast<Mat_SeqAIJKokkos*>(C1->spptr)->csrmat; /* the csrmat already uses global col ids */
978076ba34aSJunchao Zhang 
979076ba34aSJunchao Zhang   /* C2 = Ao^t * B */
980076ba34aSJunchao Zhang   ierr = MatProductCreate(Ao,B,NULL,&C2);CHKERRQ(ierr);
981076ba34aSJunchao Zhang   ierr = MatProductSetType(C2,MATPRODUCT_AtB);CHKERRQ(ierr);
982076ba34aSJunchao Zhang   ierr = MatProductSetFill(C2,product->fill);CHKERRQ(ierr);
983076ba34aSJunchao Zhang   C2->product->api_user = product->api_user;
984076ba34aSJunchao Zhang   ierr = MatProductSetFromOptions(C2);CHKERRQ(ierr);
985076ba34aSJunchao Zhang   if (!C2->ops->productsymbolic) SETERRQ1(PETSC_COMM_SELF,PETSC_ERR_PLIB,"Missing symbolic op for %s",MatProductTypes[C2->product->type]);
986076ba34aSJunchao Zhang   ierr = (*C2->ops->productsymbolic)(C2);CHKERRQ(ierr);
987076ba34aSJunchao Zhang 
988076ba34aSJunchao Zhang   ierr = MatSeqAIJKokkosReduce(C2,MAT_INITIAL_MATRIX,localB,N,l2g,a->Mvctx,mm->sf,mm->abuf,
989076ba34aSJunchao Zhang                                mm->srcrowoffset,mm->dstrowoffset,mm->C2_global);CHKERRQ(ierr);
990076ba34aSJunchao Zhang 
991076ba34aSJunchao Zhang   mm->kh.create_spadd_handle(false); /* Input C1, C2 are NOT sorted, since B may be not */
992076ba34aSJunchao Zhang   KokkosSparse::spadd_symbolic(&mm->kh,mm->C1_global,mm->C2_global,mm->C_global);
993076ba34aSJunchao Zhang   /* Have to do numeric since spadd_symbolic does not really populate column indices of the result matrix */
994076ba34aSJunchao Zhang   KokkosSparse::spadd_numeric(&mm->kh,(MatScalarType)1.0,mm->C1_global,(MatScalarType)1.0,mm->C2_global,mm->C_global);
995076ba34aSJunchao Zhang   mm->C1 = C1;
996076ba34aSJunchao Zhang   mm->C2 = C2;
997076ba34aSJunchao Zhang   PetscFunctionReturn(0);
998076ba34aSJunchao Zhang }
999076ba34aSJunchao Zhang 
1000076ba34aSJunchao Zhang PetscErrorCode MatProductNumeric_MPIAIJKokkos(Mat C)
1001076ba34aSJunchao Zhang {
1002076ba34aSJunchao Zhang   PetscErrorCode                ierr;
1003076ba34aSJunchao Zhang   Mat_Product                   *product = C->product;
1004076ba34aSJunchao Zhang   MatProductType                ptype;
1005076ba34aSJunchao Zhang   MatProductData_MPIAIJKokkos   *mmdata;
1006076ba34aSJunchao Zhang   MatMatStruct                  *mm = NULL;
1007076ba34aSJunchao Zhang   MatMatStruct_AB               *ab;
1008076ba34aSJunchao Zhang   MatMatStruct_AtB              *atb;
1009076ba34aSJunchao Zhang   Mat                           A,B,Ad,Ao,Bd,Bo;
1010076ba34aSJunchao Zhang   const MatScalarType           one = 1.0; /* Not use literal 1.0 directly, to avoid wrong template instantiation in KokkosSparse::spadd_numeric */
1011076ba34aSJunchao Zhang 
1012076ba34aSJunchao Zhang   PetscFunctionBegin;
1013076ba34aSJunchao Zhang   MatCheckProduct(C,1);
1014076ba34aSJunchao Zhang   mmdata = static_cast<MatProductData_MPIAIJKokkos*>(product->data);
1015076ba34aSJunchao Zhang   ptype  = product->type;
1016076ba34aSJunchao Zhang   A      = product->A;
1017076ba34aSJunchao Zhang   B      = product->B;
1018076ba34aSJunchao Zhang   Ad     = static_cast<Mat_MPIAIJ*>(A->data)->A;
1019076ba34aSJunchao Zhang   Ao     = static_cast<Mat_MPIAIJ*>(A->data)->B;
1020076ba34aSJunchao Zhang   Bd     = static_cast<Mat_MPIAIJ*>(B->data)->A;
1021076ba34aSJunchao Zhang   Bo     = static_cast<Mat_MPIAIJ*>(B->data)->B;
1022076ba34aSJunchao Zhang 
1023076ba34aSJunchao Zhang   if (mmdata->reusesym) { /* We reached here through e.g., MatMatMult(A,B,MAT_INITIAL_MATRIX,..,C), where symbolic/numeric are combined */
1024076ba34aSJunchao Zhang     mmdata->reusesym = PETSC_FALSE; /* So that next time when user calls MatMatMult(E,F,MAT_REUSE_MATRIX,..,C), we still do numeric  */
1025076ba34aSJunchao Zhang     ab  = mmdata->mmAB;
1026076ba34aSJunchao Zhang     atb = mmdata->mmAtB;
1027076ba34aSJunchao Zhang     if (ab) {
1028076ba34aSJunchao Zhang       static_cast<MatProductData_SeqAIJKokkos*>(ab->C1->product->data)->reusesym = PETSC_FALSE;
1029076ba34aSJunchao Zhang       static_cast<MatProductData_SeqAIJKokkos*>(ab->C2->product->data)->reusesym = PETSC_FALSE;
1030076ba34aSJunchao Zhang     }
1031076ba34aSJunchao Zhang     if (atb) {
1032076ba34aSJunchao Zhang       static_cast<MatProductData_SeqAIJKokkos*>(atb->C1->product->data)->reusesym = PETSC_FALSE;
1033076ba34aSJunchao Zhang       static_cast<MatProductData_SeqAIJKokkos*>(atb->C2->product->data)->reusesym = PETSC_FALSE;
1034076ba34aSJunchao Zhang     }
1035076ba34aSJunchao Zhang     PetscFunctionReturn(0);
1036076ba34aSJunchao Zhang   }
1037076ba34aSJunchao Zhang 
1038076ba34aSJunchao Zhang   if (ptype == MATPRODUCT_AB) {
1039076ba34aSJunchao Zhang     ab   = mmdata->mmAB;
1040076ba34aSJunchao Zhang     /* C1 = Ad * B_local */
1041076ba34aSJunchao Zhang     if (!ab->C1->ops->productnumeric || !ab->C2->ops->productnumeric) SETERRQ(PETSC_COMM_SELF,PETSC_ERR_PLIB,"Missing numeric op for MATPRODUCT_AB");
1042076ba34aSJunchao Zhang     ierr = MatMPIAIJGetLocalMatMerge(B,MAT_REUSE_MATRIX,NULL/*glob*/,&ab->B_local);CHKERRQ(ierr);
1043076ba34aSJunchao Zhang     if (ab->C1->product->B != ab->B_local) SETERRQ(PETSC_COMM_SELF,PETSC_ERR_PLIB,"In MATPRODUCT_AB, internal mat product matrix C1->B has unexpectedly changed");
1044076ba34aSJunchao Zhang     if (ab->C1->product->A != Ad) {ierr = MatProductReplaceMats(Ad,NULL,NULL,ab->C1);CHKERRQ(ierr);}
1045076ba34aSJunchao Zhang     ierr = (*ab->C1->ops->productnumeric)(ab->C1);CHKERRQ(ierr);
1046076ba34aSJunchao Zhang     ierr = MatSeqAIJKokkosBcast(ab->B_local,MAT_REUSE_MATRIX,0/*N*/,MatColIdxKokkosView()/*l2g*/,NULL/*ownerSF*/,ab->sf,
1047076ba34aSJunchao Zhang                                 ab->abuf,ab->rows,ab->rowoffset,ab->B_other);CHKERRQ(ierr);
1048076ba34aSJunchao Zhang     /* C2 = Ao * B_other */
1049076ba34aSJunchao Zhang     if (ab->C2->product->B != ab->B_other) SETERRQ(PETSC_COMM_SELF,PETSC_ERR_PLIB,"In MATPRODUCT_AB, internal mat product matrix C2->B has unexpectedly changed");
1050076ba34aSJunchao Zhang     if (ab->C1->product->A != Ao) {ierr = MatProductReplaceMats(Ao,NULL,NULL,ab->C2);CHKERRQ(ierr);}
1051076ba34aSJunchao Zhang     ierr = (*ab->C2->ops->productnumeric)(ab->C2);CHKERRQ(ierr);
1052076ba34aSJunchao Zhang     /* C = C1_global + C2_global */
1053076ba34aSJunchao Zhang     KokkosSparse::spadd_numeric(&ab->kh,one,ab->C1_global,one,ab->C2_global,ab->C_global);
1054076ba34aSJunchao Zhang     mm = static_cast<MatMatStruct*>(ab);
1055076ba34aSJunchao Zhang   } else if (ptype == MATPRODUCT_AtB) {
1056076ba34aSJunchao Zhang     atb  = mmdata->mmAtB;
1057076ba34aSJunchao Zhang     if (!atb->C1->ops->productnumeric || !atb->C2->ops->productnumeric) SETERRQ(PETSC_COMM_SELF,PETSC_ERR_PLIB,"Missing numeric op for MATPRODUCT_AtB");
1058076ba34aSJunchao Zhang     /* C1 = Ad^t * B_local */
1059076ba34aSJunchao Zhang     ierr = MatMPIAIJGetLocalMatMerge(B,MAT_REUSE_MATRIX,NULL/*glob*/,&atb->B_local);CHKERRQ(ierr);
1060076ba34aSJunchao Zhang     if (atb->C1->product->B != atb->B_local) SETERRQ(PETSC_COMM_SELF,PETSC_ERR_PLIB,"In MATPRODUCT_AtB, internal mat product matrix C1->B has unexpectedly changed");
1061076ba34aSJunchao Zhang     if (atb->C1->product->A != Ad) {ierr = MatProductReplaceMats(Ad,NULL,NULL,atb->C1);CHKERRQ(ierr);}
1062076ba34aSJunchao Zhang     ierr = (*atb->C1->ops->productnumeric)(atb->C1);CHKERRQ(ierr);
1063076ba34aSJunchao Zhang 
1064076ba34aSJunchao Zhang     /* C2 = Ao^t * B_local */
1065076ba34aSJunchao Zhang     if (atb->C2->product->B != atb->B_local) SETERRQ(PETSC_COMM_SELF,PETSC_ERR_PLIB,"In MATPRODUCT_AtB, internal mat product matrix C2->B has unexpectedly changed");
1066076ba34aSJunchao Zhang     if (atb->C2->product->A != Ao) {ierr = MatProductReplaceMats(Ao,NULL,NULL,atb->C2);CHKERRQ(ierr);}
1067076ba34aSJunchao Zhang     ierr = (*atb->C2->ops->productnumeric)(atb->C2);CHKERRQ(ierr);
1068076ba34aSJunchao Zhang     /* Form C2_global */
1069076ba34aSJunchao Zhang     ierr = MatSeqAIJKokkosReduce(atb->C2,MAT_REUSE_MATRIX,PETSC_TRUE,0/*N*/,MatColIdxKokkosView()/*l2g*/,NULL/*ownerSF*/,atb->sf,
1070076ba34aSJunchao Zhang                                  atb->abuf,atb->srcrowoffset,atb->dstrowoffset,atb->C2_global);CHKERRQ(ierr);
1071076ba34aSJunchao Zhang     /* C = C1_global + C2_global */
1072076ba34aSJunchao Zhang     KokkosSparse::spadd_numeric(&atb->kh,one,atb->C1_global,one,atb->C2_global,atb->C_global);
1073076ba34aSJunchao Zhang     mm = static_cast<MatMatStruct*>(atb);
1074076ba34aSJunchao Zhang   } else if (ptype == MATPRODUCT_PtAP) { /* BtAB */
1075076ba34aSJunchao Zhang     ab   = mmdata->mmAB;
1076076ba34aSJunchao Zhang     ierr = MatMPIAIJGetLocalMatMerge(B,MAT_REUSE_MATRIX,NULL/*glob*/,&ab->B_local);CHKERRQ(ierr);
1077076ba34aSJunchao Zhang 
1078076ba34aSJunchao Zhang     /* ab->C1 = Ad * B_local */
1079076ba34aSJunchao Zhang     if (ab->C1->product->B != ab->B_local) SETERRQ(PETSC_COMM_SELF,PETSC_ERR_PLIB,"In MATPRODUCT_PtAP, internal mat product matrix ab->C1->B has unexpectedly changed");
1080076ba34aSJunchao Zhang     if (ab->C1->product->A != Ad) {ierr = MatProductReplaceMats(Ad,NULL,NULL,ab->C1);CHKERRQ(ierr);}
1081076ba34aSJunchao Zhang     ierr = (*ab->C1->ops->productnumeric)(ab->C1);CHKERRQ(ierr);
1082076ba34aSJunchao Zhang     ierr = MatSeqAIJKokkosBcast(ab->B_local,MAT_REUSE_MATRIX,0/*N*/,MatColIdxKokkosView()/*l2g*/,NULL/*ownerSF*/,ab->sf,
1083076ba34aSJunchao Zhang                                 ab->abuf,ab->rows,ab->rowoffset,ab->B_other);CHKERRQ(ierr);
1084076ba34aSJunchao Zhang     /* ab->C2 = Ao * B_other */
1085076ba34aSJunchao Zhang     if (ab->C2->product->A != Ao) {ierr = MatProductReplaceMats(Ao,NULL,NULL,ab->C2);CHKERRQ(ierr);}
1086076ba34aSJunchao Zhang     ierr = (*ab->C2->ops->productnumeric)(ab->C2);CHKERRQ(ierr); /* C2 = Ao * B_other */
1087076ba34aSJunchao Zhang     KokkosSparse::spadd_numeric(&ab->kh,one,ab->C1_global,one,ab->C2_global,ab->C_global);
1088076ba34aSJunchao Zhang 
1089076ba34aSJunchao Zhang     /* atb->C1 = Bd^t * ab->C_petsc */
1090076ba34aSJunchao Zhang     atb  = mmdata->mmAtB;
1091076ba34aSJunchao Zhang     if (atb->C1->product->B != ab->C_petsc) SETERRQ(PETSC_COMM_SELF,PETSC_ERR_PLIB,"In MATPRODUCT_PtAP, internal mat product matrix atb->C1->B has unexpectedly changed");
1092076ba34aSJunchao Zhang     if (atb->C1->product->A != Bd) {ierr = MatProductReplaceMats(Bd,NULL,NULL,atb->C1);CHKERRQ(ierr);}
1093076ba34aSJunchao Zhang     ierr = (*atb->C1->ops->productnumeric)(atb->C1);CHKERRQ(ierr);
1094076ba34aSJunchao Zhang     /* atb->C2 = Bo^t * ab->C_petsc */
1095076ba34aSJunchao Zhang     if (atb->C2->product->A != Bo) {ierr = MatProductReplaceMats(Bo,NULL,NULL,atb->C2);CHKERRQ(ierr);}
1096076ba34aSJunchao Zhang     ierr = (*atb->C2->ops->productnumeric)(atb->C2);CHKERRQ(ierr);
1097076ba34aSJunchao Zhang     ierr = MatSeqAIJKokkosReduce(atb->C2,MAT_REUSE_MATRIX,PETSC_FALSE,0/*N*/,MatColIdxKokkosView()/*l2g*/,NULL/*ownerSF*/,atb->sf,
1098076ba34aSJunchao Zhang                                  atb->abuf,atb->srcrowoffset,atb->dstrowoffset,atb->C2_global);CHKERRQ(ierr);
1099076ba34aSJunchao Zhang     KokkosSparse::spadd_numeric(&atb->kh,one,atb->C1_global,one,atb->C2_global,atb->C_global);
1100076ba34aSJunchao Zhang     mm = static_cast<MatMatStruct*>(atb);
1101076ba34aSJunchao Zhang   }
1102076ba34aSJunchao Zhang   /* Split C_global to form C */
1103076ba34aSJunchao Zhang   ierr = MatSetMPIAIJKokkosWithGlobalCSRMatrix(C,MAT_REUSE_MATRIX,mm->C_global,mm->Cdstart);CHKERRQ(ierr);
1104076ba34aSJunchao Zhang   PetscFunctionReturn(0);
1105076ba34aSJunchao Zhang }
1106076ba34aSJunchao Zhang 
1107076ba34aSJunchao Zhang PetscErrorCode MatProductSymbolic_MPIAIJKokkos(Mat C)
1108076ba34aSJunchao Zhang {
1109076ba34aSJunchao Zhang   PetscErrorCode              ierr;
1110076ba34aSJunchao Zhang   Mat                         A,B;
1111076ba34aSJunchao Zhang   Mat_Product                 *product = C->product;
1112076ba34aSJunchao Zhang   MatProductType              ptype;
1113076ba34aSJunchao Zhang   MatProductData_MPIAIJKokkos *mmdata;
1114076ba34aSJunchao Zhang   MatMatStruct                *mm = NULL;
1115076ba34aSJunchao Zhang   IS                          glob = NULL;
1116076ba34aSJunchao Zhang   const PetscInt              *garray;
1117076ba34aSJunchao Zhang   PetscInt                    m,n,M,N,sz;
1118076ba34aSJunchao Zhang   ConstMatColIdxKokkosView    l2g; /* map local col ids to global ones */
1119076ba34aSJunchao Zhang 
1120076ba34aSJunchao Zhang   PetscFunctionBegin;
1121076ba34aSJunchao Zhang   MatCheckProduct(C,1);
1122076ba34aSJunchao Zhang   if (product->data) SETERRQ(PetscObjectComm((PetscObject)C),PETSC_ERR_PLIB,"Product data not empty");
1123076ba34aSJunchao Zhang   ptype = product->type;
1124076ba34aSJunchao Zhang   A     = product->A;
1125076ba34aSJunchao Zhang   B     = product->B;
1126076ba34aSJunchao Zhang 
1127076ba34aSJunchao Zhang   switch (ptype) {
1128076ba34aSJunchao Zhang     case MATPRODUCT_AB:   m = A->rmap->n; n = B->cmap->n; M = A->rmap->N; N = B->cmap->N; break;
1129076ba34aSJunchao Zhang     case MATPRODUCT_AtB:  m = A->cmap->n; n = B->cmap->n; M = A->cmap->N; N = B->cmap->N; break;
1130076ba34aSJunchao Zhang     case MATPRODUCT_PtAP: m = B->cmap->n; n = B->cmap->n; M = B->cmap->N; N = B->cmap->N; break; /* BtAB */
1131076ba34aSJunchao Zhang     default: SETERRQ1(PetscObjectComm((PetscObject)C),PETSC_ERR_PLIB,"Not for product type %s",MatProductTypes[ptype]);
1132076ba34aSJunchao Zhang   }
1133076ba34aSJunchao Zhang 
1134076ba34aSJunchao Zhang   ierr = MatSetSizes(C,m,n,M,N);CHKERRQ(ierr);
1135076ba34aSJunchao Zhang   ierr = PetscLayoutSetUp(C->rmap);CHKERRQ(ierr);
1136076ba34aSJunchao Zhang   ierr = PetscLayoutSetUp(C->cmap);CHKERRQ(ierr);
1137076ba34aSJunchao Zhang   ierr = MatSetType(C,((PetscObject)A)->type_name);CHKERRQ(ierr);
1138076ba34aSJunchao Zhang 
1139076ba34aSJunchao Zhang   mmdata           = new MatProductData_MPIAIJKokkos();
1140076ba34aSJunchao Zhang   mmdata->reusesym = product->api_user;
1141076ba34aSJunchao Zhang 
1142076ba34aSJunchao Zhang   if (ptype == MATPRODUCT_AB) {
1143076ba34aSJunchao Zhang     mmdata->mmAB = new MatMatStruct_AB();
1144076ba34aSJunchao Zhang     ierr = MatProductSymbolic_MPIAIJKokkos_AB(product,A,B,mmdata->mmAB);CHKERRQ(ierr);
1145076ba34aSJunchao Zhang     mm   = static_cast<MatMatStruct*>(mmdata->mmAB);
1146076ba34aSJunchao Zhang   } else if (ptype == MATPRODUCT_AtB) {
1147076ba34aSJunchao Zhang     mmdata->mmAtB = new MatMatStruct_AtB();
1148076ba34aSJunchao Zhang     auto atb      = mmdata->mmAtB;
1149076ba34aSJunchao Zhang     ierr = MatMPIAIJGetLocalMatMerge(B,MAT_INITIAL_MATRIX,&glob,&atb->B_local);CHKERRQ(ierr);
1150076ba34aSJunchao Zhang     ierr = ISGetIndices(glob,&garray);CHKERRQ(ierr);
1151076ba34aSJunchao Zhang     ierr = ISGetSize(glob,&sz);CHKERRQ(ierr);
1152076ba34aSJunchao Zhang     l2g  = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(),ConstMatColIdxKokkosViewHost(garray,sz));
1153076ba34aSJunchao Zhang     ierr = MatProductSymbolic_MPIAIJKokkos_AtB(product,A,atb->B_local,PETSC_TRUE,N,l2g,atb);CHKERRQ(ierr);
1154076ba34aSJunchao Zhang     ierr = ISRestoreIndices(glob,&garray);CHKERRQ(ierr);
1155076ba34aSJunchao Zhang     ierr = ISDestroy(&glob);CHKERRQ(ierr);
1156076ba34aSJunchao Zhang     mm   = static_cast<MatMatStruct*>(atb);
1157076ba34aSJunchao Zhang   } else if (ptype == MATPRODUCT_PtAP) { /* BtAB */
1158076ba34aSJunchao Zhang     mmdata->mmAB  = new MatMatStruct_AB(); /* tmp=A*B */
1159076ba34aSJunchao Zhang     mmdata->mmAtB = new MatMatStruct_AtB(); /* C=B^t*tmp */
1160076ba34aSJunchao Zhang     auto ab       = mmdata->mmAB;
1161076ba34aSJunchao Zhang     auto atb      = mmdata->mmAtB;
1162076ba34aSJunchao Zhang     ierr = MatProductSymbolic_MPIAIJKokkos_AB(product,A,B,ab);CHKERRQ(ierr);
1163076ba34aSJunchao Zhang     auto tmp = new Mat_SeqAIJKokkos(ab->C_global); /* Memory will be owned by ab->C_petsc */
1164076ba34aSJunchao Zhang     ierr = MatCreateSeqAIJKokkosWithCSRMatrix(PETSC_COMM_SELF,tmp,&ab->C_petsc);CHKERRQ(ierr);
1165076ba34aSJunchao Zhang     ierr = MatProductSymbolic_MPIAIJKokkos_AtB(product,B,ab->C_petsc,PETSC_FALSE,N,l2g/*not used*/,atb);CHKERRQ(ierr);
1166076ba34aSJunchao Zhang     mm   = static_cast<MatMatStruct*>(atb);
1167076ba34aSJunchao Zhang   }
1168076ba34aSJunchao Zhang   /* Split the C_global into petsc A, B format */
1169076ba34aSJunchao Zhang   ierr = MatSetMPIAIJKokkosWithGlobalCSRMatrix(C,MAT_INITIAL_MATRIX,mm->C_global,mm->Cdstart);CHKERRQ(ierr);
1170076ba34aSJunchao Zhang   C->product->data        = mmdata;
1171076ba34aSJunchao Zhang   C->product->destroy     = MatProductDataDestroy_MPIAIJKokkos;
1172076ba34aSJunchao Zhang   C->ops->productnumeric  = MatProductNumeric_MPIAIJKokkos;
1173076ba34aSJunchao Zhang   PetscFunctionReturn(0);
1174076ba34aSJunchao Zhang }
1175076ba34aSJunchao Zhang 
1176076ba34aSJunchao Zhang PETSC_INTERN PetscErrorCode MatProductSetFromOptions_MPIAIJKokkos(Mat mat)
1177076ba34aSJunchao Zhang {
1178076ba34aSJunchao Zhang   PetscErrorCode ierr;
1179076ba34aSJunchao Zhang   Mat_Product    *product = mat->product;
1180076ba34aSJunchao Zhang   PetscBool      match = PETSC_FALSE;
1181076ba34aSJunchao Zhang   PetscBool      usecpu = PETSC_FALSE;
1182076ba34aSJunchao Zhang 
1183076ba34aSJunchao Zhang   PetscFunctionBegin;
1184076ba34aSJunchao Zhang   MatCheckProduct(mat,1);
1185076ba34aSJunchao Zhang   if (!product->A->boundtocpu && !product->B->boundtocpu) {
1186076ba34aSJunchao Zhang     ierr = PetscObjectTypeCompare((PetscObject)product->B,((PetscObject)product->A)->type_name,&match);CHKERRQ(ierr);
1187076ba34aSJunchao Zhang   }
1188076ba34aSJunchao Zhang   if (match) { /* we can always fallback to the CPU if requested */
1189076ba34aSJunchao Zhang     switch (product->type) {
1190076ba34aSJunchao Zhang     case MATPRODUCT_AB:
1191076ba34aSJunchao Zhang       if (product->api_user) {
1192076ba34aSJunchao Zhang         ierr = PetscOptionsBegin(PetscObjectComm((PetscObject)mat),((PetscObject)mat)->prefix,"MatMatMult","Mat");CHKERRQ(ierr);
1193076ba34aSJunchao Zhang         ierr = PetscOptionsBool("-matmatmult_backend_cpu","Use CPU code","MatMatMult",usecpu,&usecpu,NULL);CHKERRQ(ierr);
1194076ba34aSJunchao Zhang         ierr = PetscOptionsEnd();CHKERRQ(ierr);
1195076ba34aSJunchao Zhang       } else {
1196076ba34aSJunchao Zhang         ierr = PetscOptionsBegin(PetscObjectComm((PetscObject)mat),((PetscObject)mat)->prefix,"MatProduct_AB","Mat");CHKERRQ(ierr);
1197*3e662e0bSHong Zhang         ierr = PetscOptionsBool("-mat_product_algorithm_backend_cpu","Use CPU code","MatMatMult",usecpu,&usecpu,NULL);CHKERRQ(ierr);
1198076ba34aSJunchao Zhang         ierr = PetscOptionsEnd();CHKERRQ(ierr);
1199076ba34aSJunchao Zhang       }
1200076ba34aSJunchao Zhang       break;
1201076ba34aSJunchao Zhang     case MATPRODUCT_AtB:
1202076ba34aSJunchao Zhang       if (product->api_user) {
1203076ba34aSJunchao Zhang         ierr = PetscOptionsBegin(PetscObjectComm((PetscObject)mat),((PetscObject)mat)->prefix,"MatTransposeMatMult","Mat");CHKERRQ(ierr);
1204076ba34aSJunchao Zhang         ierr = PetscOptionsBool("-mattransposematmult_backend_cpu","Use CPU code","MatTransposeMatMult",usecpu,&usecpu,NULL);CHKERRQ(ierr);
1205076ba34aSJunchao Zhang         ierr = PetscOptionsEnd();CHKERRQ(ierr);
1206076ba34aSJunchao Zhang       } else {
1207076ba34aSJunchao Zhang         ierr = PetscOptionsBegin(PetscObjectComm((PetscObject)mat),((PetscObject)mat)->prefix,"MatProduct_AtB","Mat");CHKERRQ(ierr);
1208*3e662e0bSHong Zhang         ierr = PetscOptionsBool("-mat_product_algorithm_backend_cpu","Use CPU code","MatTransposeMatMult",usecpu,&usecpu,NULL);CHKERRQ(ierr);
1209076ba34aSJunchao Zhang         ierr = PetscOptionsEnd();CHKERRQ(ierr);
1210076ba34aSJunchao Zhang       }
1211076ba34aSJunchao Zhang       break;
1212076ba34aSJunchao Zhang     case MATPRODUCT_PtAP:
1213076ba34aSJunchao Zhang       if (product->api_user) {
1214076ba34aSJunchao Zhang         ierr = PetscOptionsBegin(PetscObjectComm((PetscObject)mat),((PetscObject)mat)->prefix,"MatPtAP","Mat");CHKERRQ(ierr);
1215076ba34aSJunchao Zhang         ierr = PetscOptionsBool("-matptap_backend_cpu","Use CPU code","MatPtAP",usecpu,&usecpu,NULL);CHKERRQ(ierr);
1216076ba34aSJunchao Zhang         ierr = PetscOptionsEnd();CHKERRQ(ierr);
1217076ba34aSJunchao Zhang       } else {
1218076ba34aSJunchao Zhang         ierr = PetscOptionsBegin(PetscObjectComm((PetscObject)mat),((PetscObject)mat)->prefix,"MatProduct_PtAP","Mat");CHKERRQ(ierr);
1219*3e662e0bSHong Zhang         ierr = PetscOptionsBool("-mat_product_algorithm_backend_cpu","Use CPU code","MatPtAP",usecpu,&usecpu,NULL);CHKERRQ(ierr);
1220076ba34aSJunchao Zhang         ierr = PetscOptionsEnd();CHKERRQ(ierr);
1221076ba34aSJunchao Zhang       }
1222076ba34aSJunchao Zhang       break;
1223076ba34aSJunchao Zhang     default:
1224076ba34aSJunchao Zhang       break;
1225076ba34aSJunchao Zhang     }
1226076ba34aSJunchao Zhang     match = (PetscBool)!usecpu;
1227076ba34aSJunchao Zhang   }
1228076ba34aSJunchao Zhang   if (match) {
1229076ba34aSJunchao Zhang     switch (product->type) {
1230076ba34aSJunchao Zhang     case MATPRODUCT_AB:
1231076ba34aSJunchao Zhang     case MATPRODUCT_AtB:
1232076ba34aSJunchao Zhang     case MATPRODUCT_PtAP:
1233076ba34aSJunchao Zhang       mat->ops->productsymbolic = MatProductSymbolic_MPIAIJKokkos;
1234076ba34aSJunchao Zhang       break;
1235076ba34aSJunchao Zhang     default:
1236076ba34aSJunchao Zhang       break;
1237076ba34aSJunchao Zhang     }
1238076ba34aSJunchao Zhang   }
1239076ba34aSJunchao Zhang   /* fallback to MPIAIJ ops */
1240076ba34aSJunchao Zhang   if (!mat->ops->productsymbolic) {
1241076ba34aSJunchao Zhang     ierr = MatProductSetFromOptions_MPIAIJ(mat);CHKERRQ(ierr);
1242076ba34aSJunchao Zhang   }
1243076ba34aSJunchao Zhang   PetscFunctionReturn(0);
1244076ba34aSJunchao Zhang }
1245076ba34aSJunchao Zhang 
124642550becSJunchao Zhang /* std::upper_bound(): Given a sorted array, return index of the first element in range [first,last) whose value
124742550becSJunchao Zhang    is greater than value, or last if there is no such element.
124842550becSJunchao Zhang */
124942550becSJunchao Zhang PETSC_STATIC_INLINE PetscErrorCode PetscSortedIntUpperBound(PetscInt *array,PetscInt first,PetscInt last,PetscInt value,PetscInt *upper)
125042550becSJunchao Zhang {
125142550becSJunchao Zhang   PetscInt  it,step,count = last - first;
125242550becSJunchao Zhang 
125342550becSJunchao Zhang   PetscFunctionBegin;
125442550becSJunchao Zhang   while (count > 0) {
125542550becSJunchao Zhang     it   = first;
125642550becSJunchao Zhang     step = count / 2;
125742550becSJunchao Zhang     it  += step;
125842550becSJunchao Zhang     if (!(value < array[it])) {
125942550becSJunchao Zhang       first  = ++it;
126042550becSJunchao Zhang       count -= step + 1;
126142550becSJunchao Zhang     } else count = step;
126242550becSJunchao Zhang   }
126342550becSJunchao Zhang   *upper = first;
126442550becSJunchao Zhang   PetscFunctionReturn(0);
126542550becSJunchao Zhang }
126642550becSJunchao Zhang 
126742550becSJunchao Zhang /* Merge two sets of sorted nonzero entries and return a CSR for the merged (sequential) matrix
126842550becSJunchao Zhang 
126942550becSJunchao Zhang   Input Parameters:
127042550becSJunchao Zhang 
127142550becSJunchao Zhang     j1,rowBegin1,rowEnd1,perm1,jmap1: describe the first set of nonzeros (Set1)
127242550becSJunchao Zhang     j2,rowBegin2,rowEnd2,perm2,jmap2: describe the second set of nonzeros (Set2)
127342550becSJunchao Zhang 
127442550becSJunchao Zhang     mat: both sets' entries are on m rows, where m is the number of local rows of the matrix mat
127542550becSJunchao Zhang 
127642550becSJunchao Zhang     For Set1, j1[] contains column indices of the nonzeros.
127742550becSJunchao Zhang     For the k-th row (0<=k<m), [rowBegin1[k],rowEnd1[k]) index into j1[] and point to the begin/end nonzero in row k
127842550becSJunchao Zhang     respectively (note rowEnd1[k] is not necessarily equal to rwoBegin1[k+1]). Indices in this range of j1[] are sorted,
127942550becSJunchao Zhang     but might have repeats. jmap1[t+1] - jmap1[t] is the number of repeats for the t-th unique nonzero in Set1.
128042550becSJunchao Zhang 
128142550becSJunchao Zhang     Similar for Set2.
128242550becSJunchao Zhang 
128342550becSJunchao Zhang     This routine merges the two sets of nonzeros row by row and removes repeats.
128442550becSJunchao Zhang 
128542550becSJunchao Zhang   Output Parameters: (memories are allocated by the caller)
128642550becSJunchao Zhang 
128742550becSJunchao Zhang     i[],j[]: the CSR of the merged matrix, which has m rows.
128842550becSJunchao Zhang     imap1[]: the k-th unique nonzero in Set1 (k=0,1,...) corresponds to imap1[k]-th unique nonzero in the merged matrix.
128942550becSJunchao Zhang     imap2[]: similar to imap1[], but for Set2.
129042550becSJunchao Zhang     Note we order nonzeros row-by-row and from left to right.
129142550becSJunchao Zhang */
129242550becSJunchao Zhang static PetscErrorCode MatMergeEntries_Internal(Mat mat,const PetscInt *j1,const PetscInt *j2,const PetscInt *rowBegin1,const PetscInt *rowEnd1,
129342550becSJunchao Zhang   const PetscInt *rowBegin2,const PetscInt *rowEnd2,const MatRowMapKokkosViewHost& jmap1_h,const MatRowMapKokkosViewHost& jmap2_h,
129442550becSJunchao Zhang   MatRowMapKokkosViewHost& imap1_h,MatRowMapKokkosViewHost& imap2_h,PetscInt *i,PetscInt *j)
129542550becSJunchao Zhang {
129642550becSJunchao Zhang   PetscErrorCode ierr;
129742550becSJunchao Zhang   PetscInt       r,m,t,t1,t2,b1,e1,b2,e2;
129842550becSJunchao Zhang   PetscInt       *jmap1 = jmap1_h.data(),*jmap2 = jmap2_h.data(),*imap1 = imap1_h.data(),*imap2 = imap2_h.data();
129942550becSJunchao Zhang 
130042550becSJunchao Zhang   PetscFunctionBegin;
130142550becSJunchao Zhang   ierr = MatGetLocalSize(mat,&m,NULL);CHKERRQ(ierr);
130242550becSJunchao Zhang   t1   = t2 = t = 0; /* Count unique nonzeros of in Set1, Set1 and the merged respectively */
130342550becSJunchao Zhang   i[0] = 0;
130442550becSJunchao Zhang   for (r=0; r<m; r++) { /* Do row by row merging */
130542550becSJunchao Zhang     b1   = rowBegin1[r];
130642550becSJunchao Zhang     e1   = rowEnd1[r];
130742550becSJunchao Zhang     b2   = rowBegin2[r];
130842550becSJunchao Zhang     e2   = rowEnd2[r];
130942550becSJunchao Zhang     while (b1 < e1 && b2 < e2) {
131042550becSJunchao Zhang       if (j1[b1] == j2[b2]) { /* Same column index and hence same nonzero */
131142550becSJunchao Zhang         j[t]      = j1[b1];
131242550becSJunchao Zhang         imap1[t1] = t;
131342550becSJunchao Zhang         imap2[t2] = t;
131442550becSJunchao Zhang         b1       += jmap1[t1+1] - jmap1[t1]; /* Jump to next unique local nonzero */
131542550becSJunchao Zhang         b2       += jmap2[t2+1] - jmap2[t2]; /* Jump to next unique remote nonzero */
131642550becSJunchao Zhang         t1++; t2++; t++;
131742550becSJunchao Zhang       } else if (j1[b1] < j2[b2]) {
131842550becSJunchao Zhang         j[t]      = j1[b1];
131942550becSJunchao Zhang         imap1[t1] = t;
132042550becSJunchao Zhang         b1       += jmap1[t1+1] - jmap1[t1];
132142550becSJunchao Zhang         t1++; t++;
132242550becSJunchao Zhang       } else {
132342550becSJunchao Zhang         j[t]      = j2[b2];
132442550becSJunchao Zhang         imap2[t2] = t;
132542550becSJunchao Zhang         b2       += jmap2[t2+1] - jmap2[t2];
132642550becSJunchao Zhang         t2++; t++;
132742550becSJunchao Zhang       }
132842550becSJunchao Zhang     }
132942550becSJunchao Zhang     /* Merge the remaining in either j1[] or j2[] */
133042550becSJunchao Zhang     while (b1 < e1) {
133142550becSJunchao Zhang       j[t]      = j1[b1];
133242550becSJunchao Zhang       imap1[t1] = t;
133342550becSJunchao Zhang       b1       += jmap1[t1+1] - jmap1[t1];
133442550becSJunchao Zhang       t1++; t++;
133542550becSJunchao Zhang     }
133642550becSJunchao Zhang     while (b2 < e2) {
133742550becSJunchao Zhang       j[t]      = j2[b2];
133842550becSJunchao Zhang       imap2[t2] = t;
133942550becSJunchao Zhang       b2       += jmap2[t2+1] - jmap2[t2];
134042550becSJunchao Zhang       t2++; t++;
134142550becSJunchao Zhang     }
134242550becSJunchao Zhang     i[r+1] = t;
134342550becSJunchao Zhang   }
134442550becSJunchao Zhang   PetscFunctionReturn(0);
134542550becSJunchao Zhang }
134642550becSJunchao Zhang 
134742550becSJunchao Zhang /* Split a set/group of local entries into two subsets: those in the diagonal block and those in the off-diagonal block
134842550becSJunchao Zhang 
134942550becSJunchao Zhang   Input Parameters:
135042550becSJunchao Zhang     mat: an MPI matrix that provides row and column layout information for splitting. Let's assume its number of local rows is m.
135142550becSJunchao Zhang     n,i[],j[],perm[]: there are n input entries, belonging to m rows. Row/col indices of the entries are stored in i[] and j[]
135242550becSJunchao Zhang       respectively, along with a permutation array perm[]. Length of the i[],j[],perm[] arrays is n.
135342550becSJunchao Zhang 
135442550becSJunchao Zhang       i[] is already sorted, but within a row, j[] is not sorted and might have repeats.
135542550becSJunchao Zhang       i[] might contain negative indices at the beginning, which says the corresponding entries should be ignored in the splitting.
135642550becSJunchao Zhang 
135742550becSJunchao Zhang   Output Parameters:
135842550becSJunchao Zhang     j[],perm[]: the routine needs to sort j[] within each row along with perm[].
135942550becSJunchao Zhang     rowBegin[],rowMid[],rowEnd[]: of length m, and the memory is preallocated and zeroed by the caller.
136042550becSJunchao Zhang       They contain indices pointing to j[]. For 0<=r<m, [rowBegin[r],rowMid[r]) point to begin/end entries in row r of the diagonal block,
136142550becSJunchao Zhang       and [rowMid[r],rowEnd[r]) point to begin/end entries in row r of the off-diagonal block.
136242550becSJunchao Zhang 
136342550becSJunchao Zhang     Aperm_h,Ajmap_h: They are Kokkos views on host. This routine will resize and fill them with proper values. Let's say Aperm = Aperm_h.data(),
136442550becSJunchao Zhang       and Ajmap = Ajmap_h.data(). Aperm[] stores values from perm[] for entries in the diagonal block. Hence length of Aperm[] is the number
136542550becSJunchao Zhang       of entries in the diagonal block, though those entries might have repeats (i.e., same 'i,j' pair).
136642550becSJunchao Zhang       Ajmap[] stores the number of repeats of each unique nonzero in the diagonal block. More precisely, Ajmap[t+1] - Ajmap[t] is the number of
136742550becSJunchao Zhang       repeats for the t-th unique nonzero in the diagonal block. Ajmap[0] is always 0.
136842550becSJunchao Zhang       Length of Aperm_h is the number of nonzeros in the diagonal block.
136942550becSJunchao Zhang       Length of Ajmap_h is the number of unique nonzeros in the diagonal block + 1.
137042550becSJunchao Zhang 
137142550becSJunchao Zhang     Bperm_h and Bjmap_h are similar to Aperm_h and Ajmap_h, respectively, but for the off-diagonal block.
137242550becSJunchao Zhang */
137342550becSJunchao Zhang 
137442550becSJunchao Zhang static PetscErrorCode MatSplitEntries_Internal(Mat mat,PetscInt n,const PetscInt i[],
137542550becSJunchao Zhang   PetscInt j[],PetscInt perm[],PetscInt rowBegin[],PetscInt rowMid[],PetscInt rowEnd[],
137642550becSJunchao Zhang   MatRowMapKokkosViewHost& Aperm_h,MatRowMapKokkosViewHost& Ajmap_h,MatRowMapKokkosViewHost& Bperm_h,MatRowMapKokkosViewHost& Bjmap_h)
137742550becSJunchao Zhang {
137842550becSJunchao Zhang   PetscErrorCode    ierr;
137942550becSJunchao Zhang   PetscInt          cstart,cend,rstart,rend,mid;
138042550becSJunchao Zhang   PetscInt          Atot=0,Btot=0; /* Total number of nonzeros in the diagonal and off-diagonal blocks */
138142550becSJunchao Zhang   PetscInt          Annz=0,Bnnz=0; /* Number of unique nonzeros in the diagonal and off-diagonal blocks */
138242550becSJunchao Zhang   PetscInt          k,m,p,q,r,s,row,col;
138342550becSJunchao Zhang   PetscInt          *Aperm,*Bperm,*Ajmap,*Bjmap;
138442550becSJunchao Zhang 
138542550becSJunchao Zhang   PetscFunctionBegin;
138642550becSJunchao Zhang   ierr = PetscLayoutGetRange(mat->rmap,&rstart,&rend);CHKERRQ(ierr);
138742550becSJunchao Zhang   ierr = PetscLayoutGetRange(mat->cmap,&cstart,&cend);CHKERRQ(ierr);
138842550becSJunchao Zhang   m    = rend - rstart;
138942550becSJunchao Zhang 
139042550becSJunchao Zhang   for (k=0; k<n; k++) {if (i[k]>=0) break;} /* Skip negative rows */
139142550becSJunchao Zhang 
139242550becSJunchao Zhang   /* Process [k,n): sort and partition each local row into diag and offdiag portions,
139342550becSJunchao Zhang      fill rowBegin[], rowMid[], rowEnd[], and count Atot, Btot, Annz, Bnnz.
139442550becSJunchao Zhang   */
139542550becSJunchao Zhang   while (k<n) {
139642550becSJunchao Zhang     row = i[k];
139742550becSJunchao Zhang     /* Entries in [k,s) are in one row. Shift diagonal block col indices so that diag is ahead of offdiag after sorting the row */
139842550becSJunchao Zhang     for (s=k; s<n; s++) if (i[s] != row) break;
139942550becSJunchao Zhang     for (p=k; p<s; p++) {
140042550becSJunchao Zhang       if (j[p] >= cstart && j[p] < cend) j[p] -= PETSC_MAX_INT; /* Shift diag columns to range of [-PETSC_MAX_INT, -1]  */
140142550becSJunchao Zhang      #if defined(PETSC_USE_DEBUG)
140242550becSJunchao Zhang       else if (j[p] < 0 || j[p] > mat->cmap->N) SETERRQ1(PETSC_COMM_SELF,PETSC_ERR_ARG_OUTOFRANGE,"Column index %" PetscInt_FMT " is out of range",j[p]);
140342550becSJunchao Zhang      #endif
140442550becSJunchao Zhang     }
140542550becSJunchao Zhang     ierr = PetscSortIntWithArray(s-k,j+k,perm+k);CHKERRQ(ierr);
140642550becSJunchao Zhang     ierr = PetscSortedIntUpperBound(j,k,s,-1,&mid);CHKERRQ(ierr); /* Seperate [k,s) into [k,mid) for diag and [mid,s) for offdiag */
140742550becSJunchao Zhang     rowBegin[row-rstart] = k;
140842550becSJunchao Zhang     rowMid[row-rstart]   = mid;
140942550becSJunchao Zhang     rowEnd[row-rstart]   = s;
141042550becSJunchao Zhang 
141142550becSJunchao Zhang     /* Count nonzeros of this diag/offdiag row, which might have repeats */
141242550becSJunchao Zhang     Atot += mid - k;
141342550becSJunchao Zhang     Btot += s - mid;
141442550becSJunchao Zhang 
141542550becSJunchao Zhang     /* Count unique nonzeros of this diag/offdiag row */
141642550becSJunchao Zhang     for (p=k; p<mid;) {
141742550becSJunchao Zhang       col = j[p];
141842550becSJunchao Zhang       do {j[p] += PETSC_MAX_INT; p++;} while (p<mid && j[p] == col); /* Revert the modified diagonal indices */
141942550becSJunchao Zhang       Annz++;
142042550becSJunchao Zhang     }
142142550becSJunchao Zhang 
142242550becSJunchao Zhang     for (p=mid; p<s;) {
142342550becSJunchao Zhang       col = j[p];
142442550becSJunchao Zhang       do {p++;} while (p<s && j[p] == col);
142542550becSJunchao Zhang       Bnnz++;
142642550becSJunchao Zhang     }
142742550becSJunchao Zhang     k = s;
142842550becSJunchao Zhang   }
142942550becSJunchao Zhang 
143042550becSJunchao Zhang   /* Resize views according to Atot, Btot, Annz, Bnnz */
143142550becSJunchao Zhang   Kokkos::resize(Aperm_h,Atot);
143242550becSJunchao Zhang   Kokkos::resize(Ajmap_h,Annz+1);
143342550becSJunchao Zhang   Kokkos::resize(Bperm_h,Btot);
143442550becSJunchao Zhang   Kokkos::resize(Bjmap_h,Bnnz+1);
143542550becSJunchao Zhang   Aperm    = Aperm_h.data();
143642550becSJunchao Zhang   Bperm    = Bperm_h.data();
143742550becSJunchao Zhang   Ajmap    = Ajmap_h.data();
143842550becSJunchao Zhang   Bjmap    = Bjmap_h.data();
143942550becSJunchao Zhang   Ajmap[0] = 0;
144042550becSJunchao Zhang   Bjmap[0] = 0;
144142550becSJunchao Zhang 
144242550becSJunchao Zhang   /* Re-scan indices and copy diag/offdiag permuation indices to Aperm, Bperm and also fill Ajmap and Bjmap */
144342550becSJunchao Zhang   Atot = Btot = Annz = Bnnz = 0;
144442550becSJunchao Zhang   for (r=0; r<m; r++) {
144542550becSJunchao Zhang     k     = rowBegin[r];
144642550becSJunchao Zhang     mid   = rowMid[r];
144742550becSJunchao Zhang     s     = rowEnd[r];
144842550becSJunchao Zhang     ierr  = PetscArraycpy(Aperm+Atot,perm+k,  mid-k);CHKERRQ(ierr);
144942550becSJunchao Zhang     ierr  = PetscArraycpy(Bperm+Btot,perm+mid,s-mid);CHKERRQ(ierr);
145042550becSJunchao Zhang     Atot += mid - k;
145142550becSJunchao Zhang     Btot += s - mid;
145242550becSJunchao Zhang 
145342550becSJunchao Zhang     /* Scan column indices in this row and find out how many repeats each unique nonzero has */
145442550becSJunchao Zhang     for (p=k; p<mid;) {
145542550becSJunchao Zhang       col = j[p];
145642550becSJunchao Zhang       q   = p;
145742550becSJunchao Zhang       do {p++;} while (p<mid && j[p] == col);
145842550becSJunchao Zhang       Ajmap[Annz+1] = Ajmap[Annz] + (p - q);
145942550becSJunchao Zhang       Annz++;
146042550becSJunchao Zhang     }
146142550becSJunchao Zhang 
146242550becSJunchao Zhang     for (p=mid; p<s;) {
146342550becSJunchao Zhang       col = j[p];
146442550becSJunchao Zhang       q   = p;
146542550becSJunchao Zhang       do {p++;} while (p<s && j[p] == col);
146642550becSJunchao Zhang       Bjmap[Bnnz+1] = Bjmap[Bnnz] + (p - q);
146742550becSJunchao Zhang       Bnnz++;
146842550becSJunchao Zhang     }
146942550becSJunchao Zhang   }
147042550becSJunchao Zhang   PetscFunctionReturn(0);
147142550becSJunchao Zhang }
147242550becSJunchao Zhang 
147382a78a4eSJed Brown static PetscErrorCode MatSetPreallocationCOO_MPIAIJKokkos(Mat mat, PetscCount coo_n, const PetscInt coo_i[], const PetscInt coo_j[])
147442550becSJunchao Zhang {
147542550becSJunchao Zhang   PetscErrorCode            ierr;
147642550becSJunchao Zhang   MPI_Comm                  comm;
147742550becSJunchao Zhang   PetscMPIInt               rank,size;
147842550becSJunchao Zhang   PetscInt                  m,n,M,N,k,p,q,rstart,rend,cstart,cend,rem;
147942550becSJunchao Zhang 
148042550becSJunchao Zhang   PetscFunctionBegin;
148142550becSJunchao Zhang   ierr = PetscObjectGetComm((PetscObject)mat,&comm);CHKERRQ(ierr);
148242550becSJunchao Zhang   ierr = MPI_Comm_size(comm,&size);CHKERRMPI(ierr);
148342550becSJunchao Zhang   ierr = MPI_Comm_rank(comm,&rank);CHKERRMPI(ierr);
148442550becSJunchao Zhang 
148542550becSJunchao Zhang   ierr = PetscLayoutGetRange(mat->rmap,&rstart,&rend);CHKERRQ(ierr);
148642550becSJunchao Zhang   ierr = PetscLayoutGetRange(mat->cmap,&cstart,&cend);CHKERRQ(ierr);
148742550becSJunchao Zhang   ierr = MatGetLocalSize(mat,&m,&n);CHKERRQ(ierr);
148842550becSJunchao Zhang   ierr = MatGetSize(mat,&M,&N);CHKERRQ(ierr);
148942550becSJunchao Zhang 
149042550becSJunchao Zhang   /* ---------------------------------------------------------------------------*/
149142550becSJunchao Zhang   /* Sort (i,j) by row along with a permuation array, so that the to-be-ignored */
149242550becSJunchao Zhang   /* entries come first, then local rows, then remote rows.                     */
149342550becSJunchao Zhang   /* ---------------------------------------------------------------------------*/
149482a78a4eSJed Brown   PetscCount n1 = coo_n;
149582a78a4eSJed Brown   PetscInt *i1,*j1,*perm1; /* Copies of input COOs along with a permutation array */
149642550becSJunchao Zhang   ierr = PetscMalloc3(n1,&i1,n1,&j1,n1,&perm1);CHKERRQ(ierr);
149742550becSJunchao Zhang   ierr = PetscArraycpy(i1,coo_i,n1);CHKERRQ(ierr); /* Make a copy since we'll modify it */
149842550becSJunchao Zhang   ierr = PetscArraycpy(j1,coo_j,n1);CHKERRQ(ierr);
149942550becSJunchao Zhang   for (k=0; k<n1; k++) perm1[k] = k;
150042550becSJunchao Zhang 
150142550becSJunchao Zhang   /* Manipulate indices so that entries with negative row or col indices will have smallest
150242550becSJunchao Zhang      row indices, local entries will have greater but negative row indices, and remote entries
150342550becSJunchao Zhang      will have positive row indices.
150442550becSJunchao Zhang   */
150542550becSJunchao Zhang   for (k=0; k<n1; k++) {
150642550becSJunchao Zhang     if (i1[k] < 0 || j1[k] < 0) i1[k] = PETSC_MIN_INT; /* e.g., -2^31, minimal to move them ahead */
150742550becSJunchao Zhang     else if (i1[k] >= rstart && i1[k] < rend) i1[k] -= PETSC_MAX_INT; /* e.g., minus 2^31-1 to shift local rows to range of [-PETSC_MAX_INT, -1] */
150842550becSJunchao Zhang   }
150942550becSJunchao Zhang 
151042550becSJunchao Zhang   /* Sort by row; after that, [0,k) have ignored entires, [k,rem) have local rows and [rem,n1) have remote rows */
151142550becSJunchao Zhang   ierr = PetscSortIntWithArrayPair(n1,i1,j1,perm1);CHKERRQ(ierr);
151242550becSJunchao Zhang   for (k=0; k<n1; k++) {if (i1[k] > PETSC_MIN_INT) break;} /* Advance k to the first entry we need to take care of */
151342550becSJunchao Zhang   ierr = PetscSortedIntUpperBound(i1,k,n1,rend-1-PETSC_MAX_INT,&rem);CHKERRQ(ierr); /* rem is upper bound of the last local row */
151442550becSJunchao Zhang   for (; k<rem; k++) i1[k] += PETSC_MAX_INT; /* Revert row indices of local rows*/
151542550becSJunchao Zhang 
151642550becSJunchao Zhang   /* ---------------------------------------------------------------------------*/
151742550becSJunchao Zhang   /*           Split local rows into diag/offdiag portions                      */
151842550becSJunchao Zhang   /* ---------------------------------------------------------------------------*/
151942550becSJunchao Zhang   PetscInt                  *rowBegin1,*rowMid1,*rowEnd1;
152042550becSJunchao Zhang   MatRowMapKokkosViewHost   Ajmap1_h,Aperm1_h,Bjmap1_h,Bperm1_h,Cperm1_h("Cperm1_h",n1-rem);
152142550becSJunchao Zhang 
152242550becSJunchao Zhang   ierr = PetscCalloc3(m,&rowBegin1,m,&rowMid1,m,&rowEnd1);CHKERRQ(ierr);
152342550becSJunchao Zhang   ierr = MatSplitEntries_Internal(mat,rem,i1,j1,perm1,rowBegin1,rowMid1,rowEnd1,Aperm1_h,Ajmap1_h,Bperm1_h,Bjmap1_h);CHKERRQ(ierr);
152442550becSJunchao Zhang 
152542550becSJunchao Zhang   /* ---------------------------------------------------------------------------*/
152642550becSJunchao Zhang   /*           Send remote rows to their owner                                  */
152742550becSJunchao Zhang   /* ---------------------------------------------------------------------------*/
152842550becSJunchao Zhang   /* Find which rows should be sent to which remote ranks*/
152942550becSJunchao Zhang   PetscInt       nsend = 0;
153042550becSJunchao Zhang   PetscMPIInt    *sendto; /* Of length nsend, storing remote ranks */
153142550becSJunchao Zhang   PetscInt       *nentries; /* Of length nsend, storing number of entries to be sent to each remote rank */
153242550becSJunchao Zhang   const PetscInt *ranges;
153342550becSJunchao Zhang   PetscInt       maxNsend = size >= 128? 128 : size; /* Assume max 128 neighbors; realloc when needed */
153442550becSJunchao Zhang 
153542550becSJunchao Zhang   ierr = PetscLayoutGetRanges(mat->rmap,&ranges);CHKERRQ(ierr);
153642550becSJunchao Zhang   ierr = PetscMalloc2(maxNsend,&sendto,maxNsend,&nentries);CHKERRQ(ierr);
153742550becSJunchao Zhang   for (k=rem; k<n1;) {
153842550becSJunchao Zhang     PetscMPIInt  owner;
153942550becSJunchao Zhang     PetscInt     firstRow,lastRow;
154042550becSJunchao Zhang     /* Locate a row range */
154142550becSJunchao Zhang     firstRow = i1[k]; /* first row of this owner */
154242550becSJunchao Zhang     ierr     = PetscLayoutFindOwner(mat->rmap,firstRow,&owner);CHKERRQ(ierr);
154342550becSJunchao Zhang     lastRow  = ranges[owner+1]-1; /* last row of this owner */
154442550becSJunchao Zhang 
154542550becSJunchao Zhang     /* Find the first index 'p' in [k,n) with i[p] belonging to next owner */
154642550becSJunchao Zhang     ierr     = PetscSortedIntUpperBound(i1,k,n1,lastRow,&p);CHKERRQ(ierr);
154742550becSJunchao Zhang 
154842550becSJunchao Zhang     /* All entries in [k,p) belong to this remote owner */
154942550becSJunchao Zhang     if (nsend >= maxNsend) { /* Double the remote ranks arrays if not long enough */
155042550becSJunchao Zhang       PetscMPIInt *sendto2;
155142550becSJunchao Zhang       PetscInt    *nentries2;
155242550becSJunchao Zhang       PetscInt    maxNsend2 = (maxNsend <= size/2) ? maxNsend*2 : size;
155342550becSJunchao Zhang       ierr = PetscMalloc2(maxNsend2,&sendto2,maxNsend2,&nentries2);CHKERRQ(ierr);
155442550becSJunchao Zhang       ierr = PetscArraycpy(sendto2,sendto,maxNsend);CHKERRQ(ierr);
155542550becSJunchao Zhang       ierr = PetscArraycpy(nentries2,nentries2,maxNsend+1);CHKERRQ(ierr);
155642550becSJunchao Zhang       ierr = PetscFree2(sendto,nentries2);CHKERRQ(ierr);
155742550becSJunchao Zhang       sendto      = sendto2;
155842550becSJunchao Zhang       nentries    = nentries2;
155942550becSJunchao Zhang       maxNsend    = maxNsend2;
156042550becSJunchao Zhang     }
156142550becSJunchao Zhang     sendto[nsend]   = owner;
156242550becSJunchao Zhang     nentries[nsend] = p - k;
156342550becSJunchao Zhang     nsend++;
156442550becSJunchao Zhang     k = p;
156542550becSJunchao Zhang   }
156642550becSJunchao Zhang 
156742550becSJunchao Zhang   /* Build 1st SF to know offsets on remote to send data */
156842550becSJunchao Zhang   PetscSF     sf1;
156942550becSJunchao Zhang   PetscInt    nroots = 1,nroots2 = 0;
157042550becSJunchao Zhang   PetscInt    nleaves = nsend,nleaves2 = 0;
157142550becSJunchao Zhang   PetscInt    *offsets;
157242550becSJunchao Zhang   PetscSFNode *iremote;
157342550becSJunchao Zhang 
157442550becSJunchao Zhang   ierr = PetscSFCreate(comm,&sf1);CHKERRQ(ierr);
157542550becSJunchao Zhang   ierr = PetscMalloc1(nsend,&iremote);CHKERRQ(ierr);
157642550becSJunchao Zhang   ierr = PetscMalloc1(nsend,&offsets);CHKERRQ(ierr);
157742550becSJunchao Zhang   for (k=0; k<nsend; k++) {
157842550becSJunchao Zhang     iremote[k].rank  = sendto[k];
157942550becSJunchao Zhang     iremote[k].index = 0;
158042550becSJunchao Zhang     nleaves2        += nentries[k];
158142550becSJunchao Zhang   }
158242550becSJunchao Zhang   ierr = PetscSFSetGraph(sf1,nroots,nleaves,NULL,PETSC_OWN_POINTER,iremote,PETSC_OWN_POINTER);CHKERRQ(ierr);
158342550becSJunchao Zhang   ierr = PetscSFFetchAndOpWithMemTypeBegin(sf1,MPIU_INT,PETSC_MEMTYPE_HOST,&nroots2/*rootdata*/,PETSC_MEMTYPE_HOST,nentries/*leafdata*/,PETSC_MEMTYPE_HOST,offsets/*leafupdate*/,MPI_SUM);CHKERRQ(ierr);
158442550becSJunchao Zhang   ierr = PetscSFFetchAndOpEnd(sf1,MPIU_INT,&nroots2,nentries,offsets,MPI_SUM);CHKERRQ(ierr);
158542550becSJunchao Zhang   ierr = PetscSFDestroy(&sf1);CHKERRQ(ierr);
158642550becSJunchao Zhang 
158742550becSJunchao Zhang   /* Build 2nd SF to send remote COOs to their owner */
158842550becSJunchao Zhang   PetscSF sf2;
158942550becSJunchao Zhang   nroots  = nroots2;
159042550becSJunchao Zhang   nleaves = nleaves2;
159142550becSJunchao Zhang   ierr    = PetscSFCreate(comm,&sf2);CHKERRQ(ierr);
159242550becSJunchao Zhang   ierr    = PetscSFSetFromOptions(sf2);CHKERRQ(ierr);
159342550becSJunchao Zhang   ierr    = PetscMalloc1(nleaves,&iremote);CHKERRQ(ierr);
159442550becSJunchao Zhang   p       = 0;
159542550becSJunchao Zhang   for (k=0; k<nsend; k++) {
159642550becSJunchao Zhang     for (q=0; q<nentries[k]; q++,p++) {
159742550becSJunchao Zhang       iremote[p].rank  = sendto[k];
159842550becSJunchao Zhang       iremote[p].index = offsets[k] + q;
159942550becSJunchao Zhang     }
160042550becSJunchao Zhang   }
160142550becSJunchao Zhang   ierr = PetscSFSetGraph(sf2,nroots,nleaves,NULL,PETSC_USE_POINTER,iremote,PETSC_OWN_POINTER);CHKERRQ(ierr);
160242550becSJunchao Zhang 
160342550becSJunchao Zhang   /* sf2 only sends contiguous leafdata to contiguous rootdata. We record the permuation which will be used to fill leafdata */
160442550becSJunchao Zhang   ierr = PetscArraycpy(Cperm1_h.data(),perm1+rem,n1-rem);CHKERRQ(ierr);
160542550becSJunchao Zhang 
160642550becSJunchao Zhang   /* Send the remote COOs to their owner */
160742550becSJunchao Zhang   PetscInt n2 = nroots,*i2,*j2,*perm2; /* Buffers for received COOs from other ranks, along with a permutation array */
160842550becSJunchao Zhang   ierr = PetscMalloc3(n2,&i2,n2,&j2,n2,&perm2);CHKERRQ(ierr);
160942550becSJunchao Zhang   ierr = PetscSFReduceWithMemTypeBegin(sf2,MPIU_INT,PETSC_MEMTYPE_HOST,i1+rem,PETSC_MEMTYPE_HOST,i2,MPI_REPLACE);CHKERRQ(ierr);
161042550becSJunchao Zhang   ierr = PetscSFReduceEnd(sf2,MPIU_INT,i1+rem,i2,MPI_REPLACE);CHKERRQ(ierr);
161142550becSJunchao Zhang   ierr = PetscSFReduceWithMemTypeBegin(sf2,MPIU_INT,PETSC_MEMTYPE_HOST,j1+rem,PETSC_MEMTYPE_HOST,j2,MPI_REPLACE);CHKERRQ(ierr);
161242550becSJunchao Zhang   ierr = PetscSFReduceEnd(sf2,MPIU_INT,j1+rem,j2,MPI_REPLACE);CHKERRQ(ierr);
161342550becSJunchao Zhang 
161442550becSJunchao Zhang   ierr = PetscFree(offsets);CHKERRQ(ierr);
161542550becSJunchao Zhang   ierr = PetscFree2(sendto,nentries);CHKERRQ(ierr);
161642550becSJunchao Zhang 
161742550becSJunchao Zhang   /* ---------------------------------------------------------------*/
161842550becSJunchao Zhang   /* Sort received COOs by row along with the permutation array     */
161942550becSJunchao Zhang   /* ---------------------------------------------------------------*/
162042550becSJunchao Zhang   for (k=0; k<n2; k++) perm2[k] = k;
162142550becSJunchao Zhang   ierr = PetscSortIntWithArrayPair(n2,i2,j2,perm2);CHKERRQ(ierr);
162242550becSJunchao Zhang 
162342550becSJunchao Zhang   /* ---------------------------------------------------------------*/
162442550becSJunchao Zhang   /* Split received COOs into diag/offdiag portions                 */
162542550becSJunchao Zhang   /* ---------------------------------------------------------------*/
162642550becSJunchao Zhang   PetscInt                  *rowBegin2,*rowMid2,*rowEnd2;
162742550becSJunchao Zhang   MatRowMapKokkosViewHost   Ajmap2_h,Aperm2_h,Bjmap2_h,Bperm2_h;
162842550becSJunchao Zhang 
162942550becSJunchao Zhang   ierr = PetscCalloc3(m,&rowBegin2,m,&rowMid2,m,&rowEnd2);CHKERRQ(ierr);
163042550becSJunchao Zhang   ierr = MatSplitEntries_Internal(mat,n2,i2,j2,perm2,rowBegin2,rowMid2,rowEnd2,Aperm2_h,Ajmap2_h,Bperm2_h,Bjmap2_h);CHKERRQ(ierr);
163142550becSJunchao Zhang 
163242550becSJunchao Zhang   /* --------------------------------------------------------------------------*/
163342550becSJunchao Zhang   /* Merge local COOs with received COOs: diag with diag, offdiag with offdiag */
163442550becSJunchao Zhang   /* --------------------------------------------------------------------------*/
163542550becSJunchao Zhang   PetscInt Annz1,Annz2,Bnnz1,Bnnz2;
163642550becSJunchao Zhang   PetscInt *Ai,*Aj,*Bi,*Bj;
163742550becSJunchao Zhang 
163842550becSJunchao Zhang   Annz1 = Ajmap1_h.extent(0)-1; /* Number of unique local nonzeros in the diagonal block */
163942550becSJunchao Zhang   Annz2 = Ajmap2_h.extent(0)-1; /* Number of unique received nonzeros in the diagonal block */
164042550becSJunchao Zhang   Bnnz1 = Bjmap1_h.extent(0)-1; /* Similar, but for the off-diagonal block */
164142550becSJunchao Zhang   Bnnz2 = Bjmap2_h.extent(0)-1;
164242550becSJunchao Zhang   ierr  = PetscMalloc1(m+1,&Ai);CHKERRQ(ierr);
164342550becSJunchao Zhang   ierr  = PetscMalloc1(m+1,&Bi);CHKERRQ(ierr);
164442550becSJunchao Zhang   ierr  = PetscMalloc1(Annz1+Annz2,&Aj);CHKERRQ(ierr); /* Since local and remote entries might have dups, we might allocate excess memory */
164542550becSJunchao Zhang   ierr  = PetscMalloc1(Bnnz1+Bnnz2,&Bj);CHKERRQ(ierr);
164642550becSJunchao Zhang 
164742550becSJunchao Zhang   MatRowMapKokkosViewHost Aimap1_h("Aimpa1",Annz1),Aimap2_h("Aimpa2",Annz2),Bimap1_h("Bimap1",Bnnz1),Bimap2_h("Bimap2",Bnnz2);
164842550becSJunchao Zhang   ierr = MatMergeEntries_Internal(mat,j1,j2,rowBegin1,rowMid1,rowBegin2,rowMid2,Ajmap1_h,Ajmap2_h,Aimap1_h,Aimap2_h,Ai,Aj);CHKERRQ(ierr);
164942550becSJunchao Zhang   ierr = MatMergeEntries_Internal(mat,j1,j2,rowMid1,  rowEnd1,rowMid2,  rowEnd2,Bjmap1_h,Bjmap2_h,Bimap1_h,Bimap2_h,Bi,Bj);CHKERRQ(ierr);
165042550becSJunchao Zhang   ierr = PetscFree3(rowBegin1,rowMid1,rowEnd1);CHKERRQ(ierr);
165142550becSJunchao Zhang   ierr = PetscFree3(rowBegin2,rowMid2,rowEnd2);CHKERRQ(ierr);
165242550becSJunchao Zhang   ierr = PetscFree3(i1,j1,perm1);CHKERRQ(ierr);
165342550becSJunchao Zhang   ierr = PetscFree3(i2,j2,perm2);CHKERRQ(ierr);
165442550becSJunchao Zhang 
165542550becSJunchao Zhang   /* Reallocate Aj, Bj once we know actual numbers of unique nonzeros in A and B */
165642550becSJunchao Zhang   PetscInt Annz = Ai[m];
165742550becSJunchao Zhang   PetscInt Bnnz = Bi[m];
165842550becSJunchao Zhang   if (Annz < Annz1 + Annz2) {
165942550becSJunchao Zhang     PetscInt *Aj_new;
166042550becSJunchao Zhang     ierr = PetscMalloc1(Annz,&Aj_new);CHKERRQ(ierr);
166142550becSJunchao Zhang     ierr = PetscArraycpy(Aj_new,Aj,Annz);CHKERRQ(ierr);
166242550becSJunchao Zhang     ierr = PetscFree(Aj);CHKERRQ(ierr);
166342550becSJunchao Zhang     Aj   = Aj_new;
166442550becSJunchao Zhang   }
166542550becSJunchao Zhang 
166642550becSJunchao Zhang   if (Bnnz < Bnnz1 + Bnnz2) {
166742550becSJunchao Zhang     PetscInt *Bj_new;
166842550becSJunchao Zhang     ierr = PetscMalloc1(Bnnz,&Bj_new);CHKERRQ(ierr);
166942550becSJunchao Zhang     ierr = PetscArraycpy(Bj_new,Bj,Bnnz);CHKERRQ(ierr);
167042550becSJunchao Zhang     ierr = PetscFree(Bj);CHKERRQ(ierr);
167142550becSJunchao Zhang     Bj   = Bj_new;
167242550becSJunchao Zhang   }
167342550becSJunchao Zhang 
167442550becSJunchao Zhang   /* --------------------------------------------------------------------------------*/
167542550becSJunchao Zhang   /* Create a MPIAIJKOKKOS newmat with CSRs of A and B, then replace mat with newmat */
167642550becSJunchao Zhang   /* --------------------------------------------------------------------------------*/
167742550becSJunchao Zhang   Mat           newmat;
167842550becSJunchao Zhang   PetscScalar   *Aa,*Ba;
167942550becSJunchao Zhang   Mat_MPIAIJ    *mpiaij;
168042550becSJunchao Zhang   Mat_SeqAIJ    *a,*b;
168142550becSJunchao Zhang 
168242550becSJunchao Zhang   ierr   = PetscMalloc1(Annz,&Aa);CHKERRQ(ierr);
168342550becSJunchao Zhang   ierr   = PetscMalloc1(Bnnz,&Ba);CHKERRQ(ierr);
168442550becSJunchao Zhang   /* make Aj[] local, i.e, based off the start column of the diagonal portion */
168542550becSJunchao Zhang   if (cstart) {for (k=0; k<Annz; k++) Aj[k] -= cstart;}
168642550becSJunchao Zhang   ierr   = MatCreateMPIAIJWithSplitArrays(comm,m,n,M,N,Ai,Aj,Aa,Bi,Bj,Ba,&newmat);CHKERRQ(ierr);
168742550becSJunchao Zhang   mpiaij = (Mat_MPIAIJ*)newmat->data;
168842550becSJunchao Zhang   a      = (Mat_SeqAIJ*)mpiaij->A->data;
168942550becSJunchao Zhang   b      = (Mat_SeqAIJ*)mpiaij->B->data;
169042550becSJunchao Zhang   a->singlemalloc = b->singlemalloc = PETSC_FALSE; /* Let newmat own Ai,Aj,Aa,Bi,Bj,Ba */
169142550becSJunchao Zhang   a->free_a       = b->free_a       = PETSC_TRUE;
169242550becSJunchao Zhang   a->free_ij      = b->free_ij      = PETSC_TRUE;
169342550becSJunchao Zhang   ierr   = MatConvert(newmat,MATMPIAIJKOKKOS,MAT_INPLACE_MATRIX,&newmat);CHKERRQ(ierr);
169442550becSJunchao Zhang   ierr   = MatHeaderMerge(mat,&newmat);CHKERRQ(ierr);
169542550becSJunchao Zhang   ierr   = MatZeroEntries(mat);CHKERRQ(ierr); /* Zero matrix on device */
169642550becSJunchao Zhang   mpiaij = (Mat_MPIAIJ*)mat->data;
169742550becSJunchao Zhang   mpiaij->spptr = new Mat_MPIAIJKokkos(n1,sf2,nroots,nleaves,Annz1,Annz2,Bnnz1,Bnnz2,
169842550becSJunchao Zhang                                        Aimap1_h,Aimap2_h,Bimap1_h,Bimap2_h,
169942550becSJunchao Zhang                                        Ajmap1_h,Ajmap2_h,Bjmap1_h,Bjmap2_h,
170042550becSJunchao Zhang                                        Aperm1_h,Aperm2_h,Bperm1_h,Bperm2_h,Cperm1_h);
170142550becSJunchao Zhang   ierr = PetscSFDestroy(&sf2);CHKERRQ(ierr); /* ctor of Mat_MPIAIJKokkos already took a reference of sf3 */
170242550becSJunchao Zhang   PetscFunctionReturn(0);
170342550becSJunchao Zhang }
170442550becSJunchao Zhang 
170542550becSJunchao Zhang static PetscErrorCode MatSetValuesCOO_MPIAIJKokkos(Mat mat,const PetscScalar v[],InsertMode imode)
170642550becSJunchao Zhang {
170742550becSJunchao Zhang   PetscErrorCode                 ierr;
170842550becSJunchao Zhang   Mat_MPIAIJ                     *mpiaij = (Mat_MPIAIJ*)mat->data;
170942550becSJunchao Zhang   Mat_MPIAIJKokkos               *mpikok = static_cast<Mat_MPIAIJKokkos*>(mpiaij->spptr);
171042550becSJunchao Zhang   Mat                            A = mpiaij->A,B = mpiaij->B;
171142550becSJunchao Zhang   PetscInt                       Annz1 = mpikok->Annz1,Annz2 = mpikok->Annz2,Bnnz1 = mpikok->Bnnz1,Bnnz2 = mpikok->Bnnz2;
171242550becSJunchao Zhang   MatScalarKokkosView            Aa,Ba;
171342550becSJunchao Zhang   ConstMatScalarKokkosView       v1;
171442550becSJunchao Zhang   MatScalarKokkosView&           vsend = mpikok->sendbuf_d;
171542550becSJunchao Zhang   const MatScalarKokkosView&     v2 = mpikok->recvbuf_d;
171642550becSJunchao Zhang   const MatRowMapKokkosView&     Ajmap1 = mpikok->Ajmap1_d,Ajmap2 = mpikok->Ajmap2_d,Aimap1 = mpikok->Aimap1_d,Aimap2 = mpikok->Aimap2_d;
171742550becSJunchao Zhang   const MatRowMapKokkosView&     Bjmap1 = mpikok->Bjmap1_d,Bjmap2 = mpikok->Bjmap2_d,Bimap1 = mpikok->Bimap1_d,Bimap2 = mpikok->Bimap2_d;
171842550becSJunchao Zhang   const MatRowMapKokkosView&     Aperm1 = mpikok->Aperm1_d,Aperm2 = mpikok->Aperm2_d,Bperm1 = mpikok->Bperm1_d,Bperm2 = mpikok->Bperm2_d;
171942550becSJunchao Zhang   const MatRowMapKokkosView&     Cperm1 = mpikok->Cperm1_d;
172042550becSJunchao Zhang   PetscMemType                   memtype;
172142550becSJunchao Zhang 
172242550becSJunchao Zhang   PetscFunctionBegin;
172342550becSJunchao Zhang   if (!v) { /* NULL v means an all zero array */
172442550becSJunchao Zhang     ierr = MatZeroEntries(mat);CHKERRQ(ierr);
172542550becSJunchao Zhang     PetscFunctionReturn(0);
172642550becSJunchao Zhang   }
172742550becSJunchao Zhang 
172842550becSJunchao Zhang   ierr = PetscGetMemType(v,&memtype);CHKERRQ(ierr);
172942550becSJunchao Zhang   if (PetscMemTypeHost(memtype)) { /* If user gave v[] in host, we need to copy it to device if any */
173042550becSJunchao Zhang     v1 = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(),ConstMatScalarKokkosViewHost(v,mpikok->coo_n));
173142550becSJunchao Zhang   } else {
173242550becSJunchao Zhang     v1 = ConstMatScalarKokkosView(v,mpikok->coo_n); /* Directly use v[]'s memory */
173342550becSJunchao Zhang   }
173442550becSJunchao Zhang 
173542550becSJunchao Zhang   ierr = MatSeqAIJGetKokkosView(A,&Aa);CHKERRQ(ierr); /* Might read and write matrix values */
173642550becSJunchao Zhang   ierr = MatSeqAIJGetKokkosView(B,&Ba);CHKERRQ(ierr);
173742550becSJunchao Zhang   if (imode == INSERT_VALUES) {
173842550becSJunchao Zhang     Kokkos::deep_copy(Aa,0.0); /* Zero matrix values since INSERT_VALUES still requires summing replicated values in v[] */
173942550becSJunchao Zhang     Kokkos::deep_copy(Ba,0.0);
174042550becSJunchao Zhang   }
174142550becSJunchao Zhang 
174242550becSJunchao Zhang   /* Pack entries to be sent to remote */
174342550becSJunchao Zhang   Kokkos::parallel_for(vsend.extent(0),KOKKOS_LAMBDA(const PetscInt i) {vsend(i) = v1(Cperm1(i));});
174442550becSJunchao Zhang 
174542550becSJunchao Zhang   /* Send remote entries to their owner and overlap the communication with local computation */
174642550becSJunchao Zhang   ierr = PetscSFReduceWithMemTypeBegin(mpikok->coo_sf,MPIU_SCALAR,PETSC_MEMTYPE_KOKKOS,vsend.data(),PETSC_MEMTYPE_KOKKOS,v2.data(),MPI_REPLACE);CHKERRQ(ierr);
174742550becSJunchao Zhang   /* Add local entries to A and B */
174842550becSJunchao Zhang   Kokkos::parallel_for(Annz1,KOKKOS_LAMBDA(const PetscInt i) {for (PetscInt k=Ajmap1(i); k<Ajmap1(i+1); k++) Aa(Aimap1(i)) += v1(Aperm1(k));});
174942550becSJunchao Zhang   Kokkos::parallel_for(Bnnz1,KOKKOS_LAMBDA(const PetscInt i) {for (PetscInt k=Bjmap1(i); k<Bjmap1(i+1); k++) Ba(Bimap1(i)) += v1(Bperm1(k));});
175042550becSJunchao Zhang   ierr = PetscSFReduceEnd(mpikok->coo_sf,MPIU_SCALAR,vsend.data(),v2.data(),MPI_REPLACE);CHKERRQ(ierr);
175142550becSJunchao Zhang 
175242550becSJunchao Zhang   /* Add received remote entries to A and B */
175342550becSJunchao Zhang   Kokkos::parallel_for(Annz2,KOKKOS_LAMBDA(const PetscInt i) {for (PetscInt k=Ajmap2(i); k<Ajmap2(i+1); k++) Aa(Aimap2(i)) += v2(Aperm2(k));});
175442550becSJunchao Zhang   Kokkos::parallel_for(Bnnz2,KOKKOS_LAMBDA(const PetscInt i) {for (PetscInt k=Bjmap2(i); k<Bjmap2(i+1); k++) Ba(Bimap2(i)) += v2(Bperm2(k));});
175542550becSJunchao Zhang 
175642550becSJunchao Zhang   ierr = MatSeqAIJRestoreKokkosView(A,&Aa);CHKERRQ(ierr);
175742550becSJunchao Zhang   ierr = MatSeqAIJRestoreKokkosView(B,&Ba);CHKERRQ(ierr);
175842550becSJunchao Zhang 
175942550becSJunchao Zhang   ierr = MatAssemblyBegin(A,MAT_FINAL_ASSEMBLY);CHKERRQ(ierr);
176042550becSJunchao Zhang   ierr = MatAssemblyEnd(A,MAT_FINAL_ASSEMBLY);CHKERRQ(ierr);
176142550becSJunchao Zhang   PetscFunctionReturn(0);
176242550becSJunchao Zhang }
176342550becSJunchao Zhang 
1764076ba34aSJunchao Zhang PetscErrorCode MatDestroy_MPIAIJKokkos(Mat A)
1765076ba34aSJunchao Zhang {
1766076ba34aSJunchao Zhang   PetscErrorCode     ierr;
176742550becSJunchao Zhang   Mat_MPIAIJ         *mpiaij = (Mat_MPIAIJ*)A->data;
1768076ba34aSJunchao Zhang 
1769076ba34aSJunchao Zhang   PetscFunctionBegin;
1770076ba34aSJunchao Zhang   ierr = PetscObjectComposeFunction((PetscObject)A,"MatMPIAIJSetPreallocation_C",NULL);CHKERRQ(ierr);
1771076ba34aSJunchao Zhang   ierr = PetscObjectComposeFunction((PetscObject)A,"MatMPIAIJGetLocalMatMerge_C",NULL);CHKERRQ(ierr);
177242550becSJunchao Zhang   ierr = PetscObjectComposeFunction((PetscObject)A,"MatSetPreallocationCOO_C",   NULL);CHKERRQ(ierr);
177342550becSJunchao Zhang   ierr = PetscObjectComposeFunction((PetscObject)A,"MatSetValuesCOO_C",          NULL);CHKERRQ(ierr);
177442550becSJunchao Zhang   delete (Mat_MPIAIJKokkos*)mpiaij->spptr;
1775076ba34aSJunchao Zhang   ierr = MatDestroy_MPIAIJ(A);CHKERRQ(ierr);
1776076ba34aSJunchao Zhang   PetscFunctionReturn(0);
1777076ba34aSJunchao Zhang }
1778076ba34aSJunchao Zhang 
17798c3ff71bSJunchao Zhang PETSC_INTERN PetscErrorCode MatConvert_MPIAIJ_MPIAIJKokkos(Mat A, MatType mtype, MatReuse reuse, Mat* newmat)
17808c3ff71bSJunchao Zhang {
17818c3ff71bSJunchao Zhang   PetscErrorCode     ierr;
17828c3ff71bSJunchao Zhang   Mat                B;
1783076ba34aSJunchao Zhang   Mat_MPIAIJ         *a;
17848c3ff71bSJunchao Zhang 
17858c3ff71bSJunchao Zhang   PetscFunctionBegin;
17868c3ff71bSJunchao Zhang   if (reuse == MAT_INITIAL_MATRIX) {
17878c3ff71bSJunchao Zhang     ierr = MatDuplicate(A,MAT_COPY_VALUES,newmat);CHKERRQ(ierr);
17888c3ff71bSJunchao Zhang   } else if (reuse == MAT_REUSE_MATRIX) {
17898c3ff71bSJunchao Zhang     ierr = MatCopy(A,*newmat,SAME_NONZERO_PATTERN);CHKERRQ(ierr);
17908c3ff71bSJunchao Zhang   }
17918c3ff71bSJunchao Zhang   B = *newmat;
17928c3ff71bSJunchao Zhang 
17936f3d89d0SStefano Zampini   B->boundtocpu = PETSC_FALSE;
17948c3ff71bSJunchao Zhang   ierr = PetscFree(B->defaultvectype);CHKERRQ(ierr);
17958c3ff71bSJunchao Zhang   ierr = PetscStrallocpy(VECKOKKOS,&B->defaultvectype);CHKERRQ(ierr);
17963d0639e7SStefano Zampini   ierr = PetscObjectChangeTypeName((PetscObject)B,MATMPIAIJKOKKOS);CHKERRQ(ierr);
17978c3ff71bSJunchao Zhang 
1798076ba34aSJunchao Zhang   a = static_cast<Mat_MPIAIJ*>(A->data);
1799076ba34aSJunchao Zhang   if (a->A) {ierr = MatSetType(a->A,MATSEQAIJKOKKOS);CHKERRQ(ierr);}
1800076ba34aSJunchao Zhang   if (a->B) {ierr = MatSetType(a->B,MATSEQAIJKOKKOS);CHKERRQ(ierr);}
1801076ba34aSJunchao Zhang   if (a->lvec) {ierr = VecSetType(a->lvec,VECSEQKOKKOS);CHKERRQ(ierr);}
1802076ba34aSJunchao Zhang 
18038c3ff71bSJunchao Zhang   B->ops->assemblyend           = MatAssemblyEnd_MPIAIJKokkos;
18048c3ff71bSJunchao Zhang   B->ops->mult                  = MatMult_MPIAIJKokkos;
18058c3ff71bSJunchao Zhang   B->ops->multadd               = MatMultAdd_MPIAIJKokkos;
18068c3ff71bSJunchao Zhang   B->ops->multtranspose         = MatMultTranspose_MPIAIJKokkos;
1807076ba34aSJunchao Zhang   B->ops->productsetfromoptions = MatProductSetFromOptions_MPIAIJKokkos;
1808076ba34aSJunchao Zhang   B->ops->destroy               = MatDestroy_MPIAIJKokkos;
18098c3ff71bSJunchao Zhang 
18103d0639e7SStefano Zampini   ierr = PetscObjectComposeFunction((PetscObject)B,"MatMPIAIJSetPreallocation_C",MatMPIAIJSetPreallocation_MPIAIJKokkos);CHKERRQ(ierr);
181142550becSJunchao Zhang   ierr = PetscObjectComposeFunction((PetscObject)B,"MatMPIAIJGetLocalMatMerge_C",MatMPIAIJGetLocalMatMerge_MPIAIJKokkos);CHKERRQ(ierr);
181242550becSJunchao Zhang   ierr = PetscObjectComposeFunction((PetscObject)B,"MatSetPreallocationCOO_C",   MatSetPreallocationCOO_MPIAIJKokkos);CHKERRQ(ierr);
181342550becSJunchao Zhang   ierr = PetscObjectComposeFunction((PetscObject)B,"MatSetValuesCOO_C",          MatSetValuesCOO_MPIAIJKokkos);CHKERRQ(ierr);
18148c3ff71bSJunchao Zhang   PetscFunctionReturn(0);
18158c3ff71bSJunchao Zhang }
18168c3ff71bSJunchao Zhang 
18178c3ff71bSJunchao Zhang PETSC_EXTERN PetscErrorCode MatCreate_MPIAIJKokkos(Mat A)
18188c3ff71bSJunchao Zhang {
18198c3ff71bSJunchao Zhang   PetscErrorCode ierr;
18208c3ff71bSJunchao Zhang 
18218c3ff71bSJunchao Zhang   PetscFunctionBegin;
18228c3ff71bSJunchao Zhang   ierr = PetscKokkosInitializeCheck();CHKERRQ(ierr);
18238c3ff71bSJunchao Zhang   ierr = MatCreate_MPIAIJ(A);CHKERRQ(ierr);
18248c3ff71bSJunchao Zhang   ierr = MatConvert_MPIAIJ_MPIAIJKokkos(A,MATMPIAIJKOKKOS,MAT_INPLACE_MATRIX,&A);CHKERRQ(ierr);
18258c3ff71bSJunchao Zhang   PetscFunctionReturn(0);
18268c3ff71bSJunchao Zhang }
18278c3ff71bSJunchao Zhang 
18288c3ff71bSJunchao Zhang /*@C
18298c3ff71bSJunchao Zhang    MatCreateAIJKokkos - Creates a sparse matrix in AIJ (compressed row) format
18308c3ff71bSJunchao Zhang    (the default parallel PETSc format).  This matrix will ultimately pushed down
18318c3ff71bSJunchao Zhang    to Kokkos for calculations. For good matrix
18328c3ff71bSJunchao Zhang    assembly performance the user should preallocate the matrix storage by setting
18338c3ff71bSJunchao Zhang    the parameter nz (or the array nnz).  By setting these parameters accurately,
18348c3ff71bSJunchao Zhang    performance during matrix assembly can be increased by more than a factor of 50.
18358c3ff71bSJunchao Zhang 
18368c3ff71bSJunchao Zhang    Collective
18378c3ff71bSJunchao Zhang 
18388c3ff71bSJunchao Zhang    Input Parameters:
18398c3ff71bSJunchao Zhang +  comm - MPI communicator, set to PETSC_COMM_SELF
18408c3ff71bSJunchao Zhang .  m - number of rows
18418c3ff71bSJunchao Zhang .  n - number of columns
18428c3ff71bSJunchao Zhang .  nz - number of nonzeros per row (same for all rows)
18438c3ff71bSJunchao Zhang -  nnz - array containing the number of nonzeros in the various rows
18448c3ff71bSJunchao Zhang          (possibly different for each row) or NULL
18458c3ff71bSJunchao Zhang 
18468c3ff71bSJunchao Zhang    Output Parameter:
18478c3ff71bSJunchao Zhang .  A - the matrix
18488c3ff71bSJunchao Zhang 
18498c3ff71bSJunchao Zhang    It is recommended that one use the MatCreate(), MatSetType() and/or MatSetFromOptions(),
18508c3ff71bSJunchao Zhang    MatXXXXSetPreallocation() paradigm instead of this routine directly.
18518c3ff71bSJunchao Zhang    [MatXXXXSetPreallocation() is, for example, MatSeqAIJSetPreallocation]
18528c3ff71bSJunchao Zhang 
18538c3ff71bSJunchao Zhang    Notes:
18548c3ff71bSJunchao Zhang    If nnz is given then nz is ignored
18558c3ff71bSJunchao Zhang 
18568c3ff71bSJunchao Zhang    The AIJ format (also called the Yale sparse matrix format or
18578c3ff71bSJunchao Zhang    compressed row storage), is fully compatible with standard Fortran 77
18588c3ff71bSJunchao Zhang    storage.  That is, the stored row and column indices can begin at
18598c3ff71bSJunchao Zhang    either one (as in Fortran) or zero.  See the users' manual for details.
18608c3ff71bSJunchao Zhang 
18618c3ff71bSJunchao Zhang    Specify the preallocated storage with either nz or nnz (not both).
18628c3ff71bSJunchao Zhang    Set nz=PETSC_DEFAULT and nnz=NULL for PETSc to control dynamic memory
18638c3ff71bSJunchao Zhang    allocation.  For large problems you MUST preallocate memory or you
18648c3ff71bSJunchao Zhang    will get TERRIBLE performance, see the users' manual chapter on matrices.
18658c3ff71bSJunchao Zhang 
18668c3ff71bSJunchao Zhang    By default, this format uses inodes (identical nodes) when possible, to
18678c3ff71bSJunchao Zhang    improve numerical efficiency of matrix-vector products and solves. We
18688c3ff71bSJunchao Zhang    search for consecutive rows with the same nonzero structure, thereby
18698c3ff71bSJunchao Zhang    reusing matrix information to achieve increased efficiency.
18708c3ff71bSJunchao Zhang 
18718c3ff71bSJunchao Zhang    Level: intermediate
18728c3ff71bSJunchao Zhang 
18738c3ff71bSJunchao Zhang .seealso: MatCreate(), MatCreateAIJ(), MatSetValues(), MatSeqAIJSetColumnIndices(), MatCreateSeqAIJWithArrays(), MatCreateAIJ(), MATMPIAIJKOKKOS, MATAIJKokkos
18748c3ff71bSJunchao Zhang @*/
18758c3ff71bSJunchao Zhang PetscErrorCode  MatCreateAIJKokkos(MPI_Comm comm,PetscInt m,PetscInt n,PetscInt M,PetscInt N,PetscInt d_nz,const PetscInt d_nnz[],PetscInt o_nz,const PetscInt o_nnz[],Mat *A)
18768c3ff71bSJunchao Zhang {
18778c3ff71bSJunchao Zhang   PetscErrorCode ierr;
18788c3ff71bSJunchao Zhang   PetscMPIInt    size;
18798c3ff71bSJunchao Zhang 
18808c3ff71bSJunchao Zhang   PetscFunctionBegin;
18818c3ff71bSJunchao Zhang   ierr = MatCreate(comm,A);CHKERRQ(ierr);
18828c3ff71bSJunchao Zhang   ierr = MatSetSizes(*A,m,n,M,N);CHKERRQ(ierr);
1883ffc4695bSBarry Smith   ierr = MPI_Comm_size(comm,&size);CHKERRMPI(ierr);
18848c3ff71bSJunchao Zhang   if (size > 1) {
18858c3ff71bSJunchao Zhang     ierr = MatSetType(*A,MATMPIAIJKOKKOS);CHKERRQ(ierr);
18868c3ff71bSJunchao Zhang     ierr = MatMPIAIJSetPreallocation(*A,d_nz,d_nnz,o_nz,o_nnz);CHKERRQ(ierr);
18878c3ff71bSJunchao Zhang   } else {
18888c3ff71bSJunchao Zhang     ierr = MatSetType(*A,MATSEQAIJKOKKOS);CHKERRQ(ierr);
18898c3ff71bSJunchao Zhang     ierr = MatSeqAIJSetPreallocation(*A,d_nz,d_nnz);CHKERRQ(ierr);
18908c3ff71bSJunchao Zhang   }
18918c3ff71bSJunchao Zhang   PetscFunctionReturn(0);
18928c3ff71bSJunchao Zhang }
18938c3ff71bSJunchao Zhang 
1894a587d139SMark // get GPU pointer to stripped down Mat. For both Seq and MPI Mat.
1895042217e8SBarry Smith PetscErrorCode MatKokkosGetDeviceMatWrite(Mat A, PetscSplitCSRDataStructure *B)
1896a587d139SMark {
1897a587d139SMark   PetscMPIInt                size,rank;
1898a587d139SMark   MPI_Comm                   comm;
1899a587d139SMark   PetscErrorCode             ierr;
1900042217e8SBarry Smith   PetscSplitCSRDataStructure d_mat=NULL;
1901a587d139SMark 
1902a587d139SMark   PetscFunctionBegin;
1903a587d139SMark   ierr = PetscObjectGetComm((PetscObject)A,&comm);CHKERRQ(ierr);
190455b25c41SPierre Jolivet   ierr = MPI_Comm_size(comm,&size);CHKERRMPI(ierr);
190555b25c41SPierre Jolivet   ierr = MPI_Comm_rank(comm,&rank);CHKERRMPI(ierr);
1906a587d139SMark   if (size == 1) {
1907a587d139SMark     ierr   = MatSeqAIJKokkosGetDeviceMat(A,&d_mat);CHKERRQ(ierr);
1908fc76dfabSMark Adams     ierr   = MatSeqAIJKokkosModifyDevice(A);CHKERRQ(ierr); /* Since we are going to modify matrix values on device */
1909a587d139SMark   } else {
1910a587d139SMark     Mat_MPIAIJ  *aij = (Mat_MPIAIJ*)A->data;
1911a587d139SMark     ierr   = MatSeqAIJKokkosGetDeviceMat(aij->A,&d_mat);CHKERRQ(ierr);
1912fc76dfabSMark Adams     ierr   = MatSeqAIJKokkosModifyDevice(aij->A);CHKERRQ(ierr);
1913fc76dfabSMark Adams     ierr   = MatSeqAIJKokkosModifyDevice(aij->B);CHKERRQ(ierr);
19146c148233SMark Adams     if (!A->nooffprocentries && !aij->donotstash) SETERRQ(PetscObjectComm((PetscObject)A),PETSC_ERR_SUP,"Device assembly does not currently support offproc values insertion. Use MatSetOption(A,MAT_NO_OFF_PROC_ENTRIES,PETSC_TRUE) or MatSetOption(A,MAT_IGNORE_OFF_PROC_ENTRIES,PETSC_TRUE)");
1915a587d139SMark   }
1916a587d139SMark   // act like MatSetValues because not called on host
1917a587d139SMark   if (A->assembled) {
1918a587d139SMark     if (A->was_assembled) {
1919a587d139SMark       ierr = PetscInfo(A,"Assemble more than once already\n");CHKERRQ(ierr);
1920a587d139SMark     }
1921a587d139SMark     A->was_assembled = PETSC_TRUE; // this is done (lazy) in MatAssemble but we are not calling it anymore - done in AIJ AssemblyEnd, need here?
1922a587d139SMark   } else {
1923c0aa6a63SJacob Faibussowitsch     ierr = PetscInfo1(A,"Warning !assemble ??? assembled=%" PetscInt_FMT "\n",A->assembled);CHKERRQ(ierr);
1924a587d139SMark   }
1925a587d139SMark   if (!d_mat) {
1926042217e8SBarry Smith     struct _n_SplitCSRMat h_mat; /* host container */
1927a587d139SMark     Mat_SeqAIJKokkos      *aijkokA;
1928a587d139SMark     Mat_SeqAIJ            *jaca;
1929a587d139SMark     PetscInt              n = A->rmap->n, nnz;
1930a587d139SMark     Mat                   Amat;
1931042217e8SBarry Smith     PetscInt              *colmap;
1932042217e8SBarry Smith 
1933042217e8SBarry Smith     /* create and copy h_mat */
193449b994a9SMark Adams     h_mat.M = A->cmap->N; // use for debug build
1935a587d139SMark     ierr = PetscInfo(A,"Create device matrix in Kokkos\n");CHKERRQ(ierr);
1936a587d139SMark     if (size == 1) {
1937a587d139SMark       Amat = A;
1938a587d139SMark       jaca = (Mat_SeqAIJ*)A->data;
1939a587d139SMark       h_mat.rstart = 0; h_mat.rend = A->rmap->n;
1940a587d139SMark       h_mat.cstart = 0; h_mat.cend = A->cmap->n;
1941a587d139SMark       h_mat.offdiag.i = h_mat.offdiag.j = NULL;
1942a587d139SMark       h_mat.offdiag.a = NULL;
1943a587d139SMark       aijkokA = static_cast<Mat_SeqAIJKokkos*>(A->spptr);
1944a587d139SMark     } else {
1945a587d139SMark       Mat_MPIAIJ       *aij = (Mat_MPIAIJ*)A->data;
1946a587d139SMark       Mat_SeqAIJ       *jacb = (Mat_SeqAIJ*)aij->B->data;
1947a587d139SMark       PetscInt         ii;
1948a587d139SMark       Mat_SeqAIJKokkos *aijkokB;
1949042217e8SBarry Smith 
1950a587d139SMark       Amat = aij->A;
1951a587d139SMark       aijkokA = static_cast<Mat_SeqAIJKokkos*>(aij->A->spptr);
1952a587d139SMark       aijkokB = static_cast<Mat_SeqAIJKokkos*>(aij->B->spptr);
1953a587d139SMark       jaca = (Mat_SeqAIJ*)aij->A->data;
1954b3c64f9dSJunchao Zhang       if (aij->B->cmap->n && !aij->garray) SETERRQ(comm,PETSC_ERR_PLIB,"MPIAIJ Matrix was assembled but is missing garray");
1955a587d139SMark       if (aij->B->rmap->n != aij->A->rmap->n) SETERRQ(comm,PETSC_ERR_SUP,"Only support aij->B->rmap->n == aij->A->rmap->n");
1956a587d139SMark       aij->donotstash = PETSC_TRUE;
1957a587d139SMark       aij->A->nooffprocentries = aij->B->nooffprocentries = A->nooffprocentries = PETSC_TRUE;
1958a5b23f4aSJose E. Roman       jaca->nonew = jacb->nonew = PETSC_TRUE; // no more disassembly
1959042217e8SBarry Smith       ierr = PetscCalloc1(A->cmap->N,&colmap);CHKERRQ(ierr);
1960042217e8SBarry Smith       ierr = PetscLogObjectMemory((PetscObject)A,(A->cmap->N)*sizeof(PetscInt));CHKERRQ(ierr);
1961042217e8SBarry Smith       for (ii=0; ii<aij->B->cmap->n; ii++) colmap[aij->garray[ii]] = ii+1;
1962a587d139SMark       // allocate B copy data
1963a587d139SMark       h_mat.rstart = A->rmap->rstart; h_mat.rend = A->rmap->rend;
1964a587d139SMark       h_mat.cstart = A->cmap->rstart; h_mat.cend = A->cmap->rend;
1965a587d139SMark       nnz = jacb->i[n];
1966a587d139SMark       if (jacb->compressedrow.use) {
1967a587d139SMark         const Kokkos::View<PetscInt*, Kokkos::HostSpace, Kokkos::MemoryTraits<Kokkos::Unmanaged> > h_i_k (jacb->i,n+1);
1968300d22a6SJunchao Zhang         aijkokB->i_uncompressed_d = Kokkos::View<PetscInt*>(Kokkos::create_mirror(DefaultMemorySpace(),h_i_k));
1969300d22a6SJunchao Zhang         Kokkos::deep_copy (aijkokB->i_uncompressed_d, h_i_k);
1970300d22a6SJunchao Zhang         h_mat.offdiag.i = aijkokB->i_uncompressed_d.data();
1971a587d139SMark       } else {
197299551766SMark Adams          h_mat.offdiag.i = aijkokB->i_device_data();
1973a587d139SMark       }
197499551766SMark Adams       h_mat.offdiag.j = aijkokB->j_device_data();
1975076ba34aSJunchao Zhang       h_mat.offdiag.a = aijkokB->a_device_data();
1976a587d139SMark       {
1977042217e8SBarry Smith         Kokkos::View<PetscInt*, Kokkos::HostSpace, Kokkos::MemoryTraits<Kokkos::Unmanaged> > h_colmap_k (colmap,A->cmap->N);
1978300d22a6SJunchao Zhang         aijkokB->colmap_d = Kokkos::View<PetscInt*>(Kokkos::create_mirror(DefaultMemorySpace(),h_colmap_k));
1979300d22a6SJunchao Zhang         Kokkos::deep_copy (aijkokB->colmap_d, h_colmap_k);
1980300d22a6SJunchao Zhang         h_mat.colmap = aijkokB->colmap_d.data();
1981042217e8SBarry Smith         ierr = PetscFree(colmap);CHKERRQ(ierr);
1982a587d139SMark       }
1983a587d139SMark       h_mat.offdiag.ignorezeroentries = jacb->ignorezeroentries;
1984a587d139SMark       h_mat.offdiag.n = n;
1985a587d139SMark     }
1986a587d139SMark     // allocate A copy data
1987a587d139SMark     nnz = jaca->i[n];
1988a587d139SMark     h_mat.diag.n = n;
1989a587d139SMark     h_mat.diag.ignorezeroentries = jaca->ignorezeroentries;
199055b25c41SPierre Jolivet     ierr = MPI_Comm_rank(comm,&h_mat.rank);CHKERRMPI(ierr);
1991042217e8SBarry Smith     if (jaca->compressedrow.use) SETERRQ(PETSC_COMM_SELF,PETSC_ERR_PLIB,"A does not suppport compressed row (todo)");
1992042217e8SBarry Smith     else {
199399551766SMark Adams       h_mat.diag.i = aijkokA->i_device_data();
1994a587d139SMark     }
199599551766SMark Adams     h_mat.diag.j = aijkokA->j_device_data();
1996076ba34aSJunchao Zhang     h_mat.diag.a = aijkokA->a_device_data();
1997a587d139SMark     // copy pointers and metdata to device
1998a587d139SMark     ierr = MatSeqAIJKokkosSetDeviceMat(Amat,&h_mat);CHKERRQ(ierr);
1999a587d139SMark     ierr = MatSeqAIJKokkosGetDeviceMat(Amat,&d_mat);CHKERRQ(ierr);
2000c0aa6a63SJacob Faibussowitsch     ierr = PetscInfo2(A,"Create device Mat n=%" PetscInt_FMT " nnz=%" PetscInt_FMT "\n",h_mat.diag.n, nnz);CHKERRQ(ierr);
2001a587d139SMark   }
2002a587d139SMark   *B = d_mat; // return it, set it in Mat, and set it up
2003a587d139SMark   A->assembled = PETSC_FALSE; // ready to write with matsetvalues - this done (lazy) in normal MatSetValues
2004a587d139SMark   PetscFunctionReturn(0);
2005a587d139SMark }
2006076ba34aSJunchao Zhang 
2007076ba34aSJunchao Zhang PETSC_INTERN PetscErrorCode MatSeqAIJKokkosGetOffloadMask(Mat A, const char **mask)
2008076ba34aSJunchao Zhang {
2009076ba34aSJunchao Zhang   Mat_SeqAIJKokkos  *aijkok = static_cast<Mat_SeqAIJKokkos*>(A->spptr);
2010076ba34aSJunchao Zhang 
2011076ba34aSJunchao Zhang   PetscFunctionBegin;
2012076ba34aSJunchao Zhang   if (!aijkok) *mask = "AIJKOK_UNALLOCATED";
2013076ba34aSJunchao Zhang   else if (aijkok->a_dual.need_sync_host()) *mask = "PETSC_OFFLOAD_GPU";
2014076ba34aSJunchao Zhang   else if (aijkok->a_dual.need_sync_device()) *mask = "PETSC_OFFLOAD_CPU";
2015076ba34aSJunchao Zhang   else *mask = "PETSC_OFFLOAD_BOTH";
2016076ba34aSJunchao Zhang   PetscFunctionReturn(0);
2017076ba34aSJunchao Zhang }
2018076ba34aSJunchao Zhang 
2019076ba34aSJunchao Zhang PETSC_INTERN PetscErrorCode MatAIJKokkosPrintOffloadMask(Mat A)
2020076ba34aSJunchao Zhang {
2021076ba34aSJunchao Zhang   PetscErrorCode    ierr;
2022076ba34aSJunchao Zhang   PetscMPIInt       size;
2023076ba34aSJunchao Zhang   Mat               Ad,Ao;
2024076ba34aSJunchao Zhang   const char        *amask,*bmask;
2025076ba34aSJunchao Zhang 
2026076ba34aSJunchao Zhang   PetscFunctionBegin;
2027076ba34aSJunchao Zhang   ierr = MPI_Comm_size(PetscObjectComm((PetscObject)A),&size);CHKERRMPI(ierr);
2028076ba34aSJunchao Zhang 
2029076ba34aSJunchao Zhang   if (size == 1) {
2030076ba34aSJunchao Zhang     ierr = MatSeqAIJKokkosGetOffloadMask(A,&amask);CHKERRQ(ierr);
2031076ba34aSJunchao Zhang     ierr = PetscPrintf(PETSC_COMM_SELF,"%s\n",amask);CHKERRQ(ierr);
2032076ba34aSJunchao Zhang   } else {
2033076ba34aSJunchao Zhang     Ad  = ((Mat_MPIAIJ*)A->data)->A;
2034076ba34aSJunchao Zhang     Ao  = ((Mat_MPIAIJ*)A->data)->B;
2035076ba34aSJunchao Zhang     ierr = MatSeqAIJKokkosGetOffloadMask(Ad,&amask);CHKERRQ(ierr);
2036076ba34aSJunchao Zhang     ierr = MatSeqAIJKokkosGetOffloadMask(Ao,&bmask);CHKERRQ(ierr);
2037076ba34aSJunchao Zhang     ierr = PetscPrintf(PETSC_COMM_SELF,"Diag : Off-diag = %s : %s\n",amask,bmask);CHKERRQ(ierr);
2038076ba34aSJunchao Zhang   }
2039076ba34aSJunchao Zhang   PetscFunctionReturn(0);
2040076ba34aSJunchao Zhang }
2041