xref: /petsc/src/mat/impls/aij/mpi/kokkos/mpiaijkok.kokkos.cxx (revision f3d3cd90648576fafae91a24e2611daaef7bcd2e)
1d326c3f1SJunchao Zhang #include <petsc_kokkos.hpp>
211d22bbfSJunchao Zhang #include <petscvec_kokkos.hpp>
3c0c276a7Ssdargavi #include <petscmat_kokkos.hpp>
4f0e6e2d1SJunchao Zhang #include <petscpkg_version.h>
542550becSJunchao Zhang #include <petsc/private/sfimpl.h>
692896123SJunchao Zhang #include <petsc/private/kokkosimpl.hpp>
72c4ab24aSJunchao Zhang #include <../src/mat/impls/aij/seq/kokkos/aijkok.hpp>
88c3ff71bSJunchao Zhang #include <../src/mat/impls/aij/mpi/mpiaij.h>
9076ba34aSJunchao Zhang #include <KokkosSparse_spadd.hpp>
100e3ece09SJunchao Zhang #include <KokkosSparse_spgemm.hpp>
1111d22bbfSJunchao Zhang 
1266976f2fSJacob Faibussowitsch static PetscErrorCode MatAssemblyEnd_MPIAIJKokkos(Mat A, MatAssemblyType mode)
13d71ae5a4SJacob Faibussowitsch {
1430203840SJunchao Zhang   Mat_MPIAIJ *mpiaij = (Mat_MPIAIJ *)A->data;
158c3ff71bSJunchao Zhang 
168c3ff71bSJunchao Zhang   PetscFunctionBegin;
179566063dSJacob Faibussowitsch   PetscCall(MatAssemblyEnd_MPIAIJ(A, mode));
1830203840SJunchao Zhang   /* E.g., MatCreateSubMatrix() calls MatCreateMPIAIJWithSeqAIJ(comm,A,B,..), which creates Bnew of SEQAIJ and destroys B of SEQAIJKOKKOS.
1930203840SJunchao Zhang      Thus we finalize A/B/lvec's type in MatAssemblyEnd() to handle various cases.
2030203840SJunchao Zhang    */
2130203840SJunchao Zhang   if (mode == MAT_FINAL_ASSEMBLY) {
2292896123SJunchao Zhang     PetscScalarKokkosView v;
2392896123SJunchao Zhang 
2430203840SJunchao Zhang     PetscCall(MatSetType(mpiaij->A, MATSEQAIJKOKKOS));
2530203840SJunchao Zhang     PetscCall(MatSetType(mpiaij->B, MATSEQAIJKOKKOS));
2692896123SJunchao Zhang     PetscCall(VecSetType(mpiaij->lvec, VECSEQKOKKOS));  // lvec is init'ed on host, without copying to device
2792896123SJunchao Zhang     PetscCall(VecGetKokkosViewWrite(mpiaij->lvec, &v)); // mark lvec updated on device, as we never need to init lvec on device
2892896123SJunchao Zhang     PetscCall(VecRestoreKokkosViewWrite(mpiaij->lvec, &v));
2930203840SJunchao Zhang   }
303ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
318c3ff71bSJunchao Zhang }
328c3ff71bSJunchao Zhang 
3366976f2fSJacob Faibussowitsch static PetscErrorCode MatMPIAIJSetPreallocation_MPIAIJKokkos(Mat mat, PetscInt d_nz, const PetscInt d_nnz[], PetscInt o_nz, const PetscInt o_nnz[])
34d71ae5a4SJacob Faibussowitsch {
352cdb1aeaSJunchao Zhang   Mat_MPIAIJ *mpiaij;
368c3ff71bSJunchao Zhang 
378c3ff71bSJunchao Zhang   PetscFunctionBegin;
382cdb1aeaSJunchao Zhang   // reuse MPIAIJ's preallocation, which sets A/B's blocksize along other things
392cdb1aeaSJunchao Zhang   PetscCall(MatMPIAIJSetPreallocation_MPIAIJ(mat, d_nz, d_nnz, o_nz, o_nnz));
402cdb1aeaSJunchao Zhang   mpiaij = static_cast<Mat_MPIAIJ *>(mat->data);
412cdb1aeaSJunchao Zhang   PetscCall(MatConvert_SeqAIJ_SeqAIJKokkos(mpiaij->A, MATSEQAIJKOKKOS, MAT_INPLACE_MATRIX, &mpiaij->A));
422cdb1aeaSJunchao Zhang   PetscCall(MatConvert_SeqAIJ_SeqAIJKokkos(mpiaij->B, MATSEQAIJKOKKOS, MAT_INPLACE_MATRIX, &mpiaij->B));
433ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
448c3ff71bSJunchao Zhang }
458c3ff71bSJunchao Zhang 
4666976f2fSJacob Faibussowitsch static PetscErrorCode MatMult_MPIAIJKokkos(Mat mat, Vec xx, Vec yy)
47d71ae5a4SJacob Faibussowitsch {
488c3ff71bSJunchao Zhang   Mat_MPIAIJ *mpiaij = (Mat_MPIAIJ *)mat->data;
498c3ff71bSJunchao Zhang   PetscInt    nt;
508c3ff71bSJunchao Zhang 
518c3ff71bSJunchao Zhang   PetscFunctionBegin;
529566063dSJacob Faibussowitsch   PetscCall(VecGetLocalSize(xx, &nt));
5308401ef6SPierre Jolivet   PetscCheck(nt == mat->cmap->n, PETSC_COMM_SELF, PETSC_ERR_ARG_SIZ, "Incompatible partition of mat (%" PetscInt_FMT ") and xx (%" PetscInt_FMT ")", mat->cmap->n, nt);
549566063dSJacob Faibussowitsch   PetscCall(VecScatterBegin(mpiaij->Mvctx, xx, mpiaij->lvec, INSERT_VALUES, SCATTER_FORWARD));
559566063dSJacob Faibussowitsch   PetscCall((*mpiaij->A->ops->mult)(mpiaij->A, xx, yy));
569566063dSJacob Faibussowitsch   PetscCall(VecScatterEnd(mpiaij->Mvctx, xx, mpiaij->lvec, INSERT_VALUES, SCATTER_FORWARD));
579566063dSJacob Faibussowitsch   PetscCall((*mpiaij->B->ops->multadd)(mpiaij->B, mpiaij->lvec, yy, yy));
583ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
598c3ff71bSJunchao Zhang }
608c3ff71bSJunchao Zhang 
6166976f2fSJacob Faibussowitsch static PetscErrorCode MatMultAdd_MPIAIJKokkos(Mat mat, Vec xx, Vec yy, Vec zz)
62d71ae5a4SJacob Faibussowitsch {
638c3ff71bSJunchao Zhang   Mat_MPIAIJ *mpiaij = (Mat_MPIAIJ *)mat->data;
648c3ff71bSJunchao Zhang   PetscInt    nt;
658c3ff71bSJunchao Zhang 
668c3ff71bSJunchao Zhang   PetscFunctionBegin;
679566063dSJacob Faibussowitsch   PetscCall(VecGetLocalSize(xx, &nt));
6808401ef6SPierre Jolivet   PetscCheck(nt == mat->cmap->n, PETSC_COMM_SELF, PETSC_ERR_ARG_SIZ, "Incompatible partition of mat (%" PetscInt_FMT ") and xx (%" PetscInt_FMT ")", mat->cmap->n, nt);
699566063dSJacob Faibussowitsch   PetscCall(VecScatterBegin(mpiaij->Mvctx, xx, mpiaij->lvec, INSERT_VALUES, SCATTER_FORWARD));
709566063dSJacob Faibussowitsch   PetscCall((*mpiaij->A->ops->multadd)(mpiaij->A, xx, yy, zz));
719566063dSJacob Faibussowitsch   PetscCall(VecScatterEnd(mpiaij->Mvctx, xx, mpiaij->lvec, INSERT_VALUES, SCATTER_FORWARD));
729566063dSJacob Faibussowitsch   PetscCall((*mpiaij->B->ops->multadd)(mpiaij->B, mpiaij->lvec, zz, zz));
733ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
748c3ff71bSJunchao Zhang }
758c3ff71bSJunchao Zhang 
7666976f2fSJacob Faibussowitsch static PetscErrorCode MatMultTranspose_MPIAIJKokkos(Mat mat, Vec xx, Vec yy)
77d71ae5a4SJacob Faibussowitsch {
788c3ff71bSJunchao Zhang   Mat_MPIAIJ *mpiaij = (Mat_MPIAIJ *)mat->data;
798c3ff71bSJunchao Zhang   PetscInt    nt;
808c3ff71bSJunchao Zhang 
818c3ff71bSJunchao Zhang   PetscFunctionBegin;
829566063dSJacob Faibussowitsch   PetscCall(VecGetLocalSize(xx, &nt));
8308401ef6SPierre Jolivet   PetscCheck(nt == mat->rmap->n, PETSC_COMM_SELF, PETSC_ERR_ARG_SIZ, "Incompatible partition of mat (%" PetscInt_FMT ") and xx (%" PetscInt_FMT ")", mat->rmap->n, nt);
849566063dSJacob Faibussowitsch   PetscCall((*mpiaij->B->ops->multtranspose)(mpiaij->B, xx, mpiaij->lvec));
859566063dSJacob Faibussowitsch   PetscCall((*mpiaij->A->ops->multtranspose)(mpiaij->A, xx, yy));
869566063dSJacob Faibussowitsch   PetscCall(VecScatterBegin(mpiaij->Mvctx, mpiaij->lvec, yy, ADD_VALUES, SCATTER_REVERSE));
879566063dSJacob Faibussowitsch   PetscCall(VecScatterEnd(mpiaij->Mvctx, mpiaij->lvec, yy, ADD_VALUES, SCATTER_REVERSE));
883ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
898c3ff71bSJunchao Zhang }
908c3ff71bSJunchao Zhang 
91076ba34aSJunchao Zhang /* Merge the "A, B" matrices of mat into a matrix C.  mat's type is MPIAIJKOKKOS. C's type is MATSEQAIJKOKKOS.
92076ba34aSJunchao Zhang    A is put before B. C's size would be A->rmap->n by (A->cmap->n + B->cmap->n).
93076ba34aSJunchao Zhang    C still uses local column ids. Their corresponding global column ids are returned in glob.
94076ba34aSJunchao Zhang */
9566976f2fSJacob Faibussowitsch static PetscErrorCode MatMPIAIJGetLocalMatMerge_MPIAIJKokkos(Mat mat, MatReuse reuse, IS *glob, Mat *C)
96d71ae5a4SJacob Faibussowitsch {
97076ba34aSJunchao Zhang   Mat             Ad, Ao;
98076ba34aSJunchao Zhang   const PetscInt *cmap;
99076ba34aSJunchao Zhang 
100076ba34aSJunchao Zhang   PetscFunctionBegin;
1019566063dSJacob Faibussowitsch   PetscCall(MatMPIAIJGetSeqAIJ(mat, &Ad, &Ao, &cmap));
1029566063dSJacob Faibussowitsch   PetscCall(MatSeqAIJKokkosMergeMats(Ad, Ao, reuse, C));
103076ba34aSJunchao Zhang   if (glob) {
104076ba34aSJunchao Zhang     PetscInt cst, i, dn, on, *gidx;
1059566063dSJacob Faibussowitsch     PetscCall(MatGetLocalSize(Ad, NULL, &dn));
1069566063dSJacob Faibussowitsch     PetscCall(MatGetLocalSize(Ao, NULL, &on));
1079566063dSJacob Faibussowitsch     PetscCall(MatGetOwnershipRangeColumn(mat, &cst, NULL));
1089566063dSJacob Faibussowitsch     PetscCall(PetscMalloc1(dn + on, &gidx));
109076ba34aSJunchao Zhang     for (i = 0; i < dn; i++) gidx[i] = cst + i;
110076ba34aSJunchao Zhang     for (i = 0; i < on; i++) gidx[i + dn] = cmap[i];
1119566063dSJacob Faibussowitsch     PetscCall(ISCreateGeneral(PetscObjectComm((PetscObject)Ad), dn + on, gidx, PETSC_OWN_POINTER, glob));
112076ba34aSJunchao Zhang   }
1133ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
114076ba34aSJunchao Zhang }
115076ba34aSJunchao Zhang 
1160e3ece09SJunchao Zhang /* Structs used in matrix products of type C=AB, C=A^tB and C=B^tAB */
117076ba34aSJunchao Zhang struct MatMatStruct {
1180e3ece09SJunchao Zhang   PetscInt            n, *garray;     // C's garray and its size.
1190e3ece09SJunchao Zhang   KokkosCsrMatrix     Cd, Co;         // C is in split form matrices (all in local column indcies)
1200e3ece09SJunchao Zhang   KokkosCsrMatrix     C1, C2, C3, C4; // intermediate mat products
1210e3ece09SJunchao Zhang   KokkosCsrMatrix     C2_mid, C4_mid; // alias of C2, C4; share their a[], i[], but with different j[] (hence column size)
1220e3ece09SJunchao Zhang   PetscIntKokkosView  E_NzLeft;
1230e3ece09SJunchao Zhang   PetscSF             sf = nullptr; // SF to bcast or reduce matrices E to F
1240e3ece09SJunchao Zhang   MatScalarKokkosView rootBuf, leafBuf;
1250e3ece09SJunchao Zhang   KokkosCsrMatrix     Fd, Fo; // F in split form
1260e3ece09SJunchao Zhang 
1270e3ece09SJunchao Zhang   KernelHandle kh1; // compute C1, add C1+C3 or C1+Fd
1280e3ece09SJunchao Zhang   KernelHandle kh2; // compute C2, add C2+C4 or C2+Fo
1290e3ece09SJunchao Zhang   KernelHandle kh3; // compute C3
1300e3ece09SJunchao Zhang   KernelHandle kh4; // compute C4
1310e3ece09SJunchao Zhang 
132aaa8cc7dSPierre Jolivet   PetscInt E_TeamSize; // kernel launching parameters in merging E or splitting F
1330e3ece09SJunchao Zhang   PetscInt E_VectorLength;
1340e3ece09SJunchao Zhang   PetscInt E_RowsPerTeam;
1350e3ece09SJunchao Zhang   PetscInt F_TeamSize;
1360e3ece09SJunchao Zhang   PetscInt F_VectorLength;
1370e3ece09SJunchao Zhang   PetscInt F_RowsPerTeam;
138076ba34aSJunchao Zhang 
139d71ae5a4SJacob Faibussowitsch   ~MatMatStruct()
140d71ae5a4SJacob Faibussowitsch   {
1413ba16761SJacob Faibussowitsch     PetscFunctionBegin;
1423ba16761SJacob Faibussowitsch     PetscCallAbort(PETSC_COMM_SELF, PetscSFDestroy(&sf));
1433ba16761SJacob Faibussowitsch     PetscFunctionReturnVoid();
144076ba34aSJunchao Zhang   }
145076ba34aSJunchao Zhang };
146076ba34aSJunchao Zhang 
147076ba34aSJunchao Zhang struct MatMatStruct_AB : public MatMatStruct {
1480e3ece09SJunchao Zhang   PetscIntKokkosView F_NzLeft; // plans to split F (in leafbuf) into Fd, Fo
1490e3ece09SJunchao Zhang   PetscIntKokkosView irootloc; // plans to put E (i.e., Bd, Bo) into rootBuf
1500e3ece09SJunchao Zhang   PetscIntKokkosView rowoffset;
151076ba34aSJunchao Zhang };
152076ba34aSJunchao Zhang 
153076ba34aSJunchao Zhang struct MatMatStruct_AtB : public MatMatStruct {
1540e3ece09SJunchao Zhang   MatColIdxKokkosView Fdjmap; // plans to reduce data in rootBuf to Fd, Fo
1550e3ece09SJunchao Zhang   MatColIdxKokkosView Fdjperm;
1560e3ece09SJunchao Zhang   MatColIdxKokkosView Fojmap;
1570e3ece09SJunchao Zhang   MatColIdxKokkosView Fojperm;
158076ba34aSJunchao Zhang };
159076ba34aSJunchao Zhang 
1609371c9d4SSatish Balay struct MatProductData_MPIAIJKokkos {
1613ba16761SJacob Faibussowitsch   MatMatStruct_AB  *mmAB     = nullptr;
1623ba16761SJacob Faibussowitsch   MatMatStruct_AtB *mmAtB    = nullptr;
1633ba16761SJacob Faibussowitsch   PetscBool         reusesym = PETSC_FALSE;
1640e3ece09SJunchao Zhang   Mat               Z        = nullptr; // store Z=AB in computing BtAB
165076ba34aSJunchao Zhang 
166d71ae5a4SJacob Faibussowitsch   ~MatProductData_MPIAIJKokkos()
167d71ae5a4SJacob Faibussowitsch   {
168076ba34aSJunchao Zhang     delete mmAB;
169076ba34aSJunchao Zhang     delete mmAtB;
1700e3ece09SJunchao Zhang     PetscCallAbort(PETSC_COMM_SELF, MatDestroy(&Z));
171076ba34aSJunchao Zhang   }
172076ba34aSJunchao Zhang };
173076ba34aSJunchao Zhang 
174d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatProductDataDestroy_MPIAIJKokkos(void *data)
175d71ae5a4SJacob Faibussowitsch {
176076ba34aSJunchao Zhang   PetscFunctionBegin;
1779566063dSJacob Faibussowitsch   PetscCallCXX(delete static_cast<MatProductData_MPIAIJKokkos *>(data));
1783ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
179076ba34aSJunchao Zhang }
180076ba34aSJunchao Zhang 
1810e3ece09SJunchao Zhang // Adapted from Kokkos-Kernels spmv_launch_parameters(), to get parameters in Kokkos nested loops which we used to merge or
1820e3ece09SJunchao Zhang // split csr matrices. The rule is to have "vector_length * team_size" be around 256 on GPUs (e.g., for a CUDA thread block)
1830e3ece09SJunchao Zhang template <class ExecutionSpace>
1840e3ece09SJunchao Zhang static PetscErrorCode MatMergeGetLaunchParameters(PetscInt numRows, PetscInt nnz, PetscInt rows_per_thread, PetscInt &team_size, PetscInt &vector_length, PetscInt &rows_per_team)
185d71ae5a4SJacob Faibussowitsch {
1861aa660a0SJunchao Zhang #if PETSC_PKG_KOKKOS_KERNELS_VERSION_LE(4, 4, 1)
1871aa660a0SJunchao Zhang   constexpr bool is_gpu_exec_space = KokkosKernels::Impl::kk_is_gpu_exec_space<ExecutionSpace>();
1881aa660a0SJunchao Zhang #else
1891aa660a0SJunchao Zhang   constexpr bool is_gpu_exec_space = KokkosKernels::Impl::is_gpu_exec_space_v<ExecutionSpace>;
1901aa660a0SJunchao Zhang #endif
1910e3ece09SJunchao Zhang   Kokkos::TeamPolicy<ExecutionSpace> teamPolicy(128, Kokkos::AUTO);
192076ba34aSJunchao Zhang 
193076ba34aSJunchao Zhang   PetscFunctionBegin;
1940e3ece09SJunchao Zhang   PetscInt nnz_per_row = numRows ? (nnz / numRows) : 0; // we might meet empty matrices
195076ba34aSJunchao Zhang 
1960e3ece09SJunchao Zhang   if (nnz_per_row < 1) nnz_per_row = 1;
197076ba34aSJunchao Zhang 
1980e3ece09SJunchao Zhang   int max_vector_length = teamPolicy.vector_length_max();
199076ba34aSJunchao Zhang 
2000e3ece09SJunchao Zhang   if (vector_length < 1) {
2010e3ece09SJunchao Zhang     vector_length = 1;
2020e3ece09SJunchao Zhang     while (vector_length < max_vector_length && vector_length * 6 < nnz_per_row) vector_length *= 2;
203076ba34aSJunchao Zhang   }
204076ba34aSJunchao Zhang 
2050e3ece09SJunchao Zhang   // Determine rows per thread
2060e3ece09SJunchao Zhang   if (rows_per_thread < 1) {
2071aa660a0SJunchao Zhang     if (is_gpu_exec_space) rows_per_thread = 1;
2080e3ece09SJunchao Zhang     else {
2090e3ece09SJunchao Zhang       if (nnz_per_row < 20 && nnz > 5000000) {
2100e3ece09SJunchao Zhang         rows_per_thread = 256;
2110e3ece09SJunchao Zhang       } else rows_per_thread = 64;
212076ba34aSJunchao Zhang     }
213076ba34aSJunchao Zhang   }
214076ba34aSJunchao Zhang 
2150e3ece09SJunchao Zhang   if (team_size < 1) {
2161aa660a0SJunchao Zhang     if (is_gpu_exec_space) {
2170e3ece09SJunchao Zhang       team_size = 256 / vector_length;
218076ba34aSJunchao Zhang     } else {
2190e3ece09SJunchao Zhang       team_size = 1;
2200e3ece09SJunchao Zhang     }
221076ba34aSJunchao Zhang   }
222076ba34aSJunchao Zhang 
2230e3ece09SJunchao Zhang   rows_per_team = rows_per_thread * team_size;
224076ba34aSJunchao Zhang 
2250e3ece09SJunchao Zhang   if (rows_per_team < 0) {
2260e3ece09SJunchao Zhang     PetscInt nnz_per_team = 4096;
2270e3ece09SJunchao Zhang     PetscInt conc         = ExecutionSpace().concurrency();
2280e3ece09SJunchao Zhang     while ((conc * nnz_per_team * 4 > nnz) && (nnz_per_team > 256)) nnz_per_team /= 2;
2290e3ece09SJunchao Zhang     rows_per_team = (nnz_per_team + nnz_per_row - 1) / nnz_per_row;
2300e3ece09SJunchao Zhang   }
2313ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
232076ba34aSJunchao Zhang }
233076ba34aSJunchao Zhang 
2340e3ece09SJunchao Zhang /*
2350e3ece09SJunchao Zhang   Reduce two sets of global indices into local ones
236076ba34aSJunchao Zhang 
237076ba34aSJunchao Zhang   Input Parameters:
2380e3ece09SJunchao Zhang +  n1          - size of garray1[], the first set
2390e3ece09SJunchao Zhang .  garray1[n1] - a sorted global index array (without duplicates)
2400e3ece09SJunchao Zhang .  m           - size of indices[], the second set
2410e3ece09SJunchao Zhang -  indices[m]  - a unsorted global index array (might have duplicates), which will be updated on output into local ones
242076ba34aSJunchao Zhang 
243076ba34aSJunchao Zhang   Output Parameters:
2440e3ece09SJunchao Zhang +  n2          - size of garray2[], the merged set, which combines garray1[] and indices[]
2450e3ece09SJunchao Zhang .  garray2[n2] - allocated by callee using PetscMalloc1(). Contains sorted unique global indices (without duplicates). Caller needs to free it.
2460e3ece09SJunchao Zhang .  map[n1]     - allocated by caller. It gives garray1[i] = garray2[map[i]]
2470e3ece09SJunchao Zhang -  indices[m]  - on output, global indices in this array are rewritten with local ones, i.e, indices_input[i] = garray2[indices_output[i]]
248076ba34aSJunchao Zhang 
2490e3ece09SJunchao Zhang    Example, say
2500e3ece09SJunchao Zhang     n1         = 5
2510e3ece09SJunchao Zhang     garray1[5] = {1, 4, 7, 8, 10}
2520e3ece09SJunchao Zhang     m          = 4
2530e3ece09SJunchao Zhang     indices[4] = {2, 4, 8, 9}
25411a5261eSBarry Smith 
2550e3ece09SJunchao Zhang    Combining them together, we have 7 global indices in garray2[]
2560e3ece09SJunchao Zhang     n2         = 7
2570e3ece09SJunchao Zhang     garray2[7] = {1, 2, 4, 7, 8, 9, 10}
2580e3ece09SJunchao Zhang 
2590e3ece09SJunchao Zhang    And we have map[] to connect "garray1[i] = garray2[map[i]], i=[0,n1)"
2600e3ece09SJunchao Zhang     map[5] = {0, 2, 3, 4, 6}
2610e3ece09SJunchao Zhang 
2620e3ece09SJunchao Zhang    On output, indices[] is updated with local indices
2630e3ece09SJunchao Zhang     indices[4] = {1, 2, 4, 5}
264076ba34aSJunchao Zhang */
2650e3ece09SJunchao Zhang static PetscErrorCode ReduceTwoSetsOfGlobalIndices(PetscInt n1, const PetscInt *garray1, PetscInt m, PetscInt *indices, PetscInt *n2_, PetscInt **garray2_, PetscInt *map)
266d71ae5a4SJacob Faibussowitsch {
2670e3ece09SJunchao Zhang   PetscHMapI    g2l = nullptr;
2680e3ece09SJunchao Zhang   PetscHashIter iter;
2690e3ece09SJunchao Zhang   PetscInt      tot, key, val; // total unique global indices. key is global id; val is local id
2700e3ece09SJunchao Zhang   PetscInt      n2, *garray2;
271076ba34aSJunchao Zhang 
272076ba34aSJunchao Zhang   PetscFunctionBegin;
2730e3ece09SJunchao Zhang   tot = 0;
2740e3ece09SJunchao Zhang   PetscCall(PetscHMapICreateWithSize(n1, &g2l));
2750e3ece09SJunchao Zhang   for (PetscInt i = 0; i < m; i++) {                                // insert those in indices[]
2760e3ece09SJunchao Zhang     PetscCall(PetscHMapIGetWithDefault(g2l, indices[i], -1, &val)); // if not exist, val is set with -1
2770e3ece09SJunchao Zhang     if (val < 0) PetscCall(PetscHMapISet(g2l, indices[i], tot++));  // val < 0 means gid is not in the hash table yet
278076ba34aSJunchao Zhang   }
279076ba34aSJunchao Zhang 
2800e3ece09SJunchao Zhang   for (PetscInt i = 0; i < n1; i++) { // insert those in garray1[]
2810e3ece09SJunchao Zhang     PetscCall(PetscHMapIGetWithDefault(g2l, garray1[i], -1, &val));
2820e3ece09SJunchao Zhang     if (val < 0) PetscCall(PetscHMapISet(g2l, garray1[i], tot++));
283076ba34aSJunchao Zhang   }
284076ba34aSJunchao Zhang 
2850e3ece09SJunchao Zhang   // Pull out (unique) globals in the hash table and put them in garray2[]
2860e3ece09SJunchao Zhang   n2 = tot;
2870e3ece09SJunchao Zhang   PetscCall(PetscMalloc1(n2, &garray2));
2880e3ece09SJunchao Zhang   tot = 0;
2890e3ece09SJunchao Zhang   PetscHashIterBegin(g2l, iter);
2900e3ece09SJunchao Zhang   while (!PetscHashIterAtEnd(g2l, iter)) {
2910e3ece09SJunchao Zhang     PetscHashIterGetKey(g2l, iter, key);
2920e3ece09SJunchao Zhang     PetscHashIterNext(g2l, iter);
2930e3ece09SJunchao Zhang     garray2[tot++] = key;
294076ba34aSJunchao Zhang   }
295076ba34aSJunchao Zhang 
2960e3ece09SJunchao Zhang   // Sort garray2[] and then map them to local indices starting from 0
2970e3ece09SJunchao Zhang   PetscCall(PetscSortInt(n2, garray2));
2980e3ece09SJunchao Zhang   PetscCall(PetscHMapIClear(g2l));
2990e3ece09SJunchao Zhang   for (PetscInt i = 0; i < tot; i++) PetscCall(PetscHMapISet(g2l, garray2[i], i)); // i is the local id
300f0e6e2d1SJunchao Zhang 
3010e3ece09SJunchao Zhang   // Rewrite indices[] with local indices
302f0e6e2d1SJunchao Zhang   for (PetscInt i = 0; i < m; i++) {
3030e3ece09SJunchao Zhang     PetscCall(PetscHMapIGetWithDefault(g2l, indices[i], -1, &val));
3040e3ece09SJunchao Zhang     PetscAssert(val >= 0, PETSC_COMM_SELF, PETSC_ERR_PLIB, "Met a negative local column index");
3050e3ece09SJunchao Zhang     indices[i] = val;
3060e3ece09SJunchao Zhang   }
3070e3ece09SJunchao Zhang   // Record the map that maps garray1[i] to garray2[map[i]]
3080e3ece09SJunchao Zhang   for (PetscInt i = 0; i < n1; i++) PetscCall(PetscHMapIGetWithDefault(g2l, garray1[i], -1, &map[i]));
3090e3ece09SJunchao Zhang   PetscCall(PetscHMapIDestroy(&g2l));
3100e3ece09SJunchao Zhang   *n2_      = n2;
3110e3ece09SJunchao Zhang   *garray2_ = garray2;
3120e3ece09SJunchao Zhang   PetscFunctionReturn(PETSC_SUCCESS);
3130e3ece09SJunchao Zhang }
314f0e6e2d1SJunchao Zhang 
3150e3ece09SJunchao Zhang /*
3160e3ece09SJunchao Zhang   MatMPIAIJKokkosReduce - Reduce rows of a MPIAIJKOKKOS matrix (E, in split form) to produce another matrix (F, also in split form, stored in mm)
3170e3ece09SJunchao Zhang 
3180e3ece09SJunchao Zhang   It is the reverse of MatMPIAIJKokkosBcast() in some sense, but with a different signature since we do not really need a fully populated MPIAIJKOKKOS E.
3190e3ece09SJunchao Zhang 
3200e3ece09SJunchao Zhang   Think each row of E as a leaf, then the given ownerSF specifies roots for the leaves. Roots may connect to multiple leaves.
3210e3ece09SJunchao Zhang   In this routine, we sparse-merge leaves (rows) at their roots to form potentially longer rows in F. F's number of rows will be nroots of ownerSF.
3220e3ece09SJunchao Zhang 
3230e3ece09SJunchao Zhang   Input Parameters:
3240e3ece09SJunchao Zhang +  comm       - MPI communicator of E
3250e3ece09SJunchao Zhang .  A          - diag block of E, using local column indices
3260e3ece09SJunchao Zhang .  B          - off-diag block of E, using local column indices
3270e3ece09SJunchao Zhang .  cstart      - (global) start column of Ed
3280e3ece09SJunchao Zhang .  cend        - (global) end column + 1 of Ed.  In other words, E's column ownership is in range of [cstart, cend)
3290e3ece09SJunchao Zhang .  garray1[n1] - global column indices of Eo. Here n1 is Eo's column size.
3300e3ece09SJunchao Zhang .  ownerSF     - the SF specifies ownership (root) of rows in E
3310e3ece09SJunchao Zhang .  reuse       - either MAT_INITIAL_MATRIX or MAT_REUSE_MATRIX
3320e3ece09SJunchao Zhang -  mm          - to stash intermediate data structures for reuse
3330e3ece09SJunchao Zhang 
3340e3ece09SJunchao Zhang   Output Parameters:
3350e3ece09SJunchao Zhang +  map[n1]  - allocated by caller. It maps garray1[] to garray2[]. See more at ReduceTwoSetsOfGlobalIndices().
3360e3ece09SJunchao Zhang -  mm       - contains various info, such as garray2[], F (Fd, Fo) etc.
3370e3ece09SJunchao Zhang 
3380e3ece09SJunchao Zhang   Notes:
3390e3ece09SJunchao Zhang   When reuse = MAT_REUSE_MATRIX, cstart, cend, garray1, ownerSF, map are not significant.
3400e3ece09SJunchao Zhang 
3410e3ece09SJunchao Zhang  */
3420e3ece09SJunchao Zhang static PetscErrorCode MatMPIAIJKokkosReduceBegin(MPI_Comm comm, KokkosCsrMatrix A, KokkosCsrMatrix B, PetscInt cstart, PetscInt cend, const PetscInt *garray1, PetscSF ownerSF, MatReuse reuse, PetscInt *map, MatMatStruct_AtB *mm)
3430e3ece09SJunchao Zhang {
3440e3ece09SJunchao Zhang   PetscFunctionBegin;
3450e3ece09SJunchao Zhang   if (reuse == MAT_INITIAL_MATRIX) {
3460e3ece09SJunchao Zhang     PetscInt Em = A.numRows(), Fm;
3470e3ece09SJunchao Zhang     PetscInt n1 = B.numCols();
3480e3ece09SJunchao Zhang 
3490e3ece09SJunchao Zhang     PetscCall(PetscSFGetGraph(ownerSF, &Fm, NULL, NULL, NULL)); // Fm = #rows of F = nroots of ownerSF
3500e3ece09SJunchao Zhang 
3510e3ece09SJunchao Zhang     // Do the analysis on host
35245402d8aSJunchao Zhang     auto                 Ai_h = Kokkos::create_mirror_view_and_copy(HostMirrorMemorySpace(), A.graph.row_map);
35345402d8aSJunchao Zhang     auto                 Aj_h = Kokkos::create_mirror_view_and_copy(HostMirrorMemorySpace(), A.graph.entries);
35445402d8aSJunchao Zhang     auto                 Bi_h = Kokkos::create_mirror_view_and_copy(HostMirrorMemorySpace(), B.graph.row_map);
35545402d8aSJunchao Zhang     auto                 Bj_h = Kokkos::create_mirror_view_and_copy(HostMirrorMemorySpace(), B.graph.entries);
3560e3ece09SJunchao Zhang     const MatRowMapType *Ai = Ai_h.data(), *Bi = Bi_h.data();
3570e3ece09SJunchao Zhang     const MatColIdxType *Aj = Aj_h.data(), *Bj = Bj_h.data();
3580e3ece09SJunchao Zhang 
3590e3ece09SJunchao Zhang     // Count how many nonzeros of each row in E are in the left of the diag block [cstart,cend)
3607b8d4ba6SJunchao Zhang     PetscIntKokkosViewHost E_NzLeft_h(NoInit("E_NzLeft_h"), Em), E_RowLen_h(NoInit("E_RowLen_h"), Em);
3610e3ece09SJunchao Zhang     PetscInt              *E_NzLeft = E_NzLeft_h.data(), *E_RowLen = E_RowLen_h.data();
3620e3ece09SJunchao Zhang     for (PetscInt i = 0; i < Em; i++) {
3630e3ece09SJunchao Zhang       const PetscInt *first, *last, *it;
3640e3ece09SJunchao Zhang       PetscInt        count, step;
3650e3ece09SJunchao Zhang       // std::lower_bound(first,last,cstart), but need to use global column indices
3660e3ece09SJunchao Zhang       first = Bj + Bi[i];
3670e3ece09SJunchao Zhang       last  = Bj + Bi[i + 1];
368f0e6e2d1SJunchao Zhang       count = last - first;
369f0e6e2d1SJunchao Zhang       while (count > 0) {
370f0e6e2d1SJunchao Zhang         it   = first;
371f0e6e2d1SJunchao Zhang         step = count / 2;
372f0e6e2d1SJunchao Zhang         it += step;
3730e3ece09SJunchao Zhang         if (garray1[*it] < cstart) { // map local to global
374f0e6e2d1SJunchao Zhang           first = ++it;
375f0e6e2d1SJunchao Zhang           count -= step + 1;
376f0e6e2d1SJunchao Zhang         } else count = step;
377f0e6e2d1SJunchao Zhang       }
3780e3ece09SJunchao Zhang       E_NzLeft[i] = first - (Bj + Bi[i]);
3790e3ece09SJunchao Zhang       E_RowLen[i] = (Ai[i + 1] - Ai[i]) + (Bi[i + 1] - Bi[i]);
380f0e6e2d1SJunchao Zhang     }
381f0e6e2d1SJunchao Zhang 
3820e3ece09SJunchao Zhang     // Get length of rows (i.e., sizes of leaves) that contribute to my roots
3830e3ece09SJunchao Zhang     const PetscMPIInt *iranks, *ranks;
3840e3ece09SJunchao Zhang     const PetscInt    *ioffset, *irootloc, *roffset, *rmine;
385c09cee04SJames Wright     PetscMPIInt        niranks, nranks;
3860e3ece09SJunchao Zhang     MPI_Request       *reqs;
3870e3ece09SJunchao Zhang     PetscMPIInt        tag;
3880e3ece09SJunchao Zhang     PetscSF            reduceSF;
3890e3ece09SJunchao Zhang     PetscInt          *sdisp, *rdisp;
390f0e6e2d1SJunchao Zhang 
3910e3ece09SJunchao Zhang     PetscCall(PetscCommGetNewTag(comm, &tag));
3920e3ece09SJunchao Zhang     PetscCall(PetscSFGetLeafRanks(ownerSF, &niranks, &iranks, &ioffset, &irootloc));  // get leaf ranks connecting to roots on this process (I'll recv from them)
3930e3ece09SJunchao Zhang     PetscCall(PetscSFGetRootRanks(ownerSF, &nranks, &ranks, &roffset, &rmine, NULL)); // get root ranks this process connects (I'll send to them)
394f0e6e2d1SJunchao Zhang 
3950e3ece09SJunchao Zhang     // Find out length of each row I will receive. Even for the same row index, when they are from
3960e3ece09SJunchao Zhang     // different senders, they might have different lengths (and sparsity patterns)
3970e3ece09SJunchao Zhang     PetscInt  sendRowCnt = roffset[nranks], recvRowCnt = ioffset[niranks];
3980e3ece09SJunchao Zhang     PetscInt *sendRowLen, *recvRowLen; // lengths of rows of I need to send/recv per process
399f0e6e2d1SJunchao Zhang 
4000e3ece09SJunchao Zhang     PetscCall(PetscMalloc5(sendRowCnt, &sendRowLen, recvRowCnt + 1, &recvRowLen, nranks, &sdisp, niranks + 1, &rdisp, nranks + niranks, &reqs));
4010e3ece09SJunchao Zhang 
4020e3ece09SJunchao Zhang     for (PetscInt i = 0; i < sendRowCnt; i++) sendRowLen[i] = E_RowLen[rmine[i]];
4030e3ece09SJunchao Zhang     recvRowLen[0] = 0; // since we will make it in CSR format later
4040e3ece09SJunchao Zhang     recvRowLen++;      // advance the pointer now
4056497c311SBarry Smith     for (PetscInt i = 0; i < niranks; i++) MPI_Irecv(&recvRowLen[ioffset[i]], ioffset[i + 1] - ioffset[i], MPIU_INT, iranks[i], tag, comm, &reqs[nranks + i]);
4066497c311SBarry Smith     for (PetscInt i = 0; i < nranks; i++) MPIU_Isend(&sendRowLen[roffset[i]], roffset[i + 1] - roffset[i], MPIU_INT, ranks[i], tag, comm, &reqs[i]);
4070e3ece09SJunchao Zhang     PetscCallMPI(MPI_Waitall(nranks + niranks, reqs, MPI_STATUSES_IGNORE));
4080e3ece09SJunchao Zhang 
4090e3ece09SJunchao Zhang     // Build the real PetscSF for reducing E rows (buffer to buffer)
4100e3ece09SJunchao Zhang     rdisp[0] = 0;
4110e3ece09SJunchao Zhang     for (PetscInt i = 0; i < niranks; i++) {
4120e3ece09SJunchao Zhang       rdisp[i + 1] = rdisp[i];
4130e3ece09SJunchao Zhang       for (PetscInt j = ioffset[i]; j < ioffset[i + 1]; j++) { rdisp[i + 1] += recvRowLen[j]; }
4140e3ece09SJunchao Zhang     }
4150e3ece09SJunchao Zhang     recvRowLen--; // put it back into csr format
4160e3ece09SJunchao Zhang     for (PetscInt i = 0; i < recvRowCnt; i++) recvRowLen[i + 1] += recvRowLen[i];
4170e3ece09SJunchao Zhang 
4186497c311SBarry Smith     for (PetscInt i = 0; i < nranks; i++) MPIU_Irecv(&sdisp[i], 1, MPIU_INT, ranks[i], tag, comm, &reqs[i]);
4196497c311SBarry Smith     for (PetscInt i = 0; i < niranks; i++) MPIU_Isend(&rdisp[i], 1, MPIU_INT, iranks[i], tag, comm, &reqs[nranks + i]);
4200e3ece09SJunchao Zhang     PetscCallMPI(MPI_Waitall(nranks + niranks, reqs, MPI_STATUSES_IGNORE));
4210e3ece09SJunchao Zhang 
4220e3ece09SJunchao Zhang     PetscInt     nleaves = 0, Enz = 0;    // leaves are nonzeros I will send
4230e3ece09SJunchao Zhang     PetscInt     nroots = rdisp[niranks]; // roots are nonzeros I will recv
4240e3ece09SJunchao Zhang     PetscSFNode *iremote;
4250e3ece09SJunchao Zhang 
4260e3ece09SJunchao Zhang     for (PetscInt i = 0; i < Em; i++) Enz += E_RowLen[i];
4270e3ece09SJunchao Zhang     PetscAssert(A.nnz() + B.nnz() == Enz, comm, PETSC_ERR_PLIB, "Enz should be equal to sum of nnz of A and B");
4280e3ece09SJunchao Zhang     PetscCallMPI(PetscMalloc1(Enz, &iremote)); // no free, since we give ownership to reduceSF
4290e3ece09SJunchao Zhang 
4300e3ece09SJunchao Zhang     for (PetscInt i = 0; i < nranks; i++) {
4310e3ece09SJunchao Zhang       PetscInt count = 0;
4320e3ece09SJunchao Zhang       for (PetscInt j = roffset[i]; j < roffset[i + 1]; j++) count += E_RowLen[rmine[j]];
4330e3ece09SJunchao Zhang       for (PetscInt j = 0; j < count; j++) {
4340e3ece09SJunchao Zhang         iremote[nleaves + j].rank  = ranks[i];
4350e3ece09SJunchao Zhang         iremote[nleaves + j].index = sdisp[i] + j;
4360e3ece09SJunchao Zhang       }
4370e3ece09SJunchao Zhang       nleaves += count;
4380e3ece09SJunchao Zhang     }
4390e3ece09SJunchao Zhang     PetscCheck(nleaves == Enz, comm, PETSC_ERR_PLIB, "nleaves should be equal to Enz");
4400e3ece09SJunchao Zhang 
4410e3ece09SJunchao Zhang     PetscCall(PetscSFCreate(comm, &reduceSF));
4420e3ece09SJunchao Zhang     PetscCall(PetscSFSetGraph(reduceSF, nroots, nleaves, NULL, PETSC_OWN_POINTER, iremote, PETSC_OWN_POINTER));
4430e3ece09SJunchao Zhang 
4440e3ece09SJunchao Zhang     // Copy (global) column indices of the needed rows in E to sendCol[], and then PetscSFReduce to recvCol[]
4450e3ece09SJunchao Zhang     PetscInt *sendCol, *recvCol;
4460e3ece09SJunchao Zhang     PetscCall(PetscMalloc2(nleaves, &sendCol, nroots, &recvCol));
4470e3ece09SJunchao Zhang     for (PetscInt k = 0; k < roffset[nranks]; k++) {
4480e3ece09SJunchao Zhang       PetscInt  i      = rmine[k]; // row to be copied
4490e3ece09SJunchao Zhang       PetscInt *buf    = &sendCol[Ai[i] + Bi[i]];
4500e3ece09SJunchao Zhang       PetscInt  nzLeft = E_NzLeft[i];
4510e3ece09SJunchao Zhang       PetscInt  alen = Ai[i + 1] - Ai[i], blen = Bi[i + 1] - Bi[i];
4520e3ece09SJunchao Zhang       for (PetscInt j = 0; j < alen + blen; j++) {
4530e3ece09SJunchao Zhang         if (j < nzLeft) {
4540e3ece09SJunchao Zhang           buf[j] = garray1[Bj[Bi[i] + j]]; // left B, in global
4550e3ece09SJunchao Zhang         } else if (j < nzLeft + alen) {
4560e3ece09SJunchao Zhang           buf[j] = Aj[Ai[i] + j - nzLeft] + cstart; // diag A, also in global
4570e3ece09SJunchao Zhang         } else {
4580e3ece09SJunchao Zhang           buf[j] = garray1[Bj[Bi[i] + j - alen]]; // right B, in global
4590e3ece09SJunchao Zhang         }
4600e3ece09SJunchao Zhang       }
4610e3ece09SJunchao Zhang     }
4620e3ece09SJunchao Zhang     PetscCall(PetscSFReduceWithMemTypeBegin(reduceSF, MPIU_INT, PETSC_MEMTYPE_HOST, sendCol, PETSC_MEMTYPE_HOST, recvCol, MPI_REPLACE));
4630e3ece09SJunchao Zhang     PetscCall(PetscSFReduceEnd(reduceSF, MPIU_INT, sendCol, recvCol, MPI_REPLACE));
4640e3ece09SJunchao Zhang 
4650e3ece09SJunchao Zhang     // With recvCol[], we do a series of analysis to get i, j of Fd, Fo, and build plans to reduce nonzeros in recv buffers to Fd and Fo
4660e3ece09SJunchao Zhang     PetscInt *recvRowPerm, *recvColSorted;
4670e3ece09SJunchao Zhang     PetscInt *recvNzPerm, *recvNzPermSorted;
4680e3ece09SJunchao Zhang     PetscCall(PetscMalloc4(recvRowCnt, &recvRowPerm, nroots, &recvColSorted, nroots, &recvNzPerm, nroots, &recvNzPermSorted));
4690e3ece09SJunchao Zhang 
4700e3ece09SJunchao Zhang     for (PetscInt i = 0; i < nroots; i++) recvNzPerm[i] = i;                   // numbering all received nonzeros
4710e3ece09SJunchao Zhang     for (PetscInt i = 0; i < recvRowCnt; i++) recvRowPerm[i] = i;              // put up a permutation array, so that after sorting we know where to get a row in recvCol[]
4720e3ece09SJunchao Zhang     PetscCall(PetscSortIntWithPermutation(recvRowCnt, irootloc, recvRowPerm)); // irootloc[] (owned by ownerSF) won't be changed
4730e3ece09SJunchao Zhang 
4740e3ece09SJunchao Zhang     // i[] array, nz are always easiest to compute
4757b8d4ba6SJunchao Zhang     MatRowMapKokkosViewHost Fdi_h(NoInit("Fdi_h"), Fm + 1), Foi_h(NoInit("Foi_h"), Fm + 1);
4760e3ece09SJunchao Zhang     MatRowMapType          *Fdi, *Foi;
4770e3ece09SJunchao Zhang     PetscInt                FnzDups = 0, Fdnz = 0, FdnzDups = 0, Fonz = 0, FonzDups = 0; // nz (with or without dups) in F, Fd, Fo
4780e3ece09SJunchao Zhang     PetscInt                iter;
4790e3ece09SJunchao Zhang 
4800e3ece09SJunchao Zhang     Kokkos::deep_copy(Fdi_h, 0); // zero, as we will do 'val++' on them
4810e3ece09SJunchao Zhang     Kokkos::deep_copy(Foi_h, 0);
4820e3ece09SJunchao Zhang     Fdi  = Fdi_h.data() + 1; // +1 for easy indexing in code below
4830e3ece09SJunchao Zhang     Foi  = Foi_h.data() + 1;
4840e3ece09SJunchao Zhang     iter = 0;
4850e3ece09SJunchao Zhang     while (iter < recvRowCnt) { // iter over received rows
4860e3ece09SJunchao Zhang       PetscInt curRowIdx = irootloc[recvRowPerm[iter]];
4870e3ece09SJunchao Zhang       PetscInt dupRows   = 1; // current row has this many contributing rows (of various sparsity patterns)
4880e3ece09SJunchao Zhang 
4890e3ece09SJunchao Zhang       while (iter + dupRows < recvRowCnt && irootloc[recvRowPerm[iter + dupRows]] == curRowIdx) dupRows++;
4900e3ece09SJunchao Zhang 
4910e3ece09SJunchao Zhang       // Copy column indices (and their permutation) of these rows into recvColSorted & recvNzPermSorted
4920e3ece09SJunchao Zhang       PetscInt  nz    = 0; // nz (with dups) in the current row
4930e3ece09SJunchao Zhang       PetscInt *jbuf  = recvColSorted + FnzDups;
4940e3ece09SJunchao Zhang       PetscInt *pbuf  = recvNzPermSorted + FnzDups;
4950e3ece09SJunchao Zhang       PetscInt *jbuf2 = jbuf; // temp pointers
4960e3ece09SJunchao Zhang       PetscInt *pbuf2 = pbuf;
4970e3ece09SJunchao Zhang       for (PetscInt d = 0; d < dupRows; d++) {
4980e3ece09SJunchao Zhang         PetscInt i   = recvRowPerm[iter + d];
4990e3ece09SJunchao Zhang         PetscInt len = recvRowLen[i + 1] - recvRowLen[i];
5000e3ece09SJunchao Zhang         PetscCall(PetscArraycpy(jbuf2, &recvCol[recvRowLen[i]], len));
5010e3ece09SJunchao Zhang         PetscCall(PetscArraycpy(pbuf2, &recvNzPerm[recvRowLen[i]], len));
5020e3ece09SJunchao Zhang         jbuf2 += len;
5030e3ece09SJunchao Zhang         pbuf2 += len;
5040e3ece09SJunchao Zhang         nz += len;
5050e3ece09SJunchao Zhang       }
5060e3ece09SJunchao Zhang       PetscCall(PetscIntSortSemiOrderedWithArray(nz, jbuf, pbuf)); // It could be improved with k-way merge sort, since the rows are already sorted
5070e3ece09SJunchao Zhang 
5080e3ece09SJunchao Zhang       // Scan column indices (in jbuf[0,nz), might have dups) of this row, and see how many go to Fd and how many go to Fo
5090e3ece09SJunchao Zhang       PetscInt cur = 0;
5100e3ece09SJunchao Zhang       while (cur < nz) {
5110e3ece09SJunchao Zhang         PetscInt curColIdx = jbuf[cur];
5120e3ece09SJunchao Zhang         PetscInt dups      = 1;
5130e3ece09SJunchao Zhang 
5140e3ece09SJunchao Zhang         while (cur + dups < nz && jbuf[cur + dups] == curColIdx) dups++;
5150e3ece09SJunchao Zhang         if (curColIdx >= cstart && curColIdx < cend) {
5160e3ece09SJunchao Zhang           Fdi[curRowIdx]++;
5170e3ece09SJunchao Zhang           FdnzDups += dups;
5180e3ece09SJunchao Zhang         } else {
5190e3ece09SJunchao Zhang           Foi[curRowIdx]++;
5200e3ece09SJunchao Zhang           FonzDups += dups;
5210e3ece09SJunchao Zhang         }
5220e3ece09SJunchao Zhang         cur += dups;
5230e3ece09SJunchao Zhang       }
5240e3ece09SJunchao Zhang 
5250e3ece09SJunchao Zhang       FnzDups += nz;
5260e3ece09SJunchao Zhang       iter += dupRows; // Move to next unique row
5270e3ece09SJunchao Zhang     }
5280e3ece09SJunchao Zhang 
5290e3ece09SJunchao Zhang     Fdi = Fdi_h.data(); // restore Fdi, Foi and make them CSR
5300e3ece09SJunchao Zhang     Foi = Foi_h.data();
5310e3ece09SJunchao Zhang     for (PetscInt i = 0; i < Fm; i++) {
5320e3ece09SJunchao Zhang       Fdi[i + 1] += Fdi[i];
5330e3ece09SJunchao Zhang       Foi[i + 1] += Foi[i];
5340e3ece09SJunchao Zhang     }
5350e3ece09SJunchao Zhang     Fdnz = Fdi[Fm];
5360e3ece09SJunchao Zhang     Fonz = Foi[Fm];
5370e3ece09SJunchao Zhang     PetscCall(PetscFree2(sendCol, recvCol));
5380e3ece09SJunchao Zhang 
5390e3ece09SJunchao Zhang     // Allocate j, jmap, jperm for Fd and Fo
5407b8d4ba6SJunchao Zhang     MatColIdxKokkosViewHost Fdj_h(NoInit("Fdj_h"), Fdnz), Foj_h(NoInit("Foj_h"), Fonz);
5417b8d4ba6SJunchao Zhang     MatRowMapKokkosViewHost Fdjmap_h(NoInit("Fdjmap_h"), Fdnz + 1), Fojmap_h(NoInit("Fojmap_h"), Fonz + 1); // +1 to make csr
5427b8d4ba6SJunchao Zhang     MatRowMapKokkosViewHost Fdjperm_h(NoInit("Fdjperm_h"), FdnzDups), Fojperm_h(NoInit("Fojperm_h"), FonzDups);
5430e3ece09SJunchao Zhang     MatColIdxType          *Fdj = Fdj_h.data(), *Foj = Foj_h.data();
5440e3ece09SJunchao Zhang     MatRowMapType          *Fdjmap = Fdjmap_h.data(), *Fojmap = Fojmap_h.data();
5450e3ece09SJunchao Zhang     MatRowMapType          *Fdjperm = Fdjperm_h.data(), *Fojperm = Fojperm_h.data();
5460e3ece09SJunchao Zhang 
5470e3ece09SJunchao Zhang     // Scan recvColSorted[] again, and fill j, jmap, jperm for Fd and Fo
5480e3ece09SJunchao Zhang     Fdjmap[0] = 0;
5490e3ece09SJunchao Zhang     Fojmap[0] = 0;
5500e3ece09SJunchao Zhang     FnzDups   = 0;
5510e3ece09SJunchao Zhang     Fdnz      = 0;
5520e3ece09SJunchao Zhang     Fonz      = 0;
5530e3ece09SJunchao Zhang     iter      = 0; // iter over received rows
5540e3ece09SJunchao Zhang     while (iter < recvRowCnt) {
5550e3ece09SJunchao Zhang       PetscInt curRowIdx = irootloc[recvRowPerm[iter]]; // current row idx
5560e3ece09SJunchao Zhang       PetscInt dupRows   = 1;                           // It has this many contributing rows (of various lengths)
5570e3ece09SJunchao Zhang       PetscInt nz        = 0;                           // nz (with dups) in the current row
5580e3ece09SJunchao Zhang 
5590e3ece09SJunchao Zhang       while (iter + dupRows < recvRowCnt && irootloc[recvRowPerm[iter + dupRows]] == curRowIdx) dupRows++;
5600e3ece09SJunchao Zhang       for (PetscInt d = 0; d < dupRows; d++) {
5610e3ece09SJunchao Zhang         PetscInt i = recvRowPerm[iter + d];
5620e3ece09SJunchao Zhang         nz += recvRowLen[i + 1] - recvRowLen[i];
5630e3ece09SJunchao Zhang       }
5640e3ece09SJunchao Zhang 
5650e3ece09SJunchao Zhang       PetscInt *jbuf = recvColSorted + FnzDups;
5660e3ece09SJunchao Zhang       // Scan columns (in jbuf[0,nz) of this row, copy them and their permutation to j[] and jperm[] of Fd and Fo
5670e3ece09SJunchao Zhang       PetscInt cur = 0;
5680e3ece09SJunchao Zhang       while (cur < nz) {
5690e3ece09SJunchao Zhang         PetscInt curColIdx = jbuf[cur];
5700e3ece09SJunchao Zhang         PetscInt dups      = 1;
5710e3ece09SJunchao Zhang 
5720e3ece09SJunchao Zhang         while (cur + dups < nz && jbuf[cur + dups] == curColIdx) dups++;
5730e3ece09SJunchao Zhang         if (curColIdx >= cstart && curColIdx < cend) {
5740e3ece09SJunchao Zhang           Fdj[Fdnz]        = curColIdx - cstart; // easily convert to local
5750e3ece09SJunchao Zhang           Fdjmap[Fdnz + 1] = Fdjmap[Fdnz] + dups;
5760e3ece09SJunchao Zhang           for (PetscInt j = 0; j < dups; j++) Fdjperm[Fdjmap[Fdnz] + j] = recvNzPermSorted[FnzDups + j];
5770e3ece09SJunchao Zhang           FdnzDups += dups;
5780e3ece09SJunchao Zhang           Fdnz++;
5790e3ece09SJunchao Zhang         } else {
5800e3ece09SJunchao Zhang           Foj[Fonz]        = curColIdx; // in global
5810e3ece09SJunchao Zhang           Fojmap[Fonz + 1] = Fojmap[Fonz] + dups;
5820e3ece09SJunchao Zhang           for (PetscInt j = 0; j < dups; j++) Fojperm[Fojmap[Fonz] + j] = recvNzPermSorted[FnzDups + j];
5830e3ece09SJunchao Zhang           FonzDups += dups;
5840e3ece09SJunchao Zhang           Fonz++;
5850e3ece09SJunchao Zhang         }
5860e3ece09SJunchao Zhang         cur += dups;
5870e3ece09SJunchao Zhang         FnzDups += dups;
5880e3ece09SJunchao Zhang       }
5890e3ece09SJunchao Zhang       iter += dupRows; // Move to next unique row
5900e3ece09SJunchao Zhang     }
5910e3ece09SJunchao Zhang     PetscCall(PetscFree4(recvRowPerm, recvColSorted, recvNzPerm, recvNzPermSorted));
5920e3ece09SJunchao Zhang     PetscCall(PetscFree5(sendRowLen, recvRowLen, sdisp, rdisp, reqs));
5930e3ece09SJunchao Zhang 
5940e3ece09SJunchao Zhang     // Combine global column indices in garray1[] and Foj[]
5950e3ece09SJunchao Zhang     PetscInt n2, *garray2;
5960e3ece09SJunchao Zhang 
5970e3ece09SJunchao Zhang     PetscCall(ReduceTwoSetsOfGlobalIndices(n1, garray1, Fonz, Foj, &n2, &garray2, map));
5980e3ece09SJunchao Zhang     mm->sf       = reduceSF;
5997b8d4ba6SJunchao Zhang     mm->leafBuf  = MatScalarKokkosView(NoInit("leafBuf"), nleaves);
6007b8d4ba6SJunchao Zhang     mm->rootBuf  = MatScalarKokkosView(NoInit("rootBuf"), nroots);
601aaa8cc7dSPierre Jolivet     mm->garray   = garray2; // give ownership, so no free
6020e3ece09SJunchao Zhang     mm->n        = n2;
6030e3ece09SJunchao Zhang     mm->E_NzLeft = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), E_NzLeft_h);
6040e3ece09SJunchao Zhang     mm->Fdjmap   = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Fdjmap_h);
6050e3ece09SJunchao Zhang     mm->Fdjperm  = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Fdjperm_h);
6060e3ece09SJunchao Zhang     mm->Fojmap   = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Fojmap_h);
6070e3ece09SJunchao Zhang     mm->Fojperm  = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Fojperm_h);
6080e3ece09SJunchao Zhang 
6090e3ece09SJunchao Zhang     // Output Fd and Fo in KokkosCsrMatrix format
6107b8d4ba6SJunchao Zhang     MatScalarKokkosView Fda_d(NoInit("Fda_d"), Fdnz);
6110e3ece09SJunchao Zhang     MatRowMapKokkosView Fdi_d = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Fdi_h);
6120e3ece09SJunchao Zhang     MatColIdxKokkosView Fdj_d = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Fdj_h);
6137b8d4ba6SJunchao Zhang     MatScalarKokkosView Foa_d(NoInit("Foa_d"), Fonz);
6140e3ece09SJunchao Zhang     MatRowMapKokkosView Foi_d = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Foi_h);
6150e3ece09SJunchao Zhang     MatColIdxKokkosView Foj_d = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Foj_h);
6160e3ece09SJunchao Zhang 
6170e3ece09SJunchao Zhang     PetscCallCXX(mm->Fd = KokkosCsrMatrix("Fd", Fm, cend - cstart, Fdnz, Fda_d, Fdi_d, Fdj_d));
6180e3ece09SJunchao Zhang     PetscCallCXX(mm->Fo = KokkosCsrMatrix("Fo", Fm, n2, Fonz, Foa_d, Foi_d, Foj_d)); // Fo's column size is n2, length of garray2[]
6190e3ece09SJunchao Zhang 
6200e3ece09SJunchao Zhang     // Compute kernel launch parameters in merging E
6210e3ece09SJunchao Zhang     PetscInt teamSize, vectorLength, rowsPerTeam;
6220e3ece09SJunchao Zhang 
6230e3ece09SJunchao Zhang     teamSize = vectorLength = rowsPerTeam = -1;
6240e3ece09SJunchao Zhang     PetscCall(MatMergeGetLaunchParameters<DefaultExecutionSpace>(Em, Enz, -1, teamSize, vectorLength, rowsPerTeam));
6250e3ece09SJunchao Zhang     mm->E_TeamSize     = teamSize;
6260e3ece09SJunchao Zhang     mm->E_VectorLength = vectorLength;
6270e3ece09SJunchao Zhang     mm->E_RowsPerTeam  = rowsPerTeam;
6280e3ece09SJunchao Zhang   } else PetscCheck(reuse == MAT_REUSE_MATRIX, comm, PETSC_ERR_PLIB, "Unsupported MatReuse enum %d", reuse);
6290e3ece09SJunchao Zhang 
6300e3ece09SJunchao Zhang   // Handy aliases
6310e3ece09SJunchao Zhang   auto       &Aa           = A.values;
6320e3ece09SJunchao Zhang   auto       &Ba           = B.values;
6330e3ece09SJunchao Zhang   const auto &Ai           = A.graph.row_map;
6340e3ece09SJunchao Zhang   const auto &Bi           = B.graph.row_map;
6350e3ece09SJunchao Zhang   const auto &E_NzLeft     = mm->E_NzLeft;
6360e3ece09SJunchao Zhang   auto       &leafBuf      = mm->leafBuf;
6370e3ece09SJunchao Zhang   auto       &rootBuf      = mm->rootBuf;
6380e3ece09SJunchao Zhang   PetscSF     reduceSF     = mm->sf;
6390e3ece09SJunchao Zhang   PetscInt    Em           = A.numRows();
6400e3ece09SJunchao Zhang   PetscInt    teamSize     = mm->E_TeamSize;
6410e3ece09SJunchao Zhang   PetscInt    vectorLength = mm->E_VectorLength;
6420e3ece09SJunchao Zhang   PetscInt    rowsPerTeam  = mm->E_RowsPerTeam;
6430e3ece09SJunchao Zhang   PetscInt    workSets     = (Em + rowsPerTeam - 1) / rowsPerTeam;
6440e3ece09SJunchao Zhang 
6450e3ece09SJunchao Zhang   // Copy rows in A/B of E to leafBuf, then pass it to rootBuf
6460e3ece09SJunchao Zhang   PetscCallCXX(Kokkos::parallel_for(
647d326c3f1SJunchao Zhang     Kokkos::TeamPolicy<>(PetscGetKokkosExecutionSpace(), workSets, teamSize, vectorLength), KOKKOS_LAMBDA(const KokkosTeamMemberType &t) {
6480e3ece09SJunchao Zhang       Kokkos::parallel_for(Kokkos::TeamThreadRange(t, 0, rowsPerTeam), [&](PetscInt k) {
6490e3ece09SJunchao Zhang         PetscInt i = t.league_rank() * rowsPerTeam + k; // i-th row in F
6500e3ece09SJunchao Zhang         if (i < Em) {
6510e3ece09SJunchao Zhang           PetscInt disp   = Ai(i) + Bi(i);
6520e3ece09SJunchao Zhang           PetscInt alen   = Ai(i + 1) - Ai(i);
6530e3ece09SJunchao Zhang           PetscInt blen   = Bi(i + 1) - Bi(i);
6540e3ece09SJunchao Zhang           PetscInt nzleft = E_NzLeft(i);
6550e3ece09SJunchao Zhang 
6560e3ece09SJunchao Zhang           Kokkos::parallel_for(Kokkos::ThreadVectorRange(t, alen + blen), [&](PetscInt j) {
6570e3ece09SJunchao Zhang             MatScalar &val = leafBuf(disp + j);
6580e3ece09SJunchao Zhang             if (j < nzleft) { // B left
6590e3ece09SJunchao Zhang               val = Ba(Bi(i) + j);
6600e3ece09SJunchao Zhang             } else if (j < nzleft + alen) { // diag A
6610e3ece09SJunchao Zhang               val = Aa(Ai(i) + j - nzleft);
6620e3ece09SJunchao Zhang             } else { // B right
6630e3ece09SJunchao Zhang               val = Ba(Bi(i) + j - alen);
664f0e6e2d1SJunchao Zhang             }
665f0e6e2d1SJunchao Zhang           });
666f0e6e2d1SJunchao Zhang         }
667f0e6e2d1SJunchao Zhang       });
6680e3ece09SJunchao Zhang     }));
6690e3ece09SJunchao Zhang   PetscCall(PetscSFReduceWithMemTypeBegin(reduceSF, MPIU_SCALAR, PETSC_MEMTYPE_KOKKOS, leafBuf.data(), PETSC_MEMTYPE_KOKKOS, rootBuf.data(), MPI_REPLACE));
670f0e6e2d1SJunchao Zhang   PetscFunctionReturn(PETSC_SUCCESS);
671f0e6e2d1SJunchao Zhang }
6720e3ece09SJunchao Zhang 
673aaa8cc7dSPierre Jolivet // To finish MatMPIAIJKokkosReduce.
6740e3ece09SJunchao Zhang static PetscErrorCode MatMPIAIJKokkosReduceEnd(MPI_Comm comm, KokkosCsrMatrix A, KokkosCsrMatrix B, PetscInt cstart, PetscInt cend, const PetscInt *garray1, PetscSF ownerSF, MatReuse reuse, PetscInt *map, MatMatStruct_AtB *mm)
6750e3ece09SJunchao Zhang {
6760e3ece09SJunchao Zhang   auto       &leafBuf  = mm->leafBuf;
6770e3ece09SJunchao Zhang   auto       &rootBuf  = mm->rootBuf;
6780e3ece09SJunchao Zhang   auto       &Fda      = mm->Fd.values;
6790e3ece09SJunchao Zhang   const auto &Fdjmap   = mm->Fdjmap;
6800e3ece09SJunchao Zhang   const auto &Fdjperm  = mm->Fdjperm;
6810e3ece09SJunchao Zhang   auto        Fdnz     = mm->Fd.nnz();
6820e3ece09SJunchao Zhang   auto       &Foa      = mm->Fo.values;
6830e3ece09SJunchao Zhang   const auto &Fojmap   = mm->Fojmap;
6840e3ece09SJunchao Zhang   const auto &Fojperm  = mm->Fojperm;
6850e3ece09SJunchao Zhang   auto        Fonz     = mm->Fo.nnz();
6860e3ece09SJunchao Zhang   PetscSF     reduceSF = mm->sf;
6870e3ece09SJunchao Zhang 
688d326c3f1SJunchao Zhang   PetscFunctionBegin;
6890e3ece09SJunchao Zhang   PetscCall(PetscSFReduceEnd(reduceSF, MPIU_SCALAR, leafBuf.data(), rootBuf.data(), MPI_REPLACE));
6900e3ece09SJunchao Zhang 
6910e3ece09SJunchao Zhang   // Reduce data in rootBuf to Fd and Fo
6920e3ece09SJunchao Zhang   PetscCallCXX(Kokkos::parallel_for(
693d326c3f1SJunchao Zhang     Kokkos::RangePolicy<>(PetscGetKokkosExecutionSpace(), 0, Fdnz), KOKKOS_LAMBDA(const MatRowMapType i) {
6940e3ece09SJunchao Zhang       PetscScalar sum = 0.0;
6950e3ece09SJunchao Zhang       for (MatRowMapType k = Fdjmap(i); k < Fdjmap(i + 1); k++) sum += rootBuf(Fdjperm(k));
6960e3ece09SJunchao Zhang       Fda(i) = sum;
6970e3ece09SJunchao Zhang     }));
6980e3ece09SJunchao Zhang 
6990e3ece09SJunchao Zhang   PetscCallCXX(Kokkos::parallel_for(
700d326c3f1SJunchao Zhang     Kokkos::RangePolicy<>(PetscGetKokkosExecutionSpace(), 0, Fonz), KOKKOS_LAMBDA(const MatRowMapType i) {
7010e3ece09SJunchao Zhang       PetscScalar sum = 0.0;
7020e3ece09SJunchao Zhang       for (MatRowMapType k = Fojmap(i); k < Fojmap(i + 1); k++) sum += rootBuf(Fojperm(k));
7030e3ece09SJunchao Zhang       Foa(i) = sum;
7040e3ece09SJunchao Zhang     }));
7050e3ece09SJunchao Zhang   PetscFunctionReturn(PETSC_SUCCESS);
7060e3ece09SJunchao Zhang }
7070e3ece09SJunchao Zhang 
7080e3ece09SJunchao Zhang /*
7090e3ece09SJunchao Zhang   MatMPIAIJKokkosBcast - Bcast local rows of a MPIAIJKOKKOS matrix (E) to produce a local matrix (F, stored in mm) in split form
7100e3ece09SJunchao Zhang 
7110e3ece09SJunchao Zhang   This is a complex routine. It is essentially the MPIAIJKOKKOS counterpart of MatGetBrowsOfAoCols_MPIAIJ, but supports
7120e3ece09SJunchao Zhang   device and involves various index mapping.
7130e3ece09SJunchao Zhang 
7140e3ece09SJunchao Zhang   In the given ownerSF, leaves correspond to rows in F, and roots correspond to rows in E. Roots may connect to multiple leaves.
7150e3ece09SJunchao Zhang   Suppose F's j-th row is connected to a root identified by PetscSFNode (k,i), it means we need to bcast the i-th row of E on rank k
7160e3ece09SJunchao Zhang   to j-th row of F. ownerSF is not an arbitrary SF, instead it is the Mvctx of another MPIAIJ matrix A that is able to perform A*E.
7170e3ece09SJunchao Zhang   F has the same column layout as E.
7180e3ece09SJunchao Zhang 
7190e3ece09SJunchao Zhang   Conceptually F has global column indices. In this routine, we spit F into diagonal Fd and off-diagonal Fo.
720aaa8cc7dSPierre Jolivet   Fd uses local column indices, which are easy to compute. We just need to subtract the "local column range start" from the global indices.
7210e3ece09SJunchao Zhang   Fo had global column indices at first. We will reduce them into local ones. In doing that, we also take into account the global
7220e3ece09SJunchao Zhang   column indices that E's off-diag block has. Let's say there are n1 such indices stored in garray1[]. We will reduce them along with
7230e3ece09SJunchao Zhang   column indices in Fo and update Fo with local indices.
7240e3ece09SJunchao Zhang 
7250e3ece09SJunchao Zhang    Input Parameters:
7260e3ece09SJunchao Zhang +   E       - the MPIAIJKOKKOS matrix
7279c89aa79SPierre Jolivet .   ownerSF - the ownership SF (insignificant in MAT_REUSE_MATRIX)
7280e3ece09SJunchao Zhang .   reuse   - either MAT_INITIAL_MATRIX or MAT_REUSE_MATRIX
7290e3ece09SJunchao Zhang -   mm      - to stash matproduct intermediate data structures
7300e3ece09SJunchao Zhang 
7310e3ece09SJunchao Zhang     Output Parameters:
7320e3ece09SJunchao Zhang +   map[n1] - allocated by caller. It maps garray1[] to garray2[]. See more at ReduceTwoSetsOfGlobalIndices.
7330e3ece09SJunchao Zhang -   mm      - contains various info, such as garray2[], Fd, Fo, etc.
7340e3ece09SJunchao Zhang 
7350e3ece09SJunchao Zhang     Notes:
7360e3ece09SJunchao Zhang     When reuse = MAT_REUSE_MATRIX, ownerSF, map are not significant.
7370e3ece09SJunchao Zhang     The routine is provide in split-phase form MatMPIAIJKokkosBcastBegin/End() to provide computation/communication opportunities.
7380e3ece09SJunchao Zhang */
7390e3ece09SJunchao Zhang static PetscErrorCode MatMPIAIJKokkosBcastBegin(Mat E, PetscSF ownerSF, MatReuse reuse, PetscInt *map, MatMatStruct_AB *mm)
7400e3ece09SJunchao Zhang {
7410e3ece09SJunchao Zhang   Mat_MPIAIJ       *empi = static_cast<Mat_MPIAIJ *>(E->data);
7420e3ece09SJunchao Zhang   Mat               A = empi->A, B = empi->B; // diag and off-diag
7430e3ece09SJunchao Zhang   Mat_SeqAIJKokkos *akok = static_cast<Mat_SeqAIJKokkos *>(A->spptr), *bkok = static_cast<Mat_SeqAIJKokkos *>(B->spptr);
7440e3ece09SJunchao Zhang   PetscInt          Em = E->rmap->n; // #local rows
7450e3ece09SJunchao Zhang   MPI_Comm          comm;
7460e3ece09SJunchao Zhang 
7470e3ece09SJunchao Zhang   PetscFunctionBegin;
7480e3ece09SJunchao Zhang   PetscCallMPI(PetscObjectGetComm((PetscObject)E, &comm));
7490e3ece09SJunchao Zhang   if (reuse == MAT_INITIAL_MATRIX) {
7500e3ece09SJunchao Zhang     Mat_SeqAIJ     *aseq = static_cast<Mat_SeqAIJ *>(A->data), *bseq = static_cast<Mat_SeqAIJ *>(B->data);
7510e3ece09SJunchao Zhang     PetscInt        n1 = B->cmap->n, *Ai = aseq->i, *Aj = aseq->j, *Bi = bseq->i, *Bj = bseq->j;
7520e3ece09SJunchao Zhang     const PetscInt *garray1 = empi->garray; // its size is n1
7530e3ece09SJunchao Zhang     PetscInt        cstart, cend;
7540e3ece09SJunchao Zhang     PetscSF         bcastSF;
7550e3ece09SJunchao Zhang 
7560e3ece09SJunchao Zhang     PetscCall(MatGetOwnershipRangeColumn(E, &cstart, &cend));
7570e3ece09SJunchao Zhang 
7580e3ece09SJunchao Zhang     // Count how many nonzeros of each row in E are in the left of the diag block [cstart,cend)
7597b8d4ba6SJunchao Zhang     PetscIntKokkosViewHost E_NzLeft_h(NoInit("E_NzLeft_h"), Em), E_RowLen_h(NoInit("E_RowLen_h"), Em);
7600e3ece09SJunchao Zhang     PetscInt              *E_NzLeft = E_NzLeft_h.data(), *E_RowLen = E_RowLen_h.data();
7610e3ece09SJunchao Zhang     for (PetscInt i = 0; i < Em; i++) {
7620e3ece09SJunchao Zhang       const PetscInt *first, *last, *it;
7630e3ece09SJunchao Zhang       PetscInt        count, step;
7640e3ece09SJunchao Zhang       // std::lower_bound(first,last,cstart), but need to use global column indices
7650e3ece09SJunchao Zhang       first = Bj + Bi[i];
7660e3ece09SJunchao Zhang       last  = Bj + Bi[i + 1];
7670e3ece09SJunchao Zhang       count = last - first;
7680e3ece09SJunchao Zhang       while (count > 0) {
7690e3ece09SJunchao Zhang         it   = first;
7700e3ece09SJunchao Zhang         step = count / 2;
7710e3ece09SJunchao Zhang         it += step;
7720e3ece09SJunchao Zhang         if (empi->garray[*it] < cstart) { // map local to global
7730e3ece09SJunchao Zhang           first = ++it;
7740e3ece09SJunchao Zhang           count -= step + 1;
7750e3ece09SJunchao Zhang         } else count = step;
7760e3ece09SJunchao Zhang       }
7770e3ece09SJunchao Zhang       E_NzLeft[i] = first - (Bj + Bi[i]);
7780e3ece09SJunchao Zhang       E_RowLen[i] = (Ai[i + 1] - Ai[i]) + (Bi[i + 1] - Bi[i]);
7790e3ece09SJunchao Zhang     }
7800e3ece09SJunchao Zhang 
7810e3ece09SJunchao Zhang     // Compute row pointer Fi of F
7820e3ece09SJunchao Zhang     PetscInt *Fi, Fm, Fnz;
7830e3ece09SJunchao Zhang     PetscCall(PetscSFGetGraph(ownerSF, NULL, &Fm, NULL, NULL)); // Fm = #rows of F = nleaves of ownerSF
7840e3ece09SJunchao Zhang     PetscCall(PetscMalloc1(Fm + 1, &Fi));
7850e3ece09SJunchao Zhang     Fi[0] = 0;
7860e3ece09SJunchao Zhang     PetscCall(PetscSFBcastWithMemTypeBegin(ownerSF, MPIU_INT, PETSC_MEMTYPE_HOST, E_RowLen, PETSC_MEMTYPE_HOST, &Fi[1], MPI_REPLACE));
7870e3ece09SJunchao Zhang     PetscCall(PetscSFBcastEnd(ownerSF, MPIU_INT, E_RowLen, &Fi[1], MPI_REPLACE));
7880e3ece09SJunchao Zhang     for (PetscInt i = 0; i < Fm; i++) Fi[i + 1] += Fi[i];
7890e3ece09SJunchao Zhang     Fnz = Fi[Fm];
7900e3ece09SJunchao Zhang 
7910e3ece09SJunchao Zhang     // Build the real PetscSF for bcasting E rows (buffer to buffer)
7920e3ece09SJunchao Zhang     const PetscMPIInt *iranks, *ranks;
7930e3ece09SJunchao Zhang     const PetscInt    *ioffset, *irootloc, *roffset;
794c09cee04SJames Wright     PetscMPIInt        niranks, nranks;
795c09cee04SJames Wright     PetscInt          *sdisp, *rdisp;
7960e3ece09SJunchao Zhang     MPI_Request       *reqs;
7970e3ece09SJunchao Zhang     PetscMPIInt        tag;
7980e3ece09SJunchao Zhang 
7990e3ece09SJunchao Zhang     PetscCall(PetscSFGetLeafRanks(ownerSF, &niranks, &iranks, &ioffset, &irootloc)); // get leaf ranks referencing roots on this process
8000e3ece09SJunchao Zhang     PetscCall(PetscSFGetRootRanks(ownerSF, &nranks, &ranks, &roffset, NULL, NULL));  // recv info
8010e3ece09SJunchao Zhang     PetscCall(PetscMalloc3(niranks + 1, &sdisp, nranks, &rdisp, niranks + nranks, &reqs));
8020e3ece09SJunchao Zhang 
8030e3ece09SJunchao Zhang     sdisp[0] = 0; // send displacement
8040e3ece09SJunchao Zhang     for (PetscInt i = 0; i < niranks; i++) {
8050e3ece09SJunchao Zhang       sdisp[i + 1] = sdisp[i];
8060e3ece09SJunchao Zhang       for (PetscInt j = ioffset[i]; j < ioffset[i + 1]; j++) {
8070e3ece09SJunchao Zhang         PetscInt r = irootloc[j]; // row to be sent
8080e3ece09SJunchao Zhang         sdisp[i + 1] += E_RowLen[r];
8090e3ece09SJunchao Zhang       }
8100e3ece09SJunchao Zhang     }
8110e3ece09SJunchao Zhang 
8120e3ece09SJunchao Zhang     PetscCallMPI(PetscCommGetNewTag(comm, &tag));
8136497c311SBarry Smith     for (PetscInt i = 0; i < nranks; i++) PetscCallMPI(MPIU_Irecv(&rdisp[i], 1, MPIU_INT, ranks[i], tag, comm, &reqs[i]));
8146497c311SBarry Smith     for (PetscInt i = 0; i < niranks; i++) PetscCallMPI(MPIU_Isend(&sdisp[i], 1, MPIU_INT, iranks[i], tag, comm, &reqs[nranks + i]));
8150e3ece09SJunchao Zhang     PetscCallMPI(MPI_Waitall(niranks + nranks, reqs, MPI_STATUSES_IGNORE));
8160e3ece09SJunchao Zhang 
8170e3ece09SJunchao Zhang     PetscInt     nleaves = Fnz;            // leaves are nonzeros I will receive
8180e3ece09SJunchao Zhang     PetscInt     nroots  = sdisp[niranks]; // roots are nonzeros I will send
8190e3ece09SJunchao Zhang     PetscSFNode *iremote;                  // give ownership to bcastSF
8200e3ece09SJunchao Zhang     PetscCall(PetscMalloc1(nleaves, &iremote));
8210e3ece09SJunchao Zhang     for (PetscInt i = 0; i < nranks; i++) { // for each sender rank
8220e3ece09SJunchao Zhang       PetscInt k = 0;
8230e3ece09SJunchao Zhang       for (PetscInt j = Fi[roffset[i]]; j < Fi[roffset[i + 1]]; j++) { // I will receive rows [roffset[i], roffset[i+1]) of F from ranks[i]
8240e3ece09SJunchao Zhang         iremote[j].rank  = ranks[i];
8250e3ece09SJunchao Zhang         iremote[j].index = rdisp[i] + k; // their root location
8260e3ece09SJunchao Zhang         k++;
8270e3ece09SJunchao Zhang       }
8280e3ece09SJunchao Zhang     }
8290e3ece09SJunchao Zhang     PetscCall(PetscSFCreate(comm, &bcastSF));
8300e3ece09SJunchao Zhang     PetscCall(PetscSFSetGraph(bcastSF, nroots, nleaves, NULL, PETSC_OWN_POINTER, iremote, PETSC_OWN_POINTER));
8310e3ece09SJunchao Zhang     PetscCall(PetscFree3(sdisp, rdisp, reqs));
8320e3ece09SJunchao Zhang 
8330e3ece09SJunchao Zhang     // Build a plan (rowoffset, irootloc, E_NzLeft) to copy rows in E to rootdata of bcastSF in parallel
8347b8d4ba6SJunchao Zhang     PetscIntKokkosViewHost rowoffset_h(NoInit("rowoffset_h"), ioffset[niranks] + 1);
8350e3ece09SJunchao Zhang     PetscInt              *rowoffset = rowoffset_h.data(); // for each entry (row) indicated in irootloc[], we calculate its destinate offset in copying
8360e3ece09SJunchao Zhang     rowoffset[0]                     = 0;
8377b8d4ba6SJunchao Zhang     for (PetscInt i = 0; i < ioffset[niranks]; i++) { rowoffset[i + 1] = rowoffset[i] + E_RowLen[irootloc[i]]; }
8380e3ece09SJunchao Zhang 
8390e3ece09SJunchao Zhang     // Copy (global) column indices of the needed rows in E to a buffer, and then bcast to Fj[]
8400e3ece09SJunchao Zhang     PetscInt *jbuf, *Fj;
8410e3ece09SJunchao Zhang     PetscCall(PetscMalloc2(nroots, &jbuf, Fnz, &Fj));
8420e3ece09SJunchao Zhang     for (PetscInt k = 0; k < ioffset[niranks]; k++) {
8430e3ece09SJunchao Zhang       PetscInt  i      = irootloc[k]; // row to be copied
8440e3ece09SJunchao Zhang       PetscInt *buf    = &jbuf[rowoffset[k]];
8450e3ece09SJunchao Zhang       PetscInt  nzLeft = E_NzLeft[i];
8460e3ece09SJunchao Zhang       PetscInt  alen = Ai[i + 1] - Ai[i], blen = Bi[i + 1] - Bi[i];
8470e3ece09SJunchao Zhang       for (PetscInt j = 0; j < alen + blen; j++) {
8480e3ece09SJunchao Zhang         if (j < nzLeft) {
8490e3ece09SJunchao Zhang           buf[j] = empi->garray[Bj[Bi[i] + j]]; // left B, in global
8500e3ece09SJunchao Zhang         } else if (j < nzLeft + alen) {
8510e3ece09SJunchao Zhang           buf[j] = Aj[Ai[i] + j - nzLeft] + cstart; // diag A, also in global
8520e3ece09SJunchao Zhang         } else {
8530e3ece09SJunchao Zhang           buf[j] = empi->garray[Bj[Bi[i] + j - alen]]; // right B, in global
8540e3ece09SJunchao Zhang         }
8550e3ece09SJunchao Zhang       }
8560e3ece09SJunchao Zhang     }
8570e3ece09SJunchao Zhang     PetscCall(PetscSFBcastWithMemTypeBegin(bcastSF, MPIU_INT, PETSC_MEMTYPE_HOST, jbuf, PETSC_MEMTYPE_HOST, Fj, MPI_REPLACE));
8580e3ece09SJunchao Zhang     PetscCall(PetscSFBcastEnd(bcastSF, MPIU_INT, jbuf, Fj, MPI_REPLACE));
8590e3ece09SJunchao Zhang 
8600e3ece09SJunchao Zhang     // Build a plan (i.e., F_NzLeft) to split F into Fd and Fo
8617b8d4ba6SJunchao Zhang     MatRowMapKokkosViewHost Fdi_h(NoInit("Fdi_h"), Fm + 1), Foi_h(NoInit("Foi_h"), Fm + 1); // row pointer of Fd, Fo
8627b8d4ba6SJunchao Zhang     MatColIdxKokkosViewHost F_NzLeft_h(NoInit("F_NzLeft_h"), Fm);                           // split each row of F into Left, Diag, Right. We only need to record #nz in Left and Diag.
8630e3ece09SJunchao Zhang     MatRowMapType          *Fdi = Fdi_h.data(), *Foi = Foi_h.data();
8640e3ece09SJunchao Zhang     MatColIdxType          *F_NzLeft = F_NzLeft_h.data();
8650e3ece09SJunchao Zhang 
8660e3ece09SJunchao Zhang     Fdi[0] = Foi[0] = 0;
8670e3ece09SJunchao Zhang     for (PetscInt i = 0; i < Fm; i++) {
8680e3ece09SJunchao Zhang       PetscInt *first, *last, *lb1, *lb2;
8690e3ece09SJunchao Zhang       // cut the row into: Left, [cstart, cend), Right
8700e3ece09SJunchao Zhang       first       = Fj + Fi[i];
8710e3ece09SJunchao Zhang       last        = Fj + Fi[i + 1];
8720e3ece09SJunchao Zhang       lb1         = std::lower_bound(first, last, cstart);
8730e3ece09SJunchao Zhang       F_NzLeft[i] = lb1 - first;
8740e3ece09SJunchao Zhang       lb2         = std::lower_bound(first, last, cend);
8750e3ece09SJunchao Zhang       Fdi[i + 1]  = lb2 - lb1;                        // row i length in Fdi
8760e3ece09SJunchao Zhang       Foi[i + 1]  = (Fi[i + 1] - Fi[i]) - Fdi[i + 1]; // row i length in Foi
8770e3ece09SJunchao Zhang     }
8780e3ece09SJunchao Zhang     for (PetscInt i = 0; i < Fm; i++) {
8790e3ece09SJunchao Zhang       Fdi[i + 1] += Fdi[i];
8800e3ece09SJunchao Zhang       Foi[i + 1] += Foi[i];
8810e3ece09SJunchao Zhang     }
8820e3ece09SJunchao Zhang 
8830e3ece09SJunchao Zhang     // Fill Fdj[] and Foj[], i.e., columns of Fd and Fo. Fdj[] are local, but Foj[] are not yet.
8840e3ece09SJunchao Zhang     PetscInt                Fdnz = Fdi[Fm], Fonz = Foi[Fm];
8857b8d4ba6SJunchao Zhang     MatColIdxKokkosViewHost Fdj_h(NoInit("Fdj_h"), Fdnz), Foj_h(NoInit("Foj_h"), Fonz);
8860e3ece09SJunchao Zhang     MatColIdxType          *Fdj = Fdj_h.data(), *Foj = Foj_h.data(), gid;
8870e3ece09SJunchao Zhang 
8880e3ece09SJunchao Zhang     for (PetscInt i = 0; i < Fm; i++) {
8890e3ece09SJunchao Zhang       PetscInt nzLeft = F_NzLeft[i];
8900e3ece09SJunchao Zhang       PetscInt len    = Fdi[i + 1] - Fdi[i]; // diag row len
8910e3ece09SJunchao Zhang       for (PetscInt j = 0; j < Fi[i + 1] - Fi[i]; j++) {
8920e3ece09SJunchao Zhang         gid = Fj[Fi[i] + j];
8930e3ece09SJunchao Zhang         if (j < nzLeft) { // left, in global
8940e3ece09SJunchao Zhang           Foj[Foi[i] + j] = gid;
8950e3ece09SJunchao Zhang         } else if (j < nzLeft + len) { // diag, in local
8960e3ece09SJunchao Zhang           Fdj[Fdi[i] + j - nzLeft] = gid - cstart;
8970e3ece09SJunchao Zhang         } else { // right, in global
8980e3ece09SJunchao Zhang           Foj[Foi[i] + j - len] = gid;
8990e3ece09SJunchao Zhang         }
9000e3ece09SJunchao Zhang       }
9010e3ece09SJunchao Zhang     }
9020e3ece09SJunchao Zhang     PetscCall(PetscFree2(jbuf, Fj));
9030e3ece09SJunchao Zhang     PetscCall(PetscFree(Fi));
9040e3ece09SJunchao Zhang 
9050e3ece09SJunchao Zhang     // Reduce global indices in Foj[] and garray1[] into local ones
9060e3ece09SJunchao Zhang     PetscInt n2, *garray2;
9070e3ece09SJunchao Zhang     PetscCall(ReduceTwoSetsOfGlobalIndices(n1, garray1, Fonz, Foj, &n2, &garray2, map));
9080e3ece09SJunchao Zhang 
9090e3ece09SJunchao Zhang     // Record the plans built above, for reuse
9100e3ece09SJunchao Zhang     PetscIntKokkosViewHost tmp(const_cast<PetscInt *>(irootloc), ioffset[niranks]); // irootloc[] is owned by ownerSF. We create a copy for safety
9117b8d4ba6SJunchao Zhang     PetscIntKokkosViewHost irootloc_h(NoInit("irootloc_h"), ioffset[niranks]);
9120e3ece09SJunchao Zhang     Kokkos::deep_copy(irootloc_h, tmp);
9130e3ece09SJunchao Zhang     mm->sf        = bcastSF;
9140e3ece09SJunchao Zhang     mm->E_NzLeft  = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), E_NzLeft_h);
9150e3ece09SJunchao Zhang     mm->F_NzLeft  = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), F_NzLeft_h);
9160e3ece09SJunchao Zhang     mm->irootloc  = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), irootloc_h);
9170e3ece09SJunchao Zhang     mm->rowoffset = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), rowoffset_h);
9187b8d4ba6SJunchao Zhang     mm->rootBuf   = MatScalarKokkosView(NoInit("rootBuf"), nroots);
9197b8d4ba6SJunchao Zhang     mm->leafBuf   = MatScalarKokkosView(NoInit("leafBuf"), nleaves);
9200e3ece09SJunchao Zhang     mm->garray    = garray2;
9210e3ece09SJunchao Zhang     mm->n         = n2;
9220e3ece09SJunchao Zhang 
9230e3ece09SJunchao Zhang     // Output Fd and Fo in KokkosCsrMatrix format
9247b8d4ba6SJunchao Zhang     MatScalarKokkosView Fda_d(NoInit("Fda_d"), Fdnz), Foa_d(NoInit("Foa_d"), Fonz);
9250e3ece09SJunchao Zhang     MatRowMapKokkosView Fdi_d = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Fdi_h);
9260e3ece09SJunchao Zhang     MatColIdxKokkosView Fdj_d = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Fdj_h);
9270e3ece09SJunchao Zhang     MatRowMapKokkosView Foi_d = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Foi_h);
9280e3ece09SJunchao Zhang     MatColIdxKokkosView Foj_d = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Foj_h);
9290e3ece09SJunchao Zhang 
9300e3ece09SJunchao Zhang     PetscCallCXX(mm->Fd = KokkosCsrMatrix("Fd", Fm, cend - cstart, Fdnz, Fda_d, Fdi_d, Fdj_d));
9310e3ece09SJunchao Zhang     PetscCallCXX(mm->Fo = KokkosCsrMatrix("Fo", Fm, n2, Fonz, Foa_d, Foi_d, Foj_d));
9320e3ece09SJunchao Zhang 
9330e3ece09SJunchao Zhang     // Compute kernel launch parameters in merging E or splitting F
9340e3ece09SJunchao Zhang     PetscInt teamSize, vectorLength, rowsPerTeam;
9350e3ece09SJunchao Zhang 
9360e3ece09SJunchao Zhang     teamSize = vectorLength = rowsPerTeam = -1;
9370e3ece09SJunchao Zhang     PetscCall(MatMergeGetLaunchParameters<DefaultExecutionSpace>(mm->irootloc.extent(0), mm->rootBuf.extent(0), -1, teamSize, vectorLength, rowsPerTeam));
9380e3ece09SJunchao Zhang     mm->E_TeamSize     = teamSize;
9390e3ece09SJunchao Zhang     mm->E_VectorLength = vectorLength;
9400e3ece09SJunchao Zhang     mm->E_RowsPerTeam  = rowsPerTeam;
9410e3ece09SJunchao Zhang 
9420e3ece09SJunchao Zhang     teamSize = vectorLength = rowsPerTeam = -1;
9430e3ece09SJunchao Zhang     PetscCall(MatMergeGetLaunchParameters<DefaultExecutionSpace>(Fm, Fnz, -1, teamSize, vectorLength, rowsPerTeam));
9440e3ece09SJunchao Zhang     mm->F_TeamSize     = teamSize;
9450e3ece09SJunchao Zhang     mm->F_VectorLength = vectorLength;
9460e3ece09SJunchao Zhang     mm->F_RowsPerTeam  = rowsPerTeam;
9470e3ece09SJunchao Zhang   } else PetscCheck(reuse == MAT_REUSE_MATRIX, comm, PETSC_ERR_PLIB, "Unsupported MatReuse enum %d", reuse);
9480e3ece09SJunchao Zhang 
9490e3ece09SJunchao Zhang   // Sync E's value to device
950*f3d3cd90SZach Atkins   PetscCall(KokkosDualViewSyncDevice(akok->a_dual, PetscGetKokkosExecutionSpace()));
951*f3d3cd90SZach Atkins   PetscCall(KokkosDualViewSyncDevice(bkok->a_dual, PetscGetKokkosExecutionSpace()));
9520e3ece09SJunchao Zhang 
9530e3ece09SJunchao Zhang   // Handy aliases
9540e3ece09SJunchao Zhang   const auto &Aa = akok->a_dual.view_device();
9550e3ece09SJunchao Zhang   const auto &Ba = bkok->a_dual.view_device();
9560e3ece09SJunchao Zhang   const auto &Ai = akok->i_dual.view_device();
9570e3ece09SJunchao Zhang   const auto &Bi = bkok->i_dual.view_device();
9580e3ece09SJunchao Zhang 
9590e3ece09SJunchao Zhang   // Fetch the plans
9600e3ece09SJunchao Zhang   PetscIntKokkosView  &E_NzLeft  = mm->E_NzLeft;
9610e3ece09SJunchao Zhang   PetscSF             &bcastSF   = mm->sf;
9620e3ece09SJunchao Zhang   MatScalarKokkosView &rootBuf   = mm->rootBuf;
9630e3ece09SJunchao Zhang   MatScalarKokkosView &leafBuf   = mm->leafBuf;
9640e3ece09SJunchao Zhang   PetscIntKokkosView  &irootloc  = mm->irootloc;
9650e3ece09SJunchao Zhang   PetscIntKokkosView  &rowoffset = mm->rowoffset;
9660e3ece09SJunchao Zhang 
9670e3ece09SJunchao Zhang   PetscInt teamSize     = mm->E_TeamSize;
9680e3ece09SJunchao Zhang   PetscInt vectorLength = mm->E_VectorLength;
9690e3ece09SJunchao Zhang   PetscInt rowsPerTeam  = mm->E_RowsPerTeam;
9700e3ece09SJunchao Zhang   PetscInt workSets     = (irootloc.extent(0) + rowsPerTeam - 1) / rowsPerTeam;
9710e3ece09SJunchao Zhang 
9720e3ece09SJunchao Zhang   // Copy rows in A/B of E to rootBuf, then bcast it to leafBuf
9730e3ece09SJunchao Zhang   PetscCallCXX(Kokkos::parallel_for(
974d326c3f1SJunchao Zhang     Kokkos::TeamPolicy<>(PetscGetKokkosExecutionSpace(), workSets, teamSize, vectorLength), KOKKOS_LAMBDA(const KokkosTeamMemberType &t) {
9750e3ece09SJunchao Zhang       Kokkos::parallel_for(Kokkos::TeamThreadRange(t, 0, rowsPerTeam), [&](PetscInt k) {
9760e3ece09SJunchao Zhang         size_t r = t.league_rank() * rowsPerTeam + k; // r-th entry in irootloc[]
9770e3ece09SJunchao Zhang         if (r < irootloc.extent(0)) {
9780e3ece09SJunchao Zhang           PetscInt i      = irootloc(r); // row i of E
9790e3ece09SJunchao Zhang           PetscInt disp   = rowoffset(r);
9800e3ece09SJunchao Zhang           PetscInt alen   = Ai(i + 1) - Ai(i);
9810e3ece09SJunchao Zhang           PetscInt blen   = Bi(i + 1) - Bi(i);
9820e3ece09SJunchao Zhang           PetscInt nzleft = E_NzLeft(i);
9830e3ece09SJunchao Zhang 
9840e3ece09SJunchao Zhang           Kokkos::parallel_for(Kokkos::ThreadVectorRange(t, alen + blen), [&](PetscInt j) {
9850e3ece09SJunchao Zhang             if (j < nzleft) { // B left
9860e3ece09SJunchao Zhang               rootBuf(disp + j) = Ba(Bi(i) + j);
9870e3ece09SJunchao Zhang             } else if (j < nzleft + alen) { // diag A
9880e3ece09SJunchao Zhang               rootBuf(disp + j) = Aa(Ai(i) + j - nzleft);
9890e3ece09SJunchao Zhang             } else { // B right
9900e3ece09SJunchao Zhang               rootBuf(disp + j) = Ba(Bi(i) + j - alen);
9910e3ece09SJunchao Zhang             }
9920e3ece09SJunchao Zhang           });
9930e3ece09SJunchao Zhang         }
9940e3ece09SJunchao Zhang       });
9950e3ece09SJunchao Zhang     }));
9960e3ece09SJunchao Zhang   PetscCall(PetscSFBcastWithMemTypeBegin(bcastSF, MPIU_SCALAR, PETSC_MEMTYPE_KOKKOS, rootBuf.data(), PETSC_MEMTYPE_KOKKOS, leafBuf.data(), MPI_REPLACE));
9970e3ece09SJunchao Zhang   PetscFunctionReturn(PETSC_SUCCESS);
9980e3ece09SJunchao Zhang }
9990e3ece09SJunchao Zhang 
10000e3ece09SJunchao Zhang // To finish MatMPIAIJKokkosBcast.
10010e3ece09SJunchao Zhang static PetscErrorCode MatMPIAIJKokkosBcastEnd(Mat E, PetscSF ownerSF, MatReuse reuse, PetscInt *map, MatMatStruct_AB *mm)
10020e3ece09SJunchao Zhang {
10030e3ece09SJunchao Zhang   PetscFunctionBegin;
10040e3ece09SJunchao Zhang   const auto &Fd  = mm->Fd;
10050e3ece09SJunchao Zhang   const auto &Fo  = mm->Fo;
10060e3ece09SJunchao Zhang   const auto &Fdi = Fd.graph.row_map;
10070e3ece09SJunchao Zhang   const auto &Foi = Fo.graph.row_map;
10080e3ece09SJunchao Zhang   auto       &Fda = Fd.values;
10090e3ece09SJunchao Zhang   auto       &Foa = Fo.values;
10100e3ece09SJunchao Zhang   auto        Fm  = Fd.numRows();
10110e3ece09SJunchao Zhang 
10120e3ece09SJunchao Zhang   PetscIntKokkosView  &F_NzLeft     = mm->F_NzLeft;
10130e3ece09SJunchao Zhang   PetscSF             &bcastSF      = mm->sf;
10140e3ece09SJunchao Zhang   MatScalarKokkosView &rootBuf      = mm->rootBuf;
10150e3ece09SJunchao Zhang   MatScalarKokkosView &leafBuf      = mm->leafBuf;
10160e3ece09SJunchao Zhang   PetscInt             teamSize     = mm->F_TeamSize;
10170e3ece09SJunchao Zhang   PetscInt             vectorLength = mm->F_VectorLength;
10180e3ece09SJunchao Zhang   PetscInt             rowsPerTeam  = mm->F_RowsPerTeam;
10190e3ece09SJunchao Zhang   PetscInt             workSets     = (Fm + rowsPerTeam - 1) / rowsPerTeam;
10200e3ece09SJunchao Zhang 
10210e3ece09SJunchao Zhang   PetscCall(PetscSFBcastEnd(bcastSF, MPIU_SCALAR, rootBuf.data(), leafBuf.data(), MPI_REPLACE));
10220e3ece09SJunchao Zhang 
10230e3ece09SJunchao Zhang   // Update Fda and Foa with new data in leafBuf (as if it is Fa)
10240e3ece09SJunchao Zhang   PetscCallCXX(Kokkos::parallel_for(
1025d326c3f1SJunchao Zhang     Kokkos::TeamPolicy<>(PetscGetKokkosExecutionSpace(), workSets, teamSize, vectorLength), KOKKOS_LAMBDA(const KokkosTeamMemberType &t) {
10260e3ece09SJunchao Zhang       Kokkos::parallel_for(Kokkos::TeamThreadRange(t, 0, rowsPerTeam), [&](PetscInt k) {
10270e3ece09SJunchao Zhang         PetscInt i = t.league_rank() * rowsPerTeam + k; // i-th row in F
10280e3ece09SJunchao Zhang         if (i < Fm) {
10290e3ece09SJunchao Zhang           PetscInt nzLeft = F_NzLeft(i);
10300e3ece09SJunchao Zhang           PetscInt alen   = Fdi(i + 1) - Fdi(i);
10310e3ece09SJunchao Zhang           PetscInt blen   = Foi(i + 1) - Foi(i);
10320e3ece09SJunchao Zhang           PetscInt Fii    = Fdi(i) + Foi(i);
10330e3ece09SJunchao Zhang 
10340e3ece09SJunchao Zhang           Kokkos::parallel_for(Kokkos::ThreadVectorRange(t, alen + blen), [&](PetscInt j) {
10350e3ece09SJunchao Zhang             PetscScalar val = leafBuf(Fii + j);
10360e3ece09SJunchao Zhang             if (j < nzLeft) { // left
10370e3ece09SJunchao Zhang               Foa(Foi(i) + j) = val;
10380e3ece09SJunchao Zhang             } else if (j < nzLeft + alen) { // diag
10390e3ece09SJunchao Zhang               Fda(Fdi(i) + j - nzLeft) = val;
10400e3ece09SJunchao Zhang             } else { // right
10410e3ece09SJunchao Zhang               Foa(Foi(i) + j - alen) = val;
10420e3ece09SJunchao Zhang             }
10430e3ece09SJunchao Zhang           });
10440e3ece09SJunchao Zhang         }
10450e3ece09SJunchao Zhang       });
10460e3ece09SJunchao Zhang     }));
10470e3ece09SJunchao Zhang   PetscFunctionReturn(PETSC_SUCCESS);
10480e3ece09SJunchao Zhang }
10490e3ece09SJunchao Zhang 
10500e3ece09SJunchao Zhang static PetscErrorCode MatProductSymbolic_MPIAIJKokkos_AtB(Mat_Product *product, Mat A, Mat B, MatMatStruct_AtB *mm)
10510e3ece09SJunchao Zhang {
10520e3ece09SJunchao Zhang   Mat_MPIAIJ     *ampi = static_cast<Mat_MPIAIJ *>(A->data);
10530e3ece09SJunchao Zhang   Mat_MPIAIJ     *bmpi = static_cast<Mat_MPIAIJ *>(B->data);
10540e3ece09SJunchao Zhang   KokkosCsrMatrix Adt, Aot, Ad, Ao, Bd, Bo;
10550e3ece09SJunchao Zhang   PetscInt        cstart, cend;
10560e3ece09SJunchao Zhang   MPI_Comm        comm;
10570e3ece09SJunchao Zhang 
10580e3ece09SJunchao Zhang   PetscFunctionBegin;
10590e3ece09SJunchao Zhang   PetscCall(PetscObjectGetComm((PetscObject)B, &comm));
10600e3ece09SJunchao Zhang   PetscCall(MatSeqAIJKokkosGenerateTranspose_Private(ampi->A, &Adt));
10610e3ece09SJunchao Zhang   PetscCall(MatSeqAIJKokkosGenerateTranspose_Private(ampi->B, &Aot));
10620e3ece09SJunchao Zhang   PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(ampi->A, &Ad));
10630e3ece09SJunchao Zhang   PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(ampi->B, &Ao));
10640e3ece09SJunchao Zhang   PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(bmpi->A, &Bd));
10650e3ece09SJunchao Zhang   PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(bmpi->B, &Bo));
10660e3ece09SJunchao Zhang 
10670e3ece09SJunchao Zhang   // TODO: add command line options to select spgemm algorithms
10680e3ece09SJunchao Zhang   auto spgemm_alg = KokkosSparse::SPGEMMAlgorithm::SPGEMM_DEFAULT; // default is TPL if enabled, otherwise KK
10690e3ece09SJunchao Zhang 
10700e3ece09SJunchao Zhang   // CUDA-10.2's spgemm has bugs. We prefer the SpGEMMreuse APIs introduced in cuda-11.4
10710e3ece09SJunchao Zhang #if defined(KOKKOSKERNELS_ENABLE_TPL_CUSPARSE)
10720e3ece09SJunchao Zhang   #if PETSC_PKG_CUDA_VERSION_LT(11, 4, 0)
10730e3ece09SJunchao Zhang   spgemm_alg = KokkosSparse::SPGEMMAlgorithm::SPGEMM_KK;
1074f0e6e2d1SJunchao Zhang   #endif
10750e3ece09SJunchao Zhang #endif
10760e3ece09SJunchao Zhang 
10770e3ece09SJunchao Zhang   PetscCallCXX(mm->kh1.create_spgemm_handle(spgemm_alg));
10780e3ece09SJunchao Zhang   PetscCallCXX(mm->kh2.create_spgemm_handle(spgemm_alg));
10790e3ece09SJunchao Zhang   PetscCallCXX(mm->kh3.create_spgemm_handle(spgemm_alg));
10800e3ece09SJunchao Zhang   PetscCallCXX(mm->kh4.create_spgemm_handle(spgemm_alg));
10810e3ece09SJunchao Zhang 
10820e3ece09SJunchao Zhang   // Aot * (B's diag + B's off-diag)
10830e3ece09SJunchao Zhang   PetscCallCXX(KokkosSparse::spgemm_symbolic(mm->kh3, Aot, false, Bd, false, mm->C3));
10840e3ece09SJunchao Zhang   PetscCallCXX(KokkosSparse::spgemm_symbolic(mm->kh4, Aot, false, Bo, false, mm->C4));
10850e3ece09SJunchao Zhang   // KK spgemm_symbolic() only populates the result's row map, but not its columns.
10860e3ece09SJunchao Zhang   // TODO: Remove the fake spgemm_numeric() after KK fixed this problem.
10870e3ece09SJunchao Zhang   PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh3, Aot, false, Bd, false, mm->C3));
10880e3ece09SJunchao Zhang   PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh4, Aot, false, Bo, false, mm->C4));
10890e3ece09SJunchao Zhang #if PETSC_PKG_KOKKOS_KERNELS_VERSION_LT(4, 0, 0)
1090d326c3f1SJunchao Zhang 
10910e3ece09SJunchao Zhang   PetscCallCXX(sort_crs_matrix(mm->C3));
10920e3ece09SJunchao Zhang   PetscCallCXX(sort_crs_matrix(mm->C4));
10930e3ece09SJunchao Zhang #endif
10940e3ece09SJunchao Zhang 
10950e3ece09SJunchao Zhang   // Reduce E (i.e., C3 and C4)'s rows to form F, and overlap the communication
10967b8d4ba6SJunchao Zhang   PetscIntKokkosViewHost map_h(NoInit("map_h"), bmpi->B->cmap->n);
10970e3ece09SJunchao Zhang   PetscCall(MatGetOwnershipRangeColumn(B, &cstart, &cend));
10980e3ece09SJunchao Zhang   PetscCall(MatMPIAIJKokkosReduceBegin(comm, mm->C3, mm->C4, cstart, cend, bmpi->garray, ampi->Mvctx, MAT_INITIAL_MATRIX, map_h.data(), mm));
10990e3ece09SJunchao Zhang 
11000e3ece09SJunchao Zhang   // Adt * (B's diag + B's off-diag)
11010e3ece09SJunchao Zhang   PetscCallCXX(KokkosSparse::spgemm_symbolic(mm->kh1, Adt, false, Bd, false, mm->C1));
11020e3ece09SJunchao Zhang   PetscCallCXX(KokkosSparse::spgemm_symbolic(mm->kh2, Adt, false, Bo, false, mm->C2_mid));
11030e3ece09SJunchao Zhang   PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh1, Adt, false, Bd, false, mm->C1));
11040e3ece09SJunchao Zhang   PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh2, Adt, false, Bo, false, mm->C2_mid));
11050e3ece09SJunchao Zhang #if PETSC_PKG_KOKKOS_KERNELS_VERSION_LT(4, 0, 0)
11060e3ece09SJunchao Zhang   PetscCallCXX(sort_crs_matrix(mm->C1));
11070e3ece09SJunchao Zhang   PetscCallCXX(sort_crs_matrix(mm->C2_mid));
11080e3ece09SJunchao Zhang #endif
11090e3ece09SJunchao Zhang 
11100e3ece09SJunchao Zhang   PetscCall(MatMPIAIJKokkosReduceEnd(comm, mm->C3, mm->C4, cstart, cend, bmpi->garray, ampi->Mvctx, MAT_INITIAL_MATRIX, map_h.data(), mm));
11110e3ece09SJunchao Zhang 
11120e3ece09SJunchao Zhang   // Create C2, which shares a, i arrays with C2_mid, but with new column indices and potentially larger column size
11137b8d4ba6SJunchao Zhang   MatColIdxKokkosView oldj = mm->C2_mid.graph.entries, newj(NoInit("j"), oldj.extent(0));
11140e3ece09SJunchao Zhang   PetscIntKokkosView  map  = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), map_h);
1115d326c3f1SJunchao Zhang   PetscCallCXX(Kokkos::parallel_for(Kokkos::RangePolicy<>(PetscGetKokkosExecutionSpace(), 0, oldj.extent(0)), KOKKOS_LAMBDA(const PetscInt i) { newj(i) = map(oldj(i)); }));
11160e3ece09SJunchao Zhang   PetscCallCXX(mm->C2 = KokkosCsrMatrix("C2", mm->C2_mid.numRows(), mm->n /*new column size*/, mm->C2_mid.nnz(), mm->C2_mid.values, mm->C2_mid.graph.row_map, newj));
11170e3ece09SJunchao Zhang 
11180e3ece09SJunchao Zhang   // C = (C1+Fd, C2+Fo)
11190e3ece09SJunchao Zhang   PetscCallCXX(mm->kh1.create_spadd_handle(true)); // C1, Fd are sorted
11200e3ece09SJunchao Zhang   PetscCallCXX(mm->kh2.create_spadd_handle(true)); // C2, Fo are sorted
11210e3ece09SJunchao Zhang   PetscCallCXX(KokkosSparse::spadd_symbolic(&mm->kh1, mm->C1, mm->Fd, mm->Cd));
11220e3ece09SJunchao Zhang   PetscCallCXX(KokkosSparse::spadd_symbolic(&mm->kh2, mm->C2, mm->Fo, mm->Co));
11230e3ece09SJunchao Zhang   PetscCallCXX(KokkosSparse::spadd_numeric(&mm->kh1, 1.0, mm->C1, 1.0, mm->Fd, mm->Cd));
11240e3ece09SJunchao Zhang   PetscCallCXX(KokkosSparse::spadd_numeric(&mm->kh2, 1.0, mm->C2, 1.0, mm->Fo, mm->Co));
11250e3ece09SJunchao Zhang   PetscFunctionReturn(PETSC_SUCCESS);
11260e3ece09SJunchao Zhang }
11270e3ece09SJunchao Zhang 
11280e3ece09SJunchao Zhang static PetscErrorCode MatProductNumeric_MPIAIJKokkos_AtB(Mat_Product *product, Mat A, Mat B, MatMatStruct_AtB *mm)
11290e3ece09SJunchao Zhang {
11300e3ece09SJunchao Zhang   Mat_MPIAIJ     *ampi = static_cast<Mat_MPIAIJ *>(A->data);
11310e3ece09SJunchao Zhang   Mat_MPIAIJ     *bmpi = static_cast<Mat_MPIAIJ *>(B->data);
11320e3ece09SJunchao Zhang   KokkosCsrMatrix Adt, Aot, Bd, Bo;
11330e3ece09SJunchao Zhang   MPI_Comm        comm;
11340e3ece09SJunchao Zhang 
11350e3ece09SJunchao Zhang   PetscFunctionBegin;
11360e3ece09SJunchao Zhang   PetscCall(PetscObjectGetComm((PetscObject)B, &comm));
11370e3ece09SJunchao Zhang   PetscCall(MatSeqAIJKokkosGenerateTranspose_Private(ampi->A, &Adt));
11380e3ece09SJunchao Zhang   PetscCall(MatSeqAIJKokkosGenerateTranspose_Private(ampi->B, &Aot));
11390e3ece09SJunchao Zhang   PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(bmpi->A, &Bd));
11400e3ece09SJunchao Zhang   PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(bmpi->B, &Bo));
11410e3ece09SJunchao Zhang 
11420e3ece09SJunchao Zhang   // Aot * (B's diag + B's off-diag)
11430e3ece09SJunchao Zhang   PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh3, Aot, false, Bd, false, mm->C3));
11440e3ece09SJunchao Zhang   PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh4, Aot, false, Bo, false, mm->C4));
11450e3ece09SJunchao Zhang 
11460e3ece09SJunchao Zhang   // Reduce E (i.e., C3 and C4)'s rows to form F, and overlap the communication
11470e3ece09SJunchao Zhang   PetscCall(MatMPIAIJKokkosReduceBegin(comm, mm->C3, mm->C4, 0, 0, NULL, NULL, MAT_REUSE_MATRIX, NULL, mm));
11480e3ece09SJunchao Zhang 
11490e3ece09SJunchao Zhang   // Adt * (B's diag + B's off-diag)
11500e3ece09SJunchao Zhang   PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh1, Adt, false, Bd, false, mm->C1));
11510e3ece09SJunchao Zhang   PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh2, Adt, false, Bo, false, mm->C2_mid));
11520e3ece09SJunchao Zhang 
11530e3ece09SJunchao Zhang   PetscCall(MatMPIAIJKokkosReduceEnd(comm, mm->C3, mm->C4, 0, 0, NULL, NULL, MAT_REUSE_MATRIX, NULL, mm));
11540e3ece09SJunchao Zhang 
11550e3ece09SJunchao Zhang   // C = (C1+Fd, C2+Fo)
11560e3ece09SJunchao Zhang   PetscCallCXX(KokkosSparse::spadd_numeric(&mm->kh1, 1.0, mm->C1, 1.0, mm->Fd, mm->Cd));
11570e3ece09SJunchao Zhang   PetscCallCXX(KokkosSparse::spadd_numeric(&mm->kh2, 1.0, mm->C2, 1.0, mm->Fo, mm->Co));
11580e3ece09SJunchao Zhang   PetscFunctionReturn(PETSC_SUCCESS);
11590e3ece09SJunchao Zhang }
1160f0e6e2d1SJunchao Zhang 
1161076ba34aSJunchao Zhang /* MatProductSymbolic_MPIAIJKokkos_AB - AB flavor of MatProductSymbolic_MPIAIJKokkos
1162076ba34aSJunchao Zhang 
1163076ba34aSJunchao Zhang   Input Parameters:
1164076ba34aSJunchao Zhang +  product  - Mat_Product which carried out the computation. Passed in to access info about this mat product.
1165076ba34aSJunchao Zhang .  A        - an MPIAIJKOKKOS matrix
1166076ba34aSJunchao Zhang .  B        - an MPIAIJKOKKOS matrix
1167076ba34aSJunchao Zhang -  mm       - a struct used to stash intermediate data when computing AB. Persist from symbolic to numeric operations.
1168076ba34aSJunchao Zhang */
1169d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatProductSymbolic_MPIAIJKokkos_AB(Mat_Product *product, Mat A, Mat B, MatMatStruct_AB *mm)
1170d71ae5a4SJacob Faibussowitsch {
11710e3ece09SJunchao Zhang   Mat_MPIAIJ     *ampi = static_cast<Mat_MPIAIJ *>(A->data);
11720e3ece09SJunchao Zhang   Mat_MPIAIJ     *bmpi = static_cast<Mat_MPIAIJ *>(B->data);
11730e3ece09SJunchao Zhang   KokkosCsrMatrix Ad, Ao, Bd, Bo;
1174076ba34aSJunchao Zhang 
1175076ba34aSJunchao Zhang   PetscFunctionBegin;
11760e3ece09SJunchao Zhang   PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(ampi->A, &Ad));
11770e3ece09SJunchao Zhang   PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(ampi->B, &Ao));
11780e3ece09SJunchao Zhang   PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(bmpi->A, &Bd));
11790e3ece09SJunchao Zhang   PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(bmpi->B, &Bo));
11800e3ece09SJunchao Zhang 
11810e3ece09SJunchao Zhang   // TODO: add command line options to select spgemm algorithms
11820e3ece09SJunchao Zhang   auto spgemm_alg = KokkosSparse::SPGEMMAlgorithm::SPGEMM_DEFAULT; // default is TPL if enabled, otherwise KK
11830e3ece09SJunchao Zhang 
11840e3ece09SJunchao Zhang   // CUDA-10.2's spgemm has bugs. We prefer the SpGEMMreuse APIs introduced in cuda-11.4
11850e3ece09SJunchao Zhang #if defined(KOKKOSKERNELS_ENABLE_TPL_CUSPARSE)
11860e3ece09SJunchao Zhang   #if PETSC_PKG_CUDA_VERSION_LT(11, 4, 0)
11870e3ece09SJunchao Zhang   spgemm_alg = KokkosSparse::SPGEMMAlgorithm::SPGEMM_KK;
11880e3ece09SJunchao Zhang   #endif
1189f0e6e2d1SJunchao Zhang #endif
1190f0e6e2d1SJunchao Zhang 
11910e3ece09SJunchao Zhang   mm->kh1.create_spgemm_handle(spgemm_alg);
11920e3ece09SJunchao Zhang   mm->kh2.create_spgemm_handle(spgemm_alg);
11930e3ece09SJunchao Zhang   mm->kh3.create_spgemm_handle(spgemm_alg);
11940e3ece09SJunchao Zhang   mm->kh4.create_spgemm_handle(spgemm_alg);
1195076ba34aSJunchao Zhang 
11960e3ece09SJunchao Zhang   // Bcast B's rows to form F, and overlap the communication
11977b8d4ba6SJunchao Zhang   PetscIntKokkosViewHost map_h(NoInit("map_h"), bmpi->B->cmap->n);
11980e3ece09SJunchao Zhang   PetscCall(MatMPIAIJKokkosBcastBegin(B, ampi->Mvctx, MAT_INITIAL_MATRIX, map_h.data(), mm));
1199076ba34aSJunchao Zhang 
12000e3ece09SJunchao Zhang   // A's diag * (B's diag + B's off-diag)
12010e3ece09SJunchao Zhang   PetscCallCXX(KokkosSparse::spgemm_symbolic(mm->kh1, Ad, false, Bd, false, mm->C1));
12020e3ece09SJunchao Zhang   PetscCallCXX(KokkosSparse::spgemm_symbolic(mm->kh2, Ad, false, Bo, false, mm->C2_mid)); // C2 aliases with C2_mid, except with new column indices
12030e3ece09SJunchao Zhang   // KK spgemm_symbolic() only populates the result's row map, but not its columns.
12040e3ece09SJunchao Zhang   // TODO: Remove the fake spgemm_numeric() after KK fixed this problem.
12050e3ece09SJunchao Zhang   PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh1, Ad, false, Bd, false, mm->C1));
12060e3ece09SJunchao Zhang   PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh2, Ad, false, Bo, false, mm->C2_mid));
12070e3ece09SJunchao Zhang #if PETSC_PKG_KOKKOS_KERNELS_VERSION_LT(4, 0, 0)
12080e3ece09SJunchao Zhang   PetscCallCXX(sort_crs_matrix(mm->C1));
12090e3ece09SJunchao Zhang   PetscCallCXX(sort_crs_matrix(mm->C2_mid));
12100e3ece09SJunchao Zhang #endif
1211076ba34aSJunchao Zhang 
12120e3ece09SJunchao Zhang   PetscCall(MatMPIAIJKokkosBcastEnd(B, ampi->Mvctx, MAT_INITIAL_MATRIX, map_h.data(), mm));
1213076ba34aSJunchao Zhang 
12140e3ece09SJunchao Zhang   // A's off-diag * (F's diag + F's off-diag)
12150e3ece09SJunchao Zhang   PetscCallCXX(KokkosSparse::spgemm_symbolic(mm->kh3, Ao, false, mm->Fd, false, mm->C3));
12160e3ece09SJunchao Zhang   PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh3, Ao, false, mm->Fd, false, mm->C3));
12170e3ece09SJunchao Zhang   PetscCallCXX(KokkosSparse::spgemm_symbolic(mm->kh4, Ao, false, mm->Fo, false, mm->C4));
12180e3ece09SJunchao Zhang   PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh4, Ao, false, mm->Fo, false, mm->C4));
12190e3ece09SJunchao Zhang #if PETSC_PKG_KOKKOS_KERNELS_VERSION_LT(4, 0, 0)
12200e3ece09SJunchao Zhang   PetscCallCXX(sort_crs_matrix(mm->C3));
12210e3ece09SJunchao Zhang   PetscCallCXX(sort_crs_matrix(mm->C4));
12220e3ece09SJunchao Zhang #endif
1223076ba34aSJunchao Zhang 
12240e3ece09SJunchao Zhang   // Create C2, which shares a, i arrays with C2_mid, but with new column indices and potentially larger column size
12257b8d4ba6SJunchao Zhang   MatColIdxKokkosView oldj = mm->C2_mid.graph.entries, newj(NoInit("j"), oldj.extent(0));
12260e3ece09SJunchao Zhang   PetscIntKokkosView  map  = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), map_h);
1227d326c3f1SJunchao Zhang   PetscCallCXX(Kokkos::parallel_for(Kokkos::RangePolicy<>(PetscGetKokkosExecutionSpace(), 0, oldj.extent(0)), KOKKOS_LAMBDA(const PetscInt i) { newj(i) = map(oldj(i)); }));
12280e3ece09SJunchao Zhang   mm->C2 = KokkosCsrMatrix("C2", mm->C2_mid.numRows(), mm->n /*new column size*/, mm->C2_mid.nnz(), mm->C2_mid.values, mm->C2_mid.graph.row_map, newj);
12290e3ece09SJunchao Zhang 
12300e3ece09SJunchao Zhang   // C = (Cd, Co) = (C1+C3, C2+C4)
12310e3ece09SJunchao Zhang   mm->kh1.create_spadd_handle(true); // C1, C3 are sorted
12320e3ece09SJunchao Zhang   mm->kh2.create_spadd_handle(true); // C2, C4 are sorted
12330e3ece09SJunchao Zhang   PetscCallCXX(KokkosSparse::spadd_symbolic(&mm->kh1, mm->C1, mm->C3, mm->Cd));
12340e3ece09SJunchao Zhang   PetscCallCXX(KokkosSparse::spadd_symbolic(&mm->kh2, mm->C2, mm->C4, mm->Co));
12350e3ece09SJunchao Zhang   PetscCallCXX(KokkosSparse::spadd_numeric(&mm->kh1, 1.0, mm->C1, 1.0, mm->C3, mm->Cd));
12360e3ece09SJunchao Zhang   PetscCallCXX(KokkosSparse::spadd_numeric(&mm->kh2, 1.0, mm->C2, 1.0, mm->C4, mm->Co));
12373ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
1238076ba34aSJunchao Zhang }
1239076ba34aSJunchao Zhang 
12400e3ece09SJunchao Zhang static PetscErrorCode MatProductNumeric_MPIAIJKokkos_AB(Mat_Product *product, Mat A, Mat B, MatMatStruct_AB *mm)
1241d71ae5a4SJacob Faibussowitsch {
12420e3ece09SJunchao Zhang   Mat_MPIAIJ     *ampi = static_cast<Mat_MPIAIJ *>(A->data);
12430e3ece09SJunchao Zhang   Mat_MPIAIJ     *bmpi = static_cast<Mat_MPIAIJ *>(B->data);
12440e3ece09SJunchao Zhang   KokkosCsrMatrix Ad, Ao, Bd, Bo;
1245076ba34aSJunchao Zhang 
1246076ba34aSJunchao Zhang   PetscFunctionBegin;
12470e3ece09SJunchao Zhang   PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(ampi->A, &Ad));
12480e3ece09SJunchao Zhang   PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(ampi->B, &Ao));
12490e3ece09SJunchao Zhang   PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(bmpi->A, &Bd));
12500e3ece09SJunchao Zhang   PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(bmpi->B, &Bo));
1251076ba34aSJunchao Zhang 
12520e3ece09SJunchao Zhang   // Bcast B's rows to form F, and overlap the communication
12530e3ece09SJunchao Zhang   PetscCall(MatMPIAIJKokkosBcastBegin(B, NULL, MAT_REUSE_MATRIX, NULL, mm));
1254076ba34aSJunchao Zhang 
12550e3ece09SJunchao Zhang   // A's diag * (B's diag + B's off-diag)
12560e3ece09SJunchao Zhang   PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh1, Ad, false, Bd, false, mm->C1));
12570e3ece09SJunchao Zhang   PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh2, Ad, false, Bo, false, mm->C2_mid));
1258076ba34aSJunchao Zhang 
12590e3ece09SJunchao Zhang   PetscCall(MatMPIAIJKokkosBcastEnd(B, NULL, MAT_REUSE_MATRIX, NULL, mm));
1260076ba34aSJunchao Zhang 
12610e3ece09SJunchao Zhang   // A's off-diag * (F's diag + F's off-diag)
12620e3ece09SJunchao Zhang   PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh3, Ao, false, mm->Fd, false, mm->C3));
12630e3ece09SJunchao Zhang   PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh4, Ao, false, mm->Fo, false, mm->C4));
12640e3ece09SJunchao Zhang 
12650e3ece09SJunchao Zhang   // C = (Cd, Co) = (C1+C3, C2+C4)
12660e3ece09SJunchao Zhang   PetscCallCXX(KokkosSparse::spadd_numeric(&mm->kh1, 1.0, mm->C1, 1.0, mm->C3, mm->Cd));
12670e3ece09SJunchao Zhang   PetscCallCXX(KokkosSparse::spadd_numeric(&mm->kh2, 1.0, mm->C2, 1.0, mm->C4, mm->Co));
12683ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
1269076ba34aSJunchao Zhang }
1270076ba34aSJunchao Zhang 
127166976f2fSJacob Faibussowitsch static PetscErrorCode MatProductNumeric_MPIAIJKokkos(Mat C)
1272d71ae5a4SJacob Faibussowitsch {
12730e3ece09SJunchao Zhang   Mat_MPIAIJ                  *cmpi = static_cast<Mat_MPIAIJ *>(C->data);
12740e3ece09SJunchao Zhang   Mat_Product                 *product;
12750e3ece09SJunchao Zhang   MatProductData_MPIAIJKokkos *pdata;
1276076ba34aSJunchao Zhang   MatProductType               ptype;
12770e3ece09SJunchao Zhang   Mat                          A, B;
1278076ba34aSJunchao Zhang 
1279076ba34aSJunchao Zhang   PetscFunctionBegin;
12800e3ece09SJunchao Zhang   MatCheckProduct(C, 1); // make sure C is a product
12810e3ece09SJunchao Zhang   product = C->product;
12820e3ece09SJunchao Zhang   pdata   = static_cast<MatProductData_MPIAIJKokkos *>(product->data);
1283076ba34aSJunchao Zhang   ptype   = product->type;
1284076ba34aSJunchao Zhang   A       = product->A;
1285076ba34aSJunchao Zhang   B       = product->B;
1286076ba34aSJunchao Zhang 
12870e3ece09SJunchao Zhang   // See if numeric has already been done in symbolic (e.g., user calls MatMatMult(A,B,MAT_INITIAL_MATRIX,..,C)).
12880e3ece09SJunchao Zhang   // If yes, skip the numeric, but reset the flag so that next time when user calls MatMatMult(E,F,MAT_REUSE_MATRIX,..,C),
12890e3ece09SJunchao Zhang   // we still do numeric.
12900e3ece09SJunchao Zhang   if (pdata->reusesym) { // numeric reuses results from symbolic
12910e3ece09SJunchao Zhang     pdata->reusesym = PETSC_FALSE;
12923ba16761SJacob Faibussowitsch     PetscFunctionReturn(PETSC_SUCCESS);
1293076ba34aSJunchao Zhang   }
1294076ba34aSJunchao Zhang 
1295076ba34aSJunchao Zhang   if (ptype == MATPRODUCT_AB) {
12960e3ece09SJunchao Zhang     PetscCall(MatProductNumeric_MPIAIJKokkos_AB(product, A, B, pdata->mmAB));
1297076ba34aSJunchao Zhang   } else if (ptype == MATPRODUCT_AtB) {
12980e3ece09SJunchao Zhang     PetscCall(MatProductNumeric_MPIAIJKokkos_AtB(product, A, B, pdata->mmAtB));
12990e3ece09SJunchao Zhang   } else if (ptype == MATPRODUCT_PtAP) { // BtAB, computed by Z = AB; C= BtZ
13000e3ece09SJunchao Zhang     PetscCall(MatProductNumeric_MPIAIJKokkos_AB(product, A, B, pdata->mmAB));
13010e3ece09SJunchao Zhang     PetscCall(MatProductNumeric_MPIAIJKokkos_AtB(product, B, pdata->Z, pdata->mmAtB));
1302076ba34aSJunchao Zhang   }
13030e3ece09SJunchao Zhang   PetscCall(MatSeqAIJKokkosModifyDevice(cmpi->A)); // mark that A, B on device are modified
13040e3ece09SJunchao Zhang   PetscCall(MatSeqAIJKokkosModifyDevice(cmpi->B));
13053ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
1306076ba34aSJunchao Zhang }
1307076ba34aSJunchao Zhang 
130866976f2fSJacob Faibussowitsch static PetscErrorCode MatProductSymbolic_MPIAIJKokkos(Mat C)
1309d71ae5a4SJacob Faibussowitsch {
1310076ba34aSJunchao Zhang   Mat                          A, B;
13110e3ece09SJunchao Zhang   Mat_Product                 *product;
1312076ba34aSJunchao Zhang   MatProductType               ptype;
13130e3ece09SJunchao Zhang   MatProductData_MPIAIJKokkos *pdata;
1314076ba34aSJunchao Zhang   MatMatStruct                *mm = NULL;
13150e3ece09SJunchao Zhang   PetscInt                     m, n, M, N;
13160e3ece09SJunchao Zhang   Mat                          Cd, Co;
13170e3ece09SJunchao Zhang   MPI_Comm                     comm;
1318076ba34aSJunchao Zhang 
1319076ba34aSJunchao Zhang   PetscFunctionBegin;
13200e3ece09SJunchao Zhang   PetscCall(PetscObjectGetComm((PetscObject)C, &comm));
1321076ba34aSJunchao Zhang   MatCheckProduct(C, 1);
13220e3ece09SJunchao Zhang   product = C->product;
13230e3ece09SJunchao Zhang   PetscCheck(!product->data, comm, PETSC_ERR_PLIB, "Product data not empty");
1324076ba34aSJunchao Zhang   ptype = product->type;
1325076ba34aSJunchao Zhang   A     = product->A;
1326076ba34aSJunchao Zhang   B     = product->B;
1327076ba34aSJunchao Zhang 
1328076ba34aSJunchao Zhang   switch (ptype) {
13299371c9d4SSatish Balay   case MATPRODUCT_AB:
13309371c9d4SSatish Balay     m = A->rmap->n;
13319371c9d4SSatish Balay     n = B->cmap->n;
13329371c9d4SSatish Balay     M = A->rmap->N;
13339371c9d4SSatish Balay     N = B->cmap->N;
13349371c9d4SSatish Balay     break;
13359371c9d4SSatish Balay   case MATPRODUCT_AtB:
13369371c9d4SSatish Balay     m = A->cmap->n;
13379371c9d4SSatish Balay     n = B->cmap->n;
13389371c9d4SSatish Balay     M = A->cmap->N;
13399371c9d4SSatish Balay     N = B->cmap->N;
13409371c9d4SSatish Balay     break;
13419371c9d4SSatish Balay   case MATPRODUCT_PtAP:
13429371c9d4SSatish Balay     m = B->cmap->n;
13439371c9d4SSatish Balay     n = B->cmap->n;
13449371c9d4SSatish Balay     M = B->cmap->N;
13459371c9d4SSatish Balay     N = B->cmap->N;
13469371c9d4SSatish Balay     break; /* BtAB */
1347d71ae5a4SJacob Faibussowitsch   default:
13480e3ece09SJunchao Zhang     SETERRQ(comm, PETSC_ERR_PLIB, "Not for product type %s", MatProductTypes[ptype]);
1349076ba34aSJunchao Zhang   }
1350076ba34aSJunchao Zhang 
13519566063dSJacob Faibussowitsch   PetscCall(MatSetSizes(C, m, n, M, N));
13529566063dSJacob Faibussowitsch   PetscCall(PetscLayoutSetUp(C->rmap));
13539566063dSJacob Faibussowitsch   PetscCall(PetscLayoutSetUp(C->cmap));
13549566063dSJacob Faibussowitsch   PetscCall(MatSetType(C, ((PetscObject)A)->type_name));
1355076ba34aSJunchao Zhang 
13560e3ece09SJunchao Zhang   pdata           = new MatProductData_MPIAIJKokkos();
13570e3ece09SJunchao Zhang   pdata->reusesym = product->api_user;
1358076ba34aSJunchao Zhang 
1359076ba34aSJunchao Zhang   if (ptype == MATPRODUCT_AB) {
13600e3ece09SJunchao Zhang     auto mmAB = new MatMatStruct_AB();
13610e3ece09SJunchao Zhang     PetscCall(MatProductSymbolic_MPIAIJKokkos_AB(product, A, B, mmAB));
13620e3ece09SJunchao Zhang     mm = pdata->mmAB = mmAB;
1363076ba34aSJunchao Zhang   } else if (ptype == MATPRODUCT_AtB) {
13640e3ece09SJunchao Zhang     auto mmAtB = new MatMatStruct_AtB();
13650e3ece09SJunchao Zhang     PetscCall(MatProductSymbolic_MPIAIJKokkos_AtB(product, A, B, mmAtB));
13660e3ece09SJunchao Zhang     mm = pdata->mmAtB = mmAtB;
13670e3ece09SJunchao Zhang   } else if (ptype == MATPRODUCT_PtAP) { // C = BtAB, computed as Z = AB; C= BtZ
13680e3ece09SJunchao Zhang     Mat Zd, Zo, Z;                       // Zd, Zo are owned by pdata->Z
13690e3ece09SJunchao Zhang 
13700e3ece09SJunchao Zhang     auto mmAB = new MatMatStruct_AB();
13710e3ece09SJunchao Zhang     PetscCall(MatProductSymbolic_MPIAIJKokkos_AB(product, A, B, mmAB)); // Z stored as mmAB->{Cd, Co}
13720e3ece09SJunchao Zhang     PetscCall(MatCreateSeqAIJKokkosWithKokkosCsrMatrix(PETSC_COMM_SELF, mmAB->Cd, &Zd));
13730e3ece09SJunchao Zhang     PetscCall(MatCreateSeqAIJKokkosWithKokkosCsrMatrix(PETSC_COMM_SELF, mmAB->Co, &Zo));
13740e3ece09SJunchao Zhang     pdata->mmAB = mmAB;
13750e3ece09SJunchao Zhang 
13760e3ece09SJunchao Zhang     m = A->rmap->n; // Z's layout
13770e3ece09SJunchao Zhang     n = B->cmap->n;
13780e3ece09SJunchao Zhang     M = A->rmap->N;
13790e3ece09SJunchao Zhang     N = B->cmap->N;
1380c0c276a7Ssdargavi     PetscCall(MatCreateMPIAIJWithSeqAIJ(comm, M, N, Zd, Zo, mmAB->garray, &Z));
13810e3ece09SJunchao Zhang 
13820e3ece09SJunchao Zhang     auto mmAtB = new MatMatStruct_AtB();
13830e3ece09SJunchao Zhang     PetscCall(MatProductSymbolic_MPIAIJKokkos_AtB(product, B, Z, mmAtB)); // final result C stored as mmAtB->{Cd, Co}
13840e3ece09SJunchao Zhang 
13850e3ece09SJunchao Zhang     pdata->Z = Z; // give ownership to pdata
13860e3ece09SJunchao Zhang     mm = pdata->mmAtB = mmAtB;
1387076ba34aSJunchao Zhang   }
13880e3ece09SJunchao Zhang 
13890e3ece09SJunchao Zhang   PetscCall(MatCreateSeqAIJKokkosWithKokkosCsrMatrix(PETSC_COMM_SELF, mm->Cd, &Cd));
13900e3ece09SJunchao Zhang   PetscCall(MatCreateSeqAIJKokkosWithKokkosCsrMatrix(PETSC_COMM_SELF, mm->Co, &Co));
1391c0c276a7Ssdargavi   PetscCall(MatSetMPIAIJWithSplitSeqAIJ(C, Cd, Co, mm->garray));
139239cfb508SMark Adams   /* set block sizes */
139339cfb508SMark Adams   switch (ptype) {
139439cfb508SMark Adams   case MATPRODUCT_PtAP:
139539cfb508SMark Adams     if (B->cmap->bs > 1) PetscCall(MatSetBlockSizes(C, B->cmap->bs, B->cmap->bs));
139639cfb508SMark Adams     break;
139739cfb508SMark Adams   case MATPRODUCT_RARt:
139839cfb508SMark Adams     if (B->rmap->bs > 1) PetscCall(MatSetBlockSizes(C, B->rmap->bs, B->rmap->bs));
139939cfb508SMark Adams     break;
140039cfb508SMark Adams   case MATPRODUCT_ABC:
140139cfb508SMark Adams     PetscCall(MatSetBlockSizesFromMats(C, A, product->C));
140239cfb508SMark Adams     break;
140339cfb508SMark Adams   case MATPRODUCT_AB:
140439cfb508SMark Adams     PetscCall(MatSetBlockSizesFromMats(C, A, B));
140539cfb508SMark Adams     break;
140639cfb508SMark Adams   case MATPRODUCT_AtB:
140739cfb508SMark Adams     if (A->cmap->bs > 1 || B->cmap->bs > 1) PetscCall(MatSetBlockSizes(C, A->cmap->bs, B->cmap->bs));
140839cfb508SMark Adams     break;
140939cfb508SMark Adams   case MATPRODUCT_ABt:
141039cfb508SMark Adams     if (A->rmap->bs > 1 || B->rmap->bs > 1) PetscCall(MatSetBlockSizes(C, A->rmap->bs, B->rmap->bs));
141139cfb508SMark Adams     break;
141239cfb508SMark Adams   default:
141339cfb508SMark Adams     SETERRQ(PetscObjectComm((PetscObject)C), PETSC_ERR_PLIB, "Not for ProductType %s", MatProductTypes[ptype]);
141439cfb508SMark Adams   }
14150e3ece09SJunchao Zhang   C->product->data       = pdata;
1416076ba34aSJunchao Zhang   C->product->destroy    = MatProductDataDestroy_MPIAIJKokkos;
1417076ba34aSJunchao Zhang   C->ops->productnumeric = MatProductNumeric_MPIAIJKokkos;
14183ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
1419076ba34aSJunchao Zhang }
1420076ba34aSJunchao Zhang 
1421d71ae5a4SJacob Faibussowitsch PETSC_INTERN PetscErrorCode MatProductSetFromOptions_MPIAIJKokkos(Mat mat)
1422d71ae5a4SJacob Faibussowitsch {
1423076ba34aSJunchao Zhang   Mat_Product *product = mat->product;
1424076ba34aSJunchao Zhang   PetscBool    match   = PETSC_FALSE;
1425076ba34aSJunchao Zhang   PetscBool    usecpu  = PETSC_FALSE;
1426076ba34aSJunchao Zhang 
1427076ba34aSJunchao Zhang   PetscFunctionBegin;
1428076ba34aSJunchao Zhang   MatCheckProduct(mat, 1);
142948a46eb9SPierre Jolivet   if (!product->A->boundtocpu && !product->B->boundtocpu) PetscCall(PetscObjectTypeCompare((PetscObject)product->B, ((PetscObject)product->A)->type_name, &match));
1430076ba34aSJunchao Zhang   if (match) { /* we can always fallback to the CPU if requested */
1431076ba34aSJunchao Zhang     switch (product->type) {
1432076ba34aSJunchao Zhang     case MATPRODUCT_AB:
1433076ba34aSJunchao Zhang       if (product->api_user) {
1434d0609cedSBarry Smith         PetscOptionsBegin(PetscObjectComm((PetscObject)mat), ((PetscObject)mat)->prefix, "MatMatMult", "Mat");
14359566063dSJacob Faibussowitsch         PetscCall(PetscOptionsBool("-matmatmult_backend_cpu", "Use CPU code", "MatMatMult", usecpu, &usecpu, NULL));
1436d0609cedSBarry Smith         PetscOptionsEnd();
1437076ba34aSJunchao Zhang       } else {
1438d0609cedSBarry Smith         PetscOptionsBegin(PetscObjectComm((PetscObject)mat), ((PetscObject)mat)->prefix, "MatProduct_AB", "Mat");
14399566063dSJacob Faibussowitsch         PetscCall(PetscOptionsBool("-mat_product_algorithm_backend_cpu", "Use CPU code", "MatMatMult", usecpu, &usecpu, NULL));
1440d0609cedSBarry Smith         PetscOptionsEnd();
1441076ba34aSJunchao Zhang       }
1442076ba34aSJunchao Zhang       break;
1443076ba34aSJunchao Zhang     case MATPRODUCT_AtB:
1444076ba34aSJunchao Zhang       if (product->api_user) {
1445d0609cedSBarry Smith         PetscOptionsBegin(PetscObjectComm((PetscObject)mat), ((PetscObject)mat)->prefix, "MatTransposeMatMult", "Mat");
14469566063dSJacob Faibussowitsch         PetscCall(PetscOptionsBool("-mattransposematmult_backend_cpu", "Use CPU code", "MatTransposeMatMult", usecpu, &usecpu, NULL));
1447d0609cedSBarry Smith         PetscOptionsEnd();
1448076ba34aSJunchao Zhang       } else {
1449d0609cedSBarry Smith         PetscOptionsBegin(PetscObjectComm((PetscObject)mat), ((PetscObject)mat)->prefix, "MatProduct_AtB", "Mat");
14509566063dSJacob Faibussowitsch         PetscCall(PetscOptionsBool("-mat_product_algorithm_backend_cpu", "Use CPU code", "MatTransposeMatMult", usecpu, &usecpu, NULL));
1451d0609cedSBarry Smith         PetscOptionsEnd();
1452076ba34aSJunchao Zhang       }
1453076ba34aSJunchao Zhang       break;
1454076ba34aSJunchao Zhang     case MATPRODUCT_PtAP:
1455076ba34aSJunchao Zhang       if (product->api_user) {
1456d0609cedSBarry Smith         PetscOptionsBegin(PetscObjectComm((PetscObject)mat), ((PetscObject)mat)->prefix, "MatPtAP", "Mat");
14579566063dSJacob Faibussowitsch         PetscCall(PetscOptionsBool("-matptap_backend_cpu", "Use CPU code", "MatPtAP", usecpu, &usecpu, NULL));
1458d0609cedSBarry Smith         PetscOptionsEnd();
1459076ba34aSJunchao Zhang       } else {
1460d0609cedSBarry Smith         PetscOptionsBegin(PetscObjectComm((PetscObject)mat), ((PetscObject)mat)->prefix, "MatProduct_PtAP", "Mat");
14619566063dSJacob Faibussowitsch         PetscCall(PetscOptionsBool("-mat_product_algorithm_backend_cpu", "Use CPU code", "MatPtAP", usecpu, &usecpu, NULL));
1462d0609cedSBarry Smith         PetscOptionsEnd();
1463076ba34aSJunchao Zhang       }
1464076ba34aSJunchao Zhang       break;
1465d71ae5a4SJacob Faibussowitsch     default:
1466d71ae5a4SJacob Faibussowitsch       break;
1467076ba34aSJunchao Zhang     }
1468076ba34aSJunchao Zhang     match = (PetscBool)!usecpu;
1469076ba34aSJunchao Zhang   }
1470076ba34aSJunchao Zhang   if (match) {
1471076ba34aSJunchao Zhang     switch (product->type) {
1472076ba34aSJunchao Zhang     case MATPRODUCT_AB:
1473076ba34aSJunchao Zhang     case MATPRODUCT_AtB:
1474d71ae5a4SJacob Faibussowitsch     case MATPRODUCT_PtAP:
1475d71ae5a4SJacob Faibussowitsch       mat->ops->productsymbolic = MatProductSymbolic_MPIAIJKokkos;
1476d71ae5a4SJacob Faibussowitsch       break;
1477d71ae5a4SJacob Faibussowitsch     default:
1478d71ae5a4SJacob Faibussowitsch       break;
1479076ba34aSJunchao Zhang     }
1480076ba34aSJunchao Zhang   }
1481076ba34aSJunchao Zhang   /* fallback to MPIAIJ ops */
148248a46eb9SPierre Jolivet   if (!mat->ops->productsymbolic) PetscCall(MatProductSetFromOptions_MPIAIJ(mat));
14833ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
1484076ba34aSJunchao Zhang }
1485076ba34aSJunchao Zhang 
14862c4ab24aSJunchao Zhang // Mirror of MatCOOStruct_MPIAIJ on device
14872c4ab24aSJunchao Zhang struct MatCOOStruct_MPIAIJKokkos {
14882c4ab24aSJunchao Zhang   PetscCount           n;
14892c4ab24aSJunchao Zhang   PetscSF              sf;
14902c4ab24aSJunchao Zhang   PetscCount           Annz, Bnnz;
14912c4ab24aSJunchao Zhang   PetscCount           Annz2, Bnnz2;
14922c4ab24aSJunchao Zhang   PetscCountKokkosView Ajmap1, Aperm1;
14932c4ab24aSJunchao Zhang   PetscCountKokkosView Bjmap1, Bperm1;
14942c4ab24aSJunchao Zhang   PetscCountKokkosView Aimap2, Ajmap2, Aperm2;
14952c4ab24aSJunchao Zhang   PetscCountKokkosView Bimap2, Bjmap2, Bperm2;
14962c4ab24aSJunchao Zhang   PetscCountKokkosView Cperm1;
14972c4ab24aSJunchao Zhang   MatScalarKokkosView  sendbuf, recvbuf;
14982c4ab24aSJunchao Zhang 
149992896123SJunchao Zhang   MatCOOStruct_MPIAIJKokkos(const MatCOOStruct_MPIAIJ *coo_h)
15002c4ab24aSJunchao Zhang   {
15014df4a32cSJunchao Zhang     auto exec = PetscGetKokkosExecutionSpace();
150292896123SJunchao Zhang 
150392896123SJunchao Zhang     n       = coo_h->n;
150492896123SJunchao Zhang     sf      = coo_h->sf;
150592896123SJunchao Zhang     Annz    = coo_h->Annz;
150692896123SJunchao Zhang     Bnnz    = coo_h->Bnnz;
150792896123SJunchao Zhang     Annz2   = coo_h->Annz2;
150892896123SJunchao Zhang     Bnnz2   = coo_h->Bnnz2;
150992896123SJunchao Zhang     Ajmap1  = Kokkos::create_mirror_view_and_copy(exec, PetscCountKokkosViewHost(coo_h->Ajmap1, coo_h->Annz + 1));
151092896123SJunchao Zhang     Aperm1  = Kokkos::create_mirror_view_and_copy(exec, PetscCountKokkosViewHost(coo_h->Aperm1, coo_h->Atot1));
151192896123SJunchao Zhang     Bjmap1  = Kokkos::create_mirror_view_and_copy(exec, PetscCountKokkosViewHost(coo_h->Bjmap1, coo_h->Bnnz + 1));
151292896123SJunchao Zhang     Bperm1  = Kokkos::create_mirror_view_and_copy(exec, PetscCountKokkosViewHost(coo_h->Bperm1, coo_h->Btot1));
151392896123SJunchao Zhang     Aimap2  = Kokkos::create_mirror_view_and_copy(exec, PetscCountKokkosViewHost(coo_h->Aimap2, coo_h->Annz2));
151492896123SJunchao Zhang     Ajmap2  = Kokkos::create_mirror_view_and_copy(exec, PetscCountKokkosViewHost(coo_h->Ajmap2, coo_h->Annz2 + 1));
151592896123SJunchao Zhang     Aperm2  = Kokkos::create_mirror_view_and_copy(exec, PetscCountKokkosViewHost(coo_h->Aperm2, coo_h->Atot2));
151692896123SJunchao Zhang     Bimap2  = Kokkos::create_mirror_view_and_copy(exec, PetscCountKokkosViewHost(coo_h->Bimap2, coo_h->Bnnz2));
151792896123SJunchao Zhang     Bjmap2  = Kokkos::create_mirror_view_and_copy(exec, PetscCountKokkosViewHost(coo_h->Bjmap2, coo_h->Bnnz2 + 1));
151892896123SJunchao Zhang     Bperm2  = Kokkos::create_mirror_view_and_copy(exec, PetscCountKokkosViewHost(coo_h->Bperm2, coo_h->Btot2));
151992896123SJunchao Zhang     Cperm1  = Kokkos::create_mirror_view_and_copy(exec, PetscCountKokkosViewHost(coo_h->Cperm1, coo_h->sendlen));
152092896123SJunchao Zhang     sendbuf = Kokkos::create_mirror_view(Kokkos::WithoutInitializing, exec, MatScalarKokkosViewHost(coo_h->sendbuf, coo_h->sendlen));
152192896123SJunchao Zhang     recvbuf = Kokkos::create_mirror_view(Kokkos::WithoutInitializing, exec, MatScalarKokkosViewHost(coo_h->recvbuf, coo_h->recvlen));
15222c4ab24aSJunchao Zhang     PetscCallVoid(PetscObjectReference((PetscObject)sf));
15232c4ab24aSJunchao Zhang   }
15242c4ab24aSJunchao Zhang 
15252c4ab24aSJunchao Zhang   ~MatCOOStruct_MPIAIJKokkos() { PetscCallVoid(PetscSFDestroy(&sf)); }
15262c4ab24aSJunchao Zhang };
15272c4ab24aSJunchao Zhang 
152849abdd8aSBarry Smith static PetscErrorCode MatCOOStructDestroy_MPIAIJKokkos(void **data)
15292c4ab24aSJunchao Zhang {
15302c4ab24aSJunchao Zhang   PetscFunctionBegin;
153149abdd8aSBarry Smith   PetscCallCXX(delete static_cast<MatCOOStruct_MPIAIJKokkos *>(*data));
15322c4ab24aSJunchao Zhang   PetscFunctionReturn(PETSC_SUCCESS);
15332c4ab24aSJunchao Zhang }
15342c4ab24aSJunchao Zhang 
1535d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatSetPreallocationCOO_MPIAIJKokkos(Mat mat, PetscCount coo_n, PetscInt coo_i[], PetscInt coo_j[])
1536d71ae5a4SJacob Faibussowitsch {
15372c4ab24aSJunchao Zhang   PetscContainer             container_h, container_d;
15382c4ab24aSJunchao Zhang   MatCOOStruct_MPIAIJ       *coo_h;
15392c4ab24aSJunchao Zhang   MatCOOStruct_MPIAIJKokkos *coo_d;
154042550becSJunchao Zhang 
154142550becSJunchao Zhang   PetscFunctionBegin;
154230203840SJunchao Zhang   PetscCall(MatSetPreallocationCOO_MPIAIJ(mat, coo_n, coo_i, coo_j)); /* mpiaij->A,B's type is set to seqaijkokkos */
1543cbc6b225SStefano Zampini   mat->preallocated = PETSC_TRUE;
15449566063dSJacob Faibussowitsch   PetscCall(MatAssemblyBegin(mat, MAT_FINAL_ASSEMBLY));
15459566063dSJacob Faibussowitsch   PetscCall(MatAssemblyEnd(mat, MAT_FINAL_ASSEMBLY));
15469566063dSJacob Faibussowitsch   PetscCall(MatZeroEntries(mat));
15472c4ab24aSJunchao Zhang 
15482c4ab24aSJunchao Zhang   // Copy the COO struct to device
15492c4ab24aSJunchao Zhang   PetscCall(PetscObjectQuery((PetscObject)mat, "__PETSc_MatCOOStruct_Host", (PetscObject *)&container_h));
15502c4ab24aSJunchao Zhang   PetscCall(PetscContainerGetPointer(container_h, (void **)&coo_h));
15512c4ab24aSJunchao Zhang   PetscCallCXX(coo_d = new MatCOOStruct_MPIAIJKokkos(coo_h));
15522c4ab24aSJunchao Zhang 
15532c4ab24aSJunchao Zhang   // Put the COO struct in a container and then attach that to the matrix
15542c4ab24aSJunchao Zhang   PetscCall(PetscContainerCreate(PETSC_COMM_SELF, &container_d));
15552c4ab24aSJunchao Zhang   PetscCall(PetscContainerSetPointer(container_d, coo_d));
155649abdd8aSBarry Smith   PetscCall(PetscContainerSetCtxDestroy(container_d, MatCOOStructDestroy_MPIAIJKokkos));
15572c4ab24aSJunchao Zhang   PetscCall(PetscObjectCompose((PetscObject)mat, "__PETSc_MatCOOStruct_Device", (PetscObject)container_d));
15582c4ab24aSJunchao Zhang   PetscCall(PetscContainerDestroy(&container_d));
15593ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
156042550becSJunchao Zhang }
156142550becSJunchao Zhang 
1562d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatSetValuesCOO_MPIAIJKokkos(Mat mat, const PetscScalar v[], InsertMode imode)
1563d71ae5a4SJacob Faibussowitsch {
1564394ed5ebSJunchao Zhang   Mat_MPIAIJ                   *mpiaij = static_cast<Mat_MPIAIJ *>(mat->data);
156542550becSJunchao Zhang   Mat                           A = mpiaij->A, B = mpiaij->B;
156642550becSJunchao Zhang   MatScalarKokkosView           Aa, Ba;
1567394ed5ebSJunchao Zhang   MatScalarKokkosView           v1;
156842550becSJunchao Zhang   PetscMemType                  memtype;
15692c4ab24aSJunchao Zhang   PetscContainer                container;
15702c4ab24aSJunchao Zhang   MatCOOStruct_MPIAIJKokkos    *coo;
15714df4a32cSJunchao Zhang   Kokkos::DefaultExecutionSpace exec = PetscGetKokkosExecutionSpace();
157242550becSJunchao Zhang 
157342550becSJunchao Zhang   PetscFunctionBegin;
15742c4ab24aSJunchao Zhang   PetscCall(PetscObjectQuery((PetscObject)mat, "__PETSc_MatCOOStruct_Device", (PetscObject *)&container));
15752c4ab24aSJunchao Zhang   PetscCall(PetscContainerGetPointer(container, (void **)&coo));
15762c4ab24aSJunchao Zhang 
15772c4ab24aSJunchao Zhang   const auto &n      = coo->n;
15782c4ab24aSJunchao Zhang   const auto &Annz   = coo->Annz;
15792c4ab24aSJunchao Zhang   const auto &Annz2  = coo->Annz2;
15802c4ab24aSJunchao Zhang   const auto &Bnnz   = coo->Bnnz;
15812c4ab24aSJunchao Zhang   const auto &Bnnz2  = coo->Bnnz2;
15822c4ab24aSJunchao Zhang   const auto &vsend  = coo->sendbuf;
15832c4ab24aSJunchao Zhang   const auto &v2     = coo->recvbuf;
15842c4ab24aSJunchao Zhang   const auto &Ajmap1 = coo->Ajmap1;
15852c4ab24aSJunchao Zhang   const auto &Ajmap2 = coo->Ajmap2;
15862c4ab24aSJunchao Zhang   const auto &Aimap2 = coo->Aimap2;
15872c4ab24aSJunchao Zhang   const auto &Bjmap1 = coo->Bjmap1;
15882c4ab24aSJunchao Zhang   const auto &Bjmap2 = coo->Bjmap2;
15892c4ab24aSJunchao Zhang   const auto &Bimap2 = coo->Bimap2;
15902c4ab24aSJunchao Zhang   const auto &Aperm1 = coo->Aperm1;
15912c4ab24aSJunchao Zhang   const auto &Aperm2 = coo->Aperm2;
15922c4ab24aSJunchao Zhang   const auto &Bperm1 = coo->Bperm1;
15932c4ab24aSJunchao Zhang   const auto &Bperm2 = coo->Bperm2;
15942c4ab24aSJunchao Zhang   const auto &Cperm1 = coo->Cperm1;
15952c4ab24aSJunchao Zhang 
15969566063dSJacob Faibussowitsch   PetscCall(PetscGetMemType(v, &memtype)); /* Return PETSC_MEMTYPE_HOST when v is NULL */
159742550becSJunchao Zhang   if (PetscMemTypeHost(memtype)) {         /* If user gave v[] in host, we need to copy it to device if any */
159892896123SJunchao Zhang     v1 = Kokkos::create_mirror_view_and_copy(exec, MatScalarKokkosViewHost((PetscScalar *)v, n));
159942550becSJunchao Zhang   } else {
16002c4ab24aSJunchao Zhang     v1 = MatScalarKokkosView((PetscScalar *)v, n); /* Directly use v[]'s memory */
160142550becSJunchao Zhang   }
160242550becSJunchao Zhang 
160342550becSJunchao Zhang   if (imode == INSERT_VALUES) {
16049566063dSJacob Faibussowitsch     PetscCall(MatSeqAIJGetKokkosViewWrite(A, &Aa)); /* write matrix values */
16059566063dSJacob Faibussowitsch     PetscCall(MatSeqAIJGetKokkosViewWrite(B, &Ba));
1606394ed5ebSJunchao Zhang   } else {
16079566063dSJacob Faibussowitsch     PetscCall(MatSeqAIJGetKokkosView(A, &Aa)); /* read & write matrix values */
16089566063dSJacob Faibussowitsch     PetscCall(MatSeqAIJGetKokkosView(B, &Ba));
160942550becSJunchao Zhang   }
161042550becSJunchao Zhang 
161108bb9926SJunchao Zhang   PetscCall(PetscLogGpuTimeBegin());
161242550becSJunchao Zhang   /* Pack entries to be sent to remote */
161392896123SJunchao Zhang   Kokkos::parallel_for(Kokkos::RangePolicy<>(exec, 0, vsend.extent(0)), KOKKOS_LAMBDA(const PetscCount i) { vsend(i) = v1(Cperm1(i)); });
161442550becSJunchao Zhang 
161542550becSJunchao Zhang   /* Send remote entries to their owner and overlap the communication with local computation */
16162c4ab24aSJunchao Zhang   PetscCall(PetscSFReduceWithMemTypeBegin(coo->sf, MPIU_SCALAR, PETSC_MEMTYPE_KOKKOS, vsend.data(), PETSC_MEMTYPE_KOKKOS, v2.data(), MPI_REPLACE));
1617158ec288SJunchao Zhang   /* Add local entries to A and B in one kernel */
16189371c9d4SSatish Balay   Kokkos::parallel_for(
161992896123SJunchao Zhang     Kokkos::RangePolicy<>(exec, 0, Annz + Bnnz), KOKKOS_LAMBDA(PetscCount i) {
1620158ec288SJunchao Zhang       PetscScalar sum = 0.0;
1621158ec288SJunchao Zhang       if (i < Annz) {
1622158ec288SJunchao Zhang         for (PetscCount k = Ajmap1(i); k < Ajmap1(i + 1); k++) sum += v1(Aperm1(k));
1623ac38520cSJunchao Zhang         Aa(i) = (imode == INSERT_VALUES ? 0.0 : Aa(i)) + sum;
1624158ec288SJunchao Zhang       } else {
1625158ec288SJunchao Zhang         i -= Annz;
1626158ec288SJunchao Zhang         for (PetscCount k = Bjmap1(i); k < Bjmap1(i + 1); k++) sum += v1(Bperm1(k));
1627ac38520cSJunchao Zhang         Ba(i) = (imode == INSERT_VALUES ? 0.0 : Ba(i)) + sum;
1628158ec288SJunchao Zhang       }
1629158ec288SJunchao Zhang     });
16302c4ab24aSJunchao Zhang   PetscCall(PetscSFReduceEnd(coo->sf, MPIU_SCALAR, vsend.data(), v2.data(), MPI_REPLACE));
163142550becSJunchao Zhang 
1632158ec288SJunchao Zhang   /* Add received remote entries to A and B in one kernel */
16339371c9d4SSatish Balay   Kokkos::parallel_for(
163492896123SJunchao Zhang     Kokkos::RangePolicy<>(exec, 0, Annz2 + Bnnz2), KOKKOS_LAMBDA(PetscCount i) {
1635158ec288SJunchao Zhang       if (i < Annz2) {
1636158ec288SJunchao Zhang         for (PetscCount k = Ajmap2(i); k < Ajmap2(i + 1); k++) Aa(Aimap2(i)) += v2(Aperm2(k));
1637158ec288SJunchao Zhang       } else {
1638158ec288SJunchao Zhang         i -= Annz2;
1639158ec288SJunchao Zhang         for (PetscCount k = Bjmap2(i); k < Bjmap2(i + 1); k++) Ba(Bimap2(i)) += v2(Bperm2(k));
1640158ec288SJunchao Zhang       }
1641158ec288SJunchao Zhang     });
164208bb9926SJunchao Zhang   PetscCall(PetscLogGpuTimeEnd());
164342550becSJunchao Zhang 
1644394ed5ebSJunchao Zhang   if (imode == INSERT_VALUES) {
16459566063dSJacob Faibussowitsch     PetscCall(MatSeqAIJRestoreKokkosViewWrite(A, &Aa)); /* Increase A & B's state etc. */
16469566063dSJacob Faibussowitsch     PetscCall(MatSeqAIJRestoreKokkosViewWrite(B, &Ba));
1647394ed5ebSJunchao Zhang   } else {
16489566063dSJacob Faibussowitsch     PetscCall(MatSeqAIJRestoreKokkosView(A, &Aa));
16499566063dSJacob Faibussowitsch     PetscCall(MatSeqAIJRestoreKokkosView(B, &Ba));
1650394ed5ebSJunchao Zhang   }
16513ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
165242550becSJunchao Zhang }
165342550becSJunchao Zhang 
16542c4ab24aSJunchao Zhang static PetscErrorCode MatDestroy_MPIAIJKokkos(Mat A)
1655d71ae5a4SJacob Faibussowitsch {
1656076ba34aSJunchao Zhang   PetscFunctionBegin;
16579566063dSJacob Faibussowitsch   PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatMPIAIJSetPreallocation_C", NULL));
16589566063dSJacob Faibussowitsch   PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatMPIAIJGetLocalMatMerge_C", NULL));
16599566063dSJacob Faibussowitsch   PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatSetPreallocationCOO_C", NULL));
16609566063dSJacob Faibussowitsch   PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatSetValuesCOO_C", NULL));
166157761e9aSJunchao Zhang #if defined(PETSC_HAVE_HYPRE)
166257761e9aSJunchao Zhang   PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatConvert_mpiaijkokkos_hypre_C", NULL));
166357761e9aSJunchao Zhang #endif
16649566063dSJacob Faibussowitsch   PetscCall(MatDestroy_MPIAIJ(A));
16653ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
1666076ba34aSJunchao Zhang }
1667076ba34aSJunchao Zhang 
1668f4747e26SJunchao Zhang static PetscErrorCode MatShift_MPIAIJKokkos(Mat A, PetscScalar a)
1669f4747e26SJunchao Zhang {
1670f4747e26SJunchao Zhang   Mat_MPIAIJ *mpiaij = static_cast<Mat_MPIAIJ *>(A->data);
1671f4747e26SJunchao Zhang   PetscBool   congruent;
1672f4747e26SJunchao Zhang 
1673f4747e26SJunchao Zhang   PetscFunctionBegin;
1674f4747e26SJunchao Zhang   PetscCall(MatHasCongruentLayouts(A, &congruent));
1675f4747e26SJunchao Zhang   if (congruent) { // square matrix and the diagonals are solely in the diag block
1676f4747e26SJunchao Zhang     PetscCall(MatShift(mpiaij->A, a));
1677f4747e26SJunchao Zhang   } else { // too hard, use the general version
1678f4747e26SJunchao Zhang     PetscCall(MatShift_Basic(A, a));
1679f4747e26SJunchao Zhang   }
1680f4747e26SJunchao Zhang   PetscFunctionReturn(PETSC_SUCCESS);
1681f4747e26SJunchao Zhang }
1682f4747e26SJunchao Zhang 
16832c4ab24aSJunchao Zhang static PetscErrorCode MatSetOps_MPIAIJKokkos(Mat B)
16842c4ab24aSJunchao Zhang {
16852c4ab24aSJunchao Zhang   PetscFunctionBegin;
16862c4ab24aSJunchao Zhang   B->ops->assemblyend           = MatAssemblyEnd_MPIAIJKokkos;
16872c4ab24aSJunchao Zhang   B->ops->mult                  = MatMult_MPIAIJKokkos;
16882c4ab24aSJunchao Zhang   B->ops->multadd               = MatMultAdd_MPIAIJKokkos;
16892c4ab24aSJunchao Zhang   B->ops->multtranspose         = MatMultTranspose_MPIAIJKokkos;
16902c4ab24aSJunchao Zhang   B->ops->productsetfromoptions = MatProductSetFromOptions_MPIAIJKokkos;
16912c4ab24aSJunchao Zhang   B->ops->destroy               = MatDestroy_MPIAIJKokkos;
1692f4747e26SJunchao Zhang   B->ops->shift                 = MatShift_MPIAIJKokkos;
16932c4ab24aSJunchao Zhang 
16942c4ab24aSJunchao Zhang   PetscCall(PetscObjectComposeFunction((PetscObject)B, "MatMPIAIJSetPreallocation_C", MatMPIAIJSetPreallocation_MPIAIJKokkos));
16952c4ab24aSJunchao Zhang   PetscCall(PetscObjectComposeFunction((PetscObject)B, "MatMPIAIJGetLocalMatMerge_C", MatMPIAIJGetLocalMatMerge_MPIAIJKokkos));
16962c4ab24aSJunchao Zhang   PetscCall(PetscObjectComposeFunction((PetscObject)B, "MatSetPreallocationCOO_C", MatSetPreallocationCOO_MPIAIJKokkos));
16972c4ab24aSJunchao Zhang   PetscCall(PetscObjectComposeFunction((PetscObject)B, "MatSetValuesCOO_C", MatSetValuesCOO_MPIAIJKokkos));
169857761e9aSJunchao Zhang #if defined(PETSC_HAVE_HYPRE)
169957761e9aSJunchao Zhang   PetscCall(PetscObjectComposeFunction((PetscObject)B, "MatConvert_mpiaijkokkos_hypre_C", MatConvert_AIJ_HYPRE));
170057761e9aSJunchao Zhang #endif
17012c4ab24aSJunchao Zhang   PetscFunctionReturn(PETSC_SUCCESS);
17022c4ab24aSJunchao Zhang }
17032c4ab24aSJunchao Zhang 
1704d71ae5a4SJacob Faibussowitsch PETSC_INTERN PetscErrorCode MatConvert_MPIAIJ_MPIAIJKokkos(Mat A, MatType mtype, MatReuse reuse, Mat *newmat)
1705d71ae5a4SJacob Faibussowitsch {
17068c3ff71bSJunchao Zhang   Mat         B;
1707076ba34aSJunchao Zhang   Mat_MPIAIJ *a;
17088c3ff71bSJunchao Zhang 
17098c3ff71bSJunchao Zhang   PetscFunctionBegin;
17108c3ff71bSJunchao Zhang   if (reuse == MAT_INITIAL_MATRIX) {
17119566063dSJacob Faibussowitsch     PetscCall(MatDuplicate(A, MAT_COPY_VALUES, newmat));
17128c3ff71bSJunchao Zhang   } else if (reuse == MAT_REUSE_MATRIX) {
17139566063dSJacob Faibussowitsch     PetscCall(MatCopy(A, *newmat, SAME_NONZERO_PATTERN));
17148c3ff71bSJunchao Zhang   }
17158c3ff71bSJunchao Zhang   B = *newmat;
17168c3ff71bSJunchao Zhang 
17176f3d89d0SStefano Zampini   B->boundtocpu = PETSC_FALSE;
17189566063dSJacob Faibussowitsch   PetscCall(PetscFree(B->defaultvectype));
17199566063dSJacob Faibussowitsch   PetscCall(PetscStrallocpy(VECKOKKOS, &B->defaultvectype));
17209566063dSJacob Faibussowitsch   PetscCall(PetscObjectChangeTypeName((PetscObject)B, MATMPIAIJKOKKOS));
17218c3ff71bSJunchao Zhang 
1722076ba34aSJunchao Zhang   a = static_cast<Mat_MPIAIJ *>(A->data);
17239566063dSJacob Faibussowitsch   if (a->A) PetscCall(MatSetType(a->A, MATSEQAIJKOKKOS));
17249566063dSJacob Faibussowitsch   if (a->B) PetscCall(MatSetType(a->B, MATSEQAIJKOKKOS));
17259566063dSJacob Faibussowitsch   if (a->lvec) PetscCall(VecSetType(a->lvec, VECSEQKOKKOS));
17262c4ab24aSJunchao Zhang   PetscCall(MatSetOps_MPIAIJKokkos(B));
17273ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
17288c3ff71bSJunchao Zhang }
17292c4ab24aSJunchao Zhang 
17303f3ba80aSJunchao Zhang /*MC
173111a5261eSBarry Smith    MATAIJKOKKOS - "mpiaijkokkos", a matrix type to be used for CSR sparse matrices with Kokkos
17328c3ff71bSJunchao Zhang 
173315229ffcSPierre Jolivet    A matrix type using Kokkos-Kernels CrsMatrix type for portability across different device types
17343f3ba80aSJunchao Zhang 
17352ef1f0ffSBarry Smith    Options Database Key:
17362ef1f0ffSBarry Smith .  -mat_type aijkokkos - sets the matrix type to `MATAIJKOKKOS`
17373f3ba80aSJunchao Zhang 
17383f3ba80aSJunchao Zhang   Level: beginner
17393f3ba80aSJunchao Zhang 
17401cc06b55SBarry Smith .seealso: [](ch_matrices), `Mat`, `MatCreateAIJKokkos()`, `MATSEQAIJKOKKOS`, `MATSEQAIJ`, `MATMPIAIJ`
17413f3ba80aSJunchao Zhang M*/
1742d71ae5a4SJacob Faibussowitsch PETSC_EXTERN PetscErrorCode MatCreate_MPIAIJKokkos(Mat A)
1743d71ae5a4SJacob Faibussowitsch {
17448c3ff71bSJunchao Zhang   PetscFunctionBegin;
17459566063dSJacob Faibussowitsch   PetscCall(PetscKokkosInitializeCheck());
17469566063dSJacob Faibussowitsch   PetscCall(MatCreate_MPIAIJ(A));
17479566063dSJacob Faibussowitsch   PetscCall(MatConvert_MPIAIJ_MPIAIJKokkos(A, MATMPIAIJKOKKOS, MAT_INPLACE_MATRIX, &A));
17483ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
17498c3ff71bSJunchao Zhang }
17508c3ff71bSJunchao Zhang 
17518c3ff71bSJunchao Zhang /*@C
1752f8d70eaaSPierre Jolivet   MatCreateAIJKokkos - Creates a sparse matrix in `MATAIJKOKKOS` (compressed row) format
17538c3ff71bSJunchao Zhang   (the default parallel PETSc format).  This matrix will ultimately pushed down
175420f4b53cSBarry Smith   to Kokkos for calculations.
17558c3ff71bSJunchao Zhang 
17568c3ff71bSJunchao Zhang   Collective
17578c3ff71bSJunchao Zhang 
17588c3ff71bSJunchao Zhang   Input Parameters:
175911a5261eSBarry Smith + comm  - MPI communicator, set to `PETSC_COMM_SELF`
176020f4b53cSBarry Smith . m     - number of local rows (or `PETSC_DECIDE` to have calculated if `M` is given)
176120f4b53cSBarry Smith            This value should be the same as the local size used in creating the
176220f4b53cSBarry Smith            y vector for the matrix-vector product y = Ax.
176320f4b53cSBarry Smith . n     - This value should be the same as the local size used in creating the
176420f4b53cSBarry Smith        x vector for the matrix-vector product y = Ax. (or `PETSC_DECIDE` to have
176520f4b53cSBarry Smith        calculated if N is given) For square matrices n is almost always `m`.
176620f4b53cSBarry Smith . M     - number of global rows (or `PETSC_DETERMINE` to have calculated if `m` is given)
176720f4b53cSBarry Smith . N     - number of global columns (or `PETSC_DETERMINE` to have calculated if `n` is given)
176820f4b53cSBarry Smith . d_nz  - number of nonzeros per row in DIAGONAL portion of local submatrix
176920f4b53cSBarry Smith            (same value is used for all local rows)
177020f4b53cSBarry Smith . d_nnz - array containing the number of nonzeros in the various rows of the
177120f4b53cSBarry Smith            DIAGONAL portion of the local submatrix (possibly different for each row)
177220f4b53cSBarry Smith            or `NULL`, if `d_nz` is used to specify the nonzero structure.
177320f4b53cSBarry Smith            The size of this array is equal to the number of local rows, i.e `m`.
177420f4b53cSBarry Smith            For matrices you plan to factor you must leave room for the diagonal entry and
177520f4b53cSBarry Smith            put in the entry even if it is zero.
177620f4b53cSBarry Smith . o_nz  - number of nonzeros per row in the OFF-DIAGONAL portion of local
177720f4b53cSBarry Smith            submatrix (same value is used for all local rows).
177820f4b53cSBarry Smith - o_nnz - array containing the number of nonzeros in the various rows of the
177920f4b53cSBarry Smith            OFF-DIAGONAL portion of the local submatrix (possibly different for
178020f4b53cSBarry Smith            each row) or `NULL`, if `o_nz` is used to specify the nonzero
178120f4b53cSBarry Smith            structure. The size of this array is equal to the number
178220f4b53cSBarry Smith            of local rows, i.e `m`.
17838c3ff71bSJunchao Zhang 
17848c3ff71bSJunchao Zhang   Output Parameter:
17858c3ff71bSJunchao Zhang . A - the matrix
17868c3ff71bSJunchao Zhang 
17872ef1f0ffSBarry Smith   Level: intermediate
17882ef1f0ffSBarry Smith 
17892ef1f0ffSBarry Smith   Notes:
179011a5261eSBarry Smith   It is recommended that one use the `MatCreate()`, `MatSetType()` and/or `MatSetFromOptions()`,
17918c3ff71bSJunchao Zhang   MatXXXXSetPreallocation() paradigm instead of this routine directly.
179211a5261eSBarry Smith   [MatXXXXSetPreallocation() is, for example, `MatSeqAIJSetPreallocation()`]
17938c3ff71bSJunchao Zhang 
1794667f096bSBarry Smith   The AIJ format, also called compressed row storage), is fully compatible with standard Fortran
17958c3ff71bSJunchao Zhang   storage.  That is, the stored row and column indices can begin at
17962ef1f0ffSBarry Smith   either one (as in Fortran) or zero.
17978c3ff71bSJunchao Zhang 
1798f8d70eaaSPierre Jolivet .seealso: [](ch_matrices), `Mat`, `MATAIJKOKKOS`, `MATSEQAIJKOKKOS`, `MATMPIAIJKOKKOS`, `MatCreate()`, `MatCreateAIJ()`, `MatSetValues()`,
1799f8d70eaaSPierre Jolivet           `MatSeqAIJSetColumnIndices()`, `MatCreateSeqAIJWithArrays()`
18008c3ff71bSJunchao Zhang @*/
1801d71ae5a4SJacob Faibussowitsch PetscErrorCode MatCreateAIJKokkos(MPI_Comm comm, PetscInt m, PetscInt n, PetscInt M, PetscInt N, PetscInt d_nz, const PetscInt d_nnz[], PetscInt o_nz, const PetscInt o_nnz[], Mat *A)
1802d71ae5a4SJacob Faibussowitsch {
18038c3ff71bSJunchao Zhang   PetscMPIInt size;
18048c3ff71bSJunchao Zhang 
18058c3ff71bSJunchao Zhang   PetscFunctionBegin;
18069566063dSJacob Faibussowitsch   PetscCall(MatCreate(comm, A));
18079566063dSJacob Faibussowitsch   PetscCall(MatSetSizes(*A, m, n, M, N));
18089566063dSJacob Faibussowitsch   PetscCallMPI(MPI_Comm_size(comm, &size));
18098c3ff71bSJunchao Zhang   if (size > 1) {
18109566063dSJacob Faibussowitsch     PetscCall(MatSetType(*A, MATMPIAIJKOKKOS));
18119566063dSJacob Faibussowitsch     PetscCall(MatMPIAIJSetPreallocation(*A, d_nz, d_nnz, o_nz, o_nnz));
18128c3ff71bSJunchao Zhang   } else {
18139566063dSJacob Faibussowitsch     PetscCall(MatSetType(*A, MATSEQAIJKOKKOS));
18149566063dSJacob Faibussowitsch     PetscCall(MatSeqAIJSetPreallocation(*A, d_nz, d_nnz));
18158c3ff71bSJunchao Zhang   }
18163ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
18178c3ff71bSJunchao Zhang }
1818