xref: /petsc/src/mat/impls/aij/mpi/kokkos/mpiaijkok.kokkos.cxx (revision 11a5261e40035b7c793f2783a2ba6c7cd4f3b077)
111d22bbfSJunchao Zhang #include <petscvec_kokkos.hpp>
2076ba34aSJunchao Zhang #include <petscsf.h>
342550becSJunchao Zhang #include <petsc/private/sfimpl.h>
48c3ff71bSJunchao Zhang #include <../src/mat/impls/aij/mpi/mpiaij.h>
542550becSJunchao Zhang #include <../src/mat/impls/aij/mpi/kokkos/mpiaijkok.hpp>
6076ba34aSJunchao Zhang #include <KokkosSparse_spadd.hpp>
711d22bbfSJunchao Zhang 
89371c9d4SSatish Balay PetscErrorCode MatAssemblyEnd_MPIAIJKokkos(Mat A, MatAssemblyType mode) {
95519a089SJose E. Roman   Mat_SeqAIJKokkos *aijkok;
108c3ff71bSJunchao Zhang 
118c3ff71bSJunchao Zhang   PetscFunctionBegin;
129566063dSJacob Faibussowitsch   PetscCall(MatAssemblyEnd_MPIAIJ(A, mode));
135519a089SJose E. Roman   aijkok = static_cast<Mat_SeqAIJKokkos *>(((Mat_MPIAIJ *)A->data)->A->spptr); /* Access spptr after MatAssemblyEnd_MPIAIJ(), which might have deleted old spptr */
14a587d139SMark   if (aijkok && aijkok->device_mat_d.data()) {
15a587d139SMark     A->offloadmask = PETSC_OFFLOAD_GPU; // in GPU mode, no going back. MatSetValues checks this
16a587d139SMark   }
17a587d139SMark 
188c3ff71bSJunchao Zhang   PetscFunctionReturn(0);
198c3ff71bSJunchao Zhang }
208c3ff71bSJunchao Zhang 
219371c9d4SSatish Balay PetscErrorCode MatMPIAIJSetPreallocation_MPIAIJKokkos(Mat mat, PetscInt d_nz, const PetscInt d_nnz[], PetscInt o_nz, const PetscInt o_nnz[]) {
228c3ff71bSJunchao Zhang   Mat_MPIAIJ *mpiaij = (Mat_MPIAIJ *)mat->data;
238c3ff71bSJunchao Zhang 
248c3ff71bSJunchao Zhang   PetscFunctionBegin;
259566063dSJacob Faibussowitsch   PetscCall(PetscLayoutSetUp(mat->rmap));
269566063dSJacob Faibussowitsch   PetscCall(PetscLayoutSetUp(mat->cmap));
276a29ce69SStefano Zampini #if defined(PETSC_USE_DEBUG)
288c3ff71bSJunchao Zhang   if (d_nnz) {
296a29ce69SStefano Zampini     PetscInt i;
30ad540459SPierre Jolivet     for (i = 0; i < mat->rmap->n; i++) PetscCheck(d_nnz[i] >= 0, PETSC_COMM_SELF, PETSC_ERR_ARG_OUTOFRANGE, "d_nnz cannot be less than 0: local row %" PetscInt_FMT " value %" PetscInt_FMT, i, d_nnz[i]);
318c3ff71bSJunchao Zhang   }
328c3ff71bSJunchao Zhang   if (o_nnz) {
336a29ce69SStefano Zampini     PetscInt i;
34ad540459SPierre Jolivet     for (i = 0; i < mat->rmap->n; i++) PetscCheck(o_nnz[i] >= 0, PETSC_COMM_SELF, PETSC_ERR_ARG_OUTOFRANGE, "o_nnz cannot be less than 0: local row %" PetscInt_FMT " value %" PetscInt_FMT, i, o_nnz[i]);
358c3ff71bSJunchao Zhang   }
366a29ce69SStefano Zampini #endif
376a29ce69SStefano Zampini #if defined(PETSC_USE_CTABLE)
389566063dSJacob Faibussowitsch   PetscCall(PetscTableDestroy(&mpiaij->colmap));
396a29ce69SStefano Zampini #else
409566063dSJacob Faibussowitsch   PetscCall(PetscFree(mpiaij->colmap));
416a29ce69SStefano Zampini #endif
429566063dSJacob Faibussowitsch   PetscCall(PetscFree(mpiaij->garray));
439566063dSJacob Faibussowitsch   PetscCall(VecDestroy(&mpiaij->lvec));
449566063dSJacob Faibussowitsch   PetscCall(VecScatterDestroy(&mpiaij->Mvctx));
456a29ce69SStefano Zampini   /* Because the B will have been resized we simply destroy it and create a new one each time */
469566063dSJacob Faibussowitsch   PetscCall(MatDestroy(&mpiaij->B));
476a29ce69SStefano Zampini 
486a29ce69SStefano Zampini   if (!mpiaij->A) {
499566063dSJacob Faibussowitsch     PetscCall(MatCreate(PETSC_COMM_SELF, &mpiaij->A));
509566063dSJacob Faibussowitsch     PetscCall(MatSetSizes(mpiaij->A, mat->rmap->n, mat->cmap->n, mat->rmap->n, mat->cmap->n));
519566063dSJacob Faibussowitsch     PetscCall(PetscLogObjectParent((PetscObject)mat, (PetscObject)mpiaij->A));
526a29ce69SStefano Zampini   }
536a29ce69SStefano Zampini   if (!mpiaij->B) {
546a29ce69SStefano Zampini     PetscMPIInt size;
559566063dSJacob Faibussowitsch     PetscCallMPI(MPI_Comm_size(PetscObjectComm((PetscObject)mat), &size));
569566063dSJacob Faibussowitsch     PetscCall(MatCreate(PETSC_COMM_SELF, &mpiaij->B));
579566063dSJacob Faibussowitsch     PetscCall(MatSetSizes(mpiaij->B, mat->rmap->n, size > 1 ? mat->cmap->N : 0, mat->rmap->n, size > 1 ? mat->cmap->N : 0));
589566063dSJacob Faibussowitsch     PetscCall(PetscLogObjectParent((PetscObject)mat, (PetscObject)mpiaij->B));
598c3ff71bSJunchao Zhang   }
609566063dSJacob Faibussowitsch   PetscCall(MatSetType(mpiaij->A, MATSEQAIJKOKKOS));
619566063dSJacob Faibussowitsch   PetscCall(MatSetType(mpiaij->B, MATSEQAIJKOKKOS));
629566063dSJacob Faibussowitsch   PetscCall(MatSeqAIJSetPreallocation(mpiaij->A, d_nz, d_nnz));
639566063dSJacob Faibussowitsch   PetscCall(MatSeqAIJSetPreallocation(mpiaij->B, o_nz, o_nnz));
648c3ff71bSJunchao Zhang   mat->preallocated = PETSC_TRUE;
658c3ff71bSJunchao Zhang   PetscFunctionReturn(0);
668c3ff71bSJunchao Zhang }
678c3ff71bSJunchao Zhang 
689371c9d4SSatish Balay PetscErrorCode MatMult_MPIAIJKokkos(Mat mat, Vec xx, Vec yy) {
698c3ff71bSJunchao Zhang   Mat_MPIAIJ *mpiaij = (Mat_MPIAIJ *)mat->data;
708c3ff71bSJunchao Zhang   PetscInt    nt;
718c3ff71bSJunchao Zhang 
728c3ff71bSJunchao Zhang   PetscFunctionBegin;
739566063dSJacob Faibussowitsch   PetscCall(VecGetLocalSize(xx, &nt));
7408401ef6SPierre Jolivet   PetscCheck(nt == mat->cmap->n, PETSC_COMM_SELF, PETSC_ERR_ARG_SIZ, "Incompatible partition of mat (%" PetscInt_FMT ") and xx (%" PetscInt_FMT ")", mat->cmap->n, nt);
759566063dSJacob Faibussowitsch   PetscCall(VecScatterBegin(mpiaij->Mvctx, xx, mpiaij->lvec, INSERT_VALUES, SCATTER_FORWARD));
769566063dSJacob Faibussowitsch   PetscCall((*mpiaij->A->ops->mult)(mpiaij->A, xx, yy));
779566063dSJacob Faibussowitsch   PetscCall(VecScatterEnd(mpiaij->Mvctx, xx, mpiaij->lvec, INSERT_VALUES, SCATTER_FORWARD));
789566063dSJacob Faibussowitsch   PetscCall((*mpiaij->B->ops->multadd)(mpiaij->B, mpiaij->lvec, yy, yy));
798c3ff71bSJunchao Zhang   PetscFunctionReturn(0);
808c3ff71bSJunchao Zhang }
818c3ff71bSJunchao Zhang 
829371c9d4SSatish Balay PetscErrorCode MatMultAdd_MPIAIJKokkos(Mat mat, Vec xx, Vec yy, Vec zz) {
838c3ff71bSJunchao Zhang   Mat_MPIAIJ *mpiaij = (Mat_MPIAIJ *)mat->data;
848c3ff71bSJunchao Zhang   PetscInt    nt;
858c3ff71bSJunchao Zhang 
868c3ff71bSJunchao Zhang   PetscFunctionBegin;
879566063dSJacob Faibussowitsch   PetscCall(VecGetLocalSize(xx, &nt));
8808401ef6SPierre Jolivet   PetscCheck(nt == mat->cmap->n, PETSC_COMM_SELF, PETSC_ERR_ARG_SIZ, "Incompatible partition of mat (%" PetscInt_FMT ") and xx (%" PetscInt_FMT ")", mat->cmap->n, nt);
899566063dSJacob Faibussowitsch   PetscCall(VecScatterBegin(mpiaij->Mvctx, xx, mpiaij->lvec, INSERT_VALUES, SCATTER_FORWARD));
909566063dSJacob Faibussowitsch   PetscCall((*mpiaij->A->ops->multadd)(mpiaij->A, xx, yy, zz));
919566063dSJacob Faibussowitsch   PetscCall(VecScatterEnd(mpiaij->Mvctx, xx, mpiaij->lvec, INSERT_VALUES, SCATTER_FORWARD));
929566063dSJacob Faibussowitsch   PetscCall((*mpiaij->B->ops->multadd)(mpiaij->B, mpiaij->lvec, zz, zz));
938c3ff71bSJunchao Zhang   PetscFunctionReturn(0);
948c3ff71bSJunchao Zhang }
958c3ff71bSJunchao Zhang 
969371c9d4SSatish Balay PetscErrorCode MatMultTranspose_MPIAIJKokkos(Mat mat, Vec xx, Vec yy) {
978c3ff71bSJunchao Zhang   Mat_MPIAIJ *mpiaij = (Mat_MPIAIJ *)mat->data;
988c3ff71bSJunchao Zhang   PetscInt    nt;
998c3ff71bSJunchao Zhang 
1008c3ff71bSJunchao Zhang   PetscFunctionBegin;
1019566063dSJacob Faibussowitsch   PetscCall(VecGetLocalSize(xx, &nt));
10208401ef6SPierre Jolivet   PetscCheck(nt == mat->rmap->n, PETSC_COMM_SELF, PETSC_ERR_ARG_SIZ, "Incompatible partition of mat (%" PetscInt_FMT ") and xx (%" PetscInt_FMT ")", mat->rmap->n, nt);
1039566063dSJacob Faibussowitsch   PetscCall((*mpiaij->B->ops->multtranspose)(mpiaij->B, xx, mpiaij->lvec));
1049566063dSJacob Faibussowitsch   PetscCall((*mpiaij->A->ops->multtranspose)(mpiaij->A, xx, yy));
1059566063dSJacob Faibussowitsch   PetscCall(VecScatterBegin(mpiaij->Mvctx, mpiaij->lvec, yy, ADD_VALUES, SCATTER_REVERSE));
1069566063dSJacob Faibussowitsch   PetscCall(VecScatterEnd(mpiaij->Mvctx, mpiaij->lvec, yy, ADD_VALUES, SCATTER_REVERSE));
1078c3ff71bSJunchao Zhang   PetscFunctionReturn(0);
1088c3ff71bSJunchao Zhang }
1098c3ff71bSJunchao Zhang 
110076ba34aSJunchao Zhang /* Merge the "A, B" matrices of mat into a matrix C.  mat's type is MPIAIJKOKKOS. C's type is MATSEQAIJKOKKOS.
111076ba34aSJunchao Zhang    A is put before B. C's size would be A->rmap->n by (A->cmap->n + B->cmap->n).
112076ba34aSJunchao Zhang    C still uses local column ids. Their corresponding global column ids are returned in glob.
113076ba34aSJunchao Zhang */
1149371c9d4SSatish Balay PetscErrorCode MatMPIAIJGetLocalMatMerge_MPIAIJKokkos(Mat mat, MatReuse reuse, IS *glob, Mat *C) {
115076ba34aSJunchao Zhang   Mat             Ad, Ao;
116076ba34aSJunchao Zhang   const PetscInt *cmap;
117076ba34aSJunchao Zhang 
118076ba34aSJunchao Zhang   PetscFunctionBegin;
1199566063dSJacob Faibussowitsch   PetscCall(MatMPIAIJGetSeqAIJ(mat, &Ad, &Ao, &cmap));
1209566063dSJacob Faibussowitsch   PetscCall(MatSeqAIJKokkosMergeMats(Ad, Ao, reuse, C));
121076ba34aSJunchao Zhang   if (glob) {
122076ba34aSJunchao Zhang     PetscInt cst, i, dn, on, *gidx;
1239566063dSJacob Faibussowitsch     PetscCall(MatGetLocalSize(Ad, NULL, &dn));
1249566063dSJacob Faibussowitsch     PetscCall(MatGetLocalSize(Ao, NULL, &on));
1259566063dSJacob Faibussowitsch     PetscCall(MatGetOwnershipRangeColumn(mat, &cst, NULL));
1269566063dSJacob Faibussowitsch     PetscCall(PetscMalloc1(dn + on, &gidx));
127076ba34aSJunchao Zhang     for (i = 0; i < dn; i++) gidx[i] = cst + i;
128076ba34aSJunchao Zhang     for (i = 0; i < on; i++) gidx[i + dn] = cmap[i];
1299566063dSJacob Faibussowitsch     PetscCall(ISCreateGeneral(PetscObjectComm((PetscObject)Ad), dn + on, gidx, PETSC_OWN_POINTER, glob));
130076ba34aSJunchao Zhang   }
131076ba34aSJunchao Zhang   PetscFunctionReturn(0);
132076ba34aSJunchao Zhang }
133076ba34aSJunchao Zhang 
134076ba34aSJunchao Zhang /* Structs used in matrix product C=AB, C=A^tB and C=B^tAB */
135076ba34aSJunchao Zhang struct MatMatStruct {
136076ba34aSJunchao Zhang   MatRowMapKokkosView Cdstart; /* Used to split sequential matrix into petsc's A, B format */
137076ba34aSJunchao Zhang   PetscSF             sf;      /* SF to send/recv matrix entries */
138076ba34aSJunchao Zhang   MatScalarKokkosView abuf;    /* buf of mat values in send/recv */
139076ba34aSJunchao Zhang   Mat                 C1, C2, B_local;
140076ba34aSJunchao Zhang   KokkosCsrMatrix     C1_global, C2_global, C_global;
141076ba34aSJunchao Zhang   KernelHandle        kh;
142076ba34aSJunchao Zhang   MatMatStruct() {
143076ba34aSJunchao Zhang     C1 = C2 = B_local = NULL;
144076ba34aSJunchao Zhang     sf                = NULL;
145076ba34aSJunchao Zhang   }
146076ba34aSJunchao Zhang 
147076ba34aSJunchao Zhang   ~MatMatStruct() {
148076ba34aSJunchao Zhang     MatDestroy(&C1);
149076ba34aSJunchao Zhang     MatDestroy(&C2);
150076ba34aSJunchao Zhang     MatDestroy(&B_local);
151076ba34aSJunchao Zhang     PetscSFDestroy(&sf);
152076ba34aSJunchao Zhang     kh.destroy_spadd_handle();
153076ba34aSJunchao Zhang   }
154076ba34aSJunchao Zhang };
155076ba34aSJunchao Zhang 
156076ba34aSJunchao Zhang struct MatMatStruct_AB : public MatMatStruct {
157076ba34aSJunchao Zhang   MatColIdxKokkosView rows;
158076ba34aSJunchao Zhang   MatRowMapKokkosView rowoffset;
159076ba34aSJunchao Zhang   Mat                 B_other, C_petsc; /* SEQAIJKOKKOS matrices. TODO: have a better var name than C_petsc */
160076ba34aSJunchao Zhang 
161076ba34aSJunchao Zhang   MatMatStruct_AB() : B_other(NULL), C_petsc(NULL) { }
162076ba34aSJunchao Zhang   ~MatMatStruct_AB() {
163076ba34aSJunchao Zhang     MatDestroy(&B_other);
164076ba34aSJunchao Zhang     MatDestroy(&C_petsc);
165076ba34aSJunchao Zhang   }
166076ba34aSJunchao Zhang };
167076ba34aSJunchao Zhang 
168076ba34aSJunchao Zhang struct MatMatStruct_AtB : public MatMatStruct {
169076ba34aSJunchao Zhang   MatRowMapKokkosView srcrowoffset, dstrowoffset;
170076ba34aSJunchao Zhang };
171076ba34aSJunchao Zhang 
1729371c9d4SSatish Balay struct MatProductData_MPIAIJKokkos {
173076ba34aSJunchao Zhang   MatMatStruct_AB  *mmAB;
174076ba34aSJunchao Zhang   MatMatStruct_AtB *mmAtB;
175076ba34aSJunchao Zhang   PetscBool         reusesym;
176076ba34aSJunchao Zhang 
177076ba34aSJunchao Zhang   MatProductData_MPIAIJKokkos() : mmAB(NULL), mmAtB(NULL), reusesym(PETSC_FALSE) { }
178076ba34aSJunchao Zhang   ~MatProductData_MPIAIJKokkos() {
179076ba34aSJunchao Zhang     delete mmAB;
180076ba34aSJunchao Zhang     delete mmAtB;
181076ba34aSJunchao Zhang   }
182076ba34aSJunchao Zhang };
183076ba34aSJunchao Zhang 
1849371c9d4SSatish Balay static PetscErrorCode MatProductDataDestroy_MPIAIJKokkos(void *data) {
185076ba34aSJunchao Zhang   PetscFunctionBegin;
1869566063dSJacob Faibussowitsch   PetscCallCXX(delete static_cast<MatProductData_MPIAIJKokkos *>(data));
187076ba34aSJunchao Zhang   PetscFunctionReturn(0);
188076ba34aSJunchao Zhang }
189076ba34aSJunchao Zhang 
190076ba34aSJunchao Zhang /* MatSeqAIJKokkosGetCSRMatrixWithGlobalColumnIds - Get a KokkosCsrMatrix from a MATSEQAIJKOKKOS matrix
191076ba34aSJunchao Zhang 
192076ba34aSJunchao Zhang    Input Parameters:
193076ba34aSJunchao Zhang +  A       - the MATSEQAIJKOKKOS matrix
194076ba34aSJunchao Zhang .  N       - new column size for the returned Kokkos matrix
195076ba34aSJunchao Zhang -  l2g     - a map that maps old col ids to new col ids
196076ba34aSJunchao Zhang 
197076ba34aSJunchao Zhang    Output Parameters:
198076ba34aSJunchao Zhang .  csrmat  - the Kokkos matrix, which has the same row size as A, shares a, i but not j with A.
199076ba34aSJunchao Zhang  */
2009371c9d4SSatish Balay static PetscErrorCode MatSeqAIJKokkosGetCSRMatrixWithGlobalColumnIds(Mat A, PetscInt N, const ConstMatColIdxKokkosView &l2g, KokkosCsrMatrix &csrmat) {
201076ba34aSJunchao Zhang   KokkosCsrMatrix    &orig = static_cast<Mat_SeqAIJKokkos *>(A->spptr)->csrmat;
202076ba34aSJunchao Zhang   MatColIdxKokkosView jg("jg", orig.nnz()); /* New j array for csrmat */
203076ba34aSJunchao Zhang 
204076ba34aSJunchao Zhang   PetscFunctionBegin;
2059371c9d4SSatish Balay   PetscCallCXX(Kokkos::parallel_for(
2069371c9d4SSatish Balay     orig.nnz(), KOKKOS_LAMBDA(const PetscInt i) { jg(i) = l2g(orig.graph.entries(i)); }));
2079566063dSJacob Faibussowitsch   PetscCallCXX(csrmat = KokkosCsrMatrix("csrmat", orig.numRows(), N, orig.nnz(), orig.values, orig.graph.row_map, jg));
208076ba34aSJunchao Zhang   PetscFunctionReturn(0);
209076ba34aSJunchao Zhang }
210076ba34aSJunchao Zhang 
211076ba34aSJunchao Zhang /* MatSetMPIAIJKokkosWithSplitSeqAIJKokkosMatrices - Set the diag and offdiag matrices of a MATMPIAIJKOKKOS matrix.
212076ba34aSJunchao Zhang    It is similar to MatCreateMPIAIJWithSplitArrays.
213076ba34aSJunchao Zhang 
214076ba34aSJunchao Zhang   Input Parameters:
215076ba34aSJunchao Zhang +  mat   - the MATMPIAIJKOKKOS matrix, which should have its type and layout set, but should not have its diag, offdiag matrices set
216076ba34aSJunchao Zhang .  A     - the diag matrix using local col ids
217076ba34aSJunchao Zhang -  B     - the offdiag matrix using global col ids
218076ba34aSJunchao Zhang 
219076ba34aSJunchao Zhang   Output Parameters:
220076ba34aSJunchao Zhang .  mat   - the updated MATMPIAIJKOKKOS matrix
221076ba34aSJunchao Zhang */
2229371c9d4SSatish Balay static PetscErrorCode MatSetMPIAIJKokkosWithSplitSeqAIJKokkosMatrices(Mat mat, Mat A, Mat B) {
223076ba34aSJunchao Zhang   Mat_MPIAIJ       *mpiaij = static_cast<Mat_MPIAIJ *>(mat->data);
224076ba34aSJunchao Zhang   PetscInt          m, n, M, N, Am, An, Bm, Bn;
225076ba34aSJunchao Zhang   Mat_SeqAIJKokkos *bkok = static_cast<Mat_SeqAIJKokkos *>(B->spptr);
226076ba34aSJunchao Zhang 
227076ba34aSJunchao Zhang   PetscFunctionBegin;
2289566063dSJacob Faibussowitsch   PetscCall(MatGetSize(mat, &M, &N));
2299566063dSJacob Faibussowitsch   PetscCall(MatGetLocalSize(mat, &m, &n));
2309566063dSJacob Faibussowitsch   PetscCall(MatGetLocalSize(A, &Am, &An));
2319566063dSJacob Faibussowitsch   PetscCall(MatGetLocalSize(B, &Bm, &Bn));
232076ba34aSJunchao Zhang 
233aed4548fSBarry Smith   PetscCheck(m == Am && m == Bm, PETSC_COMM_SELF, PETSC_ERR_PLIB, "local number of rows do not match");
23408401ef6SPierre Jolivet   PetscCheck(n == An, PETSC_COMM_SELF, PETSC_ERR_PLIB, "local number of columns do not match");
23508401ef6SPierre Jolivet   PetscCheck(N == Bn, PETSC_COMM_SELF, PETSC_ERR_PLIB, "global number of columns do not match");
23608401ef6SPierre Jolivet   PetscCheck(!mpiaij->A && !mpiaij->B, PETSC_COMM_SELF, PETSC_ERR_PLIB, "A, B of the MPIAIJ matrix are not empty");
237076ba34aSJunchao Zhang   mpiaij->A = A;
238076ba34aSJunchao Zhang   mpiaij->B = B;
239076ba34aSJunchao Zhang 
240076ba34aSJunchao Zhang   mat->preallocated     = PETSC_TRUE;
241076ba34aSJunchao Zhang   mat->nooffprocentries = PETSC_TRUE; /* See MatAssemblyBegin_MPIAIJ. In effect, making MatAssemblyBegin a nop */
242076ba34aSJunchao Zhang 
2439566063dSJacob Faibussowitsch   PetscCall(MatSetOption(mat, MAT_NO_OFF_PROC_ENTRIES, PETSC_TRUE));
2449566063dSJacob Faibussowitsch   PetscCall(MatAssemblyBegin(mat, MAT_FINAL_ASSEMBLY));
245076ba34aSJunchao Zhang   /* MatAssemblyEnd is critical here. It sets mat->offloadmask according to A and B's, and
246076ba34aSJunchao Zhang     also gets mpiaij->B compacted, with its col ids and size reduced
247076ba34aSJunchao Zhang   */
2489566063dSJacob Faibussowitsch   PetscCall(MatAssemblyEnd(mat, MAT_FINAL_ASSEMBLY));
2499566063dSJacob Faibussowitsch   PetscCall(MatSetOption(mat, MAT_NO_OFF_PROC_ENTRIES, PETSC_FALSE));
2509566063dSJacob Faibussowitsch   PetscCall(MatSetOption(mat, MAT_NEW_NONZERO_LOCATION_ERR, PETSC_TRUE));
251076ba34aSJunchao Zhang 
252076ba34aSJunchao Zhang   /* Update bkok with new local col ids (stored on host) and size */
253076ba34aSJunchao Zhang   bkok->j_dual.modify_host();
254076ba34aSJunchao Zhang   bkok->j_dual.sync_device();
255076ba34aSJunchao Zhang   bkok->SetColSize(mpiaij->B->cmap->n);
256076ba34aSJunchao Zhang   PetscFunctionReturn(0);
257076ba34aSJunchao Zhang }
258076ba34aSJunchao Zhang 
259076ba34aSJunchao Zhang /* MatSeqAIJKokkosBcast - Bcast rows of a SEQAIJKOKKOS matrice (B) to form a SEQAIJKOKKOS matrix (C).
260076ba34aSJunchao Zhang 
261076ba34aSJunchao Zhang    It is essentially the MPIAIJKOKKOS counterpart of MatGetBrowsOfAoCols_MPIAIJ, but supports device and uses PetscSF.
262076ba34aSJunchao Zhang    In the given ownerSF, leaves correspond to rows in C, and roots correspond to rows in B. Roots may connect to multiple leaves.
263076ba34aSJunchao Zhang    Suppose C's j-th row is connected to a root identified by PetscSFNode (k,i), it means we will bcast the i-th row of B on rank k
264076ba34aSJunchao Zhang    to j-th row of C. ownerSF's leaves must be contiguous (in other words, as if ilocal=NULL was used to set its graph).
265076ba34aSJunchao Zhang 
266076ba34aSJunchao Zhang    Collective on comm of ownerSF
267076ba34aSJunchao Zhang 
268076ba34aSJunchao Zhang    Input Parameters:
269076ba34aSJunchao Zhang +   B       - the SEQAIJKOKKOS matrix, using local col ids
270076ba34aSJunchao Zhang .   reuse   - either MAT_INITIAL_MATRIX or MAT_REUSE_MATRIX
271076ba34aSJunchao Zhang .   N       - global col ids are in range of [0,N). N Must be the same across ranks (nonsignificant in MAT_REUSE_MATRIX)
272076ba34aSJunchao Zhang .   l2g     - a map mapping B's local col ids to global ones (nonsignificant in MAT_REUSE_MATRIX)
273076ba34aSJunchao Zhang .   ownerSF - the ownership SF (nonsignificant in MAT_REUSE_MATRIX)
274076ba34aSJunchao Zhang 
275076ba34aSJunchao Zhang    Input/Output Parameters (out when resue = MAT_INITIAL_MATRIX, inout when reuse = MAT_REUSE_MATRIX)
276076ba34aSJunchao Zhang +   bcastSF   - the SF used to bcast rows of B. This plain SF does buffer (abuf) to buffer (Ca) send/recv. In this SF, vertices are nonzeros.
277076ba34aSJunchao Zhang .   abuf      - buffer for sending matrix values
278076ba34aSJunchao Zhang .   rows      - array containing indices of (local) rows that this rank needs to bcast to others. Each receiver rank has a chunk in rows[].
279076ba34aSJunchao Zhang                 Values in rows[] might have repeats, which simply indicates a row will be bcast'ed to multiple neighbors.
280076ba34aSJunchao Zhang .   rowoffset - For each row in rows[], it will be copied to rowoffset[] at abuf[]
281076ba34aSJunchao Zhang -   C         -  the SEQAIJKOKKOS matrix made of the bcast'ed rows, using local col ids.
282076ba34aSJunchao Zhang */
2839371c9d4SSatish Balay static PetscErrorCode MatSeqAIJKokkosBcast(Mat B, MatReuse reuse, PetscInt N, const ConstMatColIdxKokkosView &l2g, PetscSF ownerSF, PetscSF &bcastSF, MatScalarKokkosView &abuf, MatColIdxKokkosView &rows, MatRowMapKokkosView &rowoffset, Mat &C) {
284076ba34aSJunchao Zhang   Mat_SeqAIJKokkos *bkok, *ckok;
285076ba34aSJunchao Zhang 
286076ba34aSJunchao Zhang   PetscFunctionBegin;
2879566063dSJacob Faibussowitsch   PetscCall(MatSeqAIJKokkosSyncDevice(B)); /* Make sure B->spptr is accessible */
288076ba34aSJunchao Zhang   bkok = static_cast<Mat_SeqAIJKokkos *>(B->spptr);
289076ba34aSJunchao Zhang 
290076ba34aSJunchao Zhang   if (reuse == MAT_REUSE_MATRIX) {
291076ba34aSJunchao Zhang     ckok = static_cast<Mat_SeqAIJKokkos *>(C->spptr);
292076ba34aSJunchao Zhang 
293076ba34aSJunchao Zhang     const auto &Ba = bkok->a_dual.view_device();
294076ba34aSJunchao Zhang     const auto &Bi = bkok->i_dual.view_device();
295076ba34aSJunchao Zhang     const auto &Ca = ckok->a_dual.view_device();
296076ba34aSJunchao Zhang 
297076ba34aSJunchao Zhang     /* Copy Ba to abuf */
2989371c9d4SSatish Balay     Kokkos::parallel_for(
2999371c9d4SSatish Balay       Kokkos::TeamPolicy<>(rows.extent(0), Kokkos::AUTO()), KOKKOS_LAMBDA(const KokkosTeamMemberType &t) {
300076ba34aSJunchao Zhang         PetscInt i    = t.league_rank(); /* rows[i] is r-th row of B */
301076ba34aSJunchao Zhang         PetscInt r    = rows(i);
302076ba34aSJunchao Zhang         PetscInt base = rowoffset(i); /* Copy r-th row of B to this offset in abuf[] */
3039371c9d4SSatish Balay         Kokkos::parallel_for(Kokkos::TeamThreadRange(t, Bi(r + 1) - Bi(r)), [&](PetscInt k) { abuf(base + k) = Ba(Bi(r) + k); });
304076ba34aSJunchao Zhang       });
305076ba34aSJunchao Zhang 
306076ba34aSJunchao Zhang     /* Send abuf to Ca through bcastSF and then mark C is updated on device */
3079566063dSJacob Faibussowitsch     PetscCall(PetscSFBcastBegin(bcastSF, MPIU_SCALAR, abuf.data(), Ca.data(), MPI_REPLACE)); /* TODO: get memtype for abuf */
3089566063dSJacob Faibussowitsch     PetscCall(PetscSFBcastEnd(bcastSF, MPIU_SCALAR, abuf.data(), Ca.data(), MPI_REPLACE));
309076ba34aSJunchao Zhang     ckok->a_dual.modify_device();
310076ba34aSJunchao Zhang   } else if (reuse == MAT_INITIAL_MATRIX) {
311076ba34aSJunchao Zhang     MPI_Comm    comm;
312076ba34aSJunchao Zhang     PetscMPIInt tag;
313076ba34aSJunchao Zhang     PetscInt    k, Cm, Cn, Cnnz, *Ci_h, nroots, nleaves;
314076ba34aSJunchao Zhang 
3159566063dSJacob Faibussowitsch     PetscCallMPI(PetscObjectGetComm((PetscObject)ownerSF, &comm));
3169566063dSJacob Faibussowitsch     PetscCall(PetscSFGetGraph(ownerSF, &nroots, &nleaves, NULL, NULL));
317076ba34aSJunchao Zhang     Cm = nleaves; /* row size of C */
318076ba34aSJunchao Zhang     Cn = N;       /* col size of C, which initially uses global ids, so we can safely set its col size as N */
319076ba34aSJunchao Zhang 
320076ba34aSJunchao Zhang     /* Get row lens (nz) of B's rows for later fast query */
321076ba34aSJunchao Zhang     PetscInt       *Browlens;
322076ba34aSJunchao Zhang     const PetscInt *tmp = bkok->i_host_data();
3239566063dSJacob Faibussowitsch     PetscCall(PetscMalloc1(nroots, &Browlens));
324076ba34aSJunchao Zhang     for (k = 0; k < nroots; k++) Browlens[k] = tmp[k + 1] - tmp[k];
325076ba34aSJunchao Zhang 
326076ba34aSJunchao Zhang     /* By ownerSF, each proc gets lens of rows of C */
327076ba34aSJunchao Zhang     MatRowMapKokkosDualView Ci("i", Cm + 1); /* C's rowmap */
328076ba34aSJunchao Zhang     Ci_h    = Ci.view_host().data();
329076ba34aSJunchao Zhang     Ci_h[0] = 0;
3309566063dSJacob Faibussowitsch     PetscCall(PetscSFBcastWithMemTypeBegin(ownerSF, MPIU_INT, PETSC_MEMTYPE_HOST, Browlens, PETSC_MEMTYPE_HOST, &Ci_h[1], MPI_REPLACE));
3319566063dSJacob Faibussowitsch     PetscCall(PetscSFBcastEnd(ownerSF, MPIU_INT, Browlens, &Ci_h[1], MPI_REPLACE));
332076ba34aSJunchao Zhang     for (k = 1; k < Cm + 1; k++) Ci_h[k] += Ci_h[k - 1]; /* Convert lens to CSR */
333076ba34aSJunchao Zhang     Cnnz = Ci_h[Cm];
334076ba34aSJunchao Zhang     Ci.modify_host();
335076ba34aSJunchao Zhang     Ci.sync_device();
336076ba34aSJunchao Zhang 
337076ba34aSJunchao Zhang     /* With the newly known Cnnz, we are able to allocate (j, a) for C on host & device */
338076ba34aSJunchao Zhang     MatColIdxKokkosDualView Cj("j", Cnnz);
339076ba34aSJunchao Zhang     MatScalarKokkosDualView Ca("a", Cnnz);
340076ba34aSJunchao Zhang 
341076ba34aSJunchao Zhang     /* Now build the bcastSF to fill Ca, Cj. This plain SF only does (contiguous) buffer to buffer send/recv */
342076ba34aSJunchao Zhang     const PetscMPIInt *iranks, *ranks;
343076ba34aSJunchao Zhang     const PetscInt    *ioffset, *irootloc, *roffset;
344076ba34aSJunchao Zhang     PetscInt           i, j, niranks, nranks, *sdisp, *rdisp, *rowptr;
345076ba34aSJunchao Zhang     MPI_Request       *reqs;
346076ba34aSJunchao Zhang 
3479566063dSJacob Faibussowitsch     PetscCall(PetscSFGetLeafRanks(ownerSF, &niranks, &iranks, &ioffset, &irootloc));                      /* irootloc[] contains indices of rows I need to send to each receiver */
3489566063dSJacob Faibussowitsch     PetscCall(PetscSFGetRootRanks(ownerSF, &nranks, &ranks, &roffset, NULL /*rmine*/, NULL /*rremote*/)); /* recv info */
349076ba34aSJunchao Zhang 
350076ba34aSJunchao Zhang     /* figure out offsets at the send buffer, to build the SF
351076ba34aSJunchao Zhang       sdisp[]  - stores offsets of nonzeros (in abuf or jbuf, see later) I need to send, per receiver.
352076ba34aSJunchao Zhang       rowptr[] - stores offsets for data of each row in abuf
353076ba34aSJunchao Zhang 
354076ba34aSJunchao Zhang       rdisp[]  - to receive sdisp[]
355076ba34aSJunchao Zhang     */
3569566063dSJacob Faibussowitsch     PetscCall(PetscMalloc3(niranks + 1, &sdisp, nranks, &rdisp, niranks + nranks, &reqs));
357076ba34aSJunchao Zhang     MatRowMapKokkosViewHost rowptr_h("rowptr_h", ioffset[niranks] + 1); /* Let Kokkos do the allocation, so that we can do an easy mirror later */
358076ba34aSJunchao Zhang     rowptr = rowptr_h.data();
359076ba34aSJunchao Zhang 
360076ba34aSJunchao Zhang     sdisp[0]  = 0;
361076ba34aSJunchao Zhang     rowptr[0] = 0;
362076ba34aSJunchao Zhang     for (i = 0; i < niranks; i++) { /* for each receiver */
363076ba34aSJunchao Zhang       PetscInt len, nz = 0;
364076ba34aSJunchao Zhang       for (j = ioffset[i]; j < ioffset[i + 1]; j++) { /* for each row to this receiver */
365076ba34aSJunchao Zhang         len           = Browlens[irootloc[j]];
366076ba34aSJunchao Zhang         rowptr[j + 1] = rowptr[j] + len;
367076ba34aSJunchao Zhang         nz += len;
368076ba34aSJunchao Zhang       }
369076ba34aSJunchao Zhang       sdisp[i + 1] = sdisp[i] + nz;
370076ba34aSJunchao Zhang     }
3719566063dSJacob Faibussowitsch     PetscCallMPI(PetscCommGetNewTag(comm, &tag));
3729566063dSJacob Faibussowitsch     for (i = 0; i < nranks; i++) PetscCallMPI(MPI_Irecv(&rdisp[i], 1, MPIU_INT, ranks[i], tag, comm, &reqs[i]));
3739566063dSJacob Faibussowitsch     for (i = 0; i < niranks; i++) PetscCallMPI(MPI_Isend(&sdisp[i], 1, MPIU_INT, iranks[i], tag, comm, &reqs[nranks + i]));
3749566063dSJacob Faibussowitsch     PetscCallMPI(MPI_Waitall(niranks + nranks, reqs, MPI_STATUSES_IGNORE));
375076ba34aSJunchao Zhang 
376076ba34aSJunchao Zhang     PetscInt     nleaves2 = Cnnz;           /* leaves are the nonzeros I will receive */
377076ba34aSJunchao Zhang     PetscInt     nroots2  = sdisp[niranks]; /* roots are the nonzeros (in abuf) I will send */
378076ba34aSJunchao Zhang     PetscSFNode *iremote;
3799566063dSJacob Faibussowitsch     PetscCall(PetscMalloc1(nleaves2, &iremote));
380076ba34aSJunchao Zhang     for (i = 0; i < nranks; i++) { /* for each sender */
381076ba34aSJunchao Zhang       k = 0;
382076ba34aSJunchao Zhang       for (j = Ci_h[roffset[i]]; j < Ci_h[roffset[i + 1]]; j++) {
383076ba34aSJunchao Zhang         iremote[j].rank  = ranks[i];
384076ba34aSJunchao Zhang         iremote[j].index = rdisp[i] + k;
385076ba34aSJunchao Zhang         k++;
386076ba34aSJunchao Zhang       }
387076ba34aSJunchao Zhang     }
388076ba34aSJunchao Zhang     /* TODO: we should extend PetscSF APIs for this buffer-to-buffer send/recv */
3899566063dSJacob Faibussowitsch     PetscCall(PetscSFCreate(comm, &bcastSF));
3909566063dSJacob Faibussowitsch     PetscCall(PetscSFSetGraph(bcastSF, nroots2, nleaves2, NULL /*ilocal*/, PETSC_OWN_POINTER, iremote, PETSC_OWN_POINTER));
391076ba34aSJunchao Zhang 
392076ba34aSJunchao Zhang     /* Extract selected rows of B, and copy their (a, j) into abuf[] and jbuf[], with j converted
393076ba34aSJunchao Zhang       from local to global. Then use bcastSF to fill Ca, Cj.
394076ba34aSJunchao Zhang     */
395076ba34aSJunchao Zhang     ConstMatColIdxKokkosViewHost rows_h(irootloc, ioffset[niranks]); /* irootloc[] stores indices of rows I need to send */
396076ba34aSJunchao Zhang     MatColIdxKokkosView          rows("rows", ioffset[niranks]);
397076ba34aSJunchao Zhang     Kokkos::deep_copy(rows, rows_h); /* Use deep copy since irootoc is managed by PetscSF and we want 'rows' to be standalone */
398076ba34aSJunchao Zhang 
399076ba34aSJunchao Zhang     rowoffset = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), rowptr_h); /* If no device, rowoffset will be an alias to rowptr_h */
400076ba34aSJunchao Zhang 
401076ba34aSJunchao Zhang     MatColIdxKokkosView jbuf("jbuf", sdisp[niranks]);   /* send buf for (global) col ids */
402076ba34aSJunchao Zhang     abuf = MatScalarKokkosView("abuf", sdisp[niranks]); /* send buf for mat values */
403076ba34aSJunchao Zhang 
404076ba34aSJunchao Zhang     const auto &Ba = bkok->a_dual.view_device();
405076ba34aSJunchao Zhang     const auto &Bi = bkok->i_dual.view_device();
406076ba34aSJunchao Zhang     const auto &Bj = bkok->j_dual.view_device();
407076ba34aSJunchao Zhang 
408076ba34aSJunchao Zhang     /* Copy Ba, Bj to abuf, jbuf with change col ids from local to global */
4099371c9d4SSatish Balay     Kokkos::parallel_for(
4109371c9d4SSatish Balay       Kokkos::TeamPolicy<>(rows.extent(0), Kokkos::AUTO()), KOKKOS_LAMBDA(const KokkosTeamMemberType &t) {
411076ba34aSJunchao Zhang         PetscInt i    = t.league_rank(); /* rows[i] is r-th row of B */
412076ba34aSJunchao Zhang         PetscInt r    = rows(i);
413076ba34aSJunchao Zhang         PetscInt base = rowoffset(i); /* Copy r-th row of B to this offset in abuf[] */
414076ba34aSJunchao Zhang         Kokkos::parallel_for(Kokkos::TeamThreadRange(t, Bi(r + 1) - Bi(r)), [&](PetscInt k) {
415076ba34aSJunchao Zhang           abuf(base + k) = Ba(Bi(r) + k);
416076ba34aSJunchao Zhang           jbuf(base + k) = l2g(Bj(Bi(r) + k));
417076ba34aSJunchao Zhang         });
418076ba34aSJunchao Zhang       });
419076ba34aSJunchao Zhang 
420076ba34aSJunchao Zhang     /* Send abuf & jbuf to fill Ca, Cj */
4219566063dSJacob Faibussowitsch     PetscCall(PetscSFBcastBegin(bcastSF, MPIU_INT, jbuf.data(), Cj.view_device().data(), MPI_REPLACE));
4229566063dSJacob Faibussowitsch     PetscCall(PetscSFBcastBegin(bcastSF, MPIU_SCALAR, abuf.data(), Ca.view_device().data(), MPI_REPLACE));
4239566063dSJacob Faibussowitsch     PetscCall(PetscSFBcastEnd(bcastSF, MPIU_INT, jbuf.data(), Cj.view_device().data(), MPI_REPLACE));
4249566063dSJacob Faibussowitsch     PetscCall(PetscSFBcastEnd(bcastSF, MPIU_SCALAR, abuf.data(), Ca.view_device().data(), MPI_REPLACE));
425076ba34aSJunchao Zhang     Cj.modify_device(); /* Mark Cj, Ca modified on device, but only sync Cj since we might not need Ca on host at all */
426076ba34aSJunchao Zhang     Cj.sync_host();
427076ba34aSJunchao Zhang     Ca.modify_device();
428076ba34aSJunchao Zhang 
429076ba34aSJunchao Zhang     /* Construct C with Ca, Ci, Cj */
430076ba34aSJunchao Zhang     auto ckok = new Mat_SeqAIJKokkos(Cm, Cn, Cnnz, Ci, Cj, Ca);
4319566063dSJacob Faibussowitsch     PetscCall(MatCreateSeqAIJKokkosWithCSRMatrix(PETSC_COMM_SELF, ckok, &C));
4329566063dSJacob Faibussowitsch     PetscCall(PetscFree3(sdisp, rdisp, reqs));
4339566063dSJacob Faibussowitsch     PetscCall(PetscFree(Browlens));
43498921bdaSJacob Faibussowitsch   } else SETERRQ(PETSC_COMM_SELF, PETSC_ERR_PLIB, "Unsupported MatReuse enum %d", reuse);
435076ba34aSJunchao Zhang   PetscFunctionReturn(0);
436076ba34aSJunchao Zhang }
437076ba34aSJunchao Zhang 
438076ba34aSJunchao Zhang /* MatSeqAIJKokkosReduce - Reduce rows of a SEQAIJKOKKOS matrix (A) to form a Kokkos Csr matrix (C)
439076ba34aSJunchao Zhang 
440076ba34aSJunchao Zhang   It is the reverse of MatSeqAIJKokkosBcast in some sense.
441076ba34aSJunchao Zhang 
442076ba34aSJunchao Zhang   Think each row of A as a leaf, then the given ownerSF specifies roots of the leaves. Roots may connect to multiple leaves.
443076ba34aSJunchao Zhang   In this routine, we reduce (i.e., concatenate) leaves (rows) at their roots to form potentially longer rows in C. Such rows might
444076ba34aSJunchao Zhang   contain repeats, which does not matter since they will be summed up by other routines. C's row size will be nroots of ownerSF.
445076ba34aSJunchao Zhang 
446076ba34aSJunchao Zhang   Input Parameters:
447076ba34aSJunchao Zhang +  A        - the SEQAIJKOKKOS matrix to be reduced
448076ba34aSJunchao Zhang .  reuse    - either MAT_INITIAL_MATRIX or MAT_REUSE_MATRIX
449076ba34aSJunchao Zhang .  local    - true if A uses local col ids; false if A is already in global col ids.
450076ba34aSJunchao Zhang .  N        - if local, N is A's global col size
451076ba34aSJunchao Zhang .  l2g      - if local, a map mapping A's local col ids to global ones, which are in range of [0,N).
452076ba34aSJunchao Zhang -  ownerSF  - the SF specifies ownership (root) of rows in A
453076ba34aSJunchao Zhang 
454076ba34aSJunchao Zhang   Output Parameters:
455076ba34aSJunchao Zhang +  reduceSF    - the SF to reduce A's rows to contiguous buffers at the receiver side
456076ba34aSJunchao Zhang .  abuf         - a contiguous buffer to receive A's rows sent to this proc. Suppose there are 'nrows' such rows.
457076ba34aSJunchao Zhang .  srcrowoffset - offset array of size nrows+1. Each entry is the corresponding row's offset in abuf[]. srcrowoffset[i+1]-srcrowoffset[i] is row i's len.
458076ba34aSJunchao Zhang .  dstrowoffset - offset array of size nrows. Each entry is the corresponding row's offset in Ca[], i.e., C's 'a' array. Row i, i+1 in abuf[] may go to
459076ba34aSJunchao Zhang                   unrelated places in Ca, so dstrowoffset is not in CSR-like format as srcrowoffset.
460076ba34aSJunchao Zhang -  C            - the matrix made up by rows sent to me from other ranks, using global col ids
461076ba34aSJunchao Zhang 
462076ba34aSJunchao Zhang    TODO: we can even have MatSeqAIJKokkosReduceBegin/End to provide oppertunity for callers to overlap comp./comm. when reuse = MAT_REUSE_MATRIX.
463076ba34aSJunchao Zhang  */
4649371c9d4SSatish Balay static PetscErrorCode MatSeqAIJKokkosReduce(Mat A, MatReuse reuse, PetscBool local, PetscInt N, const ConstMatColIdxKokkosView &l2g, PetscSF ownerSF, PetscSF &reduceSF, MatScalarKokkosView &abuf, MatRowMapKokkosView &srcrowoffset, MatRowMapKokkosView &dstrowoffset, KokkosCsrMatrix &C) {
465076ba34aSJunchao Zhang   PetscInt          i, r, Am, An, Annz, Cnnz, nrows;
466076ba34aSJunchao Zhang   const PetscInt   *Ai;
467076ba34aSJunchao Zhang   Mat_SeqAIJKokkos *akok;
468076ba34aSJunchao Zhang 
469076ba34aSJunchao Zhang   PetscFunctionBegin;
4709566063dSJacob Faibussowitsch   PetscCall(MatSeqAIJKokkosSyncDevice(A)); /* So that A's latest data is on device */
4719566063dSJacob Faibussowitsch   PetscCall(MatGetSize(A, &Am, &An));
472076ba34aSJunchao Zhang   Ai   = static_cast<Mat_SeqAIJ *>(A->data)->i;
473076ba34aSJunchao Zhang   akok = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
474076ba34aSJunchao Zhang   Annz = Ai[Am];
475076ba34aSJunchao Zhang 
476076ba34aSJunchao Zhang   if (reuse == MAT_REUSE_MATRIX) {
477076ba34aSJunchao Zhang     /* Send Aa to abuf */
4789566063dSJacob Faibussowitsch     PetscCallMPI(PetscSFReduceBegin(reduceSF, MPIU_SCALAR, akok->a_device_data(), abuf.data(), MPI_REPLACE));
4799566063dSJacob Faibussowitsch     PetscCallMPI(PetscSFReduceEnd(reduceSF, MPIU_SCALAR, akok->a_device_data(), abuf.data(), MPI_REPLACE));
480076ba34aSJunchao Zhang 
481076ba34aSJunchao Zhang     /* Copy abuf to Ca */
482076ba34aSJunchao Zhang     const MatScalarKokkosView &Ca = C.values;
483076ba34aSJunchao Zhang     nrows                         = dstrowoffset.extent(0); /* Not srcrowoffset[] since it has an extra entry for CSR */
4849371c9d4SSatish Balay     Kokkos::parallel_for(
4859371c9d4SSatish Balay       Kokkos::TeamPolicy<>(nrows, Kokkos::AUTO()), KOKKOS_LAMBDA(const KokkosTeamMemberType &t) {
486076ba34aSJunchao Zhang         PetscInt i   = t.league_rank();
487076ba34aSJunchao Zhang         PetscInt src = srcrowoffset(i), dst = dstrowoffset(i);
488076ba34aSJunchao Zhang         PetscInt len = srcrowoffset(i + 1) - srcrowoffset(i);
489076ba34aSJunchao Zhang         Kokkos::parallel_for(Kokkos::TeamThreadRange(t, len), [&](PetscInt k) { Ca(dst + k) = abuf(src + k); });
490076ba34aSJunchao Zhang       });
491076ba34aSJunchao Zhang   } else if (reuse == MAT_INITIAL_MATRIX) {
492076ba34aSJunchao Zhang     MPI_Comm     comm;
493076ba34aSJunchao Zhang     MPI_Request *reqs;
494076ba34aSJunchao Zhang     PetscMPIInt  tag;
495076ba34aSJunchao Zhang     PetscInt     Cm;
496076ba34aSJunchao Zhang 
4979566063dSJacob Faibussowitsch     PetscCall(PetscObjectGetComm((PetscObject)ownerSF, &comm));
4989566063dSJacob Faibussowitsch     PetscCall(PetscCommGetNewTag(comm, &tag));
499076ba34aSJunchao Zhang 
500076ba34aSJunchao Zhang     PetscInt           niranks, nranks, nroots, nleaves;
501076ba34aSJunchao Zhang     const PetscMPIInt *iranks, *ranks;
502076ba34aSJunchao Zhang     const PetscInt    *ioffset, *rows, *roffset; /* rows[] contains local indices of rows scattered to me from others. ioffset[] is a CSR on rows[] */
5039566063dSJacob Faibussowitsch     PetscCall(PetscSFSetUp(ownerSF));
5049566063dSJacob Faibussowitsch     PetscCall(PetscSFGetLeafRanks(ownerSF, &niranks, &iranks, &ioffset, &rows));                          /* recv info: iranks[] will send rows to me */
5059566063dSJacob Faibussowitsch     PetscCall(PetscSFGetRootRanks(ownerSF, &nranks, &ranks, &roffset, NULL /*rmine*/, NULL /*rremote*/)); /* send info */
5069566063dSJacob Faibussowitsch     PetscCall(PetscSFGetGraph(ownerSF, &nroots, &nleaves, NULL, NULL));
50708401ef6SPierre Jolivet     PetscCheck(nleaves == Am, PETSC_COMM_SELF, PETSC_ERR_PLIB, "ownerSF's nleaves(%" PetscInt_FMT ") != row size of A(%" PetscInt_FMT ")", nleaves, Am);
508076ba34aSJunchao Zhang     Cm    = nroots;
509076ba34aSJunchao Zhang     nrows = ioffset[niranks]; /* # of rows to be received. Might receive same row (each is partial) from different senders */
510076ba34aSJunchao Zhang 
511076ba34aSJunchao Zhang     /* Tell owners how long each row I will send */
512076ba34aSJunchao Zhang     PetscInt               *srowlens;                              /* send buf of row lens */
513076ba34aSJunchao Zhang     MatRowMapKokkosViewHost rrowlens_h("rrowoffset_h", nrows + 1); /* recv buf of row lens. +1 to make CSR later. Memory might be passed to other views */
514076ba34aSJunchao Zhang     PetscInt               *rrowlens = rrowlens_h.data();
515076ba34aSJunchao Zhang 
5169566063dSJacob Faibussowitsch     PetscCall(PetscMalloc2(Am, &srowlens, niranks + nranks, &reqs));
517076ba34aSJunchao Zhang     for (i = 0; i < Am; i++) srowlens[i] = Ai[i + 1] - Ai[i];
518076ba34aSJunchao Zhang     rrowlens[0] = 0;
519076ba34aSJunchao Zhang     rrowlens++; /* shift the pointer to make the following expression more readable */
5209566063dSJacob Faibussowitsch     for (i = 0; i < niranks; i++) PetscCallMPI(MPI_Irecv(&rrowlens[ioffset[i]], ioffset[i + 1] - ioffset[i], MPIU_INT, iranks[i], tag, comm, &reqs[i]));
5219566063dSJacob Faibussowitsch     for (i = 0; i < nranks; i++) PetscCallMPI(MPI_Isend(&srowlens[roffset[i]], roffset[i + 1] - roffset[i], MPIU_INT, ranks[i], tag, comm, &reqs[niranks + i]));
5229566063dSJacob Faibussowitsch     PetscCallMPI(MPI_Waitall(niranks + nranks, reqs, MPI_STATUSES_IGNORE));
523076ba34aSJunchao Zhang 
524076ba34aSJunchao Zhang     /* Owner builds Ci on host by histogramming rrowlens[] */
525076ba34aSJunchao Zhang     MatRowMapKokkosViewHost Ci_h("i", Cm + 1);
526076ba34aSJunchao Zhang     Kokkos::deep_copy(Ci_h, 0); /* Zero Ci */
527076ba34aSJunchao Zhang     MatRowMapType *Ci_ptr = Ci_h.data();
528076ba34aSJunchao Zhang 
529076ba34aSJunchao Zhang     for (i = 0; i < nrows; i++) {
530076ba34aSJunchao Zhang       r = rows[i]; /* local row id of i-th received row */
531076ba34aSJunchao Zhang #if defined(PETSC_USE_DEBUG)
532aed4548fSBarry Smith       PetscCheck(r >= 0 && r < Cm, PETSC_COMM_SELF, PETSC_ERR_PLIB, "local row id (%" PetscInt_FMT ") is out of range [0,%" PetscInt_FMT ")", r, Cm);
533076ba34aSJunchao Zhang #endif
534076ba34aSJunchao Zhang       Ci_ptr[r + 1] += rrowlens[i]; /* add to length of row r in C */
535076ba34aSJunchao Zhang     }
536076ba34aSJunchao Zhang     for (i = 0; i < Cm; i++) Ci_ptr[i + 1] += Ci_ptr[i]; /* to CSR format */
537076ba34aSJunchao Zhang     Cnnz = Ci_ptr[Cm];
538076ba34aSJunchao Zhang 
539076ba34aSJunchao Zhang     /* For each received row, compute src & dst offsets in memory copying (from recv bufs abuf, jbuf to Ca, Cj) */
540076ba34aSJunchao Zhang     MatRowMapKokkosViewHost dstrowoffset_h("dstrowoffset_h", nrows);
541076ba34aSJunchao Zhang     PetscInt               *dstrowoffset_hptr = dstrowoffset_h.data();
542076ba34aSJunchao Zhang     PetscInt               *currowlens; /* Current row lens. They are temp accumulators for row lens in C, to help build dstrowoffset */
543076ba34aSJunchao Zhang 
5449566063dSJacob Faibussowitsch     PetscCall(PetscCalloc1(Cm, &currowlens));           /* Init with zero, to be added to */
545076ba34aSJunchao Zhang     for (i = 0; i < nrows; i++) {                       /* for each row I receive */
546076ba34aSJunchao Zhang       r                    = rows[i];                   /* row id in C */
547076ba34aSJunchao Zhang       dstrowoffset_hptr[i] = Ci_ptr[r] + currowlens[r]; /* dst offset of the new place for each recv'ed row in Ca/Cj */
548076ba34aSJunchao Zhang       currowlens[r] += rrowlens[i];                     /* accumulate to length of row r in C */
549076ba34aSJunchao Zhang     }
5509566063dSJacob Faibussowitsch     PetscCall(PetscFree(currowlens));
551076ba34aSJunchao Zhang 
552076ba34aSJunchao Zhang     rrowlens--;
553076ba34aSJunchao Zhang     for (i = 0; i < nrows; i++) rrowlens[i + 1] += rrowlens[i]; /* Change rrowlens[] to CSR format */
554076ba34aSJunchao Zhang     dstrowoffset = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), dstrowoffset_h);
555076ba34aSJunchao Zhang     srcrowoffset = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), rrowlens_h); /* src offset of each recv'ed row in abuf/jbuf */
556076ba34aSJunchao Zhang 
557076ba34aSJunchao Zhang     /* Build the reduceSF, which performs buffer to buffer send/recv */
558076ba34aSJunchao Zhang     PetscInt *sdisp, *rdisp; /* buffer to send offsets of roots, and buffer to recv them */
5599566063dSJacob Faibussowitsch     PetscCall(PetscMalloc2(niranks, &sdisp, nranks, &rdisp));
560076ba34aSJunchao Zhang     for (i = 0; i < niranks; i++) sdisp[i] = rrowlens[ioffset[i]];
5619566063dSJacob Faibussowitsch     for (i = 0; i < nranks; i++) PetscCallMPI(MPI_Irecv(&rdisp[i], 1, MPIU_INT, ranks[i], tag, comm, &reqs[i]));
5629566063dSJacob Faibussowitsch     for (i = 0; i < niranks; i++) PetscCallMPI(MPI_Isend(&sdisp[i], 1, MPIU_INT, iranks[i], tag, comm, &reqs[nranks + i]));
5639566063dSJacob Faibussowitsch     PetscCallMPI(MPI_Waitall(niranks + nranks, reqs, MPI_STATUSES_IGNORE));
564076ba34aSJunchao Zhang 
565076ba34aSJunchao Zhang     /* Nonzeros in abuf/jbuf are roots and those in A are leaves */
566076ba34aSJunchao Zhang     PetscInt     nroots2 = Cnnz, nleaves2 = Annz;
567076ba34aSJunchao Zhang     PetscSFNode *iremote;
5689566063dSJacob Faibussowitsch     PetscCall(PetscMalloc1(nleaves2, &iremote)); /* no free, since memory will be given to reduceSF */
569076ba34aSJunchao Zhang     for (i = 0; i < nranks; i++) {
570076ba34aSJunchao Zhang       PetscInt rootbase = rdisp[i];                      /* root offset at this root rank */
571076ba34aSJunchao Zhang       PetscInt leafbase = Ai[roffset[i]];                /* leaf base */
572076ba34aSJunchao Zhang       PetscInt nz       = Ai[roffset[i + 1]] - leafbase; /* I will send nz nonzeros to this root rank */
573076ba34aSJunchao Zhang       for (PetscInt k = 0; k < nz; k++) {
574076ba34aSJunchao Zhang         iremote[leafbase + k].rank  = ranks[i];
575076ba34aSJunchao Zhang         iremote[leafbase + k].index = rootbase + k;
576076ba34aSJunchao Zhang       }
577076ba34aSJunchao Zhang     }
5789566063dSJacob Faibussowitsch     PetscCall(PetscSFCreate(comm, &reduceSF));
5799566063dSJacob Faibussowitsch     PetscCall(PetscSFSetGraph(reduceSF, nroots2, nleaves2, NULL, PETSC_OWN_POINTER, iremote, PETSC_OWN_POINTER));
5809566063dSJacob Faibussowitsch     PetscCall(PetscFree2(sdisp, rdisp));
581076ba34aSJunchao Zhang 
582076ba34aSJunchao Zhang     /* Reduce Aa, Ajg to abuf and jbuf */
583076ba34aSJunchao Zhang 
584076ba34aSJunchao Zhang     /* If A uses local col ids, convert them to global ones before sending */
585076ba34aSJunchao Zhang     MatColIdxKokkosView Ajg;
586076ba34aSJunchao Zhang     if (local) {
587076ba34aSJunchao Zhang       Ajg                           = MatColIdxKokkosView("j", Annz);
588076ba34aSJunchao Zhang       const MatColIdxKokkosView &Aj = akok->j_dual.view_device();
5899371c9d4SSatish Balay       Kokkos::parallel_for(
5909371c9d4SSatish Balay         Annz, KOKKOS_LAMBDA(const PetscInt i) { Ajg(i) = l2g(Aj(i)); });
591076ba34aSJunchao Zhang     } else {
592076ba34aSJunchao Zhang       Ajg = akok->j_dual.view_device(); /* no data copy, just take a reference */
593076ba34aSJunchao Zhang     }
594076ba34aSJunchao Zhang 
595076ba34aSJunchao Zhang     MatColIdxKokkosView jbuf("jbuf", Cnnz);
596076ba34aSJunchao Zhang     abuf = MatScalarKokkosView("abuf", Cnnz);
5979566063dSJacob Faibussowitsch     PetscCallMPI(PetscSFReduceBegin(reduceSF, MPIU_INT, Ajg.data(), jbuf.data(), MPI_REPLACE));
5989566063dSJacob Faibussowitsch     PetscCallMPI(PetscSFReduceEnd(reduceSF, MPIU_INT, Ajg.data(), jbuf.data(), MPI_REPLACE));
5999566063dSJacob Faibussowitsch     PetscCallMPI(PetscSFReduceBegin(reduceSF, MPIU_SCALAR, akok->a_device_data(), abuf.data(), MPI_REPLACE));
6009566063dSJacob Faibussowitsch     PetscCallMPI(PetscSFReduceEnd(reduceSF, MPIU_SCALAR, akok->a_device_data(), abuf.data(), MPI_REPLACE));
601076ba34aSJunchao Zhang 
602076ba34aSJunchao Zhang     /* Copy data from abuf, jbuf to Ca, Cj */
603076ba34aSJunchao Zhang     MatRowMapKokkosView Ci = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Ci_h); /* Ci is an alias of Ci_h if no device */
604076ba34aSJunchao Zhang     MatColIdxKokkosView Cj("j", Cnnz);
605076ba34aSJunchao Zhang     MatScalarKokkosView Ca("a", Cnnz);
606076ba34aSJunchao Zhang 
6079371c9d4SSatish Balay     Kokkos::parallel_for(
6089371c9d4SSatish Balay       Kokkos::TeamPolicy<>(nrows, Kokkos::AUTO()), KOKKOS_LAMBDA(const KokkosTeamMemberType &t) {
609076ba34aSJunchao Zhang         PetscInt i   = t.league_rank();
610076ba34aSJunchao Zhang         PetscInt src = srcrowoffset(i), dst = dstrowoffset(i);
611076ba34aSJunchao Zhang         PetscInt len = srcrowoffset(i + 1) - srcrowoffset(i);
612076ba34aSJunchao Zhang         Kokkos::parallel_for(Kokkos::TeamThreadRange(t, len), [&](PetscInt k) {
613076ba34aSJunchao Zhang           Ca(dst + k) = abuf(src + k);
614076ba34aSJunchao Zhang           Cj(dst + k) = jbuf(src + k);
615076ba34aSJunchao Zhang         });
616076ba34aSJunchao Zhang       });
617076ba34aSJunchao Zhang 
618076ba34aSJunchao Zhang     /* Build C with Ca, Ci, Cj */
619076ba34aSJunchao Zhang     C = KokkosCsrMatrix("csrmat", Cm, N, Cnnz, Ca, Ci, Cj);
6209566063dSJacob Faibussowitsch     PetscCall(PetscFree2(srowlens, reqs));
62198921bdaSJacob Faibussowitsch   } else SETERRQ(PETSC_COMM_SELF, PETSC_ERR_PLIB, "Unsupported MatReuse enum %d", reuse);
622076ba34aSJunchao Zhang   PetscFunctionReturn(0);
623076ba34aSJunchao Zhang }
624076ba34aSJunchao Zhang 
625*11a5261eSBarry Smith /* MatSetMPIAIJKokkosWithGlobalCSRMatrix - Set the diag and offdiag parts of a `MATMPIAIJKOKKOS` matrix by splitting a KokkosCsrMatrix
626076ba34aSJunchao Zhang 
627076ba34aSJunchao Zhang   Input Parameters:
628*11a5261eSBarry Smith +  C        - the `MATMPIAIJKOKKOS` matrix, of size m,n,M,N
629076ba34aSJunchao Zhang .  reuse    - indicate whether the matrix has called this function before
630076ba34aSJunchao Zhang .  csrmat   - the KokkosCsrMatrix, of size m,N
631*11a5261eSBarry Smith -  Cdstart  - when reuse == `MAT_REUSE_MATRIX`, it is an input parameter. For each row in csrmat, it stores the start of the first
632076ba34aSJunchao Zhang               entry of the diag block of C in csrmat's j array. E.g, if row i has col ids = {0, 3, 4, 5, 7, 9} and the first diag
633076ba34aSJunchao Zhang               entry is 5, then Cdstart[i] = 3.
634076ba34aSJunchao Zhang 
635076ba34aSJunchao Zhang   Output Parameters:
636*11a5261eSBarry Smith +  C        - the updated `MATMPIAIJKOKKOS` matrix
637*11a5261eSBarry Smith -  Cdstart - when reuse == `MAT_INITIAL_MATRIX`, it is an output parameter
638076ba34aSJunchao Zhang 
639*11a5261eSBarry Smith   Note:
640*11a5261eSBarry Smith    Between calls with `MAT_INITIAL_MATRIX` or `MAT_REUSE_MATRIX`, csrmat must have the same nonzero pattern
641*11a5261eSBarry Smith 
642*11a5261eSBarry Smith .seealso: `MATMPIAIJKOKKOS`
643076ba34aSJunchao Zhang  */
6449371c9d4SSatish Balay static PetscErrorCode MatSetMPIAIJKokkosWithGlobalCSRMatrix(Mat C, MatReuse reuse, const KokkosCsrMatrix &csrmat, MatRowMapKokkosView &Cdstart) {
645076ba34aSJunchao Zhang   const MatScalarKokkosView      &Ca = csrmat.values;
646076ba34aSJunchao Zhang   const ConstMatRowMapKokkosView &Ci = csrmat.graph.row_map;
647076ba34aSJunchao Zhang   PetscInt                        m, n, N;
648076ba34aSJunchao Zhang 
649076ba34aSJunchao Zhang   PetscFunctionBegin;
6509566063dSJacob Faibussowitsch   PetscCall(MatGetLocalSize(C, &m, &n));
6519566063dSJacob Faibussowitsch   PetscCall(MatGetSize(C, NULL, &N));
652076ba34aSJunchao Zhang 
653076ba34aSJunchao Zhang   if (reuse == MAT_REUSE_MATRIX) {
654076ba34aSJunchao Zhang     Mat_MPIAIJ                *mpiaij = static_cast<Mat_MPIAIJ *>(C->data);
655076ba34aSJunchao Zhang     Mat_SeqAIJKokkos          *akok   = static_cast<Mat_SeqAIJKokkos *>(mpiaij->A->spptr);
656076ba34aSJunchao Zhang     Mat_SeqAIJKokkos          *bkok   = static_cast<Mat_SeqAIJKokkos *>(mpiaij->B->spptr);
657076ba34aSJunchao Zhang     const MatScalarKokkosView &Cda = akok->a_dual.view_device(), Coa = bkok->a_dual.view_device();
658076ba34aSJunchao Zhang     const MatRowMapKokkosView &Cdi = akok->i_dual.view_device(), Coi = bkok->i_dual.view_device();
659076ba34aSJunchao Zhang 
660076ba34aSJunchao Zhang     /* Fill 'a' of Cd and Co on device */
6619371c9d4SSatish Balay     Kokkos::parallel_for(
6629371c9d4SSatish Balay       Kokkos::TeamPolicy<>(m, Kokkos::AUTO()), KOKKOS_LAMBDA(const KokkosTeamMemberType &t) {
663076ba34aSJunchao Zhang         PetscInt i       = t.league_rank();     /* row i */
664076ba34aSJunchao Zhang         PetscInt clen    = Ci(i + 1) - Ci(i);   /* len of row i of C */
665076ba34aSJunchao Zhang         PetscInt cdlen   = Cdi(i + 1) - Cdi(i); /* len of row i of Cd */
666076ba34aSJunchao Zhang         PetscInt cdstart = Cdstart(i);          /* [start, end) of row i of Cd in C */
667076ba34aSJunchao Zhang         PetscInt cdend   = cdstart + cdlen;
668076ba34aSJunchao Zhang         /* [0, clen) is cut into three blocks: [0, cdstart), [cdstart, cdend), [cdend, clen) */
669076ba34aSJunchao Zhang         Kokkos::parallel_for(Kokkos::TeamThreadRange(t, clen), [&](PetscInt k) {
670076ba34aSJunchao Zhang           if (k < cdstart) { /* k in [0, cdstart) */
671076ba34aSJunchao Zhang             Coa(Coi(i) + k) = Ca(Ci(i) + k);
672076ba34aSJunchao Zhang           } else if (k < cdend) { /* k in [cdstart, cdend) */
673076ba34aSJunchao Zhang             Cda(Cdi(i) + (k - cdstart)) = Ca(Ci(i) + k);
674076ba34aSJunchao Zhang           } else { /* k in [cdend, clen) */
675076ba34aSJunchao Zhang             Coa(Coi(i) + k - cdlen) = Ca(Ci(i) + k);
676076ba34aSJunchao Zhang           }
677076ba34aSJunchao Zhang         });
678076ba34aSJunchao Zhang       });
679076ba34aSJunchao Zhang 
680076ba34aSJunchao Zhang     akok->a_dual.modify_device();
681076ba34aSJunchao Zhang     bkok->a_dual.modify_device();
682076ba34aSJunchao Zhang   } else if (reuse == MAT_INITIAL_MATRIX) {
683076ba34aSJunchao Zhang     Mat                        Cd, Co;
684076ba34aSJunchao Zhang     const MatColIdxKokkosView &Cj = csrmat.graph.entries;
685076ba34aSJunchao Zhang     MatRowMapKokkosDualView    Cdi_dual("i", m + 1), Coi_dual("i", m + 1);
686076ba34aSJunchao Zhang     MatRowMapKokkosView        Cdi = Cdi_dual.view_device(), Coi = Coi_dual.view_device();
687076ba34aSJunchao Zhang     PetscInt                   cstart, cend;
688076ba34aSJunchao Zhang 
689076ba34aSJunchao Zhang     /* Note that each row of C is sorted by col ids. We want to find out how to cut each row into three blocks:
690076ba34aSJunchao Zhang        left to the diag block, diag block, right to the diag block. The diag block have col ids in [cstart,cend).
691076ba34aSJunchao Zhang        Suppose a row of C has len nonzeros, indexed by [0, len). We want to know two indices: cdstart and cdend,
692076ba34aSJunchao Zhang        such that the three blocks are [0,cdstart), [cdstart,cdend), [cdend,len). The following code equivalentaly
693076ba34aSJunchao Zhang        stores values of cdstart and cdend-cstart (aka Cdi[]) instead.
694076ba34aSJunchao Zhang      */
695076ba34aSJunchao Zhang     Cdstart = MatRowMapKokkosView("Cdstart", m);
6969566063dSJacob Faibussowitsch     PetscCall(PetscLayoutGetRange(C->cmap, &cstart, &cend)); /* Not MatGetOwnershipRangeColumn() since C has not been preallocated yet */
697076ba34aSJunchao Zhang 
698076ba34aSJunchao Zhang     /* I could use RangePolicy and one thread per row. But since each thread essentially does binary search, threads in a
699076ba34aSJunchao Zhang       CUDA warp would completely diverge. So I use TeamPolicy with a team size 1.
700076ba34aSJunchao Zhang      */
7019371c9d4SSatish Balay     Kokkos::parallel_for(
7029371c9d4SSatish Balay       Kokkos::TeamPolicy<>(m, 1), KOKKOS_LAMBDA(const KokkosTeamMemberType &t) {
703076ba34aSJunchao Zhang         Kokkos::single(Kokkos::PerTeam(t), [=]() {                               /* Only one thread works in a team */
704076ba34aSJunchao Zhang                                                    PetscInt i = t.league_rank(); /* row i */
705076ba34aSJunchao Zhang                                                    PetscInt j, first, count, step;
706076ba34aSJunchao Zhang 
707076ba34aSJunchao Zhang                                                    if (i == 0) { /* Set the first entry of the i arrays to zero on device, to be used in CSR */
708076ba34aSJunchao Zhang                                                      Cdi(0) = 0;
709076ba34aSJunchao Zhang                                                      Coi(0) = 0;
710076ba34aSJunchao Zhang                                                    }
711076ba34aSJunchao Zhang 
712076ba34aSJunchao Zhang                                                    /* Do std::lower_bound(Ci(i),Ci(i+1),cstart) on Cj[]. We use j as the iterator. lower_bound() returns
713076ba34aSJunchao Zhang           in 'first' the first iterator with a value >= cstart, or last iterator if no such element is found.
714076ba34aSJunchao Zhang         */
715076ba34aSJunchao Zhang                                                    count = Ci(i + 1) - Ci(i);
716076ba34aSJunchao Zhang                                                    first = Ci(i);
717076ba34aSJunchao Zhang                                                    while (count > 0) {
718076ba34aSJunchao Zhang                                                      j    = first;
719076ba34aSJunchao Zhang                                                      step = count / 2;
720076ba34aSJunchao Zhang                                                      j += step;
721076ba34aSJunchao Zhang                                                      if (Cj(j) < cstart) {
722076ba34aSJunchao Zhang                                                        first = ++j;
723076ba34aSJunchao Zhang                                                        count -= step + 1;
724076ba34aSJunchao Zhang                                                      } else count = step;
725076ba34aSJunchao Zhang                                                    }
726076ba34aSJunchao Zhang                                                    Cdstart(i) = first - Ci(i); /* 'first' is the while-loop's output */
727076ba34aSJunchao Zhang 
728076ba34aSJunchao Zhang                                                    /* Do std::lower_bound(first,Ci(i+1),cend) on Cj[] */
729076ba34aSJunchao Zhang                                                    count = Ci(i + 1) - first;
730076ba34aSJunchao Zhang                                                    while (count > 0) {
731076ba34aSJunchao Zhang                                                      j    = first;
732076ba34aSJunchao Zhang                                                      step = count / 2;
733076ba34aSJunchao Zhang                                                      j += step;
734076ba34aSJunchao Zhang                                                      if (Cj(j) < cend) {
735076ba34aSJunchao Zhang                                                        first = ++j;
736076ba34aSJunchao Zhang                                                        count -= step + 1;
737076ba34aSJunchao Zhang                                                      } else count = step;
738076ba34aSJunchao Zhang                                                    }
739076ba34aSJunchao Zhang                                                    Cdi(i + 1) = first - (Ci(i) + Cdstart(i));     /* 'first' is the while-loop's output */
740076ba34aSJunchao Zhang                                                    Coi(i + 1) = (Ci(i + 1) - Ci(i)) - Cdi(i + 1); /* Co's row len = C's row len - Cd's row len */
741076ba34aSJunchao Zhang         });
742076ba34aSJunchao Zhang       });
743076ba34aSJunchao Zhang 
744076ba34aSJunchao Zhang     /* Convert row lens in Cdi[], Coi[] to CSR format using inclusive scan, e.g., changing [0,1,2,3] into [0,1,3,6] */
7459371c9d4SSatish Balay     Kokkos::parallel_scan(
7469371c9d4SSatish Balay       m + 1, KOKKOS_LAMBDA(const PetscInt i, PetscInt &update, const bool final) {
747076ba34aSJunchao Zhang         update += Cdi(i);
748076ba34aSJunchao Zhang         if (final) Cdi(i) = update;
749076ba34aSJunchao Zhang       });
7509371c9d4SSatish Balay     Kokkos::parallel_scan(
7519371c9d4SSatish Balay       m + 1, KOKKOS_LAMBDA(const PetscInt i, PetscInt &update, const bool final) {
752076ba34aSJunchao Zhang         update += Coi(i);
753076ba34aSJunchao Zhang         if (final) Coi(i) = update;
754076ba34aSJunchao Zhang       });
755076ba34aSJunchao Zhang 
756076ba34aSJunchao Zhang     /* Get Cdi, Coi on host (it is not a waste, since we do need them on host in
757076ba34aSJunchao Zhang        MatCreateSeqAIJKokkosWithCSRMatrix() below), then get nnz of Cd and Co.
758076ba34aSJunchao Zhang     */
759076ba34aSJunchao Zhang     Cdi_dual.modify_device();
760076ba34aSJunchao Zhang     Coi_dual.modify_device();
761076ba34aSJunchao Zhang     Cdi_dual.sync_host();
762076ba34aSJunchao Zhang     Coi_dual.sync_host();
763076ba34aSJunchao Zhang     PetscInt Cd_nnz = Cdi_dual.view_host().data()[m];
764076ba34aSJunchao Zhang     PetscInt Co_nnz = Coi_dual.view_host().data()[m];
765076ba34aSJunchao Zhang 
766076ba34aSJunchao Zhang     /* With nnz, allocate a, j for Cd and Co */
767076ba34aSJunchao Zhang     MatColIdxKokkosDualView Cdj_dual("j", Cd_nnz), Coj_dual("j", Co_nnz);
768076ba34aSJunchao Zhang     MatScalarKokkosDualView Cda_dual("a", Cd_nnz), Coa_dual("a", Co_nnz);
769076ba34aSJunchao Zhang 
770076ba34aSJunchao Zhang     /* Fill a, j of Cd and Co on device */
771076ba34aSJunchao Zhang     MatColIdxKokkosView Cdj = Cdj_dual.view_device(), Coj = Coj_dual.view_device();
772076ba34aSJunchao Zhang     MatScalarKokkosView Cda = Cda_dual.view_device(), Coa = Coa_dual.view_device();
773076ba34aSJunchao Zhang 
7749371c9d4SSatish Balay     Kokkos::parallel_for(
7759371c9d4SSatish Balay       Kokkos::TeamPolicy<>(m, Kokkos::AUTO()), KOKKOS_LAMBDA(const KokkosTeamMemberType &t) {
776076ba34aSJunchao Zhang         PetscInt i       = t.league_rank();     /* row i */
777076ba34aSJunchao Zhang         PetscInt clen    = Ci(i + 1) - Ci(i);   /* len of row i of C */
778076ba34aSJunchao Zhang         PetscInt cdlen   = Cdi(i + 1) - Cdi(i); /* len of row i of Cd */
779076ba34aSJunchao Zhang         PetscInt cdstart = Cdstart(i);          /* [start, end) of row i of Cd in C */
780076ba34aSJunchao Zhang         PetscInt cdend   = cdstart + cdlen;
781076ba34aSJunchao Zhang         /* [0, clen) is cut into three blocks: [0, cdstart), [cdstart, cdend), [cdend, clen) */
782076ba34aSJunchao Zhang         Kokkos::parallel_for(Kokkos::TeamThreadRange(t, clen), [&](PetscInt k) {
783076ba34aSJunchao Zhang           if (k < cdstart) { /* k in [0, cdstart) */
784076ba34aSJunchao Zhang             Coa(Coi(i) + k) = Ca(Ci(i) + k);
785076ba34aSJunchao Zhang             Coj(Coi(i) + k) = Cj(Ci(i) + k);
786076ba34aSJunchao Zhang           } else if (k < cdend) { /* k in [cdstart, cdend) */
787076ba34aSJunchao Zhang             Cda(Cdi(i) + (k - cdstart)) = Ca(Ci(i) + k);
788076ba34aSJunchao Zhang             Cdj(Cdi(i) + (k - cdstart)) = Cj(Ci(i) + k) - cstart; /* Use local col ids in Cdj */
789076ba34aSJunchao Zhang           } else {                                                /* k in [cdend, clen) */
790076ba34aSJunchao Zhang             Coa(Coi(i) + k - cdlen) = Ca(Ci(i) + k);
791076ba34aSJunchao Zhang             Coj(Coi(i) + k - cdlen) = Cj(Ci(i) + k);
792076ba34aSJunchao Zhang           }
793076ba34aSJunchao Zhang         });
794076ba34aSJunchao Zhang       });
795076ba34aSJunchao Zhang 
796076ba34aSJunchao Zhang     Cdj_dual.modify_device();
797076ba34aSJunchao Zhang     Cda_dual.modify_device();
798076ba34aSJunchao Zhang     Coj_dual.modify_device();
799076ba34aSJunchao Zhang     Coa_dual.modify_device();
800076ba34aSJunchao Zhang     /* With a, i, j for Cd and Co, finally build Cd, Co and then C. Their offloadmask will be set in each's MatAssemblyEnd */
801076ba34aSJunchao Zhang     auto cdkok = new Mat_SeqAIJKokkos(m, n, Cd_nnz, Cdi_dual, Cdj_dual, Cda_dual);
802076ba34aSJunchao Zhang     auto cokok = new Mat_SeqAIJKokkos(m, N, Co_nnz, Coi_dual, Coj_dual, Coa_dual);
8039566063dSJacob Faibussowitsch     PetscCall(MatCreateSeqAIJKokkosWithCSRMatrix(PETSC_COMM_SELF, cdkok, &Cd));
8049566063dSJacob Faibussowitsch     PetscCall(MatCreateSeqAIJKokkosWithCSRMatrix(PETSC_COMM_SELF, cokok, &Co));
8059566063dSJacob Faibussowitsch     PetscCall(MatSetMPIAIJKokkosWithSplitSeqAIJKokkosMatrices(C, Cd, Co)); /* Coj will be converted to local ids within */
806076ba34aSJunchao Zhang   }
807076ba34aSJunchao Zhang   PetscFunctionReturn(0);
808076ba34aSJunchao Zhang }
809076ba34aSJunchao Zhang 
810076ba34aSJunchao Zhang /* MatSeqAIJCompactOutExtraColumns_SeqAIJKokkos - Compact a SEQAIJKOKKS matrix's global col ids.
811076ba34aSJunchao Zhang 
812076ba34aSJunchao Zhang   It is similar to MatSeqAIJCompactOutExtraColumns_SeqAIJ, but it applies to SEQAIJKOKKOS and returns the l2g map in Kokkos view.
813076ba34aSJunchao Zhang 
814076ba34aSJunchao Zhang   Input Parameters:
815076ba34aSJunchao Zhang +  C        - the MATMPIAIJKOKKOS matrix, of size m,n,M,N
816076ba34aSJunchao Zhang .  reuse    - indicate whether the matrix has called this function before
817076ba34aSJunchao Zhang .  csrmat   - the KokkosCsrMatrix, of size m,N
818076ba34aSJunchao Zhang -  Cdoffset - when reuse == MAT_REUSE_MATRIX, it is an input parameter. For each row in csrmat, it stores the offset of the first
819076ba34aSJunchao Zhang               entry of the diag block of C in csrmat's j array.
820076ba34aSJunchao Zhang 
821076ba34aSJunchao Zhang   Output Parameters:
822076ba34aSJunchao Zhang +  C        - the updated MATMPIAIJKOKKOS matrix
823076ba34aSJunchao Zhang -  Cdoffset - when reuse == MAT_INITIAL_MATRIX, it is an output parameter
824076ba34aSJunchao Zhang 
825*11a5261eSBarry Smith   Note:
826*11a5261eSBarry Smith   the input matrix's col ids and col size will be changed.
827076ba34aSJunchao Zhang */
8289371c9d4SSatish Balay static PetscErrorCode MatSeqAIJCompactOutExtraColumns_SeqAIJKokkos(Mat C, MatColIdxKokkosView &l2g) {
829076ba34aSJunchao Zhang   Mat_SeqAIJKokkos      *ckok;
830076ba34aSJunchao Zhang   ISLocalToGlobalMapping l2gmap;
831076ba34aSJunchao Zhang   const PetscInt        *garray;
832076ba34aSJunchao Zhang   PetscInt               sz;
833076ba34aSJunchao Zhang 
834076ba34aSJunchao Zhang   PetscFunctionBegin;
835076ba34aSJunchao Zhang   /* Compact P_other's global col ids and col size. We do it since we guess with local ids KK might be more memory scalable */
8369566063dSJacob Faibussowitsch   PetscCall(MatSeqAIJCompactOutExtraColumns_SeqAIJ(C, &l2gmap));
837076ba34aSJunchao Zhang   ckok = static_cast<Mat_SeqAIJKokkos *>(C->spptr);
838076ba34aSJunchao Zhang   ckok->j_dual.modify_host(); /* P_other's j is modified on host; we need to sync it on device */
839076ba34aSJunchao Zhang   ckok->j_dual.sync_device();
840076ba34aSJunchao Zhang   ckok->SetColSize(C->cmap->n); /* Update col size of the csrmat in spptr */
841076ba34aSJunchao Zhang 
842076ba34aSJunchao Zhang   /* Build l2g -- the local to global mapping of C's cols */
8439566063dSJacob Faibussowitsch   PetscCall(ISLocalToGlobalMappingGetIndices(l2gmap, &garray));
8449566063dSJacob Faibussowitsch   PetscCall(ISLocalToGlobalMappingGetSize(l2gmap, &sz));
84508401ef6SPierre Jolivet   PetscCheck(C->cmap->n == sz, PETSC_COMM_SELF, PETSC_ERR_PLIB, "matrix column size(%" PetscInt_FMT ") != l2g mapping size(%" PetscInt_FMT ")", C->cmap->n, sz);
846076ba34aSJunchao Zhang 
847076ba34aSJunchao Zhang   ConstMatColIdxKokkosViewHost tmp(garray, sz);
848076ba34aSJunchao Zhang   l2g = MatColIdxKokkosView("l2g", sz);
849076ba34aSJunchao Zhang   Kokkos::deep_copy(l2g, tmp);
850076ba34aSJunchao Zhang 
8519566063dSJacob Faibussowitsch   PetscCall(ISLocalToGlobalMappingRestoreIndices(l2gmap, &garray));
8529566063dSJacob Faibussowitsch   PetscCall(ISLocalToGlobalMappingDestroy(&l2gmap));
853076ba34aSJunchao Zhang   PetscFunctionReturn(0);
854076ba34aSJunchao Zhang }
855076ba34aSJunchao Zhang 
856076ba34aSJunchao Zhang /* MatProductSymbolic_MPIAIJKokkos_AB - AB flavor of MatProductSymbolic_MPIAIJKokkos
857076ba34aSJunchao Zhang 
858076ba34aSJunchao Zhang   Input Parameters:
859076ba34aSJunchao Zhang +  product  - Mat_Product which carried out the computation. Passed in to access info about this mat product.
860076ba34aSJunchao Zhang .  A        - an MPIAIJKOKKOS matrix
861076ba34aSJunchao Zhang .  B        - an MPIAIJKOKKOS matrix
862076ba34aSJunchao Zhang -  mm       - a struct used to stash intermediate data when computing AB. Persist from symbolic to numeric operations.
863076ba34aSJunchao Zhang 
864*11a5261eSBarry Smith   Note: The local part of the result C is stored as mm->C_global, which is of type KokkosCsrMatrix and uses global col ids.
865076ba34aSJunchao Zhang */
8669371c9d4SSatish Balay static PetscErrorCode MatProductSymbolic_MPIAIJKokkos_AB(Mat_Product *product, Mat A, Mat B, MatMatStruct_AB *mm) {
867076ba34aSJunchao Zhang   Mat_MPIAIJ              *a  = static_cast<Mat_MPIAIJ *>(A->data);
868076ba34aSJunchao Zhang   Mat                      Ad = a->A, Ao = a->B; /* diag and offdiag of A */
869076ba34aSJunchao Zhang   IS                       glob = NULL;
870076ba34aSJunchao Zhang   const PetscInt          *garray;
871076ba34aSJunchao Zhang   PetscInt                 N = B->cmap->N, sz;
872076ba34aSJunchao Zhang   ConstMatColIdxKokkosView l2g1; /* two temp maps mapping local col ids to global ones */
873076ba34aSJunchao Zhang   MatColIdxKokkosView      l2g2;
874076ba34aSJunchao Zhang   Mat                      C1, C2; /* intermediate matrices */
875076ba34aSJunchao Zhang 
876076ba34aSJunchao Zhang   PetscFunctionBegin;
877076ba34aSJunchao Zhang   /* C1 = Ad * B_local. B_local is a matrix got by merging Bd and Bo, and uses local col ids */
8789566063dSJacob Faibussowitsch   PetscCall(MatMPIAIJGetLocalMatMerge(B, MAT_INITIAL_MATRIX, &glob, &mm->B_local));
8799566063dSJacob Faibussowitsch   PetscCall(MatProductCreate(Ad, mm->B_local, NULL, &C1));
8809566063dSJacob Faibussowitsch   PetscCall(MatProductSetType(C1, MATPRODUCT_AB));
8819566063dSJacob Faibussowitsch   PetscCall(MatProductSetFill(C1, product->fill));
882076ba34aSJunchao Zhang   C1->product->api_user = product->api_user;
8839566063dSJacob Faibussowitsch   PetscCall(MatProductSetFromOptions(C1));
884dbbe0bcdSBarry Smith   PetscUseTypeMethod(C1, productsymbolic);
885076ba34aSJunchao Zhang 
8869566063dSJacob Faibussowitsch   PetscCall(ISGetIndices(glob, &garray));
8879566063dSJacob Faibussowitsch   PetscCall(ISGetSize(glob, &sz));
888076ba34aSJunchao Zhang   const auto &tmp = ConstMatColIdxKokkosViewHost(garray, sz);                       /* wrap garray as a view */
889076ba34aSJunchao Zhang   l2g1            = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), tmp); /* maybe just an alias to tmp, so we restore garray at the very end */
8909566063dSJacob Faibussowitsch   PetscCall(MatSeqAIJKokkosGetCSRMatrixWithGlobalColumnIds(C1, N, l2g1, mm->C1_global));
891076ba34aSJunchao Zhang 
892076ba34aSJunchao Zhang   /* C2 = Ao * B_other. B_other is a matrix consisting of needed rows of B gathered from other procs */
8939566063dSJacob Faibussowitsch   PetscCall(MatSeqAIJKokkosBcast(mm->B_local, MAT_INITIAL_MATRIX, N, l2g1, a->Mvctx, mm->sf, mm->abuf, mm->rows, mm->rowoffset, mm->B_other));
894076ba34aSJunchao Zhang 
895076ba34aSJunchao Zhang   /* Compact B_other to use local ids as we guess KK spgemm is more memroy scalable with that; We could skip the compaction to simplify code */
8969566063dSJacob Faibussowitsch   PetscCall(MatSeqAIJCompactOutExtraColumns_SeqAIJKokkos(mm->B_other, l2g2));
8979566063dSJacob Faibussowitsch   PetscCall(MatProductCreate(Ao, mm->B_other, NULL, &C2));
8989566063dSJacob Faibussowitsch   PetscCall(MatProductSetType(C2, MATPRODUCT_AB));
8999566063dSJacob Faibussowitsch   PetscCall(MatProductSetFill(C2, product->fill));
900076ba34aSJunchao Zhang   C2->product->api_user = product->api_user;
9019566063dSJacob Faibussowitsch   PetscCall(MatProductSetFromOptions(C2));
902dbbe0bcdSBarry Smith   PetscUseTypeMethod(C2, productsymbolic);
9039566063dSJacob Faibussowitsch   PetscCall(MatSeqAIJKokkosGetCSRMatrixWithGlobalColumnIds(C2, N, l2g2, mm->C2_global));
904076ba34aSJunchao Zhang 
905076ba34aSJunchao Zhang   /* C = C1 + C2.  We actually use their global col ids versions in adding */
906076ba34aSJunchao Zhang   mm->kh.create_spadd_handle(false); /* Input C1, C2 are NOT sorted, since B_local, B_other are not */
907076ba34aSJunchao Zhang   KokkosSparse::spadd_symbolic(&mm->kh, mm->C1_global, mm->C2_global, mm->C_global);
908076ba34aSJunchao Zhang   /* Have to do numeric since spadd_symbolic does not really populate column indices of the result matrix */
909076ba34aSJunchao Zhang   KokkosSparse::spadd_numeric(&mm->kh, (MatScalarType)1.0, mm->C1_global, (MatScalarType)1.0, mm->C2_global, mm->C_global);
910076ba34aSJunchao Zhang 
911076ba34aSJunchao Zhang   mm->C1 = C1;
912076ba34aSJunchao Zhang   mm->C2 = C2;
9139566063dSJacob Faibussowitsch   PetscCall(ISRestoreIndices(glob, &garray));
9149566063dSJacob Faibussowitsch   PetscCall(ISDestroy(&glob));
915076ba34aSJunchao Zhang   PetscFunctionReturn(0);
916076ba34aSJunchao Zhang }
917076ba34aSJunchao Zhang 
918076ba34aSJunchao Zhang /* MatProductSymbolic_MPIAIJKokkos_AtB - A^tB flavor of MatProductSymbolic_MPIAIJKokkos
919076ba34aSJunchao Zhang 
920076ba34aSJunchao Zhang   Input Parameters:
921076ba34aSJunchao Zhang +  product  - Mat_Product which carried out the computation. Passed in to access info about this mat product.
922076ba34aSJunchao Zhang .  A        - an MPIAIJKOKKOS matrix
923076ba34aSJunchao Zhang .  B        - a SEQAIJKOKKOS matrix. It works as if A^t is multiplied by a parallel matrix made up of Bs on each rank.
924076ba34aSJunchao Zhang .  localB   - Does B use local col ids? If false, then B is already in global col ids.
925076ba34aSJunchao Zhang .  N        - col size of the "parallel B matrix". It implies B's global col ids are in range of [0,N) and N is the same across the communicator.
926076ba34aSJunchao Zhang .  l2g      - If localB, then l2g maps B's local col ids to global ones.
927076ba34aSJunchao Zhang -  mm       - a struct used to stash intermediate data in AtB
928076ba34aSJunchao Zhang 
929*11a5261eSBarry Smith   Note: The local part of the result C is stored as mm->C_global, which is of type KokkosCsrMatrix and uses global col ids.
930076ba34aSJunchao Zhang */
9319371c9d4SSatish Balay static PetscErrorCode MatProductSymbolic_MPIAIJKokkos_AtB(Mat_Product *product, Mat A, Mat B, PetscBool localB, PetscInt N, const ConstMatColIdxKokkosView &l2g, MatMatStruct_AtB *mm) {
932076ba34aSJunchao Zhang   Mat_MPIAIJ *a  = static_cast<Mat_MPIAIJ *>(A->data);
933076ba34aSJunchao Zhang   Mat         Ad = a->A, Ao = a->B; /* diag and offdiag of A */
934076ba34aSJunchao Zhang   Mat         C1, C2;               /* intermediate matrices */
935076ba34aSJunchao Zhang 
936076ba34aSJunchao Zhang   PetscFunctionBegin;
937076ba34aSJunchao Zhang   /* C1 = Ad^t * B */
9389566063dSJacob Faibussowitsch   PetscCall(MatProductCreate(Ad, B, NULL, &C1));
9399566063dSJacob Faibussowitsch   PetscCall(MatProductSetType(C1, MATPRODUCT_AtB));
9409566063dSJacob Faibussowitsch   PetscCall(MatProductSetFill(C1, product->fill));
941076ba34aSJunchao Zhang   C1->product->api_user = product->api_user;
9429566063dSJacob Faibussowitsch   PetscCall(MatProductSetFromOptions(C1));
943dbbe0bcdSBarry Smith   PetscUseTypeMethod(C1, productsymbolic);
944076ba34aSJunchao Zhang 
9459566063dSJacob Faibussowitsch   if (localB) PetscCall(MatSeqAIJKokkosGetCSRMatrixWithGlobalColumnIds(C1, N, l2g, mm->C1_global));
946076ba34aSJunchao Zhang   else mm->C1_global = static_cast<Mat_SeqAIJKokkos *>(C1->spptr)->csrmat; /* the csrmat already uses global col ids */
947076ba34aSJunchao Zhang 
948076ba34aSJunchao Zhang   /* C2 = Ao^t * B */
9499566063dSJacob Faibussowitsch   PetscCall(MatProductCreate(Ao, B, NULL, &C2));
9509566063dSJacob Faibussowitsch   PetscCall(MatProductSetType(C2, MATPRODUCT_AtB));
9519566063dSJacob Faibussowitsch   PetscCall(MatProductSetFill(C2, product->fill));
952076ba34aSJunchao Zhang   C2->product->api_user = product->api_user;
9539566063dSJacob Faibussowitsch   PetscCall(MatProductSetFromOptions(C2));
954dbbe0bcdSBarry Smith   PetscUseTypeMethod(C2, productsymbolic);
955076ba34aSJunchao Zhang 
9569566063dSJacob Faibussowitsch   PetscCall(MatSeqAIJKokkosReduce(C2, MAT_INITIAL_MATRIX, localB, N, l2g, a->Mvctx, mm->sf, mm->abuf, mm->srcrowoffset, mm->dstrowoffset, mm->C2_global));
957076ba34aSJunchao Zhang 
958076ba34aSJunchao Zhang   mm->kh.create_spadd_handle(false); /* Input C1, C2 are NOT sorted, since B may be not */
959076ba34aSJunchao Zhang   KokkosSparse::spadd_symbolic(&mm->kh, mm->C1_global, mm->C2_global, mm->C_global);
960076ba34aSJunchao Zhang   /* Have to do numeric since spadd_symbolic does not really populate column indices of the result matrix */
961076ba34aSJunchao Zhang   KokkosSparse::spadd_numeric(&mm->kh, (MatScalarType)1.0, mm->C1_global, (MatScalarType)1.0, mm->C2_global, mm->C_global);
962076ba34aSJunchao Zhang   mm->C1 = C1;
963076ba34aSJunchao Zhang   mm->C2 = C2;
964076ba34aSJunchao Zhang   PetscFunctionReturn(0);
965076ba34aSJunchao Zhang }
966076ba34aSJunchao Zhang 
9679371c9d4SSatish Balay PetscErrorCode MatProductNumeric_MPIAIJKokkos(Mat C) {
968076ba34aSJunchao Zhang   Mat_Product                 *product = C->product;
969076ba34aSJunchao Zhang   MatProductType               ptype;
970076ba34aSJunchao Zhang   MatProductData_MPIAIJKokkos *mmdata;
971076ba34aSJunchao Zhang   MatMatStruct                *mm = NULL;
972076ba34aSJunchao Zhang   MatMatStruct_AB             *ab;
973076ba34aSJunchao Zhang   MatMatStruct_AtB            *atb;
974076ba34aSJunchao Zhang   Mat                          A, B, Ad, Ao, Bd, Bo;
975076ba34aSJunchao Zhang   const MatScalarType          one = 1.0; /* Not use literal 1.0 directly, to avoid wrong template instantiation in KokkosSparse::spadd_numeric */
976076ba34aSJunchao Zhang 
977076ba34aSJunchao Zhang   PetscFunctionBegin;
978076ba34aSJunchao Zhang   MatCheckProduct(C, 1);
979076ba34aSJunchao Zhang   mmdata = static_cast<MatProductData_MPIAIJKokkos *>(product->data);
980076ba34aSJunchao Zhang   ptype  = product->type;
981076ba34aSJunchao Zhang   A      = product->A;
982076ba34aSJunchao Zhang   B      = product->B;
983076ba34aSJunchao Zhang   Ad     = static_cast<Mat_MPIAIJ *>(A->data)->A;
984076ba34aSJunchao Zhang   Ao     = static_cast<Mat_MPIAIJ *>(A->data)->B;
985076ba34aSJunchao Zhang   Bd     = static_cast<Mat_MPIAIJ *>(B->data)->A;
986076ba34aSJunchao Zhang   Bo     = static_cast<Mat_MPIAIJ *>(B->data)->B;
987076ba34aSJunchao Zhang 
988076ba34aSJunchao Zhang   if (mmdata->reusesym) {           /* We reached here through e.g., MatMatMult(A,B,MAT_INITIAL_MATRIX,..,C), where symbolic/numeric are combined */
989076ba34aSJunchao Zhang     mmdata->reusesym = PETSC_FALSE; /* So that next time when user calls MatMatMult(E,F,MAT_REUSE_MATRIX,..,C), we still do numeric  */
990076ba34aSJunchao Zhang     ab               = mmdata->mmAB;
991076ba34aSJunchao Zhang     atb              = mmdata->mmAtB;
992076ba34aSJunchao Zhang     if (ab) {
993076ba34aSJunchao Zhang       static_cast<MatProductData_SeqAIJKokkos *>(ab->C1->product->data)->reusesym = PETSC_FALSE;
994076ba34aSJunchao Zhang       static_cast<MatProductData_SeqAIJKokkos *>(ab->C2->product->data)->reusesym = PETSC_FALSE;
995076ba34aSJunchao Zhang     }
996076ba34aSJunchao Zhang     if (atb) {
997076ba34aSJunchao Zhang       static_cast<MatProductData_SeqAIJKokkos *>(atb->C1->product->data)->reusesym = PETSC_FALSE;
998076ba34aSJunchao Zhang       static_cast<MatProductData_SeqAIJKokkos *>(atb->C2->product->data)->reusesym = PETSC_FALSE;
999076ba34aSJunchao Zhang     }
1000076ba34aSJunchao Zhang     PetscFunctionReturn(0);
1001076ba34aSJunchao Zhang   }
1002076ba34aSJunchao Zhang 
1003076ba34aSJunchao Zhang   if (ptype == MATPRODUCT_AB) {
1004076ba34aSJunchao Zhang     ab = mmdata->mmAB;
1005076ba34aSJunchao Zhang     /* C1 = Ad * B_local */
100608401ef6SPierre Jolivet     PetscCheck(ab->C1->ops->productnumeric && ab->C2->ops->productnumeric, PETSC_COMM_SELF, PETSC_ERR_PLIB, "Missing numeric op for MATPRODUCT_AB");
10079566063dSJacob Faibussowitsch     PetscCall(MatMPIAIJGetLocalMatMerge(B, MAT_REUSE_MATRIX, NULL /* glob */, &ab->B_local));
10085f80ce2aSJacob Faibussowitsch     PetscCheck(ab->C1->product->B == ab->B_local, PETSC_COMM_SELF, PETSC_ERR_PLIB, "In MATPRODUCT_AB, internal mat product matrix C1->B has unexpectedly changed");
10099566063dSJacob Faibussowitsch     if (ab->C1->product->A != Ad) PetscCall(MatProductReplaceMats(Ad, NULL, NULL, ab->C1));
10109566063dSJacob Faibussowitsch     PetscCall((*ab->C1->ops->productnumeric)(ab->C1));
10119371c9d4SSatish Balay     PetscCall(MatSeqAIJKokkosBcast(ab->B_local, MAT_REUSE_MATRIX, 0 /* N */, MatColIdxKokkosView() /*l2g*/, NULL /*ownerSF*/, ab->sf, ab->abuf, ab->rows, ab->rowoffset, ab->B_other));
1012076ba34aSJunchao Zhang     /* C2 = Ao * B_other */
101308401ef6SPierre Jolivet     PetscCheck(ab->C2->product->B == ab->B_other, PETSC_COMM_SELF, PETSC_ERR_PLIB, "In MATPRODUCT_AB, internal mat product matrix C2->B has unexpectedly changed");
10149566063dSJacob Faibussowitsch     if (ab->C1->product->A != Ao) PetscCall(MatProductReplaceMats(Ao, NULL, NULL, ab->C2));
10159566063dSJacob Faibussowitsch     PetscCall((*ab->C2->ops->productnumeric)(ab->C2));
1016076ba34aSJunchao Zhang     /* C = C1_global + C2_global */
1017076ba34aSJunchao Zhang     KokkosSparse::spadd_numeric(&ab->kh, one, ab->C1_global, one, ab->C2_global, ab->C_global);
1018076ba34aSJunchao Zhang     mm = static_cast<MatMatStruct *>(ab);
1019076ba34aSJunchao Zhang   } else if (ptype == MATPRODUCT_AtB) {
1020076ba34aSJunchao Zhang     atb = mmdata->mmAtB;
102108401ef6SPierre Jolivet     PetscCheck(atb->C1->ops->productnumeric && atb->C2->ops->productnumeric, PETSC_COMM_SELF, PETSC_ERR_PLIB, "Missing numeric op for MATPRODUCT_AtB");
1022076ba34aSJunchao Zhang     /* C1 = Ad^t * B_local */
10239566063dSJacob Faibussowitsch     PetscCall(MatMPIAIJGetLocalMatMerge(B, MAT_REUSE_MATRIX, NULL /* glob */, &atb->B_local));
102408401ef6SPierre Jolivet     PetscCheck(atb->C1->product->B == atb->B_local, PETSC_COMM_SELF, PETSC_ERR_PLIB, "In MATPRODUCT_AtB, internal mat product matrix C1->B has unexpectedly changed");
10259566063dSJacob Faibussowitsch     if (atb->C1->product->A != Ad) PetscCall(MatProductReplaceMats(Ad, NULL, NULL, atb->C1));
10269566063dSJacob Faibussowitsch     PetscCall((*atb->C1->ops->productnumeric)(atb->C1));
1027076ba34aSJunchao Zhang 
1028076ba34aSJunchao Zhang     /* C2 = Ao^t * B_local */
102908401ef6SPierre Jolivet     PetscCheck(atb->C2->product->B == atb->B_local, PETSC_COMM_SELF, PETSC_ERR_PLIB, "In MATPRODUCT_AtB, internal mat product matrix C2->B has unexpectedly changed");
10309566063dSJacob Faibussowitsch     if (atb->C2->product->A != Ao) PetscCall(MatProductReplaceMats(Ao, NULL, NULL, atb->C2));
10319566063dSJacob Faibussowitsch     PetscCall((*atb->C2->ops->productnumeric)(atb->C2));
1032076ba34aSJunchao Zhang     /* Form C2_global */
10339371c9d4SSatish Balay     PetscCall(MatSeqAIJKokkosReduce(atb->C2, MAT_REUSE_MATRIX, PETSC_TRUE, 0 /* N */, MatColIdxKokkosView() /*l2g*/, NULL /*ownerSF*/, atb->sf, atb->abuf, atb->srcrowoffset, atb->dstrowoffset, atb->C2_global));
1034076ba34aSJunchao Zhang     /* C = C1_global + C2_global */
1035076ba34aSJunchao Zhang     KokkosSparse::spadd_numeric(&atb->kh, one, atb->C1_global, one, atb->C2_global, atb->C_global);
1036076ba34aSJunchao Zhang     mm = static_cast<MatMatStruct *>(atb);
1037076ba34aSJunchao Zhang   } else if (ptype == MATPRODUCT_PtAP) { /* BtAB */
1038076ba34aSJunchao Zhang     ab = mmdata->mmAB;
10399566063dSJacob Faibussowitsch     PetscCall(MatMPIAIJGetLocalMatMerge(B, MAT_REUSE_MATRIX, NULL /* glob */, &ab->B_local));
1040076ba34aSJunchao Zhang 
1041076ba34aSJunchao Zhang     /* ab->C1 = Ad * B_local */
104208401ef6SPierre Jolivet     PetscCheck(ab->C1->product->B == ab->B_local, PETSC_COMM_SELF, PETSC_ERR_PLIB, "In MATPRODUCT_PtAP, internal mat product matrix ab->C1->B has unexpectedly changed");
10439566063dSJacob Faibussowitsch     if (ab->C1->product->A != Ad) PetscCall(MatProductReplaceMats(Ad, NULL, NULL, ab->C1));
10449566063dSJacob Faibussowitsch     PetscCall((*ab->C1->ops->productnumeric)(ab->C1));
10459371c9d4SSatish Balay     PetscCall(MatSeqAIJKokkosBcast(ab->B_local, MAT_REUSE_MATRIX, 0 /* N */, MatColIdxKokkosView() /*l2g*/, NULL /*ownerSF*/, ab->sf, ab->abuf, ab->rows, ab->rowoffset, ab->B_other));
1046076ba34aSJunchao Zhang     /* ab->C2 = Ao * B_other */
10479566063dSJacob Faibussowitsch     if (ab->C2->product->A != Ao) PetscCall(MatProductReplaceMats(Ao, NULL, NULL, ab->C2));
10489566063dSJacob Faibussowitsch     PetscCall((*ab->C2->ops->productnumeric)(ab->C2)); /* C2 = Ao * B_other */
1049076ba34aSJunchao Zhang     KokkosSparse::spadd_numeric(&ab->kh, one, ab->C1_global, one, ab->C2_global, ab->C_global);
1050076ba34aSJunchao Zhang 
1051076ba34aSJunchao Zhang     /* atb->C1 = Bd^t * ab->C_petsc */
1052076ba34aSJunchao Zhang     atb = mmdata->mmAtB;
105308401ef6SPierre Jolivet     PetscCheck(atb->C1->product->B == ab->C_petsc, PETSC_COMM_SELF, PETSC_ERR_PLIB, "In MATPRODUCT_PtAP, internal mat product matrix atb->C1->B has unexpectedly changed");
10549566063dSJacob Faibussowitsch     if (atb->C1->product->A != Bd) PetscCall(MatProductReplaceMats(Bd, NULL, NULL, atb->C1));
10559566063dSJacob Faibussowitsch     PetscCall((*atb->C1->ops->productnumeric)(atb->C1));
1056076ba34aSJunchao Zhang     /* atb->C2 = Bo^t * ab->C_petsc */
10579566063dSJacob Faibussowitsch     if (atb->C2->product->A != Bo) PetscCall(MatProductReplaceMats(Bo, NULL, NULL, atb->C2));
10589566063dSJacob Faibussowitsch     PetscCall((*atb->C2->ops->productnumeric)(atb->C2));
10599371c9d4SSatish Balay     PetscCall(MatSeqAIJKokkosReduce(atb->C2, MAT_REUSE_MATRIX, PETSC_FALSE, 0 /* N */, MatColIdxKokkosView() /*l2g*/, NULL /* ownerSF */, atb->sf, atb->abuf, atb->srcrowoffset, atb->dstrowoffset, atb->C2_global));
1060076ba34aSJunchao Zhang     KokkosSparse::spadd_numeric(&atb->kh, one, atb->C1_global, one, atb->C2_global, atb->C_global);
1061076ba34aSJunchao Zhang     mm = static_cast<MatMatStruct *>(atb);
1062076ba34aSJunchao Zhang   }
1063076ba34aSJunchao Zhang   /* Split C_global to form C */
10649566063dSJacob Faibussowitsch   PetscCall(MatSetMPIAIJKokkosWithGlobalCSRMatrix(C, MAT_REUSE_MATRIX, mm->C_global, mm->Cdstart));
1065076ba34aSJunchao Zhang   PetscFunctionReturn(0);
1066076ba34aSJunchao Zhang }
1067076ba34aSJunchao Zhang 
10689371c9d4SSatish Balay PetscErrorCode MatProductSymbolic_MPIAIJKokkos(Mat C) {
1069076ba34aSJunchao Zhang   Mat                          A, B;
1070076ba34aSJunchao Zhang   Mat_Product                 *product = C->product;
1071076ba34aSJunchao Zhang   MatProductType               ptype;
1072076ba34aSJunchao Zhang   MatProductData_MPIAIJKokkos *mmdata;
1073076ba34aSJunchao Zhang   MatMatStruct                *mm   = NULL;
1074076ba34aSJunchao Zhang   IS                           glob = NULL;
1075076ba34aSJunchao Zhang   const PetscInt              *garray;
1076076ba34aSJunchao Zhang   PetscInt                     m, n, M, N, sz;
1077076ba34aSJunchao Zhang   ConstMatColIdxKokkosView     l2g; /* map local col ids to global ones */
1078076ba34aSJunchao Zhang 
1079076ba34aSJunchao Zhang   PetscFunctionBegin;
1080076ba34aSJunchao Zhang   MatCheckProduct(C, 1);
108128b400f6SJacob Faibussowitsch   PetscCheck(!product->data, PetscObjectComm((PetscObject)C), PETSC_ERR_PLIB, "Product data not empty");
1082076ba34aSJunchao Zhang   ptype = product->type;
1083076ba34aSJunchao Zhang   A     = product->A;
1084076ba34aSJunchao Zhang   B     = product->B;
1085076ba34aSJunchao Zhang 
1086076ba34aSJunchao Zhang   switch (ptype) {
10879371c9d4SSatish Balay   case MATPRODUCT_AB:
10889371c9d4SSatish Balay     m = A->rmap->n;
10899371c9d4SSatish Balay     n = B->cmap->n;
10909371c9d4SSatish Balay     M = A->rmap->N;
10919371c9d4SSatish Balay     N = B->cmap->N;
10929371c9d4SSatish Balay     break;
10939371c9d4SSatish Balay   case MATPRODUCT_AtB:
10949371c9d4SSatish Balay     m = A->cmap->n;
10959371c9d4SSatish Balay     n = B->cmap->n;
10969371c9d4SSatish Balay     M = A->cmap->N;
10979371c9d4SSatish Balay     N = B->cmap->N;
10989371c9d4SSatish Balay     break;
10999371c9d4SSatish Balay   case MATPRODUCT_PtAP:
11009371c9d4SSatish Balay     m = B->cmap->n;
11019371c9d4SSatish Balay     n = B->cmap->n;
11029371c9d4SSatish Balay     M = B->cmap->N;
11039371c9d4SSatish Balay     N = B->cmap->N;
11049371c9d4SSatish Balay     break; /* BtAB */
110598921bdaSJacob Faibussowitsch   default: SETERRQ(PetscObjectComm((PetscObject)C), PETSC_ERR_PLIB, "Not for product type %s", MatProductTypes[ptype]);
1106076ba34aSJunchao Zhang   }
1107076ba34aSJunchao Zhang 
11089566063dSJacob Faibussowitsch   PetscCall(MatSetSizes(C, m, n, M, N));
11099566063dSJacob Faibussowitsch   PetscCall(PetscLayoutSetUp(C->rmap));
11109566063dSJacob Faibussowitsch   PetscCall(PetscLayoutSetUp(C->cmap));
11119566063dSJacob Faibussowitsch   PetscCall(MatSetType(C, ((PetscObject)A)->type_name));
1112076ba34aSJunchao Zhang 
1113076ba34aSJunchao Zhang   mmdata           = new MatProductData_MPIAIJKokkos();
1114076ba34aSJunchao Zhang   mmdata->reusesym = product->api_user;
1115076ba34aSJunchao Zhang 
1116076ba34aSJunchao Zhang   if (ptype == MATPRODUCT_AB) {
1117076ba34aSJunchao Zhang     mmdata->mmAB = new MatMatStruct_AB();
11189566063dSJacob Faibussowitsch     PetscCall(MatProductSymbolic_MPIAIJKokkos_AB(product, A, B, mmdata->mmAB));
1119076ba34aSJunchao Zhang     mm = static_cast<MatMatStruct *>(mmdata->mmAB);
1120076ba34aSJunchao Zhang   } else if (ptype == MATPRODUCT_AtB) {
1121076ba34aSJunchao Zhang     mmdata->mmAtB = new MatMatStruct_AtB();
1122076ba34aSJunchao Zhang     auto atb      = mmdata->mmAtB;
11239566063dSJacob Faibussowitsch     PetscCall(MatMPIAIJGetLocalMatMerge(B, MAT_INITIAL_MATRIX, &glob, &atb->B_local));
11249566063dSJacob Faibussowitsch     PetscCall(ISGetIndices(glob, &garray));
11259566063dSJacob Faibussowitsch     PetscCall(ISGetSize(glob, &sz));
1126076ba34aSJunchao Zhang     l2g = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), ConstMatColIdxKokkosViewHost(garray, sz));
11279566063dSJacob Faibussowitsch     PetscCall(MatProductSymbolic_MPIAIJKokkos_AtB(product, A, atb->B_local, PETSC_TRUE, N, l2g, atb));
11289566063dSJacob Faibussowitsch     PetscCall(ISRestoreIndices(glob, &garray));
11299566063dSJacob Faibussowitsch     PetscCall(ISDestroy(&glob));
1130076ba34aSJunchao Zhang     mm = static_cast<MatMatStruct *>(atb);
1131076ba34aSJunchao Zhang   } else if (ptype == MATPRODUCT_PtAP) {    /* BtAB */
1132076ba34aSJunchao Zhang     mmdata->mmAB  = new MatMatStruct_AB();  /* tmp=A*B */
1133076ba34aSJunchao Zhang     mmdata->mmAtB = new MatMatStruct_AtB(); /* C=B^t*tmp */
1134076ba34aSJunchao Zhang     auto ab       = mmdata->mmAB;
1135076ba34aSJunchao Zhang     auto atb      = mmdata->mmAtB;
11369566063dSJacob Faibussowitsch     PetscCall(MatProductSymbolic_MPIAIJKokkos_AB(product, A, B, ab));
1137076ba34aSJunchao Zhang     auto tmp = new Mat_SeqAIJKokkos(ab->C_global); /* Memory will be owned by ab->C_petsc */
11389566063dSJacob Faibussowitsch     PetscCall(MatCreateSeqAIJKokkosWithCSRMatrix(PETSC_COMM_SELF, tmp, &ab->C_petsc));
11399566063dSJacob Faibussowitsch     PetscCall(MatProductSymbolic_MPIAIJKokkos_AtB(product, B, ab->C_petsc, PETSC_FALSE, N, l2g /*not used*/, atb));
1140076ba34aSJunchao Zhang     mm = static_cast<MatMatStruct *>(atb);
1141076ba34aSJunchao Zhang   }
1142076ba34aSJunchao Zhang   /* Split the C_global into petsc A, B format */
11439566063dSJacob Faibussowitsch   PetscCall(MatSetMPIAIJKokkosWithGlobalCSRMatrix(C, MAT_INITIAL_MATRIX, mm->C_global, mm->Cdstart));
1144076ba34aSJunchao Zhang   C->product->data       = mmdata;
1145076ba34aSJunchao Zhang   C->product->destroy    = MatProductDataDestroy_MPIAIJKokkos;
1146076ba34aSJunchao Zhang   C->ops->productnumeric = MatProductNumeric_MPIAIJKokkos;
1147076ba34aSJunchao Zhang   PetscFunctionReturn(0);
1148076ba34aSJunchao Zhang }
1149076ba34aSJunchao Zhang 
11509371c9d4SSatish Balay PETSC_INTERN PetscErrorCode MatProductSetFromOptions_MPIAIJKokkos(Mat mat) {
1151076ba34aSJunchao Zhang   Mat_Product *product = mat->product;
1152076ba34aSJunchao Zhang   PetscBool    match   = PETSC_FALSE;
1153076ba34aSJunchao Zhang   PetscBool    usecpu  = PETSC_FALSE;
1154076ba34aSJunchao Zhang 
1155076ba34aSJunchao Zhang   PetscFunctionBegin;
1156076ba34aSJunchao Zhang   MatCheckProduct(mat, 1);
115748a46eb9SPierre Jolivet   if (!product->A->boundtocpu && !product->B->boundtocpu) PetscCall(PetscObjectTypeCompare((PetscObject)product->B, ((PetscObject)product->A)->type_name, &match));
1158076ba34aSJunchao Zhang   if (match) { /* we can always fallback to the CPU if requested */
1159076ba34aSJunchao Zhang     switch (product->type) {
1160076ba34aSJunchao Zhang     case MATPRODUCT_AB:
1161076ba34aSJunchao Zhang       if (product->api_user) {
1162d0609cedSBarry Smith         PetscOptionsBegin(PetscObjectComm((PetscObject)mat), ((PetscObject)mat)->prefix, "MatMatMult", "Mat");
11639566063dSJacob Faibussowitsch         PetscCall(PetscOptionsBool("-matmatmult_backend_cpu", "Use CPU code", "MatMatMult", usecpu, &usecpu, NULL));
1164d0609cedSBarry Smith         PetscOptionsEnd();
1165076ba34aSJunchao Zhang       } else {
1166d0609cedSBarry Smith         PetscOptionsBegin(PetscObjectComm((PetscObject)mat), ((PetscObject)mat)->prefix, "MatProduct_AB", "Mat");
11679566063dSJacob Faibussowitsch         PetscCall(PetscOptionsBool("-mat_product_algorithm_backend_cpu", "Use CPU code", "MatMatMult", usecpu, &usecpu, NULL));
1168d0609cedSBarry Smith         PetscOptionsEnd();
1169076ba34aSJunchao Zhang       }
1170076ba34aSJunchao Zhang       break;
1171076ba34aSJunchao Zhang     case MATPRODUCT_AtB:
1172076ba34aSJunchao Zhang       if (product->api_user) {
1173d0609cedSBarry Smith         PetscOptionsBegin(PetscObjectComm((PetscObject)mat), ((PetscObject)mat)->prefix, "MatTransposeMatMult", "Mat");
11749566063dSJacob Faibussowitsch         PetscCall(PetscOptionsBool("-mattransposematmult_backend_cpu", "Use CPU code", "MatTransposeMatMult", usecpu, &usecpu, NULL));
1175d0609cedSBarry Smith         PetscOptionsEnd();
1176076ba34aSJunchao Zhang       } else {
1177d0609cedSBarry Smith         PetscOptionsBegin(PetscObjectComm((PetscObject)mat), ((PetscObject)mat)->prefix, "MatProduct_AtB", "Mat");
11789566063dSJacob Faibussowitsch         PetscCall(PetscOptionsBool("-mat_product_algorithm_backend_cpu", "Use CPU code", "MatTransposeMatMult", usecpu, &usecpu, NULL));
1179d0609cedSBarry Smith         PetscOptionsEnd();
1180076ba34aSJunchao Zhang       }
1181076ba34aSJunchao Zhang       break;
1182076ba34aSJunchao Zhang     case MATPRODUCT_PtAP:
1183076ba34aSJunchao Zhang       if (product->api_user) {
1184d0609cedSBarry Smith         PetscOptionsBegin(PetscObjectComm((PetscObject)mat), ((PetscObject)mat)->prefix, "MatPtAP", "Mat");
11859566063dSJacob Faibussowitsch         PetscCall(PetscOptionsBool("-matptap_backend_cpu", "Use CPU code", "MatPtAP", usecpu, &usecpu, NULL));
1186d0609cedSBarry Smith         PetscOptionsEnd();
1187076ba34aSJunchao Zhang       } else {
1188d0609cedSBarry Smith         PetscOptionsBegin(PetscObjectComm((PetscObject)mat), ((PetscObject)mat)->prefix, "MatProduct_PtAP", "Mat");
11899566063dSJacob Faibussowitsch         PetscCall(PetscOptionsBool("-mat_product_algorithm_backend_cpu", "Use CPU code", "MatPtAP", usecpu, &usecpu, NULL));
1190d0609cedSBarry Smith         PetscOptionsEnd();
1191076ba34aSJunchao Zhang       }
1192076ba34aSJunchao Zhang       break;
11939371c9d4SSatish Balay     default: break;
1194076ba34aSJunchao Zhang     }
1195076ba34aSJunchao Zhang     match = (PetscBool)!usecpu;
1196076ba34aSJunchao Zhang   }
1197076ba34aSJunchao Zhang   if (match) {
1198076ba34aSJunchao Zhang     switch (product->type) {
1199076ba34aSJunchao Zhang     case MATPRODUCT_AB:
1200076ba34aSJunchao Zhang     case MATPRODUCT_AtB:
12019371c9d4SSatish Balay     case MATPRODUCT_PtAP: mat->ops->productsymbolic = MatProductSymbolic_MPIAIJKokkos; break;
12029371c9d4SSatish Balay     default: break;
1203076ba34aSJunchao Zhang     }
1204076ba34aSJunchao Zhang   }
1205076ba34aSJunchao Zhang   /* fallback to MPIAIJ ops */
120648a46eb9SPierre Jolivet   if (!mat->ops->productsymbolic) PetscCall(MatProductSetFromOptions_MPIAIJ(mat));
1207076ba34aSJunchao Zhang   PetscFunctionReturn(0);
1208076ba34aSJunchao Zhang }
1209076ba34aSJunchao Zhang 
12109371c9d4SSatish Balay static PetscErrorCode MatSetPreallocationCOO_MPIAIJKokkos(Mat mat, PetscCount coo_n, PetscInt coo_i[], PetscInt coo_j[]) {
1211394ed5ebSJunchao Zhang   Mat_MPIAIJ       *mpiaij = (Mat_MPIAIJ *)mat->data;
1212cbc6b225SStefano Zampini   Mat_MPIAIJKokkos *mpikok;
121342550becSJunchao Zhang 
121442550becSJunchao Zhang   PetscFunctionBegin;
12159566063dSJacob Faibussowitsch   PetscCall(MatSetPreallocationCOO_MPIAIJ(mat, coo_n, coo_i, coo_j));
1216cbc6b225SStefano Zampini   mat->preallocated = PETSC_TRUE;
12179566063dSJacob Faibussowitsch   PetscCall(MatAssemblyBegin(mat, MAT_FINAL_ASSEMBLY));
12189566063dSJacob Faibussowitsch   PetscCall(MatAssemblyEnd(mat, MAT_FINAL_ASSEMBLY));
12199566063dSJacob Faibussowitsch   PetscCall(MatZeroEntries(mat));
1220cbc6b225SStefano Zampini   mpikok = static_cast<Mat_MPIAIJKokkos *>(mpiaij->spptr);
1221cbc6b225SStefano Zampini   delete mpikok;
1222394ed5ebSJunchao Zhang   mpiaij->spptr = new Mat_MPIAIJKokkos(mpiaij);
122342550becSJunchao Zhang   PetscFunctionReturn(0);
122442550becSJunchao Zhang }
122542550becSJunchao Zhang 
12269371c9d4SSatish Balay static PetscErrorCode MatSetValuesCOO_MPIAIJKokkos(Mat mat, const PetscScalar v[], InsertMode imode) {
1227394ed5ebSJunchao Zhang   Mat_MPIAIJ                 *mpiaij = static_cast<Mat_MPIAIJ *>(mat->data);
122842550becSJunchao Zhang   Mat_MPIAIJKokkos           *mpikok = static_cast<Mat_MPIAIJKokkos *>(mpiaij->spptr);
122942550becSJunchao Zhang   Mat                         A = mpiaij->A, B = mpiaij->B;
1230158ec288SJunchao Zhang   PetscCount                  Annz = mpiaij->Annz, Annz2 = mpiaij->Annz2, Bnnz = mpiaij->Bnnz, Bnnz2 = mpiaij->Bnnz2;
123142550becSJunchao Zhang   MatScalarKokkosView         Aa, Ba;
1232394ed5ebSJunchao Zhang   MatScalarKokkosView         v1;
123342550becSJunchao Zhang   MatScalarKokkosView        &vsend  = mpikok->sendbuf_d;
123442550becSJunchao Zhang   const MatScalarKokkosView  &v2     = mpikok->recvbuf_d;
1235158ec288SJunchao Zhang   const PetscCountKokkosView &Ajmap1 = mpikok->Ajmap1_d, Ajmap2 = mpikok->Ajmap2_d, Aimap2 = mpikok->Aimap2_d;
1236158ec288SJunchao Zhang   const PetscCountKokkosView &Bjmap1 = mpikok->Bjmap1_d, Bjmap2 = mpikok->Bjmap2_d, Bimap2 = mpikok->Bimap2_d;
1237394ed5ebSJunchao Zhang   const PetscCountKokkosView &Aperm1 = mpikok->Aperm1_d, Aperm2 = mpikok->Aperm2_d, Bperm1 = mpikok->Bperm1_d, Bperm2 = mpikok->Bperm2_d;
1238394ed5ebSJunchao Zhang   const PetscCountKokkosView &Cperm1 = mpikok->Cperm1_d;
123942550becSJunchao Zhang   PetscMemType                memtype;
124042550becSJunchao Zhang 
124142550becSJunchao Zhang   PetscFunctionBegin;
12429566063dSJacob Faibussowitsch   PetscCall(PetscGetMemType(v, &memtype)); /* Return PETSC_MEMTYPE_HOST when v is NULL */
124342550becSJunchao Zhang   if (PetscMemTypeHost(memtype)) {         /* If user gave v[] in host, we need to copy it to device if any */
1244394ed5ebSJunchao Zhang     v1 = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), MatScalarKokkosViewHost((PetscScalar *)v, mpiaij->coo_n));
124542550becSJunchao Zhang   } else {
1246394ed5ebSJunchao Zhang     v1 = MatScalarKokkosView((PetscScalar *)v, mpiaij->coo_n); /* Directly use v[]'s memory */
124742550becSJunchao Zhang   }
124842550becSJunchao Zhang 
124942550becSJunchao Zhang   if (imode == INSERT_VALUES) {
12509566063dSJacob Faibussowitsch     PetscCall(MatSeqAIJGetKokkosViewWrite(A, &Aa)); /* write matrix values */
12519566063dSJacob Faibussowitsch     PetscCall(MatSeqAIJGetKokkosViewWrite(B, &Ba));
1252394ed5ebSJunchao Zhang   } else {
12539566063dSJacob Faibussowitsch     PetscCall(MatSeqAIJGetKokkosView(A, &Aa)); /* read & write matrix values */
12549566063dSJacob Faibussowitsch     PetscCall(MatSeqAIJGetKokkosView(B, &Ba));
125542550becSJunchao Zhang   }
125642550becSJunchao Zhang 
125742550becSJunchao Zhang   /* Pack entries to be sent to remote */
12589371c9d4SSatish Balay   Kokkos::parallel_for(
12599371c9d4SSatish Balay     vsend.extent(0), KOKKOS_LAMBDA(const PetscCount i) { vsend(i) = v1(Cperm1(i)); });
126042550becSJunchao Zhang 
126142550becSJunchao Zhang   /* Send remote entries to their owner and overlap the communication with local computation */
12629566063dSJacob Faibussowitsch   PetscCall(PetscSFReduceWithMemTypeBegin(mpiaij->coo_sf, MPIU_SCALAR, PETSC_MEMTYPE_KOKKOS, vsend.data(), PETSC_MEMTYPE_KOKKOS, v2.data(), MPI_REPLACE));
1263158ec288SJunchao Zhang   /* Add local entries to A and B in one kernel */
12649371c9d4SSatish Balay   Kokkos::parallel_for(
12659371c9d4SSatish Balay     Annz + Bnnz, KOKKOS_LAMBDA(PetscCount i) {
1266158ec288SJunchao Zhang       PetscScalar sum = 0.0;
1267158ec288SJunchao Zhang       if (i < Annz) {
1268158ec288SJunchao Zhang         for (PetscCount k = Ajmap1(i); k < Ajmap1(i + 1); k++) sum += v1(Aperm1(k));
1269ac38520cSJunchao Zhang         Aa(i) = (imode == INSERT_VALUES ? 0.0 : Aa(i)) + sum;
1270158ec288SJunchao Zhang       } else {
1271158ec288SJunchao Zhang         i -= Annz;
1272158ec288SJunchao Zhang         for (PetscCount k = Bjmap1(i); k < Bjmap1(i + 1); k++) sum += v1(Bperm1(k));
1273ac38520cSJunchao Zhang         Ba(i) = (imode == INSERT_VALUES ? 0.0 : Ba(i)) + sum;
1274158ec288SJunchao Zhang       }
1275158ec288SJunchao Zhang     });
12769566063dSJacob Faibussowitsch   PetscCall(PetscSFReduceEnd(mpiaij->coo_sf, MPIU_SCALAR, vsend.data(), v2.data(), MPI_REPLACE));
127742550becSJunchao Zhang 
1278158ec288SJunchao Zhang   /* Add received remote entries to A and B in one kernel */
12799371c9d4SSatish Balay   Kokkos::parallel_for(
12809371c9d4SSatish Balay     Annz2 + Bnnz2, KOKKOS_LAMBDA(PetscCount i) {
1281158ec288SJunchao Zhang       if (i < Annz2) {
1282158ec288SJunchao Zhang         for (PetscCount k = Ajmap2(i); k < Ajmap2(i + 1); k++) Aa(Aimap2(i)) += v2(Aperm2(k));
1283158ec288SJunchao Zhang       } else {
1284158ec288SJunchao Zhang         i -= Annz2;
1285158ec288SJunchao Zhang         for (PetscCount k = Bjmap2(i); k < Bjmap2(i + 1); k++) Ba(Bimap2(i)) += v2(Bperm2(k));
1286158ec288SJunchao Zhang       }
1287158ec288SJunchao Zhang     });
128842550becSJunchao Zhang 
1289394ed5ebSJunchao Zhang   if (imode == INSERT_VALUES) {
12909566063dSJacob Faibussowitsch     PetscCall(MatSeqAIJRestoreKokkosViewWrite(A, &Aa)); /* Increase A & B's state etc. */
12919566063dSJacob Faibussowitsch     PetscCall(MatSeqAIJRestoreKokkosViewWrite(B, &Ba));
1292394ed5ebSJunchao Zhang   } else {
12939566063dSJacob Faibussowitsch     PetscCall(MatSeqAIJRestoreKokkosView(A, &Aa));
12949566063dSJacob Faibussowitsch     PetscCall(MatSeqAIJRestoreKokkosView(B, &Ba));
1295394ed5ebSJunchao Zhang   }
129642550becSJunchao Zhang   PetscFunctionReturn(0);
129742550becSJunchao Zhang }
129842550becSJunchao Zhang 
12999371c9d4SSatish Balay PetscErrorCode MatDestroy_MPIAIJKokkos(Mat A) {
130042550becSJunchao Zhang   Mat_MPIAIJ *mpiaij = (Mat_MPIAIJ *)A->data;
1301076ba34aSJunchao Zhang 
1302076ba34aSJunchao Zhang   PetscFunctionBegin;
13039566063dSJacob Faibussowitsch   PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatMPIAIJSetPreallocation_C", NULL));
13049566063dSJacob Faibussowitsch   PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatMPIAIJGetLocalMatMerge_C", NULL));
13059566063dSJacob Faibussowitsch   PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatSetPreallocationCOO_C", NULL));
13069566063dSJacob Faibussowitsch   PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatSetValuesCOO_C", NULL));
130742550becSJunchao Zhang   delete (Mat_MPIAIJKokkos *)mpiaij->spptr;
13089566063dSJacob Faibussowitsch   PetscCall(MatDestroy_MPIAIJ(A));
1309076ba34aSJunchao Zhang   PetscFunctionReturn(0);
1310076ba34aSJunchao Zhang }
1311076ba34aSJunchao Zhang 
13129371c9d4SSatish Balay PETSC_INTERN PetscErrorCode MatConvert_MPIAIJ_MPIAIJKokkos(Mat A, MatType mtype, MatReuse reuse, Mat *newmat) {
13138c3ff71bSJunchao Zhang   Mat         B;
1314076ba34aSJunchao Zhang   Mat_MPIAIJ *a;
13158c3ff71bSJunchao Zhang 
13168c3ff71bSJunchao Zhang   PetscFunctionBegin;
13178c3ff71bSJunchao Zhang   if (reuse == MAT_INITIAL_MATRIX) {
13189566063dSJacob Faibussowitsch     PetscCall(MatDuplicate(A, MAT_COPY_VALUES, newmat));
13198c3ff71bSJunchao Zhang   } else if (reuse == MAT_REUSE_MATRIX) {
13209566063dSJacob Faibussowitsch     PetscCall(MatCopy(A, *newmat, SAME_NONZERO_PATTERN));
13218c3ff71bSJunchao Zhang   }
13228c3ff71bSJunchao Zhang   B = *newmat;
13238c3ff71bSJunchao Zhang 
13246f3d89d0SStefano Zampini   B->boundtocpu = PETSC_FALSE;
13259566063dSJacob Faibussowitsch   PetscCall(PetscFree(B->defaultvectype));
13269566063dSJacob Faibussowitsch   PetscCall(PetscStrallocpy(VECKOKKOS, &B->defaultvectype));
13279566063dSJacob Faibussowitsch   PetscCall(PetscObjectChangeTypeName((PetscObject)B, MATMPIAIJKOKKOS));
13288c3ff71bSJunchao Zhang 
1329076ba34aSJunchao Zhang   a = static_cast<Mat_MPIAIJ *>(A->data);
13309566063dSJacob Faibussowitsch   if (a->A) PetscCall(MatSetType(a->A, MATSEQAIJKOKKOS));
13319566063dSJacob Faibussowitsch   if (a->B) PetscCall(MatSetType(a->B, MATSEQAIJKOKKOS));
13329566063dSJacob Faibussowitsch   if (a->lvec) PetscCall(VecSetType(a->lvec, VECSEQKOKKOS));
1333076ba34aSJunchao Zhang 
13348c3ff71bSJunchao Zhang   B->ops->assemblyend           = MatAssemblyEnd_MPIAIJKokkos;
13358c3ff71bSJunchao Zhang   B->ops->mult                  = MatMult_MPIAIJKokkos;
13368c3ff71bSJunchao Zhang   B->ops->multadd               = MatMultAdd_MPIAIJKokkos;
13378c3ff71bSJunchao Zhang   B->ops->multtranspose         = MatMultTranspose_MPIAIJKokkos;
1338076ba34aSJunchao Zhang   B->ops->productsetfromoptions = MatProductSetFromOptions_MPIAIJKokkos;
1339076ba34aSJunchao Zhang   B->ops->destroy               = MatDestroy_MPIAIJKokkos;
13408c3ff71bSJunchao Zhang 
13419566063dSJacob Faibussowitsch   PetscCall(PetscObjectComposeFunction((PetscObject)B, "MatMPIAIJSetPreallocation_C", MatMPIAIJSetPreallocation_MPIAIJKokkos));
13429566063dSJacob Faibussowitsch   PetscCall(PetscObjectComposeFunction((PetscObject)B, "MatMPIAIJGetLocalMatMerge_C", MatMPIAIJGetLocalMatMerge_MPIAIJKokkos));
13439566063dSJacob Faibussowitsch   PetscCall(PetscObjectComposeFunction((PetscObject)B, "MatSetPreallocationCOO_C", MatSetPreallocationCOO_MPIAIJKokkos));
13449566063dSJacob Faibussowitsch   PetscCall(PetscObjectComposeFunction((PetscObject)B, "MatSetValuesCOO_C", MatSetValuesCOO_MPIAIJKokkos));
13458c3ff71bSJunchao Zhang   PetscFunctionReturn(0);
13468c3ff71bSJunchao Zhang }
13473f3ba80aSJunchao Zhang /*MC
1348*11a5261eSBarry Smith    MATAIJKOKKOS - "mpiaijkokkos", a matrix type to be used for CSR sparse matrices with Kokkos
13498c3ff71bSJunchao Zhang 
13503f3ba80aSJunchao Zhang    A matrix type type using Kokkos-Kernels CrsMatrix type for portability across different device types
13513f3ba80aSJunchao Zhang 
13523f3ba80aSJunchao Zhang    Options Database Keys:
13533f3ba80aSJunchao Zhang .  -mat_type aijkokkos - sets the matrix type to "aijkokkos" during a call to MatSetFromOptions()
13543f3ba80aSJunchao Zhang 
13553f3ba80aSJunchao Zhang   Level: beginner
13563f3ba80aSJunchao Zhang 
1357*11a5261eSBarry Smith .seealso: `MatCreateAIJKokkos()`, `MATSEQAIJKOKKOS`, `MATSEQAIJ`, `MATMPIAIJ`
13583f3ba80aSJunchao Zhang M*/
13599371c9d4SSatish Balay PETSC_EXTERN PetscErrorCode MatCreate_MPIAIJKokkos(Mat A) {
13608c3ff71bSJunchao Zhang   PetscFunctionBegin;
13619566063dSJacob Faibussowitsch   PetscCall(PetscKokkosInitializeCheck());
13629566063dSJacob Faibussowitsch   PetscCall(MatCreate_MPIAIJ(A));
13639566063dSJacob Faibussowitsch   PetscCall(MatConvert_MPIAIJ_MPIAIJKokkos(A, MATMPIAIJKOKKOS, MAT_INPLACE_MATRIX, &A));
13648c3ff71bSJunchao Zhang   PetscFunctionReturn(0);
13658c3ff71bSJunchao Zhang }
13668c3ff71bSJunchao Zhang 
13678c3ff71bSJunchao Zhang /*@C
1368*11a5261eSBarry Smith    MatCreateAIJKokkos - Creates a sparse matrix in `MATAIJKOKOS` (compressed row) format
13698c3ff71bSJunchao Zhang    (the default parallel PETSc format).  This matrix will ultimately pushed down
13708c3ff71bSJunchao Zhang    to Kokkos for calculations. For good matrix
13718c3ff71bSJunchao Zhang    assembly performance the user should preallocate the matrix storage by setting
13728c3ff71bSJunchao Zhang    the parameter nz (or the array nnz).  By setting these parameters accurately,
13738c3ff71bSJunchao Zhang    performance during matrix assembly can be increased by more than a factor of 50.
13748c3ff71bSJunchao Zhang 
13758c3ff71bSJunchao Zhang    Collective
13768c3ff71bSJunchao Zhang 
13778c3ff71bSJunchao Zhang    Input Parameters:
1378*11a5261eSBarry Smith +  comm - MPI communicator, set to `PETSC_COMM_SELF`
13798c3ff71bSJunchao Zhang .  m - number of rows
13808c3ff71bSJunchao Zhang .  n - number of columns
13818c3ff71bSJunchao Zhang .  nz - number of nonzeros per row (same for all rows)
13828c3ff71bSJunchao Zhang -  nnz - array containing the number of nonzeros in the various rows
13838c3ff71bSJunchao Zhang          (possibly different for each row) or NULL
13848c3ff71bSJunchao Zhang 
13858c3ff71bSJunchao Zhang    Output Parameter:
13868c3ff71bSJunchao Zhang .  A - the matrix
13878c3ff71bSJunchao Zhang 
1388*11a5261eSBarry Smith    It is recommended that one use the `MatCreate()`, `MatSetType()` and/or `MatSetFromOptions()`,
13898c3ff71bSJunchao Zhang    MatXXXXSetPreallocation() paradigm instead of this routine directly.
1390*11a5261eSBarry Smith    [MatXXXXSetPreallocation() is, for example, `MatSeqAIJSetPreallocation()`]
13918c3ff71bSJunchao Zhang 
13928c3ff71bSJunchao Zhang    Notes:
13938c3ff71bSJunchao Zhang    If nnz is given then nz is ignored
13948c3ff71bSJunchao Zhang 
1395*11a5261eSBarry Smith    The AIJ format, also called compressed row storage), is fully compatible with standard Fortran 77
13968c3ff71bSJunchao Zhang    storage.  That is, the stored row and column indices can begin at
13978c3ff71bSJunchao Zhang    either one (as in Fortran) or zero.  See the users' manual for details.
13988c3ff71bSJunchao Zhang 
13998c3ff71bSJunchao Zhang    Specify the preallocated storage with either nz or nnz (not both).
1400*11a5261eSBarry Smith    Set nz = `PETSC_DEFAULT` and nnz = NULL for PETSc to control dynamic memory
14018c3ff71bSJunchao Zhang    allocation.  For large problems you MUST preallocate memory or you
14028c3ff71bSJunchao Zhang    will get TERRIBLE performance, see the users' manual chapter on matrices.
14038c3ff71bSJunchao Zhang 
14048c3ff71bSJunchao Zhang    By default, this format uses inodes (identical nodes) when possible, to
14058c3ff71bSJunchao Zhang    improve numerical efficiency of matrix-vector products and solves. We
14068c3ff71bSJunchao Zhang    search for consecutive rows with the same nonzero structure, thereby
14078c3ff71bSJunchao Zhang    reusing matrix information to achieve increased efficiency.
14088c3ff71bSJunchao Zhang 
14098c3ff71bSJunchao Zhang    Level: intermediate
14108c3ff71bSJunchao Zhang 
1411*11a5261eSBarry Smith .seealso: `MATAIJKOKOS`, `MATSEQAIJKOKOS`, `MATMPIAIJKOKOS`, `MatCreate()`, `MatCreateAIJ()`, `MatSetValues()`, `MatSeqAIJSetColumnIndices()`, `MatCreateSeqAIJWithArrays()`, `MatCreateAIJ()`, `MATMPIAIJKOKKOS`, `MATAIJKOKKOS`
14128c3ff71bSJunchao Zhang @*/
14139371c9d4SSatish Balay PetscErrorCode MatCreateAIJKokkos(MPI_Comm comm, PetscInt m, PetscInt n, PetscInt M, PetscInt N, PetscInt d_nz, const PetscInt d_nnz[], PetscInt o_nz, const PetscInt o_nnz[], Mat *A) {
14148c3ff71bSJunchao Zhang   PetscMPIInt size;
14158c3ff71bSJunchao Zhang 
14168c3ff71bSJunchao Zhang   PetscFunctionBegin;
14179566063dSJacob Faibussowitsch   PetscCall(MatCreate(comm, A));
14189566063dSJacob Faibussowitsch   PetscCall(MatSetSizes(*A, m, n, M, N));
14199566063dSJacob Faibussowitsch   PetscCallMPI(MPI_Comm_size(comm, &size));
14208c3ff71bSJunchao Zhang   if (size > 1) {
14219566063dSJacob Faibussowitsch     PetscCall(MatSetType(*A, MATMPIAIJKOKKOS));
14229566063dSJacob Faibussowitsch     PetscCall(MatMPIAIJSetPreallocation(*A, d_nz, d_nnz, o_nz, o_nnz));
14238c3ff71bSJunchao Zhang   } else {
14249566063dSJacob Faibussowitsch     PetscCall(MatSetType(*A, MATSEQAIJKOKKOS));
14259566063dSJacob Faibussowitsch     PetscCall(MatSeqAIJSetPreallocation(*A, d_nz, d_nnz));
14268c3ff71bSJunchao Zhang   }
14278c3ff71bSJunchao Zhang   PetscFunctionReturn(0);
14288c3ff71bSJunchao Zhang }
14298c3ff71bSJunchao Zhang 
1430a587d139SMark // get GPU pointer to stripped down Mat. For both Seq and MPI Mat.
14319371c9d4SSatish Balay PetscErrorCode MatKokkosGetDeviceMatWrite(Mat A, PetscSplitCSRDataStructure *B) {
1432a587d139SMark   PetscMPIInt                size, rank;
1433a587d139SMark   MPI_Comm                   comm;
1434042217e8SBarry Smith   PetscSplitCSRDataStructure d_mat = NULL;
1435a587d139SMark 
1436a587d139SMark   PetscFunctionBegin;
14379566063dSJacob Faibussowitsch   PetscCall(PetscObjectGetComm((PetscObject)A, &comm));
14389566063dSJacob Faibussowitsch   PetscCallMPI(MPI_Comm_size(comm, &size));
14399566063dSJacob Faibussowitsch   PetscCallMPI(MPI_Comm_rank(comm, &rank));
1440a587d139SMark   if (size == 1) {
14419566063dSJacob Faibussowitsch     PetscCall(MatSeqAIJKokkosGetDeviceMat(A, &d_mat));
14429566063dSJacob Faibussowitsch     PetscCall(MatSeqAIJKokkosModifyDevice(A)); /* Since we are going to modify matrix values on device */
1443a587d139SMark   } else {
1444a587d139SMark     Mat_MPIAIJ *aij = (Mat_MPIAIJ *)A->data;
14459566063dSJacob Faibussowitsch     PetscCall(MatSeqAIJKokkosGetDeviceMat(aij->A, &d_mat));
14469566063dSJacob Faibussowitsch     PetscCall(MatSeqAIJKokkosModifyDevice(aij->A));
14479566063dSJacob Faibussowitsch     PetscCall(MatSeqAIJKokkosModifyDevice(aij->B));
14482c71b3e2SJacob Faibussowitsch     PetscCheck(A->nooffprocentries || aij->donotstash, PetscObjectComm((PetscObject)A), PETSC_ERR_SUP, "Device assembly does not currently support offproc values insertion. Use MatSetOption(A,MAT_NO_OFF_PROC_ENTRIES,PETSC_TRUE) or MatSetOption(A,MAT_IGNORE_OFF_PROC_ENTRIES,PETSC_TRUE)");
1449a587d139SMark   }
1450a587d139SMark   // act like MatSetValues because not called on host
1451a587d139SMark   if (A->assembled) {
145248a46eb9SPierre Jolivet     if (A->was_assembled) PetscCall(PetscInfo(A, "Assemble more than once already\n"));
1453a587d139SMark     A->was_assembled = PETSC_TRUE; // this is done (lazy) in MatAssemble but we are not calling it anymore - done in AIJ AssemblyEnd, need here?
1454a587d139SMark   } else {
14559566063dSJacob Faibussowitsch     PetscCall(PetscInfo(A, "Warning !assemble ??? assembled=%" PetscInt_FMT "\n", A->assembled));
1456a587d139SMark   }
1457a587d139SMark   if (!d_mat) {
1458042217e8SBarry Smith     struct _n_SplitCSRMat h_mat; /* host container */
1459a587d139SMark     Mat_SeqAIJKokkos     *aijkokA;
1460a587d139SMark     Mat_SeqAIJ           *jaca;
1461a587d139SMark     PetscInt              n = A->rmap->n, nnz;
1462a587d139SMark     Mat                   Amat;
1463042217e8SBarry Smith     PetscInt             *colmap;
1464042217e8SBarry Smith 
1465042217e8SBarry Smith     /* create and copy h_mat */
146649b994a9SMark Adams     h_mat.M = A->cmap->N; // use for debug build
14679566063dSJacob Faibussowitsch     PetscCall(PetscInfo(A, "Create device matrix in Kokkos\n"));
1468a587d139SMark     if (size == 1) {
1469a587d139SMark       Amat            = A;
1470a587d139SMark       jaca            = (Mat_SeqAIJ *)A->data;
14719371c9d4SSatish Balay       h_mat.rstart    = 0;
14729371c9d4SSatish Balay       h_mat.rend      = A->rmap->n;
14739371c9d4SSatish Balay       h_mat.cstart    = 0;
14749371c9d4SSatish Balay       h_mat.cend      = A->cmap->n;
1475a587d139SMark       h_mat.offdiag.i = h_mat.offdiag.j = NULL;
1476a587d139SMark       h_mat.offdiag.a                   = NULL;
1477a587d139SMark       aijkokA                           = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
1478a587d139SMark     } else {
1479a587d139SMark       Mat_MPIAIJ       *aij  = (Mat_MPIAIJ *)A->data;
1480a587d139SMark       Mat_SeqAIJ       *jacb = (Mat_SeqAIJ *)aij->B->data;
1481a587d139SMark       PetscInt          ii;
1482a587d139SMark       Mat_SeqAIJKokkos *aijkokB;
1483042217e8SBarry Smith 
1484a587d139SMark       Amat    = aij->A;
1485a587d139SMark       aijkokA = static_cast<Mat_SeqAIJKokkos *>(aij->A->spptr);
1486a587d139SMark       aijkokB = static_cast<Mat_SeqAIJKokkos *>(aij->B->spptr);
1487a587d139SMark       jaca    = (Mat_SeqAIJ *)aij->A->data;
148808401ef6SPierre Jolivet       PetscCheck(!aij->B->cmap->n || aij->garray, comm, PETSC_ERR_PLIB, "MPIAIJ Matrix was assembled but is missing garray");
148908401ef6SPierre Jolivet       PetscCheck(aij->B->rmap->n == aij->A->rmap->n, comm, PETSC_ERR_SUP, "Only support aij->B->rmap->n == aij->A->rmap->n");
1490a587d139SMark       aij->donotstash          = PETSC_TRUE;
1491a587d139SMark       aij->A->nooffprocentries = aij->B->nooffprocentries = A->nooffprocentries = PETSC_TRUE;
1492a5b23f4aSJose E. Roman       jaca->nonew = jacb->nonew = PETSC_TRUE; // no more disassembly
14939566063dSJacob Faibussowitsch       PetscCall(PetscCalloc1(A->cmap->N, &colmap));
14949566063dSJacob Faibussowitsch       PetscCall(PetscLogObjectMemory((PetscObject)A, (A->cmap->N) * sizeof(PetscInt)));
1495042217e8SBarry Smith       for (ii = 0; ii < aij->B->cmap->n; ii++) colmap[aij->garray[ii]] = ii + 1;
1496a587d139SMark       // allocate B copy data
14979371c9d4SSatish Balay       h_mat.rstart = A->rmap->rstart;
14989371c9d4SSatish Balay       h_mat.rend   = A->rmap->rend;
14999371c9d4SSatish Balay       h_mat.cstart = A->cmap->rstart;
15009371c9d4SSatish Balay       h_mat.cend   = A->cmap->rend;
1501a587d139SMark       nnz          = jacb->i[n];
1502a587d139SMark       if (jacb->compressedrow.use) {
1503a587d139SMark         const Kokkos::View<PetscInt *, Kokkos::HostSpace, Kokkos::MemoryTraits<Kokkos::Unmanaged>> h_i_k(jacb->i, n + 1);
1504300d22a6SJunchao Zhang         aijkokB->i_uncompressed_d = Kokkos::View<PetscInt *>(Kokkos::create_mirror(DefaultMemorySpace(), h_i_k));
1505300d22a6SJunchao Zhang         Kokkos::deep_copy(aijkokB->i_uncompressed_d, h_i_k);
1506300d22a6SJunchao Zhang         h_mat.offdiag.i = aijkokB->i_uncompressed_d.data();
1507a587d139SMark       } else {
150899551766SMark Adams         h_mat.offdiag.i = aijkokB->i_device_data();
1509a587d139SMark       }
151099551766SMark Adams       h_mat.offdiag.j = aijkokB->j_device_data();
1511076ba34aSJunchao Zhang       h_mat.offdiag.a = aijkokB->a_device_data();
1512a587d139SMark       {
1513042217e8SBarry Smith         Kokkos::View<PetscInt *, Kokkos::HostSpace, Kokkos::MemoryTraits<Kokkos::Unmanaged>> h_colmap_k(colmap, A->cmap->N);
1514300d22a6SJunchao Zhang         aijkokB->colmap_d = Kokkos::View<PetscInt *>(Kokkos::create_mirror(DefaultMemorySpace(), h_colmap_k));
1515300d22a6SJunchao Zhang         Kokkos::deep_copy(aijkokB->colmap_d, h_colmap_k);
1516300d22a6SJunchao Zhang         h_mat.colmap = aijkokB->colmap_d.data();
15179566063dSJacob Faibussowitsch         PetscCall(PetscFree(colmap));
1518a587d139SMark       }
1519a587d139SMark       h_mat.offdiag.ignorezeroentries = jacb->ignorezeroentries;
1520a587d139SMark       h_mat.offdiag.n                 = n;
1521a587d139SMark     }
1522a587d139SMark     // allocate A copy data
1523a587d139SMark     nnz                          = jaca->i[n];
1524a587d139SMark     h_mat.diag.n                 = n;
1525a587d139SMark     h_mat.diag.ignorezeroentries = jaca->ignorezeroentries;
15269566063dSJacob Faibussowitsch     PetscCallMPI(MPI_Comm_rank(comm, &h_mat.rank));
1527aed4548fSBarry Smith     PetscCheck(!jaca->compressedrow.use, PETSC_COMM_SELF, PETSC_ERR_PLIB, "A does not suppport compressed row (todo)");
152899551766SMark Adams     h_mat.diag.i = aijkokA->i_device_data();
152999551766SMark Adams     h_mat.diag.j = aijkokA->j_device_data();
1530076ba34aSJunchao Zhang     h_mat.diag.a = aijkokA->a_device_data();
1531a587d139SMark     // copy pointers and metdata to device
15329566063dSJacob Faibussowitsch     PetscCall(MatSeqAIJKokkosSetDeviceMat(Amat, &h_mat));
15339566063dSJacob Faibussowitsch     PetscCall(MatSeqAIJKokkosGetDeviceMat(Amat, &d_mat));
15349566063dSJacob Faibussowitsch     PetscCall(PetscInfo(A, "Create device Mat n=%" PetscInt_FMT " nnz=%" PetscInt_FMT "\n", h_mat.diag.n, nnz));
1535a587d139SMark   }
1536a587d139SMark   *B           = d_mat;       // return it, set it in Mat, and set it up
1537a587d139SMark   A->assembled = PETSC_FALSE; // ready to write with matsetvalues - this done (lazy) in normal MatSetValues
1538a587d139SMark   PetscFunctionReturn(0);
1539a587d139SMark }
1540076ba34aSJunchao Zhang 
15419371c9d4SSatish Balay PETSC_INTERN PetscErrorCode MatSeqAIJKokkosGetOffloadMask(Mat A, const char **mask) {
1542076ba34aSJunchao Zhang   Mat_SeqAIJKokkos *aijkok = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
1543076ba34aSJunchao Zhang 
1544076ba34aSJunchao Zhang   PetscFunctionBegin;
1545076ba34aSJunchao Zhang   if (!aijkok) *mask = "AIJKOK_UNALLOCATED";
1546076ba34aSJunchao Zhang   else if (aijkok->a_dual.need_sync_host()) *mask = "PETSC_OFFLOAD_GPU";
1547076ba34aSJunchao Zhang   else if (aijkok->a_dual.need_sync_device()) *mask = "PETSC_OFFLOAD_CPU";
1548076ba34aSJunchao Zhang   else *mask = "PETSC_OFFLOAD_BOTH";
1549076ba34aSJunchao Zhang   PetscFunctionReturn(0);
1550076ba34aSJunchao Zhang }
1551076ba34aSJunchao Zhang 
15529371c9d4SSatish Balay PETSC_INTERN PetscErrorCode MatAIJKokkosPrintOffloadMask(Mat A) {
1553076ba34aSJunchao Zhang   PetscMPIInt size;
1554076ba34aSJunchao Zhang   Mat         Ad, Ao;
1555076ba34aSJunchao Zhang   const char *amask, *bmask;
1556076ba34aSJunchao Zhang 
1557076ba34aSJunchao Zhang   PetscFunctionBegin;
15589566063dSJacob Faibussowitsch   PetscCallMPI(MPI_Comm_size(PetscObjectComm((PetscObject)A), &size));
1559076ba34aSJunchao Zhang 
1560076ba34aSJunchao Zhang   if (size == 1) {
15619566063dSJacob Faibussowitsch     PetscCall(MatSeqAIJKokkosGetOffloadMask(A, &amask));
15629566063dSJacob Faibussowitsch     PetscCall(PetscPrintf(PETSC_COMM_SELF, "%s\n", amask));
1563076ba34aSJunchao Zhang   } else {
1564076ba34aSJunchao Zhang     Ad = ((Mat_MPIAIJ *)A->data)->A;
1565076ba34aSJunchao Zhang     Ao = ((Mat_MPIAIJ *)A->data)->B;
15669566063dSJacob Faibussowitsch     PetscCall(MatSeqAIJKokkosGetOffloadMask(Ad, &amask));
15679566063dSJacob Faibussowitsch     PetscCall(MatSeqAIJKokkosGetOffloadMask(Ao, &bmask));
15689566063dSJacob Faibussowitsch     PetscCall(PetscPrintf(PETSC_COMM_SELF, "Diag : Off-diag = %s : %s\n", amask, bmask));
1569076ba34aSJunchao Zhang   }
1570076ba34aSJunchao Zhang   PetscFunctionReturn(0);
1571076ba34aSJunchao Zhang }
1572