1 #include <petscvec_kokkos.hpp> 2 #include <petscpkg_version.h> 3 #include <petscsf.h> 4 #include <petsc/private/sfimpl.h> 5 #include <../src/mat/impls/aij/mpi/mpiaij.h> 6 #include <../src/mat/impls/aij/mpi/kokkos/mpiaijkok.hpp> 7 #include <KokkosSparse_spadd.hpp> 8 #include <KokkosSparse_spgemm.hpp> 9 10 PetscErrorCode MatAssemblyEnd_MPIAIJKokkos(Mat A, MatAssemblyType mode) 11 { 12 Mat_MPIAIJ *mpiaij = (Mat_MPIAIJ *)A->data; 13 14 PetscFunctionBegin; 15 PetscCall(MatAssemblyEnd_MPIAIJ(A, mode)); 16 /* E.g., MatCreateSubMatrix() calls MatCreateMPIAIJWithSeqAIJ(comm,A,B,..), which creates Bnew of SEQAIJ and destroys B of SEQAIJKOKKOS. 17 Thus we finalize A/B/lvec's type in MatAssemblyEnd() to handle various cases. 18 */ 19 if (mode == MAT_FINAL_ASSEMBLY) { 20 PetscCall(MatSetType(mpiaij->A, MATSEQAIJKOKKOS)); 21 PetscCall(MatSetType(mpiaij->B, MATSEQAIJKOKKOS)); 22 PetscCall(VecSetType(mpiaij->lvec, VECSEQKOKKOS)); 23 } 24 25 PetscFunctionReturn(PETSC_SUCCESS); 26 } 27 28 PetscErrorCode MatMPIAIJSetPreallocation_MPIAIJKokkos(Mat mat, PetscInt d_nz, const PetscInt d_nnz[], PetscInt o_nz, const PetscInt o_nnz[]) 29 { 30 Mat_MPIAIJ *mpiaij = (Mat_MPIAIJ *)mat->data; 31 32 PetscFunctionBegin; 33 PetscCall(PetscLayoutSetUp(mat->rmap)); 34 PetscCall(PetscLayoutSetUp(mat->cmap)); 35 #if defined(PETSC_USE_DEBUG) 36 if (d_nnz) { 37 PetscInt i; 38 for (i = 0; i < mat->rmap->n; i++) PetscCheck(d_nnz[i] >= 0, PETSC_COMM_SELF, PETSC_ERR_ARG_OUTOFRANGE, "d_nnz cannot be less than 0: local row %" PetscInt_FMT " value %" PetscInt_FMT, i, d_nnz[i]); 39 } 40 if (o_nnz) { 41 PetscInt i; 42 for (i = 0; i < mat->rmap->n; i++) PetscCheck(o_nnz[i] >= 0, PETSC_COMM_SELF, PETSC_ERR_ARG_OUTOFRANGE, "o_nnz cannot be less than 0: local row %" PetscInt_FMT " value %" PetscInt_FMT, i, o_nnz[i]); 43 } 44 #endif 45 #if defined(PETSC_USE_CTABLE) 46 PetscCall(PetscHMapIDestroy(&mpiaij->colmap)); 47 #else 48 PetscCall(PetscFree(mpiaij->colmap)); 49 #endif 50 PetscCall(PetscFree(mpiaij->garray)); 51 PetscCall(VecDestroy(&mpiaij->lvec)); 52 PetscCall(VecScatterDestroy(&mpiaij->Mvctx)); 53 /* Because the B will have been resized we simply destroy it and create a new one each time */ 54 PetscCall(MatDestroy(&mpiaij->B)); 55 56 if (!mpiaij->A) { 57 PetscCall(MatCreate(PETSC_COMM_SELF, &mpiaij->A)); 58 PetscCall(MatSetSizes(mpiaij->A, mat->rmap->n, mat->cmap->n, mat->rmap->n, mat->cmap->n)); 59 } 60 if (!mpiaij->B) { 61 PetscMPIInt size; 62 PetscCallMPI(MPI_Comm_size(PetscObjectComm((PetscObject)mat), &size)); 63 PetscCall(MatCreate(PETSC_COMM_SELF, &mpiaij->B)); 64 PetscCall(MatSetSizes(mpiaij->B, mat->rmap->n, size > 1 ? mat->cmap->N : 0, mat->rmap->n, size > 1 ? mat->cmap->N : 0)); 65 } 66 PetscCall(MatSetType(mpiaij->A, MATSEQAIJKOKKOS)); 67 PetscCall(MatSetType(mpiaij->B, MATSEQAIJKOKKOS)); 68 PetscCall(MatSeqAIJSetPreallocation(mpiaij->A, d_nz, d_nnz)); 69 PetscCall(MatSeqAIJSetPreallocation(mpiaij->B, o_nz, o_nnz)); 70 mat->preallocated = PETSC_TRUE; 71 PetscFunctionReturn(PETSC_SUCCESS); 72 } 73 74 PetscErrorCode MatMult_MPIAIJKokkos(Mat mat, Vec xx, Vec yy) 75 { 76 Mat_MPIAIJ *mpiaij = (Mat_MPIAIJ *)mat->data; 77 PetscInt nt; 78 79 PetscFunctionBegin; 80 PetscCall(VecGetLocalSize(xx, &nt)); 81 PetscCheck(nt == mat->cmap->n, PETSC_COMM_SELF, PETSC_ERR_ARG_SIZ, "Incompatible partition of mat (%" PetscInt_FMT ") and xx (%" PetscInt_FMT ")", mat->cmap->n, nt); 82 PetscCall(VecScatterBegin(mpiaij->Mvctx, xx, mpiaij->lvec, INSERT_VALUES, SCATTER_FORWARD)); 83 PetscCall((*mpiaij->A->ops->mult)(mpiaij->A, xx, yy)); 84 PetscCall(VecScatterEnd(mpiaij->Mvctx, xx, mpiaij->lvec, INSERT_VALUES, SCATTER_FORWARD)); 85 PetscCall((*mpiaij->B->ops->multadd)(mpiaij->B, mpiaij->lvec, yy, yy)); 86 PetscFunctionReturn(PETSC_SUCCESS); 87 } 88 89 PetscErrorCode MatMultAdd_MPIAIJKokkos(Mat mat, Vec xx, Vec yy, Vec zz) 90 { 91 Mat_MPIAIJ *mpiaij = (Mat_MPIAIJ *)mat->data; 92 PetscInt nt; 93 94 PetscFunctionBegin; 95 PetscCall(VecGetLocalSize(xx, &nt)); 96 PetscCheck(nt == mat->cmap->n, PETSC_COMM_SELF, PETSC_ERR_ARG_SIZ, "Incompatible partition of mat (%" PetscInt_FMT ") and xx (%" PetscInt_FMT ")", mat->cmap->n, nt); 97 PetscCall(VecScatterBegin(mpiaij->Mvctx, xx, mpiaij->lvec, INSERT_VALUES, SCATTER_FORWARD)); 98 PetscCall((*mpiaij->A->ops->multadd)(mpiaij->A, xx, yy, zz)); 99 PetscCall(VecScatterEnd(mpiaij->Mvctx, xx, mpiaij->lvec, INSERT_VALUES, SCATTER_FORWARD)); 100 PetscCall((*mpiaij->B->ops->multadd)(mpiaij->B, mpiaij->lvec, zz, zz)); 101 PetscFunctionReturn(PETSC_SUCCESS); 102 } 103 104 PetscErrorCode MatMultTranspose_MPIAIJKokkos(Mat mat, Vec xx, Vec yy) 105 { 106 Mat_MPIAIJ *mpiaij = (Mat_MPIAIJ *)mat->data; 107 PetscInt nt; 108 109 PetscFunctionBegin; 110 PetscCall(VecGetLocalSize(xx, &nt)); 111 PetscCheck(nt == mat->rmap->n, PETSC_COMM_SELF, PETSC_ERR_ARG_SIZ, "Incompatible partition of mat (%" PetscInt_FMT ") and xx (%" PetscInt_FMT ")", mat->rmap->n, nt); 112 PetscCall((*mpiaij->B->ops->multtranspose)(mpiaij->B, xx, mpiaij->lvec)); 113 PetscCall((*mpiaij->A->ops->multtranspose)(mpiaij->A, xx, yy)); 114 PetscCall(VecScatterBegin(mpiaij->Mvctx, mpiaij->lvec, yy, ADD_VALUES, SCATTER_REVERSE)); 115 PetscCall(VecScatterEnd(mpiaij->Mvctx, mpiaij->lvec, yy, ADD_VALUES, SCATTER_REVERSE)); 116 PetscFunctionReturn(PETSC_SUCCESS); 117 } 118 119 /* Merge the "A, B" matrices of mat into a matrix C. mat's type is MPIAIJKOKKOS. C's type is MATSEQAIJKOKKOS. 120 A is put before B. C's size would be A->rmap->n by (A->cmap->n + B->cmap->n). 121 C still uses local column ids. Their corresponding global column ids are returned in glob. 122 */ 123 PetscErrorCode MatMPIAIJGetLocalMatMerge_MPIAIJKokkos(Mat mat, MatReuse reuse, IS *glob, Mat *C) 124 { 125 Mat Ad, Ao; 126 const PetscInt *cmap; 127 128 PetscFunctionBegin; 129 PetscCall(MatMPIAIJGetSeqAIJ(mat, &Ad, &Ao, &cmap)); 130 PetscCall(MatSeqAIJKokkosMergeMats(Ad, Ao, reuse, C)); 131 if (glob) { 132 PetscInt cst, i, dn, on, *gidx; 133 PetscCall(MatGetLocalSize(Ad, NULL, &dn)); 134 PetscCall(MatGetLocalSize(Ao, NULL, &on)); 135 PetscCall(MatGetOwnershipRangeColumn(mat, &cst, NULL)); 136 PetscCall(PetscMalloc1(dn + on, &gidx)); 137 for (i = 0; i < dn; i++) gidx[i] = cst + i; 138 for (i = 0; i < on; i++) gidx[i + dn] = cmap[i]; 139 PetscCall(ISCreateGeneral(PetscObjectComm((PetscObject)Ad), dn + on, gidx, PETSC_OWN_POINTER, glob)); 140 } 141 PetscFunctionReturn(PETSC_SUCCESS); 142 } 143 144 /* Structs used in matrix products of type C=AB, C=A^tB and C=B^tAB */ 145 struct MatMatStruct { 146 PetscInt n, *garray; // C's garray and its size. 147 KokkosCsrMatrix Cd, Co; // C is in split form matrices (all in local column indcies) 148 KokkosCsrMatrix C1, C2, C3, C4; // intermediate mat products 149 KokkosCsrMatrix C2_mid, C4_mid; // alias of C2, C4; share their a[], i[], but with different j[] (hence column size) 150 PetscIntKokkosView E_NzLeft; 151 PetscSF sf = nullptr; // SF to bcast or reduce matrices E to F 152 MatScalarKokkosView rootBuf, leafBuf; 153 KokkosCsrMatrix Fd, Fo; // F in split form 154 155 KernelHandle kh1; // compute C1, add C1+C3 or C1+Fd 156 KernelHandle kh2; // compute C2, add C2+C4 or C2+Fo 157 KernelHandle kh3; // compute C3 158 KernelHandle kh4; // compute C4 159 160 PetscInt E_TeamSize; // kernel launching parameters in merging E or splitting F 161 PetscInt E_VectorLength; 162 PetscInt E_RowsPerTeam; 163 PetscInt F_TeamSize; 164 PetscInt F_VectorLength; 165 PetscInt F_RowsPerTeam; 166 167 ~MatMatStruct() 168 { 169 PetscFunctionBegin; 170 PetscCallAbort(PETSC_COMM_SELF, PetscSFDestroy(&sf)); 171 PetscFunctionReturnVoid(); 172 } 173 }; 174 175 struct MatMatStruct_AB : public MatMatStruct { 176 PetscIntKokkosView F_NzLeft; // plans to split F (in leafbuf) into Fd, Fo 177 PetscIntKokkosView irootloc; // plans to put E (i.e., Bd, Bo) into rootBuf 178 PetscIntKokkosView rowoffset; 179 }; 180 181 struct MatMatStruct_AtB : public MatMatStruct { 182 MatColIdxKokkosView Fdjmap; // plans to reduce data in rootBuf to Fd, Fo 183 MatColIdxKokkosView Fdjperm; 184 MatColIdxKokkosView Fojmap; 185 MatColIdxKokkosView Fojperm; 186 }; 187 188 struct MatProductData_MPIAIJKokkos { 189 MatMatStruct_AB *mmAB = nullptr; 190 MatMatStruct_AtB *mmAtB = nullptr; 191 PetscBool reusesym = PETSC_FALSE; 192 Mat Z = nullptr; // store Z=AB in computing BtAB 193 194 ~MatProductData_MPIAIJKokkos() 195 { 196 delete mmAB; 197 delete mmAtB; 198 PetscCallAbort(PETSC_COMM_SELF, MatDestroy(&Z)); 199 } 200 }; 201 202 static PetscErrorCode MatProductDataDestroy_MPIAIJKokkos(void *data) 203 { 204 PetscFunctionBegin; 205 PetscCallCXX(delete static_cast<MatProductData_MPIAIJKokkos *>(data)); 206 PetscFunctionReturn(PETSC_SUCCESS); 207 } 208 209 /* MatSetMPIAIJKokkosWithSplitSeqAIJKokkosMatrices - Set the diag and offdiag matrices of a MATMPIAIJKOKKOS matrix. 210 It is similar to MatCreateMPIAIJWithSplitArrays. 211 212 Input Parameters: 213 + mat - the MATMPIAIJKOKKOS matrix, which should have its type and layout set, but should not have its diag, offdiag matrices set 214 . A - the diag matrix using local col ids 215 - B - the offdiag matrix using global col ids 216 217 Output Parameter: 218 . mat - the updated MATMPIAIJKOKKOS matrix 219 */ 220 static PetscErrorCode MatSetMPIAIJKokkosWithSplitSeqAIJKokkosMatrices(Mat mat, Mat A, Mat B, PetscInt *garray) 221 { 222 Mat_MPIAIJ *mpiaij = static_cast<Mat_MPIAIJ *>(mat->data); 223 PetscInt m, n, M, N, Am, An, Bm, Bn; 224 225 PetscFunctionBegin; 226 PetscCall(MatGetSize(mat, &M, &N)); 227 PetscCall(MatGetLocalSize(mat, &m, &n)); 228 PetscCall(MatGetLocalSize(A, &Am, &An)); 229 PetscCall(MatGetLocalSize(B, &Bm, &Bn)); 230 231 PetscCheck(m == Am && m == Bm, PETSC_COMM_SELF, PETSC_ERR_PLIB, "local number of rows do not match"); 232 PetscCheck(n == An, PETSC_COMM_SELF, PETSC_ERR_PLIB, "local number of columns do not match"); 233 // PetscCheck(N == Bn, PETSC_COMM_SELF, PETSC_ERR_PLIB, "global number of columns do not match"); 234 PetscCheck(!mpiaij->A && !mpiaij->B, PETSC_COMM_SELF, PETSC_ERR_PLIB, "A, B of the MPIAIJ matrix are not empty"); 235 mpiaij->A = A; 236 mpiaij->B = B; 237 mpiaij->garray = garray; 238 239 mat->preallocated = PETSC_TRUE; 240 mat->nooffprocentries = PETSC_TRUE; /* See MatAssemblyBegin_MPIAIJ. In effect, making MatAssemblyBegin a nop */ 241 242 PetscCall(MatSetOption(mat, MAT_NO_OFF_PROC_ENTRIES, PETSC_TRUE)); 243 PetscCall(MatAssemblyBegin(mat, MAT_FINAL_ASSEMBLY)); 244 /* MatAssemblyEnd is critical here. It sets mat->offloadmask according to A and B's, and 245 also gets mpiaij->B compacted, with its col ids and size reduced 246 */ 247 PetscCall(MatAssemblyEnd(mat, MAT_FINAL_ASSEMBLY)); 248 PetscCall(MatSetOption(mat, MAT_NO_OFF_PROC_ENTRIES, PETSC_FALSE)); 249 PetscCall(MatSetOption(mat, MAT_NEW_NONZERO_LOCATION_ERR, PETSC_TRUE)); 250 PetscFunctionReturn(PETSC_SUCCESS); 251 } 252 253 // Adapted from Kokkos-Kernels spmv_launch_parameters(), to get parameters in Kokkos nested loops which we used to merge or 254 // split csr matrices. The rule is to have "vector_length * team_size" be around 256 on GPUs (e.g., for a CUDA thread block) 255 template <class ExecutionSpace> 256 static PetscErrorCode MatMergeGetLaunchParameters(PetscInt numRows, PetscInt nnz, PetscInt rows_per_thread, PetscInt &team_size, PetscInt &vector_length, PetscInt &rows_per_team) 257 { 258 Kokkos::TeamPolicy<ExecutionSpace> teamPolicy(128, Kokkos::AUTO); 259 260 PetscFunctionBegin; 261 PetscInt nnz_per_row = numRows ? (nnz / numRows) : 0; // we might meet empty matrices 262 263 if (nnz_per_row < 1) nnz_per_row = 1; 264 265 int max_vector_length = teamPolicy.vector_length_max(); 266 267 if (vector_length < 1) { 268 vector_length = 1; 269 while (vector_length < max_vector_length && vector_length * 6 < nnz_per_row) vector_length *= 2; 270 } 271 272 // Determine rows per thread 273 if (rows_per_thread < 1) { 274 if (KokkosKernels::Impl::kk_is_gpu_exec_space<ExecutionSpace>()) rows_per_thread = 1; 275 else { 276 if (nnz_per_row < 20 && nnz > 5000000) { 277 rows_per_thread = 256; 278 } else rows_per_thread = 64; 279 } 280 } 281 282 if (team_size < 1) { 283 if (KokkosKernels::Impl::kk_is_gpu_exec_space<ExecutionSpace>()) { 284 team_size = 256 / vector_length; 285 } else { 286 team_size = 1; 287 } 288 } 289 290 rows_per_team = rows_per_thread * team_size; 291 292 if (rows_per_team < 0) { 293 PetscInt nnz_per_team = 4096; 294 PetscInt conc = ExecutionSpace().concurrency(); 295 while ((conc * nnz_per_team * 4 > nnz) && (nnz_per_team > 256)) nnz_per_team /= 2; 296 rows_per_team = (nnz_per_team + nnz_per_row - 1) / nnz_per_row; 297 } 298 PetscFunctionReturn(PETSC_SUCCESS); 299 } 300 301 /* 302 Reduce two sets of global indices into local ones 303 304 Input Parameters: 305 + n1 - size of garray1[], the first set 306 . garray1[n1] - a sorted global index array (without duplicates) 307 . m - size of indices[], the second set 308 - indices[m] - a unsorted global index array (might have duplicates), which will be updated on output into local ones 309 310 Output Parameters: 311 + n2 - size of garray2[], the merged set, which combines garray1[] and indices[] 312 . garray2[n2] - allocated by callee using PetscMalloc1(). Contains sorted unique global indices (without duplicates). Caller needs to free it. 313 . map[n1] - allocated by caller. It gives garray1[i] = garray2[map[i]] 314 - indices[m] - on output, global indices in this array are rewritten with local ones, i.e, indices_input[i] = garray2[indices_output[i]] 315 316 Example, say 317 n1 = 5 318 garray1[5] = {1, 4, 7, 8, 10} 319 m = 4 320 indices[4] = {2, 4, 8, 9} 321 322 Combining them together, we have 7 global indices in garray2[] 323 n2 = 7 324 garray2[7] = {1, 2, 4, 7, 8, 9, 10} 325 326 And we have map[] to connect "garray1[i] = garray2[map[i]], i=[0,n1)" 327 map[5] = {0, 2, 3, 4, 6} 328 329 On output, indices[] is updated with local indices 330 indices[4] = {1, 2, 4, 5} 331 */ 332 static PetscErrorCode ReduceTwoSetsOfGlobalIndices(PetscInt n1, const PetscInt *garray1, PetscInt m, PetscInt *indices, PetscInt *n2_, PetscInt **garray2_, PetscInt *map) 333 { 334 PetscHMapI g2l = nullptr; 335 PetscHashIter iter; 336 PetscInt tot, key, val; // total unique global indices. key is global id; val is local id 337 PetscInt n2, *garray2; 338 339 PetscFunctionBegin; 340 tot = 0; 341 PetscCall(PetscHMapICreateWithSize(n1, &g2l)); 342 for (PetscInt i = 0; i < m; i++) { // insert those in indices[] 343 PetscCall(PetscHMapIGetWithDefault(g2l, indices[i], -1, &val)); // if not exist, val is set with -1 344 if (val < 0) PetscCall(PetscHMapISet(g2l, indices[i], tot++)); // val < 0 means gid is not in the hash table yet 345 } 346 347 for (PetscInt i = 0; i < n1; i++) { // insert those in garray1[] 348 PetscCall(PetscHMapIGetWithDefault(g2l, garray1[i], -1, &val)); 349 if (val < 0) PetscCall(PetscHMapISet(g2l, garray1[i], tot++)); 350 } 351 352 // Pull out (unique) globals in the hash table and put them in garray2[] 353 n2 = tot; 354 PetscCall(PetscMalloc1(n2, &garray2)); 355 tot = 0; 356 PetscHashIterBegin(g2l, iter); 357 while (!PetscHashIterAtEnd(g2l, iter)) { 358 PetscHashIterGetKey(g2l, iter, key); 359 PetscHashIterNext(g2l, iter); 360 garray2[tot++] = key; 361 } 362 363 // Sort garray2[] and then map them to local indices starting from 0 364 PetscCall(PetscSortInt(n2, garray2)); 365 PetscCall(PetscHMapIClear(g2l)); 366 for (PetscInt i = 0; i < tot; i++) PetscCall(PetscHMapISet(g2l, garray2[i], i)); // i is the local id 367 368 // Rewrite indices[] with local indices 369 for (PetscInt i = 0; i < m; i++) { 370 PetscCall(PetscHMapIGetWithDefault(g2l, indices[i], -1, &val)); 371 PetscAssert(val >= 0, PETSC_COMM_SELF, PETSC_ERR_PLIB, "Met a negative local column index"); 372 indices[i] = val; 373 } 374 // Record the map that maps garray1[i] to garray2[map[i]] 375 for (PetscInt i = 0; i < n1; i++) PetscCall(PetscHMapIGetWithDefault(g2l, garray1[i], -1, &map[i])); 376 PetscCall(PetscHMapIDestroy(&g2l)); 377 *n2_ = n2; 378 *garray2_ = garray2; 379 PetscFunctionReturn(PETSC_SUCCESS); 380 } 381 382 /* 383 MatMPIAIJKokkosReduce - Reduce rows of a MPIAIJKOKKOS matrix (E, in split form) to produce another matrix (F, also in split form, stored in mm) 384 385 It is the reverse of MatMPIAIJKokkosBcast() in some sense, but with a different signature since we do not really need a fully populated MPIAIJKOKKOS E. 386 387 Think each row of E as a leaf, then the given ownerSF specifies roots for the leaves. Roots may connect to multiple leaves. 388 In this routine, we sparse-merge leaves (rows) at their roots to form potentially longer rows in F. F's number of rows will be nroots of ownerSF. 389 390 Input Parameters: 391 + comm - MPI communicator of E 392 . A - diag block of E, using local column indices 393 . B - off-diag block of E, using local column indices 394 . cstart - (global) start column of Ed 395 . cend - (global) end column + 1 of Ed. In other words, E's column ownership is in range of [cstart, cend) 396 . garray1[n1] - global column indices of Eo. Here n1 is Eo's column size. 397 . ownerSF - the SF specifies ownership (root) of rows in E 398 . reuse - either MAT_INITIAL_MATRIX or MAT_REUSE_MATRIX 399 - mm - to stash intermediate data structures for reuse 400 401 Output Parameters: 402 + map[n1] - allocated by caller. It maps garray1[] to garray2[]. See more at ReduceTwoSetsOfGlobalIndices(). 403 - mm - contains various info, such as garray2[], F (Fd, Fo) etc. 404 405 Notes: 406 When reuse = MAT_REUSE_MATRIX, cstart, cend, garray1, ownerSF, map are not significant. 407 408 */ 409 static PetscErrorCode MatMPIAIJKokkosReduceBegin(MPI_Comm comm, KokkosCsrMatrix A, KokkosCsrMatrix B, PetscInt cstart, PetscInt cend, const PetscInt *garray1, PetscSF ownerSF, MatReuse reuse, PetscInt *map, MatMatStruct_AtB *mm) 410 { 411 PetscFunctionBegin; 412 if (reuse == MAT_INITIAL_MATRIX) { 413 PetscInt Em = A.numRows(), Fm; 414 PetscInt n1 = B.numCols(); 415 416 PetscCall(PetscSFGetGraph(ownerSF, &Fm, NULL, NULL, NULL)); // Fm = #rows of F = nroots of ownerSF 417 418 // Do the analysis on host 419 auto Ai_h = Kokkos::create_mirror_view_and_copy(Kokkos::HostSpace(), A.graph.row_map); 420 auto Aj_h = Kokkos::create_mirror_view_and_copy(Kokkos::HostSpace(), A.graph.entries); 421 auto Bi_h = Kokkos::create_mirror_view_and_copy(Kokkos::HostSpace(), B.graph.row_map); 422 auto Bj_h = Kokkos::create_mirror_view_and_copy(Kokkos::HostSpace(), B.graph.entries); 423 const MatRowMapType *Ai = Ai_h.data(), *Bi = Bi_h.data(); 424 const MatColIdxType *Aj = Aj_h.data(), *Bj = Bj_h.data(); 425 426 // Count how many nonzeros of each row in E are in the left of the diag block [cstart,cend) 427 PetscIntKokkosViewHost E_NzLeft_h(NoInit("E_NzLeft_h"), Em), E_RowLen_h(NoInit("E_RowLen_h"), Em); 428 PetscInt *E_NzLeft = E_NzLeft_h.data(), *E_RowLen = E_RowLen_h.data(); 429 for (PetscInt i = 0; i < Em; i++) { 430 const PetscInt *first, *last, *it; 431 PetscInt count, step; 432 // std::lower_bound(first,last,cstart), but need to use global column indices 433 first = Bj + Bi[i]; 434 last = Bj + Bi[i + 1]; 435 count = last - first; 436 while (count > 0) { 437 it = first; 438 step = count / 2; 439 it += step; 440 if (garray1[*it] < cstart) { // map local to global 441 first = ++it; 442 count -= step + 1; 443 } else count = step; 444 } 445 E_NzLeft[i] = first - (Bj + Bi[i]); 446 E_RowLen[i] = (Ai[i + 1] - Ai[i]) + (Bi[i + 1] - Bi[i]); 447 } 448 449 // Get length of rows (i.e., sizes of leaves) that contribute to my roots 450 const PetscMPIInt *iranks, *ranks; 451 const PetscInt *ioffset, *irootloc, *roffset, *rmine; 452 PetscInt niranks, nranks; 453 MPI_Request *reqs; 454 PetscMPIInt tag; 455 PetscSF reduceSF; 456 PetscInt *sdisp, *rdisp; 457 458 PetscCall(PetscCommGetNewTag(comm, &tag)); 459 PetscCall(PetscSFGetLeafRanks(ownerSF, &niranks, &iranks, &ioffset, &irootloc)); // get leaf ranks connecting to roots on this process (I'll recv from them) 460 PetscCall(PetscSFGetRootRanks(ownerSF, &nranks, &ranks, &roffset, &rmine, NULL)); // get root ranks this process connects (I'll send to them) 461 462 // Find out length of each row I will receive. Even for the same row index, when they are from 463 // different senders, they might have different lengths (and sparsity patterns) 464 PetscInt sendRowCnt = roffset[nranks], recvRowCnt = ioffset[niranks]; 465 PetscInt *sendRowLen, *recvRowLen; // lengths of rows of I need to send/recv per process 466 467 PetscCall(PetscMalloc5(sendRowCnt, &sendRowLen, recvRowCnt + 1, &recvRowLen, nranks, &sdisp, niranks + 1, &rdisp, nranks + niranks, &reqs)); 468 469 for (PetscInt i = 0; i < sendRowCnt; i++) sendRowLen[i] = E_RowLen[rmine[i]]; 470 recvRowLen[0] = 0; // since we will make it in CSR format later 471 recvRowLen++; // advance the pointer now 472 for (PetscInt i = 0; i < niranks; i++) { MPI_Irecv(&recvRowLen[ioffset[i]], ioffset[i + 1] - ioffset[i], MPIU_INT, iranks[i], tag, comm, &reqs[nranks + i]); } 473 for (PetscInt i = 0; i < nranks; i++) { MPI_Isend(&sendRowLen[roffset[i]], roffset[i + 1] - roffset[i], MPIU_INT, ranks[i], tag, comm, &reqs[i]); } 474 PetscCallMPI(MPI_Waitall(nranks + niranks, reqs, MPI_STATUSES_IGNORE)); 475 476 // Build the real PetscSF for reducing E rows (buffer to buffer) 477 rdisp[0] = 0; 478 for (PetscInt i = 0; i < niranks; i++) { 479 rdisp[i + 1] = rdisp[i]; 480 for (PetscInt j = ioffset[i]; j < ioffset[i + 1]; j++) { rdisp[i + 1] += recvRowLen[j]; } 481 } 482 recvRowLen--; // put it back into csr format 483 for (PetscInt i = 0; i < recvRowCnt; i++) recvRowLen[i + 1] += recvRowLen[i]; 484 485 for (PetscInt i = 0; i < nranks; i++) { MPI_Irecv(&sdisp[i], 1, MPIU_INT, ranks[i], tag, comm, &reqs[i]); } 486 for (PetscInt i = 0; i < niranks; i++) { MPI_Isend(&rdisp[i], 1, MPIU_INT, iranks[i], tag, comm, &reqs[nranks + i]); } 487 PetscCallMPI(MPI_Waitall(nranks + niranks, reqs, MPI_STATUSES_IGNORE)); 488 489 PetscInt nleaves = 0, Enz = 0; // leaves are nonzeros I will send 490 PetscInt nroots = rdisp[niranks]; // roots are nonzeros I will recv 491 PetscSFNode *iremote; 492 493 for (PetscInt i = 0; i < Em; i++) Enz += E_RowLen[i]; 494 PetscAssert(A.nnz() + B.nnz() == Enz, comm, PETSC_ERR_PLIB, "Enz should be equal to sum of nnz of A and B"); 495 PetscCallMPI(PetscMalloc1(Enz, &iremote)); // no free, since we give ownership to reduceSF 496 497 for (PetscInt i = 0; i < nranks; i++) { 498 PetscInt count = 0; 499 for (PetscInt j = roffset[i]; j < roffset[i + 1]; j++) count += E_RowLen[rmine[j]]; 500 for (PetscInt j = 0; j < count; j++) { 501 iremote[nleaves + j].rank = ranks[i]; 502 iremote[nleaves + j].index = sdisp[i] + j; 503 } 504 nleaves += count; 505 } 506 PetscCheck(nleaves == Enz, comm, PETSC_ERR_PLIB, "nleaves should be equal to Enz"); 507 508 PetscCall(PetscSFCreate(comm, &reduceSF)); 509 PetscCall(PetscSFSetGraph(reduceSF, nroots, nleaves, NULL, PETSC_OWN_POINTER, iremote, PETSC_OWN_POINTER)); 510 511 // Copy (global) column indices of the needed rows in E to sendCol[], and then PetscSFReduce to recvCol[] 512 PetscInt *sendCol, *recvCol; 513 PetscCall(PetscMalloc2(nleaves, &sendCol, nroots, &recvCol)); 514 for (PetscInt k = 0; k < roffset[nranks]; k++) { 515 PetscInt i = rmine[k]; // row to be copied 516 PetscInt *buf = &sendCol[Ai[i] + Bi[i]]; 517 PetscInt nzLeft = E_NzLeft[i]; 518 PetscInt alen = Ai[i + 1] - Ai[i], blen = Bi[i + 1] - Bi[i]; 519 for (PetscInt j = 0; j < alen + blen; j++) { 520 if (j < nzLeft) { 521 buf[j] = garray1[Bj[Bi[i] + j]]; // left B, in global 522 } else if (j < nzLeft + alen) { 523 buf[j] = Aj[Ai[i] + j - nzLeft] + cstart; // diag A, also in global 524 } else { 525 buf[j] = garray1[Bj[Bi[i] + j - alen]]; // right B, in global 526 } 527 } 528 } 529 PetscCall(PetscSFReduceWithMemTypeBegin(reduceSF, MPIU_INT, PETSC_MEMTYPE_HOST, sendCol, PETSC_MEMTYPE_HOST, recvCol, MPI_REPLACE)); 530 PetscCall(PetscSFReduceEnd(reduceSF, MPIU_INT, sendCol, recvCol, MPI_REPLACE)); 531 532 // With recvCol[], we do a series of analysis to get i, j of Fd, Fo, and build plans to reduce nonzeros in recv buffers to Fd and Fo 533 PetscInt *recvRowPerm, *recvColSorted; 534 PetscInt *recvNzPerm, *recvNzPermSorted; 535 PetscCall(PetscMalloc4(recvRowCnt, &recvRowPerm, nroots, &recvColSorted, nroots, &recvNzPerm, nroots, &recvNzPermSorted)); 536 537 for (PetscInt i = 0; i < nroots; i++) recvNzPerm[i] = i; // numbering all received nonzeros 538 for (PetscInt i = 0; i < recvRowCnt; i++) recvRowPerm[i] = i; // put up a permutation array, so that after sorting we know where to get a row in recvCol[] 539 PetscCall(PetscSortIntWithPermutation(recvRowCnt, irootloc, recvRowPerm)); // irootloc[] (owned by ownerSF) won't be changed 540 541 // i[] array, nz are always easiest to compute 542 MatRowMapKokkosViewHost Fdi_h(NoInit("Fdi_h"), Fm + 1), Foi_h(NoInit("Foi_h"), Fm + 1); 543 MatRowMapType *Fdi, *Foi; 544 PetscInt FnzDups = 0, Fdnz = 0, FdnzDups = 0, Fonz = 0, FonzDups = 0; // nz (with or without dups) in F, Fd, Fo 545 PetscInt iter; 546 547 Kokkos::deep_copy(Fdi_h, 0); // zero, as we will do 'val++' on them 548 Kokkos::deep_copy(Foi_h, 0); 549 Fdi = Fdi_h.data() + 1; // +1 for easy indexing in code below 550 Foi = Foi_h.data() + 1; 551 iter = 0; 552 while (iter < recvRowCnt) { // iter over received rows 553 PetscInt curRowIdx = irootloc[recvRowPerm[iter]]; 554 PetscInt dupRows = 1; // current row has this many contributing rows (of various sparsity patterns) 555 556 while (iter + dupRows < recvRowCnt && irootloc[recvRowPerm[iter + dupRows]] == curRowIdx) dupRows++; 557 558 // Copy column indices (and their permutation) of these rows into recvColSorted & recvNzPermSorted 559 PetscInt nz = 0; // nz (with dups) in the current row 560 PetscInt *jbuf = recvColSorted + FnzDups; 561 PetscInt *pbuf = recvNzPermSorted + FnzDups; 562 PetscInt *jbuf2 = jbuf; // temp pointers 563 PetscInt *pbuf2 = pbuf; 564 for (PetscInt d = 0; d < dupRows; d++) { 565 PetscInt i = recvRowPerm[iter + d]; 566 PetscInt len = recvRowLen[i + 1] - recvRowLen[i]; 567 PetscCall(PetscArraycpy(jbuf2, &recvCol[recvRowLen[i]], len)); 568 PetscCall(PetscArraycpy(pbuf2, &recvNzPerm[recvRowLen[i]], len)); 569 jbuf2 += len; 570 pbuf2 += len; 571 nz += len; 572 } 573 PetscCall(PetscIntSortSemiOrderedWithArray(nz, jbuf, pbuf)); // It could be improved with k-way merge sort, since the rows are already sorted 574 575 // Scan column indices (in jbuf[0,nz), might have dups) of this row, and see how many go to Fd and how many go to Fo 576 PetscInt cur = 0; 577 while (cur < nz) { 578 PetscInt curColIdx = jbuf[cur]; 579 PetscInt dups = 1; 580 581 while (cur + dups < nz && jbuf[cur + dups] == curColIdx) dups++; 582 if (curColIdx >= cstart && curColIdx < cend) { 583 Fdi[curRowIdx]++; 584 FdnzDups += dups; 585 } else { 586 Foi[curRowIdx]++; 587 FonzDups += dups; 588 } 589 cur += dups; 590 } 591 592 FnzDups += nz; 593 iter += dupRows; // Move to next unique row 594 } 595 596 Fdi = Fdi_h.data(); // restore Fdi, Foi and make them CSR 597 Foi = Foi_h.data(); 598 for (PetscInt i = 0; i < Fm; i++) { 599 Fdi[i + 1] += Fdi[i]; 600 Foi[i + 1] += Foi[i]; 601 } 602 Fdnz = Fdi[Fm]; 603 Fonz = Foi[Fm]; 604 PetscCall(PetscFree2(sendCol, recvCol)); 605 606 // Allocate j, jmap, jperm for Fd and Fo 607 MatColIdxKokkosViewHost Fdj_h(NoInit("Fdj_h"), Fdnz), Foj_h(NoInit("Foj_h"), Fonz); 608 MatRowMapKokkosViewHost Fdjmap_h(NoInit("Fdjmap_h"), Fdnz + 1), Fojmap_h(NoInit("Fojmap_h"), Fonz + 1); // +1 to make csr 609 MatRowMapKokkosViewHost Fdjperm_h(NoInit("Fdjperm_h"), FdnzDups), Fojperm_h(NoInit("Fojperm_h"), FonzDups); 610 MatColIdxType *Fdj = Fdj_h.data(), *Foj = Foj_h.data(); 611 MatRowMapType *Fdjmap = Fdjmap_h.data(), *Fojmap = Fojmap_h.data(); 612 MatRowMapType *Fdjperm = Fdjperm_h.data(), *Fojperm = Fojperm_h.data(); 613 614 // Scan recvColSorted[] again, and fill j, jmap, jperm for Fd and Fo 615 Fdjmap[0] = 0; 616 Fojmap[0] = 0; 617 FnzDups = 0; 618 Fdnz = 0; 619 Fonz = 0; 620 iter = 0; // iter over received rows 621 while (iter < recvRowCnt) { 622 PetscInt curRowIdx = irootloc[recvRowPerm[iter]]; // current row idx 623 PetscInt dupRows = 1; // It has this many contributing rows (of various lengths) 624 PetscInt nz = 0; // nz (with dups) in the current row 625 626 while (iter + dupRows < recvRowCnt && irootloc[recvRowPerm[iter + dupRows]] == curRowIdx) dupRows++; 627 for (PetscInt d = 0; d < dupRows; d++) { 628 PetscInt i = recvRowPerm[iter + d]; 629 nz += recvRowLen[i + 1] - recvRowLen[i]; 630 } 631 632 PetscInt *jbuf = recvColSorted + FnzDups; 633 // Scan columns (in jbuf[0,nz) of this row, copy them and their permutation to j[] and jperm[] of Fd and Fo 634 PetscInt cur = 0; 635 while (cur < nz) { 636 PetscInt curColIdx = jbuf[cur]; 637 PetscInt dups = 1; 638 639 while (cur + dups < nz && jbuf[cur + dups] == curColIdx) dups++; 640 if (curColIdx >= cstart && curColIdx < cend) { 641 Fdj[Fdnz] = curColIdx - cstart; // easily convert to local 642 Fdjmap[Fdnz + 1] = Fdjmap[Fdnz] + dups; 643 for (PetscInt j = 0; j < dups; j++) Fdjperm[Fdjmap[Fdnz] + j] = recvNzPermSorted[FnzDups + j]; 644 FdnzDups += dups; 645 Fdnz++; 646 } else { 647 Foj[Fonz] = curColIdx; // in global 648 Fojmap[Fonz + 1] = Fojmap[Fonz] + dups; 649 for (PetscInt j = 0; j < dups; j++) Fojperm[Fojmap[Fonz] + j] = recvNzPermSorted[FnzDups + j]; 650 FonzDups += dups; 651 Fonz++; 652 } 653 cur += dups; 654 FnzDups += dups; 655 } 656 iter += dupRows; // Move to next unique row 657 } 658 PetscCall(PetscFree4(recvRowPerm, recvColSorted, recvNzPerm, recvNzPermSorted)); 659 PetscCall(PetscFree5(sendRowLen, recvRowLen, sdisp, rdisp, reqs)); 660 661 // Combine global column indices in garray1[] and Foj[] 662 PetscInt n2, *garray2; 663 664 PetscCall(ReduceTwoSetsOfGlobalIndices(n1, garray1, Fonz, Foj, &n2, &garray2, map)); 665 mm->sf = reduceSF; 666 mm->leafBuf = MatScalarKokkosView(NoInit("leafBuf"), nleaves); 667 mm->rootBuf = MatScalarKokkosView(NoInit("rootBuf"), nroots); 668 mm->garray = garray2; // give ownership, so no free 669 mm->n = n2; 670 mm->E_NzLeft = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), E_NzLeft_h); 671 mm->Fdjmap = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Fdjmap_h); 672 mm->Fdjperm = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Fdjperm_h); 673 mm->Fojmap = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Fojmap_h); 674 mm->Fojperm = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Fojperm_h); 675 676 // Output Fd and Fo in KokkosCsrMatrix format 677 MatScalarKokkosView Fda_d(NoInit("Fda_d"), Fdnz); 678 MatRowMapKokkosView Fdi_d = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Fdi_h); 679 MatColIdxKokkosView Fdj_d = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Fdj_h); 680 MatScalarKokkosView Foa_d(NoInit("Foa_d"), Fonz); 681 MatRowMapKokkosView Foi_d = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Foi_h); 682 MatColIdxKokkosView Foj_d = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Foj_h); 683 684 PetscCallCXX(mm->Fd = KokkosCsrMatrix("Fd", Fm, cend - cstart, Fdnz, Fda_d, Fdi_d, Fdj_d)); 685 PetscCallCXX(mm->Fo = KokkosCsrMatrix("Fo", Fm, n2, Fonz, Foa_d, Foi_d, Foj_d)); // Fo's column size is n2, length of garray2[] 686 687 // Compute kernel launch parameters in merging E 688 PetscInt teamSize, vectorLength, rowsPerTeam; 689 690 teamSize = vectorLength = rowsPerTeam = -1; 691 PetscCall(MatMergeGetLaunchParameters<DefaultExecutionSpace>(Em, Enz, -1, teamSize, vectorLength, rowsPerTeam)); 692 mm->E_TeamSize = teamSize; 693 mm->E_VectorLength = vectorLength; 694 mm->E_RowsPerTeam = rowsPerTeam; 695 } else PetscCheck(reuse == MAT_REUSE_MATRIX, comm, PETSC_ERR_PLIB, "Unsupported MatReuse enum %d", reuse); 696 697 // Handy aliases 698 auto &Aa = A.values; 699 auto &Ba = B.values; 700 const auto &Ai = A.graph.row_map; 701 const auto &Bi = B.graph.row_map; 702 const auto &E_NzLeft = mm->E_NzLeft; 703 auto &leafBuf = mm->leafBuf; 704 auto &rootBuf = mm->rootBuf; 705 PetscSF reduceSF = mm->sf; 706 PetscInt Em = A.numRows(); 707 PetscInt teamSize = mm->E_TeamSize; 708 PetscInt vectorLength = mm->E_VectorLength; 709 PetscInt rowsPerTeam = mm->E_RowsPerTeam; 710 PetscInt workSets = (Em + rowsPerTeam - 1) / rowsPerTeam; 711 712 // Copy rows in A/B of E to leafBuf, then pass it to rootBuf 713 PetscCallCXX(Kokkos::parallel_for( 714 Kokkos::TeamPolicy<>(workSets, teamSize, vectorLength), KOKKOS_LAMBDA(const KokkosTeamMemberType &t) { 715 Kokkos::parallel_for(Kokkos::TeamThreadRange(t, 0, rowsPerTeam), [&](PetscInt k) { 716 PetscInt i = t.league_rank() * rowsPerTeam + k; // i-th row in F 717 if (i < Em) { 718 PetscInt disp = Ai(i) + Bi(i); 719 PetscInt alen = Ai(i + 1) - Ai(i); 720 PetscInt blen = Bi(i + 1) - Bi(i); 721 PetscInt nzleft = E_NzLeft(i); 722 723 Kokkos::parallel_for(Kokkos::ThreadVectorRange(t, alen + blen), [&](PetscInt j) { 724 MatScalar &val = leafBuf(disp + j); 725 if (j < nzleft) { // B left 726 val = Ba(Bi(i) + j); 727 } else if (j < nzleft + alen) { // diag A 728 val = Aa(Ai(i) + j - nzleft); 729 } else { // B right 730 val = Ba(Bi(i) + j - alen); 731 } 732 }); 733 } 734 }); 735 })); 736 PetscCall(PetscSFReduceWithMemTypeBegin(reduceSF, MPIU_SCALAR, PETSC_MEMTYPE_KOKKOS, leafBuf.data(), PETSC_MEMTYPE_KOKKOS, rootBuf.data(), MPI_REPLACE)); 737 PetscFunctionReturn(PETSC_SUCCESS); 738 } 739 740 // To finish MatMPIAIJKokkosReduce. 741 static PetscErrorCode MatMPIAIJKokkosReduceEnd(MPI_Comm comm, KokkosCsrMatrix A, KokkosCsrMatrix B, PetscInt cstart, PetscInt cend, const PetscInt *garray1, PetscSF ownerSF, MatReuse reuse, PetscInt *map, MatMatStruct_AtB *mm) 742 { 743 PetscFunctionBegin; 744 auto &leafBuf = mm->leafBuf; 745 auto &rootBuf = mm->rootBuf; 746 auto &Fda = mm->Fd.values; 747 const auto &Fdjmap = mm->Fdjmap; 748 const auto &Fdjperm = mm->Fdjperm; 749 auto Fdnz = mm->Fd.nnz(); 750 auto &Foa = mm->Fo.values; 751 const auto &Fojmap = mm->Fojmap; 752 const auto &Fojperm = mm->Fojperm; 753 auto Fonz = mm->Fo.nnz(); 754 PetscSF reduceSF = mm->sf; 755 756 PetscCall(PetscSFReduceEnd(reduceSF, MPIU_SCALAR, leafBuf.data(), rootBuf.data(), MPI_REPLACE)); 757 758 // Reduce data in rootBuf to Fd and Fo 759 PetscCallCXX(Kokkos::parallel_for( 760 Fdnz, KOKKOS_LAMBDA(const MatRowMapType i) { 761 PetscScalar sum = 0.0; 762 for (MatRowMapType k = Fdjmap(i); k < Fdjmap(i + 1); k++) sum += rootBuf(Fdjperm(k)); 763 Fda(i) = sum; 764 })); 765 766 PetscCallCXX(Kokkos::parallel_for( 767 Fonz, KOKKOS_LAMBDA(const MatRowMapType i) { 768 PetscScalar sum = 0.0; 769 for (MatRowMapType k = Fojmap(i); k < Fojmap(i + 1); k++) sum += rootBuf(Fojperm(k)); 770 Foa(i) = sum; 771 })); 772 PetscFunctionReturn(PETSC_SUCCESS); 773 } 774 775 /* 776 MatMPIAIJKokkosBcast - Bcast local rows of a MPIAIJKOKKOS matrix (E) to produce a local matrix (F, stored in mm) in split form 777 778 This is a complex routine. It is essentially the MPIAIJKOKKOS counterpart of MatGetBrowsOfAoCols_MPIAIJ, but supports 779 device and involves various index mapping. 780 781 In the given ownerSF, leaves correspond to rows in F, and roots correspond to rows in E. Roots may connect to multiple leaves. 782 Suppose F's j-th row is connected to a root identified by PetscSFNode (k,i), it means we need to bcast the i-th row of E on rank k 783 to j-th row of F. ownerSF is not an arbitrary SF, instead it is the Mvctx of another MPIAIJ matrix A that is able to perform A*E. 784 F has the same column layout as E. 785 786 Conceptually F has global column indices. In this routine, we spit F into diagonal Fd and off-diagonal Fo. 787 Fd uses local column indices, which are easy to compute. We just need to subtract the "local column range start" from the global indices. 788 Fo had global column indices at first. We will reduce them into local ones. In doing that, we also take into account the global 789 column indices that E's off-diag block has. Let's say there are n1 such indices stored in garray1[]. We will reduce them along with 790 column indices in Fo and update Fo with local indices. 791 792 Input Parameters: 793 + E - the MPIAIJKOKKOS matrix 794 . ownerSF - the ownership SF (nonsignificant in MAT_REUSE_MATRIX) 795 . reuse - either MAT_INITIAL_MATRIX or MAT_REUSE_MATRIX 796 - mm - to stash matproduct intermediate data structures 797 798 Output Parameters: 799 + map[n1] - allocated by caller. It maps garray1[] to garray2[]. See more at ReduceTwoSetsOfGlobalIndices. 800 - mm - contains various info, such as garray2[], Fd, Fo, etc. 801 802 Notes: 803 When reuse = MAT_REUSE_MATRIX, ownerSF, map are not significant. 804 The routine is provide in split-phase form MatMPIAIJKokkosBcastBegin/End() to provide computation/communication opportunities. 805 */ 806 static PetscErrorCode MatMPIAIJKokkosBcastBegin(Mat E, PetscSF ownerSF, MatReuse reuse, PetscInt *map, MatMatStruct_AB *mm) 807 { 808 Mat_MPIAIJ *empi = static_cast<Mat_MPIAIJ *>(E->data); 809 Mat A = empi->A, B = empi->B; // diag and off-diag 810 Mat_SeqAIJKokkos *akok = static_cast<Mat_SeqAIJKokkos *>(A->spptr), *bkok = static_cast<Mat_SeqAIJKokkos *>(B->spptr); 811 PetscInt Em = E->rmap->n; // #local rows 812 MPI_Comm comm; 813 814 PetscFunctionBegin; 815 PetscCallMPI(PetscObjectGetComm((PetscObject)E, &comm)); 816 if (reuse == MAT_INITIAL_MATRIX) { 817 Mat_SeqAIJ *aseq = static_cast<Mat_SeqAIJ *>(A->data), *bseq = static_cast<Mat_SeqAIJ *>(B->data); 818 PetscInt n1 = B->cmap->n, *Ai = aseq->i, *Aj = aseq->j, *Bi = bseq->i, *Bj = bseq->j; 819 const PetscInt *garray1 = empi->garray; // its size is n1 820 PetscInt cstart, cend; 821 PetscSF bcastSF; 822 823 PetscCall(MatGetOwnershipRangeColumn(E, &cstart, &cend)); 824 825 // Count how many nonzeros of each row in E are in the left of the diag block [cstart,cend) 826 PetscIntKokkosViewHost E_NzLeft_h(NoInit("E_NzLeft_h"), Em), E_RowLen_h(NoInit("E_RowLen_h"), Em); 827 PetscInt *E_NzLeft = E_NzLeft_h.data(), *E_RowLen = E_RowLen_h.data(); 828 for (PetscInt i = 0; i < Em; i++) { 829 const PetscInt *first, *last, *it; 830 PetscInt count, step; 831 // std::lower_bound(first,last,cstart), but need to use global column indices 832 first = Bj + Bi[i]; 833 last = Bj + Bi[i + 1]; 834 count = last - first; 835 while (count > 0) { 836 it = first; 837 step = count / 2; 838 it += step; 839 if (empi->garray[*it] < cstart) { // map local to global 840 first = ++it; 841 count -= step + 1; 842 } else count = step; 843 } 844 E_NzLeft[i] = first - (Bj + Bi[i]); 845 E_RowLen[i] = (Ai[i + 1] - Ai[i]) + (Bi[i + 1] - Bi[i]); 846 } 847 848 // Compute row pointer Fi of F 849 PetscInt *Fi, Fm, Fnz; 850 PetscCall(PetscSFGetGraph(ownerSF, NULL, &Fm, NULL, NULL)); // Fm = #rows of F = nleaves of ownerSF 851 PetscCall(PetscMalloc1(Fm + 1, &Fi)); 852 Fi[0] = 0; 853 PetscCall(PetscSFBcastWithMemTypeBegin(ownerSF, MPIU_INT, PETSC_MEMTYPE_HOST, E_RowLen, PETSC_MEMTYPE_HOST, &Fi[1], MPI_REPLACE)); 854 PetscCall(PetscSFBcastEnd(ownerSF, MPIU_INT, E_RowLen, &Fi[1], MPI_REPLACE)); 855 for (PetscInt i = 0; i < Fm; i++) Fi[i + 1] += Fi[i]; 856 Fnz = Fi[Fm]; 857 858 // Build the real PetscSF for bcasting E rows (buffer to buffer) 859 const PetscMPIInt *iranks, *ranks; 860 const PetscInt *ioffset, *irootloc, *roffset; 861 PetscInt niranks, nranks, *sdisp, *rdisp; 862 MPI_Request *reqs; 863 PetscMPIInt tag; 864 865 PetscCall(PetscSFGetLeafRanks(ownerSF, &niranks, &iranks, &ioffset, &irootloc)); // get leaf ranks referencing roots on this process 866 PetscCall(PetscSFGetRootRanks(ownerSF, &nranks, &ranks, &roffset, NULL, NULL)); // recv info 867 PetscCall(PetscMalloc3(niranks + 1, &sdisp, nranks, &rdisp, niranks + nranks, &reqs)); 868 869 sdisp[0] = 0; // send displacement 870 for (PetscInt i = 0; i < niranks; i++) { 871 sdisp[i + 1] = sdisp[i]; 872 for (PetscInt j = ioffset[i]; j < ioffset[i + 1]; j++) { 873 PetscInt r = irootloc[j]; // row to be sent 874 sdisp[i + 1] += E_RowLen[r]; 875 } 876 } 877 878 PetscCallMPI(PetscCommGetNewTag(comm, &tag)); 879 for (PetscInt i = 0; i < nranks; i++) PetscCallMPI(MPI_Irecv(&rdisp[i], 1, MPIU_INT, ranks[i], tag, comm, &reqs[i])); 880 for (PetscInt i = 0; i < niranks; i++) PetscCallMPI(MPI_Isend(&sdisp[i], 1, MPIU_INT, iranks[i], tag, comm, &reqs[nranks + i])); 881 PetscCallMPI(MPI_Waitall(niranks + nranks, reqs, MPI_STATUSES_IGNORE)); 882 883 PetscInt nleaves = Fnz; // leaves are nonzeros I will receive 884 PetscInt nroots = sdisp[niranks]; // roots are nonzeros I will send 885 PetscSFNode *iremote; // give ownership to bcastSF 886 PetscCall(PetscMalloc1(nleaves, &iremote)); 887 for (PetscInt i = 0; i < nranks; i++) { // for each sender rank 888 PetscInt k = 0; 889 for (PetscInt j = Fi[roffset[i]]; j < Fi[roffset[i + 1]]; j++) { // I will receive rows [roffset[i], roffset[i+1]) of F from ranks[i] 890 iremote[j].rank = ranks[i]; 891 iremote[j].index = rdisp[i] + k; // their root location 892 k++; 893 } 894 } 895 PetscCall(PetscSFCreate(comm, &bcastSF)); 896 PetscCall(PetscSFSetGraph(bcastSF, nroots, nleaves, NULL, PETSC_OWN_POINTER, iremote, PETSC_OWN_POINTER)); 897 PetscCall(PetscFree3(sdisp, rdisp, reqs)); 898 899 // Build a plan (rowoffset, irootloc, E_NzLeft) to copy rows in E to rootdata of bcastSF in parallel 900 PetscIntKokkosViewHost rowoffset_h(NoInit("rowoffset_h"), ioffset[niranks] + 1); 901 PetscInt *rowoffset = rowoffset_h.data(); // for each entry (row) indicated in irootloc[], we calculate its destinate offset in copying 902 rowoffset[0] = 0; 903 for (PetscInt i = 0; i < ioffset[niranks]; i++) { rowoffset[i + 1] = rowoffset[i] + E_RowLen[irootloc[i]]; } 904 905 // Copy (global) column indices of the needed rows in E to a buffer, and then bcast to Fj[] 906 PetscInt *jbuf, *Fj; 907 PetscCall(PetscMalloc2(nroots, &jbuf, Fnz, &Fj)); 908 for (PetscInt k = 0; k < ioffset[niranks]; k++) { 909 PetscInt i = irootloc[k]; // row to be copied 910 PetscInt *buf = &jbuf[rowoffset[k]]; 911 PetscInt nzLeft = E_NzLeft[i]; 912 PetscInt alen = Ai[i + 1] - Ai[i], blen = Bi[i + 1] - Bi[i]; 913 for (PetscInt j = 0; j < alen + blen; j++) { 914 if (j < nzLeft) { 915 buf[j] = empi->garray[Bj[Bi[i] + j]]; // left B, in global 916 } else if (j < nzLeft + alen) { 917 buf[j] = Aj[Ai[i] + j - nzLeft] + cstart; // diag A, also in global 918 } else { 919 buf[j] = empi->garray[Bj[Bi[i] + j - alen]]; // right B, in global 920 } 921 } 922 } 923 PetscCall(PetscSFBcastWithMemTypeBegin(bcastSF, MPIU_INT, PETSC_MEMTYPE_HOST, jbuf, PETSC_MEMTYPE_HOST, Fj, MPI_REPLACE)); 924 PetscCall(PetscSFBcastEnd(bcastSF, MPIU_INT, jbuf, Fj, MPI_REPLACE)); 925 926 // Build a plan (i.e., F_NzLeft) to split F into Fd and Fo 927 MatRowMapKokkosViewHost Fdi_h(NoInit("Fdi_h"), Fm + 1), Foi_h(NoInit("Foi_h"), Fm + 1); // row pointer of Fd, Fo 928 MatColIdxKokkosViewHost F_NzLeft_h(NoInit("F_NzLeft_h"), Fm); // split each row of F into Left, Diag, Right. We only need to record #nz in Left and Diag. 929 MatRowMapType *Fdi = Fdi_h.data(), *Foi = Foi_h.data(); 930 MatColIdxType *F_NzLeft = F_NzLeft_h.data(); 931 932 Fdi[0] = Foi[0] = 0; 933 for (PetscInt i = 0; i < Fm; i++) { 934 PetscInt *first, *last, *lb1, *lb2; 935 // cut the row into: Left, [cstart, cend), Right 936 first = Fj + Fi[i]; 937 last = Fj + Fi[i + 1]; 938 lb1 = std::lower_bound(first, last, cstart); 939 F_NzLeft[i] = lb1 - first; 940 lb2 = std::lower_bound(first, last, cend); 941 Fdi[i + 1] = lb2 - lb1; // row i length in Fdi 942 Foi[i + 1] = (Fi[i + 1] - Fi[i]) - Fdi[i + 1]; // row i length in Foi 943 } 944 for (PetscInt i = 0; i < Fm; i++) { 945 Fdi[i + 1] += Fdi[i]; 946 Foi[i + 1] += Foi[i]; 947 } 948 949 // Fill Fdj[] and Foj[], i.e., columns of Fd and Fo. Fdj[] are local, but Foj[] are not yet. 950 PetscInt Fdnz = Fdi[Fm], Fonz = Foi[Fm]; 951 MatColIdxKokkosViewHost Fdj_h(NoInit("Fdj_h"), Fdnz), Foj_h(NoInit("Foj_h"), Fonz); 952 MatColIdxType *Fdj = Fdj_h.data(), *Foj = Foj_h.data(), gid; 953 954 for (PetscInt i = 0; i < Fm; i++) { 955 PetscInt nzLeft = F_NzLeft[i]; 956 PetscInt len = Fdi[i + 1] - Fdi[i]; // diag row len 957 for (PetscInt j = 0; j < Fi[i + 1] - Fi[i]; j++) { 958 gid = Fj[Fi[i] + j]; 959 if (j < nzLeft) { // left, in global 960 Foj[Foi[i] + j] = gid; 961 } else if (j < nzLeft + len) { // diag, in local 962 Fdj[Fdi[i] + j - nzLeft] = gid - cstart; 963 } else { // right, in global 964 Foj[Foi[i] + j - len] = gid; 965 } 966 } 967 } 968 PetscCall(PetscFree2(jbuf, Fj)); 969 PetscCall(PetscFree(Fi)); 970 971 // Reduce global indices in Foj[] and garray1[] into local ones 972 PetscInt n2, *garray2; 973 PetscCall(ReduceTwoSetsOfGlobalIndices(n1, garray1, Fonz, Foj, &n2, &garray2, map)); 974 975 // Record the plans built above, for reuse 976 PetscIntKokkosViewHost tmp(const_cast<PetscInt *>(irootloc), ioffset[niranks]); // irootloc[] is owned by ownerSF. We create a copy for safety 977 PetscIntKokkosViewHost irootloc_h(NoInit("irootloc_h"), ioffset[niranks]); 978 Kokkos::deep_copy(irootloc_h, tmp); 979 mm->sf = bcastSF; 980 mm->E_NzLeft = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), E_NzLeft_h); 981 mm->F_NzLeft = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), F_NzLeft_h); 982 mm->irootloc = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), irootloc_h); 983 mm->rowoffset = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), rowoffset_h); 984 mm->rootBuf = MatScalarKokkosView(NoInit("rootBuf"), nroots); 985 mm->leafBuf = MatScalarKokkosView(NoInit("leafBuf"), nleaves); 986 mm->garray = garray2; 987 mm->n = n2; 988 989 // Output Fd and Fo in KokkosCsrMatrix format 990 MatScalarKokkosView Fda_d(NoInit("Fda_d"), Fdnz), Foa_d(NoInit("Foa_d"), Fonz); 991 MatRowMapKokkosView Fdi_d = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Fdi_h); 992 MatColIdxKokkosView Fdj_d = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Fdj_h); 993 MatRowMapKokkosView Foi_d = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Foi_h); 994 MatColIdxKokkosView Foj_d = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Foj_h); 995 996 PetscCallCXX(mm->Fd = KokkosCsrMatrix("Fd", Fm, cend - cstart, Fdnz, Fda_d, Fdi_d, Fdj_d)); 997 PetscCallCXX(mm->Fo = KokkosCsrMatrix("Fo", Fm, n2, Fonz, Foa_d, Foi_d, Foj_d)); 998 999 // Compute kernel launch parameters in merging E or splitting F 1000 PetscInt teamSize, vectorLength, rowsPerTeam; 1001 1002 teamSize = vectorLength = rowsPerTeam = -1; 1003 PetscCall(MatMergeGetLaunchParameters<DefaultExecutionSpace>(mm->irootloc.extent(0), mm->rootBuf.extent(0), -1, teamSize, vectorLength, rowsPerTeam)); 1004 mm->E_TeamSize = teamSize; 1005 mm->E_VectorLength = vectorLength; 1006 mm->E_RowsPerTeam = rowsPerTeam; 1007 1008 teamSize = vectorLength = rowsPerTeam = -1; 1009 PetscCall(MatMergeGetLaunchParameters<DefaultExecutionSpace>(Fm, Fnz, -1, teamSize, vectorLength, rowsPerTeam)); 1010 mm->F_TeamSize = teamSize; 1011 mm->F_VectorLength = vectorLength; 1012 mm->F_RowsPerTeam = rowsPerTeam; 1013 } else PetscCheck(reuse == MAT_REUSE_MATRIX, comm, PETSC_ERR_PLIB, "Unsupported MatReuse enum %d", reuse); 1014 1015 // Sync E's value to device 1016 akok->a_dual.sync_device(); 1017 bkok->a_dual.sync_device(); 1018 1019 // Handy aliases 1020 const auto &Aa = akok->a_dual.view_device(); 1021 const auto &Ba = bkok->a_dual.view_device(); 1022 const auto &Ai = akok->i_dual.view_device(); 1023 const auto &Bi = bkok->i_dual.view_device(); 1024 1025 // Fetch the plans 1026 PetscIntKokkosView &E_NzLeft = mm->E_NzLeft; 1027 PetscSF &bcastSF = mm->sf; 1028 MatScalarKokkosView &rootBuf = mm->rootBuf; 1029 MatScalarKokkosView &leafBuf = mm->leafBuf; 1030 PetscIntKokkosView &irootloc = mm->irootloc; 1031 PetscIntKokkosView &rowoffset = mm->rowoffset; 1032 1033 PetscInt teamSize = mm->E_TeamSize; 1034 PetscInt vectorLength = mm->E_VectorLength; 1035 PetscInt rowsPerTeam = mm->E_RowsPerTeam; 1036 PetscInt workSets = (irootloc.extent(0) + rowsPerTeam - 1) / rowsPerTeam; 1037 1038 // Copy rows in A/B of E to rootBuf, then bcast it to leafBuf 1039 PetscCallCXX(Kokkos::parallel_for( 1040 Kokkos::TeamPolicy<>(workSets, teamSize, vectorLength), KOKKOS_LAMBDA(const KokkosTeamMemberType &t) { 1041 Kokkos::parallel_for(Kokkos::TeamThreadRange(t, 0, rowsPerTeam), [&](PetscInt k) { 1042 size_t r = t.league_rank() * rowsPerTeam + k; // r-th entry in irootloc[] 1043 if (r < irootloc.extent(0)) { 1044 PetscInt i = irootloc(r); // row i of E 1045 PetscInt disp = rowoffset(r); 1046 PetscInt alen = Ai(i + 1) - Ai(i); 1047 PetscInt blen = Bi(i + 1) - Bi(i); 1048 PetscInt nzleft = E_NzLeft(i); 1049 1050 Kokkos::parallel_for(Kokkos::ThreadVectorRange(t, alen + blen), [&](PetscInt j) { 1051 if (j < nzleft) { // B left 1052 rootBuf(disp + j) = Ba(Bi(i) + j); 1053 } else if (j < nzleft + alen) { // diag A 1054 rootBuf(disp + j) = Aa(Ai(i) + j - nzleft); 1055 } else { // B right 1056 rootBuf(disp + j) = Ba(Bi(i) + j - alen); 1057 } 1058 }); 1059 } 1060 }); 1061 })); 1062 PetscCall(PetscSFBcastWithMemTypeBegin(bcastSF, MPIU_SCALAR, PETSC_MEMTYPE_KOKKOS, rootBuf.data(), PETSC_MEMTYPE_KOKKOS, leafBuf.data(), MPI_REPLACE)); 1063 PetscFunctionReturn(PETSC_SUCCESS); 1064 } 1065 1066 // To finish MatMPIAIJKokkosBcast. 1067 static PetscErrorCode MatMPIAIJKokkosBcastEnd(Mat E, PetscSF ownerSF, MatReuse reuse, PetscInt *map, MatMatStruct_AB *mm) 1068 { 1069 PetscFunctionBegin; 1070 const auto &Fd = mm->Fd; 1071 const auto &Fo = mm->Fo; 1072 const auto &Fdi = Fd.graph.row_map; 1073 const auto &Foi = Fo.graph.row_map; 1074 auto &Fda = Fd.values; 1075 auto &Foa = Fo.values; 1076 auto Fm = Fd.numRows(); 1077 1078 PetscIntKokkosView &F_NzLeft = mm->F_NzLeft; 1079 PetscSF &bcastSF = mm->sf; 1080 MatScalarKokkosView &rootBuf = mm->rootBuf; 1081 MatScalarKokkosView &leafBuf = mm->leafBuf; 1082 PetscInt teamSize = mm->F_TeamSize; 1083 PetscInt vectorLength = mm->F_VectorLength; 1084 PetscInt rowsPerTeam = mm->F_RowsPerTeam; 1085 PetscInt workSets = (Fm + rowsPerTeam - 1) / rowsPerTeam; 1086 1087 PetscCall(PetscSFBcastEnd(bcastSF, MPIU_SCALAR, rootBuf.data(), leafBuf.data(), MPI_REPLACE)); 1088 1089 // Update Fda and Foa with new data in leafBuf (as if it is Fa) 1090 PetscCallCXX(Kokkos::parallel_for( 1091 Kokkos::TeamPolicy<>(workSets, teamSize, vectorLength), KOKKOS_LAMBDA(const KokkosTeamMemberType &t) { 1092 Kokkos::parallel_for(Kokkos::TeamThreadRange(t, 0, rowsPerTeam), [&](PetscInt k) { 1093 PetscInt i = t.league_rank() * rowsPerTeam + k; // i-th row in F 1094 if (i < Fm) { 1095 PetscInt nzLeft = F_NzLeft(i); 1096 PetscInt alen = Fdi(i + 1) - Fdi(i); 1097 PetscInt blen = Foi(i + 1) - Foi(i); 1098 PetscInt Fii = Fdi(i) + Foi(i); 1099 1100 Kokkos::parallel_for(Kokkos::ThreadVectorRange(t, alen + blen), [&](PetscInt j) { 1101 PetscScalar val = leafBuf(Fii + j); 1102 if (j < nzLeft) { // left 1103 Foa(Foi(i) + j) = val; 1104 } else if (j < nzLeft + alen) { // diag 1105 Fda(Fdi(i) + j - nzLeft) = val; 1106 } else { // right 1107 Foa(Foi(i) + j - alen) = val; 1108 } 1109 }); 1110 } 1111 }); 1112 })); 1113 PetscFunctionReturn(PETSC_SUCCESS); 1114 } 1115 1116 static PetscErrorCode MatProductSymbolic_MPIAIJKokkos_AtB(Mat_Product *product, Mat A, Mat B, MatMatStruct_AtB *mm) 1117 { 1118 Mat_MPIAIJ *ampi = static_cast<Mat_MPIAIJ *>(A->data); 1119 Mat_MPIAIJ *bmpi = static_cast<Mat_MPIAIJ *>(B->data); 1120 KokkosCsrMatrix Adt, Aot, Ad, Ao, Bd, Bo; 1121 PetscInt cstart, cend; 1122 MPI_Comm comm; 1123 1124 PetscFunctionBegin; 1125 PetscCall(PetscObjectGetComm((PetscObject)B, &comm)); 1126 PetscCall(MatSeqAIJKokkosGenerateTranspose_Private(ampi->A, &Adt)); 1127 PetscCall(MatSeqAIJKokkosGenerateTranspose_Private(ampi->B, &Aot)); 1128 PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(ampi->A, &Ad)); 1129 PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(ampi->B, &Ao)); 1130 PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(bmpi->A, &Bd)); 1131 PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(bmpi->B, &Bo)); 1132 1133 // TODO: add command line options to select spgemm algorithms 1134 auto spgemm_alg = KokkosSparse::SPGEMMAlgorithm::SPGEMM_DEFAULT; // default is TPL if enabled, otherwise KK 1135 1136 // CUDA-10.2's spgemm has bugs. We prefer the SpGEMMreuse APIs introduced in cuda-11.4 1137 #if defined(KOKKOSKERNELS_ENABLE_TPL_CUSPARSE) 1138 #if PETSC_PKG_CUDA_VERSION_LT(11, 4, 0) 1139 spgemm_alg = KokkosSparse::SPGEMMAlgorithm::SPGEMM_KK; 1140 #endif 1141 #endif 1142 1143 PetscCallCXX(mm->kh1.create_spgemm_handle(spgemm_alg)); 1144 PetscCallCXX(mm->kh2.create_spgemm_handle(spgemm_alg)); 1145 PetscCallCXX(mm->kh3.create_spgemm_handle(spgemm_alg)); 1146 PetscCallCXX(mm->kh4.create_spgemm_handle(spgemm_alg)); 1147 1148 // Aot * (B's diag + B's off-diag) 1149 PetscCallCXX(KokkosSparse::spgemm_symbolic(mm->kh3, Aot, false, Bd, false, mm->C3)); 1150 PetscCallCXX(KokkosSparse::spgemm_symbolic(mm->kh4, Aot, false, Bo, false, mm->C4)); 1151 // KK spgemm_symbolic() only populates the result's row map, but not its columns. 1152 // TODO: Remove the fake spgemm_numeric() after KK fixed this problem. 1153 PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh3, Aot, false, Bd, false, mm->C3)); 1154 PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh4, Aot, false, Bo, false, mm->C4)); 1155 #if PETSC_PKG_KOKKOS_KERNELS_VERSION_LT(4, 0, 0) 1156 PetscCallCXX(sort_crs_matrix(mm->C3)); 1157 PetscCallCXX(sort_crs_matrix(mm->C4)); 1158 #endif 1159 1160 // Reduce E (i.e., C3 and C4)'s rows to form F, and overlap the communication 1161 PetscIntKokkosViewHost map_h(NoInit("map_h"), bmpi->B->cmap->n); 1162 PetscCall(MatGetOwnershipRangeColumn(B, &cstart, &cend)); 1163 PetscCall(MatMPIAIJKokkosReduceBegin(comm, mm->C3, mm->C4, cstart, cend, bmpi->garray, ampi->Mvctx, MAT_INITIAL_MATRIX, map_h.data(), mm)); 1164 1165 // Adt * (B's diag + B's off-diag) 1166 PetscCallCXX(KokkosSparse::spgemm_symbolic(mm->kh1, Adt, false, Bd, false, mm->C1)); 1167 PetscCallCXX(KokkosSparse::spgemm_symbolic(mm->kh2, Adt, false, Bo, false, mm->C2_mid)); 1168 PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh1, Adt, false, Bd, false, mm->C1)); 1169 PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh2, Adt, false, Bo, false, mm->C2_mid)); 1170 #if PETSC_PKG_KOKKOS_KERNELS_VERSION_LT(4, 0, 0) 1171 PetscCallCXX(sort_crs_matrix(mm->C1)); 1172 PetscCallCXX(sort_crs_matrix(mm->C2_mid)); 1173 #endif 1174 1175 PetscCall(MatMPIAIJKokkosReduceEnd(comm, mm->C3, mm->C4, cstart, cend, bmpi->garray, ampi->Mvctx, MAT_INITIAL_MATRIX, map_h.data(), mm)); 1176 1177 // Create C2, which shares a, i arrays with C2_mid, but with new column indices and potentially larger column size 1178 MatColIdxKokkosView oldj = mm->C2_mid.graph.entries, newj(NoInit("j"), oldj.extent(0)); 1179 PetscIntKokkosView map = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), map_h); 1180 PetscCallCXX(Kokkos::parallel_for( 1181 oldj.extent(0), KOKKOS_LAMBDA(const PetscInt i) { newj(i) = map(oldj(i)); })); 1182 PetscCallCXX(mm->C2 = KokkosCsrMatrix("C2", mm->C2_mid.numRows(), mm->n /*new column size*/, mm->C2_mid.nnz(), mm->C2_mid.values, mm->C2_mid.graph.row_map, newj)); 1183 1184 // C = (C1+Fd, C2+Fo) 1185 PetscCallCXX(mm->kh1.create_spadd_handle(true)); // C1, Fd are sorted 1186 PetscCallCXX(mm->kh2.create_spadd_handle(true)); // C2, Fo are sorted 1187 PetscCallCXX(KokkosSparse::spadd_symbolic(&mm->kh1, mm->C1, mm->Fd, mm->Cd)); 1188 PetscCallCXX(KokkosSparse::spadd_symbolic(&mm->kh2, mm->C2, mm->Fo, mm->Co)); 1189 PetscCallCXX(KokkosSparse::spadd_numeric(&mm->kh1, 1.0, mm->C1, 1.0, mm->Fd, mm->Cd)); 1190 PetscCallCXX(KokkosSparse::spadd_numeric(&mm->kh2, 1.0, mm->C2, 1.0, mm->Fo, mm->Co)); 1191 PetscFunctionReturn(PETSC_SUCCESS); 1192 } 1193 1194 static PetscErrorCode MatProductNumeric_MPIAIJKokkos_AtB(Mat_Product *product, Mat A, Mat B, MatMatStruct_AtB *mm) 1195 { 1196 Mat_MPIAIJ *ampi = static_cast<Mat_MPIAIJ *>(A->data); 1197 Mat_MPIAIJ *bmpi = static_cast<Mat_MPIAIJ *>(B->data); 1198 KokkosCsrMatrix Adt, Aot, Bd, Bo; 1199 MPI_Comm comm; 1200 1201 PetscFunctionBegin; 1202 PetscCall(PetscObjectGetComm((PetscObject)B, &comm)); 1203 PetscCall(MatSeqAIJKokkosGenerateTranspose_Private(ampi->A, &Adt)); 1204 PetscCall(MatSeqAIJKokkosGenerateTranspose_Private(ampi->B, &Aot)); 1205 PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(bmpi->A, &Bd)); 1206 PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(bmpi->B, &Bo)); 1207 1208 // Aot * (B's diag + B's off-diag) 1209 PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh3, Aot, false, Bd, false, mm->C3)); 1210 PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh4, Aot, false, Bo, false, mm->C4)); 1211 1212 // Reduce E (i.e., C3 and C4)'s rows to form F, and overlap the communication 1213 PetscCall(MatMPIAIJKokkosReduceBegin(comm, mm->C3, mm->C4, 0, 0, NULL, NULL, MAT_REUSE_MATRIX, NULL, mm)); 1214 1215 // Adt * (B's diag + B's off-diag) 1216 PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh1, Adt, false, Bd, false, mm->C1)); 1217 PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh2, Adt, false, Bo, false, mm->C2_mid)); 1218 1219 PetscCall(MatMPIAIJKokkosReduceEnd(comm, mm->C3, mm->C4, 0, 0, NULL, NULL, MAT_REUSE_MATRIX, NULL, mm)); 1220 1221 // C = (C1+Fd, C2+Fo) 1222 PetscCallCXX(KokkosSparse::spadd_numeric(&mm->kh1, 1.0, mm->C1, 1.0, mm->Fd, mm->Cd)); 1223 PetscCallCXX(KokkosSparse::spadd_numeric(&mm->kh2, 1.0, mm->C2, 1.0, mm->Fo, mm->Co)); 1224 PetscFunctionReturn(PETSC_SUCCESS); 1225 } 1226 1227 /* MatProductSymbolic_MPIAIJKokkos_AB - AB flavor of MatProductSymbolic_MPIAIJKokkos 1228 1229 Input Parameters: 1230 + product - Mat_Product which carried out the computation. Passed in to access info about this mat product. 1231 . A - an MPIAIJKOKKOS matrix 1232 . B - an MPIAIJKOKKOS matrix 1233 - mm - a struct used to stash intermediate data when computing AB. Persist from symbolic to numeric operations. 1234 */ 1235 static PetscErrorCode MatProductSymbolic_MPIAIJKokkos_AB(Mat_Product *product, Mat A, Mat B, MatMatStruct_AB *mm) 1236 { 1237 Mat_MPIAIJ *ampi = static_cast<Mat_MPIAIJ *>(A->data); 1238 Mat_MPIAIJ *bmpi = static_cast<Mat_MPIAIJ *>(B->data); 1239 KokkosCsrMatrix Ad, Ao, Bd, Bo; 1240 1241 PetscFunctionBegin; 1242 PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(ampi->A, &Ad)); 1243 PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(ampi->B, &Ao)); 1244 PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(bmpi->A, &Bd)); 1245 PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(bmpi->B, &Bo)); 1246 1247 // TODO: add command line options to select spgemm algorithms 1248 auto spgemm_alg = KokkosSparse::SPGEMMAlgorithm::SPGEMM_DEFAULT; // default is TPL if enabled, otherwise KK 1249 1250 // CUDA-10.2's spgemm has bugs. We prefer the SpGEMMreuse APIs introduced in cuda-11.4 1251 #if defined(KOKKOSKERNELS_ENABLE_TPL_CUSPARSE) 1252 #if PETSC_PKG_CUDA_VERSION_LT(11, 4, 0) 1253 spgemm_alg = KokkosSparse::SPGEMMAlgorithm::SPGEMM_KK; 1254 #endif 1255 #endif 1256 1257 mm->kh1.create_spgemm_handle(spgemm_alg); 1258 mm->kh2.create_spgemm_handle(spgemm_alg); 1259 mm->kh3.create_spgemm_handle(spgemm_alg); 1260 mm->kh4.create_spgemm_handle(spgemm_alg); 1261 1262 // Bcast B's rows to form F, and overlap the communication 1263 PetscIntKokkosViewHost map_h(NoInit("map_h"), bmpi->B->cmap->n); 1264 PetscCall(MatMPIAIJKokkosBcastBegin(B, ampi->Mvctx, MAT_INITIAL_MATRIX, map_h.data(), mm)); 1265 1266 // A's diag * (B's diag + B's off-diag) 1267 PetscCallCXX(KokkosSparse::spgemm_symbolic(mm->kh1, Ad, false, Bd, false, mm->C1)); 1268 PetscCallCXX(KokkosSparse::spgemm_symbolic(mm->kh2, Ad, false, Bo, false, mm->C2_mid)); // C2 aliases with C2_mid, except with new column indices 1269 // KK spgemm_symbolic() only populates the result's row map, but not its columns. 1270 // TODO: Remove the fake spgemm_numeric() after KK fixed this problem. 1271 PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh1, Ad, false, Bd, false, mm->C1)); 1272 PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh2, Ad, false, Bo, false, mm->C2_mid)); 1273 #if PETSC_PKG_KOKKOS_KERNELS_VERSION_LT(4, 0, 0) 1274 PetscCallCXX(sort_crs_matrix(mm->C1)); 1275 PetscCallCXX(sort_crs_matrix(mm->C2_mid)); 1276 #endif 1277 1278 PetscCall(MatMPIAIJKokkosBcastEnd(B, ampi->Mvctx, MAT_INITIAL_MATRIX, map_h.data(), mm)); 1279 1280 // A's off-diag * (F's diag + F's off-diag) 1281 PetscCallCXX(KokkosSparse::spgemm_symbolic(mm->kh3, Ao, false, mm->Fd, false, mm->C3)); 1282 PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh3, Ao, false, mm->Fd, false, mm->C3)); 1283 PetscCallCXX(KokkosSparse::spgemm_symbolic(mm->kh4, Ao, false, mm->Fo, false, mm->C4)); 1284 PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh4, Ao, false, mm->Fo, false, mm->C4)); 1285 #if PETSC_PKG_KOKKOS_KERNELS_VERSION_LT(4, 0, 0) 1286 PetscCallCXX(sort_crs_matrix(mm->C3)); 1287 PetscCallCXX(sort_crs_matrix(mm->C4)); 1288 #endif 1289 1290 // Create C2, which shares a, i arrays with C2_mid, but with new column indices and potentially larger column size 1291 MatColIdxKokkosView oldj = mm->C2_mid.graph.entries, newj(NoInit("j"), oldj.extent(0)); 1292 PetscIntKokkosView map = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), map_h); 1293 PetscCallCXX(Kokkos::parallel_for( 1294 oldj.extent(0), KOKKOS_LAMBDA(const PetscInt i) { newj(i) = map(oldj(i)); })); 1295 mm->C2 = KokkosCsrMatrix("C2", mm->C2_mid.numRows(), mm->n /*new column size*/, mm->C2_mid.nnz(), mm->C2_mid.values, mm->C2_mid.graph.row_map, newj); 1296 1297 // C = (Cd, Co) = (C1+C3, C2+C4) 1298 mm->kh1.create_spadd_handle(true); // C1, C3 are sorted 1299 mm->kh2.create_spadd_handle(true); // C2, C4 are sorted 1300 PetscCallCXX(KokkosSparse::spadd_symbolic(&mm->kh1, mm->C1, mm->C3, mm->Cd)); 1301 PetscCallCXX(KokkosSparse::spadd_symbolic(&mm->kh2, mm->C2, mm->C4, mm->Co)); 1302 PetscCallCXX(KokkosSparse::spadd_numeric(&mm->kh1, 1.0, mm->C1, 1.0, mm->C3, mm->Cd)); 1303 PetscCallCXX(KokkosSparse::spadd_numeric(&mm->kh2, 1.0, mm->C2, 1.0, mm->C4, mm->Co)); 1304 PetscFunctionReturn(PETSC_SUCCESS); 1305 } 1306 1307 static PetscErrorCode MatProductNumeric_MPIAIJKokkos_AB(Mat_Product *product, Mat A, Mat B, MatMatStruct_AB *mm) 1308 { 1309 Mat_MPIAIJ *ampi = static_cast<Mat_MPIAIJ *>(A->data); 1310 Mat_MPIAIJ *bmpi = static_cast<Mat_MPIAIJ *>(B->data); 1311 KokkosCsrMatrix Ad, Ao, Bd, Bo; 1312 1313 PetscFunctionBegin; 1314 PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(ampi->A, &Ad)); 1315 PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(ampi->B, &Ao)); 1316 PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(bmpi->A, &Bd)); 1317 PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(bmpi->B, &Bo)); 1318 1319 // Bcast B's rows to form F, and overlap the communication 1320 PetscCall(MatMPIAIJKokkosBcastBegin(B, NULL, MAT_REUSE_MATRIX, NULL, mm)); 1321 1322 // A's diag * (B's diag + B's off-diag) 1323 PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh1, Ad, false, Bd, false, mm->C1)); 1324 PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh2, Ad, false, Bo, false, mm->C2_mid)); 1325 1326 PetscCall(MatMPIAIJKokkosBcastEnd(B, NULL, MAT_REUSE_MATRIX, NULL, mm)); 1327 1328 // A's off-diag * (F's diag + F's off-diag) 1329 PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh3, Ao, false, mm->Fd, false, mm->C3)); 1330 PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh4, Ao, false, mm->Fo, false, mm->C4)); 1331 1332 // C = (Cd, Co) = (C1+C3, C2+C4) 1333 PetscCallCXX(KokkosSparse::spadd_numeric(&mm->kh1, 1.0, mm->C1, 1.0, mm->C3, mm->Cd)); 1334 PetscCallCXX(KokkosSparse::spadd_numeric(&mm->kh2, 1.0, mm->C2, 1.0, mm->C4, mm->Co)); 1335 PetscFunctionReturn(PETSC_SUCCESS); 1336 } 1337 1338 PetscErrorCode MatProductNumeric_MPIAIJKokkos(Mat C) 1339 { 1340 Mat_MPIAIJ *cmpi = static_cast<Mat_MPIAIJ *>(C->data); 1341 Mat_Product *product; 1342 MatProductData_MPIAIJKokkos *pdata; 1343 MatProductType ptype; 1344 Mat A, B; 1345 1346 PetscFunctionBegin; 1347 MatCheckProduct(C, 1); // make sure C is a product 1348 product = C->product; 1349 pdata = static_cast<MatProductData_MPIAIJKokkos *>(product->data); 1350 ptype = product->type; 1351 A = product->A; 1352 B = product->B; 1353 1354 // See if numeric has already been done in symbolic (e.g., user calls MatMatMult(A,B,MAT_INITIAL_MATRIX,..,C)). 1355 // If yes, skip the numeric, but reset the flag so that next time when user calls MatMatMult(E,F,MAT_REUSE_MATRIX,..,C), 1356 // we still do numeric. 1357 if (pdata->reusesym) { // numeric reuses results from symbolic 1358 pdata->reusesym = PETSC_FALSE; 1359 PetscFunctionReturn(PETSC_SUCCESS); 1360 } 1361 1362 if (ptype == MATPRODUCT_AB) { 1363 PetscCall(MatProductNumeric_MPIAIJKokkos_AB(product, A, B, pdata->mmAB)); 1364 } else if (ptype == MATPRODUCT_AtB) { 1365 PetscCall(MatProductNumeric_MPIAIJKokkos_AtB(product, A, B, pdata->mmAtB)); 1366 } else if (ptype == MATPRODUCT_PtAP) { // BtAB, computed by Z = AB; C= BtZ 1367 PetscCall(MatProductNumeric_MPIAIJKokkos_AB(product, A, B, pdata->mmAB)); 1368 PetscCall(MatProductNumeric_MPIAIJKokkos_AtB(product, B, pdata->Z, pdata->mmAtB)); 1369 } 1370 1371 PetscCall(MatSeqAIJKokkosModifyDevice(cmpi->A)); // mark that A, B on device are modified 1372 PetscCall(MatSeqAIJKokkosModifyDevice(cmpi->B)); 1373 PetscFunctionReturn(PETSC_SUCCESS); 1374 } 1375 1376 PetscErrorCode MatProductSymbolic_MPIAIJKokkos(Mat C) 1377 { 1378 Mat A, B; 1379 Mat_Product *product; 1380 MatProductType ptype; 1381 MatProductData_MPIAIJKokkos *pdata; 1382 MatMatStruct *mm = NULL; 1383 PetscInt m, n, M, N; 1384 Mat Cd, Co; 1385 MPI_Comm comm; 1386 1387 PetscFunctionBegin; 1388 PetscCall(PetscObjectGetComm((PetscObject)C, &comm)); 1389 MatCheckProduct(C, 1); 1390 product = C->product; 1391 PetscCheck(!product->data, comm, PETSC_ERR_PLIB, "Product data not empty"); 1392 ptype = product->type; 1393 A = product->A; 1394 B = product->B; 1395 1396 switch (ptype) { 1397 case MATPRODUCT_AB: 1398 m = A->rmap->n; 1399 n = B->cmap->n; 1400 M = A->rmap->N; 1401 N = B->cmap->N; 1402 break; 1403 case MATPRODUCT_AtB: 1404 m = A->cmap->n; 1405 n = B->cmap->n; 1406 M = A->cmap->N; 1407 N = B->cmap->N; 1408 break; 1409 case MATPRODUCT_PtAP: 1410 m = B->cmap->n; 1411 n = B->cmap->n; 1412 M = B->cmap->N; 1413 N = B->cmap->N; 1414 break; /* BtAB */ 1415 default: 1416 SETERRQ(comm, PETSC_ERR_PLIB, "Not for product type %s", MatProductTypes[ptype]); 1417 } 1418 1419 PetscCall(MatSetSizes(C, m, n, M, N)); 1420 PetscCall(PetscLayoutSetUp(C->rmap)); 1421 PetscCall(PetscLayoutSetUp(C->cmap)); 1422 PetscCall(MatSetType(C, ((PetscObject)A)->type_name)); 1423 1424 pdata = new MatProductData_MPIAIJKokkos(); 1425 pdata->reusesym = product->api_user; 1426 1427 if (ptype == MATPRODUCT_AB) { 1428 auto mmAB = new MatMatStruct_AB(); 1429 PetscCall(MatProductSymbolic_MPIAIJKokkos_AB(product, A, B, mmAB)); 1430 mm = pdata->mmAB = mmAB; 1431 } else if (ptype == MATPRODUCT_AtB) { 1432 auto mmAtB = new MatMatStruct_AtB(); 1433 PetscCall(MatProductSymbolic_MPIAIJKokkos_AtB(product, A, B, mmAtB)); 1434 mm = pdata->mmAtB = mmAtB; 1435 } else if (ptype == MATPRODUCT_PtAP) { // C = BtAB, computed as Z = AB; C= BtZ 1436 Mat Zd, Zo, Z; // Zd, Zo are owned by pdata->Z 1437 1438 auto mmAB = new MatMatStruct_AB(); 1439 PetscCall(MatProductSymbolic_MPIAIJKokkos_AB(product, A, B, mmAB)); // Z stored as mmAB->{Cd, Co} 1440 PetscCall(MatCreateSeqAIJKokkosWithKokkosCsrMatrix(PETSC_COMM_SELF, mmAB->Cd, &Zd)); 1441 PetscCall(MatCreateSeqAIJKokkosWithKokkosCsrMatrix(PETSC_COMM_SELF, mmAB->Co, &Zo)); 1442 pdata->mmAB = mmAB; 1443 1444 m = A->rmap->n; // Z's layout 1445 n = B->cmap->n; 1446 M = A->rmap->N; 1447 N = B->cmap->N; 1448 PetscCall(MatCreate(comm, &Z)); 1449 PetscCall(MatSetSizes(Z, m, n, M, N)); 1450 PetscCall(PetscLayoutSetUp(Z->rmap)); 1451 PetscCall(PetscLayoutSetUp(Z->cmap)); 1452 PetscCall(MatSetType(Z, MATMPIAIJKOKKOS)); 1453 PetscCall(MatSetMPIAIJKokkosWithSplitSeqAIJKokkosMatrices(Z, Zd, Zo, mmAB->garray)); 1454 1455 auto mmAtB = new MatMatStruct_AtB(); 1456 PetscCall(MatProductSymbolic_MPIAIJKokkos_AtB(product, B, Z, mmAtB)); // final result C stored as mmAtB->{Cd, Co} 1457 1458 pdata->Z = Z; // give ownership to pdata 1459 mm = pdata->mmAtB = mmAtB; 1460 } 1461 1462 PetscCall(MatCreateSeqAIJKokkosWithKokkosCsrMatrix(PETSC_COMM_SELF, mm->Cd, &Cd)); 1463 PetscCall(MatCreateSeqAIJKokkosWithKokkosCsrMatrix(PETSC_COMM_SELF, mm->Co, &Co)); 1464 PetscCall(MatSetMPIAIJKokkosWithSplitSeqAIJKokkosMatrices(C, Cd, Co, mm->garray)); 1465 1466 C->product->data = pdata; 1467 C->product->destroy = MatProductDataDestroy_MPIAIJKokkos; 1468 C->ops->productnumeric = MatProductNumeric_MPIAIJKokkos; 1469 PetscFunctionReturn(PETSC_SUCCESS); 1470 } 1471 1472 PETSC_INTERN PetscErrorCode MatProductSetFromOptions_MPIAIJKokkos(Mat mat) 1473 { 1474 Mat_Product *product = mat->product; 1475 PetscBool match = PETSC_FALSE; 1476 PetscBool usecpu = PETSC_FALSE; 1477 1478 PetscFunctionBegin; 1479 MatCheckProduct(mat, 1); 1480 if (!product->A->boundtocpu && !product->B->boundtocpu) PetscCall(PetscObjectTypeCompare((PetscObject)product->B, ((PetscObject)product->A)->type_name, &match)); 1481 if (match) { /* we can always fallback to the CPU if requested */ 1482 switch (product->type) { 1483 case MATPRODUCT_AB: 1484 if (product->api_user) { 1485 PetscOptionsBegin(PetscObjectComm((PetscObject)mat), ((PetscObject)mat)->prefix, "MatMatMult", "Mat"); 1486 PetscCall(PetscOptionsBool("-matmatmult_backend_cpu", "Use CPU code", "MatMatMult", usecpu, &usecpu, NULL)); 1487 PetscOptionsEnd(); 1488 } else { 1489 PetscOptionsBegin(PetscObjectComm((PetscObject)mat), ((PetscObject)mat)->prefix, "MatProduct_AB", "Mat"); 1490 PetscCall(PetscOptionsBool("-mat_product_algorithm_backend_cpu", "Use CPU code", "MatMatMult", usecpu, &usecpu, NULL)); 1491 PetscOptionsEnd(); 1492 } 1493 break; 1494 case MATPRODUCT_AtB: 1495 if (product->api_user) { 1496 PetscOptionsBegin(PetscObjectComm((PetscObject)mat), ((PetscObject)mat)->prefix, "MatTransposeMatMult", "Mat"); 1497 PetscCall(PetscOptionsBool("-mattransposematmult_backend_cpu", "Use CPU code", "MatTransposeMatMult", usecpu, &usecpu, NULL)); 1498 PetscOptionsEnd(); 1499 } else { 1500 PetscOptionsBegin(PetscObjectComm((PetscObject)mat), ((PetscObject)mat)->prefix, "MatProduct_AtB", "Mat"); 1501 PetscCall(PetscOptionsBool("-mat_product_algorithm_backend_cpu", "Use CPU code", "MatTransposeMatMult", usecpu, &usecpu, NULL)); 1502 PetscOptionsEnd(); 1503 } 1504 break; 1505 case MATPRODUCT_PtAP: 1506 if (product->api_user) { 1507 PetscOptionsBegin(PetscObjectComm((PetscObject)mat), ((PetscObject)mat)->prefix, "MatPtAP", "Mat"); 1508 PetscCall(PetscOptionsBool("-matptap_backend_cpu", "Use CPU code", "MatPtAP", usecpu, &usecpu, NULL)); 1509 PetscOptionsEnd(); 1510 } else { 1511 PetscOptionsBegin(PetscObjectComm((PetscObject)mat), ((PetscObject)mat)->prefix, "MatProduct_PtAP", "Mat"); 1512 PetscCall(PetscOptionsBool("-mat_product_algorithm_backend_cpu", "Use CPU code", "MatPtAP", usecpu, &usecpu, NULL)); 1513 PetscOptionsEnd(); 1514 } 1515 break; 1516 default: 1517 break; 1518 } 1519 match = (PetscBool)!usecpu; 1520 } 1521 if (match) { 1522 switch (product->type) { 1523 case MATPRODUCT_AB: 1524 case MATPRODUCT_AtB: 1525 case MATPRODUCT_PtAP: 1526 mat->ops->productsymbolic = MatProductSymbolic_MPIAIJKokkos; 1527 break; 1528 default: 1529 break; 1530 } 1531 } 1532 /* fallback to MPIAIJ ops */ 1533 if (!mat->ops->productsymbolic) PetscCall(MatProductSetFromOptions_MPIAIJ(mat)); 1534 PetscFunctionReturn(PETSC_SUCCESS); 1535 } 1536 1537 static PetscErrorCode MatSetPreallocationCOO_MPIAIJKokkos(Mat mat, PetscCount coo_n, PetscInt coo_i[], PetscInt coo_j[]) 1538 { 1539 Mat_MPIAIJ *mpiaij = (Mat_MPIAIJ *)mat->data; 1540 Mat_MPIAIJKokkos *mpikok; 1541 1542 PetscFunctionBegin; 1543 PetscCall(MatSetPreallocationCOO_MPIAIJ(mat, coo_n, coo_i, coo_j)); /* mpiaij->A,B's type is set to seqaijkokkos */ 1544 mat->preallocated = PETSC_TRUE; 1545 PetscCall(MatAssemblyBegin(mat, MAT_FINAL_ASSEMBLY)); 1546 PetscCall(MatAssemblyEnd(mat, MAT_FINAL_ASSEMBLY)); 1547 PetscCall(MatZeroEntries(mat)); 1548 mpikok = static_cast<Mat_MPIAIJKokkos *>(mpiaij->spptr); 1549 delete mpikok; 1550 mpiaij->spptr = new Mat_MPIAIJKokkos(mpiaij); 1551 PetscFunctionReturn(PETSC_SUCCESS); 1552 } 1553 1554 static PetscErrorCode MatSetValuesCOO_MPIAIJKokkos(Mat mat, const PetscScalar v[], InsertMode imode) 1555 { 1556 Mat_MPIAIJ *mpiaij = static_cast<Mat_MPIAIJ *>(mat->data); 1557 Mat_MPIAIJKokkos *mpikok = static_cast<Mat_MPIAIJKokkos *>(mpiaij->spptr); 1558 Mat A = mpiaij->A, B = mpiaij->B; 1559 PetscCount Annz = mpiaij->Annz, Annz2 = mpiaij->Annz2, Bnnz = mpiaij->Bnnz, Bnnz2 = mpiaij->Bnnz2; 1560 MatScalarKokkosView Aa, Ba; 1561 MatScalarKokkosView v1; 1562 MatScalarKokkosView &vsend = mpikok->sendbuf_d; 1563 const MatScalarKokkosView &v2 = mpikok->recvbuf_d; 1564 const PetscCountKokkosView &Ajmap1 = mpikok->Ajmap1_d, Ajmap2 = mpikok->Ajmap2_d, Aimap2 = mpikok->Aimap2_d; 1565 const PetscCountKokkosView &Bjmap1 = mpikok->Bjmap1_d, Bjmap2 = mpikok->Bjmap2_d, Bimap2 = mpikok->Bimap2_d; 1566 const PetscCountKokkosView &Aperm1 = mpikok->Aperm1_d, Aperm2 = mpikok->Aperm2_d, Bperm1 = mpikok->Bperm1_d, Bperm2 = mpikok->Bperm2_d; 1567 const PetscCountKokkosView &Cperm1 = mpikok->Cperm1_d; 1568 PetscMemType memtype; 1569 1570 PetscFunctionBegin; 1571 PetscCall(PetscGetMemType(v, &memtype)); /* Return PETSC_MEMTYPE_HOST when v is NULL */ 1572 if (PetscMemTypeHost(memtype)) { /* If user gave v[] in host, we need to copy it to device if any */ 1573 v1 = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), MatScalarKokkosViewHost((PetscScalar *)v, mpiaij->coo_n)); 1574 } else { 1575 v1 = MatScalarKokkosView((PetscScalar *)v, mpiaij->coo_n); /* Directly use v[]'s memory */ 1576 } 1577 1578 if (imode == INSERT_VALUES) { 1579 PetscCall(MatSeqAIJGetKokkosViewWrite(A, &Aa)); /* write matrix values */ 1580 PetscCall(MatSeqAIJGetKokkosViewWrite(B, &Ba)); 1581 } else { 1582 PetscCall(MatSeqAIJGetKokkosView(A, &Aa)); /* read & write matrix values */ 1583 PetscCall(MatSeqAIJGetKokkosView(B, &Ba)); 1584 } 1585 1586 /* Pack entries to be sent to remote */ 1587 Kokkos::parallel_for( 1588 vsend.extent(0), KOKKOS_LAMBDA(const PetscCount i) { vsend(i) = v1(Cperm1(i)); }); 1589 1590 /* Send remote entries to their owner and overlap the communication with local computation */ 1591 PetscCall(PetscSFReduceWithMemTypeBegin(mpiaij->coo_sf, MPIU_SCALAR, PETSC_MEMTYPE_KOKKOS, vsend.data(), PETSC_MEMTYPE_KOKKOS, v2.data(), MPI_REPLACE)); 1592 /* Add local entries to A and B in one kernel */ 1593 Kokkos::parallel_for( 1594 Annz + Bnnz, KOKKOS_LAMBDA(PetscCount i) { 1595 PetscScalar sum = 0.0; 1596 if (i < Annz) { 1597 for (PetscCount k = Ajmap1(i); k < Ajmap1(i + 1); k++) sum += v1(Aperm1(k)); 1598 Aa(i) = (imode == INSERT_VALUES ? 0.0 : Aa(i)) + sum; 1599 } else { 1600 i -= Annz; 1601 for (PetscCount k = Bjmap1(i); k < Bjmap1(i + 1); k++) sum += v1(Bperm1(k)); 1602 Ba(i) = (imode == INSERT_VALUES ? 0.0 : Ba(i)) + sum; 1603 } 1604 }); 1605 PetscCall(PetscSFReduceEnd(mpiaij->coo_sf, MPIU_SCALAR, vsend.data(), v2.data(), MPI_REPLACE)); 1606 1607 /* Add received remote entries to A and B in one kernel */ 1608 Kokkos::parallel_for( 1609 Annz2 + Bnnz2, KOKKOS_LAMBDA(PetscCount i) { 1610 if (i < Annz2) { 1611 for (PetscCount k = Ajmap2(i); k < Ajmap2(i + 1); k++) Aa(Aimap2(i)) += v2(Aperm2(k)); 1612 } else { 1613 i -= Annz2; 1614 for (PetscCount k = Bjmap2(i); k < Bjmap2(i + 1); k++) Ba(Bimap2(i)) += v2(Bperm2(k)); 1615 } 1616 }); 1617 1618 if (imode == INSERT_VALUES) { 1619 PetscCall(MatSeqAIJRestoreKokkosViewWrite(A, &Aa)); /* Increase A & B's state etc. */ 1620 PetscCall(MatSeqAIJRestoreKokkosViewWrite(B, &Ba)); 1621 } else { 1622 PetscCall(MatSeqAIJRestoreKokkosView(A, &Aa)); 1623 PetscCall(MatSeqAIJRestoreKokkosView(B, &Ba)); 1624 } 1625 PetscFunctionReturn(PETSC_SUCCESS); 1626 } 1627 1628 PetscErrorCode MatDestroy_MPIAIJKokkos(Mat A) 1629 { 1630 Mat_MPIAIJ *mpiaij = (Mat_MPIAIJ *)A->data; 1631 1632 PetscFunctionBegin; 1633 PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatMPIAIJSetPreallocation_C", NULL)); 1634 PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatMPIAIJGetLocalMatMerge_C", NULL)); 1635 PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatSetPreallocationCOO_C", NULL)); 1636 PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatSetValuesCOO_C", NULL)); 1637 delete (Mat_MPIAIJKokkos *)mpiaij->spptr; 1638 PetscCall(MatDestroy_MPIAIJ(A)); 1639 PetscFunctionReturn(PETSC_SUCCESS); 1640 } 1641 1642 PETSC_INTERN PetscErrorCode MatConvert_MPIAIJ_MPIAIJKokkos(Mat A, MatType mtype, MatReuse reuse, Mat *newmat) 1643 { 1644 Mat B; 1645 Mat_MPIAIJ *a; 1646 1647 PetscFunctionBegin; 1648 if (reuse == MAT_INITIAL_MATRIX) { 1649 PetscCall(MatDuplicate(A, MAT_COPY_VALUES, newmat)); 1650 } else if (reuse == MAT_REUSE_MATRIX) { 1651 PetscCall(MatCopy(A, *newmat, SAME_NONZERO_PATTERN)); 1652 } 1653 B = *newmat; 1654 1655 B->boundtocpu = PETSC_FALSE; 1656 PetscCall(PetscFree(B->defaultvectype)); 1657 PetscCall(PetscStrallocpy(VECKOKKOS, &B->defaultvectype)); 1658 PetscCall(PetscObjectChangeTypeName((PetscObject)B, MATMPIAIJKOKKOS)); 1659 1660 a = static_cast<Mat_MPIAIJ *>(A->data); 1661 if (a->A) PetscCall(MatSetType(a->A, MATSEQAIJKOKKOS)); 1662 if (a->B) PetscCall(MatSetType(a->B, MATSEQAIJKOKKOS)); 1663 if (a->lvec) PetscCall(VecSetType(a->lvec, VECSEQKOKKOS)); 1664 1665 B->ops->assemblyend = MatAssemblyEnd_MPIAIJKokkos; 1666 B->ops->mult = MatMult_MPIAIJKokkos; 1667 B->ops->multadd = MatMultAdd_MPIAIJKokkos; 1668 B->ops->multtranspose = MatMultTranspose_MPIAIJKokkos; 1669 B->ops->productsetfromoptions = MatProductSetFromOptions_MPIAIJKokkos; 1670 B->ops->destroy = MatDestroy_MPIAIJKokkos; 1671 1672 PetscCall(PetscObjectComposeFunction((PetscObject)B, "MatMPIAIJSetPreallocation_C", MatMPIAIJSetPreallocation_MPIAIJKokkos)); 1673 PetscCall(PetscObjectComposeFunction((PetscObject)B, "MatMPIAIJGetLocalMatMerge_C", MatMPIAIJGetLocalMatMerge_MPIAIJKokkos)); 1674 PetscCall(PetscObjectComposeFunction((PetscObject)B, "MatSetPreallocationCOO_C", MatSetPreallocationCOO_MPIAIJKokkos)); 1675 PetscCall(PetscObjectComposeFunction((PetscObject)B, "MatSetValuesCOO_C", MatSetValuesCOO_MPIAIJKokkos)); 1676 PetscFunctionReturn(PETSC_SUCCESS); 1677 } 1678 /*MC 1679 MATAIJKOKKOS - "mpiaijkokkos", a matrix type to be used for CSR sparse matrices with Kokkos 1680 1681 A matrix type type using Kokkos-Kernels CrsMatrix type for portability across different device types 1682 1683 Options Database Key: 1684 . -mat_type aijkokkos - sets the matrix type to `MATAIJKOKKOS` 1685 1686 Level: beginner 1687 1688 .seealso: [](chapter_matrices), `Mat`, `MatCreateAIJKokkos()`, `MATSEQAIJKOKKOS`, `MATSEQAIJ`, `MATMPIAIJ` 1689 M*/ 1690 PETSC_EXTERN PetscErrorCode MatCreate_MPIAIJKokkos(Mat A) 1691 { 1692 PetscFunctionBegin; 1693 PetscCall(PetscKokkosInitializeCheck()); 1694 PetscCall(MatCreate_MPIAIJ(A)); 1695 PetscCall(MatConvert_MPIAIJ_MPIAIJKokkos(A, MATMPIAIJKOKKOS, MAT_INPLACE_MATRIX, &A)); 1696 PetscFunctionReturn(PETSC_SUCCESS); 1697 } 1698 1699 /*@C 1700 MatCreateAIJKokkos - Creates a sparse matrix in `MATAIJKOKOS` (compressed row) format 1701 (the default parallel PETSc format). This matrix will ultimately pushed down 1702 to Kokkos for calculations. 1703 1704 Collective 1705 1706 Input Parameters: 1707 + comm - MPI communicator, set to `PETSC_COMM_SELF` 1708 . m - number of local rows (or `PETSC_DECIDE` to have calculated if `M` is given) 1709 This value should be the same as the local size used in creating the 1710 y vector for the matrix-vector product y = Ax. 1711 . n - This value should be the same as the local size used in creating the 1712 x vector for the matrix-vector product y = Ax. (or `PETSC_DECIDE` to have 1713 calculated if N is given) For square matrices n is almost always `m`. 1714 . M - number of global rows (or `PETSC_DETERMINE` to have calculated if `m` is given) 1715 . N - number of global columns (or `PETSC_DETERMINE` to have calculated if `n` is given) 1716 . d_nz - number of nonzeros per row in DIAGONAL portion of local submatrix 1717 (same value is used for all local rows) 1718 . d_nnz - array containing the number of nonzeros in the various rows of the 1719 DIAGONAL portion of the local submatrix (possibly different for each row) 1720 or `NULL`, if `d_nz` is used to specify the nonzero structure. 1721 The size of this array is equal to the number of local rows, i.e `m`. 1722 For matrices you plan to factor you must leave room for the diagonal entry and 1723 put in the entry even if it is zero. 1724 . o_nz - number of nonzeros per row in the OFF-DIAGONAL portion of local 1725 submatrix (same value is used for all local rows). 1726 - o_nnz - array containing the number of nonzeros in the various rows of the 1727 OFF-DIAGONAL portion of the local submatrix (possibly different for 1728 each row) or `NULL`, if `o_nz` is used to specify the nonzero 1729 structure. The size of this array is equal to the number 1730 of local rows, i.e `m`. 1731 1732 Output Parameter: 1733 . A - the matrix 1734 1735 Level: intermediate 1736 1737 Notes: 1738 It is recommended that one use the `MatCreate()`, `MatSetType()` and/or `MatSetFromOptions()`, 1739 MatXXXXSetPreallocation() paradigm instead of this routine directly. 1740 [MatXXXXSetPreallocation() is, for example, `MatSeqAIJSetPreallocation()`] 1741 1742 The AIJ format, also called compressed row storage), is fully compatible with standard Fortran 1743 storage. That is, the stored row and column indices can begin at 1744 either one (as in Fortran) or zero. 1745 1746 .seealso: [](chapter_matrices), `Mat`, `MATAIJKOKOS`, `MATSEQAIJKOKOS`, `MATMPIAIJKOKOS`, `MatCreate()`, `MatCreateAIJ()`, `MatSetValues()`, 1747 `MatSeqAIJSetColumnIndices()`, `MatCreateSeqAIJWithArrays()`, `MatCreateAIJ()`, `MATMPIAIJKOKKOS`, `MATAIJKOKKOS` 1748 @*/ 1749 PetscErrorCode MatCreateAIJKokkos(MPI_Comm comm, PetscInt m, PetscInt n, PetscInt M, PetscInt N, PetscInt d_nz, const PetscInt d_nnz[], PetscInt o_nz, const PetscInt o_nnz[], Mat *A) 1750 { 1751 PetscMPIInt size; 1752 1753 PetscFunctionBegin; 1754 PetscCall(MatCreate(comm, A)); 1755 PetscCall(MatSetSizes(*A, m, n, M, N)); 1756 PetscCallMPI(MPI_Comm_size(comm, &size)); 1757 if (size > 1) { 1758 PetscCall(MatSetType(*A, MATMPIAIJKOKKOS)); 1759 PetscCall(MatMPIAIJSetPreallocation(*A, d_nz, d_nnz, o_nz, o_nnz)); 1760 } else { 1761 PetscCall(MatSetType(*A, MATSEQAIJKOKKOS)); 1762 PetscCall(MatSeqAIJSetPreallocation(*A, d_nz, d_nnz)); 1763 } 1764 PetscFunctionReturn(PETSC_SUCCESS); 1765 } 1766