1 #include <petscvec_kokkos.hpp> 2 #include <petscpkg_version.h> 3 #include <petsc/private/sfimpl.h> 4 #include <../src/mat/impls/aij/seq/kokkos/aijkok.hpp> 5 #include <../src/mat/impls/aij/mpi/mpiaij.h> 6 #include <KokkosSparse_spadd.hpp> 7 #include <KokkosSparse_spgemm.hpp> 8 9 static PetscErrorCode MatAssemblyEnd_MPIAIJKokkos(Mat A, MatAssemblyType mode) 10 { 11 Mat_MPIAIJ *mpiaij = (Mat_MPIAIJ *)A->data; 12 13 PetscFunctionBegin; 14 PetscCall(MatAssemblyEnd_MPIAIJ(A, mode)); 15 /* E.g., MatCreateSubMatrix() calls MatCreateMPIAIJWithSeqAIJ(comm,A,B,..), which creates Bnew of SEQAIJ and destroys B of SEQAIJKOKKOS. 16 Thus we finalize A/B/lvec's type in MatAssemblyEnd() to handle various cases. 17 */ 18 if (mode == MAT_FINAL_ASSEMBLY) { 19 PetscCall(MatSetType(mpiaij->A, MATSEQAIJKOKKOS)); 20 PetscCall(MatSetType(mpiaij->B, MATSEQAIJKOKKOS)); 21 PetscCall(VecSetType(mpiaij->lvec, VECSEQKOKKOS)); 22 } 23 PetscFunctionReturn(PETSC_SUCCESS); 24 } 25 26 static PetscErrorCode MatMPIAIJSetPreallocation_MPIAIJKokkos(Mat mat, PetscInt d_nz, const PetscInt d_nnz[], PetscInt o_nz, const PetscInt o_nnz[]) 27 { 28 Mat_MPIAIJ *mpiaij = (Mat_MPIAIJ *)mat->data; 29 30 PetscFunctionBegin; 31 // If mat was set to use the "set values with a hash table" mechanism, discard it and restore the cached ops 32 if (mat->hash_active) { 33 mat->ops[0] = mpiaij->cops; 34 mat->hash_active = PETSC_FALSE; 35 } 36 37 PetscCall(PetscLayoutSetUp(mat->rmap)); 38 PetscCall(PetscLayoutSetUp(mat->cmap)); 39 #if defined(PETSC_USE_DEBUG) 40 if (d_nnz) { 41 PetscInt i; 42 for (i = 0; i < mat->rmap->n; i++) PetscCheck(d_nnz[i] >= 0, PETSC_COMM_SELF, PETSC_ERR_ARG_OUTOFRANGE, "d_nnz cannot be less than 0: local row %" PetscInt_FMT " value %" PetscInt_FMT, i, d_nnz[i]); 43 } 44 if (o_nnz) { 45 PetscInt i; 46 for (i = 0; i < mat->rmap->n; i++) PetscCheck(o_nnz[i] >= 0, PETSC_COMM_SELF, PETSC_ERR_ARG_OUTOFRANGE, "o_nnz cannot be less than 0: local row %" PetscInt_FMT " value %" PetscInt_FMT, i, o_nnz[i]); 47 } 48 #endif 49 #if defined(PETSC_USE_CTABLE) 50 PetscCall(PetscHMapIDestroy(&mpiaij->colmap)); 51 #else 52 PetscCall(PetscFree(mpiaij->colmap)); 53 #endif 54 PetscCall(PetscFree(mpiaij->garray)); 55 PetscCall(VecDestroy(&mpiaij->lvec)); 56 PetscCall(VecScatterDestroy(&mpiaij->Mvctx)); 57 /* Because the B will have been resized we simply destroy it and create a new one each time */ 58 PetscCall(MatDestroy(&mpiaij->B)); 59 60 if (!mpiaij->A) { 61 PetscCall(MatCreate(PETSC_COMM_SELF, &mpiaij->A)); 62 PetscCall(MatSetSizes(mpiaij->A, mat->rmap->n, mat->cmap->n, mat->rmap->n, mat->cmap->n)); 63 } 64 if (!mpiaij->B) { 65 PetscMPIInt size; 66 PetscCallMPI(MPI_Comm_size(PetscObjectComm((PetscObject)mat), &size)); 67 PetscCall(MatCreate(PETSC_COMM_SELF, &mpiaij->B)); 68 PetscCall(MatSetSizes(mpiaij->B, mat->rmap->n, size > 1 ? mat->cmap->N : 0, mat->rmap->n, size > 1 ? mat->cmap->N : 0)); 69 } 70 PetscCall(MatSetType(mpiaij->A, MATSEQAIJKOKKOS)); 71 PetscCall(MatSetType(mpiaij->B, MATSEQAIJKOKKOS)); 72 PetscCall(MatSeqAIJSetPreallocation(mpiaij->A, d_nz, d_nnz)); 73 PetscCall(MatSeqAIJSetPreallocation(mpiaij->B, o_nz, o_nnz)); 74 mat->preallocated = PETSC_TRUE; 75 PetscFunctionReturn(PETSC_SUCCESS); 76 } 77 78 static PetscErrorCode MatMult_MPIAIJKokkos(Mat mat, Vec xx, Vec yy) 79 { 80 Mat_MPIAIJ *mpiaij = (Mat_MPIAIJ *)mat->data; 81 PetscInt nt; 82 83 PetscFunctionBegin; 84 PetscCall(VecGetLocalSize(xx, &nt)); 85 PetscCheck(nt == mat->cmap->n, PETSC_COMM_SELF, PETSC_ERR_ARG_SIZ, "Incompatible partition of mat (%" PetscInt_FMT ") and xx (%" PetscInt_FMT ")", mat->cmap->n, nt); 86 PetscCall(VecScatterBegin(mpiaij->Mvctx, xx, mpiaij->lvec, INSERT_VALUES, SCATTER_FORWARD)); 87 PetscCall((*mpiaij->A->ops->mult)(mpiaij->A, xx, yy)); 88 PetscCall(VecScatterEnd(mpiaij->Mvctx, xx, mpiaij->lvec, INSERT_VALUES, SCATTER_FORWARD)); 89 PetscCall((*mpiaij->B->ops->multadd)(mpiaij->B, mpiaij->lvec, yy, yy)); 90 PetscFunctionReturn(PETSC_SUCCESS); 91 } 92 93 static PetscErrorCode MatMultAdd_MPIAIJKokkos(Mat mat, Vec xx, Vec yy, Vec zz) 94 { 95 Mat_MPIAIJ *mpiaij = (Mat_MPIAIJ *)mat->data; 96 PetscInt nt; 97 98 PetscFunctionBegin; 99 PetscCall(VecGetLocalSize(xx, &nt)); 100 PetscCheck(nt == mat->cmap->n, PETSC_COMM_SELF, PETSC_ERR_ARG_SIZ, "Incompatible partition of mat (%" PetscInt_FMT ") and xx (%" PetscInt_FMT ")", mat->cmap->n, nt); 101 PetscCall(VecScatterBegin(mpiaij->Mvctx, xx, mpiaij->lvec, INSERT_VALUES, SCATTER_FORWARD)); 102 PetscCall((*mpiaij->A->ops->multadd)(mpiaij->A, xx, yy, zz)); 103 PetscCall(VecScatterEnd(mpiaij->Mvctx, xx, mpiaij->lvec, INSERT_VALUES, SCATTER_FORWARD)); 104 PetscCall((*mpiaij->B->ops->multadd)(mpiaij->B, mpiaij->lvec, zz, zz)); 105 PetscFunctionReturn(PETSC_SUCCESS); 106 } 107 108 static PetscErrorCode MatMultTranspose_MPIAIJKokkos(Mat mat, Vec xx, Vec yy) 109 { 110 Mat_MPIAIJ *mpiaij = (Mat_MPIAIJ *)mat->data; 111 PetscInt nt; 112 113 PetscFunctionBegin; 114 PetscCall(VecGetLocalSize(xx, &nt)); 115 PetscCheck(nt == mat->rmap->n, PETSC_COMM_SELF, PETSC_ERR_ARG_SIZ, "Incompatible partition of mat (%" PetscInt_FMT ") and xx (%" PetscInt_FMT ")", mat->rmap->n, nt); 116 PetscCall((*mpiaij->B->ops->multtranspose)(mpiaij->B, xx, mpiaij->lvec)); 117 PetscCall((*mpiaij->A->ops->multtranspose)(mpiaij->A, xx, yy)); 118 PetscCall(VecScatterBegin(mpiaij->Mvctx, mpiaij->lvec, yy, ADD_VALUES, SCATTER_REVERSE)); 119 PetscCall(VecScatterEnd(mpiaij->Mvctx, mpiaij->lvec, yy, ADD_VALUES, SCATTER_REVERSE)); 120 PetscFunctionReturn(PETSC_SUCCESS); 121 } 122 123 /* Merge the "A, B" matrices of mat into a matrix C. mat's type is MPIAIJKOKKOS. C's type is MATSEQAIJKOKKOS. 124 A is put before B. C's size would be A->rmap->n by (A->cmap->n + B->cmap->n). 125 C still uses local column ids. Their corresponding global column ids are returned in glob. 126 */ 127 static PetscErrorCode MatMPIAIJGetLocalMatMerge_MPIAIJKokkos(Mat mat, MatReuse reuse, IS *glob, Mat *C) 128 { 129 Mat Ad, Ao; 130 const PetscInt *cmap; 131 132 PetscFunctionBegin; 133 PetscCall(MatMPIAIJGetSeqAIJ(mat, &Ad, &Ao, &cmap)); 134 PetscCall(MatSeqAIJKokkosMergeMats(Ad, Ao, reuse, C)); 135 if (glob) { 136 PetscInt cst, i, dn, on, *gidx; 137 PetscCall(MatGetLocalSize(Ad, NULL, &dn)); 138 PetscCall(MatGetLocalSize(Ao, NULL, &on)); 139 PetscCall(MatGetOwnershipRangeColumn(mat, &cst, NULL)); 140 PetscCall(PetscMalloc1(dn + on, &gidx)); 141 for (i = 0; i < dn; i++) gidx[i] = cst + i; 142 for (i = 0; i < on; i++) gidx[i + dn] = cmap[i]; 143 PetscCall(ISCreateGeneral(PetscObjectComm((PetscObject)Ad), dn + on, gidx, PETSC_OWN_POINTER, glob)); 144 } 145 PetscFunctionReturn(PETSC_SUCCESS); 146 } 147 148 /* Structs used in matrix products of type C=AB, C=A^tB and C=B^tAB */ 149 struct MatMatStruct { 150 PetscInt n, *garray; // C's garray and its size. 151 KokkosCsrMatrix Cd, Co; // C is in split form matrices (all in local column indcies) 152 KokkosCsrMatrix C1, C2, C3, C4; // intermediate mat products 153 KokkosCsrMatrix C2_mid, C4_mid; // alias of C2, C4; share their a[], i[], but with different j[] (hence column size) 154 PetscIntKokkosView E_NzLeft; 155 PetscSF sf = nullptr; // SF to bcast or reduce matrices E to F 156 MatScalarKokkosView rootBuf, leafBuf; 157 KokkosCsrMatrix Fd, Fo; // F in split form 158 159 KernelHandle kh1; // compute C1, add C1+C3 or C1+Fd 160 KernelHandle kh2; // compute C2, add C2+C4 or C2+Fo 161 KernelHandle kh3; // compute C3 162 KernelHandle kh4; // compute C4 163 164 PetscInt E_TeamSize; // kernel launching parameters in merging E or splitting F 165 PetscInt E_VectorLength; 166 PetscInt E_RowsPerTeam; 167 PetscInt F_TeamSize; 168 PetscInt F_VectorLength; 169 PetscInt F_RowsPerTeam; 170 171 ~MatMatStruct() 172 { 173 PetscFunctionBegin; 174 PetscCallAbort(PETSC_COMM_SELF, PetscSFDestroy(&sf)); 175 PetscFunctionReturnVoid(); 176 } 177 }; 178 179 struct MatMatStruct_AB : public MatMatStruct { 180 PetscIntKokkosView F_NzLeft; // plans to split F (in leafbuf) into Fd, Fo 181 PetscIntKokkosView irootloc; // plans to put E (i.e., Bd, Bo) into rootBuf 182 PetscIntKokkosView rowoffset; 183 }; 184 185 struct MatMatStruct_AtB : public MatMatStruct { 186 MatColIdxKokkosView Fdjmap; // plans to reduce data in rootBuf to Fd, Fo 187 MatColIdxKokkosView Fdjperm; 188 MatColIdxKokkosView Fojmap; 189 MatColIdxKokkosView Fojperm; 190 }; 191 192 struct MatProductData_MPIAIJKokkos { 193 MatMatStruct_AB *mmAB = nullptr; 194 MatMatStruct_AtB *mmAtB = nullptr; 195 PetscBool reusesym = PETSC_FALSE; 196 Mat Z = nullptr; // store Z=AB in computing BtAB 197 198 ~MatProductData_MPIAIJKokkos() 199 { 200 delete mmAB; 201 delete mmAtB; 202 PetscCallAbort(PETSC_COMM_SELF, MatDestroy(&Z)); 203 } 204 }; 205 206 static PetscErrorCode MatProductDataDestroy_MPIAIJKokkos(void *data) 207 { 208 PetscFunctionBegin; 209 PetscCallCXX(delete static_cast<MatProductData_MPIAIJKokkos *>(data)); 210 PetscFunctionReturn(PETSC_SUCCESS); 211 } 212 213 /* MatSetMPIAIJKokkosWithSplitSeqAIJKokkosMatrices - Set the diag and offdiag matrices of a MATMPIAIJKOKKOS matrix. 214 It is similar to MatCreateMPIAIJWithSplitArrays. 215 216 Input Parameters: 217 + mat - the MATMPIAIJKOKKOS matrix, which should have its type and layout set, but should not have its diag, offdiag matrices set 218 . A - the diag matrix using local col ids 219 - B - the offdiag matrix using global col ids 220 221 Output Parameter: 222 . mat - the updated MATMPIAIJKOKKOS matrix 223 */ 224 static PetscErrorCode MatSetMPIAIJKokkosWithSplitSeqAIJKokkosMatrices(Mat mat, Mat A, Mat B, PetscInt *garray) 225 { 226 Mat_MPIAIJ *mpiaij = static_cast<Mat_MPIAIJ *>(mat->data); 227 PetscInt m, n, M, N, Am, An, Bm, Bn; 228 229 PetscFunctionBegin; 230 PetscCall(MatGetSize(mat, &M, &N)); 231 PetscCall(MatGetLocalSize(mat, &m, &n)); 232 PetscCall(MatGetLocalSize(A, &Am, &An)); 233 PetscCall(MatGetLocalSize(B, &Bm, &Bn)); 234 235 PetscCheck(m == Am && m == Bm, PETSC_COMM_SELF, PETSC_ERR_PLIB, "local number of rows do not match"); 236 PetscCheck(n == An, PETSC_COMM_SELF, PETSC_ERR_PLIB, "local number of columns do not match"); 237 // PetscCheck(N == Bn, PETSC_COMM_SELF, PETSC_ERR_PLIB, "global number of columns do not match"); 238 PetscCheck(!mpiaij->A && !mpiaij->B, PETSC_COMM_SELF, PETSC_ERR_PLIB, "A, B of the MPIAIJ matrix are not empty"); 239 mpiaij->A = A; 240 mpiaij->B = B; 241 mpiaij->garray = garray; 242 243 mat->preallocated = PETSC_TRUE; 244 mat->nooffprocentries = PETSC_TRUE; /* See MatAssemblyBegin_MPIAIJ. In effect, making MatAssemblyBegin a nop */ 245 246 PetscCall(MatSetOption(mat, MAT_NO_OFF_PROC_ENTRIES, PETSC_TRUE)); 247 PetscCall(MatAssemblyBegin(mat, MAT_FINAL_ASSEMBLY)); 248 /* MatAssemblyEnd is critical here. It sets mat->offloadmask according to A and B's, and 249 also gets mpiaij->B compacted, with its col ids and size reduced 250 */ 251 PetscCall(MatAssemblyEnd(mat, MAT_FINAL_ASSEMBLY)); 252 PetscCall(MatSetOption(mat, MAT_NO_OFF_PROC_ENTRIES, PETSC_FALSE)); 253 PetscCall(MatSetOption(mat, MAT_NEW_NONZERO_LOCATION_ERR, PETSC_TRUE)); 254 PetscFunctionReturn(PETSC_SUCCESS); 255 } 256 257 // Adapted from Kokkos-Kernels spmv_launch_parameters(), to get parameters in Kokkos nested loops which we used to merge or 258 // split csr matrices. The rule is to have "vector_length * team_size" be around 256 on GPUs (e.g., for a CUDA thread block) 259 template <class ExecutionSpace> 260 static PetscErrorCode MatMergeGetLaunchParameters(PetscInt numRows, PetscInt nnz, PetscInt rows_per_thread, PetscInt &team_size, PetscInt &vector_length, PetscInt &rows_per_team) 261 { 262 Kokkos::TeamPolicy<ExecutionSpace> teamPolicy(128, Kokkos::AUTO); 263 264 PetscFunctionBegin; 265 PetscInt nnz_per_row = numRows ? (nnz / numRows) : 0; // we might meet empty matrices 266 267 if (nnz_per_row < 1) nnz_per_row = 1; 268 269 int max_vector_length = teamPolicy.vector_length_max(); 270 271 if (vector_length < 1) { 272 vector_length = 1; 273 while (vector_length < max_vector_length && vector_length * 6 < nnz_per_row) vector_length *= 2; 274 } 275 276 // Determine rows per thread 277 if (rows_per_thread < 1) { 278 if (KokkosKernels::Impl::kk_is_gpu_exec_space<ExecutionSpace>()) rows_per_thread = 1; 279 else { 280 if (nnz_per_row < 20 && nnz > 5000000) { 281 rows_per_thread = 256; 282 } else rows_per_thread = 64; 283 } 284 } 285 286 if (team_size < 1) { 287 if (KokkosKernels::Impl::kk_is_gpu_exec_space<ExecutionSpace>()) { 288 team_size = 256 / vector_length; 289 } else { 290 team_size = 1; 291 } 292 } 293 294 rows_per_team = rows_per_thread * team_size; 295 296 if (rows_per_team < 0) { 297 PetscInt nnz_per_team = 4096; 298 PetscInt conc = ExecutionSpace().concurrency(); 299 while ((conc * nnz_per_team * 4 > nnz) && (nnz_per_team > 256)) nnz_per_team /= 2; 300 rows_per_team = (nnz_per_team + nnz_per_row - 1) / nnz_per_row; 301 } 302 PetscFunctionReturn(PETSC_SUCCESS); 303 } 304 305 /* 306 Reduce two sets of global indices into local ones 307 308 Input Parameters: 309 + n1 - size of garray1[], the first set 310 . garray1[n1] - a sorted global index array (without duplicates) 311 . m - size of indices[], the second set 312 - indices[m] - a unsorted global index array (might have duplicates), which will be updated on output into local ones 313 314 Output Parameters: 315 + n2 - size of garray2[], the merged set, which combines garray1[] and indices[] 316 . garray2[n2] - allocated by callee using PetscMalloc1(). Contains sorted unique global indices (without duplicates). Caller needs to free it. 317 . map[n1] - allocated by caller. It gives garray1[i] = garray2[map[i]] 318 - indices[m] - on output, global indices in this array are rewritten with local ones, i.e, indices_input[i] = garray2[indices_output[i]] 319 320 Example, say 321 n1 = 5 322 garray1[5] = {1, 4, 7, 8, 10} 323 m = 4 324 indices[4] = {2, 4, 8, 9} 325 326 Combining them together, we have 7 global indices in garray2[] 327 n2 = 7 328 garray2[7] = {1, 2, 4, 7, 8, 9, 10} 329 330 And we have map[] to connect "garray1[i] = garray2[map[i]], i=[0,n1)" 331 map[5] = {0, 2, 3, 4, 6} 332 333 On output, indices[] is updated with local indices 334 indices[4] = {1, 2, 4, 5} 335 */ 336 static PetscErrorCode ReduceTwoSetsOfGlobalIndices(PetscInt n1, const PetscInt *garray1, PetscInt m, PetscInt *indices, PetscInt *n2_, PetscInt **garray2_, PetscInt *map) 337 { 338 PetscHMapI g2l = nullptr; 339 PetscHashIter iter; 340 PetscInt tot, key, val; // total unique global indices. key is global id; val is local id 341 PetscInt n2, *garray2; 342 343 PetscFunctionBegin; 344 tot = 0; 345 PetscCall(PetscHMapICreateWithSize(n1, &g2l)); 346 for (PetscInt i = 0; i < m; i++) { // insert those in indices[] 347 PetscCall(PetscHMapIGetWithDefault(g2l, indices[i], -1, &val)); // if not exist, val is set with -1 348 if (val < 0) PetscCall(PetscHMapISet(g2l, indices[i], tot++)); // val < 0 means gid is not in the hash table yet 349 } 350 351 for (PetscInt i = 0; i < n1; i++) { // insert those in garray1[] 352 PetscCall(PetscHMapIGetWithDefault(g2l, garray1[i], -1, &val)); 353 if (val < 0) PetscCall(PetscHMapISet(g2l, garray1[i], tot++)); 354 } 355 356 // Pull out (unique) globals in the hash table and put them in garray2[] 357 n2 = tot; 358 PetscCall(PetscMalloc1(n2, &garray2)); 359 tot = 0; 360 PetscHashIterBegin(g2l, iter); 361 while (!PetscHashIterAtEnd(g2l, iter)) { 362 PetscHashIterGetKey(g2l, iter, key); 363 PetscHashIterNext(g2l, iter); 364 garray2[tot++] = key; 365 } 366 367 // Sort garray2[] and then map them to local indices starting from 0 368 PetscCall(PetscSortInt(n2, garray2)); 369 PetscCall(PetscHMapIClear(g2l)); 370 for (PetscInt i = 0; i < tot; i++) PetscCall(PetscHMapISet(g2l, garray2[i], i)); // i is the local id 371 372 // Rewrite indices[] with local indices 373 for (PetscInt i = 0; i < m; i++) { 374 PetscCall(PetscHMapIGetWithDefault(g2l, indices[i], -1, &val)); 375 PetscAssert(val >= 0, PETSC_COMM_SELF, PETSC_ERR_PLIB, "Met a negative local column index"); 376 indices[i] = val; 377 } 378 // Record the map that maps garray1[i] to garray2[map[i]] 379 for (PetscInt i = 0; i < n1; i++) PetscCall(PetscHMapIGetWithDefault(g2l, garray1[i], -1, &map[i])); 380 PetscCall(PetscHMapIDestroy(&g2l)); 381 *n2_ = n2; 382 *garray2_ = garray2; 383 PetscFunctionReturn(PETSC_SUCCESS); 384 } 385 386 /* 387 MatMPIAIJKokkosReduce - Reduce rows of a MPIAIJKOKKOS matrix (E, in split form) to produce another matrix (F, also in split form, stored in mm) 388 389 It is the reverse of MatMPIAIJKokkosBcast() in some sense, but with a different signature since we do not really need a fully populated MPIAIJKOKKOS E. 390 391 Think each row of E as a leaf, then the given ownerSF specifies roots for the leaves. Roots may connect to multiple leaves. 392 In this routine, we sparse-merge leaves (rows) at their roots to form potentially longer rows in F. F's number of rows will be nroots of ownerSF. 393 394 Input Parameters: 395 + comm - MPI communicator of E 396 . A - diag block of E, using local column indices 397 . B - off-diag block of E, using local column indices 398 . cstart - (global) start column of Ed 399 . cend - (global) end column + 1 of Ed. In other words, E's column ownership is in range of [cstart, cend) 400 . garray1[n1] - global column indices of Eo. Here n1 is Eo's column size. 401 . ownerSF - the SF specifies ownership (root) of rows in E 402 . reuse - either MAT_INITIAL_MATRIX or MAT_REUSE_MATRIX 403 - mm - to stash intermediate data structures for reuse 404 405 Output Parameters: 406 + map[n1] - allocated by caller. It maps garray1[] to garray2[]. See more at ReduceTwoSetsOfGlobalIndices(). 407 - mm - contains various info, such as garray2[], F (Fd, Fo) etc. 408 409 Notes: 410 When reuse = MAT_REUSE_MATRIX, cstart, cend, garray1, ownerSF, map are not significant. 411 412 */ 413 static PetscErrorCode MatMPIAIJKokkosReduceBegin(MPI_Comm comm, KokkosCsrMatrix A, KokkosCsrMatrix B, PetscInt cstart, PetscInt cend, const PetscInt *garray1, PetscSF ownerSF, MatReuse reuse, PetscInt *map, MatMatStruct_AtB *mm) 414 { 415 PetscFunctionBegin; 416 if (reuse == MAT_INITIAL_MATRIX) { 417 PetscInt Em = A.numRows(), Fm; 418 PetscInt n1 = B.numCols(); 419 420 PetscCall(PetscSFGetGraph(ownerSF, &Fm, NULL, NULL, NULL)); // Fm = #rows of F = nroots of ownerSF 421 422 // Do the analysis on host 423 auto Ai_h = Kokkos::create_mirror_view_and_copy(Kokkos::HostSpace(), A.graph.row_map); 424 auto Aj_h = Kokkos::create_mirror_view_and_copy(Kokkos::HostSpace(), A.graph.entries); 425 auto Bi_h = Kokkos::create_mirror_view_and_copy(Kokkos::HostSpace(), B.graph.row_map); 426 auto Bj_h = Kokkos::create_mirror_view_and_copy(Kokkos::HostSpace(), B.graph.entries); 427 const MatRowMapType *Ai = Ai_h.data(), *Bi = Bi_h.data(); 428 const MatColIdxType *Aj = Aj_h.data(), *Bj = Bj_h.data(); 429 430 // Count how many nonzeros of each row in E are in the left of the diag block [cstart,cend) 431 PetscIntKokkosViewHost E_NzLeft_h(NoInit("E_NzLeft_h"), Em), E_RowLen_h(NoInit("E_RowLen_h"), Em); 432 PetscInt *E_NzLeft = E_NzLeft_h.data(), *E_RowLen = E_RowLen_h.data(); 433 for (PetscInt i = 0; i < Em; i++) { 434 const PetscInt *first, *last, *it; 435 PetscInt count, step; 436 // std::lower_bound(first,last,cstart), but need to use global column indices 437 first = Bj + Bi[i]; 438 last = Bj + Bi[i + 1]; 439 count = last - first; 440 while (count > 0) { 441 it = first; 442 step = count / 2; 443 it += step; 444 if (garray1[*it] < cstart) { // map local to global 445 first = ++it; 446 count -= step + 1; 447 } else count = step; 448 } 449 E_NzLeft[i] = first - (Bj + Bi[i]); 450 E_RowLen[i] = (Ai[i + 1] - Ai[i]) + (Bi[i + 1] - Bi[i]); 451 } 452 453 // Get length of rows (i.e., sizes of leaves) that contribute to my roots 454 const PetscMPIInt *iranks, *ranks; 455 const PetscInt *ioffset, *irootloc, *roffset, *rmine; 456 PetscInt niranks, nranks; 457 MPI_Request *reqs; 458 PetscMPIInt tag; 459 PetscSF reduceSF; 460 PetscInt *sdisp, *rdisp; 461 462 PetscCall(PetscCommGetNewTag(comm, &tag)); 463 PetscCall(PetscSFGetLeafRanks(ownerSF, &niranks, &iranks, &ioffset, &irootloc)); // get leaf ranks connecting to roots on this process (I'll recv from them) 464 PetscCall(PetscSFGetRootRanks(ownerSF, &nranks, &ranks, &roffset, &rmine, NULL)); // get root ranks this process connects (I'll send to them) 465 466 // Find out length of each row I will receive. Even for the same row index, when they are from 467 // different senders, they might have different lengths (and sparsity patterns) 468 PetscInt sendRowCnt = roffset[nranks], recvRowCnt = ioffset[niranks]; 469 PetscInt *sendRowLen, *recvRowLen; // lengths of rows of I need to send/recv per process 470 471 PetscCall(PetscMalloc5(sendRowCnt, &sendRowLen, recvRowCnt + 1, &recvRowLen, nranks, &sdisp, niranks + 1, &rdisp, nranks + niranks, &reqs)); 472 473 for (PetscInt i = 0; i < sendRowCnt; i++) sendRowLen[i] = E_RowLen[rmine[i]]; 474 recvRowLen[0] = 0; // since we will make it in CSR format later 475 recvRowLen++; // advance the pointer now 476 for (PetscInt i = 0; i < niranks; i++) { MPI_Irecv(&recvRowLen[ioffset[i]], ioffset[i + 1] - ioffset[i], MPIU_INT, iranks[i], tag, comm, &reqs[nranks + i]); } 477 for (PetscInt i = 0; i < nranks; i++) { MPI_Isend(&sendRowLen[roffset[i]], roffset[i + 1] - roffset[i], MPIU_INT, ranks[i], tag, comm, &reqs[i]); } 478 PetscCallMPI(MPI_Waitall(nranks + niranks, reqs, MPI_STATUSES_IGNORE)); 479 480 // Build the real PetscSF for reducing E rows (buffer to buffer) 481 rdisp[0] = 0; 482 for (PetscInt i = 0; i < niranks; i++) { 483 rdisp[i + 1] = rdisp[i]; 484 for (PetscInt j = ioffset[i]; j < ioffset[i + 1]; j++) { rdisp[i + 1] += recvRowLen[j]; } 485 } 486 recvRowLen--; // put it back into csr format 487 for (PetscInt i = 0; i < recvRowCnt; i++) recvRowLen[i + 1] += recvRowLen[i]; 488 489 for (PetscInt i = 0; i < nranks; i++) { MPI_Irecv(&sdisp[i], 1, MPIU_INT, ranks[i], tag, comm, &reqs[i]); } 490 for (PetscInt i = 0; i < niranks; i++) { MPI_Isend(&rdisp[i], 1, MPIU_INT, iranks[i], tag, comm, &reqs[nranks + i]); } 491 PetscCallMPI(MPI_Waitall(nranks + niranks, reqs, MPI_STATUSES_IGNORE)); 492 493 PetscInt nleaves = 0, Enz = 0; // leaves are nonzeros I will send 494 PetscInt nroots = rdisp[niranks]; // roots are nonzeros I will recv 495 PetscSFNode *iremote; 496 497 for (PetscInt i = 0; i < Em; i++) Enz += E_RowLen[i]; 498 PetscAssert(A.nnz() + B.nnz() == Enz, comm, PETSC_ERR_PLIB, "Enz should be equal to sum of nnz of A and B"); 499 PetscCallMPI(PetscMalloc1(Enz, &iremote)); // no free, since we give ownership to reduceSF 500 501 for (PetscInt i = 0; i < nranks; i++) { 502 PetscInt count = 0; 503 for (PetscInt j = roffset[i]; j < roffset[i + 1]; j++) count += E_RowLen[rmine[j]]; 504 for (PetscInt j = 0; j < count; j++) { 505 iremote[nleaves + j].rank = ranks[i]; 506 iremote[nleaves + j].index = sdisp[i] + j; 507 } 508 nleaves += count; 509 } 510 PetscCheck(nleaves == Enz, comm, PETSC_ERR_PLIB, "nleaves should be equal to Enz"); 511 512 PetscCall(PetscSFCreate(comm, &reduceSF)); 513 PetscCall(PetscSFSetGraph(reduceSF, nroots, nleaves, NULL, PETSC_OWN_POINTER, iremote, PETSC_OWN_POINTER)); 514 515 // Copy (global) column indices of the needed rows in E to sendCol[], and then PetscSFReduce to recvCol[] 516 PetscInt *sendCol, *recvCol; 517 PetscCall(PetscMalloc2(nleaves, &sendCol, nroots, &recvCol)); 518 for (PetscInt k = 0; k < roffset[nranks]; k++) { 519 PetscInt i = rmine[k]; // row to be copied 520 PetscInt *buf = &sendCol[Ai[i] + Bi[i]]; 521 PetscInt nzLeft = E_NzLeft[i]; 522 PetscInt alen = Ai[i + 1] - Ai[i], blen = Bi[i + 1] - Bi[i]; 523 for (PetscInt j = 0; j < alen + blen; j++) { 524 if (j < nzLeft) { 525 buf[j] = garray1[Bj[Bi[i] + j]]; // left B, in global 526 } else if (j < nzLeft + alen) { 527 buf[j] = Aj[Ai[i] + j - nzLeft] + cstart; // diag A, also in global 528 } else { 529 buf[j] = garray1[Bj[Bi[i] + j - alen]]; // right B, in global 530 } 531 } 532 } 533 PetscCall(PetscSFReduceWithMemTypeBegin(reduceSF, MPIU_INT, PETSC_MEMTYPE_HOST, sendCol, PETSC_MEMTYPE_HOST, recvCol, MPI_REPLACE)); 534 PetscCall(PetscSFReduceEnd(reduceSF, MPIU_INT, sendCol, recvCol, MPI_REPLACE)); 535 536 // With recvCol[], we do a series of analysis to get i, j of Fd, Fo, and build plans to reduce nonzeros in recv buffers to Fd and Fo 537 PetscInt *recvRowPerm, *recvColSorted; 538 PetscInt *recvNzPerm, *recvNzPermSorted; 539 PetscCall(PetscMalloc4(recvRowCnt, &recvRowPerm, nroots, &recvColSorted, nroots, &recvNzPerm, nroots, &recvNzPermSorted)); 540 541 for (PetscInt i = 0; i < nroots; i++) recvNzPerm[i] = i; // numbering all received nonzeros 542 for (PetscInt i = 0; i < recvRowCnt; i++) recvRowPerm[i] = i; // put up a permutation array, so that after sorting we know where to get a row in recvCol[] 543 PetscCall(PetscSortIntWithPermutation(recvRowCnt, irootloc, recvRowPerm)); // irootloc[] (owned by ownerSF) won't be changed 544 545 // i[] array, nz are always easiest to compute 546 MatRowMapKokkosViewHost Fdi_h(NoInit("Fdi_h"), Fm + 1), Foi_h(NoInit("Foi_h"), Fm + 1); 547 MatRowMapType *Fdi, *Foi; 548 PetscInt FnzDups = 0, Fdnz = 0, FdnzDups = 0, Fonz = 0, FonzDups = 0; // nz (with or without dups) in F, Fd, Fo 549 PetscInt iter; 550 551 Kokkos::deep_copy(Fdi_h, 0); // zero, as we will do 'val++' on them 552 Kokkos::deep_copy(Foi_h, 0); 553 Fdi = Fdi_h.data() + 1; // +1 for easy indexing in code below 554 Foi = Foi_h.data() + 1; 555 iter = 0; 556 while (iter < recvRowCnt) { // iter over received rows 557 PetscInt curRowIdx = irootloc[recvRowPerm[iter]]; 558 PetscInt dupRows = 1; // current row has this many contributing rows (of various sparsity patterns) 559 560 while (iter + dupRows < recvRowCnt && irootloc[recvRowPerm[iter + dupRows]] == curRowIdx) dupRows++; 561 562 // Copy column indices (and their permutation) of these rows into recvColSorted & recvNzPermSorted 563 PetscInt nz = 0; // nz (with dups) in the current row 564 PetscInt *jbuf = recvColSorted + FnzDups; 565 PetscInt *pbuf = recvNzPermSorted + FnzDups; 566 PetscInt *jbuf2 = jbuf; // temp pointers 567 PetscInt *pbuf2 = pbuf; 568 for (PetscInt d = 0; d < dupRows; d++) { 569 PetscInt i = recvRowPerm[iter + d]; 570 PetscInt len = recvRowLen[i + 1] - recvRowLen[i]; 571 PetscCall(PetscArraycpy(jbuf2, &recvCol[recvRowLen[i]], len)); 572 PetscCall(PetscArraycpy(pbuf2, &recvNzPerm[recvRowLen[i]], len)); 573 jbuf2 += len; 574 pbuf2 += len; 575 nz += len; 576 } 577 PetscCall(PetscIntSortSemiOrderedWithArray(nz, jbuf, pbuf)); // It could be improved with k-way merge sort, since the rows are already sorted 578 579 // Scan column indices (in jbuf[0,nz), might have dups) of this row, and see how many go to Fd and how many go to Fo 580 PetscInt cur = 0; 581 while (cur < nz) { 582 PetscInt curColIdx = jbuf[cur]; 583 PetscInt dups = 1; 584 585 while (cur + dups < nz && jbuf[cur + dups] == curColIdx) dups++; 586 if (curColIdx >= cstart && curColIdx < cend) { 587 Fdi[curRowIdx]++; 588 FdnzDups += dups; 589 } else { 590 Foi[curRowIdx]++; 591 FonzDups += dups; 592 } 593 cur += dups; 594 } 595 596 FnzDups += nz; 597 iter += dupRows; // Move to next unique row 598 } 599 600 Fdi = Fdi_h.data(); // restore Fdi, Foi and make them CSR 601 Foi = Foi_h.data(); 602 for (PetscInt i = 0; i < Fm; i++) { 603 Fdi[i + 1] += Fdi[i]; 604 Foi[i + 1] += Foi[i]; 605 } 606 Fdnz = Fdi[Fm]; 607 Fonz = Foi[Fm]; 608 PetscCall(PetscFree2(sendCol, recvCol)); 609 610 // Allocate j, jmap, jperm for Fd and Fo 611 MatColIdxKokkosViewHost Fdj_h(NoInit("Fdj_h"), Fdnz), Foj_h(NoInit("Foj_h"), Fonz); 612 MatRowMapKokkosViewHost Fdjmap_h(NoInit("Fdjmap_h"), Fdnz + 1), Fojmap_h(NoInit("Fojmap_h"), Fonz + 1); // +1 to make csr 613 MatRowMapKokkosViewHost Fdjperm_h(NoInit("Fdjperm_h"), FdnzDups), Fojperm_h(NoInit("Fojperm_h"), FonzDups); 614 MatColIdxType *Fdj = Fdj_h.data(), *Foj = Foj_h.data(); 615 MatRowMapType *Fdjmap = Fdjmap_h.data(), *Fojmap = Fojmap_h.data(); 616 MatRowMapType *Fdjperm = Fdjperm_h.data(), *Fojperm = Fojperm_h.data(); 617 618 // Scan recvColSorted[] again, and fill j, jmap, jperm for Fd and Fo 619 Fdjmap[0] = 0; 620 Fojmap[0] = 0; 621 FnzDups = 0; 622 Fdnz = 0; 623 Fonz = 0; 624 iter = 0; // iter over received rows 625 while (iter < recvRowCnt) { 626 PetscInt curRowIdx = irootloc[recvRowPerm[iter]]; // current row idx 627 PetscInt dupRows = 1; // It has this many contributing rows (of various lengths) 628 PetscInt nz = 0; // nz (with dups) in the current row 629 630 while (iter + dupRows < recvRowCnt && irootloc[recvRowPerm[iter + dupRows]] == curRowIdx) dupRows++; 631 for (PetscInt d = 0; d < dupRows; d++) { 632 PetscInt i = recvRowPerm[iter + d]; 633 nz += recvRowLen[i + 1] - recvRowLen[i]; 634 } 635 636 PetscInt *jbuf = recvColSorted + FnzDups; 637 // Scan columns (in jbuf[0,nz) of this row, copy them and their permutation to j[] and jperm[] of Fd and Fo 638 PetscInt cur = 0; 639 while (cur < nz) { 640 PetscInt curColIdx = jbuf[cur]; 641 PetscInt dups = 1; 642 643 while (cur + dups < nz && jbuf[cur + dups] == curColIdx) dups++; 644 if (curColIdx >= cstart && curColIdx < cend) { 645 Fdj[Fdnz] = curColIdx - cstart; // easily convert to local 646 Fdjmap[Fdnz + 1] = Fdjmap[Fdnz] + dups; 647 for (PetscInt j = 0; j < dups; j++) Fdjperm[Fdjmap[Fdnz] + j] = recvNzPermSorted[FnzDups + j]; 648 FdnzDups += dups; 649 Fdnz++; 650 } else { 651 Foj[Fonz] = curColIdx; // in global 652 Fojmap[Fonz + 1] = Fojmap[Fonz] + dups; 653 for (PetscInt j = 0; j < dups; j++) Fojperm[Fojmap[Fonz] + j] = recvNzPermSorted[FnzDups + j]; 654 FonzDups += dups; 655 Fonz++; 656 } 657 cur += dups; 658 FnzDups += dups; 659 } 660 iter += dupRows; // Move to next unique row 661 } 662 PetscCall(PetscFree4(recvRowPerm, recvColSorted, recvNzPerm, recvNzPermSorted)); 663 PetscCall(PetscFree5(sendRowLen, recvRowLen, sdisp, rdisp, reqs)); 664 665 // Combine global column indices in garray1[] and Foj[] 666 PetscInt n2, *garray2; 667 668 PetscCall(ReduceTwoSetsOfGlobalIndices(n1, garray1, Fonz, Foj, &n2, &garray2, map)); 669 mm->sf = reduceSF; 670 mm->leafBuf = MatScalarKokkosView(NoInit("leafBuf"), nleaves); 671 mm->rootBuf = MatScalarKokkosView(NoInit("rootBuf"), nroots); 672 mm->garray = garray2; // give ownership, so no free 673 mm->n = n2; 674 mm->E_NzLeft = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), E_NzLeft_h); 675 mm->Fdjmap = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Fdjmap_h); 676 mm->Fdjperm = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Fdjperm_h); 677 mm->Fojmap = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Fojmap_h); 678 mm->Fojperm = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Fojperm_h); 679 680 // Output Fd and Fo in KokkosCsrMatrix format 681 MatScalarKokkosView Fda_d(NoInit("Fda_d"), Fdnz); 682 MatRowMapKokkosView Fdi_d = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Fdi_h); 683 MatColIdxKokkosView Fdj_d = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Fdj_h); 684 MatScalarKokkosView Foa_d(NoInit("Foa_d"), Fonz); 685 MatRowMapKokkosView Foi_d = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Foi_h); 686 MatColIdxKokkosView Foj_d = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Foj_h); 687 688 PetscCallCXX(mm->Fd = KokkosCsrMatrix("Fd", Fm, cend - cstart, Fdnz, Fda_d, Fdi_d, Fdj_d)); 689 PetscCallCXX(mm->Fo = KokkosCsrMatrix("Fo", Fm, n2, Fonz, Foa_d, Foi_d, Foj_d)); // Fo's column size is n2, length of garray2[] 690 691 // Compute kernel launch parameters in merging E 692 PetscInt teamSize, vectorLength, rowsPerTeam; 693 694 teamSize = vectorLength = rowsPerTeam = -1; 695 PetscCall(MatMergeGetLaunchParameters<DefaultExecutionSpace>(Em, Enz, -1, teamSize, vectorLength, rowsPerTeam)); 696 mm->E_TeamSize = teamSize; 697 mm->E_VectorLength = vectorLength; 698 mm->E_RowsPerTeam = rowsPerTeam; 699 } else PetscCheck(reuse == MAT_REUSE_MATRIX, comm, PETSC_ERR_PLIB, "Unsupported MatReuse enum %d", reuse); 700 701 // Handy aliases 702 auto &Aa = A.values; 703 auto &Ba = B.values; 704 const auto &Ai = A.graph.row_map; 705 const auto &Bi = B.graph.row_map; 706 const auto &E_NzLeft = mm->E_NzLeft; 707 auto &leafBuf = mm->leafBuf; 708 auto &rootBuf = mm->rootBuf; 709 PetscSF reduceSF = mm->sf; 710 PetscInt Em = A.numRows(); 711 PetscInt teamSize = mm->E_TeamSize; 712 PetscInt vectorLength = mm->E_VectorLength; 713 PetscInt rowsPerTeam = mm->E_RowsPerTeam; 714 PetscInt workSets = (Em + rowsPerTeam - 1) / rowsPerTeam; 715 716 // Copy rows in A/B of E to leafBuf, then pass it to rootBuf 717 PetscCallCXX(Kokkos::parallel_for( 718 Kokkos::TeamPolicy<>(workSets, teamSize, vectorLength), KOKKOS_LAMBDA(const KokkosTeamMemberType &t) { 719 Kokkos::parallel_for(Kokkos::TeamThreadRange(t, 0, rowsPerTeam), [&](PetscInt k) { 720 PetscInt i = t.league_rank() * rowsPerTeam + k; // i-th row in F 721 if (i < Em) { 722 PetscInt disp = Ai(i) + Bi(i); 723 PetscInt alen = Ai(i + 1) - Ai(i); 724 PetscInt blen = Bi(i + 1) - Bi(i); 725 PetscInt nzleft = E_NzLeft(i); 726 727 Kokkos::parallel_for(Kokkos::ThreadVectorRange(t, alen + blen), [&](PetscInt j) { 728 MatScalar &val = leafBuf(disp + j); 729 if (j < nzleft) { // B left 730 val = Ba(Bi(i) + j); 731 } else if (j < nzleft + alen) { // diag A 732 val = Aa(Ai(i) + j - nzleft); 733 } else { // B right 734 val = Ba(Bi(i) + j - alen); 735 } 736 }); 737 } 738 }); 739 })); 740 PetscCall(PetscSFReduceWithMemTypeBegin(reduceSF, MPIU_SCALAR, PETSC_MEMTYPE_KOKKOS, leafBuf.data(), PETSC_MEMTYPE_KOKKOS, rootBuf.data(), MPI_REPLACE)); 741 PetscFunctionReturn(PETSC_SUCCESS); 742 } 743 744 // To finish MatMPIAIJKokkosReduce. 745 static PetscErrorCode MatMPIAIJKokkosReduceEnd(MPI_Comm comm, KokkosCsrMatrix A, KokkosCsrMatrix B, PetscInt cstart, PetscInt cend, const PetscInt *garray1, PetscSF ownerSF, MatReuse reuse, PetscInt *map, MatMatStruct_AtB *mm) 746 { 747 PetscFunctionBegin; 748 auto &leafBuf = mm->leafBuf; 749 auto &rootBuf = mm->rootBuf; 750 auto &Fda = mm->Fd.values; 751 const auto &Fdjmap = mm->Fdjmap; 752 const auto &Fdjperm = mm->Fdjperm; 753 auto Fdnz = mm->Fd.nnz(); 754 auto &Foa = mm->Fo.values; 755 const auto &Fojmap = mm->Fojmap; 756 const auto &Fojperm = mm->Fojperm; 757 auto Fonz = mm->Fo.nnz(); 758 PetscSF reduceSF = mm->sf; 759 760 PetscCall(PetscSFReduceEnd(reduceSF, MPIU_SCALAR, leafBuf.data(), rootBuf.data(), MPI_REPLACE)); 761 762 // Reduce data in rootBuf to Fd and Fo 763 PetscCallCXX(Kokkos::parallel_for( 764 Fdnz, KOKKOS_LAMBDA(const MatRowMapType i) { 765 PetscScalar sum = 0.0; 766 for (MatRowMapType k = Fdjmap(i); k < Fdjmap(i + 1); k++) sum += rootBuf(Fdjperm(k)); 767 Fda(i) = sum; 768 })); 769 770 PetscCallCXX(Kokkos::parallel_for( 771 Fonz, KOKKOS_LAMBDA(const MatRowMapType i) { 772 PetscScalar sum = 0.0; 773 for (MatRowMapType k = Fojmap(i); k < Fojmap(i + 1); k++) sum += rootBuf(Fojperm(k)); 774 Foa(i) = sum; 775 })); 776 PetscFunctionReturn(PETSC_SUCCESS); 777 } 778 779 /* 780 MatMPIAIJKokkosBcast - Bcast local rows of a MPIAIJKOKKOS matrix (E) to produce a local matrix (F, stored in mm) in split form 781 782 This is a complex routine. It is essentially the MPIAIJKOKKOS counterpart of MatGetBrowsOfAoCols_MPIAIJ, but supports 783 device and involves various index mapping. 784 785 In the given ownerSF, leaves correspond to rows in F, and roots correspond to rows in E. Roots may connect to multiple leaves. 786 Suppose F's j-th row is connected to a root identified by PetscSFNode (k,i), it means we need to bcast the i-th row of E on rank k 787 to j-th row of F. ownerSF is not an arbitrary SF, instead it is the Mvctx of another MPIAIJ matrix A that is able to perform A*E. 788 F has the same column layout as E. 789 790 Conceptually F has global column indices. In this routine, we spit F into diagonal Fd and off-diagonal Fo. 791 Fd uses local column indices, which are easy to compute. We just need to subtract the "local column range start" from the global indices. 792 Fo had global column indices at first. We will reduce them into local ones. In doing that, we also take into account the global 793 column indices that E's off-diag block has. Let's say there are n1 such indices stored in garray1[]. We will reduce them along with 794 column indices in Fo and update Fo with local indices. 795 796 Input Parameters: 797 + E - the MPIAIJKOKKOS matrix 798 . ownerSF - the ownership SF (insignificant in MAT_REUSE_MATRIX) 799 . reuse - either MAT_INITIAL_MATRIX or MAT_REUSE_MATRIX 800 - mm - to stash matproduct intermediate data structures 801 802 Output Parameters: 803 + map[n1] - allocated by caller. It maps garray1[] to garray2[]. See more at ReduceTwoSetsOfGlobalIndices. 804 - mm - contains various info, such as garray2[], Fd, Fo, etc. 805 806 Notes: 807 When reuse = MAT_REUSE_MATRIX, ownerSF, map are not significant. 808 The routine is provide in split-phase form MatMPIAIJKokkosBcastBegin/End() to provide computation/communication opportunities. 809 */ 810 static PetscErrorCode MatMPIAIJKokkosBcastBegin(Mat E, PetscSF ownerSF, MatReuse reuse, PetscInt *map, MatMatStruct_AB *mm) 811 { 812 Mat_MPIAIJ *empi = static_cast<Mat_MPIAIJ *>(E->data); 813 Mat A = empi->A, B = empi->B; // diag and off-diag 814 Mat_SeqAIJKokkos *akok = static_cast<Mat_SeqAIJKokkos *>(A->spptr), *bkok = static_cast<Mat_SeqAIJKokkos *>(B->spptr); 815 PetscInt Em = E->rmap->n; // #local rows 816 MPI_Comm comm; 817 818 PetscFunctionBegin; 819 PetscCallMPI(PetscObjectGetComm((PetscObject)E, &comm)); 820 if (reuse == MAT_INITIAL_MATRIX) { 821 Mat_SeqAIJ *aseq = static_cast<Mat_SeqAIJ *>(A->data), *bseq = static_cast<Mat_SeqAIJ *>(B->data); 822 PetscInt n1 = B->cmap->n, *Ai = aseq->i, *Aj = aseq->j, *Bi = bseq->i, *Bj = bseq->j; 823 const PetscInt *garray1 = empi->garray; // its size is n1 824 PetscInt cstart, cend; 825 PetscSF bcastSF; 826 827 PetscCall(MatGetOwnershipRangeColumn(E, &cstart, &cend)); 828 829 // Count how many nonzeros of each row in E are in the left of the diag block [cstart,cend) 830 PetscIntKokkosViewHost E_NzLeft_h(NoInit("E_NzLeft_h"), Em), E_RowLen_h(NoInit("E_RowLen_h"), Em); 831 PetscInt *E_NzLeft = E_NzLeft_h.data(), *E_RowLen = E_RowLen_h.data(); 832 for (PetscInt i = 0; i < Em; i++) { 833 const PetscInt *first, *last, *it; 834 PetscInt count, step; 835 // std::lower_bound(first,last,cstart), but need to use global column indices 836 first = Bj + Bi[i]; 837 last = Bj + Bi[i + 1]; 838 count = last - first; 839 while (count > 0) { 840 it = first; 841 step = count / 2; 842 it += step; 843 if (empi->garray[*it] < cstart) { // map local to global 844 first = ++it; 845 count -= step + 1; 846 } else count = step; 847 } 848 E_NzLeft[i] = first - (Bj + Bi[i]); 849 E_RowLen[i] = (Ai[i + 1] - Ai[i]) + (Bi[i + 1] - Bi[i]); 850 } 851 852 // Compute row pointer Fi of F 853 PetscInt *Fi, Fm, Fnz; 854 PetscCall(PetscSFGetGraph(ownerSF, NULL, &Fm, NULL, NULL)); // Fm = #rows of F = nleaves of ownerSF 855 PetscCall(PetscMalloc1(Fm + 1, &Fi)); 856 Fi[0] = 0; 857 PetscCall(PetscSFBcastWithMemTypeBegin(ownerSF, MPIU_INT, PETSC_MEMTYPE_HOST, E_RowLen, PETSC_MEMTYPE_HOST, &Fi[1], MPI_REPLACE)); 858 PetscCall(PetscSFBcastEnd(ownerSF, MPIU_INT, E_RowLen, &Fi[1], MPI_REPLACE)); 859 for (PetscInt i = 0; i < Fm; i++) Fi[i + 1] += Fi[i]; 860 Fnz = Fi[Fm]; 861 862 // Build the real PetscSF for bcasting E rows (buffer to buffer) 863 const PetscMPIInt *iranks, *ranks; 864 const PetscInt *ioffset, *irootloc, *roffset; 865 PetscInt niranks, nranks, *sdisp, *rdisp; 866 MPI_Request *reqs; 867 PetscMPIInt tag; 868 869 PetscCall(PetscSFGetLeafRanks(ownerSF, &niranks, &iranks, &ioffset, &irootloc)); // get leaf ranks referencing roots on this process 870 PetscCall(PetscSFGetRootRanks(ownerSF, &nranks, &ranks, &roffset, NULL, NULL)); // recv info 871 PetscCall(PetscMalloc3(niranks + 1, &sdisp, nranks, &rdisp, niranks + nranks, &reqs)); 872 873 sdisp[0] = 0; // send displacement 874 for (PetscInt i = 0; i < niranks; i++) { 875 sdisp[i + 1] = sdisp[i]; 876 for (PetscInt j = ioffset[i]; j < ioffset[i + 1]; j++) { 877 PetscInt r = irootloc[j]; // row to be sent 878 sdisp[i + 1] += E_RowLen[r]; 879 } 880 } 881 882 PetscCallMPI(PetscCommGetNewTag(comm, &tag)); 883 for (PetscInt i = 0; i < nranks; i++) PetscCallMPI(MPI_Irecv(&rdisp[i], 1, MPIU_INT, ranks[i], tag, comm, &reqs[i])); 884 for (PetscInt i = 0; i < niranks; i++) PetscCallMPI(MPI_Isend(&sdisp[i], 1, MPIU_INT, iranks[i], tag, comm, &reqs[nranks + i])); 885 PetscCallMPI(MPI_Waitall(niranks + nranks, reqs, MPI_STATUSES_IGNORE)); 886 887 PetscInt nleaves = Fnz; // leaves are nonzeros I will receive 888 PetscInt nroots = sdisp[niranks]; // roots are nonzeros I will send 889 PetscSFNode *iremote; // give ownership to bcastSF 890 PetscCall(PetscMalloc1(nleaves, &iremote)); 891 for (PetscInt i = 0; i < nranks; i++) { // for each sender rank 892 PetscInt k = 0; 893 for (PetscInt j = Fi[roffset[i]]; j < Fi[roffset[i + 1]]; j++) { // I will receive rows [roffset[i], roffset[i+1]) of F from ranks[i] 894 iremote[j].rank = ranks[i]; 895 iremote[j].index = rdisp[i] + k; // their root location 896 k++; 897 } 898 } 899 PetscCall(PetscSFCreate(comm, &bcastSF)); 900 PetscCall(PetscSFSetGraph(bcastSF, nroots, nleaves, NULL, PETSC_OWN_POINTER, iremote, PETSC_OWN_POINTER)); 901 PetscCall(PetscFree3(sdisp, rdisp, reqs)); 902 903 // Build a plan (rowoffset, irootloc, E_NzLeft) to copy rows in E to rootdata of bcastSF in parallel 904 PetscIntKokkosViewHost rowoffset_h(NoInit("rowoffset_h"), ioffset[niranks] + 1); 905 PetscInt *rowoffset = rowoffset_h.data(); // for each entry (row) indicated in irootloc[], we calculate its destinate offset in copying 906 rowoffset[0] = 0; 907 for (PetscInt i = 0; i < ioffset[niranks]; i++) { rowoffset[i + 1] = rowoffset[i] + E_RowLen[irootloc[i]]; } 908 909 // Copy (global) column indices of the needed rows in E to a buffer, and then bcast to Fj[] 910 PetscInt *jbuf, *Fj; 911 PetscCall(PetscMalloc2(nroots, &jbuf, Fnz, &Fj)); 912 for (PetscInt k = 0; k < ioffset[niranks]; k++) { 913 PetscInt i = irootloc[k]; // row to be copied 914 PetscInt *buf = &jbuf[rowoffset[k]]; 915 PetscInt nzLeft = E_NzLeft[i]; 916 PetscInt alen = Ai[i + 1] - Ai[i], blen = Bi[i + 1] - Bi[i]; 917 for (PetscInt j = 0; j < alen + blen; j++) { 918 if (j < nzLeft) { 919 buf[j] = empi->garray[Bj[Bi[i] + j]]; // left B, in global 920 } else if (j < nzLeft + alen) { 921 buf[j] = Aj[Ai[i] + j - nzLeft] + cstart; // diag A, also in global 922 } else { 923 buf[j] = empi->garray[Bj[Bi[i] + j - alen]]; // right B, in global 924 } 925 } 926 } 927 PetscCall(PetscSFBcastWithMemTypeBegin(bcastSF, MPIU_INT, PETSC_MEMTYPE_HOST, jbuf, PETSC_MEMTYPE_HOST, Fj, MPI_REPLACE)); 928 PetscCall(PetscSFBcastEnd(bcastSF, MPIU_INT, jbuf, Fj, MPI_REPLACE)); 929 930 // Build a plan (i.e., F_NzLeft) to split F into Fd and Fo 931 MatRowMapKokkosViewHost Fdi_h(NoInit("Fdi_h"), Fm + 1), Foi_h(NoInit("Foi_h"), Fm + 1); // row pointer of Fd, Fo 932 MatColIdxKokkosViewHost F_NzLeft_h(NoInit("F_NzLeft_h"), Fm); // split each row of F into Left, Diag, Right. We only need to record #nz in Left and Diag. 933 MatRowMapType *Fdi = Fdi_h.data(), *Foi = Foi_h.data(); 934 MatColIdxType *F_NzLeft = F_NzLeft_h.data(); 935 936 Fdi[0] = Foi[0] = 0; 937 for (PetscInt i = 0; i < Fm; i++) { 938 PetscInt *first, *last, *lb1, *lb2; 939 // cut the row into: Left, [cstart, cend), Right 940 first = Fj + Fi[i]; 941 last = Fj + Fi[i + 1]; 942 lb1 = std::lower_bound(first, last, cstart); 943 F_NzLeft[i] = lb1 - first; 944 lb2 = std::lower_bound(first, last, cend); 945 Fdi[i + 1] = lb2 - lb1; // row i length in Fdi 946 Foi[i + 1] = (Fi[i + 1] - Fi[i]) - Fdi[i + 1]; // row i length in Foi 947 } 948 for (PetscInt i = 0; i < Fm; i++) { 949 Fdi[i + 1] += Fdi[i]; 950 Foi[i + 1] += Foi[i]; 951 } 952 953 // Fill Fdj[] and Foj[], i.e., columns of Fd and Fo. Fdj[] are local, but Foj[] are not yet. 954 PetscInt Fdnz = Fdi[Fm], Fonz = Foi[Fm]; 955 MatColIdxKokkosViewHost Fdj_h(NoInit("Fdj_h"), Fdnz), Foj_h(NoInit("Foj_h"), Fonz); 956 MatColIdxType *Fdj = Fdj_h.data(), *Foj = Foj_h.data(), gid; 957 958 for (PetscInt i = 0; i < Fm; i++) { 959 PetscInt nzLeft = F_NzLeft[i]; 960 PetscInt len = Fdi[i + 1] - Fdi[i]; // diag row len 961 for (PetscInt j = 0; j < Fi[i + 1] - Fi[i]; j++) { 962 gid = Fj[Fi[i] + j]; 963 if (j < nzLeft) { // left, in global 964 Foj[Foi[i] + j] = gid; 965 } else if (j < nzLeft + len) { // diag, in local 966 Fdj[Fdi[i] + j - nzLeft] = gid - cstart; 967 } else { // right, in global 968 Foj[Foi[i] + j - len] = gid; 969 } 970 } 971 } 972 PetscCall(PetscFree2(jbuf, Fj)); 973 PetscCall(PetscFree(Fi)); 974 975 // Reduce global indices in Foj[] and garray1[] into local ones 976 PetscInt n2, *garray2; 977 PetscCall(ReduceTwoSetsOfGlobalIndices(n1, garray1, Fonz, Foj, &n2, &garray2, map)); 978 979 // Record the plans built above, for reuse 980 PetscIntKokkosViewHost tmp(const_cast<PetscInt *>(irootloc), ioffset[niranks]); // irootloc[] is owned by ownerSF. We create a copy for safety 981 PetscIntKokkosViewHost irootloc_h(NoInit("irootloc_h"), ioffset[niranks]); 982 Kokkos::deep_copy(irootloc_h, tmp); 983 mm->sf = bcastSF; 984 mm->E_NzLeft = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), E_NzLeft_h); 985 mm->F_NzLeft = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), F_NzLeft_h); 986 mm->irootloc = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), irootloc_h); 987 mm->rowoffset = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), rowoffset_h); 988 mm->rootBuf = MatScalarKokkosView(NoInit("rootBuf"), nroots); 989 mm->leafBuf = MatScalarKokkosView(NoInit("leafBuf"), nleaves); 990 mm->garray = garray2; 991 mm->n = n2; 992 993 // Output Fd and Fo in KokkosCsrMatrix format 994 MatScalarKokkosView Fda_d(NoInit("Fda_d"), Fdnz), Foa_d(NoInit("Foa_d"), Fonz); 995 MatRowMapKokkosView Fdi_d = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Fdi_h); 996 MatColIdxKokkosView Fdj_d = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Fdj_h); 997 MatRowMapKokkosView Foi_d = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Foi_h); 998 MatColIdxKokkosView Foj_d = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Foj_h); 999 1000 PetscCallCXX(mm->Fd = KokkosCsrMatrix("Fd", Fm, cend - cstart, Fdnz, Fda_d, Fdi_d, Fdj_d)); 1001 PetscCallCXX(mm->Fo = KokkosCsrMatrix("Fo", Fm, n2, Fonz, Foa_d, Foi_d, Foj_d)); 1002 1003 // Compute kernel launch parameters in merging E or splitting F 1004 PetscInt teamSize, vectorLength, rowsPerTeam; 1005 1006 teamSize = vectorLength = rowsPerTeam = -1; 1007 PetscCall(MatMergeGetLaunchParameters<DefaultExecutionSpace>(mm->irootloc.extent(0), mm->rootBuf.extent(0), -1, teamSize, vectorLength, rowsPerTeam)); 1008 mm->E_TeamSize = teamSize; 1009 mm->E_VectorLength = vectorLength; 1010 mm->E_RowsPerTeam = rowsPerTeam; 1011 1012 teamSize = vectorLength = rowsPerTeam = -1; 1013 PetscCall(MatMergeGetLaunchParameters<DefaultExecutionSpace>(Fm, Fnz, -1, teamSize, vectorLength, rowsPerTeam)); 1014 mm->F_TeamSize = teamSize; 1015 mm->F_VectorLength = vectorLength; 1016 mm->F_RowsPerTeam = rowsPerTeam; 1017 } else PetscCheck(reuse == MAT_REUSE_MATRIX, comm, PETSC_ERR_PLIB, "Unsupported MatReuse enum %d", reuse); 1018 1019 // Sync E's value to device 1020 akok->a_dual.sync_device(); 1021 bkok->a_dual.sync_device(); 1022 1023 // Handy aliases 1024 const auto &Aa = akok->a_dual.view_device(); 1025 const auto &Ba = bkok->a_dual.view_device(); 1026 const auto &Ai = akok->i_dual.view_device(); 1027 const auto &Bi = bkok->i_dual.view_device(); 1028 1029 // Fetch the plans 1030 PetscIntKokkosView &E_NzLeft = mm->E_NzLeft; 1031 PetscSF &bcastSF = mm->sf; 1032 MatScalarKokkosView &rootBuf = mm->rootBuf; 1033 MatScalarKokkosView &leafBuf = mm->leafBuf; 1034 PetscIntKokkosView &irootloc = mm->irootloc; 1035 PetscIntKokkosView &rowoffset = mm->rowoffset; 1036 1037 PetscInt teamSize = mm->E_TeamSize; 1038 PetscInt vectorLength = mm->E_VectorLength; 1039 PetscInt rowsPerTeam = mm->E_RowsPerTeam; 1040 PetscInt workSets = (irootloc.extent(0) + rowsPerTeam - 1) / rowsPerTeam; 1041 1042 // Copy rows in A/B of E to rootBuf, then bcast it to leafBuf 1043 PetscCallCXX(Kokkos::parallel_for( 1044 Kokkos::TeamPolicy<>(workSets, teamSize, vectorLength), KOKKOS_LAMBDA(const KokkosTeamMemberType &t) { 1045 Kokkos::parallel_for(Kokkos::TeamThreadRange(t, 0, rowsPerTeam), [&](PetscInt k) { 1046 size_t r = t.league_rank() * rowsPerTeam + k; // r-th entry in irootloc[] 1047 if (r < irootloc.extent(0)) { 1048 PetscInt i = irootloc(r); // row i of E 1049 PetscInt disp = rowoffset(r); 1050 PetscInt alen = Ai(i + 1) - Ai(i); 1051 PetscInt blen = Bi(i + 1) - Bi(i); 1052 PetscInt nzleft = E_NzLeft(i); 1053 1054 Kokkos::parallel_for(Kokkos::ThreadVectorRange(t, alen + blen), [&](PetscInt j) { 1055 if (j < nzleft) { // B left 1056 rootBuf(disp + j) = Ba(Bi(i) + j); 1057 } else if (j < nzleft + alen) { // diag A 1058 rootBuf(disp + j) = Aa(Ai(i) + j - nzleft); 1059 } else { // B right 1060 rootBuf(disp + j) = Ba(Bi(i) + j - alen); 1061 } 1062 }); 1063 } 1064 }); 1065 })); 1066 PetscCall(PetscSFBcastWithMemTypeBegin(bcastSF, MPIU_SCALAR, PETSC_MEMTYPE_KOKKOS, rootBuf.data(), PETSC_MEMTYPE_KOKKOS, leafBuf.data(), MPI_REPLACE)); 1067 PetscFunctionReturn(PETSC_SUCCESS); 1068 } 1069 1070 // To finish MatMPIAIJKokkosBcast. 1071 static PetscErrorCode MatMPIAIJKokkosBcastEnd(Mat E, PetscSF ownerSF, MatReuse reuse, PetscInt *map, MatMatStruct_AB *mm) 1072 { 1073 PetscFunctionBegin; 1074 const auto &Fd = mm->Fd; 1075 const auto &Fo = mm->Fo; 1076 const auto &Fdi = Fd.graph.row_map; 1077 const auto &Foi = Fo.graph.row_map; 1078 auto &Fda = Fd.values; 1079 auto &Foa = Fo.values; 1080 auto Fm = Fd.numRows(); 1081 1082 PetscIntKokkosView &F_NzLeft = mm->F_NzLeft; 1083 PetscSF &bcastSF = mm->sf; 1084 MatScalarKokkosView &rootBuf = mm->rootBuf; 1085 MatScalarKokkosView &leafBuf = mm->leafBuf; 1086 PetscInt teamSize = mm->F_TeamSize; 1087 PetscInt vectorLength = mm->F_VectorLength; 1088 PetscInt rowsPerTeam = mm->F_RowsPerTeam; 1089 PetscInt workSets = (Fm + rowsPerTeam - 1) / rowsPerTeam; 1090 1091 PetscCall(PetscSFBcastEnd(bcastSF, MPIU_SCALAR, rootBuf.data(), leafBuf.data(), MPI_REPLACE)); 1092 1093 // Update Fda and Foa with new data in leafBuf (as if it is Fa) 1094 PetscCallCXX(Kokkos::parallel_for( 1095 Kokkos::TeamPolicy<>(workSets, teamSize, vectorLength), KOKKOS_LAMBDA(const KokkosTeamMemberType &t) { 1096 Kokkos::parallel_for(Kokkos::TeamThreadRange(t, 0, rowsPerTeam), [&](PetscInt k) { 1097 PetscInt i = t.league_rank() * rowsPerTeam + k; // i-th row in F 1098 if (i < Fm) { 1099 PetscInt nzLeft = F_NzLeft(i); 1100 PetscInt alen = Fdi(i + 1) - Fdi(i); 1101 PetscInt blen = Foi(i + 1) - Foi(i); 1102 PetscInt Fii = Fdi(i) + Foi(i); 1103 1104 Kokkos::parallel_for(Kokkos::ThreadVectorRange(t, alen + blen), [&](PetscInt j) { 1105 PetscScalar val = leafBuf(Fii + j); 1106 if (j < nzLeft) { // left 1107 Foa(Foi(i) + j) = val; 1108 } else if (j < nzLeft + alen) { // diag 1109 Fda(Fdi(i) + j - nzLeft) = val; 1110 } else { // right 1111 Foa(Foi(i) + j - alen) = val; 1112 } 1113 }); 1114 } 1115 }); 1116 })); 1117 PetscFunctionReturn(PETSC_SUCCESS); 1118 } 1119 1120 static PetscErrorCode MatProductSymbolic_MPIAIJKokkos_AtB(Mat_Product *product, Mat A, Mat B, MatMatStruct_AtB *mm) 1121 { 1122 Mat_MPIAIJ *ampi = static_cast<Mat_MPIAIJ *>(A->data); 1123 Mat_MPIAIJ *bmpi = static_cast<Mat_MPIAIJ *>(B->data); 1124 KokkosCsrMatrix Adt, Aot, Ad, Ao, Bd, Bo; 1125 PetscInt cstart, cend; 1126 MPI_Comm comm; 1127 1128 PetscFunctionBegin; 1129 PetscCall(PetscObjectGetComm((PetscObject)B, &comm)); 1130 PetscCall(MatSeqAIJKokkosGenerateTranspose_Private(ampi->A, &Adt)); 1131 PetscCall(MatSeqAIJKokkosGenerateTranspose_Private(ampi->B, &Aot)); 1132 PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(ampi->A, &Ad)); 1133 PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(ampi->B, &Ao)); 1134 PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(bmpi->A, &Bd)); 1135 PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(bmpi->B, &Bo)); 1136 1137 // TODO: add command line options to select spgemm algorithms 1138 auto spgemm_alg = KokkosSparse::SPGEMMAlgorithm::SPGEMM_DEFAULT; // default is TPL if enabled, otherwise KK 1139 1140 // CUDA-10.2's spgemm has bugs. We prefer the SpGEMMreuse APIs introduced in cuda-11.4 1141 #if defined(KOKKOSKERNELS_ENABLE_TPL_CUSPARSE) 1142 #if PETSC_PKG_CUDA_VERSION_LT(11, 4, 0) 1143 spgemm_alg = KokkosSparse::SPGEMMAlgorithm::SPGEMM_KK; 1144 #endif 1145 #endif 1146 1147 PetscCallCXX(mm->kh1.create_spgemm_handle(spgemm_alg)); 1148 PetscCallCXX(mm->kh2.create_spgemm_handle(spgemm_alg)); 1149 PetscCallCXX(mm->kh3.create_spgemm_handle(spgemm_alg)); 1150 PetscCallCXX(mm->kh4.create_spgemm_handle(spgemm_alg)); 1151 1152 // Aot * (B's diag + B's off-diag) 1153 PetscCallCXX(KokkosSparse::spgemm_symbolic(mm->kh3, Aot, false, Bd, false, mm->C3)); 1154 PetscCallCXX(KokkosSparse::spgemm_symbolic(mm->kh4, Aot, false, Bo, false, mm->C4)); 1155 // KK spgemm_symbolic() only populates the result's row map, but not its columns. 1156 // TODO: Remove the fake spgemm_numeric() after KK fixed this problem. 1157 PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh3, Aot, false, Bd, false, mm->C3)); 1158 PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh4, Aot, false, Bo, false, mm->C4)); 1159 #if PETSC_PKG_KOKKOS_KERNELS_VERSION_LT(4, 0, 0) 1160 PetscCallCXX(sort_crs_matrix(mm->C3)); 1161 PetscCallCXX(sort_crs_matrix(mm->C4)); 1162 #endif 1163 1164 // Reduce E (i.e., C3 and C4)'s rows to form F, and overlap the communication 1165 PetscIntKokkosViewHost map_h(NoInit("map_h"), bmpi->B->cmap->n); 1166 PetscCall(MatGetOwnershipRangeColumn(B, &cstart, &cend)); 1167 PetscCall(MatMPIAIJKokkosReduceBegin(comm, mm->C3, mm->C4, cstart, cend, bmpi->garray, ampi->Mvctx, MAT_INITIAL_MATRIX, map_h.data(), mm)); 1168 1169 // Adt * (B's diag + B's off-diag) 1170 PetscCallCXX(KokkosSparse::spgemm_symbolic(mm->kh1, Adt, false, Bd, false, mm->C1)); 1171 PetscCallCXX(KokkosSparse::spgemm_symbolic(mm->kh2, Adt, false, Bo, false, mm->C2_mid)); 1172 PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh1, Adt, false, Bd, false, mm->C1)); 1173 PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh2, Adt, false, Bo, false, mm->C2_mid)); 1174 #if PETSC_PKG_KOKKOS_KERNELS_VERSION_LT(4, 0, 0) 1175 PetscCallCXX(sort_crs_matrix(mm->C1)); 1176 PetscCallCXX(sort_crs_matrix(mm->C2_mid)); 1177 #endif 1178 1179 PetscCall(MatMPIAIJKokkosReduceEnd(comm, mm->C3, mm->C4, cstart, cend, bmpi->garray, ampi->Mvctx, MAT_INITIAL_MATRIX, map_h.data(), mm)); 1180 1181 // Create C2, which shares a, i arrays with C2_mid, but with new column indices and potentially larger column size 1182 MatColIdxKokkosView oldj = mm->C2_mid.graph.entries, newj(NoInit("j"), oldj.extent(0)); 1183 PetscIntKokkosView map = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), map_h); 1184 PetscCallCXX(Kokkos::parallel_for( 1185 oldj.extent(0), KOKKOS_LAMBDA(const PetscInt i) { newj(i) = map(oldj(i)); })); 1186 PetscCallCXX(mm->C2 = KokkosCsrMatrix("C2", mm->C2_mid.numRows(), mm->n /*new column size*/, mm->C2_mid.nnz(), mm->C2_mid.values, mm->C2_mid.graph.row_map, newj)); 1187 1188 // C = (C1+Fd, C2+Fo) 1189 PetscCallCXX(mm->kh1.create_spadd_handle(true)); // C1, Fd are sorted 1190 PetscCallCXX(mm->kh2.create_spadd_handle(true)); // C2, Fo are sorted 1191 PetscCallCXX(KokkosSparse::spadd_symbolic(&mm->kh1, mm->C1, mm->Fd, mm->Cd)); 1192 PetscCallCXX(KokkosSparse::spadd_symbolic(&mm->kh2, mm->C2, mm->Fo, mm->Co)); 1193 PetscCallCXX(KokkosSparse::spadd_numeric(&mm->kh1, 1.0, mm->C1, 1.0, mm->Fd, mm->Cd)); 1194 PetscCallCXX(KokkosSparse::spadd_numeric(&mm->kh2, 1.0, mm->C2, 1.0, mm->Fo, mm->Co)); 1195 PetscFunctionReturn(PETSC_SUCCESS); 1196 } 1197 1198 static PetscErrorCode MatProductNumeric_MPIAIJKokkos_AtB(Mat_Product *product, Mat A, Mat B, MatMatStruct_AtB *mm) 1199 { 1200 Mat_MPIAIJ *ampi = static_cast<Mat_MPIAIJ *>(A->data); 1201 Mat_MPIAIJ *bmpi = static_cast<Mat_MPIAIJ *>(B->data); 1202 KokkosCsrMatrix Adt, Aot, Bd, Bo; 1203 MPI_Comm comm; 1204 1205 PetscFunctionBegin; 1206 PetscCall(PetscObjectGetComm((PetscObject)B, &comm)); 1207 PetscCall(MatSeqAIJKokkosGenerateTranspose_Private(ampi->A, &Adt)); 1208 PetscCall(MatSeqAIJKokkosGenerateTranspose_Private(ampi->B, &Aot)); 1209 PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(bmpi->A, &Bd)); 1210 PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(bmpi->B, &Bo)); 1211 1212 // Aot * (B's diag + B's off-diag) 1213 PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh3, Aot, false, Bd, false, mm->C3)); 1214 PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh4, Aot, false, Bo, false, mm->C4)); 1215 1216 // Reduce E (i.e., C3 and C4)'s rows to form F, and overlap the communication 1217 PetscCall(MatMPIAIJKokkosReduceBegin(comm, mm->C3, mm->C4, 0, 0, NULL, NULL, MAT_REUSE_MATRIX, NULL, mm)); 1218 1219 // Adt * (B's diag + B's off-diag) 1220 PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh1, Adt, false, Bd, false, mm->C1)); 1221 PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh2, Adt, false, Bo, false, mm->C2_mid)); 1222 1223 PetscCall(MatMPIAIJKokkosReduceEnd(comm, mm->C3, mm->C4, 0, 0, NULL, NULL, MAT_REUSE_MATRIX, NULL, mm)); 1224 1225 // C = (C1+Fd, C2+Fo) 1226 PetscCallCXX(KokkosSparse::spadd_numeric(&mm->kh1, 1.0, mm->C1, 1.0, mm->Fd, mm->Cd)); 1227 PetscCallCXX(KokkosSparse::spadd_numeric(&mm->kh2, 1.0, mm->C2, 1.0, mm->Fo, mm->Co)); 1228 PetscFunctionReturn(PETSC_SUCCESS); 1229 } 1230 1231 /* MatProductSymbolic_MPIAIJKokkos_AB - AB flavor of MatProductSymbolic_MPIAIJKokkos 1232 1233 Input Parameters: 1234 + product - Mat_Product which carried out the computation. Passed in to access info about this mat product. 1235 . A - an MPIAIJKOKKOS matrix 1236 . B - an MPIAIJKOKKOS matrix 1237 - mm - a struct used to stash intermediate data when computing AB. Persist from symbolic to numeric operations. 1238 */ 1239 static PetscErrorCode MatProductSymbolic_MPIAIJKokkos_AB(Mat_Product *product, Mat A, Mat B, MatMatStruct_AB *mm) 1240 { 1241 Mat_MPIAIJ *ampi = static_cast<Mat_MPIAIJ *>(A->data); 1242 Mat_MPIAIJ *bmpi = static_cast<Mat_MPIAIJ *>(B->data); 1243 KokkosCsrMatrix Ad, Ao, Bd, Bo; 1244 1245 PetscFunctionBegin; 1246 PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(ampi->A, &Ad)); 1247 PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(ampi->B, &Ao)); 1248 PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(bmpi->A, &Bd)); 1249 PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(bmpi->B, &Bo)); 1250 1251 // TODO: add command line options to select spgemm algorithms 1252 auto spgemm_alg = KokkosSparse::SPGEMMAlgorithm::SPGEMM_DEFAULT; // default is TPL if enabled, otherwise KK 1253 1254 // CUDA-10.2's spgemm has bugs. We prefer the SpGEMMreuse APIs introduced in cuda-11.4 1255 #if defined(KOKKOSKERNELS_ENABLE_TPL_CUSPARSE) 1256 #if PETSC_PKG_CUDA_VERSION_LT(11, 4, 0) 1257 spgemm_alg = KokkosSparse::SPGEMMAlgorithm::SPGEMM_KK; 1258 #endif 1259 #endif 1260 1261 mm->kh1.create_spgemm_handle(spgemm_alg); 1262 mm->kh2.create_spgemm_handle(spgemm_alg); 1263 mm->kh3.create_spgemm_handle(spgemm_alg); 1264 mm->kh4.create_spgemm_handle(spgemm_alg); 1265 1266 // Bcast B's rows to form F, and overlap the communication 1267 PetscIntKokkosViewHost map_h(NoInit("map_h"), bmpi->B->cmap->n); 1268 PetscCall(MatMPIAIJKokkosBcastBegin(B, ampi->Mvctx, MAT_INITIAL_MATRIX, map_h.data(), mm)); 1269 1270 // A's diag * (B's diag + B's off-diag) 1271 PetscCallCXX(KokkosSparse::spgemm_symbolic(mm->kh1, Ad, false, Bd, false, mm->C1)); 1272 PetscCallCXX(KokkosSparse::spgemm_symbolic(mm->kh2, Ad, false, Bo, false, mm->C2_mid)); // C2 aliases with C2_mid, except with new column indices 1273 // KK spgemm_symbolic() only populates the result's row map, but not its columns. 1274 // TODO: Remove the fake spgemm_numeric() after KK fixed this problem. 1275 PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh1, Ad, false, Bd, false, mm->C1)); 1276 PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh2, Ad, false, Bo, false, mm->C2_mid)); 1277 #if PETSC_PKG_KOKKOS_KERNELS_VERSION_LT(4, 0, 0) 1278 PetscCallCXX(sort_crs_matrix(mm->C1)); 1279 PetscCallCXX(sort_crs_matrix(mm->C2_mid)); 1280 #endif 1281 1282 PetscCall(MatMPIAIJKokkosBcastEnd(B, ampi->Mvctx, MAT_INITIAL_MATRIX, map_h.data(), mm)); 1283 1284 // A's off-diag * (F's diag + F's off-diag) 1285 PetscCallCXX(KokkosSparse::spgemm_symbolic(mm->kh3, Ao, false, mm->Fd, false, mm->C3)); 1286 PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh3, Ao, false, mm->Fd, false, mm->C3)); 1287 PetscCallCXX(KokkosSparse::spgemm_symbolic(mm->kh4, Ao, false, mm->Fo, false, mm->C4)); 1288 PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh4, Ao, false, mm->Fo, false, mm->C4)); 1289 #if PETSC_PKG_KOKKOS_KERNELS_VERSION_LT(4, 0, 0) 1290 PetscCallCXX(sort_crs_matrix(mm->C3)); 1291 PetscCallCXX(sort_crs_matrix(mm->C4)); 1292 #endif 1293 1294 // Create C2, which shares a, i arrays with C2_mid, but with new column indices and potentially larger column size 1295 MatColIdxKokkosView oldj = mm->C2_mid.graph.entries, newj(NoInit("j"), oldj.extent(0)); 1296 PetscIntKokkosView map = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), map_h); 1297 PetscCallCXX(Kokkos::parallel_for( 1298 oldj.extent(0), KOKKOS_LAMBDA(const PetscInt i) { newj(i) = map(oldj(i)); })); 1299 mm->C2 = KokkosCsrMatrix("C2", mm->C2_mid.numRows(), mm->n /*new column size*/, mm->C2_mid.nnz(), mm->C2_mid.values, mm->C2_mid.graph.row_map, newj); 1300 1301 // C = (Cd, Co) = (C1+C3, C2+C4) 1302 mm->kh1.create_spadd_handle(true); // C1, C3 are sorted 1303 mm->kh2.create_spadd_handle(true); // C2, C4 are sorted 1304 PetscCallCXX(KokkosSparse::spadd_symbolic(&mm->kh1, mm->C1, mm->C3, mm->Cd)); 1305 PetscCallCXX(KokkosSparse::spadd_symbolic(&mm->kh2, mm->C2, mm->C4, mm->Co)); 1306 PetscCallCXX(KokkosSparse::spadd_numeric(&mm->kh1, 1.0, mm->C1, 1.0, mm->C3, mm->Cd)); 1307 PetscCallCXX(KokkosSparse::spadd_numeric(&mm->kh2, 1.0, mm->C2, 1.0, mm->C4, mm->Co)); 1308 PetscFunctionReturn(PETSC_SUCCESS); 1309 } 1310 1311 static PetscErrorCode MatProductNumeric_MPIAIJKokkos_AB(Mat_Product *product, Mat A, Mat B, MatMatStruct_AB *mm) 1312 { 1313 Mat_MPIAIJ *ampi = static_cast<Mat_MPIAIJ *>(A->data); 1314 Mat_MPIAIJ *bmpi = static_cast<Mat_MPIAIJ *>(B->data); 1315 KokkosCsrMatrix Ad, Ao, Bd, Bo; 1316 1317 PetscFunctionBegin; 1318 PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(ampi->A, &Ad)); 1319 PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(ampi->B, &Ao)); 1320 PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(bmpi->A, &Bd)); 1321 PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(bmpi->B, &Bo)); 1322 1323 // Bcast B's rows to form F, and overlap the communication 1324 PetscCall(MatMPIAIJKokkosBcastBegin(B, NULL, MAT_REUSE_MATRIX, NULL, mm)); 1325 1326 // A's diag * (B's diag + B's off-diag) 1327 PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh1, Ad, false, Bd, false, mm->C1)); 1328 PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh2, Ad, false, Bo, false, mm->C2_mid)); 1329 1330 PetscCall(MatMPIAIJKokkosBcastEnd(B, NULL, MAT_REUSE_MATRIX, NULL, mm)); 1331 1332 // A's off-diag * (F's diag + F's off-diag) 1333 PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh3, Ao, false, mm->Fd, false, mm->C3)); 1334 PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh4, Ao, false, mm->Fo, false, mm->C4)); 1335 1336 // C = (Cd, Co) = (C1+C3, C2+C4) 1337 PetscCallCXX(KokkosSparse::spadd_numeric(&mm->kh1, 1.0, mm->C1, 1.0, mm->C3, mm->Cd)); 1338 PetscCallCXX(KokkosSparse::spadd_numeric(&mm->kh2, 1.0, mm->C2, 1.0, mm->C4, mm->Co)); 1339 PetscFunctionReturn(PETSC_SUCCESS); 1340 } 1341 1342 static PetscErrorCode MatProductNumeric_MPIAIJKokkos(Mat C) 1343 { 1344 Mat_MPIAIJ *cmpi = static_cast<Mat_MPIAIJ *>(C->data); 1345 Mat_Product *product; 1346 MatProductData_MPIAIJKokkos *pdata; 1347 MatProductType ptype; 1348 Mat A, B; 1349 1350 PetscFunctionBegin; 1351 MatCheckProduct(C, 1); // make sure C is a product 1352 product = C->product; 1353 pdata = static_cast<MatProductData_MPIAIJKokkos *>(product->data); 1354 ptype = product->type; 1355 A = product->A; 1356 B = product->B; 1357 1358 // See if numeric has already been done in symbolic (e.g., user calls MatMatMult(A,B,MAT_INITIAL_MATRIX,..,C)). 1359 // If yes, skip the numeric, but reset the flag so that next time when user calls MatMatMult(E,F,MAT_REUSE_MATRIX,..,C), 1360 // we still do numeric. 1361 if (pdata->reusesym) { // numeric reuses results from symbolic 1362 pdata->reusesym = PETSC_FALSE; 1363 PetscFunctionReturn(PETSC_SUCCESS); 1364 } 1365 1366 if (ptype == MATPRODUCT_AB) { 1367 PetscCall(MatProductNumeric_MPIAIJKokkos_AB(product, A, B, pdata->mmAB)); 1368 } else if (ptype == MATPRODUCT_AtB) { 1369 PetscCall(MatProductNumeric_MPIAIJKokkos_AtB(product, A, B, pdata->mmAtB)); 1370 } else if (ptype == MATPRODUCT_PtAP) { // BtAB, computed by Z = AB; C= BtZ 1371 PetscCall(MatProductNumeric_MPIAIJKokkos_AB(product, A, B, pdata->mmAB)); 1372 PetscCall(MatProductNumeric_MPIAIJKokkos_AtB(product, B, pdata->Z, pdata->mmAtB)); 1373 } 1374 1375 PetscCall(MatSeqAIJKokkosModifyDevice(cmpi->A)); // mark that A, B on device are modified 1376 PetscCall(MatSeqAIJKokkosModifyDevice(cmpi->B)); 1377 PetscFunctionReturn(PETSC_SUCCESS); 1378 } 1379 1380 static PetscErrorCode MatProductSymbolic_MPIAIJKokkos(Mat C) 1381 { 1382 Mat A, B; 1383 Mat_Product *product; 1384 MatProductType ptype; 1385 MatProductData_MPIAIJKokkos *pdata; 1386 MatMatStruct *mm = NULL; 1387 PetscInt m, n, M, N; 1388 Mat Cd, Co; 1389 MPI_Comm comm; 1390 1391 PetscFunctionBegin; 1392 PetscCall(PetscObjectGetComm((PetscObject)C, &comm)); 1393 MatCheckProduct(C, 1); 1394 product = C->product; 1395 PetscCheck(!product->data, comm, PETSC_ERR_PLIB, "Product data not empty"); 1396 ptype = product->type; 1397 A = product->A; 1398 B = product->B; 1399 1400 switch (ptype) { 1401 case MATPRODUCT_AB: 1402 m = A->rmap->n; 1403 n = B->cmap->n; 1404 M = A->rmap->N; 1405 N = B->cmap->N; 1406 break; 1407 case MATPRODUCT_AtB: 1408 m = A->cmap->n; 1409 n = B->cmap->n; 1410 M = A->cmap->N; 1411 N = B->cmap->N; 1412 break; 1413 case MATPRODUCT_PtAP: 1414 m = B->cmap->n; 1415 n = B->cmap->n; 1416 M = B->cmap->N; 1417 N = B->cmap->N; 1418 break; /* BtAB */ 1419 default: 1420 SETERRQ(comm, PETSC_ERR_PLIB, "Not for product type %s", MatProductTypes[ptype]); 1421 } 1422 1423 PetscCall(MatSetSizes(C, m, n, M, N)); 1424 PetscCall(PetscLayoutSetUp(C->rmap)); 1425 PetscCall(PetscLayoutSetUp(C->cmap)); 1426 PetscCall(MatSetType(C, ((PetscObject)A)->type_name)); 1427 1428 pdata = new MatProductData_MPIAIJKokkos(); 1429 pdata->reusesym = product->api_user; 1430 1431 if (ptype == MATPRODUCT_AB) { 1432 auto mmAB = new MatMatStruct_AB(); 1433 PetscCall(MatProductSymbolic_MPIAIJKokkos_AB(product, A, B, mmAB)); 1434 mm = pdata->mmAB = mmAB; 1435 } else if (ptype == MATPRODUCT_AtB) { 1436 auto mmAtB = new MatMatStruct_AtB(); 1437 PetscCall(MatProductSymbolic_MPIAIJKokkos_AtB(product, A, B, mmAtB)); 1438 mm = pdata->mmAtB = mmAtB; 1439 } else if (ptype == MATPRODUCT_PtAP) { // C = BtAB, computed as Z = AB; C= BtZ 1440 Mat Zd, Zo, Z; // Zd, Zo are owned by pdata->Z 1441 1442 auto mmAB = new MatMatStruct_AB(); 1443 PetscCall(MatProductSymbolic_MPIAIJKokkos_AB(product, A, B, mmAB)); // Z stored as mmAB->{Cd, Co} 1444 PetscCall(MatCreateSeqAIJKokkosWithKokkosCsrMatrix(PETSC_COMM_SELF, mmAB->Cd, &Zd)); 1445 PetscCall(MatCreateSeqAIJKokkosWithKokkosCsrMatrix(PETSC_COMM_SELF, mmAB->Co, &Zo)); 1446 pdata->mmAB = mmAB; 1447 1448 m = A->rmap->n; // Z's layout 1449 n = B->cmap->n; 1450 M = A->rmap->N; 1451 N = B->cmap->N; 1452 PetscCall(MatCreate(comm, &Z)); 1453 PetscCall(MatSetSizes(Z, m, n, M, N)); 1454 PetscCall(PetscLayoutSetUp(Z->rmap)); 1455 PetscCall(PetscLayoutSetUp(Z->cmap)); 1456 PetscCall(MatSetType(Z, MATMPIAIJKOKKOS)); 1457 PetscCall(MatSetMPIAIJKokkosWithSplitSeqAIJKokkosMatrices(Z, Zd, Zo, mmAB->garray)); 1458 1459 auto mmAtB = new MatMatStruct_AtB(); 1460 PetscCall(MatProductSymbolic_MPIAIJKokkos_AtB(product, B, Z, mmAtB)); // final result C stored as mmAtB->{Cd, Co} 1461 1462 pdata->Z = Z; // give ownership to pdata 1463 mm = pdata->mmAtB = mmAtB; 1464 } 1465 1466 PetscCall(MatCreateSeqAIJKokkosWithKokkosCsrMatrix(PETSC_COMM_SELF, mm->Cd, &Cd)); 1467 PetscCall(MatCreateSeqAIJKokkosWithKokkosCsrMatrix(PETSC_COMM_SELF, mm->Co, &Co)); 1468 PetscCall(MatSetMPIAIJKokkosWithSplitSeqAIJKokkosMatrices(C, Cd, Co, mm->garray)); 1469 1470 C->product->data = pdata; 1471 C->product->destroy = MatProductDataDestroy_MPIAIJKokkos; 1472 C->ops->productnumeric = MatProductNumeric_MPIAIJKokkos; 1473 PetscFunctionReturn(PETSC_SUCCESS); 1474 } 1475 1476 PETSC_INTERN PetscErrorCode MatProductSetFromOptions_MPIAIJKokkos(Mat mat) 1477 { 1478 Mat_Product *product = mat->product; 1479 PetscBool match = PETSC_FALSE; 1480 PetscBool usecpu = PETSC_FALSE; 1481 1482 PetscFunctionBegin; 1483 MatCheckProduct(mat, 1); 1484 if (!product->A->boundtocpu && !product->B->boundtocpu) PetscCall(PetscObjectTypeCompare((PetscObject)product->B, ((PetscObject)product->A)->type_name, &match)); 1485 if (match) { /* we can always fallback to the CPU if requested */ 1486 switch (product->type) { 1487 case MATPRODUCT_AB: 1488 if (product->api_user) { 1489 PetscOptionsBegin(PetscObjectComm((PetscObject)mat), ((PetscObject)mat)->prefix, "MatMatMult", "Mat"); 1490 PetscCall(PetscOptionsBool("-matmatmult_backend_cpu", "Use CPU code", "MatMatMult", usecpu, &usecpu, NULL)); 1491 PetscOptionsEnd(); 1492 } else { 1493 PetscOptionsBegin(PetscObjectComm((PetscObject)mat), ((PetscObject)mat)->prefix, "MatProduct_AB", "Mat"); 1494 PetscCall(PetscOptionsBool("-mat_product_algorithm_backend_cpu", "Use CPU code", "MatMatMult", usecpu, &usecpu, NULL)); 1495 PetscOptionsEnd(); 1496 } 1497 break; 1498 case MATPRODUCT_AtB: 1499 if (product->api_user) { 1500 PetscOptionsBegin(PetscObjectComm((PetscObject)mat), ((PetscObject)mat)->prefix, "MatTransposeMatMult", "Mat"); 1501 PetscCall(PetscOptionsBool("-mattransposematmult_backend_cpu", "Use CPU code", "MatTransposeMatMult", usecpu, &usecpu, NULL)); 1502 PetscOptionsEnd(); 1503 } else { 1504 PetscOptionsBegin(PetscObjectComm((PetscObject)mat), ((PetscObject)mat)->prefix, "MatProduct_AtB", "Mat"); 1505 PetscCall(PetscOptionsBool("-mat_product_algorithm_backend_cpu", "Use CPU code", "MatTransposeMatMult", usecpu, &usecpu, NULL)); 1506 PetscOptionsEnd(); 1507 } 1508 break; 1509 case MATPRODUCT_PtAP: 1510 if (product->api_user) { 1511 PetscOptionsBegin(PetscObjectComm((PetscObject)mat), ((PetscObject)mat)->prefix, "MatPtAP", "Mat"); 1512 PetscCall(PetscOptionsBool("-matptap_backend_cpu", "Use CPU code", "MatPtAP", usecpu, &usecpu, NULL)); 1513 PetscOptionsEnd(); 1514 } else { 1515 PetscOptionsBegin(PetscObjectComm((PetscObject)mat), ((PetscObject)mat)->prefix, "MatProduct_PtAP", "Mat"); 1516 PetscCall(PetscOptionsBool("-mat_product_algorithm_backend_cpu", "Use CPU code", "MatPtAP", usecpu, &usecpu, NULL)); 1517 PetscOptionsEnd(); 1518 } 1519 break; 1520 default: 1521 break; 1522 } 1523 match = (PetscBool)!usecpu; 1524 } 1525 if (match) { 1526 switch (product->type) { 1527 case MATPRODUCT_AB: 1528 case MATPRODUCT_AtB: 1529 case MATPRODUCT_PtAP: 1530 mat->ops->productsymbolic = MatProductSymbolic_MPIAIJKokkos; 1531 break; 1532 default: 1533 break; 1534 } 1535 } 1536 /* fallback to MPIAIJ ops */ 1537 if (!mat->ops->productsymbolic) PetscCall(MatProductSetFromOptions_MPIAIJ(mat)); 1538 PetscFunctionReturn(PETSC_SUCCESS); 1539 } 1540 1541 // Mirror of MatCOOStruct_MPIAIJ on device 1542 struct MatCOOStruct_MPIAIJKokkos { 1543 PetscCount n; 1544 PetscSF sf; 1545 PetscCount Annz, Bnnz; 1546 PetscCount Annz2, Bnnz2; 1547 PetscCountKokkosView Ajmap1, Aperm1; 1548 PetscCountKokkosView Bjmap1, Bperm1; 1549 PetscCountKokkosView Aimap2, Ajmap2, Aperm2; 1550 PetscCountKokkosView Bimap2, Bjmap2, Bperm2; 1551 PetscCountKokkosView Cperm1; 1552 MatScalarKokkosView sendbuf, recvbuf; 1553 1554 MatCOOStruct_MPIAIJKokkos(const MatCOOStruct_MPIAIJ *coo_h) : 1555 n(coo_h->n), 1556 sf(coo_h->sf), 1557 Annz(coo_h->Annz), 1558 Bnnz(coo_h->Bnnz), 1559 Annz2(coo_h->Annz2), 1560 Bnnz2(coo_h->Bnnz2), 1561 Ajmap1(Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), PetscCountKokkosViewHost(coo_h->Ajmap1, coo_h->Annz + 1))), 1562 Aperm1(Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), PetscCountKokkosViewHost(coo_h->Aperm1, coo_h->Atot1))), 1563 Bjmap1(Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), PetscCountKokkosViewHost(coo_h->Bjmap1, coo_h->Bnnz + 1))), 1564 Bperm1(Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), PetscCountKokkosViewHost(coo_h->Bperm1, coo_h->Btot1))), 1565 Aimap2(Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), PetscCountKokkosViewHost(coo_h->Aimap2, coo_h->Annz2))), 1566 Ajmap2(Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), PetscCountKokkosViewHost(coo_h->Ajmap2, coo_h->Annz2 + 1))), 1567 Aperm2(Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), PetscCountKokkosViewHost(coo_h->Aperm2, coo_h->Atot2))), 1568 Bimap2(Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), PetscCountKokkosViewHost(coo_h->Bimap2, coo_h->Bnnz2))), 1569 Bjmap2(Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), PetscCountKokkosViewHost(coo_h->Bjmap2, coo_h->Bnnz2 + 1))), 1570 Bperm2(Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), PetscCountKokkosViewHost(coo_h->Bperm2, coo_h->Btot2))), 1571 Cperm1(Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), PetscCountKokkosViewHost(coo_h->Cperm1, coo_h->sendlen))), 1572 sendbuf(Kokkos::create_mirror_view(Kokkos::WithoutInitializing, DefaultMemorySpace(), MatScalarKokkosViewHost(coo_h->sendbuf, coo_h->sendlen))), 1573 recvbuf(Kokkos::create_mirror_view(Kokkos::WithoutInitializing, DefaultMemorySpace(), MatScalarKokkosViewHost(coo_h->recvbuf, coo_h->recvlen))) 1574 { 1575 PetscCallVoid(PetscObjectReference((PetscObject)sf)); 1576 } 1577 1578 ~MatCOOStruct_MPIAIJKokkos() { PetscCallVoid(PetscSFDestroy(&sf)); } 1579 }; 1580 1581 static PetscErrorCode MatCOOStructDestroy_MPIAIJKokkos(void *data) 1582 { 1583 PetscFunctionBegin; 1584 PetscCallCXX(delete static_cast<MatCOOStruct_MPIAIJKokkos *>(data)); 1585 PetscFunctionReturn(PETSC_SUCCESS); 1586 } 1587 1588 static PetscErrorCode MatSetPreallocationCOO_MPIAIJKokkos(Mat mat, PetscCount coo_n, PetscInt coo_i[], PetscInt coo_j[]) 1589 { 1590 PetscContainer container_h, container_d; 1591 MatCOOStruct_MPIAIJ *coo_h; 1592 MatCOOStruct_MPIAIJKokkos *coo_d; 1593 1594 PetscFunctionBegin; 1595 PetscCall(MatSetPreallocationCOO_MPIAIJ(mat, coo_n, coo_i, coo_j)); /* mpiaij->A,B's type is set to seqaijkokkos */ 1596 mat->preallocated = PETSC_TRUE; 1597 PetscCall(MatAssemblyBegin(mat, MAT_FINAL_ASSEMBLY)); 1598 PetscCall(MatAssemblyEnd(mat, MAT_FINAL_ASSEMBLY)); 1599 PetscCall(MatZeroEntries(mat)); 1600 1601 // Copy the COO struct to device 1602 PetscCall(PetscObjectQuery((PetscObject)mat, "__PETSc_MatCOOStruct_Host", (PetscObject *)&container_h)); 1603 PetscCall(PetscContainerGetPointer(container_h, (void **)&coo_h)); 1604 PetscCallCXX(coo_d = new MatCOOStruct_MPIAIJKokkos(coo_h)); 1605 1606 // Put the COO struct in a container and then attach that to the matrix 1607 PetscCall(PetscContainerCreate(PETSC_COMM_SELF, &container_d)); 1608 PetscCall(PetscContainerSetPointer(container_d, coo_d)); 1609 PetscCall(PetscContainerSetUserDestroy(container_d, MatCOOStructDestroy_MPIAIJKokkos)); 1610 PetscCall(PetscObjectCompose((PetscObject)mat, "__PETSc_MatCOOStruct_Device", (PetscObject)container_d)); 1611 PetscCall(PetscContainerDestroy(&container_d)); 1612 PetscFunctionReturn(PETSC_SUCCESS); 1613 } 1614 1615 static PetscErrorCode MatSetValuesCOO_MPIAIJKokkos(Mat mat, const PetscScalar v[], InsertMode imode) 1616 { 1617 Mat_MPIAIJ *mpiaij = static_cast<Mat_MPIAIJ *>(mat->data); 1618 Mat A = mpiaij->A, B = mpiaij->B; 1619 MatScalarKokkosView Aa, Ba; 1620 MatScalarKokkosView v1; 1621 PetscMemType memtype; 1622 PetscContainer container; 1623 MatCOOStruct_MPIAIJKokkos *coo; 1624 1625 PetscFunctionBegin; 1626 PetscCall(PetscObjectQuery((PetscObject)mat, "__PETSc_MatCOOStruct_Device", (PetscObject *)&container)); 1627 PetscCall(PetscContainerGetPointer(container, (void **)&coo)); 1628 1629 const auto &n = coo->n; 1630 const auto &Annz = coo->Annz; 1631 const auto &Annz2 = coo->Annz2; 1632 const auto &Bnnz = coo->Bnnz; 1633 const auto &Bnnz2 = coo->Bnnz2; 1634 const auto &vsend = coo->sendbuf; 1635 const auto &v2 = coo->recvbuf; 1636 const auto &Ajmap1 = coo->Ajmap1; 1637 const auto &Ajmap2 = coo->Ajmap2; 1638 const auto &Aimap2 = coo->Aimap2; 1639 const auto &Bjmap1 = coo->Bjmap1; 1640 const auto &Bjmap2 = coo->Bjmap2; 1641 const auto &Bimap2 = coo->Bimap2; 1642 const auto &Aperm1 = coo->Aperm1; 1643 const auto &Aperm2 = coo->Aperm2; 1644 const auto &Bperm1 = coo->Bperm1; 1645 const auto &Bperm2 = coo->Bperm2; 1646 const auto &Cperm1 = coo->Cperm1; 1647 1648 PetscCall(PetscGetMemType(v, &memtype)); /* Return PETSC_MEMTYPE_HOST when v is NULL */ 1649 if (PetscMemTypeHost(memtype)) { /* If user gave v[] in host, we need to copy it to device if any */ 1650 v1 = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), MatScalarKokkosViewHost((PetscScalar *)v, n)); 1651 } else { 1652 v1 = MatScalarKokkosView((PetscScalar *)v, n); /* Directly use v[]'s memory */ 1653 } 1654 1655 if (imode == INSERT_VALUES) { 1656 PetscCall(MatSeqAIJGetKokkosViewWrite(A, &Aa)); /* write matrix values */ 1657 PetscCall(MatSeqAIJGetKokkosViewWrite(B, &Ba)); 1658 } else { 1659 PetscCall(MatSeqAIJGetKokkosView(A, &Aa)); /* read & write matrix values */ 1660 PetscCall(MatSeqAIJGetKokkosView(B, &Ba)); 1661 } 1662 1663 PetscCall(PetscLogGpuTimeBegin()); 1664 /* Pack entries to be sent to remote */ 1665 Kokkos::parallel_for( 1666 vsend.extent(0), KOKKOS_LAMBDA(const PetscCount i) { vsend(i) = v1(Cperm1(i)); }); 1667 1668 /* Send remote entries to their owner and overlap the communication with local computation */ 1669 PetscCall(PetscSFReduceWithMemTypeBegin(coo->sf, MPIU_SCALAR, PETSC_MEMTYPE_KOKKOS, vsend.data(), PETSC_MEMTYPE_KOKKOS, v2.data(), MPI_REPLACE)); 1670 /* Add local entries to A and B in one kernel */ 1671 Kokkos::parallel_for( 1672 Annz + Bnnz, KOKKOS_LAMBDA(PetscCount i) { 1673 PetscScalar sum = 0.0; 1674 if (i < Annz) { 1675 for (PetscCount k = Ajmap1(i); k < Ajmap1(i + 1); k++) sum += v1(Aperm1(k)); 1676 Aa(i) = (imode == INSERT_VALUES ? 0.0 : Aa(i)) + sum; 1677 } else { 1678 i -= Annz; 1679 for (PetscCount k = Bjmap1(i); k < Bjmap1(i + 1); k++) sum += v1(Bperm1(k)); 1680 Ba(i) = (imode == INSERT_VALUES ? 0.0 : Ba(i)) + sum; 1681 } 1682 }); 1683 PetscCall(PetscSFReduceEnd(coo->sf, MPIU_SCALAR, vsend.data(), v2.data(), MPI_REPLACE)); 1684 1685 /* Add received remote entries to A and B in one kernel */ 1686 Kokkos::parallel_for( 1687 Annz2 + Bnnz2, KOKKOS_LAMBDA(PetscCount i) { 1688 if (i < Annz2) { 1689 for (PetscCount k = Ajmap2(i); k < Ajmap2(i + 1); k++) Aa(Aimap2(i)) += v2(Aperm2(k)); 1690 } else { 1691 i -= Annz2; 1692 for (PetscCount k = Bjmap2(i); k < Bjmap2(i + 1); k++) Ba(Bimap2(i)) += v2(Bperm2(k)); 1693 } 1694 }); 1695 PetscCall(PetscLogGpuTimeEnd()); 1696 1697 if (imode == INSERT_VALUES) { 1698 PetscCall(MatSeqAIJRestoreKokkosViewWrite(A, &Aa)); /* Increase A & B's state etc. */ 1699 PetscCall(MatSeqAIJRestoreKokkosViewWrite(B, &Ba)); 1700 } else { 1701 PetscCall(MatSeqAIJRestoreKokkosView(A, &Aa)); 1702 PetscCall(MatSeqAIJRestoreKokkosView(B, &Ba)); 1703 } 1704 PetscFunctionReturn(PETSC_SUCCESS); 1705 } 1706 1707 static PetscErrorCode MatDestroy_MPIAIJKokkos(Mat A) 1708 { 1709 PetscFunctionBegin; 1710 PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatMPIAIJSetPreallocation_C", NULL)); 1711 PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatMPIAIJGetLocalMatMerge_C", NULL)); 1712 PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatSetPreallocationCOO_C", NULL)); 1713 PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatSetValuesCOO_C", NULL)); 1714 PetscCall(MatDestroy_MPIAIJ(A)); 1715 PetscFunctionReturn(PETSC_SUCCESS); 1716 } 1717 1718 static PetscErrorCode MatShift_MPIAIJKokkos(Mat A, PetscScalar a) 1719 { 1720 Mat_MPIAIJ *mpiaij = static_cast<Mat_MPIAIJ *>(A->data); 1721 PetscBool congruent; 1722 1723 PetscFunctionBegin; 1724 PetscCall(MatHasCongruentLayouts(A, &congruent)); 1725 if (congruent) { // square matrix and the diagonals are solely in the diag block 1726 PetscCall(MatShift(mpiaij->A, a)); 1727 } else { // too hard, use the general version 1728 PetscCall(MatShift_Basic(A, a)); 1729 } 1730 PetscFunctionReturn(PETSC_SUCCESS); 1731 } 1732 1733 static PetscErrorCode MatSetOps_MPIAIJKokkos(Mat B) 1734 { 1735 PetscFunctionBegin; 1736 B->ops->assemblyend = MatAssemblyEnd_MPIAIJKokkos; 1737 B->ops->mult = MatMult_MPIAIJKokkos; 1738 B->ops->multadd = MatMultAdd_MPIAIJKokkos; 1739 B->ops->multtranspose = MatMultTranspose_MPIAIJKokkos; 1740 B->ops->productsetfromoptions = MatProductSetFromOptions_MPIAIJKokkos; 1741 B->ops->destroy = MatDestroy_MPIAIJKokkos; 1742 B->ops->shift = MatShift_MPIAIJKokkos; 1743 1744 PetscCall(PetscObjectComposeFunction((PetscObject)B, "MatMPIAIJSetPreallocation_C", MatMPIAIJSetPreallocation_MPIAIJKokkos)); 1745 PetscCall(PetscObjectComposeFunction((PetscObject)B, "MatMPIAIJGetLocalMatMerge_C", MatMPIAIJGetLocalMatMerge_MPIAIJKokkos)); 1746 PetscCall(PetscObjectComposeFunction((PetscObject)B, "MatSetPreallocationCOO_C", MatSetPreallocationCOO_MPIAIJKokkos)); 1747 PetscCall(PetscObjectComposeFunction((PetscObject)B, "MatSetValuesCOO_C", MatSetValuesCOO_MPIAIJKokkos)); 1748 PetscFunctionReturn(PETSC_SUCCESS); 1749 } 1750 1751 PETSC_INTERN PetscErrorCode MatConvert_MPIAIJ_MPIAIJKokkos(Mat A, MatType mtype, MatReuse reuse, Mat *newmat) 1752 { 1753 Mat B; 1754 Mat_MPIAIJ *a; 1755 1756 PetscFunctionBegin; 1757 if (reuse == MAT_INITIAL_MATRIX) { 1758 PetscCall(MatDuplicate(A, MAT_COPY_VALUES, newmat)); 1759 } else if (reuse == MAT_REUSE_MATRIX) { 1760 PetscCall(MatCopy(A, *newmat, SAME_NONZERO_PATTERN)); 1761 } 1762 B = *newmat; 1763 1764 B->boundtocpu = PETSC_FALSE; 1765 PetscCall(PetscFree(B->defaultvectype)); 1766 PetscCall(PetscStrallocpy(VECKOKKOS, &B->defaultvectype)); 1767 PetscCall(PetscObjectChangeTypeName((PetscObject)B, MATMPIAIJKOKKOS)); 1768 1769 a = static_cast<Mat_MPIAIJ *>(A->data); 1770 if (a->A) PetscCall(MatSetType(a->A, MATSEQAIJKOKKOS)); 1771 if (a->B) PetscCall(MatSetType(a->B, MATSEQAIJKOKKOS)); 1772 if (a->lvec) PetscCall(VecSetType(a->lvec, VECSEQKOKKOS)); 1773 PetscCall(MatSetOps_MPIAIJKokkos(B)); 1774 PetscFunctionReturn(PETSC_SUCCESS); 1775 } 1776 1777 /*MC 1778 MATAIJKOKKOS - "mpiaijkokkos", a matrix type to be used for CSR sparse matrices with Kokkos 1779 1780 A matrix type using Kokkos-Kernels CrsMatrix type for portability across different device types 1781 1782 Options Database Key: 1783 . -mat_type aijkokkos - sets the matrix type to `MATAIJKOKKOS` 1784 1785 Level: beginner 1786 1787 .seealso: [](ch_matrices), `Mat`, `MatCreateAIJKokkos()`, `MATSEQAIJKOKKOS`, `MATSEQAIJ`, `MATMPIAIJ` 1788 M*/ 1789 PETSC_EXTERN PetscErrorCode MatCreate_MPIAIJKokkos(Mat A) 1790 { 1791 PetscFunctionBegin; 1792 PetscCall(PetscKokkosInitializeCheck()); 1793 PetscCall(MatCreate_MPIAIJ(A)); 1794 PetscCall(MatConvert_MPIAIJ_MPIAIJKokkos(A, MATMPIAIJKOKKOS, MAT_INPLACE_MATRIX, &A)); 1795 PetscFunctionReturn(PETSC_SUCCESS); 1796 } 1797 1798 /*@C 1799 MatCreateAIJKokkos - Creates a sparse matrix in `MATAIJKOKOS` (compressed row) format 1800 (the default parallel PETSc format). This matrix will ultimately pushed down 1801 to Kokkos for calculations. 1802 1803 Collective 1804 1805 Input Parameters: 1806 + comm - MPI communicator, set to `PETSC_COMM_SELF` 1807 . m - number of local rows (or `PETSC_DECIDE` to have calculated if `M` is given) 1808 This value should be the same as the local size used in creating the 1809 y vector for the matrix-vector product y = Ax. 1810 . n - This value should be the same as the local size used in creating the 1811 x vector for the matrix-vector product y = Ax. (or `PETSC_DECIDE` to have 1812 calculated if N is given) For square matrices n is almost always `m`. 1813 . M - number of global rows (or `PETSC_DETERMINE` to have calculated if `m` is given) 1814 . N - number of global columns (or `PETSC_DETERMINE` to have calculated if `n` is given) 1815 . d_nz - number of nonzeros per row in DIAGONAL portion of local submatrix 1816 (same value is used for all local rows) 1817 . d_nnz - array containing the number of nonzeros in the various rows of the 1818 DIAGONAL portion of the local submatrix (possibly different for each row) 1819 or `NULL`, if `d_nz` is used to specify the nonzero structure. 1820 The size of this array is equal to the number of local rows, i.e `m`. 1821 For matrices you plan to factor you must leave room for the diagonal entry and 1822 put in the entry even if it is zero. 1823 . o_nz - number of nonzeros per row in the OFF-DIAGONAL portion of local 1824 submatrix (same value is used for all local rows). 1825 - o_nnz - array containing the number of nonzeros in the various rows of the 1826 OFF-DIAGONAL portion of the local submatrix (possibly different for 1827 each row) or `NULL`, if `o_nz` is used to specify the nonzero 1828 structure. The size of this array is equal to the number 1829 of local rows, i.e `m`. 1830 1831 Output Parameter: 1832 . A - the matrix 1833 1834 Level: intermediate 1835 1836 Notes: 1837 It is recommended that one use the `MatCreate()`, `MatSetType()` and/or `MatSetFromOptions()`, 1838 MatXXXXSetPreallocation() paradigm instead of this routine directly. 1839 [MatXXXXSetPreallocation() is, for example, `MatSeqAIJSetPreallocation()`] 1840 1841 The AIJ format, also called compressed row storage), is fully compatible with standard Fortran 1842 storage. That is, the stored row and column indices can begin at 1843 either one (as in Fortran) or zero. 1844 1845 .seealso: [](ch_matrices), `Mat`, `MATAIJKOKOS`, `MATSEQAIJKOKOS`, `MATMPIAIJKOKOS`, `MatCreate()`, `MatCreateAIJ()`, `MatSetValues()`, 1846 `MatSeqAIJSetColumnIndices()`, `MatCreateSeqAIJWithArrays()`, `MATMPIAIJKOKKOS`, `MATAIJKOKKOS` 1847 @*/ 1848 PetscErrorCode MatCreateAIJKokkos(MPI_Comm comm, PetscInt m, PetscInt n, PetscInt M, PetscInt N, PetscInt d_nz, const PetscInt d_nnz[], PetscInt o_nz, const PetscInt o_nnz[], Mat *A) 1849 { 1850 PetscMPIInt size; 1851 1852 PetscFunctionBegin; 1853 PetscCall(MatCreate(comm, A)); 1854 PetscCall(MatSetSizes(*A, m, n, M, N)); 1855 PetscCallMPI(MPI_Comm_size(comm, &size)); 1856 if (size > 1) { 1857 PetscCall(MatSetType(*A, MATMPIAIJKOKKOS)); 1858 PetscCall(MatMPIAIJSetPreallocation(*A, d_nz, d_nnz, o_nz, o_nnz)); 1859 } else { 1860 PetscCall(MatSetType(*A, MATSEQAIJKOKKOS)); 1861 PetscCall(MatSeqAIJSetPreallocation(*A, d_nz, d_nnz)); 1862 } 1863 PetscFunctionReturn(PETSC_SUCCESS); 1864 } 1865