1 #include <petsc_kokkos.hpp> 2 #include <petscvec_kokkos.hpp> 3 #include <petscpkg_version.h> 4 #include <petsc/private/sfimpl.h> 5 #include <../src/mat/impls/aij/seq/kokkos/aijkok.hpp> 6 #include <../src/mat/impls/aij/mpi/mpiaij.h> 7 #include <KokkosSparse_spadd.hpp> 8 #include <KokkosSparse_spgemm.hpp> 9 10 static PetscErrorCode MatAssemblyEnd_MPIAIJKokkos(Mat A, MatAssemblyType mode) 11 { 12 Mat_MPIAIJ *mpiaij = (Mat_MPIAIJ *)A->data; 13 14 PetscFunctionBegin; 15 PetscCall(MatAssemblyEnd_MPIAIJ(A, mode)); 16 /* E.g., MatCreateSubMatrix() calls MatCreateMPIAIJWithSeqAIJ(comm,A,B,..), which creates Bnew of SEQAIJ and destroys B of SEQAIJKOKKOS. 17 Thus we finalize A/B/lvec's type in MatAssemblyEnd() to handle various cases. 18 */ 19 if (mode == MAT_FINAL_ASSEMBLY) { 20 PetscCall(MatSetType(mpiaij->A, MATSEQAIJKOKKOS)); 21 PetscCall(MatSetType(mpiaij->B, MATSEQAIJKOKKOS)); 22 PetscCall(VecSetType(mpiaij->lvec, VECSEQKOKKOS)); 23 } 24 PetscFunctionReturn(PETSC_SUCCESS); 25 } 26 27 static PetscErrorCode MatMPIAIJSetPreallocation_MPIAIJKokkos(Mat mat, PetscInt d_nz, const PetscInt d_nnz[], PetscInt o_nz, const PetscInt o_nnz[]) 28 { 29 Mat_MPIAIJ *mpiaij = (Mat_MPIAIJ *)mat->data; 30 31 PetscFunctionBegin; 32 // If mat was set to use the "set values with a hash table" mechanism, discard it and restore the cached ops 33 if (mat->hash_active) { 34 mat->ops[0] = mpiaij->cops; 35 mat->hash_active = PETSC_FALSE; 36 } 37 38 PetscCall(PetscLayoutSetUp(mat->rmap)); 39 PetscCall(PetscLayoutSetUp(mat->cmap)); 40 #if defined(PETSC_USE_DEBUG) 41 if (d_nnz) { 42 PetscInt i; 43 for (i = 0; i < mat->rmap->n; i++) PetscCheck(d_nnz[i] >= 0, PETSC_COMM_SELF, PETSC_ERR_ARG_OUTOFRANGE, "d_nnz cannot be less than 0: local row %" PetscInt_FMT " value %" PetscInt_FMT, i, d_nnz[i]); 44 } 45 if (o_nnz) { 46 PetscInt i; 47 for (i = 0; i < mat->rmap->n; i++) PetscCheck(o_nnz[i] >= 0, PETSC_COMM_SELF, PETSC_ERR_ARG_OUTOFRANGE, "o_nnz cannot be less than 0: local row %" PetscInt_FMT " value %" PetscInt_FMT, i, o_nnz[i]); 48 } 49 #endif 50 #if defined(PETSC_USE_CTABLE) 51 PetscCall(PetscHMapIDestroy(&mpiaij->colmap)); 52 #else 53 PetscCall(PetscFree(mpiaij->colmap)); 54 #endif 55 PetscCall(PetscFree(mpiaij->garray)); 56 PetscCall(VecDestroy(&mpiaij->lvec)); 57 PetscCall(VecScatterDestroy(&mpiaij->Mvctx)); 58 /* Because the B will have been resized we simply destroy it and create a new one each time */ 59 PetscCall(MatDestroy(&mpiaij->B)); 60 61 if (!mpiaij->A) { 62 PetscCall(MatCreate(PETSC_COMM_SELF, &mpiaij->A)); 63 PetscCall(MatSetSizes(mpiaij->A, mat->rmap->n, mat->cmap->n, mat->rmap->n, mat->cmap->n)); 64 } 65 if (!mpiaij->B) { 66 PetscMPIInt size; 67 PetscCallMPI(MPI_Comm_size(PetscObjectComm((PetscObject)mat), &size)); 68 PetscCall(MatCreate(PETSC_COMM_SELF, &mpiaij->B)); 69 PetscCall(MatSetSizes(mpiaij->B, mat->rmap->n, size > 1 ? mat->cmap->N : 0, mat->rmap->n, size > 1 ? mat->cmap->N : 0)); 70 } 71 PetscCall(MatSetType(mpiaij->A, MATSEQAIJKOKKOS)); 72 PetscCall(MatSetType(mpiaij->B, MATSEQAIJKOKKOS)); 73 PetscCall(MatSeqAIJSetPreallocation(mpiaij->A, d_nz, d_nnz)); 74 PetscCall(MatSeqAIJSetPreallocation(mpiaij->B, o_nz, o_nnz)); 75 mat->preallocated = PETSC_TRUE; 76 PetscFunctionReturn(PETSC_SUCCESS); 77 } 78 79 static PetscErrorCode MatMult_MPIAIJKokkos(Mat mat, Vec xx, Vec yy) 80 { 81 Mat_MPIAIJ *mpiaij = (Mat_MPIAIJ *)mat->data; 82 PetscInt nt; 83 84 PetscFunctionBegin; 85 PetscCall(VecGetLocalSize(xx, &nt)); 86 PetscCheck(nt == mat->cmap->n, PETSC_COMM_SELF, PETSC_ERR_ARG_SIZ, "Incompatible partition of mat (%" PetscInt_FMT ") and xx (%" PetscInt_FMT ")", mat->cmap->n, nt); 87 PetscCall(VecScatterBegin(mpiaij->Mvctx, xx, mpiaij->lvec, INSERT_VALUES, SCATTER_FORWARD)); 88 PetscCall((*mpiaij->A->ops->mult)(mpiaij->A, xx, yy)); 89 PetscCall(VecScatterEnd(mpiaij->Mvctx, xx, mpiaij->lvec, INSERT_VALUES, SCATTER_FORWARD)); 90 PetscCall((*mpiaij->B->ops->multadd)(mpiaij->B, mpiaij->lvec, yy, yy)); 91 PetscFunctionReturn(PETSC_SUCCESS); 92 } 93 94 static PetscErrorCode MatMultAdd_MPIAIJKokkos(Mat mat, Vec xx, Vec yy, Vec zz) 95 { 96 Mat_MPIAIJ *mpiaij = (Mat_MPIAIJ *)mat->data; 97 PetscInt nt; 98 99 PetscFunctionBegin; 100 PetscCall(VecGetLocalSize(xx, &nt)); 101 PetscCheck(nt == mat->cmap->n, PETSC_COMM_SELF, PETSC_ERR_ARG_SIZ, "Incompatible partition of mat (%" PetscInt_FMT ") and xx (%" PetscInt_FMT ")", mat->cmap->n, nt); 102 PetscCall(VecScatterBegin(mpiaij->Mvctx, xx, mpiaij->lvec, INSERT_VALUES, SCATTER_FORWARD)); 103 PetscCall((*mpiaij->A->ops->multadd)(mpiaij->A, xx, yy, zz)); 104 PetscCall(VecScatterEnd(mpiaij->Mvctx, xx, mpiaij->lvec, INSERT_VALUES, SCATTER_FORWARD)); 105 PetscCall((*mpiaij->B->ops->multadd)(mpiaij->B, mpiaij->lvec, zz, zz)); 106 PetscFunctionReturn(PETSC_SUCCESS); 107 } 108 109 static PetscErrorCode MatMultTranspose_MPIAIJKokkos(Mat mat, Vec xx, Vec yy) 110 { 111 Mat_MPIAIJ *mpiaij = (Mat_MPIAIJ *)mat->data; 112 PetscInt nt; 113 114 PetscFunctionBegin; 115 PetscCall(VecGetLocalSize(xx, &nt)); 116 PetscCheck(nt == mat->rmap->n, PETSC_COMM_SELF, PETSC_ERR_ARG_SIZ, "Incompatible partition of mat (%" PetscInt_FMT ") and xx (%" PetscInt_FMT ")", mat->rmap->n, nt); 117 PetscCall((*mpiaij->B->ops->multtranspose)(mpiaij->B, xx, mpiaij->lvec)); 118 PetscCall((*mpiaij->A->ops->multtranspose)(mpiaij->A, xx, yy)); 119 PetscCall(VecScatterBegin(mpiaij->Mvctx, mpiaij->lvec, yy, ADD_VALUES, SCATTER_REVERSE)); 120 PetscCall(VecScatterEnd(mpiaij->Mvctx, mpiaij->lvec, yy, ADD_VALUES, SCATTER_REVERSE)); 121 PetscFunctionReturn(PETSC_SUCCESS); 122 } 123 124 /* Merge the "A, B" matrices of mat into a matrix C. mat's type is MPIAIJKOKKOS. C's type is MATSEQAIJKOKKOS. 125 A is put before B. C's size would be A->rmap->n by (A->cmap->n + B->cmap->n). 126 C still uses local column ids. Their corresponding global column ids are returned in glob. 127 */ 128 static PetscErrorCode MatMPIAIJGetLocalMatMerge_MPIAIJKokkos(Mat mat, MatReuse reuse, IS *glob, Mat *C) 129 { 130 Mat Ad, Ao; 131 const PetscInt *cmap; 132 133 PetscFunctionBegin; 134 PetscCall(MatMPIAIJGetSeqAIJ(mat, &Ad, &Ao, &cmap)); 135 PetscCall(MatSeqAIJKokkosMergeMats(Ad, Ao, reuse, C)); 136 if (glob) { 137 PetscInt cst, i, dn, on, *gidx; 138 PetscCall(MatGetLocalSize(Ad, NULL, &dn)); 139 PetscCall(MatGetLocalSize(Ao, NULL, &on)); 140 PetscCall(MatGetOwnershipRangeColumn(mat, &cst, NULL)); 141 PetscCall(PetscMalloc1(dn + on, &gidx)); 142 for (i = 0; i < dn; i++) gidx[i] = cst + i; 143 for (i = 0; i < on; i++) gidx[i + dn] = cmap[i]; 144 PetscCall(ISCreateGeneral(PetscObjectComm((PetscObject)Ad), dn + on, gidx, PETSC_OWN_POINTER, glob)); 145 } 146 PetscFunctionReturn(PETSC_SUCCESS); 147 } 148 149 /* Structs used in matrix products of type C=AB, C=A^tB and C=B^tAB */ 150 struct MatMatStruct { 151 PetscInt n, *garray; // C's garray and its size. 152 KokkosCsrMatrix Cd, Co; // C is in split form matrices (all in local column indcies) 153 KokkosCsrMatrix C1, C2, C3, C4; // intermediate mat products 154 KokkosCsrMatrix C2_mid, C4_mid; // alias of C2, C4; share their a[], i[], but with different j[] (hence column size) 155 PetscIntKokkosView E_NzLeft; 156 PetscSF sf = nullptr; // SF to bcast or reduce matrices E to F 157 MatScalarKokkosView rootBuf, leafBuf; 158 KokkosCsrMatrix Fd, Fo; // F in split form 159 160 KernelHandle kh1; // compute C1, add C1+C3 or C1+Fd 161 KernelHandle kh2; // compute C2, add C2+C4 or C2+Fo 162 KernelHandle kh3; // compute C3 163 KernelHandle kh4; // compute C4 164 165 PetscInt E_TeamSize; // kernel launching parameters in merging E or splitting F 166 PetscInt E_VectorLength; 167 PetscInt E_RowsPerTeam; 168 PetscInt F_TeamSize; 169 PetscInt F_VectorLength; 170 PetscInt F_RowsPerTeam; 171 172 ~MatMatStruct() 173 { 174 PetscFunctionBegin; 175 PetscCallAbort(PETSC_COMM_SELF, PetscSFDestroy(&sf)); 176 PetscFunctionReturnVoid(); 177 } 178 }; 179 180 struct MatMatStruct_AB : public MatMatStruct { 181 PetscIntKokkosView F_NzLeft; // plans to split F (in leafbuf) into Fd, Fo 182 PetscIntKokkosView irootloc; // plans to put E (i.e., Bd, Bo) into rootBuf 183 PetscIntKokkosView rowoffset; 184 }; 185 186 struct MatMatStruct_AtB : public MatMatStruct { 187 MatColIdxKokkosView Fdjmap; // plans to reduce data in rootBuf to Fd, Fo 188 MatColIdxKokkosView Fdjperm; 189 MatColIdxKokkosView Fojmap; 190 MatColIdxKokkosView Fojperm; 191 }; 192 193 struct MatProductData_MPIAIJKokkos { 194 MatMatStruct_AB *mmAB = nullptr; 195 MatMatStruct_AtB *mmAtB = nullptr; 196 PetscBool reusesym = PETSC_FALSE; 197 Mat Z = nullptr; // store Z=AB in computing BtAB 198 199 ~MatProductData_MPIAIJKokkos() 200 { 201 delete mmAB; 202 delete mmAtB; 203 PetscCallAbort(PETSC_COMM_SELF, MatDestroy(&Z)); 204 } 205 }; 206 207 static PetscErrorCode MatProductDataDestroy_MPIAIJKokkos(void *data) 208 { 209 PetscFunctionBegin; 210 PetscCallCXX(delete static_cast<MatProductData_MPIAIJKokkos *>(data)); 211 PetscFunctionReturn(PETSC_SUCCESS); 212 } 213 214 /* MatSetMPIAIJKokkosWithSplitSeqAIJKokkosMatrices - Set the diag and offdiag matrices of a MATMPIAIJKOKKOS matrix. 215 It is similar to MatCreateMPIAIJWithSplitArrays. 216 217 Input Parameters: 218 + mat - the MATMPIAIJKOKKOS matrix, which should have its type and layout set, but should not have its diag, offdiag matrices set 219 . A - the diag matrix using local col ids 220 - B - the offdiag matrix using global col ids 221 222 Output Parameter: 223 . mat - the updated MATMPIAIJKOKKOS matrix 224 */ 225 static PetscErrorCode MatSetMPIAIJKokkosWithSplitSeqAIJKokkosMatrices(Mat mat, Mat A, Mat B, PetscInt *garray) 226 { 227 Mat_MPIAIJ *mpiaij = static_cast<Mat_MPIAIJ *>(mat->data); 228 PetscInt m, n, M, N, Am, An, Bm, Bn; 229 230 PetscFunctionBegin; 231 PetscCall(MatGetSize(mat, &M, &N)); 232 PetscCall(MatGetLocalSize(mat, &m, &n)); 233 PetscCall(MatGetLocalSize(A, &Am, &An)); 234 PetscCall(MatGetLocalSize(B, &Bm, &Bn)); 235 236 PetscCheck(m == Am && m == Bm, PETSC_COMM_SELF, PETSC_ERR_PLIB, "local number of rows do not match"); 237 PetscCheck(n == An, PETSC_COMM_SELF, PETSC_ERR_PLIB, "local number of columns do not match"); 238 // PetscCheck(N == Bn, PETSC_COMM_SELF, PETSC_ERR_PLIB, "global number of columns do not match"); 239 PetscCheck(!mpiaij->A && !mpiaij->B, PETSC_COMM_SELF, PETSC_ERR_PLIB, "A, B of the MPIAIJ matrix are not empty"); 240 mpiaij->A = A; 241 mpiaij->B = B; 242 mpiaij->garray = garray; 243 244 mat->preallocated = PETSC_TRUE; 245 mat->nooffprocentries = PETSC_TRUE; /* See MatAssemblyBegin_MPIAIJ. In effect, making MatAssemblyBegin a nop */ 246 247 PetscCall(MatSetOption(mat, MAT_NO_OFF_PROC_ENTRIES, PETSC_TRUE)); 248 PetscCall(MatAssemblyBegin(mat, MAT_FINAL_ASSEMBLY)); 249 /* MatAssemblyEnd is critical here. It sets mat->offloadmask according to A and B's, and 250 also gets mpiaij->B compacted, with its col ids and size reduced 251 */ 252 PetscCall(MatAssemblyEnd(mat, MAT_FINAL_ASSEMBLY)); 253 PetscCall(MatSetOption(mat, MAT_NO_OFF_PROC_ENTRIES, PETSC_FALSE)); 254 PetscCall(MatSetOption(mat, MAT_NEW_NONZERO_LOCATION_ERR, PETSC_TRUE)); 255 PetscFunctionReturn(PETSC_SUCCESS); 256 } 257 258 // Adapted from Kokkos-Kernels spmv_launch_parameters(), to get parameters in Kokkos nested loops which we used to merge or 259 // split csr matrices. The rule is to have "vector_length * team_size" be around 256 on GPUs (e.g., for a CUDA thread block) 260 template <class ExecutionSpace> 261 static PetscErrorCode MatMergeGetLaunchParameters(PetscInt numRows, PetscInt nnz, PetscInt rows_per_thread, PetscInt &team_size, PetscInt &vector_length, PetscInt &rows_per_team) 262 { 263 Kokkos::TeamPolicy<ExecutionSpace> teamPolicy(128, Kokkos::AUTO); 264 265 PetscFunctionBegin; 266 PetscInt nnz_per_row = numRows ? (nnz / numRows) : 0; // we might meet empty matrices 267 268 if (nnz_per_row < 1) nnz_per_row = 1; 269 270 int max_vector_length = teamPolicy.vector_length_max(); 271 272 if (vector_length < 1) { 273 vector_length = 1; 274 while (vector_length < max_vector_length && vector_length * 6 < nnz_per_row) vector_length *= 2; 275 } 276 277 // Determine rows per thread 278 if (rows_per_thread < 1) { 279 if (KokkosKernels::Impl::kk_is_gpu_exec_space<ExecutionSpace>()) rows_per_thread = 1; 280 else { 281 if (nnz_per_row < 20 && nnz > 5000000) { 282 rows_per_thread = 256; 283 } else rows_per_thread = 64; 284 } 285 } 286 287 if (team_size < 1) { 288 if (KokkosKernels::Impl::kk_is_gpu_exec_space<ExecutionSpace>()) { 289 team_size = 256 / vector_length; 290 } else { 291 team_size = 1; 292 } 293 } 294 295 rows_per_team = rows_per_thread * team_size; 296 297 if (rows_per_team < 0) { 298 PetscInt nnz_per_team = 4096; 299 PetscInt conc = ExecutionSpace().concurrency(); 300 while ((conc * nnz_per_team * 4 > nnz) && (nnz_per_team > 256)) nnz_per_team /= 2; 301 rows_per_team = (nnz_per_team + nnz_per_row - 1) / nnz_per_row; 302 } 303 PetscFunctionReturn(PETSC_SUCCESS); 304 } 305 306 /* 307 Reduce two sets of global indices into local ones 308 309 Input Parameters: 310 + n1 - size of garray1[], the first set 311 . garray1[n1] - a sorted global index array (without duplicates) 312 . m - size of indices[], the second set 313 - indices[m] - a unsorted global index array (might have duplicates), which will be updated on output into local ones 314 315 Output Parameters: 316 + n2 - size of garray2[], the merged set, which combines garray1[] and indices[] 317 . garray2[n2] - allocated by callee using PetscMalloc1(). Contains sorted unique global indices (without duplicates). Caller needs to free it. 318 . map[n1] - allocated by caller. It gives garray1[i] = garray2[map[i]] 319 - indices[m] - on output, global indices in this array are rewritten with local ones, i.e, indices_input[i] = garray2[indices_output[i]] 320 321 Example, say 322 n1 = 5 323 garray1[5] = {1, 4, 7, 8, 10} 324 m = 4 325 indices[4] = {2, 4, 8, 9} 326 327 Combining them together, we have 7 global indices in garray2[] 328 n2 = 7 329 garray2[7] = {1, 2, 4, 7, 8, 9, 10} 330 331 And we have map[] to connect "garray1[i] = garray2[map[i]], i=[0,n1)" 332 map[5] = {0, 2, 3, 4, 6} 333 334 On output, indices[] is updated with local indices 335 indices[4] = {1, 2, 4, 5} 336 */ 337 static PetscErrorCode ReduceTwoSetsOfGlobalIndices(PetscInt n1, const PetscInt *garray1, PetscInt m, PetscInt *indices, PetscInt *n2_, PetscInt **garray2_, PetscInt *map) 338 { 339 PetscHMapI g2l = nullptr; 340 PetscHashIter iter; 341 PetscInt tot, key, val; // total unique global indices. key is global id; val is local id 342 PetscInt n2, *garray2; 343 344 PetscFunctionBegin; 345 tot = 0; 346 PetscCall(PetscHMapICreateWithSize(n1, &g2l)); 347 for (PetscInt i = 0; i < m; i++) { // insert those in indices[] 348 PetscCall(PetscHMapIGetWithDefault(g2l, indices[i], -1, &val)); // if not exist, val is set with -1 349 if (val < 0) PetscCall(PetscHMapISet(g2l, indices[i], tot++)); // val < 0 means gid is not in the hash table yet 350 } 351 352 for (PetscInt i = 0; i < n1; i++) { // insert those in garray1[] 353 PetscCall(PetscHMapIGetWithDefault(g2l, garray1[i], -1, &val)); 354 if (val < 0) PetscCall(PetscHMapISet(g2l, garray1[i], tot++)); 355 } 356 357 // Pull out (unique) globals in the hash table and put them in garray2[] 358 n2 = tot; 359 PetscCall(PetscMalloc1(n2, &garray2)); 360 tot = 0; 361 PetscHashIterBegin(g2l, iter); 362 while (!PetscHashIterAtEnd(g2l, iter)) { 363 PetscHashIterGetKey(g2l, iter, key); 364 PetscHashIterNext(g2l, iter); 365 garray2[tot++] = key; 366 } 367 368 // Sort garray2[] and then map them to local indices starting from 0 369 PetscCall(PetscSortInt(n2, garray2)); 370 PetscCall(PetscHMapIClear(g2l)); 371 for (PetscInt i = 0; i < tot; i++) PetscCall(PetscHMapISet(g2l, garray2[i], i)); // i is the local id 372 373 // Rewrite indices[] with local indices 374 for (PetscInt i = 0; i < m; i++) { 375 PetscCall(PetscHMapIGetWithDefault(g2l, indices[i], -1, &val)); 376 PetscAssert(val >= 0, PETSC_COMM_SELF, PETSC_ERR_PLIB, "Met a negative local column index"); 377 indices[i] = val; 378 } 379 // Record the map that maps garray1[i] to garray2[map[i]] 380 for (PetscInt i = 0; i < n1; i++) PetscCall(PetscHMapIGetWithDefault(g2l, garray1[i], -1, &map[i])); 381 PetscCall(PetscHMapIDestroy(&g2l)); 382 *n2_ = n2; 383 *garray2_ = garray2; 384 PetscFunctionReturn(PETSC_SUCCESS); 385 } 386 387 /* 388 MatMPIAIJKokkosReduce - Reduce rows of a MPIAIJKOKKOS matrix (E, in split form) to produce another matrix (F, also in split form, stored in mm) 389 390 It is the reverse of MatMPIAIJKokkosBcast() in some sense, but with a different signature since we do not really need a fully populated MPIAIJKOKKOS E. 391 392 Think each row of E as a leaf, then the given ownerSF specifies roots for the leaves. Roots may connect to multiple leaves. 393 In this routine, we sparse-merge leaves (rows) at their roots to form potentially longer rows in F. F's number of rows will be nroots of ownerSF. 394 395 Input Parameters: 396 + comm - MPI communicator of E 397 . A - diag block of E, using local column indices 398 . B - off-diag block of E, using local column indices 399 . cstart - (global) start column of Ed 400 . cend - (global) end column + 1 of Ed. In other words, E's column ownership is in range of [cstart, cend) 401 . garray1[n1] - global column indices of Eo. Here n1 is Eo's column size. 402 . ownerSF - the SF specifies ownership (root) of rows in E 403 . reuse - either MAT_INITIAL_MATRIX or MAT_REUSE_MATRIX 404 - mm - to stash intermediate data structures for reuse 405 406 Output Parameters: 407 + map[n1] - allocated by caller. It maps garray1[] to garray2[]. See more at ReduceTwoSetsOfGlobalIndices(). 408 - mm - contains various info, such as garray2[], F (Fd, Fo) etc. 409 410 Notes: 411 When reuse = MAT_REUSE_MATRIX, cstart, cend, garray1, ownerSF, map are not significant. 412 413 */ 414 static PetscErrorCode MatMPIAIJKokkosReduceBegin(MPI_Comm comm, KokkosCsrMatrix A, KokkosCsrMatrix B, PetscInt cstart, PetscInt cend, const PetscInt *garray1, PetscSF ownerSF, MatReuse reuse, PetscInt *map, MatMatStruct_AtB *mm) 415 { 416 PetscFunctionBegin; 417 if (reuse == MAT_INITIAL_MATRIX) { 418 PetscInt Em = A.numRows(), Fm; 419 PetscInt n1 = B.numCols(); 420 421 PetscCall(PetscSFGetGraph(ownerSF, &Fm, NULL, NULL, NULL)); // Fm = #rows of F = nroots of ownerSF 422 423 // Do the analysis on host 424 auto Ai_h = Kokkos::create_mirror_view_and_copy(Kokkos::HostSpace(), A.graph.row_map); 425 auto Aj_h = Kokkos::create_mirror_view_and_copy(Kokkos::HostSpace(), A.graph.entries); 426 auto Bi_h = Kokkos::create_mirror_view_and_copy(Kokkos::HostSpace(), B.graph.row_map); 427 auto Bj_h = Kokkos::create_mirror_view_and_copy(Kokkos::HostSpace(), B.graph.entries); 428 const MatRowMapType *Ai = Ai_h.data(), *Bi = Bi_h.data(); 429 const MatColIdxType *Aj = Aj_h.data(), *Bj = Bj_h.data(); 430 431 // Count how many nonzeros of each row in E are in the left of the diag block [cstart,cend) 432 PetscIntKokkosViewHost E_NzLeft_h(NoInit("E_NzLeft_h"), Em), E_RowLen_h(NoInit("E_RowLen_h"), Em); 433 PetscInt *E_NzLeft = E_NzLeft_h.data(), *E_RowLen = E_RowLen_h.data(); 434 for (PetscInt i = 0; i < Em; i++) { 435 const PetscInt *first, *last, *it; 436 PetscInt count, step; 437 // std::lower_bound(first,last,cstart), but need to use global column indices 438 first = Bj + Bi[i]; 439 last = Bj + Bi[i + 1]; 440 count = last - first; 441 while (count > 0) { 442 it = first; 443 step = count / 2; 444 it += step; 445 if (garray1[*it] < cstart) { // map local to global 446 first = ++it; 447 count -= step + 1; 448 } else count = step; 449 } 450 E_NzLeft[i] = first - (Bj + Bi[i]); 451 E_RowLen[i] = (Ai[i + 1] - Ai[i]) + (Bi[i + 1] - Bi[i]); 452 } 453 454 // Get length of rows (i.e., sizes of leaves) that contribute to my roots 455 const PetscMPIInt *iranks, *ranks; 456 const PetscInt *ioffset, *irootloc, *roffset, *rmine; 457 PetscInt niranks, nranks; 458 MPI_Request *reqs; 459 PetscMPIInt tag; 460 PetscSF reduceSF; 461 PetscInt *sdisp, *rdisp; 462 463 PetscCall(PetscCommGetNewTag(comm, &tag)); 464 PetscCall(PetscSFGetLeafRanks(ownerSF, &niranks, &iranks, &ioffset, &irootloc)); // get leaf ranks connecting to roots on this process (I'll recv from them) 465 PetscCall(PetscSFGetRootRanks(ownerSF, &nranks, &ranks, &roffset, &rmine, NULL)); // get root ranks this process connects (I'll send to them) 466 467 // Find out length of each row I will receive. Even for the same row index, when they are from 468 // different senders, they might have different lengths (and sparsity patterns) 469 PetscInt sendRowCnt = roffset[nranks], recvRowCnt = ioffset[niranks]; 470 PetscInt *sendRowLen, *recvRowLen; // lengths of rows of I need to send/recv per process 471 472 PetscCall(PetscMalloc5(sendRowCnt, &sendRowLen, recvRowCnt + 1, &recvRowLen, nranks, &sdisp, niranks + 1, &rdisp, nranks + niranks, &reqs)); 473 474 for (PetscInt i = 0; i < sendRowCnt; i++) sendRowLen[i] = E_RowLen[rmine[i]]; 475 recvRowLen[0] = 0; // since we will make it in CSR format later 476 recvRowLen++; // advance the pointer now 477 for (PetscInt i = 0; i < niranks; i++) { MPI_Irecv(&recvRowLen[ioffset[i]], ioffset[i + 1] - ioffset[i], MPIU_INT, iranks[i], tag, comm, &reqs[nranks + i]); } 478 for (PetscInt i = 0; i < nranks; i++) { MPI_Isend(&sendRowLen[roffset[i]], roffset[i + 1] - roffset[i], MPIU_INT, ranks[i], tag, comm, &reqs[i]); } 479 PetscCallMPI(MPI_Waitall(nranks + niranks, reqs, MPI_STATUSES_IGNORE)); 480 481 // Build the real PetscSF for reducing E rows (buffer to buffer) 482 rdisp[0] = 0; 483 for (PetscInt i = 0; i < niranks; i++) { 484 rdisp[i + 1] = rdisp[i]; 485 for (PetscInt j = ioffset[i]; j < ioffset[i + 1]; j++) { rdisp[i + 1] += recvRowLen[j]; } 486 } 487 recvRowLen--; // put it back into csr format 488 for (PetscInt i = 0; i < recvRowCnt; i++) recvRowLen[i + 1] += recvRowLen[i]; 489 490 for (PetscInt i = 0; i < nranks; i++) { MPI_Irecv(&sdisp[i], 1, MPIU_INT, ranks[i], tag, comm, &reqs[i]); } 491 for (PetscInt i = 0; i < niranks; i++) { MPI_Isend(&rdisp[i], 1, MPIU_INT, iranks[i], tag, comm, &reqs[nranks + i]); } 492 PetscCallMPI(MPI_Waitall(nranks + niranks, reqs, MPI_STATUSES_IGNORE)); 493 494 PetscInt nleaves = 0, Enz = 0; // leaves are nonzeros I will send 495 PetscInt nroots = rdisp[niranks]; // roots are nonzeros I will recv 496 PetscSFNode *iremote; 497 498 for (PetscInt i = 0; i < Em; i++) Enz += E_RowLen[i]; 499 PetscAssert(A.nnz() + B.nnz() == Enz, comm, PETSC_ERR_PLIB, "Enz should be equal to sum of nnz of A and B"); 500 PetscCallMPI(PetscMalloc1(Enz, &iremote)); // no free, since we give ownership to reduceSF 501 502 for (PetscInt i = 0; i < nranks; i++) { 503 PetscInt count = 0; 504 for (PetscInt j = roffset[i]; j < roffset[i + 1]; j++) count += E_RowLen[rmine[j]]; 505 for (PetscInt j = 0; j < count; j++) { 506 iremote[nleaves + j].rank = ranks[i]; 507 iremote[nleaves + j].index = sdisp[i] + j; 508 } 509 nleaves += count; 510 } 511 PetscCheck(nleaves == Enz, comm, PETSC_ERR_PLIB, "nleaves should be equal to Enz"); 512 513 PetscCall(PetscSFCreate(comm, &reduceSF)); 514 PetscCall(PetscSFSetGraph(reduceSF, nroots, nleaves, NULL, PETSC_OWN_POINTER, iremote, PETSC_OWN_POINTER)); 515 516 // Copy (global) column indices of the needed rows in E to sendCol[], and then PetscSFReduce to recvCol[] 517 PetscInt *sendCol, *recvCol; 518 PetscCall(PetscMalloc2(nleaves, &sendCol, nroots, &recvCol)); 519 for (PetscInt k = 0; k < roffset[nranks]; k++) { 520 PetscInt i = rmine[k]; // row to be copied 521 PetscInt *buf = &sendCol[Ai[i] + Bi[i]]; 522 PetscInt nzLeft = E_NzLeft[i]; 523 PetscInt alen = Ai[i + 1] - Ai[i], blen = Bi[i + 1] - Bi[i]; 524 for (PetscInt j = 0; j < alen + blen; j++) { 525 if (j < nzLeft) { 526 buf[j] = garray1[Bj[Bi[i] + j]]; // left B, in global 527 } else if (j < nzLeft + alen) { 528 buf[j] = Aj[Ai[i] + j - nzLeft] + cstart; // diag A, also in global 529 } else { 530 buf[j] = garray1[Bj[Bi[i] + j - alen]]; // right B, in global 531 } 532 } 533 } 534 PetscCall(PetscSFReduceWithMemTypeBegin(reduceSF, MPIU_INT, PETSC_MEMTYPE_HOST, sendCol, PETSC_MEMTYPE_HOST, recvCol, MPI_REPLACE)); 535 PetscCall(PetscSFReduceEnd(reduceSF, MPIU_INT, sendCol, recvCol, MPI_REPLACE)); 536 537 // With recvCol[], we do a series of analysis to get i, j of Fd, Fo, and build plans to reduce nonzeros in recv buffers to Fd and Fo 538 PetscInt *recvRowPerm, *recvColSorted; 539 PetscInt *recvNzPerm, *recvNzPermSorted; 540 PetscCall(PetscMalloc4(recvRowCnt, &recvRowPerm, nroots, &recvColSorted, nroots, &recvNzPerm, nroots, &recvNzPermSorted)); 541 542 for (PetscInt i = 0; i < nroots; i++) recvNzPerm[i] = i; // numbering all received nonzeros 543 for (PetscInt i = 0; i < recvRowCnt; i++) recvRowPerm[i] = i; // put up a permutation array, so that after sorting we know where to get a row in recvCol[] 544 PetscCall(PetscSortIntWithPermutation(recvRowCnt, irootloc, recvRowPerm)); // irootloc[] (owned by ownerSF) won't be changed 545 546 // i[] array, nz are always easiest to compute 547 MatRowMapKokkosViewHost Fdi_h(NoInit("Fdi_h"), Fm + 1), Foi_h(NoInit("Foi_h"), Fm + 1); 548 MatRowMapType *Fdi, *Foi; 549 PetscInt FnzDups = 0, Fdnz = 0, FdnzDups = 0, Fonz = 0, FonzDups = 0; // nz (with or without dups) in F, Fd, Fo 550 PetscInt iter; 551 552 Kokkos::deep_copy(Fdi_h, 0); // zero, as we will do 'val++' on them 553 Kokkos::deep_copy(Foi_h, 0); 554 Fdi = Fdi_h.data() + 1; // +1 for easy indexing in code below 555 Foi = Foi_h.data() + 1; 556 iter = 0; 557 while (iter < recvRowCnt) { // iter over received rows 558 PetscInt curRowIdx = irootloc[recvRowPerm[iter]]; 559 PetscInt dupRows = 1; // current row has this many contributing rows (of various sparsity patterns) 560 561 while (iter + dupRows < recvRowCnt && irootloc[recvRowPerm[iter + dupRows]] == curRowIdx) dupRows++; 562 563 // Copy column indices (and their permutation) of these rows into recvColSorted & recvNzPermSorted 564 PetscInt nz = 0; // nz (with dups) in the current row 565 PetscInt *jbuf = recvColSorted + FnzDups; 566 PetscInt *pbuf = recvNzPermSorted + FnzDups; 567 PetscInt *jbuf2 = jbuf; // temp pointers 568 PetscInt *pbuf2 = pbuf; 569 for (PetscInt d = 0; d < dupRows; d++) { 570 PetscInt i = recvRowPerm[iter + d]; 571 PetscInt len = recvRowLen[i + 1] - recvRowLen[i]; 572 PetscCall(PetscArraycpy(jbuf2, &recvCol[recvRowLen[i]], len)); 573 PetscCall(PetscArraycpy(pbuf2, &recvNzPerm[recvRowLen[i]], len)); 574 jbuf2 += len; 575 pbuf2 += len; 576 nz += len; 577 } 578 PetscCall(PetscIntSortSemiOrderedWithArray(nz, jbuf, pbuf)); // It could be improved with k-way merge sort, since the rows are already sorted 579 580 // Scan column indices (in jbuf[0,nz), might have dups) of this row, and see how many go to Fd and how many go to Fo 581 PetscInt cur = 0; 582 while (cur < nz) { 583 PetscInt curColIdx = jbuf[cur]; 584 PetscInt dups = 1; 585 586 while (cur + dups < nz && jbuf[cur + dups] == curColIdx) dups++; 587 if (curColIdx >= cstart && curColIdx < cend) { 588 Fdi[curRowIdx]++; 589 FdnzDups += dups; 590 } else { 591 Foi[curRowIdx]++; 592 FonzDups += dups; 593 } 594 cur += dups; 595 } 596 597 FnzDups += nz; 598 iter += dupRows; // Move to next unique row 599 } 600 601 Fdi = Fdi_h.data(); // restore Fdi, Foi and make them CSR 602 Foi = Foi_h.data(); 603 for (PetscInt i = 0; i < Fm; i++) { 604 Fdi[i + 1] += Fdi[i]; 605 Foi[i + 1] += Foi[i]; 606 } 607 Fdnz = Fdi[Fm]; 608 Fonz = Foi[Fm]; 609 PetscCall(PetscFree2(sendCol, recvCol)); 610 611 // Allocate j, jmap, jperm for Fd and Fo 612 MatColIdxKokkosViewHost Fdj_h(NoInit("Fdj_h"), Fdnz), Foj_h(NoInit("Foj_h"), Fonz); 613 MatRowMapKokkosViewHost Fdjmap_h(NoInit("Fdjmap_h"), Fdnz + 1), Fojmap_h(NoInit("Fojmap_h"), Fonz + 1); // +1 to make csr 614 MatRowMapKokkosViewHost Fdjperm_h(NoInit("Fdjperm_h"), FdnzDups), Fojperm_h(NoInit("Fojperm_h"), FonzDups); 615 MatColIdxType *Fdj = Fdj_h.data(), *Foj = Foj_h.data(); 616 MatRowMapType *Fdjmap = Fdjmap_h.data(), *Fojmap = Fojmap_h.data(); 617 MatRowMapType *Fdjperm = Fdjperm_h.data(), *Fojperm = Fojperm_h.data(); 618 619 // Scan recvColSorted[] again, and fill j, jmap, jperm for Fd and Fo 620 Fdjmap[0] = 0; 621 Fojmap[0] = 0; 622 FnzDups = 0; 623 Fdnz = 0; 624 Fonz = 0; 625 iter = 0; // iter over received rows 626 while (iter < recvRowCnt) { 627 PetscInt curRowIdx = irootloc[recvRowPerm[iter]]; // current row idx 628 PetscInt dupRows = 1; // It has this many contributing rows (of various lengths) 629 PetscInt nz = 0; // nz (with dups) in the current row 630 631 while (iter + dupRows < recvRowCnt && irootloc[recvRowPerm[iter + dupRows]] == curRowIdx) dupRows++; 632 for (PetscInt d = 0; d < dupRows; d++) { 633 PetscInt i = recvRowPerm[iter + d]; 634 nz += recvRowLen[i + 1] - recvRowLen[i]; 635 } 636 637 PetscInt *jbuf = recvColSorted + FnzDups; 638 // Scan columns (in jbuf[0,nz) of this row, copy them and their permutation to j[] and jperm[] of Fd and Fo 639 PetscInt cur = 0; 640 while (cur < nz) { 641 PetscInt curColIdx = jbuf[cur]; 642 PetscInt dups = 1; 643 644 while (cur + dups < nz && jbuf[cur + dups] == curColIdx) dups++; 645 if (curColIdx >= cstart && curColIdx < cend) { 646 Fdj[Fdnz] = curColIdx - cstart; // easily convert to local 647 Fdjmap[Fdnz + 1] = Fdjmap[Fdnz] + dups; 648 for (PetscInt j = 0; j < dups; j++) Fdjperm[Fdjmap[Fdnz] + j] = recvNzPermSorted[FnzDups + j]; 649 FdnzDups += dups; 650 Fdnz++; 651 } else { 652 Foj[Fonz] = curColIdx; // in global 653 Fojmap[Fonz + 1] = Fojmap[Fonz] + dups; 654 for (PetscInt j = 0; j < dups; j++) Fojperm[Fojmap[Fonz] + j] = recvNzPermSorted[FnzDups + j]; 655 FonzDups += dups; 656 Fonz++; 657 } 658 cur += dups; 659 FnzDups += dups; 660 } 661 iter += dupRows; // Move to next unique row 662 } 663 PetscCall(PetscFree4(recvRowPerm, recvColSorted, recvNzPerm, recvNzPermSorted)); 664 PetscCall(PetscFree5(sendRowLen, recvRowLen, sdisp, rdisp, reqs)); 665 666 // Combine global column indices in garray1[] and Foj[] 667 PetscInt n2, *garray2; 668 669 PetscCall(ReduceTwoSetsOfGlobalIndices(n1, garray1, Fonz, Foj, &n2, &garray2, map)); 670 mm->sf = reduceSF; 671 mm->leafBuf = MatScalarKokkosView(NoInit("leafBuf"), nleaves); 672 mm->rootBuf = MatScalarKokkosView(NoInit("rootBuf"), nroots); 673 mm->garray = garray2; // give ownership, so no free 674 mm->n = n2; 675 mm->E_NzLeft = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), E_NzLeft_h); 676 mm->Fdjmap = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Fdjmap_h); 677 mm->Fdjperm = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Fdjperm_h); 678 mm->Fojmap = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Fojmap_h); 679 mm->Fojperm = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Fojperm_h); 680 681 // Output Fd and Fo in KokkosCsrMatrix format 682 MatScalarKokkosView Fda_d(NoInit("Fda_d"), Fdnz); 683 MatRowMapKokkosView Fdi_d = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Fdi_h); 684 MatColIdxKokkosView Fdj_d = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Fdj_h); 685 MatScalarKokkosView Foa_d(NoInit("Foa_d"), Fonz); 686 MatRowMapKokkosView Foi_d = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Foi_h); 687 MatColIdxKokkosView Foj_d = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Foj_h); 688 689 PetscCallCXX(mm->Fd = KokkosCsrMatrix("Fd", Fm, cend - cstart, Fdnz, Fda_d, Fdi_d, Fdj_d)); 690 PetscCallCXX(mm->Fo = KokkosCsrMatrix("Fo", Fm, n2, Fonz, Foa_d, Foi_d, Foj_d)); // Fo's column size is n2, length of garray2[] 691 692 // Compute kernel launch parameters in merging E 693 PetscInt teamSize, vectorLength, rowsPerTeam; 694 695 teamSize = vectorLength = rowsPerTeam = -1; 696 PetscCall(MatMergeGetLaunchParameters<DefaultExecutionSpace>(Em, Enz, -1, teamSize, vectorLength, rowsPerTeam)); 697 mm->E_TeamSize = teamSize; 698 mm->E_VectorLength = vectorLength; 699 mm->E_RowsPerTeam = rowsPerTeam; 700 } else PetscCheck(reuse == MAT_REUSE_MATRIX, comm, PETSC_ERR_PLIB, "Unsupported MatReuse enum %d", reuse); 701 702 // Handy aliases 703 auto &Aa = A.values; 704 auto &Ba = B.values; 705 const auto &Ai = A.graph.row_map; 706 const auto &Bi = B.graph.row_map; 707 const auto &E_NzLeft = mm->E_NzLeft; 708 auto &leafBuf = mm->leafBuf; 709 auto &rootBuf = mm->rootBuf; 710 PetscSF reduceSF = mm->sf; 711 PetscInt Em = A.numRows(); 712 PetscInt teamSize = mm->E_TeamSize; 713 PetscInt vectorLength = mm->E_VectorLength; 714 PetscInt rowsPerTeam = mm->E_RowsPerTeam; 715 PetscInt workSets = (Em + rowsPerTeam - 1) / rowsPerTeam; 716 717 // Copy rows in A/B of E to leafBuf, then pass it to rootBuf 718 PetscCallCXX(Kokkos::parallel_for( 719 Kokkos::TeamPolicy<>(PetscGetKokkosExecutionSpace(), workSets, teamSize, vectorLength), KOKKOS_LAMBDA(const KokkosTeamMemberType &t) { 720 Kokkos::parallel_for(Kokkos::TeamThreadRange(t, 0, rowsPerTeam), [&](PetscInt k) { 721 PetscInt i = t.league_rank() * rowsPerTeam + k; // i-th row in F 722 if (i < Em) { 723 PetscInt disp = Ai(i) + Bi(i); 724 PetscInt alen = Ai(i + 1) - Ai(i); 725 PetscInt blen = Bi(i + 1) - Bi(i); 726 PetscInt nzleft = E_NzLeft(i); 727 728 Kokkos::parallel_for(Kokkos::ThreadVectorRange(t, alen + blen), [&](PetscInt j) { 729 MatScalar &val = leafBuf(disp + j); 730 if (j < nzleft) { // B left 731 val = Ba(Bi(i) + j); 732 } else if (j < nzleft + alen) { // diag A 733 val = Aa(Ai(i) + j - nzleft); 734 } else { // B right 735 val = Ba(Bi(i) + j - alen); 736 } 737 }); 738 } 739 }); 740 })); 741 PetscCall(PetscSFReduceWithMemTypeBegin(reduceSF, MPIU_SCALAR, PETSC_MEMTYPE_KOKKOS, leafBuf.data(), PETSC_MEMTYPE_KOKKOS, rootBuf.data(), MPI_REPLACE)); 742 PetscFunctionReturn(PETSC_SUCCESS); 743 } 744 745 // To finish MatMPIAIJKokkosReduce. 746 static PetscErrorCode MatMPIAIJKokkosReduceEnd(MPI_Comm comm, KokkosCsrMatrix A, KokkosCsrMatrix B, PetscInt cstart, PetscInt cend, const PetscInt *garray1, PetscSF ownerSF, MatReuse reuse, PetscInt *map, MatMatStruct_AtB *mm) 747 { 748 auto &leafBuf = mm->leafBuf; 749 auto &rootBuf = mm->rootBuf; 750 auto &Fda = mm->Fd.values; 751 const auto &Fdjmap = mm->Fdjmap; 752 const auto &Fdjperm = mm->Fdjperm; 753 auto Fdnz = mm->Fd.nnz(); 754 auto &Foa = mm->Fo.values; 755 const auto &Fojmap = mm->Fojmap; 756 const auto &Fojperm = mm->Fojperm; 757 auto Fonz = mm->Fo.nnz(); 758 PetscSF reduceSF = mm->sf; 759 760 PetscFunctionBegin; 761 PetscCall(PetscSFReduceEnd(reduceSF, MPIU_SCALAR, leafBuf.data(), rootBuf.data(), MPI_REPLACE)); 762 763 // Reduce data in rootBuf to Fd and Fo 764 PetscCallCXX(Kokkos::parallel_for( 765 Kokkos::RangePolicy<>(PetscGetKokkosExecutionSpace(), 0, Fdnz), KOKKOS_LAMBDA(const MatRowMapType i) { 766 PetscScalar sum = 0.0; 767 for (MatRowMapType k = Fdjmap(i); k < Fdjmap(i + 1); k++) sum += rootBuf(Fdjperm(k)); 768 Fda(i) = sum; 769 })); 770 771 PetscCallCXX(Kokkos::parallel_for( 772 Kokkos::RangePolicy<>(PetscGetKokkosExecutionSpace(), 0, Fonz), KOKKOS_LAMBDA(const MatRowMapType i) { 773 PetscScalar sum = 0.0; 774 for (MatRowMapType k = Fojmap(i); k < Fojmap(i + 1); k++) sum += rootBuf(Fojperm(k)); 775 Foa(i) = sum; 776 })); 777 PetscFunctionReturn(PETSC_SUCCESS); 778 } 779 780 /* 781 MatMPIAIJKokkosBcast - Bcast local rows of a MPIAIJKOKKOS matrix (E) to produce a local matrix (F, stored in mm) in split form 782 783 This is a complex routine. It is essentially the MPIAIJKOKKOS counterpart of MatGetBrowsOfAoCols_MPIAIJ, but supports 784 device and involves various index mapping. 785 786 In the given ownerSF, leaves correspond to rows in F, and roots correspond to rows in E. Roots may connect to multiple leaves. 787 Suppose F's j-th row is connected to a root identified by PetscSFNode (k,i), it means we need to bcast the i-th row of E on rank k 788 to j-th row of F. ownerSF is not an arbitrary SF, instead it is the Mvctx of another MPIAIJ matrix A that is able to perform A*E. 789 F has the same column layout as E. 790 791 Conceptually F has global column indices. In this routine, we spit F into diagonal Fd and off-diagonal Fo. 792 Fd uses local column indices, which are easy to compute. We just need to subtract the "local column range start" from the global indices. 793 Fo had global column indices at first. We will reduce them into local ones. In doing that, we also take into account the global 794 column indices that E's off-diag block has. Let's say there are n1 such indices stored in garray1[]. We will reduce them along with 795 column indices in Fo and update Fo with local indices. 796 797 Input Parameters: 798 + E - the MPIAIJKOKKOS matrix 799 . ownerSF - the ownership SF (insignificant in MAT_REUSE_MATRIX) 800 . reuse - either MAT_INITIAL_MATRIX or MAT_REUSE_MATRIX 801 - mm - to stash matproduct intermediate data structures 802 803 Output Parameters: 804 + map[n1] - allocated by caller. It maps garray1[] to garray2[]. See more at ReduceTwoSetsOfGlobalIndices. 805 - mm - contains various info, such as garray2[], Fd, Fo, etc. 806 807 Notes: 808 When reuse = MAT_REUSE_MATRIX, ownerSF, map are not significant. 809 The routine is provide in split-phase form MatMPIAIJKokkosBcastBegin/End() to provide computation/communication opportunities. 810 */ 811 static PetscErrorCode MatMPIAIJKokkosBcastBegin(Mat E, PetscSF ownerSF, MatReuse reuse, PetscInt *map, MatMatStruct_AB *mm) 812 { 813 Mat_MPIAIJ *empi = static_cast<Mat_MPIAIJ *>(E->data); 814 Mat A = empi->A, B = empi->B; // diag and off-diag 815 Mat_SeqAIJKokkos *akok = static_cast<Mat_SeqAIJKokkos *>(A->spptr), *bkok = static_cast<Mat_SeqAIJKokkos *>(B->spptr); 816 PetscInt Em = E->rmap->n; // #local rows 817 MPI_Comm comm; 818 819 PetscFunctionBegin; 820 PetscCallMPI(PetscObjectGetComm((PetscObject)E, &comm)); 821 if (reuse == MAT_INITIAL_MATRIX) { 822 Mat_SeqAIJ *aseq = static_cast<Mat_SeqAIJ *>(A->data), *bseq = static_cast<Mat_SeqAIJ *>(B->data); 823 PetscInt n1 = B->cmap->n, *Ai = aseq->i, *Aj = aseq->j, *Bi = bseq->i, *Bj = bseq->j; 824 const PetscInt *garray1 = empi->garray; // its size is n1 825 PetscInt cstart, cend; 826 PetscSF bcastSF; 827 828 PetscCall(MatGetOwnershipRangeColumn(E, &cstart, &cend)); 829 830 // Count how many nonzeros of each row in E are in the left of the diag block [cstart,cend) 831 PetscIntKokkosViewHost E_NzLeft_h(NoInit("E_NzLeft_h"), Em), E_RowLen_h(NoInit("E_RowLen_h"), Em); 832 PetscInt *E_NzLeft = E_NzLeft_h.data(), *E_RowLen = E_RowLen_h.data(); 833 for (PetscInt i = 0; i < Em; i++) { 834 const PetscInt *first, *last, *it; 835 PetscInt count, step; 836 // std::lower_bound(first,last,cstart), but need to use global column indices 837 first = Bj + Bi[i]; 838 last = Bj + Bi[i + 1]; 839 count = last - first; 840 while (count > 0) { 841 it = first; 842 step = count / 2; 843 it += step; 844 if (empi->garray[*it] < cstart) { // map local to global 845 first = ++it; 846 count -= step + 1; 847 } else count = step; 848 } 849 E_NzLeft[i] = first - (Bj + Bi[i]); 850 E_RowLen[i] = (Ai[i + 1] - Ai[i]) + (Bi[i + 1] - Bi[i]); 851 } 852 853 // Compute row pointer Fi of F 854 PetscInt *Fi, Fm, Fnz; 855 PetscCall(PetscSFGetGraph(ownerSF, NULL, &Fm, NULL, NULL)); // Fm = #rows of F = nleaves of ownerSF 856 PetscCall(PetscMalloc1(Fm + 1, &Fi)); 857 Fi[0] = 0; 858 PetscCall(PetscSFBcastWithMemTypeBegin(ownerSF, MPIU_INT, PETSC_MEMTYPE_HOST, E_RowLen, PETSC_MEMTYPE_HOST, &Fi[1], MPI_REPLACE)); 859 PetscCall(PetscSFBcastEnd(ownerSF, MPIU_INT, E_RowLen, &Fi[1], MPI_REPLACE)); 860 for (PetscInt i = 0; i < Fm; i++) Fi[i + 1] += Fi[i]; 861 Fnz = Fi[Fm]; 862 863 // Build the real PetscSF for bcasting E rows (buffer to buffer) 864 const PetscMPIInt *iranks, *ranks; 865 const PetscInt *ioffset, *irootloc, *roffset; 866 PetscInt niranks, nranks, *sdisp, *rdisp; 867 MPI_Request *reqs; 868 PetscMPIInt tag; 869 870 PetscCall(PetscSFGetLeafRanks(ownerSF, &niranks, &iranks, &ioffset, &irootloc)); // get leaf ranks referencing roots on this process 871 PetscCall(PetscSFGetRootRanks(ownerSF, &nranks, &ranks, &roffset, NULL, NULL)); // recv info 872 PetscCall(PetscMalloc3(niranks + 1, &sdisp, nranks, &rdisp, niranks + nranks, &reqs)); 873 874 sdisp[0] = 0; // send displacement 875 for (PetscInt i = 0; i < niranks; i++) { 876 sdisp[i + 1] = sdisp[i]; 877 for (PetscInt j = ioffset[i]; j < ioffset[i + 1]; j++) { 878 PetscInt r = irootloc[j]; // row to be sent 879 sdisp[i + 1] += E_RowLen[r]; 880 } 881 } 882 883 PetscCallMPI(PetscCommGetNewTag(comm, &tag)); 884 for (PetscInt i = 0; i < nranks; i++) PetscCallMPI(MPI_Irecv(&rdisp[i], 1, MPIU_INT, ranks[i], tag, comm, &reqs[i])); 885 for (PetscInt i = 0; i < niranks; i++) PetscCallMPI(MPI_Isend(&sdisp[i], 1, MPIU_INT, iranks[i], tag, comm, &reqs[nranks + i])); 886 PetscCallMPI(MPI_Waitall(niranks + nranks, reqs, MPI_STATUSES_IGNORE)); 887 888 PetscInt nleaves = Fnz; // leaves are nonzeros I will receive 889 PetscInt nroots = sdisp[niranks]; // roots are nonzeros I will send 890 PetscSFNode *iremote; // give ownership to bcastSF 891 PetscCall(PetscMalloc1(nleaves, &iremote)); 892 for (PetscInt i = 0; i < nranks; i++) { // for each sender rank 893 PetscInt k = 0; 894 for (PetscInt j = Fi[roffset[i]]; j < Fi[roffset[i + 1]]; j++) { // I will receive rows [roffset[i], roffset[i+1]) of F from ranks[i] 895 iremote[j].rank = ranks[i]; 896 iremote[j].index = rdisp[i] + k; // their root location 897 k++; 898 } 899 } 900 PetscCall(PetscSFCreate(comm, &bcastSF)); 901 PetscCall(PetscSFSetGraph(bcastSF, nroots, nleaves, NULL, PETSC_OWN_POINTER, iremote, PETSC_OWN_POINTER)); 902 PetscCall(PetscFree3(sdisp, rdisp, reqs)); 903 904 // Build a plan (rowoffset, irootloc, E_NzLeft) to copy rows in E to rootdata of bcastSF in parallel 905 PetscIntKokkosViewHost rowoffset_h(NoInit("rowoffset_h"), ioffset[niranks] + 1); 906 PetscInt *rowoffset = rowoffset_h.data(); // for each entry (row) indicated in irootloc[], we calculate its destinate offset in copying 907 rowoffset[0] = 0; 908 for (PetscInt i = 0; i < ioffset[niranks]; i++) { rowoffset[i + 1] = rowoffset[i] + E_RowLen[irootloc[i]]; } 909 910 // Copy (global) column indices of the needed rows in E to a buffer, and then bcast to Fj[] 911 PetscInt *jbuf, *Fj; 912 PetscCall(PetscMalloc2(nroots, &jbuf, Fnz, &Fj)); 913 for (PetscInt k = 0; k < ioffset[niranks]; k++) { 914 PetscInt i = irootloc[k]; // row to be copied 915 PetscInt *buf = &jbuf[rowoffset[k]]; 916 PetscInt nzLeft = E_NzLeft[i]; 917 PetscInt alen = Ai[i + 1] - Ai[i], blen = Bi[i + 1] - Bi[i]; 918 for (PetscInt j = 0; j < alen + blen; j++) { 919 if (j < nzLeft) { 920 buf[j] = empi->garray[Bj[Bi[i] + j]]; // left B, in global 921 } else if (j < nzLeft + alen) { 922 buf[j] = Aj[Ai[i] + j - nzLeft] + cstart; // diag A, also in global 923 } else { 924 buf[j] = empi->garray[Bj[Bi[i] + j - alen]]; // right B, in global 925 } 926 } 927 } 928 PetscCall(PetscSFBcastWithMemTypeBegin(bcastSF, MPIU_INT, PETSC_MEMTYPE_HOST, jbuf, PETSC_MEMTYPE_HOST, Fj, MPI_REPLACE)); 929 PetscCall(PetscSFBcastEnd(bcastSF, MPIU_INT, jbuf, Fj, MPI_REPLACE)); 930 931 // Build a plan (i.e., F_NzLeft) to split F into Fd and Fo 932 MatRowMapKokkosViewHost Fdi_h(NoInit("Fdi_h"), Fm + 1), Foi_h(NoInit("Foi_h"), Fm + 1); // row pointer of Fd, Fo 933 MatColIdxKokkosViewHost F_NzLeft_h(NoInit("F_NzLeft_h"), Fm); // split each row of F into Left, Diag, Right. We only need to record #nz in Left and Diag. 934 MatRowMapType *Fdi = Fdi_h.data(), *Foi = Foi_h.data(); 935 MatColIdxType *F_NzLeft = F_NzLeft_h.data(); 936 937 Fdi[0] = Foi[0] = 0; 938 for (PetscInt i = 0; i < Fm; i++) { 939 PetscInt *first, *last, *lb1, *lb2; 940 // cut the row into: Left, [cstart, cend), Right 941 first = Fj + Fi[i]; 942 last = Fj + Fi[i + 1]; 943 lb1 = std::lower_bound(first, last, cstart); 944 F_NzLeft[i] = lb1 - first; 945 lb2 = std::lower_bound(first, last, cend); 946 Fdi[i + 1] = lb2 - lb1; // row i length in Fdi 947 Foi[i + 1] = (Fi[i + 1] - Fi[i]) - Fdi[i + 1]; // row i length in Foi 948 } 949 for (PetscInt i = 0; i < Fm; i++) { 950 Fdi[i + 1] += Fdi[i]; 951 Foi[i + 1] += Foi[i]; 952 } 953 954 // Fill Fdj[] and Foj[], i.e., columns of Fd and Fo. Fdj[] are local, but Foj[] are not yet. 955 PetscInt Fdnz = Fdi[Fm], Fonz = Foi[Fm]; 956 MatColIdxKokkosViewHost Fdj_h(NoInit("Fdj_h"), Fdnz), Foj_h(NoInit("Foj_h"), Fonz); 957 MatColIdxType *Fdj = Fdj_h.data(), *Foj = Foj_h.data(), gid; 958 959 for (PetscInt i = 0; i < Fm; i++) { 960 PetscInt nzLeft = F_NzLeft[i]; 961 PetscInt len = Fdi[i + 1] - Fdi[i]; // diag row len 962 for (PetscInt j = 0; j < Fi[i + 1] - Fi[i]; j++) { 963 gid = Fj[Fi[i] + j]; 964 if (j < nzLeft) { // left, in global 965 Foj[Foi[i] + j] = gid; 966 } else if (j < nzLeft + len) { // diag, in local 967 Fdj[Fdi[i] + j - nzLeft] = gid - cstart; 968 } else { // right, in global 969 Foj[Foi[i] + j - len] = gid; 970 } 971 } 972 } 973 PetscCall(PetscFree2(jbuf, Fj)); 974 PetscCall(PetscFree(Fi)); 975 976 // Reduce global indices in Foj[] and garray1[] into local ones 977 PetscInt n2, *garray2; 978 PetscCall(ReduceTwoSetsOfGlobalIndices(n1, garray1, Fonz, Foj, &n2, &garray2, map)); 979 980 // Record the plans built above, for reuse 981 PetscIntKokkosViewHost tmp(const_cast<PetscInt *>(irootloc), ioffset[niranks]); // irootloc[] is owned by ownerSF. We create a copy for safety 982 PetscIntKokkosViewHost irootloc_h(NoInit("irootloc_h"), ioffset[niranks]); 983 Kokkos::deep_copy(irootloc_h, tmp); 984 mm->sf = bcastSF; 985 mm->E_NzLeft = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), E_NzLeft_h); 986 mm->F_NzLeft = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), F_NzLeft_h); 987 mm->irootloc = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), irootloc_h); 988 mm->rowoffset = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), rowoffset_h); 989 mm->rootBuf = MatScalarKokkosView(NoInit("rootBuf"), nroots); 990 mm->leafBuf = MatScalarKokkosView(NoInit("leafBuf"), nleaves); 991 mm->garray = garray2; 992 mm->n = n2; 993 994 // Output Fd and Fo in KokkosCsrMatrix format 995 MatScalarKokkosView Fda_d(NoInit("Fda_d"), Fdnz), Foa_d(NoInit("Foa_d"), Fonz); 996 MatRowMapKokkosView Fdi_d = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Fdi_h); 997 MatColIdxKokkosView Fdj_d = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Fdj_h); 998 MatRowMapKokkosView Foi_d = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Foi_h); 999 MatColIdxKokkosView Foj_d = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Foj_h); 1000 1001 PetscCallCXX(mm->Fd = KokkosCsrMatrix("Fd", Fm, cend - cstart, Fdnz, Fda_d, Fdi_d, Fdj_d)); 1002 PetscCallCXX(mm->Fo = KokkosCsrMatrix("Fo", Fm, n2, Fonz, Foa_d, Foi_d, Foj_d)); 1003 1004 // Compute kernel launch parameters in merging E or splitting F 1005 PetscInt teamSize, vectorLength, rowsPerTeam; 1006 1007 teamSize = vectorLength = rowsPerTeam = -1; 1008 PetscCall(MatMergeGetLaunchParameters<DefaultExecutionSpace>(mm->irootloc.extent(0), mm->rootBuf.extent(0), -1, teamSize, vectorLength, rowsPerTeam)); 1009 mm->E_TeamSize = teamSize; 1010 mm->E_VectorLength = vectorLength; 1011 mm->E_RowsPerTeam = rowsPerTeam; 1012 1013 teamSize = vectorLength = rowsPerTeam = -1; 1014 PetscCall(MatMergeGetLaunchParameters<DefaultExecutionSpace>(Fm, Fnz, -1, teamSize, vectorLength, rowsPerTeam)); 1015 mm->F_TeamSize = teamSize; 1016 mm->F_VectorLength = vectorLength; 1017 mm->F_RowsPerTeam = rowsPerTeam; 1018 } else PetscCheck(reuse == MAT_REUSE_MATRIX, comm, PETSC_ERR_PLIB, "Unsupported MatReuse enum %d", reuse); 1019 1020 // Sync E's value to device 1021 akok->a_dual.sync_device(); 1022 bkok->a_dual.sync_device(); 1023 1024 // Handy aliases 1025 const auto &Aa = akok->a_dual.view_device(); 1026 const auto &Ba = bkok->a_dual.view_device(); 1027 const auto &Ai = akok->i_dual.view_device(); 1028 const auto &Bi = bkok->i_dual.view_device(); 1029 1030 // Fetch the plans 1031 PetscIntKokkosView &E_NzLeft = mm->E_NzLeft; 1032 PetscSF &bcastSF = mm->sf; 1033 MatScalarKokkosView &rootBuf = mm->rootBuf; 1034 MatScalarKokkosView &leafBuf = mm->leafBuf; 1035 PetscIntKokkosView &irootloc = mm->irootloc; 1036 PetscIntKokkosView &rowoffset = mm->rowoffset; 1037 1038 PetscInt teamSize = mm->E_TeamSize; 1039 PetscInt vectorLength = mm->E_VectorLength; 1040 PetscInt rowsPerTeam = mm->E_RowsPerTeam; 1041 PetscInt workSets = (irootloc.extent(0) + rowsPerTeam - 1) / rowsPerTeam; 1042 1043 // Copy rows in A/B of E to rootBuf, then bcast it to leafBuf 1044 PetscCallCXX(Kokkos::parallel_for( 1045 Kokkos::TeamPolicy<>(PetscGetKokkosExecutionSpace(), workSets, teamSize, vectorLength), KOKKOS_LAMBDA(const KokkosTeamMemberType &t) { 1046 Kokkos::parallel_for(Kokkos::TeamThreadRange(t, 0, rowsPerTeam), [&](PetscInt k) { 1047 size_t r = t.league_rank() * rowsPerTeam + k; // r-th entry in irootloc[] 1048 if (r < irootloc.extent(0)) { 1049 PetscInt i = irootloc(r); // row i of E 1050 PetscInt disp = rowoffset(r); 1051 PetscInt alen = Ai(i + 1) - Ai(i); 1052 PetscInt blen = Bi(i + 1) - Bi(i); 1053 PetscInt nzleft = E_NzLeft(i); 1054 1055 Kokkos::parallel_for(Kokkos::ThreadVectorRange(t, alen + blen), [&](PetscInt j) { 1056 if (j < nzleft) { // B left 1057 rootBuf(disp + j) = Ba(Bi(i) + j); 1058 } else if (j < nzleft + alen) { // diag A 1059 rootBuf(disp + j) = Aa(Ai(i) + j - nzleft); 1060 } else { // B right 1061 rootBuf(disp + j) = Ba(Bi(i) + j - alen); 1062 } 1063 }); 1064 } 1065 }); 1066 })); 1067 PetscCall(PetscSFBcastWithMemTypeBegin(bcastSF, MPIU_SCALAR, PETSC_MEMTYPE_KOKKOS, rootBuf.data(), PETSC_MEMTYPE_KOKKOS, leafBuf.data(), MPI_REPLACE)); 1068 PetscFunctionReturn(PETSC_SUCCESS); 1069 } 1070 1071 // To finish MatMPIAIJKokkosBcast. 1072 static PetscErrorCode MatMPIAIJKokkosBcastEnd(Mat E, PetscSF ownerSF, MatReuse reuse, PetscInt *map, MatMatStruct_AB *mm) 1073 { 1074 PetscFunctionBegin; 1075 const auto &Fd = mm->Fd; 1076 const auto &Fo = mm->Fo; 1077 const auto &Fdi = Fd.graph.row_map; 1078 const auto &Foi = Fo.graph.row_map; 1079 auto &Fda = Fd.values; 1080 auto &Foa = Fo.values; 1081 auto Fm = Fd.numRows(); 1082 1083 PetscIntKokkosView &F_NzLeft = mm->F_NzLeft; 1084 PetscSF &bcastSF = mm->sf; 1085 MatScalarKokkosView &rootBuf = mm->rootBuf; 1086 MatScalarKokkosView &leafBuf = mm->leafBuf; 1087 PetscInt teamSize = mm->F_TeamSize; 1088 PetscInt vectorLength = mm->F_VectorLength; 1089 PetscInt rowsPerTeam = mm->F_RowsPerTeam; 1090 PetscInt workSets = (Fm + rowsPerTeam - 1) / rowsPerTeam; 1091 1092 PetscCall(PetscSFBcastEnd(bcastSF, MPIU_SCALAR, rootBuf.data(), leafBuf.data(), MPI_REPLACE)); 1093 1094 // Update Fda and Foa with new data in leafBuf (as if it is Fa) 1095 PetscCallCXX(Kokkos::parallel_for( 1096 Kokkos::TeamPolicy<>(PetscGetKokkosExecutionSpace(), workSets, teamSize, vectorLength), KOKKOS_LAMBDA(const KokkosTeamMemberType &t) { 1097 Kokkos::parallel_for(Kokkos::TeamThreadRange(t, 0, rowsPerTeam), [&](PetscInt k) { 1098 PetscInt i = t.league_rank() * rowsPerTeam + k; // i-th row in F 1099 if (i < Fm) { 1100 PetscInt nzLeft = F_NzLeft(i); 1101 PetscInt alen = Fdi(i + 1) - Fdi(i); 1102 PetscInt blen = Foi(i + 1) - Foi(i); 1103 PetscInt Fii = Fdi(i) + Foi(i); 1104 1105 Kokkos::parallel_for(Kokkos::ThreadVectorRange(t, alen + blen), [&](PetscInt j) { 1106 PetscScalar val = leafBuf(Fii + j); 1107 if (j < nzLeft) { // left 1108 Foa(Foi(i) + j) = val; 1109 } else if (j < nzLeft + alen) { // diag 1110 Fda(Fdi(i) + j - nzLeft) = val; 1111 } else { // right 1112 Foa(Foi(i) + j - alen) = val; 1113 } 1114 }); 1115 } 1116 }); 1117 })); 1118 PetscFunctionReturn(PETSC_SUCCESS); 1119 } 1120 1121 static PetscErrorCode MatProductSymbolic_MPIAIJKokkos_AtB(Mat_Product *product, Mat A, Mat B, MatMatStruct_AtB *mm) 1122 { 1123 Mat_MPIAIJ *ampi = static_cast<Mat_MPIAIJ *>(A->data); 1124 Mat_MPIAIJ *bmpi = static_cast<Mat_MPIAIJ *>(B->data); 1125 KokkosCsrMatrix Adt, Aot, Ad, Ao, Bd, Bo; 1126 PetscInt cstart, cend; 1127 MPI_Comm comm; 1128 1129 PetscFunctionBegin; 1130 PetscCall(PetscObjectGetComm((PetscObject)B, &comm)); 1131 PetscCall(MatSeqAIJKokkosGenerateTranspose_Private(ampi->A, &Adt)); 1132 PetscCall(MatSeqAIJKokkosGenerateTranspose_Private(ampi->B, &Aot)); 1133 PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(ampi->A, &Ad)); 1134 PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(ampi->B, &Ao)); 1135 PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(bmpi->A, &Bd)); 1136 PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(bmpi->B, &Bo)); 1137 1138 // TODO: add command line options to select spgemm algorithms 1139 auto spgemm_alg = KokkosSparse::SPGEMMAlgorithm::SPGEMM_DEFAULT; // default is TPL if enabled, otherwise KK 1140 1141 // CUDA-10.2's spgemm has bugs. We prefer the SpGEMMreuse APIs introduced in cuda-11.4 1142 #if defined(KOKKOSKERNELS_ENABLE_TPL_CUSPARSE) 1143 #if PETSC_PKG_CUDA_VERSION_LT(11, 4, 0) 1144 spgemm_alg = KokkosSparse::SPGEMMAlgorithm::SPGEMM_KK; 1145 #endif 1146 #endif 1147 1148 PetscCallCXX(mm->kh1.create_spgemm_handle(spgemm_alg)); 1149 PetscCallCXX(mm->kh2.create_spgemm_handle(spgemm_alg)); 1150 PetscCallCXX(mm->kh3.create_spgemm_handle(spgemm_alg)); 1151 PetscCallCXX(mm->kh4.create_spgemm_handle(spgemm_alg)); 1152 1153 // Aot * (B's diag + B's off-diag) 1154 PetscCallCXX(KokkosSparse::spgemm_symbolic(mm->kh3, Aot, false, Bd, false, mm->C3)); 1155 PetscCallCXX(KokkosSparse::spgemm_symbolic(mm->kh4, Aot, false, Bo, false, mm->C4)); 1156 // KK spgemm_symbolic() only populates the result's row map, but not its columns. 1157 // TODO: Remove the fake spgemm_numeric() after KK fixed this problem. 1158 PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh3, Aot, false, Bd, false, mm->C3)); 1159 PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh4, Aot, false, Bo, false, mm->C4)); 1160 #if PETSC_PKG_KOKKOS_KERNELS_VERSION_LT(4, 0, 0) 1161 1162 PetscCallCXX(sort_crs_matrix(mm->C3)); 1163 PetscCallCXX(sort_crs_matrix(mm->C4)); 1164 #endif 1165 1166 // Reduce E (i.e., C3 and C4)'s rows to form F, and overlap the communication 1167 PetscIntKokkosViewHost map_h(NoInit("map_h"), bmpi->B->cmap->n); 1168 PetscCall(MatGetOwnershipRangeColumn(B, &cstart, &cend)); 1169 PetscCall(MatMPIAIJKokkosReduceBegin(comm, mm->C3, mm->C4, cstart, cend, bmpi->garray, ampi->Mvctx, MAT_INITIAL_MATRIX, map_h.data(), mm)); 1170 1171 // Adt * (B's diag + B's off-diag) 1172 PetscCallCXX(KokkosSparse::spgemm_symbolic(mm->kh1, Adt, false, Bd, false, mm->C1)); 1173 PetscCallCXX(KokkosSparse::spgemm_symbolic(mm->kh2, Adt, false, Bo, false, mm->C2_mid)); 1174 PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh1, Adt, false, Bd, false, mm->C1)); 1175 PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh2, Adt, false, Bo, false, mm->C2_mid)); 1176 #if PETSC_PKG_KOKKOS_KERNELS_VERSION_LT(4, 0, 0) 1177 PetscCallCXX(sort_crs_matrix(mm->C1)); 1178 PetscCallCXX(sort_crs_matrix(mm->C2_mid)); 1179 #endif 1180 1181 PetscCall(MatMPIAIJKokkosReduceEnd(comm, mm->C3, mm->C4, cstart, cend, bmpi->garray, ampi->Mvctx, MAT_INITIAL_MATRIX, map_h.data(), mm)); 1182 1183 // Create C2, which shares a, i arrays with C2_mid, but with new column indices and potentially larger column size 1184 MatColIdxKokkosView oldj = mm->C2_mid.graph.entries, newj(NoInit("j"), oldj.extent(0)); 1185 PetscIntKokkosView map = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), map_h); 1186 PetscCallCXX(Kokkos::parallel_for(Kokkos::RangePolicy<>(PetscGetKokkosExecutionSpace(), 0, oldj.extent(0)), KOKKOS_LAMBDA(const PetscInt i) { newj(i) = map(oldj(i)); })); 1187 PetscCallCXX(mm->C2 = KokkosCsrMatrix("C2", mm->C2_mid.numRows(), mm->n /*new column size*/, mm->C2_mid.nnz(), mm->C2_mid.values, mm->C2_mid.graph.row_map, newj)); 1188 1189 // C = (C1+Fd, C2+Fo) 1190 PetscCallCXX(mm->kh1.create_spadd_handle(true)); // C1, Fd are sorted 1191 PetscCallCXX(mm->kh2.create_spadd_handle(true)); // C2, Fo are sorted 1192 PetscCallCXX(KokkosSparse::spadd_symbolic(&mm->kh1, mm->C1, mm->Fd, mm->Cd)); 1193 PetscCallCXX(KokkosSparse::spadd_symbolic(&mm->kh2, mm->C2, mm->Fo, mm->Co)); 1194 PetscCallCXX(KokkosSparse::spadd_numeric(&mm->kh1, 1.0, mm->C1, 1.0, mm->Fd, mm->Cd)); 1195 PetscCallCXX(KokkosSparse::spadd_numeric(&mm->kh2, 1.0, mm->C2, 1.0, mm->Fo, mm->Co)); 1196 PetscFunctionReturn(PETSC_SUCCESS); 1197 } 1198 1199 static PetscErrorCode MatProductNumeric_MPIAIJKokkos_AtB(Mat_Product *product, Mat A, Mat B, MatMatStruct_AtB *mm) 1200 { 1201 Mat_MPIAIJ *ampi = static_cast<Mat_MPIAIJ *>(A->data); 1202 Mat_MPIAIJ *bmpi = static_cast<Mat_MPIAIJ *>(B->data); 1203 KokkosCsrMatrix Adt, Aot, Bd, Bo; 1204 MPI_Comm comm; 1205 1206 PetscFunctionBegin; 1207 PetscCall(PetscObjectGetComm((PetscObject)B, &comm)); 1208 PetscCall(MatSeqAIJKokkosGenerateTranspose_Private(ampi->A, &Adt)); 1209 PetscCall(MatSeqAIJKokkosGenerateTranspose_Private(ampi->B, &Aot)); 1210 PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(bmpi->A, &Bd)); 1211 PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(bmpi->B, &Bo)); 1212 1213 // Aot * (B's diag + B's off-diag) 1214 PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh3, Aot, false, Bd, false, mm->C3)); 1215 PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh4, Aot, false, Bo, false, mm->C4)); 1216 1217 // Reduce E (i.e., C3 and C4)'s rows to form F, and overlap the communication 1218 PetscCall(MatMPIAIJKokkosReduceBegin(comm, mm->C3, mm->C4, 0, 0, NULL, NULL, MAT_REUSE_MATRIX, NULL, mm)); 1219 1220 // Adt * (B's diag + B's off-diag) 1221 PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh1, Adt, false, Bd, false, mm->C1)); 1222 PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh2, Adt, false, Bo, false, mm->C2_mid)); 1223 1224 PetscCall(MatMPIAIJKokkosReduceEnd(comm, mm->C3, mm->C4, 0, 0, NULL, NULL, MAT_REUSE_MATRIX, NULL, mm)); 1225 1226 // C = (C1+Fd, C2+Fo) 1227 PetscCallCXX(KokkosSparse::spadd_numeric(&mm->kh1, 1.0, mm->C1, 1.0, mm->Fd, mm->Cd)); 1228 PetscCallCXX(KokkosSparse::spadd_numeric(&mm->kh2, 1.0, mm->C2, 1.0, mm->Fo, mm->Co)); 1229 PetscFunctionReturn(PETSC_SUCCESS); 1230 } 1231 1232 /* MatProductSymbolic_MPIAIJKokkos_AB - AB flavor of MatProductSymbolic_MPIAIJKokkos 1233 1234 Input Parameters: 1235 + product - Mat_Product which carried out the computation. Passed in to access info about this mat product. 1236 . A - an MPIAIJKOKKOS matrix 1237 . B - an MPIAIJKOKKOS matrix 1238 - mm - a struct used to stash intermediate data when computing AB. Persist from symbolic to numeric operations. 1239 */ 1240 static PetscErrorCode MatProductSymbolic_MPIAIJKokkos_AB(Mat_Product *product, Mat A, Mat B, MatMatStruct_AB *mm) 1241 { 1242 Mat_MPIAIJ *ampi = static_cast<Mat_MPIAIJ *>(A->data); 1243 Mat_MPIAIJ *bmpi = static_cast<Mat_MPIAIJ *>(B->data); 1244 KokkosCsrMatrix Ad, Ao, Bd, Bo; 1245 1246 PetscFunctionBegin; 1247 PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(ampi->A, &Ad)); 1248 PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(ampi->B, &Ao)); 1249 PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(bmpi->A, &Bd)); 1250 PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(bmpi->B, &Bo)); 1251 1252 // TODO: add command line options to select spgemm algorithms 1253 auto spgemm_alg = KokkosSparse::SPGEMMAlgorithm::SPGEMM_DEFAULT; // default is TPL if enabled, otherwise KK 1254 1255 // CUDA-10.2's spgemm has bugs. We prefer the SpGEMMreuse APIs introduced in cuda-11.4 1256 #if defined(KOKKOSKERNELS_ENABLE_TPL_CUSPARSE) 1257 #if PETSC_PKG_CUDA_VERSION_LT(11, 4, 0) 1258 spgemm_alg = KokkosSparse::SPGEMMAlgorithm::SPGEMM_KK; 1259 #endif 1260 #endif 1261 1262 mm->kh1.create_spgemm_handle(spgemm_alg); 1263 mm->kh2.create_spgemm_handle(spgemm_alg); 1264 mm->kh3.create_spgemm_handle(spgemm_alg); 1265 mm->kh4.create_spgemm_handle(spgemm_alg); 1266 1267 // Bcast B's rows to form F, and overlap the communication 1268 PetscIntKokkosViewHost map_h(NoInit("map_h"), bmpi->B->cmap->n); 1269 PetscCall(MatMPIAIJKokkosBcastBegin(B, ampi->Mvctx, MAT_INITIAL_MATRIX, map_h.data(), mm)); 1270 1271 // A's diag * (B's diag + B's off-diag) 1272 PetscCallCXX(KokkosSparse::spgemm_symbolic(mm->kh1, Ad, false, Bd, false, mm->C1)); 1273 PetscCallCXX(KokkosSparse::spgemm_symbolic(mm->kh2, Ad, false, Bo, false, mm->C2_mid)); // C2 aliases with C2_mid, except with new column indices 1274 // KK spgemm_symbolic() only populates the result's row map, but not its columns. 1275 // TODO: Remove the fake spgemm_numeric() after KK fixed this problem. 1276 PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh1, Ad, false, Bd, false, mm->C1)); 1277 PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh2, Ad, false, Bo, false, mm->C2_mid)); 1278 #if PETSC_PKG_KOKKOS_KERNELS_VERSION_LT(4, 0, 0) 1279 PetscCallCXX(sort_crs_matrix(mm->C1)); 1280 PetscCallCXX(sort_crs_matrix(mm->C2_mid)); 1281 #endif 1282 1283 PetscCall(MatMPIAIJKokkosBcastEnd(B, ampi->Mvctx, MAT_INITIAL_MATRIX, map_h.data(), mm)); 1284 1285 // A's off-diag * (F's diag + F's off-diag) 1286 PetscCallCXX(KokkosSparse::spgemm_symbolic(mm->kh3, Ao, false, mm->Fd, false, mm->C3)); 1287 PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh3, Ao, false, mm->Fd, false, mm->C3)); 1288 PetscCallCXX(KokkosSparse::spgemm_symbolic(mm->kh4, Ao, false, mm->Fo, false, mm->C4)); 1289 PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh4, Ao, false, mm->Fo, false, mm->C4)); 1290 #if PETSC_PKG_KOKKOS_KERNELS_VERSION_LT(4, 0, 0) 1291 PetscCallCXX(sort_crs_matrix(mm->C3)); 1292 PetscCallCXX(sort_crs_matrix(mm->C4)); 1293 #endif 1294 1295 // Create C2, which shares a, i arrays with C2_mid, but with new column indices and potentially larger column size 1296 MatColIdxKokkosView oldj = mm->C2_mid.graph.entries, newj(NoInit("j"), oldj.extent(0)); 1297 PetscIntKokkosView map = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), map_h); 1298 PetscCallCXX(Kokkos::parallel_for(Kokkos::RangePolicy<>(PetscGetKokkosExecutionSpace(), 0, oldj.extent(0)), KOKKOS_LAMBDA(const PetscInt i) { newj(i) = map(oldj(i)); })); 1299 mm->C2 = KokkosCsrMatrix("C2", mm->C2_mid.numRows(), mm->n /*new column size*/, mm->C2_mid.nnz(), mm->C2_mid.values, mm->C2_mid.graph.row_map, newj); 1300 1301 // C = (Cd, Co) = (C1+C3, C2+C4) 1302 mm->kh1.create_spadd_handle(true); // C1, C3 are sorted 1303 mm->kh2.create_spadd_handle(true); // C2, C4 are sorted 1304 PetscCallCXX(KokkosSparse::spadd_symbolic(&mm->kh1, mm->C1, mm->C3, mm->Cd)); 1305 PetscCallCXX(KokkosSparse::spadd_symbolic(&mm->kh2, mm->C2, mm->C4, mm->Co)); 1306 PetscCallCXX(KokkosSparse::spadd_numeric(&mm->kh1, 1.0, mm->C1, 1.0, mm->C3, mm->Cd)); 1307 PetscCallCXX(KokkosSparse::spadd_numeric(&mm->kh2, 1.0, mm->C2, 1.0, mm->C4, mm->Co)); 1308 PetscFunctionReturn(PETSC_SUCCESS); 1309 } 1310 1311 static PetscErrorCode MatProductNumeric_MPIAIJKokkos_AB(Mat_Product *product, Mat A, Mat B, MatMatStruct_AB *mm) 1312 { 1313 Mat_MPIAIJ *ampi = static_cast<Mat_MPIAIJ *>(A->data); 1314 Mat_MPIAIJ *bmpi = static_cast<Mat_MPIAIJ *>(B->data); 1315 KokkosCsrMatrix Ad, Ao, Bd, Bo; 1316 1317 PetscFunctionBegin; 1318 PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(ampi->A, &Ad)); 1319 PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(ampi->B, &Ao)); 1320 PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(bmpi->A, &Bd)); 1321 PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(bmpi->B, &Bo)); 1322 1323 // Bcast B's rows to form F, and overlap the communication 1324 PetscCall(MatMPIAIJKokkosBcastBegin(B, NULL, MAT_REUSE_MATRIX, NULL, mm)); 1325 1326 // A's diag * (B's diag + B's off-diag) 1327 PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh1, Ad, false, Bd, false, mm->C1)); 1328 PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh2, Ad, false, Bo, false, mm->C2_mid)); 1329 1330 PetscCall(MatMPIAIJKokkosBcastEnd(B, NULL, MAT_REUSE_MATRIX, NULL, mm)); 1331 1332 // A's off-diag * (F's diag + F's off-diag) 1333 PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh3, Ao, false, mm->Fd, false, mm->C3)); 1334 PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh4, Ao, false, mm->Fo, false, mm->C4)); 1335 1336 // C = (Cd, Co) = (C1+C3, C2+C4) 1337 PetscCallCXX(KokkosSparse::spadd_numeric(&mm->kh1, 1.0, mm->C1, 1.0, mm->C3, mm->Cd)); 1338 PetscCallCXX(KokkosSparse::spadd_numeric(&mm->kh2, 1.0, mm->C2, 1.0, mm->C4, mm->Co)); 1339 PetscFunctionReturn(PETSC_SUCCESS); 1340 } 1341 1342 static PetscErrorCode MatProductNumeric_MPIAIJKokkos(Mat C) 1343 { 1344 Mat_MPIAIJ *cmpi = static_cast<Mat_MPIAIJ *>(C->data); 1345 Mat_Product *product; 1346 MatProductData_MPIAIJKokkos *pdata; 1347 MatProductType ptype; 1348 Mat A, B; 1349 1350 PetscFunctionBegin; 1351 MatCheckProduct(C, 1); // make sure C is a product 1352 product = C->product; 1353 pdata = static_cast<MatProductData_MPIAIJKokkos *>(product->data); 1354 ptype = product->type; 1355 A = product->A; 1356 B = product->B; 1357 1358 // See if numeric has already been done in symbolic (e.g., user calls MatMatMult(A,B,MAT_INITIAL_MATRIX,..,C)). 1359 // If yes, skip the numeric, but reset the flag so that next time when user calls MatMatMult(E,F,MAT_REUSE_MATRIX,..,C), 1360 // we still do numeric. 1361 if (pdata->reusesym) { // numeric reuses results from symbolic 1362 pdata->reusesym = PETSC_FALSE; 1363 PetscFunctionReturn(PETSC_SUCCESS); 1364 } 1365 1366 if (ptype == MATPRODUCT_AB) { 1367 PetscCall(MatProductNumeric_MPIAIJKokkos_AB(product, A, B, pdata->mmAB)); 1368 } else if (ptype == MATPRODUCT_AtB) { 1369 PetscCall(MatProductNumeric_MPIAIJKokkos_AtB(product, A, B, pdata->mmAtB)); 1370 } else if (ptype == MATPRODUCT_PtAP) { // BtAB, computed by Z = AB; C= BtZ 1371 PetscCall(MatProductNumeric_MPIAIJKokkos_AB(product, A, B, pdata->mmAB)); 1372 PetscCall(MatProductNumeric_MPIAIJKokkos_AtB(product, B, pdata->Z, pdata->mmAtB)); 1373 } 1374 1375 PetscCall(MatSeqAIJKokkosModifyDevice(cmpi->A)); // mark that A, B on device are modified 1376 PetscCall(MatSeqAIJKokkosModifyDevice(cmpi->B)); 1377 PetscFunctionReturn(PETSC_SUCCESS); 1378 } 1379 1380 static PetscErrorCode MatProductSymbolic_MPIAIJKokkos(Mat C) 1381 { 1382 Mat A, B; 1383 Mat_Product *product; 1384 MatProductType ptype; 1385 MatProductData_MPIAIJKokkos *pdata; 1386 MatMatStruct *mm = NULL; 1387 PetscInt m, n, M, N; 1388 Mat Cd, Co; 1389 MPI_Comm comm; 1390 1391 PetscFunctionBegin; 1392 PetscCall(PetscObjectGetComm((PetscObject)C, &comm)); 1393 MatCheckProduct(C, 1); 1394 product = C->product; 1395 PetscCheck(!product->data, comm, PETSC_ERR_PLIB, "Product data not empty"); 1396 ptype = product->type; 1397 A = product->A; 1398 B = product->B; 1399 1400 switch (ptype) { 1401 case MATPRODUCT_AB: 1402 m = A->rmap->n; 1403 n = B->cmap->n; 1404 M = A->rmap->N; 1405 N = B->cmap->N; 1406 break; 1407 case MATPRODUCT_AtB: 1408 m = A->cmap->n; 1409 n = B->cmap->n; 1410 M = A->cmap->N; 1411 N = B->cmap->N; 1412 break; 1413 case MATPRODUCT_PtAP: 1414 m = B->cmap->n; 1415 n = B->cmap->n; 1416 M = B->cmap->N; 1417 N = B->cmap->N; 1418 break; /* BtAB */ 1419 default: 1420 SETERRQ(comm, PETSC_ERR_PLIB, "Not for product type %s", MatProductTypes[ptype]); 1421 } 1422 1423 PetscCall(MatSetSizes(C, m, n, M, N)); 1424 PetscCall(PetscLayoutSetUp(C->rmap)); 1425 PetscCall(PetscLayoutSetUp(C->cmap)); 1426 PetscCall(MatSetType(C, ((PetscObject)A)->type_name)); 1427 1428 pdata = new MatProductData_MPIAIJKokkos(); 1429 pdata->reusesym = product->api_user; 1430 1431 if (ptype == MATPRODUCT_AB) { 1432 auto mmAB = new MatMatStruct_AB(); 1433 PetscCall(MatProductSymbolic_MPIAIJKokkos_AB(product, A, B, mmAB)); 1434 mm = pdata->mmAB = mmAB; 1435 } else if (ptype == MATPRODUCT_AtB) { 1436 auto mmAtB = new MatMatStruct_AtB(); 1437 PetscCall(MatProductSymbolic_MPIAIJKokkos_AtB(product, A, B, mmAtB)); 1438 mm = pdata->mmAtB = mmAtB; 1439 } else if (ptype == MATPRODUCT_PtAP) { // C = BtAB, computed as Z = AB; C= BtZ 1440 Mat Zd, Zo, Z; // Zd, Zo are owned by pdata->Z 1441 1442 auto mmAB = new MatMatStruct_AB(); 1443 PetscCall(MatProductSymbolic_MPIAIJKokkos_AB(product, A, B, mmAB)); // Z stored as mmAB->{Cd, Co} 1444 PetscCall(MatCreateSeqAIJKokkosWithKokkosCsrMatrix(PETSC_COMM_SELF, mmAB->Cd, &Zd)); 1445 PetscCall(MatCreateSeqAIJKokkosWithKokkosCsrMatrix(PETSC_COMM_SELF, mmAB->Co, &Zo)); 1446 pdata->mmAB = mmAB; 1447 1448 m = A->rmap->n; // Z's layout 1449 n = B->cmap->n; 1450 M = A->rmap->N; 1451 N = B->cmap->N; 1452 PetscCall(MatCreate(comm, &Z)); 1453 PetscCall(MatSetSizes(Z, m, n, M, N)); 1454 PetscCall(PetscLayoutSetUp(Z->rmap)); 1455 PetscCall(PetscLayoutSetUp(Z->cmap)); 1456 PetscCall(MatSetType(Z, MATMPIAIJKOKKOS)); 1457 PetscCall(MatSetMPIAIJKokkosWithSplitSeqAIJKokkosMatrices(Z, Zd, Zo, mmAB->garray)); 1458 1459 auto mmAtB = new MatMatStruct_AtB(); 1460 PetscCall(MatProductSymbolic_MPIAIJKokkos_AtB(product, B, Z, mmAtB)); // final result C stored as mmAtB->{Cd, Co} 1461 1462 pdata->Z = Z; // give ownership to pdata 1463 mm = pdata->mmAtB = mmAtB; 1464 } 1465 1466 PetscCall(MatCreateSeqAIJKokkosWithKokkosCsrMatrix(PETSC_COMM_SELF, mm->Cd, &Cd)); 1467 PetscCall(MatCreateSeqAIJKokkosWithKokkosCsrMatrix(PETSC_COMM_SELF, mm->Co, &Co)); 1468 PetscCall(MatSetMPIAIJKokkosWithSplitSeqAIJKokkosMatrices(C, Cd, Co, mm->garray)); 1469 1470 C->product->data = pdata; 1471 C->product->destroy = MatProductDataDestroy_MPIAIJKokkos; 1472 C->ops->productnumeric = MatProductNumeric_MPIAIJKokkos; 1473 PetscFunctionReturn(PETSC_SUCCESS); 1474 } 1475 1476 PETSC_INTERN PetscErrorCode MatProductSetFromOptions_MPIAIJKokkos(Mat mat) 1477 { 1478 Mat_Product *product = mat->product; 1479 PetscBool match = PETSC_FALSE; 1480 PetscBool usecpu = PETSC_FALSE; 1481 1482 PetscFunctionBegin; 1483 MatCheckProduct(mat, 1); 1484 if (!product->A->boundtocpu && !product->B->boundtocpu) PetscCall(PetscObjectTypeCompare((PetscObject)product->B, ((PetscObject)product->A)->type_name, &match)); 1485 if (match) { /* we can always fallback to the CPU if requested */ 1486 switch (product->type) { 1487 case MATPRODUCT_AB: 1488 if (product->api_user) { 1489 PetscOptionsBegin(PetscObjectComm((PetscObject)mat), ((PetscObject)mat)->prefix, "MatMatMult", "Mat"); 1490 PetscCall(PetscOptionsBool("-matmatmult_backend_cpu", "Use CPU code", "MatMatMult", usecpu, &usecpu, NULL)); 1491 PetscOptionsEnd(); 1492 } else { 1493 PetscOptionsBegin(PetscObjectComm((PetscObject)mat), ((PetscObject)mat)->prefix, "MatProduct_AB", "Mat"); 1494 PetscCall(PetscOptionsBool("-mat_product_algorithm_backend_cpu", "Use CPU code", "MatMatMult", usecpu, &usecpu, NULL)); 1495 PetscOptionsEnd(); 1496 } 1497 break; 1498 case MATPRODUCT_AtB: 1499 if (product->api_user) { 1500 PetscOptionsBegin(PetscObjectComm((PetscObject)mat), ((PetscObject)mat)->prefix, "MatTransposeMatMult", "Mat"); 1501 PetscCall(PetscOptionsBool("-mattransposematmult_backend_cpu", "Use CPU code", "MatTransposeMatMult", usecpu, &usecpu, NULL)); 1502 PetscOptionsEnd(); 1503 } else { 1504 PetscOptionsBegin(PetscObjectComm((PetscObject)mat), ((PetscObject)mat)->prefix, "MatProduct_AtB", "Mat"); 1505 PetscCall(PetscOptionsBool("-mat_product_algorithm_backend_cpu", "Use CPU code", "MatTransposeMatMult", usecpu, &usecpu, NULL)); 1506 PetscOptionsEnd(); 1507 } 1508 break; 1509 case MATPRODUCT_PtAP: 1510 if (product->api_user) { 1511 PetscOptionsBegin(PetscObjectComm((PetscObject)mat), ((PetscObject)mat)->prefix, "MatPtAP", "Mat"); 1512 PetscCall(PetscOptionsBool("-matptap_backend_cpu", "Use CPU code", "MatPtAP", usecpu, &usecpu, NULL)); 1513 PetscOptionsEnd(); 1514 } else { 1515 PetscOptionsBegin(PetscObjectComm((PetscObject)mat), ((PetscObject)mat)->prefix, "MatProduct_PtAP", "Mat"); 1516 PetscCall(PetscOptionsBool("-mat_product_algorithm_backend_cpu", "Use CPU code", "MatPtAP", usecpu, &usecpu, NULL)); 1517 PetscOptionsEnd(); 1518 } 1519 break; 1520 default: 1521 break; 1522 } 1523 match = (PetscBool)!usecpu; 1524 } 1525 if (match) { 1526 switch (product->type) { 1527 case MATPRODUCT_AB: 1528 case MATPRODUCT_AtB: 1529 case MATPRODUCT_PtAP: 1530 mat->ops->productsymbolic = MatProductSymbolic_MPIAIJKokkos; 1531 break; 1532 default: 1533 break; 1534 } 1535 } 1536 /* fallback to MPIAIJ ops */ 1537 if (!mat->ops->productsymbolic) PetscCall(MatProductSetFromOptions_MPIAIJ(mat)); 1538 PetscFunctionReturn(PETSC_SUCCESS); 1539 } 1540 1541 // Mirror of MatCOOStruct_MPIAIJ on device 1542 struct MatCOOStruct_MPIAIJKokkos { 1543 PetscCount n; 1544 PetscSF sf; 1545 PetscCount Annz, Bnnz; 1546 PetscCount Annz2, Bnnz2; 1547 PetscCountKokkosView Ajmap1, Aperm1; 1548 PetscCountKokkosView Bjmap1, Bperm1; 1549 PetscCountKokkosView Aimap2, Ajmap2, Aperm2; 1550 PetscCountKokkosView Bimap2, Bjmap2, Bperm2; 1551 PetscCountKokkosView Cperm1; 1552 MatScalarKokkosView sendbuf, recvbuf; 1553 1554 MatCOOStruct_MPIAIJKokkos(const MatCOOStruct_MPIAIJ *coo_h) : 1555 n(coo_h->n), 1556 sf(coo_h->sf), 1557 Annz(coo_h->Annz), 1558 Bnnz(coo_h->Bnnz), 1559 Annz2(coo_h->Annz2), 1560 Bnnz2(coo_h->Bnnz2), 1561 Ajmap1(Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), PetscCountKokkosViewHost(coo_h->Ajmap1, coo_h->Annz + 1))), 1562 Aperm1(Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), PetscCountKokkosViewHost(coo_h->Aperm1, coo_h->Atot1))), 1563 Bjmap1(Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), PetscCountKokkosViewHost(coo_h->Bjmap1, coo_h->Bnnz + 1))), 1564 Bperm1(Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), PetscCountKokkosViewHost(coo_h->Bperm1, coo_h->Btot1))), 1565 Aimap2(Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), PetscCountKokkosViewHost(coo_h->Aimap2, coo_h->Annz2))), 1566 Ajmap2(Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), PetscCountKokkosViewHost(coo_h->Ajmap2, coo_h->Annz2 + 1))), 1567 Aperm2(Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), PetscCountKokkosViewHost(coo_h->Aperm2, coo_h->Atot2))), 1568 Bimap2(Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), PetscCountKokkosViewHost(coo_h->Bimap2, coo_h->Bnnz2))), 1569 Bjmap2(Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), PetscCountKokkosViewHost(coo_h->Bjmap2, coo_h->Bnnz2 + 1))), 1570 Bperm2(Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), PetscCountKokkosViewHost(coo_h->Bperm2, coo_h->Btot2))), 1571 Cperm1(Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), PetscCountKokkosViewHost(coo_h->Cperm1, coo_h->sendlen))), 1572 sendbuf(Kokkos::create_mirror_view(Kokkos::WithoutInitializing, DefaultMemorySpace(), MatScalarKokkosViewHost(coo_h->sendbuf, coo_h->sendlen))), 1573 recvbuf(Kokkos::create_mirror_view(Kokkos::WithoutInitializing, DefaultMemorySpace(), MatScalarKokkosViewHost(coo_h->recvbuf, coo_h->recvlen))) 1574 { 1575 PetscCallVoid(PetscObjectReference((PetscObject)sf)); 1576 } 1577 1578 ~MatCOOStruct_MPIAIJKokkos() { PetscCallVoid(PetscSFDestroy(&sf)); } 1579 }; 1580 1581 static PetscErrorCode MatCOOStructDestroy_MPIAIJKokkos(void *data) 1582 { 1583 PetscFunctionBegin; 1584 PetscCallCXX(delete static_cast<MatCOOStruct_MPIAIJKokkos *>(data)); 1585 PetscFunctionReturn(PETSC_SUCCESS); 1586 } 1587 1588 static PetscErrorCode MatSetPreallocationCOO_MPIAIJKokkos(Mat mat, PetscCount coo_n, PetscInt coo_i[], PetscInt coo_j[]) 1589 { 1590 PetscContainer container_h, container_d; 1591 MatCOOStruct_MPIAIJ *coo_h; 1592 MatCOOStruct_MPIAIJKokkos *coo_d; 1593 1594 PetscFunctionBegin; 1595 PetscCall(MatSetPreallocationCOO_MPIAIJ(mat, coo_n, coo_i, coo_j)); /* mpiaij->A,B's type is set to seqaijkokkos */ 1596 mat->preallocated = PETSC_TRUE; 1597 PetscCall(MatAssemblyBegin(mat, MAT_FINAL_ASSEMBLY)); 1598 PetscCall(MatAssemblyEnd(mat, MAT_FINAL_ASSEMBLY)); 1599 PetscCall(MatZeroEntries(mat)); 1600 1601 // Copy the COO struct to device 1602 PetscCall(PetscObjectQuery((PetscObject)mat, "__PETSc_MatCOOStruct_Host", (PetscObject *)&container_h)); 1603 PetscCall(PetscContainerGetPointer(container_h, (void **)&coo_h)); 1604 PetscCallCXX(coo_d = new MatCOOStruct_MPIAIJKokkos(coo_h)); 1605 1606 // Put the COO struct in a container and then attach that to the matrix 1607 PetscCall(PetscContainerCreate(PETSC_COMM_SELF, &container_d)); 1608 PetscCall(PetscContainerSetPointer(container_d, coo_d)); 1609 PetscCall(PetscContainerSetUserDestroy(container_d, MatCOOStructDestroy_MPIAIJKokkos)); 1610 PetscCall(PetscObjectCompose((PetscObject)mat, "__PETSc_MatCOOStruct_Device", (PetscObject)container_d)); 1611 PetscCall(PetscContainerDestroy(&container_d)); 1612 PetscFunctionReturn(PETSC_SUCCESS); 1613 } 1614 1615 static PetscErrorCode MatSetValuesCOO_MPIAIJKokkos(Mat mat, const PetscScalar v[], InsertMode imode) 1616 { 1617 Mat_MPIAIJ *mpiaij = static_cast<Mat_MPIAIJ *>(mat->data); 1618 Mat A = mpiaij->A, B = mpiaij->B; 1619 MatScalarKokkosView Aa, Ba; 1620 MatScalarKokkosView v1; 1621 PetscMemType memtype; 1622 PetscContainer container; 1623 MatCOOStruct_MPIAIJKokkos *coo; 1624 1625 PetscFunctionBegin; 1626 PetscCall(PetscObjectQuery((PetscObject)mat, "__PETSc_MatCOOStruct_Device", (PetscObject *)&container)); 1627 PetscCall(PetscContainerGetPointer(container, (void **)&coo)); 1628 1629 const auto &n = coo->n; 1630 const auto &Annz = coo->Annz; 1631 const auto &Annz2 = coo->Annz2; 1632 const auto &Bnnz = coo->Bnnz; 1633 const auto &Bnnz2 = coo->Bnnz2; 1634 const auto &vsend = coo->sendbuf; 1635 const auto &v2 = coo->recvbuf; 1636 const auto &Ajmap1 = coo->Ajmap1; 1637 const auto &Ajmap2 = coo->Ajmap2; 1638 const auto &Aimap2 = coo->Aimap2; 1639 const auto &Bjmap1 = coo->Bjmap1; 1640 const auto &Bjmap2 = coo->Bjmap2; 1641 const auto &Bimap2 = coo->Bimap2; 1642 const auto &Aperm1 = coo->Aperm1; 1643 const auto &Aperm2 = coo->Aperm2; 1644 const auto &Bperm1 = coo->Bperm1; 1645 const auto &Bperm2 = coo->Bperm2; 1646 const auto &Cperm1 = coo->Cperm1; 1647 1648 PetscCall(PetscGetMemType(v, &memtype)); /* Return PETSC_MEMTYPE_HOST when v is NULL */ 1649 if (PetscMemTypeHost(memtype)) { /* If user gave v[] in host, we need to copy it to device if any */ 1650 v1 = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), MatScalarKokkosViewHost((PetscScalar *)v, n)); 1651 } else { 1652 v1 = MatScalarKokkosView((PetscScalar *)v, n); /* Directly use v[]'s memory */ 1653 } 1654 1655 if (imode == INSERT_VALUES) { 1656 PetscCall(MatSeqAIJGetKokkosViewWrite(A, &Aa)); /* write matrix values */ 1657 PetscCall(MatSeqAIJGetKokkosViewWrite(B, &Ba)); 1658 } else { 1659 PetscCall(MatSeqAIJGetKokkosView(A, &Aa)); /* read & write matrix values */ 1660 PetscCall(MatSeqAIJGetKokkosView(B, &Ba)); 1661 } 1662 1663 PetscCall(PetscLogGpuTimeBegin()); 1664 /* Pack entries to be sent to remote */ 1665 Kokkos::parallel_for(Kokkos::RangePolicy<>(PetscGetKokkosExecutionSpace(), 0, vsend.extent(0)), KOKKOS_LAMBDA(const PetscCount i) { vsend(i) = v1(Cperm1(i)); }); 1666 1667 /* Send remote entries to their owner and overlap the communication with local computation */ 1668 PetscCall(PetscSFReduceWithMemTypeBegin(coo->sf, MPIU_SCALAR, PETSC_MEMTYPE_KOKKOS, vsend.data(), PETSC_MEMTYPE_KOKKOS, v2.data(), MPI_REPLACE)); 1669 /* Add local entries to A and B in one kernel */ 1670 Kokkos::parallel_for( 1671 Kokkos::RangePolicy<>(PetscGetKokkosExecutionSpace(), 0, Annz + Bnnz), KOKKOS_LAMBDA(PetscCount i) { 1672 PetscScalar sum = 0.0; 1673 if (i < Annz) { 1674 for (PetscCount k = Ajmap1(i); k < Ajmap1(i + 1); k++) sum += v1(Aperm1(k)); 1675 Aa(i) = (imode == INSERT_VALUES ? 0.0 : Aa(i)) + sum; 1676 } else { 1677 i -= Annz; 1678 for (PetscCount k = Bjmap1(i); k < Bjmap1(i + 1); k++) sum += v1(Bperm1(k)); 1679 Ba(i) = (imode == INSERT_VALUES ? 0.0 : Ba(i)) + sum; 1680 } 1681 }); 1682 PetscCall(PetscSFReduceEnd(coo->sf, MPIU_SCALAR, vsend.data(), v2.data(), MPI_REPLACE)); 1683 1684 /* Add received remote entries to A and B in one kernel */ 1685 Kokkos::parallel_for( 1686 Kokkos::RangePolicy<>(PetscGetKokkosExecutionSpace(), 0, Annz2 + Bnnz2), KOKKOS_LAMBDA(PetscCount i) { 1687 if (i < Annz2) { 1688 for (PetscCount k = Ajmap2(i); k < Ajmap2(i + 1); k++) Aa(Aimap2(i)) += v2(Aperm2(k)); 1689 } else { 1690 i -= Annz2; 1691 for (PetscCount k = Bjmap2(i); k < Bjmap2(i + 1); k++) Ba(Bimap2(i)) += v2(Bperm2(k)); 1692 } 1693 }); 1694 PetscCall(PetscLogGpuTimeEnd()); 1695 1696 if (imode == INSERT_VALUES) { 1697 PetscCall(MatSeqAIJRestoreKokkosViewWrite(A, &Aa)); /* Increase A & B's state etc. */ 1698 PetscCall(MatSeqAIJRestoreKokkosViewWrite(B, &Ba)); 1699 } else { 1700 PetscCall(MatSeqAIJRestoreKokkosView(A, &Aa)); 1701 PetscCall(MatSeqAIJRestoreKokkosView(B, &Ba)); 1702 } 1703 PetscFunctionReturn(PETSC_SUCCESS); 1704 } 1705 1706 static PetscErrorCode MatDestroy_MPIAIJKokkos(Mat A) 1707 { 1708 PetscFunctionBegin; 1709 PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatMPIAIJSetPreallocation_C", NULL)); 1710 PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatMPIAIJGetLocalMatMerge_C", NULL)); 1711 PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatSetPreallocationCOO_C", NULL)); 1712 PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatSetValuesCOO_C", NULL)); 1713 PetscCall(MatDestroy_MPIAIJ(A)); 1714 PetscFunctionReturn(PETSC_SUCCESS); 1715 } 1716 1717 static PetscErrorCode MatShift_MPIAIJKokkos(Mat A, PetscScalar a) 1718 { 1719 Mat_MPIAIJ *mpiaij = static_cast<Mat_MPIAIJ *>(A->data); 1720 PetscBool congruent; 1721 1722 PetscFunctionBegin; 1723 PetscCall(MatHasCongruentLayouts(A, &congruent)); 1724 if (congruent) { // square matrix and the diagonals are solely in the diag block 1725 PetscCall(MatShift(mpiaij->A, a)); 1726 } else { // too hard, use the general version 1727 PetscCall(MatShift_Basic(A, a)); 1728 } 1729 PetscFunctionReturn(PETSC_SUCCESS); 1730 } 1731 1732 static PetscErrorCode MatSetOps_MPIAIJKokkos(Mat B) 1733 { 1734 PetscFunctionBegin; 1735 B->ops->assemblyend = MatAssemblyEnd_MPIAIJKokkos; 1736 B->ops->mult = MatMult_MPIAIJKokkos; 1737 B->ops->multadd = MatMultAdd_MPIAIJKokkos; 1738 B->ops->multtranspose = MatMultTranspose_MPIAIJKokkos; 1739 B->ops->productsetfromoptions = MatProductSetFromOptions_MPIAIJKokkos; 1740 B->ops->destroy = MatDestroy_MPIAIJKokkos; 1741 B->ops->shift = MatShift_MPIAIJKokkos; 1742 1743 PetscCall(PetscObjectComposeFunction((PetscObject)B, "MatMPIAIJSetPreallocation_C", MatMPIAIJSetPreallocation_MPIAIJKokkos)); 1744 PetscCall(PetscObjectComposeFunction((PetscObject)B, "MatMPIAIJGetLocalMatMerge_C", MatMPIAIJGetLocalMatMerge_MPIAIJKokkos)); 1745 PetscCall(PetscObjectComposeFunction((PetscObject)B, "MatSetPreallocationCOO_C", MatSetPreallocationCOO_MPIAIJKokkos)); 1746 PetscCall(PetscObjectComposeFunction((PetscObject)B, "MatSetValuesCOO_C", MatSetValuesCOO_MPIAIJKokkos)); 1747 PetscFunctionReturn(PETSC_SUCCESS); 1748 } 1749 1750 PETSC_INTERN PetscErrorCode MatConvert_MPIAIJ_MPIAIJKokkos(Mat A, MatType mtype, MatReuse reuse, Mat *newmat) 1751 { 1752 Mat B; 1753 Mat_MPIAIJ *a; 1754 1755 PetscFunctionBegin; 1756 if (reuse == MAT_INITIAL_MATRIX) { 1757 PetscCall(MatDuplicate(A, MAT_COPY_VALUES, newmat)); 1758 } else if (reuse == MAT_REUSE_MATRIX) { 1759 PetscCall(MatCopy(A, *newmat, SAME_NONZERO_PATTERN)); 1760 } 1761 B = *newmat; 1762 1763 B->boundtocpu = PETSC_FALSE; 1764 PetscCall(PetscFree(B->defaultvectype)); 1765 PetscCall(PetscStrallocpy(VECKOKKOS, &B->defaultvectype)); 1766 PetscCall(PetscObjectChangeTypeName((PetscObject)B, MATMPIAIJKOKKOS)); 1767 1768 a = static_cast<Mat_MPIAIJ *>(A->data); 1769 if (a->A) PetscCall(MatSetType(a->A, MATSEQAIJKOKKOS)); 1770 if (a->B) PetscCall(MatSetType(a->B, MATSEQAIJKOKKOS)); 1771 if (a->lvec) PetscCall(VecSetType(a->lvec, VECSEQKOKKOS)); 1772 PetscCall(MatSetOps_MPIAIJKokkos(B)); 1773 PetscFunctionReturn(PETSC_SUCCESS); 1774 } 1775 1776 /*MC 1777 MATAIJKOKKOS - "mpiaijkokkos", a matrix type to be used for CSR sparse matrices with Kokkos 1778 1779 A matrix type using Kokkos-Kernels CrsMatrix type for portability across different device types 1780 1781 Options Database Key: 1782 . -mat_type aijkokkos - sets the matrix type to `MATAIJKOKKOS` 1783 1784 Level: beginner 1785 1786 .seealso: [](ch_matrices), `Mat`, `MatCreateAIJKokkos()`, `MATSEQAIJKOKKOS`, `MATSEQAIJ`, `MATMPIAIJ` 1787 M*/ 1788 PETSC_EXTERN PetscErrorCode MatCreate_MPIAIJKokkos(Mat A) 1789 { 1790 PetscFunctionBegin; 1791 PetscCall(PetscKokkosInitializeCheck()); 1792 PetscCall(MatCreate_MPIAIJ(A)); 1793 PetscCall(MatConvert_MPIAIJ_MPIAIJKokkos(A, MATMPIAIJKOKKOS, MAT_INPLACE_MATRIX, &A)); 1794 PetscFunctionReturn(PETSC_SUCCESS); 1795 } 1796 1797 /*@C 1798 MatCreateAIJKokkos - Creates a sparse matrix in `MATAIJKOKOS` (compressed row) format 1799 (the default parallel PETSc format). This matrix will ultimately pushed down 1800 to Kokkos for calculations. 1801 1802 Collective 1803 1804 Input Parameters: 1805 + comm - MPI communicator, set to `PETSC_COMM_SELF` 1806 . m - number of local rows (or `PETSC_DECIDE` to have calculated if `M` is given) 1807 This value should be the same as the local size used in creating the 1808 y vector for the matrix-vector product y = Ax. 1809 . n - This value should be the same as the local size used in creating the 1810 x vector for the matrix-vector product y = Ax. (or `PETSC_DECIDE` to have 1811 calculated if N is given) For square matrices n is almost always `m`. 1812 . M - number of global rows (or `PETSC_DETERMINE` to have calculated if `m` is given) 1813 . N - number of global columns (or `PETSC_DETERMINE` to have calculated if `n` is given) 1814 . d_nz - number of nonzeros per row in DIAGONAL portion of local submatrix 1815 (same value is used for all local rows) 1816 . d_nnz - array containing the number of nonzeros in the various rows of the 1817 DIAGONAL portion of the local submatrix (possibly different for each row) 1818 or `NULL`, if `d_nz` is used to specify the nonzero structure. 1819 The size of this array is equal to the number of local rows, i.e `m`. 1820 For matrices you plan to factor you must leave room for the diagonal entry and 1821 put in the entry even if it is zero. 1822 . o_nz - number of nonzeros per row in the OFF-DIAGONAL portion of local 1823 submatrix (same value is used for all local rows). 1824 - o_nnz - array containing the number of nonzeros in the various rows of the 1825 OFF-DIAGONAL portion of the local submatrix (possibly different for 1826 each row) or `NULL`, if `o_nz` is used to specify the nonzero 1827 structure. The size of this array is equal to the number 1828 of local rows, i.e `m`. 1829 1830 Output Parameter: 1831 . A - the matrix 1832 1833 Level: intermediate 1834 1835 Notes: 1836 It is recommended that one use the `MatCreate()`, `MatSetType()` and/or `MatSetFromOptions()`, 1837 MatXXXXSetPreallocation() paradigm instead of this routine directly. 1838 [MatXXXXSetPreallocation() is, for example, `MatSeqAIJSetPreallocation()`] 1839 1840 The AIJ format, also called compressed row storage), is fully compatible with standard Fortran 1841 storage. That is, the stored row and column indices can begin at 1842 either one (as in Fortran) or zero. 1843 1844 .seealso: [](ch_matrices), `Mat`, `MATAIJKOKOS`, `MATSEQAIJKOKOS`, `MATMPIAIJKOKOS`, `MatCreate()`, `MatCreateAIJ()`, `MatSetValues()`, 1845 `MatSeqAIJSetColumnIndices()`, `MatCreateSeqAIJWithArrays()`, `MATMPIAIJKOKKOS`, `MATAIJKOKKOS` 1846 @*/ 1847 PetscErrorCode MatCreateAIJKokkos(MPI_Comm comm, PetscInt m, PetscInt n, PetscInt M, PetscInt N, PetscInt d_nz, const PetscInt d_nnz[], PetscInt o_nz, const PetscInt o_nnz[], Mat *A) 1848 { 1849 PetscMPIInt size; 1850 1851 PetscFunctionBegin; 1852 PetscCall(MatCreate(comm, A)); 1853 PetscCall(MatSetSizes(*A, m, n, M, N)); 1854 PetscCallMPI(MPI_Comm_size(comm, &size)); 1855 if (size > 1) { 1856 PetscCall(MatSetType(*A, MATMPIAIJKOKKOS)); 1857 PetscCall(MatMPIAIJSetPreallocation(*A, d_nz, d_nnz, o_nz, o_nnz)); 1858 } else { 1859 PetscCall(MatSetType(*A, MATSEQAIJKOKKOS)); 1860 PetscCall(MatSeqAIJSetPreallocation(*A, d_nz, d_nnz)); 1861 } 1862 PetscFunctionReturn(PETSC_SUCCESS); 1863 } 1864