xref: /petsc/src/mat/impls/aij/mpi/kokkos/mpiaijkok.kokkos.cxx (revision 7b8d4ba6a08626823f4aa3c70fa9bb901abea174)
111d22bbfSJunchao Zhang #include <petscvec_kokkos.hpp>
2f0e6e2d1SJunchao Zhang #include <petscpkg_version.h>
3076ba34aSJunchao Zhang #include <petscsf.h>
442550becSJunchao Zhang #include <petsc/private/sfimpl.h>
58c3ff71bSJunchao Zhang #include <../src/mat/impls/aij/mpi/mpiaij.h>
642550becSJunchao Zhang #include <../src/mat/impls/aij/mpi/kokkos/mpiaijkok.hpp>
7076ba34aSJunchao Zhang #include <KokkosSparse_spadd.hpp>
80e3ece09SJunchao Zhang #include <KokkosSparse_spgemm.hpp>
911d22bbfSJunchao Zhang 
10d71ae5a4SJacob Faibussowitsch PetscErrorCode MatAssemblyEnd_MPIAIJKokkos(Mat A, MatAssemblyType mode)
11d71ae5a4SJacob Faibussowitsch {
1230203840SJunchao Zhang   Mat_MPIAIJ *mpiaij = (Mat_MPIAIJ *)A->data;
138c3ff71bSJunchao Zhang 
148c3ff71bSJunchao Zhang   PetscFunctionBegin;
159566063dSJacob Faibussowitsch   PetscCall(MatAssemblyEnd_MPIAIJ(A, mode));
1630203840SJunchao Zhang   /* E.g., MatCreateSubMatrix() calls MatCreateMPIAIJWithSeqAIJ(comm,A,B,..), which creates Bnew of SEQAIJ and destroys B of SEQAIJKOKKOS.
1730203840SJunchao Zhang      Thus we finalize A/B/lvec's type in MatAssemblyEnd() to handle various cases.
1830203840SJunchao Zhang    */
1930203840SJunchao Zhang   if (mode == MAT_FINAL_ASSEMBLY) {
2030203840SJunchao Zhang     PetscCall(MatSetType(mpiaij->A, MATSEQAIJKOKKOS));
2130203840SJunchao Zhang     PetscCall(MatSetType(mpiaij->B, MATSEQAIJKOKKOS));
2230203840SJunchao Zhang     PetscCall(VecSetType(mpiaij->lvec, VECSEQKOKKOS));
2330203840SJunchao Zhang   }
24a587d139SMark 
253ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
268c3ff71bSJunchao Zhang }
278c3ff71bSJunchao Zhang 
28d71ae5a4SJacob Faibussowitsch PetscErrorCode MatMPIAIJSetPreallocation_MPIAIJKokkos(Mat mat, PetscInt d_nz, const PetscInt d_nnz[], PetscInt o_nz, const PetscInt o_nnz[])
29d71ae5a4SJacob Faibussowitsch {
308c3ff71bSJunchao Zhang   Mat_MPIAIJ *mpiaij = (Mat_MPIAIJ *)mat->data;
318c3ff71bSJunchao Zhang 
328c3ff71bSJunchao Zhang   PetscFunctionBegin;
339566063dSJacob Faibussowitsch   PetscCall(PetscLayoutSetUp(mat->rmap));
349566063dSJacob Faibussowitsch   PetscCall(PetscLayoutSetUp(mat->cmap));
356a29ce69SStefano Zampini #if defined(PETSC_USE_DEBUG)
368c3ff71bSJunchao Zhang   if (d_nnz) {
376a29ce69SStefano Zampini     PetscInt i;
38ad540459SPierre Jolivet     for (i = 0; i < mat->rmap->n; i++) PetscCheck(d_nnz[i] >= 0, PETSC_COMM_SELF, PETSC_ERR_ARG_OUTOFRANGE, "d_nnz cannot be less than 0: local row %" PetscInt_FMT " value %" PetscInt_FMT, i, d_nnz[i]);
398c3ff71bSJunchao Zhang   }
408c3ff71bSJunchao Zhang   if (o_nnz) {
416a29ce69SStefano Zampini     PetscInt i;
42ad540459SPierre Jolivet     for (i = 0; i < mat->rmap->n; i++) PetscCheck(o_nnz[i] >= 0, PETSC_COMM_SELF, PETSC_ERR_ARG_OUTOFRANGE, "o_nnz cannot be less than 0: local row %" PetscInt_FMT " value %" PetscInt_FMT, i, o_nnz[i]);
438c3ff71bSJunchao Zhang   }
446a29ce69SStefano Zampini #endif
456a29ce69SStefano Zampini #if defined(PETSC_USE_CTABLE)
46eec179cfSJacob Faibussowitsch   PetscCall(PetscHMapIDestroy(&mpiaij->colmap));
476a29ce69SStefano Zampini #else
489566063dSJacob Faibussowitsch   PetscCall(PetscFree(mpiaij->colmap));
496a29ce69SStefano Zampini #endif
509566063dSJacob Faibussowitsch   PetscCall(PetscFree(mpiaij->garray));
519566063dSJacob Faibussowitsch   PetscCall(VecDestroy(&mpiaij->lvec));
529566063dSJacob Faibussowitsch   PetscCall(VecScatterDestroy(&mpiaij->Mvctx));
536a29ce69SStefano Zampini   /* Because the B will have been resized we simply destroy it and create a new one each time */
549566063dSJacob Faibussowitsch   PetscCall(MatDestroy(&mpiaij->B));
556a29ce69SStefano Zampini 
566a29ce69SStefano Zampini   if (!mpiaij->A) {
579566063dSJacob Faibussowitsch     PetscCall(MatCreate(PETSC_COMM_SELF, &mpiaij->A));
589566063dSJacob Faibussowitsch     PetscCall(MatSetSizes(mpiaij->A, mat->rmap->n, mat->cmap->n, mat->rmap->n, mat->cmap->n));
596a29ce69SStefano Zampini   }
606a29ce69SStefano Zampini   if (!mpiaij->B) {
616a29ce69SStefano Zampini     PetscMPIInt size;
629566063dSJacob Faibussowitsch     PetscCallMPI(MPI_Comm_size(PetscObjectComm((PetscObject)mat), &size));
639566063dSJacob Faibussowitsch     PetscCall(MatCreate(PETSC_COMM_SELF, &mpiaij->B));
649566063dSJacob Faibussowitsch     PetscCall(MatSetSizes(mpiaij->B, mat->rmap->n, size > 1 ? mat->cmap->N : 0, mat->rmap->n, size > 1 ? mat->cmap->N : 0));
658c3ff71bSJunchao Zhang   }
669566063dSJacob Faibussowitsch   PetscCall(MatSetType(mpiaij->A, MATSEQAIJKOKKOS));
679566063dSJacob Faibussowitsch   PetscCall(MatSetType(mpiaij->B, MATSEQAIJKOKKOS));
689566063dSJacob Faibussowitsch   PetscCall(MatSeqAIJSetPreallocation(mpiaij->A, d_nz, d_nnz));
699566063dSJacob Faibussowitsch   PetscCall(MatSeqAIJSetPreallocation(mpiaij->B, o_nz, o_nnz));
708c3ff71bSJunchao Zhang   mat->preallocated = PETSC_TRUE;
713ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
728c3ff71bSJunchao Zhang }
738c3ff71bSJunchao Zhang 
74d71ae5a4SJacob Faibussowitsch PetscErrorCode MatMult_MPIAIJKokkos(Mat mat, Vec xx, Vec yy)
75d71ae5a4SJacob Faibussowitsch {
768c3ff71bSJunchao Zhang   Mat_MPIAIJ *mpiaij = (Mat_MPIAIJ *)mat->data;
778c3ff71bSJunchao Zhang   PetscInt    nt;
788c3ff71bSJunchao Zhang 
798c3ff71bSJunchao Zhang   PetscFunctionBegin;
809566063dSJacob Faibussowitsch   PetscCall(VecGetLocalSize(xx, &nt));
8108401ef6SPierre Jolivet   PetscCheck(nt == mat->cmap->n, PETSC_COMM_SELF, PETSC_ERR_ARG_SIZ, "Incompatible partition of mat (%" PetscInt_FMT ") and xx (%" PetscInt_FMT ")", mat->cmap->n, nt);
829566063dSJacob Faibussowitsch   PetscCall(VecScatterBegin(mpiaij->Mvctx, xx, mpiaij->lvec, INSERT_VALUES, SCATTER_FORWARD));
839566063dSJacob Faibussowitsch   PetscCall((*mpiaij->A->ops->mult)(mpiaij->A, xx, yy));
849566063dSJacob Faibussowitsch   PetscCall(VecScatterEnd(mpiaij->Mvctx, xx, mpiaij->lvec, INSERT_VALUES, SCATTER_FORWARD));
859566063dSJacob Faibussowitsch   PetscCall((*mpiaij->B->ops->multadd)(mpiaij->B, mpiaij->lvec, yy, yy));
863ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
878c3ff71bSJunchao Zhang }
888c3ff71bSJunchao Zhang 
89d71ae5a4SJacob Faibussowitsch PetscErrorCode MatMultAdd_MPIAIJKokkos(Mat mat, Vec xx, Vec yy, Vec zz)
90d71ae5a4SJacob Faibussowitsch {
918c3ff71bSJunchao Zhang   Mat_MPIAIJ *mpiaij = (Mat_MPIAIJ *)mat->data;
928c3ff71bSJunchao Zhang   PetscInt    nt;
938c3ff71bSJunchao Zhang 
948c3ff71bSJunchao Zhang   PetscFunctionBegin;
959566063dSJacob Faibussowitsch   PetscCall(VecGetLocalSize(xx, &nt));
9608401ef6SPierre Jolivet   PetscCheck(nt == mat->cmap->n, PETSC_COMM_SELF, PETSC_ERR_ARG_SIZ, "Incompatible partition of mat (%" PetscInt_FMT ") and xx (%" PetscInt_FMT ")", mat->cmap->n, nt);
979566063dSJacob Faibussowitsch   PetscCall(VecScatterBegin(mpiaij->Mvctx, xx, mpiaij->lvec, INSERT_VALUES, SCATTER_FORWARD));
989566063dSJacob Faibussowitsch   PetscCall((*mpiaij->A->ops->multadd)(mpiaij->A, xx, yy, zz));
999566063dSJacob Faibussowitsch   PetscCall(VecScatterEnd(mpiaij->Mvctx, xx, mpiaij->lvec, INSERT_VALUES, SCATTER_FORWARD));
1009566063dSJacob Faibussowitsch   PetscCall((*mpiaij->B->ops->multadd)(mpiaij->B, mpiaij->lvec, zz, zz));
1013ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
1028c3ff71bSJunchao Zhang }
1038c3ff71bSJunchao Zhang 
104d71ae5a4SJacob Faibussowitsch PetscErrorCode MatMultTranspose_MPIAIJKokkos(Mat mat, Vec xx, Vec yy)
105d71ae5a4SJacob Faibussowitsch {
1068c3ff71bSJunchao Zhang   Mat_MPIAIJ *mpiaij = (Mat_MPIAIJ *)mat->data;
1078c3ff71bSJunchao Zhang   PetscInt    nt;
1088c3ff71bSJunchao Zhang 
1098c3ff71bSJunchao Zhang   PetscFunctionBegin;
1109566063dSJacob Faibussowitsch   PetscCall(VecGetLocalSize(xx, &nt));
11108401ef6SPierre Jolivet   PetscCheck(nt == mat->rmap->n, PETSC_COMM_SELF, PETSC_ERR_ARG_SIZ, "Incompatible partition of mat (%" PetscInt_FMT ") and xx (%" PetscInt_FMT ")", mat->rmap->n, nt);
1129566063dSJacob Faibussowitsch   PetscCall((*mpiaij->B->ops->multtranspose)(mpiaij->B, xx, mpiaij->lvec));
1139566063dSJacob Faibussowitsch   PetscCall((*mpiaij->A->ops->multtranspose)(mpiaij->A, xx, yy));
1149566063dSJacob Faibussowitsch   PetscCall(VecScatterBegin(mpiaij->Mvctx, mpiaij->lvec, yy, ADD_VALUES, SCATTER_REVERSE));
1159566063dSJacob Faibussowitsch   PetscCall(VecScatterEnd(mpiaij->Mvctx, mpiaij->lvec, yy, ADD_VALUES, SCATTER_REVERSE));
1163ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
1178c3ff71bSJunchao Zhang }
1188c3ff71bSJunchao Zhang 
119076ba34aSJunchao Zhang /* Merge the "A, B" matrices of mat into a matrix C.  mat's type is MPIAIJKOKKOS. C's type is MATSEQAIJKOKKOS.
120076ba34aSJunchao Zhang    A is put before B. C's size would be A->rmap->n by (A->cmap->n + B->cmap->n).
121076ba34aSJunchao Zhang    C still uses local column ids. Their corresponding global column ids are returned in glob.
122076ba34aSJunchao Zhang */
123d71ae5a4SJacob Faibussowitsch PetscErrorCode MatMPIAIJGetLocalMatMerge_MPIAIJKokkos(Mat mat, MatReuse reuse, IS *glob, Mat *C)
124d71ae5a4SJacob Faibussowitsch {
125076ba34aSJunchao Zhang   Mat             Ad, Ao;
126076ba34aSJunchao Zhang   const PetscInt *cmap;
127076ba34aSJunchao Zhang 
128076ba34aSJunchao Zhang   PetscFunctionBegin;
1299566063dSJacob Faibussowitsch   PetscCall(MatMPIAIJGetSeqAIJ(mat, &Ad, &Ao, &cmap));
1309566063dSJacob Faibussowitsch   PetscCall(MatSeqAIJKokkosMergeMats(Ad, Ao, reuse, C));
131076ba34aSJunchao Zhang   if (glob) {
132076ba34aSJunchao Zhang     PetscInt cst, i, dn, on, *gidx;
1339566063dSJacob Faibussowitsch     PetscCall(MatGetLocalSize(Ad, NULL, &dn));
1349566063dSJacob Faibussowitsch     PetscCall(MatGetLocalSize(Ao, NULL, &on));
1359566063dSJacob Faibussowitsch     PetscCall(MatGetOwnershipRangeColumn(mat, &cst, NULL));
1369566063dSJacob Faibussowitsch     PetscCall(PetscMalloc1(dn + on, &gidx));
137076ba34aSJunchao Zhang     for (i = 0; i < dn; i++) gidx[i] = cst + i;
138076ba34aSJunchao Zhang     for (i = 0; i < on; i++) gidx[i + dn] = cmap[i];
1399566063dSJacob Faibussowitsch     PetscCall(ISCreateGeneral(PetscObjectComm((PetscObject)Ad), dn + on, gidx, PETSC_OWN_POINTER, glob));
140076ba34aSJunchao Zhang   }
1413ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
142076ba34aSJunchao Zhang }
143076ba34aSJunchao Zhang 
1440e3ece09SJunchao Zhang /* Structs used in matrix products of type C=AB, C=A^tB and C=B^tAB */
145076ba34aSJunchao Zhang struct MatMatStruct {
1460e3ece09SJunchao Zhang   PetscInt            n, *garray;     // C's garray and its size.
1470e3ece09SJunchao Zhang   KokkosCsrMatrix     Cd, Co;         // C is in split form matrices (all in local column indcies)
1480e3ece09SJunchao Zhang   KokkosCsrMatrix     C1, C2, C3, C4; // intermediate mat products
1490e3ece09SJunchao Zhang   KokkosCsrMatrix     C2_mid, C4_mid; // alias of C2, C4; share their a[], i[], but with different j[] (hence column size)
1500e3ece09SJunchao Zhang   PetscIntKokkosView  E_NzLeft;
1510e3ece09SJunchao Zhang   PetscSF             sf = nullptr; // SF to bcast or reduce matrices E to F
1520e3ece09SJunchao Zhang   MatScalarKokkosView rootBuf, leafBuf;
1530e3ece09SJunchao Zhang   KokkosCsrMatrix     Fd, Fo; // F in split form
1540e3ece09SJunchao Zhang 
1550e3ece09SJunchao Zhang   KernelHandle kh1; // compute C1, add C1+C3 or C1+Fd
1560e3ece09SJunchao Zhang   KernelHandle kh2; // compute C2, add C2+C4 or C2+Fo
1570e3ece09SJunchao Zhang   KernelHandle kh3; // compute C3
1580e3ece09SJunchao Zhang   KernelHandle kh4; // compute C4
1590e3ece09SJunchao Zhang 
160aaa8cc7dSPierre Jolivet   PetscInt E_TeamSize; // kernel launching parameters in merging E or splitting F
1610e3ece09SJunchao Zhang   PetscInt E_VectorLength;
1620e3ece09SJunchao Zhang   PetscInt E_RowsPerTeam;
1630e3ece09SJunchao Zhang   PetscInt F_TeamSize;
1640e3ece09SJunchao Zhang   PetscInt F_VectorLength;
1650e3ece09SJunchao Zhang   PetscInt F_RowsPerTeam;
166076ba34aSJunchao Zhang 
167d71ae5a4SJacob Faibussowitsch   ~MatMatStruct()
168d71ae5a4SJacob Faibussowitsch   {
1693ba16761SJacob Faibussowitsch     PetscFunctionBegin;
1703ba16761SJacob Faibussowitsch     PetscCallAbort(PETSC_COMM_SELF, PetscSFDestroy(&sf));
1713ba16761SJacob Faibussowitsch     PetscFunctionReturnVoid();
172076ba34aSJunchao Zhang   }
173076ba34aSJunchao Zhang };
174076ba34aSJunchao Zhang 
175076ba34aSJunchao Zhang struct MatMatStruct_AB : public MatMatStruct {
1760e3ece09SJunchao Zhang   PetscIntKokkosView F_NzLeft; // plans to split F (in leafbuf) into Fd, Fo
1770e3ece09SJunchao Zhang   PetscIntKokkosView irootloc; // plans to put E (i.e., Bd, Bo) into rootBuf
1780e3ece09SJunchao Zhang   PetscIntKokkosView rowoffset;
179076ba34aSJunchao Zhang };
180076ba34aSJunchao Zhang 
181076ba34aSJunchao Zhang struct MatMatStruct_AtB : public MatMatStruct {
1820e3ece09SJunchao Zhang   MatColIdxKokkosView Fdjmap; // plans to reduce data in rootBuf to Fd, Fo
1830e3ece09SJunchao Zhang   MatColIdxKokkosView Fdjperm;
1840e3ece09SJunchao Zhang   MatColIdxKokkosView Fojmap;
1850e3ece09SJunchao Zhang   MatColIdxKokkosView Fojperm;
186076ba34aSJunchao Zhang };
187076ba34aSJunchao Zhang 
1889371c9d4SSatish Balay struct MatProductData_MPIAIJKokkos {
1893ba16761SJacob Faibussowitsch   MatMatStruct_AB  *mmAB     = nullptr;
1903ba16761SJacob Faibussowitsch   MatMatStruct_AtB *mmAtB    = nullptr;
1913ba16761SJacob Faibussowitsch   PetscBool         reusesym = PETSC_FALSE;
1920e3ece09SJunchao Zhang   Mat               Z        = nullptr; // store Z=AB in computing BtAB
193076ba34aSJunchao Zhang 
194d71ae5a4SJacob Faibussowitsch   ~MatProductData_MPIAIJKokkos()
195d71ae5a4SJacob Faibussowitsch   {
196076ba34aSJunchao Zhang     delete mmAB;
197076ba34aSJunchao Zhang     delete mmAtB;
1980e3ece09SJunchao Zhang     PetscCallAbort(PETSC_COMM_SELF, MatDestroy(&Z));
199076ba34aSJunchao Zhang   }
200076ba34aSJunchao Zhang };
201076ba34aSJunchao Zhang 
202d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatProductDataDestroy_MPIAIJKokkos(void *data)
203d71ae5a4SJacob Faibussowitsch {
204076ba34aSJunchao Zhang   PetscFunctionBegin;
2059566063dSJacob Faibussowitsch   PetscCallCXX(delete static_cast<MatProductData_MPIAIJKokkos *>(data));
2063ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
207076ba34aSJunchao Zhang }
208076ba34aSJunchao Zhang 
209076ba34aSJunchao Zhang /* MatSetMPIAIJKokkosWithSplitSeqAIJKokkosMatrices - Set the diag and offdiag matrices of a MATMPIAIJKOKKOS matrix.
210076ba34aSJunchao Zhang    It is similar to MatCreateMPIAIJWithSplitArrays.
211076ba34aSJunchao Zhang 
212076ba34aSJunchao Zhang   Input Parameters:
213076ba34aSJunchao Zhang +  mat   - the MATMPIAIJKOKKOS matrix, which should have its type and layout set, but should not have its diag, offdiag matrices set
214076ba34aSJunchao Zhang .  A     - the diag matrix using local col ids
215076ba34aSJunchao Zhang -  B     - the offdiag matrix using global col ids
216076ba34aSJunchao Zhang 
2172fe279fdSBarry Smith   Output Parameter:
218076ba34aSJunchao Zhang .  mat   - the updated MATMPIAIJKOKKOS matrix
219076ba34aSJunchao Zhang */
2200e3ece09SJunchao Zhang static PetscErrorCode MatSetMPIAIJKokkosWithSplitSeqAIJKokkosMatrices(Mat mat, Mat A, Mat B, PetscInt *garray)
221d71ae5a4SJacob Faibussowitsch {
222076ba34aSJunchao Zhang   Mat_MPIAIJ *mpiaij = static_cast<Mat_MPIAIJ *>(mat->data);
223076ba34aSJunchao Zhang   PetscInt    m, n, M, N, Am, An, Bm, Bn;
224076ba34aSJunchao Zhang 
225076ba34aSJunchao Zhang   PetscFunctionBegin;
2269566063dSJacob Faibussowitsch   PetscCall(MatGetSize(mat, &M, &N));
2279566063dSJacob Faibussowitsch   PetscCall(MatGetLocalSize(mat, &m, &n));
2289566063dSJacob Faibussowitsch   PetscCall(MatGetLocalSize(A, &Am, &An));
2299566063dSJacob Faibussowitsch   PetscCall(MatGetLocalSize(B, &Bm, &Bn));
230076ba34aSJunchao Zhang 
231aed4548fSBarry Smith   PetscCheck(m == Am && m == Bm, PETSC_COMM_SELF, PETSC_ERR_PLIB, "local number of rows do not match");
23208401ef6SPierre Jolivet   PetscCheck(n == An, PETSC_COMM_SELF, PETSC_ERR_PLIB, "local number of columns do not match");
2330e3ece09SJunchao Zhang   // PetscCheck(N == Bn, PETSC_COMM_SELF, PETSC_ERR_PLIB, "global number of columns do not match");
23408401ef6SPierre Jolivet   PetscCheck(!mpiaij->A && !mpiaij->B, PETSC_COMM_SELF, PETSC_ERR_PLIB, "A, B of the MPIAIJ matrix are not empty");
235076ba34aSJunchao Zhang   mpiaij->A      = A;
236076ba34aSJunchao Zhang   mpiaij->B      = B;
2370e3ece09SJunchao Zhang   mpiaij->garray = garray;
238076ba34aSJunchao Zhang 
239076ba34aSJunchao Zhang   mat->preallocated     = PETSC_TRUE;
240076ba34aSJunchao Zhang   mat->nooffprocentries = PETSC_TRUE; /* See MatAssemblyBegin_MPIAIJ. In effect, making MatAssemblyBegin a nop */
241076ba34aSJunchao Zhang 
2429566063dSJacob Faibussowitsch   PetscCall(MatSetOption(mat, MAT_NO_OFF_PROC_ENTRIES, PETSC_TRUE));
2439566063dSJacob Faibussowitsch   PetscCall(MatAssemblyBegin(mat, MAT_FINAL_ASSEMBLY));
244076ba34aSJunchao Zhang   /* MatAssemblyEnd is critical here. It sets mat->offloadmask according to A and B's, and
245076ba34aSJunchao Zhang     also gets mpiaij->B compacted, with its col ids and size reduced
246076ba34aSJunchao Zhang   */
2479566063dSJacob Faibussowitsch   PetscCall(MatAssemblyEnd(mat, MAT_FINAL_ASSEMBLY));
2489566063dSJacob Faibussowitsch   PetscCall(MatSetOption(mat, MAT_NO_OFF_PROC_ENTRIES, PETSC_FALSE));
2499566063dSJacob Faibussowitsch   PetscCall(MatSetOption(mat, MAT_NEW_NONZERO_LOCATION_ERR, PETSC_TRUE));
2503ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
251076ba34aSJunchao Zhang }
252076ba34aSJunchao Zhang 
2530e3ece09SJunchao Zhang // Adapted from Kokkos-Kernels spmv_launch_parameters(), to get parameters in Kokkos nested loops which we used to merge or
2540e3ece09SJunchao Zhang // split csr matrices. The rule is to have "vector_length * team_size" be around 256 on GPUs (e.g., for a CUDA thread block)
2550e3ece09SJunchao Zhang template <class ExecutionSpace>
2560e3ece09SJunchao Zhang static PetscErrorCode MatMergeGetLaunchParameters(PetscInt numRows, PetscInt nnz, PetscInt rows_per_thread, PetscInt &team_size, PetscInt &vector_length, PetscInt &rows_per_team)
257d71ae5a4SJacob Faibussowitsch {
2580e3ece09SJunchao Zhang   Kokkos::TeamPolicy<ExecutionSpace> teamPolicy(128, Kokkos::AUTO);
259076ba34aSJunchao Zhang 
260076ba34aSJunchao Zhang   PetscFunctionBegin;
2610e3ece09SJunchao Zhang   PetscInt nnz_per_row = numRows ? (nnz / numRows) : 0; // we might meet empty matrices
262076ba34aSJunchao Zhang 
2630e3ece09SJunchao Zhang   if (nnz_per_row < 1) nnz_per_row = 1;
264076ba34aSJunchao Zhang 
2650e3ece09SJunchao Zhang   int max_vector_length = teamPolicy.vector_length_max();
266076ba34aSJunchao Zhang 
2670e3ece09SJunchao Zhang   if (vector_length < 1) {
2680e3ece09SJunchao Zhang     vector_length = 1;
2690e3ece09SJunchao Zhang     while (vector_length < max_vector_length && vector_length * 6 < nnz_per_row) vector_length *= 2;
270076ba34aSJunchao Zhang   }
271076ba34aSJunchao Zhang 
2720e3ece09SJunchao Zhang   // Determine rows per thread
2730e3ece09SJunchao Zhang   if (rows_per_thread < 1) {
2740e3ece09SJunchao Zhang     if (KokkosKernels::Impl::kk_is_gpu_exec_space<ExecutionSpace>()) rows_per_thread = 1;
2750e3ece09SJunchao Zhang     else {
2760e3ece09SJunchao Zhang       if (nnz_per_row < 20 && nnz > 5000000) {
2770e3ece09SJunchao Zhang         rows_per_thread = 256;
2780e3ece09SJunchao Zhang       } else rows_per_thread = 64;
279076ba34aSJunchao Zhang     }
280076ba34aSJunchao Zhang   }
281076ba34aSJunchao Zhang 
2820e3ece09SJunchao Zhang   if (team_size < 1) {
2830e3ece09SJunchao Zhang     if (KokkosKernels::Impl::kk_is_gpu_exec_space<ExecutionSpace>()) {
2840e3ece09SJunchao Zhang       team_size = 256 / vector_length;
285076ba34aSJunchao Zhang     } else {
2860e3ece09SJunchao Zhang       team_size = 1;
2870e3ece09SJunchao Zhang     }
288076ba34aSJunchao Zhang   }
289076ba34aSJunchao Zhang 
2900e3ece09SJunchao Zhang   rows_per_team = rows_per_thread * team_size;
291076ba34aSJunchao Zhang 
2920e3ece09SJunchao Zhang   if (rows_per_team < 0) {
2930e3ece09SJunchao Zhang     PetscInt nnz_per_team = 4096;
2940e3ece09SJunchao Zhang     PetscInt conc         = ExecutionSpace().concurrency();
2950e3ece09SJunchao Zhang     while ((conc * nnz_per_team * 4 > nnz) && (nnz_per_team > 256)) nnz_per_team /= 2;
2960e3ece09SJunchao Zhang     rows_per_team = (nnz_per_team + nnz_per_row - 1) / nnz_per_row;
2970e3ece09SJunchao Zhang   }
2983ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
299076ba34aSJunchao Zhang }
300076ba34aSJunchao Zhang 
3010e3ece09SJunchao Zhang /*
3020e3ece09SJunchao Zhang   Reduce two sets of global indices into local ones
303076ba34aSJunchao Zhang 
304076ba34aSJunchao Zhang   Input Parameters:
3050e3ece09SJunchao Zhang +  n1          - size of garray1[], the first set
3060e3ece09SJunchao Zhang .  garray1[n1] - a sorted global index array (without duplicates)
3070e3ece09SJunchao Zhang .  m           - size of indices[], the second set
3080e3ece09SJunchao Zhang -  indices[m]  - a unsorted global index array (might have duplicates), which will be updated on output into local ones
309076ba34aSJunchao Zhang 
310076ba34aSJunchao Zhang   Output Parameters:
3110e3ece09SJunchao Zhang +  n2          - size of garray2[], the merged set, which combines garray1[] and indices[]
3120e3ece09SJunchao Zhang .  garray2[n2] - allocated by callee using PetscMalloc1(). Contains sorted unique global indices (without duplicates). Caller needs to free it.
3130e3ece09SJunchao Zhang .  map[n1]     - allocated by caller. It gives garray1[i] = garray2[map[i]]
3140e3ece09SJunchao Zhang -  indices[m]  - on output, global indices in this array are rewritten with local ones, i.e, indices_input[i] = garray2[indices_output[i]]
315076ba34aSJunchao Zhang 
3160e3ece09SJunchao Zhang    Example, say
3170e3ece09SJunchao Zhang     n1         = 5
3180e3ece09SJunchao Zhang     garray1[5] = {1, 4, 7, 8, 10}
3190e3ece09SJunchao Zhang     m          = 4
3200e3ece09SJunchao Zhang     indices[4] = {2, 4, 8, 9}
32111a5261eSBarry Smith 
3220e3ece09SJunchao Zhang    Combining them together, we have 7 global indices in garray2[]
3230e3ece09SJunchao Zhang     n2         = 7
3240e3ece09SJunchao Zhang     garray2[7] = {1, 2, 4, 7, 8, 9, 10}
3250e3ece09SJunchao Zhang 
3260e3ece09SJunchao Zhang    And we have map[] to connect "garray1[i] = garray2[map[i]], i=[0,n1)"
3270e3ece09SJunchao Zhang     map[5] = {0, 2, 3, 4, 6}
3280e3ece09SJunchao Zhang 
3290e3ece09SJunchao Zhang    On output, indices[] is updated with local indices
3300e3ece09SJunchao Zhang     indices[4] = {1, 2, 4, 5}
331076ba34aSJunchao Zhang */
3320e3ece09SJunchao Zhang static PetscErrorCode ReduceTwoSetsOfGlobalIndices(PetscInt n1, const PetscInt *garray1, PetscInt m, PetscInt *indices, PetscInt *n2_, PetscInt **garray2_, PetscInt *map)
333d71ae5a4SJacob Faibussowitsch {
3340e3ece09SJunchao Zhang   PetscHMapI    g2l = nullptr;
3350e3ece09SJunchao Zhang   PetscHashIter iter;
3360e3ece09SJunchao Zhang   PetscInt      tot, key, val; // total unique global indices. key is global id; val is local id
3370e3ece09SJunchao Zhang   PetscInt      n2, *garray2;
338076ba34aSJunchao Zhang 
339076ba34aSJunchao Zhang   PetscFunctionBegin;
3400e3ece09SJunchao Zhang   tot = 0;
3410e3ece09SJunchao Zhang   PetscCall(PetscHMapICreateWithSize(n1, &g2l));
3420e3ece09SJunchao Zhang   for (PetscInt i = 0; i < m; i++) {                                // insert those in indices[]
3430e3ece09SJunchao Zhang     PetscCall(PetscHMapIGetWithDefault(g2l, indices[i], -1, &val)); // if not exist, val is set with -1
3440e3ece09SJunchao Zhang     if (val < 0) PetscCall(PetscHMapISet(g2l, indices[i], tot++));  // val < 0 means gid is not in the hash table yet
345076ba34aSJunchao Zhang   }
346076ba34aSJunchao Zhang 
3470e3ece09SJunchao Zhang   for (PetscInt i = 0; i < n1; i++) { // insert those in garray1[]
3480e3ece09SJunchao Zhang     PetscCall(PetscHMapIGetWithDefault(g2l, garray1[i], -1, &val));
3490e3ece09SJunchao Zhang     if (val < 0) PetscCall(PetscHMapISet(g2l, garray1[i], tot++));
350076ba34aSJunchao Zhang   }
351076ba34aSJunchao Zhang 
3520e3ece09SJunchao Zhang   // Pull out (unique) globals in the hash table and put them in garray2[]
3530e3ece09SJunchao Zhang   n2 = tot;
3540e3ece09SJunchao Zhang   PetscCall(PetscMalloc1(n2, &garray2));
3550e3ece09SJunchao Zhang   tot = 0;
3560e3ece09SJunchao Zhang   PetscHashIterBegin(g2l, iter);
3570e3ece09SJunchao Zhang   while (!PetscHashIterAtEnd(g2l, iter)) {
3580e3ece09SJunchao Zhang     PetscHashIterGetKey(g2l, iter, key);
3590e3ece09SJunchao Zhang     PetscHashIterNext(g2l, iter);
3600e3ece09SJunchao Zhang     garray2[tot++] = key;
361076ba34aSJunchao Zhang   }
362076ba34aSJunchao Zhang 
3630e3ece09SJunchao Zhang   // Sort garray2[] and then map them to local indices starting from 0
3640e3ece09SJunchao Zhang   PetscCall(PetscSortInt(n2, garray2));
3650e3ece09SJunchao Zhang   PetscCall(PetscHMapIClear(g2l));
3660e3ece09SJunchao Zhang   for (PetscInt i = 0; i < tot; i++) PetscCall(PetscHMapISet(g2l, garray2[i], i)); // i is the local id
367f0e6e2d1SJunchao Zhang 
3680e3ece09SJunchao Zhang   // Rewrite indices[] with local indices
369f0e6e2d1SJunchao Zhang   for (PetscInt i = 0; i < m; i++) {
3700e3ece09SJunchao Zhang     PetscCall(PetscHMapIGetWithDefault(g2l, indices[i], -1, &val));
3710e3ece09SJunchao Zhang     PetscAssert(val >= 0, PETSC_COMM_SELF, PETSC_ERR_PLIB, "Met a negative local column index");
3720e3ece09SJunchao Zhang     indices[i] = val;
3730e3ece09SJunchao Zhang   }
3740e3ece09SJunchao Zhang   // Record the map that maps garray1[i] to garray2[map[i]]
3750e3ece09SJunchao Zhang   for (PetscInt i = 0; i < n1; i++) PetscCall(PetscHMapIGetWithDefault(g2l, garray1[i], -1, &map[i]));
3760e3ece09SJunchao Zhang   PetscCall(PetscHMapIDestroy(&g2l));
3770e3ece09SJunchao Zhang   *n2_      = n2;
3780e3ece09SJunchao Zhang   *garray2_ = garray2;
3790e3ece09SJunchao Zhang   PetscFunctionReturn(PETSC_SUCCESS);
3800e3ece09SJunchao Zhang }
381f0e6e2d1SJunchao Zhang 
3820e3ece09SJunchao Zhang /*
3830e3ece09SJunchao Zhang   MatMPIAIJKokkosReduce - Reduce rows of a MPIAIJKOKKOS matrix (E, in split form) to produce another matrix (F, also in split form, stored in mm)
3840e3ece09SJunchao Zhang 
3850e3ece09SJunchao Zhang   It is the reverse of MatMPIAIJKokkosBcast() in some sense, but with a different signature since we do not really need a fully populated MPIAIJKOKKOS E.
3860e3ece09SJunchao Zhang 
3870e3ece09SJunchao Zhang   Think each row of E as a leaf, then the given ownerSF specifies roots for the leaves. Roots may connect to multiple leaves.
3880e3ece09SJunchao Zhang   In this routine, we sparse-merge leaves (rows) at their roots to form potentially longer rows in F. F's number of rows will be nroots of ownerSF.
3890e3ece09SJunchao Zhang 
3900e3ece09SJunchao Zhang   Input Parameters:
3910e3ece09SJunchao Zhang +  comm       - MPI communicator of E
3920e3ece09SJunchao Zhang .  A          - diag block of E, using local column indices
3930e3ece09SJunchao Zhang .  B          - off-diag block of E, using local column indices
3940e3ece09SJunchao Zhang .  cstart      - (global) start column of Ed
3950e3ece09SJunchao Zhang .  cend        - (global) end column + 1 of Ed.  In other words, E's column ownership is in range of [cstart, cend)
3960e3ece09SJunchao Zhang .  garray1[n1] - global column indices of Eo. Here n1 is Eo's column size.
3970e3ece09SJunchao Zhang .  ownerSF     - the SF specifies ownership (root) of rows in E
3980e3ece09SJunchao Zhang .  reuse       - either MAT_INITIAL_MATRIX or MAT_REUSE_MATRIX
3990e3ece09SJunchao Zhang -  mm          - to stash intermediate data structures for reuse
4000e3ece09SJunchao Zhang 
4010e3ece09SJunchao Zhang   Output Parameters:
4020e3ece09SJunchao Zhang +  map[n1]  - allocated by caller. It maps garray1[] to garray2[]. See more at ReduceTwoSetsOfGlobalIndices().
4030e3ece09SJunchao Zhang -  mm       - contains various info, such as garray2[], F (Fd, Fo) etc.
4040e3ece09SJunchao Zhang 
4050e3ece09SJunchao Zhang   Notes:
4060e3ece09SJunchao Zhang   When reuse = MAT_REUSE_MATRIX, cstart, cend, garray1, ownerSF, map are not significant.
4070e3ece09SJunchao Zhang 
4080e3ece09SJunchao Zhang  */
4090e3ece09SJunchao Zhang static PetscErrorCode MatMPIAIJKokkosReduceBegin(MPI_Comm comm, KokkosCsrMatrix A, KokkosCsrMatrix B, PetscInt cstart, PetscInt cend, const PetscInt *garray1, PetscSF ownerSF, MatReuse reuse, PetscInt *map, MatMatStruct_AtB *mm)
4100e3ece09SJunchao Zhang {
4110e3ece09SJunchao Zhang   PetscFunctionBegin;
4120e3ece09SJunchao Zhang   if (reuse == MAT_INITIAL_MATRIX) {
4130e3ece09SJunchao Zhang     PetscInt Em = A.numRows(), Fm;
4140e3ece09SJunchao Zhang     PetscInt n1 = B.numCols();
4150e3ece09SJunchao Zhang 
4160e3ece09SJunchao Zhang     PetscCall(PetscSFGetGraph(ownerSF, &Fm, NULL, NULL, NULL)); // Fm = #rows of F = nroots of ownerSF
4170e3ece09SJunchao Zhang 
4180e3ece09SJunchao Zhang     // Do the analysis on host
4190e3ece09SJunchao Zhang     auto                 Ai_h = Kokkos::create_mirror_view_and_copy(Kokkos::HostSpace(), A.graph.row_map);
4200e3ece09SJunchao Zhang     auto                 Aj_h = Kokkos::create_mirror_view_and_copy(Kokkos::HostSpace(), A.graph.entries);
4210e3ece09SJunchao Zhang     auto                 Bi_h = Kokkos::create_mirror_view_and_copy(Kokkos::HostSpace(), B.graph.row_map);
4220e3ece09SJunchao Zhang     auto                 Bj_h = Kokkos::create_mirror_view_and_copy(Kokkos::HostSpace(), B.graph.entries);
4230e3ece09SJunchao Zhang     const MatRowMapType *Ai = Ai_h.data(), *Bi = Bi_h.data();
4240e3ece09SJunchao Zhang     const MatColIdxType *Aj = Aj_h.data(), *Bj = Bj_h.data();
4250e3ece09SJunchao Zhang 
4260e3ece09SJunchao Zhang     // Count how many nonzeros of each row in E are in the left of the diag block [cstart,cend)
427*7b8d4ba6SJunchao Zhang     PetscIntKokkosViewHost E_NzLeft_h(NoInit("E_NzLeft_h"), Em), E_RowLen_h(NoInit("E_RowLen_h"), Em);
4280e3ece09SJunchao Zhang     PetscInt              *E_NzLeft = E_NzLeft_h.data(), *E_RowLen = E_RowLen_h.data();
4290e3ece09SJunchao Zhang     for (PetscInt i = 0; i < Em; i++) {
4300e3ece09SJunchao Zhang       const PetscInt *first, *last, *it;
4310e3ece09SJunchao Zhang       PetscInt        count, step;
4320e3ece09SJunchao Zhang       // std::lower_bound(first,last,cstart), but need to use global column indices
4330e3ece09SJunchao Zhang       first = Bj + Bi[i];
4340e3ece09SJunchao Zhang       last  = Bj + Bi[i + 1];
435f0e6e2d1SJunchao Zhang       count = last - first;
436f0e6e2d1SJunchao Zhang       while (count > 0) {
437f0e6e2d1SJunchao Zhang         it   = first;
438f0e6e2d1SJunchao Zhang         step = count / 2;
439f0e6e2d1SJunchao Zhang         it += step;
4400e3ece09SJunchao Zhang         if (garray1[*it] < cstart) { // map local to global
441f0e6e2d1SJunchao Zhang           first = ++it;
442f0e6e2d1SJunchao Zhang           count -= step + 1;
443f0e6e2d1SJunchao Zhang         } else count = step;
444f0e6e2d1SJunchao Zhang       }
4450e3ece09SJunchao Zhang       E_NzLeft[i] = first - (Bj + Bi[i]);
4460e3ece09SJunchao Zhang       E_RowLen[i] = (Ai[i + 1] - Ai[i]) + (Bi[i + 1] - Bi[i]);
447f0e6e2d1SJunchao Zhang     }
448f0e6e2d1SJunchao Zhang 
4490e3ece09SJunchao Zhang     // Get length of rows (i.e., sizes of leaves) that contribute to my roots
4500e3ece09SJunchao Zhang     const PetscMPIInt *iranks, *ranks;
4510e3ece09SJunchao Zhang     const PetscInt    *ioffset, *irootloc, *roffset, *rmine;
4520e3ece09SJunchao Zhang     PetscInt           niranks, nranks;
4530e3ece09SJunchao Zhang     MPI_Request       *reqs;
4540e3ece09SJunchao Zhang     PetscMPIInt        tag;
4550e3ece09SJunchao Zhang     PetscSF            reduceSF;
4560e3ece09SJunchao Zhang     PetscInt          *sdisp, *rdisp;
457f0e6e2d1SJunchao Zhang 
4580e3ece09SJunchao Zhang     PetscCall(PetscCommGetNewTag(comm, &tag));
4590e3ece09SJunchao Zhang     PetscCall(PetscSFGetLeafRanks(ownerSF, &niranks, &iranks, &ioffset, &irootloc));  // get leaf ranks connecting to roots on this process (I'll recv from them)
4600e3ece09SJunchao Zhang     PetscCall(PetscSFGetRootRanks(ownerSF, &nranks, &ranks, &roffset, &rmine, NULL)); // get root ranks this process connects (I'll send to them)
461f0e6e2d1SJunchao Zhang 
4620e3ece09SJunchao Zhang     // Find out length of each row I will receive. Even for the same row index, when they are from
4630e3ece09SJunchao Zhang     // different senders, they might have different lengths (and sparsity patterns)
4640e3ece09SJunchao Zhang     PetscInt  sendRowCnt = roffset[nranks], recvRowCnt = ioffset[niranks];
4650e3ece09SJunchao Zhang     PetscInt *sendRowLen, *recvRowLen; // lengths of rows of I need to send/recv per process
466f0e6e2d1SJunchao Zhang 
4670e3ece09SJunchao Zhang     PetscCall(PetscMalloc5(sendRowCnt, &sendRowLen, recvRowCnt + 1, &recvRowLen, nranks, &sdisp, niranks + 1, &rdisp, nranks + niranks, &reqs));
4680e3ece09SJunchao Zhang 
4690e3ece09SJunchao Zhang     for (PetscInt i = 0; i < sendRowCnt; i++) sendRowLen[i] = E_RowLen[rmine[i]];
4700e3ece09SJunchao Zhang     recvRowLen[0] = 0; // since we will make it in CSR format later
4710e3ece09SJunchao Zhang     recvRowLen++;      // advance the pointer now
4720e3ece09SJunchao Zhang     for (PetscInt i = 0; i < niranks; i++) { MPI_Irecv(&recvRowLen[ioffset[i]], ioffset[i + 1] - ioffset[i], MPIU_INT, iranks[i], tag, comm, &reqs[nranks + i]); }
4730e3ece09SJunchao Zhang     for (PetscInt i = 0; i < nranks; i++) { MPI_Isend(&sendRowLen[roffset[i]], roffset[i + 1] - roffset[i], MPIU_INT, ranks[i], tag, comm, &reqs[i]); }
4740e3ece09SJunchao Zhang     PetscCallMPI(MPI_Waitall(nranks + niranks, reqs, MPI_STATUSES_IGNORE));
4750e3ece09SJunchao Zhang 
4760e3ece09SJunchao Zhang     // Build the real PetscSF for reducing E rows (buffer to buffer)
4770e3ece09SJunchao Zhang     rdisp[0] = 0;
4780e3ece09SJunchao Zhang     for (PetscInt i = 0; i < niranks; i++) {
4790e3ece09SJunchao Zhang       rdisp[i + 1] = rdisp[i];
4800e3ece09SJunchao Zhang       for (PetscInt j = ioffset[i]; j < ioffset[i + 1]; j++) { rdisp[i + 1] += recvRowLen[j]; }
4810e3ece09SJunchao Zhang     }
4820e3ece09SJunchao Zhang     recvRowLen--; // put it back into csr format
4830e3ece09SJunchao Zhang     for (PetscInt i = 0; i < recvRowCnt; i++) recvRowLen[i + 1] += recvRowLen[i];
4840e3ece09SJunchao Zhang 
4850e3ece09SJunchao Zhang     for (PetscInt i = 0; i < nranks; i++) { MPI_Irecv(&sdisp[i], 1, MPIU_INT, ranks[i], tag, comm, &reqs[i]); }
4860e3ece09SJunchao Zhang     for (PetscInt i = 0; i < niranks; i++) { MPI_Isend(&rdisp[i], 1, MPIU_INT, iranks[i], tag, comm, &reqs[nranks + i]); }
4870e3ece09SJunchao Zhang     PetscCallMPI(MPI_Waitall(nranks + niranks, reqs, MPI_STATUSES_IGNORE));
4880e3ece09SJunchao Zhang 
4890e3ece09SJunchao Zhang     PetscInt     nleaves = 0, Enz = 0;    // leaves are nonzeros I will send
4900e3ece09SJunchao Zhang     PetscInt     nroots = rdisp[niranks]; // roots are nonzeros I will recv
4910e3ece09SJunchao Zhang     PetscSFNode *iremote;
4920e3ece09SJunchao Zhang 
4930e3ece09SJunchao Zhang     for (PetscInt i = 0; i < Em; i++) Enz += E_RowLen[i];
4940e3ece09SJunchao Zhang     PetscAssert(A.nnz() + B.nnz() == Enz, comm, PETSC_ERR_PLIB, "Enz should be equal to sum of nnz of A and B");
4950e3ece09SJunchao Zhang     PetscCallMPI(PetscMalloc1(Enz, &iremote)); // no free, since we give ownership to reduceSF
4960e3ece09SJunchao Zhang 
4970e3ece09SJunchao Zhang     for (PetscInt i = 0; i < nranks; i++) {
4980e3ece09SJunchao Zhang       PetscInt count = 0;
4990e3ece09SJunchao Zhang       for (PetscInt j = roffset[i]; j < roffset[i + 1]; j++) count += E_RowLen[rmine[j]];
5000e3ece09SJunchao Zhang       for (PetscInt j = 0; j < count; j++) {
5010e3ece09SJunchao Zhang         iremote[nleaves + j].rank  = ranks[i];
5020e3ece09SJunchao Zhang         iremote[nleaves + j].index = sdisp[i] + j;
5030e3ece09SJunchao Zhang       }
5040e3ece09SJunchao Zhang       nleaves += count;
5050e3ece09SJunchao Zhang     }
5060e3ece09SJunchao Zhang     PetscCheck(nleaves == Enz, comm, PETSC_ERR_PLIB, "nleaves should be equal to Enz");
5070e3ece09SJunchao Zhang 
5080e3ece09SJunchao Zhang     PetscCall(PetscSFCreate(comm, &reduceSF));
5090e3ece09SJunchao Zhang     PetscCall(PetscSFSetGraph(reduceSF, nroots, nleaves, NULL, PETSC_OWN_POINTER, iremote, PETSC_OWN_POINTER));
5100e3ece09SJunchao Zhang 
5110e3ece09SJunchao Zhang     // Copy (global) column indices of the needed rows in E to sendCol[], and then PetscSFReduce to recvCol[]
5120e3ece09SJunchao Zhang     PetscInt *sendCol, *recvCol;
5130e3ece09SJunchao Zhang     PetscCall(PetscMalloc2(nleaves, &sendCol, nroots, &recvCol));
5140e3ece09SJunchao Zhang     for (PetscInt k = 0; k < roffset[nranks]; k++) {
5150e3ece09SJunchao Zhang       PetscInt  i      = rmine[k]; // row to be copied
5160e3ece09SJunchao Zhang       PetscInt *buf    = &sendCol[Ai[i] + Bi[i]];
5170e3ece09SJunchao Zhang       PetscInt  nzLeft = E_NzLeft[i];
5180e3ece09SJunchao Zhang       PetscInt  alen = Ai[i + 1] - Ai[i], blen = Bi[i + 1] - Bi[i];
5190e3ece09SJunchao Zhang       for (PetscInt j = 0; j < alen + blen; j++) {
5200e3ece09SJunchao Zhang         if (j < nzLeft) {
5210e3ece09SJunchao Zhang           buf[j] = garray1[Bj[Bi[i] + j]]; // left B, in global
5220e3ece09SJunchao Zhang         } else if (j < nzLeft + alen) {
5230e3ece09SJunchao Zhang           buf[j] = Aj[Ai[i] + j - nzLeft] + cstart; // diag A, also in global
5240e3ece09SJunchao Zhang         } else {
5250e3ece09SJunchao Zhang           buf[j] = garray1[Bj[Bi[i] + j - alen]]; // right B, in global
5260e3ece09SJunchao Zhang         }
5270e3ece09SJunchao Zhang       }
5280e3ece09SJunchao Zhang     }
5290e3ece09SJunchao Zhang     PetscCall(PetscSFReduceWithMemTypeBegin(reduceSF, MPIU_INT, PETSC_MEMTYPE_HOST, sendCol, PETSC_MEMTYPE_HOST, recvCol, MPI_REPLACE));
5300e3ece09SJunchao Zhang     PetscCall(PetscSFReduceEnd(reduceSF, MPIU_INT, sendCol, recvCol, MPI_REPLACE));
5310e3ece09SJunchao Zhang 
5320e3ece09SJunchao Zhang     // With recvCol[], we do a series of analysis to get i, j of Fd, Fo, and build plans to reduce nonzeros in recv buffers to Fd and Fo
5330e3ece09SJunchao Zhang     PetscInt *recvRowPerm, *recvColSorted;
5340e3ece09SJunchao Zhang     PetscInt *recvNzPerm, *recvNzPermSorted;
5350e3ece09SJunchao Zhang     PetscCall(PetscMalloc4(recvRowCnt, &recvRowPerm, nroots, &recvColSorted, nroots, &recvNzPerm, nroots, &recvNzPermSorted));
5360e3ece09SJunchao Zhang 
5370e3ece09SJunchao Zhang     for (PetscInt i = 0; i < nroots; i++) recvNzPerm[i] = i;                   // numbering all received nonzeros
5380e3ece09SJunchao Zhang     for (PetscInt i = 0; i < recvRowCnt; i++) recvRowPerm[i] = i;              // put up a permutation array, so that after sorting we know where to get a row in recvCol[]
5390e3ece09SJunchao Zhang     PetscCall(PetscSortIntWithPermutation(recvRowCnt, irootloc, recvRowPerm)); // irootloc[] (owned by ownerSF) won't be changed
5400e3ece09SJunchao Zhang 
5410e3ece09SJunchao Zhang     // i[] array, nz are always easiest to compute
542*7b8d4ba6SJunchao Zhang     MatRowMapKokkosViewHost Fdi_h(NoInit("Fdi_h"), Fm + 1), Foi_h(NoInit("Foi_h"), Fm + 1);
5430e3ece09SJunchao Zhang     MatRowMapType          *Fdi, *Foi;
5440e3ece09SJunchao Zhang     PetscInt                FnzDups = 0, Fdnz = 0, FdnzDups = 0, Fonz = 0, FonzDups = 0; // nz (with or without dups) in F, Fd, Fo
5450e3ece09SJunchao Zhang     PetscInt                iter;
5460e3ece09SJunchao Zhang 
5470e3ece09SJunchao Zhang     Kokkos::deep_copy(Fdi_h, 0); // zero, as we will do 'val++' on them
5480e3ece09SJunchao Zhang     Kokkos::deep_copy(Foi_h, 0);
5490e3ece09SJunchao Zhang     Fdi  = Fdi_h.data() + 1; // +1 for easy indexing in code below
5500e3ece09SJunchao Zhang     Foi  = Foi_h.data() + 1;
5510e3ece09SJunchao Zhang     iter = 0;
5520e3ece09SJunchao Zhang     while (iter < recvRowCnt) { // iter over received rows
5530e3ece09SJunchao Zhang       PetscInt curRowIdx = irootloc[recvRowPerm[iter]];
5540e3ece09SJunchao Zhang       PetscInt dupRows   = 1; // current row has this many contributing rows (of various sparsity patterns)
5550e3ece09SJunchao Zhang 
5560e3ece09SJunchao Zhang       while (iter + dupRows < recvRowCnt && irootloc[recvRowPerm[iter + dupRows]] == curRowIdx) dupRows++;
5570e3ece09SJunchao Zhang 
5580e3ece09SJunchao Zhang       // Copy column indices (and their permutation) of these rows into recvColSorted & recvNzPermSorted
5590e3ece09SJunchao Zhang       PetscInt  nz    = 0; // nz (with dups) in the current row
5600e3ece09SJunchao Zhang       PetscInt *jbuf  = recvColSorted + FnzDups;
5610e3ece09SJunchao Zhang       PetscInt *pbuf  = recvNzPermSorted + FnzDups;
5620e3ece09SJunchao Zhang       PetscInt *jbuf2 = jbuf; // temp pointers
5630e3ece09SJunchao Zhang       PetscInt *pbuf2 = pbuf;
5640e3ece09SJunchao Zhang       for (PetscInt d = 0; d < dupRows; d++) {
5650e3ece09SJunchao Zhang         PetscInt i   = recvRowPerm[iter + d];
5660e3ece09SJunchao Zhang         PetscInt len = recvRowLen[i + 1] - recvRowLen[i];
5670e3ece09SJunchao Zhang         PetscCall(PetscArraycpy(jbuf2, &recvCol[recvRowLen[i]], len));
5680e3ece09SJunchao Zhang         PetscCall(PetscArraycpy(pbuf2, &recvNzPerm[recvRowLen[i]], len));
5690e3ece09SJunchao Zhang         jbuf2 += len;
5700e3ece09SJunchao Zhang         pbuf2 += len;
5710e3ece09SJunchao Zhang         nz += len;
5720e3ece09SJunchao Zhang       }
5730e3ece09SJunchao Zhang       PetscCall(PetscIntSortSemiOrderedWithArray(nz, jbuf, pbuf)); // It could be improved with k-way merge sort, since the rows are already sorted
5740e3ece09SJunchao Zhang 
5750e3ece09SJunchao Zhang       // Scan column indices (in jbuf[0,nz), might have dups) of this row, and see how many go to Fd and how many go to Fo
5760e3ece09SJunchao Zhang       PetscInt cur = 0;
5770e3ece09SJunchao Zhang       while (cur < nz) {
5780e3ece09SJunchao Zhang         PetscInt curColIdx = jbuf[cur];
5790e3ece09SJunchao Zhang         PetscInt dups      = 1;
5800e3ece09SJunchao Zhang 
5810e3ece09SJunchao Zhang         while (cur + dups < nz && jbuf[cur + dups] == curColIdx) dups++;
5820e3ece09SJunchao Zhang         if (curColIdx >= cstart && curColIdx < cend) {
5830e3ece09SJunchao Zhang           Fdi[curRowIdx]++;
5840e3ece09SJunchao Zhang           FdnzDups += dups;
5850e3ece09SJunchao Zhang         } else {
5860e3ece09SJunchao Zhang           Foi[curRowIdx]++;
5870e3ece09SJunchao Zhang           FonzDups += dups;
5880e3ece09SJunchao Zhang         }
5890e3ece09SJunchao Zhang         cur += dups;
5900e3ece09SJunchao Zhang       }
5910e3ece09SJunchao Zhang 
5920e3ece09SJunchao Zhang       FnzDups += nz;
5930e3ece09SJunchao Zhang       iter += dupRows; // Move to next unique row
5940e3ece09SJunchao Zhang     }
5950e3ece09SJunchao Zhang 
5960e3ece09SJunchao Zhang     Fdi = Fdi_h.data(); // restore Fdi, Foi and make them CSR
5970e3ece09SJunchao Zhang     Foi = Foi_h.data();
5980e3ece09SJunchao Zhang     for (PetscInt i = 0; i < Fm; i++) {
5990e3ece09SJunchao Zhang       Fdi[i + 1] += Fdi[i];
6000e3ece09SJunchao Zhang       Foi[i + 1] += Foi[i];
6010e3ece09SJunchao Zhang     }
6020e3ece09SJunchao Zhang     Fdnz = Fdi[Fm];
6030e3ece09SJunchao Zhang     Fonz = Foi[Fm];
6040e3ece09SJunchao Zhang     PetscCall(PetscFree2(sendCol, recvCol));
6050e3ece09SJunchao Zhang 
6060e3ece09SJunchao Zhang     // Allocate j, jmap, jperm for Fd and Fo
607*7b8d4ba6SJunchao Zhang     MatColIdxKokkosViewHost Fdj_h(NoInit("Fdj_h"), Fdnz), Foj_h(NoInit("Foj_h"), Fonz);
608*7b8d4ba6SJunchao Zhang     MatRowMapKokkosViewHost Fdjmap_h(NoInit("Fdjmap_h"), Fdnz + 1), Fojmap_h(NoInit("Fojmap_h"), Fonz + 1); // +1 to make csr
609*7b8d4ba6SJunchao Zhang     MatRowMapKokkosViewHost Fdjperm_h(NoInit("Fdjperm_h"), FdnzDups), Fojperm_h(NoInit("Fojperm_h"), FonzDups);
6100e3ece09SJunchao Zhang     MatColIdxType          *Fdj = Fdj_h.data(), *Foj = Foj_h.data();
6110e3ece09SJunchao Zhang     MatRowMapType          *Fdjmap = Fdjmap_h.data(), *Fojmap = Fojmap_h.data();
6120e3ece09SJunchao Zhang     MatRowMapType          *Fdjperm = Fdjperm_h.data(), *Fojperm = Fojperm_h.data();
6130e3ece09SJunchao Zhang 
6140e3ece09SJunchao Zhang     // Scan recvColSorted[] again, and fill j, jmap, jperm for Fd and Fo
6150e3ece09SJunchao Zhang     Fdjmap[0] = 0;
6160e3ece09SJunchao Zhang     Fojmap[0] = 0;
6170e3ece09SJunchao Zhang     FnzDups   = 0;
6180e3ece09SJunchao Zhang     Fdnz      = 0;
6190e3ece09SJunchao Zhang     Fonz      = 0;
6200e3ece09SJunchao Zhang     iter      = 0; // iter over received rows
6210e3ece09SJunchao Zhang     while (iter < recvRowCnt) {
6220e3ece09SJunchao Zhang       PetscInt curRowIdx = irootloc[recvRowPerm[iter]]; // current row idx
6230e3ece09SJunchao Zhang       PetscInt dupRows   = 1;                           // It has this many contributing rows (of various lengths)
6240e3ece09SJunchao Zhang       PetscInt nz        = 0;                           // nz (with dups) in the current row
6250e3ece09SJunchao Zhang 
6260e3ece09SJunchao Zhang       while (iter + dupRows < recvRowCnt && irootloc[recvRowPerm[iter + dupRows]] == curRowIdx) dupRows++;
6270e3ece09SJunchao Zhang       for (PetscInt d = 0; d < dupRows; d++) {
6280e3ece09SJunchao Zhang         PetscInt i = recvRowPerm[iter + d];
6290e3ece09SJunchao Zhang         nz += recvRowLen[i + 1] - recvRowLen[i];
6300e3ece09SJunchao Zhang       }
6310e3ece09SJunchao Zhang 
6320e3ece09SJunchao Zhang       PetscInt *jbuf = recvColSorted + FnzDups;
6330e3ece09SJunchao Zhang       // Scan columns (in jbuf[0,nz) of this row, copy them and their permutation to j[] and jperm[] of Fd and Fo
6340e3ece09SJunchao Zhang       PetscInt cur = 0;
6350e3ece09SJunchao Zhang       while (cur < nz) {
6360e3ece09SJunchao Zhang         PetscInt curColIdx = jbuf[cur];
6370e3ece09SJunchao Zhang         PetscInt dups      = 1;
6380e3ece09SJunchao Zhang 
6390e3ece09SJunchao Zhang         while (cur + dups < nz && jbuf[cur + dups] == curColIdx) dups++;
6400e3ece09SJunchao Zhang         if (curColIdx >= cstart && curColIdx < cend) {
6410e3ece09SJunchao Zhang           Fdj[Fdnz]        = curColIdx - cstart; // easily convert to local
6420e3ece09SJunchao Zhang           Fdjmap[Fdnz + 1] = Fdjmap[Fdnz] + dups;
6430e3ece09SJunchao Zhang           for (PetscInt j = 0; j < dups; j++) Fdjperm[Fdjmap[Fdnz] + j] = recvNzPermSorted[FnzDups + j];
6440e3ece09SJunchao Zhang           FdnzDups += dups;
6450e3ece09SJunchao Zhang           Fdnz++;
6460e3ece09SJunchao Zhang         } else {
6470e3ece09SJunchao Zhang           Foj[Fonz]        = curColIdx; // in global
6480e3ece09SJunchao Zhang           Fojmap[Fonz + 1] = Fojmap[Fonz] + dups;
6490e3ece09SJunchao Zhang           for (PetscInt j = 0; j < dups; j++) Fojperm[Fojmap[Fonz] + j] = recvNzPermSorted[FnzDups + j];
6500e3ece09SJunchao Zhang           FonzDups += dups;
6510e3ece09SJunchao Zhang           Fonz++;
6520e3ece09SJunchao Zhang         }
6530e3ece09SJunchao Zhang         cur += dups;
6540e3ece09SJunchao Zhang         FnzDups += dups;
6550e3ece09SJunchao Zhang       }
6560e3ece09SJunchao Zhang       iter += dupRows; // Move to next unique row
6570e3ece09SJunchao Zhang     }
6580e3ece09SJunchao Zhang     PetscCall(PetscFree4(recvRowPerm, recvColSorted, recvNzPerm, recvNzPermSorted));
6590e3ece09SJunchao Zhang     PetscCall(PetscFree5(sendRowLen, recvRowLen, sdisp, rdisp, reqs));
6600e3ece09SJunchao Zhang 
6610e3ece09SJunchao Zhang     // Combine global column indices in garray1[] and Foj[]
6620e3ece09SJunchao Zhang     PetscInt n2, *garray2;
6630e3ece09SJunchao Zhang 
6640e3ece09SJunchao Zhang     PetscCall(ReduceTwoSetsOfGlobalIndices(n1, garray1, Fonz, Foj, &n2, &garray2, map));
6650e3ece09SJunchao Zhang     mm->sf       = reduceSF;
666*7b8d4ba6SJunchao Zhang     mm->leafBuf  = MatScalarKokkosView(NoInit("leafBuf"), nleaves);
667*7b8d4ba6SJunchao Zhang     mm->rootBuf  = MatScalarKokkosView(NoInit("rootBuf"), nroots);
668aaa8cc7dSPierre Jolivet     mm->garray   = garray2; // give ownership, so no free
6690e3ece09SJunchao Zhang     mm->n        = n2;
6700e3ece09SJunchao Zhang     mm->E_NzLeft = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), E_NzLeft_h);
6710e3ece09SJunchao Zhang     mm->Fdjmap   = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Fdjmap_h);
6720e3ece09SJunchao Zhang     mm->Fdjperm  = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Fdjperm_h);
6730e3ece09SJunchao Zhang     mm->Fojmap   = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Fojmap_h);
6740e3ece09SJunchao Zhang     mm->Fojperm  = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Fojperm_h);
6750e3ece09SJunchao Zhang 
6760e3ece09SJunchao Zhang     // Output Fd and Fo in KokkosCsrMatrix format
677*7b8d4ba6SJunchao Zhang     MatScalarKokkosView Fda_d(NoInit("Fda_d"), Fdnz);
6780e3ece09SJunchao Zhang     MatRowMapKokkosView Fdi_d = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Fdi_h);
6790e3ece09SJunchao Zhang     MatColIdxKokkosView Fdj_d = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Fdj_h);
680*7b8d4ba6SJunchao Zhang     MatScalarKokkosView Foa_d(NoInit("Foa_d"), Fonz);
6810e3ece09SJunchao Zhang     MatRowMapKokkosView Foi_d = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Foi_h);
6820e3ece09SJunchao Zhang     MatColIdxKokkosView Foj_d = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Foj_h);
6830e3ece09SJunchao Zhang 
6840e3ece09SJunchao Zhang     PetscCallCXX(mm->Fd = KokkosCsrMatrix("Fd", Fm, cend - cstart, Fdnz, Fda_d, Fdi_d, Fdj_d));
6850e3ece09SJunchao Zhang     PetscCallCXX(mm->Fo = KokkosCsrMatrix("Fo", Fm, n2, Fonz, Foa_d, Foi_d, Foj_d)); // Fo's column size is n2, length of garray2[]
6860e3ece09SJunchao Zhang 
6870e3ece09SJunchao Zhang     // Compute kernel launch parameters in merging E
6880e3ece09SJunchao Zhang     PetscInt teamSize, vectorLength, rowsPerTeam;
6890e3ece09SJunchao Zhang 
6900e3ece09SJunchao Zhang     teamSize = vectorLength = rowsPerTeam = -1;
6910e3ece09SJunchao Zhang     PetscCall(MatMergeGetLaunchParameters<DefaultExecutionSpace>(Em, Enz, -1, teamSize, vectorLength, rowsPerTeam));
6920e3ece09SJunchao Zhang     mm->E_TeamSize     = teamSize;
6930e3ece09SJunchao Zhang     mm->E_VectorLength = vectorLength;
6940e3ece09SJunchao Zhang     mm->E_RowsPerTeam  = rowsPerTeam;
6950e3ece09SJunchao Zhang   } else PetscCheck(reuse == MAT_REUSE_MATRIX, comm, PETSC_ERR_PLIB, "Unsupported MatReuse enum %d", reuse);
6960e3ece09SJunchao Zhang 
6970e3ece09SJunchao Zhang   // Handy aliases
6980e3ece09SJunchao Zhang   auto       &Aa           = A.values;
6990e3ece09SJunchao Zhang   auto       &Ba           = B.values;
7000e3ece09SJunchao Zhang   const auto &Ai           = A.graph.row_map;
7010e3ece09SJunchao Zhang   const auto &Bi           = B.graph.row_map;
7020e3ece09SJunchao Zhang   const auto &E_NzLeft     = mm->E_NzLeft;
7030e3ece09SJunchao Zhang   auto       &leafBuf      = mm->leafBuf;
7040e3ece09SJunchao Zhang   auto       &rootBuf      = mm->rootBuf;
7050e3ece09SJunchao Zhang   PetscSF     reduceSF     = mm->sf;
7060e3ece09SJunchao Zhang   PetscInt    Em           = A.numRows();
7070e3ece09SJunchao Zhang   PetscInt    teamSize     = mm->E_TeamSize;
7080e3ece09SJunchao Zhang   PetscInt    vectorLength = mm->E_VectorLength;
7090e3ece09SJunchao Zhang   PetscInt    rowsPerTeam  = mm->E_RowsPerTeam;
7100e3ece09SJunchao Zhang   PetscInt    workSets     = (Em + rowsPerTeam - 1) / rowsPerTeam;
7110e3ece09SJunchao Zhang 
7120e3ece09SJunchao Zhang   // Copy rows in A/B of E to leafBuf, then pass it to rootBuf
7130e3ece09SJunchao Zhang   PetscCallCXX(Kokkos::parallel_for(
7140e3ece09SJunchao Zhang     Kokkos::TeamPolicy<>(workSets, teamSize, vectorLength), KOKKOS_LAMBDA(const KokkosTeamMemberType &t) {
7150e3ece09SJunchao Zhang       Kokkos::parallel_for(Kokkos::TeamThreadRange(t, 0, rowsPerTeam), [&](PetscInt k) {
7160e3ece09SJunchao Zhang         PetscInt i = t.league_rank() * rowsPerTeam + k; // i-th row in F
7170e3ece09SJunchao Zhang         if (i < Em) {
7180e3ece09SJunchao Zhang           PetscInt disp   = Ai(i) + Bi(i);
7190e3ece09SJunchao Zhang           PetscInt alen   = Ai(i + 1) - Ai(i);
7200e3ece09SJunchao Zhang           PetscInt blen   = Bi(i + 1) - Bi(i);
7210e3ece09SJunchao Zhang           PetscInt nzleft = E_NzLeft(i);
7220e3ece09SJunchao Zhang 
7230e3ece09SJunchao Zhang           Kokkos::parallel_for(Kokkos::ThreadVectorRange(t, alen + blen), [&](PetscInt j) {
7240e3ece09SJunchao Zhang             MatScalar &val = leafBuf(disp + j);
7250e3ece09SJunchao Zhang             if (j < nzleft) { // B left
7260e3ece09SJunchao Zhang               val = Ba(Bi(i) + j);
7270e3ece09SJunchao Zhang             } else if (j < nzleft + alen) { // diag A
7280e3ece09SJunchao Zhang               val = Aa(Ai(i) + j - nzleft);
7290e3ece09SJunchao Zhang             } else { // B right
7300e3ece09SJunchao Zhang               val = Ba(Bi(i) + j - alen);
731f0e6e2d1SJunchao Zhang             }
732f0e6e2d1SJunchao Zhang           });
733f0e6e2d1SJunchao Zhang         }
734f0e6e2d1SJunchao Zhang       });
7350e3ece09SJunchao Zhang     }));
7360e3ece09SJunchao Zhang   PetscCall(PetscSFReduceWithMemTypeBegin(reduceSF, MPIU_SCALAR, PETSC_MEMTYPE_KOKKOS, leafBuf.data(), PETSC_MEMTYPE_KOKKOS, rootBuf.data(), MPI_REPLACE));
737f0e6e2d1SJunchao Zhang   PetscFunctionReturn(PETSC_SUCCESS);
738f0e6e2d1SJunchao Zhang }
7390e3ece09SJunchao Zhang 
740aaa8cc7dSPierre Jolivet // To finish MatMPIAIJKokkosReduce.
7410e3ece09SJunchao Zhang static PetscErrorCode MatMPIAIJKokkosReduceEnd(MPI_Comm comm, KokkosCsrMatrix A, KokkosCsrMatrix B, PetscInt cstart, PetscInt cend, const PetscInt *garray1, PetscSF ownerSF, MatReuse reuse, PetscInt *map, MatMatStruct_AtB *mm)
7420e3ece09SJunchao Zhang {
7430e3ece09SJunchao Zhang   PetscFunctionBegin;
7440e3ece09SJunchao Zhang   auto       &leafBuf  = mm->leafBuf;
7450e3ece09SJunchao Zhang   auto       &rootBuf  = mm->rootBuf;
7460e3ece09SJunchao Zhang   auto       &Fda      = mm->Fd.values;
7470e3ece09SJunchao Zhang   const auto &Fdjmap   = mm->Fdjmap;
7480e3ece09SJunchao Zhang   const auto &Fdjperm  = mm->Fdjperm;
7490e3ece09SJunchao Zhang   auto        Fdnz     = mm->Fd.nnz();
7500e3ece09SJunchao Zhang   auto       &Foa      = mm->Fo.values;
7510e3ece09SJunchao Zhang   const auto &Fojmap   = mm->Fojmap;
7520e3ece09SJunchao Zhang   const auto &Fojperm  = mm->Fojperm;
7530e3ece09SJunchao Zhang   auto        Fonz     = mm->Fo.nnz();
7540e3ece09SJunchao Zhang   PetscSF     reduceSF = mm->sf;
7550e3ece09SJunchao Zhang 
7560e3ece09SJunchao Zhang   PetscCall(PetscSFReduceEnd(reduceSF, MPIU_SCALAR, leafBuf.data(), rootBuf.data(), MPI_REPLACE));
7570e3ece09SJunchao Zhang 
7580e3ece09SJunchao Zhang   // Reduce data in rootBuf to Fd and Fo
7590e3ece09SJunchao Zhang   PetscCallCXX(Kokkos::parallel_for(
7600e3ece09SJunchao Zhang     Fdnz, KOKKOS_LAMBDA(const MatRowMapType i) {
7610e3ece09SJunchao Zhang       PetscScalar sum = 0.0;
7620e3ece09SJunchao Zhang       for (MatRowMapType k = Fdjmap(i); k < Fdjmap(i + 1); k++) sum += rootBuf(Fdjperm(k));
7630e3ece09SJunchao Zhang       Fda(i) = sum;
7640e3ece09SJunchao Zhang     }));
7650e3ece09SJunchao Zhang 
7660e3ece09SJunchao Zhang   PetscCallCXX(Kokkos::parallel_for(
7670e3ece09SJunchao Zhang     Fonz, KOKKOS_LAMBDA(const MatRowMapType i) {
7680e3ece09SJunchao Zhang       PetscScalar sum = 0.0;
7690e3ece09SJunchao Zhang       for (MatRowMapType k = Fojmap(i); k < Fojmap(i + 1); k++) sum += rootBuf(Fojperm(k));
7700e3ece09SJunchao Zhang       Foa(i) = sum;
7710e3ece09SJunchao Zhang     }));
7720e3ece09SJunchao Zhang   PetscFunctionReturn(PETSC_SUCCESS);
7730e3ece09SJunchao Zhang }
7740e3ece09SJunchao Zhang 
7750e3ece09SJunchao Zhang /*
7760e3ece09SJunchao Zhang   MatMPIAIJKokkosBcast - Bcast local rows of a MPIAIJKOKKOS matrix (E) to produce a local matrix (F, stored in mm) in split form
7770e3ece09SJunchao Zhang 
7780e3ece09SJunchao Zhang   This is a complex routine. It is essentially the MPIAIJKOKKOS counterpart of MatGetBrowsOfAoCols_MPIAIJ, but supports
7790e3ece09SJunchao Zhang   device and involves various index mapping.
7800e3ece09SJunchao Zhang 
7810e3ece09SJunchao Zhang   In the given ownerSF, leaves correspond to rows in F, and roots correspond to rows in E. Roots may connect to multiple leaves.
7820e3ece09SJunchao Zhang   Suppose F's j-th row is connected to a root identified by PetscSFNode (k,i), it means we need to bcast the i-th row of E on rank k
7830e3ece09SJunchao Zhang   to j-th row of F. ownerSF is not an arbitrary SF, instead it is the Mvctx of another MPIAIJ matrix A that is able to perform A*E.
7840e3ece09SJunchao Zhang   F has the same column layout as E.
7850e3ece09SJunchao Zhang 
7860e3ece09SJunchao Zhang   Conceptually F has global column indices. In this routine, we spit F into diagonal Fd and off-diagonal Fo.
787aaa8cc7dSPierre Jolivet   Fd uses local column indices, which are easy to compute. We just need to subtract the "local column range start" from the global indices.
7880e3ece09SJunchao Zhang   Fo had global column indices at first. We will reduce them into local ones. In doing that, we also take into account the global
7890e3ece09SJunchao Zhang   column indices that E's off-diag block has. Let's say there are n1 such indices stored in garray1[]. We will reduce them along with
7900e3ece09SJunchao Zhang   column indices in Fo and update Fo with local indices.
7910e3ece09SJunchao Zhang 
7920e3ece09SJunchao Zhang    Input Parameters:
7930e3ece09SJunchao Zhang +   E       - the MPIAIJKOKKOS matrix
7940e3ece09SJunchao Zhang .   ownerSF - the ownership SF (nonsignificant in MAT_REUSE_MATRIX)
7950e3ece09SJunchao Zhang .   reuse   - either MAT_INITIAL_MATRIX or MAT_REUSE_MATRIX
7960e3ece09SJunchao Zhang -   mm      - to stash matproduct intermediate data structures
7970e3ece09SJunchao Zhang 
7980e3ece09SJunchao Zhang     Output Parameters:
7990e3ece09SJunchao Zhang +   map[n1] - allocated by caller. It maps garray1[] to garray2[]. See more at ReduceTwoSetsOfGlobalIndices.
8000e3ece09SJunchao Zhang -   mm      - contains various info, such as garray2[], Fd, Fo, etc.
8010e3ece09SJunchao Zhang 
8020e3ece09SJunchao Zhang     Notes:
8030e3ece09SJunchao Zhang     When reuse = MAT_REUSE_MATRIX, ownerSF, map are not significant.
8040e3ece09SJunchao Zhang     The routine is provide in split-phase form MatMPIAIJKokkosBcastBegin/End() to provide computation/communication opportunities.
8050e3ece09SJunchao Zhang */
8060e3ece09SJunchao Zhang static PetscErrorCode MatMPIAIJKokkosBcastBegin(Mat E, PetscSF ownerSF, MatReuse reuse, PetscInt *map, MatMatStruct_AB *mm)
8070e3ece09SJunchao Zhang {
8080e3ece09SJunchao Zhang   Mat_MPIAIJ       *empi = static_cast<Mat_MPIAIJ *>(E->data);
8090e3ece09SJunchao Zhang   Mat               A = empi->A, B = empi->B; // diag and off-diag
8100e3ece09SJunchao Zhang   Mat_SeqAIJKokkos *akok = static_cast<Mat_SeqAIJKokkos *>(A->spptr), *bkok = static_cast<Mat_SeqAIJKokkos *>(B->spptr);
8110e3ece09SJunchao Zhang   PetscInt          Em = E->rmap->n; // #local rows
8120e3ece09SJunchao Zhang   MPI_Comm          comm;
8130e3ece09SJunchao Zhang 
8140e3ece09SJunchao Zhang   PetscFunctionBegin;
8150e3ece09SJunchao Zhang   PetscCallMPI(PetscObjectGetComm((PetscObject)E, &comm));
8160e3ece09SJunchao Zhang   if (reuse == MAT_INITIAL_MATRIX) {
8170e3ece09SJunchao Zhang     Mat_SeqAIJ     *aseq = static_cast<Mat_SeqAIJ *>(A->data), *bseq = static_cast<Mat_SeqAIJ *>(B->data);
8180e3ece09SJunchao Zhang     PetscInt        n1 = B->cmap->n, *Ai = aseq->i, *Aj = aseq->j, *Bi = bseq->i, *Bj = bseq->j;
8190e3ece09SJunchao Zhang     const PetscInt *garray1 = empi->garray; // its size is n1
8200e3ece09SJunchao Zhang     PetscInt        cstart, cend;
8210e3ece09SJunchao Zhang     PetscSF         bcastSF;
8220e3ece09SJunchao Zhang 
8230e3ece09SJunchao Zhang     PetscCall(MatGetOwnershipRangeColumn(E, &cstart, &cend));
8240e3ece09SJunchao Zhang 
8250e3ece09SJunchao Zhang     // Count how many nonzeros of each row in E are in the left of the diag block [cstart,cend)
826*7b8d4ba6SJunchao Zhang     PetscIntKokkosViewHost E_NzLeft_h(NoInit("E_NzLeft_h"), Em), E_RowLen_h(NoInit("E_RowLen_h"), Em);
8270e3ece09SJunchao Zhang     PetscInt              *E_NzLeft = E_NzLeft_h.data(), *E_RowLen = E_RowLen_h.data();
8280e3ece09SJunchao Zhang     for (PetscInt i = 0; i < Em; i++) {
8290e3ece09SJunchao Zhang       const PetscInt *first, *last, *it;
8300e3ece09SJunchao Zhang       PetscInt        count, step;
8310e3ece09SJunchao Zhang       // std::lower_bound(first,last,cstart), but need to use global column indices
8320e3ece09SJunchao Zhang       first = Bj + Bi[i];
8330e3ece09SJunchao Zhang       last  = Bj + Bi[i + 1];
8340e3ece09SJunchao Zhang       count = last - first;
8350e3ece09SJunchao Zhang       while (count > 0) {
8360e3ece09SJunchao Zhang         it   = first;
8370e3ece09SJunchao Zhang         step = count / 2;
8380e3ece09SJunchao Zhang         it += step;
8390e3ece09SJunchao Zhang         if (empi->garray[*it] < cstart) { // map local to global
8400e3ece09SJunchao Zhang           first = ++it;
8410e3ece09SJunchao Zhang           count -= step + 1;
8420e3ece09SJunchao Zhang         } else count = step;
8430e3ece09SJunchao Zhang       }
8440e3ece09SJunchao Zhang       E_NzLeft[i] = first - (Bj + Bi[i]);
8450e3ece09SJunchao Zhang       E_RowLen[i] = (Ai[i + 1] - Ai[i]) + (Bi[i + 1] - Bi[i]);
8460e3ece09SJunchao Zhang     }
8470e3ece09SJunchao Zhang 
8480e3ece09SJunchao Zhang     // Compute row pointer Fi of F
8490e3ece09SJunchao Zhang     PetscInt *Fi, Fm, Fnz;
8500e3ece09SJunchao Zhang     PetscCall(PetscSFGetGraph(ownerSF, NULL, &Fm, NULL, NULL)); // Fm = #rows of F = nleaves of ownerSF
8510e3ece09SJunchao Zhang     PetscCall(PetscMalloc1(Fm + 1, &Fi));
8520e3ece09SJunchao Zhang     Fi[0] = 0;
8530e3ece09SJunchao Zhang     PetscCall(PetscSFBcastWithMemTypeBegin(ownerSF, MPIU_INT, PETSC_MEMTYPE_HOST, E_RowLen, PETSC_MEMTYPE_HOST, &Fi[1], MPI_REPLACE));
8540e3ece09SJunchao Zhang     PetscCall(PetscSFBcastEnd(ownerSF, MPIU_INT, E_RowLen, &Fi[1], MPI_REPLACE));
8550e3ece09SJunchao Zhang     for (PetscInt i = 0; i < Fm; i++) Fi[i + 1] += Fi[i];
8560e3ece09SJunchao Zhang     Fnz = Fi[Fm];
8570e3ece09SJunchao Zhang 
8580e3ece09SJunchao Zhang     // Build the real PetscSF for bcasting E rows (buffer to buffer)
8590e3ece09SJunchao Zhang     const PetscMPIInt *iranks, *ranks;
8600e3ece09SJunchao Zhang     const PetscInt    *ioffset, *irootloc, *roffset;
8610e3ece09SJunchao Zhang     PetscInt           niranks, nranks, *sdisp, *rdisp;
8620e3ece09SJunchao Zhang     MPI_Request       *reqs;
8630e3ece09SJunchao Zhang     PetscMPIInt        tag;
8640e3ece09SJunchao Zhang 
8650e3ece09SJunchao Zhang     PetscCall(PetscSFGetLeafRanks(ownerSF, &niranks, &iranks, &ioffset, &irootloc)); // get leaf ranks referencing roots on this process
8660e3ece09SJunchao Zhang     PetscCall(PetscSFGetRootRanks(ownerSF, &nranks, &ranks, &roffset, NULL, NULL));  // recv info
8670e3ece09SJunchao Zhang     PetscCall(PetscMalloc3(niranks + 1, &sdisp, nranks, &rdisp, niranks + nranks, &reqs));
8680e3ece09SJunchao Zhang 
8690e3ece09SJunchao Zhang     sdisp[0] = 0; // send displacement
8700e3ece09SJunchao Zhang     for (PetscInt i = 0; i < niranks; i++) {
8710e3ece09SJunchao Zhang       sdisp[i + 1] = sdisp[i];
8720e3ece09SJunchao Zhang       for (PetscInt j = ioffset[i]; j < ioffset[i + 1]; j++) {
8730e3ece09SJunchao Zhang         PetscInt r = irootloc[j]; // row to be sent
8740e3ece09SJunchao Zhang         sdisp[i + 1] += E_RowLen[r];
8750e3ece09SJunchao Zhang       }
8760e3ece09SJunchao Zhang     }
8770e3ece09SJunchao Zhang 
8780e3ece09SJunchao Zhang     PetscCallMPI(PetscCommGetNewTag(comm, &tag));
8790e3ece09SJunchao Zhang     for (PetscInt i = 0; i < nranks; i++) PetscCallMPI(MPI_Irecv(&rdisp[i], 1, MPIU_INT, ranks[i], tag, comm, &reqs[i]));
8800e3ece09SJunchao Zhang     for (PetscInt i = 0; i < niranks; i++) PetscCallMPI(MPI_Isend(&sdisp[i], 1, MPIU_INT, iranks[i], tag, comm, &reqs[nranks + i]));
8810e3ece09SJunchao Zhang     PetscCallMPI(MPI_Waitall(niranks + nranks, reqs, MPI_STATUSES_IGNORE));
8820e3ece09SJunchao Zhang 
8830e3ece09SJunchao Zhang     PetscInt     nleaves = Fnz;            // leaves are nonzeros I will receive
8840e3ece09SJunchao Zhang     PetscInt     nroots  = sdisp[niranks]; // roots are nonzeros I will send
8850e3ece09SJunchao Zhang     PetscSFNode *iremote;                  // give ownership to bcastSF
8860e3ece09SJunchao Zhang     PetscCall(PetscMalloc1(nleaves, &iremote));
8870e3ece09SJunchao Zhang     for (PetscInt i = 0; i < nranks; i++) { // for each sender rank
8880e3ece09SJunchao Zhang       PetscInt k = 0;
8890e3ece09SJunchao Zhang       for (PetscInt j = Fi[roffset[i]]; j < Fi[roffset[i + 1]]; j++) { // I will receive rows [roffset[i], roffset[i+1]) of F from ranks[i]
8900e3ece09SJunchao Zhang         iremote[j].rank  = ranks[i];
8910e3ece09SJunchao Zhang         iremote[j].index = rdisp[i] + k; // their root location
8920e3ece09SJunchao Zhang         k++;
8930e3ece09SJunchao Zhang       }
8940e3ece09SJunchao Zhang     }
8950e3ece09SJunchao Zhang     PetscCall(PetscSFCreate(comm, &bcastSF));
8960e3ece09SJunchao Zhang     PetscCall(PetscSFSetGraph(bcastSF, nroots, nleaves, NULL, PETSC_OWN_POINTER, iremote, PETSC_OWN_POINTER));
8970e3ece09SJunchao Zhang     PetscCall(PetscFree3(sdisp, rdisp, reqs));
8980e3ece09SJunchao Zhang 
8990e3ece09SJunchao Zhang     // Build a plan (rowoffset, irootloc, E_NzLeft) to copy rows in E to rootdata of bcastSF in parallel
900*7b8d4ba6SJunchao Zhang     PetscIntKokkosViewHost rowoffset_h(NoInit("rowoffset_h"), ioffset[niranks] + 1);
9010e3ece09SJunchao Zhang     PetscInt              *rowoffset = rowoffset_h.data(); // for each entry (row) indicated in irootloc[], we calculate its destinate offset in copying
9020e3ece09SJunchao Zhang     rowoffset[0]                     = 0;
903*7b8d4ba6SJunchao Zhang     for (PetscInt i = 0; i < ioffset[niranks]; i++) { rowoffset[i + 1] = rowoffset[i] + E_RowLen[irootloc[i]]; }
9040e3ece09SJunchao Zhang 
9050e3ece09SJunchao Zhang     // Copy (global) column indices of the needed rows in E to a buffer, and then bcast to Fj[]
9060e3ece09SJunchao Zhang     PetscInt *jbuf, *Fj;
9070e3ece09SJunchao Zhang     PetscCall(PetscMalloc2(nroots, &jbuf, Fnz, &Fj));
9080e3ece09SJunchao Zhang     for (PetscInt k = 0; k < ioffset[niranks]; k++) {
9090e3ece09SJunchao Zhang       PetscInt  i      = irootloc[k]; // row to be copied
9100e3ece09SJunchao Zhang       PetscInt *buf    = &jbuf[rowoffset[k]];
9110e3ece09SJunchao Zhang       PetscInt  nzLeft = E_NzLeft[i];
9120e3ece09SJunchao Zhang       PetscInt  alen = Ai[i + 1] - Ai[i], blen = Bi[i + 1] - Bi[i];
9130e3ece09SJunchao Zhang       for (PetscInt j = 0; j < alen + blen; j++) {
9140e3ece09SJunchao Zhang         if (j < nzLeft) {
9150e3ece09SJunchao Zhang           buf[j] = empi->garray[Bj[Bi[i] + j]]; // left B, in global
9160e3ece09SJunchao Zhang         } else if (j < nzLeft + alen) {
9170e3ece09SJunchao Zhang           buf[j] = Aj[Ai[i] + j - nzLeft] + cstart; // diag A, also in global
9180e3ece09SJunchao Zhang         } else {
9190e3ece09SJunchao Zhang           buf[j] = empi->garray[Bj[Bi[i] + j - alen]]; // right B, in global
9200e3ece09SJunchao Zhang         }
9210e3ece09SJunchao Zhang       }
9220e3ece09SJunchao Zhang     }
9230e3ece09SJunchao Zhang     PetscCall(PetscSFBcastWithMemTypeBegin(bcastSF, MPIU_INT, PETSC_MEMTYPE_HOST, jbuf, PETSC_MEMTYPE_HOST, Fj, MPI_REPLACE));
9240e3ece09SJunchao Zhang     PetscCall(PetscSFBcastEnd(bcastSF, MPIU_INT, jbuf, Fj, MPI_REPLACE));
9250e3ece09SJunchao Zhang 
9260e3ece09SJunchao Zhang     // Build a plan (i.e., F_NzLeft) to split F into Fd and Fo
927*7b8d4ba6SJunchao Zhang     MatRowMapKokkosViewHost Fdi_h(NoInit("Fdi_h"), Fm + 1), Foi_h(NoInit("Foi_h"), Fm + 1); // row pointer of Fd, Fo
928*7b8d4ba6SJunchao Zhang     MatColIdxKokkosViewHost F_NzLeft_h(NoInit("F_NzLeft_h"), Fm);                           // split each row of F into Left, Diag, Right. We only need to record #nz in Left and Diag.
9290e3ece09SJunchao Zhang     MatRowMapType          *Fdi = Fdi_h.data(), *Foi = Foi_h.data();
9300e3ece09SJunchao Zhang     MatColIdxType          *F_NzLeft = F_NzLeft_h.data();
9310e3ece09SJunchao Zhang 
9320e3ece09SJunchao Zhang     Fdi[0] = Foi[0] = 0;
9330e3ece09SJunchao Zhang     for (PetscInt i = 0; i < Fm; i++) {
9340e3ece09SJunchao Zhang       PetscInt *first, *last, *lb1, *lb2;
9350e3ece09SJunchao Zhang       // cut the row into: Left, [cstart, cend), Right
9360e3ece09SJunchao Zhang       first       = Fj + Fi[i];
9370e3ece09SJunchao Zhang       last        = Fj + Fi[i + 1];
9380e3ece09SJunchao Zhang       lb1         = std::lower_bound(first, last, cstart);
9390e3ece09SJunchao Zhang       F_NzLeft[i] = lb1 - first;
9400e3ece09SJunchao Zhang       lb2         = std::lower_bound(first, last, cend);
9410e3ece09SJunchao Zhang       Fdi[i + 1]  = lb2 - lb1;                        // row i length in Fdi
9420e3ece09SJunchao Zhang       Foi[i + 1]  = (Fi[i + 1] - Fi[i]) - Fdi[i + 1]; // row i length in Foi
9430e3ece09SJunchao Zhang     }
9440e3ece09SJunchao Zhang     for (PetscInt i = 0; i < Fm; i++) {
9450e3ece09SJunchao Zhang       Fdi[i + 1] += Fdi[i];
9460e3ece09SJunchao Zhang       Foi[i + 1] += Foi[i];
9470e3ece09SJunchao Zhang     }
9480e3ece09SJunchao Zhang 
9490e3ece09SJunchao Zhang     // Fill Fdj[] and Foj[], i.e., columns of Fd and Fo. Fdj[] are local, but Foj[] are not yet.
9500e3ece09SJunchao Zhang     PetscInt                Fdnz = Fdi[Fm], Fonz = Foi[Fm];
951*7b8d4ba6SJunchao Zhang     MatColIdxKokkosViewHost Fdj_h(NoInit("Fdj_h"), Fdnz), Foj_h(NoInit("Foj_h"), Fonz);
9520e3ece09SJunchao Zhang     MatColIdxType          *Fdj = Fdj_h.data(), *Foj = Foj_h.data(), gid;
9530e3ece09SJunchao Zhang 
9540e3ece09SJunchao Zhang     for (PetscInt i = 0; i < Fm; i++) {
9550e3ece09SJunchao Zhang       PetscInt nzLeft = F_NzLeft[i];
9560e3ece09SJunchao Zhang       PetscInt len    = Fdi[i + 1] - Fdi[i]; // diag row len
9570e3ece09SJunchao Zhang       for (PetscInt j = 0; j < Fi[i + 1] - Fi[i]; j++) {
9580e3ece09SJunchao Zhang         gid = Fj[Fi[i] + j];
9590e3ece09SJunchao Zhang         if (j < nzLeft) { // left, in global
9600e3ece09SJunchao Zhang           Foj[Foi[i] + j] = gid;
9610e3ece09SJunchao Zhang         } else if (j < nzLeft + len) { // diag, in local
9620e3ece09SJunchao Zhang           Fdj[Fdi[i] + j - nzLeft] = gid - cstart;
9630e3ece09SJunchao Zhang         } else { // right, in global
9640e3ece09SJunchao Zhang           Foj[Foi[i] + j - len] = gid;
9650e3ece09SJunchao Zhang         }
9660e3ece09SJunchao Zhang       }
9670e3ece09SJunchao Zhang     }
9680e3ece09SJunchao Zhang     PetscCall(PetscFree2(jbuf, Fj));
9690e3ece09SJunchao Zhang     PetscCall(PetscFree(Fi));
9700e3ece09SJunchao Zhang 
9710e3ece09SJunchao Zhang     // Reduce global indices in Foj[] and garray1[] into local ones
9720e3ece09SJunchao Zhang     PetscInt n2, *garray2;
9730e3ece09SJunchao Zhang     PetscCall(ReduceTwoSetsOfGlobalIndices(n1, garray1, Fonz, Foj, &n2, &garray2, map));
9740e3ece09SJunchao Zhang 
9750e3ece09SJunchao Zhang     // Record the plans built above, for reuse
9760e3ece09SJunchao Zhang     PetscIntKokkosViewHost tmp(const_cast<PetscInt *>(irootloc), ioffset[niranks]); // irootloc[] is owned by ownerSF. We create a copy for safety
977*7b8d4ba6SJunchao Zhang     PetscIntKokkosViewHost irootloc_h(NoInit("irootloc_h"), ioffset[niranks]);
9780e3ece09SJunchao Zhang     Kokkos::deep_copy(irootloc_h, tmp);
9790e3ece09SJunchao Zhang     mm->sf        = bcastSF;
9800e3ece09SJunchao Zhang     mm->E_NzLeft  = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), E_NzLeft_h);
9810e3ece09SJunchao Zhang     mm->F_NzLeft  = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), F_NzLeft_h);
9820e3ece09SJunchao Zhang     mm->irootloc  = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), irootloc_h);
9830e3ece09SJunchao Zhang     mm->rowoffset = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), rowoffset_h);
984*7b8d4ba6SJunchao Zhang     mm->rootBuf   = MatScalarKokkosView(NoInit("rootBuf"), nroots);
985*7b8d4ba6SJunchao Zhang     mm->leafBuf   = MatScalarKokkosView(NoInit("leafBuf"), nleaves);
9860e3ece09SJunchao Zhang     mm->garray    = garray2;
9870e3ece09SJunchao Zhang     mm->n         = n2;
9880e3ece09SJunchao Zhang 
9890e3ece09SJunchao Zhang     // Output Fd and Fo in KokkosCsrMatrix format
990*7b8d4ba6SJunchao Zhang     MatScalarKokkosView Fda_d(NoInit("Fda_d"), Fdnz), Foa_d(NoInit("Foa_d"), Fonz);
9910e3ece09SJunchao Zhang     MatRowMapKokkosView Fdi_d = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Fdi_h);
9920e3ece09SJunchao Zhang     MatColIdxKokkosView Fdj_d = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Fdj_h);
9930e3ece09SJunchao Zhang     MatRowMapKokkosView Foi_d = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Foi_h);
9940e3ece09SJunchao Zhang     MatColIdxKokkosView Foj_d = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Foj_h);
9950e3ece09SJunchao Zhang 
9960e3ece09SJunchao Zhang     PetscCallCXX(mm->Fd = KokkosCsrMatrix("Fd", Fm, cend - cstart, Fdnz, Fda_d, Fdi_d, Fdj_d));
9970e3ece09SJunchao Zhang     PetscCallCXX(mm->Fo = KokkosCsrMatrix("Fo", Fm, n2, Fonz, Foa_d, Foi_d, Foj_d));
9980e3ece09SJunchao Zhang 
9990e3ece09SJunchao Zhang     // Compute kernel launch parameters in merging E or splitting F
10000e3ece09SJunchao Zhang     PetscInt teamSize, vectorLength, rowsPerTeam;
10010e3ece09SJunchao Zhang 
10020e3ece09SJunchao Zhang     teamSize = vectorLength = rowsPerTeam = -1;
10030e3ece09SJunchao Zhang     PetscCall(MatMergeGetLaunchParameters<DefaultExecutionSpace>(mm->irootloc.extent(0), mm->rootBuf.extent(0), -1, teamSize, vectorLength, rowsPerTeam));
10040e3ece09SJunchao Zhang     mm->E_TeamSize     = teamSize;
10050e3ece09SJunchao Zhang     mm->E_VectorLength = vectorLength;
10060e3ece09SJunchao Zhang     mm->E_RowsPerTeam  = rowsPerTeam;
10070e3ece09SJunchao Zhang 
10080e3ece09SJunchao Zhang     teamSize = vectorLength = rowsPerTeam = -1;
10090e3ece09SJunchao Zhang     PetscCall(MatMergeGetLaunchParameters<DefaultExecutionSpace>(Fm, Fnz, -1, teamSize, vectorLength, rowsPerTeam));
10100e3ece09SJunchao Zhang     mm->F_TeamSize     = teamSize;
10110e3ece09SJunchao Zhang     mm->F_VectorLength = vectorLength;
10120e3ece09SJunchao Zhang     mm->F_RowsPerTeam  = rowsPerTeam;
10130e3ece09SJunchao Zhang   } else PetscCheck(reuse == MAT_REUSE_MATRIX, comm, PETSC_ERR_PLIB, "Unsupported MatReuse enum %d", reuse);
10140e3ece09SJunchao Zhang 
10150e3ece09SJunchao Zhang   // Sync E's value to device
10160e3ece09SJunchao Zhang   akok->a_dual.sync_device();
10170e3ece09SJunchao Zhang   bkok->a_dual.sync_device();
10180e3ece09SJunchao Zhang 
10190e3ece09SJunchao Zhang   // Handy aliases
10200e3ece09SJunchao Zhang   const auto &Aa = akok->a_dual.view_device();
10210e3ece09SJunchao Zhang   const auto &Ba = bkok->a_dual.view_device();
10220e3ece09SJunchao Zhang   const auto &Ai = akok->i_dual.view_device();
10230e3ece09SJunchao Zhang   const auto &Bi = bkok->i_dual.view_device();
10240e3ece09SJunchao Zhang 
10250e3ece09SJunchao Zhang   // Fetch the plans
10260e3ece09SJunchao Zhang   PetscIntKokkosView  &E_NzLeft  = mm->E_NzLeft;
10270e3ece09SJunchao Zhang   PetscSF             &bcastSF   = mm->sf;
10280e3ece09SJunchao Zhang   MatScalarKokkosView &rootBuf   = mm->rootBuf;
10290e3ece09SJunchao Zhang   MatScalarKokkosView &leafBuf   = mm->leafBuf;
10300e3ece09SJunchao Zhang   PetscIntKokkosView  &irootloc  = mm->irootloc;
10310e3ece09SJunchao Zhang   PetscIntKokkosView  &rowoffset = mm->rowoffset;
10320e3ece09SJunchao Zhang 
10330e3ece09SJunchao Zhang   PetscInt teamSize     = mm->E_TeamSize;
10340e3ece09SJunchao Zhang   PetscInt vectorLength = mm->E_VectorLength;
10350e3ece09SJunchao Zhang   PetscInt rowsPerTeam  = mm->E_RowsPerTeam;
10360e3ece09SJunchao Zhang   PetscInt workSets     = (irootloc.extent(0) + rowsPerTeam - 1) / rowsPerTeam;
10370e3ece09SJunchao Zhang 
10380e3ece09SJunchao Zhang   // Copy rows in A/B of E to rootBuf, then bcast it to leafBuf
10390e3ece09SJunchao Zhang   PetscCallCXX(Kokkos::parallel_for(
10400e3ece09SJunchao Zhang     Kokkos::TeamPolicy<>(workSets, teamSize, vectorLength), KOKKOS_LAMBDA(const KokkosTeamMemberType &t) {
10410e3ece09SJunchao Zhang       Kokkos::parallel_for(Kokkos::TeamThreadRange(t, 0, rowsPerTeam), [&](PetscInt k) {
10420e3ece09SJunchao Zhang         size_t r = t.league_rank() * rowsPerTeam + k; // r-th entry in irootloc[]
10430e3ece09SJunchao Zhang         if (r < irootloc.extent(0)) {
10440e3ece09SJunchao Zhang           PetscInt i      = irootloc(r); // row i of E
10450e3ece09SJunchao Zhang           PetscInt disp   = rowoffset(r);
10460e3ece09SJunchao Zhang           PetscInt alen   = Ai(i + 1) - Ai(i);
10470e3ece09SJunchao Zhang           PetscInt blen   = Bi(i + 1) - Bi(i);
10480e3ece09SJunchao Zhang           PetscInt nzleft = E_NzLeft(i);
10490e3ece09SJunchao Zhang 
10500e3ece09SJunchao Zhang           Kokkos::parallel_for(Kokkos::ThreadVectorRange(t, alen + blen), [&](PetscInt j) {
10510e3ece09SJunchao Zhang             if (j < nzleft) { // B left
10520e3ece09SJunchao Zhang               rootBuf(disp + j) = Ba(Bi(i) + j);
10530e3ece09SJunchao Zhang             } else if (j < nzleft + alen) { // diag A
10540e3ece09SJunchao Zhang               rootBuf(disp + j) = Aa(Ai(i) + j - nzleft);
10550e3ece09SJunchao Zhang             } else { // B right
10560e3ece09SJunchao Zhang               rootBuf(disp + j) = Ba(Bi(i) + j - alen);
10570e3ece09SJunchao Zhang             }
10580e3ece09SJunchao Zhang           });
10590e3ece09SJunchao Zhang         }
10600e3ece09SJunchao Zhang       });
10610e3ece09SJunchao Zhang     }));
10620e3ece09SJunchao Zhang   PetscCall(PetscSFBcastWithMemTypeBegin(bcastSF, MPIU_SCALAR, PETSC_MEMTYPE_KOKKOS, rootBuf.data(), PETSC_MEMTYPE_KOKKOS, leafBuf.data(), MPI_REPLACE));
10630e3ece09SJunchao Zhang   PetscFunctionReturn(PETSC_SUCCESS);
10640e3ece09SJunchao Zhang }
10650e3ece09SJunchao Zhang 
10660e3ece09SJunchao Zhang // To finish MatMPIAIJKokkosBcast.
10670e3ece09SJunchao Zhang static PetscErrorCode MatMPIAIJKokkosBcastEnd(Mat E, PetscSF ownerSF, MatReuse reuse, PetscInt *map, MatMatStruct_AB *mm)
10680e3ece09SJunchao Zhang {
10690e3ece09SJunchao Zhang   PetscFunctionBegin;
10700e3ece09SJunchao Zhang   const auto &Fd  = mm->Fd;
10710e3ece09SJunchao Zhang   const auto &Fo  = mm->Fo;
10720e3ece09SJunchao Zhang   const auto &Fdi = Fd.graph.row_map;
10730e3ece09SJunchao Zhang   const auto &Foi = Fo.graph.row_map;
10740e3ece09SJunchao Zhang   auto       &Fda = Fd.values;
10750e3ece09SJunchao Zhang   auto       &Foa = Fo.values;
10760e3ece09SJunchao Zhang   auto        Fm  = Fd.numRows();
10770e3ece09SJunchao Zhang 
10780e3ece09SJunchao Zhang   PetscIntKokkosView  &F_NzLeft     = mm->F_NzLeft;
10790e3ece09SJunchao Zhang   PetscSF             &bcastSF      = mm->sf;
10800e3ece09SJunchao Zhang   MatScalarKokkosView &rootBuf      = mm->rootBuf;
10810e3ece09SJunchao Zhang   MatScalarKokkosView &leafBuf      = mm->leafBuf;
10820e3ece09SJunchao Zhang   PetscInt             teamSize     = mm->F_TeamSize;
10830e3ece09SJunchao Zhang   PetscInt             vectorLength = mm->F_VectorLength;
10840e3ece09SJunchao Zhang   PetscInt             rowsPerTeam  = mm->F_RowsPerTeam;
10850e3ece09SJunchao Zhang   PetscInt             workSets     = (Fm + rowsPerTeam - 1) / rowsPerTeam;
10860e3ece09SJunchao Zhang 
10870e3ece09SJunchao Zhang   PetscCall(PetscSFBcastEnd(bcastSF, MPIU_SCALAR, rootBuf.data(), leafBuf.data(), MPI_REPLACE));
10880e3ece09SJunchao Zhang 
10890e3ece09SJunchao Zhang   // Update Fda and Foa with new data in leafBuf (as if it is Fa)
10900e3ece09SJunchao Zhang   PetscCallCXX(Kokkos::parallel_for(
10910e3ece09SJunchao Zhang     Kokkos::TeamPolicy<>(workSets, teamSize, vectorLength), KOKKOS_LAMBDA(const KokkosTeamMemberType &t) {
10920e3ece09SJunchao Zhang       Kokkos::parallel_for(Kokkos::TeamThreadRange(t, 0, rowsPerTeam), [&](PetscInt k) {
10930e3ece09SJunchao Zhang         PetscInt i = t.league_rank() * rowsPerTeam + k; // i-th row in F
10940e3ece09SJunchao Zhang         if (i < Fm) {
10950e3ece09SJunchao Zhang           PetscInt nzLeft = F_NzLeft(i);
10960e3ece09SJunchao Zhang           PetscInt alen   = Fdi(i + 1) - Fdi(i);
10970e3ece09SJunchao Zhang           PetscInt blen   = Foi(i + 1) - Foi(i);
10980e3ece09SJunchao Zhang           PetscInt Fii    = Fdi(i) + Foi(i);
10990e3ece09SJunchao Zhang 
11000e3ece09SJunchao Zhang           Kokkos::parallel_for(Kokkos::ThreadVectorRange(t, alen + blen), [&](PetscInt j) {
11010e3ece09SJunchao Zhang             PetscScalar val = leafBuf(Fii + j);
11020e3ece09SJunchao Zhang             if (j < nzLeft) { // left
11030e3ece09SJunchao Zhang               Foa(Foi(i) + j) = val;
11040e3ece09SJunchao Zhang             } else if (j < nzLeft + alen) { // diag
11050e3ece09SJunchao Zhang               Fda(Fdi(i) + j - nzLeft) = val;
11060e3ece09SJunchao Zhang             } else { // right
11070e3ece09SJunchao Zhang               Foa(Foi(i) + j - alen) = val;
11080e3ece09SJunchao Zhang             }
11090e3ece09SJunchao Zhang           });
11100e3ece09SJunchao Zhang         }
11110e3ece09SJunchao Zhang       });
11120e3ece09SJunchao Zhang     }));
11130e3ece09SJunchao Zhang   PetscFunctionReturn(PETSC_SUCCESS);
11140e3ece09SJunchao Zhang }
11150e3ece09SJunchao Zhang 
11160e3ece09SJunchao Zhang static PetscErrorCode MatProductSymbolic_MPIAIJKokkos_AtB(Mat_Product *product, Mat A, Mat B, MatMatStruct_AtB *mm)
11170e3ece09SJunchao Zhang {
11180e3ece09SJunchao Zhang   Mat_MPIAIJ     *ampi = static_cast<Mat_MPIAIJ *>(A->data);
11190e3ece09SJunchao Zhang   Mat_MPIAIJ     *bmpi = static_cast<Mat_MPIAIJ *>(B->data);
11200e3ece09SJunchao Zhang   KokkosCsrMatrix Adt, Aot, Ad, Ao, Bd, Bo;
11210e3ece09SJunchao Zhang   PetscInt        cstart, cend;
11220e3ece09SJunchao Zhang   MPI_Comm        comm;
11230e3ece09SJunchao Zhang 
11240e3ece09SJunchao Zhang   PetscFunctionBegin;
11250e3ece09SJunchao Zhang   PetscCall(PetscObjectGetComm((PetscObject)B, &comm));
11260e3ece09SJunchao Zhang   PetscCall(MatSeqAIJKokkosGenerateTranspose_Private(ampi->A, &Adt));
11270e3ece09SJunchao Zhang   PetscCall(MatSeqAIJKokkosGenerateTranspose_Private(ampi->B, &Aot));
11280e3ece09SJunchao Zhang   PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(ampi->A, &Ad));
11290e3ece09SJunchao Zhang   PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(ampi->B, &Ao));
11300e3ece09SJunchao Zhang   PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(bmpi->A, &Bd));
11310e3ece09SJunchao Zhang   PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(bmpi->B, &Bo));
11320e3ece09SJunchao Zhang 
11330e3ece09SJunchao Zhang   // TODO: add command line options to select spgemm algorithms
11340e3ece09SJunchao Zhang   auto spgemm_alg = KokkosSparse::SPGEMMAlgorithm::SPGEMM_DEFAULT; // default is TPL if enabled, otherwise KK
11350e3ece09SJunchao Zhang 
11360e3ece09SJunchao Zhang   // CUDA-10.2's spgemm has bugs. We prefer the SpGEMMreuse APIs introduced in cuda-11.4
11370e3ece09SJunchao Zhang #if defined(KOKKOSKERNELS_ENABLE_TPL_CUSPARSE)
11380e3ece09SJunchao Zhang   #if PETSC_PKG_CUDA_VERSION_LT(11, 4, 0)
11390e3ece09SJunchao Zhang   spgemm_alg = KokkosSparse::SPGEMMAlgorithm::SPGEMM_KK;
1140f0e6e2d1SJunchao Zhang   #endif
11410e3ece09SJunchao Zhang #endif
11420e3ece09SJunchao Zhang 
11430e3ece09SJunchao Zhang   PetscCallCXX(mm->kh1.create_spgemm_handle(spgemm_alg));
11440e3ece09SJunchao Zhang   PetscCallCXX(mm->kh2.create_spgemm_handle(spgemm_alg));
11450e3ece09SJunchao Zhang   PetscCallCXX(mm->kh3.create_spgemm_handle(spgemm_alg));
11460e3ece09SJunchao Zhang   PetscCallCXX(mm->kh4.create_spgemm_handle(spgemm_alg));
11470e3ece09SJunchao Zhang 
11480e3ece09SJunchao Zhang   // Aot * (B's diag + B's off-diag)
11490e3ece09SJunchao Zhang   PetscCallCXX(KokkosSparse::spgemm_symbolic(mm->kh3, Aot, false, Bd, false, mm->C3));
11500e3ece09SJunchao Zhang   PetscCallCXX(KokkosSparse::spgemm_symbolic(mm->kh4, Aot, false, Bo, false, mm->C4));
11510e3ece09SJunchao Zhang   // KK spgemm_symbolic() only populates the result's row map, but not its columns.
11520e3ece09SJunchao Zhang   // TODO: Remove the fake spgemm_numeric() after KK fixed this problem.
11530e3ece09SJunchao Zhang   PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh3, Aot, false, Bd, false, mm->C3));
11540e3ece09SJunchao Zhang   PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh4, Aot, false, Bo, false, mm->C4));
11550e3ece09SJunchao Zhang #if PETSC_PKG_KOKKOS_KERNELS_VERSION_LT(4, 0, 0)
11560e3ece09SJunchao Zhang   PetscCallCXX(sort_crs_matrix(mm->C3));
11570e3ece09SJunchao Zhang   PetscCallCXX(sort_crs_matrix(mm->C4));
11580e3ece09SJunchao Zhang #endif
11590e3ece09SJunchao Zhang 
11600e3ece09SJunchao Zhang   // Reduce E (i.e., C3 and C4)'s rows to form F, and overlap the communication
1161*7b8d4ba6SJunchao Zhang   PetscIntKokkosViewHost map_h(NoInit("map_h"), bmpi->B->cmap->n);
11620e3ece09SJunchao Zhang   PetscCall(MatGetOwnershipRangeColumn(B, &cstart, &cend));
11630e3ece09SJunchao Zhang   PetscCall(MatMPIAIJKokkosReduceBegin(comm, mm->C3, mm->C4, cstart, cend, bmpi->garray, ampi->Mvctx, MAT_INITIAL_MATRIX, map_h.data(), mm));
11640e3ece09SJunchao Zhang 
11650e3ece09SJunchao Zhang   // Adt * (B's diag + B's off-diag)
11660e3ece09SJunchao Zhang   PetscCallCXX(KokkosSparse::spgemm_symbolic(mm->kh1, Adt, false, Bd, false, mm->C1));
11670e3ece09SJunchao Zhang   PetscCallCXX(KokkosSparse::spgemm_symbolic(mm->kh2, Adt, false, Bo, false, mm->C2_mid));
11680e3ece09SJunchao Zhang   PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh1, Adt, false, Bd, false, mm->C1));
11690e3ece09SJunchao Zhang   PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh2, Adt, false, Bo, false, mm->C2_mid));
11700e3ece09SJunchao Zhang #if PETSC_PKG_KOKKOS_KERNELS_VERSION_LT(4, 0, 0)
11710e3ece09SJunchao Zhang   PetscCallCXX(sort_crs_matrix(mm->C1));
11720e3ece09SJunchao Zhang   PetscCallCXX(sort_crs_matrix(mm->C2_mid));
11730e3ece09SJunchao Zhang #endif
11740e3ece09SJunchao Zhang 
11750e3ece09SJunchao Zhang   PetscCall(MatMPIAIJKokkosReduceEnd(comm, mm->C3, mm->C4, cstart, cend, bmpi->garray, ampi->Mvctx, MAT_INITIAL_MATRIX, map_h.data(), mm));
11760e3ece09SJunchao Zhang 
11770e3ece09SJunchao Zhang   // Create C2, which shares a, i arrays with C2_mid, but with new column indices and potentially larger column size
1178*7b8d4ba6SJunchao Zhang   MatColIdxKokkosView oldj = mm->C2_mid.graph.entries, newj(NoInit("j"), oldj.extent(0));
11790e3ece09SJunchao Zhang   PetscIntKokkosView  map  = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), map_h);
11800e3ece09SJunchao Zhang   PetscCallCXX(Kokkos::parallel_for(
11810e3ece09SJunchao Zhang     oldj.extent(0), KOKKOS_LAMBDA(const PetscInt i) { newj(i) = map(oldj(i)); }));
11820e3ece09SJunchao Zhang   PetscCallCXX(mm->C2 = KokkosCsrMatrix("C2", mm->C2_mid.numRows(), mm->n /*new column size*/, mm->C2_mid.nnz(), mm->C2_mid.values, mm->C2_mid.graph.row_map, newj));
11830e3ece09SJunchao Zhang 
11840e3ece09SJunchao Zhang   // C = (C1+Fd, C2+Fo)
11850e3ece09SJunchao Zhang   PetscCallCXX(mm->kh1.create_spadd_handle(true)); // C1, Fd are sorted
11860e3ece09SJunchao Zhang   PetscCallCXX(mm->kh2.create_spadd_handle(true)); // C2, Fo are sorted
11870e3ece09SJunchao Zhang   PetscCallCXX(KokkosSparse::spadd_symbolic(&mm->kh1, mm->C1, mm->Fd, mm->Cd));
11880e3ece09SJunchao Zhang   PetscCallCXX(KokkosSparse::spadd_symbolic(&mm->kh2, mm->C2, mm->Fo, mm->Co));
11890e3ece09SJunchao Zhang   PetscCallCXX(KokkosSparse::spadd_numeric(&mm->kh1, 1.0, mm->C1, 1.0, mm->Fd, mm->Cd));
11900e3ece09SJunchao Zhang   PetscCallCXX(KokkosSparse::spadd_numeric(&mm->kh2, 1.0, mm->C2, 1.0, mm->Fo, mm->Co));
11910e3ece09SJunchao Zhang   PetscFunctionReturn(PETSC_SUCCESS);
11920e3ece09SJunchao Zhang }
11930e3ece09SJunchao Zhang 
11940e3ece09SJunchao Zhang static PetscErrorCode MatProductNumeric_MPIAIJKokkos_AtB(Mat_Product *product, Mat A, Mat B, MatMatStruct_AtB *mm)
11950e3ece09SJunchao Zhang {
11960e3ece09SJunchao Zhang   Mat_MPIAIJ     *ampi = static_cast<Mat_MPIAIJ *>(A->data);
11970e3ece09SJunchao Zhang   Mat_MPIAIJ     *bmpi = static_cast<Mat_MPIAIJ *>(B->data);
11980e3ece09SJunchao Zhang   KokkosCsrMatrix Adt, Aot, Bd, Bo;
11990e3ece09SJunchao Zhang   MPI_Comm        comm;
12000e3ece09SJunchao Zhang 
12010e3ece09SJunchao Zhang   PetscFunctionBegin;
12020e3ece09SJunchao Zhang   PetscCall(PetscObjectGetComm((PetscObject)B, &comm));
12030e3ece09SJunchao Zhang   PetscCall(MatSeqAIJKokkosGenerateTranspose_Private(ampi->A, &Adt));
12040e3ece09SJunchao Zhang   PetscCall(MatSeqAIJKokkosGenerateTranspose_Private(ampi->B, &Aot));
12050e3ece09SJunchao Zhang   PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(bmpi->A, &Bd));
12060e3ece09SJunchao Zhang   PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(bmpi->B, &Bo));
12070e3ece09SJunchao Zhang 
12080e3ece09SJunchao Zhang   // Aot * (B's diag + B's off-diag)
12090e3ece09SJunchao Zhang   PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh3, Aot, false, Bd, false, mm->C3));
12100e3ece09SJunchao Zhang   PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh4, Aot, false, Bo, false, mm->C4));
12110e3ece09SJunchao Zhang 
12120e3ece09SJunchao Zhang   // Reduce E (i.e., C3 and C4)'s rows to form F, and overlap the communication
12130e3ece09SJunchao Zhang   PetscCall(MatMPIAIJKokkosReduceBegin(comm, mm->C3, mm->C4, 0, 0, NULL, NULL, MAT_REUSE_MATRIX, NULL, mm));
12140e3ece09SJunchao Zhang 
12150e3ece09SJunchao Zhang   // Adt * (B's diag + B's off-diag)
12160e3ece09SJunchao Zhang   PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh1, Adt, false, Bd, false, mm->C1));
12170e3ece09SJunchao Zhang   PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh2, Adt, false, Bo, false, mm->C2_mid));
12180e3ece09SJunchao Zhang 
12190e3ece09SJunchao Zhang   PetscCall(MatMPIAIJKokkosReduceEnd(comm, mm->C3, mm->C4, 0, 0, NULL, NULL, MAT_REUSE_MATRIX, NULL, mm));
12200e3ece09SJunchao Zhang 
12210e3ece09SJunchao Zhang   // C = (C1+Fd, C2+Fo)
12220e3ece09SJunchao Zhang   PetscCallCXX(KokkosSparse::spadd_numeric(&mm->kh1, 1.0, mm->C1, 1.0, mm->Fd, mm->Cd));
12230e3ece09SJunchao Zhang   PetscCallCXX(KokkosSparse::spadd_numeric(&mm->kh2, 1.0, mm->C2, 1.0, mm->Fo, mm->Co));
12240e3ece09SJunchao Zhang   PetscFunctionReturn(PETSC_SUCCESS);
12250e3ece09SJunchao Zhang }
1226f0e6e2d1SJunchao Zhang 
1227076ba34aSJunchao Zhang /* MatProductSymbolic_MPIAIJKokkos_AB - AB flavor of MatProductSymbolic_MPIAIJKokkos
1228076ba34aSJunchao Zhang 
1229076ba34aSJunchao Zhang   Input Parameters:
1230076ba34aSJunchao Zhang +  product  - Mat_Product which carried out the computation. Passed in to access info about this mat product.
1231076ba34aSJunchao Zhang .  A        - an MPIAIJKOKKOS matrix
1232076ba34aSJunchao Zhang .  B        - an MPIAIJKOKKOS matrix
1233076ba34aSJunchao Zhang -  mm       - a struct used to stash intermediate data when computing AB. Persist from symbolic to numeric operations.
1234076ba34aSJunchao Zhang */
1235d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatProductSymbolic_MPIAIJKokkos_AB(Mat_Product *product, Mat A, Mat B, MatMatStruct_AB *mm)
1236d71ae5a4SJacob Faibussowitsch {
12370e3ece09SJunchao Zhang   Mat_MPIAIJ     *ampi = static_cast<Mat_MPIAIJ *>(A->data);
12380e3ece09SJunchao Zhang   Mat_MPIAIJ     *bmpi = static_cast<Mat_MPIAIJ *>(B->data);
12390e3ece09SJunchao Zhang   KokkosCsrMatrix Ad, Ao, Bd, Bo;
1240076ba34aSJunchao Zhang 
1241076ba34aSJunchao Zhang   PetscFunctionBegin;
12420e3ece09SJunchao Zhang   PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(ampi->A, &Ad));
12430e3ece09SJunchao Zhang   PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(ampi->B, &Ao));
12440e3ece09SJunchao Zhang   PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(bmpi->A, &Bd));
12450e3ece09SJunchao Zhang   PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(bmpi->B, &Bo));
12460e3ece09SJunchao Zhang 
12470e3ece09SJunchao Zhang   // TODO: add command line options to select spgemm algorithms
12480e3ece09SJunchao Zhang   auto spgemm_alg = KokkosSparse::SPGEMMAlgorithm::SPGEMM_DEFAULT; // default is TPL if enabled, otherwise KK
12490e3ece09SJunchao Zhang 
12500e3ece09SJunchao Zhang   // CUDA-10.2's spgemm has bugs. We prefer the SpGEMMreuse APIs introduced in cuda-11.4
12510e3ece09SJunchao Zhang #if defined(KOKKOSKERNELS_ENABLE_TPL_CUSPARSE)
12520e3ece09SJunchao Zhang   #if PETSC_PKG_CUDA_VERSION_LT(11, 4, 0)
12530e3ece09SJunchao Zhang   spgemm_alg = KokkosSparse::SPGEMMAlgorithm::SPGEMM_KK;
12540e3ece09SJunchao Zhang   #endif
1255f0e6e2d1SJunchao Zhang #endif
1256f0e6e2d1SJunchao Zhang 
12570e3ece09SJunchao Zhang   mm->kh1.create_spgemm_handle(spgemm_alg);
12580e3ece09SJunchao Zhang   mm->kh2.create_spgemm_handle(spgemm_alg);
12590e3ece09SJunchao Zhang   mm->kh3.create_spgemm_handle(spgemm_alg);
12600e3ece09SJunchao Zhang   mm->kh4.create_spgemm_handle(spgemm_alg);
1261076ba34aSJunchao Zhang 
12620e3ece09SJunchao Zhang   // Bcast B's rows to form F, and overlap the communication
1263*7b8d4ba6SJunchao Zhang   PetscIntKokkosViewHost map_h(NoInit("map_h"), bmpi->B->cmap->n);
12640e3ece09SJunchao Zhang   PetscCall(MatMPIAIJKokkosBcastBegin(B, ampi->Mvctx, MAT_INITIAL_MATRIX, map_h.data(), mm));
1265076ba34aSJunchao Zhang 
12660e3ece09SJunchao Zhang   // A's diag * (B's diag + B's off-diag)
12670e3ece09SJunchao Zhang   PetscCallCXX(KokkosSparse::spgemm_symbolic(mm->kh1, Ad, false, Bd, false, mm->C1));
12680e3ece09SJunchao Zhang   PetscCallCXX(KokkosSparse::spgemm_symbolic(mm->kh2, Ad, false, Bo, false, mm->C2_mid)); // C2 aliases with C2_mid, except with new column indices
12690e3ece09SJunchao Zhang   // KK spgemm_symbolic() only populates the result's row map, but not its columns.
12700e3ece09SJunchao Zhang   // TODO: Remove the fake spgemm_numeric() after KK fixed this problem.
12710e3ece09SJunchao Zhang   PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh1, Ad, false, Bd, false, mm->C1));
12720e3ece09SJunchao Zhang   PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh2, Ad, false, Bo, false, mm->C2_mid));
12730e3ece09SJunchao Zhang #if PETSC_PKG_KOKKOS_KERNELS_VERSION_LT(4, 0, 0)
12740e3ece09SJunchao Zhang   PetscCallCXX(sort_crs_matrix(mm->C1));
12750e3ece09SJunchao Zhang   PetscCallCXX(sort_crs_matrix(mm->C2_mid));
12760e3ece09SJunchao Zhang #endif
1277076ba34aSJunchao Zhang 
12780e3ece09SJunchao Zhang   PetscCall(MatMPIAIJKokkosBcastEnd(B, ampi->Mvctx, MAT_INITIAL_MATRIX, map_h.data(), mm));
1279076ba34aSJunchao Zhang 
12800e3ece09SJunchao Zhang   // A's off-diag * (F's diag + F's off-diag)
12810e3ece09SJunchao Zhang   PetscCallCXX(KokkosSparse::spgemm_symbolic(mm->kh3, Ao, false, mm->Fd, false, mm->C3));
12820e3ece09SJunchao Zhang   PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh3, Ao, false, mm->Fd, false, mm->C3));
12830e3ece09SJunchao Zhang   PetscCallCXX(KokkosSparse::spgemm_symbolic(mm->kh4, Ao, false, mm->Fo, false, mm->C4));
12840e3ece09SJunchao Zhang   PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh4, Ao, false, mm->Fo, false, mm->C4));
12850e3ece09SJunchao Zhang #if PETSC_PKG_KOKKOS_KERNELS_VERSION_LT(4, 0, 0)
12860e3ece09SJunchao Zhang   PetscCallCXX(sort_crs_matrix(mm->C3));
12870e3ece09SJunchao Zhang   PetscCallCXX(sort_crs_matrix(mm->C4));
12880e3ece09SJunchao Zhang #endif
1289076ba34aSJunchao Zhang 
12900e3ece09SJunchao Zhang   // Create C2, which shares a, i arrays with C2_mid, but with new column indices and potentially larger column size
1291*7b8d4ba6SJunchao Zhang   MatColIdxKokkosView oldj = mm->C2_mid.graph.entries, newj(NoInit("j"), oldj.extent(0));
12920e3ece09SJunchao Zhang   PetscIntKokkosView  map  = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), map_h);
12930e3ece09SJunchao Zhang   PetscCallCXX(Kokkos::parallel_for(
12940e3ece09SJunchao Zhang     oldj.extent(0), KOKKOS_LAMBDA(const PetscInt i) { newj(i) = map(oldj(i)); }));
12950e3ece09SJunchao Zhang   mm->C2 = KokkosCsrMatrix("C2", mm->C2_mid.numRows(), mm->n /*new column size*/, mm->C2_mid.nnz(), mm->C2_mid.values, mm->C2_mid.graph.row_map, newj);
12960e3ece09SJunchao Zhang 
12970e3ece09SJunchao Zhang   // C = (Cd, Co) = (C1+C3, C2+C4)
12980e3ece09SJunchao Zhang   mm->kh1.create_spadd_handle(true); // C1, C3 are sorted
12990e3ece09SJunchao Zhang   mm->kh2.create_spadd_handle(true); // C2, C4 are sorted
13000e3ece09SJunchao Zhang   PetscCallCXX(KokkosSparse::spadd_symbolic(&mm->kh1, mm->C1, mm->C3, mm->Cd));
13010e3ece09SJunchao Zhang   PetscCallCXX(KokkosSparse::spadd_symbolic(&mm->kh2, mm->C2, mm->C4, mm->Co));
13020e3ece09SJunchao Zhang   PetscCallCXX(KokkosSparse::spadd_numeric(&mm->kh1, 1.0, mm->C1, 1.0, mm->C3, mm->Cd));
13030e3ece09SJunchao Zhang   PetscCallCXX(KokkosSparse::spadd_numeric(&mm->kh2, 1.0, mm->C2, 1.0, mm->C4, mm->Co));
13043ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
1305076ba34aSJunchao Zhang }
1306076ba34aSJunchao Zhang 
13070e3ece09SJunchao Zhang static PetscErrorCode MatProductNumeric_MPIAIJKokkos_AB(Mat_Product *product, Mat A, Mat B, MatMatStruct_AB *mm)
1308d71ae5a4SJacob Faibussowitsch {
13090e3ece09SJunchao Zhang   Mat_MPIAIJ     *ampi = static_cast<Mat_MPIAIJ *>(A->data);
13100e3ece09SJunchao Zhang   Mat_MPIAIJ     *bmpi = static_cast<Mat_MPIAIJ *>(B->data);
13110e3ece09SJunchao Zhang   KokkosCsrMatrix Ad, Ao, Bd, Bo;
1312076ba34aSJunchao Zhang 
1313076ba34aSJunchao Zhang   PetscFunctionBegin;
13140e3ece09SJunchao Zhang   PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(ampi->A, &Ad));
13150e3ece09SJunchao Zhang   PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(ampi->B, &Ao));
13160e3ece09SJunchao Zhang   PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(bmpi->A, &Bd));
13170e3ece09SJunchao Zhang   PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(bmpi->B, &Bo));
1318076ba34aSJunchao Zhang 
13190e3ece09SJunchao Zhang   // Bcast B's rows to form F, and overlap the communication
13200e3ece09SJunchao Zhang   PetscCall(MatMPIAIJKokkosBcastBegin(B, NULL, MAT_REUSE_MATRIX, NULL, mm));
1321076ba34aSJunchao Zhang 
13220e3ece09SJunchao Zhang   // A's diag * (B's diag + B's off-diag)
13230e3ece09SJunchao Zhang   PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh1, Ad, false, Bd, false, mm->C1));
13240e3ece09SJunchao Zhang   PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh2, Ad, false, Bo, false, mm->C2_mid));
1325076ba34aSJunchao Zhang 
13260e3ece09SJunchao Zhang   PetscCall(MatMPIAIJKokkosBcastEnd(B, NULL, MAT_REUSE_MATRIX, NULL, mm));
1327076ba34aSJunchao Zhang 
13280e3ece09SJunchao Zhang   // A's off-diag * (F's diag + F's off-diag)
13290e3ece09SJunchao Zhang   PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh3, Ao, false, mm->Fd, false, mm->C3));
13300e3ece09SJunchao Zhang   PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh4, Ao, false, mm->Fo, false, mm->C4));
13310e3ece09SJunchao Zhang 
13320e3ece09SJunchao Zhang   // C = (Cd, Co) = (C1+C3, C2+C4)
13330e3ece09SJunchao Zhang   PetscCallCXX(KokkosSparse::spadd_numeric(&mm->kh1, 1.0, mm->C1, 1.0, mm->C3, mm->Cd));
13340e3ece09SJunchao Zhang   PetscCallCXX(KokkosSparse::spadd_numeric(&mm->kh2, 1.0, mm->C2, 1.0, mm->C4, mm->Co));
13353ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
1336076ba34aSJunchao Zhang }
1337076ba34aSJunchao Zhang 
1338d71ae5a4SJacob Faibussowitsch PetscErrorCode MatProductNumeric_MPIAIJKokkos(Mat C)
1339d71ae5a4SJacob Faibussowitsch {
13400e3ece09SJunchao Zhang   Mat_MPIAIJ                  *cmpi = static_cast<Mat_MPIAIJ *>(C->data);
13410e3ece09SJunchao Zhang   Mat_Product                 *product;
13420e3ece09SJunchao Zhang   MatProductData_MPIAIJKokkos *pdata;
1343076ba34aSJunchao Zhang   MatProductType               ptype;
13440e3ece09SJunchao Zhang   Mat                          A, B;
1345076ba34aSJunchao Zhang 
1346076ba34aSJunchao Zhang   PetscFunctionBegin;
13470e3ece09SJunchao Zhang   MatCheckProduct(C, 1); // make sure C is a product
13480e3ece09SJunchao Zhang   product = C->product;
13490e3ece09SJunchao Zhang   pdata   = static_cast<MatProductData_MPIAIJKokkos *>(product->data);
1350076ba34aSJunchao Zhang   ptype   = product->type;
1351076ba34aSJunchao Zhang   A       = product->A;
1352076ba34aSJunchao Zhang   B       = product->B;
1353076ba34aSJunchao Zhang 
13540e3ece09SJunchao Zhang   // See if numeric has already been done in symbolic (e.g., user calls MatMatMult(A,B,MAT_INITIAL_MATRIX,..,C)).
13550e3ece09SJunchao Zhang   // If yes, skip the numeric, but reset the flag so that next time when user calls MatMatMult(E,F,MAT_REUSE_MATRIX,..,C),
13560e3ece09SJunchao Zhang   // we still do numeric.
13570e3ece09SJunchao Zhang   if (pdata->reusesym) { // numeric reuses results from symbolic
13580e3ece09SJunchao Zhang     pdata->reusesym = PETSC_FALSE;
13593ba16761SJacob Faibussowitsch     PetscFunctionReturn(PETSC_SUCCESS);
1360076ba34aSJunchao Zhang   }
1361076ba34aSJunchao Zhang 
1362076ba34aSJunchao Zhang   if (ptype == MATPRODUCT_AB) {
13630e3ece09SJunchao Zhang     PetscCall(MatProductNumeric_MPIAIJKokkos_AB(product, A, B, pdata->mmAB));
1364076ba34aSJunchao Zhang   } else if (ptype == MATPRODUCT_AtB) {
13650e3ece09SJunchao Zhang     PetscCall(MatProductNumeric_MPIAIJKokkos_AtB(product, A, B, pdata->mmAtB));
13660e3ece09SJunchao Zhang   } else if (ptype == MATPRODUCT_PtAP) { // BtAB, computed by Z = AB; C= BtZ
13670e3ece09SJunchao Zhang     PetscCall(MatProductNumeric_MPIAIJKokkos_AB(product, A, B, pdata->mmAB));
13680e3ece09SJunchao Zhang     PetscCall(MatProductNumeric_MPIAIJKokkos_AtB(product, B, pdata->Z, pdata->mmAtB));
1369076ba34aSJunchao Zhang   }
13700e3ece09SJunchao Zhang 
13710e3ece09SJunchao Zhang   PetscCall(MatSeqAIJKokkosModifyDevice(cmpi->A)); // mark that A, B on device are modified
13720e3ece09SJunchao Zhang   PetscCall(MatSeqAIJKokkosModifyDevice(cmpi->B));
13733ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
1374076ba34aSJunchao Zhang }
1375076ba34aSJunchao Zhang 
1376d71ae5a4SJacob Faibussowitsch PetscErrorCode MatProductSymbolic_MPIAIJKokkos(Mat C)
1377d71ae5a4SJacob Faibussowitsch {
1378076ba34aSJunchao Zhang   Mat                          A, B;
13790e3ece09SJunchao Zhang   Mat_Product                 *product;
1380076ba34aSJunchao Zhang   MatProductType               ptype;
13810e3ece09SJunchao Zhang   MatProductData_MPIAIJKokkos *pdata;
1382076ba34aSJunchao Zhang   MatMatStruct                *mm = NULL;
13830e3ece09SJunchao Zhang   PetscInt                     m, n, M, N;
13840e3ece09SJunchao Zhang   Mat                          Cd, Co;
13850e3ece09SJunchao Zhang   MPI_Comm                     comm;
1386076ba34aSJunchao Zhang 
1387076ba34aSJunchao Zhang   PetscFunctionBegin;
13880e3ece09SJunchao Zhang   PetscCall(PetscObjectGetComm((PetscObject)C, &comm));
1389076ba34aSJunchao Zhang   MatCheckProduct(C, 1);
13900e3ece09SJunchao Zhang   product = C->product;
13910e3ece09SJunchao Zhang   PetscCheck(!product->data, comm, PETSC_ERR_PLIB, "Product data not empty");
1392076ba34aSJunchao Zhang   ptype = product->type;
1393076ba34aSJunchao Zhang   A     = product->A;
1394076ba34aSJunchao Zhang   B     = product->B;
1395076ba34aSJunchao Zhang 
1396076ba34aSJunchao Zhang   switch (ptype) {
13979371c9d4SSatish Balay   case MATPRODUCT_AB:
13989371c9d4SSatish Balay     m = A->rmap->n;
13999371c9d4SSatish Balay     n = B->cmap->n;
14009371c9d4SSatish Balay     M = A->rmap->N;
14019371c9d4SSatish Balay     N = B->cmap->N;
14029371c9d4SSatish Balay     break;
14039371c9d4SSatish Balay   case MATPRODUCT_AtB:
14049371c9d4SSatish Balay     m = A->cmap->n;
14059371c9d4SSatish Balay     n = B->cmap->n;
14069371c9d4SSatish Balay     M = A->cmap->N;
14079371c9d4SSatish Balay     N = B->cmap->N;
14089371c9d4SSatish Balay     break;
14099371c9d4SSatish Balay   case MATPRODUCT_PtAP:
14109371c9d4SSatish Balay     m = B->cmap->n;
14119371c9d4SSatish Balay     n = B->cmap->n;
14129371c9d4SSatish Balay     M = B->cmap->N;
14139371c9d4SSatish Balay     N = B->cmap->N;
14149371c9d4SSatish Balay     break; /* BtAB */
1415d71ae5a4SJacob Faibussowitsch   default:
14160e3ece09SJunchao Zhang     SETERRQ(comm, PETSC_ERR_PLIB, "Not for product type %s", MatProductTypes[ptype]);
1417076ba34aSJunchao Zhang   }
1418076ba34aSJunchao Zhang 
14199566063dSJacob Faibussowitsch   PetscCall(MatSetSizes(C, m, n, M, N));
14209566063dSJacob Faibussowitsch   PetscCall(PetscLayoutSetUp(C->rmap));
14219566063dSJacob Faibussowitsch   PetscCall(PetscLayoutSetUp(C->cmap));
14229566063dSJacob Faibussowitsch   PetscCall(MatSetType(C, ((PetscObject)A)->type_name));
1423076ba34aSJunchao Zhang 
14240e3ece09SJunchao Zhang   pdata           = new MatProductData_MPIAIJKokkos();
14250e3ece09SJunchao Zhang   pdata->reusesym = product->api_user;
1426076ba34aSJunchao Zhang 
1427076ba34aSJunchao Zhang   if (ptype == MATPRODUCT_AB) {
14280e3ece09SJunchao Zhang     auto mmAB = new MatMatStruct_AB();
14290e3ece09SJunchao Zhang     PetscCall(MatProductSymbolic_MPIAIJKokkos_AB(product, A, B, mmAB));
14300e3ece09SJunchao Zhang     mm = pdata->mmAB = mmAB;
1431076ba34aSJunchao Zhang   } else if (ptype == MATPRODUCT_AtB) {
14320e3ece09SJunchao Zhang     auto mmAtB = new MatMatStruct_AtB();
14330e3ece09SJunchao Zhang     PetscCall(MatProductSymbolic_MPIAIJKokkos_AtB(product, A, B, mmAtB));
14340e3ece09SJunchao Zhang     mm = pdata->mmAtB = mmAtB;
14350e3ece09SJunchao Zhang   } else if (ptype == MATPRODUCT_PtAP) { // C = BtAB, computed as Z = AB; C= BtZ
14360e3ece09SJunchao Zhang     Mat Zd, Zo, Z;                       // Zd, Zo are owned by pdata->Z
14370e3ece09SJunchao Zhang 
14380e3ece09SJunchao Zhang     auto mmAB = new MatMatStruct_AB();
14390e3ece09SJunchao Zhang     PetscCall(MatProductSymbolic_MPIAIJKokkos_AB(product, A, B, mmAB)); // Z stored as mmAB->{Cd, Co}
14400e3ece09SJunchao Zhang     PetscCall(MatCreateSeqAIJKokkosWithKokkosCsrMatrix(PETSC_COMM_SELF, mmAB->Cd, &Zd));
14410e3ece09SJunchao Zhang     PetscCall(MatCreateSeqAIJKokkosWithKokkosCsrMatrix(PETSC_COMM_SELF, mmAB->Co, &Zo));
14420e3ece09SJunchao Zhang     pdata->mmAB = mmAB;
14430e3ece09SJunchao Zhang 
14440e3ece09SJunchao Zhang     m = A->rmap->n; // Z's layout
14450e3ece09SJunchao Zhang     n = B->cmap->n;
14460e3ece09SJunchao Zhang     M = A->rmap->N;
14470e3ece09SJunchao Zhang     N = B->cmap->N;
14480e3ece09SJunchao Zhang     PetscCall(MatCreate(comm, &Z));
14490e3ece09SJunchao Zhang     PetscCall(MatSetSizes(Z, m, n, M, N));
14500e3ece09SJunchao Zhang     PetscCall(PetscLayoutSetUp(Z->rmap));
14510e3ece09SJunchao Zhang     PetscCall(PetscLayoutSetUp(Z->cmap));
14520e3ece09SJunchao Zhang     PetscCall(MatSetType(Z, MATMPIAIJKOKKOS));
14530e3ece09SJunchao Zhang     PetscCall(MatSetMPIAIJKokkosWithSplitSeqAIJKokkosMatrices(Z, Zd, Zo, mmAB->garray));
14540e3ece09SJunchao Zhang 
14550e3ece09SJunchao Zhang     auto mmAtB = new MatMatStruct_AtB();
14560e3ece09SJunchao Zhang     PetscCall(MatProductSymbolic_MPIAIJKokkos_AtB(product, B, Z, mmAtB)); // final result C stored as mmAtB->{Cd, Co}
14570e3ece09SJunchao Zhang 
14580e3ece09SJunchao Zhang     pdata->Z = Z; // give ownership to pdata
14590e3ece09SJunchao Zhang     mm = pdata->mmAtB = mmAtB;
1460076ba34aSJunchao Zhang   }
14610e3ece09SJunchao Zhang 
14620e3ece09SJunchao Zhang   PetscCall(MatCreateSeqAIJKokkosWithKokkosCsrMatrix(PETSC_COMM_SELF, mm->Cd, &Cd));
14630e3ece09SJunchao Zhang   PetscCall(MatCreateSeqAIJKokkosWithKokkosCsrMatrix(PETSC_COMM_SELF, mm->Co, &Co));
14640e3ece09SJunchao Zhang   PetscCall(MatSetMPIAIJKokkosWithSplitSeqAIJKokkosMatrices(C, Cd, Co, mm->garray));
14650e3ece09SJunchao Zhang 
14660e3ece09SJunchao Zhang   C->product->data       = pdata;
1467076ba34aSJunchao Zhang   C->product->destroy    = MatProductDataDestroy_MPIAIJKokkos;
1468076ba34aSJunchao Zhang   C->ops->productnumeric = MatProductNumeric_MPIAIJKokkos;
14693ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
1470076ba34aSJunchao Zhang }
1471076ba34aSJunchao Zhang 
1472d71ae5a4SJacob Faibussowitsch PETSC_INTERN PetscErrorCode MatProductSetFromOptions_MPIAIJKokkos(Mat mat)
1473d71ae5a4SJacob Faibussowitsch {
1474076ba34aSJunchao Zhang   Mat_Product *product = mat->product;
1475076ba34aSJunchao Zhang   PetscBool    match   = PETSC_FALSE;
1476076ba34aSJunchao Zhang   PetscBool    usecpu  = PETSC_FALSE;
1477076ba34aSJunchao Zhang 
1478076ba34aSJunchao Zhang   PetscFunctionBegin;
1479076ba34aSJunchao Zhang   MatCheckProduct(mat, 1);
148048a46eb9SPierre Jolivet   if (!product->A->boundtocpu && !product->B->boundtocpu) PetscCall(PetscObjectTypeCompare((PetscObject)product->B, ((PetscObject)product->A)->type_name, &match));
1481076ba34aSJunchao Zhang   if (match) { /* we can always fallback to the CPU if requested */
1482076ba34aSJunchao Zhang     switch (product->type) {
1483076ba34aSJunchao Zhang     case MATPRODUCT_AB:
1484076ba34aSJunchao Zhang       if (product->api_user) {
1485d0609cedSBarry Smith         PetscOptionsBegin(PetscObjectComm((PetscObject)mat), ((PetscObject)mat)->prefix, "MatMatMult", "Mat");
14869566063dSJacob Faibussowitsch         PetscCall(PetscOptionsBool("-matmatmult_backend_cpu", "Use CPU code", "MatMatMult", usecpu, &usecpu, NULL));
1487d0609cedSBarry Smith         PetscOptionsEnd();
1488076ba34aSJunchao Zhang       } else {
1489d0609cedSBarry Smith         PetscOptionsBegin(PetscObjectComm((PetscObject)mat), ((PetscObject)mat)->prefix, "MatProduct_AB", "Mat");
14909566063dSJacob Faibussowitsch         PetscCall(PetscOptionsBool("-mat_product_algorithm_backend_cpu", "Use CPU code", "MatMatMult", usecpu, &usecpu, NULL));
1491d0609cedSBarry Smith         PetscOptionsEnd();
1492076ba34aSJunchao Zhang       }
1493076ba34aSJunchao Zhang       break;
1494076ba34aSJunchao Zhang     case MATPRODUCT_AtB:
1495076ba34aSJunchao Zhang       if (product->api_user) {
1496d0609cedSBarry Smith         PetscOptionsBegin(PetscObjectComm((PetscObject)mat), ((PetscObject)mat)->prefix, "MatTransposeMatMult", "Mat");
14979566063dSJacob Faibussowitsch         PetscCall(PetscOptionsBool("-mattransposematmult_backend_cpu", "Use CPU code", "MatTransposeMatMult", usecpu, &usecpu, NULL));
1498d0609cedSBarry Smith         PetscOptionsEnd();
1499076ba34aSJunchao Zhang       } else {
1500d0609cedSBarry Smith         PetscOptionsBegin(PetscObjectComm((PetscObject)mat), ((PetscObject)mat)->prefix, "MatProduct_AtB", "Mat");
15019566063dSJacob Faibussowitsch         PetscCall(PetscOptionsBool("-mat_product_algorithm_backend_cpu", "Use CPU code", "MatTransposeMatMult", usecpu, &usecpu, NULL));
1502d0609cedSBarry Smith         PetscOptionsEnd();
1503076ba34aSJunchao Zhang       }
1504076ba34aSJunchao Zhang       break;
1505076ba34aSJunchao Zhang     case MATPRODUCT_PtAP:
1506076ba34aSJunchao Zhang       if (product->api_user) {
1507d0609cedSBarry Smith         PetscOptionsBegin(PetscObjectComm((PetscObject)mat), ((PetscObject)mat)->prefix, "MatPtAP", "Mat");
15089566063dSJacob Faibussowitsch         PetscCall(PetscOptionsBool("-matptap_backend_cpu", "Use CPU code", "MatPtAP", usecpu, &usecpu, NULL));
1509d0609cedSBarry Smith         PetscOptionsEnd();
1510076ba34aSJunchao Zhang       } else {
1511d0609cedSBarry Smith         PetscOptionsBegin(PetscObjectComm((PetscObject)mat), ((PetscObject)mat)->prefix, "MatProduct_PtAP", "Mat");
15129566063dSJacob Faibussowitsch         PetscCall(PetscOptionsBool("-mat_product_algorithm_backend_cpu", "Use CPU code", "MatPtAP", usecpu, &usecpu, NULL));
1513d0609cedSBarry Smith         PetscOptionsEnd();
1514076ba34aSJunchao Zhang       }
1515076ba34aSJunchao Zhang       break;
1516d71ae5a4SJacob Faibussowitsch     default:
1517d71ae5a4SJacob Faibussowitsch       break;
1518076ba34aSJunchao Zhang     }
1519076ba34aSJunchao Zhang     match = (PetscBool)!usecpu;
1520076ba34aSJunchao Zhang   }
1521076ba34aSJunchao Zhang   if (match) {
1522076ba34aSJunchao Zhang     switch (product->type) {
1523076ba34aSJunchao Zhang     case MATPRODUCT_AB:
1524076ba34aSJunchao Zhang     case MATPRODUCT_AtB:
1525d71ae5a4SJacob Faibussowitsch     case MATPRODUCT_PtAP:
1526d71ae5a4SJacob Faibussowitsch       mat->ops->productsymbolic = MatProductSymbolic_MPIAIJKokkos;
1527d71ae5a4SJacob Faibussowitsch       break;
1528d71ae5a4SJacob Faibussowitsch     default:
1529d71ae5a4SJacob Faibussowitsch       break;
1530076ba34aSJunchao Zhang     }
1531076ba34aSJunchao Zhang   }
1532076ba34aSJunchao Zhang   /* fallback to MPIAIJ ops */
153348a46eb9SPierre Jolivet   if (!mat->ops->productsymbolic) PetscCall(MatProductSetFromOptions_MPIAIJ(mat));
15343ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
1535076ba34aSJunchao Zhang }
1536076ba34aSJunchao Zhang 
1537d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatSetPreallocationCOO_MPIAIJKokkos(Mat mat, PetscCount coo_n, PetscInt coo_i[], PetscInt coo_j[])
1538d71ae5a4SJacob Faibussowitsch {
1539394ed5ebSJunchao Zhang   Mat_MPIAIJ       *mpiaij = (Mat_MPIAIJ *)mat->data;
1540cbc6b225SStefano Zampini   Mat_MPIAIJKokkos *mpikok;
154142550becSJunchao Zhang 
154242550becSJunchao Zhang   PetscFunctionBegin;
154330203840SJunchao Zhang   PetscCall(MatSetPreallocationCOO_MPIAIJ(mat, coo_n, coo_i, coo_j)); /* mpiaij->A,B's type is set to seqaijkokkos */
1544cbc6b225SStefano Zampini   mat->preallocated = PETSC_TRUE;
15459566063dSJacob Faibussowitsch   PetscCall(MatAssemblyBegin(mat, MAT_FINAL_ASSEMBLY));
15469566063dSJacob Faibussowitsch   PetscCall(MatAssemblyEnd(mat, MAT_FINAL_ASSEMBLY));
15479566063dSJacob Faibussowitsch   PetscCall(MatZeroEntries(mat));
1548cbc6b225SStefano Zampini   mpikok = static_cast<Mat_MPIAIJKokkos *>(mpiaij->spptr);
1549cbc6b225SStefano Zampini   delete mpikok;
1550394ed5ebSJunchao Zhang   mpiaij->spptr = new Mat_MPIAIJKokkos(mpiaij);
15513ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
155242550becSJunchao Zhang }
155342550becSJunchao Zhang 
1554d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatSetValuesCOO_MPIAIJKokkos(Mat mat, const PetscScalar v[], InsertMode imode)
1555d71ae5a4SJacob Faibussowitsch {
1556394ed5ebSJunchao Zhang   Mat_MPIAIJ                 *mpiaij = static_cast<Mat_MPIAIJ *>(mat->data);
155742550becSJunchao Zhang   Mat_MPIAIJKokkos           *mpikok = static_cast<Mat_MPIAIJKokkos *>(mpiaij->spptr);
155842550becSJunchao Zhang   Mat                         A = mpiaij->A, B = mpiaij->B;
1559158ec288SJunchao Zhang   PetscCount                  Annz = mpiaij->Annz, Annz2 = mpiaij->Annz2, Bnnz = mpiaij->Bnnz, Bnnz2 = mpiaij->Bnnz2;
156042550becSJunchao Zhang   MatScalarKokkosView         Aa, Ba;
1561394ed5ebSJunchao Zhang   MatScalarKokkosView         v1;
156242550becSJunchao Zhang   MatScalarKokkosView        &vsend  = mpikok->sendbuf_d;
156342550becSJunchao Zhang   const MatScalarKokkosView  &v2     = mpikok->recvbuf_d;
1564158ec288SJunchao Zhang   const PetscCountKokkosView &Ajmap1 = mpikok->Ajmap1_d, Ajmap2 = mpikok->Ajmap2_d, Aimap2 = mpikok->Aimap2_d;
1565158ec288SJunchao Zhang   const PetscCountKokkosView &Bjmap1 = mpikok->Bjmap1_d, Bjmap2 = mpikok->Bjmap2_d, Bimap2 = mpikok->Bimap2_d;
1566394ed5ebSJunchao Zhang   const PetscCountKokkosView &Aperm1 = mpikok->Aperm1_d, Aperm2 = mpikok->Aperm2_d, Bperm1 = mpikok->Bperm1_d, Bperm2 = mpikok->Bperm2_d;
1567394ed5ebSJunchao Zhang   const PetscCountKokkosView &Cperm1 = mpikok->Cperm1_d;
156842550becSJunchao Zhang   PetscMemType                memtype;
156942550becSJunchao Zhang 
157042550becSJunchao Zhang   PetscFunctionBegin;
15719566063dSJacob Faibussowitsch   PetscCall(PetscGetMemType(v, &memtype)); /* Return PETSC_MEMTYPE_HOST when v is NULL */
157242550becSJunchao Zhang   if (PetscMemTypeHost(memtype)) {         /* If user gave v[] in host, we need to copy it to device if any */
1573394ed5ebSJunchao Zhang     v1 = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), MatScalarKokkosViewHost((PetscScalar *)v, mpiaij->coo_n));
157442550becSJunchao Zhang   } else {
1575394ed5ebSJunchao Zhang     v1 = MatScalarKokkosView((PetscScalar *)v, mpiaij->coo_n); /* Directly use v[]'s memory */
157642550becSJunchao Zhang   }
157742550becSJunchao Zhang 
157842550becSJunchao Zhang   if (imode == INSERT_VALUES) {
15799566063dSJacob Faibussowitsch     PetscCall(MatSeqAIJGetKokkosViewWrite(A, &Aa)); /* write matrix values */
15809566063dSJacob Faibussowitsch     PetscCall(MatSeqAIJGetKokkosViewWrite(B, &Ba));
1581394ed5ebSJunchao Zhang   } else {
15829566063dSJacob Faibussowitsch     PetscCall(MatSeqAIJGetKokkosView(A, &Aa)); /* read & write matrix values */
15839566063dSJacob Faibussowitsch     PetscCall(MatSeqAIJGetKokkosView(B, &Ba));
158442550becSJunchao Zhang   }
158542550becSJunchao Zhang 
158642550becSJunchao Zhang   /* Pack entries to be sent to remote */
15879371c9d4SSatish Balay   Kokkos::parallel_for(
15889371c9d4SSatish Balay     vsend.extent(0), KOKKOS_LAMBDA(const PetscCount i) { vsend(i) = v1(Cperm1(i)); });
158942550becSJunchao Zhang 
159042550becSJunchao Zhang   /* Send remote entries to their owner and overlap the communication with local computation */
15919566063dSJacob Faibussowitsch   PetscCall(PetscSFReduceWithMemTypeBegin(mpiaij->coo_sf, MPIU_SCALAR, PETSC_MEMTYPE_KOKKOS, vsend.data(), PETSC_MEMTYPE_KOKKOS, v2.data(), MPI_REPLACE));
1592158ec288SJunchao Zhang   /* Add local entries to A and B in one kernel */
15939371c9d4SSatish Balay   Kokkos::parallel_for(
15949371c9d4SSatish Balay     Annz + Bnnz, KOKKOS_LAMBDA(PetscCount i) {
1595158ec288SJunchao Zhang       PetscScalar sum = 0.0;
1596158ec288SJunchao Zhang       if (i < Annz) {
1597158ec288SJunchao Zhang         for (PetscCount k = Ajmap1(i); k < Ajmap1(i + 1); k++) sum += v1(Aperm1(k));
1598ac38520cSJunchao Zhang         Aa(i) = (imode == INSERT_VALUES ? 0.0 : Aa(i)) + sum;
1599158ec288SJunchao Zhang       } else {
1600158ec288SJunchao Zhang         i -= Annz;
1601158ec288SJunchao Zhang         for (PetscCount k = Bjmap1(i); k < Bjmap1(i + 1); k++) sum += v1(Bperm1(k));
1602ac38520cSJunchao Zhang         Ba(i) = (imode == INSERT_VALUES ? 0.0 : Ba(i)) + sum;
1603158ec288SJunchao Zhang       }
1604158ec288SJunchao Zhang     });
16059566063dSJacob Faibussowitsch   PetscCall(PetscSFReduceEnd(mpiaij->coo_sf, MPIU_SCALAR, vsend.data(), v2.data(), MPI_REPLACE));
160642550becSJunchao Zhang 
1607158ec288SJunchao Zhang   /* Add received remote entries to A and B in one kernel */
16089371c9d4SSatish Balay   Kokkos::parallel_for(
16099371c9d4SSatish Balay     Annz2 + Bnnz2, KOKKOS_LAMBDA(PetscCount i) {
1610158ec288SJunchao Zhang       if (i < Annz2) {
1611158ec288SJunchao Zhang         for (PetscCount k = Ajmap2(i); k < Ajmap2(i + 1); k++) Aa(Aimap2(i)) += v2(Aperm2(k));
1612158ec288SJunchao Zhang       } else {
1613158ec288SJunchao Zhang         i -= Annz2;
1614158ec288SJunchao Zhang         for (PetscCount k = Bjmap2(i); k < Bjmap2(i + 1); k++) Ba(Bimap2(i)) += v2(Bperm2(k));
1615158ec288SJunchao Zhang       }
1616158ec288SJunchao Zhang     });
161742550becSJunchao Zhang 
1618394ed5ebSJunchao Zhang   if (imode == INSERT_VALUES) {
16199566063dSJacob Faibussowitsch     PetscCall(MatSeqAIJRestoreKokkosViewWrite(A, &Aa)); /* Increase A & B's state etc. */
16209566063dSJacob Faibussowitsch     PetscCall(MatSeqAIJRestoreKokkosViewWrite(B, &Ba));
1621394ed5ebSJunchao Zhang   } else {
16229566063dSJacob Faibussowitsch     PetscCall(MatSeqAIJRestoreKokkosView(A, &Aa));
16239566063dSJacob Faibussowitsch     PetscCall(MatSeqAIJRestoreKokkosView(B, &Ba));
1624394ed5ebSJunchao Zhang   }
16253ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
162642550becSJunchao Zhang }
162742550becSJunchao Zhang 
1628d71ae5a4SJacob Faibussowitsch PetscErrorCode MatDestroy_MPIAIJKokkos(Mat A)
1629d71ae5a4SJacob Faibussowitsch {
163042550becSJunchao Zhang   Mat_MPIAIJ *mpiaij = (Mat_MPIAIJ *)A->data;
1631076ba34aSJunchao Zhang 
1632076ba34aSJunchao Zhang   PetscFunctionBegin;
16339566063dSJacob Faibussowitsch   PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatMPIAIJSetPreallocation_C", NULL));
16349566063dSJacob Faibussowitsch   PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatMPIAIJGetLocalMatMerge_C", NULL));
16359566063dSJacob Faibussowitsch   PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatSetPreallocationCOO_C", NULL));
16369566063dSJacob Faibussowitsch   PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatSetValuesCOO_C", NULL));
163742550becSJunchao Zhang   delete (Mat_MPIAIJKokkos *)mpiaij->spptr;
16389566063dSJacob Faibussowitsch   PetscCall(MatDestroy_MPIAIJ(A));
16393ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
1640076ba34aSJunchao Zhang }
1641076ba34aSJunchao Zhang 
1642d71ae5a4SJacob Faibussowitsch PETSC_INTERN PetscErrorCode MatConvert_MPIAIJ_MPIAIJKokkos(Mat A, MatType mtype, MatReuse reuse, Mat *newmat)
1643d71ae5a4SJacob Faibussowitsch {
16448c3ff71bSJunchao Zhang   Mat         B;
1645076ba34aSJunchao Zhang   Mat_MPIAIJ *a;
16468c3ff71bSJunchao Zhang 
16478c3ff71bSJunchao Zhang   PetscFunctionBegin;
16488c3ff71bSJunchao Zhang   if (reuse == MAT_INITIAL_MATRIX) {
16499566063dSJacob Faibussowitsch     PetscCall(MatDuplicate(A, MAT_COPY_VALUES, newmat));
16508c3ff71bSJunchao Zhang   } else if (reuse == MAT_REUSE_MATRIX) {
16519566063dSJacob Faibussowitsch     PetscCall(MatCopy(A, *newmat, SAME_NONZERO_PATTERN));
16528c3ff71bSJunchao Zhang   }
16538c3ff71bSJunchao Zhang   B = *newmat;
16548c3ff71bSJunchao Zhang 
16556f3d89d0SStefano Zampini   B->boundtocpu = PETSC_FALSE;
16569566063dSJacob Faibussowitsch   PetscCall(PetscFree(B->defaultvectype));
16579566063dSJacob Faibussowitsch   PetscCall(PetscStrallocpy(VECKOKKOS, &B->defaultvectype));
16589566063dSJacob Faibussowitsch   PetscCall(PetscObjectChangeTypeName((PetscObject)B, MATMPIAIJKOKKOS));
16598c3ff71bSJunchao Zhang 
1660076ba34aSJunchao Zhang   a = static_cast<Mat_MPIAIJ *>(A->data);
16619566063dSJacob Faibussowitsch   if (a->A) PetscCall(MatSetType(a->A, MATSEQAIJKOKKOS));
16629566063dSJacob Faibussowitsch   if (a->B) PetscCall(MatSetType(a->B, MATSEQAIJKOKKOS));
16639566063dSJacob Faibussowitsch   if (a->lvec) PetscCall(VecSetType(a->lvec, VECSEQKOKKOS));
1664076ba34aSJunchao Zhang 
16658c3ff71bSJunchao Zhang   B->ops->assemblyend           = MatAssemblyEnd_MPIAIJKokkos;
16668c3ff71bSJunchao Zhang   B->ops->mult                  = MatMult_MPIAIJKokkos;
16678c3ff71bSJunchao Zhang   B->ops->multadd               = MatMultAdd_MPIAIJKokkos;
16688c3ff71bSJunchao Zhang   B->ops->multtranspose         = MatMultTranspose_MPIAIJKokkos;
1669076ba34aSJunchao Zhang   B->ops->productsetfromoptions = MatProductSetFromOptions_MPIAIJKokkos;
1670076ba34aSJunchao Zhang   B->ops->destroy               = MatDestroy_MPIAIJKokkos;
16718c3ff71bSJunchao Zhang 
16729566063dSJacob Faibussowitsch   PetscCall(PetscObjectComposeFunction((PetscObject)B, "MatMPIAIJSetPreallocation_C", MatMPIAIJSetPreallocation_MPIAIJKokkos));
16739566063dSJacob Faibussowitsch   PetscCall(PetscObjectComposeFunction((PetscObject)B, "MatMPIAIJGetLocalMatMerge_C", MatMPIAIJGetLocalMatMerge_MPIAIJKokkos));
16749566063dSJacob Faibussowitsch   PetscCall(PetscObjectComposeFunction((PetscObject)B, "MatSetPreallocationCOO_C", MatSetPreallocationCOO_MPIAIJKokkos));
16759566063dSJacob Faibussowitsch   PetscCall(PetscObjectComposeFunction((PetscObject)B, "MatSetValuesCOO_C", MatSetValuesCOO_MPIAIJKokkos));
16763ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
16778c3ff71bSJunchao Zhang }
16783f3ba80aSJunchao Zhang /*MC
167911a5261eSBarry Smith    MATAIJKOKKOS - "mpiaijkokkos", a matrix type to be used for CSR sparse matrices with Kokkos
16808c3ff71bSJunchao Zhang 
16813f3ba80aSJunchao Zhang    A matrix type type using Kokkos-Kernels CrsMatrix type for portability across different device types
16823f3ba80aSJunchao Zhang 
16832ef1f0ffSBarry Smith    Options Database Key:
16842ef1f0ffSBarry Smith .  -mat_type aijkokkos - sets the matrix type to `MATAIJKOKKOS`
16853f3ba80aSJunchao Zhang 
16863f3ba80aSJunchao Zhang   Level: beginner
16873f3ba80aSJunchao Zhang 
16882ef1f0ffSBarry Smith .seealso: [](chapter_matrices), `Mat`, `MatCreateAIJKokkos()`, `MATSEQAIJKOKKOS`, `MATSEQAIJ`, `MATMPIAIJ`
16893f3ba80aSJunchao Zhang M*/
1690d71ae5a4SJacob Faibussowitsch PETSC_EXTERN PetscErrorCode MatCreate_MPIAIJKokkos(Mat A)
1691d71ae5a4SJacob Faibussowitsch {
16928c3ff71bSJunchao Zhang   PetscFunctionBegin;
16939566063dSJacob Faibussowitsch   PetscCall(PetscKokkosInitializeCheck());
16949566063dSJacob Faibussowitsch   PetscCall(MatCreate_MPIAIJ(A));
16959566063dSJacob Faibussowitsch   PetscCall(MatConvert_MPIAIJ_MPIAIJKokkos(A, MATMPIAIJKOKKOS, MAT_INPLACE_MATRIX, &A));
16963ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
16978c3ff71bSJunchao Zhang }
16988c3ff71bSJunchao Zhang 
16998c3ff71bSJunchao Zhang /*@C
170011a5261eSBarry Smith    MatCreateAIJKokkos - Creates a sparse matrix in `MATAIJKOKOS` (compressed row) format
17018c3ff71bSJunchao Zhang    (the default parallel PETSc format).  This matrix will ultimately pushed down
170220f4b53cSBarry Smith    to Kokkos for calculations.
17038c3ff71bSJunchao Zhang 
17048c3ff71bSJunchao Zhang    Collective
17058c3ff71bSJunchao Zhang 
17068c3ff71bSJunchao Zhang    Input Parameters:
170711a5261eSBarry Smith +  comm - MPI communicator, set to `PETSC_COMM_SELF`
170820f4b53cSBarry Smith .  m - number of local rows (or `PETSC_DECIDE` to have calculated if `M` is given)
170920f4b53cSBarry Smith            This value should be the same as the local size used in creating the
171020f4b53cSBarry Smith            y vector for the matrix-vector product y = Ax.
171120f4b53cSBarry Smith .  n - This value should be the same as the local size used in creating the
171220f4b53cSBarry Smith        x vector for the matrix-vector product y = Ax. (or `PETSC_DECIDE` to have
171320f4b53cSBarry Smith        calculated if N is given) For square matrices n is almost always `m`.
171420f4b53cSBarry Smith .  M - number of global rows (or `PETSC_DETERMINE` to have calculated if `m` is given)
171520f4b53cSBarry Smith .  N - number of global columns (or `PETSC_DETERMINE` to have calculated if `n` is given)
171620f4b53cSBarry Smith .  d_nz  - number of nonzeros per row in DIAGONAL portion of local submatrix
171720f4b53cSBarry Smith            (same value is used for all local rows)
171820f4b53cSBarry Smith .  d_nnz - array containing the number of nonzeros in the various rows of the
171920f4b53cSBarry Smith            DIAGONAL portion of the local submatrix (possibly different for each row)
172020f4b53cSBarry Smith            or `NULL`, if `d_nz` is used to specify the nonzero structure.
172120f4b53cSBarry Smith            The size of this array is equal to the number of local rows, i.e `m`.
172220f4b53cSBarry Smith            For matrices you plan to factor you must leave room for the diagonal entry and
172320f4b53cSBarry Smith            put in the entry even if it is zero.
172420f4b53cSBarry Smith .  o_nz  - number of nonzeros per row in the OFF-DIAGONAL portion of local
172520f4b53cSBarry Smith            submatrix (same value is used for all local rows).
172620f4b53cSBarry Smith -  o_nnz - array containing the number of nonzeros in the various rows of the
172720f4b53cSBarry Smith            OFF-DIAGONAL portion of the local submatrix (possibly different for
172820f4b53cSBarry Smith            each row) or `NULL`, if `o_nz` is used to specify the nonzero
172920f4b53cSBarry Smith            structure. The size of this array is equal to the number
173020f4b53cSBarry Smith            of local rows, i.e `m`.
17318c3ff71bSJunchao Zhang 
17328c3ff71bSJunchao Zhang    Output Parameter:
17338c3ff71bSJunchao Zhang .  A - the matrix
17348c3ff71bSJunchao Zhang 
17352ef1f0ffSBarry Smith    Level: intermediate
17362ef1f0ffSBarry Smith 
17372ef1f0ffSBarry Smith    Notes:
173811a5261eSBarry Smith    It is recommended that one use the `MatCreate()`, `MatSetType()` and/or `MatSetFromOptions()`,
17398c3ff71bSJunchao Zhang    MatXXXXSetPreallocation() paradigm instead of this routine directly.
174011a5261eSBarry Smith    [MatXXXXSetPreallocation() is, for example, `MatSeqAIJSetPreallocation()`]
17418c3ff71bSJunchao Zhang 
1742667f096bSBarry Smith    The AIJ format, also called compressed row storage), is fully compatible with standard Fortran
17438c3ff71bSJunchao Zhang    storage.  That is, the stored row and column indices can begin at
17442ef1f0ffSBarry Smith    either one (as in Fortran) or zero.
17458c3ff71bSJunchao Zhang 
17462ef1f0ffSBarry Smith .seealso: [](chapter_matrices), `Mat`, `MATAIJKOKOS`, `MATSEQAIJKOKOS`, `MATMPIAIJKOKOS`, `MatCreate()`, `MatCreateAIJ()`, `MatSetValues()`,
17472ef1f0ffSBarry Smith           `MatSeqAIJSetColumnIndices()`, `MatCreateSeqAIJWithArrays()`, `MatCreateAIJ()`, `MATMPIAIJKOKKOS`, `MATAIJKOKKOS`
17488c3ff71bSJunchao Zhang @*/
1749d71ae5a4SJacob Faibussowitsch PetscErrorCode MatCreateAIJKokkos(MPI_Comm comm, PetscInt m, PetscInt n, PetscInt M, PetscInt N, PetscInt d_nz, const PetscInt d_nnz[], PetscInt o_nz, const PetscInt o_nnz[], Mat *A)
1750d71ae5a4SJacob Faibussowitsch {
17518c3ff71bSJunchao Zhang   PetscMPIInt size;
17528c3ff71bSJunchao Zhang 
17538c3ff71bSJunchao Zhang   PetscFunctionBegin;
17549566063dSJacob Faibussowitsch   PetscCall(MatCreate(comm, A));
17559566063dSJacob Faibussowitsch   PetscCall(MatSetSizes(*A, m, n, M, N));
17569566063dSJacob Faibussowitsch   PetscCallMPI(MPI_Comm_size(comm, &size));
17578c3ff71bSJunchao Zhang   if (size > 1) {
17589566063dSJacob Faibussowitsch     PetscCall(MatSetType(*A, MATMPIAIJKOKKOS));
17599566063dSJacob Faibussowitsch     PetscCall(MatMPIAIJSetPreallocation(*A, d_nz, d_nnz, o_nz, o_nnz));
17608c3ff71bSJunchao Zhang   } else {
17619566063dSJacob Faibussowitsch     PetscCall(MatSetType(*A, MATSEQAIJKOKKOS));
17629566063dSJacob Faibussowitsch     PetscCall(MatSeqAIJSetPreallocation(*A, d_nz, d_nnz));
17638c3ff71bSJunchao Zhang   }
17643ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
17658c3ff71bSJunchao Zhang }
1766