xref: /petsc/src/mat/impls/aij/mpi/kokkos/mpiaijkok.kokkos.cxx (revision 39cfb5082524f0710c9e9368670444fb0d934d7e)
1d326c3f1SJunchao Zhang #include <petsc_kokkos.hpp>
211d22bbfSJunchao Zhang #include <petscvec_kokkos.hpp>
3f0e6e2d1SJunchao Zhang #include <petscpkg_version.h>
442550becSJunchao Zhang #include <petsc/private/sfimpl.h>
592896123SJunchao Zhang #include <petsc/private/kokkosimpl.hpp>
62c4ab24aSJunchao Zhang #include <../src/mat/impls/aij/seq/kokkos/aijkok.hpp>
78c3ff71bSJunchao Zhang #include <../src/mat/impls/aij/mpi/mpiaij.h>
8076ba34aSJunchao Zhang #include <KokkosSparse_spadd.hpp>
90e3ece09SJunchao Zhang #include <KokkosSparse_spgemm.hpp>
1011d22bbfSJunchao Zhang 
1166976f2fSJacob Faibussowitsch static PetscErrorCode MatAssemblyEnd_MPIAIJKokkos(Mat A, MatAssemblyType mode)
12d71ae5a4SJacob Faibussowitsch {
1330203840SJunchao Zhang   Mat_MPIAIJ *mpiaij = (Mat_MPIAIJ *)A->data;
148c3ff71bSJunchao Zhang 
158c3ff71bSJunchao Zhang   PetscFunctionBegin;
169566063dSJacob Faibussowitsch   PetscCall(MatAssemblyEnd_MPIAIJ(A, mode));
1730203840SJunchao Zhang   /* E.g., MatCreateSubMatrix() calls MatCreateMPIAIJWithSeqAIJ(comm,A,B,..), which creates Bnew of SEQAIJ and destroys B of SEQAIJKOKKOS.
1830203840SJunchao Zhang      Thus we finalize A/B/lvec's type in MatAssemblyEnd() to handle various cases.
1930203840SJunchao Zhang    */
2030203840SJunchao Zhang   if (mode == MAT_FINAL_ASSEMBLY) {
2192896123SJunchao Zhang     PetscScalarKokkosView v;
2292896123SJunchao Zhang 
2330203840SJunchao Zhang     PetscCall(MatSetType(mpiaij->A, MATSEQAIJKOKKOS));
2430203840SJunchao Zhang     PetscCall(MatSetType(mpiaij->B, MATSEQAIJKOKKOS));
2592896123SJunchao Zhang     PetscCall(VecSetType(mpiaij->lvec, VECSEQKOKKOS));  // lvec is init'ed on host, without copying to device
2692896123SJunchao Zhang     PetscCall(VecGetKokkosViewWrite(mpiaij->lvec, &v)); // mark lvec updated on device, as we never need to init lvec on device
2792896123SJunchao Zhang     PetscCall(VecRestoreKokkosViewWrite(mpiaij->lvec, &v));
2830203840SJunchao Zhang   }
293ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
308c3ff71bSJunchao Zhang }
318c3ff71bSJunchao Zhang 
3266976f2fSJacob Faibussowitsch static PetscErrorCode MatMPIAIJSetPreallocation_MPIAIJKokkos(Mat mat, PetscInt d_nz, const PetscInt d_nnz[], PetscInt o_nz, const PetscInt o_nnz[])
33d71ae5a4SJacob Faibussowitsch {
342cdb1aeaSJunchao Zhang   Mat_MPIAIJ *mpiaij;
358c3ff71bSJunchao Zhang 
368c3ff71bSJunchao Zhang   PetscFunctionBegin;
372cdb1aeaSJunchao Zhang   // reuse MPIAIJ's preallocation, which sets A/B's blocksize along other things
382cdb1aeaSJunchao Zhang   PetscCall(MatMPIAIJSetPreallocation_MPIAIJ(mat, d_nz, d_nnz, o_nz, o_nnz));
392cdb1aeaSJunchao Zhang   mpiaij = static_cast<Mat_MPIAIJ *>(mat->data);
402cdb1aeaSJunchao Zhang   PetscCall(MatConvert_SeqAIJ_SeqAIJKokkos(mpiaij->A, MATSEQAIJKOKKOS, MAT_INPLACE_MATRIX, &mpiaij->A));
412cdb1aeaSJunchao Zhang   PetscCall(MatConvert_SeqAIJ_SeqAIJKokkos(mpiaij->B, MATSEQAIJKOKKOS, MAT_INPLACE_MATRIX, &mpiaij->B));
423ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
438c3ff71bSJunchao Zhang }
448c3ff71bSJunchao Zhang 
4566976f2fSJacob Faibussowitsch static PetscErrorCode MatMult_MPIAIJKokkos(Mat mat, Vec xx, Vec yy)
46d71ae5a4SJacob Faibussowitsch {
478c3ff71bSJunchao Zhang   Mat_MPIAIJ *mpiaij = (Mat_MPIAIJ *)mat->data;
488c3ff71bSJunchao Zhang   PetscInt    nt;
498c3ff71bSJunchao Zhang 
508c3ff71bSJunchao Zhang   PetscFunctionBegin;
519566063dSJacob Faibussowitsch   PetscCall(VecGetLocalSize(xx, &nt));
5208401ef6SPierre Jolivet   PetscCheck(nt == mat->cmap->n, PETSC_COMM_SELF, PETSC_ERR_ARG_SIZ, "Incompatible partition of mat (%" PetscInt_FMT ") and xx (%" PetscInt_FMT ")", mat->cmap->n, nt);
539566063dSJacob Faibussowitsch   PetscCall(VecScatterBegin(mpiaij->Mvctx, xx, mpiaij->lvec, INSERT_VALUES, SCATTER_FORWARD));
549566063dSJacob Faibussowitsch   PetscCall((*mpiaij->A->ops->mult)(mpiaij->A, xx, yy));
559566063dSJacob Faibussowitsch   PetscCall(VecScatterEnd(mpiaij->Mvctx, xx, mpiaij->lvec, INSERT_VALUES, SCATTER_FORWARD));
569566063dSJacob Faibussowitsch   PetscCall((*mpiaij->B->ops->multadd)(mpiaij->B, mpiaij->lvec, yy, yy));
573ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
588c3ff71bSJunchao Zhang }
598c3ff71bSJunchao Zhang 
6066976f2fSJacob Faibussowitsch static PetscErrorCode MatMultAdd_MPIAIJKokkos(Mat mat, Vec xx, Vec yy, Vec zz)
61d71ae5a4SJacob Faibussowitsch {
628c3ff71bSJunchao Zhang   Mat_MPIAIJ *mpiaij = (Mat_MPIAIJ *)mat->data;
638c3ff71bSJunchao Zhang   PetscInt    nt;
648c3ff71bSJunchao Zhang 
658c3ff71bSJunchao Zhang   PetscFunctionBegin;
669566063dSJacob Faibussowitsch   PetscCall(VecGetLocalSize(xx, &nt));
6708401ef6SPierre Jolivet   PetscCheck(nt == mat->cmap->n, PETSC_COMM_SELF, PETSC_ERR_ARG_SIZ, "Incompatible partition of mat (%" PetscInt_FMT ") and xx (%" PetscInt_FMT ")", mat->cmap->n, nt);
689566063dSJacob Faibussowitsch   PetscCall(VecScatterBegin(mpiaij->Mvctx, xx, mpiaij->lvec, INSERT_VALUES, SCATTER_FORWARD));
699566063dSJacob Faibussowitsch   PetscCall((*mpiaij->A->ops->multadd)(mpiaij->A, xx, yy, zz));
709566063dSJacob Faibussowitsch   PetscCall(VecScatterEnd(mpiaij->Mvctx, xx, mpiaij->lvec, INSERT_VALUES, SCATTER_FORWARD));
719566063dSJacob Faibussowitsch   PetscCall((*mpiaij->B->ops->multadd)(mpiaij->B, mpiaij->lvec, zz, zz));
723ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
738c3ff71bSJunchao Zhang }
748c3ff71bSJunchao Zhang 
7566976f2fSJacob Faibussowitsch static PetscErrorCode MatMultTranspose_MPIAIJKokkos(Mat mat, Vec xx, Vec yy)
76d71ae5a4SJacob Faibussowitsch {
778c3ff71bSJunchao Zhang   Mat_MPIAIJ *mpiaij = (Mat_MPIAIJ *)mat->data;
788c3ff71bSJunchao Zhang   PetscInt    nt;
798c3ff71bSJunchao Zhang 
808c3ff71bSJunchao Zhang   PetscFunctionBegin;
819566063dSJacob Faibussowitsch   PetscCall(VecGetLocalSize(xx, &nt));
8208401ef6SPierre Jolivet   PetscCheck(nt == mat->rmap->n, PETSC_COMM_SELF, PETSC_ERR_ARG_SIZ, "Incompatible partition of mat (%" PetscInt_FMT ") and xx (%" PetscInt_FMT ")", mat->rmap->n, nt);
839566063dSJacob Faibussowitsch   PetscCall((*mpiaij->B->ops->multtranspose)(mpiaij->B, xx, mpiaij->lvec));
849566063dSJacob Faibussowitsch   PetscCall((*mpiaij->A->ops->multtranspose)(mpiaij->A, xx, yy));
859566063dSJacob Faibussowitsch   PetscCall(VecScatterBegin(mpiaij->Mvctx, mpiaij->lvec, yy, ADD_VALUES, SCATTER_REVERSE));
869566063dSJacob Faibussowitsch   PetscCall(VecScatterEnd(mpiaij->Mvctx, mpiaij->lvec, yy, ADD_VALUES, SCATTER_REVERSE));
873ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
888c3ff71bSJunchao Zhang }
898c3ff71bSJunchao Zhang 
90076ba34aSJunchao Zhang /* Merge the "A, B" matrices of mat into a matrix C.  mat's type is MPIAIJKOKKOS. C's type is MATSEQAIJKOKKOS.
91076ba34aSJunchao Zhang    A is put before B. C's size would be A->rmap->n by (A->cmap->n + B->cmap->n).
92076ba34aSJunchao Zhang    C still uses local column ids. Their corresponding global column ids are returned in glob.
93076ba34aSJunchao Zhang */
9466976f2fSJacob Faibussowitsch static PetscErrorCode MatMPIAIJGetLocalMatMerge_MPIAIJKokkos(Mat mat, MatReuse reuse, IS *glob, Mat *C)
95d71ae5a4SJacob Faibussowitsch {
96076ba34aSJunchao Zhang   Mat             Ad, Ao;
97076ba34aSJunchao Zhang   const PetscInt *cmap;
98076ba34aSJunchao Zhang 
99076ba34aSJunchao Zhang   PetscFunctionBegin;
1009566063dSJacob Faibussowitsch   PetscCall(MatMPIAIJGetSeqAIJ(mat, &Ad, &Ao, &cmap));
1019566063dSJacob Faibussowitsch   PetscCall(MatSeqAIJKokkosMergeMats(Ad, Ao, reuse, C));
102076ba34aSJunchao Zhang   if (glob) {
103076ba34aSJunchao Zhang     PetscInt cst, i, dn, on, *gidx;
1049566063dSJacob Faibussowitsch     PetscCall(MatGetLocalSize(Ad, NULL, &dn));
1059566063dSJacob Faibussowitsch     PetscCall(MatGetLocalSize(Ao, NULL, &on));
1069566063dSJacob Faibussowitsch     PetscCall(MatGetOwnershipRangeColumn(mat, &cst, NULL));
1079566063dSJacob Faibussowitsch     PetscCall(PetscMalloc1(dn + on, &gidx));
108076ba34aSJunchao Zhang     for (i = 0; i < dn; i++) gidx[i] = cst + i;
109076ba34aSJunchao Zhang     for (i = 0; i < on; i++) gidx[i + dn] = cmap[i];
1109566063dSJacob Faibussowitsch     PetscCall(ISCreateGeneral(PetscObjectComm((PetscObject)Ad), dn + on, gidx, PETSC_OWN_POINTER, glob));
111076ba34aSJunchao Zhang   }
1123ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
113076ba34aSJunchao Zhang }
114076ba34aSJunchao Zhang 
1150e3ece09SJunchao Zhang /* Structs used in matrix products of type C=AB, C=A^tB and C=B^tAB */
116076ba34aSJunchao Zhang struct MatMatStruct {
1170e3ece09SJunchao Zhang   PetscInt            n, *garray;     // C's garray and its size.
1180e3ece09SJunchao Zhang   KokkosCsrMatrix     Cd, Co;         // C is in split form matrices (all in local column indcies)
1190e3ece09SJunchao Zhang   KokkosCsrMatrix     C1, C2, C3, C4; // intermediate mat products
1200e3ece09SJunchao Zhang   KokkosCsrMatrix     C2_mid, C4_mid; // alias of C2, C4; share their a[], i[], but with different j[] (hence column size)
1210e3ece09SJunchao Zhang   PetscIntKokkosView  E_NzLeft;
1220e3ece09SJunchao Zhang   PetscSF             sf = nullptr; // SF to bcast or reduce matrices E to F
1230e3ece09SJunchao Zhang   MatScalarKokkosView rootBuf, leafBuf;
1240e3ece09SJunchao Zhang   KokkosCsrMatrix     Fd, Fo; // F in split form
1250e3ece09SJunchao Zhang 
1260e3ece09SJunchao Zhang   KernelHandle kh1; // compute C1, add C1+C3 or C1+Fd
1270e3ece09SJunchao Zhang   KernelHandle kh2; // compute C2, add C2+C4 or C2+Fo
1280e3ece09SJunchao Zhang   KernelHandle kh3; // compute C3
1290e3ece09SJunchao Zhang   KernelHandle kh4; // compute C4
1300e3ece09SJunchao Zhang 
131aaa8cc7dSPierre Jolivet   PetscInt E_TeamSize; // kernel launching parameters in merging E or splitting F
1320e3ece09SJunchao Zhang   PetscInt E_VectorLength;
1330e3ece09SJunchao Zhang   PetscInt E_RowsPerTeam;
1340e3ece09SJunchao Zhang   PetscInt F_TeamSize;
1350e3ece09SJunchao Zhang   PetscInt F_VectorLength;
1360e3ece09SJunchao Zhang   PetscInt F_RowsPerTeam;
137076ba34aSJunchao Zhang 
138d71ae5a4SJacob Faibussowitsch   ~MatMatStruct()
139d71ae5a4SJacob Faibussowitsch   {
1403ba16761SJacob Faibussowitsch     PetscFunctionBegin;
1413ba16761SJacob Faibussowitsch     PetscCallAbort(PETSC_COMM_SELF, PetscSFDestroy(&sf));
1423ba16761SJacob Faibussowitsch     PetscFunctionReturnVoid();
143076ba34aSJunchao Zhang   }
144076ba34aSJunchao Zhang };
145076ba34aSJunchao Zhang 
146076ba34aSJunchao Zhang struct MatMatStruct_AB : public MatMatStruct {
1470e3ece09SJunchao Zhang   PetscIntKokkosView F_NzLeft; // plans to split F (in leafbuf) into Fd, Fo
1480e3ece09SJunchao Zhang   PetscIntKokkosView irootloc; // plans to put E (i.e., Bd, Bo) into rootBuf
1490e3ece09SJunchao Zhang   PetscIntKokkosView rowoffset;
150076ba34aSJunchao Zhang };
151076ba34aSJunchao Zhang 
152076ba34aSJunchao Zhang struct MatMatStruct_AtB : public MatMatStruct {
1530e3ece09SJunchao Zhang   MatColIdxKokkosView Fdjmap; // plans to reduce data in rootBuf to Fd, Fo
1540e3ece09SJunchao Zhang   MatColIdxKokkosView Fdjperm;
1550e3ece09SJunchao Zhang   MatColIdxKokkosView Fojmap;
1560e3ece09SJunchao Zhang   MatColIdxKokkosView Fojperm;
157076ba34aSJunchao Zhang };
158076ba34aSJunchao Zhang 
1599371c9d4SSatish Balay struct MatProductData_MPIAIJKokkos {
1603ba16761SJacob Faibussowitsch   MatMatStruct_AB  *mmAB     = nullptr;
1613ba16761SJacob Faibussowitsch   MatMatStruct_AtB *mmAtB    = nullptr;
1623ba16761SJacob Faibussowitsch   PetscBool         reusesym = PETSC_FALSE;
1630e3ece09SJunchao Zhang   Mat               Z        = nullptr; // store Z=AB in computing BtAB
164076ba34aSJunchao Zhang 
165d71ae5a4SJacob Faibussowitsch   ~MatProductData_MPIAIJKokkos()
166d71ae5a4SJacob Faibussowitsch   {
167076ba34aSJunchao Zhang     delete mmAB;
168076ba34aSJunchao Zhang     delete mmAtB;
1690e3ece09SJunchao Zhang     PetscCallAbort(PETSC_COMM_SELF, MatDestroy(&Z));
170076ba34aSJunchao Zhang   }
171076ba34aSJunchao Zhang };
172076ba34aSJunchao Zhang 
173d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatProductDataDestroy_MPIAIJKokkos(void *data)
174d71ae5a4SJacob Faibussowitsch {
175076ba34aSJunchao Zhang   PetscFunctionBegin;
1769566063dSJacob Faibussowitsch   PetscCallCXX(delete static_cast<MatProductData_MPIAIJKokkos *>(data));
1773ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
178076ba34aSJunchao Zhang }
179076ba34aSJunchao Zhang 
180076ba34aSJunchao Zhang /* MatSetMPIAIJKokkosWithSplitSeqAIJKokkosMatrices - Set the diag and offdiag matrices of a MATMPIAIJKOKKOS matrix.
181076ba34aSJunchao Zhang    It is similar to MatCreateMPIAIJWithSplitArrays.
182076ba34aSJunchao Zhang 
183076ba34aSJunchao Zhang   Input Parameters:
184076ba34aSJunchao Zhang +  mat   - the MATMPIAIJKOKKOS matrix, which should have its type and layout set, but should not have its diag, offdiag matrices set
185076ba34aSJunchao Zhang .  A     - the diag matrix using local col ids
186076ba34aSJunchao Zhang -  B     - the offdiag matrix using global col ids
187076ba34aSJunchao Zhang 
1882fe279fdSBarry Smith   Output Parameter:
189076ba34aSJunchao Zhang .  mat   - the updated MATMPIAIJKOKKOS matrix
190076ba34aSJunchao Zhang */
1910e3ece09SJunchao Zhang static PetscErrorCode MatSetMPIAIJKokkosWithSplitSeqAIJKokkosMatrices(Mat mat, Mat A, Mat B, PetscInt *garray)
192d71ae5a4SJacob Faibussowitsch {
193076ba34aSJunchao Zhang   Mat_MPIAIJ *mpiaij = static_cast<Mat_MPIAIJ *>(mat->data);
194076ba34aSJunchao Zhang   PetscInt    m, n, M, N, Am, An, Bm, Bn;
195076ba34aSJunchao Zhang 
196076ba34aSJunchao Zhang   PetscFunctionBegin;
1979566063dSJacob Faibussowitsch   PetscCall(MatGetSize(mat, &M, &N));
1989566063dSJacob Faibussowitsch   PetscCall(MatGetLocalSize(mat, &m, &n));
1999566063dSJacob Faibussowitsch   PetscCall(MatGetLocalSize(A, &Am, &An));
2009566063dSJacob Faibussowitsch   PetscCall(MatGetLocalSize(B, &Bm, &Bn));
201076ba34aSJunchao Zhang 
202aed4548fSBarry Smith   PetscCheck(m == Am && m == Bm, PETSC_COMM_SELF, PETSC_ERR_PLIB, "local number of rows do not match");
20308401ef6SPierre Jolivet   PetscCheck(n == An, PETSC_COMM_SELF, PETSC_ERR_PLIB, "local number of columns do not match");
2040e3ece09SJunchao Zhang   // PetscCheck(N == Bn, PETSC_COMM_SELF, PETSC_ERR_PLIB, "global number of columns do not match");
20508401ef6SPierre Jolivet   PetscCheck(!mpiaij->A && !mpiaij->B, PETSC_COMM_SELF, PETSC_ERR_PLIB, "A, B of the MPIAIJ matrix are not empty");
206076ba34aSJunchao Zhang   mpiaij->A      = A;
207076ba34aSJunchao Zhang   mpiaij->B      = B;
2080e3ece09SJunchao Zhang   mpiaij->garray = garray;
209076ba34aSJunchao Zhang 
210076ba34aSJunchao Zhang   mat->preallocated     = PETSC_TRUE;
211076ba34aSJunchao Zhang   mat->nooffprocentries = PETSC_TRUE; /* See MatAssemblyBegin_MPIAIJ. In effect, making MatAssemblyBegin a nop */
212076ba34aSJunchao Zhang 
2139566063dSJacob Faibussowitsch   PetscCall(MatSetOption(mat, MAT_NO_OFF_PROC_ENTRIES, PETSC_TRUE));
2149566063dSJacob Faibussowitsch   PetscCall(MatAssemblyBegin(mat, MAT_FINAL_ASSEMBLY));
215076ba34aSJunchao Zhang   /* MatAssemblyEnd is critical here. It sets mat->offloadmask according to A and B's, and
216076ba34aSJunchao Zhang     also gets mpiaij->B compacted, with its col ids and size reduced
217076ba34aSJunchao Zhang   */
2189566063dSJacob Faibussowitsch   PetscCall(MatAssemblyEnd(mat, MAT_FINAL_ASSEMBLY));
2199566063dSJacob Faibussowitsch   PetscCall(MatSetOption(mat, MAT_NO_OFF_PROC_ENTRIES, PETSC_FALSE));
2209566063dSJacob Faibussowitsch   PetscCall(MatSetOption(mat, MAT_NEW_NONZERO_LOCATION_ERR, PETSC_TRUE));
2213ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
222076ba34aSJunchao Zhang }
223076ba34aSJunchao Zhang 
2240e3ece09SJunchao Zhang // Adapted from Kokkos-Kernels spmv_launch_parameters(), to get parameters in Kokkos nested loops which we used to merge or
2250e3ece09SJunchao Zhang // split csr matrices. The rule is to have "vector_length * team_size" be around 256 on GPUs (e.g., for a CUDA thread block)
2260e3ece09SJunchao Zhang template <class ExecutionSpace>
2270e3ece09SJunchao Zhang static PetscErrorCode MatMergeGetLaunchParameters(PetscInt numRows, PetscInt nnz, PetscInt rows_per_thread, PetscInt &team_size, PetscInt &vector_length, PetscInt &rows_per_team)
228d71ae5a4SJacob Faibussowitsch {
2291aa660a0SJunchao Zhang #if PETSC_PKG_KOKKOS_KERNELS_VERSION_LE(4, 4, 1)
2301aa660a0SJunchao Zhang   constexpr bool is_gpu_exec_space = KokkosKernels::Impl::kk_is_gpu_exec_space<ExecutionSpace>();
2311aa660a0SJunchao Zhang #else
2321aa660a0SJunchao Zhang   constexpr bool is_gpu_exec_space = KokkosKernels::Impl::is_gpu_exec_space_v<ExecutionSpace>;
2331aa660a0SJunchao Zhang #endif
2340e3ece09SJunchao Zhang   Kokkos::TeamPolicy<ExecutionSpace> teamPolicy(128, Kokkos::AUTO);
235076ba34aSJunchao Zhang 
236076ba34aSJunchao Zhang   PetscFunctionBegin;
2370e3ece09SJunchao Zhang   PetscInt nnz_per_row = numRows ? (nnz / numRows) : 0; // we might meet empty matrices
238076ba34aSJunchao Zhang 
2390e3ece09SJunchao Zhang   if (nnz_per_row < 1) nnz_per_row = 1;
240076ba34aSJunchao Zhang 
2410e3ece09SJunchao Zhang   int max_vector_length = teamPolicy.vector_length_max();
242076ba34aSJunchao Zhang 
2430e3ece09SJunchao Zhang   if (vector_length < 1) {
2440e3ece09SJunchao Zhang     vector_length = 1;
2450e3ece09SJunchao Zhang     while (vector_length < max_vector_length && vector_length * 6 < nnz_per_row) vector_length *= 2;
246076ba34aSJunchao Zhang   }
247076ba34aSJunchao Zhang 
2480e3ece09SJunchao Zhang   // Determine rows per thread
2490e3ece09SJunchao Zhang   if (rows_per_thread < 1) {
2501aa660a0SJunchao Zhang     if (is_gpu_exec_space) rows_per_thread = 1;
2510e3ece09SJunchao Zhang     else {
2520e3ece09SJunchao Zhang       if (nnz_per_row < 20 && nnz > 5000000) {
2530e3ece09SJunchao Zhang         rows_per_thread = 256;
2540e3ece09SJunchao Zhang       } else rows_per_thread = 64;
255076ba34aSJunchao Zhang     }
256076ba34aSJunchao Zhang   }
257076ba34aSJunchao Zhang 
2580e3ece09SJunchao Zhang   if (team_size < 1) {
2591aa660a0SJunchao Zhang     if (is_gpu_exec_space) {
2600e3ece09SJunchao Zhang       team_size = 256 / vector_length;
261076ba34aSJunchao Zhang     } else {
2620e3ece09SJunchao Zhang       team_size = 1;
2630e3ece09SJunchao Zhang     }
264076ba34aSJunchao Zhang   }
265076ba34aSJunchao Zhang 
2660e3ece09SJunchao Zhang   rows_per_team = rows_per_thread * team_size;
267076ba34aSJunchao Zhang 
2680e3ece09SJunchao Zhang   if (rows_per_team < 0) {
2690e3ece09SJunchao Zhang     PetscInt nnz_per_team = 4096;
2700e3ece09SJunchao Zhang     PetscInt conc         = ExecutionSpace().concurrency();
2710e3ece09SJunchao Zhang     while ((conc * nnz_per_team * 4 > nnz) && (nnz_per_team > 256)) nnz_per_team /= 2;
2720e3ece09SJunchao Zhang     rows_per_team = (nnz_per_team + nnz_per_row - 1) / nnz_per_row;
2730e3ece09SJunchao Zhang   }
2743ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
275076ba34aSJunchao Zhang }
276076ba34aSJunchao Zhang 
2770e3ece09SJunchao Zhang /*
2780e3ece09SJunchao Zhang   Reduce two sets of global indices into local ones
279076ba34aSJunchao Zhang 
280076ba34aSJunchao Zhang   Input Parameters:
2810e3ece09SJunchao Zhang +  n1          - size of garray1[], the first set
2820e3ece09SJunchao Zhang .  garray1[n1] - a sorted global index array (without duplicates)
2830e3ece09SJunchao Zhang .  m           - size of indices[], the second set
2840e3ece09SJunchao Zhang -  indices[m]  - a unsorted global index array (might have duplicates), which will be updated on output into local ones
285076ba34aSJunchao Zhang 
286076ba34aSJunchao Zhang   Output Parameters:
2870e3ece09SJunchao Zhang +  n2          - size of garray2[], the merged set, which combines garray1[] and indices[]
2880e3ece09SJunchao Zhang .  garray2[n2] - allocated by callee using PetscMalloc1(). Contains sorted unique global indices (without duplicates). Caller needs to free it.
2890e3ece09SJunchao Zhang .  map[n1]     - allocated by caller. It gives garray1[i] = garray2[map[i]]
2900e3ece09SJunchao Zhang -  indices[m]  - on output, global indices in this array are rewritten with local ones, i.e, indices_input[i] = garray2[indices_output[i]]
291076ba34aSJunchao Zhang 
2920e3ece09SJunchao Zhang    Example, say
2930e3ece09SJunchao Zhang     n1         = 5
2940e3ece09SJunchao Zhang     garray1[5] = {1, 4, 7, 8, 10}
2950e3ece09SJunchao Zhang     m          = 4
2960e3ece09SJunchao Zhang     indices[4] = {2, 4, 8, 9}
29711a5261eSBarry Smith 
2980e3ece09SJunchao Zhang    Combining them together, we have 7 global indices in garray2[]
2990e3ece09SJunchao Zhang     n2         = 7
3000e3ece09SJunchao Zhang     garray2[7] = {1, 2, 4, 7, 8, 9, 10}
3010e3ece09SJunchao Zhang 
3020e3ece09SJunchao Zhang    And we have map[] to connect "garray1[i] = garray2[map[i]], i=[0,n1)"
3030e3ece09SJunchao Zhang     map[5] = {0, 2, 3, 4, 6}
3040e3ece09SJunchao Zhang 
3050e3ece09SJunchao Zhang    On output, indices[] is updated with local indices
3060e3ece09SJunchao Zhang     indices[4] = {1, 2, 4, 5}
307076ba34aSJunchao Zhang */
3080e3ece09SJunchao Zhang static PetscErrorCode ReduceTwoSetsOfGlobalIndices(PetscInt n1, const PetscInt *garray1, PetscInt m, PetscInt *indices, PetscInt *n2_, PetscInt **garray2_, PetscInt *map)
309d71ae5a4SJacob Faibussowitsch {
3100e3ece09SJunchao Zhang   PetscHMapI    g2l = nullptr;
3110e3ece09SJunchao Zhang   PetscHashIter iter;
3120e3ece09SJunchao Zhang   PetscInt      tot, key, val; // total unique global indices. key is global id; val is local id
3130e3ece09SJunchao Zhang   PetscInt      n2, *garray2;
314076ba34aSJunchao Zhang 
315076ba34aSJunchao Zhang   PetscFunctionBegin;
3160e3ece09SJunchao Zhang   tot = 0;
3170e3ece09SJunchao Zhang   PetscCall(PetscHMapICreateWithSize(n1, &g2l));
3180e3ece09SJunchao Zhang   for (PetscInt i = 0; i < m; i++) {                                // insert those in indices[]
3190e3ece09SJunchao Zhang     PetscCall(PetscHMapIGetWithDefault(g2l, indices[i], -1, &val)); // if not exist, val is set with -1
3200e3ece09SJunchao Zhang     if (val < 0) PetscCall(PetscHMapISet(g2l, indices[i], tot++));  // val < 0 means gid is not in the hash table yet
321076ba34aSJunchao Zhang   }
322076ba34aSJunchao Zhang 
3230e3ece09SJunchao Zhang   for (PetscInt i = 0; i < n1; i++) { // insert those in garray1[]
3240e3ece09SJunchao Zhang     PetscCall(PetscHMapIGetWithDefault(g2l, garray1[i], -1, &val));
3250e3ece09SJunchao Zhang     if (val < 0) PetscCall(PetscHMapISet(g2l, garray1[i], tot++));
326076ba34aSJunchao Zhang   }
327076ba34aSJunchao Zhang 
3280e3ece09SJunchao Zhang   // Pull out (unique) globals in the hash table and put them in garray2[]
3290e3ece09SJunchao Zhang   n2 = tot;
3300e3ece09SJunchao Zhang   PetscCall(PetscMalloc1(n2, &garray2));
3310e3ece09SJunchao Zhang   tot = 0;
3320e3ece09SJunchao Zhang   PetscHashIterBegin(g2l, iter);
3330e3ece09SJunchao Zhang   while (!PetscHashIterAtEnd(g2l, iter)) {
3340e3ece09SJunchao Zhang     PetscHashIterGetKey(g2l, iter, key);
3350e3ece09SJunchao Zhang     PetscHashIterNext(g2l, iter);
3360e3ece09SJunchao Zhang     garray2[tot++] = key;
337076ba34aSJunchao Zhang   }
338076ba34aSJunchao Zhang 
3390e3ece09SJunchao Zhang   // Sort garray2[] and then map them to local indices starting from 0
3400e3ece09SJunchao Zhang   PetscCall(PetscSortInt(n2, garray2));
3410e3ece09SJunchao Zhang   PetscCall(PetscHMapIClear(g2l));
3420e3ece09SJunchao Zhang   for (PetscInt i = 0; i < tot; i++) PetscCall(PetscHMapISet(g2l, garray2[i], i)); // i is the local id
343f0e6e2d1SJunchao Zhang 
3440e3ece09SJunchao Zhang   // Rewrite indices[] with local indices
345f0e6e2d1SJunchao Zhang   for (PetscInt i = 0; i < m; i++) {
3460e3ece09SJunchao Zhang     PetscCall(PetscHMapIGetWithDefault(g2l, indices[i], -1, &val));
3470e3ece09SJunchao Zhang     PetscAssert(val >= 0, PETSC_COMM_SELF, PETSC_ERR_PLIB, "Met a negative local column index");
3480e3ece09SJunchao Zhang     indices[i] = val;
3490e3ece09SJunchao Zhang   }
3500e3ece09SJunchao Zhang   // Record the map that maps garray1[i] to garray2[map[i]]
3510e3ece09SJunchao Zhang   for (PetscInt i = 0; i < n1; i++) PetscCall(PetscHMapIGetWithDefault(g2l, garray1[i], -1, &map[i]));
3520e3ece09SJunchao Zhang   PetscCall(PetscHMapIDestroy(&g2l));
3530e3ece09SJunchao Zhang   *n2_      = n2;
3540e3ece09SJunchao Zhang   *garray2_ = garray2;
3550e3ece09SJunchao Zhang   PetscFunctionReturn(PETSC_SUCCESS);
3560e3ece09SJunchao Zhang }
357f0e6e2d1SJunchao Zhang 
3580e3ece09SJunchao Zhang /*
3590e3ece09SJunchao Zhang   MatMPIAIJKokkosReduce - Reduce rows of a MPIAIJKOKKOS matrix (E, in split form) to produce another matrix (F, also in split form, stored in mm)
3600e3ece09SJunchao Zhang 
3610e3ece09SJunchao Zhang   It is the reverse of MatMPIAIJKokkosBcast() in some sense, but with a different signature since we do not really need a fully populated MPIAIJKOKKOS E.
3620e3ece09SJunchao Zhang 
3630e3ece09SJunchao Zhang   Think each row of E as a leaf, then the given ownerSF specifies roots for the leaves. Roots may connect to multiple leaves.
3640e3ece09SJunchao Zhang   In this routine, we sparse-merge leaves (rows) at their roots to form potentially longer rows in F. F's number of rows will be nroots of ownerSF.
3650e3ece09SJunchao Zhang 
3660e3ece09SJunchao Zhang   Input Parameters:
3670e3ece09SJunchao Zhang +  comm       - MPI communicator of E
3680e3ece09SJunchao Zhang .  A          - diag block of E, using local column indices
3690e3ece09SJunchao Zhang .  B          - off-diag block of E, using local column indices
3700e3ece09SJunchao Zhang .  cstart      - (global) start column of Ed
3710e3ece09SJunchao Zhang .  cend        - (global) end column + 1 of Ed.  In other words, E's column ownership is in range of [cstart, cend)
3720e3ece09SJunchao Zhang .  garray1[n1] - global column indices of Eo. Here n1 is Eo's column size.
3730e3ece09SJunchao Zhang .  ownerSF     - the SF specifies ownership (root) of rows in E
3740e3ece09SJunchao Zhang .  reuse       - either MAT_INITIAL_MATRIX or MAT_REUSE_MATRIX
3750e3ece09SJunchao Zhang -  mm          - to stash intermediate data structures for reuse
3760e3ece09SJunchao Zhang 
3770e3ece09SJunchao Zhang   Output Parameters:
3780e3ece09SJunchao Zhang +  map[n1]  - allocated by caller. It maps garray1[] to garray2[]. See more at ReduceTwoSetsOfGlobalIndices().
3790e3ece09SJunchao Zhang -  mm       - contains various info, such as garray2[], F (Fd, Fo) etc.
3800e3ece09SJunchao Zhang 
3810e3ece09SJunchao Zhang   Notes:
3820e3ece09SJunchao Zhang   When reuse = MAT_REUSE_MATRIX, cstart, cend, garray1, ownerSF, map are not significant.
3830e3ece09SJunchao Zhang 
3840e3ece09SJunchao Zhang  */
3850e3ece09SJunchao Zhang static PetscErrorCode MatMPIAIJKokkosReduceBegin(MPI_Comm comm, KokkosCsrMatrix A, KokkosCsrMatrix B, PetscInt cstart, PetscInt cend, const PetscInt *garray1, PetscSF ownerSF, MatReuse reuse, PetscInt *map, MatMatStruct_AtB *mm)
3860e3ece09SJunchao Zhang {
3870e3ece09SJunchao Zhang   PetscFunctionBegin;
3880e3ece09SJunchao Zhang   if (reuse == MAT_INITIAL_MATRIX) {
3890e3ece09SJunchao Zhang     PetscInt Em = A.numRows(), Fm;
3900e3ece09SJunchao Zhang     PetscInt n1 = B.numCols();
3910e3ece09SJunchao Zhang 
3920e3ece09SJunchao Zhang     PetscCall(PetscSFGetGraph(ownerSF, &Fm, NULL, NULL, NULL)); // Fm = #rows of F = nroots of ownerSF
3930e3ece09SJunchao Zhang 
3940e3ece09SJunchao Zhang     // Do the analysis on host
39545402d8aSJunchao Zhang     auto                 Ai_h = Kokkos::create_mirror_view_and_copy(HostMirrorMemorySpace(), A.graph.row_map);
39645402d8aSJunchao Zhang     auto                 Aj_h = Kokkos::create_mirror_view_and_copy(HostMirrorMemorySpace(), A.graph.entries);
39745402d8aSJunchao Zhang     auto                 Bi_h = Kokkos::create_mirror_view_and_copy(HostMirrorMemorySpace(), B.graph.row_map);
39845402d8aSJunchao Zhang     auto                 Bj_h = Kokkos::create_mirror_view_and_copy(HostMirrorMemorySpace(), B.graph.entries);
3990e3ece09SJunchao Zhang     const MatRowMapType *Ai = Ai_h.data(), *Bi = Bi_h.data();
4000e3ece09SJunchao Zhang     const MatColIdxType *Aj = Aj_h.data(), *Bj = Bj_h.data();
4010e3ece09SJunchao Zhang 
4020e3ece09SJunchao Zhang     // Count how many nonzeros of each row in E are in the left of the diag block [cstart,cend)
4037b8d4ba6SJunchao Zhang     PetscIntKokkosViewHost E_NzLeft_h(NoInit("E_NzLeft_h"), Em), E_RowLen_h(NoInit("E_RowLen_h"), Em);
4040e3ece09SJunchao Zhang     PetscInt              *E_NzLeft = E_NzLeft_h.data(), *E_RowLen = E_RowLen_h.data();
4050e3ece09SJunchao Zhang     for (PetscInt i = 0; i < Em; i++) {
4060e3ece09SJunchao Zhang       const PetscInt *first, *last, *it;
4070e3ece09SJunchao Zhang       PetscInt        count, step;
4080e3ece09SJunchao Zhang       // std::lower_bound(first,last,cstart), but need to use global column indices
4090e3ece09SJunchao Zhang       first = Bj + Bi[i];
4100e3ece09SJunchao Zhang       last  = Bj + Bi[i + 1];
411f0e6e2d1SJunchao Zhang       count = last - first;
412f0e6e2d1SJunchao Zhang       while (count > 0) {
413f0e6e2d1SJunchao Zhang         it   = first;
414f0e6e2d1SJunchao Zhang         step = count / 2;
415f0e6e2d1SJunchao Zhang         it += step;
4160e3ece09SJunchao Zhang         if (garray1[*it] < cstart) { // map local to global
417f0e6e2d1SJunchao Zhang           first = ++it;
418f0e6e2d1SJunchao Zhang           count -= step + 1;
419f0e6e2d1SJunchao Zhang         } else count = step;
420f0e6e2d1SJunchao Zhang       }
4210e3ece09SJunchao Zhang       E_NzLeft[i] = first - (Bj + Bi[i]);
4220e3ece09SJunchao Zhang       E_RowLen[i] = (Ai[i + 1] - Ai[i]) + (Bi[i + 1] - Bi[i]);
423f0e6e2d1SJunchao Zhang     }
424f0e6e2d1SJunchao Zhang 
4250e3ece09SJunchao Zhang     // Get length of rows (i.e., sizes of leaves) that contribute to my roots
4260e3ece09SJunchao Zhang     const PetscMPIInt *iranks, *ranks;
4270e3ece09SJunchao Zhang     const PetscInt    *ioffset, *irootloc, *roffset, *rmine;
428c09cee04SJames Wright     PetscMPIInt        niranks, nranks;
4290e3ece09SJunchao Zhang     MPI_Request       *reqs;
4300e3ece09SJunchao Zhang     PetscMPIInt        tag;
4310e3ece09SJunchao Zhang     PetscSF            reduceSF;
4320e3ece09SJunchao Zhang     PetscInt          *sdisp, *rdisp;
433f0e6e2d1SJunchao Zhang 
4340e3ece09SJunchao Zhang     PetscCall(PetscCommGetNewTag(comm, &tag));
4350e3ece09SJunchao Zhang     PetscCall(PetscSFGetLeafRanks(ownerSF, &niranks, &iranks, &ioffset, &irootloc));  // get leaf ranks connecting to roots on this process (I'll recv from them)
4360e3ece09SJunchao Zhang     PetscCall(PetscSFGetRootRanks(ownerSF, &nranks, &ranks, &roffset, &rmine, NULL)); // get root ranks this process connects (I'll send to them)
437f0e6e2d1SJunchao Zhang 
4380e3ece09SJunchao Zhang     // Find out length of each row I will receive. Even for the same row index, when they are from
4390e3ece09SJunchao Zhang     // different senders, they might have different lengths (and sparsity patterns)
4400e3ece09SJunchao Zhang     PetscInt  sendRowCnt = roffset[nranks], recvRowCnt = ioffset[niranks];
4410e3ece09SJunchao Zhang     PetscInt *sendRowLen, *recvRowLen; // lengths of rows of I need to send/recv per process
442f0e6e2d1SJunchao Zhang 
4430e3ece09SJunchao Zhang     PetscCall(PetscMalloc5(sendRowCnt, &sendRowLen, recvRowCnt + 1, &recvRowLen, nranks, &sdisp, niranks + 1, &rdisp, nranks + niranks, &reqs));
4440e3ece09SJunchao Zhang 
4450e3ece09SJunchao Zhang     for (PetscInt i = 0; i < sendRowCnt; i++) sendRowLen[i] = E_RowLen[rmine[i]];
4460e3ece09SJunchao Zhang     recvRowLen[0] = 0; // since we will make it in CSR format later
4470e3ece09SJunchao Zhang     recvRowLen++;      // advance the pointer now
4486497c311SBarry Smith     for (PetscInt i = 0; i < niranks; i++) MPI_Irecv(&recvRowLen[ioffset[i]], ioffset[i + 1] - ioffset[i], MPIU_INT, iranks[i], tag, comm, &reqs[nranks + i]);
4496497c311SBarry Smith     for (PetscInt i = 0; i < nranks; i++) MPIU_Isend(&sendRowLen[roffset[i]], roffset[i + 1] - roffset[i], MPIU_INT, ranks[i], tag, comm, &reqs[i]);
4500e3ece09SJunchao Zhang     PetscCallMPI(MPI_Waitall(nranks + niranks, reqs, MPI_STATUSES_IGNORE));
4510e3ece09SJunchao Zhang 
4520e3ece09SJunchao Zhang     // Build the real PetscSF for reducing E rows (buffer to buffer)
4530e3ece09SJunchao Zhang     rdisp[0] = 0;
4540e3ece09SJunchao Zhang     for (PetscInt i = 0; i < niranks; i++) {
4550e3ece09SJunchao Zhang       rdisp[i + 1] = rdisp[i];
4560e3ece09SJunchao Zhang       for (PetscInt j = ioffset[i]; j < ioffset[i + 1]; j++) { rdisp[i + 1] += recvRowLen[j]; }
4570e3ece09SJunchao Zhang     }
4580e3ece09SJunchao Zhang     recvRowLen--; // put it back into csr format
4590e3ece09SJunchao Zhang     for (PetscInt i = 0; i < recvRowCnt; i++) recvRowLen[i + 1] += recvRowLen[i];
4600e3ece09SJunchao Zhang 
4616497c311SBarry Smith     for (PetscInt i = 0; i < nranks; i++) MPIU_Irecv(&sdisp[i], 1, MPIU_INT, ranks[i], tag, comm, &reqs[i]);
4626497c311SBarry Smith     for (PetscInt i = 0; i < niranks; i++) MPIU_Isend(&rdisp[i], 1, MPIU_INT, iranks[i], tag, comm, &reqs[nranks + i]);
4630e3ece09SJunchao Zhang     PetscCallMPI(MPI_Waitall(nranks + niranks, reqs, MPI_STATUSES_IGNORE));
4640e3ece09SJunchao Zhang 
4650e3ece09SJunchao Zhang     PetscInt     nleaves = 0, Enz = 0;    // leaves are nonzeros I will send
4660e3ece09SJunchao Zhang     PetscInt     nroots = rdisp[niranks]; // roots are nonzeros I will recv
4670e3ece09SJunchao Zhang     PetscSFNode *iremote;
4680e3ece09SJunchao Zhang 
4690e3ece09SJunchao Zhang     for (PetscInt i = 0; i < Em; i++) Enz += E_RowLen[i];
4700e3ece09SJunchao Zhang     PetscAssert(A.nnz() + B.nnz() == Enz, comm, PETSC_ERR_PLIB, "Enz should be equal to sum of nnz of A and B");
4710e3ece09SJunchao Zhang     PetscCallMPI(PetscMalloc1(Enz, &iremote)); // no free, since we give ownership to reduceSF
4720e3ece09SJunchao Zhang 
4730e3ece09SJunchao Zhang     for (PetscInt i = 0; i < nranks; i++) {
4740e3ece09SJunchao Zhang       PetscInt count = 0;
4750e3ece09SJunchao Zhang       for (PetscInt j = roffset[i]; j < roffset[i + 1]; j++) count += E_RowLen[rmine[j]];
4760e3ece09SJunchao Zhang       for (PetscInt j = 0; j < count; j++) {
4770e3ece09SJunchao Zhang         iremote[nleaves + j].rank  = ranks[i];
4780e3ece09SJunchao Zhang         iremote[nleaves + j].index = sdisp[i] + j;
4790e3ece09SJunchao Zhang       }
4800e3ece09SJunchao Zhang       nleaves += count;
4810e3ece09SJunchao Zhang     }
4820e3ece09SJunchao Zhang     PetscCheck(nleaves == Enz, comm, PETSC_ERR_PLIB, "nleaves should be equal to Enz");
4830e3ece09SJunchao Zhang 
4840e3ece09SJunchao Zhang     PetscCall(PetscSFCreate(comm, &reduceSF));
4850e3ece09SJunchao Zhang     PetscCall(PetscSFSetGraph(reduceSF, nroots, nleaves, NULL, PETSC_OWN_POINTER, iremote, PETSC_OWN_POINTER));
4860e3ece09SJunchao Zhang 
4870e3ece09SJunchao Zhang     // Copy (global) column indices of the needed rows in E to sendCol[], and then PetscSFReduce to recvCol[]
4880e3ece09SJunchao Zhang     PetscInt *sendCol, *recvCol;
4890e3ece09SJunchao Zhang     PetscCall(PetscMalloc2(nleaves, &sendCol, nroots, &recvCol));
4900e3ece09SJunchao Zhang     for (PetscInt k = 0; k < roffset[nranks]; k++) {
4910e3ece09SJunchao Zhang       PetscInt  i      = rmine[k]; // row to be copied
4920e3ece09SJunchao Zhang       PetscInt *buf    = &sendCol[Ai[i] + Bi[i]];
4930e3ece09SJunchao Zhang       PetscInt  nzLeft = E_NzLeft[i];
4940e3ece09SJunchao Zhang       PetscInt  alen = Ai[i + 1] - Ai[i], blen = Bi[i + 1] - Bi[i];
4950e3ece09SJunchao Zhang       for (PetscInt j = 0; j < alen + blen; j++) {
4960e3ece09SJunchao Zhang         if (j < nzLeft) {
4970e3ece09SJunchao Zhang           buf[j] = garray1[Bj[Bi[i] + j]]; // left B, in global
4980e3ece09SJunchao Zhang         } else if (j < nzLeft + alen) {
4990e3ece09SJunchao Zhang           buf[j] = Aj[Ai[i] + j - nzLeft] + cstart; // diag A, also in global
5000e3ece09SJunchao Zhang         } else {
5010e3ece09SJunchao Zhang           buf[j] = garray1[Bj[Bi[i] + j - alen]]; // right B, in global
5020e3ece09SJunchao Zhang         }
5030e3ece09SJunchao Zhang       }
5040e3ece09SJunchao Zhang     }
5050e3ece09SJunchao Zhang     PetscCall(PetscSFReduceWithMemTypeBegin(reduceSF, MPIU_INT, PETSC_MEMTYPE_HOST, sendCol, PETSC_MEMTYPE_HOST, recvCol, MPI_REPLACE));
5060e3ece09SJunchao Zhang     PetscCall(PetscSFReduceEnd(reduceSF, MPIU_INT, sendCol, recvCol, MPI_REPLACE));
5070e3ece09SJunchao Zhang 
5080e3ece09SJunchao Zhang     // With recvCol[], we do a series of analysis to get i, j of Fd, Fo, and build plans to reduce nonzeros in recv buffers to Fd and Fo
5090e3ece09SJunchao Zhang     PetscInt *recvRowPerm, *recvColSorted;
5100e3ece09SJunchao Zhang     PetscInt *recvNzPerm, *recvNzPermSorted;
5110e3ece09SJunchao Zhang     PetscCall(PetscMalloc4(recvRowCnt, &recvRowPerm, nroots, &recvColSorted, nroots, &recvNzPerm, nroots, &recvNzPermSorted));
5120e3ece09SJunchao Zhang 
5130e3ece09SJunchao Zhang     for (PetscInt i = 0; i < nroots; i++) recvNzPerm[i] = i;                   // numbering all received nonzeros
5140e3ece09SJunchao Zhang     for (PetscInt i = 0; i < recvRowCnt; i++) recvRowPerm[i] = i;              // put up a permutation array, so that after sorting we know where to get a row in recvCol[]
5150e3ece09SJunchao Zhang     PetscCall(PetscSortIntWithPermutation(recvRowCnt, irootloc, recvRowPerm)); // irootloc[] (owned by ownerSF) won't be changed
5160e3ece09SJunchao Zhang 
5170e3ece09SJunchao Zhang     // i[] array, nz are always easiest to compute
5187b8d4ba6SJunchao Zhang     MatRowMapKokkosViewHost Fdi_h(NoInit("Fdi_h"), Fm + 1), Foi_h(NoInit("Foi_h"), Fm + 1);
5190e3ece09SJunchao Zhang     MatRowMapType          *Fdi, *Foi;
5200e3ece09SJunchao Zhang     PetscInt                FnzDups = 0, Fdnz = 0, FdnzDups = 0, Fonz = 0, FonzDups = 0; // nz (with or without dups) in F, Fd, Fo
5210e3ece09SJunchao Zhang     PetscInt                iter;
5220e3ece09SJunchao Zhang 
5230e3ece09SJunchao Zhang     Kokkos::deep_copy(Fdi_h, 0); // zero, as we will do 'val++' on them
5240e3ece09SJunchao Zhang     Kokkos::deep_copy(Foi_h, 0);
5250e3ece09SJunchao Zhang     Fdi  = Fdi_h.data() + 1; // +1 for easy indexing in code below
5260e3ece09SJunchao Zhang     Foi  = Foi_h.data() + 1;
5270e3ece09SJunchao Zhang     iter = 0;
5280e3ece09SJunchao Zhang     while (iter < recvRowCnt) { // iter over received rows
5290e3ece09SJunchao Zhang       PetscInt curRowIdx = irootloc[recvRowPerm[iter]];
5300e3ece09SJunchao Zhang       PetscInt dupRows   = 1; // current row has this many contributing rows (of various sparsity patterns)
5310e3ece09SJunchao Zhang 
5320e3ece09SJunchao Zhang       while (iter + dupRows < recvRowCnt && irootloc[recvRowPerm[iter + dupRows]] == curRowIdx) dupRows++;
5330e3ece09SJunchao Zhang 
5340e3ece09SJunchao Zhang       // Copy column indices (and their permutation) of these rows into recvColSorted & recvNzPermSorted
5350e3ece09SJunchao Zhang       PetscInt  nz    = 0; // nz (with dups) in the current row
5360e3ece09SJunchao Zhang       PetscInt *jbuf  = recvColSorted + FnzDups;
5370e3ece09SJunchao Zhang       PetscInt *pbuf  = recvNzPermSorted + FnzDups;
5380e3ece09SJunchao Zhang       PetscInt *jbuf2 = jbuf; // temp pointers
5390e3ece09SJunchao Zhang       PetscInt *pbuf2 = pbuf;
5400e3ece09SJunchao Zhang       for (PetscInt d = 0; d < dupRows; d++) {
5410e3ece09SJunchao Zhang         PetscInt i   = recvRowPerm[iter + d];
5420e3ece09SJunchao Zhang         PetscInt len = recvRowLen[i + 1] - recvRowLen[i];
5430e3ece09SJunchao Zhang         PetscCall(PetscArraycpy(jbuf2, &recvCol[recvRowLen[i]], len));
5440e3ece09SJunchao Zhang         PetscCall(PetscArraycpy(pbuf2, &recvNzPerm[recvRowLen[i]], len));
5450e3ece09SJunchao Zhang         jbuf2 += len;
5460e3ece09SJunchao Zhang         pbuf2 += len;
5470e3ece09SJunchao Zhang         nz += len;
5480e3ece09SJunchao Zhang       }
5490e3ece09SJunchao Zhang       PetscCall(PetscIntSortSemiOrderedWithArray(nz, jbuf, pbuf)); // It could be improved with k-way merge sort, since the rows are already sorted
5500e3ece09SJunchao Zhang 
5510e3ece09SJunchao Zhang       // Scan column indices (in jbuf[0,nz), might have dups) of this row, and see how many go to Fd and how many go to Fo
5520e3ece09SJunchao Zhang       PetscInt cur = 0;
5530e3ece09SJunchao Zhang       while (cur < nz) {
5540e3ece09SJunchao Zhang         PetscInt curColIdx = jbuf[cur];
5550e3ece09SJunchao Zhang         PetscInt dups      = 1;
5560e3ece09SJunchao Zhang 
5570e3ece09SJunchao Zhang         while (cur + dups < nz && jbuf[cur + dups] == curColIdx) dups++;
5580e3ece09SJunchao Zhang         if (curColIdx >= cstart && curColIdx < cend) {
5590e3ece09SJunchao Zhang           Fdi[curRowIdx]++;
5600e3ece09SJunchao Zhang           FdnzDups += dups;
5610e3ece09SJunchao Zhang         } else {
5620e3ece09SJunchao Zhang           Foi[curRowIdx]++;
5630e3ece09SJunchao Zhang           FonzDups += dups;
5640e3ece09SJunchao Zhang         }
5650e3ece09SJunchao Zhang         cur += dups;
5660e3ece09SJunchao Zhang       }
5670e3ece09SJunchao Zhang 
5680e3ece09SJunchao Zhang       FnzDups += nz;
5690e3ece09SJunchao Zhang       iter += dupRows; // Move to next unique row
5700e3ece09SJunchao Zhang     }
5710e3ece09SJunchao Zhang 
5720e3ece09SJunchao Zhang     Fdi = Fdi_h.data(); // restore Fdi, Foi and make them CSR
5730e3ece09SJunchao Zhang     Foi = Foi_h.data();
5740e3ece09SJunchao Zhang     for (PetscInt i = 0; i < Fm; i++) {
5750e3ece09SJunchao Zhang       Fdi[i + 1] += Fdi[i];
5760e3ece09SJunchao Zhang       Foi[i + 1] += Foi[i];
5770e3ece09SJunchao Zhang     }
5780e3ece09SJunchao Zhang     Fdnz = Fdi[Fm];
5790e3ece09SJunchao Zhang     Fonz = Foi[Fm];
5800e3ece09SJunchao Zhang     PetscCall(PetscFree2(sendCol, recvCol));
5810e3ece09SJunchao Zhang 
5820e3ece09SJunchao Zhang     // Allocate j, jmap, jperm for Fd and Fo
5837b8d4ba6SJunchao Zhang     MatColIdxKokkosViewHost Fdj_h(NoInit("Fdj_h"), Fdnz), Foj_h(NoInit("Foj_h"), Fonz);
5847b8d4ba6SJunchao Zhang     MatRowMapKokkosViewHost Fdjmap_h(NoInit("Fdjmap_h"), Fdnz + 1), Fojmap_h(NoInit("Fojmap_h"), Fonz + 1); // +1 to make csr
5857b8d4ba6SJunchao Zhang     MatRowMapKokkosViewHost Fdjperm_h(NoInit("Fdjperm_h"), FdnzDups), Fojperm_h(NoInit("Fojperm_h"), FonzDups);
5860e3ece09SJunchao Zhang     MatColIdxType          *Fdj = Fdj_h.data(), *Foj = Foj_h.data();
5870e3ece09SJunchao Zhang     MatRowMapType          *Fdjmap = Fdjmap_h.data(), *Fojmap = Fojmap_h.data();
5880e3ece09SJunchao Zhang     MatRowMapType          *Fdjperm = Fdjperm_h.data(), *Fojperm = Fojperm_h.data();
5890e3ece09SJunchao Zhang 
5900e3ece09SJunchao Zhang     // Scan recvColSorted[] again, and fill j, jmap, jperm for Fd and Fo
5910e3ece09SJunchao Zhang     Fdjmap[0] = 0;
5920e3ece09SJunchao Zhang     Fojmap[0] = 0;
5930e3ece09SJunchao Zhang     FnzDups   = 0;
5940e3ece09SJunchao Zhang     Fdnz      = 0;
5950e3ece09SJunchao Zhang     Fonz      = 0;
5960e3ece09SJunchao Zhang     iter      = 0; // iter over received rows
5970e3ece09SJunchao Zhang     while (iter < recvRowCnt) {
5980e3ece09SJunchao Zhang       PetscInt curRowIdx = irootloc[recvRowPerm[iter]]; // current row idx
5990e3ece09SJunchao Zhang       PetscInt dupRows   = 1;                           // It has this many contributing rows (of various lengths)
6000e3ece09SJunchao Zhang       PetscInt nz        = 0;                           // nz (with dups) in the current row
6010e3ece09SJunchao Zhang 
6020e3ece09SJunchao Zhang       while (iter + dupRows < recvRowCnt && irootloc[recvRowPerm[iter + dupRows]] == curRowIdx) dupRows++;
6030e3ece09SJunchao Zhang       for (PetscInt d = 0; d < dupRows; d++) {
6040e3ece09SJunchao Zhang         PetscInt i = recvRowPerm[iter + d];
6050e3ece09SJunchao Zhang         nz += recvRowLen[i + 1] - recvRowLen[i];
6060e3ece09SJunchao Zhang       }
6070e3ece09SJunchao Zhang 
6080e3ece09SJunchao Zhang       PetscInt *jbuf = recvColSorted + FnzDups;
6090e3ece09SJunchao Zhang       // Scan columns (in jbuf[0,nz) of this row, copy them and their permutation to j[] and jperm[] of Fd and Fo
6100e3ece09SJunchao Zhang       PetscInt cur = 0;
6110e3ece09SJunchao Zhang       while (cur < nz) {
6120e3ece09SJunchao Zhang         PetscInt curColIdx = jbuf[cur];
6130e3ece09SJunchao Zhang         PetscInt dups      = 1;
6140e3ece09SJunchao Zhang 
6150e3ece09SJunchao Zhang         while (cur + dups < nz && jbuf[cur + dups] == curColIdx) dups++;
6160e3ece09SJunchao Zhang         if (curColIdx >= cstart && curColIdx < cend) {
6170e3ece09SJunchao Zhang           Fdj[Fdnz]        = curColIdx - cstart; // easily convert to local
6180e3ece09SJunchao Zhang           Fdjmap[Fdnz + 1] = Fdjmap[Fdnz] + dups;
6190e3ece09SJunchao Zhang           for (PetscInt j = 0; j < dups; j++) Fdjperm[Fdjmap[Fdnz] + j] = recvNzPermSorted[FnzDups + j];
6200e3ece09SJunchao Zhang           FdnzDups += dups;
6210e3ece09SJunchao Zhang           Fdnz++;
6220e3ece09SJunchao Zhang         } else {
6230e3ece09SJunchao Zhang           Foj[Fonz]        = curColIdx; // in global
6240e3ece09SJunchao Zhang           Fojmap[Fonz + 1] = Fojmap[Fonz] + dups;
6250e3ece09SJunchao Zhang           for (PetscInt j = 0; j < dups; j++) Fojperm[Fojmap[Fonz] + j] = recvNzPermSorted[FnzDups + j];
6260e3ece09SJunchao Zhang           FonzDups += dups;
6270e3ece09SJunchao Zhang           Fonz++;
6280e3ece09SJunchao Zhang         }
6290e3ece09SJunchao Zhang         cur += dups;
6300e3ece09SJunchao Zhang         FnzDups += dups;
6310e3ece09SJunchao Zhang       }
6320e3ece09SJunchao Zhang       iter += dupRows; // Move to next unique row
6330e3ece09SJunchao Zhang     }
6340e3ece09SJunchao Zhang     PetscCall(PetscFree4(recvRowPerm, recvColSorted, recvNzPerm, recvNzPermSorted));
6350e3ece09SJunchao Zhang     PetscCall(PetscFree5(sendRowLen, recvRowLen, sdisp, rdisp, reqs));
6360e3ece09SJunchao Zhang 
6370e3ece09SJunchao Zhang     // Combine global column indices in garray1[] and Foj[]
6380e3ece09SJunchao Zhang     PetscInt n2, *garray2;
6390e3ece09SJunchao Zhang 
6400e3ece09SJunchao Zhang     PetscCall(ReduceTwoSetsOfGlobalIndices(n1, garray1, Fonz, Foj, &n2, &garray2, map));
6410e3ece09SJunchao Zhang     mm->sf       = reduceSF;
6427b8d4ba6SJunchao Zhang     mm->leafBuf  = MatScalarKokkosView(NoInit("leafBuf"), nleaves);
6437b8d4ba6SJunchao Zhang     mm->rootBuf  = MatScalarKokkosView(NoInit("rootBuf"), nroots);
644aaa8cc7dSPierre Jolivet     mm->garray   = garray2; // give ownership, so no free
6450e3ece09SJunchao Zhang     mm->n        = n2;
6460e3ece09SJunchao Zhang     mm->E_NzLeft = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), E_NzLeft_h);
6470e3ece09SJunchao Zhang     mm->Fdjmap   = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Fdjmap_h);
6480e3ece09SJunchao Zhang     mm->Fdjperm  = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Fdjperm_h);
6490e3ece09SJunchao Zhang     mm->Fojmap   = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Fojmap_h);
6500e3ece09SJunchao Zhang     mm->Fojperm  = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Fojperm_h);
6510e3ece09SJunchao Zhang 
6520e3ece09SJunchao Zhang     // Output Fd and Fo in KokkosCsrMatrix format
6537b8d4ba6SJunchao Zhang     MatScalarKokkosView Fda_d(NoInit("Fda_d"), Fdnz);
6540e3ece09SJunchao Zhang     MatRowMapKokkosView Fdi_d = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Fdi_h);
6550e3ece09SJunchao Zhang     MatColIdxKokkosView Fdj_d = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Fdj_h);
6567b8d4ba6SJunchao Zhang     MatScalarKokkosView Foa_d(NoInit("Foa_d"), Fonz);
6570e3ece09SJunchao Zhang     MatRowMapKokkosView Foi_d = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Foi_h);
6580e3ece09SJunchao Zhang     MatColIdxKokkosView Foj_d = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Foj_h);
6590e3ece09SJunchao Zhang 
6600e3ece09SJunchao Zhang     PetscCallCXX(mm->Fd = KokkosCsrMatrix("Fd", Fm, cend - cstart, Fdnz, Fda_d, Fdi_d, Fdj_d));
6610e3ece09SJunchao Zhang     PetscCallCXX(mm->Fo = KokkosCsrMatrix("Fo", Fm, n2, Fonz, Foa_d, Foi_d, Foj_d)); // Fo's column size is n2, length of garray2[]
6620e3ece09SJunchao Zhang 
6630e3ece09SJunchao Zhang     // Compute kernel launch parameters in merging E
6640e3ece09SJunchao Zhang     PetscInt teamSize, vectorLength, rowsPerTeam;
6650e3ece09SJunchao Zhang 
6660e3ece09SJunchao Zhang     teamSize = vectorLength = rowsPerTeam = -1;
6670e3ece09SJunchao Zhang     PetscCall(MatMergeGetLaunchParameters<DefaultExecutionSpace>(Em, Enz, -1, teamSize, vectorLength, rowsPerTeam));
6680e3ece09SJunchao Zhang     mm->E_TeamSize     = teamSize;
6690e3ece09SJunchao Zhang     mm->E_VectorLength = vectorLength;
6700e3ece09SJunchao Zhang     mm->E_RowsPerTeam  = rowsPerTeam;
6710e3ece09SJunchao Zhang   } else PetscCheck(reuse == MAT_REUSE_MATRIX, comm, PETSC_ERR_PLIB, "Unsupported MatReuse enum %d", reuse);
6720e3ece09SJunchao Zhang 
6730e3ece09SJunchao Zhang   // Handy aliases
6740e3ece09SJunchao Zhang   auto       &Aa           = A.values;
6750e3ece09SJunchao Zhang   auto       &Ba           = B.values;
6760e3ece09SJunchao Zhang   const auto &Ai           = A.graph.row_map;
6770e3ece09SJunchao Zhang   const auto &Bi           = B.graph.row_map;
6780e3ece09SJunchao Zhang   const auto &E_NzLeft     = mm->E_NzLeft;
6790e3ece09SJunchao Zhang   auto       &leafBuf      = mm->leafBuf;
6800e3ece09SJunchao Zhang   auto       &rootBuf      = mm->rootBuf;
6810e3ece09SJunchao Zhang   PetscSF     reduceSF     = mm->sf;
6820e3ece09SJunchao Zhang   PetscInt    Em           = A.numRows();
6830e3ece09SJunchao Zhang   PetscInt    teamSize     = mm->E_TeamSize;
6840e3ece09SJunchao Zhang   PetscInt    vectorLength = mm->E_VectorLength;
6850e3ece09SJunchao Zhang   PetscInt    rowsPerTeam  = mm->E_RowsPerTeam;
6860e3ece09SJunchao Zhang   PetscInt    workSets     = (Em + rowsPerTeam - 1) / rowsPerTeam;
6870e3ece09SJunchao Zhang 
6880e3ece09SJunchao Zhang   // Copy rows in A/B of E to leafBuf, then pass it to rootBuf
6890e3ece09SJunchao Zhang   PetscCallCXX(Kokkos::parallel_for(
690d326c3f1SJunchao Zhang     Kokkos::TeamPolicy<>(PetscGetKokkosExecutionSpace(), workSets, teamSize, vectorLength), KOKKOS_LAMBDA(const KokkosTeamMemberType &t) {
6910e3ece09SJunchao Zhang       Kokkos::parallel_for(Kokkos::TeamThreadRange(t, 0, rowsPerTeam), [&](PetscInt k) {
6920e3ece09SJunchao Zhang         PetscInt i = t.league_rank() * rowsPerTeam + k; // i-th row in F
6930e3ece09SJunchao Zhang         if (i < Em) {
6940e3ece09SJunchao Zhang           PetscInt disp   = Ai(i) + Bi(i);
6950e3ece09SJunchao Zhang           PetscInt alen   = Ai(i + 1) - Ai(i);
6960e3ece09SJunchao Zhang           PetscInt blen   = Bi(i + 1) - Bi(i);
6970e3ece09SJunchao Zhang           PetscInt nzleft = E_NzLeft(i);
6980e3ece09SJunchao Zhang 
6990e3ece09SJunchao Zhang           Kokkos::parallel_for(Kokkos::ThreadVectorRange(t, alen + blen), [&](PetscInt j) {
7000e3ece09SJunchao Zhang             MatScalar &val = leafBuf(disp + j);
7010e3ece09SJunchao Zhang             if (j < nzleft) { // B left
7020e3ece09SJunchao Zhang               val = Ba(Bi(i) + j);
7030e3ece09SJunchao Zhang             } else if (j < nzleft + alen) { // diag A
7040e3ece09SJunchao Zhang               val = Aa(Ai(i) + j - nzleft);
7050e3ece09SJunchao Zhang             } else { // B right
7060e3ece09SJunchao Zhang               val = Ba(Bi(i) + j - alen);
707f0e6e2d1SJunchao Zhang             }
708f0e6e2d1SJunchao Zhang           });
709f0e6e2d1SJunchao Zhang         }
710f0e6e2d1SJunchao Zhang       });
7110e3ece09SJunchao Zhang     }));
7120e3ece09SJunchao Zhang   PetscCall(PetscSFReduceWithMemTypeBegin(reduceSF, MPIU_SCALAR, PETSC_MEMTYPE_KOKKOS, leafBuf.data(), PETSC_MEMTYPE_KOKKOS, rootBuf.data(), MPI_REPLACE));
713f0e6e2d1SJunchao Zhang   PetscFunctionReturn(PETSC_SUCCESS);
714f0e6e2d1SJunchao Zhang }
7150e3ece09SJunchao Zhang 
716aaa8cc7dSPierre Jolivet // To finish MatMPIAIJKokkosReduce.
7170e3ece09SJunchao Zhang static PetscErrorCode MatMPIAIJKokkosReduceEnd(MPI_Comm comm, KokkosCsrMatrix A, KokkosCsrMatrix B, PetscInt cstart, PetscInt cend, const PetscInt *garray1, PetscSF ownerSF, MatReuse reuse, PetscInt *map, MatMatStruct_AtB *mm)
7180e3ece09SJunchao Zhang {
7190e3ece09SJunchao Zhang   auto       &leafBuf  = mm->leafBuf;
7200e3ece09SJunchao Zhang   auto       &rootBuf  = mm->rootBuf;
7210e3ece09SJunchao Zhang   auto       &Fda      = mm->Fd.values;
7220e3ece09SJunchao Zhang   const auto &Fdjmap   = mm->Fdjmap;
7230e3ece09SJunchao Zhang   const auto &Fdjperm  = mm->Fdjperm;
7240e3ece09SJunchao Zhang   auto        Fdnz     = mm->Fd.nnz();
7250e3ece09SJunchao Zhang   auto       &Foa      = mm->Fo.values;
7260e3ece09SJunchao Zhang   const auto &Fojmap   = mm->Fojmap;
7270e3ece09SJunchao Zhang   const auto &Fojperm  = mm->Fojperm;
7280e3ece09SJunchao Zhang   auto        Fonz     = mm->Fo.nnz();
7290e3ece09SJunchao Zhang   PetscSF     reduceSF = mm->sf;
7300e3ece09SJunchao Zhang 
731d326c3f1SJunchao Zhang   PetscFunctionBegin;
7320e3ece09SJunchao Zhang   PetscCall(PetscSFReduceEnd(reduceSF, MPIU_SCALAR, leafBuf.data(), rootBuf.data(), MPI_REPLACE));
7330e3ece09SJunchao Zhang 
7340e3ece09SJunchao Zhang   // Reduce data in rootBuf to Fd and Fo
7350e3ece09SJunchao Zhang   PetscCallCXX(Kokkos::parallel_for(
736d326c3f1SJunchao Zhang     Kokkos::RangePolicy<>(PetscGetKokkosExecutionSpace(), 0, Fdnz), KOKKOS_LAMBDA(const MatRowMapType i) {
7370e3ece09SJunchao Zhang       PetscScalar sum = 0.0;
7380e3ece09SJunchao Zhang       for (MatRowMapType k = Fdjmap(i); k < Fdjmap(i + 1); k++) sum += rootBuf(Fdjperm(k));
7390e3ece09SJunchao Zhang       Fda(i) = sum;
7400e3ece09SJunchao Zhang     }));
7410e3ece09SJunchao Zhang 
7420e3ece09SJunchao Zhang   PetscCallCXX(Kokkos::parallel_for(
743d326c3f1SJunchao Zhang     Kokkos::RangePolicy<>(PetscGetKokkosExecutionSpace(), 0, Fonz), KOKKOS_LAMBDA(const MatRowMapType i) {
7440e3ece09SJunchao Zhang       PetscScalar sum = 0.0;
7450e3ece09SJunchao Zhang       for (MatRowMapType k = Fojmap(i); k < Fojmap(i + 1); k++) sum += rootBuf(Fojperm(k));
7460e3ece09SJunchao Zhang       Foa(i) = sum;
7470e3ece09SJunchao Zhang     }));
7480e3ece09SJunchao Zhang   PetscFunctionReturn(PETSC_SUCCESS);
7490e3ece09SJunchao Zhang }
7500e3ece09SJunchao Zhang 
7510e3ece09SJunchao Zhang /*
7520e3ece09SJunchao Zhang   MatMPIAIJKokkosBcast - Bcast local rows of a MPIAIJKOKKOS matrix (E) to produce a local matrix (F, stored in mm) in split form
7530e3ece09SJunchao Zhang 
7540e3ece09SJunchao Zhang   This is a complex routine. It is essentially the MPIAIJKOKKOS counterpart of MatGetBrowsOfAoCols_MPIAIJ, but supports
7550e3ece09SJunchao Zhang   device and involves various index mapping.
7560e3ece09SJunchao Zhang 
7570e3ece09SJunchao Zhang   In the given ownerSF, leaves correspond to rows in F, and roots correspond to rows in E. Roots may connect to multiple leaves.
7580e3ece09SJunchao Zhang   Suppose F's j-th row is connected to a root identified by PetscSFNode (k,i), it means we need to bcast the i-th row of E on rank k
7590e3ece09SJunchao Zhang   to j-th row of F. ownerSF is not an arbitrary SF, instead it is the Mvctx of another MPIAIJ matrix A that is able to perform A*E.
7600e3ece09SJunchao Zhang   F has the same column layout as E.
7610e3ece09SJunchao Zhang 
7620e3ece09SJunchao Zhang   Conceptually F has global column indices. In this routine, we spit F into diagonal Fd and off-diagonal Fo.
763aaa8cc7dSPierre Jolivet   Fd uses local column indices, which are easy to compute. We just need to subtract the "local column range start" from the global indices.
7640e3ece09SJunchao Zhang   Fo had global column indices at first. We will reduce them into local ones. In doing that, we also take into account the global
7650e3ece09SJunchao Zhang   column indices that E's off-diag block has. Let's say there are n1 such indices stored in garray1[]. We will reduce them along with
7660e3ece09SJunchao Zhang   column indices in Fo and update Fo with local indices.
7670e3ece09SJunchao Zhang 
7680e3ece09SJunchao Zhang    Input Parameters:
7690e3ece09SJunchao Zhang +   E       - the MPIAIJKOKKOS matrix
7709c89aa79SPierre Jolivet .   ownerSF - the ownership SF (insignificant in MAT_REUSE_MATRIX)
7710e3ece09SJunchao Zhang .   reuse   - either MAT_INITIAL_MATRIX or MAT_REUSE_MATRIX
7720e3ece09SJunchao Zhang -   mm      - to stash matproduct intermediate data structures
7730e3ece09SJunchao Zhang 
7740e3ece09SJunchao Zhang     Output Parameters:
7750e3ece09SJunchao Zhang +   map[n1] - allocated by caller. It maps garray1[] to garray2[]. See more at ReduceTwoSetsOfGlobalIndices.
7760e3ece09SJunchao Zhang -   mm      - contains various info, such as garray2[], Fd, Fo, etc.
7770e3ece09SJunchao Zhang 
7780e3ece09SJunchao Zhang     Notes:
7790e3ece09SJunchao Zhang     When reuse = MAT_REUSE_MATRIX, ownerSF, map are not significant.
7800e3ece09SJunchao Zhang     The routine is provide in split-phase form MatMPIAIJKokkosBcastBegin/End() to provide computation/communication opportunities.
7810e3ece09SJunchao Zhang */
7820e3ece09SJunchao Zhang static PetscErrorCode MatMPIAIJKokkosBcastBegin(Mat E, PetscSF ownerSF, MatReuse reuse, PetscInt *map, MatMatStruct_AB *mm)
7830e3ece09SJunchao Zhang {
7840e3ece09SJunchao Zhang   Mat_MPIAIJ       *empi = static_cast<Mat_MPIAIJ *>(E->data);
7850e3ece09SJunchao Zhang   Mat               A = empi->A, B = empi->B; // diag and off-diag
7860e3ece09SJunchao Zhang   Mat_SeqAIJKokkos *akok = static_cast<Mat_SeqAIJKokkos *>(A->spptr), *bkok = static_cast<Mat_SeqAIJKokkos *>(B->spptr);
7870e3ece09SJunchao Zhang   PetscInt          Em = E->rmap->n; // #local rows
7880e3ece09SJunchao Zhang   MPI_Comm          comm;
7890e3ece09SJunchao Zhang 
7900e3ece09SJunchao Zhang   PetscFunctionBegin;
7910e3ece09SJunchao Zhang   PetscCallMPI(PetscObjectGetComm((PetscObject)E, &comm));
7920e3ece09SJunchao Zhang   if (reuse == MAT_INITIAL_MATRIX) {
7930e3ece09SJunchao Zhang     Mat_SeqAIJ     *aseq = static_cast<Mat_SeqAIJ *>(A->data), *bseq = static_cast<Mat_SeqAIJ *>(B->data);
7940e3ece09SJunchao Zhang     PetscInt        n1 = B->cmap->n, *Ai = aseq->i, *Aj = aseq->j, *Bi = bseq->i, *Bj = bseq->j;
7950e3ece09SJunchao Zhang     const PetscInt *garray1 = empi->garray; // its size is n1
7960e3ece09SJunchao Zhang     PetscInt        cstart, cend;
7970e3ece09SJunchao Zhang     PetscSF         bcastSF;
7980e3ece09SJunchao Zhang 
7990e3ece09SJunchao Zhang     PetscCall(MatGetOwnershipRangeColumn(E, &cstart, &cend));
8000e3ece09SJunchao Zhang 
8010e3ece09SJunchao Zhang     // Count how many nonzeros of each row in E are in the left of the diag block [cstart,cend)
8027b8d4ba6SJunchao Zhang     PetscIntKokkosViewHost E_NzLeft_h(NoInit("E_NzLeft_h"), Em), E_RowLen_h(NoInit("E_RowLen_h"), Em);
8030e3ece09SJunchao Zhang     PetscInt              *E_NzLeft = E_NzLeft_h.data(), *E_RowLen = E_RowLen_h.data();
8040e3ece09SJunchao Zhang     for (PetscInt i = 0; i < Em; i++) {
8050e3ece09SJunchao Zhang       const PetscInt *first, *last, *it;
8060e3ece09SJunchao Zhang       PetscInt        count, step;
8070e3ece09SJunchao Zhang       // std::lower_bound(first,last,cstart), but need to use global column indices
8080e3ece09SJunchao Zhang       first = Bj + Bi[i];
8090e3ece09SJunchao Zhang       last  = Bj + Bi[i + 1];
8100e3ece09SJunchao Zhang       count = last - first;
8110e3ece09SJunchao Zhang       while (count > 0) {
8120e3ece09SJunchao Zhang         it   = first;
8130e3ece09SJunchao Zhang         step = count / 2;
8140e3ece09SJunchao Zhang         it += step;
8150e3ece09SJunchao Zhang         if (empi->garray[*it] < cstart) { // map local to global
8160e3ece09SJunchao Zhang           first = ++it;
8170e3ece09SJunchao Zhang           count -= step + 1;
8180e3ece09SJunchao Zhang         } else count = step;
8190e3ece09SJunchao Zhang       }
8200e3ece09SJunchao Zhang       E_NzLeft[i] = first - (Bj + Bi[i]);
8210e3ece09SJunchao Zhang       E_RowLen[i] = (Ai[i + 1] - Ai[i]) + (Bi[i + 1] - Bi[i]);
8220e3ece09SJunchao Zhang     }
8230e3ece09SJunchao Zhang 
8240e3ece09SJunchao Zhang     // Compute row pointer Fi of F
8250e3ece09SJunchao Zhang     PetscInt *Fi, Fm, Fnz;
8260e3ece09SJunchao Zhang     PetscCall(PetscSFGetGraph(ownerSF, NULL, &Fm, NULL, NULL)); // Fm = #rows of F = nleaves of ownerSF
8270e3ece09SJunchao Zhang     PetscCall(PetscMalloc1(Fm + 1, &Fi));
8280e3ece09SJunchao Zhang     Fi[0] = 0;
8290e3ece09SJunchao Zhang     PetscCall(PetscSFBcastWithMemTypeBegin(ownerSF, MPIU_INT, PETSC_MEMTYPE_HOST, E_RowLen, PETSC_MEMTYPE_HOST, &Fi[1], MPI_REPLACE));
8300e3ece09SJunchao Zhang     PetscCall(PetscSFBcastEnd(ownerSF, MPIU_INT, E_RowLen, &Fi[1], MPI_REPLACE));
8310e3ece09SJunchao Zhang     for (PetscInt i = 0; i < Fm; i++) Fi[i + 1] += Fi[i];
8320e3ece09SJunchao Zhang     Fnz = Fi[Fm];
8330e3ece09SJunchao Zhang 
8340e3ece09SJunchao Zhang     // Build the real PetscSF for bcasting E rows (buffer to buffer)
8350e3ece09SJunchao Zhang     const PetscMPIInt *iranks, *ranks;
8360e3ece09SJunchao Zhang     const PetscInt    *ioffset, *irootloc, *roffset;
837c09cee04SJames Wright     PetscMPIInt        niranks, nranks;
838c09cee04SJames Wright     PetscInt          *sdisp, *rdisp;
8390e3ece09SJunchao Zhang     MPI_Request       *reqs;
8400e3ece09SJunchao Zhang     PetscMPIInt        tag;
8410e3ece09SJunchao Zhang 
8420e3ece09SJunchao Zhang     PetscCall(PetscSFGetLeafRanks(ownerSF, &niranks, &iranks, &ioffset, &irootloc)); // get leaf ranks referencing roots on this process
8430e3ece09SJunchao Zhang     PetscCall(PetscSFGetRootRanks(ownerSF, &nranks, &ranks, &roffset, NULL, NULL));  // recv info
8440e3ece09SJunchao Zhang     PetscCall(PetscMalloc3(niranks + 1, &sdisp, nranks, &rdisp, niranks + nranks, &reqs));
8450e3ece09SJunchao Zhang 
8460e3ece09SJunchao Zhang     sdisp[0] = 0; // send displacement
8470e3ece09SJunchao Zhang     for (PetscInt i = 0; i < niranks; i++) {
8480e3ece09SJunchao Zhang       sdisp[i + 1] = sdisp[i];
8490e3ece09SJunchao Zhang       for (PetscInt j = ioffset[i]; j < ioffset[i + 1]; j++) {
8500e3ece09SJunchao Zhang         PetscInt r = irootloc[j]; // row to be sent
8510e3ece09SJunchao Zhang         sdisp[i + 1] += E_RowLen[r];
8520e3ece09SJunchao Zhang       }
8530e3ece09SJunchao Zhang     }
8540e3ece09SJunchao Zhang 
8550e3ece09SJunchao Zhang     PetscCallMPI(PetscCommGetNewTag(comm, &tag));
8566497c311SBarry Smith     for (PetscInt i = 0; i < nranks; i++) PetscCallMPI(MPIU_Irecv(&rdisp[i], 1, MPIU_INT, ranks[i], tag, comm, &reqs[i]));
8576497c311SBarry Smith     for (PetscInt i = 0; i < niranks; i++) PetscCallMPI(MPIU_Isend(&sdisp[i], 1, MPIU_INT, iranks[i], tag, comm, &reqs[nranks + i]));
8580e3ece09SJunchao Zhang     PetscCallMPI(MPI_Waitall(niranks + nranks, reqs, MPI_STATUSES_IGNORE));
8590e3ece09SJunchao Zhang 
8600e3ece09SJunchao Zhang     PetscInt     nleaves = Fnz;            // leaves are nonzeros I will receive
8610e3ece09SJunchao Zhang     PetscInt     nroots  = sdisp[niranks]; // roots are nonzeros I will send
8620e3ece09SJunchao Zhang     PetscSFNode *iremote;                  // give ownership to bcastSF
8630e3ece09SJunchao Zhang     PetscCall(PetscMalloc1(nleaves, &iremote));
8640e3ece09SJunchao Zhang     for (PetscInt i = 0; i < nranks; i++) { // for each sender rank
8650e3ece09SJunchao Zhang       PetscInt k = 0;
8660e3ece09SJunchao Zhang       for (PetscInt j = Fi[roffset[i]]; j < Fi[roffset[i + 1]]; j++) { // I will receive rows [roffset[i], roffset[i+1]) of F from ranks[i]
8670e3ece09SJunchao Zhang         iremote[j].rank  = ranks[i];
8680e3ece09SJunchao Zhang         iremote[j].index = rdisp[i] + k; // their root location
8690e3ece09SJunchao Zhang         k++;
8700e3ece09SJunchao Zhang       }
8710e3ece09SJunchao Zhang     }
8720e3ece09SJunchao Zhang     PetscCall(PetscSFCreate(comm, &bcastSF));
8730e3ece09SJunchao Zhang     PetscCall(PetscSFSetGraph(bcastSF, nroots, nleaves, NULL, PETSC_OWN_POINTER, iremote, PETSC_OWN_POINTER));
8740e3ece09SJunchao Zhang     PetscCall(PetscFree3(sdisp, rdisp, reqs));
8750e3ece09SJunchao Zhang 
8760e3ece09SJunchao Zhang     // Build a plan (rowoffset, irootloc, E_NzLeft) to copy rows in E to rootdata of bcastSF in parallel
8777b8d4ba6SJunchao Zhang     PetscIntKokkosViewHost rowoffset_h(NoInit("rowoffset_h"), ioffset[niranks] + 1);
8780e3ece09SJunchao Zhang     PetscInt              *rowoffset = rowoffset_h.data(); // for each entry (row) indicated in irootloc[], we calculate its destinate offset in copying
8790e3ece09SJunchao Zhang     rowoffset[0]                     = 0;
8807b8d4ba6SJunchao Zhang     for (PetscInt i = 0; i < ioffset[niranks]; i++) { rowoffset[i + 1] = rowoffset[i] + E_RowLen[irootloc[i]]; }
8810e3ece09SJunchao Zhang 
8820e3ece09SJunchao Zhang     // Copy (global) column indices of the needed rows in E to a buffer, and then bcast to Fj[]
8830e3ece09SJunchao Zhang     PetscInt *jbuf, *Fj;
8840e3ece09SJunchao Zhang     PetscCall(PetscMalloc2(nroots, &jbuf, Fnz, &Fj));
8850e3ece09SJunchao Zhang     for (PetscInt k = 0; k < ioffset[niranks]; k++) {
8860e3ece09SJunchao Zhang       PetscInt  i      = irootloc[k]; // row to be copied
8870e3ece09SJunchao Zhang       PetscInt *buf    = &jbuf[rowoffset[k]];
8880e3ece09SJunchao Zhang       PetscInt  nzLeft = E_NzLeft[i];
8890e3ece09SJunchao Zhang       PetscInt  alen = Ai[i + 1] - Ai[i], blen = Bi[i + 1] - Bi[i];
8900e3ece09SJunchao Zhang       for (PetscInt j = 0; j < alen + blen; j++) {
8910e3ece09SJunchao Zhang         if (j < nzLeft) {
8920e3ece09SJunchao Zhang           buf[j] = empi->garray[Bj[Bi[i] + j]]; // left B, in global
8930e3ece09SJunchao Zhang         } else if (j < nzLeft + alen) {
8940e3ece09SJunchao Zhang           buf[j] = Aj[Ai[i] + j - nzLeft] + cstart; // diag A, also in global
8950e3ece09SJunchao Zhang         } else {
8960e3ece09SJunchao Zhang           buf[j] = empi->garray[Bj[Bi[i] + j - alen]]; // right B, in global
8970e3ece09SJunchao Zhang         }
8980e3ece09SJunchao Zhang       }
8990e3ece09SJunchao Zhang     }
9000e3ece09SJunchao Zhang     PetscCall(PetscSFBcastWithMemTypeBegin(bcastSF, MPIU_INT, PETSC_MEMTYPE_HOST, jbuf, PETSC_MEMTYPE_HOST, Fj, MPI_REPLACE));
9010e3ece09SJunchao Zhang     PetscCall(PetscSFBcastEnd(bcastSF, MPIU_INT, jbuf, Fj, MPI_REPLACE));
9020e3ece09SJunchao Zhang 
9030e3ece09SJunchao Zhang     // Build a plan (i.e., F_NzLeft) to split F into Fd and Fo
9047b8d4ba6SJunchao Zhang     MatRowMapKokkosViewHost Fdi_h(NoInit("Fdi_h"), Fm + 1), Foi_h(NoInit("Foi_h"), Fm + 1); // row pointer of Fd, Fo
9057b8d4ba6SJunchao Zhang     MatColIdxKokkosViewHost F_NzLeft_h(NoInit("F_NzLeft_h"), Fm);                           // split each row of F into Left, Diag, Right. We only need to record #nz in Left and Diag.
9060e3ece09SJunchao Zhang     MatRowMapType          *Fdi = Fdi_h.data(), *Foi = Foi_h.data();
9070e3ece09SJunchao Zhang     MatColIdxType          *F_NzLeft = F_NzLeft_h.data();
9080e3ece09SJunchao Zhang 
9090e3ece09SJunchao Zhang     Fdi[0] = Foi[0] = 0;
9100e3ece09SJunchao Zhang     for (PetscInt i = 0; i < Fm; i++) {
9110e3ece09SJunchao Zhang       PetscInt *first, *last, *lb1, *lb2;
9120e3ece09SJunchao Zhang       // cut the row into: Left, [cstart, cend), Right
9130e3ece09SJunchao Zhang       first       = Fj + Fi[i];
9140e3ece09SJunchao Zhang       last        = Fj + Fi[i + 1];
9150e3ece09SJunchao Zhang       lb1         = std::lower_bound(first, last, cstart);
9160e3ece09SJunchao Zhang       F_NzLeft[i] = lb1 - first;
9170e3ece09SJunchao Zhang       lb2         = std::lower_bound(first, last, cend);
9180e3ece09SJunchao Zhang       Fdi[i + 1]  = lb2 - lb1;                        // row i length in Fdi
9190e3ece09SJunchao Zhang       Foi[i + 1]  = (Fi[i + 1] - Fi[i]) - Fdi[i + 1]; // row i length in Foi
9200e3ece09SJunchao Zhang     }
9210e3ece09SJunchao Zhang     for (PetscInt i = 0; i < Fm; i++) {
9220e3ece09SJunchao Zhang       Fdi[i + 1] += Fdi[i];
9230e3ece09SJunchao Zhang       Foi[i + 1] += Foi[i];
9240e3ece09SJunchao Zhang     }
9250e3ece09SJunchao Zhang 
9260e3ece09SJunchao Zhang     // Fill Fdj[] and Foj[], i.e., columns of Fd and Fo. Fdj[] are local, but Foj[] are not yet.
9270e3ece09SJunchao Zhang     PetscInt                Fdnz = Fdi[Fm], Fonz = Foi[Fm];
9287b8d4ba6SJunchao Zhang     MatColIdxKokkosViewHost Fdj_h(NoInit("Fdj_h"), Fdnz), Foj_h(NoInit("Foj_h"), Fonz);
9290e3ece09SJunchao Zhang     MatColIdxType          *Fdj = Fdj_h.data(), *Foj = Foj_h.data(), gid;
9300e3ece09SJunchao Zhang 
9310e3ece09SJunchao Zhang     for (PetscInt i = 0; i < Fm; i++) {
9320e3ece09SJunchao Zhang       PetscInt nzLeft = F_NzLeft[i];
9330e3ece09SJunchao Zhang       PetscInt len    = Fdi[i + 1] - Fdi[i]; // diag row len
9340e3ece09SJunchao Zhang       for (PetscInt j = 0; j < Fi[i + 1] - Fi[i]; j++) {
9350e3ece09SJunchao Zhang         gid = Fj[Fi[i] + j];
9360e3ece09SJunchao Zhang         if (j < nzLeft) { // left, in global
9370e3ece09SJunchao Zhang           Foj[Foi[i] + j] = gid;
9380e3ece09SJunchao Zhang         } else if (j < nzLeft + len) { // diag, in local
9390e3ece09SJunchao Zhang           Fdj[Fdi[i] + j - nzLeft] = gid - cstart;
9400e3ece09SJunchao Zhang         } else { // right, in global
9410e3ece09SJunchao Zhang           Foj[Foi[i] + j - len] = gid;
9420e3ece09SJunchao Zhang         }
9430e3ece09SJunchao Zhang       }
9440e3ece09SJunchao Zhang     }
9450e3ece09SJunchao Zhang     PetscCall(PetscFree2(jbuf, Fj));
9460e3ece09SJunchao Zhang     PetscCall(PetscFree(Fi));
9470e3ece09SJunchao Zhang 
9480e3ece09SJunchao Zhang     // Reduce global indices in Foj[] and garray1[] into local ones
9490e3ece09SJunchao Zhang     PetscInt n2, *garray2;
9500e3ece09SJunchao Zhang     PetscCall(ReduceTwoSetsOfGlobalIndices(n1, garray1, Fonz, Foj, &n2, &garray2, map));
9510e3ece09SJunchao Zhang 
9520e3ece09SJunchao Zhang     // Record the plans built above, for reuse
9530e3ece09SJunchao Zhang     PetscIntKokkosViewHost tmp(const_cast<PetscInt *>(irootloc), ioffset[niranks]); // irootloc[] is owned by ownerSF. We create a copy for safety
9547b8d4ba6SJunchao Zhang     PetscIntKokkosViewHost irootloc_h(NoInit("irootloc_h"), ioffset[niranks]);
9550e3ece09SJunchao Zhang     Kokkos::deep_copy(irootloc_h, tmp);
9560e3ece09SJunchao Zhang     mm->sf        = bcastSF;
9570e3ece09SJunchao Zhang     mm->E_NzLeft  = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), E_NzLeft_h);
9580e3ece09SJunchao Zhang     mm->F_NzLeft  = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), F_NzLeft_h);
9590e3ece09SJunchao Zhang     mm->irootloc  = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), irootloc_h);
9600e3ece09SJunchao Zhang     mm->rowoffset = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), rowoffset_h);
9617b8d4ba6SJunchao Zhang     mm->rootBuf   = MatScalarKokkosView(NoInit("rootBuf"), nroots);
9627b8d4ba6SJunchao Zhang     mm->leafBuf   = MatScalarKokkosView(NoInit("leafBuf"), nleaves);
9630e3ece09SJunchao Zhang     mm->garray    = garray2;
9640e3ece09SJunchao Zhang     mm->n         = n2;
9650e3ece09SJunchao Zhang 
9660e3ece09SJunchao Zhang     // Output Fd and Fo in KokkosCsrMatrix format
9677b8d4ba6SJunchao Zhang     MatScalarKokkosView Fda_d(NoInit("Fda_d"), Fdnz), Foa_d(NoInit("Foa_d"), Fonz);
9680e3ece09SJunchao Zhang     MatRowMapKokkosView Fdi_d = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Fdi_h);
9690e3ece09SJunchao Zhang     MatColIdxKokkosView Fdj_d = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Fdj_h);
9700e3ece09SJunchao Zhang     MatRowMapKokkosView Foi_d = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Foi_h);
9710e3ece09SJunchao Zhang     MatColIdxKokkosView Foj_d = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Foj_h);
9720e3ece09SJunchao Zhang 
9730e3ece09SJunchao Zhang     PetscCallCXX(mm->Fd = KokkosCsrMatrix("Fd", Fm, cend - cstart, Fdnz, Fda_d, Fdi_d, Fdj_d));
9740e3ece09SJunchao Zhang     PetscCallCXX(mm->Fo = KokkosCsrMatrix("Fo", Fm, n2, Fonz, Foa_d, Foi_d, Foj_d));
9750e3ece09SJunchao Zhang 
9760e3ece09SJunchao Zhang     // Compute kernel launch parameters in merging E or splitting F
9770e3ece09SJunchao Zhang     PetscInt teamSize, vectorLength, rowsPerTeam;
9780e3ece09SJunchao Zhang 
9790e3ece09SJunchao Zhang     teamSize = vectorLength = rowsPerTeam = -1;
9800e3ece09SJunchao Zhang     PetscCall(MatMergeGetLaunchParameters<DefaultExecutionSpace>(mm->irootloc.extent(0), mm->rootBuf.extent(0), -1, teamSize, vectorLength, rowsPerTeam));
9810e3ece09SJunchao Zhang     mm->E_TeamSize     = teamSize;
9820e3ece09SJunchao Zhang     mm->E_VectorLength = vectorLength;
9830e3ece09SJunchao Zhang     mm->E_RowsPerTeam  = rowsPerTeam;
9840e3ece09SJunchao Zhang 
9850e3ece09SJunchao Zhang     teamSize = vectorLength = rowsPerTeam = -1;
9860e3ece09SJunchao Zhang     PetscCall(MatMergeGetLaunchParameters<DefaultExecutionSpace>(Fm, Fnz, -1, teamSize, vectorLength, rowsPerTeam));
9870e3ece09SJunchao Zhang     mm->F_TeamSize     = teamSize;
9880e3ece09SJunchao Zhang     mm->F_VectorLength = vectorLength;
9890e3ece09SJunchao Zhang     mm->F_RowsPerTeam  = rowsPerTeam;
9900e3ece09SJunchao Zhang   } else PetscCheck(reuse == MAT_REUSE_MATRIX, comm, PETSC_ERR_PLIB, "Unsupported MatReuse enum %d", reuse);
9910e3ece09SJunchao Zhang 
9920e3ece09SJunchao Zhang   // Sync E's value to device
9930e3ece09SJunchao Zhang   akok->a_dual.sync_device();
9940e3ece09SJunchao Zhang   bkok->a_dual.sync_device();
9950e3ece09SJunchao Zhang 
9960e3ece09SJunchao Zhang   // Handy aliases
9970e3ece09SJunchao Zhang   const auto &Aa = akok->a_dual.view_device();
9980e3ece09SJunchao Zhang   const auto &Ba = bkok->a_dual.view_device();
9990e3ece09SJunchao Zhang   const auto &Ai = akok->i_dual.view_device();
10000e3ece09SJunchao Zhang   const auto &Bi = bkok->i_dual.view_device();
10010e3ece09SJunchao Zhang 
10020e3ece09SJunchao Zhang   // Fetch the plans
10030e3ece09SJunchao Zhang   PetscIntKokkosView  &E_NzLeft  = mm->E_NzLeft;
10040e3ece09SJunchao Zhang   PetscSF             &bcastSF   = mm->sf;
10050e3ece09SJunchao Zhang   MatScalarKokkosView &rootBuf   = mm->rootBuf;
10060e3ece09SJunchao Zhang   MatScalarKokkosView &leafBuf   = mm->leafBuf;
10070e3ece09SJunchao Zhang   PetscIntKokkosView  &irootloc  = mm->irootloc;
10080e3ece09SJunchao Zhang   PetscIntKokkosView  &rowoffset = mm->rowoffset;
10090e3ece09SJunchao Zhang 
10100e3ece09SJunchao Zhang   PetscInt teamSize     = mm->E_TeamSize;
10110e3ece09SJunchao Zhang   PetscInt vectorLength = mm->E_VectorLength;
10120e3ece09SJunchao Zhang   PetscInt rowsPerTeam  = mm->E_RowsPerTeam;
10130e3ece09SJunchao Zhang   PetscInt workSets     = (irootloc.extent(0) + rowsPerTeam - 1) / rowsPerTeam;
10140e3ece09SJunchao Zhang 
10150e3ece09SJunchao Zhang   // Copy rows in A/B of E to rootBuf, then bcast it to leafBuf
10160e3ece09SJunchao Zhang   PetscCallCXX(Kokkos::parallel_for(
1017d326c3f1SJunchao Zhang     Kokkos::TeamPolicy<>(PetscGetKokkosExecutionSpace(), workSets, teamSize, vectorLength), KOKKOS_LAMBDA(const KokkosTeamMemberType &t) {
10180e3ece09SJunchao Zhang       Kokkos::parallel_for(Kokkos::TeamThreadRange(t, 0, rowsPerTeam), [&](PetscInt k) {
10190e3ece09SJunchao Zhang         size_t r = t.league_rank() * rowsPerTeam + k; // r-th entry in irootloc[]
10200e3ece09SJunchao Zhang         if (r < irootloc.extent(0)) {
10210e3ece09SJunchao Zhang           PetscInt i      = irootloc(r); // row i of E
10220e3ece09SJunchao Zhang           PetscInt disp   = rowoffset(r);
10230e3ece09SJunchao Zhang           PetscInt alen   = Ai(i + 1) - Ai(i);
10240e3ece09SJunchao Zhang           PetscInt blen   = Bi(i + 1) - Bi(i);
10250e3ece09SJunchao Zhang           PetscInt nzleft = E_NzLeft(i);
10260e3ece09SJunchao Zhang 
10270e3ece09SJunchao Zhang           Kokkos::parallel_for(Kokkos::ThreadVectorRange(t, alen + blen), [&](PetscInt j) {
10280e3ece09SJunchao Zhang             if (j < nzleft) { // B left
10290e3ece09SJunchao Zhang               rootBuf(disp + j) = Ba(Bi(i) + j);
10300e3ece09SJunchao Zhang             } else if (j < nzleft + alen) { // diag A
10310e3ece09SJunchao Zhang               rootBuf(disp + j) = Aa(Ai(i) + j - nzleft);
10320e3ece09SJunchao Zhang             } else { // B right
10330e3ece09SJunchao Zhang               rootBuf(disp + j) = Ba(Bi(i) + j - alen);
10340e3ece09SJunchao Zhang             }
10350e3ece09SJunchao Zhang           });
10360e3ece09SJunchao Zhang         }
10370e3ece09SJunchao Zhang       });
10380e3ece09SJunchao Zhang     }));
10390e3ece09SJunchao Zhang   PetscCall(PetscSFBcastWithMemTypeBegin(bcastSF, MPIU_SCALAR, PETSC_MEMTYPE_KOKKOS, rootBuf.data(), PETSC_MEMTYPE_KOKKOS, leafBuf.data(), MPI_REPLACE));
10400e3ece09SJunchao Zhang   PetscFunctionReturn(PETSC_SUCCESS);
10410e3ece09SJunchao Zhang }
10420e3ece09SJunchao Zhang 
10430e3ece09SJunchao Zhang // To finish MatMPIAIJKokkosBcast.
10440e3ece09SJunchao Zhang static PetscErrorCode MatMPIAIJKokkosBcastEnd(Mat E, PetscSF ownerSF, MatReuse reuse, PetscInt *map, MatMatStruct_AB *mm)
10450e3ece09SJunchao Zhang {
10460e3ece09SJunchao Zhang   PetscFunctionBegin;
10470e3ece09SJunchao Zhang   const auto &Fd  = mm->Fd;
10480e3ece09SJunchao Zhang   const auto &Fo  = mm->Fo;
10490e3ece09SJunchao Zhang   const auto &Fdi = Fd.graph.row_map;
10500e3ece09SJunchao Zhang   const auto &Foi = Fo.graph.row_map;
10510e3ece09SJunchao Zhang   auto       &Fda = Fd.values;
10520e3ece09SJunchao Zhang   auto       &Foa = Fo.values;
10530e3ece09SJunchao Zhang   auto        Fm  = Fd.numRows();
10540e3ece09SJunchao Zhang 
10550e3ece09SJunchao Zhang   PetscIntKokkosView  &F_NzLeft     = mm->F_NzLeft;
10560e3ece09SJunchao Zhang   PetscSF             &bcastSF      = mm->sf;
10570e3ece09SJunchao Zhang   MatScalarKokkosView &rootBuf      = mm->rootBuf;
10580e3ece09SJunchao Zhang   MatScalarKokkosView &leafBuf      = mm->leafBuf;
10590e3ece09SJunchao Zhang   PetscInt             teamSize     = mm->F_TeamSize;
10600e3ece09SJunchao Zhang   PetscInt             vectorLength = mm->F_VectorLength;
10610e3ece09SJunchao Zhang   PetscInt             rowsPerTeam  = mm->F_RowsPerTeam;
10620e3ece09SJunchao Zhang   PetscInt             workSets     = (Fm + rowsPerTeam - 1) / rowsPerTeam;
10630e3ece09SJunchao Zhang 
10640e3ece09SJunchao Zhang   PetscCall(PetscSFBcastEnd(bcastSF, MPIU_SCALAR, rootBuf.data(), leafBuf.data(), MPI_REPLACE));
10650e3ece09SJunchao Zhang 
10660e3ece09SJunchao Zhang   // Update Fda and Foa with new data in leafBuf (as if it is Fa)
10670e3ece09SJunchao Zhang   PetscCallCXX(Kokkos::parallel_for(
1068d326c3f1SJunchao Zhang     Kokkos::TeamPolicy<>(PetscGetKokkosExecutionSpace(), workSets, teamSize, vectorLength), KOKKOS_LAMBDA(const KokkosTeamMemberType &t) {
10690e3ece09SJunchao Zhang       Kokkos::parallel_for(Kokkos::TeamThreadRange(t, 0, rowsPerTeam), [&](PetscInt k) {
10700e3ece09SJunchao Zhang         PetscInt i = t.league_rank() * rowsPerTeam + k; // i-th row in F
10710e3ece09SJunchao Zhang         if (i < Fm) {
10720e3ece09SJunchao Zhang           PetscInt nzLeft = F_NzLeft(i);
10730e3ece09SJunchao Zhang           PetscInt alen   = Fdi(i + 1) - Fdi(i);
10740e3ece09SJunchao Zhang           PetscInt blen   = Foi(i + 1) - Foi(i);
10750e3ece09SJunchao Zhang           PetscInt Fii    = Fdi(i) + Foi(i);
10760e3ece09SJunchao Zhang 
10770e3ece09SJunchao Zhang           Kokkos::parallel_for(Kokkos::ThreadVectorRange(t, alen + blen), [&](PetscInt j) {
10780e3ece09SJunchao Zhang             PetscScalar val = leafBuf(Fii + j);
10790e3ece09SJunchao Zhang             if (j < nzLeft) { // left
10800e3ece09SJunchao Zhang               Foa(Foi(i) + j) = val;
10810e3ece09SJunchao Zhang             } else if (j < nzLeft + alen) { // diag
10820e3ece09SJunchao Zhang               Fda(Fdi(i) + j - nzLeft) = val;
10830e3ece09SJunchao Zhang             } else { // right
10840e3ece09SJunchao Zhang               Foa(Foi(i) + j - alen) = val;
10850e3ece09SJunchao Zhang             }
10860e3ece09SJunchao Zhang           });
10870e3ece09SJunchao Zhang         }
10880e3ece09SJunchao Zhang       });
10890e3ece09SJunchao Zhang     }));
10900e3ece09SJunchao Zhang   PetscFunctionReturn(PETSC_SUCCESS);
10910e3ece09SJunchao Zhang }
10920e3ece09SJunchao Zhang 
10930e3ece09SJunchao Zhang static PetscErrorCode MatProductSymbolic_MPIAIJKokkos_AtB(Mat_Product *product, Mat A, Mat B, MatMatStruct_AtB *mm)
10940e3ece09SJunchao Zhang {
10950e3ece09SJunchao Zhang   Mat_MPIAIJ     *ampi = static_cast<Mat_MPIAIJ *>(A->data);
10960e3ece09SJunchao Zhang   Mat_MPIAIJ     *bmpi = static_cast<Mat_MPIAIJ *>(B->data);
10970e3ece09SJunchao Zhang   KokkosCsrMatrix Adt, Aot, Ad, Ao, Bd, Bo;
10980e3ece09SJunchao Zhang   PetscInt        cstart, cend;
10990e3ece09SJunchao Zhang   MPI_Comm        comm;
11000e3ece09SJunchao Zhang 
11010e3ece09SJunchao Zhang   PetscFunctionBegin;
11020e3ece09SJunchao Zhang   PetscCall(PetscObjectGetComm((PetscObject)B, &comm));
11030e3ece09SJunchao Zhang   PetscCall(MatSeqAIJKokkosGenerateTranspose_Private(ampi->A, &Adt));
11040e3ece09SJunchao Zhang   PetscCall(MatSeqAIJKokkosGenerateTranspose_Private(ampi->B, &Aot));
11050e3ece09SJunchao Zhang   PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(ampi->A, &Ad));
11060e3ece09SJunchao Zhang   PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(ampi->B, &Ao));
11070e3ece09SJunchao Zhang   PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(bmpi->A, &Bd));
11080e3ece09SJunchao Zhang   PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(bmpi->B, &Bo));
11090e3ece09SJunchao Zhang 
11100e3ece09SJunchao Zhang   // TODO: add command line options to select spgemm algorithms
11110e3ece09SJunchao Zhang   auto spgemm_alg = KokkosSparse::SPGEMMAlgorithm::SPGEMM_DEFAULT; // default is TPL if enabled, otherwise KK
11120e3ece09SJunchao Zhang 
11130e3ece09SJunchao Zhang   // CUDA-10.2's spgemm has bugs. We prefer the SpGEMMreuse APIs introduced in cuda-11.4
11140e3ece09SJunchao Zhang #if defined(KOKKOSKERNELS_ENABLE_TPL_CUSPARSE)
11150e3ece09SJunchao Zhang   #if PETSC_PKG_CUDA_VERSION_LT(11, 4, 0)
11160e3ece09SJunchao Zhang   spgemm_alg = KokkosSparse::SPGEMMAlgorithm::SPGEMM_KK;
1117f0e6e2d1SJunchao Zhang   #endif
11180e3ece09SJunchao Zhang #endif
11190e3ece09SJunchao Zhang 
11200e3ece09SJunchao Zhang   PetscCallCXX(mm->kh1.create_spgemm_handle(spgemm_alg));
11210e3ece09SJunchao Zhang   PetscCallCXX(mm->kh2.create_spgemm_handle(spgemm_alg));
11220e3ece09SJunchao Zhang   PetscCallCXX(mm->kh3.create_spgemm_handle(spgemm_alg));
11230e3ece09SJunchao Zhang   PetscCallCXX(mm->kh4.create_spgemm_handle(spgemm_alg));
11240e3ece09SJunchao Zhang 
11250e3ece09SJunchao Zhang   // Aot * (B's diag + B's off-diag)
11260e3ece09SJunchao Zhang   PetscCallCXX(KokkosSparse::spgemm_symbolic(mm->kh3, Aot, false, Bd, false, mm->C3));
11270e3ece09SJunchao Zhang   PetscCallCXX(KokkosSparse::spgemm_symbolic(mm->kh4, Aot, false, Bo, false, mm->C4));
11280e3ece09SJunchao Zhang   // KK spgemm_symbolic() only populates the result's row map, but not its columns.
11290e3ece09SJunchao Zhang   // TODO: Remove the fake spgemm_numeric() after KK fixed this problem.
11300e3ece09SJunchao Zhang   PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh3, Aot, false, Bd, false, mm->C3));
11310e3ece09SJunchao Zhang   PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh4, Aot, false, Bo, false, mm->C4));
11320e3ece09SJunchao Zhang #if PETSC_PKG_KOKKOS_KERNELS_VERSION_LT(4, 0, 0)
1133d326c3f1SJunchao Zhang 
11340e3ece09SJunchao Zhang   PetscCallCXX(sort_crs_matrix(mm->C3));
11350e3ece09SJunchao Zhang   PetscCallCXX(sort_crs_matrix(mm->C4));
11360e3ece09SJunchao Zhang #endif
11370e3ece09SJunchao Zhang 
11380e3ece09SJunchao Zhang   // Reduce E (i.e., C3 and C4)'s rows to form F, and overlap the communication
11397b8d4ba6SJunchao Zhang   PetscIntKokkosViewHost map_h(NoInit("map_h"), bmpi->B->cmap->n);
11400e3ece09SJunchao Zhang   PetscCall(MatGetOwnershipRangeColumn(B, &cstart, &cend));
11410e3ece09SJunchao Zhang   PetscCall(MatMPIAIJKokkosReduceBegin(comm, mm->C3, mm->C4, cstart, cend, bmpi->garray, ampi->Mvctx, MAT_INITIAL_MATRIX, map_h.data(), mm));
11420e3ece09SJunchao Zhang 
11430e3ece09SJunchao Zhang   // Adt * (B's diag + B's off-diag)
11440e3ece09SJunchao Zhang   PetscCallCXX(KokkosSparse::spgemm_symbolic(mm->kh1, Adt, false, Bd, false, mm->C1));
11450e3ece09SJunchao Zhang   PetscCallCXX(KokkosSparse::spgemm_symbolic(mm->kh2, Adt, false, Bo, false, mm->C2_mid));
11460e3ece09SJunchao Zhang   PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh1, Adt, false, Bd, false, mm->C1));
11470e3ece09SJunchao Zhang   PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh2, Adt, false, Bo, false, mm->C2_mid));
11480e3ece09SJunchao Zhang #if PETSC_PKG_KOKKOS_KERNELS_VERSION_LT(4, 0, 0)
11490e3ece09SJunchao Zhang   PetscCallCXX(sort_crs_matrix(mm->C1));
11500e3ece09SJunchao Zhang   PetscCallCXX(sort_crs_matrix(mm->C2_mid));
11510e3ece09SJunchao Zhang #endif
11520e3ece09SJunchao Zhang 
11530e3ece09SJunchao Zhang   PetscCall(MatMPIAIJKokkosReduceEnd(comm, mm->C3, mm->C4, cstart, cend, bmpi->garray, ampi->Mvctx, MAT_INITIAL_MATRIX, map_h.data(), mm));
11540e3ece09SJunchao Zhang 
11550e3ece09SJunchao Zhang   // Create C2, which shares a, i arrays with C2_mid, but with new column indices and potentially larger column size
11567b8d4ba6SJunchao Zhang   MatColIdxKokkosView oldj = mm->C2_mid.graph.entries, newj(NoInit("j"), oldj.extent(0));
11570e3ece09SJunchao Zhang   PetscIntKokkosView  map  = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), map_h);
1158d326c3f1SJunchao Zhang   PetscCallCXX(Kokkos::parallel_for(Kokkos::RangePolicy<>(PetscGetKokkosExecutionSpace(), 0, oldj.extent(0)), KOKKOS_LAMBDA(const PetscInt i) { newj(i) = map(oldj(i)); }));
11590e3ece09SJunchao Zhang   PetscCallCXX(mm->C2 = KokkosCsrMatrix("C2", mm->C2_mid.numRows(), mm->n /*new column size*/, mm->C2_mid.nnz(), mm->C2_mid.values, mm->C2_mid.graph.row_map, newj));
11600e3ece09SJunchao Zhang 
11610e3ece09SJunchao Zhang   // C = (C1+Fd, C2+Fo)
11620e3ece09SJunchao Zhang   PetscCallCXX(mm->kh1.create_spadd_handle(true)); // C1, Fd are sorted
11630e3ece09SJunchao Zhang   PetscCallCXX(mm->kh2.create_spadd_handle(true)); // C2, Fo are sorted
11640e3ece09SJunchao Zhang   PetscCallCXX(KokkosSparse::spadd_symbolic(&mm->kh1, mm->C1, mm->Fd, mm->Cd));
11650e3ece09SJunchao Zhang   PetscCallCXX(KokkosSparse::spadd_symbolic(&mm->kh2, mm->C2, mm->Fo, mm->Co));
11660e3ece09SJunchao Zhang   PetscCallCXX(KokkosSparse::spadd_numeric(&mm->kh1, 1.0, mm->C1, 1.0, mm->Fd, mm->Cd));
11670e3ece09SJunchao Zhang   PetscCallCXX(KokkosSparse::spadd_numeric(&mm->kh2, 1.0, mm->C2, 1.0, mm->Fo, mm->Co));
11680e3ece09SJunchao Zhang   PetscFunctionReturn(PETSC_SUCCESS);
11690e3ece09SJunchao Zhang }
11700e3ece09SJunchao Zhang 
11710e3ece09SJunchao Zhang static PetscErrorCode MatProductNumeric_MPIAIJKokkos_AtB(Mat_Product *product, Mat A, Mat B, MatMatStruct_AtB *mm)
11720e3ece09SJunchao Zhang {
11730e3ece09SJunchao Zhang   Mat_MPIAIJ     *ampi = static_cast<Mat_MPIAIJ *>(A->data);
11740e3ece09SJunchao Zhang   Mat_MPIAIJ     *bmpi = static_cast<Mat_MPIAIJ *>(B->data);
11750e3ece09SJunchao Zhang   KokkosCsrMatrix Adt, Aot, Bd, Bo;
11760e3ece09SJunchao Zhang   MPI_Comm        comm;
11770e3ece09SJunchao Zhang 
11780e3ece09SJunchao Zhang   PetscFunctionBegin;
11790e3ece09SJunchao Zhang   PetscCall(PetscObjectGetComm((PetscObject)B, &comm));
11800e3ece09SJunchao Zhang   PetscCall(MatSeqAIJKokkosGenerateTranspose_Private(ampi->A, &Adt));
11810e3ece09SJunchao Zhang   PetscCall(MatSeqAIJKokkosGenerateTranspose_Private(ampi->B, &Aot));
11820e3ece09SJunchao Zhang   PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(bmpi->A, &Bd));
11830e3ece09SJunchao Zhang   PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(bmpi->B, &Bo));
11840e3ece09SJunchao Zhang 
11850e3ece09SJunchao Zhang   // Aot * (B's diag + B's off-diag)
11860e3ece09SJunchao Zhang   PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh3, Aot, false, Bd, false, mm->C3));
11870e3ece09SJunchao Zhang   PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh4, Aot, false, Bo, false, mm->C4));
11880e3ece09SJunchao Zhang 
11890e3ece09SJunchao Zhang   // Reduce E (i.e., C3 and C4)'s rows to form F, and overlap the communication
11900e3ece09SJunchao Zhang   PetscCall(MatMPIAIJKokkosReduceBegin(comm, mm->C3, mm->C4, 0, 0, NULL, NULL, MAT_REUSE_MATRIX, NULL, mm));
11910e3ece09SJunchao Zhang 
11920e3ece09SJunchao Zhang   // Adt * (B's diag + B's off-diag)
11930e3ece09SJunchao Zhang   PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh1, Adt, false, Bd, false, mm->C1));
11940e3ece09SJunchao Zhang   PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh2, Adt, false, Bo, false, mm->C2_mid));
11950e3ece09SJunchao Zhang 
11960e3ece09SJunchao Zhang   PetscCall(MatMPIAIJKokkosReduceEnd(comm, mm->C3, mm->C4, 0, 0, NULL, NULL, MAT_REUSE_MATRIX, NULL, mm));
11970e3ece09SJunchao Zhang 
11980e3ece09SJunchao Zhang   // C = (C1+Fd, C2+Fo)
11990e3ece09SJunchao Zhang   PetscCallCXX(KokkosSparse::spadd_numeric(&mm->kh1, 1.0, mm->C1, 1.0, mm->Fd, mm->Cd));
12000e3ece09SJunchao Zhang   PetscCallCXX(KokkosSparse::spadd_numeric(&mm->kh2, 1.0, mm->C2, 1.0, mm->Fo, mm->Co));
12010e3ece09SJunchao Zhang   PetscFunctionReturn(PETSC_SUCCESS);
12020e3ece09SJunchao Zhang }
1203f0e6e2d1SJunchao Zhang 
1204076ba34aSJunchao Zhang /* MatProductSymbolic_MPIAIJKokkos_AB - AB flavor of MatProductSymbolic_MPIAIJKokkos
1205076ba34aSJunchao Zhang 
1206076ba34aSJunchao Zhang   Input Parameters:
1207076ba34aSJunchao Zhang +  product  - Mat_Product which carried out the computation. Passed in to access info about this mat product.
1208076ba34aSJunchao Zhang .  A        - an MPIAIJKOKKOS matrix
1209076ba34aSJunchao Zhang .  B        - an MPIAIJKOKKOS matrix
1210076ba34aSJunchao Zhang -  mm       - a struct used to stash intermediate data when computing AB. Persist from symbolic to numeric operations.
1211076ba34aSJunchao Zhang */
1212d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatProductSymbolic_MPIAIJKokkos_AB(Mat_Product *product, Mat A, Mat B, MatMatStruct_AB *mm)
1213d71ae5a4SJacob Faibussowitsch {
12140e3ece09SJunchao Zhang   Mat_MPIAIJ     *ampi = static_cast<Mat_MPIAIJ *>(A->data);
12150e3ece09SJunchao Zhang   Mat_MPIAIJ     *bmpi = static_cast<Mat_MPIAIJ *>(B->data);
12160e3ece09SJunchao Zhang   KokkosCsrMatrix Ad, Ao, Bd, Bo;
1217076ba34aSJunchao Zhang 
1218076ba34aSJunchao Zhang   PetscFunctionBegin;
12190e3ece09SJunchao Zhang   PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(ampi->A, &Ad));
12200e3ece09SJunchao Zhang   PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(ampi->B, &Ao));
12210e3ece09SJunchao Zhang   PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(bmpi->A, &Bd));
12220e3ece09SJunchao Zhang   PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(bmpi->B, &Bo));
12230e3ece09SJunchao Zhang 
12240e3ece09SJunchao Zhang   // TODO: add command line options to select spgemm algorithms
12250e3ece09SJunchao Zhang   auto spgemm_alg = KokkosSparse::SPGEMMAlgorithm::SPGEMM_DEFAULT; // default is TPL if enabled, otherwise KK
12260e3ece09SJunchao Zhang 
12270e3ece09SJunchao Zhang   // CUDA-10.2's spgemm has bugs. We prefer the SpGEMMreuse APIs introduced in cuda-11.4
12280e3ece09SJunchao Zhang #if defined(KOKKOSKERNELS_ENABLE_TPL_CUSPARSE)
12290e3ece09SJunchao Zhang   #if PETSC_PKG_CUDA_VERSION_LT(11, 4, 0)
12300e3ece09SJunchao Zhang   spgemm_alg = KokkosSparse::SPGEMMAlgorithm::SPGEMM_KK;
12310e3ece09SJunchao Zhang   #endif
1232f0e6e2d1SJunchao Zhang #endif
1233f0e6e2d1SJunchao Zhang 
12340e3ece09SJunchao Zhang   mm->kh1.create_spgemm_handle(spgemm_alg);
12350e3ece09SJunchao Zhang   mm->kh2.create_spgemm_handle(spgemm_alg);
12360e3ece09SJunchao Zhang   mm->kh3.create_spgemm_handle(spgemm_alg);
12370e3ece09SJunchao Zhang   mm->kh4.create_spgemm_handle(spgemm_alg);
1238076ba34aSJunchao Zhang 
12390e3ece09SJunchao Zhang   // Bcast B's rows to form F, and overlap the communication
12407b8d4ba6SJunchao Zhang   PetscIntKokkosViewHost map_h(NoInit("map_h"), bmpi->B->cmap->n);
12410e3ece09SJunchao Zhang   PetscCall(MatMPIAIJKokkosBcastBegin(B, ampi->Mvctx, MAT_INITIAL_MATRIX, map_h.data(), mm));
1242076ba34aSJunchao Zhang 
12430e3ece09SJunchao Zhang   // A's diag * (B's diag + B's off-diag)
12440e3ece09SJunchao Zhang   PetscCallCXX(KokkosSparse::spgemm_symbolic(mm->kh1, Ad, false, Bd, false, mm->C1));
12450e3ece09SJunchao Zhang   PetscCallCXX(KokkosSparse::spgemm_symbolic(mm->kh2, Ad, false, Bo, false, mm->C2_mid)); // C2 aliases with C2_mid, except with new column indices
12460e3ece09SJunchao Zhang   // KK spgemm_symbolic() only populates the result's row map, but not its columns.
12470e3ece09SJunchao Zhang   // TODO: Remove the fake spgemm_numeric() after KK fixed this problem.
12480e3ece09SJunchao Zhang   PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh1, Ad, false, Bd, false, mm->C1));
12490e3ece09SJunchao Zhang   PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh2, Ad, false, Bo, false, mm->C2_mid));
12500e3ece09SJunchao Zhang #if PETSC_PKG_KOKKOS_KERNELS_VERSION_LT(4, 0, 0)
12510e3ece09SJunchao Zhang   PetscCallCXX(sort_crs_matrix(mm->C1));
12520e3ece09SJunchao Zhang   PetscCallCXX(sort_crs_matrix(mm->C2_mid));
12530e3ece09SJunchao Zhang #endif
1254076ba34aSJunchao Zhang 
12550e3ece09SJunchao Zhang   PetscCall(MatMPIAIJKokkosBcastEnd(B, ampi->Mvctx, MAT_INITIAL_MATRIX, map_h.data(), mm));
1256076ba34aSJunchao Zhang 
12570e3ece09SJunchao Zhang   // A's off-diag * (F's diag + F's off-diag)
12580e3ece09SJunchao Zhang   PetscCallCXX(KokkosSparse::spgemm_symbolic(mm->kh3, Ao, false, mm->Fd, false, mm->C3));
12590e3ece09SJunchao Zhang   PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh3, Ao, false, mm->Fd, false, mm->C3));
12600e3ece09SJunchao Zhang   PetscCallCXX(KokkosSparse::spgemm_symbolic(mm->kh4, Ao, false, mm->Fo, false, mm->C4));
12610e3ece09SJunchao Zhang   PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh4, Ao, false, mm->Fo, false, mm->C4));
12620e3ece09SJunchao Zhang #if PETSC_PKG_KOKKOS_KERNELS_VERSION_LT(4, 0, 0)
12630e3ece09SJunchao Zhang   PetscCallCXX(sort_crs_matrix(mm->C3));
12640e3ece09SJunchao Zhang   PetscCallCXX(sort_crs_matrix(mm->C4));
12650e3ece09SJunchao Zhang #endif
1266076ba34aSJunchao Zhang 
12670e3ece09SJunchao Zhang   // Create C2, which shares a, i arrays with C2_mid, but with new column indices and potentially larger column size
12687b8d4ba6SJunchao Zhang   MatColIdxKokkosView oldj = mm->C2_mid.graph.entries, newj(NoInit("j"), oldj.extent(0));
12690e3ece09SJunchao Zhang   PetscIntKokkosView  map  = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), map_h);
1270d326c3f1SJunchao Zhang   PetscCallCXX(Kokkos::parallel_for(Kokkos::RangePolicy<>(PetscGetKokkosExecutionSpace(), 0, oldj.extent(0)), KOKKOS_LAMBDA(const PetscInt i) { newj(i) = map(oldj(i)); }));
12710e3ece09SJunchao Zhang   mm->C2 = KokkosCsrMatrix("C2", mm->C2_mid.numRows(), mm->n /*new column size*/, mm->C2_mid.nnz(), mm->C2_mid.values, mm->C2_mid.graph.row_map, newj);
12720e3ece09SJunchao Zhang 
12730e3ece09SJunchao Zhang   // C = (Cd, Co) = (C1+C3, C2+C4)
12740e3ece09SJunchao Zhang   mm->kh1.create_spadd_handle(true); // C1, C3 are sorted
12750e3ece09SJunchao Zhang   mm->kh2.create_spadd_handle(true); // C2, C4 are sorted
12760e3ece09SJunchao Zhang   PetscCallCXX(KokkosSparse::spadd_symbolic(&mm->kh1, mm->C1, mm->C3, mm->Cd));
12770e3ece09SJunchao Zhang   PetscCallCXX(KokkosSparse::spadd_symbolic(&mm->kh2, mm->C2, mm->C4, mm->Co));
12780e3ece09SJunchao Zhang   PetscCallCXX(KokkosSparse::spadd_numeric(&mm->kh1, 1.0, mm->C1, 1.0, mm->C3, mm->Cd));
12790e3ece09SJunchao Zhang   PetscCallCXX(KokkosSparse::spadd_numeric(&mm->kh2, 1.0, mm->C2, 1.0, mm->C4, mm->Co));
12803ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
1281076ba34aSJunchao Zhang }
1282076ba34aSJunchao Zhang 
12830e3ece09SJunchao Zhang static PetscErrorCode MatProductNumeric_MPIAIJKokkos_AB(Mat_Product *product, Mat A, Mat B, MatMatStruct_AB *mm)
1284d71ae5a4SJacob Faibussowitsch {
12850e3ece09SJunchao Zhang   Mat_MPIAIJ     *ampi = static_cast<Mat_MPIAIJ *>(A->data);
12860e3ece09SJunchao Zhang   Mat_MPIAIJ     *bmpi = static_cast<Mat_MPIAIJ *>(B->data);
12870e3ece09SJunchao Zhang   KokkosCsrMatrix Ad, Ao, Bd, Bo;
1288076ba34aSJunchao Zhang 
1289076ba34aSJunchao Zhang   PetscFunctionBegin;
12900e3ece09SJunchao Zhang   PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(ampi->A, &Ad));
12910e3ece09SJunchao Zhang   PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(ampi->B, &Ao));
12920e3ece09SJunchao Zhang   PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(bmpi->A, &Bd));
12930e3ece09SJunchao Zhang   PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(bmpi->B, &Bo));
1294076ba34aSJunchao Zhang 
12950e3ece09SJunchao Zhang   // Bcast B's rows to form F, and overlap the communication
12960e3ece09SJunchao Zhang   PetscCall(MatMPIAIJKokkosBcastBegin(B, NULL, MAT_REUSE_MATRIX, NULL, mm));
1297076ba34aSJunchao Zhang 
12980e3ece09SJunchao Zhang   // A's diag * (B's diag + B's off-diag)
12990e3ece09SJunchao Zhang   PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh1, Ad, false, Bd, false, mm->C1));
13000e3ece09SJunchao Zhang   PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh2, Ad, false, Bo, false, mm->C2_mid));
1301076ba34aSJunchao Zhang 
13020e3ece09SJunchao Zhang   PetscCall(MatMPIAIJKokkosBcastEnd(B, NULL, MAT_REUSE_MATRIX, NULL, mm));
1303076ba34aSJunchao Zhang 
13040e3ece09SJunchao Zhang   // A's off-diag * (F's diag + F's off-diag)
13050e3ece09SJunchao Zhang   PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh3, Ao, false, mm->Fd, false, mm->C3));
13060e3ece09SJunchao Zhang   PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh4, Ao, false, mm->Fo, false, mm->C4));
13070e3ece09SJunchao Zhang 
13080e3ece09SJunchao Zhang   // C = (Cd, Co) = (C1+C3, C2+C4)
13090e3ece09SJunchao Zhang   PetscCallCXX(KokkosSparse::spadd_numeric(&mm->kh1, 1.0, mm->C1, 1.0, mm->C3, mm->Cd));
13100e3ece09SJunchao Zhang   PetscCallCXX(KokkosSparse::spadd_numeric(&mm->kh2, 1.0, mm->C2, 1.0, mm->C4, mm->Co));
13113ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
1312076ba34aSJunchao Zhang }
1313076ba34aSJunchao Zhang 
131466976f2fSJacob Faibussowitsch static PetscErrorCode MatProductNumeric_MPIAIJKokkos(Mat C)
1315d71ae5a4SJacob Faibussowitsch {
13160e3ece09SJunchao Zhang   Mat_MPIAIJ                  *cmpi = static_cast<Mat_MPIAIJ *>(C->data);
13170e3ece09SJunchao Zhang   Mat_Product                 *product;
13180e3ece09SJunchao Zhang   MatProductData_MPIAIJKokkos *pdata;
1319076ba34aSJunchao Zhang   MatProductType               ptype;
13200e3ece09SJunchao Zhang   Mat                          A, B;
1321076ba34aSJunchao Zhang 
1322076ba34aSJunchao Zhang   PetscFunctionBegin;
13230e3ece09SJunchao Zhang   MatCheckProduct(C, 1); // make sure C is a product
13240e3ece09SJunchao Zhang   product = C->product;
13250e3ece09SJunchao Zhang   pdata   = static_cast<MatProductData_MPIAIJKokkos *>(product->data);
1326076ba34aSJunchao Zhang   ptype   = product->type;
1327076ba34aSJunchao Zhang   A       = product->A;
1328076ba34aSJunchao Zhang   B       = product->B;
1329076ba34aSJunchao Zhang 
13300e3ece09SJunchao Zhang   // See if numeric has already been done in symbolic (e.g., user calls MatMatMult(A,B,MAT_INITIAL_MATRIX,..,C)).
13310e3ece09SJunchao Zhang   // If yes, skip the numeric, but reset the flag so that next time when user calls MatMatMult(E,F,MAT_REUSE_MATRIX,..,C),
13320e3ece09SJunchao Zhang   // we still do numeric.
13330e3ece09SJunchao Zhang   if (pdata->reusesym) { // numeric reuses results from symbolic
13340e3ece09SJunchao Zhang     pdata->reusesym = PETSC_FALSE;
13353ba16761SJacob Faibussowitsch     PetscFunctionReturn(PETSC_SUCCESS);
1336076ba34aSJunchao Zhang   }
1337076ba34aSJunchao Zhang 
1338076ba34aSJunchao Zhang   if (ptype == MATPRODUCT_AB) {
13390e3ece09SJunchao Zhang     PetscCall(MatProductNumeric_MPIAIJKokkos_AB(product, A, B, pdata->mmAB));
1340076ba34aSJunchao Zhang   } else if (ptype == MATPRODUCT_AtB) {
13410e3ece09SJunchao Zhang     PetscCall(MatProductNumeric_MPIAIJKokkos_AtB(product, A, B, pdata->mmAtB));
13420e3ece09SJunchao Zhang   } else if (ptype == MATPRODUCT_PtAP) { // BtAB, computed by Z = AB; C= BtZ
13430e3ece09SJunchao Zhang     PetscCall(MatProductNumeric_MPIAIJKokkos_AB(product, A, B, pdata->mmAB));
13440e3ece09SJunchao Zhang     PetscCall(MatProductNumeric_MPIAIJKokkos_AtB(product, B, pdata->Z, pdata->mmAtB));
1345076ba34aSJunchao Zhang   }
13460e3ece09SJunchao Zhang   PetscCall(MatSeqAIJKokkosModifyDevice(cmpi->A)); // mark that A, B on device are modified
13470e3ece09SJunchao Zhang   PetscCall(MatSeqAIJKokkosModifyDevice(cmpi->B));
13483ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
1349076ba34aSJunchao Zhang }
1350076ba34aSJunchao Zhang 
135166976f2fSJacob Faibussowitsch static PetscErrorCode MatProductSymbolic_MPIAIJKokkos(Mat C)
1352d71ae5a4SJacob Faibussowitsch {
1353076ba34aSJunchao Zhang   Mat                          A, B;
13540e3ece09SJunchao Zhang   Mat_Product                 *product;
1355076ba34aSJunchao Zhang   MatProductType               ptype;
13560e3ece09SJunchao Zhang   MatProductData_MPIAIJKokkos *pdata;
1357076ba34aSJunchao Zhang   MatMatStruct                *mm = NULL;
13580e3ece09SJunchao Zhang   PetscInt                     m, n, M, N;
13590e3ece09SJunchao Zhang   Mat                          Cd, Co;
13600e3ece09SJunchao Zhang   MPI_Comm                     comm;
1361076ba34aSJunchao Zhang 
1362076ba34aSJunchao Zhang   PetscFunctionBegin;
13630e3ece09SJunchao Zhang   PetscCall(PetscObjectGetComm((PetscObject)C, &comm));
1364076ba34aSJunchao Zhang   MatCheckProduct(C, 1);
13650e3ece09SJunchao Zhang   product = C->product;
13660e3ece09SJunchao Zhang   PetscCheck(!product->data, comm, PETSC_ERR_PLIB, "Product data not empty");
1367076ba34aSJunchao Zhang   ptype = product->type;
1368076ba34aSJunchao Zhang   A     = product->A;
1369076ba34aSJunchao Zhang   B     = product->B;
1370076ba34aSJunchao Zhang 
1371076ba34aSJunchao Zhang   switch (ptype) {
13729371c9d4SSatish Balay   case MATPRODUCT_AB:
13739371c9d4SSatish Balay     m = A->rmap->n;
13749371c9d4SSatish Balay     n = B->cmap->n;
13759371c9d4SSatish Balay     M = A->rmap->N;
13769371c9d4SSatish Balay     N = B->cmap->N;
13779371c9d4SSatish Balay     break;
13789371c9d4SSatish Balay   case MATPRODUCT_AtB:
13799371c9d4SSatish Balay     m = A->cmap->n;
13809371c9d4SSatish Balay     n = B->cmap->n;
13819371c9d4SSatish Balay     M = A->cmap->N;
13829371c9d4SSatish Balay     N = B->cmap->N;
13839371c9d4SSatish Balay     break;
13849371c9d4SSatish Balay   case MATPRODUCT_PtAP:
13859371c9d4SSatish Balay     m = B->cmap->n;
13869371c9d4SSatish Balay     n = B->cmap->n;
13879371c9d4SSatish Balay     M = B->cmap->N;
13889371c9d4SSatish Balay     N = B->cmap->N;
13899371c9d4SSatish Balay     break; /* BtAB */
1390d71ae5a4SJacob Faibussowitsch   default:
13910e3ece09SJunchao Zhang     SETERRQ(comm, PETSC_ERR_PLIB, "Not for product type %s", MatProductTypes[ptype]);
1392076ba34aSJunchao Zhang   }
1393076ba34aSJunchao Zhang 
13949566063dSJacob Faibussowitsch   PetscCall(MatSetSizes(C, m, n, M, N));
13959566063dSJacob Faibussowitsch   PetscCall(PetscLayoutSetUp(C->rmap));
13969566063dSJacob Faibussowitsch   PetscCall(PetscLayoutSetUp(C->cmap));
13979566063dSJacob Faibussowitsch   PetscCall(MatSetType(C, ((PetscObject)A)->type_name));
1398076ba34aSJunchao Zhang 
13990e3ece09SJunchao Zhang   pdata           = new MatProductData_MPIAIJKokkos();
14000e3ece09SJunchao Zhang   pdata->reusesym = product->api_user;
1401076ba34aSJunchao Zhang 
1402076ba34aSJunchao Zhang   if (ptype == MATPRODUCT_AB) {
14030e3ece09SJunchao Zhang     auto mmAB = new MatMatStruct_AB();
14040e3ece09SJunchao Zhang     PetscCall(MatProductSymbolic_MPIAIJKokkos_AB(product, A, B, mmAB));
14050e3ece09SJunchao Zhang     mm = pdata->mmAB = mmAB;
1406076ba34aSJunchao Zhang   } else if (ptype == MATPRODUCT_AtB) {
14070e3ece09SJunchao Zhang     auto mmAtB = new MatMatStruct_AtB();
14080e3ece09SJunchao Zhang     PetscCall(MatProductSymbolic_MPIAIJKokkos_AtB(product, A, B, mmAtB));
14090e3ece09SJunchao Zhang     mm = pdata->mmAtB = mmAtB;
14100e3ece09SJunchao Zhang   } else if (ptype == MATPRODUCT_PtAP) { // C = BtAB, computed as Z = AB; C= BtZ
14110e3ece09SJunchao Zhang     Mat Zd, Zo, Z;                       // Zd, Zo are owned by pdata->Z
14120e3ece09SJunchao Zhang 
14130e3ece09SJunchao Zhang     auto mmAB = new MatMatStruct_AB();
14140e3ece09SJunchao Zhang     PetscCall(MatProductSymbolic_MPIAIJKokkos_AB(product, A, B, mmAB)); // Z stored as mmAB->{Cd, Co}
14150e3ece09SJunchao Zhang     PetscCall(MatCreateSeqAIJKokkosWithKokkosCsrMatrix(PETSC_COMM_SELF, mmAB->Cd, &Zd));
14160e3ece09SJunchao Zhang     PetscCall(MatCreateSeqAIJKokkosWithKokkosCsrMatrix(PETSC_COMM_SELF, mmAB->Co, &Zo));
14170e3ece09SJunchao Zhang     pdata->mmAB = mmAB;
14180e3ece09SJunchao Zhang 
14190e3ece09SJunchao Zhang     m = A->rmap->n; // Z's layout
14200e3ece09SJunchao Zhang     n = B->cmap->n;
14210e3ece09SJunchao Zhang     M = A->rmap->N;
14220e3ece09SJunchao Zhang     N = B->cmap->N;
14230e3ece09SJunchao Zhang     PetscCall(MatCreate(comm, &Z));
14240e3ece09SJunchao Zhang     PetscCall(MatSetSizes(Z, m, n, M, N));
14250e3ece09SJunchao Zhang     PetscCall(PetscLayoutSetUp(Z->rmap));
14260e3ece09SJunchao Zhang     PetscCall(PetscLayoutSetUp(Z->cmap));
14270e3ece09SJunchao Zhang     PetscCall(MatSetType(Z, MATMPIAIJKOKKOS));
14280e3ece09SJunchao Zhang     PetscCall(MatSetMPIAIJKokkosWithSplitSeqAIJKokkosMatrices(Z, Zd, Zo, mmAB->garray));
14290e3ece09SJunchao Zhang 
14300e3ece09SJunchao Zhang     auto mmAtB = new MatMatStruct_AtB();
14310e3ece09SJunchao Zhang     PetscCall(MatProductSymbolic_MPIAIJKokkos_AtB(product, B, Z, mmAtB)); // final result C stored as mmAtB->{Cd, Co}
14320e3ece09SJunchao Zhang 
14330e3ece09SJunchao Zhang     pdata->Z = Z; // give ownership to pdata
14340e3ece09SJunchao Zhang     mm = pdata->mmAtB = mmAtB;
1435076ba34aSJunchao Zhang   }
14360e3ece09SJunchao Zhang 
14370e3ece09SJunchao Zhang   PetscCall(MatCreateSeqAIJKokkosWithKokkosCsrMatrix(PETSC_COMM_SELF, mm->Cd, &Cd));
14380e3ece09SJunchao Zhang   PetscCall(MatCreateSeqAIJKokkosWithKokkosCsrMatrix(PETSC_COMM_SELF, mm->Co, &Co));
14390e3ece09SJunchao Zhang   PetscCall(MatSetMPIAIJKokkosWithSplitSeqAIJKokkosMatrices(C, Cd, Co, mm->garray));
1440*39cfb508SMark Adams   /* set block sizes */
1441*39cfb508SMark Adams   switch (ptype) {
1442*39cfb508SMark Adams   case MATPRODUCT_PtAP:
1443*39cfb508SMark Adams     if (B->cmap->bs > 1) PetscCall(MatSetBlockSizes(C, B->cmap->bs, B->cmap->bs));
1444*39cfb508SMark Adams     break;
1445*39cfb508SMark Adams   case MATPRODUCT_RARt:
1446*39cfb508SMark Adams     if (B->rmap->bs > 1) PetscCall(MatSetBlockSizes(C, B->rmap->bs, B->rmap->bs));
1447*39cfb508SMark Adams     break;
1448*39cfb508SMark Adams   case MATPRODUCT_ABC:
1449*39cfb508SMark Adams     PetscCall(MatSetBlockSizesFromMats(C, A, product->C));
1450*39cfb508SMark Adams     break;
1451*39cfb508SMark Adams   case MATPRODUCT_AB:
1452*39cfb508SMark Adams     PetscCall(MatSetBlockSizesFromMats(C, A, B));
1453*39cfb508SMark Adams     break;
1454*39cfb508SMark Adams   case MATPRODUCT_AtB:
1455*39cfb508SMark Adams     if (A->cmap->bs > 1 || B->cmap->bs > 1) PetscCall(MatSetBlockSizes(C, A->cmap->bs, B->cmap->bs));
1456*39cfb508SMark Adams     break;
1457*39cfb508SMark Adams   case MATPRODUCT_ABt:
1458*39cfb508SMark Adams     if (A->rmap->bs > 1 || B->rmap->bs > 1) PetscCall(MatSetBlockSizes(C, A->rmap->bs, B->rmap->bs));
1459*39cfb508SMark Adams     break;
1460*39cfb508SMark Adams   default:
1461*39cfb508SMark Adams     SETERRQ(PetscObjectComm((PetscObject)C), PETSC_ERR_PLIB, "Not for ProductType %s", MatProductTypes[ptype]);
1462*39cfb508SMark Adams   }
14630e3ece09SJunchao Zhang   C->product->data       = pdata;
1464076ba34aSJunchao Zhang   C->product->destroy    = MatProductDataDestroy_MPIAIJKokkos;
1465076ba34aSJunchao Zhang   C->ops->productnumeric = MatProductNumeric_MPIAIJKokkos;
14663ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
1467076ba34aSJunchao Zhang }
1468076ba34aSJunchao Zhang 
1469d71ae5a4SJacob Faibussowitsch PETSC_INTERN PetscErrorCode MatProductSetFromOptions_MPIAIJKokkos(Mat mat)
1470d71ae5a4SJacob Faibussowitsch {
1471076ba34aSJunchao Zhang   Mat_Product *product = mat->product;
1472076ba34aSJunchao Zhang   PetscBool    match   = PETSC_FALSE;
1473076ba34aSJunchao Zhang   PetscBool    usecpu  = PETSC_FALSE;
1474076ba34aSJunchao Zhang 
1475076ba34aSJunchao Zhang   PetscFunctionBegin;
1476076ba34aSJunchao Zhang   MatCheckProduct(mat, 1);
147748a46eb9SPierre Jolivet   if (!product->A->boundtocpu && !product->B->boundtocpu) PetscCall(PetscObjectTypeCompare((PetscObject)product->B, ((PetscObject)product->A)->type_name, &match));
1478076ba34aSJunchao Zhang   if (match) { /* we can always fallback to the CPU if requested */
1479076ba34aSJunchao Zhang     switch (product->type) {
1480076ba34aSJunchao Zhang     case MATPRODUCT_AB:
1481076ba34aSJunchao Zhang       if (product->api_user) {
1482d0609cedSBarry Smith         PetscOptionsBegin(PetscObjectComm((PetscObject)mat), ((PetscObject)mat)->prefix, "MatMatMult", "Mat");
14839566063dSJacob Faibussowitsch         PetscCall(PetscOptionsBool("-matmatmult_backend_cpu", "Use CPU code", "MatMatMult", usecpu, &usecpu, NULL));
1484d0609cedSBarry Smith         PetscOptionsEnd();
1485076ba34aSJunchao Zhang       } else {
1486d0609cedSBarry Smith         PetscOptionsBegin(PetscObjectComm((PetscObject)mat), ((PetscObject)mat)->prefix, "MatProduct_AB", "Mat");
14879566063dSJacob Faibussowitsch         PetscCall(PetscOptionsBool("-mat_product_algorithm_backend_cpu", "Use CPU code", "MatMatMult", usecpu, &usecpu, NULL));
1488d0609cedSBarry Smith         PetscOptionsEnd();
1489076ba34aSJunchao Zhang       }
1490076ba34aSJunchao Zhang       break;
1491076ba34aSJunchao Zhang     case MATPRODUCT_AtB:
1492076ba34aSJunchao Zhang       if (product->api_user) {
1493d0609cedSBarry Smith         PetscOptionsBegin(PetscObjectComm((PetscObject)mat), ((PetscObject)mat)->prefix, "MatTransposeMatMult", "Mat");
14949566063dSJacob Faibussowitsch         PetscCall(PetscOptionsBool("-mattransposematmult_backend_cpu", "Use CPU code", "MatTransposeMatMult", usecpu, &usecpu, NULL));
1495d0609cedSBarry Smith         PetscOptionsEnd();
1496076ba34aSJunchao Zhang       } else {
1497d0609cedSBarry Smith         PetscOptionsBegin(PetscObjectComm((PetscObject)mat), ((PetscObject)mat)->prefix, "MatProduct_AtB", "Mat");
14989566063dSJacob Faibussowitsch         PetscCall(PetscOptionsBool("-mat_product_algorithm_backend_cpu", "Use CPU code", "MatTransposeMatMult", usecpu, &usecpu, NULL));
1499d0609cedSBarry Smith         PetscOptionsEnd();
1500076ba34aSJunchao Zhang       }
1501076ba34aSJunchao Zhang       break;
1502076ba34aSJunchao Zhang     case MATPRODUCT_PtAP:
1503076ba34aSJunchao Zhang       if (product->api_user) {
1504d0609cedSBarry Smith         PetscOptionsBegin(PetscObjectComm((PetscObject)mat), ((PetscObject)mat)->prefix, "MatPtAP", "Mat");
15059566063dSJacob Faibussowitsch         PetscCall(PetscOptionsBool("-matptap_backend_cpu", "Use CPU code", "MatPtAP", usecpu, &usecpu, NULL));
1506d0609cedSBarry Smith         PetscOptionsEnd();
1507076ba34aSJunchao Zhang       } else {
1508d0609cedSBarry Smith         PetscOptionsBegin(PetscObjectComm((PetscObject)mat), ((PetscObject)mat)->prefix, "MatProduct_PtAP", "Mat");
15099566063dSJacob Faibussowitsch         PetscCall(PetscOptionsBool("-mat_product_algorithm_backend_cpu", "Use CPU code", "MatPtAP", usecpu, &usecpu, NULL));
1510d0609cedSBarry Smith         PetscOptionsEnd();
1511076ba34aSJunchao Zhang       }
1512076ba34aSJunchao Zhang       break;
1513d71ae5a4SJacob Faibussowitsch     default:
1514d71ae5a4SJacob Faibussowitsch       break;
1515076ba34aSJunchao Zhang     }
1516076ba34aSJunchao Zhang     match = (PetscBool)!usecpu;
1517076ba34aSJunchao Zhang   }
1518076ba34aSJunchao Zhang   if (match) {
1519076ba34aSJunchao Zhang     switch (product->type) {
1520076ba34aSJunchao Zhang     case MATPRODUCT_AB:
1521076ba34aSJunchao Zhang     case MATPRODUCT_AtB:
1522d71ae5a4SJacob Faibussowitsch     case MATPRODUCT_PtAP:
1523d71ae5a4SJacob Faibussowitsch       mat->ops->productsymbolic = MatProductSymbolic_MPIAIJKokkos;
1524d71ae5a4SJacob Faibussowitsch       break;
1525d71ae5a4SJacob Faibussowitsch     default:
1526d71ae5a4SJacob Faibussowitsch       break;
1527076ba34aSJunchao Zhang     }
1528076ba34aSJunchao Zhang   }
1529076ba34aSJunchao Zhang   /* fallback to MPIAIJ ops */
153048a46eb9SPierre Jolivet   if (!mat->ops->productsymbolic) PetscCall(MatProductSetFromOptions_MPIAIJ(mat));
15313ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
1532076ba34aSJunchao Zhang }
1533076ba34aSJunchao Zhang 
15342c4ab24aSJunchao Zhang // Mirror of MatCOOStruct_MPIAIJ on device
15352c4ab24aSJunchao Zhang struct MatCOOStruct_MPIAIJKokkos {
15362c4ab24aSJunchao Zhang   PetscCount           n;
15372c4ab24aSJunchao Zhang   PetscSF              sf;
15382c4ab24aSJunchao Zhang   PetscCount           Annz, Bnnz;
15392c4ab24aSJunchao Zhang   PetscCount           Annz2, Bnnz2;
15402c4ab24aSJunchao Zhang   PetscCountKokkosView Ajmap1, Aperm1;
15412c4ab24aSJunchao Zhang   PetscCountKokkosView Bjmap1, Bperm1;
15422c4ab24aSJunchao Zhang   PetscCountKokkosView Aimap2, Ajmap2, Aperm2;
15432c4ab24aSJunchao Zhang   PetscCountKokkosView Bimap2, Bjmap2, Bperm2;
15442c4ab24aSJunchao Zhang   PetscCountKokkosView Cperm1;
15452c4ab24aSJunchao Zhang   MatScalarKokkosView  sendbuf, recvbuf;
15462c4ab24aSJunchao Zhang 
154792896123SJunchao Zhang   MatCOOStruct_MPIAIJKokkos(const MatCOOStruct_MPIAIJ *coo_h)
15482c4ab24aSJunchao Zhang   {
154992896123SJunchao Zhang     auto &exec = PetscGetKokkosExecutionSpace();
155092896123SJunchao Zhang 
155192896123SJunchao Zhang     n       = coo_h->n;
155292896123SJunchao Zhang     sf      = coo_h->sf;
155392896123SJunchao Zhang     Annz    = coo_h->Annz;
155492896123SJunchao Zhang     Bnnz    = coo_h->Bnnz;
155592896123SJunchao Zhang     Annz2   = coo_h->Annz2;
155692896123SJunchao Zhang     Bnnz2   = coo_h->Bnnz2;
155792896123SJunchao Zhang     Ajmap1  = Kokkos::create_mirror_view_and_copy(exec, PetscCountKokkosViewHost(coo_h->Ajmap1, coo_h->Annz + 1));
155892896123SJunchao Zhang     Aperm1  = Kokkos::create_mirror_view_and_copy(exec, PetscCountKokkosViewHost(coo_h->Aperm1, coo_h->Atot1));
155992896123SJunchao Zhang     Bjmap1  = Kokkos::create_mirror_view_and_copy(exec, PetscCountKokkosViewHost(coo_h->Bjmap1, coo_h->Bnnz + 1));
156092896123SJunchao Zhang     Bperm1  = Kokkos::create_mirror_view_and_copy(exec, PetscCountKokkosViewHost(coo_h->Bperm1, coo_h->Btot1));
156192896123SJunchao Zhang     Aimap2  = Kokkos::create_mirror_view_and_copy(exec, PetscCountKokkosViewHost(coo_h->Aimap2, coo_h->Annz2));
156292896123SJunchao Zhang     Ajmap2  = Kokkos::create_mirror_view_and_copy(exec, PetscCountKokkosViewHost(coo_h->Ajmap2, coo_h->Annz2 + 1));
156392896123SJunchao Zhang     Aperm2  = Kokkos::create_mirror_view_and_copy(exec, PetscCountKokkosViewHost(coo_h->Aperm2, coo_h->Atot2));
156492896123SJunchao Zhang     Bimap2  = Kokkos::create_mirror_view_and_copy(exec, PetscCountKokkosViewHost(coo_h->Bimap2, coo_h->Bnnz2));
156592896123SJunchao Zhang     Bjmap2  = Kokkos::create_mirror_view_and_copy(exec, PetscCountKokkosViewHost(coo_h->Bjmap2, coo_h->Bnnz2 + 1));
156692896123SJunchao Zhang     Bperm2  = Kokkos::create_mirror_view_and_copy(exec, PetscCountKokkosViewHost(coo_h->Bperm2, coo_h->Btot2));
156792896123SJunchao Zhang     Cperm1  = Kokkos::create_mirror_view_and_copy(exec, PetscCountKokkosViewHost(coo_h->Cperm1, coo_h->sendlen));
156892896123SJunchao Zhang     sendbuf = Kokkos::create_mirror_view(Kokkos::WithoutInitializing, exec, MatScalarKokkosViewHost(coo_h->sendbuf, coo_h->sendlen));
156992896123SJunchao Zhang     recvbuf = Kokkos::create_mirror_view(Kokkos::WithoutInitializing, exec, MatScalarKokkosViewHost(coo_h->recvbuf, coo_h->recvlen));
15702c4ab24aSJunchao Zhang     PetscCallVoid(PetscObjectReference((PetscObject)sf));
15712c4ab24aSJunchao Zhang   }
15722c4ab24aSJunchao Zhang 
15732c4ab24aSJunchao Zhang   ~MatCOOStruct_MPIAIJKokkos() { PetscCallVoid(PetscSFDestroy(&sf)); }
15742c4ab24aSJunchao Zhang };
15752c4ab24aSJunchao Zhang 
157649abdd8aSBarry Smith static PetscErrorCode MatCOOStructDestroy_MPIAIJKokkos(void **data)
15772c4ab24aSJunchao Zhang {
15782c4ab24aSJunchao Zhang   PetscFunctionBegin;
157949abdd8aSBarry Smith   PetscCallCXX(delete static_cast<MatCOOStruct_MPIAIJKokkos *>(*data));
15802c4ab24aSJunchao Zhang   PetscFunctionReturn(PETSC_SUCCESS);
15812c4ab24aSJunchao Zhang }
15822c4ab24aSJunchao Zhang 
1583d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatSetPreallocationCOO_MPIAIJKokkos(Mat mat, PetscCount coo_n, PetscInt coo_i[], PetscInt coo_j[])
1584d71ae5a4SJacob Faibussowitsch {
15852c4ab24aSJunchao Zhang   PetscContainer             container_h, container_d;
15862c4ab24aSJunchao Zhang   MatCOOStruct_MPIAIJ       *coo_h;
15872c4ab24aSJunchao Zhang   MatCOOStruct_MPIAIJKokkos *coo_d;
158842550becSJunchao Zhang 
158942550becSJunchao Zhang   PetscFunctionBegin;
159030203840SJunchao Zhang   PetscCall(MatSetPreallocationCOO_MPIAIJ(mat, coo_n, coo_i, coo_j)); /* mpiaij->A,B's type is set to seqaijkokkos */
1591cbc6b225SStefano Zampini   mat->preallocated = PETSC_TRUE;
15929566063dSJacob Faibussowitsch   PetscCall(MatAssemblyBegin(mat, MAT_FINAL_ASSEMBLY));
15939566063dSJacob Faibussowitsch   PetscCall(MatAssemblyEnd(mat, MAT_FINAL_ASSEMBLY));
15949566063dSJacob Faibussowitsch   PetscCall(MatZeroEntries(mat));
15952c4ab24aSJunchao Zhang 
15962c4ab24aSJunchao Zhang   // Copy the COO struct to device
15972c4ab24aSJunchao Zhang   PetscCall(PetscObjectQuery((PetscObject)mat, "__PETSc_MatCOOStruct_Host", (PetscObject *)&container_h));
15982c4ab24aSJunchao Zhang   PetscCall(PetscContainerGetPointer(container_h, (void **)&coo_h));
15992c4ab24aSJunchao Zhang   PetscCallCXX(coo_d = new MatCOOStruct_MPIAIJKokkos(coo_h));
16002c4ab24aSJunchao Zhang 
16012c4ab24aSJunchao Zhang   // Put the COO struct in a container and then attach that to the matrix
16022c4ab24aSJunchao Zhang   PetscCall(PetscContainerCreate(PETSC_COMM_SELF, &container_d));
16032c4ab24aSJunchao Zhang   PetscCall(PetscContainerSetPointer(container_d, coo_d));
160449abdd8aSBarry Smith   PetscCall(PetscContainerSetCtxDestroy(container_d, MatCOOStructDestroy_MPIAIJKokkos));
16052c4ab24aSJunchao Zhang   PetscCall(PetscObjectCompose((PetscObject)mat, "__PETSc_MatCOOStruct_Device", (PetscObject)container_d));
16062c4ab24aSJunchao Zhang   PetscCall(PetscContainerDestroy(&container_d));
16073ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
160842550becSJunchao Zhang }
160942550becSJunchao Zhang 
1610d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatSetValuesCOO_MPIAIJKokkos(Mat mat, const PetscScalar v[], InsertMode imode)
1611d71ae5a4SJacob Faibussowitsch {
1612394ed5ebSJunchao Zhang   Mat_MPIAIJ                    *mpiaij = static_cast<Mat_MPIAIJ *>(mat->data);
161342550becSJunchao Zhang   Mat                            A = mpiaij->A, B = mpiaij->B;
161442550becSJunchao Zhang   MatScalarKokkosView            Aa, Ba;
1615394ed5ebSJunchao Zhang   MatScalarKokkosView            v1;
161642550becSJunchao Zhang   PetscMemType                   memtype;
16172c4ab24aSJunchao Zhang   PetscContainer                 container;
16182c4ab24aSJunchao Zhang   MatCOOStruct_MPIAIJKokkos     *coo;
161992896123SJunchao Zhang   Kokkos::DefaultExecutionSpace &exec = PetscGetKokkosExecutionSpace();
162042550becSJunchao Zhang 
162142550becSJunchao Zhang   PetscFunctionBegin;
16222c4ab24aSJunchao Zhang   PetscCall(PetscObjectQuery((PetscObject)mat, "__PETSc_MatCOOStruct_Device", (PetscObject *)&container));
16232c4ab24aSJunchao Zhang   PetscCall(PetscContainerGetPointer(container, (void **)&coo));
16242c4ab24aSJunchao Zhang 
16252c4ab24aSJunchao Zhang   const auto &n      = coo->n;
16262c4ab24aSJunchao Zhang   const auto &Annz   = coo->Annz;
16272c4ab24aSJunchao Zhang   const auto &Annz2  = coo->Annz2;
16282c4ab24aSJunchao Zhang   const auto &Bnnz   = coo->Bnnz;
16292c4ab24aSJunchao Zhang   const auto &Bnnz2  = coo->Bnnz2;
16302c4ab24aSJunchao Zhang   const auto &vsend  = coo->sendbuf;
16312c4ab24aSJunchao Zhang   const auto &v2     = coo->recvbuf;
16322c4ab24aSJunchao Zhang   const auto &Ajmap1 = coo->Ajmap1;
16332c4ab24aSJunchao Zhang   const auto &Ajmap2 = coo->Ajmap2;
16342c4ab24aSJunchao Zhang   const auto &Aimap2 = coo->Aimap2;
16352c4ab24aSJunchao Zhang   const auto &Bjmap1 = coo->Bjmap1;
16362c4ab24aSJunchao Zhang   const auto &Bjmap2 = coo->Bjmap2;
16372c4ab24aSJunchao Zhang   const auto &Bimap2 = coo->Bimap2;
16382c4ab24aSJunchao Zhang   const auto &Aperm1 = coo->Aperm1;
16392c4ab24aSJunchao Zhang   const auto &Aperm2 = coo->Aperm2;
16402c4ab24aSJunchao Zhang   const auto &Bperm1 = coo->Bperm1;
16412c4ab24aSJunchao Zhang   const auto &Bperm2 = coo->Bperm2;
16422c4ab24aSJunchao Zhang   const auto &Cperm1 = coo->Cperm1;
16432c4ab24aSJunchao Zhang 
16449566063dSJacob Faibussowitsch   PetscCall(PetscGetMemType(v, &memtype)); /* Return PETSC_MEMTYPE_HOST when v is NULL */
164542550becSJunchao Zhang   if (PetscMemTypeHost(memtype)) {         /* If user gave v[] in host, we need to copy it to device if any */
164692896123SJunchao Zhang     v1 = Kokkos::create_mirror_view_and_copy(exec, MatScalarKokkosViewHost((PetscScalar *)v, n));
164742550becSJunchao Zhang   } else {
16482c4ab24aSJunchao Zhang     v1 = MatScalarKokkosView((PetscScalar *)v, n); /* Directly use v[]'s memory */
164942550becSJunchao Zhang   }
165042550becSJunchao Zhang 
165142550becSJunchao Zhang   if (imode == INSERT_VALUES) {
16529566063dSJacob Faibussowitsch     PetscCall(MatSeqAIJGetKokkosViewWrite(A, &Aa)); /* write matrix values */
16539566063dSJacob Faibussowitsch     PetscCall(MatSeqAIJGetKokkosViewWrite(B, &Ba));
1654394ed5ebSJunchao Zhang   } else {
16559566063dSJacob Faibussowitsch     PetscCall(MatSeqAIJGetKokkosView(A, &Aa)); /* read & write matrix values */
16569566063dSJacob Faibussowitsch     PetscCall(MatSeqAIJGetKokkosView(B, &Ba));
165742550becSJunchao Zhang   }
165842550becSJunchao Zhang 
165908bb9926SJunchao Zhang   PetscCall(PetscLogGpuTimeBegin());
166042550becSJunchao Zhang   /* Pack entries to be sent to remote */
166192896123SJunchao Zhang   Kokkos::parallel_for(Kokkos::RangePolicy<>(exec, 0, vsend.extent(0)), KOKKOS_LAMBDA(const PetscCount i) { vsend(i) = v1(Cperm1(i)); });
166242550becSJunchao Zhang 
166342550becSJunchao Zhang   /* Send remote entries to their owner and overlap the communication with local computation */
16642c4ab24aSJunchao Zhang   PetscCall(PetscSFReduceWithMemTypeBegin(coo->sf, MPIU_SCALAR, PETSC_MEMTYPE_KOKKOS, vsend.data(), PETSC_MEMTYPE_KOKKOS, v2.data(), MPI_REPLACE));
1665158ec288SJunchao Zhang   /* Add local entries to A and B in one kernel */
16669371c9d4SSatish Balay   Kokkos::parallel_for(
166792896123SJunchao Zhang     Kokkos::RangePolicy<>(exec, 0, Annz + Bnnz), KOKKOS_LAMBDA(PetscCount i) {
1668158ec288SJunchao Zhang       PetscScalar sum = 0.0;
1669158ec288SJunchao Zhang       if (i < Annz) {
1670158ec288SJunchao Zhang         for (PetscCount k = Ajmap1(i); k < Ajmap1(i + 1); k++) sum += v1(Aperm1(k));
1671ac38520cSJunchao Zhang         Aa(i) = (imode == INSERT_VALUES ? 0.0 : Aa(i)) + sum;
1672158ec288SJunchao Zhang       } else {
1673158ec288SJunchao Zhang         i -= Annz;
1674158ec288SJunchao Zhang         for (PetscCount k = Bjmap1(i); k < Bjmap1(i + 1); k++) sum += v1(Bperm1(k));
1675ac38520cSJunchao Zhang         Ba(i) = (imode == INSERT_VALUES ? 0.0 : Ba(i)) + sum;
1676158ec288SJunchao Zhang       }
1677158ec288SJunchao Zhang     });
16782c4ab24aSJunchao Zhang   PetscCall(PetscSFReduceEnd(coo->sf, MPIU_SCALAR, vsend.data(), v2.data(), MPI_REPLACE));
167942550becSJunchao Zhang 
1680158ec288SJunchao Zhang   /* Add received remote entries to A and B in one kernel */
16819371c9d4SSatish Balay   Kokkos::parallel_for(
168292896123SJunchao Zhang     Kokkos::RangePolicy<>(exec, 0, Annz2 + Bnnz2), KOKKOS_LAMBDA(PetscCount i) {
1683158ec288SJunchao Zhang       if (i < Annz2) {
1684158ec288SJunchao Zhang         for (PetscCount k = Ajmap2(i); k < Ajmap2(i + 1); k++) Aa(Aimap2(i)) += v2(Aperm2(k));
1685158ec288SJunchao Zhang       } else {
1686158ec288SJunchao Zhang         i -= Annz2;
1687158ec288SJunchao Zhang         for (PetscCount k = Bjmap2(i); k < Bjmap2(i + 1); k++) Ba(Bimap2(i)) += v2(Bperm2(k));
1688158ec288SJunchao Zhang       }
1689158ec288SJunchao Zhang     });
169008bb9926SJunchao Zhang   PetscCall(PetscLogGpuTimeEnd());
169142550becSJunchao Zhang 
1692394ed5ebSJunchao Zhang   if (imode == INSERT_VALUES) {
16939566063dSJacob Faibussowitsch     PetscCall(MatSeqAIJRestoreKokkosViewWrite(A, &Aa)); /* Increase A & B's state etc. */
16949566063dSJacob Faibussowitsch     PetscCall(MatSeqAIJRestoreKokkosViewWrite(B, &Ba));
1695394ed5ebSJunchao Zhang   } else {
16969566063dSJacob Faibussowitsch     PetscCall(MatSeqAIJRestoreKokkosView(A, &Aa));
16979566063dSJacob Faibussowitsch     PetscCall(MatSeqAIJRestoreKokkosView(B, &Ba));
1698394ed5ebSJunchao Zhang   }
16993ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
170042550becSJunchao Zhang }
170142550becSJunchao Zhang 
17022c4ab24aSJunchao Zhang static PetscErrorCode MatDestroy_MPIAIJKokkos(Mat A)
1703d71ae5a4SJacob Faibussowitsch {
1704076ba34aSJunchao Zhang   PetscFunctionBegin;
17059566063dSJacob Faibussowitsch   PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatMPIAIJSetPreallocation_C", NULL));
17069566063dSJacob Faibussowitsch   PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatMPIAIJGetLocalMatMerge_C", NULL));
17079566063dSJacob Faibussowitsch   PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatSetPreallocationCOO_C", NULL));
17089566063dSJacob Faibussowitsch   PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatSetValuesCOO_C", NULL));
170957761e9aSJunchao Zhang #if defined(PETSC_HAVE_HYPRE)
171057761e9aSJunchao Zhang   PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatConvert_mpiaijkokkos_hypre_C", NULL));
171157761e9aSJunchao Zhang #endif
17129566063dSJacob Faibussowitsch   PetscCall(MatDestroy_MPIAIJ(A));
17133ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
1714076ba34aSJunchao Zhang }
1715076ba34aSJunchao Zhang 
1716f4747e26SJunchao Zhang static PetscErrorCode MatShift_MPIAIJKokkos(Mat A, PetscScalar a)
1717f4747e26SJunchao Zhang {
1718f4747e26SJunchao Zhang   Mat_MPIAIJ *mpiaij = static_cast<Mat_MPIAIJ *>(A->data);
1719f4747e26SJunchao Zhang   PetscBool   congruent;
1720f4747e26SJunchao Zhang 
1721f4747e26SJunchao Zhang   PetscFunctionBegin;
1722f4747e26SJunchao Zhang   PetscCall(MatHasCongruentLayouts(A, &congruent));
1723f4747e26SJunchao Zhang   if (congruent) { // square matrix and the diagonals are solely in the diag block
1724f4747e26SJunchao Zhang     PetscCall(MatShift(mpiaij->A, a));
1725f4747e26SJunchao Zhang   } else { // too hard, use the general version
1726f4747e26SJunchao Zhang     PetscCall(MatShift_Basic(A, a));
1727f4747e26SJunchao Zhang   }
1728f4747e26SJunchao Zhang   PetscFunctionReturn(PETSC_SUCCESS);
1729f4747e26SJunchao Zhang }
1730f4747e26SJunchao Zhang 
17312c4ab24aSJunchao Zhang static PetscErrorCode MatSetOps_MPIAIJKokkos(Mat B)
17322c4ab24aSJunchao Zhang {
17332c4ab24aSJunchao Zhang   PetscFunctionBegin;
17342c4ab24aSJunchao Zhang   B->ops->assemblyend           = MatAssemblyEnd_MPIAIJKokkos;
17352c4ab24aSJunchao Zhang   B->ops->mult                  = MatMult_MPIAIJKokkos;
17362c4ab24aSJunchao Zhang   B->ops->multadd               = MatMultAdd_MPIAIJKokkos;
17372c4ab24aSJunchao Zhang   B->ops->multtranspose         = MatMultTranspose_MPIAIJKokkos;
17382c4ab24aSJunchao Zhang   B->ops->productsetfromoptions = MatProductSetFromOptions_MPIAIJKokkos;
17392c4ab24aSJunchao Zhang   B->ops->destroy               = MatDestroy_MPIAIJKokkos;
1740f4747e26SJunchao Zhang   B->ops->shift                 = MatShift_MPIAIJKokkos;
17412c4ab24aSJunchao Zhang 
17422c4ab24aSJunchao Zhang   PetscCall(PetscObjectComposeFunction((PetscObject)B, "MatMPIAIJSetPreallocation_C", MatMPIAIJSetPreallocation_MPIAIJKokkos));
17432c4ab24aSJunchao Zhang   PetscCall(PetscObjectComposeFunction((PetscObject)B, "MatMPIAIJGetLocalMatMerge_C", MatMPIAIJGetLocalMatMerge_MPIAIJKokkos));
17442c4ab24aSJunchao Zhang   PetscCall(PetscObjectComposeFunction((PetscObject)B, "MatSetPreallocationCOO_C", MatSetPreallocationCOO_MPIAIJKokkos));
17452c4ab24aSJunchao Zhang   PetscCall(PetscObjectComposeFunction((PetscObject)B, "MatSetValuesCOO_C", MatSetValuesCOO_MPIAIJKokkos));
174657761e9aSJunchao Zhang #if defined(PETSC_HAVE_HYPRE)
174757761e9aSJunchao Zhang   PetscCall(PetscObjectComposeFunction((PetscObject)B, "MatConvert_mpiaijkokkos_hypre_C", MatConvert_AIJ_HYPRE));
174857761e9aSJunchao Zhang #endif
17492c4ab24aSJunchao Zhang   PetscFunctionReturn(PETSC_SUCCESS);
17502c4ab24aSJunchao Zhang }
17512c4ab24aSJunchao Zhang 
1752d71ae5a4SJacob Faibussowitsch PETSC_INTERN PetscErrorCode MatConvert_MPIAIJ_MPIAIJKokkos(Mat A, MatType mtype, MatReuse reuse, Mat *newmat)
1753d71ae5a4SJacob Faibussowitsch {
17548c3ff71bSJunchao Zhang   Mat         B;
1755076ba34aSJunchao Zhang   Mat_MPIAIJ *a;
17568c3ff71bSJunchao Zhang 
17578c3ff71bSJunchao Zhang   PetscFunctionBegin;
17588c3ff71bSJunchao Zhang   if (reuse == MAT_INITIAL_MATRIX) {
17599566063dSJacob Faibussowitsch     PetscCall(MatDuplicate(A, MAT_COPY_VALUES, newmat));
17608c3ff71bSJunchao Zhang   } else if (reuse == MAT_REUSE_MATRIX) {
17619566063dSJacob Faibussowitsch     PetscCall(MatCopy(A, *newmat, SAME_NONZERO_PATTERN));
17628c3ff71bSJunchao Zhang   }
17638c3ff71bSJunchao Zhang   B = *newmat;
17648c3ff71bSJunchao Zhang 
17656f3d89d0SStefano Zampini   B->boundtocpu = PETSC_FALSE;
17669566063dSJacob Faibussowitsch   PetscCall(PetscFree(B->defaultvectype));
17679566063dSJacob Faibussowitsch   PetscCall(PetscStrallocpy(VECKOKKOS, &B->defaultvectype));
17689566063dSJacob Faibussowitsch   PetscCall(PetscObjectChangeTypeName((PetscObject)B, MATMPIAIJKOKKOS));
17698c3ff71bSJunchao Zhang 
1770076ba34aSJunchao Zhang   a = static_cast<Mat_MPIAIJ *>(A->data);
17719566063dSJacob Faibussowitsch   if (a->A) PetscCall(MatSetType(a->A, MATSEQAIJKOKKOS));
17729566063dSJacob Faibussowitsch   if (a->B) PetscCall(MatSetType(a->B, MATSEQAIJKOKKOS));
17739566063dSJacob Faibussowitsch   if (a->lvec) PetscCall(VecSetType(a->lvec, VECSEQKOKKOS));
17742c4ab24aSJunchao Zhang   PetscCall(MatSetOps_MPIAIJKokkos(B));
17753ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
17768c3ff71bSJunchao Zhang }
17772c4ab24aSJunchao Zhang 
17783f3ba80aSJunchao Zhang /*MC
177911a5261eSBarry Smith    MATAIJKOKKOS - "mpiaijkokkos", a matrix type to be used for CSR sparse matrices with Kokkos
17808c3ff71bSJunchao Zhang 
178115229ffcSPierre Jolivet    A matrix type using Kokkos-Kernels CrsMatrix type for portability across different device types
17823f3ba80aSJunchao Zhang 
17832ef1f0ffSBarry Smith    Options Database Key:
17842ef1f0ffSBarry Smith .  -mat_type aijkokkos - sets the matrix type to `MATAIJKOKKOS`
17853f3ba80aSJunchao Zhang 
17863f3ba80aSJunchao Zhang   Level: beginner
17873f3ba80aSJunchao Zhang 
17881cc06b55SBarry Smith .seealso: [](ch_matrices), `Mat`, `MatCreateAIJKokkos()`, `MATSEQAIJKOKKOS`, `MATSEQAIJ`, `MATMPIAIJ`
17893f3ba80aSJunchao Zhang M*/
1790d71ae5a4SJacob Faibussowitsch PETSC_EXTERN PetscErrorCode MatCreate_MPIAIJKokkos(Mat A)
1791d71ae5a4SJacob Faibussowitsch {
17928c3ff71bSJunchao Zhang   PetscFunctionBegin;
17939566063dSJacob Faibussowitsch   PetscCall(PetscKokkosInitializeCheck());
17949566063dSJacob Faibussowitsch   PetscCall(MatCreate_MPIAIJ(A));
17959566063dSJacob Faibussowitsch   PetscCall(MatConvert_MPIAIJ_MPIAIJKokkos(A, MATMPIAIJKOKKOS, MAT_INPLACE_MATRIX, &A));
17963ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
17978c3ff71bSJunchao Zhang }
17988c3ff71bSJunchao Zhang 
17998c3ff71bSJunchao Zhang /*@C
1800f8d70eaaSPierre Jolivet   MatCreateAIJKokkos - Creates a sparse matrix in `MATAIJKOKKOS` (compressed row) format
18018c3ff71bSJunchao Zhang   (the default parallel PETSc format).  This matrix will ultimately pushed down
180220f4b53cSBarry Smith   to Kokkos for calculations.
18038c3ff71bSJunchao Zhang 
18048c3ff71bSJunchao Zhang   Collective
18058c3ff71bSJunchao Zhang 
18068c3ff71bSJunchao Zhang   Input Parameters:
180711a5261eSBarry Smith + comm  - MPI communicator, set to `PETSC_COMM_SELF`
180820f4b53cSBarry Smith . m     - number of local rows (or `PETSC_DECIDE` to have calculated if `M` is given)
180920f4b53cSBarry Smith            This value should be the same as the local size used in creating the
181020f4b53cSBarry Smith            y vector for the matrix-vector product y = Ax.
181120f4b53cSBarry Smith . n     - This value should be the same as the local size used in creating the
181220f4b53cSBarry Smith        x vector for the matrix-vector product y = Ax. (or `PETSC_DECIDE` to have
181320f4b53cSBarry Smith        calculated if N is given) For square matrices n is almost always `m`.
181420f4b53cSBarry Smith . M     - number of global rows (or `PETSC_DETERMINE` to have calculated if `m` is given)
181520f4b53cSBarry Smith . N     - number of global columns (or `PETSC_DETERMINE` to have calculated if `n` is given)
181620f4b53cSBarry Smith . d_nz  - number of nonzeros per row in DIAGONAL portion of local submatrix
181720f4b53cSBarry Smith            (same value is used for all local rows)
181820f4b53cSBarry Smith . d_nnz - array containing the number of nonzeros in the various rows of the
181920f4b53cSBarry Smith            DIAGONAL portion of the local submatrix (possibly different for each row)
182020f4b53cSBarry Smith            or `NULL`, if `d_nz` is used to specify the nonzero structure.
182120f4b53cSBarry Smith            The size of this array is equal to the number of local rows, i.e `m`.
182220f4b53cSBarry Smith            For matrices you plan to factor you must leave room for the diagonal entry and
182320f4b53cSBarry Smith            put in the entry even if it is zero.
182420f4b53cSBarry Smith . o_nz  - number of nonzeros per row in the OFF-DIAGONAL portion of local
182520f4b53cSBarry Smith            submatrix (same value is used for all local rows).
182620f4b53cSBarry Smith - o_nnz - array containing the number of nonzeros in the various rows of the
182720f4b53cSBarry Smith            OFF-DIAGONAL portion of the local submatrix (possibly different for
182820f4b53cSBarry Smith            each row) or `NULL`, if `o_nz` is used to specify the nonzero
182920f4b53cSBarry Smith            structure. The size of this array is equal to the number
183020f4b53cSBarry Smith            of local rows, i.e `m`.
18318c3ff71bSJunchao Zhang 
18328c3ff71bSJunchao Zhang   Output Parameter:
18338c3ff71bSJunchao Zhang . A - the matrix
18348c3ff71bSJunchao Zhang 
18352ef1f0ffSBarry Smith   Level: intermediate
18362ef1f0ffSBarry Smith 
18372ef1f0ffSBarry Smith   Notes:
183811a5261eSBarry Smith   It is recommended that one use the `MatCreate()`, `MatSetType()` and/or `MatSetFromOptions()`,
18398c3ff71bSJunchao Zhang   MatXXXXSetPreallocation() paradigm instead of this routine directly.
184011a5261eSBarry Smith   [MatXXXXSetPreallocation() is, for example, `MatSeqAIJSetPreallocation()`]
18418c3ff71bSJunchao Zhang 
1842667f096bSBarry Smith   The AIJ format, also called compressed row storage), is fully compatible with standard Fortran
18438c3ff71bSJunchao Zhang   storage.  That is, the stored row and column indices can begin at
18442ef1f0ffSBarry Smith   either one (as in Fortran) or zero.
18458c3ff71bSJunchao Zhang 
1846f8d70eaaSPierre Jolivet .seealso: [](ch_matrices), `Mat`, `MATAIJKOKKOS`, `MATSEQAIJKOKKOS`, `MATMPIAIJKOKKOS`, `MatCreate()`, `MatCreateAIJ()`, `MatSetValues()`,
1847f8d70eaaSPierre Jolivet           `MatSeqAIJSetColumnIndices()`, `MatCreateSeqAIJWithArrays()`
18488c3ff71bSJunchao Zhang @*/
1849d71ae5a4SJacob Faibussowitsch PetscErrorCode MatCreateAIJKokkos(MPI_Comm comm, PetscInt m, PetscInt n, PetscInt M, PetscInt N, PetscInt d_nz, const PetscInt d_nnz[], PetscInt o_nz, const PetscInt o_nnz[], Mat *A)
1850d71ae5a4SJacob Faibussowitsch {
18518c3ff71bSJunchao Zhang   PetscMPIInt size;
18528c3ff71bSJunchao Zhang 
18538c3ff71bSJunchao Zhang   PetscFunctionBegin;
18549566063dSJacob Faibussowitsch   PetscCall(MatCreate(comm, A));
18559566063dSJacob Faibussowitsch   PetscCall(MatSetSizes(*A, m, n, M, N));
18569566063dSJacob Faibussowitsch   PetscCallMPI(MPI_Comm_size(comm, &size));
18578c3ff71bSJunchao Zhang   if (size > 1) {
18589566063dSJacob Faibussowitsch     PetscCall(MatSetType(*A, MATMPIAIJKOKKOS));
18599566063dSJacob Faibussowitsch     PetscCall(MatMPIAIJSetPreallocation(*A, d_nz, d_nnz, o_nz, o_nnz));
18608c3ff71bSJunchao Zhang   } else {
18619566063dSJacob Faibussowitsch     PetscCall(MatSetType(*A, MATSEQAIJKOKKOS));
18629566063dSJacob Faibussowitsch     PetscCall(MatSeqAIJSetPreallocation(*A, d_nz, d_nnz));
18638c3ff71bSJunchao Zhang   }
18643ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
18658c3ff71bSJunchao Zhang }
1866