xref: /petsc/src/mat/impls/aij/mpi/kokkos/mpiaijkok.kokkos.cxx (revision 2cdb1aead714728ece08262a74a00c3a30a10b8f)
1d326c3f1SJunchao Zhang #include <petsc_kokkos.hpp>
211d22bbfSJunchao Zhang #include <petscvec_kokkos.hpp>
3f0e6e2d1SJunchao Zhang #include <petscpkg_version.h>
442550becSJunchao Zhang #include <petsc/private/sfimpl.h>
52c4ab24aSJunchao Zhang #include <../src/mat/impls/aij/seq/kokkos/aijkok.hpp>
68c3ff71bSJunchao Zhang #include <../src/mat/impls/aij/mpi/mpiaij.h>
7076ba34aSJunchao Zhang #include <KokkosSparse_spadd.hpp>
80e3ece09SJunchao Zhang #include <KokkosSparse_spgemm.hpp>
911d22bbfSJunchao Zhang 
1066976f2fSJacob Faibussowitsch static PetscErrorCode MatAssemblyEnd_MPIAIJKokkos(Mat A, MatAssemblyType mode)
11d71ae5a4SJacob Faibussowitsch {
1230203840SJunchao Zhang   Mat_MPIAIJ *mpiaij = (Mat_MPIAIJ *)A->data;
138c3ff71bSJunchao Zhang 
148c3ff71bSJunchao Zhang   PetscFunctionBegin;
159566063dSJacob Faibussowitsch   PetscCall(MatAssemblyEnd_MPIAIJ(A, mode));
1630203840SJunchao Zhang   /* E.g., MatCreateSubMatrix() calls MatCreateMPIAIJWithSeqAIJ(comm,A,B,..), which creates Bnew of SEQAIJ and destroys B of SEQAIJKOKKOS.
1730203840SJunchao Zhang      Thus we finalize A/B/lvec's type in MatAssemblyEnd() to handle various cases.
1830203840SJunchao Zhang    */
1930203840SJunchao Zhang   if (mode == MAT_FINAL_ASSEMBLY) {
2030203840SJunchao Zhang     PetscCall(MatSetType(mpiaij->A, MATSEQAIJKOKKOS));
2130203840SJunchao Zhang     PetscCall(MatSetType(mpiaij->B, MATSEQAIJKOKKOS));
2230203840SJunchao Zhang     PetscCall(VecSetType(mpiaij->lvec, VECSEQKOKKOS));
2330203840SJunchao Zhang   }
243ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
258c3ff71bSJunchao Zhang }
268c3ff71bSJunchao Zhang 
2766976f2fSJacob Faibussowitsch static PetscErrorCode MatMPIAIJSetPreallocation_MPIAIJKokkos(Mat mat, PetscInt d_nz, const PetscInt d_nnz[], PetscInt o_nz, const PetscInt o_nnz[])
28d71ae5a4SJacob Faibussowitsch {
29*2cdb1aeaSJunchao Zhang   Mat_MPIAIJ *mpiaij;
308c3ff71bSJunchao Zhang 
318c3ff71bSJunchao Zhang   PetscFunctionBegin;
32*2cdb1aeaSJunchao Zhang   // reuse MPIAIJ's preallocation, which sets A/B's blocksize along other things
33*2cdb1aeaSJunchao Zhang   PetscCall(MatMPIAIJSetPreallocation_MPIAIJ(mat, d_nz, d_nnz, o_nz, o_nnz));
34*2cdb1aeaSJunchao Zhang   mpiaij = static_cast<Mat_MPIAIJ *>(mat->data);
35*2cdb1aeaSJunchao Zhang   PetscCall(MatConvert_SeqAIJ_SeqAIJKokkos(mpiaij->A, MATSEQAIJKOKKOS, MAT_INPLACE_MATRIX, &mpiaij->A));
36*2cdb1aeaSJunchao Zhang   PetscCall(MatConvert_SeqAIJ_SeqAIJKokkos(mpiaij->B, MATSEQAIJKOKKOS, MAT_INPLACE_MATRIX, &mpiaij->B));
373ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
388c3ff71bSJunchao Zhang }
398c3ff71bSJunchao Zhang 
4066976f2fSJacob Faibussowitsch static PetscErrorCode MatMult_MPIAIJKokkos(Mat mat, Vec xx, Vec yy)
41d71ae5a4SJacob Faibussowitsch {
428c3ff71bSJunchao Zhang   Mat_MPIAIJ *mpiaij = (Mat_MPIAIJ *)mat->data;
438c3ff71bSJunchao Zhang   PetscInt    nt;
448c3ff71bSJunchao Zhang 
458c3ff71bSJunchao Zhang   PetscFunctionBegin;
469566063dSJacob Faibussowitsch   PetscCall(VecGetLocalSize(xx, &nt));
4708401ef6SPierre Jolivet   PetscCheck(nt == mat->cmap->n, PETSC_COMM_SELF, PETSC_ERR_ARG_SIZ, "Incompatible partition of mat (%" PetscInt_FMT ") and xx (%" PetscInt_FMT ")", mat->cmap->n, nt);
489566063dSJacob Faibussowitsch   PetscCall(VecScatterBegin(mpiaij->Mvctx, xx, mpiaij->lvec, INSERT_VALUES, SCATTER_FORWARD));
499566063dSJacob Faibussowitsch   PetscCall((*mpiaij->A->ops->mult)(mpiaij->A, xx, yy));
509566063dSJacob Faibussowitsch   PetscCall(VecScatterEnd(mpiaij->Mvctx, xx, mpiaij->lvec, INSERT_VALUES, SCATTER_FORWARD));
519566063dSJacob Faibussowitsch   PetscCall((*mpiaij->B->ops->multadd)(mpiaij->B, mpiaij->lvec, yy, yy));
523ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
538c3ff71bSJunchao Zhang }
548c3ff71bSJunchao Zhang 
5566976f2fSJacob Faibussowitsch static PetscErrorCode MatMultAdd_MPIAIJKokkos(Mat mat, Vec xx, Vec yy, Vec zz)
56d71ae5a4SJacob Faibussowitsch {
578c3ff71bSJunchao Zhang   Mat_MPIAIJ *mpiaij = (Mat_MPIAIJ *)mat->data;
588c3ff71bSJunchao Zhang   PetscInt    nt;
598c3ff71bSJunchao Zhang 
608c3ff71bSJunchao Zhang   PetscFunctionBegin;
619566063dSJacob Faibussowitsch   PetscCall(VecGetLocalSize(xx, &nt));
6208401ef6SPierre Jolivet   PetscCheck(nt == mat->cmap->n, PETSC_COMM_SELF, PETSC_ERR_ARG_SIZ, "Incompatible partition of mat (%" PetscInt_FMT ") and xx (%" PetscInt_FMT ")", mat->cmap->n, nt);
639566063dSJacob Faibussowitsch   PetscCall(VecScatterBegin(mpiaij->Mvctx, xx, mpiaij->lvec, INSERT_VALUES, SCATTER_FORWARD));
649566063dSJacob Faibussowitsch   PetscCall((*mpiaij->A->ops->multadd)(mpiaij->A, xx, yy, zz));
659566063dSJacob Faibussowitsch   PetscCall(VecScatterEnd(mpiaij->Mvctx, xx, mpiaij->lvec, INSERT_VALUES, SCATTER_FORWARD));
669566063dSJacob Faibussowitsch   PetscCall((*mpiaij->B->ops->multadd)(mpiaij->B, mpiaij->lvec, zz, zz));
673ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
688c3ff71bSJunchao Zhang }
698c3ff71bSJunchao Zhang 
7066976f2fSJacob Faibussowitsch static PetscErrorCode MatMultTranspose_MPIAIJKokkos(Mat mat, Vec xx, Vec yy)
71d71ae5a4SJacob Faibussowitsch {
728c3ff71bSJunchao Zhang   Mat_MPIAIJ *mpiaij = (Mat_MPIAIJ *)mat->data;
738c3ff71bSJunchao Zhang   PetscInt    nt;
748c3ff71bSJunchao Zhang 
758c3ff71bSJunchao Zhang   PetscFunctionBegin;
769566063dSJacob Faibussowitsch   PetscCall(VecGetLocalSize(xx, &nt));
7708401ef6SPierre Jolivet   PetscCheck(nt == mat->rmap->n, PETSC_COMM_SELF, PETSC_ERR_ARG_SIZ, "Incompatible partition of mat (%" PetscInt_FMT ") and xx (%" PetscInt_FMT ")", mat->rmap->n, nt);
789566063dSJacob Faibussowitsch   PetscCall((*mpiaij->B->ops->multtranspose)(mpiaij->B, xx, mpiaij->lvec));
799566063dSJacob Faibussowitsch   PetscCall((*mpiaij->A->ops->multtranspose)(mpiaij->A, xx, yy));
809566063dSJacob Faibussowitsch   PetscCall(VecScatterBegin(mpiaij->Mvctx, mpiaij->lvec, yy, ADD_VALUES, SCATTER_REVERSE));
819566063dSJacob Faibussowitsch   PetscCall(VecScatterEnd(mpiaij->Mvctx, mpiaij->lvec, yy, ADD_VALUES, SCATTER_REVERSE));
823ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
838c3ff71bSJunchao Zhang }
848c3ff71bSJunchao Zhang 
85076ba34aSJunchao Zhang /* Merge the "A, B" matrices of mat into a matrix C.  mat's type is MPIAIJKOKKOS. C's type is MATSEQAIJKOKKOS.
86076ba34aSJunchao Zhang    A is put before B. C's size would be A->rmap->n by (A->cmap->n + B->cmap->n).
87076ba34aSJunchao Zhang    C still uses local column ids. Their corresponding global column ids are returned in glob.
88076ba34aSJunchao Zhang */
8966976f2fSJacob Faibussowitsch static PetscErrorCode MatMPIAIJGetLocalMatMerge_MPIAIJKokkos(Mat mat, MatReuse reuse, IS *glob, Mat *C)
90d71ae5a4SJacob Faibussowitsch {
91076ba34aSJunchao Zhang   Mat             Ad, Ao;
92076ba34aSJunchao Zhang   const PetscInt *cmap;
93076ba34aSJunchao Zhang 
94076ba34aSJunchao Zhang   PetscFunctionBegin;
959566063dSJacob Faibussowitsch   PetscCall(MatMPIAIJGetSeqAIJ(mat, &Ad, &Ao, &cmap));
969566063dSJacob Faibussowitsch   PetscCall(MatSeqAIJKokkosMergeMats(Ad, Ao, reuse, C));
97076ba34aSJunchao Zhang   if (glob) {
98076ba34aSJunchao Zhang     PetscInt cst, i, dn, on, *gidx;
999566063dSJacob Faibussowitsch     PetscCall(MatGetLocalSize(Ad, NULL, &dn));
1009566063dSJacob Faibussowitsch     PetscCall(MatGetLocalSize(Ao, NULL, &on));
1019566063dSJacob Faibussowitsch     PetscCall(MatGetOwnershipRangeColumn(mat, &cst, NULL));
1029566063dSJacob Faibussowitsch     PetscCall(PetscMalloc1(dn + on, &gidx));
103076ba34aSJunchao Zhang     for (i = 0; i < dn; i++) gidx[i] = cst + i;
104076ba34aSJunchao Zhang     for (i = 0; i < on; i++) gidx[i + dn] = cmap[i];
1059566063dSJacob Faibussowitsch     PetscCall(ISCreateGeneral(PetscObjectComm((PetscObject)Ad), dn + on, gidx, PETSC_OWN_POINTER, glob));
106076ba34aSJunchao Zhang   }
1073ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
108076ba34aSJunchao Zhang }
109076ba34aSJunchao Zhang 
1100e3ece09SJunchao Zhang /* Structs used in matrix products of type C=AB, C=A^tB and C=B^tAB */
111076ba34aSJunchao Zhang struct MatMatStruct {
1120e3ece09SJunchao Zhang   PetscInt            n, *garray;     // C's garray and its size.
1130e3ece09SJunchao Zhang   KokkosCsrMatrix     Cd, Co;         // C is in split form matrices (all in local column indcies)
1140e3ece09SJunchao Zhang   KokkosCsrMatrix     C1, C2, C3, C4; // intermediate mat products
1150e3ece09SJunchao Zhang   KokkosCsrMatrix     C2_mid, C4_mid; // alias of C2, C4; share their a[], i[], but with different j[] (hence column size)
1160e3ece09SJunchao Zhang   PetscIntKokkosView  E_NzLeft;
1170e3ece09SJunchao Zhang   PetscSF             sf = nullptr; // SF to bcast or reduce matrices E to F
1180e3ece09SJunchao Zhang   MatScalarKokkosView rootBuf, leafBuf;
1190e3ece09SJunchao Zhang   KokkosCsrMatrix     Fd, Fo; // F in split form
1200e3ece09SJunchao Zhang 
1210e3ece09SJunchao Zhang   KernelHandle kh1; // compute C1, add C1+C3 or C1+Fd
1220e3ece09SJunchao Zhang   KernelHandle kh2; // compute C2, add C2+C4 or C2+Fo
1230e3ece09SJunchao Zhang   KernelHandle kh3; // compute C3
1240e3ece09SJunchao Zhang   KernelHandle kh4; // compute C4
1250e3ece09SJunchao Zhang 
126aaa8cc7dSPierre Jolivet   PetscInt E_TeamSize; // kernel launching parameters in merging E or splitting F
1270e3ece09SJunchao Zhang   PetscInt E_VectorLength;
1280e3ece09SJunchao Zhang   PetscInt E_RowsPerTeam;
1290e3ece09SJunchao Zhang   PetscInt F_TeamSize;
1300e3ece09SJunchao Zhang   PetscInt F_VectorLength;
1310e3ece09SJunchao Zhang   PetscInt F_RowsPerTeam;
132076ba34aSJunchao Zhang 
133d71ae5a4SJacob Faibussowitsch   ~MatMatStruct()
134d71ae5a4SJacob Faibussowitsch   {
1353ba16761SJacob Faibussowitsch     PetscFunctionBegin;
1363ba16761SJacob Faibussowitsch     PetscCallAbort(PETSC_COMM_SELF, PetscSFDestroy(&sf));
1373ba16761SJacob Faibussowitsch     PetscFunctionReturnVoid();
138076ba34aSJunchao Zhang   }
139076ba34aSJunchao Zhang };
140076ba34aSJunchao Zhang 
141076ba34aSJunchao Zhang struct MatMatStruct_AB : public MatMatStruct {
1420e3ece09SJunchao Zhang   PetscIntKokkosView F_NzLeft; // plans to split F (in leafbuf) into Fd, Fo
1430e3ece09SJunchao Zhang   PetscIntKokkosView irootloc; // plans to put E (i.e., Bd, Bo) into rootBuf
1440e3ece09SJunchao Zhang   PetscIntKokkosView rowoffset;
145076ba34aSJunchao Zhang };
146076ba34aSJunchao Zhang 
147076ba34aSJunchao Zhang struct MatMatStruct_AtB : public MatMatStruct {
1480e3ece09SJunchao Zhang   MatColIdxKokkosView Fdjmap; // plans to reduce data in rootBuf to Fd, Fo
1490e3ece09SJunchao Zhang   MatColIdxKokkosView Fdjperm;
1500e3ece09SJunchao Zhang   MatColIdxKokkosView Fojmap;
1510e3ece09SJunchao Zhang   MatColIdxKokkosView Fojperm;
152076ba34aSJunchao Zhang };
153076ba34aSJunchao Zhang 
1549371c9d4SSatish Balay struct MatProductData_MPIAIJKokkos {
1553ba16761SJacob Faibussowitsch   MatMatStruct_AB  *mmAB     = nullptr;
1563ba16761SJacob Faibussowitsch   MatMatStruct_AtB *mmAtB    = nullptr;
1573ba16761SJacob Faibussowitsch   PetscBool         reusesym = PETSC_FALSE;
1580e3ece09SJunchao Zhang   Mat               Z        = nullptr; // store Z=AB in computing BtAB
159076ba34aSJunchao Zhang 
160d71ae5a4SJacob Faibussowitsch   ~MatProductData_MPIAIJKokkos()
161d71ae5a4SJacob Faibussowitsch   {
162076ba34aSJunchao Zhang     delete mmAB;
163076ba34aSJunchao Zhang     delete mmAtB;
1640e3ece09SJunchao Zhang     PetscCallAbort(PETSC_COMM_SELF, MatDestroy(&Z));
165076ba34aSJunchao Zhang   }
166076ba34aSJunchao Zhang };
167076ba34aSJunchao Zhang 
168d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatProductDataDestroy_MPIAIJKokkos(void *data)
169d71ae5a4SJacob Faibussowitsch {
170076ba34aSJunchao Zhang   PetscFunctionBegin;
1719566063dSJacob Faibussowitsch   PetscCallCXX(delete static_cast<MatProductData_MPIAIJKokkos *>(data));
1723ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
173076ba34aSJunchao Zhang }
174076ba34aSJunchao Zhang 
175076ba34aSJunchao Zhang /* MatSetMPIAIJKokkosWithSplitSeqAIJKokkosMatrices - Set the diag and offdiag matrices of a MATMPIAIJKOKKOS matrix.
176076ba34aSJunchao Zhang    It is similar to MatCreateMPIAIJWithSplitArrays.
177076ba34aSJunchao Zhang 
178076ba34aSJunchao Zhang   Input Parameters:
179076ba34aSJunchao Zhang +  mat   - the MATMPIAIJKOKKOS matrix, which should have its type and layout set, but should not have its diag, offdiag matrices set
180076ba34aSJunchao Zhang .  A     - the diag matrix using local col ids
181076ba34aSJunchao Zhang -  B     - the offdiag matrix using global col ids
182076ba34aSJunchao Zhang 
1832fe279fdSBarry Smith   Output Parameter:
184076ba34aSJunchao Zhang .  mat   - the updated MATMPIAIJKOKKOS matrix
185076ba34aSJunchao Zhang */
1860e3ece09SJunchao Zhang static PetscErrorCode MatSetMPIAIJKokkosWithSplitSeqAIJKokkosMatrices(Mat mat, Mat A, Mat B, PetscInt *garray)
187d71ae5a4SJacob Faibussowitsch {
188076ba34aSJunchao Zhang   Mat_MPIAIJ *mpiaij = static_cast<Mat_MPIAIJ *>(mat->data);
189076ba34aSJunchao Zhang   PetscInt    m, n, M, N, Am, An, Bm, Bn;
190076ba34aSJunchao Zhang 
191076ba34aSJunchao Zhang   PetscFunctionBegin;
1929566063dSJacob Faibussowitsch   PetscCall(MatGetSize(mat, &M, &N));
1939566063dSJacob Faibussowitsch   PetscCall(MatGetLocalSize(mat, &m, &n));
1949566063dSJacob Faibussowitsch   PetscCall(MatGetLocalSize(A, &Am, &An));
1959566063dSJacob Faibussowitsch   PetscCall(MatGetLocalSize(B, &Bm, &Bn));
196076ba34aSJunchao Zhang 
197aed4548fSBarry Smith   PetscCheck(m == Am && m == Bm, PETSC_COMM_SELF, PETSC_ERR_PLIB, "local number of rows do not match");
19808401ef6SPierre Jolivet   PetscCheck(n == An, PETSC_COMM_SELF, PETSC_ERR_PLIB, "local number of columns do not match");
1990e3ece09SJunchao Zhang   // PetscCheck(N == Bn, PETSC_COMM_SELF, PETSC_ERR_PLIB, "global number of columns do not match");
20008401ef6SPierre Jolivet   PetscCheck(!mpiaij->A && !mpiaij->B, PETSC_COMM_SELF, PETSC_ERR_PLIB, "A, B of the MPIAIJ matrix are not empty");
201076ba34aSJunchao Zhang   mpiaij->A      = A;
202076ba34aSJunchao Zhang   mpiaij->B      = B;
2030e3ece09SJunchao Zhang   mpiaij->garray = garray;
204076ba34aSJunchao Zhang 
205076ba34aSJunchao Zhang   mat->preallocated     = PETSC_TRUE;
206076ba34aSJunchao Zhang   mat->nooffprocentries = PETSC_TRUE; /* See MatAssemblyBegin_MPIAIJ. In effect, making MatAssemblyBegin a nop */
207076ba34aSJunchao Zhang 
2089566063dSJacob Faibussowitsch   PetscCall(MatSetOption(mat, MAT_NO_OFF_PROC_ENTRIES, PETSC_TRUE));
2099566063dSJacob Faibussowitsch   PetscCall(MatAssemblyBegin(mat, MAT_FINAL_ASSEMBLY));
210076ba34aSJunchao Zhang   /* MatAssemblyEnd is critical here. It sets mat->offloadmask according to A and B's, and
211076ba34aSJunchao Zhang     also gets mpiaij->B compacted, with its col ids and size reduced
212076ba34aSJunchao Zhang   */
2139566063dSJacob Faibussowitsch   PetscCall(MatAssemblyEnd(mat, MAT_FINAL_ASSEMBLY));
2149566063dSJacob Faibussowitsch   PetscCall(MatSetOption(mat, MAT_NO_OFF_PROC_ENTRIES, PETSC_FALSE));
2159566063dSJacob Faibussowitsch   PetscCall(MatSetOption(mat, MAT_NEW_NONZERO_LOCATION_ERR, PETSC_TRUE));
2163ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
217076ba34aSJunchao Zhang }
218076ba34aSJunchao Zhang 
2190e3ece09SJunchao Zhang // Adapted from Kokkos-Kernels spmv_launch_parameters(), to get parameters in Kokkos nested loops which we used to merge or
2200e3ece09SJunchao Zhang // split csr matrices. The rule is to have "vector_length * team_size" be around 256 on GPUs (e.g., for a CUDA thread block)
2210e3ece09SJunchao Zhang template <class ExecutionSpace>
2220e3ece09SJunchao Zhang static PetscErrorCode MatMergeGetLaunchParameters(PetscInt numRows, PetscInt nnz, PetscInt rows_per_thread, PetscInt &team_size, PetscInt &vector_length, PetscInt &rows_per_team)
223d71ae5a4SJacob Faibussowitsch {
2240e3ece09SJunchao Zhang   Kokkos::TeamPolicy<ExecutionSpace> teamPolicy(128, Kokkos::AUTO);
225076ba34aSJunchao Zhang 
226076ba34aSJunchao Zhang   PetscFunctionBegin;
2270e3ece09SJunchao Zhang   PetscInt nnz_per_row = numRows ? (nnz / numRows) : 0; // we might meet empty matrices
228076ba34aSJunchao Zhang 
2290e3ece09SJunchao Zhang   if (nnz_per_row < 1) nnz_per_row = 1;
230076ba34aSJunchao Zhang 
2310e3ece09SJunchao Zhang   int max_vector_length = teamPolicy.vector_length_max();
232076ba34aSJunchao Zhang 
2330e3ece09SJunchao Zhang   if (vector_length < 1) {
2340e3ece09SJunchao Zhang     vector_length = 1;
2350e3ece09SJunchao Zhang     while (vector_length < max_vector_length && vector_length * 6 < nnz_per_row) vector_length *= 2;
236076ba34aSJunchao Zhang   }
237076ba34aSJunchao Zhang 
2380e3ece09SJunchao Zhang   // Determine rows per thread
2390e3ece09SJunchao Zhang   if (rows_per_thread < 1) {
2400e3ece09SJunchao Zhang     if (KokkosKernels::Impl::kk_is_gpu_exec_space<ExecutionSpace>()) rows_per_thread = 1;
2410e3ece09SJunchao Zhang     else {
2420e3ece09SJunchao Zhang       if (nnz_per_row < 20 && nnz > 5000000) {
2430e3ece09SJunchao Zhang         rows_per_thread = 256;
2440e3ece09SJunchao Zhang       } else rows_per_thread = 64;
245076ba34aSJunchao Zhang     }
246076ba34aSJunchao Zhang   }
247076ba34aSJunchao Zhang 
2480e3ece09SJunchao Zhang   if (team_size < 1) {
2490e3ece09SJunchao Zhang     if (KokkosKernels::Impl::kk_is_gpu_exec_space<ExecutionSpace>()) {
2500e3ece09SJunchao Zhang       team_size = 256 / vector_length;
251076ba34aSJunchao Zhang     } else {
2520e3ece09SJunchao Zhang       team_size = 1;
2530e3ece09SJunchao Zhang     }
254076ba34aSJunchao Zhang   }
255076ba34aSJunchao Zhang 
2560e3ece09SJunchao Zhang   rows_per_team = rows_per_thread * team_size;
257076ba34aSJunchao Zhang 
2580e3ece09SJunchao Zhang   if (rows_per_team < 0) {
2590e3ece09SJunchao Zhang     PetscInt nnz_per_team = 4096;
2600e3ece09SJunchao Zhang     PetscInt conc         = ExecutionSpace().concurrency();
2610e3ece09SJunchao Zhang     while ((conc * nnz_per_team * 4 > nnz) && (nnz_per_team > 256)) nnz_per_team /= 2;
2620e3ece09SJunchao Zhang     rows_per_team = (nnz_per_team + nnz_per_row - 1) / nnz_per_row;
2630e3ece09SJunchao Zhang   }
2643ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
265076ba34aSJunchao Zhang }
266076ba34aSJunchao Zhang 
2670e3ece09SJunchao Zhang /*
2680e3ece09SJunchao Zhang   Reduce two sets of global indices into local ones
269076ba34aSJunchao Zhang 
270076ba34aSJunchao Zhang   Input Parameters:
2710e3ece09SJunchao Zhang +  n1          - size of garray1[], the first set
2720e3ece09SJunchao Zhang .  garray1[n1] - a sorted global index array (without duplicates)
2730e3ece09SJunchao Zhang .  m           - size of indices[], the second set
2740e3ece09SJunchao Zhang -  indices[m]  - a unsorted global index array (might have duplicates), which will be updated on output into local ones
275076ba34aSJunchao Zhang 
276076ba34aSJunchao Zhang   Output Parameters:
2770e3ece09SJunchao Zhang +  n2          - size of garray2[], the merged set, which combines garray1[] and indices[]
2780e3ece09SJunchao Zhang .  garray2[n2] - allocated by callee using PetscMalloc1(). Contains sorted unique global indices (without duplicates). Caller needs to free it.
2790e3ece09SJunchao Zhang .  map[n1]     - allocated by caller. It gives garray1[i] = garray2[map[i]]
2800e3ece09SJunchao Zhang -  indices[m]  - on output, global indices in this array are rewritten with local ones, i.e, indices_input[i] = garray2[indices_output[i]]
281076ba34aSJunchao Zhang 
2820e3ece09SJunchao Zhang    Example, say
2830e3ece09SJunchao Zhang     n1         = 5
2840e3ece09SJunchao Zhang     garray1[5] = {1, 4, 7, 8, 10}
2850e3ece09SJunchao Zhang     m          = 4
2860e3ece09SJunchao Zhang     indices[4] = {2, 4, 8, 9}
28711a5261eSBarry Smith 
2880e3ece09SJunchao Zhang    Combining them together, we have 7 global indices in garray2[]
2890e3ece09SJunchao Zhang     n2         = 7
2900e3ece09SJunchao Zhang     garray2[7] = {1, 2, 4, 7, 8, 9, 10}
2910e3ece09SJunchao Zhang 
2920e3ece09SJunchao Zhang    And we have map[] to connect "garray1[i] = garray2[map[i]], i=[0,n1)"
2930e3ece09SJunchao Zhang     map[5] = {0, 2, 3, 4, 6}
2940e3ece09SJunchao Zhang 
2950e3ece09SJunchao Zhang    On output, indices[] is updated with local indices
2960e3ece09SJunchao Zhang     indices[4] = {1, 2, 4, 5}
297076ba34aSJunchao Zhang */
2980e3ece09SJunchao Zhang static PetscErrorCode ReduceTwoSetsOfGlobalIndices(PetscInt n1, const PetscInt *garray1, PetscInt m, PetscInt *indices, PetscInt *n2_, PetscInt **garray2_, PetscInt *map)
299d71ae5a4SJacob Faibussowitsch {
3000e3ece09SJunchao Zhang   PetscHMapI    g2l = nullptr;
3010e3ece09SJunchao Zhang   PetscHashIter iter;
3020e3ece09SJunchao Zhang   PetscInt      tot, key, val; // total unique global indices. key is global id; val is local id
3030e3ece09SJunchao Zhang   PetscInt      n2, *garray2;
304076ba34aSJunchao Zhang 
305076ba34aSJunchao Zhang   PetscFunctionBegin;
3060e3ece09SJunchao Zhang   tot = 0;
3070e3ece09SJunchao Zhang   PetscCall(PetscHMapICreateWithSize(n1, &g2l));
3080e3ece09SJunchao Zhang   for (PetscInt i = 0; i < m; i++) {                                // insert those in indices[]
3090e3ece09SJunchao Zhang     PetscCall(PetscHMapIGetWithDefault(g2l, indices[i], -1, &val)); // if not exist, val is set with -1
3100e3ece09SJunchao Zhang     if (val < 0) PetscCall(PetscHMapISet(g2l, indices[i], tot++));  // val < 0 means gid is not in the hash table yet
311076ba34aSJunchao Zhang   }
312076ba34aSJunchao Zhang 
3130e3ece09SJunchao Zhang   for (PetscInt i = 0; i < n1; i++) { // insert those in garray1[]
3140e3ece09SJunchao Zhang     PetscCall(PetscHMapIGetWithDefault(g2l, garray1[i], -1, &val));
3150e3ece09SJunchao Zhang     if (val < 0) PetscCall(PetscHMapISet(g2l, garray1[i], tot++));
316076ba34aSJunchao Zhang   }
317076ba34aSJunchao Zhang 
3180e3ece09SJunchao Zhang   // Pull out (unique) globals in the hash table and put them in garray2[]
3190e3ece09SJunchao Zhang   n2 = tot;
3200e3ece09SJunchao Zhang   PetscCall(PetscMalloc1(n2, &garray2));
3210e3ece09SJunchao Zhang   tot = 0;
3220e3ece09SJunchao Zhang   PetscHashIterBegin(g2l, iter);
3230e3ece09SJunchao Zhang   while (!PetscHashIterAtEnd(g2l, iter)) {
3240e3ece09SJunchao Zhang     PetscHashIterGetKey(g2l, iter, key);
3250e3ece09SJunchao Zhang     PetscHashIterNext(g2l, iter);
3260e3ece09SJunchao Zhang     garray2[tot++] = key;
327076ba34aSJunchao Zhang   }
328076ba34aSJunchao Zhang 
3290e3ece09SJunchao Zhang   // Sort garray2[] and then map them to local indices starting from 0
3300e3ece09SJunchao Zhang   PetscCall(PetscSortInt(n2, garray2));
3310e3ece09SJunchao Zhang   PetscCall(PetscHMapIClear(g2l));
3320e3ece09SJunchao Zhang   for (PetscInt i = 0; i < tot; i++) PetscCall(PetscHMapISet(g2l, garray2[i], i)); // i is the local id
333f0e6e2d1SJunchao Zhang 
3340e3ece09SJunchao Zhang   // Rewrite indices[] with local indices
335f0e6e2d1SJunchao Zhang   for (PetscInt i = 0; i < m; i++) {
3360e3ece09SJunchao Zhang     PetscCall(PetscHMapIGetWithDefault(g2l, indices[i], -1, &val));
3370e3ece09SJunchao Zhang     PetscAssert(val >= 0, PETSC_COMM_SELF, PETSC_ERR_PLIB, "Met a negative local column index");
3380e3ece09SJunchao Zhang     indices[i] = val;
3390e3ece09SJunchao Zhang   }
3400e3ece09SJunchao Zhang   // Record the map that maps garray1[i] to garray2[map[i]]
3410e3ece09SJunchao Zhang   for (PetscInt i = 0; i < n1; i++) PetscCall(PetscHMapIGetWithDefault(g2l, garray1[i], -1, &map[i]));
3420e3ece09SJunchao Zhang   PetscCall(PetscHMapIDestroy(&g2l));
3430e3ece09SJunchao Zhang   *n2_      = n2;
3440e3ece09SJunchao Zhang   *garray2_ = garray2;
3450e3ece09SJunchao Zhang   PetscFunctionReturn(PETSC_SUCCESS);
3460e3ece09SJunchao Zhang }
347f0e6e2d1SJunchao Zhang 
3480e3ece09SJunchao Zhang /*
3490e3ece09SJunchao Zhang   MatMPIAIJKokkosReduce - Reduce rows of a MPIAIJKOKKOS matrix (E, in split form) to produce another matrix (F, also in split form, stored in mm)
3500e3ece09SJunchao Zhang 
3510e3ece09SJunchao Zhang   It is the reverse of MatMPIAIJKokkosBcast() in some sense, but with a different signature since we do not really need a fully populated MPIAIJKOKKOS E.
3520e3ece09SJunchao Zhang 
3530e3ece09SJunchao Zhang   Think each row of E as a leaf, then the given ownerSF specifies roots for the leaves. Roots may connect to multiple leaves.
3540e3ece09SJunchao Zhang   In this routine, we sparse-merge leaves (rows) at their roots to form potentially longer rows in F. F's number of rows will be nroots of ownerSF.
3550e3ece09SJunchao Zhang 
3560e3ece09SJunchao Zhang   Input Parameters:
3570e3ece09SJunchao Zhang +  comm       - MPI communicator of E
3580e3ece09SJunchao Zhang .  A          - diag block of E, using local column indices
3590e3ece09SJunchao Zhang .  B          - off-diag block of E, using local column indices
3600e3ece09SJunchao Zhang .  cstart      - (global) start column of Ed
3610e3ece09SJunchao Zhang .  cend        - (global) end column + 1 of Ed.  In other words, E's column ownership is in range of [cstart, cend)
3620e3ece09SJunchao Zhang .  garray1[n1] - global column indices of Eo. Here n1 is Eo's column size.
3630e3ece09SJunchao Zhang .  ownerSF     - the SF specifies ownership (root) of rows in E
3640e3ece09SJunchao Zhang .  reuse       - either MAT_INITIAL_MATRIX or MAT_REUSE_MATRIX
3650e3ece09SJunchao Zhang -  mm          - to stash intermediate data structures for reuse
3660e3ece09SJunchao Zhang 
3670e3ece09SJunchao Zhang   Output Parameters:
3680e3ece09SJunchao Zhang +  map[n1]  - allocated by caller. It maps garray1[] to garray2[]. See more at ReduceTwoSetsOfGlobalIndices().
3690e3ece09SJunchao Zhang -  mm       - contains various info, such as garray2[], F (Fd, Fo) etc.
3700e3ece09SJunchao Zhang 
3710e3ece09SJunchao Zhang   Notes:
3720e3ece09SJunchao Zhang   When reuse = MAT_REUSE_MATRIX, cstart, cend, garray1, ownerSF, map are not significant.
3730e3ece09SJunchao Zhang 
3740e3ece09SJunchao Zhang  */
3750e3ece09SJunchao Zhang static PetscErrorCode MatMPIAIJKokkosReduceBegin(MPI_Comm comm, KokkosCsrMatrix A, KokkosCsrMatrix B, PetscInt cstart, PetscInt cend, const PetscInt *garray1, PetscSF ownerSF, MatReuse reuse, PetscInt *map, MatMatStruct_AtB *mm)
3760e3ece09SJunchao Zhang {
3770e3ece09SJunchao Zhang   PetscFunctionBegin;
3780e3ece09SJunchao Zhang   if (reuse == MAT_INITIAL_MATRIX) {
3790e3ece09SJunchao Zhang     PetscInt Em = A.numRows(), Fm;
3800e3ece09SJunchao Zhang     PetscInt n1 = B.numCols();
3810e3ece09SJunchao Zhang 
3820e3ece09SJunchao Zhang     PetscCall(PetscSFGetGraph(ownerSF, &Fm, NULL, NULL, NULL)); // Fm = #rows of F = nroots of ownerSF
3830e3ece09SJunchao Zhang 
3840e3ece09SJunchao Zhang     // Do the analysis on host
3850e3ece09SJunchao Zhang     auto                 Ai_h = Kokkos::create_mirror_view_and_copy(Kokkos::HostSpace(), A.graph.row_map);
3860e3ece09SJunchao Zhang     auto                 Aj_h = Kokkos::create_mirror_view_and_copy(Kokkos::HostSpace(), A.graph.entries);
3870e3ece09SJunchao Zhang     auto                 Bi_h = Kokkos::create_mirror_view_and_copy(Kokkos::HostSpace(), B.graph.row_map);
3880e3ece09SJunchao Zhang     auto                 Bj_h = Kokkos::create_mirror_view_and_copy(Kokkos::HostSpace(), B.graph.entries);
3890e3ece09SJunchao Zhang     const MatRowMapType *Ai = Ai_h.data(), *Bi = Bi_h.data();
3900e3ece09SJunchao Zhang     const MatColIdxType *Aj = Aj_h.data(), *Bj = Bj_h.data();
3910e3ece09SJunchao Zhang 
3920e3ece09SJunchao Zhang     // Count how many nonzeros of each row in E are in the left of the diag block [cstart,cend)
3937b8d4ba6SJunchao Zhang     PetscIntKokkosViewHost E_NzLeft_h(NoInit("E_NzLeft_h"), Em), E_RowLen_h(NoInit("E_RowLen_h"), Em);
3940e3ece09SJunchao Zhang     PetscInt              *E_NzLeft = E_NzLeft_h.data(), *E_RowLen = E_RowLen_h.data();
3950e3ece09SJunchao Zhang     for (PetscInt i = 0; i < Em; i++) {
3960e3ece09SJunchao Zhang       const PetscInt *first, *last, *it;
3970e3ece09SJunchao Zhang       PetscInt        count, step;
3980e3ece09SJunchao Zhang       // std::lower_bound(first,last,cstart), but need to use global column indices
3990e3ece09SJunchao Zhang       first = Bj + Bi[i];
4000e3ece09SJunchao Zhang       last  = Bj + Bi[i + 1];
401f0e6e2d1SJunchao Zhang       count = last - first;
402f0e6e2d1SJunchao Zhang       while (count > 0) {
403f0e6e2d1SJunchao Zhang         it   = first;
404f0e6e2d1SJunchao Zhang         step = count / 2;
405f0e6e2d1SJunchao Zhang         it += step;
4060e3ece09SJunchao Zhang         if (garray1[*it] < cstart) { // map local to global
407f0e6e2d1SJunchao Zhang           first = ++it;
408f0e6e2d1SJunchao Zhang           count -= step + 1;
409f0e6e2d1SJunchao Zhang         } else count = step;
410f0e6e2d1SJunchao Zhang       }
4110e3ece09SJunchao Zhang       E_NzLeft[i] = first - (Bj + Bi[i]);
4120e3ece09SJunchao Zhang       E_RowLen[i] = (Ai[i + 1] - Ai[i]) + (Bi[i + 1] - Bi[i]);
413f0e6e2d1SJunchao Zhang     }
414f0e6e2d1SJunchao Zhang 
4150e3ece09SJunchao Zhang     // Get length of rows (i.e., sizes of leaves) that contribute to my roots
4160e3ece09SJunchao Zhang     const PetscMPIInt *iranks, *ranks;
4170e3ece09SJunchao Zhang     const PetscInt    *ioffset, *irootloc, *roffset, *rmine;
4180e3ece09SJunchao Zhang     PetscInt           niranks, nranks;
4190e3ece09SJunchao Zhang     MPI_Request       *reqs;
4200e3ece09SJunchao Zhang     PetscMPIInt        tag;
4210e3ece09SJunchao Zhang     PetscSF            reduceSF;
4220e3ece09SJunchao Zhang     PetscInt          *sdisp, *rdisp;
423f0e6e2d1SJunchao Zhang 
4240e3ece09SJunchao Zhang     PetscCall(PetscCommGetNewTag(comm, &tag));
4250e3ece09SJunchao Zhang     PetscCall(PetscSFGetLeafRanks(ownerSF, &niranks, &iranks, &ioffset, &irootloc));  // get leaf ranks connecting to roots on this process (I'll recv from them)
4260e3ece09SJunchao Zhang     PetscCall(PetscSFGetRootRanks(ownerSF, &nranks, &ranks, &roffset, &rmine, NULL)); // get root ranks this process connects (I'll send to them)
427f0e6e2d1SJunchao Zhang 
4280e3ece09SJunchao Zhang     // Find out length of each row I will receive. Even for the same row index, when they are from
4290e3ece09SJunchao Zhang     // different senders, they might have different lengths (and sparsity patterns)
4300e3ece09SJunchao Zhang     PetscInt  sendRowCnt = roffset[nranks], recvRowCnt = ioffset[niranks];
4310e3ece09SJunchao Zhang     PetscInt *sendRowLen, *recvRowLen; // lengths of rows of I need to send/recv per process
432f0e6e2d1SJunchao Zhang 
4330e3ece09SJunchao Zhang     PetscCall(PetscMalloc5(sendRowCnt, &sendRowLen, recvRowCnt + 1, &recvRowLen, nranks, &sdisp, niranks + 1, &rdisp, nranks + niranks, &reqs));
4340e3ece09SJunchao Zhang 
4350e3ece09SJunchao Zhang     for (PetscInt i = 0; i < sendRowCnt; i++) sendRowLen[i] = E_RowLen[rmine[i]];
4360e3ece09SJunchao Zhang     recvRowLen[0] = 0; // since we will make it in CSR format later
4370e3ece09SJunchao Zhang     recvRowLen++;      // advance the pointer now
4380e3ece09SJunchao Zhang     for (PetscInt i = 0; i < niranks; i++) { MPI_Irecv(&recvRowLen[ioffset[i]], ioffset[i + 1] - ioffset[i], MPIU_INT, iranks[i], tag, comm, &reqs[nranks + i]); }
4390e3ece09SJunchao Zhang     for (PetscInt i = 0; i < nranks; i++) { MPI_Isend(&sendRowLen[roffset[i]], roffset[i + 1] - roffset[i], MPIU_INT, ranks[i], tag, comm, &reqs[i]); }
4400e3ece09SJunchao Zhang     PetscCallMPI(MPI_Waitall(nranks + niranks, reqs, MPI_STATUSES_IGNORE));
4410e3ece09SJunchao Zhang 
4420e3ece09SJunchao Zhang     // Build the real PetscSF for reducing E rows (buffer to buffer)
4430e3ece09SJunchao Zhang     rdisp[0] = 0;
4440e3ece09SJunchao Zhang     for (PetscInt i = 0; i < niranks; i++) {
4450e3ece09SJunchao Zhang       rdisp[i + 1] = rdisp[i];
4460e3ece09SJunchao Zhang       for (PetscInt j = ioffset[i]; j < ioffset[i + 1]; j++) { rdisp[i + 1] += recvRowLen[j]; }
4470e3ece09SJunchao Zhang     }
4480e3ece09SJunchao Zhang     recvRowLen--; // put it back into csr format
4490e3ece09SJunchao Zhang     for (PetscInt i = 0; i < recvRowCnt; i++) recvRowLen[i + 1] += recvRowLen[i];
4500e3ece09SJunchao Zhang 
4510e3ece09SJunchao Zhang     for (PetscInt i = 0; i < nranks; i++) { MPI_Irecv(&sdisp[i], 1, MPIU_INT, ranks[i], tag, comm, &reqs[i]); }
4520e3ece09SJunchao Zhang     for (PetscInt i = 0; i < niranks; i++) { MPI_Isend(&rdisp[i], 1, MPIU_INT, iranks[i], tag, comm, &reqs[nranks + i]); }
4530e3ece09SJunchao Zhang     PetscCallMPI(MPI_Waitall(nranks + niranks, reqs, MPI_STATUSES_IGNORE));
4540e3ece09SJunchao Zhang 
4550e3ece09SJunchao Zhang     PetscInt     nleaves = 0, Enz = 0;    // leaves are nonzeros I will send
4560e3ece09SJunchao Zhang     PetscInt     nroots = rdisp[niranks]; // roots are nonzeros I will recv
4570e3ece09SJunchao Zhang     PetscSFNode *iremote;
4580e3ece09SJunchao Zhang 
4590e3ece09SJunchao Zhang     for (PetscInt i = 0; i < Em; i++) Enz += E_RowLen[i];
4600e3ece09SJunchao Zhang     PetscAssert(A.nnz() + B.nnz() == Enz, comm, PETSC_ERR_PLIB, "Enz should be equal to sum of nnz of A and B");
4610e3ece09SJunchao Zhang     PetscCallMPI(PetscMalloc1(Enz, &iremote)); // no free, since we give ownership to reduceSF
4620e3ece09SJunchao Zhang 
4630e3ece09SJunchao Zhang     for (PetscInt i = 0; i < nranks; i++) {
4640e3ece09SJunchao Zhang       PetscInt count = 0;
4650e3ece09SJunchao Zhang       for (PetscInt j = roffset[i]; j < roffset[i + 1]; j++) count += E_RowLen[rmine[j]];
4660e3ece09SJunchao Zhang       for (PetscInt j = 0; j < count; j++) {
4670e3ece09SJunchao Zhang         iremote[nleaves + j].rank  = ranks[i];
4680e3ece09SJunchao Zhang         iremote[nleaves + j].index = sdisp[i] + j;
4690e3ece09SJunchao Zhang       }
4700e3ece09SJunchao Zhang       nleaves += count;
4710e3ece09SJunchao Zhang     }
4720e3ece09SJunchao Zhang     PetscCheck(nleaves == Enz, comm, PETSC_ERR_PLIB, "nleaves should be equal to Enz");
4730e3ece09SJunchao Zhang 
4740e3ece09SJunchao Zhang     PetscCall(PetscSFCreate(comm, &reduceSF));
4750e3ece09SJunchao Zhang     PetscCall(PetscSFSetGraph(reduceSF, nroots, nleaves, NULL, PETSC_OWN_POINTER, iremote, PETSC_OWN_POINTER));
4760e3ece09SJunchao Zhang 
4770e3ece09SJunchao Zhang     // Copy (global) column indices of the needed rows in E to sendCol[], and then PetscSFReduce to recvCol[]
4780e3ece09SJunchao Zhang     PetscInt *sendCol, *recvCol;
4790e3ece09SJunchao Zhang     PetscCall(PetscMalloc2(nleaves, &sendCol, nroots, &recvCol));
4800e3ece09SJunchao Zhang     for (PetscInt k = 0; k < roffset[nranks]; k++) {
4810e3ece09SJunchao Zhang       PetscInt  i      = rmine[k]; // row to be copied
4820e3ece09SJunchao Zhang       PetscInt *buf    = &sendCol[Ai[i] + Bi[i]];
4830e3ece09SJunchao Zhang       PetscInt  nzLeft = E_NzLeft[i];
4840e3ece09SJunchao Zhang       PetscInt  alen = Ai[i + 1] - Ai[i], blen = Bi[i + 1] - Bi[i];
4850e3ece09SJunchao Zhang       for (PetscInt j = 0; j < alen + blen; j++) {
4860e3ece09SJunchao Zhang         if (j < nzLeft) {
4870e3ece09SJunchao Zhang           buf[j] = garray1[Bj[Bi[i] + j]]; // left B, in global
4880e3ece09SJunchao Zhang         } else if (j < nzLeft + alen) {
4890e3ece09SJunchao Zhang           buf[j] = Aj[Ai[i] + j - nzLeft] + cstart; // diag A, also in global
4900e3ece09SJunchao Zhang         } else {
4910e3ece09SJunchao Zhang           buf[j] = garray1[Bj[Bi[i] + j - alen]]; // right B, in global
4920e3ece09SJunchao Zhang         }
4930e3ece09SJunchao Zhang       }
4940e3ece09SJunchao Zhang     }
4950e3ece09SJunchao Zhang     PetscCall(PetscSFReduceWithMemTypeBegin(reduceSF, MPIU_INT, PETSC_MEMTYPE_HOST, sendCol, PETSC_MEMTYPE_HOST, recvCol, MPI_REPLACE));
4960e3ece09SJunchao Zhang     PetscCall(PetscSFReduceEnd(reduceSF, MPIU_INT, sendCol, recvCol, MPI_REPLACE));
4970e3ece09SJunchao Zhang 
4980e3ece09SJunchao Zhang     // With recvCol[], we do a series of analysis to get i, j of Fd, Fo, and build plans to reduce nonzeros in recv buffers to Fd and Fo
4990e3ece09SJunchao Zhang     PetscInt *recvRowPerm, *recvColSorted;
5000e3ece09SJunchao Zhang     PetscInt *recvNzPerm, *recvNzPermSorted;
5010e3ece09SJunchao Zhang     PetscCall(PetscMalloc4(recvRowCnt, &recvRowPerm, nroots, &recvColSorted, nroots, &recvNzPerm, nroots, &recvNzPermSorted));
5020e3ece09SJunchao Zhang 
5030e3ece09SJunchao Zhang     for (PetscInt i = 0; i < nroots; i++) recvNzPerm[i] = i;                   // numbering all received nonzeros
5040e3ece09SJunchao Zhang     for (PetscInt i = 0; i < recvRowCnt; i++) recvRowPerm[i] = i;              // put up a permutation array, so that after sorting we know where to get a row in recvCol[]
5050e3ece09SJunchao Zhang     PetscCall(PetscSortIntWithPermutation(recvRowCnt, irootloc, recvRowPerm)); // irootloc[] (owned by ownerSF) won't be changed
5060e3ece09SJunchao Zhang 
5070e3ece09SJunchao Zhang     // i[] array, nz are always easiest to compute
5087b8d4ba6SJunchao Zhang     MatRowMapKokkosViewHost Fdi_h(NoInit("Fdi_h"), Fm + 1), Foi_h(NoInit("Foi_h"), Fm + 1);
5090e3ece09SJunchao Zhang     MatRowMapType          *Fdi, *Foi;
5100e3ece09SJunchao Zhang     PetscInt                FnzDups = 0, Fdnz = 0, FdnzDups = 0, Fonz = 0, FonzDups = 0; // nz (with or without dups) in F, Fd, Fo
5110e3ece09SJunchao Zhang     PetscInt                iter;
5120e3ece09SJunchao Zhang 
5130e3ece09SJunchao Zhang     Kokkos::deep_copy(Fdi_h, 0); // zero, as we will do 'val++' on them
5140e3ece09SJunchao Zhang     Kokkos::deep_copy(Foi_h, 0);
5150e3ece09SJunchao Zhang     Fdi  = Fdi_h.data() + 1; // +1 for easy indexing in code below
5160e3ece09SJunchao Zhang     Foi  = Foi_h.data() + 1;
5170e3ece09SJunchao Zhang     iter = 0;
5180e3ece09SJunchao Zhang     while (iter < recvRowCnt) { // iter over received rows
5190e3ece09SJunchao Zhang       PetscInt curRowIdx = irootloc[recvRowPerm[iter]];
5200e3ece09SJunchao Zhang       PetscInt dupRows   = 1; // current row has this many contributing rows (of various sparsity patterns)
5210e3ece09SJunchao Zhang 
5220e3ece09SJunchao Zhang       while (iter + dupRows < recvRowCnt && irootloc[recvRowPerm[iter + dupRows]] == curRowIdx) dupRows++;
5230e3ece09SJunchao Zhang 
5240e3ece09SJunchao Zhang       // Copy column indices (and their permutation) of these rows into recvColSorted & recvNzPermSorted
5250e3ece09SJunchao Zhang       PetscInt  nz    = 0; // nz (with dups) in the current row
5260e3ece09SJunchao Zhang       PetscInt *jbuf  = recvColSorted + FnzDups;
5270e3ece09SJunchao Zhang       PetscInt *pbuf  = recvNzPermSorted + FnzDups;
5280e3ece09SJunchao Zhang       PetscInt *jbuf2 = jbuf; // temp pointers
5290e3ece09SJunchao Zhang       PetscInt *pbuf2 = pbuf;
5300e3ece09SJunchao Zhang       for (PetscInt d = 0; d < dupRows; d++) {
5310e3ece09SJunchao Zhang         PetscInt i   = recvRowPerm[iter + d];
5320e3ece09SJunchao Zhang         PetscInt len = recvRowLen[i + 1] - recvRowLen[i];
5330e3ece09SJunchao Zhang         PetscCall(PetscArraycpy(jbuf2, &recvCol[recvRowLen[i]], len));
5340e3ece09SJunchao Zhang         PetscCall(PetscArraycpy(pbuf2, &recvNzPerm[recvRowLen[i]], len));
5350e3ece09SJunchao Zhang         jbuf2 += len;
5360e3ece09SJunchao Zhang         pbuf2 += len;
5370e3ece09SJunchao Zhang         nz += len;
5380e3ece09SJunchao Zhang       }
5390e3ece09SJunchao Zhang       PetscCall(PetscIntSortSemiOrderedWithArray(nz, jbuf, pbuf)); // It could be improved with k-way merge sort, since the rows are already sorted
5400e3ece09SJunchao Zhang 
5410e3ece09SJunchao Zhang       // Scan column indices (in jbuf[0,nz), might have dups) of this row, and see how many go to Fd and how many go to Fo
5420e3ece09SJunchao Zhang       PetscInt cur = 0;
5430e3ece09SJunchao Zhang       while (cur < nz) {
5440e3ece09SJunchao Zhang         PetscInt curColIdx = jbuf[cur];
5450e3ece09SJunchao Zhang         PetscInt dups      = 1;
5460e3ece09SJunchao Zhang 
5470e3ece09SJunchao Zhang         while (cur + dups < nz && jbuf[cur + dups] == curColIdx) dups++;
5480e3ece09SJunchao Zhang         if (curColIdx >= cstart && curColIdx < cend) {
5490e3ece09SJunchao Zhang           Fdi[curRowIdx]++;
5500e3ece09SJunchao Zhang           FdnzDups += dups;
5510e3ece09SJunchao Zhang         } else {
5520e3ece09SJunchao Zhang           Foi[curRowIdx]++;
5530e3ece09SJunchao Zhang           FonzDups += dups;
5540e3ece09SJunchao Zhang         }
5550e3ece09SJunchao Zhang         cur += dups;
5560e3ece09SJunchao Zhang       }
5570e3ece09SJunchao Zhang 
5580e3ece09SJunchao Zhang       FnzDups += nz;
5590e3ece09SJunchao Zhang       iter += dupRows; // Move to next unique row
5600e3ece09SJunchao Zhang     }
5610e3ece09SJunchao Zhang 
5620e3ece09SJunchao Zhang     Fdi = Fdi_h.data(); // restore Fdi, Foi and make them CSR
5630e3ece09SJunchao Zhang     Foi = Foi_h.data();
5640e3ece09SJunchao Zhang     for (PetscInt i = 0; i < Fm; i++) {
5650e3ece09SJunchao Zhang       Fdi[i + 1] += Fdi[i];
5660e3ece09SJunchao Zhang       Foi[i + 1] += Foi[i];
5670e3ece09SJunchao Zhang     }
5680e3ece09SJunchao Zhang     Fdnz = Fdi[Fm];
5690e3ece09SJunchao Zhang     Fonz = Foi[Fm];
5700e3ece09SJunchao Zhang     PetscCall(PetscFree2(sendCol, recvCol));
5710e3ece09SJunchao Zhang 
5720e3ece09SJunchao Zhang     // Allocate j, jmap, jperm for Fd and Fo
5737b8d4ba6SJunchao Zhang     MatColIdxKokkosViewHost Fdj_h(NoInit("Fdj_h"), Fdnz), Foj_h(NoInit("Foj_h"), Fonz);
5747b8d4ba6SJunchao Zhang     MatRowMapKokkosViewHost Fdjmap_h(NoInit("Fdjmap_h"), Fdnz + 1), Fojmap_h(NoInit("Fojmap_h"), Fonz + 1); // +1 to make csr
5757b8d4ba6SJunchao Zhang     MatRowMapKokkosViewHost Fdjperm_h(NoInit("Fdjperm_h"), FdnzDups), Fojperm_h(NoInit("Fojperm_h"), FonzDups);
5760e3ece09SJunchao Zhang     MatColIdxType          *Fdj = Fdj_h.data(), *Foj = Foj_h.data();
5770e3ece09SJunchao Zhang     MatRowMapType          *Fdjmap = Fdjmap_h.data(), *Fojmap = Fojmap_h.data();
5780e3ece09SJunchao Zhang     MatRowMapType          *Fdjperm = Fdjperm_h.data(), *Fojperm = Fojperm_h.data();
5790e3ece09SJunchao Zhang 
5800e3ece09SJunchao Zhang     // Scan recvColSorted[] again, and fill j, jmap, jperm for Fd and Fo
5810e3ece09SJunchao Zhang     Fdjmap[0] = 0;
5820e3ece09SJunchao Zhang     Fojmap[0] = 0;
5830e3ece09SJunchao Zhang     FnzDups   = 0;
5840e3ece09SJunchao Zhang     Fdnz      = 0;
5850e3ece09SJunchao Zhang     Fonz      = 0;
5860e3ece09SJunchao Zhang     iter      = 0; // iter over received rows
5870e3ece09SJunchao Zhang     while (iter < recvRowCnt) {
5880e3ece09SJunchao Zhang       PetscInt curRowIdx = irootloc[recvRowPerm[iter]]; // current row idx
5890e3ece09SJunchao Zhang       PetscInt dupRows   = 1;                           // It has this many contributing rows (of various lengths)
5900e3ece09SJunchao Zhang       PetscInt nz        = 0;                           // nz (with dups) in the current row
5910e3ece09SJunchao Zhang 
5920e3ece09SJunchao Zhang       while (iter + dupRows < recvRowCnt && irootloc[recvRowPerm[iter + dupRows]] == curRowIdx) dupRows++;
5930e3ece09SJunchao Zhang       for (PetscInt d = 0; d < dupRows; d++) {
5940e3ece09SJunchao Zhang         PetscInt i = recvRowPerm[iter + d];
5950e3ece09SJunchao Zhang         nz += recvRowLen[i + 1] - recvRowLen[i];
5960e3ece09SJunchao Zhang       }
5970e3ece09SJunchao Zhang 
5980e3ece09SJunchao Zhang       PetscInt *jbuf = recvColSorted + FnzDups;
5990e3ece09SJunchao Zhang       // Scan columns (in jbuf[0,nz) of this row, copy them and their permutation to j[] and jperm[] of Fd and Fo
6000e3ece09SJunchao Zhang       PetscInt cur = 0;
6010e3ece09SJunchao Zhang       while (cur < nz) {
6020e3ece09SJunchao Zhang         PetscInt curColIdx = jbuf[cur];
6030e3ece09SJunchao Zhang         PetscInt dups      = 1;
6040e3ece09SJunchao Zhang 
6050e3ece09SJunchao Zhang         while (cur + dups < nz && jbuf[cur + dups] == curColIdx) dups++;
6060e3ece09SJunchao Zhang         if (curColIdx >= cstart && curColIdx < cend) {
6070e3ece09SJunchao Zhang           Fdj[Fdnz]        = curColIdx - cstart; // easily convert to local
6080e3ece09SJunchao Zhang           Fdjmap[Fdnz + 1] = Fdjmap[Fdnz] + dups;
6090e3ece09SJunchao Zhang           for (PetscInt j = 0; j < dups; j++) Fdjperm[Fdjmap[Fdnz] + j] = recvNzPermSorted[FnzDups + j];
6100e3ece09SJunchao Zhang           FdnzDups += dups;
6110e3ece09SJunchao Zhang           Fdnz++;
6120e3ece09SJunchao Zhang         } else {
6130e3ece09SJunchao Zhang           Foj[Fonz]        = curColIdx; // in global
6140e3ece09SJunchao Zhang           Fojmap[Fonz + 1] = Fojmap[Fonz] + dups;
6150e3ece09SJunchao Zhang           for (PetscInt j = 0; j < dups; j++) Fojperm[Fojmap[Fonz] + j] = recvNzPermSorted[FnzDups + j];
6160e3ece09SJunchao Zhang           FonzDups += dups;
6170e3ece09SJunchao Zhang           Fonz++;
6180e3ece09SJunchao Zhang         }
6190e3ece09SJunchao Zhang         cur += dups;
6200e3ece09SJunchao Zhang         FnzDups += dups;
6210e3ece09SJunchao Zhang       }
6220e3ece09SJunchao Zhang       iter += dupRows; // Move to next unique row
6230e3ece09SJunchao Zhang     }
6240e3ece09SJunchao Zhang     PetscCall(PetscFree4(recvRowPerm, recvColSorted, recvNzPerm, recvNzPermSorted));
6250e3ece09SJunchao Zhang     PetscCall(PetscFree5(sendRowLen, recvRowLen, sdisp, rdisp, reqs));
6260e3ece09SJunchao Zhang 
6270e3ece09SJunchao Zhang     // Combine global column indices in garray1[] and Foj[]
6280e3ece09SJunchao Zhang     PetscInt n2, *garray2;
6290e3ece09SJunchao Zhang 
6300e3ece09SJunchao Zhang     PetscCall(ReduceTwoSetsOfGlobalIndices(n1, garray1, Fonz, Foj, &n2, &garray2, map));
6310e3ece09SJunchao Zhang     mm->sf       = reduceSF;
6327b8d4ba6SJunchao Zhang     mm->leafBuf  = MatScalarKokkosView(NoInit("leafBuf"), nleaves);
6337b8d4ba6SJunchao Zhang     mm->rootBuf  = MatScalarKokkosView(NoInit("rootBuf"), nroots);
634aaa8cc7dSPierre Jolivet     mm->garray   = garray2; // give ownership, so no free
6350e3ece09SJunchao Zhang     mm->n        = n2;
6360e3ece09SJunchao Zhang     mm->E_NzLeft = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), E_NzLeft_h);
6370e3ece09SJunchao Zhang     mm->Fdjmap   = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Fdjmap_h);
6380e3ece09SJunchao Zhang     mm->Fdjperm  = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Fdjperm_h);
6390e3ece09SJunchao Zhang     mm->Fojmap   = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Fojmap_h);
6400e3ece09SJunchao Zhang     mm->Fojperm  = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Fojperm_h);
6410e3ece09SJunchao Zhang 
6420e3ece09SJunchao Zhang     // Output Fd and Fo in KokkosCsrMatrix format
6437b8d4ba6SJunchao Zhang     MatScalarKokkosView Fda_d(NoInit("Fda_d"), Fdnz);
6440e3ece09SJunchao Zhang     MatRowMapKokkosView Fdi_d = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Fdi_h);
6450e3ece09SJunchao Zhang     MatColIdxKokkosView Fdj_d = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Fdj_h);
6467b8d4ba6SJunchao Zhang     MatScalarKokkosView Foa_d(NoInit("Foa_d"), Fonz);
6470e3ece09SJunchao Zhang     MatRowMapKokkosView Foi_d = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Foi_h);
6480e3ece09SJunchao Zhang     MatColIdxKokkosView Foj_d = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Foj_h);
6490e3ece09SJunchao Zhang 
6500e3ece09SJunchao Zhang     PetscCallCXX(mm->Fd = KokkosCsrMatrix("Fd", Fm, cend - cstart, Fdnz, Fda_d, Fdi_d, Fdj_d));
6510e3ece09SJunchao Zhang     PetscCallCXX(mm->Fo = KokkosCsrMatrix("Fo", Fm, n2, Fonz, Foa_d, Foi_d, Foj_d)); // Fo's column size is n2, length of garray2[]
6520e3ece09SJunchao Zhang 
6530e3ece09SJunchao Zhang     // Compute kernel launch parameters in merging E
6540e3ece09SJunchao Zhang     PetscInt teamSize, vectorLength, rowsPerTeam;
6550e3ece09SJunchao Zhang 
6560e3ece09SJunchao Zhang     teamSize = vectorLength = rowsPerTeam = -1;
6570e3ece09SJunchao Zhang     PetscCall(MatMergeGetLaunchParameters<DefaultExecutionSpace>(Em, Enz, -1, teamSize, vectorLength, rowsPerTeam));
6580e3ece09SJunchao Zhang     mm->E_TeamSize     = teamSize;
6590e3ece09SJunchao Zhang     mm->E_VectorLength = vectorLength;
6600e3ece09SJunchao Zhang     mm->E_RowsPerTeam  = rowsPerTeam;
6610e3ece09SJunchao Zhang   } else PetscCheck(reuse == MAT_REUSE_MATRIX, comm, PETSC_ERR_PLIB, "Unsupported MatReuse enum %d", reuse);
6620e3ece09SJunchao Zhang 
6630e3ece09SJunchao Zhang   // Handy aliases
6640e3ece09SJunchao Zhang   auto       &Aa           = A.values;
6650e3ece09SJunchao Zhang   auto       &Ba           = B.values;
6660e3ece09SJunchao Zhang   const auto &Ai           = A.graph.row_map;
6670e3ece09SJunchao Zhang   const auto &Bi           = B.graph.row_map;
6680e3ece09SJunchao Zhang   const auto &E_NzLeft     = mm->E_NzLeft;
6690e3ece09SJunchao Zhang   auto       &leafBuf      = mm->leafBuf;
6700e3ece09SJunchao Zhang   auto       &rootBuf      = mm->rootBuf;
6710e3ece09SJunchao Zhang   PetscSF     reduceSF     = mm->sf;
6720e3ece09SJunchao Zhang   PetscInt    Em           = A.numRows();
6730e3ece09SJunchao Zhang   PetscInt    teamSize     = mm->E_TeamSize;
6740e3ece09SJunchao Zhang   PetscInt    vectorLength = mm->E_VectorLength;
6750e3ece09SJunchao Zhang   PetscInt    rowsPerTeam  = mm->E_RowsPerTeam;
6760e3ece09SJunchao Zhang   PetscInt    workSets     = (Em + rowsPerTeam - 1) / rowsPerTeam;
6770e3ece09SJunchao Zhang 
6780e3ece09SJunchao Zhang   // Copy rows in A/B of E to leafBuf, then pass it to rootBuf
6790e3ece09SJunchao Zhang   PetscCallCXX(Kokkos::parallel_for(
680d326c3f1SJunchao Zhang     Kokkos::TeamPolicy<>(PetscGetKokkosExecutionSpace(), workSets, teamSize, vectorLength), KOKKOS_LAMBDA(const KokkosTeamMemberType &t) {
6810e3ece09SJunchao Zhang       Kokkos::parallel_for(Kokkos::TeamThreadRange(t, 0, rowsPerTeam), [&](PetscInt k) {
6820e3ece09SJunchao Zhang         PetscInt i = t.league_rank() * rowsPerTeam + k; // i-th row in F
6830e3ece09SJunchao Zhang         if (i < Em) {
6840e3ece09SJunchao Zhang           PetscInt disp   = Ai(i) + Bi(i);
6850e3ece09SJunchao Zhang           PetscInt alen   = Ai(i + 1) - Ai(i);
6860e3ece09SJunchao Zhang           PetscInt blen   = Bi(i + 1) - Bi(i);
6870e3ece09SJunchao Zhang           PetscInt nzleft = E_NzLeft(i);
6880e3ece09SJunchao Zhang 
6890e3ece09SJunchao Zhang           Kokkos::parallel_for(Kokkos::ThreadVectorRange(t, alen + blen), [&](PetscInt j) {
6900e3ece09SJunchao Zhang             MatScalar &val = leafBuf(disp + j);
6910e3ece09SJunchao Zhang             if (j < nzleft) { // B left
6920e3ece09SJunchao Zhang               val = Ba(Bi(i) + j);
6930e3ece09SJunchao Zhang             } else if (j < nzleft + alen) { // diag A
6940e3ece09SJunchao Zhang               val = Aa(Ai(i) + j - nzleft);
6950e3ece09SJunchao Zhang             } else { // B right
6960e3ece09SJunchao Zhang               val = Ba(Bi(i) + j - alen);
697f0e6e2d1SJunchao Zhang             }
698f0e6e2d1SJunchao Zhang           });
699f0e6e2d1SJunchao Zhang         }
700f0e6e2d1SJunchao Zhang       });
7010e3ece09SJunchao Zhang     }));
7020e3ece09SJunchao Zhang   PetscCall(PetscSFReduceWithMemTypeBegin(reduceSF, MPIU_SCALAR, PETSC_MEMTYPE_KOKKOS, leafBuf.data(), PETSC_MEMTYPE_KOKKOS, rootBuf.data(), MPI_REPLACE));
703f0e6e2d1SJunchao Zhang   PetscFunctionReturn(PETSC_SUCCESS);
704f0e6e2d1SJunchao Zhang }
7050e3ece09SJunchao Zhang 
706aaa8cc7dSPierre Jolivet // To finish MatMPIAIJKokkosReduce.
7070e3ece09SJunchao Zhang static PetscErrorCode MatMPIAIJKokkosReduceEnd(MPI_Comm comm, KokkosCsrMatrix A, KokkosCsrMatrix B, PetscInt cstart, PetscInt cend, const PetscInt *garray1, PetscSF ownerSF, MatReuse reuse, PetscInt *map, MatMatStruct_AtB *mm)
7080e3ece09SJunchao Zhang {
7090e3ece09SJunchao Zhang   auto       &leafBuf  = mm->leafBuf;
7100e3ece09SJunchao Zhang   auto       &rootBuf  = mm->rootBuf;
7110e3ece09SJunchao Zhang   auto       &Fda      = mm->Fd.values;
7120e3ece09SJunchao Zhang   const auto &Fdjmap   = mm->Fdjmap;
7130e3ece09SJunchao Zhang   const auto &Fdjperm  = mm->Fdjperm;
7140e3ece09SJunchao Zhang   auto        Fdnz     = mm->Fd.nnz();
7150e3ece09SJunchao Zhang   auto       &Foa      = mm->Fo.values;
7160e3ece09SJunchao Zhang   const auto &Fojmap   = mm->Fojmap;
7170e3ece09SJunchao Zhang   const auto &Fojperm  = mm->Fojperm;
7180e3ece09SJunchao Zhang   auto        Fonz     = mm->Fo.nnz();
7190e3ece09SJunchao Zhang   PetscSF     reduceSF = mm->sf;
7200e3ece09SJunchao Zhang 
721d326c3f1SJunchao Zhang   PetscFunctionBegin;
7220e3ece09SJunchao Zhang   PetscCall(PetscSFReduceEnd(reduceSF, MPIU_SCALAR, leafBuf.data(), rootBuf.data(), MPI_REPLACE));
7230e3ece09SJunchao Zhang 
7240e3ece09SJunchao Zhang   // Reduce data in rootBuf to Fd and Fo
7250e3ece09SJunchao Zhang   PetscCallCXX(Kokkos::parallel_for(
726d326c3f1SJunchao Zhang     Kokkos::RangePolicy<>(PetscGetKokkosExecutionSpace(), 0, Fdnz), KOKKOS_LAMBDA(const MatRowMapType i) {
7270e3ece09SJunchao Zhang       PetscScalar sum = 0.0;
7280e3ece09SJunchao Zhang       for (MatRowMapType k = Fdjmap(i); k < Fdjmap(i + 1); k++) sum += rootBuf(Fdjperm(k));
7290e3ece09SJunchao Zhang       Fda(i) = sum;
7300e3ece09SJunchao Zhang     }));
7310e3ece09SJunchao Zhang 
7320e3ece09SJunchao Zhang   PetscCallCXX(Kokkos::parallel_for(
733d326c3f1SJunchao Zhang     Kokkos::RangePolicy<>(PetscGetKokkosExecutionSpace(), 0, Fonz), KOKKOS_LAMBDA(const MatRowMapType i) {
7340e3ece09SJunchao Zhang       PetscScalar sum = 0.0;
7350e3ece09SJunchao Zhang       for (MatRowMapType k = Fojmap(i); k < Fojmap(i + 1); k++) sum += rootBuf(Fojperm(k));
7360e3ece09SJunchao Zhang       Foa(i) = sum;
7370e3ece09SJunchao Zhang     }));
7380e3ece09SJunchao Zhang   PetscFunctionReturn(PETSC_SUCCESS);
7390e3ece09SJunchao Zhang }
7400e3ece09SJunchao Zhang 
7410e3ece09SJunchao Zhang /*
7420e3ece09SJunchao Zhang   MatMPIAIJKokkosBcast - Bcast local rows of a MPIAIJKOKKOS matrix (E) to produce a local matrix (F, stored in mm) in split form
7430e3ece09SJunchao Zhang 
7440e3ece09SJunchao Zhang   This is a complex routine. It is essentially the MPIAIJKOKKOS counterpart of MatGetBrowsOfAoCols_MPIAIJ, but supports
7450e3ece09SJunchao Zhang   device and involves various index mapping.
7460e3ece09SJunchao Zhang 
7470e3ece09SJunchao Zhang   In the given ownerSF, leaves correspond to rows in F, and roots correspond to rows in E. Roots may connect to multiple leaves.
7480e3ece09SJunchao Zhang   Suppose F's j-th row is connected to a root identified by PetscSFNode (k,i), it means we need to bcast the i-th row of E on rank k
7490e3ece09SJunchao Zhang   to j-th row of F. ownerSF is not an arbitrary SF, instead it is the Mvctx of another MPIAIJ matrix A that is able to perform A*E.
7500e3ece09SJunchao Zhang   F has the same column layout as E.
7510e3ece09SJunchao Zhang 
7520e3ece09SJunchao Zhang   Conceptually F has global column indices. In this routine, we spit F into diagonal Fd and off-diagonal Fo.
753aaa8cc7dSPierre Jolivet   Fd uses local column indices, which are easy to compute. We just need to subtract the "local column range start" from the global indices.
7540e3ece09SJunchao Zhang   Fo had global column indices at first. We will reduce them into local ones. In doing that, we also take into account the global
7550e3ece09SJunchao Zhang   column indices that E's off-diag block has. Let's say there are n1 such indices stored in garray1[]. We will reduce them along with
7560e3ece09SJunchao Zhang   column indices in Fo and update Fo with local indices.
7570e3ece09SJunchao Zhang 
7580e3ece09SJunchao Zhang    Input Parameters:
7590e3ece09SJunchao Zhang +   E       - the MPIAIJKOKKOS matrix
7609c89aa79SPierre Jolivet .   ownerSF - the ownership SF (insignificant in MAT_REUSE_MATRIX)
7610e3ece09SJunchao Zhang .   reuse   - either MAT_INITIAL_MATRIX or MAT_REUSE_MATRIX
7620e3ece09SJunchao Zhang -   mm      - to stash matproduct intermediate data structures
7630e3ece09SJunchao Zhang 
7640e3ece09SJunchao Zhang     Output Parameters:
7650e3ece09SJunchao Zhang +   map[n1] - allocated by caller. It maps garray1[] to garray2[]. See more at ReduceTwoSetsOfGlobalIndices.
7660e3ece09SJunchao Zhang -   mm      - contains various info, such as garray2[], Fd, Fo, etc.
7670e3ece09SJunchao Zhang 
7680e3ece09SJunchao Zhang     Notes:
7690e3ece09SJunchao Zhang     When reuse = MAT_REUSE_MATRIX, ownerSF, map are not significant.
7700e3ece09SJunchao Zhang     The routine is provide in split-phase form MatMPIAIJKokkosBcastBegin/End() to provide computation/communication opportunities.
7710e3ece09SJunchao Zhang */
7720e3ece09SJunchao Zhang static PetscErrorCode MatMPIAIJKokkosBcastBegin(Mat E, PetscSF ownerSF, MatReuse reuse, PetscInt *map, MatMatStruct_AB *mm)
7730e3ece09SJunchao Zhang {
7740e3ece09SJunchao Zhang   Mat_MPIAIJ       *empi = static_cast<Mat_MPIAIJ *>(E->data);
7750e3ece09SJunchao Zhang   Mat               A = empi->A, B = empi->B; // diag and off-diag
7760e3ece09SJunchao Zhang   Mat_SeqAIJKokkos *akok = static_cast<Mat_SeqAIJKokkos *>(A->spptr), *bkok = static_cast<Mat_SeqAIJKokkos *>(B->spptr);
7770e3ece09SJunchao Zhang   PetscInt          Em = E->rmap->n; // #local rows
7780e3ece09SJunchao Zhang   MPI_Comm          comm;
7790e3ece09SJunchao Zhang 
7800e3ece09SJunchao Zhang   PetscFunctionBegin;
7810e3ece09SJunchao Zhang   PetscCallMPI(PetscObjectGetComm((PetscObject)E, &comm));
7820e3ece09SJunchao Zhang   if (reuse == MAT_INITIAL_MATRIX) {
7830e3ece09SJunchao Zhang     Mat_SeqAIJ     *aseq = static_cast<Mat_SeqAIJ *>(A->data), *bseq = static_cast<Mat_SeqAIJ *>(B->data);
7840e3ece09SJunchao Zhang     PetscInt        n1 = B->cmap->n, *Ai = aseq->i, *Aj = aseq->j, *Bi = bseq->i, *Bj = bseq->j;
7850e3ece09SJunchao Zhang     const PetscInt *garray1 = empi->garray; // its size is n1
7860e3ece09SJunchao Zhang     PetscInt        cstart, cend;
7870e3ece09SJunchao Zhang     PetscSF         bcastSF;
7880e3ece09SJunchao Zhang 
7890e3ece09SJunchao Zhang     PetscCall(MatGetOwnershipRangeColumn(E, &cstart, &cend));
7900e3ece09SJunchao Zhang 
7910e3ece09SJunchao Zhang     // Count how many nonzeros of each row in E are in the left of the diag block [cstart,cend)
7927b8d4ba6SJunchao Zhang     PetscIntKokkosViewHost E_NzLeft_h(NoInit("E_NzLeft_h"), Em), E_RowLen_h(NoInit("E_RowLen_h"), Em);
7930e3ece09SJunchao Zhang     PetscInt              *E_NzLeft = E_NzLeft_h.data(), *E_RowLen = E_RowLen_h.data();
7940e3ece09SJunchao Zhang     for (PetscInt i = 0; i < Em; i++) {
7950e3ece09SJunchao Zhang       const PetscInt *first, *last, *it;
7960e3ece09SJunchao Zhang       PetscInt        count, step;
7970e3ece09SJunchao Zhang       // std::lower_bound(first,last,cstart), but need to use global column indices
7980e3ece09SJunchao Zhang       first = Bj + Bi[i];
7990e3ece09SJunchao Zhang       last  = Bj + Bi[i + 1];
8000e3ece09SJunchao Zhang       count = last - first;
8010e3ece09SJunchao Zhang       while (count > 0) {
8020e3ece09SJunchao Zhang         it   = first;
8030e3ece09SJunchao Zhang         step = count / 2;
8040e3ece09SJunchao Zhang         it += step;
8050e3ece09SJunchao Zhang         if (empi->garray[*it] < cstart) { // map local to global
8060e3ece09SJunchao Zhang           first = ++it;
8070e3ece09SJunchao Zhang           count -= step + 1;
8080e3ece09SJunchao Zhang         } else count = step;
8090e3ece09SJunchao Zhang       }
8100e3ece09SJunchao Zhang       E_NzLeft[i] = first - (Bj + Bi[i]);
8110e3ece09SJunchao Zhang       E_RowLen[i] = (Ai[i + 1] - Ai[i]) + (Bi[i + 1] - Bi[i]);
8120e3ece09SJunchao Zhang     }
8130e3ece09SJunchao Zhang 
8140e3ece09SJunchao Zhang     // Compute row pointer Fi of F
8150e3ece09SJunchao Zhang     PetscInt *Fi, Fm, Fnz;
8160e3ece09SJunchao Zhang     PetscCall(PetscSFGetGraph(ownerSF, NULL, &Fm, NULL, NULL)); // Fm = #rows of F = nleaves of ownerSF
8170e3ece09SJunchao Zhang     PetscCall(PetscMalloc1(Fm + 1, &Fi));
8180e3ece09SJunchao Zhang     Fi[0] = 0;
8190e3ece09SJunchao Zhang     PetscCall(PetscSFBcastWithMemTypeBegin(ownerSF, MPIU_INT, PETSC_MEMTYPE_HOST, E_RowLen, PETSC_MEMTYPE_HOST, &Fi[1], MPI_REPLACE));
8200e3ece09SJunchao Zhang     PetscCall(PetscSFBcastEnd(ownerSF, MPIU_INT, E_RowLen, &Fi[1], MPI_REPLACE));
8210e3ece09SJunchao Zhang     for (PetscInt i = 0; i < Fm; i++) Fi[i + 1] += Fi[i];
8220e3ece09SJunchao Zhang     Fnz = Fi[Fm];
8230e3ece09SJunchao Zhang 
8240e3ece09SJunchao Zhang     // Build the real PetscSF for bcasting E rows (buffer to buffer)
8250e3ece09SJunchao Zhang     const PetscMPIInt *iranks, *ranks;
8260e3ece09SJunchao Zhang     const PetscInt    *ioffset, *irootloc, *roffset;
8270e3ece09SJunchao Zhang     PetscInt           niranks, nranks, *sdisp, *rdisp;
8280e3ece09SJunchao Zhang     MPI_Request       *reqs;
8290e3ece09SJunchao Zhang     PetscMPIInt        tag;
8300e3ece09SJunchao Zhang 
8310e3ece09SJunchao Zhang     PetscCall(PetscSFGetLeafRanks(ownerSF, &niranks, &iranks, &ioffset, &irootloc)); // get leaf ranks referencing roots on this process
8320e3ece09SJunchao Zhang     PetscCall(PetscSFGetRootRanks(ownerSF, &nranks, &ranks, &roffset, NULL, NULL));  // recv info
8330e3ece09SJunchao Zhang     PetscCall(PetscMalloc3(niranks + 1, &sdisp, nranks, &rdisp, niranks + nranks, &reqs));
8340e3ece09SJunchao Zhang 
8350e3ece09SJunchao Zhang     sdisp[0] = 0; // send displacement
8360e3ece09SJunchao Zhang     for (PetscInt i = 0; i < niranks; i++) {
8370e3ece09SJunchao Zhang       sdisp[i + 1] = sdisp[i];
8380e3ece09SJunchao Zhang       for (PetscInt j = ioffset[i]; j < ioffset[i + 1]; j++) {
8390e3ece09SJunchao Zhang         PetscInt r = irootloc[j]; // row to be sent
8400e3ece09SJunchao Zhang         sdisp[i + 1] += E_RowLen[r];
8410e3ece09SJunchao Zhang       }
8420e3ece09SJunchao Zhang     }
8430e3ece09SJunchao Zhang 
8440e3ece09SJunchao Zhang     PetscCallMPI(PetscCommGetNewTag(comm, &tag));
8450e3ece09SJunchao Zhang     for (PetscInt i = 0; i < nranks; i++) PetscCallMPI(MPI_Irecv(&rdisp[i], 1, MPIU_INT, ranks[i], tag, comm, &reqs[i]));
8460e3ece09SJunchao Zhang     for (PetscInt i = 0; i < niranks; i++) PetscCallMPI(MPI_Isend(&sdisp[i], 1, MPIU_INT, iranks[i], tag, comm, &reqs[nranks + i]));
8470e3ece09SJunchao Zhang     PetscCallMPI(MPI_Waitall(niranks + nranks, reqs, MPI_STATUSES_IGNORE));
8480e3ece09SJunchao Zhang 
8490e3ece09SJunchao Zhang     PetscInt     nleaves = Fnz;            // leaves are nonzeros I will receive
8500e3ece09SJunchao Zhang     PetscInt     nroots  = sdisp[niranks]; // roots are nonzeros I will send
8510e3ece09SJunchao Zhang     PetscSFNode *iremote;                  // give ownership to bcastSF
8520e3ece09SJunchao Zhang     PetscCall(PetscMalloc1(nleaves, &iremote));
8530e3ece09SJunchao Zhang     for (PetscInt i = 0; i < nranks; i++) { // for each sender rank
8540e3ece09SJunchao Zhang       PetscInt k = 0;
8550e3ece09SJunchao Zhang       for (PetscInt j = Fi[roffset[i]]; j < Fi[roffset[i + 1]]; j++) { // I will receive rows [roffset[i], roffset[i+1]) of F from ranks[i]
8560e3ece09SJunchao Zhang         iremote[j].rank  = ranks[i];
8570e3ece09SJunchao Zhang         iremote[j].index = rdisp[i] + k; // their root location
8580e3ece09SJunchao Zhang         k++;
8590e3ece09SJunchao Zhang       }
8600e3ece09SJunchao Zhang     }
8610e3ece09SJunchao Zhang     PetscCall(PetscSFCreate(comm, &bcastSF));
8620e3ece09SJunchao Zhang     PetscCall(PetscSFSetGraph(bcastSF, nroots, nleaves, NULL, PETSC_OWN_POINTER, iremote, PETSC_OWN_POINTER));
8630e3ece09SJunchao Zhang     PetscCall(PetscFree3(sdisp, rdisp, reqs));
8640e3ece09SJunchao Zhang 
8650e3ece09SJunchao Zhang     // Build a plan (rowoffset, irootloc, E_NzLeft) to copy rows in E to rootdata of bcastSF in parallel
8667b8d4ba6SJunchao Zhang     PetscIntKokkosViewHost rowoffset_h(NoInit("rowoffset_h"), ioffset[niranks] + 1);
8670e3ece09SJunchao Zhang     PetscInt              *rowoffset = rowoffset_h.data(); // for each entry (row) indicated in irootloc[], we calculate its destinate offset in copying
8680e3ece09SJunchao Zhang     rowoffset[0]                     = 0;
8697b8d4ba6SJunchao Zhang     for (PetscInt i = 0; i < ioffset[niranks]; i++) { rowoffset[i + 1] = rowoffset[i] + E_RowLen[irootloc[i]]; }
8700e3ece09SJunchao Zhang 
8710e3ece09SJunchao Zhang     // Copy (global) column indices of the needed rows in E to a buffer, and then bcast to Fj[]
8720e3ece09SJunchao Zhang     PetscInt *jbuf, *Fj;
8730e3ece09SJunchao Zhang     PetscCall(PetscMalloc2(nroots, &jbuf, Fnz, &Fj));
8740e3ece09SJunchao Zhang     for (PetscInt k = 0; k < ioffset[niranks]; k++) {
8750e3ece09SJunchao Zhang       PetscInt  i      = irootloc[k]; // row to be copied
8760e3ece09SJunchao Zhang       PetscInt *buf    = &jbuf[rowoffset[k]];
8770e3ece09SJunchao Zhang       PetscInt  nzLeft = E_NzLeft[i];
8780e3ece09SJunchao Zhang       PetscInt  alen = Ai[i + 1] - Ai[i], blen = Bi[i + 1] - Bi[i];
8790e3ece09SJunchao Zhang       for (PetscInt j = 0; j < alen + blen; j++) {
8800e3ece09SJunchao Zhang         if (j < nzLeft) {
8810e3ece09SJunchao Zhang           buf[j] = empi->garray[Bj[Bi[i] + j]]; // left B, in global
8820e3ece09SJunchao Zhang         } else if (j < nzLeft + alen) {
8830e3ece09SJunchao Zhang           buf[j] = Aj[Ai[i] + j - nzLeft] + cstart; // diag A, also in global
8840e3ece09SJunchao Zhang         } else {
8850e3ece09SJunchao Zhang           buf[j] = empi->garray[Bj[Bi[i] + j - alen]]; // right B, in global
8860e3ece09SJunchao Zhang         }
8870e3ece09SJunchao Zhang       }
8880e3ece09SJunchao Zhang     }
8890e3ece09SJunchao Zhang     PetscCall(PetscSFBcastWithMemTypeBegin(bcastSF, MPIU_INT, PETSC_MEMTYPE_HOST, jbuf, PETSC_MEMTYPE_HOST, Fj, MPI_REPLACE));
8900e3ece09SJunchao Zhang     PetscCall(PetscSFBcastEnd(bcastSF, MPIU_INT, jbuf, Fj, MPI_REPLACE));
8910e3ece09SJunchao Zhang 
8920e3ece09SJunchao Zhang     // Build a plan (i.e., F_NzLeft) to split F into Fd and Fo
8937b8d4ba6SJunchao Zhang     MatRowMapKokkosViewHost Fdi_h(NoInit("Fdi_h"), Fm + 1), Foi_h(NoInit("Foi_h"), Fm + 1); // row pointer of Fd, Fo
8947b8d4ba6SJunchao Zhang     MatColIdxKokkosViewHost F_NzLeft_h(NoInit("F_NzLeft_h"), Fm);                           // split each row of F into Left, Diag, Right. We only need to record #nz in Left and Diag.
8950e3ece09SJunchao Zhang     MatRowMapType          *Fdi = Fdi_h.data(), *Foi = Foi_h.data();
8960e3ece09SJunchao Zhang     MatColIdxType          *F_NzLeft = F_NzLeft_h.data();
8970e3ece09SJunchao Zhang 
8980e3ece09SJunchao Zhang     Fdi[0] = Foi[0] = 0;
8990e3ece09SJunchao Zhang     for (PetscInt i = 0; i < Fm; i++) {
9000e3ece09SJunchao Zhang       PetscInt *first, *last, *lb1, *lb2;
9010e3ece09SJunchao Zhang       // cut the row into: Left, [cstart, cend), Right
9020e3ece09SJunchao Zhang       first       = Fj + Fi[i];
9030e3ece09SJunchao Zhang       last        = Fj + Fi[i + 1];
9040e3ece09SJunchao Zhang       lb1         = std::lower_bound(first, last, cstart);
9050e3ece09SJunchao Zhang       F_NzLeft[i] = lb1 - first;
9060e3ece09SJunchao Zhang       lb2         = std::lower_bound(first, last, cend);
9070e3ece09SJunchao Zhang       Fdi[i + 1]  = lb2 - lb1;                        // row i length in Fdi
9080e3ece09SJunchao Zhang       Foi[i + 1]  = (Fi[i + 1] - Fi[i]) - Fdi[i + 1]; // row i length in Foi
9090e3ece09SJunchao Zhang     }
9100e3ece09SJunchao Zhang     for (PetscInt i = 0; i < Fm; i++) {
9110e3ece09SJunchao Zhang       Fdi[i + 1] += Fdi[i];
9120e3ece09SJunchao Zhang       Foi[i + 1] += Foi[i];
9130e3ece09SJunchao Zhang     }
9140e3ece09SJunchao Zhang 
9150e3ece09SJunchao Zhang     // Fill Fdj[] and Foj[], i.e., columns of Fd and Fo. Fdj[] are local, but Foj[] are not yet.
9160e3ece09SJunchao Zhang     PetscInt                Fdnz = Fdi[Fm], Fonz = Foi[Fm];
9177b8d4ba6SJunchao Zhang     MatColIdxKokkosViewHost Fdj_h(NoInit("Fdj_h"), Fdnz), Foj_h(NoInit("Foj_h"), Fonz);
9180e3ece09SJunchao Zhang     MatColIdxType          *Fdj = Fdj_h.data(), *Foj = Foj_h.data(), gid;
9190e3ece09SJunchao Zhang 
9200e3ece09SJunchao Zhang     for (PetscInt i = 0; i < Fm; i++) {
9210e3ece09SJunchao Zhang       PetscInt nzLeft = F_NzLeft[i];
9220e3ece09SJunchao Zhang       PetscInt len    = Fdi[i + 1] - Fdi[i]; // diag row len
9230e3ece09SJunchao Zhang       for (PetscInt j = 0; j < Fi[i + 1] - Fi[i]; j++) {
9240e3ece09SJunchao Zhang         gid = Fj[Fi[i] + j];
9250e3ece09SJunchao Zhang         if (j < nzLeft) { // left, in global
9260e3ece09SJunchao Zhang           Foj[Foi[i] + j] = gid;
9270e3ece09SJunchao Zhang         } else if (j < nzLeft + len) { // diag, in local
9280e3ece09SJunchao Zhang           Fdj[Fdi[i] + j - nzLeft] = gid - cstart;
9290e3ece09SJunchao Zhang         } else { // right, in global
9300e3ece09SJunchao Zhang           Foj[Foi[i] + j - len] = gid;
9310e3ece09SJunchao Zhang         }
9320e3ece09SJunchao Zhang       }
9330e3ece09SJunchao Zhang     }
9340e3ece09SJunchao Zhang     PetscCall(PetscFree2(jbuf, Fj));
9350e3ece09SJunchao Zhang     PetscCall(PetscFree(Fi));
9360e3ece09SJunchao Zhang 
9370e3ece09SJunchao Zhang     // Reduce global indices in Foj[] and garray1[] into local ones
9380e3ece09SJunchao Zhang     PetscInt n2, *garray2;
9390e3ece09SJunchao Zhang     PetscCall(ReduceTwoSetsOfGlobalIndices(n1, garray1, Fonz, Foj, &n2, &garray2, map));
9400e3ece09SJunchao Zhang 
9410e3ece09SJunchao Zhang     // Record the plans built above, for reuse
9420e3ece09SJunchao Zhang     PetscIntKokkosViewHost tmp(const_cast<PetscInt *>(irootloc), ioffset[niranks]); // irootloc[] is owned by ownerSF. We create a copy for safety
9437b8d4ba6SJunchao Zhang     PetscIntKokkosViewHost irootloc_h(NoInit("irootloc_h"), ioffset[niranks]);
9440e3ece09SJunchao Zhang     Kokkos::deep_copy(irootloc_h, tmp);
9450e3ece09SJunchao Zhang     mm->sf        = bcastSF;
9460e3ece09SJunchao Zhang     mm->E_NzLeft  = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), E_NzLeft_h);
9470e3ece09SJunchao Zhang     mm->F_NzLeft  = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), F_NzLeft_h);
9480e3ece09SJunchao Zhang     mm->irootloc  = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), irootloc_h);
9490e3ece09SJunchao Zhang     mm->rowoffset = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), rowoffset_h);
9507b8d4ba6SJunchao Zhang     mm->rootBuf   = MatScalarKokkosView(NoInit("rootBuf"), nroots);
9517b8d4ba6SJunchao Zhang     mm->leafBuf   = MatScalarKokkosView(NoInit("leafBuf"), nleaves);
9520e3ece09SJunchao Zhang     mm->garray    = garray2;
9530e3ece09SJunchao Zhang     mm->n         = n2;
9540e3ece09SJunchao Zhang 
9550e3ece09SJunchao Zhang     // Output Fd and Fo in KokkosCsrMatrix format
9567b8d4ba6SJunchao Zhang     MatScalarKokkosView Fda_d(NoInit("Fda_d"), Fdnz), Foa_d(NoInit("Foa_d"), Fonz);
9570e3ece09SJunchao Zhang     MatRowMapKokkosView Fdi_d = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Fdi_h);
9580e3ece09SJunchao Zhang     MatColIdxKokkosView Fdj_d = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Fdj_h);
9590e3ece09SJunchao Zhang     MatRowMapKokkosView Foi_d = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Foi_h);
9600e3ece09SJunchao Zhang     MatColIdxKokkosView Foj_d = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Foj_h);
9610e3ece09SJunchao Zhang 
9620e3ece09SJunchao Zhang     PetscCallCXX(mm->Fd = KokkosCsrMatrix("Fd", Fm, cend - cstart, Fdnz, Fda_d, Fdi_d, Fdj_d));
9630e3ece09SJunchao Zhang     PetscCallCXX(mm->Fo = KokkosCsrMatrix("Fo", Fm, n2, Fonz, Foa_d, Foi_d, Foj_d));
9640e3ece09SJunchao Zhang 
9650e3ece09SJunchao Zhang     // Compute kernel launch parameters in merging E or splitting F
9660e3ece09SJunchao Zhang     PetscInt teamSize, vectorLength, rowsPerTeam;
9670e3ece09SJunchao Zhang 
9680e3ece09SJunchao Zhang     teamSize = vectorLength = rowsPerTeam = -1;
9690e3ece09SJunchao Zhang     PetscCall(MatMergeGetLaunchParameters<DefaultExecutionSpace>(mm->irootloc.extent(0), mm->rootBuf.extent(0), -1, teamSize, vectorLength, rowsPerTeam));
9700e3ece09SJunchao Zhang     mm->E_TeamSize     = teamSize;
9710e3ece09SJunchao Zhang     mm->E_VectorLength = vectorLength;
9720e3ece09SJunchao Zhang     mm->E_RowsPerTeam  = rowsPerTeam;
9730e3ece09SJunchao Zhang 
9740e3ece09SJunchao Zhang     teamSize = vectorLength = rowsPerTeam = -1;
9750e3ece09SJunchao Zhang     PetscCall(MatMergeGetLaunchParameters<DefaultExecutionSpace>(Fm, Fnz, -1, teamSize, vectorLength, rowsPerTeam));
9760e3ece09SJunchao Zhang     mm->F_TeamSize     = teamSize;
9770e3ece09SJunchao Zhang     mm->F_VectorLength = vectorLength;
9780e3ece09SJunchao Zhang     mm->F_RowsPerTeam  = rowsPerTeam;
9790e3ece09SJunchao Zhang   } else PetscCheck(reuse == MAT_REUSE_MATRIX, comm, PETSC_ERR_PLIB, "Unsupported MatReuse enum %d", reuse);
9800e3ece09SJunchao Zhang 
9810e3ece09SJunchao Zhang   // Sync E's value to device
9820e3ece09SJunchao Zhang   akok->a_dual.sync_device();
9830e3ece09SJunchao Zhang   bkok->a_dual.sync_device();
9840e3ece09SJunchao Zhang 
9850e3ece09SJunchao Zhang   // Handy aliases
9860e3ece09SJunchao Zhang   const auto &Aa = akok->a_dual.view_device();
9870e3ece09SJunchao Zhang   const auto &Ba = bkok->a_dual.view_device();
9880e3ece09SJunchao Zhang   const auto &Ai = akok->i_dual.view_device();
9890e3ece09SJunchao Zhang   const auto &Bi = bkok->i_dual.view_device();
9900e3ece09SJunchao Zhang 
9910e3ece09SJunchao Zhang   // Fetch the plans
9920e3ece09SJunchao Zhang   PetscIntKokkosView  &E_NzLeft  = mm->E_NzLeft;
9930e3ece09SJunchao Zhang   PetscSF             &bcastSF   = mm->sf;
9940e3ece09SJunchao Zhang   MatScalarKokkosView &rootBuf   = mm->rootBuf;
9950e3ece09SJunchao Zhang   MatScalarKokkosView &leafBuf   = mm->leafBuf;
9960e3ece09SJunchao Zhang   PetscIntKokkosView  &irootloc  = mm->irootloc;
9970e3ece09SJunchao Zhang   PetscIntKokkosView  &rowoffset = mm->rowoffset;
9980e3ece09SJunchao Zhang 
9990e3ece09SJunchao Zhang   PetscInt teamSize     = mm->E_TeamSize;
10000e3ece09SJunchao Zhang   PetscInt vectorLength = mm->E_VectorLength;
10010e3ece09SJunchao Zhang   PetscInt rowsPerTeam  = mm->E_RowsPerTeam;
10020e3ece09SJunchao Zhang   PetscInt workSets     = (irootloc.extent(0) + rowsPerTeam - 1) / rowsPerTeam;
10030e3ece09SJunchao Zhang 
10040e3ece09SJunchao Zhang   // Copy rows in A/B of E to rootBuf, then bcast it to leafBuf
10050e3ece09SJunchao Zhang   PetscCallCXX(Kokkos::parallel_for(
1006d326c3f1SJunchao Zhang     Kokkos::TeamPolicy<>(PetscGetKokkosExecutionSpace(), workSets, teamSize, vectorLength), KOKKOS_LAMBDA(const KokkosTeamMemberType &t) {
10070e3ece09SJunchao Zhang       Kokkos::parallel_for(Kokkos::TeamThreadRange(t, 0, rowsPerTeam), [&](PetscInt k) {
10080e3ece09SJunchao Zhang         size_t r = t.league_rank() * rowsPerTeam + k; // r-th entry in irootloc[]
10090e3ece09SJunchao Zhang         if (r < irootloc.extent(0)) {
10100e3ece09SJunchao Zhang           PetscInt i      = irootloc(r); // row i of E
10110e3ece09SJunchao Zhang           PetscInt disp   = rowoffset(r);
10120e3ece09SJunchao Zhang           PetscInt alen   = Ai(i + 1) - Ai(i);
10130e3ece09SJunchao Zhang           PetscInt blen   = Bi(i + 1) - Bi(i);
10140e3ece09SJunchao Zhang           PetscInt nzleft = E_NzLeft(i);
10150e3ece09SJunchao Zhang 
10160e3ece09SJunchao Zhang           Kokkos::parallel_for(Kokkos::ThreadVectorRange(t, alen + blen), [&](PetscInt j) {
10170e3ece09SJunchao Zhang             if (j < nzleft) { // B left
10180e3ece09SJunchao Zhang               rootBuf(disp + j) = Ba(Bi(i) + j);
10190e3ece09SJunchao Zhang             } else if (j < nzleft + alen) { // diag A
10200e3ece09SJunchao Zhang               rootBuf(disp + j) = Aa(Ai(i) + j - nzleft);
10210e3ece09SJunchao Zhang             } else { // B right
10220e3ece09SJunchao Zhang               rootBuf(disp + j) = Ba(Bi(i) + j - alen);
10230e3ece09SJunchao Zhang             }
10240e3ece09SJunchao Zhang           });
10250e3ece09SJunchao Zhang         }
10260e3ece09SJunchao Zhang       });
10270e3ece09SJunchao Zhang     }));
10280e3ece09SJunchao Zhang   PetscCall(PetscSFBcastWithMemTypeBegin(bcastSF, MPIU_SCALAR, PETSC_MEMTYPE_KOKKOS, rootBuf.data(), PETSC_MEMTYPE_KOKKOS, leafBuf.data(), MPI_REPLACE));
10290e3ece09SJunchao Zhang   PetscFunctionReturn(PETSC_SUCCESS);
10300e3ece09SJunchao Zhang }
10310e3ece09SJunchao Zhang 
10320e3ece09SJunchao Zhang // To finish MatMPIAIJKokkosBcast.
10330e3ece09SJunchao Zhang static PetscErrorCode MatMPIAIJKokkosBcastEnd(Mat E, PetscSF ownerSF, MatReuse reuse, PetscInt *map, MatMatStruct_AB *mm)
10340e3ece09SJunchao Zhang {
10350e3ece09SJunchao Zhang   PetscFunctionBegin;
10360e3ece09SJunchao Zhang   const auto &Fd  = mm->Fd;
10370e3ece09SJunchao Zhang   const auto &Fo  = mm->Fo;
10380e3ece09SJunchao Zhang   const auto &Fdi = Fd.graph.row_map;
10390e3ece09SJunchao Zhang   const auto &Foi = Fo.graph.row_map;
10400e3ece09SJunchao Zhang   auto       &Fda = Fd.values;
10410e3ece09SJunchao Zhang   auto       &Foa = Fo.values;
10420e3ece09SJunchao Zhang   auto        Fm  = Fd.numRows();
10430e3ece09SJunchao Zhang 
10440e3ece09SJunchao Zhang   PetscIntKokkosView  &F_NzLeft     = mm->F_NzLeft;
10450e3ece09SJunchao Zhang   PetscSF             &bcastSF      = mm->sf;
10460e3ece09SJunchao Zhang   MatScalarKokkosView &rootBuf      = mm->rootBuf;
10470e3ece09SJunchao Zhang   MatScalarKokkosView &leafBuf      = mm->leafBuf;
10480e3ece09SJunchao Zhang   PetscInt             teamSize     = mm->F_TeamSize;
10490e3ece09SJunchao Zhang   PetscInt             vectorLength = mm->F_VectorLength;
10500e3ece09SJunchao Zhang   PetscInt             rowsPerTeam  = mm->F_RowsPerTeam;
10510e3ece09SJunchao Zhang   PetscInt             workSets     = (Fm + rowsPerTeam - 1) / rowsPerTeam;
10520e3ece09SJunchao Zhang 
10530e3ece09SJunchao Zhang   PetscCall(PetscSFBcastEnd(bcastSF, MPIU_SCALAR, rootBuf.data(), leafBuf.data(), MPI_REPLACE));
10540e3ece09SJunchao Zhang 
10550e3ece09SJunchao Zhang   // Update Fda and Foa with new data in leafBuf (as if it is Fa)
10560e3ece09SJunchao Zhang   PetscCallCXX(Kokkos::parallel_for(
1057d326c3f1SJunchao Zhang     Kokkos::TeamPolicy<>(PetscGetKokkosExecutionSpace(), workSets, teamSize, vectorLength), KOKKOS_LAMBDA(const KokkosTeamMemberType &t) {
10580e3ece09SJunchao Zhang       Kokkos::parallel_for(Kokkos::TeamThreadRange(t, 0, rowsPerTeam), [&](PetscInt k) {
10590e3ece09SJunchao Zhang         PetscInt i = t.league_rank() * rowsPerTeam + k; // i-th row in F
10600e3ece09SJunchao Zhang         if (i < Fm) {
10610e3ece09SJunchao Zhang           PetscInt nzLeft = F_NzLeft(i);
10620e3ece09SJunchao Zhang           PetscInt alen   = Fdi(i + 1) - Fdi(i);
10630e3ece09SJunchao Zhang           PetscInt blen   = Foi(i + 1) - Foi(i);
10640e3ece09SJunchao Zhang           PetscInt Fii    = Fdi(i) + Foi(i);
10650e3ece09SJunchao Zhang 
10660e3ece09SJunchao Zhang           Kokkos::parallel_for(Kokkos::ThreadVectorRange(t, alen + blen), [&](PetscInt j) {
10670e3ece09SJunchao Zhang             PetscScalar val = leafBuf(Fii + j);
10680e3ece09SJunchao Zhang             if (j < nzLeft) { // left
10690e3ece09SJunchao Zhang               Foa(Foi(i) + j) = val;
10700e3ece09SJunchao Zhang             } else if (j < nzLeft + alen) { // diag
10710e3ece09SJunchao Zhang               Fda(Fdi(i) + j - nzLeft) = val;
10720e3ece09SJunchao Zhang             } else { // right
10730e3ece09SJunchao Zhang               Foa(Foi(i) + j - alen) = val;
10740e3ece09SJunchao Zhang             }
10750e3ece09SJunchao Zhang           });
10760e3ece09SJunchao Zhang         }
10770e3ece09SJunchao Zhang       });
10780e3ece09SJunchao Zhang     }));
10790e3ece09SJunchao Zhang   PetscFunctionReturn(PETSC_SUCCESS);
10800e3ece09SJunchao Zhang }
10810e3ece09SJunchao Zhang 
10820e3ece09SJunchao Zhang static PetscErrorCode MatProductSymbolic_MPIAIJKokkos_AtB(Mat_Product *product, Mat A, Mat B, MatMatStruct_AtB *mm)
10830e3ece09SJunchao Zhang {
10840e3ece09SJunchao Zhang   Mat_MPIAIJ     *ampi = static_cast<Mat_MPIAIJ *>(A->data);
10850e3ece09SJunchao Zhang   Mat_MPIAIJ     *bmpi = static_cast<Mat_MPIAIJ *>(B->data);
10860e3ece09SJunchao Zhang   KokkosCsrMatrix Adt, Aot, Ad, Ao, Bd, Bo;
10870e3ece09SJunchao Zhang   PetscInt        cstart, cend;
10880e3ece09SJunchao Zhang   MPI_Comm        comm;
10890e3ece09SJunchao Zhang 
10900e3ece09SJunchao Zhang   PetscFunctionBegin;
10910e3ece09SJunchao Zhang   PetscCall(PetscObjectGetComm((PetscObject)B, &comm));
10920e3ece09SJunchao Zhang   PetscCall(MatSeqAIJKokkosGenerateTranspose_Private(ampi->A, &Adt));
10930e3ece09SJunchao Zhang   PetscCall(MatSeqAIJKokkosGenerateTranspose_Private(ampi->B, &Aot));
10940e3ece09SJunchao Zhang   PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(ampi->A, &Ad));
10950e3ece09SJunchao Zhang   PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(ampi->B, &Ao));
10960e3ece09SJunchao Zhang   PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(bmpi->A, &Bd));
10970e3ece09SJunchao Zhang   PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(bmpi->B, &Bo));
10980e3ece09SJunchao Zhang 
10990e3ece09SJunchao Zhang   // TODO: add command line options to select spgemm algorithms
11000e3ece09SJunchao Zhang   auto spgemm_alg = KokkosSparse::SPGEMMAlgorithm::SPGEMM_DEFAULT; // default is TPL if enabled, otherwise KK
11010e3ece09SJunchao Zhang 
11020e3ece09SJunchao Zhang   // CUDA-10.2's spgemm has bugs. We prefer the SpGEMMreuse APIs introduced in cuda-11.4
11030e3ece09SJunchao Zhang #if defined(KOKKOSKERNELS_ENABLE_TPL_CUSPARSE)
11040e3ece09SJunchao Zhang   #if PETSC_PKG_CUDA_VERSION_LT(11, 4, 0)
11050e3ece09SJunchao Zhang   spgemm_alg = KokkosSparse::SPGEMMAlgorithm::SPGEMM_KK;
1106f0e6e2d1SJunchao Zhang   #endif
11070e3ece09SJunchao Zhang #endif
11080e3ece09SJunchao Zhang 
11090e3ece09SJunchao Zhang   PetscCallCXX(mm->kh1.create_spgemm_handle(spgemm_alg));
11100e3ece09SJunchao Zhang   PetscCallCXX(mm->kh2.create_spgemm_handle(spgemm_alg));
11110e3ece09SJunchao Zhang   PetscCallCXX(mm->kh3.create_spgemm_handle(spgemm_alg));
11120e3ece09SJunchao Zhang   PetscCallCXX(mm->kh4.create_spgemm_handle(spgemm_alg));
11130e3ece09SJunchao Zhang 
11140e3ece09SJunchao Zhang   // Aot * (B's diag + B's off-diag)
11150e3ece09SJunchao Zhang   PetscCallCXX(KokkosSparse::spgemm_symbolic(mm->kh3, Aot, false, Bd, false, mm->C3));
11160e3ece09SJunchao Zhang   PetscCallCXX(KokkosSparse::spgemm_symbolic(mm->kh4, Aot, false, Bo, false, mm->C4));
11170e3ece09SJunchao Zhang   // KK spgemm_symbolic() only populates the result's row map, but not its columns.
11180e3ece09SJunchao Zhang   // TODO: Remove the fake spgemm_numeric() after KK fixed this problem.
11190e3ece09SJunchao Zhang   PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh3, Aot, false, Bd, false, mm->C3));
11200e3ece09SJunchao Zhang   PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh4, Aot, false, Bo, false, mm->C4));
11210e3ece09SJunchao Zhang #if PETSC_PKG_KOKKOS_KERNELS_VERSION_LT(4, 0, 0)
1122d326c3f1SJunchao Zhang 
11230e3ece09SJunchao Zhang   PetscCallCXX(sort_crs_matrix(mm->C3));
11240e3ece09SJunchao Zhang   PetscCallCXX(sort_crs_matrix(mm->C4));
11250e3ece09SJunchao Zhang #endif
11260e3ece09SJunchao Zhang 
11270e3ece09SJunchao Zhang   // Reduce E (i.e., C3 and C4)'s rows to form F, and overlap the communication
11287b8d4ba6SJunchao Zhang   PetscIntKokkosViewHost map_h(NoInit("map_h"), bmpi->B->cmap->n);
11290e3ece09SJunchao Zhang   PetscCall(MatGetOwnershipRangeColumn(B, &cstart, &cend));
11300e3ece09SJunchao Zhang   PetscCall(MatMPIAIJKokkosReduceBegin(comm, mm->C3, mm->C4, cstart, cend, bmpi->garray, ampi->Mvctx, MAT_INITIAL_MATRIX, map_h.data(), mm));
11310e3ece09SJunchao Zhang 
11320e3ece09SJunchao Zhang   // Adt * (B's diag + B's off-diag)
11330e3ece09SJunchao Zhang   PetscCallCXX(KokkosSparse::spgemm_symbolic(mm->kh1, Adt, false, Bd, false, mm->C1));
11340e3ece09SJunchao Zhang   PetscCallCXX(KokkosSparse::spgemm_symbolic(mm->kh2, Adt, false, Bo, false, mm->C2_mid));
11350e3ece09SJunchao Zhang   PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh1, Adt, false, Bd, false, mm->C1));
11360e3ece09SJunchao Zhang   PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh2, Adt, false, Bo, false, mm->C2_mid));
11370e3ece09SJunchao Zhang #if PETSC_PKG_KOKKOS_KERNELS_VERSION_LT(4, 0, 0)
11380e3ece09SJunchao Zhang   PetscCallCXX(sort_crs_matrix(mm->C1));
11390e3ece09SJunchao Zhang   PetscCallCXX(sort_crs_matrix(mm->C2_mid));
11400e3ece09SJunchao Zhang #endif
11410e3ece09SJunchao Zhang 
11420e3ece09SJunchao Zhang   PetscCall(MatMPIAIJKokkosReduceEnd(comm, mm->C3, mm->C4, cstart, cend, bmpi->garray, ampi->Mvctx, MAT_INITIAL_MATRIX, map_h.data(), mm));
11430e3ece09SJunchao Zhang 
11440e3ece09SJunchao Zhang   // Create C2, which shares a, i arrays with C2_mid, but with new column indices and potentially larger column size
11457b8d4ba6SJunchao Zhang   MatColIdxKokkosView oldj = mm->C2_mid.graph.entries, newj(NoInit("j"), oldj.extent(0));
11460e3ece09SJunchao Zhang   PetscIntKokkosView  map  = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), map_h);
1147d326c3f1SJunchao Zhang   PetscCallCXX(Kokkos::parallel_for(Kokkos::RangePolicy<>(PetscGetKokkosExecutionSpace(), 0, oldj.extent(0)), KOKKOS_LAMBDA(const PetscInt i) { newj(i) = map(oldj(i)); }));
11480e3ece09SJunchao Zhang   PetscCallCXX(mm->C2 = KokkosCsrMatrix("C2", mm->C2_mid.numRows(), mm->n /*new column size*/, mm->C2_mid.nnz(), mm->C2_mid.values, mm->C2_mid.graph.row_map, newj));
11490e3ece09SJunchao Zhang 
11500e3ece09SJunchao Zhang   // C = (C1+Fd, C2+Fo)
11510e3ece09SJunchao Zhang   PetscCallCXX(mm->kh1.create_spadd_handle(true)); // C1, Fd are sorted
11520e3ece09SJunchao Zhang   PetscCallCXX(mm->kh2.create_spadd_handle(true)); // C2, Fo are sorted
11530e3ece09SJunchao Zhang   PetscCallCXX(KokkosSparse::spadd_symbolic(&mm->kh1, mm->C1, mm->Fd, mm->Cd));
11540e3ece09SJunchao Zhang   PetscCallCXX(KokkosSparse::spadd_symbolic(&mm->kh2, mm->C2, mm->Fo, mm->Co));
11550e3ece09SJunchao Zhang   PetscCallCXX(KokkosSparse::spadd_numeric(&mm->kh1, 1.0, mm->C1, 1.0, mm->Fd, mm->Cd));
11560e3ece09SJunchao Zhang   PetscCallCXX(KokkosSparse::spadd_numeric(&mm->kh2, 1.0, mm->C2, 1.0, mm->Fo, mm->Co));
11570e3ece09SJunchao Zhang   PetscFunctionReturn(PETSC_SUCCESS);
11580e3ece09SJunchao Zhang }
11590e3ece09SJunchao Zhang 
11600e3ece09SJunchao Zhang static PetscErrorCode MatProductNumeric_MPIAIJKokkos_AtB(Mat_Product *product, Mat A, Mat B, MatMatStruct_AtB *mm)
11610e3ece09SJunchao Zhang {
11620e3ece09SJunchao Zhang   Mat_MPIAIJ     *ampi = static_cast<Mat_MPIAIJ *>(A->data);
11630e3ece09SJunchao Zhang   Mat_MPIAIJ     *bmpi = static_cast<Mat_MPIAIJ *>(B->data);
11640e3ece09SJunchao Zhang   KokkosCsrMatrix Adt, Aot, Bd, Bo;
11650e3ece09SJunchao Zhang   MPI_Comm        comm;
11660e3ece09SJunchao Zhang 
11670e3ece09SJunchao Zhang   PetscFunctionBegin;
11680e3ece09SJunchao Zhang   PetscCall(PetscObjectGetComm((PetscObject)B, &comm));
11690e3ece09SJunchao Zhang   PetscCall(MatSeqAIJKokkosGenerateTranspose_Private(ampi->A, &Adt));
11700e3ece09SJunchao Zhang   PetscCall(MatSeqAIJKokkosGenerateTranspose_Private(ampi->B, &Aot));
11710e3ece09SJunchao Zhang   PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(bmpi->A, &Bd));
11720e3ece09SJunchao Zhang   PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(bmpi->B, &Bo));
11730e3ece09SJunchao Zhang 
11740e3ece09SJunchao Zhang   // Aot * (B's diag + B's off-diag)
11750e3ece09SJunchao Zhang   PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh3, Aot, false, Bd, false, mm->C3));
11760e3ece09SJunchao Zhang   PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh4, Aot, false, Bo, false, mm->C4));
11770e3ece09SJunchao Zhang 
11780e3ece09SJunchao Zhang   // Reduce E (i.e., C3 and C4)'s rows to form F, and overlap the communication
11790e3ece09SJunchao Zhang   PetscCall(MatMPIAIJKokkosReduceBegin(comm, mm->C3, mm->C4, 0, 0, NULL, NULL, MAT_REUSE_MATRIX, NULL, mm));
11800e3ece09SJunchao Zhang 
11810e3ece09SJunchao Zhang   // Adt * (B's diag + B's off-diag)
11820e3ece09SJunchao Zhang   PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh1, Adt, false, Bd, false, mm->C1));
11830e3ece09SJunchao Zhang   PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh2, Adt, false, Bo, false, mm->C2_mid));
11840e3ece09SJunchao Zhang 
11850e3ece09SJunchao Zhang   PetscCall(MatMPIAIJKokkosReduceEnd(comm, mm->C3, mm->C4, 0, 0, NULL, NULL, MAT_REUSE_MATRIX, NULL, mm));
11860e3ece09SJunchao Zhang 
11870e3ece09SJunchao Zhang   // C = (C1+Fd, C2+Fo)
11880e3ece09SJunchao Zhang   PetscCallCXX(KokkosSparse::spadd_numeric(&mm->kh1, 1.0, mm->C1, 1.0, mm->Fd, mm->Cd));
11890e3ece09SJunchao Zhang   PetscCallCXX(KokkosSparse::spadd_numeric(&mm->kh2, 1.0, mm->C2, 1.0, mm->Fo, mm->Co));
11900e3ece09SJunchao Zhang   PetscFunctionReturn(PETSC_SUCCESS);
11910e3ece09SJunchao Zhang }
1192f0e6e2d1SJunchao Zhang 
1193076ba34aSJunchao Zhang /* MatProductSymbolic_MPIAIJKokkos_AB - AB flavor of MatProductSymbolic_MPIAIJKokkos
1194076ba34aSJunchao Zhang 
1195076ba34aSJunchao Zhang   Input Parameters:
1196076ba34aSJunchao Zhang +  product  - Mat_Product which carried out the computation. Passed in to access info about this mat product.
1197076ba34aSJunchao Zhang .  A        - an MPIAIJKOKKOS matrix
1198076ba34aSJunchao Zhang .  B        - an MPIAIJKOKKOS matrix
1199076ba34aSJunchao Zhang -  mm       - a struct used to stash intermediate data when computing AB. Persist from symbolic to numeric operations.
1200076ba34aSJunchao Zhang */
1201d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatProductSymbolic_MPIAIJKokkos_AB(Mat_Product *product, Mat A, Mat B, MatMatStruct_AB *mm)
1202d71ae5a4SJacob Faibussowitsch {
12030e3ece09SJunchao Zhang   Mat_MPIAIJ     *ampi = static_cast<Mat_MPIAIJ *>(A->data);
12040e3ece09SJunchao Zhang   Mat_MPIAIJ     *bmpi = static_cast<Mat_MPIAIJ *>(B->data);
12050e3ece09SJunchao Zhang   KokkosCsrMatrix Ad, Ao, Bd, Bo;
1206076ba34aSJunchao Zhang 
1207076ba34aSJunchao Zhang   PetscFunctionBegin;
12080e3ece09SJunchao Zhang   PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(ampi->A, &Ad));
12090e3ece09SJunchao Zhang   PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(ampi->B, &Ao));
12100e3ece09SJunchao Zhang   PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(bmpi->A, &Bd));
12110e3ece09SJunchao Zhang   PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(bmpi->B, &Bo));
12120e3ece09SJunchao Zhang 
12130e3ece09SJunchao Zhang   // TODO: add command line options to select spgemm algorithms
12140e3ece09SJunchao Zhang   auto spgemm_alg = KokkosSparse::SPGEMMAlgorithm::SPGEMM_DEFAULT; // default is TPL if enabled, otherwise KK
12150e3ece09SJunchao Zhang 
12160e3ece09SJunchao Zhang   // CUDA-10.2's spgemm has bugs. We prefer the SpGEMMreuse APIs introduced in cuda-11.4
12170e3ece09SJunchao Zhang #if defined(KOKKOSKERNELS_ENABLE_TPL_CUSPARSE)
12180e3ece09SJunchao Zhang   #if PETSC_PKG_CUDA_VERSION_LT(11, 4, 0)
12190e3ece09SJunchao Zhang   spgemm_alg = KokkosSparse::SPGEMMAlgorithm::SPGEMM_KK;
12200e3ece09SJunchao Zhang   #endif
1221f0e6e2d1SJunchao Zhang #endif
1222f0e6e2d1SJunchao Zhang 
12230e3ece09SJunchao Zhang   mm->kh1.create_spgemm_handle(spgemm_alg);
12240e3ece09SJunchao Zhang   mm->kh2.create_spgemm_handle(spgemm_alg);
12250e3ece09SJunchao Zhang   mm->kh3.create_spgemm_handle(spgemm_alg);
12260e3ece09SJunchao Zhang   mm->kh4.create_spgemm_handle(spgemm_alg);
1227076ba34aSJunchao Zhang 
12280e3ece09SJunchao Zhang   // Bcast B's rows to form F, and overlap the communication
12297b8d4ba6SJunchao Zhang   PetscIntKokkosViewHost map_h(NoInit("map_h"), bmpi->B->cmap->n);
12300e3ece09SJunchao Zhang   PetscCall(MatMPIAIJKokkosBcastBegin(B, ampi->Mvctx, MAT_INITIAL_MATRIX, map_h.data(), mm));
1231076ba34aSJunchao Zhang 
12320e3ece09SJunchao Zhang   // A's diag * (B's diag + B's off-diag)
12330e3ece09SJunchao Zhang   PetscCallCXX(KokkosSparse::spgemm_symbolic(mm->kh1, Ad, false, Bd, false, mm->C1));
12340e3ece09SJunchao Zhang   PetscCallCXX(KokkosSparse::spgemm_symbolic(mm->kh2, Ad, false, Bo, false, mm->C2_mid)); // C2 aliases with C2_mid, except with new column indices
12350e3ece09SJunchao Zhang   // KK spgemm_symbolic() only populates the result's row map, but not its columns.
12360e3ece09SJunchao Zhang   // TODO: Remove the fake spgemm_numeric() after KK fixed this problem.
12370e3ece09SJunchao Zhang   PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh1, Ad, false, Bd, false, mm->C1));
12380e3ece09SJunchao Zhang   PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh2, Ad, false, Bo, false, mm->C2_mid));
12390e3ece09SJunchao Zhang #if PETSC_PKG_KOKKOS_KERNELS_VERSION_LT(4, 0, 0)
12400e3ece09SJunchao Zhang   PetscCallCXX(sort_crs_matrix(mm->C1));
12410e3ece09SJunchao Zhang   PetscCallCXX(sort_crs_matrix(mm->C2_mid));
12420e3ece09SJunchao Zhang #endif
1243076ba34aSJunchao Zhang 
12440e3ece09SJunchao Zhang   PetscCall(MatMPIAIJKokkosBcastEnd(B, ampi->Mvctx, MAT_INITIAL_MATRIX, map_h.data(), mm));
1245076ba34aSJunchao Zhang 
12460e3ece09SJunchao Zhang   // A's off-diag * (F's diag + F's off-diag)
12470e3ece09SJunchao Zhang   PetscCallCXX(KokkosSparse::spgemm_symbolic(mm->kh3, Ao, false, mm->Fd, false, mm->C3));
12480e3ece09SJunchao Zhang   PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh3, Ao, false, mm->Fd, false, mm->C3));
12490e3ece09SJunchao Zhang   PetscCallCXX(KokkosSparse::spgemm_symbolic(mm->kh4, Ao, false, mm->Fo, false, mm->C4));
12500e3ece09SJunchao Zhang   PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh4, Ao, false, mm->Fo, false, mm->C4));
12510e3ece09SJunchao Zhang #if PETSC_PKG_KOKKOS_KERNELS_VERSION_LT(4, 0, 0)
12520e3ece09SJunchao Zhang   PetscCallCXX(sort_crs_matrix(mm->C3));
12530e3ece09SJunchao Zhang   PetscCallCXX(sort_crs_matrix(mm->C4));
12540e3ece09SJunchao Zhang #endif
1255076ba34aSJunchao Zhang 
12560e3ece09SJunchao Zhang   // Create C2, which shares a, i arrays with C2_mid, but with new column indices and potentially larger column size
12577b8d4ba6SJunchao Zhang   MatColIdxKokkosView oldj = mm->C2_mid.graph.entries, newj(NoInit("j"), oldj.extent(0));
12580e3ece09SJunchao Zhang   PetscIntKokkosView  map  = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), map_h);
1259d326c3f1SJunchao Zhang   PetscCallCXX(Kokkos::parallel_for(Kokkos::RangePolicy<>(PetscGetKokkosExecutionSpace(), 0, oldj.extent(0)), KOKKOS_LAMBDA(const PetscInt i) { newj(i) = map(oldj(i)); }));
12600e3ece09SJunchao Zhang   mm->C2 = KokkosCsrMatrix("C2", mm->C2_mid.numRows(), mm->n /*new column size*/, mm->C2_mid.nnz(), mm->C2_mid.values, mm->C2_mid.graph.row_map, newj);
12610e3ece09SJunchao Zhang 
12620e3ece09SJunchao Zhang   // C = (Cd, Co) = (C1+C3, C2+C4)
12630e3ece09SJunchao Zhang   mm->kh1.create_spadd_handle(true); // C1, C3 are sorted
12640e3ece09SJunchao Zhang   mm->kh2.create_spadd_handle(true); // C2, C4 are sorted
12650e3ece09SJunchao Zhang   PetscCallCXX(KokkosSparse::spadd_symbolic(&mm->kh1, mm->C1, mm->C3, mm->Cd));
12660e3ece09SJunchao Zhang   PetscCallCXX(KokkosSparse::spadd_symbolic(&mm->kh2, mm->C2, mm->C4, mm->Co));
12670e3ece09SJunchao Zhang   PetscCallCXX(KokkosSparse::spadd_numeric(&mm->kh1, 1.0, mm->C1, 1.0, mm->C3, mm->Cd));
12680e3ece09SJunchao Zhang   PetscCallCXX(KokkosSparse::spadd_numeric(&mm->kh2, 1.0, mm->C2, 1.0, mm->C4, mm->Co));
12693ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
1270076ba34aSJunchao Zhang }
1271076ba34aSJunchao Zhang 
12720e3ece09SJunchao Zhang static PetscErrorCode MatProductNumeric_MPIAIJKokkos_AB(Mat_Product *product, Mat A, Mat B, MatMatStruct_AB *mm)
1273d71ae5a4SJacob Faibussowitsch {
12740e3ece09SJunchao Zhang   Mat_MPIAIJ     *ampi = static_cast<Mat_MPIAIJ *>(A->data);
12750e3ece09SJunchao Zhang   Mat_MPIAIJ     *bmpi = static_cast<Mat_MPIAIJ *>(B->data);
12760e3ece09SJunchao Zhang   KokkosCsrMatrix Ad, Ao, Bd, Bo;
1277076ba34aSJunchao Zhang 
1278076ba34aSJunchao Zhang   PetscFunctionBegin;
12790e3ece09SJunchao Zhang   PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(ampi->A, &Ad));
12800e3ece09SJunchao Zhang   PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(ampi->B, &Ao));
12810e3ece09SJunchao Zhang   PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(bmpi->A, &Bd));
12820e3ece09SJunchao Zhang   PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(bmpi->B, &Bo));
1283076ba34aSJunchao Zhang 
12840e3ece09SJunchao Zhang   // Bcast B's rows to form F, and overlap the communication
12850e3ece09SJunchao Zhang   PetscCall(MatMPIAIJKokkosBcastBegin(B, NULL, MAT_REUSE_MATRIX, NULL, mm));
1286076ba34aSJunchao Zhang 
12870e3ece09SJunchao Zhang   // A's diag * (B's diag + B's off-diag)
12880e3ece09SJunchao Zhang   PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh1, Ad, false, Bd, false, mm->C1));
12890e3ece09SJunchao Zhang   PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh2, Ad, false, Bo, false, mm->C2_mid));
1290076ba34aSJunchao Zhang 
12910e3ece09SJunchao Zhang   PetscCall(MatMPIAIJKokkosBcastEnd(B, NULL, MAT_REUSE_MATRIX, NULL, mm));
1292076ba34aSJunchao Zhang 
12930e3ece09SJunchao Zhang   // A's off-diag * (F's diag + F's off-diag)
12940e3ece09SJunchao Zhang   PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh3, Ao, false, mm->Fd, false, mm->C3));
12950e3ece09SJunchao Zhang   PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh4, Ao, false, mm->Fo, false, mm->C4));
12960e3ece09SJunchao Zhang 
12970e3ece09SJunchao Zhang   // C = (Cd, Co) = (C1+C3, C2+C4)
12980e3ece09SJunchao Zhang   PetscCallCXX(KokkosSparse::spadd_numeric(&mm->kh1, 1.0, mm->C1, 1.0, mm->C3, mm->Cd));
12990e3ece09SJunchao Zhang   PetscCallCXX(KokkosSparse::spadd_numeric(&mm->kh2, 1.0, mm->C2, 1.0, mm->C4, mm->Co));
13003ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
1301076ba34aSJunchao Zhang }
1302076ba34aSJunchao Zhang 
130366976f2fSJacob Faibussowitsch static PetscErrorCode MatProductNumeric_MPIAIJKokkos(Mat C)
1304d71ae5a4SJacob Faibussowitsch {
13050e3ece09SJunchao Zhang   Mat_MPIAIJ                  *cmpi = static_cast<Mat_MPIAIJ *>(C->data);
13060e3ece09SJunchao Zhang   Mat_Product                 *product;
13070e3ece09SJunchao Zhang   MatProductData_MPIAIJKokkos *pdata;
1308076ba34aSJunchao Zhang   MatProductType               ptype;
13090e3ece09SJunchao Zhang   Mat                          A, B;
1310076ba34aSJunchao Zhang 
1311076ba34aSJunchao Zhang   PetscFunctionBegin;
13120e3ece09SJunchao Zhang   MatCheckProduct(C, 1); // make sure C is a product
13130e3ece09SJunchao Zhang   product = C->product;
13140e3ece09SJunchao Zhang   pdata   = static_cast<MatProductData_MPIAIJKokkos *>(product->data);
1315076ba34aSJunchao Zhang   ptype   = product->type;
1316076ba34aSJunchao Zhang   A       = product->A;
1317076ba34aSJunchao Zhang   B       = product->B;
1318076ba34aSJunchao Zhang 
13190e3ece09SJunchao Zhang   // See if numeric has already been done in symbolic (e.g., user calls MatMatMult(A,B,MAT_INITIAL_MATRIX,..,C)).
13200e3ece09SJunchao Zhang   // If yes, skip the numeric, but reset the flag so that next time when user calls MatMatMult(E,F,MAT_REUSE_MATRIX,..,C),
13210e3ece09SJunchao Zhang   // we still do numeric.
13220e3ece09SJunchao Zhang   if (pdata->reusesym) { // numeric reuses results from symbolic
13230e3ece09SJunchao Zhang     pdata->reusesym = PETSC_FALSE;
13243ba16761SJacob Faibussowitsch     PetscFunctionReturn(PETSC_SUCCESS);
1325076ba34aSJunchao Zhang   }
1326076ba34aSJunchao Zhang 
1327076ba34aSJunchao Zhang   if (ptype == MATPRODUCT_AB) {
13280e3ece09SJunchao Zhang     PetscCall(MatProductNumeric_MPIAIJKokkos_AB(product, A, B, pdata->mmAB));
1329076ba34aSJunchao Zhang   } else if (ptype == MATPRODUCT_AtB) {
13300e3ece09SJunchao Zhang     PetscCall(MatProductNumeric_MPIAIJKokkos_AtB(product, A, B, pdata->mmAtB));
13310e3ece09SJunchao Zhang   } else if (ptype == MATPRODUCT_PtAP) { // BtAB, computed by Z = AB; C= BtZ
13320e3ece09SJunchao Zhang     PetscCall(MatProductNumeric_MPIAIJKokkos_AB(product, A, B, pdata->mmAB));
13330e3ece09SJunchao Zhang     PetscCall(MatProductNumeric_MPIAIJKokkos_AtB(product, B, pdata->Z, pdata->mmAtB));
1334076ba34aSJunchao Zhang   }
13350e3ece09SJunchao Zhang 
13360e3ece09SJunchao Zhang   PetscCall(MatSeqAIJKokkosModifyDevice(cmpi->A)); // mark that A, B on device are modified
13370e3ece09SJunchao Zhang   PetscCall(MatSeqAIJKokkosModifyDevice(cmpi->B));
13383ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
1339076ba34aSJunchao Zhang }
1340076ba34aSJunchao Zhang 
134166976f2fSJacob Faibussowitsch static PetscErrorCode MatProductSymbolic_MPIAIJKokkos(Mat C)
1342d71ae5a4SJacob Faibussowitsch {
1343076ba34aSJunchao Zhang   Mat                          A, B;
13440e3ece09SJunchao Zhang   Mat_Product                 *product;
1345076ba34aSJunchao Zhang   MatProductType               ptype;
13460e3ece09SJunchao Zhang   MatProductData_MPIAIJKokkos *pdata;
1347076ba34aSJunchao Zhang   MatMatStruct                *mm = NULL;
13480e3ece09SJunchao Zhang   PetscInt                     m, n, M, N;
13490e3ece09SJunchao Zhang   Mat                          Cd, Co;
13500e3ece09SJunchao Zhang   MPI_Comm                     comm;
1351076ba34aSJunchao Zhang 
1352076ba34aSJunchao Zhang   PetscFunctionBegin;
13530e3ece09SJunchao Zhang   PetscCall(PetscObjectGetComm((PetscObject)C, &comm));
1354076ba34aSJunchao Zhang   MatCheckProduct(C, 1);
13550e3ece09SJunchao Zhang   product = C->product;
13560e3ece09SJunchao Zhang   PetscCheck(!product->data, comm, PETSC_ERR_PLIB, "Product data not empty");
1357076ba34aSJunchao Zhang   ptype = product->type;
1358076ba34aSJunchao Zhang   A     = product->A;
1359076ba34aSJunchao Zhang   B     = product->B;
1360076ba34aSJunchao Zhang 
1361076ba34aSJunchao Zhang   switch (ptype) {
13629371c9d4SSatish Balay   case MATPRODUCT_AB:
13639371c9d4SSatish Balay     m = A->rmap->n;
13649371c9d4SSatish Balay     n = B->cmap->n;
13659371c9d4SSatish Balay     M = A->rmap->N;
13669371c9d4SSatish Balay     N = B->cmap->N;
13679371c9d4SSatish Balay     break;
13689371c9d4SSatish Balay   case MATPRODUCT_AtB:
13699371c9d4SSatish Balay     m = A->cmap->n;
13709371c9d4SSatish Balay     n = B->cmap->n;
13719371c9d4SSatish Balay     M = A->cmap->N;
13729371c9d4SSatish Balay     N = B->cmap->N;
13739371c9d4SSatish Balay     break;
13749371c9d4SSatish Balay   case MATPRODUCT_PtAP:
13759371c9d4SSatish Balay     m = B->cmap->n;
13769371c9d4SSatish Balay     n = B->cmap->n;
13779371c9d4SSatish Balay     M = B->cmap->N;
13789371c9d4SSatish Balay     N = B->cmap->N;
13799371c9d4SSatish Balay     break; /* BtAB */
1380d71ae5a4SJacob Faibussowitsch   default:
13810e3ece09SJunchao Zhang     SETERRQ(comm, PETSC_ERR_PLIB, "Not for product type %s", MatProductTypes[ptype]);
1382076ba34aSJunchao Zhang   }
1383076ba34aSJunchao Zhang 
13849566063dSJacob Faibussowitsch   PetscCall(MatSetSizes(C, m, n, M, N));
13859566063dSJacob Faibussowitsch   PetscCall(PetscLayoutSetUp(C->rmap));
13869566063dSJacob Faibussowitsch   PetscCall(PetscLayoutSetUp(C->cmap));
13879566063dSJacob Faibussowitsch   PetscCall(MatSetType(C, ((PetscObject)A)->type_name));
1388076ba34aSJunchao Zhang 
13890e3ece09SJunchao Zhang   pdata           = new MatProductData_MPIAIJKokkos();
13900e3ece09SJunchao Zhang   pdata->reusesym = product->api_user;
1391076ba34aSJunchao Zhang 
1392076ba34aSJunchao Zhang   if (ptype == MATPRODUCT_AB) {
13930e3ece09SJunchao Zhang     auto mmAB = new MatMatStruct_AB();
13940e3ece09SJunchao Zhang     PetscCall(MatProductSymbolic_MPIAIJKokkos_AB(product, A, B, mmAB));
13950e3ece09SJunchao Zhang     mm = pdata->mmAB = mmAB;
1396076ba34aSJunchao Zhang   } else if (ptype == MATPRODUCT_AtB) {
13970e3ece09SJunchao Zhang     auto mmAtB = new MatMatStruct_AtB();
13980e3ece09SJunchao Zhang     PetscCall(MatProductSymbolic_MPIAIJKokkos_AtB(product, A, B, mmAtB));
13990e3ece09SJunchao Zhang     mm = pdata->mmAtB = mmAtB;
14000e3ece09SJunchao Zhang   } else if (ptype == MATPRODUCT_PtAP) { // C = BtAB, computed as Z = AB; C= BtZ
14010e3ece09SJunchao Zhang     Mat Zd, Zo, Z;                       // Zd, Zo are owned by pdata->Z
14020e3ece09SJunchao Zhang 
14030e3ece09SJunchao Zhang     auto mmAB = new MatMatStruct_AB();
14040e3ece09SJunchao Zhang     PetscCall(MatProductSymbolic_MPIAIJKokkos_AB(product, A, B, mmAB)); // Z stored as mmAB->{Cd, Co}
14050e3ece09SJunchao Zhang     PetscCall(MatCreateSeqAIJKokkosWithKokkosCsrMatrix(PETSC_COMM_SELF, mmAB->Cd, &Zd));
14060e3ece09SJunchao Zhang     PetscCall(MatCreateSeqAIJKokkosWithKokkosCsrMatrix(PETSC_COMM_SELF, mmAB->Co, &Zo));
14070e3ece09SJunchao Zhang     pdata->mmAB = mmAB;
14080e3ece09SJunchao Zhang 
14090e3ece09SJunchao Zhang     m = A->rmap->n; // Z's layout
14100e3ece09SJunchao Zhang     n = B->cmap->n;
14110e3ece09SJunchao Zhang     M = A->rmap->N;
14120e3ece09SJunchao Zhang     N = B->cmap->N;
14130e3ece09SJunchao Zhang     PetscCall(MatCreate(comm, &Z));
14140e3ece09SJunchao Zhang     PetscCall(MatSetSizes(Z, m, n, M, N));
14150e3ece09SJunchao Zhang     PetscCall(PetscLayoutSetUp(Z->rmap));
14160e3ece09SJunchao Zhang     PetscCall(PetscLayoutSetUp(Z->cmap));
14170e3ece09SJunchao Zhang     PetscCall(MatSetType(Z, MATMPIAIJKOKKOS));
14180e3ece09SJunchao Zhang     PetscCall(MatSetMPIAIJKokkosWithSplitSeqAIJKokkosMatrices(Z, Zd, Zo, mmAB->garray));
14190e3ece09SJunchao Zhang 
14200e3ece09SJunchao Zhang     auto mmAtB = new MatMatStruct_AtB();
14210e3ece09SJunchao Zhang     PetscCall(MatProductSymbolic_MPIAIJKokkos_AtB(product, B, Z, mmAtB)); // final result C stored as mmAtB->{Cd, Co}
14220e3ece09SJunchao Zhang 
14230e3ece09SJunchao Zhang     pdata->Z = Z; // give ownership to pdata
14240e3ece09SJunchao Zhang     mm = pdata->mmAtB = mmAtB;
1425076ba34aSJunchao Zhang   }
14260e3ece09SJunchao Zhang 
14270e3ece09SJunchao Zhang   PetscCall(MatCreateSeqAIJKokkosWithKokkosCsrMatrix(PETSC_COMM_SELF, mm->Cd, &Cd));
14280e3ece09SJunchao Zhang   PetscCall(MatCreateSeqAIJKokkosWithKokkosCsrMatrix(PETSC_COMM_SELF, mm->Co, &Co));
14290e3ece09SJunchao Zhang   PetscCall(MatSetMPIAIJKokkosWithSplitSeqAIJKokkosMatrices(C, Cd, Co, mm->garray));
14300e3ece09SJunchao Zhang 
14310e3ece09SJunchao Zhang   C->product->data       = pdata;
1432076ba34aSJunchao Zhang   C->product->destroy    = MatProductDataDestroy_MPIAIJKokkos;
1433076ba34aSJunchao Zhang   C->ops->productnumeric = MatProductNumeric_MPIAIJKokkos;
14343ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
1435076ba34aSJunchao Zhang }
1436076ba34aSJunchao Zhang 
1437d71ae5a4SJacob Faibussowitsch PETSC_INTERN PetscErrorCode MatProductSetFromOptions_MPIAIJKokkos(Mat mat)
1438d71ae5a4SJacob Faibussowitsch {
1439076ba34aSJunchao Zhang   Mat_Product *product = mat->product;
1440076ba34aSJunchao Zhang   PetscBool    match   = PETSC_FALSE;
1441076ba34aSJunchao Zhang   PetscBool    usecpu  = PETSC_FALSE;
1442076ba34aSJunchao Zhang 
1443076ba34aSJunchao Zhang   PetscFunctionBegin;
1444076ba34aSJunchao Zhang   MatCheckProduct(mat, 1);
144548a46eb9SPierre Jolivet   if (!product->A->boundtocpu && !product->B->boundtocpu) PetscCall(PetscObjectTypeCompare((PetscObject)product->B, ((PetscObject)product->A)->type_name, &match));
1446076ba34aSJunchao Zhang   if (match) { /* we can always fallback to the CPU if requested */
1447076ba34aSJunchao Zhang     switch (product->type) {
1448076ba34aSJunchao Zhang     case MATPRODUCT_AB:
1449076ba34aSJunchao Zhang       if (product->api_user) {
1450d0609cedSBarry Smith         PetscOptionsBegin(PetscObjectComm((PetscObject)mat), ((PetscObject)mat)->prefix, "MatMatMult", "Mat");
14519566063dSJacob Faibussowitsch         PetscCall(PetscOptionsBool("-matmatmult_backend_cpu", "Use CPU code", "MatMatMult", usecpu, &usecpu, NULL));
1452d0609cedSBarry Smith         PetscOptionsEnd();
1453076ba34aSJunchao Zhang       } else {
1454d0609cedSBarry Smith         PetscOptionsBegin(PetscObjectComm((PetscObject)mat), ((PetscObject)mat)->prefix, "MatProduct_AB", "Mat");
14559566063dSJacob Faibussowitsch         PetscCall(PetscOptionsBool("-mat_product_algorithm_backend_cpu", "Use CPU code", "MatMatMult", usecpu, &usecpu, NULL));
1456d0609cedSBarry Smith         PetscOptionsEnd();
1457076ba34aSJunchao Zhang       }
1458076ba34aSJunchao Zhang       break;
1459076ba34aSJunchao Zhang     case MATPRODUCT_AtB:
1460076ba34aSJunchao Zhang       if (product->api_user) {
1461d0609cedSBarry Smith         PetscOptionsBegin(PetscObjectComm((PetscObject)mat), ((PetscObject)mat)->prefix, "MatTransposeMatMult", "Mat");
14629566063dSJacob Faibussowitsch         PetscCall(PetscOptionsBool("-mattransposematmult_backend_cpu", "Use CPU code", "MatTransposeMatMult", usecpu, &usecpu, NULL));
1463d0609cedSBarry Smith         PetscOptionsEnd();
1464076ba34aSJunchao Zhang       } else {
1465d0609cedSBarry Smith         PetscOptionsBegin(PetscObjectComm((PetscObject)mat), ((PetscObject)mat)->prefix, "MatProduct_AtB", "Mat");
14669566063dSJacob Faibussowitsch         PetscCall(PetscOptionsBool("-mat_product_algorithm_backend_cpu", "Use CPU code", "MatTransposeMatMult", usecpu, &usecpu, NULL));
1467d0609cedSBarry Smith         PetscOptionsEnd();
1468076ba34aSJunchao Zhang       }
1469076ba34aSJunchao Zhang       break;
1470076ba34aSJunchao Zhang     case MATPRODUCT_PtAP:
1471076ba34aSJunchao Zhang       if (product->api_user) {
1472d0609cedSBarry Smith         PetscOptionsBegin(PetscObjectComm((PetscObject)mat), ((PetscObject)mat)->prefix, "MatPtAP", "Mat");
14739566063dSJacob Faibussowitsch         PetscCall(PetscOptionsBool("-matptap_backend_cpu", "Use CPU code", "MatPtAP", usecpu, &usecpu, NULL));
1474d0609cedSBarry Smith         PetscOptionsEnd();
1475076ba34aSJunchao Zhang       } else {
1476d0609cedSBarry Smith         PetscOptionsBegin(PetscObjectComm((PetscObject)mat), ((PetscObject)mat)->prefix, "MatProduct_PtAP", "Mat");
14779566063dSJacob Faibussowitsch         PetscCall(PetscOptionsBool("-mat_product_algorithm_backend_cpu", "Use CPU code", "MatPtAP", usecpu, &usecpu, NULL));
1478d0609cedSBarry Smith         PetscOptionsEnd();
1479076ba34aSJunchao Zhang       }
1480076ba34aSJunchao Zhang       break;
1481d71ae5a4SJacob Faibussowitsch     default:
1482d71ae5a4SJacob Faibussowitsch       break;
1483076ba34aSJunchao Zhang     }
1484076ba34aSJunchao Zhang     match = (PetscBool)!usecpu;
1485076ba34aSJunchao Zhang   }
1486076ba34aSJunchao Zhang   if (match) {
1487076ba34aSJunchao Zhang     switch (product->type) {
1488076ba34aSJunchao Zhang     case MATPRODUCT_AB:
1489076ba34aSJunchao Zhang     case MATPRODUCT_AtB:
1490d71ae5a4SJacob Faibussowitsch     case MATPRODUCT_PtAP:
1491d71ae5a4SJacob Faibussowitsch       mat->ops->productsymbolic = MatProductSymbolic_MPIAIJKokkos;
1492d71ae5a4SJacob Faibussowitsch       break;
1493d71ae5a4SJacob Faibussowitsch     default:
1494d71ae5a4SJacob Faibussowitsch       break;
1495076ba34aSJunchao Zhang     }
1496076ba34aSJunchao Zhang   }
1497076ba34aSJunchao Zhang   /* fallback to MPIAIJ ops */
149848a46eb9SPierre Jolivet   if (!mat->ops->productsymbolic) PetscCall(MatProductSetFromOptions_MPIAIJ(mat));
14993ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
1500076ba34aSJunchao Zhang }
1501076ba34aSJunchao Zhang 
15022c4ab24aSJunchao Zhang // Mirror of MatCOOStruct_MPIAIJ on device
15032c4ab24aSJunchao Zhang struct MatCOOStruct_MPIAIJKokkos {
15042c4ab24aSJunchao Zhang   PetscCount           n;
15052c4ab24aSJunchao Zhang   PetscSF              sf;
15062c4ab24aSJunchao Zhang   PetscCount           Annz, Bnnz;
15072c4ab24aSJunchao Zhang   PetscCount           Annz2, Bnnz2;
15082c4ab24aSJunchao Zhang   PetscCountKokkosView Ajmap1, Aperm1;
15092c4ab24aSJunchao Zhang   PetscCountKokkosView Bjmap1, Bperm1;
15102c4ab24aSJunchao Zhang   PetscCountKokkosView Aimap2, Ajmap2, Aperm2;
15112c4ab24aSJunchao Zhang   PetscCountKokkosView Bimap2, Bjmap2, Bperm2;
15122c4ab24aSJunchao Zhang   PetscCountKokkosView Cperm1;
15132c4ab24aSJunchao Zhang   MatScalarKokkosView  sendbuf, recvbuf;
15142c4ab24aSJunchao Zhang 
15152c4ab24aSJunchao Zhang   MatCOOStruct_MPIAIJKokkos(const MatCOOStruct_MPIAIJ *coo_h) :
15162c4ab24aSJunchao Zhang     n(coo_h->n),
15172c4ab24aSJunchao Zhang     sf(coo_h->sf),
15182c4ab24aSJunchao Zhang     Annz(coo_h->Annz),
15192c4ab24aSJunchao Zhang     Bnnz(coo_h->Bnnz),
15202c4ab24aSJunchao Zhang     Annz2(coo_h->Annz2),
15212c4ab24aSJunchao Zhang     Bnnz2(coo_h->Bnnz2),
15222c4ab24aSJunchao Zhang     Ajmap1(Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), PetscCountKokkosViewHost(coo_h->Ajmap1, coo_h->Annz + 1))),
15232c4ab24aSJunchao Zhang     Aperm1(Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), PetscCountKokkosViewHost(coo_h->Aperm1, coo_h->Atot1))),
15242c4ab24aSJunchao Zhang     Bjmap1(Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), PetscCountKokkosViewHost(coo_h->Bjmap1, coo_h->Bnnz + 1))),
15252c4ab24aSJunchao Zhang     Bperm1(Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), PetscCountKokkosViewHost(coo_h->Bperm1, coo_h->Btot1))),
15262c4ab24aSJunchao Zhang     Aimap2(Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), PetscCountKokkosViewHost(coo_h->Aimap2, coo_h->Annz2))),
15272c4ab24aSJunchao Zhang     Ajmap2(Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), PetscCountKokkosViewHost(coo_h->Ajmap2, coo_h->Annz2 + 1))),
15282c4ab24aSJunchao Zhang     Aperm2(Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), PetscCountKokkosViewHost(coo_h->Aperm2, coo_h->Atot2))),
15292c4ab24aSJunchao Zhang     Bimap2(Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), PetscCountKokkosViewHost(coo_h->Bimap2, coo_h->Bnnz2))),
15302c4ab24aSJunchao Zhang     Bjmap2(Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), PetscCountKokkosViewHost(coo_h->Bjmap2, coo_h->Bnnz2 + 1))),
15312c4ab24aSJunchao Zhang     Bperm2(Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), PetscCountKokkosViewHost(coo_h->Bperm2, coo_h->Btot2))),
15322c4ab24aSJunchao Zhang     Cperm1(Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), PetscCountKokkosViewHost(coo_h->Cperm1, coo_h->sendlen))),
15336892b982SJunchao Zhang     sendbuf(Kokkos::create_mirror_view(Kokkos::WithoutInitializing, DefaultMemorySpace(), MatScalarKokkosViewHost(coo_h->sendbuf, coo_h->sendlen))),
15346892b982SJunchao Zhang     recvbuf(Kokkos::create_mirror_view(Kokkos::WithoutInitializing, DefaultMemorySpace(), MatScalarKokkosViewHost(coo_h->recvbuf, coo_h->recvlen)))
15352c4ab24aSJunchao Zhang   {
15362c4ab24aSJunchao Zhang     PetscCallVoid(PetscObjectReference((PetscObject)sf));
15372c4ab24aSJunchao Zhang   }
15382c4ab24aSJunchao Zhang 
15392c4ab24aSJunchao Zhang   ~MatCOOStruct_MPIAIJKokkos() { PetscCallVoid(PetscSFDestroy(&sf)); }
15402c4ab24aSJunchao Zhang };
15412c4ab24aSJunchao Zhang 
15422c4ab24aSJunchao Zhang static PetscErrorCode MatCOOStructDestroy_MPIAIJKokkos(void *data)
15432c4ab24aSJunchao Zhang {
15442c4ab24aSJunchao Zhang   PetscFunctionBegin;
15452c4ab24aSJunchao Zhang   PetscCallCXX(delete static_cast<MatCOOStruct_MPIAIJKokkos *>(data));
15462c4ab24aSJunchao Zhang   PetscFunctionReturn(PETSC_SUCCESS);
15472c4ab24aSJunchao Zhang }
15482c4ab24aSJunchao Zhang 
1549d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatSetPreallocationCOO_MPIAIJKokkos(Mat mat, PetscCount coo_n, PetscInt coo_i[], PetscInt coo_j[])
1550d71ae5a4SJacob Faibussowitsch {
15512c4ab24aSJunchao Zhang   PetscContainer             container_h, container_d;
15522c4ab24aSJunchao Zhang   MatCOOStruct_MPIAIJ       *coo_h;
15532c4ab24aSJunchao Zhang   MatCOOStruct_MPIAIJKokkos *coo_d;
155442550becSJunchao Zhang 
155542550becSJunchao Zhang   PetscFunctionBegin;
155630203840SJunchao Zhang   PetscCall(MatSetPreallocationCOO_MPIAIJ(mat, coo_n, coo_i, coo_j)); /* mpiaij->A,B's type is set to seqaijkokkos */
1557cbc6b225SStefano Zampini   mat->preallocated = PETSC_TRUE;
15589566063dSJacob Faibussowitsch   PetscCall(MatAssemblyBegin(mat, MAT_FINAL_ASSEMBLY));
15599566063dSJacob Faibussowitsch   PetscCall(MatAssemblyEnd(mat, MAT_FINAL_ASSEMBLY));
15609566063dSJacob Faibussowitsch   PetscCall(MatZeroEntries(mat));
15612c4ab24aSJunchao Zhang 
15622c4ab24aSJunchao Zhang   // Copy the COO struct to device
15632c4ab24aSJunchao Zhang   PetscCall(PetscObjectQuery((PetscObject)mat, "__PETSc_MatCOOStruct_Host", (PetscObject *)&container_h));
15642c4ab24aSJunchao Zhang   PetscCall(PetscContainerGetPointer(container_h, (void **)&coo_h));
15652c4ab24aSJunchao Zhang   PetscCallCXX(coo_d = new MatCOOStruct_MPIAIJKokkos(coo_h));
15662c4ab24aSJunchao Zhang 
15672c4ab24aSJunchao Zhang   // Put the COO struct in a container and then attach that to the matrix
15682c4ab24aSJunchao Zhang   PetscCall(PetscContainerCreate(PETSC_COMM_SELF, &container_d));
15692c4ab24aSJunchao Zhang   PetscCall(PetscContainerSetPointer(container_d, coo_d));
15702c4ab24aSJunchao Zhang   PetscCall(PetscContainerSetUserDestroy(container_d, MatCOOStructDestroy_MPIAIJKokkos));
15712c4ab24aSJunchao Zhang   PetscCall(PetscObjectCompose((PetscObject)mat, "__PETSc_MatCOOStruct_Device", (PetscObject)container_d));
15722c4ab24aSJunchao Zhang   PetscCall(PetscContainerDestroy(&container_d));
15733ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
157442550becSJunchao Zhang }
157542550becSJunchao Zhang 
1576d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatSetValuesCOO_MPIAIJKokkos(Mat mat, const PetscScalar v[], InsertMode imode)
1577d71ae5a4SJacob Faibussowitsch {
1578394ed5ebSJunchao Zhang   Mat_MPIAIJ                *mpiaij = static_cast<Mat_MPIAIJ *>(mat->data);
157942550becSJunchao Zhang   Mat                        A = mpiaij->A, B = mpiaij->B;
158042550becSJunchao Zhang   MatScalarKokkosView        Aa, Ba;
1581394ed5ebSJunchao Zhang   MatScalarKokkosView        v1;
158242550becSJunchao Zhang   PetscMemType               memtype;
15832c4ab24aSJunchao Zhang   PetscContainer             container;
15842c4ab24aSJunchao Zhang   MatCOOStruct_MPIAIJKokkos *coo;
158542550becSJunchao Zhang 
158642550becSJunchao Zhang   PetscFunctionBegin;
15872c4ab24aSJunchao Zhang   PetscCall(PetscObjectQuery((PetscObject)mat, "__PETSc_MatCOOStruct_Device", (PetscObject *)&container));
15882c4ab24aSJunchao Zhang   PetscCall(PetscContainerGetPointer(container, (void **)&coo));
15892c4ab24aSJunchao Zhang 
15902c4ab24aSJunchao Zhang   const auto &n      = coo->n;
15912c4ab24aSJunchao Zhang   const auto &Annz   = coo->Annz;
15922c4ab24aSJunchao Zhang   const auto &Annz2  = coo->Annz2;
15932c4ab24aSJunchao Zhang   const auto &Bnnz   = coo->Bnnz;
15942c4ab24aSJunchao Zhang   const auto &Bnnz2  = coo->Bnnz2;
15952c4ab24aSJunchao Zhang   const auto &vsend  = coo->sendbuf;
15962c4ab24aSJunchao Zhang   const auto &v2     = coo->recvbuf;
15972c4ab24aSJunchao Zhang   const auto &Ajmap1 = coo->Ajmap1;
15982c4ab24aSJunchao Zhang   const auto &Ajmap2 = coo->Ajmap2;
15992c4ab24aSJunchao Zhang   const auto &Aimap2 = coo->Aimap2;
16002c4ab24aSJunchao Zhang   const auto &Bjmap1 = coo->Bjmap1;
16012c4ab24aSJunchao Zhang   const auto &Bjmap2 = coo->Bjmap2;
16022c4ab24aSJunchao Zhang   const auto &Bimap2 = coo->Bimap2;
16032c4ab24aSJunchao Zhang   const auto &Aperm1 = coo->Aperm1;
16042c4ab24aSJunchao Zhang   const auto &Aperm2 = coo->Aperm2;
16052c4ab24aSJunchao Zhang   const auto &Bperm1 = coo->Bperm1;
16062c4ab24aSJunchao Zhang   const auto &Bperm2 = coo->Bperm2;
16072c4ab24aSJunchao Zhang   const auto &Cperm1 = coo->Cperm1;
16082c4ab24aSJunchao Zhang 
16099566063dSJacob Faibussowitsch   PetscCall(PetscGetMemType(v, &memtype)); /* Return PETSC_MEMTYPE_HOST when v is NULL */
161042550becSJunchao Zhang   if (PetscMemTypeHost(memtype)) {         /* If user gave v[] in host, we need to copy it to device if any */
16112c4ab24aSJunchao Zhang     v1 = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), MatScalarKokkosViewHost((PetscScalar *)v, n));
161242550becSJunchao Zhang   } else {
16132c4ab24aSJunchao Zhang     v1 = MatScalarKokkosView((PetscScalar *)v, n); /* Directly use v[]'s memory */
161442550becSJunchao Zhang   }
161542550becSJunchao Zhang 
161642550becSJunchao Zhang   if (imode == INSERT_VALUES) {
16179566063dSJacob Faibussowitsch     PetscCall(MatSeqAIJGetKokkosViewWrite(A, &Aa)); /* write matrix values */
16189566063dSJacob Faibussowitsch     PetscCall(MatSeqAIJGetKokkosViewWrite(B, &Ba));
1619394ed5ebSJunchao Zhang   } else {
16209566063dSJacob Faibussowitsch     PetscCall(MatSeqAIJGetKokkosView(A, &Aa)); /* read & write matrix values */
16219566063dSJacob Faibussowitsch     PetscCall(MatSeqAIJGetKokkosView(B, &Ba));
162242550becSJunchao Zhang   }
162342550becSJunchao Zhang 
162408bb9926SJunchao Zhang   PetscCall(PetscLogGpuTimeBegin());
162542550becSJunchao Zhang   /* Pack entries to be sent to remote */
1626d326c3f1SJunchao Zhang   Kokkos::parallel_for(Kokkos::RangePolicy<>(PetscGetKokkosExecutionSpace(), 0, vsend.extent(0)), KOKKOS_LAMBDA(const PetscCount i) { vsend(i) = v1(Cperm1(i)); });
162742550becSJunchao Zhang 
162842550becSJunchao Zhang   /* Send remote entries to their owner and overlap the communication with local computation */
16292c4ab24aSJunchao Zhang   PetscCall(PetscSFReduceWithMemTypeBegin(coo->sf, MPIU_SCALAR, PETSC_MEMTYPE_KOKKOS, vsend.data(), PETSC_MEMTYPE_KOKKOS, v2.data(), MPI_REPLACE));
1630158ec288SJunchao Zhang   /* Add local entries to A and B in one kernel */
16319371c9d4SSatish Balay   Kokkos::parallel_for(
1632d326c3f1SJunchao Zhang     Kokkos::RangePolicy<>(PetscGetKokkosExecutionSpace(), 0, Annz + Bnnz), KOKKOS_LAMBDA(PetscCount i) {
1633158ec288SJunchao Zhang       PetscScalar sum = 0.0;
1634158ec288SJunchao Zhang       if (i < Annz) {
1635158ec288SJunchao Zhang         for (PetscCount k = Ajmap1(i); k < Ajmap1(i + 1); k++) sum += v1(Aperm1(k));
1636ac38520cSJunchao Zhang         Aa(i) = (imode == INSERT_VALUES ? 0.0 : Aa(i)) + sum;
1637158ec288SJunchao Zhang       } else {
1638158ec288SJunchao Zhang         i -= Annz;
1639158ec288SJunchao Zhang         for (PetscCount k = Bjmap1(i); k < Bjmap1(i + 1); k++) sum += v1(Bperm1(k));
1640ac38520cSJunchao Zhang         Ba(i) = (imode == INSERT_VALUES ? 0.0 : Ba(i)) + sum;
1641158ec288SJunchao Zhang       }
1642158ec288SJunchao Zhang     });
16432c4ab24aSJunchao Zhang   PetscCall(PetscSFReduceEnd(coo->sf, MPIU_SCALAR, vsend.data(), v2.data(), MPI_REPLACE));
164442550becSJunchao Zhang 
1645158ec288SJunchao Zhang   /* Add received remote entries to A and B in one kernel */
16469371c9d4SSatish Balay   Kokkos::parallel_for(
1647d326c3f1SJunchao Zhang     Kokkos::RangePolicy<>(PetscGetKokkosExecutionSpace(), 0, Annz2 + Bnnz2), KOKKOS_LAMBDA(PetscCount i) {
1648158ec288SJunchao Zhang       if (i < Annz2) {
1649158ec288SJunchao Zhang         for (PetscCount k = Ajmap2(i); k < Ajmap2(i + 1); k++) Aa(Aimap2(i)) += v2(Aperm2(k));
1650158ec288SJunchao Zhang       } else {
1651158ec288SJunchao Zhang         i -= Annz2;
1652158ec288SJunchao Zhang         for (PetscCount k = Bjmap2(i); k < Bjmap2(i + 1); k++) Ba(Bimap2(i)) += v2(Bperm2(k));
1653158ec288SJunchao Zhang       }
1654158ec288SJunchao Zhang     });
165508bb9926SJunchao Zhang   PetscCall(PetscLogGpuTimeEnd());
165642550becSJunchao Zhang 
1657394ed5ebSJunchao Zhang   if (imode == INSERT_VALUES) {
16589566063dSJacob Faibussowitsch     PetscCall(MatSeqAIJRestoreKokkosViewWrite(A, &Aa)); /* Increase A & B's state etc. */
16599566063dSJacob Faibussowitsch     PetscCall(MatSeqAIJRestoreKokkosViewWrite(B, &Ba));
1660394ed5ebSJunchao Zhang   } else {
16619566063dSJacob Faibussowitsch     PetscCall(MatSeqAIJRestoreKokkosView(A, &Aa));
16629566063dSJacob Faibussowitsch     PetscCall(MatSeqAIJRestoreKokkosView(B, &Ba));
1663394ed5ebSJunchao Zhang   }
16643ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
166542550becSJunchao Zhang }
166642550becSJunchao Zhang 
16672c4ab24aSJunchao Zhang static PetscErrorCode MatDestroy_MPIAIJKokkos(Mat A)
1668d71ae5a4SJacob Faibussowitsch {
1669076ba34aSJunchao Zhang   PetscFunctionBegin;
16709566063dSJacob Faibussowitsch   PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatMPIAIJSetPreallocation_C", NULL));
16719566063dSJacob Faibussowitsch   PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatMPIAIJGetLocalMatMerge_C", NULL));
16729566063dSJacob Faibussowitsch   PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatSetPreallocationCOO_C", NULL));
16739566063dSJacob Faibussowitsch   PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatSetValuesCOO_C", NULL));
16749566063dSJacob Faibussowitsch   PetscCall(MatDestroy_MPIAIJ(A));
16753ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
1676076ba34aSJunchao Zhang }
1677076ba34aSJunchao Zhang 
1678f4747e26SJunchao Zhang static PetscErrorCode MatShift_MPIAIJKokkos(Mat A, PetscScalar a)
1679f4747e26SJunchao Zhang {
1680f4747e26SJunchao Zhang   Mat_MPIAIJ *mpiaij = static_cast<Mat_MPIAIJ *>(A->data);
1681f4747e26SJunchao Zhang   PetscBool   congruent;
1682f4747e26SJunchao Zhang 
1683f4747e26SJunchao Zhang   PetscFunctionBegin;
1684f4747e26SJunchao Zhang   PetscCall(MatHasCongruentLayouts(A, &congruent));
1685f4747e26SJunchao Zhang   if (congruent) { // square matrix and the diagonals are solely in the diag block
1686f4747e26SJunchao Zhang     PetscCall(MatShift(mpiaij->A, a));
1687f4747e26SJunchao Zhang   } else { // too hard, use the general version
1688f4747e26SJunchao Zhang     PetscCall(MatShift_Basic(A, a));
1689f4747e26SJunchao Zhang   }
1690f4747e26SJunchao Zhang   PetscFunctionReturn(PETSC_SUCCESS);
1691f4747e26SJunchao Zhang }
1692f4747e26SJunchao Zhang 
16932c4ab24aSJunchao Zhang static PetscErrorCode MatSetOps_MPIAIJKokkos(Mat B)
16942c4ab24aSJunchao Zhang {
16952c4ab24aSJunchao Zhang   PetscFunctionBegin;
16962c4ab24aSJunchao Zhang   B->ops->assemblyend           = MatAssemblyEnd_MPIAIJKokkos;
16972c4ab24aSJunchao Zhang   B->ops->mult                  = MatMult_MPIAIJKokkos;
16982c4ab24aSJunchao Zhang   B->ops->multadd               = MatMultAdd_MPIAIJKokkos;
16992c4ab24aSJunchao Zhang   B->ops->multtranspose         = MatMultTranspose_MPIAIJKokkos;
17002c4ab24aSJunchao Zhang   B->ops->productsetfromoptions = MatProductSetFromOptions_MPIAIJKokkos;
17012c4ab24aSJunchao Zhang   B->ops->destroy               = MatDestroy_MPIAIJKokkos;
1702f4747e26SJunchao Zhang   B->ops->shift                 = MatShift_MPIAIJKokkos;
17032c4ab24aSJunchao Zhang 
17042c4ab24aSJunchao Zhang   PetscCall(PetscObjectComposeFunction((PetscObject)B, "MatMPIAIJSetPreallocation_C", MatMPIAIJSetPreallocation_MPIAIJKokkos));
17052c4ab24aSJunchao Zhang   PetscCall(PetscObjectComposeFunction((PetscObject)B, "MatMPIAIJGetLocalMatMerge_C", MatMPIAIJGetLocalMatMerge_MPIAIJKokkos));
17062c4ab24aSJunchao Zhang   PetscCall(PetscObjectComposeFunction((PetscObject)B, "MatSetPreallocationCOO_C", MatSetPreallocationCOO_MPIAIJKokkos));
17072c4ab24aSJunchao Zhang   PetscCall(PetscObjectComposeFunction((PetscObject)B, "MatSetValuesCOO_C", MatSetValuesCOO_MPIAIJKokkos));
17082c4ab24aSJunchao Zhang   PetscFunctionReturn(PETSC_SUCCESS);
17092c4ab24aSJunchao Zhang }
17102c4ab24aSJunchao Zhang 
1711d71ae5a4SJacob Faibussowitsch PETSC_INTERN PetscErrorCode MatConvert_MPIAIJ_MPIAIJKokkos(Mat A, MatType mtype, MatReuse reuse, Mat *newmat)
1712d71ae5a4SJacob Faibussowitsch {
17138c3ff71bSJunchao Zhang   Mat         B;
1714076ba34aSJunchao Zhang   Mat_MPIAIJ *a;
17158c3ff71bSJunchao Zhang 
17168c3ff71bSJunchao Zhang   PetscFunctionBegin;
17178c3ff71bSJunchao Zhang   if (reuse == MAT_INITIAL_MATRIX) {
17189566063dSJacob Faibussowitsch     PetscCall(MatDuplicate(A, MAT_COPY_VALUES, newmat));
17198c3ff71bSJunchao Zhang   } else if (reuse == MAT_REUSE_MATRIX) {
17209566063dSJacob Faibussowitsch     PetscCall(MatCopy(A, *newmat, SAME_NONZERO_PATTERN));
17218c3ff71bSJunchao Zhang   }
17228c3ff71bSJunchao Zhang   B = *newmat;
17238c3ff71bSJunchao Zhang 
17246f3d89d0SStefano Zampini   B->boundtocpu = PETSC_FALSE;
17259566063dSJacob Faibussowitsch   PetscCall(PetscFree(B->defaultvectype));
17269566063dSJacob Faibussowitsch   PetscCall(PetscStrallocpy(VECKOKKOS, &B->defaultvectype));
17279566063dSJacob Faibussowitsch   PetscCall(PetscObjectChangeTypeName((PetscObject)B, MATMPIAIJKOKKOS));
17288c3ff71bSJunchao Zhang 
1729076ba34aSJunchao Zhang   a = static_cast<Mat_MPIAIJ *>(A->data);
17309566063dSJacob Faibussowitsch   if (a->A) PetscCall(MatSetType(a->A, MATSEQAIJKOKKOS));
17319566063dSJacob Faibussowitsch   if (a->B) PetscCall(MatSetType(a->B, MATSEQAIJKOKKOS));
17329566063dSJacob Faibussowitsch   if (a->lvec) PetscCall(VecSetType(a->lvec, VECSEQKOKKOS));
17332c4ab24aSJunchao Zhang   PetscCall(MatSetOps_MPIAIJKokkos(B));
17343ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
17358c3ff71bSJunchao Zhang }
17362c4ab24aSJunchao Zhang 
17373f3ba80aSJunchao Zhang /*MC
173811a5261eSBarry Smith    MATAIJKOKKOS - "mpiaijkokkos", a matrix type to be used for CSR sparse matrices with Kokkos
17398c3ff71bSJunchao Zhang 
174015229ffcSPierre Jolivet    A matrix type using Kokkos-Kernels CrsMatrix type for portability across different device types
17413f3ba80aSJunchao Zhang 
17422ef1f0ffSBarry Smith    Options Database Key:
17432ef1f0ffSBarry Smith .  -mat_type aijkokkos - sets the matrix type to `MATAIJKOKKOS`
17443f3ba80aSJunchao Zhang 
17453f3ba80aSJunchao Zhang   Level: beginner
17463f3ba80aSJunchao Zhang 
17471cc06b55SBarry Smith .seealso: [](ch_matrices), `Mat`, `MatCreateAIJKokkos()`, `MATSEQAIJKOKKOS`, `MATSEQAIJ`, `MATMPIAIJ`
17483f3ba80aSJunchao Zhang M*/
1749d71ae5a4SJacob Faibussowitsch PETSC_EXTERN PetscErrorCode MatCreate_MPIAIJKokkos(Mat A)
1750d71ae5a4SJacob Faibussowitsch {
17518c3ff71bSJunchao Zhang   PetscFunctionBegin;
17529566063dSJacob Faibussowitsch   PetscCall(PetscKokkosInitializeCheck());
17539566063dSJacob Faibussowitsch   PetscCall(MatCreate_MPIAIJ(A));
17549566063dSJacob Faibussowitsch   PetscCall(MatConvert_MPIAIJ_MPIAIJKokkos(A, MATMPIAIJKOKKOS, MAT_INPLACE_MATRIX, &A));
17553ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
17568c3ff71bSJunchao Zhang }
17578c3ff71bSJunchao Zhang 
17588c3ff71bSJunchao Zhang /*@C
175911a5261eSBarry Smith   MatCreateAIJKokkos - Creates a sparse matrix in `MATAIJKOKOS` (compressed row) format
17608c3ff71bSJunchao Zhang   (the default parallel PETSc format).  This matrix will ultimately pushed down
176120f4b53cSBarry Smith   to Kokkos for calculations.
17628c3ff71bSJunchao Zhang 
17638c3ff71bSJunchao Zhang   Collective
17648c3ff71bSJunchao Zhang 
17658c3ff71bSJunchao Zhang   Input Parameters:
176611a5261eSBarry Smith + comm  - MPI communicator, set to `PETSC_COMM_SELF`
176720f4b53cSBarry Smith . m     - number of local rows (or `PETSC_DECIDE` to have calculated if `M` is given)
176820f4b53cSBarry Smith            This value should be the same as the local size used in creating the
176920f4b53cSBarry Smith            y vector for the matrix-vector product y = Ax.
177020f4b53cSBarry Smith . n     - This value should be the same as the local size used in creating the
177120f4b53cSBarry Smith        x vector for the matrix-vector product y = Ax. (or `PETSC_DECIDE` to have
177220f4b53cSBarry Smith        calculated if N is given) For square matrices n is almost always `m`.
177320f4b53cSBarry Smith . M     - number of global rows (or `PETSC_DETERMINE` to have calculated if `m` is given)
177420f4b53cSBarry Smith . N     - number of global columns (or `PETSC_DETERMINE` to have calculated if `n` is given)
177520f4b53cSBarry Smith . d_nz  - number of nonzeros per row in DIAGONAL portion of local submatrix
177620f4b53cSBarry Smith            (same value is used for all local rows)
177720f4b53cSBarry Smith . d_nnz - array containing the number of nonzeros in the various rows of the
177820f4b53cSBarry Smith            DIAGONAL portion of the local submatrix (possibly different for each row)
177920f4b53cSBarry Smith            or `NULL`, if `d_nz` is used to specify the nonzero structure.
178020f4b53cSBarry Smith            The size of this array is equal to the number of local rows, i.e `m`.
178120f4b53cSBarry Smith            For matrices you plan to factor you must leave room for the diagonal entry and
178220f4b53cSBarry Smith            put in the entry even if it is zero.
178320f4b53cSBarry Smith . o_nz  - number of nonzeros per row in the OFF-DIAGONAL portion of local
178420f4b53cSBarry Smith            submatrix (same value is used for all local rows).
178520f4b53cSBarry Smith - o_nnz - array containing the number of nonzeros in the various rows of the
178620f4b53cSBarry Smith            OFF-DIAGONAL portion of the local submatrix (possibly different for
178720f4b53cSBarry Smith            each row) or `NULL`, if `o_nz` is used to specify the nonzero
178820f4b53cSBarry Smith            structure. The size of this array is equal to the number
178920f4b53cSBarry Smith            of local rows, i.e `m`.
17908c3ff71bSJunchao Zhang 
17918c3ff71bSJunchao Zhang   Output Parameter:
17928c3ff71bSJunchao Zhang . A - the matrix
17938c3ff71bSJunchao Zhang 
17942ef1f0ffSBarry Smith   Level: intermediate
17952ef1f0ffSBarry Smith 
17962ef1f0ffSBarry Smith   Notes:
179711a5261eSBarry Smith   It is recommended that one use the `MatCreate()`, `MatSetType()` and/or `MatSetFromOptions()`,
17988c3ff71bSJunchao Zhang   MatXXXXSetPreallocation() paradigm instead of this routine directly.
179911a5261eSBarry Smith   [MatXXXXSetPreallocation() is, for example, `MatSeqAIJSetPreallocation()`]
18008c3ff71bSJunchao Zhang 
1801667f096bSBarry Smith   The AIJ format, also called compressed row storage), is fully compatible with standard Fortran
18028c3ff71bSJunchao Zhang   storage.  That is, the stored row and column indices can begin at
18032ef1f0ffSBarry Smith   either one (as in Fortran) or zero.
18048c3ff71bSJunchao Zhang 
18051cc06b55SBarry Smith .seealso: [](ch_matrices), `Mat`, `MATAIJKOKOS`, `MATSEQAIJKOKOS`, `MATMPIAIJKOKOS`, `MatCreate()`, `MatCreateAIJ()`, `MatSetValues()`,
1806fe59aa6dSJacob Faibussowitsch           `MatSeqAIJSetColumnIndices()`, `MatCreateSeqAIJWithArrays()`, `MATMPIAIJKOKKOS`, `MATAIJKOKKOS`
18078c3ff71bSJunchao Zhang @*/
1808d71ae5a4SJacob Faibussowitsch PetscErrorCode MatCreateAIJKokkos(MPI_Comm comm, PetscInt m, PetscInt n, PetscInt M, PetscInt N, PetscInt d_nz, const PetscInt d_nnz[], PetscInt o_nz, const PetscInt o_nnz[], Mat *A)
1809d71ae5a4SJacob Faibussowitsch {
18108c3ff71bSJunchao Zhang   PetscMPIInt size;
18118c3ff71bSJunchao Zhang 
18128c3ff71bSJunchao Zhang   PetscFunctionBegin;
18139566063dSJacob Faibussowitsch   PetscCall(MatCreate(comm, A));
18149566063dSJacob Faibussowitsch   PetscCall(MatSetSizes(*A, m, n, M, N));
18159566063dSJacob Faibussowitsch   PetscCallMPI(MPI_Comm_size(comm, &size));
18168c3ff71bSJunchao Zhang   if (size > 1) {
18179566063dSJacob Faibussowitsch     PetscCall(MatSetType(*A, MATMPIAIJKOKKOS));
18189566063dSJacob Faibussowitsch     PetscCall(MatMPIAIJSetPreallocation(*A, d_nz, d_nnz, o_nz, o_nnz));
18198c3ff71bSJunchao Zhang   } else {
18209566063dSJacob Faibussowitsch     PetscCall(MatSetType(*A, MATSEQAIJKOKKOS));
18219566063dSJacob Faibussowitsch     PetscCall(MatSeqAIJSetPreallocation(*A, d_nz, d_nnz));
18228c3ff71bSJunchao Zhang   }
18233ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
18248c3ff71bSJunchao Zhang }
1825