xref: /petsc/src/mat/impls/aij/mpi/kokkos/mpiaijkok.kokkos.cxx (revision 2ef1f0ff6e3530e8731eb06ad663081f5844f49f)
111d22bbfSJunchao Zhang #include <petscvec_kokkos.hpp>
2f0e6e2d1SJunchao Zhang #include <petscpkg_version.h>
3076ba34aSJunchao Zhang #include <petscsf.h>
442550becSJunchao Zhang #include <petsc/private/sfimpl.h>
58c3ff71bSJunchao Zhang #include <../src/mat/impls/aij/mpi/mpiaij.h>
642550becSJunchao Zhang #include <../src/mat/impls/aij/mpi/kokkos/mpiaijkok.hpp>
7076ba34aSJunchao Zhang #include <KokkosSparse_spadd.hpp>
811d22bbfSJunchao Zhang 
9d71ae5a4SJacob Faibussowitsch PetscErrorCode MatAssemblyEnd_MPIAIJKokkos(Mat A, MatAssemblyType mode)
10d71ae5a4SJacob Faibussowitsch {
115519a089SJose E. Roman   Mat_SeqAIJKokkos *aijkok;
1230203840SJunchao Zhang   Mat_MPIAIJ       *mpiaij = (Mat_MPIAIJ *)A->data;
138c3ff71bSJunchao Zhang 
148c3ff71bSJunchao Zhang   PetscFunctionBegin;
159566063dSJacob Faibussowitsch   PetscCall(MatAssemblyEnd_MPIAIJ(A, mode));
1630203840SJunchao Zhang   /* E.g., MatCreateSubMatrix() calls MatCreateMPIAIJWithSeqAIJ(comm,A,B,..), which creates Bnew of SEQAIJ and destroys B of SEQAIJKOKKOS.
1730203840SJunchao Zhang      Thus we finalize A/B/lvec's type in MatAssemblyEnd() to handle various cases.
1830203840SJunchao Zhang    */
1930203840SJunchao Zhang   if (mode == MAT_FINAL_ASSEMBLY) {
2030203840SJunchao Zhang     PetscCall(MatSetType(mpiaij->A, MATSEQAIJKOKKOS));
2130203840SJunchao Zhang     PetscCall(MatSetType(mpiaij->B, MATSEQAIJKOKKOS));
2230203840SJunchao Zhang     PetscCall(VecSetType(mpiaij->lvec, VECSEQKOKKOS));
2330203840SJunchao Zhang   }
245519a089SJose E. Roman   aijkok = static_cast<Mat_SeqAIJKokkos *>(((Mat_MPIAIJ *)A->data)->A->spptr); /* Access spptr after MatAssemblyEnd_MPIAIJ(), which might have deleted old spptr */
25a587d139SMark   if (aijkok && aijkok->device_mat_d.data()) {
26a587d139SMark     A->offloadmask = PETSC_OFFLOAD_GPU; // in GPU mode, no going back. MatSetValues checks this
27a587d139SMark   }
28a587d139SMark 
293ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
308c3ff71bSJunchao Zhang }
318c3ff71bSJunchao Zhang 
32d71ae5a4SJacob Faibussowitsch PetscErrorCode MatMPIAIJSetPreallocation_MPIAIJKokkos(Mat mat, PetscInt d_nz, const PetscInt d_nnz[], PetscInt o_nz, const PetscInt o_nnz[])
33d71ae5a4SJacob Faibussowitsch {
348c3ff71bSJunchao Zhang   Mat_MPIAIJ *mpiaij = (Mat_MPIAIJ *)mat->data;
358c3ff71bSJunchao Zhang 
368c3ff71bSJunchao Zhang   PetscFunctionBegin;
379566063dSJacob Faibussowitsch   PetscCall(PetscLayoutSetUp(mat->rmap));
389566063dSJacob Faibussowitsch   PetscCall(PetscLayoutSetUp(mat->cmap));
396a29ce69SStefano Zampini #if defined(PETSC_USE_DEBUG)
408c3ff71bSJunchao Zhang   if (d_nnz) {
416a29ce69SStefano Zampini     PetscInt i;
42ad540459SPierre Jolivet     for (i = 0; i < mat->rmap->n; i++) PetscCheck(d_nnz[i] >= 0, PETSC_COMM_SELF, PETSC_ERR_ARG_OUTOFRANGE, "d_nnz cannot be less than 0: local row %" PetscInt_FMT " value %" PetscInt_FMT, i, d_nnz[i]);
438c3ff71bSJunchao Zhang   }
448c3ff71bSJunchao Zhang   if (o_nnz) {
456a29ce69SStefano Zampini     PetscInt i;
46ad540459SPierre Jolivet     for (i = 0; i < mat->rmap->n; i++) PetscCheck(o_nnz[i] >= 0, PETSC_COMM_SELF, PETSC_ERR_ARG_OUTOFRANGE, "o_nnz cannot be less than 0: local row %" PetscInt_FMT " value %" PetscInt_FMT, i, o_nnz[i]);
478c3ff71bSJunchao Zhang   }
486a29ce69SStefano Zampini #endif
496a29ce69SStefano Zampini #if defined(PETSC_USE_CTABLE)
50eec179cfSJacob Faibussowitsch   PetscCall(PetscHMapIDestroy(&mpiaij->colmap));
516a29ce69SStefano Zampini #else
529566063dSJacob Faibussowitsch   PetscCall(PetscFree(mpiaij->colmap));
536a29ce69SStefano Zampini #endif
549566063dSJacob Faibussowitsch   PetscCall(PetscFree(mpiaij->garray));
559566063dSJacob Faibussowitsch   PetscCall(VecDestroy(&mpiaij->lvec));
569566063dSJacob Faibussowitsch   PetscCall(VecScatterDestroy(&mpiaij->Mvctx));
576a29ce69SStefano Zampini   /* Because the B will have been resized we simply destroy it and create a new one each time */
589566063dSJacob Faibussowitsch   PetscCall(MatDestroy(&mpiaij->B));
596a29ce69SStefano Zampini 
606a29ce69SStefano Zampini   if (!mpiaij->A) {
619566063dSJacob Faibussowitsch     PetscCall(MatCreate(PETSC_COMM_SELF, &mpiaij->A));
629566063dSJacob Faibussowitsch     PetscCall(MatSetSizes(mpiaij->A, mat->rmap->n, mat->cmap->n, mat->rmap->n, mat->cmap->n));
636a29ce69SStefano Zampini   }
646a29ce69SStefano Zampini   if (!mpiaij->B) {
656a29ce69SStefano Zampini     PetscMPIInt size;
669566063dSJacob Faibussowitsch     PetscCallMPI(MPI_Comm_size(PetscObjectComm((PetscObject)mat), &size));
679566063dSJacob Faibussowitsch     PetscCall(MatCreate(PETSC_COMM_SELF, &mpiaij->B));
689566063dSJacob Faibussowitsch     PetscCall(MatSetSizes(mpiaij->B, mat->rmap->n, size > 1 ? mat->cmap->N : 0, mat->rmap->n, size > 1 ? mat->cmap->N : 0));
698c3ff71bSJunchao Zhang   }
709566063dSJacob Faibussowitsch   PetscCall(MatSetType(mpiaij->A, MATSEQAIJKOKKOS));
719566063dSJacob Faibussowitsch   PetscCall(MatSetType(mpiaij->B, MATSEQAIJKOKKOS));
729566063dSJacob Faibussowitsch   PetscCall(MatSeqAIJSetPreallocation(mpiaij->A, d_nz, d_nnz));
739566063dSJacob Faibussowitsch   PetscCall(MatSeqAIJSetPreallocation(mpiaij->B, o_nz, o_nnz));
748c3ff71bSJunchao Zhang   mat->preallocated = PETSC_TRUE;
753ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
768c3ff71bSJunchao Zhang }
778c3ff71bSJunchao Zhang 
78d71ae5a4SJacob Faibussowitsch PetscErrorCode MatMult_MPIAIJKokkos(Mat mat, Vec xx, Vec yy)
79d71ae5a4SJacob Faibussowitsch {
808c3ff71bSJunchao Zhang   Mat_MPIAIJ *mpiaij = (Mat_MPIAIJ *)mat->data;
818c3ff71bSJunchao Zhang   PetscInt    nt;
828c3ff71bSJunchao Zhang 
838c3ff71bSJunchao Zhang   PetscFunctionBegin;
849566063dSJacob Faibussowitsch   PetscCall(VecGetLocalSize(xx, &nt));
8508401ef6SPierre Jolivet   PetscCheck(nt == mat->cmap->n, PETSC_COMM_SELF, PETSC_ERR_ARG_SIZ, "Incompatible partition of mat (%" PetscInt_FMT ") and xx (%" PetscInt_FMT ")", mat->cmap->n, nt);
869566063dSJacob Faibussowitsch   PetscCall(VecScatterBegin(mpiaij->Mvctx, xx, mpiaij->lvec, INSERT_VALUES, SCATTER_FORWARD));
879566063dSJacob Faibussowitsch   PetscCall((*mpiaij->A->ops->mult)(mpiaij->A, xx, yy));
889566063dSJacob Faibussowitsch   PetscCall(VecScatterEnd(mpiaij->Mvctx, xx, mpiaij->lvec, INSERT_VALUES, SCATTER_FORWARD));
899566063dSJacob Faibussowitsch   PetscCall((*mpiaij->B->ops->multadd)(mpiaij->B, mpiaij->lvec, yy, yy));
903ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
918c3ff71bSJunchao Zhang }
928c3ff71bSJunchao Zhang 
93d71ae5a4SJacob Faibussowitsch PetscErrorCode MatMultAdd_MPIAIJKokkos(Mat mat, Vec xx, Vec yy, Vec zz)
94d71ae5a4SJacob Faibussowitsch {
958c3ff71bSJunchao Zhang   Mat_MPIAIJ *mpiaij = (Mat_MPIAIJ *)mat->data;
968c3ff71bSJunchao Zhang   PetscInt    nt;
978c3ff71bSJunchao Zhang 
988c3ff71bSJunchao Zhang   PetscFunctionBegin;
999566063dSJacob Faibussowitsch   PetscCall(VecGetLocalSize(xx, &nt));
10008401ef6SPierre Jolivet   PetscCheck(nt == mat->cmap->n, PETSC_COMM_SELF, PETSC_ERR_ARG_SIZ, "Incompatible partition of mat (%" PetscInt_FMT ") and xx (%" PetscInt_FMT ")", mat->cmap->n, nt);
1019566063dSJacob Faibussowitsch   PetscCall(VecScatterBegin(mpiaij->Mvctx, xx, mpiaij->lvec, INSERT_VALUES, SCATTER_FORWARD));
1029566063dSJacob Faibussowitsch   PetscCall((*mpiaij->A->ops->multadd)(mpiaij->A, xx, yy, zz));
1039566063dSJacob Faibussowitsch   PetscCall(VecScatterEnd(mpiaij->Mvctx, xx, mpiaij->lvec, INSERT_VALUES, SCATTER_FORWARD));
1049566063dSJacob Faibussowitsch   PetscCall((*mpiaij->B->ops->multadd)(mpiaij->B, mpiaij->lvec, zz, zz));
1053ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
1068c3ff71bSJunchao Zhang }
1078c3ff71bSJunchao Zhang 
108d71ae5a4SJacob Faibussowitsch PetscErrorCode MatMultTranspose_MPIAIJKokkos(Mat mat, Vec xx, Vec yy)
109d71ae5a4SJacob Faibussowitsch {
1108c3ff71bSJunchao Zhang   Mat_MPIAIJ *mpiaij = (Mat_MPIAIJ *)mat->data;
1118c3ff71bSJunchao Zhang   PetscInt    nt;
1128c3ff71bSJunchao Zhang 
1138c3ff71bSJunchao Zhang   PetscFunctionBegin;
1149566063dSJacob Faibussowitsch   PetscCall(VecGetLocalSize(xx, &nt));
11508401ef6SPierre Jolivet   PetscCheck(nt == mat->rmap->n, PETSC_COMM_SELF, PETSC_ERR_ARG_SIZ, "Incompatible partition of mat (%" PetscInt_FMT ") and xx (%" PetscInt_FMT ")", mat->rmap->n, nt);
1169566063dSJacob Faibussowitsch   PetscCall((*mpiaij->B->ops->multtranspose)(mpiaij->B, xx, mpiaij->lvec));
1179566063dSJacob Faibussowitsch   PetscCall((*mpiaij->A->ops->multtranspose)(mpiaij->A, xx, yy));
1189566063dSJacob Faibussowitsch   PetscCall(VecScatterBegin(mpiaij->Mvctx, mpiaij->lvec, yy, ADD_VALUES, SCATTER_REVERSE));
1199566063dSJacob Faibussowitsch   PetscCall(VecScatterEnd(mpiaij->Mvctx, mpiaij->lvec, yy, ADD_VALUES, SCATTER_REVERSE));
1203ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
1218c3ff71bSJunchao Zhang }
1228c3ff71bSJunchao Zhang 
123076ba34aSJunchao Zhang /* Merge the "A, B" matrices of mat into a matrix C.  mat's type is MPIAIJKOKKOS. C's type is MATSEQAIJKOKKOS.
124076ba34aSJunchao Zhang    A is put before B. C's size would be A->rmap->n by (A->cmap->n + B->cmap->n).
125076ba34aSJunchao Zhang    C still uses local column ids. Their corresponding global column ids are returned in glob.
126076ba34aSJunchao Zhang */
127d71ae5a4SJacob Faibussowitsch PetscErrorCode MatMPIAIJGetLocalMatMerge_MPIAIJKokkos(Mat mat, MatReuse reuse, IS *glob, Mat *C)
128d71ae5a4SJacob Faibussowitsch {
129076ba34aSJunchao Zhang   Mat             Ad, Ao;
130076ba34aSJunchao Zhang   const PetscInt *cmap;
131076ba34aSJunchao Zhang 
132076ba34aSJunchao Zhang   PetscFunctionBegin;
1339566063dSJacob Faibussowitsch   PetscCall(MatMPIAIJGetSeqAIJ(mat, &Ad, &Ao, &cmap));
1349566063dSJacob Faibussowitsch   PetscCall(MatSeqAIJKokkosMergeMats(Ad, Ao, reuse, C));
135076ba34aSJunchao Zhang   if (glob) {
136076ba34aSJunchao Zhang     PetscInt cst, i, dn, on, *gidx;
1379566063dSJacob Faibussowitsch     PetscCall(MatGetLocalSize(Ad, NULL, &dn));
1389566063dSJacob Faibussowitsch     PetscCall(MatGetLocalSize(Ao, NULL, &on));
1399566063dSJacob Faibussowitsch     PetscCall(MatGetOwnershipRangeColumn(mat, &cst, NULL));
1409566063dSJacob Faibussowitsch     PetscCall(PetscMalloc1(dn + on, &gidx));
141076ba34aSJunchao Zhang     for (i = 0; i < dn; i++) gidx[i] = cst + i;
142076ba34aSJunchao Zhang     for (i = 0; i < on; i++) gidx[i + dn] = cmap[i];
1439566063dSJacob Faibussowitsch     PetscCall(ISCreateGeneral(PetscObjectComm((PetscObject)Ad), dn + on, gidx, PETSC_OWN_POINTER, glob));
144076ba34aSJunchao Zhang   }
1453ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
146076ba34aSJunchao Zhang }
147076ba34aSJunchao Zhang 
148076ba34aSJunchao Zhang /* Structs used in matrix product C=AB, C=A^tB and C=B^tAB */
149076ba34aSJunchao Zhang struct MatMatStruct {
150076ba34aSJunchao Zhang   MatRowMapKokkosView Cdstart; /* Used to split sequential matrix into petsc's A, B format */
151076ba34aSJunchao Zhang   PetscSF             sf;      /* SF to send/recv matrix entries */
152076ba34aSJunchao Zhang   MatScalarKokkosView abuf;    /* buf of mat values in send/recv */
153076ba34aSJunchao Zhang   Mat                 C1, C2, B_local;
154076ba34aSJunchao Zhang   KokkosCsrMatrix     C1_global, C2_global, C_global;
155076ba34aSJunchao Zhang   KernelHandle        kh;
1563ba16761SJacob Faibussowitsch   MatMatStruct() noexcept : sf(nullptr), C1(nullptr), C2(nullptr), B_local(nullptr) { }
157076ba34aSJunchao Zhang 
158d71ae5a4SJacob Faibussowitsch   ~MatMatStruct()
159d71ae5a4SJacob Faibussowitsch   {
1603ba16761SJacob Faibussowitsch     PetscFunctionBegin;
1613ba16761SJacob Faibussowitsch     PetscCallAbort(PETSC_COMM_SELF, MatDestroy(&C1));
1623ba16761SJacob Faibussowitsch     PetscCallAbort(PETSC_COMM_SELF, MatDestroy(&C2));
1633ba16761SJacob Faibussowitsch     PetscCallAbort(PETSC_COMM_SELF, MatDestroy(&B_local));
1643ba16761SJacob Faibussowitsch     PetscCallAbort(PETSC_COMM_SELF, PetscSFDestroy(&sf));
165076ba34aSJunchao Zhang     kh.destroy_spadd_handle();
1663ba16761SJacob Faibussowitsch     PetscFunctionReturnVoid();
167076ba34aSJunchao Zhang   }
168076ba34aSJunchao Zhang };
169076ba34aSJunchao Zhang 
170076ba34aSJunchao Zhang struct MatMatStruct_AB : public MatMatStruct {
1713ba16761SJacob Faibussowitsch   MatColIdxKokkosView rows{};
1723ba16761SJacob Faibussowitsch   MatRowMapKokkosView rowoffset{};
1733ba16761SJacob Faibussowitsch   Mat                 B_other{}, C_petsc{}; /* SEQAIJKOKKOS matrices. TODO: have a better var name than C_petsc */
174f0e6e2d1SJunchao Zhang   MatColIdxKokkosView B_NzDiagLeft;         // Number of nonzeros on the left of B's diagonal block; Used to recover the unsplit B (i.e., local mat)
175076ba34aSJunchao Zhang 
1763ba16761SJacob Faibussowitsch   ~MatMatStruct_AB() noexcept
177d71ae5a4SJacob Faibussowitsch   {
1783ba16761SJacob Faibussowitsch     PetscFunctionBegin;
1793ba16761SJacob Faibussowitsch     PetscCallAbort(PETSC_COMM_SELF, MatDestroy(&B_other));
1803ba16761SJacob Faibussowitsch     PetscCallAbort(PETSC_COMM_SELF, MatDestroy(&C_petsc));
1813ba16761SJacob Faibussowitsch     PetscFunctionReturnVoid();
182076ba34aSJunchao Zhang   }
183076ba34aSJunchao Zhang };
184076ba34aSJunchao Zhang 
185076ba34aSJunchao Zhang struct MatMatStruct_AtB : public MatMatStruct {
186076ba34aSJunchao Zhang   MatRowMapKokkosView srcrowoffset, dstrowoffset;
187076ba34aSJunchao Zhang };
188076ba34aSJunchao Zhang 
1899371c9d4SSatish Balay struct MatProductData_MPIAIJKokkos {
1903ba16761SJacob Faibussowitsch   MatMatStruct_AB  *mmAB     = nullptr;
1913ba16761SJacob Faibussowitsch   MatMatStruct_AtB *mmAtB    = nullptr;
1923ba16761SJacob Faibussowitsch   PetscBool         reusesym = PETSC_FALSE;
193076ba34aSJunchao Zhang 
194d71ae5a4SJacob Faibussowitsch   ~MatProductData_MPIAIJKokkos()
195d71ae5a4SJacob Faibussowitsch   {
196076ba34aSJunchao Zhang     delete mmAB;
197076ba34aSJunchao Zhang     delete mmAtB;
198076ba34aSJunchao Zhang   }
199076ba34aSJunchao Zhang };
200076ba34aSJunchao Zhang 
201d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatProductDataDestroy_MPIAIJKokkos(void *data)
202d71ae5a4SJacob Faibussowitsch {
203076ba34aSJunchao Zhang   PetscFunctionBegin;
2049566063dSJacob Faibussowitsch   PetscCallCXX(delete static_cast<MatProductData_MPIAIJKokkos *>(data));
2053ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
206076ba34aSJunchao Zhang }
207076ba34aSJunchao Zhang 
208076ba34aSJunchao Zhang /* MatSeqAIJKokkosGetCSRMatrixWithGlobalColumnIds - Get a KokkosCsrMatrix from a MATSEQAIJKOKKOS matrix
209076ba34aSJunchao Zhang 
210076ba34aSJunchao Zhang    Input Parameters:
211076ba34aSJunchao Zhang +  A       - the MATSEQAIJKOKKOS matrix
212076ba34aSJunchao Zhang .  N       - new column size for the returned Kokkos matrix
213076ba34aSJunchao Zhang -  l2g     - a map that maps old col ids to new col ids
214076ba34aSJunchao Zhang 
215076ba34aSJunchao Zhang    Output Parameters:
216076ba34aSJunchao Zhang .  csrmat  - the Kokkos matrix, which has the same row size as A, shares a, i but not j with A.
217076ba34aSJunchao Zhang  */
218d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatSeqAIJKokkosGetCSRMatrixWithGlobalColumnIds(Mat A, PetscInt N, const ConstMatColIdxKokkosView &l2g, KokkosCsrMatrix &csrmat)
219d71ae5a4SJacob Faibussowitsch {
220076ba34aSJunchao Zhang   KokkosCsrMatrix    &orig = static_cast<Mat_SeqAIJKokkos *>(A->spptr)->csrmat;
221076ba34aSJunchao Zhang   MatColIdxKokkosView jg("jg", orig.nnz()); /* New j array for csrmat */
222076ba34aSJunchao Zhang 
223076ba34aSJunchao Zhang   PetscFunctionBegin;
2249371c9d4SSatish Balay   PetscCallCXX(Kokkos::parallel_for(
2259371c9d4SSatish Balay     orig.nnz(), KOKKOS_LAMBDA(const PetscInt i) { jg(i) = l2g(orig.graph.entries(i)); }));
2269566063dSJacob Faibussowitsch   PetscCallCXX(csrmat = KokkosCsrMatrix("csrmat", orig.numRows(), N, orig.nnz(), orig.values, orig.graph.row_map, jg));
2273ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
228076ba34aSJunchao Zhang }
229076ba34aSJunchao Zhang 
230076ba34aSJunchao Zhang /* MatSetMPIAIJKokkosWithSplitSeqAIJKokkosMatrices - Set the diag and offdiag matrices of a MATMPIAIJKOKKOS matrix.
231076ba34aSJunchao Zhang    It is similar to MatCreateMPIAIJWithSplitArrays.
232076ba34aSJunchao Zhang 
233076ba34aSJunchao Zhang   Input Parameters:
234076ba34aSJunchao Zhang +  mat   - the MATMPIAIJKOKKOS matrix, which should have its type and layout set, but should not have its diag, offdiag matrices set
235076ba34aSJunchao Zhang .  A     - the diag matrix using local col ids
236076ba34aSJunchao Zhang -  B     - the offdiag matrix using global col ids
237076ba34aSJunchao Zhang 
238076ba34aSJunchao Zhang   Output Parameters:
239076ba34aSJunchao Zhang .  mat   - the updated MATMPIAIJKOKKOS matrix
240076ba34aSJunchao Zhang */
241d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatSetMPIAIJKokkosWithSplitSeqAIJKokkosMatrices(Mat mat, Mat A, Mat B)
242d71ae5a4SJacob Faibussowitsch {
243076ba34aSJunchao Zhang   Mat_MPIAIJ       *mpiaij = static_cast<Mat_MPIAIJ *>(mat->data);
244076ba34aSJunchao Zhang   PetscInt          m, n, M, N, Am, An, Bm, Bn;
245076ba34aSJunchao Zhang   Mat_SeqAIJKokkos *bkok = static_cast<Mat_SeqAIJKokkos *>(B->spptr);
246076ba34aSJunchao Zhang 
247076ba34aSJunchao Zhang   PetscFunctionBegin;
2489566063dSJacob Faibussowitsch   PetscCall(MatGetSize(mat, &M, &N));
2499566063dSJacob Faibussowitsch   PetscCall(MatGetLocalSize(mat, &m, &n));
2509566063dSJacob Faibussowitsch   PetscCall(MatGetLocalSize(A, &Am, &An));
2519566063dSJacob Faibussowitsch   PetscCall(MatGetLocalSize(B, &Bm, &Bn));
252076ba34aSJunchao Zhang 
253aed4548fSBarry Smith   PetscCheck(m == Am && m == Bm, PETSC_COMM_SELF, PETSC_ERR_PLIB, "local number of rows do not match");
25408401ef6SPierre Jolivet   PetscCheck(n == An, PETSC_COMM_SELF, PETSC_ERR_PLIB, "local number of columns do not match");
25508401ef6SPierre Jolivet   PetscCheck(N == Bn, PETSC_COMM_SELF, PETSC_ERR_PLIB, "global number of columns do not match");
25608401ef6SPierre Jolivet   PetscCheck(!mpiaij->A && !mpiaij->B, PETSC_COMM_SELF, PETSC_ERR_PLIB, "A, B of the MPIAIJ matrix are not empty");
257076ba34aSJunchao Zhang   mpiaij->A = A;
258076ba34aSJunchao Zhang   mpiaij->B = B;
259076ba34aSJunchao Zhang 
260076ba34aSJunchao Zhang   mat->preallocated     = PETSC_TRUE;
261076ba34aSJunchao Zhang   mat->nooffprocentries = PETSC_TRUE; /* See MatAssemblyBegin_MPIAIJ. In effect, making MatAssemblyBegin a nop */
262076ba34aSJunchao Zhang 
2639566063dSJacob Faibussowitsch   PetscCall(MatSetOption(mat, MAT_NO_OFF_PROC_ENTRIES, PETSC_TRUE));
2649566063dSJacob Faibussowitsch   PetscCall(MatAssemblyBegin(mat, MAT_FINAL_ASSEMBLY));
265076ba34aSJunchao Zhang   /* MatAssemblyEnd is critical here. It sets mat->offloadmask according to A and B's, and
266076ba34aSJunchao Zhang     also gets mpiaij->B compacted, with its col ids and size reduced
267076ba34aSJunchao Zhang   */
2689566063dSJacob Faibussowitsch   PetscCall(MatAssemblyEnd(mat, MAT_FINAL_ASSEMBLY));
2699566063dSJacob Faibussowitsch   PetscCall(MatSetOption(mat, MAT_NO_OFF_PROC_ENTRIES, PETSC_FALSE));
2709566063dSJacob Faibussowitsch   PetscCall(MatSetOption(mat, MAT_NEW_NONZERO_LOCATION_ERR, PETSC_TRUE));
271076ba34aSJunchao Zhang 
272076ba34aSJunchao Zhang   /* Update bkok with new local col ids (stored on host) and size */
273076ba34aSJunchao Zhang   bkok->j_dual.modify_host();
274076ba34aSJunchao Zhang   bkok->j_dual.sync_device();
275076ba34aSJunchao Zhang   bkok->SetColSize(mpiaij->B->cmap->n);
2763ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
277076ba34aSJunchao Zhang }
278076ba34aSJunchao Zhang 
279076ba34aSJunchao Zhang /* MatSeqAIJKokkosBcast - Bcast rows of a SEQAIJKOKKOS matrice (B) to form a SEQAIJKOKKOS matrix (C).
280076ba34aSJunchao Zhang 
281076ba34aSJunchao Zhang    It is essentially the MPIAIJKOKKOS counterpart of MatGetBrowsOfAoCols_MPIAIJ, but supports device and uses PetscSF.
282076ba34aSJunchao Zhang    In the given ownerSF, leaves correspond to rows in C, and roots correspond to rows in B. Roots may connect to multiple leaves.
283076ba34aSJunchao Zhang    Suppose C's j-th row is connected to a root identified by PetscSFNode (k,i), it means we will bcast the i-th row of B on rank k
284076ba34aSJunchao Zhang    to j-th row of C. ownerSF's leaves must be contiguous (in other words, as if ilocal=NULL was used to set its graph).
285076ba34aSJunchao Zhang 
286*2ef1f0ffSBarry Smith    Collective
287076ba34aSJunchao Zhang 
288076ba34aSJunchao Zhang    Input Parameters:
289076ba34aSJunchao Zhang +   B       - the SEQAIJKOKKOS matrix, using local col ids
290076ba34aSJunchao Zhang .   reuse   - either MAT_INITIAL_MATRIX or MAT_REUSE_MATRIX
291076ba34aSJunchao Zhang .   N       - global col ids are in range of [0,N). N Must be the same across ranks (nonsignificant in MAT_REUSE_MATRIX)
292076ba34aSJunchao Zhang .   l2g     - a map mapping B's local col ids to global ones (nonsignificant in MAT_REUSE_MATRIX)
293076ba34aSJunchao Zhang .   ownerSF - the ownership SF (nonsignificant in MAT_REUSE_MATRIX)
294076ba34aSJunchao Zhang 
295076ba34aSJunchao Zhang    Input/Output Parameters (out when resue = MAT_INITIAL_MATRIX, inout when reuse = MAT_REUSE_MATRIX)
296076ba34aSJunchao Zhang +   bcastSF   - the SF used to bcast rows of B. This plain SF does buffer (abuf) to buffer (Ca) send/recv. In this SF, vertices are nonzeros.
297076ba34aSJunchao Zhang .   abuf      - buffer for sending matrix values
298076ba34aSJunchao Zhang .   rows      - array containing indices of (local) rows that this rank needs to bcast to others. Each receiver rank has a chunk in rows[].
299076ba34aSJunchao Zhang                 Values in rows[] might have repeats, which simply indicates a row will be bcast'ed to multiple neighbors.
300076ba34aSJunchao Zhang .   rowoffset - For each row in rows[], it will be copied to rowoffset[] at abuf[]
301076ba34aSJunchao Zhang -   C         -  the SEQAIJKOKKOS matrix made of the bcast'ed rows, using local col ids.
302076ba34aSJunchao Zhang */
303d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatSeqAIJKokkosBcast(Mat B, MatReuse reuse, PetscInt N, const ConstMatColIdxKokkosView &l2g, PetscSF ownerSF, PetscSF &bcastSF, MatScalarKokkosView &abuf, MatColIdxKokkosView &rows, MatRowMapKokkosView &rowoffset, Mat &C)
304d71ae5a4SJacob Faibussowitsch {
305076ba34aSJunchao Zhang   Mat_SeqAIJKokkos *bkok, *ckok;
306076ba34aSJunchao Zhang 
307076ba34aSJunchao Zhang   PetscFunctionBegin;
3089566063dSJacob Faibussowitsch   PetscCall(MatSeqAIJKokkosSyncDevice(B)); /* Make sure B->spptr is accessible */
309076ba34aSJunchao Zhang   bkok = static_cast<Mat_SeqAIJKokkos *>(B->spptr);
310076ba34aSJunchao Zhang 
311076ba34aSJunchao Zhang   if (reuse == MAT_REUSE_MATRIX) {
312076ba34aSJunchao Zhang     ckok = static_cast<Mat_SeqAIJKokkos *>(C->spptr);
313076ba34aSJunchao Zhang 
314076ba34aSJunchao Zhang     const auto &Ba = bkok->a_dual.view_device();
315076ba34aSJunchao Zhang     const auto &Bi = bkok->i_dual.view_device();
316076ba34aSJunchao Zhang     const auto &Ca = ckok->a_dual.view_device();
317076ba34aSJunchao Zhang 
318076ba34aSJunchao Zhang     /* Copy Ba to abuf */
3199371c9d4SSatish Balay     Kokkos::parallel_for(
3209371c9d4SSatish Balay       Kokkos::TeamPolicy<>(rows.extent(0), Kokkos::AUTO()), KOKKOS_LAMBDA(const KokkosTeamMemberType &t) {
321076ba34aSJunchao Zhang         PetscInt i    = t.league_rank(); /* rows[i] is r-th row of B */
322076ba34aSJunchao Zhang         PetscInt r    = rows(i);
323076ba34aSJunchao Zhang         PetscInt base = rowoffset(i); /* Copy r-th row of B to this offset in abuf[] */
3249371c9d4SSatish Balay         Kokkos::parallel_for(Kokkos::TeamThreadRange(t, Bi(r + 1) - Bi(r)), [&](PetscInt k) { abuf(base + k) = Ba(Bi(r) + k); });
325076ba34aSJunchao Zhang       });
326076ba34aSJunchao Zhang 
327076ba34aSJunchao Zhang     /* Send abuf to Ca through bcastSF and then mark C is updated on device */
3289566063dSJacob Faibussowitsch     PetscCall(PetscSFBcastBegin(bcastSF, MPIU_SCALAR, abuf.data(), Ca.data(), MPI_REPLACE)); /* TODO: get memtype for abuf */
3299566063dSJacob Faibussowitsch     PetscCall(PetscSFBcastEnd(bcastSF, MPIU_SCALAR, abuf.data(), Ca.data(), MPI_REPLACE));
330076ba34aSJunchao Zhang     ckok->a_dual.modify_device();
331076ba34aSJunchao Zhang   } else if (reuse == MAT_INITIAL_MATRIX) {
332076ba34aSJunchao Zhang     MPI_Comm    comm;
333076ba34aSJunchao Zhang     PetscMPIInt tag;
334076ba34aSJunchao Zhang     PetscInt    k, Cm, Cn, Cnnz, *Ci_h, nroots, nleaves;
335076ba34aSJunchao Zhang 
3369566063dSJacob Faibussowitsch     PetscCallMPI(PetscObjectGetComm((PetscObject)ownerSF, &comm));
3379566063dSJacob Faibussowitsch     PetscCall(PetscSFGetGraph(ownerSF, &nroots, &nleaves, NULL, NULL));
338076ba34aSJunchao Zhang     Cm = nleaves; /* row size of C */
339076ba34aSJunchao Zhang     Cn = N;       /* col size of C, which initially uses global ids, so we can safely set its col size as N */
340076ba34aSJunchao Zhang 
341076ba34aSJunchao Zhang     /* Get row lens (nz) of B's rows for later fast query */
342076ba34aSJunchao Zhang     PetscInt       *Browlens;
343076ba34aSJunchao Zhang     const PetscInt *tmp = bkok->i_host_data();
3449566063dSJacob Faibussowitsch     PetscCall(PetscMalloc1(nroots, &Browlens));
345076ba34aSJunchao Zhang     for (k = 0; k < nroots; k++) Browlens[k] = tmp[k + 1] - tmp[k];
346076ba34aSJunchao Zhang 
347076ba34aSJunchao Zhang     /* By ownerSF, each proc gets lens of rows of C */
348076ba34aSJunchao Zhang     MatRowMapKokkosDualView Ci("i", Cm + 1); /* C's rowmap */
349076ba34aSJunchao Zhang     Ci_h    = Ci.view_host().data();
350076ba34aSJunchao Zhang     Ci_h[0] = 0;
3519566063dSJacob Faibussowitsch     PetscCall(PetscSFBcastWithMemTypeBegin(ownerSF, MPIU_INT, PETSC_MEMTYPE_HOST, Browlens, PETSC_MEMTYPE_HOST, &Ci_h[1], MPI_REPLACE));
3529566063dSJacob Faibussowitsch     PetscCall(PetscSFBcastEnd(ownerSF, MPIU_INT, Browlens, &Ci_h[1], MPI_REPLACE));
353076ba34aSJunchao Zhang     for (k = 1; k < Cm + 1; k++) Ci_h[k] += Ci_h[k - 1]; /* Convert lens to CSR */
354076ba34aSJunchao Zhang     Cnnz = Ci_h[Cm];
355076ba34aSJunchao Zhang     Ci.modify_host();
356076ba34aSJunchao Zhang     Ci.sync_device();
357076ba34aSJunchao Zhang 
358076ba34aSJunchao Zhang     /* With the newly known Cnnz, we are able to allocate (j, a) for C on host & device */
359076ba34aSJunchao Zhang     MatColIdxKokkosDualView Cj("j", Cnnz);
360076ba34aSJunchao Zhang     MatScalarKokkosDualView Ca("a", Cnnz);
361076ba34aSJunchao Zhang 
362076ba34aSJunchao Zhang     /* Now build the bcastSF to fill Ca, Cj. This plain SF only does (contiguous) buffer to buffer send/recv */
363076ba34aSJunchao Zhang     const PetscMPIInt *iranks, *ranks;
364076ba34aSJunchao Zhang     const PetscInt    *ioffset, *irootloc, *roffset;
365076ba34aSJunchao Zhang     PetscInt           i, j, niranks, nranks, *sdisp, *rdisp, *rowptr;
366076ba34aSJunchao Zhang     MPI_Request       *reqs;
367076ba34aSJunchao Zhang 
3689566063dSJacob Faibussowitsch     PetscCall(PetscSFGetLeafRanks(ownerSF, &niranks, &iranks, &ioffset, &irootloc));                      /* irootloc[] contains indices of rows I need to send to each receiver */
3699566063dSJacob Faibussowitsch     PetscCall(PetscSFGetRootRanks(ownerSF, &nranks, &ranks, &roffset, NULL /*rmine*/, NULL /*rremote*/)); /* recv info */
370076ba34aSJunchao Zhang 
371076ba34aSJunchao Zhang     /* figure out offsets at the send buffer, to build the SF
372076ba34aSJunchao Zhang       sdisp[]  - stores offsets of nonzeros (in abuf or jbuf, see later) I need to send, per receiver.
373076ba34aSJunchao Zhang       rowptr[] - stores offsets for data of each row in abuf
374076ba34aSJunchao Zhang 
375076ba34aSJunchao Zhang       rdisp[]  - to receive sdisp[]
376076ba34aSJunchao Zhang     */
3779566063dSJacob Faibussowitsch     PetscCall(PetscMalloc3(niranks + 1, &sdisp, nranks, &rdisp, niranks + nranks, &reqs));
378076ba34aSJunchao Zhang     MatRowMapKokkosViewHost rowptr_h("rowptr_h", ioffset[niranks] + 1); /* Let Kokkos do the allocation, so that we can do an easy mirror later */
379076ba34aSJunchao Zhang     rowptr = rowptr_h.data();
380076ba34aSJunchao Zhang 
381076ba34aSJunchao Zhang     sdisp[0]  = 0;
382076ba34aSJunchao Zhang     rowptr[0] = 0;
383076ba34aSJunchao Zhang     for (i = 0; i < niranks; i++) { /* for each receiver */
384076ba34aSJunchao Zhang       PetscInt len, nz = 0;
385076ba34aSJunchao Zhang       for (j = ioffset[i]; j < ioffset[i + 1]; j++) { /* for each row to this receiver */
386076ba34aSJunchao Zhang         len           = Browlens[irootloc[j]];
387076ba34aSJunchao Zhang         rowptr[j + 1] = rowptr[j] + len;
388076ba34aSJunchao Zhang         nz += len;
389076ba34aSJunchao Zhang       }
390076ba34aSJunchao Zhang       sdisp[i + 1] = sdisp[i] + nz;
391076ba34aSJunchao Zhang     }
3929566063dSJacob Faibussowitsch     PetscCallMPI(PetscCommGetNewTag(comm, &tag));
3939566063dSJacob Faibussowitsch     for (i = 0; i < nranks; i++) PetscCallMPI(MPI_Irecv(&rdisp[i], 1, MPIU_INT, ranks[i], tag, comm, &reqs[i]));
3949566063dSJacob Faibussowitsch     for (i = 0; i < niranks; i++) PetscCallMPI(MPI_Isend(&sdisp[i], 1, MPIU_INT, iranks[i], tag, comm, &reqs[nranks + i]));
3959566063dSJacob Faibussowitsch     PetscCallMPI(MPI_Waitall(niranks + nranks, reqs, MPI_STATUSES_IGNORE));
396076ba34aSJunchao Zhang 
397076ba34aSJunchao Zhang     PetscInt     nleaves2 = Cnnz;           /* leaves are the nonzeros I will receive */
398076ba34aSJunchao Zhang     PetscInt     nroots2  = sdisp[niranks]; /* roots are the nonzeros (in abuf) I will send */
399076ba34aSJunchao Zhang     PetscSFNode *iremote;
4009566063dSJacob Faibussowitsch     PetscCall(PetscMalloc1(nleaves2, &iremote));
401076ba34aSJunchao Zhang     for (i = 0; i < nranks; i++) { /* for each sender */
402076ba34aSJunchao Zhang       k = 0;
403076ba34aSJunchao Zhang       for (j = Ci_h[roffset[i]]; j < Ci_h[roffset[i + 1]]; j++) {
404076ba34aSJunchao Zhang         iremote[j].rank  = ranks[i];
405076ba34aSJunchao Zhang         iremote[j].index = rdisp[i] + k;
406076ba34aSJunchao Zhang         k++;
407076ba34aSJunchao Zhang       }
408076ba34aSJunchao Zhang     }
409076ba34aSJunchao Zhang     /* TODO: we should extend PetscSF APIs for this buffer-to-buffer send/recv */
4109566063dSJacob Faibussowitsch     PetscCall(PetscSFCreate(comm, &bcastSF));
4119566063dSJacob Faibussowitsch     PetscCall(PetscSFSetGraph(bcastSF, nroots2, nleaves2, NULL /*ilocal*/, PETSC_OWN_POINTER, iremote, PETSC_OWN_POINTER));
412076ba34aSJunchao Zhang 
413076ba34aSJunchao Zhang     /* Extract selected rows of B, and copy their (a, j) into abuf[] and jbuf[], with j converted
414076ba34aSJunchao Zhang       from local to global. Then use bcastSF to fill Ca, Cj.
415076ba34aSJunchao Zhang     */
416076ba34aSJunchao Zhang     ConstMatColIdxKokkosViewHost rows_h(irootloc, ioffset[niranks]); /* irootloc[] stores indices of rows I need to send */
417076ba34aSJunchao Zhang     MatColIdxKokkosView          rows("rows", ioffset[niranks]);
418076ba34aSJunchao Zhang     Kokkos::deep_copy(rows, rows_h); /* Use deep copy since irootoc is managed by PetscSF and we want 'rows' to be standalone */
419076ba34aSJunchao Zhang 
420076ba34aSJunchao Zhang     rowoffset = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), rowptr_h); /* If no device, rowoffset will be an alias to rowptr_h */
421076ba34aSJunchao Zhang 
422076ba34aSJunchao Zhang     MatColIdxKokkosView jbuf("jbuf", sdisp[niranks]);   /* send buf for (global) col ids */
423076ba34aSJunchao Zhang     abuf = MatScalarKokkosView("abuf", sdisp[niranks]); /* send buf for mat values */
424076ba34aSJunchao Zhang 
425076ba34aSJunchao Zhang     const auto &Ba = bkok->a_dual.view_device();
426076ba34aSJunchao Zhang     const auto &Bi = bkok->i_dual.view_device();
427076ba34aSJunchao Zhang     const auto &Bj = bkok->j_dual.view_device();
428076ba34aSJunchao Zhang 
429076ba34aSJunchao Zhang     /* Copy Ba, Bj to abuf, jbuf with change col ids from local to global */
4309371c9d4SSatish Balay     Kokkos::parallel_for(
4319371c9d4SSatish Balay       Kokkos::TeamPolicy<>(rows.extent(0), Kokkos::AUTO()), KOKKOS_LAMBDA(const KokkosTeamMemberType &t) {
432076ba34aSJunchao Zhang         PetscInt i    = t.league_rank(); /* rows[i] is r-th row of B */
433076ba34aSJunchao Zhang         PetscInt r    = rows(i);
434076ba34aSJunchao Zhang         PetscInt base = rowoffset(i); /* Copy r-th row of B to this offset in abuf[] */
435076ba34aSJunchao Zhang         Kokkos::parallel_for(Kokkos::TeamThreadRange(t, Bi(r + 1) - Bi(r)), [&](PetscInt k) {
436076ba34aSJunchao Zhang           abuf(base + k) = Ba(Bi(r) + k);
437076ba34aSJunchao Zhang           jbuf(base + k) = l2g(Bj(Bi(r) + k));
438076ba34aSJunchao Zhang         });
439076ba34aSJunchao Zhang       });
440076ba34aSJunchao Zhang 
441076ba34aSJunchao Zhang     /* Send abuf & jbuf to fill Ca, Cj */
4429566063dSJacob Faibussowitsch     PetscCall(PetscSFBcastBegin(bcastSF, MPIU_INT, jbuf.data(), Cj.view_device().data(), MPI_REPLACE));
4439566063dSJacob Faibussowitsch     PetscCall(PetscSFBcastBegin(bcastSF, MPIU_SCALAR, abuf.data(), Ca.view_device().data(), MPI_REPLACE));
4449566063dSJacob Faibussowitsch     PetscCall(PetscSFBcastEnd(bcastSF, MPIU_INT, jbuf.data(), Cj.view_device().data(), MPI_REPLACE));
4459566063dSJacob Faibussowitsch     PetscCall(PetscSFBcastEnd(bcastSF, MPIU_SCALAR, abuf.data(), Ca.view_device().data(), MPI_REPLACE));
446076ba34aSJunchao Zhang     Cj.modify_device(); /* Mark Cj, Ca modified on device, but only sync Cj since we might not need Ca on host at all */
447076ba34aSJunchao Zhang     Cj.sync_host();
448076ba34aSJunchao Zhang     Ca.modify_device();
449076ba34aSJunchao Zhang 
450076ba34aSJunchao Zhang     /* Construct C with Ca, Ci, Cj */
451076ba34aSJunchao Zhang     auto ckok = new Mat_SeqAIJKokkos(Cm, Cn, Cnnz, Ci, Cj, Ca);
4529566063dSJacob Faibussowitsch     PetscCall(MatCreateSeqAIJKokkosWithCSRMatrix(PETSC_COMM_SELF, ckok, &C));
4539566063dSJacob Faibussowitsch     PetscCall(PetscFree3(sdisp, rdisp, reqs));
4549566063dSJacob Faibussowitsch     PetscCall(PetscFree(Browlens));
45598921bdaSJacob Faibussowitsch   } else SETERRQ(PETSC_COMM_SELF, PETSC_ERR_PLIB, "Unsupported MatReuse enum %d", reuse);
4563ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
457076ba34aSJunchao Zhang }
458076ba34aSJunchao Zhang 
459076ba34aSJunchao Zhang /* MatSeqAIJKokkosReduce - Reduce rows of a SEQAIJKOKKOS matrix (A) to form a Kokkos Csr matrix (C)
460076ba34aSJunchao Zhang 
461076ba34aSJunchao Zhang   It is the reverse of MatSeqAIJKokkosBcast in some sense.
462076ba34aSJunchao Zhang 
463076ba34aSJunchao Zhang   Think each row of A as a leaf, then the given ownerSF specifies roots of the leaves. Roots may connect to multiple leaves.
464076ba34aSJunchao Zhang   In this routine, we reduce (i.e., concatenate) leaves (rows) at their roots to form potentially longer rows in C. Such rows might
465076ba34aSJunchao Zhang   contain repeats, which does not matter since they will be summed up by other routines. C's row size will be nroots of ownerSF.
466076ba34aSJunchao Zhang 
467076ba34aSJunchao Zhang   Input Parameters:
468076ba34aSJunchao Zhang +  A        - the SEQAIJKOKKOS matrix to be reduced
469076ba34aSJunchao Zhang .  reuse    - either MAT_INITIAL_MATRIX or MAT_REUSE_MATRIX
470076ba34aSJunchao Zhang .  local    - true if A uses local col ids; false if A is already in global col ids.
471076ba34aSJunchao Zhang .  N        - if local, N is A's global col size
472076ba34aSJunchao Zhang .  l2g      - if local, a map mapping A's local col ids to global ones, which are in range of [0,N).
473076ba34aSJunchao Zhang -  ownerSF  - the SF specifies ownership (root) of rows in A
474076ba34aSJunchao Zhang 
475076ba34aSJunchao Zhang   Output Parameters:
476076ba34aSJunchao Zhang +  reduceSF    - the SF to reduce A's rows to contiguous buffers at the receiver side
477076ba34aSJunchao Zhang .  abuf         - a contiguous buffer to receive A's rows sent to this proc. Suppose there are 'nrows' such rows.
478076ba34aSJunchao Zhang .  srcrowoffset - offset array of size nrows+1. Each entry is the corresponding row's offset in abuf[]. srcrowoffset[i+1]-srcrowoffset[i] is row i's len.
479076ba34aSJunchao Zhang .  dstrowoffset - offset array of size nrows. Each entry is the corresponding row's offset in Ca[], i.e., C's 'a' array. Row i, i+1 in abuf[] may go to
480076ba34aSJunchao Zhang                   unrelated places in Ca, so dstrowoffset is not in CSR-like format as srcrowoffset.
481076ba34aSJunchao Zhang -  C            - the matrix made up by rows sent to me from other ranks, using global col ids
482076ba34aSJunchao Zhang 
483da81f932SPierre Jolivet    TODO: we can even have MatSeqAIJKokkosReduceBegin/End to provide opportunity for callers to overlap comp./comm. when reuse = MAT_REUSE_MATRIX.
484076ba34aSJunchao Zhang  */
485d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatSeqAIJKokkosReduce(Mat A, MatReuse reuse, PetscBool local, PetscInt N, const ConstMatColIdxKokkosView &l2g, PetscSF ownerSF, PetscSF &reduceSF, MatScalarKokkosView &abuf, MatRowMapKokkosView &srcrowoffset, MatRowMapKokkosView &dstrowoffset, KokkosCsrMatrix &C)
486d71ae5a4SJacob Faibussowitsch {
487076ba34aSJunchao Zhang   PetscInt          i, r, Am, An, Annz, Cnnz, nrows;
488076ba34aSJunchao Zhang   const PetscInt   *Ai;
489076ba34aSJunchao Zhang   Mat_SeqAIJKokkos *akok;
490076ba34aSJunchao Zhang 
491076ba34aSJunchao Zhang   PetscFunctionBegin;
4929566063dSJacob Faibussowitsch   PetscCall(MatSeqAIJKokkosSyncDevice(A)); /* So that A's latest data is on device */
4939566063dSJacob Faibussowitsch   PetscCall(MatGetSize(A, &Am, &An));
494076ba34aSJunchao Zhang   Ai   = static_cast<Mat_SeqAIJ *>(A->data)->i;
495076ba34aSJunchao Zhang   akok = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
496076ba34aSJunchao Zhang   Annz = Ai[Am];
497076ba34aSJunchao Zhang 
498076ba34aSJunchao Zhang   if (reuse == MAT_REUSE_MATRIX) {
499076ba34aSJunchao Zhang     /* Send Aa to abuf */
5009566063dSJacob Faibussowitsch     PetscCallMPI(PetscSFReduceBegin(reduceSF, MPIU_SCALAR, akok->a_device_data(), abuf.data(), MPI_REPLACE));
5019566063dSJacob Faibussowitsch     PetscCallMPI(PetscSFReduceEnd(reduceSF, MPIU_SCALAR, akok->a_device_data(), abuf.data(), MPI_REPLACE));
502076ba34aSJunchao Zhang 
503076ba34aSJunchao Zhang     /* Copy abuf to Ca */
504076ba34aSJunchao Zhang     const MatScalarKokkosView &Ca = C.values;
505076ba34aSJunchao Zhang     nrows                         = dstrowoffset.extent(0); /* Not srcrowoffset[] since it has an extra entry for CSR */
5069371c9d4SSatish Balay     Kokkos::parallel_for(
5079371c9d4SSatish Balay       Kokkos::TeamPolicy<>(nrows, Kokkos::AUTO()), KOKKOS_LAMBDA(const KokkosTeamMemberType &t) {
508076ba34aSJunchao Zhang         PetscInt i   = t.league_rank();
509076ba34aSJunchao Zhang         PetscInt src = srcrowoffset(i), dst = dstrowoffset(i);
510076ba34aSJunchao Zhang         PetscInt len = srcrowoffset(i + 1) - srcrowoffset(i);
511076ba34aSJunchao Zhang         Kokkos::parallel_for(Kokkos::TeamThreadRange(t, len), [&](PetscInt k) { Ca(dst + k) = abuf(src + k); });
512076ba34aSJunchao Zhang       });
513076ba34aSJunchao Zhang   } else if (reuse == MAT_INITIAL_MATRIX) {
514076ba34aSJunchao Zhang     MPI_Comm     comm;
515076ba34aSJunchao Zhang     MPI_Request *reqs;
516076ba34aSJunchao Zhang     PetscMPIInt  tag;
517076ba34aSJunchao Zhang     PetscInt     Cm;
518076ba34aSJunchao Zhang 
5199566063dSJacob Faibussowitsch     PetscCall(PetscObjectGetComm((PetscObject)ownerSF, &comm));
5209566063dSJacob Faibussowitsch     PetscCall(PetscCommGetNewTag(comm, &tag));
521076ba34aSJunchao Zhang 
522076ba34aSJunchao Zhang     PetscInt           niranks, nranks, nroots, nleaves;
523076ba34aSJunchao Zhang     const PetscMPIInt *iranks, *ranks;
524076ba34aSJunchao Zhang     const PetscInt    *ioffset, *rows, *roffset; /* rows[] contains local indices of rows scattered to me from others. ioffset[] is a CSR on rows[] */
5259566063dSJacob Faibussowitsch     PetscCall(PetscSFSetUp(ownerSF));
5269566063dSJacob Faibussowitsch     PetscCall(PetscSFGetLeafRanks(ownerSF, &niranks, &iranks, &ioffset, &rows));                          /* recv info: iranks[] will send rows to me */
5279566063dSJacob Faibussowitsch     PetscCall(PetscSFGetRootRanks(ownerSF, &nranks, &ranks, &roffset, NULL /*rmine*/, NULL /*rremote*/)); /* send info */
5289566063dSJacob Faibussowitsch     PetscCall(PetscSFGetGraph(ownerSF, &nroots, &nleaves, NULL, NULL));
52908401ef6SPierre Jolivet     PetscCheck(nleaves == Am, PETSC_COMM_SELF, PETSC_ERR_PLIB, "ownerSF's nleaves(%" PetscInt_FMT ") != row size of A(%" PetscInt_FMT ")", nleaves, Am);
530076ba34aSJunchao Zhang     Cm    = nroots;
531076ba34aSJunchao Zhang     nrows = ioffset[niranks]; /* # of rows to be received. Might receive same row (each is partial) from different senders */
532076ba34aSJunchao Zhang 
533076ba34aSJunchao Zhang     /* Tell owners how long each row I will send */
534076ba34aSJunchao Zhang     PetscInt               *srowlens;                              /* send buf of row lens */
535076ba34aSJunchao Zhang     MatRowMapKokkosViewHost rrowlens_h("rrowoffset_h", nrows + 1); /* recv buf of row lens. +1 to make CSR later. Memory might be passed to other views */
536076ba34aSJunchao Zhang     PetscInt               *rrowlens = rrowlens_h.data();
537076ba34aSJunchao Zhang 
5389566063dSJacob Faibussowitsch     PetscCall(PetscMalloc2(Am, &srowlens, niranks + nranks, &reqs));
539076ba34aSJunchao Zhang     for (i = 0; i < Am; i++) srowlens[i] = Ai[i + 1] - Ai[i];
540076ba34aSJunchao Zhang     rrowlens[0] = 0;
541076ba34aSJunchao Zhang     rrowlens++; /* shift the pointer to make the following expression more readable */
5429566063dSJacob Faibussowitsch     for (i = 0; i < niranks; i++) PetscCallMPI(MPI_Irecv(&rrowlens[ioffset[i]], ioffset[i + 1] - ioffset[i], MPIU_INT, iranks[i], tag, comm, &reqs[i]));
5439566063dSJacob Faibussowitsch     for (i = 0; i < nranks; i++) PetscCallMPI(MPI_Isend(&srowlens[roffset[i]], roffset[i + 1] - roffset[i], MPIU_INT, ranks[i], tag, comm, &reqs[niranks + i]));
5449566063dSJacob Faibussowitsch     PetscCallMPI(MPI_Waitall(niranks + nranks, reqs, MPI_STATUSES_IGNORE));
545076ba34aSJunchao Zhang 
546076ba34aSJunchao Zhang     /* Owner builds Ci on host by histogramming rrowlens[] */
547076ba34aSJunchao Zhang     MatRowMapKokkosViewHost Ci_h("i", Cm + 1);
548076ba34aSJunchao Zhang     Kokkos::deep_copy(Ci_h, 0); /* Zero Ci */
549076ba34aSJunchao Zhang     MatRowMapType *Ci_ptr = Ci_h.data();
550076ba34aSJunchao Zhang 
551076ba34aSJunchao Zhang     for (i = 0; i < nrows; i++) {
552076ba34aSJunchao Zhang       r = rows[i]; /* local row id of i-th received row */
553076ba34aSJunchao Zhang #if defined(PETSC_USE_DEBUG)
554aed4548fSBarry Smith       PetscCheck(r >= 0 && r < Cm, PETSC_COMM_SELF, PETSC_ERR_PLIB, "local row id (%" PetscInt_FMT ") is out of range [0,%" PetscInt_FMT ")", r, Cm);
555076ba34aSJunchao Zhang #endif
556076ba34aSJunchao Zhang       Ci_ptr[r + 1] += rrowlens[i]; /* add to length of row r in C */
557076ba34aSJunchao Zhang     }
558076ba34aSJunchao Zhang     for (i = 0; i < Cm; i++) Ci_ptr[i + 1] += Ci_ptr[i]; /* to CSR format */
559076ba34aSJunchao Zhang     Cnnz = Ci_ptr[Cm];
560076ba34aSJunchao Zhang 
561076ba34aSJunchao Zhang     /* For each received row, compute src & dst offsets in memory copying (from recv bufs abuf, jbuf to Ca, Cj) */
562076ba34aSJunchao Zhang     MatRowMapKokkosViewHost dstrowoffset_h("dstrowoffset_h", nrows);
563076ba34aSJunchao Zhang     PetscInt               *dstrowoffset_hptr = dstrowoffset_h.data();
564076ba34aSJunchao Zhang     PetscInt               *currowlens; /* Current row lens. They are temp accumulators for row lens in C, to help build dstrowoffset */
565076ba34aSJunchao Zhang 
5669566063dSJacob Faibussowitsch     PetscCall(PetscCalloc1(Cm, &currowlens));           /* Init with zero, to be added to */
567076ba34aSJunchao Zhang     for (i = 0; i < nrows; i++) {                       /* for each row I receive */
568076ba34aSJunchao Zhang       r                    = rows[i];                   /* row id in C */
569076ba34aSJunchao Zhang       dstrowoffset_hptr[i] = Ci_ptr[r] + currowlens[r]; /* dst offset of the new place for each recv'ed row in Ca/Cj */
570076ba34aSJunchao Zhang       currowlens[r] += rrowlens[i];                     /* accumulate to length of row r in C */
571076ba34aSJunchao Zhang     }
5729566063dSJacob Faibussowitsch     PetscCall(PetscFree(currowlens));
573076ba34aSJunchao Zhang 
574076ba34aSJunchao Zhang     rrowlens--;
575076ba34aSJunchao Zhang     for (i = 0; i < nrows; i++) rrowlens[i + 1] += rrowlens[i]; /* Change rrowlens[] to CSR format */
576076ba34aSJunchao Zhang     dstrowoffset = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), dstrowoffset_h);
577076ba34aSJunchao Zhang     srcrowoffset = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), rrowlens_h); /* src offset of each recv'ed row in abuf/jbuf */
578076ba34aSJunchao Zhang 
579076ba34aSJunchao Zhang     /* Build the reduceSF, which performs buffer to buffer send/recv */
580076ba34aSJunchao Zhang     PetscInt *sdisp, *rdisp; /* buffer to send offsets of roots, and buffer to recv them */
5819566063dSJacob Faibussowitsch     PetscCall(PetscMalloc2(niranks, &sdisp, nranks, &rdisp));
582076ba34aSJunchao Zhang     for (i = 0; i < niranks; i++) sdisp[i] = rrowlens[ioffset[i]];
5839566063dSJacob Faibussowitsch     for (i = 0; i < nranks; i++) PetscCallMPI(MPI_Irecv(&rdisp[i], 1, MPIU_INT, ranks[i], tag, comm, &reqs[i]));
5849566063dSJacob Faibussowitsch     for (i = 0; i < niranks; i++) PetscCallMPI(MPI_Isend(&sdisp[i], 1, MPIU_INT, iranks[i], tag, comm, &reqs[nranks + i]));
5859566063dSJacob Faibussowitsch     PetscCallMPI(MPI_Waitall(niranks + nranks, reqs, MPI_STATUSES_IGNORE));
586076ba34aSJunchao Zhang 
587076ba34aSJunchao Zhang     /* Nonzeros in abuf/jbuf are roots and those in A are leaves */
588076ba34aSJunchao Zhang     PetscInt     nroots2 = Cnnz, nleaves2 = Annz;
589076ba34aSJunchao Zhang     PetscSFNode *iremote;
5909566063dSJacob Faibussowitsch     PetscCall(PetscMalloc1(nleaves2, &iremote)); /* no free, since memory will be given to reduceSF */
591076ba34aSJunchao Zhang     for (i = 0; i < nranks; i++) {
592076ba34aSJunchao Zhang       PetscInt rootbase = rdisp[i];                      /* root offset at this root rank */
593076ba34aSJunchao Zhang       PetscInt leafbase = Ai[roffset[i]];                /* leaf base */
594076ba34aSJunchao Zhang       PetscInt nz       = Ai[roffset[i + 1]] - leafbase; /* I will send nz nonzeros to this root rank */
595076ba34aSJunchao Zhang       for (PetscInt k = 0; k < nz; k++) {
596076ba34aSJunchao Zhang         iremote[leafbase + k].rank  = ranks[i];
597076ba34aSJunchao Zhang         iremote[leafbase + k].index = rootbase + k;
598076ba34aSJunchao Zhang       }
599076ba34aSJunchao Zhang     }
6009566063dSJacob Faibussowitsch     PetscCall(PetscSFCreate(comm, &reduceSF));
6019566063dSJacob Faibussowitsch     PetscCall(PetscSFSetGraph(reduceSF, nroots2, nleaves2, NULL, PETSC_OWN_POINTER, iremote, PETSC_OWN_POINTER));
6029566063dSJacob Faibussowitsch     PetscCall(PetscFree2(sdisp, rdisp));
603076ba34aSJunchao Zhang 
604076ba34aSJunchao Zhang     /* Reduce Aa, Ajg to abuf and jbuf */
605076ba34aSJunchao Zhang 
606076ba34aSJunchao Zhang     /* If A uses local col ids, convert them to global ones before sending */
607076ba34aSJunchao Zhang     MatColIdxKokkosView Ajg;
608076ba34aSJunchao Zhang     if (local) {
609076ba34aSJunchao Zhang       Ajg                           = MatColIdxKokkosView("j", Annz);
610076ba34aSJunchao Zhang       const MatColIdxKokkosView &Aj = akok->j_dual.view_device();
6119371c9d4SSatish Balay       Kokkos::parallel_for(
6129371c9d4SSatish Balay         Annz, KOKKOS_LAMBDA(const PetscInt i) { Ajg(i) = l2g(Aj(i)); });
613076ba34aSJunchao Zhang     } else {
614076ba34aSJunchao Zhang       Ajg = akok->j_dual.view_device(); /* no data copy, just take a reference */
615076ba34aSJunchao Zhang     }
616076ba34aSJunchao Zhang 
617076ba34aSJunchao Zhang     MatColIdxKokkosView jbuf("jbuf", Cnnz);
618076ba34aSJunchao Zhang     abuf = MatScalarKokkosView("abuf", Cnnz);
6199566063dSJacob Faibussowitsch     PetscCallMPI(PetscSFReduceBegin(reduceSF, MPIU_INT, Ajg.data(), jbuf.data(), MPI_REPLACE));
6209566063dSJacob Faibussowitsch     PetscCallMPI(PetscSFReduceEnd(reduceSF, MPIU_INT, Ajg.data(), jbuf.data(), MPI_REPLACE));
6219566063dSJacob Faibussowitsch     PetscCallMPI(PetscSFReduceBegin(reduceSF, MPIU_SCALAR, akok->a_device_data(), abuf.data(), MPI_REPLACE));
6229566063dSJacob Faibussowitsch     PetscCallMPI(PetscSFReduceEnd(reduceSF, MPIU_SCALAR, akok->a_device_data(), abuf.data(), MPI_REPLACE));
623076ba34aSJunchao Zhang 
624076ba34aSJunchao Zhang     /* Copy data from abuf, jbuf to Ca, Cj */
625076ba34aSJunchao Zhang     MatRowMapKokkosView Ci = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Ci_h); /* Ci is an alias of Ci_h if no device */
626076ba34aSJunchao Zhang     MatColIdxKokkosView Cj("j", Cnnz);
627076ba34aSJunchao Zhang     MatScalarKokkosView Ca("a", Cnnz);
628076ba34aSJunchao Zhang 
6299371c9d4SSatish Balay     Kokkos::parallel_for(
6309371c9d4SSatish Balay       Kokkos::TeamPolicy<>(nrows, Kokkos::AUTO()), KOKKOS_LAMBDA(const KokkosTeamMemberType &t) {
631076ba34aSJunchao Zhang         PetscInt i   = t.league_rank();
632076ba34aSJunchao Zhang         PetscInt src = srcrowoffset(i), dst = dstrowoffset(i);
633076ba34aSJunchao Zhang         PetscInt len = srcrowoffset(i + 1) - srcrowoffset(i);
634076ba34aSJunchao Zhang         Kokkos::parallel_for(Kokkos::TeamThreadRange(t, len), [&](PetscInt k) {
635076ba34aSJunchao Zhang           Ca(dst + k) = abuf(src + k);
636076ba34aSJunchao Zhang           Cj(dst + k) = jbuf(src + k);
637076ba34aSJunchao Zhang         });
638076ba34aSJunchao Zhang       });
639076ba34aSJunchao Zhang 
640076ba34aSJunchao Zhang     /* Build C with Ca, Ci, Cj */
641076ba34aSJunchao Zhang     C = KokkosCsrMatrix("csrmat", Cm, N, Cnnz, Ca, Ci, Cj);
6429566063dSJacob Faibussowitsch     PetscCall(PetscFree2(srowlens, reqs));
64398921bdaSJacob Faibussowitsch   } else SETERRQ(PETSC_COMM_SELF, PETSC_ERR_PLIB, "Unsupported MatReuse enum %d", reuse);
6443ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
645076ba34aSJunchao Zhang }
646076ba34aSJunchao Zhang 
64711a5261eSBarry Smith /* MatSetMPIAIJKokkosWithGlobalCSRMatrix - Set the diag and offdiag parts of a `MATMPIAIJKOKKOS` matrix by splitting a KokkosCsrMatrix
648076ba34aSJunchao Zhang 
649076ba34aSJunchao Zhang   Input Parameters:
65011a5261eSBarry Smith +  C        - the `MATMPIAIJKOKKOS` matrix, of size m,n,M,N
651076ba34aSJunchao Zhang .  reuse    - indicate whether the matrix has called this function before
652076ba34aSJunchao Zhang .  csrmat   - the KokkosCsrMatrix, of size m,N
65311a5261eSBarry Smith -  Cdstart  - when reuse == `MAT_REUSE_MATRIX`, it is an input parameter. For each row in csrmat, it stores the start of the first
654076ba34aSJunchao Zhang               entry of the diag block of C in csrmat's j array. E.g, if row i has col ids = {0, 3, 4, 5, 7, 9} and the first diag
655076ba34aSJunchao Zhang               entry is 5, then Cdstart[i] = 3.
656076ba34aSJunchao Zhang 
657076ba34aSJunchao Zhang   Output Parameters:
65811a5261eSBarry Smith +  C        - the updated `MATMPIAIJKOKKOS` matrix
65911a5261eSBarry Smith -  Cdstart - when reuse == `MAT_INITIAL_MATRIX`, it is an output parameter
660076ba34aSJunchao Zhang 
66111a5261eSBarry Smith   Note:
66211a5261eSBarry Smith    Between calls with `MAT_INITIAL_MATRIX` or `MAT_REUSE_MATRIX`, csrmat must have the same nonzero pattern
66311a5261eSBarry Smith 
664*2ef1f0ffSBarry Smith .seealso: [](chapter_matrices), `Mat`, `MATMPIAIJKOKKOS`
665076ba34aSJunchao Zhang  */
666d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatSetMPIAIJKokkosWithGlobalCSRMatrix(Mat C, MatReuse reuse, const KokkosCsrMatrix &csrmat, MatRowMapKokkosView &Cdstart)
667d71ae5a4SJacob Faibussowitsch {
668076ba34aSJunchao Zhang   const MatScalarKokkosView      &Ca = csrmat.values;
669076ba34aSJunchao Zhang   const ConstMatRowMapKokkosView &Ci = csrmat.graph.row_map;
670076ba34aSJunchao Zhang   PetscInt                        m, n, N;
671076ba34aSJunchao Zhang 
672076ba34aSJunchao Zhang   PetscFunctionBegin;
6739566063dSJacob Faibussowitsch   PetscCall(MatGetLocalSize(C, &m, &n));
6749566063dSJacob Faibussowitsch   PetscCall(MatGetSize(C, NULL, &N));
675076ba34aSJunchao Zhang 
676076ba34aSJunchao Zhang   if (reuse == MAT_REUSE_MATRIX) {
677076ba34aSJunchao Zhang     Mat_MPIAIJ                *mpiaij = static_cast<Mat_MPIAIJ *>(C->data);
678076ba34aSJunchao Zhang     Mat_SeqAIJKokkos          *akok   = static_cast<Mat_SeqAIJKokkos *>(mpiaij->A->spptr);
679076ba34aSJunchao Zhang     Mat_SeqAIJKokkos          *bkok   = static_cast<Mat_SeqAIJKokkos *>(mpiaij->B->spptr);
680076ba34aSJunchao Zhang     const MatScalarKokkosView &Cda = akok->a_dual.view_device(), Coa = bkok->a_dual.view_device();
681076ba34aSJunchao Zhang     const MatRowMapKokkosView &Cdi = akok->i_dual.view_device(), Coi = bkok->i_dual.view_device();
682076ba34aSJunchao Zhang 
683076ba34aSJunchao Zhang     /* Fill 'a' of Cd and Co on device */
6849371c9d4SSatish Balay     Kokkos::parallel_for(
6859371c9d4SSatish Balay       Kokkos::TeamPolicy<>(m, Kokkos::AUTO()), KOKKOS_LAMBDA(const KokkosTeamMemberType &t) {
686076ba34aSJunchao Zhang         PetscInt i       = t.league_rank();     /* row i */
687076ba34aSJunchao Zhang         PetscInt clen    = Ci(i + 1) - Ci(i);   /* len of row i of C */
688076ba34aSJunchao Zhang         PetscInt cdlen   = Cdi(i + 1) - Cdi(i); /* len of row i of Cd */
689076ba34aSJunchao Zhang         PetscInt cdstart = Cdstart(i);          /* [start, end) of row i of Cd in C */
690076ba34aSJunchao Zhang         PetscInt cdend   = cdstart + cdlen;
691076ba34aSJunchao Zhang         /* [0, clen) is cut into three blocks: [0, cdstart), [cdstart, cdend), [cdend, clen) */
692076ba34aSJunchao Zhang         Kokkos::parallel_for(Kokkos::TeamThreadRange(t, clen), [&](PetscInt k) {
693076ba34aSJunchao Zhang           if (k < cdstart) { /* k in [0, cdstart) */
694076ba34aSJunchao Zhang             Coa(Coi(i) + k) = Ca(Ci(i) + k);
695076ba34aSJunchao Zhang           } else if (k < cdend) { /* k in [cdstart, cdend) */
696076ba34aSJunchao Zhang             Cda(Cdi(i) + (k - cdstart)) = Ca(Ci(i) + k);
697076ba34aSJunchao Zhang           } else { /* k in [cdend, clen) */
698076ba34aSJunchao Zhang             Coa(Coi(i) + k - cdlen) = Ca(Ci(i) + k);
699076ba34aSJunchao Zhang           }
700076ba34aSJunchao Zhang         });
701076ba34aSJunchao Zhang       });
702076ba34aSJunchao Zhang 
703076ba34aSJunchao Zhang     akok->a_dual.modify_device();
704076ba34aSJunchao Zhang     bkok->a_dual.modify_device();
705076ba34aSJunchao Zhang   } else if (reuse == MAT_INITIAL_MATRIX) {
706076ba34aSJunchao Zhang     Mat                        Cd, Co;
707076ba34aSJunchao Zhang     const MatColIdxKokkosView &Cj = csrmat.graph.entries;
708076ba34aSJunchao Zhang     MatRowMapKokkosDualView    Cdi_dual("i", m + 1), Coi_dual("i", m + 1);
709076ba34aSJunchao Zhang     MatRowMapKokkosView        Cdi = Cdi_dual.view_device(), Coi = Coi_dual.view_device();
710076ba34aSJunchao Zhang     PetscInt                   cstart, cend;
711076ba34aSJunchao Zhang 
712076ba34aSJunchao Zhang     /* Note that each row of C is sorted by col ids. We want to find out how to cut each row into three blocks:
713076ba34aSJunchao Zhang        left to the diag block, diag block, right to the diag block. The diag block have col ids in [cstart,cend).
714076ba34aSJunchao Zhang        Suppose a row of C has len nonzeros, indexed by [0, len). We want to know two indices: cdstart and cdend,
715076ba34aSJunchao Zhang        such that the three blocks are [0,cdstart), [cdstart,cdend), [cdend,len). The following code equivalentaly
716076ba34aSJunchao Zhang        stores values of cdstart and cdend-cstart (aka Cdi[]) instead.
717076ba34aSJunchao Zhang      */
718076ba34aSJunchao Zhang     Cdstart = MatRowMapKokkosView("Cdstart", m);
7199566063dSJacob Faibussowitsch     PetscCall(PetscLayoutGetRange(C->cmap, &cstart, &cend)); /* Not MatGetOwnershipRangeColumn() since C has not been preallocated yet */
720076ba34aSJunchao Zhang 
721076ba34aSJunchao Zhang     /* I could use RangePolicy and one thread per row. But since each thread essentially does binary search, threads in a
722076ba34aSJunchao Zhang       CUDA warp would completely diverge. So I use TeamPolicy with a team size 1.
723076ba34aSJunchao Zhang      */
7249371c9d4SSatish Balay     Kokkos::parallel_for(
7259371c9d4SSatish Balay       Kokkos::TeamPolicy<>(m, 1), KOKKOS_LAMBDA(const KokkosTeamMemberType &t) {
726076ba34aSJunchao Zhang         Kokkos::single(Kokkos::PerTeam(t), [=]() {                               /* Only one thread works in a team */
727076ba34aSJunchao Zhang                                                    PetscInt i = t.league_rank(); /* row i */
728076ba34aSJunchao Zhang                                                    PetscInt j, first, count, step;
729076ba34aSJunchao Zhang 
730076ba34aSJunchao Zhang                                                    if (i == 0) { /* Set the first entry of the i arrays to zero on device, to be used in CSR */
731076ba34aSJunchao Zhang                                                      Cdi(0) = 0;
732076ba34aSJunchao Zhang                                                      Coi(0) = 0;
733076ba34aSJunchao Zhang                                                    }
734076ba34aSJunchao Zhang 
735076ba34aSJunchao Zhang                                                    /* Do std::lower_bound(Ci(i),Ci(i+1),cstart) on Cj[]. We use j as the iterator. lower_bound() returns
736076ba34aSJunchao Zhang           in 'first' the first iterator with a value >= cstart, or last iterator if no such element is found.
737076ba34aSJunchao Zhang         */
738076ba34aSJunchao Zhang                                                    count = Ci(i + 1) - Ci(i);
739076ba34aSJunchao Zhang                                                    first = Ci(i);
740076ba34aSJunchao Zhang                                                    while (count > 0) {
741076ba34aSJunchao Zhang                                                      j    = first;
742076ba34aSJunchao Zhang                                                      step = count / 2;
743076ba34aSJunchao Zhang                                                      j += step;
744076ba34aSJunchao Zhang                                                      if (Cj(j) < cstart) {
745076ba34aSJunchao Zhang                                                        first = ++j;
746076ba34aSJunchao Zhang                                                        count -= step + 1;
747076ba34aSJunchao Zhang                                                      } else count = step;
748076ba34aSJunchao Zhang                                                    }
749076ba34aSJunchao Zhang                                                    Cdstart(i) = first - Ci(i); /* 'first' is the while-loop's output */
750076ba34aSJunchao Zhang 
751076ba34aSJunchao Zhang                                                    /* Do std::lower_bound(first,Ci(i+1),cend) on Cj[] */
752076ba34aSJunchao Zhang                                                    count = Ci(i + 1) - first;
753076ba34aSJunchao Zhang                                                    while (count > 0) {
754076ba34aSJunchao Zhang                                                      j    = first;
755076ba34aSJunchao Zhang                                                      step = count / 2;
756076ba34aSJunchao Zhang                                                      j += step;
757076ba34aSJunchao Zhang                                                      if (Cj(j) < cend) {
758076ba34aSJunchao Zhang                                                        first = ++j;
759076ba34aSJunchao Zhang                                                        count -= step + 1;
760076ba34aSJunchao Zhang                                                      } else count = step;
761076ba34aSJunchao Zhang                                                    }
762076ba34aSJunchao Zhang                                                    Cdi(i + 1) = first - (Ci(i) + Cdstart(i));     /* 'first' is the while-loop's output */
763076ba34aSJunchao Zhang                                                    Coi(i + 1) = (Ci(i + 1) - Ci(i)) - Cdi(i + 1); /* Co's row len = C's row len - Cd's row len */
764076ba34aSJunchao Zhang         });
765076ba34aSJunchao Zhang       });
766076ba34aSJunchao Zhang 
767076ba34aSJunchao Zhang     /* Convert row lens in Cdi[], Coi[] to CSR format using inclusive scan, e.g., changing [0,1,2,3] into [0,1,3,6] */
7689371c9d4SSatish Balay     Kokkos::parallel_scan(
7699371c9d4SSatish Balay       m + 1, KOKKOS_LAMBDA(const PetscInt i, PetscInt &update, const bool final) {
770076ba34aSJunchao Zhang         update += Cdi(i);
771076ba34aSJunchao Zhang         if (final) Cdi(i) = update;
772076ba34aSJunchao Zhang       });
7739371c9d4SSatish Balay     Kokkos::parallel_scan(
7749371c9d4SSatish Balay       m + 1, KOKKOS_LAMBDA(const PetscInt i, PetscInt &update, const bool final) {
775076ba34aSJunchao Zhang         update += Coi(i);
776076ba34aSJunchao Zhang         if (final) Coi(i) = update;
777076ba34aSJunchao Zhang       });
778076ba34aSJunchao Zhang 
779076ba34aSJunchao Zhang     /* Get Cdi, Coi on host (it is not a waste, since we do need them on host in
780076ba34aSJunchao Zhang        MatCreateSeqAIJKokkosWithCSRMatrix() below), then get nnz of Cd and Co.
781076ba34aSJunchao Zhang     */
782076ba34aSJunchao Zhang     Cdi_dual.modify_device();
783076ba34aSJunchao Zhang     Coi_dual.modify_device();
784076ba34aSJunchao Zhang     Cdi_dual.sync_host();
785076ba34aSJunchao Zhang     Coi_dual.sync_host();
786076ba34aSJunchao Zhang     PetscInt Cd_nnz = Cdi_dual.view_host().data()[m];
787076ba34aSJunchao Zhang     PetscInt Co_nnz = Coi_dual.view_host().data()[m];
788076ba34aSJunchao Zhang 
789076ba34aSJunchao Zhang     /* With nnz, allocate a, j for Cd and Co */
790076ba34aSJunchao Zhang     MatColIdxKokkosDualView Cdj_dual("j", Cd_nnz), Coj_dual("j", Co_nnz);
791076ba34aSJunchao Zhang     MatScalarKokkosDualView Cda_dual("a", Cd_nnz), Coa_dual("a", Co_nnz);
792076ba34aSJunchao Zhang 
793076ba34aSJunchao Zhang     /* Fill a, j of Cd and Co on device */
794076ba34aSJunchao Zhang     MatColIdxKokkosView Cdj = Cdj_dual.view_device(), Coj = Coj_dual.view_device();
795076ba34aSJunchao Zhang     MatScalarKokkosView Cda = Cda_dual.view_device(), Coa = Coa_dual.view_device();
796076ba34aSJunchao Zhang 
7979371c9d4SSatish Balay     Kokkos::parallel_for(
7989371c9d4SSatish Balay       Kokkos::TeamPolicy<>(m, Kokkos::AUTO()), KOKKOS_LAMBDA(const KokkosTeamMemberType &t) {
799076ba34aSJunchao Zhang         PetscInt i       = t.league_rank();     /* row i */
800076ba34aSJunchao Zhang         PetscInt clen    = Ci(i + 1) - Ci(i);   /* len of row i of C */
801076ba34aSJunchao Zhang         PetscInt cdlen   = Cdi(i + 1) - Cdi(i); /* len of row i of Cd */
802076ba34aSJunchao Zhang         PetscInt cdstart = Cdstart(i);          /* [start, end) of row i of Cd in C */
803076ba34aSJunchao Zhang         PetscInt cdend   = cdstart + cdlen;
804076ba34aSJunchao Zhang         /* [0, clen) is cut into three blocks: [0, cdstart), [cdstart, cdend), [cdend, clen) */
805076ba34aSJunchao Zhang         Kokkos::parallel_for(Kokkos::TeamThreadRange(t, clen), [&](PetscInt k) {
806076ba34aSJunchao Zhang           if (k < cdstart) { /* k in [0, cdstart) */
807076ba34aSJunchao Zhang             Coa(Coi(i) + k) = Ca(Ci(i) + k);
808076ba34aSJunchao Zhang             Coj(Coi(i) + k) = Cj(Ci(i) + k);
809076ba34aSJunchao Zhang           } else if (k < cdend) { /* k in [cdstart, cdend) */
810076ba34aSJunchao Zhang             Cda(Cdi(i) + (k - cdstart)) = Ca(Ci(i) + k);
811076ba34aSJunchao Zhang             Cdj(Cdi(i) + (k - cdstart)) = Cj(Ci(i) + k) - cstart; /* Use local col ids in Cdj */
812076ba34aSJunchao Zhang           } else {                                                /* k in [cdend, clen) */
813076ba34aSJunchao Zhang             Coa(Coi(i) + k - cdlen) = Ca(Ci(i) + k);
814076ba34aSJunchao Zhang             Coj(Coi(i) + k - cdlen) = Cj(Ci(i) + k);
815076ba34aSJunchao Zhang           }
816076ba34aSJunchao Zhang         });
817076ba34aSJunchao Zhang       });
818076ba34aSJunchao Zhang 
819076ba34aSJunchao Zhang     Cdj_dual.modify_device();
820076ba34aSJunchao Zhang     Cda_dual.modify_device();
821076ba34aSJunchao Zhang     Coj_dual.modify_device();
822076ba34aSJunchao Zhang     Coa_dual.modify_device();
823076ba34aSJunchao Zhang     /* With a, i, j for Cd and Co, finally build Cd, Co and then C. Their offloadmask will be set in each's MatAssemblyEnd */
824076ba34aSJunchao Zhang     auto cdkok = new Mat_SeqAIJKokkos(m, n, Cd_nnz, Cdi_dual, Cdj_dual, Cda_dual);
825076ba34aSJunchao Zhang     auto cokok = new Mat_SeqAIJKokkos(m, N, Co_nnz, Coi_dual, Coj_dual, Coa_dual);
8269566063dSJacob Faibussowitsch     PetscCall(MatCreateSeqAIJKokkosWithCSRMatrix(PETSC_COMM_SELF, cdkok, &Cd));
8279566063dSJacob Faibussowitsch     PetscCall(MatCreateSeqAIJKokkosWithCSRMatrix(PETSC_COMM_SELF, cokok, &Co));
8289566063dSJacob Faibussowitsch     PetscCall(MatSetMPIAIJKokkosWithSplitSeqAIJKokkosMatrices(C, Cd, Co)); /* Coj will be converted to local ids within */
829076ba34aSJunchao Zhang   }
8303ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
831076ba34aSJunchao Zhang }
832076ba34aSJunchao Zhang 
833076ba34aSJunchao Zhang /* MatSeqAIJCompactOutExtraColumns_SeqAIJKokkos - Compact a SEQAIJKOKKS matrix's global col ids.
834076ba34aSJunchao Zhang 
835076ba34aSJunchao Zhang   It is similar to MatSeqAIJCompactOutExtraColumns_SeqAIJ, but it applies to SEQAIJKOKKOS and returns the l2g map in Kokkos view.
836076ba34aSJunchao Zhang 
837076ba34aSJunchao Zhang   Input Parameters:
838076ba34aSJunchao Zhang +  C        - the MATMPIAIJKOKKOS matrix, of size m,n,M,N
839076ba34aSJunchao Zhang .  reuse    - indicate whether the matrix has called this function before
840076ba34aSJunchao Zhang .  csrmat   - the KokkosCsrMatrix, of size m,N
841076ba34aSJunchao Zhang -  Cdoffset - when reuse == MAT_REUSE_MATRIX, it is an input parameter. For each row in csrmat, it stores the offset of the first
842076ba34aSJunchao Zhang               entry of the diag block of C in csrmat's j array.
843076ba34aSJunchao Zhang 
844076ba34aSJunchao Zhang   Output Parameters:
845076ba34aSJunchao Zhang +  C        - the updated MATMPIAIJKOKKOS matrix
846076ba34aSJunchao Zhang -  Cdoffset - when reuse == MAT_INITIAL_MATRIX, it is an output parameter
847076ba34aSJunchao Zhang 
84811a5261eSBarry Smith   Note:
84911a5261eSBarry Smith   the input matrix's col ids and col size will be changed.
850076ba34aSJunchao Zhang */
851d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatSeqAIJCompactOutExtraColumns_SeqAIJKokkos(Mat C, MatColIdxKokkosView &l2g)
852d71ae5a4SJacob Faibussowitsch {
853076ba34aSJunchao Zhang   Mat_SeqAIJKokkos      *ckok;
854076ba34aSJunchao Zhang   ISLocalToGlobalMapping l2gmap;
855076ba34aSJunchao Zhang   const PetscInt        *garray;
856076ba34aSJunchao Zhang   PetscInt               sz;
857076ba34aSJunchao Zhang 
858076ba34aSJunchao Zhang   PetscFunctionBegin;
859076ba34aSJunchao Zhang   /* Compact P_other's global col ids and col size. We do it since we guess with local ids KK might be more memory scalable */
8609566063dSJacob Faibussowitsch   PetscCall(MatSeqAIJCompactOutExtraColumns_SeqAIJ(C, &l2gmap));
861076ba34aSJunchao Zhang   ckok = static_cast<Mat_SeqAIJKokkos *>(C->spptr);
862076ba34aSJunchao Zhang   ckok->j_dual.modify_host(); /* P_other's j is modified on host; we need to sync it on device */
863076ba34aSJunchao Zhang   ckok->j_dual.sync_device();
864076ba34aSJunchao Zhang   ckok->SetColSize(C->cmap->n); /* Update col size of the csrmat in spptr */
865076ba34aSJunchao Zhang 
866076ba34aSJunchao Zhang   /* Build l2g -- the local to global mapping of C's cols */
8679566063dSJacob Faibussowitsch   PetscCall(ISLocalToGlobalMappingGetIndices(l2gmap, &garray));
8689566063dSJacob Faibussowitsch   PetscCall(ISLocalToGlobalMappingGetSize(l2gmap, &sz));
86908401ef6SPierre Jolivet   PetscCheck(C->cmap->n == sz, PETSC_COMM_SELF, PETSC_ERR_PLIB, "matrix column size(%" PetscInt_FMT ") != l2g mapping size(%" PetscInt_FMT ")", C->cmap->n, sz);
870076ba34aSJunchao Zhang 
871076ba34aSJunchao Zhang   ConstMatColIdxKokkosViewHost tmp(garray, sz);
872076ba34aSJunchao Zhang   l2g = MatColIdxKokkosView("l2g", sz);
873076ba34aSJunchao Zhang   Kokkos::deep_copy(l2g, tmp);
874076ba34aSJunchao Zhang 
8759566063dSJacob Faibussowitsch   PetscCall(ISLocalToGlobalMappingRestoreIndices(l2gmap, &garray));
8769566063dSJacob Faibussowitsch   PetscCall(ISLocalToGlobalMappingDestroy(&l2gmap));
8773ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
878076ba34aSJunchao Zhang }
879076ba34aSJunchao Zhang 
880f0e6e2d1SJunchao Zhang #if PETSC_PKG_KOKKOS_KERNELS_VERSION_GE(3, 7, 99)
881f0e6e2d1SJunchao Zhang static PetscErrorCode MatMPIAIJGetLocalMat_MPIAIJKokkos(Mat mat, MatReuse reuse, MatMatStruct_AB *mm, Mat *C)
882f0e6e2d1SJunchao Zhang {
883f0e6e2d1SJunchao Zhang   Mat                 A, B;
884f0e6e2d1SJunchao Zhang   const PetscInt     *garray;
885f0e6e2d1SJunchao Zhang   Mat_SeqAIJ         *aseq, *bseq;
886f0e6e2d1SJunchao Zhang   Mat_SeqAIJKokkos   *akok, *bkok, *ckok;
887f0e6e2d1SJunchao Zhang   MatScalarKokkosView aa, ba, ca;
888f0e6e2d1SJunchao Zhang   MatRowMapKokkosView ai, bi, ci;
889f0e6e2d1SJunchao Zhang   MatColIdxKokkosView aj, bj, cj;
890f0e6e2d1SJunchao Zhang   PetscInt            m, nnz;
891f0e6e2d1SJunchao Zhang 
892f0e6e2d1SJunchao Zhang   PetscFunctionBegin;
893f0e6e2d1SJunchao Zhang   PetscCall(MatMPIAIJGetSeqAIJ(mat, &A, &B, &garray));
894f0e6e2d1SJunchao Zhang   PetscCheckTypeName(A, MATSEQAIJKOKKOS);
895f0e6e2d1SJunchao Zhang   PetscCheckTypeName(B, MATSEQAIJKOKKOS);
896f0e6e2d1SJunchao Zhang   PetscCheck(reuse != MAT_INPLACE_MATRIX, PETSC_COMM_SELF, PETSC_ERR_SUP, "MAT_INPLACE_MATRIX not supported");
897f0e6e2d1SJunchao Zhang   PetscCall(MatSeqAIJKokkosSyncDevice(A));
898f0e6e2d1SJunchao Zhang   PetscCall(MatSeqAIJKokkosSyncDevice(B));
899f0e6e2d1SJunchao Zhang   aseq = static_cast<Mat_SeqAIJ *>(A->data);
900f0e6e2d1SJunchao Zhang   bseq = static_cast<Mat_SeqAIJ *>(B->data);
901f0e6e2d1SJunchao Zhang   akok = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
902f0e6e2d1SJunchao Zhang   bkok = static_cast<Mat_SeqAIJKokkos *>(B->spptr);
903f0e6e2d1SJunchao Zhang   aa   = akok->a_dual.view_device();
904f0e6e2d1SJunchao Zhang   ai   = akok->i_dual.view_device();
905f0e6e2d1SJunchao Zhang   ba   = bkok->a_dual.view_device();
906f0e6e2d1SJunchao Zhang   bi   = bkok->i_dual.view_device();
907f0e6e2d1SJunchao Zhang   m    = A->rmap->n; /* M and nnz of C */
908f0e6e2d1SJunchao Zhang   nnz  = aseq->nz + bseq->nz;
909f0e6e2d1SJunchao Zhang   if (reuse == MAT_INITIAL_MATRIX) {
910f0e6e2d1SJunchao Zhang     aj           = akok->j_dual.view_device();
911f0e6e2d1SJunchao Zhang     bj           = bkok->j_dual.view_device();
912f0e6e2d1SJunchao Zhang     auto ca_dual = MatScalarKokkosDualView("a", nnz);
913f0e6e2d1SJunchao Zhang     auto ci_dual = MatRowMapKokkosDualView("i", m + 1);
914f0e6e2d1SJunchao Zhang     auto cj_dual = MatColIdxKokkosDualView("j", nnz);
915f0e6e2d1SJunchao Zhang     ca           = ca_dual.view_device();
916f0e6e2d1SJunchao Zhang     ci           = ci_dual.view_device();
917f0e6e2d1SJunchao Zhang     cj           = cj_dual.view_device();
918f0e6e2d1SJunchao Zhang 
919f0e6e2d1SJunchao Zhang     // For each row of B, find number of nonzeros on the left of the diagonal block (i.e., A).
920f0e6e2d1SJunchao Zhang     // The result is stored in mm->B_NzDiagLeft for reuse in the numeric phase
921f0e6e2d1SJunchao Zhang     MatColIdxKokkosViewHost NzLeft("NzLeft", m);
922f0e6e2d1SJunchao Zhang     const MatRowMapType    *rowptr = bkok->i_host_data();
923f0e6e2d1SJunchao Zhang     const MatColIdxType    *colidx = bkok->j_host_data();
924f0e6e2d1SJunchao Zhang     MatColIdxType          *nzleft = NzLeft.data();
925f0e6e2d1SJunchao Zhang     const MatColIdxType     cstart = mat->cmap->rstart; // start of global column indices of A; used to split B
926f0e6e2d1SJunchao Zhang     for (PetscInt i = 0; i < m; i++) {
927f0e6e2d1SJunchao Zhang       const MatColIdxType *first, *last, *it;
928f0e6e2d1SJunchao Zhang       PetscInt             count, step;
929f0e6e2d1SJunchao Zhang 
930f0e6e2d1SJunchao Zhang       // Basically, std::lower_bound(first,last,cstart), but need to map columns from local to global with garray[]
931f0e6e2d1SJunchao Zhang       first = colidx + rowptr[i];
932f0e6e2d1SJunchao Zhang       last  = colidx + rowptr[i + 1];
933f0e6e2d1SJunchao Zhang       count = last - first;
934f0e6e2d1SJunchao Zhang       while (count > 0) {
935f0e6e2d1SJunchao Zhang         it   = first;
936f0e6e2d1SJunchao Zhang         step = count / 2;
937f0e6e2d1SJunchao Zhang         it += step;
938f0e6e2d1SJunchao Zhang         if (garray[*it] < cstart) {
939f0e6e2d1SJunchao Zhang           first = ++it;
940f0e6e2d1SJunchao Zhang           count -= step + 1;
941f0e6e2d1SJunchao Zhang         } else count = step;
942f0e6e2d1SJunchao Zhang       }
943f0e6e2d1SJunchao Zhang       nzleft[i] = first - (colidx + rowptr[i]);
944f0e6e2d1SJunchao Zhang     }
945f0e6e2d1SJunchao Zhang     auto B_NzDiagLeft = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), NzLeft); // copy to device
946f0e6e2d1SJunchao Zhang 
947f0e6e2d1SJunchao Zhang     auto tmp = MatColIdxKokkosViewHost(const_cast<MatColIdxType *>(garray), B->cmap->n);
948f0e6e2d1SJunchao Zhang     auto l2g = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), tmp); // copy garray to device
949f0e6e2d1SJunchao Zhang 
950f0e6e2d1SJunchao Zhang     // Shuffle A and B in parallel using Kokkos hierarchical parallelism
951f0e6e2d1SJunchao Zhang     Kokkos::parallel_for(
952f0e6e2d1SJunchao Zhang       Kokkos::TeamPolicy<>(m, Kokkos::AUTO()), KOKKOS_LAMBDA(const KokkosTeamMemberType &t) {
953f0e6e2d1SJunchao Zhang         PetscInt i    = t.league_rank(); /* row i */
954f0e6e2d1SJunchao Zhang         PetscInt disp = ai(i) + bi(i), alen = ai(i + 1) - ai(i), blen = bi(i + 1) - bi(i);
955f0e6e2d1SJunchao Zhang         PetscInt nzleft = B_NzDiagLeft(i);
956f0e6e2d1SJunchao Zhang 
957f0e6e2d1SJunchao Zhang         Kokkos::single(Kokkos::PerTeam(t), [=]() {
958f0e6e2d1SJunchao Zhang           ci(i) = disp;
959f0e6e2d1SJunchao Zhang           if (i == m - 1) ci(m) = ai(m) + bi(m);
960f0e6e2d1SJunchao Zhang         });
961f0e6e2d1SJunchao Zhang 
962f0e6e2d1SJunchao Zhang         Kokkos::parallel_for(Kokkos::TeamThreadRange(t, alen + blen), [&](PetscInt k) {
963f0e6e2d1SJunchao Zhang           if (k < nzleft) { // portion of B that is on left of A
964f0e6e2d1SJunchao Zhang             ca(disp + k) = ba(bi(i) + k);
965f0e6e2d1SJunchao Zhang             cj(disp + k) = l2g(bj(bi(i) + k));
966f0e6e2d1SJunchao Zhang           } else if (k < nzleft + alen) { // diag A
967f0e6e2d1SJunchao Zhang             ca(disp + k) = aa(ai(i) + k - nzleft);
968f0e6e2d1SJunchao Zhang             cj(disp + k) = aj(ai(i) + k - nzleft) + cstart; // add the shift to convert local to global.
969f0e6e2d1SJunchao Zhang           } else {                                          // portion of B that is on right of A
970f0e6e2d1SJunchao Zhang             ca(disp + k) = ba(bi(i) + k - alen);
971f0e6e2d1SJunchao Zhang             cj(disp + k) = l2g(bj(bi(i) + k - alen));
972f0e6e2d1SJunchao Zhang           }
973f0e6e2d1SJunchao Zhang         });
974f0e6e2d1SJunchao Zhang       });
975f0e6e2d1SJunchao Zhang     ca_dual.modify_device();
976f0e6e2d1SJunchao Zhang     ci_dual.modify_device();
977f0e6e2d1SJunchao Zhang     cj_dual.modify_device();
978f0e6e2d1SJunchao Zhang     PetscCallCXX(ckok = new Mat_SeqAIJKokkos(m, mat->cmap->N, nnz, ci_dual, cj_dual, ca_dual));
979f0e6e2d1SJunchao Zhang     PetscCall(MatCreateSeqAIJKokkosWithCSRMatrix(PETSC_COMM_SELF, ckok, C));
980f0e6e2d1SJunchao Zhang     mm->B_NzDiagLeft = B_NzDiagLeft;
981f0e6e2d1SJunchao Zhang   } else if (reuse == MAT_REUSE_MATRIX) {
982f0e6e2d1SJunchao Zhang     PetscValidHeaderSpecific(*C, MAT_CLASSID, 4);
983f0e6e2d1SJunchao Zhang     PetscCheckTypeName(*C, MATSEQAIJKOKKOS);
984f0e6e2d1SJunchao Zhang     ckok               = static_cast<Mat_SeqAIJKokkos *>((*C)->spptr);
985f0e6e2d1SJunchao Zhang     ca                 = ckok->a_dual.view_device();
986f0e6e2d1SJunchao Zhang     auto &B_NzDiagLeft = mm->B_NzDiagLeft;
987f0e6e2d1SJunchao Zhang 
988f0e6e2d1SJunchao Zhang     Kokkos::parallel_for(
989f0e6e2d1SJunchao Zhang       Kokkos::TeamPolicy<>(m, Kokkos::AUTO()), KOKKOS_LAMBDA(const KokkosTeamMemberType &t) {
990f0e6e2d1SJunchao Zhang         PetscInt i    = t.league_rank(); // row i
991f0e6e2d1SJunchao Zhang         PetscInt disp = ai(i) + bi(i), alen = ai(i + 1) - ai(i), blen = bi(i + 1) - bi(i);
992f0e6e2d1SJunchao Zhang         PetscInt nzleft = B_NzDiagLeft(i);
993f0e6e2d1SJunchao Zhang 
994f0e6e2d1SJunchao Zhang         Kokkos::parallel_for(Kokkos::TeamThreadRange(t, alen + blen), [&](PetscInt k) {
995f0e6e2d1SJunchao Zhang           if (k < nzleft) { // portion of B that is on left of A
996f0e6e2d1SJunchao Zhang             ca(disp + k) = ba(bi(i) + k);
997f0e6e2d1SJunchao Zhang           } else if (k < nzleft + alen) { // diag A
998f0e6e2d1SJunchao Zhang             ca(disp + k) = aa(ai(i) + k - nzleft);
999f0e6e2d1SJunchao Zhang           } else { // portion of B that is on right of A
1000f0e6e2d1SJunchao Zhang             ca(disp + k) = ba(bi(i) + k - alen);
1001f0e6e2d1SJunchao Zhang           }
1002f0e6e2d1SJunchao Zhang         });
1003f0e6e2d1SJunchao Zhang       });
1004f0e6e2d1SJunchao Zhang 
1005f0e6e2d1SJunchao Zhang     PetscCall(MatSeqAIJKokkosModifyDevice(*C));
1006f0e6e2d1SJunchao Zhang   }
1007f0e6e2d1SJunchao Zhang   PetscFunctionReturn(PETSC_SUCCESS);
1008f0e6e2d1SJunchao Zhang }
1009f0e6e2d1SJunchao Zhang #endif
1010f0e6e2d1SJunchao Zhang 
1011076ba34aSJunchao Zhang /* MatProductSymbolic_MPIAIJKokkos_AB - AB flavor of MatProductSymbolic_MPIAIJKokkos
1012076ba34aSJunchao Zhang 
1013076ba34aSJunchao Zhang   Input Parameters:
1014076ba34aSJunchao Zhang +  product  - Mat_Product which carried out the computation. Passed in to access info about this mat product.
1015076ba34aSJunchao Zhang .  A        - an MPIAIJKOKKOS matrix
1016076ba34aSJunchao Zhang .  B        - an MPIAIJKOKKOS matrix
1017076ba34aSJunchao Zhang -  mm       - a struct used to stash intermediate data when computing AB. Persist from symbolic to numeric operations.
1018076ba34aSJunchao Zhang 
101911a5261eSBarry Smith   Note: The local part of the result C is stored as mm->C_global, which is of type KokkosCsrMatrix and uses global col ids.
1020076ba34aSJunchao Zhang */
1021d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatProductSymbolic_MPIAIJKokkos_AB(Mat_Product *product, Mat A, Mat B, MatMatStruct_AB *mm)
1022d71ae5a4SJacob Faibussowitsch {
1023076ba34aSJunchao Zhang   Mat_MPIAIJ              *a  = static_cast<Mat_MPIAIJ *>(A->data);
1024076ba34aSJunchao Zhang   Mat                      Ad = a->A, Ao = a->B; /* diag and offdiag of A */
1025076ba34aSJunchao Zhang   IS                       glob = NULL;
1026076ba34aSJunchao Zhang   const PetscInt          *garray;
1027076ba34aSJunchao Zhang   PetscInt                 N = B->cmap->N, sz;
1028076ba34aSJunchao Zhang   ConstMatColIdxKokkosView l2g1; /* two temp maps mapping local col ids to global ones */
1029076ba34aSJunchao Zhang   MatColIdxKokkosView      l2g2;
1030076ba34aSJunchao Zhang   Mat                      C1, C2; /* intermediate matrices */
1031076ba34aSJunchao Zhang 
1032076ba34aSJunchao Zhang   PetscFunctionBegin;
1033f0e6e2d1SJunchao Zhang #if PETSC_PKG_KOKKOS_KERNELS_VERSION_LT(3, 7, 99)
1034076ba34aSJunchao Zhang   /* C1 = Ad * B_local. B_local is a matrix got by merging Bd and Bo, and uses local col ids */
10359566063dSJacob Faibussowitsch   PetscCall(MatMPIAIJGetLocalMatMerge(B, MAT_INITIAL_MATRIX, &glob, &mm->B_local));
1036f0e6e2d1SJunchao Zhang #else
1037f0e6e2d1SJunchao Zhang   PetscCall(MatMPIAIJGetLocalMat_MPIAIJKokkos(B, MAT_INITIAL_MATRIX, mm, &mm->B_local));
1038f0e6e2d1SJunchao Zhang   PetscCall(ISCreateStride(MPI_COMM_SELF, N, 0, 1, &glob));
1039f0e6e2d1SJunchao Zhang #endif
1040f0e6e2d1SJunchao Zhang 
10419566063dSJacob Faibussowitsch   PetscCall(MatProductCreate(Ad, mm->B_local, NULL, &C1));
10429566063dSJacob Faibussowitsch   PetscCall(MatProductSetType(C1, MATPRODUCT_AB));
10439566063dSJacob Faibussowitsch   PetscCall(MatProductSetFill(C1, product->fill));
1044076ba34aSJunchao Zhang   C1->product->api_user = product->api_user;
10459566063dSJacob Faibussowitsch   PetscCall(MatProductSetFromOptions(C1));
1046dbbe0bcdSBarry Smith   PetscUseTypeMethod(C1, productsymbolic);
1047076ba34aSJunchao Zhang 
10489566063dSJacob Faibussowitsch   PetscCall(ISGetIndices(glob, &garray));
10499566063dSJacob Faibussowitsch   PetscCall(ISGetSize(glob, &sz));
1050076ba34aSJunchao Zhang   const auto &tmp = ConstMatColIdxKokkosViewHost(garray, sz);                       /* wrap garray as a view */
1051076ba34aSJunchao Zhang   l2g1            = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), tmp); /* maybe just an alias to tmp, so we restore garray at the very end */
10529566063dSJacob Faibussowitsch   PetscCall(MatSeqAIJKokkosGetCSRMatrixWithGlobalColumnIds(C1, N, l2g1, mm->C1_global));
1053076ba34aSJunchao Zhang 
1054076ba34aSJunchao Zhang   /* C2 = Ao * B_other. B_other is a matrix consisting of needed rows of B gathered from other procs */
10559566063dSJacob Faibussowitsch   PetscCall(MatSeqAIJKokkosBcast(mm->B_local, MAT_INITIAL_MATRIX, N, l2g1, a->Mvctx, mm->sf, mm->abuf, mm->rows, mm->rowoffset, mm->B_other));
1056076ba34aSJunchao Zhang 
1057da81f932SPierre Jolivet   /* Compact B_other to use local ids as we guess KK spgemm is more memory scalable with that; We could skip the compaction to simplify code */
10589566063dSJacob Faibussowitsch   PetscCall(MatSeqAIJCompactOutExtraColumns_SeqAIJKokkos(mm->B_other, l2g2));
10599566063dSJacob Faibussowitsch   PetscCall(MatProductCreate(Ao, mm->B_other, NULL, &C2));
10609566063dSJacob Faibussowitsch   PetscCall(MatProductSetType(C2, MATPRODUCT_AB));
10619566063dSJacob Faibussowitsch   PetscCall(MatProductSetFill(C2, product->fill));
1062076ba34aSJunchao Zhang   C2->product->api_user = product->api_user;
10639566063dSJacob Faibussowitsch   PetscCall(MatProductSetFromOptions(C2));
1064dbbe0bcdSBarry Smith   PetscUseTypeMethod(C2, productsymbolic);
10659566063dSJacob Faibussowitsch   PetscCall(MatSeqAIJKokkosGetCSRMatrixWithGlobalColumnIds(C2, N, l2g2, mm->C2_global));
1066076ba34aSJunchao Zhang 
1067076ba34aSJunchao Zhang   /* C = C1 + C2.  We actually use their global col ids versions in adding */
1068076ba34aSJunchao Zhang   mm->kh.create_spadd_handle(false); /* Input C1, C2 are NOT sorted, since B_local, B_other are not */
1069076ba34aSJunchao Zhang   KokkosSparse::spadd_symbolic(&mm->kh, mm->C1_global, mm->C2_global, mm->C_global);
1070076ba34aSJunchao Zhang   /* Have to do numeric since spadd_symbolic does not really populate column indices of the result matrix */
1071076ba34aSJunchao Zhang   KokkosSparse::spadd_numeric(&mm->kh, (MatScalarType)1.0, mm->C1_global, (MatScalarType)1.0, mm->C2_global, mm->C_global);
1072076ba34aSJunchao Zhang 
1073076ba34aSJunchao Zhang   mm->C1 = C1;
1074076ba34aSJunchao Zhang   mm->C2 = C2;
10759566063dSJacob Faibussowitsch   PetscCall(ISRestoreIndices(glob, &garray));
10769566063dSJacob Faibussowitsch   PetscCall(ISDestroy(&glob));
10773ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
1078076ba34aSJunchao Zhang }
1079076ba34aSJunchao Zhang 
1080076ba34aSJunchao Zhang /* MatProductSymbolic_MPIAIJKokkos_AtB - A^tB flavor of MatProductSymbolic_MPIAIJKokkos
1081076ba34aSJunchao Zhang 
1082076ba34aSJunchao Zhang   Input Parameters:
1083076ba34aSJunchao Zhang +  product  - Mat_Product which carried out the computation. Passed in to access info about this mat product.
1084076ba34aSJunchao Zhang .  A        - an MPIAIJKOKKOS matrix
1085076ba34aSJunchao Zhang .  B        - a SEQAIJKOKKOS matrix. It works as if A^t is multiplied by a parallel matrix made up of Bs on each rank.
1086076ba34aSJunchao Zhang .  localB   - Does B use local col ids? If false, then B is already in global col ids.
1087076ba34aSJunchao Zhang .  N        - col size of the "parallel B matrix". It implies B's global col ids are in range of [0,N) and N is the same across the communicator.
1088076ba34aSJunchao Zhang .  l2g      - If localB, then l2g maps B's local col ids to global ones.
1089076ba34aSJunchao Zhang -  mm       - a struct used to stash intermediate data in AtB
1090076ba34aSJunchao Zhang 
109111a5261eSBarry Smith   Note: The local part of the result C is stored as mm->C_global, which is of type KokkosCsrMatrix and uses global col ids.
1092076ba34aSJunchao Zhang */
1093d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatProductSymbolic_MPIAIJKokkos_AtB(Mat_Product *product, Mat A, Mat B, PetscBool localB, PetscInt N, const ConstMatColIdxKokkosView &l2g, MatMatStruct_AtB *mm)
1094d71ae5a4SJacob Faibussowitsch {
1095076ba34aSJunchao Zhang   Mat_MPIAIJ *a  = static_cast<Mat_MPIAIJ *>(A->data);
1096076ba34aSJunchao Zhang   Mat         Ad = a->A, Ao = a->B; /* diag and offdiag of A */
1097076ba34aSJunchao Zhang   Mat         C1, C2;               /* intermediate matrices */
1098076ba34aSJunchao Zhang 
1099076ba34aSJunchao Zhang   PetscFunctionBegin;
1100076ba34aSJunchao Zhang   /* C1 = Ad^t * B */
11019566063dSJacob Faibussowitsch   PetscCall(MatProductCreate(Ad, B, NULL, &C1));
11029566063dSJacob Faibussowitsch   PetscCall(MatProductSetType(C1, MATPRODUCT_AtB));
11039566063dSJacob Faibussowitsch   PetscCall(MatProductSetFill(C1, product->fill));
1104076ba34aSJunchao Zhang   C1->product->api_user = product->api_user;
11059566063dSJacob Faibussowitsch   PetscCall(MatProductSetFromOptions(C1));
1106dbbe0bcdSBarry Smith   PetscUseTypeMethod(C1, productsymbolic);
1107076ba34aSJunchao Zhang 
11089566063dSJacob Faibussowitsch   if (localB) PetscCall(MatSeqAIJKokkosGetCSRMatrixWithGlobalColumnIds(C1, N, l2g, mm->C1_global));
1109076ba34aSJunchao Zhang   else mm->C1_global = static_cast<Mat_SeqAIJKokkos *>(C1->spptr)->csrmat; /* the csrmat already uses global col ids */
1110076ba34aSJunchao Zhang 
1111076ba34aSJunchao Zhang   /* C2 = Ao^t * B */
11129566063dSJacob Faibussowitsch   PetscCall(MatProductCreate(Ao, B, NULL, &C2));
11139566063dSJacob Faibussowitsch   PetscCall(MatProductSetType(C2, MATPRODUCT_AtB));
11149566063dSJacob Faibussowitsch   PetscCall(MatProductSetFill(C2, product->fill));
1115076ba34aSJunchao Zhang   C2->product->api_user = product->api_user;
11169566063dSJacob Faibussowitsch   PetscCall(MatProductSetFromOptions(C2));
1117dbbe0bcdSBarry Smith   PetscUseTypeMethod(C2, productsymbolic);
1118076ba34aSJunchao Zhang 
11199566063dSJacob Faibussowitsch   PetscCall(MatSeqAIJKokkosReduce(C2, MAT_INITIAL_MATRIX, localB, N, l2g, a->Mvctx, mm->sf, mm->abuf, mm->srcrowoffset, mm->dstrowoffset, mm->C2_global));
1120076ba34aSJunchao Zhang 
1121076ba34aSJunchao Zhang   mm->kh.create_spadd_handle(false); /* Input C1, C2 are NOT sorted, since B may be not */
1122076ba34aSJunchao Zhang   KokkosSparse::spadd_symbolic(&mm->kh, mm->C1_global, mm->C2_global, mm->C_global);
1123076ba34aSJunchao Zhang   /* Have to do numeric since spadd_symbolic does not really populate column indices of the result matrix */
1124076ba34aSJunchao Zhang   KokkosSparse::spadd_numeric(&mm->kh, (MatScalarType)1.0, mm->C1_global, (MatScalarType)1.0, mm->C2_global, mm->C_global);
1125076ba34aSJunchao Zhang   mm->C1 = C1;
1126076ba34aSJunchao Zhang   mm->C2 = C2;
11273ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
1128076ba34aSJunchao Zhang }
1129076ba34aSJunchao Zhang 
1130d71ae5a4SJacob Faibussowitsch PetscErrorCode MatProductNumeric_MPIAIJKokkos(Mat C)
1131d71ae5a4SJacob Faibussowitsch {
1132076ba34aSJunchao Zhang   Mat_Product                 *product = C->product;
1133076ba34aSJunchao Zhang   MatProductType               ptype;
1134076ba34aSJunchao Zhang   MatProductData_MPIAIJKokkos *mmdata;
1135076ba34aSJunchao Zhang   MatMatStruct                *mm = NULL;
1136076ba34aSJunchao Zhang   MatMatStruct_AB             *ab;
1137076ba34aSJunchao Zhang   MatMatStruct_AtB            *atb;
1138076ba34aSJunchao Zhang   Mat                          A, B, Ad, Ao, Bd, Bo;
1139076ba34aSJunchao Zhang   const MatScalarType          one = 1.0; /* Not use literal 1.0 directly, to avoid wrong template instantiation in KokkosSparse::spadd_numeric */
1140076ba34aSJunchao Zhang 
1141076ba34aSJunchao Zhang   PetscFunctionBegin;
1142076ba34aSJunchao Zhang   MatCheckProduct(C, 1);
1143076ba34aSJunchao Zhang   mmdata = static_cast<MatProductData_MPIAIJKokkos *>(product->data);
1144076ba34aSJunchao Zhang   ptype  = product->type;
1145076ba34aSJunchao Zhang   A      = product->A;
1146076ba34aSJunchao Zhang   B      = product->B;
1147076ba34aSJunchao Zhang   Ad     = static_cast<Mat_MPIAIJ *>(A->data)->A;
1148076ba34aSJunchao Zhang   Ao     = static_cast<Mat_MPIAIJ *>(A->data)->B;
1149076ba34aSJunchao Zhang   Bd     = static_cast<Mat_MPIAIJ *>(B->data)->A;
1150076ba34aSJunchao Zhang   Bo     = static_cast<Mat_MPIAIJ *>(B->data)->B;
1151076ba34aSJunchao Zhang 
1152076ba34aSJunchao Zhang   if (mmdata->reusesym) {           /* We reached here through e.g., MatMatMult(A,B,MAT_INITIAL_MATRIX,..,C), where symbolic/numeric are combined */
1153076ba34aSJunchao Zhang     mmdata->reusesym = PETSC_FALSE; /* So that next time when user calls MatMatMult(E,F,MAT_REUSE_MATRIX,..,C), we still do numeric  */
1154076ba34aSJunchao Zhang     ab               = mmdata->mmAB;
1155076ba34aSJunchao Zhang     atb              = mmdata->mmAtB;
1156076ba34aSJunchao Zhang     if (ab) {
1157076ba34aSJunchao Zhang       static_cast<MatProductData_SeqAIJKokkos *>(ab->C1->product->data)->reusesym = PETSC_FALSE;
1158076ba34aSJunchao Zhang       static_cast<MatProductData_SeqAIJKokkos *>(ab->C2->product->data)->reusesym = PETSC_FALSE;
1159076ba34aSJunchao Zhang     }
1160076ba34aSJunchao Zhang     if (atb) {
1161076ba34aSJunchao Zhang       static_cast<MatProductData_SeqAIJKokkos *>(atb->C1->product->data)->reusesym = PETSC_FALSE;
1162076ba34aSJunchao Zhang       static_cast<MatProductData_SeqAIJKokkos *>(atb->C2->product->data)->reusesym = PETSC_FALSE;
1163076ba34aSJunchao Zhang     }
11643ba16761SJacob Faibussowitsch     PetscFunctionReturn(PETSC_SUCCESS);
1165076ba34aSJunchao Zhang   }
1166076ba34aSJunchao Zhang 
1167076ba34aSJunchao Zhang   if (ptype == MATPRODUCT_AB) {
1168076ba34aSJunchao Zhang     ab = mmdata->mmAB;
1169076ba34aSJunchao Zhang     /* C1 = Ad * B_local */
117008401ef6SPierre Jolivet     PetscCheck(ab->C1->ops->productnumeric && ab->C2->ops->productnumeric, PETSC_COMM_SELF, PETSC_ERR_PLIB, "Missing numeric op for MATPRODUCT_AB");
1171f0e6e2d1SJunchao Zhang #if PETSC_PKG_KOKKOS_KERNELS_VERSION_LT(3, 7, 99)
11729566063dSJacob Faibussowitsch     PetscCall(MatMPIAIJGetLocalMatMerge(B, MAT_REUSE_MATRIX, NULL /* glob */, &ab->B_local));
1173f0e6e2d1SJunchao Zhang #else
1174f0e6e2d1SJunchao Zhang     PetscCall(MatMPIAIJGetLocalMat_MPIAIJKokkos(B, MAT_REUSE_MATRIX, ab, &ab->B_local));
1175f0e6e2d1SJunchao Zhang #endif
1176f0e6e2d1SJunchao Zhang 
11775f80ce2aSJacob Faibussowitsch     PetscCheck(ab->C1->product->B == ab->B_local, PETSC_COMM_SELF, PETSC_ERR_PLIB, "In MATPRODUCT_AB, internal mat product matrix C1->B has unexpectedly changed");
11789566063dSJacob Faibussowitsch     if (ab->C1->product->A != Ad) PetscCall(MatProductReplaceMats(Ad, NULL, NULL, ab->C1));
11799566063dSJacob Faibussowitsch     PetscCall((*ab->C1->ops->productnumeric)(ab->C1));
11809371c9d4SSatish Balay     PetscCall(MatSeqAIJKokkosBcast(ab->B_local, MAT_REUSE_MATRIX, 0 /* N */, MatColIdxKokkosView() /*l2g*/, NULL /*ownerSF*/, ab->sf, ab->abuf, ab->rows, ab->rowoffset, ab->B_other));
1181076ba34aSJunchao Zhang     /* C2 = Ao * B_other */
118208401ef6SPierre Jolivet     PetscCheck(ab->C2->product->B == ab->B_other, PETSC_COMM_SELF, PETSC_ERR_PLIB, "In MATPRODUCT_AB, internal mat product matrix C2->B has unexpectedly changed");
11839566063dSJacob Faibussowitsch     if (ab->C1->product->A != Ao) PetscCall(MatProductReplaceMats(Ao, NULL, NULL, ab->C2));
11849566063dSJacob Faibussowitsch     PetscCall((*ab->C2->ops->productnumeric)(ab->C2));
1185076ba34aSJunchao Zhang     /* C = C1_global + C2_global */
1186076ba34aSJunchao Zhang     KokkosSparse::spadd_numeric(&ab->kh, one, ab->C1_global, one, ab->C2_global, ab->C_global);
1187076ba34aSJunchao Zhang     mm = static_cast<MatMatStruct *>(ab);
1188076ba34aSJunchao Zhang   } else if (ptype == MATPRODUCT_AtB) {
1189076ba34aSJunchao Zhang     atb = mmdata->mmAtB;
119008401ef6SPierre Jolivet     PetscCheck(atb->C1->ops->productnumeric && atb->C2->ops->productnumeric, PETSC_COMM_SELF, PETSC_ERR_PLIB, "Missing numeric op for MATPRODUCT_AtB");
1191076ba34aSJunchao Zhang     /* C1 = Ad^t * B_local */
11929566063dSJacob Faibussowitsch     PetscCall(MatMPIAIJGetLocalMatMerge(B, MAT_REUSE_MATRIX, NULL /* glob */, &atb->B_local));
119308401ef6SPierre Jolivet     PetscCheck(atb->C1->product->B == atb->B_local, PETSC_COMM_SELF, PETSC_ERR_PLIB, "In MATPRODUCT_AtB, internal mat product matrix C1->B has unexpectedly changed");
11949566063dSJacob Faibussowitsch     if (atb->C1->product->A != Ad) PetscCall(MatProductReplaceMats(Ad, NULL, NULL, atb->C1));
11959566063dSJacob Faibussowitsch     PetscCall((*atb->C1->ops->productnumeric)(atb->C1));
1196076ba34aSJunchao Zhang 
1197076ba34aSJunchao Zhang     /* C2 = Ao^t * B_local */
119808401ef6SPierre Jolivet     PetscCheck(atb->C2->product->B == atb->B_local, PETSC_COMM_SELF, PETSC_ERR_PLIB, "In MATPRODUCT_AtB, internal mat product matrix C2->B has unexpectedly changed");
11999566063dSJacob Faibussowitsch     if (atb->C2->product->A != Ao) PetscCall(MatProductReplaceMats(Ao, NULL, NULL, atb->C2));
12009566063dSJacob Faibussowitsch     PetscCall((*atb->C2->ops->productnumeric)(atb->C2));
1201076ba34aSJunchao Zhang     /* Form C2_global */
12029371c9d4SSatish Balay     PetscCall(MatSeqAIJKokkosReduce(atb->C2, MAT_REUSE_MATRIX, PETSC_TRUE, 0 /* N */, MatColIdxKokkosView() /*l2g*/, NULL /*ownerSF*/, atb->sf, atb->abuf, atb->srcrowoffset, atb->dstrowoffset, atb->C2_global));
1203076ba34aSJunchao Zhang     /* C = C1_global + C2_global */
1204076ba34aSJunchao Zhang     KokkosSparse::spadd_numeric(&atb->kh, one, atb->C1_global, one, atb->C2_global, atb->C_global);
1205076ba34aSJunchao Zhang     mm = static_cast<MatMatStruct *>(atb);
1206076ba34aSJunchao Zhang   } else if (ptype == MATPRODUCT_PtAP) { /* BtAB */
1207076ba34aSJunchao Zhang     ab = mmdata->mmAB;
1208f0e6e2d1SJunchao Zhang #if PETSC_PKG_KOKKOS_KERNELS_VERSION_LT(3, 7, 99)
12099566063dSJacob Faibussowitsch     PetscCall(MatMPIAIJGetLocalMatMerge(B, MAT_REUSE_MATRIX, NULL /* glob */, &ab->B_local));
1210f0e6e2d1SJunchao Zhang #else
1211f0e6e2d1SJunchao Zhang     PetscCall(MatMPIAIJGetLocalMat_MPIAIJKokkos(B, MAT_REUSE_MATRIX, ab, &ab->B_local));
1212f0e6e2d1SJunchao Zhang #endif
1213076ba34aSJunchao Zhang     /* ab->C1 = Ad * B_local */
121408401ef6SPierre Jolivet     PetscCheck(ab->C1->product->B == ab->B_local, PETSC_COMM_SELF, PETSC_ERR_PLIB, "In MATPRODUCT_PtAP, internal mat product matrix ab->C1->B has unexpectedly changed");
12159566063dSJacob Faibussowitsch     if (ab->C1->product->A != Ad) PetscCall(MatProductReplaceMats(Ad, NULL, NULL, ab->C1));
12169566063dSJacob Faibussowitsch     PetscCall((*ab->C1->ops->productnumeric)(ab->C1));
12179371c9d4SSatish Balay     PetscCall(MatSeqAIJKokkosBcast(ab->B_local, MAT_REUSE_MATRIX, 0 /* N */, MatColIdxKokkosView() /*l2g*/, NULL /*ownerSF*/, ab->sf, ab->abuf, ab->rows, ab->rowoffset, ab->B_other));
1218076ba34aSJunchao Zhang     /* ab->C2 = Ao * B_other */
12199566063dSJacob Faibussowitsch     if (ab->C2->product->A != Ao) PetscCall(MatProductReplaceMats(Ao, NULL, NULL, ab->C2));
12209566063dSJacob Faibussowitsch     PetscCall((*ab->C2->ops->productnumeric)(ab->C2)); /* C2 = Ao * B_other */
1221076ba34aSJunchao Zhang     KokkosSparse::spadd_numeric(&ab->kh, one, ab->C1_global, one, ab->C2_global, ab->C_global);
1222076ba34aSJunchao Zhang 
1223076ba34aSJunchao Zhang     /* atb->C1 = Bd^t * ab->C_petsc */
1224076ba34aSJunchao Zhang     atb = mmdata->mmAtB;
122508401ef6SPierre Jolivet     PetscCheck(atb->C1->product->B == ab->C_petsc, PETSC_COMM_SELF, PETSC_ERR_PLIB, "In MATPRODUCT_PtAP, internal mat product matrix atb->C1->B has unexpectedly changed");
12269566063dSJacob Faibussowitsch     if (atb->C1->product->A != Bd) PetscCall(MatProductReplaceMats(Bd, NULL, NULL, atb->C1));
12279566063dSJacob Faibussowitsch     PetscCall((*atb->C1->ops->productnumeric)(atb->C1));
1228076ba34aSJunchao Zhang     /* atb->C2 = Bo^t * ab->C_petsc */
12299566063dSJacob Faibussowitsch     if (atb->C2->product->A != Bo) PetscCall(MatProductReplaceMats(Bo, NULL, NULL, atb->C2));
12309566063dSJacob Faibussowitsch     PetscCall((*atb->C2->ops->productnumeric)(atb->C2));
12319371c9d4SSatish Balay     PetscCall(MatSeqAIJKokkosReduce(atb->C2, MAT_REUSE_MATRIX, PETSC_FALSE, 0 /* N */, MatColIdxKokkosView() /*l2g*/, NULL /* ownerSF */, atb->sf, atb->abuf, atb->srcrowoffset, atb->dstrowoffset, atb->C2_global));
1232076ba34aSJunchao Zhang     KokkosSparse::spadd_numeric(&atb->kh, one, atb->C1_global, one, atb->C2_global, atb->C_global);
1233076ba34aSJunchao Zhang     mm = static_cast<MatMatStruct *>(atb);
1234076ba34aSJunchao Zhang   }
1235076ba34aSJunchao Zhang   /* Split C_global to form C */
12369566063dSJacob Faibussowitsch   PetscCall(MatSetMPIAIJKokkosWithGlobalCSRMatrix(C, MAT_REUSE_MATRIX, mm->C_global, mm->Cdstart));
12373ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
1238076ba34aSJunchao Zhang }
1239076ba34aSJunchao Zhang 
1240d71ae5a4SJacob Faibussowitsch PetscErrorCode MatProductSymbolic_MPIAIJKokkos(Mat C)
1241d71ae5a4SJacob Faibussowitsch {
1242076ba34aSJunchao Zhang   Mat                          A, B;
1243076ba34aSJunchao Zhang   Mat_Product                 *product = C->product;
1244076ba34aSJunchao Zhang   MatProductType               ptype;
1245076ba34aSJunchao Zhang   MatProductData_MPIAIJKokkos *mmdata;
1246076ba34aSJunchao Zhang   MatMatStruct                *mm   = NULL;
1247076ba34aSJunchao Zhang   IS                           glob = NULL;
1248076ba34aSJunchao Zhang   const PetscInt              *garray;
1249076ba34aSJunchao Zhang   PetscInt                     m, n, M, N, sz;
1250076ba34aSJunchao Zhang   ConstMatColIdxKokkosView     l2g; /* map local col ids to global ones */
1251076ba34aSJunchao Zhang 
1252076ba34aSJunchao Zhang   PetscFunctionBegin;
1253076ba34aSJunchao Zhang   MatCheckProduct(C, 1);
125428b400f6SJacob Faibussowitsch   PetscCheck(!product->data, PetscObjectComm((PetscObject)C), PETSC_ERR_PLIB, "Product data not empty");
1255076ba34aSJunchao Zhang   ptype = product->type;
1256076ba34aSJunchao Zhang   A     = product->A;
1257076ba34aSJunchao Zhang   B     = product->B;
1258076ba34aSJunchao Zhang 
1259076ba34aSJunchao Zhang   switch (ptype) {
12609371c9d4SSatish Balay   case MATPRODUCT_AB:
12619371c9d4SSatish Balay     m = A->rmap->n;
12629371c9d4SSatish Balay     n = B->cmap->n;
12639371c9d4SSatish Balay     M = A->rmap->N;
12649371c9d4SSatish Balay     N = B->cmap->N;
12659371c9d4SSatish Balay     break;
12669371c9d4SSatish Balay   case MATPRODUCT_AtB:
12679371c9d4SSatish Balay     m = A->cmap->n;
12689371c9d4SSatish Balay     n = B->cmap->n;
12699371c9d4SSatish Balay     M = A->cmap->N;
12709371c9d4SSatish Balay     N = B->cmap->N;
12719371c9d4SSatish Balay     break;
12729371c9d4SSatish Balay   case MATPRODUCT_PtAP:
12739371c9d4SSatish Balay     m = B->cmap->n;
12749371c9d4SSatish Balay     n = B->cmap->n;
12759371c9d4SSatish Balay     M = B->cmap->N;
12769371c9d4SSatish Balay     N = B->cmap->N;
12779371c9d4SSatish Balay     break; /* BtAB */
1278d71ae5a4SJacob Faibussowitsch   default:
1279d71ae5a4SJacob Faibussowitsch     SETERRQ(PetscObjectComm((PetscObject)C), PETSC_ERR_PLIB, "Not for product type %s", MatProductTypes[ptype]);
1280076ba34aSJunchao Zhang   }
1281076ba34aSJunchao Zhang 
12829566063dSJacob Faibussowitsch   PetscCall(MatSetSizes(C, m, n, M, N));
12839566063dSJacob Faibussowitsch   PetscCall(PetscLayoutSetUp(C->rmap));
12849566063dSJacob Faibussowitsch   PetscCall(PetscLayoutSetUp(C->cmap));
12859566063dSJacob Faibussowitsch   PetscCall(MatSetType(C, ((PetscObject)A)->type_name));
1286076ba34aSJunchao Zhang 
1287076ba34aSJunchao Zhang   mmdata           = new MatProductData_MPIAIJKokkos();
1288076ba34aSJunchao Zhang   mmdata->reusesym = product->api_user;
1289076ba34aSJunchao Zhang 
1290076ba34aSJunchao Zhang   if (ptype == MATPRODUCT_AB) {
1291076ba34aSJunchao Zhang     mmdata->mmAB = new MatMatStruct_AB();
12929566063dSJacob Faibussowitsch     PetscCall(MatProductSymbolic_MPIAIJKokkos_AB(product, A, B, mmdata->mmAB));
1293076ba34aSJunchao Zhang     mm = static_cast<MatMatStruct *>(mmdata->mmAB);
1294076ba34aSJunchao Zhang   } else if (ptype == MATPRODUCT_AtB) {
1295076ba34aSJunchao Zhang     mmdata->mmAtB = new MatMatStruct_AtB();
1296076ba34aSJunchao Zhang     auto atb      = mmdata->mmAtB;
12979566063dSJacob Faibussowitsch     PetscCall(MatMPIAIJGetLocalMatMerge(B, MAT_INITIAL_MATRIX, &glob, &atb->B_local));
12989566063dSJacob Faibussowitsch     PetscCall(ISGetIndices(glob, &garray));
12999566063dSJacob Faibussowitsch     PetscCall(ISGetSize(glob, &sz));
1300076ba34aSJunchao Zhang     l2g = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), ConstMatColIdxKokkosViewHost(garray, sz));
13019566063dSJacob Faibussowitsch     PetscCall(MatProductSymbolic_MPIAIJKokkos_AtB(product, A, atb->B_local, PETSC_TRUE, N, l2g, atb));
13029566063dSJacob Faibussowitsch     PetscCall(ISRestoreIndices(glob, &garray));
13039566063dSJacob Faibussowitsch     PetscCall(ISDestroy(&glob));
1304076ba34aSJunchao Zhang     mm = static_cast<MatMatStruct *>(atb);
1305076ba34aSJunchao Zhang   } else if (ptype == MATPRODUCT_PtAP) {    /* BtAB */
1306076ba34aSJunchao Zhang     mmdata->mmAB  = new MatMatStruct_AB();  /* tmp=A*B */
1307076ba34aSJunchao Zhang     mmdata->mmAtB = new MatMatStruct_AtB(); /* C=B^t*tmp */
1308076ba34aSJunchao Zhang     auto ab       = mmdata->mmAB;
1309076ba34aSJunchao Zhang     auto atb      = mmdata->mmAtB;
13109566063dSJacob Faibussowitsch     PetscCall(MatProductSymbolic_MPIAIJKokkos_AB(product, A, B, ab));
1311076ba34aSJunchao Zhang     auto tmp = new Mat_SeqAIJKokkos(ab->C_global); /* Memory will be owned by ab->C_petsc */
13129566063dSJacob Faibussowitsch     PetscCall(MatCreateSeqAIJKokkosWithCSRMatrix(PETSC_COMM_SELF, tmp, &ab->C_petsc));
13139566063dSJacob Faibussowitsch     PetscCall(MatProductSymbolic_MPIAIJKokkos_AtB(product, B, ab->C_petsc, PETSC_FALSE, N, l2g /*not used*/, atb));
1314076ba34aSJunchao Zhang     mm = static_cast<MatMatStruct *>(atb);
1315076ba34aSJunchao Zhang   }
1316076ba34aSJunchao Zhang   /* Split the C_global into petsc A, B format */
13179566063dSJacob Faibussowitsch   PetscCall(MatSetMPIAIJKokkosWithGlobalCSRMatrix(C, MAT_INITIAL_MATRIX, mm->C_global, mm->Cdstart));
1318076ba34aSJunchao Zhang   C->product->data       = mmdata;
1319076ba34aSJunchao Zhang   C->product->destroy    = MatProductDataDestroy_MPIAIJKokkos;
1320076ba34aSJunchao Zhang   C->ops->productnumeric = MatProductNumeric_MPIAIJKokkos;
13213ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
1322076ba34aSJunchao Zhang }
1323076ba34aSJunchao Zhang 
1324d71ae5a4SJacob Faibussowitsch PETSC_INTERN PetscErrorCode MatProductSetFromOptions_MPIAIJKokkos(Mat mat)
1325d71ae5a4SJacob Faibussowitsch {
1326076ba34aSJunchao Zhang   Mat_Product *product = mat->product;
1327076ba34aSJunchao Zhang   PetscBool    match   = PETSC_FALSE;
1328076ba34aSJunchao Zhang   PetscBool    usecpu  = PETSC_FALSE;
1329076ba34aSJunchao Zhang 
1330076ba34aSJunchao Zhang   PetscFunctionBegin;
1331076ba34aSJunchao Zhang   MatCheckProduct(mat, 1);
133248a46eb9SPierre Jolivet   if (!product->A->boundtocpu && !product->B->boundtocpu) PetscCall(PetscObjectTypeCompare((PetscObject)product->B, ((PetscObject)product->A)->type_name, &match));
1333076ba34aSJunchao Zhang   if (match) { /* we can always fallback to the CPU if requested */
1334076ba34aSJunchao Zhang     switch (product->type) {
1335076ba34aSJunchao Zhang     case MATPRODUCT_AB:
1336076ba34aSJunchao Zhang       if (product->api_user) {
1337d0609cedSBarry Smith         PetscOptionsBegin(PetscObjectComm((PetscObject)mat), ((PetscObject)mat)->prefix, "MatMatMult", "Mat");
13389566063dSJacob Faibussowitsch         PetscCall(PetscOptionsBool("-matmatmult_backend_cpu", "Use CPU code", "MatMatMult", usecpu, &usecpu, NULL));
1339d0609cedSBarry Smith         PetscOptionsEnd();
1340076ba34aSJunchao Zhang       } else {
1341d0609cedSBarry Smith         PetscOptionsBegin(PetscObjectComm((PetscObject)mat), ((PetscObject)mat)->prefix, "MatProduct_AB", "Mat");
13429566063dSJacob Faibussowitsch         PetscCall(PetscOptionsBool("-mat_product_algorithm_backend_cpu", "Use CPU code", "MatMatMult", usecpu, &usecpu, NULL));
1343d0609cedSBarry Smith         PetscOptionsEnd();
1344076ba34aSJunchao Zhang       }
1345076ba34aSJunchao Zhang       break;
1346076ba34aSJunchao Zhang     case MATPRODUCT_AtB:
1347076ba34aSJunchao Zhang       if (product->api_user) {
1348d0609cedSBarry Smith         PetscOptionsBegin(PetscObjectComm((PetscObject)mat), ((PetscObject)mat)->prefix, "MatTransposeMatMult", "Mat");
13499566063dSJacob Faibussowitsch         PetscCall(PetscOptionsBool("-mattransposematmult_backend_cpu", "Use CPU code", "MatTransposeMatMult", usecpu, &usecpu, NULL));
1350d0609cedSBarry Smith         PetscOptionsEnd();
1351076ba34aSJunchao Zhang       } else {
1352d0609cedSBarry Smith         PetscOptionsBegin(PetscObjectComm((PetscObject)mat), ((PetscObject)mat)->prefix, "MatProduct_AtB", "Mat");
13539566063dSJacob Faibussowitsch         PetscCall(PetscOptionsBool("-mat_product_algorithm_backend_cpu", "Use CPU code", "MatTransposeMatMult", usecpu, &usecpu, NULL));
1354d0609cedSBarry Smith         PetscOptionsEnd();
1355076ba34aSJunchao Zhang       }
1356076ba34aSJunchao Zhang       break;
1357076ba34aSJunchao Zhang     case MATPRODUCT_PtAP:
1358076ba34aSJunchao Zhang       if (product->api_user) {
1359d0609cedSBarry Smith         PetscOptionsBegin(PetscObjectComm((PetscObject)mat), ((PetscObject)mat)->prefix, "MatPtAP", "Mat");
13609566063dSJacob Faibussowitsch         PetscCall(PetscOptionsBool("-matptap_backend_cpu", "Use CPU code", "MatPtAP", usecpu, &usecpu, NULL));
1361d0609cedSBarry Smith         PetscOptionsEnd();
1362076ba34aSJunchao Zhang       } else {
1363d0609cedSBarry Smith         PetscOptionsBegin(PetscObjectComm((PetscObject)mat), ((PetscObject)mat)->prefix, "MatProduct_PtAP", "Mat");
13649566063dSJacob Faibussowitsch         PetscCall(PetscOptionsBool("-mat_product_algorithm_backend_cpu", "Use CPU code", "MatPtAP", usecpu, &usecpu, NULL));
1365d0609cedSBarry Smith         PetscOptionsEnd();
1366076ba34aSJunchao Zhang       }
1367076ba34aSJunchao Zhang       break;
1368d71ae5a4SJacob Faibussowitsch     default:
1369d71ae5a4SJacob Faibussowitsch       break;
1370076ba34aSJunchao Zhang     }
1371076ba34aSJunchao Zhang     match = (PetscBool)!usecpu;
1372076ba34aSJunchao Zhang   }
1373076ba34aSJunchao Zhang   if (match) {
1374076ba34aSJunchao Zhang     switch (product->type) {
1375076ba34aSJunchao Zhang     case MATPRODUCT_AB:
1376076ba34aSJunchao Zhang     case MATPRODUCT_AtB:
1377d71ae5a4SJacob Faibussowitsch     case MATPRODUCT_PtAP:
1378d71ae5a4SJacob Faibussowitsch       mat->ops->productsymbolic = MatProductSymbolic_MPIAIJKokkos;
1379d71ae5a4SJacob Faibussowitsch       break;
1380d71ae5a4SJacob Faibussowitsch     default:
1381d71ae5a4SJacob Faibussowitsch       break;
1382076ba34aSJunchao Zhang     }
1383076ba34aSJunchao Zhang   }
1384076ba34aSJunchao Zhang   /* fallback to MPIAIJ ops */
138548a46eb9SPierre Jolivet   if (!mat->ops->productsymbolic) PetscCall(MatProductSetFromOptions_MPIAIJ(mat));
13863ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
1387076ba34aSJunchao Zhang }
1388076ba34aSJunchao Zhang 
1389d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatSetPreallocationCOO_MPIAIJKokkos(Mat mat, PetscCount coo_n, PetscInt coo_i[], PetscInt coo_j[])
1390d71ae5a4SJacob Faibussowitsch {
1391394ed5ebSJunchao Zhang   Mat_MPIAIJ       *mpiaij = (Mat_MPIAIJ *)mat->data;
1392cbc6b225SStefano Zampini   Mat_MPIAIJKokkos *mpikok;
139342550becSJunchao Zhang 
139442550becSJunchao Zhang   PetscFunctionBegin;
139530203840SJunchao Zhang   PetscCall(MatSetPreallocationCOO_MPIAIJ(mat, coo_n, coo_i, coo_j)); /* mpiaij->A,B's type is set to seqaijkokkos */
1396cbc6b225SStefano Zampini   mat->preallocated = PETSC_TRUE;
13979566063dSJacob Faibussowitsch   PetscCall(MatAssemblyBegin(mat, MAT_FINAL_ASSEMBLY));
13989566063dSJacob Faibussowitsch   PetscCall(MatAssemblyEnd(mat, MAT_FINAL_ASSEMBLY));
13999566063dSJacob Faibussowitsch   PetscCall(MatZeroEntries(mat));
1400cbc6b225SStefano Zampini   mpikok = static_cast<Mat_MPIAIJKokkos *>(mpiaij->spptr);
1401cbc6b225SStefano Zampini   delete mpikok;
1402394ed5ebSJunchao Zhang   mpiaij->spptr = new Mat_MPIAIJKokkos(mpiaij);
14033ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
140442550becSJunchao Zhang }
140542550becSJunchao Zhang 
1406d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatSetValuesCOO_MPIAIJKokkos(Mat mat, const PetscScalar v[], InsertMode imode)
1407d71ae5a4SJacob Faibussowitsch {
1408394ed5ebSJunchao Zhang   Mat_MPIAIJ                 *mpiaij = static_cast<Mat_MPIAIJ *>(mat->data);
140942550becSJunchao Zhang   Mat_MPIAIJKokkos           *mpikok = static_cast<Mat_MPIAIJKokkos *>(mpiaij->spptr);
141042550becSJunchao Zhang   Mat                         A = mpiaij->A, B = mpiaij->B;
1411158ec288SJunchao Zhang   PetscCount                  Annz = mpiaij->Annz, Annz2 = mpiaij->Annz2, Bnnz = mpiaij->Bnnz, Bnnz2 = mpiaij->Bnnz2;
141242550becSJunchao Zhang   MatScalarKokkosView         Aa, Ba;
1413394ed5ebSJunchao Zhang   MatScalarKokkosView         v1;
141442550becSJunchao Zhang   MatScalarKokkosView        &vsend  = mpikok->sendbuf_d;
141542550becSJunchao Zhang   const MatScalarKokkosView  &v2     = mpikok->recvbuf_d;
1416158ec288SJunchao Zhang   const PetscCountKokkosView &Ajmap1 = mpikok->Ajmap1_d, Ajmap2 = mpikok->Ajmap2_d, Aimap2 = mpikok->Aimap2_d;
1417158ec288SJunchao Zhang   const PetscCountKokkosView &Bjmap1 = mpikok->Bjmap1_d, Bjmap2 = mpikok->Bjmap2_d, Bimap2 = mpikok->Bimap2_d;
1418394ed5ebSJunchao Zhang   const PetscCountKokkosView &Aperm1 = mpikok->Aperm1_d, Aperm2 = mpikok->Aperm2_d, Bperm1 = mpikok->Bperm1_d, Bperm2 = mpikok->Bperm2_d;
1419394ed5ebSJunchao Zhang   const PetscCountKokkosView &Cperm1 = mpikok->Cperm1_d;
142042550becSJunchao Zhang   PetscMemType                memtype;
142142550becSJunchao Zhang 
142242550becSJunchao Zhang   PetscFunctionBegin;
14239566063dSJacob Faibussowitsch   PetscCall(PetscGetMemType(v, &memtype)); /* Return PETSC_MEMTYPE_HOST when v is NULL */
142442550becSJunchao Zhang   if (PetscMemTypeHost(memtype)) {         /* If user gave v[] in host, we need to copy it to device if any */
1425394ed5ebSJunchao Zhang     v1 = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), MatScalarKokkosViewHost((PetscScalar *)v, mpiaij->coo_n));
142642550becSJunchao Zhang   } else {
1427394ed5ebSJunchao Zhang     v1 = MatScalarKokkosView((PetscScalar *)v, mpiaij->coo_n); /* Directly use v[]'s memory */
142842550becSJunchao Zhang   }
142942550becSJunchao Zhang 
143042550becSJunchao Zhang   if (imode == INSERT_VALUES) {
14319566063dSJacob Faibussowitsch     PetscCall(MatSeqAIJGetKokkosViewWrite(A, &Aa)); /* write matrix values */
14329566063dSJacob Faibussowitsch     PetscCall(MatSeqAIJGetKokkosViewWrite(B, &Ba));
1433394ed5ebSJunchao Zhang   } else {
14349566063dSJacob Faibussowitsch     PetscCall(MatSeqAIJGetKokkosView(A, &Aa)); /* read & write matrix values */
14359566063dSJacob Faibussowitsch     PetscCall(MatSeqAIJGetKokkosView(B, &Ba));
143642550becSJunchao Zhang   }
143742550becSJunchao Zhang 
143842550becSJunchao Zhang   /* Pack entries to be sent to remote */
14399371c9d4SSatish Balay   Kokkos::parallel_for(
14409371c9d4SSatish Balay     vsend.extent(0), KOKKOS_LAMBDA(const PetscCount i) { vsend(i) = v1(Cperm1(i)); });
144142550becSJunchao Zhang 
144242550becSJunchao Zhang   /* Send remote entries to their owner and overlap the communication with local computation */
14439566063dSJacob Faibussowitsch   PetscCall(PetscSFReduceWithMemTypeBegin(mpiaij->coo_sf, MPIU_SCALAR, PETSC_MEMTYPE_KOKKOS, vsend.data(), PETSC_MEMTYPE_KOKKOS, v2.data(), MPI_REPLACE));
1444158ec288SJunchao Zhang   /* Add local entries to A and B in one kernel */
14459371c9d4SSatish Balay   Kokkos::parallel_for(
14469371c9d4SSatish Balay     Annz + Bnnz, KOKKOS_LAMBDA(PetscCount i) {
1447158ec288SJunchao Zhang       PetscScalar sum = 0.0;
1448158ec288SJunchao Zhang       if (i < Annz) {
1449158ec288SJunchao Zhang         for (PetscCount k = Ajmap1(i); k < Ajmap1(i + 1); k++) sum += v1(Aperm1(k));
1450ac38520cSJunchao Zhang         Aa(i) = (imode == INSERT_VALUES ? 0.0 : Aa(i)) + sum;
1451158ec288SJunchao Zhang       } else {
1452158ec288SJunchao Zhang         i -= Annz;
1453158ec288SJunchao Zhang         for (PetscCount k = Bjmap1(i); k < Bjmap1(i + 1); k++) sum += v1(Bperm1(k));
1454ac38520cSJunchao Zhang         Ba(i) = (imode == INSERT_VALUES ? 0.0 : Ba(i)) + sum;
1455158ec288SJunchao Zhang       }
1456158ec288SJunchao Zhang     });
14579566063dSJacob Faibussowitsch   PetscCall(PetscSFReduceEnd(mpiaij->coo_sf, MPIU_SCALAR, vsend.data(), v2.data(), MPI_REPLACE));
145842550becSJunchao Zhang 
1459158ec288SJunchao Zhang   /* Add received remote entries to A and B in one kernel */
14609371c9d4SSatish Balay   Kokkos::parallel_for(
14619371c9d4SSatish Balay     Annz2 + Bnnz2, KOKKOS_LAMBDA(PetscCount i) {
1462158ec288SJunchao Zhang       if (i < Annz2) {
1463158ec288SJunchao Zhang         for (PetscCount k = Ajmap2(i); k < Ajmap2(i + 1); k++) Aa(Aimap2(i)) += v2(Aperm2(k));
1464158ec288SJunchao Zhang       } else {
1465158ec288SJunchao Zhang         i -= Annz2;
1466158ec288SJunchao Zhang         for (PetscCount k = Bjmap2(i); k < Bjmap2(i + 1); k++) Ba(Bimap2(i)) += v2(Bperm2(k));
1467158ec288SJunchao Zhang       }
1468158ec288SJunchao Zhang     });
146942550becSJunchao Zhang 
1470394ed5ebSJunchao Zhang   if (imode == INSERT_VALUES) {
14719566063dSJacob Faibussowitsch     PetscCall(MatSeqAIJRestoreKokkosViewWrite(A, &Aa)); /* Increase A & B's state etc. */
14729566063dSJacob Faibussowitsch     PetscCall(MatSeqAIJRestoreKokkosViewWrite(B, &Ba));
1473394ed5ebSJunchao Zhang   } else {
14749566063dSJacob Faibussowitsch     PetscCall(MatSeqAIJRestoreKokkosView(A, &Aa));
14759566063dSJacob Faibussowitsch     PetscCall(MatSeqAIJRestoreKokkosView(B, &Ba));
1476394ed5ebSJunchao Zhang   }
14773ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
147842550becSJunchao Zhang }
147942550becSJunchao Zhang 
1480d71ae5a4SJacob Faibussowitsch PetscErrorCode MatDestroy_MPIAIJKokkos(Mat A)
1481d71ae5a4SJacob Faibussowitsch {
148242550becSJunchao Zhang   Mat_MPIAIJ *mpiaij = (Mat_MPIAIJ *)A->data;
1483076ba34aSJunchao Zhang 
1484076ba34aSJunchao Zhang   PetscFunctionBegin;
14859566063dSJacob Faibussowitsch   PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatMPIAIJSetPreallocation_C", NULL));
14869566063dSJacob Faibussowitsch   PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatMPIAIJGetLocalMatMerge_C", NULL));
14879566063dSJacob Faibussowitsch   PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatSetPreallocationCOO_C", NULL));
14889566063dSJacob Faibussowitsch   PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatSetValuesCOO_C", NULL));
148942550becSJunchao Zhang   delete (Mat_MPIAIJKokkos *)mpiaij->spptr;
14909566063dSJacob Faibussowitsch   PetscCall(MatDestroy_MPIAIJ(A));
14913ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
1492076ba34aSJunchao Zhang }
1493076ba34aSJunchao Zhang 
1494d71ae5a4SJacob Faibussowitsch PETSC_INTERN PetscErrorCode MatConvert_MPIAIJ_MPIAIJKokkos(Mat A, MatType mtype, MatReuse reuse, Mat *newmat)
1495d71ae5a4SJacob Faibussowitsch {
14968c3ff71bSJunchao Zhang   Mat         B;
1497076ba34aSJunchao Zhang   Mat_MPIAIJ *a;
14988c3ff71bSJunchao Zhang 
14998c3ff71bSJunchao Zhang   PetscFunctionBegin;
15008c3ff71bSJunchao Zhang   if (reuse == MAT_INITIAL_MATRIX) {
15019566063dSJacob Faibussowitsch     PetscCall(MatDuplicate(A, MAT_COPY_VALUES, newmat));
15028c3ff71bSJunchao Zhang   } else if (reuse == MAT_REUSE_MATRIX) {
15039566063dSJacob Faibussowitsch     PetscCall(MatCopy(A, *newmat, SAME_NONZERO_PATTERN));
15048c3ff71bSJunchao Zhang   }
15058c3ff71bSJunchao Zhang   B = *newmat;
15068c3ff71bSJunchao Zhang 
15076f3d89d0SStefano Zampini   B->boundtocpu = PETSC_FALSE;
15089566063dSJacob Faibussowitsch   PetscCall(PetscFree(B->defaultvectype));
15099566063dSJacob Faibussowitsch   PetscCall(PetscStrallocpy(VECKOKKOS, &B->defaultvectype));
15109566063dSJacob Faibussowitsch   PetscCall(PetscObjectChangeTypeName((PetscObject)B, MATMPIAIJKOKKOS));
15118c3ff71bSJunchao Zhang 
1512076ba34aSJunchao Zhang   a = static_cast<Mat_MPIAIJ *>(A->data);
15139566063dSJacob Faibussowitsch   if (a->A) PetscCall(MatSetType(a->A, MATSEQAIJKOKKOS));
15149566063dSJacob Faibussowitsch   if (a->B) PetscCall(MatSetType(a->B, MATSEQAIJKOKKOS));
15159566063dSJacob Faibussowitsch   if (a->lvec) PetscCall(VecSetType(a->lvec, VECSEQKOKKOS));
1516076ba34aSJunchao Zhang 
15178c3ff71bSJunchao Zhang   B->ops->assemblyend           = MatAssemblyEnd_MPIAIJKokkos;
15188c3ff71bSJunchao Zhang   B->ops->mult                  = MatMult_MPIAIJKokkos;
15198c3ff71bSJunchao Zhang   B->ops->multadd               = MatMultAdd_MPIAIJKokkos;
15208c3ff71bSJunchao Zhang   B->ops->multtranspose         = MatMultTranspose_MPIAIJKokkos;
1521076ba34aSJunchao Zhang   B->ops->productsetfromoptions = MatProductSetFromOptions_MPIAIJKokkos;
1522076ba34aSJunchao Zhang   B->ops->destroy               = MatDestroy_MPIAIJKokkos;
15238c3ff71bSJunchao Zhang 
15249566063dSJacob Faibussowitsch   PetscCall(PetscObjectComposeFunction((PetscObject)B, "MatMPIAIJSetPreallocation_C", MatMPIAIJSetPreallocation_MPIAIJKokkos));
15259566063dSJacob Faibussowitsch   PetscCall(PetscObjectComposeFunction((PetscObject)B, "MatMPIAIJGetLocalMatMerge_C", MatMPIAIJGetLocalMatMerge_MPIAIJKokkos));
15269566063dSJacob Faibussowitsch   PetscCall(PetscObjectComposeFunction((PetscObject)B, "MatSetPreallocationCOO_C", MatSetPreallocationCOO_MPIAIJKokkos));
15279566063dSJacob Faibussowitsch   PetscCall(PetscObjectComposeFunction((PetscObject)B, "MatSetValuesCOO_C", MatSetValuesCOO_MPIAIJKokkos));
15283ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
15298c3ff71bSJunchao Zhang }
15303f3ba80aSJunchao Zhang /*MC
153111a5261eSBarry Smith    MATAIJKOKKOS - "mpiaijkokkos", a matrix type to be used for CSR sparse matrices with Kokkos
15328c3ff71bSJunchao Zhang 
15333f3ba80aSJunchao Zhang    A matrix type type using Kokkos-Kernels CrsMatrix type for portability across different device types
15343f3ba80aSJunchao Zhang 
1535*2ef1f0ffSBarry Smith    Options Database Key:
1536*2ef1f0ffSBarry Smith .  -mat_type aijkokkos - sets the matrix type to `MATAIJKOKKOS`
15373f3ba80aSJunchao Zhang 
15383f3ba80aSJunchao Zhang   Level: beginner
15393f3ba80aSJunchao Zhang 
1540*2ef1f0ffSBarry Smith .seealso: [](chapter_matrices), `Mat`, `MatCreateAIJKokkos()`, `MATSEQAIJKOKKOS`, `MATSEQAIJ`, `MATMPIAIJ`
15413f3ba80aSJunchao Zhang M*/
1542d71ae5a4SJacob Faibussowitsch PETSC_EXTERN PetscErrorCode MatCreate_MPIAIJKokkos(Mat A)
1543d71ae5a4SJacob Faibussowitsch {
15448c3ff71bSJunchao Zhang   PetscFunctionBegin;
15459566063dSJacob Faibussowitsch   PetscCall(PetscKokkosInitializeCheck());
15469566063dSJacob Faibussowitsch   PetscCall(MatCreate_MPIAIJ(A));
15479566063dSJacob Faibussowitsch   PetscCall(MatConvert_MPIAIJ_MPIAIJKokkos(A, MATMPIAIJKOKKOS, MAT_INPLACE_MATRIX, &A));
15483ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
15498c3ff71bSJunchao Zhang }
15508c3ff71bSJunchao Zhang 
15518c3ff71bSJunchao Zhang /*@C
155211a5261eSBarry Smith    MatCreateAIJKokkos - Creates a sparse matrix in `MATAIJKOKOS` (compressed row) format
15538c3ff71bSJunchao Zhang    (the default parallel PETSc format).  This matrix will ultimately pushed down
15548c3ff71bSJunchao Zhang    to Kokkos for calculations. For good matrix
15558c3ff71bSJunchao Zhang    assembly performance the user should preallocate the matrix storage by setting
1556*2ef1f0ffSBarry Smith    the parameter `nz` (or the array `nnz`).
15578c3ff71bSJunchao Zhang 
15588c3ff71bSJunchao Zhang    Collective
15598c3ff71bSJunchao Zhang 
15608c3ff71bSJunchao Zhang    Input Parameters:
156111a5261eSBarry Smith +  comm - MPI communicator, set to `PETSC_COMM_SELF`
15628c3ff71bSJunchao Zhang .  m - number of rows
15638c3ff71bSJunchao Zhang .  n - number of columns
15648c3ff71bSJunchao Zhang .  nz - number of nonzeros per row (same for all rows)
15658c3ff71bSJunchao Zhang -  nnz - array containing the number of nonzeros in the various rows
1566*2ef1f0ffSBarry Smith          (possibly different for each row) or `NULL`
15678c3ff71bSJunchao Zhang 
15688c3ff71bSJunchao Zhang    Output Parameter:
15698c3ff71bSJunchao Zhang .  A - the matrix
15708c3ff71bSJunchao Zhang 
1571*2ef1f0ffSBarry Smith    Level: intermediate
1572*2ef1f0ffSBarry Smith 
1573*2ef1f0ffSBarry Smith    Notes:
157411a5261eSBarry Smith    It is recommended that one use the `MatCreate()`, `MatSetType()` and/or `MatSetFromOptions()`,
15758c3ff71bSJunchao Zhang    MatXXXXSetPreallocation() paradigm instead of this routine directly.
157611a5261eSBarry Smith    [MatXXXXSetPreallocation() is, for example, `MatSeqAIJSetPreallocation()`]
15778c3ff71bSJunchao Zhang 
1578*2ef1f0ffSBarry Smith    If `nnz` is given then `nz` is ignored
15798c3ff71bSJunchao Zhang 
1580667f096bSBarry Smith    The AIJ format, also called compressed row storage), is fully compatible with standard Fortran
15818c3ff71bSJunchao Zhang    storage.  That is, the stored row and column indices can begin at
1582*2ef1f0ffSBarry Smith    either one (as in Fortran) or zero.
15838c3ff71bSJunchao Zhang 
15848c3ff71bSJunchao Zhang    Specify the preallocated storage with either nz or nnz (not both).
1585*2ef1f0ffSBarry Smith    Set `nz` = `PETSC_DEFAULT` and `nnz` = `NULL` for PETSc to control dynamic memory
15868c3ff71bSJunchao Zhang    allocation.  For large problems you MUST preallocate memory or you
15878c3ff71bSJunchao Zhang    will get TERRIBLE performance, see the users' manual chapter on matrices.
15888c3ff71bSJunchao Zhang 
15898c3ff71bSJunchao Zhang    By default, this format uses inodes (identical nodes) when possible, to
15908c3ff71bSJunchao Zhang    improve numerical efficiency of matrix-vector products and solves. We
15918c3ff71bSJunchao Zhang    search for consecutive rows with the same nonzero structure, thereby
15928c3ff71bSJunchao Zhang    reusing matrix information to achieve increased efficiency.
15938c3ff71bSJunchao Zhang 
1594*2ef1f0ffSBarry Smith    Developer Note:
1595*2ef1f0ffSBarry Smith    This manual page is for the sequential constructor, not the parallel constructor
15968c3ff71bSJunchao Zhang 
1597*2ef1f0ffSBarry Smith .seealso: [](chapter_matrices), `Mat`, `MATAIJKOKOS`, `MATSEQAIJKOKOS`, `MATMPIAIJKOKOS`, `MatCreate()`, `MatCreateAIJ()`, `MatSetValues()`,
1598*2ef1f0ffSBarry Smith           `MatSeqAIJSetColumnIndices()`, `MatCreateSeqAIJWithArrays()`, `MatCreateAIJ()`, `MATMPIAIJKOKKOS`, `MATAIJKOKKOS`
15998c3ff71bSJunchao Zhang @*/
1600d71ae5a4SJacob Faibussowitsch PetscErrorCode MatCreateAIJKokkos(MPI_Comm comm, PetscInt m, PetscInt n, PetscInt M, PetscInt N, PetscInt d_nz, const PetscInt d_nnz[], PetscInt o_nz, const PetscInt o_nnz[], Mat *A)
1601d71ae5a4SJacob Faibussowitsch {
16028c3ff71bSJunchao Zhang   PetscMPIInt size;
16038c3ff71bSJunchao Zhang 
16048c3ff71bSJunchao Zhang   PetscFunctionBegin;
16059566063dSJacob Faibussowitsch   PetscCall(MatCreate(comm, A));
16069566063dSJacob Faibussowitsch   PetscCall(MatSetSizes(*A, m, n, M, N));
16079566063dSJacob Faibussowitsch   PetscCallMPI(MPI_Comm_size(comm, &size));
16088c3ff71bSJunchao Zhang   if (size > 1) {
16099566063dSJacob Faibussowitsch     PetscCall(MatSetType(*A, MATMPIAIJKOKKOS));
16109566063dSJacob Faibussowitsch     PetscCall(MatMPIAIJSetPreallocation(*A, d_nz, d_nnz, o_nz, o_nnz));
16118c3ff71bSJunchao Zhang   } else {
16129566063dSJacob Faibussowitsch     PetscCall(MatSetType(*A, MATSEQAIJKOKKOS));
16139566063dSJacob Faibussowitsch     PetscCall(MatSeqAIJSetPreallocation(*A, d_nz, d_nnz));
16148c3ff71bSJunchao Zhang   }
16153ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
16168c3ff71bSJunchao Zhang }
16178c3ff71bSJunchao Zhang 
1618a587d139SMark // get GPU pointer to stripped down Mat. For both Seq and MPI Mat.
1619d71ae5a4SJacob Faibussowitsch PetscErrorCode MatKokkosGetDeviceMatWrite(Mat A, PetscSplitCSRDataStructure *B)
1620d71ae5a4SJacob Faibussowitsch {
1621a587d139SMark   PetscMPIInt                size, rank;
1622a587d139SMark   MPI_Comm                   comm;
1623042217e8SBarry Smith   PetscSplitCSRDataStructure d_mat = NULL;
1624a587d139SMark 
1625a587d139SMark   PetscFunctionBegin;
16269566063dSJacob Faibussowitsch   PetscCall(PetscObjectGetComm((PetscObject)A, &comm));
16279566063dSJacob Faibussowitsch   PetscCallMPI(MPI_Comm_size(comm, &size));
16289566063dSJacob Faibussowitsch   PetscCallMPI(MPI_Comm_rank(comm, &rank));
1629a587d139SMark   if (size == 1) {
16309566063dSJacob Faibussowitsch     PetscCall(MatSeqAIJKokkosGetDeviceMat(A, &d_mat));
16319566063dSJacob Faibussowitsch     PetscCall(MatSeqAIJKokkosModifyDevice(A)); /* Since we are going to modify matrix values on device */
1632a587d139SMark   } else {
1633a587d139SMark     Mat_MPIAIJ *aij = (Mat_MPIAIJ *)A->data;
16349566063dSJacob Faibussowitsch     PetscCall(MatSeqAIJKokkosGetDeviceMat(aij->A, &d_mat));
16359566063dSJacob Faibussowitsch     PetscCall(MatSeqAIJKokkosModifyDevice(aij->A));
16369566063dSJacob Faibussowitsch     PetscCall(MatSeqAIJKokkosModifyDevice(aij->B));
16372c71b3e2SJacob Faibussowitsch     PetscCheck(A->nooffprocentries || aij->donotstash, PetscObjectComm((PetscObject)A), PETSC_ERR_SUP, "Device assembly does not currently support offproc values insertion. Use MatSetOption(A,MAT_NO_OFF_PROC_ENTRIES,PETSC_TRUE) or MatSetOption(A,MAT_IGNORE_OFF_PROC_ENTRIES,PETSC_TRUE)");
1638a587d139SMark   }
1639a587d139SMark   // act like MatSetValues because not called on host
1640a587d139SMark   if (A->assembled) {
164148a46eb9SPierre Jolivet     if (A->was_assembled) PetscCall(PetscInfo(A, "Assemble more than once already\n"));
1642a587d139SMark     A->was_assembled = PETSC_TRUE; // this is done (lazy) in MatAssemble but we are not calling it anymore - done in AIJ AssemblyEnd, need here?
1643a587d139SMark   } else {
16449566063dSJacob Faibussowitsch     PetscCall(PetscInfo(A, "Warning !assemble ??? assembled=%" PetscInt_FMT "\n", A->assembled));
1645a587d139SMark   }
1646a587d139SMark   if (!d_mat) {
1647042217e8SBarry Smith     struct _n_SplitCSRMat h_mat; /* host container */
1648a587d139SMark     Mat_SeqAIJKokkos     *aijkokA;
1649a587d139SMark     Mat_SeqAIJ           *jaca;
1650a587d139SMark     PetscInt              n = A->rmap->n, nnz;
1651a587d139SMark     Mat                   Amat;
1652042217e8SBarry Smith     PetscInt             *colmap;
1653042217e8SBarry Smith 
1654042217e8SBarry Smith     /* create and copy h_mat */
165549b994a9SMark Adams     h_mat.M = A->cmap->N; // use for debug build
16569566063dSJacob Faibussowitsch     PetscCall(PetscInfo(A, "Create device matrix in Kokkos\n"));
1657a587d139SMark     if (size == 1) {
1658a587d139SMark       Amat            = A;
1659a587d139SMark       jaca            = (Mat_SeqAIJ *)A->data;
16609371c9d4SSatish Balay       h_mat.rstart    = 0;
16619371c9d4SSatish Balay       h_mat.rend      = A->rmap->n;
16629371c9d4SSatish Balay       h_mat.cstart    = 0;
16639371c9d4SSatish Balay       h_mat.cend      = A->cmap->n;
1664a587d139SMark       h_mat.offdiag.i = h_mat.offdiag.j = NULL;
1665a587d139SMark       h_mat.offdiag.a                   = NULL;
1666a587d139SMark       aijkokA                           = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
1667a587d139SMark     } else {
1668a587d139SMark       Mat_MPIAIJ       *aij  = (Mat_MPIAIJ *)A->data;
1669a587d139SMark       Mat_SeqAIJ       *jacb = (Mat_SeqAIJ *)aij->B->data;
1670a587d139SMark       PetscInt          ii;
1671a587d139SMark       Mat_SeqAIJKokkos *aijkokB;
1672042217e8SBarry Smith 
1673a587d139SMark       Amat    = aij->A;
1674a587d139SMark       aijkokA = static_cast<Mat_SeqAIJKokkos *>(aij->A->spptr);
1675a587d139SMark       aijkokB = static_cast<Mat_SeqAIJKokkos *>(aij->B->spptr);
1676a587d139SMark       jaca    = (Mat_SeqAIJ *)aij->A->data;
167708401ef6SPierre Jolivet       PetscCheck(!aij->B->cmap->n || aij->garray, comm, PETSC_ERR_PLIB, "MPIAIJ Matrix was assembled but is missing garray");
167808401ef6SPierre Jolivet       PetscCheck(aij->B->rmap->n == aij->A->rmap->n, comm, PETSC_ERR_SUP, "Only support aij->B->rmap->n == aij->A->rmap->n");
1679a587d139SMark       aij->donotstash          = PETSC_TRUE;
1680a587d139SMark       aij->A->nooffprocentries = aij->B->nooffprocentries = A->nooffprocentries = PETSC_TRUE;
1681a5b23f4aSJose E. Roman       jaca->nonew = jacb->nonew = PETSC_TRUE; // no more disassembly
16829566063dSJacob Faibussowitsch       PetscCall(PetscCalloc1(A->cmap->N, &colmap));
1683042217e8SBarry Smith       for (ii = 0; ii < aij->B->cmap->n; ii++) colmap[aij->garray[ii]] = ii + 1;
1684a587d139SMark       // allocate B copy data
16859371c9d4SSatish Balay       h_mat.rstart = A->rmap->rstart;
16869371c9d4SSatish Balay       h_mat.rend   = A->rmap->rend;
16879371c9d4SSatish Balay       h_mat.cstart = A->cmap->rstart;
16889371c9d4SSatish Balay       h_mat.cend   = A->cmap->rend;
1689a587d139SMark       nnz          = jacb->i[n];
1690a587d139SMark       if (jacb->compressedrow.use) {
1691a587d139SMark         const Kokkos::View<PetscInt *, Kokkos::HostSpace, Kokkos::MemoryTraits<Kokkos::Unmanaged>> h_i_k(jacb->i, n + 1);
1692300d22a6SJunchao Zhang         aijkokB->i_uncompressed_d = Kokkos::View<PetscInt *>(Kokkos::create_mirror(DefaultMemorySpace(), h_i_k));
1693300d22a6SJunchao Zhang         Kokkos::deep_copy(aijkokB->i_uncompressed_d, h_i_k);
1694300d22a6SJunchao Zhang         h_mat.offdiag.i = aijkokB->i_uncompressed_d.data();
1695a587d139SMark       } else {
169699551766SMark Adams         h_mat.offdiag.i = aijkokB->i_device_data();
1697a587d139SMark       }
169899551766SMark Adams       h_mat.offdiag.j = aijkokB->j_device_data();
1699076ba34aSJunchao Zhang       h_mat.offdiag.a = aijkokB->a_device_data();
1700a587d139SMark       {
1701042217e8SBarry Smith         Kokkos::View<PetscInt *, Kokkos::HostSpace, Kokkos::MemoryTraits<Kokkos::Unmanaged>> h_colmap_k(colmap, A->cmap->N);
1702300d22a6SJunchao Zhang         aijkokB->colmap_d = Kokkos::View<PetscInt *>(Kokkos::create_mirror(DefaultMemorySpace(), h_colmap_k));
1703300d22a6SJunchao Zhang         Kokkos::deep_copy(aijkokB->colmap_d, h_colmap_k);
1704300d22a6SJunchao Zhang         h_mat.colmap = aijkokB->colmap_d.data();
17059566063dSJacob Faibussowitsch         PetscCall(PetscFree(colmap));
1706a587d139SMark       }
1707a587d139SMark       h_mat.offdiag.ignorezeroentries = jacb->ignorezeroentries;
1708a587d139SMark       h_mat.offdiag.n                 = n;
1709a587d139SMark     }
1710a587d139SMark     // allocate A copy data
1711a587d139SMark     nnz                          = jaca->i[n];
1712a587d139SMark     h_mat.diag.n                 = n;
1713a587d139SMark     h_mat.diag.ignorezeroentries = jaca->ignorezeroentries;
17149566063dSJacob Faibussowitsch     PetscCallMPI(MPI_Comm_rank(comm, &h_mat.rank));
1715d5b43468SJose E. Roman     PetscCheck(!jaca->compressedrow.use, PETSC_COMM_SELF, PETSC_ERR_PLIB, "A does not support compressed row (todo)");
171699551766SMark Adams     h_mat.diag.i = aijkokA->i_device_data();
171799551766SMark Adams     h_mat.diag.j = aijkokA->j_device_data();
1718076ba34aSJunchao Zhang     h_mat.diag.a = aijkokA->a_device_data();
1719da81f932SPierre Jolivet     // copy pointers and metadata to device
17209566063dSJacob Faibussowitsch     PetscCall(MatSeqAIJKokkosSetDeviceMat(Amat, &h_mat));
17219566063dSJacob Faibussowitsch     PetscCall(MatSeqAIJKokkosGetDeviceMat(Amat, &d_mat));
17229566063dSJacob Faibussowitsch     PetscCall(PetscInfo(A, "Create device Mat n=%" PetscInt_FMT " nnz=%" PetscInt_FMT "\n", h_mat.diag.n, nnz));
1723a587d139SMark   }
1724a587d139SMark   *B           = d_mat;       // return it, set it in Mat, and set it up
1725a587d139SMark   A->assembled = PETSC_FALSE; // ready to write with matsetvalues - this done (lazy) in normal MatSetValues
17263ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
1727a587d139SMark }
1728076ba34aSJunchao Zhang 
1729d71ae5a4SJacob Faibussowitsch PETSC_INTERN PetscErrorCode MatSeqAIJKokkosGetOffloadMask(Mat A, const char **mask)
1730d71ae5a4SJacob Faibussowitsch {
1731076ba34aSJunchao Zhang   Mat_SeqAIJKokkos *aijkok = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
1732076ba34aSJunchao Zhang 
1733076ba34aSJunchao Zhang   PetscFunctionBegin;
1734076ba34aSJunchao Zhang   if (!aijkok) *mask = "AIJKOK_UNALLOCATED";
1735076ba34aSJunchao Zhang   else if (aijkok->a_dual.need_sync_host()) *mask = "PETSC_OFFLOAD_GPU";
1736076ba34aSJunchao Zhang   else if (aijkok->a_dual.need_sync_device()) *mask = "PETSC_OFFLOAD_CPU";
1737076ba34aSJunchao Zhang   else *mask = "PETSC_OFFLOAD_BOTH";
17383ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
1739076ba34aSJunchao Zhang }
1740076ba34aSJunchao Zhang 
1741d71ae5a4SJacob Faibussowitsch PETSC_INTERN PetscErrorCode MatAIJKokkosPrintOffloadMask(Mat A)
1742d71ae5a4SJacob Faibussowitsch {
1743076ba34aSJunchao Zhang   PetscMPIInt size;
1744076ba34aSJunchao Zhang   Mat         Ad, Ao;
1745076ba34aSJunchao Zhang   const char *amask, *bmask;
1746076ba34aSJunchao Zhang 
1747076ba34aSJunchao Zhang   PetscFunctionBegin;
17489566063dSJacob Faibussowitsch   PetscCallMPI(MPI_Comm_size(PetscObjectComm((PetscObject)A), &size));
1749076ba34aSJunchao Zhang 
1750076ba34aSJunchao Zhang   if (size == 1) {
17519566063dSJacob Faibussowitsch     PetscCall(MatSeqAIJKokkosGetOffloadMask(A, &amask));
17529566063dSJacob Faibussowitsch     PetscCall(PetscPrintf(PETSC_COMM_SELF, "%s\n", amask));
1753076ba34aSJunchao Zhang   } else {
1754076ba34aSJunchao Zhang     Ad = ((Mat_MPIAIJ *)A->data)->A;
1755076ba34aSJunchao Zhang     Ao = ((Mat_MPIAIJ *)A->data)->B;
17569566063dSJacob Faibussowitsch     PetscCall(MatSeqAIJKokkosGetOffloadMask(Ad, &amask));
17579566063dSJacob Faibussowitsch     PetscCall(MatSeqAIJKokkosGetOffloadMask(Ao, &bmask));
17589566063dSJacob Faibussowitsch     PetscCall(PetscPrintf(PETSC_COMM_SELF, "Diag : Off-diag = %s : %s\n", amask, bmask));
1759076ba34aSJunchao Zhang   }
17603ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
1761076ba34aSJunchao Zhang }
1762