1d326c3f1SJunchao Zhang #include <petsc_kokkos.hpp> 211d22bbfSJunchao Zhang #include <petscvec_kokkos.hpp> 3f0e6e2d1SJunchao Zhang #include <petscpkg_version.h> 442550becSJunchao Zhang #include <petsc/private/sfimpl.h> 592896123SJunchao Zhang #include <petsc/private/kokkosimpl.hpp> 62c4ab24aSJunchao Zhang #include <../src/mat/impls/aij/seq/kokkos/aijkok.hpp> 78c3ff71bSJunchao Zhang #include <../src/mat/impls/aij/mpi/mpiaij.h> 8076ba34aSJunchao Zhang #include <KokkosSparse_spadd.hpp> 90e3ece09SJunchao Zhang #include <KokkosSparse_spgemm.hpp> 1011d22bbfSJunchao Zhang 1166976f2fSJacob Faibussowitsch static PetscErrorCode MatAssemblyEnd_MPIAIJKokkos(Mat A, MatAssemblyType mode) 12d71ae5a4SJacob Faibussowitsch { 1330203840SJunchao Zhang Mat_MPIAIJ *mpiaij = (Mat_MPIAIJ *)A->data; 148c3ff71bSJunchao Zhang 158c3ff71bSJunchao Zhang PetscFunctionBegin; 169566063dSJacob Faibussowitsch PetscCall(MatAssemblyEnd_MPIAIJ(A, mode)); 1730203840SJunchao Zhang /* E.g., MatCreateSubMatrix() calls MatCreateMPIAIJWithSeqAIJ(comm,A,B,..), which creates Bnew of SEQAIJ and destroys B of SEQAIJKOKKOS. 1830203840SJunchao Zhang Thus we finalize A/B/lvec's type in MatAssemblyEnd() to handle various cases. 1930203840SJunchao Zhang */ 2030203840SJunchao Zhang if (mode == MAT_FINAL_ASSEMBLY) { 2192896123SJunchao Zhang PetscScalarKokkosView v; 2292896123SJunchao Zhang 2330203840SJunchao Zhang PetscCall(MatSetType(mpiaij->A, MATSEQAIJKOKKOS)); 2430203840SJunchao Zhang PetscCall(MatSetType(mpiaij->B, MATSEQAIJKOKKOS)); 2592896123SJunchao Zhang PetscCall(VecSetType(mpiaij->lvec, VECSEQKOKKOS)); // lvec is init'ed on host, without copying to device 2692896123SJunchao Zhang PetscCall(VecGetKokkosViewWrite(mpiaij->lvec, &v)); // mark lvec updated on device, as we never need to init lvec on device 2792896123SJunchao Zhang PetscCall(VecRestoreKokkosViewWrite(mpiaij->lvec, &v)); 2830203840SJunchao Zhang } 293ba16761SJacob Faibussowitsch PetscFunctionReturn(PETSC_SUCCESS); 308c3ff71bSJunchao Zhang } 318c3ff71bSJunchao Zhang 3266976f2fSJacob Faibussowitsch static PetscErrorCode MatMPIAIJSetPreallocation_MPIAIJKokkos(Mat mat, PetscInt d_nz, const PetscInt d_nnz[], PetscInt o_nz, const PetscInt o_nnz[]) 33d71ae5a4SJacob Faibussowitsch { 342cdb1aeaSJunchao Zhang Mat_MPIAIJ *mpiaij; 358c3ff71bSJunchao Zhang 368c3ff71bSJunchao Zhang PetscFunctionBegin; 372cdb1aeaSJunchao Zhang // reuse MPIAIJ's preallocation, which sets A/B's blocksize along other things 382cdb1aeaSJunchao Zhang PetscCall(MatMPIAIJSetPreallocation_MPIAIJ(mat, d_nz, d_nnz, o_nz, o_nnz)); 392cdb1aeaSJunchao Zhang mpiaij = static_cast<Mat_MPIAIJ *>(mat->data); 402cdb1aeaSJunchao Zhang PetscCall(MatConvert_SeqAIJ_SeqAIJKokkos(mpiaij->A, MATSEQAIJKOKKOS, MAT_INPLACE_MATRIX, &mpiaij->A)); 412cdb1aeaSJunchao Zhang PetscCall(MatConvert_SeqAIJ_SeqAIJKokkos(mpiaij->B, MATSEQAIJKOKKOS, MAT_INPLACE_MATRIX, &mpiaij->B)); 423ba16761SJacob Faibussowitsch PetscFunctionReturn(PETSC_SUCCESS); 438c3ff71bSJunchao Zhang } 448c3ff71bSJunchao Zhang 4566976f2fSJacob Faibussowitsch static PetscErrorCode MatMult_MPIAIJKokkos(Mat mat, Vec xx, Vec yy) 46d71ae5a4SJacob Faibussowitsch { 478c3ff71bSJunchao Zhang Mat_MPIAIJ *mpiaij = (Mat_MPIAIJ *)mat->data; 488c3ff71bSJunchao Zhang PetscInt nt; 498c3ff71bSJunchao Zhang 508c3ff71bSJunchao Zhang PetscFunctionBegin; 519566063dSJacob Faibussowitsch PetscCall(VecGetLocalSize(xx, &nt)); 5208401ef6SPierre Jolivet PetscCheck(nt == mat->cmap->n, PETSC_COMM_SELF, PETSC_ERR_ARG_SIZ, "Incompatible partition of mat (%" PetscInt_FMT ") and xx (%" PetscInt_FMT ")", mat->cmap->n, nt); 539566063dSJacob Faibussowitsch PetscCall(VecScatterBegin(mpiaij->Mvctx, xx, mpiaij->lvec, INSERT_VALUES, SCATTER_FORWARD)); 549566063dSJacob Faibussowitsch PetscCall((*mpiaij->A->ops->mult)(mpiaij->A, xx, yy)); 559566063dSJacob Faibussowitsch PetscCall(VecScatterEnd(mpiaij->Mvctx, xx, mpiaij->lvec, INSERT_VALUES, SCATTER_FORWARD)); 569566063dSJacob Faibussowitsch PetscCall((*mpiaij->B->ops->multadd)(mpiaij->B, mpiaij->lvec, yy, yy)); 573ba16761SJacob Faibussowitsch PetscFunctionReturn(PETSC_SUCCESS); 588c3ff71bSJunchao Zhang } 598c3ff71bSJunchao Zhang 6066976f2fSJacob Faibussowitsch static PetscErrorCode MatMultAdd_MPIAIJKokkos(Mat mat, Vec xx, Vec yy, Vec zz) 61d71ae5a4SJacob Faibussowitsch { 628c3ff71bSJunchao Zhang Mat_MPIAIJ *mpiaij = (Mat_MPIAIJ *)mat->data; 638c3ff71bSJunchao Zhang PetscInt nt; 648c3ff71bSJunchao Zhang 658c3ff71bSJunchao Zhang PetscFunctionBegin; 669566063dSJacob Faibussowitsch PetscCall(VecGetLocalSize(xx, &nt)); 6708401ef6SPierre Jolivet PetscCheck(nt == mat->cmap->n, PETSC_COMM_SELF, PETSC_ERR_ARG_SIZ, "Incompatible partition of mat (%" PetscInt_FMT ") and xx (%" PetscInt_FMT ")", mat->cmap->n, nt); 689566063dSJacob Faibussowitsch PetscCall(VecScatterBegin(mpiaij->Mvctx, xx, mpiaij->lvec, INSERT_VALUES, SCATTER_FORWARD)); 699566063dSJacob Faibussowitsch PetscCall((*mpiaij->A->ops->multadd)(mpiaij->A, xx, yy, zz)); 709566063dSJacob Faibussowitsch PetscCall(VecScatterEnd(mpiaij->Mvctx, xx, mpiaij->lvec, INSERT_VALUES, SCATTER_FORWARD)); 719566063dSJacob Faibussowitsch PetscCall((*mpiaij->B->ops->multadd)(mpiaij->B, mpiaij->lvec, zz, zz)); 723ba16761SJacob Faibussowitsch PetscFunctionReturn(PETSC_SUCCESS); 738c3ff71bSJunchao Zhang } 748c3ff71bSJunchao Zhang 7566976f2fSJacob Faibussowitsch static PetscErrorCode MatMultTranspose_MPIAIJKokkos(Mat mat, Vec xx, Vec yy) 76d71ae5a4SJacob Faibussowitsch { 778c3ff71bSJunchao Zhang Mat_MPIAIJ *mpiaij = (Mat_MPIAIJ *)mat->data; 788c3ff71bSJunchao Zhang PetscInt nt; 798c3ff71bSJunchao Zhang 808c3ff71bSJunchao Zhang PetscFunctionBegin; 819566063dSJacob Faibussowitsch PetscCall(VecGetLocalSize(xx, &nt)); 8208401ef6SPierre Jolivet PetscCheck(nt == mat->rmap->n, PETSC_COMM_SELF, PETSC_ERR_ARG_SIZ, "Incompatible partition of mat (%" PetscInt_FMT ") and xx (%" PetscInt_FMT ")", mat->rmap->n, nt); 839566063dSJacob Faibussowitsch PetscCall((*mpiaij->B->ops->multtranspose)(mpiaij->B, xx, mpiaij->lvec)); 849566063dSJacob Faibussowitsch PetscCall((*mpiaij->A->ops->multtranspose)(mpiaij->A, xx, yy)); 859566063dSJacob Faibussowitsch PetscCall(VecScatterBegin(mpiaij->Mvctx, mpiaij->lvec, yy, ADD_VALUES, SCATTER_REVERSE)); 869566063dSJacob Faibussowitsch PetscCall(VecScatterEnd(mpiaij->Mvctx, mpiaij->lvec, yy, ADD_VALUES, SCATTER_REVERSE)); 873ba16761SJacob Faibussowitsch PetscFunctionReturn(PETSC_SUCCESS); 888c3ff71bSJunchao Zhang } 898c3ff71bSJunchao Zhang 90076ba34aSJunchao Zhang /* Merge the "A, B" matrices of mat into a matrix C. mat's type is MPIAIJKOKKOS. C's type is MATSEQAIJKOKKOS. 91076ba34aSJunchao Zhang A is put before B. C's size would be A->rmap->n by (A->cmap->n + B->cmap->n). 92076ba34aSJunchao Zhang C still uses local column ids. Their corresponding global column ids are returned in glob. 93076ba34aSJunchao Zhang */ 9466976f2fSJacob Faibussowitsch static PetscErrorCode MatMPIAIJGetLocalMatMerge_MPIAIJKokkos(Mat mat, MatReuse reuse, IS *glob, Mat *C) 95d71ae5a4SJacob Faibussowitsch { 96076ba34aSJunchao Zhang Mat Ad, Ao; 97076ba34aSJunchao Zhang const PetscInt *cmap; 98076ba34aSJunchao Zhang 99076ba34aSJunchao Zhang PetscFunctionBegin; 1009566063dSJacob Faibussowitsch PetscCall(MatMPIAIJGetSeqAIJ(mat, &Ad, &Ao, &cmap)); 1019566063dSJacob Faibussowitsch PetscCall(MatSeqAIJKokkosMergeMats(Ad, Ao, reuse, C)); 102076ba34aSJunchao Zhang if (glob) { 103076ba34aSJunchao Zhang PetscInt cst, i, dn, on, *gidx; 1049566063dSJacob Faibussowitsch PetscCall(MatGetLocalSize(Ad, NULL, &dn)); 1059566063dSJacob Faibussowitsch PetscCall(MatGetLocalSize(Ao, NULL, &on)); 1069566063dSJacob Faibussowitsch PetscCall(MatGetOwnershipRangeColumn(mat, &cst, NULL)); 1079566063dSJacob Faibussowitsch PetscCall(PetscMalloc1(dn + on, &gidx)); 108076ba34aSJunchao Zhang for (i = 0; i < dn; i++) gidx[i] = cst + i; 109076ba34aSJunchao Zhang for (i = 0; i < on; i++) gidx[i + dn] = cmap[i]; 1109566063dSJacob Faibussowitsch PetscCall(ISCreateGeneral(PetscObjectComm((PetscObject)Ad), dn + on, gidx, PETSC_OWN_POINTER, glob)); 111076ba34aSJunchao Zhang } 1123ba16761SJacob Faibussowitsch PetscFunctionReturn(PETSC_SUCCESS); 113076ba34aSJunchao Zhang } 114076ba34aSJunchao Zhang 1150e3ece09SJunchao Zhang /* Structs used in matrix products of type C=AB, C=A^tB and C=B^tAB */ 116076ba34aSJunchao Zhang struct MatMatStruct { 1170e3ece09SJunchao Zhang PetscInt n, *garray; // C's garray and its size. 1180e3ece09SJunchao Zhang KokkosCsrMatrix Cd, Co; // C is in split form matrices (all in local column indcies) 1190e3ece09SJunchao Zhang KokkosCsrMatrix C1, C2, C3, C4; // intermediate mat products 1200e3ece09SJunchao Zhang KokkosCsrMatrix C2_mid, C4_mid; // alias of C2, C4; share their a[], i[], but with different j[] (hence column size) 1210e3ece09SJunchao Zhang PetscIntKokkosView E_NzLeft; 1220e3ece09SJunchao Zhang PetscSF sf = nullptr; // SF to bcast or reduce matrices E to F 1230e3ece09SJunchao Zhang MatScalarKokkosView rootBuf, leafBuf; 1240e3ece09SJunchao Zhang KokkosCsrMatrix Fd, Fo; // F in split form 1250e3ece09SJunchao Zhang 1260e3ece09SJunchao Zhang KernelHandle kh1; // compute C1, add C1+C3 or C1+Fd 1270e3ece09SJunchao Zhang KernelHandle kh2; // compute C2, add C2+C4 or C2+Fo 1280e3ece09SJunchao Zhang KernelHandle kh3; // compute C3 1290e3ece09SJunchao Zhang KernelHandle kh4; // compute C4 1300e3ece09SJunchao Zhang 131aaa8cc7dSPierre Jolivet PetscInt E_TeamSize; // kernel launching parameters in merging E or splitting F 1320e3ece09SJunchao Zhang PetscInt E_VectorLength; 1330e3ece09SJunchao Zhang PetscInt E_RowsPerTeam; 1340e3ece09SJunchao Zhang PetscInt F_TeamSize; 1350e3ece09SJunchao Zhang PetscInt F_VectorLength; 1360e3ece09SJunchao Zhang PetscInt F_RowsPerTeam; 137076ba34aSJunchao Zhang 138d71ae5a4SJacob Faibussowitsch ~MatMatStruct() 139d71ae5a4SJacob Faibussowitsch { 1403ba16761SJacob Faibussowitsch PetscFunctionBegin; 1413ba16761SJacob Faibussowitsch PetscCallAbort(PETSC_COMM_SELF, PetscSFDestroy(&sf)); 1423ba16761SJacob Faibussowitsch PetscFunctionReturnVoid(); 143076ba34aSJunchao Zhang } 144076ba34aSJunchao Zhang }; 145076ba34aSJunchao Zhang 146076ba34aSJunchao Zhang struct MatMatStruct_AB : public MatMatStruct { 1470e3ece09SJunchao Zhang PetscIntKokkosView F_NzLeft; // plans to split F (in leafbuf) into Fd, Fo 1480e3ece09SJunchao Zhang PetscIntKokkosView irootloc; // plans to put E (i.e., Bd, Bo) into rootBuf 1490e3ece09SJunchao Zhang PetscIntKokkosView rowoffset; 150076ba34aSJunchao Zhang }; 151076ba34aSJunchao Zhang 152076ba34aSJunchao Zhang struct MatMatStruct_AtB : public MatMatStruct { 1530e3ece09SJunchao Zhang MatColIdxKokkosView Fdjmap; // plans to reduce data in rootBuf to Fd, Fo 1540e3ece09SJunchao Zhang MatColIdxKokkosView Fdjperm; 1550e3ece09SJunchao Zhang MatColIdxKokkosView Fojmap; 1560e3ece09SJunchao Zhang MatColIdxKokkosView Fojperm; 157076ba34aSJunchao Zhang }; 158076ba34aSJunchao Zhang 1599371c9d4SSatish Balay struct MatProductData_MPIAIJKokkos { 1603ba16761SJacob Faibussowitsch MatMatStruct_AB *mmAB = nullptr; 1613ba16761SJacob Faibussowitsch MatMatStruct_AtB *mmAtB = nullptr; 1623ba16761SJacob Faibussowitsch PetscBool reusesym = PETSC_FALSE; 1630e3ece09SJunchao Zhang Mat Z = nullptr; // store Z=AB in computing BtAB 164076ba34aSJunchao Zhang 165d71ae5a4SJacob Faibussowitsch ~MatProductData_MPIAIJKokkos() 166d71ae5a4SJacob Faibussowitsch { 167076ba34aSJunchao Zhang delete mmAB; 168076ba34aSJunchao Zhang delete mmAtB; 1690e3ece09SJunchao Zhang PetscCallAbort(PETSC_COMM_SELF, MatDestroy(&Z)); 170076ba34aSJunchao Zhang } 171076ba34aSJunchao Zhang }; 172076ba34aSJunchao Zhang 173d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatProductDataDestroy_MPIAIJKokkos(void *data) 174d71ae5a4SJacob Faibussowitsch { 175076ba34aSJunchao Zhang PetscFunctionBegin; 1769566063dSJacob Faibussowitsch PetscCallCXX(delete static_cast<MatProductData_MPIAIJKokkos *>(data)); 1773ba16761SJacob Faibussowitsch PetscFunctionReturn(PETSC_SUCCESS); 178076ba34aSJunchao Zhang } 179076ba34aSJunchao Zhang 180076ba34aSJunchao Zhang /* MatSetMPIAIJKokkosWithSplitSeqAIJKokkosMatrices - Set the diag and offdiag matrices of a MATMPIAIJKOKKOS matrix. 181076ba34aSJunchao Zhang It is similar to MatCreateMPIAIJWithSplitArrays. 182076ba34aSJunchao Zhang 183076ba34aSJunchao Zhang Input Parameters: 184076ba34aSJunchao Zhang + mat - the MATMPIAIJKOKKOS matrix, which should have its type and layout set, but should not have its diag, offdiag matrices set 185076ba34aSJunchao Zhang . A - the diag matrix using local col ids 186076ba34aSJunchao Zhang - B - the offdiag matrix using global col ids 187076ba34aSJunchao Zhang 1882fe279fdSBarry Smith Output Parameter: 189076ba34aSJunchao Zhang . mat - the updated MATMPIAIJKOKKOS matrix 190076ba34aSJunchao Zhang */ 1910e3ece09SJunchao Zhang static PetscErrorCode MatSetMPIAIJKokkosWithSplitSeqAIJKokkosMatrices(Mat mat, Mat A, Mat B, PetscInt *garray) 192d71ae5a4SJacob Faibussowitsch { 193076ba34aSJunchao Zhang Mat_MPIAIJ *mpiaij = static_cast<Mat_MPIAIJ *>(mat->data); 194076ba34aSJunchao Zhang PetscInt m, n, M, N, Am, An, Bm, Bn; 195076ba34aSJunchao Zhang 196076ba34aSJunchao Zhang PetscFunctionBegin; 1979566063dSJacob Faibussowitsch PetscCall(MatGetSize(mat, &M, &N)); 1989566063dSJacob Faibussowitsch PetscCall(MatGetLocalSize(mat, &m, &n)); 1999566063dSJacob Faibussowitsch PetscCall(MatGetLocalSize(A, &Am, &An)); 2009566063dSJacob Faibussowitsch PetscCall(MatGetLocalSize(B, &Bm, &Bn)); 201076ba34aSJunchao Zhang 202aed4548fSBarry Smith PetscCheck(m == Am && m == Bm, PETSC_COMM_SELF, PETSC_ERR_PLIB, "local number of rows do not match"); 20308401ef6SPierre Jolivet PetscCheck(n == An, PETSC_COMM_SELF, PETSC_ERR_PLIB, "local number of columns do not match"); 2040e3ece09SJunchao Zhang // PetscCheck(N == Bn, PETSC_COMM_SELF, PETSC_ERR_PLIB, "global number of columns do not match"); 20508401ef6SPierre Jolivet PetscCheck(!mpiaij->A && !mpiaij->B, PETSC_COMM_SELF, PETSC_ERR_PLIB, "A, B of the MPIAIJ matrix are not empty"); 206076ba34aSJunchao Zhang mpiaij->A = A; 207076ba34aSJunchao Zhang mpiaij->B = B; 2080e3ece09SJunchao Zhang mpiaij->garray = garray; 209076ba34aSJunchao Zhang 210076ba34aSJunchao Zhang mat->preallocated = PETSC_TRUE; 211076ba34aSJunchao Zhang mat->nooffprocentries = PETSC_TRUE; /* See MatAssemblyBegin_MPIAIJ. In effect, making MatAssemblyBegin a nop */ 212076ba34aSJunchao Zhang 2139566063dSJacob Faibussowitsch PetscCall(MatSetOption(mat, MAT_NO_OFF_PROC_ENTRIES, PETSC_TRUE)); 2149566063dSJacob Faibussowitsch PetscCall(MatAssemblyBegin(mat, MAT_FINAL_ASSEMBLY)); 215076ba34aSJunchao Zhang /* MatAssemblyEnd is critical here. It sets mat->offloadmask according to A and B's, and 216076ba34aSJunchao Zhang also gets mpiaij->B compacted, with its col ids and size reduced 217076ba34aSJunchao Zhang */ 2189566063dSJacob Faibussowitsch PetscCall(MatAssemblyEnd(mat, MAT_FINAL_ASSEMBLY)); 2199566063dSJacob Faibussowitsch PetscCall(MatSetOption(mat, MAT_NO_OFF_PROC_ENTRIES, PETSC_FALSE)); 2209566063dSJacob Faibussowitsch PetscCall(MatSetOption(mat, MAT_NEW_NONZERO_LOCATION_ERR, PETSC_TRUE)); 2213ba16761SJacob Faibussowitsch PetscFunctionReturn(PETSC_SUCCESS); 222076ba34aSJunchao Zhang } 223076ba34aSJunchao Zhang 2240e3ece09SJunchao Zhang // Adapted from Kokkos-Kernels spmv_launch_parameters(), to get parameters in Kokkos nested loops which we used to merge or 2250e3ece09SJunchao Zhang // split csr matrices. The rule is to have "vector_length * team_size" be around 256 on GPUs (e.g., for a CUDA thread block) 2260e3ece09SJunchao Zhang template <class ExecutionSpace> 2270e3ece09SJunchao Zhang static PetscErrorCode MatMergeGetLaunchParameters(PetscInt numRows, PetscInt nnz, PetscInt rows_per_thread, PetscInt &team_size, PetscInt &vector_length, PetscInt &rows_per_team) 228d71ae5a4SJacob Faibussowitsch { 229*1aa660a0SJunchao Zhang #if PETSC_PKG_KOKKOS_KERNELS_VERSION_LE(4, 4, 1) 230*1aa660a0SJunchao Zhang constexpr bool is_gpu_exec_space = KokkosKernels::Impl::kk_is_gpu_exec_space<ExecutionSpace>(); 231*1aa660a0SJunchao Zhang #else 232*1aa660a0SJunchao Zhang constexpr bool is_gpu_exec_space = KokkosKernels::Impl::is_gpu_exec_space_v<ExecutionSpace>; 233*1aa660a0SJunchao Zhang #endif 2340e3ece09SJunchao Zhang Kokkos::TeamPolicy<ExecutionSpace> teamPolicy(128, Kokkos::AUTO); 235076ba34aSJunchao Zhang 236076ba34aSJunchao Zhang PetscFunctionBegin; 2370e3ece09SJunchao Zhang PetscInt nnz_per_row = numRows ? (nnz / numRows) : 0; // we might meet empty matrices 238076ba34aSJunchao Zhang 2390e3ece09SJunchao Zhang if (nnz_per_row < 1) nnz_per_row = 1; 240076ba34aSJunchao Zhang 2410e3ece09SJunchao Zhang int max_vector_length = teamPolicy.vector_length_max(); 242076ba34aSJunchao Zhang 2430e3ece09SJunchao Zhang if (vector_length < 1) { 2440e3ece09SJunchao Zhang vector_length = 1; 2450e3ece09SJunchao Zhang while (vector_length < max_vector_length && vector_length * 6 < nnz_per_row) vector_length *= 2; 246076ba34aSJunchao Zhang } 247076ba34aSJunchao Zhang 2480e3ece09SJunchao Zhang // Determine rows per thread 2490e3ece09SJunchao Zhang if (rows_per_thread < 1) { 250*1aa660a0SJunchao Zhang if (is_gpu_exec_space) rows_per_thread = 1; 2510e3ece09SJunchao Zhang else { 2520e3ece09SJunchao Zhang if (nnz_per_row < 20 && nnz > 5000000) { 2530e3ece09SJunchao Zhang rows_per_thread = 256; 2540e3ece09SJunchao Zhang } else rows_per_thread = 64; 255076ba34aSJunchao Zhang } 256076ba34aSJunchao Zhang } 257076ba34aSJunchao Zhang 2580e3ece09SJunchao Zhang if (team_size < 1) { 259*1aa660a0SJunchao Zhang if (is_gpu_exec_space) { 2600e3ece09SJunchao Zhang team_size = 256 / vector_length; 261076ba34aSJunchao Zhang } else { 2620e3ece09SJunchao Zhang team_size = 1; 2630e3ece09SJunchao Zhang } 264076ba34aSJunchao Zhang } 265076ba34aSJunchao Zhang 2660e3ece09SJunchao Zhang rows_per_team = rows_per_thread * team_size; 267076ba34aSJunchao Zhang 2680e3ece09SJunchao Zhang if (rows_per_team < 0) { 2690e3ece09SJunchao Zhang PetscInt nnz_per_team = 4096; 2700e3ece09SJunchao Zhang PetscInt conc = ExecutionSpace().concurrency(); 2710e3ece09SJunchao Zhang while ((conc * nnz_per_team * 4 > nnz) && (nnz_per_team > 256)) nnz_per_team /= 2; 2720e3ece09SJunchao Zhang rows_per_team = (nnz_per_team + nnz_per_row - 1) / nnz_per_row; 2730e3ece09SJunchao Zhang } 2743ba16761SJacob Faibussowitsch PetscFunctionReturn(PETSC_SUCCESS); 275076ba34aSJunchao Zhang } 276076ba34aSJunchao Zhang 2770e3ece09SJunchao Zhang /* 2780e3ece09SJunchao Zhang Reduce two sets of global indices into local ones 279076ba34aSJunchao Zhang 280076ba34aSJunchao Zhang Input Parameters: 2810e3ece09SJunchao Zhang + n1 - size of garray1[], the first set 2820e3ece09SJunchao Zhang . garray1[n1] - a sorted global index array (without duplicates) 2830e3ece09SJunchao Zhang . m - size of indices[], the second set 2840e3ece09SJunchao Zhang - indices[m] - a unsorted global index array (might have duplicates), which will be updated on output into local ones 285076ba34aSJunchao Zhang 286076ba34aSJunchao Zhang Output Parameters: 2870e3ece09SJunchao Zhang + n2 - size of garray2[], the merged set, which combines garray1[] and indices[] 2880e3ece09SJunchao Zhang . garray2[n2] - allocated by callee using PetscMalloc1(). Contains sorted unique global indices (without duplicates). Caller needs to free it. 2890e3ece09SJunchao Zhang . map[n1] - allocated by caller. It gives garray1[i] = garray2[map[i]] 2900e3ece09SJunchao Zhang - indices[m] - on output, global indices in this array are rewritten with local ones, i.e, indices_input[i] = garray2[indices_output[i]] 291076ba34aSJunchao Zhang 2920e3ece09SJunchao Zhang Example, say 2930e3ece09SJunchao Zhang n1 = 5 2940e3ece09SJunchao Zhang garray1[5] = {1, 4, 7, 8, 10} 2950e3ece09SJunchao Zhang m = 4 2960e3ece09SJunchao Zhang indices[4] = {2, 4, 8, 9} 29711a5261eSBarry Smith 2980e3ece09SJunchao Zhang Combining them together, we have 7 global indices in garray2[] 2990e3ece09SJunchao Zhang n2 = 7 3000e3ece09SJunchao Zhang garray2[7] = {1, 2, 4, 7, 8, 9, 10} 3010e3ece09SJunchao Zhang 3020e3ece09SJunchao Zhang And we have map[] to connect "garray1[i] = garray2[map[i]], i=[0,n1)" 3030e3ece09SJunchao Zhang map[5] = {0, 2, 3, 4, 6} 3040e3ece09SJunchao Zhang 3050e3ece09SJunchao Zhang On output, indices[] is updated with local indices 3060e3ece09SJunchao Zhang indices[4] = {1, 2, 4, 5} 307076ba34aSJunchao Zhang */ 3080e3ece09SJunchao Zhang static PetscErrorCode ReduceTwoSetsOfGlobalIndices(PetscInt n1, const PetscInt *garray1, PetscInt m, PetscInt *indices, PetscInt *n2_, PetscInt **garray2_, PetscInt *map) 309d71ae5a4SJacob Faibussowitsch { 3100e3ece09SJunchao Zhang PetscHMapI g2l = nullptr; 3110e3ece09SJunchao Zhang PetscHashIter iter; 3120e3ece09SJunchao Zhang PetscInt tot, key, val; // total unique global indices. key is global id; val is local id 3130e3ece09SJunchao Zhang PetscInt n2, *garray2; 314076ba34aSJunchao Zhang 315076ba34aSJunchao Zhang PetscFunctionBegin; 3160e3ece09SJunchao Zhang tot = 0; 3170e3ece09SJunchao Zhang PetscCall(PetscHMapICreateWithSize(n1, &g2l)); 3180e3ece09SJunchao Zhang for (PetscInt i = 0; i < m; i++) { // insert those in indices[] 3190e3ece09SJunchao Zhang PetscCall(PetscHMapIGetWithDefault(g2l, indices[i], -1, &val)); // if not exist, val is set with -1 3200e3ece09SJunchao Zhang if (val < 0) PetscCall(PetscHMapISet(g2l, indices[i], tot++)); // val < 0 means gid is not in the hash table yet 321076ba34aSJunchao Zhang } 322076ba34aSJunchao Zhang 3230e3ece09SJunchao Zhang for (PetscInt i = 0; i < n1; i++) { // insert those in garray1[] 3240e3ece09SJunchao Zhang PetscCall(PetscHMapIGetWithDefault(g2l, garray1[i], -1, &val)); 3250e3ece09SJunchao Zhang if (val < 0) PetscCall(PetscHMapISet(g2l, garray1[i], tot++)); 326076ba34aSJunchao Zhang } 327076ba34aSJunchao Zhang 3280e3ece09SJunchao Zhang // Pull out (unique) globals in the hash table and put them in garray2[] 3290e3ece09SJunchao Zhang n2 = tot; 3300e3ece09SJunchao Zhang PetscCall(PetscMalloc1(n2, &garray2)); 3310e3ece09SJunchao Zhang tot = 0; 3320e3ece09SJunchao Zhang PetscHashIterBegin(g2l, iter); 3330e3ece09SJunchao Zhang while (!PetscHashIterAtEnd(g2l, iter)) { 3340e3ece09SJunchao Zhang PetscHashIterGetKey(g2l, iter, key); 3350e3ece09SJunchao Zhang PetscHashIterNext(g2l, iter); 3360e3ece09SJunchao Zhang garray2[tot++] = key; 337076ba34aSJunchao Zhang } 338076ba34aSJunchao Zhang 3390e3ece09SJunchao Zhang // Sort garray2[] and then map them to local indices starting from 0 3400e3ece09SJunchao Zhang PetscCall(PetscSortInt(n2, garray2)); 3410e3ece09SJunchao Zhang PetscCall(PetscHMapIClear(g2l)); 3420e3ece09SJunchao Zhang for (PetscInt i = 0; i < tot; i++) PetscCall(PetscHMapISet(g2l, garray2[i], i)); // i is the local id 343f0e6e2d1SJunchao Zhang 3440e3ece09SJunchao Zhang // Rewrite indices[] with local indices 345f0e6e2d1SJunchao Zhang for (PetscInt i = 0; i < m; i++) { 3460e3ece09SJunchao Zhang PetscCall(PetscHMapIGetWithDefault(g2l, indices[i], -1, &val)); 3470e3ece09SJunchao Zhang PetscAssert(val >= 0, PETSC_COMM_SELF, PETSC_ERR_PLIB, "Met a negative local column index"); 3480e3ece09SJunchao Zhang indices[i] = val; 3490e3ece09SJunchao Zhang } 3500e3ece09SJunchao Zhang // Record the map that maps garray1[i] to garray2[map[i]] 3510e3ece09SJunchao Zhang for (PetscInt i = 0; i < n1; i++) PetscCall(PetscHMapIGetWithDefault(g2l, garray1[i], -1, &map[i])); 3520e3ece09SJunchao Zhang PetscCall(PetscHMapIDestroy(&g2l)); 3530e3ece09SJunchao Zhang *n2_ = n2; 3540e3ece09SJunchao Zhang *garray2_ = garray2; 3550e3ece09SJunchao Zhang PetscFunctionReturn(PETSC_SUCCESS); 3560e3ece09SJunchao Zhang } 357f0e6e2d1SJunchao Zhang 3580e3ece09SJunchao Zhang /* 3590e3ece09SJunchao Zhang MatMPIAIJKokkosReduce - Reduce rows of a MPIAIJKOKKOS matrix (E, in split form) to produce another matrix (F, also in split form, stored in mm) 3600e3ece09SJunchao Zhang 3610e3ece09SJunchao Zhang It is the reverse of MatMPIAIJKokkosBcast() in some sense, but with a different signature since we do not really need a fully populated MPIAIJKOKKOS E. 3620e3ece09SJunchao Zhang 3630e3ece09SJunchao Zhang Think each row of E as a leaf, then the given ownerSF specifies roots for the leaves. Roots may connect to multiple leaves. 3640e3ece09SJunchao Zhang In this routine, we sparse-merge leaves (rows) at their roots to form potentially longer rows in F. F's number of rows will be nroots of ownerSF. 3650e3ece09SJunchao Zhang 3660e3ece09SJunchao Zhang Input Parameters: 3670e3ece09SJunchao Zhang + comm - MPI communicator of E 3680e3ece09SJunchao Zhang . A - diag block of E, using local column indices 3690e3ece09SJunchao Zhang . B - off-diag block of E, using local column indices 3700e3ece09SJunchao Zhang . cstart - (global) start column of Ed 3710e3ece09SJunchao Zhang . cend - (global) end column + 1 of Ed. In other words, E's column ownership is in range of [cstart, cend) 3720e3ece09SJunchao Zhang . garray1[n1] - global column indices of Eo. Here n1 is Eo's column size. 3730e3ece09SJunchao Zhang . ownerSF - the SF specifies ownership (root) of rows in E 3740e3ece09SJunchao Zhang . reuse - either MAT_INITIAL_MATRIX or MAT_REUSE_MATRIX 3750e3ece09SJunchao Zhang - mm - to stash intermediate data structures for reuse 3760e3ece09SJunchao Zhang 3770e3ece09SJunchao Zhang Output Parameters: 3780e3ece09SJunchao Zhang + map[n1] - allocated by caller. It maps garray1[] to garray2[]. See more at ReduceTwoSetsOfGlobalIndices(). 3790e3ece09SJunchao Zhang - mm - contains various info, such as garray2[], F (Fd, Fo) etc. 3800e3ece09SJunchao Zhang 3810e3ece09SJunchao Zhang Notes: 3820e3ece09SJunchao Zhang When reuse = MAT_REUSE_MATRIX, cstart, cend, garray1, ownerSF, map are not significant. 3830e3ece09SJunchao Zhang 3840e3ece09SJunchao Zhang */ 3850e3ece09SJunchao Zhang static PetscErrorCode MatMPIAIJKokkosReduceBegin(MPI_Comm comm, KokkosCsrMatrix A, KokkosCsrMatrix B, PetscInt cstart, PetscInt cend, const PetscInt *garray1, PetscSF ownerSF, MatReuse reuse, PetscInt *map, MatMatStruct_AtB *mm) 3860e3ece09SJunchao Zhang { 3870e3ece09SJunchao Zhang PetscFunctionBegin; 3880e3ece09SJunchao Zhang if (reuse == MAT_INITIAL_MATRIX) { 3890e3ece09SJunchao Zhang PetscInt Em = A.numRows(), Fm; 3900e3ece09SJunchao Zhang PetscInt n1 = B.numCols(); 3910e3ece09SJunchao Zhang 3920e3ece09SJunchao Zhang PetscCall(PetscSFGetGraph(ownerSF, &Fm, NULL, NULL, NULL)); // Fm = #rows of F = nroots of ownerSF 3930e3ece09SJunchao Zhang 3940e3ece09SJunchao Zhang // Do the analysis on host 3950e3ece09SJunchao Zhang auto Ai_h = Kokkos::create_mirror_view_and_copy(Kokkos::HostSpace(), A.graph.row_map); 3960e3ece09SJunchao Zhang auto Aj_h = Kokkos::create_mirror_view_and_copy(Kokkos::HostSpace(), A.graph.entries); 3970e3ece09SJunchao Zhang auto Bi_h = Kokkos::create_mirror_view_and_copy(Kokkos::HostSpace(), B.graph.row_map); 3980e3ece09SJunchao Zhang auto Bj_h = Kokkos::create_mirror_view_and_copy(Kokkos::HostSpace(), B.graph.entries); 3990e3ece09SJunchao Zhang const MatRowMapType *Ai = Ai_h.data(), *Bi = Bi_h.data(); 4000e3ece09SJunchao Zhang const MatColIdxType *Aj = Aj_h.data(), *Bj = Bj_h.data(); 4010e3ece09SJunchao Zhang 4020e3ece09SJunchao Zhang // Count how many nonzeros of each row in E are in the left of the diag block [cstart,cend) 4037b8d4ba6SJunchao Zhang PetscIntKokkosViewHost E_NzLeft_h(NoInit("E_NzLeft_h"), Em), E_RowLen_h(NoInit("E_RowLen_h"), Em); 4040e3ece09SJunchao Zhang PetscInt *E_NzLeft = E_NzLeft_h.data(), *E_RowLen = E_RowLen_h.data(); 4050e3ece09SJunchao Zhang for (PetscInt i = 0; i < Em; i++) { 4060e3ece09SJunchao Zhang const PetscInt *first, *last, *it; 4070e3ece09SJunchao Zhang PetscInt count, step; 4080e3ece09SJunchao Zhang // std::lower_bound(first,last,cstart), but need to use global column indices 4090e3ece09SJunchao Zhang first = Bj + Bi[i]; 4100e3ece09SJunchao Zhang last = Bj + Bi[i + 1]; 411f0e6e2d1SJunchao Zhang count = last - first; 412f0e6e2d1SJunchao Zhang while (count > 0) { 413f0e6e2d1SJunchao Zhang it = first; 414f0e6e2d1SJunchao Zhang step = count / 2; 415f0e6e2d1SJunchao Zhang it += step; 4160e3ece09SJunchao Zhang if (garray1[*it] < cstart) { // map local to global 417f0e6e2d1SJunchao Zhang first = ++it; 418f0e6e2d1SJunchao Zhang count -= step + 1; 419f0e6e2d1SJunchao Zhang } else count = step; 420f0e6e2d1SJunchao Zhang } 4210e3ece09SJunchao Zhang E_NzLeft[i] = first - (Bj + Bi[i]); 4220e3ece09SJunchao Zhang E_RowLen[i] = (Ai[i + 1] - Ai[i]) + (Bi[i + 1] - Bi[i]); 423f0e6e2d1SJunchao Zhang } 424f0e6e2d1SJunchao Zhang 4250e3ece09SJunchao Zhang // Get length of rows (i.e., sizes of leaves) that contribute to my roots 4260e3ece09SJunchao Zhang const PetscMPIInt *iranks, *ranks; 4270e3ece09SJunchao Zhang const PetscInt *ioffset, *irootloc, *roffset, *rmine; 428c09cee04SJames Wright PetscMPIInt niranks, nranks; 4290e3ece09SJunchao Zhang MPI_Request *reqs; 4300e3ece09SJunchao Zhang PetscMPIInt tag; 4310e3ece09SJunchao Zhang PetscSF reduceSF; 4320e3ece09SJunchao Zhang PetscInt *sdisp, *rdisp; 433f0e6e2d1SJunchao Zhang 4340e3ece09SJunchao Zhang PetscCall(PetscCommGetNewTag(comm, &tag)); 4350e3ece09SJunchao Zhang PetscCall(PetscSFGetLeafRanks(ownerSF, &niranks, &iranks, &ioffset, &irootloc)); // get leaf ranks connecting to roots on this process (I'll recv from them) 4360e3ece09SJunchao Zhang PetscCall(PetscSFGetRootRanks(ownerSF, &nranks, &ranks, &roffset, &rmine, NULL)); // get root ranks this process connects (I'll send to them) 437f0e6e2d1SJunchao Zhang 4380e3ece09SJunchao Zhang // Find out length of each row I will receive. Even for the same row index, when they are from 4390e3ece09SJunchao Zhang // different senders, they might have different lengths (and sparsity patterns) 4400e3ece09SJunchao Zhang PetscInt sendRowCnt = roffset[nranks], recvRowCnt = ioffset[niranks]; 4410e3ece09SJunchao Zhang PetscInt *sendRowLen, *recvRowLen; // lengths of rows of I need to send/recv per process 442f0e6e2d1SJunchao Zhang 4430e3ece09SJunchao Zhang PetscCall(PetscMalloc5(sendRowCnt, &sendRowLen, recvRowCnt + 1, &recvRowLen, nranks, &sdisp, niranks + 1, &rdisp, nranks + niranks, &reqs)); 4440e3ece09SJunchao Zhang 4450e3ece09SJunchao Zhang for (PetscInt i = 0; i < sendRowCnt; i++) sendRowLen[i] = E_RowLen[rmine[i]]; 4460e3ece09SJunchao Zhang recvRowLen[0] = 0; // since we will make it in CSR format later 4470e3ece09SJunchao Zhang recvRowLen++; // advance the pointer now 4486497c311SBarry Smith for (PetscInt i = 0; i < niranks; i++) MPI_Irecv(&recvRowLen[ioffset[i]], ioffset[i + 1] - ioffset[i], MPIU_INT, iranks[i], tag, comm, &reqs[nranks + i]); 4496497c311SBarry Smith for (PetscInt i = 0; i < nranks; i++) MPIU_Isend(&sendRowLen[roffset[i]], roffset[i + 1] - roffset[i], MPIU_INT, ranks[i], tag, comm, &reqs[i]); 4500e3ece09SJunchao Zhang PetscCallMPI(MPI_Waitall(nranks + niranks, reqs, MPI_STATUSES_IGNORE)); 4510e3ece09SJunchao Zhang 4520e3ece09SJunchao Zhang // Build the real PetscSF for reducing E rows (buffer to buffer) 4530e3ece09SJunchao Zhang rdisp[0] = 0; 4540e3ece09SJunchao Zhang for (PetscInt i = 0; i < niranks; i++) { 4550e3ece09SJunchao Zhang rdisp[i + 1] = rdisp[i]; 4560e3ece09SJunchao Zhang for (PetscInt j = ioffset[i]; j < ioffset[i + 1]; j++) { rdisp[i + 1] += recvRowLen[j]; } 4570e3ece09SJunchao Zhang } 4580e3ece09SJunchao Zhang recvRowLen--; // put it back into csr format 4590e3ece09SJunchao Zhang for (PetscInt i = 0; i < recvRowCnt; i++) recvRowLen[i + 1] += recvRowLen[i]; 4600e3ece09SJunchao Zhang 4616497c311SBarry Smith for (PetscInt i = 0; i < nranks; i++) MPIU_Irecv(&sdisp[i], 1, MPIU_INT, ranks[i], tag, comm, &reqs[i]); 4626497c311SBarry Smith for (PetscInt i = 0; i < niranks; i++) MPIU_Isend(&rdisp[i], 1, MPIU_INT, iranks[i], tag, comm, &reqs[nranks + i]); 4630e3ece09SJunchao Zhang PetscCallMPI(MPI_Waitall(nranks + niranks, reqs, MPI_STATUSES_IGNORE)); 4640e3ece09SJunchao Zhang 4650e3ece09SJunchao Zhang PetscInt nleaves = 0, Enz = 0; // leaves are nonzeros I will send 4660e3ece09SJunchao Zhang PetscInt nroots = rdisp[niranks]; // roots are nonzeros I will recv 4670e3ece09SJunchao Zhang PetscSFNode *iremote; 4680e3ece09SJunchao Zhang 4690e3ece09SJunchao Zhang for (PetscInt i = 0; i < Em; i++) Enz += E_RowLen[i]; 4700e3ece09SJunchao Zhang PetscAssert(A.nnz() + B.nnz() == Enz, comm, PETSC_ERR_PLIB, "Enz should be equal to sum of nnz of A and B"); 4710e3ece09SJunchao Zhang PetscCallMPI(PetscMalloc1(Enz, &iremote)); // no free, since we give ownership to reduceSF 4720e3ece09SJunchao Zhang 4730e3ece09SJunchao Zhang for (PetscInt i = 0; i < nranks; i++) { 4740e3ece09SJunchao Zhang PetscInt count = 0; 4750e3ece09SJunchao Zhang for (PetscInt j = roffset[i]; j < roffset[i + 1]; j++) count += E_RowLen[rmine[j]]; 4760e3ece09SJunchao Zhang for (PetscInt j = 0; j < count; j++) { 4770e3ece09SJunchao Zhang iremote[nleaves + j].rank = ranks[i]; 4780e3ece09SJunchao Zhang iremote[nleaves + j].index = sdisp[i] + j; 4790e3ece09SJunchao Zhang } 4800e3ece09SJunchao Zhang nleaves += count; 4810e3ece09SJunchao Zhang } 4820e3ece09SJunchao Zhang PetscCheck(nleaves == Enz, comm, PETSC_ERR_PLIB, "nleaves should be equal to Enz"); 4830e3ece09SJunchao Zhang 4840e3ece09SJunchao Zhang PetscCall(PetscSFCreate(comm, &reduceSF)); 4850e3ece09SJunchao Zhang PetscCall(PetscSFSetGraph(reduceSF, nroots, nleaves, NULL, PETSC_OWN_POINTER, iremote, PETSC_OWN_POINTER)); 4860e3ece09SJunchao Zhang 4870e3ece09SJunchao Zhang // Copy (global) column indices of the needed rows in E to sendCol[], and then PetscSFReduce to recvCol[] 4880e3ece09SJunchao Zhang PetscInt *sendCol, *recvCol; 4890e3ece09SJunchao Zhang PetscCall(PetscMalloc2(nleaves, &sendCol, nroots, &recvCol)); 4900e3ece09SJunchao Zhang for (PetscInt k = 0; k < roffset[nranks]; k++) { 4910e3ece09SJunchao Zhang PetscInt i = rmine[k]; // row to be copied 4920e3ece09SJunchao Zhang PetscInt *buf = &sendCol[Ai[i] + Bi[i]]; 4930e3ece09SJunchao Zhang PetscInt nzLeft = E_NzLeft[i]; 4940e3ece09SJunchao Zhang PetscInt alen = Ai[i + 1] - Ai[i], blen = Bi[i + 1] - Bi[i]; 4950e3ece09SJunchao Zhang for (PetscInt j = 0; j < alen + blen; j++) { 4960e3ece09SJunchao Zhang if (j < nzLeft) { 4970e3ece09SJunchao Zhang buf[j] = garray1[Bj[Bi[i] + j]]; // left B, in global 4980e3ece09SJunchao Zhang } else if (j < nzLeft + alen) { 4990e3ece09SJunchao Zhang buf[j] = Aj[Ai[i] + j - nzLeft] + cstart; // diag A, also in global 5000e3ece09SJunchao Zhang } else { 5010e3ece09SJunchao Zhang buf[j] = garray1[Bj[Bi[i] + j - alen]]; // right B, in global 5020e3ece09SJunchao Zhang } 5030e3ece09SJunchao Zhang } 5040e3ece09SJunchao Zhang } 5050e3ece09SJunchao Zhang PetscCall(PetscSFReduceWithMemTypeBegin(reduceSF, MPIU_INT, PETSC_MEMTYPE_HOST, sendCol, PETSC_MEMTYPE_HOST, recvCol, MPI_REPLACE)); 5060e3ece09SJunchao Zhang PetscCall(PetscSFReduceEnd(reduceSF, MPIU_INT, sendCol, recvCol, MPI_REPLACE)); 5070e3ece09SJunchao Zhang 5080e3ece09SJunchao Zhang // With recvCol[], we do a series of analysis to get i, j of Fd, Fo, and build plans to reduce nonzeros in recv buffers to Fd and Fo 5090e3ece09SJunchao Zhang PetscInt *recvRowPerm, *recvColSorted; 5100e3ece09SJunchao Zhang PetscInt *recvNzPerm, *recvNzPermSorted; 5110e3ece09SJunchao Zhang PetscCall(PetscMalloc4(recvRowCnt, &recvRowPerm, nroots, &recvColSorted, nroots, &recvNzPerm, nroots, &recvNzPermSorted)); 5120e3ece09SJunchao Zhang 5130e3ece09SJunchao Zhang for (PetscInt i = 0; i < nroots; i++) recvNzPerm[i] = i; // numbering all received nonzeros 5140e3ece09SJunchao Zhang for (PetscInt i = 0; i < recvRowCnt; i++) recvRowPerm[i] = i; // put up a permutation array, so that after sorting we know where to get a row in recvCol[] 5150e3ece09SJunchao Zhang PetscCall(PetscSortIntWithPermutation(recvRowCnt, irootloc, recvRowPerm)); // irootloc[] (owned by ownerSF) won't be changed 5160e3ece09SJunchao Zhang 5170e3ece09SJunchao Zhang // i[] array, nz are always easiest to compute 5187b8d4ba6SJunchao Zhang MatRowMapKokkosViewHost Fdi_h(NoInit("Fdi_h"), Fm + 1), Foi_h(NoInit("Foi_h"), Fm + 1); 5190e3ece09SJunchao Zhang MatRowMapType *Fdi, *Foi; 5200e3ece09SJunchao Zhang PetscInt FnzDups = 0, Fdnz = 0, FdnzDups = 0, Fonz = 0, FonzDups = 0; // nz (with or without dups) in F, Fd, Fo 5210e3ece09SJunchao Zhang PetscInt iter; 5220e3ece09SJunchao Zhang 5230e3ece09SJunchao Zhang Kokkos::deep_copy(Fdi_h, 0); // zero, as we will do 'val++' on them 5240e3ece09SJunchao Zhang Kokkos::deep_copy(Foi_h, 0); 5250e3ece09SJunchao Zhang Fdi = Fdi_h.data() + 1; // +1 for easy indexing in code below 5260e3ece09SJunchao Zhang Foi = Foi_h.data() + 1; 5270e3ece09SJunchao Zhang iter = 0; 5280e3ece09SJunchao Zhang while (iter < recvRowCnt) { // iter over received rows 5290e3ece09SJunchao Zhang PetscInt curRowIdx = irootloc[recvRowPerm[iter]]; 5300e3ece09SJunchao Zhang PetscInt dupRows = 1; // current row has this many contributing rows (of various sparsity patterns) 5310e3ece09SJunchao Zhang 5320e3ece09SJunchao Zhang while (iter + dupRows < recvRowCnt && irootloc[recvRowPerm[iter + dupRows]] == curRowIdx) dupRows++; 5330e3ece09SJunchao Zhang 5340e3ece09SJunchao Zhang // Copy column indices (and their permutation) of these rows into recvColSorted & recvNzPermSorted 5350e3ece09SJunchao Zhang PetscInt nz = 0; // nz (with dups) in the current row 5360e3ece09SJunchao Zhang PetscInt *jbuf = recvColSorted + FnzDups; 5370e3ece09SJunchao Zhang PetscInt *pbuf = recvNzPermSorted + FnzDups; 5380e3ece09SJunchao Zhang PetscInt *jbuf2 = jbuf; // temp pointers 5390e3ece09SJunchao Zhang PetscInt *pbuf2 = pbuf; 5400e3ece09SJunchao Zhang for (PetscInt d = 0; d < dupRows; d++) { 5410e3ece09SJunchao Zhang PetscInt i = recvRowPerm[iter + d]; 5420e3ece09SJunchao Zhang PetscInt len = recvRowLen[i + 1] - recvRowLen[i]; 5430e3ece09SJunchao Zhang PetscCall(PetscArraycpy(jbuf2, &recvCol[recvRowLen[i]], len)); 5440e3ece09SJunchao Zhang PetscCall(PetscArraycpy(pbuf2, &recvNzPerm[recvRowLen[i]], len)); 5450e3ece09SJunchao Zhang jbuf2 += len; 5460e3ece09SJunchao Zhang pbuf2 += len; 5470e3ece09SJunchao Zhang nz += len; 5480e3ece09SJunchao Zhang } 5490e3ece09SJunchao Zhang PetscCall(PetscIntSortSemiOrderedWithArray(nz, jbuf, pbuf)); // It could be improved with k-way merge sort, since the rows are already sorted 5500e3ece09SJunchao Zhang 5510e3ece09SJunchao Zhang // Scan column indices (in jbuf[0,nz), might have dups) of this row, and see how many go to Fd and how many go to Fo 5520e3ece09SJunchao Zhang PetscInt cur = 0; 5530e3ece09SJunchao Zhang while (cur < nz) { 5540e3ece09SJunchao Zhang PetscInt curColIdx = jbuf[cur]; 5550e3ece09SJunchao Zhang PetscInt dups = 1; 5560e3ece09SJunchao Zhang 5570e3ece09SJunchao Zhang while (cur + dups < nz && jbuf[cur + dups] == curColIdx) dups++; 5580e3ece09SJunchao Zhang if (curColIdx >= cstart && curColIdx < cend) { 5590e3ece09SJunchao Zhang Fdi[curRowIdx]++; 5600e3ece09SJunchao Zhang FdnzDups += dups; 5610e3ece09SJunchao Zhang } else { 5620e3ece09SJunchao Zhang Foi[curRowIdx]++; 5630e3ece09SJunchao Zhang FonzDups += dups; 5640e3ece09SJunchao Zhang } 5650e3ece09SJunchao Zhang cur += dups; 5660e3ece09SJunchao Zhang } 5670e3ece09SJunchao Zhang 5680e3ece09SJunchao Zhang FnzDups += nz; 5690e3ece09SJunchao Zhang iter += dupRows; // Move to next unique row 5700e3ece09SJunchao Zhang } 5710e3ece09SJunchao Zhang 5720e3ece09SJunchao Zhang Fdi = Fdi_h.data(); // restore Fdi, Foi and make them CSR 5730e3ece09SJunchao Zhang Foi = Foi_h.data(); 5740e3ece09SJunchao Zhang for (PetscInt i = 0; i < Fm; i++) { 5750e3ece09SJunchao Zhang Fdi[i + 1] += Fdi[i]; 5760e3ece09SJunchao Zhang Foi[i + 1] += Foi[i]; 5770e3ece09SJunchao Zhang } 5780e3ece09SJunchao Zhang Fdnz = Fdi[Fm]; 5790e3ece09SJunchao Zhang Fonz = Foi[Fm]; 5800e3ece09SJunchao Zhang PetscCall(PetscFree2(sendCol, recvCol)); 5810e3ece09SJunchao Zhang 5820e3ece09SJunchao Zhang // Allocate j, jmap, jperm for Fd and Fo 5837b8d4ba6SJunchao Zhang MatColIdxKokkosViewHost Fdj_h(NoInit("Fdj_h"), Fdnz), Foj_h(NoInit("Foj_h"), Fonz); 5847b8d4ba6SJunchao Zhang MatRowMapKokkosViewHost Fdjmap_h(NoInit("Fdjmap_h"), Fdnz + 1), Fojmap_h(NoInit("Fojmap_h"), Fonz + 1); // +1 to make csr 5857b8d4ba6SJunchao Zhang MatRowMapKokkosViewHost Fdjperm_h(NoInit("Fdjperm_h"), FdnzDups), Fojperm_h(NoInit("Fojperm_h"), FonzDups); 5860e3ece09SJunchao Zhang MatColIdxType *Fdj = Fdj_h.data(), *Foj = Foj_h.data(); 5870e3ece09SJunchao Zhang MatRowMapType *Fdjmap = Fdjmap_h.data(), *Fojmap = Fojmap_h.data(); 5880e3ece09SJunchao Zhang MatRowMapType *Fdjperm = Fdjperm_h.data(), *Fojperm = Fojperm_h.data(); 5890e3ece09SJunchao Zhang 5900e3ece09SJunchao Zhang // Scan recvColSorted[] again, and fill j, jmap, jperm for Fd and Fo 5910e3ece09SJunchao Zhang Fdjmap[0] = 0; 5920e3ece09SJunchao Zhang Fojmap[0] = 0; 5930e3ece09SJunchao Zhang FnzDups = 0; 5940e3ece09SJunchao Zhang Fdnz = 0; 5950e3ece09SJunchao Zhang Fonz = 0; 5960e3ece09SJunchao Zhang iter = 0; // iter over received rows 5970e3ece09SJunchao Zhang while (iter < recvRowCnt) { 5980e3ece09SJunchao Zhang PetscInt curRowIdx = irootloc[recvRowPerm[iter]]; // current row idx 5990e3ece09SJunchao Zhang PetscInt dupRows = 1; // It has this many contributing rows (of various lengths) 6000e3ece09SJunchao Zhang PetscInt nz = 0; // nz (with dups) in the current row 6010e3ece09SJunchao Zhang 6020e3ece09SJunchao Zhang while (iter + dupRows < recvRowCnt && irootloc[recvRowPerm[iter + dupRows]] == curRowIdx) dupRows++; 6030e3ece09SJunchao Zhang for (PetscInt d = 0; d < dupRows; d++) { 6040e3ece09SJunchao Zhang PetscInt i = recvRowPerm[iter + d]; 6050e3ece09SJunchao Zhang nz += recvRowLen[i + 1] - recvRowLen[i]; 6060e3ece09SJunchao Zhang } 6070e3ece09SJunchao Zhang 6080e3ece09SJunchao Zhang PetscInt *jbuf = recvColSorted + FnzDups; 6090e3ece09SJunchao Zhang // Scan columns (in jbuf[0,nz) of this row, copy them and their permutation to j[] and jperm[] of Fd and Fo 6100e3ece09SJunchao Zhang PetscInt cur = 0; 6110e3ece09SJunchao Zhang while (cur < nz) { 6120e3ece09SJunchao Zhang PetscInt curColIdx = jbuf[cur]; 6130e3ece09SJunchao Zhang PetscInt dups = 1; 6140e3ece09SJunchao Zhang 6150e3ece09SJunchao Zhang while (cur + dups < nz && jbuf[cur + dups] == curColIdx) dups++; 6160e3ece09SJunchao Zhang if (curColIdx >= cstart && curColIdx < cend) { 6170e3ece09SJunchao Zhang Fdj[Fdnz] = curColIdx - cstart; // easily convert to local 6180e3ece09SJunchao Zhang Fdjmap[Fdnz + 1] = Fdjmap[Fdnz] + dups; 6190e3ece09SJunchao Zhang for (PetscInt j = 0; j < dups; j++) Fdjperm[Fdjmap[Fdnz] + j] = recvNzPermSorted[FnzDups + j]; 6200e3ece09SJunchao Zhang FdnzDups += dups; 6210e3ece09SJunchao Zhang Fdnz++; 6220e3ece09SJunchao Zhang } else { 6230e3ece09SJunchao Zhang Foj[Fonz] = curColIdx; // in global 6240e3ece09SJunchao Zhang Fojmap[Fonz + 1] = Fojmap[Fonz] + dups; 6250e3ece09SJunchao Zhang for (PetscInt j = 0; j < dups; j++) Fojperm[Fojmap[Fonz] + j] = recvNzPermSorted[FnzDups + j]; 6260e3ece09SJunchao Zhang FonzDups += dups; 6270e3ece09SJunchao Zhang Fonz++; 6280e3ece09SJunchao Zhang } 6290e3ece09SJunchao Zhang cur += dups; 6300e3ece09SJunchao Zhang FnzDups += dups; 6310e3ece09SJunchao Zhang } 6320e3ece09SJunchao Zhang iter += dupRows; // Move to next unique row 6330e3ece09SJunchao Zhang } 6340e3ece09SJunchao Zhang PetscCall(PetscFree4(recvRowPerm, recvColSorted, recvNzPerm, recvNzPermSorted)); 6350e3ece09SJunchao Zhang PetscCall(PetscFree5(sendRowLen, recvRowLen, sdisp, rdisp, reqs)); 6360e3ece09SJunchao Zhang 6370e3ece09SJunchao Zhang // Combine global column indices in garray1[] and Foj[] 6380e3ece09SJunchao Zhang PetscInt n2, *garray2; 6390e3ece09SJunchao Zhang 6400e3ece09SJunchao Zhang PetscCall(ReduceTwoSetsOfGlobalIndices(n1, garray1, Fonz, Foj, &n2, &garray2, map)); 6410e3ece09SJunchao Zhang mm->sf = reduceSF; 6427b8d4ba6SJunchao Zhang mm->leafBuf = MatScalarKokkosView(NoInit("leafBuf"), nleaves); 6437b8d4ba6SJunchao Zhang mm->rootBuf = MatScalarKokkosView(NoInit("rootBuf"), nroots); 644aaa8cc7dSPierre Jolivet mm->garray = garray2; // give ownership, so no free 6450e3ece09SJunchao Zhang mm->n = n2; 6460e3ece09SJunchao Zhang mm->E_NzLeft = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), E_NzLeft_h); 6470e3ece09SJunchao Zhang mm->Fdjmap = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Fdjmap_h); 6480e3ece09SJunchao Zhang mm->Fdjperm = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Fdjperm_h); 6490e3ece09SJunchao Zhang mm->Fojmap = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Fojmap_h); 6500e3ece09SJunchao Zhang mm->Fojperm = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Fojperm_h); 6510e3ece09SJunchao Zhang 6520e3ece09SJunchao Zhang // Output Fd and Fo in KokkosCsrMatrix format 6537b8d4ba6SJunchao Zhang MatScalarKokkosView Fda_d(NoInit("Fda_d"), Fdnz); 6540e3ece09SJunchao Zhang MatRowMapKokkosView Fdi_d = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Fdi_h); 6550e3ece09SJunchao Zhang MatColIdxKokkosView Fdj_d = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Fdj_h); 6567b8d4ba6SJunchao Zhang MatScalarKokkosView Foa_d(NoInit("Foa_d"), Fonz); 6570e3ece09SJunchao Zhang MatRowMapKokkosView Foi_d = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Foi_h); 6580e3ece09SJunchao Zhang MatColIdxKokkosView Foj_d = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Foj_h); 6590e3ece09SJunchao Zhang 6600e3ece09SJunchao Zhang PetscCallCXX(mm->Fd = KokkosCsrMatrix("Fd", Fm, cend - cstart, Fdnz, Fda_d, Fdi_d, Fdj_d)); 6610e3ece09SJunchao Zhang PetscCallCXX(mm->Fo = KokkosCsrMatrix("Fo", Fm, n2, Fonz, Foa_d, Foi_d, Foj_d)); // Fo's column size is n2, length of garray2[] 6620e3ece09SJunchao Zhang 6630e3ece09SJunchao Zhang // Compute kernel launch parameters in merging E 6640e3ece09SJunchao Zhang PetscInt teamSize, vectorLength, rowsPerTeam; 6650e3ece09SJunchao Zhang 6660e3ece09SJunchao Zhang teamSize = vectorLength = rowsPerTeam = -1; 6670e3ece09SJunchao Zhang PetscCall(MatMergeGetLaunchParameters<DefaultExecutionSpace>(Em, Enz, -1, teamSize, vectorLength, rowsPerTeam)); 6680e3ece09SJunchao Zhang mm->E_TeamSize = teamSize; 6690e3ece09SJunchao Zhang mm->E_VectorLength = vectorLength; 6700e3ece09SJunchao Zhang mm->E_RowsPerTeam = rowsPerTeam; 6710e3ece09SJunchao Zhang } else PetscCheck(reuse == MAT_REUSE_MATRIX, comm, PETSC_ERR_PLIB, "Unsupported MatReuse enum %d", reuse); 6720e3ece09SJunchao Zhang 6730e3ece09SJunchao Zhang // Handy aliases 6740e3ece09SJunchao Zhang auto &Aa = A.values; 6750e3ece09SJunchao Zhang auto &Ba = B.values; 6760e3ece09SJunchao Zhang const auto &Ai = A.graph.row_map; 6770e3ece09SJunchao Zhang const auto &Bi = B.graph.row_map; 6780e3ece09SJunchao Zhang const auto &E_NzLeft = mm->E_NzLeft; 6790e3ece09SJunchao Zhang auto &leafBuf = mm->leafBuf; 6800e3ece09SJunchao Zhang auto &rootBuf = mm->rootBuf; 6810e3ece09SJunchao Zhang PetscSF reduceSF = mm->sf; 6820e3ece09SJunchao Zhang PetscInt Em = A.numRows(); 6830e3ece09SJunchao Zhang PetscInt teamSize = mm->E_TeamSize; 6840e3ece09SJunchao Zhang PetscInt vectorLength = mm->E_VectorLength; 6850e3ece09SJunchao Zhang PetscInt rowsPerTeam = mm->E_RowsPerTeam; 6860e3ece09SJunchao Zhang PetscInt workSets = (Em + rowsPerTeam - 1) / rowsPerTeam; 6870e3ece09SJunchao Zhang 6880e3ece09SJunchao Zhang // Copy rows in A/B of E to leafBuf, then pass it to rootBuf 6890e3ece09SJunchao Zhang PetscCallCXX(Kokkos::parallel_for( 690d326c3f1SJunchao Zhang Kokkos::TeamPolicy<>(PetscGetKokkosExecutionSpace(), workSets, teamSize, vectorLength), KOKKOS_LAMBDA(const KokkosTeamMemberType &t) { 6910e3ece09SJunchao Zhang Kokkos::parallel_for(Kokkos::TeamThreadRange(t, 0, rowsPerTeam), [&](PetscInt k) { 6920e3ece09SJunchao Zhang PetscInt i = t.league_rank() * rowsPerTeam + k; // i-th row in F 6930e3ece09SJunchao Zhang if (i < Em) { 6940e3ece09SJunchao Zhang PetscInt disp = Ai(i) + Bi(i); 6950e3ece09SJunchao Zhang PetscInt alen = Ai(i + 1) - Ai(i); 6960e3ece09SJunchao Zhang PetscInt blen = Bi(i + 1) - Bi(i); 6970e3ece09SJunchao Zhang PetscInt nzleft = E_NzLeft(i); 6980e3ece09SJunchao Zhang 6990e3ece09SJunchao Zhang Kokkos::parallel_for(Kokkos::ThreadVectorRange(t, alen + blen), [&](PetscInt j) { 7000e3ece09SJunchao Zhang MatScalar &val = leafBuf(disp + j); 7010e3ece09SJunchao Zhang if (j < nzleft) { // B left 7020e3ece09SJunchao Zhang val = Ba(Bi(i) + j); 7030e3ece09SJunchao Zhang } else if (j < nzleft + alen) { // diag A 7040e3ece09SJunchao Zhang val = Aa(Ai(i) + j - nzleft); 7050e3ece09SJunchao Zhang } else { // B right 7060e3ece09SJunchao Zhang val = Ba(Bi(i) + j - alen); 707f0e6e2d1SJunchao Zhang } 708f0e6e2d1SJunchao Zhang }); 709f0e6e2d1SJunchao Zhang } 710f0e6e2d1SJunchao Zhang }); 7110e3ece09SJunchao Zhang })); 7120e3ece09SJunchao Zhang PetscCall(PetscSFReduceWithMemTypeBegin(reduceSF, MPIU_SCALAR, PETSC_MEMTYPE_KOKKOS, leafBuf.data(), PETSC_MEMTYPE_KOKKOS, rootBuf.data(), MPI_REPLACE)); 713f0e6e2d1SJunchao Zhang PetscFunctionReturn(PETSC_SUCCESS); 714f0e6e2d1SJunchao Zhang } 7150e3ece09SJunchao Zhang 716aaa8cc7dSPierre Jolivet // To finish MatMPIAIJKokkosReduce. 7170e3ece09SJunchao Zhang static PetscErrorCode MatMPIAIJKokkosReduceEnd(MPI_Comm comm, KokkosCsrMatrix A, KokkosCsrMatrix B, PetscInt cstart, PetscInt cend, const PetscInt *garray1, PetscSF ownerSF, MatReuse reuse, PetscInt *map, MatMatStruct_AtB *mm) 7180e3ece09SJunchao Zhang { 7190e3ece09SJunchao Zhang auto &leafBuf = mm->leafBuf; 7200e3ece09SJunchao Zhang auto &rootBuf = mm->rootBuf; 7210e3ece09SJunchao Zhang auto &Fda = mm->Fd.values; 7220e3ece09SJunchao Zhang const auto &Fdjmap = mm->Fdjmap; 7230e3ece09SJunchao Zhang const auto &Fdjperm = mm->Fdjperm; 7240e3ece09SJunchao Zhang auto Fdnz = mm->Fd.nnz(); 7250e3ece09SJunchao Zhang auto &Foa = mm->Fo.values; 7260e3ece09SJunchao Zhang const auto &Fojmap = mm->Fojmap; 7270e3ece09SJunchao Zhang const auto &Fojperm = mm->Fojperm; 7280e3ece09SJunchao Zhang auto Fonz = mm->Fo.nnz(); 7290e3ece09SJunchao Zhang PetscSF reduceSF = mm->sf; 7300e3ece09SJunchao Zhang 731d326c3f1SJunchao Zhang PetscFunctionBegin; 7320e3ece09SJunchao Zhang PetscCall(PetscSFReduceEnd(reduceSF, MPIU_SCALAR, leafBuf.data(), rootBuf.data(), MPI_REPLACE)); 7330e3ece09SJunchao Zhang 7340e3ece09SJunchao Zhang // Reduce data in rootBuf to Fd and Fo 7350e3ece09SJunchao Zhang PetscCallCXX(Kokkos::parallel_for( 736d326c3f1SJunchao Zhang Kokkos::RangePolicy<>(PetscGetKokkosExecutionSpace(), 0, Fdnz), KOKKOS_LAMBDA(const MatRowMapType i) { 7370e3ece09SJunchao Zhang PetscScalar sum = 0.0; 7380e3ece09SJunchao Zhang for (MatRowMapType k = Fdjmap(i); k < Fdjmap(i + 1); k++) sum += rootBuf(Fdjperm(k)); 7390e3ece09SJunchao Zhang Fda(i) = sum; 7400e3ece09SJunchao Zhang })); 7410e3ece09SJunchao Zhang 7420e3ece09SJunchao Zhang PetscCallCXX(Kokkos::parallel_for( 743d326c3f1SJunchao Zhang Kokkos::RangePolicy<>(PetscGetKokkosExecutionSpace(), 0, Fonz), KOKKOS_LAMBDA(const MatRowMapType i) { 7440e3ece09SJunchao Zhang PetscScalar sum = 0.0; 7450e3ece09SJunchao Zhang for (MatRowMapType k = Fojmap(i); k < Fojmap(i + 1); k++) sum += rootBuf(Fojperm(k)); 7460e3ece09SJunchao Zhang Foa(i) = sum; 7470e3ece09SJunchao Zhang })); 7480e3ece09SJunchao Zhang PetscFunctionReturn(PETSC_SUCCESS); 7490e3ece09SJunchao Zhang } 7500e3ece09SJunchao Zhang 7510e3ece09SJunchao Zhang /* 7520e3ece09SJunchao Zhang MatMPIAIJKokkosBcast - Bcast local rows of a MPIAIJKOKKOS matrix (E) to produce a local matrix (F, stored in mm) in split form 7530e3ece09SJunchao Zhang 7540e3ece09SJunchao Zhang This is a complex routine. It is essentially the MPIAIJKOKKOS counterpart of MatGetBrowsOfAoCols_MPIAIJ, but supports 7550e3ece09SJunchao Zhang device and involves various index mapping. 7560e3ece09SJunchao Zhang 7570e3ece09SJunchao Zhang In the given ownerSF, leaves correspond to rows in F, and roots correspond to rows in E. Roots may connect to multiple leaves. 7580e3ece09SJunchao Zhang Suppose F's j-th row is connected to a root identified by PetscSFNode (k,i), it means we need to bcast the i-th row of E on rank k 7590e3ece09SJunchao Zhang to j-th row of F. ownerSF is not an arbitrary SF, instead it is the Mvctx of another MPIAIJ matrix A that is able to perform A*E. 7600e3ece09SJunchao Zhang F has the same column layout as E. 7610e3ece09SJunchao Zhang 7620e3ece09SJunchao Zhang Conceptually F has global column indices. In this routine, we spit F into diagonal Fd and off-diagonal Fo. 763aaa8cc7dSPierre Jolivet Fd uses local column indices, which are easy to compute. We just need to subtract the "local column range start" from the global indices. 7640e3ece09SJunchao Zhang Fo had global column indices at first. We will reduce them into local ones. In doing that, we also take into account the global 7650e3ece09SJunchao Zhang column indices that E's off-diag block has. Let's say there are n1 such indices stored in garray1[]. We will reduce them along with 7660e3ece09SJunchao Zhang column indices in Fo and update Fo with local indices. 7670e3ece09SJunchao Zhang 7680e3ece09SJunchao Zhang Input Parameters: 7690e3ece09SJunchao Zhang + E - the MPIAIJKOKKOS matrix 7709c89aa79SPierre Jolivet . ownerSF - the ownership SF (insignificant in MAT_REUSE_MATRIX) 7710e3ece09SJunchao Zhang . reuse - either MAT_INITIAL_MATRIX or MAT_REUSE_MATRIX 7720e3ece09SJunchao Zhang - mm - to stash matproduct intermediate data structures 7730e3ece09SJunchao Zhang 7740e3ece09SJunchao Zhang Output Parameters: 7750e3ece09SJunchao Zhang + map[n1] - allocated by caller. It maps garray1[] to garray2[]. See more at ReduceTwoSetsOfGlobalIndices. 7760e3ece09SJunchao Zhang - mm - contains various info, such as garray2[], Fd, Fo, etc. 7770e3ece09SJunchao Zhang 7780e3ece09SJunchao Zhang Notes: 7790e3ece09SJunchao Zhang When reuse = MAT_REUSE_MATRIX, ownerSF, map are not significant. 7800e3ece09SJunchao Zhang The routine is provide in split-phase form MatMPIAIJKokkosBcastBegin/End() to provide computation/communication opportunities. 7810e3ece09SJunchao Zhang */ 7820e3ece09SJunchao Zhang static PetscErrorCode MatMPIAIJKokkosBcastBegin(Mat E, PetscSF ownerSF, MatReuse reuse, PetscInt *map, MatMatStruct_AB *mm) 7830e3ece09SJunchao Zhang { 7840e3ece09SJunchao Zhang Mat_MPIAIJ *empi = static_cast<Mat_MPIAIJ *>(E->data); 7850e3ece09SJunchao Zhang Mat A = empi->A, B = empi->B; // diag and off-diag 7860e3ece09SJunchao Zhang Mat_SeqAIJKokkos *akok = static_cast<Mat_SeqAIJKokkos *>(A->spptr), *bkok = static_cast<Mat_SeqAIJKokkos *>(B->spptr); 7870e3ece09SJunchao Zhang PetscInt Em = E->rmap->n; // #local rows 7880e3ece09SJunchao Zhang MPI_Comm comm; 7890e3ece09SJunchao Zhang 7900e3ece09SJunchao Zhang PetscFunctionBegin; 7910e3ece09SJunchao Zhang PetscCallMPI(PetscObjectGetComm((PetscObject)E, &comm)); 7920e3ece09SJunchao Zhang if (reuse == MAT_INITIAL_MATRIX) { 7930e3ece09SJunchao Zhang Mat_SeqAIJ *aseq = static_cast<Mat_SeqAIJ *>(A->data), *bseq = static_cast<Mat_SeqAIJ *>(B->data); 7940e3ece09SJunchao Zhang PetscInt n1 = B->cmap->n, *Ai = aseq->i, *Aj = aseq->j, *Bi = bseq->i, *Bj = bseq->j; 7950e3ece09SJunchao Zhang const PetscInt *garray1 = empi->garray; // its size is n1 7960e3ece09SJunchao Zhang PetscInt cstart, cend; 7970e3ece09SJunchao Zhang PetscSF bcastSF; 7980e3ece09SJunchao Zhang 7990e3ece09SJunchao Zhang PetscCall(MatGetOwnershipRangeColumn(E, &cstart, &cend)); 8000e3ece09SJunchao Zhang 8010e3ece09SJunchao Zhang // Count how many nonzeros of each row in E are in the left of the diag block [cstart,cend) 8027b8d4ba6SJunchao Zhang PetscIntKokkosViewHost E_NzLeft_h(NoInit("E_NzLeft_h"), Em), E_RowLen_h(NoInit("E_RowLen_h"), Em); 8030e3ece09SJunchao Zhang PetscInt *E_NzLeft = E_NzLeft_h.data(), *E_RowLen = E_RowLen_h.data(); 8040e3ece09SJunchao Zhang for (PetscInt i = 0; i < Em; i++) { 8050e3ece09SJunchao Zhang const PetscInt *first, *last, *it; 8060e3ece09SJunchao Zhang PetscInt count, step; 8070e3ece09SJunchao Zhang // std::lower_bound(first,last,cstart), but need to use global column indices 8080e3ece09SJunchao Zhang first = Bj + Bi[i]; 8090e3ece09SJunchao Zhang last = Bj + Bi[i + 1]; 8100e3ece09SJunchao Zhang count = last - first; 8110e3ece09SJunchao Zhang while (count > 0) { 8120e3ece09SJunchao Zhang it = first; 8130e3ece09SJunchao Zhang step = count / 2; 8140e3ece09SJunchao Zhang it += step; 8150e3ece09SJunchao Zhang if (empi->garray[*it] < cstart) { // map local to global 8160e3ece09SJunchao Zhang first = ++it; 8170e3ece09SJunchao Zhang count -= step + 1; 8180e3ece09SJunchao Zhang } else count = step; 8190e3ece09SJunchao Zhang } 8200e3ece09SJunchao Zhang E_NzLeft[i] = first - (Bj + Bi[i]); 8210e3ece09SJunchao Zhang E_RowLen[i] = (Ai[i + 1] - Ai[i]) + (Bi[i + 1] - Bi[i]); 8220e3ece09SJunchao Zhang } 8230e3ece09SJunchao Zhang 8240e3ece09SJunchao Zhang // Compute row pointer Fi of F 8250e3ece09SJunchao Zhang PetscInt *Fi, Fm, Fnz; 8260e3ece09SJunchao Zhang PetscCall(PetscSFGetGraph(ownerSF, NULL, &Fm, NULL, NULL)); // Fm = #rows of F = nleaves of ownerSF 8270e3ece09SJunchao Zhang PetscCall(PetscMalloc1(Fm + 1, &Fi)); 8280e3ece09SJunchao Zhang Fi[0] = 0; 8290e3ece09SJunchao Zhang PetscCall(PetscSFBcastWithMemTypeBegin(ownerSF, MPIU_INT, PETSC_MEMTYPE_HOST, E_RowLen, PETSC_MEMTYPE_HOST, &Fi[1], MPI_REPLACE)); 8300e3ece09SJunchao Zhang PetscCall(PetscSFBcastEnd(ownerSF, MPIU_INT, E_RowLen, &Fi[1], MPI_REPLACE)); 8310e3ece09SJunchao Zhang for (PetscInt i = 0; i < Fm; i++) Fi[i + 1] += Fi[i]; 8320e3ece09SJunchao Zhang Fnz = Fi[Fm]; 8330e3ece09SJunchao Zhang 8340e3ece09SJunchao Zhang // Build the real PetscSF for bcasting E rows (buffer to buffer) 8350e3ece09SJunchao Zhang const PetscMPIInt *iranks, *ranks; 8360e3ece09SJunchao Zhang const PetscInt *ioffset, *irootloc, *roffset; 837c09cee04SJames Wright PetscMPIInt niranks, nranks; 838c09cee04SJames Wright PetscInt *sdisp, *rdisp; 8390e3ece09SJunchao Zhang MPI_Request *reqs; 8400e3ece09SJunchao Zhang PetscMPIInt tag; 8410e3ece09SJunchao Zhang 8420e3ece09SJunchao Zhang PetscCall(PetscSFGetLeafRanks(ownerSF, &niranks, &iranks, &ioffset, &irootloc)); // get leaf ranks referencing roots on this process 8430e3ece09SJunchao Zhang PetscCall(PetscSFGetRootRanks(ownerSF, &nranks, &ranks, &roffset, NULL, NULL)); // recv info 8440e3ece09SJunchao Zhang PetscCall(PetscMalloc3(niranks + 1, &sdisp, nranks, &rdisp, niranks + nranks, &reqs)); 8450e3ece09SJunchao Zhang 8460e3ece09SJunchao Zhang sdisp[0] = 0; // send displacement 8470e3ece09SJunchao Zhang for (PetscInt i = 0; i < niranks; i++) { 8480e3ece09SJunchao Zhang sdisp[i + 1] = sdisp[i]; 8490e3ece09SJunchao Zhang for (PetscInt j = ioffset[i]; j < ioffset[i + 1]; j++) { 8500e3ece09SJunchao Zhang PetscInt r = irootloc[j]; // row to be sent 8510e3ece09SJunchao Zhang sdisp[i + 1] += E_RowLen[r]; 8520e3ece09SJunchao Zhang } 8530e3ece09SJunchao Zhang } 8540e3ece09SJunchao Zhang 8550e3ece09SJunchao Zhang PetscCallMPI(PetscCommGetNewTag(comm, &tag)); 8566497c311SBarry Smith for (PetscInt i = 0; i < nranks; i++) PetscCallMPI(MPIU_Irecv(&rdisp[i], 1, MPIU_INT, ranks[i], tag, comm, &reqs[i])); 8576497c311SBarry Smith for (PetscInt i = 0; i < niranks; i++) PetscCallMPI(MPIU_Isend(&sdisp[i], 1, MPIU_INT, iranks[i], tag, comm, &reqs[nranks + i])); 8580e3ece09SJunchao Zhang PetscCallMPI(MPI_Waitall(niranks + nranks, reqs, MPI_STATUSES_IGNORE)); 8590e3ece09SJunchao Zhang 8600e3ece09SJunchao Zhang PetscInt nleaves = Fnz; // leaves are nonzeros I will receive 8610e3ece09SJunchao Zhang PetscInt nroots = sdisp[niranks]; // roots are nonzeros I will send 8620e3ece09SJunchao Zhang PetscSFNode *iremote; // give ownership to bcastSF 8630e3ece09SJunchao Zhang PetscCall(PetscMalloc1(nleaves, &iremote)); 8640e3ece09SJunchao Zhang for (PetscInt i = 0; i < nranks; i++) { // for each sender rank 8650e3ece09SJunchao Zhang PetscInt k = 0; 8660e3ece09SJunchao Zhang for (PetscInt j = Fi[roffset[i]]; j < Fi[roffset[i + 1]]; j++) { // I will receive rows [roffset[i], roffset[i+1]) of F from ranks[i] 8670e3ece09SJunchao Zhang iremote[j].rank = ranks[i]; 8680e3ece09SJunchao Zhang iremote[j].index = rdisp[i] + k; // their root location 8690e3ece09SJunchao Zhang k++; 8700e3ece09SJunchao Zhang } 8710e3ece09SJunchao Zhang } 8720e3ece09SJunchao Zhang PetscCall(PetscSFCreate(comm, &bcastSF)); 8730e3ece09SJunchao Zhang PetscCall(PetscSFSetGraph(bcastSF, nroots, nleaves, NULL, PETSC_OWN_POINTER, iremote, PETSC_OWN_POINTER)); 8740e3ece09SJunchao Zhang PetscCall(PetscFree3(sdisp, rdisp, reqs)); 8750e3ece09SJunchao Zhang 8760e3ece09SJunchao Zhang // Build a plan (rowoffset, irootloc, E_NzLeft) to copy rows in E to rootdata of bcastSF in parallel 8777b8d4ba6SJunchao Zhang PetscIntKokkosViewHost rowoffset_h(NoInit("rowoffset_h"), ioffset[niranks] + 1); 8780e3ece09SJunchao Zhang PetscInt *rowoffset = rowoffset_h.data(); // for each entry (row) indicated in irootloc[], we calculate its destinate offset in copying 8790e3ece09SJunchao Zhang rowoffset[0] = 0; 8807b8d4ba6SJunchao Zhang for (PetscInt i = 0; i < ioffset[niranks]; i++) { rowoffset[i + 1] = rowoffset[i] + E_RowLen[irootloc[i]]; } 8810e3ece09SJunchao Zhang 8820e3ece09SJunchao Zhang // Copy (global) column indices of the needed rows in E to a buffer, and then bcast to Fj[] 8830e3ece09SJunchao Zhang PetscInt *jbuf, *Fj; 8840e3ece09SJunchao Zhang PetscCall(PetscMalloc2(nroots, &jbuf, Fnz, &Fj)); 8850e3ece09SJunchao Zhang for (PetscInt k = 0; k < ioffset[niranks]; k++) { 8860e3ece09SJunchao Zhang PetscInt i = irootloc[k]; // row to be copied 8870e3ece09SJunchao Zhang PetscInt *buf = &jbuf[rowoffset[k]]; 8880e3ece09SJunchao Zhang PetscInt nzLeft = E_NzLeft[i]; 8890e3ece09SJunchao Zhang PetscInt alen = Ai[i + 1] - Ai[i], blen = Bi[i + 1] - Bi[i]; 8900e3ece09SJunchao Zhang for (PetscInt j = 0; j < alen + blen; j++) { 8910e3ece09SJunchao Zhang if (j < nzLeft) { 8920e3ece09SJunchao Zhang buf[j] = empi->garray[Bj[Bi[i] + j]]; // left B, in global 8930e3ece09SJunchao Zhang } else if (j < nzLeft + alen) { 8940e3ece09SJunchao Zhang buf[j] = Aj[Ai[i] + j - nzLeft] + cstart; // diag A, also in global 8950e3ece09SJunchao Zhang } else { 8960e3ece09SJunchao Zhang buf[j] = empi->garray[Bj[Bi[i] + j - alen]]; // right B, in global 8970e3ece09SJunchao Zhang } 8980e3ece09SJunchao Zhang } 8990e3ece09SJunchao Zhang } 9000e3ece09SJunchao Zhang PetscCall(PetscSFBcastWithMemTypeBegin(bcastSF, MPIU_INT, PETSC_MEMTYPE_HOST, jbuf, PETSC_MEMTYPE_HOST, Fj, MPI_REPLACE)); 9010e3ece09SJunchao Zhang PetscCall(PetscSFBcastEnd(bcastSF, MPIU_INT, jbuf, Fj, MPI_REPLACE)); 9020e3ece09SJunchao Zhang 9030e3ece09SJunchao Zhang // Build a plan (i.e., F_NzLeft) to split F into Fd and Fo 9047b8d4ba6SJunchao Zhang MatRowMapKokkosViewHost Fdi_h(NoInit("Fdi_h"), Fm + 1), Foi_h(NoInit("Foi_h"), Fm + 1); // row pointer of Fd, Fo 9057b8d4ba6SJunchao Zhang MatColIdxKokkosViewHost F_NzLeft_h(NoInit("F_NzLeft_h"), Fm); // split each row of F into Left, Diag, Right. We only need to record #nz in Left and Diag. 9060e3ece09SJunchao Zhang MatRowMapType *Fdi = Fdi_h.data(), *Foi = Foi_h.data(); 9070e3ece09SJunchao Zhang MatColIdxType *F_NzLeft = F_NzLeft_h.data(); 9080e3ece09SJunchao Zhang 9090e3ece09SJunchao Zhang Fdi[0] = Foi[0] = 0; 9100e3ece09SJunchao Zhang for (PetscInt i = 0; i < Fm; i++) { 9110e3ece09SJunchao Zhang PetscInt *first, *last, *lb1, *lb2; 9120e3ece09SJunchao Zhang // cut the row into: Left, [cstart, cend), Right 9130e3ece09SJunchao Zhang first = Fj + Fi[i]; 9140e3ece09SJunchao Zhang last = Fj + Fi[i + 1]; 9150e3ece09SJunchao Zhang lb1 = std::lower_bound(first, last, cstart); 9160e3ece09SJunchao Zhang F_NzLeft[i] = lb1 - first; 9170e3ece09SJunchao Zhang lb2 = std::lower_bound(first, last, cend); 9180e3ece09SJunchao Zhang Fdi[i + 1] = lb2 - lb1; // row i length in Fdi 9190e3ece09SJunchao Zhang Foi[i + 1] = (Fi[i + 1] - Fi[i]) - Fdi[i + 1]; // row i length in Foi 9200e3ece09SJunchao Zhang } 9210e3ece09SJunchao Zhang for (PetscInt i = 0; i < Fm; i++) { 9220e3ece09SJunchao Zhang Fdi[i + 1] += Fdi[i]; 9230e3ece09SJunchao Zhang Foi[i + 1] += Foi[i]; 9240e3ece09SJunchao Zhang } 9250e3ece09SJunchao Zhang 9260e3ece09SJunchao Zhang // Fill Fdj[] and Foj[], i.e., columns of Fd and Fo. Fdj[] are local, but Foj[] are not yet. 9270e3ece09SJunchao Zhang PetscInt Fdnz = Fdi[Fm], Fonz = Foi[Fm]; 9287b8d4ba6SJunchao Zhang MatColIdxKokkosViewHost Fdj_h(NoInit("Fdj_h"), Fdnz), Foj_h(NoInit("Foj_h"), Fonz); 9290e3ece09SJunchao Zhang MatColIdxType *Fdj = Fdj_h.data(), *Foj = Foj_h.data(), gid; 9300e3ece09SJunchao Zhang 9310e3ece09SJunchao Zhang for (PetscInt i = 0; i < Fm; i++) { 9320e3ece09SJunchao Zhang PetscInt nzLeft = F_NzLeft[i]; 9330e3ece09SJunchao Zhang PetscInt len = Fdi[i + 1] - Fdi[i]; // diag row len 9340e3ece09SJunchao Zhang for (PetscInt j = 0; j < Fi[i + 1] - Fi[i]; j++) { 9350e3ece09SJunchao Zhang gid = Fj[Fi[i] + j]; 9360e3ece09SJunchao Zhang if (j < nzLeft) { // left, in global 9370e3ece09SJunchao Zhang Foj[Foi[i] + j] = gid; 9380e3ece09SJunchao Zhang } else if (j < nzLeft + len) { // diag, in local 9390e3ece09SJunchao Zhang Fdj[Fdi[i] + j - nzLeft] = gid - cstart; 9400e3ece09SJunchao Zhang } else { // right, in global 9410e3ece09SJunchao Zhang Foj[Foi[i] + j - len] = gid; 9420e3ece09SJunchao Zhang } 9430e3ece09SJunchao Zhang } 9440e3ece09SJunchao Zhang } 9450e3ece09SJunchao Zhang PetscCall(PetscFree2(jbuf, Fj)); 9460e3ece09SJunchao Zhang PetscCall(PetscFree(Fi)); 9470e3ece09SJunchao Zhang 9480e3ece09SJunchao Zhang // Reduce global indices in Foj[] and garray1[] into local ones 9490e3ece09SJunchao Zhang PetscInt n2, *garray2; 9500e3ece09SJunchao Zhang PetscCall(ReduceTwoSetsOfGlobalIndices(n1, garray1, Fonz, Foj, &n2, &garray2, map)); 9510e3ece09SJunchao Zhang 9520e3ece09SJunchao Zhang // Record the plans built above, for reuse 9530e3ece09SJunchao Zhang PetscIntKokkosViewHost tmp(const_cast<PetscInt *>(irootloc), ioffset[niranks]); // irootloc[] is owned by ownerSF. We create a copy for safety 9547b8d4ba6SJunchao Zhang PetscIntKokkosViewHost irootloc_h(NoInit("irootloc_h"), ioffset[niranks]); 9550e3ece09SJunchao Zhang Kokkos::deep_copy(irootloc_h, tmp); 9560e3ece09SJunchao Zhang mm->sf = bcastSF; 9570e3ece09SJunchao Zhang mm->E_NzLeft = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), E_NzLeft_h); 9580e3ece09SJunchao Zhang mm->F_NzLeft = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), F_NzLeft_h); 9590e3ece09SJunchao Zhang mm->irootloc = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), irootloc_h); 9600e3ece09SJunchao Zhang mm->rowoffset = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), rowoffset_h); 9617b8d4ba6SJunchao Zhang mm->rootBuf = MatScalarKokkosView(NoInit("rootBuf"), nroots); 9627b8d4ba6SJunchao Zhang mm->leafBuf = MatScalarKokkosView(NoInit("leafBuf"), nleaves); 9630e3ece09SJunchao Zhang mm->garray = garray2; 9640e3ece09SJunchao Zhang mm->n = n2; 9650e3ece09SJunchao Zhang 9660e3ece09SJunchao Zhang // Output Fd and Fo in KokkosCsrMatrix format 9677b8d4ba6SJunchao Zhang MatScalarKokkosView Fda_d(NoInit("Fda_d"), Fdnz), Foa_d(NoInit("Foa_d"), Fonz); 9680e3ece09SJunchao Zhang MatRowMapKokkosView Fdi_d = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Fdi_h); 9690e3ece09SJunchao Zhang MatColIdxKokkosView Fdj_d = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Fdj_h); 9700e3ece09SJunchao Zhang MatRowMapKokkosView Foi_d = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Foi_h); 9710e3ece09SJunchao Zhang MatColIdxKokkosView Foj_d = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Foj_h); 9720e3ece09SJunchao Zhang 9730e3ece09SJunchao Zhang PetscCallCXX(mm->Fd = KokkosCsrMatrix("Fd", Fm, cend - cstart, Fdnz, Fda_d, Fdi_d, Fdj_d)); 9740e3ece09SJunchao Zhang PetscCallCXX(mm->Fo = KokkosCsrMatrix("Fo", Fm, n2, Fonz, Foa_d, Foi_d, Foj_d)); 9750e3ece09SJunchao Zhang 9760e3ece09SJunchao Zhang // Compute kernel launch parameters in merging E or splitting F 9770e3ece09SJunchao Zhang PetscInt teamSize, vectorLength, rowsPerTeam; 9780e3ece09SJunchao Zhang 9790e3ece09SJunchao Zhang teamSize = vectorLength = rowsPerTeam = -1; 9800e3ece09SJunchao Zhang PetscCall(MatMergeGetLaunchParameters<DefaultExecutionSpace>(mm->irootloc.extent(0), mm->rootBuf.extent(0), -1, teamSize, vectorLength, rowsPerTeam)); 9810e3ece09SJunchao Zhang mm->E_TeamSize = teamSize; 9820e3ece09SJunchao Zhang mm->E_VectorLength = vectorLength; 9830e3ece09SJunchao Zhang mm->E_RowsPerTeam = rowsPerTeam; 9840e3ece09SJunchao Zhang 9850e3ece09SJunchao Zhang teamSize = vectorLength = rowsPerTeam = -1; 9860e3ece09SJunchao Zhang PetscCall(MatMergeGetLaunchParameters<DefaultExecutionSpace>(Fm, Fnz, -1, teamSize, vectorLength, rowsPerTeam)); 9870e3ece09SJunchao Zhang mm->F_TeamSize = teamSize; 9880e3ece09SJunchao Zhang mm->F_VectorLength = vectorLength; 9890e3ece09SJunchao Zhang mm->F_RowsPerTeam = rowsPerTeam; 9900e3ece09SJunchao Zhang } else PetscCheck(reuse == MAT_REUSE_MATRIX, comm, PETSC_ERR_PLIB, "Unsupported MatReuse enum %d", reuse); 9910e3ece09SJunchao Zhang 9920e3ece09SJunchao Zhang // Sync E's value to device 9930e3ece09SJunchao Zhang akok->a_dual.sync_device(); 9940e3ece09SJunchao Zhang bkok->a_dual.sync_device(); 9950e3ece09SJunchao Zhang 9960e3ece09SJunchao Zhang // Handy aliases 9970e3ece09SJunchao Zhang const auto &Aa = akok->a_dual.view_device(); 9980e3ece09SJunchao Zhang const auto &Ba = bkok->a_dual.view_device(); 9990e3ece09SJunchao Zhang const auto &Ai = akok->i_dual.view_device(); 10000e3ece09SJunchao Zhang const auto &Bi = bkok->i_dual.view_device(); 10010e3ece09SJunchao Zhang 10020e3ece09SJunchao Zhang // Fetch the plans 10030e3ece09SJunchao Zhang PetscIntKokkosView &E_NzLeft = mm->E_NzLeft; 10040e3ece09SJunchao Zhang PetscSF &bcastSF = mm->sf; 10050e3ece09SJunchao Zhang MatScalarKokkosView &rootBuf = mm->rootBuf; 10060e3ece09SJunchao Zhang MatScalarKokkosView &leafBuf = mm->leafBuf; 10070e3ece09SJunchao Zhang PetscIntKokkosView &irootloc = mm->irootloc; 10080e3ece09SJunchao Zhang PetscIntKokkosView &rowoffset = mm->rowoffset; 10090e3ece09SJunchao Zhang 10100e3ece09SJunchao Zhang PetscInt teamSize = mm->E_TeamSize; 10110e3ece09SJunchao Zhang PetscInt vectorLength = mm->E_VectorLength; 10120e3ece09SJunchao Zhang PetscInt rowsPerTeam = mm->E_RowsPerTeam; 10130e3ece09SJunchao Zhang PetscInt workSets = (irootloc.extent(0) + rowsPerTeam - 1) / rowsPerTeam; 10140e3ece09SJunchao Zhang 10150e3ece09SJunchao Zhang // Copy rows in A/B of E to rootBuf, then bcast it to leafBuf 10160e3ece09SJunchao Zhang PetscCallCXX(Kokkos::parallel_for( 1017d326c3f1SJunchao Zhang Kokkos::TeamPolicy<>(PetscGetKokkosExecutionSpace(), workSets, teamSize, vectorLength), KOKKOS_LAMBDA(const KokkosTeamMemberType &t) { 10180e3ece09SJunchao Zhang Kokkos::parallel_for(Kokkos::TeamThreadRange(t, 0, rowsPerTeam), [&](PetscInt k) { 10190e3ece09SJunchao Zhang size_t r = t.league_rank() * rowsPerTeam + k; // r-th entry in irootloc[] 10200e3ece09SJunchao Zhang if (r < irootloc.extent(0)) { 10210e3ece09SJunchao Zhang PetscInt i = irootloc(r); // row i of E 10220e3ece09SJunchao Zhang PetscInt disp = rowoffset(r); 10230e3ece09SJunchao Zhang PetscInt alen = Ai(i + 1) - Ai(i); 10240e3ece09SJunchao Zhang PetscInt blen = Bi(i + 1) - Bi(i); 10250e3ece09SJunchao Zhang PetscInt nzleft = E_NzLeft(i); 10260e3ece09SJunchao Zhang 10270e3ece09SJunchao Zhang Kokkos::parallel_for(Kokkos::ThreadVectorRange(t, alen + blen), [&](PetscInt j) { 10280e3ece09SJunchao Zhang if (j < nzleft) { // B left 10290e3ece09SJunchao Zhang rootBuf(disp + j) = Ba(Bi(i) + j); 10300e3ece09SJunchao Zhang } else if (j < nzleft + alen) { // diag A 10310e3ece09SJunchao Zhang rootBuf(disp + j) = Aa(Ai(i) + j - nzleft); 10320e3ece09SJunchao Zhang } else { // B right 10330e3ece09SJunchao Zhang rootBuf(disp + j) = Ba(Bi(i) + j - alen); 10340e3ece09SJunchao Zhang } 10350e3ece09SJunchao Zhang }); 10360e3ece09SJunchao Zhang } 10370e3ece09SJunchao Zhang }); 10380e3ece09SJunchao Zhang })); 10390e3ece09SJunchao Zhang PetscCall(PetscSFBcastWithMemTypeBegin(bcastSF, MPIU_SCALAR, PETSC_MEMTYPE_KOKKOS, rootBuf.data(), PETSC_MEMTYPE_KOKKOS, leafBuf.data(), MPI_REPLACE)); 10400e3ece09SJunchao Zhang PetscFunctionReturn(PETSC_SUCCESS); 10410e3ece09SJunchao Zhang } 10420e3ece09SJunchao Zhang 10430e3ece09SJunchao Zhang // To finish MatMPIAIJKokkosBcast. 10440e3ece09SJunchao Zhang static PetscErrorCode MatMPIAIJKokkosBcastEnd(Mat E, PetscSF ownerSF, MatReuse reuse, PetscInt *map, MatMatStruct_AB *mm) 10450e3ece09SJunchao Zhang { 10460e3ece09SJunchao Zhang PetscFunctionBegin; 10470e3ece09SJunchao Zhang const auto &Fd = mm->Fd; 10480e3ece09SJunchao Zhang const auto &Fo = mm->Fo; 10490e3ece09SJunchao Zhang const auto &Fdi = Fd.graph.row_map; 10500e3ece09SJunchao Zhang const auto &Foi = Fo.graph.row_map; 10510e3ece09SJunchao Zhang auto &Fda = Fd.values; 10520e3ece09SJunchao Zhang auto &Foa = Fo.values; 10530e3ece09SJunchao Zhang auto Fm = Fd.numRows(); 10540e3ece09SJunchao Zhang 10550e3ece09SJunchao Zhang PetscIntKokkosView &F_NzLeft = mm->F_NzLeft; 10560e3ece09SJunchao Zhang PetscSF &bcastSF = mm->sf; 10570e3ece09SJunchao Zhang MatScalarKokkosView &rootBuf = mm->rootBuf; 10580e3ece09SJunchao Zhang MatScalarKokkosView &leafBuf = mm->leafBuf; 10590e3ece09SJunchao Zhang PetscInt teamSize = mm->F_TeamSize; 10600e3ece09SJunchao Zhang PetscInt vectorLength = mm->F_VectorLength; 10610e3ece09SJunchao Zhang PetscInt rowsPerTeam = mm->F_RowsPerTeam; 10620e3ece09SJunchao Zhang PetscInt workSets = (Fm + rowsPerTeam - 1) / rowsPerTeam; 10630e3ece09SJunchao Zhang 10640e3ece09SJunchao Zhang PetscCall(PetscSFBcastEnd(bcastSF, MPIU_SCALAR, rootBuf.data(), leafBuf.data(), MPI_REPLACE)); 10650e3ece09SJunchao Zhang 10660e3ece09SJunchao Zhang // Update Fda and Foa with new data in leafBuf (as if it is Fa) 10670e3ece09SJunchao Zhang PetscCallCXX(Kokkos::parallel_for( 1068d326c3f1SJunchao Zhang Kokkos::TeamPolicy<>(PetscGetKokkosExecutionSpace(), workSets, teamSize, vectorLength), KOKKOS_LAMBDA(const KokkosTeamMemberType &t) { 10690e3ece09SJunchao Zhang Kokkos::parallel_for(Kokkos::TeamThreadRange(t, 0, rowsPerTeam), [&](PetscInt k) { 10700e3ece09SJunchao Zhang PetscInt i = t.league_rank() * rowsPerTeam + k; // i-th row in F 10710e3ece09SJunchao Zhang if (i < Fm) { 10720e3ece09SJunchao Zhang PetscInt nzLeft = F_NzLeft(i); 10730e3ece09SJunchao Zhang PetscInt alen = Fdi(i + 1) - Fdi(i); 10740e3ece09SJunchao Zhang PetscInt blen = Foi(i + 1) - Foi(i); 10750e3ece09SJunchao Zhang PetscInt Fii = Fdi(i) + Foi(i); 10760e3ece09SJunchao Zhang 10770e3ece09SJunchao Zhang Kokkos::parallel_for(Kokkos::ThreadVectorRange(t, alen + blen), [&](PetscInt j) { 10780e3ece09SJunchao Zhang PetscScalar val = leafBuf(Fii + j); 10790e3ece09SJunchao Zhang if (j < nzLeft) { // left 10800e3ece09SJunchao Zhang Foa(Foi(i) + j) = val; 10810e3ece09SJunchao Zhang } else if (j < nzLeft + alen) { // diag 10820e3ece09SJunchao Zhang Fda(Fdi(i) + j - nzLeft) = val; 10830e3ece09SJunchao Zhang } else { // right 10840e3ece09SJunchao Zhang Foa(Foi(i) + j - alen) = val; 10850e3ece09SJunchao Zhang } 10860e3ece09SJunchao Zhang }); 10870e3ece09SJunchao Zhang } 10880e3ece09SJunchao Zhang }); 10890e3ece09SJunchao Zhang })); 10900e3ece09SJunchao Zhang PetscFunctionReturn(PETSC_SUCCESS); 10910e3ece09SJunchao Zhang } 10920e3ece09SJunchao Zhang 10930e3ece09SJunchao Zhang static PetscErrorCode MatProductSymbolic_MPIAIJKokkos_AtB(Mat_Product *product, Mat A, Mat B, MatMatStruct_AtB *mm) 10940e3ece09SJunchao Zhang { 10950e3ece09SJunchao Zhang Mat_MPIAIJ *ampi = static_cast<Mat_MPIAIJ *>(A->data); 10960e3ece09SJunchao Zhang Mat_MPIAIJ *bmpi = static_cast<Mat_MPIAIJ *>(B->data); 10970e3ece09SJunchao Zhang KokkosCsrMatrix Adt, Aot, Ad, Ao, Bd, Bo; 10980e3ece09SJunchao Zhang PetscInt cstart, cend; 10990e3ece09SJunchao Zhang MPI_Comm comm; 11000e3ece09SJunchao Zhang 11010e3ece09SJunchao Zhang PetscFunctionBegin; 11020e3ece09SJunchao Zhang PetscCall(PetscObjectGetComm((PetscObject)B, &comm)); 11030e3ece09SJunchao Zhang PetscCall(MatSeqAIJKokkosGenerateTranspose_Private(ampi->A, &Adt)); 11040e3ece09SJunchao Zhang PetscCall(MatSeqAIJKokkosGenerateTranspose_Private(ampi->B, &Aot)); 11050e3ece09SJunchao Zhang PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(ampi->A, &Ad)); 11060e3ece09SJunchao Zhang PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(ampi->B, &Ao)); 11070e3ece09SJunchao Zhang PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(bmpi->A, &Bd)); 11080e3ece09SJunchao Zhang PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(bmpi->B, &Bo)); 11090e3ece09SJunchao Zhang 11100e3ece09SJunchao Zhang // TODO: add command line options to select spgemm algorithms 11110e3ece09SJunchao Zhang auto spgemm_alg = KokkosSparse::SPGEMMAlgorithm::SPGEMM_DEFAULT; // default is TPL if enabled, otherwise KK 11120e3ece09SJunchao Zhang 11130e3ece09SJunchao Zhang // CUDA-10.2's spgemm has bugs. We prefer the SpGEMMreuse APIs introduced in cuda-11.4 11140e3ece09SJunchao Zhang #if defined(KOKKOSKERNELS_ENABLE_TPL_CUSPARSE) 11150e3ece09SJunchao Zhang #if PETSC_PKG_CUDA_VERSION_LT(11, 4, 0) 11160e3ece09SJunchao Zhang spgemm_alg = KokkosSparse::SPGEMMAlgorithm::SPGEMM_KK; 1117f0e6e2d1SJunchao Zhang #endif 11180e3ece09SJunchao Zhang #endif 11190e3ece09SJunchao Zhang 11200e3ece09SJunchao Zhang PetscCallCXX(mm->kh1.create_spgemm_handle(spgemm_alg)); 11210e3ece09SJunchao Zhang PetscCallCXX(mm->kh2.create_spgemm_handle(spgemm_alg)); 11220e3ece09SJunchao Zhang PetscCallCXX(mm->kh3.create_spgemm_handle(spgemm_alg)); 11230e3ece09SJunchao Zhang PetscCallCXX(mm->kh4.create_spgemm_handle(spgemm_alg)); 11240e3ece09SJunchao Zhang 11250e3ece09SJunchao Zhang // Aot * (B's diag + B's off-diag) 11260e3ece09SJunchao Zhang PetscCallCXX(KokkosSparse::spgemm_symbolic(mm->kh3, Aot, false, Bd, false, mm->C3)); 11270e3ece09SJunchao Zhang PetscCallCXX(KokkosSparse::spgemm_symbolic(mm->kh4, Aot, false, Bo, false, mm->C4)); 11280e3ece09SJunchao Zhang // KK spgemm_symbolic() only populates the result's row map, but not its columns. 11290e3ece09SJunchao Zhang // TODO: Remove the fake spgemm_numeric() after KK fixed this problem. 11300e3ece09SJunchao Zhang PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh3, Aot, false, Bd, false, mm->C3)); 11310e3ece09SJunchao Zhang PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh4, Aot, false, Bo, false, mm->C4)); 11320e3ece09SJunchao Zhang #if PETSC_PKG_KOKKOS_KERNELS_VERSION_LT(4, 0, 0) 1133d326c3f1SJunchao Zhang 11340e3ece09SJunchao Zhang PetscCallCXX(sort_crs_matrix(mm->C3)); 11350e3ece09SJunchao Zhang PetscCallCXX(sort_crs_matrix(mm->C4)); 11360e3ece09SJunchao Zhang #endif 11370e3ece09SJunchao Zhang 11380e3ece09SJunchao Zhang // Reduce E (i.e., C3 and C4)'s rows to form F, and overlap the communication 11397b8d4ba6SJunchao Zhang PetscIntKokkosViewHost map_h(NoInit("map_h"), bmpi->B->cmap->n); 11400e3ece09SJunchao Zhang PetscCall(MatGetOwnershipRangeColumn(B, &cstart, &cend)); 11410e3ece09SJunchao Zhang PetscCall(MatMPIAIJKokkosReduceBegin(comm, mm->C3, mm->C4, cstart, cend, bmpi->garray, ampi->Mvctx, MAT_INITIAL_MATRIX, map_h.data(), mm)); 11420e3ece09SJunchao Zhang 11430e3ece09SJunchao Zhang // Adt * (B's diag + B's off-diag) 11440e3ece09SJunchao Zhang PetscCallCXX(KokkosSparse::spgemm_symbolic(mm->kh1, Adt, false, Bd, false, mm->C1)); 11450e3ece09SJunchao Zhang PetscCallCXX(KokkosSparse::spgemm_symbolic(mm->kh2, Adt, false, Bo, false, mm->C2_mid)); 11460e3ece09SJunchao Zhang PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh1, Adt, false, Bd, false, mm->C1)); 11470e3ece09SJunchao Zhang PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh2, Adt, false, Bo, false, mm->C2_mid)); 11480e3ece09SJunchao Zhang #if PETSC_PKG_KOKKOS_KERNELS_VERSION_LT(4, 0, 0) 11490e3ece09SJunchao Zhang PetscCallCXX(sort_crs_matrix(mm->C1)); 11500e3ece09SJunchao Zhang PetscCallCXX(sort_crs_matrix(mm->C2_mid)); 11510e3ece09SJunchao Zhang #endif 11520e3ece09SJunchao Zhang 11530e3ece09SJunchao Zhang PetscCall(MatMPIAIJKokkosReduceEnd(comm, mm->C3, mm->C4, cstart, cend, bmpi->garray, ampi->Mvctx, MAT_INITIAL_MATRIX, map_h.data(), mm)); 11540e3ece09SJunchao Zhang 11550e3ece09SJunchao Zhang // Create C2, which shares a, i arrays with C2_mid, but with new column indices and potentially larger column size 11567b8d4ba6SJunchao Zhang MatColIdxKokkosView oldj = mm->C2_mid.graph.entries, newj(NoInit("j"), oldj.extent(0)); 11570e3ece09SJunchao Zhang PetscIntKokkosView map = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), map_h); 1158d326c3f1SJunchao Zhang PetscCallCXX(Kokkos::parallel_for(Kokkos::RangePolicy<>(PetscGetKokkosExecutionSpace(), 0, oldj.extent(0)), KOKKOS_LAMBDA(const PetscInt i) { newj(i) = map(oldj(i)); })); 11590e3ece09SJunchao Zhang PetscCallCXX(mm->C2 = KokkosCsrMatrix("C2", mm->C2_mid.numRows(), mm->n /*new column size*/, mm->C2_mid.nnz(), mm->C2_mid.values, mm->C2_mid.graph.row_map, newj)); 11600e3ece09SJunchao Zhang 11610e3ece09SJunchao Zhang // C = (C1+Fd, C2+Fo) 11620e3ece09SJunchao Zhang PetscCallCXX(mm->kh1.create_spadd_handle(true)); // C1, Fd are sorted 11630e3ece09SJunchao Zhang PetscCallCXX(mm->kh2.create_spadd_handle(true)); // C2, Fo are sorted 11640e3ece09SJunchao Zhang PetscCallCXX(KokkosSparse::spadd_symbolic(&mm->kh1, mm->C1, mm->Fd, mm->Cd)); 11650e3ece09SJunchao Zhang PetscCallCXX(KokkosSparse::spadd_symbolic(&mm->kh2, mm->C2, mm->Fo, mm->Co)); 11660e3ece09SJunchao Zhang PetscCallCXX(KokkosSparse::spadd_numeric(&mm->kh1, 1.0, mm->C1, 1.0, mm->Fd, mm->Cd)); 11670e3ece09SJunchao Zhang PetscCallCXX(KokkosSparse::spadd_numeric(&mm->kh2, 1.0, mm->C2, 1.0, mm->Fo, mm->Co)); 11680e3ece09SJunchao Zhang PetscFunctionReturn(PETSC_SUCCESS); 11690e3ece09SJunchao Zhang } 11700e3ece09SJunchao Zhang 11710e3ece09SJunchao Zhang static PetscErrorCode MatProductNumeric_MPIAIJKokkos_AtB(Mat_Product *product, Mat A, Mat B, MatMatStruct_AtB *mm) 11720e3ece09SJunchao Zhang { 11730e3ece09SJunchao Zhang Mat_MPIAIJ *ampi = static_cast<Mat_MPIAIJ *>(A->data); 11740e3ece09SJunchao Zhang Mat_MPIAIJ *bmpi = static_cast<Mat_MPIAIJ *>(B->data); 11750e3ece09SJunchao Zhang KokkosCsrMatrix Adt, Aot, Bd, Bo; 11760e3ece09SJunchao Zhang MPI_Comm comm; 11770e3ece09SJunchao Zhang 11780e3ece09SJunchao Zhang PetscFunctionBegin; 11790e3ece09SJunchao Zhang PetscCall(PetscObjectGetComm((PetscObject)B, &comm)); 11800e3ece09SJunchao Zhang PetscCall(MatSeqAIJKokkosGenerateTranspose_Private(ampi->A, &Adt)); 11810e3ece09SJunchao Zhang PetscCall(MatSeqAIJKokkosGenerateTranspose_Private(ampi->B, &Aot)); 11820e3ece09SJunchao Zhang PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(bmpi->A, &Bd)); 11830e3ece09SJunchao Zhang PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(bmpi->B, &Bo)); 11840e3ece09SJunchao Zhang 11850e3ece09SJunchao Zhang // Aot * (B's diag + B's off-diag) 11860e3ece09SJunchao Zhang PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh3, Aot, false, Bd, false, mm->C3)); 11870e3ece09SJunchao Zhang PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh4, Aot, false, Bo, false, mm->C4)); 11880e3ece09SJunchao Zhang 11890e3ece09SJunchao Zhang // Reduce E (i.e., C3 and C4)'s rows to form F, and overlap the communication 11900e3ece09SJunchao Zhang PetscCall(MatMPIAIJKokkosReduceBegin(comm, mm->C3, mm->C4, 0, 0, NULL, NULL, MAT_REUSE_MATRIX, NULL, mm)); 11910e3ece09SJunchao Zhang 11920e3ece09SJunchao Zhang // Adt * (B's diag + B's off-diag) 11930e3ece09SJunchao Zhang PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh1, Adt, false, Bd, false, mm->C1)); 11940e3ece09SJunchao Zhang PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh2, Adt, false, Bo, false, mm->C2_mid)); 11950e3ece09SJunchao Zhang 11960e3ece09SJunchao Zhang PetscCall(MatMPIAIJKokkosReduceEnd(comm, mm->C3, mm->C4, 0, 0, NULL, NULL, MAT_REUSE_MATRIX, NULL, mm)); 11970e3ece09SJunchao Zhang 11980e3ece09SJunchao Zhang // C = (C1+Fd, C2+Fo) 11990e3ece09SJunchao Zhang PetscCallCXX(KokkosSparse::spadd_numeric(&mm->kh1, 1.0, mm->C1, 1.0, mm->Fd, mm->Cd)); 12000e3ece09SJunchao Zhang PetscCallCXX(KokkosSparse::spadd_numeric(&mm->kh2, 1.0, mm->C2, 1.0, mm->Fo, mm->Co)); 12010e3ece09SJunchao Zhang PetscFunctionReturn(PETSC_SUCCESS); 12020e3ece09SJunchao Zhang } 1203f0e6e2d1SJunchao Zhang 1204076ba34aSJunchao Zhang /* MatProductSymbolic_MPIAIJKokkos_AB - AB flavor of MatProductSymbolic_MPIAIJKokkos 1205076ba34aSJunchao Zhang 1206076ba34aSJunchao Zhang Input Parameters: 1207076ba34aSJunchao Zhang + product - Mat_Product which carried out the computation. Passed in to access info about this mat product. 1208076ba34aSJunchao Zhang . A - an MPIAIJKOKKOS matrix 1209076ba34aSJunchao Zhang . B - an MPIAIJKOKKOS matrix 1210076ba34aSJunchao Zhang - mm - a struct used to stash intermediate data when computing AB. Persist from symbolic to numeric operations. 1211076ba34aSJunchao Zhang */ 1212d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatProductSymbolic_MPIAIJKokkos_AB(Mat_Product *product, Mat A, Mat B, MatMatStruct_AB *mm) 1213d71ae5a4SJacob Faibussowitsch { 12140e3ece09SJunchao Zhang Mat_MPIAIJ *ampi = static_cast<Mat_MPIAIJ *>(A->data); 12150e3ece09SJunchao Zhang Mat_MPIAIJ *bmpi = static_cast<Mat_MPIAIJ *>(B->data); 12160e3ece09SJunchao Zhang KokkosCsrMatrix Ad, Ao, Bd, Bo; 1217076ba34aSJunchao Zhang 1218076ba34aSJunchao Zhang PetscFunctionBegin; 12190e3ece09SJunchao Zhang PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(ampi->A, &Ad)); 12200e3ece09SJunchao Zhang PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(ampi->B, &Ao)); 12210e3ece09SJunchao Zhang PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(bmpi->A, &Bd)); 12220e3ece09SJunchao Zhang PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(bmpi->B, &Bo)); 12230e3ece09SJunchao Zhang 12240e3ece09SJunchao Zhang // TODO: add command line options to select spgemm algorithms 12250e3ece09SJunchao Zhang auto spgemm_alg = KokkosSparse::SPGEMMAlgorithm::SPGEMM_DEFAULT; // default is TPL if enabled, otherwise KK 12260e3ece09SJunchao Zhang 12270e3ece09SJunchao Zhang // CUDA-10.2's spgemm has bugs. We prefer the SpGEMMreuse APIs introduced in cuda-11.4 12280e3ece09SJunchao Zhang #if defined(KOKKOSKERNELS_ENABLE_TPL_CUSPARSE) 12290e3ece09SJunchao Zhang #if PETSC_PKG_CUDA_VERSION_LT(11, 4, 0) 12300e3ece09SJunchao Zhang spgemm_alg = KokkosSparse::SPGEMMAlgorithm::SPGEMM_KK; 12310e3ece09SJunchao Zhang #endif 1232f0e6e2d1SJunchao Zhang #endif 1233f0e6e2d1SJunchao Zhang 12340e3ece09SJunchao Zhang mm->kh1.create_spgemm_handle(spgemm_alg); 12350e3ece09SJunchao Zhang mm->kh2.create_spgemm_handle(spgemm_alg); 12360e3ece09SJunchao Zhang mm->kh3.create_spgemm_handle(spgemm_alg); 12370e3ece09SJunchao Zhang mm->kh4.create_spgemm_handle(spgemm_alg); 1238076ba34aSJunchao Zhang 12390e3ece09SJunchao Zhang // Bcast B's rows to form F, and overlap the communication 12407b8d4ba6SJunchao Zhang PetscIntKokkosViewHost map_h(NoInit("map_h"), bmpi->B->cmap->n); 12410e3ece09SJunchao Zhang PetscCall(MatMPIAIJKokkosBcastBegin(B, ampi->Mvctx, MAT_INITIAL_MATRIX, map_h.data(), mm)); 1242076ba34aSJunchao Zhang 12430e3ece09SJunchao Zhang // A's diag * (B's diag + B's off-diag) 12440e3ece09SJunchao Zhang PetscCallCXX(KokkosSparse::spgemm_symbolic(mm->kh1, Ad, false, Bd, false, mm->C1)); 12450e3ece09SJunchao Zhang PetscCallCXX(KokkosSparse::spgemm_symbolic(mm->kh2, Ad, false, Bo, false, mm->C2_mid)); // C2 aliases with C2_mid, except with new column indices 12460e3ece09SJunchao Zhang // KK spgemm_symbolic() only populates the result's row map, but not its columns. 12470e3ece09SJunchao Zhang // TODO: Remove the fake spgemm_numeric() after KK fixed this problem. 12480e3ece09SJunchao Zhang PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh1, Ad, false, Bd, false, mm->C1)); 12490e3ece09SJunchao Zhang PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh2, Ad, false, Bo, false, mm->C2_mid)); 12500e3ece09SJunchao Zhang #if PETSC_PKG_KOKKOS_KERNELS_VERSION_LT(4, 0, 0) 12510e3ece09SJunchao Zhang PetscCallCXX(sort_crs_matrix(mm->C1)); 12520e3ece09SJunchao Zhang PetscCallCXX(sort_crs_matrix(mm->C2_mid)); 12530e3ece09SJunchao Zhang #endif 1254076ba34aSJunchao Zhang 12550e3ece09SJunchao Zhang PetscCall(MatMPIAIJKokkosBcastEnd(B, ampi->Mvctx, MAT_INITIAL_MATRIX, map_h.data(), mm)); 1256076ba34aSJunchao Zhang 12570e3ece09SJunchao Zhang // A's off-diag * (F's diag + F's off-diag) 12580e3ece09SJunchao Zhang PetscCallCXX(KokkosSparse::spgemm_symbolic(mm->kh3, Ao, false, mm->Fd, false, mm->C3)); 12590e3ece09SJunchao Zhang PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh3, Ao, false, mm->Fd, false, mm->C3)); 12600e3ece09SJunchao Zhang PetscCallCXX(KokkosSparse::spgemm_symbolic(mm->kh4, Ao, false, mm->Fo, false, mm->C4)); 12610e3ece09SJunchao Zhang PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh4, Ao, false, mm->Fo, false, mm->C4)); 12620e3ece09SJunchao Zhang #if PETSC_PKG_KOKKOS_KERNELS_VERSION_LT(4, 0, 0) 12630e3ece09SJunchao Zhang PetscCallCXX(sort_crs_matrix(mm->C3)); 12640e3ece09SJunchao Zhang PetscCallCXX(sort_crs_matrix(mm->C4)); 12650e3ece09SJunchao Zhang #endif 1266076ba34aSJunchao Zhang 12670e3ece09SJunchao Zhang // Create C2, which shares a, i arrays with C2_mid, but with new column indices and potentially larger column size 12687b8d4ba6SJunchao Zhang MatColIdxKokkosView oldj = mm->C2_mid.graph.entries, newj(NoInit("j"), oldj.extent(0)); 12690e3ece09SJunchao Zhang PetscIntKokkosView map = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), map_h); 1270d326c3f1SJunchao Zhang PetscCallCXX(Kokkos::parallel_for(Kokkos::RangePolicy<>(PetscGetKokkosExecutionSpace(), 0, oldj.extent(0)), KOKKOS_LAMBDA(const PetscInt i) { newj(i) = map(oldj(i)); })); 12710e3ece09SJunchao Zhang mm->C2 = KokkosCsrMatrix("C2", mm->C2_mid.numRows(), mm->n /*new column size*/, mm->C2_mid.nnz(), mm->C2_mid.values, mm->C2_mid.graph.row_map, newj); 12720e3ece09SJunchao Zhang 12730e3ece09SJunchao Zhang // C = (Cd, Co) = (C1+C3, C2+C4) 12740e3ece09SJunchao Zhang mm->kh1.create_spadd_handle(true); // C1, C3 are sorted 12750e3ece09SJunchao Zhang mm->kh2.create_spadd_handle(true); // C2, C4 are sorted 12760e3ece09SJunchao Zhang PetscCallCXX(KokkosSparse::spadd_symbolic(&mm->kh1, mm->C1, mm->C3, mm->Cd)); 12770e3ece09SJunchao Zhang PetscCallCXX(KokkosSparse::spadd_symbolic(&mm->kh2, mm->C2, mm->C4, mm->Co)); 12780e3ece09SJunchao Zhang PetscCallCXX(KokkosSparse::spadd_numeric(&mm->kh1, 1.0, mm->C1, 1.0, mm->C3, mm->Cd)); 12790e3ece09SJunchao Zhang PetscCallCXX(KokkosSparse::spadd_numeric(&mm->kh2, 1.0, mm->C2, 1.0, mm->C4, mm->Co)); 12803ba16761SJacob Faibussowitsch PetscFunctionReturn(PETSC_SUCCESS); 1281076ba34aSJunchao Zhang } 1282076ba34aSJunchao Zhang 12830e3ece09SJunchao Zhang static PetscErrorCode MatProductNumeric_MPIAIJKokkos_AB(Mat_Product *product, Mat A, Mat B, MatMatStruct_AB *mm) 1284d71ae5a4SJacob Faibussowitsch { 12850e3ece09SJunchao Zhang Mat_MPIAIJ *ampi = static_cast<Mat_MPIAIJ *>(A->data); 12860e3ece09SJunchao Zhang Mat_MPIAIJ *bmpi = static_cast<Mat_MPIAIJ *>(B->data); 12870e3ece09SJunchao Zhang KokkosCsrMatrix Ad, Ao, Bd, Bo; 1288076ba34aSJunchao Zhang 1289076ba34aSJunchao Zhang PetscFunctionBegin; 12900e3ece09SJunchao Zhang PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(ampi->A, &Ad)); 12910e3ece09SJunchao Zhang PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(ampi->B, &Ao)); 12920e3ece09SJunchao Zhang PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(bmpi->A, &Bd)); 12930e3ece09SJunchao Zhang PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(bmpi->B, &Bo)); 1294076ba34aSJunchao Zhang 12950e3ece09SJunchao Zhang // Bcast B's rows to form F, and overlap the communication 12960e3ece09SJunchao Zhang PetscCall(MatMPIAIJKokkosBcastBegin(B, NULL, MAT_REUSE_MATRIX, NULL, mm)); 1297076ba34aSJunchao Zhang 12980e3ece09SJunchao Zhang // A's diag * (B's diag + B's off-diag) 12990e3ece09SJunchao Zhang PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh1, Ad, false, Bd, false, mm->C1)); 13000e3ece09SJunchao Zhang PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh2, Ad, false, Bo, false, mm->C2_mid)); 1301076ba34aSJunchao Zhang 13020e3ece09SJunchao Zhang PetscCall(MatMPIAIJKokkosBcastEnd(B, NULL, MAT_REUSE_MATRIX, NULL, mm)); 1303076ba34aSJunchao Zhang 13040e3ece09SJunchao Zhang // A's off-diag * (F's diag + F's off-diag) 13050e3ece09SJunchao Zhang PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh3, Ao, false, mm->Fd, false, mm->C3)); 13060e3ece09SJunchao Zhang PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh4, Ao, false, mm->Fo, false, mm->C4)); 13070e3ece09SJunchao Zhang 13080e3ece09SJunchao Zhang // C = (Cd, Co) = (C1+C3, C2+C4) 13090e3ece09SJunchao Zhang PetscCallCXX(KokkosSparse::spadd_numeric(&mm->kh1, 1.0, mm->C1, 1.0, mm->C3, mm->Cd)); 13100e3ece09SJunchao Zhang PetscCallCXX(KokkosSparse::spadd_numeric(&mm->kh2, 1.0, mm->C2, 1.0, mm->C4, mm->Co)); 13113ba16761SJacob Faibussowitsch PetscFunctionReturn(PETSC_SUCCESS); 1312076ba34aSJunchao Zhang } 1313076ba34aSJunchao Zhang 131466976f2fSJacob Faibussowitsch static PetscErrorCode MatProductNumeric_MPIAIJKokkos(Mat C) 1315d71ae5a4SJacob Faibussowitsch { 13160e3ece09SJunchao Zhang Mat_MPIAIJ *cmpi = static_cast<Mat_MPIAIJ *>(C->data); 13170e3ece09SJunchao Zhang Mat_Product *product; 13180e3ece09SJunchao Zhang MatProductData_MPIAIJKokkos *pdata; 1319076ba34aSJunchao Zhang MatProductType ptype; 13200e3ece09SJunchao Zhang Mat A, B; 1321076ba34aSJunchao Zhang 1322076ba34aSJunchao Zhang PetscFunctionBegin; 13230e3ece09SJunchao Zhang MatCheckProduct(C, 1); // make sure C is a product 13240e3ece09SJunchao Zhang product = C->product; 13250e3ece09SJunchao Zhang pdata = static_cast<MatProductData_MPIAIJKokkos *>(product->data); 1326076ba34aSJunchao Zhang ptype = product->type; 1327076ba34aSJunchao Zhang A = product->A; 1328076ba34aSJunchao Zhang B = product->B; 1329076ba34aSJunchao Zhang 13300e3ece09SJunchao Zhang // See if numeric has already been done in symbolic (e.g., user calls MatMatMult(A,B,MAT_INITIAL_MATRIX,..,C)). 13310e3ece09SJunchao Zhang // If yes, skip the numeric, but reset the flag so that next time when user calls MatMatMult(E,F,MAT_REUSE_MATRIX,..,C), 13320e3ece09SJunchao Zhang // we still do numeric. 13330e3ece09SJunchao Zhang if (pdata->reusesym) { // numeric reuses results from symbolic 13340e3ece09SJunchao Zhang pdata->reusesym = PETSC_FALSE; 13353ba16761SJacob Faibussowitsch PetscFunctionReturn(PETSC_SUCCESS); 1336076ba34aSJunchao Zhang } 1337076ba34aSJunchao Zhang 1338076ba34aSJunchao Zhang if (ptype == MATPRODUCT_AB) { 13390e3ece09SJunchao Zhang PetscCall(MatProductNumeric_MPIAIJKokkos_AB(product, A, B, pdata->mmAB)); 1340076ba34aSJunchao Zhang } else if (ptype == MATPRODUCT_AtB) { 13410e3ece09SJunchao Zhang PetscCall(MatProductNumeric_MPIAIJKokkos_AtB(product, A, B, pdata->mmAtB)); 13420e3ece09SJunchao Zhang } else if (ptype == MATPRODUCT_PtAP) { // BtAB, computed by Z = AB; C= BtZ 13430e3ece09SJunchao Zhang PetscCall(MatProductNumeric_MPIAIJKokkos_AB(product, A, B, pdata->mmAB)); 13440e3ece09SJunchao Zhang PetscCall(MatProductNumeric_MPIAIJKokkos_AtB(product, B, pdata->Z, pdata->mmAtB)); 1345076ba34aSJunchao Zhang } 13460e3ece09SJunchao Zhang 13470e3ece09SJunchao Zhang PetscCall(MatSeqAIJKokkosModifyDevice(cmpi->A)); // mark that A, B on device are modified 13480e3ece09SJunchao Zhang PetscCall(MatSeqAIJKokkosModifyDevice(cmpi->B)); 13493ba16761SJacob Faibussowitsch PetscFunctionReturn(PETSC_SUCCESS); 1350076ba34aSJunchao Zhang } 1351076ba34aSJunchao Zhang 135266976f2fSJacob Faibussowitsch static PetscErrorCode MatProductSymbolic_MPIAIJKokkos(Mat C) 1353d71ae5a4SJacob Faibussowitsch { 1354076ba34aSJunchao Zhang Mat A, B; 13550e3ece09SJunchao Zhang Mat_Product *product; 1356076ba34aSJunchao Zhang MatProductType ptype; 13570e3ece09SJunchao Zhang MatProductData_MPIAIJKokkos *pdata; 1358076ba34aSJunchao Zhang MatMatStruct *mm = NULL; 13590e3ece09SJunchao Zhang PetscInt m, n, M, N; 13600e3ece09SJunchao Zhang Mat Cd, Co; 13610e3ece09SJunchao Zhang MPI_Comm comm; 1362076ba34aSJunchao Zhang 1363076ba34aSJunchao Zhang PetscFunctionBegin; 13640e3ece09SJunchao Zhang PetscCall(PetscObjectGetComm((PetscObject)C, &comm)); 1365076ba34aSJunchao Zhang MatCheckProduct(C, 1); 13660e3ece09SJunchao Zhang product = C->product; 13670e3ece09SJunchao Zhang PetscCheck(!product->data, comm, PETSC_ERR_PLIB, "Product data not empty"); 1368076ba34aSJunchao Zhang ptype = product->type; 1369076ba34aSJunchao Zhang A = product->A; 1370076ba34aSJunchao Zhang B = product->B; 1371076ba34aSJunchao Zhang 1372076ba34aSJunchao Zhang switch (ptype) { 13739371c9d4SSatish Balay case MATPRODUCT_AB: 13749371c9d4SSatish Balay m = A->rmap->n; 13759371c9d4SSatish Balay n = B->cmap->n; 13769371c9d4SSatish Balay M = A->rmap->N; 13779371c9d4SSatish Balay N = B->cmap->N; 13789371c9d4SSatish Balay break; 13799371c9d4SSatish Balay case MATPRODUCT_AtB: 13809371c9d4SSatish Balay m = A->cmap->n; 13819371c9d4SSatish Balay n = B->cmap->n; 13829371c9d4SSatish Balay M = A->cmap->N; 13839371c9d4SSatish Balay N = B->cmap->N; 13849371c9d4SSatish Balay break; 13859371c9d4SSatish Balay case MATPRODUCT_PtAP: 13869371c9d4SSatish Balay m = B->cmap->n; 13879371c9d4SSatish Balay n = B->cmap->n; 13889371c9d4SSatish Balay M = B->cmap->N; 13899371c9d4SSatish Balay N = B->cmap->N; 13909371c9d4SSatish Balay break; /* BtAB */ 1391d71ae5a4SJacob Faibussowitsch default: 13920e3ece09SJunchao Zhang SETERRQ(comm, PETSC_ERR_PLIB, "Not for product type %s", MatProductTypes[ptype]); 1393076ba34aSJunchao Zhang } 1394076ba34aSJunchao Zhang 13959566063dSJacob Faibussowitsch PetscCall(MatSetSizes(C, m, n, M, N)); 13969566063dSJacob Faibussowitsch PetscCall(PetscLayoutSetUp(C->rmap)); 13979566063dSJacob Faibussowitsch PetscCall(PetscLayoutSetUp(C->cmap)); 13989566063dSJacob Faibussowitsch PetscCall(MatSetType(C, ((PetscObject)A)->type_name)); 1399076ba34aSJunchao Zhang 14000e3ece09SJunchao Zhang pdata = new MatProductData_MPIAIJKokkos(); 14010e3ece09SJunchao Zhang pdata->reusesym = product->api_user; 1402076ba34aSJunchao Zhang 1403076ba34aSJunchao Zhang if (ptype == MATPRODUCT_AB) { 14040e3ece09SJunchao Zhang auto mmAB = new MatMatStruct_AB(); 14050e3ece09SJunchao Zhang PetscCall(MatProductSymbolic_MPIAIJKokkos_AB(product, A, B, mmAB)); 14060e3ece09SJunchao Zhang mm = pdata->mmAB = mmAB; 1407076ba34aSJunchao Zhang } else if (ptype == MATPRODUCT_AtB) { 14080e3ece09SJunchao Zhang auto mmAtB = new MatMatStruct_AtB(); 14090e3ece09SJunchao Zhang PetscCall(MatProductSymbolic_MPIAIJKokkos_AtB(product, A, B, mmAtB)); 14100e3ece09SJunchao Zhang mm = pdata->mmAtB = mmAtB; 14110e3ece09SJunchao Zhang } else if (ptype == MATPRODUCT_PtAP) { // C = BtAB, computed as Z = AB; C= BtZ 14120e3ece09SJunchao Zhang Mat Zd, Zo, Z; // Zd, Zo are owned by pdata->Z 14130e3ece09SJunchao Zhang 14140e3ece09SJunchao Zhang auto mmAB = new MatMatStruct_AB(); 14150e3ece09SJunchao Zhang PetscCall(MatProductSymbolic_MPIAIJKokkos_AB(product, A, B, mmAB)); // Z stored as mmAB->{Cd, Co} 14160e3ece09SJunchao Zhang PetscCall(MatCreateSeqAIJKokkosWithKokkosCsrMatrix(PETSC_COMM_SELF, mmAB->Cd, &Zd)); 14170e3ece09SJunchao Zhang PetscCall(MatCreateSeqAIJKokkosWithKokkosCsrMatrix(PETSC_COMM_SELF, mmAB->Co, &Zo)); 14180e3ece09SJunchao Zhang pdata->mmAB = mmAB; 14190e3ece09SJunchao Zhang 14200e3ece09SJunchao Zhang m = A->rmap->n; // Z's layout 14210e3ece09SJunchao Zhang n = B->cmap->n; 14220e3ece09SJunchao Zhang M = A->rmap->N; 14230e3ece09SJunchao Zhang N = B->cmap->N; 14240e3ece09SJunchao Zhang PetscCall(MatCreate(comm, &Z)); 14250e3ece09SJunchao Zhang PetscCall(MatSetSizes(Z, m, n, M, N)); 14260e3ece09SJunchao Zhang PetscCall(PetscLayoutSetUp(Z->rmap)); 14270e3ece09SJunchao Zhang PetscCall(PetscLayoutSetUp(Z->cmap)); 14280e3ece09SJunchao Zhang PetscCall(MatSetType(Z, MATMPIAIJKOKKOS)); 14290e3ece09SJunchao Zhang PetscCall(MatSetMPIAIJKokkosWithSplitSeqAIJKokkosMatrices(Z, Zd, Zo, mmAB->garray)); 14300e3ece09SJunchao Zhang 14310e3ece09SJunchao Zhang auto mmAtB = new MatMatStruct_AtB(); 14320e3ece09SJunchao Zhang PetscCall(MatProductSymbolic_MPIAIJKokkos_AtB(product, B, Z, mmAtB)); // final result C stored as mmAtB->{Cd, Co} 14330e3ece09SJunchao Zhang 14340e3ece09SJunchao Zhang pdata->Z = Z; // give ownership to pdata 14350e3ece09SJunchao Zhang mm = pdata->mmAtB = mmAtB; 1436076ba34aSJunchao Zhang } 14370e3ece09SJunchao Zhang 14380e3ece09SJunchao Zhang PetscCall(MatCreateSeqAIJKokkosWithKokkosCsrMatrix(PETSC_COMM_SELF, mm->Cd, &Cd)); 14390e3ece09SJunchao Zhang PetscCall(MatCreateSeqAIJKokkosWithKokkosCsrMatrix(PETSC_COMM_SELF, mm->Co, &Co)); 14400e3ece09SJunchao Zhang PetscCall(MatSetMPIAIJKokkosWithSplitSeqAIJKokkosMatrices(C, Cd, Co, mm->garray)); 14410e3ece09SJunchao Zhang 14420e3ece09SJunchao Zhang C->product->data = pdata; 1443076ba34aSJunchao Zhang C->product->destroy = MatProductDataDestroy_MPIAIJKokkos; 1444076ba34aSJunchao Zhang C->ops->productnumeric = MatProductNumeric_MPIAIJKokkos; 14453ba16761SJacob Faibussowitsch PetscFunctionReturn(PETSC_SUCCESS); 1446076ba34aSJunchao Zhang } 1447076ba34aSJunchao Zhang 1448d71ae5a4SJacob Faibussowitsch PETSC_INTERN PetscErrorCode MatProductSetFromOptions_MPIAIJKokkos(Mat mat) 1449d71ae5a4SJacob Faibussowitsch { 1450076ba34aSJunchao Zhang Mat_Product *product = mat->product; 1451076ba34aSJunchao Zhang PetscBool match = PETSC_FALSE; 1452076ba34aSJunchao Zhang PetscBool usecpu = PETSC_FALSE; 1453076ba34aSJunchao Zhang 1454076ba34aSJunchao Zhang PetscFunctionBegin; 1455076ba34aSJunchao Zhang MatCheckProduct(mat, 1); 145648a46eb9SPierre Jolivet if (!product->A->boundtocpu && !product->B->boundtocpu) PetscCall(PetscObjectTypeCompare((PetscObject)product->B, ((PetscObject)product->A)->type_name, &match)); 1457076ba34aSJunchao Zhang if (match) { /* we can always fallback to the CPU if requested */ 1458076ba34aSJunchao Zhang switch (product->type) { 1459076ba34aSJunchao Zhang case MATPRODUCT_AB: 1460076ba34aSJunchao Zhang if (product->api_user) { 1461d0609cedSBarry Smith PetscOptionsBegin(PetscObjectComm((PetscObject)mat), ((PetscObject)mat)->prefix, "MatMatMult", "Mat"); 14629566063dSJacob Faibussowitsch PetscCall(PetscOptionsBool("-matmatmult_backend_cpu", "Use CPU code", "MatMatMult", usecpu, &usecpu, NULL)); 1463d0609cedSBarry Smith PetscOptionsEnd(); 1464076ba34aSJunchao Zhang } else { 1465d0609cedSBarry Smith PetscOptionsBegin(PetscObjectComm((PetscObject)mat), ((PetscObject)mat)->prefix, "MatProduct_AB", "Mat"); 14669566063dSJacob Faibussowitsch PetscCall(PetscOptionsBool("-mat_product_algorithm_backend_cpu", "Use CPU code", "MatMatMult", usecpu, &usecpu, NULL)); 1467d0609cedSBarry Smith PetscOptionsEnd(); 1468076ba34aSJunchao Zhang } 1469076ba34aSJunchao Zhang break; 1470076ba34aSJunchao Zhang case MATPRODUCT_AtB: 1471076ba34aSJunchao Zhang if (product->api_user) { 1472d0609cedSBarry Smith PetscOptionsBegin(PetscObjectComm((PetscObject)mat), ((PetscObject)mat)->prefix, "MatTransposeMatMult", "Mat"); 14739566063dSJacob Faibussowitsch PetscCall(PetscOptionsBool("-mattransposematmult_backend_cpu", "Use CPU code", "MatTransposeMatMult", usecpu, &usecpu, NULL)); 1474d0609cedSBarry Smith PetscOptionsEnd(); 1475076ba34aSJunchao Zhang } else { 1476d0609cedSBarry Smith PetscOptionsBegin(PetscObjectComm((PetscObject)mat), ((PetscObject)mat)->prefix, "MatProduct_AtB", "Mat"); 14779566063dSJacob Faibussowitsch PetscCall(PetscOptionsBool("-mat_product_algorithm_backend_cpu", "Use CPU code", "MatTransposeMatMult", usecpu, &usecpu, NULL)); 1478d0609cedSBarry Smith PetscOptionsEnd(); 1479076ba34aSJunchao Zhang } 1480076ba34aSJunchao Zhang break; 1481076ba34aSJunchao Zhang case MATPRODUCT_PtAP: 1482076ba34aSJunchao Zhang if (product->api_user) { 1483d0609cedSBarry Smith PetscOptionsBegin(PetscObjectComm((PetscObject)mat), ((PetscObject)mat)->prefix, "MatPtAP", "Mat"); 14849566063dSJacob Faibussowitsch PetscCall(PetscOptionsBool("-matptap_backend_cpu", "Use CPU code", "MatPtAP", usecpu, &usecpu, NULL)); 1485d0609cedSBarry Smith PetscOptionsEnd(); 1486076ba34aSJunchao Zhang } else { 1487d0609cedSBarry Smith PetscOptionsBegin(PetscObjectComm((PetscObject)mat), ((PetscObject)mat)->prefix, "MatProduct_PtAP", "Mat"); 14889566063dSJacob Faibussowitsch PetscCall(PetscOptionsBool("-mat_product_algorithm_backend_cpu", "Use CPU code", "MatPtAP", usecpu, &usecpu, NULL)); 1489d0609cedSBarry Smith PetscOptionsEnd(); 1490076ba34aSJunchao Zhang } 1491076ba34aSJunchao Zhang break; 1492d71ae5a4SJacob Faibussowitsch default: 1493d71ae5a4SJacob Faibussowitsch break; 1494076ba34aSJunchao Zhang } 1495076ba34aSJunchao Zhang match = (PetscBool)!usecpu; 1496076ba34aSJunchao Zhang } 1497076ba34aSJunchao Zhang if (match) { 1498076ba34aSJunchao Zhang switch (product->type) { 1499076ba34aSJunchao Zhang case MATPRODUCT_AB: 1500076ba34aSJunchao Zhang case MATPRODUCT_AtB: 1501d71ae5a4SJacob Faibussowitsch case MATPRODUCT_PtAP: 1502d71ae5a4SJacob Faibussowitsch mat->ops->productsymbolic = MatProductSymbolic_MPIAIJKokkos; 1503d71ae5a4SJacob Faibussowitsch break; 1504d71ae5a4SJacob Faibussowitsch default: 1505d71ae5a4SJacob Faibussowitsch break; 1506076ba34aSJunchao Zhang } 1507076ba34aSJunchao Zhang } 1508076ba34aSJunchao Zhang /* fallback to MPIAIJ ops */ 150948a46eb9SPierre Jolivet if (!mat->ops->productsymbolic) PetscCall(MatProductSetFromOptions_MPIAIJ(mat)); 15103ba16761SJacob Faibussowitsch PetscFunctionReturn(PETSC_SUCCESS); 1511076ba34aSJunchao Zhang } 1512076ba34aSJunchao Zhang 15132c4ab24aSJunchao Zhang // Mirror of MatCOOStruct_MPIAIJ on device 15142c4ab24aSJunchao Zhang struct MatCOOStruct_MPIAIJKokkos { 15152c4ab24aSJunchao Zhang PetscCount n; 15162c4ab24aSJunchao Zhang PetscSF sf; 15172c4ab24aSJunchao Zhang PetscCount Annz, Bnnz; 15182c4ab24aSJunchao Zhang PetscCount Annz2, Bnnz2; 15192c4ab24aSJunchao Zhang PetscCountKokkosView Ajmap1, Aperm1; 15202c4ab24aSJunchao Zhang PetscCountKokkosView Bjmap1, Bperm1; 15212c4ab24aSJunchao Zhang PetscCountKokkosView Aimap2, Ajmap2, Aperm2; 15222c4ab24aSJunchao Zhang PetscCountKokkosView Bimap2, Bjmap2, Bperm2; 15232c4ab24aSJunchao Zhang PetscCountKokkosView Cperm1; 15242c4ab24aSJunchao Zhang MatScalarKokkosView sendbuf, recvbuf; 15252c4ab24aSJunchao Zhang 152692896123SJunchao Zhang MatCOOStruct_MPIAIJKokkos(const MatCOOStruct_MPIAIJ *coo_h) 15272c4ab24aSJunchao Zhang { 152892896123SJunchao Zhang auto &exec = PetscGetKokkosExecutionSpace(); 152992896123SJunchao Zhang 153092896123SJunchao Zhang n = coo_h->n; 153192896123SJunchao Zhang sf = coo_h->sf; 153292896123SJunchao Zhang Annz = coo_h->Annz; 153392896123SJunchao Zhang Bnnz = coo_h->Bnnz; 153492896123SJunchao Zhang Annz2 = coo_h->Annz2; 153592896123SJunchao Zhang Bnnz2 = coo_h->Bnnz2; 153692896123SJunchao Zhang Ajmap1 = Kokkos::create_mirror_view_and_copy(exec, PetscCountKokkosViewHost(coo_h->Ajmap1, coo_h->Annz + 1)); 153792896123SJunchao Zhang Aperm1 = Kokkos::create_mirror_view_and_copy(exec, PetscCountKokkosViewHost(coo_h->Aperm1, coo_h->Atot1)); 153892896123SJunchao Zhang Bjmap1 = Kokkos::create_mirror_view_and_copy(exec, PetscCountKokkosViewHost(coo_h->Bjmap1, coo_h->Bnnz + 1)); 153992896123SJunchao Zhang Bperm1 = Kokkos::create_mirror_view_and_copy(exec, PetscCountKokkosViewHost(coo_h->Bperm1, coo_h->Btot1)); 154092896123SJunchao Zhang Aimap2 = Kokkos::create_mirror_view_and_copy(exec, PetscCountKokkosViewHost(coo_h->Aimap2, coo_h->Annz2)); 154192896123SJunchao Zhang Ajmap2 = Kokkos::create_mirror_view_and_copy(exec, PetscCountKokkosViewHost(coo_h->Ajmap2, coo_h->Annz2 + 1)); 154292896123SJunchao Zhang Aperm2 = Kokkos::create_mirror_view_and_copy(exec, PetscCountKokkosViewHost(coo_h->Aperm2, coo_h->Atot2)); 154392896123SJunchao Zhang Bimap2 = Kokkos::create_mirror_view_and_copy(exec, PetscCountKokkosViewHost(coo_h->Bimap2, coo_h->Bnnz2)); 154492896123SJunchao Zhang Bjmap2 = Kokkos::create_mirror_view_and_copy(exec, PetscCountKokkosViewHost(coo_h->Bjmap2, coo_h->Bnnz2 + 1)); 154592896123SJunchao Zhang Bperm2 = Kokkos::create_mirror_view_and_copy(exec, PetscCountKokkosViewHost(coo_h->Bperm2, coo_h->Btot2)); 154692896123SJunchao Zhang Cperm1 = Kokkos::create_mirror_view_and_copy(exec, PetscCountKokkosViewHost(coo_h->Cperm1, coo_h->sendlen)); 154792896123SJunchao Zhang sendbuf = Kokkos::create_mirror_view(Kokkos::WithoutInitializing, exec, MatScalarKokkosViewHost(coo_h->sendbuf, coo_h->sendlen)); 154892896123SJunchao Zhang recvbuf = Kokkos::create_mirror_view(Kokkos::WithoutInitializing, exec, MatScalarKokkosViewHost(coo_h->recvbuf, coo_h->recvlen)); 15492c4ab24aSJunchao Zhang PetscCallVoid(PetscObjectReference((PetscObject)sf)); 15502c4ab24aSJunchao Zhang } 15512c4ab24aSJunchao Zhang 15522c4ab24aSJunchao Zhang ~MatCOOStruct_MPIAIJKokkos() { PetscCallVoid(PetscSFDestroy(&sf)); } 15532c4ab24aSJunchao Zhang }; 15542c4ab24aSJunchao Zhang 15552c4ab24aSJunchao Zhang static PetscErrorCode MatCOOStructDestroy_MPIAIJKokkos(void *data) 15562c4ab24aSJunchao Zhang { 15572c4ab24aSJunchao Zhang PetscFunctionBegin; 15582c4ab24aSJunchao Zhang PetscCallCXX(delete static_cast<MatCOOStruct_MPIAIJKokkos *>(data)); 15592c4ab24aSJunchao Zhang PetscFunctionReturn(PETSC_SUCCESS); 15602c4ab24aSJunchao Zhang } 15612c4ab24aSJunchao Zhang 1562d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatSetPreallocationCOO_MPIAIJKokkos(Mat mat, PetscCount coo_n, PetscInt coo_i[], PetscInt coo_j[]) 1563d71ae5a4SJacob Faibussowitsch { 15642c4ab24aSJunchao Zhang PetscContainer container_h, container_d; 15652c4ab24aSJunchao Zhang MatCOOStruct_MPIAIJ *coo_h; 15662c4ab24aSJunchao Zhang MatCOOStruct_MPIAIJKokkos *coo_d; 156742550becSJunchao Zhang 156842550becSJunchao Zhang PetscFunctionBegin; 156930203840SJunchao Zhang PetscCall(MatSetPreallocationCOO_MPIAIJ(mat, coo_n, coo_i, coo_j)); /* mpiaij->A,B's type is set to seqaijkokkos */ 1570cbc6b225SStefano Zampini mat->preallocated = PETSC_TRUE; 15719566063dSJacob Faibussowitsch PetscCall(MatAssemblyBegin(mat, MAT_FINAL_ASSEMBLY)); 15729566063dSJacob Faibussowitsch PetscCall(MatAssemblyEnd(mat, MAT_FINAL_ASSEMBLY)); 15739566063dSJacob Faibussowitsch PetscCall(MatZeroEntries(mat)); 15742c4ab24aSJunchao Zhang 15752c4ab24aSJunchao Zhang // Copy the COO struct to device 15762c4ab24aSJunchao Zhang PetscCall(PetscObjectQuery((PetscObject)mat, "__PETSc_MatCOOStruct_Host", (PetscObject *)&container_h)); 15772c4ab24aSJunchao Zhang PetscCall(PetscContainerGetPointer(container_h, (void **)&coo_h)); 15782c4ab24aSJunchao Zhang PetscCallCXX(coo_d = new MatCOOStruct_MPIAIJKokkos(coo_h)); 15792c4ab24aSJunchao Zhang 15802c4ab24aSJunchao Zhang // Put the COO struct in a container and then attach that to the matrix 15812c4ab24aSJunchao Zhang PetscCall(PetscContainerCreate(PETSC_COMM_SELF, &container_d)); 15822c4ab24aSJunchao Zhang PetscCall(PetscContainerSetPointer(container_d, coo_d)); 15832c4ab24aSJunchao Zhang PetscCall(PetscContainerSetUserDestroy(container_d, MatCOOStructDestroy_MPIAIJKokkos)); 15842c4ab24aSJunchao Zhang PetscCall(PetscObjectCompose((PetscObject)mat, "__PETSc_MatCOOStruct_Device", (PetscObject)container_d)); 15852c4ab24aSJunchao Zhang PetscCall(PetscContainerDestroy(&container_d)); 15863ba16761SJacob Faibussowitsch PetscFunctionReturn(PETSC_SUCCESS); 158742550becSJunchao Zhang } 158842550becSJunchao Zhang 1589d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatSetValuesCOO_MPIAIJKokkos(Mat mat, const PetscScalar v[], InsertMode imode) 1590d71ae5a4SJacob Faibussowitsch { 1591394ed5ebSJunchao Zhang Mat_MPIAIJ *mpiaij = static_cast<Mat_MPIAIJ *>(mat->data); 159242550becSJunchao Zhang Mat A = mpiaij->A, B = mpiaij->B; 159342550becSJunchao Zhang MatScalarKokkosView Aa, Ba; 1594394ed5ebSJunchao Zhang MatScalarKokkosView v1; 159542550becSJunchao Zhang PetscMemType memtype; 15962c4ab24aSJunchao Zhang PetscContainer container; 15972c4ab24aSJunchao Zhang MatCOOStruct_MPIAIJKokkos *coo; 159892896123SJunchao Zhang Kokkos::DefaultExecutionSpace &exec = PetscGetKokkosExecutionSpace(); 159942550becSJunchao Zhang 160042550becSJunchao Zhang PetscFunctionBegin; 16012c4ab24aSJunchao Zhang PetscCall(PetscObjectQuery((PetscObject)mat, "__PETSc_MatCOOStruct_Device", (PetscObject *)&container)); 16022c4ab24aSJunchao Zhang PetscCall(PetscContainerGetPointer(container, (void **)&coo)); 16032c4ab24aSJunchao Zhang 16042c4ab24aSJunchao Zhang const auto &n = coo->n; 16052c4ab24aSJunchao Zhang const auto &Annz = coo->Annz; 16062c4ab24aSJunchao Zhang const auto &Annz2 = coo->Annz2; 16072c4ab24aSJunchao Zhang const auto &Bnnz = coo->Bnnz; 16082c4ab24aSJunchao Zhang const auto &Bnnz2 = coo->Bnnz2; 16092c4ab24aSJunchao Zhang const auto &vsend = coo->sendbuf; 16102c4ab24aSJunchao Zhang const auto &v2 = coo->recvbuf; 16112c4ab24aSJunchao Zhang const auto &Ajmap1 = coo->Ajmap1; 16122c4ab24aSJunchao Zhang const auto &Ajmap2 = coo->Ajmap2; 16132c4ab24aSJunchao Zhang const auto &Aimap2 = coo->Aimap2; 16142c4ab24aSJunchao Zhang const auto &Bjmap1 = coo->Bjmap1; 16152c4ab24aSJunchao Zhang const auto &Bjmap2 = coo->Bjmap2; 16162c4ab24aSJunchao Zhang const auto &Bimap2 = coo->Bimap2; 16172c4ab24aSJunchao Zhang const auto &Aperm1 = coo->Aperm1; 16182c4ab24aSJunchao Zhang const auto &Aperm2 = coo->Aperm2; 16192c4ab24aSJunchao Zhang const auto &Bperm1 = coo->Bperm1; 16202c4ab24aSJunchao Zhang const auto &Bperm2 = coo->Bperm2; 16212c4ab24aSJunchao Zhang const auto &Cperm1 = coo->Cperm1; 16222c4ab24aSJunchao Zhang 16239566063dSJacob Faibussowitsch PetscCall(PetscGetMemType(v, &memtype)); /* Return PETSC_MEMTYPE_HOST when v is NULL */ 162442550becSJunchao Zhang if (PetscMemTypeHost(memtype)) { /* If user gave v[] in host, we need to copy it to device if any */ 162592896123SJunchao Zhang v1 = Kokkos::create_mirror_view_and_copy(exec, MatScalarKokkosViewHost((PetscScalar *)v, n)); 162642550becSJunchao Zhang } else { 16272c4ab24aSJunchao Zhang v1 = MatScalarKokkosView((PetscScalar *)v, n); /* Directly use v[]'s memory */ 162842550becSJunchao Zhang } 162942550becSJunchao Zhang 163042550becSJunchao Zhang if (imode == INSERT_VALUES) { 16319566063dSJacob Faibussowitsch PetscCall(MatSeqAIJGetKokkosViewWrite(A, &Aa)); /* write matrix values */ 16329566063dSJacob Faibussowitsch PetscCall(MatSeqAIJGetKokkosViewWrite(B, &Ba)); 1633394ed5ebSJunchao Zhang } else { 16349566063dSJacob Faibussowitsch PetscCall(MatSeqAIJGetKokkosView(A, &Aa)); /* read & write matrix values */ 16359566063dSJacob Faibussowitsch PetscCall(MatSeqAIJGetKokkosView(B, &Ba)); 163642550becSJunchao Zhang } 163742550becSJunchao Zhang 163808bb9926SJunchao Zhang PetscCall(PetscLogGpuTimeBegin()); 163942550becSJunchao Zhang /* Pack entries to be sent to remote */ 164092896123SJunchao Zhang Kokkos::parallel_for(Kokkos::RangePolicy<>(exec, 0, vsend.extent(0)), KOKKOS_LAMBDA(const PetscCount i) { vsend(i) = v1(Cperm1(i)); }); 164142550becSJunchao Zhang 164242550becSJunchao Zhang /* Send remote entries to their owner and overlap the communication with local computation */ 16432c4ab24aSJunchao Zhang PetscCall(PetscSFReduceWithMemTypeBegin(coo->sf, MPIU_SCALAR, PETSC_MEMTYPE_KOKKOS, vsend.data(), PETSC_MEMTYPE_KOKKOS, v2.data(), MPI_REPLACE)); 1644158ec288SJunchao Zhang /* Add local entries to A and B in one kernel */ 16459371c9d4SSatish Balay Kokkos::parallel_for( 164692896123SJunchao Zhang Kokkos::RangePolicy<>(exec, 0, Annz + Bnnz), KOKKOS_LAMBDA(PetscCount i) { 1647158ec288SJunchao Zhang PetscScalar sum = 0.0; 1648158ec288SJunchao Zhang if (i < Annz) { 1649158ec288SJunchao Zhang for (PetscCount k = Ajmap1(i); k < Ajmap1(i + 1); k++) sum += v1(Aperm1(k)); 1650ac38520cSJunchao Zhang Aa(i) = (imode == INSERT_VALUES ? 0.0 : Aa(i)) + sum; 1651158ec288SJunchao Zhang } else { 1652158ec288SJunchao Zhang i -= Annz; 1653158ec288SJunchao Zhang for (PetscCount k = Bjmap1(i); k < Bjmap1(i + 1); k++) sum += v1(Bperm1(k)); 1654ac38520cSJunchao Zhang Ba(i) = (imode == INSERT_VALUES ? 0.0 : Ba(i)) + sum; 1655158ec288SJunchao Zhang } 1656158ec288SJunchao Zhang }); 16572c4ab24aSJunchao Zhang PetscCall(PetscSFReduceEnd(coo->sf, MPIU_SCALAR, vsend.data(), v2.data(), MPI_REPLACE)); 165842550becSJunchao Zhang 1659158ec288SJunchao Zhang /* Add received remote entries to A and B in one kernel */ 16609371c9d4SSatish Balay Kokkos::parallel_for( 166192896123SJunchao Zhang Kokkos::RangePolicy<>(exec, 0, Annz2 + Bnnz2), KOKKOS_LAMBDA(PetscCount i) { 1662158ec288SJunchao Zhang if (i < Annz2) { 1663158ec288SJunchao Zhang for (PetscCount k = Ajmap2(i); k < Ajmap2(i + 1); k++) Aa(Aimap2(i)) += v2(Aperm2(k)); 1664158ec288SJunchao Zhang } else { 1665158ec288SJunchao Zhang i -= Annz2; 1666158ec288SJunchao Zhang for (PetscCount k = Bjmap2(i); k < Bjmap2(i + 1); k++) Ba(Bimap2(i)) += v2(Bperm2(k)); 1667158ec288SJunchao Zhang } 1668158ec288SJunchao Zhang }); 166908bb9926SJunchao Zhang PetscCall(PetscLogGpuTimeEnd()); 167042550becSJunchao Zhang 1671394ed5ebSJunchao Zhang if (imode == INSERT_VALUES) { 16729566063dSJacob Faibussowitsch PetscCall(MatSeqAIJRestoreKokkosViewWrite(A, &Aa)); /* Increase A & B's state etc. */ 16739566063dSJacob Faibussowitsch PetscCall(MatSeqAIJRestoreKokkosViewWrite(B, &Ba)); 1674394ed5ebSJunchao Zhang } else { 16759566063dSJacob Faibussowitsch PetscCall(MatSeqAIJRestoreKokkosView(A, &Aa)); 16769566063dSJacob Faibussowitsch PetscCall(MatSeqAIJRestoreKokkosView(B, &Ba)); 1677394ed5ebSJunchao Zhang } 16783ba16761SJacob Faibussowitsch PetscFunctionReturn(PETSC_SUCCESS); 167942550becSJunchao Zhang } 168042550becSJunchao Zhang 16812c4ab24aSJunchao Zhang static PetscErrorCode MatDestroy_MPIAIJKokkos(Mat A) 1682d71ae5a4SJacob Faibussowitsch { 1683076ba34aSJunchao Zhang PetscFunctionBegin; 16849566063dSJacob Faibussowitsch PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatMPIAIJSetPreallocation_C", NULL)); 16859566063dSJacob Faibussowitsch PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatMPIAIJGetLocalMatMerge_C", NULL)); 16869566063dSJacob Faibussowitsch PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatSetPreallocationCOO_C", NULL)); 16879566063dSJacob Faibussowitsch PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatSetValuesCOO_C", NULL)); 16889566063dSJacob Faibussowitsch PetscCall(MatDestroy_MPIAIJ(A)); 16893ba16761SJacob Faibussowitsch PetscFunctionReturn(PETSC_SUCCESS); 1690076ba34aSJunchao Zhang } 1691076ba34aSJunchao Zhang 1692f4747e26SJunchao Zhang static PetscErrorCode MatShift_MPIAIJKokkos(Mat A, PetscScalar a) 1693f4747e26SJunchao Zhang { 1694f4747e26SJunchao Zhang Mat_MPIAIJ *mpiaij = static_cast<Mat_MPIAIJ *>(A->data); 1695f4747e26SJunchao Zhang PetscBool congruent; 1696f4747e26SJunchao Zhang 1697f4747e26SJunchao Zhang PetscFunctionBegin; 1698f4747e26SJunchao Zhang PetscCall(MatHasCongruentLayouts(A, &congruent)); 1699f4747e26SJunchao Zhang if (congruent) { // square matrix and the diagonals are solely in the diag block 1700f4747e26SJunchao Zhang PetscCall(MatShift(mpiaij->A, a)); 1701f4747e26SJunchao Zhang } else { // too hard, use the general version 1702f4747e26SJunchao Zhang PetscCall(MatShift_Basic(A, a)); 1703f4747e26SJunchao Zhang } 1704f4747e26SJunchao Zhang PetscFunctionReturn(PETSC_SUCCESS); 1705f4747e26SJunchao Zhang } 1706f4747e26SJunchao Zhang 17072c4ab24aSJunchao Zhang static PetscErrorCode MatSetOps_MPIAIJKokkos(Mat B) 17082c4ab24aSJunchao Zhang { 17092c4ab24aSJunchao Zhang PetscFunctionBegin; 17102c4ab24aSJunchao Zhang B->ops->assemblyend = MatAssemblyEnd_MPIAIJKokkos; 17112c4ab24aSJunchao Zhang B->ops->mult = MatMult_MPIAIJKokkos; 17122c4ab24aSJunchao Zhang B->ops->multadd = MatMultAdd_MPIAIJKokkos; 17132c4ab24aSJunchao Zhang B->ops->multtranspose = MatMultTranspose_MPIAIJKokkos; 17142c4ab24aSJunchao Zhang B->ops->productsetfromoptions = MatProductSetFromOptions_MPIAIJKokkos; 17152c4ab24aSJunchao Zhang B->ops->destroy = MatDestroy_MPIAIJKokkos; 1716f4747e26SJunchao Zhang B->ops->shift = MatShift_MPIAIJKokkos; 17172c4ab24aSJunchao Zhang 17182c4ab24aSJunchao Zhang PetscCall(PetscObjectComposeFunction((PetscObject)B, "MatMPIAIJSetPreallocation_C", MatMPIAIJSetPreallocation_MPIAIJKokkos)); 17192c4ab24aSJunchao Zhang PetscCall(PetscObjectComposeFunction((PetscObject)B, "MatMPIAIJGetLocalMatMerge_C", MatMPIAIJGetLocalMatMerge_MPIAIJKokkos)); 17202c4ab24aSJunchao Zhang PetscCall(PetscObjectComposeFunction((PetscObject)B, "MatSetPreallocationCOO_C", MatSetPreallocationCOO_MPIAIJKokkos)); 17212c4ab24aSJunchao Zhang PetscCall(PetscObjectComposeFunction((PetscObject)B, "MatSetValuesCOO_C", MatSetValuesCOO_MPIAIJKokkos)); 17222c4ab24aSJunchao Zhang PetscFunctionReturn(PETSC_SUCCESS); 17232c4ab24aSJunchao Zhang } 17242c4ab24aSJunchao Zhang 1725d71ae5a4SJacob Faibussowitsch PETSC_INTERN PetscErrorCode MatConvert_MPIAIJ_MPIAIJKokkos(Mat A, MatType mtype, MatReuse reuse, Mat *newmat) 1726d71ae5a4SJacob Faibussowitsch { 17278c3ff71bSJunchao Zhang Mat B; 1728076ba34aSJunchao Zhang Mat_MPIAIJ *a; 17298c3ff71bSJunchao Zhang 17308c3ff71bSJunchao Zhang PetscFunctionBegin; 17318c3ff71bSJunchao Zhang if (reuse == MAT_INITIAL_MATRIX) { 17329566063dSJacob Faibussowitsch PetscCall(MatDuplicate(A, MAT_COPY_VALUES, newmat)); 17338c3ff71bSJunchao Zhang } else if (reuse == MAT_REUSE_MATRIX) { 17349566063dSJacob Faibussowitsch PetscCall(MatCopy(A, *newmat, SAME_NONZERO_PATTERN)); 17358c3ff71bSJunchao Zhang } 17368c3ff71bSJunchao Zhang B = *newmat; 17378c3ff71bSJunchao Zhang 17386f3d89d0SStefano Zampini B->boundtocpu = PETSC_FALSE; 17399566063dSJacob Faibussowitsch PetscCall(PetscFree(B->defaultvectype)); 17409566063dSJacob Faibussowitsch PetscCall(PetscStrallocpy(VECKOKKOS, &B->defaultvectype)); 17419566063dSJacob Faibussowitsch PetscCall(PetscObjectChangeTypeName((PetscObject)B, MATMPIAIJKOKKOS)); 17428c3ff71bSJunchao Zhang 1743076ba34aSJunchao Zhang a = static_cast<Mat_MPIAIJ *>(A->data); 17449566063dSJacob Faibussowitsch if (a->A) PetscCall(MatSetType(a->A, MATSEQAIJKOKKOS)); 17459566063dSJacob Faibussowitsch if (a->B) PetscCall(MatSetType(a->B, MATSEQAIJKOKKOS)); 17469566063dSJacob Faibussowitsch if (a->lvec) PetscCall(VecSetType(a->lvec, VECSEQKOKKOS)); 17472c4ab24aSJunchao Zhang PetscCall(MatSetOps_MPIAIJKokkos(B)); 17483ba16761SJacob Faibussowitsch PetscFunctionReturn(PETSC_SUCCESS); 17498c3ff71bSJunchao Zhang } 17502c4ab24aSJunchao Zhang 17513f3ba80aSJunchao Zhang /*MC 175211a5261eSBarry Smith MATAIJKOKKOS - "mpiaijkokkos", a matrix type to be used for CSR sparse matrices with Kokkos 17538c3ff71bSJunchao Zhang 175415229ffcSPierre Jolivet A matrix type using Kokkos-Kernels CrsMatrix type for portability across different device types 17553f3ba80aSJunchao Zhang 17562ef1f0ffSBarry Smith Options Database Key: 17572ef1f0ffSBarry Smith . -mat_type aijkokkos - sets the matrix type to `MATAIJKOKKOS` 17583f3ba80aSJunchao Zhang 17593f3ba80aSJunchao Zhang Level: beginner 17603f3ba80aSJunchao Zhang 17611cc06b55SBarry Smith .seealso: [](ch_matrices), `Mat`, `MatCreateAIJKokkos()`, `MATSEQAIJKOKKOS`, `MATSEQAIJ`, `MATMPIAIJ` 17623f3ba80aSJunchao Zhang M*/ 1763d71ae5a4SJacob Faibussowitsch PETSC_EXTERN PetscErrorCode MatCreate_MPIAIJKokkos(Mat A) 1764d71ae5a4SJacob Faibussowitsch { 17658c3ff71bSJunchao Zhang PetscFunctionBegin; 17669566063dSJacob Faibussowitsch PetscCall(PetscKokkosInitializeCheck()); 17679566063dSJacob Faibussowitsch PetscCall(MatCreate_MPIAIJ(A)); 17689566063dSJacob Faibussowitsch PetscCall(MatConvert_MPIAIJ_MPIAIJKokkos(A, MATMPIAIJKOKKOS, MAT_INPLACE_MATRIX, &A)); 17693ba16761SJacob Faibussowitsch PetscFunctionReturn(PETSC_SUCCESS); 17708c3ff71bSJunchao Zhang } 17718c3ff71bSJunchao Zhang 17728c3ff71bSJunchao Zhang /*@C 177311a5261eSBarry Smith MatCreateAIJKokkos - Creates a sparse matrix in `MATAIJKOKOS` (compressed row) format 17748c3ff71bSJunchao Zhang (the default parallel PETSc format). This matrix will ultimately pushed down 177520f4b53cSBarry Smith to Kokkos for calculations. 17768c3ff71bSJunchao Zhang 17778c3ff71bSJunchao Zhang Collective 17788c3ff71bSJunchao Zhang 17798c3ff71bSJunchao Zhang Input Parameters: 178011a5261eSBarry Smith + comm - MPI communicator, set to `PETSC_COMM_SELF` 178120f4b53cSBarry Smith . m - number of local rows (or `PETSC_DECIDE` to have calculated if `M` is given) 178220f4b53cSBarry Smith This value should be the same as the local size used in creating the 178320f4b53cSBarry Smith y vector for the matrix-vector product y = Ax. 178420f4b53cSBarry Smith . n - This value should be the same as the local size used in creating the 178520f4b53cSBarry Smith x vector for the matrix-vector product y = Ax. (or `PETSC_DECIDE` to have 178620f4b53cSBarry Smith calculated if N is given) For square matrices n is almost always `m`. 178720f4b53cSBarry Smith . M - number of global rows (or `PETSC_DETERMINE` to have calculated if `m` is given) 178820f4b53cSBarry Smith . N - number of global columns (or `PETSC_DETERMINE` to have calculated if `n` is given) 178920f4b53cSBarry Smith . d_nz - number of nonzeros per row in DIAGONAL portion of local submatrix 179020f4b53cSBarry Smith (same value is used for all local rows) 179120f4b53cSBarry Smith . d_nnz - array containing the number of nonzeros in the various rows of the 179220f4b53cSBarry Smith DIAGONAL portion of the local submatrix (possibly different for each row) 179320f4b53cSBarry Smith or `NULL`, if `d_nz` is used to specify the nonzero structure. 179420f4b53cSBarry Smith The size of this array is equal to the number of local rows, i.e `m`. 179520f4b53cSBarry Smith For matrices you plan to factor you must leave room for the diagonal entry and 179620f4b53cSBarry Smith put in the entry even if it is zero. 179720f4b53cSBarry Smith . o_nz - number of nonzeros per row in the OFF-DIAGONAL portion of local 179820f4b53cSBarry Smith submatrix (same value is used for all local rows). 179920f4b53cSBarry Smith - o_nnz - array containing the number of nonzeros in the various rows of the 180020f4b53cSBarry Smith OFF-DIAGONAL portion of the local submatrix (possibly different for 180120f4b53cSBarry Smith each row) or `NULL`, if `o_nz` is used to specify the nonzero 180220f4b53cSBarry Smith structure. The size of this array is equal to the number 180320f4b53cSBarry Smith of local rows, i.e `m`. 18048c3ff71bSJunchao Zhang 18058c3ff71bSJunchao Zhang Output Parameter: 18068c3ff71bSJunchao Zhang . A - the matrix 18078c3ff71bSJunchao Zhang 18082ef1f0ffSBarry Smith Level: intermediate 18092ef1f0ffSBarry Smith 18102ef1f0ffSBarry Smith Notes: 181111a5261eSBarry Smith It is recommended that one use the `MatCreate()`, `MatSetType()` and/or `MatSetFromOptions()`, 18128c3ff71bSJunchao Zhang MatXXXXSetPreallocation() paradigm instead of this routine directly. 181311a5261eSBarry Smith [MatXXXXSetPreallocation() is, for example, `MatSeqAIJSetPreallocation()`] 18148c3ff71bSJunchao Zhang 1815667f096bSBarry Smith The AIJ format, also called compressed row storage), is fully compatible with standard Fortran 18168c3ff71bSJunchao Zhang storage. That is, the stored row and column indices can begin at 18172ef1f0ffSBarry Smith either one (as in Fortran) or zero. 18188c3ff71bSJunchao Zhang 18191cc06b55SBarry Smith .seealso: [](ch_matrices), `Mat`, `MATAIJKOKOS`, `MATSEQAIJKOKOS`, `MATMPIAIJKOKOS`, `MatCreate()`, `MatCreateAIJ()`, `MatSetValues()`, 1820fe59aa6dSJacob Faibussowitsch `MatSeqAIJSetColumnIndices()`, `MatCreateSeqAIJWithArrays()`, `MATMPIAIJKOKKOS`, `MATAIJKOKKOS` 18218c3ff71bSJunchao Zhang @*/ 1822d71ae5a4SJacob Faibussowitsch PetscErrorCode MatCreateAIJKokkos(MPI_Comm comm, PetscInt m, PetscInt n, PetscInt M, PetscInt N, PetscInt d_nz, const PetscInt d_nnz[], PetscInt o_nz, const PetscInt o_nnz[], Mat *A) 1823d71ae5a4SJacob Faibussowitsch { 18248c3ff71bSJunchao Zhang PetscMPIInt size; 18258c3ff71bSJunchao Zhang 18268c3ff71bSJunchao Zhang PetscFunctionBegin; 18279566063dSJacob Faibussowitsch PetscCall(MatCreate(comm, A)); 18289566063dSJacob Faibussowitsch PetscCall(MatSetSizes(*A, m, n, M, N)); 18299566063dSJacob Faibussowitsch PetscCallMPI(MPI_Comm_size(comm, &size)); 18308c3ff71bSJunchao Zhang if (size > 1) { 18319566063dSJacob Faibussowitsch PetscCall(MatSetType(*A, MATMPIAIJKOKKOS)); 18329566063dSJacob Faibussowitsch PetscCall(MatMPIAIJSetPreallocation(*A, d_nz, d_nnz, o_nz, o_nnz)); 18338c3ff71bSJunchao Zhang } else { 18349566063dSJacob Faibussowitsch PetscCall(MatSetType(*A, MATSEQAIJKOKKOS)); 18359566063dSJacob Faibussowitsch PetscCall(MatSeqAIJSetPreallocation(*A, d_nz, d_nnz)); 18368c3ff71bSJunchao Zhang } 18373ba16761SJacob Faibussowitsch PetscFunctionReturn(PETSC_SUCCESS); 18388c3ff71bSJunchao Zhang } 1839