xref: /petsc/src/mat/impls/aij/mpi/kokkos/mpiaijkok.kokkos.cxx (revision d326c3f1c5d950e31249cdd49d2b04d6dbe295f4)
1*d326c3f1SJunchao Zhang #include <petsc_kokkos.hpp>
211d22bbfSJunchao Zhang #include <petscvec_kokkos.hpp>
3f0e6e2d1SJunchao Zhang #include <petscpkg_version.h>
442550becSJunchao Zhang #include <petsc/private/sfimpl.h>
52c4ab24aSJunchao Zhang #include <../src/mat/impls/aij/seq/kokkos/aijkok.hpp>
68c3ff71bSJunchao Zhang #include <../src/mat/impls/aij/mpi/mpiaij.h>
7076ba34aSJunchao Zhang #include <KokkosSparse_spadd.hpp>
80e3ece09SJunchao Zhang #include <KokkosSparse_spgemm.hpp>
911d22bbfSJunchao Zhang 
1066976f2fSJacob Faibussowitsch static PetscErrorCode MatAssemblyEnd_MPIAIJKokkos(Mat A, MatAssemblyType mode)
11d71ae5a4SJacob Faibussowitsch {
1230203840SJunchao Zhang   Mat_MPIAIJ *mpiaij = (Mat_MPIAIJ *)A->data;
138c3ff71bSJunchao Zhang 
148c3ff71bSJunchao Zhang   PetscFunctionBegin;
159566063dSJacob Faibussowitsch   PetscCall(MatAssemblyEnd_MPIAIJ(A, mode));
1630203840SJunchao Zhang   /* E.g., MatCreateSubMatrix() calls MatCreateMPIAIJWithSeqAIJ(comm,A,B,..), which creates Bnew of SEQAIJ and destroys B of SEQAIJKOKKOS.
1730203840SJunchao Zhang      Thus we finalize A/B/lvec's type in MatAssemblyEnd() to handle various cases.
1830203840SJunchao Zhang    */
1930203840SJunchao Zhang   if (mode == MAT_FINAL_ASSEMBLY) {
2030203840SJunchao Zhang     PetscCall(MatSetType(mpiaij->A, MATSEQAIJKOKKOS));
2130203840SJunchao Zhang     PetscCall(MatSetType(mpiaij->B, MATSEQAIJKOKKOS));
2230203840SJunchao Zhang     PetscCall(VecSetType(mpiaij->lvec, VECSEQKOKKOS));
2330203840SJunchao Zhang   }
243ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
258c3ff71bSJunchao Zhang }
268c3ff71bSJunchao Zhang 
2766976f2fSJacob Faibussowitsch static PetscErrorCode MatMPIAIJSetPreallocation_MPIAIJKokkos(Mat mat, PetscInt d_nz, const PetscInt d_nnz[], PetscInt o_nz, const PetscInt o_nnz[])
28d71ae5a4SJacob Faibussowitsch {
298c3ff71bSJunchao Zhang   Mat_MPIAIJ *mpiaij = (Mat_MPIAIJ *)mat->data;
308c3ff71bSJunchao Zhang 
318c3ff71bSJunchao Zhang   PetscFunctionBegin;
32913a874dSJunchao Zhang   // If mat was set to use the "set values with a hash table" mechanism, discard it and restore the cached ops
33913a874dSJunchao Zhang   if (mat->hash_active) {
34913a874dSJunchao Zhang     mat->ops[0]      = mpiaij->cops;
35913a874dSJunchao Zhang     mat->hash_active = PETSC_FALSE;
36913a874dSJunchao Zhang   }
37913a874dSJunchao Zhang 
389566063dSJacob Faibussowitsch   PetscCall(PetscLayoutSetUp(mat->rmap));
399566063dSJacob Faibussowitsch   PetscCall(PetscLayoutSetUp(mat->cmap));
406a29ce69SStefano Zampini #if defined(PETSC_USE_DEBUG)
418c3ff71bSJunchao Zhang   if (d_nnz) {
426a29ce69SStefano Zampini     PetscInt i;
43ad540459SPierre Jolivet     for (i = 0; i < mat->rmap->n; i++) PetscCheck(d_nnz[i] >= 0, PETSC_COMM_SELF, PETSC_ERR_ARG_OUTOFRANGE, "d_nnz cannot be less than 0: local row %" PetscInt_FMT " value %" PetscInt_FMT, i, d_nnz[i]);
448c3ff71bSJunchao Zhang   }
458c3ff71bSJunchao Zhang   if (o_nnz) {
466a29ce69SStefano Zampini     PetscInt i;
47ad540459SPierre Jolivet     for (i = 0; i < mat->rmap->n; i++) PetscCheck(o_nnz[i] >= 0, PETSC_COMM_SELF, PETSC_ERR_ARG_OUTOFRANGE, "o_nnz cannot be less than 0: local row %" PetscInt_FMT " value %" PetscInt_FMT, i, o_nnz[i]);
488c3ff71bSJunchao Zhang   }
496a29ce69SStefano Zampini #endif
506a29ce69SStefano Zampini #if defined(PETSC_USE_CTABLE)
51eec179cfSJacob Faibussowitsch   PetscCall(PetscHMapIDestroy(&mpiaij->colmap));
526a29ce69SStefano Zampini #else
539566063dSJacob Faibussowitsch   PetscCall(PetscFree(mpiaij->colmap));
546a29ce69SStefano Zampini #endif
559566063dSJacob Faibussowitsch   PetscCall(PetscFree(mpiaij->garray));
569566063dSJacob Faibussowitsch   PetscCall(VecDestroy(&mpiaij->lvec));
579566063dSJacob Faibussowitsch   PetscCall(VecScatterDestroy(&mpiaij->Mvctx));
586a29ce69SStefano Zampini   /* Because the B will have been resized we simply destroy it and create a new one each time */
599566063dSJacob Faibussowitsch   PetscCall(MatDestroy(&mpiaij->B));
606a29ce69SStefano Zampini 
616a29ce69SStefano Zampini   if (!mpiaij->A) {
629566063dSJacob Faibussowitsch     PetscCall(MatCreate(PETSC_COMM_SELF, &mpiaij->A));
639566063dSJacob Faibussowitsch     PetscCall(MatSetSizes(mpiaij->A, mat->rmap->n, mat->cmap->n, mat->rmap->n, mat->cmap->n));
646a29ce69SStefano Zampini   }
656a29ce69SStefano Zampini   if (!mpiaij->B) {
666a29ce69SStefano Zampini     PetscMPIInt size;
679566063dSJacob Faibussowitsch     PetscCallMPI(MPI_Comm_size(PetscObjectComm((PetscObject)mat), &size));
689566063dSJacob Faibussowitsch     PetscCall(MatCreate(PETSC_COMM_SELF, &mpiaij->B));
699566063dSJacob Faibussowitsch     PetscCall(MatSetSizes(mpiaij->B, mat->rmap->n, size > 1 ? mat->cmap->N : 0, mat->rmap->n, size > 1 ? mat->cmap->N : 0));
708c3ff71bSJunchao Zhang   }
719566063dSJacob Faibussowitsch   PetscCall(MatSetType(mpiaij->A, MATSEQAIJKOKKOS));
729566063dSJacob Faibussowitsch   PetscCall(MatSetType(mpiaij->B, MATSEQAIJKOKKOS));
739566063dSJacob Faibussowitsch   PetscCall(MatSeqAIJSetPreallocation(mpiaij->A, d_nz, d_nnz));
749566063dSJacob Faibussowitsch   PetscCall(MatSeqAIJSetPreallocation(mpiaij->B, o_nz, o_nnz));
758c3ff71bSJunchao Zhang   mat->preallocated = PETSC_TRUE;
763ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
778c3ff71bSJunchao Zhang }
788c3ff71bSJunchao Zhang 
7966976f2fSJacob Faibussowitsch static PetscErrorCode MatMult_MPIAIJKokkos(Mat mat, Vec xx, Vec yy)
80d71ae5a4SJacob Faibussowitsch {
818c3ff71bSJunchao Zhang   Mat_MPIAIJ *mpiaij = (Mat_MPIAIJ *)mat->data;
828c3ff71bSJunchao Zhang   PetscInt    nt;
838c3ff71bSJunchao Zhang 
848c3ff71bSJunchao Zhang   PetscFunctionBegin;
859566063dSJacob Faibussowitsch   PetscCall(VecGetLocalSize(xx, &nt));
8608401ef6SPierre Jolivet   PetscCheck(nt == mat->cmap->n, PETSC_COMM_SELF, PETSC_ERR_ARG_SIZ, "Incompatible partition of mat (%" PetscInt_FMT ") and xx (%" PetscInt_FMT ")", mat->cmap->n, nt);
879566063dSJacob Faibussowitsch   PetscCall(VecScatterBegin(mpiaij->Mvctx, xx, mpiaij->lvec, INSERT_VALUES, SCATTER_FORWARD));
889566063dSJacob Faibussowitsch   PetscCall((*mpiaij->A->ops->mult)(mpiaij->A, xx, yy));
899566063dSJacob Faibussowitsch   PetscCall(VecScatterEnd(mpiaij->Mvctx, xx, mpiaij->lvec, INSERT_VALUES, SCATTER_FORWARD));
909566063dSJacob Faibussowitsch   PetscCall((*mpiaij->B->ops->multadd)(mpiaij->B, mpiaij->lvec, yy, yy));
913ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
928c3ff71bSJunchao Zhang }
938c3ff71bSJunchao Zhang 
9466976f2fSJacob Faibussowitsch static PetscErrorCode MatMultAdd_MPIAIJKokkos(Mat mat, Vec xx, Vec yy, Vec zz)
95d71ae5a4SJacob Faibussowitsch {
968c3ff71bSJunchao Zhang   Mat_MPIAIJ *mpiaij = (Mat_MPIAIJ *)mat->data;
978c3ff71bSJunchao Zhang   PetscInt    nt;
988c3ff71bSJunchao Zhang 
998c3ff71bSJunchao Zhang   PetscFunctionBegin;
1009566063dSJacob Faibussowitsch   PetscCall(VecGetLocalSize(xx, &nt));
10108401ef6SPierre Jolivet   PetscCheck(nt == mat->cmap->n, PETSC_COMM_SELF, PETSC_ERR_ARG_SIZ, "Incompatible partition of mat (%" PetscInt_FMT ") and xx (%" PetscInt_FMT ")", mat->cmap->n, nt);
1029566063dSJacob Faibussowitsch   PetscCall(VecScatterBegin(mpiaij->Mvctx, xx, mpiaij->lvec, INSERT_VALUES, SCATTER_FORWARD));
1039566063dSJacob Faibussowitsch   PetscCall((*mpiaij->A->ops->multadd)(mpiaij->A, xx, yy, zz));
1049566063dSJacob Faibussowitsch   PetscCall(VecScatterEnd(mpiaij->Mvctx, xx, mpiaij->lvec, INSERT_VALUES, SCATTER_FORWARD));
1059566063dSJacob Faibussowitsch   PetscCall((*mpiaij->B->ops->multadd)(mpiaij->B, mpiaij->lvec, zz, zz));
1063ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
1078c3ff71bSJunchao Zhang }
1088c3ff71bSJunchao Zhang 
10966976f2fSJacob Faibussowitsch static PetscErrorCode MatMultTranspose_MPIAIJKokkos(Mat mat, Vec xx, Vec yy)
110d71ae5a4SJacob Faibussowitsch {
1118c3ff71bSJunchao Zhang   Mat_MPIAIJ *mpiaij = (Mat_MPIAIJ *)mat->data;
1128c3ff71bSJunchao Zhang   PetscInt    nt;
1138c3ff71bSJunchao Zhang 
1148c3ff71bSJunchao Zhang   PetscFunctionBegin;
1159566063dSJacob Faibussowitsch   PetscCall(VecGetLocalSize(xx, &nt));
11608401ef6SPierre Jolivet   PetscCheck(nt == mat->rmap->n, PETSC_COMM_SELF, PETSC_ERR_ARG_SIZ, "Incompatible partition of mat (%" PetscInt_FMT ") and xx (%" PetscInt_FMT ")", mat->rmap->n, nt);
1179566063dSJacob Faibussowitsch   PetscCall((*mpiaij->B->ops->multtranspose)(mpiaij->B, xx, mpiaij->lvec));
1189566063dSJacob Faibussowitsch   PetscCall((*mpiaij->A->ops->multtranspose)(mpiaij->A, xx, yy));
1199566063dSJacob Faibussowitsch   PetscCall(VecScatterBegin(mpiaij->Mvctx, mpiaij->lvec, yy, ADD_VALUES, SCATTER_REVERSE));
1209566063dSJacob Faibussowitsch   PetscCall(VecScatterEnd(mpiaij->Mvctx, mpiaij->lvec, yy, ADD_VALUES, SCATTER_REVERSE));
1213ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
1228c3ff71bSJunchao Zhang }
1238c3ff71bSJunchao Zhang 
124076ba34aSJunchao Zhang /* Merge the "A, B" matrices of mat into a matrix C.  mat's type is MPIAIJKOKKOS. C's type is MATSEQAIJKOKKOS.
125076ba34aSJunchao Zhang    A is put before B. C's size would be A->rmap->n by (A->cmap->n + B->cmap->n).
126076ba34aSJunchao Zhang    C still uses local column ids. Their corresponding global column ids are returned in glob.
127076ba34aSJunchao Zhang */
12866976f2fSJacob Faibussowitsch static PetscErrorCode MatMPIAIJGetLocalMatMerge_MPIAIJKokkos(Mat mat, MatReuse reuse, IS *glob, Mat *C)
129d71ae5a4SJacob Faibussowitsch {
130076ba34aSJunchao Zhang   Mat             Ad, Ao;
131076ba34aSJunchao Zhang   const PetscInt *cmap;
132076ba34aSJunchao Zhang 
133076ba34aSJunchao Zhang   PetscFunctionBegin;
1349566063dSJacob Faibussowitsch   PetscCall(MatMPIAIJGetSeqAIJ(mat, &Ad, &Ao, &cmap));
1359566063dSJacob Faibussowitsch   PetscCall(MatSeqAIJKokkosMergeMats(Ad, Ao, reuse, C));
136076ba34aSJunchao Zhang   if (glob) {
137076ba34aSJunchao Zhang     PetscInt cst, i, dn, on, *gidx;
1389566063dSJacob Faibussowitsch     PetscCall(MatGetLocalSize(Ad, NULL, &dn));
1399566063dSJacob Faibussowitsch     PetscCall(MatGetLocalSize(Ao, NULL, &on));
1409566063dSJacob Faibussowitsch     PetscCall(MatGetOwnershipRangeColumn(mat, &cst, NULL));
1419566063dSJacob Faibussowitsch     PetscCall(PetscMalloc1(dn + on, &gidx));
142076ba34aSJunchao Zhang     for (i = 0; i < dn; i++) gidx[i] = cst + i;
143076ba34aSJunchao Zhang     for (i = 0; i < on; i++) gidx[i + dn] = cmap[i];
1449566063dSJacob Faibussowitsch     PetscCall(ISCreateGeneral(PetscObjectComm((PetscObject)Ad), dn + on, gidx, PETSC_OWN_POINTER, glob));
145076ba34aSJunchao Zhang   }
1463ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
147076ba34aSJunchao Zhang }
148076ba34aSJunchao Zhang 
1490e3ece09SJunchao Zhang /* Structs used in matrix products of type C=AB, C=A^tB and C=B^tAB */
150076ba34aSJunchao Zhang struct MatMatStruct {
1510e3ece09SJunchao Zhang   PetscInt            n, *garray;     // C's garray and its size.
1520e3ece09SJunchao Zhang   KokkosCsrMatrix     Cd, Co;         // C is in split form matrices (all in local column indcies)
1530e3ece09SJunchao Zhang   KokkosCsrMatrix     C1, C2, C3, C4; // intermediate mat products
1540e3ece09SJunchao Zhang   KokkosCsrMatrix     C2_mid, C4_mid; // alias of C2, C4; share their a[], i[], but with different j[] (hence column size)
1550e3ece09SJunchao Zhang   PetscIntKokkosView  E_NzLeft;
1560e3ece09SJunchao Zhang   PetscSF             sf = nullptr; // SF to bcast or reduce matrices E to F
1570e3ece09SJunchao Zhang   MatScalarKokkosView rootBuf, leafBuf;
1580e3ece09SJunchao Zhang   KokkosCsrMatrix     Fd, Fo; // F in split form
1590e3ece09SJunchao Zhang 
1600e3ece09SJunchao Zhang   KernelHandle kh1; // compute C1, add C1+C3 or C1+Fd
1610e3ece09SJunchao Zhang   KernelHandle kh2; // compute C2, add C2+C4 or C2+Fo
1620e3ece09SJunchao Zhang   KernelHandle kh3; // compute C3
1630e3ece09SJunchao Zhang   KernelHandle kh4; // compute C4
1640e3ece09SJunchao Zhang 
165aaa8cc7dSPierre Jolivet   PetscInt E_TeamSize; // kernel launching parameters in merging E or splitting F
1660e3ece09SJunchao Zhang   PetscInt E_VectorLength;
1670e3ece09SJunchao Zhang   PetscInt E_RowsPerTeam;
1680e3ece09SJunchao Zhang   PetscInt F_TeamSize;
1690e3ece09SJunchao Zhang   PetscInt F_VectorLength;
1700e3ece09SJunchao Zhang   PetscInt F_RowsPerTeam;
171076ba34aSJunchao Zhang 
172d71ae5a4SJacob Faibussowitsch   ~MatMatStruct()
173d71ae5a4SJacob Faibussowitsch   {
1743ba16761SJacob Faibussowitsch     PetscFunctionBegin;
1753ba16761SJacob Faibussowitsch     PetscCallAbort(PETSC_COMM_SELF, PetscSFDestroy(&sf));
1763ba16761SJacob Faibussowitsch     PetscFunctionReturnVoid();
177076ba34aSJunchao Zhang   }
178076ba34aSJunchao Zhang };
179076ba34aSJunchao Zhang 
180076ba34aSJunchao Zhang struct MatMatStruct_AB : public MatMatStruct {
1810e3ece09SJunchao Zhang   PetscIntKokkosView F_NzLeft; // plans to split F (in leafbuf) into Fd, Fo
1820e3ece09SJunchao Zhang   PetscIntKokkosView irootloc; // plans to put E (i.e., Bd, Bo) into rootBuf
1830e3ece09SJunchao Zhang   PetscIntKokkosView rowoffset;
184076ba34aSJunchao Zhang };
185076ba34aSJunchao Zhang 
186076ba34aSJunchao Zhang struct MatMatStruct_AtB : public MatMatStruct {
1870e3ece09SJunchao Zhang   MatColIdxKokkosView Fdjmap; // plans to reduce data in rootBuf to Fd, Fo
1880e3ece09SJunchao Zhang   MatColIdxKokkosView Fdjperm;
1890e3ece09SJunchao Zhang   MatColIdxKokkosView Fojmap;
1900e3ece09SJunchao Zhang   MatColIdxKokkosView Fojperm;
191076ba34aSJunchao Zhang };
192076ba34aSJunchao Zhang 
1939371c9d4SSatish Balay struct MatProductData_MPIAIJKokkos {
1943ba16761SJacob Faibussowitsch   MatMatStruct_AB  *mmAB     = nullptr;
1953ba16761SJacob Faibussowitsch   MatMatStruct_AtB *mmAtB    = nullptr;
1963ba16761SJacob Faibussowitsch   PetscBool         reusesym = PETSC_FALSE;
1970e3ece09SJunchao Zhang   Mat               Z        = nullptr; // store Z=AB in computing BtAB
198076ba34aSJunchao Zhang 
199d71ae5a4SJacob Faibussowitsch   ~MatProductData_MPIAIJKokkos()
200d71ae5a4SJacob Faibussowitsch   {
201076ba34aSJunchao Zhang     delete mmAB;
202076ba34aSJunchao Zhang     delete mmAtB;
2030e3ece09SJunchao Zhang     PetscCallAbort(PETSC_COMM_SELF, MatDestroy(&Z));
204076ba34aSJunchao Zhang   }
205076ba34aSJunchao Zhang };
206076ba34aSJunchao Zhang 
207d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatProductDataDestroy_MPIAIJKokkos(void *data)
208d71ae5a4SJacob Faibussowitsch {
209076ba34aSJunchao Zhang   PetscFunctionBegin;
2109566063dSJacob Faibussowitsch   PetscCallCXX(delete static_cast<MatProductData_MPIAIJKokkos *>(data));
2113ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
212076ba34aSJunchao Zhang }
213076ba34aSJunchao Zhang 
214076ba34aSJunchao Zhang /* MatSetMPIAIJKokkosWithSplitSeqAIJKokkosMatrices - Set the diag and offdiag matrices of a MATMPIAIJKOKKOS matrix.
215076ba34aSJunchao Zhang    It is similar to MatCreateMPIAIJWithSplitArrays.
216076ba34aSJunchao Zhang 
217076ba34aSJunchao Zhang   Input Parameters:
218076ba34aSJunchao Zhang +  mat   - the MATMPIAIJKOKKOS matrix, which should have its type and layout set, but should not have its diag, offdiag matrices set
219076ba34aSJunchao Zhang .  A     - the diag matrix using local col ids
220076ba34aSJunchao Zhang -  B     - the offdiag matrix using global col ids
221076ba34aSJunchao Zhang 
2222fe279fdSBarry Smith   Output Parameter:
223076ba34aSJunchao Zhang .  mat   - the updated MATMPIAIJKOKKOS matrix
224076ba34aSJunchao Zhang */
2250e3ece09SJunchao Zhang static PetscErrorCode MatSetMPIAIJKokkosWithSplitSeqAIJKokkosMatrices(Mat mat, Mat A, Mat B, PetscInt *garray)
226d71ae5a4SJacob Faibussowitsch {
227076ba34aSJunchao Zhang   Mat_MPIAIJ *mpiaij = static_cast<Mat_MPIAIJ *>(mat->data);
228076ba34aSJunchao Zhang   PetscInt    m, n, M, N, Am, An, Bm, Bn;
229076ba34aSJunchao Zhang 
230076ba34aSJunchao Zhang   PetscFunctionBegin;
2319566063dSJacob Faibussowitsch   PetscCall(MatGetSize(mat, &M, &N));
2329566063dSJacob Faibussowitsch   PetscCall(MatGetLocalSize(mat, &m, &n));
2339566063dSJacob Faibussowitsch   PetscCall(MatGetLocalSize(A, &Am, &An));
2349566063dSJacob Faibussowitsch   PetscCall(MatGetLocalSize(B, &Bm, &Bn));
235076ba34aSJunchao Zhang 
236aed4548fSBarry Smith   PetscCheck(m == Am && m == Bm, PETSC_COMM_SELF, PETSC_ERR_PLIB, "local number of rows do not match");
23708401ef6SPierre Jolivet   PetscCheck(n == An, PETSC_COMM_SELF, PETSC_ERR_PLIB, "local number of columns do not match");
2380e3ece09SJunchao Zhang   // PetscCheck(N == Bn, PETSC_COMM_SELF, PETSC_ERR_PLIB, "global number of columns do not match");
23908401ef6SPierre Jolivet   PetscCheck(!mpiaij->A && !mpiaij->B, PETSC_COMM_SELF, PETSC_ERR_PLIB, "A, B of the MPIAIJ matrix are not empty");
240076ba34aSJunchao Zhang   mpiaij->A      = A;
241076ba34aSJunchao Zhang   mpiaij->B      = B;
2420e3ece09SJunchao Zhang   mpiaij->garray = garray;
243076ba34aSJunchao Zhang 
244076ba34aSJunchao Zhang   mat->preallocated     = PETSC_TRUE;
245076ba34aSJunchao Zhang   mat->nooffprocentries = PETSC_TRUE; /* See MatAssemblyBegin_MPIAIJ. In effect, making MatAssemblyBegin a nop */
246076ba34aSJunchao Zhang 
2479566063dSJacob Faibussowitsch   PetscCall(MatSetOption(mat, MAT_NO_OFF_PROC_ENTRIES, PETSC_TRUE));
2489566063dSJacob Faibussowitsch   PetscCall(MatAssemblyBegin(mat, MAT_FINAL_ASSEMBLY));
249076ba34aSJunchao Zhang   /* MatAssemblyEnd is critical here. It sets mat->offloadmask according to A and B's, and
250076ba34aSJunchao Zhang     also gets mpiaij->B compacted, with its col ids and size reduced
251076ba34aSJunchao Zhang   */
2529566063dSJacob Faibussowitsch   PetscCall(MatAssemblyEnd(mat, MAT_FINAL_ASSEMBLY));
2539566063dSJacob Faibussowitsch   PetscCall(MatSetOption(mat, MAT_NO_OFF_PROC_ENTRIES, PETSC_FALSE));
2549566063dSJacob Faibussowitsch   PetscCall(MatSetOption(mat, MAT_NEW_NONZERO_LOCATION_ERR, PETSC_TRUE));
2553ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
256076ba34aSJunchao Zhang }
257076ba34aSJunchao Zhang 
2580e3ece09SJunchao Zhang // Adapted from Kokkos-Kernels spmv_launch_parameters(), to get parameters in Kokkos nested loops which we used to merge or
2590e3ece09SJunchao Zhang // split csr matrices. The rule is to have "vector_length * team_size" be around 256 on GPUs (e.g., for a CUDA thread block)
2600e3ece09SJunchao Zhang template <class ExecutionSpace>
2610e3ece09SJunchao Zhang static PetscErrorCode MatMergeGetLaunchParameters(PetscInt numRows, PetscInt nnz, PetscInt rows_per_thread, PetscInt &team_size, PetscInt &vector_length, PetscInt &rows_per_team)
262d71ae5a4SJacob Faibussowitsch {
2630e3ece09SJunchao Zhang   Kokkos::TeamPolicy<ExecutionSpace> teamPolicy(128, Kokkos::AUTO);
264076ba34aSJunchao Zhang 
265076ba34aSJunchao Zhang   PetscFunctionBegin;
2660e3ece09SJunchao Zhang   PetscInt nnz_per_row = numRows ? (nnz / numRows) : 0; // we might meet empty matrices
267076ba34aSJunchao Zhang 
2680e3ece09SJunchao Zhang   if (nnz_per_row < 1) nnz_per_row = 1;
269076ba34aSJunchao Zhang 
2700e3ece09SJunchao Zhang   int max_vector_length = teamPolicy.vector_length_max();
271076ba34aSJunchao Zhang 
2720e3ece09SJunchao Zhang   if (vector_length < 1) {
2730e3ece09SJunchao Zhang     vector_length = 1;
2740e3ece09SJunchao Zhang     while (vector_length < max_vector_length && vector_length * 6 < nnz_per_row) vector_length *= 2;
275076ba34aSJunchao Zhang   }
276076ba34aSJunchao Zhang 
2770e3ece09SJunchao Zhang   // Determine rows per thread
2780e3ece09SJunchao Zhang   if (rows_per_thread < 1) {
2790e3ece09SJunchao Zhang     if (KokkosKernels::Impl::kk_is_gpu_exec_space<ExecutionSpace>()) rows_per_thread = 1;
2800e3ece09SJunchao Zhang     else {
2810e3ece09SJunchao Zhang       if (nnz_per_row < 20 && nnz > 5000000) {
2820e3ece09SJunchao Zhang         rows_per_thread = 256;
2830e3ece09SJunchao Zhang       } else rows_per_thread = 64;
284076ba34aSJunchao Zhang     }
285076ba34aSJunchao Zhang   }
286076ba34aSJunchao Zhang 
2870e3ece09SJunchao Zhang   if (team_size < 1) {
2880e3ece09SJunchao Zhang     if (KokkosKernels::Impl::kk_is_gpu_exec_space<ExecutionSpace>()) {
2890e3ece09SJunchao Zhang       team_size = 256 / vector_length;
290076ba34aSJunchao Zhang     } else {
2910e3ece09SJunchao Zhang       team_size = 1;
2920e3ece09SJunchao Zhang     }
293076ba34aSJunchao Zhang   }
294076ba34aSJunchao Zhang 
2950e3ece09SJunchao Zhang   rows_per_team = rows_per_thread * team_size;
296076ba34aSJunchao Zhang 
2970e3ece09SJunchao Zhang   if (rows_per_team < 0) {
2980e3ece09SJunchao Zhang     PetscInt nnz_per_team = 4096;
2990e3ece09SJunchao Zhang     PetscInt conc         = ExecutionSpace().concurrency();
3000e3ece09SJunchao Zhang     while ((conc * nnz_per_team * 4 > nnz) && (nnz_per_team > 256)) nnz_per_team /= 2;
3010e3ece09SJunchao Zhang     rows_per_team = (nnz_per_team + nnz_per_row - 1) / nnz_per_row;
3020e3ece09SJunchao Zhang   }
3033ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
304076ba34aSJunchao Zhang }
305076ba34aSJunchao Zhang 
3060e3ece09SJunchao Zhang /*
3070e3ece09SJunchao Zhang   Reduce two sets of global indices into local ones
308076ba34aSJunchao Zhang 
309076ba34aSJunchao Zhang   Input Parameters:
3100e3ece09SJunchao Zhang +  n1          - size of garray1[], the first set
3110e3ece09SJunchao Zhang .  garray1[n1] - a sorted global index array (without duplicates)
3120e3ece09SJunchao Zhang .  m           - size of indices[], the second set
3130e3ece09SJunchao Zhang -  indices[m]  - a unsorted global index array (might have duplicates), which will be updated on output into local ones
314076ba34aSJunchao Zhang 
315076ba34aSJunchao Zhang   Output Parameters:
3160e3ece09SJunchao Zhang +  n2          - size of garray2[], the merged set, which combines garray1[] and indices[]
3170e3ece09SJunchao Zhang .  garray2[n2] - allocated by callee using PetscMalloc1(). Contains sorted unique global indices (without duplicates). Caller needs to free it.
3180e3ece09SJunchao Zhang .  map[n1]     - allocated by caller. It gives garray1[i] = garray2[map[i]]
3190e3ece09SJunchao Zhang -  indices[m]  - on output, global indices in this array are rewritten with local ones, i.e, indices_input[i] = garray2[indices_output[i]]
320076ba34aSJunchao Zhang 
3210e3ece09SJunchao Zhang    Example, say
3220e3ece09SJunchao Zhang     n1         = 5
3230e3ece09SJunchao Zhang     garray1[5] = {1, 4, 7, 8, 10}
3240e3ece09SJunchao Zhang     m          = 4
3250e3ece09SJunchao Zhang     indices[4] = {2, 4, 8, 9}
32611a5261eSBarry Smith 
3270e3ece09SJunchao Zhang    Combining them together, we have 7 global indices in garray2[]
3280e3ece09SJunchao Zhang     n2         = 7
3290e3ece09SJunchao Zhang     garray2[7] = {1, 2, 4, 7, 8, 9, 10}
3300e3ece09SJunchao Zhang 
3310e3ece09SJunchao Zhang    And we have map[] to connect "garray1[i] = garray2[map[i]], i=[0,n1)"
3320e3ece09SJunchao Zhang     map[5] = {0, 2, 3, 4, 6}
3330e3ece09SJunchao Zhang 
3340e3ece09SJunchao Zhang    On output, indices[] is updated with local indices
3350e3ece09SJunchao Zhang     indices[4] = {1, 2, 4, 5}
336076ba34aSJunchao Zhang */
3370e3ece09SJunchao Zhang static PetscErrorCode ReduceTwoSetsOfGlobalIndices(PetscInt n1, const PetscInt *garray1, PetscInt m, PetscInt *indices, PetscInt *n2_, PetscInt **garray2_, PetscInt *map)
338d71ae5a4SJacob Faibussowitsch {
3390e3ece09SJunchao Zhang   PetscHMapI    g2l = nullptr;
3400e3ece09SJunchao Zhang   PetscHashIter iter;
3410e3ece09SJunchao Zhang   PetscInt      tot, key, val; // total unique global indices. key is global id; val is local id
3420e3ece09SJunchao Zhang   PetscInt      n2, *garray2;
343076ba34aSJunchao Zhang 
344076ba34aSJunchao Zhang   PetscFunctionBegin;
3450e3ece09SJunchao Zhang   tot = 0;
3460e3ece09SJunchao Zhang   PetscCall(PetscHMapICreateWithSize(n1, &g2l));
3470e3ece09SJunchao Zhang   for (PetscInt i = 0; i < m; i++) {                                // insert those in indices[]
3480e3ece09SJunchao Zhang     PetscCall(PetscHMapIGetWithDefault(g2l, indices[i], -1, &val)); // if not exist, val is set with -1
3490e3ece09SJunchao Zhang     if (val < 0) PetscCall(PetscHMapISet(g2l, indices[i], tot++));  // val < 0 means gid is not in the hash table yet
350076ba34aSJunchao Zhang   }
351076ba34aSJunchao Zhang 
3520e3ece09SJunchao Zhang   for (PetscInt i = 0; i < n1; i++) { // insert those in garray1[]
3530e3ece09SJunchao Zhang     PetscCall(PetscHMapIGetWithDefault(g2l, garray1[i], -1, &val));
3540e3ece09SJunchao Zhang     if (val < 0) PetscCall(PetscHMapISet(g2l, garray1[i], tot++));
355076ba34aSJunchao Zhang   }
356076ba34aSJunchao Zhang 
3570e3ece09SJunchao Zhang   // Pull out (unique) globals in the hash table and put them in garray2[]
3580e3ece09SJunchao Zhang   n2 = tot;
3590e3ece09SJunchao Zhang   PetscCall(PetscMalloc1(n2, &garray2));
3600e3ece09SJunchao Zhang   tot = 0;
3610e3ece09SJunchao Zhang   PetscHashIterBegin(g2l, iter);
3620e3ece09SJunchao Zhang   while (!PetscHashIterAtEnd(g2l, iter)) {
3630e3ece09SJunchao Zhang     PetscHashIterGetKey(g2l, iter, key);
3640e3ece09SJunchao Zhang     PetscHashIterNext(g2l, iter);
3650e3ece09SJunchao Zhang     garray2[tot++] = key;
366076ba34aSJunchao Zhang   }
367076ba34aSJunchao Zhang 
3680e3ece09SJunchao Zhang   // Sort garray2[] and then map them to local indices starting from 0
3690e3ece09SJunchao Zhang   PetscCall(PetscSortInt(n2, garray2));
3700e3ece09SJunchao Zhang   PetscCall(PetscHMapIClear(g2l));
3710e3ece09SJunchao Zhang   for (PetscInt i = 0; i < tot; i++) PetscCall(PetscHMapISet(g2l, garray2[i], i)); // i is the local id
372f0e6e2d1SJunchao Zhang 
3730e3ece09SJunchao Zhang   // Rewrite indices[] with local indices
374f0e6e2d1SJunchao Zhang   for (PetscInt i = 0; i < m; i++) {
3750e3ece09SJunchao Zhang     PetscCall(PetscHMapIGetWithDefault(g2l, indices[i], -1, &val));
3760e3ece09SJunchao Zhang     PetscAssert(val >= 0, PETSC_COMM_SELF, PETSC_ERR_PLIB, "Met a negative local column index");
3770e3ece09SJunchao Zhang     indices[i] = val;
3780e3ece09SJunchao Zhang   }
3790e3ece09SJunchao Zhang   // Record the map that maps garray1[i] to garray2[map[i]]
3800e3ece09SJunchao Zhang   for (PetscInt i = 0; i < n1; i++) PetscCall(PetscHMapIGetWithDefault(g2l, garray1[i], -1, &map[i]));
3810e3ece09SJunchao Zhang   PetscCall(PetscHMapIDestroy(&g2l));
3820e3ece09SJunchao Zhang   *n2_      = n2;
3830e3ece09SJunchao Zhang   *garray2_ = garray2;
3840e3ece09SJunchao Zhang   PetscFunctionReturn(PETSC_SUCCESS);
3850e3ece09SJunchao Zhang }
386f0e6e2d1SJunchao Zhang 
3870e3ece09SJunchao Zhang /*
3880e3ece09SJunchao Zhang   MatMPIAIJKokkosReduce - Reduce rows of a MPIAIJKOKKOS matrix (E, in split form) to produce another matrix (F, also in split form, stored in mm)
3890e3ece09SJunchao Zhang 
3900e3ece09SJunchao Zhang   It is the reverse of MatMPIAIJKokkosBcast() in some sense, but with a different signature since we do not really need a fully populated MPIAIJKOKKOS E.
3910e3ece09SJunchao Zhang 
3920e3ece09SJunchao Zhang   Think each row of E as a leaf, then the given ownerSF specifies roots for the leaves. Roots may connect to multiple leaves.
3930e3ece09SJunchao Zhang   In this routine, we sparse-merge leaves (rows) at their roots to form potentially longer rows in F. F's number of rows will be nroots of ownerSF.
3940e3ece09SJunchao Zhang 
3950e3ece09SJunchao Zhang   Input Parameters:
3960e3ece09SJunchao Zhang +  comm       - MPI communicator of E
3970e3ece09SJunchao Zhang .  A          - diag block of E, using local column indices
3980e3ece09SJunchao Zhang .  B          - off-diag block of E, using local column indices
3990e3ece09SJunchao Zhang .  cstart      - (global) start column of Ed
4000e3ece09SJunchao Zhang .  cend        - (global) end column + 1 of Ed.  In other words, E's column ownership is in range of [cstart, cend)
4010e3ece09SJunchao Zhang .  garray1[n1] - global column indices of Eo. Here n1 is Eo's column size.
4020e3ece09SJunchao Zhang .  ownerSF     - the SF specifies ownership (root) of rows in E
4030e3ece09SJunchao Zhang .  reuse       - either MAT_INITIAL_MATRIX or MAT_REUSE_MATRIX
4040e3ece09SJunchao Zhang -  mm          - to stash intermediate data structures for reuse
4050e3ece09SJunchao Zhang 
4060e3ece09SJunchao Zhang   Output Parameters:
4070e3ece09SJunchao Zhang +  map[n1]  - allocated by caller. It maps garray1[] to garray2[]. See more at ReduceTwoSetsOfGlobalIndices().
4080e3ece09SJunchao Zhang -  mm       - contains various info, such as garray2[], F (Fd, Fo) etc.
4090e3ece09SJunchao Zhang 
4100e3ece09SJunchao Zhang   Notes:
4110e3ece09SJunchao Zhang   When reuse = MAT_REUSE_MATRIX, cstart, cend, garray1, ownerSF, map are not significant.
4120e3ece09SJunchao Zhang 
4130e3ece09SJunchao Zhang  */
4140e3ece09SJunchao Zhang static PetscErrorCode MatMPIAIJKokkosReduceBegin(MPI_Comm comm, KokkosCsrMatrix A, KokkosCsrMatrix B, PetscInt cstart, PetscInt cend, const PetscInt *garray1, PetscSF ownerSF, MatReuse reuse, PetscInt *map, MatMatStruct_AtB *mm)
4150e3ece09SJunchao Zhang {
4160e3ece09SJunchao Zhang   PetscFunctionBegin;
4170e3ece09SJunchao Zhang   if (reuse == MAT_INITIAL_MATRIX) {
4180e3ece09SJunchao Zhang     PetscInt Em = A.numRows(), Fm;
4190e3ece09SJunchao Zhang     PetscInt n1 = B.numCols();
4200e3ece09SJunchao Zhang 
4210e3ece09SJunchao Zhang     PetscCall(PetscSFGetGraph(ownerSF, &Fm, NULL, NULL, NULL)); // Fm = #rows of F = nroots of ownerSF
4220e3ece09SJunchao Zhang 
4230e3ece09SJunchao Zhang     // Do the analysis on host
4240e3ece09SJunchao Zhang     auto                 Ai_h = Kokkos::create_mirror_view_and_copy(Kokkos::HostSpace(), A.graph.row_map);
4250e3ece09SJunchao Zhang     auto                 Aj_h = Kokkos::create_mirror_view_and_copy(Kokkos::HostSpace(), A.graph.entries);
4260e3ece09SJunchao Zhang     auto                 Bi_h = Kokkos::create_mirror_view_and_copy(Kokkos::HostSpace(), B.graph.row_map);
4270e3ece09SJunchao Zhang     auto                 Bj_h = Kokkos::create_mirror_view_and_copy(Kokkos::HostSpace(), B.graph.entries);
4280e3ece09SJunchao Zhang     const MatRowMapType *Ai = Ai_h.data(), *Bi = Bi_h.data();
4290e3ece09SJunchao Zhang     const MatColIdxType *Aj = Aj_h.data(), *Bj = Bj_h.data();
4300e3ece09SJunchao Zhang 
4310e3ece09SJunchao Zhang     // Count how many nonzeros of each row in E are in the left of the diag block [cstart,cend)
4327b8d4ba6SJunchao Zhang     PetscIntKokkosViewHost E_NzLeft_h(NoInit("E_NzLeft_h"), Em), E_RowLen_h(NoInit("E_RowLen_h"), Em);
4330e3ece09SJunchao Zhang     PetscInt              *E_NzLeft = E_NzLeft_h.data(), *E_RowLen = E_RowLen_h.data();
4340e3ece09SJunchao Zhang     for (PetscInt i = 0; i < Em; i++) {
4350e3ece09SJunchao Zhang       const PetscInt *first, *last, *it;
4360e3ece09SJunchao Zhang       PetscInt        count, step;
4370e3ece09SJunchao Zhang       // std::lower_bound(first,last,cstart), but need to use global column indices
4380e3ece09SJunchao Zhang       first = Bj + Bi[i];
4390e3ece09SJunchao Zhang       last  = Bj + Bi[i + 1];
440f0e6e2d1SJunchao Zhang       count = last - first;
441f0e6e2d1SJunchao Zhang       while (count > 0) {
442f0e6e2d1SJunchao Zhang         it   = first;
443f0e6e2d1SJunchao Zhang         step = count / 2;
444f0e6e2d1SJunchao Zhang         it += step;
4450e3ece09SJunchao Zhang         if (garray1[*it] < cstart) { // map local to global
446f0e6e2d1SJunchao Zhang           first = ++it;
447f0e6e2d1SJunchao Zhang           count -= step + 1;
448f0e6e2d1SJunchao Zhang         } else count = step;
449f0e6e2d1SJunchao Zhang       }
4500e3ece09SJunchao Zhang       E_NzLeft[i] = first - (Bj + Bi[i]);
4510e3ece09SJunchao Zhang       E_RowLen[i] = (Ai[i + 1] - Ai[i]) + (Bi[i + 1] - Bi[i]);
452f0e6e2d1SJunchao Zhang     }
453f0e6e2d1SJunchao Zhang 
4540e3ece09SJunchao Zhang     // Get length of rows (i.e., sizes of leaves) that contribute to my roots
4550e3ece09SJunchao Zhang     const PetscMPIInt *iranks, *ranks;
4560e3ece09SJunchao Zhang     const PetscInt    *ioffset, *irootloc, *roffset, *rmine;
4570e3ece09SJunchao Zhang     PetscInt           niranks, nranks;
4580e3ece09SJunchao Zhang     MPI_Request       *reqs;
4590e3ece09SJunchao Zhang     PetscMPIInt        tag;
4600e3ece09SJunchao Zhang     PetscSF            reduceSF;
4610e3ece09SJunchao Zhang     PetscInt          *sdisp, *rdisp;
462f0e6e2d1SJunchao Zhang 
4630e3ece09SJunchao Zhang     PetscCall(PetscCommGetNewTag(comm, &tag));
4640e3ece09SJunchao Zhang     PetscCall(PetscSFGetLeafRanks(ownerSF, &niranks, &iranks, &ioffset, &irootloc));  // get leaf ranks connecting to roots on this process (I'll recv from them)
4650e3ece09SJunchao Zhang     PetscCall(PetscSFGetRootRanks(ownerSF, &nranks, &ranks, &roffset, &rmine, NULL)); // get root ranks this process connects (I'll send to them)
466f0e6e2d1SJunchao Zhang 
4670e3ece09SJunchao Zhang     // Find out length of each row I will receive. Even for the same row index, when they are from
4680e3ece09SJunchao Zhang     // different senders, they might have different lengths (and sparsity patterns)
4690e3ece09SJunchao Zhang     PetscInt  sendRowCnt = roffset[nranks], recvRowCnt = ioffset[niranks];
4700e3ece09SJunchao Zhang     PetscInt *sendRowLen, *recvRowLen; // lengths of rows of I need to send/recv per process
471f0e6e2d1SJunchao Zhang 
4720e3ece09SJunchao Zhang     PetscCall(PetscMalloc5(sendRowCnt, &sendRowLen, recvRowCnt + 1, &recvRowLen, nranks, &sdisp, niranks + 1, &rdisp, nranks + niranks, &reqs));
4730e3ece09SJunchao Zhang 
4740e3ece09SJunchao Zhang     for (PetscInt i = 0; i < sendRowCnt; i++) sendRowLen[i] = E_RowLen[rmine[i]];
4750e3ece09SJunchao Zhang     recvRowLen[0] = 0; // since we will make it in CSR format later
4760e3ece09SJunchao Zhang     recvRowLen++;      // advance the pointer now
4770e3ece09SJunchao Zhang     for (PetscInt i = 0; i < niranks; i++) { MPI_Irecv(&recvRowLen[ioffset[i]], ioffset[i + 1] - ioffset[i], MPIU_INT, iranks[i], tag, comm, &reqs[nranks + i]); }
4780e3ece09SJunchao Zhang     for (PetscInt i = 0; i < nranks; i++) { MPI_Isend(&sendRowLen[roffset[i]], roffset[i + 1] - roffset[i], MPIU_INT, ranks[i], tag, comm, &reqs[i]); }
4790e3ece09SJunchao Zhang     PetscCallMPI(MPI_Waitall(nranks + niranks, reqs, MPI_STATUSES_IGNORE));
4800e3ece09SJunchao Zhang 
4810e3ece09SJunchao Zhang     // Build the real PetscSF for reducing E rows (buffer to buffer)
4820e3ece09SJunchao Zhang     rdisp[0] = 0;
4830e3ece09SJunchao Zhang     for (PetscInt i = 0; i < niranks; i++) {
4840e3ece09SJunchao Zhang       rdisp[i + 1] = rdisp[i];
4850e3ece09SJunchao Zhang       for (PetscInt j = ioffset[i]; j < ioffset[i + 1]; j++) { rdisp[i + 1] += recvRowLen[j]; }
4860e3ece09SJunchao Zhang     }
4870e3ece09SJunchao Zhang     recvRowLen--; // put it back into csr format
4880e3ece09SJunchao Zhang     for (PetscInt i = 0; i < recvRowCnt; i++) recvRowLen[i + 1] += recvRowLen[i];
4890e3ece09SJunchao Zhang 
4900e3ece09SJunchao Zhang     for (PetscInt i = 0; i < nranks; i++) { MPI_Irecv(&sdisp[i], 1, MPIU_INT, ranks[i], tag, comm, &reqs[i]); }
4910e3ece09SJunchao Zhang     for (PetscInt i = 0; i < niranks; i++) { MPI_Isend(&rdisp[i], 1, MPIU_INT, iranks[i], tag, comm, &reqs[nranks + i]); }
4920e3ece09SJunchao Zhang     PetscCallMPI(MPI_Waitall(nranks + niranks, reqs, MPI_STATUSES_IGNORE));
4930e3ece09SJunchao Zhang 
4940e3ece09SJunchao Zhang     PetscInt     nleaves = 0, Enz = 0;    // leaves are nonzeros I will send
4950e3ece09SJunchao Zhang     PetscInt     nroots = rdisp[niranks]; // roots are nonzeros I will recv
4960e3ece09SJunchao Zhang     PetscSFNode *iremote;
4970e3ece09SJunchao Zhang 
4980e3ece09SJunchao Zhang     for (PetscInt i = 0; i < Em; i++) Enz += E_RowLen[i];
4990e3ece09SJunchao Zhang     PetscAssert(A.nnz() + B.nnz() == Enz, comm, PETSC_ERR_PLIB, "Enz should be equal to sum of nnz of A and B");
5000e3ece09SJunchao Zhang     PetscCallMPI(PetscMalloc1(Enz, &iremote)); // no free, since we give ownership to reduceSF
5010e3ece09SJunchao Zhang 
5020e3ece09SJunchao Zhang     for (PetscInt i = 0; i < nranks; i++) {
5030e3ece09SJunchao Zhang       PetscInt count = 0;
5040e3ece09SJunchao Zhang       for (PetscInt j = roffset[i]; j < roffset[i + 1]; j++) count += E_RowLen[rmine[j]];
5050e3ece09SJunchao Zhang       for (PetscInt j = 0; j < count; j++) {
5060e3ece09SJunchao Zhang         iremote[nleaves + j].rank  = ranks[i];
5070e3ece09SJunchao Zhang         iremote[nleaves + j].index = sdisp[i] + j;
5080e3ece09SJunchao Zhang       }
5090e3ece09SJunchao Zhang       nleaves += count;
5100e3ece09SJunchao Zhang     }
5110e3ece09SJunchao Zhang     PetscCheck(nleaves == Enz, comm, PETSC_ERR_PLIB, "nleaves should be equal to Enz");
5120e3ece09SJunchao Zhang 
5130e3ece09SJunchao Zhang     PetscCall(PetscSFCreate(comm, &reduceSF));
5140e3ece09SJunchao Zhang     PetscCall(PetscSFSetGraph(reduceSF, nroots, nleaves, NULL, PETSC_OWN_POINTER, iremote, PETSC_OWN_POINTER));
5150e3ece09SJunchao Zhang 
5160e3ece09SJunchao Zhang     // Copy (global) column indices of the needed rows in E to sendCol[], and then PetscSFReduce to recvCol[]
5170e3ece09SJunchao Zhang     PetscInt *sendCol, *recvCol;
5180e3ece09SJunchao Zhang     PetscCall(PetscMalloc2(nleaves, &sendCol, nroots, &recvCol));
5190e3ece09SJunchao Zhang     for (PetscInt k = 0; k < roffset[nranks]; k++) {
5200e3ece09SJunchao Zhang       PetscInt  i      = rmine[k]; // row to be copied
5210e3ece09SJunchao Zhang       PetscInt *buf    = &sendCol[Ai[i] + Bi[i]];
5220e3ece09SJunchao Zhang       PetscInt  nzLeft = E_NzLeft[i];
5230e3ece09SJunchao Zhang       PetscInt  alen = Ai[i + 1] - Ai[i], blen = Bi[i + 1] - Bi[i];
5240e3ece09SJunchao Zhang       for (PetscInt j = 0; j < alen + blen; j++) {
5250e3ece09SJunchao Zhang         if (j < nzLeft) {
5260e3ece09SJunchao Zhang           buf[j] = garray1[Bj[Bi[i] + j]]; // left B, in global
5270e3ece09SJunchao Zhang         } else if (j < nzLeft + alen) {
5280e3ece09SJunchao Zhang           buf[j] = Aj[Ai[i] + j - nzLeft] + cstart; // diag A, also in global
5290e3ece09SJunchao Zhang         } else {
5300e3ece09SJunchao Zhang           buf[j] = garray1[Bj[Bi[i] + j - alen]]; // right B, in global
5310e3ece09SJunchao Zhang         }
5320e3ece09SJunchao Zhang       }
5330e3ece09SJunchao Zhang     }
5340e3ece09SJunchao Zhang     PetscCall(PetscSFReduceWithMemTypeBegin(reduceSF, MPIU_INT, PETSC_MEMTYPE_HOST, sendCol, PETSC_MEMTYPE_HOST, recvCol, MPI_REPLACE));
5350e3ece09SJunchao Zhang     PetscCall(PetscSFReduceEnd(reduceSF, MPIU_INT, sendCol, recvCol, MPI_REPLACE));
5360e3ece09SJunchao Zhang 
5370e3ece09SJunchao Zhang     // With recvCol[], we do a series of analysis to get i, j of Fd, Fo, and build plans to reduce nonzeros in recv buffers to Fd and Fo
5380e3ece09SJunchao Zhang     PetscInt *recvRowPerm, *recvColSorted;
5390e3ece09SJunchao Zhang     PetscInt *recvNzPerm, *recvNzPermSorted;
5400e3ece09SJunchao Zhang     PetscCall(PetscMalloc4(recvRowCnt, &recvRowPerm, nroots, &recvColSorted, nroots, &recvNzPerm, nroots, &recvNzPermSorted));
5410e3ece09SJunchao Zhang 
5420e3ece09SJunchao Zhang     for (PetscInt i = 0; i < nroots; i++) recvNzPerm[i] = i;                   // numbering all received nonzeros
5430e3ece09SJunchao Zhang     for (PetscInt i = 0; i < recvRowCnt; i++) recvRowPerm[i] = i;              // put up a permutation array, so that after sorting we know where to get a row in recvCol[]
5440e3ece09SJunchao Zhang     PetscCall(PetscSortIntWithPermutation(recvRowCnt, irootloc, recvRowPerm)); // irootloc[] (owned by ownerSF) won't be changed
5450e3ece09SJunchao Zhang 
5460e3ece09SJunchao Zhang     // i[] array, nz are always easiest to compute
5477b8d4ba6SJunchao Zhang     MatRowMapKokkosViewHost Fdi_h(NoInit("Fdi_h"), Fm + 1), Foi_h(NoInit("Foi_h"), Fm + 1);
5480e3ece09SJunchao Zhang     MatRowMapType          *Fdi, *Foi;
5490e3ece09SJunchao Zhang     PetscInt                FnzDups = 0, Fdnz = 0, FdnzDups = 0, Fonz = 0, FonzDups = 0; // nz (with or without dups) in F, Fd, Fo
5500e3ece09SJunchao Zhang     PetscInt                iter;
5510e3ece09SJunchao Zhang 
5520e3ece09SJunchao Zhang     Kokkos::deep_copy(Fdi_h, 0); // zero, as we will do 'val++' on them
5530e3ece09SJunchao Zhang     Kokkos::deep_copy(Foi_h, 0);
5540e3ece09SJunchao Zhang     Fdi  = Fdi_h.data() + 1; // +1 for easy indexing in code below
5550e3ece09SJunchao Zhang     Foi  = Foi_h.data() + 1;
5560e3ece09SJunchao Zhang     iter = 0;
5570e3ece09SJunchao Zhang     while (iter < recvRowCnt) { // iter over received rows
5580e3ece09SJunchao Zhang       PetscInt curRowIdx = irootloc[recvRowPerm[iter]];
5590e3ece09SJunchao Zhang       PetscInt dupRows   = 1; // current row has this many contributing rows (of various sparsity patterns)
5600e3ece09SJunchao Zhang 
5610e3ece09SJunchao Zhang       while (iter + dupRows < recvRowCnt && irootloc[recvRowPerm[iter + dupRows]] == curRowIdx) dupRows++;
5620e3ece09SJunchao Zhang 
5630e3ece09SJunchao Zhang       // Copy column indices (and their permutation) of these rows into recvColSorted & recvNzPermSorted
5640e3ece09SJunchao Zhang       PetscInt  nz    = 0; // nz (with dups) in the current row
5650e3ece09SJunchao Zhang       PetscInt *jbuf  = recvColSorted + FnzDups;
5660e3ece09SJunchao Zhang       PetscInt *pbuf  = recvNzPermSorted + FnzDups;
5670e3ece09SJunchao Zhang       PetscInt *jbuf2 = jbuf; // temp pointers
5680e3ece09SJunchao Zhang       PetscInt *pbuf2 = pbuf;
5690e3ece09SJunchao Zhang       for (PetscInt d = 0; d < dupRows; d++) {
5700e3ece09SJunchao Zhang         PetscInt i   = recvRowPerm[iter + d];
5710e3ece09SJunchao Zhang         PetscInt len = recvRowLen[i + 1] - recvRowLen[i];
5720e3ece09SJunchao Zhang         PetscCall(PetscArraycpy(jbuf2, &recvCol[recvRowLen[i]], len));
5730e3ece09SJunchao Zhang         PetscCall(PetscArraycpy(pbuf2, &recvNzPerm[recvRowLen[i]], len));
5740e3ece09SJunchao Zhang         jbuf2 += len;
5750e3ece09SJunchao Zhang         pbuf2 += len;
5760e3ece09SJunchao Zhang         nz += len;
5770e3ece09SJunchao Zhang       }
5780e3ece09SJunchao Zhang       PetscCall(PetscIntSortSemiOrderedWithArray(nz, jbuf, pbuf)); // It could be improved with k-way merge sort, since the rows are already sorted
5790e3ece09SJunchao Zhang 
5800e3ece09SJunchao Zhang       // Scan column indices (in jbuf[0,nz), might have dups) of this row, and see how many go to Fd and how many go to Fo
5810e3ece09SJunchao Zhang       PetscInt cur = 0;
5820e3ece09SJunchao Zhang       while (cur < nz) {
5830e3ece09SJunchao Zhang         PetscInt curColIdx = jbuf[cur];
5840e3ece09SJunchao Zhang         PetscInt dups      = 1;
5850e3ece09SJunchao Zhang 
5860e3ece09SJunchao Zhang         while (cur + dups < nz && jbuf[cur + dups] == curColIdx) dups++;
5870e3ece09SJunchao Zhang         if (curColIdx >= cstart && curColIdx < cend) {
5880e3ece09SJunchao Zhang           Fdi[curRowIdx]++;
5890e3ece09SJunchao Zhang           FdnzDups += dups;
5900e3ece09SJunchao Zhang         } else {
5910e3ece09SJunchao Zhang           Foi[curRowIdx]++;
5920e3ece09SJunchao Zhang           FonzDups += dups;
5930e3ece09SJunchao Zhang         }
5940e3ece09SJunchao Zhang         cur += dups;
5950e3ece09SJunchao Zhang       }
5960e3ece09SJunchao Zhang 
5970e3ece09SJunchao Zhang       FnzDups += nz;
5980e3ece09SJunchao Zhang       iter += dupRows; // Move to next unique row
5990e3ece09SJunchao Zhang     }
6000e3ece09SJunchao Zhang 
6010e3ece09SJunchao Zhang     Fdi = Fdi_h.data(); // restore Fdi, Foi and make them CSR
6020e3ece09SJunchao Zhang     Foi = Foi_h.data();
6030e3ece09SJunchao Zhang     for (PetscInt i = 0; i < Fm; i++) {
6040e3ece09SJunchao Zhang       Fdi[i + 1] += Fdi[i];
6050e3ece09SJunchao Zhang       Foi[i + 1] += Foi[i];
6060e3ece09SJunchao Zhang     }
6070e3ece09SJunchao Zhang     Fdnz = Fdi[Fm];
6080e3ece09SJunchao Zhang     Fonz = Foi[Fm];
6090e3ece09SJunchao Zhang     PetscCall(PetscFree2(sendCol, recvCol));
6100e3ece09SJunchao Zhang 
6110e3ece09SJunchao Zhang     // Allocate j, jmap, jperm for Fd and Fo
6127b8d4ba6SJunchao Zhang     MatColIdxKokkosViewHost Fdj_h(NoInit("Fdj_h"), Fdnz), Foj_h(NoInit("Foj_h"), Fonz);
6137b8d4ba6SJunchao Zhang     MatRowMapKokkosViewHost Fdjmap_h(NoInit("Fdjmap_h"), Fdnz + 1), Fojmap_h(NoInit("Fojmap_h"), Fonz + 1); // +1 to make csr
6147b8d4ba6SJunchao Zhang     MatRowMapKokkosViewHost Fdjperm_h(NoInit("Fdjperm_h"), FdnzDups), Fojperm_h(NoInit("Fojperm_h"), FonzDups);
6150e3ece09SJunchao Zhang     MatColIdxType          *Fdj = Fdj_h.data(), *Foj = Foj_h.data();
6160e3ece09SJunchao Zhang     MatRowMapType          *Fdjmap = Fdjmap_h.data(), *Fojmap = Fojmap_h.data();
6170e3ece09SJunchao Zhang     MatRowMapType          *Fdjperm = Fdjperm_h.data(), *Fojperm = Fojperm_h.data();
6180e3ece09SJunchao Zhang 
6190e3ece09SJunchao Zhang     // Scan recvColSorted[] again, and fill j, jmap, jperm for Fd and Fo
6200e3ece09SJunchao Zhang     Fdjmap[0] = 0;
6210e3ece09SJunchao Zhang     Fojmap[0] = 0;
6220e3ece09SJunchao Zhang     FnzDups   = 0;
6230e3ece09SJunchao Zhang     Fdnz      = 0;
6240e3ece09SJunchao Zhang     Fonz      = 0;
6250e3ece09SJunchao Zhang     iter      = 0; // iter over received rows
6260e3ece09SJunchao Zhang     while (iter < recvRowCnt) {
6270e3ece09SJunchao Zhang       PetscInt curRowIdx = irootloc[recvRowPerm[iter]]; // current row idx
6280e3ece09SJunchao Zhang       PetscInt dupRows   = 1;                           // It has this many contributing rows (of various lengths)
6290e3ece09SJunchao Zhang       PetscInt nz        = 0;                           // nz (with dups) in the current row
6300e3ece09SJunchao Zhang 
6310e3ece09SJunchao Zhang       while (iter + dupRows < recvRowCnt && irootloc[recvRowPerm[iter + dupRows]] == curRowIdx) dupRows++;
6320e3ece09SJunchao Zhang       for (PetscInt d = 0; d < dupRows; d++) {
6330e3ece09SJunchao Zhang         PetscInt i = recvRowPerm[iter + d];
6340e3ece09SJunchao Zhang         nz += recvRowLen[i + 1] - recvRowLen[i];
6350e3ece09SJunchao Zhang       }
6360e3ece09SJunchao Zhang 
6370e3ece09SJunchao Zhang       PetscInt *jbuf = recvColSorted + FnzDups;
6380e3ece09SJunchao Zhang       // Scan columns (in jbuf[0,nz) of this row, copy them and their permutation to j[] and jperm[] of Fd and Fo
6390e3ece09SJunchao Zhang       PetscInt cur = 0;
6400e3ece09SJunchao Zhang       while (cur < nz) {
6410e3ece09SJunchao Zhang         PetscInt curColIdx = jbuf[cur];
6420e3ece09SJunchao Zhang         PetscInt dups      = 1;
6430e3ece09SJunchao Zhang 
6440e3ece09SJunchao Zhang         while (cur + dups < nz && jbuf[cur + dups] == curColIdx) dups++;
6450e3ece09SJunchao Zhang         if (curColIdx >= cstart && curColIdx < cend) {
6460e3ece09SJunchao Zhang           Fdj[Fdnz]        = curColIdx - cstart; // easily convert to local
6470e3ece09SJunchao Zhang           Fdjmap[Fdnz + 1] = Fdjmap[Fdnz] + dups;
6480e3ece09SJunchao Zhang           for (PetscInt j = 0; j < dups; j++) Fdjperm[Fdjmap[Fdnz] + j] = recvNzPermSorted[FnzDups + j];
6490e3ece09SJunchao Zhang           FdnzDups += dups;
6500e3ece09SJunchao Zhang           Fdnz++;
6510e3ece09SJunchao Zhang         } else {
6520e3ece09SJunchao Zhang           Foj[Fonz]        = curColIdx; // in global
6530e3ece09SJunchao Zhang           Fojmap[Fonz + 1] = Fojmap[Fonz] + dups;
6540e3ece09SJunchao Zhang           for (PetscInt j = 0; j < dups; j++) Fojperm[Fojmap[Fonz] + j] = recvNzPermSorted[FnzDups + j];
6550e3ece09SJunchao Zhang           FonzDups += dups;
6560e3ece09SJunchao Zhang           Fonz++;
6570e3ece09SJunchao Zhang         }
6580e3ece09SJunchao Zhang         cur += dups;
6590e3ece09SJunchao Zhang         FnzDups += dups;
6600e3ece09SJunchao Zhang       }
6610e3ece09SJunchao Zhang       iter += dupRows; // Move to next unique row
6620e3ece09SJunchao Zhang     }
6630e3ece09SJunchao Zhang     PetscCall(PetscFree4(recvRowPerm, recvColSorted, recvNzPerm, recvNzPermSorted));
6640e3ece09SJunchao Zhang     PetscCall(PetscFree5(sendRowLen, recvRowLen, sdisp, rdisp, reqs));
6650e3ece09SJunchao Zhang 
6660e3ece09SJunchao Zhang     // Combine global column indices in garray1[] and Foj[]
6670e3ece09SJunchao Zhang     PetscInt n2, *garray2;
6680e3ece09SJunchao Zhang 
6690e3ece09SJunchao Zhang     PetscCall(ReduceTwoSetsOfGlobalIndices(n1, garray1, Fonz, Foj, &n2, &garray2, map));
6700e3ece09SJunchao Zhang     mm->sf       = reduceSF;
6717b8d4ba6SJunchao Zhang     mm->leafBuf  = MatScalarKokkosView(NoInit("leafBuf"), nleaves);
6727b8d4ba6SJunchao Zhang     mm->rootBuf  = MatScalarKokkosView(NoInit("rootBuf"), nroots);
673aaa8cc7dSPierre Jolivet     mm->garray   = garray2; // give ownership, so no free
6740e3ece09SJunchao Zhang     mm->n        = n2;
6750e3ece09SJunchao Zhang     mm->E_NzLeft = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), E_NzLeft_h);
6760e3ece09SJunchao Zhang     mm->Fdjmap   = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Fdjmap_h);
6770e3ece09SJunchao Zhang     mm->Fdjperm  = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Fdjperm_h);
6780e3ece09SJunchao Zhang     mm->Fojmap   = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Fojmap_h);
6790e3ece09SJunchao Zhang     mm->Fojperm  = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Fojperm_h);
6800e3ece09SJunchao Zhang 
6810e3ece09SJunchao Zhang     // Output Fd and Fo in KokkosCsrMatrix format
6827b8d4ba6SJunchao Zhang     MatScalarKokkosView Fda_d(NoInit("Fda_d"), Fdnz);
6830e3ece09SJunchao Zhang     MatRowMapKokkosView Fdi_d = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Fdi_h);
6840e3ece09SJunchao Zhang     MatColIdxKokkosView Fdj_d = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Fdj_h);
6857b8d4ba6SJunchao Zhang     MatScalarKokkosView Foa_d(NoInit("Foa_d"), Fonz);
6860e3ece09SJunchao Zhang     MatRowMapKokkosView Foi_d = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Foi_h);
6870e3ece09SJunchao Zhang     MatColIdxKokkosView Foj_d = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Foj_h);
6880e3ece09SJunchao Zhang 
6890e3ece09SJunchao Zhang     PetscCallCXX(mm->Fd = KokkosCsrMatrix("Fd", Fm, cend - cstart, Fdnz, Fda_d, Fdi_d, Fdj_d));
6900e3ece09SJunchao Zhang     PetscCallCXX(mm->Fo = KokkosCsrMatrix("Fo", Fm, n2, Fonz, Foa_d, Foi_d, Foj_d)); // Fo's column size is n2, length of garray2[]
6910e3ece09SJunchao Zhang 
6920e3ece09SJunchao Zhang     // Compute kernel launch parameters in merging E
6930e3ece09SJunchao Zhang     PetscInt teamSize, vectorLength, rowsPerTeam;
6940e3ece09SJunchao Zhang 
6950e3ece09SJunchao Zhang     teamSize = vectorLength = rowsPerTeam = -1;
6960e3ece09SJunchao Zhang     PetscCall(MatMergeGetLaunchParameters<DefaultExecutionSpace>(Em, Enz, -1, teamSize, vectorLength, rowsPerTeam));
6970e3ece09SJunchao Zhang     mm->E_TeamSize     = teamSize;
6980e3ece09SJunchao Zhang     mm->E_VectorLength = vectorLength;
6990e3ece09SJunchao Zhang     mm->E_RowsPerTeam  = rowsPerTeam;
7000e3ece09SJunchao Zhang   } else PetscCheck(reuse == MAT_REUSE_MATRIX, comm, PETSC_ERR_PLIB, "Unsupported MatReuse enum %d", reuse);
7010e3ece09SJunchao Zhang 
7020e3ece09SJunchao Zhang   // Handy aliases
7030e3ece09SJunchao Zhang   auto       &Aa           = A.values;
7040e3ece09SJunchao Zhang   auto       &Ba           = B.values;
7050e3ece09SJunchao Zhang   const auto &Ai           = A.graph.row_map;
7060e3ece09SJunchao Zhang   const auto &Bi           = B.graph.row_map;
7070e3ece09SJunchao Zhang   const auto &E_NzLeft     = mm->E_NzLeft;
7080e3ece09SJunchao Zhang   auto       &leafBuf      = mm->leafBuf;
7090e3ece09SJunchao Zhang   auto       &rootBuf      = mm->rootBuf;
7100e3ece09SJunchao Zhang   PetscSF     reduceSF     = mm->sf;
7110e3ece09SJunchao Zhang   PetscInt    Em           = A.numRows();
7120e3ece09SJunchao Zhang   PetscInt    teamSize     = mm->E_TeamSize;
7130e3ece09SJunchao Zhang   PetscInt    vectorLength = mm->E_VectorLength;
7140e3ece09SJunchao Zhang   PetscInt    rowsPerTeam  = mm->E_RowsPerTeam;
7150e3ece09SJunchao Zhang   PetscInt    workSets     = (Em + rowsPerTeam - 1) / rowsPerTeam;
7160e3ece09SJunchao Zhang 
7170e3ece09SJunchao Zhang   // Copy rows in A/B of E to leafBuf, then pass it to rootBuf
7180e3ece09SJunchao Zhang   PetscCallCXX(Kokkos::parallel_for(
719*d326c3f1SJunchao Zhang     Kokkos::TeamPolicy<>(PetscGetKokkosExecutionSpace(), workSets, teamSize, vectorLength), KOKKOS_LAMBDA(const KokkosTeamMemberType &t) {
7200e3ece09SJunchao Zhang       Kokkos::parallel_for(Kokkos::TeamThreadRange(t, 0, rowsPerTeam), [&](PetscInt k) {
7210e3ece09SJunchao Zhang         PetscInt i = t.league_rank() * rowsPerTeam + k; // i-th row in F
7220e3ece09SJunchao Zhang         if (i < Em) {
7230e3ece09SJunchao Zhang           PetscInt disp   = Ai(i) + Bi(i);
7240e3ece09SJunchao Zhang           PetscInt alen   = Ai(i + 1) - Ai(i);
7250e3ece09SJunchao Zhang           PetscInt blen   = Bi(i + 1) - Bi(i);
7260e3ece09SJunchao Zhang           PetscInt nzleft = E_NzLeft(i);
7270e3ece09SJunchao Zhang 
7280e3ece09SJunchao Zhang           Kokkos::parallel_for(Kokkos::ThreadVectorRange(t, alen + blen), [&](PetscInt j) {
7290e3ece09SJunchao Zhang             MatScalar &val = leafBuf(disp + j);
7300e3ece09SJunchao Zhang             if (j < nzleft) { // B left
7310e3ece09SJunchao Zhang               val = Ba(Bi(i) + j);
7320e3ece09SJunchao Zhang             } else if (j < nzleft + alen) { // diag A
7330e3ece09SJunchao Zhang               val = Aa(Ai(i) + j - nzleft);
7340e3ece09SJunchao Zhang             } else { // B right
7350e3ece09SJunchao Zhang               val = Ba(Bi(i) + j - alen);
736f0e6e2d1SJunchao Zhang             }
737f0e6e2d1SJunchao Zhang           });
738f0e6e2d1SJunchao Zhang         }
739f0e6e2d1SJunchao Zhang       });
7400e3ece09SJunchao Zhang     }));
7410e3ece09SJunchao Zhang   PetscCall(PetscSFReduceWithMemTypeBegin(reduceSF, MPIU_SCALAR, PETSC_MEMTYPE_KOKKOS, leafBuf.data(), PETSC_MEMTYPE_KOKKOS, rootBuf.data(), MPI_REPLACE));
742f0e6e2d1SJunchao Zhang   PetscFunctionReturn(PETSC_SUCCESS);
743f0e6e2d1SJunchao Zhang }
7440e3ece09SJunchao Zhang 
745aaa8cc7dSPierre Jolivet // To finish MatMPIAIJKokkosReduce.
7460e3ece09SJunchao Zhang static PetscErrorCode MatMPIAIJKokkosReduceEnd(MPI_Comm comm, KokkosCsrMatrix A, KokkosCsrMatrix B, PetscInt cstart, PetscInt cend, const PetscInt *garray1, PetscSF ownerSF, MatReuse reuse, PetscInt *map, MatMatStruct_AtB *mm)
7470e3ece09SJunchao Zhang {
7480e3ece09SJunchao Zhang   auto       &leafBuf  = mm->leafBuf;
7490e3ece09SJunchao Zhang   auto       &rootBuf  = mm->rootBuf;
7500e3ece09SJunchao Zhang   auto       &Fda      = mm->Fd.values;
7510e3ece09SJunchao Zhang   const auto &Fdjmap   = mm->Fdjmap;
7520e3ece09SJunchao Zhang   const auto &Fdjperm  = mm->Fdjperm;
7530e3ece09SJunchao Zhang   auto        Fdnz     = mm->Fd.nnz();
7540e3ece09SJunchao Zhang   auto       &Foa      = mm->Fo.values;
7550e3ece09SJunchao Zhang   const auto &Fojmap   = mm->Fojmap;
7560e3ece09SJunchao Zhang   const auto &Fojperm  = mm->Fojperm;
7570e3ece09SJunchao Zhang   auto        Fonz     = mm->Fo.nnz();
7580e3ece09SJunchao Zhang   PetscSF     reduceSF = mm->sf;
7590e3ece09SJunchao Zhang 
760*d326c3f1SJunchao Zhang   PetscFunctionBegin;
7610e3ece09SJunchao Zhang   PetscCall(PetscSFReduceEnd(reduceSF, MPIU_SCALAR, leafBuf.data(), rootBuf.data(), MPI_REPLACE));
7620e3ece09SJunchao Zhang 
7630e3ece09SJunchao Zhang   // Reduce data in rootBuf to Fd and Fo
7640e3ece09SJunchao Zhang   PetscCallCXX(Kokkos::parallel_for(
765*d326c3f1SJunchao Zhang     Kokkos::RangePolicy<>(PetscGetKokkosExecutionSpace(), 0, Fdnz), KOKKOS_LAMBDA(const MatRowMapType i) {
7660e3ece09SJunchao Zhang       PetscScalar sum = 0.0;
7670e3ece09SJunchao Zhang       for (MatRowMapType k = Fdjmap(i); k < Fdjmap(i + 1); k++) sum += rootBuf(Fdjperm(k));
7680e3ece09SJunchao Zhang       Fda(i) = sum;
7690e3ece09SJunchao Zhang     }));
7700e3ece09SJunchao Zhang 
7710e3ece09SJunchao Zhang   PetscCallCXX(Kokkos::parallel_for(
772*d326c3f1SJunchao Zhang     Kokkos::RangePolicy<>(PetscGetKokkosExecutionSpace(), 0, Fonz), KOKKOS_LAMBDA(const MatRowMapType i) {
7730e3ece09SJunchao Zhang       PetscScalar sum = 0.0;
7740e3ece09SJunchao Zhang       for (MatRowMapType k = Fojmap(i); k < Fojmap(i + 1); k++) sum += rootBuf(Fojperm(k));
7750e3ece09SJunchao Zhang       Foa(i) = sum;
7760e3ece09SJunchao Zhang     }));
7770e3ece09SJunchao Zhang   PetscFunctionReturn(PETSC_SUCCESS);
7780e3ece09SJunchao Zhang }
7790e3ece09SJunchao Zhang 
7800e3ece09SJunchao Zhang /*
7810e3ece09SJunchao Zhang   MatMPIAIJKokkosBcast - Bcast local rows of a MPIAIJKOKKOS matrix (E) to produce a local matrix (F, stored in mm) in split form
7820e3ece09SJunchao Zhang 
7830e3ece09SJunchao Zhang   This is a complex routine. It is essentially the MPIAIJKOKKOS counterpart of MatGetBrowsOfAoCols_MPIAIJ, but supports
7840e3ece09SJunchao Zhang   device and involves various index mapping.
7850e3ece09SJunchao Zhang 
7860e3ece09SJunchao Zhang   In the given ownerSF, leaves correspond to rows in F, and roots correspond to rows in E. Roots may connect to multiple leaves.
7870e3ece09SJunchao Zhang   Suppose F's j-th row is connected to a root identified by PetscSFNode (k,i), it means we need to bcast the i-th row of E on rank k
7880e3ece09SJunchao Zhang   to j-th row of F. ownerSF is not an arbitrary SF, instead it is the Mvctx of another MPIAIJ matrix A that is able to perform A*E.
7890e3ece09SJunchao Zhang   F has the same column layout as E.
7900e3ece09SJunchao Zhang 
7910e3ece09SJunchao Zhang   Conceptually F has global column indices. In this routine, we spit F into diagonal Fd and off-diagonal Fo.
792aaa8cc7dSPierre Jolivet   Fd uses local column indices, which are easy to compute. We just need to subtract the "local column range start" from the global indices.
7930e3ece09SJunchao Zhang   Fo had global column indices at first. We will reduce them into local ones. In doing that, we also take into account the global
7940e3ece09SJunchao Zhang   column indices that E's off-diag block has. Let's say there are n1 such indices stored in garray1[]. We will reduce them along with
7950e3ece09SJunchao Zhang   column indices in Fo and update Fo with local indices.
7960e3ece09SJunchao Zhang 
7970e3ece09SJunchao Zhang    Input Parameters:
7980e3ece09SJunchao Zhang +   E       - the MPIAIJKOKKOS matrix
7999c89aa79SPierre Jolivet .   ownerSF - the ownership SF (insignificant in MAT_REUSE_MATRIX)
8000e3ece09SJunchao Zhang .   reuse   - either MAT_INITIAL_MATRIX or MAT_REUSE_MATRIX
8010e3ece09SJunchao Zhang -   mm      - to stash matproduct intermediate data structures
8020e3ece09SJunchao Zhang 
8030e3ece09SJunchao Zhang     Output Parameters:
8040e3ece09SJunchao Zhang +   map[n1] - allocated by caller. It maps garray1[] to garray2[]. See more at ReduceTwoSetsOfGlobalIndices.
8050e3ece09SJunchao Zhang -   mm      - contains various info, such as garray2[], Fd, Fo, etc.
8060e3ece09SJunchao Zhang 
8070e3ece09SJunchao Zhang     Notes:
8080e3ece09SJunchao Zhang     When reuse = MAT_REUSE_MATRIX, ownerSF, map are not significant.
8090e3ece09SJunchao Zhang     The routine is provide in split-phase form MatMPIAIJKokkosBcastBegin/End() to provide computation/communication opportunities.
8100e3ece09SJunchao Zhang */
8110e3ece09SJunchao Zhang static PetscErrorCode MatMPIAIJKokkosBcastBegin(Mat E, PetscSF ownerSF, MatReuse reuse, PetscInt *map, MatMatStruct_AB *mm)
8120e3ece09SJunchao Zhang {
8130e3ece09SJunchao Zhang   Mat_MPIAIJ       *empi = static_cast<Mat_MPIAIJ *>(E->data);
8140e3ece09SJunchao Zhang   Mat               A = empi->A, B = empi->B; // diag and off-diag
8150e3ece09SJunchao Zhang   Mat_SeqAIJKokkos *akok = static_cast<Mat_SeqAIJKokkos *>(A->spptr), *bkok = static_cast<Mat_SeqAIJKokkos *>(B->spptr);
8160e3ece09SJunchao Zhang   PetscInt          Em = E->rmap->n; // #local rows
8170e3ece09SJunchao Zhang   MPI_Comm          comm;
8180e3ece09SJunchao Zhang 
8190e3ece09SJunchao Zhang   PetscFunctionBegin;
8200e3ece09SJunchao Zhang   PetscCallMPI(PetscObjectGetComm((PetscObject)E, &comm));
8210e3ece09SJunchao Zhang   if (reuse == MAT_INITIAL_MATRIX) {
8220e3ece09SJunchao Zhang     Mat_SeqAIJ     *aseq = static_cast<Mat_SeqAIJ *>(A->data), *bseq = static_cast<Mat_SeqAIJ *>(B->data);
8230e3ece09SJunchao Zhang     PetscInt        n1 = B->cmap->n, *Ai = aseq->i, *Aj = aseq->j, *Bi = bseq->i, *Bj = bseq->j;
8240e3ece09SJunchao Zhang     const PetscInt *garray1 = empi->garray; // its size is n1
8250e3ece09SJunchao Zhang     PetscInt        cstart, cend;
8260e3ece09SJunchao Zhang     PetscSF         bcastSF;
8270e3ece09SJunchao Zhang 
8280e3ece09SJunchao Zhang     PetscCall(MatGetOwnershipRangeColumn(E, &cstart, &cend));
8290e3ece09SJunchao Zhang 
8300e3ece09SJunchao Zhang     // Count how many nonzeros of each row in E are in the left of the diag block [cstart,cend)
8317b8d4ba6SJunchao Zhang     PetscIntKokkosViewHost E_NzLeft_h(NoInit("E_NzLeft_h"), Em), E_RowLen_h(NoInit("E_RowLen_h"), Em);
8320e3ece09SJunchao Zhang     PetscInt              *E_NzLeft = E_NzLeft_h.data(), *E_RowLen = E_RowLen_h.data();
8330e3ece09SJunchao Zhang     for (PetscInt i = 0; i < Em; i++) {
8340e3ece09SJunchao Zhang       const PetscInt *first, *last, *it;
8350e3ece09SJunchao Zhang       PetscInt        count, step;
8360e3ece09SJunchao Zhang       // std::lower_bound(first,last,cstart), but need to use global column indices
8370e3ece09SJunchao Zhang       first = Bj + Bi[i];
8380e3ece09SJunchao Zhang       last  = Bj + Bi[i + 1];
8390e3ece09SJunchao Zhang       count = last - first;
8400e3ece09SJunchao Zhang       while (count > 0) {
8410e3ece09SJunchao Zhang         it   = first;
8420e3ece09SJunchao Zhang         step = count / 2;
8430e3ece09SJunchao Zhang         it += step;
8440e3ece09SJunchao Zhang         if (empi->garray[*it] < cstart) { // map local to global
8450e3ece09SJunchao Zhang           first = ++it;
8460e3ece09SJunchao Zhang           count -= step + 1;
8470e3ece09SJunchao Zhang         } else count = step;
8480e3ece09SJunchao Zhang       }
8490e3ece09SJunchao Zhang       E_NzLeft[i] = first - (Bj + Bi[i]);
8500e3ece09SJunchao Zhang       E_RowLen[i] = (Ai[i + 1] - Ai[i]) + (Bi[i + 1] - Bi[i]);
8510e3ece09SJunchao Zhang     }
8520e3ece09SJunchao Zhang 
8530e3ece09SJunchao Zhang     // Compute row pointer Fi of F
8540e3ece09SJunchao Zhang     PetscInt *Fi, Fm, Fnz;
8550e3ece09SJunchao Zhang     PetscCall(PetscSFGetGraph(ownerSF, NULL, &Fm, NULL, NULL)); // Fm = #rows of F = nleaves of ownerSF
8560e3ece09SJunchao Zhang     PetscCall(PetscMalloc1(Fm + 1, &Fi));
8570e3ece09SJunchao Zhang     Fi[0] = 0;
8580e3ece09SJunchao Zhang     PetscCall(PetscSFBcastWithMemTypeBegin(ownerSF, MPIU_INT, PETSC_MEMTYPE_HOST, E_RowLen, PETSC_MEMTYPE_HOST, &Fi[1], MPI_REPLACE));
8590e3ece09SJunchao Zhang     PetscCall(PetscSFBcastEnd(ownerSF, MPIU_INT, E_RowLen, &Fi[1], MPI_REPLACE));
8600e3ece09SJunchao Zhang     for (PetscInt i = 0; i < Fm; i++) Fi[i + 1] += Fi[i];
8610e3ece09SJunchao Zhang     Fnz = Fi[Fm];
8620e3ece09SJunchao Zhang 
8630e3ece09SJunchao Zhang     // Build the real PetscSF for bcasting E rows (buffer to buffer)
8640e3ece09SJunchao Zhang     const PetscMPIInt *iranks, *ranks;
8650e3ece09SJunchao Zhang     const PetscInt    *ioffset, *irootloc, *roffset;
8660e3ece09SJunchao Zhang     PetscInt           niranks, nranks, *sdisp, *rdisp;
8670e3ece09SJunchao Zhang     MPI_Request       *reqs;
8680e3ece09SJunchao Zhang     PetscMPIInt        tag;
8690e3ece09SJunchao Zhang 
8700e3ece09SJunchao Zhang     PetscCall(PetscSFGetLeafRanks(ownerSF, &niranks, &iranks, &ioffset, &irootloc)); // get leaf ranks referencing roots on this process
8710e3ece09SJunchao Zhang     PetscCall(PetscSFGetRootRanks(ownerSF, &nranks, &ranks, &roffset, NULL, NULL));  // recv info
8720e3ece09SJunchao Zhang     PetscCall(PetscMalloc3(niranks + 1, &sdisp, nranks, &rdisp, niranks + nranks, &reqs));
8730e3ece09SJunchao Zhang 
8740e3ece09SJunchao Zhang     sdisp[0] = 0; // send displacement
8750e3ece09SJunchao Zhang     for (PetscInt i = 0; i < niranks; i++) {
8760e3ece09SJunchao Zhang       sdisp[i + 1] = sdisp[i];
8770e3ece09SJunchao Zhang       for (PetscInt j = ioffset[i]; j < ioffset[i + 1]; j++) {
8780e3ece09SJunchao Zhang         PetscInt r = irootloc[j]; // row to be sent
8790e3ece09SJunchao Zhang         sdisp[i + 1] += E_RowLen[r];
8800e3ece09SJunchao Zhang       }
8810e3ece09SJunchao Zhang     }
8820e3ece09SJunchao Zhang 
8830e3ece09SJunchao Zhang     PetscCallMPI(PetscCommGetNewTag(comm, &tag));
8840e3ece09SJunchao Zhang     for (PetscInt i = 0; i < nranks; i++) PetscCallMPI(MPI_Irecv(&rdisp[i], 1, MPIU_INT, ranks[i], tag, comm, &reqs[i]));
8850e3ece09SJunchao Zhang     for (PetscInt i = 0; i < niranks; i++) PetscCallMPI(MPI_Isend(&sdisp[i], 1, MPIU_INT, iranks[i], tag, comm, &reqs[nranks + i]));
8860e3ece09SJunchao Zhang     PetscCallMPI(MPI_Waitall(niranks + nranks, reqs, MPI_STATUSES_IGNORE));
8870e3ece09SJunchao Zhang 
8880e3ece09SJunchao Zhang     PetscInt     nleaves = Fnz;            // leaves are nonzeros I will receive
8890e3ece09SJunchao Zhang     PetscInt     nroots  = sdisp[niranks]; // roots are nonzeros I will send
8900e3ece09SJunchao Zhang     PetscSFNode *iremote;                  // give ownership to bcastSF
8910e3ece09SJunchao Zhang     PetscCall(PetscMalloc1(nleaves, &iremote));
8920e3ece09SJunchao Zhang     for (PetscInt i = 0; i < nranks; i++) { // for each sender rank
8930e3ece09SJunchao Zhang       PetscInt k = 0;
8940e3ece09SJunchao Zhang       for (PetscInt j = Fi[roffset[i]]; j < Fi[roffset[i + 1]]; j++) { // I will receive rows [roffset[i], roffset[i+1]) of F from ranks[i]
8950e3ece09SJunchao Zhang         iremote[j].rank  = ranks[i];
8960e3ece09SJunchao Zhang         iremote[j].index = rdisp[i] + k; // their root location
8970e3ece09SJunchao Zhang         k++;
8980e3ece09SJunchao Zhang       }
8990e3ece09SJunchao Zhang     }
9000e3ece09SJunchao Zhang     PetscCall(PetscSFCreate(comm, &bcastSF));
9010e3ece09SJunchao Zhang     PetscCall(PetscSFSetGraph(bcastSF, nroots, nleaves, NULL, PETSC_OWN_POINTER, iremote, PETSC_OWN_POINTER));
9020e3ece09SJunchao Zhang     PetscCall(PetscFree3(sdisp, rdisp, reqs));
9030e3ece09SJunchao Zhang 
9040e3ece09SJunchao Zhang     // Build a plan (rowoffset, irootloc, E_NzLeft) to copy rows in E to rootdata of bcastSF in parallel
9057b8d4ba6SJunchao Zhang     PetscIntKokkosViewHost rowoffset_h(NoInit("rowoffset_h"), ioffset[niranks] + 1);
9060e3ece09SJunchao Zhang     PetscInt              *rowoffset = rowoffset_h.data(); // for each entry (row) indicated in irootloc[], we calculate its destinate offset in copying
9070e3ece09SJunchao Zhang     rowoffset[0]                     = 0;
9087b8d4ba6SJunchao Zhang     for (PetscInt i = 0; i < ioffset[niranks]; i++) { rowoffset[i + 1] = rowoffset[i] + E_RowLen[irootloc[i]]; }
9090e3ece09SJunchao Zhang 
9100e3ece09SJunchao Zhang     // Copy (global) column indices of the needed rows in E to a buffer, and then bcast to Fj[]
9110e3ece09SJunchao Zhang     PetscInt *jbuf, *Fj;
9120e3ece09SJunchao Zhang     PetscCall(PetscMalloc2(nroots, &jbuf, Fnz, &Fj));
9130e3ece09SJunchao Zhang     for (PetscInt k = 0; k < ioffset[niranks]; k++) {
9140e3ece09SJunchao Zhang       PetscInt  i      = irootloc[k]; // row to be copied
9150e3ece09SJunchao Zhang       PetscInt *buf    = &jbuf[rowoffset[k]];
9160e3ece09SJunchao Zhang       PetscInt  nzLeft = E_NzLeft[i];
9170e3ece09SJunchao Zhang       PetscInt  alen = Ai[i + 1] - Ai[i], blen = Bi[i + 1] - Bi[i];
9180e3ece09SJunchao Zhang       for (PetscInt j = 0; j < alen + blen; j++) {
9190e3ece09SJunchao Zhang         if (j < nzLeft) {
9200e3ece09SJunchao Zhang           buf[j] = empi->garray[Bj[Bi[i] + j]]; // left B, in global
9210e3ece09SJunchao Zhang         } else if (j < nzLeft + alen) {
9220e3ece09SJunchao Zhang           buf[j] = Aj[Ai[i] + j - nzLeft] + cstart; // diag A, also in global
9230e3ece09SJunchao Zhang         } else {
9240e3ece09SJunchao Zhang           buf[j] = empi->garray[Bj[Bi[i] + j - alen]]; // right B, in global
9250e3ece09SJunchao Zhang         }
9260e3ece09SJunchao Zhang       }
9270e3ece09SJunchao Zhang     }
9280e3ece09SJunchao Zhang     PetscCall(PetscSFBcastWithMemTypeBegin(bcastSF, MPIU_INT, PETSC_MEMTYPE_HOST, jbuf, PETSC_MEMTYPE_HOST, Fj, MPI_REPLACE));
9290e3ece09SJunchao Zhang     PetscCall(PetscSFBcastEnd(bcastSF, MPIU_INT, jbuf, Fj, MPI_REPLACE));
9300e3ece09SJunchao Zhang 
9310e3ece09SJunchao Zhang     // Build a plan (i.e., F_NzLeft) to split F into Fd and Fo
9327b8d4ba6SJunchao Zhang     MatRowMapKokkosViewHost Fdi_h(NoInit("Fdi_h"), Fm + 1), Foi_h(NoInit("Foi_h"), Fm + 1); // row pointer of Fd, Fo
9337b8d4ba6SJunchao Zhang     MatColIdxKokkosViewHost F_NzLeft_h(NoInit("F_NzLeft_h"), Fm);                           // split each row of F into Left, Diag, Right. We only need to record #nz in Left and Diag.
9340e3ece09SJunchao Zhang     MatRowMapType          *Fdi = Fdi_h.data(), *Foi = Foi_h.data();
9350e3ece09SJunchao Zhang     MatColIdxType          *F_NzLeft = F_NzLeft_h.data();
9360e3ece09SJunchao Zhang 
9370e3ece09SJunchao Zhang     Fdi[0] = Foi[0] = 0;
9380e3ece09SJunchao Zhang     for (PetscInt i = 0; i < Fm; i++) {
9390e3ece09SJunchao Zhang       PetscInt *first, *last, *lb1, *lb2;
9400e3ece09SJunchao Zhang       // cut the row into: Left, [cstart, cend), Right
9410e3ece09SJunchao Zhang       first       = Fj + Fi[i];
9420e3ece09SJunchao Zhang       last        = Fj + Fi[i + 1];
9430e3ece09SJunchao Zhang       lb1         = std::lower_bound(first, last, cstart);
9440e3ece09SJunchao Zhang       F_NzLeft[i] = lb1 - first;
9450e3ece09SJunchao Zhang       lb2         = std::lower_bound(first, last, cend);
9460e3ece09SJunchao Zhang       Fdi[i + 1]  = lb2 - lb1;                        // row i length in Fdi
9470e3ece09SJunchao Zhang       Foi[i + 1]  = (Fi[i + 1] - Fi[i]) - Fdi[i + 1]; // row i length in Foi
9480e3ece09SJunchao Zhang     }
9490e3ece09SJunchao Zhang     for (PetscInt i = 0; i < Fm; i++) {
9500e3ece09SJunchao Zhang       Fdi[i + 1] += Fdi[i];
9510e3ece09SJunchao Zhang       Foi[i + 1] += Foi[i];
9520e3ece09SJunchao Zhang     }
9530e3ece09SJunchao Zhang 
9540e3ece09SJunchao Zhang     // Fill Fdj[] and Foj[], i.e., columns of Fd and Fo. Fdj[] are local, but Foj[] are not yet.
9550e3ece09SJunchao Zhang     PetscInt                Fdnz = Fdi[Fm], Fonz = Foi[Fm];
9567b8d4ba6SJunchao Zhang     MatColIdxKokkosViewHost Fdj_h(NoInit("Fdj_h"), Fdnz), Foj_h(NoInit("Foj_h"), Fonz);
9570e3ece09SJunchao Zhang     MatColIdxType          *Fdj = Fdj_h.data(), *Foj = Foj_h.data(), gid;
9580e3ece09SJunchao Zhang 
9590e3ece09SJunchao Zhang     for (PetscInt i = 0; i < Fm; i++) {
9600e3ece09SJunchao Zhang       PetscInt nzLeft = F_NzLeft[i];
9610e3ece09SJunchao Zhang       PetscInt len    = Fdi[i + 1] - Fdi[i]; // diag row len
9620e3ece09SJunchao Zhang       for (PetscInt j = 0; j < Fi[i + 1] - Fi[i]; j++) {
9630e3ece09SJunchao Zhang         gid = Fj[Fi[i] + j];
9640e3ece09SJunchao Zhang         if (j < nzLeft) { // left, in global
9650e3ece09SJunchao Zhang           Foj[Foi[i] + j] = gid;
9660e3ece09SJunchao Zhang         } else if (j < nzLeft + len) { // diag, in local
9670e3ece09SJunchao Zhang           Fdj[Fdi[i] + j - nzLeft] = gid - cstart;
9680e3ece09SJunchao Zhang         } else { // right, in global
9690e3ece09SJunchao Zhang           Foj[Foi[i] + j - len] = gid;
9700e3ece09SJunchao Zhang         }
9710e3ece09SJunchao Zhang       }
9720e3ece09SJunchao Zhang     }
9730e3ece09SJunchao Zhang     PetscCall(PetscFree2(jbuf, Fj));
9740e3ece09SJunchao Zhang     PetscCall(PetscFree(Fi));
9750e3ece09SJunchao Zhang 
9760e3ece09SJunchao Zhang     // Reduce global indices in Foj[] and garray1[] into local ones
9770e3ece09SJunchao Zhang     PetscInt n2, *garray2;
9780e3ece09SJunchao Zhang     PetscCall(ReduceTwoSetsOfGlobalIndices(n1, garray1, Fonz, Foj, &n2, &garray2, map));
9790e3ece09SJunchao Zhang 
9800e3ece09SJunchao Zhang     // Record the plans built above, for reuse
9810e3ece09SJunchao Zhang     PetscIntKokkosViewHost tmp(const_cast<PetscInt *>(irootloc), ioffset[niranks]); // irootloc[] is owned by ownerSF. We create a copy for safety
9827b8d4ba6SJunchao Zhang     PetscIntKokkosViewHost irootloc_h(NoInit("irootloc_h"), ioffset[niranks]);
9830e3ece09SJunchao Zhang     Kokkos::deep_copy(irootloc_h, tmp);
9840e3ece09SJunchao Zhang     mm->sf        = bcastSF;
9850e3ece09SJunchao Zhang     mm->E_NzLeft  = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), E_NzLeft_h);
9860e3ece09SJunchao Zhang     mm->F_NzLeft  = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), F_NzLeft_h);
9870e3ece09SJunchao Zhang     mm->irootloc  = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), irootloc_h);
9880e3ece09SJunchao Zhang     mm->rowoffset = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), rowoffset_h);
9897b8d4ba6SJunchao Zhang     mm->rootBuf   = MatScalarKokkosView(NoInit("rootBuf"), nroots);
9907b8d4ba6SJunchao Zhang     mm->leafBuf   = MatScalarKokkosView(NoInit("leafBuf"), nleaves);
9910e3ece09SJunchao Zhang     mm->garray    = garray2;
9920e3ece09SJunchao Zhang     mm->n         = n2;
9930e3ece09SJunchao Zhang 
9940e3ece09SJunchao Zhang     // Output Fd and Fo in KokkosCsrMatrix format
9957b8d4ba6SJunchao Zhang     MatScalarKokkosView Fda_d(NoInit("Fda_d"), Fdnz), Foa_d(NoInit("Foa_d"), Fonz);
9960e3ece09SJunchao Zhang     MatRowMapKokkosView Fdi_d = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Fdi_h);
9970e3ece09SJunchao Zhang     MatColIdxKokkosView Fdj_d = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Fdj_h);
9980e3ece09SJunchao Zhang     MatRowMapKokkosView Foi_d = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Foi_h);
9990e3ece09SJunchao Zhang     MatColIdxKokkosView Foj_d = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Foj_h);
10000e3ece09SJunchao Zhang 
10010e3ece09SJunchao Zhang     PetscCallCXX(mm->Fd = KokkosCsrMatrix("Fd", Fm, cend - cstart, Fdnz, Fda_d, Fdi_d, Fdj_d));
10020e3ece09SJunchao Zhang     PetscCallCXX(mm->Fo = KokkosCsrMatrix("Fo", Fm, n2, Fonz, Foa_d, Foi_d, Foj_d));
10030e3ece09SJunchao Zhang 
10040e3ece09SJunchao Zhang     // Compute kernel launch parameters in merging E or splitting F
10050e3ece09SJunchao Zhang     PetscInt teamSize, vectorLength, rowsPerTeam;
10060e3ece09SJunchao Zhang 
10070e3ece09SJunchao Zhang     teamSize = vectorLength = rowsPerTeam = -1;
10080e3ece09SJunchao Zhang     PetscCall(MatMergeGetLaunchParameters<DefaultExecutionSpace>(mm->irootloc.extent(0), mm->rootBuf.extent(0), -1, teamSize, vectorLength, rowsPerTeam));
10090e3ece09SJunchao Zhang     mm->E_TeamSize     = teamSize;
10100e3ece09SJunchao Zhang     mm->E_VectorLength = vectorLength;
10110e3ece09SJunchao Zhang     mm->E_RowsPerTeam  = rowsPerTeam;
10120e3ece09SJunchao Zhang 
10130e3ece09SJunchao Zhang     teamSize = vectorLength = rowsPerTeam = -1;
10140e3ece09SJunchao Zhang     PetscCall(MatMergeGetLaunchParameters<DefaultExecutionSpace>(Fm, Fnz, -1, teamSize, vectorLength, rowsPerTeam));
10150e3ece09SJunchao Zhang     mm->F_TeamSize     = teamSize;
10160e3ece09SJunchao Zhang     mm->F_VectorLength = vectorLength;
10170e3ece09SJunchao Zhang     mm->F_RowsPerTeam  = rowsPerTeam;
10180e3ece09SJunchao Zhang   } else PetscCheck(reuse == MAT_REUSE_MATRIX, comm, PETSC_ERR_PLIB, "Unsupported MatReuse enum %d", reuse);
10190e3ece09SJunchao Zhang 
10200e3ece09SJunchao Zhang   // Sync E's value to device
10210e3ece09SJunchao Zhang   akok->a_dual.sync_device();
10220e3ece09SJunchao Zhang   bkok->a_dual.sync_device();
10230e3ece09SJunchao Zhang 
10240e3ece09SJunchao Zhang   // Handy aliases
10250e3ece09SJunchao Zhang   const auto &Aa = akok->a_dual.view_device();
10260e3ece09SJunchao Zhang   const auto &Ba = bkok->a_dual.view_device();
10270e3ece09SJunchao Zhang   const auto &Ai = akok->i_dual.view_device();
10280e3ece09SJunchao Zhang   const auto &Bi = bkok->i_dual.view_device();
10290e3ece09SJunchao Zhang 
10300e3ece09SJunchao Zhang   // Fetch the plans
10310e3ece09SJunchao Zhang   PetscIntKokkosView  &E_NzLeft  = mm->E_NzLeft;
10320e3ece09SJunchao Zhang   PetscSF             &bcastSF   = mm->sf;
10330e3ece09SJunchao Zhang   MatScalarKokkosView &rootBuf   = mm->rootBuf;
10340e3ece09SJunchao Zhang   MatScalarKokkosView &leafBuf   = mm->leafBuf;
10350e3ece09SJunchao Zhang   PetscIntKokkosView  &irootloc  = mm->irootloc;
10360e3ece09SJunchao Zhang   PetscIntKokkosView  &rowoffset = mm->rowoffset;
10370e3ece09SJunchao Zhang 
10380e3ece09SJunchao Zhang   PetscInt teamSize     = mm->E_TeamSize;
10390e3ece09SJunchao Zhang   PetscInt vectorLength = mm->E_VectorLength;
10400e3ece09SJunchao Zhang   PetscInt rowsPerTeam  = mm->E_RowsPerTeam;
10410e3ece09SJunchao Zhang   PetscInt workSets     = (irootloc.extent(0) + rowsPerTeam - 1) / rowsPerTeam;
10420e3ece09SJunchao Zhang 
10430e3ece09SJunchao Zhang   // Copy rows in A/B of E to rootBuf, then bcast it to leafBuf
10440e3ece09SJunchao Zhang   PetscCallCXX(Kokkos::parallel_for(
1045*d326c3f1SJunchao Zhang     Kokkos::TeamPolicy<>(PetscGetKokkosExecutionSpace(), workSets, teamSize, vectorLength), KOKKOS_LAMBDA(const KokkosTeamMemberType &t) {
10460e3ece09SJunchao Zhang       Kokkos::parallel_for(Kokkos::TeamThreadRange(t, 0, rowsPerTeam), [&](PetscInt k) {
10470e3ece09SJunchao Zhang         size_t r = t.league_rank() * rowsPerTeam + k; // r-th entry in irootloc[]
10480e3ece09SJunchao Zhang         if (r < irootloc.extent(0)) {
10490e3ece09SJunchao Zhang           PetscInt i      = irootloc(r); // row i of E
10500e3ece09SJunchao Zhang           PetscInt disp   = rowoffset(r);
10510e3ece09SJunchao Zhang           PetscInt alen   = Ai(i + 1) - Ai(i);
10520e3ece09SJunchao Zhang           PetscInt blen   = Bi(i + 1) - Bi(i);
10530e3ece09SJunchao Zhang           PetscInt nzleft = E_NzLeft(i);
10540e3ece09SJunchao Zhang 
10550e3ece09SJunchao Zhang           Kokkos::parallel_for(Kokkos::ThreadVectorRange(t, alen + blen), [&](PetscInt j) {
10560e3ece09SJunchao Zhang             if (j < nzleft) { // B left
10570e3ece09SJunchao Zhang               rootBuf(disp + j) = Ba(Bi(i) + j);
10580e3ece09SJunchao Zhang             } else if (j < nzleft + alen) { // diag A
10590e3ece09SJunchao Zhang               rootBuf(disp + j) = Aa(Ai(i) + j - nzleft);
10600e3ece09SJunchao Zhang             } else { // B right
10610e3ece09SJunchao Zhang               rootBuf(disp + j) = Ba(Bi(i) + j - alen);
10620e3ece09SJunchao Zhang             }
10630e3ece09SJunchao Zhang           });
10640e3ece09SJunchao Zhang         }
10650e3ece09SJunchao Zhang       });
10660e3ece09SJunchao Zhang     }));
10670e3ece09SJunchao Zhang   PetscCall(PetscSFBcastWithMemTypeBegin(bcastSF, MPIU_SCALAR, PETSC_MEMTYPE_KOKKOS, rootBuf.data(), PETSC_MEMTYPE_KOKKOS, leafBuf.data(), MPI_REPLACE));
10680e3ece09SJunchao Zhang   PetscFunctionReturn(PETSC_SUCCESS);
10690e3ece09SJunchao Zhang }
10700e3ece09SJunchao Zhang 
10710e3ece09SJunchao Zhang // To finish MatMPIAIJKokkosBcast.
10720e3ece09SJunchao Zhang static PetscErrorCode MatMPIAIJKokkosBcastEnd(Mat E, PetscSF ownerSF, MatReuse reuse, PetscInt *map, MatMatStruct_AB *mm)
10730e3ece09SJunchao Zhang {
10740e3ece09SJunchao Zhang   PetscFunctionBegin;
10750e3ece09SJunchao Zhang   const auto &Fd  = mm->Fd;
10760e3ece09SJunchao Zhang   const auto &Fo  = mm->Fo;
10770e3ece09SJunchao Zhang   const auto &Fdi = Fd.graph.row_map;
10780e3ece09SJunchao Zhang   const auto &Foi = Fo.graph.row_map;
10790e3ece09SJunchao Zhang   auto       &Fda = Fd.values;
10800e3ece09SJunchao Zhang   auto       &Foa = Fo.values;
10810e3ece09SJunchao Zhang   auto        Fm  = Fd.numRows();
10820e3ece09SJunchao Zhang 
10830e3ece09SJunchao Zhang   PetscIntKokkosView  &F_NzLeft     = mm->F_NzLeft;
10840e3ece09SJunchao Zhang   PetscSF             &bcastSF      = mm->sf;
10850e3ece09SJunchao Zhang   MatScalarKokkosView &rootBuf      = mm->rootBuf;
10860e3ece09SJunchao Zhang   MatScalarKokkosView &leafBuf      = mm->leafBuf;
10870e3ece09SJunchao Zhang   PetscInt             teamSize     = mm->F_TeamSize;
10880e3ece09SJunchao Zhang   PetscInt             vectorLength = mm->F_VectorLength;
10890e3ece09SJunchao Zhang   PetscInt             rowsPerTeam  = mm->F_RowsPerTeam;
10900e3ece09SJunchao Zhang   PetscInt             workSets     = (Fm + rowsPerTeam - 1) / rowsPerTeam;
10910e3ece09SJunchao Zhang 
10920e3ece09SJunchao Zhang   PetscCall(PetscSFBcastEnd(bcastSF, MPIU_SCALAR, rootBuf.data(), leafBuf.data(), MPI_REPLACE));
10930e3ece09SJunchao Zhang 
10940e3ece09SJunchao Zhang   // Update Fda and Foa with new data in leafBuf (as if it is Fa)
10950e3ece09SJunchao Zhang   PetscCallCXX(Kokkos::parallel_for(
1096*d326c3f1SJunchao Zhang     Kokkos::TeamPolicy<>(PetscGetKokkosExecutionSpace(), workSets, teamSize, vectorLength), KOKKOS_LAMBDA(const KokkosTeamMemberType &t) {
10970e3ece09SJunchao Zhang       Kokkos::parallel_for(Kokkos::TeamThreadRange(t, 0, rowsPerTeam), [&](PetscInt k) {
10980e3ece09SJunchao Zhang         PetscInt i = t.league_rank() * rowsPerTeam + k; // i-th row in F
10990e3ece09SJunchao Zhang         if (i < Fm) {
11000e3ece09SJunchao Zhang           PetscInt nzLeft = F_NzLeft(i);
11010e3ece09SJunchao Zhang           PetscInt alen   = Fdi(i + 1) - Fdi(i);
11020e3ece09SJunchao Zhang           PetscInt blen   = Foi(i + 1) - Foi(i);
11030e3ece09SJunchao Zhang           PetscInt Fii    = Fdi(i) + Foi(i);
11040e3ece09SJunchao Zhang 
11050e3ece09SJunchao Zhang           Kokkos::parallel_for(Kokkos::ThreadVectorRange(t, alen + blen), [&](PetscInt j) {
11060e3ece09SJunchao Zhang             PetscScalar val = leafBuf(Fii + j);
11070e3ece09SJunchao Zhang             if (j < nzLeft) { // left
11080e3ece09SJunchao Zhang               Foa(Foi(i) + j) = val;
11090e3ece09SJunchao Zhang             } else if (j < nzLeft + alen) { // diag
11100e3ece09SJunchao Zhang               Fda(Fdi(i) + j - nzLeft) = val;
11110e3ece09SJunchao Zhang             } else { // right
11120e3ece09SJunchao Zhang               Foa(Foi(i) + j - alen) = val;
11130e3ece09SJunchao Zhang             }
11140e3ece09SJunchao Zhang           });
11150e3ece09SJunchao Zhang         }
11160e3ece09SJunchao Zhang       });
11170e3ece09SJunchao Zhang     }));
11180e3ece09SJunchao Zhang   PetscFunctionReturn(PETSC_SUCCESS);
11190e3ece09SJunchao Zhang }
11200e3ece09SJunchao Zhang 
11210e3ece09SJunchao Zhang static PetscErrorCode MatProductSymbolic_MPIAIJKokkos_AtB(Mat_Product *product, Mat A, Mat B, MatMatStruct_AtB *mm)
11220e3ece09SJunchao Zhang {
11230e3ece09SJunchao Zhang   Mat_MPIAIJ     *ampi = static_cast<Mat_MPIAIJ *>(A->data);
11240e3ece09SJunchao Zhang   Mat_MPIAIJ     *bmpi = static_cast<Mat_MPIAIJ *>(B->data);
11250e3ece09SJunchao Zhang   KokkosCsrMatrix Adt, Aot, Ad, Ao, Bd, Bo;
11260e3ece09SJunchao Zhang   PetscInt        cstart, cend;
11270e3ece09SJunchao Zhang   MPI_Comm        comm;
11280e3ece09SJunchao Zhang 
11290e3ece09SJunchao Zhang   PetscFunctionBegin;
11300e3ece09SJunchao Zhang   PetscCall(PetscObjectGetComm((PetscObject)B, &comm));
11310e3ece09SJunchao Zhang   PetscCall(MatSeqAIJKokkosGenerateTranspose_Private(ampi->A, &Adt));
11320e3ece09SJunchao Zhang   PetscCall(MatSeqAIJKokkosGenerateTranspose_Private(ampi->B, &Aot));
11330e3ece09SJunchao Zhang   PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(ampi->A, &Ad));
11340e3ece09SJunchao Zhang   PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(ampi->B, &Ao));
11350e3ece09SJunchao Zhang   PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(bmpi->A, &Bd));
11360e3ece09SJunchao Zhang   PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(bmpi->B, &Bo));
11370e3ece09SJunchao Zhang 
11380e3ece09SJunchao Zhang   // TODO: add command line options to select spgemm algorithms
11390e3ece09SJunchao Zhang   auto spgemm_alg = KokkosSparse::SPGEMMAlgorithm::SPGEMM_DEFAULT; // default is TPL if enabled, otherwise KK
11400e3ece09SJunchao Zhang 
11410e3ece09SJunchao Zhang   // CUDA-10.2's spgemm has bugs. We prefer the SpGEMMreuse APIs introduced in cuda-11.4
11420e3ece09SJunchao Zhang #if defined(KOKKOSKERNELS_ENABLE_TPL_CUSPARSE)
11430e3ece09SJunchao Zhang   #if PETSC_PKG_CUDA_VERSION_LT(11, 4, 0)
11440e3ece09SJunchao Zhang   spgemm_alg = KokkosSparse::SPGEMMAlgorithm::SPGEMM_KK;
1145f0e6e2d1SJunchao Zhang   #endif
11460e3ece09SJunchao Zhang #endif
11470e3ece09SJunchao Zhang 
11480e3ece09SJunchao Zhang   PetscCallCXX(mm->kh1.create_spgemm_handle(spgemm_alg));
11490e3ece09SJunchao Zhang   PetscCallCXX(mm->kh2.create_spgemm_handle(spgemm_alg));
11500e3ece09SJunchao Zhang   PetscCallCXX(mm->kh3.create_spgemm_handle(spgemm_alg));
11510e3ece09SJunchao Zhang   PetscCallCXX(mm->kh4.create_spgemm_handle(spgemm_alg));
11520e3ece09SJunchao Zhang 
11530e3ece09SJunchao Zhang   // Aot * (B's diag + B's off-diag)
11540e3ece09SJunchao Zhang   PetscCallCXX(KokkosSparse::spgemm_symbolic(mm->kh3, Aot, false, Bd, false, mm->C3));
11550e3ece09SJunchao Zhang   PetscCallCXX(KokkosSparse::spgemm_symbolic(mm->kh4, Aot, false, Bo, false, mm->C4));
11560e3ece09SJunchao Zhang   // KK spgemm_symbolic() only populates the result's row map, but not its columns.
11570e3ece09SJunchao Zhang   // TODO: Remove the fake spgemm_numeric() after KK fixed this problem.
11580e3ece09SJunchao Zhang   PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh3, Aot, false, Bd, false, mm->C3));
11590e3ece09SJunchao Zhang   PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh4, Aot, false, Bo, false, mm->C4));
11600e3ece09SJunchao Zhang #if PETSC_PKG_KOKKOS_KERNELS_VERSION_LT(4, 0, 0)
1161*d326c3f1SJunchao Zhang 
11620e3ece09SJunchao Zhang   PetscCallCXX(sort_crs_matrix(mm->C3));
11630e3ece09SJunchao Zhang   PetscCallCXX(sort_crs_matrix(mm->C4));
11640e3ece09SJunchao Zhang #endif
11650e3ece09SJunchao Zhang 
11660e3ece09SJunchao Zhang   // Reduce E (i.e., C3 and C4)'s rows to form F, and overlap the communication
11677b8d4ba6SJunchao Zhang   PetscIntKokkosViewHost map_h(NoInit("map_h"), bmpi->B->cmap->n);
11680e3ece09SJunchao Zhang   PetscCall(MatGetOwnershipRangeColumn(B, &cstart, &cend));
11690e3ece09SJunchao Zhang   PetscCall(MatMPIAIJKokkosReduceBegin(comm, mm->C3, mm->C4, cstart, cend, bmpi->garray, ampi->Mvctx, MAT_INITIAL_MATRIX, map_h.data(), mm));
11700e3ece09SJunchao Zhang 
11710e3ece09SJunchao Zhang   // Adt * (B's diag + B's off-diag)
11720e3ece09SJunchao Zhang   PetscCallCXX(KokkosSparse::spgemm_symbolic(mm->kh1, Adt, false, Bd, false, mm->C1));
11730e3ece09SJunchao Zhang   PetscCallCXX(KokkosSparse::spgemm_symbolic(mm->kh2, Adt, false, Bo, false, mm->C2_mid));
11740e3ece09SJunchao Zhang   PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh1, Adt, false, Bd, false, mm->C1));
11750e3ece09SJunchao Zhang   PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh2, Adt, false, Bo, false, mm->C2_mid));
11760e3ece09SJunchao Zhang #if PETSC_PKG_KOKKOS_KERNELS_VERSION_LT(4, 0, 0)
11770e3ece09SJunchao Zhang   PetscCallCXX(sort_crs_matrix(mm->C1));
11780e3ece09SJunchao Zhang   PetscCallCXX(sort_crs_matrix(mm->C2_mid));
11790e3ece09SJunchao Zhang #endif
11800e3ece09SJunchao Zhang 
11810e3ece09SJunchao Zhang   PetscCall(MatMPIAIJKokkosReduceEnd(comm, mm->C3, mm->C4, cstart, cend, bmpi->garray, ampi->Mvctx, MAT_INITIAL_MATRIX, map_h.data(), mm));
11820e3ece09SJunchao Zhang 
11830e3ece09SJunchao Zhang   // Create C2, which shares a, i arrays with C2_mid, but with new column indices and potentially larger column size
11847b8d4ba6SJunchao Zhang   MatColIdxKokkosView oldj = mm->C2_mid.graph.entries, newj(NoInit("j"), oldj.extent(0));
11850e3ece09SJunchao Zhang   PetscIntKokkosView  map  = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), map_h);
1186*d326c3f1SJunchao Zhang   PetscCallCXX(Kokkos::parallel_for(Kokkos::RangePolicy<>(PetscGetKokkosExecutionSpace(), 0, oldj.extent(0)), KOKKOS_LAMBDA(const PetscInt i) { newj(i) = map(oldj(i)); }));
11870e3ece09SJunchao Zhang   PetscCallCXX(mm->C2 = KokkosCsrMatrix("C2", mm->C2_mid.numRows(), mm->n /*new column size*/, mm->C2_mid.nnz(), mm->C2_mid.values, mm->C2_mid.graph.row_map, newj));
11880e3ece09SJunchao Zhang 
11890e3ece09SJunchao Zhang   // C = (C1+Fd, C2+Fo)
11900e3ece09SJunchao Zhang   PetscCallCXX(mm->kh1.create_spadd_handle(true)); // C1, Fd are sorted
11910e3ece09SJunchao Zhang   PetscCallCXX(mm->kh2.create_spadd_handle(true)); // C2, Fo are sorted
11920e3ece09SJunchao Zhang   PetscCallCXX(KokkosSparse::spadd_symbolic(&mm->kh1, mm->C1, mm->Fd, mm->Cd));
11930e3ece09SJunchao Zhang   PetscCallCXX(KokkosSparse::spadd_symbolic(&mm->kh2, mm->C2, mm->Fo, mm->Co));
11940e3ece09SJunchao Zhang   PetscCallCXX(KokkosSparse::spadd_numeric(&mm->kh1, 1.0, mm->C1, 1.0, mm->Fd, mm->Cd));
11950e3ece09SJunchao Zhang   PetscCallCXX(KokkosSparse::spadd_numeric(&mm->kh2, 1.0, mm->C2, 1.0, mm->Fo, mm->Co));
11960e3ece09SJunchao Zhang   PetscFunctionReturn(PETSC_SUCCESS);
11970e3ece09SJunchao Zhang }
11980e3ece09SJunchao Zhang 
11990e3ece09SJunchao Zhang static PetscErrorCode MatProductNumeric_MPIAIJKokkos_AtB(Mat_Product *product, Mat A, Mat B, MatMatStruct_AtB *mm)
12000e3ece09SJunchao Zhang {
12010e3ece09SJunchao Zhang   Mat_MPIAIJ     *ampi = static_cast<Mat_MPIAIJ *>(A->data);
12020e3ece09SJunchao Zhang   Mat_MPIAIJ     *bmpi = static_cast<Mat_MPIAIJ *>(B->data);
12030e3ece09SJunchao Zhang   KokkosCsrMatrix Adt, Aot, Bd, Bo;
12040e3ece09SJunchao Zhang   MPI_Comm        comm;
12050e3ece09SJunchao Zhang 
12060e3ece09SJunchao Zhang   PetscFunctionBegin;
12070e3ece09SJunchao Zhang   PetscCall(PetscObjectGetComm((PetscObject)B, &comm));
12080e3ece09SJunchao Zhang   PetscCall(MatSeqAIJKokkosGenerateTranspose_Private(ampi->A, &Adt));
12090e3ece09SJunchao Zhang   PetscCall(MatSeqAIJKokkosGenerateTranspose_Private(ampi->B, &Aot));
12100e3ece09SJunchao Zhang   PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(bmpi->A, &Bd));
12110e3ece09SJunchao Zhang   PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(bmpi->B, &Bo));
12120e3ece09SJunchao Zhang 
12130e3ece09SJunchao Zhang   // Aot * (B's diag + B's off-diag)
12140e3ece09SJunchao Zhang   PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh3, Aot, false, Bd, false, mm->C3));
12150e3ece09SJunchao Zhang   PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh4, Aot, false, Bo, false, mm->C4));
12160e3ece09SJunchao Zhang 
12170e3ece09SJunchao Zhang   // Reduce E (i.e., C3 and C4)'s rows to form F, and overlap the communication
12180e3ece09SJunchao Zhang   PetscCall(MatMPIAIJKokkosReduceBegin(comm, mm->C3, mm->C4, 0, 0, NULL, NULL, MAT_REUSE_MATRIX, NULL, mm));
12190e3ece09SJunchao Zhang 
12200e3ece09SJunchao Zhang   // Adt * (B's diag + B's off-diag)
12210e3ece09SJunchao Zhang   PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh1, Adt, false, Bd, false, mm->C1));
12220e3ece09SJunchao Zhang   PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh2, Adt, false, Bo, false, mm->C2_mid));
12230e3ece09SJunchao Zhang 
12240e3ece09SJunchao Zhang   PetscCall(MatMPIAIJKokkosReduceEnd(comm, mm->C3, mm->C4, 0, 0, NULL, NULL, MAT_REUSE_MATRIX, NULL, mm));
12250e3ece09SJunchao Zhang 
12260e3ece09SJunchao Zhang   // C = (C1+Fd, C2+Fo)
12270e3ece09SJunchao Zhang   PetscCallCXX(KokkosSparse::spadd_numeric(&mm->kh1, 1.0, mm->C1, 1.0, mm->Fd, mm->Cd));
12280e3ece09SJunchao Zhang   PetscCallCXX(KokkosSparse::spadd_numeric(&mm->kh2, 1.0, mm->C2, 1.0, mm->Fo, mm->Co));
12290e3ece09SJunchao Zhang   PetscFunctionReturn(PETSC_SUCCESS);
12300e3ece09SJunchao Zhang }
1231f0e6e2d1SJunchao Zhang 
1232076ba34aSJunchao Zhang /* MatProductSymbolic_MPIAIJKokkos_AB - AB flavor of MatProductSymbolic_MPIAIJKokkos
1233076ba34aSJunchao Zhang 
1234076ba34aSJunchao Zhang   Input Parameters:
1235076ba34aSJunchao Zhang +  product  - Mat_Product which carried out the computation. Passed in to access info about this mat product.
1236076ba34aSJunchao Zhang .  A        - an MPIAIJKOKKOS matrix
1237076ba34aSJunchao Zhang .  B        - an MPIAIJKOKKOS matrix
1238076ba34aSJunchao Zhang -  mm       - a struct used to stash intermediate data when computing AB. Persist from symbolic to numeric operations.
1239076ba34aSJunchao Zhang */
1240d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatProductSymbolic_MPIAIJKokkos_AB(Mat_Product *product, Mat A, Mat B, MatMatStruct_AB *mm)
1241d71ae5a4SJacob Faibussowitsch {
12420e3ece09SJunchao Zhang   Mat_MPIAIJ     *ampi = static_cast<Mat_MPIAIJ *>(A->data);
12430e3ece09SJunchao Zhang   Mat_MPIAIJ     *bmpi = static_cast<Mat_MPIAIJ *>(B->data);
12440e3ece09SJunchao Zhang   KokkosCsrMatrix Ad, Ao, Bd, Bo;
1245076ba34aSJunchao Zhang 
1246076ba34aSJunchao Zhang   PetscFunctionBegin;
12470e3ece09SJunchao Zhang   PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(ampi->A, &Ad));
12480e3ece09SJunchao Zhang   PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(ampi->B, &Ao));
12490e3ece09SJunchao Zhang   PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(bmpi->A, &Bd));
12500e3ece09SJunchao Zhang   PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(bmpi->B, &Bo));
12510e3ece09SJunchao Zhang 
12520e3ece09SJunchao Zhang   // TODO: add command line options to select spgemm algorithms
12530e3ece09SJunchao Zhang   auto spgemm_alg = KokkosSparse::SPGEMMAlgorithm::SPGEMM_DEFAULT; // default is TPL if enabled, otherwise KK
12540e3ece09SJunchao Zhang 
12550e3ece09SJunchao Zhang   // CUDA-10.2's spgemm has bugs. We prefer the SpGEMMreuse APIs introduced in cuda-11.4
12560e3ece09SJunchao Zhang #if defined(KOKKOSKERNELS_ENABLE_TPL_CUSPARSE)
12570e3ece09SJunchao Zhang   #if PETSC_PKG_CUDA_VERSION_LT(11, 4, 0)
12580e3ece09SJunchao Zhang   spgemm_alg = KokkosSparse::SPGEMMAlgorithm::SPGEMM_KK;
12590e3ece09SJunchao Zhang   #endif
1260f0e6e2d1SJunchao Zhang #endif
1261f0e6e2d1SJunchao Zhang 
12620e3ece09SJunchao Zhang   mm->kh1.create_spgemm_handle(spgemm_alg);
12630e3ece09SJunchao Zhang   mm->kh2.create_spgemm_handle(spgemm_alg);
12640e3ece09SJunchao Zhang   mm->kh3.create_spgemm_handle(spgemm_alg);
12650e3ece09SJunchao Zhang   mm->kh4.create_spgemm_handle(spgemm_alg);
1266076ba34aSJunchao Zhang 
12670e3ece09SJunchao Zhang   // Bcast B's rows to form F, and overlap the communication
12687b8d4ba6SJunchao Zhang   PetscIntKokkosViewHost map_h(NoInit("map_h"), bmpi->B->cmap->n);
12690e3ece09SJunchao Zhang   PetscCall(MatMPIAIJKokkosBcastBegin(B, ampi->Mvctx, MAT_INITIAL_MATRIX, map_h.data(), mm));
1270076ba34aSJunchao Zhang 
12710e3ece09SJunchao Zhang   // A's diag * (B's diag + B's off-diag)
12720e3ece09SJunchao Zhang   PetscCallCXX(KokkosSparse::spgemm_symbolic(mm->kh1, Ad, false, Bd, false, mm->C1));
12730e3ece09SJunchao Zhang   PetscCallCXX(KokkosSparse::spgemm_symbolic(mm->kh2, Ad, false, Bo, false, mm->C2_mid)); // C2 aliases with C2_mid, except with new column indices
12740e3ece09SJunchao Zhang   // KK spgemm_symbolic() only populates the result's row map, but not its columns.
12750e3ece09SJunchao Zhang   // TODO: Remove the fake spgemm_numeric() after KK fixed this problem.
12760e3ece09SJunchao Zhang   PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh1, Ad, false, Bd, false, mm->C1));
12770e3ece09SJunchao Zhang   PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh2, Ad, false, Bo, false, mm->C2_mid));
12780e3ece09SJunchao Zhang #if PETSC_PKG_KOKKOS_KERNELS_VERSION_LT(4, 0, 0)
12790e3ece09SJunchao Zhang   PetscCallCXX(sort_crs_matrix(mm->C1));
12800e3ece09SJunchao Zhang   PetscCallCXX(sort_crs_matrix(mm->C2_mid));
12810e3ece09SJunchao Zhang #endif
1282076ba34aSJunchao Zhang 
12830e3ece09SJunchao Zhang   PetscCall(MatMPIAIJKokkosBcastEnd(B, ampi->Mvctx, MAT_INITIAL_MATRIX, map_h.data(), mm));
1284076ba34aSJunchao Zhang 
12850e3ece09SJunchao Zhang   // A's off-diag * (F's diag + F's off-diag)
12860e3ece09SJunchao Zhang   PetscCallCXX(KokkosSparse::spgemm_symbolic(mm->kh3, Ao, false, mm->Fd, false, mm->C3));
12870e3ece09SJunchao Zhang   PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh3, Ao, false, mm->Fd, false, mm->C3));
12880e3ece09SJunchao Zhang   PetscCallCXX(KokkosSparse::spgemm_symbolic(mm->kh4, Ao, false, mm->Fo, false, mm->C4));
12890e3ece09SJunchao Zhang   PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh4, Ao, false, mm->Fo, false, mm->C4));
12900e3ece09SJunchao Zhang #if PETSC_PKG_KOKKOS_KERNELS_VERSION_LT(4, 0, 0)
12910e3ece09SJunchao Zhang   PetscCallCXX(sort_crs_matrix(mm->C3));
12920e3ece09SJunchao Zhang   PetscCallCXX(sort_crs_matrix(mm->C4));
12930e3ece09SJunchao Zhang #endif
1294076ba34aSJunchao Zhang 
12950e3ece09SJunchao Zhang   // Create C2, which shares a, i arrays with C2_mid, but with new column indices and potentially larger column size
12967b8d4ba6SJunchao Zhang   MatColIdxKokkosView oldj = mm->C2_mid.graph.entries, newj(NoInit("j"), oldj.extent(0));
12970e3ece09SJunchao Zhang   PetscIntKokkosView  map  = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), map_h);
1298*d326c3f1SJunchao Zhang   PetscCallCXX(Kokkos::parallel_for(Kokkos::RangePolicy<>(PetscGetKokkosExecutionSpace(), 0, oldj.extent(0)), KOKKOS_LAMBDA(const PetscInt i) { newj(i) = map(oldj(i)); }));
12990e3ece09SJunchao Zhang   mm->C2 = KokkosCsrMatrix("C2", mm->C2_mid.numRows(), mm->n /*new column size*/, mm->C2_mid.nnz(), mm->C2_mid.values, mm->C2_mid.graph.row_map, newj);
13000e3ece09SJunchao Zhang 
13010e3ece09SJunchao Zhang   // C = (Cd, Co) = (C1+C3, C2+C4)
13020e3ece09SJunchao Zhang   mm->kh1.create_spadd_handle(true); // C1, C3 are sorted
13030e3ece09SJunchao Zhang   mm->kh2.create_spadd_handle(true); // C2, C4 are sorted
13040e3ece09SJunchao Zhang   PetscCallCXX(KokkosSparse::spadd_symbolic(&mm->kh1, mm->C1, mm->C3, mm->Cd));
13050e3ece09SJunchao Zhang   PetscCallCXX(KokkosSparse::spadd_symbolic(&mm->kh2, mm->C2, mm->C4, mm->Co));
13060e3ece09SJunchao Zhang   PetscCallCXX(KokkosSparse::spadd_numeric(&mm->kh1, 1.0, mm->C1, 1.0, mm->C3, mm->Cd));
13070e3ece09SJunchao Zhang   PetscCallCXX(KokkosSparse::spadd_numeric(&mm->kh2, 1.0, mm->C2, 1.0, mm->C4, mm->Co));
13083ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
1309076ba34aSJunchao Zhang }
1310076ba34aSJunchao Zhang 
13110e3ece09SJunchao Zhang static PetscErrorCode MatProductNumeric_MPIAIJKokkos_AB(Mat_Product *product, Mat A, Mat B, MatMatStruct_AB *mm)
1312d71ae5a4SJacob Faibussowitsch {
13130e3ece09SJunchao Zhang   Mat_MPIAIJ     *ampi = static_cast<Mat_MPIAIJ *>(A->data);
13140e3ece09SJunchao Zhang   Mat_MPIAIJ     *bmpi = static_cast<Mat_MPIAIJ *>(B->data);
13150e3ece09SJunchao Zhang   KokkosCsrMatrix Ad, Ao, Bd, Bo;
1316076ba34aSJunchao Zhang 
1317076ba34aSJunchao Zhang   PetscFunctionBegin;
13180e3ece09SJunchao Zhang   PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(ampi->A, &Ad));
13190e3ece09SJunchao Zhang   PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(ampi->B, &Ao));
13200e3ece09SJunchao Zhang   PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(bmpi->A, &Bd));
13210e3ece09SJunchao Zhang   PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(bmpi->B, &Bo));
1322076ba34aSJunchao Zhang 
13230e3ece09SJunchao Zhang   // Bcast B's rows to form F, and overlap the communication
13240e3ece09SJunchao Zhang   PetscCall(MatMPIAIJKokkosBcastBegin(B, NULL, MAT_REUSE_MATRIX, NULL, mm));
1325076ba34aSJunchao Zhang 
13260e3ece09SJunchao Zhang   // A's diag * (B's diag + B's off-diag)
13270e3ece09SJunchao Zhang   PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh1, Ad, false, Bd, false, mm->C1));
13280e3ece09SJunchao Zhang   PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh2, Ad, false, Bo, false, mm->C2_mid));
1329076ba34aSJunchao Zhang 
13300e3ece09SJunchao Zhang   PetscCall(MatMPIAIJKokkosBcastEnd(B, NULL, MAT_REUSE_MATRIX, NULL, mm));
1331076ba34aSJunchao Zhang 
13320e3ece09SJunchao Zhang   // A's off-diag * (F's diag + F's off-diag)
13330e3ece09SJunchao Zhang   PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh3, Ao, false, mm->Fd, false, mm->C3));
13340e3ece09SJunchao Zhang   PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh4, Ao, false, mm->Fo, false, mm->C4));
13350e3ece09SJunchao Zhang 
13360e3ece09SJunchao Zhang   // C = (Cd, Co) = (C1+C3, C2+C4)
13370e3ece09SJunchao Zhang   PetscCallCXX(KokkosSparse::spadd_numeric(&mm->kh1, 1.0, mm->C1, 1.0, mm->C3, mm->Cd));
13380e3ece09SJunchao Zhang   PetscCallCXX(KokkosSparse::spadd_numeric(&mm->kh2, 1.0, mm->C2, 1.0, mm->C4, mm->Co));
13393ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
1340076ba34aSJunchao Zhang }
1341076ba34aSJunchao Zhang 
134266976f2fSJacob Faibussowitsch static PetscErrorCode MatProductNumeric_MPIAIJKokkos(Mat C)
1343d71ae5a4SJacob Faibussowitsch {
13440e3ece09SJunchao Zhang   Mat_MPIAIJ                  *cmpi = static_cast<Mat_MPIAIJ *>(C->data);
13450e3ece09SJunchao Zhang   Mat_Product                 *product;
13460e3ece09SJunchao Zhang   MatProductData_MPIAIJKokkos *pdata;
1347076ba34aSJunchao Zhang   MatProductType               ptype;
13480e3ece09SJunchao Zhang   Mat                          A, B;
1349076ba34aSJunchao Zhang 
1350076ba34aSJunchao Zhang   PetscFunctionBegin;
13510e3ece09SJunchao Zhang   MatCheckProduct(C, 1); // make sure C is a product
13520e3ece09SJunchao Zhang   product = C->product;
13530e3ece09SJunchao Zhang   pdata   = static_cast<MatProductData_MPIAIJKokkos *>(product->data);
1354076ba34aSJunchao Zhang   ptype   = product->type;
1355076ba34aSJunchao Zhang   A       = product->A;
1356076ba34aSJunchao Zhang   B       = product->B;
1357076ba34aSJunchao Zhang 
13580e3ece09SJunchao Zhang   // See if numeric has already been done in symbolic (e.g., user calls MatMatMult(A,B,MAT_INITIAL_MATRIX,..,C)).
13590e3ece09SJunchao Zhang   // If yes, skip the numeric, but reset the flag so that next time when user calls MatMatMult(E,F,MAT_REUSE_MATRIX,..,C),
13600e3ece09SJunchao Zhang   // we still do numeric.
13610e3ece09SJunchao Zhang   if (pdata->reusesym) { // numeric reuses results from symbolic
13620e3ece09SJunchao Zhang     pdata->reusesym = PETSC_FALSE;
13633ba16761SJacob Faibussowitsch     PetscFunctionReturn(PETSC_SUCCESS);
1364076ba34aSJunchao Zhang   }
1365076ba34aSJunchao Zhang 
1366076ba34aSJunchao Zhang   if (ptype == MATPRODUCT_AB) {
13670e3ece09SJunchao Zhang     PetscCall(MatProductNumeric_MPIAIJKokkos_AB(product, A, B, pdata->mmAB));
1368076ba34aSJunchao Zhang   } else if (ptype == MATPRODUCT_AtB) {
13690e3ece09SJunchao Zhang     PetscCall(MatProductNumeric_MPIAIJKokkos_AtB(product, A, B, pdata->mmAtB));
13700e3ece09SJunchao Zhang   } else if (ptype == MATPRODUCT_PtAP) { // BtAB, computed by Z = AB; C= BtZ
13710e3ece09SJunchao Zhang     PetscCall(MatProductNumeric_MPIAIJKokkos_AB(product, A, B, pdata->mmAB));
13720e3ece09SJunchao Zhang     PetscCall(MatProductNumeric_MPIAIJKokkos_AtB(product, B, pdata->Z, pdata->mmAtB));
1373076ba34aSJunchao Zhang   }
13740e3ece09SJunchao Zhang 
13750e3ece09SJunchao Zhang   PetscCall(MatSeqAIJKokkosModifyDevice(cmpi->A)); // mark that A, B on device are modified
13760e3ece09SJunchao Zhang   PetscCall(MatSeqAIJKokkosModifyDevice(cmpi->B));
13773ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
1378076ba34aSJunchao Zhang }
1379076ba34aSJunchao Zhang 
138066976f2fSJacob Faibussowitsch static PetscErrorCode MatProductSymbolic_MPIAIJKokkos(Mat C)
1381d71ae5a4SJacob Faibussowitsch {
1382076ba34aSJunchao Zhang   Mat                          A, B;
13830e3ece09SJunchao Zhang   Mat_Product                 *product;
1384076ba34aSJunchao Zhang   MatProductType               ptype;
13850e3ece09SJunchao Zhang   MatProductData_MPIAIJKokkos *pdata;
1386076ba34aSJunchao Zhang   MatMatStruct                *mm = NULL;
13870e3ece09SJunchao Zhang   PetscInt                     m, n, M, N;
13880e3ece09SJunchao Zhang   Mat                          Cd, Co;
13890e3ece09SJunchao Zhang   MPI_Comm                     comm;
1390076ba34aSJunchao Zhang 
1391076ba34aSJunchao Zhang   PetscFunctionBegin;
13920e3ece09SJunchao Zhang   PetscCall(PetscObjectGetComm((PetscObject)C, &comm));
1393076ba34aSJunchao Zhang   MatCheckProduct(C, 1);
13940e3ece09SJunchao Zhang   product = C->product;
13950e3ece09SJunchao Zhang   PetscCheck(!product->data, comm, PETSC_ERR_PLIB, "Product data not empty");
1396076ba34aSJunchao Zhang   ptype = product->type;
1397076ba34aSJunchao Zhang   A     = product->A;
1398076ba34aSJunchao Zhang   B     = product->B;
1399076ba34aSJunchao Zhang 
1400076ba34aSJunchao Zhang   switch (ptype) {
14019371c9d4SSatish Balay   case MATPRODUCT_AB:
14029371c9d4SSatish Balay     m = A->rmap->n;
14039371c9d4SSatish Balay     n = B->cmap->n;
14049371c9d4SSatish Balay     M = A->rmap->N;
14059371c9d4SSatish Balay     N = B->cmap->N;
14069371c9d4SSatish Balay     break;
14079371c9d4SSatish Balay   case MATPRODUCT_AtB:
14089371c9d4SSatish Balay     m = A->cmap->n;
14099371c9d4SSatish Balay     n = B->cmap->n;
14109371c9d4SSatish Balay     M = A->cmap->N;
14119371c9d4SSatish Balay     N = B->cmap->N;
14129371c9d4SSatish Balay     break;
14139371c9d4SSatish Balay   case MATPRODUCT_PtAP:
14149371c9d4SSatish Balay     m = B->cmap->n;
14159371c9d4SSatish Balay     n = B->cmap->n;
14169371c9d4SSatish Balay     M = B->cmap->N;
14179371c9d4SSatish Balay     N = B->cmap->N;
14189371c9d4SSatish Balay     break; /* BtAB */
1419d71ae5a4SJacob Faibussowitsch   default:
14200e3ece09SJunchao Zhang     SETERRQ(comm, PETSC_ERR_PLIB, "Not for product type %s", MatProductTypes[ptype]);
1421076ba34aSJunchao Zhang   }
1422076ba34aSJunchao Zhang 
14239566063dSJacob Faibussowitsch   PetscCall(MatSetSizes(C, m, n, M, N));
14249566063dSJacob Faibussowitsch   PetscCall(PetscLayoutSetUp(C->rmap));
14259566063dSJacob Faibussowitsch   PetscCall(PetscLayoutSetUp(C->cmap));
14269566063dSJacob Faibussowitsch   PetscCall(MatSetType(C, ((PetscObject)A)->type_name));
1427076ba34aSJunchao Zhang 
14280e3ece09SJunchao Zhang   pdata           = new MatProductData_MPIAIJKokkos();
14290e3ece09SJunchao Zhang   pdata->reusesym = product->api_user;
1430076ba34aSJunchao Zhang 
1431076ba34aSJunchao Zhang   if (ptype == MATPRODUCT_AB) {
14320e3ece09SJunchao Zhang     auto mmAB = new MatMatStruct_AB();
14330e3ece09SJunchao Zhang     PetscCall(MatProductSymbolic_MPIAIJKokkos_AB(product, A, B, mmAB));
14340e3ece09SJunchao Zhang     mm = pdata->mmAB = mmAB;
1435076ba34aSJunchao Zhang   } else if (ptype == MATPRODUCT_AtB) {
14360e3ece09SJunchao Zhang     auto mmAtB = new MatMatStruct_AtB();
14370e3ece09SJunchao Zhang     PetscCall(MatProductSymbolic_MPIAIJKokkos_AtB(product, A, B, mmAtB));
14380e3ece09SJunchao Zhang     mm = pdata->mmAtB = mmAtB;
14390e3ece09SJunchao Zhang   } else if (ptype == MATPRODUCT_PtAP) { // C = BtAB, computed as Z = AB; C= BtZ
14400e3ece09SJunchao Zhang     Mat Zd, Zo, Z;                       // Zd, Zo are owned by pdata->Z
14410e3ece09SJunchao Zhang 
14420e3ece09SJunchao Zhang     auto mmAB = new MatMatStruct_AB();
14430e3ece09SJunchao Zhang     PetscCall(MatProductSymbolic_MPIAIJKokkos_AB(product, A, B, mmAB)); // Z stored as mmAB->{Cd, Co}
14440e3ece09SJunchao Zhang     PetscCall(MatCreateSeqAIJKokkosWithKokkosCsrMatrix(PETSC_COMM_SELF, mmAB->Cd, &Zd));
14450e3ece09SJunchao Zhang     PetscCall(MatCreateSeqAIJKokkosWithKokkosCsrMatrix(PETSC_COMM_SELF, mmAB->Co, &Zo));
14460e3ece09SJunchao Zhang     pdata->mmAB = mmAB;
14470e3ece09SJunchao Zhang 
14480e3ece09SJunchao Zhang     m = A->rmap->n; // Z's layout
14490e3ece09SJunchao Zhang     n = B->cmap->n;
14500e3ece09SJunchao Zhang     M = A->rmap->N;
14510e3ece09SJunchao Zhang     N = B->cmap->N;
14520e3ece09SJunchao Zhang     PetscCall(MatCreate(comm, &Z));
14530e3ece09SJunchao Zhang     PetscCall(MatSetSizes(Z, m, n, M, N));
14540e3ece09SJunchao Zhang     PetscCall(PetscLayoutSetUp(Z->rmap));
14550e3ece09SJunchao Zhang     PetscCall(PetscLayoutSetUp(Z->cmap));
14560e3ece09SJunchao Zhang     PetscCall(MatSetType(Z, MATMPIAIJKOKKOS));
14570e3ece09SJunchao Zhang     PetscCall(MatSetMPIAIJKokkosWithSplitSeqAIJKokkosMatrices(Z, Zd, Zo, mmAB->garray));
14580e3ece09SJunchao Zhang 
14590e3ece09SJunchao Zhang     auto mmAtB = new MatMatStruct_AtB();
14600e3ece09SJunchao Zhang     PetscCall(MatProductSymbolic_MPIAIJKokkos_AtB(product, B, Z, mmAtB)); // final result C stored as mmAtB->{Cd, Co}
14610e3ece09SJunchao Zhang 
14620e3ece09SJunchao Zhang     pdata->Z = Z; // give ownership to pdata
14630e3ece09SJunchao Zhang     mm = pdata->mmAtB = mmAtB;
1464076ba34aSJunchao Zhang   }
14650e3ece09SJunchao Zhang 
14660e3ece09SJunchao Zhang   PetscCall(MatCreateSeqAIJKokkosWithKokkosCsrMatrix(PETSC_COMM_SELF, mm->Cd, &Cd));
14670e3ece09SJunchao Zhang   PetscCall(MatCreateSeqAIJKokkosWithKokkosCsrMatrix(PETSC_COMM_SELF, mm->Co, &Co));
14680e3ece09SJunchao Zhang   PetscCall(MatSetMPIAIJKokkosWithSplitSeqAIJKokkosMatrices(C, Cd, Co, mm->garray));
14690e3ece09SJunchao Zhang 
14700e3ece09SJunchao Zhang   C->product->data       = pdata;
1471076ba34aSJunchao Zhang   C->product->destroy    = MatProductDataDestroy_MPIAIJKokkos;
1472076ba34aSJunchao Zhang   C->ops->productnumeric = MatProductNumeric_MPIAIJKokkos;
14733ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
1474076ba34aSJunchao Zhang }
1475076ba34aSJunchao Zhang 
1476d71ae5a4SJacob Faibussowitsch PETSC_INTERN PetscErrorCode MatProductSetFromOptions_MPIAIJKokkos(Mat mat)
1477d71ae5a4SJacob Faibussowitsch {
1478076ba34aSJunchao Zhang   Mat_Product *product = mat->product;
1479076ba34aSJunchao Zhang   PetscBool    match   = PETSC_FALSE;
1480076ba34aSJunchao Zhang   PetscBool    usecpu  = PETSC_FALSE;
1481076ba34aSJunchao Zhang 
1482076ba34aSJunchao Zhang   PetscFunctionBegin;
1483076ba34aSJunchao Zhang   MatCheckProduct(mat, 1);
148448a46eb9SPierre Jolivet   if (!product->A->boundtocpu && !product->B->boundtocpu) PetscCall(PetscObjectTypeCompare((PetscObject)product->B, ((PetscObject)product->A)->type_name, &match));
1485076ba34aSJunchao Zhang   if (match) { /* we can always fallback to the CPU if requested */
1486076ba34aSJunchao Zhang     switch (product->type) {
1487076ba34aSJunchao Zhang     case MATPRODUCT_AB:
1488076ba34aSJunchao Zhang       if (product->api_user) {
1489d0609cedSBarry Smith         PetscOptionsBegin(PetscObjectComm((PetscObject)mat), ((PetscObject)mat)->prefix, "MatMatMult", "Mat");
14909566063dSJacob Faibussowitsch         PetscCall(PetscOptionsBool("-matmatmult_backend_cpu", "Use CPU code", "MatMatMult", usecpu, &usecpu, NULL));
1491d0609cedSBarry Smith         PetscOptionsEnd();
1492076ba34aSJunchao Zhang       } else {
1493d0609cedSBarry Smith         PetscOptionsBegin(PetscObjectComm((PetscObject)mat), ((PetscObject)mat)->prefix, "MatProduct_AB", "Mat");
14949566063dSJacob Faibussowitsch         PetscCall(PetscOptionsBool("-mat_product_algorithm_backend_cpu", "Use CPU code", "MatMatMult", usecpu, &usecpu, NULL));
1495d0609cedSBarry Smith         PetscOptionsEnd();
1496076ba34aSJunchao Zhang       }
1497076ba34aSJunchao Zhang       break;
1498076ba34aSJunchao Zhang     case MATPRODUCT_AtB:
1499076ba34aSJunchao Zhang       if (product->api_user) {
1500d0609cedSBarry Smith         PetscOptionsBegin(PetscObjectComm((PetscObject)mat), ((PetscObject)mat)->prefix, "MatTransposeMatMult", "Mat");
15019566063dSJacob Faibussowitsch         PetscCall(PetscOptionsBool("-mattransposematmult_backend_cpu", "Use CPU code", "MatTransposeMatMult", usecpu, &usecpu, NULL));
1502d0609cedSBarry Smith         PetscOptionsEnd();
1503076ba34aSJunchao Zhang       } else {
1504d0609cedSBarry Smith         PetscOptionsBegin(PetscObjectComm((PetscObject)mat), ((PetscObject)mat)->prefix, "MatProduct_AtB", "Mat");
15059566063dSJacob Faibussowitsch         PetscCall(PetscOptionsBool("-mat_product_algorithm_backend_cpu", "Use CPU code", "MatTransposeMatMult", usecpu, &usecpu, NULL));
1506d0609cedSBarry Smith         PetscOptionsEnd();
1507076ba34aSJunchao Zhang       }
1508076ba34aSJunchao Zhang       break;
1509076ba34aSJunchao Zhang     case MATPRODUCT_PtAP:
1510076ba34aSJunchao Zhang       if (product->api_user) {
1511d0609cedSBarry Smith         PetscOptionsBegin(PetscObjectComm((PetscObject)mat), ((PetscObject)mat)->prefix, "MatPtAP", "Mat");
15129566063dSJacob Faibussowitsch         PetscCall(PetscOptionsBool("-matptap_backend_cpu", "Use CPU code", "MatPtAP", usecpu, &usecpu, NULL));
1513d0609cedSBarry Smith         PetscOptionsEnd();
1514076ba34aSJunchao Zhang       } else {
1515d0609cedSBarry Smith         PetscOptionsBegin(PetscObjectComm((PetscObject)mat), ((PetscObject)mat)->prefix, "MatProduct_PtAP", "Mat");
15169566063dSJacob Faibussowitsch         PetscCall(PetscOptionsBool("-mat_product_algorithm_backend_cpu", "Use CPU code", "MatPtAP", usecpu, &usecpu, NULL));
1517d0609cedSBarry Smith         PetscOptionsEnd();
1518076ba34aSJunchao Zhang       }
1519076ba34aSJunchao Zhang       break;
1520d71ae5a4SJacob Faibussowitsch     default:
1521d71ae5a4SJacob Faibussowitsch       break;
1522076ba34aSJunchao Zhang     }
1523076ba34aSJunchao Zhang     match = (PetscBool)!usecpu;
1524076ba34aSJunchao Zhang   }
1525076ba34aSJunchao Zhang   if (match) {
1526076ba34aSJunchao Zhang     switch (product->type) {
1527076ba34aSJunchao Zhang     case MATPRODUCT_AB:
1528076ba34aSJunchao Zhang     case MATPRODUCT_AtB:
1529d71ae5a4SJacob Faibussowitsch     case MATPRODUCT_PtAP:
1530d71ae5a4SJacob Faibussowitsch       mat->ops->productsymbolic = MatProductSymbolic_MPIAIJKokkos;
1531d71ae5a4SJacob Faibussowitsch       break;
1532d71ae5a4SJacob Faibussowitsch     default:
1533d71ae5a4SJacob Faibussowitsch       break;
1534076ba34aSJunchao Zhang     }
1535076ba34aSJunchao Zhang   }
1536076ba34aSJunchao Zhang   /* fallback to MPIAIJ ops */
153748a46eb9SPierre Jolivet   if (!mat->ops->productsymbolic) PetscCall(MatProductSetFromOptions_MPIAIJ(mat));
15383ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
1539076ba34aSJunchao Zhang }
1540076ba34aSJunchao Zhang 
15412c4ab24aSJunchao Zhang // Mirror of MatCOOStruct_MPIAIJ on device
15422c4ab24aSJunchao Zhang struct MatCOOStruct_MPIAIJKokkos {
15432c4ab24aSJunchao Zhang   PetscCount           n;
15442c4ab24aSJunchao Zhang   PetscSF              sf;
15452c4ab24aSJunchao Zhang   PetscCount           Annz, Bnnz;
15462c4ab24aSJunchao Zhang   PetscCount           Annz2, Bnnz2;
15472c4ab24aSJunchao Zhang   PetscCountKokkosView Ajmap1, Aperm1;
15482c4ab24aSJunchao Zhang   PetscCountKokkosView Bjmap1, Bperm1;
15492c4ab24aSJunchao Zhang   PetscCountKokkosView Aimap2, Ajmap2, Aperm2;
15502c4ab24aSJunchao Zhang   PetscCountKokkosView Bimap2, Bjmap2, Bperm2;
15512c4ab24aSJunchao Zhang   PetscCountKokkosView Cperm1;
15522c4ab24aSJunchao Zhang   MatScalarKokkosView  sendbuf, recvbuf;
15532c4ab24aSJunchao Zhang 
15542c4ab24aSJunchao Zhang   MatCOOStruct_MPIAIJKokkos(const MatCOOStruct_MPIAIJ *coo_h) :
15552c4ab24aSJunchao Zhang     n(coo_h->n),
15562c4ab24aSJunchao Zhang     sf(coo_h->sf),
15572c4ab24aSJunchao Zhang     Annz(coo_h->Annz),
15582c4ab24aSJunchao Zhang     Bnnz(coo_h->Bnnz),
15592c4ab24aSJunchao Zhang     Annz2(coo_h->Annz2),
15602c4ab24aSJunchao Zhang     Bnnz2(coo_h->Bnnz2),
15612c4ab24aSJunchao Zhang     Ajmap1(Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), PetscCountKokkosViewHost(coo_h->Ajmap1, coo_h->Annz + 1))),
15622c4ab24aSJunchao Zhang     Aperm1(Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), PetscCountKokkosViewHost(coo_h->Aperm1, coo_h->Atot1))),
15632c4ab24aSJunchao Zhang     Bjmap1(Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), PetscCountKokkosViewHost(coo_h->Bjmap1, coo_h->Bnnz + 1))),
15642c4ab24aSJunchao Zhang     Bperm1(Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), PetscCountKokkosViewHost(coo_h->Bperm1, coo_h->Btot1))),
15652c4ab24aSJunchao Zhang     Aimap2(Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), PetscCountKokkosViewHost(coo_h->Aimap2, coo_h->Annz2))),
15662c4ab24aSJunchao Zhang     Ajmap2(Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), PetscCountKokkosViewHost(coo_h->Ajmap2, coo_h->Annz2 + 1))),
15672c4ab24aSJunchao Zhang     Aperm2(Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), PetscCountKokkosViewHost(coo_h->Aperm2, coo_h->Atot2))),
15682c4ab24aSJunchao Zhang     Bimap2(Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), PetscCountKokkosViewHost(coo_h->Bimap2, coo_h->Bnnz2))),
15692c4ab24aSJunchao Zhang     Bjmap2(Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), PetscCountKokkosViewHost(coo_h->Bjmap2, coo_h->Bnnz2 + 1))),
15702c4ab24aSJunchao Zhang     Bperm2(Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), PetscCountKokkosViewHost(coo_h->Bperm2, coo_h->Btot2))),
15712c4ab24aSJunchao Zhang     Cperm1(Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), PetscCountKokkosViewHost(coo_h->Cperm1, coo_h->sendlen))),
15726892b982SJunchao Zhang     sendbuf(Kokkos::create_mirror_view(Kokkos::WithoutInitializing, DefaultMemorySpace(), MatScalarKokkosViewHost(coo_h->sendbuf, coo_h->sendlen))),
15736892b982SJunchao Zhang     recvbuf(Kokkos::create_mirror_view(Kokkos::WithoutInitializing, DefaultMemorySpace(), MatScalarKokkosViewHost(coo_h->recvbuf, coo_h->recvlen)))
15742c4ab24aSJunchao Zhang   {
15752c4ab24aSJunchao Zhang     PetscCallVoid(PetscObjectReference((PetscObject)sf));
15762c4ab24aSJunchao Zhang   }
15772c4ab24aSJunchao Zhang 
15782c4ab24aSJunchao Zhang   ~MatCOOStruct_MPIAIJKokkos() { PetscCallVoid(PetscSFDestroy(&sf)); }
15792c4ab24aSJunchao Zhang };
15802c4ab24aSJunchao Zhang 
15812c4ab24aSJunchao Zhang static PetscErrorCode MatCOOStructDestroy_MPIAIJKokkos(void *data)
15822c4ab24aSJunchao Zhang {
15832c4ab24aSJunchao Zhang   PetscFunctionBegin;
15842c4ab24aSJunchao Zhang   PetscCallCXX(delete static_cast<MatCOOStruct_MPIAIJKokkos *>(data));
15852c4ab24aSJunchao Zhang   PetscFunctionReturn(PETSC_SUCCESS);
15862c4ab24aSJunchao Zhang }
15872c4ab24aSJunchao Zhang 
1588d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatSetPreallocationCOO_MPIAIJKokkos(Mat mat, PetscCount coo_n, PetscInt coo_i[], PetscInt coo_j[])
1589d71ae5a4SJacob Faibussowitsch {
15902c4ab24aSJunchao Zhang   PetscContainer             container_h, container_d;
15912c4ab24aSJunchao Zhang   MatCOOStruct_MPIAIJ       *coo_h;
15922c4ab24aSJunchao Zhang   MatCOOStruct_MPIAIJKokkos *coo_d;
159342550becSJunchao Zhang 
159442550becSJunchao Zhang   PetscFunctionBegin;
159530203840SJunchao Zhang   PetscCall(MatSetPreallocationCOO_MPIAIJ(mat, coo_n, coo_i, coo_j)); /* mpiaij->A,B's type is set to seqaijkokkos */
1596cbc6b225SStefano Zampini   mat->preallocated = PETSC_TRUE;
15979566063dSJacob Faibussowitsch   PetscCall(MatAssemblyBegin(mat, MAT_FINAL_ASSEMBLY));
15989566063dSJacob Faibussowitsch   PetscCall(MatAssemblyEnd(mat, MAT_FINAL_ASSEMBLY));
15999566063dSJacob Faibussowitsch   PetscCall(MatZeroEntries(mat));
16002c4ab24aSJunchao Zhang 
16012c4ab24aSJunchao Zhang   // Copy the COO struct to device
16022c4ab24aSJunchao Zhang   PetscCall(PetscObjectQuery((PetscObject)mat, "__PETSc_MatCOOStruct_Host", (PetscObject *)&container_h));
16032c4ab24aSJunchao Zhang   PetscCall(PetscContainerGetPointer(container_h, (void **)&coo_h));
16042c4ab24aSJunchao Zhang   PetscCallCXX(coo_d = new MatCOOStruct_MPIAIJKokkos(coo_h));
16052c4ab24aSJunchao Zhang 
16062c4ab24aSJunchao Zhang   // Put the COO struct in a container and then attach that to the matrix
16072c4ab24aSJunchao Zhang   PetscCall(PetscContainerCreate(PETSC_COMM_SELF, &container_d));
16082c4ab24aSJunchao Zhang   PetscCall(PetscContainerSetPointer(container_d, coo_d));
16092c4ab24aSJunchao Zhang   PetscCall(PetscContainerSetUserDestroy(container_d, MatCOOStructDestroy_MPIAIJKokkos));
16102c4ab24aSJunchao Zhang   PetscCall(PetscObjectCompose((PetscObject)mat, "__PETSc_MatCOOStruct_Device", (PetscObject)container_d));
16112c4ab24aSJunchao Zhang   PetscCall(PetscContainerDestroy(&container_d));
16123ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
161342550becSJunchao Zhang }
161442550becSJunchao Zhang 
1615d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatSetValuesCOO_MPIAIJKokkos(Mat mat, const PetscScalar v[], InsertMode imode)
1616d71ae5a4SJacob Faibussowitsch {
1617394ed5ebSJunchao Zhang   Mat_MPIAIJ                *mpiaij = static_cast<Mat_MPIAIJ *>(mat->data);
161842550becSJunchao Zhang   Mat                        A = mpiaij->A, B = mpiaij->B;
161942550becSJunchao Zhang   MatScalarKokkosView        Aa, Ba;
1620394ed5ebSJunchao Zhang   MatScalarKokkosView        v1;
162142550becSJunchao Zhang   PetscMemType               memtype;
16222c4ab24aSJunchao Zhang   PetscContainer             container;
16232c4ab24aSJunchao Zhang   MatCOOStruct_MPIAIJKokkos *coo;
162442550becSJunchao Zhang 
162542550becSJunchao Zhang   PetscFunctionBegin;
16262c4ab24aSJunchao Zhang   PetscCall(PetscObjectQuery((PetscObject)mat, "__PETSc_MatCOOStruct_Device", (PetscObject *)&container));
16272c4ab24aSJunchao Zhang   PetscCall(PetscContainerGetPointer(container, (void **)&coo));
16282c4ab24aSJunchao Zhang 
16292c4ab24aSJunchao Zhang   const auto &n      = coo->n;
16302c4ab24aSJunchao Zhang   const auto &Annz   = coo->Annz;
16312c4ab24aSJunchao Zhang   const auto &Annz2  = coo->Annz2;
16322c4ab24aSJunchao Zhang   const auto &Bnnz   = coo->Bnnz;
16332c4ab24aSJunchao Zhang   const auto &Bnnz2  = coo->Bnnz2;
16342c4ab24aSJunchao Zhang   const auto &vsend  = coo->sendbuf;
16352c4ab24aSJunchao Zhang   const auto &v2     = coo->recvbuf;
16362c4ab24aSJunchao Zhang   const auto &Ajmap1 = coo->Ajmap1;
16372c4ab24aSJunchao Zhang   const auto &Ajmap2 = coo->Ajmap2;
16382c4ab24aSJunchao Zhang   const auto &Aimap2 = coo->Aimap2;
16392c4ab24aSJunchao Zhang   const auto &Bjmap1 = coo->Bjmap1;
16402c4ab24aSJunchao Zhang   const auto &Bjmap2 = coo->Bjmap2;
16412c4ab24aSJunchao Zhang   const auto &Bimap2 = coo->Bimap2;
16422c4ab24aSJunchao Zhang   const auto &Aperm1 = coo->Aperm1;
16432c4ab24aSJunchao Zhang   const auto &Aperm2 = coo->Aperm2;
16442c4ab24aSJunchao Zhang   const auto &Bperm1 = coo->Bperm1;
16452c4ab24aSJunchao Zhang   const auto &Bperm2 = coo->Bperm2;
16462c4ab24aSJunchao Zhang   const auto &Cperm1 = coo->Cperm1;
16472c4ab24aSJunchao Zhang 
16489566063dSJacob Faibussowitsch   PetscCall(PetscGetMemType(v, &memtype)); /* Return PETSC_MEMTYPE_HOST when v is NULL */
164942550becSJunchao Zhang   if (PetscMemTypeHost(memtype)) {         /* If user gave v[] in host, we need to copy it to device if any */
16502c4ab24aSJunchao Zhang     v1 = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), MatScalarKokkosViewHost((PetscScalar *)v, n));
165142550becSJunchao Zhang   } else {
16522c4ab24aSJunchao Zhang     v1 = MatScalarKokkosView((PetscScalar *)v, n); /* Directly use v[]'s memory */
165342550becSJunchao Zhang   }
165442550becSJunchao Zhang 
165542550becSJunchao Zhang   if (imode == INSERT_VALUES) {
16569566063dSJacob Faibussowitsch     PetscCall(MatSeqAIJGetKokkosViewWrite(A, &Aa)); /* write matrix values */
16579566063dSJacob Faibussowitsch     PetscCall(MatSeqAIJGetKokkosViewWrite(B, &Ba));
1658394ed5ebSJunchao Zhang   } else {
16599566063dSJacob Faibussowitsch     PetscCall(MatSeqAIJGetKokkosView(A, &Aa)); /* read & write matrix values */
16609566063dSJacob Faibussowitsch     PetscCall(MatSeqAIJGetKokkosView(B, &Ba));
166142550becSJunchao Zhang   }
166242550becSJunchao Zhang 
166308bb9926SJunchao Zhang   PetscCall(PetscLogGpuTimeBegin());
166442550becSJunchao Zhang   /* Pack entries to be sent to remote */
1665*d326c3f1SJunchao Zhang   Kokkos::parallel_for(Kokkos::RangePolicy<>(PetscGetKokkosExecutionSpace(), 0, vsend.extent(0)), KOKKOS_LAMBDA(const PetscCount i) { vsend(i) = v1(Cperm1(i)); });
166642550becSJunchao Zhang 
166742550becSJunchao Zhang   /* Send remote entries to their owner and overlap the communication with local computation */
16682c4ab24aSJunchao Zhang   PetscCall(PetscSFReduceWithMemTypeBegin(coo->sf, MPIU_SCALAR, PETSC_MEMTYPE_KOKKOS, vsend.data(), PETSC_MEMTYPE_KOKKOS, v2.data(), MPI_REPLACE));
1669158ec288SJunchao Zhang   /* Add local entries to A and B in one kernel */
16709371c9d4SSatish Balay   Kokkos::parallel_for(
1671*d326c3f1SJunchao Zhang     Kokkos::RangePolicy<>(PetscGetKokkosExecutionSpace(), 0, Annz + Bnnz), KOKKOS_LAMBDA(PetscCount i) {
1672158ec288SJunchao Zhang       PetscScalar sum = 0.0;
1673158ec288SJunchao Zhang       if (i < Annz) {
1674158ec288SJunchao Zhang         for (PetscCount k = Ajmap1(i); k < Ajmap1(i + 1); k++) sum += v1(Aperm1(k));
1675ac38520cSJunchao Zhang         Aa(i) = (imode == INSERT_VALUES ? 0.0 : Aa(i)) + sum;
1676158ec288SJunchao Zhang       } else {
1677158ec288SJunchao Zhang         i -= Annz;
1678158ec288SJunchao Zhang         for (PetscCount k = Bjmap1(i); k < Bjmap1(i + 1); k++) sum += v1(Bperm1(k));
1679ac38520cSJunchao Zhang         Ba(i) = (imode == INSERT_VALUES ? 0.0 : Ba(i)) + sum;
1680158ec288SJunchao Zhang       }
1681158ec288SJunchao Zhang     });
16822c4ab24aSJunchao Zhang   PetscCall(PetscSFReduceEnd(coo->sf, MPIU_SCALAR, vsend.data(), v2.data(), MPI_REPLACE));
168342550becSJunchao Zhang 
1684158ec288SJunchao Zhang   /* Add received remote entries to A and B in one kernel */
16859371c9d4SSatish Balay   Kokkos::parallel_for(
1686*d326c3f1SJunchao Zhang     Kokkos::RangePolicy<>(PetscGetKokkosExecutionSpace(), 0, Annz2 + Bnnz2), KOKKOS_LAMBDA(PetscCount i) {
1687158ec288SJunchao Zhang       if (i < Annz2) {
1688158ec288SJunchao Zhang         for (PetscCount k = Ajmap2(i); k < Ajmap2(i + 1); k++) Aa(Aimap2(i)) += v2(Aperm2(k));
1689158ec288SJunchao Zhang       } else {
1690158ec288SJunchao Zhang         i -= Annz2;
1691158ec288SJunchao Zhang         for (PetscCount k = Bjmap2(i); k < Bjmap2(i + 1); k++) Ba(Bimap2(i)) += v2(Bperm2(k));
1692158ec288SJunchao Zhang       }
1693158ec288SJunchao Zhang     });
169408bb9926SJunchao Zhang   PetscCall(PetscLogGpuTimeEnd());
169542550becSJunchao Zhang 
1696394ed5ebSJunchao Zhang   if (imode == INSERT_VALUES) {
16979566063dSJacob Faibussowitsch     PetscCall(MatSeqAIJRestoreKokkosViewWrite(A, &Aa)); /* Increase A & B's state etc. */
16989566063dSJacob Faibussowitsch     PetscCall(MatSeqAIJRestoreKokkosViewWrite(B, &Ba));
1699394ed5ebSJunchao Zhang   } else {
17009566063dSJacob Faibussowitsch     PetscCall(MatSeqAIJRestoreKokkosView(A, &Aa));
17019566063dSJacob Faibussowitsch     PetscCall(MatSeqAIJRestoreKokkosView(B, &Ba));
1702394ed5ebSJunchao Zhang   }
17033ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
170442550becSJunchao Zhang }
170542550becSJunchao Zhang 
17062c4ab24aSJunchao Zhang static PetscErrorCode MatDestroy_MPIAIJKokkos(Mat A)
1707d71ae5a4SJacob Faibussowitsch {
1708076ba34aSJunchao Zhang   PetscFunctionBegin;
17099566063dSJacob Faibussowitsch   PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatMPIAIJSetPreallocation_C", NULL));
17109566063dSJacob Faibussowitsch   PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatMPIAIJGetLocalMatMerge_C", NULL));
17119566063dSJacob Faibussowitsch   PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatSetPreallocationCOO_C", NULL));
17129566063dSJacob Faibussowitsch   PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatSetValuesCOO_C", NULL));
17139566063dSJacob Faibussowitsch   PetscCall(MatDestroy_MPIAIJ(A));
17143ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
1715076ba34aSJunchao Zhang }
1716076ba34aSJunchao Zhang 
1717f4747e26SJunchao Zhang static PetscErrorCode MatShift_MPIAIJKokkos(Mat A, PetscScalar a)
1718f4747e26SJunchao Zhang {
1719f4747e26SJunchao Zhang   Mat_MPIAIJ *mpiaij = static_cast<Mat_MPIAIJ *>(A->data);
1720f4747e26SJunchao Zhang   PetscBool   congruent;
1721f4747e26SJunchao Zhang 
1722f4747e26SJunchao Zhang   PetscFunctionBegin;
1723f4747e26SJunchao Zhang   PetscCall(MatHasCongruentLayouts(A, &congruent));
1724f4747e26SJunchao Zhang   if (congruent) { // square matrix and the diagonals are solely in the diag block
1725f4747e26SJunchao Zhang     PetscCall(MatShift(mpiaij->A, a));
1726f4747e26SJunchao Zhang   } else { // too hard, use the general version
1727f4747e26SJunchao Zhang     PetscCall(MatShift_Basic(A, a));
1728f4747e26SJunchao Zhang   }
1729f4747e26SJunchao Zhang   PetscFunctionReturn(PETSC_SUCCESS);
1730f4747e26SJunchao Zhang }
1731f4747e26SJunchao Zhang 
17322c4ab24aSJunchao Zhang static PetscErrorCode MatSetOps_MPIAIJKokkos(Mat B)
17332c4ab24aSJunchao Zhang {
17342c4ab24aSJunchao Zhang   PetscFunctionBegin;
17352c4ab24aSJunchao Zhang   B->ops->assemblyend           = MatAssemblyEnd_MPIAIJKokkos;
17362c4ab24aSJunchao Zhang   B->ops->mult                  = MatMult_MPIAIJKokkos;
17372c4ab24aSJunchao Zhang   B->ops->multadd               = MatMultAdd_MPIAIJKokkos;
17382c4ab24aSJunchao Zhang   B->ops->multtranspose         = MatMultTranspose_MPIAIJKokkos;
17392c4ab24aSJunchao Zhang   B->ops->productsetfromoptions = MatProductSetFromOptions_MPIAIJKokkos;
17402c4ab24aSJunchao Zhang   B->ops->destroy               = MatDestroy_MPIAIJKokkos;
1741f4747e26SJunchao Zhang   B->ops->shift                 = MatShift_MPIAIJKokkos;
17422c4ab24aSJunchao Zhang 
17432c4ab24aSJunchao Zhang   PetscCall(PetscObjectComposeFunction((PetscObject)B, "MatMPIAIJSetPreallocation_C", MatMPIAIJSetPreallocation_MPIAIJKokkos));
17442c4ab24aSJunchao Zhang   PetscCall(PetscObjectComposeFunction((PetscObject)B, "MatMPIAIJGetLocalMatMerge_C", MatMPIAIJGetLocalMatMerge_MPIAIJKokkos));
17452c4ab24aSJunchao Zhang   PetscCall(PetscObjectComposeFunction((PetscObject)B, "MatSetPreallocationCOO_C", MatSetPreallocationCOO_MPIAIJKokkos));
17462c4ab24aSJunchao Zhang   PetscCall(PetscObjectComposeFunction((PetscObject)B, "MatSetValuesCOO_C", MatSetValuesCOO_MPIAIJKokkos));
17472c4ab24aSJunchao Zhang   PetscFunctionReturn(PETSC_SUCCESS);
17482c4ab24aSJunchao Zhang }
17492c4ab24aSJunchao Zhang 
1750d71ae5a4SJacob Faibussowitsch PETSC_INTERN PetscErrorCode MatConvert_MPIAIJ_MPIAIJKokkos(Mat A, MatType mtype, MatReuse reuse, Mat *newmat)
1751d71ae5a4SJacob Faibussowitsch {
17528c3ff71bSJunchao Zhang   Mat         B;
1753076ba34aSJunchao Zhang   Mat_MPIAIJ *a;
17548c3ff71bSJunchao Zhang 
17558c3ff71bSJunchao Zhang   PetscFunctionBegin;
17568c3ff71bSJunchao Zhang   if (reuse == MAT_INITIAL_MATRIX) {
17579566063dSJacob Faibussowitsch     PetscCall(MatDuplicate(A, MAT_COPY_VALUES, newmat));
17588c3ff71bSJunchao Zhang   } else if (reuse == MAT_REUSE_MATRIX) {
17599566063dSJacob Faibussowitsch     PetscCall(MatCopy(A, *newmat, SAME_NONZERO_PATTERN));
17608c3ff71bSJunchao Zhang   }
17618c3ff71bSJunchao Zhang   B = *newmat;
17628c3ff71bSJunchao Zhang 
17636f3d89d0SStefano Zampini   B->boundtocpu = PETSC_FALSE;
17649566063dSJacob Faibussowitsch   PetscCall(PetscFree(B->defaultvectype));
17659566063dSJacob Faibussowitsch   PetscCall(PetscStrallocpy(VECKOKKOS, &B->defaultvectype));
17669566063dSJacob Faibussowitsch   PetscCall(PetscObjectChangeTypeName((PetscObject)B, MATMPIAIJKOKKOS));
17678c3ff71bSJunchao Zhang 
1768076ba34aSJunchao Zhang   a = static_cast<Mat_MPIAIJ *>(A->data);
17699566063dSJacob Faibussowitsch   if (a->A) PetscCall(MatSetType(a->A, MATSEQAIJKOKKOS));
17709566063dSJacob Faibussowitsch   if (a->B) PetscCall(MatSetType(a->B, MATSEQAIJKOKKOS));
17719566063dSJacob Faibussowitsch   if (a->lvec) PetscCall(VecSetType(a->lvec, VECSEQKOKKOS));
17722c4ab24aSJunchao Zhang   PetscCall(MatSetOps_MPIAIJKokkos(B));
17733ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
17748c3ff71bSJunchao Zhang }
17752c4ab24aSJunchao Zhang 
17763f3ba80aSJunchao Zhang /*MC
177711a5261eSBarry Smith    MATAIJKOKKOS - "mpiaijkokkos", a matrix type to be used for CSR sparse matrices with Kokkos
17788c3ff71bSJunchao Zhang 
177915229ffcSPierre Jolivet    A matrix type using Kokkos-Kernels CrsMatrix type for portability across different device types
17803f3ba80aSJunchao Zhang 
17812ef1f0ffSBarry Smith    Options Database Key:
17822ef1f0ffSBarry Smith .  -mat_type aijkokkos - sets the matrix type to `MATAIJKOKKOS`
17833f3ba80aSJunchao Zhang 
17843f3ba80aSJunchao Zhang   Level: beginner
17853f3ba80aSJunchao Zhang 
17861cc06b55SBarry Smith .seealso: [](ch_matrices), `Mat`, `MatCreateAIJKokkos()`, `MATSEQAIJKOKKOS`, `MATSEQAIJ`, `MATMPIAIJ`
17873f3ba80aSJunchao Zhang M*/
1788d71ae5a4SJacob Faibussowitsch PETSC_EXTERN PetscErrorCode MatCreate_MPIAIJKokkos(Mat A)
1789d71ae5a4SJacob Faibussowitsch {
17908c3ff71bSJunchao Zhang   PetscFunctionBegin;
17919566063dSJacob Faibussowitsch   PetscCall(PetscKokkosInitializeCheck());
17929566063dSJacob Faibussowitsch   PetscCall(MatCreate_MPIAIJ(A));
17939566063dSJacob Faibussowitsch   PetscCall(MatConvert_MPIAIJ_MPIAIJKokkos(A, MATMPIAIJKOKKOS, MAT_INPLACE_MATRIX, &A));
17943ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
17958c3ff71bSJunchao Zhang }
17968c3ff71bSJunchao Zhang 
17978c3ff71bSJunchao Zhang /*@C
179811a5261eSBarry Smith   MatCreateAIJKokkos - Creates a sparse matrix in `MATAIJKOKOS` (compressed row) format
17998c3ff71bSJunchao Zhang   (the default parallel PETSc format).  This matrix will ultimately pushed down
180020f4b53cSBarry Smith   to Kokkos for calculations.
18018c3ff71bSJunchao Zhang 
18028c3ff71bSJunchao Zhang   Collective
18038c3ff71bSJunchao Zhang 
18048c3ff71bSJunchao Zhang   Input Parameters:
180511a5261eSBarry Smith + comm  - MPI communicator, set to `PETSC_COMM_SELF`
180620f4b53cSBarry Smith . m     - number of local rows (or `PETSC_DECIDE` to have calculated if `M` is given)
180720f4b53cSBarry Smith            This value should be the same as the local size used in creating the
180820f4b53cSBarry Smith            y vector for the matrix-vector product y = Ax.
180920f4b53cSBarry Smith . n     - This value should be the same as the local size used in creating the
181020f4b53cSBarry Smith        x vector for the matrix-vector product y = Ax. (or `PETSC_DECIDE` to have
181120f4b53cSBarry Smith        calculated if N is given) For square matrices n is almost always `m`.
181220f4b53cSBarry Smith . M     - number of global rows (or `PETSC_DETERMINE` to have calculated if `m` is given)
181320f4b53cSBarry Smith . N     - number of global columns (or `PETSC_DETERMINE` to have calculated if `n` is given)
181420f4b53cSBarry Smith . d_nz  - number of nonzeros per row in DIAGONAL portion of local submatrix
181520f4b53cSBarry Smith            (same value is used for all local rows)
181620f4b53cSBarry Smith . d_nnz - array containing the number of nonzeros in the various rows of the
181720f4b53cSBarry Smith            DIAGONAL portion of the local submatrix (possibly different for each row)
181820f4b53cSBarry Smith            or `NULL`, if `d_nz` is used to specify the nonzero structure.
181920f4b53cSBarry Smith            The size of this array is equal to the number of local rows, i.e `m`.
182020f4b53cSBarry Smith            For matrices you plan to factor you must leave room for the diagonal entry and
182120f4b53cSBarry Smith            put in the entry even if it is zero.
182220f4b53cSBarry Smith . o_nz  - number of nonzeros per row in the OFF-DIAGONAL portion of local
182320f4b53cSBarry Smith            submatrix (same value is used for all local rows).
182420f4b53cSBarry Smith - o_nnz - array containing the number of nonzeros in the various rows of the
182520f4b53cSBarry Smith            OFF-DIAGONAL portion of the local submatrix (possibly different for
182620f4b53cSBarry Smith            each row) or `NULL`, if `o_nz` is used to specify the nonzero
182720f4b53cSBarry Smith            structure. The size of this array is equal to the number
182820f4b53cSBarry Smith            of local rows, i.e `m`.
18298c3ff71bSJunchao Zhang 
18308c3ff71bSJunchao Zhang   Output Parameter:
18318c3ff71bSJunchao Zhang . A - the matrix
18328c3ff71bSJunchao Zhang 
18332ef1f0ffSBarry Smith   Level: intermediate
18342ef1f0ffSBarry Smith 
18352ef1f0ffSBarry Smith   Notes:
183611a5261eSBarry Smith   It is recommended that one use the `MatCreate()`, `MatSetType()` and/or `MatSetFromOptions()`,
18378c3ff71bSJunchao Zhang   MatXXXXSetPreallocation() paradigm instead of this routine directly.
183811a5261eSBarry Smith   [MatXXXXSetPreallocation() is, for example, `MatSeqAIJSetPreallocation()`]
18398c3ff71bSJunchao Zhang 
1840667f096bSBarry Smith   The AIJ format, also called compressed row storage), is fully compatible with standard Fortran
18418c3ff71bSJunchao Zhang   storage.  That is, the stored row and column indices can begin at
18422ef1f0ffSBarry Smith   either one (as in Fortran) or zero.
18438c3ff71bSJunchao Zhang 
18441cc06b55SBarry Smith .seealso: [](ch_matrices), `Mat`, `MATAIJKOKOS`, `MATSEQAIJKOKOS`, `MATMPIAIJKOKOS`, `MatCreate()`, `MatCreateAIJ()`, `MatSetValues()`,
1845fe59aa6dSJacob Faibussowitsch           `MatSeqAIJSetColumnIndices()`, `MatCreateSeqAIJWithArrays()`, `MATMPIAIJKOKKOS`, `MATAIJKOKKOS`
18468c3ff71bSJunchao Zhang @*/
1847d71ae5a4SJacob Faibussowitsch PetscErrorCode MatCreateAIJKokkos(MPI_Comm comm, PetscInt m, PetscInt n, PetscInt M, PetscInt N, PetscInt d_nz, const PetscInt d_nnz[], PetscInt o_nz, const PetscInt o_nnz[], Mat *A)
1848d71ae5a4SJacob Faibussowitsch {
18498c3ff71bSJunchao Zhang   PetscMPIInt size;
18508c3ff71bSJunchao Zhang 
18518c3ff71bSJunchao Zhang   PetscFunctionBegin;
18529566063dSJacob Faibussowitsch   PetscCall(MatCreate(comm, A));
18539566063dSJacob Faibussowitsch   PetscCall(MatSetSizes(*A, m, n, M, N));
18549566063dSJacob Faibussowitsch   PetscCallMPI(MPI_Comm_size(comm, &size));
18558c3ff71bSJunchao Zhang   if (size > 1) {
18569566063dSJacob Faibussowitsch     PetscCall(MatSetType(*A, MATMPIAIJKOKKOS));
18579566063dSJacob Faibussowitsch     PetscCall(MatMPIAIJSetPreallocation(*A, d_nz, d_nnz, o_nz, o_nnz));
18588c3ff71bSJunchao Zhang   } else {
18599566063dSJacob Faibussowitsch     PetscCall(MatSetType(*A, MATSEQAIJKOKKOS));
18609566063dSJacob Faibussowitsch     PetscCall(MatSeqAIJSetPreallocation(*A, d_nz, d_nnz));
18618c3ff71bSJunchao Zhang   }
18623ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
18638c3ff71bSJunchao Zhang }
1864