1d326c3f1SJunchao Zhang #include <petsc_kokkos.hpp> 211d22bbfSJunchao Zhang #include <petscvec_kokkos.hpp> 3c0c276a7Ssdargavi #include <petscmat_kokkos.hpp> 4f0e6e2d1SJunchao Zhang #include <petscpkg_version.h> 542550becSJunchao Zhang #include <petsc/private/sfimpl.h> 692896123SJunchao Zhang #include <petsc/private/kokkosimpl.hpp> 72c4ab24aSJunchao Zhang #include <../src/mat/impls/aij/seq/kokkos/aijkok.hpp> 88c3ff71bSJunchao Zhang #include <../src/mat/impls/aij/mpi/mpiaij.h> 9076ba34aSJunchao Zhang #include <KokkosSparse_spadd.hpp> 100e3ece09SJunchao Zhang #include <KokkosSparse_spgemm.hpp> 1111d22bbfSJunchao Zhang 1266976f2fSJacob Faibussowitsch static PetscErrorCode MatAssemblyEnd_MPIAIJKokkos(Mat A, MatAssemblyType mode) 13d71ae5a4SJacob Faibussowitsch { 1430203840SJunchao Zhang Mat_MPIAIJ *mpiaij = (Mat_MPIAIJ *)A->data; 158c3ff71bSJunchao Zhang 168c3ff71bSJunchao Zhang PetscFunctionBegin; 179566063dSJacob Faibussowitsch PetscCall(MatAssemblyEnd_MPIAIJ(A, mode)); 1830203840SJunchao Zhang /* E.g., MatCreateSubMatrix() calls MatCreateMPIAIJWithSeqAIJ(comm,A,B,..), which creates Bnew of SEQAIJ and destroys B of SEQAIJKOKKOS. 1930203840SJunchao Zhang Thus we finalize A/B/lvec's type in MatAssemblyEnd() to handle various cases. 2030203840SJunchao Zhang */ 2130203840SJunchao Zhang if (mode == MAT_FINAL_ASSEMBLY) { 2292896123SJunchao Zhang PetscScalarKokkosView v; 2392896123SJunchao Zhang 2430203840SJunchao Zhang PetscCall(MatSetType(mpiaij->A, MATSEQAIJKOKKOS)); 2530203840SJunchao Zhang PetscCall(MatSetType(mpiaij->B, MATSEQAIJKOKKOS)); 2692896123SJunchao Zhang PetscCall(VecSetType(mpiaij->lvec, VECSEQKOKKOS)); // lvec is init'ed on host, without copying to device 2792896123SJunchao Zhang PetscCall(VecGetKokkosViewWrite(mpiaij->lvec, &v)); // mark lvec updated on device, as we never need to init lvec on device 2892896123SJunchao Zhang PetscCall(VecRestoreKokkosViewWrite(mpiaij->lvec, &v)); 2930203840SJunchao Zhang } 303ba16761SJacob Faibussowitsch PetscFunctionReturn(PETSC_SUCCESS); 318c3ff71bSJunchao Zhang } 328c3ff71bSJunchao Zhang 3366976f2fSJacob Faibussowitsch static PetscErrorCode MatMPIAIJSetPreallocation_MPIAIJKokkos(Mat mat, PetscInt d_nz, const PetscInt d_nnz[], PetscInt o_nz, const PetscInt o_nnz[]) 34d71ae5a4SJacob Faibussowitsch { 352cdb1aeaSJunchao Zhang Mat_MPIAIJ *mpiaij; 368c3ff71bSJunchao Zhang 378c3ff71bSJunchao Zhang PetscFunctionBegin; 382cdb1aeaSJunchao Zhang // reuse MPIAIJ's preallocation, which sets A/B's blocksize along other things 392cdb1aeaSJunchao Zhang PetscCall(MatMPIAIJSetPreallocation_MPIAIJ(mat, d_nz, d_nnz, o_nz, o_nnz)); 402cdb1aeaSJunchao Zhang mpiaij = static_cast<Mat_MPIAIJ *>(mat->data); 412cdb1aeaSJunchao Zhang PetscCall(MatConvert_SeqAIJ_SeqAIJKokkos(mpiaij->A, MATSEQAIJKOKKOS, MAT_INPLACE_MATRIX, &mpiaij->A)); 422cdb1aeaSJunchao Zhang PetscCall(MatConvert_SeqAIJ_SeqAIJKokkos(mpiaij->B, MATSEQAIJKOKKOS, MAT_INPLACE_MATRIX, &mpiaij->B)); 433ba16761SJacob Faibussowitsch PetscFunctionReturn(PETSC_SUCCESS); 448c3ff71bSJunchao Zhang } 458c3ff71bSJunchao Zhang 4666976f2fSJacob Faibussowitsch static PetscErrorCode MatMult_MPIAIJKokkos(Mat mat, Vec xx, Vec yy) 47d71ae5a4SJacob Faibussowitsch { 488c3ff71bSJunchao Zhang Mat_MPIAIJ *mpiaij = (Mat_MPIAIJ *)mat->data; 498c3ff71bSJunchao Zhang PetscInt nt; 508c3ff71bSJunchao Zhang 518c3ff71bSJunchao Zhang PetscFunctionBegin; 529566063dSJacob Faibussowitsch PetscCall(VecGetLocalSize(xx, &nt)); 5308401ef6SPierre Jolivet PetscCheck(nt == mat->cmap->n, PETSC_COMM_SELF, PETSC_ERR_ARG_SIZ, "Incompatible partition of mat (%" PetscInt_FMT ") and xx (%" PetscInt_FMT ")", mat->cmap->n, nt); 549566063dSJacob Faibussowitsch PetscCall(VecScatterBegin(mpiaij->Mvctx, xx, mpiaij->lvec, INSERT_VALUES, SCATTER_FORWARD)); 559566063dSJacob Faibussowitsch PetscCall((*mpiaij->A->ops->mult)(mpiaij->A, xx, yy)); 569566063dSJacob Faibussowitsch PetscCall(VecScatterEnd(mpiaij->Mvctx, xx, mpiaij->lvec, INSERT_VALUES, SCATTER_FORWARD)); 579566063dSJacob Faibussowitsch PetscCall((*mpiaij->B->ops->multadd)(mpiaij->B, mpiaij->lvec, yy, yy)); 583ba16761SJacob Faibussowitsch PetscFunctionReturn(PETSC_SUCCESS); 598c3ff71bSJunchao Zhang } 608c3ff71bSJunchao Zhang 6166976f2fSJacob Faibussowitsch static PetscErrorCode MatMultAdd_MPIAIJKokkos(Mat mat, Vec xx, Vec yy, Vec zz) 62d71ae5a4SJacob Faibussowitsch { 638c3ff71bSJunchao Zhang Mat_MPIAIJ *mpiaij = (Mat_MPIAIJ *)mat->data; 648c3ff71bSJunchao Zhang PetscInt nt; 658c3ff71bSJunchao Zhang 668c3ff71bSJunchao Zhang PetscFunctionBegin; 679566063dSJacob Faibussowitsch PetscCall(VecGetLocalSize(xx, &nt)); 6808401ef6SPierre Jolivet PetscCheck(nt == mat->cmap->n, PETSC_COMM_SELF, PETSC_ERR_ARG_SIZ, "Incompatible partition of mat (%" PetscInt_FMT ") and xx (%" PetscInt_FMT ")", mat->cmap->n, nt); 699566063dSJacob Faibussowitsch PetscCall(VecScatterBegin(mpiaij->Mvctx, xx, mpiaij->lvec, INSERT_VALUES, SCATTER_FORWARD)); 709566063dSJacob Faibussowitsch PetscCall((*mpiaij->A->ops->multadd)(mpiaij->A, xx, yy, zz)); 719566063dSJacob Faibussowitsch PetscCall(VecScatterEnd(mpiaij->Mvctx, xx, mpiaij->lvec, INSERT_VALUES, SCATTER_FORWARD)); 729566063dSJacob Faibussowitsch PetscCall((*mpiaij->B->ops->multadd)(mpiaij->B, mpiaij->lvec, zz, zz)); 733ba16761SJacob Faibussowitsch PetscFunctionReturn(PETSC_SUCCESS); 748c3ff71bSJunchao Zhang } 758c3ff71bSJunchao Zhang 7666976f2fSJacob Faibussowitsch static PetscErrorCode MatMultTranspose_MPIAIJKokkos(Mat mat, Vec xx, Vec yy) 77d71ae5a4SJacob Faibussowitsch { 788c3ff71bSJunchao Zhang Mat_MPIAIJ *mpiaij = (Mat_MPIAIJ *)mat->data; 798c3ff71bSJunchao Zhang PetscInt nt; 808c3ff71bSJunchao Zhang 818c3ff71bSJunchao Zhang PetscFunctionBegin; 829566063dSJacob Faibussowitsch PetscCall(VecGetLocalSize(xx, &nt)); 8308401ef6SPierre Jolivet PetscCheck(nt == mat->rmap->n, PETSC_COMM_SELF, PETSC_ERR_ARG_SIZ, "Incompatible partition of mat (%" PetscInt_FMT ") and xx (%" PetscInt_FMT ")", mat->rmap->n, nt); 849566063dSJacob Faibussowitsch PetscCall((*mpiaij->B->ops->multtranspose)(mpiaij->B, xx, mpiaij->lvec)); 859566063dSJacob Faibussowitsch PetscCall((*mpiaij->A->ops->multtranspose)(mpiaij->A, xx, yy)); 869566063dSJacob Faibussowitsch PetscCall(VecScatterBegin(mpiaij->Mvctx, mpiaij->lvec, yy, ADD_VALUES, SCATTER_REVERSE)); 879566063dSJacob Faibussowitsch PetscCall(VecScatterEnd(mpiaij->Mvctx, mpiaij->lvec, yy, ADD_VALUES, SCATTER_REVERSE)); 883ba16761SJacob Faibussowitsch PetscFunctionReturn(PETSC_SUCCESS); 898c3ff71bSJunchao Zhang } 908c3ff71bSJunchao Zhang 91076ba34aSJunchao Zhang /* Merge the "A, B" matrices of mat into a matrix C. mat's type is MPIAIJKOKKOS. C's type is MATSEQAIJKOKKOS. 92076ba34aSJunchao Zhang A is put before B. C's size would be A->rmap->n by (A->cmap->n + B->cmap->n). 93076ba34aSJunchao Zhang C still uses local column ids. Their corresponding global column ids are returned in glob. 94076ba34aSJunchao Zhang */ 9566976f2fSJacob Faibussowitsch static PetscErrorCode MatMPIAIJGetLocalMatMerge_MPIAIJKokkos(Mat mat, MatReuse reuse, IS *glob, Mat *C) 96d71ae5a4SJacob Faibussowitsch { 97076ba34aSJunchao Zhang Mat Ad, Ao; 98076ba34aSJunchao Zhang const PetscInt *cmap; 99076ba34aSJunchao Zhang 100076ba34aSJunchao Zhang PetscFunctionBegin; 1019566063dSJacob Faibussowitsch PetscCall(MatMPIAIJGetSeqAIJ(mat, &Ad, &Ao, &cmap)); 1029566063dSJacob Faibussowitsch PetscCall(MatSeqAIJKokkosMergeMats(Ad, Ao, reuse, C)); 103076ba34aSJunchao Zhang if (glob) { 104076ba34aSJunchao Zhang PetscInt cst, i, dn, on, *gidx; 1059566063dSJacob Faibussowitsch PetscCall(MatGetLocalSize(Ad, NULL, &dn)); 1069566063dSJacob Faibussowitsch PetscCall(MatGetLocalSize(Ao, NULL, &on)); 1079566063dSJacob Faibussowitsch PetscCall(MatGetOwnershipRangeColumn(mat, &cst, NULL)); 1089566063dSJacob Faibussowitsch PetscCall(PetscMalloc1(dn + on, &gidx)); 109076ba34aSJunchao Zhang for (i = 0; i < dn; i++) gidx[i] = cst + i; 110076ba34aSJunchao Zhang for (i = 0; i < on; i++) gidx[i + dn] = cmap[i]; 1119566063dSJacob Faibussowitsch PetscCall(ISCreateGeneral(PetscObjectComm((PetscObject)Ad), dn + on, gidx, PETSC_OWN_POINTER, glob)); 112076ba34aSJunchao Zhang } 1133ba16761SJacob Faibussowitsch PetscFunctionReturn(PETSC_SUCCESS); 114076ba34aSJunchao Zhang } 115076ba34aSJunchao Zhang 1160e3ece09SJunchao Zhang /* Structs used in matrix products of type C=AB, C=A^tB and C=B^tAB */ 117076ba34aSJunchao Zhang struct MatMatStruct { 1180e3ece09SJunchao Zhang PetscInt n, *garray; // C's garray and its size. 1190e3ece09SJunchao Zhang KokkosCsrMatrix Cd, Co; // C is in split form matrices (all in local column indcies) 1200e3ece09SJunchao Zhang KokkosCsrMatrix C1, C2, C3, C4; // intermediate mat products 1210e3ece09SJunchao Zhang KokkosCsrMatrix C2_mid, C4_mid; // alias of C2, C4; share their a[], i[], but with different j[] (hence column size) 1220e3ece09SJunchao Zhang PetscIntKokkosView E_NzLeft; 1230e3ece09SJunchao Zhang PetscSF sf = nullptr; // SF to bcast or reduce matrices E to F 1240e3ece09SJunchao Zhang MatScalarKokkosView rootBuf, leafBuf; 1250e3ece09SJunchao Zhang KokkosCsrMatrix Fd, Fo; // F in split form 1260e3ece09SJunchao Zhang 1270e3ece09SJunchao Zhang KernelHandle kh1; // compute C1, add C1+C3 or C1+Fd 1280e3ece09SJunchao Zhang KernelHandle kh2; // compute C2, add C2+C4 or C2+Fo 1290e3ece09SJunchao Zhang KernelHandle kh3; // compute C3 1300e3ece09SJunchao Zhang KernelHandle kh4; // compute C4 1310e3ece09SJunchao Zhang 132aaa8cc7dSPierre Jolivet PetscInt E_TeamSize; // kernel launching parameters in merging E or splitting F 1330e3ece09SJunchao Zhang PetscInt E_VectorLength; 1340e3ece09SJunchao Zhang PetscInt E_RowsPerTeam; 1350e3ece09SJunchao Zhang PetscInt F_TeamSize; 1360e3ece09SJunchao Zhang PetscInt F_VectorLength; 1370e3ece09SJunchao Zhang PetscInt F_RowsPerTeam; 138076ba34aSJunchao Zhang 139d71ae5a4SJacob Faibussowitsch ~MatMatStruct() 140d71ae5a4SJacob Faibussowitsch { 1413ba16761SJacob Faibussowitsch PetscFunctionBegin; 1423ba16761SJacob Faibussowitsch PetscCallAbort(PETSC_COMM_SELF, PetscSFDestroy(&sf)); 1433ba16761SJacob Faibussowitsch PetscFunctionReturnVoid(); 144076ba34aSJunchao Zhang } 145076ba34aSJunchao Zhang }; 146076ba34aSJunchao Zhang 147076ba34aSJunchao Zhang struct MatMatStruct_AB : public MatMatStruct { 1480e3ece09SJunchao Zhang PetscIntKokkosView F_NzLeft; // plans to split F (in leafbuf) into Fd, Fo 1490e3ece09SJunchao Zhang PetscIntKokkosView irootloc; // plans to put E (i.e., Bd, Bo) into rootBuf 1500e3ece09SJunchao Zhang PetscIntKokkosView rowoffset; 151076ba34aSJunchao Zhang }; 152076ba34aSJunchao Zhang 153076ba34aSJunchao Zhang struct MatMatStruct_AtB : public MatMatStruct { 1540e3ece09SJunchao Zhang MatColIdxKokkosView Fdjmap; // plans to reduce data in rootBuf to Fd, Fo 1550e3ece09SJunchao Zhang MatColIdxKokkosView Fdjperm; 1560e3ece09SJunchao Zhang MatColIdxKokkosView Fojmap; 1570e3ece09SJunchao Zhang MatColIdxKokkosView Fojperm; 158076ba34aSJunchao Zhang }; 159076ba34aSJunchao Zhang 160cc1eb50dSBarry Smith struct MatProductCtx_MPIAIJKokkos { 1613ba16761SJacob Faibussowitsch MatMatStruct_AB *mmAB = nullptr; 1623ba16761SJacob Faibussowitsch MatMatStruct_AtB *mmAtB = nullptr; 1633ba16761SJacob Faibussowitsch PetscBool reusesym = PETSC_FALSE; 1640e3ece09SJunchao Zhang Mat Z = nullptr; // store Z=AB in computing BtAB 165076ba34aSJunchao Zhang 166cc1eb50dSBarry Smith ~MatProductCtx_MPIAIJKokkos() 167d71ae5a4SJacob Faibussowitsch { 168076ba34aSJunchao Zhang delete mmAB; 169076ba34aSJunchao Zhang delete mmAtB; 1700e3ece09SJunchao Zhang PetscCallAbort(PETSC_COMM_SELF, MatDestroy(&Z)); 171076ba34aSJunchao Zhang } 172076ba34aSJunchao Zhang }; 173076ba34aSJunchao Zhang 174*2a8381b2SBarry Smith static PetscErrorCode MatProductCtxDestroy_MPIAIJKokkos(PetscCtxRt data) 175d71ae5a4SJacob Faibussowitsch { 176076ba34aSJunchao Zhang PetscFunctionBegin; 177cc1eb50dSBarry Smith PetscCallCXX(delete *reinterpret_cast<MatProductCtx_MPIAIJKokkos **>(data)); 1783ba16761SJacob Faibussowitsch PetscFunctionReturn(PETSC_SUCCESS); 179076ba34aSJunchao Zhang } 180076ba34aSJunchao Zhang 1810e3ece09SJunchao Zhang // Adapted from Kokkos-Kernels spmv_launch_parameters(), to get parameters in Kokkos nested loops which we used to merge or 1820e3ece09SJunchao Zhang // split csr matrices. The rule is to have "vector_length * team_size" be around 256 on GPUs (e.g., for a CUDA thread block) 1830e3ece09SJunchao Zhang template <class ExecutionSpace> 1840e3ece09SJunchao Zhang static PetscErrorCode MatMergeGetLaunchParameters(PetscInt numRows, PetscInt nnz, PetscInt rows_per_thread, PetscInt &team_size, PetscInt &vector_length, PetscInt &rows_per_team) 185d71ae5a4SJacob Faibussowitsch { 1861aa660a0SJunchao Zhang #if PETSC_PKG_KOKKOS_KERNELS_VERSION_LE(4, 4, 1) 1871aa660a0SJunchao Zhang constexpr bool is_gpu_exec_space = KokkosKernels::Impl::kk_is_gpu_exec_space<ExecutionSpace>(); 1881aa660a0SJunchao Zhang #else 1891aa660a0SJunchao Zhang constexpr bool is_gpu_exec_space = KokkosKernels::Impl::is_gpu_exec_space_v<ExecutionSpace>; 1901aa660a0SJunchao Zhang #endif 1910e3ece09SJunchao Zhang Kokkos::TeamPolicy<ExecutionSpace> teamPolicy(128, Kokkos::AUTO); 192076ba34aSJunchao Zhang 193076ba34aSJunchao Zhang PetscFunctionBegin; 1940e3ece09SJunchao Zhang PetscInt nnz_per_row = numRows ? (nnz / numRows) : 0; // we might meet empty matrices 195076ba34aSJunchao Zhang 1960e3ece09SJunchao Zhang if (nnz_per_row < 1) nnz_per_row = 1; 197076ba34aSJunchao Zhang 1980e3ece09SJunchao Zhang int max_vector_length = teamPolicy.vector_length_max(); 199076ba34aSJunchao Zhang 2000e3ece09SJunchao Zhang if (vector_length < 1) { 2010e3ece09SJunchao Zhang vector_length = 1; 2020e3ece09SJunchao Zhang while (vector_length < max_vector_length && vector_length * 6 < nnz_per_row) vector_length *= 2; 203076ba34aSJunchao Zhang } 204076ba34aSJunchao Zhang 2050e3ece09SJunchao Zhang // Determine rows per thread 2060e3ece09SJunchao Zhang if (rows_per_thread < 1) { 2071aa660a0SJunchao Zhang if (is_gpu_exec_space) rows_per_thread = 1; 2080e3ece09SJunchao Zhang else { 2090e3ece09SJunchao Zhang if (nnz_per_row < 20 && nnz > 5000000) { 2100e3ece09SJunchao Zhang rows_per_thread = 256; 2110e3ece09SJunchao Zhang } else rows_per_thread = 64; 212076ba34aSJunchao Zhang } 213076ba34aSJunchao Zhang } 214076ba34aSJunchao Zhang 2150e3ece09SJunchao Zhang if (team_size < 1) { 2161aa660a0SJunchao Zhang if (is_gpu_exec_space) { 2170e3ece09SJunchao Zhang team_size = 256 / vector_length; 218076ba34aSJunchao Zhang } else { 2190e3ece09SJunchao Zhang team_size = 1; 2200e3ece09SJunchao Zhang } 221076ba34aSJunchao Zhang } 222076ba34aSJunchao Zhang 2230e3ece09SJunchao Zhang rows_per_team = rows_per_thread * team_size; 224076ba34aSJunchao Zhang 2250e3ece09SJunchao Zhang if (rows_per_team < 0) { 2260e3ece09SJunchao Zhang PetscInt nnz_per_team = 4096; 2270e3ece09SJunchao Zhang PetscInt conc = ExecutionSpace().concurrency(); 2280e3ece09SJunchao Zhang while ((conc * nnz_per_team * 4 > nnz) && (nnz_per_team > 256)) nnz_per_team /= 2; 2290e3ece09SJunchao Zhang rows_per_team = (nnz_per_team + nnz_per_row - 1) / nnz_per_row; 2300e3ece09SJunchao Zhang } 2313ba16761SJacob Faibussowitsch PetscFunctionReturn(PETSC_SUCCESS); 232076ba34aSJunchao Zhang } 233076ba34aSJunchao Zhang 2340e3ece09SJunchao Zhang /* 2350e3ece09SJunchao Zhang Reduce two sets of global indices into local ones 236076ba34aSJunchao Zhang 237076ba34aSJunchao Zhang Input Parameters: 2380e3ece09SJunchao Zhang + n1 - size of garray1[], the first set 2390e3ece09SJunchao Zhang . garray1[n1] - a sorted global index array (without duplicates) 2400e3ece09SJunchao Zhang . m - size of indices[], the second set 2410e3ece09SJunchao Zhang - indices[m] - a unsorted global index array (might have duplicates), which will be updated on output into local ones 242076ba34aSJunchao Zhang 243076ba34aSJunchao Zhang Output Parameters: 2440e3ece09SJunchao Zhang + n2 - size of garray2[], the merged set, which combines garray1[] and indices[] 2450e3ece09SJunchao Zhang . garray2[n2] - allocated by callee using PetscMalloc1(). Contains sorted unique global indices (without duplicates). Caller needs to free it. 2460e3ece09SJunchao Zhang . map[n1] - allocated by caller. It gives garray1[i] = garray2[map[i]] 2470e3ece09SJunchao Zhang - indices[m] - on output, global indices in this array are rewritten with local ones, i.e, indices_input[i] = garray2[indices_output[i]] 248076ba34aSJunchao Zhang 2490e3ece09SJunchao Zhang Example, say 2500e3ece09SJunchao Zhang n1 = 5 2510e3ece09SJunchao Zhang garray1[5] = {1, 4, 7, 8, 10} 2520e3ece09SJunchao Zhang m = 4 2530e3ece09SJunchao Zhang indices[4] = {2, 4, 8, 9} 25411a5261eSBarry Smith 2550e3ece09SJunchao Zhang Combining them together, we have 7 global indices in garray2[] 2560e3ece09SJunchao Zhang n2 = 7 2570e3ece09SJunchao Zhang garray2[7] = {1, 2, 4, 7, 8, 9, 10} 2580e3ece09SJunchao Zhang 2590e3ece09SJunchao Zhang And we have map[] to connect "garray1[i] = garray2[map[i]], i=[0,n1)" 2600e3ece09SJunchao Zhang map[5] = {0, 2, 3, 4, 6} 2610e3ece09SJunchao Zhang 2620e3ece09SJunchao Zhang On output, indices[] is updated with local indices 2630e3ece09SJunchao Zhang indices[4] = {1, 2, 4, 5} 264076ba34aSJunchao Zhang */ 2650e3ece09SJunchao Zhang static PetscErrorCode ReduceTwoSetsOfGlobalIndices(PetscInt n1, const PetscInt *garray1, PetscInt m, PetscInt *indices, PetscInt *n2_, PetscInt **garray2_, PetscInt *map) 266d71ae5a4SJacob Faibussowitsch { 2670e3ece09SJunchao Zhang PetscHMapI g2l = nullptr; 2680e3ece09SJunchao Zhang PetscHashIter iter; 2690e3ece09SJunchao Zhang PetscInt tot, key, val; // total unique global indices. key is global id; val is local id 2700e3ece09SJunchao Zhang PetscInt n2, *garray2; 271076ba34aSJunchao Zhang 272076ba34aSJunchao Zhang PetscFunctionBegin; 2730e3ece09SJunchao Zhang tot = 0; 2740e3ece09SJunchao Zhang PetscCall(PetscHMapICreateWithSize(n1, &g2l)); 2750e3ece09SJunchao Zhang for (PetscInt i = 0; i < m; i++) { // insert those in indices[] 2760e3ece09SJunchao Zhang PetscCall(PetscHMapIGetWithDefault(g2l, indices[i], -1, &val)); // if not exist, val is set with -1 2770e3ece09SJunchao Zhang if (val < 0) PetscCall(PetscHMapISet(g2l, indices[i], tot++)); // val < 0 means gid is not in the hash table yet 278076ba34aSJunchao Zhang } 279076ba34aSJunchao Zhang 2800e3ece09SJunchao Zhang for (PetscInt i = 0; i < n1; i++) { // insert those in garray1[] 2810e3ece09SJunchao Zhang PetscCall(PetscHMapIGetWithDefault(g2l, garray1[i], -1, &val)); 2820e3ece09SJunchao Zhang if (val < 0) PetscCall(PetscHMapISet(g2l, garray1[i], tot++)); 283076ba34aSJunchao Zhang } 284076ba34aSJunchao Zhang 2850e3ece09SJunchao Zhang // Pull out (unique) globals in the hash table and put them in garray2[] 2860e3ece09SJunchao Zhang n2 = tot; 2870e3ece09SJunchao Zhang PetscCall(PetscMalloc1(n2, &garray2)); 2880e3ece09SJunchao Zhang tot = 0; 2890e3ece09SJunchao Zhang PetscHashIterBegin(g2l, iter); 2900e3ece09SJunchao Zhang while (!PetscHashIterAtEnd(g2l, iter)) { 2910e3ece09SJunchao Zhang PetscHashIterGetKey(g2l, iter, key); 2920e3ece09SJunchao Zhang PetscHashIterNext(g2l, iter); 2930e3ece09SJunchao Zhang garray2[tot++] = key; 294076ba34aSJunchao Zhang } 295076ba34aSJunchao Zhang 2960e3ece09SJunchao Zhang // Sort garray2[] and then map them to local indices starting from 0 2970e3ece09SJunchao Zhang PetscCall(PetscSortInt(n2, garray2)); 2980e3ece09SJunchao Zhang PetscCall(PetscHMapIClear(g2l)); 2990e3ece09SJunchao Zhang for (PetscInt i = 0; i < tot; i++) PetscCall(PetscHMapISet(g2l, garray2[i], i)); // i is the local id 300f0e6e2d1SJunchao Zhang 3010e3ece09SJunchao Zhang // Rewrite indices[] with local indices 302f0e6e2d1SJunchao Zhang for (PetscInt i = 0; i < m; i++) { 3030e3ece09SJunchao Zhang PetscCall(PetscHMapIGetWithDefault(g2l, indices[i], -1, &val)); 3040e3ece09SJunchao Zhang PetscAssert(val >= 0, PETSC_COMM_SELF, PETSC_ERR_PLIB, "Met a negative local column index"); 3050e3ece09SJunchao Zhang indices[i] = val; 3060e3ece09SJunchao Zhang } 3070e3ece09SJunchao Zhang // Record the map that maps garray1[i] to garray2[map[i]] 3080e3ece09SJunchao Zhang for (PetscInt i = 0; i < n1; i++) PetscCall(PetscHMapIGetWithDefault(g2l, garray1[i], -1, &map[i])); 3090e3ece09SJunchao Zhang PetscCall(PetscHMapIDestroy(&g2l)); 3100e3ece09SJunchao Zhang *n2_ = n2; 3110e3ece09SJunchao Zhang *garray2_ = garray2; 3120e3ece09SJunchao Zhang PetscFunctionReturn(PETSC_SUCCESS); 3130e3ece09SJunchao Zhang } 314f0e6e2d1SJunchao Zhang 3150e3ece09SJunchao Zhang /* 3160e3ece09SJunchao Zhang MatMPIAIJKokkosReduce - Reduce rows of a MPIAIJKOKKOS matrix (E, in split form) to produce another matrix (F, also in split form, stored in mm) 3170e3ece09SJunchao Zhang 3180e3ece09SJunchao Zhang It is the reverse of MatMPIAIJKokkosBcast() in some sense, but with a different signature since we do not really need a fully populated MPIAIJKOKKOS E. 3190e3ece09SJunchao Zhang 3200e3ece09SJunchao Zhang Think each row of E as a leaf, then the given ownerSF specifies roots for the leaves. Roots may connect to multiple leaves. 3210e3ece09SJunchao Zhang In this routine, we sparse-merge leaves (rows) at their roots to form potentially longer rows in F. F's number of rows will be nroots of ownerSF. 3220e3ece09SJunchao Zhang 3230e3ece09SJunchao Zhang Input Parameters: 3240e3ece09SJunchao Zhang + comm - MPI communicator of E 3250e3ece09SJunchao Zhang . A - diag block of E, using local column indices 3260e3ece09SJunchao Zhang . B - off-diag block of E, using local column indices 3270e3ece09SJunchao Zhang . cstart - (global) start column of Ed 3280e3ece09SJunchao Zhang . cend - (global) end column + 1 of Ed. In other words, E's column ownership is in range of [cstart, cend) 3290e3ece09SJunchao Zhang . garray1[n1] - global column indices of Eo. Here n1 is Eo's column size. 3300e3ece09SJunchao Zhang . ownerSF - the SF specifies ownership (root) of rows in E 3310e3ece09SJunchao Zhang . reuse - either MAT_INITIAL_MATRIX or MAT_REUSE_MATRIX 3320e3ece09SJunchao Zhang - mm - to stash intermediate data structures for reuse 3330e3ece09SJunchao Zhang 3340e3ece09SJunchao Zhang Output Parameters: 3350e3ece09SJunchao Zhang + map[n1] - allocated by caller. It maps garray1[] to garray2[]. See more at ReduceTwoSetsOfGlobalIndices(). 3360e3ece09SJunchao Zhang - mm - contains various info, such as garray2[], F (Fd, Fo) etc. 3370e3ece09SJunchao Zhang 3380e3ece09SJunchao Zhang Notes: 3390e3ece09SJunchao Zhang When reuse = MAT_REUSE_MATRIX, cstart, cend, garray1, ownerSF, map are not significant. 3400e3ece09SJunchao Zhang 3410e3ece09SJunchao Zhang */ 3420e3ece09SJunchao Zhang static PetscErrorCode MatMPIAIJKokkosReduceBegin(MPI_Comm comm, KokkosCsrMatrix A, KokkosCsrMatrix B, PetscInt cstart, PetscInt cend, const PetscInt *garray1, PetscSF ownerSF, MatReuse reuse, PetscInt *map, MatMatStruct_AtB *mm) 3430e3ece09SJunchao Zhang { 3440e3ece09SJunchao Zhang PetscFunctionBegin; 3450e3ece09SJunchao Zhang if (reuse == MAT_INITIAL_MATRIX) { 3460e3ece09SJunchao Zhang PetscInt Em = A.numRows(), Fm; 3470e3ece09SJunchao Zhang PetscInt n1 = B.numCols(); 3480e3ece09SJunchao Zhang 3490e3ece09SJunchao Zhang PetscCall(PetscSFGetGraph(ownerSF, &Fm, NULL, NULL, NULL)); // Fm = #rows of F = nroots of ownerSF 3500e3ece09SJunchao Zhang 3510e3ece09SJunchao Zhang // Do the analysis on host 35245402d8aSJunchao Zhang auto Ai_h = Kokkos::create_mirror_view_and_copy(HostMirrorMemorySpace(), A.graph.row_map); 35345402d8aSJunchao Zhang auto Aj_h = Kokkos::create_mirror_view_and_copy(HostMirrorMemorySpace(), A.graph.entries); 35445402d8aSJunchao Zhang auto Bi_h = Kokkos::create_mirror_view_and_copy(HostMirrorMemorySpace(), B.graph.row_map); 35545402d8aSJunchao Zhang auto Bj_h = Kokkos::create_mirror_view_and_copy(HostMirrorMemorySpace(), B.graph.entries); 3560e3ece09SJunchao Zhang const MatRowMapType *Ai = Ai_h.data(), *Bi = Bi_h.data(); 3570e3ece09SJunchao Zhang const MatColIdxType *Aj = Aj_h.data(), *Bj = Bj_h.data(); 3580e3ece09SJunchao Zhang 3590e3ece09SJunchao Zhang // Count how many nonzeros of each row in E are in the left of the diag block [cstart,cend) 3607b8d4ba6SJunchao Zhang PetscIntKokkosViewHost E_NzLeft_h(NoInit("E_NzLeft_h"), Em), E_RowLen_h(NoInit("E_RowLen_h"), Em); 3610e3ece09SJunchao Zhang PetscInt *E_NzLeft = E_NzLeft_h.data(), *E_RowLen = E_RowLen_h.data(); 3620e3ece09SJunchao Zhang for (PetscInt i = 0; i < Em; i++) { 3630e3ece09SJunchao Zhang const PetscInt *first, *last, *it; 3640e3ece09SJunchao Zhang PetscInt count, step; 3650e3ece09SJunchao Zhang // std::lower_bound(first,last,cstart), but need to use global column indices 3660e3ece09SJunchao Zhang first = Bj + Bi[i]; 3670e3ece09SJunchao Zhang last = Bj + Bi[i + 1]; 368f0e6e2d1SJunchao Zhang count = last - first; 369f0e6e2d1SJunchao Zhang while (count > 0) { 370f0e6e2d1SJunchao Zhang it = first; 371f0e6e2d1SJunchao Zhang step = count / 2; 372f0e6e2d1SJunchao Zhang it += step; 3730e3ece09SJunchao Zhang if (garray1[*it] < cstart) { // map local to global 374f0e6e2d1SJunchao Zhang first = ++it; 375f0e6e2d1SJunchao Zhang count -= step + 1; 376f0e6e2d1SJunchao Zhang } else count = step; 377f0e6e2d1SJunchao Zhang } 3780e3ece09SJunchao Zhang E_NzLeft[i] = first - (Bj + Bi[i]); 3790e3ece09SJunchao Zhang E_RowLen[i] = (Ai[i + 1] - Ai[i]) + (Bi[i + 1] - Bi[i]); 380f0e6e2d1SJunchao Zhang } 381f0e6e2d1SJunchao Zhang 3820e3ece09SJunchao Zhang // Get length of rows (i.e., sizes of leaves) that contribute to my roots 3830e3ece09SJunchao Zhang const PetscMPIInt *iranks, *ranks; 3840e3ece09SJunchao Zhang const PetscInt *ioffset, *irootloc, *roffset, *rmine; 385c09cee04SJames Wright PetscMPIInt niranks, nranks; 3860e3ece09SJunchao Zhang MPI_Request *reqs; 3870e3ece09SJunchao Zhang PetscMPIInt tag; 3880e3ece09SJunchao Zhang PetscSF reduceSF; 3890e3ece09SJunchao Zhang PetscInt *sdisp, *rdisp; 390f0e6e2d1SJunchao Zhang 3910e3ece09SJunchao Zhang PetscCall(PetscCommGetNewTag(comm, &tag)); 3920e3ece09SJunchao Zhang PetscCall(PetscSFGetLeafRanks(ownerSF, &niranks, &iranks, &ioffset, &irootloc)); // get leaf ranks connecting to roots on this process (I'll recv from them) 3930e3ece09SJunchao Zhang PetscCall(PetscSFGetRootRanks(ownerSF, &nranks, &ranks, &roffset, &rmine, NULL)); // get root ranks this process connects (I'll send to them) 394f0e6e2d1SJunchao Zhang 3950e3ece09SJunchao Zhang // Find out length of each row I will receive. Even for the same row index, when they are from 3960e3ece09SJunchao Zhang // different senders, they might have different lengths (and sparsity patterns) 3970e3ece09SJunchao Zhang PetscInt sendRowCnt = roffset[nranks], recvRowCnt = ioffset[niranks]; 3980e3ece09SJunchao Zhang PetscInt *sendRowLen, *recvRowLen; // lengths of rows of I need to send/recv per process 399f0e6e2d1SJunchao Zhang 4000e3ece09SJunchao Zhang PetscCall(PetscMalloc5(sendRowCnt, &sendRowLen, recvRowCnt + 1, &recvRowLen, nranks, &sdisp, niranks + 1, &rdisp, nranks + niranks, &reqs)); 4010e3ece09SJunchao Zhang 4020e3ece09SJunchao Zhang for (PetscInt i = 0; i < sendRowCnt; i++) sendRowLen[i] = E_RowLen[rmine[i]]; 4030e3ece09SJunchao Zhang recvRowLen[0] = 0; // since we will make it in CSR format later 4040e3ece09SJunchao Zhang recvRowLen++; // advance the pointer now 405458b0db5SMartin Diehl for (PetscInt i = 0; i < niranks; i++) PetscCallMPI(MPIU_Irecv(&recvRowLen[ioffset[i]], ioffset[i + 1] - ioffset[i], MPIU_INT, iranks[i], tag, comm, &reqs[nranks + i])); 406458b0db5SMartin Diehl for (PetscInt i = 0; i < nranks; i++) PetscCallMPI(MPIU_Isend(&sendRowLen[roffset[i]], roffset[i + 1] - roffset[i], MPIU_INT, ranks[i], tag, comm, &reqs[i])); 4070e3ece09SJunchao Zhang PetscCallMPI(MPI_Waitall(nranks + niranks, reqs, MPI_STATUSES_IGNORE)); 4080e3ece09SJunchao Zhang 4090e3ece09SJunchao Zhang // Build the real PetscSF for reducing E rows (buffer to buffer) 4100e3ece09SJunchao Zhang rdisp[0] = 0; 4110e3ece09SJunchao Zhang for (PetscInt i = 0; i < niranks; i++) { 4120e3ece09SJunchao Zhang rdisp[i + 1] = rdisp[i]; 413ac530a7eSPierre Jolivet for (PetscInt j = ioffset[i]; j < ioffset[i + 1]; j++) rdisp[i + 1] += recvRowLen[j]; 4140e3ece09SJunchao Zhang } 4150e3ece09SJunchao Zhang recvRowLen--; // put it back into csr format 4160e3ece09SJunchao Zhang for (PetscInt i = 0; i < recvRowCnt; i++) recvRowLen[i + 1] += recvRowLen[i]; 4170e3ece09SJunchao Zhang 418458b0db5SMartin Diehl for (PetscInt i = 0; i < nranks; i++) PetscCallMPI(MPIU_Irecv(&sdisp[i], 1, MPIU_INT, ranks[i], tag, comm, &reqs[i])); 419458b0db5SMartin Diehl for (PetscInt i = 0; i < niranks; i++) PetscCallMPI(MPIU_Isend(&rdisp[i], 1, MPIU_INT, iranks[i], tag, comm, &reqs[nranks + i])); 4200e3ece09SJunchao Zhang PetscCallMPI(MPI_Waitall(nranks + niranks, reqs, MPI_STATUSES_IGNORE)); 4210e3ece09SJunchao Zhang 4220e3ece09SJunchao Zhang PetscInt nleaves = 0, Enz = 0; // leaves are nonzeros I will send 4230e3ece09SJunchao Zhang PetscInt nroots = rdisp[niranks]; // roots are nonzeros I will recv 4240e3ece09SJunchao Zhang PetscSFNode *iremote; 4250e3ece09SJunchao Zhang 4260e3ece09SJunchao Zhang for (PetscInt i = 0; i < Em; i++) Enz += E_RowLen[i]; 4270e3ece09SJunchao Zhang PetscAssert(A.nnz() + B.nnz() == Enz, comm, PETSC_ERR_PLIB, "Enz should be equal to sum of nnz of A and B"); 4280e3ece09SJunchao Zhang PetscCallMPI(PetscMalloc1(Enz, &iremote)); // no free, since we give ownership to reduceSF 4290e3ece09SJunchao Zhang 4300e3ece09SJunchao Zhang for (PetscInt i = 0; i < nranks; i++) { 4310e3ece09SJunchao Zhang PetscInt count = 0; 4320e3ece09SJunchao Zhang for (PetscInt j = roffset[i]; j < roffset[i + 1]; j++) count += E_RowLen[rmine[j]]; 4330e3ece09SJunchao Zhang for (PetscInt j = 0; j < count; j++) { 4340e3ece09SJunchao Zhang iremote[nleaves + j].rank = ranks[i]; 4350e3ece09SJunchao Zhang iremote[nleaves + j].index = sdisp[i] + j; 4360e3ece09SJunchao Zhang } 4370e3ece09SJunchao Zhang nleaves += count; 4380e3ece09SJunchao Zhang } 4390e3ece09SJunchao Zhang PetscCheck(nleaves == Enz, comm, PETSC_ERR_PLIB, "nleaves should be equal to Enz"); 4400e3ece09SJunchao Zhang 4410e3ece09SJunchao Zhang PetscCall(PetscSFCreate(comm, &reduceSF)); 4420e3ece09SJunchao Zhang PetscCall(PetscSFSetGraph(reduceSF, nroots, nleaves, NULL, PETSC_OWN_POINTER, iremote, PETSC_OWN_POINTER)); 4430e3ece09SJunchao Zhang 4440e3ece09SJunchao Zhang // Copy (global) column indices of the needed rows in E to sendCol[], and then PetscSFReduce to recvCol[] 4450e3ece09SJunchao Zhang PetscInt *sendCol, *recvCol; 4460e3ece09SJunchao Zhang PetscCall(PetscMalloc2(nleaves, &sendCol, nroots, &recvCol)); 4470e3ece09SJunchao Zhang for (PetscInt k = 0; k < roffset[nranks]; k++) { 4480e3ece09SJunchao Zhang PetscInt i = rmine[k]; // row to be copied 4490e3ece09SJunchao Zhang PetscInt *buf = &sendCol[Ai[i] + Bi[i]]; 4500e3ece09SJunchao Zhang PetscInt nzLeft = E_NzLeft[i]; 4510e3ece09SJunchao Zhang PetscInt alen = Ai[i + 1] - Ai[i], blen = Bi[i + 1] - Bi[i]; 4520e3ece09SJunchao Zhang for (PetscInt j = 0; j < alen + blen; j++) { 4530e3ece09SJunchao Zhang if (j < nzLeft) { 4540e3ece09SJunchao Zhang buf[j] = garray1[Bj[Bi[i] + j]]; // left B, in global 4550e3ece09SJunchao Zhang } else if (j < nzLeft + alen) { 4560e3ece09SJunchao Zhang buf[j] = Aj[Ai[i] + j - nzLeft] + cstart; // diag A, also in global 4570e3ece09SJunchao Zhang } else { 4580e3ece09SJunchao Zhang buf[j] = garray1[Bj[Bi[i] + j - alen]]; // right B, in global 4590e3ece09SJunchao Zhang } 4600e3ece09SJunchao Zhang } 4610e3ece09SJunchao Zhang } 4620e3ece09SJunchao Zhang PetscCall(PetscSFReduceWithMemTypeBegin(reduceSF, MPIU_INT, PETSC_MEMTYPE_HOST, sendCol, PETSC_MEMTYPE_HOST, recvCol, MPI_REPLACE)); 4630e3ece09SJunchao Zhang PetscCall(PetscSFReduceEnd(reduceSF, MPIU_INT, sendCol, recvCol, MPI_REPLACE)); 4640e3ece09SJunchao Zhang 4650e3ece09SJunchao Zhang // With recvCol[], we do a series of analysis to get i, j of Fd, Fo, and build plans to reduce nonzeros in recv buffers to Fd and Fo 4660e3ece09SJunchao Zhang PetscInt *recvRowPerm, *recvColSorted; 4670e3ece09SJunchao Zhang PetscInt *recvNzPerm, *recvNzPermSorted; 4680e3ece09SJunchao Zhang PetscCall(PetscMalloc4(recvRowCnt, &recvRowPerm, nroots, &recvColSorted, nroots, &recvNzPerm, nroots, &recvNzPermSorted)); 4690e3ece09SJunchao Zhang 4700e3ece09SJunchao Zhang for (PetscInt i = 0; i < nroots; i++) recvNzPerm[i] = i; // numbering all received nonzeros 4710e3ece09SJunchao Zhang for (PetscInt i = 0; i < recvRowCnt; i++) recvRowPerm[i] = i; // put up a permutation array, so that after sorting we know where to get a row in recvCol[] 4720e3ece09SJunchao Zhang PetscCall(PetscSortIntWithPermutation(recvRowCnt, irootloc, recvRowPerm)); // irootloc[] (owned by ownerSF) won't be changed 4730e3ece09SJunchao Zhang 4740e3ece09SJunchao Zhang // i[] array, nz are always easiest to compute 4757b8d4ba6SJunchao Zhang MatRowMapKokkosViewHost Fdi_h(NoInit("Fdi_h"), Fm + 1), Foi_h(NoInit("Foi_h"), Fm + 1); 4760e3ece09SJunchao Zhang MatRowMapType *Fdi, *Foi; 4770e3ece09SJunchao Zhang PetscInt FnzDups = 0, Fdnz = 0, FdnzDups = 0, Fonz = 0, FonzDups = 0; // nz (with or without dups) in F, Fd, Fo 4780e3ece09SJunchao Zhang PetscInt iter; 4790e3ece09SJunchao Zhang 4800e3ece09SJunchao Zhang Kokkos::deep_copy(Fdi_h, 0); // zero, as we will do 'val++' on them 4810e3ece09SJunchao Zhang Kokkos::deep_copy(Foi_h, 0); 4820e3ece09SJunchao Zhang Fdi = Fdi_h.data() + 1; // +1 for easy indexing in code below 4830e3ece09SJunchao Zhang Foi = Foi_h.data() + 1; 4840e3ece09SJunchao Zhang iter = 0; 4850e3ece09SJunchao Zhang while (iter < recvRowCnt) { // iter over received rows 4860e3ece09SJunchao Zhang PetscInt curRowIdx = irootloc[recvRowPerm[iter]]; 4870e3ece09SJunchao Zhang PetscInt dupRows = 1; // current row has this many contributing rows (of various sparsity patterns) 4880e3ece09SJunchao Zhang 4890e3ece09SJunchao Zhang while (iter + dupRows < recvRowCnt && irootloc[recvRowPerm[iter + dupRows]] == curRowIdx) dupRows++; 4900e3ece09SJunchao Zhang 4910e3ece09SJunchao Zhang // Copy column indices (and their permutation) of these rows into recvColSorted & recvNzPermSorted 4920e3ece09SJunchao Zhang PetscInt nz = 0; // nz (with dups) in the current row 4930e3ece09SJunchao Zhang PetscInt *jbuf = recvColSorted + FnzDups; 4940e3ece09SJunchao Zhang PetscInt *pbuf = recvNzPermSorted + FnzDups; 4950e3ece09SJunchao Zhang PetscInt *jbuf2 = jbuf; // temp pointers 4960e3ece09SJunchao Zhang PetscInt *pbuf2 = pbuf; 4970e3ece09SJunchao Zhang for (PetscInt d = 0; d < dupRows; d++) { 4980e3ece09SJunchao Zhang PetscInt i = recvRowPerm[iter + d]; 4990e3ece09SJunchao Zhang PetscInt len = recvRowLen[i + 1] - recvRowLen[i]; 5000e3ece09SJunchao Zhang PetscCall(PetscArraycpy(jbuf2, &recvCol[recvRowLen[i]], len)); 5010e3ece09SJunchao Zhang PetscCall(PetscArraycpy(pbuf2, &recvNzPerm[recvRowLen[i]], len)); 5020e3ece09SJunchao Zhang jbuf2 += len; 5030e3ece09SJunchao Zhang pbuf2 += len; 5040e3ece09SJunchao Zhang nz += len; 5050e3ece09SJunchao Zhang } 5060e3ece09SJunchao Zhang PetscCall(PetscIntSortSemiOrderedWithArray(nz, jbuf, pbuf)); // It could be improved with k-way merge sort, since the rows are already sorted 5070e3ece09SJunchao Zhang 5080e3ece09SJunchao Zhang // Scan column indices (in jbuf[0,nz), might have dups) of this row, and see how many go to Fd and how many go to Fo 5090e3ece09SJunchao Zhang PetscInt cur = 0; 5100e3ece09SJunchao Zhang while (cur < nz) { 5110e3ece09SJunchao Zhang PetscInt curColIdx = jbuf[cur]; 5120e3ece09SJunchao Zhang PetscInt dups = 1; 5130e3ece09SJunchao Zhang 5140e3ece09SJunchao Zhang while (cur + dups < nz && jbuf[cur + dups] == curColIdx) dups++; 5150e3ece09SJunchao Zhang if (curColIdx >= cstart && curColIdx < cend) { 5160e3ece09SJunchao Zhang Fdi[curRowIdx]++; 5170e3ece09SJunchao Zhang FdnzDups += dups; 5180e3ece09SJunchao Zhang } else { 5190e3ece09SJunchao Zhang Foi[curRowIdx]++; 5200e3ece09SJunchao Zhang FonzDups += dups; 5210e3ece09SJunchao Zhang } 5220e3ece09SJunchao Zhang cur += dups; 5230e3ece09SJunchao Zhang } 5240e3ece09SJunchao Zhang 5250e3ece09SJunchao Zhang FnzDups += nz; 5260e3ece09SJunchao Zhang iter += dupRows; // Move to next unique row 5270e3ece09SJunchao Zhang } 5280e3ece09SJunchao Zhang 5290e3ece09SJunchao Zhang Fdi = Fdi_h.data(); // restore Fdi, Foi and make them CSR 5300e3ece09SJunchao Zhang Foi = Foi_h.data(); 5310e3ece09SJunchao Zhang for (PetscInt i = 0; i < Fm; i++) { 5320e3ece09SJunchao Zhang Fdi[i + 1] += Fdi[i]; 5330e3ece09SJunchao Zhang Foi[i + 1] += Foi[i]; 5340e3ece09SJunchao Zhang } 5350e3ece09SJunchao Zhang Fdnz = Fdi[Fm]; 5360e3ece09SJunchao Zhang Fonz = Foi[Fm]; 5370e3ece09SJunchao Zhang PetscCall(PetscFree2(sendCol, recvCol)); 5380e3ece09SJunchao Zhang 5390e3ece09SJunchao Zhang // Allocate j, jmap, jperm for Fd and Fo 5407b8d4ba6SJunchao Zhang MatColIdxKokkosViewHost Fdj_h(NoInit("Fdj_h"), Fdnz), Foj_h(NoInit("Foj_h"), Fonz); 5417b8d4ba6SJunchao Zhang MatRowMapKokkosViewHost Fdjmap_h(NoInit("Fdjmap_h"), Fdnz + 1), Fojmap_h(NoInit("Fojmap_h"), Fonz + 1); // +1 to make csr 5427b8d4ba6SJunchao Zhang MatRowMapKokkosViewHost Fdjperm_h(NoInit("Fdjperm_h"), FdnzDups), Fojperm_h(NoInit("Fojperm_h"), FonzDups); 5430e3ece09SJunchao Zhang MatColIdxType *Fdj = Fdj_h.data(), *Foj = Foj_h.data(); 5440e3ece09SJunchao Zhang MatRowMapType *Fdjmap = Fdjmap_h.data(), *Fojmap = Fojmap_h.data(); 5450e3ece09SJunchao Zhang MatRowMapType *Fdjperm = Fdjperm_h.data(), *Fojperm = Fojperm_h.data(); 5460e3ece09SJunchao Zhang 5470e3ece09SJunchao Zhang // Scan recvColSorted[] again, and fill j, jmap, jperm for Fd and Fo 5480e3ece09SJunchao Zhang Fdjmap[0] = 0; 5490e3ece09SJunchao Zhang Fojmap[0] = 0; 5500e3ece09SJunchao Zhang FnzDups = 0; 5510e3ece09SJunchao Zhang Fdnz = 0; 5520e3ece09SJunchao Zhang Fonz = 0; 5530e3ece09SJunchao Zhang iter = 0; // iter over received rows 5540e3ece09SJunchao Zhang while (iter < recvRowCnt) { 5550e3ece09SJunchao Zhang PetscInt curRowIdx = irootloc[recvRowPerm[iter]]; // current row idx 5560e3ece09SJunchao Zhang PetscInt dupRows = 1; // It has this many contributing rows (of various lengths) 5570e3ece09SJunchao Zhang PetscInt nz = 0; // nz (with dups) in the current row 5580e3ece09SJunchao Zhang 5590e3ece09SJunchao Zhang while (iter + dupRows < recvRowCnt && irootloc[recvRowPerm[iter + dupRows]] == curRowIdx) dupRows++; 5600e3ece09SJunchao Zhang for (PetscInt d = 0; d < dupRows; d++) { 5610e3ece09SJunchao Zhang PetscInt i = recvRowPerm[iter + d]; 5620e3ece09SJunchao Zhang nz += recvRowLen[i + 1] - recvRowLen[i]; 5630e3ece09SJunchao Zhang } 5640e3ece09SJunchao Zhang 5650e3ece09SJunchao Zhang PetscInt *jbuf = recvColSorted + FnzDups; 5660e3ece09SJunchao Zhang // Scan columns (in jbuf[0,nz) of this row, copy them and their permutation to j[] and jperm[] of Fd and Fo 5670e3ece09SJunchao Zhang PetscInt cur = 0; 5680e3ece09SJunchao Zhang while (cur < nz) { 5690e3ece09SJunchao Zhang PetscInt curColIdx = jbuf[cur]; 5700e3ece09SJunchao Zhang PetscInt dups = 1; 5710e3ece09SJunchao Zhang 5720e3ece09SJunchao Zhang while (cur + dups < nz && jbuf[cur + dups] == curColIdx) dups++; 5730e3ece09SJunchao Zhang if (curColIdx >= cstart && curColIdx < cend) { 5740e3ece09SJunchao Zhang Fdj[Fdnz] = curColIdx - cstart; // easily convert to local 5750e3ece09SJunchao Zhang Fdjmap[Fdnz + 1] = Fdjmap[Fdnz] + dups; 5760e3ece09SJunchao Zhang for (PetscInt j = 0; j < dups; j++) Fdjperm[Fdjmap[Fdnz] + j] = recvNzPermSorted[FnzDups + j]; 5770e3ece09SJunchao Zhang FdnzDups += dups; 5780e3ece09SJunchao Zhang Fdnz++; 5790e3ece09SJunchao Zhang } else { 5800e3ece09SJunchao Zhang Foj[Fonz] = curColIdx; // in global 5810e3ece09SJunchao Zhang Fojmap[Fonz + 1] = Fojmap[Fonz] + dups; 5820e3ece09SJunchao Zhang for (PetscInt j = 0; j < dups; j++) Fojperm[Fojmap[Fonz] + j] = recvNzPermSorted[FnzDups + j]; 5830e3ece09SJunchao Zhang FonzDups += dups; 5840e3ece09SJunchao Zhang Fonz++; 5850e3ece09SJunchao Zhang } 5860e3ece09SJunchao Zhang cur += dups; 5870e3ece09SJunchao Zhang FnzDups += dups; 5880e3ece09SJunchao Zhang } 5890e3ece09SJunchao Zhang iter += dupRows; // Move to next unique row 5900e3ece09SJunchao Zhang } 5910e3ece09SJunchao Zhang PetscCall(PetscFree4(recvRowPerm, recvColSorted, recvNzPerm, recvNzPermSorted)); 5920e3ece09SJunchao Zhang PetscCall(PetscFree5(sendRowLen, recvRowLen, sdisp, rdisp, reqs)); 5930e3ece09SJunchao Zhang 5940e3ece09SJunchao Zhang // Combine global column indices in garray1[] and Foj[] 5950e3ece09SJunchao Zhang PetscInt n2, *garray2; 5960e3ece09SJunchao Zhang 5970e3ece09SJunchao Zhang PetscCall(ReduceTwoSetsOfGlobalIndices(n1, garray1, Fonz, Foj, &n2, &garray2, map)); 5980e3ece09SJunchao Zhang mm->sf = reduceSF; 5997b8d4ba6SJunchao Zhang mm->leafBuf = MatScalarKokkosView(NoInit("leafBuf"), nleaves); 6007b8d4ba6SJunchao Zhang mm->rootBuf = MatScalarKokkosView(NoInit("rootBuf"), nroots); 601aaa8cc7dSPierre Jolivet mm->garray = garray2; // give ownership, so no free 6020e3ece09SJunchao Zhang mm->n = n2; 6030e3ece09SJunchao Zhang mm->E_NzLeft = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), E_NzLeft_h); 6040e3ece09SJunchao Zhang mm->Fdjmap = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Fdjmap_h); 6050e3ece09SJunchao Zhang mm->Fdjperm = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Fdjperm_h); 6060e3ece09SJunchao Zhang mm->Fojmap = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Fojmap_h); 6070e3ece09SJunchao Zhang mm->Fojperm = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Fojperm_h); 6080e3ece09SJunchao Zhang 6090e3ece09SJunchao Zhang // Output Fd and Fo in KokkosCsrMatrix format 6107b8d4ba6SJunchao Zhang MatScalarKokkosView Fda_d(NoInit("Fda_d"), Fdnz); 6110e3ece09SJunchao Zhang MatRowMapKokkosView Fdi_d = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Fdi_h); 6120e3ece09SJunchao Zhang MatColIdxKokkosView Fdj_d = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Fdj_h); 6137b8d4ba6SJunchao Zhang MatScalarKokkosView Foa_d(NoInit("Foa_d"), Fonz); 6140e3ece09SJunchao Zhang MatRowMapKokkosView Foi_d = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Foi_h); 6150e3ece09SJunchao Zhang MatColIdxKokkosView Foj_d = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Foj_h); 6160e3ece09SJunchao Zhang 6170e3ece09SJunchao Zhang PetscCallCXX(mm->Fd = KokkosCsrMatrix("Fd", Fm, cend - cstart, Fdnz, Fda_d, Fdi_d, Fdj_d)); 6180e3ece09SJunchao Zhang PetscCallCXX(mm->Fo = KokkosCsrMatrix("Fo", Fm, n2, Fonz, Foa_d, Foi_d, Foj_d)); // Fo's column size is n2, length of garray2[] 6190e3ece09SJunchao Zhang 6200e3ece09SJunchao Zhang // Compute kernel launch parameters in merging E 6210e3ece09SJunchao Zhang PetscInt teamSize, vectorLength, rowsPerTeam; 6220e3ece09SJunchao Zhang 6230e3ece09SJunchao Zhang teamSize = vectorLength = rowsPerTeam = -1; 6240e3ece09SJunchao Zhang PetscCall(MatMergeGetLaunchParameters<DefaultExecutionSpace>(Em, Enz, -1, teamSize, vectorLength, rowsPerTeam)); 6250e3ece09SJunchao Zhang mm->E_TeamSize = teamSize; 6260e3ece09SJunchao Zhang mm->E_VectorLength = vectorLength; 6270e3ece09SJunchao Zhang mm->E_RowsPerTeam = rowsPerTeam; 6280e3ece09SJunchao Zhang } else PetscCheck(reuse == MAT_REUSE_MATRIX, comm, PETSC_ERR_PLIB, "Unsupported MatReuse enum %d", reuse); 6290e3ece09SJunchao Zhang 6300e3ece09SJunchao Zhang // Handy aliases 6310e3ece09SJunchao Zhang auto &Aa = A.values; 6320e3ece09SJunchao Zhang auto &Ba = B.values; 6330e3ece09SJunchao Zhang const auto &Ai = A.graph.row_map; 6340e3ece09SJunchao Zhang const auto &Bi = B.graph.row_map; 6350e3ece09SJunchao Zhang const auto &E_NzLeft = mm->E_NzLeft; 6360e3ece09SJunchao Zhang auto &leafBuf = mm->leafBuf; 6370e3ece09SJunchao Zhang auto &rootBuf = mm->rootBuf; 6380e3ece09SJunchao Zhang PetscSF reduceSF = mm->sf; 6390e3ece09SJunchao Zhang PetscInt Em = A.numRows(); 6400e3ece09SJunchao Zhang PetscInt teamSize = mm->E_TeamSize; 6410e3ece09SJunchao Zhang PetscInt vectorLength = mm->E_VectorLength; 6420e3ece09SJunchao Zhang PetscInt rowsPerTeam = mm->E_RowsPerTeam; 6430e3ece09SJunchao Zhang PetscInt workSets = (Em + rowsPerTeam - 1) / rowsPerTeam; 6440e3ece09SJunchao Zhang 6450e3ece09SJunchao Zhang // Copy rows in A/B of E to leafBuf, then pass it to rootBuf 6460e3ece09SJunchao Zhang PetscCallCXX(Kokkos::parallel_for( 647d326c3f1SJunchao Zhang Kokkos::TeamPolicy<>(PetscGetKokkosExecutionSpace(), workSets, teamSize, vectorLength), KOKKOS_LAMBDA(const KokkosTeamMemberType &t) { 6480e3ece09SJunchao Zhang Kokkos::parallel_for(Kokkos::TeamThreadRange(t, 0, rowsPerTeam), [&](PetscInt k) { 6490e3ece09SJunchao Zhang PetscInt i = t.league_rank() * rowsPerTeam + k; // i-th row in F 6500e3ece09SJunchao Zhang if (i < Em) { 6510e3ece09SJunchao Zhang PetscInt disp = Ai(i) + Bi(i); 6520e3ece09SJunchao Zhang PetscInt alen = Ai(i + 1) - Ai(i); 6530e3ece09SJunchao Zhang PetscInt blen = Bi(i + 1) - Bi(i); 6540e3ece09SJunchao Zhang PetscInt nzleft = E_NzLeft(i); 6550e3ece09SJunchao Zhang 6560e3ece09SJunchao Zhang Kokkos::parallel_for(Kokkos::ThreadVectorRange(t, alen + blen), [&](PetscInt j) { 6570e3ece09SJunchao Zhang MatScalar &val = leafBuf(disp + j); 6580e3ece09SJunchao Zhang if (j < nzleft) { // B left 6590e3ece09SJunchao Zhang val = Ba(Bi(i) + j); 6600e3ece09SJunchao Zhang } else if (j < nzleft + alen) { // diag A 6610e3ece09SJunchao Zhang val = Aa(Ai(i) + j - nzleft); 6620e3ece09SJunchao Zhang } else { // B right 6630e3ece09SJunchao Zhang val = Ba(Bi(i) + j - alen); 664f0e6e2d1SJunchao Zhang } 665f0e6e2d1SJunchao Zhang }); 666f0e6e2d1SJunchao Zhang } 667f0e6e2d1SJunchao Zhang }); 6680e3ece09SJunchao Zhang })); 6690e3ece09SJunchao Zhang PetscCall(PetscSFReduceWithMemTypeBegin(reduceSF, MPIU_SCALAR, PETSC_MEMTYPE_KOKKOS, leafBuf.data(), PETSC_MEMTYPE_KOKKOS, rootBuf.data(), MPI_REPLACE)); 670f0e6e2d1SJunchao Zhang PetscFunctionReturn(PETSC_SUCCESS); 671f0e6e2d1SJunchao Zhang } 6720e3ece09SJunchao Zhang 673aaa8cc7dSPierre Jolivet // To finish MatMPIAIJKokkosReduce. 6740e3ece09SJunchao Zhang static PetscErrorCode MatMPIAIJKokkosReduceEnd(MPI_Comm comm, KokkosCsrMatrix A, KokkosCsrMatrix B, PetscInt cstart, PetscInt cend, const PetscInt *garray1, PetscSF ownerSF, MatReuse reuse, PetscInt *map, MatMatStruct_AtB *mm) 6750e3ece09SJunchao Zhang { 6760e3ece09SJunchao Zhang auto &leafBuf = mm->leafBuf; 6770e3ece09SJunchao Zhang auto &rootBuf = mm->rootBuf; 6780e3ece09SJunchao Zhang auto &Fda = mm->Fd.values; 6790e3ece09SJunchao Zhang const auto &Fdjmap = mm->Fdjmap; 6800e3ece09SJunchao Zhang const auto &Fdjperm = mm->Fdjperm; 6810e3ece09SJunchao Zhang auto Fdnz = mm->Fd.nnz(); 6820e3ece09SJunchao Zhang auto &Foa = mm->Fo.values; 6830e3ece09SJunchao Zhang const auto &Fojmap = mm->Fojmap; 6840e3ece09SJunchao Zhang const auto &Fojperm = mm->Fojperm; 6850e3ece09SJunchao Zhang auto Fonz = mm->Fo.nnz(); 6860e3ece09SJunchao Zhang PetscSF reduceSF = mm->sf; 6870e3ece09SJunchao Zhang 688d326c3f1SJunchao Zhang PetscFunctionBegin; 6890e3ece09SJunchao Zhang PetscCall(PetscSFReduceEnd(reduceSF, MPIU_SCALAR, leafBuf.data(), rootBuf.data(), MPI_REPLACE)); 6900e3ece09SJunchao Zhang 6910e3ece09SJunchao Zhang // Reduce data in rootBuf to Fd and Fo 6920e3ece09SJunchao Zhang PetscCallCXX(Kokkos::parallel_for( 693d326c3f1SJunchao Zhang Kokkos::RangePolicy<>(PetscGetKokkosExecutionSpace(), 0, Fdnz), KOKKOS_LAMBDA(const MatRowMapType i) { 6940e3ece09SJunchao Zhang PetscScalar sum = 0.0; 6950e3ece09SJunchao Zhang for (MatRowMapType k = Fdjmap(i); k < Fdjmap(i + 1); k++) sum += rootBuf(Fdjperm(k)); 6960e3ece09SJunchao Zhang Fda(i) = sum; 6970e3ece09SJunchao Zhang })); 6980e3ece09SJunchao Zhang 6990e3ece09SJunchao Zhang PetscCallCXX(Kokkos::parallel_for( 700d326c3f1SJunchao Zhang Kokkos::RangePolicy<>(PetscGetKokkosExecutionSpace(), 0, Fonz), KOKKOS_LAMBDA(const MatRowMapType i) { 7010e3ece09SJunchao Zhang PetscScalar sum = 0.0; 7020e3ece09SJunchao Zhang for (MatRowMapType k = Fojmap(i); k < Fojmap(i + 1); k++) sum += rootBuf(Fojperm(k)); 7030e3ece09SJunchao Zhang Foa(i) = sum; 7040e3ece09SJunchao Zhang })); 7050e3ece09SJunchao Zhang PetscFunctionReturn(PETSC_SUCCESS); 7060e3ece09SJunchao Zhang } 7070e3ece09SJunchao Zhang 7080e3ece09SJunchao Zhang /* 7090e3ece09SJunchao Zhang MatMPIAIJKokkosBcast - Bcast local rows of a MPIAIJKOKKOS matrix (E) to produce a local matrix (F, stored in mm) in split form 7100e3ece09SJunchao Zhang 7110e3ece09SJunchao Zhang This is a complex routine. It is essentially the MPIAIJKOKKOS counterpart of MatGetBrowsOfAoCols_MPIAIJ, but supports 7120e3ece09SJunchao Zhang device and involves various index mapping. 7130e3ece09SJunchao Zhang 7140e3ece09SJunchao Zhang In the given ownerSF, leaves correspond to rows in F, and roots correspond to rows in E. Roots may connect to multiple leaves. 7150e3ece09SJunchao Zhang Suppose F's j-th row is connected to a root identified by PetscSFNode (k,i), it means we need to bcast the i-th row of E on rank k 7160e3ece09SJunchao Zhang to j-th row of F. ownerSF is not an arbitrary SF, instead it is the Mvctx of another MPIAIJ matrix A that is able to perform A*E. 7170e3ece09SJunchao Zhang F has the same column layout as E. 7180e3ece09SJunchao Zhang 7190e3ece09SJunchao Zhang Conceptually F has global column indices. In this routine, we spit F into diagonal Fd and off-diagonal Fo. 720aaa8cc7dSPierre Jolivet Fd uses local column indices, which are easy to compute. We just need to subtract the "local column range start" from the global indices. 7210e3ece09SJunchao Zhang Fo had global column indices at first. We will reduce them into local ones. In doing that, we also take into account the global 7220e3ece09SJunchao Zhang column indices that E's off-diag block has. Let's say there are n1 such indices stored in garray1[]. We will reduce them along with 7230e3ece09SJunchao Zhang column indices in Fo and update Fo with local indices. 7240e3ece09SJunchao Zhang 7250e3ece09SJunchao Zhang Input Parameters: 7260e3ece09SJunchao Zhang + E - the MPIAIJKOKKOS matrix 7279c89aa79SPierre Jolivet . ownerSF - the ownership SF (insignificant in MAT_REUSE_MATRIX) 7280e3ece09SJunchao Zhang . reuse - either MAT_INITIAL_MATRIX or MAT_REUSE_MATRIX 7290e3ece09SJunchao Zhang - mm - to stash matproduct intermediate data structures 7300e3ece09SJunchao Zhang 7310e3ece09SJunchao Zhang Output Parameters: 7320e3ece09SJunchao Zhang + map[n1] - allocated by caller. It maps garray1[] to garray2[]. See more at ReduceTwoSetsOfGlobalIndices. 7330e3ece09SJunchao Zhang - mm - contains various info, such as garray2[], Fd, Fo, etc. 7340e3ece09SJunchao Zhang 7350e3ece09SJunchao Zhang Notes: 7360e3ece09SJunchao Zhang When reuse = MAT_REUSE_MATRIX, ownerSF, map are not significant. 7370e3ece09SJunchao Zhang The routine is provide in split-phase form MatMPIAIJKokkosBcastBegin/End() to provide computation/communication opportunities. 7380e3ece09SJunchao Zhang */ 7390e3ece09SJunchao Zhang static PetscErrorCode MatMPIAIJKokkosBcastBegin(Mat E, PetscSF ownerSF, MatReuse reuse, PetscInt *map, MatMatStruct_AB *mm) 7400e3ece09SJunchao Zhang { 7410e3ece09SJunchao Zhang Mat_MPIAIJ *empi = static_cast<Mat_MPIAIJ *>(E->data); 7420e3ece09SJunchao Zhang Mat A = empi->A, B = empi->B; // diag and off-diag 7430e3ece09SJunchao Zhang Mat_SeqAIJKokkos *akok = static_cast<Mat_SeqAIJKokkos *>(A->spptr), *bkok = static_cast<Mat_SeqAIJKokkos *>(B->spptr); 7440e3ece09SJunchao Zhang PetscInt Em = E->rmap->n; // #local rows 7450e3ece09SJunchao Zhang MPI_Comm comm; 7460e3ece09SJunchao Zhang 7470e3ece09SJunchao Zhang PetscFunctionBegin; 7480e3ece09SJunchao Zhang PetscCallMPI(PetscObjectGetComm((PetscObject)E, &comm)); 7490e3ece09SJunchao Zhang if (reuse == MAT_INITIAL_MATRIX) { 7500e3ece09SJunchao Zhang Mat_SeqAIJ *aseq = static_cast<Mat_SeqAIJ *>(A->data), *bseq = static_cast<Mat_SeqAIJ *>(B->data); 7510e3ece09SJunchao Zhang PetscInt n1 = B->cmap->n, *Ai = aseq->i, *Aj = aseq->j, *Bi = bseq->i, *Bj = bseq->j; 7520e3ece09SJunchao Zhang const PetscInt *garray1 = empi->garray; // its size is n1 7530e3ece09SJunchao Zhang PetscInt cstart, cend; 7540e3ece09SJunchao Zhang PetscSF bcastSF; 7550e3ece09SJunchao Zhang 7560e3ece09SJunchao Zhang PetscCall(MatGetOwnershipRangeColumn(E, &cstart, &cend)); 7570e3ece09SJunchao Zhang 7580e3ece09SJunchao Zhang // Count how many nonzeros of each row in E are in the left of the diag block [cstart,cend) 7597b8d4ba6SJunchao Zhang PetscIntKokkosViewHost E_NzLeft_h(NoInit("E_NzLeft_h"), Em), E_RowLen_h(NoInit("E_RowLen_h"), Em); 7600e3ece09SJunchao Zhang PetscInt *E_NzLeft = E_NzLeft_h.data(), *E_RowLen = E_RowLen_h.data(); 7610e3ece09SJunchao Zhang for (PetscInt i = 0; i < Em; i++) { 7620e3ece09SJunchao Zhang const PetscInt *first, *last, *it; 7630e3ece09SJunchao Zhang PetscInt count, step; 7640e3ece09SJunchao Zhang // std::lower_bound(first,last,cstart), but need to use global column indices 7650e3ece09SJunchao Zhang first = Bj + Bi[i]; 7660e3ece09SJunchao Zhang last = Bj + Bi[i + 1]; 7670e3ece09SJunchao Zhang count = last - first; 7680e3ece09SJunchao Zhang while (count > 0) { 7690e3ece09SJunchao Zhang it = first; 7700e3ece09SJunchao Zhang step = count / 2; 7710e3ece09SJunchao Zhang it += step; 7720e3ece09SJunchao Zhang if (empi->garray[*it] < cstart) { // map local to global 7730e3ece09SJunchao Zhang first = ++it; 7740e3ece09SJunchao Zhang count -= step + 1; 7750e3ece09SJunchao Zhang } else count = step; 7760e3ece09SJunchao Zhang } 7770e3ece09SJunchao Zhang E_NzLeft[i] = first - (Bj + Bi[i]); 7780e3ece09SJunchao Zhang E_RowLen[i] = (Ai[i + 1] - Ai[i]) + (Bi[i + 1] - Bi[i]); 7790e3ece09SJunchao Zhang } 7800e3ece09SJunchao Zhang 7810e3ece09SJunchao Zhang // Compute row pointer Fi of F 7820e3ece09SJunchao Zhang PetscInt *Fi, Fm, Fnz; 7830e3ece09SJunchao Zhang PetscCall(PetscSFGetGraph(ownerSF, NULL, &Fm, NULL, NULL)); // Fm = #rows of F = nleaves of ownerSF 7840e3ece09SJunchao Zhang PetscCall(PetscMalloc1(Fm + 1, &Fi)); 7850e3ece09SJunchao Zhang Fi[0] = 0; 7860e3ece09SJunchao Zhang PetscCall(PetscSFBcastWithMemTypeBegin(ownerSF, MPIU_INT, PETSC_MEMTYPE_HOST, E_RowLen, PETSC_MEMTYPE_HOST, &Fi[1], MPI_REPLACE)); 7870e3ece09SJunchao Zhang PetscCall(PetscSFBcastEnd(ownerSF, MPIU_INT, E_RowLen, &Fi[1], MPI_REPLACE)); 7880e3ece09SJunchao Zhang for (PetscInt i = 0; i < Fm; i++) Fi[i + 1] += Fi[i]; 7890e3ece09SJunchao Zhang Fnz = Fi[Fm]; 7900e3ece09SJunchao Zhang 7910e3ece09SJunchao Zhang // Build the real PetscSF for bcasting E rows (buffer to buffer) 7920e3ece09SJunchao Zhang const PetscMPIInt *iranks, *ranks; 7930e3ece09SJunchao Zhang const PetscInt *ioffset, *irootloc, *roffset; 794c09cee04SJames Wright PetscMPIInt niranks, nranks; 795c09cee04SJames Wright PetscInt *sdisp, *rdisp; 7960e3ece09SJunchao Zhang MPI_Request *reqs; 7970e3ece09SJunchao Zhang PetscMPIInt tag; 7980e3ece09SJunchao Zhang 7990e3ece09SJunchao Zhang PetscCall(PetscSFGetLeafRanks(ownerSF, &niranks, &iranks, &ioffset, &irootloc)); // get leaf ranks referencing roots on this process 8000e3ece09SJunchao Zhang PetscCall(PetscSFGetRootRanks(ownerSF, &nranks, &ranks, &roffset, NULL, NULL)); // recv info 8010e3ece09SJunchao Zhang PetscCall(PetscMalloc3(niranks + 1, &sdisp, nranks, &rdisp, niranks + nranks, &reqs)); 8020e3ece09SJunchao Zhang 8030e3ece09SJunchao Zhang sdisp[0] = 0; // send displacement 8040e3ece09SJunchao Zhang for (PetscInt i = 0; i < niranks; i++) { 8050e3ece09SJunchao Zhang sdisp[i + 1] = sdisp[i]; 8060e3ece09SJunchao Zhang for (PetscInt j = ioffset[i]; j < ioffset[i + 1]; j++) { 8070e3ece09SJunchao Zhang PetscInt r = irootloc[j]; // row to be sent 8080e3ece09SJunchao Zhang sdisp[i + 1] += E_RowLen[r]; 8090e3ece09SJunchao Zhang } 8100e3ece09SJunchao Zhang } 8110e3ece09SJunchao Zhang 8120e3ece09SJunchao Zhang PetscCallMPI(PetscCommGetNewTag(comm, &tag)); 8136497c311SBarry Smith for (PetscInt i = 0; i < nranks; i++) PetscCallMPI(MPIU_Irecv(&rdisp[i], 1, MPIU_INT, ranks[i], tag, comm, &reqs[i])); 8146497c311SBarry Smith for (PetscInt i = 0; i < niranks; i++) PetscCallMPI(MPIU_Isend(&sdisp[i], 1, MPIU_INT, iranks[i], tag, comm, &reqs[nranks + i])); 8150e3ece09SJunchao Zhang PetscCallMPI(MPI_Waitall(niranks + nranks, reqs, MPI_STATUSES_IGNORE)); 8160e3ece09SJunchao Zhang 8170e3ece09SJunchao Zhang PetscInt nleaves = Fnz; // leaves are nonzeros I will receive 8180e3ece09SJunchao Zhang PetscInt nroots = sdisp[niranks]; // roots are nonzeros I will send 8190e3ece09SJunchao Zhang PetscSFNode *iremote; // give ownership to bcastSF 8200e3ece09SJunchao Zhang PetscCall(PetscMalloc1(nleaves, &iremote)); 8210e3ece09SJunchao Zhang for (PetscInt i = 0; i < nranks; i++) { // for each sender rank 8220e3ece09SJunchao Zhang PetscInt k = 0; 8230e3ece09SJunchao Zhang for (PetscInt j = Fi[roffset[i]]; j < Fi[roffset[i + 1]]; j++) { // I will receive rows [roffset[i], roffset[i+1]) of F from ranks[i] 8240e3ece09SJunchao Zhang iremote[j].rank = ranks[i]; 8250e3ece09SJunchao Zhang iremote[j].index = rdisp[i] + k; // their root location 8260e3ece09SJunchao Zhang k++; 8270e3ece09SJunchao Zhang } 8280e3ece09SJunchao Zhang } 8290e3ece09SJunchao Zhang PetscCall(PetscSFCreate(comm, &bcastSF)); 8300e3ece09SJunchao Zhang PetscCall(PetscSFSetGraph(bcastSF, nroots, nleaves, NULL, PETSC_OWN_POINTER, iremote, PETSC_OWN_POINTER)); 8310e3ece09SJunchao Zhang PetscCall(PetscFree3(sdisp, rdisp, reqs)); 8320e3ece09SJunchao Zhang 8330e3ece09SJunchao Zhang // Build a plan (rowoffset, irootloc, E_NzLeft) to copy rows in E to rootdata of bcastSF in parallel 8347b8d4ba6SJunchao Zhang PetscIntKokkosViewHost rowoffset_h(NoInit("rowoffset_h"), ioffset[niranks] + 1); 8350e3ece09SJunchao Zhang PetscInt *rowoffset = rowoffset_h.data(); // for each entry (row) indicated in irootloc[], we calculate its destinate offset in copying 8360e3ece09SJunchao Zhang rowoffset[0] = 0; 837ac530a7eSPierre Jolivet for (PetscInt i = 0; i < ioffset[niranks]; i++) rowoffset[i + 1] = rowoffset[i] + E_RowLen[irootloc[i]]; 8380e3ece09SJunchao Zhang 8390e3ece09SJunchao Zhang // Copy (global) column indices of the needed rows in E to a buffer, and then bcast to Fj[] 8400e3ece09SJunchao Zhang PetscInt *jbuf, *Fj; 8410e3ece09SJunchao Zhang PetscCall(PetscMalloc2(nroots, &jbuf, Fnz, &Fj)); 8420e3ece09SJunchao Zhang for (PetscInt k = 0; k < ioffset[niranks]; k++) { 8430e3ece09SJunchao Zhang PetscInt i = irootloc[k]; // row to be copied 8440e3ece09SJunchao Zhang PetscInt *buf = &jbuf[rowoffset[k]]; 8450e3ece09SJunchao Zhang PetscInt nzLeft = E_NzLeft[i]; 8460e3ece09SJunchao Zhang PetscInt alen = Ai[i + 1] - Ai[i], blen = Bi[i + 1] - Bi[i]; 8470e3ece09SJunchao Zhang for (PetscInt j = 0; j < alen + blen; j++) { 8480e3ece09SJunchao Zhang if (j < nzLeft) { 8490e3ece09SJunchao Zhang buf[j] = empi->garray[Bj[Bi[i] + j]]; // left B, in global 8500e3ece09SJunchao Zhang } else if (j < nzLeft + alen) { 8510e3ece09SJunchao Zhang buf[j] = Aj[Ai[i] + j - nzLeft] + cstart; // diag A, also in global 8520e3ece09SJunchao Zhang } else { 8530e3ece09SJunchao Zhang buf[j] = empi->garray[Bj[Bi[i] + j - alen]]; // right B, in global 8540e3ece09SJunchao Zhang } 8550e3ece09SJunchao Zhang } 8560e3ece09SJunchao Zhang } 8570e3ece09SJunchao Zhang PetscCall(PetscSFBcastWithMemTypeBegin(bcastSF, MPIU_INT, PETSC_MEMTYPE_HOST, jbuf, PETSC_MEMTYPE_HOST, Fj, MPI_REPLACE)); 8580e3ece09SJunchao Zhang PetscCall(PetscSFBcastEnd(bcastSF, MPIU_INT, jbuf, Fj, MPI_REPLACE)); 8590e3ece09SJunchao Zhang 8600e3ece09SJunchao Zhang // Build a plan (i.e., F_NzLeft) to split F into Fd and Fo 8617b8d4ba6SJunchao Zhang MatRowMapKokkosViewHost Fdi_h(NoInit("Fdi_h"), Fm + 1), Foi_h(NoInit("Foi_h"), Fm + 1); // row pointer of Fd, Fo 8627b8d4ba6SJunchao Zhang MatColIdxKokkosViewHost F_NzLeft_h(NoInit("F_NzLeft_h"), Fm); // split each row of F into Left, Diag, Right. We only need to record #nz in Left and Diag. 8630e3ece09SJunchao Zhang MatRowMapType *Fdi = Fdi_h.data(), *Foi = Foi_h.data(); 8640e3ece09SJunchao Zhang MatColIdxType *F_NzLeft = F_NzLeft_h.data(); 8650e3ece09SJunchao Zhang 8660e3ece09SJunchao Zhang Fdi[0] = Foi[0] = 0; 8670e3ece09SJunchao Zhang for (PetscInt i = 0; i < Fm; i++) { 8680e3ece09SJunchao Zhang PetscInt *first, *last, *lb1, *lb2; 8690e3ece09SJunchao Zhang // cut the row into: Left, [cstart, cend), Right 8700e3ece09SJunchao Zhang first = Fj + Fi[i]; 8710e3ece09SJunchao Zhang last = Fj + Fi[i + 1]; 8720e3ece09SJunchao Zhang lb1 = std::lower_bound(first, last, cstart); 8730e3ece09SJunchao Zhang F_NzLeft[i] = lb1 - first; 8740e3ece09SJunchao Zhang lb2 = std::lower_bound(first, last, cend); 8750e3ece09SJunchao Zhang Fdi[i + 1] = lb2 - lb1; // row i length in Fdi 8760e3ece09SJunchao Zhang Foi[i + 1] = (Fi[i + 1] - Fi[i]) - Fdi[i + 1]; // row i length in Foi 8770e3ece09SJunchao Zhang } 8780e3ece09SJunchao Zhang for (PetscInt i = 0; i < Fm; i++) { 8790e3ece09SJunchao Zhang Fdi[i + 1] += Fdi[i]; 8800e3ece09SJunchao Zhang Foi[i + 1] += Foi[i]; 8810e3ece09SJunchao Zhang } 8820e3ece09SJunchao Zhang 8830e3ece09SJunchao Zhang // Fill Fdj[] and Foj[], i.e., columns of Fd and Fo. Fdj[] are local, but Foj[] are not yet. 8840e3ece09SJunchao Zhang PetscInt Fdnz = Fdi[Fm], Fonz = Foi[Fm]; 8857b8d4ba6SJunchao Zhang MatColIdxKokkosViewHost Fdj_h(NoInit("Fdj_h"), Fdnz), Foj_h(NoInit("Foj_h"), Fonz); 8860e3ece09SJunchao Zhang MatColIdxType *Fdj = Fdj_h.data(), *Foj = Foj_h.data(), gid; 8870e3ece09SJunchao Zhang 8880e3ece09SJunchao Zhang for (PetscInt i = 0; i < Fm; i++) { 8890e3ece09SJunchao Zhang PetscInt nzLeft = F_NzLeft[i]; 8900e3ece09SJunchao Zhang PetscInt len = Fdi[i + 1] - Fdi[i]; // diag row len 8910e3ece09SJunchao Zhang for (PetscInt j = 0; j < Fi[i + 1] - Fi[i]; j++) { 8920e3ece09SJunchao Zhang gid = Fj[Fi[i] + j]; 8930e3ece09SJunchao Zhang if (j < nzLeft) { // left, in global 8940e3ece09SJunchao Zhang Foj[Foi[i] + j] = gid; 8950e3ece09SJunchao Zhang } else if (j < nzLeft + len) { // diag, in local 8960e3ece09SJunchao Zhang Fdj[Fdi[i] + j - nzLeft] = gid - cstart; 8970e3ece09SJunchao Zhang } else { // right, in global 8980e3ece09SJunchao Zhang Foj[Foi[i] + j - len] = gid; 8990e3ece09SJunchao Zhang } 9000e3ece09SJunchao Zhang } 9010e3ece09SJunchao Zhang } 9020e3ece09SJunchao Zhang PetscCall(PetscFree2(jbuf, Fj)); 9030e3ece09SJunchao Zhang PetscCall(PetscFree(Fi)); 9040e3ece09SJunchao Zhang 9050e3ece09SJunchao Zhang // Reduce global indices in Foj[] and garray1[] into local ones 9060e3ece09SJunchao Zhang PetscInt n2, *garray2; 9070e3ece09SJunchao Zhang PetscCall(ReduceTwoSetsOfGlobalIndices(n1, garray1, Fonz, Foj, &n2, &garray2, map)); 9080e3ece09SJunchao Zhang 9090e3ece09SJunchao Zhang // Record the plans built above, for reuse 9100e3ece09SJunchao Zhang PetscIntKokkosViewHost tmp(const_cast<PetscInt *>(irootloc), ioffset[niranks]); // irootloc[] is owned by ownerSF. We create a copy for safety 9117b8d4ba6SJunchao Zhang PetscIntKokkosViewHost irootloc_h(NoInit("irootloc_h"), ioffset[niranks]); 9120e3ece09SJunchao Zhang Kokkos::deep_copy(irootloc_h, tmp); 9130e3ece09SJunchao Zhang mm->sf = bcastSF; 9140e3ece09SJunchao Zhang mm->E_NzLeft = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), E_NzLeft_h); 9150e3ece09SJunchao Zhang mm->F_NzLeft = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), F_NzLeft_h); 9160e3ece09SJunchao Zhang mm->irootloc = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), irootloc_h); 9170e3ece09SJunchao Zhang mm->rowoffset = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), rowoffset_h); 9187b8d4ba6SJunchao Zhang mm->rootBuf = MatScalarKokkosView(NoInit("rootBuf"), nroots); 9197b8d4ba6SJunchao Zhang mm->leafBuf = MatScalarKokkosView(NoInit("leafBuf"), nleaves); 9200e3ece09SJunchao Zhang mm->garray = garray2; 9210e3ece09SJunchao Zhang mm->n = n2; 9220e3ece09SJunchao Zhang 9230e3ece09SJunchao Zhang // Output Fd and Fo in KokkosCsrMatrix format 9247b8d4ba6SJunchao Zhang MatScalarKokkosView Fda_d(NoInit("Fda_d"), Fdnz), Foa_d(NoInit("Foa_d"), Fonz); 9250e3ece09SJunchao Zhang MatRowMapKokkosView Fdi_d = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Fdi_h); 9260e3ece09SJunchao Zhang MatColIdxKokkosView Fdj_d = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Fdj_h); 9270e3ece09SJunchao Zhang MatRowMapKokkosView Foi_d = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Foi_h); 9280e3ece09SJunchao Zhang MatColIdxKokkosView Foj_d = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Foj_h); 9290e3ece09SJunchao Zhang 9300e3ece09SJunchao Zhang PetscCallCXX(mm->Fd = KokkosCsrMatrix("Fd", Fm, cend - cstart, Fdnz, Fda_d, Fdi_d, Fdj_d)); 9310e3ece09SJunchao Zhang PetscCallCXX(mm->Fo = KokkosCsrMatrix("Fo", Fm, n2, Fonz, Foa_d, Foi_d, Foj_d)); 9320e3ece09SJunchao Zhang 9330e3ece09SJunchao Zhang // Compute kernel launch parameters in merging E or splitting F 9340e3ece09SJunchao Zhang PetscInt teamSize, vectorLength, rowsPerTeam; 9350e3ece09SJunchao Zhang 9360e3ece09SJunchao Zhang teamSize = vectorLength = rowsPerTeam = -1; 9370e3ece09SJunchao Zhang PetscCall(MatMergeGetLaunchParameters<DefaultExecutionSpace>(mm->irootloc.extent(0), mm->rootBuf.extent(0), -1, teamSize, vectorLength, rowsPerTeam)); 9380e3ece09SJunchao Zhang mm->E_TeamSize = teamSize; 9390e3ece09SJunchao Zhang mm->E_VectorLength = vectorLength; 9400e3ece09SJunchao Zhang mm->E_RowsPerTeam = rowsPerTeam; 9410e3ece09SJunchao Zhang 9420e3ece09SJunchao Zhang teamSize = vectorLength = rowsPerTeam = -1; 9430e3ece09SJunchao Zhang PetscCall(MatMergeGetLaunchParameters<DefaultExecutionSpace>(Fm, Fnz, -1, teamSize, vectorLength, rowsPerTeam)); 9440e3ece09SJunchao Zhang mm->F_TeamSize = teamSize; 9450e3ece09SJunchao Zhang mm->F_VectorLength = vectorLength; 9460e3ece09SJunchao Zhang mm->F_RowsPerTeam = rowsPerTeam; 9470e3ece09SJunchao Zhang } else PetscCheck(reuse == MAT_REUSE_MATRIX, comm, PETSC_ERR_PLIB, "Unsupported MatReuse enum %d", reuse); 9480e3ece09SJunchao Zhang 9490e3ece09SJunchao Zhang // Sync E's value to device 950f3d3cd90SZach Atkins PetscCall(KokkosDualViewSyncDevice(akok->a_dual, PetscGetKokkosExecutionSpace())); 951f3d3cd90SZach Atkins PetscCall(KokkosDualViewSyncDevice(bkok->a_dual, PetscGetKokkosExecutionSpace())); 9520e3ece09SJunchao Zhang 9530e3ece09SJunchao Zhang // Handy aliases 9540e3ece09SJunchao Zhang const auto &Aa = akok->a_dual.view_device(); 9550e3ece09SJunchao Zhang const auto &Ba = bkok->a_dual.view_device(); 9560e3ece09SJunchao Zhang const auto &Ai = akok->i_dual.view_device(); 9570e3ece09SJunchao Zhang const auto &Bi = bkok->i_dual.view_device(); 9580e3ece09SJunchao Zhang 9590e3ece09SJunchao Zhang // Fetch the plans 9600e3ece09SJunchao Zhang PetscIntKokkosView &E_NzLeft = mm->E_NzLeft; 9610e3ece09SJunchao Zhang PetscSF &bcastSF = mm->sf; 9620e3ece09SJunchao Zhang MatScalarKokkosView &rootBuf = mm->rootBuf; 9630e3ece09SJunchao Zhang MatScalarKokkosView &leafBuf = mm->leafBuf; 9640e3ece09SJunchao Zhang PetscIntKokkosView &irootloc = mm->irootloc; 9650e3ece09SJunchao Zhang PetscIntKokkosView &rowoffset = mm->rowoffset; 9660e3ece09SJunchao Zhang 9670e3ece09SJunchao Zhang PetscInt teamSize = mm->E_TeamSize; 9680e3ece09SJunchao Zhang PetscInt vectorLength = mm->E_VectorLength; 9690e3ece09SJunchao Zhang PetscInt rowsPerTeam = mm->E_RowsPerTeam; 9700e3ece09SJunchao Zhang PetscInt workSets = (irootloc.extent(0) + rowsPerTeam - 1) / rowsPerTeam; 9710e3ece09SJunchao Zhang 9720e3ece09SJunchao Zhang // Copy rows in A/B of E to rootBuf, then bcast it to leafBuf 9730e3ece09SJunchao Zhang PetscCallCXX(Kokkos::parallel_for( 974d326c3f1SJunchao Zhang Kokkos::TeamPolicy<>(PetscGetKokkosExecutionSpace(), workSets, teamSize, vectorLength), KOKKOS_LAMBDA(const KokkosTeamMemberType &t) { 9750e3ece09SJunchao Zhang Kokkos::parallel_for(Kokkos::TeamThreadRange(t, 0, rowsPerTeam), [&](PetscInt k) { 9760e3ece09SJunchao Zhang size_t r = t.league_rank() * rowsPerTeam + k; // r-th entry in irootloc[] 9770e3ece09SJunchao Zhang if (r < irootloc.extent(0)) { 9780e3ece09SJunchao Zhang PetscInt i = irootloc(r); // row i of E 9790e3ece09SJunchao Zhang PetscInt disp = rowoffset(r); 9800e3ece09SJunchao Zhang PetscInt alen = Ai(i + 1) - Ai(i); 9810e3ece09SJunchao Zhang PetscInt blen = Bi(i + 1) - Bi(i); 9820e3ece09SJunchao Zhang PetscInt nzleft = E_NzLeft(i); 9830e3ece09SJunchao Zhang 9840e3ece09SJunchao Zhang Kokkos::parallel_for(Kokkos::ThreadVectorRange(t, alen + blen), [&](PetscInt j) { 9850e3ece09SJunchao Zhang if (j < nzleft) { // B left 9860e3ece09SJunchao Zhang rootBuf(disp + j) = Ba(Bi(i) + j); 9870e3ece09SJunchao Zhang } else if (j < nzleft + alen) { // diag A 9880e3ece09SJunchao Zhang rootBuf(disp + j) = Aa(Ai(i) + j - nzleft); 9890e3ece09SJunchao Zhang } else { // B right 9900e3ece09SJunchao Zhang rootBuf(disp + j) = Ba(Bi(i) + j - alen); 9910e3ece09SJunchao Zhang } 9920e3ece09SJunchao Zhang }); 9930e3ece09SJunchao Zhang } 9940e3ece09SJunchao Zhang }); 9950e3ece09SJunchao Zhang })); 9960e3ece09SJunchao Zhang PetscCall(PetscSFBcastWithMemTypeBegin(bcastSF, MPIU_SCALAR, PETSC_MEMTYPE_KOKKOS, rootBuf.data(), PETSC_MEMTYPE_KOKKOS, leafBuf.data(), MPI_REPLACE)); 9970e3ece09SJunchao Zhang PetscFunctionReturn(PETSC_SUCCESS); 9980e3ece09SJunchao Zhang } 9990e3ece09SJunchao Zhang 10000e3ece09SJunchao Zhang // To finish MatMPIAIJKokkosBcast. 10010e3ece09SJunchao Zhang static PetscErrorCode MatMPIAIJKokkosBcastEnd(Mat E, PetscSF ownerSF, MatReuse reuse, PetscInt *map, MatMatStruct_AB *mm) 10020e3ece09SJunchao Zhang { 10030e3ece09SJunchao Zhang PetscFunctionBegin; 10040e3ece09SJunchao Zhang const auto &Fd = mm->Fd; 10050e3ece09SJunchao Zhang const auto &Fo = mm->Fo; 10060e3ece09SJunchao Zhang const auto &Fdi = Fd.graph.row_map; 10070e3ece09SJunchao Zhang const auto &Foi = Fo.graph.row_map; 10080e3ece09SJunchao Zhang auto &Fda = Fd.values; 10090e3ece09SJunchao Zhang auto &Foa = Fo.values; 10100e3ece09SJunchao Zhang auto Fm = Fd.numRows(); 10110e3ece09SJunchao Zhang 10120e3ece09SJunchao Zhang PetscIntKokkosView &F_NzLeft = mm->F_NzLeft; 10130e3ece09SJunchao Zhang PetscSF &bcastSF = mm->sf; 10140e3ece09SJunchao Zhang MatScalarKokkosView &rootBuf = mm->rootBuf; 10150e3ece09SJunchao Zhang MatScalarKokkosView &leafBuf = mm->leafBuf; 10160e3ece09SJunchao Zhang PetscInt teamSize = mm->F_TeamSize; 10170e3ece09SJunchao Zhang PetscInt vectorLength = mm->F_VectorLength; 10180e3ece09SJunchao Zhang PetscInt rowsPerTeam = mm->F_RowsPerTeam; 10190e3ece09SJunchao Zhang PetscInt workSets = (Fm + rowsPerTeam - 1) / rowsPerTeam; 10200e3ece09SJunchao Zhang 10210e3ece09SJunchao Zhang PetscCall(PetscSFBcastEnd(bcastSF, MPIU_SCALAR, rootBuf.data(), leafBuf.data(), MPI_REPLACE)); 10220e3ece09SJunchao Zhang 10230e3ece09SJunchao Zhang // Update Fda and Foa with new data in leafBuf (as if it is Fa) 10240e3ece09SJunchao Zhang PetscCallCXX(Kokkos::parallel_for( 1025d326c3f1SJunchao Zhang Kokkos::TeamPolicy<>(PetscGetKokkosExecutionSpace(), workSets, teamSize, vectorLength), KOKKOS_LAMBDA(const KokkosTeamMemberType &t) { 10260e3ece09SJunchao Zhang Kokkos::parallel_for(Kokkos::TeamThreadRange(t, 0, rowsPerTeam), [&](PetscInt k) { 10270e3ece09SJunchao Zhang PetscInt i = t.league_rank() * rowsPerTeam + k; // i-th row in F 10280e3ece09SJunchao Zhang if (i < Fm) { 10290e3ece09SJunchao Zhang PetscInt nzLeft = F_NzLeft(i); 10300e3ece09SJunchao Zhang PetscInt alen = Fdi(i + 1) - Fdi(i); 10310e3ece09SJunchao Zhang PetscInt blen = Foi(i + 1) - Foi(i); 10320e3ece09SJunchao Zhang PetscInt Fii = Fdi(i) + Foi(i); 10330e3ece09SJunchao Zhang 10340e3ece09SJunchao Zhang Kokkos::parallel_for(Kokkos::ThreadVectorRange(t, alen + blen), [&](PetscInt j) { 10350e3ece09SJunchao Zhang PetscScalar val = leafBuf(Fii + j); 10360e3ece09SJunchao Zhang if (j < nzLeft) { // left 10370e3ece09SJunchao Zhang Foa(Foi(i) + j) = val; 10380e3ece09SJunchao Zhang } else if (j < nzLeft + alen) { // diag 10390e3ece09SJunchao Zhang Fda(Fdi(i) + j - nzLeft) = val; 10400e3ece09SJunchao Zhang } else { // right 10410e3ece09SJunchao Zhang Foa(Foi(i) + j - alen) = val; 10420e3ece09SJunchao Zhang } 10430e3ece09SJunchao Zhang }); 10440e3ece09SJunchao Zhang } 10450e3ece09SJunchao Zhang }); 10460e3ece09SJunchao Zhang })); 10470e3ece09SJunchao Zhang PetscFunctionReturn(PETSC_SUCCESS); 10480e3ece09SJunchao Zhang } 10490e3ece09SJunchao Zhang 10500e3ece09SJunchao Zhang static PetscErrorCode MatProductSymbolic_MPIAIJKokkos_AtB(Mat_Product *product, Mat A, Mat B, MatMatStruct_AtB *mm) 10510e3ece09SJunchao Zhang { 10520e3ece09SJunchao Zhang Mat_MPIAIJ *ampi = static_cast<Mat_MPIAIJ *>(A->data); 10530e3ece09SJunchao Zhang Mat_MPIAIJ *bmpi = static_cast<Mat_MPIAIJ *>(B->data); 10540e3ece09SJunchao Zhang KokkosCsrMatrix Adt, Aot, Ad, Ao, Bd, Bo; 10550e3ece09SJunchao Zhang PetscInt cstart, cend; 10560e3ece09SJunchao Zhang MPI_Comm comm; 10570e3ece09SJunchao Zhang 10580e3ece09SJunchao Zhang PetscFunctionBegin; 10590e3ece09SJunchao Zhang PetscCall(PetscObjectGetComm((PetscObject)B, &comm)); 10600e3ece09SJunchao Zhang PetscCall(MatSeqAIJKokkosGenerateTranspose_Private(ampi->A, &Adt)); 10610e3ece09SJunchao Zhang PetscCall(MatSeqAIJKokkosGenerateTranspose_Private(ampi->B, &Aot)); 10620e3ece09SJunchao Zhang PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(ampi->A, &Ad)); 10630e3ece09SJunchao Zhang PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(ampi->B, &Ao)); 10640e3ece09SJunchao Zhang PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(bmpi->A, &Bd)); 10650e3ece09SJunchao Zhang PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(bmpi->B, &Bo)); 10660e3ece09SJunchao Zhang 10670e3ece09SJunchao Zhang // TODO: add command line options to select spgemm algorithms 10680e3ece09SJunchao Zhang auto spgemm_alg = KokkosSparse::SPGEMMAlgorithm::SPGEMM_DEFAULT; // default is TPL if enabled, otherwise KK 10690e3ece09SJunchao Zhang 10700e3ece09SJunchao Zhang // CUDA-10.2's spgemm has bugs. We prefer the SpGEMMreuse APIs introduced in cuda-11.4 10710e3ece09SJunchao Zhang #if defined(KOKKOSKERNELS_ENABLE_TPL_CUSPARSE) 10720e3ece09SJunchao Zhang #if PETSC_PKG_CUDA_VERSION_LT(11, 4, 0) 10730e3ece09SJunchao Zhang spgemm_alg = KokkosSparse::SPGEMMAlgorithm::SPGEMM_KK; 1074f0e6e2d1SJunchao Zhang #endif 10750e3ece09SJunchao Zhang #endif 10760e3ece09SJunchao Zhang 10770e3ece09SJunchao Zhang PetscCallCXX(mm->kh1.create_spgemm_handle(spgemm_alg)); 10780e3ece09SJunchao Zhang PetscCallCXX(mm->kh2.create_spgemm_handle(spgemm_alg)); 10790e3ece09SJunchao Zhang PetscCallCXX(mm->kh3.create_spgemm_handle(spgemm_alg)); 10800e3ece09SJunchao Zhang PetscCallCXX(mm->kh4.create_spgemm_handle(spgemm_alg)); 10810e3ece09SJunchao Zhang 10820e3ece09SJunchao Zhang // Aot * (B's diag + B's off-diag) 10830e3ece09SJunchao Zhang PetscCallCXX(KokkosSparse::spgemm_symbolic(mm->kh3, Aot, false, Bd, false, mm->C3)); 10840e3ece09SJunchao Zhang PetscCallCXX(KokkosSparse::spgemm_symbolic(mm->kh4, Aot, false, Bo, false, mm->C4)); 10850e3ece09SJunchao Zhang // KK spgemm_symbolic() only populates the result's row map, but not its columns. 10860e3ece09SJunchao Zhang // TODO: Remove the fake spgemm_numeric() after KK fixed this problem. 10870e3ece09SJunchao Zhang PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh3, Aot, false, Bd, false, mm->C3)); 10880e3ece09SJunchao Zhang PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh4, Aot, false, Bo, false, mm->C4)); 10890e3ece09SJunchao Zhang #if PETSC_PKG_KOKKOS_KERNELS_VERSION_LT(4, 0, 0) 1090d326c3f1SJunchao Zhang 10910e3ece09SJunchao Zhang PetscCallCXX(sort_crs_matrix(mm->C3)); 10920e3ece09SJunchao Zhang PetscCallCXX(sort_crs_matrix(mm->C4)); 10930e3ece09SJunchao Zhang #endif 10940e3ece09SJunchao Zhang 10950e3ece09SJunchao Zhang // Reduce E (i.e., C3 and C4)'s rows to form F, and overlap the communication 10967b8d4ba6SJunchao Zhang PetscIntKokkosViewHost map_h(NoInit("map_h"), bmpi->B->cmap->n); 10970e3ece09SJunchao Zhang PetscCall(MatGetOwnershipRangeColumn(B, &cstart, &cend)); 10980e3ece09SJunchao Zhang PetscCall(MatMPIAIJKokkosReduceBegin(comm, mm->C3, mm->C4, cstart, cend, bmpi->garray, ampi->Mvctx, MAT_INITIAL_MATRIX, map_h.data(), mm)); 10990e3ece09SJunchao Zhang 11000e3ece09SJunchao Zhang // Adt * (B's diag + B's off-diag) 11010e3ece09SJunchao Zhang PetscCallCXX(KokkosSparse::spgemm_symbolic(mm->kh1, Adt, false, Bd, false, mm->C1)); 11020e3ece09SJunchao Zhang PetscCallCXX(KokkosSparse::spgemm_symbolic(mm->kh2, Adt, false, Bo, false, mm->C2_mid)); 11030e3ece09SJunchao Zhang PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh1, Adt, false, Bd, false, mm->C1)); 11040e3ece09SJunchao Zhang PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh2, Adt, false, Bo, false, mm->C2_mid)); 11050e3ece09SJunchao Zhang #if PETSC_PKG_KOKKOS_KERNELS_VERSION_LT(4, 0, 0) 11060e3ece09SJunchao Zhang PetscCallCXX(sort_crs_matrix(mm->C1)); 11070e3ece09SJunchao Zhang PetscCallCXX(sort_crs_matrix(mm->C2_mid)); 11080e3ece09SJunchao Zhang #endif 11090e3ece09SJunchao Zhang 11100e3ece09SJunchao Zhang PetscCall(MatMPIAIJKokkosReduceEnd(comm, mm->C3, mm->C4, cstart, cend, bmpi->garray, ampi->Mvctx, MAT_INITIAL_MATRIX, map_h.data(), mm)); 11110e3ece09SJunchao Zhang 11120e3ece09SJunchao Zhang // Create C2, which shares a, i arrays with C2_mid, but with new column indices and potentially larger column size 11137b8d4ba6SJunchao Zhang MatColIdxKokkosView oldj = mm->C2_mid.graph.entries, newj(NoInit("j"), oldj.extent(0)); 11140e3ece09SJunchao Zhang PetscIntKokkosView map = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), map_h); 1115d326c3f1SJunchao Zhang PetscCallCXX(Kokkos::parallel_for(Kokkos::RangePolicy<>(PetscGetKokkosExecutionSpace(), 0, oldj.extent(0)), KOKKOS_LAMBDA(const PetscInt i) { newj(i) = map(oldj(i)); })); 11160e3ece09SJunchao Zhang PetscCallCXX(mm->C2 = KokkosCsrMatrix("C2", mm->C2_mid.numRows(), mm->n /*new column size*/, mm->C2_mid.nnz(), mm->C2_mid.values, mm->C2_mid.graph.row_map, newj)); 11170e3ece09SJunchao Zhang 11180e3ece09SJunchao Zhang // C = (C1+Fd, C2+Fo) 11190e3ece09SJunchao Zhang PetscCallCXX(mm->kh1.create_spadd_handle(true)); // C1, Fd are sorted 11200e3ece09SJunchao Zhang PetscCallCXX(mm->kh2.create_spadd_handle(true)); // C2, Fo are sorted 11210e3ece09SJunchao Zhang PetscCallCXX(KokkosSparse::spadd_symbolic(&mm->kh1, mm->C1, mm->Fd, mm->Cd)); 11220e3ece09SJunchao Zhang PetscCallCXX(KokkosSparse::spadd_symbolic(&mm->kh2, mm->C2, mm->Fo, mm->Co)); 11230e3ece09SJunchao Zhang PetscCallCXX(KokkosSparse::spadd_numeric(&mm->kh1, 1.0, mm->C1, 1.0, mm->Fd, mm->Cd)); 11240e3ece09SJunchao Zhang PetscCallCXX(KokkosSparse::spadd_numeric(&mm->kh2, 1.0, mm->C2, 1.0, mm->Fo, mm->Co)); 11250e3ece09SJunchao Zhang PetscFunctionReturn(PETSC_SUCCESS); 11260e3ece09SJunchao Zhang } 11270e3ece09SJunchao Zhang 11280e3ece09SJunchao Zhang static PetscErrorCode MatProductNumeric_MPIAIJKokkos_AtB(Mat_Product *product, Mat A, Mat B, MatMatStruct_AtB *mm) 11290e3ece09SJunchao Zhang { 11300e3ece09SJunchao Zhang Mat_MPIAIJ *ampi = static_cast<Mat_MPIAIJ *>(A->data); 11310e3ece09SJunchao Zhang Mat_MPIAIJ *bmpi = static_cast<Mat_MPIAIJ *>(B->data); 11320e3ece09SJunchao Zhang KokkosCsrMatrix Adt, Aot, Bd, Bo; 11330e3ece09SJunchao Zhang MPI_Comm comm; 11340e3ece09SJunchao Zhang 11350e3ece09SJunchao Zhang PetscFunctionBegin; 11360e3ece09SJunchao Zhang PetscCall(PetscObjectGetComm((PetscObject)B, &comm)); 11370e3ece09SJunchao Zhang PetscCall(MatSeqAIJKokkosGenerateTranspose_Private(ampi->A, &Adt)); 11380e3ece09SJunchao Zhang PetscCall(MatSeqAIJKokkosGenerateTranspose_Private(ampi->B, &Aot)); 11390e3ece09SJunchao Zhang PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(bmpi->A, &Bd)); 11400e3ece09SJunchao Zhang PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(bmpi->B, &Bo)); 11410e3ece09SJunchao Zhang 11420e3ece09SJunchao Zhang // Aot * (B's diag + B's off-diag) 11430e3ece09SJunchao Zhang PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh3, Aot, false, Bd, false, mm->C3)); 11440e3ece09SJunchao Zhang PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh4, Aot, false, Bo, false, mm->C4)); 11450e3ece09SJunchao Zhang 11460e3ece09SJunchao Zhang // Reduce E (i.e., C3 and C4)'s rows to form F, and overlap the communication 11470e3ece09SJunchao Zhang PetscCall(MatMPIAIJKokkosReduceBegin(comm, mm->C3, mm->C4, 0, 0, NULL, NULL, MAT_REUSE_MATRIX, NULL, mm)); 11480e3ece09SJunchao Zhang 11490e3ece09SJunchao Zhang // Adt * (B's diag + B's off-diag) 11500e3ece09SJunchao Zhang PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh1, Adt, false, Bd, false, mm->C1)); 11510e3ece09SJunchao Zhang PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh2, Adt, false, Bo, false, mm->C2_mid)); 11520e3ece09SJunchao Zhang 11530e3ece09SJunchao Zhang PetscCall(MatMPIAIJKokkosReduceEnd(comm, mm->C3, mm->C4, 0, 0, NULL, NULL, MAT_REUSE_MATRIX, NULL, mm)); 11540e3ece09SJunchao Zhang 11550e3ece09SJunchao Zhang // C = (C1+Fd, C2+Fo) 11560e3ece09SJunchao Zhang PetscCallCXX(KokkosSparse::spadd_numeric(&mm->kh1, 1.0, mm->C1, 1.0, mm->Fd, mm->Cd)); 11570e3ece09SJunchao Zhang PetscCallCXX(KokkosSparse::spadd_numeric(&mm->kh2, 1.0, mm->C2, 1.0, mm->Fo, mm->Co)); 11580e3ece09SJunchao Zhang PetscFunctionReturn(PETSC_SUCCESS); 11590e3ece09SJunchao Zhang } 1160f0e6e2d1SJunchao Zhang 1161076ba34aSJunchao Zhang /* MatProductSymbolic_MPIAIJKokkos_AB - AB flavor of MatProductSymbolic_MPIAIJKokkos 1162076ba34aSJunchao Zhang 1163076ba34aSJunchao Zhang Input Parameters: 1164076ba34aSJunchao Zhang + product - Mat_Product which carried out the computation. Passed in to access info about this mat product. 1165076ba34aSJunchao Zhang . A - an MPIAIJKOKKOS matrix 1166076ba34aSJunchao Zhang . B - an MPIAIJKOKKOS matrix 1167076ba34aSJunchao Zhang - mm - a struct used to stash intermediate data when computing AB. Persist from symbolic to numeric operations. 1168076ba34aSJunchao Zhang */ 1169d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatProductSymbolic_MPIAIJKokkos_AB(Mat_Product *product, Mat A, Mat B, MatMatStruct_AB *mm) 1170d71ae5a4SJacob Faibussowitsch { 11710e3ece09SJunchao Zhang Mat_MPIAIJ *ampi = static_cast<Mat_MPIAIJ *>(A->data); 11720e3ece09SJunchao Zhang Mat_MPIAIJ *bmpi = static_cast<Mat_MPIAIJ *>(B->data); 11730e3ece09SJunchao Zhang KokkosCsrMatrix Ad, Ao, Bd, Bo; 1174076ba34aSJunchao Zhang 1175076ba34aSJunchao Zhang PetscFunctionBegin; 11760e3ece09SJunchao Zhang PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(ampi->A, &Ad)); 11770e3ece09SJunchao Zhang PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(ampi->B, &Ao)); 11780e3ece09SJunchao Zhang PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(bmpi->A, &Bd)); 11790e3ece09SJunchao Zhang PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(bmpi->B, &Bo)); 11800e3ece09SJunchao Zhang 11810e3ece09SJunchao Zhang // TODO: add command line options to select spgemm algorithms 11820e3ece09SJunchao Zhang auto spgemm_alg = KokkosSparse::SPGEMMAlgorithm::SPGEMM_DEFAULT; // default is TPL if enabled, otherwise KK 11830e3ece09SJunchao Zhang 11840e3ece09SJunchao Zhang // CUDA-10.2's spgemm has bugs. We prefer the SpGEMMreuse APIs introduced in cuda-11.4 11850e3ece09SJunchao Zhang #if defined(KOKKOSKERNELS_ENABLE_TPL_CUSPARSE) 11860e3ece09SJunchao Zhang #if PETSC_PKG_CUDA_VERSION_LT(11, 4, 0) 11870e3ece09SJunchao Zhang spgemm_alg = KokkosSparse::SPGEMMAlgorithm::SPGEMM_KK; 11880e3ece09SJunchao Zhang #endif 1189f0e6e2d1SJunchao Zhang #endif 1190f0e6e2d1SJunchao Zhang 11910e3ece09SJunchao Zhang mm->kh1.create_spgemm_handle(spgemm_alg); 11920e3ece09SJunchao Zhang mm->kh2.create_spgemm_handle(spgemm_alg); 11930e3ece09SJunchao Zhang mm->kh3.create_spgemm_handle(spgemm_alg); 11940e3ece09SJunchao Zhang mm->kh4.create_spgemm_handle(spgemm_alg); 1195076ba34aSJunchao Zhang 11960e3ece09SJunchao Zhang // Bcast B's rows to form F, and overlap the communication 11977b8d4ba6SJunchao Zhang PetscIntKokkosViewHost map_h(NoInit("map_h"), bmpi->B->cmap->n); 11980e3ece09SJunchao Zhang PetscCall(MatMPIAIJKokkosBcastBegin(B, ampi->Mvctx, MAT_INITIAL_MATRIX, map_h.data(), mm)); 1199076ba34aSJunchao Zhang 12000e3ece09SJunchao Zhang // A's diag * (B's diag + B's off-diag) 12010e3ece09SJunchao Zhang PetscCallCXX(KokkosSparse::spgemm_symbolic(mm->kh1, Ad, false, Bd, false, mm->C1)); 12020e3ece09SJunchao Zhang PetscCallCXX(KokkosSparse::spgemm_symbolic(mm->kh2, Ad, false, Bo, false, mm->C2_mid)); // C2 aliases with C2_mid, except with new column indices 12030e3ece09SJunchao Zhang // KK spgemm_symbolic() only populates the result's row map, but not its columns. 12040e3ece09SJunchao Zhang // TODO: Remove the fake spgemm_numeric() after KK fixed this problem. 12050e3ece09SJunchao Zhang PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh1, Ad, false, Bd, false, mm->C1)); 12060e3ece09SJunchao Zhang PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh2, Ad, false, Bo, false, mm->C2_mid)); 12070e3ece09SJunchao Zhang #if PETSC_PKG_KOKKOS_KERNELS_VERSION_LT(4, 0, 0) 12080e3ece09SJunchao Zhang PetscCallCXX(sort_crs_matrix(mm->C1)); 12090e3ece09SJunchao Zhang PetscCallCXX(sort_crs_matrix(mm->C2_mid)); 12100e3ece09SJunchao Zhang #endif 1211076ba34aSJunchao Zhang 12120e3ece09SJunchao Zhang PetscCall(MatMPIAIJKokkosBcastEnd(B, ampi->Mvctx, MAT_INITIAL_MATRIX, map_h.data(), mm)); 1213076ba34aSJunchao Zhang 12140e3ece09SJunchao Zhang // A's off-diag * (F's diag + F's off-diag) 12150e3ece09SJunchao Zhang PetscCallCXX(KokkosSparse::spgemm_symbolic(mm->kh3, Ao, false, mm->Fd, false, mm->C3)); 12160e3ece09SJunchao Zhang PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh3, Ao, false, mm->Fd, false, mm->C3)); 12170e3ece09SJunchao Zhang PetscCallCXX(KokkosSparse::spgemm_symbolic(mm->kh4, Ao, false, mm->Fo, false, mm->C4)); 12180e3ece09SJunchao Zhang PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh4, Ao, false, mm->Fo, false, mm->C4)); 12190e3ece09SJunchao Zhang #if PETSC_PKG_KOKKOS_KERNELS_VERSION_LT(4, 0, 0) 12200e3ece09SJunchao Zhang PetscCallCXX(sort_crs_matrix(mm->C3)); 12210e3ece09SJunchao Zhang PetscCallCXX(sort_crs_matrix(mm->C4)); 12220e3ece09SJunchao Zhang #endif 1223076ba34aSJunchao Zhang 12240e3ece09SJunchao Zhang // Create C2, which shares a, i arrays with C2_mid, but with new column indices and potentially larger column size 12257b8d4ba6SJunchao Zhang MatColIdxKokkosView oldj = mm->C2_mid.graph.entries, newj(NoInit("j"), oldj.extent(0)); 12260e3ece09SJunchao Zhang PetscIntKokkosView map = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), map_h); 1227d326c3f1SJunchao Zhang PetscCallCXX(Kokkos::parallel_for(Kokkos::RangePolicy<>(PetscGetKokkosExecutionSpace(), 0, oldj.extent(0)), KOKKOS_LAMBDA(const PetscInt i) { newj(i) = map(oldj(i)); })); 12280e3ece09SJunchao Zhang mm->C2 = KokkosCsrMatrix("C2", mm->C2_mid.numRows(), mm->n /*new column size*/, mm->C2_mid.nnz(), mm->C2_mid.values, mm->C2_mid.graph.row_map, newj); 12290e3ece09SJunchao Zhang 12300e3ece09SJunchao Zhang // C = (Cd, Co) = (C1+C3, C2+C4) 12310e3ece09SJunchao Zhang mm->kh1.create_spadd_handle(true); // C1, C3 are sorted 12320e3ece09SJunchao Zhang mm->kh2.create_spadd_handle(true); // C2, C4 are sorted 12330e3ece09SJunchao Zhang PetscCallCXX(KokkosSparse::spadd_symbolic(&mm->kh1, mm->C1, mm->C3, mm->Cd)); 12340e3ece09SJunchao Zhang PetscCallCXX(KokkosSparse::spadd_symbolic(&mm->kh2, mm->C2, mm->C4, mm->Co)); 12350e3ece09SJunchao Zhang PetscCallCXX(KokkosSparse::spadd_numeric(&mm->kh1, 1.0, mm->C1, 1.0, mm->C3, mm->Cd)); 12360e3ece09SJunchao Zhang PetscCallCXX(KokkosSparse::spadd_numeric(&mm->kh2, 1.0, mm->C2, 1.0, mm->C4, mm->Co)); 12373ba16761SJacob Faibussowitsch PetscFunctionReturn(PETSC_SUCCESS); 1238076ba34aSJunchao Zhang } 1239076ba34aSJunchao Zhang 12400e3ece09SJunchao Zhang static PetscErrorCode MatProductNumeric_MPIAIJKokkos_AB(Mat_Product *product, Mat A, Mat B, MatMatStruct_AB *mm) 1241d71ae5a4SJacob Faibussowitsch { 12420e3ece09SJunchao Zhang Mat_MPIAIJ *ampi = static_cast<Mat_MPIAIJ *>(A->data); 12430e3ece09SJunchao Zhang Mat_MPIAIJ *bmpi = static_cast<Mat_MPIAIJ *>(B->data); 12440e3ece09SJunchao Zhang KokkosCsrMatrix Ad, Ao, Bd, Bo; 1245076ba34aSJunchao Zhang 1246076ba34aSJunchao Zhang PetscFunctionBegin; 12470e3ece09SJunchao Zhang PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(ampi->A, &Ad)); 12480e3ece09SJunchao Zhang PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(ampi->B, &Ao)); 12490e3ece09SJunchao Zhang PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(bmpi->A, &Bd)); 12500e3ece09SJunchao Zhang PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(bmpi->B, &Bo)); 1251076ba34aSJunchao Zhang 12520e3ece09SJunchao Zhang // Bcast B's rows to form F, and overlap the communication 12530e3ece09SJunchao Zhang PetscCall(MatMPIAIJKokkosBcastBegin(B, NULL, MAT_REUSE_MATRIX, NULL, mm)); 1254076ba34aSJunchao Zhang 12550e3ece09SJunchao Zhang // A's diag * (B's diag + B's off-diag) 12560e3ece09SJunchao Zhang PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh1, Ad, false, Bd, false, mm->C1)); 12570e3ece09SJunchao Zhang PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh2, Ad, false, Bo, false, mm->C2_mid)); 1258076ba34aSJunchao Zhang 12590e3ece09SJunchao Zhang PetscCall(MatMPIAIJKokkosBcastEnd(B, NULL, MAT_REUSE_MATRIX, NULL, mm)); 1260076ba34aSJunchao Zhang 12610e3ece09SJunchao Zhang // A's off-diag * (F's diag + F's off-diag) 12620e3ece09SJunchao Zhang PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh3, Ao, false, mm->Fd, false, mm->C3)); 12630e3ece09SJunchao Zhang PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh4, Ao, false, mm->Fo, false, mm->C4)); 12640e3ece09SJunchao Zhang 12650e3ece09SJunchao Zhang // C = (Cd, Co) = (C1+C3, C2+C4) 12660e3ece09SJunchao Zhang PetscCallCXX(KokkosSparse::spadd_numeric(&mm->kh1, 1.0, mm->C1, 1.0, mm->C3, mm->Cd)); 12670e3ece09SJunchao Zhang PetscCallCXX(KokkosSparse::spadd_numeric(&mm->kh2, 1.0, mm->C2, 1.0, mm->C4, mm->Co)); 12683ba16761SJacob Faibussowitsch PetscFunctionReturn(PETSC_SUCCESS); 1269076ba34aSJunchao Zhang } 1270076ba34aSJunchao Zhang 127166976f2fSJacob Faibussowitsch static PetscErrorCode MatProductNumeric_MPIAIJKokkos(Mat C) 1272d71ae5a4SJacob Faibussowitsch { 12730e3ece09SJunchao Zhang Mat_MPIAIJ *cmpi = static_cast<Mat_MPIAIJ *>(C->data); 12740e3ece09SJunchao Zhang Mat_Product *product; 1275cc1eb50dSBarry Smith MatProductCtx_MPIAIJKokkos *pdata; 1276076ba34aSJunchao Zhang MatProductType ptype; 12770e3ece09SJunchao Zhang Mat A, B; 1278076ba34aSJunchao Zhang 1279076ba34aSJunchao Zhang PetscFunctionBegin; 12800e3ece09SJunchao Zhang MatCheckProduct(C, 1); // make sure C is a product 12810e3ece09SJunchao Zhang product = C->product; 1282cc1eb50dSBarry Smith pdata = static_cast<MatProductCtx_MPIAIJKokkos *>(product->data); 1283076ba34aSJunchao Zhang ptype = product->type; 1284076ba34aSJunchao Zhang A = product->A; 1285076ba34aSJunchao Zhang B = product->B; 1286076ba34aSJunchao Zhang 12870e3ece09SJunchao Zhang // See if numeric has already been done in symbolic (e.g., user calls MatMatMult(A,B,MAT_INITIAL_MATRIX,..,C)). 12880e3ece09SJunchao Zhang // If yes, skip the numeric, but reset the flag so that next time when user calls MatMatMult(E,F,MAT_REUSE_MATRIX,..,C), 12890e3ece09SJunchao Zhang // we still do numeric. 12900e3ece09SJunchao Zhang if (pdata->reusesym) { // numeric reuses results from symbolic 12910e3ece09SJunchao Zhang pdata->reusesym = PETSC_FALSE; 12923ba16761SJacob Faibussowitsch PetscFunctionReturn(PETSC_SUCCESS); 1293076ba34aSJunchao Zhang } 1294076ba34aSJunchao Zhang 1295076ba34aSJunchao Zhang if (ptype == MATPRODUCT_AB) { 12960e3ece09SJunchao Zhang PetscCall(MatProductNumeric_MPIAIJKokkos_AB(product, A, B, pdata->mmAB)); 1297076ba34aSJunchao Zhang } else if (ptype == MATPRODUCT_AtB) { 12980e3ece09SJunchao Zhang PetscCall(MatProductNumeric_MPIAIJKokkos_AtB(product, A, B, pdata->mmAtB)); 12990e3ece09SJunchao Zhang } else if (ptype == MATPRODUCT_PtAP) { // BtAB, computed by Z = AB; C= BtZ 13000e3ece09SJunchao Zhang PetscCall(MatProductNumeric_MPIAIJKokkos_AB(product, A, B, pdata->mmAB)); 13010e3ece09SJunchao Zhang PetscCall(MatProductNumeric_MPIAIJKokkos_AtB(product, B, pdata->Z, pdata->mmAtB)); 1302076ba34aSJunchao Zhang } 13030e3ece09SJunchao Zhang PetscCall(MatSeqAIJKokkosModifyDevice(cmpi->A)); // mark that A, B on device are modified 13040e3ece09SJunchao Zhang PetscCall(MatSeqAIJKokkosModifyDevice(cmpi->B)); 13053ba16761SJacob Faibussowitsch PetscFunctionReturn(PETSC_SUCCESS); 1306076ba34aSJunchao Zhang } 1307076ba34aSJunchao Zhang 130866976f2fSJacob Faibussowitsch static PetscErrorCode MatProductSymbolic_MPIAIJKokkos(Mat C) 1309d71ae5a4SJacob Faibussowitsch { 1310076ba34aSJunchao Zhang Mat A, B; 13110e3ece09SJunchao Zhang Mat_Product *product; 1312076ba34aSJunchao Zhang MatProductType ptype; 1313cc1eb50dSBarry Smith MatProductCtx_MPIAIJKokkos *pdata; 1314076ba34aSJunchao Zhang MatMatStruct *mm = NULL; 13150e3ece09SJunchao Zhang PetscInt m, n, M, N; 13160e3ece09SJunchao Zhang Mat Cd, Co; 13170e3ece09SJunchao Zhang MPI_Comm comm; 1318f18d1d10SJunchao Zhang Mat_MPIAIJ *mpiaij; 1319076ba34aSJunchao Zhang 1320076ba34aSJunchao Zhang PetscFunctionBegin; 13210e3ece09SJunchao Zhang PetscCall(PetscObjectGetComm((PetscObject)C, &comm)); 1322076ba34aSJunchao Zhang MatCheckProduct(C, 1); 13230e3ece09SJunchao Zhang product = C->product; 13240e3ece09SJunchao Zhang PetscCheck(!product->data, comm, PETSC_ERR_PLIB, "Product data not empty"); 1325076ba34aSJunchao Zhang ptype = product->type; 1326076ba34aSJunchao Zhang A = product->A; 1327076ba34aSJunchao Zhang B = product->B; 1328076ba34aSJunchao Zhang 1329076ba34aSJunchao Zhang switch (ptype) { 13309371c9d4SSatish Balay case MATPRODUCT_AB: 13319371c9d4SSatish Balay m = A->rmap->n; 13329371c9d4SSatish Balay n = B->cmap->n; 13339371c9d4SSatish Balay M = A->rmap->N; 13349371c9d4SSatish Balay N = B->cmap->N; 13359371c9d4SSatish Balay break; 13369371c9d4SSatish Balay case MATPRODUCT_AtB: 13379371c9d4SSatish Balay m = A->cmap->n; 13389371c9d4SSatish Balay n = B->cmap->n; 13399371c9d4SSatish Balay M = A->cmap->N; 13409371c9d4SSatish Balay N = B->cmap->N; 13419371c9d4SSatish Balay break; 13429371c9d4SSatish Balay case MATPRODUCT_PtAP: 13439371c9d4SSatish Balay m = B->cmap->n; 13449371c9d4SSatish Balay n = B->cmap->n; 13459371c9d4SSatish Balay M = B->cmap->N; 13469371c9d4SSatish Balay N = B->cmap->N; 13479371c9d4SSatish Balay break; /* BtAB */ 1348d71ae5a4SJacob Faibussowitsch default: 13490e3ece09SJunchao Zhang SETERRQ(comm, PETSC_ERR_PLIB, "Not for product type %s", MatProductTypes[ptype]); 1350076ba34aSJunchao Zhang } 1351076ba34aSJunchao Zhang 13529566063dSJacob Faibussowitsch PetscCall(MatSetSizes(C, m, n, M, N)); 13539566063dSJacob Faibussowitsch PetscCall(PetscLayoutSetUp(C->rmap)); 13549566063dSJacob Faibussowitsch PetscCall(PetscLayoutSetUp(C->cmap)); 13559566063dSJacob Faibussowitsch PetscCall(MatSetType(C, ((PetscObject)A)->type_name)); 1356076ba34aSJunchao Zhang 1357cc1eb50dSBarry Smith pdata = new MatProductCtx_MPIAIJKokkos(); 13580e3ece09SJunchao Zhang pdata->reusesym = product->api_user; 1359076ba34aSJunchao Zhang 1360076ba34aSJunchao Zhang if (ptype == MATPRODUCT_AB) { 13610e3ece09SJunchao Zhang auto mmAB = new MatMatStruct_AB(); 13620e3ece09SJunchao Zhang PetscCall(MatProductSymbolic_MPIAIJKokkos_AB(product, A, B, mmAB)); 13630e3ece09SJunchao Zhang mm = pdata->mmAB = mmAB; 1364076ba34aSJunchao Zhang } else if (ptype == MATPRODUCT_AtB) { 13650e3ece09SJunchao Zhang auto mmAtB = new MatMatStruct_AtB(); 13660e3ece09SJunchao Zhang PetscCall(MatProductSymbolic_MPIAIJKokkos_AtB(product, A, B, mmAtB)); 13670e3ece09SJunchao Zhang mm = pdata->mmAtB = mmAtB; 13680e3ece09SJunchao Zhang } else if (ptype == MATPRODUCT_PtAP) { // C = BtAB, computed as Z = AB; C= BtZ 13690e3ece09SJunchao Zhang Mat Zd, Zo, Z; // Zd, Zo are owned by pdata->Z 13700e3ece09SJunchao Zhang 13710e3ece09SJunchao Zhang auto mmAB = new MatMatStruct_AB(); 13720e3ece09SJunchao Zhang PetscCall(MatProductSymbolic_MPIAIJKokkos_AB(product, A, B, mmAB)); // Z stored as mmAB->{Cd, Co} 13730e3ece09SJunchao Zhang PetscCall(MatCreateSeqAIJKokkosWithKokkosCsrMatrix(PETSC_COMM_SELF, mmAB->Cd, &Zd)); 13740e3ece09SJunchao Zhang PetscCall(MatCreateSeqAIJKokkosWithKokkosCsrMatrix(PETSC_COMM_SELF, mmAB->Co, &Zo)); 13750e3ece09SJunchao Zhang pdata->mmAB = mmAB; 13760e3ece09SJunchao Zhang 13770e3ece09SJunchao Zhang m = A->rmap->n; // Z's layout 13780e3ece09SJunchao Zhang n = B->cmap->n; 13790e3ece09SJunchao Zhang M = A->rmap->N; 13800e3ece09SJunchao Zhang N = B->cmap->N; 1381c0c276a7Ssdargavi PetscCall(MatCreateMPIAIJWithSeqAIJ(comm, M, N, Zd, Zo, mmAB->garray, &Z)); 13820e3ece09SJunchao Zhang 13830e3ece09SJunchao Zhang auto mmAtB = new MatMatStruct_AtB(); 13840e3ece09SJunchao Zhang PetscCall(MatProductSymbolic_MPIAIJKokkos_AtB(product, B, Z, mmAtB)); // final result C stored as mmAtB->{Cd, Co} 13850e3ece09SJunchao Zhang 13860e3ece09SJunchao Zhang pdata->Z = Z; // give ownership to pdata 13870e3ece09SJunchao Zhang mm = pdata->mmAtB = mmAtB; 1388076ba34aSJunchao Zhang } 13890e3ece09SJunchao Zhang 13900e3ece09SJunchao Zhang PetscCall(MatCreateSeqAIJKokkosWithKokkosCsrMatrix(PETSC_COMM_SELF, mm->Cd, &Cd)); 13910e3ece09SJunchao Zhang PetscCall(MatCreateSeqAIJKokkosWithKokkosCsrMatrix(PETSC_COMM_SELF, mm->Co, &Co)); 1392f18d1d10SJunchao Zhang 1393f18d1d10SJunchao Zhang mpiaij = (Mat_MPIAIJ *)C->data; 1394f18d1d10SJunchao Zhang mpiaij->A = Cd; 1395f18d1d10SJunchao Zhang mpiaij->B = Co; 1396f18d1d10SJunchao Zhang mpiaij->garray = mm->garray; 1397f18d1d10SJunchao Zhang 1398f18d1d10SJunchao Zhang C->preallocated = PETSC_TRUE; 1399f18d1d10SJunchao Zhang C->nooffprocentries = PETSC_TRUE; /* See MatAssemblyBegin_MPIAIJ. In effect, making MatAssemblyBegin a nop */ 1400f18d1d10SJunchao Zhang 1401f18d1d10SJunchao Zhang PetscCall(MatSetOption(C, MAT_NO_OFF_PROC_ENTRIES, PETSC_TRUE)); 1402f18d1d10SJunchao Zhang PetscCall(MatAssemblyBegin(C, MAT_FINAL_ASSEMBLY)); 1403f18d1d10SJunchao Zhang PetscCall(MatAssemblyEnd(C, MAT_FINAL_ASSEMBLY)); 1404f18d1d10SJunchao Zhang PetscCall(MatSetOption(C, MAT_NO_OFF_PROC_ENTRIES, PETSC_FALSE)); 1405f18d1d10SJunchao Zhang PetscCall(MatSetOption(C, MAT_NEW_NONZERO_LOCATION_ERR, PETSC_TRUE)); 1406f18d1d10SJunchao Zhang 140739cfb508SMark Adams /* set block sizes */ 140839cfb508SMark Adams switch (ptype) { 140939cfb508SMark Adams case MATPRODUCT_PtAP: 141039cfb508SMark Adams if (B->cmap->bs > 1) PetscCall(MatSetBlockSizes(C, B->cmap->bs, B->cmap->bs)); 141139cfb508SMark Adams break; 141239cfb508SMark Adams case MATPRODUCT_RARt: 141339cfb508SMark Adams if (B->rmap->bs > 1) PetscCall(MatSetBlockSizes(C, B->rmap->bs, B->rmap->bs)); 141439cfb508SMark Adams break; 141539cfb508SMark Adams case MATPRODUCT_ABC: 141639cfb508SMark Adams PetscCall(MatSetBlockSizesFromMats(C, A, product->C)); 141739cfb508SMark Adams break; 141839cfb508SMark Adams case MATPRODUCT_AB: 141939cfb508SMark Adams PetscCall(MatSetBlockSizesFromMats(C, A, B)); 142039cfb508SMark Adams break; 142139cfb508SMark Adams case MATPRODUCT_AtB: 142239cfb508SMark Adams if (A->cmap->bs > 1 || B->cmap->bs > 1) PetscCall(MatSetBlockSizes(C, A->cmap->bs, B->cmap->bs)); 142339cfb508SMark Adams break; 142439cfb508SMark Adams case MATPRODUCT_ABt: 142539cfb508SMark Adams if (A->rmap->bs > 1 || B->rmap->bs > 1) PetscCall(MatSetBlockSizes(C, A->rmap->bs, B->rmap->bs)); 142639cfb508SMark Adams break; 142739cfb508SMark Adams default: 142839cfb508SMark Adams SETERRQ(PetscObjectComm((PetscObject)C), PETSC_ERR_PLIB, "Not for ProductType %s", MatProductTypes[ptype]); 142939cfb508SMark Adams } 14300e3ece09SJunchao Zhang C->product->data = pdata; 1431cc1eb50dSBarry Smith C->product->destroy = MatProductCtxDestroy_MPIAIJKokkos; 1432076ba34aSJunchao Zhang C->ops->productnumeric = MatProductNumeric_MPIAIJKokkos; 14333ba16761SJacob Faibussowitsch PetscFunctionReturn(PETSC_SUCCESS); 1434076ba34aSJunchao Zhang } 1435076ba34aSJunchao Zhang 1436d71ae5a4SJacob Faibussowitsch PETSC_INTERN PetscErrorCode MatProductSetFromOptions_MPIAIJKokkos(Mat mat) 1437d71ae5a4SJacob Faibussowitsch { 1438076ba34aSJunchao Zhang Mat_Product *product = mat->product; 1439076ba34aSJunchao Zhang PetscBool match = PETSC_FALSE; 1440076ba34aSJunchao Zhang PetscBool usecpu = PETSC_FALSE; 1441076ba34aSJunchao Zhang 1442076ba34aSJunchao Zhang PetscFunctionBegin; 1443076ba34aSJunchao Zhang MatCheckProduct(mat, 1); 144448a46eb9SPierre Jolivet if (!product->A->boundtocpu && !product->B->boundtocpu) PetscCall(PetscObjectTypeCompare((PetscObject)product->B, ((PetscObject)product->A)->type_name, &match)); 1445076ba34aSJunchao Zhang if (match) { /* we can always fallback to the CPU if requested */ 1446076ba34aSJunchao Zhang switch (product->type) { 1447076ba34aSJunchao Zhang case MATPRODUCT_AB: 1448076ba34aSJunchao Zhang if (product->api_user) { 1449d0609cedSBarry Smith PetscOptionsBegin(PetscObjectComm((PetscObject)mat), ((PetscObject)mat)->prefix, "MatMatMult", "Mat"); 14509566063dSJacob Faibussowitsch PetscCall(PetscOptionsBool("-matmatmult_backend_cpu", "Use CPU code", "MatMatMult", usecpu, &usecpu, NULL)); 1451d0609cedSBarry Smith PetscOptionsEnd(); 1452076ba34aSJunchao Zhang } else { 1453d0609cedSBarry Smith PetscOptionsBegin(PetscObjectComm((PetscObject)mat), ((PetscObject)mat)->prefix, "MatProduct_AB", "Mat"); 14549566063dSJacob Faibussowitsch PetscCall(PetscOptionsBool("-mat_product_algorithm_backend_cpu", "Use CPU code", "MatMatMult", usecpu, &usecpu, NULL)); 1455d0609cedSBarry Smith PetscOptionsEnd(); 1456076ba34aSJunchao Zhang } 1457076ba34aSJunchao Zhang break; 1458076ba34aSJunchao Zhang case MATPRODUCT_AtB: 1459076ba34aSJunchao Zhang if (product->api_user) { 1460d0609cedSBarry Smith PetscOptionsBegin(PetscObjectComm((PetscObject)mat), ((PetscObject)mat)->prefix, "MatTransposeMatMult", "Mat"); 14619566063dSJacob Faibussowitsch PetscCall(PetscOptionsBool("-mattransposematmult_backend_cpu", "Use CPU code", "MatTransposeMatMult", usecpu, &usecpu, NULL)); 1462d0609cedSBarry Smith PetscOptionsEnd(); 1463076ba34aSJunchao Zhang } else { 1464d0609cedSBarry Smith PetscOptionsBegin(PetscObjectComm((PetscObject)mat), ((PetscObject)mat)->prefix, "MatProduct_AtB", "Mat"); 14659566063dSJacob Faibussowitsch PetscCall(PetscOptionsBool("-mat_product_algorithm_backend_cpu", "Use CPU code", "MatTransposeMatMult", usecpu, &usecpu, NULL)); 1466d0609cedSBarry Smith PetscOptionsEnd(); 1467076ba34aSJunchao Zhang } 1468076ba34aSJunchao Zhang break; 1469076ba34aSJunchao Zhang case MATPRODUCT_PtAP: 1470076ba34aSJunchao Zhang if (product->api_user) { 1471d0609cedSBarry Smith PetscOptionsBegin(PetscObjectComm((PetscObject)mat), ((PetscObject)mat)->prefix, "MatPtAP", "Mat"); 14729566063dSJacob Faibussowitsch PetscCall(PetscOptionsBool("-matptap_backend_cpu", "Use CPU code", "MatPtAP", usecpu, &usecpu, NULL)); 1473d0609cedSBarry Smith PetscOptionsEnd(); 1474076ba34aSJunchao Zhang } else { 1475d0609cedSBarry Smith PetscOptionsBegin(PetscObjectComm((PetscObject)mat), ((PetscObject)mat)->prefix, "MatProduct_PtAP", "Mat"); 14769566063dSJacob Faibussowitsch PetscCall(PetscOptionsBool("-mat_product_algorithm_backend_cpu", "Use CPU code", "MatPtAP", usecpu, &usecpu, NULL)); 1477d0609cedSBarry Smith PetscOptionsEnd(); 1478076ba34aSJunchao Zhang } 1479076ba34aSJunchao Zhang break; 1480d71ae5a4SJacob Faibussowitsch default: 1481d71ae5a4SJacob Faibussowitsch break; 1482076ba34aSJunchao Zhang } 1483076ba34aSJunchao Zhang match = (PetscBool)!usecpu; 1484076ba34aSJunchao Zhang } 1485076ba34aSJunchao Zhang if (match) { 1486076ba34aSJunchao Zhang switch (product->type) { 1487076ba34aSJunchao Zhang case MATPRODUCT_AB: 1488076ba34aSJunchao Zhang case MATPRODUCT_AtB: 1489d71ae5a4SJacob Faibussowitsch case MATPRODUCT_PtAP: 1490d71ae5a4SJacob Faibussowitsch mat->ops->productsymbolic = MatProductSymbolic_MPIAIJKokkos; 1491d71ae5a4SJacob Faibussowitsch break; 1492d71ae5a4SJacob Faibussowitsch default: 1493d71ae5a4SJacob Faibussowitsch break; 1494076ba34aSJunchao Zhang } 1495076ba34aSJunchao Zhang } 1496076ba34aSJunchao Zhang /* fallback to MPIAIJ ops */ 149748a46eb9SPierre Jolivet if (!mat->ops->productsymbolic) PetscCall(MatProductSetFromOptions_MPIAIJ(mat)); 14983ba16761SJacob Faibussowitsch PetscFunctionReturn(PETSC_SUCCESS); 1499076ba34aSJunchao Zhang } 1500076ba34aSJunchao Zhang 15012c4ab24aSJunchao Zhang // Mirror of MatCOOStruct_MPIAIJ on device 15022c4ab24aSJunchao Zhang struct MatCOOStruct_MPIAIJKokkos { 15032c4ab24aSJunchao Zhang PetscCount n; 15042c4ab24aSJunchao Zhang PetscSF sf; 15052c4ab24aSJunchao Zhang PetscCount Annz, Bnnz; 15062c4ab24aSJunchao Zhang PetscCount Annz2, Bnnz2; 15072c4ab24aSJunchao Zhang PetscCountKokkosView Ajmap1, Aperm1; 15082c4ab24aSJunchao Zhang PetscCountKokkosView Bjmap1, Bperm1; 15092c4ab24aSJunchao Zhang PetscCountKokkosView Aimap2, Ajmap2, Aperm2; 15102c4ab24aSJunchao Zhang PetscCountKokkosView Bimap2, Bjmap2, Bperm2; 15112c4ab24aSJunchao Zhang PetscCountKokkosView Cperm1; 15122c4ab24aSJunchao Zhang MatScalarKokkosView sendbuf, recvbuf; 15132c4ab24aSJunchao Zhang 151492896123SJunchao Zhang MatCOOStruct_MPIAIJKokkos(const MatCOOStruct_MPIAIJ *coo_h) 15152c4ab24aSJunchao Zhang { 15164df4a32cSJunchao Zhang auto exec = PetscGetKokkosExecutionSpace(); 151792896123SJunchao Zhang 151892896123SJunchao Zhang n = coo_h->n; 151992896123SJunchao Zhang sf = coo_h->sf; 152092896123SJunchao Zhang Annz = coo_h->Annz; 152192896123SJunchao Zhang Bnnz = coo_h->Bnnz; 152292896123SJunchao Zhang Annz2 = coo_h->Annz2; 152392896123SJunchao Zhang Bnnz2 = coo_h->Bnnz2; 152492896123SJunchao Zhang Ajmap1 = Kokkos::create_mirror_view_and_copy(exec, PetscCountKokkosViewHost(coo_h->Ajmap1, coo_h->Annz + 1)); 152592896123SJunchao Zhang Aperm1 = Kokkos::create_mirror_view_and_copy(exec, PetscCountKokkosViewHost(coo_h->Aperm1, coo_h->Atot1)); 152692896123SJunchao Zhang Bjmap1 = Kokkos::create_mirror_view_and_copy(exec, PetscCountKokkosViewHost(coo_h->Bjmap1, coo_h->Bnnz + 1)); 152792896123SJunchao Zhang Bperm1 = Kokkos::create_mirror_view_and_copy(exec, PetscCountKokkosViewHost(coo_h->Bperm1, coo_h->Btot1)); 152892896123SJunchao Zhang Aimap2 = Kokkos::create_mirror_view_and_copy(exec, PetscCountKokkosViewHost(coo_h->Aimap2, coo_h->Annz2)); 152992896123SJunchao Zhang Ajmap2 = Kokkos::create_mirror_view_and_copy(exec, PetscCountKokkosViewHost(coo_h->Ajmap2, coo_h->Annz2 + 1)); 153092896123SJunchao Zhang Aperm2 = Kokkos::create_mirror_view_and_copy(exec, PetscCountKokkosViewHost(coo_h->Aperm2, coo_h->Atot2)); 153192896123SJunchao Zhang Bimap2 = Kokkos::create_mirror_view_and_copy(exec, PetscCountKokkosViewHost(coo_h->Bimap2, coo_h->Bnnz2)); 153292896123SJunchao Zhang Bjmap2 = Kokkos::create_mirror_view_and_copy(exec, PetscCountKokkosViewHost(coo_h->Bjmap2, coo_h->Bnnz2 + 1)); 153392896123SJunchao Zhang Bperm2 = Kokkos::create_mirror_view_and_copy(exec, PetscCountKokkosViewHost(coo_h->Bperm2, coo_h->Btot2)); 153492896123SJunchao Zhang Cperm1 = Kokkos::create_mirror_view_and_copy(exec, PetscCountKokkosViewHost(coo_h->Cperm1, coo_h->sendlen)); 153592896123SJunchao Zhang sendbuf = Kokkos::create_mirror_view(Kokkos::WithoutInitializing, exec, MatScalarKokkosViewHost(coo_h->sendbuf, coo_h->sendlen)); 153692896123SJunchao Zhang recvbuf = Kokkos::create_mirror_view(Kokkos::WithoutInitializing, exec, MatScalarKokkosViewHost(coo_h->recvbuf, coo_h->recvlen)); 15372c4ab24aSJunchao Zhang PetscCallVoid(PetscObjectReference((PetscObject)sf)); 15382c4ab24aSJunchao Zhang } 15392c4ab24aSJunchao Zhang 15402c4ab24aSJunchao Zhang ~MatCOOStruct_MPIAIJKokkos() { PetscCallVoid(PetscSFDestroy(&sf)); } 15412c4ab24aSJunchao Zhang }; 15422c4ab24aSJunchao Zhang 1543*2a8381b2SBarry Smith static PetscErrorCode MatCOOStructDestroy_MPIAIJKokkos(PetscCtxRt data) 15442c4ab24aSJunchao Zhang { 15452c4ab24aSJunchao Zhang PetscFunctionBegin; 1546*2a8381b2SBarry Smith PetscCallCXX(delete *static_cast<MatCOOStruct_MPIAIJKokkos **>(data)); 15472c4ab24aSJunchao Zhang PetscFunctionReturn(PETSC_SUCCESS); 15482c4ab24aSJunchao Zhang } 15492c4ab24aSJunchao Zhang 1550d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatSetPreallocationCOO_MPIAIJKokkos(Mat mat, PetscCount coo_n, PetscInt coo_i[], PetscInt coo_j[]) 1551d71ae5a4SJacob Faibussowitsch { 15522c4ab24aSJunchao Zhang PetscContainer container_h, container_d; 15532c4ab24aSJunchao Zhang MatCOOStruct_MPIAIJ *coo_h; 15542c4ab24aSJunchao Zhang MatCOOStruct_MPIAIJKokkos *coo_d; 155542550becSJunchao Zhang 155642550becSJunchao Zhang PetscFunctionBegin; 155730203840SJunchao Zhang PetscCall(MatSetPreallocationCOO_MPIAIJ(mat, coo_n, coo_i, coo_j)); /* mpiaij->A,B's type is set to seqaijkokkos */ 1558cbc6b225SStefano Zampini mat->preallocated = PETSC_TRUE; 15599566063dSJacob Faibussowitsch PetscCall(MatAssemblyBegin(mat, MAT_FINAL_ASSEMBLY)); 15609566063dSJacob Faibussowitsch PetscCall(MatAssemblyEnd(mat, MAT_FINAL_ASSEMBLY)); 15619566063dSJacob Faibussowitsch PetscCall(MatZeroEntries(mat)); 15622c4ab24aSJunchao Zhang 15632c4ab24aSJunchao Zhang // Copy the COO struct to device 15642c4ab24aSJunchao Zhang PetscCall(PetscObjectQuery((PetscObject)mat, "__PETSc_MatCOOStruct_Host", (PetscObject *)&container_h)); 1565*2a8381b2SBarry Smith PetscCall(PetscContainerGetPointer(container_h, &coo_h)); 15662c4ab24aSJunchao Zhang PetscCallCXX(coo_d = new MatCOOStruct_MPIAIJKokkos(coo_h)); 15672c4ab24aSJunchao Zhang 15682c4ab24aSJunchao Zhang // Put the COO struct in a container and then attach that to the matrix 15692c4ab24aSJunchao Zhang PetscCall(PetscContainerCreate(PETSC_COMM_SELF, &container_d)); 15702c4ab24aSJunchao Zhang PetscCall(PetscContainerSetPointer(container_d, coo_d)); 157149abdd8aSBarry Smith PetscCall(PetscContainerSetCtxDestroy(container_d, MatCOOStructDestroy_MPIAIJKokkos)); 15722c4ab24aSJunchao Zhang PetscCall(PetscObjectCompose((PetscObject)mat, "__PETSc_MatCOOStruct_Device", (PetscObject)container_d)); 15732c4ab24aSJunchao Zhang PetscCall(PetscContainerDestroy(&container_d)); 15743ba16761SJacob Faibussowitsch PetscFunctionReturn(PETSC_SUCCESS); 157542550becSJunchao Zhang } 157642550becSJunchao Zhang 1577d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatSetValuesCOO_MPIAIJKokkos(Mat mat, const PetscScalar v[], InsertMode imode) 1578d71ae5a4SJacob Faibussowitsch { 1579394ed5ebSJunchao Zhang Mat_MPIAIJ *mpiaij = static_cast<Mat_MPIAIJ *>(mat->data); 158042550becSJunchao Zhang Mat A = mpiaij->A, B = mpiaij->B; 158142550becSJunchao Zhang MatScalarKokkosView Aa, Ba; 1582394ed5ebSJunchao Zhang MatScalarKokkosView v1; 158342550becSJunchao Zhang PetscMemType memtype; 15842c4ab24aSJunchao Zhang PetscContainer container; 15852c4ab24aSJunchao Zhang MatCOOStruct_MPIAIJKokkos *coo; 15864df4a32cSJunchao Zhang Kokkos::DefaultExecutionSpace exec = PetscGetKokkosExecutionSpace(); 158742550becSJunchao Zhang 158842550becSJunchao Zhang PetscFunctionBegin; 15892c4ab24aSJunchao Zhang PetscCall(PetscObjectQuery((PetscObject)mat, "__PETSc_MatCOOStruct_Device", (PetscObject *)&container)); 1590*2a8381b2SBarry Smith PetscCall(PetscContainerGetPointer(container, &coo)); 15912c4ab24aSJunchao Zhang 15922c4ab24aSJunchao Zhang const auto &n = coo->n; 15932c4ab24aSJunchao Zhang const auto &Annz = coo->Annz; 15942c4ab24aSJunchao Zhang const auto &Annz2 = coo->Annz2; 15952c4ab24aSJunchao Zhang const auto &Bnnz = coo->Bnnz; 15962c4ab24aSJunchao Zhang const auto &Bnnz2 = coo->Bnnz2; 15972c4ab24aSJunchao Zhang const auto &vsend = coo->sendbuf; 15982c4ab24aSJunchao Zhang const auto &v2 = coo->recvbuf; 15992c4ab24aSJunchao Zhang const auto &Ajmap1 = coo->Ajmap1; 16002c4ab24aSJunchao Zhang const auto &Ajmap2 = coo->Ajmap2; 16012c4ab24aSJunchao Zhang const auto &Aimap2 = coo->Aimap2; 16022c4ab24aSJunchao Zhang const auto &Bjmap1 = coo->Bjmap1; 16032c4ab24aSJunchao Zhang const auto &Bjmap2 = coo->Bjmap2; 16042c4ab24aSJunchao Zhang const auto &Bimap2 = coo->Bimap2; 16052c4ab24aSJunchao Zhang const auto &Aperm1 = coo->Aperm1; 16062c4ab24aSJunchao Zhang const auto &Aperm2 = coo->Aperm2; 16072c4ab24aSJunchao Zhang const auto &Bperm1 = coo->Bperm1; 16082c4ab24aSJunchao Zhang const auto &Bperm2 = coo->Bperm2; 16092c4ab24aSJunchao Zhang const auto &Cperm1 = coo->Cperm1; 16102c4ab24aSJunchao Zhang 16119566063dSJacob Faibussowitsch PetscCall(PetscGetMemType(v, &memtype)); /* Return PETSC_MEMTYPE_HOST when v is NULL */ 161242550becSJunchao Zhang if (PetscMemTypeHost(memtype)) { /* If user gave v[] in host, we need to copy it to device if any */ 161392896123SJunchao Zhang v1 = Kokkos::create_mirror_view_and_copy(exec, MatScalarKokkosViewHost((PetscScalar *)v, n)); 161442550becSJunchao Zhang } else { 16152c4ab24aSJunchao Zhang v1 = MatScalarKokkosView((PetscScalar *)v, n); /* Directly use v[]'s memory */ 161642550becSJunchao Zhang } 161742550becSJunchao Zhang 161842550becSJunchao Zhang if (imode == INSERT_VALUES) { 16199566063dSJacob Faibussowitsch PetscCall(MatSeqAIJGetKokkosViewWrite(A, &Aa)); /* write matrix values */ 16209566063dSJacob Faibussowitsch PetscCall(MatSeqAIJGetKokkosViewWrite(B, &Ba)); 1621394ed5ebSJunchao Zhang } else { 16229566063dSJacob Faibussowitsch PetscCall(MatSeqAIJGetKokkosView(A, &Aa)); /* read & write matrix values */ 16239566063dSJacob Faibussowitsch PetscCall(MatSeqAIJGetKokkosView(B, &Ba)); 162442550becSJunchao Zhang } 162542550becSJunchao Zhang 162608bb9926SJunchao Zhang PetscCall(PetscLogGpuTimeBegin()); 162742550becSJunchao Zhang /* Pack entries to be sent to remote */ 162892896123SJunchao Zhang Kokkos::parallel_for(Kokkos::RangePolicy<>(exec, 0, vsend.extent(0)), KOKKOS_LAMBDA(const PetscCount i) { vsend(i) = v1(Cperm1(i)); }); 162942550becSJunchao Zhang 163042550becSJunchao Zhang /* Send remote entries to their owner and overlap the communication with local computation */ 16312c4ab24aSJunchao Zhang PetscCall(PetscSFReduceWithMemTypeBegin(coo->sf, MPIU_SCALAR, PETSC_MEMTYPE_KOKKOS, vsend.data(), PETSC_MEMTYPE_KOKKOS, v2.data(), MPI_REPLACE)); 1632158ec288SJunchao Zhang /* Add local entries to A and B in one kernel */ 16339371c9d4SSatish Balay Kokkos::parallel_for( 163492896123SJunchao Zhang Kokkos::RangePolicy<>(exec, 0, Annz + Bnnz), KOKKOS_LAMBDA(PetscCount i) { 1635158ec288SJunchao Zhang PetscScalar sum = 0.0; 1636158ec288SJunchao Zhang if (i < Annz) { 1637158ec288SJunchao Zhang for (PetscCount k = Ajmap1(i); k < Ajmap1(i + 1); k++) sum += v1(Aperm1(k)); 1638ac38520cSJunchao Zhang Aa(i) = (imode == INSERT_VALUES ? 0.0 : Aa(i)) + sum; 1639158ec288SJunchao Zhang } else { 1640158ec288SJunchao Zhang i -= Annz; 1641158ec288SJunchao Zhang for (PetscCount k = Bjmap1(i); k < Bjmap1(i + 1); k++) sum += v1(Bperm1(k)); 1642ac38520cSJunchao Zhang Ba(i) = (imode == INSERT_VALUES ? 0.0 : Ba(i)) + sum; 1643158ec288SJunchao Zhang } 1644158ec288SJunchao Zhang }); 16452c4ab24aSJunchao Zhang PetscCall(PetscSFReduceEnd(coo->sf, MPIU_SCALAR, vsend.data(), v2.data(), MPI_REPLACE)); 164642550becSJunchao Zhang 1647158ec288SJunchao Zhang /* Add received remote entries to A and B in one kernel */ 16489371c9d4SSatish Balay Kokkos::parallel_for( 164992896123SJunchao Zhang Kokkos::RangePolicy<>(exec, 0, Annz2 + Bnnz2), KOKKOS_LAMBDA(PetscCount i) { 1650158ec288SJunchao Zhang if (i < Annz2) { 1651158ec288SJunchao Zhang for (PetscCount k = Ajmap2(i); k < Ajmap2(i + 1); k++) Aa(Aimap2(i)) += v2(Aperm2(k)); 1652158ec288SJunchao Zhang } else { 1653158ec288SJunchao Zhang i -= Annz2; 1654158ec288SJunchao Zhang for (PetscCount k = Bjmap2(i); k < Bjmap2(i + 1); k++) Ba(Bimap2(i)) += v2(Bperm2(k)); 1655158ec288SJunchao Zhang } 1656158ec288SJunchao Zhang }); 165708bb9926SJunchao Zhang PetscCall(PetscLogGpuTimeEnd()); 165842550becSJunchao Zhang 1659394ed5ebSJunchao Zhang if (imode == INSERT_VALUES) { 16609566063dSJacob Faibussowitsch PetscCall(MatSeqAIJRestoreKokkosViewWrite(A, &Aa)); /* Increase A & B's state etc. */ 16619566063dSJacob Faibussowitsch PetscCall(MatSeqAIJRestoreKokkosViewWrite(B, &Ba)); 1662394ed5ebSJunchao Zhang } else { 16639566063dSJacob Faibussowitsch PetscCall(MatSeqAIJRestoreKokkosView(A, &Aa)); 16649566063dSJacob Faibussowitsch PetscCall(MatSeqAIJRestoreKokkosView(B, &Ba)); 1665394ed5ebSJunchao Zhang } 16663ba16761SJacob Faibussowitsch PetscFunctionReturn(PETSC_SUCCESS); 166742550becSJunchao Zhang } 166842550becSJunchao Zhang 16692c4ab24aSJunchao Zhang static PetscErrorCode MatDestroy_MPIAIJKokkos(Mat A) 1670d71ae5a4SJacob Faibussowitsch { 1671076ba34aSJunchao Zhang PetscFunctionBegin; 16729566063dSJacob Faibussowitsch PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatMPIAIJSetPreallocation_C", NULL)); 16739566063dSJacob Faibussowitsch PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatMPIAIJGetLocalMatMerge_C", NULL)); 16749566063dSJacob Faibussowitsch PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatSetPreallocationCOO_C", NULL)); 16759566063dSJacob Faibussowitsch PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatSetValuesCOO_C", NULL)); 167657761e9aSJunchao Zhang #if defined(PETSC_HAVE_HYPRE) 167757761e9aSJunchao Zhang PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatConvert_mpiaijkokkos_hypre_C", NULL)); 167857761e9aSJunchao Zhang #endif 16799566063dSJacob Faibussowitsch PetscCall(MatDestroy_MPIAIJ(A)); 16803ba16761SJacob Faibussowitsch PetscFunctionReturn(PETSC_SUCCESS); 1681076ba34aSJunchao Zhang } 1682076ba34aSJunchao Zhang 1683f4747e26SJunchao Zhang static PetscErrorCode MatShift_MPIAIJKokkos(Mat A, PetscScalar a) 1684f4747e26SJunchao Zhang { 1685f4747e26SJunchao Zhang Mat_MPIAIJ *mpiaij = static_cast<Mat_MPIAIJ *>(A->data); 1686f4747e26SJunchao Zhang PetscBool congruent; 1687f4747e26SJunchao Zhang 1688f4747e26SJunchao Zhang PetscFunctionBegin; 1689f4747e26SJunchao Zhang PetscCall(MatHasCongruentLayouts(A, &congruent)); 1690f4747e26SJunchao Zhang if (congruent) { // square matrix and the diagonals are solely in the diag block 1691f4747e26SJunchao Zhang PetscCall(MatShift(mpiaij->A, a)); 1692f4747e26SJunchao Zhang } else { // too hard, use the general version 1693f4747e26SJunchao Zhang PetscCall(MatShift_Basic(A, a)); 1694f4747e26SJunchao Zhang } 1695f4747e26SJunchao Zhang PetscFunctionReturn(PETSC_SUCCESS); 1696f4747e26SJunchao Zhang } 1697f4747e26SJunchao Zhang 16982c4ab24aSJunchao Zhang static PetscErrorCode MatSetOps_MPIAIJKokkos(Mat B) 16992c4ab24aSJunchao Zhang { 17002c4ab24aSJunchao Zhang PetscFunctionBegin; 17012c4ab24aSJunchao Zhang B->ops->assemblyend = MatAssemblyEnd_MPIAIJKokkos; 17022c4ab24aSJunchao Zhang B->ops->mult = MatMult_MPIAIJKokkos; 17032c4ab24aSJunchao Zhang B->ops->multadd = MatMultAdd_MPIAIJKokkos; 17042c4ab24aSJunchao Zhang B->ops->multtranspose = MatMultTranspose_MPIAIJKokkos; 17052c4ab24aSJunchao Zhang B->ops->productsetfromoptions = MatProductSetFromOptions_MPIAIJKokkos; 17062c4ab24aSJunchao Zhang B->ops->destroy = MatDestroy_MPIAIJKokkos; 1707f4747e26SJunchao Zhang B->ops->shift = MatShift_MPIAIJKokkos; 170803db1824SAlex Lindsay B->ops->getcurrentmemtype = MatGetCurrentMemType_MPIAIJ; 17092c4ab24aSJunchao Zhang 17102c4ab24aSJunchao Zhang PetscCall(PetscObjectComposeFunction((PetscObject)B, "MatMPIAIJSetPreallocation_C", MatMPIAIJSetPreallocation_MPIAIJKokkos)); 17112c4ab24aSJunchao Zhang PetscCall(PetscObjectComposeFunction((PetscObject)B, "MatMPIAIJGetLocalMatMerge_C", MatMPIAIJGetLocalMatMerge_MPIAIJKokkos)); 17122c4ab24aSJunchao Zhang PetscCall(PetscObjectComposeFunction((PetscObject)B, "MatSetPreallocationCOO_C", MatSetPreallocationCOO_MPIAIJKokkos)); 17132c4ab24aSJunchao Zhang PetscCall(PetscObjectComposeFunction((PetscObject)B, "MatSetValuesCOO_C", MatSetValuesCOO_MPIAIJKokkos)); 171457761e9aSJunchao Zhang #if defined(PETSC_HAVE_HYPRE) 171557761e9aSJunchao Zhang PetscCall(PetscObjectComposeFunction((PetscObject)B, "MatConvert_mpiaijkokkos_hypre_C", MatConvert_AIJ_HYPRE)); 171657761e9aSJunchao Zhang #endif 17172c4ab24aSJunchao Zhang PetscFunctionReturn(PETSC_SUCCESS); 17182c4ab24aSJunchao Zhang } 17192c4ab24aSJunchao Zhang 1720d71ae5a4SJacob Faibussowitsch PETSC_INTERN PetscErrorCode MatConvert_MPIAIJ_MPIAIJKokkos(Mat A, MatType mtype, MatReuse reuse, Mat *newmat) 1721d71ae5a4SJacob Faibussowitsch { 17228c3ff71bSJunchao Zhang Mat B; 1723076ba34aSJunchao Zhang Mat_MPIAIJ *a; 17248c3ff71bSJunchao Zhang 17258c3ff71bSJunchao Zhang PetscFunctionBegin; 17268c3ff71bSJunchao Zhang if (reuse == MAT_INITIAL_MATRIX) { 17279566063dSJacob Faibussowitsch PetscCall(MatDuplicate(A, MAT_COPY_VALUES, newmat)); 17288c3ff71bSJunchao Zhang } else if (reuse == MAT_REUSE_MATRIX) { 17299566063dSJacob Faibussowitsch PetscCall(MatCopy(A, *newmat, SAME_NONZERO_PATTERN)); 17308c3ff71bSJunchao Zhang } 17318c3ff71bSJunchao Zhang B = *newmat; 17328c3ff71bSJunchao Zhang 17336f3d89d0SStefano Zampini B->boundtocpu = PETSC_FALSE; 17349566063dSJacob Faibussowitsch PetscCall(PetscFree(B->defaultvectype)); 17359566063dSJacob Faibussowitsch PetscCall(PetscStrallocpy(VECKOKKOS, &B->defaultvectype)); 17369566063dSJacob Faibussowitsch PetscCall(PetscObjectChangeTypeName((PetscObject)B, MATMPIAIJKOKKOS)); 17378c3ff71bSJunchao Zhang 1738076ba34aSJunchao Zhang a = static_cast<Mat_MPIAIJ *>(A->data); 17399566063dSJacob Faibussowitsch if (a->A) PetscCall(MatSetType(a->A, MATSEQAIJKOKKOS)); 17409566063dSJacob Faibussowitsch if (a->B) PetscCall(MatSetType(a->B, MATSEQAIJKOKKOS)); 17419566063dSJacob Faibussowitsch if (a->lvec) PetscCall(VecSetType(a->lvec, VECSEQKOKKOS)); 17422c4ab24aSJunchao Zhang PetscCall(MatSetOps_MPIAIJKokkos(B)); 17433ba16761SJacob Faibussowitsch PetscFunctionReturn(PETSC_SUCCESS); 17448c3ff71bSJunchao Zhang } 17452c4ab24aSJunchao Zhang 17463f3ba80aSJunchao Zhang /*MC 17477f98ec86SVictor Eijkhout MATAIJKOKKOS - "mpiaijkokkos", a matrix type to be used for CSR sparse matrices with Kokkos. 17488c3ff71bSJunchao Zhang 174915229ffcSPierre Jolivet A matrix type using Kokkos-Kernels CrsMatrix type for portability across different device types 17503f3ba80aSJunchao Zhang 17512ef1f0ffSBarry Smith Options Database Key: 17522ef1f0ffSBarry Smith . -mat_type aijkokkos - sets the matrix type to `MATAIJKOKKOS` 17533f3ba80aSJunchao Zhang 17543f3ba80aSJunchao Zhang Level: beginner 17553f3ba80aSJunchao Zhang 17561cc06b55SBarry Smith .seealso: [](ch_matrices), `Mat`, `MatCreateAIJKokkos()`, `MATSEQAIJKOKKOS`, `MATSEQAIJ`, `MATMPIAIJ` 17573f3ba80aSJunchao Zhang M*/ 1758d71ae5a4SJacob Faibussowitsch PETSC_EXTERN PetscErrorCode MatCreate_MPIAIJKokkos(Mat A) 1759d71ae5a4SJacob Faibussowitsch { 17608c3ff71bSJunchao Zhang PetscFunctionBegin; 17619566063dSJacob Faibussowitsch PetscCall(PetscKokkosInitializeCheck()); 17629566063dSJacob Faibussowitsch PetscCall(MatCreate_MPIAIJ(A)); 17639566063dSJacob Faibussowitsch PetscCall(MatConvert_MPIAIJ_MPIAIJKokkos(A, MATMPIAIJKOKKOS, MAT_INPLACE_MATRIX, &A)); 17643ba16761SJacob Faibussowitsch PetscFunctionReturn(PETSC_SUCCESS); 17658c3ff71bSJunchao Zhang } 17668c3ff71bSJunchao Zhang 17678c3ff71bSJunchao Zhang /*@C 1768f8d70eaaSPierre Jolivet MatCreateAIJKokkos - Creates a sparse matrix in `MATAIJKOKKOS` (compressed row) format 17698c3ff71bSJunchao Zhang (the default parallel PETSc format). This matrix will ultimately pushed down 177020f4b53cSBarry Smith to Kokkos for calculations. 17718c3ff71bSJunchao Zhang 17728c3ff71bSJunchao Zhang Collective 17738c3ff71bSJunchao Zhang 17748c3ff71bSJunchao Zhang Input Parameters: 177511a5261eSBarry Smith + comm - MPI communicator, set to `PETSC_COMM_SELF` 177620f4b53cSBarry Smith . m - number of local rows (or `PETSC_DECIDE` to have calculated if `M` is given) 177720f4b53cSBarry Smith This value should be the same as the local size used in creating the 177820f4b53cSBarry Smith y vector for the matrix-vector product y = Ax. 177920f4b53cSBarry Smith . n - This value should be the same as the local size used in creating the 178020f4b53cSBarry Smith x vector for the matrix-vector product y = Ax. (or `PETSC_DECIDE` to have 178120f4b53cSBarry Smith calculated if N is given) For square matrices n is almost always `m`. 178220f4b53cSBarry Smith . M - number of global rows (or `PETSC_DETERMINE` to have calculated if `m` is given) 178320f4b53cSBarry Smith . N - number of global columns (or `PETSC_DETERMINE` to have calculated if `n` is given) 178420f4b53cSBarry Smith . d_nz - number of nonzeros per row in DIAGONAL portion of local submatrix 178520f4b53cSBarry Smith (same value is used for all local rows) 178620f4b53cSBarry Smith . d_nnz - array containing the number of nonzeros in the various rows of the 178720f4b53cSBarry Smith DIAGONAL portion of the local submatrix (possibly different for each row) 178820f4b53cSBarry Smith or `NULL`, if `d_nz` is used to specify the nonzero structure. 178920f4b53cSBarry Smith The size of this array is equal to the number of local rows, i.e `m`. 179020f4b53cSBarry Smith For matrices you plan to factor you must leave room for the diagonal entry and 179120f4b53cSBarry Smith put in the entry even if it is zero. 179220f4b53cSBarry Smith . o_nz - number of nonzeros per row in the OFF-DIAGONAL portion of local 179320f4b53cSBarry Smith submatrix (same value is used for all local rows). 179420f4b53cSBarry Smith - o_nnz - array containing the number of nonzeros in the various rows of the 179520f4b53cSBarry Smith OFF-DIAGONAL portion of the local submatrix (possibly different for 179620f4b53cSBarry Smith each row) or `NULL`, if `o_nz` is used to specify the nonzero 179720f4b53cSBarry Smith structure. The size of this array is equal to the number 179820f4b53cSBarry Smith of local rows, i.e `m`. 17998c3ff71bSJunchao Zhang 18008c3ff71bSJunchao Zhang Output Parameter: 18018c3ff71bSJunchao Zhang . A - the matrix 18028c3ff71bSJunchao Zhang 18032ef1f0ffSBarry Smith Level: intermediate 18042ef1f0ffSBarry Smith 18052ef1f0ffSBarry Smith Notes: 180611a5261eSBarry Smith It is recommended that one use the `MatCreate()`, `MatSetType()` and/or `MatSetFromOptions()`, 18078c3ff71bSJunchao Zhang MatXXXXSetPreallocation() paradigm instead of this routine directly. 180811a5261eSBarry Smith [MatXXXXSetPreallocation() is, for example, `MatSeqAIJSetPreallocation()`] 18098c3ff71bSJunchao Zhang 1810667f096bSBarry Smith The AIJ format, also called compressed row storage), is fully compatible with standard Fortran 18118c3ff71bSJunchao Zhang storage. That is, the stored row and column indices can begin at 18122ef1f0ffSBarry Smith either one (as in Fortran) or zero. 18138c3ff71bSJunchao Zhang 1814f8d70eaaSPierre Jolivet .seealso: [](ch_matrices), `Mat`, `MATAIJKOKKOS`, `MATSEQAIJKOKKOS`, `MATMPIAIJKOKKOS`, `MatCreate()`, `MatCreateAIJ()`, `MatSetValues()`, 1815f8d70eaaSPierre Jolivet `MatSeqAIJSetColumnIndices()`, `MatCreateSeqAIJWithArrays()` 18168c3ff71bSJunchao Zhang @*/ 1817d71ae5a4SJacob Faibussowitsch PetscErrorCode MatCreateAIJKokkos(MPI_Comm comm, PetscInt m, PetscInt n, PetscInt M, PetscInt N, PetscInt d_nz, const PetscInt d_nnz[], PetscInt o_nz, const PetscInt o_nnz[], Mat *A) 1818d71ae5a4SJacob Faibussowitsch { 18198c3ff71bSJunchao Zhang PetscMPIInt size; 18208c3ff71bSJunchao Zhang 18218c3ff71bSJunchao Zhang PetscFunctionBegin; 18229566063dSJacob Faibussowitsch PetscCall(MatCreate(comm, A)); 18239566063dSJacob Faibussowitsch PetscCall(MatSetSizes(*A, m, n, M, N)); 18249566063dSJacob Faibussowitsch PetscCallMPI(MPI_Comm_size(comm, &size)); 18258c3ff71bSJunchao Zhang if (size > 1) { 18269566063dSJacob Faibussowitsch PetscCall(MatSetType(*A, MATMPIAIJKOKKOS)); 18279566063dSJacob Faibussowitsch PetscCall(MatMPIAIJSetPreallocation(*A, d_nz, d_nnz, o_nz, o_nnz)); 18288c3ff71bSJunchao Zhang } else { 18299566063dSJacob Faibussowitsch PetscCall(MatSetType(*A, MATSEQAIJKOKKOS)); 18309566063dSJacob Faibussowitsch PetscCall(MatSeqAIJSetPreallocation(*A, d_nz, d_nnz)); 18318c3ff71bSJunchao Zhang } 18323ba16761SJacob Faibussowitsch PetscFunctionReturn(PETSC_SUCCESS); 18338c3ff71bSJunchao Zhang } 1834