1 #include <petscvec_kokkos.hpp> 2 #include <petscpkg_version.h> 3 #include <petsc/private/sfimpl.h> 4 #include <../src/mat/impls/aij/seq/kokkos/aijkok.hpp> 5 #include <../src/mat/impls/aij/mpi/mpiaij.h> 6 #include <KokkosSparse_spadd.hpp> 7 #include <KokkosSparse_spgemm.hpp> 8 9 static PetscErrorCode MatAssemblyEnd_MPIAIJKokkos(Mat A, MatAssemblyType mode) 10 { 11 Mat_MPIAIJ *mpiaij = (Mat_MPIAIJ *)A->data; 12 13 PetscFunctionBegin; 14 PetscCall(MatAssemblyEnd_MPIAIJ(A, mode)); 15 /* E.g., MatCreateSubMatrix() calls MatCreateMPIAIJWithSeqAIJ(comm,A,B,..), which creates Bnew of SEQAIJ and destroys B of SEQAIJKOKKOS. 16 Thus we finalize A/B/lvec's type in MatAssemblyEnd() to handle various cases. 17 */ 18 if (mode == MAT_FINAL_ASSEMBLY) { 19 PetscCall(MatSetType(mpiaij->A, MATSEQAIJKOKKOS)); 20 PetscCall(MatSetType(mpiaij->B, MATSEQAIJKOKKOS)); 21 PetscCall(VecSetType(mpiaij->lvec, VECSEQKOKKOS)); 22 } 23 PetscFunctionReturn(PETSC_SUCCESS); 24 } 25 26 static PetscErrorCode MatMPIAIJSetPreallocation_MPIAIJKokkos(Mat mat, PetscInt d_nz, const PetscInt d_nnz[], PetscInt o_nz, const PetscInt o_nnz[]) 27 { 28 Mat_MPIAIJ *mpiaij = (Mat_MPIAIJ *)mat->data; 29 30 PetscFunctionBegin; 31 // If mat was set to use the "set values with a hash table" mechanism, discard it and restore the cached ops 32 if (mat->hash_active) { 33 mat->ops[0] = mpiaij->cops; 34 mat->hash_active = PETSC_FALSE; 35 } 36 37 PetscCall(PetscLayoutSetUp(mat->rmap)); 38 PetscCall(PetscLayoutSetUp(mat->cmap)); 39 #if defined(PETSC_USE_DEBUG) 40 if (d_nnz) { 41 PetscInt i; 42 for (i = 0; i < mat->rmap->n; i++) PetscCheck(d_nnz[i] >= 0, PETSC_COMM_SELF, PETSC_ERR_ARG_OUTOFRANGE, "d_nnz cannot be less than 0: local row %" PetscInt_FMT " value %" PetscInt_FMT, i, d_nnz[i]); 43 } 44 if (o_nnz) { 45 PetscInt i; 46 for (i = 0; i < mat->rmap->n; i++) PetscCheck(o_nnz[i] >= 0, PETSC_COMM_SELF, PETSC_ERR_ARG_OUTOFRANGE, "o_nnz cannot be less than 0: local row %" PetscInt_FMT " value %" PetscInt_FMT, i, o_nnz[i]); 47 } 48 #endif 49 #if defined(PETSC_USE_CTABLE) 50 PetscCall(PetscHMapIDestroy(&mpiaij->colmap)); 51 #else 52 PetscCall(PetscFree(mpiaij->colmap)); 53 #endif 54 PetscCall(PetscFree(mpiaij->garray)); 55 PetscCall(VecDestroy(&mpiaij->lvec)); 56 PetscCall(VecScatterDestroy(&mpiaij->Mvctx)); 57 /* Because the B will have been resized we simply destroy it and create a new one each time */ 58 PetscCall(MatDestroy(&mpiaij->B)); 59 60 if (!mpiaij->A) { 61 PetscCall(MatCreate(PETSC_COMM_SELF, &mpiaij->A)); 62 PetscCall(MatSetSizes(mpiaij->A, mat->rmap->n, mat->cmap->n, mat->rmap->n, mat->cmap->n)); 63 } 64 if (!mpiaij->B) { 65 PetscMPIInt size; 66 PetscCallMPI(MPI_Comm_size(PetscObjectComm((PetscObject)mat), &size)); 67 PetscCall(MatCreate(PETSC_COMM_SELF, &mpiaij->B)); 68 PetscCall(MatSetSizes(mpiaij->B, mat->rmap->n, size > 1 ? mat->cmap->N : 0, mat->rmap->n, size > 1 ? mat->cmap->N : 0)); 69 } 70 PetscCall(MatSetType(mpiaij->A, MATSEQAIJKOKKOS)); 71 PetscCall(MatSetType(mpiaij->B, MATSEQAIJKOKKOS)); 72 PetscCall(MatSeqAIJSetPreallocation(mpiaij->A, d_nz, d_nnz)); 73 PetscCall(MatSeqAIJSetPreallocation(mpiaij->B, o_nz, o_nnz)); 74 mat->preallocated = PETSC_TRUE; 75 PetscFunctionReturn(PETSC_SUCCESS); 76 } 77 78 static PetscErrorCode MatMult_MPIAIJKokkos(Mat mat, Vec xx, Vec yy) 79 { 80 Mat_MPIAIJ *mpiaij = (Mat_MPIAIJ *)mat->data; 81 PetscInt nt; 82 83 PetscFunctionBegin; 84 PetscCall(VecGetLocalSize(xx, &nt)); 85 PetscCheck(nt == mat->cmap->n, PETSC_COMM_SELF, PETSC_ERR_ARG_SIZ, "Incompatible partition of mat (%" PetscInt_FMT ") and xx (%" PetscInt_FMT ")", mat->cmap->n, nt); 86 PetscCall(VecScatterBegin(mpiaij->Mvctx, xx, mpiaij->lvec, INSERT_VALUES, SCATTER_FORWARD)); 87 PetscCall((*mpiaij->A->ops->mult)(mpiaij->A, xx, yy)); 88 PetscCall(VecScatterEnd(mpiaij->Mvctx, xx, mpiaij->lvec, INSERT_VALUES, SCATTER_FORWARD)); 89 PetscCall((*mpiaij->B->ops->multadd)(mpiaij->B, mpiaij->lvec, yy, yy)); 90 PetscFunctionReturn(PETSC_SUCCESS); 91 } 92 93 static PetscErrorCode MatMultAdd_MPIAIJKokkos(Mat mat, Vec xx, Vec yy, Vec zz) 94 { 95 Mat_MPIAIJ *mpiaij = (Mat_MPIAIJ *)mat->data; 96 PetscInt nt; 97 98 PetscFunctionBegin; 99 PetscCall(VecGetLocalSize(xx, &nt)); 100 PetscCheck(nt == mat->cmap->n, PETSC_COMM_SELF, PETSC_ERR_ARG_SIZ, "Incompatible partition of mat (%" PetscInt_FMT ") and xx (%" PetscInt_FMT ")", mat->cmap->n, nt); 101 PetscCall(VecScatterBegin(mpiaij->Mvctx, xx, mpiaij->lvec, INSERT_VALUES, SCATTER_FORWARD)); 102 PetscCall((*mpiaij->A->ops->multadd)(mpiaij->A, xx, yy, zz)); 103 PetscCall(VecScatterEnd(mpiaij->Mvctx, xx, mpiaij->lvec, INSERT_VALUES, SCATTER_FORWARD)); 104 PetscCall((*mpiaij->B->ops->multadd)(mpiaij->B, mpiaij->lvec, zz, zz)); 105 PetscFunctionReturn(PETSC_SUCCESS); 106 } 107 108 static PetscErrorCode MatMultTranspose_MPIAIJKokkos(Mat mat, Vec xx, Vec yy) 109 { 110 Mat_MPIAIJ *mpiaij = (Mat_MPIAIJ *)mat->data; 111 PetscInt nt; 112 113 PetscFunctionBegin; 114 PetscCall(VecGetLocalSize(xx, &nt)); 115 PetscCheck(nt == mat->rmap->n, PETSC_COMM_SELF, PETSC_ERR_ARG_SIZ, "Incompatible partition of mat (%" PetscInt_FMT ") and xx (%" PetscInt_FMT ")", mat->rmap->n, nt); 116 PetscCall((*mpiaij->B->ops->multtranspose)(mpiaij->B, xx, mpiaij->lvec)); 117 PetscCall((*mpiaij->A->ops->multtranspose)(mpiaij->A, xx, yy)); 118 PetscCall(VecScatterBegin(mpiaij->Mvctx, mpiaij->lvec, yy, ADD_VALUES, SCATTER_REVERSE)); 119 PetscCall(VecScatterEnd(mpiaij->Mvctx, mpiaij->lvec, yy, ADD_VALUES, SCATTER_REVERSE)); 120 PetscFunctionReturn(PETSC_SUCCESS); 121 } 122 123 /* Merge the "A, B" matrices of mat into a matrix C. mat's type is MPIAIJKOKKOS. C's type is MATSEQAIJKOKKOS. 124 A is put before B. C's size would be A->rmap->n by (A->cmap->n + B->cmap->n). 125 C still uses local column ids. Their corresponding global column ids are returned in glob. 126 */ 127 static PetscErrorCode MatMPIAIJGetLocalMatMerge_MPIAIJKokkos(Mat mat, MatReuse reuse, IS *glob, Mat *C) 128 { 129 Mat Ad, Ao; 130 const PetscInt *cmap; 131 132 PetscFunctionBegin; 133 PetscCall(MatMPIAIJGetSeqAIJ(mat, &Ad, &Ao, &cmap)); 134 PetscCall(MatSeqAIJKokkosMergeMats(Ad, Ao, reuse, C)); 135 if (glob) { 136 PetscInt cst, i, dn, on, *gidx; 137 PetscCall(MatGetLocalSize(Ad, NULL, &dn)); 138 PetscCall(MatGetLocalSize(Ao, NULL, &on)); 139 PetscCall(MatGetOwnershipRangeColumn(mat, &cst, NULL)); 140 PetscCall(PetscMalloc1(dn + on, &gidx)); 141 for (i = 0; i < dn; i++) gidx[i] = cst + i; 142 for (i = 0; i < on; i++) gidx[i + dn] = cmap[i]; 143 PetscCall(ISCreateGeneral(PetscObjectComm((PetscObject)Ad), dn + on, gidx, PETSC_OWN_POINTER, glob)); 144 } 145 PetscFunctionReturn(PETSC_SUCCESS); 146 } 147 148 /* Structs used in matrix products of type C=AB, C=A^tB and C=B^tAB */ 149 struct MatMatStruct { 150 PetscInt n, *garray; // C's garray and its size. 151 KokkosCsrMatrix Cd, Co; // C is in split form matrices (all in local column indcies) 152 KokkosCsrMatrix C1, C2, C3, C4; // intermediate mat products 153 KokkosCsrMatrix C2_mid, C4_mid; // alias of C2, C4; share their a[], i[], but with different j[] (hence column size) 154 PetscIntKokkosView E_NzLeft; 155 PetscSF sf = nullptr; // SF to bcast or reduce matrices E to F 156 MatScalarKokkosView rootBuf, leafBuf; 157 KokkosCsrMatrix Fd, Fo; // F in split form 158 159 KernelHandle kh1; // compute C1, add C1+C3 or C1+Fd 160 KernelHandle kh2; // compute C2, add C2+C4 or C2+Fo 161 KernelHandle kh3; // compute C3 162 KernelHandle kh4; // compute C4 163 164 PetscInt E_TeamSize; // kernel launching parameters in merging E or splitting F 165 PetscInt E_VectorLength; 166 PetscInt E_RowsPerTeam; 167 PetscInt F_TeamSize; 168 PetscInt F_VectorLength; 169 PetscInt F_RowsPerTeam; 170 171 ~MatMatStruct() 172 { 173 PetscFunctionBegin; 174 PetscCallAbort(PETSC_COMM_SELF, PetscSFDestroy(&sf)); 175 PetscFunctionReturnVoid(); 176 } 177 }; 178 179 struct MatMatStruct_AB : public MatMatStruct { 180 PetscIntKokkosView F_NzLeft; // plans to split F (in leafbuf) into Fd, Fo 181 PetscIntKokkosView irootloc; // plans to put E (i.e., Bd, Bo) into rootBuf 182 PetscIntKokkosView rowoffset; 183 }; 184 185 struct MatMatStruct_AtB : public MatMatStruct { 186 MatColIdxKokkosView Fdjmap; // plans to reduce data in rootBuf to Fd, Fo 187 MatColIdxKokkosView Fdjperm; 188 MatColIdxKokkosView Fojmap; 189 MatColIdxKokkosView Fojperm; 190 }; 191 192 struct MatProductData_MPIAIJKokkos { 193 MatMatStruct_AB *mmAB = nullptr; 194 MatMatStruct_AtB *mmAtB = nullptr; 195 PetscBool reusesym = PETSC_FALSE; 196 Mat Z = nullptr; // store Z=AB in computing BtAB 197 198 ~MatProductData_MPIAIJKokkos() 199 { 200 delete mmAB; 201 delete mmAtB; 202 PetscCallAbort(PETSC_COMM_SELF, MatDestroy(&Z)); 203 } 204 }; 205 206 static PetscErrorCode MatProductDataDestroy_MPIAIJKokkos(void *data) 207 { 208 PetscFunctionBegin; 209 PetscCallCXX(delete static_cast<MatProductData_MPIAIJKokkos *>(data)); 210 PetscFunctionReturn(PETSC_SUCCESS); 211 } 212 213 /* MatSetMPIAIJKokkosWithSplitSeqAIJKokkosMatrices - Set the diag and offdiag matrices of a MATMPIAIJKOKKOS matrix. 214 It is similar to MatCreateMPIAIJWithSplitArrays. 215 216 Input Parameters: 217 + mat - the MATMPIAIJKOKKOS matrix, which should have its type and layout set, but should not have its diag, offdiag matrices set 218 . A - the diag matrix using local col ids 219 - B - the offdiag matrix using global col ids 220 221 Output Parameter: 222 . mat - the updated MATMPIAIJKOKKOS matrix 223 */ 224 static PetscErrorCode MatSetMPIAIJKokkosWithSplitSeqAIJKokkosMatrices(Mat mat, Mat A, Mat B, PetscInt *garray) 225 { 226 Mat_MPIAIJ *mpiaij = static_cast<Mat_MPIAIJ *>(mat->data); 227 PetscInt m, n, M, N, Am, An, Bm, Bn; 228 229 PetscFunctionBegin; 230 PetscCall(MatGetSize(mat, &M, &N)); 231 PetscCall(MatGetLocalSize(mat, &m, &n)); 232 PetscCall(MatGetLocalSize(A, &Am, &An)); 233 PetscCall(MatGetLocalSize(B, &Bm, &Bn)); 234 235 PetscCheck(m == Am && m == Bm, PETSC_COMM_SELF, PETSC_ERR_PLIB, "local number of rows do not match"); 236 PetscCheck(n == An, PETSC_COMM_SELF, PETSC_ERR_PLIB, "local number of columns do not match"); 237 // PetscCheck(N == Bn, PETSC_COMM_SELF, PETSC_ERR_PLIB, "global number of columns do not match"); 238 PetscCheck(!mpiaij->A && !mpiaij->B, PETSC_COMM_SELF, PETSC_ERR_PLIB, "A, B of the MPIAIJ matrix are not empty"); 239 mpiaij->A = A; 240 mpiaij->B = B; 241 mpiaij->garray = garray; 242 243 mat->preallocated = PETSC_TRUE; 244 mat->nooffprocentries = PETSC_TRUE; /* See MatAssemblyBegin_MPIAIJ. In effect, making MatAssemblyBegin a nop */ 245 246 PetscCall(MatSetOption(mat, MAT_NO_OFF_PROC_ENTRIES, PETSC_TRUE)); 247 PetscCall(MatAssemblyBegin(mat, MAT_FINAL_ASSEMBLY)); 248 /* MatAssemblyEnd is critical here. It sets mat->offloadmask according to A and B's, and 249 also gets mpiaij->B compacted, with its col ids and size reduced 250 */ 251 PetscCall(MatAssemblyEnd(mat, MAT_FINAL_ASSEMBLY)); 252 PetscCall(MatSetOption(mat, MAT_NO_OFF_PROC_ENTRIES, PETSC_FALSE)); 253 PetscCall(MatSetOption(mat, MAT_NEW_NONZERO_LOCATION_ERR, PETSC_TRUE)); 254 PetscFunctionReturn(PETSC_SUCCESS); 255 } 256 257 // Adapted from Kokkos-Kernels spmv_launch_parameters(), to get parameters in Kokkos nested loops which we used to merge or 258 // split csr matrices. The rule is to have "vector_length * team_size" be around 256 on GPUs (e.g., for a CUDA thread block) 259 template <class ExecutionSpace> 260 static PetscErrorCode MatMergeGetLaunchParameters(PetscInt numRows, PetscInt nnz, PetscInt rows_per_thread, PetscInt &team_size, PetscInt &vector_length, PetscInt &rows_per_team) 261 { 262 Kokkos::TeamPolicy<ExecutionSpace> teamPolicy(128, Kokkos::AUTO); 263 264 PetscFunctionBegin; 265 PetscInt nnz_per_row = numRows ? (nnz / numRows) : 0; // we might meet empty matrices 266 267 if (nnz_per_row < 1) nnz_per_row = 1; 268 269 int max_vector_length = teamPolicy.vector_length_max(); 270 271 if (vector_length < 1) { 272 vector_length = 1; 273 while (vector_length < max_vector_length && vector_length * 6 < nnz_per_row) vector_length *= 2; 274 } 275 276 // Determine rows per thread 277 if (rows_per_thread < 1) { 278 if (KokkosKernels::Impl::kk_is_gpu_exec_space<ExecutionSpace>()) rows_per_thread = 1; 279 else { 280 if (nnz_per_row < 20 && nnz > 5000000) { 281 rows_per_thread = 256; 282 } else rows_per_thread = 64; 283 } 284 } 285 286 if (team_size < 1) { 287 if (KokkosKernels::Impl::kk_is_gpu_exec_space<ExecutionSpace>()) { 288 team_size = 256 / vector_length; 289 } else { 290 team_size = 1; 291 } 292 } 293 294 rows_per_team = rows_per_thread * team_size; 295 296 if (rows_per_team < 0) { 297 PetscInt nnz_per_team = 4096; 298 PetscInt conc = ExecutionSpace().concurrency(); 299 while ((conc * nnz_per_team * 4 > nnz) && (nnz_per_team > 256)) nnz_per_team /= 2; 300 rows_per_team = (nnz_per_team + nnz_per_row - 1) / nnz_per_row; 301 } 302 PetscFunctionReturn(PETSC_SUCCESS); 303 } 304 305 /* 306 Reduce two sets of global indices into local ones 307 308 Input Parameters: 309 + n1 - size of garray1[], the first set 310 . garray1[n1] - a sorted global index array (without duplicates) 311 . m - size of indices[], the second set 312 - indices[m] - a unsorted global index array (might have duplicates), which will be updated on output into local ones 313 314 Output Parameters: 315 + n2 - size of garray2[], the merged set, which combines garray1[] and indices[] 316 . garray2[n2] - allocated by callee using PetscMalloc1(). Contains sorted unique global indices (without duplicates). Caller needs to free it. 317 . map[n1] - allocated by caller. It gives garray1[i] = garray2[map[i]] 318 - indices[m] - on output, global indices in this array are rewritten with local ones, i.e, indices_input[i] = garray2[indices_output[i]] 319 320 Example, say 321 n1 = 5 322 garray1[5] = {1, 4, 7, 8, 10} 323 m = 4 324 indices[4] = {2, 4, 8, 9} 325 326 Combining them together, we have 7 global indices in garray2[] 327 n2 = 7 328 garray2[7] = {1, 2, 4, 7, 8, 9, 10} 329 330 And we have map[] to connect "garray1[i] = garray2[map[i]], i=[0,n1)" 331 map[5] = {0, 2, 3, 4, 6} 332 333 On output, indices[] is updated with local indices 334 indices[4] = {1, 2, 4, 5} 335 */ 336 static PetscErrorCode ReduceTwoSetsOfGlobalIndices(PetscInt n1, const PetscInt *garray1, PetscInt m, PetscInt *indices, PetscInt *n2_, PetscInt **garray2_, PetscInt *map) 337 { 338 PetscHMapI g2l = nullptr; 339 PetscHashIter iter; 340 PetscInt tot, key, val; // total unique global indices. key is global id; val is local id 341 PetscInt n2, *garray2; 342 343 PetscFunctionBegin; 344 tot = 0; 345 PetscCall(PetscHMapICreateWithSize(n1, &g2l)); 346 for (PetscInt i = 0; i < m; i++) { // insert those in indices[] 347 PetscCall(PetscHMapIGetWithDefault(g2l, indices[i], -1, &val)); // if not exist, val is set with -1 348 if (val < 0) PetscCall(PetscHMapISet(g2l, indices[i], tot++)); // val < 0 means gid is not in the hash table yet 349 } 350 351 for (PetscInt i = 0; i < n1; i++) { // insert those in garray1[] 352 PetscCall(PetscHMapIGetWithDefault(g2l, garray1[i], -1, &val)); 353 if (val < 0) PetscCall(PetscHMapISet(g2l, garray1[i], tot++)); 354 } 355 356 // Pull out (unique) globals in the hash table and put them in garray2[] 357 n2 = tot; 358 PetscCall(PetscMalloc1(n2, &garray2)); 359 tot = 0; 360 PetscHashIterBegin(g2l, iter); 361 while (!PetscHashIterAtEnd(g2l, iter)) { 362 PetscHashIterGetKey(g2l, iter, key); 363 PetscHashIterNext(g2l, iter); 364 garray2[tot++] = key; 365 } 366 367 // Sort garray2[] and then map them to local indices starting from 0 368 PetscCall(PetscSortInt(n2, garray2)); 369 PetscCall(PetscHMapIClear(g2l)); 370 for (PetscInt i = 0; i < tot; i++) PetscCall(PetscHMapISet(g2l, garray2[i], i)); // i is the local id 371 372 // Rewrite indices[] with local indices 373 for (PetscInt i = 0; i < m; i++) { 374 PetscCall(PetscHMapIGetWithDefault(g2l, indices[i], -1, &val)); 375 PetscAssert(val >= 0, PETSC_COMM_SELF, PETSC_ERR_PLIB, "Met a negative local column index"); 376 indices[i] = val; 377 } 378 // Record the map that maps garray1[i] to garray2[map[i]] 379 for (PetscInt i = 0; i < n1; i++) PetscCall(PetscHMapIGetWithDefault(g2l, garray1[i], -1, &map[i])); 380 PetscCall(PetscHMapIDestroy(&g2l)); 381 *n2_ = n2; 382 *garray2_ = garray2; 383 PetscFunctionReturn(PETSC_SUCCESS); 384 } 385 386 /* 387 MatMPIAIJKokkosReduce - Reduce rows of a MPIAIJKOKKOS matrix (E, in split form) to produce another matrix (F, also in split form, stored in mm) 388 389 It is the reverse of MatMPIAIJKokkosBcast() in some sense, but with a different signature since we do not really need a fully populated MPIAIJKOKKOS E. 390 391 Think each row of E as a leaf, then the given ownerSF specifies roots for the leaves. Roots may connect to multiple leaves. 392 In this routine, we sparse-merge leaves (rows) at their roots to form potentially longer rows in F. F's number of rows will be nroots of ownerSF. 393 394 Input Parameters: 395 + comm - MPI communicator of E 396 . A - diag block of E, using local column indices 397 . B - off-diag block of E, using local column indices 398 . cstart - (global) start column of Ed 399 . cend - (global) end column + 1 of Ed. In other words, E's column ownership is in range of [cstart, cend) 400 . garray1[n1] - global column indices of Eo. Here n1 is Eo's column size. 401 . ownerSF - the SF specifies ownership (root) of rows in E 402 . reuse - either MAT_INITIAL_MATRIX or MAT_REUSE_MATRIX 403 - mm - to stash intermediate data structures for reuse 404 405 Output Parameters: 406 + map[n1] - allocated by caller. It maps garray1[] to garray2[]. See more at ReduceTwoSetsOfGlobalIndices(). 407 - mm - contains various info, such as garray2[], F (Fd, Fo) etc. 408 409 Notes: 410 When reuse = MAT_REUSE_MATRIX, cstart, cend, garray1, ownerSF, map are not significant. 411 412 */ 413 static PetscErrorCode MatMPIAIJKokkosReduceBegin(MPI_Comm comm, KokkosCsrMatrix A, KokkosCsrMatrix B, PetscInt cstart, PetscInt cend, const PetscInt *garray1, PetscSF ownerSF, MatReuse reuse, PetscInt *map, MatMatStruct_AtB *mm) 414 { 415 PetscFunctionBegin; 416 if (reuse == MAT_INITIAL_MATRIX) { 417 PetscInt Em = A.numRows(), Fm; 418 PetscInt n1 = B.numCols(); 419 420 PetscCall(PetscSFGetGraph(ownerSF, &Fm, NULL, NULL, NULL)); // Fm = #rows of F = nroots of ownerSF 421 422 // Do the analysis on host 423 auto Ai_h = Kokkos::create_mirror_view_and_copy(Kokkos::HostSpace(), A.graph.row_map); 424 auto Aj_h = Kokkos::create_mirror_view_and_copy(Kokkos::HostSpace(), A.graph.entries); 425 auto Bi_h = Kokkos::create_mirror_view_and_copy(Kokkos::HostSpace(), B.graph.row_map); 426 auto Bj_h = Kokkos::create_mirror_view_and_copy(Kokkos::HostSpace(), B.graph.entries); 427 const MatRowMapType *Ai = Ai_h.data(), *Bi = Bi_h.data(); 428 const MatColIdxType *Aj = Aj_h.data(), *Bj = Bj_h.data(); 429 430 // Count how many nonzeros of each row in E are in the left of the diag block [cstart,cend) 431 PetscIntKokkosViewHost E_NzLeft_h(NoInit("E_NzLeft_h"), Em), E_RowLen_h(NoInit("E_RowLen_h"), Em); 432 PetscInt *E_NzLeft = E_NzLeft_h.data(), *E_RowLen = E_RowLen_h.data(); 433 for (PetscInt i = 0; i < Em; i++) { 434 const PetscInt *first, *last, *it; 435 PetscInt count, step; 436 // std::lower_bound(first,last,cstart), but need to use global column indices 437 first = Bj + Bi[i]; 438 last = Bj + Bi[i + 1]; 439 count = last - first; 440 while (count > 0) { 441 it = first; 442 step = count / 2; 443 it += step; 444 if (garray1[*it] < cstart) { // map local to global 445 first = ++it; 446 count -= step + 1; 447 } else count = step; 448 } 449 E_NzLeft[i] = first - (Bj + Bi[i]); 450 E_RowLen[i] = (Ai[i + 1] - Ai[i]) + (Bi[i + 1] - Bi[i]); 451 } 452 453 // Get length of rows (i.e., sizes of leaves) that contribute to my roots 454 const PetscMPIInt *iranks, *ranks; 455 const PetscInt *ioffset, *irootloc, *roffset, *rmine; 456 PetscInt niranks, nranks; 457 MPI_Request *reqs; 458 PetscMPIInt tag; 459 PetscSF reduceSF; 460 PetscInt *sdisp, *rdisp; 461 462 PetscCall(PetscCommGetNewTag(comm, &tag)); 463 PetscCall(PetscSFGetLeafRanks(ownerSF, &niranks, &iranks, &ioffset, &irootloc)); // get leaf ranks connecting to roots on this process (I'll recv from them) 464 PetscCall(PetscSFGetRootRanks(ownerSF, &nranks, &ranks, &roffset, &rmine, NULL)); // get root ranks this process connects (I'll send to them) 465 466 // Find out length of each row I will receive. Even for the same row index, when they are from 467 // different senders, they might have different lengths (and sparsity patterns) 468 PetscInt sendRowCnt = roffset[nranks], recvRowCnt = ioffset[niranks]; 469 PetscInt *sendRowLen, *recvRowLen; // lengths of rows of I need to send/recv per process 470 471 PetscCall(PetscMalloc5(sendRowCnt, &sendRowLen, recvRowCnt + 1, &recvRowLen, nranks, &sdisp, niranks + 1, &rdisp, nranks + niranks, &reqs)); 472 473 for (PetscInt i = 0; i < sendRowCnt; i++) sendRowLen[i] = E_RowLen[rmine[i]]; 474 recvRowLen[0] = 0; // since we will make it in CSR format later 475 recvRowLen++; // advance the pointer now 476 for (PetscInt i = 0; i < niranks; i++) { MPI_Irecv(&recvRowLen[ioffset[i]], ioffset[i + 1] - ioffset[i], MPIU_INT, iranks[i], tag, comm, &reqs[nranks + i]); } 477 for (PetscInt i = 0; i < nranks; i++) { MPI_Isend(&sendRowLen[roffset[i]], roffset[i + 1] - roffset[i], MPIU_INT, ranks[i], tag, comm, &reqs[i]); } 478 PetscCallMPI(MPI_Waitall(nranks + niranks, reqs, MPI_STATUSES_IGNORE)); 479 480 // Build the real PetscSF for reducing E rows (buffer to buffer) 481 rdisp[0] = 0; 482 for (PetscInt i = 0; i < niranks; i++) { 483 rdisp[i + 1] = rdisp[i]; 484 for (PetscInt j = ioffset[i]; j < ioffset[i + 1]; j++) { rdisp[i + 1] += recvRowLen[j]; } 485 } 486 recvRowLen--; // put it back into csr format 487 for (PetscInt i = 0; i < recvRowCnt; i++) recvRowLen[i + 1] += recvRowLen[i]; 488 489 for (PetscInt i = 0; i < nranks; i++) { MPI_Irecv(&sdisp[i], 1, MPIU_INT, ranks[i], tag, comm, &reqs[i]); } 490 for (PetscInt i = 0; i < niranks; i++) { MPI_Isend(&rdisp[i], 1, MPIU_INT, iranks[i], tag, comm, &reqs[nranks + i]); } 491 PetscCallMPI(MPI_Waitall(nranks + niranks, reqs, MPI_STATUSES_IGNORE)); 492 493 PetscInt nleaves = 0, Enz = 0; // leaves are nonzeros I will send 494 PetscInt nroots = rdisp[niranks]; // roots are nonzeros I will recv 495 PetscSFNode *iremote; 496 497 for (PetscInt i = 0; i < Em; i++) Enz += E_RowLen[i]; 498 PetscAssert(A.nnz() + B.nnz() == Enz, comm, PETSC_ERR_PLIB, "Enz should be equal to sum of nnz of A and B"); 499 PetscCallMPI(PetscMalloc1(Enz, &iremote)); // no free, since we give ownership to reduceSF 500 501 for (PetscInt i = 0; i < nranks; i++) { 502 PetscInt count = 0; 503 for (PetscInt j = roffset[i]; j < roffset[i + 1]; j++) count += E_RowLen[rmine[j]]; 504 for (PetscInt j = 0; j < count; j++) { 505 iremote[nleaves + j].rank = ranks[i]; 506 iremote[nleaves + j].index = sdisp[i] + j; 507 } 508 nleaves += count; 509 } 510 PetscCheck(nleaves == Enz, comm, PETSC_ERR_PLIB, "nleaves should be equal to Enz"); 511 512 PetscCall(PetscSFCreate(comm, &reduceSF)); 513 PetscCall(PetscSFSetGraph(reduceSF, nroots, nleaves, NULL, PETSC_OWN_POINTER, iremote, PETSC_OWN_POINTER)); 514 515 // Copy (global) column indices of the needed rows in E to sendCol[], and then PetscSFReduce to recvCol[] 516 PetscInt *sendCol, *recvCol; 517 PetscCall(PetscMalloc2(nleaves, &sendCol, nroots, &recvCol)); 518 for (PetscInt k = 0; k < roffset[nranks]; k++) { 519 PetscInt i = rmine[k]; // row to be copied 520 PetscInt *buf = &sendCol[Ai[i] + Bi[i]]; 521 PetscInt nzLeft = E_NzLeft[i]; 522 PetscInt alen = Ai[i + 1] - Ai[i], blen = Bi[i + 1] - Bi[i]; 523 for (PetscInt j = 0; j < alen + blen; j++) { 524 if (j < nzLeft) { 525 buf[j] = garray1[Bj[Bi[i] + j]]; // left B, in global 526 } else if (j < nzLeft + alen) { 527 buf[j] = Aj[Ai[i] + j - nzLeft] + cstart; // diag A, also in global 528 } else { 529 buf[j] = garray1[Bj[Bi[i] + j - alen]]; // right B, in global 530 } 531 } 532 } 533 PetscCall(PetscSFReduceWithMemTypeBegin(reduceSF, MPIU_INT, PETSC_MEMTYPE_HOST, sendCol, PETSC_MEMTYPE_HOST, recvCol, MPI_REPLACE)); 534 PetscCall(PetscSFReduceEnd(reduceSF, MPIU_INT, sendCol, recvCol, MPI_REPLACE)); 535 536 // With recvCol[], we do a series of analysis to get i, j of Fd, Fo, and build plans to reduce nonzeros in recv buffers to Fd and Fo 537 PetscInt *recvRowPerm, *recvColSorted; 538 PetscInt *recvNzPerm, *recvNzPermSorted; 539 PetscCall(PetscMalloc4(recvRowCnt, &recvRowPerm, nroots, &recvColSorted, nroots, &recvNzPerm, nroots, &recvNzPermSorted)); 540 541 for (PetscInt i = 0; i < nroots; i++) recvNzPerm[i] = i; // numbering all received nonzeros 542 for (PetscInt i = 0; i < recvRowCnt; i++) recvRowPerm[i] = i; // put up a permutation array, so that after sorting we know where to get a row in recvCol[] 543 PetscCall(PetscSortIntWithPermutation(recvRowCnt, irootloc, recvRowPerm)); // irootloc[] (owned by ownerSF) won't be changed 544 545 // i[] array, nz are always easiest to compute 546 MatRowMapKokkosViewHost Fdi_h(NoInit("Fdi_h"), Fm + 1), Foi_h(NoInit("Foi_h"), Fm + 1); 547 MatRowMapType *Fdi, *Foi; 548 PetscInt FnzDups = 0, Fdnz = 0, FdnzDups = 0, Fonz = 0, FonzDups = 0; // nz (with or without dups) in F, Fd, Fo 549 PetscInt iter; 550 551 Kokkos::deep_copy(Fdi_h, 0); // zero, as we will do 'val++' on them 552 Kokkos::deep_copy(Foi_h, 0); 553 Fdi = Fdi_h.data() + 1; // +1 for easy indexing in code below 554 Foi = Foi_h.data() + 1; 555 iter = 0; 556 while (iter < recvRowCnt) { // iter over received rows 557 PetscInt curRowIdx = irootloc[recvRowPerm[iter]]; 558 PetscInt dupRows = 1; // current row has this many contributing rows (of various sparsity patterns) 559 560 while (iter + dupRows < recvRowCnt && irootloc[recvRowPerm[iter + dupRows]] == curRowIdx) dupRows++; 561 562 // Copy column indices (and their permutation) of these rows into recvColSorted & recvNzPermSorted 563 PetscInt nz = 0; // nz (with dups) in the current row 564 PetscInt *jbuf = recvColSorted + FnzDups; 565 PetscInt *pbuf = recvNzPermSorted + FnzDups; 566 PetscInt *jbuf2 = jbuf; // temp pointers 567 PetscInt *pbuf2 = pbuf; 568 for (PetscInt d = 0; d < dupRows; d++) { 569 PetscInt i = recvRowPerm[iter + d]; 570 PetscInt len = recvRowLen[i + 1] - recvRowLen[i]; 571 PetscCall(PetscArraycpy(jbuf2, &recvCol[recvRowLen[i]], len)); 572 PetscCall(PetscArraycpy(pbuf2, &recvNzPerm[recvRowLen[i]], len)); 573 jbuf2 += len; 574 pbuf2 += len; 575 nz += len; 576 } 577 PetscCall(PetscIntSortSemiOrderedWithArray(nz, jbuf, pbuf)); // It could be improved with k-way merge sort, since the rows are already sorted 578 579 // Scan column indices (in jbuf[0,nz), might have dups) of this row, and see how many go to Fd and how many go to Fo 580 PetscInt cur = 0; 581 while (cur < nz) { 582 PetscInt curColIdx = jbuf[cur]; 583 PetscInt dups = 1; 584 585 while (cur + dups < nz && jbuf[cur + dups] == curColIdx) dups++; 586 if (curColIdx >= cstart && curColIdx < cend) { 587 Fdi[curRowIdx]++; 588 FdnzDups += dups; 589 } else { 590 Foi[curRowIdx]++; 591 FonzDups += dups; 592 } 593 cur += dups; 594 } 595 596 FnzDups += nz; 597 iter += dupRows; // Move to next unique row 598 } 599 600 Fdi = Fdi_h.data(); // restore Fdi, Foi and make them CSR 601 Foi = Foi_h.data(); 602 for (PetscInt i = 0; i < Fm; i++) { 603 Fdi[i + 1] += Fdi[i]; 604 Foi[i + 1] += Foi[i]; 605 } 606 Fdnz = Fdi[Fm]; 607 Fonz = Foi[Fm]; 608 PetscCall(PetscFree2(sendCol, recvCol)); 609 610 // Allocate j, jmap, jperm for Fd and Fo 611 MatColIdxKokkosViewHost Fdj_h(NoInit("Fdj_h"), Fdnz), Foj_h(NoInit("Foj_h"), Fonz); 612 MatRowMapKokkosViewHost Fdjmap_h(NoInit("Fdjmap_h"), Fdnz + 1), Fojmap_h(NoInit("Fojmap_h"), Fonz + 1); // +1 to make csr 613 MatRowMapKokkosViewHost Fdjperm_h(NoInit("Fdjperm_h"), FdnzDups), Fojperm_h(NoInit("Fojperm_h"), FonzDups); 614 MatColIdxType *Fdj = Fdj_h.data(), *Foj = Foj_h.data(); 615 MatRowMapType *Fdjmap = Fdjmap_h.data(), *Fojmap = Fojmap_h.data(); 616 MatRowMapType *Fdjperm = Fdjperm_h.data(), *Fojperm = Fojperm_h.data(); 617 618 // Scan recvColSorted[] again, and fill j, jmap, jperm for Fd and Fo 619 Fdjmap[0] = 0; 620 Fojmap[0] = 0; 621 FnzDups = 0; 622 Fdnz = 0; 623 Fonz = 0; 624 iter = 0; // iter over received rows 625 while (iter < recvRowCnt) { 626 PetscInt curRowIdx = irootloc[recvRowPerm[iter]]; // current row idx 627 PetscInt dupRows = 1; // It has this many contributing rows (of various lengths) 628 PetscInt nz = 0; // nz (with dups) in the current row 629 630 while (iter + dupRows < recvRowCnt && irootloc[recvRowPerm[iter + dupRows]] == curRowIdx) dupRows++; 631 for (PetscInt d = 0; d < dupRows; d++) { 632 PetscInt i = recvRowPerm[iter + d]; 633 nz += recvRowLen[i + 1] - recvRowLen[i]; 634 } 635 636 PetscInt *jbuf = recvColSorted + FnzDups; 637 // Scan columns (in jbuf[0,nz) of this row, copy them and their permutation to j[] and jperm[] of Fd and Fo 638 PetscInt cur = 0; 639 while (cur < nz) { 640 PetscInt curColIdx = jbuf[cur]; 641 PetscInt dups = 1; 642 643 while (cur + dups < nz && jbuf[cur + dups] == curColIdx) dups++; 644 if (curColIdx >= cstart && curColIdx < cend) { 645 Fdj[Fdnz] = curColIdx - cstart; // easily convert to local 646 Fdjmap[Fdnz + 1] = Fdjmap[Fdnz] + dups; 647 for (PetscInt j = 0; j < dups; j++) Fdjperm[Fdjmap[Fdnz] + j] = recvNzPermSorted[FnzDups + j]; 648 FdnzDups += dups; 649 Fdnz++; 650 } else { 651 Foj[Fonz] = curColIdx; // in global 652 Fojmap[Fonz + 1] = Fojmap[Fonz] + dups; 653 for (PetscInt j = 0; j < dups; j++) Fojperm[Fojmap[Fonz] + j] = recvNzPermSorted[FnzDups + j]; 654 FonzDups += dups; 655 Fonz++; 656 } 657 cur += dups; 658 FnzDups += dups; 659 } 660 iter += dupRows; // Move to next unique row 661 } 662 PetscCall(PetscFree4(recvRowPerm, recvColSorted, recvNzPerm, recvNzPermSorted)); 663 PetscCall(PetscFree5(sendRowLen, recvRowLen, sdisp, rdisp, reqs)); 664 665 // Combine global column indices in garray1[] and Foj[] 666 PetscInt n2, *garray2; 667 668 PetscCall(ReduceTwoSetsOfGlobalIndices(n1, garray1, Fonz, Foj, &n2, &garray2, map)); 669 mm->sf = reduceSF; 670 mm->leafBuf = MatScalarKokkosView(NoInit("leafBuf"), nleaves); 671 mm->rootBuf = MatScalarKokkosView(NoInit("rootBuf"), nroots); 672 mm->garray = garray2; // give ownership, so no free 673 mm->n = n2; 674 mm->E_NzLeft = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), E_NzLeft_h); 675 mm->Fdjmap = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Fdjmap_h); 676 mm->Fdjperm = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Fdjperm_h); 677 mm->Fojmap = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Fojmap_h); 678 mm->Fojperm = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Fojperm_h); 679 680 // Output Fd and Fo in KokkosCsrMatrix format 681 MatScalarKokkosView Fda_d(NoInit("Fda_d"), Fdnz); 682 MatRowMapKokkosView Fdi_d = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Fdi_h); 683 MatColIdxKokkosView Fdj_d = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Fdj_h); 684 MatScalarKokkosView Foa_d(NoInit("Foa_d"), Fonz); 685 MatRowMapKokkosView Foi_d = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Foi_h); 686 MatColIdxKokkosView Foj_d = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Foj_h); 687 688 PetscCallCXX(mm->Fd = KokkosCsrMatrix("Fd", Fm, cend - cstart, Fdnz, Fda_d, Fdi_d, Fdj_d)); 689 PetscCallCXX(mm->Fo = KokkosCsrMatrix("Fo", Fm, n2, Fonz, Foa_d, Foi_d, Foj_d)); // Fo's column size is n2, length of garray2[] 690 691 // Compute kernel launch parameters in merging E 692 PetscInt teamSize, vectorLength, rowsPerTeam; 693 694 teamSize = vectorLength = rowsPerTeam = -1; 695 PetscCall(MatMergeGetLaunchParameters<DefaultExecutionSpace>(Em, Enz, -1, teamSize, vectorLength, rowsPerTeam)); 696 mm->E_TeamSize = teamSize; 697 mm->E_VectorLength = vectorLength; 698 mm->E_RowsPerTeam = rowsPerTeam; 699 } else PetscCheck(reuse == MAT_REUSE_MATRIX, comm, PETSC_ERR_PLIB, "Unsupported MatReuse enum %d", reuse); 700 701 // Handy aliases 702 auto &Aa = A.values; 703 auto &Ba = B.values; 704 const auto &Ai = A.graph.row_map; 705 const auto &Bi = B.graph.row_map; 706 const auto &E_NzLeft = mm->E_NzLeft; 707 auto &leafBuf = mm->leafBuf; 708 auto &rootBuf = mm->rootBuf; 709 PetscSF reduceSF = mm->sf; 710 PetscInt Em = A.numRows(); 711 PetscInt teamSize = mm->E_TeamSize; 712 PetscInt vectorLength = mm->E_VectorLength; 713 PetscInt rowsPerTeam = mm->E_RowsPerTeam; 714 PetscInt workSets = (Em + rowsPerTeam - 1) / rowsPerTeam; 715 716 // Copy rows in A/B of E to leafBuf, then pass it to rootBuf 717 PetscCallCXX(Kokkos::parallel_for( 718 Kokkos::TeamPolicy<>(workSets, teamSize, vectorLength), KOKKOS_LAMBDA(const KokkosTeamMemberType &t) { 719 Kokkos::parallel_for(Kokkos::TeamThreadRange(t, 0, rowsPerTeam), [&](PetscInt k) { 720 PetscInt i = t.league_rank() * rowsPerTeam + k; // i-th row in F 721 if (i < Em) { 722 PetscInt disp = Ai(i) + Bi(i); 723 PetscInt alen = Ai(i + 1) - Ai(i); 724 PetscInt blen = Bi(i + 1) - Bi(i); 725 PetscInt nzleft = E_NzLeft(i); 726 727 Kokkos::parallel_for(Kokkos::ThreadVectorRange(t, alen + blen), [&](PetscInt j) { 728 MatScalar &val = leafBuf(disp + j); 729 if (j < nzleft) { // B left 730 val = Ba(Bi(i) + j); 731 } else if (j < nzleft + alen) { // diag A 732 val = Aa(Ai(i) + j - nzleft); 733 } else { // B right 734 val = Ba(Bi(i) + j - alen); 735 } 736 }); 737 } 738 }); 739 })); 740 PetscCall(PetscSFReduceWithMemTypeBegin(reduceSF, MPIU_SCALAR, PETSC_MEMTYPE_KOKKOS, leafBuf.data(), PETSC_MEMTYPE_KOKKOS, rootBuf.data(), MPI_REPLACE)); 741 PetscFunctionReturn(PETSC_SUCCESS); 742 } 743 744 // To finish MatMPIAIJKokkosReduce. 745 static PetscErrorCode MatMPIAIJKokkosReduceEnd(MPI_Comm comm, KokkosCsrMatrix A, KokkosCsrMatrix B, PetscInt cstart, PetscInt cend, const PetscInt *garray1, PetscSF ownerSF, MatReuse reuse, PetscInt *map, MatMatStruct_AtB *mm) 746 { 747 PetscFunctionBegin; 748 auto &leafBuf = mm->leafBuf; 749 auto &rootBuf = mm->rootBuf; 750 auto &Fda = mm->Fd.values; 751 const auto &Fdjmap = mm->Fdjmap; 752 const auto &Fdjperm = mm->Fdjperm; 753 auto Fdnz = mm->Fd.nnz(); 754 auto &Foa = mm->Fo.values; 755 const auto &Fojmap = mm->Fojmap; 756 const auto &Fojperm = mm->Fojperm; 757 auto Fonz = mm->Fo.nnz(); 758 PetscSF reduceSF = mm->sf; 759 760 PetscCall(PetscSFReduceEnd(reduceSF, MPIU_SCALAR, leafBuf.data(), rootBuf.data(), MPI_REPLACE)); 761 762 // Reduce data in rootBuf to Fd and Fo 763 PetscCallCXX(Kokkos::parallel_for( 764 Fdnz, KOKKOS_LAMBDA(const MatRowMapType i) { 765 PetscScalar sum = 0.0; 766 for (MatRowMapType k = Fdjmap(i); k < Fdjmap(i + 1); k++) sum += rootBuf(Fdjperm(k)); 767 Fda(i) = sum; 768 })); 769 770 PetscCallCXX(Kokkos::parallel_for( 771 Fonz, KOKKOS_LAMBDA(const MatRowMapType i) { 772 PetscScalar sum = 0.0; 773 for (MatRowMapType k = Fojmap(i); k < Fojmap(i + 1); k++) sum += rootBuf(Fojperm(k)); 774 Foa(i) = sum; 775 })); 776 PetscFunctionReturn(PETSC_SUCCESS); 777 } 778 779 /* 780 MatMPIAIJKokkosBcast - Bcast local rows of a MPIAIJKOKKOS matrix (E) to produce a local matrix (F, stored in mm) in split form 781 782 This is a complex routine. It is essentially the MPIAIJKOKKOS counterpart of MatGetBrowsOfAoCols_MPIAIJ, but supports 783 device and involves various index mapping. 784 785 In the given ownerSF, leaves correspond to rows in F, and roots correspond to rows in E. Roots may connect to multiple leaves. 786 Suppose F's j-th row is connected to a root identified by PetscSFNode (k,i), it means we need to bcast the i-th row of E on rank k 787 to j-th row of F. ownerSF is not an arbitrary SF, instead it is the Mvctx of another MPIAIJ matrix A that is able to perform A*E. 788 F has the same column layout as E. 789 790 Conceptually F has global column indices. In this routine, we spit F into diagonal Fd and off-diagonal Fo. 791 Fd uses local column indices, which are easy to compute. We just need to subtract the "local column range start" from the global indices. 792 Fo had global column indices at first. We will reduce them into local ones. In doing that, we also take into account the global 793 column indices that E's off-diag block has. Let's say there are n1 such indices stored in garray1[]. We will reduce them along with 794 column indices in Fo and update Fo with local indices. 795 796 Input Parameters: 797 + E - the MPIAIJKOKKOS matrix 798 . ownerSF - the ownership SF (insignificant in MAT_REUSE_MATRIX) 799 . reuse - either MAT_INITIAL_MATRIX or MAT_REUSE_MATRIX 800 - mm - to stash matproduct intermediate data structures 801 802 Output Parameters: 803 + map[n1] - allocated by caller. It maps garray1[] to garray2[]. See more at ReduceTwoSetsOfGlobalIndices. 804 - mm - contains various info, such as garray2[], Fd, Fo, etc. 805 806 Notes: 807 When reuse = MAT_REUSE_MATRIX, ownerSF, map are not significant. 808 The routine is provide in split-phase form MatMPIAIJKokkosBcastBegin/End() to provide computation/communication opportunities. 809 */ 810 static PetscErrorCode MatMPIAIJKokkosBcastBegin(Mat E, PetscSF ownerSF, MatReuse reuse, PetscInt *map, MatMatStruct_AB *mm) 811 { 812 Mat_MPIAIJ *empi = static_cast<Mat_MPIAIJ *>(E->data); 813 Mat A = empi->A, B = empi->B; // diag and off-diag 814 Mat_SeqAIJKokkos *akok = static_cast<Mat_SeqAIJKokkos *>(A->spptr), *bkok = static_cast<Mat_SeqAIJKokkos *>(B->spptr); 815 PetscInt Em = E->rmap->n; // #local rows 816 MPI_Comm comm; 817 818 PetscFunctionBegin; 819 PetscCallMPI(PetscObjectGetComm((PetscObject)E, &comm)); 820 if (reuse == MAT_INITIAL_MATRIX) { 821 Mat_SeqAIJ *aseq = static_cast<Mat_SeqAIJ *>(A->data), *bseq = static_cast<Mat_SeqAIJ *>(B->data); 822 PetscInt n1 = B->cmap->n, *Ai = aseq->i, *Aj = aseq->j, *Bi = bseq->i, *Bj = bseq->j; 823 const PetscInt *garray1 = empi->garray; // its size is n1 824 PetscInt cstart, cend; 825 PetscSF bcastSF; 826 827 PetscCall(MatGetOwnershipRangeColumn(E, &cstart, &cend)); 828 829 // Count how many nonzeros of each row in E are in the left of the diag block [cstart,cend) 830 PetscIntKokkosViewHost E_NzLeft_h(NoInit("E_NzLeft_h"), Em), E_RowLen_h(NoInit("E_RowLen_h"), Em); 831 PetscInt *E_NzLeft = E_NzLeft_h.data(), *E_RowLen = E_RowLen_h.data(); 832 for (PetscInt i = 0; i < Em; i++) { 833 const PetscInt *first, *last, *it; 834 PetscInt count, step; 835 // std::lower_bound(first,last,cstart), but need to use global column indices 836 first = Bj + Bi[i]; 837 last = Bj + Bi[i + 1]; 838 count = last - first; 839 while (count > 0) { 840 it = first; 841 step = count / 2; 842 it += step; 843 if (empi->garray[*it] < cstart) { // map local to global 844 first = ++it; 845 count -= step + 1; 846 } else count = step; 847 } 848 E_NzLeft[i] = first - (Bj + Bi[i]); 849 E_RowLen[i] = (Ai[i + 1] - Ai[i]) + (Bi[i + 1] - Bi[i]); 850 } 851 852 // Compute row pointer Fi of F 853 PetscInt *Fi, Fm, Fnz; 854 PetscCall(PetscSFGetGraph(ownerSF, NULL, &Fm, NULL, NULL)); // Fm = #rows of F = nleaves of ownerSF 855 PetscCall(PetscMalloc1(Fm + 1, &Fi)); 856 Fi[0] = 0; 857 PetscCall(PetscSFBcastWithMemTypeBegin(ownerSF, MPIU_INT, PETSC_MEMTYPE_HOST, E_RowLen, PETSC_MEMTYPE_HOST, &Fi[1], MPI_REPLACE)); 858 PetscCall(PetscSFBcastEnd(ownerSF, MPIU_INT, E_RowLen, &Fi[1], MPI_REPLACE)); 859 for (PetscInt i = 0; i < Fm; i++) Fi[i + 1] += Fi[i]; 860 Fnz = Fi[Fm]; 861 862 // Build the real PetscSF for bcasting E rows (buffer to buffer) 863 const PetscMPIInt *iranks, *ranks; 864 const PetscInt *ioffset, *irootloc, *roffset; 865 PetscInt niranks, nranks, *sdisp, *rdisp; 866 MPI_Request *reqs; 867 PetscMPIInt tag; 868 869 PetscCall(PetscSFGetLeafRanks(ownerSF, &niranks, &iranks, &ioffset, &irootloc)); // get leaf ranks referencing roots on this process 870 PetscCall(PetscSFGetRootRanks(ownerSF, &nranks, &ranks, &roffset, NULL, NULL)); // recv info 871 PetscCall(PetscMalloc3(niranks + 1, &sdisp, nranks, &rdisp, niranks + nranks, &reqs)); 872 873 sdisp[0] = 0; // send displacement 874 for (PetscInt i = 0; i < niranks; i++) { 875 sdisp[i + 1] = sdisp[i]; 876 for (PetscInt j = ioffset[i]; j < ioffset[i + 1]; j++) { 877 PetscInt r = irootloc[j]; // row to be sent 878 sdisp[i + 1] += E_RowLen[r]; 879 } 880 } 881 882 PetscCallMPI(PetscCommGetNewTag(comm, &tag)); 883 for (PetscInt i = 0; i < nranks; i++) PetscCallMPI(MPI_Irecv(&rdisp[i], 1, MPIU_INT, ranks[i], tag, comm, &reqs[i])); 884 for (PetscInt i = 0; i < niranks; i++) PetscCallMPI(MPI_Isend(&sdisp[i], 1, MPIU_INT, iranks[i], tag, comm, &reqs[nranks + i])); 885 PetscCallMPI(MPI_Waitall(niranks + nranks, reqs, MPI_STATUSES_IGNORE)); 886 887 PetscInt nleaves = Fnz; // leaves are nonzeros I will receive 888 PetscInt nroots = sdisp[niranks]; // roots are nonzeros I will send 889 PetscSFNode *iremote; // give ownership to bcastSF 890 PetscCall(PetscMalloc1(nleaves, &iremote)); 891 for (PetscInt i = 0; i < nranks; i++) { // for each sender rank 892 PetscInt k = 0; 893 for (PetscInt j = Fi[roffset[i]]; j < Fi[roffset[i + 1]]; j++) { // I will receive rows [roffset[i], roffset[i+1]) of F from ranks[i] 894 iremote[j].rank = ranks[i]; 895 iremote[j].index = rdisp[i] + k; // their root location 896 k++; 897 } 898 } 899 PetscCall(PetscSFCreate(comm, &bcastSF)); 900 PetscCall(PetscSFSetGraph(bcastSF, nroots, nleaves, NULL, PETSC_OWN_POINTER, iremote, PETSC_OWN_POINTER)); 901 PetscCall(PetscFree3(sdisp, rdisp, reqs)); 902 903 // Build a plan (rowoffset, irootloc, E_NzLeft) to copy rows in E to rootdata of bcastSF in parallel 904 PetscIntKokkosViewHost rowoffset_h(NoInit("rowoffset_h"), ioffset[niranks] + 1); 905 PetscInt *rowoffset = rowoffset_h.data(); // for each entry (row) indicated in irootloc[], we calculate its destinate offset in copying 906 rowoffset[0] = 0; 907 for (PetscInt i = 0; i < ioffset[niranks]; i++) { rowoffset[i + 1] = rowoffset[i] + E_RowLen[irootloc[i]]; } 908 909 // Copy (global) column indices of the needed rows in E to a buffer, and then bcast to Fj[] 910 PetscInt *jbuf, *Fj; 911 PetscCall(PetscMalloc2(nroots, &jbuf, Fnz, &Fj)); 912 for (PetscInt k = 0; k < ioffset[niranks]; k++) { 913 PetscInt i = irootloc[k]; // row to be copied 914 PetscInt *buf = &jbuf[rowoffset[k]]; 915 PetscInt nzLeft = E_NzLeft[i]; 916 PetscInt alen = Ai[i + 1] - Ai[i], blen = Bi[i + 1] - Bi[i]; 917 for (PetscInt j = 0; j < alen + blen; j++) { 918 if (j < nzLeft) { 919 buf[j] = empi->garray[Bj[Bi[i] + j]]; // left B, in global 920 } else if (j < nzLeft + alen) { 921 buf[j] = Aj[Ai[i] + j - nzLeft] + cstart; // diag A, also in global 922 } else { 923 buf[j] = empi->garray[Bj[Bi[i] + j - alen]]; // right B, in global 924 } 925 } 926 } 927 PetscCall(PetscSFBcastWithMemTypeBegin(bcastSF, MPIU_INT, PETSC_MEMTYPE_HOST, jbuf, PETSC_MEMTYPE_HOST, Fj, MPI_REPLACE)); 928 PetscCall(PetscSFBcastEnd(bcastSF, MPIU_INT, jbuf, Fj, MPI_REPLACE)); 929 930 // Build a plan (i.e., F_NzLeft) to split F into Fd and Fo 931 MatRowMapKokkosViewHost Fdi_h(NoInit("Fdi_h"), Fm + 1), Foi_h(NoInit("Foi_h"), Fm + 1); // row pointer of Fd, Fo 932 MatColIdxKokkosViewHost F_NzLeft_h(NoInit("F_NzLeft_h"), Fm); // split each row of F into Left, Diag, Right. We only need to record #nz in Left and Diag. 933 MatRowMapType *Fdi = Fdi_h.data(), *Foi = Foi_h.data(); 934 MatColIdxType *F_NzLeft = F_NzLeft_h.data(); 935 936 Fdi[0] = Foi[0] = 0; 937 for (PetscInt i = 0; i < Fm; i++) { 938 PetscInt *first, *last, *lb1, *lb2; 939 // cut the row into: Left, [cstart, cend), Right 940 first = Fj + Fi[i]; 941 last = Fj + Fi[i + 1]; 942 lb1 = std::lower_bound(first, last, cstart); 943 F_NzLeft[i] = lb1 - first; 944 lb2 = std::lower_bound(first, last, cend); 945 Fdi[i + 1] = lb2 - lb1; // row i length in Fdi 946 Foi[i + 1] = (Fi[i + 1] - Fi[i]) - Fdi[i + 1]; // row i length in Foi 947 } 948 for (PetscInt i = 0; i < Fm; i++) { 949 Fdi[i + 1] += Fdi[i]; 950 Foi[i + 1] += Foi[i]; 951 } 952 953 // Fill Fdj[] and Foj[], i.e., columns of Fd and Fo. Fdj[] are local, but Foj[] are not yet. 954 PetscInt Fdnz = Fdi[Fm], Fonz = Foi[Fm]; 955 MatColIdxKokkosViewHost Fdj_h(NoInit("Fdj_h"), Fdnz), Foj_h(NoInit("Foj_h"), Fonz); 956 MatColIdxType *Fdj = Fdj_h.data(), *Foj = Foj_h.data(), gid; 957 958 for (PetscInt i = 0; i < Fm; i++) { 959 PetscInt nzLeft = F_NzLeft[i]; 960 PetscInt len = Fdi[i + 1] - Fdi[i]; // diag row len 961 for (PetscInt j = 0; j < Fi[i + 1] - Fi[i]; j++) { 962 gid = Fj[Fi[i] + j]; 963 if (j < nzLeft) { // left, in global 964 Foj[Foi[i] + j] = gid; 965 } else if (j < nzLeft + len) { // diag, in local 966 Fdj[Fdi[i] + j - nzLeft] = gid - cstart; 967 } else { // right, in global 968 Foj[Foi[i] + j - len] = gid; 969 } 970 } 971 } 972 PetscCall(PetscFree2(jbuf, Fj)); 973 PetscCall(PetscFree(Fi)); 974 975 // Reduce global indices in Foj[] and garray1[] into local ones 976 PetscInt n2, *garray2; 977 PetscCall(ReduceTwoSetsOfGlobalIndices(n1, garray1, Fonz, Foj, &n2, &garray2, map)); 978 979 // Record the plans built above, for reuse 980 PetscIntKokkosViewHost tmp(const_cast<PetscInt *>(irootloc), ioffset[niranks]); // irootloc[] is owned by ownerSF. We create a copy for safety 981 PetscIntKokkosViewHost irootloc_h(NoInit("irootloc_h"), ioffset[niranks]); 982 Kokkos::deep_copy(irootloc_h, tmp); 983 mm->sf = bcastSF; 984 mm->E_NzLeft = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), E_NzLeft_h); 985 mm->F_NzLeft = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), F_NzLeft_h); 986 mm->irootloc = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), irootloc_h); 987 mm->rowoffset = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), rowoffset_h); 988 mm->rootBuf = MatScalarKokkosView(NoInit("rootBuf"), nroots); 989 mm->leafBuf = MatScalarKokkosView(NoInit("leafBuf"), nleaves); 990 mm->garray = garray2; 991 mm->n = n2; 992 993 // Output Fd and Fo in KokkosCsrMatrix format 994 MatScalarKokkosView Fda_d(NoInit("Fda_d"), Fdnz), Foa_d(NoInit("Foa_d"), Fonz); 995 MatRowMapKokkosView Fdi_d = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Fdi_h); 996 MatColIdxKokkosView Fdj_d = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Fdj_h); 997 MatRowMapKokkosView Foi_d = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Foi_h); 998 MatColIdxKokkosView Foj_d = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Foj_h); 999 1000 PetscCallCXX(mm->Fd = KokkosCsrMatrix("Fd", Fm, cend - cstart, Fdnz, Fda_d, Fdi_d, Fdj_d)); 1001 PetscCallCXX(mm->Fo = KokkosCsrMatrix("Fo", Fm, n2, Fonz, Foa_d, Foi_d, Foj_d)); 1002 1003 // Compute kernel launch parameters in merging E or splitting F 1004 PetscInt teamSize, vectorLength, rowsPerTeam; 1005 1006 teamSize = vectorLength = rowsPerTeam = -1; 1007 PetscCall(MatMergeGetLaunchParameters<DefaultExecutionSpace>(mm->irootloc.extent(0), mm->rootBuf.extent(0), -1, teamSize, vectorLength, rowsPerTeam)); 1008 mm->E_TeamSize = teamSize; 1009 mm->E_VectorLength = vectorLength; 1010 mm->E_RowsPerTeam = rowsPerTeam; 1011 1012 teamSize = vectorLength = rowsPerTeam = -1; 1013 PetscCall(MatMergeGetLaunchParameters<DefaultExecutionSpace>(Fm, Fnz, -1, teamSize, vectorLength, rowsPerTeam)); 1014 mm->F_TeamSize = teamSize; 1015 mm->F_VectorLength = vectorLength; 1016 mm->F_RowsPerTeam = rowsPerTeam; 1017 } else PetscCheck(reuse == MAT_REUSE_MATRIX, comm, PETSC_ERR_PLIB, "Unsupported MatReuse enum %d", reuse); 1018 1019 // Sync E's value to device 1020 akok->a_dual.sync_device(); 1021 bkok->a_dual.sync_device(); 1022 1023 // Handy aliases 1024 const auto &Aa = akok->a_dual.view_device(); 1025 const auto &Ba = bkok->a_dual.view_device(); 1026 const auto &Ai = akok->i_dual.view_device(); 1027 const auto &Bi = bkok->i_dual.view_device(); 1028 1029 // Fetch the plans 1030 PetscIntKokkosView &E_NzLeft = mm->E_NzLeft; 1031 PetscSF &bcastSF = mm->sf; 1032 MatScalarKokkosView &rootBuf = mm->rootBuf; 1033 MatScalarKokkosView &leafBuf = mm->leafBuf; 1034 PetscIntKokkosView &irootloc = mm->irootloc; 1035 PetscIntKokkosView &rowoffset = mm->rowoffset; 1036 1037 PetscInt teamSize = mm->E_TeamSize; 1038 PetscInt vectorLength = mm->E_VectorLength; 1039 PetscInt rowsPerTeam = mm->E_RowsPerTeam; 1040 PetscInt workSets = (irootloc.extent(0) + rowsPerTeam - 1) / rowsPerTeam; 1041 1042 // Copy rows in A/B of E to rootBuf, then bcast it to leafBuf 1043 PetscCallCXX(Kokkos::parallel_for( 1044 Kokkos::TeamPolicy<>(workSets, teamSize, vectorLength), KOKKOS_LAMBDA(const KokkosTeamMemberType &t) { 1045 Kokkos::parallel_for(Kokkos::TeamThreadRange(t, 0, rowsPerTeam), [&](PetscInt k) { 1046 size_t r = t.league_rank() * rowsPerTeam + k; // r-th entry in irootloc[] 1047 if (r < irootloc.extent(0)) { 1048 PetscInt i = irootloc(r); // row i of E 1049 PetscInt disp = rowoffset(r); 1050 PetscInt alen = Ai(i + 1) - Ai(i); 1051 PetscInt blen = Bi(i + 1) - Bi(i); 1052 PetscInt nzleft = E_NzLeft(i); 1053 1054 Kokkos::parallel_for(Kokkos::ThreadVectorRange(t, alen + blen), [&](PetscInt j) { 1055 if (j < nzleft) { // B left 1056 rootBuf(disp + j) = Ba(Bi(i) + j); 1057 } else if (j < nzleft + alen) { // diag A 1058 rootBuf(disp + j) = Aa(Ai(i) + j - nzleft); 1059 } else { // B right 1060 rootBuf(disp + j) = Ba(Bi(i) + j - alen); 1061 } 1062 }); 1063 } 1064 }); 1065 })); 1066 PetscCall(PetscSFBcastWithMemTypeBegin(bcastSF, MPIU_SCALAR, PETSC_MEMTYPE_KOKKOS, rootBuf.data(), PETSC_MEMTYPE_KOKKOS, leafBuf.data(), MPI_REPLACE)); 1067 PetscFunctionReturn(PETSC_SUCCESS); 1068 } 1069 1070 // To finish MatMPIAIJKokkosBcast. 1071 static PetscErrorCode MatMPIAIJKokkosBcastEnd(Mat E, PetscSF ownerSF, MatReuse reuse, PetscInt *map, MatMatStruct_AB *mm) 1072 { 1073 PetscFunctionBegin; 1074 const auto &Fd = mm->Fd; 1075 const auto &Fo = mm->Fo; 1076 const auto &Fdi = Fd.graph.row_map; 1077 const auto &Foi = Fo.graph.row_map; 1078 auto &Fda = Fd.values; 1079 auto &Foa = Fo.values; 1080 auto Fm = Fd.numRows(); 1081 1082 PetscIntKokkosView &F_NzLeft = mm->F_NzLeft; 1083 PetscSF &bcastSF = mm->sf; 1084 MatScalarKokkosView &rootBuf = mm->rootBuf; 1085 MatScalarKokkosView &leafBuf = mm->leafBuf; 1086 PetscInt teamSize = mm->F_TeamSize; 1087 PetscInt vectorLength = mm->F_VectorLength; 1088 PetscInt rowsPerTeam = mm->F_RowsPerTeam; 1089 PetscInt workSets = (Fm + rowsPerTeam - 1) / rowsPerTeam; 1090 1091 PetscCall(PetscSFBcastEnd(bcastSF, MPIU_SCALAR, rootBuf.data(), leafBuf.data(), MPI_REPLACE)); 1092 1093 // Update Fda and Foa with new data in leafBuf (as if it is Fa) 1094 PetscCallCXX(Kokkos::parallel_for( 1095 Kokkos::TeamPolicy<>(workSets, teamSize, vectorLength), KOKKOS_LAMBDA(const KokkosTeamMemberType &t) { 1096 Kokkos::parallel_for(Kokkos::TeamThreadRange(t, 0, rowsPerTeam), [&](PetscInt k) { 1097 PetscInt i = t.league_rank() * rowsPerTeam + k; // i-th row in F 1098 if (i < Fm) { 1099 PetscInt nzLeft = F_NzLeft(i); 1100 PetscInt alen = Fdi(i + 1) - Fdi(i); 1101 PetscInt blen = Foi(i + 1) - Foi(i); 1102 PetscInt Fii = Fdi(i) + Foi(i); 1103 1104 Kokkos::parallel_for(Kokkos::ThreadVectorRange(t, alen + blen), [&](PetscInt j) { 1105 PetscScalar val = leafBuf(Fii + j); 1106 if (j < nzLeft) { // left 1107 Foa(Foi(i) + j) = val; 1108 } else if (j < nzLeft + alen) { // diag 1109 Fda(Fdi(i) + j - nzLeft) = val; 1110 } else { // right 1111 Foa(Foi(i) + j - alen) = val; 1112 } 1113 }); 1114 } 1115 }); 1116 })); 1117 PetscFunctionReturn(PETSC_SUCCESS); 1118 } 1119 1120 static PetscErrorCode MatProductSymbolic_MPIAIJKokkos_AtB(Mat_Product *product, Mat A, Mat B, MatMatStruct_AtB *mm) 1121 { 1122 Mat_MPIAIJ *ampi = static_cast<Mat_MPIAIJ *>(A->data); 1123 Mat_MPIAIJ *bmpi = static_cast<Mat_MPIAIJ *>(B->data); 1124 KokkosCsrMatrix Adt, Aot, Ad, Ao, Bd, Bo; 1125 PetscInt cstart, cend; 1126 MPI_Comm comm; 1127 1128 PetscFunctionBegin; 1129 PetscCall(PetscObjectGetComm((PetscObject)B, &comm)); 1130 PetscCall(MatSeqAIJKokkosGenerateTranspose_Private(ampi->A, &Adt)); 1131 PetscCall(MatSeqAIJKokkosGenerateTranspose_Private(ampi->B, &Aot)); 1132 PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(ampi->A, &Ad)); 1133 PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(ampi->B, &Ao)); 1134 PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(bmpi->A, &Bd)); 1135 PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(bmpi->B, &Bo)); 1136 1137 // TODO: add command line options to select spgemm algorithms 1138 auto spgemm_alg = KokkosSparse::SPGEMMAlgorithm::SPGEMM_DEFAULT; // default is TPL if enabled, otherwise KK 1139 1140 // CUDA-10.2's spgemm has bugs. We prefer the SpGEMMreuse APIs introduced in cuda-11.4 1141 #if defined(KOKKOSKERNELS_ENABLE_TPL_CUSPARSE) 1142 #if PETSC_PKG_CUDA_VERSION_LT(11, 4, 0) 1143 spgemm_alg = KokkosSparse::SPGEMMAlgorithm::SPGEMM_KK; 1144 #endif 1145 #endif 1146 1147 PetscCallCXX(mm->kh1.create_spgemm_handle(spgemm_alg)); 1148 PetscCallCXX(mm->kh2.create_spgemm_handle(spgemm_alg)); 1149 PetscCallCXX(mm->kh3.create_spgemm_handle(spgemm_alg)); 1150 PetscCallCXX(mm->kh4.create_spgemm_handle(spgemm_alg)); 1151 1152 // Aot * (B's diag + B's off-diag) 1153 PetscCallCXX(KokkosSparse::spgemm_symbolic(mm->kh3, Aot, false, Bd, false, mm->C3)); 1154 PetscCallCXX(KokkosSparse::spgemm_symbolic(mm->kh4, Aot, false, Bo, false, mm->C4)); 1155 // KK spgemm_symbolic() only populates the result's row map, but not its columns. 1156 // TODO: Remove the fake spgemm_numeric() after KK fixed this problem. 1157 PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh3, Aot, false, Bd, false, mm->C3)); 1158 PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh4, Aot, false, Bo, false, mm->C4)); 1159 #if PETSC_PKG_KOKKOS_KERNELS_VERSION_LT(4, 0, 0) 1160 PetscCallCXX(sort_crs_matrix(mm->C3)); 1161 PetscCallCXX(sort_crs_matrix(mm->C4)); 1162 #endif 1163 1164 // Reduce E (i.e., C3 and C4)'s rows to form F, and overlap the communication 1165 PetscIntKokkosViewHost map_h(NoInit("map_h"), bmpi->B->cmap->n); 1166 PetscCall(MatGetOwnershipRangeColumn(B, &cstart, &cend)); 1167 PetscCall(MatMPIAIJKokkosReduceBegin(comm, mm->C3, mm->C4, cstart, cend, bmpi->garray, ampi->Mvctx, MAT_INITIAL_MATRIX, map_h.data(), mm)); 1168 1169 // Adt * (B's diag + B's off-diag) 1170 PetscCallCXX(KokkosSparse::spgemm_symbolic(mm->kh1, Adt, false, Bd, false, mm->C1)); 1171 PetscCallCXX(KokkosSparse::spgemm_symbolic(mm->kh2, Adt, false, Bo, false, mm->C2_mid)); 1172 PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh1, Adt, false, Bd, false, mm->C1)); 1173 PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh2, Adt, false, Bo, false, mm->C2_mid)); 1174 #if PETSC_PKG_KOKKOS_KERNELS_VERSION_LT(4, 0, 0) 1175 PetscCallCXX(sort_crs_matrix(mm->C1)); 1176 PetscCallCXX(sort_crs_matrix(mm->C2_mid)); 1177 #endif 1178 1179 PetscCall(MatMPIAIJKokkosReduceEnd(comm, mm->C3, mm->C4, cstart, cend, bmpi->garray, ampi->Mvctx, MAT_INITIAL_MATRIX, map_h.data(), mm)); 1180 1181 // Create C2, which shares a, i arrays with C2_mid, but with new column indices and potentially larger column size 1182 MatColIdxKokkosView oldj = mm->C2_mid.graph.entries, newj(NoInit("j"), oldj.extent(0)); 1183 PetscIntKokkosView map = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), map_h); 1184 PetscCallCXX(Kokkos::parallel_for(oldj.extent(0), KOKKOS_LAMBDA(const PetscInt i) { newj(i) = map(oldj(i)); })); 1185 PetscCallCXX(mm->C2 = KokkosCsrMatrix("C2", mm->C2_mid.numRows(), mm->n /*new column size*/, mm->C2_mid.nnz(), mm->C2_mid.values, mm->C2_mid.graph.row_map, newj)); 1186 1187 // C = (C1+Fd, C2+Fo) 1188 PetscCallCXX(mm->kh1.create_spadd_handle(true)); // C1, Fd are sorted 1189 PetscCallCXX(mm->kh2.create_spadd_handle(true)); // C2, Fo are sorted 1190 PetscCallCXX(KokkosSparse::spadd_symbolic(&mm->kh1, mm->C1, mm->Fd, mm->Cd)); 1191 PetscCallCXX(KokkosSparse::spadd_symbolic(&mm->kh2, mm->C2, mm->Fo, mm->Co)); 1192 PetscCallCXX(KokkosSparse::spadd_numeric(&mm->kh1, 1.0, mm->C1, 1.0, mm->Fd, mm->Cd)); 1193 PetscCallCXX(KokkosSparse::spadd_numeric(&mm->kh2, 1.0, mm->C2, 1.0, mm->Fo, mm->Co)); 1194 PetscFunctionReturn(PETSC_SUCCESS); 1195 } 1196 1197 static PetscErrorCode MatProductNumeric_MPIAIJKokkos_AtB(Mat_Product *product, Mat A, Mat B, MatMatStruct_AtB *mm) 1198 { 1199 Mat_MPIAIJ *ampi = static_cast<Mat_MPIAIJ *>(A->data); 1200 Mat_MPIAIJ *bmpi = static_cast<Mat_MPIAIJ *>(B->data); 1201 KokkosCsrMatrix Adt, Aot, Bd, Bo; 1202 MPI_Comm comm; 1203 1204 PetscFunctionBegin; 1205 PetscCall(PetscObjectGetComm((PetscObject)B, &comm)); 1206 PetscCall(MatSeqAIJKokkosGenerateTranspose_Private(ampi->A, &Adt)); 1207 PetscCall(MatSeqAIJKokkosGenerateTranspose_Private(ampi->B, &Aot)); 1208 PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(bmpi->A, &Bd)); 1209 PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(bmpi->B, &Bo)); 1210 1211 // Aot * (B's diag + B's off-diag) 1212 PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh3, Aot, false, Bd, false, mm->C3)); 1213 PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh4, Aot, false, Bo, false, mm->C4)); 1214 1215 // Reduce E (i.e., C3 and C4)'s rows to form F, and overlap the communication 1216 PetscCall(MatMPIAIJKokkosReduceBegin(comm, mm->C3, mm->C4, 0, 0, NULL, NULL, MAT_REUSE_MATRIX, NULL, mm)); 1217 1218 // Adt * (B's diag + B's off-diag) 1219 PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh1, Adt, false, Bd, false, mm->C1)); 1220 PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh2, Adt, false, Bo, false, mm->C2_mid)); 1221 1222 PetscCall(MatMPIAIJKokkosReduceEnd(comm, mm->C3, mm->C4, 0, 0, NULL, NULL, MAT_REUSE_MATRIX, NULL, mm)); 1223 1224 // C = (C1+Fd, C2+Fo) 1225 PetscCallCXX(KokkosSparse::spadd_numeric(&mm->kh1, 1.0, mm->C1, 1.0, mm->Fd, mm->Cd)); 1226 PetscCallCXX(KokkosSparse::spadd_numeric(&mm->kh2, 1.0, mm->C2, 1.0, mm->Fo, mm->Co)); 1227 PetscFunctionReturn(PETSC_SUCCESS); 1228 } 1229 1230 /* MatProductSymbolic_MPIAIJKokkos_AB - AB flavor of MatProductSymbolic_MPIAIJKokkos 1231 1232 Input Parameters: 1233 + product - Mat_Product which carried out the computation. Passed in to access info about this mat product. 1234 . A - an MPIAIJKOKKOS matrix 1235 . B - an MPIAIJKOKKOS matrix 1236 - mm - a struct used to stash intermediate data when computing AB. Persist from symbolic to numeric operations. 1237 */ 1238 static PetscErrorCode MatProductSymbolic_MPIAIJKokkos_AB(Mat_Product *product, Mat A, Mat B, MatMatStruct_AB *mm) 1239 { 1240 Mat_MPIAIJ *ampi = static_cast<Mat_MPIAIJ *>(A->data); 1241 Mat_MPIAIJ *bmpi = static_cast<Mat_MPIAIJ *>(B->data); 1242 KokkosCsrMatrix Ad, Ao, Bd, Bo; 1243 1244 PetscFunctionBegin; 1245 PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(ampi->A, &Ad)); 1246 PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(ampi->B, &Ao)); 1247 PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(bmpi->A, &Bd)); 1248 PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(bmpi->B, &Bo)); 1249 1250 // TODO: add command line options to select spgemm algorithms 1251 auto spgemm_alg = KokkosSparse::SPGEMMAlgorithm::SPGEMM_DEFAULT; // default is TPL if enabled, otherwise KK 1252 1253 // CUDA-10.2's spgemm has bugs. We prefer the SpGEMMreuse APIs introduced in cuda-11.4 1254 #if defined(KOKKOSKERNELS_ENABLE_TPL_CUSPARSE) 1255 #if PETSC_PKG_CUDA_VERSION_LT(11, 4, 0) 1256 spgemm_alg = KokkosSparse::SPGEMMAlgorithm::SPGEMM_KK; 1257 #endif 1258 #endif 1259 1260 mm->kh1.create_spgemm_handle(spgemm_alg); 1261 mm->kh2.create_spgemm_handle(spgemm_alg); 1262 mm->kh3.create_spgemm_handle(spgemm_alg); 1263 mm->kh4.create_spgemm_handle(spgemm_alg); 1264 1265 // Bcast B's rows to form F, and overlap the communication 1266 PetscIntKokkosViewHost map_h(NoInit("map_h"), bmpi->B->cmap->n); 1267 PetscCall(MatMPIAIJKokkosBcastBegin(B, ampi->Mvctx, MAT_INITIAL_MATRIX, map_h.data(), mm)); 1268 1269 // A's diag * (B's diag + B's off-diag) 1270 PetscCallCXX(KokkosSparse::spgemm_symbolic(mm->kh1, Ad, false, Bd, false, mm->C1)); 1271 PetscCallCXX(KokkosSparse::spgemm_symbolic(mm->kh2, Ad, false, Bo, false, mm->C2_mid)); // C2 aliases with C2_mid, except with new column indices 1272 // KK spgemm_symbolic() only populates the result's row map, but not its columns. 1273 // TODO: Remove the fake spgemm_numeric() after KK fixed this problem. 1274 PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh1, Ad, false, Bd, false, mm->C1)); 1275 PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh2, Ad, false, Bo, false, mm->C2_mid)); 1276 #if PETSC_PKG_KOKKOS_KERNELS_VERSION_LT(4, 0, 0) 1277 PetscCallCXX(sort_crs_matrix(mm->C1)); 1278 PetscCallCXX(sort_crs_matrix(mm->C2_mid)); 1279 #endif 1280 1281 PetscCall(MatMPIAIJKokkosBcastEnd(B, ampi->Mvctx, MAT_INITIAL_MATRIX, map_h.data(), mm)); 1282 1283 // A's off-diag * (F's diag + F's off-diag) 1284 PetscCallCXX(KokkosSparse::spgemm_symbolic(mm->kh3, Ao, false, mm->Fd, false, mm->C3)); 1285 PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh3, Ao, false, mm->Fd, false, mm->C3)); 1286 PetscCallCXX(KokkosSparse::spgemm_symbolic(mm->kh4, Ao, false, mm->Fo, false, mm->C4)); 1287 PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh4, Ao, false, mm->Fo, false, mm->C4)); 1288 #if PETSC_PKG_KOKKOS_KERNELS_VERSION_LT(4, 0, 0) 1289 PetscCallCXX(sort_crs_matrix(mm->C3)); 1290 PetscCallCXX(sort_crs_matrix(mm->C4)); 1291 #endif 1292 1293 // Create C2, which shares a, i arrays with C2_mid, but with new column indices and potentially larger column size 1294 MatColIdxKokkosView oldj = mm->C2_mid.graph.entries, newj(NoInit("j"), oldj.extent(0)); 1295 PetscIntKokkosView map = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), map_h); 1296 PetscCallCXX(Kokkos::parallel_for(oldj.extent(0), KOKKOS_LAMBDA(const PetscInt i) { newj(i) = map(oldj(i)); })); 1297 mm->C2 = KokkosCsrMatrix("C2", mm->C2_mid.numRows(), mm->n /*new column size*/, mm->C2_mid.nnz(), mm->C2_mid.values, mm->C2_mid.graph.row_map, newj); 1298 1299 // C = (Cd, Co) = (C1+C3, C2+C4) 1300 mm->kh1.create_spadd_handle(true); // C1, C3 are sorted 1301 mm->kh2.create_spadd_handle(true); // C2, C4 are sorted 1302 PetscCallCXX(KokkosSparse::spadd_symbolic(&mm->kh1, mm->C1, mm->C3, mm->Cd)); 1303 PetscCallCXX(KokkosSparse::spadd_symbolic(&mm->kh2, mm->C2, mm->C4, mm->Co)); 1304 PetscCallCXX(KokkosSparse::spadd_numeric(&mm->kh1, 1.0, mm->C1, 1.0, mm->C3, mm->Cd)); 1305 PetscCallCXX(KokkosSparse::spadd_numeric(&mm->kh2, 1.0, mm->C2, 1.0, mm->C4, mm->Co)); 1306 PetscFunctionReturn(PETSC_SUCCESS); 1307 } 1308 1309 static PetscErrorCode MatProductNumeric_MPIAIJKokkos_AB(Mat_Product *product, Mat A, Mat B, MatMatStruct_AB *mm) 1310 { 1311 Mat_MPIAIJ *ampi = static_cast<Mat_MPIAIJ *>(A->data); 1312 Mat_MPIAIJ *bmpi = static_cast<Mat_MPIAIJ *>(B->data); 1313 KokkosCsrMatrix Ad, Ao, Bd, Bo; 1314 1315 PetscFunctionBegin; 1316 PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(ampi->A, &Ad)); 1317 PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(ampi->B, &Ao)); 1318 PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(bmpi->A, &Bd)); 1319 PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(bmpi->B, &Bo)); 1320 1321 // Bcast B's rows to form F, and overlap the communication 1322 PetscCall(MatMPIAIJKokkosBcastBegin(B, NULL, MAT_REUSE_MATRIX, NULL, mm)); 1323 1324 // A's diag * (B's diag + B's off-diag) 1325 PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh1, Ad, false, Bd, false, mm->C1)); 1326 PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh2, Ad, false, Bo, false, mm->C2_mid)); 1327 1328 PetscCall(MatMPIAIJKokkosBcastEnd(B, NULL, MAT_REUSE_MATRIX, NULL, mm)); 1329 1330 // A's off-diag * (F's diag + F's off-diag) 1331 PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh3, Ao, false, mm->Fd, false, mm->C3)); 1332 PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh4, Ao, false, mm->Fo, false, mm->C4)); 1333 1334 // C = (Cd, Co) = (C1+C3, C2+C4) 1335 PetscCallCXX(KokkosSparse::spadd_numeric(&mm->kh1, 1.0, mm->C1, 1.0, mm->C3, mm->Cd)); 1336 PetscCallCXX(KokkosSparse::spadd_numeric(&mm->kh2, 1.0, mm->C2, 1.0, mm->C4, mm->Co)); 1337 PetscFunctionReturn(PETSC_SUCCESS); 1338 } 1339 1340 static PetscErrorCode MatProductNumeric_MPIAIJKokkos(Mat C) 1341 { 1342 Mat_MPIAIJ *cmpi = static_cast<Mat_MPIAIJ *>(C->data); 1343 Mat_Product *product; 1344 MatProductData_MPIAIJKokkos *pdata; 1345 MatProductType ptype; 1346 Mat A, B; 1347 1348 PetscFunctionBegin; 1349 MatCheckProduct(C, 1); // make sure C is a product 1350 product = C->product; 1351 pdata = static_cast<MatProductData_MPIAIJKokkos *>(product->data); 1352 ptype = product->type; 1353 A = product->A; 1354 B = product->B; 1355 1356 // See if numeric has already been done in symbolic (e.g., user calls MatMatMult(A,B,MAT_INITIAL_MATRIX,..,C)). 1357 // If yes, skip the numeric, but reset the flag so that next time when user calls MatMatMult(E,F,MAT_REUSE_MATRIX,..,C), 1358 // we still do numeric. 1359 if (pdata->reusesym) { // numeric reuses results from symbolic 1360 pdata->reusesym = PETSC_FALSE; 1361 PetscFunctionReturn(PETSC_SUCCESS); 1362 } 1363 1364 if (ptype == MATPRODUCT_AB) { 1365 PetscCall(MatProductNumeric_MPIAIJKokkos_AB(product, A, B, pdata->mmAB)); 1366 } else if (ptype == MATPRODUCT_AtB) { 1367 PetscCall(MatProductNumeric_MPIAIJKokkos_AtB(product, A, B, pdata->mmAtB)); 1368 } else if (ptype == MATPRODUCT_PtAP) { // BtAB, computed by Z = AB; C= BtZ 1369 PetscCall(MatProductNumeric_MPIAIJKokkos_AB(product, A, B, pdata->mmAB)); 1370 PetscCall(MatProductNumeric_MPIAIJKokkos_AtB(product, B, pdata->Z, pdata->mmAtB)); 1371 } 1372 1373 PetscCall(MatSeqAIJKokkosModifyDevice(cmpi->A)); // mark that A, B on device are modified 1374 PetscCall(MatSeqAIJKokkosModifyDevice(cmpi->B)); 1375 PetscFunctionReturn(PETSC_SUCCESS); 1376 } 1377 1378 static PetscErrorCode MatProductSymbolic_MPIAIJKokkos(Mat C) 1379 { 1380 Mat A, B; 1381 Mat_Product *product; 1382 MatProductType ptype; 1383 MatProductData_MPIAIJKokkos *pdata; 1384 MatMatStruct *mm = NULL; 1385 PetscInt m, n, M, N; 1386 Mat Cd, Co; 1387 MPI_Comm comm; 1388 1389 PetscFunctionBegin; 1390 PetscCall(PetscObjectGetComm((PetscObject)C, &comm)); 1391 MatCheckProduct(C, 1); 1392 product = C->product; 1393 PetscCheck(!product->data, comm, PETSC_ERR_PLIB, "Product data not empty"); 1394 ptype = product->type; 1395 A = product->A; 1396 B = product->B; 1397 1398 switch (ptype) { 1399 case MATPRODUCT_AB: 1400 m = A->rmap->n; 1401 n = B->cmap->n; 1402 M = A->rmap->N; 1403 N = B->cmap->N; 1404 break; 1405 case MATPRODUCT_AtB: 1406 m = A->cmap->n; 1407 n = B->cmap->n; 1408 M = A->cmap->N; 1409 N = B->cmap->N; 1410 break; 1411 case MATPRODUCT_PtAP: 1412 m = B->cmap->n; 1413 n = B->cmap->n; 1414 M = B->cmap->N; 1415 N = B->cmap->N; 1416 break; /* BtAB */ 1417 default: 1418 SETERRQ(comm, PETSC_ERR_PLIB, "Not for product type %s", MatProductTypes[ptype]); 1419 } 1420 1421 PetscCall(MatSetSizes(C, m, n, M, N)); 1422 PetscCall(PetscLayoutSetUp(C->rmap)); 1423 PetscCall(PetscLayoutSetUp(C->cmap)); 1424 PetscCall(MatSetType(C, ((PetscObject)A)->type_name)); 1425 1426 pdata = new MatProductData_MPIAIJKokkos(); 1427 pdata->reusesym = product->api_user; 1428 1429 if (ptype == MATPRODUCT_AB) { 1430 auto mmAB = new MatMatStruct_AB(); 1431 PetscCall(MatProductSymbolic_MPIAIJKokkos_AB(product, A, B, mmAB)); 1432 mm = pdata->mmAB = mmAB; 1433 } else if (ptype == MATPRODUCT_AtB) { 1434 auto mmAtB = new MatMatStruct_AtB(); 1435 PetscCall(MatProductSymbolic_MPIAIJKokkos_AtB(product, A, B, mmAtB)); 1436 mm = pdata->mmAtB = mmAtB; 1437 } else if (ptype == MATPRODUCT_PtAP) { // C = BtAB, computed as Z = AB; C= BtZ 1438 Mat Zd, Zo, Z; // Zd, Zo are owned by pdata->Z 1439 1440 auto mmAB = new MatMatStruct_AB(); 1441 PetscCall(MatProductSymbolic_MPIAIJKokkos_AB(product, A, B, mmAB)); // Z stored as mmAB->{Cd, Co} 1442 PetscCall(MatCreateSeqAIJKokkosWithKokkosCsrMatrix(PETSC_COMM_SELF, mmAB->Cd, &Zd)); 1443 PetscCall(MatCreateSeqAIJKokkosWithKokkosCsrMatrix(PETSC_COMM_SELF, mmAB->Co, &Zo)); 1444 pdata->mmAB = mmAB; 1445 1446 m = A->rmap->n; // Z's layout 1447 n = B->cmap->n; 1448 M = A->rmap->N; 1449 N = B->cmap->N; 1450 PetscCall(MatCreate(comm, &Z)); 1451 PetscCall(MatSetSizes(Z, m, n, M, N)); 1452 PetscCall(PetscLayoutSetUp(Z->rmap)); 1453 PetscCall(PetscLayoutSetUp(Z->cmap)); 1454 PetscCall(MatSetType(Z, MATMPIAIJKOKKOS)); 1455 PetscCall(MatSetMPIAIJKokkosWithSplitSeqAIJKokkosMatrices(Z, Zd, Zo, mmAB->garray)); 1456 1457 auto mmAtB = new MatMatStruct_AtB(); 1458 PetscCall(MatProductSymbolic_MPIAIJKokkos_AtB(product, B, Z, mmAtB)); // final result C stored as mmAtB->{Cd, Co} 1459 1460 pdata->Z = Z; // give ownership to pdata 1461 mm = pdata->mmAtB = mmAtB; 1462 } 1463 1464 PetscCall(MatCreateSeqAIJKokkosWithKokkosCsrMatrix(PETSC_COMM_SELF, mm->Cd, &Cd)); 1465 PetscCall(MatCreateSeqAIJKokkosWithKokkosCsrMatrix(PETSC_COMM_SELF, mm->Co, &Co)); 1466 PetscCall(MatSetMPIAIJKokkosWithSplitSeqAIJKokkosMatrices(C, Cd, Co, mm->garray)); 1467 1468 C->product->data = pdata; 1469 C->product->destroy = MatProductDataDestroy_MPIAIJKokkos; 1470 C->ops->productnumeric = MatProductNumeric_MPIAIJKokkos; 1471 PetscFunctionReturn(PETSC_SUCCESS); 1472 } 1473 1474 PETSC_INTERN PetscErrorCode MatProductSetFromOptions_MPIAIJKokkos(Mat mat) 1475 { 1476 Mat_Product *product = mat->product; 1477 PetscBool match = PETSC_FALSE; 1478 PetscBool usecpu = PETSC_FALSE; 1479 1480 PetscFunctionBegin; 1481 MatCheckProduct(mat, 1); 1482 if (!product->A->boundtocpu && !product->B->boundtocpu) PetscCall(PetscObjectTypeCompare((PetscObject)product->B, ((PetscObject)product->A)->type_name, &match)); 1483 if (match) { /* we can always fallback to the CPU if requested */ 1484 switch (product->type) { 1485 case MATPRODUCT_AB: 1486 if (product->api_user) { 1487 PetscOptionsBegin(PetscObjectComm((PetscObject)mat), ((PetscObject)mat)->prefix, "MatMatMult", "Mat"); 1488 PetscCall(PetscOptionsBool("-matmatmult_backend_cpu", "Use CPU code", "MatMatMult", usecpu, &usecpu, NULL)); 1489 PetscOptionsEnd(); 1490 } else { 1491 PetscOptionsBegin(PetscObjectComm((PetscObject)mat), ((PetscObject)mat)->prefix, "MatProduct_AB", "Mat"); 1492 PetscCall(PetscOptionsBool("-mat_product_algorithm_backend_cpu", "Use CPU code", "MatMatMult", usecpu, &usecpu, NULL)); 1493 PetscOptionsEnd(); 1494 } 1495 break; 1496 case MATPRODUCT_AtB: 1497 if (product->api_user) { 1498 PetscOptionsBegin(PetscObjectComm((PetscObject)mat), ((PetscObject)mat)->prefix, "MatTransposeMatMult", "Mat"); 1499 PetscCall(PetscOptionsBool("-mattransposematmult_backend_cpu", "Use CPU code", "MatTransposeMatMult", usecpu, &usecpu, NULL)); 1500 PetscOptionsEnd(); 1501 } else { 1502 PetscOptionsBegin(PetscObjectComm((PetscObject)mat), ((PetscObject)mat)->prefix, "MatProduct_AtB", "Mat"); 1503 PetscCall(PetscOptionsBool("-mat_product_algorithm_backend_cpu", "Use CPU code", "MatTransposeMatMult", usecpu, &usecpu, NULL)); 1504 PetscOptionsEnd(); 1505 } 1506 break; 1507 case MATPRODUCT_PtAP: 1508 if (product->api_user) { 1509 PetscOptionsBegin(PetscObjectComm((PetscObject)mat), ((PetscObject)mat)->prefix, "MatPtAP", "Mat"); 1510 PetscCall(PetscOptionsBool("-matptap_backend_cpu", "Use CPU code", "MatPtAP", usecpu, &usecpu, NULL)); 1511 PetscOptionsEnd(); 1512 } else { 1513 PetscOptionsBegin(PetscObjectComm((PetscObject)mat), ((PetscObject)mat)->prefix, "MatProduct_PtAP", "Mat"); 1514 PetscCall(PetscOptionsBool("-mat_product_algorithm_backend_cpu", "Use CPU code", "MatPtAP", usecpu, &usecpu, NULL)); 1515 PetscOptionsEnd(); 1516 } 1517 break; 1518 default: 1519 break; 1520 } 1521 match = (PetscBool)!usecpu; 1522 } 1523 if (match) { 1524 switch (product->type) { 1525 case MATPRODUCT_AB: 1526 case MATPRODUCT_AtB: 1527 case MATPRODUCT_PtAP: 1528 mat->ops->productsymbolic = MatProductSymbolic_MPIAIJKokkos; 1529 break; 1530 default: 1531 break; 1532 } 1533 } 1534 /* fallback to MPIAIJ ops */ 1535 if (!mat->ops->productsymbolic) PetscCall(MatProductSetFromOptions_MPIAIJ(mat)); 1536 PetscFunctionReturn(PETSC_SUCCESS); 1537 } 1538 1539 // Mirror of MatCOOStruct_MPIAIJ on device 1540 struct MatCOOStruct_MPIAIJKokkos { 1541 PetscCount n; 1542 PetscSF sf; 1543 PetscCount Annz, Bnnz; 1544 PetscCount Annz2, Bnnz2; 1545 PetscCountKokkosView Ajmap1, Aperm1; 1546 PetscCountKokkosView Bjmap1, Bperm1; 1547 PetscCountKokkosView Aimap2, Ajmap2, Aperm2; 1548 PetscCountKokkosView Bimap2, Bjmap2, Bperm2; 1549 PetscCountKokkosView Cperm1; 1550 MatScalarKokkosView sendbuf, recvbuf; 1551 1552 MatCOOStruct_MPIAIJKokkos(const MatCOOStruct_MPIAIJ *coo_h) : 1553 n(coo_h->n), 1554 sf(coo_h->sf), 1555 Annz(coo_h->Annz), 1556 Bnnz(coo_h->Bnnz), 1557 Annz2(coo_h->Annz2), 1558 Bnnz2(coo_h->Bnnz2), 1559 Ajmap1(Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), PetscCountKokkosViewHost(coo_h->Ajmap1, coo_h->Annz + 1))), 1560 Aperm1(Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), PetscCountKokkosViewHost(coo_h->Aperm1, coo_h->Atot1))), 1561 Bjmap1(Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), PetscCountKokkosViewHost(coo_h->Bjmap1, coo_h->Bnnz + 1))), 1562 Bperm1(Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), PetscCountKokkosViewHost(coo_h->Bperm1, coo_h->Btot1))), 1563 Aimap2(Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), PetscCountKokkosViewHost(coo_h->Aimap2, coo_h->Annz2))), 1564 Ajmap2(Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), PetscCountKokkosViewHost(coo_h->Ajmap2, coo_h->Annz2 + 1))), 1565 Aperm2(Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), PetscCountKokkosViewHost(coo_h->Aperm2, coo_h->Atot2))), 1566 Bimap2(Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), PetscCountKokkosViewHost(coo_h->Bimap2, coo_h->Bnnz2))), 1567 Bjmap2(Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), PetscCountKokkosViewHost(coo_h->Bjmap2, coo_h->Bnnz2 + 1))), 1568 Bperm2(Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), PetscCountKokkosViewHost(coo_h->Bperm2, coo_h->Btot2))), 1569 Cperm1(Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), PetscCountKokkosViewHost(coo_h->Cperm1, coo_h->sendlen))), 1570 sendbuf(Kokkos::create_mirror_view(Kokkos::WithoutInitializing, DefaultMemorySpace(), MatScalarKokkosViewHost(coo_h->sendbuf, coo_h->sendlen))), 1571 recvbuf(Kokkos::create_mirror_view(Kokkos::WithoutInitializing, DefaultMemorySpace(), MatScalarKokkosViewHost(coo_h->recvbuf, coo_h->recvlen))) 1572 { 1573 PetscCallVoid(PetscObjectReference((PetscObject)sf)); 1574 } 1575 1576 ~MatCOOStruct_MPIAIJKokkos() { PetscCallVoid(PetscSFDestroy(&sf)); } 1577 }; 1578 1579 static PetscErrorCode MatCOOStructDestroy_MPIAIJKokkos(void *data) 1580 { 1581 PetscFunctionBegin; 1582 PetscCallCXX(delete static_cast<MatCOOStruct_MPIAIJKokkos *>(data)); 1583 PetscFunctionReturn(PETSC_SUCCESS); 1584 } 1585 1586 static PetscErrorCode MatSetPreallocationCOO_MPIAIJKokkos(Mat mat, PetscCount coo_n, PetscInt coo_i[], PetscInt coo_j[]) 1587 { 1588 PetscContainer container_h, container_d; 1589 MatCOOStruct_MPIAIJ *coo_h; 1590 MatCOOStruct_MPIAIJKokkos *coo_d; 1591 1592 PetscFunctionBegin; 1593 PetscCall(MatSetPreallocationCOO_MPIAIJ(mat, coo_n, coo_i, coo_j)); /* mpiaij->A,B's type is set to seqaijkokkos */ 1594 mat->preallocated = PETSC_TRUE; 1595 PetscCall(MatAssemblyBegin(mat, MAT_FINAL_ASSEMBLY)); 1596 PetscCall(MatAssemblyEnd(mat, MAT_FINAL_ASSEMBLY)); 1597 PetscCall(MatZeroEntries(mat)); 1598 1599 // Copy the COO struct to device 1600 PetscCall(PetscObjectQuery((PetscObject)mat, "__PETSc_MatCOOStruct_Host", (PetscObject *)&container_h)); 1601 PetscCall(PetscContainerGetPointer(container_h, (void **)&coo_h)); 1602 PetscCallCXX(coo_d = new MatCOOStruct_MPIAIJKokkos(coo_h)); 1603 1604 // Put the COO struct in a container and then attach that to the matrix 1605 PetscCall(PetscContainerCreate(PETSC_COMM_SELF, &container_d)); 1606 PetscCall(PetscContainerSetPointer(container_d, coo_d)); 1607 PetscCall(PetscContainerSetUserDestroy(container_d, MatCOOStructDestroy_MPIAIJKokkos)); 1608 PetscCall(PetscObjectCompose((PetscObject)mat, "__PETSc_MatCOOStruct_Device", (PetscObject)container_d)); 1609 PetscCall(PetscContainerDestroy(&container_d)); 1610 PetscFunctionReturn(PETSC_SUCCESS); 1611 } 1612 1613 static PetscErrorCode MatSetValuesCOO_MPIAIJKokkos(Mat mat, const PetscScalar v[], InsertMode imode) 1614 { 1615 Mat_MPIAIJ *mpiaij = static_cast<Mat_MPIAIJ *>(mat->data); 1616 Mat A = mpiaij->A, B = mpiaij->B; 1617 MatScalarKokkosView Aa, Ba; 1618 MatScalarKokkosView v1; 1619 PetscMemType memtype; 1620 PetscContainer container; 1621 MatCOOStruct_MPIAIJKokkos *coo; 1622 1623 PetscFunctionBegin; 1624 PetscCall(PetscObjectQuery((PetscObject)mat, "__PETSc_MatCOOStruct_Device", (PetscObject *)&container)); 1625 PetscCall(PetscContainerGetPointer(container, (void **)&coo)); 1626 1627 const auto &n = coo->n; 1628 const auto &Annz = coo->Annz; 1629 const auto &Annz2 = coo->Annz2; 1630 const auto &Bnnz = coo->Bnnz; 1631 const auto &Bnnz2 = coo->Bnnz2; 1632 const auto &vsend = coo->sendbuf; 1633 const auto &v2 = coo->recvbuf; 1634 const auto &Ajmap1 = coo->Ajmap1; 1635 const auto &Ajmap2 = coo->Ajmap2; 1636 const auto &Aimap2 = coo->Aimap2; 1637 const auto &Bjmap1 = coo->Bjmap1; 1638 const auto &Bjmap2 = coo->Bjmap2; 1639 const auto &Bimap2 = coo->Bimap2; 1640 const auto &Aperm1 = coo->Aperm1; 1641 const auto &Aperm2 = coo->Aperm2; 1642 const auto &Bperm1 = coo->Bperm1; 1643 const auto &Bperm2 = coo->Bperm2; 1644 const auto &Cperm1 = coo->Cperm1; 1645 1646 PetscCall(PetscGetMemType(v, &memtype)); /* Return PETSC_MEMTYPE_HOST when v is NULL */ 1647 if (PetscMemTypeHost(memtype)) { /* If user gave v[] in host, we need to copy it to device if any */ 1648 v1 = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), MatScalarKokkosViewHost((PetscScalar *)v, n)); 1649 } else { 1650 v1 = MatScalarKokkosView((PetscScalar *)v, n); /* Directly use v[]'s memory */ 1651 } 1652 1653 if (imode == INSERT_VALUES) { 1654 PetscCall(MatSeqAIJGetKokkosViewWrite(A, &Aa)); /* write matrix values */ 1655 PetscCall(MatSeqAIJGetKokkosViewWrite(B, &Ba)); 1656 } else { 1657 PetscCall(MatSeqAIJGetKokkosView(A, &Aa)); /* read & write matrix values */ 1658 PetscCall(MatSeqAIJGetKokkosView(B, &Ba)); 1659 } 1660 1661 PetscCall(PetscLogGpuTimeBegin()); 1662 /* Pack entries to be sent to remote */ 1663 Kokkos::parallel_for(vsend.extent(0), KOKKOS_LAMBDA(const PetscCount i) { vsend(i) = v1(Cperm1(i)); }); 1664 1665 /* Send remote entries to their owner and overlap the communication with local computation */ 1666 PetscCall(PetscSFReduceWithMemTypeBegin(coo->sf, MPIU_SCALAR, PETSC_MEMTYPE_KOKKOS, vsend.data(), PETSC_MEMTYPE_KOKKOS, v2.data(), MPI_REPLACE)); 1667 /* Add local entries to A and B in one kernel */ 1668 Kokkos::parallel_for( 1669 Annz + Bnnz, KOKKOS_LAMBDA(PetscCount i) { 1670 PetscScalar sum = 0.0; 1671 if (i < Annz) { 1672 for (PetscCount k = Ajmap1(i); k < Ajmap1(i + 1); k++) sum += v1(Aperm1(k)); 1673 Aa(i) = (imode == INSERT_VALUES ? 0.0 : Aa(i)) + sum; 1674 } else { 1675 i -= Annz; 1676 for (PetscCount k = Bjmap1(i); k < Bjmap1(i + 1); k++) sum += v1(Bperm1(k)); 1677 Ba(i) = (imode == INSERT_VALUES ? 0.0 : Ba(i)) + sum; 1678 } 1679 }); 1680 PetscCall(PetscSFReduceEnd(coo->sf, MPIU_SCALAR, vsend.data(), v2.data(), MPI_REPLACE)); 1681 1682 /* Add received remote entries to A and B in one kernel */ 1683 Kokkos::parallel_for( 1684 Annz2 + Bnnz2, KOKKOS_LAMBDA(PetscCount i) { 1685 if (i < Annz2) { 1686 for (PetscCount k = Ajmap2(i); k < Ajmap2(i + 1); k++) Aa(Aimap2(i)) += v2(Aperm2(k)); 1687 } else { 1688 i -= Annz2; 1689 for (PetscCount k = Bjmap2(i); k < Bjmap2(i + 1); k++) Ba(Bimap2(i)) += v2(Bperm2(k)); 1690 } 1691 }); 1692 PetscCall(PetscLogGpuTimeEnd()); 1693 1694 if (imode == INSERT_VALUES) { 1695 PetscCall(MatSeqAIJRestoreKokkosViewWrite(A, &Aa)); /* Increase A & B's state etc. */ 1696 PetscCall(MatSeqAIJRestoreKokkosViewWrite(B, &Ba)); 1697 } else { 1698 PetscCall(MatSeqAIJRestoreKokkosView(A, &Aa)); 1699 PetscCall(MatSeqAIJRestoreKokkosView(B, &Ba)); 1700 } 1701 PetscFunctionReturn(PETSC_SUCCESS); 1702 } 1703 1704 static PetscErrorCode MatDestroy_MPIAIJKokkos(Mat A) 1705 { 1706 PetscFunctionBegin; 1707 PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatMPIAIJSetPreallocation_C", NULL)); 1708 PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatMPIAIJGetLocalMatMerge_C", NULL)); 1709 PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatSetPreallocationCOO_C", NULL)); 1710 PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatSetValuesCOO_C", NULL)); 1711 PetscCall(MatDestroy_MPIAIJ(A)); 1712 PetscFunctionReturn(PETSC_SUCCESS); 1713 } 1714 1715 static PetscErrorCode MatShift_MPIAIJKokkos(Mat A, PetscScalar a) 1716 { 1717 Mat_MPIAIJ *mpiaij = static_cast<Mat_MPIAIJ *>(A->data); 1718 PetscBool congruent; 1719 1720 PetscFunctionBegin; 1721 PetscCall(MatHasCongruentLayouts(A, &congruent)); 1722 if (congruent) { // square matrix and the diagonals are solely in the diag block 1723 PetscCall(MatShift(mpiaij->A, a)); 1724 } else { // too hard, use the general version 1725 PetscCall(MatShift_Basic(A, a)); 1726 } 1727 PetscFunctionReturn(PETSC_SUCCESS); 1728 } 1729 1730 static PetscErrorCode MatSetOps_MPIAIJKokkos(Mat B) 1731 { 1732 PetscFunctionBegin; 1733 B->ops->assemblyend = MatAssemblyEnd_MPIAIJKokkos; 1734 B->ops->mult = MatMult_MPIAIJKokkos; 1735 B->ops->multadd = MatMultAdd_MPIAIJKokkos; 1736 B->ops->multtranspose = MatMultTranspose_MPIAIJKokkos; 1737 B->ops->productsetfromoptions = MatProductSetFromOptions_MPIAIJKokkos; 1738 B->ops->destroy = MatDestroy_MPIAIJKokkos; 1739 B->ops->shift = MatShift_MPIAIJKokkos; 1740 1741 PetscCall(PetscObjectComposeFunction((PetscObject)B, "MatMPIAIJSetPreallocation_C", MatMPIAIJSetPreallocation_MPIAIJKokkos)); 1742 PetscCall(PetscObjectComposeFunction((PetscObject)B, "MatMPIAIJGetLocalMatMerge_C", MatMPIAIJGetLocalMatMerge_MPIAIJKokkos)); 1743 PetscCall(PetscObjectComposeFunction((PetscObject)B, "MatSetPreallocationCOO_C", MatSetPreallocationCOO_MPIAIJKokkos)); 1744 PetscCall(PetscObjectComposeFunction((PetscObject)B, "MatSetValuesCOO_C", MatSetValuesCOO_MPIAIJKokkos)); 1745 PetscFunctionReturn(PETSC_SUCCESS); 1746 } 1747 1748 PETSC_INTERN PetscErrorCode MatConvert_MPIAIJ_MPIAIJKokkos(Mat A, MatType mtype, MatReuse reuse, Mat *newmat) 1749 { 1750 Mat B; 1751 Mat_MPIAIJ *a; 1752 1753 PetscFunctionBegin; 1754 if (reuse == MAT_INITIAL_MATRIX) { 1755 PetscCall(MatDuplicate(A, MAT_COPY_VALUES, newmat)); 1756 } else if (reuse == MAT_REUSE_MATRIX) { 1757 PetscCall(MatCopy(A, *newmat, SAME_NONZERO_PATTERN)); 1758 } 1759 B = *newmat; 1760 1761 B->boundtocpu = PETSC_FALSE; 1762 PetscCall(PetscFree(B->defaultvectype)); 1763 PetscCall(PetscStrallocpy(VECKOKKOS, &B->defaultvectype)); 1764 PetscCall(PetscObjectChangeTypeName((PetscObject)B, MATMPIAIJKOKKOS)); 1765 1766 a = static_cast<Mat_MPIAIJ *>(A->data); 1767 if (a->A) PetscCall(MatSetType(a->A, MATSEQAIJKOKKOS)); 1768 if (a->B) PetscCall(MatSetType(a->B, MATSEQAIJKOKKOS)); 1769 if (a->lvec) PetscCall(VecSetType(a->lvec, VECSEQKOKKOS)); 1770 PetscCall(MatSetOps_MPIAIJKokkos(B)); 1771 PetscFunctionReturn(PETSC_SUCCESS); 1772 } 1773 1774 /*MC 1775 MATAIJKOKKOS - "mpiaijkokkos", a matrix type to be used for CSR sparse matrices with Kokkos 1776 1777 A matrix type using Kokkos-Kernels CrsMatrix type for portability across different device types 1778 1779 Options Database Key: 1780 . -mat_type aijkokkos - sets the matrix type to `MATAIJKOKKOS` 1781 1782 Level: beginner 1783 1784 .seealso: [](ch_matrices), `Mat`, `MatCreateAIJKokkos()`, `MATSEQAIJKOKKOS`, `MATSEQAIJ`, `MATMPIAIJ` 1785 M*/ 1786 PETSC_EXTERN PetscErrorCode MatCreate_MPIAIJKokkos(Mat A) 1787 { 1788 PetscFunctionBegin; 1789 PetscCall(PetscKokkosInitializeCheck()); 1790 PetscCall(MatCreate_MPIAIJ(A)); 1791 PetscCall(MatConvert_MPIAIJ_MPIAIJKokkos(A, MATMPIAIJKOKKOS, MAT_INPLACE_MATRIX, &A)); 1792 PetscFunctionReturn(PETSC_SUCCESS); 1793 } 1794 1795 /*@C 1796 MatCreateAIJKokkos - Creates a sparse matrix in `MATAIJKOKOS` (compressed row) format 1797 (the default parallel PETSc format). This matrix will ultimately pushed down 1798 to Kokkos for calculations. 1799 1800 Collective 1801 1802 Input Parameters: 1803 + comm - MPI communicator, set to `PETSC_COMM_SELF` 1804 . m - number of local rows (or `PETSC_DECIDE` to have calculated if `M` is given) 1805 This value should be the same as the local size used in creating the 1806 y vector for the matrix-vector product y = Ax. 1807 . n - This value should be the same as the local size used in creating the 1808 x vector for the matrix-vector product y = Ax. (or `PETSC_DECIDE` to have 1809 calculated if N is given) For square matrices n is almost always `m`. 1810 . M - number of global rows (or `PETSC_DETERMINE` to have calculated if `m` is given) 1811 . N - number of global columns (or `PETSC_DETERMINE` to have calculated if `n` is given) 1812 . d_nz - number of nonzeros per row in DIAGONAL portion of local submatrix 1813 (same value is used for all local rows) 1814 . d_nnz - array containing the number of nonzeros in the various rows of the 1815 DIAGONAL portion of the local submatrix (possibly different for each row) 1816 or `NULL`, if `d_nz` is used to specify the nonzero structure. 1817 The size of this array is equal to the number of local rows, i.e `m`. 1818 For matrices you plan to factor you must leave room for the diagonal entry and 1819 put in the entry even if it is zero. 1820 . o_nz - number of nonzeros per row in the OFF-DIAGONAL portion of local 1821 submatrix (same value is used for all local rows). 1822 - o_nnz - array containing the number of nonzeros in the various rows of the 1823 OFF-DIAGONAL portion of the local submatrix (possibly different for 1824 each row) or `NULL`, if `o_nz` is used to specify the nonzero 1825 structure. The size of this array is equal to the number 1826 of local rows, i.e `m`. 1827 1828 Output Parameter: 1829 . A - the matrix 1830 1831 Level: intermediate 1832 1833 Notes: 1834 It is recommended that one use the `MatCreate()`, `MatSetType()` and/or `MatSetFromOptions()`, 1835 MatXXXXSetPreallocation() paradigm instead of this routine directly. 1836 [MatXXXXSetPreallocation() is, for example, `MatSeqAIJSetPreallocation()`] 1837 1838 The AIJ format, also called compressed row storage), is fully compatible with standard Fortran 1839 storage. That is, the stored row and column indices can begin at 1840 either one (as in Fortran) or zero. 1841 1842 .seealso: [](ch_matrices), `Mat`, `MATAIJKOKOS`, `MATSEQAIJKOKOS`, `MATMPIAIJKOKOS`, `MatCreate()`, `MatCreateAIJ()`, `MatSetValues()`, 1843 `MatSeqAIJSetColumnIndices()`, `MatCreateSeqAIJWithArrays()`, `MATMPIAIJKOKKOS`, `MATAIJKOKKOS` 1844 @*/ 1845 PetscErrorCode MatCreateAIJKokkos(MPI_Comm comm, PetscInt m, PetscInt n, PetscInt M, PetscInt N, PetscInt d_nz, const PetscInt d_nnz[], PetscInt o_nz, const PetscInt o_nnz[], Mat *A) 1846 { 1847 PetscMPIInt size; 1848 1849 PetscFunctionBegin; 1850 PetscCall(MatCreate(comm, A)); 1851 PetscCall(MatSetSizes(*A, m, n, M, N)); 1852 PetscCallMPI(MPI_Comm_size(comm, &size)); 1853 if (size > 1) { 1854 PetscCall(MatSetType(*A, MATMPIAIJKOKKOS)); 1855 PetscCall(MatMPIAIJSetPreallocation(*A, d_nz, d_nnz, o_nz, o_nnz)); 1856 } else { 1857 PetscCall(MatSetType(*A, MATSEQAIJKOKKOS)); 1858 PetscCall(MatSeqAIJSetPreallocation(*A, d_nz, d_nnz)); 1859 } 1860 PetscFunctionReturn(PETSC_SUCCESS); 1861 } 1862