1 #ifndef PETSCVECSEQCUPM_HPP 2 #define PETSCVECSEQCUPM_HPP 3 4 #include <petsc/private/veccupmimpl.h> 5 6 #if defined(__cplusplus) 7 #include <petsc/private/randomimpl.h> // for _p_PetscRandom 8 9 #include <petsc/private/cpp/utility.hpp> // util::exchange, util::index_sequence 10 11 #include "../src/sys/objects/device/impls/cupm/cupmthrustutility.hpp" 12 #include "../src/sys/objects/device/impls/cupm/kernels.hpp" 13 14 #if PetscDefined(USE_COMPLEX) 15 #include <thrust/transform_reduce.h> 16 #endif 17 #include <thrust/transform.h> 18 #include <thrust/reduce.h> 19 #include <thrust/functional.h> 20 #include <thrust/tuple.h> 21 #include <thrust/device_ptr.h> 22 #include <thrust/iterator/zip_iterator.h> 23 #include <thrust/iterator/counting_iterator.h> 24 #include <thrust/inner_product.h> 25 26 namespace Petsc 27 { 28 29 namespace vec 30 { 31 32 namespace cupm 33 { 34 35 namespace impl 36 { 37 38 // ========================================================================================== 39 // VecSeq_CUPM 40 // ========================================================================================== 41 42 template <device::cupm::DeviceType T> 43 class VecSeq_CUPM : Vec_CUPMBase<T, VecSeq_CUPM<T>> { 44 public: 45 PETSC_VEC_CUPM_BASE_CLASS_HEADER(base_type, T, VecSeq_CUPM<T>); 46 47 private: 48 PETSC_NODISCARD static Vec_Seq *VecIMPLCast_(Vec) noexcept; 49 PETSC_NODISCARD static constexpr VecType VECIMPLCUPM_() noexcept; 50 PETSC_NODISCARD static constexpr VecType VECIMPL_() noexcept; 51 52 static PetscErrorCode VecDestroy_IMPL_(Vec) noexcept; 53 static PetscErrorCode VecResetArray_IMPL_(Vec) noexcept; 54 static PetscErrorCode VecPlaceArray_IMPL_(Vec, const PetscScalar *) noexcept; 55 static PetscErrorCode VecCreate_IMPL_Private_(Vec, PetscBool *, PetscInt, PetscScalar *) noexcept; 56 57 static PetscErrorCode MaybeIncrementEmptyLocalVec(Vec) noexcept; 58 59 // common core for min and max 60 template <typename TupleFuncT, typename UnaryFuncT> 61 static PetscErrorCode MinMax_(TupleFuncT &&, UnaryFuncT &&, Vec, PetscInt *, PetscReal *) noexcept; 62 // common core for pointwise binary and pointwise unary thrust functions 63 template <typename BinaryFuncT> 64 static PetscErrorCode PointwiseBinary_(BinaryFuncT &&, Vec, Vec, Vec) noexcept; 65 template <typename UnaryFuncT> 66 static PetscErrorCode PointwiseUnary_(UnaryFuncT &&, Vec, Vec /*out*/ = nullptr) noexcept; 67 // mdot dispatchers 68 static PetscErrorCode MDot_(/* use complex = */ std::true_type, Vec, PetscInt, const Vec[], PetscScalar *, PetscDeviceContext) noexcept; 69 static PetscErrorCode MDot_(/* use complex = */ std::false_type, Vec, PetscInt, const Vec[], PetscScalar *, PetscDeviceContext) noexcept; 70 template <std::size_t... Idx> 71 static PetscErrorCode MDot_kernel_dispatch_(PetscDeviceContext, cupmStream_t, const PetscScalar *, const Vec[], PetscInt, PetscScalar *, util::index_sequence<Idx...>) noexcept; 72 template <int> 73 static PetscErrorCode MDot_kernel_dispatch_(PetscDeviceContext, cupmStream_t, const PetscScalar *, const Vec[], PetscInt, PetscScalar *, PetscInt &) noexcept; 74 template <std::size_t... Idx> 75 static PetscErrorCode MAXPY_kernel_dispatch_(PetscDeviceContext, cupmStream_t, PetscScalar *, const PetscScalar *, const Vec *, PetscInt, util::index_sequence<Idx...>) noexcept; 76 template <int> 77 static PetscErrorCode MAXPY_kernel_dispatch_(PetscDeviceContext, cupmStream_t, PetscScalar *, const PetscScalar *, const Vec *, PetscInt, PetscInt &) noexcept; 78 // common core for the various create routines 79 static PetscErrorCode CreateSeqCUPM_(Vec, PetscDeviceContext, PetscScalar * /*host_ptr*/ = nullptr, PetscScalar * /*device_ptr*/ = nullptr) noexcept; 80 81 public: 82 // callable directly via a bespoke function 83 static PetscErrorCode CreateSeqCUPM(MPI_Comm, PetscInt, PetscInt, Vec *, PetscBool) noexcept; 84 static PetscErrorCode CreateSeqCUPMWithBothArrays(MPI_Comm, PetscInt, PetscInt, const PetscScalar[], const PetscScalar[], Vec *) noexcept; 85 86 // callable indirectly via function pointers 87 static PetscErrorCode Duplicate(Vec, Vec *) noexcept; 88 static PetscErrorCode AYPX(Vec, PetscScalar, Vec) noexcept; 89 static PetscErrorCode AXPY(Vec, PetscScalar, Vec) noexcept; 90 static PetscErrorCode PointwiseDivide(Vec, Vec, Vec) noexcept; 91 static PetscErrorCode PointwiseMult(Vec, Vec, Vec) noexcept; 92 static PetscErrorCode Reciprocal(Vec) noexcept; 93 static PetscErrorCode WAXPY(Vec, PetscScalar, Vec, Vec) noexcept; 94 static PetscErrorCode MAXPY(Vec, PetscInt, const PetscScalar[], Vec *) noexcept; 95 static PetscErrorCode Dot(Vec, Vec, PetscScalar *) noexcept; 96 static PetscErrorCode MDot(Vec, PetscInt, const Vec[], PetscScalar *) noexcept; 97 static PetscErrorCode Set(Vec, PetscScalar) noexcept; 98 static PetscErrorCode Scale(Vec, PetscScalar) noexcept; 99 static PetscErrorCode TDot(Vec, Vec, PetscScalar *) noexcept; 100 static PetscErrorCode Copy(Vec, Vec) noexcept; 101 static PetscErrorCode Swap(Vec, Vec) noexcept; 102 static PetscErrorCode AXPBY(Vec, PetscScalar, PetscScalar, Vec) noexcept; 103 static PetscErrorCode AXPBYPCZ(Vec, PetscScalar, PetscScalar, PetscScalar, Vec, Vec) noexcept; 104 static PetscErrorCode Norm(Vec, NormType, PetscReal *) noexcept; 105 static PetscErrorCode DotNorm2(Vec, Vec, PetscScalar *, PetscScalar *) noexcept; 106 static PetscErrorCode Conjugate(Vec) noexcept; 107 template <PetscMemoryAccessMode> 108 static PetscErrorCode GetLocalVector(Vec, Vec) noexcept; 109 template <PetscMemoryAccessMode> 110 static PetscErrorCode RestoreLocalVector(Vec, Vec) noexcept; 111 static PetscErrorCode Max(Vec, PetscInt *, PetscReal *) noexcept; 112 static PetscErrorCode Min(Vec, PetscInt *, PetscReal *) noexcept; 113 static PetscErrorCode Sum(Vec, PetscScalar *) noexcept; 114 static PetscErrorCode Shift(Vec, PetscScalar) noexcept; 115 static PetscErrorCode SetRandom(Vec, PetscRandom) noexcept; 116 static PetscErrorCode BindToCPU(Vec, PetscBool) noexcept; 117 static PetscErrorCode SetPreallocationCOO(Vec, PetscCount, const PetscInt[]) noexcept; 118 static PetscErrorCode SetValuesCOO(Vec, const PetscScalar[], InsertMode) noexcept; 119 }; 120 121 // ========================================================================================== 122 // VecSeq_CUPM - Private API 123 // ========================================================================================== 124 125 template <device::cupm::DeviceType T> 126 inline Vec_Seq *VecSeq_CUPM<T>::VecIMPLCast_(Vec v) noexcept 127 { 128 return static_cast<Vec_Seq *>(v->data); 129 } 130 131 template <device::cupm::DeviceType T> 132 inline constexpr VecType VecSeq_CUPM<T>::VECIMPLCUPM_() noexcept 133 { 134 return VECSEQCUPM(); 135 } 136 137 template <device::cupm::DeviceType T> 138 inline constexpr VecType VecSeq_CUPM<T>::VECIMPL_() noexcept 139 { 140 return VECSEQ; 141 } 142 143 template <device::cupm::DeviceType T> 144 inline PetscErrorCode VecSeq_CUPM<T>::VecDestroy_IMPL_(Vec v) noexcept 145 { 146 return VecDestroy_Seq(v); 147 } 148 149 template <device::cupm::DeviceType T> 150 inline PetscErrorCode VecSeq_CUPM<T>::VecResetArray_IMPL_(Vec v) noexcept 151 { 152 return VecResetArray_Seq(v); 153 } 154 155 template <device::cupm::DeviceType T> 156 inline PetscErrorCode VecSeq_CUPM<T>::VecPlaceArray_IMPL_(Vec v, const PetscScalar *a) noexcept 157 { 158 return VecPlaceArray_Seq(v, a); 159 } 160 161 template <device::cupm::DeviceType T> 162 inline PetscErrorCode VecSeq_CUPM<T>::VecCreate_IMPL_Private_(Vec v, PetscBool *alloc_missing, PetscInt, PetscScalar *host_array) noexcept 163 { 164 PetscMPIInt size; 165 166 PetscFunctionBegin; 167 if (alloc_missing) *alloc_missing = PETSC_FALSE; 168 PetscCallMPI(MPI_Comm_size(PetscObjectComm(PetscObjectCast(v)), &size)); 169 PetscCheck(size <= 1, PETSC_COMM_SELF, PETSC_ERR_ARG_WRONG, "Must create VecSeq on communicator of size 1, have size %d", size); 170 PetscCall(VecCreate_Seq_Private(v, host_array)); 171 PetscFunctionReturn(PETSC_SUCCESS); 172 } 173 174 // for functions with an early return based one vec size we still need to artificially bump the 175 // object state. This is to prevent the following: 176 // 177 // 0. Suppose you have a Vec { 178 // rank 0: [0], 179 // rank 1: [<empty>] 180 // } 181 // 1. both ranks have Vec with PetscObjectState = 0, stashed norm of 0 182 // 2. Vec enters e.g. VecSet(10) 183 // 3. rank 1 has local size 0 and bails immediately 184 // 4. rank 0 has local size 1 and enters function, eventually calls DeviceArrayWrite() 185 // 5. DeviceArrayWrite() calls PetscObjectStateIncrease(), now state = 1 186 // 6. Vec enters VecNorm(), and calls VecNormAvailable() 187 // 7. rank 1 has object state = 0, equal to stash and returns early with norm = 0 188 // 8. rank 0 has object state = 1, not equal to stash, continues to impl function 189 // 9. rank 0 deadlocks on MPI_Allreduce() because rank 1 bailed early 190 template <device::cupm::DeviceType T> 191 inline PetscErrorCode VecSeq_CUPM<T>::MaybeIncrementEmptyLocalVec(Vec v) noexcept 192 { 193 PetscFunctionBegin; 194 if (PetscUnlikely((v->map->n == 0) && (v->map->N != 0))) PetscCall(PetscObjectStateIncrease(PetscObjectCast(v))); 195 PetscFunctionReturn(PETSC_SUCCESS); 196 } 197 198 template <device::cupm::DeviceType T> 199 inline PetscErrorCode VecSeq_CUPM<T>::CreateSeqCUPM_(Vec v, PetscDeviceContext dctx, PetscScalar *host_array, PetscScalar *device_array) noexcept 200 { 201 PetscFunctionBegin; 202 PetscCall(base_type::VecCreate_IMPL_Private(v, nullptr, 0, host_array)); 203 PetscCall(Initialize_CUPMBase(v, PETSC_FALSE, host_array, device_array, dctx)); 204 PetscFunctionReturn(PETSC_SUCCESS); 205 } 206 207 template <device::cupm::DeviceType T> 208 template <typename BinaryFuncT> 209 inline PetscErrorCode VecSeq_CUPM<T>::PointwiseBinary_(BinaryFuncT &&binary, Vec xin, Vec yin, Vec zout) noexcept 210 { 211 PetscFunctionBegin; 212 if (const auto n = zout->map->n) { 213 PetscDeviceContext dctx; 214 cupmStream_t stream; 215 216 PetscCall(GetHandles_(&dctx, &stream)); 217 // clang-format off 218 PetscCallThrust( 219 const auto dxptr = thrust::device_pointer_cast(DeviceArrayRead(dctx, xin).data()); 220 221 THRUST_CALL( 222 thrust::transform, 223 stream, 224 dxptr, dxptr + n, 225 thrust::device_pointer_cast(DeviceArrayRead(dctx, yin).data()), 226 thrust::device_pointer_cast(DeviceArrayWrite(dctx, zout).data()), 227 std::forward<BinaryFuncT>(binary) 228 ) 229 ); 230 // clang-format on 231 PetscCall(PetscLogFlops(n)); 232 PetscCall(PetscDeviceContextSynchronize(dctx)); 233 } else { 234 PetscCall(MaybeIncrementEmptyLocalVec(zout)); 235 } 236 PetscFunctionReturn(PETSC_SUCCESS); 237 } 238 239 template <device::cupm::DeviceType T> 240 template <typename UnaryFuncT> 241 inline PetscErrorCode VecSeq_CUPM<T>::PointwiseUnary_(UnaryFuncT &&unary, Vec xinout, Vec yin) noexcept 242 { 243 const auto inplace = !yin || (xinout == yin); 244 245 PetscFunctionBegin; 246 if (const auto n = xinout->map->n) { 247 PetscDeviceContext dctx; 248 cupmStream_t stream; 249 const auto apply = [&](PetscScalar *xinout, PetscScalar *yin = nullptr) { 250 PetscFunctionBegin; 251 // clang-format off 252 PetscCallThrust( 253 const auto xptr = thrust::device_pointer_cast(xinout); 254 255 THRUST_CALL( 256 thrust::transform, 257 stream, 258 xptr, xptr + n, 259 (yin && (yin != xinout)) ? thrust::device_pointer_cast(yin) : xptr, 260 std::forward<UnaryFuncT>(unary) 261 ) 262 ); 263 PetscFunctionReturn(PETSC_SUCCESS); 264 }; 265 266 PetscCall(GetHandles_(&dctx, &stream)); 267 if (inplace) { 268 PetscCall(apply(DeviceArrayReadWrite(dctx, xinout).data())); 269 } else { 270 PetscCall(apply(DeviceArrayRead(dctx, xinout).data(), DeviceArrayWrite(dctx, yin).data())); 271 } 272 PetscCall(PetscLogFlops(n)); 273 PetscCall(PetscDeviceContextSynchronize(dctx)); 274 } else { 275 if (inplace) { 276 PetscCall(MaybeIncrementEmptyLocalVec(xinout)); 277 } else { 278 PetscCall(MaybeIncrementEmptyLocalVec(yin)); 279 } 280 } 281 PetscFunctionReturn(PETSC_SUCCESS); 282 } 283 284 // ========================================================================================== 285 // VecSeq_CUPM - Public API - Constructors 286 // ========================================================================================== 287 288 // VecCreateSeqCUPM() 289 template <device::cupm::DeviceType T> 290 inline PetscErrorCode VecSeq_CUPM<T>::CreateSeqCUPM(MPI_Comm comm, PetscInt bs, PetscInt n, Vec *v, PetscBool call_set_type) noexcept 291 { 292 PetscFunctionBegin; 293 PetscCall(Create_CUPMBase(comm, bs, n, n, v, call_set_type)); 294 PetscFunctionReturn(PETSC_SUCCESS); 295 } 296 297 // VecCreateSeqCUPMWithArrays() 298 template <device::cupm::DeviceType T> 299 inline PetscErrorCode VecSeq_CUPM<T>::CreateSeqCUPMWithBothArrays(MPI_Comm comm, PetscInt bs, PetscInt n, const PetscScalar host_array[], const PetscScalar device_array[], Vec *v) noexcept 300 { 301 PetscDeviceContext dctx; 302 303 PetscFunctionBegin; 304 PetscCall(GetHandles_(&dctx)); 305 // do NOT call VecSetType(), otherwise ops->create() -> create() -> 306 // CreateSeqCUPM_() is called! 307 PetscCall(CreateSeqCUPM(comm, bs, n, v, PETSC_FALSE)); 308 PetscCall(CreateSeqCUPM_(*v, dctx, PetscRemoveConstCast(host_array), PetscRemoveConstCast(device_array))); 309 PetscFunctionReturn(PETSC_SUCCESS); 310 } 311 312 // v->ops->duplicate 313 template <device::cupm::DeviceType T> 314 inline PetscErrorCode VecSeq_CUPM<T>::Duplicate(Vec v, Vec *y) noexcept 315 { 316 PetscDeviceContext dctx; 317 318 PetscFunctionBegin; 319 PetscCall(GetHandles_(&dctx)); 320 PetscCall(Duplicate_CUPMBase(v, y, dctx)); 321 PetscFunctionReturn(PETSC_SUCCESS); 322 } 323 324 // ========================================================================================== 325 // VecSeq_CUPM - Public API - Utility 326 // ========================================================================================== 327 328 // v->ops->bindtocpu 329 template <device::cupm::DeviceType T> 330 inline PetscErrorCode VecSeq_CUPM<T>::BindToCPU(Vec v, PetscBool usehost) noexcept 331 { 332 PetscDeviceContext dctx; 333 334 PetscFunctionBegin; 335 PetscCall(GetHandles_(&dctx)); 336 PetscCall(BindToCPU_CUPMBase(v, usehost, dctx)); 337 338 // REVIEW ME: this absolutely should be some sort of bulk mempcy rather than this mess 339 VecSetOp_CUPM(dot, VecDot_Seq, Dot); 340 VecSetOp_CUPM(norm, VecNorm_Seq, Norm); 341 VecSetOp_CUPM(tdot, VecTDot_Seq, TDot); 342 VecSetOp_CUPM(mdot, VecMDot_Seq, MDot); 343 VecSetOp_CUPM(resetarray, VecResetArray_Seq, base_type::template ResetArray<PETSC_MEMTYPE_HOST>); 344 VecSetOp_CUPM(placearray, VecPlaceArray_Seq, base_type::template PlaceArray<PETSC_MEMTYPE_HOST>); 345 v->ops->mtdot = v->ops->mtdot_local = VecMTDot_Seq; 346 VecSetOp_CUPM(conjugate, VecConjugate_Seq, Conjugate); 347 VecSetOp_CUPM(max, VecMax_Seq, Max); 348 VecSetOp_CUPM(min, VecMin_Seq, Min); 349 VecSetOp_CUPM(setpreallocationcoo, VecSetPreallocationCOO_Seq, SetPreallocationCOO); 350 VecSetOp_CUPM(setvaluescoo, VecSetValuesCOO_Seq, SetValuesCOO); 351 PetscFunctionReturn(PETSC_SUCCESS); 352 } 353 354 // ========================================================================================== 355 // VecSeq_CUPM - Public API - Mutators 356 // ========================================================================================== 357 358 // v->ops->getlocalvector or v->ops->getlocalvectorread 359 template <device::cupm::DeviceType T> 360 template <PetscMemoryAccessMode access> 361 inline PetscErrorCode VecSeq_CUPM<T>::GetLocalVector(Vec v, Vec w) noexcept 362 { 363 PetscBool wisseqcupm; 364 365 PetscFunctionBegin; 366 PetscCheckTypeNames(v, VECSEQCUPM(), VECMPICUPM()); 367 PetscCall(PetscObjectTypeCompare(PetscObjectCast(w), VECSEQCUPM(), &wisseqcupm)); 368 if (wisseqcupm) { 369 if (const auto wseq = VecIMPLCast(w)) { 370 if (auto &alloced = wseq->array_allocated) { 371 const auto useit = UseCUPMHostAlloc(util::exchange(w->pinned_memory, PETSC_FALSE)); 372 373 PetscCall(PetscFree(alloced)); 374 } 375 wseq->array = nullptr; 376 wseq->unplacedarray = nullptr; 377 } 378 if (const auto wcu = VecCUPMCast(w)) { 379 if (auto &device_array = wcu->array_d) { 380 cupmStream_t stream; 381 382 PetscCall(GetHandles_(&stream)); 383 PetscCallCUPM(cupmFreeAsync(device_array, stream)); 384 } 385 PetscCall(PetscFree(w->spptr /* wcu */)); 386 } 387 } 388 if (v->petscnative && wisseqcupm) { 389 PetscCall(PetscFree(w->data)); 390 w->data = v->data; 391 w->offloadmask = v->offloadmask; 392 w->pinned_memory = v->pinned_memory; 393 w->spptr = v->spptr; 394 PetscCall(PetscObjectStateIncrease(PetscObjectCast(w))); 395 } else { 396 const auto array = &VecIMPLCast(w)->array; 397 398 if (access == PETSC_MEMORY_ACCESS_READ) { 399 PetscCall(VecGetArrayRead(v, const_cast<const PetscScalar **>(array))); 400 } else { 401 PetscCall(VecGetArray(v, array)); 402 } 403 w->offloadmask = PETSC_OFFLOAD_CPU; 404 if (wisseqcupm) { 405 PetscDeviceContext dctx; 406 407 PetscCall(GetHandles_(&dctx)); 408 PetscCall(DeviceAllocateCheck_(dctx, w)); 409 } 410 } 411 PetscFunctionReturn(PETSC_SUCCESS); 412 } 413 414 // v->ops->restorelocalvector or v->ops->restorelocalvectorread 415 template <device::cupm::DeviceType T> 416 template <PetscMemoryAccessMode access> 417 inline PetscErrorCode VecSeq_CUPM<T>::RestoreLocalVector(Vec v, Vec w) noexcept 418 { 419 PetscBool wisseqcupm; 420 421 PetscFunctionBegin; 422 PetscCheckTypeNames(v, VECSEQCUPM(), VECMPICUPM()); 423 PetscCall(PetscObjectTypeCompare(PetscObjectCast(w), VECSEQCUPM(), &wisseqcupm)); 424 if (v->petscnative && wisseqcupm) { 425 // the assignments to nullptr are __critical__, as w may persist after this call returns 426 // and shouldn't share data with v! 427 v->pinned_memory = w->pinned_memory; 428 v->offloadmask = util::exchange(w->offloadmask, PETSC_OFFLOAD_UNALLOCATED); 429 v->data = util::exchange(w->data, nullptr); 430 v->spptr = util::exchange(w->spptr, nullptr); 431 } else { 432 const auto array = &VecIMPLCast(w)->array; 433 434 if (access == PETSC_MEMORY_ACCESS_READ) { 435 PetscCall(VecRestoreArrayRead(v, const_cast<const PetscScalar **>(array))); 436 } else { 437 PetscCall(VecRestoreArray(v, array)); 438 } 439 if (w->spptr && wisseqcupm) { 440 cupmStream_t stream; 441 442 PetscCall(GetHandles_(&stream)); 443 PetscCallCUPM(cupmFreeAsync(VecCUPMCast(w)->array_d, stream)); 444 PetscCall(PetscFree(w->spptr)); 445 } 446 } 447 PetscFunctionReturn(PETSC_SUCCESS); 448 } 449 450 // ========================================================================================== 451 // VecSeq_CUPM - Public API - Compute Methods 452 // ========================================================================================== 453 454 // v->ops->aypx 455 template <device::cupm::DeviceType T> 456 inline PetscErrorCode VecSeq_CUPM<T>::AYPX(Vec yin, PetscScalar alpha, Vec xin) noexcept 457 { 458 const auto n = static_cast<cupmBlasInt_t>(yin->map->n); 459 const auto sync = n != 0; 460 PetscDeviceContext dctx; 461 462 PetscFunctionBegin; 463 PetscCall(GetHandles_(&dctx)); 464 if (alpha == PetscScalar(0.0)) { 465 cupmStream_t stream; 466 467 PetscCall(GetHandlesFrom_(dctx, &stream)); 468 PetscCall(PetscLogGpuTimeBegin()); 469 PetscCall(PetscCUPMMemcpyAsync(DeviceArrayWrite(dctx, yin).data(), DeviceArrayRead(dctx, xin).data(), n, cupmMemcpyDeviceToDevice, stream)); 470 PetscCall(PetscLogGpuTimeEnd()); 471 } else if (n) { 472 const auto alphaIsOne = alpha == PetscScalar(1.0); 473 const auto calpha = cupmScalarPtrCast(&alpha); 474 cupmBlasHandle_t cupmBlasHandle; 475 476 PetscCall(GetHandlesFrom_(dctx, &cupmBlasHandle)); 477 { 478 const auto yptr = DeviceArrayReadWrite(dctx, yin); 479 const auto xptr = DeviceArrayRead(dctx, xin); 480 481 PetscCall(PetscLogGpuTimeBegin()); 482 if (alphaIsOne) { 483 PetscCallCUPMBLAS(cupmBlasXaxpy(cupmBlasHandle, n, calpha, xptr.cupmdata(), 1, yptr.cupmdata(), 1)); 484 } else { 485 const auto one = cupmScalarCast(1.0); 486 487 PetscCallCUPMBLAS(cupmBlasXscal(cupmBlasHandle, n, calpha, yptr.cupmdata(), 1)); 488 PetscCallCUPMBLAS(cupmBlasXaxpy(cupmBlasHandle, n, &one, xptr.cupmdata(), 1, yptr.cupmdata(), 1)); 489 } 490 PetscCall(PetscLogGpuTimeEnd()); 491 } 492 PetscCall(PetscLogGpuFlops((alphaIsOne ? 1 : 2) * n)); 493 } 494 if (sync) PetscCall(PetscDeviceContextSynchronize(dctx)); 495 PetscFunctionReturn(PETSC_SUCCESS); 496 } 497 498 // v->ops->axpy 499 template <device::cupm::DeviceType T> 500 inline PetscErrorCode VecSeq_CUPM<T>::AXPY(Vec yin, PetscScalar alpha, Vec xin) noexcept 501 { 502 PetscBool xiscupm; 503 504 PetscFunctionBegin; 505 PetscCall(PetscObjectTypeCompareAny(PetscObjectCast(xin), &xiscupm, VECSEQCUPM(), VECMPICUPM(), "")); 506 if (xiscupm) { 507 const auto n = static_cast<cupmBlasInt_t>(yin->map->n); 508 PetscDeviceContext dctx; 509 cupmBlasHandle_t cupmBlasHandle; 510 511 PetscCall(GetHandles_(&dctx, &cupmBlasHandle)); 512 PetscCall(PetscLogGpuTimeBegin()); 513 PetscCallCUPMBLAS(cupmBlasXaxpy(cupmBlasHandle, n, cupmScalarPtrCast(&alpha), DeviceArrayRead(dctx, xin), 1, DeviceArrayReadWrite(dctx, yin), 1)); 514 PetscCall(PetscLogGpuTimeEnd()); 515 PetscCall(PetscLogGpuFlops(2 * n)); 516 PetscCall(PetscDeviceContextSynchronize(dctx)); 517 } else { 518 PetscCall(VecAXPY_Seq(yin, alpha, xin)); 519 } 520 PetscFunctionReturn(PETSC_SUCCESS); 521 } 522 523 // v->ops->pointwisedivide 524 template <device::cupm::DeviceType T> 525 inline PetscErrorCode VecSeq_CUPM<T>::PointwiseDivide(Vec win, Vec xin, Vec yin) noexcept 526 { 527 PetscFunctionBegin; 528 if (xin->boundtocpu || yin->boundtocpu) { 529 PetscCall(VecPointwiseDivide_Seq(win, xin, yin)); 530 } else { 531 // note order of arguments! xin and yin are read, win is written! 532 PetscCall(PointwiseBinary_(thrust::divides<PetscScalar>{}, xin, yin, win)); 533 } 534 PetscFunctionReturn(PETSC_SUCCESS); 535 } 536 537 // v->ops->pointwisemult 538 template <device::cupm::DeviceType T> 539 inline PetscErrorCode VecSeq_CUPM<T>::PointwiseMult(Vec win, Vec xin, Vec yin) noexcept 540 { 541 PetscFunctionBegin; 542 if (xin->boundtocpu || yin->boundtocpu) { 543 PetscCall(VecPointwiseMult_Seq(win, xin, yin)); 544 } else { 545 // note order of arguments! xin and yin are read, win is written! 546 PetscCall(PointwiseBinary_(thrust::multiplies<PetscScalar>{}, xin, yin, win)); 547 } 548 PetscFunctionReturn(PETSC_SUCCESS); 549 } 550 551 namespace detail 552 { 553 554 struct reciprocal { 555 PETSC_HOSTDEVICE_INLINE_DECL PetscScalar operator()(PetscScalar s) const noexcept 556 { 557 // yes all of this verbosity is needed because sometimes PetscScalar is a thrust::complex 558 // and then it matters whether we do s ? true : false vs s == 0, as well as whether we wrap 559 // everything in PetscScalar... 560 return s == PetscScalar{0.0} ? s : PetscScalar{1.0} / s; 561 } 562 }; 563 564 } // namespace detail 565 566 // v->ops->reciprocal 567 template <device::cupm::DeviceType T> 568 inline PetscErrorCode VecSeq_CUPM<T>::Reciprocal(Vec xin) noexcept 569 { 570 PetscFunctionBegin; 571 PetscCall(PointwiseUnary_(detail::reciprocal{}, xin)); 572 PetscFunctionReturn(PETSC_SUCCESS); 573 } 574 575 // v->ops->waxpy 576 template <device::cupm::DeviceType T> 577 inline PetscErrorCode VecSeq_CUPM<T>::WAXPY(Vec win, PetscScalar alpha, Vec xin, Vec yin) noexcept 578 { 579 PetscFunctionBegin; 580 if (alpha == PetscScalar(0.0)) { 581 PetscCall(Copy(yin, win)); 582 } else if (const auto n = static_cast<cupmBlasInt_t>(win->map->n)) { 583 PetscDeviceContext dctx; 584 cupmBlasHandle_t cupmBlasHandle; 585 cupmStream_t stream; 586 587 PetscCall(GetHandles_(&dctx, &cupmBlasHandle, &stream)); 588 { 589 const auto wptr = DeviceArrayWrite(dctx, win); 590 591 PetscCall(PetscLogGpuTimeBegin()); 592 PetscCall(PetscCUPMMemcpyAsync(wptr.data(), DeviceArrayRead(dctx, yin).data(), n, cupmMemcpyDeviceToDevice, stream, true)); 593 PetscCallCUPMBLAS(cupmBlasXaxpy(cupmBlasHandle, n, cupmScalarPtrCast(&alpha), DeviceArrayRead(dctx, xin), 1, wptr.cupmdata(), 1)); 594 PetscCall(PetscLogGpuTimeEnd()); 595 } 596 PetscCall(PetscLogGpuFlops(2 * n)); 597 PetscCall(PetscDeviceContextSynchronize(dctx)); 598 } 599 PetscFunctionReturn(PETSC_SUCCESS); 600 } 601 602 namespace kernels 603 { 604 605 template <typename... Args> 606 PETSC_KERNEL_DECL static void MAXPY_kernel(const PetscInt size, PetscScalar *PETSC_RESTRICT xptr, const PetscScalar *PETSC_RESTRICT aptr, Args... yptr) 607 { 608 constexpr int N = sizeof...(Args); 609 const auto tx = threadIdx.x; 610 const PetscScalar *yptr_p[] = {yptr...}; 611 612 PETSC_SHAREDMEM_DECL PetscScalar aptr_shmem[N]; 613 614 // load a to shared memory 615 if (tx < N) aptr_shmem[tx] = aptr[tx]; 616 __syncthreads(); 617 618 ::Petsc::device::cupm::kernels::util::grid_stride_1D(size, [&](PetscInt i) { 619 // these may look the same but give different results! 620 #if 0 621 PetscScalar sum = 0.0; 622 623 #pragma unroll 624 for (auto j = 0; j < N; ++j) sum += aptr_shmem[j]*yptr_p[j][i]; 625 xptr[i] += sum; 626 #else 627 auto sum = xptr[i]; 628 629 #pragma unroll 630 for (auto j = 0; j < N; ++j) sum += aptr_shmem[j]*yptr_p[j][i]; 631 xptr[i] = sum; 632 #endif 633 }); 634 return; 635 } 636 637 } // namespace kernels 638 639 namespace detail 640 { 641 642 // a helper-struct to gobble the size_t input, it is used with template parameter pack 643 // expansion such that 644 // typename repeat_type<MyType, IdxParamPack>... 645 // expands to 646 // MyType, MyType, MyType, ... [repeated sizeof...(IdxParamPack) times] 647 template <typename T, std::size_t> 648 struct repeat_type { 649 using type = T; 650 }; 651 652 } // namespace detail 653 654 template <device::cupm::DeviceType T> 655 template <std::size_t... Idx> 656 inline PetscErrorCode VecSeq_CUPM<T>::MAXPY_kernel_dispatch_(PetscDeviceContext dctx, cupmStream_t stream, PetscScalar *xptr, const PetscScalar *aptr, const Vec *yin, PetscInt size, util::index_sequence<Idx...>) noexcept 657 { 658 PetscFunctionBegin; 659 // clang-format off 660 PetscCall( 661 PetscCUPMLaunchKernel1D( 662 size, 0, stream, 663 kernels::MAXPY_kernel<typename detail::repeat_type<const PetscScalar *, Idx>::type...>, 664 size, xptr, aptr, DeviceArrayRead(dctx, yin[Idx]).data()... 665 ) 666 ); 667 // clang-format on 668 PetscFunctionReturn(PETSC_SUCCESS); 669 } 670 671 template <device::cupm::DeviceType T> 672 template <int N> 673 inline PetscErrorCode VecSeq_CUPM<T>::MAXPY_kernel_dispatch_(PetscDeviceContext dctx, cupmStream_t stream, PetscScalar *xptr, const PetscScalar *aptr, const Vec *yin, PetscInt size, PetscInt &yidx) noexcept 674 { 675 PetscFunctionBegin; 676 PetscCall(MAXPY_kernel_dispatch_(dctx, stream, xptr, aptr + yidx, yin + yidx, size, util::make_index_sequence<N>{})); 677 yidx += N; 678 PetscFunctionReturn(PETSC_SUCCESS); 679 } 680 681 // v->ops->maxpy 682 template <device::cupm::DeviceType T> 683 inline PetscErrorCode VecSeq_CUPM<T>::MAXPY(Vec xin, PetscInt nv, const PetscScalar *alpha, Vec *yin) noexcept 684 { 685 const auto n = xin->map->n; 686 PetscDeviceContext dctx; 687 cupmStream_t stream; 688 689 PetscFunctionBegin; 690 PetscCall(GetHandles_(&dctx, &stream)); 691 { 692 const auto xptr = DeviceArrayReadWrite(dctx, xin); 693 PetscScalar *d_alpha = nullptr; 694 PetscInt yidx = 0; 695 696 // placement of early-return is deliberate, we would like to capture the 697 // DeviceArrayReadWrite() call (which calls PetscObjectStateIncreate()) before we bail 698 if (!n || !nv) PetscFunctionReturn(PETSC_SUCCESS); 699 PetscCall(PetscDeviceMalloc(dctx, PETSC_MEMTYPE_CUPM(), nv, &d_alpha)); 700 PetscCall(PetscCUPMMemcpyAsync(d_alpha, alpha, nv, cupmMemcpyHostToDevice, stream)); 701 PetscCall(PetscLogGpuTimeBegin()); 702 do { 703 switch (nv - yidx) { 704 case 7: 705 PetscCall(MAXPY_kernel_dispatch_<7>(dctx, stream, xptr.data(), d_alpha, yin, n, yidx)); 706 break; 707 case 6: 708 PetscCall(MAXPY_kernel_dispatch_<6>(dctx, stream, xptr.data(), d_alpha, yin, n, yidx)); 709 break; 710 case 5: 711 PetscCall(MAXPY_kernel_dispatch_<5>(dctx, stream, xptr.data(), d_alpha, yin, n, yidx)); 712 break; 713 case 4: 714 PetscCall(MAXPY_kernel_dispatch_<4>(dctx, stream, xptr.data(), d_alpha, yin, n, yidx)); 715 break; 716 case 3: 717 PetscCall(MAXPY_kernel_dispatch_<3>(dctx, stream, xptr.data(), d_alpha, yin, n, yidx)); 718 break; 719 case 2: 720 PetscCall(MAXPY_kernel_dispatch_<2>(dctx, stream, xptr.data(), d_alpha, yin, n, yidx)); 721 break; 722 case 1: 723 PetscCall(MAXPY_kernel_dispatch_<1>(dctx, stream, xptr.data(), d_alpha, yin, n, yidx)); 724 break; 725 default: // 8 or more 726 PetscCall(MAXPY_kernel_dispatch_<8>(dctx, stream, xptr.data(), d_alpha, yin, n, yidx)); 727 break; 728 } 729 } while (yidx < nv); 730 PetscCall(PetscLogGpuTimeEnd()); 731 PetscCall(PetscDeviceFree(dctx, d_alpha)); 732 } 733 PetscCall(PetscLogGpuFlops(nv * 2 * n)); 734 PetscCall(PetscDeviceContextSynchronize(dctx)); 735 PetscFunctionReturn(PETSC_SUCCESS); 736 } 737 738 template <device::cupm::DeviceType T> 739 inline PetscErrorCode VecSeq_CUPM<T>::Dot(Vec xin, Vec yin, PetscScalar *z) noexcept 740 { 741 PetscFunctionBegin; 742 if (const auto n = static_cast<cupmBlasInt_t>(xin->map->n)) { 743 PetscDeviceContext dctx; 744 cupmBlasHandle_t cupmBlasHandle; 745 746 PetscCall(GetHandles_(&dctx, &cupmBlasHandle)); 747 // arguments y, x are reversed because BLAS complex conjugates the first argument, PETSc the 748 // second 749 PetscCall(PetscLogGpuTimeBegin()); 750 PetscCallCUPMBLAS(cupmBlasXdot(cupmBlasHandle, n, DeviceArrayRead(dctx, yin), 1, DeviceArrayRead(dctx, xin), 1, cupmScalarPtrCast(z))); 751 PetscCall(PetscLogGpuTimeEnd()); 752 PetscCall(PetscLogGpuFlops(2 * n - 1)); 753 } else { 754 *z = 0.0; 755 } 756 PetscFunctionReturn(PETSC_SUCCESS); 757 } 758 759 #define MDOT_WORKGROUP_NUM 128 760 #define MDOT_WORKGROUP_SIZE MDOT_WORKGROUP_NUM 761 762 namespace kernels 763 { 764 765 PETSC_DEVICE_INLINE_DECL static PetscInt EntriesPerGroup(const PetscInt size) noexcept 766 { 767 const auto group_entries = (size - 1) / gridDim.x + 1; 768 // for very small vectors, a group should still do some work 769 return group_entries ? group_entries : 1; 770 } 771 772 template <typename... ConstPetscScalarPointer> 773 PETSC_KERNEL_DECL static void MDot_kernel(const PetscScalar *PETSC_RESTRICT x, const PetscInt size, PetscScalar *PETSC_RESTRICT results, ConstPetscScalarPointer... y) 774 { 775 constexpr int N = sizeof...(ConstPetscScalarPointer); 776 const PetscScalar *ylocal[] = {y...}; 777 PetscScalar sumlocal[N]; 778 779 PETSC_SHAREDMEM_DECL PetscScalar shmem[N * MDOT_WORKGROUP_SIZE]; 780 781 // HIP -- for whatever reason -- has threadIdx, blockIdx, blockDim, and gridDim as separate 782 // types, so each of these go on separate lines... 783 const auto tx = threadIdx.x; 784 const auto bx = blockIdx.x; 785 const auto bdx = blockDim.x; 786 const auto gdx = gridDim.x; 787 const auto worksize = EntriesPerGroup(size); 788 const auto begin = tx + bx * worksize; 789 const auto end = min((bx + 1) * worksize, size); 790 791 #pragma unroll 792 for (auto i = 0; i < N; ++i) sumlocal[i] = 0; 793 794 for (auto i = begin; i < end; i += bdx) { 795 const auto xi = x[i]; // load only once from global memory! 796 797 #pragma unroll 798 for (auto j = 0; j < N; ++j) sumlocal[j] += ylocal[j][i] * xi; 799 } 800 801 #pragma unroll 802 for (auto i = 0; i < N; ++i) shmem[tx + i * MDOT_WORKGROUP_SIZE] = sumlocal[i]; 803 804 // parallel reduction 805 for (auto stride = bdx / 2; stride > 0; stride /= 2) { 806 __syncthreads(); 807 if (tx < stride) { 808 #pragma unroll 809 for (auto i = 0; i < N; ++i) shmem[tx + i * MDOT_WORKGROUP_SIZE] += shmem[tx + stride + i * MDOT_WORKGROUP_SIZE]; 810 } 811 } 812 // bottom N threads per block write to global memory 813 // REVIEW ME: I am ~pretty~ sure we don't need another __syncthreads() here since each thread 814 // writes to the same sections in the above loop that it is about to read from below, but 815 // running this under the racecheck tool of cuda-memcheck reports a write-after-write hazard. 816 __syncthreads(); 817 if (tx < N) results[bx + tx * gdx] = shmem[tx * MDOT_WORKGROUP_SIZE]; 818 return; 819 } 820 821 namespace 822 { 823 824 PETSC_KERNEL_DECL void sum_kernel(const PetscInt size, PetscScalar *PETSC_RESTRICT results) 825 { 826 int local_i = 0; 827 PetscScalar local_results[8]; 828 829 // each thread sums up MDOT_WORKGROUP_NUM entries of the result, storing it in a local buffer 830 // 831 // *-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-* 832 // | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | ... 833 // *-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-* 834 // | ______________________________________________________/ 835 // | / <- MDOT_WORKGROUP_NUM -> 836 // |/ 837 // + 838 // v 839 // *-*-* 840 // | | | ... 841 // *-*-* 842 // 843 ::Petsc::device::cupm::kernels::util::grid_stride_1D(size, [&](PetscInt i) { 844 PetscScalar z_sum = 0; 845 846 for (auto j = i * MDOT_WORKGROUP_SIZE; j < (i + 1) * MDOT_WORKGROUP_SIZE; ++j) z_sum += results[j]; 847 local_results[local_i++] = z_sum; 848 }); 849 // if we needed more than 1 workgroup to handle the vector we should sync since other threads 850 // may currently be reading from results 851 if (size >= MDOT_WORKGROUP_SIZE) __syncthreads(); 852 // Local buffer is now written to global memory 853 ::Petsc::device::cupm::kernels::util::grid_stride_1D(size, [&](PetscInt i) { 854 const auto j = --local_i; 855 856 if (j >= 0) results[i] = local_results[j]; 857 }); 858 return; 859 } 860 861 } // namespace 862 863 } // namespace kernels 864 865 template <device::cupm::DeviceType T> 866 template <std::size_t... Idx> 867 inline PetscErrorCode VecSeq_CUPM<T>::MDot_kernel_dispatch_(PetscDeviceContext dctx, cupmStream_t stream, const PetscScalar *xarr, const Vec yin[], PetscInt size, PetscScalar *results, util::index_sequence<Idx...>) noexcept 868 { 869 PetscFunctionBegin; 870 // REVIEW ME: convert this kernel launch to PetscCUPMLaunchKernel1D(), it currently launches 871 // 128 blocks of 128 threads every time which may be wasteful 872 // clang-format off 873 PetscCallCUPM( 874 cupmLaunchKernel( 875 kernels::MDot_kernel<typename detail::repeat_type<const PetscScalar *, Idx>::type...>, 876 MDOT_WORKGROUP_NUM, MDOT_WORKGROUP_SIZE, 0, stream, 877 xarr, size, results, DeviceArrayRead(dctx, yin[Idx]).data()... 878 ) 879 ); 880 // clang-format on 881 PetscFunctionReturn(PETSC_SUCCESS); 882 } 883 884 template <device::cupm::DeviceType T> 885 template <int N> 886 inline PetscErrorCode VecSeq_CUPM<T>::MDot_kernel_dispatch_(PetscDeviceContext dctx, cupmStream_t stream, const PetscScalar *xarr, const Vec yin[], PetscInt size, PetscScalar *results, PetscInt &yidx) noexcept 887 { 888 PetscFunctionBegin; 889 PetscCall(MDot_kernel_dispatch_(dctx, stream, xarr, yin + yidx, size, results + yidx * MDOT_WORKGROUP_NUM, util::make_index_sequence<N>{})); 890 yidx += N; 891 PetscFunctionReturn(PETSC_SUCCESS); 892 } 893 894 template <device::cupm::DeviceType T> 895 inline PetscErrorCode VecSeq_CUPM<T>::MDot_(std::false_type, Vec xin, PetscInt nv, const Vec yin[], PetscScalar *z, PetscDeviceContext dctx) noexcept 896 { 897 // the largest possible size of a batch 898 constexpr PetscInt batchsize = 8; 899 // how many sub streams to create, if nv <= batchsize we can do this without looping, so we 900 // do not create substreams. Note we don't create more than 8 streams, in practice we could 901 // not get more parallelism with higher numbers. 902 const auto num_sub_streams = nv > batchsize ? std::min((nv + batchsize) / batchsize, batchsize) : 0; 903 const auto n = xin->map->n; 904 // number of vectors that we handle via the batches. note any singletons are handled by 905 // cublas, hence the nv-1. 906 const auto nvbatch = ((nv % batchsize) == 1) ? nv - 1 : nv; 907 const auto nwork = nvbatch * MDOT_WORKGROUP_NUM; 908 PetscScalar *d_results; 909 cupmStream_t stream; 910 911 PetscFunctionBegin; 912 PetscCall(GetHandlesFrom_(dctx, &stream)); 913 // allocate scratchpad memory for the results of individual work groups 914 PetscCall(PetscDeviceMalloc(dctx, PETSC_MEMTYPE_CUPM(), nwork, &d_results)); 915 { 916 const auto xptr = DeviceArrayRead(dctx, xin); 917 PetscInt yidx = 0; 918 auto subidx = 0; 919 auto cur_stream = stream; 920 auto cur_ctx = dctx; 921 PetscDeviceContext *sub = nullptr; 922 PetscStreamType stype; 923 924 // REVIEW ME: maybe PetscDeviceContextFork() should insert dctx into the first entry of 925 // sub. Ideally the parent context should also join in on the fork, but it is extremely 926 // fiddly to do so presently 927 PetscCall(PetscDeviceContextGetStreamType(dctx, &stype)); 928 if (stype == PETSC_STREAM_GLOBAL_BLOCKING) stype = PETSC_STREAM_DEFAULT_BLOCKING; 929 // If we have a globally blocking stream create nonblocking streams instead (as we can 930 // locally exploit the parallelism). Otherwise use the prescribed stream type. 931 PetscCall(PetscDeviceContextForkWithStreamType(dctx, stype, num_sub_streams, &sub)); 932 PetscCall(PetscLogGpuTimeBegin()); 933 do { 934 if (num_sub_streams) { 935 cur_ctx = sub[subidx++ % num_sub_streams]; 936 PetscCall(GetHandlesFrom_(cur_ctx, &cur_stream)); 937 } 938 // REVIEW ME: Should probably try and load-balance these. Consider the case where nv = 9; 939 // it is very likely better to do 4+5 rather than 8+1 940 switch (nv - yidx) { 941 case 7: 942 PetscCall(MDot_kernel_dispatch_<7>(cur_ctx, cur_stream, xptr.data(), yin, n, d_results, yidx)); 943 break; 944 case 6: 945 PetscCall(MDot_kernel_dispatch_<6>(cur_ctx, cur_stream, xptr.data(), yin, n, d_results, yidx)); 946 break; 947 case 5: 948 PetscCall(MDot_kernel_dispatch_<5>(cur_ctx, cur_stream, xptr.data(), yin, n, d_results, yidx)); 949 break; 950 case 4: 951 PetscCall(MDot_kernel_dispatch_<4>(cur_ctx, cur_stream, xptr.data(), yin, n, d_results, yidx)); 952 break; 953 case 3: 954 PetscCall(MDot_kernel_dispatch_<3>(cur_ctx, cur_stream, xptr.data(), yin, n, d_results, yidx)); 955 break; 956 case 2: 957 PetscCall(MDot_kernel_dispatch_<2>(cur_ctx, cur_stream, xptr.data(), yin, n, d_results, yidx)); 958 break; 959 case 1: { 960 cupmBlasHandle_t cupmBlasHandle; 961 962 PetscCall(GetHandlesFrom_(cur_ctx, &cupmBlasHandle)); 963 PetscCallCUPMBLAS(cupmBlasXdot(cupmBlasHandle, static_cast<cupmBlasInt_t>(n), DeviceArrayRead(cur_ctx, yin[yidx]).cupmdata(), 1, xptr.cupmdata(), 1, cupmScalarPtrCast(z + yidx))); 964 ++yidx; 965 } break; 966 default: // 8 or more 967 PetscCall(MDot_kernel_dispatch_<8>(cur_ctx, cur_stream, xptr.data(), yin, n, d_results, yidx)); 968 break; 969 } 970 } while (yidx < nv); 971 PetscCall(PetscLogGpuTimeEnd()); 972 PetscCall(PetscDeviceContextJoin(dctx, num_sub_streams, PETSC_DEVICE_CONTEXT_JOIN_DESTROY, &sub)); 973 } 974 975 PetscCall(PetscCUPMLaunchKernel1D(nvbatch, 0, stream, kernels::sum_kernel, nvbatch, d_results)); 976 // copy result of device reduction to host 977 PetscCall(PetscCUPMMemcpyAsync(z, d_results, nvbatch, cupmMemcpyDeviceToHost, stream)); 978 // do these now while final reduction is in flight 979 PetscCall(PetscLogFlops(nwork)); 980 PetscCall(PetscDeviceFree(dctx, d_results)); 981 PetscFunctionReturn(PETSC_SUCCESS); 982 } 983 984 #undef MDOT_WORKGROUP_NUM 985 #undef MDOT_WORKGROUP_SIZE 986 987 template <device::cupm::DeviceType T> 988 inline PetscErrorCode VecSeq_CUPM<T>::MDot_(std::true_type, Vec xin, PetscInt nv, const Vec yin[], PetscScalar *z, PetscDeviceContext dctx) noexcept 989 { 990 // probably not worth it to run more than 8 of these at a time? 991 const auto n_sub = PetscMin(nv, 8); 992 const auto n = static_cast<cupmBlasInt_t>(xin->map->n); 993 const auto xptr = DeviceArrayRead(dctx, xin); 994 PetscScalar *d_z; 995 PetscDeviceContext *subctx; 996 cupmStream_t stream; 997 998 PetscFunctionBegin; 999 PetscCall(GetHandlesFrom_(dctx, &stream)); 1000 PetscCall(PetscDeviceMalloc(dctx, PETSC_MEMTYPE_CUPM(), nv, &d_z)); 1001 PetscCall(PetscDeviceContextFork(dctx, n_sub, &subctx)); 1002 PetscCall(PetscLogGpuTimeBegin()); 1003 for (PetscInt i = 0; i < nv; ++i) { 1004 const auto sub = subctx[i % n_sub]; 1005 cupmBlasHandle_t handle; 1006 cupmBlasPointerMode_t old_mode; 1007 1008 PetscCall(GetHandlesFrom_(sub, &handle)); 1009 PetscCallCUPMBLAS(cupmBlasGetPointerMode(handle, &old_mode)); 1010 if (old_mode != CUPMBLAS_POINTER_MODE_DEVICE) PetscCallCUPMBLAS(cupmBlasSetPointerMode(handle, CUPMBLAS_POINTER_MODE_DEVICE)); 1011 PetscCallCUPMBLAS(cupmBlasXdot(handle, n, DeviceArrayRead(sub, yin[i]), 1, xptr.cupmdata(), 1, cupmScalarPtrCast(d_z + i))); 1012 if (old_mode != CUPMBLAS_POINTER_MODE_DEVICE) PetscCallCUPMBLAS(cupmBlasSetPointerMode(handle, old_mode)); 1013 } 1014 PetscCall(PetscLogGpuTimeEnd()); 1015 PetscCall(PetscDeviceContextJoin(dctx, n_sub, PETSC_DEVICE_CONTEXT_JOIN_DESTROY, &subctx)); 1016 PetscCall(PetscCUPMMemcpyAsync(z, d_z, nv, cupmMemcpyDeviceToHost, stream)); 1017 PetscCall(PetscDeviceFree(dctx, d_z)); 1018 // REVIEW ME: flops????? 1019 PetscFunctionReturn(PETSC_SUCCESS); 1020 } 1021 1022 // v->ops->mdot 1023 template <device::cupm::DeviceType T> 1024 inline PetscErrorCode VecSeq_CUPM<T>::MDot(Vec xin, PetscInt nv, const Vec yin[], PetscScalar *z) noexcept 1025 { 1026 PetscFunctionBegin; 1027 if (PetscUnlikely(nv == 1)) { 1028 // dot handles nv = 0 correctly 1029 PetscCall(Dot(xin, const_cast<Vec>(yin[0]), z)); 1030 } else if (const auto n = xin->map->n) { 1031 PetscDeviceContext dctx; 1032 1033 PetscCheck(nv > 0, PETSC_COMM_SELF, PETSC_ERR_LIB, "Number of vectors provided to %s %" PetscInt_FMT " not positive", PETSC_FUNCTION_NAME, nv); 1034 PetscCall(GetHandles_(&dctx)); 1035 PetscCall(MDot_(std::integral_constant<bool, PetscDefined(USE_COMPLEX)>{}, xin, nv, yin, z, dctx)); 1036 // REVIEW ME: double count of flops?? 1037 PetscCall(PetscLogGpuFlops(nv * (2 * n - 1))); 1038 PetscCall(PetscDeviceContextSynchronize(dctx)); 1039 } else { 1040 PetscCall(PetscArrayzero(z, nv)); 1041 } 1042 PetscFunctionReturn(PETSC_SUCCESS); 1043 } 1044 1045 // v->ops->set 1046 template <device::cupm::DeviceType T> 1047 inline PetscErrorCode VecSeq_CUPM<T>::Set(Vec xin, PetscScalar alpha) noexcept 1048 { 1049 const auto n = xin->map->n; 1050 PetscDeviceContext dctx; 1051 cupmStream_t stream; 1052 1053 PetscFunctionBegin; 1054 PetscCall(GetHandles_(&dctx, &stream)); 1055 { 1056 const auto xptr = DeviceArrayWrite(dctx, xin); 1057 1058 if (alpha == PetscScalar(0.0)) { 1059 PetscCall(PetscCUPMMemsetAsync(xptr.data(), 0, n, stream)); 1060 } else { 1061 const auto dptr = thrust::device_pointer_cast(xptr.data()); 1062 1063 PetscCallThrust(THRUST_CALL(thrust::fill, stream, dptr, dptr + n, alpha)); 1064 } 1065 if (n) PetscCall(PetscDeviceContextSynchronize(dctx)); // don't sync if we did nothing 1066 } 1067 PetscFunctionReturn(PETSC_SUCCESS); 1068 } 1069 1070 // v->ops->scale 1071 template <device::cupm::DeviceType T> 1072 inline PetscErrorCode VecSeq_CUPM<T>::Scale(Vec xin, PetscScalar alpha) noexcept 1073 { 1074 PetscFunctionBegin; 1075 if (PetscUnlikely(alpha == PetscScalar(1.0))) PetscFunctionReturn(PETSC_SUCCESS); 1076 if (PetscUnlikely(alpha == PetscScalar(0.0))) { 1077 PetscCall(Set(xin, alpha)); 1078 } else if (const auto n = static_cast<cupmBlasInt_t>(xin->map->n)) { 1079 PetscDeviceContext dctx; 1080 cupmBlasHandle_t cupmBlasHandle; 1081 1082 PetscCall(GetHandles_(&dctx, &cupmBlasHandle)); 1083 PetscCall(PetscLogGpuTimeBegin()); 1084 PetscCallCUPMBLAS(cupmBlasXscal(cupmBlasHandle, n, cupmScalarPtrCast(&alpha), DeviceArrayReadWrite(dctx, xin), 1)); 1085 PetscCall(PetscLogGpuTimeEnd()); 1086 PetscCall(PetscLogGpuFlops(n)); 1087 PetscCall(PetscDeviceContextSynchronize(dctx)); 1088 } else { 1089 PetscCall(MaybeIncrementEmptyLocalVec(xin)); 1090 } 1091 PetscFunctionReturn(PETSC_SUCCESS); 1092 } 1093 1094 // v->ops->tdot 1095 template <device::cupm::DeviceType T> 1096 inline PetscErrorCode VecSeq_CUPM<T>::TDot(Vec xin, Vec yin, PetscScalar *z) noexcept 1097 { 1098 PetscFunctionBegin; 1099 if (const auto n = static_cast<cupmBlasInt_t>(xin->map->n)) { 1100 PetscDeviceContext dctx; 1101 cupmBlasHandle_t cupmBlasHandle; 1102 1103 PetscCall(GetHandles_(&dctx, &cupmBlasHandle)); 1104 PetscCall(PetscLogGpuTimeBegin()); 1105 PetscCallCUPMBLAS(cupmBlasXdotu(cupmBlasHandle, n, DeviceArrayRead(dctx, xin), 1, DeviceArrayRead(dctx, yin), 1, cupmScalarPtrCast(z))); 1106 PetscCall(PetscLogGpuTimeEnd()); 1107 PetscCall(PetscLogGpuFlops(2 * n - 1)); 1108 } else { 1109 *z = 0.0; 1110 } 1111 PetscFunctionReturn(PETSC_SUCCESS); 1112 } 1113 1114 // v->ops->copy 1115 template <device::cupm::DeviceType T> 1116 inline PetscErrorCode VecSeq_CUPM<T>::Copy(Vec xin, Vec yout) noexcept 1117 { 1118 PetscFunctionBegin; 1119 if (xin == yout) PetscFunctionReturn(PETSC_SUCCESS); 1120 if (const auto n = xin->map->n) { 1121 const auto xmask = xin->offloadmask; 1122 // silence buggy gcc warning: mode may be used uninitialized in this function 1123 auto mode = cupmMemcpyDeviceToDevice; 1124 PetscDeviceContext dctx; 1125 cupmStream_t stream; 1126 1127 // translate from PetscOffloadMask to cupmMemcpyKind 1128 switch (const auto ymask = yout->offloadmask) { 1129 case PETSC_OFFLOAD_UNALLOCATED: { 1130 PetscBool yiscupm; 1131 1132 PetscCall(PetscObjectTypeCompareAny(PetscObjectCast(yout), &yiscupm, VECSEQCUPM(), VECMPICUPM(), "")); 1133 if (yiscupm) { 1134 mode = PetscOffloadDevice(xmask) ? cupmMemcpyDeviceToDevice : cupmMemcpyHostToHost; 1135 break; 1136 } 1137 } // fall-through if unallocated and not cupm 1138 #if PETSC_CPP_VERSION >= 17 1139 [[fallthrough]]; 1140 #endif 1141 case PETSC_OFFLOAD_CPU: 1142 mode = PetscOffloadHost(xmask) ? cupmMemcpyHostToHost : cupmMemcpyDeviceToHost; 1143 break; 1144 case PETSC_OFFLOAD_BOTH: 1145 case PETSC_OFFLOAD_GPU: 1146 mode = PetscOffloadDevice(xmask) ? cupmMemcpyDeviceToDevice : cupmMemcpyHostToDevice; 1147 break; 1148 default: 1149 SETERRQ(PETSC_COMM_SELF, PETSC_ERR_ARG_INCOMP, "Incompatible offload mask %s", PetscOffloadMaskToString(ymask)); 1150 } 1151 1152 PetscCall(GetHandles_(&dctx, &stream)); 1153 switch (mode) { 1154 case cupmMemcpyDeviceToDevice: // the best case 1155 case cupmMemcpyHostToDevice: { // not terrible 1156 const auto yptr = DeviceArrayWrite(dctx, yout); 1157 const auto xptr = mode == cupmMemcpyDeviceToDevice ? DeviceArrayRead(dctx, xin).data() : HostArrayRead(dctx, xin).data(); 1158 1159 PetscCall(PetscLogGpuTimeBegin()); 1160 PetscCall(PetscCUPMMemcpyAsync(yptr.data(), xptr, n, mode, stream)); 1161 PetscCall(PetscLogGpuTimeEnd()); 1162 } break; 1163 case cupmMemcpyDeviceToHost: // not great 1164 case cupmMemcpyHostToHost: { // worst case 1165 const auto xptr = mode == cupmMemcpyDeviceToHost ? DeviceArrayRead(dctx, xin).data() : HostArrayRead(dctx, xin).data(); 1166 PetscScalar *yptr; 1167 1168 PetscCall(VecGetArrayWrite(yout, &yptr)); 1169 if (mode == cupmMemcpyDeviceToHost) PetscCall(PetscLogGpuTimeBegin()); 1170 PetscCall(PetscCUPMMemcpyAsync(yptr, xptr, n, mode, stream, /* force async */ true)); 1171 if (mode == cupmMemcpyDeviceToHost) PetscCall(PetscLogGpuTimeEnd()); 1172 PetscCall(VecRestoreArrayWrite(yout, &yptr)); 1173 } break; 1174 default: 1175 SETERRQ(PETSC_COMM_SELF, PETSC_ERR_GPU, "Unknown cupmMemcpyKind %d", static_cast<int>(mode)); 1176 } 1177 PetscCall(PetscDeviceContextSynchronize(dctx)); 1178 } else { 1179 PetscCall(MaybeIncrementEmptyLocalVec(yout)); 1180 } 1181 PetscFunctionReturn(PETSC_SUCCESS); 1182 } 1183 1184 // v->ops->swap 1185 template <device::cupm::DeviceType T> 1186 inline PetscErrorCode VecSeq_CUPM<T>::Swap(Vec xin, Vec yin) noexcept 1187 { 1188 PetscFunctionBegin; 1189 if (xin == yin) PetscFunctionReturn(PETSC_SUCCESS); 1190 if (const auto n = static_cast<cupmBlasInt_t>(xin->map->n)) { 1191 PetscDeviceContext dctx; 1192 cupmBlasHandle_t cupmBlasHandle; 1193 1194 PetscCall(GetHandles_(&dctx, &cupmBlasHandle)); 1195 PetscCall(PetscLogGpuTimeBegin()); 1196 PetscCallCUPMBLAS(cupmBlasXswap(cupmBlasHandle, n, DeviceArrayReadWrite(dctx, xin), 1, DeviceArrayReadWrite(dctx, yin), 1)); 1197 PetscCall(PetscLogGpuTimeEnd()); 1198 PetscCall(PetscDeviceContextSynchronize(dctx)); 1199 } else { 1200 PetscCall(MaybeIncrementEmptyLocalVec(xin)); 1201 PetscCall(MaybeIncrementEmptyLocalVec(yin)); 1202 } 1203 PetscFunctionReturn(PETSC_SUCCESS); 1204 } 1205 1206 // v->ops->axpby 1207 template <device::cupm::DeviceType T> 1208 inline PetscErrorCode VecSeq_CUPM<T>::AXPBY(Vec yin, PetscScalar alpha, PetscScalar beta, Vec xin) noexcept 1209 { 1210 PetscFunctionBegin; 1211 if (alpha == PetscScalar(0.0)) { 1212 PetscCall(Scale(yin, beta)); 1213 } else if (beta == PetscScalar(1.0)) { 1214 PetscCall(AXPY(yin, alpha, xin)); 1215 } else if (alpha == PetscScalar(1.0)) { 1216 PetscCall(AYPX(yin, beta, xin)); 1217 } else if (const auto n = static_cast<cupmBlasInt_t>(yin->map->n)) { 1218 const auto betaIsZero = beta == PetscScalar(0.0); 1219 const auto aptr = cupmScalarPtrCast(&alpha); 1220 PetscDeviceContext dctx; 1221 cupmBlasHandle_t cupmBlasHandle; 1222 1223 PetscCall(GetHandles_(&dctx, &cupmBlasHandle)); 1224 { 1225 const auto xptr = DeviceArrayRead(dctx, xin); 1226 1227 if (betaIsZero /* beta = 0 */) { 1228 // here we can get away with purely write-only as we memcpy into it first 1229 const auto yptr = DeviceArrayWrite(dctx, yin); 1230 cupmStream_t stream; 1231 1232 PetscCall(GetHandlesFrom_(dctx, &stream)); 1233 PetscCall(PetscLogGpuTimeBegin()); 1234 PetscCall(PetscCUPMMemcpyAsync(yptr.data(), xptr.data(), n, cupmMemcpyDeviceToDevice, stream)); 1235 PetscCallCUPMBLAS(cupmBlasXscal(cupmBlasHandle, n, aptr, yptr.cupmdata(), 1)); 1236 } else { 1237 const auto yptr = DeviceArrayReadWrite(dctx, yin); 1238 1239 PetscCall(PetscLogGpuTimeBegin()); 1240 PetscCallCUPMBLAS(cupmBlasXscal(cupmBlasHandle, n, cupmScalarPtrCast(&beta), yptr.cupmdata(), 1)); 1241 PetscCallCUPMBLAS(cupmBlasXaxpy(cupmBlasHandle, n, aptr, xptr.cupmdata(), 1, yptr.cupmdata(), 1)); 1242 } 1243 } 1244 PetscCall(PetscLogGpuTimeEnd()); 1245 PetscCall(PetscLogGpuFlops((betaIsZero ? 1 : 3) * n)); 1246 PetscCall(PetscDeviceContextSynchronize(dctx)); 1247 } else { 1248 PetscCall(MaybeIncrementEmptyLocalVec(yin)); 1249 } 1250 PetscFunctionReturn(PETSC_SUCCESS); 1251 } 1252 1253 // v->ops->axpbypcz 1254 template <device::cupm::DeviceType T> 1255 inline PetscErrorCode VecSeq_CUPM<T>::AXPBYPCZ(Vec zin, PetscScalar alpha, PetscScalar beta, PetscScalar gamma, Vec xin, Vec yin) noexcept 1256 { 1257 PetscFunctionBegin; 1258 if (gamma != PetscScalar(1.0)) PetscCall(Scale(zin, gamma)); 1259 PetscCall(AXPY(zin, alpha, xin)); 1260 PetscCall(AXPY(zin, beta, yin)); 1261 PetscFunctionReturn(PETSC_SUCCESS); 1262 } 1263 1264 // v->ops->norm 1265 template <device::cupm::DeviceType T> 1266 inline PetscErrorCode VecSeq_CUPM<T>::Norm(Vec xin, NormType type, PetscReal *z) noexcept 1267 { 1268 PetscDeviceContext dctx; 1269 cupmBlasHandle_t cupmBlasHandle; 1270 1271 PetscFunctionBegin; 1272 PetscCall(GetHandles_(&dctx, &cupmBlasHandle)); 1273 if (const auto n = static_cast<cupmBlasInt_t>(xin->map->n)) { 1274 const auto xptr = DeviceArrayRead(dctx, xin); 1275 PetscInt flopCount = 0; 1276 1277 PetscCall(PetscLogGpuTimeBegin()); 1278 switch (type) { 1279 case NORM_1_AND_2: 1280 case NORM_1: 1281 PetscCallCUPMBLAS(cupmBlasXasum(cupmBlasHandle, n, xptr.cupmdata(), 1, cupmRealPtrCast(z))); 1282 flopCount = std::max(n - 1, 0); 1283 if (type == NORM_1) break; 1284 ++z; // fall-through 1285 #if PETSC_CPP_VERSION >= 17 1286 [[fallthrough]]; 1287 #endif 1288 case NORM_2: 1289 case NORM_FROBENIUS: 1290 PetscCallCUPMBLAS(cupmBlasXnrm2(cupmBlasHandle, n, xptr.cupmdata(), 1, cupmRealPtrCast(z))); 1291 flopCount += std::max(2 * n - 1, 0); // += in case we've fallen through from NORM_1_AND_2 1292 break; 1293 case NORM_INFINITY: { 1294 cupmBlasInt_t max_loc = 0; 1295 PetscScalar xv = 0.; 1296 cupmStream_t stream; 1297 1298 PetscCall(GetHandlesFrom_(dctx, &stream)); 1299 PetscCallCUPMBLAS(cupmBlasXamax(cupmBlasHandle, n, xptr.cupmdata(), 1, &max_loc)); 1300 PetscCall(PetscCUPMMemcpyAsync(&xv, xptr.data() + max_loc - 1, 1, cupmMemcpyDeviceToHost, stream)); 1301 *z = PetscAbsScalar(xv); 1302 // REVIEW ME: flopCount = ??? 1303 } break; 1304 } 1305 PetscCall(PetscLogGpuTimeEnd()); 1306 PetscCall(PetscLogGpuFlops(flopCount)); 1307 } else { 1308 z[0] = 0.0; 1309 z[type == NORM_1_AND_2] = 0.0; 1310 } 1311 PetscFunctionReturn(PETSC_SUCCESS); 1312 } 1313 1314 namespace detail 1315 { 1316 1317 struct dotnorm2_mult { 1318 PETSC_NODISCARD PETSC_HOSTDEVICE_INLINE_DECL thrust::tuple<PetscScalar, PetscScalar> operator()(const PetscScalar &s, const PetscScalar &t) const noexcept 1319 { 1320 const auto conjt = PetscConj(t); 1321 1322 return {s * conjt, t * conjt}; 1323 } 1324 }; 1325 1326 // it is positively __bananas__ that thrust does not define default operator+ for tuples... I 1327 // would do it myself but now I am worried that they do so on purpose... 1328 struct dotnorm2_tuple_plus { 1329 using value_type = thrust::tuple<PetscScalar, PetscScalar>; 1330 1331 PETSC_NODISCARD PETSC_HOSTDEVICE_INLINE_DECL value_type operator()(const value_type &lhs, const value_type &rhs) const noexcept { return {lhs.get<0>() + rhs.get<0>(), lhs.get<1>() + rhs.get<1>()}; } 1332 }; 1333 1334 } // namespace detail 1335 1336 // v->ops->dotnorm2 1337 template <device::cupm::DeviceType T> 1338 inline PetscErrorCode VecSeq_CUPM<T>::DotNorm2(Vec s, Vec t, PetscScalar *dp, PetscScalar *nm) noexcept 1339 { 1340 PetscDeviceContext dctx; 1341 cupmStream_t stream; 1342 1343 PetscFunctionBegin; 1344 PetscCall(GetHandles_(&dctx, &stream)); 1345 { 1346 PetscScalar dpt = 0.0, nmt = 0.0; 1347 const auto sdptr = thrust::device_pointer_cast(DeviceArrayRead(dctx, s).data()); 1348 1349 // clang-format off 1350 PetscCallThrust( 1351 thrust::tie(*dp, *nm) = THRUST_CALL( 1352 thrust::inner_product, 1353 stream, 1354 sdptr, sdptr+s->map->n, thrust::device_pointer_cast(DeviceArrayRead(dctx, t).data()), 1355 thrust::make_tuple(dpt, nmt), 1356 detail::dotnorm2_tuple_plus{}, detail::dotnorm2_mult{} 1357 ); 1358 ); 1359 // clang-format on 1360 } 1361 PetscFunctionReturn(PETSC_SUCCESS); 1362 } 1363 1364 namespace detail 1365 { 1366 1367 struct conjugate { 1368 PETSC_NODISCARD PETSC_HOSTDEVICE_INLINE_DECL PetscScalar operator()(PetscScalar x) const noexcept { return PetscConj(x); } 1369 }; 1370 1371 } // namespace detail 1372 1373 // v->ops->conjugate 1374 template <device::cupm::DeviceType T> 1375 inline PetscErrorCode VecSeq_CUPM<T>::Conjugate(Vec xin) noexcept 1376 { 1377 PetscFunctionBegin; 1378 if (PetscDefined(USE_COMPLEX)) PetscCall(PointwiseUnary_(detail::conjugate{}, xin)); 1379 PetscFunctionReturn(PETSC_SUCCESS); 1380 } 1381 1382 namespace detail 1383 { 1384 1385 struct real_part { 1386 PETSC_NODISCARD PETSC_HOSTDEVICE_INLINE_DECL thrust::tuple<PetscReal, PetscInt> operator()(const thrust::tuple<PetscScalar, PetscInt> &x) const { return {PetscRealPart(x.get<0>()), x.get<1>()}; } 1387 1388 PETSC_NODISCARD PETSC_HOSTDEVICE_INLINE_DECL PetscReal operator()(PetscScalar x) const { return PetscRealPart(x); } 1389 }; 1390 1391 // deriving from Operator allows us to "store" an instance of the operator in the class but 1392 // also take advantage of empty base class optimization if the operator is stateless 1393 template <typename Operator> 1394 class tuple_compare : Operator { 1395 public: 1396 using tuple_type = thrust::tuple<PetscReal, PetscInt>; 1397 using operator_type = Operator; 1398 1399 PETSC_NODISCARD PETSC_HOSTDEVICE_INLINE_DECL tuple_type operator()(const tuple_type &x, const tuple_type &y) const noexcept 1400 { 1401 if (op_()(y.get<0>(), x.get<0>())) { 1402 // if y is strictly greater/less than x, return y 1403 return y; 1404 } else if (y.get<0>() == x.get<0>()) { 1405 // if equal, prefer lower index 1406 return y.get<1>() < x.get<1>() ? y : x; 1407 } 1408 // otherwise return x 1409 return x; 1410 } 1411 1412 private: 1413 PETSC_NODISCARD PETSC_HOSTDEVICE_INLINE_DECL const operator_type &op_() const noexcept { return *this; } 1414 }; 1415 1416 } // namespace detail 1417 1418 template <device::cupm::DeviceType T> 1419 template <typename TupleFuncT, typename UnaryFuncT> 1420 inline PetscErrorCode VecSeq_CUPM<T>::MinMax_(TupleFuncT &&tuple_ftr, UnaryFuncT &&unary_ftr, Vec v, PetscInt *p, PetscReal *m) noexcept 1421 { 1422 PetscFunctionBegin; 1423 PetscCheckTypeNames(v, VECSEQCUPM(), VECMPICUPM()); 1424 if (p) *p = -1; 1425 if (const auto n = v->map->n) { 1426 PetscDeviceContext dctx; 1427 cupmStream_t stream; 1428 1429 PetscCall(GetHandles_(&dctx, &stream)); 1430 // needed to: 1431 // 1. switch between transform_reduce and reduce 1432 // 2. strip the real_part functor from the arguments 1433 #if PetscDefined(USE_COMPLEX) 1434 #define THRUST_MINMAX_REDUCE(...) THRUST_CALL(thrust::transform_reduce, __VA_ARGS__) 1435 #else 1436 #define THRUST_MINMAX_REDUCE(s, b, e, real_part__, ...) THRUST_CALL(thrust::reduce, s, b, e, __VA_ARGS__) 1437 #endif 1438 { 1439 const auto vptr = thrust::device_pointer_cast(DeviceArrayRead(dctx, v).data()); 1440 1441 if (p) { 1442 // clang-format off 1443 const auto zip = thrust::make_zip_iterator( 1444 thrust::make_tuple(std::move(vptr), thrust::make_counting_iterator(PetscInt{0})) 1445 ); 1446 // clang-format on 1447 // need to use preprocessor conditionals since otherwise thrust complains about not being 1448 // able to convert a thrust::device_reference<PetscScalar> to a PetscReal on complex 1449 // builds... 1450 // clang-format off 1451 PetscCallThrust( 1452 thrust::tie(*m, *p) = THRUST_MINMAX_REDUCE( 1453 stream, zip, zip + n, detail::real_part{}, 1454 thrust::make_tuple(*m, *p), std::forward<TupleFuncT>(tuple_ftr) 1455 ); 1456 ); 1457 // clang-format on 1458 } else { 1459 // clang-format off 1460 PetscCallThrust( 1461 *m = THRUST_MINMAX_REDUCE( 1462 stream, vptr, vptr + n, detail::real_part{}, 1463 *m, std::forward<UnaryFuncT>(unary_ftr) 1464 ); 1465 ); 1466 // clang-format on 1467 } 1468 } 1469 #undef THRUST_MINMAX_REDUCE 1470 } 1471 // REVIEW ME: flops? 1472 PetscFunctionReturn(PETSC_SUCCESS); 1473 } 1474 1475 // v->ops->max 1476 template <device::cupm::DeviceType T> 1477 inline PetscErrorCode VecSeq_CUPM<T>::Max(Vec v, PetscInt *p, PetscReal *m) noexcept 1478 { 1479 using tuple_functor = detail::tuple_compare<thrust::greater<PetscReal>>; 1480 using unary_functor = thrust::maximum<PetscReal>; 1481 1482 PetscFunctionBegin; 1483 *m = PETSC_MIN_REAL; 1484 // use {} constructor syntax otherwise most vexing parse 1485 PetscCall(MinMax_(tuple_functor{}, unary_functor{}, v, p, m)); 1486 PetscFunctionReturn(PETSC_SUCCESS); 1487 } 1488 1489 // v->ops->min 1490 template <device::cupm::DeviceType T> 1491 inline PetscErrorCode VecSeq_CUPM<T>::Min(Vec v, PetscInt *p, PetscReal *m) noexcept 1492 { 1493 using tuple_functor = detail::tuple_compare<thrust::less<PetscReal>>; 1494 using unary_functor = thrust::minimum<PetscReal>; 1495 1496 PetscFunctionBegin; 1497 *m = PETSC_MAX_REAL; 1498 // use {} constructor syntax otherwise most vexing parse 1499 PetscCall(MinMax_(tuple_functor{}, unary_functor{}, v, p, m)); 1500 PetscFunctionReturn(PETSC_SUCCESS); 1501 } 1502 1503 // v->ops->sum 1504 template <device::cupm::DeviceType T> 1505 inline PetscErrorCode VecSeq_CUPM<T>::Sum(Vec v, PetscScalar *sum) noexcept 1506 { 1507 PetscFunctionBegin; 1508 if (const auto n = v->map->n) { 1509 PetscDeviceContext dctx; 1510 cupmStream_t stream; 1511 1512 PetscCall(GetHandles_(&dctx, &stream)); 1513 const auto dptr = thrust::device_pointer_cast(DeviceArrayRead(dctx, v).data()); 1514 // REVIEW ME: why not cupmBlasXasum()? 1515 PetscCallThrust(*sum = THRUST_CALL(thrust::reduce, stream, dptr, dptr + n, PetscScalar{0.0});); 1516 // REVIEW ME: must be at least n additions 1517 PetscCall(PetscLogGpuFlops(n)); 1518 } else { 1519 *sum = 0.0; 1520 } 1521 PetscFunctionReturn(PETSC_SUCCESS); 1522 } 1523 1524 template <device::cupm::DeviceType T> 1525 inline PetscErrorCode VecSeq_CUPM<T>::Shift(Vec v, PetscScalar shift) noexcept 1526 { 1527 PetscFunctionBegin; 1528 PetscCall(PointwiseUnary_(device::cupm::functors::make_plus_equals(shift), v)); 1529 PetscFunctionReturn(PETSC_SUCCESS); 1530 } 1531 1532 template <device::cupm::DeviceType T> 1533 inline PetscErrorCode VecSeq_CUPM<T>::SetRandom(Vec v, PetscRandom rand) noexcept 1534 { 1535 PetscFunctionBegin; 1536 if (const auto n = v->map->n) { 1537 PetscBool iscurand; 1538 PetscDeviceContext dctx; 1539 1540 PetscCall(GetHandles_(&dctx)); 1541 PetscCall(PetscObjectTypeCompare(PetscObjectCast(rand), PETSCCURAND, &iscurand)); 1542 if (iscurand) PetscCall(PetscRandomGetValues(rand, n, DeviceArrayWrite(dctx, v))); 1543 else PetscCall(PetscRandomGetValues(rand, n, HostArrayWrite(dctx, v))); 1544 } else { 1545 PetscCall(MaybeIncrementEmptyLocalVec(v)); 1546 } 1547 // REVIEW ME: flops???? 1548 // REVIEW ME: Timing??? 1549 PetscFunctionReturn(PETSC_SUCCESS); 1550 } 1551 1552 // v->ops->setpreallocation 1553 template <device::cupm::DeviceType T> 1554 inline PetscErrorCode VecSeq_CUPM<T>::SetPreallocationCOO(Vec v, PetscCount ncoo, const PetscInt coo_i[]) noexcept 1555 { 1556 PetscDeviceContext dctx; 1557 1558 PetscFunctionBegin; 1559 PetscCall(GetHandles_(&dctx)); 1560 PetscCall(VecSetPreallocationCOO_Seq(v, ncoo, coo_i)); 1561 PetscCall(SetPreallocationCOO_CUPMBase(v, ncoo, coo_i, dctx)); 1562 PetscFunctionReturn(PETSC_SUCCESS); 1563 } 1564 1565 namespace kernels 1566 { 1567 1568 template <typename F> 1569 PETSC_DEVICE_INLINE_DECL void add_coo_values_impl(const PetscScalar *PETSC_RESTRICT vv, PetscCount n, const PetscCount *PETSC_RESTRICT jmap, const PetscCount *PETSC_RESTRICT perm, InsertMode imode, PetscScalar *PETSC_RESTRICT xv, F &&xvindex) 1570 { 1571 ::Petsc::device::cupm::kernels::util::grid_stride_1D(n, [=](PetscCount i) { 1572 const auto end = jmap[i + 1]; 1573 const auto idx = xvindex(i); 1574 PetscScalar sum = 0.0; 1575 1576 for (auto k = jmap[i]; k < end; ++k) sum += vv[perm[k]]; 1577 1578 if (imode == INSERT_VALUES) { 1579 xv[idx] = sum; 1580 } else { 1581 xv[idx] += sum; 1582 } 1583 }); 1584 return; 1585 } 1586 1587 namespace 1588 { 1589 1590 PETSC_KERNEL_DECL void add_coo_values(const PetscScalar *PETSC_RESTRICT v, PetscCount n, const PetscCount *PETSC_RESTRICT jmap1, const PetscCount *PETSC_RESTRICT perm1, InsertMode imode, PetscScalar *PETSC_RESTRICT xv) 1591 { 1592 add_coo_values_impl(v, n, jmap1, perm1, imode, xv, [](PetscCount i) { return i; }); 1593 return; 1594 } 1595 1596 } // namespace 1597 1598 #if PetscDefined(USING_HCC) 1599 namespace do_not_use 1600 { 1601 1602 // Needed to silence clang warning: 1603 // 1604 // warning: function 'FUNCTION NAME' is not needed and will not be emitted 1605 // 1606 // The warning is silly, since the function *is* used, however the host compiler does not 1607 // appear see this. Likely because the function using it is in a template. 1608 // 1609 // This warning appeared in clang-11, and still persists until clang-15 (21/02/2023) 1610 inline void silence_warning_function_sum_kernel_is_not_needed_and_will_not_be_emitted() 1611 { 1612 (void)sum_kernel; 1613 } 1614 1615 inline void silence_warning_function_add_coo_values_is_not_needed_and_will_not_be_emitted() 1616 { 1617 (void)add_coo_values; 1618 } 1619 1620 } // namespace do_not_use 1621 #endif 1622 1623 } // namespace kernels 1624 1625 // v->ops->setvaluescoo 1626 template <device::cupm::DeviceType T> 1627 inline PetscErrorCode VecSeq_CUPM<T>::SetValuesCOO(Vec x, const PetscScalar v[], InsertMode imode) noexcept 1628 { 1629 auto vv = const_cast<PetscScalar *>(v); 1630 PetscMemType memtype; 1631 PetscDeviceContext dctx; 1632 cupmStream_t stream; 1633 1634 PetscFunctionBegin; 1635 PetscCall(GetHandles_(&dctx, &stream)); 1636 PetscCall(PetscGetMemType(v, &memtype)); 1637 if (PetscMemTypeHost(memtype)) { 1638 const auto size = VecIMPLCast(x)->coo_n; 1639 1640 // If user gave v[] in host, we might need to copy it to device if any 1641 PetscCall(PetscDeviceMalloc(dctx, PETSC_MEMTYPE_CUPM(), size, &vv)); 1642 PetscCall(PetscCUPMMemcpyAsync(vv, v, size, cupmMemcpyHostToDevice, stream)); 1643 } 1644 1645 if (const auto n = x->map->n) { 1646 const auto vcu = VecCUPMCast(x); 1647 1648 PetscCall(PetscCUPMLaunchKernel1D(n, 0, stream, kernels::add_coo_values, vv, n, vcu->jmap1_d, vcu->perm1_d, imode, imode == INSERT_VALUES ? DeviceArrayWrite(dctx, x).data() : DeviceArrayReadWrite(dctx, x).data())); 1649 } else { 1650 PetscCall(MaybeIncrementEmptyLocalVec(x)); 1651 } 1652 1653 if (PetscMemTypeHost(memtype)) PetscCall(PetscDeviceFree(dctx, vv)); 1654 PetscCall(PetscDeviceContextSynchronize(dctx)); 1655 PetscFunctionReturn(PETSC_SUCCESS); 1656 } 1657 1658 } // namespace impl 1659 1660 // ========================================================================================== 1661 // VecSeq_CUPM - Implementations 1662 // ========================================================================================== 1663 1664 namespace 1665 { 1666 1667 template <device::cupm::DeviceType T> 1668 inline PetscErrorCode VecCreateSeqCUPMAsync(MPI_Comm comm, PetscInt n, Vec *v) noexcept 1669 { 1670 PetscFunctionBegin; 1671 PetscValidPointer(v, 4); 1672 PetscCall(impl::VecSeq_CUPM<T>::CreateSeqCUPM(comm, 0, n, v, PETSC_TRUE)); 1673 PetscFunctionReturn(PETSC_SUCCESS); 1674 } 1675 1676 template <device::cupm::DeviceType T> 1677 inline PetscErrorCode VecCreateSeqCUPMWithArraysAsync(MPI_Comm comm, PetscInt bs, PetscInt n, const PetscScalar cpuarray[], const PetscScalar gpuarray[], Vec *v) noexcept 1678 { 1679 PetscFunctionBegin; 1680 if (n && cpuarray) PetscValidScalarPointer(cpuarray, 4); 1681 PetscValidPointer(v, 6); 1682 PetscCall(impl::VecSeq_CUPM<T>::CreateSeqCUPMWithBothArrays(comm, bs, n, cpuarray, gpuarray, v)); 1683 PetscFunctionReturn(PETSC_SUCCESS); 1684 } 1685 1686 template <PetscMemoryAccessMode mode, device::cupm::DeviceType T> 1687 inline PetscErrorCode VecCUPMGetArrayAsync_Private(Vec v, PetscScalar **a, PetscDeviceContext dctx) noexcept 1688 { 1689 PetscFunctionBegin; 1690 PetscValidHeaderSpecific(v, VEC_CLASSID, 1); 1691 PetscValidPointer(a, 2); 1692 PetscCall(PetscDeviceContextGetOptionalNullContext_Internal(&dctx)); 1693 PetscCall(impl::VecSeq_CUPM<T>::template GetArray<PETSC_MEMTYPE_DEVICE, mode>(v, a, dctx)); 1694 PetscFunctionReturn(PETSC_SUCCESS); 1695 } 1696 1697 template <PetscMemoryAccessMode mode, device::cupm::DeviceType T> 1698 inline PetscErrorCode VecCUPMRestoreArrayAsync_Private(Vec v, PetscScalar **a, PetscDeviceContext dctx) noexcept 1699 { 1700 PetscFunctionBegin; 1701 PetscValidHeaderSpecific(v, VEC_CLASSID, 1); 1702 PetscCall(PetscDeviceContextGetOptionalNullContext_Internal(&dctx)); 1703 PetscCall(impl::VecSeq_CUPM<T>::template RestoreArray<PETSC_MEMTYPE_DEVICE, mode>(v, a, dctx)); 1704 PetscFunctionReturn(PETSC_SUCCESS); 1705 } 1706 1707 template <device::cupm::DeviceType T> 1708 inline PetscErrorCode VecCUPMGetArrayAsync(Vec v, PetscScalar **a, PetscDeviceContext dctx = nullptr) noexcept 1709 { 1710 PetscFunctionBegin; 1711 PetscCall(VecCUPMGetArrayAsync_Private<PETSC_MEMORY_ACCESS_READ_WRITE, T>(v, a, dctx)); 1712 PetscFunctionReturn(PETSC_SUCCESS); 1713 } 1714 1715 template <device::cupm::DeviceType T> 1716 inline PetscErrorCode VecCUPMRestoreArrayAsync(Vec v, PetscScalar **a, PetscDeviceContext dctx = nullptr) noexcept 1717 { 1718 PetscFunctionBegin; 1719 PetscCall(VecCUPMRestoreArrayAsync_Private<PETSC_MEMORY_ACCESS_READ_WRITE, T>(v, a, dctx)); 1720 PetscFunctionReturn(PETSC_SUCCESS); 1721 } 1722 1723 template <device::cupm::DeviceType T> 1724 inline PetscErrorCode VecCUPMGetArrayReadAsync(Vec v, const PetscScalar **a, PetscDeviceContext dctx = nullptr) noexcept 1725 { 1726 PetscFunctionBegin; 1727 PetscCall(VecCUPMGetArrayAsync_Private<PETSC_MEMORY_ACCESS_READ, T>(v, const_cast<PetscScalar **>(a), dctx)); 1728 PetscFunctionReturn(PETSC_SUCCESS); 1729 } 1730 1731 template <device::cupm::DeviceType T> 1732 inline PetscErrorCode VecCUPMRestoreArrayReadAsync(Vec v, const PetscScalar **a, PetscDeviceContext dctx = nullptr) noexcept 1733 { 1734 PetscFunctionBegin; 1735 PetscCall(VecCUPMRestoreArrayAsync_Private<PETSC_MEMORY_ACCESS_READ, T>(v, const_cast<PetscScalar **>(a), dctx)); 1736 PetscFunctionReturn(PETSC_SUCCESS); 1737 } 1738 1739 template <device::cupm::DeviceType T> 1740 inline PetscErrorCode VecCUPMGetArrayWriteAsync(Vec v, PetscScalar **a, PetscDeviceContext dctx = nullptr) noexcept 1741 { 1742 PetscFunctionBegin; 1743 PetscCall(VecCUPMGetArrayAsync_Private<PETSC_MEMORY_ACCESS_WRITE, T>(v, a, dctx)); 1744 PetscFunctionReturn(PETSC_SUCCESS); 1745 } 1746 1747 template <device::cupm::DeviceType T> 1748 inline PetscErrorCode VecCUPMRestoreArrayWriteAsync(Vec v, PetscScalar **a, PetscDeviceContext dctx = nullptr) noexcept 1749 { 1750 PetscFunctionBegin; 1751 PetscCall(VecCUPMRestoreArrayAsync_Private<PETSC_MEMORY_ACCESS_WRITE, T>(v, a, dctx)); 1752 PetscFunctionReturn(PETSC_SUCCESS); 1753 } 1754 1755 template <device::cupm::DeviceType T> 1756 inline PetscErrorCode VecCUPMPlaceArrayAsync(Vec vin, const PetscScalar a[]) noexcept 1757 { 1758 PetscFunctionBegin; 1759 PetscValidHeaderSpecific(vin, VEC_CLASSID, 1); 1760 PetscCall(impl::VecSeq_CUPM<T>::template PlaceArray<PETSC_MEMTYPE_DEVICE>(vin, a)); 1761 PetscFunctionReturn(PETSC_SUCCESS); 1762 } 1763 1764 template <device::cupm::DeviceType T> 1765 inline PetscErrorCode VecCUPMReplaceArrayAsync(Vec vin, const PetscScalar a[]) noexcept 1766 { 1767 PetscFunctionBegin; 1768 PetscValidHeaderSpecific(vin, VEC_CLASSID, 1); 1769 PetscCall(impl::VecSeq_CUPM<T>::template ReplaceArray<PETSC_MEMTYPE_DEVICE>(vin, a)); 1770 PetscFunctionReturn(PETSC_SUCCESS); 1771 } 1772 1773 template <device::cupm::DeviceType T> 1774 inline PetscErrorCode VecCUPMResetArrayAsync(Vec vin) noexcept 1775 { 1776 PetscFunctionBegin; 1777 PetscValidHeaderSpecific(vin, VEC_CLASSID, 1); 1778 PetscCall(impl::VecSeq_CUPM<T>::template ResetArray<PETSC_MEMTYPE_DEVICE>(vin)); 1779 PetscFunctionReturn(PETSC_SUCCESS); 1780 } 1781 1782 } // anonymous namespace 1783 1784 } // namespace cupm 1785 1786 } // namespace vec 1787 1788 } // namespace Petsc 1789 1790 #endif // __cplusplus 1791 1792 #endif // PETSCVECSEQCUPM_HPP 1793