1 #include <petscvec_kokkos.hpp> 2 #include <petscpkg_version.h> 3 #include <petsc/private/sfimpl.h> 4 #include <../src/mat/impls/aij/seq/kokkos/aijkok.hpp> 5 #include <../src/mat/impls/aij/mpi/mpiaij.h> 6 #include <KokkosSparse_spadd.hpp> 7 #include <KokkosSparse_spgemm.hpp> 8 9 PetscErrorCode MatAssemblyEnd_MPIAIJKokkos(Mat A, MatAssemblyType mode) 10 { 11 Mat_MPIAIJ *mpiaij = (Mat_MPIAIJ *)A->data; 12 13 PetscFunctionBegin; 14 PetscCall(MatAssemblyEnd_MPIAIJ(A, mode)); 15 /* E.g., MatCreateSubMatrix() calls MatCreateMPIAIJWithSeqAIJ(comm,A,B,..), which creates Bnew of SEQAIJ and destroys B of SEQAIJKOKKOS. 16 Thus we finalize A/B/lvec's type in MatAssemblyEnd() to handle various cases. 17 */ 18 if (mode == MAT_FINAL_ASSEMBLY) { 19 PetscCall(MatSetType(mpiaij->A, MATSEQAIJKOKKOS)); 20 PetscCall(MatSetType(mpiaij->B, MATSEQAIJKOKKOS)); 21 PetscCall(VecSetType(mpiaij->lvec, VECSEQKOKKOS)); 22 } 23 24 PetscFunctionReturn(PETSC_SUCCESS); 25 } 26 27 PetscErrorCode MatMPIAIJSetPreallocation_MPIAIJKokkos(Mat mat, PetscInt d_nz, const PetscInt d_nnz[], PetscInt o_nz, const PetscInt o_nnz[]) 28 { 29 Mat_MPIAIJ *mpiaij = (Mat_MPIAIJ *)mat->data; 30 31 PetscFunctionBegin; 32 PetscCall(PetscLayoutSetUp(mat->rmap)); 33 PetscCall(PetscLayoutSetUp(mat->cmap)); 34 #if defined(PETSC_USE_DEBUG) 35 if (d_nnz) { 36 PetscInt i; 37 for (i = 0; i < mat->rmap->n; i++) PetscCheck(d_nnz[i] >= 0, PETSC_COMM_SELF, PETSC_ERR_ARG_OUTOFRANGE, "d_nnz cannot be less than 0: local row %" PetscInt_FMT " value %" PetscInt_FMT, i, d_nnz[i]); 38 } 39 if (o_nnz) { 40 PetscInt i; 41 for (i = 0; i < mat->rmap->n; i++) PetscCheck(o_nnz[i] >= 0, PETSC_COMM_SELF, PETSC_ERR_ARG_OUTOFRANGE, "o_nnz cannot be less than 0: local row %" PetscInt_FMT " value %" PetscInt_FMT, i, o_nnz[i]); 42 } 43 #endif 44 #if defined(PETSC_USE_CTABLE) 45 PetscCall(PetscHMapIDestroy(&mpiaij->colmap)); 46 #else 47 PetscCall(PetscFree(mpiaij->colmap)); 48 #endif 49 PetscCall(PetscFree(mpiaij->garray)); 50 PetscCall(VecDestroy(&mpiaij->lvec)); 51 PetscCall(VecScatterDestroy(&mpiaij->Mvctx)); 52 /* Because the B will have been resized we simply destroy it and create a new one each time */ 53 PetscCall(MatDestroy(&mpiaij->B)); 54 55 if (!mpiaij->A) { 56 PetscCall(MatCreate(PETSC_COMM_SELF, &mpiaij->A)); 57 PetscCall(MatSetSizes(mpiaij->A, mat->rmap->n, mat->cmap->n, mat->rmap->n, mat->cmap->n)); 58 } 59 if (!mpiaij->B) { 60 PetscMPIInt size; 61 PetscCallMPI(MPI_Comm_size(PetscObjectComm((PetscObject)mat), &size)); 62 PetscCall(MatCreate(PETSC_COMM_SELF, &mpiaij->B)); 63 PetscCall(MatSetSizes(mpiaij->B, mat->rmap->n, size > 1 ? mat->cmap->N : 0, mat->rmap->n, size > 1 ? mat->cmap->N : 0)); 64 } 65 PetscCall(MatSetType(mpiaij->A, MATSEQAIJKOKKOS)); 66 PetscCall(MatSetType(mpiaij->B, MATSEQAIJKOKKOS)); 67 PetscCall(MatSeqAIJSetPreallocation(mpiaij->A, d_nz, d_nnz)); 68 PetscCall(MatSeqAIJSetPreallocation(mpiaij->B, o_nz, o_nnz)); 69 mat->preallocated = PETSC_TRUE; 70 PetscFunctionReturn(PETSC_SUCCESS); 71 } 72 73 PetscErrorCode MatMult_MPIAIJKokkos(Mat mat, Vec xx, Vec yy) 74 { 75 Mat_MPIAIJ *mpiaij = (Mat_MPIAIJ *)mat->data; 76 PetscInt nt; 77 78 PetscFunctionBegin; 79 PetscCall(VecGetLocalSize(xx, &nt)); 80 PetscCheck(nt == mat->cmap->n, PETSC_COMM_SELF, PETSC_ERR_ARG_SIZ, "Incompatible partition of mat (%" PetscInt_FMT ") and xx (%" PetscInt_FMT ")", mat->cmap->n, nt); 81 PetscCall(VecScatterBegin(mpiaij->Mvctx, xx, mpiaij->lvec, INSERT_VALUES, SCATTER_FORWARD)); 82 PetscCall((*mpiaij->A->ops->mult)(mpiaij->A, xx, yy)); 83 PetscCall(VecScatterEnd(mpiaij->Mvctx, xx, mpiaij->lvec, INSERT_VALUES, SCATTER_FORWARD)); 84 PetscCall((*mpiaij->B->ops->multadd)(mpiaij->B, mpiaij->lvec, yy, yy)); 85 PetscFunctionReturn(PETSC_SUCCESS); 86 } 87 88 PetscErrorCode MatMultAdd_MPIAIJKokkos(Mat mat, Vec xx, Vec yy, Vec zz) 89 { 90 Mat_MPIAIJ *mpiaij = (Mat_MPIAIJ *)mat->data; 91 PetscInt nt; 92 93 PetscFunctionBegin; 94 PetscCall(VecGetLocalSize(xx, &nt)); 95 PetscCheck(nt == mat->cmap->n, PETSC_COMM_SELF, PETSC_ERR_ARG_SIZ, "Incompatible partition of mat (%" PetscInt_FMT ") and xx (%" PetscInt_FMT ")", mat->cmap->n, nt); 96 PetscCall(VecScatterBegin(mpiaij->Mvctx, xx, mpiaij->lvec, INSERT_VALUES, SCATTER_FORWARD)); 97 PetscCall((*mpiaij->A->ops->multadd)(mpiaij->A, xx, yy, zz)); 98 PetscCall(VecScatterEnd(mpiaij->Mvctx, xx, mpiaij->lvec, INSERT_VALUES, SCATTER_FORWARD)); 99 PetscCall((*mpiaij->B->ops->multadd)(mpiaij->B, mpiaij->lvec, zz, zz)); 100 PetscFunctionReturn(PETSC_SUCCESS); 101 } 102 103 PetscErrorCode MatMultTranspose_MPIAIJKokkos(Mat mat, Vec xx, Vec yy) 104 { 105 Mat_MPIAIJ *mpiaij = (Mat_MPIAIJ *)mat->data; 106 PetscInt nt; 107 108 PetscFunctionBegin; 109 PetscCall(VecGetLocalSize(xx, &nt)); 110 PetscCheck(nt == mat->rmap->n, PETSC_COMM_SELF, PETSC_ERR_ARG_SIZ, "Incompatible partition of mat (%" PetscInt_FMT ") and xx (%" PetscInt_FMT ")", mat->rmap->n, nt); 111 PetscCall((*mpiaij->B->ops->multtranspose)(mpiaij->B, xx, mpiaij->lvec)); 112 PetscCall((*mpiaij->A->ops->multtranspose)(mpiaij->A, xx, yy)); 113 PetscCall(VecScatterBegin(mpiaij->Mvctx, mpiaij->lvec, yy, ADD_VALUES, SCATTER_REVERSE)); 114 PetscCall(VecScatterEnd(mpiaij->Mvctx, mpiaij->lvec, yy, ADD_VALUES, SCATTER_REVERSE)); 115 PetscFunctionReturn(PETSC_SUCCESS); 116 } 117 118 /* Merge the "A, B" matrices of mat into a matrix C. mat's type is MPIAIJKOKKOS. C's type is MATSEQAIJKOKKOS. 119 A is put before B. C's size would be A->rmap->n by (A->cmap->n + B->cmap->n). 120 C still uses local column ids. Their corresponding global column ids are returned in glob. 121 */ 122 PetscErrorCode MatMPIAIJGetLocalMatMerge_MPIAIJKokkos(Mat mat, MatReuse reuse, IS *glob, Mat *C) 123 { 124 Mat Ad, Ao; 125 const PetscInt *cmap; 126 127 PetscFunctionBegin; 128 PetscCall(MatMPIAIJGetSeqAIJ(mat, &Ad, &Ao, &cmap)); 129 PetscCall(MatSeqAIJKokkosMergeMats(Ad, Ao, reuse, C)); 130 if (glob) { 131 PetscInt cst, i, dn, on, *gidx; 132 PetscCall(MatGetLocalSize(Ad, NULL, &dn)); 133 PetscCall(MatGetLocalSize(Ao, NULL, &on)); 134 PetscCall(MatGetOwnershipRangeColumn(mat, &cst, NULL)); 135 PetscCall(PetscMalloc1(dn + on, &gidx)); 136 for (i = 0; i < dn; i++) gidx[i] = cst + i; 137 for (i = 0; i < on; i++) gidx[i + dn] = cmap[i]; 138 PetscCall(ISCreateGeneral(PetscObjectComm((PetscObject)Ad), dn + on, gidx, PETSC_OWN_POINTER, glob)); 139 } 140 PetscFunctionReturn(PETSC_SUCCESS); 141 } 142 143 /* Structs used in matrix products of type C=AB, C=A^tB and C=B^tAB */ 144 struct MatMatStruct { 145 PetscInt n, *garray; // C's garray and its size. 146 KokkosCsrMatrix Cd, Co; // C is in split form matrices (all in local column indcies) 147 KokkosCsrMatrix C1, C2, C3, C4; // intermediate mat products 148 KokkosCsrMatrix C2_mid, C4_mid; // alias of C2, C4; share their a[], i[], but with different j[] (hence column size) 149 PetscIntKokkosView E_NzLeft; 150 PetscSF sf = nullptr; // SF to bcast or reduce matrices E to F 151 MatScalarKokkosView rootBuf, leafBuf; 152 KokkosCsrMatrix Fd, Fo; // F in split form 153 154 KernelHandle kh1; // compute C1, add C1+C3 or C1+Fd 155 KernelHandle kh2; // compute C2, add C2+C4 or C2+Fo 156 KernelHandle kh3; // compute C3 157 KernelHandle kh4; // compute C4 158 159 PetscInt E_TeamSize; // kernel launching parameters in merging E or splitting F 160 PetscInt E_VectorLength; 161 PetscInt E_RowsPerTeam; 162 PetscInt F_TeamSize; 163 PetscInt F_VectorLength; 164 PetscInt F_RowsPerTeam; 165 166 ~MatMatStruct() 167 { 168 PetscFunctionBegin; 169 PetscCallAbort(PETSC_COMM_SELF, PetscSFDestroy(&sf)); 170 PetscFunctionReturnVoid(); 171 } 172 }; 173 174 struct MatMatStruct_AB : public MatMatStruct { 175 PetscIntKokkosView F_NzLeft; // plans to split F (in leafbuf) into Fd, Fo 176 PetscIntKokkosView irootloc; // plans to put E (i.e., Bd, Bo) into rootBuf 177 PetscIntKokkosView rowoffset; 178 }; 179 180 struct MatMatStruct_AtB : public MatMatStruct { 181 MatColIdxKokkosView Fdjmap; // plans to reduce data in rootBuf to Fd, Fo 182 MatColIdxKokkosView Fdjperm; 183 MatColIdxKokkosView Fojmap; 184 MatColIdxKokkosView Fojperm; 185 }; 186 187 struct MatProductData_MPIAIJKokkos { 188 MatMatStruct_AB *mmAB = nullptr; 189 MatMatStruct_AtB *mmAtB = nullptr; 190 PetscBool reusesym = PETSC_FALSE; 191 Mat Z = nullptr; // store Z=AB in computing BtAB 192 193 ~MatProductData_MPIAIJKokkos() 194 { 195 delete mmAB; 196 delete mmAtB; 197 PetscCallAbort(PETSC_COMM_SELF, MatDestroy(&Z)); 198 } 199 }; 200 201 static PetscErrorCode MatProductDataDestroy_MPIAIJKokkos(void *data) 202 { 203 PetscFunctionBegin; 204 PetscCallCXX(delete static_cast<MatProductData_MPIAIJKokkos *>(data)); 205 PetscFunctionReturn(PETSC_SUCCESS); 206 } 207 208 /* MatSetMPIAIJKokkosWithSplitSeqAIJKokkosMatrices - Set the diag and offdiag matrices of a MATMPIAIJKOKKOS matrix. 209 It is similar to MatCreateMPIAIJWithSplitArrays. 210 211 Input Parameters: 212 + mat - the MATMPIAIJKOKKOS matrix, which should have its type and layout set, but should not have its diag, offdiag matrices set 213 . A - the diag matrix using local col ids 214 - B - the offdiag matrix using global col ids 215 216 Output Parameter: 217 . mat - the updated MATMPIAIJKOKKOS matrix 218 */ 219 static PetscErrorCode MatSetMPIAIJKokkosWithSplitSeqAIJKokkosMatrices(Mat mat, Mat A, Mat B, PetscInt *garray) 220 { 221 Mat_MPIAIJ *mpiaij = static_cast<Mat_MPIAIJ *>(mat->data); 222 PetscInt m, n, M, N, Am, An, Bm, Bn; 223 224 PetscFunctionBegin; 225 PetscCall(MatGetSize(mat, &M, &N)); 226 PetscCall(MatGetLocalSize(mat, &m, &n)); 227 PetscCall(MatGetLocalSize(A, &Am, &An)); 228 PetscCall(MatGetLocalSize(B, &Bm, &Bn)); 229 230 PetscCheck(m == Am && m == Bm, PETSC_COMM_SELF, PETSC_ERR_PLIB, "local number of rows do not match"); 231 PetscCheck(n == An, PETSC_COMM_SELF, PETSC_ERR_PLIB, "local number of columns do not match"); 232 // PetscCheck(N == Bn, PETSC_COMM_SELF, PETSC_ERR_PLIB, "global number of columns do not match"); 233 PetscCheck(!mpiaij->A && !mpiaij->B, PETSC_COMM_SELF, PETSC_ERR_PLIB, "A, B of the MPIAIJ matrix are not empty"); 234 mpiaij->A = A; 235 mpiaij->B = B; 236 mpiaij->garray = garray; 237 238 mat->preallocated = PETSC_TRUE; 239 mat->nooffprocentries = PETSC_TRUE; /* See MatAssemblyBegin_MPIAIJ. In effect, making MatAssemblyBegin a nop */ 240 241 PetscCall(MatSetOption(mat, MAT_NO_OFF_PROC_ENTRIES, PETSC_TRUE)); 242 PetscCall(MatAssemblyBegin(mat, MAT_FINAL_ASSEMBLY)); 243 /* MatAssemblyEnd is critical here. It sets mat->offloadmask according to A and B's, and 244 also gets mpiaij->B compacted, with its col ids and size reduced 245 */ 246 PetscCall(MatAssemblyEnd(mat, MAT_FINAL_ASSEMBLY)); 247 PetscCall(MatSetOption(mat, MAT_NO_OFF_PROC_ENTRIES, PETSC_FALSE)); 248 PetscCall(MatSetOption(mat, MAT_NEW_NONZERO_LOCATION_ERR, PETSC_TRUE)); 249 PetscFunctionReturn(PETSC_SUCCESS); 250 } 251 252 // Adapted from Kokkos-Kernels spmv_launch_parameters(), to get parameters in Kokkos nested loops which we used to merge or 253 // split csr matrices. The rule is to have "vector_length * team_size" be around 256 on GPUs (e.g., for a CUDA thread block) 254 template <class ExecutionSpace> 255 static PetscErrorCode MatMergeGetLaunchParameters(PetscInt numRows, PetscInt nnz, PetscInt rows_per_thread, PetscInt &team_size, PetscInt &vector_length, PetscInt &rows_per_team) 256 { 257 Kokkos::TeamPolicy<ExecutionSpace> teamPolicy(128, Kokkos::AUTO); 258 259 PetscFunctionBegin; 260 PetscInt nnz_per_row = numRows ? (nnz / numRows) : 0; // we might meet empty matrices 261 262 if (nnz_per_row < 1) nnz_per_row = 1; 263 264 int max_vector_length = teamPolicy.vector_length_max(); 265 266 if (vector_length < 1) { 267 vector_length = 1; 268 while (vector_length < max_vector_length && vector_length * 6 < nnz_per_row) vector_length *= 2; 269 } 270 271 // Determine rows per thread 272 if (rows_per_thread < 1) { 273 if (KokkosKernels::Impl::kk_is_gpu_exec_space<ExecutionSpace>()) rows_per_thread = 1; 274 else { 275 if (nnz_per_row < 20 && nnz > 5000000) { 276 rows_per_thread = 256; 277 } else rows_per_thread = 64; 278 } 279 } 280 281 if (team_size < 1) { 282 if (KokkosKernels::Impl::kk_is_gpu_exec_space<ExecutionSpace>()) { 283 team_size = 256 / vector_length; 284 } else { 285 team_size = 1; 286 } 287 } 288 289 rows_per_team = rows_per_thread * team_size; 290 291 if (rows_per_team < 0) { 292 PetscInt nnz_per_team = 4096; 293 PetscInt conc = ExecutionSpace().concurrency(); 294 while ((conc * nnz_per_team * 4 > nnz) && (nnz_per_team > 256)) nnz_per_team /= 2; 295 rows_per_team = (nnz_per_team + nnz_per_row - 1) / nnz_per_row; 296 } 297 PetscFunctionReturn(PETSC_SUCCESS); 298 } 299 300 /* 301 Reduce two sets of global indices into local ones 302 303 Input Parameters: 304 + n1 - size of garray1[], the first set 305 . garray1[n1] - a sorted global index array (without duplicates) 306 . m - size of indices[], the second set 307 - indices[m] - a unsorted global index array (might have duplicates), which will be updated on output into local ones 308 309 Output Parameters: 310 + n2 - size of garray2[], the merged set, which combines garray1[] and indices[] 311 . garray2[n2] - allocated by callee using PetscMalloc1(). Contains sorted unique global indices (without duplicates). Caller needs to free it. 312 . map[n1] - allocated by caller. It gives garray1[i] = garray2[map[i]] 313 - indices[m] - on output, global indices in this array are rewritten with local ones, i.e, indices_input[i] = garray2[indices_output[i]] 314 315 Example, say 316 n1 = 5 317 garray1[5] = {1, 4, 7, 8, 10} 318 m = 4 319 indices[4] = {2, 4, 8, 9} 320 321 Combining them together, we have 7 global indices in garray2[] 322 n2 = 7 323 garray2[7] = {1, 2, 4, 7, 8, 9, 10} 324 325 And we have map[] to connect "garray1[i] = garray2[map[i]], i=[0,n1)" 326 map[5] = {0, 2, 3, 4, 6} 327 328 On output, indices[] is updated with local indices 329 indices[4] = {1, 2, 4, 5} 330 */ 331 static PetscErrorCode ReduceTwoSetsOfGlobalIndices(PetscInt n1, const PetscInt *garray1, PetscInt m, PetscInt *indices, PetscInt *n2_, PetscInt **garray2_, PetscInt *map) 332 { 333 PetscHMapI g2l = nullptr; 334 PetscHashIter iter; 335 PetscInt tot, key, val; // total unique global indices. key is global id; val is local id 336 PetscInt n2, *garray2; 337 338 PetscFunctionBegin; 339 tot = 0; 340 PetscCall(PetscHMapICreateWithSize(n1, &g2l)); 341 for (PetscInt i = 0; i < m; i++) { // insert those in indices[] 342 PetscCall(PetscHMapIGetWithDefault(g2l, indices[i], -1, &val)); // if not exist, val is set with -1 343 if (val < 0) PetscCall(PetscHMapISet(g2l, indices[i], tot++)); // val < 0 means gid is not in the hash table yet 344 } 345 346 for (PetscInt i = 0; i < n1; i++) { // insert those in garray1[] 347 PetscCall(PetscHMapIGetWithDefault(g2l, garray1[i], -1, &val)); 348 if (val < 0) PetscCall(PetscHMapISet(g2l, garray1[i], tot++)); 349 } 350 351 // Pull out (unique) globals in the hash table and put them in garray2[] 352 n2 = tot; 353 PetscCall(PetscMalloc1(n2, &garray2)); 354 tot = 0; 355 PetscHashIterBegin(g2l, iter); 356 while (!PetscHashIterAtEnd(g2l, iter)) { 357 PetscHashIterGetKey(g2l, iter, key); 358 PetscHashIterNext(g2l, iter); 359 garray2[tot++] = key; 360 } 361 362 // Sort garray2[] and then map them to local indices starting from 0 363 PetscCall(PetscSortInt(n2, garray2)); 364 PetscCall(PetscHMapIClear(g2l)); 365 for (PetscInt i = 0; i < tot; i++) PetscCall(PetscHMapISet(g2l, garray2[i], i)); // i is the local id 366 367 // Rewrite indices[] with local indices 368 for (PetscInt i = 0; i < m; i++) { 369 PetscCall(PetscHMapIGetWithDefault(g2l, indices[i], -1, &val)); 370 PetscAssert(val >= 0, PETSC_COMM_SELF, PETSC_ERR_PLIB, "Met a negative local column index"); 371 indices[i] = val; 372 } 373 // Record the map that maps garray1[i] to garray2[map[i]] 374 for (PetscInt i = 0; i < n1; i++) PetscCall(PetscHMapIGetWithDefault(g2l, garray1[i], -1, &map[i])); 375 PetscCall(PetscHMapIDestroy(&g2l)); 376 *n2_ = n2; 377 *garray2_ = garray2; 378 PetscFunctionReturn(PETSC_SUCCESS); 379 } 380 381 /* 382 MatMPIAIJKokkosReduce - Reduce rows of a MPIAIJKOKKOS matrix (E, in split form) to produce another matrix (F, also in split form, stored in mm) 383 384 It is the reverse of MatMPIAIJKokkosBcast() in some sense, but with a different signature since we do not really need a fully populated MPIAIJKOKKOS E. 385 386 Think each row of E as a leaf, then the given ownerSF specifies roots for the leaves. Roots may connect to multiple leaves. 387 In this routine, we sparse-merge leaves (rows) at their roots to form potentially longer rows in F. F's number of rows will be nroots of ownerSF. 388 389 Input Parameters: 390 + comm - MPI communicator of E 391 . A - diag block of E, using local column indices 392 . B - off-diag block of E, using local column indices 393 . cstart - (global) start column of Ed 394 . cend - (global) end column + 1 of Ed. In other words, E's column ownership is in range of [cstart, cend) 395 . garray1[n1] - global column indices of Eo. Here n1 is Eo's column size. 396 . ownerSF - the SF specifies ownership (root) of rows in E 397 . reuse - either MAT_INITIAL_MATRIX or MAT_REUSE_MATRIX 398 - mm - to stash intermediate data structures for reuse 399 400 Output Parameters: 401 + map[n1] - allocated by caller. It maps garray1[] to garray2[]. See more at ReduceTwoSetsOfGlobalIndices(). 402 - mm - contains various info, such as garray2[], F (Fd, Fo) etc. 403 404 Notes: 405 When reuse = MAT_REUSE_MATRIX, cstart, cend, garray1, ownerSF, map are not significant. 406 407 */ 408 static PetscErrorCode MatMPIAIJKokkosReduceBegin(MPI_Comm comm, KokkosCsrMatrix A, KokkosCsrMatrix B, PetscInt cstart, PetscInt cend, const PetscInt *garray1, PetscSF ownerSF, MatReuse reuse, PetscInt *map, MatMatStruct_AtB *mm) 409 { 410 PetscFunctionBegin; 411 if (reuse == MAT_INITIAL_MATRIX) { 412 PetscInt Em = A.numRows(), Fm; 413 PetscInt n1 = B.numCols(); 414 415 PetscCall(PetscSFGetGraph(ownerSF, &Fm, NULL, NULL, NULL)); // Fm = #rows of F = nroots of ownerSF 416 417 // Do the analysis on host 418 auto Ai_h = Kokkos::create_mirror_view_and_copy(Kokkos::HostSpace(), A.graph.row_map); 419 auto Aj_h = Kokkos::create_mirror_view_and_copy(Kokkos::HostSpace(), A.graph.entries); 420 auto Bi_h = Kokkos::create_mirror_view_and_copy(Kokkos::HostSpace(), B.graph.row_map); 421 auto Bj_h = Kokkos::create_mirror_view_and_copy(Kokkos::HostSpace(), B.graph.entries); 422 const MatRowMapType *Ai = Ai_h.data(), *Bi = Bi_h.data(); 423 const MatColIdxType *Aj = Aj_h.data(), *Bj = Bj_h.data(); 424 425 // Count how many nonzeros of each row in E are in the left of the diag block [cstart,cend) 426 PetscIntKokkosViewHost E_NzLeft_h(NoInit("E_NzLeft_h"), Em), E_RowLen_h(NoInit("E_RowLen_h"), Em); 427 PetscInt *E_NzLeft = E_NzLeft_h.data(), *E_RowLen = E_RowLen_h.data(); 428 for (PetscInt i = 0; i < Em; i++) { 429 const PetscInt *first, *last, *it; 430 PetscInt count, step; 431 // std::lower_bound(first,last,cstart), but need to use global column indices 432 first = Bj + Bi[i]; 433 last = Bj + Bi[i + 1]; 434 count = last - first; 435 while (count > 0) { 436 it = first; 437 step = count / 2; 438 it += step; 439 if (garray1[*it] < cstart) { // map local to global 440 first = ++it; 441 count -= step + 1; 442 } else count = step; 443 } 444 E_NzLeft[i] = first - (Bj + Bi[i]); 445 E_RowLen[i] = (Ai[i + 1] - Ai[i]) + (Bi[i + 1] - Bi[i]); 446 } 447 448 // Get length of rows (i.e., sizes of leaves) that contribute to my roots 449 const PetscMPIInt *iranks, *ranks; 450 const PetscInt *ioffset, *irootloc, *roffset, *rmine; 451 PetscInt niranks, nranks; 452 MPI_Request *reqs; 453 PetscMPIInt tag; 454 PetscSF reduceSF; 455 PetscInt *sdisp, *rdisp; 456 457 PetscCall(PetscCommGetNewTag(comm, &tag)); 458 PetscCall(PetscSFGetLeafRanks(ownerSF, &niranks, &iranks, &ioffset, &irootloc)); // get leaf ranks connecting to roots on this process (I'll recv from them) 459 PetscCall(PetscSFGetRootRanks(ownerSF, &nranks, &ranks, &roffset, &rmine, NULL)); // get root ranks this process connects (I'll send to them) 460 461 // Find out length of each row I will receive. Even for the same row index, when they are from 462 // different senders, they might have different lengths (and sparsity patterns) 463 PetscInt sendRowCnt = roffset[nranks], recvRowCnt = ioffset[niranks]; 464 PetscInt *sendRowLen, *recvRowLen; // lengths of rows of I need to send/recv per process 465 466 PetscCall(PetscMalloc5(sendRowCnt, &sendRowLen, recvRowCnt + 1, &recvRowLen, nranks, &sdisp, niranks + 1, &rdisp, nranks + niranks, &reqs)); 467 468 for (PetscInt i = 0; i < sendRowCnt; i++) sendRowLen[i] = E_RowLen[rmine[i]]; 469 recvRowLen[0] = 0; // since we will make it in CSR format later 470 recvRowLen++; // advance the pointer now 471 for (PetscInt i = 0; i < niranks; i++) { MPI_Irecv(&recvRowLen[ioffset[i]], ioffset[i + 1] - ioffset[i], MPIU_INT, iranks[i], tag, comm, &reqs[nranks + i]); } 472 for (PetscInt i = 0; i < nranks; i++) { MPI_Isend(&sendRowLen[roffset[i]], roffset[i + 1] - roffset[i], MPIU_INT, ranks[i], tag, comm, &reqs[i]); } 473 PetscCallMPI(MPI_Waitall(nranks + niranks, reqs, MPI_STATUSES_IGNORE)); 474 475 // Build the real PetscSF for reducing E rows (buffer to buffer) 476 rdisp[0] = 0; 477 for (PetscInt i = 0; i < niranks; i++) { 478 rdisp[i + 1] = rdisp[i]; 479 for (PetscInt j = ioffset[i]; j < ioffset[i + 1]; j++) { rdisp[i + 1] += recvRowLen[j]; } 480 } 481 recvRowLen--; // put it back into csr format 482 for (PetscInt i = 0; i < recvRowCnt; i++) recvRowLen[i + 1] += recvRowLen[i]; 483 484 for (PetscInt i = 0; i < nranks; i++) { MPI_Irecv(&sdisp[i], 1, MPIU_INT, ranks[i], tag, comm, &reqs[i]); } 485 for (PetscInt i = 0; i < niranks; i++) { MPI_Isend(&rdisp[i], 1, MPIU_INT, iranks[i], tag, comm, &reqs[nranks + i]); } 486 PetscCallMPI(MPI_Waitall(nranks + niranks, reqs, MPI_STATUSES_IGNORE)); 487 488 PetscInt nleaves = 0, Enz = 0; // leaves are nonzeros I will send 489 PetscInt nroots = rdisp[niranks]; // roots are nonzeros I will recv 490 PetscSFNode *iremote; 491 492 for (PetscInt i = 0; i < Em; i++) Enz += E_RowLen[i]; 493 PetscAssert(A.nnz() + B.nnz() == Enz, comm, PETSC_ERR_PLIB, "Enz should be equal to sum of nnz of A and B"); 494 PetscCallMPI(PetscMalloc1(Enz, &iremote)); // no free, since we give ownership to reduceSF 495 496 for (PetscInt i = 0; i < nranks; i++) { 497 PetscInt count = 0; 498 for (PetscInt j = roffset[i]; j < roffset[i + 1]; j++) count += E_RowLen[rmine[j]]; 499 for (PetscInt j = 0; j < count; j++) { 500 iremote[nleaves + j].rank = ranks[i]; 501 iremote[nleaves + j].index = sdisp[i] + j; 502 } 503 nleaves += count; 504 } 505 PetscCheck(nleaves == Enz, comm, PETSC_ERR_PLIB, "nleaves should be equal to Enz"); 506 507 PetscCall(PetscSFCreate(comm, &reduceSF)); 508 PetscCall(PetscSFSetGraph(reduceSF, nroots, nleaves, NULL, PETSC_OWN_POINTER, iremote, PETSC_OWN_POINTER)); 509 510 // Copy (global) column indices of the needed rows in E to sendCol[], and then PetscSFReduce to recvCol[] 511 PetscInt *sendCol, *recvCol; 512 PetscCall(PetscMalloc2(nleaves, &sendCol, nroots, &recvCol)); 513 for (PetscInt k = 0; k < roffset[nranks]; k++) { 514 PetscInt i = rmine[k]; // row to be copied 515 PetscInt *buf = &sendCol[Ai[i] + Bi[i]]; 516 PetscInt nzLeft = E_NzLeft[i]; 517 PetscInt alen = Ai[i + 1] - Ai[i], blen = Bi[i + 1] - Bi[i]; 518 for (PetscInt j = 0; j < alen + blen; j++) { 519 if (j < nzLeft) { 520 buf[j] = garray1[Bj[Bi[i] + j]]; // left B, in global 521 } else if (j < nzLeft + alen) { 522 buf[j] = Aj[Ai[i] + j - nzLeft] + cstart; // diag A, also in global 523 } else { 524 buf[j] = garray1[Bj[Bi[i] + j - alen]]; // right B, in global 525 } 526 } 527 } 528 PetscCall(PetscSFReduceWithMemTypeBegin(reduceSF, MPIU_INT, PETSC_MEMTYPE_HOST, sendCol, PETSC_MEMTYPE_HOST, recvCol, MPI_REPLACE)); 529 PetscCall(PetscSFReduceEnd(reduceSF, MPIU_INT, sendCol, recvCol, MPI_REPLACE)); 530 531 // With recvCol[], we do a series of analysis to get i, j of Fd, Fo, and build plans to reduce nonzeros in recv buffers to Fd and Fo 532 PetscInt *recvRowPerm, *recvColSorted; 533 PetscInt *recvNzPerm, *recvNzPermSorted; 534 PetscCall(PetscMalloc4(recvRowCnt, &recvRowPerm, nroots, &recvColSorted, nroots, &recvNzPerm, nroots, &recvNzPermSorted)); 535 536 for (PetscInt i = 0; i < nroots; i++) recvNzPerm[i] = i; // numbering all received nonzeros 537 for (PetscInt i = 0; i < recvRowCnt; i++) recvRowPerm[i] = i; // put up a permutation array, so that after sorting we know where to get a row in recvCol[] 538 PetscCall(PetscSortIntWithPermutation(recvRowCnt, irootloc, recvRowPerm)); // irootloc[] (owned by ownerSF) won't be changed 539 540 // i[] array, nz are always easiest to compute 541 MatRowMapKokkosViewHost Fdi_h(NoInit("Fdi_h"), Fm + 1), Foi_h(NoInit("Foi_h"), Fm + 1); 542 MatRowMapType *Fdi, *Foi; 543 PetscInt FnzDups = 0, Fdnz = 0, FdnzDups = 0, Fonz = 0, FonzDups = 0; // nz (with or without dups) in F, Fd, Fo 544 PetscInt iter; 545 546 Kokkos::deep_copy(Fdi_h, 0); // zero, as we will do 'val++' on them 547 Kokkos::deep_copy(Foi_h, 0); 548 Fdi = Fdi_h.data() + 1; // +1 for easy indexing in code below 549 Foi = Foi_h.data() + 1; 550 iter = 0; 551 while (iter < recvRowCnt) { // iter over received rows 552 PetscInt curRowIdx = irootloc[recvRowPerm[iter]]; 553 PetscInt dupRows = 1; // current row has this many contributing rows (of various sparsity patterns) 554 555 while (iter + dupRows < recvRowCnt && irootloc[recvRowPerm[iter + dupRows]] == curRowIdx) dupRows++; 556 557 // Copy column indices (and their permutation) of these rows into recvColSorted & recvNzPermSorted 558 PetscInt nz = 0; // nz (with dups) in the current row 559 PetscInt *jbuf = recvColSorted + FnzDups; 560 PetscInt *pbuf = recvNzPermSorted + FnzDups; 561 PetscInt *jbuf2 = jbuf; // temp pointers 562 PetscInt *pbuf2 = pbuf; 563 for (PetscInt d = 0; d < dupRows; d++) { 564 PetscInt i = recvRowPerm[iter + d]; 565 PetscInt len = recvRowLen[i + 1] - recvRowLen[i]; 566 PetscCall(PetscArraycpy(jbuf2, &recvCol[recvRowLen[i]], len)); 567 PetscCall(PetscArraycpy(pbuf2, &recvNzPerm[recvRowLen[i]], len)); 568 jbuf2 += len; 569 pbuf2 += len; 570 nz += len; 571 } 572 PetscCall(PetscIntSortSemiOrderedWithArray(nz, jbuf, pbuf)); // It could be improved with k-way merge sort, since the rows are already sorted 573 574 // Scan column indices (in jbuf[0,nz), might have dups) of this row, and see how many go to Fd and how many go to Fo 575 PetscInt cur = 0; 576 while (cur < nz) { 577 PetscInt curColIdx = jbuf[cur]; 578 PetscInt dups = 1; 579 580 while (cur + dups < nz && jbuf[cur + dups] == curColIdx) dups++; 581 if (curColIdx >= cstart && curColIdx < cend) { 582 Fdi[curRowIdx]++; 583 FdnzDups += dups; 584 } else { 585 Foi[curRowIdx]++; 586 FonzDups += dups; 587 } 588 cur += dups; 589 } 590 591 FnzDups += nz; 592 iter += dupRows; // Move to next unique row 593 } 594 595 Fdi = Fdi_h.data(); // restore Fdi, Foi and make them CSR 596 Foi = Foi_h.data(); 597 for (PetscInt i = 0; i < Fm; i++) { 598 Fdi[i + 1] += Fdi[i]; 599 Foi[i + 1] += Foi[i]; 600 } 601 Fdnz = Fdi[Fm]; 602 Fonz = Foi[Fm]; 603 PetscCall(PetscFree2(sendCol, recvCol)); 604 605 // Allocate j, jmap, jperm for Fd and Fo 606 MatColIdxKokkosViewHost Fdj_h(NoInit("Fdj_h"), Fdnz), Foj_h(NoInit("Foj_h"), Fonz); 607 MatRowMapKokkosViewHost Fdjmap_h(NoInit("Fdjmap_h"), Fdnz + 1), Fojmap_h(NoInit("Fojmap_h"), Fonz + 1); // +1 to make csr 608 MatRowMapKokkosViewHost Fdjperm_h(NoInit("Fdjperm_h"), FdnzDups), Fojperm_h(NoInit("Fojperm_h"), FonzDups); 609 MatColIdxType *Fdj = Fdj_h.data(), *Foj = Foj_h.data(); 610 MatRowMapType *Fdjmap = Fdjmap_h.data(), *Fojmap = Fojmap_h.data(); 611 MatRowMapType *Fdjperm = Fdjperm_h.data(), *Fojperm = Fojperm_h.data(); 612 613 // Scan recvColSorted[] again, and fill j, jmap, jperm for Fd and Fo 614 Fdjmap[0] = 0; 615 Fojmap[0] = 0; 616 FnzDups = 0; 617 Fdnz = 0; 618 Fonz = 0; 619 iter = 0; // iter over received rows 620 while (iter < recvRowCnt) { 621 PetscInt curRowIdx = irootloc[recvRowPerm[iter]]; // current row idx 622 PetscInt dupRows = 1; // It has this many contributing rows (of various lengths) 623 PetscInt nz = 0; // nz (with dups) in the current row 624 625 while (iter + dupRows < recvRowCnt && irootloc[recvRowPerm[iter + dupRows]] == curRowIdx) dupRows++; 626 for (PetscInt d = 0; d < dupRows; d++) { 627 PetscInt i = recvRowPerm[iter + d]; 628 nz += recvRowLen[i + 1] - recvRowLen[i]; 629 } 630 631 PetscInt *jbuf = recvColSorted + FnzDups; 632 // Scan columns (in jbuf[0,nz) of this row, copy them and their permutation to j[] and jperm[] of Fd and Fo 633 PetscInt cur = 0; 634 while (cur < nz) { 635 PetscInt curColIdx = jbuf[cur]; 636 PetscInt dups = 1; 637 638 while (cur + dups < nz && jbuf[cur + dups] == curColIdx) dups++; 639 if (curColIdx >= cstart && curColIdx < cend) { 640 Fdj[Fdnz] = curColIdx - cstart; // easily convert to local 641 Fdjmap[Fdnz + 1] = Fdjmap[Fdnz] + dups; 642 for (PetscInt j = 0; j < dups; j++) Fdjperm[Fdjmap[Fdnz] + j] = recvNzPermSorted[FnzDups + j]; 643 FdnzDups += dups; 644 Fdnz++; 645 } else { 646 Foj[Fonz] = curColIdx; // in global 647 Fojmap[Fonz + 1] = Fojmap[Fonz] + dups; 648 for (PetscInt j = 0; j < dups; j++) Fojperm[Fojmap[Fonz] + j] = recvNzPermSorted[FnzDups + j]; 649 FonzDups += dups; 650 Fonz++; 651 } 652 cur += dups; 653 FnzDups += dups; 654 } 655 iter += dupRows; // Move to next unique row 656 } 657 PetscCall(PetscFree4(recvRowPerm, recvColSorted, recvNzPerm, recvNzPermSorted)); 658 PetscCall(PetscFree5(sendRowLen, recvRowLen, sdisp, rdisp, reqs)); 659 660 // Combine global column indices in garray1[] and Foj[] 661 PetscInt n2, *garray2; 662 663 PetscCall(ReduceTwoSetsOfGlobalIndices(n1, garray1, Fonz, Foj, &n2, &garray2, map)); 664 mm->sf = reduceSF; 665 mm->leafBuf = MatScalarKokkosView(NoInit("leafBuf"), nleaves); 666 mm->rootBuf = MatScalarKokkosView(NoInit("rootBuf"), nroots); 667 mm->garray = garray2; // give ownership, so no free 668 mm->n = n2; 669 mm->E_NzLeft = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), E_NzLeft_h); 670 mm->Fdjmap = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Fdjmap_h); 671 mm->Fdjperm = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Fdjperm_h); 672 mm->Fojmap = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Fojmap_h); 673 mm->Fojperm = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Fojperm_h); 674 675 // Output Fd and Fo in KokkosCsrMatrix format 676 MatScalarKokkosView Fda_d(NoInit("Fda_d"), Fdnz); 677 MatRowMapKokkosView Fdi_d = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Fdi_h); 678 MatColIdxKokkosView Fdj_d = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Fdj_h); 679 MatScalarKokkosView Foa_d(NoInit("Foa_d"), Fonz); 680 MatRowMapKokkosView Foi_d = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Foi_h); 681 MatColIdxKokkosView Foj_d = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Foj_h); 682 683 PetscCallCXX(mm->Fd = KokkosCsrMatrix("Fd", Fm, cend - cstart, Fdnz, Fda_d, Fdi_d, Fdj_d)); 684 PetscCallCXX(mm->Fo = KokkosCsrMatrix("Fo", Fm, n2, Fonz, Foa_d, Foi_d, Foj_d)); // Fo's column size is n2, length of garray2[] 685 686 // Compute kernel launch parameters in merging E 687 PetscInt teamSize, vectorLength, rowsPerTeam; 688 689 teamSize = vectorLength = rowsPerTeam = -1; 690 PetscCall(MatMergeGetLaunchParameters<DefaultExecutionSpace>(Em, Enz, -1, teamSize, vectorLength, rowsPerTeam)); 691 mm->E_TeamSize = teamSize; 692 mm->E_VectorLength = vectorLength; 693 mm->E_RowsPerTeam = rowsPerTeam; 694 } else PetscCheck(reuse == MAT_REUSE_MATRIX, comm, PETSC_ERR_PLIB, "Unsupported MatReuse enum %d", reuse); 695 696 // Handy aliases 697 auto &Aa = A.values; 698 auto &Ba = B.values; 699 const auto &Ai = A.graph.row_map; 700 const auto &Bi = B.graph.row_map; 701 const auto &E_NzLeft = mm->E_NzLeft; 702 auto &leafBuf = mm->leafBuf; 703 auto &rootBuf = mm->rootBuf; 704 PetscSF reduceSF = mm->sf; 705 PetscInt Em = A.numRows(); 706 PetscInt teamSize = mm->E_TeamSize; 707 PetscInt vectorLength = mm->E_VectorLength; 708 PetscInt rowsPerTeam = mm->E_RowsPerTeam; 709 PetscInt workSets = (Em + rowsPerTeam - 1) / rowsPerTeam; 710 711 // Copy rows in A/B of E to leafBuf, then pass it to rootBuf 712 PetscCallCXX(Kokkos::parallel_for( 713 Kokkos::TeamPolicy<>(workSets, teamSize, vectorLength), KOKKOS_LAMBDA(const KokkosTeamMemberType &t) { 714 Kokkos::parallel_for(Kokkos::TeamThreadRange(t, 0, rowsPerTeam), [&](PetscInt k) { 715 PetscInt i = t.league_rank() * rowsPerTeam + k; // i-th row in F 716 if (i < Em) { 717 PetscInt disp = Ai(i) + Bi(i); 718 PetscInt alen = Ai(i + 1) - Ai(i); 719 PetscInt blen = Bi(i + 1) - Bi(i); 720 PetscInt nzleft = E_NzLeft(i); 721 722 Kokkos::parallel_for(Kokkos::ThreadVectorRange(t, alen + blen), [&](PetscInt j) { 723 MatScalar &val = leafBuf(disp + j); 724 if (j < nzleft) { // B left 725 val = Ba(Bi(i) + j); 726 } else if (j < nzleft + alen) { // diag A 727 val = Aa(Ai(i) + j - nzleft); 728 } else { // B right 729 val = Ba(Bi(i) + j - alen); 730 } 731 }); 732 } 733 }); 734 })); 735 PetscCall(PetscSFReduceWithMemTypeBegin(reduceSF, MPIU_SCALAR, PETSC_MEMTYPE_KOKKOS, leafBuf.data(), PETSC_MEMTYPE_KOKKOS, rootBuf.data(), MPI_REPLACE)); 736 PetscFunctionReturn(PETSC_SUCCESS); 737 } 738 739 // To finish MatMPIAIJKokkosReduce. 740 static PetscErrorCode MatMPIAIJKokkosReduceEnd(MPI_Comm comm, KokkosCsrMatrix A, KokkosCsrMatrix B, PetscInt cstart, PetscInt cend, const PetscInt *garray1, PetscSF ownerSF, MatReuse reuse, PetscInt *map, MatMatStruct_AtB *mm) 741 { 742 PetscFunctionBegin; 743 auto &leafBuf = mm->leafBuf; 744 auto &rootBuf = mm->rootBuf; 745 auto &Fda = mm->Fd.values; 746 const auto &Fdjmap = mm->Fdjmap; 747 const auto &Fdjperm = mm->Fdjperm; 748 auto Fdnz = mm->Fd.nnz(); 749 auto &Foa = mm->Fo.values; 750 const auto &Fojmap = mm->Fojmap; 751 const auto &Fojperm = mm->Fojperm; 752 auto Fonz = mm->Fo.nnz(); 753 PetscSF reduceSF = mm->sf; 754 755 PetscCall(PetscSFReduceEnd(reduceSF, MPIU_SCALAR, leafBuf.data(), rootBuf.data(), MPI_REPLACE)); 756 757 // Reduce data in rootBuf to Fd and Fo 758 PetscCallCXX(Kokkos::parallel_for( 759 Fdnz, KOKKOS_LAMBDA(const MatRowMapType i) { 760 PetscScalar sum = 0.0; 761 for (MatRowMapType k = Fdjmap(i); k < Fdjmap(i + 1); k++) sum += rootBuf(Fdjperm(k)); 762 Fda(i) = sum; 763 })); 764 765 PetscCallCXX(Kokkos::parallel_for( 766 Fonz, KOKKOS_LAMBDA(const MatRowMapType i) { 767 PetscScalar sum = 0.0; 768 for (MatRowMapType k = Fojmap(i); k < Fojmap(i + 1); k++) sum += rootBuf(Fojperm(k)); 769 Foa(i) = sum; 770 })); 771 PetscFunctionReturn(PETSC_SUCCESS); 772 } 773 774 /* 775 MatMPIAIJKokkosBcast - Bcast local rows of a MPIAIJKOKKOS matrix (E) to produce a local matrix (F, stored in mm) in split form 776 777 This is a complex routine. It is essentially the MPIAIJKOKKOS counterpart of MatGetBrowsOfAoCols_MPIAIJ, but supports 778 device and involves various index mapping. 779 780 In the given ownerSF, leaves correspond to rows in F, and roots correspond to rows in E. Roots may connect to multiple leaves. 781 Suppose F's j-th row is connected to a root identified by PetscSFNode (k,i), it means we need to bcast the i-th row of E on rank k 782 to j-th row of F. ownerSF is not an arbitrary SF, instead it is the Mvctx of another MPIAIJ matrix A that is able to perform A*E. 783 F has the same column layout as E. 784 785 Conceptually F has global column indices. In this routine, we spit F into diagonal Fd and off-diagonal Fo. 786 Fd uses local column indices, which are easy to compute. We just need to subtract the "local column range start" from the global indices. 787 Fo had global column indices at first. We will reduce them into local ones. In doing that, we also take into account the global 788 column indices that E's off-diag block has. Let's say there are n1 such indices stored in garray1[]. We will reduce them along with 789 column indices in Fo and update Fo with local indices. 790 791 Input Parameters: 792 + E - the MPIAIJKOKKOS matrix 793 . ownerSF - the ownership SF (insignificant in MAT_REUSE_MATRIX) 794 . reuse - either MAT_INITIAL_MATRIX or MAT_REUSE_MATRIX 795 - mm - to stash matproduct intermediate data structures 796 797 Output Parameters: 798 + map[n1] - allocated by caller. It maps garray1[] to garray2[]. See more at ReduceTwoSetsOfGlobalIndices. 799 - mm - contains various info, such as garray2[], Fd, Fo, etc. 800 801 Notes: 802 When reuse = MAT_REUSE_MATRIX, ownerSF, map are not significant. 803 The routine is provide in split-phase form MatMPIAIJKokkosBcastBegin/End() to provide computation/communication opportunities. 804 */ 805 static PetscErrorCode MatMPIAIJKokkosBcastBegin(Mat E, PetscSF ownerSF, MatReuse reuse, PetscInt *map, MatMatStruct_AB *mm) 806 { 807 Mat_MPIAIJ *empi = static_cast<Mat_MPIAIJ *>(E->data); 808 Mat A = empi->A, B = empi->B; // diag and off-diag 809 Mat_SeqAIJKokkos *akok = static_cast<Mat_SeqAIJKokkos *>(A->spptr), *bkok = static_cast<Mat_SeqAIJKokkos *>(B->spptr); 810 PetscInt Em = E->rmap->n; // #local rows 811 MPI_Comm comm; 812 813 PetscFunctionBegin; 814 PetscCallMPI(PetscObjectGetComm((PetscObject)E, &comm)); 815 if (reuse == MAT_INITIAL_MATRIX) { 816 Mat_SeqAIJ *aseq = static_cast<Mat_SeqAIJ *>(A->data), *bseq = static_cast<Mat_SeqAIJ *>(B->data); 817 PetscInt n1 = B->cmap->n, *Ai = aseq->i, *Aj = aseq->j, *Bi = bseq->i, *Bj = bseq->j; 818 const PetscInt *garray1 = empi->garray; // its size is n1 819 PetscInt cstart, cend; 820 PetscSF bcastSF; 821 822 PetscCall(MatGetOwnershipRangeColumn(E, &cstart, &cend)); 823 824 // Count how many nonzeros of each row in E are in the left of the diag block [cstart,cend) 825 PetscIntKokkosViewHost E_NzLeft_h(NoInit("E_NzLeft_h"), Em), E_RowLen_h(NoInit("E_RowLen_h"), Em); 826 PetscInt *E_NzLeft = E_NzLeft_h.data(), *E_RowLen = E_RowLen_h.data(); 827 for (PetscInt i = 0; i < Em; i++) { 828 const PetscInt *first, *last, *it; 829 PetscInt count, step; 830 // std::lower_bound(first,last,cstart), but need to use global column indices 831 first = Bj + Bi[i]; 832 last = Bj + Bi[i + 1]; 833 count = last - first; 834 while (count > 0) { 835 it = first; 836 step = count / 2; 837 it += step; 838 if (empi->garray[*it] < cstart) { // map local to global 839 first = ++it; 840 count -= step + 1; 841 } else count = step; 842 } 843 E_NzLeft[i] = first - (Bj + Bi[i]); 844 E_RowLen[i] = (Ai[i + 1] - Ai[i]) + (Bi[i + 1] - Bi[i]); 845 } 846 847 // Compute row pointer Fi of F 848 PetscInt *Fi, Fm, Fnz; 849 PetscCall(PetscSFGetGraph(ownerSF, NULL, &Fm, NULL, NULL)); // Fm = #rows of F = nleaves of ownerSF 850 PetscCall(PetscMalloc1(Fm + 1, &Fi)); 851 Fi[0] = 0; 852 PetscCall(PetscSFBcastWithMemTypeBegin(ownerSF, MPIU_INT, PETSC_MEMTYPE_HOST, E_RowLen, PETSC_MEMTYPE_HOST, &Fi[1], MPI_REPLACE)); 853 PetscCall(PetscSFBcastEnd(ownerSF, MPIU_INT, E_RowLen, &Fi[1], MPI_REPLACE)); 854 for (PetscInt i = 0; i < Fm; i++) Fi[i + 1] += Fi[i]; 855 Fnz = Fi[Fm]; 856 857 // Build the real PetscSF for bcasting E rows (buffer to buffer) 858 const PetscMPIInt *iranks, *ranks; 859 const PetscInt *ioffset, *irootloc, *roffset; 860 PetscInt niranks, nranks, *sdisp, *rdisp; 861 MPI_Request *reqs; 862 PetscMPIInt tag; 863 864 PetscCall(PetscSFGetLeafRanks(ownerSF, &niranks, &iranks, &ioffset, &irootloc)); // get leaf ranks referencing roots on this process 865 PetscCall(PetscSFGetRootRanks(ownerSF, &nranks, &ranks, &roffset, NULL, NULL)); // recv info 866 PetscCall(PetscMalloc3(niranks + 1, &sdisp, nranks, &rdisp, niranks + nranks, &reqs)); 867 868 sdisp[0] = 0; // send displacement 869 for (PetscInt i = 0; i < niranks; i++) { 870 sdisp[i + 1] = sdisp[i]; 871 for (PetscInt j = ioffset[i]; j < ioffset[i + 1]; j++) { 872 PetscInt r = irootloc[j]; // row to be sent 873 sdisp[i + 1] += E_RowLen[r]; 874 } 875 } 876 877 PetscCallMPI(PetscCommGetNewTag(comm, &tag)); 878 for (PetscInt i = 0; i < nranks; i++) PetscCallMPI(MPI_Irecv(&rdisp[i], 1, MPIU_INT, ranks[i], tag, comm, &reqs[i])); 879 for (PetscInt i = 0; i < niranks; i++) PetscCallMPI(MPI_Isend(&sdisp[i], 1, MPIU_INT, iranks[i], tag, comm, &reqs[nranks + i])); 880 PetscCallMPI(MPI_Waitall(niranks + nranks, reqs, MPI_STATUSES_IGNORE)); 881 882 PetscInt nleaves = Fnz; // leaves are nonzeros I will receive 883 PetscInt nroots = sdisp[niranks]; // roots are nonzeros I will send 884 PetscSFNode *iremote; // give ownership to bcastSF 885 PetscCall(PetscMalloc1(nleaves, &iremote)); 886 for (PetscInt i = 0; i < nranks; i++) { // for each sender rank 887 PetscInt k = 0; 888 for (PetscInt j = Fi[roffset[i]]; j < Fi[roffset[i + 1]]; j++) { // I will receive rows [roffset[i], roffset[i+1]) of F from ranks[i] 889 iremote[j].rank = ranks[i]; 890 iremote[j].index = rdisp[i] + k; // their root location 891 k++; 892 } 893 } 894 PetscCall(PetscSFCreate(comm, &bcastSF)); 895 PetscCall(PetscSFSetGraph(bcastSF, nroots, nleaves, NULL, PETSC_OWN_POINTER, iremote, PETSC_OWN_POINTER)); 896 PetscCall(PetscFree3(sdisp, rdisp, reqs)); 897 898 // Build a plan (rowoffset, irootloc, E_NzLeft) to copy rows in E to rootdata of bcastSF in parallel 899 PetscIntKokkosViewHost rowoffset_h(NoInit("rowoffset_h"), ioffset[niranks] + 1); 900 PetscInt *rowoffset = rowoffset_h.data(); // for each entry (row) indicated in irootloc[], we calculate its destinate offset in copying 901 rowoffset[0] = 0; 902 for (PetscInt i = 0; i < ioffset[niranks]; i++) { rowoffset[i + 1] = rowoffset[i] + E_RowLen[irootloc[i]]; } 903 904 // Copy (global) column indices of the needed rows in E to a buffer, and then bcast to Fj[] 905 PetscInt *jbuf, *Fj; 906 PetscCall(PetscMalloc2(nroots, &jbuf, Fnz, &Fj)); 907 for (PetscInt k = 0; k < ioffset[niranks]; k++) { 908 PetscInt i = irootloc[k]; // row to be copied 909 PetscInt *buf = &jbuf[rowoffset[k]]; 910 PetscInt nzLeft = E_NzLeft[i]; 911 PetscInt alen = Ai[i + 1] - Ai[i], blen = Bi[i + 1] - Bi[i]; 912 for (PetscInt j = 0; j < alen + blen; j++) { 913 if (j < nzLeft) { 914 buf[j] = empi->garray[Bj[Bi[i] + j]]; // left B, in global 915 } else if (j < nzLeft + alen) { 916 buf[j] = Aj[Ai[i] + j - nzLeft] + cstart; // diag A, also in global 917 } else { 918 buf[j] = empi->garray[Bj[Bi[i] + j - alen]]; // right B, in global 919 } 920 } 921 } 922 PetscCall(PetscSFBcastWithMemTypeBegin(bcastSF, MPIU_INT, PETSC_MEMTYPE_HOST, jbuf, PETSC_MEMTYPE_HOST, Fj, MPI_REPLACE)); 923 PetscCall(PetscSFBcastEnd(bcastSF, MPIU_INT, jbuf, Fj, MPI_REPLACE)); 924 925 // Build a plan (i.e., F_NzLeft) to split F into Fd and Fo 926 MatRowMapKokkosViewHost Fdi_h(NoInit("Fdi_h"), Fm + 1), Foi_h(NoInit("Foi_h"), Fm + 1); // row pointer of Fd, Fo 927 MatColIdxKokkosViewHost F_NzLeft_h(NoInit("F_NzLeft_h"), Fm); // split each row of F into Left, Diag, Right. We only need to record #nz in Left and Diag. 928 MatRowMapType *Fdi = Fdi_h.data(), *Foi = Foi_h.data(); 929 MatColIdxType *F_NzLeft = F_NzLeft_h.data(); 930 931 Fdi[0] = Foi[0] = 0; 932 for (PetscInt i = 0; i < Fm; i++) { 933 PetscInt *first, *last, *lb1, *lb2; 934 // cut the row into: Left, [cstart, cend), Right 935 first = Fj + Fi[i]; 936 last = Fj + Fi[i + 1]; 937 lb1 = std::lower_bound(first, last, cstart); 938 F_NzLeft[i] = lb1 - first; 939 lb2 = std::lower_bound(first, last, cend); 940 Fdi[i + 1] = lb2 - lb1; // row i length in Fdi 941 Foi[i + 1] = (Fi[i + 1] - Fi[i]) - Fdi[i + 1]; // row i length in Foi 942 } 943 for (PetscInt i = 0; i < Fm; i++) { 944 Fdi[i + 1] += Fdi[i]; 945 Foi[i + 1] += Foi[i]; 946 } 947 948 // Fill Fdj[] and Foj[], i.e., columns of Fd and Fo. Fdj[] are local, but Foj[] are not yet. 949 PetscInt Fdnz = Fdi[Fm], Fonz = Foi[Fm]; 950 MatColIdxKokkosViewHost Fdj_h(NoInit("Fdj_h"), Fdnz), Foj_h(NoInit("Foj_h"), Fonz); 951 MatColIdxType *Fdj = Fdj_h.data(), *Foj = Foj_h.data(), gid; 952 953 for (PetscInt i = 0; i < Fm; i++) { 954 PetscInt nzLeft = F_NzLeft[i]; 955 PetscInt len = Fdi[i + 1] - Fdi[i]; // diag row len 956 for (PetscInt j = 0; j < Fi[i + 1] - Fi[i]; j++) { 957 gid = Fj[Fi[i] + j]; 958 if (j < nzLeft) { // left, in global 959 Foj[Foi[i] + j] = gid; 960 } else if (j < nzLeft + len) { // diag, in local 961 Fdj[Fdi[i] + j - nzLeft] = gid - cstart; 962 } else { // right, in global 963 Foj[Foi[i] + j - len] = gid; 964 } 965 } 966 } 967 PetscCall(PetscFree2(jbuf, Fj)); 968 PetscCall(PetscFree(Fi)); 969 970 // Reduce global indices in Foj[] and garray1[] into local ones 971 PetscInt n2, *garray2; 972 PetscCall(ReduceTwoSetsOfGlobalIndices(n1, garray1, Fonz, Foj, &n2, &garray2, map)); 973 974 // Record the plans built above, for reuse 975 PetscIntKokkosViewHost tmp(const_cast<PetscInt *>(irootloc), ioffset[niranks]); // irootloc[] is owned by ownerSF. We create a copy for safety 976 PetscIntKokkosViewHost irootloc_h(NoInit("irootloc_h"), ioffset[niranks]); 977 Kokkos::deep_copy(irootloc_h, tmp); 978 mm->sf = bcastSF; 979 mm->E_NzLeft = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), E_NzLeft_h); 980 mm->F_NzLeft = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), F_NzLeft_h); 981 mm->irootloc = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), irootloc_h); 982 mm->rowoffset = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), rowoffset_h); 983 mm->rootBuf = MatScalarKokkosView(NoInit("rootBuf"), nroots); 984 mm->leafBuf = MatScalarKokkosView(NoInit("leafBuf"), nleaves); 985 mm->garray = garray2; 986 mm->n = n2; 987 988 // Output Fd and Fo in KokkosCsrMatrix format 989 MatScalarKokkosView Fda_d(NoInit("Fda_d"), Fdnz), Foa_d(NoInit("Foa_d"), Fonz); 990 MatRowMapKokkosView Fdi_d = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Fdi_h); 991 MatColIdxKokkosView Fdj_d = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Fdj_h); 992 MatRowMapKokkosView Foi_d = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Foi_h); 993 MatColIdxKokkosView Foj_d = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Foj_h); 994 995 PetscCallCXX(mm->Fd = KokkosCsrMatrix("Fd", Fm, cend - cstart, Fdnz, Fda_d, Fdi_d, Fdj_d)); 996 PetscCallCXX(mm->Fo = KokkosCsrMatrix("Fo", Fm, n2, Fonz, Foa_d, Foi_d, Foj_d)); 997 998 // Compute kernel launch parameters in merging E or splitting F 999 PetscInt teamSize, vectorLength, rowsPerTeam; 1000 1001 teamSize = vectorLength = rowsPerTeam = -1; 1002 PetscCall(MatMergeGetLaunchParameters<DefaultExecutionSpace>(mm->irootloc.extent(0), mm->rootBuf.extent(0), -1, teamSize, vectorLength, rowsPerTeam)); 1003 mm->E_TeamSize = teamSize; 1004 mm->E_VectorLength = vectorLength; 1005 mm->E_RowsPerTeam = rowsPerTeam; 1006 1007 teamSize = vectorLength = rowsPerTeam = -1; 1008 PetscCall(MatMergeGetLaunchParameters<DefaultExecutionSpace>(Fm, Fnz, -1, teamSize, vectorLength, rowsPerTeam)); 1009 mm->F_TeamSize = teamSize; 1010 mm->F_VectorLength = vectorLength; 1011 mm->F_RowsPerTeam = rowsPerTeam; 1012 } else PetscCheck(reuse == MAT_REUSE_MATRIX, comm, PETSC_ERR_PLIB, "Unsupported MatReuse enum %d", reuse); 1013 1014 // Sync E's value to device 1015 akok->a_dual.sync_device(); 1016 bkok->a_dual.sync_device(); 1017 1018 // Handy aliases 1019 const auto &Aa = akok->a_dual.view_device(); 1020 const auto &Ba = bkok->a_dual.view_device(); 1021 const auto &Ai = akok->i_dual.view_device(); 1022 const auto &Bi = bkok->i_dual.view_device(); 1023 1024 // Fetch the plans 1025 PetscIntKokkosView &E_NzLeft = mm->E_NzLeft; 1026 PetscSF &bcastSF = mm->sf; 1027 MatScalarKokkosView &rootBuf = mm->rootBuf; 1028 MatScalarKokkosView &leafBuf = mm->leafBuf; 1029 PetscIntKokkosView &irootloc = mm->irootloc; 1030 PetscIntKokkosView &rowoffset = mm->rowoffset; 1031 1032 PetscInt teamSize = mm->E_TeamSize; 1033 PetscInt vectorLength = mm->E_VectorLength; 1034 PetscInt rowsPerTeam = mm->E_RowsPerTeam; 1035 PetscInt workSets = (irootloc.extent(0) + rowsPerTeam - 1) / rowsPerTeam; 1036 1037 // Copy rows in A/B of E to rootBuf, then bcast it to leafBuf 1038 PetscCallCXX(Kokkos::parallel_for( 1039 Kokkos::TeamPolicy<>(workSets, teamSize, vectorLength), KOKKOS_LAMBDA(const KokkosTeamMemberType &t) { 1040 Kokkos::parallel_for(Kokkos::TeamThreadRange(t, 0, rowsPerTeam), [&](PetscInt k) { 1041 size_t r = t.league_rank() * rowsPerTeam + k; // r-th entry in irootloc[] 1042 if (r < irootloc.extent(0)) { 1043 PetscInt i = irootloc(r); // row i of E 1044 PetscInt disp = rowoffset(r); 1045 PetscInt alen = Ai(i + 1) - Ai(i); 1046 PetscInt blen = Bi(i + 1) - Bi(i); 1047 PetscInt nzleft = E_NzLeft(i); 1048 1049 Kokkos::parallel_for(Kokkos::ThreadVectorRange(t, alen + blen), [&](PetscInt j) { 1050 if (j < nzleft) { // B left 1051 rootBuf(disp + j) = Ba(Bi(i) + j); 1052 } else if (j < nzleft + alen) { // diag A 1053 rootBuf(disp + j) = Aa(Ai(i) + j - nzleft); 1054 } else { // B right 1055 rootBuf(disp + j) = Ba(Bi(i) + j - alen); 1056 } 1057 }); 1058 } 1059 }); 1060 })); 1061 PetscCall(PetscSFBcastWithMemTypeBegin(bcastSF, MPIU_SCALAR, PETSC_MEMTYPE_KOKKOS, rootBuf.data(), PETSC_MEMTYPE_KOKKOS, leafBuf.data(), MPI_REPLACE)); 1062 PetscFunctionReturn(PETSC_SUCCESS); 1063 } 1064 1065 // To finish MatMPIAIJKokkosBcast. 1066 static PetscErrorCode MatMPIAIJKokkosBcastEnd(Mat E, PetscSF ownerSF, MatReuse reuse, PetscInt *map, MatMatStruct_AB *mm) 1067 { 1068 PetscFunctionBegin; 1069 const auto &Fd = mm->Fd; 1070 const auto &Fo = mm->Fo; 1071 const auto &Fdi = Fd.graph.row_map; 1072 const auto &Foi = Fo.graph.row_map; 1073 auto &Fda = Fd.values; 1074 auto &Foa = Fo.values; 1075 auto Fm = Fd.numRows(); 1076 1077 PetscIntKokkosView &F_NzLeft = mm->F_NzLeft; 1078 PetscSF &bcastSF = mm->sf; 1079 MatScalarKokkosView &rootBuf = mm->rootBuf; 1080 MatScalarKokkosView &leafBuf = mm->leafBuf; 1081 PetscInt teamSize = mm->F_TeamSize; 1082 PetscInt vectorLength = mm->F_VectorLength; 1083 PetscInt rowsPerTeam = mm->F_RowsPerTeam; 1084 PetscInt workSets = (Fm + rowsPerTeam - 1) / rowsPerTeam; 1085 1086 PetscCall(PetscSFBcastEnd(bcastSF, MPIU_SCALAR, rootBuf.data(), leafBuf.data(), MPI_REPLACE)); 1087 1088 // Update Fda and Foa with new data in leafBuf (as if it is Fa) 1089 PetscCallCXX(Kokkos::parallel_for( 1090 Kokkos::TeamPolicy<>(workSets, teamSize, vectorLength), KOKKOS_LAMBDA(const KokkosTeamMemberType &t) { 1091 Kokkos::parallel_for(Kokkos::TeamThreadRange(t, 0, rowsPerTeam), [&](PetscInt k) { 1092 PetscInt i = t.league_rank() * rowsPerTeam + k; // i-th row in F 1093 if (i < Fm) { 1094 PetscInt nzLeft = F_NzLeft(i); 1095 PetscInt alen = Fdi(i + 1) - Fdi(i); 1096 PetscInt blen = Foi(i + 1) - Foi(i); 1097 PetscInt Fii = Fdi(i) + Foi(i); 1098 1099 Kokkos::parallel_for(Kokkos::ThreadVectorRange(t, alen + blen), [&](PetscInt j) { 1100 PetscScalar val = leafBuf(Fii + j); 1101 if (j < nzLeft) { // left 1102 Foa(Foi(i) + j) = val; 1103 } else if (j < nzLeft + alen) { // diag 1104 Fda(Fdi(i) + j - nzLeft) = val; 1105 } else { // right 1106 Foa(Foi(i) + j - alen) = val; 1107 } 1108 }); 1109 } 1110 }); 1111 })); 1112 PetscFunctionReturn(PETSC_SUCCESS); 1113 } 1114 1115 static PetscErrorCode MatProductSymbolic_MPIAIJKokkos_AtB(Mat_Product *product, Mat A, Mat B, MatMatStruct_AtB *mm) 1116 { 1117 Mat_MPIAIJ *ampi = static_cast<Mat_MPIAIJ *>(A->data); 1118 Mat_MPIAIJ *bmpi = static_cast<Mat_MPIAIJ *>(B->data); 1119 KokkosCsrMatrix Adt, Aot, Ad, Ao, Bd, Bo; 1120 PetscInt cstart, cend; 1121 MPI_Comm comm; 1122 1123 PetscFunctionBegin; 1124 PetscCall(PetscObjectGetComm((PetscObject)B, &comm)); 1125 PetscCall(MatSeqAIJKokkosGenerateTranspose_Private(ampi->A, &Adt)); 1126 PetscCall(MatSeqAIJKokkosGenerateTranspose_Private(ampi->B, &Aot)); 1127 PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(ampi->A, &Ad)); 1128 PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(ampi->B, &Ao)); 1129 PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(bmpi->A, &Bd)); 1130 PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(bmpi->B, &Bo)); 1131 1132 // TODO: add command line options to select spgemm algorithms 1133 auto spgemm_alg = KokkosSparse::SPGEMMAlgorithm::SPGEMM_DEFAULT; // default is TPL if enabled, otherwise KK 1134 1135 // CUDA-10.2's spgemm has bugs. We prefer the SpGEMMreuse APIs introduced in cuda-11.4 1136 #if defined(KOKKOSKERNELS_ENABLE_TPL_CUSPARSE) 1137 #if PETSC_PKG_CUDA_VERSION_LT(11, 4, 0) 1138 spgemm_alg = KokkosSparse::SPGEMMAlgorithm::SPGEMM_KK; 1139 #endif 1140 #endif 1141 1142 PetscCallCXX(mm->kh1.create_spgemm_handle(spgemm_alg)); 1143 PetscCallCXX(mm->kh2.create_spgemm_handle(spgemm_alg)); 1144 PetscCallCXX(mm->kh3.create_spgemm_handle(spgemm_alg)); 1145 PetscCallCXX(mm->kh4.create_spgemm_handle(spgemm_alg)); 1146 1147 // Aot * (B's diag + B's off-diag) 1148 PetscCallCXX(KokkosSparse::spgemm_symbolic(mm->kh3, Aot, false, Bd, false, mm->C3)); 1149 PetscCallCXX(KokkosSparse::spgemm_symbolic(mm->kh4, Aot, false, Bo, false, mm->C4)); 1150 // KK spgemm_symbolic() only populates the result's row map, but not its columns. 1151 // TODO: Remove the fake spgemm_numeric() after KK fixed this problem. 1152 PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh3, Aot, false, Bd, false, mm->C3)); 1153 PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh4, Aot, false, Bo, false, mm->C4)); 1154 #if PETSC_PKG_KOKKOS_KERNELS_VERSION_LT(4, 0, 0) 1155 PetscCallCXX(sort_crs_matrix(mm->C3)); 1156 PetscCallCXX(sort_crs_matrix(mm->C4)); 1157 #endif 1158 1159 // Reduce E (i.e., C3 and C4)'s rows to form F, and overlap the communication 1160 PetscIntKokkosViewHost map_h(NoInit("map_h"), bmpi->B->cmap->n); 1161 PetscCall(MatGetOwnershipRangeColumn(B, &cstart, &cend)); 1162 PetscCall(MatMPIAIJKokkosReduceBegin(comm, mm->C3, mm->C4, cstart, cend, bmpi->garray, ampi->Mvctx, MAT_INITIAL_MATRIX, map_h.data(), mm)); 1163 1164 // Adt * (B's diag + B's off-diag) 1165 PetscCallCXX(KokkosSparse::spgemm_symbolic(mm->kh1, Adt, false, Bd, false, mm->C1)); 1166 PetscCallCXX(KokkosSparse::spgemm_symbolic(mm->kh2, Adt, false, Bo, false, mm->C2_mid)); 1167 PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh1, Adt, false, Bd, false, mm->C1)); 1168 PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh2, Adt, false, Bo, false, mm->C2_mid)); 1169 #if PETSC_PKG_KOKKOS_KERNELS_VERSION_LT(4, 0, 0) 1170 PetscCallCXX(sort_crs_matrix(mm->C1)); 1171 PetscCallCXX(sort_crs_matrix(mm->C2_mid)); 1172 #endif 1173 1174 PetscCall(MatMPIAIJKokkosReduceEnd(comm, mm->C3, mm->C4, cstart, cend, bmpi->garray, ampi->Mvctx, MAT_INITIAL_MATRIX, map_h.data(), mm)); 1175 1176 // Create C2, which shares a, i arrays with C2_mid, but with new column indices and potentially larger column size 1177 MatColIdxKokkosView oldj = mm->C2_mid.graph.entries, newj(NoInit("j"), oldj.extent(0)); 1178 PetscIntKokkosView map = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), map_h); 1179 PetscCallCXX(Kokkos::parallel_for( 1180 oldj.extent(0), KOKKOS_LAMBDA(const PetscInt i) { newj(i) = map(oldj(i)); })); 1181 PetscCallCXX(mm->C2 = KokkosCsrMatrix("C2", mm->C2_mid.numRows(), mm->n /*new column size*/, mm->C2_mid.nnz(), mm->C2_mid.values, mm->C2_mid.graph.row_map, newj)); 1182 1183 // C = (C1+Fd, C2+Fo) 1184 PetscCallCXX(mm->kh1.create_spadd_handle(true)); // C1, Fd are sorted 1185 PetscCallCXX(mm->kh2.create_spadd_handle(true)); // C2, Fo are sorted 1186 PetscCallCXX(KokkosSparse::spadd_symbolic(&mm->kh1, mm->C1, mm->Fd, mm->Cd)); 1187 PetscCallCXX(KokkosSparse::spadd_symbolic(&mm->kh2, mm->C2, mm->Fo, mm->Co)); 1188 PetscCallCXX(KokkosSparse::spadd_numeric(&mm->kh1, 1.0, mm->C1, 1.0, mm->Fd, mm->Cd)); 1189 PetscCallCXX(KokkosSparse::spadd_numeric(&mm->kh2, 1.0, mm->C2, 1.0, mm->Fo, mm->Co)); 1190 PetscFunctionReturn(PETSC_SUCCESS); 1191 } 1192 1193 static PetscErrorCode MatProductNumeric_MPIAIJKokkos_AtB(Mat_Product *product, Mat A, Mat B, MatMatStruct_AtB *mm) 1194 { 1195 Mat_MPIAIJ *ampi = static_cast<Mat_MPIAIJ *>(A->data); 1196 Mat_MPIAIJ *bmpi = static_cast<Mat_MPIAIJ *>(B->data); 1197 KokkosCsrMatrix Adt, Aot, Bd, Bo; 1198 MPI_Comm comm; 1199 1200 PetscFunctionBegin; 1201 PetscCall(PetscObjectGetComm((PetscObject)B, &comm)); 1202 PetscCall(MatSeqAIJKokkosGenerateTranspose_Private(ampi->A, &Adt)); 1203 PetscCall(MatSeqAIJKokkosGenerateTranspose_Private(ampi->B, &Aot)); 1204 PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(bmpi->A, &Bd)); 1205 PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(bmpi->B, &Bo)); 1206 1207 // Aot * (B's diag + B's off-diag) 1208 PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh3, Aot, false, Bd, false, mm->C3)); 1209 PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh4, Aot, false, Bo, false, mm->C4)); 1210 1211 // Reduce E (i.e., C3 and C4)'s rows to form F, and overlap the communication 1212 PetscCall(MatMPIAIJKokkosReduceBegin(comm, mm->C3, mm->C4, 0, 0, NULL, NULL, MAT_REUSE_MATRIX, NULL, mm)); 1213 1214 // Adt * (B's diag + B's off-diag) 1215 PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh1, Adt, false, Bd, false, mm->C1)); 1216 PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh2, Adt, false, Bo, false, mm->C2_mid)); 1217 1218 PetscCall(MatMPIAIJKokkosReduceEnd(comm, mm->C3, mm->C4, 0, 0, NULL, NULL, MAT_REUSE_MATRIX, NULL, mm)); 1219 1220 // C = (C1+Fd, C2+Fo) 1221 PetscCallCXX(KokkosSparse::spadd_numeric(&mm->kh1, 1.0, mm->C1, 1.0, mm->Fd, mm->Cd)); 1222 PetscCallCXX(KokkosSparse::spadd_numeric(&mm->kh2, 1.0, mm->C2, 1.0, mm->Fo, mm->Co)); 1223 PetscFunctionReturn(PETSC_SUCCESS); 1224 } 1225 1226 /* MatProductSymbolic_MPIAIJKokkos_AB - AB flavor of MatProductSymbolic_MPIAIJKokkos 1227 1228 Input Parameters: 1229 + product - Mat_Product which carried out the computation. Passed in to access info about this mat product. 1230 . A - an MPIAIJKOKKOS matrix 1231 . B - an MPIAIJKOKKOS matrix 1232 - mm - a struct used to stash intermediate data when computing AB. Persist from symbolic to numeric operations. 1233 */ 1234 static PetscErrorCode MatProductSymbolic_MPIAIJKokkos_AB(Mat_Product *product, Mat A, Mat B, MatMatStruct_AB *mm) 1235 { 1236 Mat_MPIAIJ *ampi = static_cast<Mat_MPIAIJ *>(A->data); 1237 Mat_MPIAIJ *bmpi = static_cast<Mat_MPIAIJ *>(B->data); 1238 KokkosCsrMatrix Ad, Ao, Bd, Bo; 1239 1240 PetscFunctionBegin; 1241 PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(ampi->A, &Ad)); 1242 PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(ampi->B, &Ao)); 1243 PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(bmpi->A, &Bd)); 1244 PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(bmpi->B, &Bo)); 1245 1246 // TODO: add command line options to select spgemm algorithms 1247 auto spgemm_alg = KokkosSparse::SPGEMMAlgorithm::SPGEMM_DEFAULT; // default is TPL if enabled, otherwise KK 1248 1249 // CUDA-10.2's spgemm has bugs. We prefer the SpGEMMreuse APIs introduced in cuda-11.4 1250 #if defined(KOKKOSKERNELS_ENABLE_TPL_CUSPARSE) 1251 #if PETSC_PKG_CUDA_VERSION_LT(11, 4, 0) 1252 spgemm_alg = KokkosSparse::SPGEMMAlgorithm::SPGEMM_KK; 1253 #endif 1254 #endif 1255 1256 mm->kh1.create_spgemm_handle(spgemm_alg); 1257 mm->kh2.create_spgemm_handle(spgemm_alg); 1258 mm->kh3.create_spgemm_handle(spgemm_alg); 1259 mm->kh4.create_spgemm_handle(spgemm_alg); 1260 1261 // Bcast B's rows to form F, and overlap the communication 1262 PetscIntKokkosViewHost map_h(NoInit("map_h"), bmpi->B->cmap->n); 1263 PetscCall(MatMPIAIJKokkosBcastBegin(B, ampi->Mvctx, MAT_INITIAL_MATRIX, map_h.data(), mm)); 1264 1265 // A's diag * (B's diag + B's off-diag) 1266 PetscCallCXX(KokkosSparse::spgemm_symbolic(mm->kh1, Ad, false, Bd, false, mm->C1)); 1267 PetscCallCXX(KokkosSparse::spgemm_symbolic(mm->kh2, Ad, false, Bo, false, mm->C2_mid)); // C2 aliases with C2_mid, except with new column indices 1268 // KK spgemm_symbolic() only populates the result's row map, but not its columns. 1269 // TODO: Remove the fake spgemm_numeric() after KK fixed this problem. 1270 PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh1, Ad, false, Bd, false, mm->C1)); 1271 PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh2, Ad, false, Bo, false, mm->C2_mid)); 1272 #if PETSC_PKG_KOKKOS_KERNELS_VERSION_LT(4, 0, 0) 1273 PetscCallCXX(sort_crs_matrix(mm->C1)); 1274 PetscCallCXX(sort_crs_matrix(mm->C2_mid)); 1275 #endif 1276 1277 PetscCall(MatMPIAIJKokkosBcastEnd(B, ampi->Mvctx, MAT_INITIAL_MATRIX, map_h.data(), mm)); 1278 1279 // A's off-diag * (F's diag + F's off-diag) 1280 PetscCallCXX(KokkosSparse::spgemm_symbolic(mm->kh3, Ao, false, mm->Fd, false, mm->C3)); 1281 PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh3, Ao, false, mm->Fd, false, mm->C3)); 1282 PetscCallCXX(KokkosSparse::spgemm_symbolic(mm->kh4, Ao, false, mm->Fo, false, mm->C4)); 1283 PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh4, Ao, false, mm->Fo, false, mm->C4)); 1284 #if PETSC_PKG_KOKKOS_KERNELS_VERSION_LT(4, 0, 0) 1285 PetscCallCXX(sort_crs_matrix(mm->C3)); 1286 PetscCallCXX(sort_crs_matrix(mm->C4)); 1287 #endif 1288 1289 // Create C2, which shares a, i arrays with C2_mid, but with new column indices and potentially larger column size 1290 MatColIdxKokkosView oldj = mm->C2_mid.graph.entries, newj(NoInit("j"), oldj.extent(0)); 1291 PetscIntKokkosView map = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), map_h); 1292 PetscCallCXX(Kokkos::parallel_for( 1293 oldj.extent(0), KOKKOS_LAMBDA(const PetscInt i) { newj(i) = map(oldj(i)); })); 1294 mm->C2 = KokkosCsrMatrix("C2", mm->C2_mid.numRows(), mm->n /*new column size*/, mm->C2_mid.nnz(), mm->C2_mid.values, mm->C2_mid.graph.row_map, newj); 1295 1296 // C = (Cd, Co) = (C1+C3, C2+C4) 1297 mm->kh1.create_spadd_handle(true); // C1, C3 are sorted 1298 mm->kh2.create_spadd_handle(true); // C2, C4 are sorted 1299 PetscCallCXX(KokkosSparse::spadd_symbolic(&mm->kh1, mm->C1, mm->C3, mm->Cd)); 1300 PetscCallCXX(KokkosSparse::spadd_symbolic(&mm->kh2, mm->C2, mm->C4, mm->Co)); 1301 PetscCallCXX(KokkosSparse::spadd_numeric(&mm->kh1, 1.0, mm->C1, 1.0, mm->C3, mm->Cd)); 1302 PetscCallCXX(KokkosSparse::spadd_numeric(&mm->kh2, 1.0, mm->C2, 1.0, mm->C4, mm->Co)); 1303 PetscFunctionReturn(PETSC_SUCCESS); 1304 } 1305 1306 static PetscErrorCode MatProductNumeric_MPIAIJKokkos_AB(Mat_Product *product, Mat A, Mat B, MatMatStruct_AB *mm) 1307 { 1308 Mat_MPIAIJ *ampi = static_cast<Mat_MPIAIJ *>(A->data); 1309 Mat_MPIAIJ *bmpi = static_cast<Mat_MPIAIJ *>(B->data); 1310 KokkosCsrMatrix Ad, Ao, Bd, Bo; 1311 1312 PetscFunctionBegin; 1313 PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(ampi->A, &Ad)); 1314 PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(ampi->B, &Ao)); 1315 PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(bmpi->A, &Bd)); 1316 PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(bmpi->B, &Bo)); 1317 1318 // Bcast B's rows to form F, and overlap the communication 1319 PetscCall(MatMPIAIJKokkosBcastBegin(B, NULL, MAT_REUSE_MATRIX, NULL, mm)); 1320 1321 // A's diag * (B's diag + B's off-diag) 1322 PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh1, Ad, false, Bd, false, mm->C1)); 1323 PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh2, Ad, false, Bo, false, mm->C2_mid)); 1324 1325 PetscCall(MatMPIAIJKokkosBcastEnd(B, NULL, MAT_REUSE_MATRIX, NULL, mm)); 1326 1327 // A's off-diag * (F's diag + F's off-diag) 1328 PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh3, Ao, false, mm->Fd, false, mm->C3)); 1329 PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh4, Ao, false, mm->Fo, false, mm->C4)); 1330 1331 // C = (Cd, Co) = (C1+C3, C2+C4) 1332 PetscCallCXX(KokkosSparse::spadd_numeric(&mm->kh1, 1.0, mm->C1, 1.0, mm->C3, mm->Cd)); 1333 PetscCallCXX(KokkosSparse::spadd_numeric(&mm->kh2, 1.0, mm->C2, 1.0, mm->C4, mm->Co)); 1334 PetscFunctionReturn(PETSC_SUCCESS); 1335 } 1336 1337 PetscErrorCode MatProductNumeric_MPIAIJKokkos(Mat C) 1338 { 1339 Mat_MPIAIJ *cmpi = static_cast<Mat_MPIAIJ *>(C->data); 1340 Mat_Product *product; 1341 MatProductData_MPIAIJKokkos *pdata; 1342 MatProductType ptype; 1343 Mat A, B; 1344 1345 PetscFunctionBegin; 1346 MatCheckProduct(C, 1); // make sure C is a product 1347 product = C->product; 1348 pdata = static_cast<MatProductData_MPIAIJKokkos *>(product->data); 1349 ptype = product->type; 1350 A = product->A; 1351 B = product->B; 1352 1353 // See if numeric has already been done in symbolic (e.g., user calls MatMatMult(A,B,MAT_INITIAL_MATRIX,..,C)). 1354 // If yes, skip the numeric, but reset the flag so that next time when user calls MatMatMult(E,F,MAT_REUSE_MATRIX,..,C), 1355 // we still do numeric. 1356 if (pdata->reusesym) { // numeric reuses results from symbolic 1357 pdata->reusesym = PETSC_FALSE; 1358 PetscFunctionReturn(PETSC_SUCCESS); 1359 } 1360 1361 if (ptype == MATPRODUCT_AB) { 1362 PetscCall(MatProductNumeric_MPIAIJKokkos_AB(product, A, B, pdata->mmAB)); 1363 } else if (ptype == MATPRODUCT_AtB) { 1364 PetscCall(MatProductNumeric_MPIAIJKokkos_AtB(product, A, B, pdata->mmAtB)); 1365 } else if (ptype == MATPRODUCT_PtAP) { // BtAB, computed by Z = AB; C= BtZ 1366 PetscCall(MatProductNumeric_MPIAIJKokkos_AB(product, A, B, pdata->mmAB)); 1367 PetscCall(MatProductNumeric_MPIAIJKokkos_AtB(product, B, pdata->Z, pdata->mmAtB)); 1368 } 1369 1370 PetscCall(MatSeqAIJKokkosModifyDevice(cmpi->A)); // mark that A, B on device are modified 1371 PetscCall(MatSeqAIJKokkosModifyDevice(cmpi->B)); 1372 PetscFunctionReturn(PETSC_SUCCESS); 1373 } 1374 1375 PetscErrorCode MatProductSymbolic_MPIAIJKokkos(Mat C) 1376 { 1377 Mat A, B; 1378 Mat_Product *product; 1379 MatProductType ptype; 1380 MatProductData_MPIAIJKokkos *pdata; 1381 MatMatStruct *mm = NULL; 1382 PetscInt m, n, M, N; 1383 Mat Cd, Co; 1384 MPI_Comm comm; 1385 1386 PetscFunctionBegin; 1387 PetscCall(PetscObjectGetComm((PetscObject)C, &comm)); 1388 MatCheckProduct(C, 1); 1389 product = C->product; 1390 PetscCheck(!product->data, comm, PETSC_ERR_PLIB, "Product data not empty"); 1391 ptype = product->type; 1392 A = product->A; 1393 B = product->B; 1394 1395 switch (ptype) { 1396 case MATPRODUCT_AB: 1397 m = A->rmap->n; 1398 n = B->cmap->n; 1399 M = A->rmap->N; 1400 N = B->cmap->N; 1401 break; 1402 case MATPRODUCT_AtB: 1403 m = A->cmap->n; 1404 n = B->cmap->n; 1405 M = A->cmap->N; 1406 N = B->cmap->N; 1407 break; 1408 case MATPRODUCT_PtAP: 1409 m = B->cmap->n; 1410 n = B->cmap->n; 1411 M = B->cmap->N; 1412 N = B->cmap->N; 1413 break; /* BtAB */ 1414 default: 1415 SETERRQ(comm, PETSC_ERR_PLIB, "Not for product type %s", MatProductTypes[ptype]); 1416 } 1417 1418 PetscCall(MatSetSizes(C, m, n, M, N)); 1419 PetscCall(PetscLayoutSetUp(C->rmap)); 1420 PetscCall(PetscLayoutSetUp(C->cmap)); 1421 PetscCall(MatSetType(C, ((PetscObject)A)->type_name)); 1422 1423 pdata = new MatProductData_MPIAIJKokkos(); 1424 pdata->reusesym = product->api_user; 1425 1426 if (ptype == MATPRODUCT_AB) { 1427 auto mmAB = new MatMatStruct_AB(); 1428 PetscCall(MatProductSymbolic_MPIAIJKokkos_AB(product, A, B, mmAB)); 1429 mm = pdata->mmAB = mmAB; 1430 } else if (ptype == MATPRODUCT_AtB) { 1431 auto mmAtB = new MatMatStruct_AtB(); 1432 PetscCall(MatProductSymbolic_MPIAIJKokkos_AtB(product, A, B, mmAtB)); 1433 mm = pdata->mmAtB = mmAtB; 1434 } else if (ptype == MATPRODUCT_PtAP) { // C = BtAB, computed as Z = AB; C= BtZ 1435 Mat Zd, Zo, Z; // Zd, Zo are owned by pdata->Z 1436 1437 auto mmAB = new MatMatStruct_AB(); 1438 PetscCall(MatProductSymbolic_MPIAIJKokkos_AB(product, A, B, mmAB)); // Z stored as mmAB->{Cd, Co} 1439 PetscCall(MatCreateSeqAIJKokkosWithKokkosCsrMatrix(PETSC_COMM_SELF, mmAB->Cd, &Zd)); 1440 PetscCall(MatCreateSeqAIJKokkosWithKokkosCsrMatrix(PETSC_COMM_SELF, mmAB->Co, &Zo)); 1441 pdata->mmAB = mmAB; 1442 1443 m = A->rmap->n; // Z's layout 1444 n = B->cmap->n; 1445 M = A->rmap->N; 1446 N = B->cmap->N; 1447 PetscCall(MatCreate(comm, &Z)); 1448 PetscCall(MatSetSizes(Z, m, n, M, N)); 1449 PetscCall(PetscLayoutSetUp(Z->rmap)); 1450 PetscCall(PetscLayoutSetUp(Z->cmap)); 1451 PetscCall(MatSetType(Z, MATMPIAIJKOKKOS)); 1452 PetscCall(MatSetMPIAIJKokkosWithSplitSeqAIJKokkosMatrices(Z, Zd, Zo, mmAB->garray)); 1453 1454 auto mmAtB = new MatMatStruct_AtB(); 1455 PetscCall(MatProductSymbolic_MPIAIJKokkos_AtB(product, B, Z, mmAtB)); // final result C stored as mmAtB->{Cd, Co} 1456 1457 pdata->Z = Z; // give ownership to pdata 1458 mm = pdata->mmAtB = mmAtB; 1459 } 1460 1461 PetscCall(MatCreateSeqAIJKokkosWithKokkosCsrMatrix(PETSC_COMM_SELF, mm->Cd, &Cd)); 1462 PetscCall(MatCreateSeqAIJKokkosWithKokkosCsrMatrix(PETSC_COMM_SELF, mm->Co, &Co)); 1463 PetscCall(MatSetMPIAIJKokkosWithSplitSeqAIJKokkosMatrices(C, Cd, Co, mm->garray)); 1464 1465 C->product->data = pdata; 1466 C->product->destroy = MatProductDataDestroy_MPIAIJKokkos; 1467 C->ops->productnumeric = MatProductNumeric_MPIAIJKokkos; 1468 PetscFunctionReturn(PETSC_SUCCESS); 1469 } 1470 1471 PETSC_INTERN PetscErrorCode MatProductSetFromOptions_MPIAIJKokkos(Mat mat) 1472 { 1473 Mat_Product *product = mat->product; 1474 PetscBool match = PETSC_FALSE; 1475 PetscBool usecpu = PETSC_FALSE; 1476 1477 PetscFunctionBegin; 1478 MatCheckProduct(mat, 1); 1479 if (!product->A->boundtocpu && !product->B->boundtocpu) PetscCall(PetscObjectTypeCompare((PetscObject)product->B, ((PetscObject)product->A)->type_name, &match)); 1480 if (match) { /* we can always fallback to the CPU if requested */ 1481 switch (product->type) { 1482 case MATPRODUCT_AB: 1483 if (product->api_user) { 1484 PetscOptionsBegin(PetscObjectComm((PetscObject)mat), ((PetscObject)mat)->prefix, "MatMatMult", "Mat"); 1485 PetscCall(PetscOptionsBool("-matmatmult_backend_cpu", "Use CPU code", "MatMatMult", usecpu, &usecpu, NULL)); 1486 PetscOptionsEnd(); 1487 } else { 1488 PetscOptionsBegin(PetscObjectComm((PetscObject)mat), ((PetscObject)mat)->prefix, "MatProduct_AB", "Mat"); 1489 PetscCall(PetscOptionsBool("-mat_product_algorithm_backend_cpu", "Use CPU code", "MatMatMult", usecpu, &usecpu, NULL)); 1490 PetscOptionsEnd(); 1491 } 1492 break; 1493 case MATPRODUCT_AtB: 1494 if (product->api_user) { 1495 PetscOptionsBegin(PetscObjectComm((PetscObject)mat), ((PetscObject)mat)->prefix, "MatTransposeMatMult", "Mat"); 1496 PetscCall(PetscOptionsBool("-mattransposematmult_backend_cpu", "Use CPU code", "MatTransposeMatMult", usecpu, &usecpu, NULL)); 1497 PetscOptionsEnd(); 1498 } else { 1499 PetscOptionsBegin(PetscObjectComm((PetscObject)mat), ((PetscObject)mat)->prefix, "MatProduct_AtB", "Mat"); 1500 PetscCall(PetscOptionsBool("-mat_product_algorithm_backend_cpu", "Use CPU code", "MatTransposeMatMult", usecpu, &usecpu, NULL)); 1501 PetscOptionsEnd(); 1502 } 1503 break; 1504 case MATPRODUCT_PtAP: 1505 if (product->api_user) { 1506 PetscOptionsBegin(PetscObjectComm((PetscObject)mat), ((PetscObject)mat)->prefix, "MatPtAP", "Mat"); 1507 PetscCall(PetscOptionsBool("-matptap_backend_cpu", "Use CPU code", "MatPtAP", usecpu, &usecpu, NULL)); 1508 PetscOptionsEnd(); 1509 } else { 1510 PetscOptionsBegin(PetscObjectComm((PetscObject)mat), ((PetscObject)mat)->prefix, "MatProduct_PtAP", "Mat"); 1511 PetscCall(PetscOptionsBool("-mat_product_algorithm_backend_cpu", "Use CPU code", "MatPtAP", usecpu, &usecpu, NULL)); 1512 PetscOptionsEnd(); 1513 } 1514 break; 1515 default: 1516 break; 1517 } 1518 match = (PetscBool)!usecpu; 1519 } 1520 if (match) { 1521 switch (product->type) { 1522 case MATPRODUCT_AB: 1523 case MATPRODUCT_AtB: 1524 case MATPRODUCT_PtAP: 1525 mat->ops->productsymbolic = MatProductSymbolic_MPIAIJKokkos; 1526 break; 1527 default: 1528 break; 1529 } 1530 } 1531 /* fallback to MPIAIJ ops */ 1532 if (!mat->ops->productsymbolic) PetscCall(MatProductSetFromOptions_MPIAIJ(mat)); 1533 PetscFunctionReturn(PETSC_SUCCESS); 1534 } 1535 1536 // Mirror of MatCOOStruct_MPIAIJ on device 1537 struct MatCOOStruct_MPIAIJKokkos { 1538 PetscCount n; 1539 PetscSF sf; 1540 PetscCount Annz, Bnnz; 1541 PetscCount Annz2, Bnnz2; 1542 PetscCountKokkosView Ajmap1, Aperm1; 1543 PetscCountKokkosView Bjmap1, Bperm1; 1544 PetscCountKokkosView Aimap2, Ajmap2, Aperm2; 1545 PetscCountKokkosView Bimap2, Bjmap2, Bperm2; 1546 PetscCountKokkosView Cperm1; 1547 MatScalarKokkosView sendbuf, recvbuf; 1548 1549 MatCOOStruct_MPIAIJKokkos(const MatCOOStruct_MPIAIJ *coo_h) : 1550 n(coo_h->n), 1551 sf(coo_h->sf), 1552 Annz(coo_h->Annz), 1553 Bnnz(coo_h->Bnnz), 1554 Annz2(coo_h->Annz2), 1555 Bnnz2(coo_h->Bnnz2), 1556 Ajmap1(Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), PetscCountKokkosViewHost(coo_h->Ajmap1, coo_h->Annz + 1))), 1557 Aperm1(Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), PetscCountKokkosViewHost(coo_h->Aperm1, coo_h->Atot1))), 1558 Bjmap1(Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), PetscCountKokkosViewHost(coo_h->Bjmap1, coo_h->Bnnz + 1))), 1559 Bperm1(Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), PetscCountKokkosViewHost(coo_h->Bperm1, coo_h->Btot1))), 1560 Aimap2(Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), PetscCountKokkosViewHost(coo_h->Aimap2, coo_h->Annz2))), 1561 Ajmap2(Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), PetscCountKokkosViewHost(coo_h->Ajmap2, coo_h->Annz2 + 1))), 1562 Aperm2(Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), PetscCountKokkosViewHost(coo_h->Aperm2, coo_h->Atot2))), 1563 Bimap2(Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), PetscCountKokkosViewHost(coo_h->Bimap2, coo_h->Bnnz2))), 1564 Bjmap2(Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), PetscCountKokkosViewHost(coo_h->Bjmap2, coo_h->Bnnz2 + 1))), 1565 Bperm2(Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), PetscCountKokkosViewHost(coo_h->Bperm2, coo_h->Btot2))), 1566 Cperm1(Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), PetscCountKokkosViewHost(coo_h->Cperm1, coo_h->sendlen))), 1567 sendbuf(Kokkos::create_mirror_view(DefaultMemorySpace(), MatScalarKokkosViewHost(coo_h->sendbuf, coo_h->sendlen))), 1568 recvbuf(Kokkos::create_mirror_view(DefaultMemorySpace(), MatScalarKokkosViewHost(coo_h->recvbuf, coo_h->recvlen))) 1569 { 1570 PetscCallVoid(PetscObjectReference((PetscObject)sf)); 1571 } 1572 1573 ~MatCOOStruct_MPIAIJKokkos() { PetscCallVoid(PetscSFDestroy(&sf)); } 1574 }; 1575 1576 static PetscErrorCode MatCOOStructDestroy_MPIAIJKokkos(void *data) 1577 { 1578 PetscFunctionBegin; 1579 PetscCallCXX(delete static_cast<MatCOOStruct_MPIAIJKokkos *>(data)); 1580 PetscFunctionReturn(PETSC_SUCCESS); 1581 } 1582 1583 static PetscErrorCode MatSetPreallocationCOO_MPIAIJKokkos(Mat mat, PetscCount coo_n, PetscInt coo_i[], PetscInt coo_j[]) 1584 { 1585 PetscContainer container_h, container_d; 1586 MatCOOStruct_MPIAIJ *coo_h; 1587 MatCOOStruct_MPIAIJKokkos *coo_d; 1588 1589 PetscFunctionBegin; 1590 PetscCall(MatSetPreallocationCOO_MPIAIJ(mat, coo_n, coo_i, coo_j)); /* mpiaij->A,B's type is set to seqaijkokkos */ 1591 mat->preallocated = PETSC_TRUE; 1592 PetscCall(MatAssemblyBegin(mat, MAT_FINAL_ASSEMBLY)); 1593 PetscCall(MatAssemblyEnd(mat, MAT_FINAL_ASSEMBLY)); 1594 PetscCall(MatZeroEntries(mat)); 1595 1596 // Copy the COO struct to device 1597 PetscCall(PetscObjectQuery((PetscObject)mat, "__PETSc_MatCOOStruct_Host", (PetscObject *)&container_h)); 1598 PetscCall(PetscContainerGetPointer(container_h, (void **)&coo_h)); 1599 PetscCallCXX(coo_d = new MatCOOStruct_MPIAIJKokkos(coo_h)); 1600 1601 // Put the COO struct in a container and then attach that to the matrix 1602 PetscCall(PetscContainerCreate(PETSC_COMM_SELF, &container_d)); 1603 PetscCall(PetscContainerSetPointer(container_d, coo_d)); 1604 PetscCall(PetscContainerSetUserDestroy(container_d, MatCOOStructDestroy_MPIAIJKokkos)); 1605 PetscCall(PetscObjectCompose((PetscObject)mat, "__PETSc_MatCOOStruct_Device", (PetscObject)container_d)); 1606 PetscCall(PetscContainerDestroy(&container_d)); 1607 PetscFunctionReturn(PETSC_SUCCESS); 1608 } 1609 1610 static PetscErrorCode MatSetValuesCOO_MPIAIJKokkos(Mat mat, const PetscScalar v[], InsertMode imode) 1611 { 1612 Mat_MPIAIJ *mpiaij = static_cast<Mat_MPIAIJ *>(mat->data); 1613 Mat A = mpiaij->A, B = mpiaij->B; 1614 MatScalarKokkosView Aa, Ba; 1615 MatScalarKokkosView v1; 1616 PetscMemType memtype; 1617 PetscContainer container; 1618 MatCOOStruct_MPIAIJKokkos *coo; 1619 1620 PetscFunctionBegin; 1621 PetscCall(PetscObjectQuery((PetscObject)mat, "__PETSc_MatCOOStruct_Device", (PetscObject *)&container)); 1622 PetscCall(PetscContainerGetPointer(container, (void **)&coo)); 1623 1624 const auto &n = coo->n; 1625 const auto &Annz = coo->Annz; 1626 const auto &Annz2 = coo->Annz2; 1627 const auto &Bnnz = coo->Bnnz; 1628 const auto &Bnnz2 = coo->Bnnz2; 1629 const auto &vsend = coo->sendbuf; 1630 const auto &v2 = coo->recvbuf; 1631 const auto &Ajmap1 = coo->Ajmap1; 1632 const auto &Ajmap2 = coo->Ajmap2; 1633 const auto &Aimap2 = coo->Aimap2; 1634 const auto &Bjmap1 = coo->Bjmap1; 1635 const auto &Bjmap2 = coo->Bjmap2; 1636 const auto &Bimap2 = coo->Bimap2; 1637 const auto &Aperm1 = coo->Aperm1; 1638 const auto &Aperm2 = coo->Aperm2; 1639 const auto &Bperm1 = coo->Bperm1; 1640 const auto &Bperm2 = coo->Bperm2; 1641 const auto &Cperm1 = coo->Cperm1; 1642 1643 PetscCall(PetscGetMemType(v, &memtype)); /* Return PETSC_MEMTYPE_HOST when v is NULL */ 1644 if (PetscMemTypeHost(memtype)) { /* If user gave v[] in host, we need to copy it to device if any */ 1645 v1 = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), MatScalarKokkosViewHost((PetscScalar *)v, n)); 1646 } else { 1647 v1 = MatScalarKokkosView((PetscScalar *)v, n); /* Directly use v[]'s memory */ 1648 } 1649 1650 if (imode == INSERT_VALUES) { 1651 PetscCall(MatSeqAIJGetKokkosViewWrite(A, &Aa)); /* write matrix values */ 1652 PetscCall(MatSeqAIJGetKokkosViewWrite(B, &Ba)); 1653 } else { 1654 PetscCall(MatSeqAIJGetKokkosView(A, &Aa)); /* read & write matrix values */ 1655 PetscCall(MatSeqAIJGetKokkosView(B, &Ba)); 1656 } 1657 1658 /* Pack entries to be sent to remote */ 1659 Kokkos::parallel_for( 1660 vsend.extent(0), KOKKOS_LAMBDA(const PetscCount i) { vsend(i) = v1(Cperm1(i)); }); 1661 1662 /* Send remote entries to their owner and overlap the communication with local computation */ 1663 PetscCall(PetscSFReduceWithMemTypeBegin(coo->sf, MPIU_SCALAR, PETSC_MEMTYPE_KOKKOS, vsend.data(), PETSC_MEMTYPE_KOKKOS, v2.data(), MPI_REPLACE)); 1664 /* Add local entries to A and B in one kernel */ 1665 Kokkos::parallel_for( 1666 Annz + Bnnz, KOKKOS_LAMBDA(PetscCount i) { 1667 PetscScalar sum = 0.0; 1668 if (i < Annz) { 1669 for (PetscCount k = Ajmap1(i); k < Ajmap1(i + 1); k++) sum += v1(Aperm1(k)); 1670 Aa(i) = (imode == INSERT_VALUES ? 0.0 : Aa(i)) + sum; 1671 } else { 1672 i -= Annz; 1673 for (PetscCount k = Bjmap1(i); k < Bjmap1(i + 1); k++) sum += v1(Bperm1(k)); 1674 Ba(i) = (imode == INSERT_VALUES ? 0.0 : Ba(i)) + sum; 1675 } 1676 }); 1677 PetscCall(PetscSFReduceEnd(coo->sf, MPIU_SCALAR, vsend.data(), v2.data(), MPI_REPLACE)); 1678 1679 /* Add received remote entries to A and B in one kernel */ 1680 Kokkos::parallel_for( 1681 Annz2 + Bnnz2, KOKKOS_LAMBDA(PetscCount i) { 1682 if (i < Annz2) { 1683 for (PetscCount k = Ajmap2(i); k < Ajmap2(i + 1); k++) Aa(Aimap2(i)) += v2(Aperm2(k)); 1684 } else { 1685 i -= Annz2; 1686 for (PetscCount k = Bjmap2(i); k < Bjmap2(i + 1); k++) Ba(Bimap2(i)) += v2(Bperm2(k)); 1687 } 1688 }); 1689 1690 if (imode == INSERT_VALUES) { 1691 PetscCall(MatSeqAIJRestoreKokkosViewWrite(A, &Aa)); /* Increase A & B's state etc. */ 1692 PetscCall(MatSeqAIJRestoreKokkosViewWrite(B, &Ba)); 1693 } else { 1694 PetscCall(MatSeqAIJRestoreKokkosView(A, &Aa)); 1695 PetscCall(MatSeqAIJRestoreKokkosView(B, &Ba)); 1696 } 1697 PetscFunctionReturn(PETSC_SUCCESS); 1698 } 1699 1700 static PetscErrorCode MatDestroy_MPIAIJKokkos(Mat A) 1701 { 1702 PetscFunctionBegin; 1703 PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatMPIAIJSetPreallocation_C", NULL)); 1704 PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatMPIAIJGetLocalMatMerge_C", NULL)); 1705 PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatSetPreallocationCOO_C", NULL)); 1706 PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatSetValuesCOO_C", NULL)); 1707 PetscCall(MatDestroy_MPIAIJ(A)); 1708 PetscFunctionReturn(PETSC_SUCCESS); 1709 } 1710 1711 static PetscErrorCode MatSetOps_MPIAIJKokkos(Mat B) 1712 { 1713 PetscFunctionBegin; 1714 B->ops->assemblyend = MatAssemblyEnd_MPIAIJKokkos; 1715 B->ops->mult = MatMult_MPIAIJKokkos; 1716 B->ops->multadd = MatMultAdd_MPIAIJKokkos; 1717 B->ops->multtranspose = MatMultTranspose_MPIAIJKokkos; 1718 B->ops->productsetfromoptions = MatProductSetFromOptions_MPIAIJKokkos; 1719 B->ops->destroy = MatDestroy_MPIAIJKokkos; 1720 1721 PetscCall(PetscObjectComposeFunction((PetscObject)B, "MatMPIAIJSetPreallocation_C", MatMPIAIJSetPreallocation_MPIAIJKokkos)); 1722 PetscCall(PetscObjectComposeFunction((PetscObject)B, "MatMPIAIJGetLocalMatMerge_C", MatMPIAIJGetLocalMatMerge_MPIAIJKokkos)); 1723 PetscCall(PetscObjectComposeFunction((PetscObject)B, "MatSetPreallocationCOO_C", MatSetPreallocationCOO_MPIAIJKokkos)); 1724 PetscCall(PetscObjectComposeFunction((PetscObject)B, "MatSetValuesCOO_C", MatSetValuesCOO_MPIAIJKokkos)); 1725 PetscFunctionReturn(PETSC_SUCCESS); 1726 } 1727 1728 PETSC_INTERN PetscErrorCode MatConvert_MPIAIJ_MPIAIJKokkos(Mat A, MatType mtype, MatReuse reuse, Mat *newmat) 1729 { 1730 Mat B; 1731 Mat_MPIAIJ *a; 1732 1733 PetscFunctionBegin; 1734 if (reuse == MAT_INITIAL_MATRIX) { 1735 PetscCall(MatDuplicate(A, MAT_COPY_VALUES, newmat)); 1736 } else if (reuse == MAT_REUSE_MATRIX) { 1737 PetscCall(MatCopy(A, *newmat, SAME_NONZERO_PATTERN)); 1738 } 1739 B = *newmat; 1740 1741 B->boundtocpu = PETSC_FALSE; 1742 PetscCall(PetscFree(B->defaultvectype)); 1743 PetscCall(PetscStrallocpy(VECKOKKOS, &B->defaultvectype)); 1744 PetscCall(PetscObjectChangeTypeName((PetscObject)B, MATMPIAIJKOKKOS)); 1745 1746 a = static_cast<Mat_MPIAIJ *>(A->data); 1747 if (a->A) PetscCall(MatSetType(a->A, MATSEQAIJKOKKOS)); 1748 if (a->B) PetscCall(MatSetType(a->B, MATSEQAIJKOKKOS)); 1749 if (a->lvec) PetscCall(VecSetType(a->lvec, VECSEQKOKKOS)); 1750 PetscCall(MatSetOps_MPIAIJKokkos(B)); 1751 PetscFunctionReturn(PETSC_SUCCESS); 1752 } 1753 1754 /*MC 1755 MATAIJKOKKOS - "mpiaijkokkos", a matrix type to be used for CSR sparse matrices with Kokkos 1756 1757 A matrix type type using Kokkos-Kernels CrsMatrix type for portability across different device types 1758 1759 Options Database Key: 1760 . -mat_type aijkokkos - sets the matrix type to `MATAIJKOKKOS` 1761 1762 Level: beginner 1763 1764 .seealso: [](ch_matrices), `Mat`, `MatCreateAIJKokkos()`, `MATSEQAIJKOKKOS`, `MATSEQAIJ`, `MATMPIAIJ` 1765 M*/ 1766 PETSC_EXTERN PetscErrorCode MatCreate_MPIAIJKokkos(Mat A) 1767 { 1768 PetscFunctionBegin; 1769 PetscCall(PetscKokkosInitializeCheck()); 1770 PetscCall(MatCreate_MPIAIJ(A)); 1771 PetscCall(MatConvert_MPIAIJ_MPIAIJKokkos(A, MATMPIAIJKOKKOS, MAT_INPLACE_MATRIX, &A)); 1772 PetscFunctionReturn(PETSC_SUCCESS); 1773 } 1774 1775 /*@C 1776 MatCreateAIJKokkos - Creates a sparse matrix in `MATAIJKOKOS` (compressed row) format 1777 (the default parallel PETSc format). This matrix will ultimately pushed down 1778 to Kokkos for calculations. 1779 1780 Collective 1781 1782 Input Parameters: 1783 + comm - MPI communicator, set to `PETSC_COMM_SELF` 1784 . m - number of local rows (or `PETSC_DECIDE` to have calculated if `M` is given) 1785 This value should be the same as the local size used in creating the 1786 y vector for the matrix-vector product y = Ax. 1787 . n - This value should be the same as the local size used in creating the 1788 x vector for the matrix-vector product y = Ax. (or `PETSC_DECIDE` to have 1789 calculated if N is given) For square matrices n is almost always `m`. 1790 . M - number of global rows (or `PETSC_DETERMINE` to have calculated if `m` is given) 1791 . N - number of global columns (or `PETSC_DETERMINE` to have calculated if `n` is given) 1792 . d_nz - number of nonzeros per row in DIAGONAL portion of local submatrix 1793 (same value is used for all local rows) 1794 . d_nnz - array containing the number of nonzeros in the various rows of the 1795 DIAGONAL portion of the local submatrix (possibly different for each row) 1796 or `NULL`, if `d_nz` is used to specify the nonzero structure. 1797 The size of this array is equal to the number of local rows, i.e `m`. 1798 For matrices you plan to factor you must leave room for the diagonal entry and 1799 put in the entry even if it is zero. 1800 . o_nz - number of nonzeros per row in the OFF-DIAGONAL portion of local 1801 submatrix (same value is used for all local rows). 1802 - o_nnz - array containing the number of nonzeros in the various rows of the 1803 OFF-DIAGONAL portion of the local submatrix (possibly different for 1804 each row) or `NULL`, if `o_nz` is used to specify the nonzero 1805 structure. The size of this array is equal to the number 1806 of local rows, i.e `m`. 1807 1808 Output Parameter: 1809 . A - the matrix 1810 1811 Level: intermediate 1812 1813 Notes: 1814 It is recommended that one use the `MatCreate()`, `MatSetType()` and/or `MatSetFromOptions()`, 1815 MatXXXXSetPreallocation() paradigm instead of this routine directly. 1816 [MatXXXXSetPreallocation() is, for example, `MatSeqAIJSetPreallocation()`] 1817 1818 The AIJ format, also called compressed row storage), is fully compatible with standard Fortran 1819 storage. That is, the stored row and column indices can begin at 1820 either one (as in Fortran) or zero. 1821 1822 .seealso: [](ch_matrices), `Mat`, `MATAIJKOKOS`, `MATSEQAIJKOKOS`, `MATMPIAIJKOKOS`, `MatCreate()`, `MatCreateAIJ()`, `MatSetValues()`, 1823 `MatSeqAIJSetColumnIndices()`, `MatCreateSeqAIJWithArrays()`, `MatCreateAIJ()`, `MATMPIAIJKOKKOS`, `MATAIJKOKKOS` 1824 @*/ 1825 PetscErrorCode MatCreateAIJKokkos(MPI_Comm comm, PetscInt m, PetscInt n, PetscInt M, PetscInt N, PetscInt d_nz, const PetscInt d_nnz[], PetscInt o_nz, const PetscInt o_nnz[], Mat *A) 1826 { 1827 PetscMPIInt size; 1828 1829 PetscFunctionBegin; 1830 PetscCall(MatCreate(comm, A)); 1831 PetscCall(MatSetSizes(*A, m, n, M, N)); 1832 PetscCallMPI(MPI_Comm_size(comm, &size)); 1833 if (size > 1) { 1834 PetscCall(MatSetType(*A, MATMPIAIJKOKKOS)); 1835 PetscCall(MatMPIAIJSetPreallocation(*A, d_nz, d_nnz, o_nz, o_nnz)); 1836 } else { 1837 PetscCall(MatSetType(*A, MATSEQAIJKOKKOS)); 1838 PetscCall(MatSeqAIJSetPreallocation(*A, d_nz, d_nnz)); 1839 } 1840 PetscFunctionReturn(PETSC_SUCCESS); 1841 } 1842