1 #include <petsc_kokkos.hpp> 2 #include <petscvec_kokkos.hpp> 3 #include <petscpkg_version.h> 4 #include <petsc/private/sfimpl.h> 5 #include <../src/mat/impls/aij/seq/kokkos/aijkok.hpp> 6 #include <../src/mat/impls/aij/mpi/mpiaij.h> 7 #include <KokkosSparse_spadd.hpp> 8 #include <KokkosSparse_spgemm.hpp> 9 10 static PetscErrorCode MatAssemblyEnd_MPIAIJKokkos(Mat A, MatAssemblyType mode) 11 { 12 Mat_MPIAIJ *mpiaij = (Mat_MPIAIJ *)A->data; 13 14 PetscFunctionBegin; 15 PetscCall(MatAssemblyEnd_MPIAIJ(A, mode)); 16 /* E.g., MatCreateSubMatrix() calls MatCreateMPIAIJWithSeqAIJ(comm,A,B,..), which creates Bnew of SEQAIJ and destroys B of SEQAIJKOKKOS. 17 Thus we finalize A/B/lvec's type in MatAssemblyEnd() to handle various cases. 18 */ 19 if (mode == MAT_FINAL_ASSEMBLY) { 20 PetscCall(MatSetType(mpiaij->A, MATSEQAIJKOKKOS)); 21 PetscCall(MatSetType(mpiaij->B, MATSEQAIJKOKKOS)); 22 PetscCall(VecSetType(mpiaij->lvec, VECSEQKOKKOS)); 23 } 24 PetscFunctionReturn(PETSC_SUCCESS); 25 } 26 27 static PetscErrorCode MatMPIAIJSetPreallocation_MPIAIJKokkos(Mat mat, PetscInt d_nz, const PetscInt d_nnz[], PetscInt o_nz, const PetscInt o_nnz[]) 28 { 29 Mat_MPIAIJ *mpiaij; 30 31 PetscFunctionBegin; 32 // reuse MPIAIJ's preallocation, which sets A/B's blocksize along other things 33 PetscCall(MatMPIAIJSetPreallocation_MPIAIJ(mat, d_nz, d_nnz, o_nz, o_nnz)); 34 mpiaij = static_cast<Mat_MPIAIJ *>(mat->data); 35 PetscCall(MatConvert_SeqAIJ_SeqAIJKokkos(mpiaij->A, MATSEQAIJKOKKOS, MAT_INPLACE_MATRIX, &mpiaij->A)); 36 PetscCall(MatConvert_SeqAIJ_SeqAIJKokkos(mpiaij->B, MATSEQAIJKOKKOS, MAT_INPLACE_MATRIX, &mpiaij->B)); 37 PetscFunctionReturn(PETSC_SUCCESS); 38 } 39 40 static PetscErrorCode MatMult_MPIAIJKokkos(Mat mat, Vec xx, Vec yy) 41 { 42 Mat_MPIAIJ *mpiaij = (Mat_MPIAIJ *)mat->data; 43 PetscInt nt; 44 45 PetscFunctionBegin; 46 PetscCall(VecGetLocalSize(xx, &nt)); 47 PetscCheck(nt == mat->cmap->n, PETSC_COMM_SELF, PETSC_ERR_ARG_SIZ, "Incompatible partition of mat (%" PetscInt_FMT ") and xx (%" PetscInt_FMT ")", mat->cmap->n, nt); 48 PetscCall(VecScatterBegin(mpiaij->Mvctx, xx, mpiaij->lvec, INSERT_VALUES, SCATTER_FORWARD)); 49 PetscCall((*mpiaij->A->ops->mult)(mpiaij->A, xx, yy)); 50 PetscCall(VecScatterEnd(mpiaij->Mvctx, xx, mpiaij->lvec, INSERT_VALUES, SCATTER_FORWARD)); 51 PetscCall((*mpiaij->B->ops->multadd)(mpiaij->B, mpiaij->lvec, yy, yy)); 52 PetscFunctionReturn(PETSC_SUCCESS); 53 } 54 55 static PetscErrorCode MatMultAdd_MPIAIJKokkos(Mat mat, Vec xx, Vec yy, Vec zz) 56 { 57 Mat_MPIAIJ *mpiaij = (Mat_MPIAIJ *)mat->data; 58 PetscInt nt; 59 60 PetscFunctionBegin; 61 PetscCall(VecGetLocalSize(xx, &nt)); 62 PetscCheck(nt == mat->cmap->n, PETSC_COMM_SELF, PETSC_ERR_ARG_SIZ, "Incompatible partition of mat (%" PetscInt_FMT ") and xx (%" PetscInt_FMT ")", mat->cmap->n, nt); 63 PetscCall(VecScatterBegin(mpiaij->Mvctx, xx, mpiaij->lvec, INSERT_VALUES, SCATTER_FORWARD)); 64 PetscCall((*mpiaij->A->ops->multadd)(mpiaij->A, xx, yy, zz)); 65 PetscCall(VecScatterEnd(mpiaij->Mvctx, xx, mpiaij->lvec, INSERT_VALUES, SCATTER_FORWARD)); 66 PetscCall((*mpiaij->B->ops->multadd)(mpiaij->B, mpiaij->lvec, zz, zz)); 67 PetscFunctionReturn(PETSC_SUCCESS); 68 } 69 70 static PetscErrorCode MatMultTranspose_MPIAIJKokkos(Mat mat, Vec xx, Vec yy) 71 { 72 Mat_MPIAIJ *mpiaij = (Mat_MPIAIJ *)mat->data; 73 PetscInt nt; 74 75 PetscFunctionBegin; 76 PetscCall(VecGetLocalSize(xx, &nt)); 77 PetscCheck(nt == mat->rmap->n, PETSC_COMM_SELF, PETSC_ERR_ARG_SIZ, "Incompatible partition of mat (%" PetscInt_FMT ") and xx (%" PetscInt_FMT ")", mat->rmap->n, nt); 78 PetscCall((*mpiaij->B->ops->multtranspose)(mpiaij->B, xx, mpiaij->lvec)); 79 PetscCall((*mpiaij->A->ops->multtranspose)(mpiaij->A, xx, yy)); 80 PetscCall(VecScatterBegin(mpiaij->Mvctx, mpiaij->lvec, yy, ADD_VALUES, SCATTER_REVERSE)); 81 PetscCall(VecScatterEnd(mpiaij->Mvctx, mpiaij->lvec, yy, ADD_VALUES, SCATTER_REVERSE)); 82 PetscFunctionReturn(PETSC_SUCCESS); 83 } 84 85 /* Merge the "A, B" matrices of mat into a matrix C. mat's type is MPIAIJKOKKOS. C's type is MATSEQAIJKOKKOS. 86 A is put before B. C's size would be A->rmap->n by (A->cmap->n + B->cmap->n). 87 C still uses local column ids. Their corresponding global column ids are returned in glob. 88 */ 89 static PetscErrorCode MatMPIAIJGetLocalMatMerge_MPIAIJKokkos(Mat mat, MatReuse reuse, IS *glob, Mat *C) 90 { 91 Mat Ad, Ao; 92 const PetscInt *cmap; 93 94 PetscFunctionBegin; 95 PetscCall(MatMPIAIJGetSeqAIJ(mat, &Ad, &Ao, &cmap)); 96 PetscCall(MatSeqAIJKokkosMergeMats(Ad, Ao, reuse, C)); 97 if (glob) { 98 PetscInt cst, i, dn, on, *gidx; 99 PetscCall(MatGetLocalSize(Ad, NULL, &dn)); 100 PetscCall(MatGetLocalSize(Ao, NULL, &on)); 101 PetscCall(MatGetOwnershipRangeColumn(mat, &cst, NULL)); 102 PetscCall(PetscMalloc1(dn + on, &gidx)); 103 for (i = 0; i < dn; i++) gidx[i] = cst + i; 104 for (i = 0; i < on; i++) gidx[i + dn] = cmap[i]; 105 PetscCall(ISCreateGeneral(PetscObjectComm((PetscObject)Ad), dn + on, gidx, PETSC_OWN_POINTER, glob)); 106 } 107 PetscFunctionReturn(PETSC_SUCCESS); 108 } 109 110 /* Structs used in matrix products of type C=AB, C=A^tB and C=B^tAB */ 111 struct MatMatStruct { 112 PetscInt n, *garray; // C's garray and its size. 113 KokkosCsrMatrix Cd, Co; // C is in split form matrices (all in local column indcies) 114 KokkosCsrMatrix C1, C2, C3, C4; // intermediate mat products 115 KokkosCsrMatrix C2_mid, C4_mid; // alias of C2, C4; share their a[], i[], but with different j[] (hence column size) 116 PetscIntKokkosView E_NzLeft; 117 PetscSF sf = nullptr; // SF to bcast or reduce matrices E to F 118 MatScalarKokkosView rootBuf, leafBuf; 119 KokkosCsrMatrix Fd, Fo; // F in split form 120 121 KernelHandle kh1; // compute C1, add C1+C3 or C1+Fd 122 KernelHandle kh2; // compute C2, add C2+C4 or C2+Fo 123 KernelHandle kh3; // compute C3 124 KernelHandle kh4; // compute C4 125 126 PetscInt E_TeamSize; // kernel launching parameters in merging E or splitting F 127 PetscInt E_VectorLength; 128 PetscInt E_RowsPerTeam; 129 PetscInt F_TeamSize; 130 PetscInt F_VectorLength; 131 PetscInt F_RowsPerTeam; 132 133 ~MatMatStruct() 134 { 135 PetscFunctionBegin; 136 PetscCallAbort(PETSC_COMM_SELF, PetscSFDestroy(&sf)); 137 PetscFunctionReturnVoid(); 138 } 139 }; 140 141 struct MatMatStruct_AB : public MatMatStruct { 142 PetscIntKokkosView F_NzLeft; // plans to split F (in leafbuf) into Fd, Fo 143 PetscIntKokkosView irootloc; // plans to put E (i.e., Bd, Bo) into rootBuf 144 PetscIntKokkosView rowoffset; 145 }; 146 147 struct MatMatStruct_AtB : public MatMatStruct { 148 MatColIdxKokkosView Fdjmap; // plans to reduce data in rootBuf to Fd, Fo 149 MatColIdxKokkosView Fdjperm; 150 MatColIdxKokkosView Fojmap; 151 MatColIdxKokkosView Fojperm; 152 }; 153 154 struct MatProductData_MPIAIJKokkos { 155 MatMatStruct_AB *mmAB = nullptr; 156 MatMatStruct_AtB *mmAtB = nullptr; 157 PetscBool reusesym = PETSC_FALSE; 158 Mat Z = nullptr; // store Z=AB in computing BtAB 159 160 ~MatProductData_MPIAIJKokkos() 161 { 162 delete mmAB; 163 delete mmAtB; 164 PetscCallAbort(PETSC_COMM_SELF, MatDestroy(&Z)); 165 } 166 }; 167 168 static PetscErrorCode MatProductDataDestroy_MPIAIJKokkos(void *data) 169 { 170 PetscFunctionBegin; 171 PetscCallCXX(delete static_cast<MatProductData_MPIAIJKokkos *>(data)); 172 PetscFunctionReturn(PETSC_SUCCESS); 173 } 174 175 /* MatSetMPIAIJKokkosWithSplitSeqAIJKokkosMatrices - Set the diag and offdiag matrices of a MATMPIAIJKOKKOS matrix. 176 It is similar to MatCreateMPIAIJWithSplitArrays. 177 178 Input Parameters: 179 + mat - the MATMPIAIJKOKKOS matrix, which should have its type and layout set, but should not have its diag, offdiag matrices set 180 . A - the diag matrix using local col ids 181 - B - the offdiag matrix using global col ids 182 183 Output Parameter: 184 . mat - the updated MATMPIAIJKOKKOS matrix 185 */ 186 static PetscErrorCode MatSetMPIAIJKokkosWithSplitSeqAIJKokkosMatrices(Mat mat, Mat A, Mat B, PetscInt *garray) 187 { 188 Mat_MPIAIJ *mpiaij = static_cast<Mat_MPIAIJ *>(mat->data); 189 PetscInt m, n, M, N, Am, An, Bm, Bn; 190 191 PetscFunctionBegin; 192 PetscCall(MatGetSize(mat, &M, &N)); 193 PetscCall(MatGetLocalSize(mat, &m, &n)); 194 PetscCall(MatGetLocalSize(A, &Am, &An)); 195 PetscCall(MatGetLocalSize(B, &Bm, &Bn)); 196 197 PetscCheck(m == Am && m == Bm, PETSC_COMM_SELF, PETSC_ERR_PLIB, "local number of rows do not match"); 198 PetscCheck(n == An, PETSC_COMM_SELF, PETSC_ERR_PLIB, "local number of columns do not match"); 199 // PetscCheck(N == Bn, PETSC_COMM_SELF, PETSC_ERR_PLIB, "global number of columns do not match"); 200 PetscCheck(!mpiaij->A && !mpiaij->B, PETSC_COMM_SELF, PETSC_ERR_PLIB, "A, B of the MPIAIJ matrix are not empty"); 201 mpiaij->A = A; 202 mpiaij->B = B; 203 mpiaij->garray = garray; 204 205 mat->preallocated = PETSC_TRUE; 206 mat->nooffprocentries = PETSC_TRUE; /* See MatAssemblyBegin_MPIAIJ. In effect, making MatAssemblyBegin a nop */ 207 208 PetscCall(MatSetOption(mat, MAT_NO_OFF_PROC_ENTRIES, PETSC_TRUE)); 209 PetscCall(MatAssemblyBegin(mat, MAT_FINAL_ASSEMBLY)); 210 /* MatAssemblyEnd is critical here. It sets mat->offloadmask according to A and B's, and 211 also gets mpiaij->B compacted, with its col ids and size reduced 212 */ 213 PetscCall(MatAssemblyEnd(mat, MAT_FINAL_ASSEMBLY)); 214 PetscCall(MatSetOption(mat, MAT_NO_OFF_PROC_ENTRIES, PETSC_FALSE)); 215 PetscCall(MatSetOption(mat, MAT_NEW_NONZERO_LOCATION_ERR, PETSC_TRUE)); 216 PetscFunctionReturn(PETSC_SUCCESS); 217 } 218 219 // Adapted from Kokkos-Kernels spmv_launch_parameters(), to get parameters in Kokkos nested loops which we used to merge or 220 // split csr matrices. The rule is to have "vector_length * team_size" be around 256 on GPUs (e.g., for a CUDA thread block) 221 template <class ExecutionSpace> 222 static PetscErrorCode MatMergeGetLaunchParameters(PetscInt numRows, PetscInt nnz, PetscInt rows_per_thread, PetscInt &team_size, PetscInt &vector_length, PetscInt &rows_per_team) 223 { 224 Kokkos::TeamPolicy<ExecutionSpace> teamPolicy(128, Kokkos::AUTO); 225 226 PetscFunctionBegin; 227 PetscInt nnz_per_row = numRows ? (nnz / numRows) : 0; // we might meet empty matrices 228 229 if (nnz_per_row < 1) nnz_per_row = 1; 230 231 int max_vector_length = teamPolicy.vector_length_max(); 232 233 if (vector_length < 1) { 234 vector_length = 1; 235 while (vector_length < max_vector_length && vector_length * 6 < nnz_per_row) vector_length *= 2; 236 } 237 238 // Determine rows per thread 239 if (rows_per_thread < 1) { 240 if (KokkosKernels::Impl::kk_is_gpu_exec_space<ExecutionSpace>()) rows_per_thread = 1; 241 else { 242 if (nnz_per_row < 20 && nnz > 5000000) { 243 rows_per_thread = 256; 244 } else rows_per_thread = 64; 245 } 246 } 247 248 if (team_size < 1) { 249 if (KokkosKernels::Impl::kk_is_gpu_exec_space<ExecutionSpace>()) { 250 team_size = 256 / vector_length; 251 } else { 252 team_size = 1; 253 } 254 } 255 256 rows_per_team = rows_per_thread * team_size; 257 258 if (rows_per_team < 0) { 259 PetscInt nnz_per_team = 4096; 260 PetscInt conc = ExecutionSpace().concurrency(); 261 while ((conc * nnz_per_team * 4 > nnz) && (nnz_per_team > 256)) nnz_per_team /= 2; 262 rows_per_team = (nnz_per_team + nnz_per_row - 1) / nnz_per_row; 263 } 264 PetscFunctionReturn(PETSC_SUCCESS); 265 } 266 267 /* 268 Reduce two sets of global indices into local ones 269 270 Input Parameters: 271 + n1 - size of garray1[], the first set 272 . garray1[n1] - a sorted global index array (without duplicates) 273 . m - size of indices[], the second set 274 - indices[m] - a unsorted global index array (might have duplicates), which will be updated on output into local ones 275 276 Output Parameters: 277 + n2 - size of garray2[], the merged set, which combines garray1[] and indices[] 278 . garray2[n2] - allocated by callee using PetscMalloc1(). Contains sorted unique global indices (without duplicates). Caller needs to free it. 279 . map[n1] - allocated by caller. It gives garray1[i] = garray2[map[i]] 280 - indices[m] - on output, global indices in this array are rewritten with local ones, i.e, indices_input[i] = garray2[indices_output[i]] 281 282 Example, say 283 n1 = 5 284 garray1[5] = {1, 4, 7, 8, 10} 285 m = 4 286 indices[4] = {2, 4, 8, 9} 287 288 Combining them together, we have 7 global indices in garray2[] 289 n2 = 7 290 garray2[7] = {1, 2, 4, 7, 8, 9, 10} 291 292 And we have map[] to connect "garray1[i] = garray2[map[i]], i=[0,n1)" 293 map[5] = {0, 2, 3, 4, 6} 294 295 On output, indices[] is updated with local indices 296 indices[4] = {1, 2, 4, 5} 297 */ 298 static PetscErrorCode ReduceTwoSetsOfGlobalIndices(PetscInt n1, const PetscInt *garray1, PetscInt m, PetscInt *indices, PetscInt *n2_, PetscInt **garray2_, PetscInt *map) 299 { 300 PetscHMapI g2l = nullptr; 301 PetscHashIter iter; 302 PetscInt tot, key, val; // total unique global indices. key is global id; val is local id 303 PetscInt n2, *garray2; 304 305 PetscFunctionBegin; 306 tot = 0; 307 PetscCall(PetscHMapICreateWithSize(n1, &g2l)); 308 for (PetscInt i = 0; i < m; i++) { // insert those in indices[] 309 PetscCall(PetscHMapIGetWithDefault(g2l, indices[i], -1, &val)); // if not exist, val is set with -1 310 if (val < 0) PetscCall(PetscHMapISet(g2l, indices[i], tot++)); // val < 0 means gid is not in the hash table yet 311 } 312 313 for (PetscInt i = 0; i < n1; i++) { // insert those in garray1[] 314 PetscCall(PetscHMapIGetWithDefault(g2l, garray1[i], -1, &val)); 315 if (val < 0) PetscCall(PetscHMapISet(g2l, garray1[i], tot++)); 316 } 317 318 // Pull out (unique) globals in the hash table and put them in garray2[] 319 n2 = tot; 320 PetscCall(PetscMalloc1(n2, &garray2)); 321 tot = 0; 322 PetscHashIterBegin(g2l, iter); 323 while (!PetscHashIterAtEnd(g2l, iter)) { 324 PetscHashIterGetKey(g2l, iter, key); 325 PetscHashIterNext(g2l, iter); 326 garray2[tot++] = key; 327 } 328 329 // Sort garray2[] and then map them to local indices starting from 0 330 PetscCall(PetscSortInt(n2, garray2)); 331 PetscCall(PetscHMapIClear(g2l)); 332 for (PetscInt i = 0; i < tot; i++) PetscCall(PetscHMapISet(g2l, garray2[i], i)); // i is the local id 333 334 // Rewrite indices[] with local indices 335 for (PetscInt i = 0; i < m; i++) { 336 PetscCall(PetscHMapIGetWithDefault(g2l, indices[i], -1, &val)); 337 PetscAssert(val >= 0, PETSC_COMM_SELF, PETSC_ERR_PLIB, "Met a negative local column index"); 338 indices[i] = val; 339 } 340 // Record the map that maps garray1[i] to garray2[map[i]] 341 for (PetscInt i = 0; i < n1; i++) PetscCall(PetscHMapIGetWithDefault(g2l, garray1[i], -1, &map[i])); 342 PetscCall(PetscHMapIDestroy(&g2l)); 343 *n2_ = n2; 344 *garray2_ = garray2; 345 PetscFunctionReturn(PETSC_SUCCESS); 346 } 347 348 /* 349 MatMPIAIJKokkosReduce - Reduce rows of a MPIAIJKOKKOS matrix (E, in split form) to produce another matrix (F, also in split form, stored in mm) 350 351 It is the reverse of MatMPIAIJKokkosBcast() in some sense, but with a different signature since we do not really need a fully populated MPIAIJKOKKOS E. 352 353 Think each row of E as a leaf, then the given ownerSF specifies roots for the leaves. Roots may connect to multiple leaves. 354 In this routine, we sparse-merge leaves (rows) at their roots to form potentially longer rows in F. F's number of rows will be nroots of ownerSF. 355 356 Input Parameters: 357 + comm - MPI communicator of E 358 . A - diag block of E, using local column indices 359 . B - off-diag block of E, using local column indices 360 . cstart - (global) start column of Ed 361 . cend - (global) end column + 1 of Ed. In other words, E's column ownership is in range of [cstart, cend) 362 . garray1[n1] - global column indices of Eo. Here n1 is Eo's column size. 363 . ownerSF - the SF specifies ownership (root) of rows in E 364 . reuse - either MAT_INITIAL_MATRIX or MAT_REUSE_MATRIX 365 - mm - to stash intermediate data structures for reuse 366 367 Output Parameters: 368 + map[n1] - allocated by caller. It maps garray1[] to garray2[]. See more at ReduceTwoSetsOfGlobalIndices(). 369 - mm - contains various info, such as garray2[], F (Fd, Fo) etc. 370 371 Notes: 372 When reuse = MAT_REUSE_MATRIX, cstart, cend, garray1, ownerSF, map are not significant. 373 374 */ 375 static PetscErrorCode MatMPIAIJKokkosReduceBegin(MPI_Comm comm, KokkosCsrMatrix A, KokkosCsrMatrix B, PetscInt cstart, PetscInt cend, const PetscInt *garray1, PetscSF ownerSF, MatReuse reuse, PetscInt *map, MatMatStruct_AtB *mm) 376 { 377 PetscFunctionBegin; 378 if (reuse == MAT_INITIAL_MATRIX) { 379 PetscInt Em = A.numRows(), Fm; 380 PetscInt n1 = B.numCols(); 381 382 PetscCall(PetscSFGetGraph(ownerSF, &Fm, NULL, NULL, NULL)); // Fm = #rows of F = nroots of ownerSF 383 384 // Do the analysis on host 385 auto Ai_h = Kokkos::create_mirror_view_and_copy(Kokkos::HostSpace(), A.graph.row_map); 386 auto Aj_h = Kokkos::create_mirror_view_and_copy(Kokkos::HostSpace(), A.graph.entries); 387 auto Bi_h = Kokkos::create_mirror_view_and_copy(Kokkos::HostSpace(), B.graph.row_map); 388 auto Bj_h = Kokkos::create_mirror_view_and_copy(Kokkos::HostSpace(), B.graph.entries); 389 const MatRowMapType *Ai = Ai_h.data(), *Bi = Bi_h.data(); 390 const MatColIdxType *Aj = Aj_h.data(), *Bj = Bj_h.data(); 391 392 // Count how many nonzeros of each row in E are in the left of the diag block [cstart,cend) 393 PetscIntKokkosViewHost E_NzLeft_h(NoInit("E_NzLeft_h"), Em), E_RowLen_h(NoInit("E_RowLen_h"), Em); 394 PetscInt *E_NzLeft = E_NzLeft_h.data(), *E_RowLen = E_RowLen_h.data(); 395 for (PetscInt i = 0; i < Em; i++) { 396 const PetscInt *first, *last, *it; 397 PetscInt count, step; 398 // std::lower_bound(first,last,cstart), but need to use global column indices 399 first = Bj + Bi[i]; 400 last = Bj + Bi[i + 1]; 401 count = last - first; 402 while (count > 0) { 403 it = first; 404 step = count / 2; 405 it += step; 406 if (garray1[*it] < cstart) { // map local to global 407 first = ++it; 408 count -= step + 1; 409 } else count = step; 410 } 411 E_NzLeft[i] = first - (Bj + Bi[i]); 412 E_RowLen[i] = (Ai[i + 1] - Ai[i]) + (Bi[i + 1] - Bi[i]); 413 } 414 415 // Get length of rows (i.e., sizes of leaves) that contribute to my roots 416 const PetscMPIInt *iranks, *ranks; 417 const PetscInt *ioffset, *irootloc, *roffset, *rmine; 418 PetscInt niranks, nranks; 419 MPI_Request *reqs; 420 PetscMPIInt tag; 421 PetscSF reduceSF; 422 PetscInt *sdisp, *rdisp; 423 424 PetscCall(PetscCommGetNewTag(comm, &tag)); 425 PetscCall(PetscSFGetLeafRanks(ownerSF, &niranks, &iranks, &ioffset, &irootloc)); // get leaf ranks connecting to roots on this process (I'll recv from them) 426 PetscCall(PetscSFGetRootRanks(ownerSF, &nranks, &ranks, &roffset, &rmine, NULL)); // get root ranks this process connects (I'll send to them) 427 428 // Find out length of each row I will receive. Even for the same row index, when they are from 429 // different senders, they might have different lengths (and sparsity patterns) 430 PetscInt sendRowCnt = roffset[nranks], recvRowCnt = ioffset[niranks]; 431 PetscInt *sendRowLen, *recvRowLen; // lengths of rows of I need to send/recv per process 432 433 PetscCall(PetscMalloc5(sendRowCnt, &sendRowLen, recvRowCnt + 1, &recvRowLen, nranks, &sdisp, niranks + 1, &rdisp, nranks + niranks, &reqs)); 434 435 for (PetscInt i = 0; i < sendRowCnt; i++) sendRowLen[i] = E_RowLen[rmine[i]]; 436 recvRowLen[0] = 0; // since we will make it in CSR format later 437 recvRowLen++; // advance the pointer now 438 for (PetscInt i = 0; i < niranks; i++) { MPI_Irecv(&recvRowLen[ioffset[i]], ioffset[i + 1] - ioffset[i], MPIU_INT, iranks[i], tag, comm, &reqs[nranks + i]); } 439 for (PetscInt i = 0; i < nranks; i++) { MPI_Isend(&sendRowLen[roffset[i]], roffset[i + 1] - roffset[i], MPIU_INT, ranks[i], tag, comm, &reqs[i]); } 440 PetscCallMPI(MPI_Waitall(nranks + niranks, reqs, MPI_STATUSES_IGNORE)); 441 442 // Build the real PetscSF for reducing E rows (buffer to buffer) 443 rdisp[0] = 0; 444 for (PetscInt i = 0; i < niranks; i++) { 445 rdisp[i + 1] = rdisp[i]; 446 for (PetscInt j = ioffset[i]; j < ioffset[i + 1]; j++) { rdisp[i + 1] += recvRowLen[j]; } 447 } 448 recvRowLen--; // put it back into csr format 449 for (PetscInt i = 0; i < recvRowCnt; i++) recvRowLen[i + 1] += recvRowLen[i]; 450 451 for (PetscInt i = 0; i < nranks; i++) { MPI_Irecv(&sdisp[i], 1, MPIU_INT, ranks[i], tag, comm, &reqs[i]); } 452 for (PetscInt i = 0; i < niranks; i++) { MPI_Isend(&rdisp[i], 1, MPIU_INT, iranks[i], tag, comm, &reqs[nranks + i]); } 453 PetscCallMPI(MPI_Waitall(nranks + niranks, reqs, MPI_STATUSES_IGNORE)); 454 455 PetscInt nleaves = 0, Enz = 0; // leaves are nonzeros I will send 456 PetscInt nroots = rdisp[niranks]; // roots are nonzeros I will recv 457 PetscSFNode *iremote; 458 459 for (PetscInt i = 0; i < Em; i++) Enz += E_RowLen[i]; 460 PetscAssert(A.nnz() + B.nnz() == Enz, comm, PETSC_ERR_PLIB, "Enz should be equal to sum of nnz of A and B"); 461 PetscCallMPI(PetscMalloc1(Enz, &iremote)); // no free, since we give ownership to reduceSF 462 463 for (PetscInt i = 0; i < nranks; i++) { 464 PetscInt count = 0; 465 for (PetscInt j = roffset[i]; j < roffset[i + 1]; j++) count += E_RowLen[rmine[j]]; 466 for (PetscInt j = 0; j < count; j++) { 467 iremote[nleaves + j].rank = ranks[i]; 468 iremote[nleaves + j].index = sdisp[i] + j; 469 } 470 nleaves += count; 471 } 472 PetscCheck(nleaves == Enz, comm, PETSC_ERR_PLIB, "nleaves should be equal to Enz"); 473 474 PetscCall(PetscSFCreate(comm, &reduceSF)); 475 PetscCall(PetscSFSetGraph(reduceSF, nroots, nleaves, NULL, PETSC_OWN_POINTER, iremote, PETSC_OWN_POINTER)); 476 477 // Copy (global) column indices of the needed rows in E to sendCol[], and then PetscSFReduce to recvCol[] 478 PetscInt *sendCol, *recvCol; 479 PetscCall(PetscMalloc2(nleaves, &sendCol, nroots, &recvCol)); 480 for (PetscInt k = 0; k < roffset[nranks]; k++) { 481 PetscInt i = rmine[k]; // row to be copied 482 PetscInt *buf = &sendCol[Ai[i] + Bi[i]]; 483 PetscInt nzLeft = E_NzLeft[i]; 484 PetscInt alen = Ai[i + 1] - Ai[i], blen = Bi[i + 1] - Bi[i]; 485 for (PetscInt j = 0; j < alen + blen; j++) { 486 if (j < nzLeft) { 487 buf[j] = garray1[Bj[Bi[i] + j]]; // left B, in global 488 } else if (j < nzLeft + alen) { 489 buf[j] = Aj[Ai[i] + j - nzLeft] + cstart; // diag A, also in global 490 } else { 491 buf[j] = garray1[Bj[Bi[i] + j - alen]]; // right B, in global 492 } 493 } 494 } 495 PetscCall(PetscSFReduceWithMemTypeBegin(reduceSF, MPIU_INT, PETSC_MEMTYPE_HOST, sendCol, PETSC_MEMTYPE_HOST, recvCol, MPI_REPLACE)); 496 PetscCall(PetscSFReduceEnd(reduceSF, MPIU_INT, sendCol, recvCol, MPI_REPLACE)); 497 498 // With recvCol[], we do a series of analysis to get i, j of Fd, Fo, and build plans to reduce nonzeros in recv buffers to Fd and Fo 499 PetscInt *recvRowPerm, *recvColSorted; 500 PetscInt *recvNzPerm, *recvNzPermSorted; 501 PetscCall(PetscMalloc4(recvRowCnt, &recvRowPerm, nroots, &recvColSorted, nroots, &recvNzPerm, nroots, &recvNzPermSorted)); 502 503 for (PetscInt i = 0; i < nroots; i++) recvNzPerm[i] = i; // numbering all received nonzeros 504 for (PetscInt i = 0; i < recvRowCnt; i++) recvRowPerm[i] = i; // put up a permutation array, so that after sorting we know where to get a row in recvCol[] 505 PetscCall(PetscSortIntWithPermutation(recvRowCnt, irootloc, recvRowPerm)); // irootloc[] (owned by ownerSF) won't be changed 506 507 // i[] array, nz are always easiest to compute 508 MatRowMapKokkosViewHost Fdi_h(NoInit("Fdi_h"), Fm + 1), Foi_h(NoInit("Foi_h"), Fm + 1); 509 MatRowMapType *Fdi, *Foi; 510 PetscInt FnzDups = 0, Fdnz = 0, FdnzDups = 0, Fonz = 0, FonzDups = 0; // nz (with or without dups) in F, Fd, Fo 511 PetscInt iter; 512 513 Kokkos::deep_copy(Fdi_h, 0); // zero, as we will do 'val++' on them 514 Kokkos::deep_copy(Foi_h, 0); 515 Fdi = Fdi_h.data() + 1; // +1 for easy indexing in code below 516 Foi = Foi_h.data() + 1; 517 iter = 0; 518 while (iter < recvRowCnt) { // iter over received rows 519 PetscInt curRowIdx = irootloc[recvRowPerm[iter]]; 520 PetscInt dupRows = 1; // current row has this many contributing rows (of various sparsity patterns) 521 522 while (iter + dupRows < recvRowCnt && irootloc[recvRowPerm[iter + dupRows]] == curRowIdx) dupRows++; 523 524 // Copy column indices (and their permutation) of these rows into recvColSorted & recvNzPermSorted 525 PetscInt nz = 0; // nz (with dups) in the current row 526 PetscInt *jbuf = recvColSorted + FnzDups; 527 PetscInt *pbuf = recvNzPermSorted + FnzDups; 528 PetscInt *jbuf2 = jbuf; // temp pointers 529 PetscInt *pbuf2 = pbuf; 530 for (PetscInt d = 0; d < dupRows; d++) { 531 PetscInt i = recvRowPerm[iter + d]; 532 PetscInt len = recvRowLen[i + 1] - recvRowLen[i]; 533 PetscCall(PetscArraycpy(jbuf2, &recvCol[recvRowLen[i]], len)); 534 PetscCall(PetscArraycpy(pbuf2, &recvNzPerm[recvRowLen[i]], len)); 535 jbuf2 += len; 536 pbuf2 += len; 537 nz += len; 538 } 539 PetscCall(PetscIntSortSemiOrderedWithArray(nz, jbuf, pbuf)); // It could be improved with k-way merge sort, since the rows are already sorted 540 541 // Scan column indices (in jbuf[0,nz), might have dups) of this row, and see how many go to Fd and how many go to Fo 542 PetscInt cur = 0; 543 while (cur < nz) { 544 PetscInt curColIdx = jbuf[cur]; 545 PetscInt dups = 1; 546 547 while (cur + dups < nz && jbuf[cur + dups] == curColIdx) dups++; 548 if (curColIdx >= cstart && curColIdx < cend) { 549 Fdi[curRowIdx]++; 550 FdnzDups += dups; 551 } else { 552 Foi[curRowIdx]++; 553 FonzDups += dups; 554 } 555 cur += dups; 556 } 557 558 FnzDups += nz; 559 iter += dupRows; // Move to next unique row 560 } 561 562 Fdi = Fdi_h.data(); // restore Fdi, Foi and make them CSR 563 Foi = Foi_h.data(); 564 for (PetscInt i = 0; i < Fm; i++) { 565 Fdi[i + 1] += Fdi[i]; 566 Foi[i + 1] += Foi[i]; 567 } 568 Fdnz = Fdi[Fm]; 569 Fonz = Foi[Fm]; 570 PetscCall(PetscFree2(sendCol, recvCol)); 571 572 // Allocate j, jmap, jperm for Fd and Fo 573 MatColIdxKokkosViewHost Fdj_h(NoInit("Fdj_h"), Fdnz), Foj_h(NoInit("Foj_h"), Fonz); 574 MatRowMapKokkosViewHost Fdjmap_h(NoInit("Fdjmap_h"), Fdnz + 1), Fojmap_h(NoInit("Fojmap_h"), Fonz + 1); // +1 to make csr 575 MatRowMapKokkosViewHost Fdjperm_h(NoInit("Fdjperm_h"), FdnzDups), Fojperm_h(NoInit("Fojperm_h"), FonzDups); 576 MatColIdxType *Fdj = Fdj_h.data(), *Foj = Foj_h.data(); 577 MatRowMapType *Fdjmap = Fdjmap_h.data(), *Fojmap = Fojmap_h.data(); 578 MatRowMapType *Fdjperm = Fdjperm_h.data(), *Fojperm = Fojperm_h.data(); 579 580 // Scan recvColSorted[] again, and fill j, jmap, jperm for Fd and Fo 581 Fdjmap[0] = 0; 582 Fojmap[0] = 0; 583 FnzDups = 0; 584 Fdnz = 0; 585 Fonz = 0; 586 iter = 0; // iter over received rows 587 while (iter < recvRowCnt) { 588 PetscInt curRowIdx = irootloc[recvRowPerm[iter]]; // current row idx 589 PetscInt dupRows = 1; // It has this many contributing rows (of various lengths) 590 PetscInt nz = 0; // nz (with dups) in the current row 591 592 while (iter + dupRows < recvRowCnt && irootloc[recvRowPerm[iter + dupRows]] == curRowIdx) dupRows++; 593 for (PetscInt d = 0; d < dupRows; d++) { 594 PetscInt i = recvRowPerm[iter + d]; 595 nz += recvRowLen[i + 1] - recvRowLen[i]; 596 } 597 598 PetscInt *jbuf = recvColSorted + FnzDups; 599 // Scan columns (in jbuf[0,nz) of this row, copy them and their permutation to j[] and jperm[] of Fd and Fo 600 PetscInt cur = 0; 601 while (cur < nz) { 602 PetscInt curColIdx = jbuf[cur]; 603 PetscInt dups = 1; 604 605 while (cur + dups < nz && jbuf[cur + dups] == curColIdx) dups++; 606 if (curColIdx >= cstart && curColIdx < cend) { 607 Fdj[Fdnz] = curColIdx - cstart; // easily convert to local 608 Fdjmap[Fdnz + 1] = Fdjmap[Fdnz] + dups; 609 for (PetscInt j = 0; j < dups; j++) Fdjperm[Fdjmap[Fdnz] + j] = recvNzPermSorted[FnzDups + j]; 610 FdnzDups += dups; 611 Fdnz++; 612 } else { 613 Foj[Fonz] = curColIdx; // in global 614 Fojmap[Fonz + 1] = Fojmap[Fonz] + dups; 615 for (PetscInt j = 0; j < dups; j++) Fojperm[Fojmap[Fonz] + j] = recvNzPermSorted[FnzDups + j]; 616 FonzDups += dups; 617 Fonz++; 618 } 619 cur += dups; 620 FnzDups += dups; 621 } 622 iter += dupRows; // Move to next unique row 623 } 624 PetscCall(PetscFree4(recvRowPerm, recvColSorted, recvNzPerm, recvNzPermSorted)); 625 PetscCall(PetscFree5(sendRowLen, recvRowLen, sdisp, rdisp, reqs)); 626 627 // Combine global column indices in garray1[] and Foj[] 628 PetscInt n2, *garray2; 629 630 PetscCall(ReduceTwoSetsOfGlobalIndices(n1, garray1, Fonz, Foj, &n2, &garray2, map)); 631 mm->sf = reduceSF; 632 mm->leafBuf = MatScalarKokkosView(NoInit("leafBuf"), nleaves); 633 mm->rootBuf = MatScalarKokkosView(NoInit("rootBuf"), nroots); 634 mm->garray = garray2; // give ownership, so no free 635 mm->n = n2; 636 mm->E_NzLeft = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), E_NzLeft_h); 637 mm->Fdjmap = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Fdjmap_h); 638 mm->Fdjperm = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Fdjperm_h); 639 mm->Fojmap = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Fojmap_h); 640 mm->Fojperm = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Fojperm_h); 641 642 // Output Fd and Fo in KokkosCsrMatrix format 643 MatScalarKokkosView Fda_d(NoInit("Fda_d"), Fdnz); 644 MatRowMapKokkosView Fdi_d = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Fdi_h); 645 MatColIdxKokkosView Fdj_d = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Fdj_h); 646 MatScalarKokkosView Foa_d(NoInit("Foa_d"), Fonz); 647 MatRowMapKokkosView Foi_d = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Foi_h); 648 MatColIdxKokkosView Foj_d = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Foj_h); 649 650 PetscCallCXX(mm->Fd = KokkosCsrMatrix("Fd", Fm, cend - cstart, Fdnz, Fda_d, Fdi_d, Fdj_d)); 651 PetscCallCXX(mm->Fo = KokkosCsrMatrix("Fo", Fm, n2, Fonz, Foa_d, Foi_d, Foj_d)); // Fo's column size is n2, length of garray2[] 652 653 // Compute kernel launch parameters in merging E 654 PetscInt teamSize, vectorLength, rowsPerTeam; 655 656 teamSize = vectorLength = rowsPerTeam = -1; 657 PetscCall(MatMergeGetLaunchParameters<DefaultExecutionSpace>(Em, Enz, -1, teamSize, vectorLength, rowsPerTeam)); 658 mm->E_TeamSize = teamSize; 659 mm->E_VectorLength = vectorLength; 660 mm->E_RowsPerTeam = rowsPerTeam; 661 } else PetscCheck(reuse == MAT_REUSE_MATRIX, comm, PETSC_ERR_PLIB, "Unsupported MatReuse enum %d", reuse); 662 663 // Handy aliases 664 auto &Aa = A.values; 665 auto &Ba = B.values; 666 const auto &Ai = A.graph.row_map; 667 const auto &Bi = B.graph.row_map; 668 const auto &E_NzLeft = mm->E_NzLeft; 669 auto &leafBuf = mm->leafBuf; 670 auto &rootBuf = mm->rootBuf; 671 PetscSF reduceSF = mm->sf; 672 PetscInt Em = A.numRows(); 673 PetscInt teamSize = mm->E_TeamSize; 674 PetscInt vectorLength = mm->E_VectorLength; 675 PetscInt rowsPerTeam = mm->E_RowsPerTeam; 676 PetscInt workSets = (Em + rowsPerTeam - 1) / rowsPerTeam; 677 678 // Copy rows in A/B of E to leafBuf, then pass it to rootBuf 679 PetscCallCXX(Kokkos::parallel_for( 680 Kokkos::TeamPolicy<>(PetscGetKokkosExecutionSpace(), workSets, teamSize, vectorLength), KOKKOS_LAMBDA(const KokkosTeamMemberType &t) { 681 Kokkos::parallel_for(Kokkos::TeamThreadRange(t, 0, rowsPerTeam), [&](PetscInt k) { 682 PetscInt i = t.league_rank() * rowsPerTeam + k; // i-th row in F 683 if (i < Em) { 684 PetscInt disp = Ai(i) + Bi(i); 685 PetscInt alen = Ai(i + 1) - Ai(i); 686 PetscInt blen = Bi(i + 1) - Bi(i); 687 PetscInt nzleft = E_NzLeft(i); 688 689 Kokkos::parallel_for(Kokkos::ThreadVectorRange(t, alen + blen), [&](PetscInt j) { 690 MatScalar &val = leafBuf(disp + j); 691 if (j < nzleft) { // B left 692 val = Ba(Bi(i) + j); 693 } else if (j < nzleft + alen) { // diag A 694 val = Aa(Ai(i) + j - nzleft); 695 } else { // B right 696 val = Ba(Bi(i) + j - alen); 697 } 698 }); 699 } 700 }); 701 })); 702 PetscCall(PetscSFReduceWithMemTypeBegin(reduceSF, MPIU_SCALAR, PETSC_MEMTYPE_KOKKOS, leafBuf.data(), PETSC_MEMTYPE_KOKKOS, rootBuf.data(), MPI_REPLACE)); 703 PetscFunctionReturn(PETSC_SUCCESS); 704 } 705 706 // To finish MatMPIAIJKokkosReduce. 707 static PetscErrorCode MatMPIAIJKokkosReduceEnd(MPI_Comm comm, KokkosCsrMatrix A, KokkosCsrMatrix B, PetscInt cstart, PetscInt cend, const PetscInt *garray1, PetscSF ownerSF, MatReuse reuse, PetscInt *map, MatMatStruct_AtB *mm) 708 { 709 auto &leafBuf = mm->leafBuf; 710 auto &rootBuf = mm->rootBuf; 711 auto &Fda = mm->Fd.values; 712 const auto &Fdjmap = mm->Fdjmap; 713 const auto &Fdjperm = mm->Fdjperm; 714 auto Fdnz = mm->Fd.nnz(); 715 auto &Foa = mm->Fo.values; 716 const auto &Fojmap = mm->Fojmap; 717 const auto &Fojperm = mm->Fojperm; 718 auto Fonz = mm->Fo.nnz(); 719 PetscSF reduceSF = mm->sf; 720 721 PetscFunctionBegin; 722 PetscCall(PetscSFReduceEnd(reduceSF, MPIU_SCALAR, leafBuf.data(), rootBuf.data(), MPI_REPLACE)); 723 724 // Reduce data in rootBuf to Fd and Fo 725 PetscCallCXX(Kokkos::parallel_for( 726 Kokkos::RangePolicy<>(PetscGetKokkosExecutionSpace(), 0, Fdnz), KOKKOS_LAMBDA(const MatRowMapType i) { 727 PetscScalar sum = 0.0; 728 for (MatRowMapType k = Fdjmap(i); k < Fdjmap(i + 1); k++) sum += rootBuf(Fdjperm(k)); 729 Fda(i) = sum; 730 })); 731 732 PetscCallCXX(Kokkos::parallel_for( 733 Kokkos::RangePolicy<>(PetscGetKokkosExecutionSpace(), 0, Fonz), KOKKOS_LAMBDA(const MatRowMapType i) { 734 PetscScalar sum = 0.0; 735 for (MatRowMapType k = Fojmap(i); k < Fojmap(i + 1); k++) sum += rootBuf(Fojperm(k)); 736 Foa(i) = sum; 737 })); 738 PetscFunctionReturn(PETSC_SUCCESS); 739 } 740 741 /* 742 MatMPIAIJKokkosBcast - Bcast local rows of a MPIAIJKOKKOS matrix (E) to produce a local matrix (F, stored in mm) in split form 743 744 This is a complex routine. It is essentially the MPIAIJKOKKOS counterpart of MatGetBrowsOfAoCols_MPIAIJ, but supports 745 device and involves various index mapping. 746 747 In the given ownerSF, leaves correspond to rows in F, and roots correspond to rows in E. Roots may connect to multiple leaves. 748 Suppose F's j-th row is connected to a root identified by PetscSFNode (k,i), it means we need to bcast the i-th row of E on rank k 749 to j-th row of F. ownerSF is not an arbitrary SF, instead it is the Mvctx of another MPIAIJ matrix A that is able to perform A*E. 750 F has the same column layout as E. 751 752 Conceptually F has global column indices. In this routine, we spit F into diagonal Fd and off-diagonal Fo. 753 Fd uses local column indices, which are easy to compute. We just need to subtract the "local column range start" from the global indices. 754 Fo had global column indices at first. We will reduce them into local ones. In doing that, we also take into account the global 755 column indices that E's off-diag block has. Let's say there are n1 such indices stored in garray1[]. We will reduce them along with 756 column indices in Fo and update Fo with local indices. 757 758 Input Parameters: 759 + E - the MPIAIJKOKKOS matrix 760 . ownerSF - the ownership SF (insignificant in MAT_REUSE_MATRIX) 761 . reuse - either MAT_INITIAL_MATRIX or MAT_REUSE_MATRIX 762 - mm - to stash matproduct intermediate data structures 763 764 Output Parameters: 765 + map[n1] - allocated by caller. It maps garray1[] to garray2[]. See more at ReduceTwoSetsOfGlobalIndices. 766 - mm - contains various info, such as garray2[], Fd, Fo, etc. 767 768 Notes: 769 When reuse = MAT_REUSE_MATRIX, ownerSF, map are not significant. 770 The routine is provide in split-phase form MatMPIAIJKokkosBcastBegin/End() to provide computation/communication opportunities. 771 */ 772 static PetscErrorCode MatMPIAIJKokkosBcastBegin(Mat E, PetscSF ownerSF, MatReuse reuse, PetscInt *map, MatMatStruct_AB *mm) 773 { 774 Mat_MPIAIJ *empi = static_cast<Mat_MPIAIJ *>(E->data); 775 Mat A = empi->A, B = empi->B; // diag and off-diag 776 Mat_SeqAIJKokkos *akok = static_cast<Mat_SeqAIJKokkos *>(A->spptr), *bkok = static_cast<Mat_SeqAIJKokkos *>(B->spptr); 777 PetscInt Em = E->rmap->n; // #local rows 778 MPI_Comm comm; 779 780 PetscFunctionBegin; 781 PetscCallMPI(PetscObjectGetComm((PetscObject)E, &comm)); 782 if (reuse == MAT_INITIAL_MATRIX) { 783 Mat_SeqAIJ *aseq = static_cast<Mat_SeqAIJ *>(A->data), *bseq = static_cast<Mat_SeqAIJ *>(B->data); 784 PetscInt n1 = B->cmap->n, *Ai = aseq->i, *Aj = aseq->j, *Bi = bseq->i, *Bj = bseq->j; 785 const PetscInt *garray1 = empi->garray; // its size is n1 786 PetscInt cstart, cend; 787 PetscSF bcastSF; 788 789 PetscCall(MatGetOwnershipRangeColumn(E, &cstart, &cend)); 790 791 // Count how many nonzeros of each row in E are in the left of the diag block [cstart,cend) 792 PetscIntKokkosViewHost E_NzLeft_h(NoInit("E_NzLeft_h"), Em), E_RowLen_h(NoInit("E_RowLen_h"), Em); 793 PetscInt *E_NzLeft = E_NzLeft_h.data(), *E_RowLen = E_RowLen_h.data(); 794 for (PetscInt i = 0; i < Em; i++) { 795 const PetscInt *first, *last, *it; 796 PetscInt count, step; 797 // std::lower_bound(first,last,cstart), but need to use global column indices 798 first = Bj + Bi[i]; 799 last = Bj + Bi[i + 1]; 800 count = last - first; 801 while (count > 0) { 802 it = first; 803 step = count / 2; 804 it += step; 805 if (empi->garray[*it] < cstart) { // map local to global 806 first = ++it; 807 count -= step + 1; 808 } else count = step; 809 } 810 E_NzLeft[i] = first - (Bj + Bi[i]); 811 E_RowLen[i] = (Ai[i + 1] - Ai[i]) + (Bi[i + 1] - Bi[i]); 812 } 813 814 // Compute row pointer Fi of F 815 PetscInt *Fi, Fm, Fnz; 816 PetscCall(PetscSFGetGraph(ownerSF, NULL, &Fm, NULL, NULL)); // Fm = #rows of F = nleaves of ownerSF 817 PetscCall(PetscMalloc1(Fm + 1, &Fi)); 818 Fi[0] = 0; 819 PetscCall(PetscSFBcastWithMemTypeBegin(ownerSF, MPIU_INT, PETSC_MEMTYPE_HOST, E_RowLen, PETSC_MEMTYPE_HOST, &Fi[1], MPI_REPLACE)); 820 PetscCall(PetscSFBcastEnd(ownerSF, MPIU_INT, E_RowLen, &Fi[1], MPI_REPLACE)); 821 for (PetscInt i = 0; i < Fm; i++) Fi[i + 1] += Fi[i]; 822 Fnz = Fi[Fm]; 823 824 // Build the real PetscSF for bcasting E rows (buffer to buffer) 825 const PetscMPIInt *iranks, *ranks; 826 const PetscInt *ioffset, *irootloc, *roffset; 827 PetscInt niranks, nranks, *sdisp, *rdisp; 828 MPI_Request *reqs; 829 PetscMPIInt tag; 830 831 PetscCall(PetscSFGetLeafRanks(ownerSF, &niranks, &iranks, &ioffset, &irootloc)); // get leaf ranks referencing roots on this process 832 PetscCall(PetscSFGetRootRanks(ownerSF, &nranks, &ranks, &roffset, NULL, NULL)); // recv info 833 PetscCall(PetscMalloc3(niranks + 1, &sdisp, nranks, &rdisp, niranks + nranks, &reqs)); 834 835 sdisp[0] = 0; // send displacement 836 for (PetscInt i = 0; i < niranks; i++) { 837 sdisp[i + 1] = sdisp[i]; 838 for (PetscInt j = ioffset[i]; j < ioffset[i + 1]; j++) { 839 PetscInt r = irootloc[j]; // row to be sent 840 sdisp[i + 1] += E_RowLen[r]; 841 } 842 } 843 844 PetscCallMPI(PetscCommGetNewTag(comm, &tag)); 845 for (PetscInt i = 0; i < nranks; i++) PetscCallMPI(MPI_Irecv(&rdisp[i], 1, MPIU_INT, ranks[i], tag, comm, &reqs[i])); 846 for (PetscInt i = 0; i < niranks; i++) PetscCallMPI(MPI_Isend(&sdisp[i], 1, MPIU_INT, iranks[i], tag, comm, &reqs[nranks + i])); 847 PetscCallMPI(MPI_Waitall(niranks + nranks, reqs, MPI_STATUSES_IGNORE)); 848 849 PetscInt nleaves = Fnz; // leaves are nonzeros I will receive 850 PetscInt nroots = sdisp[niranks]; // roots are nonzeros I will send 851 PetscSFNode *iremote; // give ownership to bcastSF 852 PetscCall(PetscMalloc1(nleaves, &iremote)); 853 for (PetscInt i = 0; i < nranks; i++) { // for each sender rank 854 PetscInt k = 0; 855 for (PetscInt j = Fi[roffset[i]]; j < Fi[roffset[i + 1]]; j++) { // I will receive rows [roffset[i], roffset[i+1]) of F from ranks[i] 856 iremote[j].rank = ranks[i]; 857 iremote[j].index = rdisp[i] + k; // their root location 858 k++; 859 } 860 } 861 PetscCall(PetscSFCreate(comm, &bcastSF)); 862 PetscCall(PetscSFSetGraph(bcastSF, nroots, nleaves, NULL, PETSC_OWN_POINTER, iremote, PETSC_OWN_POINTER)); 863 PetscCall(PetscFree3(sdisp, rdisp, reqs)); 864 865 // Build a plan (rowoffset, irootloc, E_NzLeft) to copy rows in E to rootdata of bcastSF in parallel 866 PetscIntKokkosViewHost rowoffset_h(NoInit("rowoffset_h"), ioffset[niranks] + 1); 867 PetscInt *rowoffset = rowoffset_h.data(); // for each entry (row) indicated in irootloc[], we calculate its destinate offset in copying 868 rowoffset[0] = 0; 869 for (PetscInt i = 0; i < ioffset[niranks]; i++) { rowoffset[i + 1] = rowoffset[i] + E_RowLen[irootloc[i]]; } 870 871 // Copy (global) column indices of the needed rows in E to a buffer, and then bcast to Fj[] 872 PetscInt *jbuf, *Fj; 873 PetscCall(PetscMalloc2(nroots, &jbuf, Fnz, &Fj)); 874 for (PetscInt k = 0; k < ioffset[niranks]; k++) { 875 PetscInt i = irootloc[k]; // row to be copied 876 PetscInt *buf = &jbuf[rowoffset[k]]; 877 PetscInt nzLeft = E_NzLeft[i]; 878 PetscInt alen = Ai[i + 1] - Ai[i], blen = Bi[i + 1] - Bi[i]; 879 for (PetscInt j = 0; j < alen + blen; j++) { 880 if (j < nzLeft) { 881 buf[j] = empi->garray[Bj[Bi[i] + j]]; // left B, in global 882 } else if (j < nzLeft + alen) { 883 buf[j] = Aj[Ai[i] + j - nzLeft] + cstart; // diag A, also in global 884 } else { 885 buf[j] = empi->garray[Bj[Bi[i] + j - alen]]; // right B, in global 886 } 887 } 888 } 889 PetscCall(PetscSFBcastWithMemTypeBegin(bcastSF, MPIU_INT, PETSC_MEMTYPE_HOST, jbuf, PETSC_MEMTYPE_HOST, Fj, MPI_REPLACE)); 890 PetscCall(PetscSFBcastEnd(bcastSF, MPIU_INT, jbuf, Fj, MPI_REPLACE)); 891 892 // Build a plan (i.e., F_NzLeft) to split F into Fd and Fo 893 MatRowMapKokkosViewHost Fdi_h(NoInit("Fdi_h"), Fm + 1), Foi_h(NoInit("Foi_h"), Fm + 1); // row pointer of Fd, Fo 894 MatColIdxKokkosViewHost F_NzLeft_h(NoInit("F_NzLeft_h"), Fm); // split each row of F into Left, Diag, Right. We only need to record #nz in Left and Diag. 895 MatRowMapType *Fdi = Fdi_h.data(), *Foi = Foi_h.data(); 896 MatColIdxType *F_NzLeft = F_NzLeft_h.data(); 897 898 Fdi[0] = Foi[0] = 0; 899 for (PetscInt i = 0; i < Fm; i++) { 900 PetscInt *first, *last, *lb1, *lb2; 901 // cut the row into: Left, [cstart, cend), Right 902 first = Fj + Fi[i]; 903 last = Fj + Fi[i + 1]; 904 lb1 = std::lower_bound(first, last, cstart); 905 F_NzLeft[i] = lb1 - first; 906 lb2 = std::lower_bound(first, last, cend); 907 Fdi[i + 1] = lb2 - lb1; // row i length in Fdi 908 Foi[i + 1] = (Fi[i + 1] - Fi[i]) - Fdi[i + 1]; // row i length in Foi 909 } 910 for (PetscInt i = 0; i < Fm; i++) { 911 Fdi[i + 1] += Fdi[i]; 912 Foi[i + 1] += Foi[i]; 913 } 914 915 // Fill Fdj[] and Foj[], i.e., columns of Fd and Fo. Fdj[] are local, but Foj[] are not yet. 916 PetscInt Fdnz = Fdi[Fm], Fonz = Foi[Fm]; 917 MatColIdxKokkosViewHost Fdj_h(NoInit("Fdj_h"), Fdnz), Foj_h(NoInit("Foj_h"), Fonz); 918 MatColIdxType *Fdj = Fdj_h.data(), *Foj = Foj_h.data(), gid; 919 920 for (PetscInt i = 0; i < Fm; i++) { 921 PetscInt nzLeft = F_NzLeft[i]; 922 PetscInt len = Fdi[i + 1] - Fdi[i]; // diag row len 923 for (PetscInt j = 0; j < Fi[i + 1] - Fi[i]; j++) { 924 gid = Fj[Fi[i] + j]; 925 if (j < nzLeft) { // left, in global 926 Foj[Foi[i] + j] = gid; 927 } else if (j < nzLeft + len) { // diag, in local 928 Fdj[Fdi[i] + j - nzLeft] = gid - cstart; 929 } else { // right, in global 930 Foj[Foi[i] + j - len] = gid; 931 } 932 } 933 } 934 PetscCall(PetscFree2(jbuf, Fj)); 935 PetscCall(PetscFree(Fi)); 936 937 // Reduce global indices in Foj[] and garray1[] into local ones 938 PetscInt n2, *garray2; 939 PetscCall(ReduceTwoSetsOfGlobalIndices(n1, garray1, Fonz, Foj, &n2, &garray2, map)); 940 941 // Record the plans built above, for reuse 942 PetscIntKokkosViewHost tmp(const_cast<PetscInt *>(irootloc), ioffset[niranks]); // irootloc[] is owned by ownerSF. We create a copy for safety 943 PetscIntKokkosViewHost irootloc_h(NoInit("irootloc_h"), ioffset[niranks]); 944 Kokkos::deep_copy(irootloc_h, tmp); 945 mm->sf = bcastSF; 946 mm->E_NzLeft = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), E_NzLeft_h); 947 mm->F_NzLeft = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), F_NzLeft_h); 948 mm->irootloc = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), irootloc_h); 949 mm->rowoffset = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), rowoffset_h); 950 mm->rootBuf = MatScalarKokkosView(NoInit("rootBuf"), nroots); 951 mm->leafBuf = MatScalarKokkosView(NoInit("leafBuf"), nleaves); 952 mm->garray = garray2; 953 mm->n = n2; 954 955 // Output Fd and Fo in KokkosCsrMatrix format 956 MatScalarKokkosView Fda_d(NoInit("Fda_d"), Fdnz), Foa_d(NoInit("Foa_d"), Fonz); 957 MatRowMapKokkosView Fdi_d = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Fdi_h); 958 MatColIdxKokkosView Fdj_d = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Fdj_h); 959 MatRowMapKokkosView Foi_d = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Foi_h); 960 MatColIdxKokkosView Foj_d = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Foj_h); 961 962 PetscCallCXX(mm->Fd = KokkosCsrMatrix("Fd", Fm, cend - cstart, Fdnz, Fda_d, Fdi_d, Fdj_d)); 963 PetscCallCXX(mm->Fo = KokkosCsrMatrix("Fo", Fm, n2, Fonz, Foa_d, Foi_d, Foj_d)); 964 965 // Compute kernel launch parameters in merging E or splitting F 966 PetscInt teamSize, vectorLength, rowsPerTeam; 967 968 teamSize = vectorLength = rowsPerTeam = -1; 969 PetscCall(MatMergeGetLaunchParameters<DefaultExecutionSpace>(mm->irootloc.extent(0), mm->rootBuf.extent(0), -1, teamSize, vectorLength, rowsPerTeam)); 970 mm->E_TeamSize = teamSize; 971 mm->E_VectorLength = vectorLength; 972 mm->E_RowsPerTeam = rowsPerTeam; 973 974 teamSize = vectorLength = rowsPerTeam = -1; 975 PetscCall(MatMergeGetLaunchParameters<DefaultExecutionSpace>(Fm, Fnz, -1, teamSize, vectorLength, rowsPerTeam)); 976 mm->F_TeamSize = teamSize; 977 mm->F_VectorLength = vectorLength; 978 mm->F_RowsPerTeam = rowsPerTeam; 979 } else PetscCheck(reuse == MAT_REUSE_MATRIX, comm, PETSC_ERR_PLIB, "Unsupported MatReuse enum %d", reuse); 980 981 // Sync E's value to device 982 akok->a_dual.sync_device(); 983 bkok->a_dual.sync_device(); 984 985 // Handy aliases 986 const auto &Aa = akok->a_dual.view_device(); 987 const auto &Ba = bkok->a_dual.view_device(); 988 const auto &Ai = akok->i_dual.view_device(); 989 const auto &Bi = bkok->i_dual.view_device(); 990 991 // Fetch the plans 992 PetscIntKokkosView &E_NzLeft = mm->E_NzLeft; 993 PetscSF &bcastSF = mm->sf; 994 MatScalarKokkosView &rootBuf = mm->rootBuf; 995 MatScalarKokkosView &leafBuf = mm->leafBuf; 996 PetscIntKokkosView &irootloc = mm->irootloc; 997 PetscIntKokkosView &rowoffset = mm->rowoffset; 998 999 PetscInt teamSize = mm->E_TeamSize; 1000 PetscInt vectorLength = mm->E_VectorLength; 1001 PetscInt rowsPerTeam = mm->E_RowsPerTeam; 1002 PetscInt workSets = (irootloc.extent(0) + rowsPerTeam - 1) / rowsPerTeam; 1003 1004 // Copy rows in A/B of E to rootBuf, then bcast it to leafBuf 1005 PetscCallCXX(Kokkos::parallel_for( 1006 Kokkos::TeamPolicy<>(PetscGetKokkosExecutionSpace(), workSets, teamSize, vectorLength), KOKKOS_LAMBDA(const KokkosTeamMemberType &t) { 1007 Kokkos::parallel_for(Kokkos::TeamThreadRange(t, 0, rowsPerTeam), [&](PetscInt k) { 1008 size_t r = t.league_rank() * rowsPerTeam + k; // r-th entry in irootloc[] 1009 if (r < irootloc.extent(0)) { 1010 PetscInt i = irootloc(r); // row i of E 1011 PetscInt disp = rowoffset(r); 1012 PetscInt alen = Ai(i + 1) - Ai(i); 1013 PetscInt blen = Bi(i + 1) - Bi(i); 1014 PetscInt nzleft = E_NzLeft(i); 1015 1016 Kokkos::parallel_for(Kokkos::ThreadVectorRange(t, alen + blen), [&](PetscInt j) { 1017 if (j < nzleft) { // B left 1018 rootBuf(disp + j) = Ba(Bi(i) + j); 1019 } else if (j < nzleft + alen) { // diag A 1020 rootBuf(disp + j) = Aa(Ai(i) + j - nzleft); 1021 } else { // B right 1022 rootBuf(disp + j) = Ba(Bi(i) + j - alen); 1023 } 1024 }); 1025 } 1026 }); 1027 })); 1028 PetscCall(PetscSFBcastWithMemTypeBegin(bcastSF, MPIU_SCALAR, PETSC_MEMTYPE_KOKKOS, rootBuf.data(), PETSC_MEMTYPE_KOKKOS, leafBuf.data(), MPI_REPLACE)); 1029 PetscFunctionReturn(PETSC_SUCCESS); 1030 } 1031 1032 // To finish MatMPIAIJKokkosBcast. 1033 static PetscErrorCode MatMPIAIJKokkosBcastEnd(Mat E, PetscSF ownerSF, MatReuse reuse, PetscInt *map, MatMatStruct_AB *mm) 1034 { 1035 PetscFunctionBegin; 1036 const auto &Fd = mm->Fd; 1037 const auto &Fo = mm->Fo; 1038 const auto &Fdi = Fd.graph.row_map; 1039 const auto &Foi = Fo.graph.row_map; 1040 auto &Fda = Fd.values; 1041 auto &Foa = Fo.values; 1042 auto Fm = Fd.numRows(); 1043 1044 PetscIntKokkosView &F_NzLeft = mm->F_NzLeft; 1045 PetscSF &bcastSF = mm->sf; 1046 MatScalarKokkosView &rootBuf = mm->rootBuf; 1047 MatScalarKokkosView &leafBuf = mm->leafBuf; 1048 PetscInt teamSize = mm->F_TeamSize; 1049 PetscInt vectorLength = mm->F_VectorLength; 1050 PetscInt rowsPerTeam = mm->F_RowsPerTeam; 1051 PetscInt workSets = (Fm + rowsPerTeam - 1) / rowsPerTeam; 1052 1053 PetscCall(PetscSFBcastEnd(bcastSF, MPIU_SCALAR, rootBuf.data(), leafBuf.data(), MPI_REPLACE)); 1054 1055 // Update Fda and Foa with new data in leafBuf (as if it is Fa) 1056 PetscCallCXX(Kokkos::parallel_for( 1057 Kokkos::TeamPolicy<>(PetscGetKokkosExecutionSpace(), workSets, teamSize, vectorLength), KOKKOS_LAMBDA(const KokkosTeamMemberType &t) { 1058 Kokkos::parallel_for(Kokkos::TeamThreadRange(t, 0, rowsPerTeam), [&](PetscInt k) { 1059 PetscInt i = t.league_rank() * rowsPerTeam + k; // i-th row in F 1060 if (i < Fm) { 1061 PetscInt nzLeft = F_NzLeft(i); 1062 PetscInt alen = Fdi(i + 1) - Fdi(i); 1063 PetscInt blen = Foi(i + 1) - Foi(i); 1064 PetscInt Fii = Fdi(i) + Foi(i); 1065 1066 Kokkos::parallel_for(Kokkos::ThreadVectorRange(t, alen + blen), [&](PetscInt j) { 1067 PetscScalar val = leafBuf(Fii + j); 1068 if (j < nzLeft) { // left 1069 Foa(Foi(i) + j) = val; 1070 } else if (j < nzLeft + alen) { // diag 1071 Fda(Fdi(i) + j - nzLeft) = val; 1072 } else { // right 1073 Foa(Foi(i) + j - alen) = val; 1074 } 1075 }); 1076 } 1077 }); 1078 })); 1079 PetscFunctionReturn(PETSC_SUCCESS); 1080 } 1081 1082 static PetscErrorCode MatProductSymbolic_MPIAIJKokkos_AtB(Mat_Product *product, Mat A, Mat B, MatMatStruct_AtB *mm) 1083 { 1084 Mat_MPIAIJ *ampi = static_cast<Mat_MPIAIJ *>(A->data); 1085 Mat_MPIAIJ *bmpi = static_cast<Mat_MPIAIJ *>(B->data); 1086 KokkosCsrMatrix Adt, Aot, Ad, Ao, Bd, Bo; 1087 PetscInt cstart, cend; 1088 MPI_Comm comm; 1089 1090 PetscFunctionBegin; 1091 PetscCall(PetscObjectGetComm((PetscObject)B, &comm)); 1092 PetscCall(MatSeqAIJKokkosGenerateTranspose_Private(ampi->A, &Adt)); 1093 PetscCall(MatSeqAIJKokkosGenerateTranspose_Private(ampi->B, &Aot)); 1094 PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(ampi->A, &Ad)); 1095 PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(ampi->B, &Ao)); 1096 PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(bmpi->A, &Bd)); 1097 PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(bmpi->B, &Bo)); 1098 1099 // TODO: add command line options to select spgemm algorithms 1100 auto spgemm_alg = KokkosSparse::SPGEMMAlgorithm::SPGEMM_DEFAULT; // default is TPL if enabled, otherwise KK 1101 1102 // CUDA-10.2's spgemm has bugs. We prefer the SpGEMMreuse APIs introduced in cuda-11.4 1103 #if defined(KOKKOSKERNELS_ENABLE_TPL_CUSPARSE) 1104 #if PETSC_PKG_CUDA_VERSION_LT(11, 4, 0) 1105 spgemm_alg = KokkosSparse::SPGEMMAlgorithm::SPGEMM_KK; 1106 #endif 1107 #endif 1108 1109 PetscCallCXX(mm->kh1.create_spgemm_handle(spgemm_alg)); 1110 PetscCallCXX(mm->kh2.create_spgemm_handle(spgemm_alg)); 1111 PetscCallCXX(mm->kh3.create_spgemm_handle(spgemm_alg)); 1112 PetscCallCXX(mm->kh4.create_spgemm_handle(spgemm_alg)); 1113 1114 // Aot * (B's diag + B's off-diag) 1115 PetscCallCXX(KokkosSparse::spgemm_symbolic(mm->kh3, Aot, false, Bd, false, mm->C3)); 1116 PetscCallCXX(KokkosSparse::spgemm_symbolic(mm->kh4, Aot, false, Bo, false, mm->C4)); 1117 // KK spgemm_symbolic() only populates the result's row map, but not its columns. 1118 // TODO: Remove the fake spgemm_numeric() after KK fixed this problem. 1119 PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh3, Aot, false, Bd, false, mm->C3)); 1120 PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh4, Aot, false, Bo, false, mm->C4)); 1121 #if PETSC_PKG_KOKKOS_KERNELS_VERSION_LT(4, 0, 0) 1122 1123 PetscCallCXX(sort_crs_matrix(mm->C3)); 1124 PetscCallCXX(sort_crs_matrix(mm->C4)); 1125 #endif 1126 1127 // Reduce E (i.e., C3 and C4)'s rows to form F, and overlap the communication 1128 PetscIntKokkosViewHost map_h(NoInit("map_h"), bmpi->B->cmap->n); 1129 PetscCall(MatGetOwnershipRangeColumn(B, &cstart, &cend)); 1130 PetscCall(MatMPIAIJKokkosReduceBegin(comm, mm->C3, mm->C4, cstart, cend, bmpi->garray, ampi->Mvctx, MAT_INITIAL_MATRIX, map_h.data(), mm)); 1131 1132 // Adt * (B's diag + B's off-diag) 1133 PetscCallCXX(KokkosSparse::spgemm_symbolic(mm->kh1, Adt, false, Bd, false, mm->C1)); 1134 PetscCallCXX(KokkosSparse::spgemm_symbolic(mm->kh2, Adt, false, Bo, false, mm->C2_mid)); 1135 PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh1, Adt, false, Bd, false, mm->C1)); 1136 PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh2, Adt, false, Bo, false, mm->C2_mid)); 1137 #if PETSC_PKG_KOKKOS_KERNELS_VERSION_LT(4, 0, 0) 1138 PetscCallCXX(sort_crs_matrix(mm->C1)); 1139 PetscCallCXX(sort_crs_matrix(mm->C2_mid)); 1140 #endif 1141 1142 PetscCall(MatMPIAIJKokkosReduceEnd(comm, mm->C3, mm->C4, cstart, cend, bmpi->garray, ampi->Mvctx, MAT_INITIAL_MATRIX, map_h.data(), mm)); 1143 1144 // Create C2, which shares a, i arrays with C2_mid, but with new column indices and potentially larger column size 1145 MatColIdxKokkosView oldj = mm->C2_mid.graph.entries, newj(NoInit("j"), oldj.extent(0)); 1146 PetscIntKokkosView map = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), map_h); 1147 PetscCallCXX(Kokkos::parallel_for(Kokkos::RangePolicy<>(PetscGetKokkosExecutionSpace(), 0, oldj.extent(0)), KOKKOS_LAMBDA(const PetscInt i) { newj(i) = map(oldj(i)); })); 1148 PetscCallCXX(mm->C2 = KokkosCsrMatrix("C2", mm->C2_mid.numRows(), mm->n /*new column size*/, mm->C2_mid.nnz(), mm->C2_mid.values, mm->C2_mid.graph.row_map, newj)); 1149 1150 // C = (C1+Fd, C2+Fo) 1151 PetscCallCXX(mm->kh1.create_spadd_handle(true)); // C1, Fd are sorted 1152 PetscCallCXX(mm->kh2.create_spadd_handle(true)); // C2, Fo are sorted 1153 PetscCallCXX(KokkosSparse::spadd_symbolic(&mm->kh1, mm->C1, mm->Fd, mm->Cd)); 1154 PetscCallCXX(KokkosSparse::spadd_symbolic(&mm->kh2, mm->C2, mm->Fo, mm->Co)); 1155 PetscCallCXX(KokkosSparse::spadd_numeric(&mm->kh1, 1.0, mm->C1, 1.0, mm->Fd, mm->Cd)); 1156 PetscCallCXX(KokkosSparse::spadd_numeric(&mm->kh2, 1.0, mm->C2, 1.0, mm->Fo, mm->Co)); 1157 PetscFunctionReturn(PETSC_SUCCESS); 1158 } 1159 1160 static PetscErrorCode MatProductNumeric_MPIAIJKokkos_AtB(Mat_Product *product, Mat A, Mat B, MatMatStruct_AtB *mm) 1161 { 1162 Mat_MPIAIJ *ampi = static_cast<Mat_MPIAIJ *>(A->data); 1163 Mat_MPIAIJ *bmpi = static_cast<Mat_MPIAIJ *>(B->data); 1164 KokkosCsrMatrix Adt, Aot, Bd, Bo; 1165 MPI_Comm comm; 1166 1167 PetscFunctionBegin; 1168 PetscCall(PetscObjectGetComm((PetscObject)B, &comm)); 1169 PetscCall(MatSeqAIJKokkosGenerateTranspose_Private(ampi->A, &Adt)); 1170 PetscCall(MatSeqAIJKokkosGenerateTranspose_Private(ampi->B, &Aot)); 1171 PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(bmpi->A, &Bd)); 1172 PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(bmpi->B, &Bo)); 1173 1174 // Aot * (B's diag + B's off-diag) 1175 PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh3, Aot, false, Bd, false, mm->C3)); 1176 PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh4, Aot, false, Bo, false, mm->C4)); 1177 1178 // Reduce E (i.e., C3 and C4)'s rows to form F, and overlap the communication 1179 PetscCall(MatMPIAIJKokkosReduceBegin(comm, mm->C3, mm->C4, 0, 0, NULL, NULL, MAT_REUSE_MATRIX, NULL, mm)); 1180 1181 // Adt * (B's diag + B's off-diag) 1182 PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh1, Adt, false, Bd, false, mm->C1)); 1183 PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh2, Adt, false, Bo, false, mm->C2_mid)); 1184 1185 PetscCall(MatMPIAIJKokkosReduceEnd(comm, mm->C3, mm->C4, 0, 0, NULL, NULL, MAT_REUSE_MATRIX, NULL, mm)); 1186 1187 // C = (C1+Fd, C2+Fo) 1188 PetscCallCXX(KokkosSparse::spadd_numeric(&mm->kh1, 1.0, mm->C1, 1.0, mm->Fd, mm->Cd)); 1189 PetscCallCXX(KokkosSparse::spadd_numeric(&mm->kh2, 1.0, mm->C2, 1.0, mm->Fo, mm->Co)); 1190 PetscFunctionReturn(PETSC_SUCCESS); 1191 } 1192 1193 /* MatProductSymbolic_MPIAIJKokkos_AB - AB flavor of MatProductSymbolic_MPIAIJKokkos 1194 1195 Input Parameters: 1196 + product - Mat_Product which carried out the computation. Passed in to access info about this mat product. 1197 . A - an MPIAIJKOKKOS matrix 1198 . B - an MPIAIJKOKKOS matrix 1199 - mm - a struct used to stash intermediate data when computing AB. Persist from symbolic to numeric operations. 1200 */ 1201 static PetscErrorCode MatProductSymbolic_MPIAIJKokkos_AB(Mat_Product *product, Mat A, Mat B, MatMatStruct_AB *mm) 1202 { 1203 Mat_MPIAIJ *ampi = static_cast<Mat_MPIAIJ *>(A->data); 1204 Mat_MPIAIJ *bmpi = static_cast<Mat_MPIAIJ *>(B->data); 1205 KokkosCsrMatrix Ad, Ao, Bd, Bo; 1206 1207 PetscFunctionBegin; 1208 PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(ampi->A, &Ad)); 1209 PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(ampi->B, &Ao)); 1210 PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(bmpi->A, &Bd)); 1211 PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(bmpi->B, &Bo)); 1212 1213 // TODO: add command line options to select spgemm algorithms 1214 auto spgemm_alg = KokkosSparse::SPGEMMAlgorithm::SPGEMM_DEFAULT; // default is TPL if enabled, otherwise KK 1215 1216 // CUDA-10.2's spgemm has bugs. We prefer the SpGEMMreuse APIs introduced in cuda-11.4 1217 #if defined(KOKKOSKERNELS_ENABLE_TPL_CUSPARSE) 1218 #if PETSC_PKG_CUDA_VERSION_LT(11, 4, 0) 1219 spgemm_alg = KokkosSparse::SPGEMMAlgorithm::SPGEMM_KK; 1220 #endif 1221 #endif 1222 1223 mm->kh1.create_spgemm_handle(spgemm_alg); 1224 mm->kh2.create_spgemm_handle(spgemm_alg); 1225 mm->kh3.create_spgemm_handle(spgemm_alg); 1226 mm->kh4.create_spgemm_handle(spgemm_alg); 1227 1228 // Bcast B's rows to form F, and overlap the communication 1229 PetscIntKokkosViewHost map_h(NoInit("map_h"), bmpi->B->cmap->n); 1230 PetscCall(MatMPIAIJKokkosBcastBegin(B, ampi->Mvctx, MAT_INITIAL_MATRIX, map_h.data(), mm)); 1231 1232 // A's diag * (B's diag + B's off-diag) 1233 PetscCallCXX(KokkosSparse::spgemm_symbolic(mm->kh1, Ad, false, Bd, false, mm->C1)); 1234 PetscCallCXX(KokkosSparse::spgemm_symbolic(mm->kh2, Ad, false, Bo, false, mm->C2_mid)); // C2 aliases with C2_mid, except with new column indices 1235 // KK spgemm_symbolic() only populates the result's row map, but not its columns. 1236 // TODO: Remove the fake spgemm_numeric() after KK fixed this problem. 1237 PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh1, Ad, false, Bd, false, mm->C1)); 1238 PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh2, Ad, false, Bo, false, mm->C2_mid)); 1239 #if PETSC_PKG_KOKKOS_KERNELS_VERSION_LT(4, 0, 0) 1240 PetscCallCXX(sort_crs_matrix(mm->C1)); 1241 PetscCallCXX(sort_crs_matrix(mm->C2_mid)); 1242 #endif 1243 1244 PetscCall(MatMPIAIJKokkosBcastEnd(B, ampi->Mvctx, MAT_INITIAL_MATRIX, map_h.data(), mm)); 1245 1246 // A's off-diag * (F's diag + F's off-diag) 1247 PetscCallCXX(KokkosSparse::spgemm_symbolic(mm->kh3, Ao, false, mm->Fd, false, mm->C3)); 1248 PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh3, Ao, false, mm->Fd, false, mm->C3)); 1249 PetscCallCXX(KokkosSparse::spgemm_symbolic(mm->kh4, Ao, false, mm->Fo, false, mm->C4)); 1250 PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh4, Ao, false, mm->Fo, false, mm->C4)); 1251 #if PETSC_PKG_KOKKOS_KERNELS_VERSION_LT(4, 0, 0) 1252 PetscCallCXX(sort_crs_matrix(mm->C3)); 1253 PetscCallCXX(sort_crs_matrix(mm->C4)); 1254 #endif 1255 1256 // Create C2, which shares a, i arrays with C2_mid, but with new column indices and potentially larger column size 1257 MatColIdxKokkosView oldj = mm->C2_mid.graph.entries, newj(NoInit("j"), oldj.extent(0)); 1258 PetscIntKokkosView map = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), map_h); 1259 PetscCallCXX(Kokkos::parallel_for(Kokkos::RangePolicy<>(PetscGetKokkosExecutionSpace(), 0, oldj.extent(0)), KOKKOS_LAMBDA(const PetscInt i) { newj(i) = map(oldj(i)); })); 1260 mm->C2 = KokkosCsrMatrix("C2", mm->C2_mid.numRows(), mm->n /*new column size*/, mm->C2_mid.nnz(), mm->C2_mid.values, mm->C2_mid.graph.row_map, newj); 1261 1262 // C = (Cd, Co) = (C1+C3, C2+C4) 1263 mm->kh1.create_spadd_handle(true); // C1, C3 are sorted 1264 mm->kh2.create_spadd_handle(true); // C2, C4 are sorted 1265 PetscCallCXX(KokkosSparse::spadd_symbolic(&mm->kh1, mm->C1, mm->C3, mm->Cd)); 1266 PetscCallCXX(KokkosSparse::spadd_symbolic(&mm->kh2, mm->C2, mm->C4, mm->Co)); 1267 PetscCallCXX(KokkosSparse::spadd_numeric(&mm->kh1, 1.0, mm->C1, 1.0, mm->C3, mm->Cd)); 1268 PetscCallCXX(KokkosSparse::spadd_numeric(&mm->kh2, 1.0, mm->C2, 1.0, mm->C4, mm->Co)); 1269 PetscFunctionReturn(PETSC_SUCCESS); 1270 } 1271 1272 static PetscErrorCode MatProductNumeric_MPIAIJKokkos_AB(Mat_Product *product, Mat A, Mat B, MatMatStruct_AB *mm) 1273 { 1274 Mat_MPIAIJ *ampi = static_cast<Mat_MPIAIJ *>(A->data); 1275 Mat_MPIAIJ *bmpi = static_cast<Mat_MPIAIJ *>(B->data); 1276 KokkosCsrMatrix Ad, Ao, Bd, Bo; 1277 1278 PetscFunctionBegin; 1279 PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(ampi->A, &Ad)); 1280 PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(ampi->B, &Ao)); 1281 PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(bmpi->A, &Bd)); 1282 PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(bmpi->B, &Bo)); 1283 1284 // Bcast B's rows to form F, and overlap the communication 1285 PetscCall(MatMPIAIJKokkosBcastBegin(B, NULL, MAT_REUSE_MATRIX, NULL, mm)); 1286 1287 // A's diag * (B's diag + B's off-diag) 1288 PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh1, Ad, false, Bd, false, mm->C1)); 1289 PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh2, Ad, false, Bo, false, mm->C2_mid)); 1290 1291 PetscCall(MatMPIAIJKokkosBcastEnd(B, NULL, MAT_REUSE_MATRIX, NULL, mm)); 1292 1293 // A's off-diag * (F's diag + F's off-diag) 1294 PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh3, Ao, false, mm->Fd, false, mm->C3)); 1295 PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh4, Ao, false, mm->Fo, false, mm->C4)); 1296 1297 // C = (Cd, Co) = (C1+C3, C2+C4) 1298 PetscCallCXX(KokkosSparse::spadd_numeric(&mm->kh1, 1.0, mm->C1, 1.0, mm->C3, mm->Cd)); 1299 PetscCallCXX(KokkosSparse::spadd_numeric(&mm->kh2, 1.0, mm->C2, 1.0, mm->C4, mm->Co)); 1300 PetscFunctionReturn(PETSC_SUCCESS); 1301 } 1302 1303 static PetscErrorCode MatProductNumeric_MPIAIJKokkos(Mat C) 1304 { 1305 Mat_MPIAIJ *cmpi = static_cast<Mat_MPIAIJ *>(C->data); 1306 Mat_Product *product; 1307 MatProductData_MPIAIJKokkos *pdata; 1308 MatProductType ptype; 1309 Mat A, B; 1310 1311 PetscFunctionBegin; 1312 MatCheckProduct(C, 1); // make sure C is a product 1313 product = C->product; 1314 pdata = static_cast<MatProductData_MPIAIJKokkos *>(product->data); 1315 ptype = product->type; 1316 A = product->A; 1317 B = product->B; 1318 1319 // See if numeric has already been done in symbolic (e.g., user calls MatMatMult(A,B,MAT_INITIAL_MATRIX,..,C)). 1320 // If yes, skip the numeric, but reset the flag so that next time when user calls MatMatMult(E,F,MAT_REUSE_MATRIX,..,C), 1321 // we still do numeric. 1322 if (pdata->reusesym) { // numeric reuses results from symbolic 1323 pdata->reusesym = PETSC_FALSE; 1324 PetscFunctionReturn(PETSC_SUCCESS); 1325 } 1326 1327 if (ptype == MATPRODUCT_AB) { 1328 PetscCall(MatProductNumeric_MPIAIJKokkos_AB(product, A, B, pdata->mmAB)); 1329 } else if (ptype == MATPRODUCT_AtB) { 1330 PetscCall(MatProductNumeric_MPIAIJKokkos_AtB(product, A, B, pdata->mmAtB)); 1331 } else if (ptype == MATPRODUCT_PtAP) { // BtAB, computed by Z = AB; C= BtZ 1332 PetscCall(MatProductNumeric_MPIAIJKokkos_AB(product, A, B, pdata->mmAB)); 1333 PetscCall(MatProductNumeric_MPIAIJKokkos_AtB(product, B, pdata->Z, pdata->mmAtB)); 1334 } 1335 1336 PetscCall(MatSeqAIJKokkosModifyDevice(cmpi->A)); // mark that A, B on device are modified 1337 PetscCall(MatSeqAIJKokkosModifyDevice(cmpi->B)); 1338 PetscFunctionReturn(PETSC_SUCCESS); 1339 } 1340 1341 static PetscErrorCode MatProductSymbolic_MPIAIJKokkos(Mat C) 1342 { 1343 Mat A, B; 1344 Mat_Product *product; 1345 MatProductType ptype; 1346 MatProductData_MPIAIJKokkos *pdata; 1347 MatMatStruct *mm = NULL; 1348 PetscInt m, n, M, N; 1349 Mat Cd, Co; 1350 MPI_Comm comm; 1351 1352 PetscFunctionBegin; 1353 PetscCall(PetscObjectGetComm((PetscObject)C, &comm)); 1354 MatCheckProduct(C, 1); 1355 product = C->product; 1356 PetscCheck(!product->data, comm, PETSC_ERR_PLIB, "Product data not empty"); 1357 ptype = product->type; 1358 A = product->A; 1359 B = product->B; 1360 1361 switch (ptype) { 1362 case MATPRODUCT_AB: 1363 m = A->rmap->n; 1364 n = B->cmap->n; 1365 M = A->rmap->N; 1366 N = B->cmap->N; 1367 break; 1368 case MATPRODUCT_AtB: 1369 m = A->cmap->n; 1370 n = B->cmap->n; 1371 M = A->cmap->N; 1372 N = B->cmap->N; 1373 break; 1374 case MATPRODUCT_PtAP: 1375 m = B->cmap->n; 1376 n = B->cmap->n; 1377 M = B->cmap->N; 1378 N = B->cmap->N; 1379 break; /* BtAB */ 1380 default: 1381 SETERRQ(comm, PETSC_ERR_PLIB, "Not for product type %s", MatProductTypes[ptype]); 1382 } 1383 1384 PetscCall(MatSetSizes(C, m, n, M, N)); 1385 PetscCall(PetscLayoutSetUp(C->rmap)); 1386 PetscCall(PetscLayoutSetUp(C->cmap)); 1387 PetscCall(MatSetType(C, ((PetscObject)A)->type_name)); 1388 1389 pdata = new MatProductData_MPIAIJKokkos(); 1390 pdata->reusesym = product->api_user; 1391 1392 if (ptype == MATPRODUCT_AB) { 1393 auto mmAB = new MatMatStruct_AB(); 1394 PetscCall(MatProductSymbolic_MPIAIJKokkos_AB(product, A, B, mmAB)); 1395 mm = pdata->mmAB = mmAB; 1396 } else if (ptype == MATPRODUCT_AtB) { 1397 auto mmAtB = new MatMatStruct_AtB(); 1398 PetscCall(MatProductSymbolic_MPIAIJKokkos_AtB(product, A, B, mmAtB)); 1399 mm = pdata->mmAtB = mmAtB; 1400 } else if (ptype == MATPRODUCT_PtAP) { // C = BtAB, computed as Z = AB; C= BtZ 1401 Mat Zd, Zo, Z; // Zd, Zo are owned by pdata->Z 1402 1403 auto mmAB = new MatMatStruct_AB(); 1404 PetscCall(MatProductSymbolic_MPIAIJKokkos_AB(product, A, B, mmAB)); // Z stored as mmAB->{Cd, Co} 1405 PetscCall(MatCreateSeqAIJKokkosWithKokkosCsrMatrix(PETSC_COMM_SELF, mmAB->Cd, &Zd)); 1406 PetscCall(MatCreateSeqAIJKokkosWithKokkosCsrMatrix(PETSC_COMM_SELF, mmAB->Co, &Zo)); 1407 pdata->mmAB = mmAB; 1408 1409 m = A->rmap->n; // Z's layout 1410 n = B->cmap->n; 1411 M = A->rmap->N; 1412 N = B->cmap->N; 1413 PetscCall(MatCreate(comm, &Z)); 1414 PetscCall(MatSetSizes(Z, m, n, M, N)); 1415 PetscCall(PetscLayoutSetUp(Z->rmap)); 1416 PetscCall(PetscLayoutSetUp(Z->cmap)); 1417 PetscCall(MatSetType(Z, MATMPIAIJKOKKOS)); 1418 PetscCall(MatSetMPIAIJKokkosWithSplitSeqAIJKokkosMatrices(Z, Zd, Zo, mmAB->garray)); 1419 1420 auto mmAtB = new MatMatStruct_AtB(); 1421 PetscCall(MatProductSymbolic_MPIAIJKokkos_AtB(product, B, Z, mmAtB)); // final result C stored as mmAtB->{Cd, Co} 1422 1423 pdata->Z = Z; // give ownership to pdata 1424 mm = pdata->mmAtB = mmAtB; 1425 } 1426 1427 PetscCall(MatCreateSeqAIJKokkosWithKokkosCsrMatrix(PETSC_COMM_SELF, mm->Cd, &Cd)); 1428 PetscCall(MatCreateSeqAIJKokkosWithKokkosCsrMatrix(PETSC_COMM_SELF, mm->Co, &Co)); 1429 PetscCall(MatSetMPIAIJKokkosWithSplitSeqAIJKokkosMatrices(C, Cd, Co, mm->garray)); 1430 1431 C->product->data = pdata; 1432 C->product->destroy = MatProductDataDestroy_MPIAIJKokkos; 1433 C->ops->productnumeric = MatProductNumeric_MPIAIJKokkos; 1434 PetscFunctionReturn(PETSC_SUCCESS); 1435 } 1436 1437 PETSC_INTERN PetscErrorCode MatProductSetFromOptions_MPIAIJKokkos(Mat mat) 1438 { 1439 Mat_Product *product = mat->product; 1440 PetscBool match = PETSC_FALSE; 1441 PetscBool usecpu = PETSC_FALSE; 1442 1443 PetscFunctionBegin; 1444 MatCheckProduct(mat, 1); 1445 if (!product->A->boundtocpu && !product->B->boundtocpu) PetscCall(PetscObjectTypeCompare((PetscObject)product->B, ((PetscObject)product->A)->type_name, &match)); 1446 if (match) { /* we can always fallback to the CPU if requested */ 1447 switch (product->type) { 1448 case MATPRODUCT_AB: 1449 if (product->api_user) { 1450 PetscOptionsBegin(PetscObjectComm((PetscObject)mat), ((PetscObject)mat)->prefix, "MatMatMult", "Mat"); 1451 PetscCall(PetscOptionsBool("-matmatmult_backend_cpu", "Use CPU code", "MatMatMult", usecpu, &usecpu, NULL)); 1452 PetscOptionsEnd(); 1453 } else { 1454 PetscOptionsBegin(PetscObjectComm((PetscObject)mat), ((PetscObject)mat)->prefix, "MatProduct_AB", "Mat"); 1455 PetscCall(PetscOptionsBool("-mat_product_algorithm_backend_cpu", "Use CPU code", "MatMatMult", usecpu, &usecpu, NULL)); 1456 PetscOptionsEnd(); 1457 } 1458 break; 1459 case MATPRODUCT_AtB: 1460 if (product->api_user) { 1461 PetscOptionsBegin(PetscObjectComm((PetscObject)mat), ((PetscObject)mat)->prefix, "MatTransposeMatMult", "Mat"); 1462 PetscCall(PetscOptionsBool("-mattransposematmult_backend_cpu", "Use CPU code", "MatTransposeMatMult", usecpu, &usecpu, NULL)); 1463 PetscOptionsEnd(); 1464 } else { 1465 PetscOptionsBegin(PetscObjectComm((PetscObject)mat), ((PetscObject)mat)->prefix, "MatProduct_AtB", "Mat"); 1466 PetscCall(PetscOptionsBool("-mat_product_algorithm_backend_cpu", "Use CPU code", "MatTransposeMatMult", usecpu, &usecpu, NULL)); 1467 PetscOptionsEnd(); 1468 } 1469 break; 1470 case MATPRODUCT_PtAP: 1471 if (product->api_user) { 1472 PetscOptionsBegin(PetscObjectComm((PetscObject)mat), ((PetscObject)mat)->prefix, "MatPtAP", "Mat"); 1473 PetscCall(PetscOptionsBool("-matptap_backend_cpu", "Use CPU code", "MatPtAP", usecpu, &usecpu, NULL)); 1474 PetscOptionsEnd(); 1475 } else { 1476 PetscOptionsBegin(PetscObjectComm((PetscObject)mat), ((PetscObject)mat)->prefix, "MatProduct_PtAP", "Mat"); 1477 PetscCall(PetscOptionsBool("-mat_product_algorithm_backend_cpu", "Use CPU code", "MatPtAP", usecpu, &usecpu, NULL)); 1478 PetscOptionsEnd(); 1479 } 1480 break; 1481 default: 1482 break; 1483 } 1484 match = (PetscBool)!usecpu; 1485 } 1486 if (match) { 1487 switch (product->type) { 1488 case MATPRODUCT_AB: 1489 case MATPRODUCT_AtB: 1490 case MATPRODUCT_PtAP: 1491 mat->ops->productsymbolic = MatProductSymbolic_MPIAIJKokkos; 1492 break; 1493 default: 1494 break; 1495 } 1496 } 1497 /* fallback to MPIAIJ ops */ 1498 if (!mat->ops->productsymbolic) PetscCall(MatProductSetFromOptions_MPIAIJ(mat)); 1499 PetscFunctionReturn(PETSC_SUCCESS); 1500 } 1501 1502 // Mirror of MatCOOStruct_MPIAIJ on device 1503 struct MatCOOStruct_MPIAIJKokkos { 1504 PetscCount n; 1505 PetscSF sf; 1506 PetscCount Annz, Bnnz; 1507 PetscCount Annz2, Bnnz2; 1508 PetscCountKokkosView Ajmap1, Aperm1; 1509 PetscCountKokkosView Bjmap1, Bperm1; 1510 PetscCountKokkosView Aimap2, Ajmap2, Aperm2; 1511 PetscCountKokkosView Bimap2, Bjmap2, Bperm2; 1512 PetscCountKokkosView Cperm1; 1513 MatScalarKokkosView sendbuf, recvbuf; 1514 1515 MatCOOStruct_MPIAIJKokkos(const MatCOOStruct_MPIAIJ *coo_h) : 1516 n(coo_h->n), 1517 sf(coo_h->sf), 1518 Annz(coo_h->Annz), 1519 Bnnz(coo_h->Bnnz), 1520 Annz2(coo_h->Annz2), 1521 Bnnz2(coo_h->Bnnz2), 1522 Ajmap1(Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), PetscCountKokkosViewHost(coo_h->Ajmap1, coo_h->Annz + 1))), 1523 Aperm1(Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), PetscCountKokkosViewHost(coo_h->Aperm1, coo_h->Atot1))), 1524 Bjmap1(Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), PetscCountKokkosViewHost(coo_h->Bjmap1, coo_h->Bnnz + 1))), 1525 Bperm1(Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), PetscCountKokkosViewHost(coo_h->Bperm1, coo_h->Btot1))), 1526 Aimap2(Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), PetscCountKokkosViewHost(coo_h->Aimap2, coo_h->Annz2))), 1527 Ajmap2(Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), PetscCountKokkosViewHost(coo_h->Ajmap2, coo_h->Annz2 + 1))), 1528 Aperm2(Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), PetscCountKokkosViewHost(coo_h->Aperm2, coo_h->Atot2))), 1529 Bimap2(Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), PetscCountKokkosViewHost(coo_h->Bimap2, coo_h->Bnnz2))), 1530 Bjmap2(Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), PetscCountKokkosViewHost(coo_h->Bjmap2, coo_h->Bnnz2 + 1))), 1531 Bperm2(Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), PetscCountKokkosViewHost(coo_h->Bperm2, coo_h->Btot2))), 1532 Cperm1(Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), PetscCountKokkosViewHost(coo_h->Cperm1, coo_h->sendlen))), 1533 sendbuf(Kokkos::create_mirror_view(Kokkos::WithoutInitializing, DefaultMemorySpace(), MatScalarKokkosViewHost(coo_h->sendbuf, coo_h->sendlen))), 1534 recvbuf(Kokkos::create_mirror_view(Kokkos::WithoutInitializing, DefaultMemorySpace(), MatScalarKokkosViewHost(coo_h->recvbuf, coo_h->recvlen))) 1535 { 1536 PetscCallVoid(PetscObjectReference((PetscObject)sf)); 1537 } 1538 1539 ~MatCOOStruct_MPIAIJKokkos() { PetscCallVoid(PetscSFDestroy(&sf)); } 1540 }; 1541 1542 static PetscErrorCode MatCOOStructDestroy_MPIAIJKokkos(void *data) 1543 { 1544 PetscFunctionBegin; 1545 PetscCallCXX(delete static_cast<MatCOOStruct_MPIAIJKokkos *>(data)); 1546 PetscFunctionReturn(PETSC_SUCCESS); 1547 } 1548 1549 static PetscErrorCode MatSetPreallocationCOO_MPIAIJKokkos(Mat mat, PetscCount coo_n, PetscInt coo_i[], PetscInt coo_j[]) 1550 { 1551 PetscContainer container_h, container_d; 1552 MatCOOStruct_MPIAIJ *coo_h; 1553 MatCOOStruct_MPIAIJKokkos *coo_d; 1554 1555 PetscFunctionBegin; 1556 PetscCall(MatSetPreallocationCOO_MPIAIJ(mat, coo_n, coo_i, coo_j)); /* mpiaij->A,B's type is set to seqaijkokkos */ 1557 mat->preallocated = PETSC_TRUE; 1558 PetscCall(MatAssemblyBegin(mat, MAT_FINAL_ASSEMBLY)); 1559 PetscCall(MatAssemblyEnd(mat, MAT_FINAL_ASSEMBLY)); 1560 PetscCall(MatZeroEntries(mat)); 1561 1562 // Copy the COO struct to device 1563 PetscCall(PetscObjectQuery((PetscObject)mat, "__PETSc_MatCOOStruct_Host", (PetscObject *)&container_h)); 1564 PetscCall(PetscContainerGetPointer(container_h, (void **)&coo_h)); 1565 PetscCallCXX(coo_d = new MatCOOStruct_MPIAIJKokkos(coo_h)); 1566 1567 // Put the COO struct in a container and then attach that to the matrix 1568 PetscCall(PetscContainerCreate(PETSC_COMM_SELF, &container_d)); 1569 PetscCall(PetscContainerSetPointer(container_d, coo_d)); 1570 PetscCall(PetscContainerSetUserDestroy(container_d, MatCOOStructDestroy_MPIAIJKokkos)); 1571 PetscCall(PetscObjectCompose((PetscObject)mat, "__PETSc_MatCOOStruct_Device", (PetscObject)container_d)); 1572 PetscCall(PetscContainerDestroy(&container_d)); 1573 PetscFunctionReturn(PETSC_SUCCESS); 1574 } 1575 1576 static PetscErrorCode MatSetValuesCOO_MPIAIJKokkos(Mat mat, const PetscScalar v[], InsertMode imode) 1577 { 1578 Mat_MPIAIJ *mpiaij = static_cast<Mat_MPIAIJ *>(mat->data); 1579 Mat A = mpiaij->A, B = mpiaij->B; 1580 MatScalarKokkosView Aa, Ba; 1581 MatScalarKokkosView v1; 1582 PetscMemType memtype; 1583 PetscContainer container; 1584 MatCOOStruct_MPIAIJKokkos *coo; 1585 1586 PetscFunctionBegin; 1587 PetscCall(PetscObjectQuery((PetscObject)mat, "__PETSc_MatCOOStruct_Device", (PetscObject *)&container)); 1588 PetscCall(PetscContainerGetPointer(container, (void **)&coo)); 1589 1590 const auto &n = coo->n; 1591 const auto &Annz = coo->Annz; 1592 const auto &Annz2 = coo->Annz2; 1593 const auto &Bnnz = coo->Bnnz; 1594 const auto &Bnnz2 = coo->Bnnz2; 1595 const auto &vsend = coo->sendbuf; 1596 const auto &v2 = coo->recvbuf; 1597 const auto &Ajmap1 = coo->Ajmap1; 1598 const auto &Ajmap2 = coo->Ajmap2; 1599 const auto &Aimap2 = coo->Aimap2; 1600 const auto &Bjmap1 = coo->Bjmap1; 1601 const auto &Bjmap2 = coo->Bjmap2; 1602 const auto &Bimap2 = coo->Bimap2; 1603 const auto &Aperm1 = coo->Aperm1; 1604 const auto &Aperm2 = coo->Aperm2; 1605 const auto &Bperm1 = coo->Bperm1; 1606 const auto &Bperm2 = coo->Bperm2; 1607 const auto &Cperm1 = coo->Cperm1; 1608 1609 PetscCall(PetscGetMemType(v, &memtype)); /* Return PETSC_MEMTYPE_HOST when v is NULL */ 1610 if (PetscMemTypeHost(memtype)) { /* If user gave v[] in host, we need to copy it to device if any */ 1611 v1 = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), MatScalarKokkosViewHost((PetscScalar *)v, n)); 1612 } else { 1613 v1 = MatScalarKokkosView((PetscScalar *)v, n); /* Directly use v[]'s memory */ 1614 } 1615 1616 if (imode == INSERT_VALUES) { 1617 PetscCall(MatSeqAIJGetKokkosViewWrite(A, &Aa)); /* write matrix values */ 1618 PetscCall(MatSeqAIJGetKokkosViewWrite(B, &Ba)); 1619 } else { 1620 PetscCall(MatSeqAIJGetKokkosView(A, &Aa)); /* read & write matrix values */ 1621 PetscCall(MatSeqAIJGetKokkosView(B, &Ba)); 1622 } 1623 1624 PetscCall(PetscLogGpuTimeBegin()); 1625 /* Pack entries to be sent to remote */ 1626 Kokkos::parallel_for(Kokkos::RangePolicy<>(PetscGetKokkosExecutionSpace(), 0, vsend.extent(0)), KOKKOS_LAMBDA(const PetscCount i) { vsend(i) = v1(Cperm1(i)); }); 1627 1628 /* Send remote entries to their owner and overlap the communication with local computation */ 1629 PetscCall(PetscSFReduceWithMemTypeBegin(coo->sf, MPIU_SCALAR, PETSC_MEMTYPE_KOKKOS, vsend.data(), PETSC_MEMTYPE_KOKKOS, v2.data(), MPI_REPLACE)); 1630 /* Add local entries to A and B in one kernel */ 1631 Kokkos::parallel_for( 1632 Kokkos::RangePolicy<>(PetscGetKokkosExecutionSpace(), 0, Annz + Bnnz), KOKKOS_LAMBDA(PetscCount i) { 1633 PetscScalar sum = 0.0; 1634 if (i < Annz) { 1635 for (PetscCount k = Ajmap1(i); k < Ajmap1(i + 1); k++) sum += v1(Aperm1(k)); 1636 Aa(i) = (imode == INSERT_VALUES ? 0.0 : Aa(i)) + sum; 1637 } else { 1638 i -= Annz; 1639 for (PetscCount k = Bjmap1(i); k < Bjmap1(i + 1); k++) sum += v1(Bperm1(k)); 1640 Ba(i) = (imode == INSERT_VALUES ? 0.0 : Ba(i)) + sum; 1641 } 1642 }); 1643 PetscCall(PetscSFReduceEnd(coo->sf, MPIU_SCALAR, vsend.data(), v2.data(), MPI_REPLACE)); 1644 1645 /* Add received remote entries to A and B in one kernel */ 1646 Kokkos::parallel_for( 1647 Kokkos::RangePolicy<>(PetscGetKokkosExecutionSpace(), 0, Annz2 + Bnnz2), KOKKOS_LAMBDA(PetscCount i) { 1648 if (i < Annz2) { 1649 for (PetscCount k = Ajmap2(i); k < Ajmap2(i + 1); k++) Aa(Aimap2(i)) += v2(Aperm2(k)); 1650 } else { 1651 i -= Annz2; 1652 for (PetscCount k = Bjmap2(i); k < Bjmap2(i + 1); k++) Ba(Bimap2(i)) += v2(Bperm2(k)); 1653 } 1654 }); 1655 PetscCall(PetscLogGpuTimeEnd()); 1656 1657 if (imode == INSERT_VALUES) { 1658 PetscCall(MatSeqAIJRestoreKokkosViewWrite(A, &Aa)); /* Increase A & B's state etc. */ 1659 PetscCall(MatSeqAIJRestoreKokkosViewWrite(B, &Ba)); 1660 } else { 1661 PetscCall(MatSeqAIJRestoreKokkosView(A, &Aa)); 1662 PetscCall(MatSeqAIJRestoreKokkosView(B, &Ba)); 1663 } 1664 PetscFunctionReturn(PETSC_SUCCESS); 1665 } 1666 1667 static PetscErrorCode MatDestroy_MPIAIJKokkos(Mat A) 1668 { 1669 PetscFunctionBegin; 1670 PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatMPIAIJSetPreallocation_C", NULL)); 1671 PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatMPIAIJGetLocalMatMerge_C", NULL)); 1672 PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatSetPreallocationCOO_C", NULL)); 1673 PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatSetValuesCOO_C", NULL)); 1674 PetscCall(MatDestroy_MPIAIJ(A)); 1675 PetscFunctionReturn(PETSC_SUCCESS); 1676 } 1677 1678 static PetscErrorCode MatShift_MPIAIJKokkos(Mat A, PetscScalar a) 1679 { 1680 Mat_MPIAIJ *mpiaij = static_cast<Mat_MPIAIJ *>(A->data); 1681 PetscBool congruent; 1682 1683 PetscFunctionBegin; 1684 PetscCall(MatHasCongruentLayouts(A, &congruent)); 1685 if (congruent) { // square matrix and the diagonals are solely in the diag block 1686 PetscCall(MatShift(mpiaij->A, a)); 1687 } else { // too hard, use the general version 1688 PetscCall(MatShift_Basic(A, a)); 1689 } 1690 PetscFunctionReturn(PETSC_SUCCESS); 1691 } 1692 1693 static PetscErrorCode MatSetOps_MPIAIJKokkos(Mat B) 1694 { 1695 PetscFunctionBegin; 1696 B->ops->assemblyend = MatAssemblyEnd_MPIAIJKokkos; 1697 B->ops->mult = MatMult_MPIAIJKokkos; 1698 B->ops->multadd = MatMultAdd_MPIAIJKokkos; 1699 B->ops->multtranspose = MatMultTranspose_MPIAIJKokkos; 1700 B->ops->productsetfromoptions = MatProductSetFromOptions_MPIAIJKokkos; 1701 B->ops->destroy = MatDestroy_MPIAIJKokkos; 1702 B->ops->shift = MatShift_MPIAIJKokkos; 1703 1704 PetscCall(PetscObjectComposeFunction((PetscObject)B, "MatMPIAIJSetPreallocation_C", MatMPIAIJSetPreallocation_MPIAIJKokkos)); 1705 PetscCall(PetscObjectComposeFunction((PetscObject)B, "MatMPIAIJGetLocalMatMerge_C", MatMPIAIJGetLocalMatMerge_MPIAIJKokkos)); 1706 PetscCall(PetscObjectComposeFunction((PetscObject)B, "MatSetPreallocationCOO_C", MatSetPreallocationCOO_MPIAIJKokkos)); 1707 PetscCall(PetscObjectComposeFunction((PetscObject)B, "MatSetValuesCOO_C", MatSetValuesCOO_MPIAIJKokkos)); 1708 PetscFunctionReturn(PETSC_SUCCESS); 1709 } 1710 1711 PETSC_INTERN PetscErrorCode MatConvert_MPIAIJ_MPIAIJKokkos(Mat A, MatType mtype, MatReuse reuse, Mat *newmat) 1712 { 1713 Mat B; 1714 Mat_MPIAIJ *a; 1715 1716 PetscFunctionBegin; 1717 if (reuse == MAT_INITIAL_MATRIX) { 1718 PetscCall(MatDuplicate(A, MAT_COPY_VALUES, newmat)); 1719 } else if (reuse == MAT_REUSE_MATRIX) { 1720 PetscCall(MatCopy(A, *newmat, SAME_NONZERO_PATTERN)); 1721 } 1722 B = *newmat; 1723 1724 B->boundtocpu = PETSC_FALSE; 1725 PetscCall(PetscFree(B->defaultvectype)); 1726 PetscCall(PetscStrallocpy(VECKOKKOS, &B->defaultvectype)); 1727 PetscCall(PetscObjectChangeTypeName((PetscObject)B, MATMPIAIJKOKKOS)); 1728 1729 a = static_cast<Mat_MPIAIJ *>(A->data); 1730 if (a->A) PetscCall(MatSetType(a->A, MATSEQAIJKOKKOS)); 1731 if (a->B) PetscCall(MatSetType(a->B, MATSEQAIJKOKKOS)); 1732 if (a->lvec) PetscCall(VecSetType(a->lvec, VECSEQKOKKOS)); 1733 PetscCall(MatSetOps_MPIAIJKokkos(B)); 1734 PetscFunctionReturn(PETSC_SUCCESS); 1735 } 1736 1737 /*MC 1738 MATAIJKOKKOS - "mpiaijkokkos", a matrix type to be used for CSR sparse matrices with Kokkos 1739 1740 A matrix type using Kokkos-Kernels CrsMatrix type for portability across different device types 1741 1742 Options Database Key: 1743 . -mat_type aijkokkos - sets the matrix type to `MATAIJKOKKOS` 1744 1745 Level: beginner 1746 1747 .seealso: [](ch_matrices), `Mat`, `MatCreateAIJKokkos()`, `MATSEQAIJKOKKOS`, `MATSEQAIJ`, `MATMPIAIJ` 1748 M*/ 1749 PETSC_EXTERN PetscErrorCode MatCreate_MPIAIJKokkos(Mat A) 1750 { 1751 PetscFunctionBegin; 1752 PetscCall(PetscKokkosInitializeCheck()); 1753 PetscCall(MatCreate_MPIAIJ(A)); 1754 PetscCall(MatConvert_MPIAIJ_MPIAIJKokkos(A, MATMPIAIJKOKKOS, MAT_INPLACE_MATRIX, &A)); 1755 PetscFunctionReturn(PETSC_SUCCESS); 1756 } 1757 1758 /*@C 1759 MatCreateAIJKokkos - Creates a sparse matrix in `MATAIJKOKOS` (compressed row) format 1760 (the default parallel PETSc format). This matrix will ultimately pushed down 1761 to Kokkos for calculations. 1762 1763 Collective 1764 1765 Input Parameters: 1766 + comm - MPI communicator, set to `PETSC_COMM_SELF` 1767 . m - number of local rows (or `PETSC_DECIDE` to have calculated if `M` is given) 1768 This value should be the same as the local size used in creating the 1769 y vector for the matrix-vector product y = Ax. 1770 . n - This value should be the same as the local size used in creating the 1771 x vector for the matrix-vector product y = Ax. (or `PETSC_DECIDE` to have 1772 calculated if N is given) For square matrices n is almost always `m`. 1773 . M - number of global rows (or `PETSC_DETERMINE` to have calculated if `m` is given) 1774 . N - number of global columns (or `PETSC_DETERMINE` to have calculated if `n` is given) 1775 . d_nz - number of nonzeros per row in DIAGONAL portion of local submatrix 1776 (same value is used for all local rows) 1777 . d_nnz - array containing the number of nonzeros in the various rows of the 1778 DIAGONAL portion of the local submatrix (possibly different for each row) 1779 or `NULL`, if `d_nz` is used to specify the nonzero structure. 1780 The size of this array is equal to the number of local rows, i.e `m`. 1781 For matrices you plan to factor you must leave room for the diagonal entry and 1782 put in the entry even if it is zero. 1783 . o_nz - number of nonzeros per row in the OFF-DIAGONAL portion of local 1784 submatrix (same value is used for all local rows). 1785 - o_nnz - array containing the number of nonzeros in the various rows of the 1786 OFF-DIAGONAL portion of the local submatrix (possibly different for 1787 each row) or `NULL`, if `o_nz` is used to specify the nonzero 1788 structure. The size of this array is equal to the number 1789 of local rows, i.e `m`. 1790 1791 Output Parameter: 1792 . A - the matrix 1793 1794 Level: intermediate 1795 1796 Notes: 1797 It is recommended that one use the `MatCreate()`, `MatSetType()` and/or `MatSetFromOptions()`, 1798 MatXXXXSetPreallocation() paradigm instead of this routine directly. 1799 [MatXXXXSetPreallocation() is, for example, `MatSeqAIJSetPreallocation()`] 1800 1801 The AIJ format, also called compressed row storage), is fully compatible with standard Fortran 1802 storage. That is, the stored row and column indices can begin at 1803 either one (as in Fortran) or zero. 1804 1805 .seealso: [](ch_matrices), `Mat`, `MATAIJKOKOS`, `MATSEQAIJKOKOS`, `MATMPIAIJKOKOS`, `MatCreate()`, `MatCreateAIJ()`, `MatSetValues()`, 1806 `MatSeqAIJSetColumnIndices()`, `MatCreateSeqAIJWithArrays()`, `MATMPIAIJKOKKOS`, `MATAIJKOKKOS` 1807 @*/ 1808 PetscErrorCode MatCreateAIJKokkos(MPI_Comm comm, PetscInt m, PetscInt n, PetscInt M, PetscInt N, PetscInt d_nz, const PetscInt d_nnz[], PetscInt o_nz, const PetscInt o_nnz[], Mat *A) 1809 { 1810 PetscMPIInt size; 1811 1812 PetscFunctionBegin; 1813 PetscCall(MatCreate(comm, A)); 1814 PetscCall(MatSetSizes(*A, m, n, M, N)); 1815 PetscCallMPI(MPI_Comm_size(comm, &size)); 1816 if (size > 1) { 1817 PetscCall(MatSetType(*A, MATMPIAIJKOKKOS)); 1818 PetscCall(MatMPIAIJSetPreallocation(*A, d_nz, d_nnz, o_nz, o_nnz)); 1819 } else { 1820 PetscCall(MatSetType(*A, MATSEQAIJKOKKOS)); 1821 PetscCall(MatSeqAIJSetPreallocation(*A, d_nz, d_nnz)); 1822 } 1823 PetscFunctionReturn(PETSC_SUCCESS); 1824 } 1825