xref: /petsc/src/mat/impls/aij/mpi/kokkos/mpiaijkok.kokkos.cxx (revision 08bb99260af10b9a29bd93948fe63730a22f8379)
111d22bbfSJunchao Zhang #include <petscvec_kokkos.hpp>
2f0e6e2d1SJunchao Zhang #include <petscpkg_version.h>
342550becSJunchao Zhang #include <petsc/private/sfimpl.h>
42c4ab24aSJunchao Zhang #include <../src/mat/impls/aij/seq/kokkos/aijkok.hpp>
58c3ff71bSJunchao Zhang #include <../src/mat/impls/aij/mpi/mpiaij.h>
6076ba34aSJunchao Zhang #include <KokkosSparse_spadd.hpp>
70e3ece09SJunchao Zhang #include <KokkosSparse_spgemm.hpp>
811d22bbfSJunchao Zhang 
9d71ae5a4SJacob Faibussowitsch PetscErrorCode MatAssemblyEnd_MPIAIJKokkos(Mat A, MatAssemblyType mode)
10d71ae5a4SJacob Faibussowitsch {
1130203840SJunchao Zhang   Mat_MPIAIJ *mpiaij = (Mat_MPIAIJ *)A->data;
128c3ff71bSJunchao Zhang 
138c3ff71bSJunchao Zhang   PetscFunctionBegin;
149566063dSJacob Faibussowitsch   PetscCall(MatAssemblyEnd_MPIAIJ(A, mode));
1530203840SJunchao Zhang   /* E.g., MatCreateSubMatrix() calls MatCreateMPIAIJWithSeqAIJ(comm,A,B,..), which creates Bnew of SEQAIJ and destroys B of SEQAIJKOKKOS.
1630203840SJunchao Zhang      Thus we finalize A/B/lvec's type in MatAssemblyEnd() to handle various cases.
1730203840SJunchao Zhang    */
1830203840SJunchao Zhang   if (mode == MAT_FINAL_ASSEMBLY) {
1930203840SJunchao Zhang     PetscCall(MatSetType(mpiaij->A, MATSEQAIJKOKKOS));
2030203840SJunchao Zhang     PetscCall(MatSetType(mpiaij->B, MATSEQAIJKOKKOS));
2130203840SJunchao Zhang     PetscCall(VecSetType(mpiaij->lvec, VECSEQKOKKOS));
2230203840SJunchao Zhang   }
23a587d139SMark 
243ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
258c3ff71bSJunchao Zhang }
268c3ff71bSJunchao Zhang 
27d71ae5a4SJacob Faibussowitsch PetscErrorCode MatMPIAIJSetPreallocation_MPIAIJKokkos(Mat mat, PetscInt d_nz, const PetscInt d_nnz[], PetscInt o_nz, const PetscInt o_nnz[])
28d71ae5a4SJacob Faibussowitsch {
298c3ff71bSJunchao Zhang   Mat_MPIAIJ *mpiaij = (Mat_MPIAIJ *)mat->data;
308c3ff71bSJunchao Zhang 
318c3ff71bSJunchao Zhang   PetscFunctionBegin;
329566063dSJacob Faibussowitsch   PetscCall(PetscLayoutSetUp(mat->rmap));
339566063dSJacob Faibussowitsch   PetscCall(PetscLayoutSetUp(mat->cmap));
346a29ce69SStefano Zampini #if defined(PETSC_USE_DEBUG)
358c3ff71bSJunchao Zhang   if (d_nnz) {
366a29ce69SStefano Zampini     PetscInt i;
37ad540459SPierre Jolivet     for (i = 0; i < mat->rmap->n; i++) PetscCheck(d_nnz[i] >= 0, PETSC_COMM_SELF, PETSC_ERR_ARG_OUTOFRANGE, "d_nnz cannot be less than 0: local row %" PetscInt_FMT " value %" PetscInt_FMT, i, d_nnz[i]);
388c3ff71bSJunchao Zhang   }
398c3ff71bSJunchao Zhang   if (o_nnz) {
406a29ce69SStefano Zampini     PetscInt i;
41ad540459SPierre Jolivet     for (i = 0; i < mat->rmap->n; i++) PetscCheck(o_nnz[i] >= 0, PETSC_COMM_SELF, PETSC_ERR_ARG_OUTOFRANGE, "o_nnz cannot be less than 0: local row %" PetscInt_FMT " value %" PetscInt_FMT, i, o_nnz[i]);
428c3ff71bSJunchao Zhang   }
436a29ce69SStefano Zampini #endif
446a29ce69SStefano Zampini #if defined(PETSC_USE_CTABLE)
45eec179cfSJacob Faibussowitsch   PetscCall(PetscHMapIDestroy(&mpiaij->colmap));
466a29ce69SStefano Zampini #else
479566063dSJacob Faibussowitsch   PetscCall(PetscFree(mpiaij->colmap));
486a29ce69SStefano Zampini #endif
499566063dSJacob Faibussowitsch   PetscCall(PetscFree(mpiaij->garray));
509566063dSJacob Faibussowitsch   PetscCall(VecDestroy(&mpiaij->lvec));
519566063dSJacob Faibussowitsch   PetscCall(VecScatterDestroy(&mpiaij->Mvctx));
526a29ce69SStefano Zampini   /* Because the B will have been resized we simply destroy it and create a new one each time */
539566063dSJacob Faibussowitsch   PetscCall(MatDestroy(&mpiaij->B));
546a29ce69SStefano Zampini 
556a29ce69SStefano Zampini   if (!mpiaij->A) {
569566063dSJacob Faibussowitsch     PetscCall(MatCreate(PETSC_COMM_SELF, &mpiaij->A));
579566063dSJacob Faibussowitsch     PetscCall(MatSetSizes(mpiaij->A, mat->rmap->n, mat->cmap->n, mat->rmap->n, mat->cmap->n));
586a29ce69SStefano Zampini   }
596a29ce69SStefano Zampini   if (!mpiaij->B) {
606a29ce69SStefano Zampini     PetscMPIInt size;
619566063dSJacob Faibussowitsch     PetscCallMPI(MPI_Comm_size(PetscObjectComm((PetscObject)mat), &size));
629566063dSJacob Faibussowitsch     PetscCall(MatCreate(PETSC_COMM_SELF, &mpiaij->B));
639566063dSJacob Faibussowitsch     PetscCall(MatSetSizes(mpiaij->B, mat->rmap->n, size > 1 ? mat->cmap->N : 0, mat->rmap->n, size > 1 ? mat->cmap->N : 0));
648c3ff71bSJunchao Zhang   }
659566063dSJacob Faibussowitsch   PetscCall(MatSetType(mpiaij->A, MATSEQAIJKOKKOS));
669566063dSJacob Faibussowitsch   PetscCall(MatSetType(mpiaij->B, MATSEQAIJKOKKOS));
679566063dSJacob Faibussowitsch   PetscCall(MatSeqAIJSetPreallocation(mpiaij->A, d_nz, d_nnz));
689566063dSJacob Faibussowitsch   PetscCall(MatSeqAIJSetPreallocation(mpiaij->B, o_nz, o_nnz));
698c3ff71bSJunchao Zhang   mat->preallocated = PETSC_TRUE;
703ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
718c3ff71bSJunchao Zhang }
728c3ff71bSJunchao Zhang 
73d71ae5a4SJacob Faibussowitsch PetscErrorCode MatMult_MPIAIJKokkos(Mat mat, Vec xx, Vec yy)
74d71ae5a4SJacob Faibussowitsch {
758c3ff71bSJunchao Zhang   Mat_MPIAIJ *mpiaij = (Mat_MPIAIJ *)mat->data;
768c3ff71bSJunchao Zhang   PetscInt    nt;
778c3ff71bSJunchao Zhang 
788c3ff71bSJunchao Zhang   PetscFunctionBegin;
799566063dSJacob Faibussowitsch   PetscCall(VecGetLocalSize(xx, &nt));
8008401ef6SPierre Jolivet   PetscCheck(nt == mat->cmap->n, PETSC_COMM_SELF, PETSC_ERR_ARG_SIZ, "Incompatible partition of mat (%" PetscInt_FMT ") and xx (%" PetscInt_FMT ")", mat->cmap->n, nt);
819566063dSJacob Faibussowitsch   PetscCall(VecScatterBegin(mpiaij->Mvctx, xx, mpiaij->lvec, INSERT_VALUES, SCATTER_FORWARD));
829566063dSJacob Faibussowitsch   PetscCall((*mpiaij->A->ops->mult)(mpiaij->A, xx, yy));
839566063dSJacob Faibussowitsch   PetscCall(VecScatterEnd(mpiaij->Mvctx, xx, mpiaij->lvec, INSERT_VALUES, SCATTER_FORWARD));
849566063dSJacob Faibussowitsch   PetscCall((*mpiaij->B->ops->multadd)(mpiaij->B, mpiaij->lvec, yy, yy));
853ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
868c3ff71bSJunchao Zhang }
878c3ff71bSJunchao Zhang 
88d71ae5a4SJacob Faibussowitsch PetscErrorCode MatMultAdd_MPIAIJKokkos(Mat mat, Vec xx, Vec yy, Vec zz)
89d71ae5a4SJacob Faibussowitsch {
908c3ff71bSJunchao Zhang   Mat_MPIAIJ *mpiaij = (Mat_MPIAIJ *)mat->data;
918c3ff71bSJunchao Zhang   PetscInt    nt;
928c3ff71bSJunchao Zhang 
938c3ff71bSJunchao Zhang   PetscFunctionBegin;
949566063dSJacob Faibussowitsch   PetscCall(VecGetLocalSize(xx, &nt));
9508401ef6SPierre Jolivet   PetscCheck(nt == mat->cmap->n, PETSC_COMM_SELF, PETSC_ERR_ARG_SIZ, "Incompatible partition of mat (%" PetscInt_FMT ") and xx (%" PetscInt_FMT ")", mat->cmap->n, nt);
969566063dSJacob Faibussowitsch   PetscCall(VecScatterBegin(mpiaij->Mvctx, xx, mpiaij->lvec, INSERT_VALUES, SCATTER_FORWARD));
979566063dSJacob Faibussowitsch   PetscCall((*mpiaij->A->ops->multadd)(mpiaij->A, xx, yy, zz));
989566063dSJacob Faibussowitsch   PetscCall(VecScatterEnd(mpiaij->Mvctx, xx, mpiaij->lvec, INSERT_VALUES, SCATTER_FORWARD));
999566063dSJacob Faibussowitsch   PetscCall((*mpiaij->B->ops->multadd)(mpiaij->B, mpiaij->lvec, zz, zz));
1003ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
1018c3ff71bSJunchao Zhang }
1028c3ff71bSJunchao Zhang 
103d71ae5a4SJacob Faibussowitsch PetscErrorCode MatMultTranspose_MPIAIJKokkos(Mat mat, Vec xx, Vec yy)
104d71ae5a4SJacob Faibussowitsch {
1058c3ff71bSJunchao Zhang   Mat_MPIAIJ *mpiaij = (Mat_MPIAIJ *)mat->data;
1068c3ff71bSJunchao Zhang   PetscInt    nt;
1078c3ff71bSJunchao Zhang 
1088c3ff71bSJunchao Zhang   PetscFunctionBegin;
1099566063dSJacob Faibussowitsch   PetscCall(VecGetLocalSize(xx, &nt));
11008401ef6SPierre Jolivet   PetscCheck(nt == mat->rmap->n, PETSC_COMM_SELF, PETSC_ERR_ARG_SIZ, "Incompatible partition of mat (%" PetscInt_FMT ") and xx (%" PetscInt_FMT ")", mat->rmap->n, nt);
1119566063dSJacob Faibussowitsch   PetscCall((*mpiaij->B->ops->multtranspose)(mpiaij->B, xx, mpiaij->lvec));
1129566063dSJacob Faibussowitsch   PetscCall((*mpiaij->A->ops->multtranspose)(mpiaij->A, xx, yy));
1139566063dSJacob Faibussowitsch   PetscCall(VecScatterBegin(mpiaij->Mvctx, mpiaij->lvec, yy, ADD_VALUES, SCATTER_REVERSE));
1149566063dSJacob Faibussowitsch   PetscCall(VecScatterEnd(mpiaij->Mvctx, mpiaij->lvec, yy, ADD_VALUES, SCATTER_REVERSE));
1153ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
1168c3ff71bSJunchao Zhang }
1178c3ff71bSJunchao Zhang 
118076ba34aSJunchao Zhang /* Merge the "A, B" matrices of mat into a matrix C.  mat's type is MPIAIJKOKKOS. C's type is MATSEQAIJKOKKOS.
119076ba34aSJunchao Zhang    A is put before B. C's size would be A->rmap->n by (A->cmap->n + B->cmap->n).
120076ba34aSJunchao Zhang    C still uses local column ids. Their corresponding global column ids are returned in glob.
121076ba34aSJunchao Zhang */
122d71ae5a4SJacob Faibussowitsch PetscErrorCode MatMPIAIJGetLocalMatMerge_MPIAIJKokkos(Mat mat, MatReuse reuse, IS *glob, Mat *C)
123d71ae5a4SJacob Faibussowitsch {
124076ba34aSJunchao Zhang   Mat             Ad, Ao;
125076ba34aSJunchao Zhang   const PetscInt *cmap;
126076ba34aSJunchao Zhang 
127076ba34aSJunchao Zhang   PetscFunctionBegin;
1289566063dSJacob Faibussowitsch   PetscCall(MatMPIAIJGetSeqAIJ(mat, &Ad, &Ao, &cmap));
1299566063dSJacob Faibussowitsch   PetscCall(MatSeqAIJKokkosMergeMats(Ad, Ao, reuse, C));
130076ba34aSJunchao Zhang   if (glob) {
131076ba34aSJunchao Zhang     PetscInt cst, i, dn, on, *gidx;
1329566063dSJacob Faibussowitsch     PetscCall(MatGetLocalSize(Ad, NULL, &dn));
1339566063dSJacob Faibussowitsch     PetscCall(MatGetLocalSize(Ao, NULL, &on));
1349566063dSJacob Faibussowitsch     PetscCall(MatGetOwnershipRangeColumn(mat, &cst, NULL));
1359566063dSJacob Faibussowitsch     PetscCall(PetscMalloc1(dn + on, &gidx));
136076ba34aSJunchao Zhang     for (i = 0; i < dn; i++) gidx[i] = cst + i;
137076ba34aSJunchao Zhang     for (i = 0; i < on; i++) gidx[i + dn] = cmap[i];
1389566063dSJacob Faibussowitsch     PetscCall(ISCreateGeneral(PetscObjectComm((PetscObject)Ad), dn + on, gidx, PETSC_OWN_POINTER, glob));
139076ba34aSJunchao Zhang   }
1403ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
141076ba34aSJunchao Zhang }
142076ba34aSJunchao Zhang 
1430e3ece09SJunchao Zhang /* Structs used in matrix products of type C=AB, C=A^tB and C=B^tAB */
144076ba34aSJunchao Zhang struct MatMatStruct {
1450e3ece09SJunchao Zhang   PetscInt            n, *garray;     // C's garray and its size.
1460e3ece09SJunchao Zhang   KokkosCsrMatrix     Cd, Co;         // C is in split form matrices (all in local column indcies)
1470e3ece09SJunchao Zhang   KokkosCsrMatrix     C1, C2, C3, C4; // intermediate mat products
1480e3ece09SJunchao Zhang   KokkosCsrMatrix     C2_mid, C4_mid; // alias of C2, C4; share their a[], i[], but with different j[] (hence column size)
1490e3ece09SJunchao Zhang   PetscIntKokkosView  E_NzLeft;
1500e3ece09SJunchao Zhang   PetscSF             sf = nullptr; // SF to bcast or reduce matrices E to F
1510e3ece09SJunchao Zhang   MatScalarKokkosView rootBuf, leafBuf;
1520e3ece09SJunchao Zhang   KokkosCsrMatrix     Fd, Fo; // F in split form
1530e3ece09SJunchao Zhang 
1540e3ece09SJunchao Zhang   KernelHandle kh1; // compute C1, add C1+C3 or C1+Fd
1550e3ece09SJunchao Zhang   KernelHandle kh2; // compute C2, add C2+C4 or C2+Fo
1560e3ece09SJunchao Zhang   KernelHandle kh3; // compute C3
1570e3ece09SJunchao Zhang   KernelHandle kh4; // compute C4
1580e3ece09SJunchao Zhang 
159aaa8cc7dSPierre Jolivet   PetscInt E_TeamSize; // kernel launching parameters in merging E or splitting F
1600e3ece09SJunchao Zhang   PetscInt E_VectorLength;
1610e3ece09SJunchao Zhang   PetscInt E_RowsPerTeam;
1620e3ece09SJunchao Zhang   PetscInt F_TeamSize;
1630e3ece09SJunchao Zhang   PetscInt F_VectorLength;
1640e3ece09SJunchao Zhang   PetscInt F_RowsPerTeam;
165076ba34aSJunchao Zhang 
166d71ae5a4SJacob Faibussowitsch   ~MatMatStruct()
167d71ae5a4SJacob Faibussowitsch   {
1683ba16761SJacob Faibussowitsch     PetscFunctionBegin;
1693ba16761SJacob Faibussowitsch     PetscCallAbort(PETSC_COMM_SELF, PetscSFDestroy(&sf));
1703ba16761SJacob Faibussowitsch     PetscFunctionReturnVoid();
171076ba34aSJunchao Zhang   }
172076ba34aSJunchao Zhang };
173076ba34aSJunchao Zhang 
174076ba34aSJunchao Zhang struct MatMatStruct_AB : public MatMatStruct {
1750e3ece09SJunchao Zhang   PetscIntKokkosView F_NzLeft; // plans to split F (in leafbuf) into Fd, Fo
1760e3ece09SJunchao Zhang   PetscIntKokkosView irootloc; // plans to put E (i.e., Bd, Bo) into rootBuf
1770e3ece09SJunchao Zhang   PetscIntKokkosView rowoffset;
178076ba34aSJunchao Zhang };
179076ba34aSJunchao Zhang 
180076ba34aSJunchao Zhang struct MatMatStruct_AtB : public MatMatStruct {
1810e3ece09SJunchao Zhang   MatColIdxKokkosView Fdjmap; // plans to reduce data in rootBuf to Fd, Fo
1820e3ece09SJunchao Zhang   MatColIdxKokkosView Fdjperm;
1830e3ece09SJunchao Zhang   MatColIdxKokkosView Fojmap;
1840e3ece09SJunchao Zhang   MatColIdxKokkosView Fojperm;
185076ba34aSJunchao Zhang };
186076ba34aSJunchao Zhang 
1879371c9d4SSatish Balay struct MatProductData_MPIAIJKokkos {
1883ba16761SJacob Faibussowitsch   MatMatStruct_AB  *mmAB     = nullptr;
1893ba16761SJacob Faibussowitsch   MatMatStruct_AtB *mmAtB    = nullptr;
1903ba16761SJacob Faibussowitsch   PetscBool         reusesym = PETSC_FALSE;
1910e3ece09SJunchao Zhang   Mat               Z        = nullptr; // store Z=AB in computing BtAB
192076ba34aSJunchao Zhang 
193d71ae5a4SJacob Faibussowitsch   ~MatProductData_MPIAIJKokkos()
194d71ae5a4SJacob Faibussowitsch   {
195076ba34aSJunchao Zhang     delete mmAB;
196076ba34aSJunchao Zhang     delete mmAtB;
1970e3ece09SJunchao Zhang     PetscCallAbort(PETSC_COMM_SELF, MatDestroy(&Z));
198076ba34aSJunchao Zhang   }
199076ba34aSJunchao Zhang };
200076ba34aSJunchao Zhang 
201d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatProductDataDestroy_MPIAIJKokkos(void *data)
202d71ae5a4SJacob Faibussowitsch {
203076ba34aSJunchao Zhang   PetscFunctionBegin;
2049566063dSJacob Faibussowitsch   PetscCallCXX(delete static_cast<MatProductData_MPIAIJKokkos *>(data));
2053ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
206076ba34aSJunchao Zhang }
207076ba34aSJunchao Zhang 
208076ba34aSJunchao Zhang /* MatSetMPIAIJKokkosWithSplitSeqAIJKokkosMatrices - Set the diag and offdiag matrices of a MATMPIAIJKOKKOS matrix.
209076ba34aSJunchao Zhang    It is similar to MatCreateMPIAIJWithSplitArrays.
210076ba34aSJunchao Zhang 
211076ba34aSJunchao Zhang   Input Parameters:
212076ba34aSJunchao Zhang +  mat   - the MATMPIAIJKOKKOS matrix, which should have its type and layout set, but should not have its diag, offdiag matrices set
213076ba34aSJunchao Zhang .  A     - the diag matrix using local col ids
214076ba34aSJunchao Zhang -  B     - the offdiag matrix using global col ids
215076ba34aSJunchao Zhang 
2162fe279fdSBarry Smith   Output Parameter:
217076ba34aSJunchao Zhang .  mat   - the updated MATMPIAIJKOKKOS matrix
218076ba34aSJunchao Zhang */
2190e3ece09SJunchao Zhang static PetscErrorCode MatSetMPIAIJKokkosWithSplitSeqAIJKokkosMatrices(Mat mat, Mat A, Mat B, PetscInt *garray)
220d71ae5a4SJacob Faibussowitsch {
221076ba34aSJunchao Zhang   Mat_MPIAIJ *mpiaij = static_cast<Mat_MPIAIJ *>(mat->data);
222076ba34aSJunchao Zhang   PetscInt    m, n, M, N, Am, An, Bm, Bn;
223076ba34aSJunchao Zhang 
224076ba34aSJunchao Zhang   PetscFunctionBegin;
2259566063dSJacob Faibussowitsch   PetscCall(MatGetSize(mat, &M, &N));
2269566063dSJacob Faibussowitsch   PetscCall(MatGetLocalSize(mat, &m, &n));
2279566063dSJacob Faibussowitsch   PetscCall(MatGetLocalSize(A, &Am, &An));
2289566063dSJacob Faibussowitsch   PetscCall(MatGetLocalSize(B, &Bm, &Bn));
229076ba34aSJunchao Zhang 
230aed4548fSBarry Smith   PetscCheck(m == Am && m == Bm, PETSC_COMM_SELF, PETSC_ERR_PLIB, "local number of rows do not match");
23108401ef6SPierre Jolivet   PetscCheck(n == An, PETSC_COMM_SELF, PETSC_ERR_PLIB, "local number of columns do not match");
2320e3ece09SJunchao Zhang   // PetscCheck(N == Bn, PETSC_COMM_SELF, PETSC_ERR_PLIB, "global number of columns do not match");
23308401ef6SPierre Jolivet   PetscCheck(!mpiaij->A && !mpiaij->B, PETSC_COMM_SELF, PETSC_ERR_PLIB, "A, B of the MPIAIJ matrix are not empty");
234076ba34aSJunchao Zhang   mpiaij->A      = A;
235076ba34aSJunchao Zhang   mpiaij->B      = B;
2360e3ece09SJunchao Zhang   mpiaij->garray = garray;
237076ba34aSJunchao Zhang 
238076ba34aSJunchao Zhang   mat->preallocated     = PETSC_TRUE;
239076ba34aSJunchao Zhang   mat->nooffprocentries = PETSC_TRUE; /* See MatAssemblyBegin_MPIAIJ. In effect, making MatAssemblyBegin a nop */
240076ba34aSJunchao Zhang 
2419566063dSJacob Faibussowitsch   PetscCall(MatSetOption(mat, MAT_NO_OFF_PROC_ENTRIES, PETSC_TRUE));
2429566063dSJacob Faibussowitsch   PetscCall(MatAssemblyBegin(mat, MAT_FINAL_ASSEMBLY));
243076ba34aSJunchao Zhang   /* MatAssemblyEnd is critical here. It sets mat->offloadmask according to A and B's, and
244076ba34aSJunchao Zhang     also gets mpiaij->B compacted, with its col ids and size reduced
245076ba34aSJunchao Zhang   */
2469566063dSJacob Faibussowitsch   PetscCall(MatAssemblyEnd(mat, MAT_FINAL_ASSEMBLY));
2479566063dSJacob Faibussowitsch   PetscCall(MatSetOption(mat, MAT_NO_OFF_PROC_ENTRIES, PETSC_FALSE));
2489566063dSJacob Faibussowitsch   PetscCall(MatSetOption(mat, MAT_NEW_NONZERO_LOCATION_ERR, PETSC_TRUE));
2493ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
250076ba34aSJunchao Zhang }
251076ba34aSJunchao Zhang 
2520e3ece09SJunchao Zhang // Adapted from Kokkos-Kernels spmv_launch_parameters(), to get parameters in Kokkos nested loops which we used to merge or
2530e3ece09SJunchao Zhang // split csr matrices. The rule is to have "vector_length * team_size" be around 256 on GPUs (e.g., for a CUDA thread block)
2540e3ece09SJunchao Zhang template <class ExecutionSpace>
2550e3ece09SJunchao Zhang static PetscErrorCode MatMergeGetLaunchParameters(PetscInt numRows, PetscInt nnz, PetscInt rows_per_thread, PetscInt &team_size, PetscInt &vector_length, PetscInt &rows_per_team)
256d71ae5a4SJacob Faibussowitsch {
2570e3ece09SJunchao Zhang   Kokkos::TeamPolicy<ExecutionSpace> teamPolicy(128, Kokkos::AUTO);
258076ba34aSJunchao Zhang 
259076ba34aSJunchao Zhang   PetscFunctionBegin;
2600e3ece09SJunchao Zhang   PetscInt nnz_per_row = numRows ? (nnz / numRows) : 0; // we might meet empty matrices
261076ba34aSJunchao Zhang 
2620e3ece09SJunchao Zhang   if (nnz_per_row < 1) nnz_per_row = 1;
263076ba34aSJunchao Zhang 
2640e3ece09SJunchao Zhang   int max_vector_length = teamPolicy.vector_length_max();
265076ba34aSJunchao Zhang 
2660e3ece09SJunchao Zhang   if (vector_length < 1) {
2670e3ece09SJunchao Zhang     vector_length = 1;
2680e3ece09SJunchao Zhang     while (vector_length < max_vector_length && vector_length * 6 < nnz_per_row) vector_length *= 2;
269076ba34aSJunchao Zhang   }
270076ba34aSJunchao Zhang 
2710e3ece09SJunchao Zhang   // Determine rows per thread
2720e3ece09SJunchao Zhang   if (rows_per_thread < 1) {
2730e3ece09SJunchao Zhang     if (KokkosKernels::Impl::kk_is_gpu_exec_space<ExecutionSpace>()) rows_per_thread = 1;
2740e3ece09SJunchao Zhang     else {
2750e3ece09SJunchao Zhang       if (nnz_per_row < 20 && nnz > 5000000) {
2760e3ece09SJunchao Zhang         rows_per_thread = 256;
2770e3ece09SJunchao Zhang       } else rows_per_thread = 64;
278076ba34aSJunchao Zhang     }
279076ba34aSJunchao Zhang   }
280076ba34aSJunchao Zhang 
2810e3ece09SJunchao Zhang   if (team_size < 1) {
2820e3ece09SJunchao Zhang     if (KokkosKernels::Impl::kk_is_gpu_exec_space<ExecutionSpace>()) {
2830e3ece09SJunchao Zhang       team_size = 256 / vector_length;
284076ba34aSJunchao Zhang     } else {
2850e3ece09SJunchao Zhang       team_size = 1;
2860e3ece09SJunchao Zhang     }
287076ba34aSJunchao Zhang   }
288076ba34aSJunchao Zhang 
2890e3ece09SJunchao Zhang   rows_per_team = rows_per_thread * team_size;
290076ba34aSJunchao Zhang 
2910e3ece09SJunchao Zhang   if (rows_per_team < 0) {
2920e3ece09SJunchao Zhang     PetscInt nnz_per_team = 4096;
2930e3ece09SJunchao Zhang     PetscInt conc         = ExecutionSpace().concurrency();
2940e3ece09SJunchao Zhang     while ((conc * nnz_per_team * 4 > nnz) && (nnz_per_team > 256)) nnz_per_team /= 2;
2950e3ece09SJunchao Zhang     rows_per_team = (nnz_per_team + nnz_per_row - 1) / nnz_per_row;
2960e3ece09SJunchao Zhang   }
2973ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
298076ba34aSJunchao Zhang }
299076ba34aSJunchao Zhang 
3000e3ece09SJunchao Zhang /*
3010e3ece09SJunchao Zhang   Reduce two sets of global indices into local ones
302076ba34aSJunchao Zhang 
303076ba34aSJunchao Zhang   Input Parameters:
3040e3ece09SJunchao Zhang +  n1          - size of garray1[], the first set
3050e3ece09SJunchao Zhang .  garray1[n1] - a sorted global index array (without duplicates)
3060e3ece09SJunchao Zhang .  m           - size of indices[], the second set
3070e3ece09SJunchao Zhang -  indices[m]  - a unsorted global index array (might have duplicates), which will be updated on output into local ones
308076ba34aSJunchao Zhang 
309076ba34aSJunchao Zhang   Output Parameters:
3100e3ece09SJunchao Zhang +  n2          - size of garray2[], the merged set, which combines garray1[] and indices[]
3110e3ece09SJunchao Zhang .  garray2[n2] - allocated by callee using PetscMalloc1(). Contains sorted unique global indices (without duplicates). Caller needs to free it.
3120e3ece09SJunchao Zhang .  map[n1]     - allocated by caller. It gives garray1[i] = garray2[map[i]]
3130e3ece09SJunchao Zhang -  indices[m]  - on output, global indices in this array are rewritten with local ones, i.e, indices_input[i] = garray2[indices_output[i]]
314076ba34aSJunchao Zhang 
3150e3ece09SJunchao Zhang    Example, say
3160e3ece09SJunchao Zhang     n1         = 5
3170e3ece09SJunchao Zhang     garray1[5] = {1, 4, 7, 8, 10}
3180e3ece09SJunchao Zhang     m          = 4
3190e3ece09SJunchao Zhang     indices[4] = {2, 4, 8, 9}
32011a5261eSBarry Smith 
3210e3ece09SJunchao Zhang    Combining them together, we have 7 global indices in garray2[]
3220e3ece09SJunchao Zhang     n2         = 7
3230e3ece09SJunchao Zhang     garray2[7] = {1, 2, 4, 7, 8, 9, 10}
3240e3ece09SJunchao Zhang 
3250e3ece09SJunchao Zhang    And we have map[] to connect "garray1[i] = garray2[map[i]], i=[0,n1)"
3260e3ece09SJunchao Zhang     map[5] = {0, 2, 3, 4, 6}
3270e3ece09SJunchao Zhang 
3280e3ece09SJunchao Zhang    On output, indices[] is updated with local indices
3290e3ece09SJunchao Zhang     indices[4] = {1, 2, 4, 5}
330076ba34aSJunchao Zhang */
3310e3ece09SJunchao Zhang static PetscErrorCode ReduceTwoSetsOfGlobalIndices(PetscInt n1, const PetscInt *garray1, PetscInt m, PetscInt *indices, PetscInt *n2_, PetscInt **garray2_, PetscInt *map)
332d71ae5a4SJacob Faibussowitsch {
3330e3ece09SJunchao Zhang   PetscHMapI    g2l = nullptr;
3340e3ece09SJunchao Zhang   PetscHashIter iter;
3350e3ece09SJunchao Zhang   PetscInt      tot, key, val; // total unique global indices. key is global id; val is local id
3360e3ece09SJunchao Zhang   PetscInt      n2, *garray2;
337076ba34aSJunchao Zhang 
338076ba34aSJunchao Zhang   PetscFunctionBegin;
3390e3ece09SJunchao Zhang   tot = 0;
3400e3ece09SJunchao Zhang   PetscCall(PetscHMapICreateWithSize(n1, &g2l));
3410e3ece09SJunchao Zhang   for (PetscInt i = 0; i < m; i++) {                                // insert those in indices[]
3420e3ece09SJunchao Zhang     PetscCall(PetscHMapIGetWithDefault(g2l, indices[i], -1, &val)); // if not exist, val is set with -1
3430e3ece09SJunchao Zhang     if (val < 0) PetscCall(PetscHMapISet(g2l, indices[i], tot++));  // val < 0 means gid is not in the hash table yet
344076ba34aSJunchao Zhang   }
345076ba34aSJunchao Zhang 
3460e3ece09SJunchao Zhang   for (PetscInt i = 0; i < n1; i++) { // insert those in garray1[]
3470e3ece09SJunchao Zhang     PetscCall(PetscHMapIGetWithDefault(g2l, garray1[i], -1, &val));
3480e3ece09SJunchao Zhang     if (val < 0) PetscCall(PetscHMapISet(g2l, garray1[i], tot++));
349076ba34aSJunchao Zhang   }
350076ba34aSJunchao Zhang 
3510e3ece09SJunchao Zhang   // Pull out (unique) globals in the hash table and put them in garray2[]
3520e3ece09SJunchao Zhang   n2 = tot;
3530e3ece09SJunchao Zhang   PetscCall(PetscMalloc1(n2, &garray2));
3540e3ece09SJunchao Zhang   tot = 0;
3550e3ece09SJunchao Zhang   PetscHashIterBegin(g2l, iter);
3560e3ece09SJunchao Zhang   while (!PetscHashIterAtEnd(g2l, iter)) {
3570e3ece09SJunchao Zhang     PetscHashIterGetKey(g2l, iter, key);
3580e3ece09SJunchao Zhang     PetscHashIterNext(g2l, iter);
3590e3ece09SJunchao Zhang     garray2[tot++] = key;
360076ba34aSJunchao Zhang   }
361076ba34aSJunchao Zhang 
3620e3ece09SJunchao Zhang   // Sort garray2[] and then map them to local indices starting from 0
3630e3ece09SJunchao Zhang   PetscCall(PetscSortInt(n2, garray2));
3640e3ece09SJunchao Zhang   PetscCall(PetscHMapIClear(g2l));
3650e3ece09SJunchao Zhang   for (PetscInt i = 0; i < tot; i++) PetscCall(PetscHMapISet(g2l, garray2[i], i)); // i is the local id
366f0e6e2d1SJunchao Zhang 
3670e3ece09SJunchao Zhang   // Rewrite indices[] with local indices
368f0e6e2d1SJunchao Zhang   for (PetscInt i = 0; i < m; i++) {
3690e3ece09SJunchao Zhang     PetscCall(PetscHMapIGetWithDefault(g2l, indices[i], -1, &val));
3700e3ece09SJunchao Zhang     PetscAssert(val >= 0, PETSC_COMM_SELF, PETSC_ERR_PLIB, "Met a negative local column index");
3710e3ece09SJunchao Zhang     indices[i] = val;
3720e3ece09SJunchao Zhang   }
3730e3ece09SJunchao Zhang   // Record the map that maps garray1[i] to garray2[map[i]]
3740e3ece09SJunchao Zhang   for (PetscInt i = 0; i < n1; i++) PetscCall(PetscHMapIGetWithDefault(g2l, garray1[i], -1, &map[i]));
3750e3ece09SJunchao Zhang   PetscCall(PetscHMapIDestroy(&g2l));
3760e3ece09SJunchao Zhang   *n2_      = n2;
3770e3ece09SJunchao Zhang   *garray2_ = garray2;
3780e3ece09SJunchao Zhang   PetscFunctionReturn(PETSC_SUCCESS);
3790e3ece09SJunchao Zhang }
380f0e6e2d1SJunchao Zhang 
3810e3ece09SJunchao Zhang /*
3820e3ece09SJunchao Zhang   MatMPIAIJKokkosReduce - Reduce rows of a MPIAIJKOKKOS matrix (E, in split form) to produce another matrix (F, also in split form, stored in mm)
3830e3ece09SJunchao Zhang 
3840e3ece09SJunchao Zhang   It is the reverse of MatMPIAIJKokkosBcast() in some sense, but with a different signature since we do not really need a fully populated MPIAIJKOKKOS E.
3850e3ece09SJunchao Zhang 
3860e3ece09SJunchao Zhang   Think each row of E as a leaf, then the given ownerSF specifies roots for the leaves. Roots may connect to multiple leaves.
3870e3ece09SJunchao Zhang   In this routine, we sparse-merge leaves (rows) at their roots to form potentially longer rows in F. F's number of rows will be nroots of ownerSF.
3880e3ece09SJunchao Zhang 
3890e3ece09SJunchao Zhang   Input Parameters:
3900e3ece09SJunchao Zhang +  comm       - MPI communicator of E
3910e3ece09SJunchao Zhang .  A          - diag block of E, using local column indices
3920e3ece09SJunchao Zhang .  B          - off-diag block of E, using local column indices
3930e3ece09SJunchao Zhang .  cstart      - (global) start column of Ed
3940e3ece09SJunchao Zhang .  cend        - (global) end column + 1 of Ed.  In other words, E's column ownership is in range of [cstart, cend)
3950e3ece09SJunchao Zhang .  garray1[n1] - global column indices of Eo. Here n1 is Eo's column size.
3960e3ece09SJunchao Zhang .  ownerSF     - the SF specifies ownership (root) of rows in E
3970e3ece09SJunchao Zhang .  reuse       - either MAT_INITIAL_MATRIX or MAT_REUSE_MATRIX
3980e3ece09SJunchao Zhang -  mm          - to stash intermediate data structures for reuse
3990e3ece09SJunchao Zhang 
4000e3ece09SJunchao Zhang   Output Parameters:
4010e3ece09SJunchao Zhang +  map[n1]  - allocated by caller. It maps garray1[] to garray2[]. See more at ReduceTwoSetsOfGlobalIndices().
4020e3ece09SJunchao Zhang -  mm       - contains various info, such as garray2[], F (Fd, Fo) etc.
4030e3ece09SJunchao Zhang 
4040e3ece09SJunchao Zhang   Notes:
4050e3ece09SJunchao Zhang   When reuse = MAT_REUSE_MATRIX, cstart, cend, garray1, ownerSF, map are not significant.
4060e3ece09SJunchao Zhang 
4070e3ece09SJunchao Zhang  */
4080e3ece09SJunchao Zhang static PetscErrorCode MatMPIAIJKokkosReduceBegin(MPI_Comm comm, KokkosCsrMatrix A, KokkosCsrMatrix B, PetscInt cstart, PetscInt cend, const PetscInt *garray1, PetscSF ownerSF, MatReuse reuse, PetscInt *map, MatMatStruct_AtB *mm)
4090e3ece09SJunchao Zhang {
4100e3ece09SJunchao Zhang   PetscFunctionBegin;
4110e3ece09SJunchao Zhang   if (reuse == MAT_INITIAL_MATRIX) {
4120e3ece09SJunchao Zhang     PetscInt Em = A.numRows(), Fm;
4130e3ece09SJunchao Zhang     PetscInt n1 = B.numCols();
4140e3ece09SJunchao Zhang 
4150e3ece09SJunchao Zhang     PetscCall(PetscSFGetGraph(ownerSF, &Fm, NULL, NULL, NULL)); // Fm = #rows of F = nroots of ownerSF
4160e3ece09SJunchao Zhang 
4170e3ece09SJunchao Zhang     // Do the analysis on host
4180e3ece09SJunchao Zhang     auto                 Ai_h = Kokkos::create_mirror_view_and_copy(Kokkos::HostSpace(), A.graph.row_map);
4190e3ece09SJunchao Zhang     auto                 Aj_h = Kokkos::create_mirror_view_and_copy(Kokkos::HostSpace(), A.graph.entries);
4200e3ece09SJunchao Zhang     auto                 Bi_h = Kokkos::create_mirror_view_and_copy(Kokkos::HostSpace(), B.graph.row_map);
4210e3ece09SJunchao Zhang     auto                 Bj_h = Kokkos::create_mirror_view_and_copy(Kokkos::HostSpace(), B.graph.entries);
4220e3ece09SJunchao Zhang     const MatRowMapType *Ai = Ai_h.data(), *Bi = Bi_h.data();
4230e3ece09SJunchao Zhang     const MatColIdxType *Aj = Aj_h.data(), *Bj = Bj_h.data();
4240e3ece09SJunchao Zhang 
4250e3ece09SJunchao Zhang     // Count how many nonzeros of each row in E are in the left of the diag block [cstart,cend)
4267b8d4ba6SJunchao Zhang     PetscIntKokkosViewHost E_NzLeft_h(NoInit("E_NzLeft_h"), Em), E_RowLen_h(NoInit("E_RowLen_h"), Em);
4270e3ece09SJunchao Zhang     PetscInt              *E_NzLeft = E_NzLeft_h.data(), *E_RowLen = E_RowLen_h.data();
4280e3ece09SJunchao Zhang     for (PetscInt i = 0; i < Em; i++) {
4290e3ece09SJunchao Zhang       const PetscInt *first, *last, *it;
4300e3ece09SJunchao Zhang       PetscInt        count, step;
4310e3ece09SJunchao Zhang       // std::lower_bound(first,last,cstart), but need to use global column indices
4320e3ece09SJunchao Zhang       first = Bj + Bi[i];
4330e3ece09SJunchao Zhang       last  = Bj + Bi[i + 1];
434f0e6e2d1SJunchao Zhang       count = last - first;
435f0e6e2d1SJunchao Zhang       while (count > 0) {
436f0e6e2d1SJunchao Zhang         it   = first;
437f0e6e2d1SJunchao Zhang         step = count / 2;
438f0e6e2d1SJunchao Zhang         it += step;
4390e3ece09SJunchao Zhang         if (garray1[*it] < cstart) { // map local to global
440f0e6e2d1SJunchao Zhang           first = ++it;
441f0e6e2d1SJunchao Zhang           count -= step + 1;
442f0e6e2d1SJunchao Zhang         } else count = step;
443f0e6e2d1SJunchao Zhang       }
4440e3ece09SJunchao Zhang       E_NzLeft[i] = first - (Bj + Bi[i]);
4450e3ece09SJunchao Zhang       E_RowLen[i] = (Ai[i + 1] - Ai[i]) + (Bi[i + 1] - Bi[i]);
446f0e6e2d1SJunchao Zhang     }
447f0e6e2d1SJunchao Zhang 
4480e3ece09SJunchao Zhang     // Get length of rows (i.e., sizes of leaves) that contribute to my roots
4490e3ece09SJunchao Zhang     const PetscMPIInt *iranks, *ranks;
4500e3ece09SJunchao Zhang     const PetscInt    *ioffset, *irootloc, *roffset, *rmine;
4510e3ece09SJunchao Zhang     PetscInt           niranks, nranks;
4520e3ece09SJunchao Zhang     MPI_Request       *reqs;
4530e3ece09SJunchao Zhang     PetscMPIInt        tag;
4540e3ece09SJunchao Zhang     PetscSF            reduceSF;
4550e3ece09SJunchao Zhang     PetscInt          *sdisp, *rdisp;
456f0e6e2d1SJunchao Zhang 
4570e3ece09SJunchao Zhang     PetscCall(PetscCommGetNewTag(comm, &tag));
4580e3ece09SJunchao Zhang     PetscCall(PetscSFGetLeafRanks(ownerSF, &niranks, &iranks, &ioffset, &irootloc));  // get leaf ranks connecting to roots on this process (I'll recv from them)
4590e3ece09SJunchao Zhang     PetscCall(PetscSFGetRootRanks(ownerSF, &nranks, &ranks, &roffset, &rmine, NULL)); // get root ranks this process connects (I'll send to them)
460f0e6e2d1SJunchao Zhang 
4610e3ece09SJunchao Zhang     // Find out length of each row I will receive. Even for the same row index, when they are from
4620e3ece09SJunchao Zhang     // different senders, they might have different lengths (and sparsity patterns)
4630e3ece09SJunchao Zhang     PetscInt  sendRowCnt = roffset[nranks], recvRowCnt = ioffset[niranks];
4640e3ece09SJunchao Zhang     PetscInt *sendRowLen, *recvRowLen; // lengths of rows of I need to send/recv per process
465f0e6e2d1SJunchao Zhang 
4660e3ece09SJunchao Zhang     PetscCall(PetscMalloc5(sendRowCnt, &sendRowLen, recvRowCnt + 1, &recvRowLen, nranks, &sdisp, niranks + 1, &rdisp, nranks + niranks, &reqs));
4670e3ece09SJunchao Zhang 
4680e3ece09SJunchao Zhang     for (PetscInt i = 0; i < sendRowCnt; i++) sendRowLen[i] = E_RowLen[rmine[i]];
4690e3ece09SJunchao Zhang     recvRowLen[0] = 0; // since we will make it in CSR format later
4700e3ece09SJunchao Zhang     recvRowLen++;      // advance the pointer now
4710e3ece09SJunchao Zhang     for (PetscInt i = 0; i < niranks; i++) { MPI_Irecv(&recvRowLen[ioffset[i]], ioffset[i + 1] - ioffset[i], MPIU_INT, iranks[i], tag, comm, &reqs[nranks + i]); }
4720e3ece09SJunchao Zhang     for (PetscInt i = 0; i < nranks; i++) { MPI_Isend(&sendRowLen[roffset[i]], roffset[i + 1] - roffset[i], MPIU_INT, ranks[i], tag, comm, &reqs[i]); }
4730e3ece09SJunchao Zhang     PetscCallMPI(MPI_Waitall(nranks + niranks, reqs, MPI_STATUSES_IGNORE));
4740e3ece09SJunchao Zhang 
4750e3ece09SJunchao Zhang     // Build the real PetscSF for reducing E rows (buffer to buffer)
4760e3ece09SJunchao Zhang     rdisp[0] = 0;
4770e3ece09SJunchao Zhang     for (PetscInt i = 0; i < niranks; i++) {
4780e3ece09SJunchao Zhang       rdisp[i + 1] = rdisp[i];
4790e3ece09SJunchao Zhang       for (PetscInt j = ioffset[i]; j < ioffset[i + 1]; j++) { rdisp[i + 1] += recvRowLen[j]; }
4800e3ece09SJunchao Zhang     }
4810e3ece09SJunchao Zhang     recvRowLen--; // put it back into csr format
4820e3ece09SJunchao Zhang     for (PetscInt i = 0; i < recvRowCnt; i++) recvRowLen[i + 1] += recvRowLen[i];
4830e3ece09SJunchao Zhang 
4840e3ece09SJunchao Zhang     for (PetscInt i = 0; i < nranks; i++) { MPI_Irecv(&sdisp[i], 1, MPIU_INT, ranks[i], tag, comm, &reqs[i]); }
4850e3ece09SJunchao Zhang     for (PetscInt i = 0; i < niranks; i++) { MPI_Isend(&rdisp[i], 1, MPIU_INT, iranks[i], tag, comm, &reqs[nranks + i]); }
4860e3ece09SJunchao Zhang     PetscCallMPI(MPI_Waitall(nranks + niranks, reqs, MPI_STATUSES_IGNORE));
4870e3ece09SJunchao Zhang 
4880e3ece09SJunchao Zhang     PetscInt     nleaves = 0, Enz = 0;    // leaves are nonzeros I will send
4890e3ece09SJunchao Zhang     PetscInt     nroots = rdisp[niranks]; // roots are nonzeros I will recv
4900e3ece09SJunchao Zhang     PetscSFNode *iremote;
4910e3ece09SJunchao Zhang 
4920e3ece09SJunchao Zhang     for (PetscInt i = 0; i < Em; i++) Enz += E_RowLen[i];
4930e3ece09SJunchao Zhang     PetscAssert(A.nnz() + B.nnz() == Enz, comm, PETSC_ERR_PLIB, "Enz should be equal to sum of nnz of A and B");
4940e3ece09SJunchao Zhang     PetscCallMPI(PetscMalloc1(Enz, &iremote)); // no free, since we give ownership to reduceSF
4950e3ece09SJunchao Zhang 
4960e3ece09SJunchao Zhang     for (PetscInt i = 0; i < nranks; i++) {
4970e3ece09SJunchao Zhang       PetscInt count = 0;
4980e3ece09SJunchao Zhang       for (PetscInt j = roffset[i]; j < roffset[i + 1]; j++) count += E_RowLen[rmine[j]];
4990e3ece09SJunchao Zhang       for (PetscInt j = 0; j < count; j++) {
5000e3ece09SJunchao Zhang         iremote[nleaves + j].rank  = ranks[i];
5010e3ece09SJunchao Zhang         iremote[nleaves + j].index = sdisp[i] + j;
5020e3ece09SJunchao Zhang       }
5030e3ece09SJunchao Zhang       nleaves += count;
5040e3ece09SJunchao Zhang     }
5050e3ece09SJunchao Zhang     PetscCheck(nleaves == Enz, comm, PETSC_ERR_PLIB, "nleaves should be equal to Enz");
5060e3ece09SJunchao Zhang 
5070e3ece09SJunchao Zhang     PetscCall(PetscSFCreate(comm, &reduceSF));
5080e3ece09SJunchao Zhang     PetscCall(PetscSFSetGraph(reduceSF, nroots, nleaves, NULL, PETSC_OWN_POINTER, iremote, PETSC_OWN_POINTER));
5090e3ece09SJunchao Zhang 
5100e3ece09SJunchao Zhang     // Copy (global) column indices of the needed rows in E to sendCol[], and then PetscSFReduce to recvCol[]
5110e3ece09SJunchao Zhang     PetscInt *sendCol, *recvCol;
5120e3ece09SJunchao Zhang     PetscCall(PetscMalloc2(nleaves, &sendCol, nroots, &recvCol));
5130e3ece09SJunchao Zhang     for (PetscInt k = 0; k < roffset[nranks]; k++) {
5140e3ece09SJunchao Zhang       PetscInt  i      = rmine[k]; // row to be copied
5150e3ece09SJunchao Zhang       PetscInt *buf    = &sendCol[Ai[i] + Bi[i]];
5160e3ece09SJunchao Zhang       PetscInt  nzLeft = E_NzLeft[i];
5170e3ece09SJunchao Zhang       PetscInt  alen = Ai[i + 1] - Ai[i], blen = Bi[i + 1] - Bi[i];
5180e3ece09SJunchao Zhang       for (PetscInt j = 0; j < alen + blen; j++) {
5190e3ece09SJunchao Zhang         if (j < nzLeft) {
5200e3ece09SJunchao Zhang           buf[j] = garray1[Bj[Bi[i] + j]]; // left B, in global
5210e3ece09SJunchao Zhang         } else if (j < nzLeft + alen) {
5220e3ece09SJunchao Zhang           buf[j] = Aj[Ai[i] + j - nzLeft] + cstart; // diag A, also in global
5230e3ece09SJunchao Zhang         } else {
5240e3ece09SJunchao Zhang           buf[j] = garray1[Bj[Bi[i] + j - alen]]; // right B, in global
5250e3ece09SJunchao Zhang         }
5260e3ece09SJunchao Zhang       }
5270e3ece09SJunchao Zhang     }
5280e3ece09SJunchao Zhang     PetscCall(PetscSFReduceWithMemTypeBegin(reduceSF, MPIU_INT, PETSC_MEMTYPE_HOST, sendCol, PETSC_MEMTYPE_HOST, recvCol, MPI_REPLACE));
5290e3ece09SJunchao Zhang     PetscCall(PetscSFReduceEnd(reduceSF, MPIU_INT, sendCol, recvCol, MPI_REPLACE));
5300e3ece09SJunchao Zhang 
5310e3ece09SJunchao Zhang     // With recvCol[], we do a series of analysis to get i, j of Fd, Fo, and build plans to reduce nonzeros in recv buffers to Fd and Fo
5320e3ece09SJunchao Zhang     PetscInt *recvRowPerm, *recvColSorted;
5330e3ece09SJunchao Zhang     PetscInt *recvNzPerm, *recvNzPermSorted;
5340e3ece09SJunchao Zhang     PetscCall(PetscMalloc4(recvRowCnt, &recvRowPerm, nroots, &recvColSorted, nroots, &recvNzPerm, nroots, &recvNzPermSorted));
5350e3ece09SJunchao Zhang 
5360e3ece09SJunchao Zhang     for (PetscInt i = 0; i < nroots; i++) recvNzPerm[i] = i;                   // numbering all received nonzeros
5370e3ece09SJunchao Zhang     for (PetscInt i = 0; i < recvRowCnt; i++) recvRowPerm[i] = i;              // put up a permutation array, so that after sorting we know where to get a row in recvCol[]
5380e3ece09SJunchao Zhang     PetscCall(PetscSortIntWithPermutation(recvRowCnt, irootloc, recvRowPerm)); // irootloc[] (owned by ownerSF) won't be changed
5390e3ece09SJunchao Zhang 
5400e3ece09SJunchao Zhang     // i[] array, nz are always easiest to compute
5417b8d4ba6SJunchao Zhang     MatRowMapKokkosViewHost Fdi_h(NoInit("Fdi_h"), Fm + 1), Foi_h(NoInit("Foi_h"), Fm + 1);
5420e3ece09SJunchao Zhang     MatRowMapType          *Fdi, *Foi;
5430e3ece09SJunchao Zhang     PetscInt                FnzDups = 0, Fdnz = 0, FdnzDups = 0, Fonz = 0, FonzDups = 0; // nz (with or without dups) in F, Fd, Fo
5440e3ece09SJunchao Zhang     PetscInt                iter;
5450e3ece09SJunchao Zhang 
5460e3ece09SJunchao Zhang     Kokkos::deep_copy(Fdi_h, 0); // zero, as we will do 'val++' on them
5470e3ece09SJunchao Zhang     Kokkos::deep_copy(Foi_h, 0);
5480e3ece09SJunchao Zhang     Fdi  = Fdi_h.data() + 1; // +1 for easy indexing in code below
5490e3ece09SJunchao Zhang     Foi  = Foi_h.data() + 1;
5500e3ece09SJunchao Zhang     iter = 0;
5510e3ece09SJunchao Zhang     while (iter < recvRowCnt) { // iter over received rows
5520e3ece09SJunchao Zhang       PetscInt curRowIdx = irootloc[recvRowPerm[iter]];
5530e3ece09SJunchao Zhang       PetscInt dupRows   = 1; // current row has this many contributing rows (of various sparsity patterns)
5540e3ece09SJunchao Zhang 
5550e3ece09SJunchao Zhang       while (iter + dupRows < recvRowCnt && irootloc[recvRowPerm[iter + dupRows]] == curRowIdx) dupRows++;
5560e3ece09SJunchao Zhang 
5570e3ece09SJunchao Zhang       // Copy column indices (and their permutation) of these rows into recvColSorted & recvNzPermSorted
5580e3ece09SJunchao Zhang       PetscInt  nz    = 0; // nz (with dups) in the current row
5590e3ece09SJunchao Zhang       PetscInt *jbuf  = recvColSorted + FnzDups;
5600e3ece09SJunchao Zhang       PetscInt *pbuf  = recvNzPermSorted + FnzDups;
5610e3ece09SJunchao Zhang       PetscInt *jbuf2 = jbuf; // temp pointers
5620e3ece09SJunchao Zhang       PetscInt *pbuf2 = pbuf;
5630e3ece09SJunchao Zhang       for (PetscInt d = 0; d < dupRows; d++) {
5640e3ece09SJunchao Zhang         PetscInt i   = recvRowPerm[iter + d];
5650e3ece09SJunchao Zhang         PetscInt len = recvRowLen[i + 1] - recvRowLen[i];
5660e3ece09SJunchao Zhang         PetscCall(PetscArraycpy(jbuf2, &recvCol[recvRowLen[i]], len));
5670e3ece09SJunchao Zhang         PetscCall(PetscArraycpy(pbuf2, &recvNzPerm[recvRowLen[i]], len));
5680e3ece09SJunchao Zhang         jbuf2 += len;
5690e3ece09SJunchao Zhang         pbuf2 += len;
5700e3ece09SJunchao Zhang         nz += len;
5710e3ece09SJunchao Zhang       }
5720e3ece09SJunchao Zhang       PetscCall(PetscIntSortSemiOrderedWithArray(nz, jbuf, pbuf)); // It could be improved with k-way merge sort, since the rows are already sorted
5730e3ece09SJunchao Zhang 
5740e3ece09SJunchao Zhang       // Scan column indices (in jbuf[0,nz), might have dups) of this row, and see how many go to Fd and how many go to Fo
5750e3ece09SJunchao Zhang       PetscInt cur = 0;
5760e3ece09SJunchao Zhang       while (cur < nz) {
5770e3ece09SJunchao Zhang         PetscInt curColIdx = jbuf[cur];
5780e3ece09SJunchao Zhang         PetscInt dups      = 1;
5790e3ece09SJunchao Zhang 
5800e3ece09SJunchao Zhang         while (cur + dups < nz && jbuf[cur + dups] == curColIdx) dups++;
5810e3ece09SJunchao Zhang         if (curColIdx >= cstart && curColIdx < cend) {
5820e3ece09SJunchao Zhang           Fdi[curRowIdx]++;
5830e3ece09SJunchao Zhang           FdnzDups += dups;
5840e3ece09SJunchao Zhang         } else {
5850e3ece09SJunchao Zhang           Foi[curRowIdx]++;
5860e3ece09SJunchao Zhang           FonzDups += dups;
5870e3ece09SJunchao Zhang         }
5880e3ece09SJunchao Zhang         cur += dups;
5890e3ece09SJunchao Zhang       }
5900e3ece09SJunchao Zhang 
5910e3ece09SJunchao Zhang       FnzDups += nz;
5920e3ece09SJunchao Zhang       iter += dupRows; // Move to next unique row
5930e3ece09SJunchao Zhang     }
5940e3ece09SJunchao Zhang 
5950e3ece09SJunchao Zhang     Fdi = Fdi_h.data(); // restore Fdi, Foi and make them CSR
5960e3ece09SJunchao Zhang     Foi = Foi_h.data();
5970e3ece09SJunchao Zhang     for (PetscInt i = 0; i < Fm; i++) {
5980e3ece09SJunchao Zhang       Fdi[i + 1] += Fdi[i];
5990e3ece09SJunchao Zhang       Foi[i + 1] += Foi[i];
6000e3ece09SJunchao Zhang     }
6010e3ece09SJunchao Zhang     Fdnz = Fdi[Fm];
6020e3ece09SJunchao Zhang     Fonz = Foi[Fm];
6030e3ece09SJunchao Zhang     PetscCall(PetscFree2(sendCol, recvCol));
6040e3ece09SJunchao Zhang 
6050e3ece09SJunchao Zhang     // Allocate j, jmap, jperm for Fd and Fo
6067b8d4ba6SJunchao Zhang     MatColIdxKokkosViewHost Fdj_h(NoInit("Fdj_h"), Fdnz), Foj_h(NoInit("Foj_h"), Fonz);
6077b8d4ba6SJunchao Zhang     MatRowMapKokkosViewHost Fdjmap_h(NoInit("Fdjmap_h"), Fdnz + 1), Fojmap_h(NoInit("Fojmap_h"), Fonz + 1); // +1 to make csr
6087b8d4ba6SJunchao Zhang     MatRowMapKokkosViewHost Fdjperm_h(NoInit("Fdjperm_h"), FdnzDups), Fojperm_h(NoInit("Fojperm_h"), FonzDups);
6090e3ece09SJunchao Zhang     MatColIdxType          *Fdj = Fdj_h.data(), *Foj = Foj_h.data();
6100e3ece09SJunchao Zhang     MatRowMapType          *Fdjmap = Fdjmap_h.data(), *Fojmap = Fojmap_h.data();
6110e3ece09SJunchao Zhang     MatRowMapType          *Fdjperm = Fdjperm_h.data(), *Fojperm = Fojperm_h.data();
6120e3ece09SJunchao Zhang 
6130e3ece09SJunchao Zhang     // Scan recvColSorted[] again, and fill j, jmap, jperm for Fd and Fo
6140e3ece09SJunchao Zhang     Fdjmap[0] = 0;
6150e3ece09SJunchao Zhang     Fojmap[0] = 0;
6160e3ece09SJunchao Zhang     FnzDups   = 0;
6170e3ece09SJunchao Zhang     Fdnz      = 0;
6180e3ece09SJunchao Zhang     Fonz      = 0;
6190e3ece09SJunchao Zhang     iter      = 0; // iter over received rows
6200e3ece09SJunchao Zhang     while (iter < recvRowCnt) {
6210e3ece09SJunchao Zhang       PetscInt curRowIdx = irootloc[recvRowPerm[iter]]; // current row idx
6220e3ece09SJunchao Zhang       PetscInt dupRows   = 1;                           // It has this many contributing rows (of various lengths)
6230e3ece09SJunchao Zhang       PetscInt nz        = 0;                           // nz (with dups) in the current row
6240e3ece09SJunchao Zhang 
6250e3ece09SJunchao Zhang       while (iter + dupRows < recvRowCnt && irootloc[recvRowPerm[iter + dupRows]] == curRowIdx) dupRows++;
6260e3ece09SJunchao Zhang       for (PetscInt d = 0; d < dupRows; d++) {
6270e3ece09SJunchao Zhang         PetscInt i = recvRowPerm[iter + d];
6280e3ece09SJunchao Zhang         nz += recvRowLen[i + 1] - recvRowLen[i];
6290e3ece09SJunchao Zhang       }
6300e3ece09SJunchao Zhang 
6310e3ece09SJunchao Zhang       PetscInt *jbuf = recvColSorted + FnzDups;
6320e3ece09SJunchao Zhang       // Scan columns (in jbuf[0,nz) of this row, copy them and their permutation to j[] and jperm[] of Fd and Fo
6330e3ece09SJunchao Zhang       PetscInt cur = 0;
6340e3ece09SJunchao Zhang       while (cur < nz) {
6350e3ece09SJunchao Zhang         PetscInt curColIdx = jbuf[cur];
6360e3ece09SJunchao Zhang         PetscInt dups      = 1;
6370e3ece09SJunchao Zhang 
6380e3ece09SJunchao Zhang         while (cur + dups < nz && jbuf[cur + dups] == curColIdx) dups++;
6390e3ece09SJunchao Zhang         if (curColIdx >= cstart && curColIdx < cend) {
6400e3ece09SJunchao Zhang           Fdj[Fdnz]        = curColIdx - cstart; // easily convert to local
6410e3ece09SJunchao Zhang           Fdjmap[Fdnz + 1] = Fdjmap[Fdnz] + dups;
6420e3ece09SJunchao Zhang           for (PetscInt j = 0; j < dups; j++) Fdjperm[Fdjmap[Fdnz] + j] = recvNzPermSorted[FnzDups + j];
6430e3ece09SJunchao Zhang           FdnzDups += dups;
6440e3ece09SJunchao Zhang           Fdnz++;
6450e3ece09SJunchao Zhang         } else {
6460e3ece09SJunchao Zhang           Foj[Fonz]        = curColIdx; // in global
6470e3ece09SJunchao Zhang           Fojmap[Fonz + 1] = Fojmap[Fonz] + dups;
6480e3ece09SJunchao Zhang           for (PetscInt j = 0; j < dups; j++) Fojperm[Fojmap[Fonz] + j] = recvNzPermSorted[FnzDups + j];
6490e3ece09SJunchao Zhang           FonzDups += dups;
6500e3ece09SJunchao Zhang           Fonz++;
6510e3ece09SJunchao Zhang         }
6520e3ece09SJunchao Zhang         cur += dups;
6530e3ece09SJunchao Zhang         FnzDups += dups;
6540e3ece09SJunchao Zhang       }
6550e3ece09SJunchao Zhang       iter += dupRows; // Move to next unique row
6560e3ece09SJunchao Zhang     }
6570e3ece09SJunchao Zhang     PetscCall(PetscFree4(recvRowPerm, recvColSorted, recvNzPerm, recvNzPermSorted));
6580e3ece09SJunchao Zhang     PetscCall(PetscFree5(sendRowLen, recvRowLen, sdisp, rdisp, reqs));
6590e3ece09SJunchao Zhang 
6600e3ece09SJunchao Zhang     // Combine global column indices in garray1[] and Foj[]
6610e3ece09SJunchao Zhang     PetscInt n2, *garray2;
6620e3ece09SJunchao Zhang 
6630e3ece09SJunchao Zhang     PetscCall(ReduceTwoSetsOfGlobalIndices(n1, garray1, Fonz, Foj, &n2, &garray2, map));
6640e3ece09SJunchao Zhang     mm->sf       = reduceSF;
6657b8d4ba6SJunchao Zhang     mm->leafBuf  = MatScalarKokkosView(NoInit("leafBuf"), nleaves);
6667b8d4ba6SJunchao Zhang     mm->rootBuf  = MatScalarKokkosView(NoInit("rootBuf"), nroots);
667aaa8cc7dSPierre Jolivet     mm->garray   = garray2; // give ownership, so no free
6680e3ece09SJunchao Zhang     mm->n        = n2;
6690e3ece09SJunchao Zhang     mm->E_NzLeft = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), E_NzLeft_h);
6700e3ece09SJunchao Zhang     mm->Fdjmap   = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Fdjmap_h);
6710e3ece09SJunchao Zhang     mm->Fdjperm  = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Fdjperm_h);
6720e3ece09SJunchao Zhang     mm->Fojmap   = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Fojmap_h);
6730e3ece09SJunchao Zhang     mm->Fojperm  = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Fojperm_h);
6740e3ece09SJunchao Zhang 
6750e3ece09SJunchao Zhang     // Output Fd and Fo in KokkosCsrMatrix format
6767b8d4ba6SJunchao Zhang     MatScalarKokkosView Fda_d(NoInit("Fda_d"), Fdnz);
6770e3ece09SJunchao Zhang     MatRowMapKokkosView Fdi_d = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Fdi_h);
6780e3ece09SJunchao Zhang     MatColIdxKokkosView Fdj_d = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Fdj_h);
6797b8d4ba6SJunchao Zhang     MatScalarKokkosView Foa_d(NoInit("Foa_d"), Fonz);
6800e3ece09SJunchao Zhang     MatRowMapKokkosView Foi_d = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Foi_h);
6810e3ece09SJunchao Zhang     MatColIdxKokkosView Foj_d = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Foj_h);
6820e3ece09SJunchao Zhang 
6830e3ece09SJunchao Zhang     PetscCallCXX(mm->Fd = KokkosCsrMatrix("Fd", Fm, cend - cstart, Fdnz, Fda_d, Fdi_d, Fdj_d));
6840e3ece09SJunchao Zhang     PetscCallCXX(mm->Fo = KokkosCsrMatrix("Fo", Fm, n2, Fonz, Foa_d, Foi_d, Foj_d)); // Fo's column size is n2, length of garray2[]
6850e3ece09SJunchao Zhang 
6860e3ece09SJunchao Zhang     // Compute kernel launch parameters in merging E
6870e3ece09SJunchao Zhang     PetscInt teamSize, vectorLength, rowsPerTeam;
6880e3ece09SJunchao Zhang 
6890e3ece09SJunchao Zhang     teamSize = vectorLength = rowsPerTeam = -1;
6900e3ece09SJunchao Zhang     PetscCall(MatMergeGetLaunchParameters<DefaultExecutionSpace>(Em, Enz, -1, teamSize, vectorLength, rowsPerTeam));
6910e3ece09SJunchao Zhang     mm->E_TeamSize     = teamSize;
6920e3ece09SJunchao Zhang     mm->E_VectorLength = vectorLength;
6930e3ece09SJunchao Zhang     mm->E_RowsPerTeam  = rowsPerTeam;
6940e3ece09SJunchao Zhang   } else PetscCheck(reuse == MAT_REUSE_MATRIX, comm, PETSC_ERR_PLIB, "Unsupported MatReuse enum %d", reuse);
6950e3ece09SJunchao Zhang 
6960e3ece09SJunchao Zhang   // Handy aliases
6970e3ece09SJunchao Zhang   auto       &Aa           = A.values;
6980e3ece09SJunchao Zhang   auto       &Ba           = B.values;
6990e3ece09SJunchao Zhang   const auto &Ai           = A.graph.row_map;
7000e3ece09SJunchao Zhang   const auto &Bi           = B.graph.row_map;
7010e3ece09SJunchao Zhang   const auto &E_NzLeft     = mm->E_NzLeft;
7020e3ece09SJunchao Zhang   auto       &leafBuf      = mm->leafBuf;
7030e3ece09SJunchao Zhang   auto       &rootBuf      = mm->rootBuf;
7040e3ece09SJunchao Zhang   PetscSF     reduceSF     = mm->sf;
7050e3ece09SJunchao Zhang   PetscInt    Em           = A.numRows();
7060e3ece09SJunchao Zhang   PetscInt    teamSize     = mm->E_TeamSize;
7070e3ece09SJunchao Zhang   PetscInt    vectorLength = mm->E_VectorLength;
7080e3ece09SJunchao Zhang   PetscInt    rowsPerTeam  = mm->E_RowsPerTeam;
7090e3ece09SJunchao Zhang   PetscInt    workSets     = (Em + rowsPerTeam - 1) / rowsPerTeam;
7100e3ece09SJunchao Zhang 
7110e3ece09SJunchao Zhang   // Copy rows in A/B of E to leafBuf, then pass it to rootBuf
7120e3ece09SJunchao Zhang   PetscCallCXX(Kokkos::parallel_for(
7130e3ece09SJunchao Zhang     Kokkos::TeamPolicy<>(workSets, teamSize, vectorLength), KOKKOS_LAMBDA(const KokkosTeamMemberType &t) {
7140e3ece09SJunchao Zhang       Kokkos::parallel_for(Kokkos::TeamThreadRange(t, 0, rowsPerTeam), [&](PetscInt k) {
7150e3ece09SJunchao Zhang         PetscInt i = t.league_rank() * rowsPerTeam + k; // i-th row in F
7160e3ece09SJunchao Zhang         if (i < Em) {
7170e3ece09SJunchao Zhang           PetscInt disp   = Ai(i) + Bi(i);
7180e3ece09SJunchao Zhang           PetscInt alen   = Ai(i + 1) - Ai(i);
7190e3ece09SJunchao Zhang           PetscInt blen   = Bi(i + 1) - Bi(i);
7200e3ece09SJunchao Zhang           PetscInt nzleft = E_NzLeft(i);
7210e3ece09SJunchao Zhang 
7220e3ece09SJunchao Zhang           Kokkos::parallel_for(Kokkos::ThreadVectorRange(t, alen + blen), [&](PetscInt j) {
7230e3ece09SJunchao Zhang             MatScalar &val = leafBuf(disp + j);
7240e3ece09SJunchao Zhang             if (j < nzleft) { // B left
7250e3ece09SJunchao Zhang               val = Ba(Bi(i) + j);
7260e3ece09SJunchao Zhang             } else if (j < nzleft + alen) { // diag A
7270e3ece09SJunchao Zhang               val = Aa(Ai(i) + j - nzleft);
7280e3ece09SJunchao Zhang             } else { // B right
7290e3ece09SJunchao Zhang               val = Ba(Bi(i) + j - alen);
730f0e6e2d1SJunchao Zhang             }
731f0e6e2d1SJunchao Zhang           });
732f0e6e2d1SJunchao Zhang         }
733f0e6e2d1SJunchao Zhang       });
7340e3ece09SJunchao Zhang     }));
7350e3ece09SJunchao Zhang   PetscCall(PetscSFReduceWithMemTypeBegin(reduceSF, MPIU_SCALAR, PETSC_MEMTYPE_KOKKOS, leafBuf.data(), PETSC_MEMTYPE_KOKKOS, rootBuf.data(), MPI_REPLACE));
736f0e6e2d1SJunchao Zhang   PetscFunctionReturn(PETSC_SUCCESS);
737f0e6e2d1SJunchao Zhang }
7380e3ece09SJunchao Zhang 
739aaa8cc7dSPierre Jolivet // To finish MatMPIAIJKokkosReduce.
7400e3ece09SJunchao Zhang static PetscErrorCode MatMPIAIJKokkosReduceEnd(MPI_Comm comm, KokkosCsrMatrix A, KokkosCsrMatrix B, PetscInt cstart, PetscInt cend, const PetscInt *garray1, PetscSF ownerSF, MatReuse reuse, PetscInt *map, MatMatStruct_AtB *mm)
7410e3ece09SJunchao Zhang {
7420e3ece09SJunchao Zhang   PetscFunctionBegin;
7430e3ece09SJunchao Zhang   auto       &leafBuf  = mm->leafBuf;
7440e3ece09SJunchao Zhang   auto       &rootBuf  = mm->rootBuf;
7450e3ece09SJunchao Zhang   auto       &Fda      = mm->Fd.values;
7460e3ece09SJunchao Zhang   const auto &Fdjmap   = mm->Fdjmap;
7470e3ece09SJunchao Zhang   const auto &Fdjperm  = mm->Fdjperm;
7480e3ece09SJunchao Zhang   auto        Fdnz     = mm->Fd.nnz();
7490e3ece09SJunchao Zhang   auto       &Foa      = mm->Fo.values;
7500e3ece09SJunchao Zhang   const auto &Fojmap   = mm->Fojmap;
7510e3ece09SJunchao Zhang   const auto &Fojperm  = mm->Fojperm;
7520e3ece09SJunchao Zhang   auto        Fonz     = mm->Fo.nnz();
7530e3ece09SJunchao Zhang   PetscSF     reduceSF = mm->sf;
7540e3ece09SJunchao Zhang 
7550e3ece09SJunchao Zhang   PetscCall(PetscSFReduceEnd(reduceSF, MPIU_SCALAR, leafBuf.data(), rootBuf.data(), MPI_REPLACE));
7560e3ece09SJunchao Zhang 
7570e3ece09SJunchao Zhang   // Reduce data in rootBuf to Fd and Fo
7580e3ece09SJunchao Zhang   PetscCallCXX(Kokkos::parallel_for(
7590e3ece09SJunchao Zhang     Fdnz, KOKKOS_LAMBDA(const MatRowMapType i) {
7600e3ece09SJunchao Zhang       PetscScalar sum = 0.0;
7610e3ece09SJunchao Zhang       for (MatRowMapType k = Fdjmap(i); k < Fdjmap(i + 1); k++) sum += rootBuf(Fdjperm(k));
7620e3ece09SJunchao Zhang       Fda(i) = sum;
7630e3ece09SJunchao Zhang     }));
7640e3ece09SJunchao Zhang 
7650e3ece09SJunchao Zhang   PetscCallCXX(Kokkos::parallel_for(
7660e3ece09SJunchao Zhang     Fonz, KOKKOS_LAMBDA(const MatRowMapType i) {
7670e3ece09SJunchao Zhang       PetscScalar sum = 0.0;
7680e3ece09SJunchao Zhang       for (MatRowMapType k = Fojmap(i); k < Fojmap(i + 1); k++) sum += rootBuf(Fojperm(k));
7690e3ece09SJunchao Zhang       Foa(i) = sum;
7700e3ece09SJunchao Zhang     }));
7710e3ece09SJunchao Zhang   PetscFunctionReturn(PETSC_SUCCESS);
7720e3ece09SJunchao Zhang }
7730e3ece09SJunchao Zhang 
7740e3ece09SJunchao Zhang /*
7750e3ece09SJunchao Zhang   MatMPIAIJKokkosBcast - Bcast local rows of a MPIAIJKOKKOS matrix (E) to produce a local matrix (F, stored in mm) in split form
7760e3ece09SJunchao Zhang 
7770e3ece09SJunchao Zhang   This is a complex routine. It is essentially the MPIAIJKOKKOS counterpart of MatGetBrowsOfAoCols_MPIAIJ, but supports
7780e3ece09SJunchao Zhang   device and involves various index mapping.
7790e3ece09SJunchao Zhang 
7800e3ece09SJunchao Zhang   In the given ownerSF, leaves correspond to rows in F, and roots correspond to rows in E. Roots may connect to multiple leaves.
7810e3ece09SJunchao Zhang   Suppose F's j-th row is connected to a root identified by PetscSFNode (k,i), it means we need to bcast the i-th row of E on rank k
7820e3ece09SJunchao Zhang   to j-th row of F. ownerSF is not an arbitrary SF, instead it is the Mvctx of another MPIAIJ matrix A that is able to perform A*E.
7830e3ece09SJunchao Zhang   F has the same column layout as E.
7840e3ece09SJunchao Zhang 
7850e3ece09SJunchao Zhang   Conceptually F has global column indices. In this routine, we spit F into diagonal Fd and off-diagonal Fo.
786aaa8cc7dSPierre Jolivet   Fd uses local column indices, which are easy to compute. We just need to subtract the "local column range start" from the global indices.
7870e3ece09SJunchao Zhang   Fo had global column indices at first. We will reduce them into local ones. In doing that, we also take into account the global
7880e3ece09SJunchao Zhang   column indices that E's off-diag block has. Let's say there are n1 such indices stored in garray1[]. We will reduce them along with
7890e3ece09SJunchao Zhang   column indices in Fo and update Fo with local indices.
7900e3ece09SJunchao Zhang 
7910e3ece09SJunchao Zhang    Input Parameters:
7920e3ece09SJunchao Zhang +   E       - the MPIAIJKOKKOS matrix
7939c89aa79SPierre Jolivet .   ownerSF - the ownership SF (insignificant in MAT_REUSE_MATRIX)
7940e3ece09SJunchao Zhang .   reuse   - either MAT_INITIAL_MATRIX or MAT_REUSE_MATRIX
7950e3ece09SJunchao Zhang -   mm      - to stash matproduct intermediate data structures
7960e3ece09SJunchao Zhang 
7970e3ece09SJunchao Zhang     Output Parameters:
7980e3ece09SJunchao Zhang +   map[n1] - allocated by caller. It maps garray1[] to garray2[]. See more at ReduceTwoSetsOfGlobalIndices.
7990e3ece09SJunchao Zhang -   mm      - contains various info, such as garray2[], Fd, Fo, etc.
8000e3ece09SJunchao Zhang 
8010e3ece09SJunchao Zhang     Notes:
8020e3ece09SJunchao Zhang     When reuse = MAT_REUSE_MATRIX, ownerSF, map are not significant.
8030e3ece09SJunchao Zhang     The routine is provide in split-phase form MatMPIAIJKokkosBcastBegin/End() to provide computation/communication opportunities.
8040e3ece09SJunchao Zhang */
8050e3ece09SJunchao Zhang static PetscErrorCode MatMPIAIJKokkosBcastBegin(Mat E, PetscSF ownerSF, MatReuse reuse, PetscInt *map, MatMatStruct_AB *mm)
8060e3ece09SJunchao Zhang {
8070e3ece09SJunchao Zhang   Mat_MPIAIJ       *empi = static_cast<Mat_MPIAIJ *>(E->data);
8080e3ece09SJunchao Zhang   Mat               A = empi->A, B = empi->B; // diag and off-diag
8090e3ece09SJunchao Zhang   Mat_SeqAIJKokkos *akok = static_cast<Mat_SeqAIJKokkos *>(A->spptr), *bkok = static_cast<Mat_SeqAIJKokkos *>(B->spptr);
8100e3ece09SJunchao Zhang   PetscInt          Em = E->rmap->n; // #local rows
8110e3ece09SJunchao Zhang   MPI_Comm          comm;
8120e3ece09SJunchao Zhang 
8130e3ece09SJunchao Zhang   PetscFunctionBegin;
8140e3ece09SJunchao Zhang   PetscCallMPI(PetscObjectGetComm((PetscObject)E, &comm));
8150e3ece09SJunchao Zhang   if (reuse == MAT_INITIAL_MATRIX) {
8160e3ece09SJunchao Zhang     Mat_SeqAIJ     *aseq = static_cast<Mat_SeqAIJ *>(A->data), *bseq = static_cast<Mat_SeqAIJ *>(B->data);
8170e3ece09SJunchao Zhang     PetscInt        n1 = B->cmap->n, *Ai = aseq->i, *Aj = aseq->j, *Bi = bseq->i, *Bj = bseq->j;
8180e3ece09SJunchao Zhang     const PetscInt *garray1 = empi->garray; // its size is n1
8190e3ece09SJunchao Zhang     PetscInt        cstart, cend;
8200e3ece09SJunchao Zhang     PetscSF         bcastSF;
8210e3ece09SJunchao Zhang 
8220e3ece09SJunchao Zhang     PetscCall(MatGetOwnershipRangeColumn(E, &cstart, &cend));
8230e3ece09SJunchao Zhang 
8240e3ece09SJunchao Zhang     // Count how many nonzeros of each row in E are in the left of the diag block [cstart,cend)
8257b8d4ba6SJunchao Zhang     PetscIntKokkosViewHost E_NzLeft_h(NoInit("E_NzLeft_h"), Em), E_RowLen_h(NoInit("E_RowLen_h"), Em);
8260e3ece09SJunchao Zhang     PetscInt              *E_NzLeft = E_NzLeft_h.data(), *E_RowLen = E_RowLen_h.data();
8270e3ece09SJunchao Zhang     for (PetscInt i = 0; i < Em; i++) {
8280e3ece09SJunchao Zhang       const PetscInt *first, *last, *it;
8290e3ece09SJunchao Zhang       PetscInt        count, step;
8300e3ece09SJunchao Zhang       // std::lower_bound(first,last,cstart), but need to use global column indices
8310e3ece09SJunchao Zhang       first = Bj + Bi[i];
8320e3ece09SJunchao Zhang       last  = Bj + Bi[i + 1];
8330e3ece09SJunchao Zhang       count = last - first;
8340e3ece09SJunchao Zhang       while (count > 0) {
8350e3ece09SJunchao Zhang         it   = first;
8360e3ece09SJunchao Zhang         step = count / 2;
8370e3ece09SJunchao Zhang         it += step;
8380e3ece09SJunchao Zhang         if (empi->garray[*it] < cstart) { // map local to global
8390e3ece09SJunchao Zhang           first = ++it;
8400e3ece09SJunchao Zhang           count -= step + 1;
8410e3ece09SJunchao Zhang         } else count = step;
8420e3ece09SJunchao Zhang       }
8430e3ece09SJunchao Zhang       E_NzLeft[i] = first - (Bj + Bi[i]);
8440e3ece09SJunchao Zhang       E_RowLen[i] = (Ai[i + 1] - Ai[i]) + (Bi[i + 1] - Bi[i]);
8450e3ece09SJunchao Zhang     }
8460e3ece09SJunchao Zhang 
8470e3ece09SJunchao Zhang     // Compute row pointer Fi of F
8480e3ece09SJunchao Zhang     PetscInt *Fi, Fm, Fnz;
8490e3ece09SJunchao Zhang     PetscCall(PetscSFGetGraph(ownerSF, NULL, &Fm, NULL, NULL)); // Fm = #rows of F = nleaves of ownerSF
8500e3ece09SJunchao Zhang     PetscCall(PetscMalloc1(Fm + 1, &Fi));
8510e3ece09SJunchao Zhang     Fi[0] = 0;
8520e3ece09SJunchao Zhang     PetscCall(PetscSFBcastWithMemTypeBegin(ownerSF, MPIU_INT, PETSC_MEMTYPE_HOST, E_RowLen, PETSC_MEMTYPE_HOST, &Fi[1], MPI_REPLACE));
8530e3ece09SJunchao Zhang     PetscCall(PetscSFBcastEnd(ownerSF, MPIU_INT, E_RowLen, &Fi[1], MPI_REPLACE));
8540e3ece09SJunchao Zhang     for (PetscInt i = 0; i < Fm; i++) Fi[i + 1] += Fi[i];
8550e3ece09SJunchao Zhang     Fnz = Fi[Fm];
8560e3ece09SJunchao Zhang 
8570e3ece09SJunchao Zhang     // Build the real PetscSF for bcasting E rows (buffer to buffer)
8580e3ece09SJunchao Zhang     const PetscMPIInt *iranks, *ranks;
8590e3ece09SJunchao Zhang     const PetscInt    *ioffset, *irootloc, *roffset;
8600e3ece09SJunchao Zhang     PetscInt           niranks, nranks, *sdisp, *rdisp;
8610e3ece09SJunchao Zhang     MPI_Request       *reqs;
8620e3ece09SJunchao Zhang     PetscMPIInt        tag;
8630e3ece09SJunchao Zhang 
8640e3ece09SJunchao Zhang     PetscCall(PetscSFGetLeafRanks(ownerSF, &niranks, &iranks, &ioffset, &irootloc)); // get leaf ranks referencing roots on this process
8650e3ece09SJunchao Zhang     PetscCall(PetscSFGetRootRanks(ownerSF, &nranks, &ranks, &roffset, NULL, NULL));  // recv info
8660e3ece09SJunchao Zhang     PetscCall(PetscMalloc3(niranks + 1, &sdisp, nranks, &rdisp, niranks + nranks, &reqs));
8670e3ece09SJunchao Zhang 
8680e3ece09SJunchao Zhang     sdisp[0] = 0; // send displacement
8690e3ece09SJunchao Zhang     for (PetscInt i = 0; i < niranks; i++) {
8700e3ece09SJunchao Zhang       sdisp[i + 1] = sdisp[i];
8710e3ece09SJunchao Zhang       for (PetscInt j = ioffset[i]; j < ioffset[i + 1]; j++) {
8720e3ece09SJunchao Zhang         PetscInt r = irootloc[j]; // row to be sent
8730e3ece09SJunchao Zhang         sdisp[i + 1] += E_RowLen[r];
8740e3ece09SJunchao Zhang       }
8750e3ece09SJunchao Zhang     }
8760e3ece09SJunchao Zhang 
8770e3ece09SJunchao Zhang     PetscCallMPI(PetscCommGetNewTag(comm, &tag));
8780e3ece09SJunchao Zhang     for (PetscInt i = 0; i < nranks; i++) PetscCallMPI(MPI_Irecv(&rdisp[i], 1, MPIU_INT, ranks[i], tag, comm, &reqs[i]));
8790e3ece09SJunchao Zhang     for (PetscInt i = 0; i < niranks; i++) PetscCallMPI(MPI_Isend(&sdisp[i], 1, MPIU_INT, iranks[i], tag, comm, &reqs[nranks + i]));
8800e3ece09SJunchao Zhang     PetscCallMPI(MPI_Waitall(niranks + nranks, reqs, MPI_STATUSES_IGNORE));
8810e3ece09SJunchao Zhang 
8820e3ece09SJunchao Zhang     PetscInt     nleaves = Fnz;            // leaves are nonzeros I will receive
8830e3ece09SJunchao Zhang     PetscInt     nroots  = sdisp[niranks]; // roots are nonzeros I will send
8840e3ece09SJunchao Zhang     PetscSFNode *iremote;                  // give ownership to bcastSF
8850e3ece09SJunchao Zhang     PetscCall(PetscMalloc1(nleaves, &iremote));
8860e3ece09SJunchao Zhang     for (PetscInt i = 0; i < nranks; i++) { // for each sender rank
8870e3ece09SJunchao Zhang       PetscInt k = 0;
8880e3ece09SJunchao Zhang       for (PetscInt j = Fi[roffset[i]]; j < Fi[roffset[i + 1]]; j++) { // I will receive rows [roffset[i], roffset[i+1]) of F from ranks[i]
8890e3ece09SJunchao Zhang         iremote[j].rank  = ranks[i];
8900e3ece09SJunchao Zhang         iremote[j].index = rdisp[i] + k; // their root location
8910e3ece09SJunchao Zhang         k++;
8920e3ece09SJunchao Zhang       }
8930e3ece09SJunchao Zhang     }
8940e3ece09SJunchao Zhang     PetscCall(PetscSFCreate(comm, &bcastSF));
8950e3ece09SJunchao Zhang     PetscCall(PetscSFSetGraph(bcastSF, nroots, nleaves, NULL, PETSC_OWN_POINTER, iremote, PETSC_OWN_POINTER));
8960e3ece09SJunchao Zhang     PetscCall(PetscFree3(sdisp, rdisp, reqs));
8970e3ece09SJunchao Zhang 
8980e3ece09SJunchao Zhang     // Build a plan (rowoffset, irootloc, E_NzLeft) to copy rows in E to rootdata of bcastSF in parallel
8997b8d4ba6SJunchao Zhang     PetscIntKokkosViewHost rowoffset_h(NoInit("rowoffset_h"), ioffset[niranks] + 1);
9000e3ece09SJunchao Zhang     PetscInt              *rowoffset = rowoffset_h.data(); // for each entry (row) indicated in irootloc[], we calculate its destinate offset in copying
9010e3ece09SJunchao Zhang     rowoffset[0]                     = 0;
9027b8d4ba6SJunchao Zhang     for (PetscInt i = 0; i < ioffset[niranks]; i++) { rowoffset[i + 1] = rowoffset[i] + E_RowLen[irootloc[i]]; }
9030e3ece09SJunchao Zhang 
9040e3ece09SJunchao Zhang     // Copy (global) column indices of the needed rows in E to a buffer, and then bcast to Fj[]
9050e3ece09SJunchao Zhang     PetscInt *jbuf, *Fj;
9060e3ece09SJunchao Zhang     PetscCall(PetscMalloc2(nroots, &jbuf, Fnz, &Fj));
9070e3ece09SJunchao Zhang     for (PetscInt k = 0; k < ioffset[niranks]; k++) {
9080e3ece09SJunchao Zhang       PetscInt  i      = irootloc[k]; // row to be copied
9090e3ece09SJunchao Zhang       PetscInt *buf    = &jbuf[rowoffset[k]];
9100e3ece09SJunchao Zhang       PetscInt  nzLeft = E_NzLeft[i];
9110e3ece09SJunchao Zhang       PetscInt  alen = Ai[i + 1] - Ai[i], blen = Bi[i + 1] - Bi[i];
9120e3ece09SJunchao Zhang       for (PetscInt j = 0; j < alen + blen; j++) {
9130e3ece09SJunchao Zhang         if (j < nzLeft) {
9140e3ece09SJunchao Zhang           buf[j] = empi->garray[Bj[Bi[i] + j]]; // left B, in global
9150e3ece09SJunchao Zhang         } else if (j < nzLeft + alen) {
9160e3ece09SJunchao Zhang           buf[j] = Aj[Ai[i] + j - nzLeft] + cstart; // diag A, also in global
9170e3ece09SJunchao Zhang         } else {
9180e3ece09SJunchao Zhang           buf[j] = empi->garray[Bj[Bi[i] + j - alen]]; // right B, in global
9190e3ece09SJunchao Zhang         }
9200e3ece09SJunchao Zhang       }
9210e3ece09SJunchao Zhang     }
9220e3ece09SJunchao Zhang     PetscCall(PetscSFBcastWithMemTypeBegin(bcastSF, MPIU_INT, PETSC_MEMTYPE_HOST, jbuf, PETSC_MEMTYPE_HOST, Fj, MPI_REPLACE));
9230e3ece09SJunchao Zhang     PetscCall(PetscSFBcastEnd(bcastSF, MPIU_INT, jbuf, Fj, MPI_REPLACE));
9240e3ece09SJunchao Zhang 
9250e3ece09SJunchao Zhang     // Build a plan (i.e., F_NzLeft) to split F into Fd and Fo
9267b8d4ba6SJunchao Zhang     MatRowMapKokkosViewHost Fdi_h(NoInit("Fdi_h"), Fm + 1), Foi_h(NoInit("Foi_h"), Fm + 1); // row pointer of Fd, Fo
9277b8d4ba6SJunchao Zhang     MatColIdxKokkosViewHost F_NzLeft_h(NoInit("F_NzLeft_h"), Fm);                           // split each row of F into Left, Diag, Right. We only need to record #nz in Left and Diag.
9280e3ece09SJunchao Zhang     MatRowMapType          *Fdi = Fdi_h.data(), *Foi = Foi_h.data();
9290e3ece09SJunchao Zhang     MatColIdxType          *F_NzLeft = F_NzLeft_h.data();
9300e3ece09SJunchao Zhang 
9310e3ece09SJunchao Zhang     Fdi[0] = Foi[0] = 0;
9320e3ece09SJunchao Zhang     for (PetscInt i = 0; i < Fm; i++) {
9330e3ece09SJunchao Zhang       PetscInt *first, *last, *lb1, *lb2;
9340e3ece09SJunchao Zhang       // cut the row into: Left, [cstart, cend), Right
9350e3ece09SJunchao Zhang       first       = Fj + Fi[i];
9360e3ece09SJunchao Zhang       last        = Fj + Fi[i + 1];
9370e3ece09SJunchao Zhang       lb1         = std::lower_bound(first, last, cstart);
9380e3ece09SJunchao Zhang       F_NzLeft[i] = lb1 - first;
9390e3ece09SJunchao Zhang       lb2         = std::lower_bound(first, last, cend);
9400e3ece09SJunchao Zhang       Fdi[i + 1]  = lb2 - lb1;                        // row i length in Fdi
9410e3ece09SJunchao Zhang       Foi[i + 1]  = (Fi[i + 1] - Fi[i]) - Fdi[i + 1]; // row i length in Foi
9420e3ece09SJunchao Zhang     }
9430e3ece09SJunchao Zhang     for (PetscInt i = 0; i < Fm; i++) {
9440e3ece09SJunchao Zhang       Fdi[i + 1] += Fdi[i];
9450e3ece09SJunchao Zhang       Foi[i + 1] += Foi[i];
9460e3ece09SJunchao Zhang     }
9470e3ece09SJunchao Zhang 
9480e3ece09SJunchao Zhang     // Fill Fdj[] and Foj[], i.e., columns of Fd and Fo. Fdj[] are local, but Foj[] are not yet.
9490e3ece09SJunchao Zhang     PetscInt                Fdnz = Fdi[Fm], Fonz = Foi[Fm];
9507b8d4ba6SJunchao Zhang     MatColIdxKokkosViewHost Fdj_h(NoInit("Fdj_h"), Fdnz), Foj_h(NoInit("Foj_h"), Fonz);
9510e3ece09SJunchao Zhang     MatColIdxType          *Fdj = Fdj_h.data(), *Foj = Foj_h.data(), gid;
9520e3ece09SJunchao Zhang 
9530e3ece09SJunchao Zhang     for (PetscInt i = 0; i < Fm; i++) {
9540e3ece09SJunchao Zhang       PetscInt nzLeft = F_NzLeft[i];
9550e3ece09SJunchao Zhang       PetscInt len    = Fdi[i + 1] - Fdi[i]; // diag row len
9560e3ece09SJunchao Zhang       for (PetscInt j = 0; j < Fi[i + 1] - Fi[i]; j++) {
9570e3ece09SJunchao Zhang         gid = Fj[Fi[i] + j];
9580e3ece09SJunchao Zhang         if (j < nzLeft) { // left, in global
9590e3ece09SJunchao Zhang           Foj[Foi[i] + j] = gid;
9600e3ece09SJunchao Zhang         } else if (j < nzLeft + len) { // diag, in local
9610e3ece09SJunchao Zhang           Fdj[Fdi[i] + j - nzLeft] = gid - cstart;
9620e3ece09SJunchao Zhang         } else { // right, in global
9630e3ece09SJunchao Zhang           Foj[Foi[i] + j - len] = gid;
9640e3ece09SJunchao Zhang         }
9650e3ece09SJunchao Zhang       }
9660e3ece09SJunchao Zhang     }
9670e3ece09SJunchao Zhang     PetscCall(PetscFree2(jbuf, Fj));
9680e3ece09SJunchao Zhang     PetscCall(PetscFree(Fi));
9690e3ece09SJunchao Zhang 
9700e3ece09SJunchao Zhang     // Reduce global indices in Foj[] and garray1[] into local ones
9710e3ece09SJunchao Zhang     PetscInt n2, *garray2;
9720e3ece09SJunchao Zhang     PetscCall(ReduceTwoSetsOfGlobalIndices(n1, garray1, Fonz, Foj, &n2, &garray2, map));
9730e3ece09SJunchao Zhang 
9740e3ece09SJunchao Zhang     // Record the plans built above, for reuse
9750e3ece09SJunchao Zhang     PetscIntKokkosViewHost tmp(const_cast<PetscInt *>(irootloc), ioffset[niranks]); // irootloc[] is owned by ownerSF. We create a copy for safety
9767b8d4ba6SJunchao Zhang     PetscIntKokkosViewHost irootloc_h(NoInit("irootloc_h"), ioffset[niranks]);
9770e3ece09SJunchao Zhang     Kokkos::deep_copy(irootloc_h, tmp);
9780e3ece09SJunchao Zhang     mm->sf        = bcastSF;
9790e3ece09SJunchao Zhang     mm->E_NzLeft  = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), E_NzLeft_h);
9800e3ece09SJunchao Zhang     mm->F_NzLeft  = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), F_NzLeft_h);
9810e3ece09SJunchao Zhang     mm->irootloc  = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), irootloc_h);
9820e3ece09SJunchao Zhang     mm->rowoffset = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), rowoffset_h);
9837b8d4ba6SJunchao Zhang     mm->rootBuf   = MatScalarKokkosView(NoInit("rootBuf"), nroots);
9847b8d4ba6SJunchao Zhang     mm->leafBuf   = MatScalarKokkosView(NoInit("leafBuf"), nleaves);
9850e3ece09SJunchao Zhang     mm->garray    = garray2;
9860e3ece09SJunchao Zhang     mm->n         = n2;
9870e3ece09SJunchao Zhang 
9880e3ece09SJunchao Zhang     // Output Fd and Fo in KokkosCsrMatrix format
9897b8d4ba6SJunchao Zhang     MatScalarKokkosView Fda_d(NoInit("Fda_d"), Fdnz), Foa_d(NoInit("Foa_d"), Fonz);
9900e3ece09SJunchao Zhang     MatRowMapKokkosView Fdi_d = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Fdi_h);
9910e3ece09SJunchao Zhang     MatColIdxKokkosView Fdj_d = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Fdj_h);
9920e3ece09SJunchao Zhang     MatRowMapKokkosView Foi_d = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Foi_h);
9930e3ece09SJunchao Zhang     MatColIdxKokkosView Foj_d = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Foj_h);
9940e3ece09SJunchao Zhang 
9950e3ece09SJunchao Zhang     PetscCallCXX(mm->Fd = KokkosCsrMatrix("Fd", Fm, cend - cstart, Fdnz, Fda_d, Fdi_d, Fdj_d));
9960e3ece09SJunchao Zhang     PetscCallCXX(mm->Fo = KokkosCsrMatrix("Fo", Fm, n2, Fonz, Foa_d, Foi_d, Foj_d));
9970e3ece09SJunchao Zhang 
9980e3ece09SJunchao Zhang     // Compute kernel launch parameters in merging E or splitting F
9990e3ece09SJunchao Zhang     PetscInt teamSize, vectorLength, rowsPerTeam;
10000e3ece09SJunchao Zhang 
10010e3ece09SJunchao Zhang     teamSize = vectorLength = rowsPerTeam = -1;
10020e3ece09SJunchao Zhang     PetscCall(MatMergeGetLaunchParameters<DefaultExecutionSpace>(mm->irootloc.extent(0), mm->rootBuf.extent(0), -1, teamSize, vectorLength, rowsPerTeam));
10030e3ece09SJunchao Zhang     mm->E_TeamSize     = teamSize;
10040e3ece09SJunchao Zhang     mm->E_VectorLength = vectorLength;
10050e3ece09SJunchao Zhang     mm->E_RowsPerTeam  = rowsPerTeam;
10060e3ece09SJunchao Zhang 
10070e3ece09SJunchao Zhang     teamSize = vectorLength = rowsPerTeam = -1;
10080e3ece09SJunchao Zhang     PetscCall(MatMergeGetLaunchParameters<DefaultExecutionSpace>(Fm, Fnz, -1, teamSize, vectorLength, rowsPerTeam));
10090e3ece09SJunchao Zhang     mm->F_TeamSize     = teamSize;
10100e3ece09SJunchao Zhang     mm->F_VectorLength = vectorLength;
10110e3ece09SJunchao Zhang     mm->F_RowsPerTeam  = rowsPerTeam;
10120e3ece09SJunchao Zhang   } else PetscCheck(reuse == MAT_REUSE_MATRIX, comm, PETSC_ERR_PLIB, "Unsupported MatReuse enum %d", reuse);
10130e3ece09SJunchao Zhang 
10140e3ece09SJunchao Zhang   // Sync E's value to device
10150e3ece09SJunchao Zhang   akok->a_dual.sync_device();
10160e3ece09SJunchao Zhang   bkok->a_dual.sync_device();
10170e3ece09SJunchao Zhang 
10180e3ece09SJunchao Zhang   // Handy aliases
10190e3ece09SJunchao Zhang   const auto &Aa = akok->a_dual.view_device();
10200e3ece09SJunchao Zhang   const auto &Ba = bkok->a_dual.view_device();
10210e3ece09SJunchao Zhang   const auto &Ai = akok->i_dual.view_device();
10220e3ece09SJunchao Zhang   const auto &Bi = bkok->i_dual.view_device();
10230e3ece09SJunchao Zhang 
10240e3ece09SJunchao Zhang   // Fetch the plans
10250e3ece09SJunchao Zhang   PetscIntKokkosView  &E_NzLeft  = mm->E_NzLeft;
10260e3ece09SJunchao Zhang   PetscSF             &bcastSF   = mm->sf;
10270e3ece09SJunchao Zhang   MatScalarKokkosView &rootBuf   = mm->rootBuf;
10280e3ece09SJunchao Zhang   MatScalarKokkosView &leafBuf   = mm->leafBuf;
10290e3ece09SJunchao Zhang   PetscIntKokkosView  &irootloc  = mm->irootloc;
10300e3ece09SJunchao Zhang   PetscIntKokkosView  &rowoffset = mm->rowoffset;
10310e3ece09SJunchao Zhang 
10320e3ece09SJunchao Zhang   PetscInt teamSize     = mm->E_TeamSize;
10330e3ece09SJunchao Zhang   PetscInt vectorLength = mm->E_VectorLength;
10340e3ece09SJunchao Zhang   PetscInt rowsPerTeam  = mm->E_RowsPerTeam;
10350e3ece09SJunchao Zhang   PetscInt workSets     = (irootloc.extent(0) + rowsPerTeam - 1) / rowsPerTeam;
10360e3ece09SJunchao Zhang 
10370e3ece09SJunchao Zhang   // Copy rows in A/B of E to rootBuf, then bcast it to leafBuf
10380e3ece09SJunchao Zhang   PetscCallCXX(Kokkos::parallel_for(
10390e3ece09SJunchao Zhang     Kokkos::TeamPolicy<>(workSets, teamSize, vectorLength), KOKKOS_LAMBDA(const KokkosTeamMemberType &t) {
10400e3ece09SJunchao Zhang       Kokkos::parallel_for(Kokkos::TeamThreadRange(t, 0, rowsPerTeam), [&](PetscInt k) {
10410e3ece09SJunchao Zhang         size_t r = t.league_rank() * rowsPerTeam + k; // r-th entry in irootloc[]
10420e3ece09SJunchao Zhang         if (r < irootloc.extent(0)) {
10430e3ece09SJunchao Zhang           PetscInt i      = irootloc(r); // row i of E
10440e3ece09SJunchao Zhang           PetscInt disp   = rowoffset(r);
10450e3ece09SJunchao Zhang           PetscInt alen   = Ai(i + 1) - Ai(i);
10460e3ece09SJunchao Zhang           PetscInt blen   = Bi(i + 1) - Bi(i);
10470e3ece09SJunchao Zhang           PetscInt nzleft = E_NzLeft(i);
10480e3ece09SJunchao Zhang 
10490e3ece09SJunchao Zhang           Kokkos::parallel_for(Kokkos::ThreadVectorRange(t, alen + blen), [&](PetscInt j) {
10500e3ece09SJunchao Zhang             if (j < nzleft) { // B left
10510e3ece09SJunchao Zhang               rootBuf(disp + j) = Ba(Bi(i) + j);
10520e3ece09SJunchao Zhang             } else if (j < nzleft + alen) { // diag A
10530e3ece09SJunchao Zhang               rootBuf(disp + j) = Aa(Ai(i) + j - nzleft);
10540e3ece09SJunchao Zhang             } else { // B right
10550e3ece09SJunchao Zhang               rootBuf(disp + j) = Ba(Bi(i) + j - alen);
10560e3ece09SJunchao Zhang             }
10570e3ece09SJunchao Zhang           });
10580e3ece09SJunchao Zhang         }
10590e3ece09SJunchao Zhang       });
10600e3ece09SJunchao Zhang     }));
10610e3ece09SJunchao Zhang   PetscCall(PetscSFBcastWithMemTypeBegin(bcastSF, MPIU_SCALAR, PETSC_MEMTYPE_KOKKOS, rootBuf.data(), PETSC_MEMTYPE_KOKKOS, leafBuf.data(), MPI_REPLACE));
10620e3ece09SJunchao Zhang   PetscFunctionReturn(PETSC_SUCCESS);
10630e3ece09SJunchao Zhang }
10640e3ece09SJunchao Zhang 
10650e3ece09SJunchao Zhang // To finish MatMPIAIJKokkosBcast.
10660e3ece09SJunchao Zhang static PetscErrorCode MatMPIAIJKokkosBcastEnd(Mat E, PetscSF ownerSF, MatReuse reuse, PetscInt *map, MatMatStruct_AB *mm)
10670e3ece09SJunchao Zhang {
10680e3ece09SJunchao Zhang   PetscFunctionBegin;
10690e3ece09SJunchao Zhang   const auto &Fd  = mm->Fd;
10700e3ece09SJunchao Zhang   const auto &Fo  = mm->Fo;
10710e3ece09SJunchao Zhang   const auto &Fdi = Fd.graph.row_map;
10720e3ece09SJunchao Zhang   const auto &Foi = Fo.graph.row_map;
10730e3ece09SJunchao Zhang   auto       &Fda = Fd.values;
10740e3ece09SJunchao Zhang   auto       &Foa = Fo.values;
10750e3ece09SJunchao Zhang   auto        Fm  = Fd.numRows();
10760e3ece09SJunchao Zhang 
10770e3ece09SJunchao Zhang   PetscIntKokkosView  &F_NzLeft     = mm->F_NzLeft;
10780e3ece09SJunchao Zhang   PetscSF             &bcastSF      = mm->sf;
10790e3ece09SJunchao Zhang   MatScalarKokkosView &rootBuf      = mm->rootBuf;
10800e3ece09SJunchao Zhang   MatScalarKokkosView &leafBuf      = mm->leafBuf;
10810e3ece09SJunchao Zhang   PetscInt             teamSize     = mm->F_TeamSize;
10820e3ece09SJunchao Zhang   PetscInt             vectorLength = mm->F_VectorLength;
10830e3ece09SJunchao Zhang   PetscInt             rowsPerTeam  = mm->F_RowsPerTeam;
10840e3ece09SJunchao Zhang   PetscInt             workSets     = (Fm + rowsPerTeam - 1) / rowsPerTeam;
10850e3ece09SJunchao Zhang 
10860e3ece09SJunchao Zhang   PetscCall(PetscSFBcastEnd(bcastSF, MPIU_SCALAR, rootBuf.data(), leafBuf.data(), MPI_REPLACE));
10870e3ece09SJunchao Zhang 
10880e3ece09SJunchao Zhang   // Update Fda and Foa with new data in leafBuf (as if it is Fa)
10890e3ece09SJunchao Zhang   PetscCallCXX(Kokkos::parallel_for(
10900e3ece09SJunchao Zhang     Kokkos::TeamPolicy<>(workSets, teamSize, vectorLength), KOKKOS_LAMBDA(const KokkosTeamMemberType &t) {
10910e3ece09SJunchao Zhang       Kokkos::parallel_for(Kokkos::TeamThreadRange(t, 0, rowsPerTeam), [&](PetscInt k) {
10920e3ece09SJunchao Zhang         PetscInt i = t.league_rank() * rowsPerTeam + k; // i-th row in F
10930e3ece09SJunchao Zhang         if (i < Fm) {
10940e3ece09SJunchao Zhang           PetscInt nzLeft = F_NzLeft(i);
10950e3ece09SJunchao Zhang           PetscInt alen   = Fdi(i + 1) - Fdi(i);
10960e3ece09SJunchao Zhang           PetscInt blen   = Foi(i + 1) - Foi(i);
10970e3ece09SJunchao Zhang           PetscInt Fii    = Fdi(i) + Foi(i);
10980e3ece09SJunchao Zhang 
10990e3ece09SJunchao Zhang           Kokkos::parallel_for(Kokkos::ThreadVectorRange(t, alen + blen), [&](PetscInt j) {
11000e3ece09SJunchao Zhang             PetscScalar val = leafBuf(Fii + j);
11010e3ece09SJunchao Zhang             if (j < nzLeft) { // left
11020e3ece09SJunchao Zhang               Foa(Foi(i) + j) = val;
11030e3ece09SJunchao Zhang             } else if (j < nzLeft + alen) { // diag
11040e3ece09SJunchao Zhang               Fda(Fdi(i) + j - nzLeft) = val;
11050e3ece09SJunchao Zhang             } else { // right
11060e3ece09SJunchao Zhang               Foa(Foi(i) + j - alen) = val;
11070e3ece09SJunchao Zhang             }
11080e3ece09SJunchao Zhang           });
11090e3ece09SJunchao Zhang         }
11100e3ece09SJunchao Zhang       });
11110e3ece09SJunchao Zhang     }));
11120e3ece09SJunchao Zhang   PetscFunctionReturn(PETSC_SUCCESS);
11130e3ece09SJunchao Zhang }
11140e3ece09SJunchao Zhang 
11150e3ece09SJunchao Zhang static PetscErrorCode MatProductSymbolic_MPIAIJKokkos_AtB(Mat_Product *product, Mat A, Mat B, MatMatStruct_AtB *mm)
11160e3ece09SJunchao Zhang {
11170e3ece09SJunchao Zhang   Mat_MPIAIJ     *ampi = static_cast<Mat_MPIAIJ *>(A->data);
11180e3ece09SJunchao Zhang   Mat_MPIAIJ     *bmpi = static_cast<Mat_MPIAIJ *>(B->data);
11190e3ece09SJunchao Zhang   KokkosCsrMatrix Adt, Aot, Ad, Ao, Bd, Bo;
11200e3ece09SJunchao Zhang   PetscInt        cstart, cend;
11210e3ece09SJunchao Zhang   MPI_Comm        comm;
11220e3ece09SJunchao Zhang 
11230e3ece09SJunchao Zhang   PetscFunctionBegin;
11240e3ece09SJunchao Zhang   PetscCall(PetscObjectGetComm((PetscObject)B, &comm));
11250e3ece09SJunchao Zhang   PetscCall(MatSeqAIJKokkosGenerateTranspose_Private(ampi->A, &Adt));
11260e3ece09SJunchao Zhang   PetscCall(MatSeqAIJKokkosGenerateTranspose_Private(ampi->B, &Aot));
11270e3ece09SJunchao Zhang   PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(ampi->A, &Ad));
11280e3ece09SJunchao Zhang   PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(ampi->B, &Ao));
11290e3ece09SJunchao Zhang   PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(bmpi->A, &Bd));
11300e3ece09SJunchao Zhang   PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(bmpi->B, &Bo));
11310e3ece09SJunchao Zhang 
11320e3ece09SJunchao Zhang   // TODO: add command line options to select spgemm algorithms
11330e3ece09SJunchao Zhang   auto spgemm_alg = KokkosSparse::SPGEMMAlgorithm::SPGEMM_DEFAULT; // default is TPL if enabled, otherwise KK
11340e3ece09SJunchao Zhang 
11350e3ece09SJunchao Zhang   // CUDA-10.2's spgemm has bugs. We prefer the SpGEMMreuse APIs introduced in cuda-11.4
11360e3ece09SJunchao Zhang #if defined(KOKKOSKERNELS_ENABLE_TPL_CUSPARSE)
11370e3ece09SJunchao Zhang   #if PETSC_PKG_CUDA_VERSION_LT(11, 4, 0)
11380e3ece09SJunchao Zhang   spgemm_alg = KokkosSparse::SPGEMMAlgorithm::SPGEMM_KK;
1139f0e6e2d1SJunchao Zhang   #endif
11400e3ece09SJunchao Zhang #endif
11410e3ece09SJunchao Zhang 
11420e3ece09SJunchao Zhang   PetscCallCXX(mm->kh1.create_spgemm_handle(spgemm_alg));
11430e3ece09SJunchao Zhang   PetscCallCXX(mm->kh2.create_spgemm_handle(spgemm_alg));
11440e3ece09SJunchao Zhang   PetscCallCXX(mm->kh3.create_spgemm_handle(spgemm_alg));
11450e3ece09SJunchao Zhang   PetscCallCXX(mm->kh4.create_spgemm_handle(spgemm_alg));
11460e3ece09SJunchao Zhang 
11470e3ece09SJunchao Zhang   // Aot * (B's diag + B's off-diag)
11480e3ece09SJunchao Zhang   PetscCallCXX(KokkosSparse::spgemm_symbolic(mm->kh3, Aot, false, Bd, false, mm->C3));
11490e3ece09SJunchao Zhang   PetscCallCXX(KokkosSparse::spgemm_symbolic(mm->kh4, Aot, false, Bo, false, mm->C4));
11500e3ece09SJunchao Zhang   // KK spgemm_symbolic() only populates the result's row map, but not its columns.
11510e3ece09SJunchao Zhang   // TODO: Remove the fake spgemm_numeric() after KK fixed this problem.
11520e3ece09SJunchao Zhang   PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh3, Aot, false, Bd, false, mm->C3));
11530e3ece09SJunchao Zhang   PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh4, Aot, false, Bo, false, mm->C4));
11540e3ece09SJunchao Zhang #if PETSC_PKG_KOKKOS_KERNELS_VERSION_LT(4, 0, 0)
11550e3ece09SJunchao Zhang   PetscCallCXX(sort_crs_matrix(mm->C3));
11560e3ece09SJunchao Zhang   PetscCallCXX(sort_crs_matrix(mm->C4));
11570e3ece09SJunchao Zhang #endif
11580e3ece09SJunchao Zhang 
11590e3ece09SJunchao Zhang   // Reduce E (i.e., C3 and C4)'s rows to form F, and overlap the communication
11607b8d4ba6SJunchao Zhang   PetscIntKokkosViewHost map_h(NoInit("map_h"), bmpi->B->cmap->n);
11610e3ece09SJunchao Zhang   PetscCall(MatGetOwnershipRangeColumn(B, &cstart, &cend));
11620e3ece09SJunchao Zhang   PetscCall(MatMPIAIJKokkosReduceBegin(comm, mm->C3, mm->C4, cstart, cend, bmpi->garray, ampi->Mvctx, MAT_INITIAL_MATRIX, map_h.data(), mm));
11630e3ece09SJunchao Zhang 
11640e3ece09SJunchao Zhang   // Adt * (B's diag + B's off-diag)
11650e3ece09SJunchao Zhang   PetscCallCXX(KokkosSparse::spgemm_symbolic(mm->kh1, Adt, false, Bd, false, mm->C1));
11660e3ece09SJunchao Zhang   PetscCallCXX(KokkosSparse::spgemm_symbolic(mm->kh2, Adt, false, Bo, false, mm->C2_mid));
11670e3ece09SJunchao Zhang   PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh1, Adt, false, Bd, false, mm->C1));
11680e3ece09SJunchao Zhang   PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh2, Adt, false, Bo, false, mm->C2_mid));
11690e3ece09SJunchao Zhang #if PETSC_PKG_KOKKOS_KERNELS_VERSION_LT(4, 0, 0)
11700e3ece09SJunchao Zhang   PetscCallCXX(sort_crs_matrix(mm->C1));
11710e3ece09SJunchao Zhang   PetscCallCXX(sort_crs_matrix(mm->C2_mid));
11720e3ece09SJunchao Zhang #endif
11730e3ece09SJunchao Zhang 
11740e3ece09SJunchao Zhang   PetscCall(MatMPIAIJKokkosReduceEnd(comm, mm->C3, mm->C4, cstart, cend, bmpi->garray, ampi->Mvctx, MAT_INITIAL_MATRIX, map_h.data(), mm));
11750e3ece09SJunchao Zhang 
11760e3ece09SJunchao Zhang   // Create C2, which shares a, i arrays with C2_mid, but with new column indices and potentially larger column size
11777b8d4ba6SJunchao Zhang   MatColIdxKokkosView oldj = mm->C2_mid.graph.entries, newj(NoInit("j"), oldj.extent(0));
11780e3ece09SJunchao Zhang   PetscIntKokkosView  map  = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), map_h);
11790e3ece09SJunchao Zhang   PetscCallCXX(Kokkos::parallel_for(
11800e3ece09SJunchao Zhang     oldj.extent(0), KOKKOS_LAMBDA(const PetscInt i) { newj(i) = map(oldj(i)); }));
11810e3ece09SJunchao Zhang   PetscCallCXX(mm->C2 = KokkosCsrMatrix("C2", mm->C2_mid.numRows(), mm->n /*new column size*/, mm->C2_mid.nnz(), mm->C2_mid.values, mm->C2_mid.graph.row_map, newj));
11820e3ece09SJunchao Zhang 
11830e3ece09SJunchao Zhang   // C = (C1+Fd, C2+Fo)
11840e3ece09SJunchao Zhang   PetscCallCXX(mm->kh1.create_spadd_handle(true)); // C1, Fd are sorted
11850e3ece09SJunchao Zhang   PetscCallCXX(mm->kh2.create_spadd_handle(true)); // C2, Fo are sorted
11860e3ece09SJunchao Zhang   PetscCallCXX(KokkosSparse::spadd_symbolic(&mm->kh1, mm->C1, mm->Fd, mm->Cd));
11870e3ece09SJunchao Zhang   PetscCallCXX(KokkosSparse::spadd_symbolic(&mm->kh2, mm->C2, mm->Fo, mm->Co));
11880e3ece09SJunchao Zhang   PetscCallCXX(KokkosSparse::spadd_numeric(&mm->kh1, 1.0, mm->C1, 1.0, mm->Fd, mm->Cd));
11890e3ece09SJunchao Zhang   PetscCallCXX(KokkosSparse::spadd_numeric(&mm->kh2, 1.0, mm->C2, 1.0, mm->Fo, mm->Co));
11900e3ece09SJunchao Zhang   PetscFunctionReturn(PETSC_SUCCESS);
11910e3ece09SJunchao Zhang }
11920e3ece09SJunchao Zhang 
11930e3ece09SJunchao Zhang static PetscErrorCode MatProductNumeric_MPIAIJKokkos_AtB(Mat_Product *product, Mat A, Mat B, MatMatStruct_AtB *mm)
11940e3ece09SJunchao Zhang {
11950e3ece09SJunchao Zhang   Mat_MPIAIJ     *ampi = static_cast<Mat_MPIAIJ *>(A->data);
11960e3ece09SJunchao Zhang   Mat_MPIAIJ     *bmpi = static_cast<Mat_MPIAIJ *>(B->data);
11970e3ece09SJunchao Zhang   KokkosCsrMatrix Adt, Aot, Bd, Bo;
11980e3ece09SJunchao Zhang   MPI_Comm        comm;
11990e3ece09SJunchao Zhang 
12000e3ece09SJunchao Zhang   PetscFunctionBegin;
12010e3ece09SJunchao Zhang   PetscCall(PetscObjectGetComm((PetscObject)B, &comm));
12020e3ece09SJunchao Zhang   PetscCall(MatSeqAIJKokkosGenerateTranspose_Private(ampi->A, &Adt));
12030e3ece09SJunchao Zhang   PetscCall(MatSeqAIJKokkosGenerateTranspose_Private(ampi->B, &Aot));
12040e3ece09SJunchao Zhang   PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(bmpi->A, &Bd));
12050e3ece09SJunchao Zhang   PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(bmpi->B, &Bo));
12060e3ece09SJunchao Zhang 
12070e3ece09SJunchao Zhang   // Aot * (B's diag + B's off-diag)
12080e3ece09SJunchao Zhang   PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh3, Aot, false, Bd, false, mm->C3));
12090e3ece09SJunchao Zhang   PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh4, Aot, false, Bo, false, mm->C4));
12100e3ece09SJunchao Zhang 
12110e3ece09SJunchao Zhang   // Reduce E (i.e., C3 and C4)'s rows to form F, and overlap the communication
12120e3ece09SJunchao Zhang   PetscCall(MatMPIAIJKokkosReduceBegin(comm, mm->C3, mm->C4, 0, 0, NULL, NULL, MAT_REUSE_MATRIX, NULL, mm));
12130e3ece09SJunchao Zhang 
12140e3ece09SJunchao Zhang   // Adt * (B's diag + B's off-diag)
12150e3ece09SJunchao Zhang   PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh1, Adt, false, Bd, false, mm->C1));
12160e3ece09SJunchao Zhang   PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh2, Adt, false, Bo, false, mm->C2_mid));
12170e3ece09SJunchao Zhang 
12180e3ece09SJunchao Zhang   PetscCall(MatMPIAIJKokkosReduceEnd(comm, mm->C3, mm->C4, 0, 0, NULL, NULL, MAT_REUSE_MATRIX, NULL, mm));
12190e3ece09SJunchao Zhang 
12200e3ece09SJunchao Zhang   // C = (C1+Fd, C2+Fo)
12210e3ece09SJunchao Zhang   PetscCallCXX(KokkosSparse::spadd_numeric(&mm->kh1, 1.0, mm->C1, 1.0, mm->Fd, mm->Cd));
12220e3ece09SJunchao Zhang   PetscCallCXX(KokkosSparse::spadd_numeric(&mm->kh2, 1.0, mm->C2, 1.0, mm->Fo, mm->Co));
12230e3ece09SJunchao Zhang   PetscFunctionReturn(PETSC_SUCCESS);
12240e3ece09SJunchao Zhang }
1225f0e6e2d1SJunchao Zhang 
1226076ba34aSJunchao Zhang /* MatProductSymbolic_MPIAIJKokkos_AB - AB flavor of MatProductSymbolic_MPIAIJKokkos
1227076ba34aSJunchao Zhang 
1228076ba34aSJunchao Zhang   Input Parameters:
1229076ba34aSJunchao Zhang +  product  - Mat_Product which carried out the computation. Passed in to access info about this mat product.
1230076ba34aSJunchao Zhang .  A        - an MPIAIJKOKKOS matrix
1231076ba34aSJunchao Zhang .  B        - an MPIAIJKOKKOS matrix
1232076ba34aSJunchao Zhang -  mm       - a struct used to stash intermediate data when computing AB. Persist from symbolic to numeric operations.
1233076ba34aSJunchao Zhang */
1234d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatProductSymbolic_MPIAIJKokkos_AB(Mat_Product *product, Mat A, Mat B, MatMatStruct_AB *mm)
1235d71ae5a4SJacob Faibussowitsch {
12360e3ece09SJunchao Zhang   Mat_MPIAIJ     *ampi = static_cast<Mat_MPIAIJ *>(A->data);
12370e3ece09SJunchao Zhang   Mat_MPIAIJ     *bmpi = static_cast<Mat_MPIAIJ *>(B->data);
12380e3ece09SJunchao Zhang   KokkosCsrMatrix Ad, Ao, Bd, Bo;
1239076ba34aSJunchao Zhang 
1240076ba34aSJunchao Zhang   PetscFunctionBegin;
12410e3ece09SJunchao Zhang   PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(ampi->A, &Ad));
12420e3ece09SJunchao Zhang   PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(ampi->B, &Ao));
12430e3ece09SJunchao Zhang   PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(bmpi->A, &Bd));
12440e3ece09SJunchao Zhang   PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(bmpi->B, &Bo));
12450e3ece09SJunchao Zhang 
12460e3ece09SJunchao Zhang   // TODO: add command line options to select spgemm algorithms
12470e3ece09SJunchao Zhang   auto spgemm_alg = KokkosSparse::SPGEMMAlgorithm::SPGEMM_DEFAULT; // default is TPL if enabled, otherwise KK
12480e3ece09SJunchao Zhang 
12490e3ece09SJunchao Zhang   // CUDA-10.2's spgemm has bugs. We prefer the SpGEMMreuse APIs introduced in cuda-11.4
12500e3ece09SJunchao Zhang #if defined(KOKKOSKERNELS_ENABLE_TPL_CUSPARSE)
12510e3ece09SJunchao Zhang   #if PETSC_PKG_CUDA_VERSION_LT(11, 4, 0)
12520e3ece09SJunchao Zhang   spgemm_alg = KokkosSparse::SPGEMMAlgorithm::SPGEMM_KK;
12530e3ece09SJunchao Zhang   #endif
1254f0e6e2d1SJunchao Zhang #endif
1255f0e6e2d1SJunchao Zhang 
12560e3ece09SJunchao Zhang   mm->kh1.create_spgemm_handle(spgemm_alg);
12570e3ece09SJunchao Zhang   mm->kh2.create_spgemm_handle(spgemm_alg);
12580e3ece09SJunchao Zhang   mm->kh3.create_spgemm_handle(spgemm_alg);
12590e3ece09SJunchao Zhang   mm->kh4.create_spgemm_handle(spgemm_alg);
1260076ba34aSJunchao Zhang 
12610e3ece09SJunchao Zhang   // Bcast B's rows to form F, and overlap the communication
12627b8d4ba6SJunchao Zhang   PetscIntKokkosViewHost map_h(NoInit("map_h"), bmpi->B->cmap->n);
12630e3ece09SJunchao Zhang   PetscCall(MatMPIAIJKokkosBcastBegin(B, ampi->Mvctx, MAT_INITIAL_MATRIX, map_h.data(), mm));
1264076ba34aSJunchao Zhang 
12650e3ece09SJunchao Zhang   // A's diag * (B's diag + B's off-diag)
12660e3ece09SJunchao Zhang   PetscCallCXX(KokkosSparse::spgemm_symbolic(mm->kh1, Ad, false, Bd, false, mm->C1));
12670e3ece09SJunchao Zhang   PetscCallCXX(KokkosSparse::spgemm_symbolic(mm->kh2, Ad, false, Bo, false, mm->C2_mid)); // C2 aliases with C2_mid, except with new column indices
12680e3ece09SJunchao Zhang   // KK spgemm_symbolic() only populates the result's row map, but not its columns.
12690e3ece09SJunchao Zhang   // TODO: Remove the fake spgemm_numeric() after KK fixed this problem.
12700e3ece09SJunchao Zhang   PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh1, Ad, false, Bd, false, mm->C1));
12710e3ece09SJunchao Zhang   PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh2, Ad, false, Bo, false, mm->C2_mid));
12720e3ece09SJunchao Zhang #if PETSC_PKG_KOKKOS_KERNELS_VERSION_LT(4, 0, 0)
12730e3ece09SJunchao Zhang   PetscCallCXX(sort_crs_matrix(mm->C1));
12740e3ece09SJunchao Zhang   PetscCallCXX(sort_crs_matrix(mm->C2_mid));
12750e3ece09SJunchao Zhang #endif
1276076ba34aSJunchao Zhang 
12770e3ece09SJunchao Zhang   PetscCall(MatMPIAIJKokkosBcastEnd(B, ampi->Mvctx, MAT_INITIAL_MATRIX, map_h.data(), mm));
1278076ba34aSJunchao Zhang 
12790e3ece09SJunchao Zhang   // A's off-diag * (F's diag + F's off-diag)
12800e3ece09SJunchao Zhang   PetscCallCXX(KokkosSparse::spgemm_symbolic(mm->kh3, Ao, false, mm->Fd, false, mm->C3));
12810e3ece09SJunchao Zhang   PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh3, Ao, false, mm->Fd, false, mm->C3));
12820e3ece09SJunchao Zhang   PetscCallCXX(KokkosSparse::spgemm_symbolic(mm->kh4, Ao, false, mm->Fo, false, mm->C4));
12830e3ece09SJunchao Zhang   PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh4, Ao, false, mm->Fo, false, mm->C4));
12840e3ece09SJunchao Zhang #if PETSC_PKG_KOKKOS_KERNELS_VERSION_LT(4, 0, 0)
12850e3ece09SJunchao Zhang   PetscCallCXX(sort_crs_matrix(mm->C3));
12860e3ece09SJunchao Zhang   PetscCallCXX(sort_crs_matrix(mm->C4));
12870e3ece09SJunchao Zhang #endif
1288076ba34aSJunchao Zhang 
12890e3ece09SJunchao Zhang   // Create C2, which shares a, i arrays with C2_mid, but with new column indices and potentially larger column size
12907b8d4ba6SJunchao Zhang   MatColIdxKokkosView oldj = mm->C2_mid.graph.entries, newj(NoInit("j"), oldj.extent(0));
12910e3ece09SJunchao Zhang   PetscIntKokkosView  map  = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), map_h);
12920e3ece09SJunchao Zhang   PetscCallCXX(Kokkos::parallel_for(
12930e3ece09SJunchao Zhang     oldj.extent(0), KOKKOS_LAMBDA(const PetscInt i) { newj(i) = map(oldj(i)); }));
12940e3ece09SJunchao Zhang   mm->C2 = KokkosCsrMatrix("C2", mm->C2_mid.numRows(), mm->n /*new column size*/, mm->C2_mid.nnz(), mm->C2_mid.values, mm->C2_mid.graph.row_map, newj);
12950e3ece09SJunchao Zhang 
12960e3ece09SJunchao Zhang   // C = (Cd, Co) = (C1+C3, C2+C4)
12970e3ece09SJunchao Zhang   mm->kh1.create_spadd_handle(true); // C1, C3 are sorted
12980e3ece09SJunchao Zhang   mm->kh2.create_spadd_handle(true); // C2, C4 are sorted
12990e3ece09SJunchao Zhang   PetscCallCXX(KokkosSparse::spadd_symbolic(&mm->kh1, mm->C1, mm->C3, mm->Cd));
13000e3ece09SJunchao Zhang   PetscCallCXX(KokkosSparse::spadd_symbolic(&mm->kh2, mm->C2, mm->C4, mm->Co));
13010e3ece09SJunchao Zhang   PetscCallCXX(KokkosSparse::spadd_numeric(&mm->kh1, 1.0, mm->C1, 1.0, mm->C3, mm->Cd));
13020e3ece09SJunchao Zhang   PetscCallCXX(KokkosSparse::spadd_numeric(&mm->kh2, 1.0, mm->C2, 1.0, mm->C4, mm->Co));
13033ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
1304076ba34aSJunchao Zhang }
1305076ba34aSJunchao Zhang 
13060e3ece09SJunchao Zhang static PetscErrorCode MatProductNumeric_MPIAIJKokkos_AB(Mat_Product *product, Mat A, Mat B, MatMatStruct_AB *mm)
1307d71ae5a4SJacob Faibussowitsch {
13080e3ece09SJunchao Zhang   Mat_MPIAIJ     *ampi = static_cast<Mat_MPIAIJ *>(A->data);
13090e3ece09SJunchao Zhang   Mat_MPIAIJ     *bmpi = static_cast<Mat_MPIAIJ *>(B->data);
13100e3ece09SJunchao Zhang   KokkosCsrMatrix Ad, Ao, Bd, Bo;
1311076ba34aSJunchao Zhang 
1312076ba34aSJunchao Zhang   PetscFunctionBegin;
13130e3ece09SJunchao Zhang   PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(ampi->A, &Ad));
13140e3ece09SJunchao Zhang   PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(ampi->B, &Ao));
13150e3ece09SJunchao Zhang   PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(bmpi->A, &Bd));
13160e3ece09SJunchao Zhang   PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(bmpi->B, &Bo));
1317076ba34aSJunchao Zhang 
13180e3ece09SJunchao Zhang   // Bcast B's rows to form F, and overlap the communication
13190e3ece09SJunchao Zhang   PetscCall(MatMPIAIJKokkosBcastBegin(B, NULL, MAT_REUSE_MATRIX, NULL, mm));
1320076ba34aSJunchao Zhang 
13210e3ece09SJunchao Zhang   // A's diag * (B's diag + B's off-diag)
13220e3ece09SJunchao Zhang   PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh1, Ad, false, Bd, false, mm->C1));
13230e3ece09SJunchao Zhang   PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh2, Ad, false, Bo, false, mm->C2_mid));
1324076ba34aSJunchao Zhang 
13250e3ece09SJunchao Zhang   PetscCall(MatMPIAIJKokkosBcastEnd(B, NULL, MAT_REUSE_MATRIX, NULL, mm));
1326076ba34aSJunchao Zhang 
13270e3ece09SJunchao Zhang   // A's off-diag * (F's diag + F's off-diag)
13280e3ece09SJunchao Zhang   PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh3, Ao, false, mm->Fd, false, mm->C3));
13290e3ece09SJunchao Zhang   PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh4, Ao, false, mm->Fo, false, mm->C4));
13300e3ece09SJunchao Zhang 
13310e3ece09SJunchao Zhang   // C = (Cd, Co) = (C1+C3, C2+C4)
13320e3ece09SJunchao Zhang   PetscCallCXX(KokkosSparse::spadd_numeric(&mm->kh1, 1.0, mm->C1, 1.0, mm->C3, mm->Cd));
13330e3ece09SJunchao Zhang   PetscCallCXX(KokkosSparse::spadd_numeric(&mm->kh2, 1.0, mm->C2, 1.0, mm->C4, mm->Co));
13343ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
1335076ba34aSJunchao Zhang }
1336076ba34aSJunchao Zhang 
1337d71ae5a4SJacob Faibussowitsch PetscErrorCode MatProductNumeric_MPIAIJKokkos(Mat C)
1338d71ae5a4SJacob Faibussowitsch {
13390e3ece09SJunchao Zhang   Mat_MPIAIJ                  *cmpi = static_cast<Mat_MPIAIJ *>(C->data);
13400e3ece09SJunchao Zhang   Mat_Product                 *product;
13410e3ece09SJunchao Zhang   MatProductData_MPIAIJKokkos *pdata;
1342076ba34aSJunchao Zhang   MatProductType               ptype;
13430e3ece09SJunchao Zhang   Mat                          A, B;
1344076ba34aSJunchao Zhang 
1345076ba34aSJunchao Zhang   PetscFunctionBegin;
13460e3ece09SJunchao Zhang   MatCheckProduct(C, 1); // make sure C is a product
13470e3ece09SJunchao Zhang   product = C->product;
13480e3ece09SJunchao Zhang   pdata   = static_cast<MatProductData_MPIAIJKokkos *>(product->data);
1349076ba34aSJunchao Zhang   ptype   = product->type;
1350076ba34aSJunchao Zhang   A       = product->A;
1351076ba34aSJunchao Zhang   B       = product->B;
1352076ba34aSJunchao Zhang 
13530e3ece09SJunchao Zhang   // See if numeric has already been done in symbolic (e.g., user calls MatMatMult(A,B,MAT_INITIAL_MATRIX,..,C)).
13540e3ece09SJunchao Zhang   // If yes, skip the numeric, but reset the flag so that next time when user calls MatMatMult(E,F,MAT_REUSE_MATRIX,..,C),
13550e3ece09SJunchao Zhang   // we still do numeric.
13560e3ece09SJunchao Zhang   if (pdata->reusesym) { // numeric reuses results from symbolic
13570e3ece09SJunchao Zhang     pdata->reusesym = PETSC_FALSE;
13583ba16761SJacob Faibussowitsch     PetscFunctionReturn(PETSC_SUCCESS);
1359076ba34aSJunchao Zhang   }
1360076ba34aSJunchao Zhang 
1361076ba34aSJunchao Zhang   if (ptype == MATPRODUCT_AB) {
13620e3ece09SJunchao Zhang     PetscCall(MatProductNumeric_MPIAIJKokkos_AB(product, A, B, pdata->mmAB));
1363076ba34aSJunchao Zhang   } else if (ptype == MATPRODUCT_AtB) {
13640e3ece09SJunchao Zhang     PetscCall(MatProductNumeric_MPIAIJKokkos_AtB(product, A, B, pdata->mmAtB));
13650e3ece09SJunchao Zhang   } else if (ptype == MATPRODUCT_PtAP) { // BtAB, computed by Z = AB; C= BtZ
13660e3ece09SJunchao Zhang     PetscCall(MatProductNumeric_MPIAIJKokkos_AB(product, A, B, pdata->mmAB));
13670e3ece09SJunchao Zhang     PetscCall(MatProductNumeric_MPIAIJKokkos_AtB(product, B, pdata->Z, pdata->mmAtB));
1368076ba34aSJunchao Zhang   }
13690e3ece09SJunchao Zhang 
13700e3ece09SJunchao Zhang   PetscCall(MatSeqAIJKokkosModifyDevice(cmpi->A)); // mark that A, B on device are modified
13710e3ece09SJunchao Zhang   PetscCall(MatSeqAIJKokkosModifyDevice(cmpi->B));
13723ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
1373076ba34aSJunchao Zhang }
1374076ba34aSJunchao Zhang 
1375d71ae5a4SJacob Faibussowitsch PetscErrorCode MatProductSymbolic_MPIAIJKokkos(Mat C)
1376d71ae5a4SJacob Faibussowitsch {
1377076ba34aSJunchao Zhang   Mat                          A, B;
13780e3ece09SJunchao Zhang   Mat_Product                 *product;
1379076ba34aSJunchao Zhang   MatProductType               ptype;
13800e3ece09SJunchao Zhang   MatProductData_MPIAIJKokkos *pdata;
1381076ba34aSJunchao Zhang   MatMatStruct                *mm = NULL;
13820e3ece09SJunchao Zhang   PetscInt                     m, n, M, N;
13830e3ece09SJunchao Zhang   Mat                          Cd, Co;
13840e3ece09SJunchao Zhang   MPI_Comm                     comm;
1385076ba34aSJunchao Zhang 
1386076ba34aSJunchao Zhang   PetscFunctionBegin;
13870e3ece09SJunchao Zhang   PetscCall(PetscObjectGetComm((PetscObject)C, &comm));
1388076ba34aSJunchao Zhang   MatCheckProduct(C, 1);
13890e3ece09SJunchao Zhang   product = C->product;
13900e3ece09SJunchao Zhang   PetscCheck(!product->data, comm, PETSC_ERR_PLIB, "Product data not empty");
1391076ba34aSJunchao Zhang   ptype = product->type;
1392076ba34aSJunchao Zhang   A     = product->A;
1393076ba34aSJunchao Zhang   B     = product->B;
1394076ba34aSJunchao Zhang 
1395076ba34aSJunchao Zhang   switch (ptype) {
13969371c9d4SSatish Balay   case MATPRODUCT_AB:
13979371c9d4SSatish Balay     m = A->rmap->n;
13989371c9d4SSatish Balay     n = B->cmap->n;
13999371c9d4SSatish Balay     M = A->rmap->N;
14009371c9d4SSatish Balay     N = B->cmap->N;
14019371c9d4SSatish Balay     break;
14029371c9d4SSatish Balay   case MATPRODUCT_AtB:
14039371c9d4SSatish Balay     m = A->cmap->n;
14049371c9d4SSatish Balay     n = B->cmap->n;
14059371c9d4SSatish Balay     M = A->cmap->N;
14069371c9d4SSatish Balay     N = B->cmap->N;
14079371c9d4SSatish Balay     break;
14089371c9d4SSatish Balay   case MATPRODUCT_PtAP:
14099371c9d4SSatish Balay     m = B->cmap->n;
14109371c9d4SSatish Balay     n = B->cmap->n;
14119371c9d4SSatish Balay     M = B->cmap->N;
14129371c9d4SSatish Balay     N = B->cmap->N;
14139371c9d4SSatish Balay     break; /* BtAB */
1414d71ae5a4SJacob Faibussowitsch   default:
14150e3ece09SJunchao Zhang     SETERRQ(comm, PETSC_ERR_PLIB, "Not for product type %s", MatProductTypes[ptype]);
1416076ba34aSJunchao Zhang   }
1417076ba34aSJunchao Zhang 
14189566063dSJacob Faibussowitsch   PetscCall(MatSetSizes(C, m, n, M, N));
14199566063dSJacob Faibussowitsch   PetscCall(PetscLayoutSetUp(C->rmap));
14209566063dSJacob Faibussowitsch   PetscCall(PetscLayoutSetUp(C->cmap));
14219566063dSJacob Faibussowitsch   PetscCall(MatSetType(C, ((PetscObject)A)->type_name));
1422076ba34aSJunchao Zhang 
14230e3ece09SJunchao Zhang   pdata           = new MatProductData_MPIAIJKokkos();
14240e3ece09SJunchao Zhang   pdata->reusesym = product->api_user;
1425076ba34aSJunchao Zhang 
1426076ba34aSJunchao Zhang   if (ptype == MATPRODUCT_AB) {
14270e3ece09SJunchao Zhang     auto mmAB = new MatMatStruct_AB();
14280e3ece09SJunchao Zhang     PetscCall(MatProductSymbolic_MPIAIJKokkos_AB(product, A, B, mmAB));
14290e3ece09SJunchao Zhang     mm = pdata->mmAB = mmAB;
1430076ba34aSJunchao Zhang   } else if (ptype == MATPRODUCT_AtB) {
14310e3ece09SJunchao Zhang     auto mmAtB = new MatMatStruct_AtB();
14320e3ece09SJunchao Zhang     PetscCall(MatProductSymbolic_MPIAIJKokkos_AtB(product, A, B, mmAtB));
14330e3ece09SJunchao Zhang     mm = pdata->mmAtB = mmAtB;
14340e3ece09SJunchao Zhang   } else if (ptype == MATPRODUCT_PtAP) { // C = BtAB, computed as Z = AB; C= BtZ
14350e3ece09SJunchao Zhang     Mat Zd, Zo, Z;                       // Zd, Zo are owned by pdata->Z
14360e3ece09SJunchao Zhang 
14370e3ece09SJunchao Zhang     auto mmAB = new MatMatStruct_AB();
14380e3ece09SJunchao Zhang     PetscCall(MatProductSymbolic_MPIAIJKokkos_AB(product, A, B, mmAB)); // Z stored as mmAB->{Cd, Co}
14390e3ece09SJunchao Zhang     PetscCall(MatCreateSeqAIJKokkosWithKokkosCsrMatrix(PETSC_COMM_SELF, mmAB->Cd, &Zd));
14400e3ece09SJunchao Zhang     PetscCall(MatCreateSeqAIJKokkosWithKokkosCsrMatrix(PETSC_COMM_SELF, mmAB->Co, &Zo));
14410e3ece09SJunchao Zhang     pdata->mmAB = mmAB;
14420e3ece09SJunchao Zhang 
14430e3ece09SJunchao Zhang     m = A->rmap->n; // Z's layout
14440e3ece09SJunchao Zhang     n = B->cmap->n;
14450e3ece09SJunchao Zhang     M = A->rmap->N;
14460e3ece09SJunchao Zhang     N = B->cmap->N;
14470e3ece09SJunchao Zhang     PetscCall(MatCreate(comm, &Z));
14480e3ece09SJunchao Zhang     PetscCall(MatSetSizes(Z, m, n, M, N));
14490e3ece09SJunchao Zhang     PetscCall(PetscLayoutSetUp(Z->rmap));
14500e3ece09SJunchao Zhang     PetscCall(PetscLayoutSetUp(Z->cmap));
14510e3ece09SJunchao Zhang     PetscCall(MatSetType(Z, MATMPIAIJKOKKOS));
14520e3ece09SJunchao Zhang     PetscCall(MatSetMPIAIJKokkosWithSplitSeqAIJKokkosMatrices(Z, Zd, Zo, mmAB->garray));
14530e3ece09SJunchao Zhang 
14540e3ece09SJunchao Zhang     auto mmAtB = new MatMatStruct_AtB();
14550e3ece09SJunchao Zhang     PetscCall(MatProductSymbolic_MPIAIJKokkos_AtB(product, B, Z, mmAtB)); // final result C stored as mmAtB->{Cd, Co}
14560e3ece09SJunchao Zhang 
14570e3ece09SJunchao Zhang     pdata->Z = Z; // give ownership to pdata
14580e3ece09SJunchao Zhang     mm = pdata->mmAtB = mmAtB;
1459076ba34aSJunchao Zhang   }
14600e3ece09SJunchao Zhang 
14610e3ece09SJunchao Zhang   PetscCall(MatCreateSeqAIJKokkosWithKokkosCsrMatrix(PETSC_COMM_SELF, mm->Cd, &Cd));
14620e3ece09SJunchao Zhang   PetscCall(MatCreateSeqAIJKokkosWithKokkosCsrMatrix(PETSC_COMM_SELF, mm->Co, &Co));
14630e3ece09SJunchao Zhang   PetscCall(MatSetMPIAIJKokkosWithSplitSeqAIJKokkosMatrices(C, Cd, Co, mm->garray));
14640e3ece09SJunchao Zhang 
14650e3ece09SJunchao Zhang   C->product->data       = pdata;
1466076ba34aSJunchao Zhang   C->product->destroy    = MatProductDataDestroy_MPIAIJKokkos;
1467076ba34aSJunchao Zhang   C->ops->productnumeric = MatProductNumeric_MPIAIJKokkos;
14683ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
1469076ba34aSJunchao Zhang }
1470076ba34aSJunchao Zhang 
1471d71ae5a4SJacob Faibussowitsch PETSC_INTERN PetscErrorCode MatProductSetFromOptions_MPIAIJKokkos(Mat mat)
1472d71ae5a4SJacob Faibussowitsch {
1473076ba34aSJunchao Zhang   Mat_Product *product = mat->product;
1474076ba34aSJunchao Zhang   PetscBool    match   = PETSC_FALSE;
1475076ba34aSJunchao Zhang   PetscBool    usecpu  = PETSC_FALSE;
1476076ba34aSJunchao Zhang 
1477076ba34aSJunchao Zhang   PetscFunctionBegin;
1478076ba34aSJunchao Zhang   MatCheckProduct(mat, 1);
147948a46eb9SPierre Jolivet   if (!product->A->boundtocpu && !product->B->boundtocpu) PetscCall(PetscObjectTypeCompare((PetscObject)product->B, ((PetscObject)product->A)->type_name, &match));
1480076ba34aSJunchao Zhang   if (match) { /* we can always fallback to the CPU if requested */
1481076ba34aSJunchao Zhang     switch (product->type) {
1482076ba34aSJunchao Zhang     case MATPRODUCT_AB:
1483076ba34aSJunchao Zhang       if (product->api_user) {
1484d0609cedSBarry Smith         PetscOptionsBegin(PetscObjectComm((PetscObject)mat), ((PetscObject)mat)->prefix, "MatMatMult", "Mat");
14859566063dSJacob Faibussowitsch         PetscCall(PetscOptionsBool("-matmatmult_backend_cpu", "Use CPU code", "MatMatMult", usecpu, &usecpu, NULL));
1486d0609cedSBarry Smith         PetscOptionsEnd();
1487076ba34aSJunchao Zhang       } else {
1488d0609cedSBarry Smith         PetscOptionsBegin(PetscObjectComm((PetscObject)mat), ((PetscObject)mat)->prefix, "MatProduct_AB", "Mat");
14899566063dSJacob Faibussowitsch         PetscCall(PetscOptionsBool("-mat_product_algorithm_backend_cpu", "Use CPU code", "MatMatMult", usecpu, &usecpu, NULL));
1490d0609cedSBarry Smith         PetscOptionsEnd();
1491076ba34aSJunchao Zhang       }
1492076ba34aSJunchao Zhang       break;
1493076ba34aSJunchao Zhang     case MATPRODUCT_AtB:
1494076ba34aSJunchao Zhang       if (product->api_user) {
1495d0609cedSBarry Smith         PetscOptionsBegin(PetscObjectComm((PetscObject)mat), ((PetscObject)mat)->prefix, "MatTransposeMatMult", "Mat");
14969566063dSJacob Faibussowitsch         PetscCall(PetscOptionsBool("-mattransposematmult_backend_cpu", "Use CPU code", "MatTransposeMatMult", usecpu, &usecpu, NULL));
1497d0609cedSBarry Smith         PetscOptionsEnd();
1498076ba34aSJunchao Zhang       } else {
1499d0609cedSBarry Smith         PetscOptionsBegin(PetscObjectComm((PetscObject)mat), ((PetscObject)mat)->prefix, "MatProduct_AtB", "Mat");
15009566063dSJacob Faibussowitsch         PetscCall(PetscOptionsBool("-mat_product_algorithm_backend_cpu", "Use CPU code", "MatTransposeMatMult", usecpu, &usecpu, NULL));
1501d0609cedSBarry Smith         PetscOptionsEnd();
1502076ba34aSJunchao Zhang       }
1503076ba34aSJunchao Zhang       break;
1504076ba34aSJunchao Zhang     case MATPRODUCT_PtAP:
1505076ba34aSJunchao Zhang       if (product->api_user) {
1506d0609cedSBarry Smith         PetscOptionsBegin(PetscObjectComm((PetscObject)mat), ((PetscObject)mat)->prefix, "MatPtAP", "Mat");
15079566063dSJacob Faibussowitsch         PetscCall(PetscOptionsBool("-matptap_backend_cpu", "Use CPU code", "MatPtAP", usecpu, &usecpu, NULL));
1508d0609cedSBarry Smith         PetscOptionsEnd();
1509076ba34aSJunchao Zhang       } else {
1510d0609cedSBarry Smith         PetscOptionsBegin(PetscObjectComm((PetscObject)mat), ((PetscObject)mat)->prefix, "MatProduct_PtAP", "Mat");
15119566063dSJacob Faibussowitsch         PetscCall(PetscOptionsBool("-mat_product_algorithm_backend_cpu", "Use CPU code", "MatPtAP", usecpu, &usecpu, NULL));
1512d0609cedSBarry Smith         PetscOptionsEnd();
1513076ba34aSJunchao Zhang       }
1514076ba34aSJunchao Zhang       break;
1515d71ae5a4SJacob Faibussowitsch     default:
1516d71ae5a4SJacob Faibussowitsch       break;
1517076ba34aSJunchao Zhang     }
1518076ba34aSJunchao Zhang     match = (PetscBool)!usecpu;
1519076ba34aSJunchao Zhang   }
1520076ba34aSJunchao Zhang   if (match) {
1521076ba34aSJunchao Zhang     switch (product->type) {
1522076ba34aSJunchao Zhang     case MATPRODUCT_AB:
1523076ba34aSJunchao Zhang     case MATPRODUCT_AtB:
1524d71ae5a4SJacob Faibussowitsch     case MATPRODUCT_PtAP:
1525d71ae5a4SJacob Faibussowitsch       mat->ops->productsymbolic = MatProductSymbolic_MPIAIJKokkos;
1526d71ae5a4SJacob Faibussowitsch       break;
1527d71ae5a4SJacob Faibussowitsch     default:
1528d71ae5a4SJacob Faibussowitsch       break;
1529076ba34aSJunchao Zhang     }
1530076ba34aSJunchao Zhang   }
1531076ba34aSJunchao Zhang   /* fallback to MPIAIJ ops */
153248a46eb9SPierre Jolivet   if (!mat->ops->productsymbolic) PetscCall(MatProductSetFromOptions_MPIAIJ(mat));
15333ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
1534076ba34aSJunchao Zhang }
1535076ba34aSJunchao Zhang 
15362c4ab24aSJunchao Zhang // Mirror of MatCOOStruct_MPIAIJ on device
15372c4ab24aSJunchao Zhang struct MatCOOStruct_MPIAIJKokkos {
15382c4ab24aSJunchao Zhang   PetscCount           n;
15392c4ab24aSJunchao Zhang   PetscSF              sf;
15402c4ab24aSJunchao Zhang   PetscCount           Annz, Bnnz;
15412c4ab24aSJunchao Zhang   PetscCount           Annz2, Bnnz2;
15422c4ab24aSJunchao Zhang   PetscCountKokkosView Ajmap1, Aperm1;
15432c4ab24aSJunchao Zhang   PetscCountKokkosView Bjmap1, Bperm1;
15442c4ab24aSJunchao Zhang   PetscCountKokkosView Aimap2, Ajmap2, Aperm2;
15452c4ab24aSJunchao Zhang   PetscCountKokkosView Bimap2, Bjmap2, Bperm2;
15462c4ab24aSJunchao Zhang   PetscCountKokkosView Cperm1;
15472c4ab24aSJunchao Zhang   MatScalarKokkosView  sendbuf, recvbuf;
15482c4ab24aSJunchao Zhang 
15492c4ab24aSJunchao Zhang   MatCOOStruct_MPIAIJKokkos(const MatCOOStruct_MPIAIJ *coo_h) :
15502c4ab24aSJunchao Zhang     n(coo_h->n),
15512c4ab24aSJunchao Zhang     sf(coo_h->sf),
15522c4ab24aSJunchao Zhang     Annz(coo_h->Annz),
15532c4ab24aSJunchao Zhang     Bnnz(coo_h->Bnnz),
15542c4ab24aSJunchao Zhang     Annz2(coo_h->Annz2),
15552c4ab24aSJunchao Zhang     Bnnz2(coo_h->Bnnz2),
15562c4ab24aSJunchao Zhang     Ajmap1(Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), PetscCountKokkosViewHost(coo_h->Ajmap1, coo_h->Annz + 1))),
15572c4ab24aSJunchao Zhang     Aperm1(Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), PetscCountKokkosViewHost(coo_h->Aperm1, coo_h->Atot1))),
15582c4ab24aSJunchao Zhang     Bjmap1(Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), PetscCountKokkosViewHost(coo_h->Bjmap1, coo_h->Bnnz + 1))),
15592c4ab24aSJunchao Zhang     Bperm1(Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), PetscCountKokkosViewHost(coo_h->Bperm1, coo_h->Btot1))),
15602c4ab24aSJunchao Zhang     Aimap2(Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), PetscCountKokkosViewHost(coo_h->Aimap2, coo_h->Annz2))),
15612c4ab24aSJunchao Zhang     Ajmap2(Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), PetscCountKokkosViewHost(coo_h->Ajmap2, coo_h->Annz2 + 1))),
15622c4ab24aSJunchao Zhang     Aperm2(Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), PetscCountKokkosViewHost(coo_h->Aperm2, coo_h->Atot2))),
15632c4ab24aSJunchao Zhang     Bimap2(Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), PetscCountKokkosViewHost(coo_h->Bimap2, coo_h->Bnnz2))),
15642c4ab24aSJunchao Zhang     Bjmap2(Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), PetscCountKokkosViewHost(coo_h->Bjmap2, coo_h->Bnnz2 + 1))),
15652c4ab24aSJunchao Zhang     Bperm2(Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), PetscCountKokkosViewHost(coo_h->Bperm2, coo_h->Btot2))),
15662c4ab24aSJunchao Zhang     Cperm1(Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), PetscCountKokkosViewHost(coo_h->Cperm1, coo_h->sendlen))),
15672c4ab24aSJunchao Zhang     sendbuf(Kokkos::create_mirror_view(DefaultMemorySpace(), MatScalarKokkosViewHost(coo_h->sendbuf, coo_h->sendlen))),
15682c4ab24aSJunchao Zhang     recvbuf(Kokkos::create_mirror_view(DefaultMemorySpace(), MatScalarKokkosViewHost(coo_h->recvbuf, coo_h->recvlen)))
15692c4ab24aSJunchao Zhang   {
15702c4ab24aSJunchao Zhang     PetscCallVoid(PetscObjectReference((PetscObject)sf));
15712c4ab24aSJunchao Zhang   }
15722c4ab24aSJunchao Zhang 
15732c4ab24aSJunchao Zhang   ~MatCOOStruct_MPIAIJKokkos() { PetscCallVoid(PetscSFDestroy(&sf)); }
15742c4ab24aSJunchao Zhang };
15752c4ab24aSJunchao Zhang 
15762c4ab24aSJunchao Zhang static PetscErrorCode MatCOOStructDestroy_MPIAIJKokkos(void *data)
15772c4ab24aSJunchao Zhang {
15782c4ab24aSJunchao Zhang   PetscFunctionBegin;
15792c4ab24aSJunchao Zhang   PetscCallCXX(delete static_cast<MatCOOStruct_MPIAIJKokkos *>(data));
15802c4ab24aSJunchao Zhang   PetscFunctionReturn(PETSC_SUCCESS);
15812c4ab24aSJunchao Zhang }
15822c4ab24aSJunchao Zhang 
1583d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatSetPreallocationCOO_MPIAIJKokkos(Mat mat, PetscCount coo_n, PetscInt coo_i[], PetscInt coo_j[])
1584d71ae5a4SJacob Faibussowitsch {
15852c4ab24aSJunchao Zhang   PetscContainer             container_h, container_d;
15862c4ab24aSJunchao Zhang   MatCOOStruct_MPIAIJ       *coo_h;
15872c4ab24aSJunchao Zhang   MatCOOStruct_MPIAIJKokkos *coo_d;
158842550becSJunchao Zhang 
158942550becSJunchao Zhang   PetscFunctionBegin;
159030203840SJunchao Zhang   PetscCall(MatSetPreallocationCOO_MPIAIJ(mat, coo_n, coo_i, coo_j)); /* mpiaij->A,B's type is set to seqaijkokkos */
1591cbc6b225SStefano Zampini   mat->preallocated = PETSC_TRUE;
15929566063dSJacob Faibussowitsch   PetscCall(MatAssemblyBegin(mat, MAT_FINAL_ASSEMBLY));
15939566063dSJacob Faibussowitsch   PetscCall(MatAssemblyEnd(mat, MAT_FINAL_ASSEMBLY));
15949566063dSJacob Faibussowitsch   PetscCall(MatZeroEntries(mat));
15952c4ab24aSJunchao Zhang 
15962c4ab24aSJunchao Zhang   // Copy the COO struct to device
15972c4ab24aSJunchao Zhang   PetscCall(PetscObjectQuery((PetscObject)mat, "__PETSc_MatCOOStruct_Host", (PetscObject *)&container_h));
15982c4ab24aSJunchao Zhang   PetscCall(PetscContainerGetPointer(container_h, (void **)&coo_h));
15992c4ab24aSJunchao Zhang   PetscCallCXX(coo_d = new MatCOOStruct_MPIAIJKokkos(coo_h));
16002c4ab24aSJunchao Zhang 
16012c4ab24aSJunchao Zhang   // Put the COO struct in a container and then attach that to the matrix
16022c4ab24aSJunchao Zhang   PetscCall(PetscContainerCreate(PETSC_COMM_SELF, &container_d));
16032c4ab24aSJunchao Zhang   PetscCall(PetscContainerSetPointer(container_d, coo_d));
16042c4ab24aSJunchao Zhang   PetscCall(PetscContainerSetUserDestroy(container_d, MatCOOStructDestroy_MPIAIJKokkos));
16052c4ab24aSJunchao Zhang   PetscCall(PetscObjectCompose((PetscObject)mat, "__PETSc_MatCOOStruct_Device", (PetscObject)container_d));
16062c4ab24aSJunchao Zhang   PetscCall(PetscContainerDestroy(&container_d));
16073ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
160842550becSJunchao Zhang }
160942550becSJunchao Zhang 
1610d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatSetValuesCOO_MPIAIJKokkos(Mat mat, const PetscScalar v[], InsertMode imode)
1611d71ae5a4SJacob Faibussowitsch {
1612394ed5ebSJunchao Zhang   Mat_MPIAIJ                *mpiaij = static_cast<Mat_MPIAIJ *>(mat->data);
161342550becSJunchao Zhang   Mat                        A = mpiaij->A, B = mpiaij->B;
161442550becSJunchao Zhang   MatScalarKokkosView        Aa, Ba;
1615394ed5ebSJunchao Zhang   MatScalarKokkosView        v1;
161642550becSJunchao Zhang   PetscMemType               memtype;
16172c4ab24aSJunchao Zhang   PetscContainer             container;
16182c4ab24aSJunchao Zhang   MatCOOStruct_MPIAIJKokkos *coo;
161942550becSJunchao Zhang 
162042550becSJunchao Zhang   PetscFunctionBegin;
16212c4ab24aSJunchao Zhang   PetscCall(PetscObjectQuery((PetscObject)mat, "__PETSc_MatCOOStruct_Device", (PetscObject *)&container));
16222c4ab24aSJunchao Zhang   PetscCall(PetscContainerGetPointer(container, (void **)&coo));
16232c4ab24aSJunchao Zhang 
16242c4ab24aSJunchao Zhang   const auto &n      = coo->n;
16252c4ab24aSJunchao Zhang   const auto &Annz   = coo->Annz;
16262c4ab24aSJunchao Zhang   const auto &Annz2  = coo->Annz2;
16272c4ab24aSJunchao Zhang   const auto &Bnnz   = coo->Bnnz;
16282c4ab24aSJunchao Zhang   const auto &Bnnz2  = coo->Bnnz2;
16292c4ab24aSJunchao Zhang   const auto &vsend  = coo->sendbuf;
16302c4ab24aSJunchao Zhang   const auto &v2     = coo->recvbuf;
16312c4ab24aSJunchao Zhang   const auto &Ajmap1 = coo->Ajmap1;
16322c4ab24aSJunchao Zhang   const auto &Ajmap2 = coo->Ajmap2;
16332c4ab24aSJunchao Zhang   const auto &Aimap2 = coo->Aimap2;
16342c4ab24aSJunchao Zhang   const auto &Bjmap1 = coo->Bjmap1;
16352c4ab24aSJunchao Zhang   const auto &Bjmap2 = coo->Bjmap2;
16362c4ab24aSJunchao Zhang   const auto &Bimap2 = coo->Bimap2;
16372c4ab24aSJunchao Zhang   const auto &Aperm1 = coo->Aperm1;
16382c4ab24aSJunchao Zhang   const auto &Aperm2 = coo->Aperm2;
16392c4ab24aSJunchao Zhang   const auto &Bperm1 = coo->Bperm1;
16402c4ab24aSJunchao Zhang   const auto &Bperm2 = coo->Bperm2;
16412c4ab24aSJunchao Zhang   const auto &Cperm1 = coo->Cperm1;
16422c4ab24aSJunchao Zhang 
16439566063dSJacob Faibussowitsch   PetscCall(PetscGetMemType(v, &memtype)); /* Return PETSC_MEMTYPE_HOST when v is NULL */
164442550becSJunchao Zhang   if (PetscMemTypeHost(memtype)) {         /* If user gave v[] in host, we need to copy it to device if any */
16452c4ab24aSJunchao Zhang     v1 = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), MatScalarKokkosViewHost((PetscScalar *)v, n));
164642550becSJunchao Zhang   } else {
16472c4ab24aSJunchao Zhang     v1 = MatScalarKokkosView((PetscScalar *)v, n); /* Directly use v[]'s memory */
164842550becSJunchao Zhang   }
164942550becSJunchao Zhang 
165042550becSJunchao Zhang   if (imode == INSERT_VALUES) {
16519566063dSJacob Faibussowitsch     PetscCall(MatSeqAIJGetKokkosViewWrite(A, &Aa)); /* write matrix values */
16529566063dSJacob Faibussowitsch     PetscCall(MatSeqAIJGetKokkosViewWrite(B, &Ba));
1653394ed5ebSJunchao Zhang   } else {
16549566063dSJacob Faibussowitsch     PetscCall(MatSeqAIJGetKokkosView(A, &Aa)); /* read & write matrix values */
16559566063dSJacob Faibussowitsch     PetscCall(MatSeqAIJGetKokkosView(B, &Ba));
165642550becSJunchao Zhang   }
165742550becSJunchao Zhang 
1658*08bb9926SJunchao Zhang   PetscCall(PetscLogGpuTimeBegin());
165942550becSJunchao Zhang   /* Pack entries to be sent to remote */
16609371c9d4SSatish Balay   Kokkos::parallel_for(
16619371c9d4SSatish Balay     vsend.extent(0), KOKKOS_LAMBDA(const PetscCount i) { vsend(i) = v1(Cperm1(i)); });
166242550becSJunchao Zhang 
166342550becSJunchao Zhang   /* Send remote entries to their owner and overlap the communication with local computation */
16642c4ab24aSJunchao Zhang   PetscCall(PetscSFReduceWithMemTypeBegin(coo->sf, MPIU_SCALAR, PETSC_MEMTYPE_KOKKOS, vsend.data(), PETSC_MEMTYPE_KOKKOS, v2.data(), MPI_REPLACE));
1665158ec288SJunchao Zhang   /* Add local entries to A and B in one kernel */
16669371c9d4SSatish Balay   Kokkos::parallel_for(
16679371c9d4SSatish Balay     Annz + Bnnz, KOKKOS_LAMBDA(PetscCount i) {
1668158ec288SJunchao Zhang       PetscScalar sum = 0.0;
1669158ec288SJunchao Zhang       if (i < Annz) {
1670158ec288SJunchao Zhang         for (PetscCount k = Ajmap1(i); k < Ajmap1(i + 1); k++) sum += v1(Aperm1(k));
1671ac38520cSJunchao Zhang         Aa(i) = (imode == INSERT_VALUES ? 0.0 : Aa(i)) + sum;
1672158ec288SJunchao Zhang       } else {
1673158ec288SJunchao Zhang         i -= Annz;
1674158ec288SJunchao Zhang         for (PetscCount k = Bjmap1(i); k < Bjmap1(i + 1); k++) sum += v1(Bperm1(k));
1675ac38520cSJunchao Zhang         Ba(i) = (imode == INSERT_VALUES ? 0.0 : Ba(i)) + sum;
1676158ec288SJunchao Zhang       }
1677158ec288SJunchao Zhang     });
16782c4ab24aSJunchao Zhang   PetscCall(PetscSFReduceEnd(coo->sf, MPIU_SCALAR, vsend.data(), v2.data(), MPI_REPLACE));
167942550becSJunchao Zhang 
1680158ec288SJunchao Zhang   /* Add received remote entries to A and B in one kernel */
16819371c9d4SSatish Balay   Kokkos::parallel_for(
16829371c9d4SSatish Balay     Annz2 + Bnnz2, KOKKOS_LAMBDA(PetscCount i) {
1683158ec288SJunchao Zhang       if (i < Annz2) {
1684158ec288SJunchao Zhang         for (PetscCount k = Ajmap2(i); k < Ajmap2(i + 1); k++) Aa(Aimap2(i)) += v2(Aperm2(k));
1685158ec288SJunchao Zhang       } else {
1686158ec288SJunchao Zhang         i -= Annz2;
1687158ec288SJunchao Zhang         for (PetscCount k = Bjmap2(i); k < Bjmap2(i + 1); k++) Ba(Bimap2(i)) += v2(Bperm2(k));
1688158ec288SJunchao Zhang       }
1689158ec288SJunchao Zhang     });
1690*08bb9926SJunchao Zhang   PetscCall(PetscLogGpuTimeEnd());
169142550becSJunchao Zhang 
1692394ed5ebSJunchao Zhang   if (imode == INSERT_VALUES) {
16939566063dSJacob Faibussowitsch     PetscCall(MatSeqAIJRestoreKokkosViewWrite(A, &Aa)); /* Increase A & B's state etc. */
16949566063dSJacob Faibussowitsch     PetscCall(MatSeqAIJRestoreKokkosViewWrite(B, &Ba));
1695394ed5ebSJunchao Zhang   } else {
16969566063dSJacob Faibussowitsch     PetscCall(MatSeqAIJRestoreKokkosView(A, &Aa));
16979566063dSJacob Faibussowitsch     PetscCall(MatSeqAIJRestoreKokkosView(B, &Ba));
1698394ed5ebSJunchao Zhang   }
16993ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
170042550becSJunchao Zhang }
170142550becSJunchao Zhang 
17022c4ab24aSJunchao Zhang static PetscErrorCode MatDestroy_MPIAIJKokkos(Mat A)
1703d71ae5a4SJacob Faibussowitsch {
1704076ba34aSJunchao Zhang   PetscFunctionBegin;
17059566063dSJacob Faibussowitsch   PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatMPIAIJSetPreallocation_C", NULL));
17069566063dSJacob Faibussowitsch   PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatMPIAIJGetLocalMatMerge_C", NULL));
17079566063dSJacob Faibussowitsch   PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatSetPreallocationCOO_C", NULL));
17089566063dSJacob Faibussowitsch   PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatSetValuesCOO_C", NULL));
17099566063dSJacob Faibussowitsch   PetscCall(MatDestroy_MPIAIJ(A));
17103ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
1711076ba34aSJunchao Zhang }
1712076ba34aSJunchao Zhang 
17132c4ab24aSJunchao Zhang static PetscErrorCode MatSetOps_MPIAIJKokkos(Mat B)
17142c4ab24aSJunchao Zhang {
17152c4ab24aSJunchao Zhang   PetscFunctionBegin;
17162c4ab24aSJunchao Zhang   B->ops->assemblyend           = MatAssemblyEnd_MPIAIJKokkos;
17172c4ab24aSJunchao Zhang   B->ops->mult                  = MatMult_MPIAIJKokkos;
17182c4ab24aSJunchao Zhang   B->ops->multadd               = MatMultAdd_MPIAIJKokkos;
17192c4ab24aSJunchao Zhang   B->ops->multtranspose         = MatMultTranspose_MPIAIJKokkos;
17202c4ab24aSJunchao Zhang   B->ops->productsetfromoptions = MatProductSetFromOptions_MPIAIJKokkos;
17212c4ab24aSJunchao Zhang   B->ops->destroy               = MatDestroy_MPIAIJKokkos;
17222c4ab24aSJunchao Zhang 
17232c4ab24aSJunchao Zhang   PetscCall(PetscObjectComposeFunction((PetscObject)B, "MatMPIAIJSetPreallocation_C", MatMPIAIJSetPreallocation_MPIAIJKokkos));
17242c4ab24aSJunchao Zhang   PetscCall(PetscObjectComposeFunction((PetscObject)B, "MatMPIAIJGetLocalMatMerge_C", MatMPIAIJGetLocalMatMerge_MPIAIJKokkos));
17252c4ab24aSJunchao Zhang   PetscCall(PetscObjectComposeFunction((PetscObject)B, "MatSetPreallocationCOO_C", MatSetPreallocationCOO_MPIAIJKokkos));
17262c4ab24aSJunchao Zhang   PetscCall(PetscObjectComposeFunction((PetscObject)B, "MatSetValuesCOO_C", MatSetValuesCOO_MPIAIJKokkos));
17272c4ab24aSJunchao Zhang   PetscFunctionReturn(PETSC_SUCCESS);
17282c4ab24aSJunchao Zhang }
17292c4ab24aSJunchao Zhang 
1730d71ae5a4SJacob Faibussowitsch PETSC_INTERN PetscErrorCode MatConvert_MPIAIJ_MPIAIJKokkos(Mat A, MatType mtype, MatReuse reuse, Mat *newmat)
1731d71ae5a4SJacob Faibussowitsch {
17328c3ff71bSJunchao Zhang   Mat         B;
1733076ba34aSJunchao Zhang   Mat_MPIAIJ *a;
17348c3ff71bSJunchao Zhang 
17358c3ff71bSJunchao Zhang   PetscFunctionBegin;
17368c3ff71bSJunchao Zhang   if (reuse == MAT_INITIAL_MATRIX) {
17379566063dSJacob Faibussowitsch     PetscCall(MatDuplicate(A, MAT_COPY_VALUES, newmat));
17388c3ff71bSJunchao Zhang   } else if (reuse == MAT_REUSE_MATRIX) {
17399566063dSJacob Faibussowitsch     PetscCall(MatCopy(A, *newmat, SAME_NONZERO_PATTERN));
17408c3ff71bSJunchao Zhang   }
17418c3ff71bSJunchao Zhang   B = *newmat;
17428c3ff71bSJunchao Zhang 
17436f3d89d0SStefano Zampini   B->boundtocpu = PETSC_FALSE;
17449566063dSJacob Faibussowitsch   PetscCall(PetscFree(B->defaultvectype));
17459566063dSJacob Faibussowitsch   PetscCall(PetscStrallocpy(VECKOKKOS, &B->defaultvectype));
17469566063dSJacob Faibussowitsch   PetscCall(PetscObjectChangeTypeName((PetscObject)B, MATMPIAIJKOKKOS));
17478c3ff71bSJunchao Zhang 
1748076ba34aSJunchao Zhang   a = static_cast<Mat_MPIAIJ *>(A->data);
17499566063dSJacob Faibussowitsch   if (a->A) PetscCall(MatSetType(a->A, MATSEQAIJKOKKOS));
17509566063dSJacob Faibussowitsch   if (a->B) PetscCall(MatSetType(a->B, MATSEQAIJKOKKOS));
17519566063dSJacob Faibussowitsch   if (a->lvec) PetscCall(VecSetType(a->lvec, VECSEQKOKKOS));
17522c4ab24aSJunchao Zhang   PetscCall(MatSetOps_MPIAIJKokkos(B));
17533ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
17548c3ff71bSJunchao Zhang }
17552c4ab24aSJunchao Zhang 
17563f3ba80aSJunchao Zhang /*MC
175711a5261eSBarry Smith    MATAIJKOKKOS - "mpiaijkokkos", a matrix type to be used for CSR sparse matrices with Kokkos
17588c3ff71bSJunchao Zhang 
17593f3ba80aSJunchao Zhang    A matrix type type using Kokkos-Kernels CrsMatrix type for portability across different device types
17603f3ba80aSJunchao Zhang 
17612ef1f0ffSBarry Smith    Options Database Key:
17622ef1f0ffSBarry Smith .  -mat_type aijkokkos - sets the matrix type to `MATAIJKOKKOS`
17633f3ba80aSJunchao Zhang 
17643f3ba80aSJunchao Zhang   Level: beginner
17653f3ba80aSJunchao Zhang 
17661cc06b55SBarry Smith .seealso: [](ch_matrices), `Mat`, `MatCreateAIJKokkos()`, `MATSEQAIJKOKKOS`, `MATSEQAIJ`, `MATMPIAIJ`
17673f3ba80aSJunchao Zhang M*/
1768d71ae5a4SJacob Faibussowitsch PETSC_EXTERN PetscErrorCode MatCreate_MPIAIJKokkos(Mat A)
1769d71ae5a4SJacob Faibussowitsch {
17708c3ff71bSJunchao Zhang   PetscFunctionBegin;
17719566063dSJacob Faibussowitsch   PetscCall(PetscKokkosInitializeCheck());
17729566063dSJacob Faibussowitsch   PetscCall(MatCreate_MPIAIJ(A));
17739566063dSJacob Faibussowitsch   PetscCall(MatConvert_MPIAIJ_MPIAIJKokkos(A, MATMPIAIJKOKKOS, MAT_INPLACE_MATRIX, &A));
17743ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
17758c3ff71bSJunchao Zhang }
17768c3ff71bSJunchao Zhang 
17778c3ff71bSJunchao Zhang /*@C
177811a5261eSBarry Smith   MatCreateAIJKokkos - Creates a sparse matrix in `MATAIJKOKOS` (compressed row) format
17798c3ff71bSJunchao Zhang   (the default parallel PETSc format).  This matrix will ultimately pushed down
178020f4b53cSBarry Smith   to Kokkos for calculations.
17818c3ff71bSJunchao Zhang 
17828c3ff71bSJunchao Zhang   Collective
17838c3ff71bSJunchao Zhang 
17848c3ff71bSJunchao Zhang   Input Parameters:
178511a5261eSBarry Smith + comm  - MPI communicator, set to `PETSC_COMM_SELF`
178620f4b53cSBarry Smith . m     - number of local rows (or `PETSC_DECIDE` to have calculated if `M` is given)
178720f4b53cSBarry Smith            This value should be the same as the local size used in creating the
178820f4b53cSBarry Smith            y vector for the matrix-vector product y = Ax.
178920f4b53cSBarry Smith . n     - This value should be the same as the local size used in creating the
179020f4b53cSBarry Smith        x vector for the matrix-vector product y = Ax. (or `PETSC_DECIDE` to have
179120f4b53cSBarry Smith        calculated if N is given) For square matrices n is almost always `m`.
179220f4b53cSBarry Smith . M     - number of global rows (or `PETSC_DETERMINE` to have calculated if `m` is given)
179320f4b53cSBarry Smith . N     - number of global columns (or `PETSC_DETERMINE` to have calculated if `n` is given)
179420f4b53cSBarry Smith . d_nz  - number of nonzeros per row in DIAGONAL portion of local submatrix
179520f4b53cSBarry Smith            (same value is used for all local rows)
179620f4b53cSBarry Smith . d_nnz - array containing the number of nonzeros in the various rows of the
179720f4b53cSBarry Smith            DIAGONAL portion of the local submatrix (possibly different for each row)
179820f4b53cSBarry Smith            or `NULL`, if `d_nz` is used to specify the nonzero structure.
179920f4b53cSBarry Smith            The size of this array is equal to the number of local rows, i.e `m`.
180020f4b53cSBarry Smith            For matrices you plan to factor you must leave room for the diagonal entry and
180120f4b53cSBarry Smith            put in the entry even if it is zero.
180220f4b53cSBarry Smith . o_nz  - number of nonzeros per row in the OFF-DIAGONAL portion of local
180320f4b53cSBarry Smith            submatrix (same value is used for all local rows).
180420f4b53cSBarry Smith - o_nnz - array containing the number of nonzeros in the various rows of the
180520f4b53cSBarry Smith            OFF-DIAGONAL portion of the local submatrix (possibly different for
180620f4b53cSBarry Smith            each row) or `NULL`, if `o_nz` is used to specify the nonzero
180720f4b53cSBarry Smith            structure. The size of this array is equal to the number
180820f4b53cSBarry Smith            of local rows, i.e `m`.
18098c3ff71bSJunchao Zhang 
18108c3ff71bSJunchao Zhang   Output Parameter:
18118c3ff71bSJunchao Zhang . A - the matrix
18128c3ff71bSJunchao Zhang 
18132ef1f0ffSBarry Smith   Level: intermediate
18142ef1f0ffSBarry Smith 
18152ef1f0ffSBarry Smith   Notes:
181611a5261eSBarry Smith   It is recommended that one use the `MatCreate()`, `MatSetType()` and/or `MatSetFromOptions()`,
18178c3ff71bSJunchao Zhang   MatXXXXSetPreallocation() paradigm instead of this routine directly.
181811a5261eSBarry Smith   [MatXXXXSetPreallocation() is, for example, `MatSeqAIJSetPreallocation()`]
18198c3ff71bSJunchao Zhang 
1820667f096bSBarry Smith   The AIJ format, also called compressed row storage), is fully compatible with standard Fortran
18218c3ff71bSJunchao Zhang   storage.  That is, the stored row and column indices can begin at
18222ef1f0ffSBarry Smith   either one (as in Fortran) or zero.
18238c3ff71bSJunchao Zhang 
18241cc06b55SBarry Smith .seealso: [](ch_matrices), `Mat`, `MATAIJKOKOS`, `MATSEQAIJKOKOS`, `MATMPIAIJKOKOS`, `MatCreate()`, `MatCreateAIJ()`, `MatSetValues()`,
1825fe59aa6dSJacob Faibussowitsch           `MatSeqAIJSetColumnIndices()`, `MatCreateSeqAIJWithArrays()`, `MATMPIAIJKOKKOS`, `MATAIJKOKKOS`
18268c3ff71bSJunchao Zhang @*/
1827d71ae5a4SJacob Faibussowitsch PetscErrorCode MatCreateAIJKokkos(MPI_Comm comm, PetscInt m, PetscInt n, PetscInt M, PetscInt N, PetscInt d_nz, const PetscInt d_nnz[], PetscInt o_nz, const PetscInt o_nnz[], Mat *A)
1828d71ae5a4SJacob Faibussowitsch {
18298c3ff71bSJunchao Zhang   PetscMPIInt size;
18308c3ff71bSJunchao Zhang 
18318c3ff71bSJunchao Zhang   PetscFunctionBegin;
18329566063dSJacob Faibussowitsch   PetscCall(MatCreate(comm, A));
18339566063dSJacob Faibussowitsch   PetscCall(MatSetSizes(*A, m, n, M, N));
18349566063dSJacob Faibussowitsch   PetscCallMPI(MPI_Comm_size(comm, &size));
18358c3ff71bSJunchao Zhang   if (size > 1) {
18369566063dSJacob Faibussowitsch     PetscCall(MatSetType(*A, MATMPIAIJKOKKOS));
18379566063dSJacob Faibussowitsch     PetscCall(MatMPIAIJSetPreallocation(*A, d_nz, d_nnz, o_nz, o_nnz));
18388c3ff71bSJunchao Zhang   } else {
18399566063dSJacob Faibussowitsch     PetscCall(MatSetType(*A, MATSEQAIJKOKKOS));
18409566063dSJacob Faibussowitsch     PetscCall(MatSeqAIJSetPreallocation(*A, d_nz, d_nnz));
18418c3ff71bSJunchao Zhang   }
18423ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
18438c3ff71bSJunchao Zhang }
1844