xref: /petsc/src/mat/impls/aij/mpi/kokkos/mpiaijkok.kokkos.cxx (revision 9289612345f7f660754a0183484c072fccee64a8)
1d326c3f1SJunchao Zhang #include <petsc_kokkos.hpp>
211d22bbfSJunchao Zhang #include <petscvec_kokkos.hpp>
3f0e6e2d1SJunchao Zhang #include <petscpkg_version.h>
442550becSJunchao Zhang #include <petsc/private/sfimpl.h>
5*92896123SJunchao Zhang #include <petsc/private/kokkosimpl.hpp>
62c4ab24aSJunchao Zhang #include <../src/mat/impls/aij/seq/kokkos/aijkok.hpp>
78c3ff71bSJunchao Zhang #include <../src/mat/impls/aij/mpi/mpiaij.h>
8076ba34aSJunchao Zhang #include <KokkosSparse_spadd.hpp>
90e3ece09SJunchao Zhang #include <KokkosSparse_spgemm.hpp>
1011d22bbfSJunchao Zhang 
1166976f2fSJacob Faibussowitsch static PetscErrorCode MatAssemblyEnd_MPIAIJKokkos(Mat A, MatAssemblyType mode)
12d71ae5a4SJacob Faibussowitsch {
1330203840SJunchao Zhang   Mat_MPIAIJ *mpiaij = (Mat_MPIAIJ *)A->data;
148c3ff71bSJunchao Zhang 
158c3ff71bSJunchao Zhang   PetscFunctionBegin;
169566063dSJacob Faibussowitsch   PetscCall(MatAssemblyEnd_MPIAIJ(A, mode));
1730203840SJunchao Zhang   /* E.g., MatCreateSubMatrix() calls MatCreateMPIAIJWithSeqAIJ(comm,A,B,..), which creates Bnew of SEQAIJ and destroys B of SEQAIJKOKKOS.
1830203840SJunchao Zhang      Thus we finalize A/B/lvec's type in MatAssemblyEnd() to handle various cases.
1930203840SJunchao Zhang    */
2030203840SJunchao Zhang   if (mode == MAT_FINAL_ASSEMBLY) {
21*92896123SJunchao Zhang     PetscScalarKokkosView v;
22*92896123SJunchao Zhang 
2330203840SJunchao Zhang     PetscCall(MatSetType(mpiaij->A, MATSEQAIJKOKKOS));
2430203840SJunchao Zhang     PetscCall(MatSetType(mpiaij->B, MATSEQAIJKOKKOS));
25*92896123SJunchao Zhang     PetscCall(VecSetType(mpiaij->lvec, VECSEQKOKKOS));  // lvec is init'ed on host, without copying to device
26*92896123SJunchao Zhang     PetscCall(VecGetKokkosViewWrite(mpiaij->lvec, &v)); // mark lvec updated on device, as we never need to init lvec on device
27*92896123SJunchao Zhang     PetscCall(VecRestoreKokkosViewWrite(mpiaij->lvec, &v));
2830203840SJunchao Zhang   }
293ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
308c3ff71bSJunchao Zhang }
318c3ff71bSJunchao Zhang 
3266976f2fSJacob Faibussowitsch static PetscErrorCode MatMPIAIJSetPreallocation_MPIAIJKokkos(Mat mat, PetscInt d_nz, const PetscInt d_nnz[], PetscInt o_nz, const PetscInt o_nnz[])
33d71ae5a4SJacob Faibussowitsch {
348c3ff71bSJunchao Zhang   Mat_MPIAIJ *mpiaij = (Mat_MPIAIJ *)mat->data;
358c3ff71bSJunchao Zhang 
368c3ff71bSJunchao Zhang   PetscFunctionBegin;
37913a874dSJunchao Zhang   // If mat was set to use the "set values with a hash table" mechanism, discard it and restore the cached ops
38913a874dSJunchao Zhang   if (mat->hash_active) {
39913a874dSJunchao Zhang     mat->ops[0]      = mpiaij->cops;
40913a874dSJunchao Zhang     mat->hash_active = PETSC_FALSE;
41913a874dSJunchao Zhang   }
42913a874dSJunchao Zhang 
439566063dSJacob Faibussowitsch   PetscCall(PetscLayoutSetUp(mat->rmap));
449566063dSJacob Faibussowitsch   PetscCall(PetscLayoutSetUp(mat->cmap));
456a29ce69SStefano Zampini #if defined(PETSC_USE_DEBUG)
468c3ff71bSJunchao Zhang   if (d_nnz) {
476a29ce69SStefano Zampini     PetscInt i;
48ad540459SPierre Jolivet     for (i = 0; i < mat->rmap->n; i++) PetscCheck(d_nnz[i] >= 0, PETSC_COMM_SELF, PETSC_ERR_ARG_OUTOFRANGE, "d_nnz cannot be less than 0: local row %" PetscInt_FMT " value %" PetscInt_FMT, i, d_nnz[i]);
498c3ff71bSJunchao Zhang   }
508c3ff71bSJunchao Zhang   if (o_nnz) {
516a29ce69SStefano Zampini     PetscInt i;
52ad540459SPierre Jolivet     for (i = 0; i < mat->rmap->n; i++) PetscCheck(o_nnz[i] >= 0, PETSC_COMM_SELF, PETSC_ERR_ARG_OUTOFRANGE, "o_nnz cannot be less than 0: local row %" PetscInt_FMT " value %" PetscInt_FMT, i, o_nnz[i]);
538c3ff71bSJunchao Zhang   }
546a29ce69SStefano Zampini #endif
556a29ce69SStefano Zampini #if defined(PETSC_USE_CTABLE)
56eec179cfSJacob Faibussowitsch   PetscCall(PetscHMapIDestroy(&mpiaij->colmap));
576a29ce69SStefano Zampini #else
589566063dSJacob Faibussowitsch   PetscCall(PetscFree(mpiaij->colmap));
596a29ce69SStefano Zampini #endif
609566063dSJacob Faibussowitsch   PetscCall(PetscFree(mpiaij->garray));
619566063dSJacob Faibussowitsch   PetscCall(VecDestroy(&mpiaij->lvec));
629566063dSJacob Faibussowitsch   PetscCall(VecScatterDestroy(&mpiaij->Mvctx));
636a29ce69SStefano Zampini   /* Because the B will have been resized we simply destroy it and create a new one each time */
649566063dSJacob Faibussowitsch   PetscCall(MatDestroy(&mpiaij->B));
656a29ce69SStefano Zampini 
666a29ce69SStefano Zampini   if (!mpiaij->A) {
679566063dSJacob Faibussowitsch     PetscCall(MatCreate(PETSC_COMM_SELF, &mpiaij->A));
689566063dSJacob Faibussowitsch     PetscCall(MatSetSizes(mpiaij->A, mat->rmap->n, mat->cmap->n, mat->rmap->n, mat->cmap->n));
696a29ce69SStefano Zampini   }
706a29ce69SStefano Zampini   if (!mpiaij->B) {
716a29ce69SStefano Zampini     PetscMPIInt size;
729566063dSJacob Faibussowitsch     PetscCallMPI(MPI_Comm_size(PetscObjectComm((PetscObject)mat), &size));
739566063dSJacob Faibussowitsch     PetscCall(MatCreate(PETSC_COMM_SELF, &mpiaij->B));
749566063dSJacob Faibussowitsch     PetscCall(MatSetSizes(mpiaij->B, mat->rmap->n, size > 1 ? mat->cmap->N : 0, mat->rmap->n, size > 1 ? mat->cmap->N : 0));
758c3ff71bSJunchao Zhang   }
769566063dSJacob Faibussowitsch   PetscCall(MatSetType(mpiaij->A, MATSEQAIJKOKKOS));
779566063dSJacob Faibussowitsch   PetscCall(MatSetType(mpiaij->B, MATSEQAIJKOKKOS));
789566063dSJacob Faibussowitsch   PetscCall(MatSeqAIJSetPreallocation(mpiaij->A, d_nz, d_nnz));
799566063dSJacob Faibussowitsch   PetscCall(MatSeqAIJSetPreallocation(mpiaij->B, o_nz, o_nnz));
808c3ff71bSJunchao Zhang   mat->preallocated = PETSC_TRUE;
813ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
828c3ff71bSJunchao Zhang }
838c3ff71bSJunchao Zhang 
8466976f2fSJacob Faibussowitsch static PetscErrorCode MatMult_MPIAIJKokkos(Mat mat, Vec xx, Vec yy)
85d71ae5a4SJacob Faibussowitsch {
868c3ff71bSJunchao Zhang   Mat_MPIAIJ *mpiaij = (Mat_MPIAIJ *)mat->data;
878c3ff71bSJunchao Zhang   PetscInt    nt;
888c3ff71bSJunchao Zhang 
898c3ff71bSJunchao Zhang   PetscFunctionBegin;
909566063dSJacob Faibussowitsch   PetscCall(VecGetLocalSize(xx, &nt));
9108401ef6SPierre Jolivet   PetscCheck(nt == mat->cmap->n, PETSC_COMM_SELF, PETSC_ERR_ARG_SIZ, "Incompatible partition of mat (%" PetscInt_FMT ") and xx (%" PetscInt_FMT ")", mat->cmap->n, nt);
929566063dSJacob Faibussowitsch   PetscCall(VecScatterBegin(mpiaij->Mvctx, xx, mpiaij->lvec, INSERT_VALUES, SCATTER_FORWARD));
939566063dSJacob Faibussowitsch   PetscCall((*mpiaij->A->ops->mult)(mpiaij->A, xx, yy));
949566063dSJacob Faibussowitsch   PetscCall(VecScatterEnd(mpiaij->Mvctx, xx, mpiaij->lvec, INSERT_VALUES, SCATTER_FORWARD));
959566063dSJacob Faibussowitsch   PetscCall((*mpiaij->B->ops->multadd)(mpiaij->B, mpiaij->lvec, yy, yy));
963ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
978c3ff71bSJunchao Zhang }
988c3ff71bSJunchao Zhang 
9966976f2fSJacob Faibussowitsch static PetscErrorCode MatMultAdd_MPIAIJKokkos(Mat mat, Vec xx, Vec yy, Vec zz)
100d71ae5a4SJacob Faibussowitsch {
1018c3ff71bSJunchao Zhang   Mat_MPIAIJ *mpiaij = (Mat_MPIAIJ *)mat->data;
1028c3ff71bSJunchao Zhang   PetscInt    nt;
1038c3ff71bSJunchao Zhang 
1048c3ff71bSJunchao Zhang   PetscFunctionBegin;
1059566063dSJacob Faibussowitsch   PetscCall(VecGetLocalSize(xx, &nt));
10608401ef6SPierre Jolivet   PetscCheck(nt == mat->cmap->n, PETSC_COMM_SELF, PETSC_ERR_ARG_SIZ, "Incompatible partition of mat (%" PetscInt_FMT ") and xx (%" PetscInt_FMT ")", mat->cmap->n, nt);
1079566063dSJacob Faibussowitsch   PetscCall(VecScatterBegin(mpiaij->Mvctx, xx, mpiaij->lvec, INSERT_VALUES, SCATTER_FORWARD));
1089566063dSJacob Faibussowitsch   PetscCall((*mpiaij->A->ops->multadd)(mpiaij->A, xx, yy, zz));
1099566063dSJacob Faibussowitsch   PetscCall(VecScatterEnd(mpiaij->Mvctx, xx, mpiaij->lvec, INSERT_VALUES, SCATTER_FORWARD));
1109566063dSJacob Faibussowitsch   PetscCall((*mpiaij->B->ops->multadd)(mpiaij->B, mpiaij->lvec, zz, zz));
1113ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
1128c3ff71bSJunchao Zhang }
1138c3ff71bSJunchao Zhang 
11466976f2fSJacob Faibussowitsch static PetscErrorCode MatMultTranspose_MPIAIJKokkos(Mat mat, Vec xx, Vec yy)
115d71ae5a4SJacob Faibussowitsch {
1168c3ff71bSJunchao Zhang   Mat_MPIAIJ *mpiaij = (Mat_MPIAIJ *)mat->data;
1178c3ff71bSJunchao Zhang   PetscInt    nt;
1188c3ff71bSJunchao Zhang 
1198c3ff71bSJunchao Zhang   PetscFunctionBegin;
1209566063dSJacob Faibussowitsch   PetscCall(VecGetLocalSize(xx, &nt));
12108401ef6SPierre Jolivet   PetscCheck(nt == mat->rmap->n, PETSC_COMM_SELF, PETSC_ERR_ARG_SIZ, "Incompatible partition of mat (%" PetscInt_FMT ") and xx (%" PetscInt_FMT ")", mat->rmap->n, nt);
1229566063dSJacob Faibussowitsch   PetscCall((*mpiaij->B->ops->multtranspose)(mpiaij->B, xx, mpiaij->lvec));
1239566063dSJacob Faibussowitsch   PetscCall((*mpiaij->A->ops->multtranspose)(mpiaij->A, xx, yy));
1249566063dSJacob Faibussowitsch   PetscCall(VecScatterBegin(mpiaij->Mvctx, mpiaij->lvec, yy, ADD_VALUES, SCATTER_REVERSE));
1259566063dSJacob Faibussowitsch   PetscCall(VecScatterEnd(mpiaij->Mvctx, mpiaij->lvec, yy, ADD_VALUES, SCATTER_REVERSE));
1263ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
1278c3ff71bSJunchao Zhang }
1288c3ff71bSJunchao Zhang 
129076ba34aSJunchao Zhang /* Merge the "A, B" matrices of mat into a matrix C.  mat's type is MPIAIJKOKKOS. C's type is MATSEQAIJKOKKOS.
130076ba34aSJunchao Zhang    A is put before B. C's size would be A->rmap->n by (A->cmap->n + B->cmap->n).
131076ba34aSJunchao Zhang    C still uses local column ids. Their corresponding global column ids are returned in glob.
132076ba34aSJunchao Zhang */
13366976f2fSJacob Faibussowitsch static PetscErrorCode MatMPIAIJGetLocalMatMerge_MPIAIJKokkos(Mat mat, MatReuse reuse, IS *glob, Mat *C)
134d71ae5a4SJacob Faibussowitsch {
135076ba34aSJunchao Zhang   Mat             Ad, Ao;
136076ba34aSJunchao Zhang   const PetscInt *cmap;
137076ba34aSJunchao Zhang 
138076ba34aSJunchao Zhang   PetscFunctionBegin;
1399566063dSJacob Faibussowitsch   PetscCall(MatMPIAIJGetSeqAIJ(mat, &Ad, &Ao, &cmap));
1409566063dSJacob Faibussowitsch   PetscCall(MatSeqAIJKokkosMergeMats(Ad, Ao, reuse, C));
141076ba34aSJunchao Zhang   if (glob) {
142076ba34aSJunchao Zhang     PetscInt cst, i, dn, on, *gidx;
1439566063dSJacob Faibussowitsch     PetscCall(MatGetLocalSize(Ad, NULL, &dn));
1449566063dSJacob Faibussowitsch     PetscCall(MatGetLocalSize(Ao, NULL, &on));
1459566063dSJacob Faibussowitsch     PetscCall(MatGetOwnershipRangeColumn(mat, &cst, NULL));
1469566063dSJacob Faibussowitsch     PetscCall(PetscMalloc1(dn + on, &gidx));
147076ba34aSJunchao Zhang     for (i = 0; i < dn; i++) gidx[i] = cst + i;
148076ba34aSJunchao Zhang     for (i = 0; i < on; i++) gidx[i + dn] = cmap[i];
1499566063dSJacob Faibussowitsch     PetscCall(ISCreateGeneral(PetscObjectComm((PetscObject)Ad), dn + on, gidx, PETSC_OWN_POINTER, glob));
150076ba34aSJunchao Zhang   }
1513ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
152076ba34aSJunchao Zhang }
153076ba34aSJunchao Zhang 
1540e3ece09SJunchao Zhang /* Structs used in matrix products of type C=AB, C=A^tB and C=B^tAB */
155076ba34aSJunchao Zhang struct MatMatStruct {
1560e3ece09SJunchao Zhang   PetscInt            n, *garray;     // C's garray and its size.
1570e3ece09SJunchao Zhang   KokkosCsrMatrix     Cd, Co;         // C is in split form matrices (all in local column indcies)
1580e3ece09SJunchao Zhang   KokkosCsrMatrix     C1, C2, C3, C4; // intermediate mat products
1590e3ece09SJunchao Zhang   KokkosCsrMatrix     C2_mid, C4_mid; // alias of C2, C4; share their a[], i[], but with different j[] (hence column size)
1600e3ece09SJunchao Zhang   PetscIntKokkosView  E_NzLeft;
1610e3ece09SJunchao Zhang   PetscSF             sf = nullptr; // SF to bcast or reduce matrices E to F
1620e3ece09SJunchao Zhang   MatScalarKokkosView rootBuf, leafBuf;
1630e3ece09SJunchao Zhang   KokkosCsrMatrix     Fd, Fo; // F in split form
1640e3ece09SJunchao Zhang 
1650e3ece09SJunchao Zhang   KernelHandle kh1; // compute C1, add C1+C3 or C1+Fd
1660e3ece09SJunchao Zhang   KernelHandle kh2; // compute C2, add C2+C4 or C2+Fo
1670e3ece09SJunchao Zhang   KernelHandle kh3; // compute C3
1680e3ece09SJunchao Zhang   KernelHandle kh4; // compute C4
1690e3ece09SJunchao Zhang 
170aaa8cc7dSPierre Jolivet   PetscInt E_TeamSize; // kernel launching parameters in merging E or splitting F
1710e3ece09SJunchao Zhang   PetscInt E_VectorLength;
1720e3ece09SJunchao Zhang   PetscInt E_RowsPerTeam;
1730e3ece09SJunchao Zhang   PetscInt F_TeamSize;
1740e3ece09SJunchao Zhang   PetscInt F_VectorLength;
1750e3ece09SJunchao Zhang   PetscInt F_RowsPerTeam;
176076ba34aSJunchao Zhang 
177d71ae5a4SJacob Faibussowitsch   ~MatMatStruct()
178d71ae5a4SJacob Faibussowitsch   {
1793ba16761SJacob Faibussowitsch     PetscFunctionBegin;
1803ba16761SJacob Faibussowitsch     PetscCallAbort(PETSC_COMM_SELF, PetscSFDestroy(&sf));
1813ba16761SJacob Faibussowitsch     PetscFunctionReturnVoid();
182076ba34aSJunchao Zhang   }
183076ba34aSJunchao Zhang };
184076ba34aSJunchao Zhang 
185076ba34aSJunchao Zhang struct MatMatStruct_AB : public MatMatStruct {
1860e3ece09SJunchao Zhang   PetscIntKokkosView F_NzLeft; // plans to split F (in leafbuf) into Fd, Fo
1870e3ece09SJunchao Zhang   PetscIntKokkosView irootloc; // plans to put E (i.e., Bd, Bo) into rootBuf
1880e3ece09SJunchao Zhang   PetscIntKokkosView rowoffset;
189076ba34aSJunchao Zhang };
190076ba34aSJunchao Zhang 
191076ba34aSJunchao Zhang struct MatMatStruct_AtB : public MatMatStruct {
1920e3ece09SJunchao Zhang   MatColIdxKokkosView Fdjmap; // plans to reduce data in rootBuf to Fd, Fo
1930e3ece09SJunchao Zhang   MatColIdxKokkosView Fdjperm;
1940e3ece09SJunchao Zhang   MatColIdxKokkosView Fojmap;
1950e3ece09SJunchao Zhang   MatColIdxKokkosView Fojperm;
196076ba34aSJunchao Zhang };
197076ba34aSJunchao Zhang 
1989371c9d4SSatish Balay struct MatProductData_MPIAIJKokkos {
1993ba16761SJacob Faibussowitsch   MatMatStruct_AB  *mmAB     = nullptr;
2003ba16761SJacob Faibussowitsch   MatMatStruct_AtB *mmAtB    = nullptr;
2013ba16761SJacob Faibussowitsch   PetscBool         reusesym = PETSC_FALSE;
2020e3ece09SJunchao Zhang   Mat               Z        = nullptr; // store Z=AB in computing BtAB
203076ba34aSJunchao Zhang 
204d71ae5a4SJacob Faibussowitsch   ~MatProductData_MPIAIJKokkos()
205d71ae5a4SJacob Faibussowitsch   {
206076ba34aSJunchao Zhang     delete mmAB;
207076ba34aSJunchao Zhang     delete mmAtB;
2080e3ece09SJunchao Zhang     PetscCallAbort(PETSC_COMM_SELF, MatDestroy(&Z));
209076ba34aSJunchao Zhang   }
210076ba34aSJunchao Zhang };
211076ba34aSJunchao Zhang 
212d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatProductDataDestroy_MPIAIJKokkos(void *data)
213d71ae5a4SJacob Faibussowitsch {
214076ba34aSJunchao Zhang   PetscFunctionBegin;
2159566063dSJacob Faibussowitsch   PetscCallCXX(delete static_cast<MatProductData_MPIAIJKokkos *>(data));
2163ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
217076ba34aSJunchao Zhang }
218076ba34aSJunchao Zhang 
219076ba34aSJunchao Zhang /* MatSetMPIAIJKokkosWithSplitSeqAIJKokkosMatrices - Set the diag and offdiag matrices of a MATMPIAIJKOKKOS matrix.
220076ba34aSJunchao Zhang    It is similar to MatCreateMPIAIJWithSplitArrays.
221076ba34aSJunchao Zhang 
222076ba34aSJunchao Zhang   Input Parameters:
223076ba34aSJunchao Zhang +  mat   - the MATMPIAIJKOKKOS matrix, which should have its type and layout set, but should not have its diag, offdiag matrices set
224076ba34aSJunchao Zhang .  A     - the diag matrix using local col ids
225076ba34aSJunchao Zhang -  B     - the offdiag matrix using global col ids
226076ba34aSJunchao Zhang 
2272fe279fdSBarry Smith   Output Parameter:
228076ba34aSJunchao Zhang .  mat   - the updated MATMPIAIJKOKKOS matrix
229076ba34aSJunchao Zhang */
2300e3ece09SJunchao Zhang static PetscErrorCode MatSetMPIAIJKokkosWithSplitSeqAIJKokkosMatrices(Mat mat, Mat A, Mat B, PetscInt *garray)
231d71ae5a4SJacob Faibussowitsch {
232076ba34aSJunchao Zhang   Mat_MPIAIJ *mpiaij = static_cast<Mat_MPIAIJ *>(mat->data);
233076ba34aSJunchao Zhang   PetscInt    m, n, M, N, Am, An, Bm, Bn;
234076ba34aSJunchao Zhang 
235076ba34aSJunchao Zhang   PetscFunctionBegin;
2369566063dSJacob Faibussowitsch   PetscCall(MatGetSize(mat, &M, &N));
2379566063dSJacob Faibussowitsch   PetscCall(MatGetLocalSize(mat, &m, &n));
2389566063dSJacob Faibussowitsch   PetscCall(MatGetLocalSize(A, &Am, &An));
2399566063dSJacob Faibussowitsch   PetscCall(MatGetLocalSize(B, &Bm, &Bn));
240076ba34aSJunchao Zhang 
241aed4548fSBarry Smith   PetscCheck(m == Am && m == Bm, PETSC_COMM_SELF, PETSC_ERR_PLIB, "local number of rows do not match");
24208401ef6SPierre Jolivet   PetscCheck(n == An, PETSC_COMM_SELF, PETSC_ERR_PLIB, "local number of columns do not match");
2430e3ece09SJunchao Zhang   // PetscCheck(N == Bn, PETSC_COMM_SELF, PETSC_ERR_PLIB, "global number of columns do not match");
24408401ef6SPierre Jolivet   PetscCheck(!mpiaij->A && !mpiaij->B, PETSC_COMM_SELF, PETSC_ERR_PLIB, "A, B of the MPIAIJ matrix are not empty");
245076ba34aSJunchao Zhang   mpiaij->A      = A;
246076ba34aSJunchao Zhang   mpiaij->B      = B;
2470e3ece09SJunchao Zhang   mpiaij->garray = garray;
248076ba34aSJunchao Zhang 
249076ba34aSJunchao Zhang   mat->preallocated     = PETSC_TRUE;
250076ba34aSJunchao Zhang   mat->nooffprocentries = PETSC_TRUE; /* See MatAssemblyBegin_MPIAIJ. In effect, making MatAssemblyBegin a nop */
251076ba34aSJunchao Zhang 
2529566063dSJacob Faibussowitsch   PetscCall(MatSetOption(mat, MAT_NO_OFF_PROC_ENTRIES, PETSC_TRUE));
2539566063dSJacob Faibussowitsch   PetscCall(MatAssemblyBegin(mat, MAT_FINAL_ASSEMBLY));
254076ba34aSJunchao Zhang   /* MatAssemblyEnd is critical here. It sets mat->offloadmask according to A and B's, and
255076ba34aSJunchao Zhang     also gets mpiaij->B compacted, with its col ids and size reduced
256076ba34aSJunchao Zhang   */
2579566063dSJacob Faibussowitsch   PetscCall(MatAssemblyEnd(mat, MAT_FINAL_ASSEMBLY));
2589566063dSJacob Faibussowitsch   PetscCall(MatSetOption(mat, MAT_NO_OFF_PROC_ENTRIES, PETSC_FALSE));
2599566063dSJacob Faibussowitsch   PetscCall(MatSetOption(mat, MAT_NEW_NONZERO_LOCATION_ERR, PETSC_TRUE));
2603ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
261076ba34aSJunchao Zhang }
262076ba34aSJunchao Zhang 
2630e3ece09SJunchao Zhang // Adapted from Kokkos-Kernels spmv_launch_parameters(), to get parameters in Kokkos nested loops which we used to merge or
2640e3ece09SJunchao Zhang // split csr matrices. The rule is to have "vector_length * team_size" be around 256 on GPUs (e.g., for a CUDA thread block)
2650e3ece09SJunchao Zhang template <class ExecutionSpace>
2660e3ece09SJunchao Zhang static PetscErrorCode MatMergeGetLaunchParameters(PetscInt numRows, PetscInt nnz, PetscInt rows_per_thread, PetscInt &team_size, PetscInt &vector_length, PetscInt &rows_per_team)
267d71ae5a4SJacob Faibussowitsch {
2680e3ece09SJunchao Zhang   Kokkos::TeamPolicy<ExecutionSpace> teamPolicy(128, Kokkos::AUTO);
269076ba34aSJunchao Zhang 
270076ba34aSJunchao Zhang   PetscFunctionBegin;
2710e3ece09SJunchao Zhang   PetscInt nnz_per_row = numRows ? (nnz / numRows) : 0; // we might meet empty matrices
272076ba34aSJunchao Zhang 
2730e3ece09SJunchao Zhang   if (nnz_per_row < 1) nnz_per_row = 1;
274076ba34aSJunchao Zhang 
2750e3ece09SJunchao Zhang   int max_vector_length = teamPolicy.vector_length_max();
276076ba34aSJunchao Zhang 
2770e3ece09SJunchao Zhang   if (vector_length < 1) {
2780e3ece09SJunchao Zhang     vector_length = 1;
2790e3ece09SJunchao Zhang     while (vector_length < max_vector_length && vector_length * 6 < nnz_per_row) vector_length *= 2;
280076ba34aSJunchao Zhang   }
281076ba34aSJunchao Zhang 
2820e3ece09SJunchao Zhang   // Determine rows per thread
2830e3ece09SJunchao Zhang   if (rows_per_thread < 1) {
2840e3ece09SJunchao Zhang     if (KokkosKernels::Impl::kk_is_gpu_exec_space<ExecutionSpace>()) rows_per_thread = 1;
2850e3ece09SJunchao Zhang     else {
2860e3ece09SJunchao Zhang       if (nnz_per_row < 20 && nnz > 5000000) {
2870e3ece09SJunchao Zhang         rows_per_thread = 256;
2880e3ece09SJunchao Zhang       } else rows_per_thread = 64;
289076ba34aSJunchao Zhang     }
290076ba34aSJunchao Zhang   }
291076ba34aSJunchao Zhang 
2920e3ece09SJunchao Zhang   if (team_size < 1) {
2930e3ece09SJunchao Zhang     if (KokkosKernels::Impl::kk_is_gpu_exec_space<ExecutionSpace>()) {
2940e3ece09SJunchao Zhang       team_size = 256 / vector_length;
295076ba34aSJunchao Zhang     } else {
2960e3ece09SJunchao Zhang       team_size = 1;
2970e3ece09SJunchao Zhang     }
298076ba34aSJunchao Zhang   }
299076ba34aSJunchao Zhang 
3000e3ece09SJunchao Zhang   rows_per_team = rows_per_thread * team_size;
301076ba34aSJunchao Zhang 
3020e3ece09SJunchao Zhang   if (rows_per_team < 0) {
3030e3ece09SJunchao Zhang     PetscInt nnz_per_team = 4096;
3040e3ece09SJunchao Zhang     PetscInt conc         = ExecutionSpace().concurrency();
3050e3ece09SJunchao Zhang     while ((conc * nnz_per_team * 4 > nnz) && (nnz_per_team > 256)) nnz_per_team /= 2;
3060e3ece09SJunchao Zhang     rows_per_team = (nnz_per_team + nnz_per_row - 1) / nnz_per_row;
3070e3ece09SJunchao Zhang   }
3083ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
309076ba34aSJunchao Zhang }
310076ba34aSJunchao Zhang 
3110e3ece09SJunchao Zhang /*
3120e3ece09SJunchao Zhang   Reduce two sets of global indices into local ones
313076ba34aSJunchao Zhang 
314076ba34aSJunchao Zhang   Input Parameters:
3150e3ece09SJunchao Zhang +  n1          - size of garray1[], the first set
3160e3ece09SJunchao Zhang .  garray1[n1] - a sorted global index array (without duplicates)
3170e3ece09SJunchao Zhang .  m           - size of indices[], the second set
3180e3ece09SJunchao Zhang -  indices[m]  - a unsorted global index array (might have duplicates), which will be updated on output into local ones
319076ba34aSJunchao Zhang 
320076ba34aSJunchao Zhang   Output Parameters:
3210e3ece09SJunchao Zhang +  n2          - size of garray2[], the merged set, which combines garray1[] and indices[]
3220e3ece09SJunchao Zhang .  garray2[n2] - allocated by callee using PetscMalloc1(). Contains sorted unique global indices (without duplicates). Caller needs to free it.
3230e3ece09SJunchao Zhang .  map[n1]     - allocated by caller. It gives garray1[i] = garray2[map[i]]
3240e3ece09SJunchao Zhang -  indices[m]  - on output, global indices in this array are rewritten with local ones, i.e, indices_input[i] = garray2[indices_output[i]]
325076ba34aSJunchao Zhang 
3260e3ece09SJunchao Zhang    Example, say
3270e3ece09SJunchao Zhang     n1         = 5
3280e3ece09SJunchao Zhang     garray1[5] = {1, 4, 7, 8, 10}
3290e3ece09SJunchao Zhang     m          = 4
3300e3ece09SJunchao Zhang     indices[4] = {2, 4, 8, 9}
33111a5261eSBarry Smith 
3320e3ece09SJunchao Zhang    Combining them together, we have 7 global indices in garray2[]
3330e3ece09SJunchao Zhang     n2         = 7
3340e3ece09SJunchao Zhang     garray2[7] = {1, 2, 4, 7, 8, 9, 10}
3350e3ece09SJunchao Zhang 
3360e3ece09SJunchao Zhang    And we have map[] to connect "garray1[i] = garray2[map[i]], i=[0,n1)"
3370e3ece09SJunchao Zhang     map[5] = {0, 2, 3, 4, 6}
3380e3ece09SJunchao Zhang 
3390e3ece09SJunchao Zhang    On output, indices[] is updated with local indices
3400e3ece09SJunchao Zhang     indices[4] = {1, 2, 4, 5}
341076ba34aSJunchao Zhang */
3420e3ece09SJunchao Zhang static PetscErrorCode ReduceTwoSetsOfGlobalIndices(PetscInt n1, const PetscInt *garray1, PetscInt m, PetscInt *indices, PetscInt *n2_, PetscInt **garray2_, PetscInt *map)
343d71ae5a4SJacob Faibussowitsch {
3440e3ece09SJunchao Zhang   PetscHMapI    g2l = nullptr;
3450e3ece09SJunchao Zhang   PetscHashIter iter;
3460e3ece09SJunchao Zhang   PetscInt      tot, key, val; // total unique global indices. key is global id; val is local id
3470e3ece09SJunchao Zhang   PetscInt      n2, *garray2;
348076ba34aSJunchao Zhang 
349076ba34aSJunchao Zhang   PetscFunctionBegin;
3500e3ece09SJunchao Zhang   tot = 0;
3510e3ece09SJunchao Zhang   PetscCall(PetscHMapICreateWithSize(n1, &g2l));
3520e3ece09SJunchao Zhang   for (PetscInt i = 0; i < m; i++) {                                // insert those in indices[]
3530e3ece09SJunchao Zhang     PetscCall(PetscHMapIGetWithDefault(g2l, indices[i], -1, &val)); // if not exist, val is set with -1
3540e3ece09SJunchao Zhang     if (val < 0) PetscCall(PetscHMapISet(g2l, indices[i], tot++));  // val < 0 means gid is not in the hash table yet
355076ba34aSJunchao Zhang   }
356076ba34aSJunchao Zhang 
3570e3ece09SJunchao Zhang   for (PetscInt i = 0; i < n1; i++) { // insert those in garray1[]
3580e3ece09SJunchao Zhang     PetscCall(PetscHMapIGetWithDefault(g2l, garray1[i], -1, &val));
3590e3ece09SJunchao Zhang     if (val < 0) PetscCall(PetscHMapISet(g2l, garray1[i], tot++));
360076ba34aSJunchao Zhang   }
361076ba34aSJunchao Zhang 
3620e3ece09SJunchao Zhang   // Pull out (unique) globals in the hash table and put them in garray2[]
3630e3ece09SJunchao Zhang   n2 = tot;
3640e3ece09SJunchao Zhang   PetscCall(PetscMalloc1(n2, &garray2));
3650e3ece09SJunchao Zhang   tot = 0;
3660e3ece09SJunchao Zhang   PetscHashIterBegin(g2l, iter);
3670e3ece09SJunchao Zhang   while (!PetscHashIterAtEnd(g2l, iter)) {
3680e3ece09SJunchao Zhang     PetscHashIterGetKey(g2l, iter, key);
3690e3ece09SJunchao Zhang     PetscHashIterNext(g2l, iter);
3700e3ece09SJunchao Zhang     garray2[tot++] = key;
371076ba34aSJunchao Zhang   }
372076ba34aSJunchao Zhang 
3730e3ece09SJunchao Zhang   // Sort garray2[] and then map them to local indices starting from 0
3740e3ece09SJunchao Zhang   PetscCall(PetscSortInt(n2, garray2));
3750e3ece09SJunchao Zhang   PetscCall(PetscHMapIClear(g2l));
3760e3ece09SJunchao Zhang   for (PetscInt i = 0; i < tot; i++) PetscCall(PetscHMapISet(g2l, garray2[i], i)); // i is the local id
377f0e6e2d1SJunchao Zhang 
3780e3ece09SJunchao Zhang   // Rewrite indices[] with local indices
379f0e6e2d1SJunchao Zhang   for (PetscInt i = 0; i < m; i++) {
3800e3ece09SJunchao Zhang     PetscCall(PetscHMapIGetWithDefault(g2l, indices[i], -1, &val));
3810e3ece09SJunchao Zhang     PetscAssert(val >= 0, PETSC_COMM_SELF, PETSC_ERR_PLIB, "Met a negative local column index");
3820e3ece09SJunchao Zhang     indices[i] = val;
3830e3ece09SJunchao Zhang   }
3840e3ece09SJunchao Zhang   // Record the map that maps garray1[i] to garray2[map[i]]
3850e3ece09SJunchao Zhang   for (PetscInt i = 0; i < n1; i++) PetscCall(PetscHMapIGetWithDefault(g2l, garray1[i], -1, &map[i]));
3860e3ece09SJunchao Zhang   PetscCall(PetscHMapIDestroy(&g2l));
3870e3ece09SJunchao Zhang   *n2_      = n2;
3880e3ece09SJunchao Zhang   *garray2_ = garray2;
3890e3ece09SJunchao Zhang   PetscFunctionReturn(PETSC_SUCCESS);
3900e3ece09SJunchao Zhang }
391f0e6e2d1SJunchao Zhang 
3920e3ece09SJunchao Zhang /*
3930e3ece09SJunchao Zhang   MatMPIAIJKokkosReduce - Reduce rows of a MPIAIJKOKKOS matrix (E, in split form) to produce another matrix (F, also in split form, stored in mm)
3940e3ece09SJunchao Zhang 
3950e3ece09SJunchao Zhang   It is the reverse of MatMPIAIJKokkosBcast() in some sense, but with a different signature since we do not really need a fully populated MPIAIJKOKKOS E.
3960e3ece09SJunchao Zhang 
3970e3ece09SJunchao Zhang   Think each row of E as a leaf, then the given ownerSF specifies roots for the leaves. Roots may connect to multiple leaves.
3980e3ece09SJunchao Zhang   In this routine, we sparse-merge leaves (rows) at their roots to form potentially longer rows in F. F's number of rows will be nroots of ownerSF.
3990e3ece09SJunchao Zhang 
4000e3ece09SJunchao Zhang   Input Parameters:
4010e3ece09SJunchao Zhang +  comm       - MPI communicator of E
4020e3ece09SJunchao Zhang .  A          - diag block of E, using local column indices
4030e3ece09SJunchao Zhang .  B          - off-diag block of E, using local column indices
4040e3ece09SJunchao Zhang .  cstart      - (global) start column of Ed
4050e3ece09SJunchao Zhang .  cend        - (global) end column + 1 of Ed.  In other words, E's column ownership is in range of [cstart, cend)
4060e3ece09SJunchao Zhang .  garray1[n1] - global column indices of Eo. Here n1 is Eo's column size.
4070e3ece09SJunchao Zhang .  ownerSF     - the SF specifies ownership (root) of rows in E
4080e3ece09SJunchao Zhang .  reuse       - either MAT_INITIAL_MATRIX or MAT_REUSE_MATRIX
4090e3ece09SJunchao Zhang -  mm          - to stash intermediate data structures for reuse
4100e3ece09SJunchao Zhang 
4110e3ece09SJunchao Zhang   Output Parameters:
4120e3ece09SJunchao Zhang +  map[n1]  - allocated by caller. It maps garray1[] to garray2[]. See more at ReduceTwoSetsOfGlobalIndices().
4130e3ece09SJunchao Zhang -  mm       - contains various info, such as garray2[], F (Fd, Fo) etc.
4140e3ece09SJunchao Zhang 
4150e3ece09SJunchao Zhang   Notes:
4160e3ece09SJunchao Zhang   When reuse = MAT_REUSE_MATRIX, cstart, cend, garray1, ownerSF, map are not significant.
4170e3ece09SJunchao Zhang 
4180e3ece09SJunchao Zhang  */
4190e3ece09SJunchao Zhang static PetscErrorCode MatMPIAIJKokkosReduceBegin(MPI_Comm comm, KokkosCsrMatrix A, KokkosCsrMatrix B, PetscInt cstart, PetscInt cend, const PetscInt *garray1, PetscSF ownerSF, MatReuse reuse, PetscInt *map, MatMatStruct_AtB *mm)
4200e3ece09SJunchao Zhang {
4210e3ece09SJunchao Zhang   PetscFunctionBegin;
4220e3ece09SJunchao Zhang   if (reuse == MAT_INITIAL_MATRIX) {
4230e3ece09SJunchao Zhang     PetscInt Em = A.numRows(), Fm;
4240e3ece09SJunchao Zhang     PetscInt n1 = B.numCols();
4250e3ece09SJunchao Zhang 
4260e3ece09SJunchao Zhang     PetscCall(PetscSFGetGraph(ownerSF, &Fm, NULL, NULL, NULL)); // Fm = #rows of F = nroots of ownerSF
4270e3ece09SJunchao Zhang 
4280e3ece09SJunchao Zhang     // Do the analysis on host
4290e3ece09SJunchao Zhang     auto                 Ai_h = Kokkos::create_mirror_view_and_copy(Kokkos::HostSpace(), A.graph.row_map);
4300e3ece09SJunchao Zhang     auto                 Aj_h = Kokkos::create_mirror_view_and_copy(Kokkos::HostSpace(), A.graph.entries);
4310e3ece09SJunchao Zhang     auto                 Bi_h = Kokkos::create_mirror_view_and_copy(Kokkos::HostSpace(), B.graph.row_map);
4320e3ece09SJunchao Zhang     auto                 Bj_h = Kokkos::create_mirror_view_and_copy(Kokkos::HostSpace(), B.graph.entries);
4330e3ece09SJunchao Zhang     const MatRowMapType *Ai = Ai_h.data(), *Bi = Bi_h.data();
4340e3ece09SJunchao Zhang     const MatColIdxType *Aj = Aj_h.data(), *Bj = Bj_h.data();
4350e3ece09SJunchao Zhang 
4360e3ece09SJunchao Zhang     // Count how many nonzeros of each row in E are in the left of the diag block [cstart,cend)
4377b8d4ba6SJunchao Zhang     PetscIntKokkosViewHost E_NzLeft_h(NoInit("E_NzLeft_h"), Em), E_RowLen_h(NoInit("E_RowLen_h"), Em);
4380e3ece09SJunchao Zhang     PetscInt              *E_NzLeft = E_NzLeft_h.data(), *E_RowLen = E_RowLen_h.data();
4390e3ece09SJunchao Zhang     for (PetscInt i = 0; i < Em; i++) {
4400e3ece09SJunchao Zhang       const PetscInt *first, *last, *it;
4410e3ece09SJunchao Zhang       PetscInt        count, step;
4420e3ece09SJunchao Zhang       // std::lower_bound(first,last,cstart), but need to use global column indices
4430e3ece09SJunchao Zhang       first = Bj + Bi[i];
4440e3ece09SJunchao Zhang       last  = Bj + Bi[i + 1];
445f0e6e2d1SJunchao Zhang       count = last - first;
446f0e6e2d1SJunchao Zhang       while (count > 0) {
447f0e6e2d1SJunchao Zhang         it   = first;
448f0e6e2d1SJunchao Zhang         step = count / 2;
449f0e6e2d1SJunchao Zhang         it += step;
4500e3ece09SJunchao Zhang         if (garray1[*it] < cstart) { // map local to global
451f0e6e2d1SJunchao Zhang           first = ++it;
452f0e6e2d1SJunchao Zhang           count -= step + 1;
453f0e6e2d1SJunchao Zhang         } else count = step;
454f0e6e2d1SJunchao Zhang       }
4550e3ece09SJunchao Zhang       E_NzLeft[i] = first - (Bj + Bi[i]);
4560e3ece09SJunchao Zhang       E_RowLen[i] = (Ai[i + 1] - Ai[i]) + (Bi[i + 1] - Bi[i]);
457f0e6e2d1SJunchao Zhang     }
458f0e6e2d1SJunchao Zhang 
4590e3ece09SJunchao Zhang     // Get length of rows (i.e., sizes of leaves) that contribute to my roots
4600e3ece09SJunchao Zhang     const PetscMPIInt *iranks, *ranks;
4610e3ece09SJunchao Zhang     const PetscInt    *ioffset, *irootloc, *roffset, *rmine;
4620e3ece09SJunchao Zhang     PetscInt           niranks, nranks;
4630e3ece09SJunchao Zhang     MPI_Request       *reqs;
4640e3ece09SJunchao Zhang     PetscMPIInt        tag;
4650e3ece09SJunchao Zhang     PetscSF            reduceSF;
4660e3ece09SJunchao Zhang     PetscInt          *sdisp, *rdisp;
467f0e6e2d1SJunchao Zhang 
4680e3ece09SJunchao Zhang     PetscCall(PetscCommGetNewTag(comm, &tag));
4690e3ece09SJunchao Zhang     PetscCall(PetscSFGetLeafRanks(ownerSF, &niranks, &iranks, &ioffset, &irootloc));  // get leaf ranks connecting to roots on this process (I'll recv from them)
4700e3ece09SJunchao Zhang     PetscCall(PetscSFGetRootRanks(ownerSF, &nranks, &ranks, &roffset, &rmine, NULL)); // get root ranks this process connects (I'll send to them)
471f0e6e2d1SJunchao Zhang 
4720e3ece09SJunchao Zhang     // Find out length of each row I will receive. Even for the same row index, when they are from
4730e3ece09SJunchao Zhang     // different senders, they might have different lengths (and sparsity patterns)
4740e3ece09SJunchao Zhang     PetscInt  sendRowCnt = roffset[nranks], recvRowCnt = ioffset[niranks];
4750e3ece09SJunchao Zhang     PetscInt *sendRowLen, *recvRowLen; // lengths of rows of I need to send/recv per process
476f0e6e2d1SJunchao Zhang 
4770e3ece09SJunchao Zhang     PetscCall(PetscMalloc5(sendRowCnt, &sendRowLen, recvRowCnt + 1, &recvRowLen, nranks, &sdisp, niranks + 1, &rdisp, nranks + niranks, &reqs));
4780e3ece09SJunchao Zhang 
4790e3ece09SJunchao Zhang     for (PetscInt i = 0; i < sendRowCnt; i++) sendRowLen[i] = E_RowLen[rmine[i]];
4800e3ece09SJunchao Zhang     recvRowLen[0] = 0; // since we will make it in CSR format later
4810e3ece09SJunchao Zhang     recvRowLen++;      // advance the pointer now
4820e3ece09SJunchao Zhang     for (PetscInt i = 0; i < niranks; i++) { MPI_Irecv(&recvRowLen[ioffset[i]], ioffset[i + 1] - ioffset[i], MPIU_INT, iranks[i], tag, comm, &reqs[nranks + i]); }
4830e3ece09SJunchao Zhang     for (PetscInt i = 0; i < nranks; i++) { MPI_Isend(&sendRowLen[roffset[i]], roffset[i + 1] - roffset[i], MPIU_INT, ranks[i], tag, comm, &reqs[i]); }
4840e3ece09SJunchao Zhang     PetscCallMPI(MPI_Waitall(nranks + niranks, reqs, MPI_STATUSES_IGNORE));
4850e3ece09SJunchao Zhang 
4860e3ece09SJunchao Zhang     // Build the real PetscSF for reducing E rows (buffer to buffer)
4870e3ece09SJunchao Zhang     rdisp[0] = 0;
4880e3ece09SJunchao Zhang     for (PetscInt i = 0; i < niranks; i++) {
4890e3ece09SJunchao Zhang       rdisp[i + 1] = rdisp[i];
4900e3ece09SJunchao Zhang       for (PetscInt j = ioffset[i]; j < ioffset[i + 1]; j++) { rdisp[i + 1] += recvRowLen[j]; }
4910e3ece09SJunchao Zhang     }
4920e3ece09SJunchao Zhang     recvRowLen--; // put it back into csr format
4930e3ece09SJunchao Zhang     for (PetscInt i = 0; i < recvRowCnt; i++) recvRowLen[i + 1] += recvRowLen[i];
4940e3ece09SJunchao Zhang 
4950e3ece09SJunchao Zhang     for (PetscInt i = 0; i < nranks; i++) { MPI_Irecv(&sdisp[i], 1, MPIU_INT, ranks[i], tag, comm, &reqs[i]); }
4960e3ece09SJunchao Zhang     for (PetscInt i = 0; i < niranks; i++) { MPI_Isend(&rdisp[i], 1, MPIU_INT, iranks[i], tag, comm, &reqs[nranks + i]); }
4970e3ece09SJunchao Zhang     PetscCallMPI(MPI_Waitall(nranks + niranks, reqs, MPI_STATUSES_IGNORE));
4980e3ece09SJunchao Zhang 
4990e3ece09SJunchao Zhang     PetscInt     nleaves = 0, Enz = 0;    // leaves are nonzeros I will send
5000e3ece09SJunchao Zhang     PetscInt     nroots = rdisp[niranks]; // roots are nonzeros I will recv
5010e3ece09SJunchao Zhang     PetscSFNode *iremote;
5020e3ece09SJunchao Zhang 
5030e3ece09SJunchao Zhang     for (PetscInt i = 0; i < Em; i++) Enz += E_RowLen[i];
5040e3ece09SJunchao Zhang     PetscAssert(A.nnz() + B.nnz() == Enz, comm, PETSC_ERR_PLIB, "Enz should be equal to sum of nnz of A and B");
5050e3ece09SJunchao Zhang     PetscCallMPI(PetscMalloc1(Enz, &iremote)); // no free, since we give ownership to reduceSF
5060e3ece09SJunchao Zhang 
5070e3ece09SJunchao Zhang     for (PetscInt i = 0; i < nranks; i++) {
5080e3ece09SJunchao Zhang       PetscInt count = 0;
5090e3ece09SJunchao Zhang       for (PetscInt j = roffset[i]; j < roffset[i + 1]; j++) count += E_RowLen[rmine[j]];
5100e3ece09SJunchao Zhang       for (PetscInt j = 0; j < count; j++) {
5110e3ece09SJunchao Zhang         iremote[nleaves + j].rank  = ranks[i];
5120e3ece09SJunchao Zhang         iremote[nleaves + j].index = sdisp[i] + j;
5130e3ece09SJunchao Zhang       }
5140e3ece09SJunchao Zhang       nleaves += count;
5150e3ece09SJunchao Zhang     }
5160e3ece09SJunchao Zhang     PetscCheck(nleaves == Enz, comm, PETSC_ERR_PLIB, "nleaves should be equal to Enz");
5170e3ece09SJunchao Zhang 
5180e3ece09SJunchao Zhang     PetscCall(PetscSFCreate(comm, &reduceSF));
5190e3ece09SJunchao Zhang     PetscCall(PetscSFSetGraph(reduceSF, nroots, nleaves, NULL, PETSC_OWN_POINTER, iremote, PETSC_OWN_POINTER));
5200e3ece09SJunchao Zhang 
5210e3ece09SJunchao Zhang     // Copy (global) column indices of the needed rows in E to sendCol[], and then PetscSFReduce to recvCol[]
5220e3ece09SJunchao Zhang     PetscInt *sendCol, *recvCol;
5230e3ece09SJunchao Zhang     PetscCall(PetscMalloc2(nleaves, &sendCol, nroots, &recvCol));
5240e3ece09SJunchao Zhang     for (PetscInt k = 0; k < roffset[nranks]; k++) {
5250e3ece09SJunchao Zhang       PetscInt  i      = rmine[k]; // row to be copied
5260e3ece09SJunchao Zhang       PetscInt *buf    = &sendCol[Ai[i] + Bi[i]];
5270e3ece09SJunchao Zhang       PetscInt  nzLeft = E_NzLeft[i];
5280e3ece09SJunchao Zhang       PetscInt  alen = Ai[i + 1] - Ai[i], blen = Bi[i + 1] - Bi[i];
5290e3ece09SJunchao Zhang       for (PetscInt j = 0; j < alen + blen; j++) {
5300e3ece09SJunchao Zhang         if (j < nzLeft) {
5310e3ece09SJunchao Zhang           buf[j] = garray1[Bj[Bi[i] + j]]; // left B, in global
5320e3ece09SJunchao Zhang         } else if (j < nzLeft + alen) {
5330e3ece09SJunchao Zhang           buf[j] = Aj[Ai[i] + j - nzLeft] + cstart; // diag A, also in global
5340e3ece09SJunchao Zhang         } else {
5350e3ece09SJunchao Zhang           buf[j] = garray1[Bj[Bi[i] + j - alen]]; // right B, in global
5360e3ece09SJunchao Zhang         }
5370e3ece09SJunchao Zhang       }
5380e3ece09SJunchao Zhang     }
5390e3ece09SJunchao Zhang     PetscCall(PetscSFReduceWithMemTypeBegin(reduceSF, MPIU_INT, PETSC_MEMTYPE_HOST, sendCol, PETSC_MEMTYPE_HOST, recvCol, MPI_REPLACE));
5400e3ece09SJunchao Zhang     PetscCall(PetscSFReduceEnd(reduceSF, MPIU_INT, sendCol, recvCol, MPI_REPLACE));
5410e3ece09SJunchao Zhang 
5420e3ece09SJunchao Zhang     // With recvCol[], we do a series of analysis to get i, j of Fd, Fo, and build plans to reduce nonzeros in recv buffers to Fd and Fo
5430e3ece09SJunchao Zhang     PetscInt *recvRowPerm, *recvColSorted;
5440e3ece09SJunchao Zhang     PetscInt *recvNzPerm, *recvNzPermSorted;
5450e3ece09SJunchao Zhang     PetscCall(PetscMalloc4(recvRowCnt, &recvRowPerm, nroots, &recvColSorted, nroots, &recvNzPerm, nroots, &recvNzPermSorted));
5460e3ece09SJunchao Zhang 
5470e3ece09SJunchao Zhang     for (PetscInt i = 0; i < nroots; i++) recvNzPerm[i] = i;                   // numbering all received nonzeros
5480e3ece09SJunchao Zhang     for (PetscInt i = 0; i < recvRowCnt; i++) recvRowPerm[i] = i;              // put up a permutation array, so that after sorting we know where to get a row in recvCol[]
5490e3ece09SJunchao Zhang     PetscCall(PetscSortIntWithPermutation(recvRowCnt, irootloc, recvRowPerm)); // irootloc[] (owned by ownerSF) won't be changed
5500e3ece09SJunchao Zhang 
5510e3ece09SJunchao Zhang     // i[] array, nz are always easiest to compute
5527b8d4ba6SJunchao Zhang     MatRowMapKokkosViewHost Fdi_h(NoInit("Fdi_h"), Fm + 1), Foi_h(NoInit("Foi_h"), Fm + 1);
5530e3ece09SJunchao Zhang     MatRowMapType          *Fdi, *Foi;
5540e3ece09SJunchao Zhang     PetscInt                FnzDups = 0, Fdnz = 0, FdnzDups = 0, Fonz = 0, FonzDups = 0; // nz (with or without dups) in F, Fd, Fo
5550e3ece09SJunchao Zhang     PetscInt                iter;
5560e3ece09SJunchao Zhang 
5570e3ece09SJunchao Zhang     Kokkos::deep_copy(Fdi_h, 0); // zero, as we will do 'val++' on them
5580e3ece09SJunchao Zhang     Kokkos::deep_copy(Foi_h, 0);
5590e3ece09SJunchao Zhang     Fdi  = Fdi_h.data() + 1; // +1 for easy indexing in code below
5600e3ece09SJunchao Zhang     Foi  = Foi_h.data() + 1;
5610e3ece09SJunchao Zhang     iter = 0;
5620e3ece09SJunchao Zhang     while (iter < recvRowCnt) { // iter over received rows
5630e3ece09SJunchao Zhang       PetscInt curRowIdx = irootloc[recvRowPerm[iter]];
5640e3ece09SJunchao Zhang       PetscInt dupRows   = 1; // current row has this many contributing rows (of various sparsity patterns)
5650e3ece09SJunchao Zhang 
5660e3ece09SJunchao Zhang       while (iter + dupRows < recvRowCnt && irootloc[recvRowPerm[iter + dupRows]] == curRowIdx) dupRows++;
5670e3ece09SJunchao Zhang 
5680e3ece09SJunchao Zhang       // Copy column indices (and their permutation) of these rows into recvColSorted & recvNzPermSorted
5690e3ece09SJunchao Zhang       PetscInt  nz    = 0; // nz (with dups) in the current row
5700e3ece09SJunchao Zhang       PetscInt *jbuf  = recvColSorted + FnzDups;
5710e3ece09SJunchao Zhang       PetscInt *pbuf  = recvNzPermSorted + FnzDups;
5720e3ece09SJunchao Zhang       PetscInt *jbuf2 = jbuf; // temp pointers
5730e3ece09SJunchao Zhang       PetscInt *pbuf2 = pbuf;
5740e3ece09SJunchao Zhang       for (PetscInt d = 0; d < dupRows; d++) {
5750e3ece09SJunchao Zhang         PetscInt i   = recvRowPerm[iter + d];
5760e3ece09SJunchao Zhang         PetscInt len = recvRowLen[i + 1] - recvRowLen[i];
5770e3ece09SJunchao Zhang         PetscCall(PetscArraycpy(jbuf2, &recvCol[recvRowLen[i]], len));
5780e3ece09SJunchao Zhang         PetscCall(PetscArraycpy(pbuf2, &recvNzPerm[recvRowLen[i]], len));
5790e3ece09SJunchao Zhang         jbuf2 += len;
5800e3ece09SJunchao Zhang         pbuf2 += len;
5810e3ece09SJunchao Zhang         nz += len;
5820e3ece09SJunchao Zhang       }
5830e3ece09SJunchao Zhang       PetscCall(PetscIntSortSemiOrderedWithArray(nz, jbuf, pbuf)); // It could be improved with k-way merge sort, since the rows are already sorted
5840e3ece09SJunchao Zhang 
5850e3ece09SJunchao Zhang       // Scan column indices (in jbuf[0,nz), might have dups) of this row, and see how many go to Fd and how many go to Fo
5860e3ece09SJunchao Zhang       PetscInt cur = 0;
5870e3ece09SJunchao Zhang       while (cur < nz) {
5880e3ece09SJunchao Zhang         PetscInt curColIdx = jbuf[cur];
5890e3ece09SJunchao Zhang         PetscInt dups      = 1;
5900e3ece09SJunchao Zhang 
5910e3ece09SJunchao Zhang         while (cur + dups < nz && jbuf[cur + dups] == curColIdx) dups++;
5920e3ece09SJunchao Zhang         if (curColIdx >= cstart && curColIdx < cend) {
5930e3ece09SJunchao Zhang           Fdi[curRowIdx]++;
5940e3ece09SJunchao Zhang           FdnzDups += dups;
5950e3ece09SJunchao Zhang         } else {
5960e3ece09SJunchao Zhang           Foi[curRowIdx]++;
5970e3ece09SJunchao Zhang           FonzDups += dups;
5980e3ece09SJunchao Zhang         }
5990e3ece09SJunchao Zhang         cur += dups;
6000e3ece09SJunchao Zhang       }
6010e3ece09SJunchao Zhang 
6020e3ece09SJunchao Zhang       FnzDups += nz;
6030e3ece09SJunchao Zhang       iter += dupRows; // Move to next unique row
6040e3ece09SJunchao Zhang     }
6050e3ece09SJunchao Zhang 
6060e3ece09SJunchao Zhang     Fdi = Fdi_h.data(); // restore Fdi, Foi and make them CSR
6070e3ece09SJunchao Zhang     Foi = Foi_h.data();
6080e3ece09SJunchao Zhang     for (PetscInt i = 0; i < Fm; i++) {
6090e3ece09SJunchao Zhang       Fdi[i + 1] += Fdi[i];
6100e3ece09SJunchao Zhang       Foi[i + 1] += Foi[i];
6110e3ece09SJunchao Zhang     }
6120e3ece09SJunchao Zhang     Fdnz = Fdi[Fm];
6130e3ece09SJunchao Zhang     Fonz = Foi[Fm];
6140e3ece09SJunchao Zhang     PetscCall(PetscFree2(sendCol, recvCol));
6150e3ece09SJunchao Zhang 
6160e3ece09SJunchao Zhang     // Allocate j, jmap, jperm for Fd and Fo
6177b8d4ba6SJunchao Zhang     MatColIdxKokkosViewHost Fdj_h(NoInit("Fdj_h"), Fdnz), Foj_h(NoInit("Foj_h"), Fonz);
6187b8d4ba6SJunchao Zhang     MatRowMapKokkosViewHost Fdjmap_h(NoInit("Fdjmap_h"), Fdnz + 1), Fojmap_h(NoInit("Fojmap_h"), Fonz + 1); // +1 to make csr
6197b8d4ba6SJunchao Zhang     MatRowMapKokkosViewHost Fdjperm_h(NoInit("Fdjperm_h"), FdnzDups), Fojperm_h(NoInit("Fojperm_h"), FonzDups);
6200e3ece09SJunchao Zhang     MatColIdxType          *Fdj = Fdj_h.data(), *Foj = Foj_h.data();
6210e3ece09SJunchao Zhang     MatRowMapType          *Fdjmap = Fdjmap_h.data(), *Fojmap = Fojmap_h.data();
6220e3ece09SJunchao Zhang     MatRowMapType          *Fdjperm = Fdjperm_h.data(), *Fojperm = Fojperm_h.data();
6230e3ece09SJunchao Zhang 
6240e3ece09SJunchao Zhang     // Scan recvColSorted[] again, and fill j, jmap, jperm for Fd and Fo
6250e3ece09SJunchao Zhang     Fdjmap[0] = 0;
6260e3ece09SJunchao Zhang     Fojmap[0] = 0;
6270e3ece09SJunchao Zhang     FnzDups   = 0;
6280e3ece09SJunchao Zhang     Fdnz      = 0;
6290e3ece09SJunchao Zhang     Fonz      = 0;
6300e3ece09SJunchao Zhang     iter      = 0; // iter over received rows
6310e3ece09SJunchao Zhang     while (iter < recvRowCnt) {
6320e3ece09SJunchao Zhang       PetscInt curRowIdx = irootloc[recvRowPerm[iter]]; // current row idx
6330e3ece09SJunchao Zhang       PetscInt dupRows   = 1;                           // It has this many contributing rows (of various lengths)
6340e3ece09SJunchao Zhang       PetscInt nz        = 0;                           // nz (with dups) in the current row
6350e3ece09SJunchao Zhang 
6360e3ece09SJunchao Zhang       while (iter + dupRows < recvRowCnt && irootloc[recvRowPerm[iter + dupRows]] == curRowIdx) dupRows++;
6370e3ece09SJunchao Zhang       for (PetscInt d = 0; d < dupRows; d++) {
6380e3ece09SJunchao Zhang         PetscInt i = recvRowPerm[iter + d];
6390e3ece09SJunchao Zhang         nz += recvRowLen[i + 1] - recvRowLen[i];
6400e3ece09SJunchao Zhang       }
6410e3ece09SJunchao Zhang 
6420e3ece09SJunchao Zhang       PetscInt *jbuf = recvColSorted + FnzDups;
6430e3ece09SJunchao Zhang       // Scan columns (in jbuf[0,nz) of this row, copy them and their permutation to j[] and jperm[] of Fd and Fo
6440e3ece09SJunchao Zhang       PetscInt cur = 0;
6450e3ece09SJunchao Zhang       while (cur < nz) {
6460e3ece09SJunchao Zhang         PetscInt curColIdx = jbuf[cur];
6470e3ece09SJunchao Zhang         PetscInt dups      = 1;
6480e3ece09SJunchao Zhang 
6490e3ece09SJunchao Zhang         while (cur + dups < nz && jbuf[cur + dups] == curColIdx) dups++;
6500e3ece09SJunchao Zhang         if (curColIdx >= cstart && curColIdx < cend) {
6510e3ece09SJunchao Zhang           Fdj[Fdnz]        = curColIdx - cstart; // easily convert to local
6520e3ece09SJunchao Zhang           Fdjmap[Fdnz + 1] = Fdjmap[Fdnz] + dups;
6530e3ece09SJunchao Zhang           for (PetscInt j = 0; j < dups; j++) Fdjperm[Fdjmap[Fdnz] + j] = recvNzPermSorted[FnzDups + j];
6540e3ece09SJunchao Zhang           FdnzDups += dups;
6550e3ece09SJunchao Zhang           Fdnz++;
6560e3ece09SJunchao Zhang         } else {
6570e3ece09SJunchao Zhang           Foj[Fonz]        = curColIdx; // in global
6580e3ece09SJunchao Zhang           Fojmap[Fonz + 1] = Fojmap[Fonz] + dups;
6590e3ece09SJunchao Zhang           for (PetscInt j = 0; j < dups; j++) Fojperm[Fojmap[Fonz] + j] = recvNzPermSorted[FnzDups + j];
6600e3ece09SJunchao Zhang           FonzDups += dups;
6610e3ece09SJunchao Zhang           Fonz++;
6620e3ece09SJunchao Zhang         }
6630e3ece09SJunchao Zhang         cur += dups;
6640e3ece09SJunchao Zhang         FnzDups += dups;
6650e3ece09SJunchao Zhang       }
6660e3ece09SJunchao Zhang       iter += dupRows; // Move to next unique row
6670e3ece09SJunchao Zhang     }
6680e3ece09SJunchao Zhang     PetscCall(PetscFree4(recvRowPerm, recvColSorted, recvNzPerm, recvNzPermSorted));
6690e3ece09SJunchao Zhang     PetscCall(PetscFree5(sendRowLen, recvRowLen, sdisp, rdisp, reqs));
6700e3ece09SJunchao Zhang 
6710e3ece09SJunchao Zhang     // Combine global column indices in garray1[] and Foj[]
6720e3ece09SJunchao Zhang     PetscInt n2, *garray2;
6730e3ece09SJunchao Zhang 
6740e3ece09SJunchao Zhang     PetscCall(ReduceTwoSetsOfGlobalIndices(n1, garray1, Fonz, Foj, &n2, &garray2, map));
6750e3ece09SJunchao Zhang     mm->sf       = reduceSF;
6767b8d4ba6SJunchao Zhang     mm->leafBuf  = MatScalarKokkosView(NoInit("leafBuf"), nleaves);
6777b8d4ba6SJunchao Zhang     mm->rootBuf  = MatScalarKokkosView(NoInit("rootBuf"), nroots);
678aaa8cc7dSPierre Jolivet     mm->garray   = garray2; // give ownership, so no free
6790e3ece09SJunchao Zhang     mm->n        = n2;
6800e3ece09SJunchao Zhang     mm->E_NzLeft = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), E_NzLeft_h);
6810e3ece09SJunchao Zhang     mm->Fdjmap   = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Fdjmap_h);
6820e3ece09SJunchao Zhang     mm->Fdjperm  = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Fdjperm_h);
6830e3ece09SJunchao Zhang     mm->Fojmap   = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Fojmap_h);
6840e3ece09SJunchao Zhang     mm->Fojperm  = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Fojperm_h);
6850e3ece09SJunchao Zhang 
6860e3ece09SJunchao Zhang     // Output Fd and Fo in KokkosCsrMatrix format
6877b8d4ba6SJunchao Zhang     MatScalarKokkosView Fda_d(NoInit("Fda_d"), Fdnz);
6880e3ece09SJunchao Zhang     MatRowMapKokkosView Fdi_d = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Fdi_h);
6890e3ece09SJunchao Zhang     MatColIdxKokkosView Fdj_d = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Fdj_h);
6907b8d4ba6SJunchao Zhang     MatScalarKokkosView Foa_d(NoInit("Foa_d"), Fonz);
6910e3ece09SJunchao Zhang     MatRowMapKokkosView Foi_d = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Foi_h);
6920e3ece09SJunchao Zhang     MatColIdxKokkosView Foj_d = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Foj_h);
6930e3ece09SJunchao Zhang 
6940e3ece09SJunchao Zhang     PetscCallCXX(mm->Fd = KokkosCsrMatrix("Fd", Fm, cend - cstart, Fdnz, Fda_d, Fdi_d, Fdj_d));
6950e3ece09SJunchao Zhang     PetscCallCXX(mm->Fo = KokkosCsrMatrix("Fo", Fm, n2, Fonz, Foa_d, Foi_d, Foj_d)); // Fo's column size is n2, length of garray2[]
6960e3ece09SJunchao Zhang 
6970e3ece09SJunchao Zhang     // Compute kernel launch parameters in merging E
6980e3ece09SJunchao Zhang     PetscInt teamSize, vectorLength, rowsPerTeam;
6990e3ece09SJunchao Zhang 
7000e3ece09SJunchao Zhang     teamSize = vectorLength = rowsPerTeam = -1;
7010e3ece09SJunchao Zhang     PetscCall(MatMergeGetLaunchParameters<DefaultExecutionSpace>(Em, Enz, -1, teamSize, vectorLength, rowsPerTeam));
7020e3ece09SJunchao Zhang     mm->E_TeamSize     = teamSize;
7030e3ece09SJunchao Zhang     mm->E_VectorLength = vectorLength;
7040e3ece09SJunchao Zhang     mm->E_RowsPerTeam  = rowsPerTeam;
7050e3ece09SJunchao Zhang   } else PetscCheck(reuse == MAT_REUSE_MATRIX, comm, PETSC_ERR_PLIB, "Unsupported MatReuse enum %d", reuse);
7060e3ece09SJunchao Zhang 
7070e3ece09SJunchao Zhang   // Handy aliases
7080e3ece09SJunchao Zhang   auto       &Aa           = A.values;
7090e3ece09SJunchao Zhang   auto       &Ba           = B.values;
7100e3ece09SJunchao Zhang   const auto &Ai           = A.graph.row_map;
7110e3ece09SJunchao Zhang   const auto &Bi           = B.graph.row_map;
7120e3ece09SJunchao Zhang   const auto &E_NzLeft     = mm->E_NzLeft;
7130e3ece09SJunchao Zhang   auto       &leafBuf      = mm->leafBuf;
7140e3ece09SJunchao Zhang   auto       &rootBuf      = mm->rootBuf;
7150e3ece09SJunchao Zhang   PetscSF     reduceSF     = mm->sf;
7160e3ece09SJunchao Zhang   PetscInt    Em           = A.numRows();
7170e3ece09SJunchao Zhang   PetscInt    teamSize     = mm->E_TeamSize;
7180e3ece09SJunchao Zhang   PetscInt    vectorLength = mm->E_VectorLength;
7190e3ece09SJunchao Zhang   PetscInt    rowsPerTeam  = mm->E_RowsPerTeam;
7200e3ece09SJunchao Zhang   PetscInt    workSets     = (Em + rowsPerTeam - 1) / rowsPerTeam;
7210e3ece09SJunchao Zhang 
7220e3ece09SJunchao Zhang   // Copy rows in A/B of E to leafBuf, then pass it to rootBuf
7230e3ece09SJunchao Zhang   PetscCallCXX(Kokkos::parallel_for(
724d326c3f1SJunchao Zhang     Kokkos::TeamPolicy<>(PetscGetKokkosExecutionSpace(), workSets, teamSize, vectorLength), KOKKOS_LAMBDA(const KokkosTeamMemberType &t) {
7250e3ece09SJunchao Zhang       Kokkos::parallel_for(Kokkos::TeamThreadRange(t, 0, rowsPerTeam), [&](PetscInt k) {
7260e3ece09SJunchao Zhang         PetscInt i = t.league_rank() * rowsPerTeam + k; // i-th row in F
7270e3ece09SJunchao Zhang         if (i < Em) {
7280e3ece09SJunchao Zhang           PetscInt disp   = Ai(i) + Bi(i);
7290e3ece09SJunchao Zhang           PetscInt alen   = Ai(i + 1) - Ai(i);
7300e3ece09SJunchao Zhang           PetscInt blen   = Bi(i + 1) - Bi(i);
7310e3ece09SJunchao Zhang           PetscInt nzleft = E_NzLeft(i);
7320e3ece09SJunchao Zhang 
7330e3ece09SJunchao Zhang           Kokkos::parallel_for(Kokkos::ThreadVectorRange(t, alen + blen), [&](PetscInt j) {
7340e3ece09SJunchao Zhang             MatScalar &val = leafBuf(disp + j);
7350e3ece09SJunchao Zhang             if (j < nzleft) { // B left
7360e3ece09SJunchao Zhang               val = Ba(Bi(i) + j);
7370e3ece09SJunchao Zhang             } else if (j < nzleft + alen) { // diag A
7380e3ece09SJunchao Zhang               val = Aa(Ai(i) + j - nzleft);
7390e3ece09SJunchao Zhang             } else { // B right
7400e3ece09SJunchao Zhang               val = Ba(Bi(i) + j - alen);
741f0e6e2d1SJunchao Zhang             }
742f0e6e2d1SJunchao Zhang           });
743f0e6e2d1SJunchao Zhang         }
744f0e6e2d1SJunchao Zhang       });
7450e3ece09SJunchao Zhang     }));
7460e3ece09SJunchao Zhang   PetscCall(PetscSFReduceWithMemTypeBegin(reduceSF, MPIU_SCALAR, PETSC_MEMTYPE_KOKKOS, leafBuf.data(), PETSC_MEMTYPE_KOKKOS, rootBuf.data(), MPI_REPLACE));
747f0e6e2d1SJunchao Zhang   PetscFunctionReturn(PETSC_SUCCESS);
748f0e6e2d1SJunchao Zhang }
7490e3ece09SJunchao Zhang 
750aaa8cc7dSPierre Jolivet // To finish MatMPIAIJKokkosReduce.
7510e3ece09SJunchao Zhang static PetscErrorCode MatMPIAIJKokkosReduceEnd(MPI_Comm comm, KokkosCsrMatrix A, KokkosCsrMatrix B, PetscInt cstart, PetscInt cend, const PetscInt *garray1, PetscSF ownerSF, MatReuse reuse, PetscInt *map, MatMatStruct_AtB *mm)
7520e3ece09SJunchao Zhang {
7530e3ece09SJunchao Zhang   auto       &leafBuf  = mm->leafBuf;
7540e3ece09SJunchao Zhang   auto       &rootBuf  = mm->rootBuf;
7550e3ece09SJunchao Zhang   auto       &Fda      = mm->Fd.values;
7560e3ece09SJunchao Zhang   const auto &Fdjmap   = mm->Fdjmap;
7570e3ece09SJunchao Zhang   const auto &Fdjperm  = mm->Fdjperm;
7580e3ece09SJunchao Zhang   auto        Fdnz     = mm->Fd.nnz();
7590e3ece09SJunchao Zhang   auto       &Foa      = mm->Fo.values;
7600e3ece09SJunchao Zhang   const auto &Fojmap   = mm->Fojmap;
7610e3ece09SJunchao Zhang   const auto &Fojperm  = mm->Fojperm;
7620e3ece09SJunchao Zhang   auto        Fonz     = mm->Fo.nnz();
7630e3ece09SJunchao Zhang   PetscSF     reduceSF = mm->sf;
7640e3ece09SJunchao Zhang 
765d326c3f1SJunchao Zhang   PetscFunctionBegin;
7660e3ece09SJunchao Zhang   PetscCall(PetscSFReduceEnd(reduceSF, MPIU_SCALAR, leafBuf.data(), rootBuf.data(), MPI_REPLACE));
7670e3ece09SJunchao Zhang 
7680e3ece09SJunchao Zhang   // Reduce data in rootBuf to Fd and Fo
7690e3ece09SJunchao Zhang   PetscCallCXX(Kokkos::parallel_for(
770d326c3f1SJunchao Zhang     Kokkos::RangePolicy<>(PetscGetKokkosExecutionSpace(), 0, Fdnz), KOKKOS_LAMBDA(const MatRowMapType i) {
7710e3ece09SJunchao Zhang       PetscScalar sum = 0.0;
7720e3ece09SJunchao Zhang       for (MatRowMapType k = Fdjmap(i); k < Fdjmap(i + 1); k++) sum += rootBuf(Fdjperm(k));
7730e3ece09SJunchao Zhang       Fda(i) = sum;
7740e3ece09SJunchao Zhang     }));
7750e3ece09SJunchao Zhang 
7760e3ece09SJunchao Zhang   PetscCallCXX(Kokkos::parallel_for(
777d326c3f1SJunchao Zhang     Kokkos::RangePolicy<>(PetscGetKokkosExecutionSpace(), 0, Fonz), KOKKOS_LAMBDA(const MatRowMapType i) {
7780e3ece09SJunchao Zhang       PetscScalar sum = 0.0;
7790e3ece09SJunchao Zhang       for (MatRowMapType k = Fojmap(i); k < Fojmap(i + 1); k++) sum += rootBuf(Fojperm(k));
7800e3ece09SJunchao Zhang       Foa(i) = sum;
7810e3ece09SJunchao Zhang     }));
7820e3ece09SJunchao Zhang   PetscFunctionReturn(PETSC_SUCCESS);
7830e3ece09SJunchao Zhang }
7840e3ece09SJunchao Zhang 
7850e3ece09SJunchao Zhang /*
7860e3ece09SJunchao Zhang   MatMPIAIJKokkosBcast - Bcast local rows of a MPIAIJKOKKOS matrix (E) to produce a local matrix (F, stored in mm) in split form
7870e3ece09SJunchao Zhang 
7880e3ece09SJunchao Zhang   This is a complex routine. It is essentially the MPIAIJKOKKOS counterpart of MatGetBrowsOfAoCols_MPIAIJ, but supports
7890e3ece09SJunchao Zhang   device and involves various index mapping.
7900e3ece09SJunchao Zhang 
7910e3ece09SJunchao Zhang   In the given ownerSF, leaves correspond to rows in F, and roots correspond to rows in E. Roots may connect to multiple leaves.
7920e3ece09SJunchao Zhang   Suppose F's j-th row is connected to a root identified by PetscSFNode (k,i), it means we need to bcast the i-th row of E on rank k
7930e3ece09SJunchao Zhang   to j-th row of F. ownerSF is not an arbitrary SF, instead it is the Mvctx of another MPIAIJ matrix A that is able to perform A*E.
7940e3ece09SJunchao Zhang   F has the same column layout as E.
7950e3ece09SJunchao Zhang 
7960e3ece09SJunchao Zhang   Conceptually F has global column indices. In this routine, we spit F into diagonal Fd and off-diagonal Fo.
797aaa8cc7dSPierre Jolivet   Fd uses local column indices, which are easy to compute. We just need to subtract the "local column range start" from the global indices.
7980e3ece09SJunchao Zhang   Fo had global column indices at first. We will reduce them into local ones. In doing that, we also take into account the global
7990e3ece09SJunchao Zhang   column indices that E's off-diag block has. Let's say there are n1 such indices stored in garray1[]. We will reduce them along with
8000e3ece09SJunchao Zhang   column indices in Fo and update Fo with local indices.
8010e3ece09SJunchao Zhang 
8020e3ece09SJunchao Zhang    Input Parameters:
8030e3ece09SJunchao Zhang +   E       - the MPIAIJKOKKOS matrix
8049c89aa79SPierre Jolivet .   ownerSF - the ownership SF (insignificant in MAT_REUSE_MATRIX)
8050e3ece09SJunchao Zhang .   reuse   - either MAT_INITIAL_MATRIX or MAT_REUSE_MATRIX
8060e3ece09SJunchao Zhang -   mm      - to stash matproduct intermediate data structures
8070e3ece09SJunchao Zhang 
8080e3ece09SJunchao Zhang     Output Parameters:
8090e3ece09SJunchao Zhang +   map[n1] - allocated by caller. It maps garray1[] to garray2[]. See more at ReduceTwoSetsOfGlobalIndices.
8100e3ece09SJunchao Zhang -   mm      - contains various info, such as garray2[], Fd, Fo, etc.
8110e3ece09SJunchao Zhang 
8120e3ece09SJunchao Zhang     Notes:
8130e3ece09SJunchao Zhang     When reuse = MAT_REUSE_MATRIX, ownerSF, map are not significant.
8140e3ece09SJunchao Zhang     The routine is provide in split-phase form MatMPIAIJKokkosBcastBegin/End() to provide computation/communication opportunities.
8150e3ece09SJunchao Zhang */
8160e3ece09SJunchao Zhang static PetscErrorCode MatMPIAIJKokkosBcastBegin(Mat E, PetscSF ownerSF, MatReuse reuse, PetscInt *map, MatMatStruct_AB *mm)
8170e3ece09SJunchao Zhang {
8180e3ece09SJunchao Zhang   Mat_MPIAIJ       *empi = static_cast<Mat_MPIAIJ *>(E->data);
8190e3ece09SJunchao Zhang   Mat               A = empi->A, B = empi->B; // diag and off-diag
8200e3ece09SJunchao Zhang   Mat_SeqAIJKokkos *akok = static_cast<Mat_SeqAIJKokkos *>(A->spptr), *bkok = static_cast<Mat_SeqAIJKokkos *>(B->spptr);
8210e3ece09SJunchao Zhang   PetscInt          Em = E->rmap->n; // #local rows
8220e3ece09SJunchao Zhang   MPI_Comm          comm;
8230e3ece09SJunchao Zhang 
8240e3ece09SJunchao Zhang   PetscFunctionBegin;
8250e3ece09SJunchao Zhang   PetscCallMPI(PetscObjectGetComm((PetscObject)E, &comm));
8260e3ece09SJunchao Zhang   if (reuse == MAT_INITIAL_MATRIX) {
8270e3ece09SJunchao Zhang     Mat_SeqAIJ     *aseq = static_cast<Mat_SeqAIJ *>(A->data), *bseq = static_cast<Mat_SeqAIJ *>(B->data);
8280e3ece09SJunchao Zhang     PetscInt        n1 = B->cmap->n, *Ai = aseq->i, *Aj = aseq->j, *Bi = bseq->i, *Bj = bseq->j;
8290e3ece09SJunchao Zhang     const PetscInt *garray1 = empi->garray; // its size is n1
8300e3ece09SJunchao Zhang     PetscInt        cstart, cend;
8310e3ece09SJunchao Zhang     PetscSF         bcastSF;
8320e3ece09SJunchao Zhang 
8330e3ece09SJunchao Zhang     PetscCall(MatGetOwnershipRangeColumn(E, &cstart, &cend));
8340e3ece09SJunchao Zhang 
8350e3ece09SJunchao Zhang     // Count how many nonzeros of each row in E are in the left of the diag block [cstart,cend)
8367b8d4ba6SJunchao Zhang     PetscIntKokkosViewHost E_NzLeft_h(NoInit("E_NzLeft_h"), Em), E_RowLen_h(NoInit("E_RowLen_h"), Em);
8370e3ece09SJunchao Zhang     PetscInt              *E_NzLeft = E_NzLeft_h.data(), *E_RowLen = E_RowLen_h.data();
8380e3ece09SJunchao Zhang     for (PetscInt i = 0; i < Em; i++) {
8390e3ece09SJunchao Zhang       const PetscInt *first, *last, *it;
8400e3ece09SJunchao Zhang       PetscInt        count, step;
8410e3ece09SJunchao Zhang       // std::lower_bound(first,last,cstart), but need to use global column indices
8420e3ece09SJunchao Zhang       first = Bj + Bi[i];
8430e3ece09SJunchao Zhang       last  = Bj + Bi[i + 1];
8440e3ece09SJunchao Zhang       count = last - first;
8450e3ece09SJunchao Zhang       while (count > 0) {
8460e3ece09SJunchao Zhang         it   = first;
8470e3ece09SJunchao Zhang         step = count / 2;
8480e3ece09SJunchao Zhang         it += step;
8490e3ece09SJunchao Zhang         if (empi->garray[*it] < cstart) { // map local to global
8500e3ece09SJunchao Zhang           first = ++it;
8510e3ece09SJunchao Zhang           count -= step + 1;
8520e3ece09SJunchao Zhang         } else count = step;
8530e3ece09SJunchao Zhang       }
8540e3ece09SJunchao Zhang       E_NzLeft[i] = first - (Bj + Bi[i]);
8550e3ece09SJunchao Zhang       E_RowLen[i] = (Ai[i + 1] - Ai[i]) + (Bi[i + 1] - Bi[i]);
8560e3ece09SJunchao Zhang     }
8570e3ece09SJunchao Zhang 
8580e3ece09SJunchao Zhang     // Compute row pointer Fi of F
8590e3ece09SJunchao Zhang     PetscInt *Fi, Fm, Fnz;
8600e3ece09SJunchao Zhang     PetscCall(PetscSFGetGraph(ownerSF, NULL, &Fm, NULL, NULL)); // Fm = #rows of F = nleaves of ownerSF
8610e3ece09SJunchao Zhang     PetscCall(PetscMalloc1(Fm + 1, &Fi));
8620e3ece09SJunchao Zhang     Fi[0] = 0;
8630e3ece09SJunchao Zhang     PetscCall(PetscSFBcastWithMemTypeBegin(ownerSF, MPIU_INT, PETSC_MEMTYPE_HOST, E_RowLen, PETSC_MEMTYPE_HOST, &Fi[1], MPI_REPLACE));
8640e3ece09SJunchao Zhang     PetscCall(PetscSFBcastEnd(ownerSF, MPIU_INT, E_RowLen, &Fi[1], MPI_REPLACE));
8650e3ece09SJunchao Zhang     for (PetscInt i = 0; i < Fm; i++) Fi[i + 1] += Fi[i];
8660e3ece09SJunchao Zhang     Fnz = Fi[Fm];
8670e3ece09SJunchao Zhang 
8680e3ece09SJunchao Zhang     // Build the real PetscSF for bcasting E rows (buffer to buffer)
8690e3ece09SJunchao Zhang     const PetscMPIInt *iranks, *ranks;
8700e3ece09SJunchao Zhang     const PetscInt    *ioffset, *irootloc, *roffset;
8710e3ece09SJunchao Zhang     PetscInt           niranks, nranks, *sdisp, *rdisp;
8720e3ece09SJunchao Zhang     MPI_Request       *reqs;
8730e3ece09SJunchao Zhang     PetscMPIInt        tag;
8740e3ece09SJunchao Zhang 
8750e3ece09SJunchao Zhang     PetscCall(PetscSFGetLeafRanks(ownerSF, &niranks, &iranks, &ioffset, &irootloc)); // get leaf ranks referencing roots on this process
8760e3ece09SJunchao Zhang     PetscCall(PetscSFGetRootRanks(ownerSF, &nranks, &ranks, &roffset, NULL, NULL));  // recv info
8770e3ece09SJunchao Zhang     PetscCall(PetscMalloc3(niranks + 1, &sdisp, nranks, &rdisp, niranks + nranks, &reqs));
8780e3ece09SJunchao Zhang 
8790e3ece09SJunchao Zhang     sdisp[0] = 0; // send displacement
8800e3ece09SJunchao Zhang     for (PetscInt i = 0; i < niranks; i++) {
8810e3ece09SJunchao Zhang       sdisp[i + 1] = sdisp[i];
8820e3ece09SJunchao Zhang       for (PetscInt j = ioffset[i]; j < ioffset[i + 1]; j++) {
8830e3ece09SJunchao Zhang         PetscInt r = irootloc[j]; // row to be sent
8840e3ece09SJunchao Zhang         sdisp[i + 1] += E_RowLen[r];
8850e3ece09SJunchao Zhang       }
8860e3ece09SJunchao Zhang     }
8870e3ece09SJunchao Zhang 
8880e3ece09SJunchao Zhang     PetscCallMPI(PetscCommGetNewTag(comm, &tag));
8890e3ece09SJunchao Zhang     for (PetscInt i = 0; i < nranks; i++) PetscCallMPI(MPI_Irecv(&rdisp[i], 1, MPIU_INT, ranks[i], tag, comm, &reqs[i]));
8900e3ece09SJunchao Zhang     for (PetscInt i = 0; i < niranks; i++) PetscCallMPI(MPI_Isend(&sdisp[i], 1, MPIU_INT, iranks[i], tag, comm, &reqs[nranks + i]));
8910e3ece09SJunchao Zhang     PetscCallMPI(MPI_Waitall(niranks + nranks, reqs, MPI_STATUSES_IGNORE));
8920e3ece09SJunchao Zhang 
8930e3ece09SJunchao Zhang     PetscInt     nleaves = Fnz;            // leaves are nonzeros I will receive
8940e3ece09SJunchao Zhang     PetscInt     nroots  = sdisp[niranks]; // roots are nonzeros I will send
8950e3ece09SJunchao Zhang     PetscSFNode *iremote;                  // give ownership to bcastSF
8960e3ece09SJunchao Zhang     PetscCall(PetscMalloc1(nleaves, &iremote));
8970e3ece09SJunchao Zhang     for (PetscInt i = 0; i < nranks; i++) { // for each sender rank
8980e3ece09SJunchao Zhang       PetscInt k = 0;
8990e3ece09SJunchao Zhang       for (PetscInt j = Fi[roffset[i]]; j < Fi[roffset[i + 1]]; j++) { // I will receive rows [roffset[i], roffset[i+1]) of F from ranks[i]
9000e3ece09SJunchao Zhang         iremote[j].rank  = ranks[i];
9010e3ece09SJunchao Zhang         iremote[j].index = rdisp[i] + k; // their root location
9020e3ece09SJunchao Zhang         k++;
9030e3ece09SJunchao Zhang       }
9040e3ece09SJunchao Zhang     }
9050e3ece09SJunchao Zhang     PetscCall(PetscSFCreate(comm, &bcastSF));
9060e3ece09SJunchao Zhang     PetscCall(PetscSFSetGraph(bcastSF, nroots, nleaves, NULL, PETSC_OWN_POINTER, iremote, PETSC_OWN_POINTER));
9070e3ece09SJunchao Zhang     PetscCall(PetscFree3(sdisp, rdisp, reqs));
9080e3ece09SJunchao Zhang 
9090e3ece09SJunchao Zhang     // Build a plan (rowoffset, irootloc, E_NzLeft) to copy rows in E to rootdata of bcastSF in parallel
9107b8d4ba6SJunchao Zhang     PetscIntKokkosViewHost rowoffset_h(NoInit("rowoffset_h"), ioffset[niranks] + 1);
9110e3ece09SJunchao Zhang     PetscInt              *rowoffset = rowoffset_h.data(); // for each entry (row) indicated in irootloc[], we calculate its destinate offset in copying
9120e3ece09SJunchao Zhang     rowoffset[0]                     = 0;
9137b8d4ba6SJunchao Zhang     for (PetscInt i = 0; i < ioffset[niranks]; i++) { rowoffset[i + 1] = rowoffset[i] + E_RowLen[irootloc[i]]; }
9140e3ece09SJunchao Zhang 
9150e3ece09SJunchao Zhang     // Copy (global) column indices of the needed rows in E to a buffer, and then bcast to Fj[]
9160e3ece09SJunchao Zhang     PetscInt *jbuf, *Fj;
9170e3ece09SJunchao Zhang     PetscCall(PetscMalloc2(nroots, &jbuf, Fnz, &Fj));
9180e3ece09SJunchao Zhang     for (PetscInt k = 0; k < ioffset[niranks]; k++) {
9190e3ece09SJunchao Zhang       PetscInt  i      = irootloc[k]; // row to be copied
9200e3ece09SJunchao Zhang       PetscInt *buf    = &jbuf[rowoffset[k]];
9210e3ece09SJunchao Zhang       PetscInt  nzLeft = E_NzLeft[i];
9220e3ece09SJunchao Zhang       PetscInt  alen = Ai[i + 1] - Ai[i], blen = Bi[i + 1] - Bi[i];
9230e3ece09SJunchao Zhang       for (PetscInt j = 0; j < alen + blen; j++) {
9240e3ece09SJunchao Zhang         if (j < nzLeft) {
9250e3ece09SJunchao Zhang           buf[j] = empi->garray[Bj[Bi[i] + j]]; // left B, in global
9260e3ece09SJunchao Zhang         } else if (j < nzLeft + alen) {
9270e3ece09SJunchao Zhang           buf[j] = Aj[Ai[i] + j - nzLeft] + cstart; // diag A, also in global
9280e3ece09SJunchao Zhang         } else {
9290e3ece09SJunchao Zhang           buf[j] = empi->garray[Bj[Bi[i] + j - alen]]; // right B, in global
9300e3ece09SJunchao Zhang         }
9310e3ece09SJunchao Zhang       }
9320e3ece09SJunchao Zhang     }
9330e3ece09SJunchao Zhang     PetscCall(PetscSFBcastWithMemTypeBegin(bcastSF, MPIU_INT, PETSC_MEMTYPE_HOST, jbuf, PETSC_MEMTYPE_HOST, Fj, MPI_REPLACE));
9340e3ece09SJunchao Zhang     PetscCall(PetscSFBcastEnd(bcastSF, MPIU_INT, jbuf, Fj, MPI_REPLACE));
9350e3ece09SJunchao Zhang 
9360e3ece09SJunchao Zhang     // Build a plan (i.e., F_NzLeft) to split F into Fd and Fo
9377b8d4ba6SJunchao Zhang     MatRowMapKokkosViewHost Fdi_h(NoInit("Fdi_h"), Fm + 1), Foi_h(NoInit("Foi_h"), Fm + 1); // row pointer of Fd, Fo
9387b8d4ba6SJunchao Zhang     MatColIdxKokkosViewHost F_NzLeft_h(NoInit("F_NzLeft_h"), Fm);                           // split each row of F into Left, Diag, Right. We only need to record #nz in Left and Diag.
9390e3ece09SJunchao Zhang     MatRowMapType          *Fdi = Fdi_h.data(), *Foi = Foi_h.data();
9400e3ece09SJunchao Zhang     MatColIdxType          *F_NzLeft = F_NzLeft_h.data();
9410e3ece09SJunchao Zhang 
9420e3ece09SJunchao Zhang     Fdi[0] = Foi[0] = 0;
9430e3ece09SJunchao Zhang     for (PetscInt i = 0; i < Fm; i++) {
9440e3ece09SJunchao Zhang       PetscInt *first, *last, *lb1, *lb2;
9450e3ece09SJunchao Zhang       // cut the row into: Left, [cstart, cend), Right
9460e3ece09SJunchao Zhang       first       = Fj + Fi[i];
9470e3ece09SJunchao Zhang       last        = Fj + Fi[i + 1];
9480e3ece09SJunchao Zhang       lb1         = std::lower_bound(first, last, cstart);
9490e3ece09SJunchao Zhang       F_NzLeft[i] = lb1 - first;
9500e3ece09SJunchao Zhang       lb2         = std::lower_bound(first, last, cend);
9510e3ece09SJunchao Zhang       Fdi[i + 1]  = lb2 - lb1;                        // row i length in Fdi
9520e3ece09SJunchao Zhang       Foi[i + 1]  = (Fi[i + 1] - Fi[i]) - Fdi[i + 1]; // row i length in Foi
9530e3ece09SJunchao Zhang     }
9540e3ece09SJunchao Zhang     for (PetscInt i = 0; i < Fm; i++) {
9550e3ece09SJunchao Zhang       Fdi[i + 1] += Fdi[i];
9560e3ece09SJunchao Zhang       Foi[i + 1] += Foi[i];
9570e3ece09SJunchao Zhang     }
9580e3ece09SJunchao Zhang 
9590e3ece09SJunchao Zhang     // Fill Fdj[] and Foj[], i.e., columns of Fd and Fo. Fdj[] are local, but Foj[] are not yet.
9600e3ece09SJunchao Zhang     PetscInt                Fdnz = Fdi[Fm], Fonz = Foi[Fm];
9617b8d4ba6SJunchao Zhang     MatColIdxKokkosViewHost Fdj_h(NoInit("Fdj_h"), Fdnz), Foj_h(NoInit("Foj_h"), Fonz);
9620e3ece09SJunchao Zhang     MatColIdxType          *Fdj = Fdj_h.data(), *Foj = Foj_h.data(), gid;
9630e3ece09SJunchao Zhang 
9640e3ece09SJunchao Zhang     for (PetscInt i = 0; i < Fm; i++) {
9650e3ece09SJunchao Zhang       PetscInt nzLeft = F_NzLeft[i];
9660e3ece09SJunchao Zhang       PetscInt len    = Fdi[i + 1] - Fdi[i]; // diag row len
9670e3ece09SJunchao Zhang       for (PetscInt j = 0; j < Fi[i + 1] - Fi[i]; j++) {
9680e3ece09SJunchao Zhang         gid = Fj[Fi[i] + j];
9690e3ece09SJunchao Zhang         if (j < nzLeft) { // left, in global
9700e3ece09SJunchao Zhang           Foj[Foi[i] + j] = gid;
9710e3ece09SJunchao Zhang         } else if (j < nzLeft + len) { // diag, in local
9720e3ece09SJunchao Zhang           Fdj[Fdi[i] + j - nzLeft] = gid - cstart;
9730e3ece09SJunchao Zhang         } else { // right, in global
9740e3ece09SJunchao Zhang           Foj[Foi[i] + j - len] = gid;
9750e3ece09SJunchao Zhang         }
9760e3ece09SJunchao Zhang       }
9770e3ece09SJunchao Zhang     }
9780e3ece09SJunchao Zhang     PetscCall(PetscFree2(jbuf, Fj));
9790e3ece09SJunchao Zhang     PetscCall(PetscFree(Fi));
9800e3ece09SJunchao Zhang 
9810e3ece09SJunchao Zhang     // Reduce global indices in Foj[] and garray1[] into local ones
9820e3ece09SJunchao Zhang     PetscInt n2, *garray2;
9830e3ece09SJunchao Zhang     PetscCall(ReduceTwoSetsOfGlobalIndices(n1, garray1, Fonz, Foj, &n2, &garray2, map));
9840e3ece09SJunchao Zhang 
9850e3ece09SJunchao Zhang     // Record the plans built above, for reuse
9860e3ece09SJunchao Zhang     PetscIntKokkosViewHost tmp(const_cast<PetscInt *>(irootloc), ioffset[niranks]); // irootloc[] is owned by ownerSF. We create a copy for safety
9877b8d4ba6SJunchao Zhang     PetscIntKokkosViewHost irootloc_h(NoInit("irootloc_h"), ioffset[niranks]);
9880e3ece09SJunchao Zhang     Kokkos::deep_copy(irootloc_h, tmp);
9890e3ece09SJunchao Zhang     mm->sf        = bcastSF;
9900e3ece09SJunchao Zhang     mm->E_NzLeft  = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), E_NzLeft_h);
9910e3ece09SJunchao Zhang     mm->F_NzLeft  = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), F_NzLeft_h);
9920e3ece09SJunchao Zhang     mm->irootloc  = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), irootloc_h);
9930e3ece09SJunchao Zhang     mm->rowoffset = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), rowoffset_h);
9947b8d4ba6SJunchao Zhang     mm->rootBuf   = MatScalarKokkosView(NoInit("rootBuf"), nroots);
9957b8d4ba6SJunchao Zhang     mm->leafBuf   = MatScalarKokkosView(NoInit("leafBuf"), nleaves);
9960e3ece09SJunchao Zhang     mm->garray    = garray2;
9970e3ece09SJunchao Zhang     mm->n         = n2;
9980e3ece09SJunchao Zhang 
9990e3ece09SJunchao Zhang     // Output Fd and Fo in KokkosCsrMatrix format
10007b8d4ba6SJunchao Zhang     MatScalarKokkosView Fda_d(NoInit("Fda_d"), Fdnz), Foa_d(NoInit("Foa_d"), Fonz);
10010e3ece09SJunchao Zhang     MatRowMapKokkosView Fdi_d = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Fdi_h);
10020e3ece09SJunchao Zhang     MatColIdxKokkosView Fdj_d = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Fdj_h);
10030e3ece09SJunchao Zhang     MatRowMapKokkosView Foi_d = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Foi_h);
10040e3ece09SJunchao Zhang     MatColIdxKokkosView Foj_d = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Foj_h);
10050e3ece09SJunchao Zhang 
10060e3ece09SJunchao Zhang     PetscCallCXX(mm->Fd = KokkosCsrMatrix("Fd", Fm, cend - cstart, Fdnz, Fda_d, Fdi_d, Fdj_d));
10070e3ece09SJunchao Zhang     PetscCallCXX(mm->Fo = KokkosCsrMatrix("Fo", Fm, n2, Fonz, Foa_d, Foi_d, Foj_d));
10080e3ece09SJunchao Zhang 
10090e3ece09SJunchao Zhang     // Compute kernel launch parameters in merging E or splitting F
10100e3ece09SJunchao Zhang     PetscInt teamSize, vectorLength, rowsPerTeam;
10110e3ece09SJunchao Zhang 
10120e3ece09SJunchao Zhang     teamSize = vectorLength = rowsPerTeam = -1;
10130e3ece09SJunchao Zhang     PetscCall(MatMergeGetLaunchParameters<DefaultExecutionSpace>(mm->irootloc.extent(0), mm->rootBuf.extent(0), -1, teamSize, vectorLength, rowsPerTeam));
10140e3ece09SJunchao Zhang     mm->E_TeamSize     = teamSize;
10150e3ece09SJunchao Zhang     mm->E_VectorLength = vectorLength;
10160e3ece09SJunchao Zhang     mm->E_RowsPerTeam  = rowsPerTeam;
10170e3ece09SJunchao Zhang 
10180e3ece09SJunchao Zhang     teamSize = vectorLength = rowsPerTeam = -1;
10190e3ece09SJunchao Zhang     PetscCall(MatMergeGetLaunchParameters<DefaultExecutionSpace>(Fm, Fnz, -1, teamSize, vectorLength, rowsPerTeam));
10200e3ece09SJunchao Zhang     mm->F_TeamSize     = teamSize;
10210e3ece09SJunchao Zhang     mm->F_VectorLength = vectorLength;
10220e3ece09SJunchao Zhang     mm->F_RowsPerTeam  = rowsPerTeam;
10230e3ece09SJunchao Zhang   } else PetscCheck(reuse == MAT_REUSE_MATRIX, comm, PETSC_ERR_PLIB, "Unsupported MatReuse enum %d", reuse);
10240e3ece09SJunchao Zhang 
10250e3ece09SJunchao Zhang   // Sync E's value to device
10260e3ece09SJunchao Zhang   akok->a_dual.sync_device();
10270e3ece09SJunchao Zhang   bkok->a_dual.sync_device();
10280e3ece09SJunchao Zhang 
10290e3ece09SJunchao Zhang   // Handy aliases
10300e3ece09SJunchao Zhang   const auto &Aa = akok->a_dual.view_device();
10310e3ece09SJunchao Zhang   const auto &Ba = bkok->a_dual.view_device();
10320e3ece09SJunchao Zhang   const auto &Ai = akok->i_dual.view_device();
10330e3ece09SJunchao Zhang   const auto &Bi = bkok->i_dual.view_device();
10340e3ece09SJunchao Zhang 
10350e3ece09SJunchao Zhang   // Fetch the plans
10360e3ece09SJunchao Zhang   PetscIntKokkosView  &E_NzLeft  = mm->E_NzLeft;
10370e3ece09SJunchao Zhang   PetscSF             &bcastSF   = mm->sf;
10380e3ece09SJunchao Zhang   MatScalarKokkosView &rootBuf   = mm->rootBuf;
10390e3ece09SJunchao Zhang   MatScalarKokkosView &leafBuf   = mm->leafBuf;
10400e3ece09SJunchao Zhang   PetscIntKokkosView  &irootloc  = mm->irootloc;
10410e3ece09SJunchao Zhang   PetscIntKokkosView  &rowoffset = mm->rowoffset;
10420e3ece09SJunchao Zhang 
10430e3ece09SJunchao Zhang   PetscInt teamSize     = mm->E_TeamSize;
10440e3ece09SJunchao Zhang   PetscInt vectorLength = mm->E_VectorLength;
10450e3ece09SJunchao Zhang   PetscInt rowsPerTeam  = mm->E_RowsPerTeam;
10460e3ece09SJunchao Zhang   PetscInt workSets     = (irootloc.extent(0) + rowsPerTeam - 1) / rowsPerTeam;
10470e3ece09SJunchao Zhang 
10480e3ece09SJunchao Zhang   // Copy rows in A/B of E to rootBuf, then bcast it to leafBuf
10490e3ece09SJunchao Zhang   PetscCallCXX(Kokkos::parallel_for(
1050d326c3f1SJunchao Zhang     Kokkos::TeamPolicy<>(PetscGetKokkosExecutionSpace(), workSets, teamSize, vectorLength), KOKKOS_LAMBDA(const KokkosTeamMemberType &t) {
10510e3ece09SJunchao Zhang       Kokkos::parallel_for(Kokkos::TeamThreadRange(t, 0, rowsPerTeam), [&](PetscInt k) {
10520e3ece09SJunchao Zhang         size_t r = t.league_rank() * rowsPerTeam + k; // r-th entry in irootloc[]
10530e3ece09SJunchao Zhang         if (r < irootloc.extent(0)) {
10540e3ece09SJunchao Zhang           PetscInt i      = irootloc(r); // row i of E
10550e3ece09SJunchao Zhang           PetscInt disp   = rowoffset(r);
10560e3ece09SJunchao Zhang           PetscInt alen   = Ai(i + 1) - Ai(i);
10570e3ece09SJunchao Zhang           PetscInt blen   = Bi(i + 1) - Bi(i);
10580e3ece09SJunchao Zhang           PetscInt nzleft = E_NzLeft(i);
10590e3ece09SJunchao Zhang 
10600e3ece09SJunchao Zhang           Kokkos::parallel_for(Kokkos::ThreadVectorRange(t, alen + blen), [&](PetscInt j) {
10610e3ece09SJunchao Zhang             if (j < nzleft) { // B left
10620e3ece09SJunchao Zhang               rootBuf(disp + j) = Ba(Bi(i) + j);
10630e3ece09SJunchao Zhang             } else if (j < nzleft + alen) { // diag A
10640e3ece09SJunchao Zhang               rootBuf(disp + j) = Aa(Ai(i) + j - nzleft);
10650e3ece09SJunchao Zhang             } else { // B right
10660e3ece09SJunchao Zhang               rootBuf(disp + j) = Ba(Bi(i) + j - alen);
10670e3ece09SJunchao Zhang             }
10680e3ece09SJunchao Zhang           });
10690e3ece09SJunchao Zhang         }
10700e3ece09SJunchao Zhang       });
10710e3ece09SJunchao Zhang     }));
10720e3ece09SJunchao Zhang   PetscCall(PetscSFBcastWithMemTypeBegin(bcastSF, MPIU_SCALAR, PETSC_MEMTYPE_KOKKOS, rootBuf.data(), PETSC_MEMTYPE_KOKKOS, leafBuf.data(), MPI_REPLACE));
10730e3ece09SJunchao Zhang   PetscFunctionReturn(PETSC_SUCCESS);
10740e3ece09SJunchao Zhang }
10750e3ece09SJunchao Zhang 
10760e3ece09SJunchao Zhang // To finish MatMPIAIJKokkosBcast.
10770e3ece09SJunchao Zhang static PetscErrorCode MatMPIAIJKokkosBcastEnd(Mat E, PetscSF ownerSF, MatReuse reuse, PetscInt *map, MatMatStruct_AB *mm)
10780e3ece09SJunchao Zhang {
10790e3ece09SJunchao Zhang   PetscFunctionBegin;
10800e3ece09SJunchao Zhang   const auto &Fd  = mm->Fd;
10810e3ece09SJunchao Zhang   const auto &Fo  = mm->Fo;
10820e3ece09SJunchao Zhang   const auto &Fdi = Fd.graph.row_map;
10830e3ece09SJunchao Zhang   const auto &Foi = Fo.graph.row_map;
10840e3ece09SJunchao Zhang   auto       &Fda = Fd.values;
10850e3ece09SJunchao Zhang   auto       &Foa = Fo.values;
10860e3ece09SJunchao Zhang   auto        Fm  = Fd.numRows();
10870e3ece09SJunchao Zhang 
10880e3ece09SJunchao Zhang   PetscIntKokkosView  &F_NzLeft     = mm->F_NzLeft;
10890e3ece09SJunchao Zhang   PetscSF             &bcastSF      = mm->sf;
10900e3ece09SJunchao Zhang   MatScalarKokkosView &rootBuf      = mm->rootBuf;
10910e3ece09SJunchao Zhang   MatScalarKokkosView &leafBuf      = mm->leafBuf;
10920e3ece09SJunchao Zhang   PetscInt             teamSize     = mm->F_TeamSize;
10930e3ece09SJunchao Zhang   PetscInt             vectorLength = mm->F_VectorLength;
10940e3ece09SJunchao Zhang   PetscInt             rowsPerTeam  = mm->F_RowsPerTeam;
10950e3ece09SJunchao Zhang   PetscInt             workSets     = (Fm + rowsPerTeam - 1) / rowsPerTeam;
10960e3ece09SJunchao Zhang 
10970e3ece09SJunchao Zhang   PetscCall(PetscSFBcastEnd(bcastSF, MPIU_SCALAR, rootBuf.data(), leafBuf.data(), MPI_REPLACE));
10980e3ece09SJunchao Zhang 
10990e3ece09SJunchao Zhang   // Update Fda and Foa with new data in leafBuf (as if it is Fa)
11000e3ece09SJunchao Zhang   PetscCallCXX(Kokkos::parallel_for(
1101d326c3f1SJunchao Zhang     Kokkos::TeamPolicy<>(PetscGetKokkosExecutionSpace(), workSets, teamSize, vectorLength), KOKKOS_LAMBDA(const KokkosTeamMemberType &t) {
11020e3ece09SJunchao Zhang       Kokkos::parallel_for(Kokkos::TeamThreadRange(t, 0, rowsPerTeam), [&](PetscInt k) {
11030e3ece09SJunchao Zhang         PetscInt i = t.league_rank() * rowsPerTeam + k; // i-th row in F
11040e3ece09SJunchao Zhang         if (i < Fm) {
11050e3ece09SJunchao Zhang           PetscInt nzLeft = F_NzLeft(i);
11060e3ece09SJunchao Zhang           PetscInt alen   = Fdi(i + 1) - Fdi(i);
11070e3ece09SJunchao Zhang           PetscInt blen   = Foi(i + 1) - Foi(i);
11080e3ece09SJunchao Zhang           PetscInt Fii    = Fdi(i) + Foi(i);
11090e3ece09SJunchao Zhang 
11100e3ece09SJunchao Zhang           Kokkos::parallel_for(Kokkos::ThreadVectorRange(t, alen + blen), [&](PetscInt j) {
11110e3ece09SJunchao Zhang             PetscScalar val = leafBuf(Fii + j);
11120e3ece09SJunchao Zhang             if (j < nzLeft) { // left
11130e3ece09SJunchao Zhang               Foa(Foi(i) + j) = val;
11140e3ece09SJunchao Zhang             } else if (j < nzLeft + alen) { // diag
11150e3ece09SJunchao Zhang               Fda(Fdi(i) + j - nzLeft) = val;
11160e3ece09SJunchao Zhang             } else { // right
11170e3ece09SJunchao Zhang               Foa(Foi(i) + j - alen) = val;
11180e3ece09SJunchao Zhang             }
11190e3ece09SJunchao Zhang           });
11200e3ece09SJunchao Zhang         }
11210e3ece09SJunchao Zhang       });
11220e3ece09SJunchao Zhang     }));
11230e3ece09SJunchao Zhang   PetscFunctionReturn(PETSC_SUCCESS);
11240e3ece09SJunchao Zhang }
11250e3ece09SJunchao Zhang 
11260e3ece09SJunchao Zhang static PetscErrorCode MatProductSymbolic_MPIAIJKokkos_AtB(Mat_Product *product, Mat A, Mat B, MatMatStruct_AtB *mm)
11270e3ece09SJunchao Zhang {
11280e3ece09SJunchao Zhang   Mat_MPIAIJ     *ampi = static_cast<Mat_MPIAIJ *>(A->data);
11290e3ece09SJunchao Zhang   Mat_MPIAIJ     *bmpi = static_cast<Mat_MPIAIJ *>(B->data);
11300e3ece09SJunchao Zhang   KokkosCsrMatrix Adt, Aot, Ad, Ao, Bd, Bo;
11310e3ece09SJunchao Zhang   PetscInt        cstart, cend;
11320e3ece09SJunchao Zhang   MPI_Comm        comm;
11330e3ece09SJunchao Zhang 
11340e3ece09SJunchao Zhang   PetscFunctionBegin;
11350e3ece09SJunchao Zhang   PetscCall(PetscObjectGetComm((PetscObject)B, &comm));
11360e3ece09SJunchao Zhang   PetscCall(MatSeqAIJKokkosGenerateTranspose_Private(ampi->A, &Adt));
11370e3ece09SJunchao Zhang   PetscCall(MatSeqAIJKokkosGenerateTranspose_Private(ampi->B, &Aot));
11380e3ece09SJunchao Zhang   PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(ampi->A, &Ad));
11390e3ece09SJunchao Zhang   PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(ampi->B, &Ao));
11400e3ece09SJunchao Zhang   PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(bmpi->A, &Bd));
11410e3ece09SJunchao Zhang   PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(bmpi->B, &Bo));
11420e3ece09SJunchao Zhang 
11430e3ece09SJunchao Zhang   // TODO: add command line options to select spgemm algorithms
11440e3ece09SJunchao Zhang   auto spgemm_alg = KokkosSparse::SPGEMMAlgorithm::SPGEMM_DEFAULT; // default is TPL if enabled, otherwise KK
11450e3ece09SJunchao Zhang 
11460e3ece09SJunchao Zhang   // CUDA-10.2's spgemm has bugs. We prefer the SpGEMMreuse APIs introduced in cuda-11.4
11470e3ece09SJunchao Zhang #if defined(KOKKOSKERNELS_ENABLE_TPL_CUSPARSE)
11480e3ece09SJunchao Zhang   #if PETSC_PKG_CUDA_VERSION_LT(11, 4, 0)
11490e3ece09SJunchao Zhang   spgemm_alg = KokkosSparse::SPGEMMAlgorithm::SPGEMM_KK;
1150f0e6e2d1SJunchao Zhang   #endif
11510e3ece09SJunchao Zhang #endif
11520e3ece09SJunchao Zhang 
11530e3ece09SJunchao Zhang   PetscCallCXX(mm->kh1.create_spgemm_handle(spgemm_alg));
11540e3ece09SJunchao Zhang   PetscCallCXX(mm->kh2.create_spgemm_handle(spgemm_alg));
11550e3ece09SJunchao Zhang   PetscCallCXX(mm->kh3.create_spgemm_handle(spgemm_alg));
11560e3ece09SJunchao Zhang   PetscCallCXX(mm->kh4.create_spgemm_handle(spgemm_alg));
11570e3ece09SJunchao Zhang 
11580e3ece09SJunchao Zhang   // Aot * (B's diag + B's off-diag)
11590e3ece09SJunchao Zhang   PetscCallCXX(KokkosSparse::spgemm_symbolic(mm->kh3, Aot, false, Bd, false, mm->C3));
11600e3ece09SJunchao Zhang   PetscCallCXX(KokkosSparse::spgemm_symbolic(mm->kh4, Aot, false, Bo, false, mm->C4));
11610e3ece09SJunchao Zhang   // KK spgemm_symbolic() only populates the result's row map, but not its columns.
11620e3ece09SJunchao Zhang   // TODO: Remove the fake spgemm_numeric() after KK fixed this problem.
11630e3ece09SJunchao Zhang   PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh3, Aot, false, Bd, false, mm->C3));
11640e3ece09SJunchao Zhang   PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh4, Aot, false, Bo, false, mm->C4));
11650e3ece09SJunchao Zhang #if PETSC_PKG_KOKKOS_KERNELS_VERSION_LT(4, 0, 0)
1166d326c3f1SJunchao Zhang 
11670e3ece09SJunchao Zhang   PetscCallCXX(sort_crs_matrix(mm->C3));
11680e3ece09SJunchao Zhang   PetscCallCXX(sort_crs_matrix(mm->C4));
11690e3ece09SJunchao Zhang #endif
11700e3ece09SJunchao Zhang 
11710e3ece09SJunchao Zhang   // Reduce E (i.e., C3 and C4)'s rows to form F, and overlap the communication
11727b8d4ba6SJunchao Zhang   PetscIntKokkosViewHost map_h(NoInit("map_h"), bmpi->B->cmap->n);
11730e3ece09SJunchao Zhang   PetscCall(MatGetOwnershipRangeColumn(B, &cstart, &cend));
11740e3ece09SJunchao Zhang   PetscCall(MatMPIAIJKokkosReduceBegin(comm, mm->C3, mm->C4, cstart, cend, bmpi->garray, ampi->Mvctx, MAT_INITIAL_MATRIX, map_h.data(), mm));
11750e3ece09SJunchao Zhang 
11760e3ece09SJunchao Zhang   // Adt * (B's diag + B's off-diag)
11770e3ece09SJunchao Zhang   PetscCallCXX(KokkosSparse::spgemm_symbolic(mm->kh1, Adt, false, Bd, false, mm->C1));
11780e3ece09SJunchao Zhang   PetscCallCXX(KokkosSparse::spgemm_symbolic(mm->kh2, Adt, false, Bo, false, mm->C2_mid));
11790e3ece09SJunchao Zhang   PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh1, Adt, false, Bd, false, mm->C1));
11800e3ece09SJunchao Zhang   PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh2, Adt, false, Bo, false, mm->C2_mid));
11810e3ece09SJunchao Zhang #if PETSC_PKG_KOKKOS_KERNELS_VERSION_LT(4, 0, 0)
11820e3ece09SJunchao Zhang   PetscCallCXX(sort_crs_matrix(mm->C1));
11830e3ece09SJunchao Zhang   PetscCallCXX(sort_crs_matrix(mm->C2_mid));
11840e3ece09SJunchao Zhang #endif
11850e3ece09SJunchao Zhang 
11860e3ece09SJunchao Zhang   PetscCall(MatMPIAIJKokkosReduceEnd(comm, mm->C3, mm->C4, cstart, cend, bmpi->garray, ampi->Mvctx, MAT_INITIAL_MATRIX, map_h.data(), mm));
11870e3ece09SJunchao Zhang 
11880e3ece09SJunchao Zhang   // Create C2, which shares a, i arrays with C2_mid, but with new column indices and potentially larger column size
11897b8d4ba6SJunchao Zhang   MatColIdxKokkosView oldj = mm->C2_mid.graph.entries, newj(NoInit("j"), oldj.extent(0));
11900e3ece09SJunchao Zhang   PetscIntKokkosView  map  = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), map_h);
1191d326c3f1SJunchao Zhang   PetscCallCXX(Kokkos::parallel_for(Kokkos::RangePolicy<>(PetscGetKokkosExecutionSpace(), 0, oldj.extent(0)), KOKKOS_LAMBDA(const PetscInt i) { newj(i) = map(oldj(i)); }));
11920e3ece09SJunchao Zhang   PetscCallCXX(mm->C2 = KokkosCsrMatrix("C2", mm->C2_mid.numRows(), mm->n /*new column size*/, mm->C2_mid.nnz(), mm->C2_mid.values, mm->C2_mid.graph.row_map, newj));
11930e3ece09SJunchao Zhang 
11940e3ece09SJunchao Zhang   // C = (C1+Fd, C2+Fo)
11950e3ece09SJunchao Zhang   PetscCallCXX(mm->kh1.create_spadd_handle(true)); // C1, Fd are sorted
11960e3ece09SJunchao Zhang   PetscCallCXX(mm->kh2.create_spadd_handle(true)); // C2, Fo are sorted
11970e3ece09SJunchao Zhang   PetscCallCXX(KokkosSparse::spadd_symbolic(&mm->kh1, mm->C1, mm->Fd, mm->Cd));
11980e3ece09SJunchao Zhang   PetscCallCXX(KokkosSparse::spadd_symbolic(&mm->kh2, mm->C2, mm->Fo, mm->Co));
11990e3ece09SJunchao Zhang   PetscCallCXX(KokkosSparse::spadd_numeric(&mm->kh1, 1.0, mm->C1, 1.0, mm->Fd, mm->Cd));
12000e3ece09SJunchao Zhang   PetscCallCXX(KokkosSparse::spadd_numeric(&mm->kh2, 1.0, mm->C2, 1.0, mm->Fo, mm->Co));
12010e3ece09SJunchao Zhang   PetscFunctionReturn(PETSC_SUCCESS);
12020e3ece09SJunchao Zhang }
12030e3ece09SJunchao Zhang 
12040e3ece09SJunchao Zhang static PetscErrorCode MatProductNumeric_MPIAIJKokkos_AtB(Mat_Product *product, Mat A, Mat B, MatMatStruct_AtB *mm)
12050e3ece09SJunchao Zhang {
12060e3ece09SJunchao Zhang   Mat_MPIAIJ     *ampi = static_cast<Mat_MPIAIJ *>(A->data);
12070e3ece09SJunchao Zhang   Mat_MPIAIJ     *bmpi = static_cast<Mat_MPIAIJ *>(B->data);
12080e3ece09SJunchao Zhang   KokkosCsrMatrix Adt, Aot, Bd, Bo;
12090e3ece09SJunchao Zhang   MPI_Comm        comm;
12100e3ece09SJunchao Zhang 
12110e3ece09SJunchao Zhang   PetscFunctionBegin;
12120e3ece09SJunchao Zhang   PetscCall(PetscObjectGetComm((PetscObject)B, &comm));
12130e3ece09SJunchao Zhang   PetscCall(MatSeqAIJKokkosGenerateTranspose_Private(ampi->A, &Adt));
12140e3ece09SJunchao Zhang   PetscCall(MatSeqAIJKokkosGenerateTranspose_Private(ampi->B, &Aot));
12150e3ece09SJunchao Zhang   PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(bmpi->A, &Bd));
12160e3ece09SJunchao Zhang   PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(bmpi->B, &Bo));
12170e3ece09SJunchao Zhang 
12180e3ece09SJunchao Zhang   // Aot * (B's diag + B's off-diag)
12190e3ece09SJunchao Zhang   PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh3, Aot, false, Bd, false, mm->C3));
12200e3ece09SJunchao Zhang   PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh4, Aot, false, Bo, false, mm->C4));
12210e3ece09SJunchao Zhang 
12220e3ece09SJunchao Zhang   // Reduce E (i.e., C3 and C4)'s rows to form F, and overlap the communication
12230e3ece09SJunchao Zhang   PetscCall(MatMPIAIJKokkosReduceBegin(comm, mm->C3, mm->C4, 0, 0, NULL, NULL, MAT_REUSE_MATRIX, NULL, mm));
12240e3ece09SJunchao Zhang 
12250e3ece09SJunchao Zhang   // Adt * (B's diag + B's off-diag)
12260e3ece09SJunchao Zhang   PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh1, Adt, false, Bd, false, mm->C1));
12270e3ece09SJunchao Zhang   PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh2, Adt, false, Bo, false, mm->C2_mid));
12280e3ece09SJunchao Zhang 
12290e3ece09SJunchao Zhang   PetscCall(MatMPIAIJKokkosReduceEnd(comm, mm->C3, mm->C4, 0, 0, NULL, NULL, MAT_REUSE_MATRIX, NULL, mm));
12300e3ece09SJunchao Zhang 
12310e3ece09SJunchao Zhang   // C = (C1+Fd, C2+Fo)
12320e3ece09SJunchao Zhang   PetscCallCXX(KokkosSparse::spadd_numeric(&mm->kh1, 1.0, mm->C1, 1.0, mm->Fd, mm->Cd));
12330e3ece09SJunchao Zhang   PetscCallCXX(KokkosSparse::spadd_numeric(&mm->kh2, 1.0, mm->C2, 1.0, mm->Fo, mm->Co));
12340e3ece09SJunchao Zhang   PetscFunctionReturn(PETSC_SUCCESS);
12350e3ece09SJunchao Zhang }
1236f0e6e2d1SJunchao Zhang 
1237076ba34aSJunchao Zhang /* MatProductSymbolic_MPIAIJKokkos_AB - AB flavor of MatProductSymbolic_MPIAIJKokkos
1238076ba34aSJunchao Zhang 
1239076ba34aSJunchao Zhang   Input Parameters:
1240076ba34aSJunchao Zhang +  product  - Mat_Product which carried out the computation. Passed in to access info about this mat product.
1241076ba34aSJunchao Zhang .  A        - an MPIAIJKOKKOS matrix
1242076ba34aSJunchao Zhang .  B        - an MPIAIJKOKKOS matrix
1243076ba34aSJunchao Zhang -  mm       - a struct used to stash intermediate data when computing AB. Persist from symbolic to numeric operations.
1244076ba34aSJunchao Zhang */
1245d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatProductSymbolic_MPIAIJKokkos_AB(Mat_Product *product, Mat A, Mat B, MatMatStruct_AB *mm)
1246d71ae5a4SJacob Faibussowitsch {
12470e3ece09SJunchao Zhang   Mat_MPIAIJ     *ampi = static_cast<Mat_MPIAIJ *>(A->data);
12480e3ece09SJunchao Zhang   Mat_MPIAIJ     *bmpi = static_cast<Mat_MPIAIJ *>(B->data);
12490e3ece09SJunchao Zhang   KokkosCsrMatrix Ad, Ao, Bd, Bo;
1250076ba34aSJunchao Zhang 
1251076ba34aSJunchao Zhang   PetscFunctionBegin;
12520e3ece09SJunchao Zhang   PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(ampi->A, &Ad));
12530e3ece09SJunchao Zhang   PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(ampi->B, &Ao));
12540e3ece09SJunchao Zhang   PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(bmpi->A, &Bd));
12550e3ece09SJunchao Zhang   PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(bmpi->B, &Bo));
12560e3ece09SJunchao Zhang 
12570e3ece09SJunchao Zhang   // TODO: add command line options to select spgemm algorithms
12580e3ece09SJunchao Zhang   auto spgemm_alg = KokkosSparse::SPGEMMAlgorithm::SPGEMM_DEFAULT; // default is TPL if enabled, otherwise KK
12590e3ece09SJunchao Zhang 
12600e3ece09SJunchao Zhang   // CUDA-10.2's spgemm has bugs. We prefer the SpGEMMreuse APIs introduced in cuda-11.4
12610e3ece09SJunchao Zhang #if defined(KOKKOSKERNELS_ENABLE_TPL_CUSPARSE)
12620e3ece09SJunchao Zhang   #if PETSC_PKG_CUDA_VERSION_LT(11, 4, 0)
12630e3ece09SJunchao Zhang   spgemm_alg = KokkosSparse::SPGEMMAlgorithm::SPGEMM_KK;
12640e3ece09SJunchao Zhang   #endif
1265f0e6e2d1SJunchao Zhang #endif
1266f0e6e2d1SJunchao Zhang 
12670e3ece09SJunchao Zhang   mm->kh1.create_spgemm_handle(spgemm_alg);
12680e3ece09SJunchao Zhang   mm->kh2.create_spgemm_handle(spgemm_alg);
12690e3ece09SJunchao Zhang   mm->kh3.create_spgemm_handle(spgemm_alg);
12700e3ece09SJunchao Zhang   mm->kh4.create_spgemm_handle(spgemm_alg);
1271076ba34aSJunchao Zhang 
12720e3ece09SJunchao Zhang   // Bcast B's rows to form F, and overlap the communication
12737b8d4ba6SJunchao Zhang   PetscIntKokkosViewHost map_h(NoInit("map_h"), bmpi->B->cmap->n);
12740e3ece09SJunchao Zhang   PetscCall(MatMPIAIJKokkosBcastBegin(B, ampi->Mvctx, MAT_INITIAL_MATRIX, map_h.data(), mm));
1275076ba34aSJunchao Zhang 
12760e3ece09SJunchao Zhang   // A's diag * (B's diag + B's off-diag)
12770e3ece09SJunchao Zhang   PetscCallCXX(KokkosSparse::spgemm_symbolic(mm->kh1, Ad, false, Bd, false, mm->C1));
12780e3ece09SJunchao Zhang   PetscCallCXX(KokkosSparse::spgemm_symbolic(mm->kh2, Ad, false, Bo, false, mm->C2_mid)); // C2 aliases with C2_mid, except with new column indices
12790e3ece09SJunchao Zhang   // KK spgemm_symbolic() only populates the result's row map, but not its columns.
12800e3ece09SJunchao Zhang   // TODO: Remove the fake spgemm_numeric() after KK fixed this problem.
12810e3ece09SJunchao Zhang   PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh1, Ad, false, Bd, false, mm->C1));
12820e3ece09SJunchao Zhang   PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh2, Ad, false, Bo, false, mm->C2_mid));
12830e3ece09SJunchao Zhang #if PETSC_PKG_KOKKOS_KERNELS_VERSION_LT(4, 0, 0)
12840e3ece09SJunchao Zhang   PetscCallCXX(sort_crs_matrix(mm->C1));
12850e3ece09SJunchao Zhang   PetscCallCXX(sort_crs_matrix(mm->C2_mid));
12860e3ece09SJunchao Zhang #endif
1287076ba34aSJunchao Zhang 
12880e3ece09SJunchao Zhang   PetscCall(MatMPIAIJKokkosBcastEnd(B, ampi->Mvctx, MAT_INITIAL_MATRIX, map_h.data(), mm));
1289076ba34aSJunchao Zhang 
12900e3ece09SJunchao Zhang   // A's off-diag * (F's diag + F's off-diag)
12910e3ece09SJunchao Zhang   PetscCallCXX(KokkosSparse::spgemm_symbolic(mm->kh3, Ao, false, mm->Fd, false, mm->C3));
12920e3ece09SJunchao Zhang   PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh3, Ao, false, mm->Fd, false, mm->C3));
12930e3ece09SJunchao Zhang   PetscCallCXX(KokkosSparse::spgemm_symbolic(mm->kh4, Ao, false, mm->Fo, false, mm->C4));
12940e3ece09SJunchao Zhang   PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh4, Ao, false, mm->Fo, false, mm->C4));
12950e3ece09SJunchao Zhang #if PETSC_PKG_KOKKOS_KERNELS_VERSION_LT(4, 0, 0)
12960e3ece09SJunchao Zhang   PetscCallCXX(sort_crs_matrix(mm->C3));
12970e3ece09SJunchao Zhang   PetscCallCXX(sort_crs_matrix(mm->C4));
12980e3ece09SJunchao Zhang #endif
1299076ba34aSJunchao Zhang 
13000e3ece09SJunchao Zhang   // Create C2, which shares a, i arrays with C2_mid, but with new column indices and potentially larger column size
13017b8d4ba6SJunchao Zhang   MatColIdxKokkosView oldj = mm->C2_mid.graph.entries, newj(NoInit("j"), oldj.extent(0));
13020e3ece09SJunchao Zhang   PetscIntKokkosView  map  = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), map_h);
1303d326c3f1SJunchao Zhang   PetscCallCXX(Kokkos::parallel_for(Kokkos::RangePolicy<>(PetscGetKokkosExecutionSpace(), 0, oldj.extent(0)), KOKKOS_LAMBDA(const PetscInt i) { newj(i) = map(oldj(i)); }));
13040e3ece09SJunchao Zhang   mm->C2 = KokkosCsrMatrix("C2", mm->C2_mid.numRows(), mm->n /*new column size*/, mm->C2_mid.nnz(), mm->C2_mid.values, mm->C2_mid.graph.row_map, newj);
13050e3ece09SJunchao Zhang 
13060e3ece09SJunchao Zhang   // C = (Cd, Co) = (C1+C3, C2+C4)
13070e3ece09SJunchao Zhang   mm->kh1.create_spadd_handle(true); // C1, C3 are sorted
13080e3ece09SJunchao Zhang   mm->kh2.create_spadd_handle(true); // C2, C4 are sorted
13090e3ece09SJunchao Zhang   PetscCallCXX(KokkosSparse::spadd_symbolic(&mm->kh1, mm->C1, mm->C3, mm->Cd));
13100e3ece09SJunchao Zhang   PetscCallCXX(KokkosSparse::spadd_symbolic(&mm->kh2, mm->C2, mm->C4, mm->Co));
13110e3ece09SJunchao Zhang   PetscCallCXX(KokkosSparse::spadd_numeric(&mm->kh1, 1.0, mm->C1, 1.0, mm->C3, mm->Cd));
13120e3ece09SJunchao Zhang   PetscCallCXX(KokkosSparse::spadd_numeric(&mm->kh2, 1.0, mm->C2, 1.0, mm->C4, mm->Co));
13133ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
1314076ba34aSJunchao Zhang }
1315076ba34aSJunchao Zhang 
13160e3ece09SJunchao Zhang static PetscErrorCode MatProductNumeric_MPIAIJKokkos_AB(Mat_Product *product, Mat A, Mat B, MatMatStruct_AB *mm)
1317d71ae5a4SJacob Faibussowitsch {
13180e3ece09SJunchao Zhang   Mat_MPIAIJ     *ampi = static_cast<Mat_MPIAIJ *>(A->data);
13190e3ece09SJunchao Zhang   Mat_MPIAIJ     *bmpi = static_cast<Mat_MPIAIJ *>(B->data);
13200e3ece09SJunchao Zhang   KokkosCsrMatrix Ad, Ao, Bd, Bo;
1321076ba34aSJunchao Zhang 
1322076ba34aSJunchao Zhang   PetscFunctionBegin;
13230e3ece09SJunchao Zhang   PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(ampi->A, &Ad));
13240e3ece09SJunchao Zhang   PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(ampi->B, &Ao));
13250e3ece09SJunchao Zhang   PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(bmpi->A, &Bd));
13260e3ece09SJunchao Zhang   PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(bmpi->B, &Bo));
1327076ba34aSJunchao Zhang 
13280e3ece09SJunchao Zhang   // Bcast B's rows to form F, and overlap the communication
13290e3ece09SJunchao Zhang   PetscCall(MatMPIAIJKokkosBcastBegin(B, NULL, MAT_REUSE_MATRIX, NULL, mm));
1330076ba34aSJunchao Zhang 
13310e3ece09SJunchao Zhang   // A's diag * (B's diag + B's off-diag)
13320e3ece09SJunchao Zhang   PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh1, Ad, false, Bd, false, mm->C1));
13330e3ece09SJunchao Zhang   PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh2, Ad, false, Bo, false, mm->C2_mid));
1334076ba34aSJunchao Zhang 
13350e3ece09SJunchao Zhang   PetscCall(MatMPIAIJKokkosBcastEnd(B, NULL, MAT_REUSE_MATRIX, NULL, mm));
1336076ba34aSJunchao Zhang 
13370e3ece09SJunchao Zhang   // A's off-diag * (F's diag + F's off-diag)
13380e3ece09SJunchao Zhang   PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh3, Ao, false, mm->Fd, false, mm->C3));
13390e3ece09SJunchao Zhang   PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh4, Ao, false, mm->Fo, false, mm->C4));
13400e3ece09SJunchao Zhang 
13410e3ece09SJunchao Zhang   // C = (Cd, Co) = (C1+C3, C2+C4)
13420e3ece09SJunchao Zhang   PetscCallCXX(KokkosSparse::spadd_numeric(&mm->kh1, 1.0, mm->C1, 1.0, mm->C3, mm->Cd));
13430e3ece09SJunchao Zhang   PetscCallCXX(KokkosSparse::spadd_numeric(&mm->kh2, 1.0, mm->C2, 1.0, mm->C4, mm->Co));
13443ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
1345076ba34aSJunchao Zhang }
1346076ba34aSJunchao Zhang 
134766976f2fSJacob Faibussowitsch static PetscErrorCode MatProductNumeric_MPIAIJKokkos(Mat C)
1348d71ae5a4SJacob Faibussowitsch {
13490e3ece09SJunchao Zhang   Mat_MPIAIJ                  *cmpi = static_cast<Mat_MPIAIJ *>(C->data);
13500e3ece09SJunchao Zhang   Mat_Product                 *product;
13510e3ece09SJunchao Zhang   MatProductData_MPIAIJKokkos *pdata;
1352076ba34aSJunchao Zhang   MatProductType               ptype;
13530e3ece09SJunchao Zhang   Mat                          A, B;
1354076ba34aSJunchao Zhang 
1355076ba34aSJunchao Zhang   PetscFunctionBegin;
13560e3ece09SJunchao Zhang   MatCheckProduct(C, 1); // make sure C is a product
13570e3ece09SJunchao Zhang   product = C->product;
13580e3ece09SJunchao Zhang   pdata   = static_cast<MatProductData_MPIAIJKokkos *>(product->data);
1359076ba34aSJunchao Zhang   ptype   = product->type;
1360076ba34aSJunchao Zhang   A       = product->A;
1361076ba34aSJunchao Zhang   B       = product->B;
1362076ba34aSJunchao Zhang 
13630e3ece09SJunchao Zhang   // See if numeric has already been done in symbolic (e.g., user calls MatMatMult(A,B,MAT_INITIAL_MATRIX,..,C)).
13640e3ece09SJunchao Zhang   // If yes, skip the numeric, but reset the flag so that next time when user calls MatMatMult(E,F,MAT_REUSE_MATRIX,..,C),
13650e3ece09SJunchao Zhang   // we still do numeric.
13660e3ece09SJunchao Zhang   if (pdata->reusesym) { // numeric reuses results from symbolic
13670e3ece09SJunchao Zhang     pdata->reusesym = PETSC_FALSE;
13683ba16761SJacob Faibussowitsch     PetscFunctionReturn(PETSC_SUCCESS);
1369076ba34aSJunchao Zhang   }
1370076ba34aSJunchao Zhang 
1371076ba34aSJunchao Zhang   if (ptype == MATPRODUCT_AB) {
13720e3ece09SJunchao Zhang     PetscCall(MatProductNumeric_MPIAIJKokkos_AB(product, A, B, pdata->mmAB));
1373076ba34aSJunchao Zhang   } else if (ptype == MATPRODUCT_AtB) {
13740e3ece09SJunchao Zhang     PetscCall(MatProductNumeric_MPIAIJKokkos_AtB(product, A, B, pdata->mmAtB));
13750e3ece09SJunchao Zhang   } else if (ptype == MATPRODUCT_PtAP) { // BtAB, computed by Z = AB; C= BtZ
13760e3ece09SJunchao Zhang     PetscCall(MatProductNumeric_MPIAIJKokkos_AB(product, A, B, pdata->mmAB));
13770e3ece09SJunchao Zhang     PetscCall(MatProductNumeric_MPIAIJKokkos_AtB(product, B, pdata->Z, pdata->mmAtB));
1378076ba34aSJunchao Zhang   }
13790e3ece09SJunchao Zhang 
13800e3ece09SJunchao Zhang   PetscCall(MatSeqAIJKokkosModifyDevice(cmpi->A)); // mark that A, B on device are modified
13810e3ece09SJunchao Zhang   PetscCall(MatSeqAIJKokkosModifyDevice(cmpi->B));
13823ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
1383076ba34aSJunchao Zhang }
1384076ba34aSJunchao Zhang 
138566976f2fSJacob Faibussowitsch static PetscErrorCode MatProductSymbolic_MPIAIJKokkos(Mat C)
1386d71ae5a4SJacob Faibussowitsch {
1387076ba34aSJunchao Zhang   Mat                          A, B;
13880e3ece09SJunchao Zhang   Mat_Product                 *product;
1389076ba34aSJunchao Zhang   MatProductType               ptype;
13900e3ece09SJunchao Zhang   MatProductData_MPIAIJKokkos *pdata;
1391076ba34aSJunchao Zhang   MatMatStruct                *mm = NULL;
13920e3ece09SJunchao Zhang   PetscInt                     m, n, M, N;
13930e3ece09SJunchao Zhang   Mat                          Cd, Co;
13940e3ece09SJunchao Zhang   MPI_Comm                     comm;
1395076ba34aSJunchao Zhang 
1396076ba34aSJunchao Zhang   PetscFunctionBegin;
13970e3ece09SJunchao Zhang   PetscCall(PetscObjectGetComm((PetscObject)C, &comm));
1398076ba34aSJunchao Zhang   MatCheckProduct(C, 1);
13990e3ece09SJunchao Zhang   product = C->product;
14000e3ece09SJunchao Zhang   PetscCheck(!product->data, comm, PETSC_ERR_PLIB, "Product data not empty");
1401076ba34aSJunchao Zhang   ptype = product->type;
1402076ba34aSJunchao Zhang   A     = product->A;
1403076ba34aSJunchao Zhang   B     = product->B;
1404076ba34aSJunchao Zhang 
1405076ba34aSJunchao Zhang   switch (ptype) {
14069371c9d4SSatish Balay   case MATPRODUCT_AB:
14079371c9d4SSatish Balay     m = A->rmap->n;
14089371c9d4SSatish Balay     n = B->cmap->n;
14099371c9d4SSatish Balay     M = A->rmap->N;
14109371c9d4SSatish Balay     N = B->cmap->N;
14119371c9d4SSatish Balay     break;
14129371c9d4SSatish Balay   case MATPRODUCT_AtB:
14139371c9d4SSatish Balay     m = A->cmap->n;
14149371c9d4SSatish Balay     n = B->cmap->n;
14159371c9d4SSatish Balay     M = A->cmap->N;
14169371c9d4SSatish Balay     N = B->cmap->N;
14179371c9d4SSatish Balay     break;
14189371c9d4SSatish Balay   case MATPRODUCT_PtAP:
14199371c9d4SSatish Balay     m = B->cmap->n;
14209371c9d4SSatish Balay     n = B->cmap->n;
14219371c9d4SSatish Balay     M = B->cmap->N;
14229371c9d4SSatish Balay     N = B->cmap->N;
14239371c9d4SSatish Balay     break; /* BtAB */
1424d71ae5a4SJacob Faibussowitsch   default:
14250e3ece09SJunchao Zhang     SETERRQ(comm, PETSC_ERR_PLIB, "Not for product type %s", MatProductTypes[ptype]);
1426076ba34aSJunchao Zhang   }
1427076ba34aSJunchao Zhang 
14289566063dSJacob Faibussowitsch   PetscCall(MatSetSizes(C, m, n, M, N));
14299566063dSJacob Faibussowitsch   PetscCall(PetscLayoutSetUp(C->rmap));
14309566063dSJacob Faibussowitsch   PetscCall(PetscLayoutSetUp(C->cmap));
14319566063dSJacob Faibussowitsch   PetscCall(MatSetType(C, ((PetscObject)A)->type_name));
1432076ba34aSJunchao Zhang 
14330e3ece09SJunchao Zhang   pdata           = new MatProductData_MPIAIJKokkos();
14340e3ece09SJunchao Zhang   pdata->reusesym = product->api_user;
1435076ba34aSJunchao Zhang 
1436076ba34aSJunchao Zhang   if (ptype == MATPRODUCT_AB) {
14370e3ece09SJunchao Zhang     auto mmAB = new MatMatStruct_AB();
14380e3ece09SJunchao Zhang     PetscCall(MatProductSymbolic_MPIAIJKokkos_AB(product, A, B, mmAB));
14390e3ece09SJunchao Zhang     mm = pdata->mmAB = mmAB;
1440076ba34aSJunchao Zhang   } else if (ptype == MATPRODUCT_AtB) {
14410e3ece09SJunchao Zhang     auto mmAtB = new MatMatStruct_AtB();
14420e3ece09SJunchao Zhang     PetscCall(MatProductSymbolic_MPIAIJKokkos_AtB(product, A, B, mmAtB));
14430e3ece09SJunchao Zhang     mm = pdata->mmAtB = mmAtB;
14440e3ece09SJunchao Zhang   } else if (ptype == MATPRODUCT_PtAP) { // C = BtAB, computed as Z = AB; C= BtZ
14450e3ece09SJunchao Zhang     Mat Zd, Zo, Z;                       // Zd, Zo are owned by pdata->Z
14460e3ece09SJunchao Zhang 
14470e3ece09SJunchao Zhang     auto mmAB = new MatMatStruct_AB();
14480e3ece09SJunchao Zhang     PetscCall(MatProductSymbolic_MPIAIJKokkos_AB(product, A, B, mmAB)); // Z stored as mmAB->{Cd, Co}
14490e3ece09SJunchao Zhang     PetscCall(MatCreateSeqAIJKokkosWithKokkosCsrMatrix(PETSC_COMM_SELF, mmAB->Cd, &Zd));
14500e3ece09SJunchao Zhang     PetscCall(MatCreateSeqAIJKokkosWithKokkosCsrMatrix(PETSC_COMM_SELF, mmAB->Co, &Zo));
14510e3ece09SJunchao Zhang     pdata->mmAB = mmAB;
14520e3ece09SJunchao Zhang 
14530e3ece09SJunchao Zhang     m = A->rmap->n; // Z's layout
14540e3ece09SJunchao Zhang     n = B->cmap->n;
14550e3ece09SJunchao Zhang     M = A->rmap->N;
14560e3ece09SJunchao Zhang     N = B->cmap->N;
14570e3ece09SJunchao Zhang     PetscCall(MatCreate(comm, &Z));
14580e3ece09SJunchao Zhang     PetscCall(MatSetSizes(Z, m, n, M, N));
14590e3ece09SJunchao Zhang     PetscCall(PetscLayoutSetUp(Z->rmap));
14600e3ece09SJunchao Zhang     PetscCall(PetscLayoutSetUp(Z->cmap));
14610e3ece09SJunchao Zhang     PetscCall(MatSetType(Z, MATMPIAIJKOKKOS));
14620e3ece09SJunchao Zhang     PetscCall(MatSetMPIAIJKokkosWithSplitSeqAIJKokkosMatrices(Z, Zd, Zo, mmAB->garray));
14630e3ece09SJunchao Zhang 
14640e3ece09SJunchao Zhang     auto mmAtB = new MatMatStruct_AtB();
14650e3ece09SJunchao Zhang     PetscCall(MatProductSymbolic_MPIAIJKokkos_AtB(product, B, Z, mmAtB)); // final result C stored as mmAtB->{Cd, Co}
14660e3ece09SJunchao Zhang 
14670e3ece09SJunchao Zhang     pdata->Z = Z; // give ownership to pdata
14680e3ece09SJunchao Zhang     mm = pdata->mmAtB = mmAtB;
1469076ba34aSJunchao Zhang   }
14700e3ece09SJunchao Zhang 
14710e3ece09SJunchao Zhang   PetscCall(MatCreateSeqAIJKokkosWithKokkosCsrMatrix(PETSC_COMM_SELF, mm->Cd, &Cd));
14720e3ece09SJunchao Zhang   PetscCall(MatCreateSeqAIJKokkosWithKokkosCsrMatrix(PETSC_COMM_SELF, mm->Co, &Co));
14730e3ece09SJunchao Zhang   PetscCall(MatSetMPIAIJKokkosWithSplitSeqAIJKokkosMatrices(C, Cd, Co, mm->garray));
14740e3ece09SJunchao Zhang 
14750e3ece09SJunchao Zhang   C->product->data       = pdata;
1476076ba34aSJunchao Zhang   C->product->destroy    = MatProductDataDestroy_MPIAIJKokkos;
1477076ba34aSJunchao Zhang   C->ops->productnumeric = MatProductNumeric_MPIAIJKokkos;
14783ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
1479076ba34aSJunchao Zhang }
1480076ba34aSJunchao Zhang 
1481d71ae5a4SJacob Faibussowitsch PETSC_INTERN PetscErrorCode MatProductSetFromOptions_MPIAIJKokkos(Mat mat)
1482d71ae5a4SJacob Faibussowitsch {
1483076ba34aSJunchao Zhang   Mat_Product *product = mat->product;
1484076ba34aSJunchao Zhang   PetscBool    match   = PETSC_FALSE;
1485076ba34aSJunchao Zhang   PetscBool    usecpu  = PETSC_FALSE;
1486076ba34aSJunchao Zhang 
1487076ba34aSJunchao Zhang   PetscFunctionBegin;
1488076ba34aSJunchao Zhang   MatCheckProduct(mat, 1);
148948a46eb9SPierre Jolivet   if (!product->A->boundtocpu && !product->B->boundtocpu) PetscCall(PetscObjectTypeCompare((PetscObject)product->B, ((PetscObject)product->A)->type_name, &match));
1490076ba34aSJunchao Zhang   if (match) { /* we can always fallback to the CPU if requested */
1491076ba34aSJunchao Zhang     switch (product->type) {
1492076ba34aSJunchao Zhang     case MATPRODUCT_AB:
1493076ba34aSJunchao Zhang       if (product->api_user) {
1494d0609cedSBarry Smith         PetscOptionsBegin(PetscObjectComm((PetscObject)mat), ((PetscObject)mat)->prefix, "MatMatMult", "Mat");
14959566063dSJacob Faibussowitsch         PetscCall(PetscOptionsBool("-matmatmult_backend_cpu", "Use CPU code", "MatMatMult", usecpu, &usecpu, NULL));
1496d0609cedSBarry Smith         PetscOptionsEnd();
1497076ba34aSJunchao Zhang       } else {
1498d0609cedSBarry Smith         PetscOptionsBegin(PetscObjectComm((PetscObject)mat), ((PetscObject)mat)->prefix, "MatProduct_AB", "Mat");
14999566063dSJacob Faibussowitsch         PetscCall(PetscOptionsBool("-mat_product_algorithm_backend_cpu", "Use CPU code", "MatMatMult", usecpu, &usecpu, NULL));
1500d0609cedSBarry Smith         PetscOptionsEnd();
1501076ba34aSJunchao Zhang       }
1502076ba34aSJunchao Zhang       break;
1503076ba34aSJunchao Zhang     case MATPRODUCT_AtB:
1504076ba34aSJunchao Zhang       if (product->api_user) {
1505d0609cedSBarry Smith         PetscOptionsBegin(PetscObjectComm((PetscObject)mat), ((PetscObject)mat)->prefix, "MatTransposeMatMult", "Mat");
15069566063dSJacob Faibussowitsch         PetscCall(PetscOptionsBool("-mattransposematmult_backend_cpu", "Use CPU code", "MatTransposeMatMult", usecpu, &usecpu, NULL));
1507d0609cedSBarry Smith         PetscOptionsEnd();
1508076ba34aSJunchao Zhang       } else {
1509d0609cedSBarry Smith         PetscOptionsBegin(PetscObjectComm((PetscObject)mat), ((PetscObject)mat)->prefix, "MatProduct_AtB", "Mat");
15109566063dSJacob Faibussowitsch         PetscCall(PetscOptionsBool("-mat_product_algorithm_backend_cpu", "Use CPU code", "MatTransposeMatMult", usecpu, &usecpu, NULL));
1511d0609cedSBarry Smith         PetscOptionsEnd();
1512076ba34aSJunchao Zhang       }
1513076ba34aSJunchao Zhang       break;
1514076ba34aSJunchao Zhang     case MATPRODUCT_PtAP:
1515076ba34aSJunchao Zhang       if (product->api_user) {
1516d0609cedSBarry Smith         PetscOptionsBegin(PetscObjectComm((PetscObject)mat), ((PetscObject)mat)->prefix, "MatPtAP", "Mat");
15179566063dSJacob Faibussowitsch         PetscCall(PetscOptionsBool("-matptap_backend_cpu", "Use CPU code", "MatPtAP", usecpu, &usecpu, NULL));
1518d0609cedSBarry Smith         PetscOptionsEnd();
1519076ba34aSJunchao Zhang       } else {
1520d0609cedSBarry Smith         PetscOptionsBegin(PetscObjectComm((PetscObject)mat), ((PetscObject)mat)->prefix, "MatProduct_PtAP", "Mat");
15219566063dSJacob Faibussowitsch         PetscCall(PetscOptionsBool("-mat_product_algorithm_backend_cpu", "Use CPU code", "MatPtAP", usecpu, &usecpu, NULL));
1522d0609cedSBarry Smith         PetscOptionsEnd();
1523076ba34aSJunchao Zhang       }
1524076ba34aSJunchao Zhang       break;
1525d71ae5a4SJacob Faibussowitsch     default:
1526d71ae5a4SJacob Faibussowitsch       break;
1527076ba34aSJunchao Zhang     }
1528076ba34aSJunchao Zhang     match = (PetscBool)!usecpu;
1529076ba34aSJunchao Zhang   }
1530076ba34aSJunchao Zhang   if (match) {
1531076ba34aSJunchao Zhang     switch (product->type) {
1532076ba34aSJunchao Zhang     case MATPRODUCT_AB:
1533076ba34aSJunchao Zhang     case MATPRODUCT_AtB:
1534d71ae5a4SJacob Faibussowitsch     case MATPRODUCT_PtAP:
1535d71ae5a4SJacob Faibussowitsch       mat->ops->productsymbolic = MatProductSymbolic_MPIAIJKokkos;
1536d71ae5a4SJacob Faibussowitsch       break;
1537d71ae5a4SJacob Faibussowitsch     default:
1538d71ae5a4SJacob Faibussowitsch       break;
1539076ba34aSJunchao Zhang     }
1540076ba34aSJunchao Zhang   }
1541076ba34aSJunchao Zhang   /* fallback to MPIAIJ ops */
154248a46eb9SPierre Jolivet   if (!mat->ops->productsymbolic) PetscCall(MatProductSetFromOptions_MPIAIJ(mat));
15433ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
1544076ba34aSJunchao Zhang }
1545076ba34aSJunchao Zhang 
15462c4ab24aSJunchao Zhang // Mirror of MatCOOStruct_MPIAIJ on device
15472c4ab24aSJunchao Zhang struct MatCOOStruct_MPIAIJKokkos {
15482c4ab24aSJunchao Zhang   PetscCount           n;
15492c4ab24aSJunchao Zhang   PetscSF              sf;
15502c4ab24aSJunchao Zhang   PetscCount           Annz, Bnnz;
15512c4ab24aSJunchao Zhang   PetscCount           Annz2, Bnnz2;
15522c4ab24aSJunchao Zhang   PetscCountKokkosView Ajmap1, Aperm1;
15532c4ab24aSJunchao Zhang   PetscCountKokkosView Bjmap1, Bperm1;
15542c4ab24aSJunchao Zhang   PetscCountKokkosView Aimap2, Ajmap2, Aperm2;
15552c4ab24aSJunchao Zhang   PetscCountKokkosView Bimap2, Bjmap2, Bperm2;
15562c4ab24aSJunchao Zhang   PetscCountKokkosView Cperm1;
15572c4ab24aSJunchao Zhang   MatScalarKokkosView  sendbuf, recvbuf;
15582c4ab24aSJunchao Zhang 
1559*92896123SJunchao Zhang   MatCOOStruct_MPIAIJKokkos(const MatCOOStruct_MPIAIJ *coo_h)
15602c4ab24aSJunchao Zhang   {
1561*92896123SJunchao Zhang     auto &exec = PetscGetKokkosExecutionSpace();
1562*92896123SJunchao Zhang 
1563*92896123SJunchao Zhang     n       = coo_h->n;
1564*92896123SJunchao Zhang     sf      = coo_h->sf;
1565*92896123SJunchao Zhang     Annz    = coo_h->Annz;
1566*92896123SJunchao Zhang     Bnnz    = coo_h->Bnnz;
1567*92896123SJunchao Zhang     Annz2   = coo_h->Annz2;
1568*92896123SJunchao Zhang     Bnnz2   = coo_h->Bnnz2;
1569*92896123SJunchao Zhang     Ajmap1  = Kokkos::create_mirror_view_and_copy(exec, PetscCountKokkosViewHost(coo_h->Ajmap1, coo_h->Annz + 1));
1570*92896123SJunchao Zhang     Aperm1  = Kokkos::create_mirror_view_and_copy(exec, PetscCountKokkosViewHost(coo_h->Aperm1, coo_h->Atot1));
1571*92896123SJunchao Zhang     Bjmap1  = Kokkos::create_mirror_view_and_copy(exec, PetscCountKokkosViewHost(coo_h->Bjmap1, coo_h->Bnnz + 1));
1572*92896123SJunchao Zhang     Bperm1  = Kokkos::create_mirror_view_and_copy(exec, PetscCountKokkosViewHost(coo_h->Bperm1, coo_h->Btot1));
1573*92896123SJunchao Zhang     Aimap2  = Kokkos::create_mirror_view_and_copy(exec, PetscCountKokkosViewHost(coo_h->Aimap2, coo_h->Annz2));
1574*92896123SJunchao Zhang     Ajmap2  = Kokkos::create_mirror_view_and_copy(exec, PetscCountKokkosViewHost(coo_h->Ajmap2, coo_h->Annz2 + 1));
1575*92896123SJunchao Zhang     Aperm2  = Kokkos::create_mirror_view_and_copy(exec, PetscCountKokkosViewHost(coo_h->Aperm2, coo_h->Atot2));
1576*92896123SJunchao Zhang     Bimap2  = Kokkos::create_mirror_view_and_copy(exec, PetscCountKokkosViewHost(coo_h->Bimap2, coo_h->Bnnz2));
1577*92896123SJunchao Zhang     Bjmap2  = Kokkos::create_mirror_view_and_copy(exec, PetscCountKokkosViewHost(coo_h->Bjmap2, coo_h->Bnnz2 + 1));
1578*92896123SJunchao Zhang     Bperm2  = Kokkos::create_mirror_view_and_copy(exec, PetscCountKokkosViewHost(coo_h->Bperm2, coo_h->Btot2));
1579*92896123SJunchao Zhang     Cperm1  = Kokkos::create_mirror_view_and_copy(exec, PetscCountKokkosViewHost(coo_h->Cperm1, coo_h->sendlen));
1580*92896123SJunchao Zhang     sendbuf = Kokkos::create_mirror_view(Kokkos::WithoutInitializing, exec, MatScalarKokkosViewHost(coo_h->sendbuf, coo_h->sendlen));
1581*92896123SJunchao Zhang     recvbuf = Kokkos::create_mirror_view(Kokkos::WithoutInitializing, exec, MatScalarKokkosViewHost(coo_h->recvbuf, coo_h->recvlen));
15822c4ab24aSJunchao Zhang     PetscCallVoid(PetscObjectReference((PetscObject)sf));
15832c4ab24aSJunchao Zhang   }
15842c4ab24aSJunchao Zhang 
15852c4ab24aSJunchao Zhang   ~MatCOOStruct_MPIAIJKokkos() { PetscCallVoid(PetscSFDestroy(&sf)); }
15862c4ab24aSJunchao Zhang };
15872c4ab24aSJunchao Zhang 
15882c4ab24aSJunchao Zhang static PetscErrorCode MatCOOStructDestroy_MPIAIJKokkos(void *data)
15892c4ab24aSJunchao Zhang {
15902c4ab24aSJunchao Zhang   PetscFunctionBegin;
15912c4ab24aSJunchao Zhang   PetscCallCXX(delete static_cast<MatCOOStruct_MPIAIJKokkos *>(data));
15922c4ab24aSJunchao Zhang   PetscFunctionReturn(PETSC_SUCCESS);
15932c4ab24aSJunchao Zhang }
15942c4ab24aSJunchao Zhang 
1595d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatSetPreallocationCOO_MPIAIJKokkos(Mat mat, PetscCount coo_n, PetscInt coo_i[], PetscInt coo_j[])
1596d71ae5a4SJacob Faibussowitsch {
15972c4ab24aSJunchao Zhang   PetscContainer             container_h, container_d;
15982c4ab24aSJunchao Zhang   MatCOOStruct_MPIAIJ       *coo_h;
15992c4ab24aSJunchao Zhang   MatCOOStruct_MPIAIJKokkos *coo_d;
160042550becSJunchao Zhang 
160142550becSJunchao Zhang   PetscFunctionBegin;
160230203840SJunchao Zhang   PetscCall(MatSetPreallocationCOO_MPIAIJ(mat, coo_n, coo_i, coo_j)); /* mpiaij->A,B's type is set to seqaijkokkos */
1603cbc6b225SStefano Zampini   mat->preallocated = PETSC_TRUE;
16049566063dSJacob Faibussowitsch   PetscCall(MatAssemblyBegin(mat, MAT_FINAL_ASSEMBLY));
16059566063dSJacob Faibussowitsch   PetscCall(MatAssemblyEnd(mat, MAT_FINAL_ASSEMBLY));
16069566063dSJacob Faibussowitsch   PetscCall(MatZeroEntries(mat));
16072c4ab24aSJunchao Zhang 
16082c4ab24aSJunchao Zhang   // Copy the COO struct to device
16092c4ab24aSJunchao Zhang   PetscCall(PetscObjectQuery((PetscObject)mat, "__PETSc_MatCOOStruct_Host", (PetscObject *)&container_h));
16102c4ab24aSJunchao Zhang   PetscCall(PetscContainerGetPointer(container_h, (void **)&coo_h));
16112c4ab24aSJunchao Zhang   PetscCallCXX(coo_d = new MatCOOStruct_MPIAIJKokkos(coo_h));
16122c4ab24aSJunchao Zhang 
16132c4ab24aSJunchao Zhang   // Put the COO struct in a container and then attach that to the matrix
16142c4ab24aSJunchao Zhang   PetscCall(PetscContainerCreate(PETSC_COMM_SELF, &container_d));
16152c4ab24aSJunchao Zhang   PetscCall(PetscContainerSetPointer(container_d, coo_d));
16162c4ab24aSJunchao Zhang   PetscCall(PetscContainerSetUserDestroy(container_d, MatCOOStructDestroy_MPIAIJKokkos));
16172c4ab24aSJunchao Zhang   PetscCall(PetscObjectCompose((PetscObject)mat, "__PETSc_MatCOOStruct_Device", (PetscObject)container_d));
16182c4ab24aSJunchao Zhang   PetscCall(PetscContainerDestroy(&container_d));
16193ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
162042550becSJunchao Zhang }
162142550becSJunchao Zhang 
1622d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatSetValuesCOO_MPIAIJKokkos(Mat mat, const PetscScalar v[], InsertMode imode)
1623d71ae5a4SJacob Faibussowitsch {
1624394ed5ebSJunchao Zhang   Mat_MPIAIJ                    *mpiaij = static_cast<Mat_MPIAIJ *>(mat->data);
162542550becSJunchao Zhang   Mat                            A = mpiaij->A, B = mpiaij->B;
162642550becSJunchao Zhang   MatScalarKokkosView            Aa, Ba;
1627394ed5ebSJunchao Zhang   MatScalarKokkosView            v1;
162842550becSJunchao Zhang   PetscMemType                   memtype;
16292c4ab24aSJunchao Zhang   PetscContainer                 container;
16302c4ab24aSJunchao Zhang   MatCOOStruct_MPIAIJKokkos     *coo;
1631*92896123SJunchao Zhang   Kokkos::DefaultExecutionSpace &exec = PetscGetKokkosExecutionSpace();
163242550becSJunchao Zhang 
163342550becSJunchao Zhang   PetscFunctionBegin;
16342c4ab24aSJunchao Zhang   PetscCall(PetscObjectQuery((PetscObject)mat, "__PETSc_MatCOOStruct_Device", (PetscObject *)&container));
16352c4ab24aSJunchao Zhang   PetscCall(PetscContainerGetPointer(container, (void **)&coo));
16362c4ab24aSJunchao Zhang 
16372c4ab24aSJunchao Zhang   const auto &n      = coo->n;
16382c4ab24aSJunchao Zhang   const auto &Annz   = coo->Annz;
16392c4ab24aSJunchao Zhang   const auto &Annz2  = coo->Annz2;
16402c4ab24aSJunchao Zhang   const auto &Bnnz   = coo->Bnnz;
16412c4ab24aSJunchao Zhang   const auto &Bnnz2  = coo->Bnnz2;
16422c4ab24aSJunchao Zhang   const auto &vsend  = coo->sendbuf;
16432c4ab24aSJunchao Zhang   const auto &v2     = coo->recvbuf;
16442c4ab24aSJunchao Zhang   const auto &Ajmap1 = coo->Ajmap1;
16452c4ab24aSJunchao Zhang   const auto &Ajmap2 = coo->Ajmap2;
16462c4ab24aSJunchao Zhang   const auto &Aimap2 = coo->Aimap2;
16472c4ab24aSJunchao Zhang   const auto &Bjmap1 = coo->Bjmap1;
16482c4ab24aSJunchao Zhang   const auto &Bjmap2 = coo->Bjmap2;
16492c4ab24aSJunchao Zhang   const auto &Bimap2 = coo->Bimap2;
16502c4ab24aSJunchao Zhang   const auto &Aperm1 = coo->Aperm1;
16512c4ab24aSJunchao Zhang   const auto &Aperm2 = coo->Aperm2;
16522c4ab24aSJunchao Zhang   const auto &Bperm1 = coo->Bperm1;
16532c4ab24aSJunchao Zhang   const auto &Bperm2 = coo->Bperm2;
16542c4ab24aSJunchao Zhang   const auto &Cperm1 = coo->Cperm1;
16552c4ab24aSJunchao Zhang 
16569566063dSJacob Faibussowitsch   PetscCall(PetscGetMemType(v, &memtype)); /* Return PETSC_MEMTYPE_HOST when v is NULL */
165742550becSJunchao Zhang   if (PetscMemTypeHost(memtype)) {         /* If user gave v[] in host, we need to copy it to device if any */
1658*92896123SJunchao Zhang     v1 = Kokkos::create_mirror_view_and_copy(exec, MatScalarKokkosViewHost((PetscScalar *)v, n));
165942550becSJunchao Zhang   } else {
16602c4ab24aSJunchao Zhang     v1 = MatScalarKokkosView((PetscScalar *)v, n); /* Directly use v[]'s memory */
166142550becSJunchao Zhang   }
166242550becSJunchao Zhang 
166342550becSJunchao Zhang   if (imode == INSERT_VALUES) {
16649566063dSJacob Faibussowitsch     PetscCall(MatSeqAIJGetKokkosViewWrite(A, &Aa)); /* write matrix values */
16659566063dSJacob Faibussowitsch     PetscCall(MatSeqAIJGetKokkosViewWrite(B, &Ba));
1666394ed5ebSJunchao Zhang   } else {
16679566063dSJacob Faibussowitsch     PetscCall(MatSeqAIJGetKokkosView(A, &Aa)); /* read & write matrix values */
16689566063dSJacob Faibussowitsch     PetscCall(MatSeqAIJGetKokkosView(B, &Ba));
166942550becSJunchao Zhang   }
167042550becSJunchao Zhang 
167108bb9926SJunchao Zhang   PetscCall(PetscLogGpuTimeBegin());
167242550becSJunchao Zhang   /* Pack entries to be sent to remote */
1673*92896123SJunchao Zhang   Kokkos::parallel_for(Kokkos::RangePolicy<>(exec, 0, vsend.extent(0)), KOKKOS_LAMBDA(const PetscCount i) { vsend(i) = v1(Cperm1(i)); });
167442550becSJunchao Zhang 
167542550becSJunchao Zhang   /* Send remote entries to their owner and overlap the communication with local computation */
16762c4ab24aSJunchao Zhang   PetscCall(PetscSFReduceWithMemTypeBegin(coo->sf, MPIU_SCALAR, PETSC_MEMTYPE_KOKKOS, vsend.data(), PETSC_MEMTYPE_KOKKOS, v2.data(), MPI_REPLACE));
1677158ec288SJunchao Zhang   /* Add local entries to A and B in one kernel */
16789371c9d4SSatish Balay   Kokkos::parallel_for(
1679*92896123SJunchao Zhang     Kokkos::RangePolicy<>(exec, 0, Annz + Bnnz), KOKKOS_LAMBDA(PetscCount i) {
1680158ec288SJunchao Zhang       PetscScalar sum = 0.0;
1681158ec288SJunchao Zhang       if (i < Annz) {
1682158ec288SJunchao Zhang         for (PetscCount k = Ajmap1(i); k < Ajmap1(i + 1); k++) sum += v1(Aperm1(k));
1683ac38520cSJunchao Zhang         Aa(i) = (imode == INSERT_VALUES ? 0.0 : Aa(i)) + sum;
1684158ec288SJunchao Zhang       } else {
1685158ec288SJunchao Zhang         i -= Annz;
1686158ec288SJunchao Zhang         for (PetscCount k = Bjmap1(i); k < Bjmap1(i + 1); k++) sum += v1(Bperm1(k));
1687ac38520cSJunchao Zhang         Ba(i) = (imode == INSERT_VALUES ? 0.0 : Ba(i)) + sum;
1688158ec288SJunchao Zhang       }
1689158ec288SJunchao Zhang     });
16902c4ab24aSJunchao Zhang   PetscCall(PetscSFReduceEnd(coo->sf, MPIU_SCALAR, vsend.data(), v2.data(), MPI_REPLACE));
169142550becSJunchao Zhang 
1692158ec288SJunchao Zhang   /* Add received remote entries to A and B in one kernel */
16939371c9d4SSatish Balay   Kokkos::parallel_for(
1694*92896123SJunchao Zhang     Kokkos::RangePolicy<>(exec, 0, Annz2 + Bnnz2), KOKKOS_LAMBDA(PetscCount i) {
1695158ec288SJunchao Zhang       if (i < Annz2) {
1696158ec288SJunchao Zhang         for (PetscCount k = Ajmap2(i); k < Ajmap2(i + 1); k++) Aa(Aimap2(i)) += v2(Aperm2(k));
1697158ec288SJunchao Zhang       } else {
1698158ec288SJunchao Zhang         i -= Annz2;
1699158ec288SJunchao Zhang         for (PetscCount k = Bjmap2(i); k < Bjmap2(i + 1); k++) Ba(Bimap2(i)) += v2(Bperm2(k));
1700158ec288SJunchao Zhang       }
1701158ec288SJunchao Zhang     });
170208bb9926SJunchao Zhang   PetscCall(PetscLogGpuTimeEnd());
170342550becSJunchao Zhang 
1704394ed5ebSJunchao Zhang   if (imode == INSERT_VALUES) {
17059566063dSJacob Faibussowitsch     PetscCall(MatSeqAIJRestoreKokkosViewWrite(A, &Aa)); /* Increase A & B's state etc. */
17069566063dSJacob Faibussowitsch     PetscCall(MatSeqAIJRestoreKokkosViewWrite(B, &Ba));
1707394ed5ebSJunchao Zhang   } else {
17089566063dSJacob Faibussowitsch     PetscCall(MatSeqAIJRestoreKokkosView(A, &Aa));
17099566063dSJacob Faibussowitsch     PetscCall(MatSeqAIJRestoreKokkosView(B, &Ba));
1710394ed5ebSJunchao Zhang   }
17113ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
171242550becSJunchao Zhang }
171342550becSJunchao Zhang 
17142c4ab24aSJunchao Zhang static PetscErrorCode MatDestroy_MPIAIJKokkos(Mat A)
1715d71ae5a4SJacob Faibussowitsch {
1716076ba34aSJunchao Zhang   PetscFunctionBegin;
17179566063dSJacob Faibussowitsch   PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatMPIAIJSetPreallocation_C", NULL));
17189566063dSJacob Faibussowitsch   PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatMPIAIJGetLocalMatMerge_C", NULL));
17199566063dSJacob Faibussowitsch   PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatSetPreallocationCOO_C", NULL));
17209566063dSJacob Faibussowitsch   PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatSetValuesCOO_C", NULL));
17219566063dSJacob Faibussowitsch   PetscCall(MatDestroy_MPIAIJ(A));
17223ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
1723076ba34aSJunchao Zhang }
1724076ba34aSJunchao Zhang 
1725f4747e26SJunchao Zhang static PetscErrorCode MatShift_MPIAIJKokkos(Mat A, PetscScalar a)
1726f4747e26SJunchao Zhang {
1727f4747e26SJunchao Zhang   Mat_MPIAIJ *mpiaij = static_cast<Mat_MPIAIJ *>(A->data);
1728f4747e26SJunchao Zhang   PetscBool   congruent;
1729f4747e26SJunchao Zhang 
1730f4747e26SJunchao Zhang   PetscFunctionBegin;
1731f4747e26SJunchao Zhang   PetscCall(MatHasCongruentLayouts(A, &congruent));
1732f4747e26SJunchao Zhang   if (congruent) { // square matrix and the diagonals are solely in the diag block
1733f4747e26SJunchao Zhang     PetscCall(MatShift(mpiaij->A, a));
1734f4747e26SJunchao Zhang   } else { // too hard, use the general version
1735f4747e26SJunchao Zhang     PetscCall(MatShift_Basic(A, a));
1736f4747e26SJunchao Zhang   }
1737f4747e26SJunchao Zhang   PetscFunctionReturn(PETSC_SUCCESS);
1738f4747e26SJunchao Zhang }
1739f4747e26SJunchao Zhang 
17402c4ab24aSJunchao Zhang static PetscErrorCode MatSetOps_MPIAIJKokkos(Mat B)
17412c4ab24aSJunchao Zhang {
17422c4ab24aSJunchao Zhang   PetscFunctionBegin;
17432c4ab24aSJunchao Zhang   B->ops->assemblyend           = MatAssemblyEnd_MPIAIJKokkos;
17442c4ab24aSJunchao Zhang   B->ops->mult                  = MatMult_MPIAIJKokkos;
17452c4ab24aSJunchao Zhang   B->ops->multadd               = MatMultAdd_MPIAIJKokkos;
17462c4ab24aSJunchao Zhang   B->ops->multtranspose         = MatMultTranspose_MPIAIJKokkos;
17472c4ab24aSJunchao Zhang   B->ops->productsetfromoptions = MatProductSetFromOptions_MPIAIJKokkos;
17482c4ab24aSJunchao Zhang   B->ops->destroy               = MatDestroy_MPIAIJKokkos;
1749f4747e26SJunchao Zhang   B->ops->shift                 = MatShift_MPIAIJKokkos;
17502c4ab24aSJunchao Zhang 
17512c4ab24aSJunchao Zhang   PetscCall(PetscObjectComposeFunction((PetscObject)B, "MatMPIAIJSetPreallocation_C", MatMPIAIJSetPreallocation_MPIAIJKokkos));
17522c4ab24aSJunchao Zhang   PetscCall(PetscObjectComposeFunction((PetscObject)B, "MatMPIAIJGetLocalMatMerge_C", MatMPIAIJGetLocalMatMerge_MPIAIJKokkos));
17532c4ab24aSJunchao Zhang   PetscCall(PetscObjectComposeFunction((PetscObject)B, "MatSetPreallocationCOO_C", MatSetPreallocationCOO_MPIAIJKokkos));
17542c4ab24aSJunchao Zhang   PetscCall(PetscObjectComposeFunction((PetscObject)B, "MatSetValuesCOO_C", MatSetValuesCOO_MPIAIJKokkos));
17552c4ab24aSJunchao Zhang   PetscFunctionReturn(PETSC_SUCCESS);
17562c4ab24aSJunchao Zhang }
17572c4ab24aSJunchao Zhang 
1758d71ae5a4SJacob Faibussowitsch PETSC_INTERN PetscErrorCode MatConvert_MPIAIJ_MPIAIJKokkos(Mat A, MatType mtype, MatReuse reuse, Mat *newmat)
1759d71ae5a4SJacob Faibussowitsch {
17608c3ff71bSJunchao Zhang   Mat         B;
1761076ba34aSJunchao Zhang   Mat_MPIAIJ *a;
17628c3ff71bSJunchao Zhang 
17638c3ff71bSJunchao Zhang   PetscFunctionBegin;
17648c3ff71bSJunchao Zhang   if (reuse == MAT_INITIAL_MATRIX) {
17659566063dSJacob Faibussowitsch     PetscCall(MatDuplicate(A, MAT_COPY_VALUES, newmat));
17668c3ff71bSJunchao Zhang   } else if (reuse == MAT_REUSE_MATRIX) {
17679566063dSJacob Faibussowitsch     PetscCall(MatCopy(A, *newmat, SAME_NONZERO_PATTERN));
17688c3ff71bSJunchao Zhang   }
17698c3ff71bSJunchao Zhang   B = *newmat;
17708c3ff71bSJunchao Zhang 
17716f3d89d0SStefano Zampini   B->boundtocpu = PETSC_FALSE;
17729566063dSJacob Faibussowitsch   PetscCall(PetscFree(B->defaultvectype));
17739566063dSJacob Faibussowitsch   PetscCall(PetscStrallocpy(VECKOKKOS, &B->defaultvectype));
17749566063dSJacob Faibussowitsch   PetscCall(PetscObjectChangeTypeName((PetscObject)B, MATMPIAIJKOKKOS));
17758c3ff71bSJunchao Zhang 
1776076ba34aSJunchao Zhang   a = static_cast<Mat_MPIAIJ *>(A->data);
17779566063dSJacob Faibussowitsch   if (a->A) PetscCall(MatSetType(a->A, MATSEQAIJKOKKOS));
17789566063dSJacob Faibussowitsch   if (a->B) PetscCall(MatSetType(a->B, MATSEQAIJKOKKOS));
17799566063dSJacob Faibussowitsch   if (a->lvec) PetscCall(VecSetType(a->lvec, VECSEQKOKKOS));
17802c4ab24aSJunchao Zhang   PetscCall(MatSetOps_MPIAIJKokkos(B));
17813ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
17828c3ff71bSJunchao Zhang }
17832c4ab24aSJunchao Zhang 
17843f3ba80aSJunchao Zhang /*MC
178511a5261eSBarry Smith    MATAIJKOKKOS - "mpiaijkokkos", a matrix type to be used for CSR sparse matrices with Kokkos
17868c3ff71bSJunchao Zhang 
178715229ffcSPierre Jolivet    A matrix type using Kokkos-Kernels CrsMatrix type for portability across different device types
17883f3ba80aSJunchao Zhang 
17892ef1f0ffSBarry Smith    Options Database Key:
17902ef1f0ffSBarry Smith .  -mat_type aijkokkos - sets the matrix type to `MATAIJKOKKOS`
17913f3ba80aSJunchao Zhang 
17923f3ba80aSJunchao Zhang   Level: beginner
17933f3ba80aSJunchao Zhang 
17941cc06b55SBarry Smith .seealso: [](ch_matrices), `Mat`, `MatCreateAIJKokkos()`, `MATSEQAIJKOKKOS`, `MATSEQAIJ`, `MATMPIAIJ`
17953f3ba80aSJunchao Zhang M*/
1796d71ae5a4SJacob Faibussowitsch PETSC_EXTERN PetscErrorCode MatCreate_MPIAIJKokkos(Mat A)
1797d71ae5a4SJacob Faibussowitsch {
17988c3ff71bSJunchao Zhang   PetscFunctionBegin;
17999566063dSJacob Faibussowitsch   PetscCall(PetscKokkosInitializeCheck());
18009566063dSJacob Faibussowitsch   PetscCall(MatCreate_MPIAIJ(A));
18019566063dSJacob Faibussowitsch   PetscCall(MatConvert_MPIAIJ_MPIAIJKokkos(A, MATMPIAIJKOKKOS, MAT_INPLACE_MATRIX, &A));
18023ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
18038c3ff71bSJunchao Zhang }
18048c3ff71bSJunchao Zhang 
18058c3ff71bSJunchao Zhang /*@C
180611a5261eSBarry Smith   MatCreateAIJKokkos - Creates a sparse matrix in `MATAIJKOKOS` (compressed row) format
18078c3ff71bSJunchao Zhang   (the default parallel PETSc format).  This matrix will ultimately pushed down
180820f4b53cSBarry Smith   to Kokkos for calculations.
18098c3ff71bSJunchao Zhang 
18108c3ff71bSJunchao Zhang   Collective
18118c3ff71bSJunchao Zhang 
18128c3ff71bSJunchao Zhang   Input Parameters:
181311a5261eSBarry Smith + comm  - MPI communicator, set to `PETSC_COMM_SELF`
181420f4b53cSBarry Smith . m     - number of local rows (or `PETSC_DECIDE` to have calculated if `M` is given)
181520f4b53cSBarry Smith            This value should be the same as the local size used in creating the
181620f4b53cSBarry Smith            y vector for the matrix-vector product y = Ax.
181720f4b53cSBarry Smith . n     - This value should be the same as the local size used in creating the
181820f4b53cSBarry Smith        x vector for the matrix-vector product y = Ax. (or `PETSC_DECIDE` to have
181920f4b53cSBarry Smith        calculated if N is given) For square matrices n is almost always `m`.
182020f4b53cSBarry Smith . M     - number of global rows (or `PETSC_DETERMINE` to have calculated if `m` is given)
182120f4b53cSBarry Smith . N     - number of global columns (or `PETSC_DETERMINE` to have calculated if `n` is given)
182220f4b53cSBarry Smith . d_nz  - number of nonzeros per row in DIAGONAL portion of local submatrix
182320f4b53cSBarry Smith            (same value is used for all local rows)
182420f4b53cSBarry Smith . d_nnz - array containing the number of nonzeros in the various rows of the
182520f4b53cSBarry Smith            DIAGONAL portion of the local submatrix (possibly different for each row)
182620f4b53cSBarry Smith            or `NULL`, if `d_nz` is used to specify the nonzero structure.
182720f4b53cSBarry Smith            The size of this array is equal to the number of local rows, i.e `m`.
182820f4b53cSBarry Smith            For matrices you plan to factor you must leave room for the diagonal entry and
182920f4b53cSBarry Smith            put in the entry even if it is zero.
183020f4b53cSBarry Smith . o_nz  - number of nonzeros per row in the OFF-DIAGONAL portion of local
183120f4b53cSBarry Smith            submatrix (same value is used for all local rows).
183220f4b53cSBarry Smith - o_nnz - array containing the number of nonzeros in the various rows of the
183320f4b53cSBarry Smith            OFF-DIAGONAL portion of the local submatrix (possibly different for
183420f4b53cSBarry Smith            each row) or `NULL`, if `o_nz` is used to specify the nonzero
183520f4b53cSBarry Smith            structure. The size of this array is equal to the number
183620f4b53cSBarry Smith            of local rows, i.e `m`.
18378c3ff71bSJunchao Zhang 
18388c3ff71bSJunchao Zhang   Output Parameter:
18398c3ff71bSJunchao Zhang . A - the matrix
18408c3ff71bSJunchao Zhang 
18412ef1f0ffSBarry Smith   Level: intermediate
18422ef1f0ffSBarry Smith 
18432ef1f0ffSBarry Smith   Notes:
184411a5261eSBarry Smith   It is recommended that one use the `MatCreate()`, `MatSetType()` and/or `MatSetFromOptions()`,
18458c3ff71bSJunchao Zhang   MatXXXXSetPreallocation() paradigm instead of this routine directly.
184611a5261eSBarry Smith   [MatXXXXSetPreallocation() is, for example, `MatSeqAIJSetPreallocation()`]
18478c3ff71bSJunchao Zhang 
1848667f096bSBarry Smith   The AIJ format, also called compressed row storage), is fully compatible with standard Fortran
18498c3ff71bSJunchao Zhang   storage.  That is, the stored row and column indices can begin at
18502ef1f0ffSBarry Smith   either one (as in Fortran) or zero.
18518c3ff71bSJunchao Zhang 
18521cc06b55SBarry Smith .seealso: [](ch_matrices), `Mat`, `MATAIJKOKOS`, `MATSEQAIJKOKOS`, `MATMPIAIJKOKOS`, `MatCreate()`, `MatCreateAIJ()`, `MatSetValues()`,
1853fe59aa6dSJacob Faibussowitsch           `MatSeqAIJSetColumnIndices()`, `MatCreateSeqAIJWithArrays()`, `MATMPIAIJKOKKOS`, `MATAIJKOKKOS`
18548c3ff71bSJunchao Zhang @*/
1855d71ae5a4SJacob Faibussowitsch PetscErrorCode MatCreateAIJKokkos(MPI_Comm comm, PetscInt m, PetscInt n, PetscInt M, PetscInt N, PetscInt d_nz, const PetscInt d_nnz[], PetscInt o_nz, const PetscInt o_nnz[], Mat *A)
1856d71ae5a4SJacob Faibussowitsch {
18578c3ff71bSJunchao Zhang   PetscMPIInt size;
18588c3ff71bSJunchao Zhang 
18598c3ff71bSJunchao Zhang   PetscFunctionBegin;
18609566063dSJacob Faibussowitsch   PetscCall(MatCreate(comm, A));
18619566063dSJacob Faibussowitsch   PetscCall(MatSetSizes(*A, m, n, M, N));
18629566063dSJacob Faibussowitsch   PetscCallMPI(MPI_Comm_size(comm, &size));
18638c3ff71bSJunchao Zhang   if (size > 1) {
18649566063dSJacob Faibussowitsch     PetscCall(MatSetType(*A, MATMPIAIJKOKKOS));
18659566063dSJacob Faibussowitsch     PetscCall(MatMPIAIJSetPreallocation(*A, d_nz, d_nnz, o_nz, o_nnz));
18668c3ff71bSJunchao Zhang   } else {
18679566063dSJacob Faibussowitsch     PetscCall(MatSetType(*A, MATSEQAIJKOKKOS));
18689566063dSJacob Faibussowitsch     PetscCall(MatSeqAIJSetPreallocation(*A, d_nz, d_nnz));
18698c3ff71bSJunchao Zhang   }
18703ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
18718c3ff71bSJunchao Zhang }
1872