1 #include <petsc_kokkos.hpp> 2 #include <petscvec_kokkos.hpp> 3 #include <petscpkg_version.h> 4 #include <petsc/private/sfimpl.h> 5 #include <petsc/private/kokkosimpl.hpp> 6 #include <../src/mat/impls/aij/seq/kokkos/aijkok.hpp> 7 #include <../src/mat/impls/aij/mpi/mpiaij.h> 8 #include <KokkosSparse_spadd.hpp> 9 #include <KokkosSparse_spgemm.hpp> 10 11 static PetscErrorCode MatAssemblyEnd_MPIAIJKokkos(Mat A, MatAssemblyType mode) 12 { 13 Mat_MPIAIJ *mpiaij = (Mat_MPIAIJ *)A->data; 14 15 PetscFunctionBegin; 16 PetscCall(MatAssemblyEnd_MPIAIJ(A, mode)); 17 /* E.g., MatCreateSubMatrix() calls MatCreateMPIAIJWithSeqAIJ(comm,A,B,..), which creates Bnew of SEQAIJ and destroys B of SEQAIJKOKKOS. 18 Thus we finalize A/B/lvec's type in MatAssemblyEnd() to handle various cases. 19 */ 20 if (mode == MAT_FINAL_ASSEMBLY) { 21 PetscScalarKokkosView v; 22 23 PetscCall(MatSetType(mpiaij->A, MATSEQAIJKOKKOS)); 24 PetscCall(MatSetType(mpiaij->B, MATSEQAIJKOKKOS)); 25 PetscCall(VecSetType(mpiaij->lvec, VECSEQKOKKOS)); // lvec is init'ed on host, without copying to device 26 PetscCall(VecGetKokkosViewWrite(mpiaij->lvec, &v)); // mark lvec updated on device, as we never need to init lvec on device 27 PetscCall(VecRestoreKokkosViewWrite(mpiaij->lvec, &v)); 28 } 29 PetscFunctionReturn(PETSC_SUCCESS); 30 } 31 32 static PetscErrorCode MatMPIAIJSetPreallocation_MPIAIJKokkos(Mat mat, PetscInt d_nz, const PetscInt d_nnz[], PetscInt o_nz, const PetscInt o_nnz[]) 33 { 34 Mat_MPIAIJ *mpiaij; 35 36 PetscFunctionBegin; 37 // reuse MPIAIJ's preallocation, which sets A/B's blocksize along other things 38 PetscCall(MatMPIAIJSetPreallocation_MPIAIJ(mat, d_nz, d_nnz, o_nz, o_nnz)); 39 mpiaij = static_cast<Mat_MPIAIJ *>(mat->data); 40 PetscCall(MatConvert_SeqAIJ_SeqAIJKokkos(mpiaij->A, MATSEQAIJKOKKOS, MAT_INPLACE_MATRIX, &mpiaij->A)); 41 PetscCall(MatConvert_SeqAIJ_SeqAIJKokkos(mpiaij->B, MATSEQAIJKOKKOS, MAT_INPLACE_MATRIX, &mpiaij->B)); 42 PetscFunctionReturn(PETSC_SUCCESS); 43 } 44 45 static PetscErrorCode MatMult_MPIAIJKokkos(Mat mat, Vec xx, Vec yy) 46 { 47 Mat_MPIAIJ *mpiaij = (Mat_MPIAIJ *)mat->data; 48 PetscInt nt; 49 50 PetscFunctionBegin; 51 PetscCall(VecGetLocalSize(xx, &nt)); 52 PetscCheck(nt == mat->cmap->n, PETSC_COMM_SELF, PETSC_ERR_ARG_SIZ, "Incompatible partition of mat (%" PetscInt_FMT ") and xx (%" PetscInt_FMT ")", mat->cmap->n, nt); 53 PetscCall(VecScatterBegin(mpiaij->Mvctx, xx, mpiaij->lvec, INSERT_VALUES, SCATTER_FORWARD)); 54 PetscCall((*mpiaij->A->ops->mult)(mpiaij->A, xx, yy)); 55 PetscCall(VecScatterEnd(mpiaij->Mvctx, xx, mpiaij->lvec, INSERT_VALUES, SCATTER_FORWARD)); 56 PetscCall((*mpiaij->B->ops->multadd)(mpiaij->B, mpiaij->lvec, yy, yy)); 57 PetscFunctionReturn(PETSC_SUCCESS); 58 } 59 60 static PetscErrorCode MatMultAdd_MPIAIJKokkos(Mat mat, Vec xx, Vec yy, Vec zz) 61 { 62 Mat_MPIAIJ *mpiaij = (Mat_MPIAIJ *)mat->data; 63 PetscInt nt; 64 65 PetscFunctionBegin; 66 PetscCall(VecGetLocalSize(xx, &nt)); 67 PetscCheck(nt == mat->cmap->n, PETSC_COMM_SELF, PETSC_ERR_ARG_SIZ, "Incompatible partition of mat (%" PetscInt_FMT ") and xx (%" PetscInt_FMT ")", mat->cmap->n, nt); 68 PetscCall(VecScatterBegin(mpiaij->Mvctx, xx, mpiaij->lvec, INSERT_VALUES, SCATTER_FORWARD)); 69 PetscCall((*mpiaij->A->ops->multadd)(mpiaij->A, xx, yy, zz)); 70 PetscCall(VecScatterEnd(mpiaij->Mvctx, xx, mpiaij->lvec, INSERT_VALUES, SCATTER_FORWARD)); 71 PetscCall((*mpiaij->B->ops->multadd)(mpiaij->B, mpiaij->lvec, zz, zz)); 72 PetscFunctionReturn(PETSC_SUCCESS); 73 } 74 75 static PetscErrorCode MatMultTranspose_MPIAIJKokkos(Mat mat, Vec xx, Vec yy) 76 { 77 Mat_MPIAIJ *mpiaij = (Mat_MPIAIJ *)mat->data; 78 PetscInt nt; 79 80 PetscFunctionBegin; 81 PetscCall(VecGetLocalSize(xx, &nt)); 82 PetscCheck(nt == mat->rmap->n, PETSC_COMM_SELF, PETSC_ERR_ARG_SIZ, "Incompatible partition of mat (%" PetscInt_FMT ") and xx (%" PetscInt_FMT ")", mat->rmap->n, nt); 83 PetscCall((*mpiaij->B->ops->multtranspose)(mpiaij->B, xx, mpiaij->lvec)); 84 PetscCall((*mpiaij->A->ops->multtranspose)(mpiaij->A, xx, yy)); 85 PetscCall(VecScatterBegin(mpiaij->Mvctx, mpiaij->lvec, yy, ADD_VALUES, SCATTER_REVERSE)); 86 PetscCall(VecScatterEnd(mpiaij->Mvctx, mpiaij->lvec, yy, ADD_VALUES, SCATTER_REVERSE)); 87 PetscFunctionReturn(PETSC_SUCCESS); 88 } 89 90 /* Merge the "A, B" matrices of mat into a matrix C. mat's type is MPIAIJKOKKOS. C's type is MATSEQAIJKOKKOS. 91 A is put before B. C's size would be A->rmap->n by (A->cmap->n + B->cmap->n). 92 C still uses local column ids. Their corresponding global column ids are returned in glob. 93 */ 94 static PetscErrorCode MatMPIAIJGetLocalMatMerge_MPIAIJKokkos(Mat mat, MatReuse reuse, IS *glob, Mat *C) 95 { 96 Mat Ad, Ao; 97 const PetscInt *cmap; 98 99 PetscFunctionBegin; 100 PetscCall(MatMPIAIJGetSeqAIJ(mat, &Ad, &Ao, &cmap)); 101 PetscCall(MatSeqAIJKokkosMergeMats(Ad, Ao, reuse, C)); 102 if (glob) { 103 PetscInt cst, i, dn, on, *gidx; 104 PetscCall(MatGetLocalSize(Ad, NULL, &dn)); 105 PetscCall(MatGetLocalSize(Ao, NULL, &on)); 106 PetscCall(MatGetOwnershipRangeColumn(mat, &cst, NULL)); 107 PetscCall(PetscMalloc1(dn + on, &gidx)); 108 for (i = 0; i < dn; i++) gidx[i] = cst + i; 109 for (i = 0; i < on; i++) gidx[i + dn] = cmap[i]; 110 PetscCall(ISCreateGeneral(PetscObjectComm((PetscObject)Ad), dn + on, gidx, PETSC_OWN_POINTER, glob)); 111 } 112 PetscFunctionReturn(PETSC_SUCCESS); 113 } 114 115 /* Structs used in matrix products of type C=AB, C=A^tB and C=B^tAB */ 116 struct MatMatStruct { 117 PetscInt n, *garray; // C's garray and its size. 118 KokkosCsrMatrix Cd, Co; // C is in split form matrices (all in local column indcies) 119 KokkosCsrMatrix C1, C2, C3, C4; // intermediate mat products 120 KokkosCsrMatrix C2_mid, C4_mid; // alias of C2, C4; share their a[], i[], but with different j[] (hence column size) 121 PetscIntKokkosView E_NzLeft; 122 PetscSF sf = nullptr; // SF to bcast or reduce matrices E to F 123 MatScalarKokkosView rootBuf, leafBuf; 124 KokkosCsrMatrix Fd, Fo; // F in split form 125 126 KernelHandle kh1; // compute C1, add C1+C3 or C1+Fd 127 KernelHandle kh2; // compute C2, add C2+C4 or C2+Fo 128 KernelHandle kh3; // compute C3 129 KernelHandle kh4; // compute C4 130 131 PetscInt E_TeamSize; // kernel launching parameters in merging E or splitting F 132 PetscInt E_VectorLength; 133 PetscInt E_RowsPerTeam; 134 PetscInt F_TeamSize; 135 PetscInt F_VectorLength; 136 PetscInt F_RowsPerTeam; 137 138 ~MatMatStruct() 139 { 140 PetscFunctionBegin; 141 PetscCallAbort(PETSC_COMM_SELF, PetscSFDestroy(&sf)); 142 PetscFunctionReturnVoid(); 143 } 144 }; 145 146 struct MatMatStruct_AB : public MatMatStruct { 147 PetscIntKokkosView F_NzLeft; // plans to split F (in leafbuf) into Fd, Fo 148 PetscIntKokkosView irootloc; // plans to put E (i.e., Bd, Bo) into rootBuf 149 PetscIntKokkosView rowoffset; 150 }; 151 152 struct MatMatStruct_AtB : public MatMatStruct { 153 MatColIdxKokkosView Fdjmap; // plans to reduce data in rootBuf to Fd, Fo 154 MatColIdxKokkosView Fdjperm; 155 MatColIdxKokkosView Fojmap; 156 MatColIdxKokkosView Fojperm; 157 }; 158 159 struct MatProductData_MPIAIJKokkos { 160 MatMatStruct_AB *mmAB = nullptr; 161 MatMatStruct_AtB *mmAtB = nullptr; 162 PetscBool reusesym = PETSC_FALSE; 163 Mat Z = nullptr; // store Z=AB in computing BtAB 164 165 ~MatProductData_MPIAIJKokkos() 166 { 167 delete mmAB; 168 delete mmAtB; 169 PetscCallAbort(PETSC_COMM_SELF, MatDestroy(&Z)); 170 } 171 }; 172 173 static PetscErrorCode MatProductDataDestroy_MPIAIJKokkos(void *data) 174 { 175 PetscFunctionBegin; 176 PetscCallCXX(delete static_cast<MatProductData_MPIAIJKokkos *>(data)); 177 PetscFunctionReturn(PETSC_SUCCESS); 178 } 179 180 /* MatSetMPIAIJKokkosWithSplitSeqAIJKokkosMatrices - Set the diag and offdiag matrices of a MATMPIAIJKOKKOS matrix. 181 It is similar to MatCreateMPIAIJWithSplitArrays. 182 183 Input Parameters: 184 + mat - the MATMPIAIJKOKKOS matrix, which should have its type and layout set, but should not have its diag, offdiag matrices set 185 . A - the diag matrix using local col ids 186 - B - the offdiag matrix using global col ids 187 188 Output Parameter: 189 . mat - the updated MATMPIAIJKOKKOS matrix 190 */ 191 static PetscErrorCode MatSetMPIAIJKokkosWithSplitSeqAIJKokkosMatrices(Mat mat, Mat A, Mat B, PetscInt *garray) 192 { 193 Mat_MPIAIJ *mpiaij = static_cast<Mat_MPIAIJ *>(mat->data); 194 PetscInt m, n, M, N, Am, An, Bm, Bn; 195 196 PetscFunctionBegin; 197 PetscCall(MatGetSize(mat, &M, &N)); 198 PetscCall(MatGetLocalSize(mat, &m, &n)); 199 PetscCall(MatGetLocalSize(A, &Am, &An)); 200 PetscCall(MatGetLocalSize(B, &Bm, &Bn)); 201 202 PetscCheck(m == Am && m == Bm, PETSC_COMM_SELF, PETSC_ERR_PLIB, "local number of rows do not match"); 203 PetscCheck(n == An, PETSC_COMM_SELF, PETSC_ERR_PLIB, "local number of columns do not match"); 204 // PetscCheck(N == Bn, PETSC_COMM_SELF, PETSC_ERR_PLIB, "global number of columns do not match"); 205 PetscCheck(!mpiaij->A && !mpiaij->B, PETSC_COMM_SELF, PETSC_ERR_PLIB, "A, B of the MPIAIJ matrix are not empty"); 206 mpiaij->A = A; 207 mpiaij->B = B; 208 mpiaij->garray = garray; 209 210 mat->preallocated = PETSC_TRUE; 211 mat->nooffprocentries = PETSC_TRUE; /* See MatAssemblyBegin_MPIAIJ. In effect, making MatAssemblyBegin a nop */ 212 213 PetscCall(MatSetOption(mat, MAT_NO_OFF_PROC_ENTRIES, PETSC_TRUE)); 214 PetscCall(MatAssemblyBegin(mat, MAT_FINAL_ASSEMBLY)); 215 /* MatAssemblyEnd is critical here. It sets mat->offloadmask according to A and B's, and 216 also gets mpiaij->B compacted, with its col ids and size reduced 217 */ 218 PetscCall(MatAssemblyEnd(mat, MAT_FINAL_ASSEMBLY)); 219 PetscCall(MatSetOption(mat, MAT_NO_OFF_PROC_ENTRIES, PETSC_FALSE)); 220 PetscCall(MatSetOption(mat, MAT_NEW_NONZERO_LOCATION_ERR, PETSC_TRUE)); 221 PetscFunctionReturn(PETSC_SUCCESS); 222 } 223 224 // Adapted from Kokkos-Kernels spmv_launch_parameters(), to get parameters in Kokkos nested loops which we used to merge or 225 // split csr matrices. The rule is to have "vector_length * team_size" be around 256 on GPUs (e.g., for a CUDA thread block) 226 template <class ExecutionSpace> 227 static PetscErrorCode MatMergeGetLaunchParameters(PetscInt numRows, PetscInt nnz, PetscInt rows_per_thread, PetscInt &team_size, PetscInt &vector_length, PetscInt &rows_per_team) 228 { 229 Kokkos::TeamPolicy<ExecutionSpace> teamPolicy(128, Kokkos::AUTO); 230 231 PetscFunctionBegin; 232 PetscInt nnz_per_row = numRows ? (nnz / numRows) : 0; // we might meet empty matrices 233 234 if (nnz_per_row < 1) nnz_per_row = 1; 235 236 int max_vector_length = teamPolicy.vector_length_max(); 237 238 if (vector_length < 1) { 239 vector_length = 1; 240 while (vector_length < max_vector_length && vector_length * 6 < nnz_per_row) vector_length *= 2; 241 } 242 243 // Determine rows per thread 244 if (rows_per_thread < 1) { 245 if (KokkosKernels::Impl::kk_is_gpu_exec_space<ExecutionSpace>()) rows_per_thread = 1; 246 else { 247 if (nnz_per_row < 20 && nnz > 5000000) { 248 rows_per_thread = 256; 249 } else rows_per_thread = 64; 250 } 251 } 252 253 if (team_size < 1) { 254 if (KokkosKernels::Impl::kk_is_gpu_exec_space<ExecutionSpace>()) { 255 team_size = 256 / vector_length; 256 } else { 257 team_size = 1; 258 } 259 } 260 261 rows_per_team = rows_per_thread * team_size; 262 263 if (rows_per_team < 0) { 264 PetscInt nnz_per_team = 4096; 265 PetscInt conc = ExecutionSpace().concurrency(); 266 while ((conc * nnz_per_team * 4 > nnz) && (nnz_per_team > 256)) nnz_per_team /= 2; 267 rows_per_team = (nnz_per_team + nnz_per_row - 1) / nnz_per_row; 268 } 269 PetscFunctionReturn(PETSC_SUCCESS); 270 } 271 272 /* 273 Reduce two sets of global indices into local ones 274 275 Input Parameters: 276 + n1 - size of garray1[], the first set 277 . garray1[n1] - a sorted global index array (without duplicates) 278 . m - size of indices[], the second set 279 - indices[m] - a unsorted global index array (might have duplicates), which will be updated on output into local ones 280 281 Output Parameters: 282 + n2 - size of garray2[], the merged set, which combines garray1[] and indices[] 283 . garray2[n2] - allocated by callee using PetscMalloc1(). Contains sorted unique global indices (without duplicates). Caller needs to free it. 284 . map[n1] - allocated by caller. It gives garray1[i] = garray2[map[i]] 285 - indices[m] - on output, global indices in this array are rewritten with local ones, i.e, indices_input[i] = garray2[indices_output[i]] 286 287 Example, say 288 n1 = 5 289 garray1[5] = {1, 4, 7, 8, 10} 290 m = 4 291 indices[4] = {2, 4, 8, 9} 292 293 Combining them together, we have 7 global indices in garray2[] 294 n2 = 7 295 garray2[7] = {1, 2, 4, 7, 8, 9, 10} 296 297 And we have map[] to connect "garray1[i] = garray2[map[i]], i=[0,n1)" 298 map[5] = {0, 2, 3, 4, 6} 299 300 On output, indices[] is updated with local indices 301 indices[4] = {1, 2, 4, 5} 302 */ 303 static PetscErrorCode ReduceTwoSetsOfGlobalIndices(PetscInt n1, const PetscInt *garray1, PetscInt m, PetscInt *indices, PetscInt *n2_, PetscInt **garray2_, PetscInt *map) 304 { 305 PetscHMapI g2l = nullptr; 306 PetscHashIter iter; 307 PetscInt tot, key, val; // total unique global indices. key is global id; val is local id 308 PetscInt n2, *garray2; 309 310 PetscFunctionBegin; 311 tot = 0; 312 PetscCall(PetscHMapICreateWithSize(n1, &g2l)); 313 for (PetscInt i = 0; i < m; i++) { // insert those in indices[] 314 PetscCall(PetscHMapIGetWithDefault(g2l, indices[i], -1, &val)); // if not exist, val is set with -1 315 if (val < 0) PetscCall(PetscHMapISet(g2l, indices[i], tot++)); // val < 0 means gid is not in the hash table yet 316 } 317 318 for (PetscInt i = 0; i < n1; i++) { // insert those in garray1[] 319 PetscCall(PetscHMapIGetWithDefault(g2l, garray1[i], -1, &val)); 320 if (val < 0) PetscCall(PetscHMapISet(g2l, garray1[i], tot++)); 321 } 322 323 // Pull out (unique) globals in the hash table and put them in garray2[] 324 n2 = tot; 325 PetscCall(PetscMalloc1(n2, &garray2)); 326 tot = 0; 327 PetscHashIterBegin(g2l, iter); 328 while (!PetscHashIterAtEnd(g2l, iter)) { 329 PetscHashIterGetKey(g2l, iter, key); 330 PetscHashIterNext(g2l, iter); 331 garray2[tot++] = key; 332 } 333 334 // Sort garray2[] and then map them to local indices starting from 0 335 PetscCall(PetscSortInt(n2, garray2)); 336 PetscCall(PetscHMapIClear(g2l)); 337 for (PetscInt i = 0; i < tot; i++) PetscCall(PetscHMapISet(g2l, garray2[i], i)); // i is the local id 338 339 // Rewrite indices[] with local indices 340 for (PetscInt i = 0; i < m; i++) { 341 PetscCall(PetscHMapIGetWithDefault(g2l, indices[i], -1, &val)); 342 PetscAssert(val >= 0, PETSC_COMM_SELF, PETSC_ERR_PLIB, "Met a negative local column index"); 343 indices[i] = val; 344 } 345 // Record the map that maps garray1[i] to garray2[map[i]] 346 for (PetscInt i = 0; i < n1; i++) PetscCall(PetscHMapIGetWithDefault(g2l, garray1[i], -1, &map[i])); 347 PetscCall(PetscHMapIDestroy(&g2l)); 348 *n2_ = n2; 349 *garray2_ = garray2; 350 PetscFunctionReturn(PETSC_SUCCESS); 351 } 352 353 /* 354 MatMPIAIJKokkosReduce - Reduce rows of a MPIAIJKOKKOS matrix (E, in split form) to produce another matrix (F, also in split form, stored in mm) 355 356 It is the reverse of MatMPIAIJKokkosBcast() in some sense, but with a different signature since we do not really need a fully populated MPIAIJKOKKOS E. 357 358 Think each row of E as a leaf, then the given ownerSF specifies roots for the leaves. Roots may connect to multiple leaves. 359 In this routine, we sparse-merge leaves (rows) at their roots to form potentially longer rows in F. F's number of rows will be nroots of ownerSF. 360 361 Input Parameters: 362 + comm - MPI communicator of E 363 . A - diag block of E, using local column indices 364 . B - off-diag block of E, using local column indices 365 . cstart - (global) start column of Ed 366 . cend - (global) end column + 1 of Ed. In other words, E's column ownership is in range of [cstart, cend) 367 . garray1[n1] - global column indices of Eo. Here n1 is Eo's column size. 368 . ownerSF - the SF specifies ownership (root) of rows in E 369 . reuse - either MAT_INITIAL_MATRIX or MAT_REUSE_MATRIX 370 - mm - to stash intermediate data structures for reuse 371 372 Output Parameters: 373 + map[n1] - allocated by caller. It maps garray1[] to garray2[]. See more at ReduceTwoSetsOfGlobalIndices(). 374 - mm - contains various info, such as garray2[], F (Fd, Fo) etc. 375 376 Notes: 377 When reuse = MAT_REUSE_MATRIX, cstart, cend, garray1, ownerSF, map are not significant. 378 379 */ 380 static PetscErrorCode MatMPIAIJKokkosReduceBegin(MPI_Comm comm, KokkosCsrMatrix A, KokkosCsrMatrix B, PetscInt cstart, PetscInt cend, const PetscInt *garray1, PetscSF ownerSF, MatReuse reuse, PetscInt *map, MatMatStruct_AtB *mm) 381 { 382 PetscFunctionBegin; 383 if (reuse == MAT_INITIAL_MATRIX) { 384 PetscInt Em = A.numRows(), Fm; 385 PetscInt n1 = B.numCols(); 386 387 PetscCall(PetscSFGetGraph(ownerSF, &Fm, NULL, NULL, NULL)); // Fm = #rows of F = nroots of ownerSF 388 389 // Do the analysis on host 390 auto Ai_h = Kokkos::create_mirror_view_and_copy(Kokkos::HostSpace(), A.graph.row_map); 391 auto Aj_h = Kokkos::create_mirror_view_and_copy(Kokkos::HostSpace(), A.graph.entries); 392 auto Bi_h = Kokkos::create_mirror_view_and_copy(Kokkos::HostSpace(), B.graph.row_map); 393 auto Bj_h = Kokkos::create_mirror_view_and_copy(Kokkos::HostSpace(), B.graph.entries); 394 const MatRowMapType *Ai = Ai_h.data(), *Bi = Bi_h.data(); 395 const MatColIdxType *Aj = Aj_h.data(), *Bj = Bj_h.data(); 396 397 // Count how many nonzeros of each row in E are in the left of the diag block [cstart,cend) 398 PetscIntKokkosViewHost E_NzLeft_h(NoInit("E_NzLeft_h"), Em), E_RowLen_h(NoInit("E_RowLen_h"), Em); 399 PetscInt *E_NzLeft = E_NzLeft_h.data(), *E_RowLen = E_RowLen_h.data(); 400 for (PetscInt i = 0; i < Em; i++) { 401 const PetscInt *first, *last, *it; 402 PetscInt count, step; 403 // std::lower_bound(first,last,cstart), but need to use global column indices 404 first = Bj + Bi[i]; 405 last = Bj + Bi[i + 1]; 406 count = last - first; 407 while (count > 0) { 408 it = first; 409 step = count / 2; 410 it += step; 411 if (garray1[*it] < cstart) { // map local to global 412 first = ++it; 413 count -= step + 1; 414 } else count = step; 415 } 416 E_NzLeft[i] = first - (Bj + Bi[i]); 417 E_RowLen[i] = (Ai[i + 1] - Ai[i]) + (Bi[i + 1] - Bi[i]); 418 } 419 420 // Get length of rows (i.e., sizes of leaves) that contribute to my roots 421 const PetscMPIInt *iranks, *ranks; 422 const PetscInt *ioffset, *irootloc, *roffset, *rmine; 423 PetscMPIInt niranks, nranks; 424 MPI_Request *reqs; 425 PetscMPIInt tag; 426 PetscSF reduceSF; 427 PetscInt *sdisp, *rdisp; 428 429 PetscCall(PetscCommGetNewTag(comm, &tag)); 430 PetscCall(PetscSFGetLeafRanks(ownerSF, &niranks, &iranks, &ioffset, &irootloc)); // get leaf ranks connecting to roots on this process (I'll recv from them) 431 PetscCall(PetscSFGetRootRanks(ownerSF, &nranks, &ranks, &roffset, &rmine, NULL)); // get root ranks this process connects (I'll send to them) 432 433 // Find out length of each row I will receive. Even for the same row index, when they are from 434 // different senders, they might have different lengths (and sparsity patterns) 435 PetscInt sendRowCnt = roffset[nranks], recvRowCnt = ioffset[niranks]; 436 PetscInt *sendRowLen, *recvRowLen; // lengths of rows of I need to send/recv per process 437 438 PetscCall(PetscMalloc5(sendRowCnt, &sendRowLen, recvRowCnt + 1, &recvRowLen, nranks, &sdisp, niranks + 1, &rdisp, nranks + niranks, &reqs)); 439 440 for (PetscInt i = 0; i < sendRowCnt; i++) sendRowLen[i] = E_RowLen[rmine[i]]; 441 recvRowLen[0] = 0; // since we will make it in CSR format later 442 recvRowLen++; // advance the pointer now 443 for (PetscInt i = 0; i < niranks; i++) MPI_Irecv(&recvRowLen[ioffset[i]], ioffset[i + 1] - ioffset[i], MPIU_INT, iranks[i], tag, comm, &reqs[nranks + i]); 444 for (PetscInt i = 0; i < nranks; i++) MPIU_Isend(&sendRowLen[roffset[i]], roffset[i + 1] - roffset[i], MPIU_INT, ranks[i], tag, comm, &reqs[i]); 445 PetscCallMPI(MPI_Waitall(nranks + niranks, reqs, MPI_STATUSES_IGNORE)); 446 447 // Build the real PetscSF for reducing E rows (buffer to buffer) 448 rdisp[0] = 0; 449 for (PetscInt i = 0; i < niranks; i++) { 450 rdisp[i + 1] = rdisp[i]; 451 for (PetscInt j = ioffset[i]; j < ioffset[i + 1]; j++) { rdisp[i + 1] += recvRowLen[j]; } 452 } 453 recvRowLen--; // put it back into csr format 454 for (PetscInt i = 0; i < recvRowCnt; i++) recvRowLen[i + 1] += recvRowLen[i]; 455 456 for (PetscInt i = 0; i < nranks; i++) MPIU_Irecv(&sdisp[i], 1, MPIU_INT, ranks[i], tag, comm, &reqs[i]); 457 for (PetscInt i = 0; i < niranks; i++) MPIU_Isend(&rdisp[i], 1, MPIU_INT, iranks[i], tag, comm, &reqs[nranks + i]); 458 PetscCallMPI(MPI_Waitall(nranks + niranks, reqs, MPI_STATUSES_IGNORE)); 459 460 PetscInt nleaves = 0, Enz = 0; // leaves are nonzeros I will send 461 PetscInt nroots = rdisp[niranks]; // roots are nonzeros I will recv 462 PetscSFNode *iremote; 463 464 for (PetscInt i = 0; i < Em; i++) Enz += E_RowLen[i]; 465 PetscAssert(A.nnz() + B.nnz() == Enz, comm, PETSC_ERR_PLIB, "Enz should be equal to sum of nnz of A and B"); 466 PetscCallMPI(PetscMalloc1(Enz, &iremote)); // no free, since we give ownership to reduceSF 467 468 for (PetscInt i = 0; i < nranks; i++) { 469 PetscInt count = 0; 470 for (PetscInt j = roffset[i]; j < roffset[i + 1]; j++) count += E_RowLen[rmine[j]]; 471 for (PetscInt j = 0; j < count; j++) { 472 iremote[nleaves + j].rank = ranks[i]; 473 iremote[nleaves + j].index = sdisp[i] + j; 474 } 475 nleaves += count; 476 } 477 PetscCheck(nleaves == Enz, comm, PETSC_ERR_PLIB, "nleaves should be equal to Enz"); 478 479 PetscCall(PetscSFCreate(comm, &reduceSF)); 480 PetscCall(PetscSFSetGraph(reduceSF, nroots, nleaves, NULL, PETSC_OWN_POINTER, iremote, PETSC_OWN_POINTER)); 481 482 // Copy (global) column indices of the needed rows in E to sendCol[], and then PetscSFReduce to recvCol[] 483 PetscInt *sendCol, *recvCol; 484 PetscCall(PetscMalloc2(nleaves, &sendCol, nroots, &recvCol)); 485 for (PetscInt k = 0; k < roffset[nranks]; k++) { 486 PetscInt i = rmine[k]; // row to be copied 487 PetscInt *buf = &sendCol[Ai[i] + Bi[i]]; 488 PetscInt nzLeft = E_NzLeft[i]; 489 PetscInt alen = Ai[i + 1] - Ai[i], blen = Bi[i + 1] - Bi[i]; 490 for (PetscInt j = 0; j < alen + blen; j++) { 491 if (j < nzLeft) { 492 buf[j] = garray1[Bj[Bi[i] + j]]; // left B, in global 493 } else if (j < nzLeft + alen) { 494 buf[j] = Aj[Ai[i] + j - nzLeft] + cstart; // diag A, also in global 495 } else { 496 buf[j] = garray1[Bj[Bi[i] + j - alen]]; // right B, in global 497 } 498 } 499 } 500 PetscCall(PetscSFReduceWithMemTypeBegin(reduceSF, MPIU_INT, PETSC_MEMTYPE_HOST, sendCol, PETSC_MEMTYPE_HOST, recvCol, MPI_REPLACE)); 501 PetscCall(PetscSFReduceEnd(reduceSF, MPIU_INT, sendCol, recvCol, MPI_REPLACE)); 502 503 // With recvCol[], we do a series of analysis to get i, j of Fd, Fo, and build plans to reduce nonzeros in recv buffers to Fd and Fo 504 PetscInt *recvRowPerm, *recvColSorted; 505 PetscInt *recvNzPerm, *recvNzPermSorted; 506 PetscCall(PetscMalloc4(recvRowCnt, &recvRowPerm, nroots, &recvColSorted, nroots, &recvNzPerm, nroots, &recvNzPermSorted)); 507 508 for (PetscInt i = 0; i < nroots; i++) recvNzPerm[i] = i; // numbering all received nonzeros 509 for (PetscInt i = 0; i < recvRowCnt; i++) recvRowPerm[i] = i; // put up a permutation array, so that after sorting we know where to get a row in recvCol[] 510 PetscCall(PetscSortIntWithPermutation(recvRowCnt, irootloc, recvRowPerm)); // irootloc[] (owned by ownerSF) won't be changed 511 512 // i[] array, nz are always easiest to compute 513 MatRowMapKokkosViewHost Fdi_h(NoInit("Fdi_h"), Fm + 1), Foi_h(NoInit("Foi_h"), Fm + 1); 514 MatRowMapType *Fdi, *Foi; 515 PetscInt FnzDups = 0, Fdnz = 0, FdnzDups = 0, Fonz = 0, FonzDups = 0; // nz (with or without dups) in F, Fd, Fo 516 PetscInt iter; 517 518 Kokkos::deep_copy(Fdi_h, 0); // zero, as we will do 'val++' on them 519 Kokkos::deep_copy(Foi_h, 0); 520 Fdi = Fdi_h.data() + 1; // +1 for easy indexing in code below 521 Foi = Foi_h.data() + 1; 522 iter = 0; 523 while (iter < recvRowCnt) { // iter over received rows 524 PetscInt curRowIdx = irootloc[recvRowPerm[iter]]; 525 PetscInt dupRows = 1; // current row has this many contributing rows (of various sparsity patterns) 526 527 while (iter + dupRows < recvRowCnt && irootloc[recvRowPerm[iter + dupRows]] == curRowIdx) dupRows++; 528 529 // Copy column indices (and their permutation) of these rows into recvColSorted & recvNzPermSorted 530 PetscInt nz = 0; // nz (with dups) in the current row 531 PetscInt *jbuf = recvColSorted + FnzDups; 532 PetscInt *pbuf = recvNzPermSorted + FnzDups; 533 PetscInt *jbuf2 = jbuf; // temp pointers 534 PetscInt *pbuf2 = pbuf; 535 for (PetscInt d = 0; d < dupRows; d++) { 536 PetscInt i = recvRowPerm[iter + d]; 537 PetscInt len = recvRowLen[i + 1] - recvRowLen[i]; 538 PetscCall(PetscArraycpy(jbuf2, &recvCol[recvRowLen[i]], len)); 539 PetscCall(PetscArraycpy(pbuf2, &recvNzPerm[recvRowLen[i]], len)); 540 jbuf2 += len; 541 pbuf2 += len; 542 nz += len; 543 } 544 PetscCall(PetscIntSortSemiOrderedWithArray(nz, jbuf, pbuf)); // It could be improved with k-way merge sort, since the rows are already sorted 545 546 // Scan column indices (in jbuf[0,nz), might have dups) of this row, and see how many go to Fd and how many go to Fo 547 PetscInt cur = 0; 548 while (cur < nz) { 549 PetscInt curColIdx = jbuf[cur]; 550 PetscInt dups = 1; 551 552 while (cur + dups < nz && jbuf[cur + dups] == curColIdx) dups++; 553 if (curColIdx >= cstart && curColIdx < cend) { 554 Fdi[curRowIdx]++; 555 FdnzDups += dups; 556 } else { 557 Foi[curRowIdx]++; 558 FonzDups += dups; 559 } 560 cur += dups; 561 } 562 563 FnzDups += nz; 564 iter += dupRows; // Move to next unique row 565 } 566 567 Fdi = Fdi_h.data(); // restore Fdi, Foi and make them CSR 568 Foi = Foi_h.data(); 569 for (PetscInt i = 0; i < Fm; i++) { 570 Fdi[i + 1] += Fdi[i]; 571 Foi[i + 1] += Foi[i]; 572 } 573 Fdnz = Fdi[Fm]; 574 Fonz = Foi[Fm]; 575 PetscCall(PetscFree2(sendCol, recvCol)); 576 577 // Allocate j, jmap, jperm for Fd and Fo 578 MatColIdxKokkosViewHost Fdj_h(NoInit("Fdj_h"), Fdnz), Foj_h(NoInit("Foj_h"), Fonz); 579 MatRowMapKokkosViewHost Fdjmap_h(NoInit("Fdjmap_h"), Fdnz + 1), Fojmap_h(NoInit("Fojmap_h"), Fonz + 1); // +1 to make csr 580 MatRowMapKokkosViewHost Fdjperm_h(NoInit("Fdjperm_h"), FdnzDups), Fojperm_h(NoInit("Fojperm_h"), FonzDups); 581 MatColIdxType *Fdj = Fdj_h.data(), *Foj = Foj_h.data(); 582 MatRowMapType *Fdjmap = Fdjmap_h.data(), *Fojmap = Fojmap_h.data(); 583 MatRowMapType *Fdjperm = Fdjperm_h.data(), *Fojperm = Fojperm_h.data(); 584 585 // Scan recvColSorted[] again, and fill j, jmap, jperm for Fd and Fo 586 Fdjmap[0] = 0; 587 Fojmap[0] = 0; 588 FnzDups = 0; 589 Fdnz = 0; 590 Fonz = 0; 591 iter = 0; // iter over received rows 592 while (iter < recvRowCnt) { 593 PetscInt curRowIdx = irootloc[recvRowPerm[iter]]; // current row idx 594 PetscInt dupRows = 1; // It has this many contributing rows (of various lengths) 595 PetscInt nz = 0; // nz (with dups) in the current row 596 597 while (iter + dupRows < recvRowCnt && irootloc[recvRowPerm[iter + dupRows]] == curRowIdx) dupRows++; 598 for (PetscInt d = 0; d < dupRows; d++) { 599 PetscInt i = recvRowPerm[iter + d]; 600 nz += recvRowLen[i + 1] - recvRowLen[i]; 601 } 602 603 PetscInt *jbuf = recvColSorted + FnzDups; 604 // Scan columns (in jbuf[0,nz) of this row, copy them and their permutation to j[] and jperm[] of Fd and Fo 605 PetscInt cur = 0; 606 while (cur < nz) { 607 PetscInt curColIdx = jbuf[cur]; 608 PetscInt dups = 1; 609 610 while (cur + dups < nz && jbuf[cur + dups] == curColIdx) dups++; 611 if (curColIdx >= cstart && curColIdx < cend) { 612 Fdj[Fdnz] = curColIdx - cstart; // easily convert to local 613 Fdjmap[Fdnz + 1] = Fdjmap[Fdnz] + dups; 614 for (PetscInt j = 0; j < dups; j++) Fdjperm[Fdjmap[Fdnz] + j] = recvNzPermSorted[FnzDups + j]; 615 FdnzDups += dups; 616 Fdnz++; 617 } else { 618 Foj[Fonz] = curColIdx; // in global 619 Fojmap[Fonz + 1] = Fojmap[Fonz] + dups; 620 for (PetscInt j = 0; j < dups; j++) Fojperm[Fojmap[Fonz] + j] = recvNzPermSorted[FnzDups + j]; 621 FonzDups += dups; 622 Fonz++; 623 } 624 cur += dups; 625 FnzDups += dups; 626 } 627 iter += dupRows; // Move to next unique row 628 } 629 PetscCall(PetscFree4(recvRowPerm, recvColSorted, recvNzPerm, recvNzPermSorted)); 630 PetscCall(PetscFree5(sendRowLen, recvRowLen, sdisp, rdisp, reqs)); 631 632 // Combine global column indices in garray1[] and Foj[] 633 PetscInt n2, *garray2; 634 635 PetscCall(ReduceTwoSetsOfGlobalIndices(n1, garray1, Fonz, Foj, &n2, &garray2, map)); 636 mm->sf = reduceSF; 637 mm->leafBuf = MatScalarKokkosView(NoInit("leafBuf"), nleaves); 638 mm->rootBuf = MatScalarKokkosView(NoInit("rootBuf"), nroots); 639 mm->garray = garray2; // give ownership, so no free 640 mm->n = n2; 641 mm->E_NzLeft = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), E_NzLeft_h); 642 mm->Fdjmap = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Fdjmap_h); 643 mm->Fdjperm = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Fdjperm_h); 644 mm->Fojmap = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Fojmap_h); 645 mm->Fojperm = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Fojperm_h); 646 647 // Output Fd and Fo in KokkosCsrMatrix format 648 MatScalarKokkosView Fda_d(NoInit("Fda_d"), Fdnz); 649 MatRowMapKokkosView Fdi_d = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Fdi_h); 650 MatColIdxKokkosView Fdj_d = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Fdj_h); 651 MatScalarKokkosView Foa_d(NoInit("Foa_d"), Fonz); 652 MatRowMapKokkosView Foi_d = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Foi_h); 653 MatColIdxKokkosView Foj_d = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Foj_h); 654 655 PetscCallCXX(mm->Fd = KokkosCsrMatrix("Fd", Fm, cend - cstart, Fdnz, Fda_d, Fdi_d, Fdj_d)); 656 PetscCallCXX(mm->Fo = KokkosCsrMatrix("Fo", Fm, n2, Fonz, Foa_d, Foi_d, Foj_d)); // Fo's column size is n2, length of garray2[] 657 658 // Compute kernel launch parameters in merging E 659 PetscInt teamSize, vectorLength, rowsPerTeam; 660 661 teamSize = vectorLength = rowsPerTeam = -1; 662 PetscCall(MatMergeGetLaunchParameters<DefaultExecutionSpace>(Em, Enz, -1, teamSize, vectorLength, rowsPerTeam)); 663 mm->E_TeamSize = teamSize; 664 mm->E_VectorLength = vectorLength; 665 mm->E_RowsPerTeam = rowsPerTeam; 666 } else PetscCheck(reuse == MAT_REUSE_MATRIX, comm, PETSC_ERR_PLIB, "Unsupported MatReuse enum %d", reuse); 667 668 // Handy aliases 669 auto &Aa = A.values; 670 auto &Ba = B.values; 671 const auto &Ai = A.graph.row_map; 672 const auto &Bi = B.graph.row_map; 673 const auto &E_NzLeft = mm->E_NzLeft; 674 auto &leafBuf = mm->leafBuf; 675 auto &rootBuf = mm->rootBuf; 676 PetscSF reduceSF = mm->sf; 677 PetscInt Em = A.numRows(); 678 PetscInt teamSize = mm->E_TeamSize; 679 PetscInt vectorLength = mm->E_VectorLength; 680 PetscInt rowsPerTeam = mm->E_RowsPerTeam; 681 PetscInt workSets = (Em + rowsPerTeam - 1) / rowsPerTeam; 682 683 // Copy rows in A/B of E to leafBuf, then pass it to rootBuf 684 PetscCallCXX(Kokkos::parallel_for( 685 Kokkos::TeamPolicy<>(PetscGetKokkosExecutionSpace(), workSets, teamSize, vectorLength), KOKKOS_LAMBDA(const KokkosTeamMemberType &t) { 686 Kokkos::parallel_for(Kokkos::TeamThreadRange(t, 0, rowsPerTeam), [&](PetscInt k) { 687 PetscInt i = t.league_rank() * rowsPerTeam + k; // i-th row in F 688 if (i < Em) { 689 PetscInt disp = Ai(i) + Bi(i); 690 PetscInt alen = Ai(i + 1) - Ai(i); 691 PetscInt blen = Bi(i + 1) - Bi(i); 692 PetscInt nzleft = E_NzLeft(i); 693 694 Kokkos::parallel_for(Kokkos::ThreadVectorRange(t, alen + blen), [&](PetscInt j) { 695 MatScalar &val = leafBuf(disp + j); 696 if (j < nzleft) { // B left 697 val = Ba(Bi(i) + j); 698 } else if (j < nzleft + alen) { // diag A 699 val = Aa(Ai(i) + j - nzleft); 700 } else { // B right 701 val = Ba(Bi(i) + j - alen); 702 } 703 }); 704 } 705 }); 706 })); 707 PetscCall(PetscSFReduceWithMemTypeBegin(reduceSF, MPIU_SCALAR, PETSC_MEMTYPE_KOKKOS, leafBuf.data(), PETSC_MEMTYPE_KOKKOS, rootBuf.data(), MPI_REPLACE)); 708 PetscFunctionReturn(PETSC_SUCCESS); 709 } 710 711 // To finish MatMPIAIJKokkosReduce. 712 static PetscErrorCode MatMPIAIJKokkosReduceEnd(MPI_Comm comm, KokkosCsrMatrix A, KokkosCsrMatrix B, PetscInt cstart, PetscInt cend, const PetscInt *garray1, PetscSF ownerSF, MatReuse reuse, PetscInt *map, MatMatStruct_AtB *mm) 713 { 714 auto &leafBuf = mm->leafBuf; 715 auto &rootBuf = mm->rootBuf; 716 auto &Fda = mm->Fd.values; 717 const auto &Fdjmap = mm->Fdjmap; 718 const auto &Fdjperm = mm->Fdjperm; 719 auto Fdnz = mm->Fd.nnz(); 720 auto &Foa = mm->Fo.values; 721 const auto &Fojmap = mm->Fojmap; 722 const auto &Fojperm = mm->Fojperm; 723 auto Fonz = mm->Fo.nnz(); 724 PetscSF reduceSF = mm->sf; 725 726 PetscFunctionBegin; 727 PetscCall(PetscSFReduceEnd(reduceSF, MPIU_SCALAR, leafBuf.data(), rootBuf.data(), MPI_REPLACE)); 728 729 // Reduce data in rootBuf to Fd and Fo 730 PetscCallCXX(Kokkos::parallel_for( 731 Kokkos::RangePolicy<>(PetscGetKokkosExecutionSpace(), 0, Fdnz), KOKKOS_LAMBDA(const MatRowMapType i) { 732 PetscScalar sum = 0.0; 733 for (MatRowMapType k = Fdjmap(i); k < Fdjmap(i + 1); k++) sum += rootBuf(Fdjperm(k)); 734 Fda(i) = sum; 735 })); 736 737 PetscCallCXX(Kokkos::parallel_for( 738 Kokkos::RangePolicy<>(PetscGetKokkosExecutionSpace(), 0, Fonz), KOKKOS_LAMBDA(const MatRowMapType i) { 739 PetscScalar sum = 0.0; 740 for (MatRowMapType k = Fojmap(i); k < Fojmap(i + 1); k++) sum += rootBuf(Fojperm(k)); 741 Foa(i) = sum; 742 })); 743 PetscFunctionReturn(PETSC_SUCCESS); 744 } 745 746 /* 747 MatMPIAIJKokkosBcast - Bcast local rows of a MPIAIJKOKKOS matrix (E) to produce a local matrix (F, stored in mm) in split form 748 749 This is a complex routine. It is essentially the MPIAIJKOKKOS counterpart of MatGetBrowsOfAoCols_MPIAIJ, but supports 750 device and involves various index mapping. 751 752 In the given ownerSF, leaves correspond to rows in F, and roots correspond to rows in E. Roots may connect to multiple leaves. 753 Suppose F's j-th row is connected to a root identified by PetscSFNode (k,i), it means we need to bcast the i-th row of E on rank k 754 to j-th row of F. ownerSF is not an arbitrary SF, instead it is the Mvctx of another MPIAIJ matrix A that is able to perform A*E. 755 F has the same column layout as E. 756 757 Conceptually F has global column indices. In this routine, we spit F into diagonal Fd and off-diagonal Fo. 758 Fd uses local column indices, which are easy to compute. We just need to subtract the "local column range start" from the global indices. 759 Fo had global column indices at first. We will reduce them into local ones. In doing that, we also take into account the global 760 column indices that E's off-diag block has. Let's say there are n1 such indices stored in garray1[]. We will reduce them along with 761 column indices in Fo and update Fo with local indices. 762 763 Input Parameters: 764 + E - the MPIAIJKOKKOS matrix 765 . ownerSF - the ownership SF (insignificant in MAT_REUSE_MATRIX) 766 . reuse - either MAT_INITIAL_MATRIX or MAT_REUSE_MATRIX 767 - mm - to stash matproduct intermediate data structures 768 769 Output Parameters: 770 + map[n1] - allocated by caller. It maps garray1[] to garray2[]. See more at ReduceTwoSetsOfGlobalIndices. 771 - mm - contains various info, such as garray2[], Fd, Fo, etc. 772 773 Notes: 774 When reuse = MAT_REUSE_MATRIX, ownerSF, map are not significant. 775 The routine is provide in split-phase form MatMPIAIJKokkosBcastBegin/End() to provide computation/communication opportunities. 776 */ 777 static PetscErrorCode MatMPIAIJKokkosBcastBegin(Mat E, PetscSF ownerSF, MatReuse reuse, PetscInt *map, MatMatStruct_AB *mm) 778 { 779 Mat_MPIAIJ *empi = static_cast<Mat_MPIAIJ *>(E->data); 780 Mat A = empi->A, B = empi->B; // diag and off-diag 781 Mat_SeqAIJKokkos *akok = static_cast<Mat_SeqAIJKokkos *>(A->spptr), *bkok = static_cast<Mat_SeqAIJKokkos *>(B->spptr); 782 PetscInt Em = E->rmap->n; // #local rows 783 MPI_Comm comm; 784 785 PetscFunctionBegin; 786 PetscCallMPI(PetscObjectGetComm((PetscObject)E, &comm)); 787 if (reuse == MAT_INITIAL_MATRIX) { 788 Mat_SeqAIJ *aseq = static_cast<Mat_SeqAIJ *>(A->data), *bseq = static_cast<Mat_SeqAIJ *>(B->data); 789 PetscInt n1 = B->cmap->n, *Ai = aseq->i, *Aj = aseq->j, *Bi = bseq->i, *Bj = bseq->j; 790 const PetscInt *garray1 = empi->garray; // its size is n1 791 PetscInt cstart, cend; 792 PetscSF bcastSF; 793 794 PetscCall(MatGetOwnershipRangeColumn(E, &cstart, &cend)); 795 796 // Count how many nonzeros of each row in E are in the left of the diag block [cstart,cend) 797 PetscIntKokkosViewHost E_NzLeft_h(NoInit("E_NzLeft_h"), Em), E_RowLen_h(NoInit("E_RowLen_h"), Em); 798 PetscInt *E_NzLeft = E_NzLeft_h.data(), *E_RowLen = E_RowLen_h.data(); 799 for (PetscInt i = 0; i < Em; i++) { 800 const PetscInt *first, *last, *it; 801 PetscInt count, step; 802 // std::lower_bound(first,last,cstart), but need to use global column indices 803 first = Bj + Bi[i]; 804 last = Bj + Bi[i + 1]; 805 count = last - first; 806 while (count > 0) { 807 it = first; 808 step = count / 2; 809 it += step; 810 if (empi->garray[*it] < cstart) { // map local to global 811 first = ++it; 812 count -= step + 1; 813 } else count = step; 814 } 815 E_NzLeft[i] = first - (Bj + Bi[i]); 816 E_RowLen[i] = (Ai[i + 1] - Ai[i]) + (Bi[i + 1] - Bi[i]); 817 } 818 819 // Compute row pointer Fi of F 820 PetscInt *Fi, Fm, Fnz; 821 PetscCall(PetscSFGetGraph(ownerSF, NULL, &Fm, NULL, NULL)); // Fm = #rows of F = nleaves of ownerSF 822 PetscCall(PetscMalloc1(Fm + 1, &Fi)); 823 Fi[0] = 0; 824 PetscCall(PetscSFBcastWithMemTypeBegin(ownerSF, MPIU_INT, PETSC_MEMTYPE_HOST, E_RowLen, PETSC_MEMTYPE_HOST, &Fi[1], MPI_REPLACE)); 825 PetscCall(PetscSFBcastEnd(ownerSF, MPIU_INT, E_RowLen, &Fi[1], MPI_REPLACE)); 826 for (PetscInt i = 0; i < Fm; i++) Fi[i + 1] += Fi[i]; 827 Fnz = Fi[Fm]; 828 829 // Build the real PetscSF for bcasting E rows (buffer to buffer) 830 const PetscMPIInt *iranks, *ranks; 831 const PetscInt *ioffset, *irootloc, *roffset; 832 PetscMPIInt niranks, nranks; 833 PetscInt *sdisp, *rdisp; 834 MPI_Request *reqs; 835 PetscMPIInt tag; 836 837 PetscCall(PetscSFGetLeafRanks(ownerSF, &niranks, &iranks, &ioffset, &irootloc)); // get leaf ranks referencing roots on this process 838 PetscCall(PetscSFGetRootRanks(ownerSF, &nranks, &ranks, &roffset, NULL, NULL)); // recv info 839 PetscCall(PetscMalloc3(niranks + 1, &sdisp, nranks, &rdisp, niranks + nranks, &reqs)); 840 841 sdisp[0] = 0; // send displacement 842 for (PetscInt i = 0; i < niranks; i++) { 843 sdisp[i + 1] = sdisp[i]; 844 for (PetscInt j = ioffset[i]; j < ioffset[i + 1]; j++) { 845 PetscInt r = irootloc[j]; // row to be sent 846 sdisp[i + 1] += E_RowLen[r]; 847 } 848 } 849 850 PetscCallMPI(PetscCommGetNewTag(comm, &tag)); 851 for (PetscInt i = 0; i < nranks; i++) PetscCallMPI(MPIU_Irecv(&rdisp[i], 1, MPIU_INT, ranks[i], tag, comm, &reqs[i])); 852 for (PetscInt i = 0; i < niranks; i++) PetscCallMPI(MPIU_Isend(&sdisp[i], 1, MPIU_INT, iranks[i], tag, comm, &reqs[nranks + i])); 853 PetscCallMPI(MPI_Waitall(niranks + nranks, reqs, MPI_STATUSES_IGNORE)); 854 855 PetscInt nleaves = Fnz; // leaves are nonzeros I will receive 856 PetscInt nroots = sdisp[niranks]; // roots are nonzeros I will send 857 PetscSFNode *iremote; // give ownership to bcastSF 858 PetscCall(PetscMalloc1(nleaves, &iremote)); 859 for (PetscInt i = 0; i < nranks; i++) { // for each sender rank 860 PetscInt k = 0; 861 for (PetscInt j = Fi[roffset[i]]; j < Fi[roffset[i + 1]]; j++) { // I will receive rows [roffset[i], roffset[i+1]) of F from ranks[i] 862 iremote[j].rank = ranks[i]; 863 iremote[j].index = rdisp[i] + k; // their root location 864 k++; 865 } 866 } 867 PetscCall(PetscSFCreate(comm, &bcastSF)); 868 PetscCall(PetscSFSetGraph(bcastSF, nroots, nleaves, NULL, PETSC_OWN_POINTER, iremote, PETSC_OWN_POINTER)); 869 PetscCall(PetscFree3(sdisp, rdisp, reqs)); 870 871 // Build a plan (rowoffset, irootloc, E_NzLeft) to copy rows in E to rootdata of bcastSF in parallel 872 PetscIntKokkosViewHost rowoffset_h(NoInit("rowoffset_h"), ioffset[niranks] + 1); 873 PetscInt *rowoffset = rowoffset_h.data(); // for each entry (row) indicated in irootloc[], we calculate its destinate offset in copying 874 rowoffset[0] = 0; 875 for (PetscInt i = 0; i < ioffset[niranks]; i++) { rowoffset[i + 1] = rowoffset[i] + E_RowLen[irootloc[i]]; } 876 877 // Copy (global) column indices of the needed rows in E to a buffer, and then bcast to Fj[] 878 PetscInt *jbuf, *Fj; 879 PetscCall(PetscMalloc2(nroots, &jbuf, Fnz, &Fj)); 880 for (PetscInt k = 0; k < ioffset[niranks]; k++) { 881 PetscInt i = irootloc[k]; // row to be copied 882 PetscInt *buf = &jbuf[rowoffset[k]]; 883 PetscInt nzLeft = E_NzLeft[i]; 884 PetscInt alen = Ai[i + 1] - Ai[i], blen = Bi[i + 1] - Bi[i]; 885 for (PetscInt j = 0; j < alen + blen; j++) { 886 if (j < nzLeft) { 887 buf[j] = empi->garray[Bj[Bi[i] + j]]; // left B, in global 888 } else if (j < nzLeft + alen) { 889 buf[j] = Aj[Ai[i] + j - nzLeft] + cstart; // diag A, also in global 890 } else { 891 buf[j] = empi->garray[Bj[Bi[i] + j - alen]]; // right B, in global 892 } 893 } 894 } 895 PetscCall(PetscSFBcastWithMemTypeBegin(bcastSF, MPIU_INT, PETSC_MEMTYPE_HOST, jbuf, PETSC_MEMTYPE_HOST, Fj, MPI_REPLACE)); 896 PetscCall(PetscSFBcastEnd(bcastSF, MPIU_INT, jbuf, Fj, MPI_REPLACE)); 897 898 // Build a plan (i.e., F_NzLeft) to split F into Fd and Fo 899 MatRowMapKokkosViewHost Fdi_h(NoInit("Fdi_h"), Fm + 1), Foi_h(NoInit("Foi_h"), Fm + 1); // row pointer of Fd, Fo 900 MatColIdxKokkosViewHost F_NzLeft_h(NoInit("F_NzLeft_h"), Fm); // split each row of F into Left, Diag, Right. We only need to record #nz in Left and Diag. 901 MatRowMapType *Fdi = Fdi_h.data(), *Foi = Foi_h.data(); 902 MatColIdxType *F_NzLeft = F_NzLeft_h.data(); 903 904 Fdi[0] = Foi[0] = 0; 905 for (PetscInt i = 0; i < Fm; i++) { 906 PetscInt *first, *last, *lb1, *lb2; 907 // cut the row into: Left, [cstart, cend), Right 908 first = Fj + Fi[i]; 909 last = Fj + Fi[i + 1]; 910 lb1 = std::lower_bound(first, last, cstart); 911 F_NzLeft[i] = lb1 - first; 912 lb2 = std::lower_bound(first, last, cend); 913 Fdi[i + 1] = lb2 - lb1; // row i length in Fdi 914 Foi[i + 1] = (Fi[i + 1] - Fi[i]) - Fdi[i + 1]; // row i length in Foi 915 } 916 for (PetscInt i = 0; i < Fm; i++) { 917 Fdi[i + 1] += Fdi[i]; 918 Foi[i + 1] += Foi[i]; 919 } 920 921 // Fill Fdj[] and Foj[], i.e., columns of Fd and Fo. Fdj[] are local, but Foj[] are not yet. 922 PetscInt Fdnz = Fdi[Fm], Fonz = Foi[Fm]; 923 MatColIdxKokkosViewHost Fdj_h(NoInit("Fdj_h"), Fdnz), Foj_h(NoInit("Foj_h"), Fonz); 924 MatColIdxType *Fdj = Fdj_h.data(), *Foj = Foj_h.data(), gid; 925 926 for (PetscInt i = 0; i < Fm; i++) { 927 PetscInt nzLeft = F_NzLeft[i]; 928 PetscInt len = Fdi[i + 1] - Fdi[i]; // diag row len 929 for (PetscInt j = 0; j < Fi[i + 1] - Fi[i]; j++) { 930 gid = Fj[Fi[i] + j]; 931 if (j < nzLeft) { // left, in global 932 Foj[Foi[i] + j] = gid; 933 } else if (j < nzLeft + len) { // diag, in local 934 Fdj[Fdi[i] + j - nzLeft] = gid - cstart; 935 } else { // right, in global 936 Foj[Foi[i] + j - len] = gid; 937 } 938 } 939 } 940 PetscCall(PetscFree2(jbuf, Fj)); 941 PetscCall(PetscFree(Fi)); 942 943 // Reduce global indices in Foj[] and garray1[] into local ones 944 PetscInt n2, *garray2; 945 PetscCall(ReduceTwoSetsOfGlobalIndices(n1, garray1, Fonz, Foj, &n2, &garray2, map)); 946 947 // Record the plans built above, for reuse 948 PetscIntKokkosViewHost tmp(const_cast<PetscInt *>(irootloc), ioffset[niranks]); // irootloc[] is owned by ownerSF. We create a copy for safety 949 PetscIntKokkosViewHost irootloc_h(NoInit("irootloc_h"), ioffset[niranks]); 950 Kokkos::deep_copy(irootloc_h, tmp); 951 mm->sf = bcastSF; 952 mm->E_NzLeft = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), E_NzLeft_h); 953 mm->F_NzLeft = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), F_NzLeft_h); 954 mm->irootloc = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), irootloc_h); 955 mm->rowoffset = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), rowoffset_h); 956 mm->rootBuf = MatScalarKokkosView(NoInit("rootBuf"), nroots); 957 mm->leafBuf = MatScalarKokkosView(NoInit("leafBuf"), nleaves); 958 mm->garray = garray2; 959 mm->n = n2; 960 961 // Output Fd and Fo in KokkosCsrMatrix format 962 MatScalarKokkosView Fda_d(NoInit("Fda_d"), Fdnz), Foa_d(NoInit("Foa_d"), Fonz); 963 MatRowMapKokkosView Fdi_d = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Fdi_h); 964 MatColIdxKokkosView Fdj_d = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Fdj_h); 965 MatRowMapKokkosView Foi_d = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Foi_h); 966 MatColIdxKokkosView Foj_d = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Foj_h); 967 968 PetscCallCXX(mm->Fd = KokkosCsrMatrix("Fd", Fm, cend - cstart, Fdnz, Fda_d, Fdi_d, Fdj_d)); 969 PetscCallCXX(mm->Fo = KokkosCsrMatrix("Fo", Fm, n2, Fonz, Foa_d, Foi_d, Foj_d)); 970 971 // Compute kernel launch parameters in merging E or splitting F 972 PetscInt teamSize, vectorLength, rowsPerTeam; 973 974 teamSize = vectorLength = rowsPerTeam = -1; 975 PetscCall(MatMergeGetLaunchParameters<DefaultExecutionSpace>(mm->irootloc.extent(0), mm->rootBuf.extent(0), -1, teamSize, vectorLength, rowsPerTeam)); 976 mm->E_TeamSize = teamSize; 977 mm->E_VectorLength = vectorLength; 978 mm->E_RowsPerTeam = rowsPerTeam; 979 980 teamSize = vectorLength = rowsPerTeam = -1; 981 PetscCall(MatMergeGetLaunchParameters<DefaultExecutionSpace>(Fm, Fnz, -1, teamSize, vectorLength, rowsPerTeam)); 982 mm->F_TeamSize = teamSize; 983 mm->F_VectorLength = vectorLength; 984 mm->F_RowsPerTeam = rowsPerTeam; 985 } else PetscCheck(reuse == MAT_REUSE_MATRIX, comm, PETSC_ERR_PLIB, "Unsupported MatReuse enum %d", reuse); 986 987 // Sync E's value to device 988 akok->a_dual.sync_device(); 989 bkok->a_dual.sync_device(); 990 991 // Handy aliases 992 const auto &Aa = akok->a_dual.view_device(); 993 const auto &Ba = bkok->a_dual.view_device(); 994 const auto &Ai = akok->i_dual.view_device(); 995 const auto &Bi = bkok->i_dual.view_device(); 996 997 // Fetch the plans 998 PetscIntKokkosView &E_NzLeft = mm->E_NzLeft; 999 PetscSF &bcastSF = mm->sf; 1000 MatScalarKokkosView &rootBuf = mm->rootBuf; 1001 MatScalarKokkosView &leafBuf = mm->leafBuf; 1002 PetscIntKokkosView &irootloc = mm->irootloc; 1003 PetscIntKokkosView &rowoffset = mm->rowoffset; 1004 1005 PetscInt teamSize = mm->E_TeamSize; 1006 PetscInt vectorLength = mm->E_VectorLength; 1007 PetscInt rowsPerTeam = mm->E_RowsPerTeam; 1008 PetscInt workSets = (irootloc.extent(0) + rowsPerTeam - 1) / rowsPerTeam; 1009 1010 // Copy rows in A/B of E to rootBuf, then bcast it to leafBuf 1011 PetscCallCXX(Kokkos::parallel_for( 1012 Kokkos::TeamPolicy<>(PetscGetKokkosExecutionSpace(), workSets, teamSize, vectorLength), KOKKOS_LAMBDA(const KokkosTeamMemberType &t) { 1013 Kokkos::parallel_for(Kokkos::TeamThreadRange(t, 0, rowsPerTeam), [&](PetscInt k) { 1014 size_t r = t.league_rank() * rowsPerTeam + k; // r-th entry in irootloc[] 1015 if (r < irootloc.extent(0)) { 1016 PetscInt i = irootloc(r); // row i of E 1017 PetscInt disp = rowoffset(r); 1018 PetscInt alen = Ai(i + 1) - Ai(i); 1019 PetscInt blen = Bi(i + 1) - Bi(i); 1020 PetscInt nzleft = E_NzLeft(i); 1021 1022 Kokkos::parallel_for(Kokkos::ThreadVectorRange(t, alen + blen), [&](PetscInt j) { 1023 if (j < nzleft) { // B left 1024 rootBuf(disp + j) = Ba(Bi(i) + j); 1025 } else if (j < nzleft + alen) { // diag A 1026 rootBuf(disp + j) = Aa(Ai(i) + j - nzleft); 1027 } else { // B right 1028 rootBuf(disp + j) = Ba(Bi(i) + j - alen); 1029 } 1030 }); 1031 } 1032 }); 1033 })); 1034 PetscCall(PetscSFBcastWithMemTypeBegin(bcastSF, MPIU_SCALAR, PETSC_MEMTYPE_KOKKOS, rootBuf.data(), PETSC_MEMTYPE_KOKKOS, leafBuf.data(), MPI_REPLACE)); 1035 PetscFunctionReturn(PETSC_SUCCESS); 1036 } 1037 1038 // To finish MatMPIAIJKokkosBcast. 1039 static PetscErrorCode MatMPIAIJKokkosBcastEnd(Mat E, PetscSF ownerSF, MatReuse reuse, PetscInt *map, MatMatStruct_AB *mm) 1040 { 1041 PetscFunctionBegin; 1042 const auto &Fd = mm->Fd; 1043 const auto &Fo = mm->Fo; 1044 const auto &Fdi = Fd.graph.row_map; 1045 const auto &Foi = Fo.graph.row_map; 1046 auto &Fda = Fd.values; 1047 auto &Foa = Fo.values; 1048 auto Fm = Fd.numRows(); 1049 1050 PetscIntKokkosView &F_NzLeft = mm->F_NzLeft; 1051 PetscSF &bcastSF = mm->sf; 1052 MatScalarKokkosView &rootBuf = mm->rootBuf; 1053 MatScalarKokkosView &leafBuf = mm->leafBuf; 1054 PetscInt teamSize = mm->F_TeamSize; 1055 PetscInt vectorLength = mm->F_VectorLength; 1056 PetscInt rowsPerTeam = mm->F_RowsPerTeam; 1057 PetscInt workSets = (Fm + rowsPerTeam - 1) / rowsPerTeam; 1058 1059 PetscCall(PetscSFBcastEnd(bcastSF, MPIU_SCALAR, rootBuf.data(), leafBuf.data(), MPI_REPLACE)); 1060 1061 // Update Fda and Foa with new data in leafBuf (as if it is Fa) 1062 PetscCallCXX(Kokkos::parallel_for( 1063 Kokkos::TeamPolicy<>(PetscGetKokkosExecutionSpace(), workSets, teamSize, vectorLength), KOKKOS_LAMBDA(const KokkosTeamMemberType &t) { 1064 Kokkos::parallel_for(Kokkos::TeamThreadRange(t, 0, rowsPerTeam), [&](PetscInt k) { 1065 PetscInt i = t.league_rank() * rowsPerTeam + k; // i-th row in F 1066 if (i < Fm) { 1067 PetscInt nzLeft = F_NzLeft(i); 1068 PetscInt alen = Fdi(i + 1) - Fdi(i); 1069 PetscInt blen = Foi(i + 1) - Foi(i); 1070 PetscInt Fii = Fdi(i) + Foi(i); 1071 1072 Kokkos::parallel_for(Kokkos::ThreadVectorRange(t, alen + blen), [&](PetscInt j) { 1073 PetscScalar val = leafBuf(Fii + j); 1074 if (j < nzLeft) { // left 1075 Foa(Foi(i) + j) = val; 1076 } else if (j < nzLeft + alen) { // diag 1077 Fda(Fdi(i) + j - nzLeft) = val; 1078 } else { // right 1079 Foa(Foi(i) + j - alen) = val; 1080 } 1081 }); 1082 } 1083 }); 1084 })); 1085 PetscFunctionReturn(PETSC_SUCCESS); 1086 } 1087 1088 static PetscErrorCode MatProductSymbolic_MPIAIJKokkos_AtB(Mat_Product *product, Mat A, Mat B, MatMatStruct_AtB *mm) 1089 { 1090 Mat_MPIAIJ *ampi = static_cast<Mat_MPIAIJ *>(A->data); 1091 Mat_MPIAIJ *bmpi = static_cast<Mat_MPIAIJ *>(B->data); 1092 KokkosCsrMatrix Adt, Aot, Ad, Ao, Bd, Bo; 1093 PetscInt cstart, cend; 1094 MPI_Comm comm; 1095 1096 PetscFunctionBegin; 1097 PetscCall(PetscObjectGetComm((PetscObject)B, &comm)); 1098 PetscCall(MatSeqAIJKokkosGenerateTranspose_Private(ampi->A, &Adt)); 1099 PetscCall(MatSeqAIJKokkosGenerateTranspose_Private(ampi->B, &Aot)); 1100 PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(ampi->A, &Ad)); 1101 PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(ampi->B, &Ao)); 1102 PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(bmpi->A, &Bd)); 1103 PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(bmpi->B, &Bo)); 1104 1105 // TODO: add command line options to select spgemm algorithms 1106 auto spgemm_alg = KokkosSparse::SPGEMMAlgorithm::SPGEMM_DEFAULT; // default is TPL if enabled, otherwise KK 1107 1108 // CUDA-10.2's spgemm has bugs. We prefer the SpGEMMreuse APIs introduced in cuda-11.4 1109 #if defined(KOKKOSKERNELS_ENABLE_TPL_CUSPARSE) 1110 #if PETSC_PKG_CUDA_VERSION_LT(11, 4, 0) 1111 spgemm_alg = KokkosSparse::SPGEMMAlgorithm::SPGEMM_KK; 1112 #endif 1113 #endif 1114 1115 PetscCallCXX(mm->kh1.create_spgemm_handle(spgemm_alg)); 1116 PetscCallCXX(mm->kh2.create_spgemm_handle(spgemm_alg)); 1117 PetscCallCXX(mm->kh3.create_spgemm_handle(spgemm_alg)); 1118 PetscCallCXX(mm->kh4.create_spgemm_handle(spgemm_alg)); 1119 1120 // Aot * (B's diag + B's off-diag) 1121 PetscCallCXX(KokkosSparse::spgemm_symbolic(mm->kh3, Aot, false, Bd, false, mm->C3)); 1122 PetscCallCXX(KokkosSparse::spgemm_symbolic(mm->kh4, Aot, false, Bo, false, mm->C4)); 1123 // KK spgemm_symbolic() only populates the result's row map, but not its columns. 1124 // TODO: Remove the fake spgemm_numeric() after KK fixed this problem. 1125 PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh3, Aot, false, Bd, false, mm->C3)); 1126 PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh4, Aot, false, Bo, false, mm->C4)); 1127 #if PETSC_PKG_KOKKOS_KERNELS_VERSION_LT(4, 0, 0) 1128 1129 PetscCallCXX(sort_crs_matrix(mm->C3)); 1130 PetscCallCXX(sort_crs_matrix(mm->C4)); 1131 #endif 1132 1133 // Reduce E (i.e., C3 and C4)'s rows to form F, and overlap the communication 1134 PetscIntKokkosViewHost map_h(NoInit("map_h"), bmpi->B->cmap->n); 1135 PetscCall(MatGetOwnershipRangeColumn(B, &cstart, &cend)); 1136 PetscCall(MatMPIAIJKokkosReduceBegin(comm, mm->C3, mm->C4, cstart, cend, bmpi->garray, ampi->Mvctx, MAT_INITIAL_MATRIX, map_h.data(), mm)); 1137 1138 // Adt * (B's diag + B's off-diag) 1139 PetscCallCXX(KokkosSparse::spgemm_symbolic(mm->kh1, Adt, false, Bd, false, mm->C1)); 1140 PetscCallCXX(KokkosSparse::spgemm_symbolic(mm->kh2, Adt, false, Bo, false, mm->C2_mid)); 1141 PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh1, Adt, false, Bd, false, mm->C1)); 1142 PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh2, Adt, false, Bo, false, mm->C2_mid)); 1143 #if PETSC_PKG_KOKKOS_KERNELS_VERSION_LT(4, 0, 0) 1144 PetscCallCXX(sort_crs_matrix(mm->C1)); 1145 PetscCallCXX(sort_crs_matrix(mm->C2_mid)); 1146 #endif 1147 1148 PetscCall(MatMPIAIJKokkosReduceEnd(comm, mm->C3, mm->C4, cstart, cend, bmpi->garray, ampi->Mvctx, MAT_INITIAL_MATRIX, map_h.data(), mm)); 1149 1150 // Create C2, which shares a, i arrays with C2_mid, but with new column indices and potentially larger column size 1151 MatColIdxKokkosView oldj = mm->C2_mid.graph.entries, newj(NoInit("j"), oldj.extent(0)); 1152 PetscIntKokkosView map = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), map_h); 1153 PetscCallCXX(Kokkos::parallel_for(Kokkos::RangePolicy<>(PetscGetKokkosExecutionSpace(), 0, oldj.extent(0)), KOKKOS_LAMBDA(const PetscInt i) { newj(i) = map(oldj(i)); })); 1154 PetscCallCXX(mm->C2 = KokkosCsrMatrix("C2", mm->C2_mid.numRows(), mm->n /*new column size*/, mm->C2_mid.nnz(), mm->C2_mid.values, mm->C2_mid.graph.row_map, newj)); 1155 1156 // C = (C1+Fd, C2+Fo) 1157 PetscCallCXX(mm->kh1.create_spadd_handle(true)); // C1, Fd are sorted 1158 PetscCallCXX(mm->kh2.create_spadd_handle(true)); // C2, Fo are sorted 1159 PetscCallCXX(KokkosSparse::spadd_symbolic(&mm->kh1, mm->C1, mm->Fd, mm->Cd)); 1160 PetscCallCXX(KokkosSparse::spadd_symbolic(&mm->kh2, mm->C2, mm->Fo, mm->Co)); 1161 PetscCallCXX(KokkosSparse::spadd_numeric(&mm->kh1, 1.0, mm->C1, 1.0, mm->Fd, mm->Cd)); 1162 PetscCallCXX(KokkosSparse::spadd_numeric(&mm->kh2, 1.0, mm->C2, 1.0, mm->Fo, mm->Co)); 1163 PetscFunctionReturn(PETSC_SUCCESS); 1164 } 1165 1166 static PetscErrorCode MatProductNumeric_MPIAIJKokkos_AtB(Mat_Product *product, Mat A, Mat B, MatMatStruct_AtB *mm) 1167 { 1168 Mat_MPIAIJ *ampi = static_cast<Mat_MPIAIJ *>(A->data); 1169 Mat_MPIAIJ *bmpi = static_cast<Mat_MPIAIJ *>(B->data); 1170 KokkosCsrMatrix Adt, Aot, Bd, Bo; 1171 MPI_Comm comm; 1172 1173 PetscFunctionBegin; 1174 PetscCall(PetscObjectGetComm((PetscObject)B, &comm)); 1175 PetscCall(MatSeqAIJKokkosGenerateTranspose_Private(ampi->A, &Adt)); 1176 PetscCall(MatSeqAIJKokkosGenerateTranspose_Private(ampi->B, &Aot)); 1177 PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(bmpi->A, &Bd)); 1178 PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(bmpi->B, &Bo)); 1179 1180 // Aot * (B's diag + B's off-diag) 1181 PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh3, Aot, false, Bd, false, mm->C3)); 1182 PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh4, Aot, false, Bo, false, mm->C4)); 1183 1184 // Reduce E (i.e., C3 and C4)'s rows to form F, and overlap the communication 1185 PetscCall(MatMPIAIJKokkosReduceBegin(comm, mm->C3, mm->C4, 0, 0, NULL, NULL, MAT_REUSE_MATRIX, NULL, mm)); 1186 1187 // Adt * (B's diag + B's off-diag) 1188 PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh1, Adt, false, Bd, false, mm->C1)); 1189 PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh2, Adt, false, Bo, false, mm->C2_mid)); 1190 1191 PetscCall(MatMPIAIJKokkosReduceEnd(comm, mm->C3, mm->C4, 0, 0, NULL, NULL, MAT_REUSE_MATRIX, NULL, mm)); 1192 1193 // C = (C1+Fd, C2+Fo) 1194 PetscCallCXX(KokkosSparse::spadd_numeric(&mm->kh1, 1.0, mm->C1, 1.0, mm->Fd, mm->Cd)); 1195 PetscCallCXX(KokkosSparse::spadd_numeric(&mm->kh2, 1.0, mm->C2, 1.0, mm->Fo, mm->Co)); 1196 PetscFunctionReturn(PETSC_SUCCESS); 1197 } 1198 1199 /* MatProductSymbolic_MPIAIJKokkos_AB - AB flavor of MatProductSymbolic_MPIAIJKokkos 1200 1201 Input Parameters: 1202 + product - Mat_Product which carried out the computation. Passed in to access info about this mat product. 1203 . A - an MPIAIJKOKKOS matrix 1204 . B - an MPIAIJKOKKOS matrix 1205 - mm - a struct used to stash intermediate data when computing AB. Persist from symbolic to numeric operations. 1206 */ 1207 static PetscErrorCode MatProductSymbolic_MPIAIJKokkos_AB(Mat_Product *product, Mat A, Mat B, MatMatStruct_AB *mm) 1208 { 1209 Mat_MPIAIJ *ampi = static_cast<Mat_MPIAIJ *>(A->data); 1210 Mat_MPIAIJ *bmpi = static_cast<Mat_MPIAIJ *>(B->data); 1211 KokkosCsrMatrix Ad, Ao, Bd, Bo; 1212 1213 PetscFunctionBegin; 1214 PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(ampi->A, &Ad)); 1215 PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(ampi->B, &Ao)); 1216 PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(bmpi->A, &Bd)); 1217 PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(bmpi->B, &Bo)); 1218 1219 // TODO: add command line options to select spgemm algorithms 1220 auto spgemm_alg = KokkosSparse::SPGEMMAlgorithm::SPGEMM_DEFAULT; // default is TPL if enabled, otherwise KK 1221 1222 // CUDA-10.2's spgemm has bugs. We prefer the SpGEMMreuse APIs introduced in cuda-11.4 1223 #if defined(KOKKOSKERNELS_ENABLE_TPL_CUSPARSE) 1224 #if PETSC_PKG_CUDA_VERSION_LT(11, 4, 0) 1225 spgemm_alg = KokkosSparse::SPGEMMAlgorithm::SPGEMM_KK; 1226 #endif 1227 #endif 1228 1229 mm->kh1.create_spgemm_handle(spgemm_alg); 1230 mm->kh2.create_spgemm_handle(spgemm_alg); 1231 mm->kh3.create_spgemm_handle(spgemm_alg); 1232 mm->kh4.create_spgemm_handle(spgemm_alg); 1233 1234 // Bcast B's rows to form F, and overlap the communication 1235 PetscIntKokkosViewHost map_h(NoInit("map_h"), bmpi->B->cmap->n); 1236 PetscCall(MatMPIAIJKokkosBcastBegin(B, ampi->Mvctx, MAT_INITIAL_MATRIX, map_h.data(), mm)); 1237 1238 // A's diag * (B's diag + B's off-diag) 1239 PetscCallCXX(KokkosSparse::spgemm_symbolic(mm->kh1, Ad, false, Bd, false, mm->C1)); 1240 PetscCallCXX(KokkosSparse::spgemm_symbolic(mm->kh2, Ad, false, Bo, false, mm->C2_mid)); // C2 aliases with C2_mid, except with new column indices 1241 // KK spgemm_symbolic() only populates the result's row map, but not its columns. 1242 // TODO: Remove the fake spgemm_numeric() after KK fixed this problem. 1243 PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh1, Ad, false, Bd, false, mm->C1)); 1244 PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh2, Ad, false, Bo, false, mm->C2_mid)); 1245 #if PETSC_PKG_KOKKOS_KERNELS_VERSION_LT(4, 0, 0) 1246 PetscCallCXX(sort_crs_matrix(mm->C1)); 1247 PetscCallCXX(sort_crs_matrix(mm->C2_mid)); 1248 #endif 1249 1250 PetscCall(MatMPIAIJKokkosBcastEnd(B, ampi->Mvctx, MAT_INITIAL_MATRIX, map_h.data(), mm)); 1251 1252 // A's off-diag * (F's diag + F's off-diag) 1253 PetscCallCXX(KokkosSparse::spgemm_symbolic(mm->kh3, Ao, false, mm->Fd, false, mm->C3)); 1254 PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh3, Ao, false, mm->Fd, false, mm->C3)); 1255 PetscCallCXX(KokkosSparse::spgemm_symbolic(mm->kh4, Ao, false, mm->Fo, false, mm->C4)); 1256 PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh4, Ao, false, mm->Fo, false, mm->C4)); 1257 #if PETSC_PKG_KOKKOS_KERNELS_VERSION_LT(4, 0, 0) 1258 PetscCallCXX(sort_crs_matrix(mm->C3)); 1259 PetscCallCXX(sort_crs_matrix(mm->C4)); 1260 #endif 1261 1262 // Create C2, which shares a, i arrays with C2_mid, but with new column indices and potentially larger column size 1263 MatColIdxKokkosView oldj = mm->C2_mid.graph.entries, newj(NoInit("j"), oldj.extent(0)); 1264 PetscIntKokkosView map = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), map_h); 1265 PetscCallCXX(Kokkos::parallel_for(Kokkos::RangePolicy<>(PetscGetKokkosExecutionSpace(), 0, oldj.extent(0)), KOKKOS_LAMBDA(const PetscInt i) { newj(i) = map(oldj(i)); })); 1266 mm->C2 = KokkosCsrMatrix("C2", mm->C2_mid.numRows(), mm->n /*new column size*/, mm->C2_mid.nnz(), mm->C2_mid.values, mm->C2_mid.graph.row_map, newj); 1267 1268 // C = (Cd, Co) = (C1+C3, C2+C4) 1269 mm->kh1.create_spadd_handle(true); // C1, C3 are sorted 1270 mm->kh2.create_spadd_handle(true); // C2, C4 are sorted 1271 PetscCallCXX(KokkosSparse::spadd_symbolic(&mm->kh1, mm->C1, mm->C3, mm->Cd)); 1272 PetscCallCXX(KokkosSparse::spadd_symbolic(&mm->kh2, mm->C2, mm->C4, mm->Co)); 1273 PetscCallCXX(KokkosSparse::spadd_numeric(&mm->kh1, 1.0, mm->C1, 1.0, mm->C3, mm->Cd)); 1274 PetscCallCXX(KokkosSparse::spadd_numeric(&mm->kh2, 1.0, mm->C2, 1.0, mm->C4, mm->Co)); 1275 PetscFunctionReturn(PETSC_SUCCESS); 1276 } 1277 1278 static PetscErrorCode MatProductNumeric_MPIAIJKokkos_AB(Mat_Product *product, Mat A, Mat B, MatMatStruct_AB *mm) 1279 { 1280 Mat_MPIAIJ *ampi = static_cast<Mat_MPIAIJ *>(A->data); 1281 Mat_MPIAIJ *bmpi = static_cast<Mat_MPIAIJ *>(B->data); 1282 KokkosCsrMatrix Ad, Ao, Bd, Bo; 1283 1284 PetscFunctionBegin; 1285 PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(ampi->A, &Ad)); 1286 PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(ampi->B, &Ao)); 1287 PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(bmpi->A, &Bd)); 1288 PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(bmpi->B, &Bo)); 1289 1290 // Bcast B's rows to form F, and overlap the communication 1291 PetscCall(MatMPIAIJKokkosBcastBegin(B, NULL, MAT_REUSE_MATRIX, NULL, mm)); 1292 1293 // A's diag * (B's diag + B's off-diag) 1294 PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh1, Ad, false, Bd, false, mm->C1)); 1295 PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh2, Ad, false, Bo, false, mm->C2_mid)); 1296 1297 PetscCall(MatMPIAIJKokkosBcastEnd(B, NULL, MAT_REUSE_MATRIX, NULL, mm)); 1298 1299 // A's off-diag * (F's diag + F's off-diag) 1300 PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh3, Ao, false, mm->Fd, false, mm->C3)); 1301 PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh4, Ao, false, mm->Fo, false, mm->C4)); 1302 1303 // C = (Cd, Co) = (C1+C3, C2+C4) 1304 PetscCallCXX(KokkosSparse::spadd_numeric(&mm->kh1, 1.0, mm->C1, 1.0, mm->C3, mm->Cd)); 1305 PetscCallCXX(KokkosSparse::spadd_numeric(&mm->kh2, 1.0, mm->C2, 1.0, mm->C4, mm->Co)); 1306 PetscFunctionReturn(PETSC_SUCCESS); 1307 } 1308 1309 static PetscErrorCode MatProductNumeric_MPIAIJKokkos(Mat C) 1310 { 1311 Mat_MPIAIJ *cmpi = static_cast<Mat_MPIAIJ *>(C->data); 1312 Mat_Product *product; 1313 MatProductData_MPIAIJKokkos *pdata; 1314 MatProductType ptype; 1315 Mat A, B; 1316 1317 PetscFunctionBegin; 1318 MatCheckProduct(C, 1); // make sure C is a product 1319 product = C->product; 1320 pdata = static_cast<MatProductData_MPIAIJKokkos *>(product->data); 1321 ptype = product->type; 1322 A = product->A; 1323 B = product->B; 1324 1325 // See if numeric has already been done in symbolic (e.g., user calls MatMatMult(A,B,MAT_INITIAL_MATRIX,..,C)). 1326 // If yes, skip the numeric, but reset the flag so that next time when user calls MatMatMult(E,F,MAT_REUSE_MATRIX,..,C), 1327 // we still do numeric. 1328 if (pdata->reusesym) { // numeric reuses results from symbolic 1329 pdata->reusesym = PETSC_FALSE; 1330 PetscFunctionReturn(PETSC_SUCCESS); 1331 } 1332 1333 if (ptype == MATPRODUCT_AB) { 1334 PetscCall(MatProductNumeric_MPIAIJKokkos_AB(product, A, B, pdata->mmAB)); 1335 } else if (ptype == MATPRODUCT_AtB) { 1336 PetscCall(MatProductNumeric_MPIAIJKokkos_AtB(product, A, B, pdata->mmAtB)); 1337 } else if (ptype == MATPRODUCT_PtAP) { // BtAB, computed by Z = AB; C= BtZ 1338 PetscCall(MatProductNumeric_MPIAIJKokkos_AB(product, A, B, pdata->mmAB)); 1339 PetscCall(MatProductNumeric_MPIAIJKokkos_AtB(product, B, pdata->Z, pdata->mmAtB)); 1340 } 1341 1342 PetscCall(MatSeqAIJKokkosModifyDevice(cmpi->A)); // mark that A, B on device are modified 1343 PetscCall(MatSeqAIJKokkosModifyDevice(cmpi->B)); 1344 PetscFunctionReturn(PETSC_SUCCESS); 1345 } 1346 1347 static PetscErrorCode MatProductSymbolic_MPIAIJKokkos(Mat C) 1348 { 1349 Mat A, B; 1350 Mat_Product *product; 1351 MatProductType ptype; 1352 MatProductData_MPIAIJKokkos *pdata; 1353 MatMatStruct *mm = NULL; 1354 PetscInt m, n, M, N; 1355 Mat Cd, Co; 1356 MPI_Comm comm; 1357 1358 PetscFunctionBegin; 1359 PetscCall(PetscObjectGetComm((PetscObject)C, &comm)); 1360 MatCheckProduct(C, 1); 1361 product = C->product; 1362 PetscCheck(!product->data, comm, PETSC_ERR_PLIB, "Product data not empty"); 1363 ptype = product->type; 1364 A = product->A; 1365 B = product->B; 1366 1367 switch (ptype) { 1368 case MATPRODUCT_AB: 1369 m = A->rmap->n; 1370 n = B->cmap->n; 1371 M = A->rmap->N; 1372 N = B->cmap->N; 1373 break; 1374 case MATPRODUCT_AtB: 1375 m = A->cmap->n; 1376 n = B->cmap->n; 1377 M = A->cmap->N; 1378 N = B->cmap->N; 1379 break; 1380 case MATPRODUCT_PtAP: 1381 m = B->cmap->n; 1382 n = B->cmap->n; 1383 M = B->cmap->N; 1384 N = B->cmap->N; 1385 break; /* BtAB */ 1386 default: 1387 SETERRQ(comm, PETSC_ERR_PLIB, "Not for product type %s", MatProductTypes[ptype]); 1388 } 1389 1390 PetscCall(MatSetSizes(C, m, n, M, N)); 1391 PetscCall(PetscLayoutSetUp(C->rmap)); 1392 PetscCall(PetscLayoutSetUp(C->cmap)); 1393 PetscCall(MatSetType(C, ((PetscObject)A)->type_name)); 1394 1395 pdata = new MatProductData_MPIAIJKokkos(); 1396 pdata->reusesym = product->api_user; 1397 1398 if (ptype == MATPRODUCT_AB) { 1399 auto mmAB = new MatMatStruct_AB(); 1400 PetscCall(MatProductSymbolic_MPIAIJKokkos_AB(product, A, B, mmAB)); 1401 mm = pdata->mmAB = mmAB; 1402 } else if (ptype == MATPRODUCT_AtB) { 1403 auto mmAtB = new MatMatStruct_AtB(); 1404 PetscCall(MatProductSymbolic_MPIAIJKokkos_AtB(product, A, B, mmAtB)); 1405 mm = pdata->mmAtB = mmAtB; 1406 } else if (ptype == MATPRODUCT_PtAP) { // C = BtAB, computed as Z = AB; C= BtZ 1407 Mat Zd, Zo, Z; // Zd, Zo are owned by pdata->Z 1408 1409 auto mmAB = new MatMatStruct_AB(); 1410 PetscCall(MatProductSymbolic_MPIAIJKokkos_AB(product, A, B, mmAB)); // Z stored as mmAB->{Cd, Co} 1411 PetscCall(MatCreateSeqAIJKokkosWithKokkosCsrMatrix(PETSC_COMM_SELF, mmAB->Cd, &Zd)); 1412 PetscCall(MatCreateSeqAIJKokkosWithKokkosCsrMatrix(PETSC_COMM_SELF, mmAB->Co, &Zo)); 1413 pdata->mmAB = mmAB; 1414 1415 m = A->rmap->n; // Z's layout 1416 n = B->cmap->n; 1417 M = A->rmap->N; 1418 N = B->cmap->N; 1419 PetscCall(MatCreate(comm, &Z)); 1420 PetscCall(MatSetSizes(Z, m, n, M, N)); 1421 PetscCall(PetscLayoutSetUp(Z->rmap)); 1422 PetscCall(PetscLayoutSetUp(Z->cmap)); 1423 PetscCall(MatSetType(Z, MATMPIAIJKOKKOS)); 1424 PetscCall(MatSetMPIAIJKokkosWithSplitSeqAIJKokkosMatrices(Z, Zd, Zo, mmAB->garray)); 1425 1426 auto mmAtB = new MatMatStruct_AtB(); 1427 PetscCall(MatProductSymbolic_MPIAIJKokkos_AtB(product, B, Z, mmAtB)); // final result C stored as mmAtB->{Cd, Co} 1428 1429 pdata->Z = Z; // give ownership to pdata 1430 mm = pdata->mmAtB = mmAtB; 1431 } 1432 1433 PetscCall(MatCreateSeqAIJKokkosWithKokkosCsrMatrix(PETSC_COMM_SELF, mm->Cd, &Cd)); 1434 PetscCall(MatCreateSeqAIJKokkosWithKokkosCsrMatrix(PETSC_COMM_SELF, mm->Co, &Co)); 1435 PetscCall(MatSetMPIAIJKokkosWithSplitSeqAIJKokkosMatrices(C, Cd, Co, mm->garray)); 1436 1437 C->product->data = pdata; 1438 C->product->destroy = MatProductDataDestroy_MPIAIJKokkos; 1439 C->ops->productnumeric = MatProductNumeric_MPIAIJKokkos; 1440 PetscFunctionReturn(PETSC_SUCCESS); 1441 } 1442 1443 PETSC_INTERN PetscErrorCode MatProductSetFromOptions_MPIAIJKokkos(Mat mat) 1444 { 1445 Mat_Product *product = mat->product; 1446 PetscBool match = PETSC_FALSE; 1447 PetscBool usecpu = PETSC_FALSE; 1448 1449 PetscFunctionBegin; 1450 MatCheckProduct(mat, 1); 1451 if (!product->A->boundtocpu && !product->B->boundtocpu) PetscCall(PetscObjectTypeCompare((PetscObject)product->B, ((PetscObject)product->A)->type_name, &match)); 1452 if (match) { /* we can always fallback to the CPU if requested */ 1453 switch (product->type) { 1454 case MATPRODUCT_AB: 1455 if (product->api_user) { 1456 PetscOptionsBegin(PetscObjectComm((PetscObject)mat), ((PetscObject)mat)->prefix, "MatMatMult", "Mat"); 1457 PetscCall(PetscOptionsBool("-matmatmult_backend_cpu", "Use CPU code", "MatMatMult", usecpu, &usecpu, NULL)); 1458 PetscOptionsEnd(); 1459 } else { 1460 PetscOptionsBegin(PetscObjectComm((PetscObject)mat), ((PetscObject)mat)->prefix, "MatProduct_AB", "Mat"); 1461 PetscCall(PetscOptionsBool("-mat_product_algorithm_backend_cpu", "Use CPU code", "MatMatMult", usecpu, &usecpu, NULL)); 1462 PetscOptionsEnd(); 1463 } 1464 break; 1465 case MATPRODUCT_AtB: 1466 if (product->api_user) { 1467 PetscOptionsBegin(PetscObjectComm((PetscObject)mat), ((PetscObject)mat)->prefix, "MatTransposeMatMult", "Mat"); 1468 PetscCall(PetscOptionsBool("-mattransposematmult_backend_cpu", "Use CPU code", "MatTransposeMatMult", usecpu, &usecpu, NULL)); 1469 PetscOptionsEnd(); 1470 } else { 1471 PetscOptionsBegin(PetscObjectComm((PetscObject)mat), ((PetscObject)mat)->prefix, "MatProduct_AtB", "Mat"); 1472 PetscCall(PetscOptionsBool("-mat_product_algorithm_backend_cpu", "Use CPU code", "MatTransposeMatMult", usecpu, &usecpu, NULL)); 1473 PetscOptionsEnd(); 1474 } 1475 break; 1476 case MATPRODUCT_PtAP: 1477 if (product->api_user) { 1478 PetscOptionsBegin(PetscObjectComm((PetscObject)mat), ((PetscObject)mat)->prefix, "MatPtAP", "Mat"); 1479 PetscCall(PetscOptionsBool("-matptap_backend_cpu", "Use CPU code", "MatPtAP", usecpu, &usecpu, NULL)); 1480 PetscOptionsEnd(); 1481 } else { 1482 PetscOptionsBegin(PetscObjectComm((PetscObject)mat), ((PetscObject)mat)->prefix, "MatProduct_PtAP", "Mat"); 1483 PetscCall(PetscOptionsBool("-mat_product_algorithm_backend_cpu", "Use CPU code", "MatPtAP", usecpu, &usecpu, NULL)); 1484 PetscOptionsEnd(); 1485 } 1486 break; 1487 default: 1488 break; 1489 } 1490 match = (PetscBool)!usecpu; 1491 } 1492 if (match) { 1493 switch (product->type) { 1494 case MATPRODUCT_AB: 1495 case MATPRODUCT_AtB: 1496 case MATPRODUCT_PtAP: 1497 mat->ops->productsymbolic = MatProductSymbolic_MPIAIJKokkos; 1498 break; 1499 default: 1500 break; 1501 } 1502 } 1503 /* fallback to MPIAIJ ops */ 1504 if (!mat->ops->productsymbolic) PetscCall(MatProductSetFromOptions_MPIAIJ(mat)); 1505 PetscFunctionReturn(PETSC_SUCCESS); 1506 } 1507 1508 // Mirror of MatCOOStruct_MPIAIJ on device 1509 struct MatCOOStruct_MPIAIJKokkos { 1510 PetscCount n; 1511 PetscSF sf; 1512 PetscCount Annz, Bnnz; 1513 PetscCount Annz2, Bnnz2; 1514 PetscCountKokkosView Ajmap1, Aperm1; 1515 PetscCountKokkosView Bjmap1, Bperm1; 1516 PetscCountKokkosView Aimap2, Ajmap2, Aperm2; 1517 PetscCountKokkosView Bimap2, Bjmap2, Bperm2; 1518 PetscCountKokkosView Cperm1; 1519 MatScalarKokkosView sendbuf, recvbuf; 1520 1521 MatCOOStruct_MPIAIJKokkos(const MatCOOStruct_MPIAIJ *coo_h) 1522 { 1523 auto &exec = PetscGetKokkosExecutionSpace(); 1524 1525 n = coo_h->n; 1526 sf = coo_h->sf; 1527 Annz = coo_h->Annz; 1528 Bnnz = coo_h->Bnnz; 1529 Annz2 = coo_h->Annz2; 1530 Bnnz2 = coo_h->Bnnz2; 1531 Ajmap1 = Kokkos::create_mirror_view_and_copy(exec, PetscCountKokkosViewHost(coo_h->Ajmap1, coo_h->Annz + 1)); 1532 Aperm1 = Kokkos::create_mirror_view_and_copy(exec, PetscCountKokkosViewHost(coo_h->Aperm1, coo_h->Atot1)); 1533 Bjmap1 = Kokkos::create_mirror_view_and_copy(exec, PetscCountKokkosViewHost(coo_h->Bjmap1, coo_h->Bnnz + 1)); 1534 Bperm1 = Kokkos::create_mirror_view_and_copy(exec, PetscCountKokkosViewHost(coo_h->Bperm1, coo_h->Btot1)); 1535 Aimap2 = Kokkos::create_mirror_view_and_copy(exec, PetscCountKokkosViewHost(coo_h->Aimap2, coo_h->Annz2)); 1536 Ajmap2 = Kokkos::create_mirror_view_and_copy(exec, PetscCountKokkosViewHost(coo_h->Ajmap2, coo_h->Annz2 + 1)); 1537 Aperm2 = Kokkos::create_mirror_view_and_copy(exec, PetscCountKokkosViewHost(coo_h->Aperm2, coo_h->Atot2)); 1538 Bimap2 = Kokkos::create_mirror_view_and_copy(exec, PetscCountKokkosViewHost(coo_h->Bimap2, coo_h->Bnnz2)); 1539 Bjmap2 = Kokkos::create_mirror_view_and_copy(exec, PetscCountKokkosViewHost(coo_h->Bjmap2, coo_h->Bnnz2 + 1)); 1540 Bperm2 = Kokkos::create_mirror_view_and_copy(exec, PetscCountKokkosViewHost(coo_h->Bperm2, coo_h->Btot2)); 1541 Cperm1 = Kokkos::create_mirror_view_and_copy(exec, PetscCountKokkosViewHost(coo_h->Cperm1, coo_h->sendlen)); 1542 sendbuf = Kokkos::create_mirror_view(Kokkos::WithoutInitializing, exec, MatScalarKokkosViewHost(coo_h->sendbuf, coo_h->sendlen)); 1543 recvbuf = Kokkos::create_mirror_view(Kokkos::WithoutInitializing, exec, MatScalarKokkosViewHost(coo_h->recvbuf, coo_h->recvlen)); 1544 PetscCallVoid(PetscObjectReference((PetscObject)sf)); 1545 } 1546 1547 ~MatCOOStruct_MPIAIJKokkos() { PetscCallVoid(PetscSFDestroy(&sf)); } 1548 }; 1549 1550 static PetscErrorCode MatCOOStructDestroy_MPIAIJKokkos(void *data) 1551 { 1552 PetscFunctionBegin; 1553 PetscCallCXX(delete static_cast<MatCOOStruct_MPIAIJKokkos *>(data)); 1554 PetscFunctionReturn(PETSC_SUCCESS); 1555 } 1556 1557 static PetscErrorCode MatSetPreallocationCOO_MPIAIJKokkos(Mat mat, PetscCount coo_n, PetscInt coo_i[], PetscInt coo_j[]) 1558 { 1559 PetscContainer container_h, container_d; 1560 MatCOOStruct_MPIAIJ *coo_h; 1561 MatCOOStruct_MPIAIJKokkos *coo_d; 1562 1563 PetscFunctionBegin; 1564 PetscCall(MatSetPreallocationCOO_MPIAIJ(mat, coo_n, coo_i, coo_j)); /* mpiaij->A,B's type is set to seqaijkokkos */ 1565 mat->preallocated = PETSC_TRUE; 1566 PetscCall(MatAssemblyBegin(mat, MAT_FINAL_ASSEMBLY)); 1567 PetscCall(MatAssemblyEnd(mat, MAT_FINAL_ASSEMBLY)); 1568 PetscCall(MatZeroEntries(mat)); 1569 1570 // Copy the COO struct to device 1571 PetscCall(PetscObjectQuery((PetscObject)mat, "__PETSc_MatCOOStruct_Host", (PetscObject *)&container_h)); 1572 PetscCall(PetscContainerGetPointer(container_h, (void **)&coo_h)); 1573 PetscCallCXX(coo_d = new MatCOOStruct_MPIAIJKokkos(coo_h)); 1574 1575 // Put the COO struct in a container and then attach that to the matrix 1576 PetscCall(PetscContainerCreate(PETSC_COMM_SELF, &container_d)); 1577 PetscCall(PetscContainerSetPointer(container_d, coo_d)); 1578 PetscCall(PetscContainerSetUserDestroy(container_d, MatCOOStructDestroy_MPIAIJKokkos)); 1579 PetscCall(PetscObjectCompose((PetscObject)mat, "__PETSc_MatCOOStruct_Device", (PetscObject)container_d)); 1580 PetscCall(PetscContainerDestroy(&container_d)); 1581 PetscFunctionReturn(PETSC_SUCCESS); 1582 } 1583 1584 static PetscErrorCode MatSetValuesCOO_MPIAIJKokkos(Mat mat, const PetscScalar v[], InsertMode imode) 1585 { 1586 Mat_MPIAIJ *mpiaij = static_cast<Mat_MPIAIJ *>(mat->data); 1587 Mat A = mpiaij->A, B = mpiaij->B; 1588 MatScalarKokkosView Aa, Ba; 1589 MatScalarKokkosView v1; 1590 PetscMemType memtype; 1591 PetscContainer container; 1592 MatCOOStruct_MPIAIJKokkos *coo; 1593 Kokkos::DefaultExecutionSpace &exec = PetscGetKokkosExecutionSpace(); 1594 1595 PetscFunctionBegin; 1596 PetscCall(PetscObjectQuery((PetscObject)mat, "__PETSc_MatCOOStruct_Device", (PetscObject *)&container)); 1597 PetscCall(PetscContainerGetPointer(container, (void **)&coo)); 1598 1599 const auto &n = coo->n; 1600 const auto &Annz = coo->Annz; 1601 const auto &Annz2 = coo->Annz2; 1602 const auto &Bnnz = coo->Bnnz; 1603 const auto &Bnnz2 = coo->Bnnz2; 1604 const auto &vsend = coo->sendbuf; 1605 const auto &v2 = coo->recvbuf; 1606 const auto &Ajmap1 = coo->Ajmap1; 1607 const auto &Ajmap2 = coo->Ajmap2; 1608 const auto &Aimap2 = coo->Aimap2; 1609 const auto &Bjmap1 = coo->Bjmap1; 1610 const auto &Bjmap2 = coo->Bjmap2; 1611 const auto &Bimap2 = coo->Bimap2; 1612 const auto &Aperm1 = coo->Aperm1; 1613 const auto &Aperm2 = coo->Aperm2; 1614 const auto &Bperm1 = coo->Bperm1; 1615 const auto &Bperm2 = coo->Bperm2; 1616 const auto &Cperm1 = coo->Cperm1; 1617 1618 PetscCall(PetscGetMemType(v, &memtype)); /* Return PETSC_MEMTYPE_HOST when v is NULL */ 1619 if (PetscMemTypeHost(memtype)) { /* If user gave v[] in host, we need to copy it to device if any */ 1620 v1 = Kokkos::create_mirror_view_and_copy(exec, MatScalarKokkosViewHost((PetscScalar *)v, n)); 1621 } else { 1622 v1 = MatScalarKokkosView((PetscScalar *)v, n); /* Directly use v[]'s memory */ 1623 } 1624 1625 if (imode == INSERT_VALUES) { 1626 PetscCall(MatSeqAIJGetKokkosViewWrite(A, &Aa)); /* write matrix values */ 1627 PetscCall(MatSeqAIJGetKokkosViewWrite(B, &Ba)); 1628 } else { 1629 PetscCall(MatSeqAIJGetKokkosView(A, &Aa)); /* read & write matrix values */ 1630 PetscCall(MatSeqAIJGetKokkosView(B, &Ba)); 1631 } 1632 1633 PetscCall(PetscLogGpuTimeBegin()); 1634 /* Pack entries to be sent to remote */ 1635 Kokkos::parallel_for(Kokkos::RangePolicy<>(exec, 0, vsend.extent(0)), KOKKOS_LAMBDA(const PetscCount i) { vsend(i) = v1(Cperm1(i)); }); 1636 1637 /* Send remote entries to their owner and overlap the communication with local computation */ 1638 PetscCall(PetscSFReduceWithMemTypeBegin(coo->sf, MPIU_SCALAR, PETSC_MEMTYPE_KOKKOS, vsend.data(), PETSC_MEMTYPE_KOKKOS, v2.data(), MPI_REPLACE)); 1639 /* Add local entries to A and B in one kernel */ 1640 Kokkos::parallel_for( 1641 Kokkos::RangePolicy<>(exec, 0, Annz + Bnnz), KOKKOS_LAMBDA(PetscCount i) { 1642 PetscScalar sum = 0.0; 1643 if (i < Annz) { 1644 for (PetscCount k = Ajmap1(i); k < Ajmap1(i + 1); k++) sum += v1(Aperm1(k)); 1645 Aa(i) = (imode == INSERT_VALUES ? 0.0 : Aa(i)) + sum; 1646 } else { 1647 i -= Annz; 1648 for (PetscCount k = Bjmap1(i); k < Bjmap1(i + 1); k++) sum += v1(Bperm1(k)); 1649 Ba(i) = (imode == INSERT_VALUES ? 0.0 : Ba(i)) + sum; 1650 } 1651 }); 1652 PetscCall(PetscSFReduceEnd(coo->sf, MPIU_SCALAR, vsend.data(), v2.data(), MPI_REPLACE)); 1653 1654 /* Add received remote entries to A and B in one kernel */ 1655 Kokkos::parallel_for( 1656 Kokkos::RangePolicy<>(exec, 0, Annz2 + Bnnz2), KOKKOS_LAMBDA(PetscCount i) { 1657 if (i < Annz2) { 1658 for (PetscCount k = Ajmap2(i); k < Ajmap2(i + 1); k++) Aa(Aimap2(i)) += v2(Aperm2(k)); 1659 } else { 1660 i -= Annz2; 1661 for (PetscCount k = Bjmap2(i); k < Bjmap2(i + 1); k++) Ba(Bimap2(i)) += v2(Bperm2(k)); 1662 } 1663 }); 1664 PetscCall(PetscLogGpuTimeEnd()); 1665 1666 if (imode == INSERT_VALUES) { 1667 PetscCall(MatSeqAIJRestoreKokkosViewWrite(A, &Aa)); /* Increase A & B's state etc. */ 1668 PetscCall(MatSeqAIJRestoreKokkosViewWrite(B, &Ba)); 1669 } else { 1670 PetscCall(MatSeqAIJRestoreKokkosView(A, &Aa)); 1671 PetscCall(MatSeqAIJRestoreKokkosView(B, &Ba)); 1672 } 1673 PetscFunctionReturn(PETSC_SUCCESS); 1674 } 1675 1676 static PetscErrorCode MatDestroy_MPIAIJKokkos(Mat A) 1677 { 1678 PetscFunctionBegin; 1679 PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatMPIAIJSetPreallocation_C", NULL)); 1680 PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatMPIAIJGetLocalMatMerge_C", NULL)); 1681 PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatSetPreallocationCOO_C", NULL)); 1682 PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatSetValuesCOO_C", NULL)); 1683 PetscCall(MatDestroy_MPIAIJ(A)); 1684 PetscFunctionReturn(PETSC_SUCCESS); 1685 } 1686 1687 static PetscErrorCode MatShift_MPIAIJKokkos(Mat A, PetscScalar a) 1688 { 1689 Mat_MPIAIJ *mpiaij = static_cast<Mat_MPIAIJ *>(A->data); 1690 PetscBool congruent; 1691 1692 PetscFunctionBegin; 1693 PetscCall(MatHasCongruentLayouts(A, &congruent)); 1694 if (congruent) { // square matrix and the diagonals are solely in the diag block 1695 PetscCall(MatShift(mpiaij->A, a)); 1696 } else { // too hard, use the general version 1697 PetscCall(MatShift_Basic(A, a)); 1698 } 1699 PetscFunctionReturn(PETSC_SUCCESS); 1700 } 1701 1702 static PetscErrorCode MatSetOps_MPIAIJKokkos(Mat B) 1703 { 1704 PetscFunctionBegin; 1705 B->ops->assemblyend = MatAssemblyEnd_MPIAIJKokkos; 1706 B->ops->mult = MatMult_MPIAIJKokkos; 1707 B->ops->multadd = MatMultAdd_MPIAIJKokkos; 1708 B->ops->multtranspose = MatMultTranspose_MPIAIJKokkos; 1709 B->ops->productsetfromoptions = MatProductSetFromOptions_MPIAIJKokkos; 1710 B->ops->destroy = MatDestroy_MPIAIJKokkos; 1711 B->ops->shift = MatShift_MPIAIJKokkos; 1712 1713 PetscCall(PetscObjectComposeFunction((PetscObject)B, "MatMPIAIJSetPreallocation_C", MatMPIAIJSetPreallocation_MPIAIJKokkos)); 1714 PetscCall(PetscObjectComposeFunction((PetscObject)B, "MatMPIAIJGetLocalMatMerge_C", MatMPIAIJGetLocalMatMerge_MPIAIJKokkos)); 1715 PetscCall(PetscObjectComposeFunction((PetscObject)B, "MatSetPreallocationCOO_C", MatSetPreallocationCOO_MPIAIJKokkos)); 1716 PetscCall(PetscObjectComposeFunction((PetscObject)B, "MatSetValuesCOO_C", MatSetValuesCOO_MPIAIJKokkos)); 1717 PetscFunctionReturn(PETSC_SUCCESS); 1718 } 1719 1720 PETSC_INTERN PetscErrorCode MatConvert_MPIAIJ_MPIAIJKokkos(Mat A, MatType mtype, MatReuse reuse, Mat *newmat) 1721 { 1722 Mat B; 1723 Mat_MPIAIJ *a; 1724 1725 PetscFunctionBegin; 1726 if (reuse == MAT_INITIAL_MATRIX) { 1727 PetscCall(MatDuplicate(A, MAT_COPY_VALUES, newmat)); 1728 } else if (reuse == MAT_REUSE_MATRIX) { 1729 PetscCall(MatCopy(A, *newmat, SAME_NONZERO_PATTERN)); 1730 } 1731 B = *newmat; 1732 1733 B->boundtocpu = PETSC_FALSE; 1734 PetscCall(PetscFree(B->defaultvectype)); 1735 PetscCall(PetscStrallocpy(VECKOKKOS, &B->defaultvectype)); 1736 PetscCall(PetscObjectChangeTypeName((PetscObject)B, MATMPIAIJKOKKOS)); 1737 1738 a = static_cast<Mat_MPIAIJ *>(A->data); 1739 if (a->A) PetscCall(MatSetType(a->A, MATSEQAIJKOKKOS)); 1740 if (a->B) PetscCall(MatSetType(a->B, MATSEQAIJKOKKOS)); 1741 if (a->lvec) PetscCall(VecSetType(a->lvec, VECSEQKOKKOS)); 1742 PetscCall(MatSetOps_MPIAIJKokkos(B)); 1743 PetscFunctionReturn(PETSC_SUCCESS); 1744 } 1745 1746 /*MC 1747 MATAIJKOKKOS - "mpiaijkokkos", a matrix type to be used for CSR sparse matrices with Kokkos 1748 1749 A matrix type using Kokkos-Kernels CrsMatrix type for portability across different device types 1750 1751 Options Database Key: 1752 . -mat_type aijkokkos - sets the matrix type to `MATAIJKOKKOS` 1753 1754 Level: beginner 1755 1756 .seealso: [](ch_matrices), `Mat`, `MatCreateAIJKokkos()`, `MATSEQAIJKOKKOS`, `MATSEQAIJ`, `MATMPIAIJ` 1757 M*/ 1758 PETSC_EXTERN PetscErrorCode MatCreate_MPIAIJKokkos(Mat A) 1759 { 1760 PetscFunctionBegin; 1761 PetscCall(PetscKokkosInitializeCheck()); 1762 PetscCall(MatCreate_MPIAIJ(A)); 1763 PetscCall(MatConvert_MPIAIJ_MPIAIJKokkos(A, MATMPIAIJKOKKOS, MAT_INPLACE_MATRIX, &A)); 1764 PetscFunctionReturn(PETSC_SUCCESS); 1765 } 1766 1767 /*@C 1768 MatCreateAIJKokkos - Creates a sparse matrix in `MATAIJKOKOS` (compressed row) format 1769 (the default parallel PETSc format). This matrix will ultimately pushed down 1770 to Kokkos for calculations. 1771 1772 Collective 1773 1774 Input Parameters: 1775 + comm - MPI communicator, set to `PETSC_COMM_SELF` 1776 . m - number of local rows (or `PETSC_DECIDE` to have calculated if `M` is given) 1777 This value should be the same as the local size used in creating the 1778 y vector for the matrix-vector product y = Ax. 1779 . n - This value should be the same as the local size used in creating the 1780 x vector for the matrix-vector product y = Ax. (or `PETSC_DECIDE` to have 1781 calculated if N is given) For square matrices n is almost always `m`. 1782 . M - number of global rows (or `PETSC_DETERMINE` to have calculated if `m` is given) 1783 . N - number of global columns (or `PETSC_DETERMINE` to have calculated if `n` is given) 1784 . d_nz - number of nonzeros per row in DIAGONAL portion of local submatrix 1785 (same value is used for all local rows) 1786 . d_nnz - array containing the number of nonzeros in the various rows of the 1787 DIAGONAL portion of the local submatrix (possibly different for each row) 1788 or `NULL`, if `d_nz` is used to specify the nonzero structure. 1789 The size of this array is equal to the number of local rows, i.e `m`. 1790 For matrices you plan to factor you must leave room for the diagonal entry and 1791 put in the entry even if it is zero. 1792 . o_nz - number of nonzeros per row in the OFF-DIAGONAL portion of local 1793 submatrix (same value is used for all local rows). 1794 - o_nnz - array containing the number of nonzeros in the various rows of the 1795 OFF-DIAGONAL portion of the local submatrix (possibly different for 1796 each row) or `NULL`, if `o_nz` is used to specify the nonzero 1797 structure. The size of this array is equal to the number 1798 of local rows, i.e `m`. 1799 1800 Output Parameter: 1801 . A - the matrix 1802 1803 Level: intermediate 1804 1805 Notes: 1806 It is recommended that one use the `MatCreate()`, `MatSetType()` and/or `MatSetFromOptions()`, 1807 MatXXXXSetPreallocation() paradigm instead of this routine directly. 1808 [MatXXXXSetPreallocation() is, for example, `MatSeqAIJSetPreallocation()`] 1809 1810 The AIJ format, also called compressed row storage), is fully compatible with standard Fortran 1811 storage. That is, the stored row and column indices can begin at 1812 either one (as in Fortran) or zero. 1813 1814 .seealso: [](ch_matrices), `Mat`, `MATAIJKOKOS`, `MATSEQAIJKOKOS`, `MATMPIAIJKOKOS`, `MatCreate()`, `MatCreateAIJ()`, `MatSetValues()`, 1815 `MatSeqAIJSetColumnIndices()`, `MatCreateSeqAIJWithArrays()`, `MATMPIAIJKOKKOS`, `MATAIJKOKKOS` 1816 @*/ 1817 PetscErrorCode MatCreateAIJKokkos(MPI_Comm comm, PetscInt m, PetscInt n, PetscInt M, PetscInt N, PetscInt d_nz, const PetscInt d_nnz[], PetscInt o_nz, const PetscInt o_nnz[], Mat *A) 1818 { 1819 PetscMPIInt size; 1820 1821 PetscFunctionBegin; 1822 PetscCall(MatCreate(comm, A)); 1823 PetscCall(MatSetSizes(*A, m, n, M, N)); 1824 PetscCallMPI(MPI_Comm_size(comm, &size)); 1825 if (size > 1) { 1826 PetscCall(MatSetType(*A, MATMPIAIJKOKKOS)); 1827 PetscCall(MatMPIAIJSetPreallocation(*A, d_nz, d_nnz, o_nz, o_nnz)); 1828 } else { 1829 PetscCall(MatSetType(*A, MATSEQAIJKOKKOS)); 1830 PetscCall(MatSeqAIJSetPreallocation(*A, d_nz, d_nnz)); 1831 } 1832 PetscFunctionReturn(PETSC_SUCCESS); 1833 } 1834