1 #include <petsc_kokkos.hpp> 2 #include <petscvec_kokkos.hpp> 3 #include <petscpkg_version.h> 4 #include <petsc/private/sfimpl.h> 5 #include <petsc/private/kokkosimpl.hpp> 6 #include <../src/mat/impls/aij/seq/kokkos/aijkok.hpp> 7 #include <../src/mat/impls/aij/mpi/mpiaij.h> 8 #include <KokkosSparse_spadd.hpp> 9 #include <KokkosSparse_spgemm.hpp> 10 11 static PetscErrorCode MatAssemblyEnd_MPIAIJKokkos(Mat A, MatAssemblyType mode) 12 { 13 Mat_MPIAIJ *mpiaij = (Mat_MPIAIJ *)A->data; 14 15 PetscFunctionBegin; 16 PetscCall(MatAssemblyEnd_MPIAIJ(A, mode)); 17 /* E.g., MatCreateSubMatrix() calls MatCreateMPIAIJWithSeqAIJ(comm,A,B,..), which creates Bnew of SEQAIJ and destroys B of SEQAIJKOKKOS. 18 Thus we finalize A/B/lvec's type in MatAssemblyEnd() to handle various cases. 19 */ 20 if (mode == MAT_FINAL_ASSEMBLY) { 21 PetscScalarKokkosView v; 22 23 PetscCall(MatSetType(mpiaij->A, MATSEQAIJKOKKOS)); 24 PetscCall(MatSetType(mpiaij->B, MATSEQAIJKOKKOS)); 25 PetscCall(VecSetType(mpiaij->lvec, VECSEQKOKKOS)); // lvec is init'ed on host, without copying to device 26 PetscCall(VecGetKokkosViewWrite(mpiaij->lvec, &v)); // mark lvec updated on device, as we never need to init lvec on device 27 PetscCall(VecRestoreKokkosViewWrite(mpiaij->lvec, &v)); 28 } 29 PetscFunctionReturn(PETSC_SUCCESS); 30 } 31 32 static PetscErrorCode MatMPIAIJSetPreallocation_MPIAIJKokkos(Mat mat, PetscInt d_nz, const PetscInt d_nnz[], PetscInt o_nz, const PetscInt o_nnz[]) 33 { 34 Mat_MPIAIJ *mpiaij; 35 36 PetscFunctionBegin; 37 // reuse MPIAIJ's preallocation, which sets A/B's blocksize along other things 38 PetscCall(MatMPIAIJSetPreallocation_MPIAIJ(mat, d_nz, d_nnz, o_nz, o_nnz)); 39 mpiaij = static_cast<Mat_MPIAIJ *>(mat->data); 40 PetscCall(MatConvert_SeqAIJ_SeqAIJKokkos(mpiaij->A, MATSEQAIJKOKKOS, MAT_INPLACE_MATRIX, &mpiaij->A)); 41 PetscCall(MatConvert_SeqAIJ_SeqAIJKokkos(mpiaij->B, MATSEQAIJKOKKOS, MAT_INPLACE_MATRIX, &mpiaij->B)); 42 PetscFunctionReturn(PETSC_SUCCESS); 43 } 44 45 static PetscErrorCode MatMult_MPIAIJKokkos(Mat mat, Vec xx, Vec yy) 46 { 47 Mat_MPIAIJ *mpiaij = (Mat_MPIAIJ *)mat->data; 48 PetscInt nt; 49 50 PetscFunctionBegin; 51 PetscCall(VecGetLocalSize(xx, &nt)); 52 PetscCheck(nt == mat->cmap->n, PETSC_COMM_SELF, PETSC_ERR_ARG_SIZ, "Incompatible partition of mat (%" PetscInt_FMT ") and xx (%" PetscInt_FMT ")", mat->cmap->n, nt); 53 PetscCall(VecScatterBegin(mpiaij->Mvctx, xx, mpiaij->lvec, INSERT_VALUES, SCATTER_FORWARD)); 54 PetscCall((*mpiaij->A->ops->mult)(mpiaij->A, xx, yy)); 55 PetscCall(VecScatterEnd(mpiaij->Mvctx, xx, mpiaij->lvec, INSERT_VALUES, SCATTER_FORWARD)); 56 PetscCall((*mpiaij->B->ops->multadd)(mpiaij->B, mpiaij->lvec, yy, yy)); 57 PetscFunctionReturn(PETSC_SUCCESS); 58 } 59 60 static PetscErrorCode MatMultAdd_MPIAIJKokkos(Mat mat, Vec xx, Vec yy, Vec zz) 61 { 62 Mat_MPIAIJ *mpiaij = (Mat_MPIAIJ *)mat->data; 63 PetscInt nt; 64 65 PetscFunctionBegin; 66 PetscCall(VecGetLocalSize(xx, &nt)); 67 PetscCheck(nt == mat->cmap->n, PETSC_COMM_SELF, PETSC_ERR_ARG_SIZ, "Incompatible partition of mat (%" PetscInt_FMT ") and xx (%" PetscInt_FMT ")", mat->cmap->n, nt); 68 PetscCall(VecScatterBegin(mpiaij->Mvctx, xx, mpiaij->lvec, INSERT_VALUES, SCATTER_FORWARD)); 69 PetscCall((*mpiaij->A->ops->multadd)(mpiaij->A, xx, yy, zz)); 70 PetscCall(VecScatterEnd(mpiaij->Mvctx, xx, mpiaij->lvec, INSERT_VALUES, SCATTER_FORWARD)); 71 PetscCall((*mpiaij->B->ops->multadd)(mpiaij->B, mpiaij->lvec, zz, zz)); 72 PetscFunctionReturn(PETSC_SUCCESS); 73 } 74 75 static PetscErrorCode MatMultTranspose_MPIAIJKokkos(Mat mat, Vec xx, Vec yy) 76 { 77 Mat_MPIAIJ *mpiaij = (Mat_MPIAIJ *)mat->data; 78 PetscInt nt; 79 80 PetscFunctionBegin; 81 PetscCall(VecGetLocalSize(xx, &nt)); 82 PetscCheck(nt == mat->rmap->n, PETSC_COMM_SELF, PETSC_ERR_ARG_SIZ, "Incompatible partition of mat (%" PetscInt_FMT ") and xx (%" PetscInt_FMT ")", mat->rmap->n, nt); 83 PetscCall((*mpiaij->B->ops->multtranspose)(mpiaij->B, xx, mpiaij->lvec)); 84 PetscCall((*mpiaij->A->ops->multtranspose)(mpiaij->A, xx, yy)); 85 PetscCall(VecScatterBegin(mpiaij->Mvctx, mpiaij->lvec, yy, ADD_VALUES, SCATTER_REVERSE)); 86 PetscCall(VecScatterEnd(mpiaij->Mvctx, mpiaij->lvec, yy, ADD_VALUES, SCATTER_REVERSE)); 87 PetscFunctionReturn(PETSC_SUCCESS); 88 } 89 90 /* Merge the "A, B" matrices of mat into a matrix C. mat's type is MPIAIJKOKKOS. C's type is MATSEQAIJKOKKOS. 91 A is put before B. C's size would be A->rmap->n by (A->cmap->n + B->cmap->n). 92 C still uses local column ids. Their corresponding global column ids are returned in glob. 93 */ 94 static PetscErrorCode MatMPIAIJGetLocalMatMerge_MPIAIJKokkos(Mat mat, MatReuse reuse, IS *glob, Mat *C) 95 { 96 Mat Ad, Ao; 97 const PetscInt *cmap; 98 99 PetscFunctionBegin; 100 PetscCall(MatMPIAIJGetSeqAIJ(mat, &Ad, &Ao, &cmap)); 101 PetscCall(MatSeqAIJKokkosMergeMats(Ad, Ao, reuse, C)); 102 if (glob) { 103 PetscInt cst, i, dn, on, *gidx; 104 PetscCall(MatGetLocalSize(Ad, NULL, &dn)); 105 PetscCall(MatGetLocalSize(Ao, NULL, &on)); 106 PetscCall(MatGetOwnershipRangeColumn(mat, &cst, NULL)); 107 PetscCall(PetscMalloc1(dn + on, &gidx)); 108 for (i = 0; i < dn; i++) gidx[i] = cst + i; 109 for (i = 0; i < on; i++) gidx[i + dn] = cmap[i]; 110 PetscCall(ISCreateGeneral(PetscObjectComm((PetscObject)Ad), dn + on, gidx, PETSC_OWN_POINTER, glob)); 111 } 112 PetscFunctionReturn(PETSC_SUCCESS); 113 } 114 115 /* Structs used in matrix products of type C=AB, C=A^tB and C=B^tAB */ 116 struct MatMatStruct { 117 PetscInt n, *garray; // C's garray and its size. 118 KokkosCsrMatrix Cd, Co; // C is in split form matrices (all in local column indcies) 119 KokkosCsrMatrix C1, C2, C3, C4; // intermediate mat products 120 KokkosCsrMatrix C2_mid, C4_mid; // alias of C2, C4; share their a[], i[], but with different j[] (hence column size) 121 PetscIntKokkosView E_NzLeft; 122 PetscSF sf = nullptr; // SF to bcast or reduce matrices E to F 123 MatScalarKokkosView rootBuf, leafBuf; 124 KokkosCsrMatrix Fd, Fo; // F in split form 125 126 KernelHandle kh1; // compute C1, add C1+C3 or C1+Fd 127 KernelHandle kh2; // compute C2, add C2+C4 or C2+Fo 128 KernelHandle kh3; // compute C3 129 KernelHandle kh4; // compute C4 130 131 PetscInt E_TeamSize; // kernel launching parameters in merging E or splitting F 132 PetscInt E_VectorLength; 133 PetscInt E_RowsPerTeam; 134 PetscInt F_TeamSize; 135 PetscInt F_VectorLength; 136 PetscInt F_RowsPerTeam; 137 138 ~MatMatStruct() 139 { 140 PetscFunctionBegin; 141 PetscCallAbort(PETSC_COMM_SELF, PetscSFDestroy(&sf)); 142 PetscFunctionReturnVoid(); 143 } 144 }; 145 146 struct MatMatStruct_AB : public MatMatStruct { 147 PetscIntKokkosView F_NzLeft; // plans to split F (in leafbuf) into Fd, Fo 148 PetscIntKokkosView irootloc; // plans to put E (i.e., Bd, Bo) into rootBuf 149 PetscIntKokkosView rowoffset; 150 }; 151 152 struct MatMatStruct_AtB : public MatMatStruct { 153 MatColIdxKokkosView Fdjmap; // plans to reduce data in rootBuf to Fd, Fo 154 MatColIdxKokkosView Fdjperm; 155 MatColIdxKokkosView Fojmap; 156 MatColIdxKokkosView Fojperm; 157 }; 158 159 struct MatProductData_MPIAIJKokkos { 160 MatMatStruct_AB *mmAB = nullptr; 161 MatMatStruct_AtB *mmAtB = nullptr; 162 PetscBool reusesym = PETSC_FALSE; 163 Mat Z = nullptr; // store Z=AB in computing BtAB 164 165 ~MatProductData_MPIAIJKokkos() 166 { 167 delete mmAB; 168 delete mmAtB; 169 PetscCallAbort(PETSC_COMM_SELF, MatDestroy(&Z)); 170 } 171 }; 172 173 static PetscErrorCode MatProductDataDestroy_MPIAIJKokkos(void *data) 174 { 175 PetscFunctionBegin; 176 PetscCallCXX(delete static_cast<MatProductData_MPIAIJKokkos *>(data)); 177 PetscFunctionReturn(PETSC_SUCCESS); 178 } 179 180 /* MatSetMPIAIJKokkosWithSplitSeqAIJKokkosMatrices - Set the diag and offdiag matrices of a MATMPIAIJKOKKOS matrix. 181 It is similar to MatCreateMPIAIJWithSplitArrays. 182 183 Input Parameters: 184 + mat - the MATMPIAIJKOKKOS matrix, which should have its type and layout set, but should not have its diag, offdiag matrices set 185 . A - the diag matrix using local col ids 186 - B - the offdiag matrix using global col ids 187 188 Output Parameter: 189 . mat - the updated MATMPIAIJKOKKOS matrix 190 */ 191 static PetscErrorCode MatSetMPIAIJKokkosWithSplitSeqAIJKokkosMatrices(Mat mat, Mat A, Mat B, PetscInt *garray) 192 { 193 Mat_MPIAIJ *mpiaij = static_cast<Mat_MPIAIJ *>(mat->data); 194 PetscInt m, n, M, N, Am, An, Bm, Bn; 195 196 PetscFunctionBegin; 197 PetscCall(MatGetSize(mat, &M, &N)); 198 PetscCall(MatGetLocalSize(mat, &m, &n)); 199 PetscCall(MatGetLocalSize(A, &Am, &An)); 200 PetscCall(MatGetLocalSize(B, &Bm, &Bn)); 201 202 PetscCheck(m == Am && m == Bm, PETSC_COMM_SELF, PETSC_ERR_PLIB, "local number of rows do not match"); 203 PetscCheck(n == An, PETSC_COMM_SELF, PETSC_ERR_PLIB, "local number of columns do not match"); 204 // PetscCheck(N == Bn, PETSC_COMM_SELF, PETSC_ERR_PLIB, "global number of columns do not match"); 205 PetscCheck(!mpiaij->A && !mpiaij->B, PETSC_COMM_SELF, PETSC_ERR_PLIB, "A, B of the MPIAIJ matrix are not empty"); 206 mpiaij->A = A; 207 mpiaij->B = B; 208 mpiaij->garray = garray; 209 210 mat->preallocated = PETSC_TRUE; 211 mat->nooffprocentries = PETSC_TRUE; /* See MatAssemblyBegin_MPIAIJ. In effect, making MatAssemblyBegin a nop */ 212 213 PetscCall(MatSetOption(mat, MAT_NO_OFF_PROC_ENTRIES, PETSC_TRUE)); 214 PetscCall(MatAssemblyBegin(mat, MAT_FINAL_ASSEMBLY)); 215 /* MatAssemblyEnd is critical here. It sets mat->offloadmask according to A and B's, and 216 also gets mpiaij->B compacted, with its col ids and size reduced 217 */ 218 PetscCall(MatAssemblyEnd(mat, MAT_FINAL_ASSEMBLY)); 219 PetscCall(MatSetOption(mat, MAT_NO_OFF_PROC_ENTRIES, PETSC_FALSE)); 220 PetscCall(MatSetOption(mat, MAT_NEW_NONZERO_LOCATION_ERR, PETSC_TRUE)); 221 PetscFunctionReturn(PETSC_SUCCESS); 222 } 223 224 // Adapted from Kokkos-Kernels spmv_launch_parameters(), to get parameters in Kokkos nested loops which we used to merge or 225 // split csr matrices. The rule is to have "vector_length * team_size" be around 256 on GPUs (e.g., for a CUDA thread block) 226 template <class ExecutionSpace> 227 static PetscErrorCode MatMergeGetLaunchParameters(PetscInt numRows, PetscInt nnz, PetscInt rows_per_thread, PetscInt &team_size, PetscInt &vector_length, PetscInt &rows_per_team) 228 { 229 Kokkos::TeamPolicy<ExecutionSpace> teamPolicy(128, Kokkos::AUTO); 230 231 PetscFunctionBegin; 232 PetscInt nnz_per_row = numRows ? (nnz / numRows) : 0; // we might meet empty matrices 233 234 if (nnz_per_row < 1) nnz_per_row = 1; 235 236 int max_vector_length = teamPolicy.vector_length_max(); 237 238 if (vector_length < 1) { 239 vector_length = 1; 240 while (vector_length < max_vector_length && vector_length * 6 < nnz_per_row) vector_length *= 2; 241 } 242 243 // Determine rows per thread 244 if (rows_per_thread < 1) { 245 if (KokkosKernels::Impl::kk_is_gpu_exec_space<ExecutionSpace>()) rows_per_thread = 1; 246 else { 247 if (nnz_per_row < 20 && nnz > 5000000) { 248 rows_per_thread = 256; 249 } else rows_per_thread = 64; 250 } 251 } 252 253 if (team_size < 1) { 254 if (KokkosKernels::Impl::kk_is_gpu_exec_space<ExecutionSpace>()) { 255 team_size = 256 / vector_length; 256 } else { 257 team_size = 1; 258 } 259 } 260 261 rows_per_team = rows_per_thread * team_size; 262 263 if (rows_per_team < 0) { 264 PetscInt nnz_per_team = 4096; 265 PetscInt conc = ExecutionSpace().concurrency(); 266 while ((conc * nnz_per_team * 4 > nnz) && (nnz_per_team > 256)) nnz_per_team /= 2; 267 rows_per_team = (nnz_per_team + nnz_per_row - 1) / nnz_per_row; 268 } 269 PetscFunctionReturn(PETSC_SUCCESS); 270 } 271 272 /* 273 Reduce two sets of global indices into local ones 274 275 Input Parameters: 276 + n1 - size of garray1[], the first set 277 . garray1[n1] - a sorted global index array (without duplicates) 278 . m - size of indices[], the second set 279 - indices[m] - a unsorted global index array (might have duplicates), which will be updated on output into local ones 280 281 Output Parameters: 282 + n2 - size of garray2[], the merged set, which combines garray1[] and indices[] 283 . garray2[n2] - allocated by callee using PetscMalloc1(). Contains sorted unique global indices (without duplicates). Caller needs to free it. 284 . map[n1] - allocated by caller. It gives garray1[i] = garray2[map[i]] 285 - indices[m] - on output, global indices in this array are rewritten with local ones, i.e, indices_input[i] = garray2[indices_output[i]] 286 287 Example, say 288 n1 = 5 289 garray1[5] = {1, 4, 7, 8, 10} 290 m = 4 291 indices[4] = {2, 4, 8, 9} 292 293 Combining them together, we have 7 global indices in garray2[] 294 n2 = 7 295 garray2[7] = {1, 2, 4, 7, 8, 9, 10} 296 297 And we have map[] to connect "garray1[i] = garray2[map[i]], i=[0,n1)" 298 map[5] = {0, 2, 3, 4, 6} 299 300 On output, indices[] is updated with local indices 301 indices[4] = {1, 2, 4, 5} 302 */ 303 static PetscErrorCode ReduceTwoSetsOfGlobalIndices(PetscInt n1, const PetscInt *garray1, PetscInt m, PetscInt *indices, PetscInt *n2_, PetscInt **garray2_, PetscInt *map) 304 { 305 PetscHMapI g2l = nullptr; 306 PetscHashIter iter; 307 PetscInt tot, key, val; // total unique global indices. key is global id; val is local id 308 PetscInt n2, *garray2; 309 310 PetscFunctionBegin; 311 tot = 0; 312 PetscCall(PetscHMapICreateWithSize(n1, &g2l)); 313 for (PetscInt i = 0; i < m; i++) { // insert those in indices[] 314 PetscCall(PetscHMapIGetWithDefault(g2l, indices[i], -1, &val)); // if not exist, val is set with -1 315 if (val < 0) PetscCall(PetscHMapISet(g2l, indices[i], tot++)); // val < 0 means gid is not in the hash table yet 316 } 317 318 for (PetscInt i = 0; i < n1; i++) { // insert those in garray1[] 319 PetscCall(PetscHMapIGetWithDefault(g2l, garray1[i], -1, &val)); 320 if (val < 0) PetscCall(PetscHMapISet(g2l, garray1[i], tot++)); 321 } 322 323 // Pull out (unique) globals in the hash table and put them in garray2[] 324 n2 = tot; 325 PetscCall(PetscMalloc1(n2, &garray2)); 326 tot = 0; 327 PetscHashIterBegin(g2l, iter); 328 while (!PetscHashIterAtEnd(g2l, iter)) { 329 PetscHashIterGetKey(g2l, iter, key); 330 PetscHashIterNext(g2l, iter); 331 garray2[tot++] = key; 332 } 333 334 // Sort garray2[] and then map them to local indices starting from 0 335 PetscCall(PetscSortInt(n2, garray2)); 336 PetscCall(PetscHMapIClear(g2l)); 337 for (PetscInt i = 0; i < tot; i++) PetscCall(PetscHMapISet(g2l, garray2[i], i)); // i is the local id 338 339 // Rewrite indices[] with local indices 340 for (PetscInt i = 0; i < m; i++) { 341 PetscCall(PetscHMapIGetWithDefault(g2l, indices[i], -1, &val)); 342 PetscAssert(val >= 0, PETSC_COMM_SELF, PETSC_ERR_PLIB, "Met a negative local column index"); 343 indices[i] = val; 344 } 345 // Record the map that maps garray1[i] to garray2[map[i]] 346 for (PetscInt i = 0; i < n1; i++) PetscCall(PetscHMapIGetWithDefault(g2l, garray1[i], -1, &map[i])); 347 PetscCall(PetscHMapIDestroy(&g2l)); 348 *n2_ = n2; 349 *garray2_ = garray2; 350 PetscFunctionReturn(PETSC_SUCCESS); 351 } 352 353 /* 354 MatMPIAIJKokkosReduce - Reduce rows of a MPIAIJKOKKOS matrix (E, in split form) to produce another matrix (F, also in split form, stored in mm) 355 356 It is the reverse of MatMPIAIJKokkosBcast() in some sense, but with a different signature since we do not really need a fully populated MPIAIJKOKKOS E. 357 358 Think each row of E as a leaf, then the given ownerSF specifies roots for the leaves. Roots may connect to multiple leaves. 359 In this routine, we sparse-merge leaves (rows) at their roots to form potentially longer rows in F. F's number of rows will be nroots of ownerSF. 360 361 Input Parameters: 362 + comm - MPI communicator of E 363 . A - diag block of E, using local column indices 364 . B - off-diag block of E, using local column indices 365 . cstart - (global) start column of Ed 366 . cend - (global) end column + 1 of Ed. In other words, E's column ownership is in range of [cstart, cend) 367 . garray1[n1] - global column indices of Eo. Here n1 is Eo's column size. 368 . ownerSF - the SF specifies ownership (root) of rows in E 369 . reuse - either MAT_INITIAL_MATRIX or MAT_REUSE_MATRIX 370 - mm - to stash intermediate data structures for reuse 371 372 Output Parameters: 373 + map[n1] - allocated by caller. It maps garray1[] to garray2[]. See more at ReduceTwoSetsOfGlobalIndices(). 374 - mm - contains various info, such as garray2[], F (Fd, Fo) etc. 375 376 Notes: 377 When reuse = MAT_REUSE_MATRIX, cstart, cend, garray1, ownerSF, map are not significant. 378 379 */ 380 static PetscErrorCode MatMPIAIJKokkosReduceBegin(MPI_Comm comm, KokkosCsrMatrix A, KokkosCsrMatrix B, PetscInt cstart, PetscInt cend, const PetscInt *garray1, PetscSF ownerSF, MatReuse reuse, PetscInt *map, MatMatStruct_AtB *mm) 381 { 382 PetscFunctionBegin; 383 if (reuse == MAT_INITIAL_MATRIX) { 384 PetscInt Em = A.numRows(), Fm; 385 PetscInt n1 = B.numCols(); 386 387 PetscCall(PetscSFGetGraph(ownerSF, &Fm, NULL, NULL, NULL)); // Fm = #rows of F = nroots of ownerSF 388 389 // Do the analysis on host 390 auto Ai_h = Kokkos::create_mirror_view_and_copy(Kokkos::HostSpace(), A.graph.row_map); 391 auto Aj_h = Kokkos::create_mirror_view_and_copy(Kokkos::HostSpace(), A.graph.entries); 392 auto Bi_h = Kokkos::create_mirror_view_and_copy(Kokkos::HostSpace(), B.graph.row_map); 393 auto Bj_h = Kokkos::create_mirror_view_and_copy(Kokkos::HostSpace(), B.graph.entries); 394 const MatRowMapType *Ai = Ai_h.data(), *Bi = Bi_h.data(); 395 const MatColIdxType *Aj = Aj_h.data(), *Bj = Bj_h.data(); 396 397 // Count how many nonzeros of each row in E are in the left of the diag block [cstart,cend) 398 PetscIntKokkosViewHost E_NzLeft_h(NoInit("E_NzLeft_h"), Em), E_RowLen_h(NoInit("E_RowLen_h"), Em); 399 PetscInt *E_NzLeft = E_NzLeft_h.data(), *E_RowLen = E_RowLen_h.data(); 400 for (PetscInt i = 0; i < Em; i++) { 401 const PetscInt *first, *last, *it; 402 PetscInt count, step; 403 // std::lower_bound(first,last,cstart), but need to use global column indices 404 first = Bj + Bi[i]; 405 last = Bj + Bi[i + 1]; 406 count = last - first; 407 while (count > 0) { 408 it = first; 409 step = count / 2; 410 it += step; 411 if (garray1[*it] < cstart) { // map local to global 412 first = ++it; 413 count -= step + 1; 414 } else count = step; 415 } 416 E_NzLeft[i] = first - (Bj + Bi[i]); 417 E_RowLen[i] = (Ai[i + 1] - Ai[i]) + (Bi[i + 1] - Bi[i]); 418 } 419 420 // Get length of rows (i.e., sizes of leaves) that contribute to my roots 421 const PetscMPIInt *iranks, *ranks; 422 const PetscInt *ioffset, *irootloc, *roffset, *rmine; 423 PetscInt niranks, nranks; 424 MPI_Request *reqs; 425 PetscMPIInt tag; 426 PetscSF reduceSF; 427 PetscInt *sdisp, *rdisp; 428 429 PetscCall(PetscCommGetNewTag(comm, &tag)); 430 PetscCall(PetscSFGetLeafRanks(ownerSF, &niranks, &iranks, &ioffset, &irootloc)); // get leaf ranks connecting to roots on this process (I'll recv from them) 431 PetscCall(PetscSFGetRootRanks(ownerSF, &nranks, &ranks, &roffset, &rmine, NULL)); // get root ranks this process connects (I'll send to them) 432 433 // Find out length of each row I will receive. Even for the same row index, when they are from 434 // different senders, they might have different lengths (and sparsity patterns) 435 PetscInt sendRowCnt = roffset[nranks], recvRowCnt = ioffset[niranks]; 436 PetscInt *sendRowLen, *recvRowLen; // lengths of rows of I need to send/recv per process 437 438 PetscCall(PetscMalloc5(sendRowCnt, &sendRowLen, recvRowCnt + 1, &recvRowLen, nranks, &sdisp, niranks + 1, &rdisp, nranks + niranks, &reqs)); 439 440 for (PetscInt i = 0; i < sendRowCnt; i++) sendRowLen[i] = E_RowLen[rmine[i]]; 441 recvRowLen[0] = 0; // since we will make it in CSR format later 442 recvRowLen++; // advance the pointer now 443 for (PetscInt i = 0; i < niranks; i++) MPI_Irecv(&recvRowLen[ioffset[i]], ioffset[i + 1] - ioffset[i], MPIU_INT, iranks[i], tag, comm, &reqs[nranks + i]); 444 for (PetscInt i = 0; i < nranks; i++) MPIU_Isend(&sendRowLen[roffset[i]], roffset[i + 1] - roffset[i], MPIU_INT, ranks[i], tag, comm, &reqs[i]); 445 PetscCallMPI(MPI_Waitall(nranks + niranks, reqs, MPI_STATUSES_IGNORE)); 446 447 // Build the real PetscSF for reducing E rows (buffer to buffer) 448 rdisp[0] = 0; 449 for (PetscInt i = 0; i < niranks; i++) { 450 rdisp[i + 1] = rdisp[i]; 451 for (PetscInt j = ioffset[i]; j < ioffset[i + 1]; j++) { rdisp[i + 1] += recvRowLen[j]; } 452 } 453 recvRowLen--; // put it back into csr format 454 for (PetscInt i = 0; i < recvRowCnt; i++) recvRowLen[i + 1] += recvRowLen[i]; 455 456 for (PetscInt i = 0; i < nranks; i++) MPIU_Irecv(&sdisp[i], 1, MPIU_INT, ranks[i], tag, comm, &reqs[i]); 457 for (PetscInt i = 0; i < niranks; i++) MPIU_Isend(&rdisp[i], 1, MPIU_INT, iranks[i], tag, comm, &reqs[nranks + i]); 458 PetscCallMPI(MPI_Waitall(nranks + niranks, reqs, MPI_STATUSES_IGNORE)); 459 460 PetscInt nleaves = 0, Enz = 0; // leaves are nonzeros I will send 461 PetscInt nroots = rdisp[niranks]; // roots are nonzeros I will recv 462 PetscSFNode *iremote; 463 464 for (PetscInt i = 0; i < Em; i++) Enz += E_RowLen[i]; 465 PetscAssert(A.nnz() + B.nnz() == Enz, comm, PETSC_ERR_PLIB, "Enz should be equal to sum of nnz of A and B"); 466 PetscCallMPI(PetscMalloc1(Enz, &iremote)); // no free, since we give ownership to reduceSF 467 468 for (PetscInt i = 0; i < nranks; i++) { 469 PetscInt count = 0; 470 for (PetscInt j = roffset[i]; j < roffset[i + 1]; j++) count += E_RowLen[rmine[j]]; 471 for (PetscInt j = 0; j < count; j++) { 472 iremote[nleaves + j].rank = ranks[i]; 473 iremote[nleaves + j].index = sdisp[i] + j; 474 } 475 nleaves += count; 476 } 477 PetscCheck(nleaves == Enz, comm, PETSC_ERR_PLIB, "nleaves should be equal to Enz"); 478 479 PetscCall(PetscSFCreate(comm, &reduceSF)); 480 PetscCall(PetscSFSetGraph(reduceSF, nroots, nleaves, NULL, PETSC_OWN_POINTER, iremote, PETSC_OWN_POINTER)); 481 482 // Copy (global) column indices of the needed rows in E to sendCol[], and then PetscSFReduce to recvCol[] 483 PetscInt *sendCol, *recvCol; 484 PetscCall(PetscMalloc2(nleaves, &sendCol, nroots, &recvCol)); 485 for (PetscInt k = 0; k < roffset[nranks]; k++) { 486 PetscInt i = rmine[k]; // row to be copied 487 PetscInt *buf = &sendCol[Ai[i] + Bi[i]]; 488 PetscInt nzLeft = E_NzLeft[i]; 489 PetscInt alen = Ai[i + 1] - Ai[i], blen = Bi[i + 1] - Bi[i]; 490 for (PetscInt j = 0; j < alen + blen; j++) { 491 if (j < nzLeft) { 492 buf[j] = garray1[Bj[Bi[i] + j]]; // left B, in global 493 } else if (j < nzLeft + alen) { 494 buf[j] = Aj[Ai[i] + j - nzLeft] + cstart; // diag A, also in global 495 } else { 496 buf[j] = garray1[Bj[Bi[i] + j - alen]]; // right B, in global 497 } 498 } 499 } 500 PetscCall(PetscSFReduceWithMemTypeBegin(reduceSF, MPIU_INT, PETSC_MEMTYPE_HOST, sendCol, PETSC_MEMTYPE_HOST, recvCol, MPI_REPLACE)); 501 PetscCall(PetscSFReduceEnd(reduceSF, MPIU_INT, sendCol, recvCol, MPI_REPLACE)); 502 503 // With recvCol[], we do a series of analysis to get i, j of Fd, Fo, and build plans to reduce nonzeros in recv buffers to Fd and Fo 504 PetscInt *recvRowPerm, *recvColSorted; 505 PetscInt *recvNzPerm, *recvNzPermSorted; 506 PetscCall(PetscMalloc4(recvRowCnt, &recvRowPerm, nroots, &recvColSorted, nroots, &recvNzPerm, nroots, &recvNzPermSorted)); 507 508 for (PetscInt i = 0; i < nroots; i++) recvNzPerm[i] = i; // numbering all received nonzeros 509 for (PetscInt i = 0; i < recvRowCnt; i++) recvRowPerm[i] = i; // put up a permutation array, so that after sorting we know where to get a row in recvCol[] 510 PetscCall(PetscSortIntWithPermutation(recvRowCnt, irootloc, recvRowPerm)); // irootloc[] (owned by ownerSF) won't be changed 511 512 // i[] array, nz are always easiest to compute 513 MatRowMapKokkosViewHost Fdi_h(NoInit("Fdi_h"), Fm + 1), Foi_h(NoInit("Foi_h"), Fm + 1); 514 MatRowMapType *Fdi, *Foi; 515 PetscInt FnzDups = 0, Fdnz = 0, FdnzDups = 0, Fonz = 0, FonzDups = 0; // nz (with or without dups) in F, Fd, Fo 516 PetscInt iter; 517 518 Kokkos::deep_copy(Fdi_h, 0); // zero, as we will do 'val++' on them 519 Kokkos::deep_copy(Foi_h, 0); 520 Fdi = Fdi_h.data() + 1; // +1 for easy indexing in code below 521 Foi = Foi_h.data() + 1; 522 iter = 0; 523 while (iter < recvRowCnt) { // iter over received rows 524 PetscInt curRowIdx = irootloc[recvRowPerm[iter]]; 525 PetscInt dupRows = 1; // current row has this many contributing rows (of various sparsity patterns) 526 527 while (iter + dupRows < recvRowCnt && irootloc[recvRowPerm[iter + dupRows]] == curRowIdx) dupRows++; 528 529 // Copy column indices (and their permutation) of these rows into recvColSorted & recvNzPermSorted 530 PetscInt nz = 0; // nz (with dups) in the current row 531 PetscInt *jbuf = recvColSorted + FnzDups; 532 PetscInt *pbuf = recvNzPermSorted + FnzDups; 533 PetscInt *jbuf2 = jbuf; // temp pointers 534 PetscInt *pbuf2 = pbuf; 535 for (PetscInt d = 0; d < dupRows; d++) { 536 PetscInt i = recvRowPerm[iter + d]; 537 PetscInt len = recvRowLen[i + 1] - recvRowLen[i]; 538 PetscCall(PetscArraycpy(jbuf2, &recvCol[recvRowLen[i]], len)); 539 PetscCall(PetscArraycpy(pbuf2, &recvNzPerm[recvRowLen[i]], len)); 540 jbuf2 += len; 541 pbuf2 += len; 542 nz += len; 543 } 544 PetscCall(PetscIntSortSemiOrderedWithArray(nz, jbuf, pbuf)); // It could be improved with k-way merge sort, since the rows are already sorted 545 546 // Scan column indices (in jbuf[0,nz), might have dups) of this row, and see how many go to Fd and how many go to Fo 547 PetscInt cur = 0; 548 while (cur < nz) { 549 PetscInt curColIdx = jbuf[cur]; 550 PetscInt dups = 1; 551 552 while (cur + dups < nz && jbuf[cur + dups] == curColIdx) dups++; 553 if (curColIdx >= cstart && curColIdx < cend) { 554 Fdi[curRowIdx]++; 555 FdnzDups += dups; 556 } else { 557 Foi[curRowIdx]++; 558 FonzDups += dups; 559 } 560 cur += dups; 561 } 562 563 FnzDups += nz; 564 iter += dupRows; // Move to next unique row 565 } 566 567 Fdi = Fdi_h.data(); // restore Fdi, Foi and make them CSR 568 Foi = Foi_h.data(); 569 for (PetscInt i = 0; i < Fm; i++) { 570 Fdi[i + 1] += Fdi[i]; 571 Foi[i + 1] += Foi[i]; 572 } 573 Fdnz = Fdi[Fm]; 574 Fonz = Foi[Fm]; 575 PetscCall(PetscFree2(sendCol, recvCol)); 576 577 // Allocate j, jmap, jperm for Fd and Fo 578 MatColIdxKokkosViewHost Fdj_h(NoInit("Fdj_h"), Fdnz), Foj_h(NoInit("Foj_h"), Fonz); 579 MatRowMapKokkosViewHost Fdjmap_h(NoInit("Fdjmap_h"), Fdnz + 1), Fojmap_h(NoInit("Fojmap_h"), Fonz + 1); // +1 to make csr 580 MatRowMapKokkosViewHost Fdjperm_h(NoInit("Fdjperm_h"), FdnzDups), Fojperm_h(NoInit("Fojperm_h"), FonzDups); 581 MatColIdxType *Fdj = Fdj_h.data(), *Foj = Foj_h.data(); 582 MatRowMapType *Fdjmap = Fdjmap_h.data(), *Fojmap = Fojmap_h.data(); 583 MatRowMapType *Fdjperm = Fdjperm_h.data(), *Fojperm = Fojperm_h.data(); 584 585 // Scan recvColSorted[] again, and fill j, jmap, jperm for Fd and Fo 586 Fdjmap[0] = 0; 587 Fojmap[0] = 0; 588 FnzDups = 0; 589 Fdnz = 0; 590 Fonz = 0; 591 iter = 0; // iter over received rows 592 while (iter < recvRowCnt) { 593 PetscInt curRowIdx = irootloc[recvRowPerm[iter]]; // current row idx 594 PetscInt dupRows = 1; // It has this many contributing rows (of various lengths) 595 PetscInt nz = 0; // nz (with dups) in the current row 596 597 while (iter + dupRows < recvRowCnt && irootloc[recvRowPerm[iter + dupRows]] == curRowIdx) dupRows++; 598 for (PetscInt d = 0; d < dupRows; d++) { 599 PetscInt i = recvRowPerm[iter + d]; 600 nz += recvRowLen[i + 1] - recvRowLen[i]; 601 } 602 603 PetscInt *jbuf = recvColSorted + FnzDups; 604 // Scan columns (in jbuf[0,nz) of this row, copy them and their permutation to j[] and jperm[] of Fd and Fo 605 PetscInt cur = 0; 606 while (cur < nz) { 607 PetscInt curColIdx = jbuf[cur]; 608 PetscInt dups = 1; 609 610 while (cur + dups < nz && jbuf[cur + dups] == curColIdx) dups++; 611 if (curColIdx >= cstart && curColIdx < cend) { 612 Fdj[Fdnz] = curColIdx - cstart; // easily convert to local 613 Fdjmap[Fdnz + 1] = Fdjmap[Fdnz] + dups; 614 for (PetscInt j = 0; j < dups; j++) Fdjperm[Fdjmap[Fdnz] + j] = recvNzPermSorted[FnzDups + j]; 615 FdnzDups += dups; 616 Fdnz++; 617 } else { 618 Foj[Fonz] = curColIdx; // in global 619 Fojmap[Fonz + 1] = Fojmap[Fonz] + dups; 620 for (PetscInt j = 0; j < dups; j++) Fojperm[Fojmap[Fonz] + j] = recvNzPermSorted[FnzDups + j]; 621 FonzDups += dups; 622 Fonz++; 623 } 624 cur += dups; 625 FnzDups += dups; 626 } 627 iter += dupRows; // Move to next unique row 628 } 629 PetscCall(PetscFree4(recvRowPerm, recvColSorted, recvNzPerm, recvNzPermSorted)); 630 PetscCall(PetscFree5(sendRowLen, recvRowLen, sdisp, rdisp, reqs)); 631 632 // Combine global column indices in garray1[] and Foj[] 633 PetscInt n2, *garray2; 634 635 PetscCall(ReduceTwoSetsOfGlobalIndices(n1, garray1, Fonz, Foj, &n2, &garray2, map)); 636 mm->sf = reduceSF; 637 mm->leafBuf = MatScalarKokkosView(NoInit("leafBuf"), nleaves); 638 mm->rootBuf = MatScalarKokkosView(NoInit("rootBuf"), nroots); 639 mm->garray = garray2; // give ownership, so no free 640 mm->n = n2; 641 mm->E_NzLeft = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), E_NzLeft_h); 642 mm->Fdjmap = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Fdjmap_h); 643 mm->Fdjperm = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Fdjperm_h); 644 mm->Fojmap = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Fojmap_h); 645 mm->Fojperm = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Fojperm_h); 646 647 // Output Fd and Fo in KokkosCsrMatrix format 648 MatScalarKokkosView Fda_d(NoInit("Fda_d"), Fdnz); 649 MatRowMapKokkosView Fdi_d = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Fdi_h); 650 MatColIdxKokkosView Fdj_d = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Fdj_h); 651 MatScalarKokkosView Foa_d(NoInit("Foa_d"), Fonz); 652 MatRowMapKokkosView Foi_d = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Foi_h); 653 MatColIdxKokkosView Foj_d = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Foj_h); 654 655 PetscCallCXX(mm->Fd = KokkosCsrMatrix("Fd", Fm, cend - cstart, Fdnz, Fda_d, Fdi_d, Fdj_d)); 656 PetscCallCXX(mm->Fo = KokkosCsrMatrix("Fo", Fm, n2, Fonz, Foa_d, Foi_d, Foj_d)); // Fo's column size is n2, length of garray2[] 657 658 // Compute kernel launch parameters in merging E 659 PetscInt teamSize, vectorLength, rowsPerTeam; 660 661 teamSize = vectorLength = rowsPerTeam = -1; 662 PetscCall(MatMergeGetLaunchParameters<DefaultExecutionSpace>(Em, Enz, -1, teamSize, vectorLength, rowsPerTeam)); 663 mm->E_TeamSize = teamSize; 664 mm->E_VectorLength = vectorLength; 665 mm->E_RowsPerTeam = rowsPerTeam; 666 } else PetscCheck(reuse == MAT_REUSE_MATRIX, comm, PETSC_ERR_PLIB, "Unsupported MatReuse enum %d", reuse); 667 668 // Handy aliases 669 auto &Aa = A.values; 670 auto &Ba = B.values; 671 const auto &Ai = A.graph.row_map; 672 const auto &Bi = B.graph.row_map; 673 const auto &E_NzLeft = mm->E_NzLeft; 674 auto &leafBuf = mm->leafBuf; 675 auto &rootBuf = mm->rootBuf; 676 PetscSF reduceSF = mm->sf; 677 PetscInt Em = A.numRows(); 678 PetscInt teamSize = mm->E_TeamSize; 679 PetscInt vectorLength = mm->E_VectorLength; 680 PetscInt rowsPerTeam = mm->E_RowsPerTeam; 681 PetscInt workSets = (Em + rowsPerTeam - 1) / rowsPerTeam; 682 683 // Copy rows in A/B of E to leafBuf, then pass it to rootBuf 684 PetscCallCXX(Kokkos::parallel_for( 685 Kokkos::TeamPolicy<>(PetscGetKokkosExecutionSpace(), workSets, teamSize, vectorLength), KOKKOS_LAMBDA(const KokkosTeamMemberType &t) { 686 Kokkos::parallel_for(Kokkos::TeamThreadRange(t, 0, rowsPerTeam), [&](PetscInt k) { 687 PetscInt i = t.league_rank() * rowsPerTeam + k; // i-th row in F 688 if (i < Em) { 689 PetscInt disp = Ai(i) + Bi(i); 690 PetscInt alen = Ai(i + 1) - Ai(i); 691 PetscInt blen = Bi(i + 1) - Bi(i); 692 PetscInt nzleft = E_NzLeft(i); 693 694 Kokkos::parallel_for(Kokkos::ThreadVectorRange(t, alen + blen), [&](PetscInt j) { 695 MatScalar &val = leafBuf(disp + j); 696 if (j < nzleft) { // B left 697 val = Ba(Bi(i) + j); 698 } else if (j < nzleft + alen) { // diag A 699 val = Aa(Ai(i) + j - nzleft); 700 } else { // B right 701 val = Ba(Bi(i) + j - alen); 702 } 703 }); 704 } 705 }); 706 })); 707 PetscCall(PetscSFReduceWithMemTypeBegin(reduceSF, MPIU_SCALAR, PETSC_MEMTYPE_KOKKOS, leafBuf.data(), PETSC_MEMTYPE_KOKKOS, rootBuf.data(), MPI_REPLACE)); 708 PetscFunctionReturn(PETSC_SUCCESS); 709 } 710 711 // To finish MatMPIAIJKokkosReduce. 712 static PetscErrorCode MatMPIAIJKokkosReduceEnd(MPI_Comm comm, KokkosCsrMatrix A, KokkosCsrMatrix B, PetscInt cstart, PetscInt cend, const PetscInt *garray1, PetscSF ownerSF, MatReuse reuse, PetscInt *map, MatMatStruct_AtB *mm) 713 { 714 auto &leafBuf = mm->leafBuf; 715 auto &rootBuf = mm->rootBuf; 716 auto &Fda = mm->Fd.values; 717 const auto &Fdjmap = mm->Fdjmap; 718 const auto &Fdjperm = mm->Fdjperm; 719 auto Fdnz = mm->Fd.nnz(); 720 auto &Foa = mm->Fo.values; 721 const auto &Fojmap = mm->Fojmap; 722 const auto &Fojperm = mm->Fojperm; 723 auto Fonz = mm->Fo.nnz(); 724 PetscSF reduceSF = mm->sf; 725 726 PetscFunctionBegin; 727 PetscCall(PetscSFReduceEnd(reduceSF, MPIU_SCALAR, leafBuf.data(), rootBuf.data(), MPI_REPLACE)); 728 729 // Reduce data in rootBuf to Fd and Fo 730 PetscCallCXX(Kokkos::parallel_for( 731 Kokkos::RangePolicy<>(PetscGetKokkosExecutionSpace(), 0, Fdnz), KOKKOS_LAMBDA(const MatRowMapType i) { 732 PetscScalar sum = 0.0; 733 for (MatRowMapType k = Fdjmap(i); k < Fdjmap(i + 1); k++) sum += rootBuf(Fdjperm(k)); 734 Fda(i) = sum; 735 })); 736 737 PetscCallCXX(Kokkos::parallel_for( 738 Kokkos::RangePolicy<>(PetscGetKokkosExecutionSpace(), 0, Fonz), KOKKOS_LAMBDA(const MatRowMapType i) { 739 PetscScalar sum = 0.0; 740 for (MatRowMapType k = Fojmap(i); k < Fojmap(i + 1); k++) sum += rootBuf(Fojperm(k)); 741 Foa(i) = sum; 742 })); 743 PetscFunctionReturn(PETSC_SUCCESS); 744 } 745 746 /* 747 MatMPIAIJKokkosBcast - Bcast local rows of a MPIAIJKOKKOS matrix (E) to produce a local matrix (F, stored in mm) in split form 748 749 This is a complex routine. It is essentially the MPIAIJKOKKOS counterpart of MatGetBrowsOfAoCols_MPIAIJ, but supports 750 device and involves various index mapping. 751 752 In the given ownerSF, leaves correspond to rows in F, and roots correspond to rows in E. Roots may connect to multiple leaves. 753 Suppose F's j-th row is connected to a root identified by PetscSFNode (k,i), it means we need to bcast the i-th row of E on rank k 754 to j-th row of F. ownerSF is not an arbitrary SF, instead it is the Mvctx of another MPIAIJ matrix A that is able to perform A*E. 755 F has the same column layout as E. 756 757 Conceptually F has global column indices. In this routine, we spit F into diagonal Fd and off-diagonal Fo. 758 Fd uses local column indices, which are easy to compute. We just need to subtract the "local column range start" from the global indices. 759 Fo had global column indices at first. We will reduce them into local ones. In doing that, we also take into account the global 760 column indices that E's off-diag block has. Let's say there are n1 such indices stored in garray1[]. We will reduce them along with 761 column indices in Fo and update Fo with local indices. 762 763 Input Parameters: 764 + E - the MPIAIJKOKKOS matrix 765 . ownerSF - the ownership SF (insignificant in MAT_REUSE_MATRIX) 766 . reuse - either MAT_INITIAL_MATRIX or MAT_REUSE_MATRIX 767 - mm - to stash matproduct intermediate data structures 768 769 Output Parameters: 770 + map[n1] - allocated by caller. It maps garray1[] to garray2[]. See more at ReduceTwoSetsOfGlobalIndices. 771 - mm - contains various info, such as garray2[], Fd, Fo, etc. 772 773 Notes: 774 When reuse = MAT_REUSE_MATRIX, ownerSF, map are not significant. 775 The routine is provide in split-phase form MatMPIAIJKokkosBcastBegin/End() to provide computation/communication opportunities. 776 */ 777 static PetscErrorCode MatMPIAIJKokkosBcastBegin(Mat E, PetscSF ownerSF, MatReuse reuse, PetscInt *map, MatMatStruct_AB *mm) 778 { 779 Mat_MPIAIJ *empi = static_cast<Mat_MPIAIJ *>(E->data); 780 Mat A = empi->A, B = empi->B; // diag and off-diag 781 Mat_SeqAIJKokkos *akok = static_cast<Mat_SeqAIJKokkos *>(A->spptr), *bkok = static_cast<Mat_SeqAIJKokkos *>(B->spptr); 782 PetscInt Em = E->rmap->n; // #local rows 783 MPI_Comm comm; 784 785 PetscFunctionBegin; 786 PetscCallMPI(PetscObjectGetComm((PetscObject)E, &comm)); 787 if (reuse == MAT_INITIAL_MATRIX) { 788 Mat_SeqAIJ *aseq = static_cast<Mat_SeqAIJ *>(A->data), *bseq = static_cast<Mat_SeqAIJ *>(B->data); 789 PetscInt n1 = B->cmap->n, *Ai = aseq->i, *Aj = aseq->j, *Bi = bseq->i, *Bj = bseq->j; 790 const PetscInt *garray1 = empi->garray; // its size is n1 791 PetscInt cstart, cend; 792 PetscSF bcastSF; 793 794 PetscCall(MatGetOwnershipRangeColumn(E, &cstart, &cend)); 795 796 // Count how many nonzeros of each row in E are in the left of the diag block [cstart,cend) 797 PetscIntKokkosViewHost E_NzLeft_h(NoInit("E_NzLeft_h"), Em), E_RowLen_h(NoInit("E_RowLen_h"), Em); 798 PetscInt *E_NzLeft = E_NzLeft_h.data(), *E_RowLen = E_RowLen_h.data(); 799 for (PetscInt i = 0; i < Em; i++) { 800 const PetscInt *first, *last, *it; 801 PetscInt count, step; 802 // std::lower_bound(first,last,cstart), but need to use global column indices 803 first = Bj + Bi[i]; 804 last = Bj + Bi[i + 1]; 805 count = last - first; 806 while (count > 0) { 807 it = first; 808 step = count / 2; 809 it += step; 810 if (empi->garray[*it] < cstart) { // map local to global 811 first = ++it; 812 count -= step + 1; 813 } else count = step; 814 } 815 E_NzLeft[i] = first - (Bj + Bi[i]); 816 E_RowLen[i] = (Ai[i + 1] - Ai[i]) + (Bi[i + 1] - Bi[i]); 817 } 818 819 // Compute row pointer Fi of F 820 PetscInt *Fi, Fm, Fnz; 821 PetscCall(PetscSFGetGraph(ownerSF, NULL, &Fm, NULL, NULL)); // Fm = #rows of F = nleaves of ownerSF 822 PetscCall(PetscMalloc1(Fm + 1, &Fi)); 823 Fi[0] = 0; 824 PetscCall(PetscSFBcastWithMemTypeBegin(ownerSF, MPIU_INT, PETSC_MEMTYPE_HOST, E_RowLen, PETSC_MEMTYPE_HOST, &Fi[1], MPI_REPLACE)); 825 PetscCall(PetscSFBcastEnd(ownerSF, MPIU_INT, E_RowLen, &Fi[1], MPI_REPLACE)); 826 for (PetscInt i = 0; i < Fm; i++) Fi[i + 1] += Fi[i]; 827 Fnz = Fi[Fm]; 828 829 // Build the real PetscSF for bcasting E rows (buffer to buffer) 830 const PetscMPIInt *iranks, *ranks; 831 const PetscInt *ioffset, *irootloc, *roffset; 832 PetscInt niranks, nranks, *sdisp, *rdisp; 833 MPI_Request *reqs; 834 PetscMPIInt tag; 835 836 PetscCall(PetscSFGetLeafRanks(ownerSF, &niranks, &iranks, &ioffset, &irootloc)); // get leaf ranks referencing roots on this process 837 PetscCall(PetscSFGetRootRanks(ownerSF, &nranks, &ranks, &roffset, NULL, NULL)); // recv info 838 PetscCall(PetscMalloc3(niranks + 1, &sdisp, nranks, &rdisp, niranks + nranks, &reqs)); 839 840 sdisp[0] = 0; // send displacement 841 for (PetscInt i = 0; i < niranks; i++) { 842 sdisp[i + 1] = sdisp[i]; 843 for (PetscInt j = ioffset[i]; j < ioffset[i + 1]; j++) { 844 PetscInt r = irootloc[j]; // row to be sent 845 sdisp[i + 1] += E_RowLen[r]; 846 } 847 } 848 849 PetscCallMPI(PetscCommGetNewTag(comm, &tag)); 850 for (PetscInt i = 0; i < nranks; i++) PetscCallMPI(MPIU_Irecv(&rdisp[i], 1, MPIU_INT, ranks[i], tag, comm, &reqs[i])); 851 for (PetscInt i = 0; i < niranks; i++) PetscCallMPI(MPIU_Isend(&sdisp[i], 1, MPIU_INT, iranks[i], tag, comm, &reqs[nranks + i])); 852 PetscCallMPI(MPI_Waitall(niranks + nranks, reqs, MPI_STATUSES_IGNORE)); 853 854 PetscInt nleaves = Fnz; // leaves are nonzeros I will receive 855 PetscInt nroots = sdisp[niranks]; // roots are nonzeros I will send 856 PetscSFNode *iremote; // give ownership to bcastSF 857 PetscCall(PetscMalloc1(nleaves, &iremote)); 858 for (PetscInt i = 0; i < nranks; i++) { // for each sender rank 859 PetscInt k = 0; 860 for (PetscInt j = Fi[roffset[i]]; j < Fi[roffset[i + 1]]; j++) { // I will receive rows [roffset[i], roffset[i+1]) of F from ranks[i] 861 iremote[j].rank = ranks[i]; 862 iremote[j].index = rdisp[i] + k; // their root location 863 k++; 864 } 865 } 866 PetscCall(PetscSFCreate(comm, &bcastSF)); 867 PetscCall(PetscSFSetGraph(bcastSF, nroots, nleaves, NULL, PETSC_OWN_POINTER, iremote, PETSC_OWN_POINTER)); 868 PetscCall(PetscFree3(sdisp, rdisp, reqs)); 869 870 // Build a plan (rowoffset, irootloc, E_NzLeft) to copy rows in E to rootdata of bcastSF in parallel 871 PetscIntKokkosViewHost rowoffset_h(NoInit("rowoffset_h"), ioffset[niranks] + 1); 872 PetscInt *rowoffset = rowoffset_h.data(); // for each entry (row) indicated in irootloc[], we calculate its destinate offset in copying 873 rowoffset[0] = 0; 874 for (PetscInt i = 0; i < ioffset[niranks]; i++) { rowoffset[i + 1] = rowoffset[i] + E_RowLen[irootloc[i]]; } 875 876 // Copy (global) column indices of the needed rows in E to a buffer, and then bcast to Fj[] 877 PetscInt *jbuf, *Fj; 878 PetscCall(PetscMalloc2(nroots, &jbuf, Fnz, &Fj)); 879 for (PetscInt k = 0; k < ioffset[niranks]; k++) { 880 PetscInt i = irootloc[k]; // row to be copied 881 PetscInt *buf = &jbuf[rowoffset[k]]; 882 PetscInt nzLeft = E_NzLeft[i]; 883 PetscInt alen = Ai[i + 1] - Ai[i], blen = Bi[i + 1] - Bi[i]; 884 for (PetscInt j = 0; j < alen + blen; j++) { 885 if (j < nzLeft) { 886 buf[j] = empi->garray[Bj[Bi[i] + j]]; // left B, in global 887 } else if (j < nzLeft + alen) { 888 buf[j] = Aj[Ai[i] + j - nzLeft] + cstart; // diag A, also in global 889 } else { 890 buf[j] = empi->garray[Bj[Bi[i] + j - alen]]; // right B, in global 891 } 892 } 893 } 894 PetscCall(PetscSFBcastWithMemTypeBegin(bcastSF, MPIU_INT, PETSC_MEMTYPE_HOST, jbuf, PETSC_MEMTYPE_HOST, Fj, MPI_REPLACE)); 895 PetscCall(PetscSFBcastEnd(bcastSF, MPIU_INT, jbuf, Fj, MPI_REPLACE)); 896 897 // Build a plan (i.e., F_NzLeft) to split F into Fd and Fo 898 MatRowMapKokkosViewHost Fdi_h(NoInit("Fdi_h"), Fm + 1), Foi_h(NoInit("Foi_h"), Fm + 1); // row pointer of Fd, Fo 899 MatColIdxKokkosViewHost F_NzLeft_h(NoInit("F_NzLeft_h"), Fm); // split each row of F into Left, Diag, Right. We only need to record #nz in Left and Diag. 900 MatRowMapType *Fdi = Fdi_h.data(), *Foi = Foi_h.data(); 901 MatColIdxType *F_NzLeft = F_NzLeft_h.data(); 902 903 Fdi[0] = Foi[0] = 0; 904 for (PetscInt i = 0; i < Fm; i++) { 905 PetscInt *first, *last, *lb1, *lb2; 906 // cut the row into: Left, [cstart, cend), Right 907 first = Fj + Fi[i]; 908 last = Fj + Fi[i + 1]; 909 lb1 = std::lower_bound(first, last, cstart); 910 F_NzLeft[i] = lb1 - first; 911 lb2 = std::lower_bound(first, last, cend); 912 Fdi[i + 1] = lb2 - lb1; // row i length in Fdi 913 Foi[i + 1] = (Fi[i + 1] - Fi[i]) - Fdi[i + 1]; // row i length in Foi 914 } 915 for (PetscInt i = 0; i < Fm; i++) { 916 Fdi[i + 1] += Fdi[i]; 917 Foi[i + 1] += Foi[i]; 918 } 919 920 // Fill Fdj[] and Foj[], i.e., columns of Fd and Fo. Fdj[] are local, but Foj[] are not yet. 921 PetscInt Fdnz = Fdi[Fm], Fonz = Foi[Fm]; 922 MatColIdxKokkosViewHost Fdj_h(NoInit("Fdj_h"), Fdnz), Foj_h(NoInit("Foj_h"), Fonz); 923 MatColIdxType *Fdj = Fdj_h.data(), *Foj = Foj_h.data(), gid; 924 925 for (PetscInt i = 0; i < Fm; i++) { 926 PetscInt nzLeft = F_NzLeft[i]; 927 PetscInt len = Fdi[i + 1] - Fdi[i]; // diag row len 928 for (PetscInt j = 0; j < Fi[i + 1] - Fi[i]; j++) { 929 gid = Fj[Fi[i] + j]; 930 if (j < nzLeft) { // left, in global 931 Foj[Foi[i] + j] = gid; 932 } else if (j < nzLeft + len) { // diag, in local 933 Fdj[Fdi[i] + j - nzLeft] = gid - cstart; 934 } else { // right, in global 935 Foj[Foi[i] + j - len] = gid; 936 } 937 } 938 } 939 PetscCall(PetscFree2(jbuf, Fj)); 940 PetscCall(PetscFree(Fi)); 941 942 // Reduce global indices in Foj[] and garray1[] into local ones 943 PetscInt n2, *garray2; 944 PetscCall(ReduceTwoSetsOfGlobalIndices(n1, garray1, Fonz, Foj, &n2, &garray2, map)); 945 946 // Record the plans built above, for reuse 947 PetscIntKokkosViewHost tmp(const_cast<PetscInt *>(irootloc), ioffset[niranks]); // irootloc[] is owned by ownerSF. We create a copy for safety 948 PetscIntKokkosViewHost irootloc_h(NoInit("irootloc_h"), ioffset[niranks]); 949 Kokkos::deep_copy(irootloc_h, tmp); 950 mm->sf = bcastSF; 951 mm->E_NzLeft = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), E_NzLeft_h); 952 mm->F_NzLeft = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), F_NzLeft_h); 953 mm->irootloc = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), irootloc_h); 954 mm->rowoffset = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), rowoffset_h); 955 mm->rootBuf = MatScalarKokkosView(NoInit("rootBuf"), nroots); 956 mm->leafBuf = MatScalarKokkosView(NoInit("leafBuf"), nleaves); 957 mm->garray = garray2; 958 mm->n = n2; 959 960 // Output Fd and Fo in KokkosCsrMatrix format 961 MatScalarKokkosView Fda_d(NoInit("Fda_d"), Fdnz), Foa_d(NoInit("Foa_d"), Fonz); 962 MatRowMapKokkosView Fdi_d = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Fdi_h); 963 MatColIdxKokkosView Fdj_d = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Fdj_h); 964 MatRowMapKokkosView Foi_d = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Foi_h); 965 MatColIdxKokkosView Foj_d = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Foj_h); 966 967 PetscCallCXX(mm->Fd = KokkosCsrMatrix("Fd", Fm, cend - cstart, Fdnz, Fda_d, Fdi_d, Fdj_d)); 968 PetscCallCXX(mm->Fo = KokkosCsrMatrix("Fo", Fm, n2, Fonz, Foa_d, Foi_d, Foj_d)); 969 970 // Compute kernel launch parameters in merging E or splitting F 971 PetscInt teamSize, vectorLength, rowsPerTeam; 972 973 teamSize = vectorLength = rowsPerTeam = -1; 974 PetscCall(MatMergeGetLaunchParameters<DefaultExecutionSpace>(mm->irootloc.extent(0), mm->rootBuf.extent(0), -1, teamSize, vectorLength, rowsPerTeam)); 975 mm->E_TeamSize = teamSize; 976 mm->E_VectorLength = vectorLength; 977 mm->E_RowsPerTeam = rowsPerTeam; 978 979 teamSize = vectorLength = rowsPerTeam = -1; 980 PetscCall(MatMergeGetLaunchParameters<DefaultExecutionSpace>(Fm, Fnz, -1, teamSize, vectorLength, rowsPerTeam)); 981 mm->F_TeamSize = teamSize; 982 mm->F_VectorLength = vectorLength; 983 mm->F_RowsPerTeam = rowsPerTeam; 984 } else PetscCheck(reuse == MAT_REUSE_MATRIX, comm, PETSC_ERR_PLIB, "Unsupported MatReuse enum %d", reuse); 985 986 // Sync E's value to device 987 akok->a_dual.sync_device(); 988 bkok->a_dual.sync_device(); 989 990 // Handy aliases 991 const auto &Aa = akok->a_dual.view_device(); 992 const auto &Ba = bkok->a_dual.view_device(); 993 const auto &Ai = akok->i_dual.view_device(); 994 const auto &Bi = bkok->i_dual.view_device(); 995 996 // Fetch the plans 997 PetscIntKokkosView &E_NzLeft = mm->E_NzLeft; 998 PetscSF &bcastSF = mm->sf; 999 MatScalarKokkosView &rootBuf = mm->rootBuf; 1000 MatScalarKokkosView &leafBuf = mm->leafBuf; 1001 PetscIntKokkosView &irootloc = mm->irootloc; 1002 PetscIntKokkosView &rowoffset = mm->rowoffset; 1003 1004 PetscInt teamSize = mm->E_TeamSize; 1005 PetscInt vectorLength = mm->E_VectorLength; 1006 PetscInt rowsPerTeam = mm->E_RowsPerTeam; 1007 PetscInt workSets = (irootloc.extent(0) + rowsPerTeam - 1) / rowsPerTeam; 1008 1009 // Copy rows in A/B of E to rootBuf, then bcast it to leafBuf 1010 PetscCallCXX(Kokkos::parallel_for( 1011 Kokkos::TeamPolicy<>(PetscGetKokkosExecutionSpace(), workSets, teamSize, vectorLength), KOKKOS_LAMBDA(const KokkosTeamMemberType &t) { 1012 Kokkos::parallel_for(Kokkos::TeamThreadRange(t, 0, rowsPerTeam), [&](PetscInt k) { 1013 size_t r = t.league_rank() * rowsPerTeam + k; // r-th entry in irootloc[] 1014 if (r < irootloc.extent(0)) { 1015 PetscInt i = irootloc(r); // row i of E 1016 PetscInt disp = rowoffset(r); 1017 PetscInt alen = Ai(i + 1) - Ai(i); 1018 PetscInt blen = Bi(i + 1) - Bi(i); 1019 PetscInt nzleft = E_NzLeft(i); 1020 1021 Kokkos::parallel_for(Kokkos::ThreadVectorRange(t, alen + blen), [&](PetscInt j) { 1022 if (j < nzleft) { // B left 1023 rootBuf(disp + j) = Ba(Bi(i) + j); 1024 } else if (j < nzleft + alen) { // diag A 1025 rootBuf(disp + j) = Aa(Ai(i) + j - nzleft); 1026 } else { // B right 1027 rootBuf(disp + j) = Ba(Bi(i) + j - alen); 1028 } 1029 }); 1030 } 1031 }); 1032 })); 1033 PetscCall(PetscSFBcastWithMemTypeBegin(bcastSF, MPIU_SCALAR, PETSC_MEMTYPE_KOKKOS, rootBuf.data(), PETSC_MEMTYPE_KOKKOS, leafBuf.data(), MPI_REPLACE)); 1034 PetscFunctionReturn(PETSC_SUCCESS); 1035 } 1036 1037 // To finish MatMPIAIJKokkosBcast. 1038 static PetscErrorCode MatMPIAIJKokkosBcastEnd(Mat E, PetscSF ownerSF, MatReuse reuse, PetscInt *map, MatMatStruct_AB *mm) 1039 { 1040 PetscFunctionBegin; 1041 const auto &Fd = mm->Fd; 1042 const auto &Fo = mm->Fo; 1043 const auto &Fdi = Fd.graph.row_map; 1044 const auto &Foi = Fo.graph.row_map; 1045 auto &Fda = Fd.values; 1046 auto &Foa = Fo.values; 1047 auto Fm = Fd.numRows(); 1048 1049 PetscIntKokkosView &F_NzLeft = mm->F_NzLeft; 1050 PetscSF &bcastSF = mm->sf; 1051 MatScalarKokkosView &rootBuf = mm->rootBuf; 1052 MatScalarKokkosView &leafBuf = mm->leafBuf; 1053 PetscInt teamSize = mm->F_TeamSize; 1054 PetscInt vectorLength = mm->F_VectorLength; 1055 PetscInt rowsPerTeam = mm->F_RowsPerTeam; 1056 PetscInt workSets = (Fm + rowsPerTeam - 1) / rowsPerTeam; 1057 1058 PetscCall(PetscSFBcastEnd(bcastSF, MPIU_SCALAR, rootBuf.data(), leafBuf.data(), MPI_REPLACE)); 1059 1060 // Update Fda and Foa with new data in leafBuf (as if it is Fa) 1061 PetscCallCXX(Kokkos::parallel_for( 1062 Kokkos::TeamPolicy<>(PetscGetKokkosExecutionSpace(), workSets, teamSize, vectorLength), KOKKOS_LAMBDA(const KokkosTeamMemberType &t) { 1063 Kokkos::parallel_for(Kokkos::TeamThreadRange(t, 0, rowsPerTeam), [&](PetscInt k) { 1064 PetscInt i = t.league_rank() * rowsPerTeam + k; // i-th row in F 1065 if (i < Fm) { 1066 PetscInt nzLeft = F_NzLeft(i); 1067 PetscInt alen = Fdi(i + 1) - Fdi(i); 1068 PetscInt blen = Foi(i + 1) - Foi(i); 1069 PetscInt Fii = Fdi(i) + Foi(i); 1070 1071 Kokkos::parallel_for(Kokkos::ThreadVectorRange(t, alen + blen), [&](PetscInt j) { 1072 PetscScalar val = leafBuf(Fii + j); 1073 if (j < nzLeft) { // left 1074 Foa(Foi(i) + j) = val; 1075 } else if (j < nzLeft + alen) { // diag 1076 Fda(Fdi(i) + j - nzLeft) = val; 1077 } else { // right 1078 Foa(Foi(i) + j - alen) = val; 1079 } 1080 }); 1081 } 1082 }); 1083 })); 1084 PetscFunctionReturn(PETSC_SUCCESS); 1085 } 1086 1087 static PetscErrorCode MatProductSymbolic_MPIAIJKokkos_AtB(Mat_Product *product, Mat A, Mat B, MatMatStruct_AtB *mm) 1088 { 1089 Mat_MPIAIJ *ampi = static_cast<Mat_MPIAIJ *>(A->data); 1090 Mat_MPIAIJ *bmpi = static_cast<Mat_MPIAIJ *>(B->data); 1091 KokkosCsrMatrix Adt, Aot, Ad, Ao, Bd, Bo; 1092 PetscInt cstart, cend; 1093 MPI_Comm comm; 1094 1095 PetscFunctionBegin; 1096 PetscCall(PetscObjectGetComm((PetscObject)B, &comm)); 1097 PetscCall(MatSeqAIJKokkosGenerateTranspose_Private(ampi->A, &Adt)); 1098 PetscCall(MatSeqAIJKokkosGenerateTranspose_Private(ampi->B, &Aot)); 1099 PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(ampi->A, &Ad)); 1100 PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(ampi->B, &Ao)); 1101 PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(bmpi->A, &Bd)); 1102 PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(bmpi->B, &Bo)); 1103 1104 // TODO: add command line options to select spgemm algorithms 1105 auto spgemm_alg = KokkosSparse::SPGEMMAlgorithm::SPGEMM_DEFAULT; // default is TPL if enabled, otherwise KK 1106 1107 // CUDA-10.2's spgemm has bugs. We prefer the SpGEMMreuse APIs introduced in cuda-11.4 1108 #if defined(KOKKOSKERNELS_ENABLE_TPL_CUSPARSE) 1109 #if PETSC_PKG_CUDA_VERSION_LT(11, 4, 0) 1110 spgemm_alg = KokkosSparse::SPGEMMAlgorithm::SPGEMM_KK; 1111 #endif 1112 #endif 1113 1114 PetscCallCXX(mm->kh1.create_spgemm_handle(spgemm_alg)); 1115 PetscCallCXX(mm->kh2.create_spgemm_handle(spgemm_alg)); 1116 PetscCallCXX(mm->kh3.create_spgemm_handle(spgemm_alg)); 1117 PetscCallCXX(mm->kh4.create_spgemm_handle(spgemm_alg)); 1118 1119 // Aot * (B's diag + B's off-diag) 1120 PetscCallCXX(KokkosSparse::spgemm_symbolic(mm->kh3, Aot, false, Bd, false, mm->C3)); 1121 PetscCallCXX(KokkosSparse::spgemm_symbolic(mm->kh4, Aot, false, Bo, false, mm->C4)); 1122 // KK spgemm_symbolic() only populates the result's row map, but not its columns. 1123 // TODO: Remove the fake spgemm_numeric() after KK fixed this problem. 1124 PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh3, Aot, false, Bd, false, mm->C3)); 1125 PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh4, Aot, false, Bo, false, mm->C4)); 1126 #if PETSC_PKG_KOKKOS_KERNELS_VERSION_LT(4, 0, 0) 1127 1128 PetscCallCXX(sort_crs_matrix(mm->C3)); 1129 PetscCallCXX(sort_crs_matrix(mm->C4)); 1130 #endif 1131 1132 // Reduce E (i.e., C3 and C4)'s rows to form F, and overlap the communication 1133 PetscIntKokkosViewHost map_h(NoInit("map_h"), bmpi->B->cmap->n); 1134 PetscCall(MatGetOwnershipRangeColumn(B, &cstart, &cend)); 1135 PetscCall(MatMPIAIJKokkosReduceBegin(comm, mm->C3, mm->C4, cstart, cend, bmpi->garray, ampi->Mvctx, MAT_INITIAL_MATRIX, map_h.data(), mm)); 1136 1137 // Adt * (B's diag + B's off-diag) 1138 PetscCallCXX(KokkosSparse::spgemm_symbolic(mm->kh1, Adt, false, Bd, false, mm->C1)); 1139 PetscCallCXX(KokkosSparse::spgemm_symbolic(mm->kh2, Adt, false, Bo, false, mm->C2_mid)); 1140 PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh1, Adt, false, Bd, false, mm->C1)); 1141 PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh2, Adt, false, Bo, false, mm->C2_mid)); 1142 #if PETSC_PKG_KOKKOS_KERNELS_VERSION_LT(4, 0, 0) 1143 PetscCallCXX(sort_crs_matrix(mm->C1)); 1144 PetscCallCXX(sort_crs_matrix(mm->C2_mid)); 1145 #endif 1146 1147 PetscCall(MatMPIAIJKokkosReduceEnd(comm, mm->C3, mm->C4, cstart, cend, bmpi->garray, ampi->Mvctx, MAT_INITIAL_MATRIX, map_h.data(), mm)); 1148 1149 // Create C2, which shares a, i arrays with C2_mid, but with new column indices and potentially larger column size 1150 MatColIdxKokkosView oldj = mm->C2_mid.graph.entries, newj(NoInit("j"), oldj.extent(0)); 1151 PetscIntKokkosView map = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), map_h); 1152 PetscCallCXX(Kokkos::parallel_for(Kokkos::RangePolicy<>(PetscGetKokkosExecutionSpace(), 0, oldj.extent(0)), KOKKOS_LAMBDA(const PetscInt i) { newj(i) = map(oldj(i)); })); 1153 PetscCallCXX(mm->C2 = KokkosCsrMatrix("C2", mm->C2_mid.numRows(), mm->n /*new column size*/, mm->C2_mid.nnz(), mm->C2_mid.values, mm->C2_mid.graph.row_map, newj)); 1154 1155 // C = (C1+Fd, C2+Fo) 1156 PetscCallCXX(mm->kh1.create_spadd_handle(true)); // C1, Fd are sorted 1157 PetscCallCXX(mm->kh2.create_spadd_handle(true)); // C2, Fo are sorted 1158 PetscCallCXX(KokkosSparse::spadd_symbolic(&mm->kh1, mm->C1, mm->Fd, mm->Cd)); 1159 PetscCallCXX(KokkosSparse::spadd_symbolic(&mm->kh2, mm->C2, mm->Fo, mm->Co)); 1160 PetscCallCXX(KokkosSparse::spadd_numeric(&mm->kh1, 1.0, mm->C1, 1.0, mm->Fd, mm->Cd)); 1161 PetscCallCXX(KokkosSparse::spadd_numeric(&mm->kh2, 1.0, mm->C2, 1.0, mm->Fo, mm->Co)); 1162 PetscFunctionReturn(PETSC_SUCCESS); 1163 } 1164 1165 static PetscErrorCode MatProductNumeric_MPIAIJKokkos_AtB(Mat_Product *product, Mat A, Mat B, MatMatStruct_AtB *mm) 1166 { 1167 Mat_MPIAIJ *ampi = static_cast<Mat_MPIAIJ *>(A->data); 1168 Mat_MPIAIJ *bmpi = static_cast<Mat_MPIAIJ *>(B->data); 1169 KokkosCsrMatrix Adt, Aot, Bd, Bo; 1170 MPI_Comm comm; 1171 1172 PetscFunctionBegin; 1173 PetscCall(PetscObjectGetComm((PetscObject)B, &comm)); 1174 PetscCall(MatSeqAIJKokkosGenerateTranspose_Private(ampi->A, &Adt)); 1175 PetscCall(MatSeqAIJKokkosGenerateTranspose_Private(ampi->B, &Aot)); 1176 PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(bmpi->A, &Bd)); 1177 PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(bmpi->B, &Bo)); 1178 1179 // Aot * (B's diag + B's off-diag) 1180 PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh3, Aot, false, Bd, false, mm->C3)); 1181 PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh4, Aot, false, Bo, false, mm->C4)); 1182 1183 // Reduce E (i.e., C3 and C4)'s rows to form F, and overlap the communication 1184 PetscCall(MatMPIAIJKokkosReduceBegin(comm, mm->C3, mm->C4, 0, 0, NULL, NULL, MAT_REUSE_MATRIX, NULL, mm)); 1185 1186 // Adt * (B's diag + B's off-diag) 1187 PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh1, Adt, false, Bd, false, mm->C1)); 1188 PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh2, Adt, false, Bo, false, mm->C2_mid)); 1189 1190 PetscCall(MatMPIAIJKokkosReduceEnd(comm, mm->C3, mm->C4, 0, 0, NULL, NULL, MAT_REUSE_MATRIX, NULL, mm)); 1191 1192 // C = (C1+Fd, C2+Fo) 1193 PetscCallCXX(KokkosSparse::spadd_numeric(&mm->kh1, 1.0, mm->C1, 1.0, mm->Fd, mm->Cd)); 1194 PetscCallCXX(KokkosSparse::spadd_numeric(&mm->kh2, 1.0, mm->C2, 1.0, mm->Fo, mm->Co)); 1195 PetscFunctionReturn(PETSC_SUCCESS); 1196 } 1197 1198 /* MatProductSymbolic_MPIAIJKokkos_AB - AB flavor of MatProductSymbolic_MPIAIJKokkos 1199 1200 Input Parameters: 1201 + product - Mat_Product which carried out the computation. Passed in to access info about this mat product. 1202 . A - an MPIAIJKOKKOS matrix 1203 . B - an MPIAIJKOKKOS matrix 1204 - mm - a struct used to stash intermediate data when computing AB. Persist from symbolic to numeric operations. 1205 */ 1206 static PetscErrorCode MatProductSymbolic_MPIAIJKokkos_AB(Mat_Product *product, Mat A, Mat B, MatMatStruct_AB *mm) 1207 { 1208 Mat_MPIAIJ *ampi = static_cast<Mat_MPIAIJ *>(A->data); 1209 Mat_MPIAIJ *bmpi = static_cast<Mat_MPIAIJ *>(B->data); 1210 KokkosCsrMatrix Ad, Ao, Bd, Bo; 1211 1212 PetscFunctionBegin; 1213 PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(ampi->A, &Ad)); 1214 PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(ampi->B, &Ao)); 1215 PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(bmpi->A, &Bd)); 1216 PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(bmpi->B, &Bo)); 1217 1218 // TODO: add command line options to select spgemm algorithms 1219 auto spgemm_alg = KokkosSparse::SPGEMMAlgorithm::SPGEMM_DEFAULT; // default is TPL if enabled, otherwise KK 1220 1221 // CUDA-10.2's spgemm has bugs. We prefer the SpGEMMreuse APIs introduced in cuda-11.4 1222 #if defined(KOKKOSKERNELS_ENABLE_TPL_CUSPARSE) 1223 #if PETSC_PKG_CUDA_VERSION_LT(11, 4, 0) 1224 spgemm_alg = KokkosSparse::SPGEMMAlgorithm::SPGEMM_KK; 1225 #endif 1226 #endif 1227 1228 mm->kh1.create_spgemm_handle(spgemm_alg); 1229 mm->kh2.create_spgemm_handle(spgemm_alg); 1230 mm->kh3.create_spgemm_handle(spgemm_alg); 1231 mm->kh4.create_spgemm_handle(spgemm_alg); 1232 1233 // Bcast B's rows to form F, and overlap the communication 1234 PetscIntKokkosViewHost map_h(NoInit("map_h"), bmpi->B->cmap->n); 1235 PetscCall(MatMPIAIJKokkosBcastBegin(B, ampi->Mvctx, MAT_INITIAL_MATRIX, map_h.data(), mm)); 1236 1237 // A's diag * (B's diag + B's off-diag) 1238 PetscCallCXX(KokkosSparse::spgemm_symbolic(mm->kh1, Ad, false, Bd, false, mm->C1)); 1239 PetscCallCXX(KokkosSparse::spgemm_symbolic(mm->kh2, Ad, false, Bo, false, mm->C2_mid)); // C2 aliases with C2_mid, except with new column indices 1240 // KK spgemm_symbolic() only populates the result's row map, but not its columns. 1241 // TODO: Remove the fake spgemm_numeric() after KK fixed this problem. 1242 PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh1, Ad, false, Bd, false, mm->C1)); 1243 PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh2, Ad, false, Bo, false, mm->C2_mid)); 1244 #if PETSC_PKG_KOKKOS_KERNELS_VERSION_LT(4, 0, 0) 1245 PetscCallCXX(sort_crs_matrix(mm->C1)); 1246 PetscCallCXX(sort_crs_matrix(mm->C2_mid)); 1247 #endif 1248 1249 PetscCall(MatMPIAIJKokkosBcastEnd(B, ampi->Mvctx, MAT_INITIAL_MATRIX, map_h.data(), mm)); 1250 1251 // A's off-diag * (F's diag + F's off-diag) 1252 PetscCallCXX(KokkosSparse::spgemm_symbolic(mm->kh3, Ao, false, mm->Fd, false, mm->C3)); 1253 PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh3, Ao, false, mm->Fd, false, mm->C3)); 1254 PetscCallCXX(KokkosSparse::spgemm_symbolic(mm->kh4, Ao, false, mm->Fo, false, mm->C4)); 1255 PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh4, Ao, false, mm->Fo, false, mm->C4)); 1256 #if PETSC_PKG_KOKKOS_KERNELS_VERSION_LT(4, 0, 0) 1257 PetscCallCXX(sort_crs_matrix(mm->C3)); 1258 PetscCallCXX(sort_crs_matrix(mm->C4)); 1259 #endif 1260 1261 // Create C2, which shares a, i arrays with C2_mid, but with new column indices and potentially larger column size 1262 MatColIdxKokkosView oldj = mm->C2_mid.graph.entries, newj(NoInit("j"), oldj.extent(0)); 1263 PetscIntKokkosView map = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), map_h); 1264 PetscCallCXX(Kokkos::parallel_for(Kokkos::RangePolicy<>(PetscGetKokkosExecutionSpace(), 0, oldj.extent(0)), KOKKOS_LAMBDA(const PetscInt i) { newj(i) = map(oldj(i)); })); 1265 mm->C2 = KokkosCsrMatrix("C2", mm->C2_mid.numRows(), mm->n /*new column size*/, mm->C2_mid.nnz(), mm->C2_mid.values, mm->C2_mid.graph.row_map, newj); 1266 1267 // C = (Cd, Co) = (C1+C3, C2+C4) 1268 mm->kh1.create_spadd_handle(true); // C1, C3 are sorted 1269 mm->kh2.create_spadd_handle(true); // C2, C4 are sorted 1270 PetscCallCXX(KokkosSparse::spadd_symbolic(&mm->kh1, mm->C1, mm->C3, mm->Cd)); 1271 PetscCallCXX(KokkosSparse::spadd_symbolic(&mm->kh2, mm->C2, mm->C4, mm->Co)); 1272 PetscCallCXX(KokkosSparse::spadd_numeric(&mm->kh1, 1.0, mm->C1, 1.0, mm->C3, mm->Cd)); 1273 PetscCallCXX(KokkosSparse::spadd_numeric(&mm->kh2, 1.0, mm->C2, 1.0, mm->C4, mm->Co)); 1274 PetscFunctionReturn(PETSC_SUCCESS); 1275 } 1276 1277 static PetscErrorCode MatProductNumeric_MPIAIJKokkos_AB(Mat_Product *product, Mat A, Mat B, MatMatStruct_AB *mm) 1278 { 1279 Mat_MPIAIJ *ampi = static_cast<Mat_MPIAIJ *>(A->data); 1280 Mat_MPIAIJ *bmpi = static_cast<Mat_MPIAIJ *>(B->data); 1281 KokkosCsrMatrix Ad, Ao, Bd, Bo; 1282 1283 PetscFunctionBegin; 1284 PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(ampi->A, &Ad)); 1285 PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(ampi->B, &Ao)); 1286 PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(bmpi->A, &Bd)); 1287 PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(bmpi->B, &Bo)); 1288 1289 // Bcast B's rows to form F, and overlap the communication 1290 PetscCall(MatMPIAIJKokkosBcastBegin(B, NULL, MAT_REUSE_MATRIX, NULL, mm)); 1291 1292 // A's diag * (B's diag + B's off-diag) 1293 PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh1, Ad, false, Bd, false, mm->C1)); 1294 PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh2, Ad, false, Bo, false, mm->C2_mid)); 1295 1296 PetscCall(MatMPIAIJKokkosBcastEnd(B, NULL, MAT_REUSE_MATRIX, NULL, mm)); 1297 1298 // A's off-diag * (F's diag + F's off-diag) 1299 PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh3, Ao, false, mm->Fd, false, mm->C3)); 1300 PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh4, Ao, false, mm->Fo, false, mm->C4)); 1301 1302 // C = (Cd, Co) = (C1+C3, C2+C4) 1303 PetscCallCXX(KokkosSparse::spadd_numeric(&mm->kh1, 1.0, mm->C1, 1.0, mm->C3, mm->Cd)); 1304 PetscCallCXX(KokkosSparse::spadd_numeric(&mm->kh2, 1.0, mm->C2, 1.0, mm->C4, mm->Co)); 1305 PetscFunctionReturn(PETSC_SUCCESS); 1306 } 1307 1308 static PetscErrorCode MatProductNumeric_MPIAIJKokkos(Mat C) 1309 { 1310 Mat_MPIAIJ *cmpi = static_cast<Mat_MPIAIJ *>(C->data); 1311 Mat_Product *product; 1312 MatProductData_MPIAIJKokkos *pdata; 1313 MatProductType ptype; 1314 Mat A, B; 1315 1316 PetscFunctionBegin; 1317 MatCheckProduct(C, 1); // make sure C is a product 1318 product = C->product; 1319 pdata = static_cast<MatProductData_MPIAIJKokkos *>(product->data); 1320 ptype = product->type; 1321 A = product->A; 1322 B = product->B; 1323 1324 // See if numeric has already been done in symbolic (e.g., user calls MatMatMult(A,B,MAT_INITIAL_MATRIX,..,C)). 1325 // If yes, skip the numeric, but reset the flag so that next time when user calls MatMatMult(E,F,MAT_REUSE_MATRIX,..,C), 1326 // we still do numeric. 1327 if (pdata->reusesym) { // numeric reuses results from symbolic 1328 pdata->reusesym = PETSC_FALSE; 1329 PetscFunctionReturn(PETSC_SUCCESS); 1330 } 1331 1332 if (ptype == MATPRODUCT_AB) { 1333 PetscCall(MatProductNumeric_MPIAIJKokkos_AB(product, A, B, pdata->mmAB)); 1334 } else if (ptype == MATPRODUCT_AtB) { 1335 PetscCall(MatProductNumeric_MPIAIJKokkos_AtB(product, A, B, pdata->mmAtB)); 1336 } else if (ptype == MATPRODUCT_PtAP) { // BtAB, computed by Z = AB; C= BtZ 1337 PetscCall(MatProductNumeric_MPIAIJKokkos_AB(product, A, B, pdata->mmAB)); 1338 PetscCall(MatProductNumeric_MPIAIJKokkos_AtB(product, B, pdata->Z, pdata->mmAtB)); 1339 } 1340 1341 PetscCall(MatSeqAIJKokkosModifyDevice(cmpi->A)); // mark that A, B on device are modified 1342 PetscCall(MatSeqAIJKokkosModifyDevice(cmpi->B)); 1343 PetscFunctionReturn(PETSC_SUCCESS); 1344 } 1345 1346 static PetscErrorCode MatProductSymbolic_MPIAIJKokkos(Mat C) 1347 { 1348 Mat A, B; 1349 Mat_Product *product; 1350 MatProductType ptype; 1351 MatProductData_MPIAIJKokkos *pdata; 1352 MatMatStruct *mm = NULL; 1353 PetscInt m, n, M, N; 1354 Mat Cd, Co; 1355 MPI_Comm comm; 1356 1357 PetscFunctionBegin; 1358 PetscCall(PetscObjectGetComm((PetscObject)C, &comm)); 1359 MatCheckProduct(C, 1); 1360 product = C->product; 1361 PetscCheck(!product->data, comm, PETSC_ERR_PLIB, "Product data not empty"); 1362 ptype = product->type; 1363 A = product->A; 1364 B = product->B; 1365 1366 switch (ptype) { 1367 case MATPRODUCT_AB: 1368 m = A->rmap->n; 1369 n = B->cmap->n; 1370 M = A->rmap->N; 1371 N = B->cmap->N; 1372 break; 1373 case MATPRODUCT_AtB: 1374 m = A->cmap->n; 1375 n = B->cmap->n; 1376 M = A->cmap->N; 1377 N = B->cmap->N; 1378 break; 1379 case MATPRODUCT_PtAP: 1380 m = B->cmap->n; 1381 n = B->cmap->n; 1382 M = B->cmap->N; 1383 N = B->cmap->N; 1384 break; /* BtAB */ 1385 default: 1386 SETERRQ(comm, PETSC_ERR_PLIB, "Not for product type %s", MatProductTypes[ptype]); 1387 } 1388 1389 PetscCall(MatSetSizes(C, m, n, M, N)); 1390 PetscCall(PetscLayoutSetUp(C->rmap)); 1391 PetscCall(PetscLayoutSetUp(C->cmap)); 1392 PetscCall(MatSetType(C, ((PetscObject)A)->type_name)); 1393 1394 pdata = new MatProductData_MPIAIJKokkos(); 1395 pdata->reusesym = product->api_user; 1396 1397 if (ptype == MATPRODUCT_AB) { 1398 auto mmAB = new MatMatStruct_AB(); 1399 PetscCall(MatProductSymbolic_MPIAIJKokkos_AB(product, A, B, mmAB)); 1400 mm = pdata->mmAB = mmAB; 1401 } else if (ptype == MATPRODUCT_AtB) { 1402 auto mmAtB = new MatMatStruct_AtB(); 1403 PetscCall(MatProductSymbolic_MPIAIJKokkos_AtB(product, A, B, mmAtB)); 1404 mm = pdata->mmAtB = mmAtB; 1405 } else if (ptype == MATPRODUCT_PtAP) { // C = BtAB, computed as Z = AB; C= BtZ 1406 Mat Zd, Zo, Z; // Zd, Zo are owned by pdata->Z 1407 1408 auto mmAB = new MatMatStruct_AB(); 1409 PetscCall(MatProductSymbolic_MPIAIJKokkos_AB(product, A, B, mmAB)); // Z stored as mmAB->{Cd, Co} 1410 PetscCall(MatCreateSeqAIJKokkosWithKokkosCsrMatrix(PETSC_COMM_SELF, mmAB->Cd, &Zd)); 1411 PetscCall(MatCreateSeqAIJKokkosWithKokkosCsrMatrix(PETSC_COMM_SELF, mmAB->Co, &Zo)); 1412 pdata->mmAB = mmAB; 1413 1414 m = A->rmap->n; // Z's layout 1415 n = B->cmap->n; 1416 M = A->rmap->N; 1417 N = B->cmap->N; 1418 PetscCall(MatCreate(comm, &Z)); 1419 PetscCall(MatSetSizes(Z, m, n, M, N)); 1420 PetscCall(PetscLayoutSetUp(Z->rmap)); 1421 PetscCall(PetscLayoutSetUp(Z->cmap)); 1422 PetscCall(MatSetType(Z, MATMPIAIJKOKKOS)); 1423 PetscCall(MatSetMPIAIJKokkosWithSplitSeqAIJKokkosMatrices(Z, Zd, Zo, mmAB->garray)); 1424 1425 auto mmAtB = new MatMatStruct_AtB(); 1426 PetscCall(MatProductSymbolic_MPIAIJKokkos_AtB(product, B, Z, mmAtB)); // final result C stored as mmAtB->{Cd, Co} 1427 1428 pdata->Z = Z; // give ownership to pdata 1429 mm = pdata->mmAtB = mmAtB; 1430 } 1431 1432 PetscCall(MatCreateSeqAIJKokkosWithKokkosCsrMatrix(PETSC_COMM_SELF, mm->Cd, &Cd)); 1433 PetscCall(MatCreateSeqAIJKokkosWithKokkosCsrMatrix(PETSC_COMM_SELF, mm->Co, &Co)); 1434 PetscCall(MatSetMPIAIJKokkosWithSplitSeqAIJKokkosMatrices(C, Cd, Co, mm->garray)); 1435 1436 C->product->data = pdata; 1437 C->product->destroy = MatProductDataDestroy_MPIAIJKokkos; 1438 C->ops->productnumeric = MatProductNumeric_MPIAIJKokkos; 1439 PetscFunctionReturn(PETSC_SUCCESS); 1440 } 1441 1442 PETSC_INTERN PetscErrorCode MatProductSetFromOptions_MPIAIJKokkos(Mat mat) 1443 { 1444 Mat_Product *product = mat->product; 1445 PetscBool match = PETSC_FALSE; 1446 PetscBool usecpu = PETSC_FALSE; 1447 1448 PetscFunctionBegin; 1449 MatCheckProduct(mat, 1); 1450 if (!product->A->boundtocpu && !product->B->boundtocpu) PetscCall(PetscObjectTypeCompare((PetscObject)product->B, ((PetscObject)product->A)->type_name, &match)); 1451 if (match) { /* we can always fallback to the CPU if requested */ 1452 switch (product->type) { 1453 case MATPRODUCT_AB: 1454 if (product->api_user) { 1455 PetscOptionsBegin(PetscObjectComm((PetscObject)mat), ((PetscObject)mat)->prefix, "MatMatMult", "Mat"); 1456 PetscCall(PetscOptionsBool("-matmatmult_backend_cpu", "Use CPU code", "MatMatMult", usecpu, &usecpu, NULL)); 1457 PetscOptionsEnd(); 1458 } else { 1459 PetscOptionsBegin(PetscObjectComm((PetscObject)mat), ((PetscObject)mat)->prefix, "MatProduct_AB", "Mat"); 1460 PetscCall(PetscOptionsBool("-mat_product_algorithm_backend_cpu", "Use CPU code", "MatMatMult", usecpu, &usecpu, NULL)); 1461 PetscOptionsEnd(); 1462 } 1463 break; 1464 case MATPRODUCT_AtB: 1465 if (product->api_user) { 1466 PetscOptionsBegin(PetscObjectComm((PetscObject)mat), ((PetscObject)mat)->prefix, "MatTransposeMatMult", "Mat"); 1467 PetscCall(PetscOptionsBool("-mattransposematmult_backend_cpu", "Use CPU code", "MatTransposeMatMult", usecpu, &usecpu, NULL)); 1468 PetscOptionsEnd(); 1469 } else { 1470 PetscOptionsBegin(PetscObjectComm((PetscObject)mat), ((PetscObject)mat)->prefix, "MatProduct_AtB", "Mat"); 1471 PetscCall(PetscOptionsBool("-mat_product_algorithm_backend_cpu", "Use CPU code", "MatTransposeMatMult", usecpu, &usecpu, NULL)); 1472 PetscOptionsEnd(); 1473 } 1474 break; 1475 case MATPRODUCT_PtAP: 1476 if (product->api_user) { 1477 PetscOptionsBegin(PetscObjectComm((PetscObject)mat), ((PetscObject)mat)->prefix, "MatPtAP", "Mat"); 1478 PetscCall(PetscOptionsBool("-matptap_backend_cpu", "Use CPU code", "MatPtAP", usecpu, &usecpu, NULL)); 1479 PetscOptionsEnd(); 1480 } else { 1481 PetscOptionsBegin(PetscObjectComm((PetscObject)mat), ((PetscObject)mat)->prefix, "MatProduct_PtAP", "Mat"); 1482 PetscCall(PetscOptionsBool("-mat_product_algorithm_backend_cpu", "Use CPU code", "MatPtAP", usecpu, &usecpu, NULL)); 1483 PetscOptionsEnd(); 1484 } 1485 break; 1486 default: 1487 break; 1488 } 1489 match = (PetscBool)!usecpu; 1490 } 1491 if (match) { 1492 switch (product->type) { 1493 case MATPRODUCT_AB: 1494 case MATPRODUCT_AtB: 1495 case MATPRODUCT_PtAP: 1496 mat->ops->productsymbolic = MatProductSymbolic_MPIAIJKokkos; 1497 break; 1498 default: 1499 break; 1500 } 1501 } 1502 /* fallback to MPIAIJ ops */ 1503 if (!mat->ops->productsymbolic) PetscCall(MatProductSetFromOptions_MPIAIJ(mat)); 1504 PetscFunctionReturn(PETSC_SUCCESS); 1505 } 1506 1507 // Mirror of MatCOOStruct_MPIAIJ on device 1508 struct MatCOOStruct_MPIAIJKokkos { 1509 PetscCount n; 1510 PetscSF sf; 1511 PetscCount Annz, Bnnz; 1512 PetscCount Annz2, Bnnz2; 1513 PetscCountKokkosView Ajmap1, Aperm1; 1514 PetscCountKokkosView Bjmap1, Bperm1; 1515 PetscCountKokkosView Aimap2, Ajmap2, Aperm2; 1516 PetscCountKokkosView Bimap2, Bjmap2, Bperm2; 1517 PetscCountKokkosView Cperm1; 1518 MatScalarKokkosView sendbuf, recvbuf; 1519 1520 MatCOOStruct_MPIAIJKokkos(const MatCOOStruct_MPIAIJ *coo_h) 1521 { 1522 auto &exec = PetscGetKokkosExecutionSpace(); 1523 1524 n = coo_h->n; 1525 sf = coo_h->sf; 1526 Annz = coo_h->Annz; 1527 Bnnz = coo_h->Bnnz; 1528 Annz2 = coo_h->Annz2; 1529 Bnnz2 = coo_h->Bnnz2; 1530 Ajmap1 = Kokkos::create_mirror_view_and_copy(exec, PetscCountKokkosViewHost(coo_h->Ajmap1, coo_h->Annz + 1)); 1531 Aperm1 = Kokkos::create_mirror_view_and_copy(exec, PetscCountKokkosViewHost(coo_h->Aperm1, coo_h->Atot1)); 1532 Bjmap1 = Kokkos::create_mirror_view_and_copy(exec, PetscCountKokkosViewHost(coo_h->Bjmap1, coo_h->Bnnz + 1)); 1533 Bperm1 = Kokkos::create_mirror_view_and_copy(exec, PetscCountKokkosViewHost(coo_h->Bperm1, coo_h->Btot1)); 1534 Aimap2 = Kokkos::create_mirror_view_and_copy(exec, PetscCountKokkosViewHost(coo_h->Aimap2, coo_h->Annz2)); 1535 Ajmap2 = Kokkos::create_mirror_view_and_copy(exec, PetscCountKokkosViewHost(coo_h->Ajmap2, coo_h->Annz2 + 1)); 1536 Aperm2 = Kokkos::create_mirror_view_and_copy(exec, PetscCountKokkosViewHost(coo_h->Aperm2, coo_h->Atot2)); 1537 Bimap2 = Kokkos::create_mirror_view_and_copy(exec, PetscCountKokkosViewHost(coo_h->Bimap2, coo_h->Bnnz2)); 1538 Bjmap2 = Kokkos::create_mirror_view_and_copy(exec, PetscCountKokkosViewHost(coo_h->Bjmap2, coo_h->Bnnz2 + 1)); 1539 Bperm2 = Kokkos::create_mirror_view_and_copy(exec, PetscCountKokkosViewHost(coo_h->Bperm2, coo_h->Btot2)); 1540 Cperm1 = Kokkos::create_mirror_view_and_copy(exec, PetscCountKokkosViewHost(coo_h->Cperm1, coo_h->sendlen)); 1541 sendbuf = Kokkos::create_mirror_view(Kokkos::WithoutInitializing, exec, MatScalarKokkosViewHost(coo_h->sendbuf, coo_h->sendlen)); 1542 recvbuf = Kokkos::create_mirror_view(Kokkos::WithoutInitializing, exec, MatScalarKokkosViewHost(coo_h->recvbuf, coo_h->recvlen)); 1543 PetscCallVoid(PetscObjectReference((PetscObject)sf)); 1544 } 1545 1546 ~MatCOOStruct_MPIAIJKokkos() { PetscCallVoid(PetscSFDestroy(&sf)); } 1547 }; 1548 1549 static PetscErrorCode MatCOOStructDestroy_MPIAIJKokkos(void *data) 1550 { 1551 PetscFunctionBegin; 1552 PetscCallCXX(delete static_cast<MatCOOStruct_MPIAIJKokkos *>(data)); 1553 PetscFunctionReturn(PETSC_SUCCESS); 1554 } 1555 1556 static PetscErrorCode MatSetPreallocationCOO_MPIAIJKokkos(Mat mat, PetscCount coo_n, PetscInt coo_i[], PetscInt coo_j[]) 1557 { 1558 PetscContainer container_h, container_d; 1559 MatCOOStruct_MPIAIJ *coo_h; 1560 MatCOOStruct_MPIAIJKokkos *coo_d; 1561 1562 PetscFunctionBegin; 1563 PetscCall(MatSetPreallocationCOO_MPIAIJ(mat, coo_n, coo_i, coo_j)); /* mpiaij->A,B's type is set to seqaijkokkos */ 1564 mat->preallocated = PETSC_TRUE; 1565 PetscCall(MatAssemblyBegin(mat, MAT_FINAL_ASSEMBLY)); 1566 PetscCall(MatAssemblyEnd(mat, MAT_FINAL_ASSEMBLY)); 1567 PetscCall(MatZeroEntries(mat)); 1568 1569 // Copy the COO struct to device 1570 PetscCall(PetscObjectQuery((PetscObject)mat, "__PETSc_MatCOOStruct_Host", (PetscObject *)&container_h)); 1571 PetscCall(PetscContainerGetPointer(container_h, (void **)&coo_h)); 1572 PetscCallCXX(coo_d = new MatCOOStruct_MPIAIJKokkos(coo_h)); 1573 1574 // Put the COO struct in a container and then attach that to the matrix 1575 PetscCall(PetscContainerCreate(PETSC_COMM_SELF, &container_d)); 1576 PetscCall(PetscContainerSetPointer(container_d, coo_d)); 1577 PetscCall(PetscContainerSetUserDestroy(container_d, MatCOOStructDestroy_MPIAIJKokkos)); 1578 PetscCall(PetscObjectCompose((PetscObject)mat, "__PETSc_MatCOOStruct_Device", (PetscObject)container_d)); 1579 PetscCall(PetscContainerDestroy(&container_d)); 1580 PetscFunctionReturn(PETSC_SUCCESS); 1581 } 1582 1583 static PetscErrorCode MatSetValuesCOO_MPIAIJKokkos(Mat mat, const PetscScalar v[], InsertMode imode) 1584 { 1585 Mat_MPIAIJ *mpiaij = static_cast<Mat_MPIAIJ *>(mat->data); 1586 Mat A = mpiaij->A, B = mpiaij->B; 1587 MatScalarKokkosView Aa, Ba; 1588 MatScalarKokkosView v1; 1589 PetscMemType memtype; 1590 PetscContainer container; 1591 MatCOOStruct_MPIAIJKokkos *coo; 1592 Kokkos::DefaultExecutionSpace &exec = PetscGetKokkosExecutionSpace(); 1593 1594 PetscFunctionBegin; 1595 PetscCall(PetscObjectQuery((PetscObject)mat, "__PETSc_MatCOOStruct_Device", (PetscObject *)&container)); 1596 PetscCall(PetscContainerGetPointer(container, (void **)&coo)); 1597 1598 const auto &n = coo->n; 1599 const auto &Annz = coo->Annz; 1600 const auto &Annz2 = coo->Annz2; 1601 const auto &Bnnz = coo->Bnnz; 1602 const auto &Bnnz2 = coo->Bnnz2; 1603 const auto &vsend = coo->sendbuf; 1604 const auto &v2 = coo->recvbuf; 1605 const auto &Ajmap1 = coo->Ajmap1; 1606 const auto &Ajmap2 = coo->Ajmap2; 1607 const auto &Aimap2 = coo->Aimap2; 1608 const auto &Bjmap1 = coo->Bjmap1; 1609 const auto &Bjmap2 = coo->Bjmap2; 1610 const auto &Bimap2 = coo->Bimap2; 1611 const auto &Aperm1 = coo->Aperm1; 1612 const auto &Aperm2 = coo->Aperm2; 1613 const auto &Bperm1 = coo->Bperm1; 1614 const auto &Bperm2 = coo->Bperm2; 1615 const auto &Cperm1 = coo->Cperm1; 1616 1617 PetscCall(PetscGetMemType(v, &memtype)); /* Return PETSC_MEMTYPE_HOST when v is NULL */ 1618 if (PetscMemTypeHost(memtype)) { /* If user gave v[] in host, we need to copy it to device if any */ 1619 v1 = Kokkos::create_mirror_view_and_copy(exec, MatScalarKokkosViewHost((PetscScalar *)v, n)); 1620 } else { 1621 v1 = MatScalarKokkosView((PetscScalar *)v, n); /* Directly use v[]'s memory */ 1622 } 1623 1624 if (imode == INSERT_VALUES) { 1625 PetscCall(MatSeqAIJGetKokkosViewWrite(A, &Aa)); /* write matrix values */ 1626 PetscCall(MatSeqAIJGetKokkosViewWrite(B, &Ba)); 1627 } else { 1628 PetscCall(MatSeqAIJGetKokkosView(A, &Aa)); /* read & write matrix values */ 1629 PetscCall(MatSeqAIJGetKokkosView(B, &Ba)); 1630 } 1631 1632 PetscCall(PetscLogGpuTimeBegin()); 1633 /* Pack entries to be sent to remote */ 1634 Kokkos::parallel_for(Kokkos::RangePolicy<>(exec, 0, vsend.extent(0)), KOKKOS_LAMBDA(const PetscCount i) { vsend(i) = v1(Cperm1(i)); }); 1635 1636 /* Send remote entries to their owner and overlap the communication with local computation */ 1637 PetscCall(PetscSFReduceWithMemTypeBegin(coo->sf, MPIU_SCALAR, PETSC_MEMTYPE_KOKKOS, vsend.data(), PETSC_MEMTYPE_KOKKOS, v2.data(), MPI_REPLACE)); 1638 /* Add local entries to A and B in one kernel */ 1639 Kokkos::parallel_for( 1640 Kokkos::RangePolicy<>(exec, 0, Annz + Bnnz), KOKKOS_LAMBDA(PetscCount i) { 1641 PetscScalar sum = 0.0; 1642 if (i < Annz) { 1643 for (PetscCount k = Ajmap1(i); k < Ajmap1(i + 1); k++) sum += v1(Aperm1(k)); 1644 Aa(i) = (imode == INSERT_VALUES ? 0.0 : Aa(i)) + sum; 1645 } else { 1646 i -= Annz; 1647 for (PetscCount k = Bjmap1(i); k < Bjmap1(i + 1); k++) sum += v1(Bperm1(k)); 1648 Ba(i) = (imode == INSERT_VALUES ? 0.0 : Ba(i)) + sum; 1649 } 1650 }); 1651 PetscCall(PetscSFReduceEnd(coo->sf, MPIU_SCALAR, vsend.data(), v2.data(), MPI_REPLACE)); 1652 1653 /* Add received remote entries to A and B in one kernel */ 1654 Kokkos::parallel_for( 1655 Kokkos::RangePolicy<>(exec, 0, Annz2 + Bnnz2), KOKKOS_LAMBDA(PetscCount i) { 1656 if (i < Annz2) { 1657 for (PetscCount k = Ajmap2(i); k < Ajmap2(i + 1); k++) Aa(Aimap2(i)) += v2(Aperm2(k)); 1658 } else { 1659 i -= Annz2; 1660 for (PetscCount k = Bjmap2(i); k < Bjmap2(i + 1); k++) Ba(Bimap2(i)) += v2(Bperm2(k)); 1661 } 1662 }); 1663 PetscCall(PetscLogGpuTimeEnd()); 1664 1665 if (imode == INSERT_VALUES) { 1666 PetscCall(MatSeqAIJRestoreKokkosViewWrite(A, &Aa)); /* Increase A & B's state etc. */ 1667 PetscCall(MatSeqAIJRestoreKokkosViewWrite(B, &Ba)); 1668 } else { 1669 PetscCall(MatSeqAIJRestoreKokkosView(A, &Aa)); 1670 PetscCall(MatSeqAIJRestoreKokkosView(B, &Ba)); 1671 } 1672 PetscFunctionReturn(PETSC_SUCCESS); 1673 } 1674 1675 static PetscErrorCode MatDestroy_MPIAIJKokkos(Mat A) 1676 { 1677 PetscFunctionBegin; 1678 PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatMPIAIJSetPreallocation_C", NULL)); 1679 PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatMPIAIJGetLocalMatMerge_C", NULL)); 1680 PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatSetPreallocationCOO_C", NULL)); 1681 PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatSetValuesCOO_C", NULL)); 1682 PetscCall(MatDestroy_MPIAIJ(A)); 1683 PetscFunctionReturn(PETSC_SUCCESS); 1684 } 1685 1686 static PetscErrorCode MatShift_MPIAIJKokkos(Mat A, PetscScalar a) 1687 { 1688 Mat_MPIAIJ *mpiaij = static_cast<Mat_MPIAIJ *>(A->data); 1689 PetscBool congruent; 1690 1691 PetscFunctionBegin; 1692 PetscCall(MatHasCongruentLayouts(A, &congruent)); 1693 if (congruent) { // square matrix and the diagonals are solely in the diag block 1694 PetscCall(MatShift(mpiaij->A, a)); 1695 } else { // too hard, use the general version 1696 PetscCall(MatShift_Basic(A, a)); 1697 } 1698 PetscFunctionReturn(PETSC_SUCCESS); 1699 } 1700 1701 static PetscErrorCode MatSetOps_MPIAIJKokkos(Mat B) 1702 { 1703 PetscFunctionBegin; 1704 B->ops->assemblyend = MatAssemblyEnd_MPIAIJKokkos; 1705 B->ops->mult = MatMult_MPIAIJKokkos; 1706 B->ops->multadd = MatMultAdd_MPIAIJKokkos; 1707 B->ops->multtranspose = MatMultTranspose_MPIAIJKokkos; 1708 B->ops->productsetfromoptions = MatProductSetFromOptions_MPIAIJKokkos; 1709 B->ops->destroy = MatDestroy_MPIAIJKokkos; 1710 B->ops->shift = MatShift_MPIAIJKokkos; 1711 1712 PetscCall(PetscObjectComposeFunction((PetscObject)B, "MatMPIAIJSetPreallocation_C", MatMPIAIJSetPreallocation_MPIAIJKokkos)); 1713 PetscCall(PetscObjectComposeFunction((PetscObject)B, "MatMPIAIJGetLocalMatMerge_C", MatMPIAIJGetLocalMatMerge_MPIAIJKokkos)); 1714 PetscCall(PetscObjectComposeFunction((PetscObject)B, "MatSetPreallocationCOO_C", MatSetPreallocationCOO_MPIAIJKokkos)); 1715 PetscCall(PetscObjectComposeFunction((PetscObject)B, "MatSetValuesCOO_C", MatSetValuesCOO_MPIAIJKokkos)); 1716 PetscFunctionReturn(PETSC_SUCCESS); 1717 } 1718 1719 PETSC_INTERN PetscErrorCode MatConvert_MPIAIJ_MPIAIJKokkos(Mat A, MatType mtype, MatReuse reuse, Mat *newmat) 1720 { 1721 Mat B; 1722 Mat_MPIAIJ *a; 1723 1724 PetscFunctionBegin; 1725 if (reuse == MAT_INITIAL_MATRIX) { 1726 PetscCall(MatDuplicate(A, MAT_COPY_VALUES, newmat)); 1727 } else if (reuse == MAT_REUSE_MATRIX) { 1728 PetscCall(MatCopy(A, *newmat, SAME_NONZERO_PATTERN)); 1729 } 1730 B = *newmat; 1731 1732 B->boundtocpu = PETSC_FALSE; 1733 PetscCall(PetscFree(B->defaultvectype)); 1734 PetscCall(PetscStrallocpy(VECKOKKOS, &B->defaultvectype)); 1735 PetscCall(PetscObjectChangeTypeName((PetscObject)B, MATMPIAIJKOKKOS)); 1736 1737 a = static_cast<Mat_MPIAIJ *>(A->data); 1738 if (a->A) PetscCall(MatSetType(a->A, MATSEQAIJKOKKOS)); 1739 if (a->B) PetscCall(MatSetType(a->B, MATSEQAIJKOKKOS)); 1740 if (a->lvec) PetscCall(VecSetType(a->lvec, VECSEQKOKKOS)); 1741 PetscCall(MatSetOps_MPIAIJKokkos(B)); 1742 PetscFunctionReturn(PETSC_SUCCESS); 1743 } 1744 1745 /*MC 1746 MATAIJKOKKOS - "mpiaijkokkos", a matrix type to be used for CSR sparse matrices with Kokkos 1747 1748 A matrix type using Kokkos-Kernels CrsMatrix type for portability across different device types 1749 1750 Options Database Key: 1751 . -mat_type aijkokkos - sets the matrix type to `MATAIJKOKKOS` 1752 1753 Level: beginner 1754 1755 .seealso: [](ch_matrices), `Mat`, `MatCreateAIJKokkos()`, `MATSEQAIJKOKKOS`, `MATSEQAIJ`, `MATMPIAIJ` 1756 M*/ 1757 PETSC_EXTERN PetscErrorCode MatCreate_MPIAIJKokkos(Mat A) 1758 { 1759 PetscFunctionBegin; 1760 PetscCall(PetscKokkosInitializeCheck()); 1761 PetscCall(MatCreate_MPIAIJ(A)); 1762 PetscCall(MatConvert_MPIAIJ_MPIAIJKokkos(A, MATMPIAIJKOKKOS, MAT_INPLACE_MATRIX, &A)); 1763 PetscFunctionReturn(PETSC_SUCCESS); 1764 } 1765 1766 /*@C 1767 MatCreateAIJKokkos - Creates a sparse matrix in `MATAIJKOKOS` (compressed row) format 1768 (the default parallel PETSc format). This matrix will ultimately pushed down 1769 to Kokkos for calculations. 1770 1771 Collective 1772 1773 Input Parameters: 1774 + comm - MPI communicator, set to `PETSC_COMM_SELF` 1775 . m - number of local rows (or `PETSC_DECIDE` to have calculated if `M` is given) 1776 This value should be the same as the local size used in creating the 1777 y vector for the matrix-vector product y = Ax. 1778 . n - This value should be the same as the local size used in creating the 1779 x vector for the matrix-vector product y = Ax. (or `PETSC_DECIDE` to have 1780 calculated if N is given) For square matrices n is almost always `m`. 1781 . M - number of global rows (or `PETSC_DETERMINE` to have calculated if `m` is given) 1782 . N - number of global columns (or `PETSC_DETERMINE` to have calculated if `n` is given) 1783 . d_nz - number of nonzeros per row in DIAGONAL portion of local submatrix 1784 (same value is used for all local rows) 1785 . d_nnz - array containing the number of nonzeros in the various rows of the 1786 DIAGONAL portion of the local submatrix (possibly different for each row) 1787 or `NULL`, if `d_nz` is used to specify the nonzero structure. 1788 The size of this array is equal to the number of local rows, i.e `m`. 1789 For matrices you plan to factor you must leave room for the diagonal entry and 1790 put in the entry even if it is zero. 1791 . o_nz - number of nonzeros per row in the OFF-DIAGONAL portion of local 1792 submatrix (same value is used for all local rows). 1793 - o_nnz - array containing the number of nonzeros in the various rows of the 1794 OFF-DIAGONAL portion of the local submatrix (possibly different for 1795 each row) or `NULL`, if `o_nz` is used to specify the nonzero 1796 structure. The size of this array is equal to the number 1797 of local rows, i.e `m`. 1798 1799 Output Parameter: 1800 . A - the matrix 1801 1802 Level: intermediate 1803 1804 Notes: 1805 It is recommended that one use the `MatCreate()`, `MatSetType()` and/or `MatSetFromOptions()`, 1806 MatXXXXSetPreallocation() paradigm instead of this routine directly. 1807 [MatXXXXSetPreallocation() is, for example, `MatSeqAIJSetPreallocation()`] 1808 1809 The AIJ format, also called compressed row storage), is fully compatible with standard Fortran 1810 storage. That is, the stored row and column indices can begin at 1811 either one (as in Fortran) or zero. 1812 1813 .seealso: [](ch_matrices), `Mat`, `MATAIJKOKOS`, `MATSEQAIJKOKOS`, `MATMPIAIJKOKOS`, `MatCreate()`, `MatCreateAIJ()`, `MatSetValues()`, 1814 `MatSeqAIJSetColumnIndices()`, `MatCreateSeqAIJWithArrays()`, `MATMPIAIJKOKKOS`, `MATAIJKOKKOS` 1815 @*/ 1816 PetscErrorCode MatCreateAIJKokkos(MPI_Comm comm, PetscInt m, PetscInt n, PetscInt M, PetscInt N, PetscInt d_nz, const PetscInt d_nnz[], PetscInt o_nz, const PetscInt o_nnz[], Mat *A) 1817 { 1818 PetscMPIInt size; 1819 1820 PetscFunctionBegin; 1821 PetscCall(MatCreate(comm, A)); 1822 PetscCall(MatSetSizes(*A, m, n, M, N)); 1823 PetscCallMPI(MPI_Comm_size(comm, &size)); 1824 if (size > 1) { 1825 PetscCall(MatSetType(*A, MATMPIAIJKOKKOS)); 1826 PetscCall(MatMPIAIJSetPreallocation(*A, d_nz, d_nnz, o_nz, o_nnz)); 1827 } else { 1828 PetscCall(MatSetType(*A, MATSEQAIJKOKKOS)); 1829 PetscCall(MatSeqAIJSetPreallocation(*A, d_nz, d_nnz)); 1830 } 1831 PetscFunctionReturn(PETSC_SUCCESS); 1832 } 1833