1 #include <petsc_kokkos.hpp> 2 #include <petscvec_kokkos.hpp> 3 #include <petscpkg_version.h> 4 #include <petsc/private/sfimpl.h> 5 #include <petsc/private/kokkosimpl.hpp> 6 #include <../src/mat/impls/aij/seq/kokkos/aijkok.hpp> 7 #include <../src/mat/impls/aij/mpi/mpiaij.h> 8 #include <KokkosSparse_spadd.hpp> 9 #include <KokkosSparse_spgemm.hpp> 10 11 static PetscErrorCode MatAssemblyEnd_MPIAIJKokkos(Mat A, MatAssemblyType mode) 12 { 13 Mat_MPIAIJ *mpiaij = (Mat_MPIAIJ *)A->data; 14 15 PetscFunctionBegin; 16 PetscCall(MatAssemblyEnd_MPIAIJ(A, mode)); 17 /* E.g., MatCreateSubMatrix() calls MatCreateMPIAIJWithSeqAIJ(comm,A,B,..), which creates Bnew of SEQAIJ and destroys B of SEQAIJKOKKOS. 18 Thus we finalize A/B/lvec's type in MatAssemblyEnd() to handle various cases. 19 */ 20 if (mode == MAT_FINAL_ASSEMBLY) { 21 PetscScalarKokkosView v; 22 23 PetscCall(MatSetType(mpiaij->A, MATSEQAIJKOKKOS)); 24 PetscCall(MatSetType(mpiaij->B, MATSEQAIJKOKKOS)); 25 PetscCall(VecSetType(mpiaij->lvec, VECSEQKOKKOS)); // lvec is init'ed on host, without copying to device 26 PetscCall(VecGetKokkosViewWrite(mpiaij->lvec, &v)); // mark lvec updated on device, as we never need to init lvec on device 27 PetscCall(VecRestoreKokkosViewWrite(mpiaij->lvec, &v)); 28 } 29 PetscFunctionReturn(PETSC_SUCCESS); 30 } 31 32 static PetscErrorCode MatMPIAIJSetPreallocation_MPIAIJKokkos(Mat mat, PetscInt d_nz, const PetscInt d_nnz[], PetscInt o_nz, const PetscInt o_nnz[]) 33 { 34 Mat_MPIAIJ *mpiaij; 35 36 PetscFunctionBegin; 37 // reuse MPIAIJ's preallocation, which sets A/B's blocksize along other things 38 PetscCall(MatMPIAIJSetPreallocation_MPIAIJ(mat, d_nz, d_nnz, o_nz, o_nnz)); 39 mpiaij = static_cast<Mat_MPIAIJ *>(mat->data); 40 PetscCall(MatConvert_SeqAIJ_SeqAIJKokkos(mpiaij->A, MATSEQAIJKOKKOS, MAT_INPLACE_MATRIX, &mpiaij->A)); 41 PetscCall(MatConvert_SeqAIJ_SeqAIJKokkos(mpiaij->B, MATSEQAIJKOKKOS, MAT_INPLACE_MATRIX, &mpiaij->B)); 42 PetscFunctionReturn(PETSC_SUCCESS); 43 } 44 45 static PetscErrorCode MatMult_MPIAIJKokkos(Mat mat, Vec xx, Vec yy) 46 { 47 Mat_MPIAIJ *mpiaij = (Mat_MPIAIJ *)mat->data; 48 PetscInt nt; 49 50 PetscFunctionBegin; 51 PetscCall(VecGetLocalSize(xx, &nt)); 52 PetscCheck(nt == mat->cmap->n, PETSC_COMM_SELF, PETSC_ERR_ARG_SIZ, "Incompatible partition of mat (%" PetscInt_FMT ") and xx (%" PetscInt_FMT ")", mat->cmap->n, nt); 53 PetscCall(VecScatterBegin(mpiaij->Mvctx, xx, mpiaij->lvec, INSERT_VALUES, SCATTER_FORWARD)); 54 PetscCall((*mpiaij->A->ops->mult)(mpiaij->A, xx, yy)); 55 PetscCall(VecScatterEnd(mpiaij->Mvctx, xx, mpiaij->lvec, INSERT_VALUES, SCATTER_FORWARD)); 56 PetscCall((*mpiaij->B->ops->multadd)(mpiaij->B, mpiaij->lvec, yy, yy)); 57 PetscFunctionReturn(PETSC_SUCCESS); 58 } 59 60 static PetscErrorCode MatMultAdd_MPIAIJKokkos(Mat mat, Vec xx, Vec yy, Vec zz) 61 { 62 Mat_MPIAIJ *mpiaij = (Mat_MPIAIJ *)mat->data; 63 PetscInt nt; 64 65 PetscFunctionBegin; 66 PetscCall(VecGetLocalSize(xx, &nt)); 67 PetscCheck(nt == mat->cmap->n, PETSC_COMM_SELF, PETSC_ERR_ARG_SIZ, "Incompatible partition of mat (%" PetscInt_FMT ") and xx (%" PetscInt_FMT ")", mat->cmap->n, nt); 68 PetscCall(VecScatterBegin(mpiaij->Mvctx, xx, mpiaij->lvec, INSERT_VALUES, SCATTER_FORWARD)); 69 PetscCall((*mpiaij->A->ops->multadd)(mpiaij->A, xx, yy, zz)); 70 PetscCall(VecScatterEnd(mpiaij->Mvctx, xx, mpiaij->lvec, INSERT_VALUES, SCATTER_FORWARD)); 71 PetscCall((*mpiaij->B->ops->multadd)(mpiaij->B, mpiaij->lvec, zz, zz)); 72 PetscFunctionReturn(PETSC_SUCCESS); 73 } 74 75 static PetscErrorCode MatMultTranspose_MPIAIJKokkos(Mat mat, Vec xx, Vec yy) 76 { 77 Mat_MPIAIJ *mpiaij = (Mat_MPIAIJ *)mat->data; 78 PetscInt nt; 79 80 PetscFunctionBegin; 81 PetscCall(VecGetLocalSize(xx, &nt)); 82 PetscCheck(nt == mat->rmap->n, PETSC_COMM_SELF, PETSC_ERR_ARG_SIZ, "Incompatible partition of mat (%" PetscInt_FMT ") and xx (%" PetscInt_FMT ")", mat->rmap->n, nt); 83 PetscCall((*mpiaij->B->ops->multtranspose)(mpiaij->B, xx, mpiaij->lvec)); 84 PetscCall((*mpiaij->A->ops->multtranspose)(mpiaij->A, xx, yy)); 85 PetscCall(VecScatterBegin(mpiaij->Mvctx, mpiaij->lvec, yy, ADD_VALUES, SCATTER_REVERSE)); 86 PetscCall(VecScatterEnd(mpiaij->Mvctx, mpiaij->lvec, yy, ADD_VALUES, SCATTER_REVERSE)); 87 PetscFunctionReturn(PETSC_SUCCESS); 88 } 89 90 /* Merge the "A, B" matrices of mat into a matrix C. mat's type is MPIAIJKOKKOS. C's type is MATSEQAIJKOKKOS. 91 A is put before B. C's size would be A->rmap->n by (A->cmap->n + B->cmap->n). 92 C still uses local column ids. Their corresponding global column ids are returned in glob. 93 */ 94 static PetscErrorCode MatMPIAIJGetLocalMatMerge_MPIAIJKokkos(Mat mat, MatReuse reuse, IS *glob, Mat *C) 95 { 96 Mat Ad, Ao; 97 const PetscInt *cmap; 98 99 PetscFunctionBegin; 100 PetscCall(MatMPIAIJGetSeqAIJ(mat, &Ad, &Ao, &cmap)); 101 PetscCall(MatSeqAIJKokkosMergeMats(Ad, Ao, reuse, C)); 102 if (glob) { 103 PetscInt cst, i, dn, on, *gidx; 104 PetscCall(MatGetLocalSize(Ad, NULL, &dn)); 105 PetscCall(MatGetLocalSize(Ao, NULL, &on)); 106 PetscCall(MatGetOwnershipRangeColumn(mat, &cst, NULL)); 107 PetscCall(PetscMalloc1(dn + on, &gidx)); 108 for (i = 0; i < dn; i++) gidx[i] = cst + i; 109 for (i = 0; i < on; i++) gidx[i + dn] = cmap[i]; 110 PetscCall(ISCreateGeneral(PetscObjectComm((PetscObject)Ad), dn + on, gidx, PETSC_OWN_POINTER, glob)); 111 } 112 PetscFunctionReturn(PETSC_SUCCESS); 113 } 114 115 /* Structs used in matrix products of type C=AB, C=A^tB and C=B^tAB */ 116 struct MatMatStruct { 117 PetscInt n, *garray; // C's garray and its size. 118 KokkosCsrMatrix Cd, Co; // C is in split form matrices (all in local column indcies) 119 KokkosCsrMatrix C1, C2, C3, C4; // intermediate mat products 120 KokkosCsrMatrix C2_mid, C4_mid; // alias of C2, C4; share their a[], i[], but with different j[] (hence column size) 121 PetscIntKokkosView E_NzLeft; 122 PetscSF sf = nullptr; // SF to bcast or reduce matrices E to F 123 MatScalarKokkosView rootBuf, leafBuf; 124 KokkosCsrMatrix Fd, Fo; // F in split form 125 126 KernelHandle kh1; // compute C1, add C1+C3 or C1+Fd 127 KernelHandle kh2; // compute C2, add C2+C4 or C2+Fo 128 KernelHandle kh3; // compute C3 129 KernelHandle kh4; // compute C4 130 131 PetscInt E_TeamSize; // kernel launching parameters in merging E or splitting F 132 PetscInt E_VectorLength; 133 PetscInt E_RowsPerTeam; 134 PetscInt F_TeamSize; 135 PetscInt F_VectorLength; 136 PetscInt F_RowsPerTeam; 137 138 ~MatMatStruct() 139 { 140 PetscFunctionBegin; 141 PetscCallAbort(PETSC_COMM_SELF, PetscSFDestroy(&sf)); 142 PetscFunctionReturnVoid(); 143 } 144 }; 145 146 struct MatMatStruct_AB : public MatMatStruct { 147 PetscIntKokkosView F_NzLeft; // plans to split F (in leafbuf) into Fd, Fo 148 PetscIntKokkosView irootloc; // plans to put E (i.e., Bd, Bo) into rootBuf 149 PetscIntKokkosView rowoffset; 150 }; 151 152 struct MatMatStruct_AtB : public MatMatStruct { 153 MatColIdxKokkosView Fdjmap; // plans to reduce data in rootBuf to Fd, Fo 154 MatColIdxKokkosView Fdjperm; 155 MatColIdxKokkosView Fojmap; 156 MatColIdxKokkosView Fojperm; 157 }; 158 159 struct MatProductData_MPIAIJKokkos { 160 MatMatStruct_AB *mmAB = nullptr; 161 MatMatStruct_AtB *mmAtB = nullptr; 162 PetscBool reusesym = PETSC_FALSE; 163 Mat Z = nullptr; // store Z=AB in computing BtAB 164 165 ~MatProductData_MPIAIJKokkos() 166 { 167 delete mmAB; 168 delete mmAtB; 169 PetscCallAbort(PETSC_COMM_SELF, MatDestroy(&Z)); 170 } 171 }; 172 173 static PetscErrorCode MatProductDataDestroy_MPIAIJKokkos(void *data) 174 { 175 PetscFunctionBegin; 176 PetscCallCXX(delete static_cast<MatProductData_MPIAIJKokkos *>(data)); 177 PetscFunctionReturn(PETSC_SUCCESS); 178 } 179 180 /* MatSetMPIAIJKokkosWithSplitSeqAIJKokkosMatrices - Set the diag and offdiag matrices of a MATMPIAIJKOKKOS matrix. 181 It is similar to MatCreateMPIAIJWithSplitArrays. 182 183 Input Parameters: 184 + mat - the MATMPIAIJKOKKOS matrix, which should have its type and layout set, but should not have its diag, offdiag matrices set 185 . A - the diag matrix using local col ids 186 - B - the offdiag matrix using global col ids 187 188 Output Parameter: 189 . mat - the updated MATMPIAIJKOKKOS matrix 190 */ 191 static PetscErrorCode MatSetMPIAIJKokkosWithSplitSeqAIJKokkosMatrices(Mat mat, Mat A, Mat B, PetscInt *garray) 192 { 193 Mat_MPIAIJ *mpiaij = static_cast<Mat_MPIAIJ *>(mat->data); 194 PetscInt m, n, M, N, Am, An, Bm, Bn; 195 196 PetscFunctionBegin; 197 PetscCall(MatGetSize(mat, &M, &N)); 198 PetscCall(MatGetLocalSize(mat, &m, &n)); 199 PetscCall(MatGetLocalSize(A, &Am, &An)); 200 PetscCall(MatGetLocalSize(B, &Bm, &Bn)); 201 202 PetscCheck(m == Am && m == Bm, PETSC_COMM_SELF, PETSC_ERR_PLIB, "local number of rows do not match"); 203 PetscCheck(n == An, PETSC_COMM_SELF, PETSC_ERR_PLIB, "local number of columns do not match"); 204 // PetscCheck(N == Bn, PETSC_COMM_SELF, PETSC_ERR_PLIB, "global number of columns do not match"); 205 PetscCheck(!mpiaij->A && !mpiaij->B, PETSC_COMM_SELF, PETSC_ERR_PLIB, "A, B of the MPIAIJ matrix are not empty"); 206 mpiaij->A = A; 207 mpiaij->B = B; 208 mpiaij->garray = garray; 209 210 mat->preallocated = PETSC_TRUE; 211 mat->nooffprocentries = PETSC_TRUE; /* See MatAssemblyBegin_MPIAIJ. In effect, making MatAssemblyBegin a nop */ 212 213 PetscCall(MatSetOption(mat, MAT_NO_OFF_PROC_ENTRIES, PETSC_TRUE)); 214 PetscCall(MatAssemblyBegin(mat, MAT_FINAL_ASSEMBLY)); 215 /* MatAssemblyEnd is critical here. It sets mat->offloadmask according to A and B's, and 216 also gets mpiaij->B compacted, with its col ids and size reduced 217 */ 218 PetscCall(MatAssemblyEnd(mat, MAT_FINAL_ASSEMBLY)); 219 PetscCall(MatSetOption(mat, MAT_NO_OFF_PROC_ENTRIES, PETSC_FALSE)); 220 PetscCall(MatSetOption(mat, MAT_NEW_NONZERO_LOCATION_ERR, PETSC_TRUE)); 221 PetscFunctionReturn(PETSC_SUCCESS); 222 } 223 224 // Adapted from Kokkos-Kernels spmv_launch_parameters(), to get parameters in Kokkos nested loops which we used to merge or 225 // split csr matrices. The rule is to have "vector_length * team_size" be around 256 on GPUs (e.g., for a CUDA thread block) 226 template <class ExecutionSpace> 227 static PetscErrorCode MatMergeGetLaunchParameters(PetscInt numRows, PetscInt nnz, PetscInt rows_per_thread, PetscInt &team_size, PetscInt &vector_length, PetscInt &rows_per_team) 228 { 229 #if PETSC_PKG_KOKKOS_KERNELS_VERSION_LE(4, 4, 1) 230 constexpr bool is_gpu_exec_space = KokkosKernels::Impl::kk_is_gpu_exec_space<ExecutionSpace>(); 231 #else 232 constexpr bool is_gpu_exec_space = KokkosKernels::Impl::is_gpu_exec_space_v<ExecutionSpace>; 233 #endif 234 Kokkos::TeamPolicy<ExecutionSpace> teamPolicy(128, Kokkos::AUTO); 235 236 PetscFunctionBegin; 237 PetscInt nnz_per_row = numRows ? (nnz / numRows) : 0; // we might meet empty matrices 238 239 if (nnz_per_row < 1) nnz_per_row = 1; 240 241 int max_vector_length = teamPolicy.vector_length_max(); 242 243 if (vector_length < 1) { 244 vector_length = 1; 245 while (vector_length < max_vector_length && vector_length * 6 < nnz_per_row) vector_length *= 2; 246 } 247 248 // Determine rows per thread 249 if (rows_per_thread < 1) { 250 if (is_gpu_exec_space) rows_per_thread = 1; 251 else { 252 if (nnz_per_row < 20 && nnz > 5000000) { 253 rows_per_thread = 256; 254 } else rows_per_thread = 64; 255 } 256 } 257 258 if (team_size < 1) { 259 if (is_gpu_exec_space) { 260 team_size = 256 / vector_length; 261 } else { 262 team_size = 1; 263 } 264 } 265 266 rows_per_team = rows_per_thread * team_size; 267 268 if (rows_per_team < 0) { 269 PetscInt nnz_per_team = 4096; 270 PetscInt conc = ExecutionSpace().concurrency(); 271 while ((conc * nnz_per_team * 4 > nnz) && (nnz_per_team > 256)) nnz_per_team /= 2; 272 rows_per_team = (nnz_per_team + nnz_per_row - 1) / nnz_per_row; 273 } 274 PetscFunctionReturn(PETSC_SUCCESS); 275 } 276 277 /* 278 Reduce two sets of global indices into local ones 279 280 Input Parameters: 281 + n1 - size of garray1[], the first set 282 . garray1[n1] - a sorted global index array (without duplicates) 283 . m - size of indices[], the second set 284 - indices[m] - a unsorted global index array (might have duplicates), which will be updated on output into local ones 285 286 Output Parameters: 287 + n2 - size of garray2[], the merged set, which combines garray1[] and indices[] 288 . garray2[n2] - allocated by callee using PetscMalloc1(). Contains sorted unique global indices (without duplicates). Caller needs to free it. 289 . map[n1] - allocated by caller. It gives garray1[i] = garray2[map[i]] 290 - indices[m] - on output, global indices in this array are rewritten with local ones, i.e, indices_input[i] = garray2[indices_output[i]] 291 292 Example, say 293 n1 = 5 294 garray1[5] = {1, 4, 7, 8, 10} 295 m = 4 296 indices[4] = {2, 4, 8, 9} 297 298 Combining them together, we have 7 global indices in garray2[] 299 n2 = 7 300 garray2[7] = {1, 2, 4, 7, 8, 9, 10} 301 302 And we have map[] to connect "garray1[i] = garray2[map[i]], i=[0,n1)" 303 map[5] = {0, 2, 3, 4, 6} 304 305 On output, indices[] is updated with local indices 306 indices[4] = {1, 2, 4, 5} 307 */ 308 static PetscErrorCode ReduceTwoSetsOfGlobalIndices(PetscInt n1, const PetscInt *garray1, PetscInt m, PetscInt *indices, PetscInt *n2_, PetscInt **garray2_, PetscInt *map) 309 { 310 PetscHMapI g2l = nullptr; 311 PetscHashIter iter; 312 PetscInt tot, key, val; // total unique global indices. key is global id; val is local id 313 PetscInt n2, *garray2; 314 315 PetscFunctionBegin; 316 tot = 0; 317 PetscCall(PetscHMapICreateWithSize(n1, &g2l)); 318 for (PetscInt i = 0; i < m; i++) { // insert those in indices[] 319 PetscCall(PetscHMapIGetWithDefault(g2l, indices[i], -1, &val)); // if not exist, val is set with -1 320 if (val < 0) PetscCall(PetscHMapISet(g2l, indices[i], tot++)); // val < 0 means gid is not in the hash table yet 321 } 322 323 for (PetscInt i = 0; i < n1; i++) { // insert those in garray1[] 324 PetscCall(PetscHMapIGetWithDefault(g2l, garray1[i], -1, &val)); 325 if (val < 0) PetscCall(PetscHMapISet(g2l, garray1[i], tot++)); 326 } 327 328 // Pull out (unique) globals in the hash table and put them in garray2[] 329 n2 = tot; 330 PetscCall(PetscMalloc1(n2, &garray2)); 331 tot = 0; 332 PetscHashIterBegin(g2l, iter); 333 while (!PetscHashIterAtEnd(g2l, iter)) { 334 PetscHashIterGetKey(g2l, iter, key); 335 PetscHashIterNext(g2l, iter); 336 garray2[tot++] = key; 337 } 338 339 // Sort garray2[] and then map them to local indices starting from 0 340 PetscCall(PetscSortInt(n2, garray2)); 341 PetscCall(PetscHMapIClear(g2l)); 342 for (PetscInt i = 0; i < tot; i++) PetscCall(PetscHMapISet(g2l, garray2[i], i)); // i is the local id 343 344 // Rewrite indices[] with local indices 345 for (PetscInt i = 0; i < m; i++) { 346 PetscCall(PetscHMapIGetWithDefault(g2l, indices[i], -1, &val)); 347 PetscAssert(val >= 0, PETSC_COMM_SELF, PETSC_ERR_PLIB, "Met a negative local column index"); 348 indices[i] = val; 349 } 350 // Record the map that maps garray1[i] to garray2[map[i]] 351 for (PetscInt i = 0; i < n1; i++) PetscCall(PetscHMapIGetWithDefault(g2l, garray1[i], -1, &map[i])); 352 PetscCall(PetscHMapIDestroy(&g2l)); 353 *n2_ = n2; 354 *garray2_ = garray2; 355 PetscFunctionReturn(PETSC_SUCCESS); 356 } 357 358 /* 359 MatMPIAIJKokkosReduce - Reduce rows of a MPIAIJKOKKOS matrix (E, in split form) to produce another matrix (F, also in split form, stored in mm) 360 361 It is the reverse of MatMPIAIJKokkosBcast() in some sense, but with a different signature since we do not really need a fully populated MPIAIJKOKKOS E. 362 363 Think each row of E as a leaf, then the given ownerSF specifies roots for the leaves. Roots may connect to multiple leaves. 364 In this routine, we sparse-merge leaves (rows) at their roots to form potentially longer rows in F. F's number of rows will be nroots of ownerSF. 365 366 Input Parameters: 367 + comm - MPI communicator of E 368 . A - diag block of E, using local column indices 369 . B - off-diag block of E, using local column indices 370 . cstart - (global) start column of Ed 371 . cend - (global) end column + 1 of Ed. In other words, E's column ownership is in range of [cstart, cend) 372 . garray1[n1] - global column indices of Eo. Here n1 is Eo's column size. 373 . ownerSF - the SF specifies ownership (root) of rows in E 374 . reuse - either MAT_INITIAL_MATRIX or MAT_REUSE_MATRIX 375 - mm - to stash intermediate data structures for reuse 376 377 Output Parameters: 378 + map[n1] - allocated by caller. It maps garray1[] to garray2[]. See more at ReduceTwoSetsOfGlobalIndices(). 379 - mm - contains various info, such as garray2[], F (Fd, Fo) etc. 380 381 Notes: 382 When reuse = MAT_REUSE_MATRIX, cstart, cend, garray1, ownerSF, map are not significant. 383 384 */ 385 static PetscErrorCode MatMPIAIJKokkosReduceBegin(MPI_Comm comm, KokkosCsrMatrix A, KokkosCsrMatrix B, PetscInt cstart, PetscInt cend, const PetscInt *garray1, PetscSF ownerSF, MatReuse reuse, PetscInt *map, MatMatStruct_AtB *mm) 386 { 387 PetscFunctionBegin; 388 if (reuse == MAT_INITIAL_MATRIX) { 389 PetscInt Em = A.numRows(), Fm; 390 PetscInt n1 = B.numCols(); 391 392 PetscCall(PetscSFGetGraph(ownerSF, &Fm, NULL, NULL, NULL)); // Fm = #rows of F = nroots of ownerSF 393 394 // Do the analysis on host 395 auto Ai_h = Kokkos::create_mirror_view_and_copy(Kokkos::HostSpace(), A.graph.row_map); 396 auto Aj_h = Kokkos::create_mirror_view_and_copy(Kokkos::HostSpace(), A.graph.entries); 397 auto Bi_h = Kokkos::create_mirror_view_and_copy(Kokkos::HostSpace(), B.graph.row_map); 398 auto Bj_h = Kokkos::create_mirror_view_and_copy(Kokkos::HostSpace(), B.graph.entries); 399 const MatRowMapType *Ai = Ai_h.data(), *Bi = Bi_h.data(); 400 const MatColIdxType *Aj = Aj_h.data(), *Bj = Bj_h.data(); 401 402 // Count how many nonzeros of each row in E are in the left of the diag block [cstart,cend) 403 PetscIntKokkosViewHost E_NzLeft_h(NoInit("E_NzLeft_h"), Em), E_RowLen_h(NoInit("E_RowLen_h"), Em); 404 PetscInt *E_NzLeft = E_NzLeft_h.data(), *E_RowLen = E_RowLen_h.data(); 405 for (PetscInt i = 0; i < Em; i++) { 406 const PetscInt *first, *last, *it; 407 PetscInt count, step; 408 // std::lower_bound(first,last,cstart), but need to use global column indices 409 first = Bj + Bi[i]; 410 last = Bj + Bi[i + 1]; 411 count = last - first; 412 while (count > 0) { 413 it = first; 414 step = count / 2; 415 it += step; 416 if (garray1[*it] < cstart) { // map local to global 417 first = ++it; 418 count -= step + 1; 419 } else count = step; 420 } 421 E_NzLeft[i] = first - (Bj + Bi[i]); 422 E_RowLen[i] = (Ai[i + 1] - Ai[i]) + (Bi[i + 1] - Bi[i]); 423 } 424 425 // Get length of rows (i.e., sizes of leaves) that contribute to my roots 426 const PetscMPIInt *iranks, *ranks; 427 const PetscInt *ioffset, *irootloc, *roffset, *rmine; 428 PetscMPIInt niranks, nranks; 429 MPI_Request *reqs; 430 PetscMPIInt tag; 431 PetscSF reduceSF; 432 PetscInt *sdisp, *rdisp; 433 434 PetscCall(PetscCommGetNewTag(comm, &tag)); 435 PetscCall(PetscSFGetLeafRanks(ownerSF, &niranks, &iranks, &ioffset, &irootloc)); // get leaf ranks connecting to roots on this process (I'll recv from them) 436 PetscCall(PetscSFGetRootRanks(ownerSF, &nranks, &ranks, &roffset, &rmine, NULL)); // get root ranks this process connects (I'll send to them) 437 438 // Find out length of each row I will receive. Even for the same row index, when they are from 439 // different senders, they might have different lengths (and sparsity patterns) 440 PetscInt sendRowCnt = roffset[nranks], recvRowCnt = ioffset[niranks]; 441 PetscInt *sendRowLen, *recvRowLen; // lengths of rows of I need to send/recv per process 442 443 PetscCall(PetscMalloc5(sendRowCnt, &sendRowLen, recvRowCnt + 1, &recvRowLen, nranks, &sdisp, niranks + 1, &rdisp, nranks + niranks, &reqs)); 444 445 for (PetscInt i = 0; i < sendRowCnt; i++) sendRowLen[i] = E_RowLen[rmine[i]]; 446 recvRowLen[0] = 0; // since we will make it in CSR format later 447 recvRowLen++; // advance the pointer now 448 for (PetscInt i = 0; i < niranks; i++) MPI_Irecv(&recvRowLen[ioffset[i]], ioffset[i + 1] - ioffset[i], MPIU_INT, iranks[i], tag, comm, &reqs[nranks + i]); 449 for (PetscInt i = 0; i < nranks; i++) MPIU_Isend(&sendRowLen[roffset[i]], roffset[i + 1] - roffset[i], MPIU_INT, ranks[i], tag, comm, &reqs[i]); 450 PetscCallMPI(MPI_Waitall(nranks + niranks, reqs, MPI_STATUSES_IGNORE)); 451 452 // Build the real PetscSF for reducing E rows (buffer to buffer) 453 rdisp[0] = 0; 454 for (PetscInt i = 0; i < niranks; i++) { 455 rdisp[i + 1] = rdisp[i]; 456 for (PetscInt j = ioffset[i]; j < ioffset[i + 1]; j++) { rdisp[i + 1] += recvRowLen[j]; } 457 } 458 recvRowLen--; // put it back into csr format 459 for (PetscInt i = 0; i < recvRowCnt; i++) recvRowLen[i + 1] += recvRowLen[i]; 460 461 for (PetscInt i = 0; i < nranks; i++) MPIU_Irecv(&sdisp[i], 1, MPIU_INT, ranks[i], tag, comm, &reqs[i]); 462 for (PetscInt i = 0; i < niranks; i++) MPIU_Isend(&rdisp[i], 1, MPIU_INT, iranks[i], tag, comm, &reqs[nranks + i]); 463 PetscCallMPI(MPI_Waitall(nranks + niranks, reqs, MPI_STATUSES_IGNORE)); 464 465 PetscInt nleaves = 0, Enz = 0; // leaves are nonzeros I will send 466 PetscInt nroots = rdisp[niranks]; // roots are nonzeros I will recv 467 PetscSFNode *iremote; 468 469 for (PetscInt i = 0; i < Em; i++) Enz += E_RowLen[i]; 470 PetscAssert(A.nnz() + B.nnz() == Enz, comm, PETSC_ERR_PLIB, "Enz should be equal to sum of nnz of A and B"); 471 PetscCallMPI(PetscMalloc1(Enz, &iremote)); // no free, since we give ownership to reduceSF 472 473 for (PetscInt i = 0; i < nranks; i++) { 474 PetscInt count = 0; 475 for (PetscInt j = roffset[i]; j < roffset[i + 1]; j++) count += E_RowLen[rmine[j]]; 476 for (PetscInt j = 0; j < count; j++) { 477 iremote[nleaves + j].rank = ranks[i]; 478 iremote[nleaves + j].index = sdisp[i] + j; 479 } 480 nleaves += count; 481 } 482 PetscCheck(nleaves == Enz, comm, PETSC_ERR_PLIB, "nleaves should be equal to Enz"); 483 484 PetscCall(PetscSFCreate(comm, &reduceSF)); 485 PetscCall(PetscSFSetGraph(reduceSF, nroots, nleaves, NULL, PETSC_OWN_POINTER, iremote, PETSC_OWN_POINTER)); 486 487 // Copy (global) column indices of the needed rows in E to sendCol[], and then PetscSFReduce to recvCol[] 488 PetscInt *sendCol, *recvCol; 489 PetscCall(PetscMalloc2(nleaves, &sendCol, nroots, &recvCol)); 490 for (PetscInt k = 0; k < roffset[nranks]; k++) { 491 PetscInt i = rmine[k]; // row to be copied 492 PetscInt *buf = &sendCol[Ai[i] + Bi[i]]; 493 PetscInt nzLeft = E_NzLeft[i]; 494 PetscInt alen = Ai[i + 1] - Ai[i], blen = Bi[i + 1] - Bi[i]; 495 for (PetscInt j = 0; j < alen + blen; j++) { 496 if (j < nzLeft) { 497 buf[j] = garray1[Bj[Bi[i] + j]]; // left B, in global 498 } else if (j < nzLeft + alen) { 499 buf[j] = Aj[Ai[i] + j - nzLeft] + cstart; // diag A, also in global 500 } else { 501 buf[j] = garray1[Bj[Bi[i] + j - alen]]; // right B, in global 502 } 503 } 504 } 505 PetscCall(PetscSFReduceWithMemTypeBegin(reduceSF, MPIU_INT, PETSC_MEMTYPE_HOST, sendCol, PETSC_MEMTYPE_HOST, recvCol, MPI_REPLACE)); 506 PetscCall(PetscSFReduceEnd(reduceSF, MPIU_INT, sendCol, recvCol, MPI_REPLACE)); 507 508 // With recvCol[], we do a series of analysis to get i, j of Fd, Fo, and build plans to reduce nonzeros in recv buffers to Fd and Fo 509 PetscInt *recvRowPerm, *recvColSorted; 510 PetscInt *recvNzPerm, *recvNzPermSorted; 511 PetscCall(PetscMalloc4(recvRowCnt, &recvRowPerm, nroots, &recvColSorted, nroots, &recvNzPerm, nroots, &recvNzPermSorted)); 512 513 for (PetscInt i = 0; i < nroots; i++) recvNzPerm[i] = i; // numbering all received nonzeros 514 for (PetscInt i = 0; i < recvRowCnt; i++) recvRowPerm[i] = i; // put up a permutation array, so that after sorting we know where to get a row in recvCol[] 515 PetscCall(PetscSortIntWithPermutation(recvRowCnt, irootloc, recvRowPerm)); // irootloc[] (owned by ownerSF) won't be changed 516 517 // i[] array, nz are always easiest to compute 518 MatRowMapKokkosViewHost Fdi_h(NoInit("Fdi_h"), Fm + 1), Foi_h(NoInit("Foi_h"), Fm + 1); 519 MatRowMapType *Fdi, *Foi; 520 PetscInt FnzDups = 0, Fdnz = 0, FdnzDups = 0, Fonz = 0, FonzDups = 0; // nz (with or without dups) in F, Fd, Fo 521 PetscInt iter; 522 523 Kokkos::deep_copy(Fdi_h, 0); // zero, as we will do 'val++' on them 524 Kokkos::deep_copy(Foi_h, 0); 525 Fdi = Fdi_h.data() + 1; // +1 for easy indexing in code below 526 Foi = Foi_h.data() + 1; 527 iter = 0; 528 while (iter < recvRowCnt) { // iter over received rows 529 PetscInt curRowIdx = irootloc[recvRowPerm[iter]]; 530 PetscInt dupRows = 1; // current row has this many contributing rows (of various sparsity patterns) 531 532 while (iter + dupRows < recvRowCnt && irootloc[recvRowPerm[iter + dupRows]] == curRowIdx) dupRows++; 533 534 // Copy column indices (and their permutation) of these rows into recvColSorted & recvNzPermSorted 535 PetscInt nz = 0; // nz (with dups) in the current row 536 PetscInt *jbuf = recvColSorted + FnzDups; 537 PetscInt *pbuf = recvNzPermSorted + FnzDups; 538 PetscInt *jbuf2 = jbuf; // temp pointers 539 PetscInt *pbuf2 = pbuf; 540 for (PetscInt d = 0; d < dupRows; d++) { 541 PetscInt i = recvRowPerm[iter + d]; 542 PetscInt len = recvRowLen[i + 1] - recvRowLen[i]; 543 PetscCall(PetscArraycpy(jbuf2, &recvCol[recvRowLen[i]], len)); 544 PetscCall(PetscArraycpy(pbuf2, &recvNzPerm[recvRowLen[i]], len)); 545 jbuf2 += len; 546 pbuf2 += len; 547 nz += len; 548 } 549 PetscCall(PetscIntSortSemiOrderedWithArray(nz, jbuf, pbuf)); // It could be improved with k-way merge sort, since the rows are already sorted 550 551 // Scan column indices (in jbuf[0,nz), might have dups) of this row, and see how many go to Fd and how many go to Fo 552 PetscInt cur = 0; 553 while (cur < nz) { 554 PetscInt curColIdx = jbuf[cur]; 555 PetscInt dups = 1; 556 557 while (cur + dups < nz && jbuf[cur + dups] == curColIdx) dups++; 558 if (curColIdx >= cstart && curColIdx < cend) { 559 Fdi[curRowIdx]++; 560 FdnzDups += dups; 561 } else { 562 Foi[curRowIdx]++; 563 FonzDups += dups; 564 } 565 cur += dups; 566 } 567 568 FnzDups += nz; 569 iter += dupRows; // Move to next unique row 570 } 571 572 Fdi = Fdi_h.data(); // restore Fdi, Foi and make them CSR 573 Foi = Foi_h.data(); 574 for (PetscInt i = 0; i < Fm; i++) { 575 Fdi[i + 1] += Fdi[i]; 576 Foi[i + 1] += Foi[i]; 577 } 578 Fdnz = Fdi[Fm]; 579 Fonz = Foi[Fm]; 580 PetscCall(PetscFree2(sendCol, recvCol)); 581 582 // Allocate j, jmap, jperm for Fd and Fo 583 MatColIdxKokkosViewHost Fdj_h(NoInit("Fdj_h"), Fdnz), Foj_h(NoInit("Foj_h"), Fonz); 584 MatRowMapKokkosViewHost Fdjmap_h(NoInit("Fdjmap_h"), Fdnz + 1), Fojmap_h(NoInit("Fojmap_h"), Fonz + 1); // +1 to make csr 585 MatRowMapKokkosViewHost Fdjperm_h(NoInit("Fdjperm_h"), FdnzDups), Fojperm_h(NoInit("Fojperm_h"), FonzDups); 586 MatColIdxType *Fdj = Fdj_h.data(), *Foj = Foj_h.data(); 587 MatRowMapType *Fdjmap = Fdjmap_h.data(), *Fojmap = Fojmap_h.data(); 588 MatRowMapType *Fdjperm = Fdjperm_h.data(), *Fojperm = Fojperm_h.data(); 589 590 // Scan recvColSorted[] again, and fill j, jmap, jperm for Fd and Fo 591 Fdjmap[0] = 0; 592 Fojmap[0] = 0; 593 FnzDups = 0; 594 Fdnz = 0; 595 Fonz = 0; 596 iter = 0; // iter over received rows 597 while (iter < recvRowCnt) { 598 PetscInt curRowIdx = irootloc[recvRowPerm[iter]]; // current row idx 599 PetscInt dupRows = 1; // It has this many contributing rows (of various lengths) 600 PetscInt nz = 0; // nz (with dups) in the current row 601 602 while (iter + dupRows < recvRowCnt && irootloc[recvRowPerm[iter + dupRows]] == curRowIdx) dupRows++; 603 for (PetscInt d = 0; d < dupRows; d++) { 604 PetscInt i = recvRowPerm[iter + d]; 605 nz += recvRowLen[i + 1] - recvRowLen[i]; 606 } 607 608 PetscInt *jbuf = recvColSorted + FnzDups; 609 // Scan columns (in jbuf[0,nz) of this row, copy them and their permutation to j[] and jperm[] of Fd and Fo 610 PetscInt cur = 0; 611 while (cur < nz) { 612 PetscInt curColIdx = jbuf[cur]; 613 PetscInt dups = 1; 614 615 while (cur + dups < nz && jbuf[cur + dups] == curColIdx) dups++; 616 if (curColIdx >= cstart && curColIdx < cend) { 617 Fdj[Fdnz] = curColIdx - cstart; // easily convert to local 618 Fdjmap[Fdnz + 1] = Fdjmap[Fdnz] + dups; 619 for (PetscInt j = 0; j < dups; j++) Fdjperm[Fdjmap[Fdnz] + j] = recvNzPermSorted[FnzDups + j]; 620 FdnzDups += dups; 621 Fdnz++; 622 } else { 623 Foj[Fonz] = curColIdx; // in global 624 Fojmap[Fonz + 1] = Fojmap[Fonz] + dups; 625 for (PetscInt j = 0; j < dups; j++) Fojperm[Fojmap[Fonz] + j] = recvNzPermSorted[FnzDups + j]; 626 FonzDups += dups; 627 Fonz++; 628 } 629 cur += dups; 630 FnzDups += dups; 631 } 632 iter += dupRows; // Move to next unique row 633 } 634 PetscCall(PetscFree4(recvRowPerm, recvColSorted, recvNzPerm, recvNzPermSorted)); 635 PetscCall(PetscFree5(sendRowLen, recvRowLen, sdisp, rdisp, reqs)); 636 637 // Combine global column indices in garray1[] and Foj[] 638 PetscInt n2, *garray2; 639 640 PetscCall(ReduceTwoSetsOfGlobalIndices(n1, garray1, Fonz, Foj, &n2, &garray2, map)); 641 mm->sf = reduceSF; 642 mm->leafBuf = MatScalarKokkosView(NoInit("leafBuf"), nleaves); 643 mm->rootBuf = MatScalarKokkosView(NoInit("rootBuf"), nroots); 644 mm->garray = garray2; // give ownership, so no free 645 mm->n = n2; 646 mm->E_NzLeft = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), E_NzLeft_h); 647 mm->Fdjmap = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Fdjmap_h); 648 mm->Fdjperm = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Fdjperm_h); 649 mm->Fojmap = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Fojmap_h); 650 mm->Fojperm = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Fojperm_h); 651 652 // Output Fd and Fo in KokkosCsrMatrix format 653 MatScalarKokkosView Fda_d(NoInit("Fda_d"), Fdnz); 654 MatRowMapKokkosView Fdi_d = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Fdi_h); 655 MatColIdxKokkosView Fdj_d = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Fdj_h); 656 MatScalarKokkosView Foa_d(NoInit("Foa_d"), Fonz); 657 MatRowMapKokkosView Foi_d = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Foi_h); 658 MatColIdxKokkosView Foj_d = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Foj_h); 659 660 PetscCallCXX(mm->Fd = KokkosCsrMatrix("Fd", Fm, cend - cstart, Fdnz, Fda_d, Fdi_d, Fdj_d)); 661 PetscCallCXX(mm->Fo = KokkosCsrMatrix("Fo", Fm, n2, Fonz, Foa_d, Foi_d, Foj_d)); // Fo's column size is n2, length of garray2[] 662 663 // Compute kernel launch parameters in merging E 664 PetscInt teamSize, vectorLength, rowsPerTeam; 665 666 teamSize = vectorLength = rowsPerTeam = -1; 667 PetscCall(MatMergeGetLaunchParameters<DefaultExecutionSpace>(Em, Enz, -1, teamSize, vectorLength, rowsPerTeam)); 668 mm->E_TeamSize = teamSize; 669 mm->E_VectorLength = vectorLength; 670 mm->E_RowsPerTeam = rowsPerTeam; 671 } else PetscCheck(reuse == MAT_REUSE_MATRIX, comm, PETSC_ERR_PLIB, "Unsupported MatReuse enum %d", reuse); 672 673 // Handy aliases 674 auto &Aa = A.values; 675 auto &Ba = B.values; 676 const auto &Ai = A.graph.row_map; 677 const auto &Bi = B.graph.row_map; 678 const auto &E_NzLeft = mm->E_NzLeft; 679 auto &leafBuf = mm->leafBuf; 680 auto &rootBuf = mm->rootBuf; 681 PetscSF reduceSF = mm->sf; 682 PetscInt Em = A.numRows(); 683 PetscInt teamSize = mm->E_TeamSize; 684 PetscInt vectorLength = mm->E_VectorLength; 685 PetscInt rowsPerTeam = mm->E_RowsPerTeam; 686 PetscInt workSets = (Em + rowsPerTeam - 1) / rowsPerTeam; 687 688 // Copy rows in A/B of E to leafBuf, then pass it to rootBuf 689 PetscCallCXX(Kokkos::parallel_for( 690 Kokkos::TeamPolicy<>(PetscGetKokkosExecutionSpace(), workSets, teamSize, vectorLength), KOKKOS_LAMBDA(const KokkosTeamMemberType &t) { 691 Kokkos::parallel_for(Kokkos::TeamThreadRange(t, 0, rowsPerTeam), [&](PetscInt k) { 692 PetscInt i = t.league_rank() * rowsPerTeam + k; // i-th row in F 693 if (i < Em) { 694 PetscInt disp = Ai(i) + Bi(i); 695 PetscInt alen = Ai(i + 1) - Ai(i); 696 PetscInt blen = Bi(i + 1) - Bi(i); 697 PetscInt nzleft = E_NzLeft(i); 698 699 Kokkos::parallel_for(Kokkos::ThreadVectorRange(t, alen + blen), [&](PetscInt j) { 700 MatScalar &val = leafBuf(disp + j); 701 if (j < nzleft) { // B left 702 val = Ba(Bi(i) + j); 703 } else if (j < nzleft + alen) { // diag A 704 val = Aa(Ai(i) + j - nzleft); 705 } else { // B right 706 val = Ba(Bi(i) + j - alen); 707 } 708 }); 709 } 710 }); 711 })); 712 PetscCall(PetscSFReduceWithMemTypeBegin(reduceSF, MPIU_SCALAR, PETSC_MEMTYPE_KOKKOS, leafBuf.data(), PETSC_MEMTYPE_KOKKOS, rootBuf.data(), MPI_REPLACE)); 713 PetscFunctionReturn(PETSC_SUCCESS); 714 } 715 716 // To finish MatMPIAIJKokkosReduce. 717 static PetscErrorCode MatMPIAIJKokkosReduceEnd(MPI_Comm comm, KokkosCsrMatrix A, KokkosCsrMatrix B, PetscInt cstart, PetscInt cend, const PetscInt *garray1, PetscSF ownerSF, MatReuse reuse, PetscInt *map, MatMatStruct_AtB *mm) 718 { 719 auto &leafBuf = mm->leafBuf; 720 auto &rootBuf = mm->rootBuf; 721 auto &Fda = mm->Fd.values; 722 const auto &Fdjmap = mm->Fdjmap; 723 const auto &Fdjperm = mm->Fdjperm; 724 auto Fdnz = mm->Fd.nnz(); 725 auto &Foa = mm->Fo.values; 726 const auto &Fojmap = mm->Fojmap; 727 const auto &Fojperm = mm->Fojperm; 728 auto Fonz = mm->Fo.nnz(); 729 PetscSF reduceSF = mm->sf; 730 731 PetscFunctionBegin; 732 PetscCall(PetscSFReduceEnd(reduceSF, MPIU_SCALAR, leafBuf.data(), rootBuf.data(), MPI_REPLACE)); 733 734 // Reduce data in rootBuf to Fd and Fo 735 PetscCallCXX(Kokkos::parallel_for( 736 Kokkos::RangePolicy<>(PetscGetKokkosExecutionSpace(), 0, Fdnz), KOKKOS_LAMBDA(const MatRowMapType i) { 737 PetscScalar sum = 0.0; 738 for (MatRowMapType k = Fdjmap(i); k < Fdjmap(i + 1); k++) sum += rootBuf(Fdjperm(k)); 739 Fda(i) = sum; 740 })); 741 742 PetscCallCXX(Kokkos::parallel_for( 743 Kokkos::RangePolicy<>(PetscGetKokkosExecutionSpace(), 0, Fonz), KOKKOS_LAMBDA(const MatRowMapType i) { 744 PetscScalar sum = 0.0; 745 for (MatRowMapType k = Fojmap(i); k < Fojmap(i + 1); k++) sum += rootBuf(Fojperm(k)); 746 Foa(i) = sum; 747 })); 748 PetscFunctionReturn(PETSC_SUCCESS); 749 } 750 751 /* 752 MatMPIAIJKokkosBcast - Bcast local rows of a MPIAIJKOKKOS matrix (E) to produce a local matrix (F, stored in mm) in split form 753 754 This is a complex routine. It is essentially the MPIAIJKOKKOS counterpart of MatGetBrowsOfAoCols_MPIAIJ, but supports 755 device and involves various index mapping. 756 757 In the given ownerSF, leaves correspond to rows in F, and roots correspond to rows in E. Roots may connect to multiple leaves. 758 Suppose F's j-th row is connected to a root identified by PetscSFNode (k,i), it means we need to bcast the i-th row of E on rank k 759 to j-th row of F. ownerSF is not an arbitrary SF, instead it is the Mvctx of another MPIAIJ matrix A that is able to perform A*E. 760 F has the same column layout as E. 761 762 Conceptually F has global column indices. In this routine, we spit F into diagonal Fd and off-diagonal Fo. 763 Fd uses local column indices, which are easy to compute. We just need to subtract the "local column range start" from the global indices. 764 Fo had global column indices at first. We will reduce them into local ones. In doing that, we also take into account the global 765 column indices that E's off-diag block has. Let's say there are n1 such indices stored in garray1[]. We will reduce them along with 766 column indices in Fo and update Fo with local indices. 767 768 Input Parameters: 769 + E - the MPIAIJKOKKOS matrix 770 . ownerSF - the ownership SF (insignificant in MAT_REUSE_MATRIX) 771 . reuse - either MAT_INITIAL_MATRIX or MAT_REUSE_MATRIX 772 - mm - to stash matproduct intermediate data structures 773 774 Output Parameters: 775 + map[n1] - allocated by caller. It maps garray1[] to garray2[]. See more at ReduceTwoSetsOfGlobalIndices. 776 - mm - contains various info, such as garray2[], Fd, Fo, etc. 777 778 Notes: 779 When reuse = MAT_REUSE_MATRIX, ownerSF, map are not significant. 780 The routine is provide in split-phase form MatMPIAIJKokkosBcastBegin/End() to provide computation/communication opportunities. 781 */ 782 static PetscErrorCode MatMPIAIJKokkosBcastBegin(Mat E, PetscSF ownerSF, MatReuse reuse, PetscInt *map, MatMatStruct_AB *mm) 783 { 784 Mat_MPIAIJ *empi = static_cast<Mat_MPIAIJ *>(E->data); 785 Mat A = empi->A, B = empi->B; // diag and off-diag 786 Mat_SeqAIJKokkos *akok = static_cast<Mat_SeqAIJKokkos *>(A->spptr), *bkok = static_cast<Mat_SeqAIJKokkos *>(B->spptr); 787 PetscInt Em = E->rmap->n; // #local rows 788 MPI_Comm comm; 789 790 PetscFunctionBegin; 791 PetscCallMPI(PetscObjectGetComm((PetscObject)E, &comm)); 792 if (reuse == MAT_INITIAL_MATRIX) { 793 Mat_SeqAIJ *aseq = static_cast<Mat_SeqAIJ *>(A->data), *bseq = static_cast<Mat_SeqAIJ *>(B->data); 794 PetscInt n1 = B->cmap->n, *Ai = aseq->i, *Aj = aseq->j, *Bi = bseq->i, *Bj = bseq->j; 795 const PetscInt *garray1 = empi->garray; // its size is n1 796 PetscInt cstart, cend; 797 PetscSF bcastSF; 798 799 PetscCall(MatGetOwnershipRangeColumn(E, &cstart, &cend)); 800 801 // Count how many nonzeros of each row in E are in the left of the diag block [cstart,cend) 802 PetscIntKokkosViewHost E_NzLeft_h(NoInit("E_NzLeft_h"), Em), E_RowLen_h(NoInit("E_RowLen_h"), Em); 803 PetscInt *E_NzLeft = E_NzLeft_h.data(), *E_RowLen = E_RowLen_h.data(); 804 for (PetscInt i = 0; i < Em; i++) { 805 const PetscInt *first, *last, *it; 806 PetscInt count, step; 807 // std::lower_bound(first,last,cstart), but need to use global column indices 808 first = Bj + Bi[i]; 809 last = Bj + Bi[i + 1]; 810 count = last - first; 811 while (count > 0) { 812 it = first; 813 step = count / 2; 814 it += step; 815 if (empi->garray[*it] < cstart) { // map local to global 816 first = ++it; 817 count -= step + 1; 818 } else count = step; 819 } 820 E_NzLeft[i] = first - (Bj + Bi[i]); 821 E_RowLen[i] = (Ai[i + 1] - Ai[i]) + (Bi[i + 1] - Bi[i]); 822 } 823 824 // Compute row pointer Fi of F 825 PetscInt *Fi, Fm, Fnz; 826 PetscCall(PetscSFGetGraph(ownerSF, NULL, &Fm, NULL, NULL)); // Fm = #rows of F = nleaves of ownerSF 827 PetscCall(PetscMalloc1(Fm + 1, &Fi)); 828 Fi[0] = 0; 829 PetscCall(PetscSFBcastWithMemTypeBegin(ownerSF, MPIU_INT, PETSC_MEMTYPE_HOST, E_RowLen, PETSC_MEMTYPE_HOST, &Fi[1], MPI_REPLACE)); 830 PetscCall(PetscSFBcastEnd(ownerSF, MPIU_INT, E_RowLen, &Fi[1], MPI_REPLACE)); 831 for (PetscInt i = 0; i < Fm; i++) Fi[i + 1] += Fi[i]; 832 Fnz = Fi[Fm]; 833 834 // Build the real PetscSF for bcasting E rows (buffer to buffer) 835 const PetscMPIInt *iranks, *ranks; 836 const PetscInt *ioffset, *irootloc, *roffset; 837 PetscMPIInt niranks, nranks; 838 PetscInt *sdisp, *rdisp; 839 MPI_Request *reqs; 840 PetscMPIInt tag; 841 842 PetscCall(PetscSFGetLeafRanks(ownerSF, &niranks, &iranks, &ioffset, &irootloc)); // get leaf ranks referencing roots on this process 843 PetscCall(PetscSFGetRootRanks(ownerSF, &nranks, &ranks, &roffset, NULL, NULL)); // recv info 844 PetscCall(PetscMalloc3(niranks + 1, &sdisp, nranks, &rdisp, niranks + nranks, &reqs)); 845 846 sdisp[0] = 0; // send displacement 847 for (PetscInt i = 0; i < niranks; i++) { 848 sdisp[i + 1] = sdisp[i]; 849 for (PetscInt j = ioffset[i]; j < ioffset[i + 1]; j++) { 850 PetscInt r = irootloc[j]; // row to be sent 851 sdisp[i + 1] += E_RowLen[r]; 852 } 853 } 854 855 PetscCallMPI(PetscCommGetNewTag(comm, &tag)); 856 for (PetscInt i = 0; i < nranks; i++) PetscCallMPI(MPIU_Irecv(&rdisp[i], 1, MPIU_INT, ranks[i], tag, comm, &reqs[i])); 857 for (PetscInt i = 0; i < niranks; i++) PetscCallMPI(MPIU_Isend(&sdisp[i], 1, MPIU_INT, iranks[i], tag, comm, &reqs[nranks + i])); 858 PetscCallMPI(MPI_Waitall(niranks + nranks, reqs, MPI_STATUSES_IGNORE)); 859 860 PetscInt nleaves = Fnz; // leaves are nonzeros I will receive 861 PetscInt nroots = sdisp[niranks]; // roots are nonzeros I will send 862 PetscSFNode *iremote; // give ownership to bcastSF 863 PetscCall(PetscMalloc1(nleaves, &iremote)); 864 for (PetscInt i = 0; i < nranks; i++) { // for each sender rank 865 PetscInt k = 0; 866 for (PetscInt j = Fi[roffset[i]]; j < Fi[roffset[i + 1]]; j++) { // I will receive rows [roffset[i], roffset[i+1]) of F from ranks[i] 867 iremote[j].rank = ranks[i]; 868 iremote[j].index = rdisp[i] + k; // their root location 869 k++; 870 } 871 } 872 PetscCall(PetscSFCreate(comm, &bcastSF)); 873 PetscCall(PetscSFSetGraph(bcastSF, nroots, nleaves, NULL, PETSC_OWN_POINTER, iremote, PETSC_OWN_POINTER)); 874 PetscCall(PetscFree3(sdisp, rdisp, reqs)); 875 876 // Build a plan (rowoffset, irootloc, E_NzLeft) to copy rows in E to rootdata of bcastSF in parallel 877 PetscIntKokkosViewHost rowoffset_h(NoInit("rowoffset_h"), ioffset[niranks] + 1); 878 PetscInt *rowoffset = rowoffset_h.data(); // for each entry (row) indicated in irootloc[], we calculate its destinate offset in copying 879 rowoffset[0] = 0; 880 for (PetscInt i = 0; i < ioffset[niranks]; i++) { rowoffset[i + 1] = rowoffset[i] + E_RowLen[irootloc[i]]; } 881 882 // Copy (global) column indices of the needed rows in E to a buffer, and then bcast to Fj[] 883 PetscInt *jbuf, *Fj; 884 PetscCall(PetscMalloc2(nroots, &jbuf, Fnz, &Fj)); 885 for (PetscInt k = 0; k < ioffset[niranks]; k++) { 886 PetscInt i = irootloc[k]; // row to be copied 887 PetscInt *buf = &jbuf[rowoffset[k]]; 888 PetscInt nzLeft = E_NzLeft[i]; 889 PetscInt alen = Ai[i + 1] - Ai[i], blen = Bi[i + 1] - Bi[i]; 890 for (PetscInt j = 0; j < alen + blen; j++) { 891 if (j < nzLeft) { 892 buf[j] = empi->garray[Bj[Bi[i] + j]]; // left B, in global 893 } else if (j < nzLeft + alen) { 894 buf[j] = Aj[Ai[i] + j - nzLeft] + cstart; // diag A, also in global 895 } else { 896 buf[j] = empi->garray[Bj[Bi[i] + j - alen]]; // right B, in global 897 } 898 } 899 } 900 PetscCall(PetscSFBcastWithMemTypeBegin(bcastSF, MPIU_INT, PETSC_MEMTYPE_HOST, jbuf, PETSC_MEMTYPE_HOST, Fj, MPI_REPLACE)); 901 PetscCall(PetscSFBcastEnd(bcastSF, MPIU_INT, jbuf, Fj, MPI_REPLACE)); 902 903 // Build a plan (i.e., F_NzLeft) to split F into Fd and Fo 904 MatRowMapKokkosViewHost Fdi_h(NoInit("Fdi_h"), Fm + 1), Foi_h(NoInit("Foi_h"), Fm + 1); // row pointer of Fd, Fo 905 MatColIdxKokkosViewHost F_NzLeft_h(NoInit("F_NzLeft_h"), Fm); // split each row of F into Left, Diag, Right. We only need to record #nz in Left and Diag. 906 MatRowMapType *Fdi = Fdi_h.data(), *Foi = Foi_h.data(); 907 MatColIdxType *F_NzLeft = F_NzLeft_h.data(); 908 909 Fdi[0] = Foi[0] = 0; 910 for (PetscInt i = 0; i < Fm; i++) { 911 PetscInt *first, *last, *lb1, *lb2; 912 // cut the row into: Left, [cstart, cend), Right 913 first = Fj + Fi[i]; 914 last = Fj + Fi[i + 1]; 915 lb1 = std::lower_bound(first, last, cstart); 916 F_NzLeft[i] = lb1 - first; 917 lb2 = std::lower_bound(first, last, cend); 918 Fdi[i + 1] = lb2 - lb1; // row i length in Fdi 919 Foi[i + 1] = (Fi[i + 1] - Fi[i]) - Fdi[i + 1]; // row i length in Foi 920 } 921 for (PetscInt i = 0; i < Fm; i++) { 922 Fdi[i + 1] += Fdi[i]; 923 Foi[i + 1] += Foi[i]; 924 } 925 926 // Fill Fdj[] and Foj[], i.e., columns of Fd and Fo. Fdj[] are local, but Foj[] are not yet. 927 PetscInt Fdnz = Fdi[Fm], Fonz = Foi[Fm]; 928 MatColIdxKokkosViewHost Fdj_h(NoInit("Fdj_h"), Fdnz), Foj_h(NoInit("Foj_h"), Fonz); 929 MatColIdxType *Fdj = Fdj_h.data(), *Foj = Foj_h.data(), gid; 930 931 for (PetscInt i = 0; i < Fm; i++) { 932 PetscInt nzLeft = F_NzLeft[i]; 933 PetscInt len = Fdi[i + 1] - Fdi[i]; // diag row len 934 for (PetscInt j = 0; j < Fi[i + 1] - Fi[i]; j++) { 935 gid = Fj[Fi[i] + j]; 936 if (j < nzLeft) { // left, in global 937 Foj[Foi[i] + j] = gid; 938 } else if (j < nzLeft + len) { // diag, in local 939 Fdj[Fdi[i] + j - nzLeft] = gid - cstart; 940 } else { // right, in global 941 Foj[Foi[i] + j - len] = gid; 942 } 943 } 944 } 945 PetscCall(PetscFree2(jbuf, Fj)); 946 PetscCall(PetscFree(Fi)); 947 948 // Reduce global indices in Foj[] and garray1[] into local ones 949 PetscInt n2, *garray2; 950 PetscCall(ReduceTwoSetsOfGlobalIndices(n1, garray1, Fonz, Foj, &n2, &garray2, map)); 951 952 // Record the plans built above, for reuse 953 PetscIntKokkosViewHost tmp(const_cast<PetscInt *>(irootloc), ioffset[niranks]); // irootloc[] is owned by ownerSF. We create a copy for safety 954 PetscIntKokkosViewHost irootloc_h(NoInit("irootloc_h"), ioffset[niranks]); 955 Kokkos::deep_copy(irootloc_h, tmp); 956 mm->sf = bcastSF; 957 mm->E_NzLeft = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), E_NzLeft_h); 958 mm->F_NzLeft = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), F_NzLeft_h); 959 mm->irootloc = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), irootloc_h); 960 mm->rowoffset = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), rowoffset_h); 961 mm->rootBuf = MatScalarKokkosView(NoInit("rootBuf"), nroots); 962 mm->leafBuf = MatScalarKokkosView(NoInit("leafBuf"), nleaves); 963 mm->garray = garray2; 964 mm->n = n2; 965 966 // Output Fd and Fo in KokkosCsrMatrix format 967 MatScalarKokkosView Fda_d(NoInit("Fda_d"), Fdnz), Foa_d(NoInit("Foa_d"), Fonz); 968 MatRowMapKokkosView Fdi_d = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Fdi_h); 969 MatColIdxKokkosView Fdj_d = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Fdj_h); 970 MatRowMapKokkosView Foi_d = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Foi_h); 971 MatColIdxKokkosView Foj_d = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Foj_h); 972 973 PetscCallCXX(mm->Fd = KokkosCsrMatrix("Fd", Fm, cend - cstart, Fdnz, Fda_d, Fdi_d, Fdj_d)); 974 PetscCallCXX(mm->Fo = KokkosCsrMatrix("Fo", Fm, n2, Fonz, Foa_d, Foi_d, Foj_d)); 975 976 // Compute kernel launch parameters in merging E or splitting F 977 PetscInt teamSize, vectorLength, rowsPerTeam; 978 979 teamSize = vectorLength = rowsPerTeam = -1; 980 PetscCall(MatMergeGetLaunchParameters<DefaultExecutionSpace>(mm->irootloc.extent(0), mm->rootBuf.extent(0), -1, teamSize, vectorLength, rowsPerTeam)); 981 mm->E_TeamSize = teamSize; 982 mm->E_VectorLength = vectorLength; 983 mm->E_RowsPerTeam = rowsPerTeam; 984 985 teamSize = vectorLength = rowsPerTeam = -1; 986 PetscCall(MatMergeGetLaunchParameters<DefaultExecutionSpace>(Fm, Fnz, -1, teamSize, vectorLength, rowsPerTeam)); 987 mm->F_TeamSize = teamSize; 988 mm->F_VectorLength = vectorLength; 989 mm->F_RowsPerTeam = rowsPerTeam; 990 } else PetscCheck(reuse == MAT_REUSE_MATRIX, comm, PETSC_ERR_PLIB, "Unsupported MatReuse enum %d", reuse); 991 992 // Sync E's value to device 993 akok->a_dual.sync_device(); 994 bkok->a_dual.sync_device(); 995 996 // Handy aliases 997 const auto &Aa = akok->a_dual.view_device(); 998 const auto &Ba = bkok->a_dual.view_device(); 999 const auto &Ai = akok->i_dual.view_device(); 1000 const auto &Bi = bkok->i_dual.view_device(); 1001 1002 // Fetch the plans 1003 PetscIntKokkosView &E_NzLeft = mm->E_NzLeft; 1004 PetscSF &bcastSF = mm->sf; 1005 MatScalarKokkosView &rootBuf = mm->rootBuf; 1006 MatScalarKokkosView &leafBuf = mm->leafBuf; 1007 PetscIntKokkosView &irootloc = mm->irootloc; 1008 PetscIntKokkosView &rowoffset = mm->rowoffset; 1009 1010 PetscInt teamSize = mm->E_TeamSize; 1011 PetscInt vectorLength = mm->E_VectorLength; 1012 PetscInt rowsPerTeam = mm->E_RowsPerTeam; 1013 PetscInt workSets = (irootloc.extent(0) + rowsPerTeam - 1) / rowsPerTeam; 1014 1015 // Copy rows in A/B of E to rootBuf, then bcast it to leafBuf 1016 PetscCallCXX(Kokkos::parallel_for( 1017 Kokkos::TeamPolicy<>(PetscGetKokkosExecutionSpace(), workSets, teamSize, vectorLength), KOKKOS_LAMBDA(const KokkosTeamMemberType &t) { 1018 Kokkos::parallel_for(Kokkos::TeamThreadRange(t, 0, rowsPerTeam), [&](PetscInt k) { 1019 size_t r = t.league_rank() * rowsPerTeam + k; // r-th entry in irootloc[] 1020 if (r < irootloc.extent(0)) { 1021 PetscInt i = irootloc(r); // row i of E 1022 PetscInt disp = rowoffset(r); 1023 PetscInt alen = Ai(i + 1) - Ai(i); 1024 PetscInt blen = Bi(i + 1) - Bi(i); 1025 PetscInt nzleft = E_NzLeft(i); 1026 1027 Kokkos::parallel_for(Kokkos::ThreadVectorRange(t, alen + blen), [&](PetscInt j) { 1028 if (j < nzleft) { // B left 1029 rootBuf(disp + j) = Ba(Bi(i) + j); 1030 } else if (j < nzleft + alen) { // diag A 1031 rootBuf(disp + j) = Aa(Ai(i) + j - nzleft); 1032 } else { // B right 1033 rootBuf(disp + j) = Ba(Bi(i) + j - alen); 1034 } 1035 }); 1036 } 1037 }); 1038 })); 1039 PetscCall(PetscSFBcastWithMemTypeBegin(bcastSF, MPIU_SCALAR, PETSC_MEMTYPE_KOKKOS, rootBuf.data(), PETSC_MEMTYPE_KOKKOS, leafBuf.data(), MPI_REPLACE)); 1040 PetscFunctionReturn(PETSC_SUCCESS); 1041 } 1042 1043 // To finish MatMPIAIJKokkosBcast. 1044 static PetscErrorCode MatMPIAIJKokkosBcastEnd(Mat E, PetscSF ownerSF, MatReuse reuse, PetscInt *map, MatMatStruct_AB *mm) 1045 { 1046 PetscFunctionBegin; 1047 const auto &Fd = mm->Fd; 1048 const auto &Fo = mm->Fo; 1049 const auto &Fdi = Fd.graph.row_map; 1050 const auto &Foi = Fo.graph.row_map; 1051 auto &Fda = Fd.values; 1052 auto &Foa = Fo.values; 1053 auto Fm = Fd.numRows(); 1054 1055 PetscIntKokkosView &F_NzLeft = mm->F_NzLeft; 1056 PetscSF &bcastSF = mm->sf; 1057 MatScalarKokkosView &rootBuf = mm->rootBuf; 1058 MatScalarKokkosView &leafBuf = mm->leafBuf; 1059 PetscInt teamSize = mm->F_TeamSize; 1060 PetscInt vectorLength = mm->F_VectorLength; 1061 PetscInt rowsPerTeam = mm->F_RowsPerTeam; 1062 PetscInt workSets = (Fm + rowsPerTeam - 1) / rowsPerTeam; 1063 1064 PetscCall(PetscSFBcastEnd(bcastSF, MPIU_SCALAR, rootBuf.data(), leafBuf.data(), MPI_REPLACE)); 1065 1066 // Update Fda and Foa with new data in leafBuf (as if it is Fa) 1067 PetscCallCXX(Kokkos::parallel_for( 1068 Kokkos::TeamPolicy<>(PetscGetKokkosExecutionSpace(), workSets, teamSize, vectorLength), KOKKOS_LAMBDA(const KokkosTeamMemberType &t) { 1069 Kokkos::parallel_for(Kokkos::TeamThreadRange(t, 0, rowsPerTeam), [&](PetscInt k) { 1070 PetscInt i = t.league_rank() * rowsPerTeam + k; // i-th row in F 1071 if (i < Fm) { 1072 PetscInt nzLeft = F_NzLeft(i); 1073 PetscInt alen = Fdi(i + 1) - Fdi(i); 1074 PetscInt blen = Foi(i + 1) - Foi(i); 1075 PetscInt Fii = Fdi(i) + Foi(i); 1076 1077 Kokkos::parallel_for(Kokkos::ThreadVectorRange(t, alen + blen), [&](PetscInt j) { 1078 PetscScalar val = leafBuf(Fii + j); 1079 if (j < nzLeft) { // left 1080 Foa(Foi(i) + j) = val; 1081 } else if (j < nzLeft + alen) { // diag 1082 Fda(Fdi(i) + j - nzLeft) = val; 1083 } else { // right 1084 Foa(Foi(i) + j - alen) = val; 1085 } 1086 }); 1087 } 1088 }); 1089 })); 1090 PetscFunctionReturn(PETSC_SUCCESS); 1091 } 1092 1093 static PetscErrorCode MatProductSymbolic_MPIAIJKokkos_AtB(Mat_Product *product, Mat A, Mat B, MatMatStruct_AtB *mm) 1094 { 1095 Mat_MPIAIJ *ampi = static_cast<Mat_MPIAIJ *>(A->data); 1096 Mat_MPIAIJ *bmpi = static_cast<Mat_MPIAIJ *>(B->data); 1097 KokkosCsrMatrix Adt, Aot, Ad, Ao, Bd, Bo; 1098 PetscInt cstart, cend; 1099 MPI_Comm comm; 1100 1101 PetscFunctionBegin; 1102 PetscCall(PetscObjectGetComm((PetscObject)B, &comm)); 1103 PetscCall(MatSeqAIJKokkosGenerateTranspose_Private(ampi->A, &Adt)); 1104 PetscCall(MatSeqAIJKokkosGenerateTranspose_Private(ampi->B, &Aot)); 1105 PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(ampi->A, &Ad)); 1106 PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(ampi->B, &Ao)); 1107 PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(bmpi->A, &Bd)); 1108 PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(bmpi->B, &Bo)); 1109 1110 // TODO: add command line options to select spgemm algorithms 1111 auto spgemm_alg = KokkosSparse::SPGEMMAlgorithm::SPGEMM_DEFAULT; // default is TPL if enabled, otherwise KK 1112 1113 // CUDA-10.2's spgemm has bugs. We prefer the SpGEMMreuse APIs introduced in cuda-11.4 1114 #if defined(KOKKOSKERNELS_ENABLE_TPL_CUSPARSE) 1115 #if PETSC_PKG_CUDA_VERSION_LT(11, 4, 0) 1116 spgemm_alg = KokkosSparse::SPGEMMAlgorithm::SPGEMM_KK; 1117 #endif 1118 #endif 1119 1120 PetscCallCXX(mm->kh1.create_spgemm_handle(spgemm_alg)); 1121 PetscCallCXX(mm->kh2.create_spgemm_handle(spgemm_alg)); 1122 PetscCallCXX(mm->kh3.create_spgemm_handle(spgemm_alg)); 1123 PetscCallCXX(mm->kh4.create_spgemm_handle(spgemm_alg)); 1124 1125 // Aot * (B's diag + B's off-diag) 1126 PetscCallCXX(KokkosSparse::spgemm_symbolic(mm->kh3, Aot, false, Bd, false, mm->C3)); 1127 PetscCallCXX(KokkosSparse::spgemm_symbolic(mm->kh4, Aot, false, Bo, false, mm->C4)); 1128 // KK spgemm_symbolic() only populates the result's row map, but not its columns. 1129 // TODO: Remove the fake spgemm_numeric() after KK fixed this problem. 1130 PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh3, Aot, false, Bd, false, mm->C3)); 1131 PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh4, Aot, false, Bo, false, mm->C4)); 1132 #if PETSC_PKG_KOKKOS_KERNELS_VERSION_LT(4, 0, 0) 1133 1134 PetscCallCXX(sort_crs_matrix(mm->C3)); 1135 PetscCallCXX(sort_crs_matrix(mm->C4)); 1136 #endif 1137 1138 // Reduce E (i.e., C3 and C4)'s rows to form F, and overlap the communication 1139 PetscIntKokkosViewHost map_h(NoInit("map_h"), bmpi->B->cmap->n); 1140 PetscCall(MatGetOwnershipRangeColumn(B, &cstart, &cend)); 1141 PetscCall(MatMPIAIJKokkosReduceBegin(comm, mm->C3, mm->C4, cstart, cend, bmpi->garray, ampi->Mvctx, MAT_INITIAL_MATRIX, map_h.data(), mm)); 1142 1143 // Adt * (B's diag + B's off-diag) 1144 PetscCallCXX(KokkosSparse::spgemm_symbolic(mm->kh1, Adt, false, Bd, false, mm->C1)); 1145 PetscCallCXX(KokkosSparse::spgemm_symbolic(mm->kh2, Adt, false, Bo, false, mm->C2_mid)); 1146 PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh1, Adt, false, Bd, false, mm->C1)); 1147 PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh2, Adt, false, Bo, false, mm->C2_mid)); 1148 #if PETSC_PKG_KOKKOS_KERNELS_VERSION_LT(4, 0, 0) 1149 PetscCallCXX(sort_crs_matrix(mm->C1)); 1150 PetscCallCXX(sort_crs_matrix(mm->C2_mid)); 1151 #endif 1152 1153 PetscCall(MatMPIAIJKokkosReduceEnd(comm, mm->C3, mm->C4, cstart, cend, bmpi->garray, ampi->Mvctx, MAT_INITIAL_MATRIX, map_h.data(), mm)); 1154 1155 // Create C2, which shares a, i arrays with C2_mid, but with new column indices and potentially larger column size 1156 MatColIdxKokkosView oldj = mm->C2_mid.graph.entries, newj(NoInit("j"), oldj.extent(0)); 1157 PetscIntKokkosView map = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), map_h); 1158 PetscCallCXX(Kokkos::parallel_for(Kokkos::RangePolicy<>(PetscGetKokkosExecutionSpace(), 0, oldj.extent(0)), KOKKOS_LAMBDA(const PetscInt i) { newj(i) = map(oldj(i)); })); 1159 PetscCallCXX(mm->C2 = KokkosCsrMatrix("C2", mm->C2_mid.numRows(), mm->n /*new column size*/, mm->C2_mid.nnz(), mm->C2_mid.values, mm->C2_mid.graph.row_map, newj)); 1160 1161 // C = (C1+Fd, C2+Fo) 1162 PetscCallCXX(mm->kh1.create_spadd_handle(true)); // C1, Fd are sorted 1163 PetscCallCXX(mm->kh2.create_spadd_handle(true)); // C2, Fo are sorted 1164 PetscCallCXX(KokkosSparse::spadd_symbolic(&mm->kh1, mm->C1, mm->Fd, mm->Cd)); 1165 PetscCallCXX(KokkosSparse::spadd_symbolic(&mm->kh2, mm->C2, mm->Fo, mm->Co)); 1166 PetscCallCXX(KokkosSparse::spadd_numeric(&mm->kh1, 1.0, mm->C1, 1.0, mm->Fd, mm->Cd)); 1167 PetscCallCXX(KokkosSparse::spadd_numeric(&mm->kh2, 1.0, mm->C2, 1.0, mm->Fo, mm->Co)); 1168 PetscFunctionReturn(PETSC_SUCCESS); 1169 } 1170 1171 static PetscErrorCode MatProductNumeric_MPIAIJKokkos_AtB(Mat_Product *product, Mat A, Mat B, MatMatStruct_AtB *mm) 1172 { 1173 Mat_MPIAIJ *ampi = static_cast<Mat_MPIAIJ *>(A->data); 1174 Mat_MPIAIJ *bmpi = static_cast<Mat_MPIAIJ *>(B->data); 1175 KokkosCsrMatrix Adt, Aot, Bd, Bo; 1176 MPI_Comm comm; 1177 1178 PetscFunctionBegin; 1179 PetscCall(PetscObjectGetComm((PetscObject)B, &comm)); 1180 PetscCall(MatSeqAIJKokkosGenerateTranspose_Private(ampi->A, &Adt)); 1181 PetscCall(MatSeqAIJKokkosGenerateTranspose_Private(ampi->B, &Aot)); 1182 PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(bmpi->A, &Bd)); 1183 PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(bmpi->B, &Bo)); 1184 1185 // Aot * (B's diag + B's off-diag) 1186 PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh3, Aot, false, Bd, false, mm->C3)); 1187 PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh4, Aot, false, Bo, false, mm->C4)); 1188 1189 // Reduce E (i.e., C3 and C4)'s rows to form F, and overlap the communication 1190 PetscCall(MatMPIAIJKokkosReduceBegin(comm, mm->C3, mm->C4, 0, 0, NULL, NULL, MAT_REUSE_MATRIX, NULL, mm)); 1191 1192 // Adt * (B's diag + B's off-diag) 1193 PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh1, Adt, false, Bd, false, mm->C1)); 1194 PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh2, Adt, false, Bo, false, mm->C2_mid)); 1195 1196 PetscCall(MatMPIAIJKokkosReduceEnd(comm, mm->C3, mm->C4, 0, 0, NULL, NULL, MAT_REUSE_MATRIX, NULL, mm)); 1197 1198 // C = (C1+Fd, C2+Fo) 1199 PetscCallCXX(KokkosSparse::spadd_numeric(&mm->kh1, 1.0, mm->C1, 1.0, mm->Fd, mm->Cd)); 1200 PetscCallCXX(KokkosSparse::spadd_numeric(&mm->kh2, 1.0, mm->C2, 1.0, mm->Fo, mm->Co)); 1201 PetscFunctionReturn(PETSC_SUCCESS); 1202 } 1203 1204 /* MatProductSymbolic_MPIAIJKokkos_AB - AB flavor of MatProductSymbolic_MPIAIJKokkos 1205 1206 Input Parameters: 1207 + product - Mat_Product which carried out the computation. Passed in to access info about this mat product. 1208 . A - an MPIAIJKOKKOS matrix 1209 . B - an MPIAIJKOKKOS matrix 1210 - mm - a struct used to stash intermediate data when computing AB. Persist from symbolic to numeric operations. 1211 */ 1212 static PetscErrorCode MatProductSymbolic_MPIAIJKokkos_AB(Mat_Product *product, Mat A, Mat B, MatMatStruct_AB *mm) 1213 { 1214 Mat_MPIAIJ *ampi = static_cast<Mat_MPIAIJ *>(A->data); 1215 Mat_MPIAIJ *bmpi = static_cast<Mat_MPIAIJ *>(B->data); 1216 KokkosCsrMatrix Ad, Ao, Bd, Bo; 1217 1218 PetscFunctionBegin; 1219 PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(ampi->A, &Ad)); 1220 PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(ampi->B, &Ao)); 1221 PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(bmpi->A, &Bd)); 1222 PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(bmpi->B, &Bo)); 1223 1224 // TODO: add command line options to select spgemm algorithms 1225 auto spgemm_alg = KokkosSparse::SPGEMMAlgorithm::SPGEMM_DEFAULT; // default is TPL if enabled, otherwise KK 1226 1227 // CUDA-10.2's spgemm has bugs. We prefer the SpGEMMreuse APIs introduced in cuda-11.4 1228 #if defined(KOKKOSKERNELS_ENABLE_TPL_CUSPARSE) 1229 #if PETSC_PKG_CUDA_VERSION_LT(11, 4, 0) 1230 spgemm_alg = KokkosSparse::SPGEMMAlgorithm::SPGEMM_KK; 1231 #endif 1232 #endif 1233 1234 mm->kh1.create_spgemm_handle(spgemm_alg); 1235 mm->kh2.create_spgemm_handle(spgemm_alg); 1236 mm->kh3.create_spgemm_handle(spgemm_alg); 1237 mm->kh4.create_spgemm_handle(spgemm_alg); 1238 1239 // Bcast B's rows to form F, and overlap the communication 1240 PetscIntKokkosViewHost map_h(NoInit("map_h"), bmpi->B->cmap->n); 1241 PetscCall(MatMPIAIJKokkosBcastBegin(B, ampi->Mvctx, MAT_INITIAL_MATRIX, map_h.data(), mm)); 1242 1243 // A's diag * (B's diag + B's off-diag) 1244 PetscCallCXX(KokkosSparse::spgemm_symbolic(mm->kh1, Ad, false, Bd, false, mm->C1)); 1245 PetscCallCXX(KokkosSparse::spgemm_symbolic(mm->kh2, Ad, false, Bo, false, mm->C2_mid)); // C2 aliases with C2_mid, except with new column indices 1246 // KK spgemm_symbolic() only populates the result's row map, but not its columns. 1247 // TODO: Remove the fake spgemm_numeric() after KK fixed this problem. 1248 PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh1, Ad, false, Bd, false, mm->C1)); 1249 PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh2, Ad, false, Bo, false, mm->C2_mid)); 1250 #if PETSC_PKG_KOKKOS_KERNELS_VERSION_LT(4, 0, 0) 1251 PetscCallCXX(sort_crs_matrix(mm->C1)); 1252 PetscCallCXX(sort_crs_matrix(mm->C2_mid)); 1253 #endif 1254 1255 PetscCall(MatMPIAIJKokkosBcastEnd(B, ampi->Mvctx, MAT_INITIAL_MATRIX, map_h.data(), mm)); 1256 1257 // A's off-diag * (F's diag + F's off-diag) 1258 PetscCallCXX(KokkosSparse::spgemm_symbolic(mm->kh3, Ao, false, mm->Fd, false, mm->C3)); 1259 PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh3, Ao, false, mm->Fd, false, mm->C3)); 1260 PetscCallCXX(KokkosSparse::spgemm_symbolic(mm->kh4, Ao, false, mm->Fo, false, mm->C4)); 1261 PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh4, Ao, false, mm->Fo, false, mm->C4)); 1262 #if PETSC_PKG_KOKKOS_KERNELS_VERSION_LT(4, 0, 0) 1263 PetscCallCXX(sort_crs_matrix(mm->C3)); 1264 PetscCallCXX(sort_crs_matrix(mm->C4)); 1265 #endif 1266 1267 // Create C2, which shares a, i arrays with C2_mid, but with new column indices and potentially larger column size 1268 MatColIdxKokkosView oldj = mm->C2_mid.graph.entries, newj(NoInit("j"), oldj.extent(0)); 1269 PetscIntKokkosView map = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), map_h); 1270 PetscCallCXX(Kokkos::parallel_for(Kokkos::RangePolicy<>(PetscGetKokkosExecutionSpace(), 0, oldj.extent(0)), KOKKOS_LAMBDA(const PetscInt i) { newj(i) = map(oldj(i)); })); 1271 mm->C2 = KokkosCsrMatrix("C2", mm->C2_mid.numRows(), mm->n /*new column size*/, mm->C2_mid.nnz(), mm->C2_mid.values, mm->C2_mid.graph.row_map, newj); 1272 1273 // C = (Cd, Co) = (C1+C3, C2+C4) 1274 mm->kh1.create_spadd_handle(true); // C1, C3 are sorted 1275 mm->kh2.create_spadd_handle(true); // C2, C4 are sorted 1276 PetscCallCXX(KokkosSparse::spadd_symbolic(&mm->kh1, mm->C1, mm->C3, mm->Cd)); 1277 PetscCallCXX(KokkosSparse::spadd_symbolic(&mm->kh2, mm->C2, mm->C4, mm->Co)); 1278 PetscCallCXX(KokkosSparse::spadd_numeric(&mm->kh1, 1.0, mm->C1, 1.0, mm->C3, mm->Cd)); 1279 PetscCallCXX(KokkosSparse::spadd_numeric(&mm->kh2, 1.0, mm->C2, 1.0, mm->C4, mm->Co)); 1280 PetscFunctionReturn(PETSC_SUCCESS); 1281 } 1282 1283 static PetscErrorCode MatProductNumeric_MPIAIJKokkos_AB(Mat_Product *product, Mat A, Mat B, MatMatStruct_AB *mm) 1284 { 1285 Mat_MPIAIJ *ampi = static_cast<Mat_MPIAIJ *>(A->data); 1286 Mat_MPIAIJ *bmpi = static_cast<Mat_MPIAIJ *>(B->data); 1287 KokkosCsrMatrix Ad, Ao, Bd, Bo; 1288 1289 PetscFunctionBegin; 1290 PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(ampi->A, &Ad)); 1291 PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(ampi->B, &Ao)); 1292 PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(bmpi->A, &Bd)); 1293 PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(bmpi->B, &Bo)); 1294 1295 // Bcast B's rows to form F, and overlap the communication 1296 PetscCall(MatMPIAIJKokkosBcastBegin(B, NULL, MAT_REUSE_MATRIX, NULL, mm)); 1297 1298 // A's diag * (B's diag + B's off-diag) 1299 PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh1, Ad, false, Bd, false, mm->C1)); 1300 PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh2, Ad, false, Bo, false, mm->C2_mid)); 1301 1302 PetscCall(MatMPIAIJKokkosBcastEnd(B, NULL, MAT_REUSE_MATRIX, NULL, mm)); 1303 1304 // A's off-diag * (F's diag + F's off-diag) 1305 PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh3, Ao, false, mm->Fd, false, mm->C3)); 1306 PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh4, Ao, false, mm->Fo, false, mm->C4)); 1307 1308 // C = (Cd, Co) = (C1+C3, C2+C4) 1309 PetscCallCXX(KokkosSparse::spadd_numeric(&mm->kh1, 1.0, mm->C1, 1.0, mm->C3, mm->Cd)); 1310 PetscCallCXX(KokkosSparse::spadd_numeric(&mm->kh2, 1.0, mm->C2, 1.0, mm->C4, mm->Co)); 1311 PetscFunctionReturn(PETSC_SUCCESS); 1312 } 1313 1314 static PetscErrorCode MatProductNumeric_MPIAIJKokkos(Mat C) 1315 { 1316 Mat_MPIAIJ *cmpi = static_cast<Mat_MPIAIJ *>(C->data); 1317 Mat_Product *product; 1318 MatProductData_MPIAIJKokkos *pdata; 1319 MatProductType ptype; 1320 Mat A, B; 1321 1322 PetscFunctionBegin; 1323 MatCheckProduct(C, 1); // make sure C is a product 1324 product = C->product; 1325 pdata = static_cast<MatProductData_MPIAIJKokkos *>(product->data); 1326 ptype = product->type; 1327 A = product->A; 1328 B = product->B; 1329 1330 // See if numeric has already been done in symbolic (e.g., user calls MatMatMult(A,B,MAT_INITIAL_MATRIX,..,C)). 1331 // If yes, skip the numeric, but reset the flag so that next time when user calls MatMatMult(E,F,MAT_REUSE_MATRIX,..,C), 1332 // we still do numeric. 1333 if (pdata->reusesym) { // numeric reuses results from symbolic 1334 pdata->reusesym = PETSC_FALSE; 1335 PetscFunctionReturn(PETSC_SUCCESS); 1336 } 1337 1338 if (ptype == MATPRODUCT_AB) { 1339 PetscCall(MatProductNumeric_MPIAIJKokkos_AB(product, A, B, pdata->mmAB)); 1340 } else if (ptype == MATPRODUCT_AtB) { 1341 PetscCall(MatProductNumeric_MPIAIJKokkos_AtB(product, A, B, pdata->mmAtB)); 1342 } else if (ptype == MATPRODUCT_PtAP) { // BtAB, computed by Z = AB; C= BtZ 1343 PetscCall(MatProductNumeric_MPIAIJKokkos_AB(product, A, B, pdata->mmAB)); 1344 PetscCall(MatProductNumeric_MPIAIJKokkos_AtB(product, B, pdata->Z, pdata->mmAtB)); 1345 } 1346 1347 PetscCall(MatSeqAIJKokkosModifyDevice(cmpi->A)); // mark that A, B on device are modified 1348 PetscCall(MatSeqAIJKokkosModifyDevice(cmpi->B)); 1349 PetscFunctionReturn(PETSC_SUCCESS); 1350 } 1351 1352 static PetscErrorCode MatProductSymbolic_MPIAIJKokkos(Mat C) 1353 { 1354 Mat A, B; 1355 Mat_Product *product; 1356 MatProductType ptype; 1357 MatProductData_MPIAIJKokkos *pdata; 1358 MatMatStruct *mm = NULL; 1359 PetscInt m, n, M, N; 1360 Mat Cd, Co; 1361 MPI_Comm comm; 1362 1363 PetscFunctionBegin; 1364 PetscCall(PetscObjectGetComm((PetscObject)C, &comm)); 1365 MatCheckProduct(C, 1); 1366 product = C->product; 1367 PetscCheck(!product->data, comm, PETSC_ERR_PLIB, "Product data not empty"); 1368 ptype = product->type; 1369 A = product->A; 1370 B = product->B; 1371 1372 switch (ptype) { 1373 case MATPRODUCT_AB: 1374 m = A->rmap->n; 1375 n = B->cmap->n; 1376 M = A->rmap->N; 1377 N = B->cmap->N; 1378 break; 1379 case MATPRODUCT_AtB: 1380 m = A->cmap->n; 1381 n = B->cmap->n; 1382 M = A->cmap->N; 1383 N = B->cmap->N; 1384 break; 1385 case MATPRODUCT_PtAP: 1386 m = B->cmap->n; 1387 n = B->cmap->n; 1388 M = B->cmap->N; 1389 N = B->cmap->N; 1390 break; /* BtAB */ 1391 default: 1392 SETERRQ(comm, PETSC_ERR_PLIB, "Not for product type %s", MatProductTypes[ptype]); 1393 } 1394 1395 PetscCall(MatSetSizes(C, m, n, M, N)); 1396 PetscCall(PetscLayoutSetUp(C->rmap)); 1397 PetscCall(PetscLayoutSetUp(C->cmap)); 1398 PetscCall(MatSetType(C, ((PetscObject)A)->type_name)); 1399 1400 pdata = new MatProductData_MPIAIJKokkos(); 1401 pdata->reusesym = product->api_user; 1402 1403 if (ptype == MATPRODUCT_AB) { 1404 auto mmAB = new MatMatStruct_AB(); 1405 PetscCall(MatProductSymbolic_MPIAIJKokkos_AB(product, A, B, mmAB)); 1406 mm = pdata->mmAB = mmAB; 1407 } else if (ptype == MATPRODUCT_AtB) { 1408 auto mmAtB = new MatMatStruct_AtB(); 1409 PetscCall(MatProductSymbolic_MPIAIJKokkos_AtB(product, A, B, mmAtB)); 1410 mm = pdata->mmAtB = mmAtB; 1411 } else if (ptype == MATPRODUCT_PtAP) { // C = BtAB, computed as Z = AB; C= BtZ 1412 Mat Zd, Zo, Z; // Zd, Zo are owned by pdata->Z 1413 1414 auto mmAB = new MatMatStruct_AB(); 1415 PetscCall(MatProductSymbolic_MPIAIJKokkos_AB(product, A, B, mmAB)); // Z stored as mmAB->{Cd, Co} 1416 PetscCall(MatCreateSeqAIJKokkosWithKokkosCsrMatrix(PETSC_COMM_SELF, mmAB->Cd, &Zd)); 1417 PetscCall(MatCreateSeqAIJKokkosWithKokkosCsrMatrix(PETSC_COMM_SELF, mmAB->Co, &Zo)); 1418 pdata->mmAB = mmAB; 1419 1420 m = A->rmap->n; // Z's layout 1421 n = B->cmap->n; 1422 M = A->rmap->N; 1423 N = B->cmap->N; 1424 PetscCall(MatCreate(comm, &Z)); 1425 PetscCall(MatSetSizes(Z, m, n, M, N)); 1426 PetscCall(PetscLayoutSetUp(Z->rmap)); 1427 PetscCall(PetscLayoutSetUp(Z->cmap)); 1428 PetscCall(MatSetType(Z, MATMPIAIJKOKKOS)); 1429 PetscCall(MatSetMPIAIJKokkosWithSplitSeqAIJKokkosMatrices(Z, Zd, Zo, mmAB->garray)); 1430 1431 auto mmAtB = new MatMatStruct_AtB(); 1432 PetscCall(MatProductSymbolic_MPIAIJKokkos_AtB(product, B, Z, mmAtB)); // final result C stored as mmAtB->{Cd, Co} 1433 1434 pdata->Z = Z; // give ownership to pdata 1435 mm = pdata->mmAtB = mmAtB; 1436 } 1437 1438 PetscCall(MatCreateSeqAIJKokkosWithKokkosCsrMatrix(PETSC_COMM_SELF, mm->Cd, &Cd)); 1439 PetscCall(MatCreateSeqAIJKokkosWithKokkosCsrMatrix(PETSC_COMM_SELF, mm->Co, &Co)); 1440 PetscCall(MatSetMPIAIJKokkosWithSplitSeqAIJKokkosMatrices(C, Cd, Co, mm->garray)); 1441 1442 C->product->data = pdata; 1443 C->product->destroy = MatProductDataDestroy_MPIAIJKokkos; 1444 C->ops->productnumeric = MatProductNumeric_MPIAIJKokkos; 1445 PetscFunctionReturn(PETSC_SUCCESS); 1446 } 1447 1448 PETSC_INTERN PetscErrorCode MatProductSetFromOptions_MPIAIJKokkos(Mat mat) 1449 { 1450 Mat_Product *product = mat->product; 1451 PetscBool match = PETSC_FALSE; 1452 PetscBool usecpu = PETSC_FALSE; 1453 1454 PetscFunctionBegin; 1455 MatCheckProduct(mat, 1); 1456 if (!product->A->boundtocpu && !product->B->boundtocpu) PetscCall(PetscObjectTypeCompare((PetscObject)product->B, ((PetscObject)product->A)->type_name, &match)); 1457 if (match) { /* we can always fallback to the CPU if requested */ 1458 switch (product->type) { 1459 case MATPRODUCT_AB: 1460 if (product->api_user) { 1461 PetscOptionsBegin(PetscObjectComm((PetscObject)mat), ((PetscObject)mat)->prefix, "MatMatMult", "Mat"); 1462 PetscCall(PetscOptionsBool("-matmatmult_backend_cpu", "Use CPU code", "MatMatMult", usecpu, &usecpu, NULL)); 1463 PetscOptionsEnd(); 1464 } else { 1465 PetscOptionsBegin(PetscObjectComm((PetscObject)mat), ((PetscObject)mat)->prefix, "MatProduct_AB", "Mat"); 1466 PetscCall(PetscOptionsBool("-mat_product_algorithm_backend_cpu", "Use CPU code", "MatMatMult", usecpu, &usecpu, NULL)); 1467 PetscOptionsEnd(); 1468 } 1469 break; 1470 case MATPRODUCT_AtB: 1471 if (product->api_user) { 1472 PetscOptionsBegin(PetscObjectComm((PetscObject)mat), ((PetscObject)mat)->prefix, "MatTransposeMatMult", "Mat"); 1473 PetscCall(PetscOptionsBool("-mattransposematmult_backend_cpu", "Use CPU code", "MatTransposeMatMult", usecpu, &usecpu, NULL)); 1474 PetscOptionsEnd(); 1475 } else { 1476 PetscOptionsBegin(PetscObjectComm((PetscObject)mat), ((PetscObject)mat)->prefix, "MatProduct_AtB", "Mat"); 1477 PetscCall(PetscOptionsBool("-mat_product_algorithm_backend_cpu", "Use CPU code", "MatTransposeMatMult", usecpu, &usecpu, NULL)); 1478 PetscOptionsEnd(); 1479 } 1480 break; 1481 case MATPRODUCT_PtAP: 1482 if (product->api_user) { 1483 PetscOptionsBegin(PetscObjectComm((PetscObject)mat), ((PetscObject)mat)->prefix, "MatPtAP", "Mat"); 1484 PetscCall(PetscOptionsBool("-matptap_backend_cpu", "Use CPU code", "MatPtAP", usecpu, &usecpu, NULL)); 1485 PetscOptionsEnd(); 1486 } else { 1487 PetscOptionsBegin(PetscObjectComm((PetscObject)mat), ((PetscObject)mat)->prefix, "MatProduct_PtAP", "Mat"); 1488 PetscCall(PetscOptionsBool("-mat_product_algorithm_backend_cpu", "Use CPU code", "MatPtAP", usecpu, &usecpu, NULL)); 1489 PetscOptionsEnd(); 1490 } 1491 break; 1492 default: 1493 break; 1494 } 1495 match = (PetscBool)!usecpu; 1496 } 1497 if (match) { 1498 switch (product->type) { 1499 case MATPRODUCT_AB: 1500 case MATPRODUCT_AtB: 1501 case MATPRODUCT_PtAP: 1502 mat->ops->productsymbolic = MatProductSymbolic_MPIAIJKokkos; 1503 break; 1504 default: 1505 break; 1506 } 1507 } 1508 /* fallback to MPIAIJ ops */ 1509 if (!mat->ops->productsymbolic) PetscCall(MatProductSetFromOptions_MPIAIJ(mat)); 1510 PetscFunctionReturn(PETSC_SUCCESS); 1511 } 1512 1513 // Mirror of MatCOOStruct_MPIAIJ on device 1514 struct MatCOOStruct_MPIAIJKokkos { 1515 PetscCount n; 1516 PetscSF sf; 1517 PetscCount Annz, Bnnz; 1518 PetscCount Annz2, Bnnz2; 1519 PetscCountKokkosView Ajmap1, Aperm1; 1520 PetscCountKokkosView Bjmap1, Bperm1; 1521 PetscCountKokkosView Aimap2, Ajmap2, Aperm2; 1522 PetscCountKokkosView Bimap2, Bjmap2, Bperm2; 1523 PetscCountKokkosView Cperm1; 1524 MatScalarKokkosView sendbuf, recvbuf; 1525 1526 MatCOOStruct_MPIAIJKokkos(const MatCOOStruct_MPIAIJ *coo_h) 1527 { 1528 auto &exec = PetscGetKokkosExecutionSpace(); 1529 1530 n = coo_h->n; 1531 sf = coo_h->sf; 1532 Annz = coo_h->Annz; 1533 Bnnz = coo_h->Bnnz; 1534 Annz2 = coo_h->Annz2; 1535 Bnnz2 = coo_h->Bnnz2; 1536 Ajmap1 = Kokkos::create_mirror_view_and_copy(exec, PetscCountKokkosViewHost(coo_h->Ajmap1, coo_h->Annz + 1)); 1537 Aperm1 = Kokkos::create_mirror_view_and_copy(exec, PetscCountKokkosViewHost(coo_h->Aperm1, coo_h->Atot1)); 1538 Bjmap1 = Kokkos::create_mirror_view_and_copy(exec, PetscCountKokkosViewHost(coo_h->Bjmap1, coo_h->Bnnz + 1)); 1539 Bperm1 = Kokkos::create_mirror_view_and_copy(exec, PetscCountKokkosViewHost(coo_h->Bperm1, coo_h->Btot1)); 1540 Aimap2 = Kokkos::create_mirror_view_and_copy(exec, PetscCountKokkosViewHost(coo_h->Aimap2, coo_h->Annz2)); 1541 Ajmap2 = Kokkos::create_mirror_view_and_copy(exec, PetscCountKokkosViewHost(coo_h->Ajmap2, coo_h->Annz2 + 1)); 1542 Aperm2 = Kokkos::create_mirror_view_and_copy(exec, PetscCountKokkosViewHost(coo_h->Aperm2, coo_h->Atot2)); 1543 Bimap2 = Kokkos::create_mirror_view_and_copy(exec, PetscCountKokkosViewHost(coo_h->Bimap2, coo_h->Bnnz2)); 1544 Bjmap2 = Kokkos::create_mirror_view_and_copy(exec, PetscCountKokkosViewHost(coo_h->Bjmap2, coo_h->Bnnz2 + 1)); 1545 Bperm2 = Kokkos::create_mirror_view_and_copy(exec, PetscCountKokkosViewHost(coo_h->Bperm2, coo_h->Btot2)); 1546 Cperm1 = Kokkos::create_mirror_view_and_copy(exec, PetscCountKokkosViewHost(coo_h->Cperm1, coo_h->sendlen)); 1547 sendbuf = Kokkos::create_mirror_view(Kokkos::WithoutInitializing, exec, MatScalarKokkosViewHost(coo_h->sendbuf, coo_h->sendlen)); 1548 recvbuf = Kokkos::create_mirror_view(Kokkos::WithoutInitializing, exec, MatScalarKokkosViewHost(coo_h->recvbuf, coo_h->recvlen)); 1549 PetscCallVoid(PetscObjectReference((PetscObject)sf)); 1550 } 1551 1552 ~MatCOOStruct_MPIAIJKokkos() { PetscCallVoid(PetscSFDestroy(&sf)); } 1553 }; 1554 1555 static PetscErrorCode MatCOOStructDestroy_MPIAIJKokkos(void **data) 1556 { 1557 PetscFunctionBegin; 1558 PetscCallCXX(delete static_cast<MatCOOStruct_MPIAIJKokkos *>(*data)); 1559 PetscFunctionReturn(PETSC_SUCCESS); 1560 } 1561 1562 static PetscErrorCode MatSetPreallocationCOO_MPIAIJKokkos(Mat mat, PetscCount coo_n, PetscInt coo_i[], PetscInt coo_j[]) 1563 { 1564 PetscContainer container_h, container_d; 1565 MatCOOStruct_MPIAIJ *coo_h; 1566 MatCOOStruct_MPIAIJKokkos *coo_d; 1567 1568 PetscFunctionBegin; 1569 PetscCall(MatSetPreallocationCOO_MPIAIJ(mat, coo_n, coo_i, coo_j)); /* mpiaij->A,B's type is set to seqaijkokkos */ 1570 mat->preallocated = PETSC_TRUE; 1571 PetscCall(MatAssemblyBegin(mat, MAT_FINAL_ASSEMBLY)); 1572 PetscCall(MatAssemblyEnd(mat, MAT_FINAL_ASSEMBLY)); 1573 PetscCall(MatZeroEntries(mat)); 1574 1575 // Copy the COO struct to device 1576 PetscCall(PetscObjectQuery((PetscObject)mat, "__PETSc_MatCOOStruct_Host", (PetscObject *)&container_h)); 1577 PetscCall(PetscContainerGetPointer(container_h, (void **)&coo_h)); 1578 PetscCallCXX(coo_d = new MatCOOStruct_MPIAIJKokkos(coo_h)); 1579 1580 // Put the COO struct in a container and then attach that to the matrix 1581 PetscCall(PetscContainerCreate(PETSC_COMM_SELF, &container_d)); 1582 PetscCall(PetscContainerSetPointer(container_d, coo_d)); 1583 PetscCall(PetscContainerSetCtxDestroy(container_d, MatCOOStructDestroy_MPIAIJKokkos)); 1584 PetscCall(PetscObjectCompose((PetscObject)mat, "__PETSc_MatCOOStruct_Device", (PetscObject)container_d)); 1585 PetscCall(PetscContainerDestroy(&container_d)); 1586 PetscFunctionReturn(PETSC_SUCCESS); 1587 } 1588 1589 static PetscErrorCode MatSetValuesCOO_MPIAIJKokkos(Mat mat, const PetscScalar v[], InsertMode imode) 1590 { 1591 Mat_MPIAIJ *mpiaij = static_cast<Mat_MPIAIJ *>(mat->data); 1592 Mat A = mpiaij->A, B = mpiaij->B; 1593 MatScalarKokkosView Aa, Ba; 1594 MatScalarKokkosView v1; 1595 PetscMemType memtype; 1596 PetscContainer container; 1597 MatCOOStruct_MPIAIJKokkos *coo; 1598 Kokkos::DefaultExecutionSpace &exec = PetscGetKokkosExecutionSpace(); 1599 1600 PetscFunctionBegin; 1601 PetscCall(PetscObjectQuery((PetscObject)mat, "__PETSc_MatCOOStruct_Device", (PetscObject *)&container)); 1602 PetscCall(PetscContainerGetPointer(container, (void **)&coo)); 1603 1604 const auto &n = coo->n; 1605 const auto &Annz = coo->Annz; 1606 const auto &Annz2 = coo->Annz2; 1607 const auto &Bnnz = coo->Bnnz; 1608 const auto &Bnnz2 = coo->Bnnz2; 1609 const auto &vsend = coo->sendbuf; 1610 const auto &v2 = coo->recvbuf; 1611 const auto &Ajmap1 = coo->Ajmap1; 1612 const auto &Ajmap2 = coo->Ajmap2; 1613 const auto &Aimap2 = coo->Aimap2; 1614 const auto &Bjmap1 = coo->Bjmap1; 1615 const auto &Bjmap2 = coo->Bjmap2; 1616 const auto &Bimap2 = coo->Bimap2; 1617 const auto &Aperm1 = coo->Aperm1; 1618 const auto &Aperm2 = coo->Aperm2; 1619 const auto &Bperm1 = coo->Bperm1; 1620 const auto &Bperm2 = coo->Bperm2; 1621 const auto &Cperm1 = coo->Cperm1; 1622 1623 PetscCall(PetscGetMemType(v, &memtype)); /* Return PETSC_MEMTYPE_HOST when v is NULL */ 1624 if (PetscMemTypeHost(memtype)) { /* If user gave v[] in host, we need to copy it to device if any */ 1625 v1 = Kokkos::create_mirror_view_and_copy(exec, MatScalarKokkosViewHost((PetscScalar *)v, n)); 1626 } else { 1627 v1 = MatScalarKokkosView((PetscScalar *)v, n); /* Directly use v[]'s memory */ 1628 } 1629 1630 if (imode == INSERT_VALUES) { 1631 PetscCall(MatSeqAIJGetKokkosViewWrite(A, &Aa)); /* write matrix values */ 1632 PetscCall(MatSeqAIJGetKokkosViewWrite(B, &Ba)); 1633 } else { 1634 PetscCall(MatSeqAIJGetKokkosView(A, &Aa)); /* read & write matrix values */ 1635 PetscCall(MatSeqAIJGetKokkosView(B, &Ba)); 1636 } 1637 1638 PetscCall(PetscLogGpuTimeBegin()); 1639 /* Pack entries to be sent to remote */ 1640 Kokkos::parallel_for(Kokkos::RangePolicy<>(exec, 0, vsend.extent(0)), KOKKOS_LAMBDA(const PetscCount i) { vsend(i) = v1(Cperm1(i)); }); 1641 1642 /* Send remote entries to their owner and overlap the communication with local computation */ 1643 PetscCall(PetscSFReduceWithMemTypeBegin(coo->sf, MPIU_SCALAR, PETSC_MEMTYPE_KOKKOS, vsend.data(), PETSC_MEMTYPE_KOKKOS, v2.data(), MPI_REPLACE)); 1644 /* Add local entries to A and B in one kernel */ 1645 Kokkos::parallel_for( 1646 Kokkos::RangePolicy<>(exec, 0, Annz + Bnnz), KOKKOS_LAMBDA(PetscCount i) { 1647 PetscScalar sum = 0.0; 1648 if (i < Annz) { 1649 for (PetscCount k = Ajmap1(i); k < Ajmap1(i + 1); k++) sum += v1(Aperm1(k)); 1650 Aa(i) = (imode == INSERT_VALUES ? 0.0 : Aa(i)) + sum; 1651 } else { 1652 i -= Annz; 1653 for (PetscCount k = Bjmap1(i); k < Bjmap1(i + 1); k++) sum += v1(Bperm1(k)); 1654 Ba(i) = (imode == INSERT_VALUES ? 0.0 : Ba(i)) + sum; 1655 } 1656 }); 1657 PetscCall(PetscSFReduceEnd(coo->sf, MPIU_SCALAR, vsend.data(), v2.data(), MPI_REPLACE)); 1658 1659 /* Add received remote entries to A and B in one kernel */ 1660 Kokkos::parallel_for( 1661 Kokkos::RangePolicy<>(exec, 0, Annz2 + Bnnz2), KOKKOS_LAMBDA(PetscCount i) { 1662 if (i < Annz2) { 1663 for (PetscCount k = Ajmap2(i); k < Ajmap2(i + 1); k++) Aa(Aimap2(i)) += v2(Aperm2(k)); 1664 } else { 1665 i -= Annz2; 1666 for (PetscCount k = Bjmap2(i); k < Bjmap2(i + 1); k++) Ba(Bimap2(i)) += v2(Bperm2(k)); 1667 } 1668 }); 1669 PetscCall(PetscLogGpuTimeEnd()); 1670 1671 if (imode == INSERT_VALUES) { 1672 PetscCall(MatSeqAIJRestoreKokkosViewWrite(A, &Aa)); /* Increase A & B's state etc. */ 1673 PetscCall(MatSeqAIJRestoreKokkosViewWrite(B, &Ba)); 1674 } else { 1675 PetscCall(MatSeqAIJRestoreKokkosView(A, &Aa)); 1676 PetscCall(MatSeqAIJRestoreKokkosView(B, &Ba)); 1677 } 1678 PetscFunctionReturn(PETSC_SUCCESS); 1679 } 1680 1681 static PetscErrorCode MatDestroy_MPIAIJKokkos(Mat A) 1682 { 1683 PetscFunctionBegin; 1684 PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatMPIAIJSetPreallocation_C", NULL)); 1685 PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatMPIAIJGetLocalMatMerge_C", NULL)); 1686 PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatSetPreallocationCOO_C", NULL)); 1687 PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatSetValuesCOO_C", NULL)); 1688 #if defined(PETSC_HAVE_HYPRE) 1689 PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatConvert_mpiaijkokkos_hypre_C", NULL)); 1690 #endif 1691 PetscCall(MatDestroy_MPIAIJ(A)); 1692 PetscFunctionReturn(PETSC_SUCCESS); 1693 } 1694 1695 static PetscErrorCode MatShift_MPIAIJKokkos(Mat A, PetscScalar a) 1696 { 1697 Mat_MPIAIJ *mpiaij = static_cast<Mat_MPIAIJ *>(A->data); 1698 PetscBool congruent; 1699 1700 PetscFunctionBegin; 1701 PetscCall(MatHasCongruentLayouts(A, &congruent)); 1702 if (congruent) { // square matrix and the diagonals are solely in the diag block 1703 PetscCall(MatShift(mpiaij->A, a)); 1704 } else { // too hard, use the general version 1705 PetscCall(MatShift_Basic(A, a)); 1706 } 1707 PetscFunctionReturn(PETSC_SUCCESS); 1708 } 1709 1710 static PetscErrorCode MatSetOps_MPIAIJKokkos(Mat B) 1711 { 1712 PetscFunctionBegin; 1713 B->ops->assemblyend = MatAssemblyEnd_MPIAIJKokkos; 1714 B->ops->mult = MatMult_MPIAIJKokkos; 1715 B->ops->multadd = MatMultAdd_MPIAIJKokkos; 1716 B->ops->multtranspose = MatMultTranspose_MPIAIJKokkos; 1717 B->ops->productsetfromoptions = MatProductSetFromOptions_MPIAIJKokkos; 1718 B->ops->destroy = MatDestroy_MPIAIJKokkos; 1719 B->ops->shift = MatShift_MPIAIJKokkos; 1720 1721 PetscCall(PetscObjectComposeFunction((PetscObject)B, "MatMPIAIJSetPreallocation_C", MatMPIAIJSetPreallocation_MPIAIJKokkos)); 1722 PetscCall(PetscObjectComposeFunction((PetscObject)B, "MatMPIAIJGetLocalMatMerge_C", MatMPIAIJGetLocalMatMerge_MPIAIJKokkos)); 1723 PetscCall(PetscObjectComposeFunction((PetscObject)B, "MatSetPreallocationCOO_C", MatSetPreallocationCOO_MPIAIJKokkos)); 1724 PetscCall(PetscObjectComposeFunction((PetscObject)B, "MatSetValuesCOO_C", MatSetValuesCOO_MPIAIJKokkos)); 1725 #if defined(PETSC_HAVE_HYPRE) 1726 PetscCall(PetscObjectComposeFunction((PetscObject)B, "MatConvert_mpiaijkokkos_hypre_C", MatConvert_AIJ_HYPRE)); 1727 #endif 1728 PetscFunctionReturn(PETSC_SUCCESS); 1729 } 1730 1731 PETSC_INTERN PetscErrorCode MatConvert_MPIAIJ_MPIAIJKokkos(Mat A, MatType mtype, MatReuse reuse, Mat *newmat) 1732 { 1733 Mat B; 1734 Mat_MPIAIJ *a; 1735 1736 PetscFunctionBegin; 1737 if (reuse == MAT_INITIAL_MATRIX) { 1738 PetscCall(MatDuplicate(A, MAT_COPY_VALUES, newmat)); 1739 } else if (reuse == MAT_REUSE_MATRIX) { 1740 PetscCall(MatCopy(A, *newmat, SAME_NONZERO_PATTERN)); 1741 } 1742 B = *newmat; 1743 1744 B->boundtocpu = PETSC_FALSE; 1745 PetscCall(PetscFree(B->defaultvectype)); 1746 PetscCall(PetscStrallocpy(VECKOKKOS, &B->defaultvectype)); 1747 PetscCall(PetscObjectChangeTypeName((PetscObject)B, MATMPIAIJKOKKOS)); 1748 1749 a = static_cast<Mat_MPIAIJ *>(A->data); 1750 if (a->A) PetscCall(MatSetType(a->A, MATSEQAIJKOKKOS)); 1751 if (a->B) PetscCall(MatSetType(a->B, MATSEQAIJKOKKOS)); 1752 if (a->lvec) PetscCall(VecSetType(a->lvec, VECSEQKOKKOS)); 1753 PetscCall(MatSetOps_MPIAIJKokkos(B)); 1754 PetscFunctionReturn(PETSC_SUCCESS); 1755 } 1756 1757 /*MC 1758 MATAIJKOKKOS - "mpiaijkokkos", a matrix type to be used for CSR sparse matrices with Kokkos 1759 1760 A matrix type using Kokkos-Kernels CrsMatrix type for portability across different device types 1761 1762 Options Database Key: 1763 . -mat_type aijkokkos - sets the matrix type to `MATAIJKOKKOS` 1764 1765 Level: beginner 1766 1767 .seealso: [](ch_matrices), `Mat`, `MatCreateAIJKokkos()`, `MATSEQAIJKOKKOS`, `MATSEQAIJ`, `MATMPIAIJ` 1768 M*/ 1769 PETSC_EXTERN PetscErrorCode MatCreate_MPIAIJKokkos(Mat A) 1770 { 1771 PetscFunctionBegin; 1772 PetscCall(PetscKokkosInitializeCheck()); 1773 PetscCall(MatCreate_MPIAIJ(A)); 1774 PetscCall(MatConvert_MPIAIJ_MPIAIJKokkos(A, MATMPIAIJKOKKOS, MAT_INPLACE_MATRIX, &A)); 1775 PetscFunctionReturn(PETSC_SUCCESS); 1776 } 1777 1778 /*@C 1779 MatCreateAIJKokkos - Creates a sparse matrix in `MATAIJKOKKOS` (compressed row) format 1780 (the default parallel PETSc format). This matrix will ultimately pushed down 1781 to Kokkos for calculations. 1782 1783 Collective 1784 1785 Input Parameters: 1786 + comm - MPI communicator, set to `PETSC_COMM_SELF` 1787 . m - number of local rows (or `PETSC_DECIDE` to have calculated if `M` is given) 1788 This value should be the same as the local size used in creating the 1789 y vector for the matrix-vector product y = Ax. 1790 . n - This value should be the same as the local size used in creating the 1791 x vector for the matrix-vector product y = Ax. (or `PETSC_DECIDE` to have 1792 calculated if N is given) For square matrices n is almost always `m`. 1793 . M - number of global rows (or `PETSC_DETERMINE` to have calculated if `m` is given) 1794 . N - number of global columns (or `PETSC_DETERMINE` to have calculated if `n` is given) 1795 . d_nz - number of nonzeros per row in DIAGONAL portion of local submatrix 1796 (same value is used for all local rows) 1797 . d_nnz - array containing the number of nonzeros in the various rows of the 1798 DIAGONAL portion of the local submatrix (possibly different for each row) 1799 or `NULL`, if `d_nz` is used to specify the nonzero structure. 1800 The size of this array is equal to the number of local rows, i.e `m`. 1801 For matrices you plan to factor you must leave room for the diagonal entry and 1802 put in the entry even if it is zero. 1803 . o_nz - number of nonzeros per row in the OFF-DIAGONAL portion of local 1804 submatrix (same value is used for all local rows). 1805 - o_nnz - array containing the number of nonzeros in the various rows of the 1806 OFF-DIAGONAL portion of the local submatrix (possibly different for 1807 each row) or `NULL`, if `o_nz` is used to specify the nonzero 1808 structure. The size of this array is equal to the number 1809 of local rows, i.e `m`. 1810 1811 Output Parameter: 1812 . A - the matrix 1813 1814 Level: intermediate 1815 1816 Notes: 1817 It is recommended that one use the `MatCreate()`, `MatSetType()` and/or `MatSetFromOptions()`, 1818 MatXXXXSetPreallocation() paradigm instead of this routine directly. 1819 [MatXXXXSetPreallocation() is, for example, `MatSeqAIJSetPreallocation()`] 1820 1821 The AIJ format, also called compressed row storage), is fully compatible with standard Fortran 1822 storage. That is, the stored row and column indices can begin at 1823 either one (as in Fortran) or zero. 1824 1825 .seealso: [](ch_matrices), `Mat`, `MATAIJKOKKOS`, `MATSEQAIJKOKKOS`, `MATMPIAIJKOKKOS`, `MatCreate()`, `MatCreateAIJ()`, `MatSetValues()`, 1826 `MatSeqAIJSetColumnIndices()`, `MatCreateSeqAIJWithArrays()` 1827 @*/ 1828 PetscErrorCode MatCreateAIJKokkos(MPI_Comm comm, PetscInt m, PetscInt n, PetscInt M, PetscInt N, PetscInt d_nz, const PetscInt d_nnz[], PetscInt o_nz, const PetscInt o_nnz[], Mat *A) 1829 { 1830 PetscMPIInt size; 1831 1832 PetscFunctionBegin; 1833 PetscCall(MatCreate(comm, A)); 1834 PetscCall(MatSetSizes(*A, m, n, M, N)); 1835 PetscCallMPI(MPI_Comm_size(comm, &size)); 1836 if (size > 1) { 1837 PetscCall(MatSetType(*A, MATMPIAIJKOKKOS)); 1838 PetscCall(MatMPIAIJSetPreallocation(*A, d_nz, d_nnz, o_nz, o_nnz)); 1839 } else { 1840 PetscCall(MatSetType(*A, MATSEQAIJKOKKOS)); 1841 PetscCall(MatSeqAIJSetPreallocation(*A, d_nz, d_nnz)); 1842 } 1843 PetscFunctionReturn(PETSC_SUCCESS); 1844 } 1845