111d22bbfSJunchao Zhang #include <petscvec_kokkos.hpp> 2f0e6e2d1SJunchao Zhang #include <petscpkg_version.h> 3076ba34aSJunchao Zhang #include <petscsf.h> 442550becSJunchao Zhang #include <petsc/private/sfimpl.h> 58c3ff71bSJunchao Zhang #include <../src/mat/impls/aij/mpi/mpiaij.h> 642550becSJunchao Zhang #include <../src/mat/impls/aij/mpi/kokkos/mpiaijkok.hpp> 7076ba34aSJunchao Zhang #include <KokkosSparse_spadd.hpp> 80e3ece09SJunchao Zhang #include <KokkosSparse_spgemm.hpp> 911d22bbfSJunchao Zhang 10d71ae5a4SJacob Faibussowitsch PetscErrorCode MatAssemblyEnd_MPIAIJKokkos(Mat A, MatAssemblyType mode) 11d71ae5a4SJacob Faibussowitsch { 125519a089SJose E. Roman Mat_SeqAIJKokkos *aijkok; 1330203840SJunchao Zhang Mat_MPIAIJ *mpiaij = (Mat_MPIAIJ *)A->data; 148c3ff71bSJunchao Zhang 158c3ff71bSJunchao Zhang PetscFunctionBegin; 169566063dSJacob Faibussowitsch PetscCall(MatAssemblyEnd_MPIAIJ(A, mode)); 1730203840SJunchao Zhang /* E.g., MatCreateSubMatrix() calls MatCreateMPIAIJWithSeqAIJ(comm,A,B,..), which creates Bnew of SEQAIJ and destroys B of SEQAIJKOKKOS. 1830203840SJunchao Zhang Thus we finalize A/B/lvec's type in MatAssemblyEnd() to handle various cases. 1930203840SJunchao Zhang */ 2030203840SJunchao Zhang if (mode == MAT_FINAL_ASSEMBLY) { 2130203840SJunchao Zhang PetscCall(MatSetType(mpiaij->A, MATSEQAIJKOKKOS)); 2230203840SJunchao Zhang PetscCall(MatSetType(mpiaij->B, MATSEQAIJKOKKOS)); 2330203840SJunchao Zhang PetscCall(VecSetType(mpiaij->lvec, VECSEQKOKKOS)); 2430203840SJunchao Zhang } 255519a089SJose E. Roman aijkok = static_cast<Mat_SeqAIJKokkos *>(((Mat_MPIAIJ *)A->data)->A->spptr); /* Access spptr after MatAssemblyEnd_MPIAIJ(), which might have deleted old spptr */ 26a587d139SMark if (aijkok && aijkok->device_mat_d.data()) { 27a587d139SMark A->offloadmask = PETSC_OFFLOAD_GPU; // in GPU mode, no going back. MatSetValues checks this 28a587d139SMark } 29a587d139SMark 303ba16761SJacob Faibussowitsch PetscFunctionReturn(PETSC_SUCCESS); 318c3ff71bSJunchao Zhang } 328c3ff71bSJunchao Zhang 33d71ae5a4SJacob Faibussowitsch PetscErrorCode MatMPIAIJSetPreallocation_MPIAIJKokkos(Mat mat, PetscInt d_nz, const PetscInt d_nnz[], PetscInt o_nz, const PetscInt o_nnz[]) 34d71ae5a4SJacob Faibussowitsch { 358c3ff71bSJunchao Zhang Mat_MPIAIJ *mpiaij = (Mat_MPIAIJ *)mat->data; 368c3ff71bSJunchao Zhang 378c3ff71bSJunchao Zhang PetscFunctionBegin; 389566063dSJacob Faibussowitsch PetscCall(PetscLayoutSetUp(mat->rmap)); 399566063dSJacob Faibussowitsch PetscCall(PetscLayoutSetUp(mat->cmap)); 406a29ce69SStefano Zampini #if defined(PETSC_USE_DEBUG) 418c3ff71bSJunchao Zhang if (d_nnz) { 426a29ce69SStefano Zampini PetscInt i; 43ad540459SPierre Jolivet for (i = 0; i < mat->rmap->n; i++) PetscCheck(d_nnz[i] >= 0, PETSC_COMM_SELF, PETSC_ERR_ARG_OUTOFRANGE, "d_nnz cannot be less than 0: local row %" PetscInt_FMT " value %" PetscInt_FMT, i, d_nnz[i]); 448c3ff71bSJunchao Zhang } 458c3ff71bSJunchao Zhang if (o_nnz) { 466a29ce69SStefano Zampini PetscInt i; 47ad540459SPierre Jolivet for (i = 0; i < mat->rmap->n; i++) PetscCheck(o_nnz[i] >= 0, PETSC_COMM_SELF, PETSC_ERR_ARG_OUTOFRANGE, "o_nnz cannot be less than 0: local row %" PetscInt_FMT " value %" PetscInt_FMT, i, o_nnz[i]); 488c3ff71bSJunchao Zhang } 496a29ce69SStefano Zampini #endif 506a29ce69SStefano Zampini #if defined(PETSC_USE_CTABLE) 51eec179cfSJacob Faibussowitsch PetscCall(PetscHMapIDestroy(&mpiaij->colmap)); 526a29ce69SStefano Zampini #else 539566063dSJacob Faibussowitsch PetscCall(PetscFree(mpiaij->colmap)); 546a29ce69SStefano Zampini #endif 559566063dSJacob Faibussowitsch PetscCall(PetscFree(mpiaij->garray)); 569566063dSJacob Faibussowitsch PetscCall(VecDestroy(&mpiaij->lvec)); 579566063dSJacob Faibussowitsch PetscCall(VecScatterDestroy(&mpiaij->Mvctx)); 586a29ce69SStefano Zampini /* Because the B will have been resized we simply destroy it and create a new one each time */ 599566063dSJacob Faibussowitsch PetscCall(MatDestroy(&mpiaij->B)); 606a29ce69SStefano Zampini 616a29ce69SStefano Zampini if (!mpiaij->A) { 629566063dSJacob Faibussowitsch PetscCall(MatCreate(PETSC_COMM_SELF, &mpiaij->A)); 639566063dSJacob Faibussowitsch PetscCall(MatSetSizes(mpiaij->A, mat->rmap->n, mat->cmap->n, mat->rmap->n, mat->cmap->n)); 646a29ce69SStefano Zampini } 656a29ce69SStefano Zampini if (!mpiaij->B) { 666a29ce69SStefano Zampini PetscMPIInt size; 679566063dSJacob Faibussowitsch PetscCallMPI(MPI_Comm_size(PetscObjectComm((PetscObject)mat), &size)); 689566063dSJacob Faibussowitsch PetscCall(MatCreate(PETSC_COMM_SELF, &mpiaij->B)); 699566063dSJacob Faibussowitsch PetscCall(MatSetSizes(mpiaij->B, mat->rmap->n, size > 1 ? mat->cmap->N : 0, mat->rmap->n, size > 1 ? mat->cmap->N : 0)); 708c3ff71bSJunchao Zhang } 719566063dSJacob Faibussowitsch PetscCall(MatSetType(mpiaij->A, MATSEQAIJKOKKOS)); 729566063dSJacob Faibussowitsch PetscCall(MatSetType(mpiaij->B, MATSEQAIJKOKKOS)); 739566063dSJacob Faibussowitsch PetscCall(MatSeqAIJSetPreallocation(mpiaij->A, d_nz, d_nnz)); 749566063dSJacob Faibussowitsch PetscCall(MatSeqAIJSetPreallocation(mpiaij->B, o_nz, o_nnz)); 758c3ff71bSJunchao Zhang mat->preallocated = PETSC_TRUE; 763ba16761SJacob Faibussowitsch PetscFunctionReturn(PETSC_SUCCESS); 778c3ff71bSJunchao Zhang } 788c3ff71bSJunchao Zhang 79d71ae5a4SJacob Faibussowitsch PetscErrorCode MatMult_MPIAIJKokkos(Mat mat, Vec xx, Vec yy) 80d71ae5a4SJacob Faibussowitsch { 818c3ff71bSJunchao Zhang Mat_MPIAIJ *mpiaij = (Mat_MPIAIJ *)mat->data; 828c3ff71bSJunchao Zhang PetscInt nt; 838c3ff71bSJunchao Zhang 848c3ff71bSJunchao Zhang PetscFunctionBegin; 859566063dSJacob Faibussowitsch PetscCall(VecGetLocalSize(xx, &nt)); 8608401ef6SPierre Jolivet PetscCheck(nt == mat->cmap->n, PETSC_COMM_SELF, PETSC_ERR_ARG_SIZ, "Incompatible partition of mat (%" PetscInt_FMT ") and xx (%" PetscInt_FMT ")", mat->cmap->n, nt); 879566063dSJacob Faibussowitsch PetscCall(VecScatterBegin(mpiaij->Mvctx, xx, mpiaij->lvec, INSERT_VALUES, SCATTER_FORWARD)); 889566063dSJacob Faibussowitsch PetscCall((*mpiaij->A->ops->mult)(mpiaij->A, xx, yy)); 899566063dSJacob Faibussowitsch PetscCall(VecScatterEnd(mpiaij->Mvctx, xx, mpiaij->lvec, INSERT_VALUES, SCATTER_FORWARD)); 909566063dSJacob Faibussowitsch PetscCall((*mpiaij->B->ops->multadd)(mpiaij->B, mpiaij->lvec, yy, yy)); 913ba16761SJacob Faibussowitsch PetscFunctionReturn(PETSC_SUCCESS); 928c3ff71bSJunchao Zhang } 938c3ff71bSJunchao Zhang 94d71ae5a4SJacob Faibussowitsch PetscErrorCode MatMultAdd_MPIAIJKokkos(Mat mat, Vec xx, Vec yy, Vec zz) 95d71ae5a4SJacob Faibussowitsch { 968c3ff71bSJunchao Zhang Mat_MPIAIJ *mpiaij = (Mat_MPIAIJ *)mat->data; 978c3ff71bSJunchao Zhang PetscInt nt; 988c3ff71bSJunchao Zhang 998c3ff71bSJunchao Zhang PetscFunctionBegin; 1009566063dSJacob Faibussowitsch PetscCall(VecGetLocalSize(xx, &nt)); 10108401ef6SPierre Jolivet PetscCheck(nt == mat->cmap->n, PETSC_COMM_SELF, PETSC_ERR_ARG_SIZ, "Incompatible partition of mat (%" PetscInt_FMT ") and xx (%" PetscInt_FMT ")", mat->cmap->n, nt); 1029566063dSJacob Faibussowitsch PetscCall(VecScatterBegin(mpiaij->Mvctx, xx, mpiaij->lvec, INSERT_VALUES, SCATTER_FORWARD)); 1039566063dSJacob Faibussowitsch PetscCall((*mpiaij->A->ops->multadd)(mpiaij->A, xx, yy, zz)); 1049566063dSJacob Faibussowitsch PetscCall(VecScatterEnd(mpiaij->Mvctx, xx, mpiaij->lvec, INSERT_VALUES, SCATTER_FORWARD)); 1059566063dSJacob Faibussowitsch PetscCall((*mpiaij->B->ops->multadd)(mpiaij->B, mpiaij->lvec, zz, zz)); 1063ba16761SJacob Faibussowitsch PetscFunctionReturn(PETSC_SUCCESS); 1078c3ff71bSJunchao Zhang } 1088c3ff71bSJunchao Zhang 109d71ae5a4SJacob Faibussowitsch PetscErrorCode MatMultTranspose_MPIAIJKokkos(Mat mat, Vec xx, Vec yy) 110d71ae5a4SJacob Faibussowitsch { 1118c3ff71bSJunchao Zhang Mat_MPIAIJ *mpiaij = (Mat_MPIAIJ *)mat->data; 1128c3ff71bSJunchao Zhang PetscInt nt; 1138c3ff71bSJunchao Zhang 1148c3ff71bSJunchao Zhang PetscFunctionBegin; 1159566063dSJacob Faibussowitsch PetscCall(VecGetLocalSize(xx, &nt)); 11608401ef6SPierre Jolivet PetscCheck(nt == mat->rmap->n, PETSC_COMM_SELF, PETSC_ERR_ARG_SIZ, "Incompatible partition of mat (%" PetscInt_FMT ") and xx (%" PetscInt_FMT ")", mat->rmap->n, nt); 1179566063dSJacob Faibussowitsch PetscCall((*mpiaij->B->ops->multtranspose)(mpiaij->B, xx, mpiaij->lvec)); 1189566063dSJacob Faibussowitsch PetscCall((*mpiaij->A->ops->multtranspose)(mpiaij->A, xx, yy)); 1199566063dSJacob Faibussowitsch PetscCall(VecScatterBegin(mpiaij->Mvctx, mpiaij->lvec, yy, ADD_VALUES, SCATTER_REVERSE)); 1209566063dSJacob Faibussowitsch PetscCall(VecScatterEnd(mpiaij->Mvctx, mpiaij->lvec, yy, ADD_VALUES, SCATTER_REVERSE)); 1213ba16761SJacob Faibussowitsch PetscFunctionReturn(PETSC_SUCCESS); 1228c3ff71bSJunchao Zhang } 1238c3ff71bSJunchao Zhang 124076ba34aSJunchao Zhang /* Merge the "A, B" matrices of mat into a matrix C. mat's type is MPIAIJKOKKOS. C's type is MATSEQAIJKOKKOS. 125076ba34aSJunchao Zhang A is put before B. C's size would be A->rmap->n by (A->cmap->n + B->cmap->n). 126076ba34aSJunchao Zhang C still uses local column ids. Their corresponding global column ids are returned in glob. 127076ba34aSJunchao Zhang */ 128d71ae5a4SJacob Faibussowitsch PetscErrorCode MatMPIAIJGetLocalMatMerge_MPIAIJKokkos(Mat mat, MatReuse reuse, IS *glob, Mat *C) 129d71ae5a4SJacob Faibussowitsch { 130076ba34aSJunchao Zhang Mat Ad, Ao; 131076ba34aSJunchao Zhang const PetscInt *cmap; 132076ba34aSJunchao Zhang 133076ba34aSJunchao Zhang PetscFunctionBegin; 1349566063dSJacob Faibussowitsch PetscCall(MatMPIAIJGetSeqAIJ(mat, &Ad, &Ao, &cmap)); 1359566063dSJacob Faibussowitsch PetscCall(MatSeqAIJKokkosMergeMats(Ad, Ao, reuse, C)); 136076ba34aSJunchao Zhang if (glob) { 137076ba34aSJunchao Zhang PetscInt cst, i, dn, on, *gidx; 1389566063dSJacob Faibussowitsch PetscCall(MatGetLocalSize(Ad, NULL, &dn)); 1399566063dSJacob Faibussowitsch PetscCall(MatGetLocalSize(Ao, NULL, &on)); 1409566063dSJacob Faibussowitsch PetscCall(MatGetOwnershipRangeColumn(mat, &cst, NULL)); 1419566063dSJacob Faibussowitsch PetscCall(PetscMalloc1(dn + on, &gidx)); 142076ba34aSJunchao Zhang for (i = 0; i < dn; i++) gidx[i] = cst + i; 143076ba34aSJunchao Zhang for (i = 0; i < on; i++) gidx[i + dn] = cmap[i]; 1449566063dSJacob Faibussowitsch PetscCall(ISCreateGeneral(PetscObjectComm((PetscObject)Ad), dn + on, gidx, PETSC_OWN_POINTER, glob)); 145076ba34aSJunchao Zhang } 1463ba16761SJacob Faibussowitsch PetscFunctionReturn(PETSC_SUCCESS); 147076ba34aSJunchao Zhang } 148076ba34aSJunchao Zhang 1490e3ece09SJunchao Zhang /* Structs used in matrix products of type C=AB, C=A^tB and C=B^tAB */ 150076ba34aSJunchao Zhang struct MatMatStruct { 1510e3ece09SJunchao Zhang PetscInt n, *garray; // C's garray and its size. 1520e3ece09SJunchao Zhang KokkosCsrMatrix Cd, Co; // C is in split form matrices (all in local column indcies) 1530e3ece09SJunchao Zhang KokkosCsrMatrix C1, C2, C3, C4; // intermediate mat products 1540e3ece09SJunchao Zhang KokkosCsrMatrix C2_mid, C4_mid; // alias of C2, C4; share their a[], i[], but with different j[] (hence column size) 1550e3ece09SJunchao Zhang PetscIntKokkosView E_NzLeft; 1560e3ece09SJunchao Zhang PetscSF sf = nullptr; // SF to bcast or reduce matrices E to F 1570e3ece09SJunchao Zhang MatScalarKokkosView rootBuf, leafBuf; 1580e3ece09SJunchao Zhang KokkosCsrMatrix Fd, Fo; // F in split form 1590e3ece09SJunchao Zhang 1600e3ece09SJunchao Zhang KernelHandle kh1; // compute C1, add C1+C3 or C1+Fd 1610e3ece09SJunchao Zhang KernelHandle kh2; // compute C2, add C2+C4 or C2+Fo 1620e3ece09SJunchao Zhang KernelHandle kh3; // compute C3 1630e3ece09SJunchao Zhang KernelHandle kh4; // compute C4 1640e3ece09SJunchao Zhang 1650e3ece09SJunchao Zhang PetscInt E_TeamSize; // kernel launching parameters in merging E or spliting F 1660e3ece09SJunchao Zhang PetscInt E_VectorLength; 1670e3ece09SJunchao Zhang PetscInt E_RowsPerTeam; 1680e3ece09SJunchao Zhang PetscInt F_TeamSize; 1690e3ece09SJunchao Zhang PetscInt F_VectorLength; 1700e3ece09SJunchao Zhang PetscInt F_RowsPerTeam; 171076ba34aSJunchao Zhang 172d71ae5a4SJacob Faibussowitsch ~MatMatStruct() 173d71ae5a4SJacob Faibussowitsch { 1743ba16761SJacob Faibussowitsch PetscFunctionBegin; 1753ba16761SJacob Faibussowitsch PetscCallAbort(PETSC_COMM_SELF, PetscSFDestroy(&sf)); 1763ba16761SJacob Faibussowitsch PetscFunctionReturnVoid(); 177076ba34aSJunchao Zhang } 178076ba34aSJunchao Zhang }; 179076ba34aSJunchao Zhang 180076ba34aSJunchao Zhang struct MatMatStruct_AB : public MatMatStruct { 1810e3ece09SJunchao Zhang PetscIntKokkosView F_NzLeft; // plans to split F (in leafbuf) into Fd, Fo 1820e3ece09SJunchao Zhang PetscIntKokkosView irootloc; // plans to put E (i.e., Bd, Bo) into rootBuf 1830e3ece09SJunchao Zhang PetscIntKokkosView rowoffset; 184076ba34aSJunchao Zhang }; 185076ba34aSJunchao Zhang 186076ba34aSJunchao Zhang struct MatMatStruct_AtB : public MatMatStruct { 1870e3ece09SJunchao Zhang MatColIdxKokkosView Fdjmap; // plans to reduce data in rootBuf to Fd, Fo 1880e3ece09SJunchao Zhang MatColIdxKokkosView Fdjperm; 1890e3ece09SJunchao Zhang MatColIdxKokkosView Fojmap; 1900e3ece09SJunchao Zhang MatColIdxKokkosView Fojperm; 191076ba34aSJunchao Zhang }; 192076ba34aSJunchao Zhang 1939371c9d4SSatish Balay struct MatProductData_MPIAIJKokkos { 1943ba16761SJacob Faibussowitsch MatMatStruct_AB *mmAB = nullptr; 1953ba16761SJacob Faibussowitsch MatMatStruct_AtB *mmAtB = nullptr; 1963ba16761SJacob Faibussowitsch PetscBool reusesym = PETSC_FALSE; 1970e3ece09SJunchao Zhang Mat Z = nullptr; // store Z=AB in computing BtAB 198076ba34aSJunchao Zhang 199d71ae5a4SJacob Faibussowitsch ~MatProductData_MPIAIJKokkos() 200d71ae5a4SJacob Faibussowitsch { 201076ba34aSJunchao Zhang delete mmAB; 202076ba34aSJunchao Zhang delete mmAtB; 2030e3ece09SJunchao Zhang PetscCallAbort(PETSC_COMM_SELF, MatDestroy(&Z)); 204076ba34aSJunchao Zhang } 205076ba34aSJunchao Zhang }; 206076ba34aSJunchao Zhang 207d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatProductDataDestroy_MPIAIJKokkos(void *data) 208d71ae5a4SJacob Faibussowitsch { 209076ba34aSJunchao Zhang PetscFunctionBegin; 2109566063dSJacob Faibussowitsch PetscCallCXX(delete static_cast<MatProductData_MPIAIJKokkos *>(data)); 2113ba16761SJacob Faibussowitsch PetscFunctionReturn(PETSC_SUCCESS); 212076ba34aSJunchao Zhang } 213076ba34aSJunchao Zhang 214076ba34aSJunchao Zhang /* MatSetMPIAIJKokkosWithSplitSeqAIJKokkosMatrices - Set the diag and offdiag matrices of a MATMPIAIJKOKKOS matrix. 215076ba34aSJunchao Zhang It is similar to MatCreateMPIAIJWithSplitArrays. 216076ba34aSJunchao Zhang 217076ba34aSJunchao Zhang Input Parameters: 218076ba34aSJunchao Zhang + mat - the MATMPIAIJKOKKOS matrix, which should have its type and layout set, but should not have its diag, offdiag matrices set 219076ba34aSJunchao Zhang . A - the diag matrix using local col ids 220076ba34aSJunchao Zhang - B - the offdiag matrix using global col ids 221076ba34aSJunchao Zhang 222076ba34aSJunchao Zhang Output Parameters: 223076ba34aSJunchao Zhang . mat - the updated MATMPIAIJKOKKOS matrix 224076ba34aSJunchao Zhang */ 2250e3ece09SJunchao Zhang static PetscErrorCode MatSetMPIAIJKokkosWithSplitSeqAIJKokkosMatrices(Mat mat, Mat A, Mat B, PetscInt *garray) 226d71ae5a4SJacob Faibussowitsch { 227076ba34aSJunchao Zhang Mat_MPIAIJ *mpiaij = static_cast<Mat_MPIAIJ *>(mat->data); 228076ba34aSJunchao Zhang PetscInt m, n, M, N, Am, An, Bm, Bn; 229076ba34aSJunchao Zhang 230076ba34aSJunchao Zhang PetscFunctionBegin; 2319566063dSJacob Faibussowitsch PetscCall(MatGetSize(mat, &M, &N)); 2329566063dSJacob Faibussowitsch PetscCall(MatGetLocalSize(mat, &m, &n)); 2339566063dSJacob Faibussowitsch PetscCall(MatGetLocalSize(A, &Am, &An)); 2349566063dSJacob Faibussowitsch PetscCall(MatGetLocalSize(B, &Bm, &Bn)); 235076ba34aSJunchao Zhang 236aed4548fSBarry Smith PetscCheck(m == Am && m == Bm, PETSC_COMM_SELF, PETSC_ERR_PLIB, "local number of rows do not match"); 23708401ef6SPierre Jolivet PetscCheck(n == An, PETSC_COMM_SELF, PETSC_ERR_PLIB, "local number of columns do not match"); 2380e3ece09SJunchao Zhang // PetscCheck(N == Bn, PETSC_COMM_SELF, PETSC_ERR_PLIB, "global number of columns do not match"); 23908401ef6SPierre Jolivet PetscCheck(!mpiaij->A && !mpiaij->B, PETSC_COMM_SELF, PETSC_ERR_PLIB, "A, B of the MPIAIJ matrix are not empty"); 240076ba34aSJunchao Zhang mpiaij->A = A; 241076ba34aSJunchao Zhang mpiaij->B = B; 2420e3ece09SJunchao Zhang mpiaij->garray = garray; 243076ba34aSJunchao Zhang 244076ba34aSJunchao Zhang mat->preallocated = PETSC_TRUE; 245076ba34aSJunchao Zhang mat->nooffprocentries = PETSC_TRUE; /* See MatAssemblyBegin_MPIAIJ. In effect, making MatAssemblyBegin a nop */ 246076ba34aSJunchao Zhang 2479566063dSJacob Faibussowitsch PetscCall(MatSetOption(mat, MAT_NO_OFF_PROC_ENTRIES, PETSC_TRUE)); 2489566063dSJacob Faibussowitsch PetscCall(MatAssemblyBegin(mat, MAT_FINAL_ASSEMBLY)); 249076ba34aSJunchao Zhang /* MatAssemblyEnd is critical here. It sets mat->offloadmask according to A and B's, and 250076ba34aSJunchao Zhang also gets mpiaij->B compacted, with its col ids and size reduced 251076ba34aSJunchao Zhang */ 2529566063dSJacob Faibussowitsch PetscCall(MatAssemblyEnd(mat, MAT_FINAL_ASSEMBLY)); 2539566063dSJacob Faibussowitsch PetscCall(MatSetOption(mat, MAT_NO_OFF_PROC_ENTRIES, PETSC_FALSE)); 2549566063dSJacob Faibussowitsch PetscCall(MatSetOption(mat, MAT_NEW_NONZERO_LOCATION_ERR, PETSC_TRUE)); 2553ba16761SJacob Faibussowitsch PetscFunctionReturn(PETSC_SUCCESS); 256076ba34aSJunchao Zhang } 257076ba34aSJunchao Zhang 2580e3ece09SJunchao Zhang // Adapted from Kokkos-Kernels spmv_launch_parameters(), to get parameters in Kokkos nested loops which we used to merge or 2590e3ece09SJunchao Zhang // split csr matrices. The rule is to have "vector_length * team_size" be around 256 on GPUs (e.g., for a CUDA thread block) 2600e3ece09SJunchao Zhang template <class ExecutionSpace> 2610e3ece09SJunchao Zhang static PetscErrorCode MatMergeGetLaunchParameters(PetscInt numRows, PetscInt nnz, PetscInt rows_per_thread, PetscInt &team_size, PetscInt &vector_length, PetscInt &rows_per_team) 262d71ae5a4SJacob Faibussowitsch { 2630e3ece09SJunchao Zhang Kokkos::TeamPolicy<ExecutionSpace> teamPolicy(128, Kokkos::AUTO); 264076ba34aSJunchao Zhang 265076ba34aSJunchao Zhang PetscFunctionBegin; 2660e3ece09SJunchao Zhang PetscInt nnz_per_row = numRows ? (nnz / numRows) : 0; // we might meet empty matrices 267076ba34aSJunchao Zhang 2680e3ece09SJunchao Zhang if (nnz_per_row < 1) nnz_per_row = 1; 269076ba34aSJunchao Zhang 2700e3ece09SJunchao Zhang int max_vector_length = teamPolicy.vector_length_max(); 271076ba34aSJunchao Zhang 2720e3ece09SJunchao Zhang if (vector_length < 1) { 2730e3ece09SJunchao Zhang vector_length = 1; 2740e3ece09SJunchao Zhang while (vector_length < max_vector_length && vector_length * 6 < nnz_per_row) vector_length *= 2; 275076ba34aSJunchao Zhang } 276076ba34aSJunchao Zhang 2770e3ece09SJunchao Zhang // Determine rows per thread 2780e3ece09SJunchao Zhang if (rows_per_thread < 1) { 2790e3ece09SJunchao Zhang if (KokkosKernels::Impl::kk_is_gpu_exec_space<ExecutionSpace>()) rows_per_thread = 1; 2800e3ece09SJunchao Zhang else { 2810e3ece09SJunchao Zhang if (nnz_per_row < 20 && nnz > 5000000) { 2820e3ece09SJunchao Zhang rows_per_thread = 256; 2830e3ece09SJunchao Zhang } else rows_per_thread = 64; 284076ba34aSJunchao Zhang } 285076ba34aSJunchao Zhang } 286076ba34aSJunchao Zhang 2870e3ece09SJunchao Zhang if (team_size < 1) { 2880e3ece09SJunchao Zhang if (KokkosKernels::Impl::kk_is_gpu_exec_space<ExecutionSpace>()) { 2890e3ece09SJunchao Zhang team_size = 256 / vector_length; 290076ba34aSJunchao Zhang } else { 2910e3ece09SJunchao Zhang team_size = 1; 2920e3ece09SJunchao Zhang } 293076ba34aSJunchao Zhang } 294076ba34aSJunchao Zhang 2950e3ece09SJunchao Zhang rows_per_team = rows_per_thread * team_size; 296076ba34aSJunchao Zhang 2970e3ece09SJunchao Zhang if (rows_per_team < 0) { 2980e3ece09SJunchao Zhang PetscInt nnz_per_team = 4096; 2990e3ece09SJunchao Zhang PetscInt conc = ExecutionSpace().concurrency(); 3000e3ece09SJunchao Zhang while ((conc * nnz_per_team * 4 > nnz) && (nnz_per_team > 256)) nnz_per_team /= 2; 3010e3ece09SJunchao Zhang rows_per_team = (nnz_per_team + nnz_per_row - 1) / nnz_per_row; 3020e3ece09SJunchao Zhang } 3033ba16761SJacob Faibussowitsch PetscFunctionReturn(PETSC_SUCCESS); 304076ba34aSJunchao Zhang } 305076ba34aSJunchao Zhang 3060e3ece09SJunchao Zhang /* 3070e3ece09SJunchao Zhang Reduce two sets of global indices into local ones 308076ba34aSJunchao Zhang 309076ba34aSJunchao Zhang Input Parameters: 3100e3ece09SJunchao Zhang + n1 - size of garray1[], the first set 3110e3ece09SJunchao Zhang . garray1[n1] - a sorted global index array (without duplicates) 3120e3ece09SJunchao Zhang . m - size of indices[], the second set 3130e3ece09SJunchao Zhang - indices[m] - a unsorted global index array (might have duplicates), which will be updated on output into local ones 314076ba34aSJunchao Zhang 315076ba34aSJunchao Zhang Output Parameters: 3160e3ece09SJunchao Zhang + n2 - size of garray2[], the merged set, which combines garray1[] and indices[] 3170e3ece09SJunchao Zhang . garray2[n2] - allocated by callee using PetscMalloc1(). Contains sorted unique global indices (without duplicates). Caller needs to free it. 3180e3ece09SJunchao Zhang . map[n1] - allocated by caller. It gives garray1[i] = garray2[map[i]] 3190e3ece09SJunchao Zhang - indices[m] - on output, global indices in this array are rewritten with local ones, i.e, indices_input[i] = garray2[indices_output[i]] 320076ba34aSJunchao Zhang 3210e3ece09SJunchao Zhang Example, say 3220e3ece09SJunchao Zhang n1 = 5 3230e3ece09SJunchao Zhang garray1[5] = {1, 4, 7, 8, 10} 3240e3ece09SJunchao Zhang m = 4 3250e3ece09SJunchao Zhang indices[4] = {2, 4, 8, 9} 32611a5261eSBarry Smith 3270e3ece09SJunchao Zhang Combining them together, we have 7 global indices in garray2[] 3280e3ece09SJunchao Zhang n2 = 7 3290e3ece09SJunchao Zhang garray2[7] = {1, 2, 4, 7, 8, 9, 10} 3300e3ece09SJunchao Zhang 3310e3ece09SJunchao Zhang And we have map[] to connect "garray1[i] = garray2[map[i]], i=[0,n1)" 3320e3ece09SJunchao Zhang map[5] = {0, 2, 3, 4, 6} 3330e3ece09SJunchao Zhang 3340e3ece09SJunchao Zhang On output, indices[] is updated with local indices 3350e3ece09SJunchao Zhang indices[4] = {1, 2, 4, 5} 336076ba34aSJunchao Zhang */ 3370e3ece09SJunchao Zhang static PetscErrorCode ReduceTwoSetsOfGlobalIndices(PetscInt n1, const PetscInt *garray1, PetscInt m, PetscInt *indices, PetscInt *n2_, PetscInt **garray2_, PetscInt *map) 338d71ae5a4SJacob Faibussowitsch { 3390e3ece09SJunchao Zhang PetscHMapI g2l = nullptr; 3400e3ece09SJunchao Zhang PetscHashIter iter; 3410e3ece09SJunchao Zhang PetscInt tot, key, val; // total unique global indices. key is global id; val is local id 3420e3ece09SJunchao Zhang PetscInt n2, *garray2; 343076ba34aSJunchao Zhang 344076ba34aSJunchao Zhang PetscFunctionBegin; 3450e3ece09SJunchao Zhang tot = 0; 3460e3ece09SJunchao Zhang PetscCall(PetscHMapICreateWithSize(n1, &g2l)); 3470e3ece09SJunchao Zhang for (PetscInt i = 0; i < m; i++) { // insert those in indices[] 3480e3ece09SJunchao Zhang PetscCall(PetscHMapIGetWithDefault(g2l, indices[i], -1, &val)); // if not exist, val is set with -1 3490e3ece09SJunchao Zhang if (val < 0) PetscCall(PetscHMapISet(g2l, indices[i], tot++)); // val < 0 means gid is not in the hash table yet 350076ba34aSJunchao Zhang } 351076ba34aSJunchao Zhang 3520e3ece09SJunchao Zhang for (PetscInt i = 0; i < n1; i++) { // insert those in garray1[] 3530e3ece09SJunchao Zhang PetscCall(PetscHMapIGetWithDefault(g2l, garray1[i], -1, &val)); 3540e3ece09SJunchao Zhang if (val < 0) PetscCall(PetscHMapISet(g2l, garray1[i], tot++)); 355076ba34aSJunchao Zhang } 356076ba34aSJunchao Zhang 3570e3ece09SJunchao Zhang // Pull out (unique) globals in the hash table and put them in garray2[] 3580e3ece09SJunchao Zhang n2 = tot; 3590e3ece09SJunchao Zhang PetscCall(PetscMalloc1(n2, &garray2)); 3600e3ece09SJunchao Zhang tot = 0; 3610e3ece09SJunchao Zhang PetscHashIterBegin(g2l, iter); 3620e3ece09SJunchao Zhang while (!PetscHashIterAtEnd(g2l, iter)) { 3630e3ece09SJunchao Zhang PetscHashIterGetKey(g2l, iter, key); 3640e3ece09SJunchao Zhang PetscHashIterNext(g2l, iter); 3650e3ece09SJunchao Zhang garray2[tot++] = key; 366076ba34aSJunchao Zhang } 367076ba34aSJunchao Zhang 3680e3ece09SJunchao Zhang // Sort garray2[] and then map them to local indices starting from 0 3690e3ece09SJunchao Zhang PetscCall(PetscSortInt(n2, garray2)); 3700e3ece09SJunchao Zhang PetscCall(PetscHMapIClear(g2l)); 3710e3ece09SJunchao Zhang for (PetscInt i = 0; i < tot; i++) PetscCall(PetscHMapISet(g2l, garray2[i], i)); // i is the local id 372f0e6e2d1SJunchao Zhang 3730e3ece09SJunchao Zhang // Rewrite indices[] with local indices 374f0e6e2d1SJunchao Zhang for (PetscInt i = 0; i < m; i++) { 3750e3ece09SJunchao Zhang PetscCall(PetscHMapIGetWithDefault(g2l, indices[i], -1, &val)); 3760e3ece09SJunchao Zhang PetscAssert(val >= 0, PETSC_COMM_SELF, PETSC_ERR_PLIB, "Met a negative local column index"); 3770e3ece09SJunchao Zhang indices[i] = val; 3780e3ece09SJunchao Zhang } 3790e3ece09SJunchao Zhang // Record the map that maps garray1[i] to garray2[map[i]] 3800e3ece09SJunchao Zhang for (PetscInt i = 0; i < n1; i++) PetscCall(PetscHMapIGetWithDefault(g2l, garray1[i], -1, &map[i])); 3810e3ece09SJunchao Zhang PetscCall(PetscHMapIDestroy(&g2l)); 3820e3ece09SJunchao Zhang *n2_ = n2; 3830e3ece09SJunchao Zhang *garray2_ = garray2; 3840e3ece09SJunchao Zhang PetscFunctionReturn(PETSC_SUCCESS); 3850e3ece09SJunchao Zhang } 386f0e6e2d1SJunchao Zhang 3870e3ece09SJunchao Zhang /* 3880e3ece09SJunchao Zhang MatMPIAIJKokkosReduce - Reduce rows of a MPIAIJKOKKOS matrix (E, in split form) to produce another matrix (F, also in split form, stored in mm) 3890e3ece09SJunchao Zhang 3900e3ece09SJunchao Zhang It is the reverse of MatMPIAIJKokkosBcast() in some sense, but with a different signature since we do not really need a fully populated MPIAIJKOKKOS E. 3910e3ece09SJunchao Zhang 3920e3ece09SJunchao Zhang Think each row of E as a leaf, then the given ownerSF specifies roots for the leaves. Roots may connect to multiple leaves. 3930e3ece09SJunchao Zhang In this routine, we sparse-merge leaves (rows) at their roots to form potentially longer rows in F. F's number of rows will be nroots of ownerSF. 3940e3ece09SJunchao Zhang 3950e3ece09SJunchao Zhang Input Parameters: 3960e3ece09SJunchao Zhang + comm - MPI communicator of E 3970e3ece09SJunchao Zhang . A - diag block of E, using local column indices 3980e3ece09SJunchao Zhang . B - off-diag block of E, using local column indices 3990e3ece09SJunchao Zhang . cstart - (global) start column of Ed 4000e3ece09SJunchao Zhang . cend - (global) end column + 1 of Ed. In other words, E's column ownership is in range of [cstart, cend) 4010e3ece09SJunchao Zhang . garray1[n1] - global column indices of Eo. Here n1 is Eo's column size. 4020e3ece09SJunchao Zhang . ownerSF - the SF specifies ownership (root) of rows in E 4030e3ece09SJunchao Zhang . reuse - either MAT_INITIAL_MATRIX or MAT_REUSE_MATRIX 4040e3ece09SJunchao Zhang - mm - to stash intermediate data structures for reuse 4050e3ece09SJunchao Zhang 4060e3ece09SJunchao Zhang Output Parameters: 4070e3ece09SJunchao Zhang + map[n1] - allocated by caller. It maps garray1[] to garray2[]. See more at ReduceTwoSetsOfGlobalIndices(). 4080e3ece09SJunchao Zhang - mm - contains various info, such as garray2[], F (Fd, Fo) etc. 4090e3ece09SJunchao Zhang 4100e3ece09SJunchao Zhang Notes: 4110e3ece09SJunchao Zhang When reuse = MAT_REUSE_MATRIX, cstart, cend, garray1, ownerSF, map are not significant. 4120e3ece09SJunchao Zhang 4130e3ece09SJunchao Zhang */ 4140e3ece09SJunchao Zhang static PetscErrorCode MatMPIAIJKokkosReduceBegin(MPI_Comm comm, KokkosCsrMatrix A, KokkosCsrMatrix B, PetscInt cstart, PetscInt cend, const PetscInt *garray1, PetscSF ownerSF, MatReuse reuse, PetscInt *map, MatMatStruct_AtB *mm) 4150e3ece09SJunchao Zhang { 4160e3ece09SJunchao Zhang PetscFunctionBegin; 4170e3ece09SJunchao Zhang if (reuse == MAT_INITIAL_MATRIX) { 4180e3ece09SJunchao Zhang PetscInt Em = A.numRows(), Fm; 4190e3ece09SJunchao Zhang PetscInt n1 = B.numCols(); 4200e3ece09SJunchao Zhang 4210e3ece09SJunchao Zhang PetscCall(PetscSFGetGraph(ownerSF, &Fm, NULL, NULL, NULL)); // Fm = #rows of F = nroots of ownerSF 4220e3ece09SJunchao Zhang 4230e3ece09SJunchao Zhang // Do the analysis on host 4240e3ece09SJunchao Zhang auto Ai_h = Kokkos::create_mirror_view_and_copy(Kokkos::HostSpace(), A.graph.row_map); 4250e3ece09SJunchao Zhang auto Aj_h = Kokkos::create_mirror_view_and_copy(Kokkos::HostSpace(), A.graph.entries); 4260e3ece09SJunchao Zhang auto Bi_h = Kokkos::create_mirror_view_and_copy(Kokkos::HostSpace(), B.graph.row_map); 4270e3ece09SJunchao Zhang auto Bj_h = Kokkos::create_mirror_view_and_copy(Kokkos::HostSpace(), B.graph.entries); 4280e3ece09SJunchao Zhang const MatRowMapType *Ai = Ai_h.data(), *Bi = Bi_h.data(); 4290e3ece09SJunchao Zhang const MatColIdxType *Aj = Aj_h.data(), *Bj = Bj_h.data(); 4300e3ece09SJunchao Zhang 4310e3ece09SJunchao Zhang // Count how many nonzeros of each row in E are in the left of the diag block [cstart,cend) 4320e3ece09SJunchao Zhang PetscIntKokkosViewHost E_NzLeft_h("E_NzLeft_h", Em), E_RowLen_h("E_RowLen_h", Em); 4330e3ece09SJunchao Zhang PetscInt *E_NzLeft = E_NzLeft_h.data(), *E_RowLen = E_RowLen_h.data(); 4340e3ece09SJunchao Zhang for (PetscInt i = 0; i < Em; i++) { 4350e3ece09SJunchao Zhang const PetscInt *first, *last, *it; 4360e3ece09SJunchao Zhang PetscInt count, step; 4370e3ece09SJunchao Zhang // std::lower_bound(first,last,cstart), but need to use global column indices 4380e3ece09SJunchao Zhang first = Bj + Bi[i]; 4390e3ece09SJunchao Zhang last = Bj + Bi[i + 1]; 440f0e6e2d1SJunchao Zhang count = last - first; 441f0e6e2d1SJunchao Zhang while (count > 0) { 442f0e6e2d1SJunchao Zhang it = first; 443f0e6e2d1SJunchao Zhang step = count / 2; 444f0e6e2d1SJunchao Zhang it += step; 4450e3ece09SJunchao Zhang if (garray1[*it] < cstart) { // map local to global 446f0e6e2d1SJunchao Zhang first = ++it; 447f0e6e2d1SJunchao Zhang count -= step + 1; 448f0e6e2d1SJunchao Zhang } else count = step; 449f0e6e2d1SJunchao Zhang } 4500e3ece09SJunchao Zhang E_NzLeft[i] = first - (Bj + Bi[i]); 4510e3ece09SJunchao Zhang E_RowLen[i] = (Ai[i + 1] - Ai[i]) + (Bi[i + 1] - Bi[i]); 452f0e6e2d1SJunchao Zhang } 453f0e6e2d1SJunchao Zhang 4540e3ece09SJunchao Zhang // Get length of rows (i.e., sizes of leaves) that contribute to my roots 4550e3ece09SJunchao Zhang const PetscMPIInt *iranks, *ranks; 4560e3ece09SJunchao Zhang const PetscInt *ioffset, *irootloc, *roffset, *rmine; 4570e3ece09SJunchao Zhang PetscInt niranks, nranks; 4580e3ece09SJunchao Zhang MPI_Request *reqs; 4590e3ece09SJunchao Zhang PetscMPIInt tag; 4600e3ece09SJunchao Zhang PetscSF reduceSF; 4610e3ece09SJunchao Zhang PetscInt *sdisp, *rdisp; 462f0e6e2d1SJunchao Zhang 4630e3ece09SJunchao Zhang PetscCall(PetscCommGetNewTag(comm, &tag)); 4640e3ece09SJunchao Zhang PetscCall(PetscSFGetLeafRanks(ownerSF, &niranks, &iranks, &ioffset, &irootloc)); // get leaf ranks connecting to roots on this process (I'll recv from them) 4650e3ece09SJunchao Zhang PetscCall(PetscSFGetRootRanks(ownerSF, &nranks, &ranks, &roffset, &rmine, NULL)); // get root ranks this process connects (I'll send to them) 466f0e6e2d1SJunchao Zhang 4670e3ece09SJunchao Zhang // Find out length of each row I will receive. Even for the same row index, when they are from 4680e3ece09SJunchao Zhang // different senders, they might have different lengths (and sparsity patterns) 4690e3ece09SJunchao Zhang PetscInt sendRowCnt = roffset[nranks], recvRowCnt = ioffset[niranks]; 4700e3ece09SJunchao Zhang PetscInt *sendRowLen, *recvRowLen; // lengths of rows of I need to send/recv per process 471f0e6e2d1SJunchao Zhang 4720e3ece09SJunchao Zhang PetscCall(PetscMalloc5(sendRowCnt, &sendRowLen, recvRowCnt + 1, &recvRowLen, nranks, &sdisp, niranks + 1, &rdisp, nranks + niranks, &reqs)); 4730e3ece09SJunchao Zhang 4740e3ece09SJunchao Zhang for (PetscInt i = 0; i < sendRowCnt; i++) sendRowLen[i] = E_RowLen[rmine[i]]; 4750e3ece09SJunchao Zhang recvRowLen[0] = 0; // since we will make it in CSR format later 4760e3ece09SJunchao Zhang recvRowLen++; // advance the pointer now 4770e3ece09SJunchao Zhang for (PetscInt i = 0; i < niranks; i++) { MPI_Irecv(&recvRowLen[ioffset[i]], ioffset[i + 1] - ioffset[i], MPIU_INT, iranks[i], tag, comm, &reqs[nranks + i]); } 4780e3ece09SJunchao Zhang for (PetscInt i = 0; i < nranks; i++) { MPI_Isend(&sendRowLen[roffset[i]], roffset[i + 1] - roffset[i], MPIU_INT, ranks[i], tag, comm, &reqs[i]); } 4790e3ece09SJunchao Zhang PetscCallMPI(MPI_Waitall(nranks + niranks, reqs, MPI_STATUSES_IGNORE)); 4800e3ece09SJunchao Zhang 4810e3ece09SJunchao Zhang // Build the real PetscSF for reducing E rows (buffer to buffer) 4820e3ece09SJunchao Zhang rdisp[0] = 0; 4830e3ece09SJunchao Zhang for (PetscInt i = 0; i < niranks; i++) { 4840e3ece09SJunchao Zhang rdisp[i + 1] = rdisp[i]; 4850e3ece09SJunchao Zhang for (PetscInt j = ioffset[i]; j < ioffset[i + 1]; j++) { rdisp[i + 1] += recvRowLen[j]; } 4860e3ece09SJunchao Zhang } 4870e3ece09SJunchao Zhang recvRowLen--; // put it back into csr format 4880e3ece09SJunchao Zhang for (PetscInt i = 0; i < recvRowCnt; i++) recvRowLen[i + 1] += recvRowLen[i]; 4890e3ece09SJunchao Zhang 4900e3ece09SJunchao Zhang for (PetscInt i = 0; i < nranks; i++) { MPI_Irecv(&sdisp[i], 1, MPIU_INT, ranks[i], tag, comm, &reqs[i]); } 4910e3ece09SJunchao Zhang for (PetscInt i = 0; i < niranks; i++) { MPI_Isend(&rdisp[i], 1, MPIU_INT, iranks[i], tag, comm, &reqs[nranks + i]); } 4920e3ece09SJunchao Zhang PetscCallMPI(MPI_Waitall(nranks + niranks, reqs, MPI_STATUSES_IGNORE)); 4930e3ece09SJunchao Zhang 4940e3ece09SJunchao Zhang PetscInt nleaves = 0, Enz = 0; // leaves are nonzeros I will send 4950e3ece09SJunchao Zhang PetscInt nroots = rdisp[niranks]; // roots are nonzeros I will recv 4960e3ece09SJunchao Zhang PetscSFNode *iremote; 4970e3ece09SJunchao Zhang 4980e3ece09SJunchao Zhang for (PetscInt i = 0; i < Em; i++) Enz += E_RowLen[i]; 4990e3ece09SJunchao Zhang PetscAssert(A.nnz() + B.nnz() == Enz, comm, PETSC_ERR_PLIB, "Enz should be equal to sum of nnz of A and B"); 5000e3ece09SJunchao Zhang PetscCallMPI(PetscMalloc1(Enz, &iremote)); // no free, since we give ownership to reduceSF 5010e3ece09SJunchao Zhang 5020e3ece09SJunchao Zhang for (PetscInt i = 0; i < nranks; i++) { 5030e3ece09SJunchao Zhang PetscInt count = 0; 5040e3ece09SJunchao Zhang for (PetscInt j = roffset[i]; j < roffset[i + 1]; j++) count += E_RowLen[rmine[j]]; 5050e3ece09SJunchao Zhang for (PetscInt j = 0; j < count; j++) { 5060e3ece09SJunchao Zhang iremote[nleaves + j].rank = ranks[i]; 5070e3ece09SJunchao Zhang iremote[nleaves + j].index = sdisp[i] + j; 5080e3ece09SJunchao Zhang } 5090e3ece09SJunchao Zhang nleaves += count; 5100e3ece09SJunchao Zhang } 5110e3ece09SJunchao Zhang PetscCheck(nleaves == Enz, comm, PETSC_ERR_PLIB, "nleaves should be equal to Enz"); 5120e3ece09SJunchao Zhang 5130e3ece09SJunchao Zhang PetscCall(PetscSFCreate(comm, &reduceSF)); 5140e3ece09SJunchao Zhang PetscCall(PetscSFSetGraph(reduceSF, nroots, nleaves, NULL, PETSC_OWN_POINTER, iremote, PETSC_OWN_POINTER)); 5150e3ece09SJunchao Zhang 5160e3ece09SJunchao Zhang // Copy (global) column indices of the needed rows in E to sendCol[], and then PetscSFReduce to recvCol[] 5170e3ece09SJunchao Zhang PetscInt *sendCol, *recvCol; 5180e3ece09SJunchao Zhang PetscCall(PetscMalloc2(nleaves, &sendCol, nroots, &recvCol)); 5190e3ece09SJunchao Zhang for (PetscInt k = 0; k < roffset[nranks]; k++) { 5200e3ece09SJunchao Zhang PetscInt i = rmine[k]; // row to be copied 5210e3ece09SJunchao Zhang PetscInt *buf = &sendCol[Ai[i] + Bi[i]]; 5220e3ece09SJunchao Zhang PetscInt nzLeft = E_NzLeft[i]; 5230e3ece09SJunchao Zhang PetscInt alen = Ai[i + 1] - Ai[i], blen = Bi[i + 1] - Bi[i]; 5240e3ece09SJunchao Zhang for (PetscInt j = 0; j < alen + blen; j++) { 5250e3ece09SJunchao Zhang if (j < nzLeft) { 5260e3ece09SJunchao Zhang buf[j] = garray1[Bj[Bi[i] + j]]; // left B, in global 5270e3ece09SJunchao Zhang } else if (j < nzLeft + alen) { 5280e3ece09SJunchao Zhang buf[j] = Aj[Ai[i] + j - nzLeft] + cstart; // diag A, also in global 5290e3ece09SJunchao Zhang } else { 5300e3ece09SJunchao Zhang buf[j] = garray1[Bj[Bi[i] + j - alen]]; // right B, in global 5310e3ece09SJunchao Zhang } 5320e3ece09SJunchao Zhang } 5330e3ece09SJunchao Zhang } 5340e3ece09SJunchao Zhang PetscCall(PetscSFReduceWithMemTypeBegin(reduceSF, MPIU_INT, PETSC_MEMTYPE_HOST, sendCol, PETSC_MEMTYPE_HOST, recvCol, MPI_REPLACE)); 5350e3ece09SJunchao Zhang PetscCall(PetscSFReduceEnd(reduceSF, MPIU_INT, sendCol, recvCol, MPI_REPLACE)); 5360e3ece09SJunchao Zhang 5370e3ece09SJunchao Zhang // With recvCol[], we do a series of analysis to get i, j of Fd, Fo, and build plans to reduce nonzeros in recv buffers to Fd and Fo 5380e3ece09SJunchao Zhang PetscInt *recvRowPerm, *recvColSorted; 5390e3ece09SJunchao Zhang PetscInt *recvNzPerm, *recvNzPermSorted; 5400e3ece09SJunchao Zhang PetscCall(PetscMalloc4(recvRowCnt, &recvRowPerm, nroots, &recvColSorted, nroots, &recvNzPerm, nroots, &recvNzPermSorted)); 5410e3ece09SJunchao Zhang 5420e3ece09SJunchao Zhang for (PetscInt i = 0; i < nroots; i++) recvNzPerm[i] = i; // numbering all received nonzeros 5430e3ece09SJunchao Zhang for (PetscInt i = 0; i < recvRowCnt; i++) recvRowPerm[i] = i; // put up a permutation array, so that after sorting we know where to get a row in recvCol[] 5440e3ece09SJunchao Zhang PetscCall(PetscSortIntWithPermutation(recvRowCnt, irootloc, recvRowPerm)); // irootloc[] (owned by ownerSF) won't be changed 5450e3ece09SJunchao Zhang 5460e3ece09SJunchao Zhang // i[] array, nz are always easiest to compute 5470e3ece09SJunchao Zhang MatRowMapKokkosViewHost Fdi_h("Fdi_h", Fm + 1), Foi_h("Foi_h", Fm + 1); 5480e3ece09SJunchao Zhang MatRowMapType *Fdi, *Foi; 5490e3ece09SJunchao Zhang PetscInt FnzDups = 0, Fdnz = 0, FdnzDups = 0, Fonz = 0, FonzDups = 0; // nz (with or without dups) in F, Fd, Fo 5500e3ece09SJunchao Zhang PetscInt iter; 5510e3ece09SJunchao Zhang 5520e3ece09SJunchao Zhang Kokkos::deep_copy(Fdi_h, 0); // zero, as we will do 'val++' on them 5530e3ece09SJunchao Zhang Kokkos::deep_copy(Foi_h, 0); 5540e3ece09SJunchao Zhang Fdi = Fdi_h.data() + 1; // +1 for easy indexing in code below 5550e3ece09SJunchao Zhang Foi = Foi_h.data() + 1; 5560e3ece09SJunchao Zhang iter = 0; 5570e3ece09SJunchao Zhang while (iter < recvRowCnt) { // iter over received rows 5580e3ece09SJunchao Zhang PetscInt curRowIdx = irootloc[recvRowPerm[iter]]; 5590e3ece09SJunchao Zhang PetscInt dupRows = 1; // current row has this many contributing rows (of various sparsity patterns) 5600e3ece09SJunchao Zhang 5610e3ece09SJunchao Zhang while (iter + dupRows < recvRowCnt && irootloc[recvRowPerm[iter + dupRows]] == curRowIdx) dupRows++; 5620e3ece09SJunchao Zhang 5630e3ece09SJunchao Zhang // Copy column indices (and their permutation) of these rows into recvColSorted & recvNzPermSorted 5640e3ece09SJunchao Zhang PetscInt nz = 0; // nz (with dups) in the current row 5650e3ece09SJunchao Zhang PetscInt *jbuf = recvColSorted + FnzDups; 5660e3ece09SJunchao Zhang PetscInt *pbuf = recvNzPermSorted + FnzDups; 5670e3ece09SJunchao Zhang PetscInt *jbuf2 = jbuf; // temp pointers 5680e3ece09SJunchao Zhang PetscInt *pbuf2 = pbuf; 5690e3ece09SJunchao Zhang for (PetscInt d = 0; d < dupRows; d++) { 5700e3ece09SJunchao Zhang PetscInt i = recvRowPerm[iter + d]; 5710e3ece09SJunchao Zhang PetscInt len = recvRowLen[i + 1] - recvRowLen[i]; 5720e3ece09SJunchao Zhang PetscCall(PetscArraycpy(jbuf2, &recvCol[recvRowLen[i]], len)); 5730e3ece09SJunchao Zhang PetscCall(PetscArraycpy(pbuf2, &recvNzPerm[recvRowLen[i]], len)); 5740e3ece09SJunchao Zhang jbuf2 += len; 5750e3ece09SJunchao Zhang pbuf2 += len; 5760e3ece09SJunchao Zhang nz += len; 5770e3ece09SJunchao Zhang } 5780e3ece09SJunchao Zhang PetscCall(PetscIntSortSemiOrderedWithArray(nz, jbuf, pbuf)); // It could be improved with k-way merge sort, since the rows are already sorted 5790e3ece09SJunchao Zhang 5800e3ece09SJunchao Zhang // Scan column indices (in jbuf[0,nz), might have dups) of this row, and see how many go to Fd and how many go to Fo 5810e3ece09SJunchao Zhang PetscInt cur = 0; 5820e3ece09SJunchao Zhang while (cur < nz) { 5830e3ece09SJunchao Zhang PetscInt curColIdx = jbuf[cur]; 5840e3ece09SJunchao Zhang PetscInt dups = 1; 5850e3ece09SJunchao Zhang 5860e3ece09SJunchao Zhang while (cur + dups < nz && jbuf[cur + dups] == curColIdx) dups++; 5870e3ece09SJunchao Zhang if (curColIdx >= cstart && curColIdx < cend) { 5880e3ece09SJunchao Zhang Fdi[curRowIdx]++; 5890e3ece09SJunchao Zhang FdnzDups += dups; 5900e3ece09SJunchao Zhang } else { 5910e3ece09SJunchao Zhang Foi[curRowIdx]++; 5920e3ece09SJunchao Zhang FonzDups += dups; 5930e3ece09SJunchao Zhang } 5940e3ece09SJunchao Zhang cur += dups; 5950e3ece09SJunchao Zhang } 5960e3ece09SJunchao Zhang 5970e3ece09SJunchao Zhang FnzDups += nz; 5980e3ece09SJunchao Zhang iter += dupRows; // Move to next unique row 5990e3ece09SJunchao Zhang } 6000e3ece09SJunchao Zhang 6010e3ece09SJunchao Zhang Fdi = Fdi_h.data(); // restore Fdi, Foi and make them CSR 6020e3ece09SJunchao Zhang Foi = Foi_h.data(); 6030e3ece09SJunchao Zhang for (PetscInt i = 0; i < Fm; i++) { 6040e3ece09SJunchao Zhang Fdi[i + 1] += Fdi[i]; 6050e3ece09SJunchao Zhang Foi[i + 1] += Foi[i]; 6060e3ece09SJunchao Zhang } 6070e3ece09SJunchao Zhang Fdnz = Fdi[Fm]; 6080e3ece09SJunchao Zhang Fonz = Foi[Fm]; 6090e3ece09SJunchao Zhang PetscCall(PetscFree2(sendCol, recvCol)); 6100e3ece09SJunchao Zhang 6110e3ece09SJunchao Zhang // Allocate j, jmap, jperm for Fd and Fo 6120e3ece09SJunchao Zhang MatColIdxKokkosViewHost Fdj_h("Fdj_h", Fdnz), Foj_h("Foj_h", Fonz); 6130e3ece09SJunchao Zhang MatRowMapKokkosViewHost Fdjmap_h("Fdjmap_h", Fdnz + 1), Fojmap_h("Fojmap_h", Fonz + 1); // +1 to make csr 6140e3ece09SJunchao Zhang MatRowMapKokkosViewHost Fdjperm_h("Fdjperm_h", FdnzDups), Fojperm_h("Fojperm_h", FonzDups); 6150e3ece09SJunchao Zhang MatColIdxType *Fdj = Fdj_h.data(), *Foj = Foj_h.data(); 6160e3ece09SJunchao Zhang MatRowMapType *Fdjmap = Fdjmap_h.data(), *Fojmap = Fojmap_h.data(); 6170e3ece09SJunchao Zhang MatRowMapType *Fdjperm = Fdjperm_h.data(), *Fojperm = Fojperm_h.data(); 6180e3ece09SJunchao Zhang 6190e3ece09SJunchao Zhang // Scan recvColSorted[] again, and fill j, jmap, jperm for Fd and Fo 6200e3ece09SJunchao Zhang Fdjmap[0] = 0; 6210e3ece09SJunchao Zhang Fojmap[0] = 0; 6220e3ece09SJunchao Zhang FnzDups = 0; 6230e3ece09SJunchao Zhang Fdnz = 0; 6240e3ece09SJunchao Zhang Fonz = 0; 6250e3ece09SJunchao Zhang iter = 0; // iter over received rows 6260e3ece09SJunchao Zhang while (iter < recvRowCnt) { 6270e3ece09SJunchao Zhang PetscInt curRowIdx = irootloc[recvRowPerm[iter]]; // current row idx 6280e3ece09SJunchao Zhang PetscInt dupRows = 1; // It has this many contributing rows (of various lengths) 6290e3ece09SJunchao Zhang PetscInt nz = 0; // nz (with dups) in the current row 6300e3ece09SJunchao Zhang 6310e3ece09SJunchao Zhang while (iter + dupRows < recvRowCnt && irootloc[recvRowPerm[iter + dupRows]] == curRowIdx) dupRows++; 6320e3ece09SJunchao Zhang for (PetscInt d = 0; d < dupRows; d++) { 6330e3ece09SJunchao Zhang PetscInt i = recvRowPerm[iter + d]; 6340e3ece09SJunchao Zhang nz += recvRowLen[i + 1] - recvRowLen[i]; 6350e3ece09SJunchao Zhang } 6360e3ece09SJunchao Zhang 6370e3ece09SJunchao Zhang PetscInt *jbuf = recvColSorted + FnzDups; 6380e3ece09SJunchao Zhang // Scan columns (in jbuf[0,nz) of this row, copy them and their permutation to j[] and jperm[] of Fd and Fo 6390e3ece09SJunchao Zhang PetscInt cur = 0; 6400e3ece09SJunchao Zhang while (cur < nz) { 6410e3ece09SJunchao Zhang PetscInt curColIdx = jbuf[cur]; 6420e3ece09SJunchao Zhang PetscInt dups = 1; 6430e3ece09SJunchao Zhang 6440e3ece09SJunchao Zhang while (cur + dups < nz && jbuf[cur + dups] == curColIdx) dups++; 6450e3ece09SJunchao Zhang if (curColIdx >= cstart && curColIdx < cend) { 6460e3ece09SJunchao Zhang Fdj[Fdnz] = curColIdx - cstart; // easily convert to local 6470e3ece09SJunchao Zhang Fdjmap[Fdnz + 1] = Fdjmap[Fdnz] + dups; 6480e3ece09SJunchao Zhang for (PetscInt j = 0; j < dups; j++) Fdjperm[Fdjmap[Fdnz] + j] = recvNzPermSorted[FnzDups + j]; 6490e3ece09SJunchao Zhang FdnzDups += dups; 6500e3ece09SJunchao Zhang Fdnz++; 6510e3ece09SJunchao Zhang } else { 6520e3ece09SJunchao Zhang Foj[Fonz] = curColIdx; // in global 6530e3ece09SJunchao Zhang Fojmap[Fonz + 1] = Fojmap[Fonz] + dups; 6540e3ece09SJunchao Zhang for (PetscInt j = 0; j < dups; j++) Fojperm[Fojmap[Fonz] + j] = recvNzPermSorted[FnzDups + j]; 6550e3ece09SJunchao Zhang FonzDups += dups; 6560e3ece09SJunchao Zhang Fonz++; 6570e3ece09SJunchao Zhang } 6580e3ece09SJunchao Zhang cur += dups; 6590e3ece09SJunchao Zhang FnzDups += dups; 6600e3ece09SJunchao Zhang } 6610e3ece09SJunchao Zhang iter += dupRows; // Move to next unique row 6620e3ece09SJunchao Zhang } 6630e3ece09SJunchao Zhang PetscCall(PetscFree4(recvRowPerm, recvColSorted, recvNzPerm, recvNzPermSorted)); 6640e3ece09SJunchao Zhang PetscCall(PetscFree5(sendRowLen, recvRowLen, sdisp, rdisp, reqs)); 6650e3ece09SJunchao Zhang 6660e3ece09SJunchao Zhang // Combine global column indices in garray1[] and Foj[] 6670e3ece09SJunchao Zhang PetscInt n2, *garray2; 6680e3ece09SJunchao Zhang 6690e3ece09SJunchao Zhang PetscCall(ReduceTwoSetsOfGlobalIndices(n1, garray1, Fonz, Foj, &n2, &garray2, map)); 6700e3ece09SJunchao Zhang mm->sf = reduceSF; 6710e3ece09SJunchao Zhang mm->leafBuf = MatScalarKokkosView("leafBuf", nleaves); 6720e3ece09SJunchao Zhang mm->rootBuf = MatScalarKokkosView("rootBuf", nroots); 6730e3ece09SJunchao Zhang mm->garray = garray2; // give owership, so no free 6740e3ece09SJunchao Zhang mm->n = n2; 6750e3ece09SJunchao Zhang mm->E_NzLeft = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), E_NzLeft_h); 6760e3ece09SJunchao Zhang mm->Fdjmap = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Fdjmap_h); 6770e3ece09SJunchao Zhang mm->Fdjperm = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Fdjperm_h); 6780e3ece09SJunchao Zhang mm->Fojmap = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Fojmap_h); 6790e3ece09SJunchao Zhang mm->Fojperm = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Fojperm_h); 6800e3ece09SJunchao Zhang 6810e3ece09SJunchao Zhang // Output Fd and Fo in KokkosCsrMatrix format 6820e3ece09SJunchao Zhang MatScalarKokkosView Fda_d("Fda_d", Fdnz); 6830e3ece09SJunchao Zhang MatRowMapKokkosView Fdi_d = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Fdi_h); 6840e3ece09SJunchao Zhang MatColIdxKokkosView Fdj_d = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Fdj_h); 6850e3ece09SJunchao Zhang MatScalarKokkosView Foa_d("Foa_d", Fonz); 6860e3ece09SJunchao Zhang MatRowMapKokkosView Foi_d = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Foi_h); 6870e3ece09SJunchao Zhang MatColIdxKokkosView Foj_d = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Foj_h); 6880e3ece09SJunchao Zhang 6890e3ece09SJunchao Zhang PetscCallCXX(mm->Fd = KokkosCsrMatrix("Fd", Fm, cend - cstart, Fdnz, Fda_d, Fdi_d, Fdj_d)); 6900e3ece09SJunchao Zhang PetscCallCXX(mm->Fo = KokkosCsrMatrix("Fo", Fm, n2, Fonz, Foa_d, Foi_d, Foj_d)); // Fo's column size is n2, length of garray2[] 6910e3ece09SJunchao Zhang 6920e3ece09SJunchao Zhang // Compute kernel launch parameters in merging E 6930e3ece09SJunchao Zhang PetscInt teamSize, vectorLength, rowsPerTeam; 6940e3ece09SJunchao Zhang 6950e3ece09SJunchao Zhang teamSize = vectorLength = rowsPerTeam = -1; 6960e3ece09SJunchao Zhang PetscCall(MatMergeGetLaunchParameters<DefaultExecutionSpace>(Em, Enz, -1, teamSize, vectorLength, rowsPerTeam)); 6970e3ece09SJunchao Zhang mm->E_TeamSize = teamSize; 6980e3ece09SJunchao Zhang mm->E_VectorLength = vectorLength; 6990e3ece09SJunchao Zhang mm->E_RowsPerTeam = rowsPerTeam; 7000e3ece09SJunchao Zhang } else PetscCheck(reuse == MAT_REUSE_MATRIX, comm, PETSC_ERR_PLIB, "Unsupported MatReuse enum %d", reuse); 7010e3ece09SJunchao Zhang 7020e3ece09SJunchao Zhang // Handy aliases 7030e3ece09SJunchao Zhang auto &Aa = A.values; 7040e3ece09SJunchao Zhang auto &Ba = B.values; 7050e3ece09SJunchao Zhang const auto &Ai = A.graph.row_map; 7060e3ece09SJunchao Zhang const auto &Bi = B.graph.row_map; 7070e3ece09SJunchao Zhang const auto &E_NzLeft = mm->E_NzLeft; 7080e3ece09SJunchao Zhang auto &leafBuf = mm->leafBuf; 7090e3ece09SJunchao Zhang auto &rootBuf = mm->rootBuf; 7100e3ece09SJunchao Zhang PetscSF reduceSF = mm->sf; 7110e3ece09SJunchao Zhang PetscInt Em = A.numRows(); 7120e3ece09SJunchao Zhang PetscInt teamSize = mm->E_TeamSize; 7130e3ece09SJunchao Zhang PetscInt vectorLength = mm->E_VectorLength; 7140e3ece09SJunchao Zhang PetscInt rowsPerTeam = mm->E_RowsPerTeam; 7150e3ece09SJunchao Zhang PetscInt workSets = (Em + rowsPerTeam - 1) / rowsPerTeam; 7160e3ece09SJunchao Zhang 7170e3ece09SJunchao Zhang // Copy rows in A/B of E to leafBuf, then pass it to rootBuf 7180e3ece09SJunchao Zhang PetscCallCXX(Kokkos::parallel_for( 7190e3ece09SJunchao Zhang Kokkos::TeamPolicy<>(workSets, teamSize, vectorLength), KOKKOS_LAMBDA(const KokkosTeamMemberType &t) { 7200e3ece09SJunchao Zhang Kokkos::parallel_for(Kokkos::TeamThreadRange(t, 0, rowsPerTeam), [&](PetscInt k) { 7210e3ece09SJunchao Zhang PetscInt i = t.league_rank() * rowsPerTeam + k; // i-th row in F 7220e3ece09SJunchao Zhang if (i < Em) { 7230e3ece09SJunchao Zhang PetscInt disp = Ai(i) + Bi(i); 7240e3ece09SJunchao Zhang PetscInt alen = Ai(i + 1) - Ai(i); 7250e3ece09SJunchao Zhang PetscInt blen = Bi(i + 1) - Bi(i); 7260e3ece09SJunchao Zhang PetscInt nzleft = E_NzLeft(i); 7270e3ece09SJunchao Zhang 7280e3ece09SJunchao Zhang Kokkos::parallel_for(Kokkos::ThreadVectorRange(t, alen + blen), [&](PetscInt j) { 7290e3ece09SJunchao Zhang MatScalar &val = leafBuf(disp + j); 7300e3ece09SJunchao Zhang if (j < nzleft) { // B left 7310e3ece09SJunchao Zhang val = Ba(Bi(i) + j); 7320e3ece09SJunchao Zhang } else if (j < nzleft + alen) { // diag A 7330e3ece09SJunchao Zhang val = Aa(Ai(i) + j - nzleft); 7340e3ece09SJunchao Zhang } else { // B right 7350e3ece09SJunchao Zhang val = Ba(Bi(i) + j - alen); 736f0e6e2d1SJunchao Zhang } 737f0e6e2d1SJunchao Zhang }); 738f0e6e2d1SJunchao Zhang } 739f0e6e2d1SJunchao Zhang }); 7400e3ece09SJunchao Zhang })); 7410e3ece09SJunchao Zhang PetscCall(PetscSFReduceWithMemTypeBegin(reduceSF, MPIU_SCALAR, PETSC_MEMTYPE_KOKKOS, leafBuf.data(), PETSC_MEMTYPE_KOKKOS, rootBuf.data(), MPI_REPLACE)); 742f0e6e2d1SJunchao Zhang PetscFunctionReturn(PETSC_SUCCESS); 743f0e6e2d1SJunchao Zhang } 7440e3ece09SJunchao Zhang 7450e3ece09SJunchao Zhang // To finsih MatMPIAIJKokkosReduce. 7460e3ece09SJunchao Zhang static PetscErrorCode MatMPIAIJKokkosReduceEnd(MPI_Comm comm, KokkosCsrMatrix A, KokkosCsrMatrix B, PetscInt cstart, PetscInt cend, const PetscInt *garray1, PetscSF ownerSF, MatReuse reuse, PetscInt *map, MatMatStruct_AtB *mm) 7470e3ece09SJunchao Zhang { 7480e3ece09SJunchao Zhang PetscFunctionBegin; 7490e3ece09SJunchao Zhang auto &leafBuf = mm->leafBuf; 7500e3ece09SJunchao Zhang auto &rootBuf = mm->rootBuf; 7510e3ece09SJunchao Zhang auto &Fda = mm->Fd.values; 7520e3ece09SJunchao Zhang const auto &Fdjmap = mm->Fdjmap; 7530e3ece09SJunchao Zhang const auto &Fdjperm = mm->Fdjperm; 7540e3ece09SJunchao Zhang auto Fdnz = mm->Fd.nnz(); 7550e3ece09SJunchao Zhang auto &Foa = mm->Fo.values; 7560e3ece09SJunchao Zhang const auto &Fojmap = mm->Fojmap; 7570e3ece09SJunchao Zhang const auto &Fojperm = mm->Fojperm; 7580e3ece09SJunchao Zhang auto Fonz = mm->Fo.nnz(); 7590e3ece09SJunchao Zhang PetscSF reduceSF = mm->sf; 7600e3ece09SJunchao Zhang 7610e3ece09SJunchao Zhang PetscCall(PetscSFReduceEnd(reduceSF, MPIU_SCALAR, leafBuf.data(), rootBuf.data(), MPI_REPLACE)); 7620e3ece09SJunchao Zhang 7630e3ece09SJunchao Zhang // Reduce data in rootBuf to Fd and Fo 7640e3ece09SJunchao Zhang PetscCallCXX(Kokkos::parallel_for( 7650e3ece09SJunchao Zhang Fdnz, KOKKOS_LAMBDA(const MatRowMapType i) { 7660e3ece09SJunchao Zhang PetscScalar sum = 0.0; 7670e3ece09SJunchao Zhang for (MatRowMapType k = Fdjmap(i); k < Fdjmap(i + 1); k++) sum += rootBuf(Fdjperm(k)); 7680e3ece09SJunchao Zhang Fda(i) = sum; 7690e3ece09SJunchao Zhang })); 7700e3ece09SJunchao Zhang 7710e3ece09SJunchao Zhang PetscCallCXX(Kokkos::parallel_for( 7720e3ece09SJunchao Zhang Fonz, KOKKOS_LAMBDA(const MatRowMapType i) { 7730e3ece09SJunchao Zhang PetscScalar sum = 0.0; 7740e3ece09SJunchao Zhang for (MatRowMapType k = Fojmap(i); k < Fojmap(i + 1); k++) sum += rootBuf(Fojperm(k)); 7750e3ece09SJunchao Zhang Foa(i) = sum; 7760e3ece09SJunchao Zhang })); 7770e3ece09SJunchao Zhang PetscFunctionReturn(PETSC_SUCCESS); 7780e3ece09SJunchao Zhang } 7790e3ece09SJunchao Zhang 7800e3ece09SJunchao Zhang /* 7810e3ece09SJunchao Zhang MatMPIAIJKokkosBcast - Bcast local rows of a MPIAIJKOKKOS matrix (E) to produce a local matrix (F, stored in mm) in split form 7820e3ece09SJunchao Zhang 7830e3ece09SJunchao Zhang This is a complex routine. It is essentially the MPIAIJKOKKOS counterpart of MatGetBrowsOfAoCols_MPIAIJ, but supports 7840e3ece09SJunchao Zhang device and involves various index mapping. 7850e3ece09SJunchao Zhang 7860e3ece09SJunchao Zhang In the given ownerSF, leaves correspond to rows in F, and roots correspond to rows in E. Roots may connect to multiple leaves. 7870e3ece09SJunchao Zhang Suppose F's j-th row is connected to a root identified by PetscSFNode (k,i), it means we need to bcast the i-th row of E on rank k 7880e3ece09SJunchao Zhang to j-th row of F. ownerSF is not an arbitrary SF, instead it is the Mvctx of another MPIAIJ matrix A that is able to perform A*E. 7890e3ece09SJunchao Zhang F has the same column layout as E. 7900e3ece09SJunchao Zhang 7910e3ece09SJunchao Zhang Conceptually F has global column indices. In this routine, we spit F into diagonal Fd and off-diagonal Fo. 7920e3ece09SJunchao Zhang Fd uses local column indices, which are easy to compute. We just need to substract the "local column range start" from the global indices. 7930e3ece09SJunchao Zhang Fo had global column indices at first. We will reduce them into local ones. In doing that, we also take into account the global 7940e3ece09SJunchao Zhang column indices that E's off-diag block has. Let's say there are n1 such indices stored in garray1[]. We will reduce them along with 7950e3ece09SJunchao Zhang column indices in Fo and update Fo with local indices. 7960e3ece09SJunchao Zhang 7970e3ece09SJunchao Zhang Input Parameters: 7980e3ece09SJunchao Zhang + E - the MPIAIJKOKKOS matrix 7990e3ece09SJunchao Zhang . ownerSF - the ownership SF (nonsignificant in MAT_REUSE_MATRIX) 8000e3ece09SJunchao Zhang . reuse - either MAT_INITIAL_MATRIX or MAT_REUSE_MATRIX 8010e3ece09SJunchao Zhang - mm - to stash matproduct intermediate data structures 8020e3ece09SJunchao Zhang 8030e3ece09SJunchao Zhang Output Parameters: 8040e3ece09SJunchao Zhang + map[n1] - allocated by caller. It maps garray1[] to garray2[]. See more at ReduceTwoSetsOfGlobalIndices. 8050e3ece09SJunchao Zhang - mm - contains various info, such as garray2[], Fd, Fo, etc. 8060e3ece09SJunchao Zhang 8070e3ece09SJunchao Zhang Notes: 8080e3ece09SJunchao Zhang When reuse = MAT_REUSE_MATRIX, ownerSF, map are not significant. 8090e3ece09SJunchao Zhang The routine is provide in split-phase form MatMPIAIJKokkosBcastBegin/End() to provide computation/communication opportunities. 8100e3ece09SJunchao Zhang */ 8110e3ece09SJunchao Zhang static PetscErrorCode MatMPIAIJKokkosBcastBegin(Mat E, PetscSF ownerSF, MatReuse reuse, PetscInt *map, MatMatStruct_AB *mm) 8120e3ece09SJunchao Zhang { 8130e3ece09SJunchao Zhang Mat_MPIAIJ *empi = static_cast<Mat_MPIAIJ *>(E->data); 8140e3ece09SJunchao Zhang Mat A = empi->A, B = empi->B; // diag and off-diag 8150e3ece09SJunchao Zhang Mat_SeqAIJKokkos *akok = static_cast<Mat_SeqAIJKokkos *>(A->spptr), *bkok = static_cast<Mat_SeqAIJKokkos *>(B->spptr); 8160e3ece09SJunchao Zhang PetscInt Em = E->rmap->n; // #local rows 8170e3ece09SJunchao Zhang MPI_Comm comm; 8180e3ece09SJunchao Zhang 8190e3ece09SJunchao Zhang PetscFunctionBegin; 8200e3ece09SJunchao Zhang PetscCallMPI(PetscObjectGetComm((PetscObject)E, &comm)); 8210e3ece09SJunchao Zhang if (reuse == MAT_INITIAL_MATRIX) { 8220e3ece09SJunchao Zhang Mat_SeqAIJ *aseq = static_cast<Mat_SeqAIJ *>(A->data), *bseq = static_cast<Mat_SeqAIJ *>(B->data); 8230e3ece09SJunchao Zhang PetscInt n1 = B->cmap->n, *Ai = aseq->i, *Aj = aseq->j, *Bi = bseq->i, *Bj = bseq->j; 8240e3ece09SJunchao Zhang const PetscInt *garray1 = empi->garray; // its size is n1 8250e3ece09SJunchao Zhang PetscInt cstart, cend; 8260e3ece09SJunchao Zhang PetscSF bcastSF; 8270e3ece09SJunchao Zhang 8280e3ece09SJunchao Zhang PetscCall(MatGetOwnershipRangeColumn(E, &cstart, &cend)); 8290e3ece09SJunchao Zhang 8300e3ece09SJunchao Zhang // Count how many nonzeros of each row in E are in the left of the diag block [cstart,cend) 8310e3ece09SJunchao Zhang PetscIntKokkosViewHost E_NzLeft_h("E_NzLeft_h", Em), E_RowLen_h("E_RowLen_h", Em); 8320e3ece09SJunchao Zhang PetscInt *E_NzLeft = E_NzLeft_h.data(), *E_RowLen = E_RowLen_h.data(); 8330e3ece09SJunchao Zhang for (PetscInt i = 0; i < Em; i++) { 8340e3ece09SJunchao Zhang const PetscInt *first, *last, *it; 8350e3ece09SJunchao Zhang PetscInt count, step; 8360e3ece09SJunchao Zhang // std::lower_bound(first,last,cstart), but need to use global column indices 8370e3ece09SJunchao Zhang first = Bj + Bi[i]; 8380e3ece09SJunchao Zhang last = Bj + Bi[i + 1]; 8390e3ece09SJunchao Zhang count = last - first; 8400e3ece09SJunchao Zhang while (count > 0) { 8410e3ece09SJunchao Zhang it = first; 8420e3ece09SJunchao Zhang step = count / 2; 8430e3ece09SJunchao Zhang it += step; 8440e3ece09SJunchao Zhang if (empi->garray[*it] < cstart) { // map local to global 8450e3ece09SJunchao Zhang first = ++it; 8460e3ece09SJunchao Zhang count -= step + 1; 8470e3ece09SJunchao Zhang } else count = step; 8480e3ece09SJunchao Zhang } 8490e3ece09SJunchao Zhang E_NzLeft[i] = first - (Bj + Bi[i]); 8500e3ece09SJunchao Zhang E_RowLen[i] = (Ai[i + 1] - Ai[i]) + (Bi[i + 1] - Bi[i]); 8510e3ece09SJunchao Zhang } 8520e3ece09SJunchao Zhang 8530e3ece09SJunchao Zhang // Compute row pointer Fi of F 8540e3ece09SJunchao Zhang PetscInt *Fi, Fm, Fnz; 8550e3ece09SJunchao Zhang PetscCall(PetscSFGetGraph(ownerSF, NULL, &Fm, NULL, NULL)); // Fm = #rows of F = nleaves of ownerSF 8560e3ece09SJunchao Zhang PetscCall(PetscMalloc1(Fm + 1, &Fi)); 8570e3ece09SJunchao Zhang Fi[0] = 0; 8580e3ece09SJunchao Zhang PetscCall(PetscSFBcastWithMemTypeBegin(ownerSF, MPIU_INT, PETSC_MEMTYPE_HOST, E_RowLen, PETSC_MEMTYPE_HOST, &Fi[1], MPI_REPLACE)); 8590e3ece09SJunchao Zhang PetscCall(PetscSFBcastEnd(ownerSF, MPIU_INT, E_RowLen, &Fi[1], MPI_REPLACE)); 8600e3ece09SJunchao Zhang for (PetscInt i = 0; i < Fm; i++) Fi[i + 1] += Fi[i]; 8610e3ece09SJunchao Zhang Fnz = Fi[Fm]; 8620e3ece09SJunchao Zhang 8630e3ece09SJunchao Zhang // Build the real PetscSF for bcasting E rows (buffer to buffer) 8640e3ece09SJunchao Zhang const PetscMPIInt *iranks, *ranks; 8650e3ece09SJunchao Zhang const PetscInt *ioffset, *irootloc, *roffset; 8660e3ece09SJunchao Zhang PetscInt niranks, nranks, *sdisp, *rdisp; 8670e3ece09SJunchao Zhang MPI_Request *reqs; 8680e3ece09SJunchao Zhang PetscMPIInt tag; 8690e3ece09SJunchao Zhang 8700e3ece09SJunchao Zhang PetscCall(PetscSFGetLeafRanks(ownerSF, &niranks, &iranks, &ioffset, &irootloc)); // get leaf ranks referencing roots on this process 8710e3ece09SJunchao Zhang PetscCall(PetscSFGetRootRanks(ownerSF, &nranks, &ranks, &roffset, NULL, NULL)); // recv info 8720e3ece09SJunchao Zhang PetscCall(PetscMalloc3(niranks + 1, &sdisp, nranks, &rdisp, niranks + nranks, &reqs)); 8730e3ece09SJunchao Zhang 8740e3ece09SJunchao Zhang sdisp[0] = 0; // send displacement 8750e3ece09SJunchao Zhang for (PetscInt i = 0; i < niranks; i++) { 8760e3ece09SJunchao Zhang sdisp[i + 1] = sdisp[i]; 8770e3ece09SJunchao Zhang for (PetscInt j = ioffset[i]; j < ioffset[i + 1]; j++) { 8780e3ece09SJunchao Zhang PetscInt r = irootloc[j]; // row to be sent 8790e3ece09SJunchao Zhang sdisp[i + 1] += E_RowLen[r]; 8800e3ece09SJunchao Zhang } 8810e3ece09SJunchao Zhang } 8820e3ece09SJunchao Zhang 8830e3ece09SJunchao Zhang PetscCallMPI(PetscCommGetNewTag(comm, &tag)); 8840e3ece09SJunchao Zhang for (PetscInt i = 0; i < nranks; i++) PetscCallMPI(MPI_Irecv(&rdisp[i], 1, MPIU_INT, ranks[i], tag, comm, &reqs[i])); 8850e3ece09SJunchao Zhang for (PetscInt i = 0; i < niranks; i++) PetscCallMPI(MPI_Isend(&sdisp[i], 1, MPIU_INT, iranks[i], tag, comm, &reqs[nranks + i])); 8860e3ece09SJunchao Zhang PetscCallMPI(MPI_Waitall(niranks + nranks, reqs, MPI_STATUSES_IGNORE)); 8870e3ece09SJunchao Zhang 8880e3ece09SJunchao Zhang PetscInt nleaves = Fnz; // leaves are nonzeros I will receive 8890e3ece09SJunchao Zhang PetscInt nroots = sdisp[niranks]; // roots are nonzeros I will send 8900e3ece09SJunchao Zhang PetscSFNode *iremote; // give ownership to bcastSF 8910e3ece09SJunchao Zhang PetscCall(PetscMalloc1(nleaves, &iremote)); 8920e3ece09SJunchao Zhang for (PetscInt i = 0; i < nranks; i++) { // for each sender rank 8930e3ece09SJunchao Zhang PetscInt k = 0; 8940e3ece09SJunchao Zhang for (PetscInt j = Fi[roffset[i]]; j < Fi[roffset[i + 1]]; j++) { // I will receive rows [roffset[i], roffset[i+1]) of F from ranks[i] 8950e3ece09SJunchao Zhang iremote[j].rank = ranks[i]; 8960e3ece09SJunchao Zhang iremote[j].index = rdisp[i] + k; // their root location 8970e3ece09SJunchao Zhang k++; 8980e3ece09SJunchao Zhang } 8990e3ece09SJunchao Zhang } 9000e3ece09SJunchao Zhang PetscCall(PetscSFCreate(comm, &bcastSF)); 9010e3ece09SJunchao Zhang PetscCall(PetscSFSetGraph(bcastSF, nroots, nleaves, NULL, PETSC_OWN_POINTER, iremote, PETSC_OWN_POINTER)); 9020e3ece09SJunchao Zhang PetscCall(PetscFree3(sdisp, rdisp, reqs)); 9030e3ece09SJunchao Zhang 9040e3ece09SJunchao Zhang // Build a plan (rowoffset, irootloc, E_NzLeft) to copy rows in E to rootdata of bcastSF in parallel 9050e3ece09SJunchao Zhang PetscIntKokkosViewHost rowoffset_h("rowoffset_h", ioffset[niranks] + 1); 9060e3ece09SJunchao Zhang PetscInt *rowoffset = rowoffset_h.data(); // for each entry (row) indicated in irootloc[], we calculate its destinate offset in copying 9070e3ece09SJunchao Zhang rowoffset[0] = 0; 9080e3ece09SJunchao Zhang for (PetscInt i = 0; i < ioffset[niranks]; i++) { rowoffset[i + 1] += rowoffset[i] + E_RowLen[irootloc[i]]; } 9090e3ece09SJunchao Zhang 9100e3ece09SJunchao Zhang // Copy (global) column indices of the needed rows in E to a buffer, and then bcast to Fj[] 9110e3ece09SJunchao Zhang PetscInt *jbuf, *Fj; 9120e3ece09SJunchao Zhang PetscCall(PetscMalloc2(nroots, &jbuf, Fnz, &Fj)); 9130e3ece09SJunchao Zhang for (PetscInt k = 0; k < ioffset[niranks]; k++) { 9140e3ece09SJunchao Zhang PetscInt i = irootloc[k]; // row to be copied 9150e3ece09SJunchao Zhang PetscInt *buf = &jbuf[rowoffset[k]]; 9160e3ece09SJunchao Zhang PetscInt nzLeft = E_NzLeft[i]; 9170e3ece09SJunchao Zhang PetscInt alen = Ai[i + 1] - Ai[i], blen = Bi[i + 1] - Bi[i]; 9180e3ece09SJunchao Zhang for (PetscInt j = 0; j < alen + blen; j++) { 9190e3ece09SJunchao Zhang if (j < nzLeft) { 9200e3ece09SJunchao Zhang buf[j] = empi->garray[Bj[Bi[i] + j]]; // left B, in global 9210e3ece09SJunchao Zhang } else if (j < nzLeft + alen) { 9220e3ece09SJunchao Zhang buf[j] = Aj[Ai[i] + j - nzLeft] + cstart; // diag A, also in global 9230e3ece09SJunchao Zhang } else { 9240e3ece09SJunchao Zhang buf[j] = empi->garray[Bj[Bi[i] + j - alen]]; // right B, in global 9250e3ece09SJunchao Zhang } 9260e3ece09SJunchao Zhang } 9270e3ece09SJunchao Zhang } 9280e3ece09SJunchao Zhang PetscCall(PetscSFBcastWithMemTypeBegin(bcastSF, MPIU_INT, PETSC_MEMTYPE_HOST, jbuf, PETSC_MEMTYPE_HOST, Fj, MPI_REPLACE)); 9290e3ece09SJunchao Zhang PetscCall(PetscSFBcastEnd(bcastSF, MPIU_INT, jbuf, Fj, MPI_REPLACE)); 9300e3ece09SJunchao Zhang 9310e3ece09SJunchao Zhang // Build a plan (i.e., F_NzLeft) to split F into Fd and Fo 9320e3ece09SJunchao Zhang MatRowMapKokkosViewHost Fdi_h("Fdi_h", Fm + 1), Foi_h("Foi_h", Fm + 1); // row pointer of Fd, Fo 9330e3ece09SJunchao Zhang MatColIdxKokkosViewHost F_NzLeft_h("F_NzLeft_h", Fm); // split each row of F into Left, Diag, Right. We only need to record #nz in Left and Diag. 9340e3ece09SJunchao Zhang MatRowMapType *Fdi = Fdi_h.data(), *Foi = Foi_h.data(); 9350e3ece09SJunchao Zhang MatColIdxType *F_NzLeft = F_NzLeft_h.data(); 9360e3ece09SJunchao Zhang 9370e3ece09SJunchao Zhang Fdi[0] = Foi[0] = 0; 9380e3ece09SJunchao Zhang for (PetscInt i = 0; i < Fm; i++) { 9390e3ece09SJunchao Zhang PetscInt *first, *last, *lb1, *lb2; 9400e3ece09SJunchao Zhang // cut the row into: Left, [cstart, cend), Right 9410e3ece09SJunchao Zhang first = Fj + Fi[i]; 9420e3ece09SJunchao Zhang last = Fj + Fi[i + 1]; 9430e3ece09SJunchao Zhang lb1 = std::lower_bound(first, last, cstart); 9440e3ece09SJunchao Zhang F_NzLeft[i] = lb1 - first; 9450e3ece09SJunchao Zhang lb2 = std::lower_bound(first, last, cend); 9460e3ece09SJunchao Zhang Fdi[i + 1] = lb2 - lb1; // row i length in Fdi 9470e3ece09SJunchao Zhang Foi[i + 1] = (Fi[i + 1] - Fi[i]) - Fdi[i + 1]; // row i length in Foi 9480e3ece09SJunchao Zhang } 9490e3ece09SJunchao Zhang for (PetscInt i = 0; i < Fm; i++) { 9500e3ece09SJunchao Zhang Fdi[i + 1] += Fdi[i]; 9510e3ece09SJunchao Zhang Foi[i + 1] += Foi[i]; 9520e3ece09SJunchao Zhang } 9530e3ece09SJunchao Zhang 9540e3ece09SJunchao Zhang // Fill Fdj[] and Foj[], i.e., columns of Fd and Fo. Fdj[] are local, but Foj[] are not yet. 9550e3ece09SJunchao Zhang PetscInt Fdnz = Fdi[Fm], Fonz = Foi[Fm]; 9560e3ece09SJunchao Zhang MatColIdxKokkosViewHost Fdj_h("Fdj_h", Fdnz), Foj_h("Foj_h", Fonz); 9570e3ece09SJunchao Zhang MatColIdxType *Fdj = Fdj_h.data(), *Foj = Foj_h.data(), gid; 9580e3ece09SJunchao Zhang 9590e3ece09SJunchao Zhang for (PetscInt i = 0; i < Fm; i++) { 9600e3ece09SJunchao Zhang PetscInt nzLeft = F_NzLeft[i]; 9610e3ece09SJunchao Zhang PetscInt len = Fdi[i + 1] - Fdi[i]; // diag row len 9620e3ece09SJunchao Zhang for (PetscInt j = 0; j < Fi[i + 1] - Fi[i]; j++) { 9630e3ece09SJunchao Zhang gid = Fj[Fi[i] + j]; 9640e3ece09SJunchao Zhang if (j < nzLeft) { // left, in global 9650e3ece09SJunchao Zhang Foj[Foi[i] + j] = gid; 9660e3ece09SJunchao Zhang } else if (j < nzLeft + len) { // diag, in local 9670e3ece09SJunchao Zhang Fdj[Fdi[i] + j - nzLeft] = gid - cstart; 9680e3ece09SJunchao Zhang } else { // right, in global 9690e3ece09SJunchao Zhang Foj[Foi[i] + j - len] = gid; 9700e3ece09SJunchao Zhang } 9710e3ece09SJunchao Zhang } 9720e3ece09SJunchao Zhang } 9730e3ece09SJunchao Zhang PetscCall(PetscFree2(jbuf, Fj)); 9740e3ece09SJunchao Zhang PetscCall(PetscFree(Fi)); 9750e3ece09SJunchao Zhang 9760e3ece09SJunchao Zhang // Reduce global indices in Foj[] and garray1[] into local ones 9770e3ece09SJunchao Zhang PetscInt n2, *garray2; 9780e3ece09SJunchao Zhang PetscCall(ReduceTwoSetsOfGlobalIndices(n1, garray1, Fonz, Foj, &n2, &garray2, map)); 9790e3ece09SJunchao Zhang 9800e3ece09SJunchao Zhang // Record the plans built above, for reuse 9810e3ece09SJunchao Zhang PetscIntKokkosViewHost tmp(const_cast<PetscInt *>(irootloc), ioffset[niranks]); // irootloc[] is owned by ownerSF. We create a copy for safety 9820e3ece09SJunchao Zhang PetscIntKokkosViewHost irootloc_h("irootloc_h", ioffset[niranks]); 9830e3ece09SJunchao Zhang Kokkos::deep_copy(irootloc_h, tmp); 9840e3ece09SJunchao Zhang mm->sf = bcastSF; 9850e3ece09SJunchao Zhang mm->E_NzLeft = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), E_NzLeft_h); 9860e3ece09SJunchao Zhang mm->F_NzLeft = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), F_NzLeft_h); 9870e3ece09SJunchao Zhang mm->irootloc = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), irootloc_h); 9880e3ece09SJunchao Zhang mm->rowoffset = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), rowoffset_h); 9890e3ece09SJunchao Zhang mm->rootBuf = MatScalarKokkosView("rootBuf", nroots); 9900e3ece09SJunchao Zhang mm->leafBuf = MatScalarKokkosView("leafBuf", nleaves); 9910e3ece09SJunchao Zhang mm->garray = garray2; 9920e3ece09SJunchao Zhang mm->n = n2; 9930e3ece09SJunchao Zhang 9940e3ece09SJunchao Zhang // Output Fd and Fo in KokkosCsrMatrix format 9950e3ece09SJunchao Zhang MatScalarKokkosView Fda_d("Fda_d", Fdnz), Foa_d("Foa_d", Fonz); 9960e3ece09SJunchao Zhang MatRowMapKokkosView Fdi_d = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Fdi_h); 9970e3ece09SJunchao Zhang MatColIdxKokkosView Fdj_d = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Fdj_h); 9980e3ece09SJunchao Zhang MatRowMapKokkosView Foi_d = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Foi_h); 9990e3ece09SJunchao Zhang MatColIdxKokkosView Foj_d = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Foj_h); 10000e3ece09SJunchao Zhang 10010e3ece09SJunchao Zhang PetscCallCXX(mm->Fd = KokkosCsrMatrix("Fd", Fm, cend - cstart, Fdnz, Fda_d, Fdi_d, Fdj_d)); 10020e3ece09SJunchao Zhang PetscCallCXX(mm->Fo = KokkosCsrMatrix("Fo", Fm, n2, Fonz, Foa_d, Foi_d, Foj_d)); 10030e3ece09SJunchao Zhang 10040e3ece09SJunchao Zhang // Compute kernel launch parameters in merging E or splitting F 10050e3ece09SJunchao Zhang PetscInt teamSize, vectorLength, rowsPerTeam; 10060e3ece09SJunchao Zhang 10070e3ece09SJunchao Zhang teamSize = vectorLength = rowsPerTeam = -1; 10080e3ece09SJunchao Zhang PetscCall(MatMergeGetLaunchParameters<DefaultExecutionSpace>(mm->irootloc.extent(0), mm->rootBuf.extent(0), -1, teamSize, vectorLength, rowsPerTeam)); 10090e3ece09SJunchao Zhang mm->E_TeamSize = teamSize; 10100e3ece09SJunchao Zhang mm->E_VectorLength = vectorLength; 10110e3ece09SJunchao Zhang mm->E_RowsPerTeam = rowsPerTeam; 10120e3ece09SJunchao Zhang 10130e3ece09SJunchao Zhang teamSize = vectorLength = rowsPerTeam = -1; 10140e3ece09SJunchao Zhang PetscCall(MatMergeGetLaunchParameters<DefaultExecutionSpace>(Fm, Fnz, -1, teamSize, vectorLength, rowsPerTeam)); 10150e3ece09SJunchao Zhang mm->F_TeamSize = teamSize; 10160e3ece09SJunchao Zhang mm->F_VectorLength = vectorLength; 10170e3ece09SJunchao Zhang mm->F_RowsPerTeam = rowsPerTeam; 10180e3ece09SJunchao Zhang } else PetscCheck(reuse == MAT_REUSE_MATRIX, comm, PETSC_ERR_PLIB, "Unsupported MatReuse enum %d", reuse); 10190e3ece09SJunchao Zhang 10200e3ece09SJunchao Zhang // Sync E's value to device 10210e3ece09SJunchao Zhang akok->a_dual.sync_device(); 10220e3ece09SJunchao Zhang bkok->a_dual.sync_device(); 10230e3ece09SJunchao Zhang 10240e3ece09SJunchao Zhang // Handy aliases 10250e3ece09SJunchao Zhang const auto &Aa = akok->a_dual.view_device(); 10260e3ece09SJunchao Zhang const auto &Ba = bkok->a_dual.view_device(); 10270e3ece09SJunchao Zhang const auto &Ai = akok->i_dual.view_device(); 10280e3ece09SJunchao Zhang const auto &Bi = bkok->i_dual.view_device(); 10290e3ece09SJunchao Zhang 10300e3ece09SJunchao Zhang // Fetch the plans 10310e3ece09SJunchao Zhang PetscIntKokkosView &E_NzLeft = mm->E_NzLeft; 10320e3ece09SJunchao Zhang PetscSF &bcastSF = mm->sf; 10330e3ece09SJunchao Zhang MatScalarKokkosView &rootBuf = mm->rootBuf; 10340e3ece09SJunchao Zhang MatScalarKokkosView &leafBuf = mm->leafBuf; 10350e3ece09SJunchao Zhang PetscIntKokkosView &irootloc = mm->irootloc; 10360e3ece09SJunchao Zhang PetscIntKokkosView &rowoffset = mm->rowoffset; 10370e3ece09SJunchao Zhang 10380e3ece09SJunchao Zhang PetscInt teamSize = mm->E_TeamSize; 10390e3ece09SJunchao Zhang PetscInt vectorLength = mm->E_VectorLength; 10400e3ece09SJunchao Zhang PetscInt rowsPerTeam = mm->E_RowsPerTeam; 10410e3ece09SJunchao Zhang PetscInt workSets = (irootloc.extent(0) + rowsPerTeam - 1) / rowsPerTeam; 10420e3ece09SJunchao Zhang 10430e3ece09SJunchao Zhang // Copy rows in A/B of E to rootBuf, then bcast it to leafBuf 10440e3ece09SJunchao Zhang PetscCallCXX(Kokkos::parallel_for( 10450e3ece09SJunchao Zhang Kokkos::TeamPolicy<>(workSets, teamSize, vectorLength), KOKKOS_LAMBDA(const KokkosTeamMemberType &t) { 10460e3ece09SJunchao Zhang Kokkos::parallel_for(Kokkos::TeamThreadRange(t, 0, rowsPerTeam), [&](PetscInt k) { 10470e3ece09SJunchao Zhang size_t r = t.league_rank() * rowsPerTeam + k; // r-th entry in irootloc[] 10480e3ece09SJunchao Zhang if (r < irootloc.extent(0)) { 10490e3ece09SJunchao Zhang PetscInt i = irootloc(r); // row i of E 10500e3ece09SJunchao Zhang PetscInt disp = rowoffset(r); 10510e3ece09SJunchao Zhang PetscInt alen = Ai(i + 1) - Ai(i); 10520e3ece09SJunchao Zhang PetscInt blen = Bi(i + 1) - Bi(i); 10530e3ece09SJunchao Zhang PetscInt nzleft = E_NzLeft(i); 10540e3ece09SJunchao Zhang 10550e3ece09SJunchao Zhang Kokkos::parallel_for(Kokkos::ThreadVectorRange(t, alen + blen), [&](PetscInt j) { 10560e3ece09SJunchao Zhang if (j < nzleft) { // B left 10570e3ece09SJunchao Zhang rootBuf(disp + j) = Ba(Bi(i) + j); 10580e3ece09SJunchao Zhang } else if (j < nzleft + alen) { // diag A 10590e3ece09SJunchao Zhang rootBuf(disp + j) = Aa(Ai(i) + j - nzleft); 10600e3ece09SJunchao Zhang } else { // B right 10610e3ece09SJunchao Zhang rootBuf(disp + j) = Ba(Bi(i) + j - alen); 10620e3ece09SJunchao Zhang } 10630e3ece09SJunchao Zhang }); 10640e3ece09SJunchao Zhang } 10650e3ece09SJunchao Zhang }); 10660e3ece09SJunchao Zhang })); 10670e3ece09SJunchao Zhang PetscCall(PetscSFBcastWithMemTypeBegin(bcastSF, MPIU_SCALAR, PETSC_MEMTYPE_KOKKOS, rootBuf.data(), PETSC_MEMTYPE_KOKKOS, leafBuf.data(), MPI_REPLACE)); 10680e3ece09SJunchao Zhang PetscFunctionReturn(PETSC_SUCCESS); 10690e3ece09SJunchao Zhang } 10700e3ece09SJunchao Zhang 10710e3ece09SJunchao Zhang // To finish MatMPIAIJKokkosBcast. 10720e3ece09SJunchao Zhang static PetscErrorCode MatMPIAIJKokkosBcastEnd(Mat E, PetscSF ownerSF, MatReuse reuse, PetscInt *map, MatMatStruct_AB *mm) 10730e3ece09SJunchao Zhang { 10740e3ece09SJunchao Zhang PetscFunctionBegin; 10750e3ece09SJunchao Zhang const auto &Fd = mm->Fd; 10760e3ece09SJunchao Zhang const auto &Fo = mm->Fo; 10770e3ece09SJunchao Zhang const auto &Fdi = Fd.graph.row_map; 10780e3ece09SJunchao Zhang const auto &Foi = Fo.graph.row_map; 10790e3ece09SJunchao Zhang auto &Fda = Fd.values; 10800e3ece09SJunchao Zhang auto &Foa = Fo.values; 10810e3ece09SJunchao Zhang auto Fm = Fd.numRows(); 10820e3ece09SJunchao Zhang 10830e3ece09SJunchao Zhang PetscIntKokkosView &F_NzLeft = mm->F_NzLeft; 10840e3ece09SJunchao Zhang PetscSF &bcastSF = mm->sf; 10850e3ece09SJunchao Zhang MatScalarKokkosView &rootBuf = mm->rootBuf; 10860e3ece09SJunchao Zhang MatScalarKokkosView &leafBuf = mm->leafBuf; 10870e3ece09SJunchao Zhang PetscInt teamSize = mm->F_TeamSize; 10880e3ece09SJunchao Zhang PetscInt vectorLength = mm->F_VectorLength; 10890e3ece09SJunchao Zhang PetscInt rowsPerTeam = mm->F_RowsPerTeam; 10900e3ece09SJunchao Zhang PetscInt workSets = (Fm + rowsPerTeam - 1) / rowsPerTeam; 10910e3ece09SJunchao Zhang 10920e3ece09SJunchao Zhang PetscCall(PetscSFBcastEnd(bcastSF, MPIU_SCALAR, rootBuf.data(), leafBuf.data(), MPI_REPLACE)); 10930e3ece09SJunchao Zhang 10940e3ece09SJunchao Zhang // Update Fda and Foa with new data in leafBuf (as if it is Fa) 10950e3ece09SJunchao Zhang PetscCallCXX(Kokkos::parallel_for( 10960e3ece09SJunchao Zhang Kokkos::TeamPolicy<>(workSets, teamSize, vectorLength), KOKKOS_LAMBDA(const KokkosTeamMemberType &t) { 10970e3ece09SJunchao Zhang Kokkos::parallel_for(Kokkos::TeamThreadRange(t, 0, rowsPerTeam), [&](PetscInt k) { 10980e3ece09SJunchao Zhang PetscInt i = t.league_rank() * rowsPerTeam + k; // i-th row in F 10990e3ece09SJunchao Zhang if (i < Fm) { 11000e3ece09SJunchao Zhang PetscInt nzLeft = F_NzLeft(i); 11010e3ece09SJunchao Zhang PetscInt alen = Fdi(i + 1) - Fdi(i); 11020e3ece09SJunchao Zhang PetscInt blen = Foi(i + 1) - Foi(i); 11030e3ece09SJunchao Zhang PetscInt Fii = Fdi(i) + Foi(i); 11040e3ece09SJunchao Zhang 11050e3ece09SJunchao Zhang Kokkos::parallel_for(Kokkos::ThreadVectorRange(t, alen + blen), [&](PetscInt j) { 11060e3ece09SJunchao Zhang PetscScalar val = leafBuf(Fii + j); 11070e3ece09SJunchao Zhang if (j < nzLeft) { // left 11080e3ece09SJunchao Zhang Foa(Foi(i) + j) = val; 11090e3ece09SJunchao Zhang } else if (j < nzLeft + alen) { // diag 11100e3ece09SJunchao Zhang Fda(Fdi(i) + j - nzLeft) = val; 11110e3ece09SJunchao Zhang } else { // right 11120e3ece09SJunchao Zhang Foa(Foi(i) + j - alen) = val; 11130e3ece09SJunchao Zhang } 11140e3ece09SJunchao Zhang }); 11150e3ece09SJunchao Zhang } 11160e3ece09SJunchao Zhang }); 11170e3ece09SJunchao Zhang })); 11180e3ece09SJunchao Zhang PetscFunctionReturn(PETSC_SUCCESS); 11190e3ece09SJunchao Zhang } 11200e3ece09SJunchao Zhang 11210e3ece09SJunchao Zhang static PetscErrorCode MatProductSymbolic_MPIAIJKokkos_AtB(Mat_Product *product, Mat A, Mat B, MatMatStruct_AtB *mm) 11220e3ece09SJunchao Zhang { 11230e3ece09SJunchao Zhang Mat_MPIAIJ *ampi = static_cast<Mat_MPIAIJ *>(A->data); 11240e3ece09SJunchao Zhang Mat_MPIAIJ *bmpi = static_cast<Mat_MPIAIJ *>(B->data); 11250e3ece09SJunchao Zhang KokkosCsrMatrix Adt, Aot, Ad, Ao, Bd, Bo; 11260e3ece09SJunchao Zhang PetscInt cstart, cend; 11270e3ece09SJunchao Zhang MPI_Comm comm; 11280e3ece09SJunchao Zhang 11290e3ece09SJunchao Zhang PetscFunctionBegin; 11300e3ece09SJunchao Zhang PetscCall(PetscObjectGetComm((PetscObject)B, &comm)); 11310e3ece09SJunchao Zhang PetscCall(MatSeqAIJKokkosGenerateTranspose_Private(ampi->A, &Adt)); 11320e3ece09SJunchao Zhang PetscCall(MatSeqAIJKokkosGenerateTranspose_Private(ampi->B, &Aot)); 11330e3ece09SJunchao Zhang PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(ampi->A, &Ad)); 11340e3ece09SJunchao Zhang PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(ampi->B, &Ao)); 11350e3ece09SJunchao Zhang PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(bmpi->A, &Bd)); 11360e3ece09SJunchao Zhang PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(bmpi->B, &Bo)); 11370e3ece09SJunchao Zhang 11380e3ece09SJunchao Zhang // TODO: add command line options to select spgemm algorithms 11390e3ece09SJunchao Zhang auto spgemm_alg = KokkosSparse::SPGEMMAlgorithm::SPGEMM_DEFAULT; // default is TPL if enabled, otherwise KK 11400e3ece09SJunchao Zhang 11410e3ece09SJunchao Zhang // CUDA-10.2's spgemm has bugs. We prefer the SpGEMMreuse APIs introduced in cuda-11.4 11420e3ece09SJunchao Zhang #if defined(KOKKOSKERNELS_ENABLE_TPL_CUSPARSE) 11430e3ece09SJunchao Zhang #if PETSC_PKG_CUDA_VERSION_LT(11, 4, 0) 11440e3ece09SJunchao Zhang spgemm_alg = KokkosSparse::SPGEMMAlgorithm::SPGEMM_KK; 1145f0e6e2d1SJunchao Zhang #endif 11460e3ece09SJunchao Zhang #endif 11470e3ece09SJunchao Zhang 11480e3ece09SJunchao Zhang PetscCallCXX(mm->kh1.create_spgemm_handle(spgemm_alg)); 11490e3ece09SJunchao Zhang PetscCallCXX(mm->kh2.create_spgemm_handle(spgemm_alg)); 11500e3ece09SJunchao Zhang PetscCallCXX(mm->kh3.create_spgemm_handle(spgemm_alg)); 11510e3ece09SJunchao Zhang PetscCallCXX(mm->kh4.create_spgemm_handle(spgemm_alg)); 11520e3ece09SJunchao Zhang 11530e3ece09SJunchao Zhang // Aot * (B's diag + B's off-diag) 11540e3ece09SJunchao Zhang PetscCallCXX(KokkosSparse::spgemm_symbolic(mm->kh3, Aot, false, Bd, false, mm->C3)); 11550e3ece09SJunchao Zhang PetscCallCXX(KokkosSparse::spgemm_symbolic(mm->kh4, Aot, false, Bo, false, mm->C4)); 11560e3ece09SJunchao Zhang // KK spgemm_symbolic() only populates the result's row map, but not its columns. 11570e3ece09SJunchao Zhang // TODO: Remove the fake spgemm_numeric() after KK fixed this problem. 11580e3ece09SJunchao Zhang PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh3, Aot, false, Bd, false, mm->C3)); 11590e3ece09SJunchao Zhang PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh4, Aot, false, Bo, false, mm->C4)); 11600e3ece09SJunchao Zhang #if PETSC_PKG_KOKKOS_KERNELS_VERSION_LT(4, 0, 0) 11610e3ece09SJunchao Zhang PetscCallCXX(sort_crs_matrix(mm->C3)); 11620e3ece09SJunchao Zhang PetscCallCXX(sort_crs_matrix(mm->C4)); 11630e3ece09SJunchao Zhang #endif 11640e3ece09SJunchao Zhang 11650e3ece09SJunchao Zhang // Reduce E (i.e., C3 and C4)'s rows to form F, and overlap the communication 11660e3ece09SJunchao Zhang PetscIntKokkosViewHost map_h("map_h", bmpi->B->cmap->n); 11670e3ece09SJunchao Zhang PetscCall(MatGetOwnershipRangeColumn(B, &cstart, &cend)); 11680e3ece09SJunchao Zhang PetscCall(MatMPIAIJKokkosReduceBegin(comm, mm->C3, mm->C4, cstart, cend, bmpi->garray, ampi->Mvctx, MAT_INITIAL_MATRIX, map_h.data(), mm)); 11690e3ece09SJunchao Zhang 11700e3ece09SJunchao Zhang // Adt * (B's diag + B's off-diag) 11710e3ece09SJunchao Zhang PetscCallCXX(KokkosSparse::spgemm_symbolic(mm->kh1, Adt, false, Bd, false, mm->C1)); 11720e3ece09SJunchao Zhang PetscCallCXX(KokkosSparse::spgemm_symbolic(mm->kh2, Adt, false, Bo, false, mm->C2_mid)); 11730e3ece09SJunchao Zhang PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh1, Adt, false, Bd, false, mm->C1)); 11740e3ece09SJunchao Zhang PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh2, Adt, false, Bo, false, mm->C2_mid)); 11750e3ece09SJunchao Zhang #if PETSC_PKG_KOKKOS_KERNELS_VERSION_LT(4, 0, 0) 11760e3ece09SJunchao Zhang PetscCallCXX(sort_crs_matrix(mm->C1)); 11770e3ece09SJunchao Zhang PetscCallCXX(sort_crs_matrix(mm->C2_mid)); 11780e3ece09SJunchao Zhang #endif 11790e3ece09SJunchao Zhang 11800e3ece09SJunchao Zhang PetscCall(MatMPIAIJKokkosReduceEnd(comm, mm->C3, mm->C4, cstart, cend, bmpi->garray, ampi->Mvctx, MAT_INITIAL_MATRIX, map_h.data(), mm)); 11810e3ece09SJunchao Zhang 11820e3ece09SJunchao Zhang // Create C2, which shares a, i arrays with C2_mid, but with new column indices and potentially larger column size 11830e3ece09SJunchao Zhang MatColIdxKokkosView oldj = mm->C2_mid.graph.entries, newj("j", oldj.extent(0)); 11840e3ece09SJunchao Zhang PetscIntKokkosView map = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), map_h); 11850e3ece09SJunchao Zhang PetscCallCXX(Kokkos::parallel_for( 11860e3ece09SJunchao Zhang oldj.extent(0), KOKKOS_LAMBDA(const PetscInt i) { newj(i) = map(oldj(i)); })); 11870e3ece09SJunchao Zhang PetscCallCXX(mm->C2 = KokkosCsrMatrix("C2", mm->C2_mid.numRows(), mm->n /*new column size*/, mm->C2_mid.nnz(), mm->C2_mid.values, mm->C2_mid.graph.row_map, newj)); 11880e3ece09SJunchao Zhang 11890e3ece09SJunchao Zhang // C = (C1+Fd, C2+Fo) 11900e3ece09SJunchao Zhang PetscCallCXX(mm->kh1.create_spadd_handle(true)); // C1, Fd are sorted 11910e3ece09SJunchao Zhang PetscCallCXX(mm->kh2.create_spadd_handle(true)); // C2, Fo are sorted 11920e3ece09SJunchao Zhang PetscCallCXX(KokkosSparse::spadd_symbolic(&mm->kh1, mm->C1, mm->Fd, mm->Cd)); 11930e3ece09SJunchao Zhang PetscCallCXX(KokkosSparse::spadd_symbolic(&mm->kh2, mm->C2, mm->Fo, mm->Co)); 11940e3ece09SJunchao Zhang PetscCallCXX(KokkosSparse::spadd_numeric(&mm->kh1, 1.0, mm->C1, 1.0, mm->Fd, mm->Cd)); 11950e3ece09SJunchao Zhang PetscCallCXX(KokkosSparse::spadd_numeric(&mm->kh2, 1.0, mm->C2, 1.0, mm->Fo, mm->Co)); 11960e3ece09SJunchao Zhang PetscFunctionReturn(PETSC_SUCCESS); 11970e3ece09SJunchao Zhang } 11980e3ece09SJunchao Zhang 11990e3ece09SJunchao Zhang static PetscErrorCode MatProductNumeric_MPIAIJKokkos_AtB(Mat_Product *product, Mat A, Mat B, MatMatStruct_AtB *mm) 12000e3ece09SJunchao Zhang { 12010e3ece09SJunchao Zhang Mat_MPIAIJ *ampi = static_cast<Mat_MPIAIJ *>(A->data); 12020e3ece09SJunchao Zhang Mat_MPIAIJ *bmpi = static_cast<Mat_MPIAIJ *>(B->data); 12030e3ece09SJunchao Zhang KokkosCsrMatrix Adt, Aot, Bd, Bo; 12040e3ece09SJunchao Zhang MPI_Comm comm; 12050e3ece09SJunchao Zhang 12060e3ece09SJunchao Zhang PetscFunctionBegin; 12070e3ece09SJunchao Zhang PetscCall(PetscObjectGetComm((PetscObject)B, &comm)); 12080e3ece09SJunchao Zhang PetscCall(MatSeqAIJKokkosGenerateTranspose_Private(ampi->A, &Adt)); 12090e3ece09SJunchao Zhang PetscCall(MatSeqAIJKokkosGenerateTranspose_Private(ampi->B, &Aot)); 12100e3ece09SJunchao Zhang PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(bmpi->A, &Bd)); 12110e3ece09SJunchao Zhang PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(bmpi->B, &Bo)); 12120e3ece09SJunchao Zhang 12130e3ece09SJunchao Zhang // Aot * (B's diag + B's off-diag) 12140e3ece09SJunchao Zhang PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh3, Aot, false, Bd, false, mm->C3)); 12150e3ece09SJunchao Zhang PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh4, Aot, false, Bo, false, mm->C4)); 12160e3ece09SJunchao Zhang 12170e3ece09SJunchao Zhang // Reduce E (i.e., C3 and C4)'s rows to form F, and overlap the communication 12180e3ece09SJunchao Zhang PetscCall(MatMPIAIJKokkosReduceBegin(comm, mm->C3, mm->C4, 0, 0, NULL, NULL, MAT_REUSE_MATRIX, NULL, mm)); 12190e3ece09SJunchao Zhang 12200e3ece09SJunchao Zhang // Adt * (B's diag + B's off-diag) 12210e3ece09SJunchao Zhang PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh1, Adt, false, Bd, false, mm->C1)); 12220e3ece09SJunchao Zhang PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh2, Adt, false, Bo, false, mm->C2_mid)); 12230e3ece09SJunchao Zhang 12240e3ece09SJunchao Zhang PetscCall(MatMPIAIJKokkosReduceEnd(comm, mm->C3, mm->C4, 0, 0, NULL, NULL, MAT_REUSE_MATRIX, NULL, mm)); 12250e3ece09SJunchao Zhang 12260e3ece09SJunchao Zhang // C = (C1+Fd, C2+Fo) 12270e3ece09SJunchao Zhang PetscCallCXX(KokkosSparse::spadd_numeric(&mm->kh1, 1.0, mm->C1, 1.0, mm->Fd, mm->Cd)); 12280e3ece09SJunchao Zhang PetscCallCXX(KokkosSparse::spadd_numeric(&mm->kh2, 1.0, mm->C2, 1.0, mm->Fo, mm->Co)); 12290e3ece09SJunchao Zhang PetscFunctionReturn(PETSC_SUCCESS); 12300e3ece09SJunchao Zhang } 1231f0e6e2d1SJunchao Zhang 1232076ba34aSJunchao Zhang /* MatProductSymbolic_MPIAIJKokkos_AB - AB flavor of MatProductSymbolic_MPIAIJKokkos 1233076ba34aSJunchao Zhang 1234076ba34aSJunchao Zhang Input Parameters: 1235076ba34aSJunchao Zhang + product - Mat_Product which carried out the computation. Passed in to access info about this mat product. 1236076ba34aSJunchao Zhang . A - an MPIAIJKOKKOS matrix 1237076ba34aSJunchao Zhang . B - an MPIAIJKOKKOS matrix 1238076ba34aSJunchao Zhang - mm - a struct used to stash intermediate data when computing AB. Persist from symbolic to numeric operations. 1239076ba34aSJunchao Zhang */ 1240d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatProductSymbolic_MPIAIJKokkos_AB(Mat_Product *product, Mat A, Mat B, MatMatStruct_AB *mm) 1241d71ae5a4SJacob Faibussowitsch { 12420e3ece09SJunchao Zhang Mat_MPIAIJ *ampi = static_cast<Mat_MPIAIJ *>(A->data); 12430e3ece09SJunchao Zhang Mat_MPIAIJ *bmpi = static_cast<Mat_MPIAIJ *>(B->data); 12440e3ece09SJunchao Zhang KokkosCsrMatrix Ad, Ao, Bd, Bo; 1245076ba34aSJunchao Zhang 1246076ba34aSJunchao Zhang PetscFunctionBegin; 12470e3ece09SJunchao Zhang PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(ampi->A, &Ad)); 12480e3ece09SJunchao Zhang PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(ampi->B, &Ao)); 12490e3ece09SJunchao Zhang PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(bmpi->A, &Bd)); 12500e3ece09SJunchao Zhang PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(bmpi->B, &Bo)); 12510e3ece09SJunchao Zhang 12520e3ece09SJunchao Zhang // TODO: add command line options to select spgemm algorithms 12530e3ece09SJunchao Zhang auto spgemm_alg = KokkosSparse::SPGEMMAlgorithm::SPGEMM_DEFAULT; // default is TPL if enabled, otherwise KK 12540e3ece09SJunchao Zhang 12550e3ece09SJunchao Zhang // CUDA-10.2's spgemm has bugs. We prefer the SpGEMMreuse APIs introduced in cuda-11.4 12560e3ece09SJunchao Zhang #if defined(KOKKOSKERNELS_ENABLE_TPL_CUSPARSE) 12570e3ece09SJunchao Zhang #if PETSC_PKG_CUDA_VERSION_LT(11, 4, 0) 12580e3ece09SJunchao Zhang spgemm_alg = KokkosSparse::SPGEMMAlgorithm::SPGEMM_KK; 12590e3ece09SJunchao Zhang #endif 1260f0e6e2d1SJunchao Zhang #endif 1261f0e6e2d1SJunchao Zhang 12620e3ece09SJunchao Zhang mm->kh1.create_spgemm_handle(spgemm_alg); 12630e3ece09SJunchao Zhang mm->kh2.create_spgemm_handle(spgemm_alg); 12640e3ece09SJunchao Zhang mm->kh3.create_spgemm_handle(spgemm_alg); 12650e3ece09SJunchao Zhang mm->kh4.create_spgemm_handle(spgemm_alg); 1266076ba34aSJunchao Zhang 12670e3ece09SJunchao Zhang // Bcast B's rows to form F, and overlap the communication 12680e3ece09SJunchao Zhang PetscIntKokkosViewHost map_h("map_h", bmpi->B->cmap->n); 12690e3ece09SJunchao Zhang PetscCall(MatMPIAIJKokkosBcastBegin(B, ampi->Mvctx, MAT_INITIAL_MATRIX, map_h.data(), mm)); 1270076ba34aSJunchao Zhang 12710e3ece09SJunchao Zhang // A's diag * (B's diag + B's off-diag) 12720e3ece09SJunchao Zhang PetscCallCXX(KokkosSparse::spgemm_symbolic(mm->kh1, Ad, false, Bd, false, mm->C1)); 12730e3ece09SJunchao Zhang PetscCallCXX(KokkosSparse::spgemm_symbolic(mm->kh2, Ad, false, Bo, false, mm->C2_mid)); // C2 aliases with C2_mid, except with new column indices 12740e3ece09SJunchao Zhang // KK spgemm_symbolic() only populates the result's row map, but not its columns. 12750e3ece09SJunchao Zhang // TODO: Remove the fake spgemm_numeric() after KK fixed this problem. 12760e3ece09SJunchao Zhang PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh1, Ad, false, Bd, false, mm->C1)); 12770e3ece09SJunchao Zhang PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh2, Ad, false, Bo, false, mm->C2_mid)); 12780e3ece09SJunchao Zhang #if PETSC_PKG_KOKKOS_KERNELS_VERSION_LT(4, 0, 0) 12790e3ece09SJunchao Zhang PetscCallCXX(sort_crs_matrix(mm->C1)); 12800e3ece09SJunchao Zhang PetscCallCXX(sort_crs_matrix(mm->C2_mid)); 12810e3ece09SJunchao Zhang #endif 1282076ba34aSJunchao Zhang 12830e3ece09SJunchao Zhang PetscCall(MatMPIAIJKokkosBcastEnd(B, ampi->Mvctx, MAT_INITIAL_MATRIX, map_h.data(), mm)); 1284076ba34aSJunchao Zhang 12850e3ece09SJunchao Zhang // A's off-diag * (F's diag + F's off-diag) 12860e3ece09SJunchao Zhang PetscCallCXX(KokkosSparse::spgemm_symbolic(mm->kh3, Ao, false, mm->Fd, false, mm->C3)); 12870e3ece09SJunchao Zhang PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh3, Ao, false, mm->Fd, false, mm->C3)); 12880e3ece09SJunchao Zhang PetscCallCXX(KokkosSparse::spgemm_symbolic(mm->kh4, Ao, false, mm->Fo, false, mm->C4)); 12890e3ece09SJunchao Zhang PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh4, Ao, false, mm->Fo, false, mm->C4)); 12900e3ece09SJunchao Zhang #if PETSC_PKG_KOKKOS_KERNELS_VERSION_LT(4, 0, 0) 12910e3ece09SJunchao Zhang PetscCallCXX(sort_crs_matrix(mm->C3)); 12920e3ece09SJunchao Zhang PetscCallCXX(sort_crs_matrix(mm->C4)); 12930e3ece09SJunchao Zhang #endif 1294076ba34aSJunchao Zhang 12950e3ece09SJunchao Zhang // Create C2, which shares a, i arrays with C2_mid, but with new column indices and potentially larger column size 12960e3ece09SJunchao Zhang MatColIdxKokkosView oldj = mm->C2_mid.graph.entries, newj("j", oldj.extent(0)); 12970e3ece09SJunchao Zhang PetscIntKokkosView map = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), map_h); 12980e3ece09SJunchao Zhang PetscCallCXX(Kokkos::parallel_for( 12990e3ece09SJunchao Zhang oldj.extent(0), KOKKOS_LAMBDA(const PetscInt i) { newj(i) = map(oldj(i)); })); 13000e3ece09SJunchao Zhang mm->C2 = KokkosCsrMatrix("C2", mm->C2_mid.numRows(), mm->n /*new column size*/, mm->C2_mid.nnz(), mm->C2_mid.values, mm->C2_mid.graph.row_map, newj); 13010e3ece09SJunchao Zhang 13020e3ece09SJunchao Zhang // C = (Cd, Co) = (C1+C3, C2+C4) 13030e3ece09SJunchao Zhang mm->kh1.create_spadd_handle(true); // C1, C3 are sorted 13040e3ece09SJunchao Zhang mm->kh2.create_spadd_handle(true); // C2, C4 are sorted 13050e3ece09SJunchao Zhang PetscCallCXX(KokkosSparse::spadd_symbolic(&mm->kh1, mm->C1, mm->C3, mm->Cd)); 13060e3ece09SJunchao Zhang PetscCallCXX(KokkosSparse::spadd_symbolic(&mm->kh2, mm->C2, mm->C4, mm->Co)); 13070e3ece09SJunchao Zhang PetscCallCXX(KokkosSparse::spadd_numeric(&mm->kh1, 1.0, mm->C1, 1.0, mm->C3, mm->Cd)); 13080e3ece09SJunchao Zhang PetscCallCXX(KokkosSparse::spadd_numeric(&mm->kh2, 1.0, mm->C2, 1.0, mm->C4, mm->Co)); 13093ba16761SJacob Faibussowitsch PetscFunctionReturn(PETSC_SUCCESS); 1310076ba34aSJunchao Zhang } 1311076ba34aSJunchao Zhang 13120e3ece09SJunchao Zhang static PetscErrorCode MatProductNumeric_MPIAIJKokkos_AB(Mat_Product *product, Mat A, Mat B, MatMatStruct_AB *mm) 1313d71ae5a4SJacob Faibussowitsch { 13140e3ece09SJunchao Zhang Mat_MPIAIJ *ampi = static_cast<Mat_MPIAIJ *>(A->data); 13150e3ece09SJunchao Zhang Mat_MPIAIJ *bmpi = static_cast<Mat_MPIAIJ *>(B->data); 13160e3ece09SJunchao Zhang KokkosCsrMatrix Ad, Ao, Bd, Bo; 1317076ba34aSJunchao Zhang 1318076ba34aSJunchao Zhang PetscFunctionBegin; 13190e3ece09SJunchao Zhang PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(ampi->A, &Ad)); 13200e3ece09SJunchao Zhang PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(ampi->B, &Ao)); 13210e3ece09SJunchao Zhang PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(bmpi->A, &Bd)); 13220e3ece09SJunchao Zhang PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(bmpi->B, &Bo)); 1323076ba34aSJunchao Zhang 13240e3ece09SJunchao Zhang // Bcast B's rows to form F, and overlap the communication 13250e3ece09SJunchao Zhang PetscCall(MatMPIAIJKokkosBcastBegin(B, NULL, MAT_REUSE_MATRIX, NULL, mm)); 1326076ba34aSJunchao Zhang 13270e3ece09SJunchao Zhang // A's diag * (B's diag + B's off-diag) 13280e3ece09SJunchao Zhang PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh1, Ad, false, Bd, false, mm->C1)); 13290e3ece09SJunchao Zhang PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh2, Ad, false, Bo, false, mm->C2_mid)); 1330076ba34aSJunchao Zhang 13310e3ece09SJunchao Zhang PetscCall(MatMPIAIJKokkosBcastEnd(B, NULL, MAT_REUSE_MATRIX, NULL, mm)); 1332076ba34aSJunchao Zhang 13330e3ece09SJunchao Zhang // A's off-diag * (F's diag + F's off-diag) 13340e3ece09SJunchao Zhang PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh3, Ao, false, mm->Fd, false, mm->C3)); 13350e3ece09SJunchao Zhang PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh4, Ao, false, mm->Fo, false, mm->C4)); 13360e3ece09SJunchao Zhang 13370e3ece09SJunchao Zhang // C = (Cd, Co) = (C1+C3, C2+C4) 13380e3ece09SJunchao Zhang PetscCallCXX(KokkosSparse::spadd_numeric(&mm->kh1, 1.0, mm->C1, 1.0, mm->C3, mm->Cd)); 13390e3ece09SJunchao Zhang PetscCallCXX(KokkosSparse::spadd_numeric(&mm->kh2, 1.0, mm->C2, 1.0, mm->C4, mm->Co)); 13403ba16761SJacob Faibussowitsch PetscFunctionReturn(PETSC_SUCCESS); 1341076ba34aSJunchao Zhang } 1342076ba34aSJunchao Zhang 1343d71ae5a4SJacob Faibussowitsch PetscErrorCode MatProductNumeric_MPIAIJKokkos(Mat C) 1344d71ae5a4SJacob Faibussowitsch { 13450e3ece09SJunchao Zhang Mat_MPIAIJ *cmpi = static_cast<Mat_MPIAIJ *>(C->data); 13460e3ece09SJunchao Zhang Mat_Product *product; 13470e3ece09SJunchao Zhang MatProductData_MPIAIJKokkos *pdata; 1348076ba34aSJunchao Zhang MatProductType ptype; 13490e3ece09SJunchao Zhang Mat A, B; 1350076ba34aSJunchao Zhang 1351076ba34aSJunchao Zhang PetscFunctionBegin; 13520e3ece09SJunchao Zhang MatCheckProduct(C, 1); // make sure C is a product 13530e3ece09SJunchao Zhang product = C->product; 13540e3ece09SJunchao Zhang pdata = static_cast<MatProductData_MPIAIJKokkos *>(product->data); 1355076ba34aSJunchao Zhang ptype = product->type; 1356076ba34aSJunchao Zhang A = product->A; 1357076ba34aSJunchao Zhang B = product->B; 1358076ba34aSJunchao Zhang 13590e3ece09SJunchao Zhang // See if numeric has already been done in symbolic (e.g., user calls MatMatMult(A,B,MAT_INITIAL_MATRIX,..,C)). 13600e3ece09SJunchao Zhang // If yes, skip the numeric, but reset the flag so that next time when user calls MatMatMult(E,F,MAT_REUSE_MATRIX,..,C), 13610e3ece09SJunchao Zhang // we still do numeric. 13620e3ece09SJunchao Zhang if (pdata->reusesym) { // numeric reuses results from symbolic 13630e3ece09SJunchao Zhang pdata->reusesym = PETSC_FALSE; 13643ba16761SJacob Faibussowitsch PetscFunctionReturn(PETSC_SUCCESS); 1365076ba34aSJunchao Zhang } 1366076ba34aSJunchao Zhang 1367076ba34aSJunchao Zhang if (ptype == MATPRODUCT_AB) { 13680e3ece09SJunchao Zhang PetscCall(MatProductNumeric_MPIAIJKokkos_AB(product, A, B, pdata->mmAB)); 1369076ba34aSJunchao Zhang } else if (ptype == MATPRODUCT_AtB) { 13700e3ece09SJunchao Zhang PetscCall(MatProductNumeric_MPIAIJKokkos_AtB(product, A, B, pdata->mmAtB)); 13710e3ece09SJunchao Zhang } else if (ptype == MATPRODUCT_PtAP) { // BtAB, computed by Z = AB; C= BtZ 13720e3ece09SJunchao Zhang PetscCall(MatProductNumeric_MPIAIJKokkos_AB(product, A, B, pdata->mmAB)); 13730e3ece09SJunchao Zhang PetscCall(MatProductNumeric_MPIAIJKokkos_AtB(product, B, pdata->Z, pdata->mmAtB)); 1374076ba34aSJunchao Zhang } 13750e3ece09SJunchao Zhang 13760e3ece09SJunchao Zhang PetscCall(MatSeqAIJKokkosModifyDevice(cmpi->A)); // mark that A, B on device are modified 13770e3ece09SJunchao Zhang PetscCall(MatSeqAIJKokkosModifyDevice(cmpi->B)); 13783ba16761SJacob Faibussowitsch PetscFunctionReturn(PETSC_SUCCESS); 1379076ba34aSJunchao Zhang } 1380076ba34aSJunchao Zhang 1381d71ae5a4SJacob Faibussowitsch PetscErrorCode MatProductSymbolic_MPIAIJKokkos(Mat C) 1382d71ae5a4SJacob Faibussowitsch { 1383076ba34aSJunchao Zhang Mat A, B; 13840e3ece09SJunchao Zhang Mat_Product *product; 1385076ba34aSJunchao Zhang MatProductType ptype; 13860e3ece09SJunchao Zhang MatProductData_MPIAIJKokkos *pdata; 1387076ba34aSJunchao Zhang MatMatStruct *mm = NULL; 13880e3ece09SJunchao Zhang PetscInt m, n, M, N; 13890e3ece09SJunchao Zhang Mat Cd, Co; 13900e3ece09SJunchao Zhang MPI_Comm comm; 1391076ba34aSJunchao Zhang 1392076ba34aSJunchao Zhang PetscFunctionBegin; 13930e3ece09SJunchao Zhang PetscCall(PetscObjectGetComm((PetscObject)C, &comm)); 1394076ba34aSJunchao Zhang MatCheckProduct(C, 1); 13950e3ece09SJunchao Zhang product = C->product; 13960e3ece09SJunchao Zhang PetscCheck(!product->data, comm, PETSC_ERR_PLIB, "Product data not empty"); 1397076ba34aSJunchao Zhang ptype = product->type; 1398076ba34aSJunchao Zhang A = product->A; 1399076ba34aSJunchao Zhang B = product->B; 1400076ba34aSJunchao Zhang 1401076ba34aSJunchao Zhang switch (ptype) { 14029371c9d4SSatish Balay case MATPRODUCT_AB: 14039371c9d4SSatish Balay m = A->rmap->n; 14049371c9d4SSatish Balay n = B->cmap->n; 14059371c9d4SSatish Balay M = A->rmap->N; 14069371c9d4SSatish Balay N = B->cmap->N; 14079371c9d4SSatish Balay break; 14089371c9d4SSatish Balay case MATPRODUCT_AtB: 14099371c9d4SSatish Balay m = A->cmap->n; 14109371c9d4SSatish Balay n = B->cmap->n; 14119371c9d4SSatish Balay M = A->cmap->N; 14129371c9d4SSatish Balay N = B->cmap->N; 14139371c9d4SSatish Balay break; 14149371c9d4SSatish Balay case MATPRODUCT_PtAP: 14159371c9d4SSatish Balay m = B->cmap->n; 14169371c9d4SSatish Balay n = B->cmap->n; 14179371c9d4SSatish Balay M = B->cmap->N; 14189371c9d4SSatish Balay N = B->cmap->N; 14199371c9d4SSatish Balay break; /* BtAB */ 1420d71ae5a4SJacob Faibussowitsch default: 14210e3ece09SJunchao Zhang SETERRQ(comm, PETSC_ERR_PLIB, "Not for product type %s", MatProductTypes[ptype]); 1422076ba34aSJunchao Zhang } 1423076ba34aSJunchao Zhang 14249566063dSJacob Faibussowitsch PetscCall(MatSetSizes(C, m, n, M, N)); 14259566063dSJacob Faibussowitsch PetscCall(PetscLayoutSetUp(C->rmap)); 14269566063dSJacob Faibussowitsch PetscCall(PetscLayoutSetUp(C->cmap)); 14279566063dSJacob Faibussowitsch PetscCall(MatSetType(C, ((PetscObject)A)->type_name)); 1428076ba34aSJunchao Zhang 14290e3ece09SJunchao Zhang pdata = new MatProductData_MPIAIJKokkos(); 14300e3ece09SJunchao Zhang pdata->reusesym = product->api_user; 1431076ba34aSJunchao Zhang 1432076ba34aSJunchao Zhang if (ptype == MATPRODUCT_AB) { 14330e3ece09SJunchao Zhang auto mmAB = new MatMatStruct_AB(); 14340e3ece09SJunchao Zhang PetscCall(MatProductSymbolic_MPIAIJKokkos_AB(product, A, B, mmAB)); 14350e3ece09SJunchao Zhang mm = pdata->mmAB = mmAB; 1436076ba34aSJunchao Zhang } else if (ptype == MATPRODUCT_AtB) { 14370e3ece09SJunchao Zhang auto mmAtB = new MatMatStruct_AtB(); 14380e3ece09SJunchao Zhang PetscCall(MatProductSymbolic_MPIAIJKokkos_AtB(product, A, B, mmAtB)); 14390e3ece09SJunchao Zhang mm = pdata->mmAtB = mmAtB; 14400e3ece09SJunchao Zhang } else if (ptype == MATPRODUCT_PtAP) { // C = BtAB, computed as Z = AB; C= BtZ 14410e3ece09SJunchao Zhang Mat Zd, Zo, Z; // Zd, Zo are owned by pdata->Z 14420e3ece09SJunchao Zhang 14430e3ece09SJunchao Zhang auto mmAB = new MatMatStruct_AB(); 14440e3ece09SJunchao Zhang PetscCall(MatProductSymbolic_MPIAIJKokkos_AB(product, A, B, mmAB)); // Z stored as mmAB->{Cd, Co} 14450e3ece09SJunchao Zhang PetscCall(MatCreateSeqAIJKokkosWithKokkosCsrMatrix(PETSC_COMM_SELF, mmAB->Cd, &Zd)); 14460e3ece09SJunchao Zhang PetscCall(MatCreateSeqAIJKokkosWithKokkosCsrMatrix(PETSC_COMM_SELF, mmAB->Co, &Zo)); 14470e3ece09SJunchao Zhang pdata->mmAB = mmAB; 14480e3ece09SJunchao Zhang 14490e3ece09SJunchao Zhang m = A->rmap->n; // Z's layout 14500e3ece09SJunchao Zhang n = B->cmap->n; 14510e3ece09SJunchao Zhang M = A->rmap->N; 14520e3ece09SJunchao Zhang N = B->cmap->N; 14530e3ece09SJunchao Zhang PetscCall(MatCreate(comm, &Z)); 14540e3ece09SJunchao Zhang PetscCall(MatSetSizes(Z, m, n, M, N)); 14550e3ece09SJunchao Zhang PetscCall(PetscLayoutSetUp(Z->rmap)); 14560e3ece09SJunchao Zhang PetscCall(PetscLayoutSetUp(Z->cmap)); 14570e3ece09SJunchao Zhang PetscCall(MatSetType(Z, MATMPIAIJKOKKOS)); 14580e3ece09SJunchao Zhang PetscCall(MatSetMPIAIJKokkosWithSplitSeqAIJKokkosMatrices(Z, Zd, Zo, mmAB->garray)); 14590e3ece09SJunchao Zhang 14600e3ece09SJunchao Zhang auto mmAtB = new MatMatStruct_AtB(); 14610e3ece09SJunchao Zhang PetscCall(MatProductSymbolic_MPIAIJKokkos_AtB(product, B, Z, mmAtB)); // final result C stored as mmAtB->{Cd, Co} 14620e3ece09SJunchao Zhang 14630e3ece09SJunchao Zhang pdata->Z = Z; // give ownership to pdata 14640e3ece09SJunchao Zhang mm = pdata->mmAtB = mmAtB; 1465076ba34aSJunchao Zhang } 14660e3ece09SJunchao Zhang 14670e3ece09SJunchao Zhang PetscCall(MatCreateSeqAIJKokkosWithKokkosCsrMatrix(PETSC_COMM_SELF, mm->Cd, &Cd)); 14680e3ece09SJunchao Zhang PetscCall(MatCreateSeqAIJKokkosWithKokkosCsrMatrix(PETSC_COMM_SELF, mm->Co, &Co)); 14690e3ece09SJunchao Zhang PetscCall(MatSetMPIAIJKokkosWithSplitSeqAIJKokkosMatrices(C, Cd, Co, mm->garray)); 14700e3ece09SJunchao Zhang 14710e3ece09SJunchao Zhang C->product->data = pdata; 1472076ba34aSJunchao Zhang C->product->destroy = MatProductDataDestroy_MPIAIJKokkos; 1473076ba34aSJunchao Zhang C->ops->productnumeric = MatProductNumeric_MPIAIJKokkos; 14743ba16761SJacob Faibussowitsch PetscFunctionReturn(PETSC_SUCCESS); 1475076ba34aSJunchao Zhang } 1476076ba34aSJunchao Zhang 1477d71ae5a4SJacob Faibussowitsch PETSC_INTERN PetscErrorCode MatProductSetFromOptions_MPIAIJKokkos(Mat mat) 1478d71ae5a4SJacob Faibussowitsch { 1479076ba34aSJunchao Zhang Mat_Product *product = mat->product; 1480076ba34aSJunchao Zhang PetscBool match = PETSC_FALSE; 1481076ba34aSJunchao Zhang PetscBool usecpu = PETSC_FALSE; 1482076ba34aSJunchao Zhang 1483076ba34aSJunchao Zhang PetscFunctionBegin; 1484076ba34aSJunchao Zhang MatCheckProduct(mat, 1); 148548a46eb9SPierre Jolivet if (!product->A->boundtocpu && !product->B->boundtocpu) PetscCall(PetscObjectTypeCompare((PetscObject)product->B, ((PetscObject)product->A)->type_name, &match)); 1486076ba34aSJunchao Zhang if (match) { /* we can always fallback to the CPU if requested */ 1487076ba34aSJunchao Zhang switch (product->type) { 1488076ba34aSJunchao Zhang case MATPRODUCT_AB: 1489076ba34aSJunchao Zhang if (product->api_user) { 1490d0609cedSBarry Smith PetscOptionsBegin(PetscObjectComm((PetscObject)mat), ((PetscObject)mat)->prefix, "MatMatMult", "Mat"); 14919566063dSJacob Faibussowitsch PetscCall(PetscOptionsBool("-matmatmult_backend_cpu", "Use CPU code", "MatMatMult", usecpu, &usecpu, NULL)); 1492d0609cedSBarry Smith PetscOptionsEnd(); 1493076ba34aSJunchao Zhang } else { 1494d0609cedSBarry Smith PetscOptionsBegin(PetscObjectComm((PetscObject)mat), ((PetscObject)mat)->prefix, "MatProduct_AB", "Mat"); 14959566063dSJacob Faibussowitsch PetscCall(PetscOptionsBool("-mat_product_algorithm_backend_cpu", "Use CPU code", "MatMatMult", usecpu, &usecpu, NULL)); 1496d0609cedSBarry Smith PetscOptionsEnd(); 1497076ba34aSJunchao Zhang } 1498076ba34aSJunchao Zhang break; 1499076ba34aSJunchao Zhang case MATPRODUCT_AtB: 1500076ba34aSJunchao Zhang if (product->api_user) { 1501d0609cedSBarry Smith PetscOptionsBegin(PetscObjectComm((PetscObject)mat), ((PetscObject)mat)->prefix, "MatTransposeMatMult", "Mat"); 15029566063dSJacob Faibussowitsch PetscCall(PetscOptionsBool("-mattransposematmult_backend_cpu", "Use CPU code", "MatTransposeMatMult", usecpu, &usecpu, NULL)); 1503d0609cedSBarry Smith PetscOptionsEnd(); 1504076ba34aSJunchao Zhang } else { 1505d0609cedSBarry Smith PetscOptionsBegin(PetscObjectComm((PetscObject)mat), ((PetscObject)mat)->prefix, "MatProduct_AtB", "Mat"); 15069566063dSJacob Faibussowitsch PetscCall(PetscOptionsBool("-mat_product_algorithm_backend_cpu", "Use CPU code", "MatTransposeMatMult", usecpu, &usecpu, NULL)); 1507d0609cedSBarry Smith PetscOptionsEnd(); 1508076ba34aSJunchao Zhang } 1509076ba34aSJunchao Zhang break; 1510076ba34aSJunchao Zhang case MATPRODUCT_PtAP: 1511076ba34aSJunchao Zhang if (product->api_user) { 1512d0609cedSBarry Smith PetscOptionsBegin(PetscObjectComm((PetscObject)mat), ((PetscObject)mat)->prefix, "MatPtAP", "Mat"); 15139566063dSJacob Faibussowitsch PetscCall(PetscOptionsBool("-matptap_backend_cpu", "Use CPU code", "MatPtAP", usecpu, &usecpu, NULL)); 1514d0609cedSBarry Smith PetscOptionsEnd(); 1515076ba34aSJunchao Zhang } else { 1516d0609cedSBarry Smith PetscOptionsBegin(PetscObjectComm((PetscObject)mat), ((PetscObject)mat)->prefix, "MatProduct_PtAP", "Mat"); 15179566063dSJacob Faibussowitsch PetscCall(PetscOptionsBool("-mat_product_algorithm_backend_cpu", "Use CPU code", "MatPtAP", usecpu, &usecpu, NULL)); 1518d0609cedSBarry Smith PetscOptionsEnd(); 1519076ba34aSJunchao Zhang } 1520076ba34aSJunchao Zhang break; 1521d71ae5a4SJacob Faibussowitsch default: 1522d71ae5a4SJacob Faibussowitsch break; 1523076ba34aSJunchao Zhang } 1524076ba34aSJunchao Zhang match = (PetscBool)!usecpu; 1525076ba34aSJunchao Zhang } 1526076ba34aSJunchao Zhang if (match) { 1527076ba34aSJunchao Zhang switch (product->type) { 1528076ba34aSJunchao Zhang case MATPRODUCT_AB: 1529076ba34aSJunchao Zhang case MATPRODUCT_AtB: 1530d71ae5a4SJacob Faibussowitsch case MATPRODUCT_PtAP: 1531d71ae5a4SJacob Faibussowitsch mat->ops->productsymbolic = MatProductSymbolic_MPIAIJKokkos; 1532d71ae5a4SJacob Faibussowitsch break; 1533d71ae5a4SJacob Faibussowitsch default: 1534d71ae5a4SJacob Faibussowitsch break; 1535076ba34aSJunchao Zhang } 1536076ba34aSJunchao Zhang } 1537076ba34aSJunchao Zhang /* fallback to MPIAIJ ops */ 153848a46eb9SPierre Jolivet if (!mat->ops->productsymbolic) PetscCall(MatProductSetFromOptions_MPIAIJ(mat)); 15393ba16761SJacob Faibussowitsch PetscFunctionReturn(PETSC_SUCCESS); 1540076ba34aSJunchao Zhang } 1541076ba34aSJunchao Zhang 1542d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatSetPreallocationCOO_MPIAIJKokkos(Mat mat, PetscCount coo_n, PetscInt coo_i[], PetscInt coo_j[]) 1543d71ae5a4SJacob Faibussowitsch { 1544394ed5ebSJunchao Zhang Mat_MPIAIJ *mpiaij = (Mat_MPIAIJ *)mat->data; 1545cbc6b225SStefano Zampini Mat_MPIAIJKokkos *mpikok; 154642550becSJunchao Zhang 154742550becSJunchao Zhang PetscFunctionBegin; 154830203840SJunchao Zhang PetscCall(MatSetPreallocationCOO_MPIAIJ(mat, coo_n, coo_i, coo_j)); /* mpiaij->A,B's type is set to seqaijkokkos */ 1549cbc6b225SStefano Zampini mat->preallocated = PETSC_TRUE; 15509566063dSJacob Faibussowitsch PetscCall(MatAssemblyBegin(mat, MAT_FINAL_ASSEMBLY)); 15519566063dSJacob Faibussowitsch PetscCall(MatAssemblyEnd(mat, MAT_FINAL_ASSEMBLY)); 15529566063dSJacob Faibussowitsch PetscCall(MatZeroEntries(mat)); 1553cbc6b225SStefano Zampini mpikok = static_cast<Mat_MPIAIJKokkos *>(mpiaij->spptr); 1554cbc6b225SStefano Zampini delete mpikok; 1555394ed5ebSJunchao Zhang mpiaij->spptr = new Mat_MPIAIJKokkos(mpiaij); 15563ba16761SJacob Faibussowitsch PetscFunctionReturn(PETSC_SUCCESS); 155742550becSJunchao Zhang } 155842550becSJunchao Zhang 1559d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatSetValuesCOO_MPIAIJKokkos(Mat mat, const PetscScalar v[], InsertMode imode) 1560d71ae5a4SJacob Faibussowitsch { 1561394ed5ebSJunchao Zhang Mat_MPIAIJ *mpiaij = static_cast<Mat_MPIAIJ *>(mat->data); 156242550becSJunchao Zhang Mat_MPIAIJKokkos *mpikok = static_cast<Mat_MPIAIJKokkos *>(mpiaij->spptr); 156342550becSJunchao Zhang Mat A = mpiaij->A, B = mpiaij->B; 1564158ec288SJunchao Zhang PetscCount Annz = mpiaij->Annz, Annz2 = mpiaij->Annz2, Bnnz = mpiaij->Bnnz, Bnnz2 = mpiaij->Bnnz2; 156542550becSJunchao Zhang MatScalarKokkosView Aa, Ba; 1566394ed5ebSJunchao Zhang MatScalarKokkosView v1; 156742550becSJunchao Zhang MatScalarKokkosView &vsend = mpikok->sendbuf_d; 156842550becSJunchao Zhang const MatScalarKokkosView &v2 = mpikok->recvbuf_d; 1569158ec288SJunchao Zhang const PetscCountKokkosView &Ajmap1 = mpikok->Ajmap1_d, Ajmap2 = mpikok->Ajmap2_d, Aimap2 = mpikok->Aimap2_d; 1570158ec288SJunchao Zhang const PetscCountKokkosView &Bjmap1 = mpikok->Bjmap1_d, Bjmap2 = mpikok->Bjmap2_d, Bimap2 = mpikok->Bimap2_d; 1571394ed5ebSJunchao Zhang const PetscCountKokkosView &Aperm1 = mpikok->Aperm1_d, Aperm2 = mpikok->Aperm2_d, Bperm1 = mpikok->Bperm1_d, Bperm2 = mpikok->Bperm2_d; 1572394ed5ebSJunchao Zhang const PetscCountKokkosView &Cperm1 = mpikok->Cperm1_d; 157342550becSJunchao Zhang PetscMemType memtype; 157442550becSJunchao Zhang 157542550becSJunchao Zhang PetscFunctionBegin; 15769566063dSJacob Faibussowitsch PetscCall(PetscGetMemType(v, &memtype)); /* Return PETSC_MEMTYPE_HOST when v is NULL */ 157742550becSJunchao Zhang if (PetscMemTypeHost(memtype)) { /* If user gave v[] in host, we need to copy it to device if any */ 1578394ed5ebSJunchao Zhang v1 = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), MatScalarKokkosViewHost((PetscScalar *)v, mpiaij->coo_n)); 157942550becSJunchao Zhang } else { 1580394ed5ebSJunchao Zhang v1 = MatScalarKokkosView((PetscScalar *)v, mpiaij->coo_n); /* Directly use v[]'s memory */ 158142550becSJunchao Zhang } 158242550becSJunchao Zhang 158342550becSJunchao Zhang if (imode == INSERT_VALUES) { 15849566063dSJacob Faibussowitsch PetscCall(MatSeqAIJGetKokkosViewWrite(A, &Aa)); /* write matrix values */ 15859566063dSJacob Faibussowitsch PetscCall(MatSeqAIJGetKokkosViewWrite(B, &Ba)); 1586394ed5ebSJunchao Zhang } else { 15879566063dSJacob Faibussowitsch PetscCall(MatSeqAIJGetKokkosView(A, &Aa)); /* read & write matrix values */ 15889566063dSJacob Faibussowitsch PetscCall(MatSeqAIJGetKokkosView(B, &Ba)); 158942550becSJunchao Zhang } 159042550becSJunchao Zhang 159142550becSJunchao Zhang /* Pack entries to be sent to remote */ 15929371c9d4SSatish Balay Kokkos::parallel_for( 15939371c9d4SSatish Balay vsend.extent(0), KOKKOS_LAMBDA(const PetscCount i) { vsend(i) = v1(Cperm1(i)); }); 159442550becSJunchao Zhang 159542550becSJunchao Zhang /* Send remote entries to their owner and overlap the communication with local computation */ 15969566063dSJacob Faibussowitsch PetscCall(PetscSFReduceWithMemTypeBegin(mpiaij->coo_sf, MPIU_SCALAR, PETSC_MEMTYPE_KOKKOS, vsend.data(), PETSC_MEMTYPE_KOKKOS, v2.data(), MPI_REPLACE)); 1597158ec288SJunchao Zhang /* Add local entries to A and B in one kernel */ 15989371c9d4SSatish Balay Kokkos::parallel_for( 15999371c9d4SSatish Balay Annz + Bnnz, KOKKOS_LAMBDA(PetscCount i) { 1600158ec288SJunchao Zhang PetscScalar sum = 0.0; 1601158ec288SJunchao Zhang if (i < Annz) { 1602158ec288SJunchao Zhang for (PetscCount k = Ajmap1(i); k < Ajmap1(i + 1); k++) sum += v1(Aperm1(k)); 1603ac38520cSJunchao Zhang Aa(i) = (imode == INSERT_VALUES ? 0.0 : Aa(i)) + sum; 1604158ec288SJunchao Zhang } else { 1605158ec288SJunchao Zhang i -= Annz; 1606158ec288SJunchao Zhang for (PetscCount k = Bjmap1(i); k < Bjmap1(i + 1); k++) sum += v1(Bperm1(k)); 1607ac38520cSJunchao Zhang Ba(i) = (imode == INSERT_VALUES ? 0.0 : Ba(i)) + sum; 1608158ec288SJunchao Zhang } 1609158ec288SJunchao Zhang }); 16109566063dSJacob Faibussowitsch PetscCall(PetscSFReduceEnd(mpiaij->coo_sf, MPIU_SCALAR, vsend.data(), v2.data(), MPI_REPLACE)); 161142550becSJunchao Zhang 1612158ec288SJunchao Zhang /* Add received remote entries to A and B in one kernel */ 16139371c9d4SSatish Balay Kokkos::parallel_for( 16149371c9d4SSatish Balay Annz2 + Bnnz2, KOKKOS_LAMBDA(PetscCount i) { 1615158ec288SJunchao Zhang if (i < Annz2) { 1616158ec288SJunchao Zhang for (PetscCount k = Ajmap2(i); k < Ajmap2(i + 1); k++) Aa(Aimap2(i)) += v2(Aperm2(k)); 1617158ec288SJunchao Zhang } else { 1618158ec288SJunchao Zhang i -= Annz2; 1619158ec288SJunchao Zhang for (PetscCount k = Bjmap2(i); k < Bjmap2(i + 1); k++) Ba(Bimap2(i)) += v2(Bperm2(k)); 1620158ec288SJunchao Zhang } 1621158ec288SJunchao Zhang }); 162242550becSJunchao Zhang 1623394ed5ebSJunchao Zhang if (imode == INSERT_VALUES) { 16249566063dSJacob Faibussowitsch PetscCall(MatSeqAIJRestoreKokkosViewWrite(A, &Aa)); /* Increase A & B's state etc. */ 16259566063dSJacob Faibussowitsch PetscCall(MatSeqAIJRestoreKokkosViewWrite(B, &Ba)); 1626394ed5ebSJunchao Zhang } else { 16279566063dSJacob Faibussowitsch PetscCall(MatSeqAIJRestoreKokkosView(A, &Aa)); 16289566063dSJacob Faibussowitsch PetscCall(MatSeqAIJRestoreKokkosView(B, &Ba)); 1629394ed5ebSJunchao Zhang } 16303ba16761SJacob Faibussowitsch PetscFunctionReturn(PETSC_SUCCESS); 163142550becSJunchao Zhang } 163242550becSJunchao Zhang 1633d71ae5a4SJacob Faibussowitsch PetscErrorCode MatDestroy_MPIAIJKokkos(Mat A) 1634d71ae5a4SJacob Faibussowitsch { 163542550becSJunchao Zhang Mat_MPIAIJ *mpiaij = (Mat_MPIAIJ *)A->data; 1636076ba34aSJunchao Zhang 1637076ba34aSJunchao Zhang PetscFunctionBegin; 16389566063dSJacob Faibussowitsch PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatMPIAIJSetPreallocation_C", NULL)); 16399566063dSJacob Faibussowitsch PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatMPIAIJGetLocalMatMerge_C", NULL)); 16409566063dSJacob Faibussowitsch PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatSetPreallocationCOO_C", NULL)); 16419566063dSJacob Faibussowitsch PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatSetValuesCOO_C", NULL)); 164242550becSJunchao Zhang delete (Mat_MPIAIJKokkos *)mpiaij->spptr; 16439566063dSJacob Faibussowitsch PetscCall(MatDestroy_MPIAIJ(A)); 16443ba16761SJacob Faibussowitsch PetscFunctionReturn(PETSC_SUCCESS); 1645076ba34aSJunchao Zhang } 1646076ba34aSJunchao Zhang 1647d71ae5a4SJacob Faibussowitsch PETSC_INTERN PetscErrorCode MatConvert_MPIAIJ_MPIAIJKokkos(Mat A, MatType mtype, MatReuse reuse, Mat *newmat) 1648d71ae5a4SJacob Faibussowitsch { 16498c3ff71bSJunchao Zhang Mat B; 1650076ba34aSJunchao Zhang Mat_MPIAIJ *a; 16518c3ff71bSJunchao Zhang 16528c3ff71bSJunchao Zhang PetscFunctionBegin; 16538c3ff71bSJunchao Zhang if (reuse == MAT_INITIAL_MATRIX) { 16549566063dSJacob Faibussowitsch PetscCall(MatDuplicate(A, MAT_COPY_VALUES, newmat)); 16558c3ff71bSJunchao Zhang } else if (reuse == MAT_REUSE_MATRIX) { 16569566063dSJacob Faibussowitsch PetscCall(MatCopy(A, *newmat, SAME_NONZERO_PATTERN)); 16578c3ff71bSJunchao Zhang } 16588c3ff71bSJunchao Zhang B = *newmat; 16598c3ff71bSJunchao Zhang 16606f3d89d0SStefano Zampini B->boundtocpu = PETSC_FALSE; 16619566063dSJacob Faibussowitsch PetscCall(PetscFree(B->defaultvectype)); 16629566063dSJacob Faibussowitsch PetscCall(PetscStrallocpy(VECKOKKOS, &B->defaultvectype)); 16639566063dSJacob Faibussowitsch PetscCall(PetscObjectChangeTypeName((PetscObject)B, MATMPIAIJKOKKOS)); 16648c3ff71bSJunchao Zhang 1665076ba34aSJunchao Zhang a = static_cast<Mat_MPIAIJ *>(A->data); 16669566063dSJacob Faibussowitsch if (a->A) PetscCall(MatSetType(a->A, MATSEQAIJKOKKOS)); 16679566063dSJacob Faibussowitsch if (a->B) PetscCall(MatSetType(a->B, MATSEQAIJKOKKOS)); 16689566063dSJacob Faibussowitsch if (a->lvec) PetscCall(VecSetType(a->lvec, VECSEQKOKKOS)); 1669076ba34aSJunchao Zhang 16708c3ff71bSJunchao Zhang B->ops->assemblyend = MatAssemblyEnd_MPIAIJKokkos; 16718c3ff71bSJunchao Zhang B->ops->mult = MatMult_MPIAIJKokkos; 16728c3ff71bSJunchao Zhang B->ops->multadd = MatMultAdd_MPIAIJKokkos; 16738c3ff71bSJunchao Zhang B->ops->multtranspose = MatMultTranspose_MPIAIJKokkos; 1674076ba34aSJunchao Zhang B->ops->productsetfromoptions = MatProductSetFromOptions_MPIAIJKokkos; 1675076ba34aSJunchao Zhang B->ops->destroy = MatDestroy_MPIAIJKokkos; 16768c3ff71bSJunchao Zhang 16779566063dSJacob Faibussowitsch PetscCall(PetscObjectComposeFunction((PetscObject)B, "MatMPIAIJSetPreallocation_C", MatMPIAIJSetPreallocation_MPIAIJKokkos)); 16789566063dSJacob Faibussowitsch PetscCall(PetscObjectComposeFunction((PetscObject)B, "MatMPIAIJGetLocalMatMerge_C", MatMPIAIJGetLocalMatMerge_MPIAIJKokkos)); 16799566063dSJacob Faibussowitsch PetscCall(PetscObjectComposeFunction((PetscObject)B, "MatSetPreallocationCOO_C", MatSetPreallocationCOO_MPIAIJKokkos)); 16809566063dSJacob Faibussowitsch PetscCall(PetscObjectComposeFunction((PetscObject)B, "MatSetValuesCOO_C", MatSetValuesCOO_MPIAIJKokkos)); 16813ba16761SJacob Faibussowitsch PetscFunctionReturn(PETSC_SUCCESS); 16828c3ff71bSJunchao Zhang } 16833f3ba80aSJunchao Zhang /*MC 168411a5261eSBarry Smith MATAIJKOKKOS - "mpiaijkokkos", a matrix type to be used for CSR sparse matrices with Kokkos 16858c3ff71bSJunchao Zhang 16863f3ba80aSJunchao Zhang A matrix type type using Kokkos-Kernels CrsMatrix type for portability across different device types 16873f3ba80aSJunchao Zhang 16882ef1f0ffSBarry Smith Options Database Key: 16892ef1f0ffSBarry Smith . -mat_type aijkokkos - sets the matrix type to `MATAIJKOKKOS` 16903f3ba80aSJunchao Zhang 16913f3ba80aSJunchao Zhang Level: beginner 16923f3ba80aSJunchao Zhang 16932ef1f0ffSBarry Smith .seealso: [](chapter_matrices), `Mat`, `MatCreateAIJKokkos()`, `MATSEQAIJKOKKOS`, `MATSEQAIJ`, `MATMPIAIJ` 16943f3ba80aSJunchao Zhang M*/ 1695d71ae5a4SJacob Faibussowitsch PETSC_EXTERN PetscErrorCode MatCreate_MPIAIJKokkos(Mat A) 1696d71ae5a4SJacob Faibussowitsch { 16978c3ff71bSJunchao Zhang PetscFunctionBegin; 16989566063dSJacob Faibussowitsch PetscCall(PetscKokkosInitializeCheck()); 16999566063dSJacob Faibussowitsch PetscCall(MatCreate_MPIAIJ(A)); 17009566063dSJacob Faibussowitsch PetscCall(MatConvert_MPIAIJ_MPIAIJKokkos(A, MATMPIAIJKOKKOS, MAT_INPLACE_MATRIX, &A)); 17013ba16761SJacob Faibussowitsch PetscFunctionReturn(PETSC_SUCCESS); 17028c3ff71bSJunchao Zhang } 17038c3ff71bSJunchao Zhang 17048c3ff71bSJunchao Zhang /*@C 170511a5261eSBarry Smith MatCreateAIJKokkos - Creates a sparse matrix in `MATAIJKOKOS` (compressed row) format 17068c3ff71bSJunchao Zhang (the default parallel PETSc format). This matrix will ultimately pushed down 1707*20f4b53cSBarry Smith to Kokkos for calculations. 17088c3ff71bSJunchao Zhang 17098c3ff71bSJunchao Zhang Collective 17108c3ff71bSJunchao Zhang 17118c3ff71bSJunchao Zhang Input Parameters: 171211a5261eSBarry Smith + comm - MPI communicator, set to `PETSC_COMM_SELF` 1713*20f4b53cSBarry Smith . m - number of local rows (or `PETSC_DECIDE` to have calculated if `M` is given) 1714*20f4b53cSBarry Smith This value should be the same as the local size used in creating the 1715*20f4b53cSBarry Smith y vector for the matrix-vector product y = Ax. 1716*20f4b53cSBarry Smith . n - This value should be the same as the local size used in creating the 1717*20f4b53cSBarry Smith x vector for the matrix-vector product y = Ax. (or `PETSC_DECIDE` to have 1718*20f4b53cSBarry Smith calculated if N is given) For square matrices n is almost always `m`. 1719*20f4b53cSBarry Smith . M - number of global rows (or `PETSC_DETERMINE` to have calculated if `m` is given) 1720*20f4b53cSBarry Smith . N - number of global columns (or `PETSC_DETERMINE` to have calculated if `n` is given) 1721*20f4b53cSBarry Smith . d_nz - number of nonzeros per row in DIAGONAL portion of local submatrix 1722*20f4b53cSBarry Smith (same value is used for all local rows) 1723*20f4b53cSBarry Smith . d_nnz - array containing the number of nonzeros in the various rows of the 1724*20f4b53cSBarry Smith DIAGONAL portion of the local submatrix (possibly different for each row) 1725*20f4b53cSBarry Smith or `NULL`, if `d_nz` is used to specify the nonzero structure. 1726*20f4b53cSBarry Smith The size of this array is equal to the number of local rows, i.e `m`. 1727*20f4b53cSBarry Smith For matrices you plan to factor you must leave room for the diagonal entry and 1728*20f4b53cSBarry Smith put in the entry even if it is zero. 1729*20f4b53cSBarry Smith . o_nz - number of nonzeros per row in the OFF-DIAGONAL portion of local 1730*20f4b53cSBarry Smith submatrix (same value is used for all local rows). 1731*20f4b53cSBarry Smith - o_nnz - array containing the number of nonzeros in the various rows of the 1732*20f4b53cSBarry Smith OFF-DIAGONAL portion of the local submatrix (possibly different for 1733*20f4b53cSBarry Smith each row) or `NULL`, if `o_nz` is used to specify the nonzero 1734*20f4b53cSBarry Smith structure. The size of this array is equal to the number 1735*20f4b53cSBarry Smith of local rows, i.e `m`. 17368c3ff71bSJunchao Zhang 17378c3ff71bSJunchao Zhang Output Parameter: 17388c3ff71bSJunchao Zhang . A - the matrix 17398c3ff71bSJunchao Zhang 17402ef1f0ffSBarry Smith Level: intermediate 17412ef1f0ffSBarry Smith 17422ef1f0ffSBarry Smith Notes: 174311a5261eSBarry Smith It is recommended that one use the `MatCreate()`, `MatSetType()` and/or `MatSetFromOptions()`, 17448c3ff71bSJunchao Zhang MatXXXXSetPreallocation() paradigm instead of this routine directly. 174511a5261eSBarry Smith [MatXXXXSetPreallocation() is, for example, `MatSeqAIJSetPreallocation()`] 17468c3ff71bSJunchao Zhang 1747667f096bSBarry Smith The AIJ format, also called compressed row storage), is fully compatible with standard Fortran 17488c3ff71bSJunchao Zhang storage. That is, the stored row and column indices can begin at 17492ef1f0ffSBarry Smith either one (as in Fortran) or zero. 17508c3ff71bSJunchao Zhang 17512ef1f0ffSBarry Smith .seealso: [](chapter_matrices), `Mat`, `MATAIJKOKOS`, `MATSEQAIJKOKOS`, `MATMPIAIJKOKOS`, `MatCreate()`, `MatCreateAIJ()`, `MatSetValues()`, 17522ef1f0ffSBarry Smith `MatSeqAIJSetColumnIndices()`, `MatCreateSeqAIJWithArrays()`, `MatCreateAIJ()`, `MATMPIAIJKOKKOS`, `MATAIJKOKKOS` 17538c3ff71bSJunchao Zhang @*/ 1754d71ae5a4SJacob Faibussowitsch PetscErrorCode MatCreateAIJKokkos(MPI_Comm comm, PetscInt m, PetscInt n, PetscInt M, PetscInt N, PetscInt d_nz, const PetscInt d_nnz[], PetscInt o_nz, const PetscInt o_nnz[], Mat *A) 1755d71ae5a4SJacob Faibussowitsch { 17568c3ff71bSJunchao Zhang PetscMPIInt size; 17578c3ff71bSJunchao Zhang 17588c3ff71bSJunchao Zhang PetscFunctionBegin; 17599566063dSJacob Faibussowitsch PetscCall(MatCreate(comm, A)); 17609566063dSJacob Faibussowitsch PetscCall(MatSetSizes(*A, m, n, M, N)); 17619566063dSJacob Faibussowitsch PetscCallMPI(MPI_Comm_size(comm, &size)); 17628c3ff71bSJunchao Zhang if (size > 1) { 17639566063dSJacob Faibussowitsch PetscCall(MatSetType(*A, MATMPIAIJKOKKOS)); 17649566063dSJacob Faibussowitsch PetscCall(MatMPIAIJSetPreallocation(*A, d_nz, d_nnz, o_nz, o_nnz)); 17658c3ff71bSJunchao Zhang } else { 17669566063dSJacob Faibussowitsch PetscCall(MatSetType(*A, MATSEQAIJKOKKOS)); 17679566063dSJacob Faibussowitsch PetscCall(MatSeqAIJSetPreallocation(*A, d_nz, d_nnz)); 17688c3ff71bSJunchao Zhang } 17693ba16761SJacob Faibussowitsch PetscFunctionReturn(PETSC_SUCCESS); 17708c3ff71bSJunchao Zhang } 17718c3ff71bSJunchao Zhang 1772a587d139SMark // get GPU pointer to stripped down Mat. For both Seq and MPI Mat. 1773d71ae5a4SJacob Faibussowitsch PetscErrorCode MatKokkosGetDeviceMatWrite(Mat A, PetscSplitCSRDataStructure *B) 1774d71ae5a4SJacob Faibussowitsch { 1775a587d139SMark PetscMPIInt size, rank; 1776a587d139SMark MPI_Comm comm; 1777042217e8SBarry Smith PetscSplitCSRDataStructure d_mat = NULL; 1778a587d139SMark 1779a587d139SMark PetscFunctionBegin; 17809566063dSJacob Faibussowitsch PetscCall(PetscObjectGetComm((PetscObject)A, &comm)); 17819566063dSJacob Faibussowitsch PetscCallMPI(MPI_Comm_size(comm, &size)); 17829566063dSJacob Faibussowitsch PetscCallMPI(MPI_Comm_rank(comm, &rank)); 1783a587d139SMark if (size == 1) { 17849566063dSJacob Faibussowitsch PetscCall(MatSeqAIJKokkosGetDeviceMat(A, &d_mat)); 17859566063dSJacob Faibussowitsch PetscCall(MatSeqAIJKokkosModifyDevice(A)); /* Since we are going to modify matrix values on device */ 1786a587d139SMark } else { 1787a587d139SMark Mat_MPIAIJ *aij = (Mat_MPIAIJ *)A->data; 17889566063dSJacob Faibussowitsch PetscCall(MatSeqAIJKokkosGetDeviceMat(aij->A, &d_mat)); 17899566063dSJacob Faibussowitsch PetscCall(MatSeqAIJKokkosModifyDevice(aij->A)); 17909566063dSJacob Faibussowitsch PetscCall(MatSeqAIJKokkosModifyDevice(aij->B)); 17912c71b3e2SJacob Faibussowitsch PetscCheck(A->nooffprocentries || aij->donotstash, PetscObjectComm((PetscObject)A), PETSC_ERR_SUP, "Device assembly does not currently support offproc values insertion. Use MatSetOption(A,MAT_NO_OFF_PROC_ENTRIES,PETSC_TRUE) or MatSetOption(A,MAT_IGNORE_OFF_PROC_ENTRIES,PETSC_TRUE)"); 1792a587d139SMark } 1793a587d139SMark // act like MatSetValues because not called on host 1794a587d139SMark if (A->assembled) { 179548a46eb9SPierre Jolivet if (A->was_assembled) PetscCall(PetscInfo(A, "Assemble more than once already\n")); 1796a587d139SMark A->was_assembled = PETSC_TRUE; // this is done (lazy) in MatAssemble but we are not calling it anymore - done in AIJ AssemblyEnd, need here? 1797a587d139SMark } else { 17989566063dSJacob Faibussowitsch PetscCall(PetscInfo(A, "Warning !assemble ??? assembled=%" PetscInt_FMT "\n", A->assembled)); 1799a587d139SMark } 1800a587d139SMark if (!d_mat) { 1801042217e8SBarry Smith struct _n_SplitCSRMat h_mat; /* host container */ 1802a587d139SMark Mat_SeqAIJKokkos *aijkokA; 1803a587d139SMark Mat_SeqAIJ *jaca; 1804a587d139SMark PetscInt n = A->rmap->n, nnz; 1805a587d139SMark Mat Amat; 1806042217e8SBarry Smith PetscInt *colmap; 1807042217e8SBarry Smith 1808042217e8SBarry Smith /* create and copy h_mat */ 180949b994a9SMark Adams h_mat.M = A->cmap->N; // use for debug build 18109566063dSJacob Faibussowitsch PetscCall(PetscInfo(A, "Create device matrix in Kokkos\n")); 1811a587d139SMark if (size == 1) { 1812a587d139SMark Amat = A; 1813a587d139SMark jaca = (Mat_SeqAIJ *)A->data; 18149371c9d4SSatish Balay h_mat.rstart = 0; 18159371c9d4SSatish Balay h_mat.rend = A->rmap->n; 18169371c9d4SSatish Balay h_mat.cstart = 0; 18179371c9d4SSatish Balay h_mat.cend = A->cmap->n; 1818a587d139SMark h_mat.offdiag.i = h_mat.offdiag.j = NULL; 1819a587d139SMark h_mat.offdiag.a = NULL; 1820a587d139SMark aijkokA = static_cast<Mat_SeqAIJKokkos *>(A->spptr); 1821a587d139SMark } else { 1822a587d139SMark Mat_MPIAIJ *aij = (Mat_MPIAIJ *)A->data; 1823a587d139SMark Mat_SeqAIJ *jacb = (Mat_SeqAIJ *)aij->B->data; 1824a587d139SMark PetscInt ii; 1825a587d139SMark Mat_SeqAIJKokkos *aijkokB; 1826042217e8SBarry Smith 1827a587d139SMark Amat = aij->A; 1828a587d139SMark aijkokA = static_cast<Mat_SeqAIJKokkos *>(aij->A->spptr); 1829a587d139SMark aijkokB = static_cast<Mat_SeqAIJKokkos *>(aij->B->spptr); 1830a587d139SMark jaca = (Mat_SeqAIJ *)aij->A->data; 183108401ef6SPierre Jolivet PetscCheck(!aij->B->cmap->n || aij->garray, comm, PETSC_ERR_PLIB, "MPIAIJ Matrix was assembled but is missing garray"); 183208401ef6SPierre Jolivet PetscCheck(aij->B->rmap->n == aij->A->rmap->n, comm, PETSC_ERR_SUP, "Only support aij->B->rmap->n == aij->A->rmap->n"); 1833a587d139SMark aij->donotstash = PETSC_TRUE; 1834a587d139SMark aij->A->nooffprocentries = aij->B->nooffprocentries = A->nooffprocentries = PETSC_TRUE; 1835a5b23f4aSJose E. Roman jaca->nonew = jacb->nonew = PETSC_TRUE; // no more disassembly 18369566063dSJacob Faibussowitsch PetscCall(PetscCalloc1(A->cmap->N, &colmap)); 1837042217e8SBarry Smith for (ii = 0; ii < aij->B->cmap->n; ii++) colmap[aij->garray[ii]] = ii + 1; 1838a587d139SMark // allocate B copy data 18399371c9d4SSatish Balay h_mat.rstart = A->rmap->rstart; 18409371c9d4SSatish Balay h_mat.rend = A->rmap->rend; 18419371c9d4SSatish Balay h_mat.cstart = A->cmap->rstart; 18429371c9d4SSatish Balay h_mat.cend = A->cmap->rend; 1843a587d139SMark nnz = jacb->i[n]; 1844a587d139SMark if (jacb->compressedrow.use) { 1845a587d139SMark const Kokkos::View<PetscInt *, Kokkos::HostSpace, Kokkos::MemoryTraits<Kokkos::Unmanaged>> h_i_k(jacb->i, n + 1); 1846300d22a6SJunchao Zhang aijkokB->i_uncompressed_d = Kokkos::View<PetscInt *>(Kokkos::create_mirror(DefaultMemorySpace(), h_i_k)); 1847300d22a6SJunchao Zhang Kokkos::deep_copy(aijkokB->i_uncompressed_d, h_i_k); 1848300d22a6SJunchao Zhang h_mat.offdiag.i = aijkokB->i_uncompressed_d.data(); 1849a587d139SMark } else { 185099551766SMark Adams h_mat.offdiag.i = aijkokB->i_device_data(); 1851a587d139SMark } 185299551766SMark Adams h_mat.offdiag.j = aijkokB->j_device_data(); 1853076ba34aSJunchao Zhang h_mat.offdiag.a = aijkokB->a_device_data(); 1854a587d139SMark { 1855042217e8SBarry Smith Kokkos::View<PetscInt *, Kokkos::HostSpace, Kokkos::MemoryTraits<Kokkos::Unmanaged>> h_colmap_k(colmap, A->cmap->N); 1856300d22a6SJunchao Zhang aijkokB->colmap_d = Kokkos::View<PetscInt *>(Kokkos::create_mirror(DefaultMemorySpace(), h_colmap_k)); 1857300d22a6SJunchao Zhang Kokkos::deep_copy(aijkokB->colmap_d, h_colmap_k); 1858300d22a6SJunchao Zhang h_mat.colmap = aijkokB->colmap_d.data(); 18599566063dSJacob Faibussowitsch PetscCall(PetscFree(colmap)); 1860a587d139SMark } 1861a587d139SMark h_mat.offdiag.ignorezeroentries = jacb->ignorezeroentries; 1862a587d139SMark h_mat.offdiag.n = n; 1863a587d139SMark } 1864a587d139SMark // allocate A copy data 1865a587d139SMark nnz = jaca->i[n]; 1866a587d139SMark h_mat.diag.n = n; 1867a587d139SMark h_mat.diag.ignorezeroentries = jaca->ignorezeroentries; 18689566063dSJacob Faibussowitsch PetscCallMPI(MPI_Comm_rank(comm, &h_mat.rank)); 1869d5b43468SJose E. Roman PetscCheck(!jaca->compressedrow.use, PETSC_COMM_SELF, PETSC_ERR_PLIB, "A does not support compressed row (todo)"); 187099551766SMark Adams h_mat.diag.i = aijkokA->i_device_data(); 187199551766SMark Adams h_mat.diag.j = aijkokA->j_device_data(); 1872076ba34aSJunchao Zhang h_mat.diag.a = aijkokA->a_device_data(); 1873da81f932SPierre Jolivet // copy pointers and metadata to device 18749566063dSJacob Faibussowitsch PetscCall(MatSeqAIJKokkosSetDeviceMat(Amat, &h_mat)); 18759566063dSJacob Faibussowitsch PetscCall(MatSeqAIJKokkosGetDeviceMat(Amat, &d_mat)); 18769566063dSJacob Faibussowitsch PetscCall(PetscInfo(A, "Create device Mat n=%" PetscInt_FMT " nnz=%" PetscInt_FMT "\n", h_mat.diag.n, nnz)); 1877a587d139SMark } 1878a587d139SMark *B = d_mat; // return it, set it in Mat, and set it up 1879a587d139SMark A->assembled = PETSC_FALSE; // ready to write with matsetvalues - this done (lazy) in normal MatSetValues 18803ba16761SJacob Faibussowitsch PetscFunctionReturn(PETSC_SUCCESS); 1881a587d139SMark } 1882076ba34aSJunchao Zhang 1883d71ae5a4SJacob Faibussowitsch PETSC_INTERN PetscErrorCode MatSeqAIJKokkosGetOffloadMask(Mat A, const char **mask) 1884d71ae5a4SJacob Faibussowitsch { 1885076ba34aSJunchao Zhang Mat_SeqAIJKokkos *aijkok = static_cast<Mat_SeqAIJKokkos *>(A->spptr); 1886076ba34aSJunchao Zhang 1887076ba34aSJunchao Zhang PetscFunctionBegin; 1888076ba34aSJunchao Zhang if (!aijkok) *mask = "AIJKOK_UNALLOCATED"; 1889076ba34aSJunchao Zhang else if (aijkok->a_dual.need_sync_host()) *mask = "PETSC_OFFLOAD_GPU"; 1890076ba34aSJunchao Zhang else if (aijkok->a_dual.need_sync_device()) *mask = "PETSC_OFFLOAD_CPU"; 1891076ba34aSJunchao Zhang else *mask = "PETSC_OFFLOAD_BOTH"; 18923ba16761SJacob Faibussowitsch PetscFunctionReturn(PETSC_SUCCESS); 1893076ba34aSJunchao Zhang } 1894076ba34aSJunchao Zhang 1895d71ae5a4SJacob Faibussowitsch PETSC_INTERN PetscErrorCode MatAIJKokkosPrintOffloadMask(Mat A) 1896d71ae5a4SJacob Faibussowitsch { 1897076ba34aSJunchao Zhang PetscMPIInt size; 1898076ba34aSJunchao Zhang Mat Ad, Ao; 1899076ba34aSJunchao Zhang const char *amask, *bmask; 1900076ba34aSJunchao Zhang 1901076ba34aSJunchao Zhang PetscFunctionBegin; 19029566063dSJacob Faibussowitsch PetscCallMPI(MPI_Comm_size(PetscObjectComm((PetscObject)A), &size)); 1903076ba34aSJunchao Zhang 1904076ba34aSJunchao Zhang if (size == 1) { 19059566063dSJacob Faibussowitsch PetscCall(MatSeqAIJKokkosGetOffloadMask(A, &amask)); 19069566063dSJacob Faibussowitsch PetscCall(PetscPrintf(PETSC_COMM_SELF, "%s\n", amask)); 1907076ba34aSJunchao Zhang } else { 1908076ba34aSJunchao Zhang Ad = ((Mat_MPIAIJ *)A->data)->A; 1909076ba34aSJunchao Zhang Ao = ((Mat_MPIAIJ *)A->data)->B; 19109566063dSJacob Faibussowitsch PetscCall(MatSeqAIJKokkosGetOffloadMask(Ad, &amask)); 19119566063dSJacob Faibussowitsch PetscCall(MatSeqAIJKokkosGetOffloadMask(Ao, &bmask)); 19129566063dSJacob Faibussowitsch PetscCall(PetscPrintf(PETSC_COMM_SELF, "Diag : Off-diag = %s : %s\n", amask, bmask)); 1913076ba34aSJunchao Zhang } 19143ba16761SJacob Faibussowitsch PetscFunctionReturn(PETSC_SUCCESS); 1915076ba34aSJunchao Zhang } 1916