xref: /petsc/src/mat/impls/aij/mpi/kokkos/mpiaijkok.kokkos.cxx (revision 3ba1676111f5c958fe6c2729b46ca4d523958bb3)
111d22bbfSJunchao Zhang #include <petscvec_kokkos.hpp>
2076ba34aSJunchao Zhang #include <petscsf.h>
342550becSJunchao Zhang #include <petsc/private/sfimpl.h>
48c3ff71bSJunchao Zhang #include <../src/mat/impls/aij/mpi/mpiaij.h>
542550becSJunchao Zhang #include <../src/mat/impls/aij/mpi/kokkos/mpiaijkok.hpp>
6076ba34aSJunchao Zhang #include <KokkosSparse_spadd.hpp>
711d22bbfSJunchao Zhang 
8d71ae5a4SJacob Faibussowitsch PetscErrorCode MatAssemblyEnd_MPIAIJKokkos(Mat A, MatAssemblyType mode)
9d71ae5a4SJacob Faibussowitsch {
105519a089SJose E. Roman   Mat_SeqAIJKokkos *aijkok;
1130203840SJunchao Zhang   Mat_MPIAIJ       *mpiaij = (Mat_MPIAIJ *)A->data;
128c3ff71bSJunchao Zhang 
138c3ff71bSJunchao Zhang   PetscFunctionBegin;
149566063dSJacob Faibussowitsch   PetscCall(MatAssemblyEnd_MPIAIJ(A, mode));
1530203840SJunchao Zhang   /* E.g., MatCreateSubMatrix() calls MatCreateMPIAIJWithSeqAIJ(comm,A,B,..), which creates Bnew of SEQAIJ and destroys B of SEQAIJKOKKOS.
1630203840SJunchao Zhang      Thus we finalize A/B/lvec's type in MatAssemblyEnd() to handle various cases.
1730203840SJunchao Zhang    */
1830203840SJunchao Zhang   if (mode == MAT_FINAL_ASSEMBLY) {
1930203840SJunchao Zhang     PetscCall(MatSetType(mpiaij->A, MATSEQAIJKOKKOS));
2030203840SJunchao Zhang     PetscCall(MatSetType(mpiaij->B, MATSEQAIJKOKKOS));
2130203840SJunchao Zhang     PetscCall(VecSetType(mpiaij->lvec, VECSEQKOKKOS));
2230203840SJunchao Zhang   }
235519a089SJose E. Roman   aijkok = static_cast<Mat_SeqAIJKokkos *>(((Mat_MPIAIJ *)A->data)->A->spptr); /* Access spptr after MatAssemblyEnd_MPIAIJ(), which might have deleted old spptr */
24a587d139SMark   if (aijkok && aijkok->device_mat_d.data()) {
25a587d139SMark     A->offloadmask = PETSC_OFFLOAD_GPU; // in GPU mode, no going back. MatSetValues checks this
26a587d139SMark   }
27a587d139SMark 
28*3ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
298c3ff71bSJunchao Zhang }
308c3ff71bSJunchao Zhang 
31d71ae5a4SJacob Faibussowitsch PetscErrorCode MatMPIAIJSetPreallocation_MPIAIJKokkos(Mat mat, PetscInt d_nz, const PetscInt d_nnz[], PetscInt o_nz, const PetscInt o_nnz[])
32d71ae5a4SJacob Faibussowitsch {
338c3ff71bSJunchao Zhang   Mat_MPIAIJ *mpiaij = (Mat_MPIAIJ *)mat->data;
348c3ff71bSJunchao Zhang 
358c3ff71bSJunchao Zhang   PetscFunctionBegin;
369566063dSJacob Faibussowitsch   PetscCall(PetscLayoutSetUp(mat->rmap));
379566063dSJacob Faibussowitsch   PetscCall(PetscLayoutSetUp(mat->cmap));
386a29ce69SStefano Zampini #if defined(PETSC_USE_DEBUG)
398c3ff71bSJunchao Zhang   if (d_nnz) {
406a29ce69SStefano Zampini     PetscInt i;
41ad540459SPierre Jolivet     for (i = 0; i < mat->rmap->n; i++) PetscCheck(d_nnz[i] >= 0, PETSC_COMM_SELF, PETSC_ERR_ARG_OUTOFRANGE, "d_nnz cannot be less than 0: local row %" PetscInt_FMT " value %" PetscInt_FMT, i, d_nnz[i]);
428c3ff71bSJunchao Zhang   }
438c3ff71bSJunchao Zhang   if (o_nnz) {
446a29ce69SStefano Zampini     PetscInt i;
45ad540459SPierre Jolivet     for (i = 0; i < mat->rmap->n; i++) PetscCheck(o_nnz[i] >= 0, PETSC_COMM_SELF, PETSC_ERR_ARG_OUTOFRANGE, "o_nnz cannot be less than 0: local row %" PetscInt_FMT " value %" PetscInt_FMT, i, o_nnz[i]);
468c3ff71bSJunchao Zhang   }
476a29ce69SStefano Zampini #endif
486a29ce69SStefano Zampini #if defined(PETSC_USE_CTABLE)
49eec179cfSJacob Faibussowitsch   PetscCall(PetscHMapIDestroy(&mpiaij->colmap));
506a29ce69SStefano Zampini #else
519566063dSJacob Faibussowitsch   PetscCall(PetscFree(mpiaij->colmap));
526a29ce69SStefano Zampini #endif
539566063dSJacob Faibussowitsch   PetscCall(PetscFree(mpiaij->garray));
549566063dSJacob Faibussowitsch   PetscCall(VecDestroy(&mpiaij->lvec));
559566063dSJacob Faibussowitsch   PetscCall(VecScatterDestroy(&mpiaij->Mvctx));
566a29ce69SStefano Zampini   /* Because the B will have been resized we simply destroy it and create a new one each time */
579566063dSJacob Faibussowitsch   PetscCall(MatDestroy(&mpiaij->B));
586a29ce69SStefano Zampini 
596a29ce69SStefano Zampini   if (!mpiaij->A) {
609566063dSJacob Faibussowitsch     PetscCall(MatCreate(PETSC_COMM_SELF, &mpiaij->A));
619566063dSJacob Faibussowitsch     PetscCall(MatSetSizes(mpiaij->A, mat->rmap->n, mat->cmap->n, mat->rmap->n, mat->cmap->n));
626a29ce69SStefano Zampini   }
636a29ce69SStefano Zampini   if (!mpiaij->B) {
646a29ce69SStefano Zampini     PetscMPIInt size;
659566063dSJacob Faibussowitsch     PetscCallMPI(MPI_Comm_size(PetscObjectComm((PetscObject)mat), &size));
669566063dSJacob Faibussowitsch     PetscCall(MatCreate(PETSC_COMM_SELF, &mpiaij->B));
679566063dSJacob Faibussowitsch     PetscCall(MatSetSizes(mpiaij->B, mat->rmap->n, size > 1 ? mat->cmap->N : 0, mat->rmap->n, size > 1 ? mat->cmap->N : 0));
688c3ff71bSJunchao Zhang   }
699566063dSJacob Faibussowitsch   PetscCall(MatSetType(mpiaij->A, MATSEQAIJKOKKOS));
709566063dSJacob Faibussowitsch   PetscCall(MatSetType(mpiaij->B, MATSEQAIJKOKKOS));
719566063dSJacob Faibussowitsch   PetscCall(MatSeqAIJSetPreallocation(mpiaij->A, d_nz, d_nnz));
729566063dSJacob Faibussowitsch   PetscCall(MatSeqAIJSetPreallocation(mpiaij->B, o_nz, o_nnz));
738c3ff71bSJunchao Zhang   mat->preallocated = PETSC_TRUE;
74*3ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
758c3ff71bSJunchao Zhang }
768c3ff71bSJunchao Zhang 
77d71ae5a4SJacob Faibussowitsch PetscErrorCode MatMult_MPIAIJKokkos(Mat mat, Vec xx, Vec yy)
78d71ae5a4SJacob Faibussowitsch {
798c3ff71bSJunchao Zhang   Mat_MPIAIJ *mpiaij = (Mat_MPIAIJ *)mat->data;
808c3ff71bSJunchao Zhang   PetscInt    nt;
818c3ff71bSJunchao Zhang 
828c3ff71bSJunchao Zhang   PetscFunctionBegin;
839566063dSJacob Faibussowitsch   PetscCall(VecGetLocalSize(xx, &nt));
8408401ef6SPierre Jolivet   PetscCheck(nt == mat->cmap->n, PETSC_COMM_SELF, PETSC_ERR_ARG_SIZ, "Incompatible partition of mat (%" PetscInt_FMT ") and xx (%" PetscInt_FMT ")", mat->cmap->n, nt);
859566063dSJacob Faibussowitsch   PetscCall(VecScatterBegin(mpiaij->Mvctx, xx, mpiaij->lvec, INSERT_VALUES, SCATTER_FORWARD));
869566063dSJacob Faibussowitsch   PetscCall((*mpiaij->A->ops->mult)(mpiaij->A, xx, yy));
879566063dSJacob Faibussowitsch   PetscCall(VecScatterEnd(mpiaij->Mvctx, xx, mpiaij->lvec, INSERT_VALUES, SCATTER_FORWARD));
889566063dSJacob Faibussowitsch   PetscCall((*mpiaij->B->ops->multadd)(mpiaij->B, mpiaij->lvec, yy, yy));
89*3ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
908c3ff71bSJunchao Zhang }
918c3ff71bSJunchao Zhang 
92d71ae5a4SJacob Faibussowitsch PetscErrorCode MatMultAdd_MPIAIJKokkos(Mat mat, Vec xx, Vec yy, Vec zz)
93d71ae5a4SJacob Faibussowitsch {
948c3ff71bSJunchao Zhang   Mat_MPIAIJ *mpiaij = (Mat_MPIAIJ *)mat->data;
958c3ff71bSJunchao Zhang   PetscInt    nt;
968c3ff71bSJunchao Zhang 
978c3ff71bSJunchao Zhang   PetscFunctionBegin;
989566063dSJacob Faibussowitsch   PetscCall(VecGetLocalSize(xx, &nt));
9908401ef6SPierre Jolivet   PetscCheck(nt == mat->cmap->n, PETSC_COMM_SELF, PETSC_ERR_ARG_SIZ, "Incompatible partition of mat (%" PetscInt_FMT ") and xx (%" PetscInt_FMT ")", mat->cmap->n, nt);
1009566063dSJacob Faibussowitsch   PetscCall(VecScatterBegin(mpiaij->Mvctx, xx, mpiaij->lvec, INSERT_VALUES, SCATTER_FORWARD));
1019566063dSJacob Faibussowitsch   PetscCall((*mpiaij->A->ops->multadd)(mpiaij->A, xx, yy, zz));
1029566063dSJacob Faibussowitsch   PetscCall(VecScatterEnd(mpiaij->Mvctx, xx, mpiaij->lvec, INSERT_VALUES, SCATTER_FORWARD));
1039566063dSJacob Faibussowitsch   PetscCall((*mpiaij->B->ops->multadd)(mpiaij->B, mpiaij->lvec, zz, zz));
104*3ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
1058c3ff71bSJunchao Zhang }
1068c3ff71bSJunchao Zhang 
107d71ae5a4SJacob Faibussowitsch PetscErrorCode MatMultTranspose_MPIAIJKokkos(Mat mat, Vec xx, Vec yy)
108d71ae5a4SJacob Faibussowitsch {
1098c3ff71bSJunchao Zhang   Mat_MPIAIJ *mpiaij = (Mat_MPIAIJ *)mat->data;
1108c3ff71bSJunchao Zhang   PetscInt    nt;
1118c3ff71bSJunchao Zhang 
1128c3ff71bSJunchao Zhang   PetscFunctionBegin;
1139566063dSJacob Faibussowitsch   PetscCall(VecGetLocalSize(xx, &nt));
11408401ef6SPierre Jolivet   PetscCheck(nt == mat->rmap->n, PETSC_COMM_SELF, PETSC_ERR_ARG_SIZ, "Incompatible partition of mat (%" PetscInt_FMT ") and xx (%" PetscInt_FMT ")", mat->rmap->n, nt);
1159566063dSJacob Faibussowitsch   PetscCall((*mpiaij->B->ops->multtranspose)(mpiaij->B, xx, mpiaij->lvec));
1169566063dSJacob Faibussowitsch   PetscCall((*mpiaij->A->ops->multtranspose)(mpiaij->A, xx, yy));
1179566063dSJacob Faibussowitsch   PetscCall(VecScatterBegin(mpiaij->Mvctx, mpiaij->lvec, yy, ADD_VALUES, SCATTER_REVERSE));
1189566063dSJacob Faibussowitsch   PetscCall(VecScatterEnd(mpiaij->Mvctx, mpiaij->lvec, yy, ADD_VALUES, SCATTER_REVERSE));
119*3ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
1208c3ff71bSJunchao Zhang }
1218c3ff71bSJunchao Zhang 
122076ba34aSJunchao Zhang /* Merge the "A, B" matrices of mat into a matrix C.  mat's type is MPIAIJKOKKOS. C's type is MATSEQAIJKOKKOS.
123076ba34aSJunchao Zhang    A is put before B. C's size would be A->rmap->n by (A->cmap->n + B->cmap->n).
124076ba34aSJunchao Zhang    C still uses local column ids. Their corresponding global column ids are returned in glob.
125076ba34aSJunchao Zhang */
126d71ae5a4SJacob Faibussowitsch PetscErrorCode MatMPIAIJGetLocalMatMerge_MPIAIJKokkos(Mat mat, MatReuse reuse, IS *glob, Mat *C)
127d71ae5a4SJacob Faibussowitsch {
128076ba34aSJunchao Zhang   Mat             Ad, Ao;
129076ba34aSJunchao Zhang   const PetscInt *cmap;
130076ba34aSJunchao Zhang 
131076ba34aSJunchao Zhang   PetscFunctionBegin;
1329566063dSJacob Faibussowitsch   PetscCall(MatMPIAIJGetSeqAIJ(mat, &Ad, &Ao, &cmap));
1339566063dSJacob Faibussowitsch   PetscCall(MatSeqAIJKokkosMergeMats(Ad, Ao, reuse, C));
134076ba34aSJunchao Zhang   if (glob) {
135076ba34aSJunchao Zhang     PetscInt cst, i, dn, on, *gidx;
1369566063dSJacob Faibussowitsch     PetscCall(MatGetLocalSize(Ad, NULL, &dn));
1379566063dSJacob Faibussowitsch     PetscCall(MatGetLocalSize(Ao, NULL, &on));
1389566063dSJacob Faibussowitsch     PetscCall(MatGetOwnershipRangeColumn(mat, &cst, NULL));
1399566063dSJacob Faibussowitsch     PetscCall(PetscMalloc1(dn + on, &gidx));
140076ba34aSJunchao Zhang     for (i = 0; i < dn; i++) gidx[i] = cst + i;
141076ba34aSJunchao Zhang     for (i = 0; i < on; i++) gidx[i + dn] = cmap[i];
1429566063dSJacob Faibussowitsch     PetscCall(ISCreateGeneral(PetscObjectComm((PetscObject)Ad), dn + on, gidx, PETSC_OWN_POINTER, glob));
143076ba34aSJunchao Zhang   }
144*3ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
145076ba34aSJunchao Zhang }
146076ba34aSJunchao Zhang 
147076ba34aSJunchao Zhang /* Structs used in matrix product C=AB, C=A^tB and C=B^tAB */
148076ba34aSJunchao Zhang struct MatMatStruct {
149076ba34aSJunchao Zhang   MatRowMapKokkosView Cdstart; /* Used to split sequential matrix into petsc's A, B format */
150076ba34aSJunchao Zhang   PetscSF             sf;      /* SF to send/recv matrix entries */
151076ba34aSJunchao Zhang   MatScalarKokkosView abuf;    /* buf of mat values in send/recv */
152076ba34aSJunchao Zhang   Mat                 C1, C2, B_local;
153076ba34aSJunchao Zhang   KokkosCsrMatrix     C1_global, C2_global, C_global;
154076ba34aSJunchao Zhang   KernelHandle        kh;
155*3ba16761SJacob Faibussowitsch   MatMatStruct() noexcept : sf(nullptr), C1(nullptr), C2(nullptr), B_local(nullptr) { }
156076ba34aSJunchao Zhang 
157d71ae5a4SJacob Faibussowitsch   ~MatMatStruct()
158d71ae5a4SJacob Faibussowitsch   {
159*3ba16761SJacob Faibussowitsch     PetscFunctionBegin;
160*3ba16761SJacob Faibussowitsch     PetscCallAbort(PETSC_COMM_SELF, MatDestroy(&C1));
161*3ba16761SJacob Faibussowitsch     PetscCallAbort(PETSC_COMM_SELF, MatDestroy(&C2));
162*3ba16761SJacob Faibussowitsch     PetscCallAbort(PETSC_COMM_SELF, MatDestroy(&B_local));
163*3ba16761SJacob Faibussowitsch     PetscCallAbort(PETSC_COMM_SELF, PetscSFDestroy(&sf));
164076ba34aSJunchao Zhang     kh.destroy_spadd_handle();
165*3ba16761SJacob Faibussowitsch     PetscFunctionReturnVoid();
166076ba34aSJunchao Zhang   }
167076ba34aSJunchao Zhang };
168076ba34aSJunchao Zhang 
169076ba34aSJunchao Zhang struct MatMatStruct_AB : public MatMatStruct {
170*3ba16761SJacob Faibussowitsch   MatColIdxKokkosView rows{};
171*3ba16761SJacob Faibussowitsch   MatRowMapKokkosView rowoffset{};
172*3ba16761SJacob Faibussowitsch   Mat                 B_other{}, C_petsc{}; /* SEQAIJKOKKOS matrices. TODO: have a better var name than C_petsc */
173076ba34aSJunchao Zhang 
174*3ba16761SJacob Faibussowitsch   ~MatMatStruct_AB() noexcept
175d71ae5a4SJacob Faibussowitsch   {
176*3ba16761SJacob Faibussowitsch     PetscFunctionBegin;
177*3ba16761SJacob Faibussowitsch     PetscCallAbort(PETSC_COMM_SELF, MatDestroy(&B_other));
178*3ba16761SJacob Faibussowitsch     PetscCallAbort(PETSC_COMM_SELF, MatDestroy(&C_petsc));
179*3ba16761SJacob Faibussowitsch     PetscFunctionReturnVoid();
180076ba34aSJunchao Zhang   }
181076ba34aSJunchao Zhang };
182076ba34aSJunchao Zhang 
183076ba34aSJunchao Zhang struct MatMatStruct_AtB : public MatMatStruct {
184076ba34aSJunchao Zhang   MatRowMapKokkosView srcrowoffset, dstrowoffset;
185076ba34aSJunchao Zhang };
186076ba34aSJunchao Zhang 
1879371c9d4SSatish Balay struct MatProductData_MPIAIJKokkos {
188*3ba16761SJacob Faibussowitsch   MatMatStruct_AB  *mmAB     = nullptr;
189*3ba16761SJacob Faibussowitsch   MatMatStruct_AtB *mmAtB    = nullptr;
190*3ba16761SJacob Faibussowitsch   PetscBool         reusesym = PETSC_FALSE;
191076ba34aSJunchao Zhang 
192d71ae5a4SJacob Faibussowitsch   ~MatProductData_MPIAIJKokkos()
193d71ae5a4SJacob Faibussowitsch   {
194076ba34aSJunchao Zhang     delete mmAB;
195076ba34aSJunchao Zhang     delete mmAtB;
196076ba34aSJunchao Zhang   }
197076ba34aSJunchao Zhang };
198076ba34aSJunchao Zhang 
199d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatProductDataDestroy_MPIAIJKokkos(void *data)
200d71ae5a4SJacob Faibussowitsch {
201076ba34aSJunchao Zhang   PetscFunctionBegin;
2029566063dSJacob Faibussowitsch   PetscCallCXX(delete static_cast<MatProductData_MPIAIJKokkos *>(data));
203*3ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
204076ba34aSJunchao Zhang }
205076ba34aSJunchao Zhang 
206076ba34aSJunchao Zhang /* MatSeqAIJKokkosGetCSRMatrixWithGlobalColumnIds - Get a KokkosCsrMatrix from a MATSEQAIJKOKKOS matrix
207076ba34aSJunchao Zhang 
208076ba34aSJunchao Zhang    Input Parameters:
209076ba34aSJunchao Zhang +  A       - the MATSEQAIJKOKKOS matrix
210076ba34aSJunchao Zhang .  N       - new column size for the returned Kokkos matrix
211076ba34aSJunchao Zhang -  l2g     - a map that maps old col ids to new col ids
212076ba34aSJunchao Zhang 
213076ba34aSJunchao Zhang    Output Parameters:
214076ba34aSJunchao Zhang .  csrmat  - the Kokkos matrix, which has the same row size as A, shares a, i but not j with A.
215076ba34aSJunchao Zhang  */
216d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatSeqAIJKokkosGetCSRMatrixWithGlobalColumnIds(Mat A, PetscInt N, const ConstMatColIdxKokkosView &l2g, KokkosCsrMatrix &csrmat)
217d71ae5a4SJacob Faibussowitsch {
218076ba34aSJunchao Zhang   KokkosCsrMatrix    &orig = static_cast<Mat_SeqAIJKokkos *>(A->spptr)->csrmat;
219076ba34aSJunchao Zhang   MatColIdxKokkosView jg("jg", orig.nnz()); /* New j array for csrmat */
220076ba34aSJunchao Zhang 
221076ba34aSJunchao Zhang   PetscFunctionBegin;
2229371c9d4SSatish Balay   PetscCallCXX(Kokkos::parallel_for(
2239371c9d4SSatish Balay     orig.nnz(), KOKKOS_LAMBDA(const PetscInt i) { jg(i) = l2g(orig.graph.entries(i)); }));
2249566063dSJacob Faibussowitsch   PetscCallCXX(csrmat = KokkosCsrMatrix("csrmat", orig.numRows(), N, orig.nnz(), orig.values, orig.graph.row_map, jg));
225*3ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
226076ba34aSJunchao Zhang }
227076ba34aSJunchao Zhang 
228076ba34aSJunchao Zhang /* MatSetMPIAIJKokkosWithSplitSeqAIJKokkosMatrices - Set the diag and offdiag matrices of a MATMPIAIJKOKKOS matrix.
229076ba34aSJunchao Zhang    It is similar to MatCreateMPIAIJWithSplitArrays.
230076ba34aSJunchao Zhang 
231076ba34aSJunchao Zhang   Input Parameters:
232076ba34aSJunchao Zhang +  mat   - the MATMPIAIJKOKKOS matrix, which should have its type and layout set, but should not have its diag, offdiag matrices set
233076ba34aSJunchao Zhang .  A     - the diag matrix using local col ids
234076ba34aSJunchao Zhang -  B     - the offdiag matrix using global col ids
235076ba34aSJunchao Zhang 
236076ba34aSJunchao Zhang   Output Parameters:
237076ba34aSJunchao Zhang .  mat   - the updated MATMPIAIJKOKKOS matrix
238076ba34aSJunchao Zhang */
239d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatSetMPIAIJKokkosWithSplitSeqAIJKokkosMatrices(Mat mat, Mat A, Mat B)
240d71ae5a4SJacob Faibussowitsch {
241076ba34aSJunchao Zhang   Mat_MPIAIJ       *mpiaij = static_cast<Mat_MPIAIJ *>(mat->data);
242076ba34aSJunchao Zhang   PetscInt          m, n, M, N, Am, An, Bm, Bn;
243076ba34aSJunchao Zhang   Mat_SeqAIJKokkos *bkok = static_cast<Mat_SeqAIJKokkos *>(B->spptr);
244076ba34aSJunchao Zhang 
245076ba34aSJunchao Zhang   PetscFunctionBegin;
2469566063dSJacob Faibussowitsch   PetscCall(MatGetSize(mat, &M, &N));
2479566063dSJacob Faibussowitsch   PetscCall(MatGetLocalSize(mat, &m, &n));
2489566063dSJacob Faibussowitsch   PetscCall(MatGetLocalSize(A, &Am, &An));
2499566063dSJacob Faibussowitsch   PetscCall(MatGetLocalSize(B, &Bm, &Bn));
250076ba34aSJunchao Zhang 
251aed4548fSBarry Smith   PetscCheck(m == Am && m == Bm, PETSC_COMM_SELF, PETSC_ERR_PLIB, "local number of rows do not match");
25208401ef6SPierre Jolivet   PetscCheck(n == An, PETSC_COMM_SELF, PETSC_ERR_PLIB, "local number of columns do not match");
25308401ef6SPierre Jolivet   PetscCheck(N == Bn, PETSC_COMM_SELF, PETSC_ERR_PLIB, "global number of columns do not match");
25408401ef6SPierre Jolivet   PetscCheck(!mpiaij->A && !mpiaij->B, PETSC_COMM_SELF, PETSC_ERR_PLIB, "A, B of the MPIAIJ matrix are not empty");
255076ba34aSJunchao Zhang   mpiaij->A = A;
256076ba34aSJunchao Zhang   mpiaij->B = B;
257076ba34aSJunchao Zhang 
258076ba34aSJunchao Zhang   mat->preallocated     = PETSC_TRUE;
259076ba34aSJunchao Zhang   mat->nooffprocentries = PETSC_TRUE; /* See MatAssemblyBegin_MPIAIJ. In effect, making MatAssemblyBegin a nop */
260076ba34aSJunchao Zhang 
2619566063dSJacob Faibussowitsch   PetscCall(MatSetOption(mat, MAT_NO_OFF_PROC_ENTRIES, PETSC_TRUE));
2629566063dSJacob Faibussowitsch   PetscCall(MatAssemblyBegin(mat, MAT_FINAL_ASSEMBLY));
263076ba34aSJunchao Zhang   /* MatAssemblyEnd is critical here. It sets mat->offloadmask according to A and B's, and
264076ba34aSJunchao Zhang     also gets mpiaij->B compacted, with its col ids and size reduced
265076ba34aSJunchao Zhang   */
2669566063dSJacob Faibussowitsch   PetscCall(MatAssemblyEnd(mat, MAT_FINAL_ASSEMBLY));
2679566063dSJacob Faibussowitsch   PetscCall(MatSetOption(mat, MAT_NO_OFF_PROC_ENTRIES, PETSC_FALSE));
2689566063dSJacob Faibussowitsch   PetscCall(MatSetOption(mat, MAT_NEW_NONZERO_LOCATION_ERR, PETSC_TRUE));
269076ba34aSJunchao Zhang 
270076ba34aSJunchao Zhang   /* Update bkok with new local col ids (stored on host) and size */
271076ba34aSJunchao Zhang   bkok->j_dual.modify_host();
272076ba34aSJunchao Zhang   bkok->j_dual.sync_device();
273076ba34aSJunchao Zhang   bkok->SetColSize(mpiaij->B->cmap->n);
274*3ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
275076ba34aSJunchao Zhang }
276076ba34aSJunchao Zhang 
277076ba34aSJunchao Zhang /* MatSeqAIJKokkosBcast - Bcast rows of a SEQAIJKOKKOS matrice (B) to form a SEQAIJKOKKOS matrix (C).
278076ba34aSJunchao Zhang 
279076ba34aSJunchao Zhang    It is essentially the MPIAIJKOKKOS counterpart of MatGetBrowsOfAoCols_MPIAIJ, but supports device and uses PetscSF.
280076ba34aSJunchao Zhang    In the given ownerSF, leaves correspond to rows in C, and roots correspond to rows in B. Roots may connect to multiple leaves.
281076ba34aSJunchao Zhang    Suppose C's j-th row is connected to a root identified by PetscSFNode (k,i), it means we will bcast the i-th row of B on rank k
282076ba34aSJunchao Zhang    to j-th row of C. ownerSF's leaves must be contiguous (in other words, as if ilocal=NULL was used to set its graph).
283076ba34aSJunchao Zhang 
284076ba34aSJunchao Zhang    Collective on comm of ownerSF
285076ba34aSJunchao Zhang 
286076ba34aSJunchao Zhang    Input Parameters:
287076ba34aSJunchao Zhang +   B       - the SEQAIJKOKKOS matrix, using local col ids
288076ba34aSJunchao Zhang .   reuse   - either MAT_INITIAL_MATRIX or MAT_REUSE_MATRIX
289076ba34aSJunchao Zhang .   N       - global col ids are in range of [0,N). N Must be the same across ranks (nonsignificant in MAT_REUSE_MATRIX)
290076ba34aSJunchao Zhang .   l2g     - a map mapping B's local col ids to global ones (nonsignificant in MAT_REUSE_MATRIX)
291076ba34aSJunchao Zhang .   ownerSF - the ownership SF (nonsignificant in MAT_REUSE_MATRIX)
292076ba34aSJunchao Zhang 
293076ba34aSJunchao Zhang    Input/Output Parameters (out when resue = MAT_INITIAL_MATRIX, inout when reuse = MAT_REUSE_MATRIX)
294076ba34aSJunchao Zhang +   bcastSF   - the SF used to bcast rows of B. This plain SF does buffer (abuf) to buffer (Ca) send/recv. In this SF, vertices are nonzeros.
295076ba34aSJunchao Zhang .   abuf      - buffer for sending matrix values
296076ba34aSJunchao Zhang .   rows      - array containing indices of (local) rows that this rank needs to bcast to others. Each receiver rank has a chunk in rows[].
297076ba34aSJunchao Zhang                 Values in rows[] might have repeats, which simply indicates a row will be bcast'ed to multiple neighbors.
298076ba34aSJunchao Zhang .   rowoffset - For each row in rows[], it will be copied to rowoffset[] at abuf[]
299076ba34aSJunchao Zhang -   C         -  the SEQAIJKOKKOS matrix made of the bcast'ed rows, using local col ids.
300076ba34aSJunchao Zhang */
301d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatSeqAIJKokkosBcast(Mat B, MatReuse reuse, PetscInt N, const ConstMatColIdxKokkosView &l2g, PetscSF ownerSF, PetscSF &bcastSF, MatScalarKokkosView &abuf, MatColIdxKokkosView &rows, MatRowMapKokkosView &rowoffset, Mat &C)
302d71ae5a4SJacob Faibussowitsch {
303076ba34aSJunchao Zhang   Mat_SeqAIJKokkos *bkok, *ckok;
304076ba34aSJunchao Zhang 
305076ba34aSJunchao Zhang   PetscFunctionBegin;
3069566063dSJacob Faibussowitsch   PetscCall(MatSeqAIJKokkosSyncDevice(B)); /* Make sure B->spptr is accessible */
307076ba34aSJunchao Zhang   bkok = static_cast<Mat_SeqAIJKokkos *>(B->spptr);
308076ba34aSJunchao Zhang 
309076ba34aSJunchao Zhang   if (reuse == MAT_REUSE_MATRIX) {
310076ba34aSJunchao Zhang     ckok = static_cast<Mat_SeqAIJKokkos *>(C->spptr);
311076ba34aSJunchao Zhang 
312076ba34aSJunchao Zhang     const auto &Ba = bkok->a_dual.view_device();
313076ba34aSJunchao Zhang     const auto &Bi = bkok->i_dual.view_device();
314076ba34aSJunchao Zhang     const auto &Ca = ckok->a_dual.view_device();
315076ba34aSJunchao Zhang 
316076ba34aSJunchao Zhang     /* Copy Ba to abuf */
3179371c9d4SSatish Balay     Kokkos::parallel_for(
3189371c9d4SSatish Balay       Kokkos::TeamPolicy<>(rows.extent(0), Kokkos::AUTO()), KOKKOS_LAMBDA(const KokkosTeamMemberType &t) {
319076ba34aSJunchao Zhang         PetscInt i    = t.league_rank(); /* rows[i] is r-th row of B */
320076ba34aSJunchao Zhang         PetscInt r    = rows(i);
321076ba34aSJunchao Zhang         PetscInt base = rowoffset(i); /* Copy r-th row of B to this offset in abuf[] */
3229371c9d4SSatish Balay         Kokkos::parallel_for(Kokkos::TeamThreadRange(t, Bi(r + 1) - Bi(r)), [&](PetscInt k) { abuf(base + k) = Ba(Bi(r) + k); });
323076ba34aSJunchao Zhang       });
324076ba34aSJunchao Zhang 
325076ba34aSJunchao Zhang     /* Send abuf to Ca through bcastSF and then mark C is updated on device */
3269566063dSJacob Faibussowitsch     PetscCall(PetscSFBcastBegin(bcastSF, MPIU_SCALAR, abuf.data(), Ca.data(), MPI_REPLACE)); /* TODO: get memtype for abuf */
3279566063dSJacob Faibussowitsch     PetscCall(PetscSFBcastEnd(bcastSF, MPIU_SCALAR, abuf.data(), Ca.data(), MPI_REPLACE));
328076ba34aSJunchao Zhang     ckok->a_dual.modify_device();
329076ba34aSJunchao Zhang   } else if (reuse == MAT_INITIAL_MATRIX) {
330076ba34aSJunchao Zhang     MPI_Comm    comm;
331076ba34aSJunchao Zhang     PetscMPIInt tag;
332076ba34aSJunchao Zhang     PetscInt    k, Cm, Cn, Cnnz, *Ci_h, nroots, nleaves;
333076ba34aSJunchao Zhang 
3349566063dSJacob Faibussowitsch     PetscCallMPI(PetscObjectGetComm((PetscObject)ownerSF, &comm));
3359566063dSJacob Faibussowitsch     PetscCall(PetscSFGetGraph(ownerSF, &nroots, &nleaves, NULL, NULL));
336076ba34aSJunchao Zhang     Cm = nleaves; /* row size of C */
337076ba34aSJunchao Zhang     Cn = N;       /* col size of C, which initially uses global ids, so we can safely set its col size as N */
338076ba34aSJunchao Zhang 
339076ba34aSJunchao Zhang     /* Get row lens (nz) of B's rows for later fast query */
340076ba34aSJunchao Zhang     PetscInt       *Browlens;
341076ba34aSJunchao Zhang     const PetscInt *tmp = bkok->i_host_data();
3429566063dSJacob Faibussowitsch     PetscCall(PetscMalloc1(nroots, &Browlens));
343076ba34aSJunchao Zhang     for (k = 0; k < nroots; k++) Browlens[k] = tmp[k + 1] - tmp[k];
344076ba34aSJunchao Zhang 
345076ba34aSJunchao Zhang     /* By ownerSF, each proc gets lens of rows of C */
346076ba34aSJunchao Zhang     MatRowMapKokkosDualView Ci("i", Cm + 1); /* C's rowmap */
347076ba34aSJunchao Zhang     Ci_h    = Ci.view_host().data();
348076ba34aSJunchao Zhang     Ci_h[0] = 0;
3499566063dSJacob Faibussowitsch     PetscCall(PetscSFBcastWithMemTypeBegin(ownerSF, MPIU_INT, PETSC_MEMTYPE_HOST, Browlens, PETSC_MEMTYPE_HOST, &Ci_h[1], MPI_REPLACE));
3509566063dSJacob Faibussowitsch     PetscCall(PetscSFBcastEnd(ownerSF, MPIU_INT, Browlens, &Ci_h[1], MPI_REPLACE));
351076ba34aSJunchao Zhang     for (k = 1; k < Cm + 1; k++) Ci_h[k] += Ci_h[k - 1]; /* Convert lens to CSR */
352076ba34aSJunchao Zhang     Cnnz = Ci_h[Cm];
353076ba34aSJunchao Zhang     Ci.modify_host();
354076ba34aSJunchao Zhang     Ci.sync_device();
355076ba34aSJunchao Zhang 
356076ba34aSJunchao Zhang     /* With the newly known Cnnz, we are able to allocate (j, a) for C on host & device */
357076ba34aSJunchao Zhang     MatColIdxKokkosDualView Cj("j", Cnnz);
358076ba34aSJunchao Zhang     MatScalarKokkosDualView Ca("a", Cnnz);
359076ba34aSJunchao Zhang 
360076ba34aSJunchao Zhang     /* Now build the bcastSF to fill Ca, Cj. This plain SF only does (contiguous) buffer to buffer send/recv */
361076ba34aSJunchao Zhang     const PetscMPIInt *iranks, *ranks;
362076ba34aSJunchao Zhang     const PetscInt    *ioffset, *irootloc, *roffset;
363076ba34aSJunchao Zhang     PetscInt           i, j, niranks, nranks, *sdisp, *rdisp, *rowptr;
364076ba34aSJunchao Zhang     MPI_Request       *reqs;
365076ba34aSJunchao Zhang 
3669566063dSJacob Faibussowitsch     PetscCall(PetscSFGetLeafRanks(ownerSF, &niranks, &iranks, &ioffset, &irootloc));                      /* irootloc[] contains indices of rows I need to send to each receiver */
3679566063dSJacob Faibussowitsch     PetscCall(PetscSFGetRootRanks(ownerSF, &nranks, &ranks, &roffset, NULL /*rmine*/, NULL /*rremote*/)); /* recv info */
368076ba34aSJunchao Zhang 
369076ba34aSJunchao Zhang     /* figure out offsets at the send buffer, to build the SF
370076ba34aSJunchao Zhang       sdisp[]  - stores offsets of nonzeros (in abuf or jbuf, see later) I need to send, per receiver.
371076ba34aSJunchao Zhang       rowptr[] - stores offsets for data of each row in abuf
372076ba34aSJunchao Zhang 
373076ba34aSJunchao Zhang       rdisp[]  - to receive sdisp[]
374076ba34aSJunchao Zhang     */
3759566063dSJacob Faibussowitsch     PetscCall(PetscMalloc3(niranks + 1, &sdisp, nranks, &rdisp, niranks + nranks, &reqs));
376076ba34aSJunchao Zhang     MatRowMapKokkosViewHost rowptr_h("rowptr_h", ioffset[niranks] + 1); /* Let Kokkos do the allocation, so that we can do an easy mirror later */
377076ba34aSJunchao Zhang     rowptr = rowptr_h.data();
378076ba34aSJunchao Zhang 
379076ba34aSJunchao Zhang     sdisp[0]  = 0;
380076ba34aSJunchao Zhang     rowptr[0] = 0;
381076ba34aSJunchao Zhang     for (i = 0; i < niranks; i++) { /* for each receiver */
382076ba34aSJunchao Zhang       PetscInt len, nz = 0;
383076ba34aSJunchao Zhang       for (j = ioffset[i]; j < ioffset[i + 1]; j++) { /* for each row to this receiver */
384076ba34aSJunchao Zhang         len           = Browlens[irootloc[j]];
385076ba34aSJunchao Zhang         rowptr[j + 1] = rowptr[j] + len;
386076ba34aSJunchao Zhang         nz += len;
387076ba34aSJunchao Zhang       }
388076ba34aSJunchao Zhang       sdisp[i + 1] = sdisp[i] + nz;
389076ba34aSJunchao Zhang     }
3909566063dSJacob Faibussowitsch     PetscCallMPI(PetscCommGetNewTag(comm, &tag));
3919566063dSJacob Faibussowitsch     for (i = 0; i < nranks; i++) PetscCallMPI(MPI_Irecv(&rdisp[i], 1, MPIU_INT, ranks[i], tag, comm, &reqs[i]));
3929566063dSJacob Faibussowitsch     for (i = 0; i < niranks; i++) PetscCallMPI(MPI_Isend(&sdisp[i], 1, MPIU_INT, iranks[i], tag, comm, &reqs[nranks + i]));
3939566063dSJacob Faibussowitsch     PetscCallMPI(MPI_Waitall(niranks + nranks, reqs, MPI_STATUSES_IGNORE));
394076ba34aSJunchao Zhang 
395076ba34aSJunchao Zhang     PetscInt     nleaves2 = Cnnz;           /* leaves are the nonzeros I will receive */
396076ba34aSJunchao Zhang     PetscInt     nroots2  = sdisp[niranks]; /* roots are the nonzeros (in abuf) I will send */
397076ba34aSJunchao Zhang     PetscSFNode *iremote;
3989566063dSJacob Faibussowitsch     PetscCall(PetscMalloc1(nleaves2, &iremote));
399076ba34aSJunchao Zhang     for (i = 0; i < nranks; i++) { /* for each sender */
400076ba34aSJunchao Zhang       k = 0;
401076ba34aSJunchao Zhang       for (j = Ci_h[roffset[i]]; j < Ci_h[roffset[i + 1]]; j++) {
402076ba34aSJunchao Zhang         iremote[j].rank  = ranks[i];
403076ba34aSJunchao Zhang         iremote[j].index = rdisp[i] + k;
404076ba34aSJunchao Zhang         k++;
405076ba34aSJunchao Zhang       }
406076ba34aSJunchao Zhang     }
407076ba34aSJunchao Zhang     /* TODO: we should extend PetscSF APIs for this buffer-to-buffer send/recv */
4089566063dSJacob Faibussowitsch     PetscCall(PetscSFCreate(comm, &bcastSF));
4099566063dSJacob Faibussowitsch     PetscCall(PetscSFSetGraph(bcastSF, nroots2, nleaves2, NULL /*ilocal*/, PETSC_OWN_POINTER, iremote, PETSC_OWN_POINTER));
410076ba34aSJunchao Zhang 
411076ba34aSJunchao Zhang     /* Extract selected rows of B, and copy their (a, j) into abuf[] and jbuf[], with j converted
412076ba34aSJunchao Zhang       from local to global. Then use bcastSF to fill Ca, Cj.
413076ba34aSJunchao Zhang     */
414076ba34aSJunchao Zhang     ConstMatColIdxKokkosViewHost rows_h(irootloc, ioffset[niranks]); /* irootloc[] stores indices of rows I need to send */
415076ba34aSJunchao Zhang     MatColIdxKokkosView          rows("rows", ioffset[niranks]);
416076ba34aSJunchao Zhang     Kokkos::deep_copy(rows, rows_h); /* Use deep copy since irootoc is managed by PetscSF and we want 'rows' to be standalone */
417076ba34aSJunchao Zhang 
418076ba34aSJunchao Zhang     rowoffset = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), rowptr_h); /* If no device, rowoffset will be an alias to rowptr_h */
419076ba34aSJunchao Zhang 
420076ba34aSJunchao Zhang     MatColIdxKokkosView jbuf("jbuf", sdisp[niranks]);   /* send buf for (global) col ids */
421076ba34aSJunchao Zhang     abuf = MatScalarKokkosView("abuf", sdisp[niranks]); /* send buf for mat values */
422076ba34aSJunchao Zhang 
423076ba34aSJunchao Zhang     const auto &Ba = bkok->a_dual.view_device();
424076ba34aSJunchao Zhang     const auto &Bi = bkok->i_dual.view_device();
425076ba34aSJunchao Zhang     const auto &Bj = bkok->j_dual.view_device();
426076ba34aSJunchao Zhang 
427076ba34aSJunchao Zhang     /* Copy Ba, Bj to abuf, jbuf with change col ids from local to global */
4289371c9d4SSatish Balay     Kokkos::parallel_for(
4299371c9d4SSatish Balay       Kokkos::TeamPolicy<>(rows.extent(0), Kokkos::AUTO()), KOKKOS_LAMBDA(const KokkosTeamMemberType &t) {
430076ba34aSJunchao Zhang         PetscInt i    = t.league_rank(); /* rows[i] is r-th row of B */
431076ba34aSJunchao Zhang         PetscInt r    = rows(i);
432076ba34aSJunchao Zhang         PetscInt base = rowoffset(i); /* Copy r-th row of B to this offset in abuf[] */
433076ba34aSJunchao Zhang         Kokkos::parallel_for(Kokkos::TeamThreadRange(t, Bi(r + 1) - Bi(r)), [&](PetscInt k) {
434076ba34aSJunchao Zhang           abuf(base + k) = Ba(Bi(r) + k);
435076ba34aSJunchao Zhang           jbuf(base + k) = l2g(Bj(Bi(r) + k));
436076ba34aSJunchao Zhang         });
437076ba34aSJunchao Zhang       });
438076ba34aSJunchao Zhang 
439076ba34aSJunchao Zhang     /* Send abuf & jbuf to fill Ca, Cj */
4409566063dSJacob Faibussowitsch     PetscCall(PetscSFBcastBegin(bcastSF, MPIU_INT, jbuf.data(), Cj.view_device().data(), MPI_REPLACE));
4419566063dSJacob Faibussowitsch     PetscCall(PetscSFBcastBegin(bcastSF, MPIU_SCALAR, abuf.data(), Ca.view_device().data(), MPI_REPLACE));
4429566063dSJacob Faibussowitsch     PetscCall(PetscSFBcastEnd(bcastSF, MPIU_INT, jbuf.data(), Cj.view_device().data(), MPI_REPLACE));
4439566063dSJacob Faibussowitsch     PetscCall(PetscSFBcastEnd(bcastSF, MPIU_SCALAR, abuf.data(), Ca.view_device().data(), MPI_REPLACE));
444076ba34aSJunchao Zhang     Cj.modify_device(); /* Mark Cj, Ca modified on device, but only sync Cj since we might not need Ca on host at all */
445076ba34aSJunchao Zhang     Cj.sync_host();
446076ba34aSJunchao Zhang     Ca.modify_device();
447076ba34aSJunchao Zhang 
448076ba34aSJunchao Zhang     /* Construct C with Ca, Ci, Cj */
449076ba34aSJunchao Zhang     auto ckok = new Mat_SeqAIJKokkos(Cm, Cn, Cnnz, Ci, Cj, Ca);
4509566063dSJacob Faibussowitsch     PetscCall(MatCreateSeqAIJKokkosWithCSRMatrix(PETSC_COMM_SELF, ckok, &C));
4519566063dSJacob Faibussowitsch     PetscCall(PetscFree3(sdisp, rdisp, reqs));
4529566063dSJacob Faibussowitsch     PetscCall(PetscFree(Browlens));
45398921bdaSJacob Faibussowitsch   } else SETERRQ(PETSC_COMM_SELF, PETSC_ERR_PLIB, "Unsupported MatReuse enum %d", reuse);
454*3ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
455076ba34aSJunchao Zhang }
456076ba34aSJunchao Zhang 
457076ba34aSJunchao Zhang /* MatSeqAIJKokkosReduce - Reduce rows of a SEQAIJKOKKOS matrix (A) to form a Kokkos Csr matrix (C)
458076ba34aSJunchao Zhang 
459076ba34aSJunchao Zhang   It is the reverse of MatSeqAIJKokkosBcast in some sense.
460076ba34aSJunchao Zhang 
461076ba34aSJunchao Zhang   Think each row of A as a leaf, then the given ownerSF specifies roots of the leaves. Roots may connect to multiple leaves.
462076ba34aSJunchao Zhang   In this routine, we reduce (i.e., concatenate) leaves (rows) at their roots to form potentially longer rows in C. Such rows might
463076ba34aSJunchao Zhang   contain repeats, which does not matter since they will be summed up by other routines. C's row size will be nroots of ownerSF.
464076ba34aSJunchao Zhang 
465076ba34aSJunchao Zhang   Input Parameters:
466076ba34aSJunchao Zhang +  A        - the SEQAIJKOKKOS matrix to be reduced
467076ba34aSJunchao Zhang .  reuse    - either MAT_INITIAL_MATRIX or MAT_REUSE_MATRIX
468076ba34aSJunchao Zhang .  local    - true if A uses local col ids; false if A is already in global col ids.
469076ba34aSJunchao Zhang .  N        - if local, N is A's global col size
470076ba34aSJunchao Zhang .  l2g      - if local, a map mapping A's local col ids to global ones, which are in range of [0,N).
471076ba34aSJunchao Zhang -  ownerSF  - the SF specifies ownership (root) of rows in A
472076ba34aSJunchao Zhang 
473076ba34aSJunchao Zhang   Output Parameters:
474076ba34aSJunchao Zhang +  reduceSF    - the SF to reduce A's rows to contiguous buffers at the receiver side
475076ba34aSJunchao Zhang .  abuf         - a contiguous buffer to receive A's rows sent to this proc. Suppose there are 'nrows' such rows.
476076ba34aSJunchao Zhang .  srcrowoffset - offset array of size nrows+1. Each entry is the corresponding row's offset in abuf[]. srcrowoffset[i+1]-srcrowoffset[i] is row i's len.
477076ba34aSJunchao Zhang .  dstrowoffset - offset array of size nrows. Each entry is the corresponding row's offset in Ca[], i.e., C's 'a' array. Row i, i+1 in abuf[] may go to
478076ba34aSJunchao Zhang                   unrelated places in Ca, so dstrowoffset is not in CSR-like format as srcrowoffset.
479076ba34aSJunchao Zhang -  C            - the matrix made up by rows sent to me from other ranks, using global col ids
480076ba34aSJunchao Zhang 
481076ba34aSJunchao Zhang    TODO: we can even have MatSeqAIJKokkosReduceBegin/End to provide oppertunity for callers to overlap comp./comm. when reuse = MAT_REUSE_MATRIX.
482076ba34aSJunchao Zhang  */
483d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatSeqAIJKokkosReduce(Mat A, MatReuse reuse, PetscBool local, PetscInt N, const ConstMatColIdxKokkosView &l2g, PetscSF ownerSF, PetscSF &reduceSF, MatScalarKokkosView &abuf, MatRowMapKokkosView &srcrowoffset, MatRowMapKokkosView &dstrowoffset, KokkosCsrMatrix &C)
484d71ae5a4SJacob Faibussowitsch {
485076ba34aSJunchao Zhang   PetscInt          i, r, Am, An, Annz, Cnnz, nrows;
486076ba34aSJunchao Zhang   const PetscInt   *Ai;
487076ba34aSJunchao Zhang   Mat_SeqAIJKokkos *akok;
488076ba34aSJunchao Zhang 
489076ba34aSJunchao Zhang   PetscFunctionBegin;
4909566063dSJacob Faibussowitsch   PetscCall(MatSeqAIJKokkosSyncDevice(A)); /* So that A's latest data is on device */
4919566063dSJacob Faibussowitsch   PetscCall(MatGetSize(A, &Am, &An));
492076ba34aSJunchao Zhang   Ai   = static_cast<Mat_SeqAIJ *>(A->data)->i;
493076ba34aSJunchao Zhang   akok = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
494076ba34aSJunchao Zhang   Annz = Ai[Am];
495076ba34aSJunchao Zhang 
496076ba34aSJunchao Zhang   if (reuse == MAT_REUSE_MATRIX) {
497076ba34aSJunchao Zhang     /* Send Aa to abuf */
4989566063dSJacob Faibussowitsch     PetscCallMPI(PetscSFReduceBegin(reduceSF, MPIU_SCALAR, akok->a_device_data(), abuf.data(), MPI_REPLACE));
4999566063dSJacob Faibussowitsch     PetscCallMPI(PetscSFReduceEnd(reduceSF, MPIU_SCALAR, akok->a_device_data(), abuf.data(), MPI_REPLACE));
500076ba34aSJunchao Zhang 
501076ba34aSJunchao Zhang     /* Copy abuf to Ca */
502076ba34aSJunchao Zhang     const MatScalarKokkosView &Ca = C.values;
503076ba34aSJunchao Zhang     nrows                         = dstrowoffset.extent(0); /* Not srcrowoffset[] since it has an extra entry for CSR */
5049371c9d4SSatish Balay     Kokkos::parallel_for(
5059371c9d4SSatish Balay       Kokkos::TeamPolicy<>(nrows, Kokkos::AUTO()), KOKKOS_LAMBDA(const KokkosTeamMemberType &t) {
506076ba34aSJunchao Zhang         PetscInt i   = t.league_rank();
507076ba34aSJunchao Zhang         PetscInt src = srcrowoffset(i), dst = dstrowoffset(i);
508076ba34aSJunchao Zhang         PetscInt len = srcrowoffset(i + 1) - srcrowoffset(i);
509076ba34aSJunchao Zhang         Kokkos::parallel_for(Kokkos::TeamThreadRange(t, len), [&](PetscInt k) { Ca(dst + k) = abuf(src + k); });
510076ba34aSJunchao Zhang       });
511076ba34aSJunchao Zhang   } else if (reuse == MAT_INITIAL_MATRIX) {
512076ba34aSJunchao Zhang     MPI_Comm     comm;
513076ba34aSJunchao Zhang     MPI_Request *reqs;
514076ba34aSJunchao Zhang     PetscMPIInt  tag;
515076ba34aSJunchao Zhang     PetscInt     Cm;
516076ba34aSJunchao Zhang 
5179566063dSJacob Faibussowitsch     PetscCall(PetscObjectGetComm((PetscObject)ownerSF, &comm));
5189566063dSJacob Faibussowitsch     PetscCall(PetscCommGetNewTag(comm, &tag));
519076ba34aSJunchao Zhang 
520076ba34aSJunchao Zhang     PetscInt           niranks, nranks, nroots, nleaves;
521076ba34aSJunchao Zhang     const PetscMPIInt *iranks, *ranks;
522076ba34aSJunchao Zhang     const PetscInt    *ioffset, *rows, *roffset; /* rows[] contains local indices of rows scattered to me from others. ioffset[] is a CSR on rows[] */
5239566063dSJacob Faibussowitsch     PetscCall(PetscSFSetUp(ownerSF));
5249566063dSJacob Faibussowitsch     PetscCall(PetscSFGetLeafRanks(ownerSF, &niranks, &iranks, &ioffset, &rows));                          /* recv info: iranks[] will send rows to me */
5259566063dSJacob Faibussowitsch     PetscCall(PetscSFGetRootRanks(ownerSF, &nranks, &ranks, &roffset, NULL /*rmine*/, NULL /*rremote*/)); /* send info */
5269566063dSJacob Faibussowitsch     PetscCall(PetscSFGetGraph(ownerSF, &nroots, &nleaves, NULL, NULL));
52708401ef6SPierre Jolivet     PetscCheck(nleaves == Am, PETSC_COMM_SELF, PETSC_ERR_PLIB, "ownerSF's nleaves(%" PetscInt_FMT ") != row size of A(%" PetscInt_FMT ")", nleaves, Am);
528076ba34aSJunchao Zhang     Cm    = nroots;
529076ba34aSJunchao Zhang     nrows = ioffset[niranks]; /* # of rows to be received. Might receive same row (each is partial) from different senders */
530076ba34aSJunchao Zhang 
531076ba34aSJunchao Zhang     /* Tell owners how long each row I will send */
532076ba34aSJunchao Zhang     PetscInt               *srowlens;                              /* send buf of row lens */
533076ba34aSJunchao Zhang     MatRowMapKokkosViewHost rrowlens_h("rrowoffset_h", nrows + 1); /* recv buf of row lens. +1 to make CSR later. Memory might be passed to other views */
534076ba34aSJunchao Zhang     PetscInt               *rrowlens = rrowlens_h.data();
535076ba34aSJunchao Zhang 
5369566063dSJacob Faibussowitsch     PetscCall(PetscMalloc2(Am, &srowlens, niranks + nranks, &reqs));
537076ba34aSJunchao Zhang     for (i = 0; i < Am; i++) srowlens[i] = Ai[i + 1] - Ai[i];
538076ba34aSJunchao Zhang     rrowlens[0] = 0;
539076ba34aSJunchao Zhang     rrowlens++; /* shift the pointer to make the following expression more readable */
5409566063dSJacob Faibussowitsch     for (i = 0; i < niranks; i++) PetscCallMPI(MPI_Irecv(&rrowlens[ioffset[i]], ioffset[i + 1] - ioffset[i], MPIU_INT, iranks[i], tag, comm, &reqs[i]));
5419566063dSJacob Faibussowitsch     for (i = 0; i < nranks; i++) PetscCallMPI(MPI_Isend(&srowlens[roffset[i]], roffset[i + 1] - roffset[i], MPIU_INT, ranks[i], tag, comm, &reqs[niranks + i]));
5429566063dSJacob Faibussowitsch     PetscCallMPI(MPI_Waitall(niranks + nranks, reqs, MPI_STATUSES_IGNORE));
543076ba34aSJunchao Zhang 
544076ba34aSJunchao Zhang     /* Owner builds Ci on host by histogramming rrowlens[] */
545076ba34aSJunchao Zhang     MatRowMapKokkosViewHost Ci_h("i", Cm + 1);
546076ba34aSJunchao Zhang     Kokkos::deep_copy(Ci_h, 0); /* Zero Ci */
547076ba34aSJunchao Zhang     MatRowMapType *Ci_ptr = Ci_h.data();
548076ba34aSJunchao Zhang 
549076ba34aSJunchao Zhang     for (i = 0; i < nrows; i++) {
550076ba34aSJunchao Zhang       r = rows[i]; /* local row id of i-th received row */
551076ba34aSJunchao Zhang #if defined(PETSC_USE_DEBUG)
552aed4548fSBarry Smith       PetscCheck(r >= 0 && r < Cm, PETSC_COMM_SELF, PETSC_ERR_PLIB, "local row id (%" PetscInt_FMT ") is out of range [0,%" PetscInt_FMT ")", r, Cm);
553076ba34aSJunchao Zhang #endif
554076ba34aSJunchao Zhang       Ci_ptr[r + 1] += rrowlens[i]; /* add to length of row r in C */
555076ba34aSJunchao Zhang     }
556076ba34aSJunchao Zhang     for (i = 0; i < Cm; i++) Ci_ptr[i + 1] += Ci_ptr[i]; /* to CSR format */
557076ba34aSJunchao Zhang     Cnnz = Ci_ptr[Cm];
558076ba34aSJunchao Zhang 
559076ba34aSJunchao Zhang     /* For each received row, compute src & dst offsets in memory copying (from recv bufs abuf, jbuf to Ca, Cj) */
560076ba34aSJunchao Zhang     MatRowMapKokkosViewHost dstrowoffset_h("dstrowoffset_h", nrows);
561076ba34aSJunchao Zhang     PetscInt               *dstrowoffset_hptr = dstrowoffset_h.data();
562076ba34aSJunchao Zhang     PetscInt               *currowlens; /* Current row lens. They are temp accumulators for row lens in C, to help build dstrowoffset */
563076ba34aSJunchao Zhang 
5649566063dSJacob Faibussowitsch     PetscCall(PetscCalloc1(Cm, &currowlens));           /* Init with zero, to be added to */
565076ba34aSJunchao Zhang     for (i = 0; i < nrows; i++) {                       /* for each row I receive */
566076ba34aSJunchao Zhang       r                    = rows[i];                   /* row id in C */
567076ba34aSJunchao Zhang       dstrowoffset_hptr[i] = Ci_ptr[r] + currowlens[r]; /* dst offset of the new place for each recv'ed row in Ca/Cj */
568076ba34aSJunchao Zhang       currowlens[r] += rrowlens[i];                     /* accumulate to length of row r in C */
569076ba34aSJunchao Zhang     }
5709566063dSJacob Faibussowitsch     PetscCall(PetscFree(currowlens));
571076ba34aSJunchao Zhang 
572076ba34aSJunchao Zhang     rrowlens--;
573076ba34aSJunchao Zhang     for (i = 0; i < nrows; i++) rrowlens[i + 1] += rrowlens[i]; /* Change rrowlens[] to CSR format */
574076ba34aSJunchao Zhang     dstrowoffset = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), dstrowoffset_h);
575076ba34aSJunchao Zhang     srcrowoffset = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), rrowlens_h); /* src offset of each recv'ed row in abuf/jbuf */
576076ba34aSJunchao Zhang 
577076ba34aSJunchao Zhang     /* Build the reduceSF, which performs buffer to buffer send/recv */
578076ba34aSJunchao Zhang     PetscInt *sdisp, *rdisp; /* buffer to send offsets of roots, and buffer to recv them */
5799566063dSJacob Faibussowitsch     PetscCall(PetscMalloc2(niranks, &sdisp, nranks, &rdisp));
580076ba34aSJunchao Zhang     for (i = 0; i < niranks; i++) sdisp[i] = rrowlens[ioffset[i]];
5819566063dSJacob Faibussowitsch     for (i = 0; i < nranks; i++) PetscCallMPI(MPI_Irecv(&rdisp[i], 1, MPIU_INT, ranks[i], tag, comm, &reqs[i]));
5829566063dSJacob Faibussowitsch     for (i = 0; i < niranks; i++) PetscCallMPI(MPI_Isend(&sdisp[i], 1, MPIU_INT, iranks[i], tag, comm, &reqs[nranks + i]));
5839566063dSJacob Faibussowitsch     PetscCallMPI(MPI_Waitall(niranks + nranks, reqs, MPI_STATUSES_IGNORE));
584076ba34aSJunchao Zhang 
585076ba34aSJunchao Zhang     /* Nonzeros in abuf/jbuf are roots and those in A are leaves */
586076ba34aSJunchao Zhang     PetscInt     nroots2 = Cnnz, nleaves2 = Annz;
587076ba34aSJunchao Zhang     PetscSFNode *iremote;
5889566063dSJacob Faibussowitsch     PetscCall(PetscMalloc1(nleaves2, &iremote)); /* no free, since memory will be given to reduceSF */
589076ba34aSJunchao Zhang     for (i = 0; i < nranks; i++) {
590076ba34aSJunchao Zhang       PetscInt rootbase = rdisp[i];                      /* root offset at this root rank */
591076ba34aSJunchao Zhang       PetscInt leafbase = Ai[roffset[i]];                /* leaf base */
592076ba34aSJunchao Zhang       PetscInt nz       = Ai[roffset[i + 1]] - leafbase; /* I will send nz nonzeros to this root rank */
593076ba34aSJunchao Zhang       for (PetscInt k = 0; k < nz; k++) {
594076ba34aSJunchao Zhang         iremote[leafbase + k].rank  = ranks[i];
595076ba34aSJunchao Zhang         iremote[leafbase + k].index = rootbase + k;
596076ba34aSJunchao Zhang       }
597076ba34aSJunchao Zhang     }
5989566063dSJacob Faibussowitsch     PetscCall(PetscSFCreate(comm, &reduceSF));
5999566063dSJacob Faibussowitsch     PetscCall(PetscSFSetGraph(reduceSF, nroots2, nleaves2, NULL, PETSC_OWN_POINTER, iremote, PETSC_OWN_POINTER));
6009566063dSJacob Faibussowitsch     PetscCall(PetscFree2(sdisp, rdisp));
601076ba34aSJunchao Zhang 
602076ba34aSJunchao Zhang     /* Reduce Aa, Ajg to abuf and jbuf */
603076ba34aSJunchao Zhang 
604076ba34aSJunchao Zhang     /* If A uses local col ids, convert them to global ones before sending */
605076ba34aSJunchao Zhang     MatColIdxKokkosView Ajg;
606076ba34aSJunchao Zhang     if (local) {
607076ba34aSJunchao Zhang       Ajg                           = MatColIdxKokkosView("j", Annz);
608076ba34aSJunchao Zhang       const MatColIdxKokkosView &Aj = akok->j_dual.view_device();
6099371c9d4SSatish Balay       Kokkos::parallel_for(
6109371c9d4SSatish Balay         Annz, KOKKOS_LAMBDA(const PetscInt i) { Ajg(i) = l2g(Aj(i)); });
611076ba34aSJunchao Zhang     } else {
612076ba34aSJunchao Zhang       Ajg = akok->j_dual.view_device(); /* no data copy, just take a reference */
613076ba34aSJunchao Zhang     }
614076ba34aSJunchao Zhang 
615076ba34aSJunchao Zhang     MatColIdxKokkosView jbuf("jbuf", Cnnz);
616076ba34aSJunchao Zhang     abuf = MatScalarKokkosView("abuf", Cnnz);
6179566063dSJacob Faibussowitsch     PetscCallMPI(PetscSFReduceBegin(reduceSF, MPIU_INT, Ajg.data(), jbuf.data(), MPI_REPLACE));
6189566063dSJacob Faibussowitsch     PetscCallMPI(PetscSFReduceEnd(reduceSF, MPIU_INT, Ajg.data(), jbuf.data(), MPI_REPLACE));
6199566063dSJacob Faibussowitsch     PetscCallMPI(PetscSFReduceBegin(reduceSF, MPIU_SCALAR, akok->a_device_data(), abuf.data(), MPI_REPLACE));
6209566063dSJacob Faibussowitsch     PetscCallMPI(PetscSFReduceEnd(reduceSF, MPIU_SCALAR, akok->a_device_data(), abuf.data(), MPI_REPLACE));
621076ba34aSJunchao Zhang 
622076ba34aSJunchao Zhang     /* Copy data from abuf, jbuf to Ca, Cj */
623076ba34aSJunchao Zhang     MatRowMapKokkosView Ci = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Ci_h); /* Ci is an alias of Ci_h if no device */
624076ba34aSJunchao Zhang     MatColIdxKokkosView Cj("j", Cnnz);
625076ba34aSJunchao Zhang     MatScalarKokkosView Ca("a", Cnnz);
626076ba34aSJunchao Zhang 
6279371c9d4SSatish Balay     Kokkos::parallel_for(
6289371c9d4SSatish Balay       Kokkos::TeamPolicy<>(nrows, Kokkos::AUTO()), KOKKOS_LAMBDA(const KokkosTeamMemberType &t) {
629076ba34aSJunchao Zhang         PetscInt i   = t.league_rank();
630076ba34aSJunchao Zhang         PetscInt src = srcrowoffset(i), dst = dstrowoffset(i);
631076ba34aSJunchao Zhang         PetscInt len = srcrowoffset(i + 1) - srcrowoffset(i);
632076ba34aSJunchao Zhang         Kokkos::parallel_for(Kokkos::TeamThreadRange(t, len), [&](PetscInt k) {
633076ba34aSJunchao Zhang           Ca(dst + k) = abuf(src + k);
634076ba34aSJunchao Zhang           Cj(dst + k) = jbuf(src + k);
635076ba34aSJunchao Zhang         });
636076ba34aSJunchao Zhang       });
637076ba34aSJunchao Zhang 
638076ba34aSJunchao Zhang     /* Build C with Ca, Ci, Cj */
639076ba34aSJunchao Zhang     C = KokkosCsrMatrix("csrmat", Cm, N, Cnnz, Ca, Ci, Cj);
6409566063dSJacob Faibussowitsch     PetscCall(PetscFree2(srowlens, reqs));
64198921bdaSJacob Faibussowitsch   } else SETERRQ(PETSC_COMM_SELF, PETSC_ERR_PLIB, "Unsupported MatReuse enum %d", reuse);
642*3ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
643076ba34aSJunchao Zhang }
644076ba34aSJunchao Zhang 
64511a5261eSBarry Smith /* MatSetMPIAIJKokkosWithGlobalCSRMatrix - Set the diag and offdiag parts of a `MATMPIAIJKOKKOS` matrix by splitting a KokkosCsrMatrix
646076ba34aSJunchao Zhang 
647076ba34aSJunchao Zhang   Input Parameters:
64811a5261eSBarry Smith +  C        - the `MATMPIAIJKOKKOS` matrix, of size m,n,M,N
649076ba34aSJunchao Zhang .  reuse    - indicate whether the matrix has called this function before
650076ba34aSJunchao Zhang .  csrmat   - the KokkosCsrMatrix, of size m,N
65111a5261eSBarry Smith -  Cdstart  - when reuse == `MAT_REUSE_MATRIX`, it is an input parameter. For each row in csrmat, it stores the start of the first
652076ba34aSJunchao Zhang               entry of the diag block of C in csrmat's j array. E.g, if row i has col ids = {0, 3, 4, 5, 7, 9} and the first diag
653076ba34aSJunchao Zhang               entry is 5, then Cdstart[i] = 3.
654076ba34aSJunchao Zhang 
655076ba34aSJunchao Zhang   Output Parameters:
65611a5261eSBarry Smith +  C        - the updated `MATMPIAIJKOKKOS` matrix
65711a5261eSBarry Smith -  Cdstart - when reuse == `MAT_INITIAL_MATRIX`, it is an output parameter
658076ba34aSJunchao Zhang 
65911a5261eSBarry Smith   Note:
66011a5261eSBarry Smith    Between calls with `MAT_INITIAL_MATRIX` or `MAT_REUSE_MATRIX`, csrmat must have the same nonzero pattern
66111a5261eSBarry Smith 
66211a5261eSBarry Smith .seealso: `MATMPIAIJKOKKOS`
663076ba34aSJunchao Zhang  */
664d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatSetMPIAIJKokkosWithGlobalCSRMatrix(Mat C, MatReuse reuse, const KokkosCsrMatrix &csrmat, MatRowMapKokkosView &Cdstart)
665d71ae5a4SJacob Faibussowitsch {
666076ba34aSJunchao Zhang   const MatScalarKokkosView      &Ca = csrmat.values;
667076ba34aSJunchao Zhang   const ConstMatRowMapKokkosView &Ci = csrmat.graph.row_map;
668076ba34aSJunchao Zhang   PetscInt                        m, n, N;
669076ba34aSJunchao Zhang 
670076ba34aSJunchao Zhang   PetscFunctionBegin;
6719566063dSJacob Faibussowitsch   PetscCall(MatGetLocalSize(C, &m, &n));
6729566063dSJacob Faibussowitsch   PetscCall(MatGetSize(C, NULL, &N));
673076ba34aSJunchao Zhang 
674076ba34aSJunchao Zhang   if (reuse == MAT_REUSE_MATRIX) {
675076ba34aSJunchao Zhang     Mat_MPIAIJ                *mpiaij = static_cast<Mat_MPIAIJ *>(C->data);
676076ba34aSJunchao Zhang     Mat_SeqAIJKokkos          *akok   = static_cast<Mat_SeqAIJKokkos *>(mpiaij->A->spptr);
677076ba34aSJunchao Zhang     Mat_SeqAIJKokkos          *bkok   = static_cast<Mat_SeqAIJKokkos *>(mpiaij->B->spptr);
678076ba34aSJunchao Zhang     const MatScalarKokkosView &Cda = akok->a_dual.view_device(), Coa = bkok->a_dual.view_device();
679076ba34aSJunchao Zhang     const MatRowMapKokkosView &Cdi = akok->i_dual.view_device(), Coi = bkok->i_dual.view_device();
680076ba34aSJunchao Zhang 
681076ba34aSJunchao Zhang     /* Fill 'a' of Cd and Co on device */
6829371c9d4SSatish Balay     Kokkos::parallel_for(
6839371c9d4SSatish Balay       Kokkos::TeamPolicy<>(m, Kokkos::AUTO()), KOKKOS_LAMBDA(const KokkosTeamMemberType &t) {
684076ba34aSJunchao Zhang         PetscInt i       = t.league_rank();     /* row i */
685076ba34aSJunchao Zhang         PetscInt clen    = Ci(i + 1) - Ci(i);   /* len of row i of C */
686076ba34aSJunchao Zhang         PetscInt cdlen   = Cdi(i + 1) - Cdi(i); /* len of row i of Cd */
687076ba34aSJunchao Zhang         PetscInt cdstart = Cdstart(i);          /* [start, end) of row i of Cd in C */
688076ba34aSJunchao Zhang         PetscInt cdend   = cdstart + cdlen;
689076ba34aSJunchao Zhang         /* [0, clen) is cut into three blocks: [0, cdstart), [cdstart, cdend), [cdend, clen) */
690076ba34aSJunchao Zhang         Kokkos::parallel_for(Kokkos::TeamThreadRange(t, clen), [&](PetscInt k) {
691076ba34aSJunchao Zhang           if (k < cdstart) { /* k in [0, cdstart) */
692076ba34aSJunchao Zhang             Coa(Coi(i) + k) = Ca(Ci(i) + k);
693076ba34aSJunchao Zhang           } else if (k < cdend) { /* k in [cdstart, cdend) */
694076ba34aSJunchao Zhang             Cda(Cdi(i) + (k - cdstart)) = Ca(Ci(i) + k);
695076ba34aSJunchao Zhang           } else { /* k in [cdend, clen) */
696076ba34aSJunchao Zhang             Coa(Coi(i) + k - cdlen) = Ca(Ci(i) + k);
697076ba34aSJunchao Zhang           }
698076ba34aSJunchao Zhang         });
699076ba34aSJunchao Zhang       });
700076ba34aSJunchao Zhang 
701076ba34aSJunchao Zhang     akok->a_dual.modify_device();
702076ba34aSJunchao Zhang     bkok->a_dual.modify_device();
703076ba34aSJunchao Zhang   } else if (reuse == MAT_INITIAL_MATRIX) {
704076ba34aSJunchao Zhang     Mat                        Cd, Co;
705076ba34aSJunchao Zhang     const MatColIdxKokkosView &Cj = csrmat.graph.entries;
706076ba34aSJunchao Zhang     MatRowMapKokkosDualView    Cdi_dual("i", m + 1), Coi_dual("i", m + 1);
707076ba34aSJunchao Zhang     MatRowMapKokkosView        Cdi = Cdi_dual.view_device(), Coi = Coi_dual.view_device();
708076ba34aSJunchao Zhang     PetscInt                   cstart, cend;
709076ba34aSJunchao Zhang 
710076ba34aSJunchao Zhang     /* Note that each row of C is sorted by col ids. We want to find out how to cut each row into three blocks:
711076ba34aSJunchao Zhang        left to the diag block, diag block, right to the diag block. The diag block have col ids in [cstart,cend).
712076ba34aSJunchao Zhang        Suppose a row of C has len nonzeros, indexed by [0, len). We want to know two indices: cdstart and cdend,
713076ba34aSJunchao Zhang        such that the three blocks are [0,cdstart), [cdstart,cdend), [cdend,len). The following code equivalentaly
714076ba34aSJunchao Zhang        stores values of cdstart and cdend-cstart (aka Cdi[]) instead.
715076ba34aSJunchao Zhang      */
716076ba34aSJunchao Zhang     Cdstart = MatRowMapKokkosView("Cdstart", m);
7179566063dSJacob Faibussowitsch     PetscCall(PetscLayoutGetRange(C->cmap, &cstart, &cend)); /* Not MatGetOwnershipRangeColumn() since C has not been preallocated yet */
718076ba34aSJunchao Zhang 
719076ba34aSJunchao Zhang     /* I could use RangePolicy and one thread per row. But since each thread essentially does binary search, threads in a
720076ba34aSJunchao Zhang       CUDA warp would completely diverge. So I use TeamPolicy with a team size 1.
721076ba34aSJunchao Zhang      */
7229371c9d4SSatish Balay     Kokkos::parallel_for(
7239371c9d4SSatish Balay       Kokkos::TeamPolicy<>(m, 1), KOKKOS_LAMBDA(const KokkosTeamMemberType &t) {
724076ba34aSJunchao Zhang         Kokkos::single(Kokkos::PerTeam(t), [=]() {                               /* Only one thread works in a team */
725076ba34aSJunchao Zhang                                                    PetscInt i = t.league_rank(); /* row i */
726076ba34aSJunchao Zhang                                                    PetscInt j, first, count, step;
727076ba34aSJunchao Zhang 
728076ba34aSJunchao Zhang                                                    if (i == 0) { /* Set the first entry of the i arrays to zero on device, to be used in CSR */
729076ba34aSJunchao Zhang                                                      Cdi(0) = 0;
730076ba34aSJunchao Zhang                                                      Coi(0) = 0;
731076ba34aSJunchao Zhang                                                    }
732076ba34aSJunchao Zhang 
733076ba34aSJunchao Zhang                                                    /* Do std::lower_bound(Ci(i),Ci(i+1),cstart) on Cj[]. We use j as the iterator. lower_bound() returns
734076ba34aSJunchao Zhang           in 'first' the first iterator with a value >= cstart, or last iterator if no such element is found.
735076ba34aSJunchao Zhang         */
736076ba34aSJunchao Zhang                                                    count = Ci(i + 1) - Ci(i);
737076ba34aSJunchao Zhang                                                    first = Ci(i);
738076ba34aSJunchao Zhang                                                    while (count > 0) {
739076ba34aSJunchao Zhang                                                      j    = first;
740076ba34aSJunchao Zhang                                                      step = count / 2;
741076ba34aSJunchao Zhang                                                      j += step;
742076ba34aSJunchao Zhang                                                      if (Cj(j) < cstart) {
743076ba34aSJunchao Zhang                                                        first = ++j;
744076ba34aSJunchao Zhang                                                        count -= step + 1;
745076ba34aSJunchao Zhang                                                      } else count = step;
746076ba34aSJunchao Zhang                                                    }
747076ba34aSJunchao Zhang                                                    Cdstart(i) = first - Ci(i); /* 'first' is the while-loop's output */
748076ba34aSJunchao Zhang 
749076ba34aSJunchao Zhang                                                    /* Do std::lower_bound(first,Ci(i+1),cend) on Cj[] */
750076ba34aSJunchao Zhang                                                    count = Ci(i + 1) - first;
751076ba34aSJunchao Zhang                                                    while (count > 0) {
752076ba34aSJunchao Zhang                                                      j    = first;
753076ba34aSJunchao Zhang                                                      step = count / 2;
754076ba34aSJunchao Zhang                                                      j += step;
755076ba34aSJunchao Zhang                                                      if (Cj(j) < cend) {
756076ba34aSJunchao Zhang                                                        first = ++j;
757076ba34aSJunchao Zhang                                                        count -= step + 1;
758076ba34aSJunchao Zhang                                                      } else count = step;
759076ba34aSJunchao Zhang                                                    }
760076ba34aSJunchao Zhang                                                    Cdi(i + 1) = first - (Ci(i) + Cdstart(i));     /* 'first' is the while-loop's output */
761076ba34aSJunchao Zhang                                                    Coi(i + 1) = (Ci(i + 1) - Ci(i)) - Cdi(i + 1); /* Co's row len = C's row len - Cd's row len */
762076ba34aSJunchao Zhang         });
763076ba34aSJunchao Zhang       });
764076ba34aSJunchao Zhang 
765076ba34aSJunchao Zhang     /* Convert row lens in Cdi[], Coi[] to CSR format using inclusive scan, e.g., changing [0,1,2,3] into [0,1,3,6] */
7669371c9d4SSatish Balay     Kokkos::parallel_scan(
7679371c9d4SSatish Balay       m + 1, KOKKOS_LAMBDA(const PetscInt i, PetscInt &update, const bool final) {
768076ba34aSJunchao Zhang         update += Cdi(i);
769076ba34aSJunchao Zhang         if (final) Cdi(i) = update;
770076ba34aSJunchao Zhang       });
7719371c9d4SSatish Balay     Kokkos::parallel_scan(
7729371c9d4SSatish Balay       m + 1, KOKKOS_LAMBDA(const PetscInt i, PetscInt &update, const bool final) {
773076ba34aSJunchao Zhang         update += Coi(i);
774076ba34aSJunchao Zhang         if (final) Coi(i) = update;
775076ba34aSJunchao Zhang       });
776076ba34aSJunchao Zhang 
777076ba34aSJunchao Zhang     /* Get Cdi, Coi on host (it is not a waste, since we do need them on host in
778076ba34aSJunchao Zhang        MatCreateSeqAIJKokkosWithCSRMatrix() below), then get nnz of Cd and Co.
779076ba34aSJunchao Zhang     */
780076ba34aSJunchao Zhang     Cdi_dual.modify_device();
781076ba34aSJunchao Zhang     Coi_dual.modify_device();
782076ba34aSJunchao Zhang     Cdi_dual.sync_host();
783076ba34aSJunchao Zhang     Coi_dual.sync_host();
784076ba34aSJunchao Zhang     PetscInt Cd_nnz = Cdi_dual.view_host().data()[m];
785076ba34aSJunchao Zhang     PetscInt Co_nnz = Coi_dual.view_host().data()[m];
786076ba34aSJunchao Zhang 
787076ba34aSJunchao Zhang     /* With nnz, allocate a, j for Cd and Co */
788076ba34aSJunchao Zhang     MatColIdxKokkosDualView Cdj_dual("j", Cd_nnz), Coj_dual("j", Co_nnz);
789076ba34aSJunchao Zhang     MatScalarKokkosDualView Cda_dual("a", Cd_nnz), Coa_dual("a", Co_nnz);
790076ba34aSJunchao Zhang 
791076ba34aSJunchao Zhang     /* Fill a, j of Cd and Co on device */
792076ba34aSJunchao Zhang     MatColIdxKokkosView Cdj = Cdj_dual.view_device(), Coj = Coj_dual.view_device();
793076ba34aSJunchao Zhang     MatScalarKokkosView Cda = Cda_dual.view_device(), Coa = Coa_dual.view_device();
794076ba34aSJunchao Zhang 
7959371c9d4SSatish Balay     Kokkos::parallel_for(
7969371c9d4SSatish Balay       Kokkos::TeamPolicy<>(m, Kokkos::AUTO()), KOKKOS_LAMBDA(const KokkosTeamMemberType &t) {
797076ba34aSJunchao Zhang         PetscInt i       = t.league_rank();     /* row i */
798076ba34aSJunchao Zhang         PetscInt clen    = Ci(i + 1) - Ci(i);   /* len of row i of C */
799076ba34aSJunchao Zhang         PetscInt cdlen   = Cdi(i + 1) - Cdi(i); /* len of row i of Cd */
800076ba34aSJunchao Zhang         PetscInt cdstart = Cdstart(i);          /* [start, end) of row i of Cd in C */
801076ba34aSJunchao Zhang         PetscInt cdend   = cdstart + cdlen;
802076ba34aSJunchao Zhang         /* [0, clen) is cut into three blocks: [0, cdstart), [cdstart, cdend), [cdend, clen) */
803076ba34aSJunchao Zhang         Kokkos::parallel_for(Kokkos::TeamThreadRange(t, clen), [&](PetscInt k) {
804076ba34aSJunchao Zhang           if (k < cdstart) { /* k in [0, cdstart) */
805076ba34aSJunchao Zhang             Coa(Coi(i) + k) = Ca(Ci(i) + k);
806076ba34aSJunchao Zhang             Coj(Coi(i) + k) = Cj(Ci(i) + k);
807076ba34aSJunchao Zhang           } else if (k < cdend) { /* k in [cdstart, cdend) */
808076ba34aSJunchao Zhang             Cda(Cdi(i) + (k - cdstart)) = Ca(Ci(i) + k);
809076ba34aSJunchao Zhang             Cdj(Cdi(i) + (k - cdstart)) = Cj(Ci(i) + k) - cstart; /* Use local col ids in Cdj */
810076ba34aSJunchao Zhang           } else {                                                /* k in [cdend, clen) */
811076ba34aSJunchao Zhang             Coa(Coi(i) + k - cdlen) = Ca(Ci(i) + k);
812076ba34aSJunchao Zhang             Coj(Coi(i) + k - cdlen) = Cj(Ci(i) + k);
813076ba34aSJunchao Zhang           }
814076ba34aSJunchao Zhang         });
815076ba34aSJunchao Zhang       });
816076ba34aSJunchao Zhang 
817076ba34aSJunchao Zhang     Cdj_dual.modify_device();
818076ba34aSJunchao Zhang     Cda_dual.modify_device();
819076ba34aSJunchao Zhang     Coj_dual.modify_device();
820076ba34aSJunchao Zhang     Coa_dual.modify_device();
821076ba34aSJunchao Zhang     /* With a, i, j for Cd and Co, finally build Cd, Co and then C. Their offloadmask will be set in each's MatAssemblyEnd */
822076ba34aSJunchao Zhang     auto cdkok = new Mat_SeqAIJKokkos(m, n, Cd_nnz, Cdi_dual, Cdj_dual, Cda_dual);
823076ba34aSJunchao Zhang     auto cokok = new Mat_SeqAIJKokkos(m, N, Co_nnz, Coi_dual, Coj_dual, Coa_dual);
8249566063dSJacob Faibussowitsch     PetscCall(MatCreateSeqAIJKokkosWithCSRMatrix(PETSC_COMM_SELF, cdkok, &Cd));
8259566063dSJacob Faibussowitsch     PetscCall(MatCreateSeqAIJKokkosWithCSRMatrix(PETSC_COMM_SELF, cokok, &Co));
8269566063dSJacob Faibussowitsch     PetscCall(MatSetMPIAIJKokkosWithSplitSeqAIJKokkosMatrices(C, Cd, Co)); /* Coj will be converted to local ids within */
827076ba34aSJunchao Zhang   }
828*3ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
829076ba34aSJunchao Zhang }
830076ba34aSJunchao Zhang 
831076ba34aSJunchao Zhang /* MatSeqAIJCompactOutExtraColumns_SeqAIJKokkos - Compact a SEQAIJKOKKS matrix's global col ids.
832076ba34aSJunchao Zhang 
833076ba34aSJunchao Zhang   It is similar to MatSeqAIJCompactOutExtraColumns_SeqAIJ, but it applies to SEQAIJKOKKOS and returns the l2g map in Kokkos view.
834076ba34aSJunchao Zhang 
835076ba34aSJunchao Zhang   Input Parameters:
836076ba34aSJunchao Zhang +  C        - the MATMPIAIJKOKKOS matrix, of size m,n,M,N
837076ba34aSJunchao Zhang .  reuse    - indicate whether the matrix has called this function before
838076ba34aSJunchao Zhang .  csrmat   - the KokkosCsrMatrix, of size m,N
839076ba34aSJunchao Zhang -  Cdoffset - when reuse == MAT_REUSE_MATRIX, it is an input parameter. For each row in csrmat, it stores the offset of the first
840076ba34aSJunchao Zhang               entry of the diag block of C in csrmat's j array.
841076ba34aSJunchao Zhang 
842076ba34aSJunchao Zhang   Output Parameters:
843076ba34aSJunchao Zhang +  C        - the updated MATMPIAIJKOKKOS matrix
844076ba34aSJunchao Zhang -  Cdoffset - when reuse == MAT_INITIAL_MATRIX, it is an output parameter
845076ba34aSJunchao Zhang 
84611a5261eSBarry Smith   Note:
84711a5261eSBarry Smith   the input matrix's col ids and col size will be changed.
848076ba34aSJunchao Zhang */
849d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatSeqAIJCompactOutExtraColumns_SeqAIJKokkos(Mat C, MatColIdxKokkosView &l2g)
850d71ae5a4SJacob Faibussowitsch {
851076ba34aSJunchao Zhang   Mat_SeqAIJKokkos      *ckok;
852076ba34aSJunchao Zhang   ISLocalToGlobalMapping l2gmap;
853076ba34aSJunchao Zhang   const PetscInt        *garray;
854076ba34aSJunchao Zhang   PetscInt               sz;
855076ba34aSJunchao Zhang 
856076ba34aSJunchao Zhang   PetscFunctionBegin;
857076ba34aSJunchao Zhang   /* Compact P_other's global col ids and col size. We do it since we guess with local ids KK might be more memory scalable */
8589566063dSJacob Faibussowitsch   PetscCall(MatSeqAIJCompactOutExtraColumns_SeqAIJ(C, &l2gmap));
859076ba34aSJunchao Zhang   ckok = static_cast<Mat_SeqAIJKokkos *>(C->spptr);
860076ba34aSJunchao Zhang   ckok->j_dual.modify_host(); /* P_other's j is modified on host; we need to sync it on device */
861076ba34aSJunchao Zhang   ckok->j_dual.sync_device();
862076ba34aSJunchao Zhang   ckok->SetColSize(C->cmap->n); /* Update col size of the csrmat in spptr */
863076ba34aSJunchao Zhang 
864076ba34aSJunchao Zhang   /* Build l2g -- the local to global mapping of C's cols */
8659566063dSJacob Faibussowitsch   PetscCall(ISLocalToGlobalMappingGetIndices(l2gmap, &garray));
8669566063dSJacob Faibussowitsch   PetscCall(ISLocalToGlobalMappingGetSize(l2gmap, &sz));
86708401ef6SPierre Jolivet   PetscCheck(C->cmap->n == sz, PETSC_COMM_SELF, PETSC_ERR_PLIB, "matrix column size(%" PetscInt_FMT ") != l2g mapping size(%" PetscInt_FMT ")", C->cmap->n, sz);
868076ba34aSJunchao Zhang 
869076ba34aSJunchao Zhang   ConstMatColIdxKokkosViewHost tmp(garray, sz);
870076ba34aSJunchao Zhang   l2g = MatColIdxKokkosView("l2g", sz);
871076ba34aSJunchao Zhang   Kokkos::deep_copy(l2g, tmp);
872076ba34aSJunchao Zhang 
8739566063dSJacob Faibussowitsch   PetscCall(ISLocalToGlobalMappingRestoreIndices(l2gmap, &garray));
8749566063dSJacob Faibussowitsch   PetscCall(ISLocalToGlobalMappingDestroy(&l2gmap));
875*3ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
876076ba34aSJunchao Zhang }
877076ba34aSJunchao Zhang 
878076ba34aSJunchao Zhang /* MatProductSymbolic_MPIAIJKokkos_AB - AB flavor of MatProductSymbolic_MPIAIJKokkos
879076ba34aSJunchao Zhang 
880076ba34aSJunchao Zhang   Input Parameters:
881076ba34aSJunchao Zhang +  product  - Mat_Product which carried out the computation. Passed in to access info about this mat product.
882076ba34aSJunchao Zhang .  A        - an MPIAIJKOKKOS matrix
883076ba34aSJunchao Zhang .  B        - an MPIAIJKOKKOS matrix
884076ba34aSJunchao Zhang -  mm       - a struct used to stash intermediate data when computing AB. Persist from symbolic to numeric operations.
885076ba34aSJunchao Zhang 
88611a5261eSBarry Smith   Note: The local part of the result C is stored as mm->C_global, which is of type KokkosCsrMatrix and uses global col ids.
887076ba34aSJunchao Zhang */
888d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatProductSymbolic_MPIAIJKokkos_AB(Mat_Product *product, Mat A, Mat B, MatMatStruct_AB *mm)
889d71ae5a4SJacob Faibussowitsch {
890076ba34aSJunchao Zhang   Mat_MPIAIJ              *a  = static_cast<Mat_MPIAIJ *>(A->data);
891076ba34aSJunchao Zhang   Mat                      Ad = a->A, Ao = a->B; /* diag and offdiag of A */
892076ba34aSJunchao Zhang   IS                       glob = NULL;
893076ba34aSJunchao Zhang   const PetscInt          *garray;
894076ba34aSJunchao Zhang   PetscInt                 N = B->cmap->N, sz;
895076ba34aSJunchao Zhang   ConstMatColIdxKokkosView l2g1; /* two temp maps mapping local col ids to global ones */
896076ba34aSJunchao Zhang   MatColIdxKokkosView      l2g2;
897076ba34aSJunchao Zhang   Mat                      C1, C2; /* intermediate matrices */
898076ba34aSJunchao Zhang 
899076ba34aSJunchao Zhang   PetscFunctionBegin;
900076ba34aSJunchao Zhang   /* C1 = Ad * B_local. B_local is a matrix got by merging Bd and Bo, and uses local col ids */
9019566063dSJacob Faibussowitsch   PetscCall(MatMPIAIJGetLocalMatMerge(B, MAT_INITIAL_MATRIX, &glob, &mm->B_local));
9029566063dSJacob Faibussowitsch   PetscCall(MatProductCreate(Ad, mm->B_local, NULL, &C1));
9039566063dSJacob Faibussowitsch   PetscCall(MatProductSetType(C1, MATPRODUCT_AB));
9049566063dSJacob Faibussowitsch   PetscCall(MatProductSetFill(C1, product->fill));
905076ba34aSJunchao Zhang   C1->product->api_user = product->api_user;
9069566063dSJacob Faibussowitsch   PetscCall(MatProductSetFromOptions(C1));
907dbbe0bcdSBarry Smith   PetscUseTypeMethod(C1, productsymbolic);
908076ba34aSJunchao Zhang 
9099566063dSJacob Faibussowitsch   PetscCall(ISGetIndices(glob, &garray));
9109566063dSJacob Faibussowitsch   PetscCall(ISGetSize(glob, &sz));
911076ba34aSJunchao Zhang   const auto &tmp = ConstMatColIdxKokkosViewHost(garray, sz);                       /* wrap garray as a view */
912076ba34aSJunchao Zhang   l2g1            = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), tmp); /* maybe just an alias to tmp, so we restore garray at the very end */
9139566063dSJacob Faibussowitsch   PetscCall(MatSeqAIJKokkosGetCSRMatrixWithGlobalColumnIds(C1, N, l2g1, mm->C1_global));
914076ba34aSJunchao Zhang 
915076ba34aSJunchao Zhang   /* C2 = Ao * B_other. B_other is a matrix consisting of needed rows of B gathered from other procs */
9169566063dSJacob Faibussowitsch   PetscCall(MatSeqAIJKokkosBcast(mm->B_local, MAT_INITIAL_MATRIX, N, l2g1, a->Mvctx, mm->sf, mm->abuf, mm->rows, mm->rowoffset, mm->B_other));
917076ba34aSJunchao Zhang 
918076ba34aSJunchao Zhang   /* Compact B_other to use local ids as we guess KK spgemm is more memroy scalable with that; We could skip the compaction to simplify code */
9199566063dSJacob Faibussowitsch   PetscCall(MatSeqAIJCompactOutExtraColumns_SeqAIJKokkos(mm->B_other, l2g2));
9209566063dSJacob Faibussowitsch   PetscCall(MatProductCreate(Ao, mm->B_other, NULL, &C2));
9219566063dSJacob Faibussowitsch   PetscCall(MatProductSetType(C2, MATPRODUCT_AB));
9229566063dSJacob Faibussowitsch   PetscCall(MatProductSetFill(C2, product->fill));
923076ba34aSJunchao Zhang   C2->product->api_user = product->api_user;
9249566063dSJacob Faibussowitsch   PetscCall(MatProductSetFromOptions(C2));
925dbbe0bcdSBarry Smith   PetscUseTypeMethod(C2, productsymbolic);
9269566063dSJacob Faibussowitsch   PetscCall(MatSeqAIJKokkosGetCSRMatrixWithGlobalColumnIds(C2, N, l2g2, mm->C2_global));
927076ba34aSJunchao Zhang 
928076ba34aSJunchao Zhang   /* C = C1 + C2.  We actually use their global col ids versions in adding */
929076ba34aSJunchao Zhang   mm->kh.create_spadd_handle(false); /* Input C1, C2 are NOT sorted, since B_local, B_other are not */
930076ba34aSJunchao Zhang   KokkosSparse::spadd_symbolic(&mm->kh, mm->C1_global, mm->C2_global, mm->C_global);
931076ba34aSJunchao Zhang   /* Have to do numeric since spadd_symbolic does not really populate column indices of the result matrix */
932076ba34aSJunchao Zhang   KokkosSparse::spadd_numeric(&mm->kh, (MatScalarType)1.0, mm->C1_global, (MatScalarType)1.0, mm->C2_global, mm->C_global);
933076ba34aSJunchao Zhang 
934076ba34aSJunchao Zhang   mm->C1 = C1;
935076ba34aSJunchao Zhang   mm->C2 = C2;
9369566063dSJacob Faibussowitsch   PetscCall(ISRestoreIndices(glob, &garray));
9379566063dSJacob Faibussowitsch   PetscCall(ISDestroy(&glob));
938*3ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
939076ba34aSJunchao Zhang }
940076ba34aSJunchao Zhang 
941076ba34aSJunchao Zhang /* MatProductSymbolic_MPIAIJKokkos_AtB - A^tB flavor of MatProductSymbolic_MPIAIJKokkos
942076ba34aSJunchao Zhang 
943076ba34aSJunchao Zhang   Input Parameters:
944076ba34aSJunchao Zhang +  product  - Mat_Product which carried out the computation. Passed in to access info about this mat product.
945076ba34aSJunchao Zhang .  A        - an MPIAIJKOKKOS matrix
946076ba34aSJunchao Zhang .  B        - a SEQAIJKOKKOS matrix. It works as if A^t is multiplied by a parallel matrix made up of Bs on each rank.
947076ba34aSJunchao Zhang .  localB   - Does B use local col ids? If false, then B is already in global col ids.
948076ba34aSJunchao Zhang .  N        - col size of the "parallel B matrix". It implies B's global col ids are in range of [0,N) and N is the same across the communicator.
949076ba34aSJunchao Zhang .  l2g      - If localB, then l2g maps B's local col ids to global ones.
950076ba34aSJunchao Zhang -  mm       - a struct used to stash intermediate data in AtB
951076ba34aSJunchao Zhang 
95211a5261eSBarry Smith   Note: The local part of the result C is stored as mm->C_global, which is of type KokkosCsrMatrix and uses global col ids.
953076ba34aSJunchao Zhang */
954d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatProductSymbolic_MPIAIJKokkos_AtB(Mat_Product *product, Mat A, Mat B, PetscBool localB, PetscInt N, const ConstMatColIdxKokkosView &l2g, MatMatStruct_AtB *mm)
955d71ae5a4SJacob Faibussowitsch {
956076ba34aSJunchao Zhang   Mat_MPIAIJ *a  = static_cast<Mat_MPIAIJ *>(A->data);
957076ba34aSJunchao Zhang   Mat         Ad = a->A, Ao = a->B; /* diag and offdiag of A */
958076ba34aSJunchao Zhang   Mat         C1, C2;               /* intermediate matrices */
959076ba34aSJunchao Zhang 
960076ba34aSJunchao Zhang   PetscFunctionBegin;
961076ba34aSJunchao Zhang   /* C1 = Ad^t * B */
9629566063dSJacob Faibussowitsch   PetscCall(MatProductCreate(Ad, B, NULL, &C1));
9639566063dSJacob Faibussowitsch   PetscCall(MatProductSetType(C1, MATPRODUCT_AtB));
9649566063dSJacob Faibussowitsch   PetscCall(MatProductSetFill(C1, product->fill));
965076ba34aSJunchao Zhang   C1->product->api_user = product->api_user;
9669566063dSJacob Faibussowitsch   PetscCall(MatProductSetFromOptions(C1));
967dbbe0bcdSBarry Smith   PetscUseTypeMethod(C1, productsymbolic);
968076ba34aSJunchao Zhang 
9699566063dSJacob Faibussowitsch   if (localB) PetscCall(MatSeqAIJKokkosGetCSRMatrixWithGlobalColumnIds(C1, N, l2g, mm->C1_global));
970076ba34aSJunchao Zhang   else mm->C1_global = static_cast<Mat_SeqAIJKokkos *>(C1->spptr)->csrmat; /* the csrmat already uses global col ids */
971076ba34aSJunchao Zhang 
972076ba34aSJunchao Zhang   /* C2 = Ao^t * B */
9739566063dSJacob Faibussowitsch   PetscCall(MatProductCreate(Ao, B, NULL, &C2));
9749566063dSJacob Faibussowitsch   PetscCall(MatProductSetType(C2, MATPRODUCT_AtB));
9759566063dSJacob Faibussowitsch   PetscCall(MatProductSetFill(C2, product->fill));
976076ba34aSJunchao Zhang   C2->product->api_user = product->api_user;
9779566063dSJacob Faibussowitsch   PetscCall(MatProductSetFromOptions(C2));
978dbbe0bcdSBarry Smith   PetscUseTypeMethod(C2, productsymbolic);
979076ba34aSJunchao Zhang 
9809566063dSJacob Faibussowitsch   PetscCall(MatSeqAIJKokkosReduce(C2, MAT_INITIAL_MATRIX, localB, N, l2g, a->Mvctx, mm->sf, mm->abuf, mm->srcrowoffset, mm->dstrowoffset, mm->C2_global));
981076ba34aSJunchao Zhang 
982076ba34aSJunchao Zhang   mm->kh.create_spadd_handle(false); /* Input C1, C2 are NOT sorted, since B may be not */
983076ba34aSJunchao Zhang   KokkosSparse::spadd_symbolic(&mm->kh, mm->C1_global, mm->C2_global, mm->C_global);
984076ba34aSJunchao Zhang   /* Have to do numeric since spadd_symbolic does not really populate column indices of the result matrix */
985076ba34aSJunchao Zhang   KokkosSparse::spadd_numeric(&mm->kh, (MatScalarType)1.0, mm->C1_global, (MatScalarType)1.0, mm->C2_global, mm->C_global);
986076ba34aSJunchao Zhang   mm->C1 = C1;
987076ba34aSJunchao Zhang   mm->C2 = C2;
988*3ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
989076ba34aSJunchao Zhang }
990076ba34aSJunchao Zhang 
991d71ae5a4SJacob Faibussowitsch PetscErrorCode MatProductNumeric_MPIAIJKokkos(Mat C)
992d71ae5a4SJacob Faibussowitsch {
993076ba34aSJunchao Zhang   Mat_Product                 *product = C->product;
994076ba34aSJunchao Zhang   MatProductType               ptype;
995076ba34aSJunchao Zhang   MatProductData_MPIAIJKokkos *mmdata;
996076ba34aSJunchao Zhang   MatMatStruct                *mm = NULL;
997076ba34aSJunchao Zhang   MatMatStruct_AB             *ab;
998076ba34aSJunchao Zhang   MatMatStruct_AtB            *atb;
999076ba34aSJunchao Zhang   Mat                          A, B, Ad, Ao, Bd, Bo;
1000076ba34aSJunchao Zhang   const MatScalarType          one = 1.0; /* Not use literal 1.0 directly, to avoid wrong template instantiation in KokkosSparse::spadd_numeric */
1001076ba34aSJunchao Zhang 
1002076ba34aSJunchao Zhang   PetscFunctionBegin;
1003076ba34aSJunchao Zhang   MatCheckProduct(C, 1);
1004076ba34aSJunchao Zhang   mmdata = static_cast<MatProductData_MPIAIJKokkos *>(product->data);
1005076ba34aSJunchao Zhang   ptype  = product->type;
1006076ba34aSJunchao Zhang   A      = product->A;
1007076ba34aSJunchao Zhang   B      = product->B;
1008076ba34aSJunchao Zhang   Ad     = static_cast<Mat_MPIAIJ *>(A->data)->A;
1009076ba34aSJunchao Zhang   Ao     = static_cast<Mat_MPIAIJ *>(A->data)->B;
1010076ba34aSJunchao Zhang   Bd     = static_cast<Mat_MPIAIJ *>(B->data)->A;
1011076ba34aSJunchao Zhang   Bo     = static_cast<Mat_MPIAIJ *>(B->data)->B;
1012076ba34aSJunchao Zhang 
1013076ba34aSJunchao Zhang   if (mmdata->reusesym) {           /* We reached here through e.g., MatMatMult(A,B,MAT_INITIAL_MATRIX,..,C), where symbolic/numeric are combined */
1014076ba34aSJunchao Zhang     mmdata->reusesym = PETSC_FALSE; /* So that next time when user calls MatMatMult(E,F,MAT_REUSE_MATRIX,..,C), we still do numeric  */
1015076ba34aSJunchao Zhang     ab               = mmdata->mmAB;
1016076ba34aSJunchao Zhang     atb              = mmdata->mmAtB;
1017076ba34aSJunchao Zhang     if (ab) {
1018076ba34aSJunchao Zhang       static_cast<MatProductData_SeqAIJKokkos *>(ab->C1->product->data)->reusesym = PETSC_FALSE;
1019076ba34aSJunchao Zhang       static_cast<MatProductData_SeqAIJKokkos *>(ab->C2->product->data)->reusesym = PETSC_FALSE;
1020076ba34aSJunchao Zhang     }
1021076ba34aSJunchao Zhang     if (atb) {
1022076ba34aSJunchao Zhang       static_cast<MatProductData_SeqAIJKokkos *>(atb->C1->product->data)->reusesym = PETSC_FALSE;
1023076ba34aSJunchao Zhang       static_cast<MatProductData_SeqAIJKokkos *>(atb->C2->product->data)->reusesym = PETSC_FALSE;
1024076ba34aSJunchao Zhang     }
1025*3ba16761SJacob Faibussowitsch     PetscFunctionReturn(PETSC_SUCCESS);
1026076ba34aSJunchao Zhang   }
1027076ba34aSJunchao Zhang 
1028076ba34aSJunchao Zhang   if (ptype == MATPRODUCT_AB) {
1029076ba34aSJunchao Zhang     ab = mmdata->mmAB;
1030076ba34aSJunchao Zhang     /* C1 = Ad * B_local */
103108401ef6SPierre Jolivet     PetscCheck(ab->C1->ops->productnumeric && ab->C2->ops->productnumeric, PETSC_COMM_SELF, PETSC_ERR_PLIB, "Missing numeric op for MATPRODUCT_AB");
10329566063dSJacob Faibussowitsch     PetscCall(MatMPIAIJGetLocalMatMerge(B, MAT_REUSE_MATRIX, NULL /* glob */, &ab->B_local));
10335f80ce2aSJacob Faibussowitsch     PetscCheck(ab->C1->product->B == ab->B_local, PETSC_COMM_SELF, PETSC_ERR_PLIB, "In MATPRODUCT_AB, internal mat product matrix C1->B has unexpectedly changed");
10349566063dSJacob Faibussowitsch     if (ab->C1->product->A != Ad) PetscCall(MatProductReplaceMats(Ad, NULL, NULL, ab->C1));
10359566063dSJacob Faibussowitsch     PetscCall((*ab->C1->ops->productnumeric)(ab->C1));
10369371c9d4SSatish Balay     PetscCall(MatSeqAIJKokkosBcast(ab->B_local, MAT_REUSE_MATRIX, 0 /* N */, MatColIdxKokkosView() /*l2g*/, NULL /*ownerSF*/, ab->sf, ab->abuf, ab->rows, ab->rowoffset, ab->B_other));
1037076ba34aSJunchao Zhang     /* C2 = Ao * B_other */
103808401ef6SPierre Jolivet     PetscCheck(ab->C2->product->B == ab->B_other, PETSC_COMM_SELF, PETSC_ERR_PLIB, "In MATPRODUCT_AB, internal mat product matrix C2->B has unexpectedly changed");
10399566063dSJacob Faibussowitsch     if (ab->C1->product->A != Ao) PetscCall(MatProductReplaceMats(Ao, NULL, NULL, ab->C2));
10409566063dSJacob Faibussowitsch     PetscCall((*ab->C2->ops->productnumeric)(ab->C2));
1041076ba34aSJunchao Zhang     /* C = C1_global + C2_global */
1042076ba34aSJunchao Zhang     KokkosSparse::spadd_numeric(&ab->kh, one, ab->C1_global, one, ab->C2_global, ab->C_global);
1043076ba34aSJunchao Zhang     mm = static_cast<MatMatStruct *>(ab);
1044076ba34aSJunchao Zhang   } else if (ptype == MATPRODUCT_AtB) {
1045076ba34aSJunchao Zhang     atb = mmdata->mmAtB;
104608401ef6SPierre Jolivet     PetscCheck(atb->C1->ops->productnumeric && atb->C2->ops->productnumeric, PETSC_COMM_SELF, PETSC_ERR_PLIB, "Missing numeric op for MATPRODUCT_AtB");
1047076ba34aSJunchao Zhang     /* C1 = Ad^t * B_local */
10489566063dSJacob Faibussowitsch     PetscCall(MatMPIAIJGetLocalMatMerge(B, MAT_REUSE_MATRIX, NULL /* glob */, &atb->B_local));
104908401ef6SPierre Jolivet     PetscCheck(atb->C1->product->B == atb->B_local, PETSC_COMM_SELF, PETSC_ERR_PLIB, "In MATPRODUCT_AtB, internal mat product matrix C1->B has unexpectedly changed");
10509566063dSJacob Faibussowitsch     if (atb->C1->product->A != Ad) PetscCall(MatProductReplaceMats(Ad, NULL, NULL, atb->C1));
10519566063dSJacob Faibussowitsch     PetscCall((*atb->C1->ops->productnumeric)(atb->C1));
1052076ba34aSJunchao Zhang 
1053076ba34aSJunchao Zhang     /* C2 = Ao^t * B_local */
105408401ef6SPierre Jolivet     PetscCheck(atb->C2->product->B == atb->B_local, PETSC_COMM_SELF, PETSC_ERR_PLIB, "In MATPRODUCT_AtB, internal mat product matrix C2->B has unexpectedly changed");
10559566063dSJacob Faibussowitsch     if (atb->C2->product->A != Ao) PetscCall(MatProductReplaceMats(Ao, NULL, NULL, atb->C2));
10569566063dSJacob Faibussowitsch     PetscCall((*atb->C2->ops->productnumeric)(atb->C2));
1057076ba34aSJunchao Zhang     /* Form C2_global */
10589371c9d4SSatish Balay     PetscCall(MatSeqAIJKokkosReduce(atb->C2, MAT_REUSE_MATRIX, PETSC_TRUE, 0 /* N */, MatColIdxKokkosView() /*l2g*/, NULL /*ownerSF*/, atb->sf, atb->abuf, atb->srcrowoffset, atb->dstrowoffset, atb->C2_global));
1059076ba34aSJunchao Zhang     /* C = C1_global + C2_global */
1060076ba34aSJunchao Zhang     KokkosSparse::spadd_numeric(&atb->kh, one, atb->C1_global, one, atb->C2_global, atb->C_global);
1061076ba34aSJunchao Zhang     mm = static_cast<MatMatStruct *>(atb);
1062076ba34aSJunchao Zhang   } else if (ptype == MATPRODUCT_PtAP) { /* BtAB */
1063076ba34aSJunchao Zhang     ab = mmdata->mmAB;
10649566063dSJacob Faibussowitsch     PetscCall(MatMPIAIJGetLocalMatMerge(B, MAT_REUSE_MATRIX, NULL /* glob */, &ab->B_local));
1065076ba34aSJunchao Zhang 
1066076ba34aSJunchao Zhang     /* ab->C1 = Ad * B_local */
106708401ef6SPierre Jolivet     PetscCheck(ab->C1->product->B == ab->B_local, PETSC_COMM_SELF, PETSC_ERR_PLIB, "In MATPRODUCT_PtAP, internal mat product matrix ab->C1->B has unexpectedly changed");
10689566063dSJacob Faibussowitsch     if (ab->C1->product->A != Ad) PetscCall(MatProductReplaceMats(Ad, NULL, NULL, ab->C1));
10699566063dSJacob Faibussowitsch     PetscCall((*ab->C1->ops->productnumeric)(ab->C1));
10709371c9d4SSatish Balay     PetscCall(MatSeqAIJKokkosBcast(ab->B_local, MAT_REUSE_MATRIX, 0 /* N */, MatColIdxKokkosView() /*l2g*/, NULL /*ownerSF*/, ab->sf, ab->abuf, ab->rows, ab->rowoffset, ab->B_other));
1071076ba34aSJunchao Zhang     /* ab->C2 = Ao * B_other */
10729566063dSJacob Faibussowitsch     if (ab->C2->product->A != Ao) PetscCall(MatProductReplaceMats(Ao, NULL, NULL, ab->C2));
10739566063dSJacob Faibussowitsch     PetscCall((*ab->C2->ops->productnumeric)(ab->C2)); /* C2 = Ao * B_other */
1074076ba34aSJunchao Zhang     KokkosSparse::spadd_numeric(&ab->kh, one, ab->C1_global, one, ab->C2_global, ab->C_global);
1075076ba34aSJunchao Zhang 
1076076ba34aSJunchao Zhang     /* atb->C1 = Bd^t * ab->C_petsc */
1077076ba34aSJunchao Zhang     atb = mmdata->mmAtB;
107808401ef6SPierre Jolivet     PetscCheck(atb->C1->product->B == ab->C_petsc, PETSC_COMM_SELF, PETSC_ERR_PLIB, "In MATPRODUCT_PtAP, internal mat product matrix atb->C1->B has unexpectedly changed");
10799566063dSJacob Faibussowitsch     if (atb->C1->product->A != Bd) PetscCall(MatProductReplaceMats(Bd, NULL, NULL, atb->C1));
10809566063dSJacob Faibussowitsch     PetscCall((*atb->C1->ops->productnumeric)(atb->C1));
1081076ba34aSJunchao Zhang     /* atb->C2 = Bo^t * ab->C_petsc */
10829566063dSJacob Faibussowitsch     if (atb->C2->product->A != Bo) PetscCall(MatProductReplaceMats(Bo, NULL, NULL, atb->C2));
10839566063dSJacob Faibussowitsch     PetscCall((*atb->C2->ops->productnumeric)(atb->C2));
10849371c9d4SSatish Balay     PetscCall(MatSeqAIJKokkosReduce(atb->C2, MAT_REUSE_MATRIX, PETSC_FALSE, 0 /* N */, MatColIdxKokkosView() /*l2g*/, NULL /* ownerSF */, atb->sf, atb->abuf, atb->srcrowoffset, atb->dstrowoffset, atb->C2_global));
1085076ba34aSJunchao Zhang     KokkosSparse::spadd_numeric(&atb->kh, one, atb->C1_global, one, atb->C2_global, atb->C_global);
1086076ba34aSJunchao Zhang     mm = static_cast<MatMatStruct *>(atb);
1087076ba34aSJunchao Zhang   }
1088076ba34aSJunchao Zhang   /* Split C_global to form C */
10899566063dSJacob Faibussowitsch   PetscCall(MatSetMPIAIJKokkosWithGlobalCSRMatrix(C, MAT_REUSE_MATRIX, mm->C_global, mm->Cdstart));
1090*3ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
1091076ba34aSJunchao Zhang }
1092076ba34aSJunchao Zhang 
1093d71ae5a4SJacob Faibussowitsch PetscErrorCode MatProductSymbolic_MPIAIJKokkos(Mat C)
1094d71ae5a4SJacob Faibussowitsch {
1095076ba34aSJunchao Zhang   Mat                          A, B;
1096076ba34aSJunchao Zhang   Mat_Product                 *product = C->product;
1097076ba34aSJunchao Zhang   MatProductType               ptype;
1098076ba34aSJunchao Zhang   MatProductData_MPIAIJKokkos *mmdata;
1099076ba34aSJunchao Zhang   MatMatStruct                *mm   = NULL;
1100076ba34aSJunchao Zhang   IS                           glob = NULL;
1101076ba34aSJunchao Zhang   const PetscInt              *garray;
1102076ba34aSJunchao Zhang   PetscInt                     m, n, M, N, sz;
1103076ba34aSJunchao Zhang   ConstMatColIdxKokkosView     l2g; /* map local col ids to global ones */
1104076ba34aSJunchao Zhang 
1105076ba34aSJunchao Zhang   PetscFunctionBegin;
1106076ba34aSJunchao Zhang   MatCheckProduct(C, 1);
110728b400f6SJacob Faibussowitsch   PetscCheck(!product->data, PetscObjectComm((PetscObject)C), PETSC_ERR_PLIB, "Product data not empty");
1108076ba34aSJunchao Zhang   ptype = product->type;
1109076ba34aSJunchao Zhang   A     = product->A;
1110076ba34aSJunchao Zhang   B     = product->B;
1111076ba34aSJunchao Zhang 
1112076ba34aSJunchao Zhang   switch (ptype) {
11139371c9d4SSatish Balay   case MATPRODUCT_AB:
11149371c9d4SSatish Balay     m = A->rmap->n;
11159371c9d4SSatish Balay     n = B->cmap->n;
11169371c9d4SSatish Balay     M = A->rmap->N;
11179371c9d4SSatish Balay     N = B->cmap->N;
11189371c9d4SSatish Balay     break;
11199371c9d4SSatish Balay   case MATPRODUCT_AtB:
11209371c9d4SSatish Balay     m = A->cmap->n;
11219371c9d4SSatish Balay     n = B->cmap->n;
11229371c9d4SSatish Balay     M = A->cmap->N;
11239371c9d4SSatish Balay     N = B->cmap->N;
11249371c9d4SSatish Balay     break;
11259371c9d4SSatish Balay   case MATPRODUCT_PtAP:
11269371c9d4SSatish Balay     m = B->cmap->n;
11279371c9d4SSatish Balay     n = B->cmap->n;
11289371c9d4SSatish Balay     M = B->cmap->N;
11299371c9d4SSatish Balay     N = B->cmap->N;
11309371c9d4SSatish Balay     break; /* BtAB */
1131d71ae5a4SJacob Faibussowitsch   default:
1132d71ae5a4SJacob Faibussowitsch     SETERRQ(PetscObjectComm((PetscObject)C), PETSC_ERR_PLIB, "Not for product type %s", MatProductTypes[ptype]);
1133076ba34aSJunchao Zhang   }
1134076ba34aSJunchao Zhang 
11359566063dSJacob Faibussowitsch   PetscCall(MatSetSizes(C, m, n, M, N));
11369566063dSJacob Faibussowitsch   PetscCall(PetscLayoutSetUp(C->rmap));
11379566063dSJacob Faibussowitsch   PetscCall(PetscLayoutSetUp(C->cmap));
11389566063dSJacob Faibussowitsch   PetscCall(MatSetType(C, ((PetscObject)A)->type_name));
1139076ba34aSJunchao Zhang 
1140076ba34aSJunchao Zhang   mmdata           = new MatProductData_MPIAIJKokkos();
1141076ba34aSJunchao Zhang   mmdata->reusesym = product->api_user;
1142076ba34aSJunchao Zhang 
1143076ba34aSJunchao Zhang   if (ptype == MATPRODUCT_AB) {
1144076ba34aSJunchao Zhang     mmdata->mmAB = new MatMatStruct_AB();
11459566063dSJacob Faibussowitsch     PetscCall(MatProductSymbolic_MPIAIJKokkos_AB(product, A, B, mmdata->mmAB));
1146076ba34aSJunchao Zhang     mm = static_cast<MatMatStruct *>(mmdata->mmAB);
1147076ba34aSJunchao Zhang   } else if (ptype == MATPRODUCT_AtB) {
1148076ba34aSJunchao Zhang     mmdata->mmAtB = new MatMatStruct_AtB();
1149076ba34aSJunchao Zhang     auto atb      = mmdata->mmAtB;
11509566063dSJacob Faibussowitsch     PetscCall(MatMPIAIJGetLocalMatMerge(B, MAT_INITIAL_MATRIX, &glob, &atb->B_local));
11519566063dSJacob Faibussowitsch     PetscCall(ISGetIndices(glob, &garray));
11529566063dSJacob Faibussowitsch     PetscCall(ISGetSize(glob, &sz));
1153076ba34aSJunchao Zhang     l2g = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), ConstMatColIdxKokkosViewHost(garray, sz));
11549566063dSJacob Faibussowitsch     PetscCall(MatProductSymbolic_MPIAIJKokkos_AtB(product, A, atb->B_local, PETSC_TRUE, N, l2g, atb));
11559566063dSJacob Faibussowitsch     PetscCall(ISRestoreIndices(glob, &garray));
11569566063dSJacob Faibussowitsch     PetscCall(ISDestroy(&glob));
1157076ba34aSJunchao Zhang     mm = static_cast<MatMatStruct *>(atb);
1158076ba34aSJunchao Zhang   } else if (ptype == MATPRODUCT_PtAP) {    /* BtAB */
1159076ba34aSJunchao Zhang     mmdata->mmAB  = new MatMatStruct_AB();  /* tmp=A*B */
1160076ba34aSJunchao Zhang     mmdata->mmAtB = new MatMatStruct_AtB(); /* C=B^t*tmp */
1161076ba34aSJunchao Zhang     auto ab       = mmdata->mmAB;
1162076ba34aSJunchao Zhang     auto atb      = mmdata->mmAtB;
11639566063dSJacob Faibussowitsch     PetscCall(MatProductSymbolic_MPIAIJKokkos_AB(product, A, B, ab));
1164076ba34aSJunchao Zhang     auto tmp = new Mat_SeqAIJKokkos(ab->C_global); /* Memory will be owned by ab->C_petsc */
11659566063dSJacob Faibussowitsch     PetscCall(MatCreateSeqAIJKokkosWithCSRMatrix(PETSC_COMM_SELF, tmp, &ab->C_petsc));
11669566063dSJacob Faibussowitsch     PetscCall(MatProductSymbolic_MPIAIJKokkos_AtB(product, B, ab->C_petsc, PETSC_FALSE, N, l2g /*not used*/, atb));
1167076ba34aSJunchao Zhang     mm = static_cast<MatMatStruct *>(atb);
1168076ba34aSJunchao Zhang   }
1169076ba34aSJunchao Zhang   /* Split the C_global into petsc A, B format */
11709566063dSJacob Faibussowitsch   PetscCall(MatSetMPIAIJKokkosWithGlobalCSRMatrix(C, MAT_INITIAL_MATRIX, mm->C_global, mm->Cdstart));
1171076ba34aSJunchao Zhang   C->product->data       = mmdata;
1172076ba34aSJunchao Zhang   C->product->destroy    = MatProductDataDestroy_MPIAIJKokkos;
1173076ba34aSJunchao Zhang   C->ops->productnumeric = MatProductNumeric_MPIAIJKokkos;
1174*3ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
1175076ba34aSJunchao Zhang }
1176076ba34aSJunchao Zhang 
1177d71ae5a4SJacob Faibussowitsch PETSC_INTERN PetscErrorCode MatProductSetFromOptions_MPIAIJKokkos(Mat mat)
1178d71ae5a4SJacob Faibussowitsch {
1179076ba34aSJunchao Zhang   Mat_Product *product = mat->product;
1180076ba34aSJunchao Zhang   PetscBool    match   = PETSC_FALSE;
1181076ba34aSJunchao Zhang   PetscBool    usecpu  = PETSC_FALSE;
1182076ba34aSJunchao Zhang 
1183076ba34aSJunchao Zhang   PetscFunctionBegin;
1184076ba34aSJunchao Zhang   MatCheckProduct(mat, 1);
118548a46eb9SPierre Jolivet   if (!product->A->boundtocpu && !product->B->boundtocpu) PetscCall(PetscObjectTypeCompare((PetscObject)product->B, ((PetscObject)product->A)->type_name, &match));
1186076ba34aSJunchao Zhang   if (match) { /* we can always fallback to the CPU if requested */
1187076ba34aSJunchao Zhang     switch (product->type) {
1188076ba34aSJunchao Zhang     case MATPRODUCT_AB:
1189076ba34aSJunchao Zhang       if (product->api_user) {
1190d0609cedSBarry Smith         PetscOptionsBegin(PetscObjectComm((PetscObject)mat), ((PetscObject)mat)->prefix, "MatMatMult", "Mat");
11919566063dSJacob Faibussowitsch         PetscCall(PetscOptionsBool("-matmatmult_backend_cpu", "Use CPU code", "MatMatMult", usecpu, &usecpu, NULL));
1192d0609cedSBarry Smith         PetscOptionsEnd();
1193076ba34aSJunchao Zhang       } else {
1194d0609cedSBarry Smith         PetscOptionsBegin(PetscObjectComm((PetscObject)mat), ((PetscObject)mat)->prefix, "MatProduct_AB", "Mat");
11959566063dSJacob Faibussowitsch         PetscCall(PetscOptionsBool("-mat_product_algorithm_backend_cpu", "Use CPU code", "MatMatMult", usecpu, &usecpu, NULL));
1196d0609cedSBarry Smith         PetscOptionsEnd();
1197076ba34aSJunchao Zhang       }
1198076ba34aSJunchao Zhang       break;
1199076ba34aSJunchao Zhang     case MATPRODUCT_AtB:
1200076ba34aSJunchao Zhang       if (product->api_user) {
1201d0609cedSBarry Smith         PetscOptionsBegin(PetscObjectComm((PetscObject)mat), ((PetscObject)mat)->prefix, "MatTransposeMatMult", "Mat");
12029566063dSJacob Faibussowitsch         PetscCall(PetscOptionsBool("-mattransposematmult_backend_cpu", "Use CPU code", "MatTransposeMatMult", usecpu, &usecpu, NULL));
1203d0609cedSBarry Smith         PetscOptionsEnd();
1204076ba34aSJunchao Zhang       } else {
1205d0609cedSBarry Smith         PetscOptionsBegin(PetscObjectComm((PetscObject)mat), ((PetscObject)mat)->prefix, "MatProduct_AtB", "Mat");
12069566063dSJacob Faibussowitsch         PetscCall(PetscOptionsBool("-mat_product_algorithm_backend_cpu", "Use CPU code", "MatTransposeMatMult", usecpu, &usecpu, NULL));
1207d0609cedSBarry Smith         PetscOptionsEnd();
1208076ba34aSJunchao Zhang       }
1209076ba34aSJunchao Zhang       break;
1210076ba34aSJunchao Zhang     case MATPRODUCT_PtAP:
1211076ba34aSJunchao Zhang       if (product->api_user) {
1212d0609cedSBarry Smith         PetscOptionsBegin(PetscObjectComm((PetscObject)mat), ((PetscObject)mat)->prefix, "MatPtAP", "Mat");
12139566063dSJacob Faibussowitsch         PetscCall(PetscOptionsBool("-matptap_backend_cpu", "Use CPU code", "MatPtAP", usecpu, &usecpu, NULL));
1214d0609cedSBarry Smith         PetscOptionsEnd();
1215076ba34aSJunchao Zhang       } else {
1216d0609cedSBarry Smith         PetscOptionsBegin(PetscObjectComm((PetscObject)mat), ((PetscObject)mat)->prefix, "MatProduct_PtAP", "Mat");
12179566063dSJacob Faibussowitsch         PetscCall(PetscOptionsBool("-mat_product_algorithm_backend_cpu", "Use CPU code", "MatPtAP", usecpu, &usecpu, NULL));
1218d0609cedSBarry Smith         PetscOptionsEnd();
1219076ba34aSJunchao Zhang       }
1220076ba34aSJunchao Zhang       break;
1221d71ae5a4SJacob Faibussowitsch     default:
1222d71ae5a4SJacob Faibussowitsch       break;
1223076ba34aSJunchao Zhang     }
1224076ba34aSJunchao Zhang     match = (PetscBool)!usecpu;
1225076ba34aSJunchao Zhang   }
1226076ba34aSJunchao Zhang   if (match) {
1227076ba34aSJunchao Zhang     switch (product->type) {
1228076ba34aSJunchao Zhang     case MATPRODUCT_AB:
1229076ba34aSJunchao Zhang     case MATPRODUCT_AtB:
1230d71ae5a4SJacob Faibussowitsch     case MATPRODUCT_PtAP:
1231d71ae5a4SJacob Faibussowitsch       mat->ops->productsymbolic = MatProductSymbolic_MPIAIJKokkos;
1232d71ae5a4SJacob Faibussowitsch       break;
1233d71ae5a4SJacob Faibussowitsch     default:
1234d71ae5a4SJacob Faibussowitsch       break;
1235076ba34aSJunchao Zhang     }
1236076ba34aSJunchao Zhang   }
1237076ba34aSJunchao Zhang   /* fallback to MPIAIJ ops */
123848a46eb9SPierre Jolivet   if (!mat->ops->productsymbolic) PetscCall(MatProductSetFromOptions_MPIAIJ(mat));
1239*3ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
1240076ba34aSJunchao Zhang }
1241076ba34aSJunchao Zhang 
1242d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatSetPreallocationCOO_MPIAIJKokkos(Mat mat, PetscCount coo_n, PetscInt coo_i[], PetscInt coo_j[])
1243d71ae5a4SJacob Faibussowitsch {
1244394ed5ebSJunchao Zhang   Mat_MPIAIJ       *mpiaij = (Mat_MPIAIJ *)mat->data;
1245cbc6b225SStefano Zampini   Mat_MPIAIJKokkos *mpikok;
124642550becSJunchao Zhang 
124742550becSJunchao Zhang   PetscFunctionBegin;
124830203840SJunchao Zhang   PetscCall(MatSetPreallocationCOO_MPIAIJ(mat, coo_n, coo_i, coo_j)); /* mpiaij->A,B's type is set to seqaijkokkos */
1249cbc6b225SStefano Zampini   mat->preallocated = PETSC_TRUE;
12509566063dSJacob Faibussowitsch   PetscCall(MatAssemblyBegin(mat, MAT_FINAL_ASSEMBLY));
12519566063dSJacob Faibussowitsch   PetscCall(MatAssemblyEnd(mat, MAT_FINAL_ASSEMBLY));
12529566063dSJacob Faibussowitsch   PetscCall(MatZeroEntries(mat));
1253cbc6b225SStefano Zampini   mpikok = static_cast<Mat_MPIAIJKokkos *>(mpiaij->spptr);
1254cbc6b225SStefano Zampini   delete mpikok;
1255394ed5ebSJunchao Zhang   mpiaij->spptr = new Mat_MPIAIJKokkos(mpiaij);
1256*3ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
125742550becSJunchao Zhang }
125842550becSJunchao Zhang 
1259d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatSetValuesCOO_MPIAIJKokkos(Mat mat, const PetscScalar v[], InsertMode imode)
1260d71ae5a4SJacob Faibussowitsch {
1261394ed5ebSJunchao Zhang   Mat_MPIAIJ                 *mpiaij = static_cast<Mat_MPIAIJ *>(mat->data);
126242550becSJunchao Zhang   Mat_MPIAIJKokkos           *mpikok = static_cast<Mat_MPIAIJKokkos *>(mpiaij->spptr);
126342550becSJunchao Zhang   Mat                         A = mpiaij->A, B = mpiaij->B;
1264158ec288SJunchao Zhang   PetscCount                  Annz = mpiaij->Annz, Annz2 = mpiaij->Annz2, Bnnz = mpiaij->Bnnz, Bnnz2 = mpiaij->Bnnz2;
126542550becSJunchao Zhang   MatScalarKokkosView         Aa, Ba;
1266394ed5ebSJunchao Zhang   MatScalarKokkosView         v1;
126742550becSJunchao Zhang   MatScalarKokkosView        &vsend  = mpikok->sendbuf_d;
126842550becSJunchao Zhang   const MatScalarKokkosView  &v2     = mpikok->recvbuf_d;
1269158ec288SJunchao Zhang   const PetscCountKokkosView &Ajmap1 = mpikok->Ajmap1_d, Ajmap2 = mpikok->Ajmap2_d, Aimap2 = mpikok->Aimap2_d;
1270158ec288SJunchao Zhang   const PetscCountKokkosView &Bjmap1 = mpikok->Bjmap1_d, Bjmap2 = mpikok->Bjmap2_d, Bimap2 = mpikok->Bimap2_d;
1271394ed5ebSJunchao Zhang   const PetscCountKokkosView &Aperm1 = mpikok->Aperm1_d, Aperm2 = mpikok->Aperm2_d, Bperm1 = mpikok->Bperm1_d, Bperm2 = mpikok->Bperm2_d;
1272394ed5ebSJunchao Zhang   const PetscCountKokkosView &Cperm1 = mpikok->Cperm1_d;
127342550becSJunchao Zhang   PetscMemType                memtype;
127442550becSJunchao Zhang 
127542550becSJunchao Zhang   PetscFunctionBegin;
12769566063dSJacob Faibussowitsch   PetscCall(PetscGetMemType(v, &memtype)); /* Return PETSC_MEMTYPE_HOST when v is NULL */
127742550becSJunchao Zhang   if (PetscMemTypeHost(memtype)) {         /* If user gave v[] in host, we need to copy it to device if any */
1278394ed5ebSJunchao Zhang     v1 = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), MatScalarKokkosViewHost((PetscScalar *)v, mpiaij->coo_n));
127942550becSJunchao Zhang   } else {
1280394ed5ebSJunchao Zhang     v1 = MatScalarKokkosView((PetscScalar *)v, mpiaij->coo_n); /* Directly use v[]'s memory */
128142550becSJunchao Zhang   }
128242550becSJunchao Zhang 
128342550becSJunchao Zhang   if (imode == INSERT_VALUES) {
12849566063dSJacob Faibussowitsch     PetscCall(MatSeqAIJGetKokkosViewWrite(A, &Aa)); /* write matrix values */
12859566063dSJacob Faibussowitsch     PetscCall(MatSeqAIJGetKokkosViewWrite(B, &Ba));
1286394ed5ebSJunchao Zhang   } else {
12879566063dSJacob Faibussowitsch     PetscCall(MatSeqAIJGetKokkosView(A, &Aa)); /* read & write matrix values */
12889566063dSJacob Faibussowitsch     PetscCall(MatSeqAIJGetKokkosView(B, &Ba));
128942550becSJunchao Zhang   }
129042550becSJunchao Zhang 
129142550becSJunchao Zhang   /* Pack entries to be sent to remote */
12929371c9d4SSatish Balay   Kokkos::parallel_for(
12939371c9d4SSatish Balay     vsend.extent(0), KOKKOS_LAMBDA(const PetscCount i) { vsend(i) = v1(Cperm1(i)); });
129442550becSJunchao Zhang 
129542550becSJunchao Zhang   /* Send remote entries to their owner and overlap the communication with local computation */
12969566063dSJacob Faibussowitsch   PetscCall(PetscSFReduceWithMemTypeBegin(mpiaij->coo_sf, MPIU_SCALAR, PETSC_MEMTYPE_KOKKOS, vsend.data(), PETSC_MEMTYPE_KOKKOS, v2.data(), MPI_REPLACE));
1297158ec288SJunchao Zhang   /* Add local entries to A and B in one kernel */
12989371c9d4SSatish Balay   Kokkos::parallel_for(
12999371c9d4SSatish Balay     Annz + Bnnz, KOKKOS_LAMBDA(PetscCount i) {
1300158ec288SJunchao Zhang       PetscScalar sum = 0.0;
1301158ec288SJunchao Zhang       if (i < Annz) {
1302158ec288SJunchao Zhang         for (PetscCount k = Ajmap1(i); k < Ajmap1(i + 1); k++) sum += v1(Aperm1(k));
1303ac38520cSJunchao Zhang         Aa(i) = (imode == INSERT_VALUES ? 0.0 : Aa(i)) + sum;
1304158ec288SJunchao Zhang       } else {
1305158ec288SJunchao Zhang         i -= Annz;
1306158ec288SJunchao Zhang         for (PetscCount k = Bjmap1(i); k < Bjmap1(i + 1); k++) sum += v1(Bperm1(k));
1307ac38520cSJunchao Zhang         Ba(i) = (imode == INSERT_VALUES ? 0.0 : Ba(i)) + sum;
1308158ec288SJunchao Zhang       }
1309158ec288SJunchao Zhang     });
13109566063dSJacob Faibussowitsch   PetscCall(PetscSFReduceEnd(mpiaij->coo_sf, MPIU_SCALAR, vsend.data(), v2.data(), MPI_REPLACE));
131142550becSJunchao Zhang 
1312158ec288SJunchao Zhang   /* Add received remote entries to A and B in one kernel */
13139371c9d4SSatish Balay   Kokkos::parallel_for(
13149371c9d4SSatish Balay     Annz2 + Bnnz2, KOKKOS_LAMBDA(PetscCount i) {
1315158ec288SJunchao Zhang       if (i < Annz2) {
1316158ec288SJunchao Zhang         for (PetscCount k = Ajmap2(i); k < Ajmap2(i + 1); k++) Aa(Aimap2(i)) += v2(Aperm2(k));
1317158ec288SJunchao Zhang       } else {
1318158ec288SJunchao Zhang         i -= Annz2;
1319158ec288SJunchao Zhang         for (PetscCount k = Bjmap2(i); k < Bjmap2(i + 1); k++) Ba(Bimap2(i)) += v2(Bperm2(k));
1320158ec288SJunchao Zhang       }
1321158ec288SJunchao Zhang     });
132242550becSJunchao Zhang 
1323394ed5ebSJunchao Zhang   if (imode == INSERT_VALUES) {
13249566063dSJacob Faibussowitsch     PetscCall(MatSeqAIJRestoreKokkosViewWrite(A, &Aa)); /* Increase A & B's state etc. */
13259566063dSJacob Faibussowitsch     PetscCall(MatSeqAIJRestoreKokkosViewWrite(B, &Ba));
1326394ed5ebSJunchao Zhang   } else {
13279566063dSJacob Faibussowitsch     PetscCall(MatSeqAIJRestoreKokkosView(A, &Aa));
13289566063dSJacob Faibussowitsch     PetscCall(MatSeqAIJRestoreKokkosView(B, &Ba));
1329394ed5ebSJunchao Zhang   }
1330*3ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
133142550becSJunchao Zhang }
133242550becSJunchao Zhang 
1333d71ae5a4SJacob Faibussowitsch PetscErrorCode MatDestroy_MPIAIJKokkos(Mat A)
1334d71ae5a4SJacob Faibussowitsch {
133542550becSJunchao Zhang   Mat_MPIAIJ *mpiaij = (Mat_MPIAIJ *)A->data;
1336076ba34aSJunchao Zhang 
1337076ba34aSJunchao Zhang   PetscFunctionBegin;
13389566063dSJacob Faibussowitsch   PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatMPIAIJSetPreallocation_C", NULL));
13399566063dSJacob Faibussowitsch   PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatMPIAIJGetLocalMatMerge_C", NULL));
13409566063dSJacob Faibussowitsch   PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatSetPreallocationCOO_C", NULL));
13419566063dSJacob Faibussowitsch   PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatSetValuesCOO_C", NULL));
134242550becSJunchao Zhang   delete (Mat_MPIAIJKokkos *)mpiaij->spptr;
13439566063dSJacob Faibussowitsch   PetscCall(MatDestroy_MPIAIJ(A));
1344*3ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
1345076ba34aSJunchao Zhang }
1346076ba34aSJunchao Zhang 
1347d71ae5a4SJacob Faibussowitsch PETSC_INTERN PetscErrorCode MatConvert_MPIAIJ_MPIAIJKokkos(Mat A, MatType mtype, MatReuse reuse, Mat *newmat)
1348d71ae5a4SJacob Faibussowitsch {
13498c3ff71bSJunchao Zhang   Mat         B;
1350076ba34aSJunchao Zhang   Mat_MPIAIJ *a;
13518c3ff71bSJunchao Zhang 
13528c3ff71bSJunchao Zhang   PetscFunctionBegin;
13538c3ff71bSJunchao Zhang   if (reuse == MAT_INITIAL_MATRIX) {
13549566063dSJacob Faibussowitsch     PetscCall(MatDuplicate(A, MAT_COPY_VALUES, newmat));
13558c3ff71bSJunchao Zhang   } else if (reuse == MAT_REUSE_MATRIX) {
13569566063dSJacob Faibussowitsch     PetscCall(MatCopy(A, *newmat, SAME_NONZERO_PATTERN));
13578c3ff71bSJunchao Zhang   }
13588c3ff71bSJunchao Zhang   B = *newmat;
13598c3ff71bSJunchao Zhang 
13606f3d89d0SStefano Zampini   B->boundtocpu = PETSC_FALSE;
13619566063dSJacob Faibussowitsch   PetscCall(PetscFree(B->defaultvectype));
13629566063dSJacob Faibussowitsch   PetscCall(PetscStrallocpy(VECKOKKOS, &B->defaultvectype));
13639566063dSJacob Faibussowitsch   PetscCall(PetscObjectChangeTypeName((PetscObject)B, MATMPIAIJKOKKOS));
13648c3ff71bSJunchao Zhang 
1365076ba34aSJunchao Zhang   a = static_cast<Mat_MPIAIJ *>(A->data);
13669566063dSJacob Faibussowitsch   if (a->A) PetscCall(MatSetType(a->A, MATSEQAIJKOKKOS));
13679566063dSJacob Faibussowitsch   if (a->B) PetscCall(MatSetType(a->B, MATSEQAIJKOKKOS));
13689566063dSJacob Faibussowitsch   if (a->lvec) PetscCall(VecSetType(a->lvec, VECSEQKOKKOS));
1369076ba34aSJunchao Zhang 
13708c3ff71bSJunchao Zhang   B->ops->assemblyend           = MatAssemblyEnd_MPIAIJKokkos;
13718c3ff71bSJunchao Zhang   B->ops->mult                  = MatMult_MPIAIJKokkos;
13728c3ff71bSJunchao Zhang   B->ops->multadd               = MatMultAdd_MPIAIJKokkos;
13738c3ff71bSJunchao Zhang   B->ops->multtranspose         = MatMultTranspose_MPIAIJKokkos;
1374076ba34aSJunchao Zhang   B->ops->productsetfromoptions = MatProductSetFromOptions_MPIAIJKokkos;
1375076ba34aSJunchao Zhang   B->ops->destroy               = MatDestroy_MPIAIJKokkos;
13768c3ff71bSJunchao Zhang 
13779566063dSJacob Faibussowitsch   PetscCall(PetscObjectComposeFunction((PetscObject)B, "MatMPIAIJSetPreallocation_C", MatMPIAIJSetPreallocation_MPIAIJKokkos));
13789566063dSJacob Faibussowitsch   PetscCall(PetscObjectComposeFunction((PetscObject)B, "MatMPIAIJGetLocalMatMerge_C", MatMPIAIJGetLocalMatMerge_MPIAIJKokkos));
13799566063dSJacob Faibussowitsch   PetscCall(PetscObjectComposeFunction((PetscObject)B, "MatSetPreallocationCOO_C", MatSetPreallocationCOO_MPIAIJKokkos));
13809566063dSJacob Faibussowitsch   PetscCall(PetscObjectComposeFunction((PetscObject)B, "MatSetValuesCOO_C", MatSetValuesCOO_MPIAIJKokkos));
1381*3ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
13828c3ff71bSJunchao Zhang }
13833f3ba80aSJunchao Zhang /*MC
138411a5261eSBarry Smith    MATAIJKOKKOS - "mpiaijkokkos", a matrix type to be used for CSR sparse matrices with Kokkos
13858c3ff71bSJunchao Zhang 
13863f3ba80aSJunchao Zhang    A matrix type type using Kokkos-Kernels CrsMatrix type for portability across different device types
13873f3ba80aSJunchao Zhang 
13883f3ba80aSJunchao Zhang    Options Database Keys:
13893f3ba80aSJunchao Zhang .  -mat_type aijkokkos - sets the matrix type to "aijkokkos" during a call to MatSetFromOptions()
13903f3ba80aSJunchao Zhang 
13913f3ba80aSJunchao Zhang   Level: beginner
13923f3ba80aSJunchao Zhang 
139311a5261eSBarry Smith .seealso: `MatCreateAIJKokkos()`, `MATSEQAIJKOKKOS`, `MATSEQAIJ`, `MATMPIAIJ`
13943f3ba80aSJunchao Zhang M*/
1395d71ae5a4SJacob Faibussowitsch PETSC_EXTERN PetscErrorCode MatCreate_MPIAIJKokkos(Mat A)
1396d71ae5a4SJacob Faibussowitsch {
13978c3ff71bSJunchao Zhang   PetscFunctionBegin;
13989566063dSJacob Faibussowitsch   PetscCall(PetscKokkosInitializeCheck());
13999566063dSJacob Faibussowitsch   PetscCall(MatCreate_MPIAIJ(A));
14009566063dSJacob Faibussowitsch   PetscCall(MatConvert_MPIAIJ_MPIAIJKokkos(A, MATMPIAIJKOKKOS, MAT_INPLACE_MATRIX, &A));
1401*3ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
14028c3ff71bSJunchao Zhang }
14038c3ff71bSJunchao Zhang 
14048c3ff71bSJunchao Zhang /*@C
140511a5261eSBarry Smith    MatCreateAIJKokkos - Creates a sparse matrix in `MATAIJKOKOS` (compressed row) format
14068c3ff71bSJunchao Zhang    (the default parallel PETSc format).  This matrix will ultimately pushed down
14078c3ff71bSJunchao Zhang    to Kokkos for calculations. For good matrix
14088c3ff71bSJunchao Zhang    assembly performance the user should preallocate the matrix storage by setting
14098c3ff71bSJunchao Zhang    the parameter nz (or the array nnz).  By setting these parameters accurately,
14108c3ff71bSJunchao Zhang    performance during matrix assembly can be increased by more than a factor of 50.
14118c3ff71bSJunchao Zhang 
14128c3ff71bSJunchao Zhang    Collective
14138c3ff71bSJunchao Zhang 
14148c3ff71bSJunchao Zhang    Input Parameters:
141511a5261eSBarry Smith +  comm - MPI communicator, set to `PETSC_COMM_SELF`
14168c3ff71bSJunchao Zhang .  m - number of rows
14178c3ff71bSJunchao Zhang .  n - number of columns
14188c3ff71bSJunchao Zhang .  nz - number of nonzeros per row (same for all rows)
14198c3ff71bSJunchao Zhang -  nnz - array containing the number of nonzeros in the various rows
14208c3ff71bSJunchao Zhang          (possibly different for each row) or NULL
14218c3ff71bSJunchao Zhang 
14228c3ff71bSJunchao Zhang    Output Parameter:
14238c3ff71bSJunchao Zhang .  A - the matrix
14248c3ff71bSJunchao Zhang 
142511a5261eSBarry Smith    It is recommended that one use the `MatCreate()`, `MatSetType()` and/or `MatSetFromOptions()`,
14268c3ff71bSJunchao Zhang    MatXXXXSetPreallocation() paradigm instead of this routine directly.
142711a5261eSBarry Smith    [MatXXXXSetPreallocation() is, for example, `MatSeqAIJSetPreallocation()`]
14288c3ff71bSJunchao Zhang 
14298c3ff71bSJunchao Zhang    Notes:
14308c3ff71bSJunchao Zhang    If nnz is given then nz is ignored
14318c3ff71bSJunchao Zhang 
143211a5261eSBarry Smith    The AIJ format, also called compressed row storage), is fully compatible with standard Fortran 77
14338c3ff71bSJunchao Zhang    storage.  That is, the stored row and column indices can begin at
14348c3ff71bSJunchao Zhang    either one (as in Fortran) or zero.  See the users' manual for details.
14358c3ff71bSJunchao Zhang 
14368c3ff71bSJunchao Zhang    Specify the preallocated storage with either nz or nnz (not both).
143711a5261eSBarry Smith    Set nz = `PETSC_DEFAULT` and nnz = NULL for PETSc to control dynamic memory
14388c3ff71bSJunchao Zhang    allocation.  For large problems you MUST preallocate memory or you
14398c3ff71bSJunchao Zhang    will get TERRIBLE performance, see the users' manual chapter on matrices.
14408c3ff71bSJunchao Zhang 
14418c3ff71bSJunchao Zhang    By default, this format uses inodes (identical nodes) when possible, to
14428c3ff71bSJunchao Zhang    improve numerical efficiency of matrix-vector products and solves. We
14438c3ff71bSJunchao Zhang    search for consecutive rows with the same nonzero structure, thereby
14448c3ff71bSJunchao Zhang    reusing matrix information to achieve increased efficiency.
14458c3ff71bSJunchao Zhang 
14468c3ff71bSJunchao Zhang    Level: intermediate
14478c3ff71bSJunchao Zhang 
144811a5261eSBarry Smith .seealso: `MATAIJKOKOS`, `MATSEQAIJKOKOS`, `MATMPIAIJKOKOS`, `MatCreate()`, `MatCreateAIJ()`, `MatSetValues()`, `MatSeqAIJSetColumnIndices()`, `MatCreateSeqAIJWithArrays()`, `MatCreateAIJ()`, `MATMPIAIJKOKKOS`, `MATAIJKOKKOS`
14498c3ff71bSJunchao Zhang @*/
1450d71ae5a4SJacob Faibussowitsch PetscErrorCode MatCreateAIJKokkos(MPI_Comm comm, PetscInt m, PetscInt n, PetscInt M, PetscInt N, PetscInt d_nz, const PetscInt d_nnz[], PetscInt o_nz, const PetscInt o_nnz[], Mat *A)
1451d71ae5a4SJacob Faibussowitsch {
14528c3ff71bSJunchao Zhang   PetscMPIInt size;
14538c3ff71bSJunchao Zhang 
14548c3ff71bSJunchao Zhang   PetscFunctionBegin;
14559566063dSJacob Faibussowitsch   PetscCall(MatCreate(comm, A));
14569566063dSJacob Faibussowitsch   PetscCall(MatSetSizes(*A, m, n, M, N));
14579566063dSJacob Faibussowitsch   PetscCallMPI(MPI_Comm_size(comm, &size));
14588c3ff71bSJunchao Zhang   if (size > 1) {
14599566063dSJacob Faibussowitsch     PetscCall(MatSetType(*A, MATMPIAIJKOKKOS));
14609566063dSJacob Faibussowitsch     PetscCall(MatMPIAIJSetPreallocation(*A, d_nz, d_nnz, o_nz, o_nnz));
14618c3ff71bSJunchao Zhang   } else {
14629566063dSJacob Faibussowitsch     PetscCall(MatSetType(*A, MATSEQAIJKOKKOS));
14639566063dSJacob Faibussowitsch     PetscCall(MatSeqAIJSetPreallocation(*A, d_nz, d_nnz));
14648c3ff71bSJunchao Zhang   }
1465*3ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
14668c3ff71bSJunchao Zhang }
14678c3ff71bSJunchao Zhang 
1468a587d139SMark // get GPU pointer to stripped down Mat. For both Seq and MPI Mat.
1469d71ae5a4SJacob Faibussowitsch PetscErrorCode MatKokkosGetDeviceMatWrite(Mat A, PetscSplitCSRDataStructure *B)
1470d71ae5a4SJacob Faibussowitsch {
1471a587d139SMark   PetscMPIInt                size, rank;
1472a587d139SMark   MPI_Comm                   comm;
1473042217e8SBarry Smith   PetscSplitCSRDataStructure d_mat = NULL;
1474a587d139SMark 
1475a587d139SMark   PetscFunctionBegin;
14769566063dSJacob Faibussowitsch   PetscCall(PetscObjectGetComm((PetscObject)A, &comm));
14779566063dSJacob Faibussowitsch   PetscCallMPI(MPI_Comm_size(comm, &size));
14789566063dSJacob Faibussowitsch   PetscCallMPI(MPI_Comm_rank(comm, &rank));
1479a587d139SMark   if (size == 1) {
14809566063dSJacob Faibussowitsch     PetscCall(MatSeqAIJKokkosGetDeviceMat(A, &d_mat));
14819566063dSJacob Faibussowitsch     PetscCall(MatSeqAIJKokkosModifyDevice(A)); /* Since we are going to modify matrix values on device */
1482a587d139SMark   } else {
1483a587d139SMark     Mat_MPIAIJ *aij = (Mat_MPIAIJ *)A->data;
14849566063dSJacob Faibussowitsch     PetscCall(MatSeqAIJKokkosGetDeviceMat(aij->A, &d_mat));
14859566063dSJacob Faibussowitsch     PetscCall(MatSeqAIJKokkosModifyDevice(aij->A));
14869566063dSJacob Faibussowitsch     PetscCall(MatSeqAIJKokkosModifyDevice(aij->B));
14872c71b3e2SJacob Faibussowitsch     PetscCheck(A->nooffprocentries || aij->donotstash, PetscObjectComm((PetscObject)A), PETSC_ERR_SUP, "Device assembly does not currently support offproc values insertion. Use MatSetOption(A,MAT_NO_OFF_PROC_ENTRIES,PETSC_TRUE) or MatSetOption(A,MAT_IGNORE_OFF_PROC_ENTRIES,PETSC_TRUE)");
1488a587d139SMark   }
1489a587d139SMark   // act like MatSetValues because not called on host
1490a587d139SMark   if (A->assembled) {
149148a46eb9SPierre Jolivet     if (A->was_assembled) PetscCall(PetscInfo(A, "Assemble more than once already\n"));
1492a587d139SMark     A->was_assembled = PETSC_TRUE; // this is done (lazy) in MatAssemble but we are not calling it anymore - done in AIJ AssemblyEnd, need here?
1493a587d139SMark   } else {
14949566063dSJacob Faibussowitsch     PetscCall(PetscInfo(A, "Warning !assemble ??? assembled=%" PetscInt_FMT "\n", A->assembled));
1495a587d139SMark   }
1496a587d139SMark   if (!d_mat) {
1497042217e8SBarry Smith     struct _n_SplitCSRMat h_mat; /* host container */
1498a587d139SMark     Mat_SeqAIJKokkos     *aijkokA;
1499a587d139SMark     Mat_SeqAIJ           *jaca;
1500a587d139SMark     PetscInt              n = A->rmap->n, nnz;
1501a587d139SMark     Mat                   Amat;
1502042217e8SBarry Smith     PetscInt             *colmap;
1503042217e8SBarry Smith 
1504042217e8SBarry Smith     /* create and copy h_mat */
150549b994a9SMark Adams     h_mat.M = A->cmap->N; // use for debug build
15069566063dSJacob Faibussowitsch     PetscCall(PetscInfo(A, "Create device matrix in Kokkos\n"));
1507a587d139SMark     if (size == 1) {
1508a587d139SMark       Amat            = A;
1509a587d139SMark       jaca            = (Mat_SeqAIJ *)A->data;
15109371c9d4SSatish Balay       h_mat.rstart    = 0;
15119371c9d4SSatish Balay       h_mat.rend      = A->rmap->n;
15129371c9d4SSatish Balay       h_mat.cstart    = 0;
15139371c9d4SSatish Balay       h_mat.cend      = A->cmap->n;
1514a587d139SMark       h_mat.offdiag.i = h_mat.offdiag.j = NULL;
1515a587d139SMark       h_mat.offdiag.a                   = NULL;
1516a587d139SMark       aijkokA                           = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
1517a587d139SMark     } else {
1518a587d139SMark       Mat_MPIAIJ       *aij  = (Mat_MPIAIJ *)A->data;
1519a587d139SMark       Mat_SeqAIJ       *jacb = (Mat_SeqAIJ *)aij->B->data;
1520a587d139SMark       PetscInt          ii;
1521a587d139SMark       Mat_SeqAIJKokkos *aijkokB;
1522042217e8SBarry Smith 
1523a587d139SMark       Amat    = aij->A;
1524a587d139SMark       aijkokA = static_cast<Mat_SeqAIJKokkos *>(aij->A->spptr);
1525a587d139SMark       aijkokB = static_cast<Mat_SeqAIJKokkos *>(aij->B->spptr);
1526a587d139SMark       jaca    = (Mat_SeqAIJ *)aij->A->data;
152708401ef6SPierre Jolivet       PetscCheck(!aij->B->cmap->n || aij->garray, comm, PETSC_ERR_PLIB, "MPIAIJ Matrix was assembled but is missing garray");
152808401ef6SPierre Jolivet       PetscCheck(aij->B->rmap->n == aij->A->rmap->n, comm, PETSC_ERR_SUP, "Only support aij->B->rmap->n == aij->A->rmap->n");
1529a587d139SMark       aij->donotstash          = PETSC_TRUE;
1530a587d139SMark       aij->A->nooffprocentries = aij->B->nooffprocentries = A->nooffprocentries = PETSC_TRUE;
1531a5b23f4aSJose E. Roman       jaca->nonew = jacb->nonew = PETSC_TRUE; // no more disassembly
15329566063dSJacob Faibussowitsch       PetscCall(PetscCalloc1(A->cmap->N, &colmap));
1533042217e8SBarry Smith       for (ii = 0; ii < aij->B->cmap->n; ii++) colmap[aij->garray[ii]] = ii + 1;
1534a587d139SMark       // allocate B copy data
15359371c9d4SSatish Balay       h_mat.rstart = A->rmap->rstart;
15369371c9d4SSatish Balay       h_mat.rend   = A->rmap->rend;
15379371c9d4SSatish Balay       h_mat.cstart = A->cmap->rstart;
15389371c9d4SSatish Balay       h_mat.cend   = A->cmap->rend;
1539a587d139SMark       nnz          = jacb->i[n];
1540a587d139SMark       if (jacb->compressedrow.use) {
1541a587d139SMark         const Kokkos::View<PetscInt *, Kokkos::HostSpace, Kokkos::MemoryTraits<Kokkos::Unmanaged>> h_i_k(jacb->i, n + 1);
1542300d22a6SJunchao Zhang         aijkokB->i_uncompressed_d = Kokkos::View<PetscInt *>(Kokkos::create_mirror(DefaultMemorySpace(), h_i_k));
1543300d22a6SJunchao Zhang         Kokkos::deep_copy(aijkokB->i_uncompressed_d, h_i_k);
1544300d22a6SJunchao Zhang         h_mat.offdiag.i = aijkokB->i_uncompressed_d.data();
1545a587d139SMark       } else {
154699551766SMark Adams         h_mat.offdiag.i = aijkokB->i_device_data();
1547a587d139SMark       }
154899551766SMark Adams       h_mat.offdiag.j = aijkokB->j_device_data();
1549076ba34aSJunchao Zhang       h_mat.offdiag.a = aijkokB->a_device_data();
1550a587d139SMark       {
1551042217e8SBarry Smith         Kokkos::View<PetscInt *, Kokkos::HostSpace, Kokkos::MemoryTraits<Kokkos::Unmanaged>> h_colmap_k(colmap, A->cmap->N);
1552300d22a6SJunchao Zhang         aijkokB->colmap_d = Kokkos::View<PetscInt *>(Kokkos::create_mirror(DefaultMemorySpace(), h_colmap_k));
1553300d22a6SJunchao Zhang         Kokkos::deep_copy(aijkokB->colmap_d, h_colmap_k);
1554300d22a6SJunchao Zhang         h_mat.colmap = aijkokB->colmap_d.data();
15559566063dSJacob Faibussowitsch         PetscCall(PetscFree(colmap));
1556a587d139SMark       }
1557a587d139SMark       h_mat.offdiag.ignorezeroentries = jacb->ignorezeroentries;
1558a587d139SMark       h_mat.offdiag.n                 = n;
1559a587d139SMark     }
1560a587d139SMark     // allocate A copy data
1561a587d139SMark     nnz                          = jaca->i[n];
1562a587d139SMark     h_mat.diag.n                 = n;
1563a587d139SMark     h_mat.diag.ignorezeroentries = jaca->ignorezeroentries;
15649566063dSJacob Faibussowitsch     PetscCallMPI(MPI_Comm_rank(comm, &h_mat.rank));
1565d5b43468SJose E. Roman     PetscCheck(!jaca->compressedrow.use, PETSC_COMM_SELF, PETSC_ERR_PLIB, "A does not support compressed row (todo)");
156699551766SMark Adams     h_mat.diag.i = aijkokA->i_device_data();
156799551766SMark Adams     h_mat.diag.j = aijkokA->j_device_data();
1568076ba34aSJunchao Zhang     h_mat.diag.a = aijkokA->a_device_data();
1569a587d139SMark     // copy pointers and metdata to device
15709566063dSJacob Faibussowitsch     PetscCall(MatSeqAIJKokkosSetDeviceMat(Amat, &h_mat));
15719566063dSJacob Faibussowitsch     PetscCall(MatSeqAIJKokkosGetDeviceMat(Amat, &d_mat));
15729566063dSJacob Faibussowitsch     PetscCall(PetscInfo(A, "Create device Mat n=%" PetscInt_FMT " nnz=%" PetscInt_FMT "\n", h_mat.diag.n, nnz));
1573a587d139SMark   }
1574a587d139SMark   *B           = d_mat;       // return it, set it in Mat, and set it up
1575a587d139SMark   A->assembled = PETSC_FALSE; // ready to write with matsetvalues - this done (lazy) in normal MatSetValues
1576*3ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
1577a587d139SMark }
1578076ba34aSJunchao Zhang 
1579d71ae5a4SJacob Faibussowitsch PETSC_INTERN PetscErrorCode MatSeqAIJKokkosGetOffloadMask(Mat A, const char **mask)
1580d71ae5a4SJacob Faibussowitsch {
1581076ba34aSJunchao Zhang   Mat_SeqAIJKokkos *aijkok = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
1582076ba34aSJunchao Zhang 
1583076ba34aSJunchao Zhang   PetscFunctionBegin;
1584076ba34aSJunchao Zhang   if (!aijkok) *mask = "AIJKOK_UNALLOCATED";
1585076ba34aSJunchao Zhang   else if (aijkok->a_dual.need_sync_host()) *mask = "PETSC_OFFLOAD_GPU";
1586076ba34aSJunchao Zhang   else if (aijkok->a_dual.need_sync_device()) *mask = "PETSC_OFFLOAD_CPU";
1587076ba34aSJunchao Zhang   else *mask = "PETSC_OFFLOAD_BOTH";
1588*3ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
1589076ba34aSJunchao Zhang }
1590076ba34aSJunchao Zhang 
1591d71ae5a4SJacob Faibussowitsch PETSC_INTERN PetscErrorCode MatAIJKokkosPrintOffloadMask(Mat A)
1592d71ae5a4SJacob Faibussowitsch {
1593076ba34aSJunchao Zhang   PetscMPIInt size;
1594076ba34aSJunchao Zhang   Mat         Ad, Ao;
1595076ba34aSJunchao Zhang   const char *amask, *bmask;
1596076ba34aSJunchao Zhang 
1597076ba34aSJunchao Zhang   PetscFunctionBegin;
15989566063dSJacob Faibussowitsch   PetscCallMPI(MPI_Comm_size(PetscObjectComm((PetscObject)A), &size));
1599076ba34aSJunchao Zhang 
1600076ba34aSJunchao Zhang   if (size == 1) {
16019566063dSJacob Faibussowitsch     PetscCall(MatSeqAIJKokkosGetOffloadMask(A, &amask));
16029566063dSJacob Faibussowitsch     PetscCall(PetscPrintf(PETSC_COMM_SELF, "%s\n", amask));
1603076ba34aSJunchao Zhang   } else {
1604076ba34aSJunchao Zhang     Ad = ((Mat_MPIAIJ *)A->data)->A;
1605076ba34aSJunchao Zhang     Ao = ((Mat_MPIAIJ *)A->data)->B;
16069566063dSJacob Faibussowitsch     PetscCall(MatSeqAIJKokkosGetOffloadMask(Ad, &amask));
16079566063dSJacob Faibussowitsch     PetscCall(MatSeqAIJKokkosGetOffloadMask(Ao, &bmask));
16089566063dSJacob Faibussowitsch     PetscCall(PetscPrintf(PETSC_COMM_SELF, "Diag : Off-diag = %s : %s\n", amask, bmask));
1609076ba34aSJunchao Zhang   }
1610*3ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
1611076ba34aSJunchao Zhang }
1612