xref: /petsc/src/mat/impls/aij/mpi/kokkos/mpiaijkok.kokkos.cxx (revision 0e3ece09e3140f5ddabf8db2b6f5fd48b6ec6274)
111d22bbfSJunchao Zhang #include <petscvec_kokkos.hpp>
2f0e6e2d1SJunchao Zhang #include <petscpkg_version.h>
3076ba34aSJunchao Zhang #include <petscsf.h>
442550becSJunchao Zhang #include <petsc/private/sfimpl.h>
58c3ff71bSJunchao Zhang #include <../src/mat/impls/aij/mpi/mpiaij.h>
642550becSJunchao Zhang #include <../src/mat/impls/aij/mpi/kokkos/mpiaijkok.hpp>
7076ba34aSJunchao Zhang #include <KokkosSparse_spadd.hpp>
8*0e3ece09SJunchao Zhang #include <KokkosSparse_spgemm.hpp>
911d22bbfSJunchao Zhang 
10d71ae5a4SJacob Faibussowitsch PetscErrorCode MatAssemblyEnd_MPIAIJKokkos(Mat A, MatAssemblyType mode)
11d71ae5a4SJacob Faibussowitsch {
125519a089SJose E. Roman   Mat_SeqAIJKokkos *aijkok;
1330203840SJunchao Zhang   Mat_MPIAIJ       *mpiaij = (Mat_MPIAIJ *)A->data;
148c3ff71bSJunchao Zhang 
158c3ff71bSJunchao Zhang   PetscFunctionBegin;
169566063dSJacob Faibussowitsch   PetscCall(MatAssemblyEnd_MPIAIJ(A, mode));
1730203840SJunchao Zhang   /* E.g., MatCreateSubMatrix() calls MatCreateMPIAIJWithSeqAIJ(comm,A,B,..), which creates Bnew of SEQAIJ and destroys B of SEQAIJKOKKOS.
1830203840SJunchao Zhang      Thus we finalize A/B/lvec's type in MatAssemblyEnd() to handle various cases.
1930203840SJunchao Zhang    */
2030203840SJunchao Zhang   if (mode == MAT_FINAL_ASSEMBLY) {
2130203840SJunchao Zhang     PetscCall(MatSetType(mpiaij->A, MATSEQAIJKOKKOS));
2230203840SJunchao Zhang     PetscCall(MatSetType(mpiaij->B, MATSEQAIJKOKKOS));
2330203840SJunchao Zhang     PetscCall(VecSetType(mpiaij->lvec, VECSEQKOKKOS));
2430203840SJunchao Zhang   }
255519a089SJose E. Roman   aijkok = static_cast<Mat_SeqAIJKokkos *>(((Mat_MPIAIJ *)A->data)->A->spptr); /* Access spptr after MatAssemblyEnd_MPIAIJ(), which might have deleted old spptr */
26a587d139SMark   if (aijkok && aijkok->device_mat_d.data()) {
27a587d139SMark     A->offloadmask = PETSC_OFFLOAD_GPU; // in GPU mode, no going back. MatSetValues checks this
28a587d139SMark   }
29a587d139SMark 
303ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
318c3ff71bSJunchao Zhang }
328c3ff71bSJunchao Zhang 
33d71ae5a4SJacob Faibussowitsch PetscErrorCode MatMPIAIJSetPreallocation_MPIAIJKokkos(Mat mat, PetscInt d_nz, const PetscInt d_nnz[], PetscInt o_nz, const PetscInt o_nnz[])
34d71ae5a4SJacob Faibussowitsch {
358c3ff71bSJunchao Zhang   Mat_MPIAIJ *mpiaij = (Mat_MPIAIJ *)mat->data;
368c3ff71bSJunchao Zhang 
378c3ff71bSJunchao Zhang   PetscFunctionBegin;
389566063dSJacob Faibussowitsch   PetscCall(PetscLayoutSetUp(mat->rmap));
399566063dSJacob Faibussowitsch   PetscCall(PetscLayoutSetUp(mat->cmap));
406a29ce69SStefano Zampini #if defined(PETSC_USE_DEBUG)
418c3ff71bSJunchao Zhang   if (d_nnz) {
426a29ce69SStefano Zampini     PetscInt i;
43ad540459SPierre Jolivet     for (i = 0; i < mat->rmap->n; i++) PetscCheck(d_nnz[i] >= 0, PETSC_COMM_SELF, PETSC_ERR_ARG_OUTOFRANGE, "d_nnz cannot be less than 0: local row %" PetscInt_FMT " value %" PetscInt_FMT, i, d_nnz[i]);
448c3ff71bSJunchao Zhang   }
458c3ff71bSJunchao Zhang   if (o_nnz) {
466a29ce69SStefano Zampini     PetscInt i;
47ad540459SPierre Jolivet     for (i = 0; i < mat->rmap->n; i++) PetscCheck(o_nnz[i] >= 0, PETSC_COMM_SELF, PETSC_ERR_ARG_OUTOFRANGE, "o_nnz cannot be less than 0: local row %" PetscInt_FMT " value %" PetscInt_FMT, i, o_nnz[i]);
488c3ff71bSJunchao Zhang   }
496a29ce69SStefano Zampini #endif
506a29ce69SStefano Zampini #if defined(PETSC_USE_CTABLE)
51eec179cfSJacob Faibussowitsch   PetscCall(PetscHMapIDestroy(&mpiaij->colmap));
526a29ce69SStefano Zampini #else
539566063dSJacob Faibussowitsch   PetscCall(PetscFree(mpiaij->colmap));
546a29ce69SStefano Zampini #endif
559566063dSJacob Faibussowitsch   PetscCall(PetscFree(mpiaij->garray));
569566063dSJacob Faibussowitsch   PetscCall(VecDestroy(&mpiaij->lvec));
579566063dSJacob Faibussowitsch   PetscCall(VecScatterDestroy(&mpiaij->Mvctx));
586a29ce69SStefano Zampini   /* Because the B will have been resized we simply destroy it and create a new one each time */
599566063dSJacob Faibussowitsch   PetscCall(MatDestroy(&mpiaij->B));
606a29ce69SStefano Zampini 
616a29ce69SStefano Zampini   if (!mpiaij->A) {
629566063dSJacob Faibussowitsch     PetscCall(MatCreate(PETSC_COMM_SELF, &mpiaij->A));
639566063dSJacob Faibussowitsch     PetscCall(MatSetSizes(mpiaij->A, mat->rmap->n, mat->cmap->n, mat->rmap->n, mat->cmap->n));
646a29ce69SStefano Zampini   }
656a29ce69SStefano Zampini   if (!mpiaij->B) {
666a29ce69SStefano Zampini     PetscMPIInt size;
679566063dSJacob Faibussowitsch     PetscCallMPI(MPI_Comm_size(PetscObjectComm((PetscObject)mat), &size));
689566063dSJacob Faibussowitsch     PetscCall(MatCreate(PETSC_COMM_SELF, &mpiaij->B));
699566063dSJacob Faibussowitsch     PetscCall(MatSetSizes(mpiaij->B, mat->rmap->n, size > 1 ? mat->cmap->N : 0, mat->rmap->n, size > 1 ? mat->cmap->N : 0));
708c3ff71bSJunchao Zhang   }
719566063dSJacob Faibussowitsch   PetscCall(MatSetType(mpiaij->A, MATSEQAIJKOKKOS));
729566063dSJacob Faibussowitsch   PetscCall(MatSetType(mpiaij->B, MATSEQAIJKOKKOS));
739566063dSJacob Faibussowitsch   PetscCall(MatSeqAIJSetPreallocation(mpiaij->A, d_nz, d_nnz));
749566063dSJacob Faibussowitsch   PetscCall(MatSeqAIJSetPreallocation(mpiaij->B, o_nz, o_nnz));
758c3ff71bSJunchao Zhang   mat->preallocated = PETSC_TRUE;
763ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
778c3ff71bSJunchao Zhang }
788c3ff71bSJunchao Zhang 
79d71ae5a4SJacob Faibussowitsch PetscErrorCode MatMult_MPIAIJKokkos(Mat mat, Vec xx, Vec yy)
80d71ae5a4SJacob Faibussowitsch {
818c3ff71bSJunchao Zhang   Mat_MPIAIJ *mpiaij = (Mat_MPIAIJ *)mat->data;
828c3ff71bSJunchao Zhang   PetscInt    nt;
838c3ff71bSJunchao Zhang 
848c3ff71bSJunchao Zhang   PetscFunctionBegin;
859566063dSJacob Faibussowitsch   PetscCall(VecGetLocalSize(xx, &nt));
8608401ef6SPierre Jolivet   PetscCheck(nt == mat->cmap->n, PETSC_COMM_SELF, PETSC_ERR_ARG_SIZ, "Incompatible partition of mat (%" PetscInt_FMT ") and xx (%" PetscInt_FMT ")", mat->cmap->n, nt);
879566063dSJacob Faibussowitsch   PetscCall(VecScatterBegin(mpiaij->Mvctx, xx, mpiaij->lvec, INSERT_VALUES, SCATTER_FORWARD));
889566063dSJacob Faibussowitsch   PetscCall((*mpiaij->A->ops->mult)(mpiaij->A, xx, yy));
899566063dSJacob Faibussowitsch   PetscCall(VecScatterEnd(mpiaij->Mvctx, xx, mpiaij->lvec, INSERT_VALUES, SCATTER_FORWARD));
909566063dSJacob Faibussowitsch   PetscCall((*mpiaij->B->ops->multadd)(mpiaij->B, mpiaij->lvec, yy, yy));
913ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
928c3ff71bSJunchao Zhang }
938c3ff71bSJunchao Zhang 
94d71ae5a4SJacob Faibussowitsch PetscErrorCode MatMultAdd_MPIAIJKokkos(Mat mat, Vec xx, Vec yy, Vec zz)
95d71ae5a4SJacob Faibussowitsch {
968c3ff71bSJunchao Zhang   Mat_MPIAIJ *mpiaij = (Mat_MPIAIJ *)mat->data;
978c3ff71bSJunchao Zhang   PetscInt    nt;
988c3ff71bSJunchao Zhang 
998c3ff71bSJunchao Zhang   PetscFunctionBegin;
1009566063dSJacob Faibussowitsch   PetscCall(VecGetLocalSize(xx, &nt));
10108401ef6SPierre Jolivet   PetscCheck(nt == mat->cmap->n, PETSC_COMM_SELF, PETSC_ERR_ARG_SIZ, "Incompatible partition of mat (%" PetscInt_FMT ") and xx (%" PetscInt_FMT ")", mat->cmap->n, nt);
1029566063dSJacob Faibussowitsch   PetscCall(VecScatterBegin(mpiaij->Mvctx, xx, mpiaij->lvec, INSERT_VALUES, SCATTER_FORWARD));
1039566063dSJacob Faibussowitsch   PetscCall((*mpiaij->A->ops->multadd)(mpiaij->A, xx, yy, zz));
1049566063dSJacob Faibussowitsch   PetscCall(VecScatterEnd(mpiaij->Mvctx, xx, mpiaij->lvec, INSERT_VALUES, SCATTER_FORWARD));
1059566063dSJacob Faibussowitsch   PetscCall((*mpiaij->B->ops->multadd)(mpiaij->B, mpiaij->lvec, zz, zz));
1063ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
1078c3ff71bSJunchao Zhang }
1088c3ff71bSJunchao Zhang 
109d71ae5a4SJacob Faibussowitsch PetscErrorCode MatMultTranspose_MPIAIJKokkos(Mat mat, Vec xx, Vec yy)
110d71ae5a4SJacob Faibussowitsch {
1118c3ff71bSJunchao Zhang   Mat_MPIAIJ *mpiaij = (Mat_MPIAIJ *)mat->data;
1128c3ff71bSJunchao Zhang   PetscInt    nt;
1138c3ff71bSJunchao Zhang 
1148c3ff71bSJunchao Zhang   PetscFunctionBegin;
1159566063dSJacob Faibussowitsch   PetscCall(VecGetLocalSize(xx, &nt));
11608401ef6SPierre Jolivet   PetscCheck(nt == mat->rmap->n, PETSC_COMM_SELF, PETSC_ERR_ARG_SIZ, "Incompatible partition of mat (%" PetscInt_FMT ") and xx (%" PetscInt_FMT ")", mat->rmap->n, nt);
1179566063dSJacob Faibussowitsch   PetscCall((*mpiaij->B->ops->multtranspose)(mpiaij->B, xx, mpiaij->lvec));
1189566063dSJacob Faibussowitsch   PetscCall((*mpiaij->A->ops->multtranspose)(mpiaij->A, xx, yy));
1199566063dSJacob Faibussowitsch   PetscCall(VecScatterBegin(mpiaij->Mvctx, mpiaij->lvec, yy, ADD_VALUES, SCATTER_REVERSE));
1209566063dSJacob Faibussowitsch   PetscCall(VecScatterEnd(mpiaij->Mvctx, mpiaij->lvec, yy, ADD_VALUES, SCATTER_REVERSE));
1213ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
1228c3ff71bSJunchao Zhang }
1238c3ff71bSJunchao Zhang 
124076ba34aSJunchao Zhang /* Merge the "A, B" matrices of mat into a matrix C.  mat's type is MPIAIJKOKKOS. C's type is MATSEQAIJKOKKOS.
125076ba34aSJunchao Zhang    A is put before B. C's size would be A->rmap->n by (A->cmap->n + B->cmap->n).
126076ba34aSJunchao Zhang    C still uses local column ids. Their corresponding global column ids are returned in glob.
127076ba34aSJunchao Zhang */
128d71ae5a4SJacob Faibussowitsch PetscErrorCode MatMPIAIJGetLocalMatMerge_MPIAIJKokkos(Mat mat, MatReuse reuse, IS *glob, Mat *C)
129d71ae5a4SJacob Faibussowitsch {
130076ba34aSJunchao Zhang   Mat             Ad, Ao;
131076ba34aSJunchao Zhang   const PetscInt *cmap;
132076ba34aSJunchao Zhang 
133076ba34aSJunchao Zhang   PetscFunctionBegin;
1349566063dSJacob Faibussowitsch   PetscCall(MatMPIAIJGetSeqAIJ(mat, &Ad, &Ao, &cmap));
1359566063dSJacob Faibussowitsch   PetscCall(MatSeqAIJKokkosMergeMats(Ad, Ao, reuse, C));
136076ba34aSJunchao Zhang   if (glob) {
137076ba34aSJunchao Zhang     PetscInt cst, i, dn, on, *gidx;
1389566063dSJacob Faibussowitsch     PetscCall(MatGetLocalSize(Ad, NULL, &dn));
1399566063dSJacob Faibussowitsch     PetscCall(MatGetLocalSize(Ao, NULL, &on));
1409566063dSJacob Faibussowitsch     PetscCall(MatGetOwnershipRangeColumn(mat, &cst, NULL));
1419566063dSJacob Faibussowitsch     PetscCall(PetscMalloc1(dn + on, &gidx));
142076ba34aSJunchao Zhang     for (i = 0; i < dn; i++) gidx[i] = cst + i;
143076ba34aSJunchao Zhang     for (i = 0; i < on; i++) gidx[i + dn] = cmap[i];
1449566063dSJacob Faibussowitsch     PetscCall(ISCreateGeneral(PetscObjectComm((PetscObject)Ad), dn + on, gidx, PETSC_OWN_POINTER, glob));
145076ba34aSJunchao Zhang   }
1463ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
147076ba34aSJunchao Zhang }
148076ba34aSJunchao Zhang 
149*0e3ece09SJunchao Zhang /* Structs used in matrix products of type C=AB, C=A^tB and C=B^tAB */
150076ba34aSJunchao Zhang struct MatMatStruct {
151*0e3ece09SJunchao Zhang   PetscInt            n, *garray;     // C's garray and its size.
152*0e3ece09SJunchao Zhang   KokkosCsrMatrix     Cd, Co;         // C is in split form matrices (all in local column indcies)
153*0e3ece09SJunchao Zhang   KokkosCsrMatrix     C1, C2, C3, C4; // intermediate mat products
154*0e3ece09SJunchao Zhang   KokkosCsrMatrix     C2_mid, C4_mid; // alias of C2, C4; share their a[], i[], but with different j[] (hence column size)
155*0e3ece09SJunchao Zhang   PetscIntKokkosView  E_NzLeft;
156*0e3ece09SJunchao Zhang   PetscSF             sf = nullptr; // SF to bcast or reduce matrices E to F
157*0e3ece09SJunchao Zhang   MatScalarKokkosView rootBuf, leafBuf;
158*0e3ece09SJunchao Zhang   KokkosCsrMatrix     Fd, Fo; // F in split form
159*0e3ece09SJunchao Zhang 
160*0e3ece09SJunchao Zhang   KernelHandle kh1; // compute C1, add C1+C3 or C1+Fd
161*0e3ece09SJunchao Zhang   KernelHandle kh2; // compute C2, add C2+C4 or C2+Fo
162*0e3ece09SJunchao Zhang   KernelHandle kh3; // compute C3
163*0e3ece09SJunchao Zhang   KernelHandle kh4; // compute C4
164*0e3ece09SJunchao Zhang 
165*0e3ece09SJunchao Zhang   PetscInt E_TeamSize; // kernel launching parameters in merging E or spliting F
166*0e3ece09SJunchao Zhang   PetscInt E_VectorLength;
167*0e3ece09SJunchao Zhang   PetscInt E_RowsPerTeam;
168*0e3ece09SJunchao Zhang   PetscInt F_TeamSize;
169*0e3ece09SJunchao Zhang   PetscInt F_VectorLength;
170*0e3ece09SJunchao Zhang   PetscInt F_RowsPerTeam;
171076ba34aSJunchao Zhang 
172d71ae5a4SJacob Faibussowitsch   ~MatMatStruct()
173d71ae5a4SJacob Faibussowitsch   {
1743ba16761SJacob Faibussowitsch     PetscFunctionBegin;
1753ba16761SJacob Faibussowitsch     PetscCallAbort(PETSC_COMM_SELF, PetscSFDestroy(&sf));
1763ba16761SJacob Faibussowitsch     PetscFunctionReturnVoid();
177076ba34aSJunchao Zhang   }
178076ba34aSJunchao Zhang };
179076ba34aSJunchao Zhang 
180076ba34aSJunchao Zhang struct MatMatStruct_AB : public MatMatStruct {
181*0e3ece09SJunchao Zhang   PetscIntKokkosView F_NzLeft; // plans to split F (in leafbuf) into Fd, Fo
182*0e3ece09SJunchao Zhang   PetscIntKokkosView irootloc; // plans to put E (i.e., Bd, Bo) into rootBuf
183*0e3ece09SJunchao Zhang   PetscIntKokkosView rowoffset;
184076ba34aSJunchao Zhang };
185076ba34aSJunchao Zhang 
186076ba34aSJunchao Zhang struct MatMatStruct_AtB : public MatMatStruct {
187*0e3ece09SJunchao Zhang   MatColIdxKokkosView Fdjmap; // plans to reduce data in rootBuf to Fd, Fo
188*0e3ece09SJunchao Zhang   MatColIdxKokkosView Fdjperm;
189*0e3ece09SJunchao Zhang   MatColIdxKokkosView Fojmap;
190*0e3ece09SJunchao Zhang   MatColIdxKokkosView Fojperm;
191076ba34aSJunchao Zhang };
192076ba34aSJunchao Zhang 
1939371c9d4SSatish Balay struct MatProductData_MPIAIJKokkos {
1943ba16761SJacob Faibussowitsch   MatMatStruct_AB  *mmAB     = nullptr;
1953ba16761SJacob Faibussowitsch   MatMatStruct_AtB *mmAtB    = nullptr;
1963ba16761SJacob Faibussowitsch   PetscBool         reusesym = PETSC_FALSE;
197*0e3ece09SJunchao Zhang   Mat               Z        = nullptr; // store Z=AB in computing BtAB
198076ba34aSJunchao Zhang 
199d71ae5a4SJacob Faibussowitsch   ~MatProductData_MPIAIJKokkos()
200d71ae5a4SJacob Faibussowitsch   {
201076ba34aSJunchao Zhang     delete mmAB;
202076ba34aSJunchao Zhang     delete mmAtB;
203*0e3ece09SJunchao Zhang     PetscCallAbort(PETSC_COMM_SELF, MatDestroy(&Z));
204076ba34aSJunchao Zhang   }
205076ba34aSJunchao Zhang };
206076ba34aSJunchao Zhang 
207d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatProductDataDestroy_MPIAIJKokkos(void *data)
208d71ae5a4SJacob Faibussowitsch {
209076ba34aSJunchao Zhang   PetscFunctionBegin;
2109566063dSJacob Faibussowitsch   PetscCallCXX(delete static_cast<MatProductData_MPIAIJKokkos *>(data));
2113ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
212076ba34aSJunchao Zhang }
213076ba34aSJunchao Zhang 
214076ba34aSJunchao Zhang /* MatSetMPIAIJKokkosWithSplitSeqAIJKokkosMatrices - Set the diag and offdiag matrices of a MATMPIAIJKOKKOS matrix.
215076ba34aSJunchao Zhang    It is similar to MatCreateMPIAIJWithSplitArrays.
216076ba34aSJunchao Zhang 
217076ba34aSJunchao Zhang   Input Parameters:
218076ba34aSJunchao Zhang +  mat   - the MATMPIAIJKOKKOS matrix, which should have its type and layout set, but should not have its diag, offdiag matrices set
219076ba34aSJunchao Zhang .  A     - the diag matrix using local col ids
220076ba34aSJunchao Zhang -  B     - the offdiag matrix using global col ids
221076ba34aSJunchao Zhang 
222076ba34aSJunchao Zhang   Output Parameters:
223076ba34aSJunchao Zhang .  mat   - the updated MATMPIAIJKOKKOS matrix
224076ba34aSJunchao Zhang */
225*0e3ece09SJunchao Zhang static PetscErrorCode MatSetMPIAIJKokkosWithSplitSeqAIJKokkosMatrices(Mat mat, Mat A, Mat B, PetscInt *garray)
226d71ae5a4SJacob Faibussowitsch {
227076ba34aSJunchao Zhang   Mat_MPIAIJ *mpiaij = static_cast<Mat_MPIAIJ *>(mat->data);
228076ba34aSJunchao Zhang   PetscInt    m, n, M, N, Am, An, Bm, Bn;
229076ba34aSJunchao Zhang 
230076ba34aSJunchao Zhang   PetscFunctionBegin;
2319566063dSJacob Faibussowitsch   PetscCall(MatGetSize(mat, &M, &N));
2329566063dSJacob Faibussowitsch   PetscCall(MatGetLocalSize(mat, &m, &n));
2339566063dSJacob Faibussowitsch   PetscCall(MatGetLocalSize(A, &Am, &An));
2349566063dSJacob Faibussowitsch   PetscCall(MatGetLocalSize(B, &Bm, &Bn));
235076ba34aSJunchao Zhang 
236aed4548fSBarry Smith   PetscCheck(m == Am && m == Bm, PETSC_COMM_SELF, PETSC_ERR_PLIB, "local number of rows do not match");
23708401ef6SPierre Jolivet   PetscCheck(n == An, PETSC_COMM_SELF, PETSC_ERR_PLIB, "local number of columns do not match");
238*0e3ece09SJunchao Zhang   // PetscCheck(N == Bn, PETSC_COMM_SELF, PETSC_ERR_PLIB, "global number of columns do not match");
23908401ef6SPierre Jolivet   PetscCheck(!mpiaij->A && !mpiaij->B, PETSC_COMM_SELF, PETSC_ERR_PLIB, "A, B of the MPIAIJ matrix are not empty");
240076ba34aSJunchao Zhang   mpiaij->A      = A;
241076ba34aSJunchao Zhang   mpiaij->B      = B;
242*0e3ece09SJunchao Zhang   mpiaij->garray = garray;
243076ba34aSJunchao Zhang 
244076ba34aSJunchao Zhang   mat->preallocated     = PETSC_TRUE;
245076ba34aSJunchao Zhang   mat->nooffprocentries = PETSC_TRUE; /* See MatAssemblyBegin_MPIAIJ. In effect, making MatAssemblyBegin a nop */
246076ba34aSJunchao Zhang 
2479566063dSJacob Faibussowitsch   PetscCall(MatSetOption(mat, MAT_NO_OFF_PROC_ENTRIES, PETSC_TRUE));
2489566063dSJacob Faibussowitsch   PetscCall(MatAssemblyBegin(mat, MAT_FINAL_ASSEMBLY));
249076ba34aSJunchao Zhang   /* MatAssemblyEnd is critical here. It sets mat->offloadmask according to A and B's, and
250076ba34aSJunchao Zhang     also gets mpiaij->B compacted, with its col ids and size reduced
251076ba34aSJunchao Zhang   */
2529566063dSJacob Faibussowitsch   PetscCall(MatAssemblyEnd(mat, MAT_FINAL_ASSEMBLY));
2539566063dSJacob Faibussowitsch   PetscCall(MatSetOption(mat, MAT_NO_OFF_PROC_ENTRIES, PETSC_FALSE));
2549566063dSJacob Faibussowitsch   PetscCall(MatSetOption(mat, MAT_NEW_NONZERO_LOCATION_ERR, PETSC_TRUE));
2553ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
256076ba34aSJunchao Zhang }
257076ba34aSJunchao Zhang 
258*0e3ece09SJunchao Zhang // Adapted from Kokkos-Kernels spmv_launch_parameters(), to get parameters in Kokkos nested loops which we used to merge or
259*0e3ece09SJunchao Zhang // split csr matrices. The rule is to have "vector_length * team_size" be around 256 on GPUs (e.g., for a CUDA thread block)
260*0e3ece09SJunchao Zhang template <class ExecutionSpace>
261*0e3ece09SJunchao Zhang static PetscErrorCode MatMergeGetLaunchParameters(PetscInt numRows, PetscInt nnz, PetscInt rows_per_thread, PetscInt &team_size, PetscInt &vector_length, PetscInt &rows_per_team)
262d71ae5a4SJacob Faibussowitsch {
263*0e3ece09SJunchao Zhang   Kokkos::TeamPolicy<ExecutionSpace> teamPolicy(128, Kokkos::AUTO);
264076ba34aSJunchao Zhang 
265076ba34aSJunchao Zhang   PetscFunctionBegin;
266*0e3ece09SJunchao Zhang   PetscInt nnz_per_row = numRows ? (nnz / numRows) : 0; // we might meet empty matrices
267076ba34aSJunchao Zhang 
268*0e3ece09SJunchao Zhang   if (nnz_per_row < 1) nnz_per_row = 1;
269076ba34aSJunchao Zhang 
270*0e3ece09SJunchao Zhang   int max_vector_length = teamPolicy.vector_length_max();
271076ba34aSJunchao Zhang 
272*0e3ece09SJunchao Zhang   if (vector_length < 1) {
273*0e3ece09SJunchao Zhang     vector_length = 1;
274*0e3ece09SJunchao Zhang     while (vector_length < max_vector_length && vector_length * 6 < nnz_per_row) vector_length *= 2;
275076ba34aSJunchao Zhang   }
276076ba34aSJunchao Zhang 
277*0e3ece09SJunchao Zhang   // Determine rows per thread
278*0e3ece09SJunchao Zhang   if (rows_per_thread < 1) {
279*0e3ece09SJunchao Zhang     if (KokkosKernels::Impl::kk_is_gpu_exec_space<ExecutionSpace>()) rows_per_thread = 1;
280*0e3ece09SJunchao Zhang     else {
281*0e3ece09SJunchao Zhang       if (nnz_per_row < 20 && nnz > 5000000) {
282*0e3ece09SJunchao Zhang         rows_per_thread = 256;
283*0e3ece09SJunchao Zhang       } else rows_per_thread = 64;
284076ba34aSJunchao Zhang     }
285076ba34aSJunchao Zhang   }
286076ba34aSJunchao Zhang 
287*0e3ece09SJunchao Zhang   if (team_size < 1) {
288*0e3ece09SJunchao Zhang     if (KokkosKernels::Impl::kk_is_gpu_exec_space<ExecutionSpace>()) {
289*0e3ece09SJunchao Zhang       team_size = 256 / vector_length;
290076ba34aSJunchao Zhang     } else {
291*0e3ece09SJunchao Zhang       team_size = 1;
292*0e3ece09SJunchao Zhang     }
293076ba34aSJunchao Zhang   }
294076ba34aSJunchao Zhang 
295*0e3ece09SJunchao Zhang   rows_per_team = rows_per_thread * team_size;
296076ba34aSJunchao Zhang 
297*0e3ece09SJunchao Zhang   if (rows_per_team < 0) {
298*0e3ece09SJunchao Zhang     PetscInt nnz_per_team = 4096;
299*0e3ece09SJunchao Zhang     PetscInt conc         = ExecutionSpace().concurrency();
300*0e3ece09SJunchao Zhang     while ((conc * nnz_per_team * 4 > nnz) && (nnz_per_team > 256)) nnz_per_team /= 2;
301*0e3ece09SJunchao Zhang     rows_per_team = (nnz_per_team + nnz_per_row - 1) / nnz_per_row;
302*0e3ece09SJunchao Zhang   }
3033ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
304076ba34aSJunchao Zhang }
305076ba34aSJunchao Zhang 
306*0e3ece09SJunchao Zhang /*
307*0e3ece09SJunchao Zhang   Reduce two sets of global indices into local ones
308076ba34aSJunchao Zhang 
309076ba34aSJunchao Zhang   Input Parameters:
310*0e3ece09SJunchao Zhang +  n1          - size of garray1[], the first set
311*0e3ece09SJunchao Zhang .  garray1[n1] - a sorted global index array (without duplicates)
312*0e3ece09SJunchao Zhang .  m           - size of indices[], the second set
313*0e3ece09SJunchao Zhang -  indices[m]  - a unsorted global index array (might have duplicates), which will be updated on output into local ones
314076ba34aSJunchao Zhang 
315076ba34aSJunchao Zhang   Output Parameters:
316*0e3ece09SJunchao Zhang +  n2          - size of garray2[], the merged set, which combines garray1[] and indices[]
317*0e3ece09SJunchao Zhang .  garray2[n2] - allocated by callee using PetscMalloc1(). Contains sorted unique global indices (without duplicates). Caller needs to free it.
318*0e3ece09SJunchao Zhang .  map[n1]     - allocated by caller. It gives garray1[i] = garray2[map[i]]
319*0e3ece09SJunchao Zhang -  indices[m]  - on output, global indices in this array are rewritten with local ones, i.e, indices_input[i] = garray2[indices_output[i]]
320076ba34aSJunchao Zhang 
321*0e3ece09SJunchao Zhang    Example, say
322*0e3ece09SJunchao Zhang     n1         = 5
323*0e3ece09SJunchao Zhang     garray1[5] = {1, 4, 7, 8, 10}
324*0e3ece09SJunchao Zhang     m          = 4
325*0e3ece09SJunchao Zhang     indices[4] = {2, 4, 8, 9}
32611a5261eSBarry Smith 
327*0e3ece09SJunchao Zhang    Combining them together, we have 7 global indices in garray2[]
328*0e3ece09SJunchao Zhang     n2         = 7
329*0e3ece09SJunchao Zhang     garray2[7] = {1, 2, 4, 7, 8, 9, 10}
330*0e3ece09SJunchao Zhang 
331*0e3ece09SJunchao Zhang    And we have map[] to connect "garray1[i] = garray2[map[i]], i=[0,n1)"
332*0e3ece09SJunchao Zhang     map[5] = {0, 2, 3, 4, 6}
333*0e3ece09SJunchao Zhang 
334*0e3ece09SJunchao Zhang    On output, indices[] is updated with local indices
335*0e3ece09SJunchao Zhang     indices[4] = {1, 2, 4, 5}
336076ba34aSJunchao Zhang */
337*0e3ece09SJunchao Zhang static PetscErrorCode ReduceTwoSetsOfGlobalIndices(PetscInt n1, const PetscInt *garray1, PetscInt m, PetscInt *indices, PetscInt *n2_, PetscInt **garray2_, PetscInt *map)
338d71ae5a4SJacob Faibussowitsch {
339*0e3ece09SJunchao Zhang   PetscHMapI    g2l = nullptr;
340*0e3ece09SJunchao Zhang   PetscHashIter iter;
341*0e3ece09SJunchao Zhang   PetscInt      tot, key, val; // total unique global indices. key is global id; val is local id
342*0e3ece09SJunchao Zhang   PetscInt      n2, *garray2;
343076ba34aSJunchao Zhang 
344076ba34aSJunchao Zhang   PetscFunctionBegin;
345*0e3ece09SJunchao Zhang   tot = 0;
346*0e3ece09SJunchao Zhang   PetscCall(PetscHMapICreateWithSize(n1, &g2l));
347*0e3ece09SJunchao Zhang   for (PetscInt i = 0; i < m; i++) {                                // insert those in indices[]
348*0e3ece09SJunchao Zhang     PetscCall(PetscHMapIGetWithDefault(g2l, indices[i], -1, &val)); // if not exist, val is set with -1
349*0e3ece09SJunchao Zhang     if (val < 0) PetscCall(PetscHMapISet(g2l, indices[i], tot++));  // val < 0 means gid is not in the hash table yet
350076ba34aSJunchao Zhang   }
351076ba34aSJunchao Zhang 
352*0e3ece09SJunchao Zhang   for (PetscInt i = 0; i < n1; i++) { // insert those in garray1[]
353*0e3ece09SJunchao Zhang     PetscCall(PetscHMapIGetWithDefault(g2l, garray1[i], -1, &val));
354*0e3ece09SJunchao Zhang     if (val < 0) PetscCall(PetscHMapISet(g2l, garray1[i], tot++));
355076ba34aSJunchao Zhang   }
356076ba34aSJunchao Zhang 
357*0e3ece09SJunchao Zhang   // Pull out (unique) globals in the hash table and put them in garray2[]
358*0e3ece09SJunchao Zhang   n2 = tot;
359*0e3ece09SJunchao Zhang   PetscCall(PetscMalloc1(n2, &garray2));
360*0e3ece09SJunchao Zhang   tot = 0;
361*0e3ece09SJunchao Zhang   PetscHashIterBegin(g2l, iter);
362*0e3ece09SJunchao Zhang   while (!PetscHashIterAtEnd(g2l, iter)) {
363*0e3ece09SJunchao Zhang     PetscHashIterGetKey(g2l, iter, key);
364*0e3ece09SJunchao Zhang     PetscHashIterNext(g2l, iter);
365*0e3ece09SJunchao Zhang     garray2[tot++] = key;
366076ba34aSJunchao Zhang   }
367076ba34aSJunchao Zhang 
368*0e3ece09SJunchao Zhang   // Sort garray2[] and then map them to local indices starting from 0
369*0e3ece09SJunchao Zhang   PetscCall(PetscSortInt(n2, garray2));
370*0e3ece09SJunchao Zhang   PetscCall(PetscHMapIClear(g2l));
371*0e3ece09SJunchao Zhang   for (PetscInt i = 0; i < tot; i++) PetscCall(PetscHMapISet(g2l, garray2[i], i)); // i is the local id
372f0e6e2d1SJunchao Zhang 
373*0e3ece09SJunchao Zhang   // Rewrite indices[] with local indices
374f0e6e2d1SJunchao Zhang   for (PetscInt i = 0; i < m; i++) {
375*0e3ece09SJunchao Zhang     PetscCall(PetscHMapIGetWithDefault(g2l, indices[i], -1, &val));
376*0e3ece09SJunchao Zhang     PetscAssert(val >= 0, PETSC_COMM_SELF, PETSC_ERR_PLIB, "Met a negative local column index");
377*0e3ece09SJunchao Zhang     indices[i] = val;
378*0e3ece09SJunchao Zhang   }
379*0e3ece09SJunchao Zhang   // Record the map that maps garray1[i] to garray2[map[i]]
380*0e3ece09SJunchao Zhang   for (PetscInt i = 0; i < n1; i++) PetscCall(PetscHMapIGetWithDefault(g2l, garray1[i], -1, &map[i]));
381*0e3ece09SJunchao Zhang   PetscCall(PetscHMapIDestroy(&g2l));
382*0e3ece09SJunchao Zhang   *n2_      = n2;
383*0e3ece09SJunchao Zhang   *garray2_ = garray2;
384*0e3ece09SJunchao Zhang   PetscFunctionReturn(PETSC_SUCCESS);
385*0e3ece09SJunchao Zhang }
386f0e6e2d1SJunchao Zhang 
387*0e3ece09SJunchao Zhang /*
388*0e3ece09SJunchao Zhang   MatMPIAIJKokkosReduce - Reduce rows of a MPIAIJKOKKOS matrix (E, in split form) to produce another matrix (F, also in split form, stored in mm)
389*0e3ece09SJunchao Zhang 
390*0e3ece09SJunchao Zhang   It is the reverse of MatMPIAIJKokkosBcast() in some sense, but with a different signature since we do not really need a fully populated MPIAIJKOKKOS E.
391*0e3ece09SJunchao Zhang 
392*0e3ece09SJunchao Zhang   Think each row of E as a leaf, then the given ownerSF specifies roots for the leaves. Roots may connect to multiple leaves.
393*0e3ece09SJunchao Zhang   In this routine, we sparse-merge leaves (rows) at their roots to form potentially longer rows in F. F's number of rows will be nroots of ownerSF.
394*0e3ece09SJunchao Zhang 
395*0e3ece09SJunchao Zhang   Input Parameters:
396*0e3ece09SJunchao Zhang +  comm       - MPI communicator of E
397*0e3ece09SJunchao Zhang .  A          - diag block of E, using local column indices
398*0e3ece09SJunchao Zhang .  B          - off-diag block of E, using local column indices
399*0e3ece09SJunchao Zhang .  cstart      - (global) start column of Ed
400*0e3ece09SJunchao Zhang .  cend        - (global) end column + 1 of Ed.  In other words, E's column ownership is in range of [cstart, cend)
401*0e3ece09SJunchao Zhang .  garray1[n1] - global column indices of Eo. Here n1 is Eo's column size.
402*0e3ece09SJunchao Zhang .  ownerSF     - the SF specifies ownership (root) of rows in E
403*0e3ece09SJunchao Zhang .  reuse       - either MAT_INITIAL_MATRIX or MAT_REUSE_MATRIX
404*0e3ece09SJunchao Zhang -  mm          - to stash intermediate data structures for reuse
405*0e3ece09SJunchao Zhang 
406*0e3ece09SJunchao Zhang   Output Parameters:
407*0e3ece09SJunchao Zhang +  map[n1]  - allocated by caller. It maps garray1[] to garray2[]. See more at ReduceTwoSetsOfGlobalIndices().
408*0e3ece09SJunchao Zhang -  mm       - contains various info, such as garray2[], F (Fd, Fo) etc.
409*0e3ece09SJunchao Zhang 
410*0e3ece09SJunchao Zhang   Notes:
411*0e3ece09SJunchao Zhang   When reuse = MAT_REUSE_MATRIX, cstart, cend, garray1, ownerSF, map are not significant.
412*0e3ece09SJunchao Zhang 
413*0e3ece09SJunchao Zhang  */
414*0e3ece09SJunchao Zhang static PetscErrorCode MatMPIAIJKokkosReduceBegin(MPI_Comm comm, KokkosCsrMatrix A, KokkosCsrMatrix B, PetscInt cstart, PetscInt cend, const PetscInt *garray1, PetscSF ownerSF, MatReuse reuse, PetscInt *map, MatMatStruct_AtB *mm)
415*0e3ece09SJunchao Zhang {
416*0e3ece09SJunchao Zhang   PetscFunctionBegin;
417*0e3ece09SJunchao Zhang   if (reuse == MAT_INITIAL_MATRIX) {
418*0e3ece09SJunchao Zhang     PetscInt Em = A.numRows(), Fm;
419*0e3ece09SJunchao Zhang     PetscInt n1 = B.numCols();
420*0e3ece09SJunchao Zhang 
421*0e3ece09SJunchao Zhang     PetscCall(PetscSFGetGraph(ownerSF, &Fm, NULL, NULL, NULL)); // Fm = #rows of F = nroots of ownerSF
422*0e3ece09SJunchao Zhang 
423*0e3ece09SJunchao Zhang     // Do the analysis on host
424*0e3ece09SJunchao Zhang     auto                 Ai_h = Kokkos::create_mirror_view_and_copy(Kokkos::HostSpace(), A.graph.row_map);
425*0e3ece09SJunchao Zhang     auto                 Aj_h = Kokkos::create_mirror_view_and_copy(Kokkos::HostSpace(), A.graph.entries);
426*0e3ece09SJunchao Zhang     auto                 Bi_h = Kokkos::create_mirror_view_and_copy(Kokkos::HostSpace(), B.graph.row_map);
427*0e3ece09SJunchao Zhang     auto                 Bj_h = Kokkos::create_mirror_view_and_copy(Kokkos::HostSpace(), B.graph.entries);
428*0e3ece09SJunchao Zhang     const MatRowMapType *Ai = Ai_h.data(), *Bi = Bi_h.data();
429*0e3ece09SJunchao Zhang     const MatColIdxType *Aj = Aj_h.data(), *Bj = Bj_h.data();
430*0e3ece09SJunchao Zhang 
431*0e3ece09SJunchao Zhang     // Count how many nonzeros of each row in E are in the left of the diag block [cstart,cend)
432*0e3ece09SJunchao Zhang     PetscIntKokkosViewHost E_NzLeft_h("E_NzLeft_h", Em), E_RowLen_h("E_RowLen_h", Em);
433*0e3ece09SJunchao Zhang     PetscInt              *E_NzLeft = E_NzLeft_h.data(), *E_RowLen = E_RowLen_h.data();
434*0e3ece09SJunchao Zhang     for (PetscInt i = 0; i < Em; i++) {
435*0e3ece09SJunchao Zhang       const PetscInt *first, *last, *it;
436*0e3ece09SJunchao Zhang       PetscInt        count, step;
437*0e3ece09SJunchao Zhang       // std::lower_bound(first,last,cstart), but need to use global column indices
438*0e3ece09SJunchao Zhang       first = Bj + Bi[i];
439*0e3ece09SJunchao Zhang       last  = Bj + Bi[i + 1];
440f0e6e2d1SJunchao Zhang       count = last - first;
441f0e6e2d1SJunchao Zhang       while (count > 0) {
442f0e6e2d1SJunchao Zhang         it   = first;
443f0e6e2d1SJunchao Zhang         step = count / 2;
444f0e6e2d1SJunchao Zhang         it += step;
445*0e3ece09SJunchao Zhang         if (garray1[*it] < cstart) { // map local to global
446f0e6e2d1SJunchao Zhang           first = ++it;
447f0e6e2d1SJunchao Zhang           count -= step + 1;
448f0e6e2d1SJunchao Zhang         } else count = step;
449f0e6e2d1SJunchao Zhang       }
450*0e3ece09SJunchao Zhang       E_NzLeft[i] = first - (Bj + Bi[i]);
451*0e3ece09SJunchao Zhang       E_RowLen[i] = (Ai[i + 1] - Ai[i]) + (Bi[i + 1] - Bi[i]);
452f0e6e2d1SJunchao Zhang     }
453f0e6e2d1SJunchao Zhang 
454*0e3ece09SJunchao Zhang     // Get length of rows (i.e., sizes of leaves) that contribute to my roots
455*0e3ece09SJunchao Zhang     const PetscMPIInt *iranks, *ranks;
456*0e3ece09SJunchao Zhang     const PetscInt    *ioffset, *irootloc, *roffset, *rmine;
457*0e3ece09SJunchao Zhang     PetscInt           niranks, nranks;
458*0e3ece09SJunchao Zhang     MPI_Request       *reqs;
459*0e3ece09SJunchao Zhang     PetscMPIInt        tag;
460*0e3ece09SJunchao Zhang     PetscSF            reduceSF;
461*0e3ece09SJunchao Zhang     PetscInt          *sdisp, *rdisp;
462f0e6e2d1SJunchao Zhang 
463*0e3ece09SJunchao Zhang     PetscCall(PetscCommGetNewTag(comm, &tag));
464*0e3ece09SJunchao Zhang     PetscCall(PetscSFGetLeafRanks(ownerSF, &niranks, &iranks, &ioffset, &irootloc));  // get leaf ranks connecting to roots on this process (I'll recv from them)
465*0e3ece09SJunchao Zhang     PetscCall(PetscSFGetRootRanks(ownerSF, &nranks, &ranks, &roffset, &rmine, NULL)); // get root ranks this process connects (I'll send to them)
466f0e6e2d1SJunchao Zhang 
467*0e3ece09SJunchao Zhang     // Find out length of each row I will receive. Even for the same row index, when they are from
468*0e3ece09SJunchao Zhang     // different senders, they might have different lengths (and sparsity patterns)
469*0e3ece09SJunchao Zhang     PetscInt  sendRowCnt = roffset[nranks], recvRowCnt = ioffset[niranks];
470*0e3ece09SJunchao Zhang     PetscInt *sendRowLen, *recvRowLen; // lengths of rows of I need to send/recv per process
471f0e6e2d1SJunchao Zhang 
472*0e3ece09SJunchao Zhang     PetscCall(PetscMalloc5(sendRowCnt, &sendRowLen, recvRowCnt + 1, &recvRowLen, nranks, &sdisp, niranks + 1, &rdisp, nranks + niranks, &reqs));
473*0e3ece09SJunchao Zhang 
474*0e3ece09SJunchao Zhang     for (PetscInt i = 0; i < sendRowCnt; i++) sendRowLen[i] = E_RowLen[rmine[i]];
475*0e3ece09SJunchao Zhang     recvRowLen[0] = 0; // since we will make it in CSR format later
476*0e3ece09SJunchao Zhang     recvRowLen++;      // advance the pointer now
477*0e3ece09SJunchao Zhang     for (PetscInt i = 0; i < niranks; i++) { MPI_Irecv(&recvRowLen[ioffset[i]], ioffset[i + 1] - ioffset[i], MPIU_INT, iranks[i], tag, comm, &reqs[nranks + i]); }
478*0e3ece09SJunchao Zhang     for (PetscInt i = 0; i < nranks; i++) { MPI_Isend(&sendRowLen[roffset[i]], roffset[i + 1] - roffset[i], MPIU_INT, ranks[i], tag, comm, &reqs[i]); }
479*0e3ece09SJunchao Zhang     PetscCallMPI(MPI_Waitall(nranks + niranks, reqs, MPI_STATUSES_IGNORE));
480*0e3ece09SJunchao Zhang 
481*0e3ece09SJunchao Zhang     // Build the real PetscSF for reducing E rows (buffer to buffer)
482*0e3ece09SJunchao Zhang     rdisp[0] = 0;
483*0e3ece09SJunchao Zhang     for (PetscInt i = 0; i < niranks; i++) {
484*0e3ece09SJunchao Zhang       rdisp[i + 1] = rdisp[i];
485*0e3ece09SJunchao Zhang       for (PetscInt j = ioffset[i]; j < ioffset[i + 1]; j++) { rdisp[i + 1] += recvRowLen[j]; }
486*0e3ece09SJunchao Zhang     }
487*0e3ece09SJunchao Zhang     recvRowLen--; // put it back into csr format
488*0e3ece09SJunchao Zhang     for (PetscInt i = 0; i < recvRowCnt; i++) recvRowLen[i + 1] += recvRowLen[i];
489*0e3ece09SJunchao Zhang 
490*0e3ece09SJunchao Zhang     for (PetscInt i = 0; i < nranks; i++) { MPI_Irecv(&sdisp[i], 1, MPIU_INT, ranks[i], tag, comm, &reqs[i]); }
491*0e3ece09SJunchao Zhang     for (PetscInt i = 0; i < niranks; i++) { MPI_Isend(&rdisp[i], 1, MPIU_INT, iranks[i], tag, comm, &reqs[nranks + i]); }
492*0e3ece09SJunchao Zhang     PetscCallMPI(MPI_Waitall(nranks + niranks, reqs, MPI_STATUSES_IGNORE));
493*0e3ece09SJunchao Zhang 
494*0e3ece09SJunchao Zhang     PetscInt     nleaves = 0, Enz = 0;    // leaves are nonzeros I will send
495*0e3ece09SJunchao Zhang     PetscInt     nroots = rdisp[niranks]; // roots are nonzeros I will recv
496*0e3ece09SJunchao Zhang     PetscSFNode *iremote;
497*0e3ece09SJunchao Zhang 
498*0e3ece09SJunchao Zhang     for (PetscInt i = 0; i < Em; i++) Enz += E_RowLen[i];
499*0e3ece09SJunchao Zhang     PetscAssert(A.nnz() + B.nnz() == Enz, comm, PETSC_ERR_PLIB, "Enz should be equal to sum of nnz of A and B");
500*0e3ece09SJunchao Zhang     PetscCallMPI(PetscMalloc1(Enz, &iremote)); // no free, since we give ownership to reduceSF
501*0e3ece09SJunchao Zhang 
502*0e3ece09SJunchao Zhang     for (PetscInt i = 0; i < nranks; i++) {
503*0e3ece09SJunchao Zhang       PetscInt count = 0;
504*0e3ece09SJunchao Zhang       for (PetscInt j = roffset[i]; j < roffset[i + 1]; j++) count += E_RowLen[rmine[j]];
505*0e3ece09SJunchao Zhang       for (PetscInt j = 0; j < count; j++) {
506*0e3ece09SJunchao Zhang         iremote[nleaves + j].rank  = ranks[i];
507*0e3ece09SJunchao Zhang         iremote[nleaves + j].index = sdisp[i] + j;
508*0e3ece09SJunchao Zhang       }
509*0e3ece09SJunchao Zhang       nleaves += count;
510*0e3ece09SJunchao Zhang     }
511*0e3ece09SJunchao Zhang     PetscCheck(nleaves == Enz, comm, PETSC_ERR_PLIB, "nleaves should be equal to Enz");
512*0e3ece09SJunchao Zhang 
513*0e3ece09SJunchao Zhang     PetscCall(PetscSFCreate(comm, &reduceSF));
514*0e3ece09SJunchao Zhang     PetscCall(PetscSFSetGraph(reduceSF, nroots, nleaves, NULL, PETSC_OWN_POINTER, iremote, PETSC_OWN_POINTER));
515*0e3ece09SJunchao Zhang 
516*0e3ece09SJunchao Zhang     // Copy (global) column indices of the needed rows in E to sendCol[], and then PetscSFReduce to recvCol[]
517*0e3ece09SJunchao Zhang     PetscInt *sendCol, *recvCol;
518*0e3ece09SJunchao Zhang     PetscCall(PetscMalloc2(nleaves, &sendCol, nroots, &recvCol));
519*0e3ece09SJunchao Zhang     for (PetscInt k = 0; k < roffset[nranks]; k++) {
520*0e3ece09SJunchao Zhang       PetscInt  i      = rmine[k]; // row to be copied
521*0e3ece09SJunchao Zhang       PetscInt *buf    = &sendCol[Ai[i] + Bi[i]];
522*0e3ece09SJunchao Zhang       PetscInt  nzLeft = E_NzLeft[i];
523*0e3ece09SJunchao Zhang       PetscInt  alen = Ai[i + 1] - Ai[i], blen = Bi[i + 1] - Bi[i];
524*0e3ece09SJunchao Zhang       for (PetscInt j = 0; j < alen + blen; j++) {
525*0e3ece09SJunchao Zhang         if (j < nzLeft) {
526*0e3ece09SJunchao Zhang           buf[j] = garray1[Bj[Bi[i] + j]]; // left B, in global
527*0e3ece09SJunchao Zhang         } else if (j < nzLeft + alen) {
528*0e3ece09SJunchao Zhang           buf[j] = Aj[Ai[i] + j - nzLeft] + cstart; // diag A, also in global
529*0e3ece09SJunchao Zhang         } else {
530*0e3ece09SJunchao Zhang           buf[j] = garray1[Bj[Bi[i] + j - alen]]; // right B, in global
531*0e3ece09SJunchao Zhang         }
532*0e3ece09SJunchao Zhang       }
533*0e3ece09SJunchao Zhang     }
534*0e3ece09SJunchao Zhang     PetscCall(PetscSFReduceWithMemTypeBegin(reduceSF, MPIU_INT, PETSC_MEMTYPE_HOST, sendCol, PETSC_MEMTYPE_HOST, recvCol, MPI_REPLACE));
535*0e3ece09SJunchao Zhang     PetscCall(PetscSFReduceEnd(reduceSF, MPIU_INT, sendCol, recvCol, MPI_REPLACE));
536*0e3ece09SJunchao Zhang 
537*0e3ece09SJunchao Zhang     // With recvCol[], we do a series of analysis to get i, j of Fd, Fo, and build plans to reduce nonzeros in recv buffers to Fd and Fo
538*0e3ece09SJunchao Zhang     PetscInt *recvRowPerm, *recvColSorted;
539*0e3ece09SJunchao Zhang     PetscInt *recvNzPerm, *recvNzPermSorted;
540*0e3ece09SJunchao Zhang     PetscCall(PetscMalloc4(recvRowCnt, &recvRowPerm, nroots, &recvColSorted, nroots, &recvNzPerm, nroots, &recvNzPermSorted));
541*0e3ece09SJunchao Zhang 
542*0e3ece09SJunchao Zhang     for (PetscInt i = 0; i < nroots; i++) recvNzPerm[i] = i;                   // numbering all received nonzeros
543*0e3ece09SJunchao Zhang     for (PetscInt i = 0; i < recvRowCnt; i++) recvRowPerm[i] = i;              // put up a permutation array, so that after sorting we know where to get a row in recvCol[]
544*0e3ece09SJunchao Zhang     PetscCall(PetscSortIntWithPermutation(recvRowCnt, irootloc, recvRowPerm)); // irootloc[] (owned by ownerSF) won't be changed
545*0e3ece09SJunchao Zhang 
546*0e3ece09SJunchao Zhang     // i[] array, nz are always easiest to compute
547*0e3ece09SJunchao Zhang     MatRowMapKokkosViewHost Fdi_h("Fdi_h", Fm + 1), Foi_h("Foi_h", Fm + 1);
548*0e3ece09SJunchao Zhang     MatRowMapType          *Fdi, *Foi;
549*0e3ece09SJunchao Zhang     PetscInt                FnzDups = 0, Fdnz = 0, FdnzDups = 0, Fonz = 0, FonzDups = 0; // nz (with or without dups) in F, Fd, Fo
550*0e3ece09SJunchao Zhang     PetscInt                iter;
551*0e3ece09SJunchao Zhang 
552*0e3ece09SJunchao Zhang     Kokkos::deep_copy(Fdi_h, 0); // zero, as we will do 'val++' on them
553*0e3ece09SJunchao Zhang     Kokkos::deep_copy(Foi_h, 0);
554*0e3ece09SJunchao Zhang     Fdi  = Fdi_h.data() + 1; // +1 for easy indexing in code below
555*0e3ece09SJunchao Zhang     Foi  = Foi_h.data() + 1;
556*0e3ece09SJunchao Zhang     iter = 0;
557*0e3ece09SJunchao Zhang     while (iter < recvRowCnt) { // iter over received rows
558*0e3ece09SJunchao Zhang       PetscInt curRowIdx = irootloc[recvRowPerm[iter]];
559*0e3ece09SJunchao Zhang       PetscInt dupRows   = 1; // current row has this many contributing rows (of various sparsity patterns)
560*0e3ece09SJunchao Zhang 
561*0e3ece09SJunchao Zhang       while (iter + dupRows < recvRowCnt && irootloc[recvRowPerm[iter + dupRows]] == curRowIdx) dupRows++;
562*0e3ece09SJunchao Zhang 
563*0e3ece09SJunchao Zhang       // Copy column indices (and their permutation) of these rows into recvColSorted & recvNzPermSorted
564*0e3ece09SJunchao Zhang       PetscInt  nz    = 0; // nz (with dups) in the current row
565*0e3ece09SJunchao Zhang       PetscInt *jbuf  = recvColSorted + FnzDups;
566*0e3ece09SJunchao Zhang       PetscInt *pbuf  = recvNzPermSorted + FnzDups;
567*0e3ece09SJunchao Zhang       PetscInt *jbuf2 = jbuf; // temp pointers
568*0e3ece09SJunchao Zhang       PetscInt *pbuf2 = pbuf;
569*0e3ece09SJunchao Zhang       for (PetscInt d = 0; d < dupRows; d++) {
570*0e3ece09SJunchao Zhang         PetscInt i   = recvRowPerm[iter + d];
571*0e3ece09SJunchao Zhang         PetscInt len = recvRowLen[i + 1] - recvRowLen[i];
572*0e3ece09SJunchao Zhang         PetscCall(PetscArraycpy(jbuf2, &recvCol[recvRowLen[i]], len));
573*0e3ece09SJunchao Zhang         PetscCall(PetscArraycpy(pbuf2, &recvNzPerm[recvRowLen[i]], len));
574*0e3ece09SJunchao Zhang         jbuf2 += len;
575*0e3ece09SJunchao Zhang         pbuf2 += len;
576*0e3ece09SJunchao Zhang         nz += len;
577*0e3ece09SJunchao Zhang       }
578*0e3ece09SJunchao Zhang       PetscCall(PetscIntSortSemiOrderedWithArray(nz, jbuf, pbuf)); // It could be improved with k-way merge sort, since the rows are already sorted
579*0e3ece09SJunchao Zhang 
580*0e3ece09SJunchao Zhang       // Scan column indices (in jbuf[0,nz), might have dups) of this row, and see how many go to Fd and how many go to Fo
581*0e3ece09SJunchao Zhang       PetscInt cur = 0;
582*0e3ece09SJunchao Zhang       while (cur < nz) {
583*0e3ece09SJunchao Zhang         PetscInt curColIdx = jbuf[cur];
584*0e3ece09SJunchao Zhang         PetscInt dups      = 1;
585*0e3ece09SJunchao Zhang 
586*0e3ece09SJunchao Zhang         while (cur + dups < nz && jbuf[cur + dups] == curColIdx) dups++;
587*0e3ece09SJunchao Zhang         if (curColIdx >= cstart && curColIdx < cend) {
588*0e3ece09SJunchao Zhang           Fdi[curRowIdx]++;
589*0e3ece09SJunchao Zhang           FdnzDups += dups;
590*0e3ece09SJunchao Zhang         } else {
591*0e3ece09SJunchao Zhang           Foi[curRowIdx]++;
592*0e3ece09SJunchao Zhang           FonzDups += dups;
593*0e3ece09SJunchao Zhang         }
594*0e3ece09SJunchao Zhang         cur += dups;
595*0e3ece09SJunchao Zhang       }
596*0e3ece09SJunchao Zhang 
597*0e3ece09SJunchao Zhang       FnzDups += nz;
598*0e3ece09SJunchao Zhang       iter += dupRows; // Move to next unique row
599*0e3ece09SJunchao Zhang     }
600*0e3ece09SJunchao Zhang 
601*0e3ece09SJunchao Zhang     Fdi = Fdi_h.data(); // restore Fdi, Foi and make them CSR
602*0e3ece09SJunchao Zhang     Foi = Foi_h.data();
603*0e3ece09SJunchao Zhang     for (PetscInt i = 0; i < Fm; i++) {
604*0e3ece09SJunchao Zhang       Fdi[i + 1] += Fdi[i];
605*0e3ece09SJunchao Zhang       Foi[i + 1] += Foi[i];
606*0e3ece09SJunchao Zhang     }
607*0e3ece09SJunchao Zhang     Fdnz = Fdi[Fm];
608*0e3ece09SJunchao Zhang     Fonz = Foi[Fm];
609*0e3ece09SJunchao Zhang     PetscCall(PetscFree2(sendCol, recvCol));
610*0e3ece09SJunchao Zhang 
611*0e3ece09SJunchao Zhang     // Allocate j, jmap, jperm for Fd and Fo
612*0e3ece09SJunchao Zhang     MatColIdxKokkosViewHost Fdj_h("Fdj_h", Fdnz), Foj_h("Foj_h", Fonz);
613*0e3ece09SJunchao Zhang     MatRowMapKokkosViewHost Fdjmap_h("Fdjmap_h", Fdnz + 1), Fojmap_h("Fojmap_h", Fonz + 1); // +1 to make csr
614*0e3ece09SJunchao Zhang     MatRowMapKokkosViewHost Fdjperm_h("Fdjperm_h", FdnzDups), Fojperm_h("Fojperm_h", FonzDups);
615*0e3ece09SJunchao Zhang     MatColIdxType          *Fdj = Fdj_h.data(), *Foj = Foj_h.data();
616*0e3ece09SJunchao Zhang     MatRowMapType          *Fdjmap = Fdjmap_h.data(), *Fojmap = Fojmap_h.data();
617*0e3ece09SJunchao Zhang     MatRowMapType          *Fdjperm = Fdjperm_h.data(), *Fojperm = Fojperm_h.data();
618*0e3ece09SJunchao Zhang 
619*0e3ece09SJunchao Zhang     // Scan recvColSorted[] again, and fill j, jmap, jperm for Fd and Fo
620*0e3ece09SJunchao Zhang     Fdjmap[0] = 0;
621*0e3ece09SJunchao Zhang     Fojmap[0] = 0;
622*0e3ece09SJunchao Zhang     FnzDups   = 0;
623*0e3ece09SJunchao Zhang     Fdnz      = 0;
624*0e3ece09SJunchao Zhang     Fonz      = 0;
625*0e3ece09SJunchao Zhang     iter      = 0; // iter over received rows
626*0e3ece09SJunchao Zhang     while (iter < recvRowCnt) {
627*0e3ece09SJunchao Zhang       PetscInt curRowIdx = irootloc[recvRowPerm[iter]]; // current row idx
628*0e3ece09SJunchao Zhang       PetscInt dupRows   = 1;                           // It has this many contributing rows (of various lengths)
629*0e3ece09SJunchao Zhang       PetscInt nz        = 0;                           // nz (with dups) in the current row
630*0e3ece09SJunchao Zhang 
631*0e3ece09SJunchao Zhang       while (iter + dupRows < recvRowCnt && irootloc[recvRowPerm[iter + dupRows]] == curRowIdx) dupRows++;
632*0e3ece09SJunchao Zhang       for (PetscInt d = 0; d < dupRows; d++) {
633*0e3ece09SJunchao Zhang         PetscInt i = recvRowPerm[iter + d];
634*0e3ece09SJunchao Zhang         nz += recvRowLen[i + 1] - recvRowLen[i];
635*0e3ece09SJunchao Zhang       }
636*0e3ece09SJunchao Zhang 
637*0e3ece09SJunchao Zhang       PetscInt *jbuf = recvColSorted + FnzDups;
638*0e3ece09SJunchao Zhang       // Scan columns (in jbuf[0,nz) of this row, copy them and their permutation to j[] and jperm[] of Fd and Fo
639*0e3ece09SJunchao Zhang       PetscInt cur = 0;
640*0e3ece09SJunchao Zhang       while (cur < nz) {
641*0e3ece09SJunchao Zhang         PetscInt curColIdx = jbuf[cur];
642*0e3ece09SJunchao Zhang         PetscInt dups      = 1;
643*0e3ece09SJunchao Zhang 
644*0e3ece09SJunchao Zhang         while (cur + dups < nz && jbuf[cur + dups] == curColIdx) dups++;
645*0e3ece09SJunchao Zhang         if (curColIdx >= cstart && curColIdx < cend) {
646*0e3ece09SJunchao Zhang           Fdj[Fdnz]        = curColIdx - cstart; // easily convert to local
647*0e3ece09SJunchao Zhang           Fdjmap[Fdnz + 1] = Fdjmap[Fdnz] + dups;
648*0e3ece09SJunchao Zhang           for (PetscInt j = 0; j < dups; j++) Fdjperm[Fdjmap[Fdnz] + j] = recvNzPermSorted[FnzDups + j];
649*0e3ece09SJunchao Zhang           FdnzDups += dups;
650*0e3ece09SJunchao Zhang           Fdnz++;
651*0e3ece09SJunchao Zhang         } else {
652*0e3ece09SJunchao Zhang           Foj[Fonz]        = curColIdx; // in global
653*0e3ece09SJunchao Zhang           Fojmap[Fonz + 1] = Fojmap[Fonz] + dups;
654*0e3ece09SJunchao Zhang           for (PetscInt j = 0; j < dups; j++) Fojperm[Fojmap[Fonz] + j] = recvNzPermSorted[FnzDups + j];
655*0e3ece09SJunchao Zhang           FonzDups += dups;
656*0e3ece09SJunchao Zhang           Fonz++;
657*0e3ece09SJunchao Zhang         }
658*0e3ece09SJunchao Zhang         cur += dups;
659*0e3ece09SJunchao Zhang         FnzDups += dups;
660*0e3ece09SJunchao Zhang       }
661*0e3ece09SJunchao Zhang       iter += dupRows; // Move to next unique row
662*0e3ece09SJunchao Zhang     }
663*0e3ece09SJunchao Zhang     PetscCall(PetscFree4(recvRowPerm, recvColSorted, recvNzPerm, recvNzPermSorted));
664*0e3ece09SJunchao Zhang     PetscCall(PetscFree5(sendRowLen, recvRowLen, sdisp, rdisp, reqs));
665*0e3ece09SJunchao Zhang 
666*0e3ece09SJunchao Zhang     // Combine global column indices in garray1[] and Foj[]
667*0e3ece09SJunchao Zhang     PetscInt n2, *garray2;
668*0e3ece09SJunchao Zhang 
669*0e3ece09SJunchao Zhang     PetscCall(ReduceTwoSetsOfGlobalIndices(n1, garray1, Fonz, Foj, &n2, &garray2, map));
670*0e3ece09SJunchao Zhang     mm->sf       = reduceSF;
671*0e3ece09SJunchao Zhang     mm->leafBuf  = MatScalarKokkosView("leafBuf", nleaves);
672*0e3ece09SJunchao Zhang     mm->rootBuf  = MatScalarKokkosView("rootBuf", nroots);
673*0e3ece09SJunchao Zhang     mm->garray   = garray2; // give owership, so no free
674*0e3ece09SJunchao Zhang     mm->n        = n2;
675*0e3ece09SJunchao Zhang     mm->E_NzLeft = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), E_NzLeft_h);
676*0e3ece09SJunchao Zhang     mm->Fdjmap   = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Fdjmap_h);
677*0e3ece09SJunchao Zhang     mm->Fdjperm  = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Fdjperm_h);
678*0e3ece09SJunchao Zhang     mm->Fojmap   = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Fojmap_h);
679*0e3ece09SJunchao Zhang     mm->Fojperm  = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Fojperm_h);
680*0e3ece09SJunchao Zhang 
681*0e3ece09SJunchao Zhang     // Output Fd and Fo in KokkosCsrMatrix format
682*0e3ece09SJunchao Zhang     MatScalarKokkosView Fda_d("Fda_d", Fdnz);
683*0e3ece09SJunchao Zhang     MatRowMapKokkosView Fdi_d = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Fdi_h);
684*0e3ece09SJunchao Zhang     MatColIdxKokkosView Fdj_d = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Fdj_h);
685*0e3ece09SJunchao Zhang     MatScalarKokkosView Foa_d("Foa_d", Fonz);
686*0e3ece09SJunchao Zhang     MatRowMapKokkosView Foi_d = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Foi_h);
687*0e3ece09SJunchao Zhang     MatColIdxKokkosView Foj_d = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Foj_h);
688*0e3ece09SJunchao Zhang 
689*0e3ece09SJunchao Zhang     PetscCallCXX(mm->Fd = KokkosCsrMatrix("Fd", Fm, cend - cstart, Fdnz, Fda_d, Fdi_d, Fdj_d));
690*0e3ece09SJunchao Zhang     PetscCallCXX(mm->Fo = KokkosCsrMatrix("Fo", Fm, n2, Fonz, Foa_d, Foi_d, Foj_d)); // Fo's column size is n2, length of garray2[]
691*0e3ece09SJunchao Zhang 
692*0e3ece09SJunchao Zhang     // Compute kernel launch parameters in merging E
693*0e3ece09SJunchao Zhang     PetscInt teamSize, vectorLength, rowsPerTeam;
694*0e3ece09SJunchao Zhang 
695*0e3ece09SJunchao Zhang     teamSize = vectorLength = rowsPerTeam = -1;
696*0e3ece09SJunchao Zhang     PetscCall(MatMergeGetLaunchParameters<DefaultExecutionSpace>(Em, Enz, -1, teamSize, vectorLength, rowsPerTeam));
697*0e3ece09SJunchao Zhang     mm->E_TeamSize     = teamSize;
698*0e3ece09SJunchao Zhang     mm->E_VectorLength = vectorLength;
699*0e3ece09SJunchao Zhang     mm->E_RowsPerTeam  = rowsPerTeam;
700*0e3ece09SJunchao Zhang   } else PetscCheck(reuse == MAT_REUSE_MATRIX, comm, PETSC_ERR_PLIB, "Unsupported MatReuse enum %d", reuse);
701*0e3ece09SJunchao Zhang 
702*0e3ece09SJunchao Zhang   // Handy aliases
703*0e3ece09SJunchao Zhang   auto       &Aa           = A.values;
704*0e3ece09SJunchao Zhang   auto       &Ba           = B.values;
705*0e3ece09SJunchao Zhang   const auto &Ai           = A.graph.row_map;
706*0e3ece09SJunchao Zhang   const auto &Bi           = B.graph.row_map;
707*0e3ece09SJunchao Zhang   const auto &E_NzLeft     = mm->E_NzLeft;
708*0e3ece09SJunchao Zhang   auto       &leafBuf      = mm->leafBuf;
709*0e3ece09SJunchao Zhang   auto       &rootBuf      = mm->rootBuf;
710*0e3ece09SJunchao Zhang   PetscSF     reduceSF     = mm->sf;
711*0e3ece09SJunchao Zhang   PetscInt    Em           = A.numRows();
712*0e3ece09SJunchao Zhang   PetscInt    teamSize     = mm->E_TeamSize;
713*0e3ece09SJunchao Zhang   PetscInt    vectorLength = mm->E_VectorLength;
714*0e3ece09SJunchao Zhang   PetscInt    rowsPerTeam  = mm->E_RowsPerTeam;
715*0e3ece09SJunchao Zhang   PetscInt    workSets     = (Em + rowsPerTeam - 1) / rowsPerTeam;
716*0e3ece09SJunchao Zhang 
717*0e3ece09SJunchao Zhang   // Copy rows in A/B of E to leafBuf, then pass it to rootBuf
718*0e3ece09SJunchao Zhang   PetscCallCXX(Kokkos::parallel_for(
719*0e3ece09SJunchao Zhang     Kokkos::TeamPolicy<>(workSets, teamSize, vectorLength), KOKKOS_LAMBDA(const KokkosTeamMemberType &t) {
720*0e3ece09SJunchao Zhang       Kokkos::parallel_for(Kokkos::TeamThreadRange(t, 0, rowsPerTeam), [&](PetscInt k) {
721*0e3ece09SJunchao Zhang         PetscInt i = t.league_rank() * rowsPerTeam + k; // i-th row in F
722*0e3ece09SJunchao Zhang         if (i < Em) {
723*0e3ece09SJunchao Zhang           PetscInt disp   = Ai(i) + Bi(i);
724*0e3ece09SJunchao Zhang           PetscInt alen   = Ai(i + 1) - Ai(i);
725*0e3ece09SJunchao Zhang           PetscInt blen   = Bi(i + 1) - Bi(i);
726*0e3ece09SJunchao Zhang           PetscInt nzleft = E_NzLeft(i);
727*0e3ece09SJunchao Zhang 
728*0e3ece09SJunchao Zhang           Kokkos::parallel_for(Kokkos::ThreadVectorRange(t, alen + blen), [&](PetscInt j) {
729*0e3ece09SJunchao Zhang             MatScalar &val = leafBuf(disp + j);
730*0e3ece09SJunchao Zhang             if (j < nzleft) { // B left
731*0e3ece09SJunchao Zhang               val = Ba(Bi(i) + j);
732*0e3ece09SJunchao Zhang             } else if (j < nzleft + alen) { // diag A
733*0e3ece09SJunchao Zhang               val = Aa(Ai(i) + j - nzleft);
734*0e3ece09SJunchao Zhang             } else { // B right
735*0e3ece09SJunchao Zhang               val = Ba(Bi(i) + j - alen);
736f0e6e2d1SJunchao Zhang             }
737f0e6e2d1SJunchao Zhang           });
738f0e6e2d1SJunchao Zhang         }
739f0e6e2d1SJunchao Zhang       });
740*0e3ece09SJunchao Zhang     }));
741*0e3ece09SJunchao Zhang   PetscCall(PetscSFReduceWithMemTypeBegin(reduceSF, MPIU_SCALAR, PETSC_MEMTYPE_KOKKOS, leafBuf.data(), PETSC_MEMTYPE_KOKKOS, rootBuf.data(), MPI_REPLACE));
742f0e6e2d1SJunchao Zhang   PetscFunctionReturn(PETSC_SUCCESS);
743f0e6e2d1SJunchao Zhang }
744*0e3ece09SJunchao Zhang 
745*0e3ece09SJunchao Zhang // To finsih MatMPIAIJKokkosReduce.
746*0e3ece09SJunchao Zhang static PetscErrorCode MatMPIAIJKokkosReduceEnd(MPI_Comm comm, KokkosCsrMatrix A, KokkosCsrMatrix B, PetscInt cstart, PetscInt cend, const PetscInt *garray1, PetscSF ownerSF, MatReuse reuse, PetscInt *map, MatMatStruct_AtB *mm)
747*0e3ece09SJunchao Zhang {
748*0e3ece09SJunchao Zhang   PetscFunctionBegin;
749*0e3ece09SJunchao Zhang   auto       &leafBuf  = mm->leafBuf;
750*0e3ece09SJunchao Zhang   auto       &rootBuf  = mm->rootBuf;
751*0e3ece09SJunchao Zhang   auto       &Fda      = mm->Fd.values;
752*0e3ece09SJunchao Zhang   const auto &Fdjmap   = mm->Fdjmap;
753*0e3ece09SJunchao Zhang   const auto &Fdjperm  = mm->Fdjperm;
754*0e3ece09SJunchao Zhang   auto        Fdnz     = mm->Fd.nnz();
755*0e3ece09SJunchao Zhang   auto       &Foa      = mm->Fo.values;
756*0e3ece09SJunchao Zhang   const auto &Fojmap   = mm->Fojmap;
757*0e3ece09SJunchao Zhang   const auto &Fojperm  = mm->Fojperm;
758*0e3ece09SJunchao Zhang   auto        Fonz     = mm->Fo.nnz();
759*0e3ece09SJunchao Zhang   PetscSF     reduceSF = mm->sf;
760*0e3ece09SJunchao Zhang 
761*0e3ece09SJunchao Zhang   PetscCall(PetscSFReduceEnd(reduceSF, MPIU_SCALAR, leafBuf.data(), rootBuf.data(), MPI_REPLACE));
762*0e3ece09SJunchao Zhang 
763*0e3ece09SJunchao Zhang   // Reduce data in rootBuf to Fd and Fo
764*0e3ece09SJunchao Zhang   PetscCallCXX(Kokkos::parallel_for(
765*0e3ece09SJunchao Zhang     Fdnz, KOKKOS_LAMBDA(const MatRowMapType i) {
766*0e3ece09SJunchao Zhang       PetscScalar sum = 0.0;
767*0e3ece09SJunchao Zhang       for (MatRowMapType k = Fdjmap(i); k < Fdjmap(i + 1); k++) sum += rootBuf(Fdjperm(k));
768*0e3ece09SJunchao Zhang       Fda(i) = sum;
769*0e3ece09SJunchao Zhang     }));
770*0e3ece09SJunchao Zhang 
771*0e3ece09SJunchao Zhang   PetscCallCXX(Kokkos::parallel_for(
772*0e3ece09SJunchao Zhang     Fonz, KOKKOS_LAMBDA(const MatRowMapType i) {
773*0e3ece09SJunchao Zhang       PetscScalar sum = 0.0;
774*0e3ece09SJunchao Zhang       for (MatRowMapType k = Fojmap(i); k < Fojmap(i + 1); k++) sum += rootBuf(Fojperm(k));
775*0e3ece09SJunchao Zhang       Foa(i) = sum;
776*0e3ece09SJunchao Zhang     }));
777*0e3ece09SJunchao Zhang   PetscFunctionReturn(PETSC_SUCCESS);
778*0e3ece09SJunchao Zhang }
779*0e3ece09SJunchao Zhang 
780*0e3ece09SJunchao Zhang /*
781*0e3ece09SJunchao Zhang   MatMPIAIJKokkosBcast - Bcast local rows of a MPIAIJKOKKOS matrix (E) to produce a local matrix (F, stored in mm) in split form
782*0e3ece09SJunchao Zhang 
783*0e3ece09SJunchao Zhang   This is a complex routine. It is essentially the MPIAIJKOKKOS counterpart of MatGetBrowsOfAoCols_MPIAIJ, but supports
784*0e3ece09SJunchao Zhang   device and involves various index mapping.
785*0e3ece09SJunchao Zhang 
786*0e3ece09SJunchao Zhang   In the given ownerSF, leaves correspond to rows in F, and roots correspond to rows in E. Roots may connect to multiple leaves.
787*0e3ece09SJunchao Zhang   Suppose F's j-th row is connected to a root identified by PetscSFNode (k,i), it means we need to bcast the i-th row of E on rank k
788*0e3ece09SJunchao Zhang   to j-th row of F. ownerSF is not an arbitrary SF, instead it is the Mvctx of another MPIAIJ matrix A that is able to perform A*E.
789*0e3ece09SJunchao Zhang   F has the same column layout as E.
790*0e3ece09SJunchao Zhang 
791*0e3ece09SJunchao Zhang   Conceptually F has global column indices. In this routine, we spit F into diagonal Fd and off-diagonal Fo.
792*0e3ece09SJunchao Zhang   Fd uses local column indices, which are easy to compute. We just need to substract the "local column range start" from the global indices.
793*0e3ece09SJunchao Zhang   Fo had global column indices at first. We will reduce them into local ones. In doing that, we also take into account the global
794*0e3ece09SJunchao Zhang   column indices that E's off-diag block has. Let's say there are n1 such indices stored in garray1[]. We will reduce them along with
795*0e3ece09SJunchao Zhang   column indices in Fo and update Fo with local indices.
796*0e3ece09SJunchao Zhang 
797*0e3ece09SJunchao Zhang    Input Parameters:
798*0e3ece09SJunchao Zhang +   E       - the MPIAIJKOKKOS matrix
799*0e3ece09SJunchao Zhang .   ownerSF - the ownership SF (nonsignificant in MAT_REUSE_MATRIX)
800*0e3ece09SJunchao Zhang .   reuse   - either MAT_INITIAL_MATRIX or MAT_REUSE_MATRIX
801*0e3ece09SJunchao Zhang -   mm      - to stash matproduct intermediate data structures
802*0e3ece09SJunchao Zhang 
803*0e3ece09SJunchao Zhang     Output Parameters:
804*0e3ece09SJunchao Zhang +   map[n1] - allocated by caller. It maps garray1[] to garray2[]. See more at ReduceTwoSetsOfGlobalIndices.
805*0e3ece09SJunchao Zhang -   mm      - contains various info, such as garray2[], Fd, Fo, etc.
806*0e3ece09SJunchao Zhang 
807*0e3ece09SJunchao Zhang     Notes:
808*0e3ece09SJunchao Zhang     When reuse = MAT_REUSE_MATRIX, ownerSF, map are not significant.
809*0e3ece09SJunchao Zhang     The routine is provide in split-phase form MatMPIAIJKokkosBcastBegin/End() to provide computation/communication opportunities.
810*0e3ece09SJunchao Zhang */
811*0e3ece09SJunchao Zhang static PetscErrorCode MatMPIAIJKokkosBcastBegin(Mat E, PetscSF ownerSF, MatReuse reuse, PetscInt *map, MatMatStruct_AB *mm)
812*0e3ece09SJunchao Zhang {
813*0e3ece09SJunchao Zhang   Mat_MPIAIJ       *empi = static_cast<Mat_MPIAIJ *>(E->data);
814*0e3ece09SJunchao Zhang   Mat               A = empi->A, B = empi->B; // diag and off-diag
815*0e3ece09SJunchao Zhang   Mat_SeqAIJKokkos *akok = static_cast<Mat_SeqAIJKokkos *>(A->spptr), *bkok = static_cast<Mat_SeqAIJKokkos *>(B->spptr);
816*0e3ece09SJunchao Zhang   PetscInt          Em = E->rmap->n; // #local rows
817*0e3ece09SJunchao Zhang   MPI_Comm          comm;
818*0e3ece09SJunchao Zhang 
819*0e3ece09SJunchao Zhang   PetscFunctionBegin;
820*0e3ece09SJunchao Zhang   PetscCallMPI(PetscObjectGetComm((PetscObject)E, &comm));
821*0e3ece09SJunchao Zhang   if (reuse == MAT_INITIAL_MATRIX) {
822*0e3ece09SJunchao Zhang     Mat_SeqAIJ     *aseq = static_cast<Mat_SeqAIJ *>(A->data), *bseq = static_cast<Mat_SeqAIJ *>(B->data);
823*0e3ece09SJunchao Zhang     PetscInt        n1 = B->cmap->n, *Ai = aseq->i, *Aj = aseq->j, *Bi = bseq->i, *Bj = bseq->j;
824*0e3ece09SJunchao Zhang     const PetscInt *garray1 = empi->garray; // its size is n1
825*0e3ece09SJunchao Zhang     PetscInt        cstart, cend;
826*0e3ece09SJunchao Zhang     PetscSF         bcastSF;
827*0e3ece09SJunchao Zhang 
828*0e3ece09SJunchao Zhang     PetscCall(MatGetOwnershipRangeColumn(E, &cstart, &cend));
829*0e3ece09SJunchao Zhang 
830*0e3ece09SJunchao Zhang     // Count how many nonzeros of each row in E are in the left of the diag block [cstart,cend)
831*0e3ece09SJunchao Zhang     PetscIntKokkosViewHost E_NzLeft_h("E_NzLeft_h", Em), E_RowLen_h("E_RowLen_h", Em);
832*0e3ece09SJunchao Zhang     PetscInt              *E_NzLeft = E_NzLeft_h.data(), *E_RowLen = E_RowLen_h.data();
833*0e3ece09SJunchao Zhang     for (PetscInt i = 0; i < Em; i++) {
834*0e3ece09SJunchao Zhang       const PetscInt *first, *last, *it;
835*0e3ece09SJunchao Zhang       PetscInt        count, step;
836*0e3ece09SJunchao Zhang       // std::lower_bound(first,last,cstart), but need to use global column indices
837*0e3ece09SJunchao Zhang       first = Bj + Bi[i];
838*0e3ece09SJunchao Zhang       last  = Bj + Bi[i + 1];
839*0e3ece09SJunchao Zhang       count = last - first;
840*0e3ece09SJunchao Zhang       while (count > 0) {
841*0e3ece09SJunchao Zhang         it   = first;
842*0e3ece09SJunchao Zhang         step = count / 2;
843*0e3ece09SJunchao Zhang         it += step;
844*0e3ece09SJunchao Zhang         if (empi->garray[*it] < cstart) { // map local to global
845*0e3ece09SJunchao Zhang           first = ++it;
846*0e3ece09SJunchao Zhang           count -= step + 1;
847*0e3ece09SJunchao Zhang         } else count = step;
848*0e3ece09SJunchao Zhang       }
849*0e3ece09SJunchao Zhang       E_NzLeft[i] = first - (Bj + Bi[i]);
850*0e3ece09SJunchao Zhang       E_RowLen[i] = (Ai[i + 1] - Ai[i]) + (Bi[i + 1] - Bi[i]);
851*0e3ece09SJunchao Zhang     }
852*0e3ece09SJunchao Zhang 
853*0e3ece09SJunchao Zhang     // Compute row pointer Fi of F
854*0e3ece09SJunchao Zhang     PetscInt *Fi, Fm, Fnz;
855*0e3ece09SJunchao Zhang     PetscCall(PetscSFGetGraph(ownerSF, NULL, &Fm, NULL, NULL)); // Fm = #rows of F = nleaves of ownerSF
856*0e3ece09SJunchao Zhang     PetscCall(PetscMalloc1(Fm + 1, &Fi));
857*0e3ece09SJunchao Zhang     Fi[0] = 0;
858*0e3ece09SJunchao Zhang     PetscCall(PetscSFBcastWithMemTypeBegin(ownerSF, MPIU_INT, PETSC_MEMTYPE_HOST, E_RowLen, PETSC_MEMTYPE_HOST, &Fi[1], MPI_REPLACE));
859*0e3ece09SJunchao Zhang     PetscCall(PetscSFBcastEnd(ownerSF, MPIU_INT, E_RowLen, &Fi[1], MPI_REPLACE));
860*0e3ece09SJunchao Zhang     for (PetscInt i = 0; i < Fm; i++) Fi[i + 1] += Fi[i];
861*0e3ece09SJunchao Zhang     Fnz = Fi[Fm];
862*0e3ece09SJunchao Zhang 
863*0e3ece09SJunchao Zhang     // Build the real PetscSF for bcasting E rows (buffer to buffer)
864*0e3ece09SJunchao Zhang     const PetscMPIInt *iranks, *ranks;
865*0e3ece09SJunchao Zhang     const PetscInt    *ioffset, *irootloc, *roffset;
866*0e3ece09SJunchao Zhang     PetscInt           niranks, nranks, *sdisp, *rdisp;
867*0e3ece09SJunchao Zhang     MPI_Request       *reqs;
868*0e3ece09SJunchao Zhang     PetscMPIInt        tag;
869*0e3ece09SJunchao Zhang 
870*0e3ece09SJunchao Zhang     PetscCall(PetscSFGetLeafRanks(ownerSF, &niranks, &iranks, &ioffset, &irootloc)); // get leaf ranks referencing roots on this process
871*0e3ece09SJunchao Zhang     PetscCall(PetscSFGetRootRanks(ownerSF, &nranks, &ranks, &roffset, NULL, NULL));  // recv info
872*0e3ece09SJunchao Zhang     PetscCall(PetscMalloc3(niranks + 1, &sdisp, nranks, &rdisp, niranks + nranks, &reqs));
873*0e3ece09SJunchao Zhang 
874*0e3ece09SJunchao Zhang     sdisp[0] = 0; // send displacement
875*0e3ece09SJunchao Zhang     for (PetscInt i = 0; i < niranks; i++) {
876*0e3ece09SJunchao Zhang       sdisp[i + 1] = sdisp[i];
877*0e3ece09SJunchao Zhang       for (PetscInt j = ioffset[i]; j < ioffset[i + 1]; j++) {
878*0e3ece09SJunchao Zhang         PetscInt r = irootloc[j]; // row to be sent
879*0e3ece09SJunchao Zhang         sdisp[i + 1] += E_RowLen[r];
880*0e3ece09SJunchao Zhang       }
881*0e3ece09SJunchao Zhang     }
882*0e3ece09SJunchao Zhang 
883*0e3ece09SJunchao Zhang     PetscCallMPI(PetscCommGetNewTag(comm, &tag));
884*0e3ece09SJunchao Zhang     for (PetscInt i = 0; i < nranks; i++) PetscCallMPI(MPI_Irecv(&rdisp[i], 1, MPIU_INT, ranks[i], tag, comm, &reqs[i]));
885*0e3ece09SJunchao Zhang     for (PetscInt i = 0; i < niranks; i++) PetscCallMPI(MPI_Isend(&sdisp[i], 1, MPIU_INT, iranks[i], tag, comm, &reqs[nranks + i]));
886*0e3ece09SJunchao Zhang     PetscCallMPI(MPI_Waitall(niranks + nranks, reqs, MPI_STATUSES_IGNORE));
887*0e3ece09SJunchao Zhang 
888*0e3ece09SJunchao Zhang     PetscInt     nleaves = Fnz;            // leaves are nonzeros I will receive
889*0e3ece09SJunchao Zhang     PetscInt     nroots  = sdisp[niranks]; // roots are nonzeros I will send
890*0e3ece09SJunchao Zhang     PetscSFNode *iremote;                  // give ownership to bcastSF
891*0e3ece09SJunchao Zhang     PetscCall(PetscMalloc1(nleaves, &iremote));
892*0e3ece09SJunchao Zhang     for (PetscInt i = 0; i < nranks; i++) { // for each sender rank
893*0e3ece09SJunchao Zhang       PetscInt k = 0;
894*0e3ece09SJunchao Zhang       for (PetscInt j = Fi[roffset[i]]; j < Fi[roffset[i + 1]]; j++) { // I will receive rows [roffset[i], roffset[i+1]) of F from ranks[i]
895*0e3ece09SJunchao Zhang         iremote[j].rank  = ranks[i];
896*0e3ece09SJunchao Zhang         iremote[j].index = rdisp[i] + k; // their root location
897*0e3ece09SJunchao Zhang         k++;
898*0e3ece09SJunchao Zhang       }
899*0e3ece09SJunchao Zhang     }
900*0e3ece09SJunchao Zhang     PetscCall(PetscSFCreate(comm, &bcastSF));
901*0e3ece09SJunchao Zhang     PetscCall(PetscSFSetGraph(bcastSF, nroots, nleaves, NULL, PETSC_OWN_POINTER, iremote, PETSC_OWN_POINTER));
902*0e3ece09SJunchao Zhang     PetscCall(PetscFree3(sdisp, rdisp, reqs));
903*0e3ece09SJunchao Zhang 
904*0e3ece09SJunchao Zhang     // Build a plan (rowoffset, irootloc, E_NzLeft) to copy rows in E to rootdata of bcastSF in parallel
905*0e3ece09SJunchao Zhang     PetscIntKokkosViewHost rowoffset_h("rowoffset_h", ioffset[niranks] + 1);
906*0e3ece09SJunchao Zhang     PetscInt              *rowoffset = rowoffset_h.data(); // for each entry (row) indicated in irootloc[], we calculate its destinate offset in copying
907*0e3ece09SJunchao Zhang     rowoffset[0]                     = 0;
908*0e3ece09SJunchao Zhang     for (PetscInt i = 0; i < ioffset[niranks]; i++) { rowoffset[i + 1] += rowoffset[i] + E_RowLen[irootloc[i]]; }
909*0e3ece09SJunchao Zhang 
910*0e3ece09SJunchao Zhang     // Copy (global) column indices of the needed rows in E to a buffer, and then bcast to Fj[]
911*0e3ece09SJunchao Zhang     PetscInt *jbuf, *Fj;
912*0e3ece09SJunchao Zhang     PetscCall(PetscMalloc2(nroots, &jbuf, Fnz, &Fj));
913*0e3ece09SJunchao Zhang     for (PetscInt k = 0; k < ioffset[niranks]; k++) {
914*0e3ece09SJunchao Zhang       PetscInt  i      = irootloc[k]; // row to be copied
915*0e3ece09SJunchao Zhang       PetscInt *buf    = &jbuf[rowoffset[k]];
916*0e3ece09SJunchao Zhang       PetscInt  nzLeft = E_NzLeft[i];
917*0e3ece09SJunchao Zhang       PetscInt  alen = Ai[i + 1] - Ai[i], blen = Bi[i + 1] - Bi[i];
918*0e3ece09SJunchao Zhang       for (PetscInt j = 0; j < alen + blen; j++) {
919*0e3ece09SJunchao Zhang         if (j < nzLeft) {
920*0e3ece09SJunchao Zhang           buf[j] = empi->garray[Bj[Bi[i] + j]]; // left B, in global
921*0e3ece09SJunchao Zhang         } else if (j < nzLeft + alen) {
922*0e3ece09SJunchao Zhang           buf[j] = Aj[Ai[i] + j - nzLeft] + cstart; // diag A, also in global
923*0e3ece09SJunchao Zhang         } else {
924*0e3ece09SJunchao Zhang           buf[j] = empi->garray[Bj[Bi[i] + j - alen]]; // right B, in global
925*0e3ece09SJunchao Zhang         }
926*0e3ece09SJunchao Zhang       }
927*0e3ece09SJunchao Zhang     }
928*0e3ece09SJunchao Zhang     PetscCall(PetscSFBcastWithMemTypeBegin(bcastSF, MPIU_INT, PETSC_MEMTYPE_HOST, jbuf, PETSC_MEMTYPE_HOST, Fj, MPI_REPLACE));
929*0e3ece09SJunchao Zhang     PetscCall(PetscSFBcastEnd(bcastSF, MPIU_INT, jbuf, Fj, MPI_REPLACE));
930*0e3ece09SJunchao Zhang 
931*0e3ece09SJunchao Zhang     // Build a plan (i.e., F_NzLeft) to split F into Fd and Fo
932*0e3ece09SJunchao Zhang     MatRowMapKokkosViewHost Fdi_h("Fdi_h", Fm + 1), Foi_h("Foi_h", Fm + 1); // row pointer of Fd, Fo
933*0e3ece09SJunchao Zhang     MatColIdxKokkosViewHost F_NzLeft_h("F_NzLeft_h", Fm);                   // split each row of F into Left, Diag, Right. We only need to record #nz in Left and Diag.
934*0e3ece09SJunchao Zhang     MatRowMapType          *Fdi = Fdi_h.data(), *Foi = Foi_h.data();
935*0e3ece09SJunchao Zhang     MatColIdxType          *F_NzLeft = F_NzLeft_h.data();
936*0e3ece09SJunchao Zhang 
937*0e3ece09SJunchao Zhang     Fdi[0] = Foi[0] = 0;
938*0e3ece09SJunchao Zhang     for (PetscInt i = 0; i < Fm; i++) {
939*0e3ece09SJunchao Zhang       PetscInt *first, *last, *lb1, *lb2;
940*0e3ece09SJunchao Zhang       // cut the row into: Left, [cstart, cend), Right
941*0e3ece09SJunchao Zhang       first       = Fj + Fi[i];
942*0e3ece09SJunchao Zhang       last        = Fj + Fi[i + 1];
943*0e3ece09SJunchao Zhang       lb1         = std::lower_bound(first, last, cstart);
944*0e3ece09SJunchao Zhang       F_NzLeft[i] = lb1 - first;
945*0e3ece09SJunchao Zhang       lb2         = std::lower_bound(first, last, cend);
946*0e3ece09SJunchao Zhang       Fdi[i + 1]  = lb2 - lb1;                        // row i length in Fdi
947*0e3ece09SJunchao Zhang       Foi[i + 1]  = (Fi[i + 1] - Fi[i]) - Fdi[i + 1]; // row i length in Foi
948*0e3ece09SJunchao Zhang     }
949*0e3ece09SJunchao Zhang     for (PetscInt i = 0; i < Fm; i++) {
950*0e3ece09SJunchao Zhang       Fdi[i + 1] += Fdi[i];
951*0e3ece09SJunchao Zhang       Foi[i + 1] += Foi[i];
952*0e3ece09SJunchao Zhang     }
953*0e3ece09SJunchao Zhang 
954*0e3ece09SJunchao Zhang     // Fill Fdj[] and Foj[], i.e., columns of Fd and Fo. Fdj[] are local, but Foj[] are not yet.
955*0e3ece09SJunchao Zhang     PetscInt                Fdnz = Fdi[Fm], Fonz = Foi[Fm];
956*0e3ece09SJunchao Zhang     MatColIdxKokkosViewHost Fdj_h("Fdj_h", Fdnz), Foj_h("Foj_h", Fonz);
957*0e3ece09SJunchao Zhang     MatColIdxType          *Fdj = Fdj_h.data(), *Foj = Foj_h.data(), gid;
958*0e3ece09SJunchao Zhang 
959*0e3ece09SJunchao Zhang     for (PetscInt i = 0; i < Fm; i++) {
960*0e3ece09SJunchao Zhang       PetscInt nzLeft = F_NzLeft[i];
961*0e3ece09SJunchao Zhang       PetscInt len    = Fdi[i + 1] - Fdi[i]; // diag row len
962*0e3ece09SJunchao Zhang       for (PetscInt j = 0; j < Fi[i + 1] - Fi[i]; j++) {
963*0e3ece09SJunchao Zhang         gid = Fj[Fi[i] + j];
964*0e3ece09SJunchao Zhang         if (j < nzLeft) { // left, in global
965*0e3ece09SJunchao Zhang           Foj[Foi[i] + j] = gid;
966*0e3ece09SJunchao Zhang         } else if (j < nzLeft + len) { // diag, in local
967*0e3ece09SJunchao Zhang           Fdj[Fdi[i] + j - nzLeft] = gid - cstart;
968*0e3ece09SJunchao Zhang         } else { // right, in global
969*0e3ece09SJunchao Zhang           Foj[Foi[i] + j - len] = gid;
970*0e3ece09SJunchao Zhang         }
971*0e3ece09SJunchao Zhang       }
972*0e3ece09SJunchao Zhang     }
973*0e3ece09SJunchao Zhang     PetscCall(PetscFree2(jbuf, Fj));
974*0e3ece09SJunchao Zhang     PetscCall(PetscFree(Fi));
975*0e3ece09SJunchao Zhang 
976*0e3ece09SJunchao Zhang     // Reduce global indices in Foj[] and garray1[] into local ones
977*0e3ece09SJunchao Zhang     PetscInt n2, *garray2;
978*0e3ece09SJunchao Zhang     PetscCall(ReduceTwoSetsOfGlobalIndices(n1, garray1, Fonz, Foj, &n2, &garray2, map));
979*0e3ece09SJunchao Zhang 
980*0e3ece09SJunchao Zhang     // Record the plans built above, for reuse
981*0e3ece09SJunchao Zhang     PetscIntKokkosViewHost tmp(const_cast<PetscInt *>(irootloc), ioffset[niranks]); // irootloc[] is owned by ownerSF. We create a copy for safety
982*0e3ece09SJunchao Zhang     PetscIntKokkosViewHost irootloc_h("irootloc_h", ioffset[niranks]);
983*0e3ece09SJunchao Zhang     Kokkos::deep_copy(irootloc_h, tmp);
984*0e3ece09SJunchao Zhang     mm->sf        = bcastSF;
985*0e3ece09SJunchao Zhang     mm->E_NzLeft  = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), E_NzLeft_h);
986*0e3ece09SJunchao Zhang     mm->F_NzLeft  = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), F_NzLeft_h);
987*0e3ece09SJunchao Zhang     mm->irootloc  = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), irootloc_h);
988*0e3ece09SJunchao Zhang     mm->rowoffset = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), rowoffset_h);
989*0e3ece09SJunchao Zhang     mm->rootBuf   = MatScalarKokkosView("rootBuf", nroots);
990*0e3ece09SJunchao Zhang     mm->leafBuf   = MatScalarKokkosView("leafBuf", nleaves);
991*0e3ece09SJunchao Zhang     mm->garray    = garray2;
992*0e3ece09SJunchao Zhang     mm->n         = n2;
993*0e3ece09SJunchao Zhang 
994*0e3ece09SJunchao Zhang     // Output Fd and Fo in KokkosCsrMatrix format
995*0e3ece09SJunchao Zhang     MatScalarKokkosView Fda_d("Fda_d", Fdnz), Foa_d("Foa_d", Fonz);
996*0e3ece09SJunchao Zhang     MatRowMapKokkosView Fdi_d = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Fdi_h);
997*0e3ece09SJunchao Zhang     MatColIdxKokkosView Fdj_d = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Fdj_h);
998*0e3ece09SJunchao Zhang     MatRowMapKokkosView Foi_d = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Foi_h);
999*0e3ece09SJunchao Zhang     MatColIdxKokkosView Foj_d = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Foj_h);
1000*0e3ece09SJunchao Zhang 
1001*0e3ece09SJunchao Zhang     PetscCallCXX(mm->Fd = KokkosCsrMatrix("Fd", Fm, cend - cstart, Fdnz, Fda_d, Fdi_d, Fdj_d));
1002*0e3ece09SJunchao Zhang     PetscCallCXX(mm->Fo = KokkosCsrMatrix("Fo", Fm, n2, Fonz, Foa_d, Foi_d, Foj_d));
1003*0e3ece09SJunchao Zhang 
1004*0e3ece09SJunchao Zhang     // Compute kernel launch parameters in merging E or splitting F
1005*0e3ece09SJunchao Zhang     PetscInt teamSize, vectorLength, rowsPerTeam;
1006*0e3ece09SJunchao Zhang 
1007*0e3ece09SJunchao Zhang     teamSize = vectorLength = rowsPerTeam = -1;
1008*0e3ece09SJunchao Zhang     PetscCall(MatMergeGetLaunchParameters<DefaultExecutionSpace>(mm->irootloc.extent(0), mm->rootBuf.extent(0), -1, teamSize, vectorLength, rowsPerTeam));
1009*0e3ece09SJunchao Zhang     mm->E_TeamSize     = teamSize;
1010*0e3ece09SJunchao Zhang     mm->E_VectorLength = vectorLength;
1011*0e3ece09SJunchao Zhang     mm->E_RowsPerTeam  = rowsPerTeam;
1012*0e3ece09SJunchao Zhang 
1013*0e3ece09SJunchao Zhang     teamSize = vectorLength = rowsPerTeam = -1;
1014*0e3ece09SJunchao Zhang     PetscCall(MatMergeGetLaunchParameters<DefaultExecutionSpace>(Fm, Fnz, -1, teamSize, vectorLength, rowsPerTeam));
1015*0e3ece09SJunchao Zhang     mm->F_TeamSize     = teamSize;
1016*0e3ece09SJunchao Zhang     mm->F_VectorLength = vectorLength;
1017*0e3ece09SJunchao Zhang     mm->F_RowsPerTeam  = rowsPerTeam;
1018*0e3ece09SJunchao Zhang   } else PetscCheck(reuse == MAT_REUSE_MATRIX, comm, PETSC_ERR_PLIB, "Unsupported MatReuse enum %d", reuse);
1019*0e3ece09SJunchao Zhang 
1020*0e3ece09SJunchao Zhang   // Sync E's value to device
1021*0e3ece09SJunchao Zhang   akok->a_dual.sync_device();
1022*0e3ece09SJunchao Zhang   bkok->a_dual.sync_device();
1023*0e3ece09SJunchao Zhang 
1024*0e3ece09SJunchao Zhang   // Handy aliases
1025*0e3ece09SJunchao Zhang   const auto &Aa = akok->a_dual.view_device();
1026*0e3ece09SJunchao Zhang   const auto &Ba = bkok->a_dual.view_device();
1027*0e3ece09SJunchao Zhang   const auto &Ai = akok->i_dual.view_device();
1028*0e3ece09SJunchao Zhang   const auto &Bi = bkok->i_dual.view_device();
1029*0e3ece09SJunchao Zhang 
1030*0e3ece09SJunchao Zhang   // Fetch the plans
1031*0e3ece09SJunchao Zhang   PetscIntKokkosView  &E_NzLeft  = mm->E_NzLeft;
1032*0e3ece09SJunchao Zhang   PetscSF             &bcastSF   = mm->sf;
1033*0e3ece09SJunchao Zhang   MatScalarKokkosView &rootBuf   = mm->rootBuf;
1034*0e3ece09SJunchao Zhang   MatScalarKokkosView &leafBuf   = mm->leafBuf;
1035*0e3ece09SJunchao Zhang   PetscIntKokkosView  &irootloc  = mm->irootloc;
1036*0e3ece09SJunchao Zhang   PetscIntKokkosView  &rowoffset = mm->rowoffset;
1037*0e3ece09SJunchao Zhang 
1038*0e3ece09SJunchao Zhang   PetscInt teamSize     = mm->E_TeamSize;
1039*0e3ece09SJunchao Zhang   PetscInt vectorLength = mm->E_VectorLength;
1040*0e3ece09SJunchao Zhang   PetscInt rowsPerTeam  = mm->E_RowsPerTeam;
1041*0e3ece09SJunchao Zhang   PetscInt workSets     = (irootloc.extent(0) + rowsPerTeam - 1) / rowsPerTeam;
1042*0e3ece09SJunchao Zhang 
1043*0e3ece09SJunchao Zhang   // Copy rows in A/B of E to rootBuf, then bcast it to leafBuf
1044*0e3ece09SJunchao Zhang   PetscCallCXX(Kokkos::parallel_for(
1045*0e3ece09SJunchao Zhang     Kokkos::TeamPolicy<>(workSets, teamSize, vectorLength), KOKKOS_LAMBDA(const KokkosTeamMemberType &t) {
1046*0e3ece09SJunchao Zhang       Kokkos::parallel_for(Kokkos::TeamThreadRange(t, 0, rowsPerTeam), [&](PetscInt k) {
1047*0e3ece09SJunchao Zhang         size_t r = t.league_rank() * rowsPerTeam + k; // r-th entry in irootloc[]
1048*0e3ece09SJunchao Zhang         if (r < irootloc.extent(0)) {
1049*0e3ece09SJunchao Zhang           PetscInt i      = irootloc(r); // row i of E
1050*0e3ece09SJunchao Zhang           PetscInt disp   = rowoffset(r);
1051*0e3ece09SJunchao Zhang           PetscInt alen   = Ai(i + 1) - Ai(i);
1052*0e3ece09SJunchao Zhang           PetscInt blen   = Bi(i + 1) - Bi(i);
1053*0e3ece09SJunchao Zhang           PetscInt nzleft = E_NzLeft(i);
1054*0e3ece09SJunchao Zhang 
1055*0e3ece09SJunchao Zhang           Kokkos::parallel_for(Kokkos::ThreadVectorRange(t, alen + blen), [&](PetscInt j) {
1056*0e3ece09SJunchao Zhang             if (j < nzleft) { // B left
1057*0e3ece09SJunchao Zhang               rootBuf(disp + j) = Ba(Bi(i) + j);
1058*0e3ece09SJunchao Zhang             } else if (j < nzleft + alen) { // diag A
1059*0e3ece09SJunchao Zhang               rootBuf(disp + j) = Aa(Ai(i) + j - nzleft);
1060*0e3ece09SJunchao Zhang             } else { // B right
1061*0e3ece09SJunchao Zhang               rootBuf(disp + j) = Ba(Bi(i) + j - alen);
1062*0e3ece09SJunchao Zhang             }
1063*0e3ece09SJunchao Zhang           });
1064*0e3ece09SJunchao Zhang         }
1065*0e3ece09SJunchao Zhang       });
1066*0e3ece09SJunchao Zhang     }));
1067*0e3ece09SJunchao Zhang   PetscCall(PetscSFBcastWithMemTypeBegin(bcastSF, MPIU_SCALAR, PETSC_MEMTYPE_KOKKOS, rootBuf.data(), PETSC_MEMTYPE_KOKKOS, leafBuf.data(), MPI_REPLACE));
1068*0e3ece09SJunchao Zhang   PetscFunctionReturn(PETSC_SUCCESS);
1069*0e3ece09SJunchao Zhang }
1070*0e3ece09SJunchao Zhang 
1071*0e3ece09SJunchao Zhang // To finish MatMPIAIJKokkosBcast.
1072*0e3ece09SJunchao Zhang static PetscErrorCode MatMPIAIJKokkosBcastEnd(Mat E, PetscSF ownerSF, MatReuse reuse, PetscInt *map, MatMatStruct_AB *mm)
1073*0e3ece09SJunchao Zhang {
1074*0e3ece09SJunchao Zhang   PetscFunctionBegin;
1075*0e3ece09SJunchao Zhang   const auto &Fd  = mm->Fd;
1076*0e3ece09SJunchao Zhang   const auto &Fo  = mm->Fo;
1077*0e3ece09SJunchao Zhang   const auto &Fdi = Fd.graph.row_map;
1078*0e3ece09SJunchao Zhang   const auto &Foi = Fo.graph.row_map;
1079*0e3ece09SJunchao Zhang   auto       &Fda = Fd.values;
1080*0e3ece09SJunchao Zhang   auto       &Foa = Fo.values;
1081*0e3ece09SJunchao Zhang   auto        Fm  = Fd.numRows();
1082*0e3ece09SJunchao Zhang 
1083*0e3ece09SJunchao Zhang   PetscIntKokkosView  &F_NzLeft     = mm->F_NzLeft;
1084*0e3ece09SJunchao Zhang   PetscSF             &bcastSF      = mm->sf;
1085*0e3ece09SJunchao Zhang   MatScalarKokkosView &rootBuf      = mm->rootBuf;
1086*0e3ece09SJunchao Zhang   MatScalarKokkosView &leafBuf      = mm->leafBuf;
1087*0e3ece09SJunchao Zhang   PetscInt             teamSize     = mm->F_TeamSize;
1088*0e3ece09SJunchao Zhang   PetscInt             vectorLength = mm->F_VectorLength;
1089*0e3ece09SJunchao Zhang   PetscInt             rowsPerTeam  = mm->F_RowsPerTeam;
1090*0e3ece09SJunchao Zhang   PetscInt             workSets     = (Fm + rowsPerTeam - 1) / rowsPerTeam;
1091*0e3ece09SJunchao Zhang 
1092*0e3ece09SJunchao Zhang   PetscCall(PetscSFBcastEnd(bcastSF, MPIU_SCALAR, rootBuf.data(), leafBuf.data(), MPI_REPLACE));
1093*0e3ece09SJunchao Zhang 
1094*0e3ece09SJunchao Zhang   // Update Fda and Foa with new data in leafBuf (as if it is Fa)
1095*0e3ece09SJunchao Zhang   PetscCallCXX(Kokkos::parallel_for(
1096*0e3ece09SJunchao Zhang     Kokkos::TeamPolicy<>(workSets, teamSize, vectorLength), KOKKOS_LAMBDA(const KokkosTeamMemberType &t) {
1097*0e3ece09SJunchao Zhang       Kokkos::parallel_for(Kokkos::TeamThreadRange(t, 0, rowsPerTeam), [&](PetscInt k) {
1098*0e3ece09SJunchao Zhang         PetscInt i = t.league_rank() * rowsPerTeam + k; // i-th row in F
1099*0e3ece09SJunchao Zhang         if (i < Fm) {
1100*0e3ece09SJunchao Zhang           PetscInt nzLeft = F_NzLeft(i);
1101*0e3ece09SJunchao Zhang           PetscInt alen   = Fdi(i + 1) - Fdi(i);
1102*0e3ece09SJunchao Zhang           PetscInt blen   = Foi(i + 1) - Foi(i);
1103*0e3ece09SJunchao Zhang           PetscInt Fii    = Fdi(i) + Foi(i);
1104*0e3ece09SJunchao Zhang 
1105*0e3ece09SJunchao Zhang           Kokkos::parallel_for(Kokkos::ThreadVectorRange(t, alen + blen), [&](PetscInt j) {
1106*0e3ece09SJunchao Zhang             PetscScalar val = leafBuf(Fii + j);
1107*0e3ece09SJunchao Zhang             if (j < nzLeft) { // left
1108*0e3ece09SJunchao Zhang               Foa(Foi(i) + j) = val;
1109*0e3ece09SJunchao Zhang             } else if (j < nzLeft + alen) { // diag
1110*0e3ece09SJunchao Zhang               Fda(Fdi(i) + j - nzLeft) = val;
1111*0e3ece09SJunchao Zhang             } else { // right
1112*0e3ece09SJunchao Zhang               Foa(Foi(i) + j - alen) = val;
1113*0e3ece09SJunchao Zhang             }
1114*0e3ece09SJunchao Zhang           });
1115*0e3ece09SJunchao Zhang         }
1116*0e3ece09SJunchao Zhang       });
1117*0e3ece09SJunchao Zhang     }));
1118*0e3ece09SJunchao Zhang   PetscFunctionReturn(PETSC_SUCCESS);
1119*0e3ece09SJunchao Zhang }
1120*0e3ece09SJunchao Zhang 
1121*0e3ece09SJunchao Zhang static PetscErrorCode MatProductSymbolic_MPIAIJKokkos_AtB(Mat_Product *product, Mat A, Mat B, MatMatStruct_AtB *mm)
1122*0e3ece09SJunchao Zhang {
1123*0e3ece09SJunchao Zhang   Mat_MPIAIJ     *ampi = static_cast<Mat_MPIAIJ *>(A->data);
1124*0e3ece09SJunchao Zhang   Mat_MPIAIJ     *bmpi = static_cast<Mat_MPIAIJ *>(B->data);
1125*0e3ece09SJunchao Zhang   KokkosCsrMatrix Adt, Aot, Ad, Ao, Bd, Bo;
1126*0e3ece09SJunchao Zhang   PetscInt        cstart, cend;
1127*0e3ece09SJunchao Zhang   MPI_Comm        comm;
1128*0e3ece09SJunchao Zhang 
1129*0e3ece09SJunchao Zhang   PetscFunctionBegin;
1130*0e3ece09SJunchao Zhang   PetscCall(PetscObjectGetComm((PetscObject)B, &comm));
1131*0e3ece09SJunchao Zhang   PetscCall(MatSeqAIJKokkosGenerateTranspose_Private(ampi->A, &Adt));
1132*0e3ece09SJunchao Zhang   PetscCall(MatSeqAIJKokkosGenerateTranspose_Private(ampi->B, &Aot));
1133*0e3ece09SJunchao Zhang   PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(ampi->A, &Ad));
1134*0e3ece09SJunchao Zhang   PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(ampi->B, &Ao));
1135*0e3ece09SJunchao Zhang   PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(bmpi->A, &Bd));
1136*0e3ece09SJunchao Zhang   PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(bmpi->B, &Bo));
1137*0e3ece09SJunchao Zhang 
1138*0e3ece09SJunchao Zhang   // TODO: add command line options to select spgemm algorithms
1139*0e3ece09SJunchao Zhang   auto spgemm_alg = KokkosSparse::SPGEMMAlgorithm::SPGEMM_DEFAULT; // default is TPL if enabled, otherwise KK
1140*0e3ece09SJunchao Zhang 
1141*0e3ece09SJunchao Zhang   // CUDA-10.2's spgemm has bugs. We prefer the SpGEMMreuse APIs introduced in cuda-11.4
1142*0e3ece09SJunchao Zhang #if defined(KOKKOSKERNELS_ENABLE_TPL_CUSPARSE)
1143*0e3ece09SJunchao Zhang   #if PETSC_PKG_CUDA_VERSION_LT(11, 4, 0)
1144*0e3ece09SJunchao Zhang   spgemm_alg = KokkosSparse::SPGEMMAlgorithm::SPGEMM_KK;
1145f0e6e2d1SJunchao Zhang   #endif
1146*0e3ece09SJunchao Zhang #endif
1147*0e3ece09SJunchao Zhang 
1148*0e3ece09SJunchao Zhang   PetscCallCXX(mm->kh1.create_spgemm_handle(spgemm_alg));
1149*0e3ece09SJunchao Zhang   PetscCallCXX(mm->kh2.create_spgemm_handle(spgemm_alg));
1150*0e3ece09SJunchao Zhang   PetscCallCXX(mm->kh3.create_spgemm_handle(spgemm_alg));
1151*0e3ece09SJunchao Zhang   PetscCallCXX(mm->kh4.create_spgemm_handle(spgemm_alg));
1152*0e3ece09SJunchao Zhang 
1153*0e3ece09SJunchao Zhang   // Aot * (B's diag + B's off-diag)
1154*0e3ece09SJunchao Zhang   PetscCallCXX(KokkosSparse::spgemm_symbolic(mm->kh3, Aot, false, Bd, false, mm->C3));
1155*0e3ece09SJunchao Zhang   PetscCallCXX(KokkosSparse::spgemm_symbolic(mm->kh4, Aot, false, Bo, false, mm->C4));
1156*0e3ece09SJunchao Zhang   // KK spgemm_symbolic() only populates the result's row map, but not its columns.
1157*0e3ece09SJunchao Zhang   // TODO: Remove the fake spgemm_numeric() after KK fixed this problem.
1158*0e3ece09SJunchao Zhang   PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh3, Aot, false, Bd, false, mm->C3));
1159*0e3ece09SJunchao Zhang   PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh4, Aot, false, Bo, false, mm->C4));
1160*0e3ece09SJunchao Zhang #if PETSC_PKG_KOKKOS_KERNELS_VERSION_LT(4, 0, 0)
1161*0e3ece09SJunchao Zhang   PetscCallCXX(sort_crs_matrix(mm->C3));
1162*0e3ece09SJunchao Zhang   PetscCallCXX(sort_crs_matrix(mm->C4));
1163*0e3ece09SJunchao Zhang #endif
1164*0e3ece09SJunchao Zhang 
1165*0e3ece09SJunchao Zhang   // Reduce E (i.e., C3 and C4)'s rows to form F, and overlap the communication
1166*0e3ece09SJunchao Zhang   PetscIntKokkosViewHost map_h("map_h", bmpi->B->cmap->n);
1167*0e3ece09SJunchao Zhang   PetscCall(MatGetOwnershipRangeColumn(B, &cstart, &cend));
1168*0e3ece09SJunchao Zhang   PetscCall(MatMPIAIJKokkosReduceBegin(comm, mm->C3, mm->C4, cstart, cend, bmpi->garray, ampi->Mvctx, MAT_INITIAL_MATRIX, map_h.data(), mm));
1169*0e3ece09SJunchao Zhang 
1170*0e3ece09SJunchao Zhang   // Adt * (B's diag + B's off-diag)
1171*0e3ece09SJunchao Zhang   PetscCallCXX(KokkosSparse::spgemm_symbolic(mm->kh1, Adt, false, Bd, false, mm->C1));
1172*0e3ece09SJunchao Zhang   PetscCallCXX(KokkosSparse::spgemm_symbolic(mm->kh2, Adt, false, Bo, false, mm->C2_mid));
1173*0e3ece09SJunchao Zhang   PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh1, Adt, false, Bd, false, mm->C1));
1174*0e3ece09SJunchao Zhang   PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh2, Adt, false, Bo, false, mm->C2_mid));
1175*0e3ece09SJunchao Zhang #if PETSC_PKG_KOKKOS_KERNELS_VERSION_LT(4, 0, 0)
1176*0e3ece09SJunchao Zhang   PetscCallCXX(sort_crs_matrix(mm->C1));
1177*0e3ece09SJunchao Zhang   PetscCallCXX(sort_crs_matrix(mm->C2_mid));
1178*0e3ece09SJunchao Zhang #endif
1179*0e3ece09SJunchao Zhang 
1180*0e3ece09SJunchao Zhang   PetscCall(MatMPIAIJKokkosReduceEnd(comm, mm->C3, mm->C4, cstart, cend, bmpi->garray, ampi->Mvctx, MAT_INITIAL_MATRIX, map_h.data(), mm));
1181*0e3ece09SJunchao Zhang 
1182*0e3ece09SJunchao Zhang   // Create C2, which shares a, i arrays with C2_mid, but with new column indices and potentially larger column size
1183*0e3ece09SJunchao Zhang   MatColIdxKokkosView oldj = mm->C2_mid.graph.entries, newj("j", oldj.extent(0));
1184*0e3ece09SJunchao Zhang   PetscIntKokkosView  map  = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), map_h);
1185*0e3ece09SJunchao Zhang   PetscCallCXX(Kokkos::parallel_for(
1186*0e3ece09SJunchao Zhang     oldj.extent(0), KOKKOS_LAMBDA(const PetscInt i) { newj(i) = map(oldj(i)); }));
1187*0e3ece09SJunchao Zhang   PetscCallCXX(mm->C2 = KokkosCsrMatrix("C2", mm->C2_mid.numRows(), mm->n /*new column size*/, mm->C2_mid.nnz(), mm->C2_mid.values, mm->C2_mid.graph.row_map, newj));
1188*0e3ece09SJunchao Zhang 
1189*0e3ece09SJunchao Zhang   // C = (C1+Fd, C2+Fo)
1190*0e3ece09SJunchao Zhang   PetscCallCXX(mm->kh1.create_spadd_handle(true)); // C1, Fd are sorted
1191*0e3ece09SJunchao Zhang   PetscCallCXX(mm->kh2.create_spadd_handle(true)); // C2, Fo are sorted
1192*0e3ece09SJunchao Zhang   PetscCallCXX(KokkosSparse::spadd_symbolic(&mm->kh1, mm->C1, mm->Fd, mm->Cd));
1193*0e3ece09SJunchao Zhang   PetscCallCXX(KokkosSparse::spadd_symbolic(&mm->kh2, mm->C2, mm->Fo, mm->Co));
1194*0e3ece09SJunchao Zhang   PetscCallCXX(KokkosSparse::spadd_numeric(&mm->kh1, 1.0, mm->C1, 1.0, mm->Fd, mm->Cd));
1195*0e3ece09SJunchao Zhang   PetscCallCXX(KokkosSparse::spadd_numeric(&mm->kh2, 1.0, mm->C2, 1.0, mm->Fo, mm->Co));
1196*0e3ece09SJunchao Zhang   PetscFunctionReturn(PETSC_SUCCESS);
1197*0e3ece09SJunchao Zhang }
1198*0e3ece09SJunchao Zhang 
1199*0e3ece09SJunchao Zhang static PetscErrorCode MatProductNumeric_MPIAIJKokkos_AtB(Mat_Product *product, Mat A, Mat B, MatMatStruct_AtB *mm)
1200*0e3ece09SJunchao Zhang {
1201*0e3ece09SJunchao Zhang   Mat_MPIAIJ     *ampi = static_cast<Mat_MPIAIJ *>(A->data);
1202*0e3ece09SJunchao Zhang   Mat_MPIAIJ     *bmpi = static_cast<Mat_MPIAIJ *>(B->data);
1203*0e3ece09SJunchao Zhang   KokkosCsrMatrix Adt, Aot, Bd, Bo;
1204*0e3ece09SJunchao Zhang   MPI_Comm        comm;
1205*0e3ece09SJunchao Zhang 
1206*0e3ece09SJunchao Zhang   PetscFunctionBegin;
1207*0e3ece09SJunchao Zhang   PetscCall(PetscObjectGetComm((PetscObject)B, &comm));
1208*0e3ece09SJunchao Zhang   PetscCall(MatSeqAIJKokkosGenerateTranspose_Private(ampi->A, &Adt));
1209*0e3ece09SJunchao Zhang   PetscCall(MatSeqAIJKokkosGenerateTranspose_Private(ampi->B, &Aot));
1210*0e3ece09SJunchao Zhang   PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(bmpi->A, &Bd));
1211*0e3ece09SJunchao Zhang   PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(bmpi->B, &Bo));
1212*0e3ece09SJunchao Zhang 
1213*0e3ece09SJunchao Zhang   // Aot * (B's diag + B's off-diag)
1214*0e3ece09SJunchao Zhang   PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh3, Aot, false, Bd, false, mm->C3));
1215*0e3ece09SJunchao Zhang   PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh4, Aot, false, Bo, false, mm->C4));
1216*0e3ece09SJunchao Zhang 
1217*0e3ece09SJunchao Zhang   // Reduce E (i.e., C3 and C4)'s rows to form F, and overlap the communication
1218*0e3ece09SJunchao Zhang   PetscCall(MatMPIAIJKokkosReduceBegin(comm, mm->C3, mm->C4, 0, 0, NULL, NULL, MAT_REUSE_MATRIX, NULL, mm));
1219*0e3ece09SJunchao Zhang 
1220*0e3ece09SJunchao Zhang   // Adt * (B's diag + B's off-diag)
1221*0e3ece09SJunchao Zhang   PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh1, Adt, false, Bd, false, mm->C1));
1222*0e3ece09SJunchao Zhang   PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh2, Adt, false, Bo, false, mm->C2_mid));
1223*0e3ece09SJunchao Zhang 
1224*0e3ece09SJunchao Zhang   PetscCall(MatMPIAIJKokkosReduceEnd(comm, mm->C3, mm->C4, 0, 0, NULL, NULL, MAT_REUSE_MATRIX, NULL, mm));
1225*0e3ece09SJunchao Zhang 
1226*0e3ece09SJunchao Zhang   // C = (C1+Fd, C2+Fo)
1227*0e3ece09SJunchao Zhang   PetscCallCXX(KokkosSparse::spadd_numeric(&mm->kh1, 1.0, mm->C1, 1.0, mm->Fd, mm->Cd));
1228*0e3ece09SJunchao Zhang   PetscCallCXX(KokkosSparse::spadd_numeric(&mm->kh2, 1.0, mm->C2, 1.0, mm->Fo, mm->Co));
1229*0e3ece09SJunchao Zhang   PetscFunctionReturn(PETSC_SUCCESS);
1230*0e3ece09SJunchao Zhang }
1231f0e6e2d1SJunchao Zhang 
1232076ba34aSJunchao Zhang /* MatProductSymbolic_MPIAIJKokkos_AB - AB flavor of MatProductSymbolic_MPIAIJKokkos
1233076ba34aSJunchao Zhang 
1234076ba34aSJunchao Zhang   Input Parameters:
1235076ba34aSJunchao Zhang +  product  - Mat_Product which carried out the computation. Passed in to access info about this mat product.
1236076ba34aSJunchao Zhang .  A        - an MPIAIJKOKKOS matrix
1237076ba34aSJunchao Zhang .  B        - an MPIAIJKOKKOS matrix
1238076ba34aSJunchao Zhang -  mm       - a struct used to stash intermediate data when computing AB. Persist from symbolic to numeric operations.
1239076ba34aSJunchao Zhang */
1240d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatProductSymbolic_MPIAIJKokkos_AB(Mat_Product *product, Mat A, Mat B, MatMatStruct_AB *mm)
1241d71ae5a4SJacob Faibussowitsch {
1242*0e3ece09SJunchao Zhang   Mat_MPIAIJ     *ampi = static_cast<Mat_MPIAIJ *>(A->data);
1243*0e3ece09SJunchao Zhang   Mat_MPIAIJ     *bmpi = static_cast<Mat_MPIAIJ *>(B->data);
1244*0e3ece09SJunchao Zhang   KokkosCsrMatrix Ad, Ao, Bd, Bo;
1245076ba34aSJunchao Zhang 
1246076ba34aSJunchao Zhang   PetscFunctionBegin;
1247*0e3ece09SJunchao Zhang   PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(ampi->A, &Ad));
1248*0e3ece09SJunchao Zhang   PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(ampi->B, &Ao));
1249*0e3ece09SJunchao Zhang   PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(bmpi->A, &Bd));
1250*0e3ece09SJunchao Zhang   PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(bmpi->B, &Bo));
1251*0e3ece09SJunchao Zhang 
1252*0e3ece09SJunchao Zhang   // TODO: add command line options to select spgemm algorithms
1253*0e3ece09SJunchao Zhang   auto spgemm_alg = KokkosSparse::SPGEMMAlgorithm::SPGEMM_DEFAULT; // default is TPL if enabled, otherwise KK
1254*0e3ece09SJunchao Zhang 
1255*0e3ece09SJunchao Zhang   // CUDA-10.2's spgemm has bugs. We prefer the SpGEMMreuse APIs introduced in cuda-11.4
1256*0e3ece09SJunchao Zhang #if defined(KOKKOSKERNELS_ENABLE_TPL_CUSPARSE)
1257*0e3ece09SJunchao Zhang   #if PETSC_PKG_CUDA_VERSION_LT(11, 4, 0)
1258*0e3ece09SJunchao Zhang   spgemm_alg = KokkosSparse::SPGEMMAlgorithm::SPGEMM_KK;
1259*0e3ece09SJunchao Zhang   #endif
1260f0e6e2d1SJunchao Zhang #endif
1261f0e6e2d1SJunchao Zhang 
1262*0e3ece09SJunchao Zhang   mm->kh1.create_spgemm_handle(spgemm_alg);
1263*0e3ece09SJunchao Zhang   mm->kh2.create_spgemm_handle(spgemm_alg);
1264*0e3ece09SJunchao Zhang   mm->kh3.create_spgemm_handle(spgemm_alg);
1265*0e3ece09SJunchao Zhang   mm->kh4.create_spgemm_handle(spgemm_alg);
1266076ba34aSJunchao Zhang 
1267*0e3ece09SJunchao Zhang   // Bcast B's rows to form F, and overlap the communication
1268*0e3ece09SJunchao Zhang   PetscIntKokkosViewHost map_h("map_h", bmpi->B->cmap->n);
1269*0e3ece09SJunchao Zhang   PetscCall(MatMPIAIJKokkosBcastBegin(B, ampi->Mvctx, MAT_INITIAL_MATRIX, map_h.data(), mm));
1270076ba34aSJunchao Zhang 
1271*0e3ece09SJunchao Zhang   // A's diag * (B's diag + B's off-diag)
1272*0e3ece09SJunchao Zhang   PetscCallCXX(KokkosSparse::spgemm_symbolic(mm->kh1, Ad, false, Bd, false, mm->C1));
1273*0e3ece09SJunchao Zhang   PetscCallCXX(KokkosSparse::spgemm_symbolic(mm->kh2, Ad, false, Bo, false, mm->C2_mid)); // C2 aliases with C2_mid, except with new column indices
1274*0e3ece09SJunchao Zhang   // KK spgemm_symbolic() only populates the result's row map, but not its columns.
1275*0e3ece09SJunchao Zhang   // TODO: Remove the fake spgemm_numeric() after KK fixed this problem.
1276*0e3ece09SJunchao Zhang   PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh1, Ad, false, Bd, false, mm->C1));
1277*0e3ece09SJunchao Zhang   PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh2, Ad, false, Bo, false, mm->C2_mid));
1278*0e3ece09SJunchao Zhang #if PETSC_PKG_KOKKOS_KERNELS_VERSION_LT(4, 0, 0)
1279*0e3ece09SJunchao Zhang   PetscCallCXX(sort_crs_matrix(mm->C1));
1280*0e3ece09SJunchao Zhang   PetscCallCXX(sort_crs_matrix(mm->C2_mid));
1281*0e3ece09SJunchao Zhang #endif
1282076ba34aSJunchao Zhang 
1283*0e3ece09SJunchao Zhang   PetscCall(MatMPIAIJKokkosBcastEnd(B, ampi->Mvctx, MAT_INITIAL_MATRIX, map_h.data(), mm));
1284076ba34aSJunchao Zhang 
1285*0e3ece09SJunchao Zhang   // A's off-diag * (F's diag + F's off-diag)
1286*0e3ece09SJunchao Zhang   PetscCallCXX(KokkosSparse::spgemm_symbolic(mm->kh3, Ao, false, mm->Fd, false, mm->C3));
1287*0e3ece09SJunchao Zhang   PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh3, Ao, false, mm->Fd, false, mm->C3));
1288*0e3ece09SJunchao Zhang   PetscCallCXX(KokkosSparse::spgemm_symbolic(mm->kh4, Ao, false, mm->Fo, false, mm->C4));
1289*0e3ece09SJunchao Zhang   PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh4, Ao, false, mm->Fo, false, mm->C4));
1290*0e3ece09SJunchao Zhang #if PETSC_PKG_KOKKOS_KERNELS_VERSION_LT(4, 0, 0)
1291*0e3ece09SJunchao Zhang   PetscCallCXX(sort_crs_matrix(mm->C3));
1292*0e3ece09SJunchao Zhang   PetscCallCXX(sort_crs_matrix(mm->C4));
1293*0e3ece09SJunchao Zhang #endif
1294076ba34aSJunchao Zhang 
1295*0e3ece09SJunchao Zhang   // Create C2, which shares a, i arrays with C2_mid, but with new column indices and potentially larger column size
1296*0e3ece09SJunchao Zhang   MatColIdxKokkosView oldj = mm->C2_mid.graph.entries, newj("j", oldj.extent(0));
1297*0e3ece09SJunchao Zhang   PetscIntKokkosView  map  = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), map_h);
1298*0e3ece09SJunchao Zhang   PetscCallCXX(Kokkos::parallel_for(
1299*0e3ece09SJunchao Zhang     oldj.extent(0), KOKKOS_LAMBDA(const PetscInt i) { newj(i) = map(oldj(i)); }));
1300*0e3ece09SJunchao Zhang   mm->C2 = KokkosCsrMatrix("C2", mm->C2_mid.numRows(), mm->n /*new column size*/, mm->C2_mid.nnz(), mm->C2_mid.values, mm->C2_mid.graph.row_map, newj);
1301*0e3ece09SJunchao Zhang 
1302*0e3ece09SJunchao Zhang   // C = (Cd, Co) = (C1+C3, C2+C4)
1303*0e3ece09SJunchao Zhang   mm->kh1.create_spadd_handle(true); // C1, C3 are sorted
1304*0e3ece09SJunchao Zhang   mm->kh2.create_spadd_handle(true); // C2, C4 are sorted
1305*0e3ece09SJunchao Zhang   PetscCallCXX(KokkosSparse::spadd_symbolic(&mm->kh1, mm->C1, mm->C3, mm->Cd));
1306*0e3ece09SJunchao Zhang   PetscCallCXX(KokkosSparse::spadd_symbolic(&mm->kh2, mm->C2, mm->C4, mm->Co));
1307*0e3ece09SJunchao Zhang   PetscCallCXX(KokkosSparse::spadd_numeric(&mm->kh1, 1.0, mm->C1, 1.0, mm->C3, mm->Cd));
1308*0e3ece09SJunchao Zhang   PetscCallCXX(KokkosSparse::spadd_numeric(&mm->kh2, 1.0, mm->C2, 1.0, mm->C4, mm->Co));
13093ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
1310076ba34aSJunchao Zhang }
1311076ba34aSJunchao Zhang 
1312*0e3ece09SJunchao Zhang static PetscErrorCode MatProductNumeric_MPIAIJKokkos_AB(Mat_Product *product, Mat A, Mat B, MatMatStruct_AB *mm)
1313d71ae5a4SJacob Faibussowitsch {
1314*0e3ece09SJunchao Zhang   Mat_MPIAIJ     *ampi = static_cast<Mat_MPIAIJ *>(A->data);
1315*0e3ece09SJunchao Zhang   Mat_MPIAIJ     *bmpi = static_cast<Mat_MPIAIJ *>(B->data);
1316*0e3ece09SJunchao Zhang   KokkosCsrMatrix Ad, Ao, Bd, Bo;
1317076ba34aSJunchao Zhang 
1318076ba34aSJunchao Zhang   PetscFunctionBegin;
1319*0e3ece09SJunchao Zhang   PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(ampi->A, &Ad));
1320*0e3ece09SJunchao Zhang   PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(ampi->B, &Ao));
1321*0e3ece09SJunchao Zhang   PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(bmpi->A, &Bd));
1322*0e3ece09SJunchao Zhang   PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(bmpi->B, &Bo));
1323076ba34aSJunchao Zhang 
1324*0e3ece09SJunchao Zhang   // Bcast B's rows to form F, and overlap the communication
1325*0e3ece09SJunchao Zhang   PetscCall(MatMPIAIJKokkosBcastBegin(B, NULL, MAT_REUSE_MATRIX, NULL, mm));
1326076ba34aSJunchao Zhang 
1327*0e3ece09SJunchao Zhang   // A's diag * (B's diag + B's off-diag)
1328*0e3ece09SJunchao Zhang   PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh1, Ad, false, Bd, false, mm->C1));
1329*0e3ece09SJunchao Zhang   PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh2, Ad, false, Bo, false, mm->C2_mid));
1330076ba34aSJunchao Zhang 
1331*0e3ece09SJunchao Zhang   PetscCall(MatMPIAIJKokkosBcastEnd(B, NULL, MAT_REUSE_MATRIX, NULL, mm));
1332076ba34aSJunchao Zhang 
1333*0e3ece09SJunchao Zhang   // A's off-diag * (F's diag + F's off-diag)
1334*0e3ece09SJunchao Zhang   PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh3, Ao, false, mm->Fd, false, mm->C3));
1335*0e3ece09SJunchao Zhang   PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh4, Ao, false, mm->Fo, false, mm->C4));
1336*0e3ece09SJunchao Zhang 
1337*0e3ece09SJunchao Zhang   // C = (Cd, Co) = (C1+C3, C2+C4)
1338*0e3ece09SJunchao Zhang   PetscCallCXX(KokkosSparse::spadd_numeric(&mm->kh1, 1.0, mm->C1, 1.0, mm->C3, mm->Cd));
1339*0e3ece09SJunchao Zhang   PetscCallCXX(KokkosSparse::spadd_numeric(&mm->kh2, 1.0, mm->C2, 1.0, mm->C4, mm->Co));
13403ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
1341076ba34aSJunchao Zhang }
1342076ba34aSJunchao Zhang 
1343d71ae5a4SJacob Faibussowitsch PetscErrorCode MatProductNumeric_MPIAIJKokkos(Mat C)
1344d71ae5a4SJacob Faibussowitsch {
1345*0e3ece09SJunchao Zhang   Mat_MPIAIJ                  *cmpi = static_cast<Mat_MPIAIJ *>(C->data);
1346*0e3ece09SJunchao Zhang   Mat_Product                 *product;
1347*0e3ece09SJunchao Zhang   MatProductData_MPIAIJKokkos *pdata;
1348076ba34aSJunchao Zhang   MatProductType               ptype;
1349*0e3ece09SJunchao Zhang   Mat                          A, B;
1350076ba34aSJunchao Zhang 
1351076ba34aSJunchao Zhang   PetscFunctionBegin;
1352*0e3ece09SJunchao Zhang   MatCheckProduct(C, 1); // make sure C is a product
1353*0e3ece09SJunchao Zhang   product = C->product;
1354*0e3ece09SJunchao Zhang   pdata   = static_cast<MatProductData_MPIAIJKokkos *>(product->data);
1355076ba34aSJunchao Zhang   ptype   = product->type;
1356076ba34aSJunchao Zhang   A       = product->A;
1357076ba34aSJunchao Zhang   B       = product->B;
1358076ba34aSJunchao Zhang 
1359*0e3ece09SJunchao Zhang   // See if numeric has already been done in symbolic (e.g., user calls MatMatMult(A,B,MAT_INITIAL_MATRIX,..,C)).
1360*0e3ece09SJunchao Zhang   // If yes, skip the numeric, but reset the flag so that next time when user calls MatMatMult(E,F,MAT_REUSE_MATRIX,..,C),
1361*0e3ece09SJunchao Zhang   // we still do numeric.
1362*0e3ece09SJunchao Zhang   if (pdata->reusesym) { // numeric reuses results from symbolic
1363*0e3ece09SJunchao Zhang     pdata->reusesym = PETSC_FALSE;
13643ba16761SJacob Faibussowitsch     PetscFunctionReturn(PETSC_SUCCESS);
1365076ba34aSJunchao Zhang   }
1366076ba34aSJunchao Zhang 
1367076ba34aSJunchao Zhang   if (ptype == MATPRODUCT_AB) {
1368*0e3ece09SJunchao Zhang     PetscCall(MatProductNumeric_MPIAIJKokkos_AB(product, A, B, pdata->mmAB));
1369076ba34aSJunchao Zhang   } else if (ptype == MATPRODUCT_AtB) {
1370*0e3ece09SJunchao Zhang     PetscCall(MatProductNumeric_MPIAIJKokkos_AtB(product, A, B, pdata->mmAtB));
1371*0e3ece09SJunchao Zhang   } else if (ptype == MATPRODUCT_PtAP) { // BtAB, computed by Z = AB; C= BtZ
1372*0e3ece09SJunchao Zhang     PetscCall(MatProductNumeric_MPIAIJKokkos_AB(product, A, B, pdata->mmAB));
1373*0e3ece09SJunchao Zhang     PetscCall(MatProductNumeric_MPIAIJKokkos_AtB(product, B, pdata->Z, pdata->mmAtB));
1374076ba34aSJunchao Zhang   }
1375*0e3ece09SJunchao Zhang 
1376*0e3ece09SJunchao Zhang   PetscCall(MatSeqAIJKokkosModifyDevice(cmpi->A)); // mark that A, B on device are modified
1377*0e3ece09SJunchao Zhang   PetscCall(MatSeqAIJKokkosModifyDevice(cmpi->B));
13783ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
1379076ba34aSJunchao Zhang }
1380076ba34aSJunchao Zhang 
1381d71ae5a4SJacob Faibussowitsch PetscErrorCode MatProductSymbolic_MPIAIJKokkos(Mat C)
1382d71ae5a4SJacob Faibussowitsch {
1383076ba34aSJunchao Zhang   Mat                          A, B;
1384*0e3ece09SJunchao Zhang   Mat_Product                 *product;
1385076ba34aSJunchao Zhang   MatProductType               ptype;
1386*0e3ece09SJunchao Zhang   MatProductData_MPIAIJKokkos *pdata;
1387076ba34aSJunchao Zhang   MatMatStruct                *mm = NULL;
1388*0e3ece09SJunchao Zhang   PetscInt                     m, n, M, N;
1389*0e3ece09SJunchao Zhang   Mat                          Cd, Co;
1390*0e3ece09SJunchao Zhang   MPI_Comm                     comm;
1391076ba34aSJunchao Zhang 
1392076ba34aSJunchao Zhang   PetscFunctionBegin;
1393*0e3ece09SJunchao Zhang   PetscCall(PetscObjectGetComm((PetscObject)C, &comm));
1394076ba34aSJunchao Zhang   MatCheckProduct(C, 1);
1395*0e3ece09SJunchao Zhang   product = C->product;
1396*0e3ece09SJunchao Zhang   PetscCheck(!product->data, comm, PETSC_ERR_PLIB, "Product data not empty");
1397076ba34aSJunchao Zhang   ptype = product->type;
1398076ba34aSJunchao Zhang   A     = product->A;
1399076ba34aSJunchao Zhang   B     = product->B;
1400076ba34aSJunchao Zhang 
1401076ba34aSJunchao Zhang   switch (ptype) {
14029371c9d4SSatish Balay   case MATPRODUCT_AB:
14039371c9d4SSatish Balay     m = A->rmap->n;
14049371c9d4SSatish Balay     n = B->cmap->n;
14059371c9d4SSatish Balay     M = A->rmap->N;
14069371c9d4SSatish Balay     N = B->cmap->N;
14079371c9d4SSatish Balay     break;
14089371c9d4SSatish Balay   case MATPRODUCT_AtB:
14099371c9d4SSatish Balay     m = A->cmap->n;
14109371c9d4SSatish Balay     n = B->cmap->n;
14119371c9d4SSatish Balay     M = A->cmap->N;
14129371c9d4SSatish Balay     N = B->cmap->N;
14139371c9d4SSatish Balay     break;
14149371c9d4SSatish Balay   case MATPRODUCT_PtAP:
14159371c9d4SSatish Balay     m = B->cmap->n;
14169371c9d4SSatish Balay     n = B->cmap->n;
14179371c9d4SSatish Balay     M = B->cmap->N;
14189371c9d4SSatish Balay     N = B->cmap->N;
14199371c9d4SSatish Balay     break; /* BtAB */
1420d71ae5a4SJacob Faibussowitsch   default:
1421*0e3ece09SJunchao Zhang     SETERRQ(comm, PETSC_ERR_PLIB, "Not for product type %s", MatProductTypes[ptype]);
1422076ba34aSJunchao Zhang   }
1423076ba34aSJunchao Zhang 
14249566063dSJacob Faibussowitsch   PetscCall(MatSetSizes(C, m, n, M, N));
14259566063dSJacob Faibussowitsch   PetscCall(PetscLayoutSetUp(C->rmap));
14269566063dSJacob Faibussowitsch   PetscCall(PetscLayoutSetUp(C->cmap));
14279566063dSJacob Faibussowitsch   PetscCall(MatSetType(C, ((PetscObject)A)->type_name));
1428076ba34aSJunchao Zhang 
1429*0e3ece09SJunchao Zhang   pdata           = new MatProductData_MPIAIJKokkos();
1430*0e3ece09SJunchao Zhang   pdata->reusesym = product->api_user;
1431076ba34aSJunchao Zhang 
1432076ba34aSJunchao Zhang   if (ptype == MATPRODUCT_AB) {
1433*0e3ece09SJunchao Zhang     auto mmAB = new MatMatStruct_AB();
1434*0e3ece09SJunchao Zhang     PetscCall(MatProductSymbolic_MPIAIJKokkos_AB(product, A, B, mmAB));
1435*0e3ece09SJunchao Zhang     mm = pdata->mmAB = mmAB;
1436076ba34aSJunchao Zhang   } else if (ptype == MATPRODUCT_AtB) {
1437*0e3ece09SJunchao Zhang     auto mmAtB = new MatMatStruct_AtB();
1438*0e3ece09SJunchao Zhang     PetscCall(MatProductSymbolic_MPIAIJKokkos_AtB(product, A, B, mmAtB));
1439*0e3ece09SJunchao Zhang     mm = pdata->mmAtB = mmAtB;
1440*0e3ece09SJunchao Zhang   } else if (ptype == MATPRODUCT_PtAP) { // C = BtAB, computed as Z = AB; C= BtZ
1441*0e3ece09SJunchao Zhang     Mat Zd, Zo, Z;                       // Zd, Zo are owned by pdata->Z
1442*0e3ece09SJunchao Zhang 
1443*0e3ece09SJunchao Zhang     auto mmAB = new MatMatStruct_AB();
1444*0e3ece09SJunchao Zhang     PetscCall(MatProductSymbolic_MPIAIJKokkos_AB(product, A, B, mmAB)); // Z stored as mmAB->{Cd, Co}
1445*0e3ece09SJunchao Zhang     PetscCall(MatCreateSeqAIJKokkosWithKokkosCsrMatrix(PETSC_COMM_SELF, mmAB->Cd, &Zd));
1446*0e3ece09SJunchao Zhang     PetscCall(MatCreateSeqAIJKokkosWithKokkosCsrMatrix(PETSC_COMM_SELF, mmAB->Co, &Zo));
1447*0e3ece09SJunchao Zhang     pdata->mmAB = mmAB;
1448*0e3ece09SJunchao Zhang 
1449*0e3ece09SJunchao Zhang     m = A->rmap->n; // Z's layout
1450*0e3ece09SJunchao Zhang     n = B->cmap->n;
1451*0e3ece09SJunchao Zhang     M = A->rmap->N;
1452*0e3ece09SJunchao Zhang     N = B->cmap->N;
1453*0e3ece09SJunchao Zhang     PetscCall(MatCreate(comm, &Z));
1454*0e3ece09SJunchao Zhang     PetscCall(MatSetSizes(Z, m, n, M, N));
1455*0e3ece09SJunchao Zhang     PetscCall(PetscLayoutSetUp(Z->rmap));
1456*0e3ece09SJunchao Zhang     PetscCall(PetscLayoutSetUp(Z->cmap));
1457*0e3ece09SJunchao Zhang     PetscCall(MatSetType(Z, MATMPIAIJKOKKOS));
1458*0e3ece09SJunchao Zhang     PetscCall(MatSetMPIAIJKokkosWithSplitSeqAIJKokkosMatrices(Z, Zd, Zo, mmAB->garray));
1459*0e3ece09SJunchao Zhang 
1460*0e3ece09SJunchao Zhang     auto mmAtB = new MatMatStruct_AtB();
1461*0e3ece09SJunchao Zhang     PetscCall(MatProductSymbolic_MPIAIJKokkos_AtB(product, B, Z, mmAtB)); // final result C stored as mmAtB->{Cd, Co}
1462*0e3ece09SJunchao Zhang 
1463*0e3ece09SJunchao Zhang     pdata->Z = Z; // give ownership to pdata
1464*0e3ece09SJunchao Zhang     mm = pdata->mmAtB = mmAtB;
1465076ba34aSJunchao Zhang   }
1466*0e3ece09SJunchao Zhang 
1467*0e3ece09SJunchao Zhang   PetscCall(MatCreateSeqAIJKokkosWithKokkosCsrMatrix(PETSC_COMM_SELF, mm->Cd, &Cd));
1468*0e3ece09SJunchao Zhang   PetscCall(MatCreateSeqAIJKokkosWithKokkosCsrMatrix(PETSC_COMM_SELF, mm->Co, &Co));
1469*0e3ece09SJunchao Zhang   PetscCall(MatSetMPIAIJKokkosWithSplitSeqAIJKokkosMatrices(C, Cd, Co, mm->garray));
1470*0e3ece09SJunchao Zhang 
1471*0e3ece09SJunchao Zhang   C->product->data       = pdata;
1472076ba34aSJunchao Zhang   C->product->destroy    = MatProductDataDestroy_MPIAIJKokkos;
1473076ba34aSJunchao Zhang   C->ops->productnumeric = MatProductNumeric_MPIAIJKokkos;
14743ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
1475076ba34aSJunchao Zhang }
1476076ba34aSJunchao Zhang 
1477d71ae5a4SJacob Faibussowitsch PETSC_INTERN PetscErrorCode MatProductSetFromOptions_MPIAIJKokkos(Mat mat)
1478d71ae5a4SJacob Faibussowitsch {
1479076ba34aSJunchao Zhang   Mat_Product *product = mat->product;
1480076ba34aSJunchao Zhang   PetscBool    match   = PETSC_FALSE;
1481076ba34aSJunchao Zhang   PetscBool    usecpu  = PETSC_FALSE;
1482076ba34aSJunchao Zhang 
1483076ba34aSJunchao Zhang   PetscFunctionBegin;
1484076ba34aSJunchao Zhang   MatCheckProduct(mat, 1);
148548a46eb9SPierre Jolivet   if (!product->A->boundtocpu && !product->B->boundtocpu) PetscCall(PetscObjectTypeCompare((PetscObject)product->B, ((PetscObject)product->A)->type_name, &match));
1486076ba34aSJunchao Zhang   if (match) { /* we can always fallback to the CPU if requested */
1487076ba34aSJunchao Zhang     switch (product->type) {
1488076ba34aSJunchao Zhang     case MATPRODUCT_AB:
1489076ba34aSJunchao Zhang       if (product->api_user) {
1490d0609cedSBarry Smith         PetscOptionsBegin(PetscObjectComm((PetscObject)mat), ((PetscObject)mat)->prefix, "MatMatMult", "Mat");
14919566063dSJacob Faibussowitsch         PetscCall(PetscOptionsBool("-matmatmult_backend_cpu", "Use CPU code", "MatMatMult", usecpu, &usecpu, NULL));
1492d0609cedSBarry Smith         PetscOptionsEnd();
1493076ba34aSJunchao Zhang       } else {
1494d0609cedSBarry Smith         PetscOptionsBegin(PetscObjectComm((PetscObject)mat), ((PetscObject)mat)->prefix, "MatProduct_AB", "Mat");
14959566063dSJacob Faibussowitsch         PetscCall(PetscOptionsBool("-mat_product_algorithm_backend_cpu", "Use CPU code", "MatMatMult", usecpu, &usecpu, NULL));
1496d0609cedSBarry Smith         PetscOptionsEnd();
1497076ba34aSJunchao Zhang       }
1498076ba34aSJunchao Zhang       break;
1499076ba34aSJunchao Zhang     case MATPRODUCT_AtB:
1500076ba34aSJunchao Zhang       if (product->api_user) {
1501d0609cedSBarry Smith         PetscOptionsBegin(PetscObjectComm((PetscObject)mat), ((PetscObject)mat)->prefix, "MatTransposeMatMult", "Mat");
15029566063dSJacob Faibussowitsch         PetscCall(PetscOptionsBool("-mattransposematmult_backend_cpu", "Use CPU code", "MatTransposeMatMult", usecpu, &usecpu, NULL));
1503d0609cedSBarry Smith         PetscOptionsEnd();
1504076ba34aSJunchao Zhang       } else {
1505d0609cedSBarry Smith         PetscOptionsBegin(PetscObjectComm((PetscObject)mat), ((PetscObject)mat)->prefix, "MatProduct_AtB", "Mat");
15069566063dSJacob Faibussowitsch         PetscCall(PetscOptionsBool("-mat_product_algorithm_backend_cpu", "Use CPU code", "MatTransposeMatMult", usecpu, &usecpu, NULL));
1507d0609cedSBarry Smith         PetscOptionsEnd();
1508076ba34aSJunchao Zhang       }
1509076ba34aSJunchao Zhang       break;
1510076ba34aSJunchao Zhang     case MATPRODUCT_PtAP:
1511076ba34aSJunchao Zhang       if (product->api_user) {
1512d0609cedSBarry Smith         PetscOptionsBegin(PetscObjectComm((PetscObject)mat), ((PetscObject)mat)->prefix, "MatPtAP", "Mat");
15139566063dSJacob Faibussowitsch         PetscCall(PetscOptionsBool("-matptap_backend_cpu", "Use CPU code", "MatPtAP", usecpu, &usecpu, NULL));
1514d0609cedSBarry Smith         PetscOptionsEnd();
1515076ba34aSJunchao Zhang       } else {
1516d0609cedSBarry Smith         PetscOptionsBegin(PetscObjectComm((PetscObject)mat), ((PetscObject)mat)->prefix, "MatProduct_PtAP", "Mat");
15179566063dSJacob Faibussowitsch         PetscCall(PetscOptionsBool("-mat_product_algorithm_backend_cpu", "Use CPU code", "MatPtAP", usecpu, &usecpu, NULL));
1518d0609cedSBarry Smith         PetscOptionsEnd();
1519076ba34aSJunchao Zhang       }
1520076ba34aSJunchao Zhang       break;
1521d71ae5a4SJacob Faibussowitsch     default:
1522d71ae5a4SJacob Faibussowitsch       break;
1523076ba34aSJunchao Zhang     }
1524076ba34aSJunchao Zhang     match = (PetscBool)!usecpu;
1525076ba34aSJunchao Zhang   }
1526076ba34aSJunchao Zhang   if (match) {
1527076ba34aSJunchao Zhang     switch (product->type) {
1528076ba34aSJunchao Zhang     case MATPRODUCT_AB:
1529076ba34aSJunchao Zhang     case MATPRODUCT_AtB:
1530d71ae5a4SJacob Faibussowitsch     case MATPRODUCT_PtAP:
1531d71ae5a4SJacob Faibussowitsch       mat->ops->productsymbolic = MatProductSymbolic_MPIAIJKokkos;
1532d71ae5a4SJacob Faibussowitsch       break;
1533d71ae5a4SJacob Faibussowitsch     default:
1534d71ae5a4SJacob Faibussowitsch       break;
1535076ba34aSJunchao Zhang     }
1536076ba34aSJunchao Zhang   }
1537076ba34aSJunchao Zhang   /* fallback to MPIAIJ ops */
153848a46eb9SPierre Jolivet   if (!mat->ops->productsymbolic) PetscCall(MatProductSetFromOptions_MPIAIJ(mat));
15393ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
1540076ba34aSJunchao Zhang }
1541076ba34aSJunchao Zhang 
1542d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatSetPreallocationCOO_MPIAIJKokkos(Mat mat, PetscCount coo_n, PetscInt coo_i[], PetscInt coo_j[])
1543d71ae5a4SJacob Faibussowitsch {
1544394ed5ebSJunchao Zhang   Mat_MPIAIJ       *mpiaij = (Mat_MPIAIJ *)mat->data;
1545cbc6b225SStefano Zampini   Mat_MPIAIJKokkos *mpikok;
154642550becSJunchao Zhang 
154742550becSJunchao Zhang   PetscFunctionBegin;
154830203840SJunchao Zhang   PetscCall(MatSetPreallocationCOO_MPIAIJ(mat, coo_n, coo_i, coo_j)); /* mpiaij->A,B's type is set to seqaijkokkos */
1549cbc6b225SStefano Zampini   mat->preallocated = PETSC_TRUE;
15509566063dSJacob Faibussowitsch   PetscCall(MatAssemblyBegin(mat, MAT_FINAL_ASSEMBLY));
15519566063dSJacob Faibussowitsch   PetscCall(MatAssemblyEnd(mat, MAT_FINAL_ASSEMBLY));
15529566063dSJacob Faibussowitsch   PetscCall(MatZeroEntries(mat));
1553cbc6b225SStefano Zampini   mpikok = static_cast<Mat_MPIAIJKokkos *>(mpiaij->spptr);
1554cbc6b225SStefano Zampini   delete mpikok;
1555394ed5ebSJunchao Zhang   mpiaij->spptr = new Mat_MPIAIJKokkos(mpiaij);
15563ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
155742550becSJunchao Zhang }
155842550becSJunchao Zhang 
1559d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatSetValuesCOO_MPIAIJKokkos(Mat mat, const PetscScalar v[], InsertMode imode)
1560d71ae5a4SJacob Faibussowitsch {
1561394ed5ebSJunchao Zhang   Mat_MPIAIJ                 *mpiaij = static_cast<Mat_MPIAIJ *>(mat->data);
156242550becSJunchao Zhang   Mat_MPIAIJKokkos           *mpikok = static_cast<Mat_MPIAIJKokkos *>(mpiaij->spptr);
156342550becSJunchao Zhang   Mat                         A = mpiaij->A, B = mpiaij->B;
1564158ec288SJunchao Zhang   PetscCount                  Annz = mpiaij->Annz, Annz2 = mpiaij->Annz2, Bnnz = mpiaij->Bnnz, Bnnz2 = mpiaij->Bnnz2;
156542550becSJunchao Zhang   MatScalarKokkosView         Aa, Ba;
1566394ed5ebSJunchao Zhang   MatScalarKokkosView         v1;
156742550becSJunchao Zhang   MatScalarKokkosView        &vsend  = mpikok->sendbuf_d;
156842550becSJunchao Zhang   const MatScalarKokkosView  &v2     = mpikok->recvbuf_d;
1569158ec288SJunchao Zhang   const PetscCountKokkosView &Ajmap1 = mpikok->Ajmap1_d, Ajmap2 = mpikok->Ajmap2_d, Aimap2 = mpikok->Aimap2_d;
1570158ec288SJunchao Zhang   const PetscCountKokkosView &Bjmap1 = mpikok->Bjmap1_d, Bjmap2 = mpikok->Bjmap2_d, Bimap2 = mpikok->Bimap2_d;
1571394ed5ebSJunchao Zhang   const PetscCountKokkosView &Aperm1 = mpikok->Aperm1_d, Aperm2 = mpikok->Aperm2_d, Bperm1 = mpikok->Bperm1_d, Bperm2 = mpikok->Bperm2_d;
1572394ed5ebSJunchao Zhang   const PetscCountKokkosView &Cperm1 = mpikok->Cperm1_d;
157342550becSJunchao Zhang   PetscMemType                memtype;
157442550becSJunchao Zhang 
157542550becSJunchao Zhang   PetscFunctionBegin;
15769566063dSJacob Faibussowitsch   PetscCall(PetscGetMemType(v, &memtype)); /* Return PETSC_MEMTYPE_HOST when v is NULL */
157742550becSJunchao Zhang   if (PetscMemTypeHost(memtype)) {         /* If user gave v[] in host, we need to copy it to device if any */
1578394ed5ebSJunchao Zhang     v1 = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), MatScalarKokkosViewHost((PetscScalar *)v, mpiaij->coo_n));
157942550becSJunchao Zhang   } else {
1580394ed5ebSJunchao Zhang     v1 = MatScalarKokkosView((PetscScalar *)v, mpiaij->coo_n); /* Directly use v[]'s memory */
158142550becSJunchao Zhang   }
158242550becSJunchao Zhang 
158342550becSJunchao Zhang   if (imode == INSERT_VALUES) {
15849566063dSJacob Faibussowitsch     PetscCall(MatSeqAIJGetKokkosViewWrite(A, &Aa)); /* write matrix values */
15859566063dSJacob Faibussowitsch     PetscCall(MatSeqAIJGetKokkosViewWrite(B, &Ba));
1586394ed5ebSJunchao Zhang   } else {
15879566063dSJacob Faibussowitsch     PetscCall(MatSeqAIJGetKokkosView(A, &Aa)); /* read & write matrix values */
15889566063dSJacob Faibussowitsch     PetscCall(MatSeqAIJGetKokkosView(B, &Ba));
158942550becSJunchao Zhang   }
159042550becSJunchao Zhang 
159142550becSJunchao Zhang   /* Pack entries to be sent to remote */
15929371c9d4SSatish Balay   Kokkos::parallel_for(
15939371c9d4SSatish Balay     vsend.extent(0), KOKKOS_LAMBDA(const PetscCount i) { vsend(i) = v1(Cperm1(i)); });
159442550becSJunchao Zhang 
159542550becSJunchao Zhang   /* Send remote entries to their owner and overlap the communication with local computation */
15969566063dSJacob Faibussowitsch   PetscCall(PetscSFReduceWithMemTypeBegin(mpiaij->coo_sf, MPIU_SCALAR, PETSC_MEMTYPE_KOKKOS, vsend.data(), PETSC_MEMTYPE_KOKKOS, v2.data(), MPI_REPLACE));
1597158ec288SJunchao Zhang   /* Add local entries to A and B in one kernel */
15989371c9d4SSatish Balay   Kokkos::parallel_for(
15999371c9d4SSatish Balay     Annz + Bnnz, KOKKOS_LAMBDA(PetscCount i) {
1600158ec288SJunchao Zhang       PetscScalar sum = 0.0;
1601158ec288SJunchao Zhang       if (i < Annz) {
1602158ec288SJunchao Zhang         for (PetscCount k = Ajmap1(i); k < Ajmap1(i + 1); k++) sum += v1(Aperm1(k));
1603ac38520cSJunchao Zhang         Aa(i) = (imode == INSERT_VALUES ? 0.0 : Aa(i)) + sum;
1604158ec288SJunchao Zhang       } else {
1605158ec288SJunchao Zhang         i -= Annz;
1606158ec288SJunchao Zhang         for (PetscCount k = Bjmap1(i); k < Bjmap1(i + 1); k++) sum += v1(Bperm1(k));
1607ac38520cSJunchao Zhang         Ba(i) = (imode == INSERT_VALUES ? 0.0 : Ba(i)) + sum;
1608158ec288SJunchao Zhang       }
1609158ec288SJunchao Zhang     });
16109566063dSJacob Faibussowitsch   PetscCall(PetscSFReduceEnd(mpiaij->coo_sf, MPIU_SCALAR, vsend.data(), v2.data(), MPI_REPLACE));
161142550becSJunchao Zhang 
1612158ec288SJunchao Zhang   /* Add received remote entries to A and B in one kernel */
16139371c9d4SSatish Balay   Kokkos::parallel_for(
16149371c9d4SSatish Balay     Annz2 + Bnnz2, KOKKOS_LAMBDA(PetscCount i) {
1615158ec288SJunchao Zhang       if (i < Annz2) {
1616158ec288SJunchao Zhang         for (PetscCount k = Ajmap2(i); k < Ajmap2(i + 1); k++) Aa(Aimap2(i)) += v2(Aperm2(k));
1617158ec288SJunchao Zhang       } else {
1618158ec288SJunchao Zhang         i -= Annz2;
1619158ec288SJunchao Zhang         for (PetscCount k = Bjmap2(i); k < Bjmap2(i + 1); k++) Ba(Bimap2(i)) += v2(Bperm2(k));
1620158ec288SJunchao Zhang       }
1621158ec288SJunchao Zhang     });
162242550becSJunchao Zhang 
1623394ed5ebSJunchao Zhang   if (imode == INSERT_VALUES) {
16249566063dSJacob Faibussowitsch     PetscCall(MatSeqAIJRestoreKokkosViewWrite(A, &Aa)); /* Increase A & B's state etc. */
16259566063dSJacob Faibussowitsch     PetscCall(MatSeqAIJRestoreKokkosViewWrite(B, &Ba));
1626394ed5ebSJunchao Zhang   } else {
16279566063dSJacob Faibussowitsch     PetscCall(MatSeqAIJRestoreKokkosView(A, &Aa));
16289566063dSJacob Faibussowitsch     PetscCall(MatSeqAIJRestoreKokkosView(B, &Ba));
1629394ed5ebSJunchao Zhang   }
16303ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
163142550becSJunchao Zhang }
163242550becSJunchao Zhang 
1633d71ae5a4SJacob Faibussowitsch PetscErrorCode MatDestroy_MPIAIJKokkos(Mat A)
1634d71ae5a4SJacob Faibussowitsch {
163542550becSJunchao Zhang   Mat_MPIAIJ *mpiaij = (Mat_MPIAIJ *)A->data;
1636076ba34aSJunchao Zhang 
1637076ba34aSJunchao Zhang   PetscFunctionBegin;
16389566063dSJacob Faibussowitsch   PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatMPIAIJSetPreallocation_C", NULL));
16399566063dSJacob Faibussowitsch   PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatMPIAIJGetLocalMatMerge_C", NULL));
16409566063dSJacob Faibussowitsch   PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatSetPreallocationCOO_C", NULL));
16419566063dSJacob Faibussowitsch   PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatSetValuesCOO_C", NULL));
164242550becSJunchao Zhang   delete (Mat_MPIAIJKokkos *)mpiaij->spptr;
16439566063dSJacob Faibussowitsch   PetscCall(MatDestroy_MPIAIJ(A));
16443ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
1645076ba34aSJunchao Zhang }
1646076ba34aSJunchao Zhang 
1647d71ae5a4SJacob Faibussowitsch PETSC_INTERN PetscErrorCode MatConvert_MPIAIJ_MPIAIJKokkos(Mat A, MatType mtype, MatReuse reuse, Mat *newmat)
1648d71ae5a4SJacob Faibussowitsch {
16498c3ff71bSJunchao Zhang   Mat         B;
1650076ba34aSJunchao Zhang   Mat_MPIAIJ *a;
16518c3ff71bSJunchao Zhang 
16528c3ff71bSJunchao Zhang   PetscFunctionBegin;
16538c3ff71bSJunchao Zhang   if (reuse == MAT_INITIAL_MATRIX) {
16549566063dSJacob Faibussowitsch     PetscCall(MatDuplicate(A, MAT_COPY_VALUES, newmat));
16558c3ff71bSJunchao Zhang   } else if (reuse == MAT_REUSE_MATRIX) {
16569566063dSJacob Faibussowitsch     PetscCall(MatCopy(A, *newmat, SAME_NONZERO_PATTERN));
16578c3ff71bSJunchao Zhang   }
16588c3ff71bSJunchao Zhang   B = *newmat;
16598c3ff71bSJunchao Zhang 
16606f3d89d0SStefano Zampini   B->boundtocpu = PETSC_FALSE;
16619566063dSJacob Faibussowitsch   PetscCall(PetscFree(B->defaultvectype));
16629566063dSJacob Faibussowitsch   PetscCall(PetscStrallocpy(VECKOKKOS, &B->defaultvectype));
16639566063dSJacob Faibussowitsch   PetscCall(PetscObjectChangeTypeName((PetscObject)B, MATMPIAIJKOKKOS));
16648c3ff71bSJunchao Zhang 
1665076ba34aSJunchao Zhang   a = static_cast<Mat_MPIAIJ *>(A->data);
16669566063dSJacob Faibussowitsch   if (a->A) PetscCall(MatSetType(a->A, MATSEQAIJKOKKOS));
16679566063dSJacob Faibussowitsch   if (a->B) PetscCall(MatSetType(a->B, MATSEQAIJKOKKOS));
16689566063dSJacob Faibussowitsch   if (a->lvec) PetscCall(VecSetType(a->lvec, VECSEQKOKKOS));
1669076ba34aSJunchao Zhang 
16708c3ff71bSJunchao Zhang   B->ops->assemblyend           = MatAssemblyEnd_MPIAIJKokkos;
16718c3ff71bSJunchao Zhang   B->ops->mult                  = MatMult_MPIAIJKokkos;
16728c3ff71bSJunchao Zhang   B->ops->multadd               = MatMultAdd_MPIAIJKokkos;
16738c3ff71bSJunchao Zhang   B->ops->multtranspose         = MatMultTranspose_MPIAIJKokkos;
1674076ba34aSJunchao Zhang   B->ops->productsetfromoptions = MatProductSetFromOptions_MPIAIJKokkos;
1675076ba34aSJunchao Zhang   B->ops->destroy               = MatDestroy_MPIAIJKokkos;
16768c3ff71bSJunchao Zhang 
16779566063dSJacob Faibussowitsch   PetscCall(PetscObjectComposeFunction((PetscObject)B, "MatMPIAIJSetPreallocation_C", MatMPIAIJSetPreallocation_MPIAIJKokkos));
16789566063dSJacob Faibussowitsch   PetscCall(PetscObjectComposeFunction((PetscObject)B, "MatMPIAIJGetLocalMatMerge_C", MatMPIAIJGetLocalMatMerge_MPIAIJKokkos));
16799566063dSJacob Faibussowitsch   PetscCall(PetscObjectComposeFunction((PetscObject)B, "MatSetPreallocationCOO_C", MatSetPreallocationCOO_MPIAIJKokkos));
16809566063dSJacob Faibussowitsch   PetscCall(PetscObjectComposeFunction((PetscObject)B, "MatSetValuesCOO_C", MatSetValuesCOO_MPIAIJKokkos));
16813ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
16828c3ff71bSJunchao Zhang }
16833f3ba80aSJunchao Zhang /*MC
168411a5261eSBarry Smith    MATAIJKOKKOS - "mpiaijkokkos", a matrix type to be used for CSR sparse matrices with Kokkos
16858c3ff71bSJunchao Zhang 
16863f3ba80aSJunchao Zhang    A matrix type type using Kokkos-Kernels CrsMatrix type for portability across different device types
16873f3ba80aSJunchao Zhang 
16882ef1f0ffSBarry Smith    Options Database Key:
16892ef1f0ffSBarry Smith .  -mat_type aijkokkos - sets the matrix type to `MATAIJKOKKOS`
16903f3ba80aSJunchao Zhang 
16913f3ba80aSJunchao Zhang   Level: beginner
16923f3ba80aSJunchao Zhang 
16932ef1f0ffSBarry Smith .seealso: [](chapter_matrices), `Mat`, `MatCreateAIJKokkos()`, `MATSEQAIJKOKKOS`, `MATSEQAIJ`, `MATMPIAIJ`
16943f3ba80aSJunchao Zhang M*/
1695d71ae5a4SJacob Faibussowitsch PETSC_EXTERN PetscErrorCode MatCreate_MPIAIJKokkos(Mat A)
1696d71ae5a4SJacob Faibussowitsch {
16978c3ff71bSJunchao Zhang   PetscFunctionBegin;
16989566063dSJacob Faibussowitsch   PetscCall(PetscKokkosInitializeCheck());
16999566063dSJacob Faibussowitsch   PetscCall(MatCreate_MPIAIJ(A));
17009566063dSJacob Faibussowitsch   PetscCall(MatConvert_MPIAIJ_MPIAIJKokkos(A, MATMPIAIJKOKKOS, MAT_INPLACE_MATRIX, &A));
17013ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
17028c3ff71bSJunchao Zhang }
17038c3ff71bSJunchao Zhang 
17048c3ff71bSJunchao Zhang /*@C
170511a5261eSBarry Smith    MatCreateAIJKokkos - Creates a sparse matrix in `MATAIJKOKOS` (compressed row) format
17068c3ff71bSJunchao Zhang    (the default parallel PETSc format).  This matrix will ultimately pushed down
17078c3ff71bSJunchao Zhang    to Kokkos for calculations. For good matrix
17088c3ff71bSJunchao Zhang    assembly performance the user should preallocate the matrix storage by setting
17092ef1f0ffSBarry Smith    the parameter `nz` (or the array `nnz`).
17108c3ff71bSJunchao Zhang 
17118c3ff71bSJunchao Zhang    Collective
17128c3ff71bSJunchao Zhang 
17138c3ff71bSJunchao Zhang    Input Parameters:
171411a5261eSBarry Smith +  comm - MPI communicator, set to `PETSC_COMM_SELF`
17158c3ff71bSJunchao Zhang .  m - number of rows
17168c3ff71bSJunchao Zhang .  n - number of columns
17178c3ff71bSJunchao Zhang .  nz - number of nonzeros per row (same for all rows)
17188c3ff71bSJunchao Zhang -  nnz - array containing the number of nonzeros in the various rows
17192ef1f0ffSBarry Smith          (possibly different for each row) or `NULL`
17208c3ff71bSJunchao Zhang 
17218c3ff71bSJunchao Zhang    Output Parameter:
17228c3ff71bSJunchao Zhang .  A - the matrix
17238c3ff71bSJunchao Zhang 
17242ef1f0ffSBarry Smith    Level: intermediate
17252ef1f0ffSBarry Smith 
17262ef1f0ffSBarry Smith    Notes:
172711a5261eSBarry Smith    It is recommended that one use the `MatCreate()`, `MatSetType()` and/or `MatSetFromOptions()`,
17288c3ff71bSJunchao Zhang    MatXXXXSetPreallocation() paradigm instead of this routine directly.
172911a5261eSBarry Smith    [MatXXXXSetPreallocation() is, for example, `MatSeqAIJSetPreallocation()`]
17308c3ff71bSJunchao Zhang 
17312ef1f0ffSBarry Smith    If `nnz` is given then `nz` is ignored
17328c3ff71bSJunchao Zhang 
1733667f096bSBarry Smith    The AIJ format, also called compressed row storage), is fully compatible with standard Fortran
17348c3ff71bSJunchao Zhang    storage.  That is, the stored row and column indices can begin at
17352ef1f0ffSBarry Smith    either one (as in Fortran) or zero.
17368c3ff71bSJunchao Zhang 
17378c3ff71bSJunchao Zhang    Specify the preallocated storage with either nz or nnz (not both).
17382ef1f0ffSBarry Smith    Set `nz` = `PETSC_DEFAULT` and `nnz` = `NULL` for PETSc to control dynamic memory
17398c3ff71bSJunchao Zhang    allocation.  For large problems you MUST preallocate memory or you
17408c3ff71bSJunchao Zhang    will get TERRIBLE performance, see the users' manual chapter on matrices.
17418c3ff71bSJunchao Zhang 
17428c3ff71bSJunchao Zhang    By default, this format uses inodes (identical nodes) when possible, to
17438c3ff71bSJunchao Zhang    improve numerical efficiency of matrix-vector products and solves. We
17448c3ff71bSJunchao Zhang    search for consecutive rows with the same nonzero structure, thereby
17458c3ff71bSJunchao Zhang    reusing matrix information to achieve increased efficiency.
17468c3ff71bSJunchao Zhang 
17472ef1f0ffSBarry Smith    Developer Note:
17482ef1f0ffSBarry Smith    This manual page is for the sequential constructor, not the parallel constructor
17498c3ff71bSJunchao Zhang 
17502ef1f0ffSBarry Smith .seealso: [](chapter_matrices), `Mat`, `MATAIJKOKOS`, `MATSEQAIJKOKOS`, `MATMPIAIJKOKOS`, `MatCreate()`, `MatCreateAIJ()`, `MatSetValues()`,
17512ef1f0ffSBarry Smith           `MatSeqAIJSetColumnIndices()`, `MatCreateSeqAIJWithArrays()`, `MatCreateAIJ()`, `MATMPIAIJKOKKOS`, `MATAIJKOKKOS`
17528c3ff71bSJunchao Zhang @*/
1753d71ae5a4SJacob Faibussowitsch PetscErrorCode MatCreateAIJKokkos(MPI_Comm comm, PetscInt m, PetscInt n, PetscInt M, PetscInt N, PetscInt d_nz, const PetscInt d_nnz[], PetscInt o_nz, const PetscInt o_nnz[], Mat *A)
1754d71ae5a4SJacob Faibussowitsch {
17558c3ff71bSJunchao Zhang   PetscMPIInt size;
17568c3ff71bSJunchao Zhang 
17578c3ff71bSJunchao Zhang   PetscFunctionBegin;
17589566063dSJacob Faibussowitsch   PetscCall(MatCreate(comm, A));
17599566063dSJacob Faibussowitsch   PetscCall(MatSetSizes(*A, m, n, M, N));
17609566063dSJacob Faibussowitsch   PetscCallMPI(MPI_Comm_size(comm, &size));
17618c3ff71bSJunchao Zhang   if (size > 1) {
17629566063dSJacob Faibussowitsch     PetscCall(MatSetType(*A, MATMPIAIJKOKKOS));
17639566063dSJacob Faibussowitsch     PetscCall(MatMPIAIJSetPreallocation(*A, d_nz, d_nnz, o_nz, o_nnz));
17648c3ff71bSJunchao Zhang   } else {
17659566063dSJacob Faibussowitsch     PetscCall(MatSetType(*A, MATSEQAIJKOKKOS));
17669566063dSJacob Faibussowitsch     PetscCall(MatSeqAIJSetPreallocation(*A, d_nz, d_nnz));
17678c3ff71bSJunchao Zhang   }
17683ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
17698c3ff71bSJunchao Zhang }
17708c3ff71bSJunchao Zhang 
1771a587d139SMark // get GPU pointer to stripped down Mat. For both Seq and MPI Mat.
1772d71ae5a4SJacob Faibussowitsch PetscErrorCode MatKokkosGetDeviceMatWrite(Mat A, PetscSplitCSRDataStructure *B)
1773d71ae5a4SJacob Faibussowitsch {
1774a587d139SMark   PetscMPIInt                size, rank;
1775a587d139SMark   MPI_Comm                   comm;
1776042217e8SBarry Smith   PetscSplitCSRDataStructure d_mat = NULL;
1777a587d139SMark 
1778a587d139SMark   PetscFunctionBegin;
17799566063dSJacob Faibussowitsch   PetscCall(PetscObjectGetComm((PetscObject)A, &comm));
17809566063dSJacob Faibussowitsch   PetscCallMPI(MPI_Comm_size(comm, &size));
17819566063dSJacob Faibussowitsch   PetscCallMPI(MPI_Comm_rank(comm, &rank));
1782a587d139SMark   if (size == 1) {
17839566063dSJacob Faibussowitsch     PetscCall(MatSeqAIJKokkosGetDeviceMat(A, &d_mat));
17849566063dSJacob Faibussowitsch     PetscCall(MatSeqAIJKokkosModifyDevice(A)); /* Since we are going to modify matrix values on device */
1785a587d139SMark   } else {
1786a587d139SMark     Mat_MPIAIJ *aij = (Mat_MPIAIJ *)A->data;
17879566063dSJacob Faibussowitsch     PetscCall(MatSeqAIJKokkosGetDeviceMat(aij->A, &d_mat));
17889566063dSJacob Faibussowitsch     PetscCall(MatSeqAIJKokkosModifyDevice(aij->A));
17899566063dSJacob Faibussowitsch     PetscCall(MatSeqAIJKokkosModifyDevice(aij->B));
17902c71b3e2SJacob Faibussowitsch     PetscCheck(A->nooffprocentries || aij->donotstash, PetscObjectComm((PetscObject)A), PETSC_ERR_SUP, "Device assembly does not currently support offproc values insertion. Use MatSetOption(A,MAT_NO_OFF_PROC_ENTRIES,PETSC_TRUE) or MatSetOption(A,MAT_IGNORE_OFF_PROC_ENTRIES,PETSC_TRUE)");
1791a587d139SMark   }
1792a587d139SMark   // act like MatSetValues because not called on host
1793a587d139SMark   if (A->assembled) {
179448a46eb9SPierre Jolivet     if (A->was_assembled) PetscCall(PetscInfo(A, "Assemble more than once already\n"));
1795a587d139SMark     A->was_assembled = PETSC_TRUE; // this is done (lazy) in MatAssemble but we are not calling it anymore - done in AIJ AssemblyEnd, need here?
1796a587d139SMark   } else {
17979566063dSJacob Faibussowitsch     PetscCall(PetscInfo(A, "Warning !assemble ??? assembled=%" PetscInt_FMT "\n", A->assembled));
1798a587d139SMark   }
1799a587d139SMark   if (!d_mat) {
1800042217e8SBarry Smith     struct _n_SplitCSRMat h_mat; /* host container */
1801a587d139SMark     Mat_SeqAIJKokkos     *aijkokA;
1802a587d139SMark     Mat_SeqAIJ           *jaca;
1803a587d139SMark     PetscInt              n = A->rmap->n, nnz;
1804a587d139SMark     Mat                   Amat;
1805042217e8SBarry Smith     PetscInt             *colmap;
1806042217e8SBarry Smith 
1807042217e8SBarry Smith     /* create and copy h_mat */
180849b994a9SMark Adams     h_mat.M = A->cmap->N; // use for debug build
18099566063dSJacob Faibussowitsch     PetscCall(PetscInfo(A, "Create device matrix in Kokkos\n"));
1810a587d139SMark     if (size == 1) {
1811a587d139SMark       Amat            = A;
1812a587d139SMark       jaca            = (Mat_SeqAIJ *)A->data;
18139371c9d4SSatish Balay       h_mat.rstart    = 0;
18149371c9d4SSatish Balay       h_mat.rend      = A->rmap->n;
18159371c9d4SSatish Balay       h_mat.cstart    = 0;
18169371c9d4SSatish Balay       h_mat.cend      = A->cmap->n;
1817a587d139SMark       h_mat.offdiag.i = h_mat.offdiag.j = NULL;
1818a587d139SMark       h_mat.offdiag.a                   = NULL;
1819a587d139SMark       aijkokA                           = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
1820a587d139SMark     } else {
1821a587d139SMark       Mat_MPIAIJ       *aij  = (Mat_MPIAIJ *)A->data;
1822a587d139SMark       Mat_SeqAIJ       *jacb = (Mat_SeqAIJ *)aij->B->data;
1823a587d139SMark       PetscInt          ii;
1824a587d139SMark       Mat_SeqAIJKokkos *aijkokB;
1825042217e8SBarry Smith 
1826a587d139SMark       Amat    = aij->A;
1827a587d139SMark       aijkokA = static_cast<Mat_SeqAIJKokkos *>(aij->A->spptr);
1828a587d139SMark       aijkokB = static_cast<Mat_SeqAIJKokkos *>(aij->B->spptr);
1829a587d139SMark       jaca    = (Mat_SeqAIJ *)aij->A->data;
183008401ef6SPierre Jolivet       PetscCheck(!aij->B->cmap->n || aij->garray, comm, PETSC_ERR_PLIB, "MPIAIJ Matrix was assembled but is missing garray");
183108401ef6SPierre Jolivet       PetscCheck(aij->B->rmap->n == aij->A->rmap->n, comm, PETSC_ERR_SUP, "Only support aij->B->rmap->n == aij->A->rmap->n");
1832a587d139SMark       aij->donotstash          = PETSC_TRUE;
1833a587d139SMark       aij->A->nooffprocentries = aij->B->nooffprocentries = A->nooffprocentries = PETSC_TRUE;
1834a5b23f4aSJose E. Roman       jaca->nonew = jacb->nonew = PETSC_TRUE; // no more disassembly
18359566063dSJacob Faibussowitsch       PetscCall(PetscCalloc1(A->cmap->N, &colmap));
1836042217e8SBarry Smith       for (ii = 0; ii < aij->B->cmap->n; ii++) colmap[aij->garray[ii]] = ii + 1;
1837a587d139SMark       // allocate B copy data
18389371c9d4SSatish Balay       h_mat.rstart = A->rmap->rstart;
18399371c9d4SSatish Balay       h_mat.rend   = A->rmap->rend;
18409371c9d4SSatish Balay       h_mat.cstart = A->cmap->rstart;
18419371c9d4SSatish Balay       h_mat.cend   = A->cmap->rend;
1842a587d139SMark       nnz          = jacb->i[n];
1843a587d139SMark       if (jacb->compressedrow.use) {
1844a587d139SMark         const Kokkos::View<PetscInt *, Kokkos::HostSpace, Kokkos::MemoryTraits<Kokkos::Unmanaged>> h_i_k(jacb->i, n + 1);
1845300d22a6SJunchao Zhang         aijkokB->i_uncompressed_d = Kokkos::View<PetscInt *>(Kokkos::create_mirror(DefaultMemorySpace(), h_i_k));
1846300d22a6SJunchao Zhang         Kokkos::deep_copy(aijkokB->i_uncompressed_d, h_i_k);
1847300d22a6SJunchao Zhang         h_mat.offdiag.i = aijkokB->i_uncompressed_d.data();
1848a587d139SMark       } else {
184999551766SMark Adams         h_mat.offdiag.i = aijkokB->i_device_data();
1850a587d139SMark       }
185199551766SMark Adams       h_mat.offdiag.j = aijkokB->j_device_data();
1852076ba34aSJunchao Zhang       h_mat.offdiag.a = aijkokB->a_device_data();
1853a587d139SMark       {
1854042217e8SBarry Smith         Kokkos::View<PetscInt *, Kokkos::HostSpace, Kokkos::MemoryTraits<Kokkos::Unmanaged>> h_colmap_k(colmap, A->cmap->N);
1855300d22a6SJunchao Zhang         aijkokB->colmap_d = Kokkos::View<PetscInt *>(Kokkos::create_mirror(DefaultMemorySpace(), h_colmap_k));
1856300d22a6SJunchao Zhang         Kokkos::deep_copy(aijkokB->colmap_d, h_colmap_k);
1857300d22a6SJunchao Zhang         h_mat.colmap = aijkokB->colmap_d.data();
18589566063dSJacob Faibussowitsch         PetscCall(PetscFree(colmap));
1859a587d139SMark       }
1860a587d139SMark       h_mat.offdiag.ignorezeroentries = jacb->ignorezeroentries;
1861a587d139SMark       h_mat.offdiag.n                 = n;
1862a587d139SMark     }
1863a587d139SMark     // allocate A copy data
1864a587d139SMark     nnz                          = jaca->i[n];
1865a587d139SMark     h_mat.diag.n                 = n;
1866a587d139SMark     h_mat.diag.ignorezeroentries = jaca->ignorezeroentries;
18679566063dSJacob Faibussowitsch     PetscCallMPI(MPI_Comm_rank(comm, &h_mat.rank));
1868d5b43468SJose E. Roman     PetscCheck(!jaca->compressedrow.use, PETSC_COMM_SELF, PETSC_ERR_PLIB, "A does not support compressed row (todo)");
186999551766SMark Adams     h_mat.diag.i = aijkokA->i_device_data();
187099551766SMark Adams     h_mat.diag.j = aijkokA->j_device_data();
1871076ba34aSJunchao Zhang     h_mat.diag.a = aijkokA->a_device_data();
1872da81f932SPierre Jolivet     // copy pointers and metadata to device
18739566063dSJacob Faibussowitsch     PetscCall(MatSeqAIJKokkosSetDeviceMat(Amat, &h_mat));
18749566063dSJacob Faibussowitsch     PetscCall(MatSeqAIJKokkosGetDeviceMat(Amat, &d_mat));
18759566063dSJacob Faibussowitsch     PetscCall(PetscInfo(A, "Create device Mat n=%" PetscInt_FMT " nnz=%" PetscInt_FMT "\n", h_mat.diag.n, nnz));
1876a587d139SMark   }
1877a587d139SMark   *B           = d_mat;       // return it, set it in Mat, and set it up
1878a587d139SMark   A->assembled = PETSC_FALSE; // ready to write with matsetvalues - this done (lazy) in normal MatSetValues
18793ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
1880a587d139SMark }
1881076ba34aSJunchao Zhang 
1882d71ae5a4SJacob Faibussowitsch PETSC_INTERN PetscErrorCode MatSeqAIJKokkosGetOffloadMask(Mat A, const char **mask)
1883d71ae5a4SJacob Faibussowitsch {
1884076ba34aSJunchao Zhang   Mat_SeqAIJKokkos *aijkok = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
1885076ba34aSJunchao Zhang 
1886076ba34aSJunchao Zhang   PetscFunctionBegin;
1887076ba34aSJunchao Zhang   if (!aijkok) *mask = "AIJKOK_UNALLOCATED";
1888076ba34aSJunchao Zhang   else if (aijkok->a_dual.need_sync_host()) *mask = "PETSC_OFFLOAD_GPU";
1889076ba34aSJunchao Zhang   else if (aijkok->a_dual.need_sync_device()) *mask = "PETSC_OFFLOAD_CPU";
1890076ba34aSJunchao Zhang   else *mask = "PETSC_OFFLOAD_BOTH";
18913ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
1892076ba34aSJunchao Zhang }
1893076ba34aSJunchao Zhang 
1894d71ae5a4SJacob Faibussowitsch PETSC_INTERN PetscErrorCode MatAIJKokkosPrintOffloadMask(Mat A)
1895d71ae5a4SJacob Faibussowitsch {
1896076ba34aSJunchao Zhang   PetscMPIInt size;
1897076ba34aSJunchao Zhang   Mat         Ad, Ao;
1898076ba34aSJunchao Zhang   const char *amask, *bmask;
1899076ba34aSJunchao Zhang 
1900076ba34aSJunchao Zhang   PetscFunctionBegin;
19019566063dSJacob Faibussowitsch   PetscCallMPI(MPI_Comm_size(PetscObjectComm((PetscObject)A), &size));
1902076ba34aSJunchao Zhang 
1903076ba34aSJunchao Zhang   if (size == 1) {
19049566063dSJacob Faibussowitsch     PetscCall(MatSeqAIJKokkosGetOffloadMask(A, &amask));
19059566063dSJacob Faibussowitsch     PetscCall(PetscPrintf(PETSC_COMM_SELF, "%s\n", amask));
1906076ba34aSJunchao Zhang   } else {
1907076ba34aSJunchao Zhang     Ad = ((Mat_MPIAIJ *)A->data)->A;
1908076ba34aSJunchao Zhang     Ao = ((Mat_MPIAIJ *)A->data)->B;
19099566063dSJacob Faibussowitsch     PetscCall(MatSeqAIJKokkosGetOffloadMask(Ad, &amask));
19109566063dSJacob Faibussowitsch     PetscCall(MatSeqAIJKokkosGetOffloadMask(Ao, &bmask));
19119566063dSJacob Faibussowitsch     PetscCall(PetscPrintf(PETSC_COMM_SELF, "Diag : Off-diag = %s : %s\n", amask, bmask));
1912076ba34aSJunchao Zhang   }
19133ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
1914076ba34aSJunchao Zhang }
1915