1 #include <petsc_kokkos.hpp> 2 #include <petscvec_kokkos.hpp> 3 #include <petscpkg_version.h> 4 #include <petsc/private/sfimpl.h> 5 #include <petsc/private/kokkosimpl.hpp> 6 #include <../src/mat/impls/aij/seq/kokkos/aijkok.hpp> 7 #include <../src/mat/impls/aij/mpi/mpiaij.h> 8 #include <KokkosSparse_spadd.hpp> 9 #include <KokkosSparse_spgemm.hpp> 10 11 static PetscErrorCode MatAssemblyEnd_MPIAIJKokkos(Mat A, MatAssemblyType mode) 12 { 13 Mat_MPIAIJ *mpiaij = (Mat_MPIAIJ *)A->data; 14 15 PetscFunctionBegin; 16 PetscCall(MatAssemblyEnd_MPIAIJ(A, mode)); 17 /* E.g., MatCreateSubMatrix() calls MatCreateMPIAIJWithSeqAIJ(comm,A,B,..), which creates Bnew of SEQAIJ and destroys B of SEQAIJKOKKOS. 18 Thus we finalize A/B/lvec's type in MatAssemblyEnd() to handle various cases. 19 */ 20 if (mode == MAT_FINAL_ASSEMBLY) { 21 PetscScalarKokkosView v; 22 23 PetscCall(MatSetType(mpiaij->A, MATSEQAIJKOKKOS)); 24 PetscCall(MatSetType(mpiaij->B, MATSEQAIJKOKKOS)); 25 PetscCall(VecSetType(mpiaij->lvec, VECSEQKOKKOS)); // lvec is init'ed on host, without copying to device 26 PetscCall(VecGetKokkosViewWrite(mpiaij->lvec, &v)); // mark lvec updated on device, as we never need to init lvec on device 27 PetscCall(VecRestoreKokkosViewWrite(mpiaij->lvec, &v)); 28 } 29 PetscFunctionReturn(PETSC_SUCCESS); 30 } 31 32 static PetscErrorCode MatMPIAIJSetPreallocation_MPIAIJKokkos(Mat mat, PetscInt d_nz, const PetscInt d_nnz[], PetscInt o_nz, const PetscInt o_nnz[]) 33 { 34 Mat_MPIAIJ *mpiaij = (Mat_MPIAIJ *)mat->data; 35 36 PetscFunctionBegin; 37 // If mat was set to use the "set values with a hash table" mechanism, discard it and restore the cached ops 38 if (mat->hash_active) { 39 mat->ops[0] = mpiaij->cops; 40 mat->hash_active = PETSC_FALSE; 41 } 42 43 PetscCall(PetscLayoutSetUp(mat->rmap)); 44 PetscCall(PetscLayoutSetUp(mat->cmap)); 45 #if defined(PETSC_USE_DEBUG) 46 if (d_nnz) { 47 PetscInt i; 48 for (i = 0; i < mat->rmap->n; i++) PetscCheck(d_nnz[i] >= 0, PETSC_COMM_SELF, PETSC_ERR_ARG_OUTOFRANGE, "d_nnz cannot be less than 0: local row %" PetscInt_FMT " value %" PetscInt_FMT, i, d_nnz[i]); 49 } 50 if (o_nnz) { 51 PetscInt i; 52 for (i = 0; i < mat->rmap->n; i++) PetscCheck(o_nnz[i] >= 0, PETSC_COMM_SELF, PETSC_ERR_ARG_OUTOFRANGE, "o_nnz cannot be less than 0: local row %" PetscInt_FMT " value %" PetscInt_FMT, i, o_nnz[i]); 53 } 54 #endif 55 #if defined(PETSC_USE_CTABLE) 56 PetscCall(PetscHMapIDestroy(&mpiaij->colmap)); 57 #else 58 PetscCall(PetscFree(mpiaij->colmap)); 59 #endif 60 PetscCall(PetscFree(mpiaij->garray)); 61 PetscCall(VecDestroy(&mpiaij->lvec)); 62 PetscCall(VecScatterDestroy(&mpiaij->Mvctx)); 63 /* Because the B will have been resized we simply destroy it and create a new one each time */ 64 PetscCall(MatDestroy(&mpiaij->B)); 65 66 if (!mpiaij->A) { 67 PetscCall(MatCreate(PETSC_COMM_SELF, &mpiaij->A)); 68 PetscCall(MatSetSizes(mpiaij->A, mat->rmap->n, mat->cmap->n, mat->rmap->n, mat->cmap->n)); 69 } 70 if (!mpiaij->B) { 71 PetscMPIInt size; 72 PetscCallMPI(MPI_Comm_size(PetscObjectComm((PetscObject)mat), &size)); 73 PetscCall(MatCreate(PETSC_COMM_SELF, &mpiaij->B)); 74 PetscCall(MatSetSizes(mpiaij->B, mat->rmap->n, size > 1 ? mat->cmap->N : 0, mat->rmap->n, size > 1 ? mat->cmap->N : 0)); 75 } 76 PetscCall(MatSetType(mpiaij->A, MATSEQAIJKOKKOS)); 77 PetscCall(MatSetType(mpiaij->B, MATSEQAIJKOKKOS)); 78 PetscCall(MatSeqAIJSetPreallocation(mpiaij->A, d_nz, d_nnz)); 79 PetscCall(MatSeqAIJSetPreallocation(mpiaij->B, o_nz, o_nnz)); 80 mat->preallocated = PETSC_TRUE; 81 PetscFunctionReturn(PETSC_SUCCESS); 82 } 83 84 static PetscErrorCode MatMult_MPIAIJKokkos(Mat mat, Vec xx, Vec yy) 85 { 86 Mat_MPIAIJ *mpiaij = (Mat_MPIAIJ *)mat->data; 87 PetscInt nt; 88 89 PetscFunctionBegin; 90 PetscCall(VecGetLocalSize(xx, &nt)); 91 PetscCheck(nt == mat->cmap->n, PETSC_COMM_SELF, PETSC_ERR_ARG_SIZ, "Incompatible partition of mat (%" PetscInt_FMT ") and xx (%" PetscInt_FMT ")", mat->cmap->n, nt); 92 PetscCall(VecScatterBegin(mpiaij->Mvctx, xx, mpiaij->lvec, INSERT_VALUES, SCATTER_FORWARD)); 93 PetscCall((*mpiaij->A->ops->mult)(mpiaij->A, xx, yy)); 94 PetscCall(VecScatterEnd(mpiaij->Mvctx, xx, mpiaij->lvec, INSERT_VALUES, SCATTER_FORWARD)); 95 PetscCall((*mpiaij->B->ops->multadd)(mpiaij->B, mpiaij->lvec, yy, yy)); 96 PetscFunctionReturn(PETSC_SUCCESS); 97 } 98 99 static PetscErrorCode MatMultAdd_MPIAIJKokkos(Mat mat, Vec xx, Vec yy, Vec zz) 100 { 101 Mat_MPIAIJ *mpiaij = (Mat_MPIAIJ *)mat->data; 102 PetscInt nt; 103 104 PetscFunctionBegin; 105 PetscCall(VecGetLocalSize(xx, &nt)); 106 PetscCheck(nt == mat->cmap->n, PETSC_COMM_SELF, PETSC_ERR_ARG_SIZ, "Incompatible partition of mat (%" PetscInt_FMT ") and xx (%" PetscInt_FMT ")", mat->cmap->n, nt); 107 PetscCall(VecScatterBegin(mpiaij->Mvctx, xx, mpiaij->lvec, INSERT_VALUES, SCATTER_FORWARD)); 108 PetscCall((*mpiaij->A->ops->multadd)(mpiaij->A, xx, yy, zz)); 109 PetscCall(VecScatterEnd(mpiaij->Mvctx, xx, mpiaij->lvec, INSERT_VALUES, SCATTER_FORWARD)); 110 PetscCall((*mpiaij->B->ops->multadd)(mpiaij->B, mpiaij->lvec, zz, zz)); 111 PetscFunctionReturn(PETSC_SUCCESS); 112 } 113 114 static PetscErrorCode MatMultTranspose_MPIAIJKokkos(Mat mat, Vec xx, Vec yy) 115 { 116 Mat_MPIAIJ *mpiaij = (Mat_MPIAIJ *)mat->data; 117 PetscInt nt; 118 119 PetscFunctionBegin; 120 PetscCall(VecGetLocalSize(xx, &nt)); 121 PetscCheck(nt == mat->rmap->n, PETSC_COMM_SELF, PETSC_ERR_ARG_SIZ, "Incompatible partition of mat (%" PetscInt_FMT ") and xx (%" PetscInt_FMT ")", mat->rmap->n, nt); 122 PetscCall((*mpiaij->B->ops->multtranspose)(mpiaij->B, xx, mpiaij->lvec)); 123 PetscCall((*mpiaij->A->ops->multtranspose)(mpiaij->A, xx, yy)); 124 PetscCall(VecScatterBegin(mpiaij->Mvctx, mpiaij->lvec, yy, ADD_VALUES, SCATTER_REVERSE)); 125 PetscCall(VecScatterEnd(mpiaij->Mvctx, mpiaij->lvec, yy, ADD_VALUES, SCATTER_REVERSE)); 126 PetscFunctionReturn(PETSC_SUCCESS); 127 } 128 129 /* Merge the "A, B" matrices of mat into a matrix C. mat's type is MPIAIJKOKKOS. C's type is MATSEQAIJKOKKOS. 130 A is put before B. C's size would be A->rmap->n by (A->cmap->n + B->cmap->n). 131 C still uses local column ids. Their corresponding global column ids are returned in glob. 132 */ 133 static PetscErrorCode MatMPIAIJGetLocalMatMerge_MPIAIJKokkos(Mat mat, MatReuse reuse, IS *glob, Mat *C) 134 { 135 Mat Ad, Ao; 136 const PetscInt *cmap; 137 138 PetscFunctionBegin; 139 PetscCall(MatMPIAIJGetSeqAIJ(mat, &Ad, &Ao, &cmap)); 140 PetscCall(MatSeqAIJKokkosMergeMats(Ad, Ao, reuse, C)); 141 if (glob) { 142 PetscInt cst, i, dn, on, *gidx; 143 PetscCall(MatGetLocalSize(Ad, NULL, &dn)); 144 PetscCall(MatGetLocalSize(Ao, NULL, &on)); 145 PetscCall(MatGetOwnershipRangeColumn(mat, &cst, NULL)); 146 PetscCall(PetscMalloc1(dn + on, &gidx)); 147 for (i = 0; i < dn; i++) gidx[i] = cst + i; 148 for (i = 0; i < on; i++) gidx[i + dn] = cmap[i]; 149 PetscCall(ISCreateGeneral(PetscObjectComm((PetscObject)Ad), dn + on, gidx, PETSC_OWN_POINTER, glob)); 150 } 151 PetscFunctionReturn(PETSC_SUCCESS); 152 } 153 154 /* Structs used in matrix products of type C=AB, C=A^tB and C=B^tAB */ 155 struct MatMatStruct { 156 PetscInt n, *garray; // C's garray and its size. 157 KokkosCsrMatrix Cd, Co; // C is in split form matrices (all in local column indcies) 158 KokkosCsrMatrix C1, C2, C3, C4; // intermediate mat products 159 KokkosCsrMatrix C2_mid, C4_mid; // alias of C2, C4; share their a[], i[], but with different j[] (hence column size) 160 PetscIntKokkosView E_NzLeft; 161 PetscSF sf = nullptr; // SF to bcast or reduce matrices E to F 162 MatScalarKokkosView rootBuf, leafBuf; 163 KokkosCsrMatrix Fd, Fo; // F in split form 164 165 KernelHandle kh1; // compute C1, add C1+C3 or C1+Fd 166 KernelHandle kh2; // compute C2, add C2+C4 or C2+Fo 167 KernelHandle kh3; // compute C3 168 KernelHandle kh4; // compute C4 169 170 PetscInt E_TeamSize; // kernel launching parameters in merging E or splitting F 171 PetscInt E_VectorLength; 172 PetscInt E_RowsPerTeam; 173 PetscInt F_TeamSize; 174 PetscInt F_VectorLength; 175 PetscInt F_RowsPerTeam; 176 177 ~MatMatStruct() 178 { 179 PetscFunctionBegin; 180 PetscCallAbort(PETSC_COMM_SELF, PetscSFDestroy(&sf)); 181 PetscFunctionReturnVoid(); 182 } 183 }; 184 185 struct MatMatStruct_AB : public MatMatStruct { 186 PetscIntKokkosView F_NzLeft; // plans to split F (in leafbuf) into Fd, Fo 187 PetscIntKokkosView irootloc; // plans to put E (i.e., Bd, Bo) into rootBuf 188 PetscIntKokkosView rowoffset; 189 }; 190 191 struct MatMatStruct_AtB : public MatMatStruct { 192 MatColIdxKokkosView Fdjmap; // plans to reduce data in rootBuf to Fd, Fo 193 MatColIdxKokkosView Fdjperm; 194 MatColIdxKokkosView Fojmap; 195 MatColIdxKokkosView Fojperm; 196 }; 197 198 struct MatProductData_MPIAIJKokkos { 199 MatMatStruct_AB *mmAB = nullptr; 200 MatMatStruct_AtB *mmAtB = nullptr; 201 PetscBool reusesym = PETSC_FALSE; 202 Mat Z = nullptr; // store Z=AB in computing BtAB 203 204 ~MatProductData_MPIAIJKokkos() 205 { 206 delete mmAB; 207 delete mmAtB; 208 PetscCallAbort(PETSC_COMM_SELF, MatDestroy(&Z)); 209 } 210 }; 211 212 static PetscErrorCode MatProductDataDestroy_MPIAIJKokkos(void *data) 213 { 214 PetscFunctionBegin; 215 PetscCallCXX(delete static_cast<MatProductData_MPIAIJKokkos *>(data)); 216 PetscFunctionReturn(PETSC_SUCCESS); 217 } 218 219 /* MatSetMPIAIJKokkosWithSplitSeqAIJKokkosMatrices - Set the diag and offdiag matrices of a MATMPIAIJKOKKOS matrix. 220 It is similar to MatCreateMPIAIJWithSplitArrays. 221 222 Input Parameters: 223 + mat - the MATMPIAIJKOKKOS matrix, which should have its type and layout set, but should not have its diag, offdiag matrices set 224 . A - the diag matrix using local col ids 225 - B - the offdiag matrix using global col ids 226 227 Output Parameter: 228 . mat - the updated MATMPIAIJKOKKOS matrix 229 */ 230 static PetscErrorCode MatSetMPIAIJKokkosWithSplitSeqAIJKokkosMatrices(Mat mat, Mat A, Mat B, PetscInt *garray) 231 { 232 Mat_MPIAIJ *mpiaij = static_cast<Mat_MPIAIJ *>(mat->data); 233 PetscInt m, n, M, N, Am, An, Bm, Bn; 234 235 PetscFunctionBegin; 236 PetscCall(MatGetSize(mat, &M, &N)); 237 PetscCall(MatGetLocalSize(mat, &m, &n)); 238 PetscCall(MatGetLocalSize(A, &Am, &An)); 239 PetscCall(MatGetLocalSize(B, &Bm, &Bn)); 240 241 PetscCheck(m == Am && m == Bm, PETSC_COMM_SELF, PETSC_ERR_PLIB, "local number of rows do not match"); 242 PetscCheck(n == An, PETSC_COMM_SELF, PETSC_ERR_PLIB, "local number of columns do not match"); 243 // PetscCheck(N == Bn, PETSC_COMM_SELF, PETSC_ERR_PLIB, "global number of columns do not match"); 244 PetscCheck(!mpiaij->A && !mpiaij->B, PETSC_COMM_SELF, PETSC_ERR_PLIB, "A, B of the MPIAIJ matrix are not empty"); 245 mpiaij->A = A; 246 mpiaij->B = B; 247 mpiaij->garray = garray; 248 249 mat->preallocated = PETSC_TRUE; 250 mat->nooffprocentries = PETSC_TRUE; /* See MatAssemblyBegin_MPIAIJ. In effect, making MatAssemblyBegin a nop */ 251 252 PetscCall(MatSetOption(mat, MAT_NO_OFF_PROC_ENTRIES, PETSC_TRUE)); 253 PetscCall(MatAssemblyBegin(mat, MAT_FINAL_ASSEMBLY)); 254 /* MatAssemblyEnd is critical here. It sets mat->offloadmask according to A and B's, and 255 also gets mpiaij->B compacted, with its col ids and size reduced 256 */ 257 PetscCall(MatAssemblyEnd(mat, MAT_FINAL_ASSEMBLY)); 258 PetscCall(MatSetOption(mat, MAT_NO_OFF_PROC_ENTRIES, PETSC_FALSE)); 259 PetscCall(MatSetOption(mat, MAT_NEW_NONZERO_LOCATION_ERR, PETSC_TRUE)); 260 PetscFunctionReturn(PETSC_SUCCESS); 261 } 262 263 // Adapted from Kokkos-Kernels spmv_launch_parameters(), to get parameters in Kokkos nested loops which we used to merge or 264 // split csr matrices. The rule is to have "vector_length * team_size" be around 256 on GPUs (e.g., for a CUDA thread block) 265 template <class ExecutionSpace> 266 static PetscErrorCode MatMergeGetLaunchParameters(PetscInt numRows, PetscInt nnz, PetscInt rows_per_thread, PetscInt &team_size, PetscInt &vector_length, PetscInt &rows_per_team) 267 { 268 Kokkos::TeamPolicy<ExecutionSpace> teamPolicy(128, Kokkos::AUTO); 269 270 PetscFunctionBegin; 271 PetscInt nnz_per_row = numRows ? (nnz / numRows) : 0; // we might meet empty matrices 272 273 if (nnz_per_row < 1) nnz_per_row = 1; 274 275 int max_vector_length = teamPolicy.vector_length_max(); 276 277 if (vector_length < 1) { 278 vector_length = 1; 279 while (vector_length < max_vector_length && vector_length * 6 < nnz_per_row) vector_length *= 2; 280 } 281 282 // Determine rows per thread 283 if (rows_per_thread < 1) { 284 if (KokkosKernels::Impl::kk_is_gpu_exec_space<ExecutionSpace>()) rows_per_thread = 1; 285 else { 286 if (nnz_per_row < 20 && nnz > 5000000) { 287 rows_per_thread = 256; 288 } else rows_per_thread = 64; 289 } 290 } 291 292 if (team_size < 1) { 293 if (KokkosKernels::Impl::kk_is_gpu_exec_space<ExecutionSpace>()) { 294 team_size = 256 / vector_length; 295 } else { 296 team_size = 1; 297 } 298 } 299 300 rows_per_team = rows_per_thread * team_size; 301 302 if (rows_per_team < 0) { 303 PetscInt nnz_per_team = 4096; 304 PetscInt conc = ExecutionSpace().concurrency(); 305 while ((conc * nnz_per_team * 4 > nnz) && (nnz_per_team > 256)) nnz_per_team /= 2; 306 rows_per_team = (nnz_per_team + nnz_per_row - 1) / nnz_per_row; 307 } 308 PetscFunctionReturn(PETSC_SUCCESS); 309 } 310 311 /* 312 Reduce two sets of global indices into local ones 313 314 Input Parameters: 315 + n1 - size of garray1[], the first set 316 . garray1[n1] - a sorted global index array (without duplicates) 317 . m - size of indices[], the second set 318 - indices[m] - a unsorted global index array (might have duplicates), which will be updated on output into local ones 319 320 Output Parameters: 321 + n2 - size of garray2[], the merged set, which combines garray1[] and indices[] 322 . garray2[n2] - allocated by callee using PetscMalloc1(). Contains sorted unique global indices (without duplicates). Caller needs to free it. 323 . map[n1] - allocated by caller. It gives garray1[i] = garray2[map[i]] 324 - indices[m] - on output, global indices in this array are rewritten with local ones, i.e, indices_input[i] = garray2[indices_output[i]] 325 326 Example, say 327 n1 = 5 328 garray1[5] = {1, 4, 7, 8, 10} 329 m = 4 330 indices[4] = {2, 4, 8, 9} 331 332 Combining them together, we have 7 global indices in garray2[] 333 n2 = 7 334 garray2[7] = {1, 2, 4, 7, 8, 9, 10} 335 336 And we have map[] to connect "garray1[i] = garray2[map[i]], i=[0,n1)" 337 map[5] = {0, 2, 3, 4, 6} 338 339 On output, indices[] is updated with local indices 340 indices[4] = {1, 2, 4, 5} 341 */ 342 static PetscErrorCode ReduceTwoSetsOfGlobalIndices(PetscInt n1, const PetscInt *garray1, PetscInt m, PetscInt *indices, PetscInt *n2_, PetscInt **garray2_, PetscInt *map) 343 { 344 PetscHMapI g2l = nullptr; 345 PetscHashIter iter; 346 PetscInt tot, key, val; // total unique global indices. key is global id; val is local id 347 PetscInt n2, *garray2; 348 349 PetscFunctionBegin; 350 tot = 0; 351 PetscCall(PetscHMapICreateWithSize(n1, &g2l)); 352 for (PetscInt i = 0; i < m; i++) { // insert those in indices[] 353 PetscCall(PetscHMapIGetWithDefault(g2l, indices[i], -1, &val)); // if not exist, val is set with -1 354 if (val < 0) PetscCall(PetscHMapISet(g2l, indices[i], tot++)); // val < 0 means gid is not in the hash table yet 355 } 356 357 for (PetscInt i = 0; i < n1; i++) { // insert those in garray1[] 358 PetscCall(PetscHMapIGetWithDefault(g2l, garray1[i], -1, &val)); 359 if (val < 0) PetscCall(PetscHMapISet(g2l, garray1[i], tot++)); 360 } 361 362 // Pull out (unique) globals in the hash table and put them in garray2[] 363 n2 = tot; 364 PetscCall(PetscMalloc1(n2, &garray2)); 365 tot = 0; 366 PetscHashIterBegin(g2l, iter); 367 while (!PetscHashIterAtEnd(g2l, iter)) { 368 PetscHashIterGetKey(g2l, iter, key); 369 PetscHashIterNext(g2l, iter); 370 garray2[tot++] = key; 371 } 372 373 // Sort garray2[] and then map them to local indices starting from 0 374 PetscCall(PetscSortInt(n2, garray2)); 375 PetscCall(PetscHMapIClear(g2l)); 376 for (PetscInt i = 0; i < tot; i++) PetscCall(PetscHMapISet(g2l, garray2[i], i)); // i is the local id 377 378 // Rewrite indices[] with local indices 379 for (PetscInt i = 0; i < m; i++) { 380 PetscCall(PetscHMapIGetWithDefault(g2l, indices[i], -1, &val)); 381 PetscAssert(val >= 0, PETSC_COMM_SELF, PETSC_ERR_PLIB, "Met a negative local column index"); 382 indices[i] = val; 383 } 384 // Record the map that maps garray1[i] to garray2[map[i]] 385 for (PetscInt i = 0; i < n1; i++) PetscCall(PetscHMapIGetWithDefault(g2l, garray1[i], -1, &map[i])); 386 PetscCall(PetscHMapIDestroy(&g2l)); 387 *n2_ = n2; 388 *garray2_ = garray2; 389 PetscFunctionReturn(PETSC_SUCCESS); 390 } 391 392 /* 393 MatMPIAIJKokkosReduce - Reduce rows of a MPIAIJKOKKOS matrix (E, in split form) to produce another matrix (F, also in split form, stored in mm) 394 395 It is the reverse of MatMPIAIJKokkosBcast() in some sense, but with a different signature since we do not really need a fully populated MPIAIJKOKKOS E. 396 397 Think each row of E as a leaf, then the given ownerSF specifies roots for the leaves. Roots may connect to multiple leaves. 398 In this routine, we sparse-merge leaves (rows) at their roots to form potentially longer rows in F. F's number of rows will be nroots of ownerSF. 399 400 Input Parameters: 401 + comm - MPI communicator of E 402 . A - diag block of E, using local column indices 403 . B - off-diag block of E, using local column indices 404 . cstart - (global) start column of Ed 405 . cend - (global) end column + 1 of Ed. In other words, E's column ownership is in range of [cstart, cend) 406 . garray1[n1] - global column indices of Eo. Here n1 is Eo's column size. 407 . ownerSF - the SF specifies ownership (root) of rows in E 408 . reuse - either MAT_INITIAL_MATRIX or MAT_REUSE_MATRIX 409 - mm - to stash intermediate data structures for reuse 410 411 Output Parameters: 412 + map[n1] - allocated by caller. It maps garray1[] to garray2[]. See more at ReduceTwoSetsOfGlobalIndices(). 413 - mm - contains various info, such as garray2[], F (Fd, Fo) etc. 414 415 Notes: 416 When reuse = MAT_REUSE_MATRIX, cstart, cend, garray1, ownerSF, map are not significant. 417 418 */ 419 static PetscErrorCode MatMPIAIJKokkosReduceBegin(MPI_Comm comm, KokkosCsrMatrix A, KokkosCsrMatrix B, PetscInt cstart, PetscInt cend, const PetscInt *garray1, PetscSF ownerSF, MatReuse reuse, PetscInt *map, MatMatStruct_AtB *mm) 420 { 421 PetscFunctionBegin; 422 if (reuse == MAT_INITIAL_MATRIX) { 423 PetscInt Em = A.numRows(), Fm; 424 PetscInt n1 = B.numCols(); 425 426 PetscCall(PetscSFGetGraph(ownerSF, &Fm, NULL, NULL, NULL)); // Fm = #rows of F = nroots of ownerSF 427 428 // Do the analysis on host 429 auto Ai_h = Kokkos::create_mirror_view_and_copy(Kokkos::HostSpace(), A.graph.row_map); 430 auto Aj_h = Kokkos::create_mirror_view_and_copy(Kokkos::HostSpace(), A.graph.entries); 431 auto Bi_h = Kokkos::create_mirror_view_and_copy(Kokkos::HostSpace(), B.graph.row_map); 432 auto Bj_h = Kokkos::create_mirror_view_and_copy(Kokkos::HostSpace(), B.graph.entries); 433 const MatRowMapType *Ai = Ai_h.data(), *Bi = Bi_h.data(); 434 const MatColIdxType *Aj = Aj_h.data(), *Bj = Bj_h.data(); 435 436 // Count how many nonzeros of each row in E are in the left of the diag block [cstart,cend) 437 PetscIntKokkosViewHost E_NzLeft_h(NoInit("E_NzLeft_h"), Em), E_RowLen_h(NoInit("E_RowLen_h"), Em); 438 PetscInt *E_NzLeft = E_NzLeft_h.data(), *E_RowLen = E_RowLen_h.data(); 439 for (PetscInt i = 0; i < Em; i++) { 440 const PetscInt *first, *last, *it; 441 PetscInt count, step; 442 // std::lower_bound(first,last,cstart), but need to use global column indices 443 first = Bj + Bi[i]; 444 last = Bj + Bi[i + 1]; 445 count = last - first; 446 while (count > 0) { 447 it = first; 448 step = count / 2; 449 it += step; 450 if (garray1[*it] < cstart) { // map local to global 451 first = ++it; 452 count -= step + 1; 453 } else count = step; 454 } 455 E_NzLeft[i] = first - (Bj + Bi[i]); 456 E_RowLen[i] = (Ai[i + 1] - Ai[i]) + (Bi[i + 1] - Bi[i]); 457 } 458 459 // Get length of rows (i.e., sizes of leaves) that contribute to my roots 460 const PetscMPIInt *iranks, *ranks; 461 const PetscInt *ioffset, *irootloc, *roffset, *rmine; 462 PetscInt niranks, nranks; 463 MPI_Request *reqs; 464 PetscMPIInt tag; 465 PetscSF reduceSF; 466 PetscInt *sdisp, *rdisp; 467 468 PetscCall(PetscCommGetNewTag(comm, &tag)); 469 PetscCall(PetscSFGetLeafRanks(ownerSF, &niranks, &iranks, &ioffset, &irootloc)); // get leaf ranks connecting to roots on this process (I'll recv from them) 470 PetscCall(PetscSFGetRootRanks(ownerSF, &nranks, &ranks, &roffset, &rmine, NULL)); // get root ranks this process connects (I'll send to them) 471 472 // Find out length of each row I will receive. Even for the same row index, when they are from 473 // different senders, they might have different lengths (and sparsity patterns) 474 PetscInt sendRowCnt = roffset[nranks], recvRowCnt = ioffset[niranks]; 475 PetscInt *sendRowLen, *recvRowLen; // lengths of rows of I need to send/recv per process 476 477 PetscCall(PetscMalloc5(sendRowCnt, &sendRowLen, recvRowCnt + 1, &recvRowLen, nranks, &sdisp, niranks + 1, &rdisp, nranks + niranks, &reqs)); 478 479 for (PetscInt i = 0; i < sendRowCnt; i++) sendRowLen[i] = E_RowLen[rmine[i]]; 480 recvRowLen[0] = 0; // since we will make it in CSR format later 481 recvRowLen++; // advance the pointer now 482 for (PetscInt i = 0; i < niranks; i++) { MPI_Irecv(&recvRowLen[ioffset[i]], ioffset[i + 1] - ioffset[i], MPIU_INT, iranks[i], tag, comm, &reqs[nranks + i]); } 483 for (PetscInt i = 0; i < nranks; i++) { MPI_Isend(&sendRowLen[roffset[i]], roffset[i + 1] - roffset[i], MPIU_INT, ranks[i], tag, comm, &reqs[i]); } 484 PetscCallMPI(MPI_Waitall(nranks + niranks, reqs, MPI_STATUSES_IGNORE)); 485 486 // Build the real PetscSF for reducing E rows (buffer to buffer) 487 rdisp[0] = 0; 488 for (PetscInt i = 0; i < niranks; i++) { 489 rdisp[i + 1] = rdisp[i]; 490 for (PetscInt j = ioffset[i]; j < ioffset[i + 1]; j++) { rdisp[i + 1] += recvRowLen[j]; } 491 } 492 recvRowLen--; // put it back into csr format 493 for (PetscInt i = 0; i < recvRowCnt; i++) recvRowLen[i + 1] += recvRowLen[i]; 494 495 for (PetscInt i = 0; i < nranks; i++) { MPI_Irecv(&sdisp[i], 1, MPIU_INT, ranks[i], tag, comm, &reqs[i]); } 496 for (PetscInt i = 0; i < niranks; i++) { MPI_Isend(&rdisp[i], 1, MPIU_INT, iranks[i], tag, comm, &reqs[nranks + i]); } 497 PetscCallMPI(MPI_Waitall(nranks + niranks, reqs, MPI_STATUSES_IGNORE)); 498 499 PetscInt nleaves = 0, Enz = 0; // leaves are nonzeros I will send 500 PetscInt nroots = rdisp[niranks]; // roots are nonzeros I will recv 501 PetscSFNode *iremote; 502 503 for (PetscInt i = 0; i < Em; i++) Enz += E_RowLen[i]; 504 PetscAssert(A.nnz() + B.nnz() == Enz, comm, PETSC_ERR_PLIB, "Enz should be equal to sum of nnz of A and B"); 505 PetscCallMPI(PetscMalloc1(Enz, &iremote)); // no free, since we give ownership to reduceSF 506 507 for (PetscInt i = 0; i < nranks; i++) { 508 PetscInt count = 0; 509 for (PetscInt j = roffset[i]; j < roffset[i + 1]; j++) count += E_RowLen[rmine[j]]; 510 for (PetscInt j = 0; j < count; j++) { 511 iremote[nleaves + j].rank = ranks[i]; 512 iremote[nleaves + j].index = sdisp[i] + j; 513 } 514 nleaves += count; 515 } 516 PetscCheck(nleaves == Enz, comm, PETSC_ERR_PLIB, "nleaves should be equal to Enz"); 517 518 PetscCall(PetscSFCreate(comm, &reduceSF)); 519 PetscCall(PetscSFSetGraph(reduceSF, nroots, nleaves, NULL, PETSC_OWN_POINTER, iremote, PETSC_OWN_POINTER)); 520 521 // Copy (global) column indices of the needed rows in E to sendCol[], and then PetscSFReduce to recvCol[] 522 PetscInt *sendCol, *recvCol; 523 PetscCall(PetscMalloc2(nleaves, &sendCol, nroots, &recvCol)); 524 for (PetscInt k = 0; k < roffset[nranks]; k++) { 525 PetscInt i = rmine[k]; // row to be copied 526 PetscInt *buf = &sendCol[Ai[i] + Bi[i]]; 527 PetscInt nzLeft = E_NzLeft[i]; 528 PetscInt alen = Ai[i + 1] - Ai[i], blen = Bi[i + 1] - Bi[i]; 529 for (PetscInt j = 0; j < alen + blen; j++) { 530 if (j < nzLeft) { 531 buf[j] = garray1[Bj[Bi[i] + j]]; // left B, in global 532 } else if (j < nzLeft + alen) { 533 buf[j] = Aj[Ai[i] + j - nzLeft] + cstart; // diag A, also in global 534 } else { 535 buf[j] = garray1[Bj[Bi[i] + j - alen]]; // right B, in global 536 } 537 } 538 } 539 PetscCall(PetscSFReduceWithMemTypeBegin(reduceSF, MPIU_INT, PETSC_MEMTYPE_HOST, sendCol, PETSC_MEMTYPE_HOST, recvCol, MPI_REPLACE)); 540 PetscCall(PetscSFReduceEnd(reduceSF, MPIU_INT, sendCol, recvCol, MPI_REPLACE)); 541 542 // With recvCol[], we do a series of analysis to get i, j of Fd, Fo, and build plans to reduce nonzeros in recv buffers to Fd and Fo 543 PetscInt *recvRowPerm, *recvColSorted; 544 PetscInt *recvNzPerm, *recvNzPermSorted; 545 PetscCall(PetscMalloc4(recvRowCnt, &recvRowPerm, nroots, &recvColSorted, nroots, &recvNzPerm, nroots, &recvNzPermSorted)); 546 547 for (PetscInt i = 0; i < nroots; i++) recvNzPerm[i] = i; // numbering all received nonzeros 548 for (PetscInt i = 0; i < recvRowCnt; i++) recvRowPerm[i] = i; // put up a permutation array, so that after sorting we know where to get a row in recvCol[] 549 PetscCall(PetscSortIntWithPermutation(recvRowCnt, irootloc, recvRowPerm)); // irootloc[] (owned by ownerSF) won't be changed 550 551 // i[] array, nz are always easiest to compute 552 MatRowMapKokkosViewHost Fdi_h(NoInit("Fdi_h"), Fm + 1), Foi_h(NoInit("Foi_h"), Fm + 1); 553 MatRowMapType *Fdi, *Foi; 554 PetscInt FnzDups = 0, Fdnz = 0, FdnzDups = 0, Fonz = 0, FonzDups = 0; // nz (with or without dups) in F, Fd, Fo 555 PetscInt iter; 556 557 Kokkos::deep_copy(Fdi_h, 0); // zero, as we will do 'val++' on them 558 Kokkos::deep_copy(Foi_h, 0); 559 Fdi = Fdi_h.data() + 1; // +1 for easy indexing in code below 560 Foi = Foi_h.data() + 1; 561 iter = 0; 562 while (iter < recvRowCnt) { // iter over received rows 563 PetscInt curRowIdx = irootloc[recvRowPerm[iter]]; 564 PetscInt dupRows = 1; // current row has this many contributing rows (of various sparsity patterns) 565 566 while (iter + dupRows < recvRowCnt && irootloc[recvRowPerm[iter + dupRows]] == curRowIdx) dupRows++; 567 568 // Copy column indices (and their permutation) of these rows into recvColSorted & recvNzPermSorted 569 PetscInt nz = 0; // nz (with dups) in the current row 570 PetscInt *jbuf = recvColSorted + FnzDups; 571 PetscInt *pbuf = recvNzPermSorted + FnzDups; 572 PetscInt *jbuf2 = jbuf; // temp pointers 573 PetscInt *pbuf2 = pbuf; 574 for (PetscInt d = 0; d < dupRows; d++) { 575 PetscInt i = recvRowPerm[iter + d]; 576 PetscInt len = recvRowLen[i + 1] - recvRowLen[i]; 577 PetscCall(PetscArraycpy(jbuf2, &recvCol[recvRowLen[i]], len)); 578 PetscCall(PetscArraycpy(pbuf2, &recvNzPerm[recvRowLen[i]], len)); 579 jbuf2 += len; 580 pbuf2 += len; 581 nz += len; 582 } 583 PetscCall(PetscIntSortSemiOrderedWithArray(nz, jbuf, pbuf)); // It could be improved with k-way merge sort, since the rows are already sorted 584 585 // Scan column indices (in jbuf[0,nz), might have dups) of this row, and see how many go to Fd and how many go to Fo 586 PetscInt cur = 0; 587 while (cur < nz) { 588 PetscInt curColIdx = jbuf[cur]; 589 PetscInt dups = 1; 590 591 while (cur + dups < nz && jbuf[cur + dups] == curColIdx) dups++; 592 if (curColIdx >= cstart && curColIdx < cend) { 593 Fdi[curRowIdx]++; 594 FdnzDups += dups; 595 } else { 596 Foi[curRowIdx]++; 597 FonzDups += dups; 598 } 599 cur += dups; 600 } 601 602 FnzDups += nz; 603 iter += dupRows; // Move to next unique row 604 } 605 606 Fdi = Fdi_h.data(); // restore Fdi, Foi and make them CSR 607 Foi = Foi_h.data(); 608 for (PetscInt i = 0; i < Fm; i++) { 609 Fdi[i + 1] += Fdi[i]; 610 Foi[i + 1] += Foi[i]; 611 } 612 Fdnz = Fdi[Fm]; 613 Fonz = Foi[Fm]; 614 PetscCall(PetscFree2(sendCol, recvCol)); 615 616 // Allocate j, jmap, jperm for Fd and Fo 617 MatColIdxKokkosViewHost Fdj_h(NoInit("Fdj_h"), Fdnz), Foj_h(NoInit("Foj_h"), Fonz); 618 MatRowMapKokkosViewHost Fdjmap_h(NoInit("Fdjmap_h"), Fdnz + 1), Fojmap_h(NoInit("Fojmap_h"), Fonz + 1); // +1 to make csr 619 MatRowMapKokkosViewHost Fdjperm_h(NoInit("Fdjperm_h"), FdnzDups), Fojperm_h(NoInit("Fojperm_h"), FonzDups); 620 MatColIdxType *Fdj = Fdj_h.data(), *Foj = Foj_h.data(); 621 MatRowMapType *Fdjmap = Fdjmap_h.data(), *Fojmap = Fojmap_h.data(); 622 MatRowMapType *Fdjperm = Fdjperm_h.data(), *Fojperm = Fojperm_h.data(); 623 624 // Scan recvColSorted[] again, and fill j, jmap, jperm for Fd and Fo 625 Fdjmap[0] = 0; 626 Fojmap[0] = 0; 627 FnzDups = 0; 628 Fdnz = 0; 629 Fonz = 0; 630 iter = 0; // iter over received rows 631 while (iter < recvRowCnt) { 632 PetscInt curRowIdx = irootloc[recvRowPerm[iter]]; // current row idx 633 PetscInt dupRows = 1; // It has this many contributing rows (of various lengths) 634 PetscInt nz = 0; // nz (with dups) in the current row 635 636 while (iter + dupRows < recvRowCnt && irootloc[recvRowPerm[iter + dupRows]] == curRowIdx) dupRows++; 637 for (PetscInt d = 0; d < dupRows; d++) { 638 PetscInt i = recvRowPerm[iter + d]; 639 nz += recvRowLen[i + 1] - recvRowLen[i]; 640 } 641 642 PetscInt *jbuf = recvColSorted + FnzDups; 643 // Scan columns (in jbuf[0,nz) of this row, copy them and their permutation to j[] and jperm[] of Fd and Fo 644 PetscInt cur = 0; 645 while (cur < nz) { 646 PetscInt curColIdx = jbuf[cur]; 647 PetscInt dups = 1; 648 649 while (cur + dups < nz && jbuf[cur + dups] == curColIdx) dups++; 650 if (curColIdx >= cstart && curColIdx < cend) { 651 Fdj[Fdnz] = curColIdx - cstart; // easily convert to local 652 Fdjmap[Fdnz + 1] = Fdjmap[Fdnz] + dups; 653 for (PetscInt j = 0; j < dups; j++) Fdjperm[Fdjmap[Fdnz] + j] = recvNzPermSorted[FnzDups + j]; 654 FdnzDups += dups; 655 Fdnz++; 656 } else { 657 Foj[Fonz] = curColIdx; // in global 658 Fojmap[Fonz + 1] = Fojmap[Fonz] + dups; 659 for (PetscInt j = 0; j < dups; j++) Fojperm[Fojmap[Fonz] + j] = recvNzPermSorted[FnzDups + j]; 660 FonzDups += dups; 661 Fonz++; 662 } 663 cur += dups; 664 FnzDups += dups; 665 } 666 iter += dupRows; // Move to next unique row 667 } 668 PetscCall(PetscFree4(recvRowPerm, recvColSorted, recvNzPerm, recvNzPermSorted)); 669 PetscCall(PetscFree5(sendRowLen, recvRowLen, sdisp, rdisp, reqs)); 670 671 // Combine global column indices in garray1[] and Foj[] 672 PetscInt n2, *garray2; 673 674 PetscCall(ReduceTwoSetsOfGlobalIndices(n1, garray1, Fonz, Foj, &n2, &garray2, map)); 675 mm->sf = reduceSF; 676 mm->leafBuf = MatScalarKokkosView(NoInit("leafBuf"), nleaves); 677 mm->rootBuf = MatScalarKokkosView(NoInit("rootBuf"), nroots); 678 mm->garray = garray2; // give ownership, so no free 679 mm->n = n2; 680 mm->E_NzLeft = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), E_NzLeft_h); 681 mm->Fdjmap = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Fdjmap_h); 682 mm->Fdjperm = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Fdjperm_h); 683 mm->Fojmap = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Fojmap_h); 684 mm->Fojperm = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Fojperm_h); 685 686 // Output Fd and Fo in KokkosCsrMatrix format 687 MatScalarKokkosView Fda_d(NoInit("Fda_d"), Fdnz); 688 MatRowMapKokkosView Fdi_d = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Fdi_h); 689 MatColIdxKokkosView Fdj_d = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Fdj_h); 690 MatScalarKokkosView Foa_d(NoInit("Foa_d"), Fonz); 691 MatRowMapKokkosView Foi_d = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Foi_h); 692 MatColIdxKokkosView Foj_d = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Foj_h); 693 694 PetscCallCXX(mm->Fd = KokkosCsrMatrix("Fd", Fm, cend - cstart, Fdnz, Fda_d, Fdi_d, Fdj_d)); 695 PetscCallCXX(mm->Fo = KokkosCsrMatrix("Fo", Fm, n2, Fonz, Foa_d, Foi_d, Foj_d)); // Fo's column size is n2, length of garray2[] 696 697 // Compute kernel launch parameters in merging E 698 PetscInt teamSize, vectorLength, rowsPerTeam; 699 700 teamSize = vectorLength = rowsPerTeam = -1; 701 PetscCall(MatMergeGetLaunchParameters<DefaultExecutionSpace>(Em, Enz, -1, teamSize, vectorLength, rowsPerTeam)); 702 mm->E_TeamSize = teamSize; 703 mm->E_VectorLength = vectorLength; 704 mm->E_RowsPerTeam = rowsPerTeam; 705 } else PetscCheck(reuse == MAT_REUSE_MATRIX, comm, PETSC_ERR_PLIB, "Unsupported MatReuse enum %d", reuse); 706 707 // Handy aliases 708 auto &Aa = A.values; 709 auto &Ba = B.values; 710 const auto &Ai = A.graph.row_map; 711 const auto &Bi = B.graph.row_map; 712 const auto &E_NzLeft = mm->E_NzLeft; 713 auto &leafBuf = mm->leafBuf; 714 auto &rootBuf = mm->rootBuf; 715 PetscSF reduceSF = mm->sf; 716 PetscInt Em = A.numRows(); 717 PetscInt teamSize = mm->E_TeamSize; 718 PetscInt vectorLength = mm->E_VectorLength; 719 PetscInt rowsPerTeam = mm->E_RowsPerTeam; 720 PetscInt workSets = (Em + rowsPerTeam - 1) / rowsPerTeam; 721 722 // Copy rows in A/B of E to leafBuf, then pass it to rootBuf 723 PetscCallCXX(Kokkos::parallel_for( 724 Kokkos::TeamPolicy<>(PetscGetKokkosExecutionSpace(), workSets, teamSize, vectorLength), KOKKOS_LAMBDA(const KokkosTeamMemberType &t) { 725 Kokkos::parallel_for(Kokkos::TeamThreadRange(t, 0, rowsPerTeam), [&](PetscInt k) { 726 PetscInt i = t.league_rank() * rowsPerTeam + k; // i-th row in F 727 if (i < Em) { 728 PetscInt disp = Ai(i) + Bi(i); 729 PetscInt alen = Ai(i + 1) - Ai(i); 730 PetscInt blen = Bi(i + 1) - Bi(i); 731 PetscInt nzleft = E_NzLeft(i); 732 733 Kokkos::parallel_for(Kokkos::ThreadVectorRange(t, alen + blen), [&](PetscInt j) { 734 MatScalar &val = leafBuf(disp + j); 735 if (j < nzleft) { // B left 736 val = Ba(Bi(i) + j); 737 } else if (j < nzleft + alen) { // diag A 738 val = Aa(Ai(i) + j - nzleft); 739 } else { // B right 740 val = Ba(Bi(i) + j - alen); 741 } 742 }); 743 } 744 }); 745 })); 746 PetscCall(PetscSFReduceWithMemTypeBegin(reduceSF, MPIU_SCALAR, PETSC_MEMTYPE_KOKKOS, leafBuf.data(), PETSC_MEMTYPE_KOKKOS, rootBuf.data(), MPI_REPLACE)); 747 PetscFunctionReturn(PETSC_SUCCESS); 748 } 749 750 // To finish MatMPIAIJKokkosReduce. 751 static PetscErrorCode MatMPIAIJKokkosReduceEnd(MPI_Comm comm, KokkosCsrMatrix A, KokkosCsrMatrix B, PetscInt cstart, PetscInt cend, const PetscInt *garray1, PetscSF ownerSF, MatReuse reuse, PetscInt *map, MatMatStruct_AtB *mm) 752 { 753 auto &leafBuf = mm->leafBuf; 754 auto &rootBuf = mm->rootBuf; 755 auto &Fda = mm->Fd.values; 756 const auto &Fdjmap = mm->Fdjmap; 757 const auto &Fdjperm = mm->Fdjperm; 758 auto Fdnz = mm->Fd.nnz(); 759 auto &Foa = mm->Fo.values; 760 const auto &Fojmap = mm->Fojmap; 761 const auto &Fojperm = mm->Fojperm; 762 auto Fonz = mm->Fo.nnz(); 763 PetscSF reduceSF = mm->sf; 764 765 PetscFunctionBegin; 766 PetscCall(PetscSFReduceEnd(reduceSF, MPIU_SCALAR, leafBuf.data(), rootBuf.data(), MPI_REPLACE)); 767 768 // Reduce data in rootBuf to Fd and Fo 769 PetscCallCXX(Kokkos::parallel_for( 770 Kokkos::RangePolicy<>(PetscGetKokkosExecutionSpace(), 0, Fdnz), KOKKOS_LAMBDA(const MatRowMapType i) { 771 PetscScalar sum = 0.0; 772 for (MatRowMapType k = Fdjmap(i); k < Fdjmap(i + 1); k++) sum += rootBuf(Fdjperm(k)); 773 Fda(i) = sum; 774 })); 775 776 PetscCallCXX(Kokkos::parallel_for( 777 Kokkos::RangePolicy<>(PetscGetKokkosExecutionSpace(), 0, Fonz), KOKKOS_LAMBDA(const MatRowMapType i) { 778 PetscScalar sum = 0.0; 779 for (MatRowMapType k = Fojmap(i); k < Fojmap(i + 1); k++) sum += rootBuf(Fojperm(k)); 780 Foa(i) = sum; 781 })); 782 PetscFunctionReturn(PETSC_SUCCESS); 783 } 784 785 /* 786 MatMPIAIJKokkosBcast - Bcast local rows of a MPIAIJKOKKOS matrix (E) to produce a local matrix (F, stored in mm) in split form 787 788 This is a complex routine. It is essentially the MPIAIJKOKKOS counterpart of MatGetBrowsOfAoCols_MPIAIJ, but supports 789 device and involves various index mapping. 790 791 In the given ownerSF, leaves correspond to rows in F, and roots correspond to rows in E. Roots may connect to multiple leaves. 792 Suppose F's j-th row is connected to a root identified by PetscSFNode (k,i), it means we need to bcast the i-th row of E on rank k 793 to j-th row of F. ownerSF is not an arbitrary SF, instead it is the Mvctx of another MPIAIJ matrix A that is able to perform A*E. 794 F has the same column layout as E. 795 796 Conceptually F has global column indices. In this routine, we spit F into diagonal Fd and off-diagonal Fo. 797 Fd uses local column indices, which are easy to compute. We just need to subtract the "local column range start" from the global indices. 798 Fo had global column indices at first. We will reduce them into local ones. In doing that, we also take into account the global 799 column indices that E's off-diag block has. Let's say there are n1 such indices stored in garray1[]. We will reduce them along with 800 column indices in Fo and update Fo with local indices. 801 802 Input Parameters: 803 + E - the MPIAIJKOKKOS matrix 804 . ownerSF - the ownership SF (insignificant in MAT_REUSE_MATRIX) 805 . reuse - either MAT_INITIAL_MATRIX or MAT_REUSE_MATRIX 806 - mm - to stash matproduct intermediate data structures 807 808 Output Parameters: 809 + map[n1] - allocated by caller. It maps garray1[] to garray2[]. See more at ReduceTwoSetsOfGlobalIndices. 810 - mm - contains various info, such as garray2[], Fd, Fo, etc. 811 812 Notes: 813 When reuse = MAT_REUSE_MATRIX, ownerSF, map are not significant. 814 The routine is provide in split-phase form MatMPIAIJKokkosBcastBegin/End() to provide computation/communication opportunities. 815 */ 816 static PetscErrorCode MatMPIAIJKokkosBcastBegin(Mat E, PetscSF ownerSF, MatReuse reuse, PetscInt *map, MatMatStruct_AB *mm) 817 { 818 Mat_MPIAIJ *empi = static_cast<Mat_MPIAIJ *>(E->data); 819 Mat A = empi->A, B = empi->B; // diag and off-diag 820 Mat_SeqAIJKokkos *akok = static_cast<Mat_SeqAIJKokkos *>(A->spptr), *bkok = static_cast<Mat_SeqAIJKokkos *>(B->spptr); 821 PetscInt Em = E->rmap->n; // #local rows 822 MPI_Comm comm; 823 824 PetscFunctionBegin; 825 PetscCallMPI(PetscObjectGetComm((PetscObject)E, &comm)); 826 if (reuse == MAT_INITIAL_MATRIX) { 827 Mat_SeqAIJ *aseq = static_cast<Mat_SeqAIJ *>(A->data), *bseq = static_cast<Mat_SeqAIJ *>(B->data); 828 PetscInt n1 = B->cmap->n, *Ai = aseq->i, *Aj = aseq->j, *Bi = bseq->i, *Bj = bseq->j; 829 const PetscInt *garray1 = empi->garray; // its size is n1 830 PetscInt cstart, cend; 831 PetscSF bcastSF; 832 833 PetscCall(MatGetOwnershipRangeColumn(E, &cstart, &cend)); 834 835 // Count how many nonzeros of each row in E are in the left of the diag block [cstart,cend) 836 PetscIntKokkosViewHost E_NzLeft_h(NoInit("E_NzLeft_h"), Em), E_RowLen_h(NoInit("E_RowLen_h"), Em); 837 PetscInt *E_NzLeft = E_NzLeft_h.data(), *E_RowLen = E_RowLen_h.data(); 838 for (PetscInt i = 0; i < Em; i++) { 839 const PetscInt *first, *last, *it; 840 PetscInt count, step; 841 // std::lower_bound(first,last,cstart), but need to use global column indices 842 first = Bj + Bi[i]; 843 last = Bj + Bi[i + 1]; 844 count = last - first; 845 while (count > 0) { 846 it = first; 847 step = count / 2; 848 it += step; 849 if (empi->garray[*it] < cstart) { // map local to global 850 first = ++it; 851 count -= step + 1; 852 } else count = step; 853 } 854 E_NzLeft[i] = first - (Bj + Bi[i]); 855 E_RowLen[i] = (Ai[i + 1] - Ai[i]) + (Bi[i + 1] - Bi[i]); 856 } 857 858 // Compute row pointer Fi of F 859 PetscInt *Fi, Fm, Fnz; 860 PetscCall(PetscSFGetGraph(ownerSF, NULL, &Fm, NULL, NULL)); // Fm = #rows of F = nleaves of ownerSF 861 PetscCall(PetscMalloc1(Fm + 1, &Fi)); 862 Fi[0] = 0; 863 PetscCall(PetscSFBcastWithMemTypeBegin(ownerSF, MPIU_INT, PETSC_MEMTYPE_HOST, E_RowLen, PETSC_MEMTYPE_HOST, &Fi[1], MPI_REPLACE)); 864 PetscCall(PetscSFBcastEnd(ownerSF, MPIU_INT, E_RowLen, &Fi[1], MPI_REPLACE)); 865 for (PetscInt i = 0; i < Fm; i++) Fi[i + 1] += Fi[i]; 866 Fnz = Fi[Fm]; 867 868 // Build the real PetscSF for bcasting E rows (buffer to buffer) 869 const PetscMPIInt *iranks, *ranks; 870 const PetscInt *ioffset, *irootloc, *roffset; 871 PetscInt niranks, nranks, *sdisp, *rdisp; 872 MPI_Request *reqs; 873 PetscMPIInt tag; 874 875 PetscCall(PetscSFGetLeafRanks(ownerSF, &niranks, &iranks, &ioffset, &irootloc)); // get leaf ranks referencing roots on this process 876 PetscCall(PetscSFGetRootRanks(ownerSF, &nranks, &ranks, &roffset, NULL, NULL)); // recv info 877 PetscCall(PetscMalloc3(niranks + 1, &sdisp, nranks, &rdisp, niranks + nranks, &reqs)); 878 879 sdisp[0] = 0; // send displacement 880 for (PetscInt i = 0; i < niranks; i++) { 881 sdisp[i + 1] = sdisp[i]; 882 for (PetscInt j = ioffset[i]; j < ioffset[i + 1]; j++) { 883 PetscInt r = irootloc[j]; // row to be sent 884 sdisp[i + 1] += E_RowLen[r]; 885 } 886 } 887 888 PetscCallMPI(PetscCommGetNewTag(comm, &tag)); 889 for (PetscInt i = 0; i < nranks; i++) PetscCallMPI(MPI_Irecv(&rdisp[i], 1, MPIU_INT, ranks[i], tag, comm, &reqs[i])); 890 for (PetscInt i = 0; i < niranks; i++) PetscCallMPI(MPI_Isend(&sdisp[i], 1, MPIU_INT, iranks[i], tag, comm, &reqs[nranks + i])); 891 PetscCallMPI(MPI_Waitall(niranks + nranks, reqs, MPI_STATUSES_IGNORE)); 892 893 PetscInt nleaves = Fnz; // leaves are nonzeros I will receive 894 PetscInt nroots = sdisp[niranks]; // roots are nonzeros I will send 895 PetscSFNode *iremote; // give ownership to bcastSF 896 PetscCall(PetscMalloc1(nleaves, &iremote)); 897 for (PetscInt i = 0; i < nranks; i++) { // for each sender rank 898 PetscInt k = 0; 899 for (PetscInt j = Fi[roffset[i]]; j < Fi[roffset[i + 1]]; j++) { // I will receive rows [roffset[i], roffset[i+1]) of F from ranks[i] 900 iremote[j].rank = ranks[i]; 901 iremote[j].index = rdisp[i] + k; // their root location 902 k++; 903 } 904 } 905 PetscCall(PetscSFCreate(comm, &bcastSF)); 906 PetscCall(PetscSFSetGraph(bcastSF, nroots, nleaves, NULL, PETSC_OWN_POINTER, iremote, PETSC_OWN_POINTER)); 907 PetscCall(PetscFree3(sdisp, rdisp, reqs)); 908 909 // Build a plan (rowoffset, irootloc, E_NzLeft) to copy rows in E to rootdata of bcastSF in parallel 910 PetscIntKokkosViewHost rowoffset_h(NoInit("rowoffset_h"), ioffset[niranks] + 1); 911 PetscInt *rowoffset = rowoffset_h.data(); // for each entry (row) indicated in irootloc[], we calculate its destinate offset in copying 912 rowoffset[0] = 0; 913 for (PetscInt i = 0; i < ioffset[niranks]; i++) { rowoffset[i + 1] = rowoffset[i] + E_RowLen[irootloc[i]]; } 914 915 // Copy (global) column indices of the needed rows in E to a buffer, and then bcast to Fj[] 916 PetscInt *jbuf, *Fj; 917 PetscCall(PetscMalloc2(nroots, &jbuf, Fnz, &Fj)); 918 for (PetscInt k = 0; k < ioffset[niranks]; k++) { 919 PetscInt i = irootloc[k]; // row to be copied 920 PetscInt *buf = &jbuf[rowoffset[k]]; 921 PetscInt nzLeft = E_NzLeft[i]; 922 PetscInt alen = Ai[i + 1] - Ai[i], blen = Bi[i + 1] - Bi[i]; 923 for (PetscInt j = 0; j < alen + blen; j++) { 924 if (j < nzLeft) { 925 buf[j] = empi->garray[Bj[Bi[i] + j]]; // left B, in global 926 } else if (j < nzLeft + alen) { 927 buf[j] = Aj[Ai[i] + j - nzLeft] + cstart; // diag A, also in global 928 } else { 929 buf[j] = empi->garray[Bj[Bi[i] + j - alen]]; // right B, in global 930 } 931 } 932 } 933 PetscCall(PetscSFBcastWithMemTypeBegin(bcastSF, MPIU_INT, PETSC_MEMTYPE_HOST, jbuf, PETSC_MEMTYPE_HOST, Fj, MPI_REPLACE)); 934 PetscCall(PetscSFBcastEnd(bcastSF, MPIU_INT, jbuf, Fj, MPI_REPLACE)); 935 936 // Build a plan (i.e., F_NzLeft) to split F into Fd and Fo 937 MatRowMapKokkosViewHost Fdi_h(NoInit("Fdi_h"), Fm + 1), Foi_h(NoInit("Foi_h"), Fm + 1); // row pointer of Fd, Fo 938 MatColIdxKokkosViewHost F_NzLeft_h(NoInit("F_NzLeft_h"), Fm); // split each row of F into Left, Diag, Right. We only need to record #nz in Left and Diag. 939 MatRowMapType *Fdi = Fdi_h.data(), *Foi = Foi_h.data(); 940 MatColIdxType *F_NzLeft = F_NzLeft_h.data(); 941 942 Fdi[0] = Foi[0] = 0; 943 for (PetscInt i = 0; i < Fm; i++) { 944 PetscInt *first, *last, *lb1, *lb2; 945 // cut the row into: Left, [cstart, cend), Right 946 first = Fj + Fi[i]; 947 last = Fj + Fi[i + 1]; 948 lb1 = std::lower_bound(first, last, cstart); 949 F_NzLeft[i] = lb1 - first; 950 lb2 = std::lower_bound(first, last, cend); 951 Fdi[i + 1] = lb2 - lb1; // row i length in Fdi 952 Foi[i + 1] = (Fi[i + 1] - Fi[i]) - Fdi[i + 1]; // row i length in Foi 953 } 954 for (PetscInt i = 0; i < Fm; i++) { 955 Fdi[i + 1] += Fdi[i]; 956 Foi[i + 1] += Foi[i]; 957 } 958 959 // Fill Fdj[] and Foj[], i.e., columns of Fd and Fo. Fdj[] are local, but Foj[] are not yet. 960 PetscInt Fdnz = Fdi[Fm], Fonz = Foi[Fm]; 961 MatColIdxKokkosViewHost Fdj_h(NoInit("Fdj_h"), Fdnz), Foj_h(NoInit("Foj_h"), Fonz); 962 MatColIdxType *Fdj = Fdj_h.data(), *Foj = Foj_h.data(), gid; 963 964 for (PetscInt i = 0; i < Fm; i++) { 965 PetscInt nzLeft = F_NzLeft[i]; 966 PetscInt len = Fdi[i + 1] - Fdi[i]; // diag row len 967 for (PetscInt j = 0; j < Fi[i + 1] - Fi[i]; j++) { 968 gid = Fj[Fi[i] + j]; 969 if (j < nzLeft) { // left, in global 970 Foj[Foi[i] + j] = gid; 971 } else if (j < nzLeft + len) { // diag, in local 972 Fdj[Fdi[i] + j - nzLeft] = gid - cstart; 973 } else { // right, in global 974 Foj[Foi[i] + j - len] = gid; 975 } 976 } 977 } 978 PetscCall(PetscFree2(jbuf, Fj)); 979 PetscCall(PetscFree(Fi)); 980 981 // Reduce global indices in Foj[] and garray1[] into local ones 982 PetscInt n2, *garray2; 983 PetscCall(ReduceTwoSetsOfGlobalIndices(n1, garray1, Fonz, Foj, &n2, &garray2, map)); 984 985 // Record the plans built above, for reuse 986 PetscIntKokkosViewHost tmp(const_cast<PetscInt *>(irootloc), ioffset[niranks]); // irootloc[] is owned by ownerSF. We create a copy for safety 987 PetscIntKokkosViewHost irootloc_h(NoInit("irootloc_h"), ioffset[niranks]); 988 Kokkos::deep_copy(irootloc_h, tmp); 989 mm->sf = bcastSF; 990 mm->E_NzLeft = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), E_NzLeft_h); 991 mm->F_NzLeft = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), F_NzLeft_h); 992 mm->irootloc = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), irootloc_h); 993 mm->rowoffset = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), rowoffset_h); 994 mm->rootBuf = MatScalarKokkosView(NoInit("rootBuf"), nroots); 995 mm->leafBuf = MatScalarKokkosView(NoInit("leafBuf"), nleaves); 996 mm->garray = garray2; 997 mm->n = n2; 998 999 // Output Fd and Fo in KokkosCsrMatrix format 1000 MatScalarKokkosView Fda_d(NoInit("Fda_d"), Fdnz), Foa_d(NoInit("Foa_d"), Fonz); 1001 MatRowMapKokkosView Fdi_d = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Fdi_h); 1002 MatColIdxKokkosView Fdj_d = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Fdj_h); 1003 MatRowMapKokkosView Foi_d = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Foi_h); 1004 MatColIdxKokkosView Foj_d = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Foj_h); 1005 1006 PetscCallCXX(mm->Fd = KokkosCsrMatrix("Fd", Fm, cend - cstart, Fdnz, Fda_d, Fdi_d, Fdj_d)); 1007 PetscCallCXX(mm->Fo = KokkosCsrMatrix("Fo", Fm, n2, Fonz, Foa_d, Foi_d, Foj_d)); 1008 1009 // Compute kernel launch parameters in merging E or splitting F 1010 PetscInt teamSize, vectorLength, rowsPerTeam; 1011 1012 teamSize = vectorLength = rowsPerTeam = -1; 1013 PetscCall(MatMergeGetLaunchParameters<DefaultExecutionSpace>(mm->irootloc.extent(0), mm->rootBuf.extent(0), -1, teamSize, vectorLength, rowsPerTeam)); 1014 mm->E_TeamSize = teamSize; 1015 mm->E_VectorLength = vectorLength; 1016 mm->E_RowsPerTeam = rowsPerTeam; 1017 1018 teamSize = vectorLength = rowsPerTeam = -1; 1019 PetscCall(MatMergeGetLaunchParameters<DefaultExecutionSpace>(Fm, Fnz, -1, teamSize, vectorLength, rowsPerTeam)); 1020 mm->F_TeamSize = teamSize; 1021 mm->F_VectorLength = vectorLength; 1022 mm->F_RowsPerTeam = rowsPerTeam; 1023 } else PetscCheck(reuse == MAT_REUSE_MATRIX, comm, PETSC_ERR_PLIB, "Unsupported MatReuse enum %d", reuse); 1024 1025 // Sync E's value to device 1026 akok->a_dual.sync_device(); 1027 bkok->a_dual.sync_device(); 1028 1029 // Handy aliases 1030 const auto &Aa = akok->a_dual.view_device(); 1031 const auto &Ba = bkok->a_dual.view_device(); 1032 const auto &Ai = akok->i_dual.view_device(); 1033 const auto &Bi = bkok->i_dual.view_device(); 1034 1035 // Fetch the plans 1036 PetscIntKokkosView &E_NzLeft = mm->E_NzLeft; 1037 PetscSF &bcastSF = mm->sf; 1038 MatScalarKokkosView &rootBuf = mm->rootBuf; 1039 MatScalarKokkosView &leafBuf = mm->leafBuf; 1040 PetscIntKokkosView &irootloc = mm->irootloc; 1041 PetscIntKokkosView &rowoffset = mm->rowoffset; 1042 1043 PetscInt teamSize = mm->E_TeamSize; 1044 PetscInt vectorLength = mm->E_VectorLength; 1045 PetscInt rowsPerTeam = mm->E_RowsPerTeam; 1046 PetscInt workSets = (irootloc.extent(0) + rowsPerTeam - 1) / rowsPerTeam; 1047 1048 // Copy rows in A/B of E to rootBuf, then bcast it to leafBuf 1049 PetscCallCXX(Kokkos::parallel_for( 1050 Kokkos::TeamPolicy<>(PetscGetKokkosExecutionSpace(), workSets, teamSize, vectorLength), KOKKOS_LAMBDA(const KokkosTeamMemberType &t) { 1051 Kokkos::parallel_for(Kokkos::TeamThreadRange(t, 0, rowsPerTeam), [&](PetscInt k) { 1052 size_t r = t.league_rank() * rowsPerTeam + k; // r-th entry in irootloc[] 1053 if (r < irootloc.extent(0)) { 1054 PetscInt i = irootloc(r); // row i of E 1055 PetscInt disp = rowoffset(r); 1056 PetscInt alen = Ai(i + 1) - Ai(i); 1057 PetscInt blen = Bi(i + 1) - Bi(i); 1058 PetscInt nzleft = E_NzLeft(i); 1059 1060 Kokkos::parallel_for(Kokkos::ThreadVectorRange(t, alen + blen), [&](PetscInt j) { 1061 if (j < nzleft) { // B left 1062 rootBuf(disp + j) = Ba(Bi(i) + j); 1063 } else if (j < nzleft + alen) { // diag A 1064 rootBuf(disp + j) = Aa(Ai(i) + j - nzleft); 1065 } else { // B right 1066 rootBuf(disp + j) = Ba(Bi(i) + j - alen); 1067 } 1068 }); 1069 } 1070 }); 1071 })); 1072 PetscCall(PetscSFBcastWithMemTypeBegin(bcastSF, MPIU_SCALAR, PETSC_MEMTYPE_KOKKOS, rootBuf.data(), PETSC_MEMTYPE_KOKKOS, leafBuf.data(), MPI_REPLACE)); 1073 PetscFunctionReturn(PETSC_SUCCESS); 1074 } 1075 1076 // To finish MatMPIAIJKokkosBcast. 1077 static PetscErrorCode MatMPIAIJKokkosBcastEnd(Mat E, PetscSF ownerSF, MatReuse reuse, PetscInt *map, MatMatStruct_AB *mm) 1078 { 1079 PetscFunctionBegin; 1080 const auto &Fd = mm->Fd; 1081 const auto &Fo = mm->Fo; 1082 const auto &Fdi = Fd.graph.row_map; 1083 const auto &Foi = Fo.graph.row_map; 1084 auto &Fda = Fd.values; 1085 auto &Foa = Fo.values; 1086 auto Fm = Fd.numRows(); 1087 1088 PetscIntKokkosView &F_NzLeft = mm->F_NzLeft; 1089 PetscSF &bcastSF = mm->sf; 1090 MatScalarKokkosView &rootBuf = mm->rootBuf; 1091 MatScalarKokkosView &leafBuf = mm->leafBuf; 1092 PetscInt teamSize = mm->F_TeamSize; 1093 PetscInt vectorLength = mm->F_VectorLength; 1094 PetscInt rowsPerTeam = mm->F_RowsPerTeam; 1095 PetscInt workSets = (Fm + rowsPerTeam - 1) / rowsPerTeam; 1096 1097 PetscCall(PetscSFBcastEnd(bcastSF, MPIU_SCALAR, rootBuf.data(), leafBuf.data(), MPI_REPLACE)); 1098 1099 // Update Fda and Foa with new data in leafBuf (as if it is Fa) 1100 PetscCallCXX(Kokkos::parallel_for( 1101 Kokkos::TeamPolicy<>(PetscGetKokkosExecutionSpace(), workSets, teamSize, vectorLength), KOKKOS_LAMBDA(const KokkosTeamMemberType &t) { 1102 Kokkos::parallel_for(Kokkos::TeamThreadRange(t, 0, rowsPerTeam), [&](PetscInt k) { 1103 PetscInt i = t.league_rank() * rowsPerTeam + k; // i-th row in F 1104 if (i < Fm) { 1105 PetscInt nzLeft = F_NzLeft(i); 1106 PetscInt alen = Fdi(i + 1) - Fdi(i); 1107 PetscInt blen = Foi(i + 1) - Foi(i); 1108 PetscInt Fii = Fdi(i) + Foi(i); 1109 1110 Kokkos::parallel_for(Kokkos::ThreadVectorRange(t, alen + blen), [&](PetscInt j) { 1111 PetscScalar val = leafBuf(Fii + j); 1112 if (j < nzLeft) { // left 1113 Foa(Foi(i) + j) = val; 1114 } else if (j < nzLeft + alen) { // diag 1115 Fda(Fdi(i) + j - nzLeft) = val; 1116 } else { // right 1117 Foa(Foi(i) + j - alen) = val; 1118 } 1119 }); 1120 } 1121 }); 1122 })); 1123 PetscFunctionReturn(PETSC_SUCCESS); 1124 } 1125 1126 static PetscErrorCode MatProductSymbolic_MPIAIJKokkos_AtB(Mat_Product *product, Mat A, Mat B, MatMatStruct_AtB *mm) 1127 { 1128 Mat_MPIAIJ *ampi = static_cast<Mat_MPIAIJ *>(A->data); 1129 Mat_MPIAIJ *bmpi = static_cast<Mat_MPIAIJ *>(B->data); 1130 KokkosCsrMatrix Adt, Aot, Ad, Ao, Bd, Bo; 1131 PetscInt cstart, cend; 1132 MPI_Comm comm; 1133 1134 PetscFunctionBegin; 1135 PetscCall(PetscObjectGetComm((PetscObject)B, &comm)); 1136 PetscCall(MatSeqAIJKokkosGenerateTranspose_Private(ampi->A, &Adt)); 1137 PetscCall(MatSeqAIJKokkosGenerateTranspose_Private(ampi->B, &Aot)); 1138 PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(ampi->A, &Ad)); 1139 PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(ampi->B, &Ao)); 1140 PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(bmpi->A, &Bd)); 1141 PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(bmpi->B, &Bo)); 1142 1143 // TODO: add command line options to select spgemm algorithms 1144 auto spgemm_alg = KokkosSparse::SPGEMMAlgorithm::SPGEMM_DEFAULT; // default is TPL if enabled, otherwise KK 1145 1146 // CUDA-10.2's spgemm has bugs. We prefer the SpGEMMreuse APIs introduced in cuda-11.4 1147 #if defined(KOKKOSKERNELS_ENABLE_TPL_CUSPARSE) 1148 #if PETSC_PKG_CUDA_VERSION_LT(11, 4, 0) 1149 spgemm_alg = KokkosSparse::SPGEMMAlgorithm::SPGEMM_KK; 1150 #endif 1151 #endif 1152 1153 PetscCallCXX(mm->kh1.create_spgemm_handle(spgemm_alg)); 1154 PetscCallCXX(mm->kh2.create_spgemm_handle(spgemm_alg)); 1155 PetscCallCXX(mm->kh3.create_spgemm_handle(spgemm_alg)); 1156 PetscCallCXX(mm->kh4.create_spgemm_handle(spgemm_alg)); 1157 1158 // Aot * (B's diag + B's off-diag) 1159 PetscCallCXX(KokkosSparse::spgemm_symbolic(mm->kh3, Aot, false, Bd, false, mm->C3)); 1160 PetscCallCXX(KokkosSparse::spgemm_symbolic(mm->kh4, Aot, false, Bo, false, mm->C4)); 1161 // KK spgemm_symbolic() only populates the result's row map, but not its columns. 1162 // TODO: Remove the fake spgemm_numeric() after KK fixed this problem. 1163 PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh3, Aot, false, Bd, false, mm->C3)); 1164 PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh4, Aot, false, Bo, false, mm->C4)); 1165 #if PETSC_PKG_KOKKOS_KERNELS_VERSION_LT(4, 0, 0) 1166 1167 PetscCallCXX(sort_crs_matrix(mm->C3)); 1168 PetscCallCXX(sort_crs_matrix(mm->C4)); 1169 #endif 1170 1171 // Reduce E (i.e., C3 and C4)'s rows to form F, and overlap the communication 1172 PetscIntKokkosViewHost map_h(NoInit("map_h"), bmpi->B->cmap->n); 1173 PetscCall(MatGetOwnershipRangeColumn(B, &cstart, &cend)); 1174 PetscCall(MatMPIAIJKokkosReduceBegin(comm, mm->C3, mm->C4, cstart, cend, bmpi->garray, ampi->Mvctx, MAT_INITIAL_MATRIX, map_h.data(), mm)); 1175 1176 // Adt * (B's diag + B's off-diag) 1177 PetscCallCXX(KokkosSparse::spgemm_symbolic(mm->kh1, Adt, false, Bd, false, mm->C1)); 1178 PetscCallCXX(KokkosSparse::spgemm_symbolic(mm->kh2, Adt, false, Bo, false, mm->C2_mid)); 1179 PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh1, Adt, false, Bd, false, mm->C1)); 1180 PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh2, Adt, false, Bo, false, mm->C2_mid)); 1181 #if PETSC_PKG_KOKKOS_KERNELS_VERSION_LT(4, 0, 0) 1182 PetscCallCXX(sort_crs_matrix(mm->C1)); 1183 PetscCallCXX(sort_crs_matrix(mm->C2_mid)); 1184 #endif 1185 1186 PetscCall(MatMPIAIJKokkosReduceEnd(comm, mm->C3, mm->C4, cstart, cend, bmpi->garray, ampi->Mvctx, MAT_INITIAL_MATRIX, map_h.data(), mm)); 1187 1188 // Create C2, which shares a, i arrays with C2_mid, but with new column indices and potentially larger column size 1189 MatColIdxKokkosView oldj = mm->C2_mid.graph.entries, newj(NoInit("j"), oldj.extent(0)); 1190 PetscIntKokkosView map = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), map_h); 1191 PetscCallCXX(Kokkos::parallel_for(Kokkos::RangePolicy<>(PetscGetKokkosExecutionSpace(), 0, oldj.extent(0)), KOKKOS_LAMBDA(const PetscInt i) { newj(i) = map(oldj(i)); })); 1192 PetscCallCXX(mm->C2 = KokkosCsrMatrix("C2", mm->C2_mid.numRows(), mm->n /*new column size*/, mm->C2_mid.nnz(), mm->C2_mid.values, mm->C2_mid.graph.row_map, newj)); 1193 1194 // C = (C1+Fd, C2+Fo) 1195 PetscCallCXX(mm->kh1.create_spadd_handle(true)); // C1, Fd are sorted 1196 PetscCallCXX(mm->kh2.create_spadd_handle(true)); // C2, Fo are sorted 1197 PetscCallCXX(KokkosSparse::spadd_symbolic(&mm->kh1, mm->C1, mm->Fd, mm->Cd)); 1198 PetscCallCXX(KokkosSparse::spadd_symbolic(&mm->kh2, mm->C2, mm->Fo, mm->Co)); 1199 PetscCallCXX(KokkosSparse::spadd_numeric(&mm->kh1, 1.0, mm->C1, 1.0, mm->Fd, mm->Cd)); 1200 PetscCallCXX(KokkosSparse::spadd_numeric(&mm->kh2, 1.0, mm->C2, 1.0, mm->Fo, mm->Co)); 1201 PetscFunctionReturn(PETSC_SUCCESS); 1202 } 1203 1204 static PetscErrorCode MatProductNumeric_MPIAIJKokkos_AtB(Mat_Product *product, Mat A, Mat B, MatMatStruct_AtB *mm) 1205 { 1206 Mat_MPIAIJ *ampi = static_cast<Mat_MPIAIJ *>(A->data); 1207 Mat_MPIAIJ *bmpi = static_cast<Mat_MPIAIJ *>(B->data); 1208 KokkosCsrMatrix Adt, Aot, Bd, Bo; 1209 MPI_Comm comm; 1210 1211 PetscFunctionBegin; 1212 PetscCall(PetscObjectGetComm((PetscObject)B, &comm)); 1213 PetscCall(MatSeqAIJKokkosGenerateTranspose_Private(ampi->A, &Adt)); 1214 PetscCall(MatSeqAIJKokkosGenerateTranspose_Private(ampi->B, &Aot)); 1215 PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(bmpi->A, &Bd)); 1216 PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(bmpi->B, &Bo)); 1217 1218 // Aot * (B's diag + B's off-diag) 1219 PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh3, Aot, false, Bd, false, mm->C3)); 1220 PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh4, Aot, false, Bo, false, mm->C4)); 1221 1222 // Reduce E (i.e., C3 and C4)'s rows to form F, and overlap the communication 1223 PetscCall(MatMPIAIJKokkosReduceBegin(comm, mm->C3, mm->C4, 0, 0, NULL, NULL, MAT_REUSE_MATRIX, NULL, mm)); 1224 1225 // Adt * (B's diag + B's off-diag) 1226 PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh1, Adt, false, Bd, false, mm->C1)); 1227 PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh2, Adt, false, Bo, false, mm->C2_mid)); 1228 1229 PetscCall(MatMPIAIJKokkosReduceEnd(comm, mm->C3, mm->C4, 0, 0, NULL, NULL, MAT_REUSE_MATRIX, NULL, mm)); 1230 1231 // C = (C1+Fd, C2+Fo) 1232 PetscCallCXX(KokkosSparse::spadd_numeric(&mm->kh1, 1.0, mm->C1, 1.0, mm->Fd, mm->Cd)); 1233 PetscCallCXX(KokkosSparse::spadd_numeric(&mm->kh2, 1.0, mm->C2, 1.0, mm->Fo, mm->Co)); 1234 PetscFunctionReturn(PETSC_SUCCESS); 1235 } 1236 1237 /* MatProductSymbolic_MPIAIJKokkos_AB - AB flavor of MatProductSymbolic_MPIAIJKokkos 1238 1239 Input Parameters: 1240 + product - Mat_Product which carried out the computation. Passed in to access info about this mat product. 1241 . A - an MPIAIJKOKKOS matrix 1242 . B - an MPIAIJKOKKOS matrix 1243 - mm - a struct used to stash intermediate data when computing AB. Persist from symbolic to numeric operations. 1244 */ 1245 static PetscErrorCode MatProductSymbolic_MPIAIJKokkos_AB(Mat_Product *product, Mat A, Mat B, MatMatStruct_AB *mm) 1246 { 1247 Mat_MPIAIJ *ampi = static_cast<Mat_MPIAIJ *>(A->data); 1248 Mat_MPIAIJ *bmpi = static_cast<Mat_MPIAIJ *>(B->data); 1249 KokkosCsrMatrix Ad, Ao, Bd, Bo; 1250 1251 PetscFunctionBegin; 1252 PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(ampi->A, &Ad)); 1253 PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(ampi->B, &Ao)); 1254 PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(bmpi->A, &Bd)); 1255 PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(bmpi->B, &Bo)); 1256 1257 // TODO: add command line options to select spgemm algorithms 1258 auto spgemm_alg = KokkosSparse::SPGEMMAlgorithm::SPGEMM_DEFAULT; // default is TPL if enabled, otherwise KK 1259 1260 // CUDA-10.2's spgemm has bugs. We prefer the SpGEMMreuse APIs introduced in cuda-11.4 1261 #if defined(KOKKOSKERNELS_ENABLE_TPL_CUSPARSE) 1262 #if PETSC_PKG_CUDA_VERSION_LT(11, 4, 0) 1263 spgemm_alg = KokkosSparse::SPGEMMAlgorithm::SPGEMM_KK; 1264 #endif 1265 #endif 1266 1267 mm->kh1.create_spgemm_handle(spgemm_alg); 1268 mm->kh2.create_spgemm_handle(spgemm_alg); 1269 mm->kh3.create_spgemm_handle(spgemm_alg); 1270 mm->kh4.create_spgemm_handle(spgemm_alg); 1271 1272 // Bcast B's rows to form F, and overlap the communication 1273 PetscIntKokkosViewHost map_h(NoInit("map_h"), bmpi->B->cmap->n); 1274 PetscCall(MatMPIAIJKokkosBcastBegin(B, ampi->Mvctx, MAT_INITIAL_MATRIX, map_h.data(), mm)); 1275 1276 // A's diag * (B's diag + B's off-diag) 1277 PetscCallCXX(KokkosSparse::spgemm_symbolic(mm->kh1, Ad, false, Bd, false, mm->C1)); 1278 PetscCallCXX(KokkosSparse::spgemm_symbolic(mm->kh2, Ad, false, Bo, false, mm->C2_mid)); // C2 aliases with C2_mid, except with new column indices 1279 // KK spgemm_symbolic() only populates the result's row map, but not its columns. 1280 // TODO: Remove the fake spgemm_numeric() after KK fixed this problem. 1281 PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh1, Ad, false, Bd, false, mm->C1)); 1282 PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh2, Ad, false, Bo, false, mm->C2_mid)); 1283 #if PETSC_PKG_KOKKOS_KERNELS_VERSION_LT(4, 0, 0) 1284 PetscCallCXX(sort_crs_matrix(mm->C1)); 1285 PetscCallCXX(sort_crs_matrix(mm->C2_mid)); 1286 #endif 1287 1288 PetscCall(MatMPIAIJKokkosBcastEnd(B, ampi->Mvctx, MAT_INITIAL_MATRIX, map_h.data(), mm)); 1289 1290 // A's off-diag * (F's diag + F's off-diag) 1291 PetscCallCXX(KokkosSparse::spgemm_symbolic(mm->kh3, Ao, false, mm->Fd, false, mm->C3)); 1292 PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh3, Ao, false, mm->Fd, false, mm->C3)); 1293 PetscCallCXX(KokkosSparse::spgemm_symbolic(mm->kh4, Ao, false, mm->Fo, false, mm->C4)); 1294 PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh4, Ao, false, mm->Fo, false, mm->C4)); 1295 #if PETSC_PKG_KOKKOS_KERNELS_VERSION_LT(4, 0, 0) 1296 PetscCallCXX(sort_crs_matrix(mm->C3)); 1297 PetscCallCXX(sort_crs_matrix(mm->C4)); 1298 #endif 1299 1300 // Create C2, which shares a, i arrays with C2_mid, but with new column indices and potentially larger column size 1301 MatColIdxKokkosView oldj = mm->C2_mid.graph.entries, newj(NoInit("j"), oldj.extent(0)); 1302 PetscIntKokkosView map = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), map_h); 1303 PetscCallCXX(Kokkos::parallel_for(Kokkos::RangePolicy<>(PetscGetKokkosExecutionSpace(), 0, oldj.extent(0)), KOKKOS_LAMBDA(const PetscInt i) { newj(i) = map(oldj(i)); })); 1304 mm->C2 = KokkosCsrMatrix("C2", mm->C2_mid.numRows(), mm->n /*new column size*/, mm->C2_mid.nnz(), mm->C2_mid.values, mm->C2_mid.graph.row_map, newj); 1305 1306 // C = (Cd, Co) = (C1+C3, C2+C4) 1307 mm->kh1.create_spadd_handle(true); // C1, C3 are sorted 1308 mm->kh2.create_spadd_handle(true); // C2, C4 are sorted 1309 PetscCallCXX(KokkosSparse::spadd_symbolic(&mm->kh1, mm->C1, mm->C3, mm->Cd)); 1310 PetscCallCXX(KokkosSparse::spadd_symbolic(&mm->kh2, mm->C2, mm->C4, mm->Co)); 1311 PetscCallCXX(KokkosSparse::spadd_numeric(&mm->kh1, 1.0, mm->C1, 1.0, mm->C3, mm->Cd)); 1312 PetscCallCXX(KokkosSparse::spadd_numeric(&mm->kh2, 1.0, mm->C2, 1.0, mm->C4, mm->Co)); 1313 PetscFunctionReturn(PETSC_SUCCESS); 1314 } 1315 1316 static PetscErrorCode MatProductNumeric_MPIAIJKokkos_AB(Mat_Product *product, Mat A, Mat B, MatMatStruct_AB *mm) 1317 { 1318 Mat_MPIAIJ *ampi = static_cast<Mat_MPIAIJ *>(A->data); 1319 Mat_MPIAIJ *bmpi = static_cast<Mat_MPIAIJ *>(B->data); 1320 KokkosCsrMatrix Ad, Ao, Bd, Bo; 1321 1322 PetscFunctionBegin; 1323 PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(ampi->A, &Ad)); 1324 PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(ampi->B, &Ao)); 1325 PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(bmpi->A, &Bd)); 1326 PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(bmpi->B, &Bo)); 1327 1328 // Bcast B's rows to form F, and overlap the communication 1329 PetscCall(MatMPIAIJKokkosBcastBegin(B, NULL, MAT_REUSE_MATRIX, NULL, mm)); 1330 1331 // A's diag * (B's diag + B's off-diag) 1332 PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh1, Ad, false, Bd, false, mm->C1)); 1333 PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh2, Ad, false, Bo, false, mm->C2_mid)); 1334 1335 PetscCall(MatMPIAIJKokkosBcastEnd(B, NULL, MAT_REUSE_MATRIX, NULL, mm)); 1336 1337 // A's off-diag * (F's diag + F's off-diag) 1338 PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh3, Ao, false, mm->Fd, false, mm->C3)); 1339 PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh4, Ao, false, mm->Fo, false, mm->C4)); 1340 1341 // C = (Cd, Co) = (C1+C3, C2+C4) 1342 PetscCallCXX(KokkosSparse::spadd_numeric(&mm->kh1, 1.0, mm->C1, 1.0, mm->C3, mm->Cd)); 1343 PetscCallCXX(KokkosSparse::spadd_numeric(&mm->kh2, 1.0, mm->C2, 1.0, mm->C4, mm->Co)); 1344 PetscFunctionReturn(PETSC_SUCCESS); 1345 } 1346 1347 static PetscErrorCode MatProductNumeric_MPIAIJKokkos(Mat C) 1348 { 1349 Mat_MPIAIJ *cmpi = static_cast<Mat_MPIAIJ *>(C->data); 1350 Mat_Product *product; 1351 MatProductData_MPIAIJKokkos *pdata; 1352 MatProductType ptype; 1353 Mat A, B; 1354 1355 PetscFunctionBegin; 1356 MatCheckProduct(C, 1); // make sure C is a product 1357 product = C->product; 1358 pdata = static_cast<MatProductData_MPIAIJKokkos *>(product->data); 1359 ptype = product->type; 1360 A = product->A; 1361 B = product->B; 1362 1363 // See if numeric has already been done in symbolic (e.g., user calls MatMatMult(A,B,MAT_INITIAL_MATRIX,..,C)). 1364 // If yes, skip the numeric, but reset the flag so that next time when user calls MatMatMult(E,F,MAT_REUSE_MATRIX,..,C), 1365 // we still do numeric. 1366 if (pdata->reusesym) { // numeric reuses results from symbolic 1367 pdata->reusesym = PETSC_FALSE; 1368 PetscFunctionReturn(PETSC_SUCCESS); 1369 } 1370 1371 if (ptype == MATPRODUCT_AB) { 1372 PetscCall(MatProductNumeric_MPIAIJKokkos_AB(product, A, B, pdata->mmAB)); 1373 } else if (ptype == MATPRODUCT_AtB) { 1374 PetscCall(MatProductNumeric_MPIAIJKokkos_AtB(product, A, B, pdata->mmAtB)); 1375 } else if (ptype == MATPRODUCT_PtAP) { // BtAB, computed by Z = AB; C= BtZ 1376 PetscCall(MatProductNumeric_MPIAIJKokkos_AB(product, A, B, pdata->mmAB)); 1377 PetscCall(MatProductNumeric_MPIAIJKokkos_AtB(product, B, pdata->Z, pdata->mmAtB)); 1378 } 1379 1380 PetscCall(MatSeqAIJKokkosModifyDevice(cmpi->A)); // mark that A, B on device are modified 1381 PetscCall(MatSeqAIJKokkosModifyDevice(cmpi->B)); 1382 PetscFunctionReturn(PETSC_SUCCESS); 1383 } 1384 1385 static PetscErrorCode MatProductSymbolic_MPIAIJKokkos(Mat C) 1386 { 1387 Mat A, B; 1388 Mat_Product *product; 1389 MatProductType ptype; 1390 MatProductData_MPIAIJKokkos *pdata; 1391 MatMatStruct *mm = NULL; 1392 PetscInt m, n, M, N; 1393 Mat Cd, Co; 1394 MPI_Comm comm; 1395 1396 PetscFunctionBegin; 1397 PetscCall(PetscObjectGetComm((PetscObject)C, &comm)); 1398 MatCheckProduct(C, 1); 1399 product = C->product; 1400 PetscCheck(!product->data, comm, PETSC_ERR_PLIB, "Product data not empty"); 1401 ptype = product->type; 1402 A = product->A; 1403 B = product->B; 1404 1405 switch (ptype) { 1406 case MATPRODUCT_AB: 1407 m = A->rmap->n; 1408 n = B->cmap->n; 1409 M = A->rmap->N; 1410 N = B->cmap->N; 1411 break; 1412 case MATPRODUCT_AtB: 1413 m = A->cmap->n; 1414 n = B->cmap->n; 1415 M = A->cmap->N; 1416 N = B->cmap->N; 1417 break; 1418 case MATPRODUCT_PtAP: 1419 m = B->cmap->n; 1420 n = B->cmap->n; 1421 M = B->cmap->N; 1422 N = B->cmap->N; 1423 break; /* BtAB */ 1424 default: 1425 SETERRQ(comm, PETSC_ERR_PLIB, "Not for product type %s", MatProductTypes[ptype]); 1426 } 1427 1428 PetscCall(MatSetSizes(C, m, n, M, N)); 1429 PetscCall(PetscLayoutSetUp(C->rmap)); 1430 PetscCall(PetscLayoutSetUp(C->cmap)); 1431 PetscCall(MatSetType(C, ((PetscObject)A)->type_name)); 1432 1433 pdata = new MatProductData_MPIAIJKokkos(); 1434 pdata->reusesym = product->api_user; 1435 1436 if (ptype == MATPRODUCT_AB) { 1437 auto mmAB = new MatMatStruct_AB(); 1438 PetscCall(MatProductSymbolic_MPIAIJKokkos_AB(product, A, B, mmAB)); 1439 mm = pdata->mmAB = mmAB; 1440 } else if (ptype == MATPRODUCT_AtB) { 1441 auto mmAtB = new MatMatStruct_AtB(); 1442 PetscCall(MatProductSymbolic_MPIAIJKokkos_AtB(product, A, B, mmAtB)); 1443 mm = pdata->mmAtB = mmAtB; 1444 } else if (ptype == MATPRODUCT_PtAP) { // C = BtAB, computed as Z = AB; C= BtZ 1445 Mat Zd, Zo, Z; // Zd, Zo are owned by pdata->Z 1446 1447 auto mmAB = new MatMatStruct_AB(); 1448 PetscCall(MatProductSymbolic_MPIAIJKokkos_AB(product, A, B, mmAB)); // Z stored as mmAB->{Cd, Co} 1449 PetscCall(MatCreateSeqAIJKokkosWithKokkosCsrMatrix(PETSC_COMM_SELF, mmAB->Cd, &Zd)); 1450 PetscCall(MatCreateSeqAIJKokkosWithKokkosCsrMatrix(PETSC_COMM_SELF, mmAB->Co, &Zo)); 1451 pdata->mmAB = mmAB; 1452 1453 m = A->rmap->n; // Z's layout 1454 n = B->cmap->n; 1455 M = A->rmap->N; 1456 N = B->cmap->N; 1457 PetscCall(MatCreate(comm, &Z)); 1458 PetscCall(MatSetSizes(Z, m, n, M, N)); 1459 PetscCall(PetscLayoutSetUp(Z->rmap)); 1460 PetscCall(PetscLayoutSetUp(Z->cmap)); 1461 PetscCall(MatSetType(Z, MATMPIAIJKOKKOS)); 1462 PetscCall(MatSetMPIAIJKokkosWithSplitSeqAIJKokkosMatrices(Z, Zd, Zo, mmAB->garray)); 1463 1464 auto mmAtB = new MatMatStruct_AtB(); 1465 PetscCall(MatProductSymbolic_MPIAIJKokkos_AtB(product, B, Z, mmAtB)); // final result C stored as mmAtB->{Cd, Co} 1466 1467 pdata->Z = Z; // give ownership to pdata 1468 mm = pdata->mmAtB = mmAtB; 1469 } 1470 1471 PetscCall(MatCreateSeqAIJKokkosWithKokkosCsrMatrix(PETSC_COMM_SELF, mm->Cd, &Cd)); 1472 PetscCall(MatCreateSeqAIJKokkosWithKokkosCsrMatrix(PETSC_COMM_SELF, mm->Co, &Co)); 1473 PetscCall(MatSetMPIAIJKokkosWithSplitSeqAIJKokkosMatrices(C, Cd, Co, mm->garray)); 1474 1475 C->product->data = pdata; 1476 C->product->destroy = MatProductDataDestroy_MPIAIJKokkos; 1477 C->ops->productnumeric = MatProductNumeric_MPIAIJKokkos; 1478 PetscFunctionReturn(PETSC_SUCCESS); 1479 } 1480 1481 PETSC_INTERN PetscErrorCode MatProductSetFromOptions_MPIAIJKokkos(Mat mat) 1482 { 1483 Mat_Product *product = mat->product; 1484 PetscBool match = PETSC_FALSE; 1485 PetscBool usecpu = PETSC_FALSE; 1486 1487 PetscFunctionBegin; 1488 MatCheckProduct(mat, 1); 1489 if (!product->A->boundtocpu && !product->B->boundtocpu) PetscCall(PetscObjectTypeCompare((PetscObject)product->B, ((PetscObject)product->A)->type_name, &match)); 1490 if (match) { /* we can always fallback to the CPU if requested */ 1491 switch (product->type) { 1492 case MATPRODUCT_AB: 1493 if (product->api_user) { 1494 PetscOptionsBegin(PetscObjectComm((PetscObject)mat), ((PetscObject)mat)->prefix, "MatMatMult", "Mat"); 1495 PetscCall(PetscOptionsBool("-matmatmult_backend_cpu", "Use CPU code", "MatMatMult", usecpu, &usecpu, NULL)); 1496 PetscOptionsEnd(); 1497 } else { 1498 PetscOptionsBegin(PetscObjectComm((PetscObject)mat), ((PetscObject)mat)->prefix, "MatProduct_AB", "Mat"); 1499 PetscCall(PetscOptionsBool("-mat_product_algorithm_backend_cpu", "Use CPU code", "MatMatMult", usecpu, &usecpu, NULL)); 1500 PetscOptionsEnd(); 1501 } 1502 break; 1503 case MATPRODUCT_AtB: 1504 if (product->api_user) { 1505 PetscOptionsBegin(PetscObjectComm((PetscObject)mat), ((PetscObject)mat)->prefix, "MatTransposeMatMult", "Mat"); 1506 PetscCall(PetscOptionsBool("-mattransposematmult_backend_cpu", "Use CPU code", "MatTransposeMatMult", usecpu, &usecpu, NULL)); 1507 PetscOptionsEnd(); 1508 } else { 1509 PetscOptionsBegin(PetscObjectComm((PetscObject)mat), ((PetscObject)mat)->prefix, "MatProduct_AtB", "Mat"); 1510 PetscCall(PetscOptionsBool("-mat_product_algorithm_backend_cpu", "Use CPU code", "MatTransposeMatMult", usecpu, &usecpu, NULL)); 1511 PetscOptionsEnd(); 1512 } 1513 break; 1514 case MATPRODUCT_PtAP: 1515 if (product->api_user) { 1516 PetscOptionsBegin(PetscObjectComm((PetscObject)mat), ((PetscObject)mat)->prefix, "MatPtAP", "Mat"); 1517 PetscCall(PetscOptionsBool("-matptap_backend_cpu", "Use CPU code", "MatPtAP", usecpu, &usecpu, NULL)); 1518 PetscOptionsEnd(); 1519 } else { 1520 PetscOptionsBegin(PetscObjectComm((PetscObject)mat), ((PetscObject)mat)->prefix, "MatProduct_PtAP", "Mat"); 1521 PetscCall(PetscOptionsBool("-mat_product_algorithm_backend_cpu", "Use CPU code", "MatPtAP", usecpu, &usecpu, NULL)); 1522 PetscOptionsEnd(); 1523 } 1524 break; 1525 default: 1526 break; 1527 } 1528 match = (PetscBool)!usecpu; 1529 } 1530 if (match) { 1531 switch (product->type) { 1532 case MATPRODUCT_AB: 1533 case MATPRODUCT_AtB: 1534 case MATPRODUCT_PtAP: 1535 mat->ops->productsymbolic = MatProductSymbolic_MPIAIJKokkos; 1536 break; 1537 default: 1538 break; 1539 } 1540 } 1541 /* fallback to MPIAIJ ops */ 1542 if (!mat->ops->productsymbolic) PetscCall(MatProductSetFromOptions_MPIAIJ(mat)); 1543 PetscFunctionReturn(PETSC_SUCCESS); 1544 } 1545 1546 // Mirror of MatCOOStruct_MPIAIJ on device 1547 struct MatCOOStruct_MPIAIJKokkos { 1548 PetscCount n; 1549 PetscSF sf; 1550 PetscCount Annz, Bnnz; 1551 PetscCount Annz2, Bnnz2; 1552 PetscCountKokkosView Ajmap1, Aperm1; 1553 PetscCountKokkosView Bjmap1, Bperm1; 1554 PetscCountKokkosView Aimap2, Ajmap2, Aperm2; 1555 PetscCountKokkosView Bimap2, Bjmap2, Bperm2; 1556 PetscCountKokkosView Cperm1; 1557 MatScalarKokkosView sendbuf, recvbuf; 1558 1559 MatCOOStruct_MPIAIJKokkos(const MatCOOStruct_MPIAIJ *coo_h) 1560 { 1561 auto &exec = PetscGetKokkosExecutionSpace(); 1562 1563 n = coo_h->n; 1564 sf = coo_h->sf; 1565 Annz = coo_h->Annz; 1566 Bnnz = coo_h->Bnnz; 1567 Annz2 = coo_h->Annz2; 1568 Bnnz2 = coo_h->Bnnz2; 1569 Ajmap1 = Kokkos::create_mirror_view_and_copy(exec, PetscCountKokkosViewHost(coo_h->Ajmap1, coo_h->Annz + 1)); 1570 Aperm1 = Kokkos::create_mirror_view_and_copy(exec, PetscCountKokkosViewHost(coo_h->Aperm1, coo_h->Atot1)); 1571 Bjmap1 = Kokkos::create_mirror_view_and_copy(exec, PetscCountKokkosViewHost(coo_h->Bjmap1, coo_h->Bnnz + 1)); 1572 Bperm1 = Kokkos::create_mirror_view_and_copy(exec, PetscCountKokkosViewHost(coo_h->Bperm1, coo_h->Btot1)); 1573 Aimap2 = Kokkos::create_mirror_view_and_copy(exec, PetscCountKokkosViewHost(coo_h->Aimap2, coo_h->Annz2)); 1574 Ajmap2 = Kokkos::create_mirror_view_and_copy(exec, PetscCountKokkosViewHost(coo_h->Ajmap2, coo_h->Annz2 + 1)); 1575 Aperm2 = Kokkos::create_mirror_view_and_copy(exec, PetscCountKokkosViewHost(coo_h->Aperm2, coo_h->Atot2)); 1576 Bimap2 = Kokkos::create_mirror_view_and_copy(exec, PetscCountKokkosViewHost(coo_h->Bimap2, coo_h->Bnnz2)); 1577 Bjmap2 = Kokkos::create_mirror_view_and_copy(exec, PetscCountKokkosViewHost(coo_h->Bjmap2, coo_h->Bnnz2 + 1)); 1578 Bperm2 = Kokkos::create_mirror_view_and_copy(exec, PetscCountKokkosViewHost(coo_h->Bperm2, coo_h->Btot2)); 1579 Cperm1 = Kokkos::create_mirror_view_and_copy(exec, PetscCountKokkosViewHost(coo_h->Cperm1, coo_h->sendlen)); 1580 sendbuf = Kokkos::create_mirror_view(Kokkos::WithoutInitializing, exec, MatScalarKokkosViewHost(coo_h->sendbuf, coo_h->sendlen)); 1581 recvbuf = Kokkos::create_mirror_view(Kokkos::WithoutInitializing, exec, MatScalarKokkosViewHost(coo_h->recvbuf, coo_h->recvlen)); 1582 PetscCallVoid(PetscObjectReference((PetscObject)sf)); 1583 } 1584 1585 ~MatCOOStruct_MPIAIJKokkos() { PetscCallVoid(PetscSFDestroy(&sf)); } 1586 }; 1587 1588 static PetscErrorCode MatCOOStructDestroy_MPIAIJKokkos(void *data) 1589 { 1590 PetscFunctionBegin; 1591 PetscCallCXX(delete static_cast<MatCOOStruct_MPIAIJKokkos *>(data)); 1592 PetscFunctionReturn(PETSC_SUCCESS); 1593 } 1594 1595 static PetscErrorCode MatSetPreallocationCOO_MPIAIJKokkos(Mat mat, PetscCount coo_n, PetscInt coo_i[], PetscInt coo_j[]) 1596 { 1597 PetscContainer container_h, container_d; 1598 MatCOOStruct_MPIAIJ *coo_h; 1599 MatCOOStruct_MPIAIJKokkos *coo_d; 1600 1601 PetscFunctionBegin; 1602 PetscCall(MatSetPreallocationCOO_MPIAIJ(mat, coo_n, coo_i, coo_j)); /* mpiaij->A,B's type is set to seqaijkokkos */ 1603 mat->preallocated = PETSC_TRUE; 1604 PetscCall(MatAssemblyBegin(mat, MAT_FINAL_ASSEMBLY)); 1605 PetscCall(MatAssemblyEnd(mat, MAT_FINAL_ASSEMBLY)); 1606 PetscCall(MatZeroEntries(mat)); 1607 1608 // Copy the COO struct to device 1609 PetscCall(PetscObjectQuery((PetscObject)mat, "__PETSc_MatCOOStruct_Host", (PetscObject *)&container_h)); 1610 PetscCall(PetscContainerGetPointer(container_h, (void **)&coo_h)); 1611 PetscCallCXX(coo_d = new MatCOOStruct_MPIAIJKokkos(coo_h)); 1612 1613 // Put the COO struct in a container and then attach that to the matrix 1614 PetscCall(PetscContainerCreate(PETSC_COMM_SELF, &container_d)); 1615 PetscCall(PetscContainerSetPointer(container_d, coo_d)); 1616 PetscCall(PetscContainerSetUserDestroy(container_d, MatCOOStructDestroy_MPIAIJKokkos)); 1617 PetscCall(PetscObjectCompose((PetscObject)mat, "__PETSc_MatCOOStruct_Device", (PetscObject)container_d)); 1618 PetscCall(PetscContainerDestroy(&container_d)); 1619 PetscFunctionReturn(PETSC_SUCCESS); 1620 } 1621 1622 static PetscErrorCode MatSetValuesCOO_MPIAIJKokkos(Mat mat, const PetscScalar v[], InsertMode imode) 1623 { 1624 Mat_MPIAIJ *mpiaij = static_cast<Mat_MPIAIJ *>(mat->data); 1625 Mat A = mpiaij->A, B = mpiaij->B; 1626 MatScalarKokkosView Aa, Ba; 1627 MatScalarKokkosView v1; 1628 PetscMemType memtype; 1629 PetscContainer container; 1630 MatCOOStruct_MPIAIJKokkos *coo; 1631 Kokkos::DefaultExecutionSpace &exec = PetscGetKokkosExecutionSpace(); 1632 1633 PetscFunctionBegin; 1634 PetscCall(PetscObjectQuery((PetscObject)mat, "__PETSc_MatCOOStruct_Device", (PetscObject *)&container)); 1635 PetscCall(PetscContainerGetPointer(container, (void **)&coo)); 1636 1637 const auto &n = coo->n; 1638 const auto &Annz = coo->Annz; 1639 const auto &Annz2 = coo->Annz2; 1640 const auto &Bnnz = coo->Bnnz; 1641 const auto &Bnnz2 = coo->Bnnz2; 1642 const auto &vsend = coo->sendbuf; 1643 const auto &v2 = coo->recvbuf; 1644 const auto &Ajmap1 = coo->Ajmap1; 1645 const auto &Ajmap2 = coo->Ajmap2; 1646 const auto &Aimap2 = coo->Aimap2; 1647 const auto &Bjmap1 = coo->Bjmap1; 1648 const auto &Bjmap2 = coo->Bjmap2; 1649 const auto &Bimap2 = coo->Bimap2; 1650 const auto &Aperm1 = coo->Aperm1; 1651 const auto &Aperm2 = coo->Aperm2; 1652 const auto &Bperm1 = coo->Bperm1; 1653 const auto &Bperm2 = coo->Bperm2; 1654 const auto &Cperm1 = coo->Cperm1; 1655 1656 PetscCall(PetscGetMemType(v, &memtype)); /* Return PETSC_MEMTYPE_HOST when v is NULL */ 1657 if (PetscMemTypeHost(memtype)) { /* If user gave v[] in host, we need to copy it to device if any */ 1658 v1 = Kokkos::create_mirror_view_and_copy(exec, MatScalarKokkosViewHost((PetscScalar *)v, n)); 1659 } else { 1660 v1 = MatScalarKokkosView((PetscScalar *)v, n); /* Directly use v[]'s memory */ 1661 } 1662 1663 if (imode == INSERT_VALUES) { 1664 PetscCall(MatSeqAIJGetKokkosViewWrite(A, &Aa)); /* write matrix values */ 1665 PetscCall(MatSeqAIJGetKokkosViewWrite(B, &Ba)); 1666 } else { 1667 PetscCall(MatSeqAIJGetKokkosView(A, &Aa)); /* read & write matrix values */ 1668 PetscCall(MatSeqAIJGetKokkosView(B, &Ba)); 1669 } 1670 1671 PetscCall(PetscLogGpuTimeBegin()); 1672 /* Pack entries to be sent to remote */ 1673 Kokkos::parallel_for(Kokkos::RangePolicy<>(exec, 0, vsend.extent(0)), KOKKOS_LAMBDA(const PetscCount i) { vsend(i) = v1(Cperm1(i)); }); 1674 1675 /* Send remote entries to their owner and overlap the communication with local computation */ 1676 PetscCall(PetscSFReduceWithMemTypeBegin(coo->sf, MPIU_SCALAR, PETSC_MEMTYPE_KOKKOS, vsend.data(), PETSC_MEMTYPE_KOKKOS, v2.data(), MPI_REPLACE)); 1677 /* Add local entries to A and B in one kernel */ 1678 Kokkos::parallel_for( 1679 Kokkos::RangePolicy<>(exec, 0, Annz + Bnnz), KOKKOS_LAMBDA(PetscCount i) { 1680 PetscScalar sum = 0.0; 1681 if (i < Annz) { 1682 for (PetscCount k = Ajmap1(i); k < Ajmap1(i + 1); k++) sum += v1(Aperm1(k)); 1683 Aa(i) = (imode == INSERT_VALUES ? 0.0 : Aa(i)) + sum; 1684 } else { 1685 i -= Annz; 1686 for (PetscCount k = Bjmap1(i); k < Bjmap1(i + 1); k++) sum += v1(Bperm1(k)); 1687 Ba(i) = (imode == INSERT_VALUES ? 0.0 : Ba(i)) + sum; 1688 } 1689 }); 1690 PetscCall(PetscSFReduceEnd(coo->sf, MPIU_SCALAR, vsend.data(), v2.data(), MPI_REPLACE)); 1691 1692 /* Add received remote entries to A and B in one kernel */ 1693 Kokkos::parallel_for( 1694 Kokkos::RangePolicy<>(exec, 0, Annz2 + Bnnz2), KOKKOS_LAMBDA(PetscCount i) { 1695 if (i < Annz2) { 1696 for (PetscCount k = Ajmap2(i); k < Ajmap2(i + 1); k++) Aa(Aimap2(i)) += v2(Aperm2(k)); 1697 } else { 1698 i -= Annz2; 1699 for (PetscCount k = Bjmap2(i); k < Bjmap2(i + 1); k++) Ba(Bimap2(i)) += v2(Bperm2(k)); 1700 } 1701 }); 1702 PetscCall(PetscLogGpuTimeEnd()); 1703 1704 if (imode == INSERT_VALUES) { 1705 PetscCall(MatSeqAIJRestoreKokkosViewWrite(A, &Aa)); /* Increase A & B's state etc. */ 1706 PetscCall(MatSeqAIJRestoreKokkosViewWrite(B, &Ba)); 1707 } else { 1708 PetscCall(MatSeqAIJRestoreKokkosView(A, &Aa)); 1709 PetscCall(MatSeqAIJRestoreKokkosView(B, &Ba)); 1710 } 1711 PetscFunctionReturn(PETSC_SUCCESS); 1712 } 1713 1714 static PetscErrorCode MatDestroy_MPIAIJKokkos(Mat A) 1715 { 1716 PetscFunctionBegin; 1717 PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatMPIAIJSetPreallocation_C", NULL)); 1718 PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatMPIAIJGetLocalMatMerge_C", NULL)); 1719 PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatSetPreallocationCOO_C", NULL)); 1720 PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatSetValuesCOO_C", NULL)); 1721 PetscCall(MatDestroy_MPIAIJ(A)); 1722 PetscFunctionReturn(PETSC_SUCCESS); 1723 } 1724 1725 static PetscErrorCode MatShift_MPIAIJKokkos(Mat A, PetscScalar a) 1726 { 1727 Mat_MPIAIJ *mpiaij = static_cast<Mat_MPIAIJ *>(A->data); 1728 PetscBool congruent; 1729 1730 PetscFunctionBegin; 1731 PetscCall(MatHasCongruentLayouts(A, &congruent)); 1732 if (congruent) { // square matrix and the diagonals are solely in the diag block 1733 PetscCall(MatShift(mpiaij->A, a)); 1734 } else { // too hard, use the general version 1735 PetscCall(MatShift_Basic(A, a)); 1736 } 1737 PetscFunctionReturn(PETSC_SUCCESS); 1738 } 1739 1740 static PetscErrorCode MatSetOps_MPIAIJKokkos(Mat B) 1741 { 1742 PetscFunctionBegin; 1743 B->ops->assemblyend = MatAssemblyEnd_MPIAIJKokkos; 1744 B->ops->mult = MatMult_MPIAIJKokkos; 1745 B->ops->multadd = MatMultAdd_MPIAIJKokkos; 1746 B->ops->multtranspose = MatMultTranspose_MPIAIJKokkos; 1747 B->ops->productsetfromoptions = MatProductSetFromOptions_MPIAIJKokkos; 1748 B->ops->destroy = MatDestroy_MPIAIJKokkos; 1749 B->ops->shift = MatShift_MPIAIJKokkos; 1750 1751 PetscCall(PetscObjectComposeFunction((PetscObject)B, "MatMPIAIJSetPreallocation_C", MatMPIAIJSetPreallocation_MPIAIJKokkos)); 1752 PetscCall(PetscObjectComposeFunction((PetscObject)B, "MatMPIAIJGetLocalMatMerge_C", MatMPIAIJGetLocalMatMerge_MPIAIJKokkos)); 1753 PetscCall(PetscObjectComposeFunction((PetscObject)B, "MatSetPreallocationCOO_C", MatSetPreallocationCOO_MPIAIJKokkos)); 1754 PetscCall(PetscObjectComposeFunction((PetscObject)B, "MatSetValuesCOO_C", MatSetValuesCOO_MPIAIJKokkos)); 1755 PetscFunctionReturn(PETSC_SUCCESS); 1756 } 1757 1758 PETSC_INTERN PetscErrorCode MatConvert_MPIAIJ_MPIAIJKokkos(Mat A, MatType mtype, MatReuse reuse, Mat *newmat) 1759 { 1760 Mat B; 1761 Mat_MPIAIJ *a; 1762 1763 PetscFunctionBegin; 1764 if (reuse == MAT_INITIAL_MATRIX) { 1765 PetscCall(MatDuplicate(A, MAT_COPY_VALUES, newmat)); 1766 } else if (reuse == MAT_REUSE_MATRIX) { 1767 PetscCall(MatCopy(A, *newmat, SAME_NONZERO_PATTERN)); 1768 } 1769 B = *newmat; 1770 1771 B->boundtocpu = PETSC_FALSE; 1772 PetscCall(PetscFree(B->defaultvectype)); 1773 PetscCall(PetscStrallocpy(VECKOKKOS, &B->defaultvectype)); 1774 PetscCall(PetscObjectChangeTypeName((PetscObject)B, MATMPIAIJKOKKOS)); 1775 1776 a = static_cast<Mat_MPIAIJ *>(A->data); 1777 if (a->A) PetscCall(MatSetType(a->A, MATSEQAIJKOKKOS)); 1778 if (a->B) PetscCall(MatSetType(a->B, MATSEQAIJKOKKOS)); 1779 if (a->lvec) PetscCall(VecSetType(a->lvec, VECSEQKOKKOS)); 1780 PetscCall(MatSetOps_MPIAIJKokkos(B)); 1781 PetscFunctionReturn(PETSC_SUCCESS); 1782 } 1783 1784 /*MC 1785 MATAIJKOKKOS - "mpiaijkokkos", a matrix type to be used for CSR sparse matrices with Kokkos 1786 1787 A matrix type using Kokkos-Kernels CrsMatrix type for portability across different device types 1788 1789 Options Database Key: 1790 . -mat_type aijkokkos - sets the matrix type to `MATAIJKOKKOS` 1791 1792 Level: beginner 1793 1794 .seealso: [](ch_matrices), `Mat`, `MatCreateAIJKokkos()`, `MATSEQAIJKOKKOS`, `MATSEQAIJ`, `MATMPIAIJ` 1795 M*/ 1796 PETSC_EXTERN PetscErrorCode MatCreate_MPIAIJKokkos(Mat A) 1797 { 1798 PetscFunctionBegin; 1799 PetscCall(PetscKokkosInitializeCheck()); 1800 PetscCall(MatCreate_MPIAIJ(A)); 1801 PetscCall(MatConvert_MPIAIJ_MPIAIJKokkos(A, MATMPIAIJKOKKOS, MAT_INPLACE_MATRIX, &A)); 1802 PetscFunctionReturn(PETSC_SUCCESS); 1803 } 1804 1805 /*@C 1806 MatCreateAIJKokkos - Creates a sparse matrix in `MATAIJKOKOS` (compressed row) format 1807 (the default parallel PETSc format). This matrix will ultimately pushed down 1808 to Kokkos for calculations. 1809 1810 Collective 1811 1812 Input Parameters: 1813 + comm - MPI communicator, set to `PETSC_COMM_SELF` 1814 . m - number of local rows (or `PETSC_DECIDE` to have calculated if `M` is given) 1815 This value should be the same as the local size used in creating the 1816 y vector for the matrix-vector product y = Ax. 1817 . n - This value should be the same as the local size used in creating the 1818 x vector for the matrix-vector product y = Ax. (or `PETSC_DECIDE` to have 1819 calculated if N is given) For square matrices n is almost always `m`. 1820 . M - number of global rows (or `PETSC_DETERMINE` to have calculated if `m` is given) 1821 . N - number of global columns (or `PETSC_DETERMINE` to have calculated if `n` is given) 1822 . d_nz - number of nonzeros per row in DIAGONAL portion of local submatrix 1823 (same value is used for all local rows) 1824 . d_nnz - array containing the number of nonzeros in the various rows of the 1825 DIAGONAL portion of the local submatrix (possibly different for each row) 1826 or `NULL`, if `d_nz` is used to specify the nonzero structure. 1827 The size of this array is equal to the number of local rows, i.e `m`. 1828 For matrices you plan to factor you must leave room for the diagonal entry and 1829 put in the entry even if it is zero. 1830 . o_nz - number of nonzeros per row in the OFF-DIAGONAL portion of local 1831 submatrix (same value is used for all local rows). 1832 - o_nnz - array containing the number of nonzeros in the various rows of the 1833 OFF-DIAGONAL portion of the local submatrix (possibly different for 1834 each row) or `NULL`, if `o_nz` is used to specify the nonzero 1835 structure. The size of this array is equal to the number 1836 of local rows, i.e `m`. 1837 1838 Output Parameter: 1839 . A - the matrix 1840 1841 Level: intermediate 1842 1843 Notes: 1844 It is recommended that one use the `MatCreate()`, `MatSetType()` and/or `MatSetFromOptions()`, 1845 MatXXXXSetPreallocation() paradigm instead of this routine directly. 1846 [MatXXXXSetPreallocation() is, for example, `MatSeqAIJSetPreallocation()`] 1847 1848 The AIJ format, also called compressed row storage), is fully compatible with standard Fortran 1849 storage. That is, the stored row and column indices can begin at 1850 either one (as in Fortran) or zero. 1851 1852 .seealso: [](ch_matrices), `Mat`, `MATAIJKOKOS`, `MATSEQAIJKOKOS`, `MATMPIAIJKOKOS`, `MatCreate()`, `MatCreateAIJ()`, `MatSetValues()`, 1853 `MatSeqAIJSetColumnIndices()`, `MatCreateSeqAIJWithArrays()`, `MATMPIAIJKOKKOS`, `MATAIJKOKKOS` 1854 @*/ 1855 PetscErrorCode MatCreateAIJKokkos(MPI_Comm comm, PetscInt m, PetscInt n, PetscInt M, PetscInt N, PetscInt d_nz, const PetscInt d_nnz[], PetscInt o_nz, const PetscInt o_nnz[], Mat *A) 1856 { 1857 PetscMPIInt size; 1858 1859 PetscFunctionBegin; 1860 PetscCall(MatCreate(comm, A)); 1861 PetscCall(MatSetSizes(*A, m, n, M, N)); 1862 PetscCallMPI(MPI_Comm_size(comm, &size)); 1863 if (size > 1) { 1864 PetscCall(MatSetType(*A, MATMPIAIJKOKKOS)); 1865 PetscCall(MatMPIAIJSetPreallocation(*A, d_nz, d_nnz, o_nz, o_nnz)); 1866 } else { 1867 PetscCall(MatSetType(*A, MATSEQAIJKOKKOS)); 1868 PetscCall(MatSeqAIJSetPreallocation(*A, d_nz, d_nnz)); 1869 } 1870 PetscFunctionReturn(PETSC_SUCCESS); 1871 } 1872