xref: /petsc/src/mat/impls/aij/mpi/kokkos/mpiaijkok.kokkos.cxx (revision d71ae5a4db6382e7f06317b8d368875286fe9008)
111d22bbfSJunchao Zhang #include <petscvec_kokkos.hpp>
2076ba34aSJunchao Zhang #include <petscsf.h>
342550becSJunchao Zhang #include <petsc/private/sfimpl.h>
48c3ff71bSJunchao Zhang #include <../src/mat/impls/aij/mpi/mpiaij.h>
542550becSJunchao Zhang #include <../src/mat/impls/aij/mpi/kokkos/mpiaijkok.hpp>
6076ba34aSJunchao Zhang #include <KokkosSparse_spadd.hpp>
711d22bbfSJunchao Zhang 
8*d71ae5a4SJacob Faibussowitsch PetscErrorCode MatAssemblyEnd_MPIAIJKokkos(Mat A, MatAssemblyType mode)
9*d71ae5a4SJacob Faibussowitsch {
105519a089SJose E. Roman   Mat_SeqAIJKokkos *aijkok;
118c3ff71bSJunchao Zhang 
128c3ff71bSJunchao Zhang   PetscFunctionBegin;
139566063dSJacob Faibussowitsch   PetscCall(MatAssemblyEnd_MPIAIJ(A, mode));
145519a089SJose E. Roman   aijkok = static_cast<Mat_SeqAIJKokkos *>(((Mat_MPIAIJ *)A->data)->A->spptr); /* Access spptr after MatAssemblyEnd_MPIAIJ(), which might have deleted old spptr */
15a587d139SMark   if (aijkok && aijkok->device_mat_d.data()) {
16a587d139SMark     A->offloadmask = PETSC_OFFLOAD_GPU; // in GPU mode, no going back. MatSetValues checks this
17a587d139SMark   }
18a587d139SMark 
198c3ff71bSJunchao Zhang   PetscFunctionReturn(0);
208c3ff71bSJunchao Zhang }
218c3ff71bSJunchao Zhang 
22*d71ae5a4SJacob Faibussowitsch PetscErrorCode MatMPIAIJSetPreallocation_MPIAIJKokkos(Mat mat, PetscInt d_nz, const PetscInt d_nnz[], PetscInt o_nz, const PetscInt o_nnz[])
23*d71ae5a4SJacob Faibussowitsch {
248c3ff71bSJunchao Zhang   Mat_MPIAIJ *mpiaij = (Mat_MPIAIJ *)mat->data;
258c3ff71bSJunchao Zhang 
268c3ff71bSJunchao Zhang   PetscFunctionBegin;
279566063dSJacob Faibussowitsch   PetscCall(PetscLayoutSetUp(mat->rmap));
289566063dSJacob Faibussowitsch   PetscCall(PetscLayoutSetUp(mat->cmap));
296a29ce69SStefano Zampini #if defined(PETSC_USE_DEBUG)
308c3ff71bSJunchao Zhang   if (d_nnz) {
316a29ce69SStefano Zampini     PetscInt i;
32ad540459SPierre Jolivet     for (i = 0; i < mat->rmap->n; i++) PetscCheck(d_nnz[i] >= 0, PETSC_COMM_SELF, PETSC_ERR_ARG_OUTOFRANGE, "d_nnz cannot be less than 0: local row %" PetscInt_FMT " value %" PetscInt_FMT, i, d_nnz[i]);
338c3ff71bSJunchao Zhang   }
348c3ff71bSJunchao Zhang   if (o_nnz) {
356a29ce69SStefano Zampini     PetscInt i;
36ad540459SPierre Jolivet     for (i = 0; i < mat->rmap->n; i++) PetscCheck(o_nnz[i] >= 0, PETSC_COMM_SELF, PETSC_ERR_ARG_OUTOFRANGE, "o_nnz cannot be less than 0: local row %" PetscInt_FMT " value %" PetscInt_FMT, i, o_nnz[i]);
378c3ff71bSJunchao Zhang   }
386a29ce69SStefano Zampini #endif
396a29ce69SStefano Zampini #if defined(PETSC_USE_CTABLE)
409566063dSJacob Faibussowitsch   PetscCall(PetscTableDestroy(&mpiaij->colmap));
416a29ce69SStefano Zampini #else
429566063dSJacob Faibussowitsch   PetscCall(PetscFree(mpiaij->colmap));
436a29ce69SStefano Zampini #endif
449566063dSJacob Faibussowitsch   PetscCall(PetscFree(mpiaij->garray));
459566063dSJacob Faibussowitsch   PetscCall(VecDestroy(&mpiaij->lvec));
469566063dSJacob Faibussowitsch   PetscCall(VecScatterDestroy(&mpiaij->Mvctx));
476a29ce69SStefano Zampini   /* Because the B will have been resized we simply destroy it and create a new one each time */
489566063dSJacob Faibussowitsch   PetscCall(MatDestroy(&mpiaij->B));
496a29ce69SStefano Zampini 
506a29ce69SStefano Zampini   if (!mpiaij->A) {
519566063dSJacob Faibussowitsch     PetscCall(MatCreate(PETSC_COMM_SELF, &mpiaij->A));
529566063dSJacob Faibussowitsch     PetscCall(MatSetSizes(mpiaij->A, mat->rmap->n, mat->cmap->n, mat->rmap->n, mat->cmap->n));
536a29ce69SStefano Zampini   }
546a29ce69SStefano Zampini   if (!mpiaij->B) {
556a29ce69SStefano Zampini     PetscMPIInt size;
569566063dSJacob Faibussowitsch     PetscCallMPI(MPI_Comm_size(PetscObjectComm((PetscObject)mat), &size));
579566063dSJacob Faibussowitsch     PetscCall(MatCreate(PETSC_COMM_SELF, &mpiaij->B));
589566063dSJacob Faibussowitsch     PetscCall(MatSetSizes(mpiaij->B, mat->rmap->n, size > 1 ? mat->cmap->N : 0, mat->rmap->n, size > 1 ? mat->cmap->N : 0));
598c3ff71bSJunchao Zhang   }
609566063dSJacob Faibussowitsch   PetscCall(MatSetType(mpiaij->A, MATSEQAIJKOKKOS));
619566063dSJacob Faibussowitsch   PetscCall(MatSetType(mpiaij->B, MATSEQAIJKOKKOS));
629566063dSJacob Faibussowitsch   PetscCall(MatSeqAIJSetPreallocation(mpiaij->A, d_nz, d_nnz));
639566063dSJacob Faibussowitsch   PetscCall(MatSeqAIJSetPreallocation(mpiaij->B, o_nz, o_nnz));
648c3ff71bSJunchao Zhang   mat->preallocated = PETSC_TRUE;
658c3ff71bSJunchao Zhang   PetscFunctionReturn(0);
668c3ff71bSJunchao Zhang }
678c3ff71bSJunchao Zhang 
68*d71ae5a4SJacob Faibussowitsch PetscErrorCode MatMult_MPIAIJKokkos(Mat mat, Vec xx, Vec yy)
69*d71ae5a4SJacob Faibussowitsch {
708c3ff71bSJunchao Zhang   Mat_MPIAIJ *mpiaij = (Mat_MPIAIJ *)mat->data;
718c3ff71bSJunchao Zhang   PetscInt    nt;
728c3ff71bSJunchao Zhang 
738c3ff71bSJunchao Zhang   PetscFunctionBegin;
749566063dSJacob Faibussowitsch   PetscCall(VecGetLocalSize(xx, &nt));
7508401ef6SPierre Jolivet   PetscCheck(nt == mat->cmap->n, PETSC_COMM_SELF, PETSC_ERR_ARG_SIZ, "Incompatible partition of mat (%" PetscInt_FMT ") and xx (%" PetscInt_FMT ")", mat->cmap->n, nt);
769566063dSJacob Faibussowitsch   PetscCall(VecScatterBegin(mpiaij->Mvctx, xx, mpiaij->lvec, INSERT_VALUES, SCATTER_FORWARD));
779566063dSJacob Faibussowitsch   PetscCall((*mpiaij->A->ops->mult)(mpiaij->A, xx, yy));
789566063dSJacob Faibussowitsch   PetscCall(VecScatterEnd(mpiaij->Mvctx, xx, mpiaij->lvec, INSERT_VALUES, SCATTER_FORWARD));
799566063dSJacob Faibussowitsch   PetscCall((*mpiaij->B->ops->multadd)(mpiaij->B, mpiaij->lvec, yy, yy));
808c3ff71bSJunchao Zhang   PetscFunctionReturn(0);
818c3ff71bSJunchao Zhang }
828c3ff71bSJunchao Zhang 
83*d71ae5a4SJacob Faibussowitsch PetscErrorCode MatMultAdd_MPIAIJKokkos(Mat mat, Vec xx, Vec yy, Vec zz)
84*d71ae5a4SJacob Faibussowitsch {
858c3ff71bSJunchao Zhang   Mat_MPIAIJ *mpiaij = (Mat_MPIAIJ *)mat->data;
868c3ff71bSJunchao Zhang   PetscInt    nt;
878c3ff71bSJunchao Zhang 
888c3ff71bSJunchao Zhang   PetscFunctionBegin;
899566063dSJacob Faibussowitsch   PetscCall(VecGetLocalSize(xx, &nt));
9008401ef6SPierre Jolivet   PetscCheck(nt == mat->cmap->n, PETSC_COMM_SELF, PETSC_ERR_ARG_SIZ, "Incompatible partition of mat (%" PetscInt_FMT ") and xx (%" PetscInt_FMT ")", mat->cmap->n, nt);
919566063dSJacob Faibussowitsch   PetscCall(VecScatterBegin(mpiaij->Mvctx, xx, mpiaij->lvec, INSERT_VALUES, SCATTER_FORWARD));
929566063dSJacob Faibussowitsch   PetscCall((*mpiaij->A->ops->multadd)(mpiaij->A, xx, yy, zz));
939566063dSJacob Faibussowitsch   PetscCall(VecScatterEnd(mpiaij->Mvctx, xx, mpiaij->lvec, INSERT_VALUES, SCATTER_FORWARD));
949566063dSJacob Faibussowitsch   PetscCall((*mpiaij->B->ops->multadd)(mpiaij->B, mpiaij->lvec, zz, zz));
958c3ff71bSJunchao Zhang   PetscFunctionReturn(0);
968c3ff71bSJunchao Zhang }
978c3ff71bSJunchao Zhang 
98*d71ae5a4SJacob Faibussowitsch PetscErrorCode MatMultTranspose_MPIAIJKokkos(Mat mat, Vec xx, Vec yy)
99*d71ae5a4SJacob Faibussowitsch {
1008c3ff71bSJunchao Zhang   Mat_MPIAIJ *mpiaij = (Mat_MPIAIJ *)mat->data;
1018c3ff71bSJunchao Zhang   PetscInt    nt;
1028c3ff71bSJunchao Zhang 
1038c3ff71bSJunchao Zhang   PetscFunctionBegin;
1049566063dSJacob Faibussowitsch   PetscCall(VecGetLocalSize(xx, &nt));
10508401ef6SPierre Jolivet   PetscCheck(nt == mat->rmap->n, PETSC_COMM_SELF, PETSC_ERR_ARG_SIZ, "Incompatible partition of mat (%" PetscInt_FMT ") and xx (%" PetscInt_FMT ")", mat->rmap->n, nt);
1069566063dSJacob Faibussowitsch   PetscCall((*mpiaij->B->ops->multtranspose)(mpiaij->B, xx, mpiaij->lvec));
1079566063dSJacob Faibussowitsch   PetscCall((*mpiaij->A->ops->multtranspose)(mpiaij->A, xx, yy));
1089566063dSJacob Faibussowitsch   PetscCall(VecScatterBegin(mpiaij->Mvctx, mpiaij->lvec, yy, ADD_VALUES, SCATTER_REVERSE));
1099566063dSJacob Faibussowitsch   PetscCall(VecScatterEnd(mpiaij->Mvctx, mpiaij->lvec, yy, ADD_VALUES, SCATTER_REVERSE));
1108c3ff71bSJunchao Zhang   PetscFunctionReturn(0);
1118c3ff71bSJunchao Zhang }
1128c3ff71bSJunchao Zhang 
113076ba34aSJunchao Zhang /* Merge the "A, B" matrices of mat into a matrix C.  mat's type is MPIAIJKOKKOS. C's type is MATSEQAIJKOKKOS.
114076ba34aSJunchao Zhang    A is put before B. C's size would be A->rmap->n by (A->cmap->n + B->cmap->n).
115076ba34aSJunchao Zhang    C still uses local column ids. Their corresponding global column ids are returned in glob.
116076ba34aSJunchao Zhang */
117*d71ae5a4SJacob Faibussowitsch PetscErrorCode MatMPIAIJGetLocalMatMerge_MPIAIJKokkos(Mat mat, MatReuse reuse, IS *glob, Mat *C)
118*d71ae5a4SJacob Faibussowitsch {
119076ba34aSJunchao Zhang   Mat             Ad, Ao;
120076ba34aSJunchao Zhang   const PetscInt *cmap;
121076ba34aSJunchao Zhang 
122076ba34aSJunchao Zhang   PetscFunctionBegin;
1239566063dSJacob Faibussowitsch   PetscCall(MatMPIAIJGetSeqAIJ(mat, &Ad, &Ao, &cmap));
1249566063dSJacob Faibussowitsch   PetscCall(MatSeqAIJKokkosMergeMats(Ad, Ao, reuse, C));
125076ba34aSJunchao Zhang   if (glob) {
126076ba34aSJunchao Zhang     PetscInt cst, i, dn, on, *gidx;
1279566063dSJacob Faibussowitsch     PetscCall(MatGetLocalSize(Ad, NULL, &dn));
1289566063dSJacob Faibussowitsch     PetscCall(MatGetLocalSize(Ao, NULL, &on));
1299566063dSJacob Faibussowitsch     PetscCall(MatGetOwnershipRangeColumn(mat, &cst, NULL));
1309566063dSJacob Faibussowitsch     PetscCall(PetscMalloc1(dn + on, &gidx));
131076ba34aSJunchao Zhang     for (i = 0; i < dn; i++) gidx[i] = cst + i;
132076ba34aSJunchao Zhang     for (i = 0; i < on; i++) gidx[i + dn] = cmap[i];
1339566063dSJacob Faibussowitsch     PetscCall(ISCreateGeneral(PetscObjectComm((PetscObject)Ad), dn + on, gidx, PETSC_OWN_POINTER, glob));
134076ba34aSJunchao Zhang   }
135076ba34aSJunchao Zhang   PetscFunctionReturn(0);
136076ba34aSJunchao Zhang }
137076ba34aSJunchao Zhang 
138076ba34aSJunchao Zhang /* Structs used in matrix product C=AB, C=A^tB and C=B^tAB */
139076ba34aSJunchao Zhang struct MatMatStruct {
140076ba34aSJunchao Zhang   MatRowMapKokkosView Cdstart; /* Used to split sequential matrix into petsc's A, B format */
141076ba34aSJunchao Zhang   PetscSF             sf;      /* SF to send/recv matrix entries */
142076ba34aSJunchao Zhang   MatScalarKokkosView abuf;    /* buf of mat values in send/recv */
143076ba34aSJunchao Zhang   Mat                 C1, C2, B_local;
144076ba34aSJunchao Zhang   KokkosCsrMatrix     C1_global, C2_global, C_global;
145076ba34aSJunchao Zhang   KernelHandle        kh;
146*d71ae5a4SJacob Faibussowitsch   MatMatStruct()
147*d71ae5a4SJacob Faibussowitsch   {
148076ba34aSJunchao Zhang     C1 = C2 = B_local = NULL;
149076ba34aSJunchao Zhang     sf                = NULL;
150076ba34aSJunchao Zhang   }
151076ba34aSJunchao Zhang 
152*d71ae5a4SJacob Faibussowitsch   ~MatMatStruct()
153*d71ae5a4SJacob Faibussowitsch   {
154076ba34aSJunchao Zhang     MatDestroy(&C1);
155076ba34aSJunchao Zhang     MatDestroy(&C2);
156076ba34aSJunchao Zhang     MatDestroy(&B_local);
157076ba34aSJunchao Zhang     PetscSFDestroy(&sf);
158076ba34aSJunchao Zhang     kh.destroy_spadd_handle();
159076ba34aSJunchao Zhang   }
160076ba34aSJunchao Zhang };
161076ba34aSJunchao Zhang 
162076ba34aSJunchao Zhang struct MatMatStruct_AB : public MatMatStruct {
163076ba34aSJunchao Zhang   MatColIdxKokkosView rows;
164076ba34aSJunchao Zhang   MatRowMapKokkosView rowoffset;
165076ba34aSJunchao Zhang   Mat                 B_other, C_petsc; /* SEQAIJKOKKOS matrices. TODO: have a better var name than C_petsc */
166076ba34aSJunchao Zhang 
167076ba34aSJunchao Zhang   MatMatStruct_AB() : B_other(NULL), C_petsc(NULL) { }
168*d71ae5a4SJacob Faibussowitsch   ~MatMatStruct_AB()
169*d71ae5a4SJacob Faibussowitsch   {
170076ba34aSJunchao Zhang     MatDestroy(&B_other);
171076ba34aSJunchao Zhang     MatDestroy(&C_petsc);
172076ba34aSJunchao Zhang   }
173076ba34aSJunchao Zhang };
174076ba34aSJunchao Zhang 
175076ba34aSJunchao Zhang struct MatMatStruct_AtB : public MatMatStruct {
176076ba34aSJunchao Zhang   MatRowMapKokkosView srcrowoffset, dstrowoffset;
177076ba34aSJunchao Zhang };
178076ba34aSJunchao Zhang 
1799371c9d4SSatish Balay struct MatProductData_MPIAIJKokkos {
180076ba34aSJunchao Zhang   MatMatStruct_AB  *mmAB;
181076ba34aSJunchao Zhang   MatMatStruct_AtB *mmAtB;
182076ba34aSJunchao Zhang   PetscBool         reusesym;
183076ba34aSJunchao Zhang 
184076ba34aSJunchao Zhang   MatProductData_MPIAIJKokkos() : mmAB(NULL), mmAtB(NULL), reusesym(PETSC_FALSE) { }
185*d71ae5a4SJacob Faibussowitsch   ~MatProductData_MPIAIJKokkos()
186*d71ae5a4SJacob Faibussowitsch   {
187076ba34aSJunchao Zhang     delete mmAB;
188076ba34aSJunchao Zhang     delete mmAtB;
189076ba34aSJunchao Zhang   }
190076ba34aSJunchao Zhang };
191076ba34aSJunchao Zhang 
192*d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatProductDataDestroy_MPIAIJKokkos(void *data)
193*d71ae5a4SJacob Faibussowitsch {
194076ba34aSJunchao Zhang   PetscFunctionBegin;
1959566063dSJacob Faibussowitsch   PetscCallCXX(delete static_cast<MatProductData_MPIAIJKokkos *>(data));
196076ba34aSJunchao Zhang   PetscFunctionReturn(0);
197076ba34aSJunchao Zhang }
198076ba34aSJunchao Zhang 
199076ba34aSJunchao Zhang /* MatSeqAIJKokkosGetCSRMatrixWithGlobalColumnIds - Get a KokkosCsrMatrix from a MATSEQAIJKOKKOS matrix
200076ba34aSJunchao Zhang 
201076ba34aSJunchao Zhang    Input Parameters:
202076ba34aSJunchao Zhang +  A       - the MATSEQAIJKOKKOS matrix
203076ba34aSJunchao Zhang .  N       - new column size for the returned Kokkos matrix
204076ba34aSJunchao Zhang -  l2g     - a map that maps old col ids to new col ids
205076ba34aSJunchao Zhang 
206076ba34aSJunchao Zhang    Output Parameters:
207076ba34aSJunchao Zhang .  csrmat  - the Kokkos matrix, which has the same row size as A, shares a, i but not j with A.
208076ba34aSJunchao Zhang  */
209*d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatSeqAIJKokkosGetCSRMatrixWithGlobalColumnIds(Mat A, PetscInt N, const ConstMatColIdxKokkosView &l2g, KokkosCsrMatrix &csrmat)
210*d71ae5a4SJacob Faibussowitsch {
211076ba34aSJunchao Zhang   KokkosCsrMatrix    &orig = static_cast<Mat_SeqAIJKokkos *>(A->spptr)->csrmat;
212076ba34aSJunchao Zhang   MatColIdxKokkosView jg("jg", orig.nnz()); /* New j array for csrmat */
213076ba34aSJunchao Zhang 
214076ba34aSJunchao Zhang   PetscFunctionBegin;
2159371c9d4SSatish Balay   PetscCallCXX(Kokkos::parallel_for(
2169371c9d4SSatish Balay     orig.nnz(), KOKKOS_LAMBDA(const PetscInt i) { jg(i) = l2g(orig.graph.entries(i)); }));
2179566063dSJacob Faibussowitsch   PetscCallCXX(csrmat = KokkosCsrMatrix("csrmat", orig.numRows(), N, orig.nnz(), orig.values, orig.graph.row_map, jg));
218076ba34aSJunchao Zhang   PetscFunctionReturn(0);
219076ba34aSJunchao Zhang }
220076ba34aSJunchao Zhang 
221076ba34aSJunchao Zhang /* MatSetMPIAIJKokkosWithSplitSeqAIJKokkosMatrices - Set the diag and offdiag matrices of a MATMPIAIJKOKKOS matrix.
222076ba34aSJunchao Zhang    It is similar to MatCreateMPIAIJWithSplitArrays.
223076ba34aSJunchao Zhang 
224076ba34aSJunchao Zhang   Input Parameters:
225076ba34aSJunchao Zhang +  mat   - the MATMPIAIJKOKKOS matrix, which should have its type and layout set, but should not have its diag, offdiag matrices set
226076ba34aSJunchao Zhang .  A     - the diag matrix using local col ids
227076ba34aSJunchao Zhang -  B     - the offdiag matrix using global col ids
228076ba34aSJunchao Zhang 
229076ba34aSJunchao Zhang   Output Parameters:
230076ba34aSJunchao Zhang .  mat   - the updated MATMPIAIJKOKKOS matrix
231076ba34aSJunchao Zhang */
232*d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatSetMPIAIJKokkosWithSplitSeqAIJKokkosMatrices(Mat mat, Mat A, Mat B)
233*d71ae5a4SJacob Faibussowitsch {
234076ba34aSJunchao Zhang   Mat_MPIAIJ       *mpiaij = static_cast<Mat_MPIAIJ *>(mat->data);
235076ba34aSJunchao Zhang   PetscInt          m, n, M, N, Am, An, Bm, Bn;
236076ba34aSJunchao Zhang   Mat_SeqAIJKokkos *bkok = static_cast<Mat_SeqAIJKokkos *>(B->spptr);
237076ba34aSJunchao Zhang 
238076ba34aSJunchao Zhang   PetscFunctionBegin;
2399566063dSJacob Faibussowitsch   PetscCall(MatGetSize(mat, &M, &N));
2409566063dSJacob Faibussowitsch   PetscCall(MatGetLocalSize(mat, &m, &n));
2419566063dSJacob Faibussowitsch   PetscCall(MatGetLocalSize(A, &Am, &An));
2429566063dSJacob Faibussowitsch   PetscCall(MatGetLocalSize(B, &Bm, &Bn));
243076ba34aSJunchao Zhang 
244aed4548fSBarry Smith   PetscCheck(m == Am && m == Bm, PETSC_COMM_SELF, PETSC_ERR_PLIB, "local number of rows do not match");
24508401ef6SPierre Jolivet   PetscCheck(n == An, PETSC_COMM_SELF, PETSC_ERR_PLIB, "local number of columns do not match");
24608401ef6SPierre Jolivet   PetscCheck(N == Bn, PETSC_COMM_SELF, PETSC_ERR_PLIB, "global number of columns do not match");
24708401ef6SPierre Jolivet   PetscCheck(!mpiaij->A && !mpiaij->B, PETSC_COMM_SELF, PETSC_ERR_PLIB, "A, B of the MPIAIJ matrix are not empty");
248076ba34aSJunchao Zhang   mpiaij->A = A;
249076ba34aSJunchao Zhang   mpiaij->B = B;
250076ba34aSJunchao Zhang 
251076ba34aSJunchao Zhang   mat->preallocated     = PETSC_TRUE;
252076ba34aSJunchao Zhang   mat->nooffprocentries = PETSC_TRUE; /* See MatAssemblyBegin_MPIAIJ. In effect, making MatAssemblyBegin a nop */
253076ba34aSJunchao Zhang 
2549566063dSJacob Faibussowitsch   PetscCall(MatSetOption(mat, MAT_NO_OFF_PROC_ENTRIES, PETSC_TRUE));
2559566063dSJacob Faibussowitsch   PetscCall(MatAssemblyBegin(mat, MAT_FINAL_ASSEMBLY));
256076ba34aSJunchao Zhang   /* MatAssemblyEnd is critical here. It sets mat->offloadmask according to A and B's, and
257076ba34aSJunchao Zhang     also gets mpiaij->B compacted, with its col ids and size reduced
258076ba34aSJunchao Zhang   */
2599566063dSJacob Faibussowitsch   PetscCall(MatAssemblyEnd(mat, MAT_FINAL_ASSEMBLY));
2609566063dSJacob Faibussowitsch   PetscCall(MatSetOption(mat, MAT_NO_OFF_PROC_ENTRIES, PETSC_FALSE));
2619566063dSJacob Faibussowitsch   PetscCall(MatSetOption(mat, MAT_NEW_NONZERO_LOCATION_ERR, PETSC_TRUE));
262076ba34aSJunchao Zhang 
263076ba34aSJunchao Zhang   /* Update bkok with new local col ids (stored on host) and size */
264076ba34aSJunchao Zhang   bkok->j_dual.modify_host();
265076ba34aSJunchao Zhang   bkok->j_dual.sync_device();
266076ba34aSJunchao Zhang   bkok->SetColSize(mpiaij->B->cmap->n);
267076ba34aSJunchao Zhang   PetscFunctionReturn(0);
268076ba34aSJunchao Zhang }
269076ba34aSJunchao Zhang 
270076ba34aSJunchao Zhang /* MatSeqAIJKokkosBcast - Bcast rows of a SEQAIJKOKKOS matrice (B) to form a SEQAIJKOKKOS matrix (C).
271076ba34aSJunchao Zhang 
272076ba34aSJunchao Zhang    It is essentially the MPIAIJKOKKOS counterpart of MatGetBrowsOfAoCols_MPIAIJ, but supports device and uses PetscSF.
273076ba34aSJunchao Zhang    In the given ownerSF, leaves correspond to rows in C, and roots correspond to rows in B. Roots may connect to multiple leaves.
274076ba34aSJunchao Zhang    Suppose C's j-th row is connected to a root identified by PetscSFNode (k,i), it means we will bcast the i-th row of B on rank k
275076ba34aSJunchao Zhang    to j-th row of C. ownerSF's leaves must be contiguous (in other words, as if ilocal=NULL was used to set its graph).
276076ba34aSJunchao Zhang 
277076ba34aSJunchao Zhang    Collective on comm of ownerSF
278076ba34aSJunchao Zhang 
279076ba34aSJunchao Zhang    Input Parameters:
280076ba34aSJunchao Zhang +   B       - the SEQAIJKOKKOS matrix, using local col ids
281076ba34aSJunchao Zhang .   reuse   - either MAT_INITIAL_MATRIX or MAT_REUSE_MATRIX
282076ba34aSJunchao Zhang .   N       - global col ids are in range of [0,N). N Must be the same across ranks (nonsignificant in MAT_REUSE_MATRIX)
283076ba34aSJunchao Zhang .   l2g     - a map mapping B's local col ids to global ones (nonsignificant in MAT_REUSE_MATRIX)
284076ba34aSJunchao Zhang .   ownerSF - the ownership SF (nonsignificant in MAT_REUSE_MATRIX)
285076ba34aSJunchao Zhang 
286076ba34aSJunchao Zhang    Input/Output Parameters (out when resue = MAT_INITIAL_MATRIX, inout when reuse = MAT_REUSE_MATRIX)
287076ba34aSJunchao Zhang +   bcastSF   - the SF used to bcast rows of B. This plain SF does buffer (abuf) to buffer (Ca) send/recv. In this SF, vertices are nonzeros.
288076ba34aSJunchao Zhang .   abuf      - buffer for sending matrix values
289076ba34aSJunchao Zhang .   rows      - array containing indices of (local) rows that this rank needs to bcast to others. Each receiver rank has a chunk in rows[].
290076ba34aSJunchao Zhang                 Values in rows[] might have repeats, which simply indicates a row will be bcast'ed to multiple neighbors.
291076ba34aSJunchao Zhang .   rowoffset - For each row in rows[], it will be copied to rowoffset[] at abuf[]
292076ba34aSJunchao Zhang -   C         -  the SEQAIJKOKKOS matrix made of the bcast'ed rows, using local col ids.
293076ba34aSJunchao Zhang */
294*d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatSeqAIJKokkosBcast(Mat B, MatReuse reuse, PetscInt N, const ConstMatColIdxKokkosView &l2g, PetscSF ownerSF, PetscSF &bcastSF, MatScalarKokkosView &abuf, MatColIdxKokkosView &rows, MatRowMapKokkosView &rowoffset, Mat &C)
295*d71ae5a4SJacob Faibussowitsch {
296076ba34aSJunchao Zhang   Mat_SeqAIJKokkos *bkok, *ckok;
297076ba34aSJunchao Zhang 
298076ba34aSJunchao Zhang   PetscFunctionBegin;
2999566063dSJacob Faibussowitsch   PetscCall(MatSeqAIJKokkosSyncDevice(B)); /* Make sure B->spptr is accessible */
300076ba34aSJunchao Zhang   bkok = static_cast<Mat_SeqAIJKokkos *>(B->spptr);
301076ba34aSJunchao Zhang 
302076ba34aSJunchao Zhang   if (reuse == MAT_REUSE_MATRIX) {
303076ba34aSJunchao Zhang     ckok = static_cast<Mat_SeqAIJKokkos *>(C->spptr);
304076ba34aSJunchao Zhang 
305076ba34aSJunchao Zhang     const auto &Ba = bkok->a_dual.view_device();
306076ba34aSJunchao Zhang     const auto &Bi = bkok->i_dual.view_device();
307076ba34aSJunchao Zhang     const auto &Ca = ckok->a_dual.view_device();
308076ba34aSJunchao Zhang 
309076ba34aSJunchao Zhang     /* Copy Ba to abuf */
3109371c9d4SSatish Balay     Kokkos::parallel_for(
3119371c9d4SSatish Balay       Kokkos::TeamPolicy<>(rows.extent(0), Kokkos::AUTO()), KOKKOS_LAMBDA(const KokkosTeamMemberType &t) {
312076ba34aSJunchao Zhang         PetscInt i    = t.league_rank(); /* rows[i] is r-th row of B */
313076ba34aSJunchao Zhang         PetscInt r    = rows(i);
314076ba34aSJunchao Zhang         PetscInt base = rowoffset(i); /* Copy r-th row of B to this offset in abuf[] */
3159371c9d4SSatish Balay         Kokkos::parallel_for(Kokkos::TeamThreadRange(t, Bi(r + 1) - Bi(r)), [&](PetscInt k) { abuf(base + k) = Ba(Bi(r) + k); });
316076ba34aSJunchao Zhang       });
317076ba34aSJunchao Zhang 
318076ba34aSJunchao Zhang     /* Send abuf to Ca through bcastSF and then mark C is updated on device */
3199566063dSJacob Faibussowitsch     PetscCall(PetscSFBcastBegin(bcastSF, MPIU_SCALAR, abuf.data(), Ca.data(), MPI_REPLACE)); /* TODO: get memtype for abuf */
3209566063dSJacob Faibussowitsch     PetscCall(PetscSFBcastEnd(bcastSF, MPIU_SCALAR, abuf.data(), Ca.data(), MPI_REPLACE));
321076ba34aSJunchao Zhang     ckok->a_dual.modify_device();
322076ba34aSJunchao Zhang   } else if (reuse == MAT_INITIAL_MATRIX) {
323076ba34aSJunchao Zhang     MPI_Comm    comm;
324076ba34aSJunchao Zhang     PetscMPIInt tag;
325076ba34aSJunchao Zhang     PetscInt    k, Cm, Cn, Cnnz, *Ci_h, nroots, nleaves;
326076ba34aSJunchao Zhang 
3279566063dSJacob Faibussowitsch     PetscCallMPI(PetscObjectGetComm((PetscObject)ownerSF, &comm));
3289566063dSJacob Faibussowitsch     PetscCall(PetscSFGetGraph(ownerSF, &nroots, &nleaves, NULL, NULL));
329076ba34aSJunchao Zhang     Cm = nleaves; /* row size of C */
330076ba34aSJunchao Zhang     Cn = N;       /* col size of C, which initially uses global ids, so we can safely set its col size as N */
331076ba34aSJunchao Zhang 
332076ba34aSJunchao Zhang     /* Get row lens (nz) of B's rows for later fast query */
333076ba34aSJunchao Zhang     PetscInt       *Browlens;
334076ba34aSJunchao Zhang     const PetscInt *tmp = bkok->i_host_data();
3359566063dSJacob Faibussowitsch     PetscCall(PetscMalloc1(nroots, &Browlens));
336076ba34aSJunchao Zhang     for (k = 0; k < nroots; k++) Browlens[k] = tmp[k + 1] - tmp[k];
337076ba34aSJunchao Zhang 
338076ba34aSJunchao Zhang     /* By ownerSF, each proc gets lens of rows of C */
339076ba34aSJunchao Zhang     MatRowMapKokkosDualView Ci("i", Cm + 1); /* C's rowmap */
340076ba34aSJunchao Zhang     Ci_h    = Ci.view_host().data();
341076ba34aSJunchao Zhang     Ci_h[0] = 0;
3429566063dSJacob Faibussowitsch     PetscCall(PetscSFBcastWithMemTypeBegin(ownerSF, MPIU_INT, PETSC_MEMTYPE_HOST, Browlens, PETSC_MEMTYPE_HOST, &Ci_h[1], MPI_REPLACE));
3439566063dSJacob Faibussowitsch     PetscCall(PetscSFBcastEnd(ownerSF, MPIU_INT, Browlens, &Ci_h[1], MPI_REPLACE));
344076ba34aSJunchao Zhang     for (k = 1; k < Cm + 1; k++) Ci_h[k] += Ci_h[k - 1]; /* Convert lens to CSR */
345076ba34aSJunchao Zhang     Cnnz = Ci_h[Cm];
346076ba34aSJunchao Zhang     Ci.modify_host();
347076ba34aSJunchao Zhang     Ci.sync_device();
348076ba34aSJunchao Zhang 
349076ba34aSJunchao Zhang     /* With the newly known Cnnz, we are able to allocate (j, a) for C on host & device */
350076ba34aSJunchao Zhang     MatColIdxKokkosDualView Cj("j", Cnnz);
351076ba34aSJunchao Zhang     MatScalarKokkosDualView Ca("a", Cnnz);
352076ba34aSJunchao Zhang 
353076ba34aSJunchao Zhang     /* Now build the bcastSF to fill Ca, Cj. This plain SF only does (contiguous) buffer to buffer send/recv */
354076ba34aSJunchao Zhang     const PetscMPIInt *iranks, *ranks;
355076ba34aSJunchao Zhang     const PetscInt    *ioffset, *irootloc, *roffset;
356076ba34aSJunchao Zhang     PetscInt           i, j, niranks, nranks, *sdisp, *rdisp, *rowptr;
357076ba34aSJunchao Zhang     MPI_Request       *reqs;
358076ba34aSJunchao Zhang 
3599566063dSJacob Faibussowitsch     PetscCall(PetscSFGetLeafRanks(ownerSF, &niranks, &iranks, &ioffset, &irootloc));                      /* irootloc[] contains indices of rows I need to send to each receiver */
3609566063dSJacob Faibussowitsch     PetscCall(PetscSFGetRootRanks(ownerSF, &nranks, &ranks, &roffset, NULL /*rmine*/, NULL /*rremote*/)); /* recv info */
361076ba34aSJunchao Zhang 
362076ba34aSJunchao Zhang     /* figure out offsets at the send buffer, to build the SF
363076ba34aSJunchao Zhang       sdisp[]  - stores offsets of nonzeros (in abuf or jbuf, see later) I need to send, per receiver.
364076ba34aSJunchao Zhang       rowptr[] - stores offsets for data of each row in abuf
365076ba34aSJunchao Zhang 
366076ba34aSJunchao Zhang       rdisp[]  - to receive sdisp[]
367076ba34aSJunchao Zhang     */
3689566063dSJacob Faibussowitsch     PetscCall(PetscMalloc3(niranks + 1, &sdisp, nranks, &rdisp, niranks + nranks, &reqs));
369076ba34aSJunchao Zhang     MatRowMapKokkosViewHost rowptr_h("rowptr_h", ioffset[niranks] + 1); /* Let Kokkos do the allocation, so that we can do an easy mirror later */
370076ba34aSJunchao Zhang     rowptr = rowptr_h.data();
371076ba34aSJunchao Zhang 
372076ba34aSJunchao Zhang     sdisp[0]  = 0;
373076ba34aSJunchao Zhang     rowptr[0] = 0;
374076ba34aSJunchao Zhang     for (i = 0; i < niranks; i++) { /* for each receiver */
375076ba34aSJunchao Zhang       PetscInt len, nz = 0;
376076ba34aSJunchao Zhang       for (j = ioffset[i]; j < ioffset[i + 1]; j++) { /* for each row to this receiver */
377076ba34aSJunchao Zhang         len           = Browlens[irootloc[j]];
378076ba34aSJunchao Zhang         rowptr[j + 1] = rowptr[j] + len;
379076ba34aSJunchao Zhang         nz += len;
380076ba34aSJunchao Zhang       }
381076ba34aSJunchao Zhang       sdisp[i + 1] = sdisp[i] + nz;
382076ba34aSJunchao Zhang     }
3839566063dSJacob Faibussowitsch     PetscCallMPI(PetscCommGetNewTag(comm, &tag));
3849566063dSJacob Faibussowitsch     for (i = 0; i < nranks; i++) PetscCallMPI(MPI_Irecv(&rdisp[i], 1, MPIU_INT, ranks[i], tag, comm, &reqs[i]));
3859566063dSJacob Faibussowitsch     for (i = 0; i < niranks; i++) PetscCallMPI(MPI_Isend(&sdisp[i], 1, MPIU_INT, iranks[i], tag, comm, &reqs[nranks + i]));
3869566063dSJacob Faibussowitsch     PetscCallMPI(MPI_Waitall(niranks + nranks, reqs, MPI_STATUSES_IGNORE));
387076ba34aSJunchao Zhang 
388076ba34aSJunchao Zhang     PetscInt     nleaves2 = Cnnz;           /* leaves are the nonzeros I will receive */
389076ba34aSJunchao Zhang     PetscInt     nroots2  = sdisp[niranks]; /* roots are the nonzeros (in abuf) I will send */
390076ba34aSJunchao Zhang     PetscSFNode *iremote;
3919566063dSJacob Faibussowitsch     PetscCall(PetscMalloc1(nleaves2, &iremote));
392076ba34aSJunchao Zhang     for (i = 0; i < nranks; i++) { /* for each sender */
393076ba34aSJunchao Zhang       k = 0;
394076ba34aSJunchao Zhang       for (j = Ci_h[roffset[i]]; j < Ci_h[roffset[i + 1]]; j++) {
395076ba34aSJunchao Zhang         iremote[j].rank  = ranks[i];
396076ba34aSJunchao Zhang         iremote[j].index = rdisp[i] + k;
397076ba34aSJunchao Zhang         k++;
398076ba34aSJunchao Zhang       }
399076ba34aSJunchao Zhang     }
400076ba34aSJunchao Zhang     /* TODO: we should extend PetscSF APIs for this buffer-to-buffer send/recv */
4019566063dSJacob Faibussowitsch     PetscCall(PetscSFCreate(comm, &bcastSF));
4029566063dSJacob Faibussowitsch     PetscCall(PetscSFSetGraph(bcastSF, nroots2, nleaves2, NULL /*ilocal*/, PETSC_OWN_POINTER, iremote, PETSC_OWN_POINTER));
403076ba34aSJunchao Zhang 
404076ba34aSJunchao Zhang     /* Extract selected rows of B, and copy their (a, j) into abuf[] and jbuf[], with j converted
405076ba34aSJunchao Zhang       from local to global. Then use bcastSF to fill Ca, Cj.
406076ba34aSJunchao Zhang     */
407076ba34aSJunchao Zhang     ConstMatColIdxKokkosViewHost rows_h(irootloc, ioffset[niranks]); /* irootloc[] stores indices of rows I need to send */
408076ba34aSJunchao Zhang     MatColIdxKokkosView          rows("rows", ioffset[niranks]);
409076ba34aSJunchao Zhang     Kokkos::deep_copy(rows, rows_h); /* Use deep copy since irootoc is managed by PetscSF and we want 'rows' to be standalone */
410076ba34aSJunchao Zhang 
411076ba34aSJunchao Zhang     rowoffset = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), rowptr_h); /* If no device, rowoffset will be an alias to rowptr_h */
412076ba34aSJunchao Zhang 
413076ba34aSJunchao Zhang     MatColIdxKokkosView jbuf("jbuf", sdisp[niranks]);   /* send buf for (global) col ids */
414076ba34aSJunchao Zhang     abuf = MatScalarKokkosView("abuf", sdisp[niranks]); /* send buf for mat values */
415076ba34aSJunchao Zhang 
416076ba34aSJunchao Zhang     const auto &Ba = bkok->a_dual.view_device();
417076ba34aSJunchao Zhang     const auto &Bi = bkok->i_dual.view_device();
418076ba34aSJunchao Zhang     const auto &Bj = bkok->j_dual.view_device();
419076ba34aSJunchao Zhang 
420076ba34aSJunchao Zhang     /* Copy Ba, Bj to abuf, jbuf with change col ids from local to global */
4219371c9d4SSatish Balay     Kokkos::parallel_for(
4229371c9d4SSatish Balay       Kokkos::TeamPolicy<>(rows.extent(0), Kokkos::AUTO()), KOKKOS_LAMBDA(const KokkosTeamMemberType &t) {
423076ba34aSJunchao Zhang         PetscInt i    = t.league_rank(); /* rows[i] is r-th row of B */
424076ba34aSJunchao Zhang         PetscInt r    = rows(i);
425076ba34aSJunchao Zhang         PetscInt base = rowoffset(i); /* Copy r-th row of B to this offset in abuf[] */
426076ba34aSJunchao Zhang         Kokkos::parallel_for(Kokkos::TeamThreadRange(t, Bi(r + 1) - Bi(r)), [&](PetscInt k) {
427076ba34aSJunchao Zhang           abuf(base + k) = Ba(Bi(r) + k);
428076ba34aSJunchao Zhang           jbuf(base + k) = l2g(Bj(Bi(r) + k));
429076ba34aSJunchao Zhang         });
430076ba34aSJunchao Zhang       });
431076ba34aSJunchao Zhang 
432076ba34aSJunchao Zhang     /* Send abuf & jbuf to fill Ca, Cj */
4339566063dSJacob Faibussowitsch     PetscCall(PetscSFBcastBegin(bcastSF, MPIU_INT, jbuf.data(), Cj.view_device().data(), MPI_REPLACE));
4349566063dSJacob Faibussowitsch     PetscCall(PetscSFBcastBegin(bcastSF, MPIU_SCALAR, abuf.data(), Ca.view_device().data(), MPI_REPLACE));
4359566063dSJacob Faibussowitsch     PetscCall(PetscSFBcastEnd(bcastSF, MPIU_INT, jbuf.data(), Cj.view_device().data(), MPI_REPLACE));
4369566063dSJacob Faibussowitsch     PetscCall(PetscSFBcastEnd(bcastSF, MPIU_SCALAR, abuf.data(), Ca.view_device().data(), MPI_REPLACE));
437076ba34aSJunchao Zhang     Cj.modify_device(); /* Mark Cj, Ca modified on device, but only sync Cj since we might not need Ca on host at all */
438076ba34aSJunchao Zhang     Cj.sync_host();
439076ba34aSJunchao Zhang     Ca.modify_device();
440076ba34aSJunchao Zhang 
441076ba34aSJunchao Zhang     /* Construct C with Ca, Ci, Cj */
442076ba34aSJunchao Zhang     auto ckok = new Mat_SeqAIJKokkos(Cm, Cn, Cnnz, Ci, Cj, Ca);
4439566063dSJacob Faibussowitsch     PetscCall(MatCreateSeqAIJKokkosWithCSRMatrix(PETSC_COMM_SELF, ckok, &C));
4449566063dSJacob Faibussowitsch     PetscCall(PetscFree3(sdisp, rdisp, reqs));
4459566063dSJacob Faibussowitsch     PetscCall(PetscFree(Browlens));
44698921bdaSJacob Faibussowitsch   } else SETERRQ(PETSC_COMM_SELF, PETSC_ERR_PLIB, "Unsupported MatReuse enum %d", reuse);
447076ba34aSJunchao Zhang   PetscFunctionReturn(0);
448076ba34aSJunchao Zhang }
449076ba34aSJunchao Zhang 
450076ba34aSJunchao Zhang /* MatSeqAIJKokkosReduce - Reduce rows of a SEQAIJKOKKOS matrix (A) to form a Kokkos Csr matrix (C)
451076ba34aSJunchao Zhang 
452076ba34aSJunchao Zhang   It is the reverse of MatSeqAIJKokkosBcast in some sense.
453076ba34aSJunchao Zhang 
454076ba34aSJunchao Zhang   Think each row of A as a leaf, then the given ownerSF specifies roots of the leaves. Roots may connect to multiple leaves.
455076ba34aSJunchao Zhang   In this routine, we reduce (i.e., concatenate) leaves (rows) at their roots to form potentially longer rows in C. Such rows might
456076ba34aSJunchao Zhang   contain repeats, which does not matter since they will be summed up by other routines. C's row size will be nroots of ownerSF.
457076ba34aSJunchao Zhang 
458076ba34aSJunchao Zhang   Input Parameters:
459076ba34aSJunchao Zhang +  A        - the SEQAIJKOKKOS matrix to be reduced
460076ba34aSJunchao Zhang .  reuse    - either MAT_INITIAL_MATRIX or MAT_REUSE_MATRIX
461076ba34aSJunchao Zhang .  local    - true if A uses local col ids; false if A is already in global col ids.
462076ba34aSJunchao Zhang .  N        - if local, N is A's global col size
463076ba34aSJunchao Zhang .  l2g      - if local, a map mapping A's local col ids to global ones, which are in range of [0,N).
464076ba34aSJunchao Zhang -  ownerSF  - the SF specifies ownership (root) of rows in A
465076ba34aSJunchao Zhang 
466076ba34aSJunchao Zhang   Output Parameters:
467076ba34aSJunchao Zhang +  reduceSF    - the SF to reduce A's rows to contiguous buffers at the receiver side
468076ba34aSJunchao Zhang .  abuf         - a contiguous buffer to receive A's rows sent to this proc. Suppose there are 'nrows' such rows.
469076ba34aSJunchao Zhang .  srcrowoffset - offset array of size nrows+1. Each entry is the corresponding row's offset in abuf[]. srcrowoffset[i+1]-srcrowoffset[i] is row i's len.
470076ba34aSJunchao Zhang .  dstrowoffset - offset array of size nrows. Each entry is the corresponding row's offset in Ca[], i.e., C's 'a' array. Row i, i+1 in abuf[] may go to
471076ba34aSJunchao Zhang                   unrelated places in Ca, so dstrowoffset is not in CSR-like format as srcrowoffset.
472076ba34aSJunchao Zhang -  C            - the matrix made up by rows sent to me from other ranks, using global col ids
473076ba34aSJunchao Zhang 
474076ba34aSJunchao Zhang    TODO: we can even have MatSeqAIJKokkosReduceBegin/End to provide oppertunity for callers to overlap comp./comm. when reuse = MAT_REUSE_MATRIX.
475076ba34aSJunchao Zhang  */
476*d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatSeqAIJKokkosReduce(Mat A, MatReuse reuse, PetscBool local, PetscInt N, const ConstMatColIdxKokkosView &l2g, PetscSF ownerSF, PetscSF &reduceSF, MatScalarKokkosView &abuf, MatRowMapKokkosView &srcrowoffset, MatRowMapKokkosView &dstrowoffset, KokkosCsrMatrix &C)
477*d71ae5a4SJacob Faibussowitsch {
478076ba34aSJunchao Zhang   PetscInt          i, r, Am, An, Annz, Cnnz, nrows;
479076ba34aSJunchao Zhang   const PetscInt   *Ai;
480076ba34aSJunchao Zhang   Mat_SeqAIJKokkos *akok;
481076ba34aSJunchao Zhang 
482076ba34aSJunchao Zhang   PetscFunctionBegin;
4839566063dSJacob Faibussowitsch   PetscCall(MatSeqAIJKokkosSyncDevice(A)); /* So that A's latest data is on device */
4849566063dSJacob Faibussowitsch   PetscCall(MatGetSize(A, &Am, &An));
485076ba34aSJunchao Zhang   Ai   = static_cast<Mat_SeqAIJ *>(A->data)->i;
486076ba34aSJunchao Zhang   akok = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
487076ba34aSJunchao Zhang   Annz = Ai[Am];
488076ba34aSJunchao Zhang 
489076ba34aSJunchao Zhang   if (reuse == MAT_REUSE_MATRIX) {
490076ba34aSJunchao Zhang     /* Send Aa to abuf */
4919566063dSJacob Faibussowitsch     PetscCallMPI(PetscSFReduceBegin(reduceSF, MPIU_SCALAR, akok->a_device_data(), abuf.data(), MPI_REPLACE));
4929566063dSJacob Faibussowitsch     PetscCallMPI(PetscSFReduceEnd(reduceSF, MPIU_SCALAR, akok->a_device_data(), abuf.data(), MPI_REPLACE));
493076ba34aSJunchao Zhang 
494076ba34aSJunchao Zhang     /* Copy abuf to Ca */
495076ba34aSJunchao Zhang     const MatScalarKokkosView &Ca = C.values;
496076ba34aSJunchao Zhang     nrows                         = dstrowoffset.extent(0); /* Not srcrowoffset[] since it has an extra entry for CSR */
4979371c9d4SSatish Balay     Kokkos::parallel_for(
4989371c9d4SSatish Balay       Kokkos::TeamPolicy<>(nrows, Kokkos::AUTO()), KOKKOS_LAMBDA(const KokkosTeamMemberType &t) {
499076ba34aSJunchao Zhang         PetscInt i   = t.league_rank();
500076ba34aSJunchao Zhang         PetscInt src = srcrowoffset(i), dst = dstrowoffset(i);
501076ba34aSJunchao Zhang         PetscInt len = srcrowoffset(i + 1) - srcrowoffset(i);
502076ba34aSJunchao Zhang         Kokkos::parallel_for(Kokkos::TeamThreadRange(t, len), [&](PetscInt k) { Ca(dst + k) = abuf(src + k); });
503076ba34aSJunchao Zhang       });
504076ba34aSJunchao Zhang   } else if (reuse == MAT_INITIAL_MATRIX) {
505076ba34aSJunchao Zhang     MPI_Comm     comm;
506076ba34aSJunchao Zhang     MPI_Request *reqs;
507076ba34aSJunchao Zhang     PetscMPIInt  tag;
508076ba34aSJunchao Zhang     PetscInt     Cm;
509076ba34aSJunchao Zhang 
5109566063dSJacob Faibussowitsch     PetscCall(PetscObjectGetComm((PetscObject)ownerSF, &comm));
5119566063dSJacob Faibussowitsch     PetscCall(PetscCommGetNewTag(comm, &tag));
512076ba34aSJunchao Zhang 
513076ba34aSJunchao Zhang     PetscInt           niranks, nranks, nroots, nleaves;
514076ba34aSJunchao Zhang     const PetscMPIInt *iranks, *ranks;
515076ba34aSJunchao Zhang     const PetscInt    *ioffset, *rows, *roffset; /* rows[] contains local indices of rows scattered to me from others. ioffset[] is a CSR on rows[] */
5169566063dSJacob Faibussowitsch     PetscCall(PetscSFSetUp(ownerSF));
5179566063dSJacob Faibussowitsch     PetscCall(PetscSFGetLeafRanks(ownerSF, &niranks, &iranks, &ioffset, &rows));                          /* recv info: iranks[] will send rows to me */
5189566063dSJacob Faibussowitsch     PetscCall(PetscSFGetRootRanks(ownerSF, &nranks, &ranks, &roffset, NULL /*rmine*/, NULL /*rremote*/)); /* send info */
5199566063dSJacob Faibussowitsch     PetscCall(PetscSFGetGraph(ownerSF, &nroots, &nleaves, NULL, NULL));
52008401ef6SPierre Jolivet     PetscCheck(nleaves == Am, PETSC_COMM_SELF, PETSC_ERR_PLIB, "ownerSF's nleaves(%" PetscInt_FMT ") != row size of A(%" PetscInt_FMT ")", nleaves, Am);
521076ba34aSJunchao Zhang     Cm    = nroots;
522076ba34aSJunchao Zhang     nrows = ioffset[niranks]; /* # of rows to be received. Might receive same row (each is partial) from different senders */
523076ba34aSJunchao Zhang 
524076ba34aSJunchao Zhang     /* Tell owners how long each row I will send */
525076ba34aSJunchao Zhang     PetscInt               *srowlens;                              /* send buf of row lens */
526076ba34aSJunchao Zhang     MatRowMapKokkosViewHost rrowlens_h("rrowoffset_h", nrows + 1); /* recv buf of row lens. +1 to make CSR later. Memory might be passed to other views */
527076ba34aSJunchao Zhang     PetscInt               *rrowlens = rrowlens_h.data();
528076ba34aSJunchao Zhang 
5299566063dSJacob Faibussowitsch     PetscCall(PetscMalloc2(Am, &srowlens, niranks + nranks, &reqs));
530076ba34aSJunchao Zhang     for (i = 0; i < Am; i++) srowlens[i] = Ai[i + 1] - Ai[i];
531076ba34aSJunchao Zhang     rrowlens[0] = 0;
532076ba34aSJunchao Zhang     rrowlens++; /* shift the pointer to make the following expression more readable */
5339566063dSJacob Faibussowitsch     for (i = 0; i < niranks; i++) PetscCallMPI(MPI_Irecv(&rrowlens[ioffset[i]], ioffset[i + 1] - ioffset[i], MPIU_INT, iranks[i], tag, comm, &reqs[i]));
5349566063dSJacob Faibussowitsch     for (i = 0; i < nranks; i++) PetscCallMPI(MPI_Isend(&srowlens[roffset[i]], roffset[i + 1] - roffset[i], MPIU_INT, ranks[i], tag, comm, &reqs[niranks + i]));
5359566063dSJacob Faibussowitsch     PetscCallMPI(MPI_Waitall(niranks + nranks, reqs, MPI_STATUSES_IGNORE));
536076ba34aSJunchao Zhang 
537076ba34aSJunchao Zhang     /* Owner builds Ci on host by histogramming rrowlens[] */
538076ba34aSJunchao Zhang     MatRowMapKokkosViewHost Ci_h("i", Cm + 1);
539076ba34aSJunchao Zhang     Kokkos::deep_copy(Ci_h, 0); /* Zero Ci */
540076ba34aSJunchao Zhang     MatRowMapType *Ci_ptr = Ci_h.data();
541076ba34aSJunchao Zhang 
542076ba34aSJunchao Zhang     for (i = 0; i < nrows; i++) {
543076ba34aSJunchao Zhang       r = rows[i]; /* local row id of i-th received row */
544076ba34aSJunchao Zhang #if defined(PETSC_USE_DEBUG)
545aed4548fSBarry Smith       PetscCheck(r >= 0 && r < Cm, PETSC_COMM_SELF, PETSC_ERR_PLIB, "local row id (%" PetscInt_FMT ") is out of range [0,%" PetscInt_FMT ")", r, Cm);
546076ba34aSJunchao Zhang #endif
547076ba34aSJunchao Zhang       Ci_ptr[r + 1] += rrowlens[i]; /* add to length of row r in C */
548076ba34aSJunchao Zhang     }
549076ba34aSJunchao Zhang     for (i = 0; i < Cm; i++) Ci_ptr[i + 1] += Ci_ptr[i]; /* to CSR format */
550076ba34aSJunchao Zhang     Cnnz = Ci_ptr[Cm];
551076ba34aSJunchao Zhang 
552076ba34aSJunchao Zhang     /* For each received row, compute src & dst offsets in memory copying (from recv bufs abuf, jbuf to Ca, Cj) */
553076ba34aSJunchao Zhang     MatRowMapKokkosViewHost dstrowoffset_h("dstrowoffset_h", nrows);
554076ba34aSJunchao Zhang     PetscInt               *dstrowoffset_hptr = dstrowoffset_h.data();
555076ba34aSJunchao Zhang     PetscInt               *currowlens; /* Current row lens. They are temp accumulators for row lens in C, to help build dstrowoffset */
556076ba34aSJunchao Zhang 
5579566063dSJacob Faibussowitsch     PetscCall(PetscCalloc1(Cm, &currowlens));           /* Init with zero, to be added to */
558076ba34aSJunchao Zhang     for (i = 0; i < nrows; i++) {                       /* for each row I receive */
559076ba34aSJunchao Zhang       r                    = rows[i];                   /* row id in C */
560076ba34aSJunchao Zhang       dstrowoffset_hptr[i] = Ci_ptr[r] + currowlens[r]; /* dst offset of the new place for each recv'ed row in Ca/Cj */
561076ba34aSJunchao Zhang       currowlens[r] += rrowlens[i];                     /* accumulate to length of row r in C */
562076ba34aSJunchao Zhang     }
5639566063dSJacob Faibussowitsch     PetscCall(PetscFree(currowlens));
564076ba34aSJunchao Zhang 
565076ba34aSJunchao Zhang     rrowlens--;
566076ba34aSJunchao Zhang     for (i = 0; i < nrows; i++) rrowlens[i + 1] += rrowlens[i]; /* Change rrowlens[] to CSR format */
567076ba34aSJunchao Zhang     dstrowoffset = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), dstrowoffset_h);
568076ba34aSJunchao Zhang     srcrowoffset = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), rrowlens_h); /* src offset of each recv'ed row in abuf/jbuf */
569076ba34aSJunchao Zhang 
570076ba34aSJunchao Zhang     /* Build the reduceSF, which performs buffer to buffer send/recv */
571076ba34aSJunchao Zhang     PetscInt *sdisp, *rdisp; /* buffer to send offsets of roots, and buffer to recv them */
5729566063dSJacob Faibussowitsch     PetscCall(PetscMalloc2(niranks, &sdisp, nranks, &rdisp));
573076ba34aSJunchao Zhang     for (i = 0; i < niranks; i++) sdisp[i] = rrowlens[ioffset[i]];
5749566063dSJacob Faibussowitsch     for (i = 0; i < nranks; i++) PetscCallMPI(MPI_Irecv(&rdisp[i], 1, MPIU_INT, ranks[i], tag, comm, &reqs[i]));
5759566063dSJacob Faibussowitsch     for (i = 0; i < niranks; i++) PetscCallMPI(MPI_Isend(&sdisp[i], 1, MPIU_INT, iranks[i], tag, comm, &reqs[nranks + i]));
5769566063dSJacob Faibussowitsch     PetscCallMPI(MPI_Waitall(niranks + nranks, reqs, MPI_STATUSES_IGNORE));
577076ba34aSJunchao Zhang 
578076ba34aSJunchao Zhang     /* Nonzeros in abuf/jbuf are roots and those in A are leaves */
579076ba34aSJunchao Zhang     PetscInt     nroots2 = Cnnz, nleaves2 = Annz;
580076ba34aSJunchao Zhang     PetscSFNode *iremote;
5819566063dSJacob Faibussowitsch     PetscCall(PetscMalloc1(nleaves2, &iremote)); /* no free, since memory will be given to reduceSF */
582076ba34aSJunchao Zhang     for (i = 0; i < nranks; i++) {
583076ba34aSJunchao Zhang       PetscInt rootbase = rdisp[i];                      /* root offset at this root rank */
584076ba34aSJunchao Zhang       PetscInt leafbase = Ai[roffset[i]];                /* leaf base */
585076ba34aSJunchao Zhang       PetscInt nz       = Ai[roffset[i + 1]] - leafbase; /* I will send nz nonzeros to this root rank */
586076ba34aSJunchao Zhang       for (PetscInt k = 0; k < nz; k++) {
587076ba34aSJunchao Zhang         iremote[leafbase + k].rank  = ranks[i];
588076ba34aSJunchao Zhang         iremote[leafbase + k].index = rootbase + k;
589076ba34aSJunchao Zhang       }
590076ba34aSJunchao Zhang     }
5919566063dSJacob Faibussowitsch     PetscCall(PetscSFCreate(comm, &reduceSF));
5929566063dSJacob Faibussowitsch     PetscCall(PetscSFSetGraph(reduceSF, nroots2, nleaves2, NULL, PETSC_OWN_POINTER, iremote, PETSC_OWN_POINTER));
5939566063dSJacob Faibussowitsch     PetscCall(PetscFree2(sdisp, rdisp));
594076ba34aSJunchao Zhang 
595076ba34aSJunchao Zhang     /* Reduce Aa, Ajg to abuf and jbuf */
596076ba34aSJunchao Zhang 
597076ba34aSJunchao Zhang     /* If A uses local col ids, convert them to global ones before sending */
598076ba34aSJunchao Zhang     MatColIdxKokkosView Ajg;
599076ba34aSJunchao Zhang     if (local) {
600076ba34aSJunchao Zhang       Ajg                           = MatColIdxKokkosView("j", Annz);
601076ba34aSJunchao Zhang       const MatColIdxKokkosView &Aj = akok->j_dual.view_device();
6029371c9d4SSatish Balay       Kokkos::parallel_for(
6039371c9d4SSatish Balay         Annz, KOKKOS_LAMBDA(const PetscInt i) { Ajg(i) = l2g(Aj(i)); });
604076ba34aSJunchao Zhang     } else {
605076ba34aSJunchao Zhang       Ajg = akok->j_dual.view_device(); /* no data copy, just take a reference */
606076ba34aSJunchao Zhang     }
607076ba34aSJunchao Zhang 
608076ba34aSJunchao Zhang     MatColIdxKokkosView jbuf("jbuf", Cnnz);
609076ba34aSJunchao Zhang     abuf = MatScalarKokkosView("abuf", Cnnz);
6109566063dSJacob Faibussowitsch     PetscCallMPI(PetscSFReduceBegin(reduceSF, MPIU_INT, Ajg.data(), jbuf.data(), MPI_REPLACE));
6119566063dSJacob Faibussowitsch     PetscCallMPI(PetscSFReduceEnd(reduceSF, MPIU_INT, Ajg.data(), jbuf.data(), MPI_REPLACE));
6129566063dSJacob Faibussowitsch     PetscCallMPI(PetscSFReduceBegin(reduceSF, MPIU_SCALAR, akok->a_device_data(), abuf.data(), MPI_REPLACE));
6139566063dSJacob Faibussowitsch     PetscCallMPI(PetscSFReduceEnd(reduceSF, MPIU_SCALAR, akok->a_device_data(), abuf.data(), MPI_REPLACE));
614076ba34aSJunchao Zhang 
615076ba34aSJunchao Zhang     /* Copy data from abuf, jbuf to Ca, Cj */
616076ba34aSJunchao Zhang     MatRowMapKokkosView Ci = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Ci_h); /* Ci is an alias of Ci_h if no device */
617076ba34aSJunchao Zhang     MatColIdxKokkosView Cj("j", Cnnz);
618076ba34aSJunchao Zhang     MatScalarKokkosView Ca("a", Cnnz);
619076ba34aSJunchao Zhang 
6209371c9d4SSatish Balay     Kokkos::parallel_for(
6219371c9d4SSatish Balay       Kokkos::TeamPolicy<>(nrows, Kokkos::AUTO()), KOKKOS_LAMBDA(const KokkosTeamMemberType &t) {
622076ba34aSJunchao Zhang         PetscInt i   = t.league_rank();
623076ba34aSJunchao Zhang         PetscInt src = srcrowoffset(i), dst = dstrowoffset(i);
624076ba34aSJunchao Zhang         PetscInt len = srcrowoffset(i + 1) - srcrowoffset(i);
625076ba34aSJunchao Zhang         Kokkos::parallel_for(Kokkos::TeamThreadRange(t, len), [&](PetscInt k) {
626076ba34aSJunchao Zhang           Ca(dst + k) = abuf(src + k);
627076ba34aSJunchao Zhang           Cj(dst + k) = jbuf(src + k);
628076ba34aSJunchao Zhang         });
629076ba34aSJunchao Zhang       });
630076ba34aSJunchao Zhang 
631076ba34aSJunchao Zhang     /* Build C with Ca, Ci, Cj */
632076ba34aSJunchao Zhang     C = KokkosCsrMatrix("csrmat", Cm, N, Cnnz, Ca, Ci, Cj);
6339566063dSJacob Faibussowitsch     PetscCall(PetscFree2(srowlens, reqs));
63498921bdaSJacob Faibussowitsch   } else SETERRQ(PETSC_COMM_SELF, PETSC_ERR_PLIB, "Unsupported MatReuse enum %d", reuse);
635076ba34aSJunchao Zhang   PetscFunctionReturn(0);
636076ba34aSJunchao Zhang }
637076ba34aSJunchao Zhang 
63811a5261eSBarry Smith /* MatSetMPIAIJKokkosWithGlobalCSRMatrix - Set the diag and offdiag parts of a `MATMPIAIJKOKKOS` matrix by splitting a KokkosCsrMatrix
639076ba34aSJunchao Zhang 
640076ba34aSJunchao Zhang   Input Parameters:
64111a5261eSBarry Smith +  C        - the `MATMPIAIJKOKKOS` matrix, of size m,n,M,N
642076ba34aSJunchao Zhang .  reuse    - indicate whether the matrix has called this function before
643076ba34aSJunchao Zhang .  csrmat   - the KokkosCsrMatrix, of size m,N
64411a5261eSBarry Smith -  Cdstart  - when reuse == `MAT_REUSE_MATRIX`, it is an input parameter. For each row in csrmat, it stores the start of the first
645076ba34aSJunchao Zhang               entry of the diag block of C in csrmat's j array. E.g, if row i has col ids = {0, 3, 4, 5, 7, 9} and the first diag
646076ba34aSJunchao Zhang               entry is 5, then Cdstart[i] = 3.
647076ba34aSJunchao Zhang 
648076ba34aSJunchao Zhang   Output Parameters:
64911a5261eSBarry Smith +  C        - the updated `MATMPIAIJKOKKOS` matrix
65011a5261eSBarry Smith -  Cdstart - when reuse == `MAT_INITIAL_MATRIX`, it is an output parameter
651076ba34aSJunchao Zhang 
65211a5261eSBarry Smith   Note:
65311a5261eSBarry Smith    Between calls with `MAT_INITIAL_MATRIX` or `MAT_REUSE_MATRIX`, csrmat must have the same nonzero pattern
65411a5261eSBarry Smith 
65511a5261eSBarry Smith .seealso: `MATMPIAIJKOKKOS`
656076ba34aSJunchao Zhang  */
657*d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatSetMPIAIJKokkosWithGlobalCSRMatrix(Mat C, MatReuse reuse, const KokkosCsrMatrix &csrmat, MatRowMapKokkosView &Cdstart)
658*d71ae5a4SJacob Faibussowitsch {
659076ba34aSJunchao Zhang   const MatScalarKokkosView      &Ca = csrmat.values;
660076ba34aSJunchao Zhang   const ConstMatRowMapKokkosView &Ci = csrmat.graph.row_map;
661076ba34aSJunchao Zhang   PetscInt                        m, n, N;
662076ba34aSJunchao Zhang 
663076ba34aSJunchao Zhang   PetscFunctionBegin;
6649566063dSJacob Faibussowitsch   PetscCall(MatGetLocalSize(C, &m, &n));
6659566063dSJacob Faibussowitsch   PetscCall(MatGetSize(C, NULL, &N));
666076ba34aSJunchao Zhang 
667076ba34aSJunchao Zhang   if (reuse == MAT_REUSE_MATRIX) {
668076ba34aSJunchao Zhang     Mat_MPIAIJ                *mpiaij = static_cast<Mat_MPIAIJ *>(C->data);
669076ba34aSJunchao Zhang     Mat_SeqAIJKokkos          *akok   = static_cast<Mat_SeqAIJKokkos *>(mpiaij->A->spptr);
670076ba34aSJunchao Zhang     Mat_SeqAIJKokkos          *bkok   = static_cast<Mat_SeqAIJKokkos *>(mpiaij->B->spptr);
671076ba34aSJunchao Zhang     const MatScalarKokkosView &Cda = akok->a_dual.view_device(), Coa = bkok->a_dual.view_device();
672076ba34aSJunchao Zhang     const MatRowMapKokkosView &Cdi = akok->i_dual.view_device(), Coi = bkok->i_dual.view_device();
673076ba34aSJunchao Zhang 
674076ba34aSJunchao Zhang     /* Fill 'a' of Cd and Co on device */
6759371c9d4SSatish Balay     Kokkos::parallel_for(
6769371c9d4SSatish Balay       Kokkos::TeamPolicy<>(m, Kokkos::AUTO()), KOKKOS_LAMBDA(const KokkosTeamMemberType &t) {
677076ba34aSJunchao Zhang         PetscInt i       = t.league_rank();     /* row i */
678076ba34aSJunchao Zhang         PetscInt clen    = Ci(i + 1) - Ci(i);   /* len of row i of C */
679076ba34aSJunchao Zhang         PetscInt cdlen   = Cdi(i + 1) - Cdi(i); /* len of row i of Cd */
680076ba34aSJunchao Zhang         PetscInt cdstart = Cdstart(i);          /* [start, end) of row i of Cd in C */
681076ba34aSJunchao Zhang         PetscInt cdend   = cdstart + cdlen;
682076ba34aSJunchao Zhang         /* [0, clen) is cut into three blocks: [0, cdstart), [cdstart, cdend), [cdend, clen) */
683076ba34aSJunchao Zhang         Kokkos::parallel_for(Kokkos::TeamThreadRange(t, clen), [&](PetscInt k) {
684076ba34aSJunchao Zhang           if (k < cdstart) { /* k in [0, cdstart) */
685076ba34aSJunchao Zhang             Coa(Coi(i) + k) = Ca(Ci(i) + k);
686076ba34aSJunchao Zhang           } else if (k < cdend) { /* k in [cdstart, cdend) */
687076ba34aSJunchao Zhang             Cda(Cdi(i) + (k - cdstart)) = Ca(Ci(i) + k);
688076ba34aSJunchao Zhang           } else { /* k in [cdend, clen) */
689076ba34aSJunchao Zhang             Coa(Coi(i) + k - cdlen) = Ca(Ci(i) + k);
690076ba34aSJunchao Zhang           }
691076ba34aSJunchao Zhang         });
692076ba34aSJunchao Zhang       });
693076ba34aSJunchao Zhang 
694076ba34aSJunchao Zhang     akok->a_dual.modify_device();
695076ba34aSJunchao Zhang     bkok->a_dual.modify_device();
696076ba34aSJunchao Zhang   } else if (reuse == MAT_INITIAL_MATRIX) {
697076ba34aSJunchao Zhang     Mat                        Cd, Co;
698076ba34aSJunchao Zhang     const MatColIdxKokkosView &Cj = csrmat.graph.entries;
699076ba34aSJunchao Zhang     MatRowMapKokkosDualView    Cdi_dual("i", m + 1), Coi_dual("i", m + 1);
700076ba34aSJunchao Zhang     MatRowMapKokkosView        Cdi = Cdi_dual.view_device(), Coi = Coi_dual.view_device();
701076ba34aSJunchao Zhang     PetscInt                   cstart, cend;
702076ba34aSJunchao Zhang 
703076ba34aSJunchao Zhang     /* Note that each row of C is sorted by col ids. We want to find out how to cut each row into three blocks:
704076ba34aSJunchao Zhang        left to the diag block, diag block, right to the diag block. The diag block have col ids in [cstart,cend).
705076ba34aSJunchao Zhang        Suppose a row of C has len nonzeros, indexed by [0, len). We want to know two indices: cdstart and cdend,
706076ba34aSJunchao Zhang        such that the three blocks are [0,cdstart), [cdstart,cdend), [cdend,len). The following code equivalentaly
707076ba34aSJunchao Zhang        stores values of cdstart and cdend-cstart (aka Cdi[]) instead.
708076ba34aSJunchao Zhang      */
709076ba34aSJunchao Zhang     Cdstart = MatRowMapKokkosView("Cdstart", m);
7109566063dSJacob Faibussowitsch     PetscCall(PetscLayoutGetRange(C->cmap, &cstart, &cend)); /* Not MatGetOwnershipRangeColumn() since C has not been preallocated yet */
711076ba34aSJunchao Zhang 
712076ba34aSJunchao Zhang     /* I could use RangePolicy and one thread per row. But since each thread essentially does binary search, threads in a
713076ba34aSJunchao Zhang       CUDA warp would completely diverge. So I use TeamPolicy with a team size 1.
714076ba34aSJunchao Zhang      */
7159371c9d4SSatish Balay     Kokkos::parallel_for(
7169371c9d4SSatish Balay       Kokkos::TeamPolicy<>(m, 1), KOKKOS_LAMBDA(const KokkosTeamMemberType &t) {
717076ba34aSJunchao Zhang         Kokkos::single(Kokkos::PerTeam(t), [=]() {                               /* Only one thread works in a team */
718076ba34aSJunchao Zhang                                                    PetscInt i = t.league_rank(); /* row i */
719076ba34aSJunchao Zhang                                                    PetscInt j, first, count, step;
720076ba34aSJunchao Zhang 
721076ba34aSJunchao Zhang                                                    if (i == 0) { /* Set the first entry of the i arrays to zero on device, to be used in CSR */
722076ba34aSJunchao Zhang                                                      Cdi(0) = 0;
723076ba34aSJunchao Zhang                                                      Coi(0) = 0;
724076ba34aSJunchao Zhang                                                    }
725076ba34aSJunchao Zhang 
726076ba34aSJunchao Zhang                                                    /* Do std::lower_bound(Ci(i),Ci(i+1),cstart) on Cj[]. We use j as the iterator. lower_bound() returns
727076ba34aSJunchao Zhang           in 'first' the first iterator with a value >= cstart, or last iterator if no such element is found.
728076ba34aSJunchao Zhang         */
729076ba34aSJunchao Zhang                                                    count = Ci(i + 1) - Ci(i);
730076ba34aSJunchao Zhang                                                    first = Ci(i);
731076ba34aSJunchao Zhang                                                    while (count > 0) {
732076ba34aSJunchao Zhang                                                      j    = first;
733076ba34aSJunchao Zhang                                                      step = count / 2;
734076ba34aSJunchao Zhang                                                      j += step;
735076ba34aSJunchao Zhang                                                      if (Cj(j) < cstart) {
736076ba34aSJunchao Zhang                                                        first = ++j;
737076ba34aSJunchao Zhang                                                        count -= step + 1;
738076ba34aSJunchao Zhang                                                      } else count = step;
739076ba34aSJunchao Zhang                                                    }
740076ba34aSJunchao Zhang                                                    Cdstart(i) = first - Ci(i); /* 'first' is the while-loop's output */
741076ba34aSJunchao Zhang 
742076ba34aSJunchao Zhang                                                    /* Do std::lower_bound(first,Ci(i+1),cend) on Cj[] */
743076ba34aSJunchao Zhang                                                    count = Ci(i + 1) - first;
744076ba34aSJunchao Zhang                                                    while (count > 0) {
745076ba34aSJunchao Zhang                                                      j    = first;
746076ba34aSJunchao Zhang                                                      step = count / 2;
747076ba34aSJunchao Zhang                                                      j += step;
748076ba34aSJunchao Zhang                                                      if (Cj(j) < cend) {
749076ba34aSJunchao Zhang                                                        first = ++j;
750076ba34aSJunchao Zhang                                                        count -= step + 1;
751076ba34aSJunchao Zhang                                                      } else count = step;
752076ba34aSJunchao Zhang                                                    }
753076ba34aSJunchao Zhang                                                    Cdi(i + 1) = first - (Ci(i) + Cdstart(i));     /* 'first' is the while-loop's output */
754076ba34aSJunchao Zhang                                                    Coi(i + 1) = (Ci(i + 1) - Ci(i)) - Cdi(i + 1); /* Co's row len = C's row len - Cd's row len */
755076ba34aSJunchao Zhang         });
756076ba34aSJunchao Zhang       });
757076ba34aSJunchao Zhang 
758076ba34aSJunchao Zhang     /* Convert row lens in Cdi[], Coi[] to CSR format using inclusive scan, e.g., changing [0,1,2,3] into [0,1,3,6] */
7599371c9d4SSatish Balay     Kokkos::parallel_scan(
7609371c9d4SSatish Balay       m + 1, KOKKOS_LAMBDA(const PetscInt i, PetscInt &update, const bool final) {
761076ba34aSJunchao Zhang         update += Cdi(i);
762076ba34aSJunchao Zhang         if (final) Cdi(i) = update;
763076ba34aSJunchao Zhang       });
7649371c9d4SSatish Balay     Kokkos::parallel_scan(
7659371c9d4SSatish Balay       m + 1, KOKKOS_LAMBDA(const PetscInt i, PetscInt &update, const bool final) {
766076ba34aSJunchao Zhang         update += Coi(i);
767076ba34aSJunchao Zhang         if (final) Coi(i) = update;
768076ba34aSJunchao Zhang       });
769076ba34aSJunchao Zhang 
770076ba34aSJunchao Zhang     /* Get Cdi, Coi on host (it is not a waste, since we do need them on host in
771076ba34aSJunchao Zhang        MatCreateSeqAIJKokkosWithCSRMatrix() below), then get nnz of Cd and Co.
772076ba34aSJunchao Zhang     */
773076ba34aSJunchao Zhang     Cdi_dual.modify_device();
774076ba34aSJunchao Zhang     Coi_dual.modify_device();
775076ba34aSJunchao Zhang     Cdi_dual.sync_host();
776076ba34aSJunchao Zhang     Coi_dual.sync_host();
777076ba34aSJunchao Zhang     PetscInt Cd_nnz = Cdi_dual.view_host().data()[m];
778076ba34aSJunchao Zhang     PetscInt Co_nnz = Coi_dual.view_host().data()[m];
779076ba34aSJunchao Zhang 
780076ba34aSJunchao Zhang     /* With nnz, allocate a, j for Cd and Co */
781076ba34aSJunchao Zhang     MatColIdxKokkosDualView Cdj_dual("j", Cd_nnz), Coj_dual("j", Co_nnz);
782076ba34aSJunchao Zhang     MatScalarKokkosDualView Cda_dual("a", Cd_nnz), Coa_dual("a", Co_nnz);
783076ba34aSJunchao Zhang 
784076ba34aSJunchao Zhang     /* Fill a, j of Cd and Co on device */
785076ba34aSJunchao Zhang     MatColIdxKokkosView Cdj = Cdj_dual.view_device(), Coj = Coj_dual.view_device();
786076ba34aSJunchao Zhang     MatScalarKokkosView Cda = Cda_dual.view_device(), Coa = Coa_dual.view_device();
787076ba34aSJunchao Zhang 
7889371c9d4SSatish Balay     Kokkos::parallel_for(
7899371c9d4SSatish Balay       Kokkos::TeamPolicy<>(m, Kokkos::AUTO()), KOKKOS_LAMBDA(const KokkosTeamMemberType &t) {
790076ba34aSJunchao Zhang         PetscInt i       = t.league_rank();     /* row i */
791076ba34aSJunchao Zhang         PetscInt clen    = Ci(i + 1) - Ci(i);   /* len of row i of C */
792076ba34aSJunchao Zhang         PetscInt cdlen   = Cdi(i + 1) - Cdi(i); /* len of row i of Cd */
793076ba34aSJunchao Zhang         PetscInt cdstart = Cdstart(i);          /* [start, end) of row i of Cd in C */
794076ba34aSJunchao Zhang         PetscInt cdend   = cdstart + cdlen;
795076ba34aSJunchao Zhang         /* [0, clen) is cut into three blocks: [0, cdstart), [cdstart, cdend), [cdend, clen) */
796076ba34aSJunchao Zhang         Kokkos::parallel_for(Kokkos::TeamThreadRange(t, clen), [&](PetscInt k) {
797076ba34aSJunchao Zhang           if (k < cdstart) { /* k in [0, cdstart) */
798076ba34aSJunchao Zhang             Coa(Coi(i) + k) = Ca(Ci(i) + k);
799076ba34aSJunchao Zhang             Coj(Coi(i) + k) = Cj(Ci(i) + k);
800076ba34aSJunchao Zhang           } else if (k < cdend) { /* k in [cdstart, cdend) */
801076ba34aSJunchao Zhang             Cda(Cdi(i) + (k - cdstart)) = Ca(Ci(i) + k);
802076ba34aSJunchao Zhang             Cdj(Cdi(i) + (k - cdstart)) = Cj(Ci(i) + k) - cstart; /* Use local col ids in Cdj */
803076ba34aSJunchao Zhang           } else {                                                /* k in [cdend, clen) */
804076ba34aSJunchao Zhang             Coa(Coi(i) + k - cdlen) = Ca(Ci(i) + k);
805076ba34aSJunchao Zhang             Coj(Coi(i) + k - cdlen) = Cj(Ci(i) + k);
806076ba34aSJunchao Zhang           }
807076ba34aSJunchao Zhang         });
808076ba34aSJunchao Zhang       });
809076ba34aSJunchao Zhang 
810076ba34aSJunchao Zhang     Cdj_dual.modify_device();
811076ba34aSJunchao Zhang     Cda_dual.modify_device();
812076ba34aSJunchao Zhang     Coj_dual.modify_device();
813076ba34aSJunchao Zhang     Coa_dual.modify_device();
814076ba34aSJunchao Zhang     /* With a, i, j for Cd and Co, finally build Cd, Co and then C. Their offloadmask will be set in each's MatAssemblyEnd */
815076ba34aSJunchao Zhang     auto cdkok = new Mat_SeqAIJKokkos(m, n, Cd_nnz, Cdi_dual, Cdj_dual, Cda_dual);
816076ba34aSJunchao Zhang     auto cokok = new Mat_SeqAIJKokkos(m, N, Co_nnz, Coi_dual, Coj_dual, Coa_dual);
8179566063dSJacob Faibussowitsch     PetscCall(MatCreateSeqAIJKokkosWithCSRMatrix(PETSC_COMM_SELF, cdkok, &Cd));
8189566063dSJacob Faibussowitsch     PetscCall(MatCreateSeqAIJKokkosWithCSRMatrix(PETSC_COMM_SELF, cokok, &Co));
8199566063dSJacob Faibussowitsch     PetscCall(MatSetMPIAIJKokkosWithSplitSeqAIJKokkosMatrices(C, Cd, Co)); /* Coj will be converted to local ids within */
820076ba34aSJunchao Zhang   }
821076ba34aSJunchao Zhang   PetscFunctionReturn(0);
822076ba34aSJunchao Zhang }
823076ba34aSJunchao Zhang 
824076ba34aSJunchao Zhang /* MatSeqAIJCompactOutExtraColumns_SeqAIJKokkos - Compact a SEQAIJKOKKS matrix's global col ids.
825076ba34aSJunchao Zhang 
826076ba34aSJunchao Zhang   It is similar to MatSeqAIJCompactOutExtraColumns_SeqAIJ, but it applies to SEQAIJKOKKOS and returns the l2g map in Kokkos view.
827076ba34aSJunchao Zhang 
828076ba34aSJunchao Zhang   Input Parameters:
829076ba34aSJunchao Zhang +  C        - the MATMPIAIJKOKKOS matrix, of size m,n,M,N
830076ba34aSJunchao Zhang .  reuse    - indicate whether the matrix has called this function before
831076ba34aSJunchao Zhang .  csrmat   - the KokkosCsrMatrix, of size m,N
832076ba34aSJunchao Zhang -  Cdoffset - when reuse == MAT_REUSE_MATRIX, it is an input parameter. For each row in csrmat, it stores the offset of the first
833076ba34aSJunchao Zhang               entry of the diag block of C in csrmat's j array.
834076ba34aSJunchao Zhang 
835076ba34aSJunchao Zhang   Output Parameters:
836076ba34aSJunchao Zhang +  C        - the updated MATMPIAIJKOKKOS matrix
837076ba34aSJunchao Zhang -  Cdoffset - when reuse == MAT_INITIAL_MATRIX, it is an output parameter
838076ba34aSJunchao Zhang 
83911a5261eSBarry Smith   Note:
84011a5261eSBarry Smith   the input matrix's col ids and col size will be changed.
841076ba34aSJunchao Zhang */
842*d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatSeqAIJCompactOutExtraColumns_SeqAIJKokkos(Mat C, MatColIdxKokkosView &l2g)
843*d71ae5a4SJacob Faibussowitsch {
844076ba34aSJunchao Zhang   Mat_SeqAIJKokkos      *ckok;
845076ba34aSJunchao Zhang   ISLocalToGlobalMapping l2gmap;
846076ba34aSJunchao Zhang   const PetscInt        *garray;
847076ba34aSJunchao Zhang   PetscInt               sz;
848076ba34aSJunchao Zhang 
849076ba34aSJunchao Zhang   PetscFunctionBegin;
850076ba34aSJunchao Zhang   /* Compact P_other's global col ids and col size. We do it since we guess with local ids KK might be more memory scalable */
8519566063dSJacob Faibussowitsch   PetscCall(MatSeqAIJCompactOutExtraColumns_SeqAIJ(C, &l2gmap));
852076ba34aSJunchao Zhang   ckok = static_cast<Mat_SeqAIJKokkos *>(C->spptr);
853076ba34aSJunchao Zhang   ckok->j_dual.modify_host(); /* P_other's j is modified on host; we need to sync it on device */
854076ba34aSJunchao Zhang   ckok->j_dual.sync_device();
855076ba34aSJunchao Zhang   ckok->SetColSize(C->cmap->n); /* Update col size of the csrmat in spptr */
856076ba34aSJunchao Zhang 
857076ba34aSJunchao Zhang   /* Build l2g -- the local to global mapping of C's cols */
8589566063dSJacob Faibussowitsch   PetscCall(ISLocalToGlobalMappingGetIndices(l2gmap, &garray));
8599566063dSJacob Faibussowitsch   PetscCall(ISLocalToGlobalMappingGetSize(l2gmap, &sz));
86008401ef6SPierre Jolivet   PetscCheck(C->cmap->n == sz, PETSC_COMM_SELF, PETSC_ERR_PLIB, "matrix column size(%" PetscInt_FMT ") != l2g mapping size(%" PetscInt_FMT ")", C->cmap->n, sz);
861076ba34aSJunchao Zhang 
862076ba34aSJunchao Zhang   ConstMatColIdxKokkosViewHost tmp(garray, sz);
863076ba34aSJunchao Zhang   l2g = MatColIdxKokkosView("l2g", sz);
864076ba34aSJunchao Zhang   Kokkos::deep_copy(l2g, tmp);
865076ba34aSJunchao Zhang 
8669566063dSJacob Faibussowitsch   PetscCall(ISLocalToGlobalMappingRestoreIndices(l2gmap, &garray));
8679566063dSJacob Faibussowitsch   PetscCall(ISLocalToGlobalMappingDestroy(&l2gmap));
868076ba34aSJunchao Zhang   PetscFunctionReturn(0);
869076ba34aSJunchao Zhang }
870076ba34aSJunchao Zhang 
871076ba34aSJunchao Zhang /* MatProductSymbolic_MPIAIJKokkos_AB - AB flavor of MatProductSymbolic_MPIAIJKokkos
872076ba34aSJunchao Zhang 
873076ba34aSJunchao Zhang   Input Parameters:
874076ba34aSJunchao Zhang +  product  - Mat_Product which carried out the computation. Passed in to access info about this mat product.
875076ba34aSJunchao Zhang .  A        - an MPIAIJKOKKOS matrix
876076ba34aSJunchao Zhang .  B        - an MPIAIJKOKKOS matrix
877076ba34aSJunchao Zhang -  mm       - a struct used to stash intermediate data when computing AB. Persist from symbolic to numeric operations.
878076ba34aSJunchao Zhang 
87911a5261eSBarry Smith   Note: The local part of the result C is stored as mm->C_global, which is of type KokkosCsrMatrix and uses global col ids.
880076ba34aSJunchao Zhang */
881*d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatProductSymbolic_MPIAIJKokkos_AB(Mat_Product *product, Mat A, Mat B, MatMatStruct_AB *mm)
882*d71ae5a4SJacob Faibussowitsch {
883076ba34aSJunchao Zhang   Mat_MPIAIJ              *a  = static_cast<Mat_MPIAIJ *>(A->data);
884076ba34aSJunchao Zhang   Mat                      Ad = a->A, Ao = a->B; /* diag and offdiag of A */
885076ba34aSJunchao Zhang   IS                       glob = NULL;
886076ba34aSJunchao Zhang   const PetscInt          *garray;
887076ba34aSJunchao Zhang   PetscInt                 N = B->cmap->N, sz;
888076ba34aSJunchao Zhang   ConstMatColIdxKokkosView l2g1; /* two temp maps mapping local col ids to global ones */
889076ba34aSJunchao Zhang   MatColIdxKokkosView      l2g2;
890076ba34aSJunchao Zhang   Mat                      C1, C2; /* intermediate matrices */
891076ba34aSJunchao Zhang 
892076ba34aSJunchao Zhang   PetscFunctionBegin;
893076ba34aSJunchao Zhang   /* C1 = Ad * B_local. B_local is a matrix got by merging Bd and Bo, and uses local col ids */
8949566063dSJacob Faibussowitsch   PetscCall(MatMPIAIJGetLocalMatMerge(B, MAT_INITIAL_MATRIX, &glob, &mm->B_local));
8959566063dSJacob Faibussowitsch   PetscCall(MatProductCreate(Ad, mm->B_local, NULL, &C1));
8969566063dSJacob Faibussowitsch   PetscCall(MatProductSetType(C1, MATPRODUCT_AB));
8979566063dSJacob Faibussowitsch   PetscCall(MatProductSetFill(C1, product->fill));
898076ba34aSJunchao Zhang   C1->product->api_user = product->api_user;
8999566063dSJacob Faibussowitsch   PetscCall(MatProductSetFromOptions(C1));
900dbbe0bcdSBarry Smith   PetscUseTypeMethod(C1, productsymbolic);
901076ba34aSJunchao Zhang 
9029566063dSJacob Faibussowitsch   PetscCall(ISGetIndices(glob, &garray));
9039566063dSJacob Faibussowitsch   PetscCall(ISGetSize(glob, &sz));
904076ba34aSJunchao Zhang   const auto &tmp = ConstMatColIdxKokkosViewHost(garray, sz);                       /* wrap garray as a view */
905076ba34aSJunchao Zhang   l2g1            = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), tmp); /* maybe just an alias to tmp, so we restore garray at the very end */
9069566063dSJacob Faibussowitsch   PetscCall(MatSeqAIJKokkosGetCSRMatrixWithGlobalColumnIds(C1, N, l2g1, mm->C1_global));
907076ba34aSJunchao Zhang 
908076ba34aSJunchao Zhang   /* C2 = Ao * B_other. B_other is a matrix consisting of needed rows of B gathered from other procs */
9099566063dSJacob Faibussowitsch   PetscCall(MatSeqAIJKokkosBcast(mm->B_local, MAT_INITIAL_MATRIX, N, l2g1, a->Mvctx, mm->sf, mm->abuf, mm->rows, mm->rowoffset, mm->B_other));
910076ba34aSJunchao Zhang 
911076ba34aSJunchao Zhang   /* Compact B_other to use local ids as we guess KK spgemm is more memroy scalable with that; We could skip the compaction to simplify code */
9129566063dSJacob Faibussowitsch   PetscCall(MatSeqAIJCompactOutExtraColumns_SeqAIJKokkos(mm->B_other, l2g2));
9139566063dSJacob Faibussowitsch   PetscCall(MatProductCreate(Ao, mm->B_other, NULL, &C2));
9149566063dSJacob Faibussowitsch   PetscCall(MatProductSetType(C2, MATPRODUCT_AB));
9159566063dSJacob Faibussowitsch   PetscCall(MatProductSetFill(C2, product->fill));
916076ba34aSJunchao Zhang   C2->product->api_user = product->api_user;
9179566063dSJacob Faibussowitsch   PetscCall(MatProductSetFromOptions(C2));
918dbbe0bcdSBarry Smith   PetscUseTypeMethod(C2, productsymbolic);
9199566063dSJacob Faibussowitsch   PetscCall(MatSeqAIJKokkosGetCSRMatrixWithGlobalColumnIds(C2, N, l2g2, mm->C2_global));
920076ba34aSJunchao Zhang 
921076ba34aSJunchao Zhang   /* C = C1 + C2.  We actually use their global col ids versions in adding */
922076ba34aSJunchao Zhang   mm->kh.create_spadd_handle(false); /* Input C1, C2 are NOT sorted, since B_local, B_other are not */
923076ba34aSJunchao Zhang   KokkosSparse::spadd_symbolic(&mm->kh, mm->C1_global, mm->C2_global, mm->C_global);
924076ba34aSJunchao Zhang   /* Have to do numeric since spadd_symbolic does not really populate column indices of the result matrix */
925076ba34aSJunchao Zhang   KokkosSparse::spadd_numeric(&mm->kh, (MatScalarType)1.0, mm->C1_global, (MatScalarType)1.0, mm->C2_global, mm->C_global);
926076ba34aSJunchao Zhang 
927076ba34aSJunchao Zhang   mm->C1 = C1;
928076ba34aSJunchao Zhang   mm->C2 = C2;
9299566063dSJacob Faibussowitsch   PetscCall(ISRestoreIndices(glob, &garray));
9309566063dSJacob Faibussowitsch   PetscCall(ISDestroy(&glob));
931076ba34aSJunchao Zhang   PetscFunctionReturn(0);
932076ba34aSJunchao Zhang }
933076ba34aSJunchao Zhang 
934076ba34aSJunchao Zhang /* MatProductSymbolic_MPIAIJKokkos_AtB - A^tB flavor of MatProductSymbolic_MPIAIJKokkos
935076ba34aSJunchao Zhang 
936076ba34aSJunchao Zhang   Input Parameters:
937076ba34aSJunchao Zhang +  product  - Mat_Product which carried out the computation. Passed in to access info about this mat product.
938076ba34aSJunchao Zhang .  A        - an MPIAIJKOKKOS matrix
939076ba34aSJunchao Zhang .  B        - a SEQAIJKOKKOS matrix. It works as if A^t is multiplied by a parallel matrix made up of Bs on each rank.
940076ba34aSJunchao Zhang .  localB   - Does B use local col ids? If false, then B is already in global col ids.
941076ba34aSJunchao Zhang .  N        - col size of the "parallel B matrix". It implies B's global col ids are in range of [0,N) and N is the same across the communicator.
942076ba34aSJunchao Zhang .  l2g      - If localB, then l2g maps B's local col ids to global ones.
943076ba34aSJunchao Zhang -  mm       - a struct used to stash intermediate data in AtB
944076ba34aSJunchao Zhang 
94511a5261eSBarry Smith   Note: The local part of the result C is stored as mm->C_global, which is of type KokkosCsrMatrix and uses global col ids.
946076ba34aSJunchao Zhang */
947*d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatProductSymbolic_MPIAIJKokkos_AtB(Mat_Product *product, Mat A, Mat B, PetscBool localB, PetscInt N, const ConstMatColIdxKokkosView &l2g, MatMatStruct_AtB *mm)
948*d71ae5a4SJacob Faibussowitsch {
949076ba34aSJunchao Zhang   Mat_MPIAIJ *a  = static_cast<Mat_MPIAIJ *>(A->data);
950076ba34aSJunchao Zhang   Mat         Ad = a->A, Ao = a->B; /* diag and offdiag of A */
951076ba34aSJunchao Zhang   Mat         C1, C2;               /* intermediate matrices */
952076ba34aSJunchao Zhang 
953076ba34aSJunchao Zhang   PetscFunctionBegin;
954076ba34aSJunchao Zhang   /* C1 = Ad^t * B */
9559566063dSJacob Faibussowitsch   PetscCall(MatProductCreate(Ad, B, NULL, &C1));
9569566063dSJacob Faibussowitsch   PetscCall(MatProductSetType(C1, MATPRODUCT_AtB));
9579566063dSJacob Faibussowitsch   PetscCall(MatProductSetFill(C1, product->fill));
958076ba34aSJunchao Zhang   C1->product->api_user = product->api_user;
9599566063dSJacob Faibussowitsch   PetscCall(MatProductSetFromOptions(C1));
960dbbe0bcdSBarry Smith   PetscUseTypeMethod(C1, productsymbolic);
961076ba34aSJunchao Zhang 
9629566063dSJacob Faibussowitsch   if (localB) PetscCall(MatSeqAIJKokkosGetCSRMatrixWithGlobalColumnIds(C1, N, l2g, mm->C1_global));
963076ba34aSJunchao Zhang   else mm->C1_global = static_cast<Mat_SeqAIJKokkos *>(C1->spptr)->csrmat; /* the csrmat already uses global col ids */
964076ba34aSJunchao Zhang 
965076ba34aSJunchao Zhang   /* C2 = Ao^t * B */
9669566063dSJacob Faibussowitsch   PetscCall(MatProductCreate(Ao, B, NULL, &C2));
9679566063dSJacob Faibussowitsch   PetscCall(MatProductSetType(C2, MATPRODUCT_AtB));
9689566063dSJacob Faibussowitsch   PetscCall(MatProductSetFill(C2, product->fill));
969076ba34aSJunchao Zhang   C2->product->api_user = product->api_user;
9709566063dSJacob Faibussowitsch   PetscCall(MatProductSetFromOptions(C2));
971dbbe0bcdSBarry Smith   PetscUseTypeMethod(C2, productsymbolic);
972076ba34aSJunchao Zhang 
9739566063dSJacob Faibussowitsch   PetscCall(MatSeqAIJKokkosReduce(C2, MAT_INITIAL_MATRIX, localB, N, l2g, a->Mvctx, mm->sf, mm->abuf, mm->srcrowoffset, mm->dstrowoffset, mm->C2_global));
974076ba34aSJunchao Zhang 
975076ba34aSJunchao Zhang   mm->kh.create_spadd_handle(false); /* Input C1, C2 are NOT sorted, since B may be not */
976076ba34aSJunchao Zhang   KokkosSparse::spadd_symbolic(&mm->kh, mm->C1_global, mm->C2_global, mm->C_global);
977076ba34aSJunchao Zhang   /* Have to do numeric since spadd_symbolic does not really populate column indices of the result matrix */
978076ba34aSJunchao Zhang   KokkosSparse::spadd_numeric(&mm->kh, (MatScalarType)1.0, mm->C1_global, (MatScalarType)1.0, mm->C2_global, mm->C_global);
979076ba34aSJunchao Zhang   mm->C1 = C1;
980076ba34aSJunchao Zhang   mm->C2 = C2;
981076ba34aSJunchao Zhang   PetscFunctionReturn(0);
982076ba34aSJunchao Zhang }
983076ba34aSJunchao Zhang 
984*d71ae5a4SJacob Faibussowitsch PetscErrorCode MatProductNumeric_MPIAIJKokkos(Mat C)
985*d71ae5a4SJacob Faibussowitsch {
986076ba34aSJunchao Zhang   Mat_Product                 *product = C->product;
987076ba34aSJunchao Zhang   MatProductType               ptype;
988076ba34aSJunchao Zhang   MatProductData_MPIAIJKokkos *mmdata;
989076ba34aSJunchao Zhang   MatMatStruct                *mm = NULL;
990076ba34aSJunchao Zhang   MatMatStruct_AB             *ab;
991076ba34aSJunchao Zhang   MatMatStruct_AtB            *atb;
992076ba34aSJunchao Zhang   Mat                          A, B, Ad, Ao, Bd, Bo;
993076ba34aSJunchao Zhang   const MatScalarType          one = 1.0; /* Not use literal 1.0 directly, to avoid wrong template instantiation in KokkosSparse::spadd_numeric */
994076ba34aSJunchao Zhang 
995076ba34aSJunchao Zhang   PetscFunctionBegin;
996076ba34aSJunchao Zhang   MatCheckProduct(C, 1);
997076ba34aSJunchao Zhang   mmdata = static_cast<MatProductData_MPIAIJKokkos *>(product->data);
998076ba34aSJunchao Zhang   ptype  = product->type;
999076ba34aSJunchao Zhang   A      = product->A;
1000076ba34aSJunchao Zhang   B      = product->B;
1001076ba34aSJunchao Zhang   Ad     = static_cast<Mat_MPIAIJ *>(A->data)->A;
1002076ba34aSJunchao Zhang   Ao     = static_cast<Mat_MPIAIJ *>(A->data)->B;
1003076ba34aSJunchao Zhang   Bd     = static_cast<Mat_MPIAIJ *>(B->data)->A;
1004076ba34aSJunchao Zhang   Bo     = static_cast<Mat_MPIAIJ *>(B->data)->B;
1005076ba34aSJunchao Zhang 
1006076ba34aSJunchao Zhang   if (mmdata->reusesym) {           /* We reached here through e.g., MatMatMult(A,B,MAT_INITIAL_MATRIX,..,C), where symbolic/numeric are combined */
1007076ba34aSJunchao Zhang     mmdata->reusesym = PETSC_FALSE; /* So that next time when user calls MatMatMult(E,F,MAT_REUSE_MATRIX,..,C), we still do numeric  */
1008076ba34aSJunchao Zhang     ab               = mmdata->mmAB;
1009076ba34aSJunchao Zhang     atb              = mmdata->mmAtB;
1010076ba34aSJunchao Zhang     if (ab) {
1011076ba34aSJunchao Zhang       static_cast<MatProductData_SeqAIJKokkos *>(ab->C1->product->data)->reusesym = PETSC_FALSE;
1012076ba34aSJunchao Zhang       static_cast<MatProductData_SeqAIJKokkos *>(ab->C2->product->data)->reusesym = PETSC_FALSE;
1013076ba34aSJunchao Zhang     }
1014076ba34aSJunchao Zhang     if (atb) {
1015076ba34aSJunchao Zhang       static_cast<MatProductData_SeqAIJKokkos *>(atb->C1->product->data)->reusesym = PETSC_FALSE;
1016076ba34aSJunchao Zhang       static_cast<MatProductData_SeqAIJKokkos *>(atb->C2->product->data)->reusesym = PETSC_FALSE;
1017076ba34aSJunchao Zhang     }
1018076ba34aSJunchao Zhang     PetscFunctionReturn(0);
1019076ba34aSJunchao Zhang   }
1020076ba34aSJunchao Zhang 
1021076ba34aSJunchao Zhang   if (ptype == MATPRODUCT_AB) {
1022076ba34aSJunchao Zhang     ab = mmdata->mmAB;
1023076ba34aSJunchao Zhang     /* C1 = Ad * B_local */
102408401ef6SPierre Jolivet     PetscCheck(ab->C1->ops->productnumeric && ab->C2->ops->productnumeric, PETSC_COMM_SELF, PETSC_ERR_PLIB, "Missing numeric op for MATPRODUCT_AB");
10259566063dSJacob Faibussowitsch     PetscCall(MatMPIAIJGetLocalMatMerge(B, MAT_REUSE_MATRIX, NULL /* glob */, &ab->B_local));
10265f80ce2aSJacob Faibussowitsch     PetscCheck(ab->C1->product->B == ab->B_local, PETSC_COMM_SELF, PETSC_ERR_PLIB, "In MATPRODUCT_AB, internal mat product matrix C1->B has unexpectedly changed");
10279566063dSJacob Faibussowitsch     if (ab->C1->product->A != Ad) PetscCall(MatProductReplaceMats(Ad, NULL, NULL, ab->C1));
10289566063dSJacob Faibussowitsch     PetscCall((*ab->C1->ops->productnumeric)(ab->C1));
10299371c9d4SSatish Balay     PetscCall(MatSeqAIJKokkosBcast(ab->B_local, MAT_REUSE_MATRIX, 0 /* N */, MatColIdxKokkosView() /*l2g*/, NULL /*ownerSF*/, ab->sf, ab->abuf, ab->rows, ab->rowoffset, ab->B_other));
1030076ba34aSJunchao Zhang     /* C2 = Ao * B_other */
103108401ef6SPierre Jolivet     PetscCheck(ab->C2->product->B == ab->B_other, PETSC_COMM_SELF, PETSC_ERR_PLIB, "In MATPRODUCT_AB, internal mat product matrix C2->B has unexpectedly changed");
10329566063dSJacob Faibussowitsch     if (ab->C1->product->A != Ao) PetscCall(MatProductReplaceMats(Ao, NULL, NULL, ab->C2));
10339566063dSJacob Faibussowitsch     PetscCall((*ab->C2->ops->productnumeric)(ab->C2));
1034076ba34aSJunchao Zhang     /* C = C1_global + C2_global */
1035076ba34aSJunchao Zhang     KokkosSparse::spadd_numeric(&ab->kh, one, ab->C1_global, one, ab->C2_global, ab->C_global);
1036076ba34aSJunchao Zhang     mm = static_cast<MatMatStruct *>(ab);
1037076ba34aSJunchao Zhang   } else if (ptype == MATPRODUCT_AtB) {
1038076ba34aSJunchao Zhang     atb = mmdata->mmAtB;
103908401ef6SPierre Jolivet     PetscCheck(atb->C1->ops->productnumeric && atb->C2->ops->productnumeric, PETSC_COMM_SELF, PETSC_ERR_PLIB, "Missing numeric op for MATPRODUCT_AtB");
1040076ba34aSJunchao Zhang     /* C1 = Ad^t * B_local */
10419566063dSJacob Faibussowitsch     PetscCall(MatMPIAIJGetLocalMatMerge(B, MAT_REUSE_MATRIX, NULL /* glob */, &atb->B_local));
104208401ef6SPierre Jolivet     PetscCheck(atb->C1->product->B == atb->B_local, PETSC_COMM_SELF, PETSC_ERR_PLIB, "In MATPRODUCT_AtB, internal mat product matrix C1->B has unexpectedly changed");
10439566063dSJacob Faibussowitsch     if (atb->C1->product->A != Ad) PetscCall(MatProductReplaceMats(Ad, NULL, NULL, atb->C1));
10449566063dSJacob Faibussowitsch     PetscCall((*atb->C1->ops->productnumeric)(atb->C1));
1045076ba34aSJunchao Zhang 
1046076ba34aSJunchao Zhang     /* C2 = Ao^t * B_local */
104708401ef6SPierre Jolivet     PetscCheck(atb->C2->product->B == atb->B_local, PETSC_COMM_SELF, PETSC_ERR_PLIB, "In MATPRODUCT_AtB, internal mat product matrix C2->B has unexpectedly changed");
10489566063dSJacob Faibussowitsch     if (atb->C2->product->A != Ao) PetscCall(MatProductReplaceMats(Ao, NULL, NULL, atb->C2));
10499566063dSJacob Faibussowitsch     PetscCall((*atb->C2->ops->productnumeric)(atb->C2));
1050076ba34aSJunchao Zhang     /* Form C2_global */
10519371c9d4SSatish Balay     PetscCall(MatSeqAIJKokkosReduce(atb->C2, MAT_REUSE_MATRIX, PETSC_TRUE, 0 /* N */, MatColIdxKokkosView() /*l2g*/, NULL /*ownerSF*/, atb->sf, atb->abuf, atb->srcrowoffset, atb->dstrowoffset, atb->C2_global));
1052076ba34aSJunchao Zhang     /* C = C1_global + C2_global */
1053076ba34aSJunchao Zhang     KokkosSparse::spadd_numeric(&atb->kh, one, atb->C1_global, one, atb->C2_global, atb->C_global);
1054076ba34aSJunchao Zhang     mm = static_cast<MatMatStruct *>(atb);
1055076ba34aSJunchao Zhang   } else if (ptype == MATPRODUCT_PtAP) { /* BtAB */
1056076ba34aSJunchao Zhang     ab = mmdata->mmAB;
10579566063dSJacob Faibussowitsch     PetscCall(MatMPIAIJGetLocalMatMerge(B, MAT_REUSE_MATRIX, NULL /* glob */, &ab->B_local));
1058076ba34aSJunchao Zhang 
1059076ba34aSJunchao Zhang     /* ab->C1 = Ad * B_local */
106008401ef6SPierre Jolivet     PetscCheck(ab->C1->product->B == ab->B_local, PETSC_COMM_SELF, PETSC_ERR_PLIB, "In MATPRODUCT_PtAP, internal mat product matrix ab->C1->B has unexpectedly changed");
10619566063dSJacob Faibussowitsch     if (ab->C1->product->A != Ad) PetscCall(MatProductReplaceMats(Ad, NULL, NULL, ab->C1));
10629566063dSJacob Faibussowitsch     PetscCall((*ab->C1->ops->productnumeric)(ab->C1));
10639371c9d4SSatish Balay     PetscCall(MatSeqAIJKokkosBcast(ab->B_local, MAT_REUSE_MATRIX, 0 /* N */, MatColIdxKokkosView() /*l2g*/, NULL /*ownerSF*/, ab->sf, ab->abuf, ab->rows, ab->rowoffset, ab->B_other));
1064076ba34aSJunchao Zhang     /* ab->C2 = Ao * B_other */
10659566063dSJacob Faibussowitsch     if (ab->C2->product->A != Ao) PetscCall(MatProductReplaceMats(Ao, NULL, NULL, ab->C2));
10669566063dSJacob Faibussowitsch     PetscCall((*ab->C2->ops->productnumeric)(ab->C2)); /* C2 = Ao * B_other */
1067076ba34aSJunchao Zhang     KokkosSparse::spadd_numeric(&ab->kh, one, ab->C1_global, one, ab->C2_global, ab->C_global);
1068076ba34aSJunchao Zhang 
1069076ba34aSJunchao Zhang     /* atb->C1 = Bd^t * ab->C_petsc */
1070076ba34aSJunchao Zhang     atb = mmdata->mmAtB;
107108401ef6SPierre Jolivet     PetscCheck(atb->C1->product->B == ab->C_petsc, PETSC_COMM_SELF, PETSC_ERR_PLIB, "In MATPRODUCT_PtAP, internal mat product matrix atb->C1->B has unexpectedly changed");
10729566063dSJacob Faibussowitsch     if (atb->C1->product->A != Bd) PetscCall(MatProductReplaceMats(Bd, NULL, NULL, atb->C1));
10739566063dSJacob Faibussowitsch     PetscCall((*atb->C1->ops->productnumeric)(atb->C1));
1074076ba34aSJunchao Zhang     /* atb->C2 = Bo^t * ab->C_petsc */
10759566063dSJacob Faibussowitsch     if (atb->C2->product->A != Bo) PetscCall(MatProductReplaceMats(Bo, NULL, NULL, atb->C2));
10769566063dSJacob Faibussowitsch     PetscCall((*atb->C2->ops->productnumeric)(atb->C2));
10779371c9d4SSatish Balay     PetscCall(MatSeqAIJKokkosReduce(atb->C2, MAT_REUSE_MATRIX, PETSC_FALSE, 0 /* N */, MatColIdxKokkosView() /*l2g*/, NULL /* ownerSF */, atb->sf, atb->abuf, atb->srcrowoffset, atb->dstrowoffset, atb->C2_global));
1078076ba34aSJunchao Zhang     KokkosSparse::spadd_numeric(&atb->kh, one, atb->C1_global, one, atb->C2_global, atb->C_global);
1079076ba34aSJunchao Zhang     mm = static_cast<MatMatStruct *>(atb);
1080076ba34aSJunchao Zhang   }
1081076ba34aSJunchao Zhang   /* Split C_global to form C */
10829566063dSJacob Faibussowitsch   PetscCall(MatSetMPIAIJKokkosWithGlobalCSRMatrix(C, MAT_REUSE_MATRIX, mm->C_global, mm->Cdstart));
1083076ba34aSJunchao Zhang   PetscFunctionReturn(0);
1084076ba34aSJunchao Zhang }
1085076ba34aSJunchao Zhang 
1086*d71ae5a4SJacob Faibussowitsch PetscErrorCode MatProductSymbolic_MPIAIJKokkos(Mat C)
1087*d71ae5a4SJacob Faibussowitsch {
1088076ba34aSJunchao Zhang   Mat                          A, B;
1089076ba34aSJunchao Zhang   Mat_Product                 *product = C->product;
1090076ba34aSJunchao Zhang   MatProductType               ptype;
1091076ba34aSJunchao Zhang   MatProductData_MPIAIJKokkos *mmdata;
1092076ba34aSJunchao Zhang   MatMatStruct                *mm   = NULL;
1093076ba34aSJunchao Zhang   IS                           glob = NULL;
1094076ba34aSJunchao Zhang   const PetscInt              *garray;
1095076ba34aSJunchao Zhang   PetscInt                     m, n, M, N, sz;
1096076ba34aSJunchao Zhang   ConstMatColIdxKokkosView     l2g; /* map local col ids to global ones */
1097076ba34aSJunchao Zhang 
1098076ba34aSJunchao Zhang   PetscFunctionBegin;
1099076ba34aSJunchao Zhang   MatCheckProduct(C, 1);
110028b400f6SJacob Faibussowitsch   PetscCheck(!product->data, PetscObjectComm((PetscObject)C), PETSC_ERR_PLIB, "Product data not empty");
1101076ba34aSJunchao Zhang   ptype = product->type;
1102076ba34aSJunchao Zhang   A     = product->A;
1103076ba34aSJunchao Zhang   B     = product->B;
1104076ba34aSJunchao Zhang 
1105076ba34aSJunchao Zhang   switch (ptype) {
11069371c9d4SSatish Balay   case MATPRODUCT_AB:
11079371c9d4SSatish Balay     m = A->rmap->n;
11089371c9d4SSatish Balay     n = B->cmap->n;
11099371c9d4SSatish Balay     M = A->rmap->N;
11109371c9d4SSatish Balay     N = B->cmap->N;
11119371c9d4SSatish Balay     break;
11129371c9d4SSatish Balay   case MATPRODUCT_AtB:
11139371c9d4SSatish Balay     m = A->cmap->n;
11149371c9d4SSatish Balay     n = B->cmap->n;
11159371c9d4SSatish Balay     M = A->cmap->N;
11169371c9d4SSatish Balay     N = B->cmap->N;
11179371c9d4SSatish Balay     break;
11189371c9d4SSatish Balay   case MATPRODUCT_PtAP:
11199371c9d4SSatish Balay     m = B->cmap->n;
11209371c9d4SSatish Balay     n = B->cmap->n;
11219371c9d4SSatish Balay     M = B->cmap->N;
11229371c9d4SSatish Balay     N = B->cmap->N;
11239371c9d4SSatish Balay     break; /* BtAB */
1124*d71ae5a4SJacob Faibussowitsch   default:
1125*d71ae5a4SJacob Faibussowitsch     SETERRQ(PetscObjectComm((PetscObject)C), PETSC_ERR_PLIB, "Not for product type %s", MatProductTypes[ptype]);
1126076ba34aSJunchao Zhang   }
1127076ba34aSJunchao Zhang 
11289566063dSJacob Faibussowitsch   PetscCall(MatSetSizes(C, m, n, M, N));
11299566063dSJacob Faibussowitsch   PetscCall(PetscLayoutSetUp(C->rmap));
11309566063dSJacob Faibussowitsch   PetscCall(PetscLayoutSetUp(C->cmap));
11319566063dSJacob Faibussowitsch   PetscCall(MatSetType(C, ((PetscObject)A)->type_name));
1132076ba34aSJunchao Zhang 
1133076ba34aSJunchao Zhang   mmdata           = new MatProductData_MPIAIJKokkos();
1134076ba34aSJunchao Zhang   mmdata->reusesym = product->api_user;
1135076ba34aSJunchao Zhang 
1136076ba34aSJunchao Zhang   if (ptype == MATPRODUCT_AB) {
1137076ba34aSJunchao Zhang     mmdata->mmAB = new MatMatStruct_AB();
11389566063dSJacob Faibussowitsch     PetscCall(MatProductSymbolic_MPIAIJKokkos_AB(product, A, B, mmdata->mmAB));
1139076ba34aSJunchao Zhang     mm = static_cast<MatMatStruct *>(mmdata->mmAB);
1140076ba34aSJunchao Zhang   } else if (ptype == MATPRODUCT_AtB) {
1141076ba34aSJunchao Zhang     mmdata->mmAtB = new MatMatStruct_AtB();
1142076ba34aSJunchao Zhang     auto atb      = mmdata->mmAtB;
11439566063dSJacob Faibussowitsch     PetscCall(MatMPIAIJGetLocalMatMerge(B, MAT_INITIAL_MATRIX, &glob, &atb->B_local));
11449566063dSJacob Faibussowitsch     PetscCall(ISGetIndices(glob, &garray));
11459566063dSJacob Faibussowitsch     PetscCall(ISGetSize(glob, &sz));
1146076ba34aSJunchao Zhang     l2g = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), ConstMatColIdxKokkosViewHost(garray, sz));
11479566063dSJacob Faibussowitsch     PetscCall(MatProductSymbolic_MPIAIJKokkos_AtB(product, A, atb->B_local, PETSC_TRUE, N, l2g, atb));
11489566063dSJacob Faibussowitsch     PetscCall(ISRestoreIndices(glob, &garray));
11499566063dSJacob Faibussowitsch     PetscCall(ISDestroy(&glob));
1150076ba34aSJunchao Zhang     mm = static_cast<MatMatStruct *>(atb);
1151076ba34aSJunchao Zhang   } else if (ptype == MATPRODUCT_PtAP) {    /* BtAB */
1152076ba34aSJunchao Zhang     mmdata->mmAB  = new MatMatStruct_AB();  /* tmp=A*B */
1153076ba34aSJunchao Zhang     mmdata->mmAtB = new MatMatStruct_AtB(); /* C=B^t*tmp */
1154076ba34aSJunchao Zhang     auto ab       = mmdata->mmAB;
1155076ba34aSJunchao Zhang     auto atb      = mmdata->mmAtB;
11569566063dSJacob Faibussowitsch     PetscCall(MatProductSymbolic_MPIAIJKokkos_AB(product, A, B, ab));
1157076ba34aSJunchao Zhang     auto tmp = new Mat_SeqAIJKokkos(ab->C_global); /* Memory will be owned by ab->C_petsc */
11589566063dSJacob Faibussowitsch     PetscCall(MatCreateSeqAIJKokkosWithCSRMatrix(PETSC_COMM_SELF, tmp, &ab->C_petsc));
11599566063dSJacob Faibussowitsch     PetscCall(MatProductSymbolic_MPIAIJKokkos_AtB(product, B, ab->C_petsc, PETSC_FALSE, N, l2g /*not used*/, atb));
1160076ba34aSJunchao Zhang     mm = static_cast<MatMatStruct *>(atb);
1161076ba34aSJunchao Zhang   }
1162076ba34aSJunchao Zhang   /* Split the C_global into petsc A, B format */
11639566063dSJacob Faibussowitsch   PetscCall(MatSetMPIAIJKokkosWithGlobalCSRMatrix(C, MAT_INITIAL_MATRIX, mm->C_global, mm->Cdstart));
1164076ba34aSJunchao Zhang   C->product->data       = mmdata;
1165076ba34aSJunchao Zhang   C->product->destroy    = MatProductDataDestroy_MPIAIJKokkos;
1166076ba34aSJunchao Zhang   C->ops->productnumeric = MatProductNumeric_MPIAIJKokkos;
1167076ba34aSJunchao Zhang   PetscFunctionReturn(0);
1168076ba34aSJunchao Zhang }
1169076ba34aSJunchao Zhang 
1170*d71ae5a4SJacob Faibussowitsch PETSC_INTERN PetscErrorCode MatProductSetFromOptions_MPIAIJKokkos(Mat mat)
1171*d71ae5a4SJacob Faibussowitsch {
1172076ba34aSJunchao Zhang   Mat_Product *product = mat->product;
1173076ba34aSJunchao Zhang   PetscBool    match   = PETSC_FALSE;
1174076ba34aSJunchao Zhang   PetscBool    usecpu  = PETSC_FALSE;
1175076ba34aSJunchao Zhang 
1176076ba34aSJunchao Zhang   PetscFunctionBegin;
1177076ba34aSJunchao Zhang   MatCheckProduct(mat, 1);
117848a46eb9SPierre Jolivet   if (!product->A->boundtocpu && !product->B->boundtocpu) PetscCall(PetscObjectTypeCompare((PetscObject)product->B, ((PetscObject)product->A)->type_name, &match));
1179076ba34aSJunchao Zhang   if (match) { /* we can always fallback to the CPU if requested */
1180076ba34aSJunchao Zhang     switch (product->type) {
1181076ba34aSJunchao Zhang     case MATPRODUCT_AB:
1182076ba34aSJunchao Zhang       if (product->api_user) {
1183d0609cedSBarry Smith         PetscOptionsBegin(PetscObjectComm((PetscObject)mat), ((PetscObject)mat)->prefix, "MatMatMult", "Mat");
11849566063dSJacob Faibussowitsch         PetscCall(PetscOptionsBool("-matmatmult_backend_cpu", "Use CPU code", "MatMatMult", usecpu, &usecpu, NULL));
1185d0609cedSBarry Smith         PetscOptionsEnd();
1186076ba34aSJunchao Zhang       } else {
1187d0609cedSBarry Smith         PetscOptionsBegin(PetscObjectComm((PetscObject)mat), ((PetscObject)mat)->prefix, "MatProduct_AB", "Mat");
11889566063dSJacob Faibussowitsch         PetscCall(PetscOptionsBool("-mat_product_algorithm_backend_cpu", "Use CPU code", "MatMatMult", usecpu, &usecpu, NULL));
1189d0609cedSBarry Smith         PetscOptionsEnd();
1190076ba34aSJunchao Zhang       }
1191076ba34aSJunchao Zhang       break;
1192076ba34aSJunchao Zhang     case MATPRODUCT_AtB:
1193076ba34aSJunchao Zhang       if (product->api_user) {
1194d0609cedSBarry Smith         PetscOptionsBegin(PetscObjectComm((PetscObject)mat), ((PetscObject)mat)->prefix, "MatTransposeMatMult", "Mat");
11959566063dSJacob Faibussowitsch         PetscCall(PetscOptionsBool("-mattransposematmult_backend_cpu", "Use CPU code", "MatTransposeMatMult", usecpu, &usecpu, NULL));
1196d0609cedSBarry Smith         PetscOptionsEnd();
1197076ba34aSJunchao Zhang       } else {
1198d0609cedSBarry Smith         PetscOptionsBegin(PetscObjectComm((PetscObject)mat), ((PetscObject)mat)->prefix, "MatProduct_AtB", "Mat");
11999566063dSJacob Faibussowitsch         PetscCall(PetscOptionsBool("-mat_product_algorithm_backend_cpu", "Use CPU code", "MatTransposeMatMult", usecpu, &usecpu, NULL));
1200d0609cedSBarry Smith         PetscOptionsEnd();
1201076ba34aSJunchao Zhang       }
1202076ba34aSJunchao Zhang       break;
1203076ba34aSJunchao Zhang     case MATPRODUCT_PtAP:
1204076ba34aSJunchao Zhang       if (product->api_user) {
1205d0609cedSBarry Smith         PetscOptionsBegin(PetscObjectComm((PetscObject)mat), ((PetscObject)mat)->prefix, "MatPtAP", "Mat");
12069566063dSJacob Faibussowitsch         PetscCall(PetscOptionsBool("-matptap_backend_cpu", "Use CPU code", "MatPtAP", usecpu, &usecpu, NULL));
1207d0609cedSBarry Smith         PetscOptionsEnd();
1208076ba34aSJunchao Zhang       } else {
1209d0609cedSBarry Smith         PetscOptionsBegin(PetscObjectComm((PetscObject)mat), ((PetscObject)mat)->prefix, "MatProduct_PtAP", "Mat");
12109566063dSJacob Faibussowitsch         PetscCall(PetscOptionsBool("-mat_product_algorithm_backend_cpu", "Use CPU code", "MatPtAP", usecpu, &usecpu, NULL));
1211d0609cedSBarry Smith         PetscOptionsEnd();
1212076ba34aSJunchao Zhang       }
1213076ba34aSJunchao Zhang       break;
1214*d71ae5a4SJacob Faibussowitsch     default:
1215*d71ae5a4SJacob Faibussowitsch       break;
1216076ba34aSJunchao Zhang     }
1217076ba34aSJunchao Zhang     match = (PetscBool)!usecpu;
1218076ba34aSJunchao Zhang   }
1219076ba34aSJunchao Zhang   if (match) {
1220076ba34aSJunchao Zhang     switch (product->type) {
1221076ba34aSJunchao Zhang     case MATPRODUCT_AB:
1222076ba34aSJunchao Zhang     case MATPRODUCT_AtB:
1223*d71ae5a4SJacob Faibussowitsch     case MATPRODUCT_PtAP:
1224*d71ae5a4SJacob Faibussowitsch       mat->ops->productsymbolic = MatProductSymbolic_MPIAIJKokkos;
1225*d71ae5a4SJacob Faibussowitsch       break;
1226*d71ae5a4SJacob Faibussowitsch     default:
1227*d71ae5a4SJacob Faibussowitsch       break;
1228076ba34aSJunchao Zhang     }
1229076ba34aSJunchao Zhang   }
1230076ba34aSJunchao Zhang   /* fallback to MPIAIJ ops */
123148a46eb9SPierre Jolivet   if (!mat->ops->productsymbolic) PetscCall(MatProductSetFromOptions_MPIAIJ(mat));
1232076ba34aSJunchao Zhang   PetscFunctionReturn(0);
1233076ba34aSJunchao Zhang }
1234076ba34aSJunchao Zhang 
1235*d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatSetPreallocationCOO_MPIAIJKokkos(Mat mat, PetscCount coo_n, PetscInt coo_i[], PetscInt coo_j[])
1236*d71ae5a4SJacob Faibussowitsch {
1237394ed5ebSJunchao Zhang   Mat_MPIAIJ       *mpiaij = (Mat_MPIAIJ *)mat->data;
1238cbc6b225SStefano Zampini   Mat_MPIAIJKokkos *mpikok;
123942550becSJunchao Zhang 
124042550becSJunchao Zhang   PetscFunctionBegin;
12419566063dSJacob Faibussowitsch   PetscCall(MatSetPreallocationCOO_MPIAIJ(mat, coo_n, coo_i, coo_j));
1242cbc6b225SStefano Zampini   mat->preallocated = PETSC_TRUE;
12439566063dSJacob Faibussowitsch   PetscCall(MatAssemblyBegin(mat, MAT_FINAL_ASSEMBLY));
12449566063dSJacob Faibussowitsch   PetscCall(MatAssemblyEnd(mat, MAT_FINAL_ASSEMBLY));
12459566063dSJacob Faibussowitsch   PetscCall(MatZeroEntries(mat));
1246cbc6b225SStefano Zampini   mpikok = static_cast<Mat_MPIAIJKokkos *>(mpiaij->spptr);
1247cbc6b225SStefano Zampini   delete mpikok;
1248394ed5ebSJunchao Zhang   mpiaij->spptr = new Mat_MPIAIJKokkos(mpiaij);
124942550becSJunchao Zhang   PetscFunctionReturn(0);
125042550becSJunchao Zhang }
125142550becSJunchao Zhang 
1252*d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatSetValuesCOO_MPIAIJKokkos(Mat mat, const PetscScalar v[], InsertMode imode)
1253*d71ae5a4SJacob Faibussowitsch {
1254394ed5ebSJunchao Zhang   Mat_MPIAIJ                 *mpiaij = static_cast<Mat_MPIAIJ *>(mat->data);
125542550becSJunchao Zhang   Mat_MPIAIJKokkos           *mpikok = static_cast<Mat_MPIAIJKokkos *>(mpiaij->spptr);
125642550becSJunchao Zhang   Mat                         A = mpiaij->A, B = mpiaij->B;
1257158ec288SJunchao Zhang   PetscCount                  Annz = mpiaij->Annz, Annz2 = mpiaij->Annz2, Bnnz = mpiaij->Bnnz, Bnnz2 = mpiaij->Bnnz2;
125842550becSJunchao Zhang   MatScalarKokkosView         Aa, Ba;
1259394ed5ebSJunchao Zhang   MatScalarKokkosView         v1;
126042550becSJunchao Zhang   MatScalarKokkosView        &vsend  = mpikok->sendbuf_d;
126142550becSJunchao Zhang   const MatScalarKokkosView  &v2     = mpikok->recvbuf_d;
1262158ec288SJunchao Zhang   const PetscCountKokkosView &Ajmap1 = mpikok->Ajmap1_d, Ajmap2 = mpikok->Ajmap2_d, Aimap2 = mpikok->Aimap2_d;
1263158ec288SJunchao Zhang   const PetscCountKokkosView &Bjmap1 = mpikok->Bjmap1_d, Bjmap2 = mpikok->Bjmap2_d, Bimap2 = mpikok->Bimap2_d;
1264394ed5ebSJunchao Zhang   const PetscCountKokkosView &Aperm1 = mpikok->Aperm1_d, Aperm2 = mpikok->Aperm2_d, Bperm1 = mpikok->Bperm1_d, Bperm2 = mpikok->Bperm2_d;
1265394ed5ebSJunchao Zhang   const PetscCountKokkosView &Cperm1 = mpikok->Cperm1_d;
126642550becSJunchao Zhang   PetscMemType                memtype;
126742550becSJunchao Zhang 
126842550becSJunchao Zhang   PetscFunctionBegin;
12699566063dSJacob Faibussowitsch   PetscCall(PetscGetMemType(v, &memtype)); /* Return PETSC_MEMTYPE_HOST when v is NULL */
127042550becSJunchao Zhang   if (PetscMemTypeHost(memtype)) {         /* If user gave v[] in host, we need to copy it to device if any */
1271394ed5ebSJunchao Zhang     v1 = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), MatScalarKokkosViewHost((PetscScalar *)v, mpiaij->coo_n));
127242550becSJunchao Zhang   } else {
1273394ed5ebSJunchao Zhang     v1 = MatScalarKokkosView((PetscScalar *)v, mpiaij->coo_n); /* Directly use v[]'s memory */
127442550becSJunchao Zhang   }
127542550becSJunchao Zhang 
127642550becSJunchao Zhang   if (imode == INSERT_VALUES) {
12779566063dSJacob Faibussowitsch     PetscCall(MatSeqAIJGetKokkosViewWrite(A, &Aa)); /* write matrix values */
12789566063dSJacob Faibussowitsch     PetscCall(MatSeqAIJGetKokkosViewWrite(B, &Ba));
1279394ed5ebSJunchao Zhang   } else {
12809566063dSJacob Faibussowitsch     PetscCall(MatSeqAIJGetKokkosView(A, &Aa)); /* read & write matrix values */
12819566063dSJacob Faibussowitsch     PetscCall(MatSeqAIJGetKokkosView(B, &Ba));
128242550becSJunchao Zhang   }
128342550becSJunchao Zhang 
128442550becSJunchao Zhang   /* Pack entries to be sent to remote */
12859371c9d4SSatish Balay   Kokkos::parallel_for(
12869371c9d4SSatish Balay     vsend.extent(0), KOKKOS_LAMBDA(const PetscCount i) { vsend(i) = v1(Cperm1(i)); });
128742550becSJunchao Zhang 
128842550becSJunchao Zhang   /* Send remote entries to their owner and overlap the communication with local computation */
12899566063dSJacob Faibussowitsch   PetscCall(PetscSFReduceWithMemTypeBegin(mpiaij->coo_sf, MPIU_SCALAR, PETSC_MEMTYPE_KOKKOS, vsend.data(), PETSC_MEMTYPE_KOKKOS, v2.data(), MPI_REPLACE));
1290158ec288SJunchao Zhang   /* Add local entries to A and B in one kernel */
12919371c9d4SSatish Balay   Kokkos::parallel_for(
12929371c9d4SSatish Balay     Annz + Bnnz, KOKKOS_LAMBDA(PetscCount i) {
1293158ec288SJunchao Zhang       PetscScalar sum = 0.0;
1294158ec288SJunchao Zhang       if (i < Annz) {
1295158ec288SJunchao Zhang         for (PetscCount k = Ajmap1(i); k < Ajmap1(i + 1); k++) sum += v1(Aperm1(k));
1296ac38520cSJunchao Zhang         Aa(i) = (imode == INSERT_VALUES ? 0.0 : Aa(i)) + sum;
1297158ec288SJunchao Zhang       } else {
1298158ec288SJunchao Zhang         i -= Annz;
1299158ec288SJunchao Zhang         for (PetscCount k = Bjmap1(i); k < Bjmap1(i + 1); k++) sum += v1(Bperm1(k));
1300ac38520cSJunchao Zhang         Ba(i) = (imode == INSERT_VALUES ? 0.0 : Ba(i)) + sum;
1301158ec288SJunchao Zhang       }
1302158ec288SJunchao Zhang     });
13039566063dSJacob Faibussowitsch   PetscCall(PetscSFReduceEnd(mpiaij->coo_sf, MPIU_SCALAR, vsend.data(), v2.data(), MPI_REPLACE));
130442550becSJunchao Zhang 
1305158ec288SJunchao Zhang   /* Add received remote entries to A and B in one kernel */
13069371c9d4SSatish Balay   Kokkos::parallel_for(
13079371c9d4SSatish Balay     Annz2 + Bnnz2, KOKKOS_LAMBDA(PetscCount i) {
1308158ec288SJunchao Zhang       if (i < Annz2) {
1309158ec288SJunchao Zhang         for (PetscCount k = Ajmap2(i); k < Ajmap2(i + 1); k++) Aa(Aimap2(i)) += v2(Aperm2(k));
1310158ec288SJunchao Zhang       } else {
1311158ec288SJunchao Zhang         i -= Annz2;
1312158ec288SJunchao Zhang         for (PetscCount k = Bjmap2(i); k < Bjmap2(i + 1); k++) Ba(Bimap2(i)) += v2(Bperm2(k));
1313158ec288SJunchao Zhang       }
1314158ec288SJunchao Zhang     });
131542550becSJunchao Zhang 
1316394ed5ebSJunchao Zhang   if (imode == INSERT_VALUES) {
13179566063dSJacob Faibussowitsch     PetscCall(MatSeqAIJRestoreKokkosViewWrite(A, &Aa)); /* Increase A & B's state etc. */
13189566063dSJacob Faibussowitsch     PetscCall(MatSeqAIJRestoreKokkosViewWrite(B, &Ba));
1319394ed5ebSJunchao Zhang   } else {
13209566063dSJacob Faibussowitsch     PetscCall(MatSeqAIJRestoreKokkosView(A, &Aa));
13219566063dSJacob Faibussowitsch     PetscCall(MatSeqAIJRestoreKokkosView(B, &Ba));
1322394ed5ebSJunchao Zhang   }
132342550becSJunchao Zhang   PetscFunctionReturn(0);
132442550becSJunchao Zhang }
132542550becSJunchao Zhang 
1326*d71ae5a4SJacob Faibussowitsch PetscErrorCode MatDestroy_MPIAIJKokkos(Mat A)
1327*d71ae5a4SJacob Faibussowitsch {
132842550becSJunchao Zhang   Mat_MPIAIJ *mpiaij = (Mat_MPIAIJ *)A->data;
1329076ba34aSJunchao Zhang 
1330076ba34aSJunchao Zhang   PetscFunctionBegin;
13319566063dSJacob Faibussowitsch   PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatMPIAIJSetPreallocation_C", NULL));
13329566063dSJacob Faibussowitsch   PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatMPIAIJGetLocalMatMerge_C", NULL));
13339566063dSJacob Faibussowitsch   PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatSetPreallocationCOO_C", NULL));
13349566063dSJacob Faibussowitsch   PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatSetValuesCOO_C", NULL));
133542550becSJunchao Zhang   delete (Mat_MPIAIJKokkos *)mpiaij->spptr;
13369566063dSJacob Faibussowitsch   PetscCall(MatDestroy_MPIAIJ(A));
1337076ba34aSJunchao Zhang   PetscFunctionReturn(0);
1338076ba34aSJunchao Zhang }
1339076ba34aSJunchao Zhang 
1340*d71ae5a4SJacob Faibussowitsch PETSC_INTERN PetscErrorCode MatConvert_MPIAIJ_MPIAIJKokkos(Mat A, MatType mtype, MatReuse reuse, Mat *newmat)
1341*d71ae5a4SJacob Faibussowitsch {
13428c3ff71bSJunchao Zhang   Mat         B;
1343076ba34aSJunchao Zhang   Mat_MPIAIJ *a;
13448c3ff71bSJunchao Zhang 
13458c3ff71bSJunchao Zhang   PetscFunctionBegin;
13468c3ff71bSJunchao Zhang   if (reuse == MAT_INITIAL_MATRIX) {
13479566063dSJacob Faibussowitsch     PetscCall(MatDuplicate(A, MAT_COPY_VALUES, newmat));
13488c3ff71bSJunchao Zhang   } else if (reuse == MAT_REUSE_MATRIX) {
13499566063dSJacob Faibussowitsch     PetscCall(MatCopy(A, *newmat, SAME_NONZERO_PATTERN));
13508c3ff71bSJunchao Zhang   }
13518c3ff71bSJunchao Zhang   B = *newmat;
13528c3ff71bSJunchao Zhang 
13536f3d89d0SStefano Zampini   B->boundtocpu = PETSC_FALSE;
13549566063dSJacob Faibussowitsch   PetscCall(PetscFree(B->defaultvectype));
13559566063dSJacob Faibussowitsch   PetscCall(PetscStrallocpy(VECKOKKOS, &B->defaultvectype));
13569566063dSJacob Faibussowitsch   PetscCall(PetscObjectChangeTypeName((PetscObject)B, MATMPIAIJKOKKOS));
13578c3ff71bSJunchao Zhang 
1358076ba34aSJunchao Zhang   a = static_cast<Mat_MPIAIJ *>(A->data);
13599566063dSJacob Faibussowitsch   if (a->A) PetscCall(MatSetType(a->A, MATSEQAIJKOKKOS));
13609566063dSJacob Faibussowitsch   if (a->B) PetscCall(MatSetType(a->B, MATSEQAIJKOKKOS));
13619566063dSJacob Faibussowitsch   if (a->lvec) PetscCall(VecSetType(a->lvec, VECSEQKOKKOS));
1362076ba34aSJunchao Zhang 
13638c3ff71bSJunchao Zhang   B->ops->assemblyend           = MatAssemblyEnd_MPIAIJKokkos;
13648c3ff71bSJunchao Zhang   B->ops->mult                  = MatMult_MPIAIJKokkos;
13658c3ff71bSJunchao Zhang   B->ops->multadd               = MatMultAdd_MPIAIJKokkos;
13668c3ff71bSJunchao Zhang   B->ops->multtranspose         = MatMultTranspose_MPIAIJKokkos;
1367076ba34aSJunchao Zhang   B->ops->productsetfromoptions = MatProductSetFromOptions_MPIAIJKokkos;
1368076ba34aSJunchao Zhang   B->ops->destroy               = MatDestroy_MPIAIJKokkos;
13698c3ff71bSJunchao Zhang 
13709566063dSJacob Faibussowitsch   PetscCall(PetscObjectComposeFunction((PetscObject)B, "MatMPIAIJSetPreallocation_C", MatMPIAIJSetPreallocation_MPIAIJKokkos));
13719566063dSJacob Faibussowitsch   PetscCall(PetscObjectComposeFunction((PetscObject)B, "MatMPIAIJGetLocalMatMerge_C", MatMPIAIJGetLocalMatMerge_MPIAIJKokkos));
13729566063dSJacob Faibussowitsch   PetscCall(PetscObjectComposeFunction((PetscObject)B, "MatSetPreallocationCOO_C", MatSetPreallocationCOO_MPIAIJKokkos));
13739566063dSJacob Faibussowitsch   PetscCall(PetscObjectComposeFunction((PetscObject)B, "MatSetValuesCOO_C", MatSetValuesCOO_MPIAIJKokkos));
13748c3ff71bSJunchao Zhang   PetscFunctionReturn(0);
13758c3ff71bSJunchao Zhang }
13763f3ba80aSJunchao Zhang /*MC
137711a5261eSBarry Smith    MATAIJKOKKOS - "mpiaijkokkos", a matrix type to be used for CSR sparse matrices with Kokkos
13788c3ff71bSJunchao Zhang 
13793f3ba80aSJunchao Zhang    A matrix type type using Kokkos-Kernels CrsMatrix type for portability across different device types
13803f3ba80aSJunchao Zhang 
13813f3ba80aSJunchao Zhang    Options Database Keys:
13823f3ba80aSJunchao Zhang .  -mat_type aijkokkos - sets the matrix type to "aijkokkos" during a call to MatSetFromOptions()
13833f3ba80aSJunchao Zhang 
13843f3ba80aSJunchao Zhang   Level: beginner
13853f3ba80aSJunchao Zhang 
138611a5261eSBarry Smith .seealso: `MatCreateAIJKokkos()`, `MATSEQAIJKOKKOS`, `MATSEQAIJ`, `MATMPIAIJ`
13873f3ba80aSJunchao Zhang M*/
1388*d71ae5a4SJacob Faibussowitsch PETSC_EXTERN PetscErrorCode MatCreate_MPIAIJKokkos(Mat A)
1389*d71ae5a4SJacob Faibussowitsch {
13908c3ff71bSJunchao Zhang   PetscFunctionBegin;
13919566063dSJacob Faibussowitsch   PetscCall(PetscKokkosInitializeCheck());
13929566063dSJacob Faibussowitsch   PetscCall(MatCreate_MPIAIJ(A));
13939566063dSJacob Faibussowitsch   PetscCall(MatConvert_MPIAIJ_MPIAIJKokkos(A, MATMPIAIJKOKKOS, MAT_INPLACE_MATRIX, &A));
13948c3ff71bSJunchao Zhang   PetscFunctionReturn(0);
13958c3ff71bSJunchao Zhang }
13968c3ff71bSJunchao Zhang 
13978c3ff71bSJunchao Zhang /*@C
139811a5261eSBarry Smith    MatCreateAIJKokkos - Creates a sparse matrix in `MATAIJKOKOS` (compressed row) format
13998c3ff71bSJunchao Zhang    (the default parallel PETSc format).  This matrix will ultimately pushed down
14008c3ff71bSJunchao Zhang    to Kokkos for calculations. For good matrix
14018c3ff71bSJunchao Zhang    assembly performance the user should preallocate the matrix storage by setting
14028c3ff71bSJunchao Zhang    the parameter nz (or the array nnz).  By setting these parameters accurately,
14038c3ff71bSJunchao Zhang    performance during matrix assembly can be increased by more than a factor of 50.
14048c3ff71bSJunchao Zhang 
14058c3ff71bSJunchao Zhang    Collective
14068c3ff71bSJunchao Zhang 
14078c3ff71bSJunchao Zhang    Input Parameters:
140811a5261eSBarry Smith +  comm - MPI communicator, set to `PETSC_COMM_SELF`
14098c3ff71bSJunchao Zhang .  m - number of rows
14108c3ff71bSJunchao Zhang .  n - number of columns
14118c3ff71bSJunchao Zhang .  nz - number of nonzeros per row (same for all rows)
14128c3ff71bSJunchao Zhang -  nnz - array containing the number of nonzeros in the various rows
14138c3ff71bSJunchao Zhang          (possibly different for each row) or NULL
14148c3ff71bSJunchao Zhang 
14158c3ff71bSJunchao Zhang    Output Parameter:
14168c3ff71bSJunchao Zhang .  A - the matrix
14178c3ff71bSJunchao Zhang 
141811a5261eSBarry Smith    It is recommended that one use the `MatCreate()`, `MatSetType()` and/or `MatSetFromOptions()`,
14198c3ff71bSJunchao Zhang    MatXXXXSetPreallocation() paradigm instead of this routine directly.
142011a5261eSBarry Smith    [MatXXXXSetPreallocation() is, for example, `MatSeqAIJSetPreallocation()`]
14218c3ff71bSJunchao Zhang 
14228c3ff71bSJunchao Zhang    Notes:
14238c3ff71bSJunchao Zhang    If nnz is given then nz is ignored
14248c3ff71bSJunchao Zhang 
142511a5261eSBarry Smith    The AIJ format, also called compressed row storage), is fully compatible with standard Fortran 77
14268c3ff71bSJunchao Zhang    storage.  That is, the stored row and column indices can begin at
14278c3ff71bSJunchao Zhang    either one (as in Fortran) or zero.  See the users' manual for details.
14288c3ff71bSJunchao Zhang 
14298c3ff71bSJunchao Zhang    Specify the preallocated storage with either nz or nnz (not both).
143011a5261eSBarry Smith    Set nz = `PETSC_DEFAULT` and nnz = NULL for PETSc to control dynamic memory
14318c3ff71bSJunchao Zhang    allocation.  For large problems you MUST preallocate memory or you
14328c3ff71bSJunchao Zhang    will get TERRIBLE performance, see the users' manual chapter on matrices.
14338c3ff71bSJunchao Zhang 
14348c3ff71bSJunchao Zhang    By default, this format uses inodes (identical nodes) when possible, to
14358c3ff71bSJunchao Zhang    improve numerical efficiency of matrix-vector products and solves. We
14368c3ff71bSJunchao Zhang    search for consecutive rows with the same nonzero structure, thereby
14378c3ff71bSJunchao Zhang    reusing matrix information to achieve increased efficiency.
14388c3ff71bSJunchao Zhang 
14398c3ff71bSJunchao Zhang    Level: intermediate
14408c3ff71bSJunchao Zhang 
144111a5261eSBarry Smith .seealso: `MATAIJKOKOS`, `MATSEQAIJKOKOS`, `MATMPIAIJKOKOS`, `MatCreate()`, `MatCreateAIJ()`, `MatSetValues()`, `MatSeqAIJSetColumnIndices()`, `MatCreateSeqAIJWithArrays()`, `MatCreateAIJ()`, `MATMPIAIJKOKKOS`, `MATAIJKOKKOS`
14428c3ff71bSJunchao Zhang @*/
1443*d71ae5a4SJacob Faibussowitsch PetscErrorCode MatCreateAIJKokkos(MPI_Comm comm, PetscInt m, PetscInt n, PetscInt M, PetscInt N, PetscInt d_nz, const PetscInt d_nnz[], PetscInt o_nz, const PetscInt o_nnz[], Mat *A)
1444*d71ae5a4SJacob Faibussowitsch {
14458c3ff71bSJunchao Zhang   PetscMPIInt size;
14468c3ff71bSJunchao Zhang 
14478c3ff71bSJunchao Zhang   PetscFunctionBegin;
14489566063dSJacob Faibussowitsch   PetscCall(MatCreate(comm, A));
14499566063dSJacob Faibussowitsch   PetscCall(MatSetSizes(*A, m, n, M, N));
14509566063dSJacob Faibussowitsch   PetscCallMPI(MPI_Comm_size(comm, &size));
14518c3ff71bSJunchao Zhang   if (size > 1) {
14529566063dSJacob Faibussowitsch     PetscCall(MatSetType(*A, MATMPIAIJKOKKOS));
14539566063dSJacob Faibussowitsch     PetscCall(MatMPIAIJSetPreallocation(*A, d_nz, d_nnz, o_nz, o_nnz));
14548c3ff71bSJunchao Zhang   } else {
14559566063dSJacob Faibussowitsch     PetscCall(MatSetType(*A, MATSEQAIJKOKKOS));
14569566063dSJacob Faibussowitsch     PetscCall(MatSeqAIJSetPreallocation(*A, d_nz, d_nnz));
14578c3ff71bSJunchao Zhang   }
14588c3ff71bSJunchao Zhang   PetscFunctionReturn(0);
14598c3ff71bSJunchao Zhang }
14608c3ff71bSJunchao Zhang 
1461a587d139SMark // get GPU pointer to stripped down Mat. For both Seq and MPI Mat.
1462*d71ae5a4SJacob Faibussowitsch PetscErrorCode MatKokkosGetDeviceMatWrite(Mat A, PetscSplitCSRDataStructure *B)
1463*d71ae5a4SJacob Faibussowitsch {
1464a587d139SMark   PetscMPIInt                size, rank;
1465a587d139SMark   MPI_Comm                   comm;
1466042217e8SBarry Smith   PetscSplitCSRDataStructure d_mat = NULL;
1467a587d139SMark 
1468a587d139SMark   PetscFunctionBegin;
14699566063dSJacob Faibussowitsch   PetscCall(PetscObjectGetComm((PetscObject)A, &comm));
14709566063dSJacob Faibussowitsch   PetscCallMPI(MPI_Comm_size(comm, &size));
14719566063dSJacob Faibussowitsch   PetscCallMPI(MPI_Comm_rank(comm, &rank));
1472a587d139SMark   if (size == 1) {
14739566063dSJacob Faibussowitsch     PetscCall(MatSeqAIJKokkosGetDeviceMat(A, &d_mat));
14749566063dSJacob Faibussowitsch     PetscCall(MatSeqAIJKokkosModifyDevice(A)); /* Since we are going to modify matrix values on device */
1475a587d139SMark   } else {
1476a587d139SMark     Mat_MPIAIJ *aij = (Mat_MPIAIJ *)A->data;
14779566063dSJacob Faibussowitsch     PetscCall(MatSeqAIJKokkosGetDeviceMat(aij->A, &d_mat));
14789566063dSJacob Faibussowitsch     PetscCall(MatSeqAIJKokkosModifyDevice(aij->A));
14799566063dSJacob Faibussowitsch     PetscCall(MatSeqAIJKokkosModifyDevice(aij->B));
14802c71b3e2SJacob Faibussowitsch     PetscCheck(A->nooffprocentries || aij->donotstash, PetscObjectComm((PetscObject)A), PETSC_ERR_SUP, "Device assembly does not currently support offproc values insertion. Use MatSetOption(A,MAT_NO_OFF_PROC_ENTRIES,PETSC_TRUE) or MatSetOption(A,MAT_IGNORE_OFF_PROC_ENTRIES,PETSC_TRUE)");
1481a587d139SMark   }
1482a587d139SMark   // act like MatSetValues because not called on host
1483a587d139SMark   if (A->assembled) {
148448a46eb9SPierre Jolivet     if (A->was_assembled) PetscCall(PetscInfo(A, "Assemble more than once already\n"));
1485a587d139SMark     A->was_assembled = PETSC_TRUE; // this is done (lazy) in MatAssemble but we are not calling it anymore - done in AIJ AssemblyEnd, need here?
1486a587d139SMark   } else {
14879566063dSJacob Faibussowitsch     PetscCall(PetscInfo(A, "Warning !assemble ??? assembled=%" PetscInt_FMT "\n", A->assembled));
1488a587d139SMark   }
1489a587d139SMark   if (!d_mat) {
1490042217e8SBarry Smith     struct _n_SplitCSRMat h_mat; /* host container */
1491a587d139SMark     Mat_SeqAIJKokkos     *aijkokA;
1492a587d139SMark     Mat_SeqAIJ           *jaca;
1493a587d139SMark     PetscInt              n = A->rmap->n, nnz;
1494a587d139SMark     Mat                   Amat;
1495042217e8SBarry Smith     PetscInt             *colmap;
1496042217e8SBarry Smith 
1497042217e8SBarry Smith     /* create and copy h_mat */
149849b994a9SMark Adams     h_mat.M = A->cmap->N; // use for debug build
14999566063dSJacob Faibussowitsch     PetscCall(PetscInfo(A, "Create device matrix in Kokkos\n"));
1500a587d139SMark     if (size == 1) {
1501a587d139SMark       Amat            = A;
1502a587d139SMark       jaca            = (Mat_SeqAIJ *)A->data;
15039371c9d4SSatish Balay       h_mat.rstart    = 0;
15049371c9d4SSatish Balay       h_mat.rend      = A->rmap->n;
15059371c9d4SSatish Balay       h_mat.cstart    = 0;
15069371c9d4SSatish Balay       h_mat.cend      = A->cmap->n;
1507a587d139SMark       h_mat.offdiag.i = h_mat.offdiag.j = NULL;
1508a587d139SMark       h_mat.offdiag.a                   = NULL;
1509a587d139SMark       aijkokA                           = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
1510a587d139SMark     } else {
1511a587d139SMark       Mat_MPIAIJ       *aij  = (Mat_MPIAIJ *)A->data;
1512a587d139SMark       Mat_SeqAIJ       *jacb = (Mat_SeqAIJ *)aij->B->data;
1513a587d139SMark       PetscInt          ii;
1514a587d139SMark       Mat_SeqAIJKokkos *aijkokB;
1515042217e8SBarry Smith 
1516a587d139SMark       Amat    = aij->A;
1517a587d139SMark       aijkokA = static_cast<Mat_SeqAIJKokkos *>(aij->A->spptr);
1518a587d139SMark       aijkokB = static_cast<Mat_SeqAIJKokkos *>(aij->B->spptr);
1519a587d139SMark       jaca    = (Mat_SeqAIJ *)aij->A->data;
152008401ef6SPierre Jolivet       PetscCheck(!aij->B->cmap->n || aij->garray, comm, PETSC_ERR_PLIB, "MPIAIJ Matrix was assembled but is missing garray");
152108401ef6SPierre Jolivet       PetscCheck(aij->B->rmap->n == aij->A->rmap->n, comm, PETSC_ERR_SUP, "Only support aij->B->rmap->n == aij->A->rmap->n");
1522a587d139SMark       aij->donotstash          = PETSC_TRUE;
1523a587d139SMark       aij->A->nooffprocentries = aij->B->nooffprocentries = A->nooffprocentries = PETSC_TRUE;
1524a5b23f4aSJose E. Roman       jaca->nonew = jacb->nonew = PETSC_TRUE; // no more disassembly
15259566063dSJacob Faibussowitsch       PetscCall(PetscCalloc1(A->cmap->N, &colmap));
1526042217e8SBarry Smith       for (ii = 0; ii < aij->B->cmap->n; ii++) colmap[aij->garray[ii]] = ii + 1;
1527a587d139SMark       // allocate B copy data
15289371c9d4SSatish Balay       h_mat.rstart = A->rmap->rstart;
15299371c9d4SSatish Balay       h_mat.rend   = A->rmap->rend;
15309371c9d4SSatish Balay       h_mat.cstart = A->cmap->rstart;
15319371c9d4SSatish Balay       h_mat.cend   = A->cmap->rend;
1532a587d139SMark       nnz          = jacb->i[n];
1533a587d139SMark       if (jacb->compressedrow.use) {
1534a587d139SMark         const Kokkos::View<PetscInt *, Kokkos::HostSpace, Kokkos::MemoryTraits<Kokkos::Unmanaged>> h_i_k(jacb->i, n + 1);
1535300d22a6SJunchao Zhang         aijkokB->i_uncompressed_d = Kokkos::View<PetscInt *>(Kokkos::create_mirror(DefaultMemorySpace(), h_i_k));
1536300d22a6SJunchao Zhang         Kokkos::deep_copy(aijkokB->i_uncompressed_d, h_i_k);
1537300d22a6SJunchao Zhang         h_mat.offdiag.i = aijkokB->i_uncompressed_d.data();
1538a587d139SMark       } else {
153999551766SMark Adams         h_mat.offdiag.i = aijkokB->i_device_data();
1540a587d139SMark       }
154199551766SMark Adams       h_mat.offdiag.j = aijkokB->j_device_data();
1542076ba34aSJunchao Zhang       h_mat.offdiag.a = aijkokB->a_device_data();
1543a587d139SMark       {
1544042217e8SBarry Smith         Kokkos::View<PetscInt *, Kokkos::HostSpace, Kokkos::MemoryTraits<Kokkos::Unmanaged>> h_colmap_k(colmap, A->cmap->N);
1545300d22a6SJunchao Zhang         aijkokB->colmap_d = Kokkos::View<PetscInt *>(Kokkos::create_mirror(DefaultMemorySpace(), h_colmap_k));
1546300d22a6SJunchao Zhang         Kokkos::deep_copy(aijkokB->colmap_d, h_colmap_k);
1547300d22a6SJunchao Zhang         h_mat.colmap = aijkokB->colmap_d.data();
15489566063dSJacob Faibussowitsch         PetscCall(PetscFree(colmap));
1549a587d139SMark       }
1550a587d139SMark       h_mat.offdiag.ignorezeroentries = jacb->ignorezeroentries;
1551a587d139SMark       h_mat.offdiag.n                 = n;
1552a587d139SMark     }
1553a587d139SMark     // allocate A copy data
1554a587d139SMark     nnz                          = jaca->i[n];
1555a587d139SMark     h_mat.diag.n                 = n;
1556a587d139SMark     h_mat.diag.ignorezeroentries = jaca->ignorezeroentries;
15579566063dSJacob Faibussowitsch     PetscCallMPI(MPI_Comm_rank(comm, &h_mat.rank));
1558aed4548fSBarry Smith     PetscCheck(!jaca->compressedrow.use, PETSC_COMM_SELF, PETSC_ERR_PLIB, "A does not suppport compressed row (todo)");
155999551766SMark Adams     h_mat.diag.i = aijkokA->i_device_data();
156099551766SMark Adams     h_mat.diag.j = aijkokA->j_device_data();
1561076ba34aSJunchao Zhang     h_mat.diag.a = aijkokA->a_device_data();
1562a587d139SMark     // copy pointers and metdata to device
15639566063dSJacob Faibussowitsch     PetscCall(MatSeqAIJKokkosSetDeviceMat(Amat, &h_mat));
15649566063dSJacob Faibussowitsch     PetscCall(MatSeqAIJKokkosGetDeviceMat(Amat, &d_mat));
15659566063dSJacob Faibussowitsch     PetscCall(PetscInfo(A, "Create device Mat n=%" PetscInt_FMT " nnz=%" PetscInt_FMT "\n", h_mat.diag.n, nnz));
1566a587d139SMark   }
1567a587d139SMark   *B           = d_mat;       // return it, set it in Mat, and set it up
1568a587d139SMark   A->assembled = PETSC_FALSE; // ready to write with matsetvalues - this done (lazy) in normal MatSetValues
1569a587d139SMark   PetscFunctionReturn(0);
1570a587d139SMark }
1571076ba34aSJunchao Zhang 
1572*d71ae5a4SJacob Faibussowitsch PETSC_INTERN PetscErrorCode MatSeqAIJKokkosGetOffloadMask(Mat A, const char **mask)
1573*d71ae5a4SJacob Faibussowitsch {
1574076ba34aSJunchao Zhang   Mat_SeqAIJKokkos *aijkok = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
1575076ba34aSJunchao Zhang 
1576076ba34aSJunchao Zhang   PetscFunctionBegin;
1577076ba34aSJunchao Zhang   if (!aijkok) *mask = "AIJKOK_UNALLOCATED";
1578076ba34aSJunchao Zhang   else if (aijkok->a_dual.need_sync_host()) *mask = "PETSC_OFFLOAD_GPU";
1579076ba34aSJunchao Zhang   else if (aijkok->a_dual.need_sync_device()) *mask = "PETSC_OFFLOAD_CPU";
1580076ba34aSJunchao Zhang   else *mask = "PETSC_OFFLOAD_BOTH";
1581076ba34aSJunchao Zhang   PetscFunctionReturn(0);
1582076ba34aSJunchao Zhang }
1583076ba34aSJunchao Zhang 
1584*d71ae5a4SJacob Faibussowitsch PETSC_INTERN PetscErrorCode MatAIJKokkosPrintOffloadMask(Mat A)
1585*d71ae5a4SJacob Faibussowitsch {
1586076ba34aSJunchao Zhang   PetscMPIInt size;
1587076ba34aSJunchao Zhang   Mat         Ad, Ao;
1588076ba34aSJunchao Zhang   const char *amask, *bmask;
1589076ba34aSJunchao Zhang 
1590076ba34aSJunchao Zhang   PetscFunctionBegin;
15919566063dSJacob Faibussowitsch   PetscCallMPI(MPI_Comm_size(PetscObjectComm((PetscObject)A), &size));
1592076ba34aSJunchao Zhang 
1593076ba34aSJunchao Zhang   if (size == 1) {
15949566063dSJacob Faibussowitsch     PetscCall(MatSeqAIJKokkosGetOffloadMask(A, &amask));
15959566063dSJacob Faibussowitsch     PetscCall(PetscPrintf(PETSC_COMM_SELF, "%s\n", amask));
1596076ba34aSJunchao Zhang   } else {
1597076ba34aSJunchao Zhang     Ad = ((Mat_MPIAIJ *)A->data)->A;
1598076ba34aSJunchao Zhang     Ao = ((Mat_MPIAIJ *)A->data)->B;
15999566063dSJacob Faibussowitsch     PetscCall(MatSeqAIJKokkosGetOffloadMask(Ad, &amask));
16009566063dSJacob Faibussowitsch     PetscCall(MatSeqAIJKokkosGetOffloadMask(Ao, &bmask));
16019566063dSJacob Faibussowitsch     PetscCall(PetscPrintf(PETSC_COMM_SELF, "Diag : Off-diag = %s : %s\n", amask, bmask));
1602076ba34aSJunchao Zhang   }
1603076ba34aSJunchao Zhang   PetscFunctionReturn(0);
1604076ba34aSJunchao Zhang }
1605