1 #ifndef PETSCVECSEQCUPM_HPP 2 #define PETSCVECSEQCUPM_HPP 3 4 #include <petsc/private/veccupmimpl.h> 5 6 #if defined(__cplusplus) 7 #include <petsc/private/randomimpl.h> // for _p_PetscRandom 8 9 #include <petsc/private/cpp/utility.hpp> // util::exchange, util::index_sequence 10 11 #include "../src/sys/objects/device/impls/cupm/cupmthrustutility.hpp" 12 #include "../src/sys/objects/device/impls/cupm/kernels.hpp" 13 14 #if PetscDefined(USE_COMPLEX) 15 #include <thrust/transform_reduce.h> 16 #endif 17 #include <thrust/transform.h> 18 #include <thrust/reduce.h> 19 #include <thrust/functional.h> 20 #include <thrust/tuple.h> 21 #include <thrust/device_ptr.h> 22 #include <thrust/iterator/zip_iterator.h> 23 #include <thrust/iterator/counting_iterator.h> 24 #include <thrust/inner_product.h> 25 26 namespace Petsc 27 { 28 29 namespace vec 30 { 31 32 namespace cupm 33 { 34 35 namespace impl 36 { 37 38 // ========================================================================================== 39 // VecSeq_CUPM 40 // ========================================================================================== 41 42 template <device::cupm::DeviceType T> 43 class VecSeq_CUPM : Vec_CUPMBase<T, VecSeq_CUPM<T>> { 44 public: 45 PETSC_VEC_CUPM_BASE_CLASS_HEADER(base_type, T, VecSeq_CUPM<T>); 46 47 private: 48 PETSC_NODISCARD static Vec_Seq *VecIMPLCast_(Vec) noexcept; 49 PETSC_NODISCARD static constexpr VecType VECIMPLCUPM_() noexcept; 50 51 static PetscErrorCode VecDestroy_IMPL_(Vec) noexcept; 52 static PetscErrorCode VecResetArray_IMPL_(Vec) noexcept; 53 static PetscErrorCode VecPlaceArray_IMPL_(Vec, const PetscScalar *) noexcept; 54 static PetscErrorCode VecCreate_IMPL_Private_(Vec, PetscBool *, PetscInt, PetscScalar *) noexcept; 55 56 static PetscErrorCode MaybeIncrementEmptyLocalVec(Vec) noexcept; 57 58 // common core for min and max 59 template <typename TupleFuncT, typename UnaryFuncT> 60 static PetscErrorCode minmax_(TupleFuncT &&, UnaryFuncT &&, Vec, PetscInt *, PetscReal *) noexcept; 61 // common core for pointwise binary and pointwise unary thrust functions 62 template <typename BinaryFuncT> 63 static PetscErrorCode pointwisebinary_(BinaryFuncT &&, Vec, Vec, Vec) noexcept; 64 template <typename UnaryFuncT> 65 static PetscErrorCode pointwiseunary_(UnaryFuncT &&, Vec, Vec /*out*/ = nullptr) noexcept; 66 // mdot dispatchers 67 static PetscErrorCode mdot_(/* use complex = */ std::true_type, Vec, PetscInt, const Vec[], PetscScalar *, PetscDeviceContext) noexcept; 68 static PetscErrorCode mdot_(/* use complex = */ std::false_type, Vec, PetscInt, const Vec[], PetscScalar *, PetscDeviceContext) noexcept; 69 template <std::size_t... Idx> 70 static PetscErrorCode mdot_kernel_dispatch_(PetscDeviceContext, cupmStream_t, const PetscScalar *, const Vec[], PetscInt, PetscScalar *, util::index_sequence<Idx...>) noexcept; 71 template <int> 72 static PetscErrorCode mdot_kernel_dispatch_(PetscDeviceContext, cupmStream_t, const PetscScalar *, const Vec[], PetscInt, PetscScalar *, PetscInt &) noexcept; 73 template <std::size_t... Idx> 74 static PetscErrorCode maxpy_kernel_dispatch_(PetscDeviceContext, cupmStream_t, PetscScalar *, const PetscScalar *, const Vec *, PetscInt, util::index_sequence<Idx...>) noexcept; 75 template <int> 76 static PetscErrorCode maxpy_kernel_dispatch_(PetscDeviceContext, cupmStream_t, PetscScalar *, const PetscScalar *, const Vec *, PetscInt, PetscInt &) noexcept; 77 // common core for the various create routines 78 static PetscErrorCode createseqcupm_(Vec, PetscDeviceContext, PetscScalar * /*host_ptr*/ = nullptr, PetscScalar * /*device_ptr*/ = nullptr) noexcept; 79 80 public: 81 // callable directly via a bespoke function 82 static PetscErrorCode createseqcupm(MPI_Comm, PetscInt, PetscInt, Vec *, PetscBool) noexcept; 83 static PetscErrorCode createseqcupmwithbotharrays(MPI_Comm, PetscInt, PetscInt, const PetscScalar[], const PetscScalar[], Vec *) noexcept; 84 85 // callable indirectly via function pointers 86 static PetscErrorCode duplicate(Vec, Vec *) noexcept; 87 static PetscErrorCode aypx(Vec, PetscScalar, Vec) noexcept; 88 static PetscErrorCode axpy(Vec, PetscScalar, Vec) noexcept; 89 static PetscErrorCode pointwisedivide(Vec, Vec, Vec) noexcept; 90 static PetscErrorCode pointwisemult(Vec, Vec, Vec) noexcept; 91 static PetscErrorCode reciprocal(Vec) noexcept; 92 static PetscErrorCode waxpy(Vec, PetscScalar, Vec, Vec) noexcept; 93 static PetscErrorCode maxpy(Vec, PetscInt, const PetscScalar[], Vec *) noexcept; 94 static PetscErrorCode dot(Vec, Vec, PetscScalar *) noexcept; 95 static PetscErrorCode mdot(Vec, PetscInt, const Vec[], PetscScalar *) noexcept; 96 static PetscErrorCode set(Vec, PetscScalar) noexcept; 97 static PetscErrorCode scale(Vec, PetscScalar) noexcept; 98 static PetscErrorCode tdot(Vec, Vec, PetscScalar *) noexcept; 99 static PetscErrorCode copy(Vec, Vec) noexcept; 100 static PetscErrorCode swap(Vec, Vec) noexcept; 101 static PetscErrorCode axpby(Vec, PetscScalar, PetscScalar, Vec) noexcept; 102 static PetscErrorCode axpbypcz(Vec, PetscScalar, PetscScalar, PetscScalar, Vec, Vec) noexcept; 103 static PetscErrorCode norm(Vec, NormType, PetscReal *) noexcept; 104 static PetscErrorCode dotnorm2(Vec, Vec, PetscScalar *, PetscScalar *) noexcept; 105 static PetscErrorCode destroy(Vec) noexcept; 106 static PetscErrorCode conjugate(Vec) noexcept; 107 template <PetscMemoryAccessMode> 108 static PetscErrorCode getlocalvector(Vec, Vec) noexcept; 109 template <PetscMemoryAccessMode> 110 static PetscErrorCode restorelocalvector(Vec, Vec) noexcept; 111 static PetscErrorCode max(Vec, PetscInt *, PetscReal *) noexcept; 112 static PetscErrorCode min(Vec, PetscInt *, PetscReal *) noexcept; 113 static PetscErrorCode sum(Vec, PetscScalar *) noexcept; 114 static PetscErrorCode shift(Vec, PetscScalar) noexcept; 115 static PetscErrorCode setrandom(Vec, PetscRandom) noexcept; 116 static PetscErrorCode bindtocpu(Vec, PetscBool) noexcept; 117 static PetscErrorCode setpreallocationcoo(Vec, PetscCount, const PetscInt[]) noexcept; 118 static PetscErrorCode setvaluescoo(Vec, const PetscScalar[], InsertMode) noexcept; 119 }; 120 121 // ========================================================================================== 122 // VecSeq_CUPM - Private API 123 // ========================================================================================== 124 125 template <device::cupm::DeviceType T> 126 inline Vec_Seq *VecSeq_CUPM<T>::VecIMPLCast_(Vec v) noexcept 127 { 128 return static_cast<Vec_Seq *>(v->data); 129 } 130 131 template <device::cupm::DeviceType T> 132 inline constexpr VecType VecSeq_CUPM<T>::VECIMPLCUPM_() noexcept 133 { 134 return VECSEQCUPM(); 135 } 136 137 template <device::cupm::DeviceType T> 138 inline PetscErrorCode VecSeq_CUPM<T>::VecDestroy_IMPL_(Vec v) noexcept 139 { 140 return VecDestroy_Seq(v); 141 } 142 143 template <device::cupm::DeviceType T> 144 inline PetscErrorCode VecSeq_CUPM<T>::VecResetArray_IMPL_(Vec v) noexcept 145 { 146 return VecResetArray_Seq(v); 147 } 148 149 template <device::cupm::DeviceType T> 150 inline PetscErrorCode VecSeq_CUPM<T>::VecPlaceArray_IMPL_(Vec v, const PetscScalar *a) noexcept 151 { 152 return VecPlaceArray_Seq(v, a); 153 } 154 155 template <device::cupm::DeviceType T> 156 inline PetscErrorCode VecSeq_CUPM<T>::VecCreate_IMPL_Private_(Vec v, PetscBool *alloc_missing, PetscInt, PetscScalar *host_array) noexcept 157 { 158 PetscMPIInt size; 159 160 PetscFunctionBegin; 161 if (alloc_missing) *alloc_missing = PETSC_FALSE; 162 PetscCallMPI(MPI_Comm_size(PetscObjectComm(PetscObjectCast(v)), &size)); 163 PetscCheck(size <= 1, PETSC_COMM_SELF, PETSC_ERR_ARG_WRONG, "Must create VecSeq on communicator of size 1, have size %d", size); 164 PetscCall(VecCreate_Seq_Private(v, host_array)); 165 PetscFunctionReturn(PETSC_SUCCESS); 166 } 167 168 // for functions with an early return based one vec size we still need to artificially bump the 169 // object state. This is to prevent the following: 170 // 171 // 0. Suppose you have a Vec { 172 // rank 0: [0], 173 // rank 1: [<empty>] 174 // } 175 // 1. both ranks have Vec with PetscObjectState = 0, stashed norm of 0 176 // 2. Vec enters e.g. VecSet(10) 177 // 3. rank 1 has local size 0 and bails immediately 178 // 4. rank 0 has local size 1 and enters function, eventually calls DeviceArrayWrite() 179 // 5. DeviceArrayWrite() calls PetscObjectStateIncrease(), now state = 1 180 // 6. Vec enters VecNorm(), and calls VecNormAvailable() 181 // 7. rank 1 has object state = 0, equal to stash and returns early with norm = 0 182 // 8. rank 0 has object state = 1, not equal to stash, continues to impl function 183 // 9. rank 0 deadlocks on MPI_Allreduce() because rank 1 bailed early 184 template <device::cupm::DeviceType T> 185 inline PetscErrorCode VecSeq_CUPM<T>::MaybeIncrementEmptyLocalVec(Vec v) noexcept 186 { 187 PetscFunctionBegin; 188 if (PetscUnlikely((v->map->n == 0) && (v->map->N != 0))) PetscCall(PetscObjectStateIncrease(PetscObjectCast(v))); 189 PetscFunctionReturn(PETSC_SUCCESS); 190 } 191 192 template <device::cupm::DeviceType T> 193 inline PetscErrorCode VecSeq_CUPM<T>::createseqcupm_(Vec v, PetscDeviceContext dctx, PetscScalar *host_array, PetscScalar *device_array) noexcept 194 { 195 PetscFunctionBegin; 196 PetscCall(base_type::VecCreate_IMPL_Private(v, nullptr, 0, host_array)); 197 PetscCall(Initialize_CUPMBase(v, PETSC_FALSE, host_array, device_array, dctx)); 198 PetscFunctionReturn(PETSC_SUCCESS); 199 } 200 201 template <device::cupm::DeviceType T> 202 template <typename BinaryFuncT> 203 inline PetscErrorCode VecSeq_CUPM<T>::pointwisebinary_(BinaryFuncT &&binary, Vec xin, Vec yin, Vec zout) noexcept 204 { 205 PetscFunctionBegin; 206 if (const auto n = zout->map->n) { 207 PetscDeviceContext dctx; 208 cupmStream_t stream; 209 210 PetscCall(GetHandles_(&dctx, &stream)); 211 // clang-format off 212 PetscCallThrust( 213 const auto dxptr = thrust::device_pointer_cast(DeviceArrayRead(dctx, xin).data()); 214 215 THRUST_CALL( 216 thrust::transform, 217 stream, 218 dxptr, dxptr + n, 219 thrust::device_pointer_cast(DeviceArrayRead(dctx, yin).data()), 220 thrust::device_pointer_cast(DeviceArrayWrite(dctx, zout).data()), 221 std::forward<BinaryFuncT>(binary) 222 ) 223 ); 224 // clang-format on 225 PetscCall(PetscLogFlops(n)); 226 PetscCall(PetscDeviceContextSynchronize(dctx)); 227 } else { 228 PetscCall(MaybeIncrementEmptyLocalVec(zout)); 229 } 230 PetscFunctionReturn(PETSC_SUCCESS); 231 } 232 233 template <device::cupm::DeviceType T> 234 template <typename UnaryFuncT> 235 inline PetscErrorCode VecSeq_CUPM<T>::pointwiseunary_(UnaryFuncT &&unary, Vec xinout, Vec yin) noexcept 236 { 237 const auto inplace = !yin || (xinout == yin); 238 239 PetscFunctionBegin; 240 if (const auto n = xinout->map->n) { 241 PetscDeviceContext dctx; 242 cupmStream_t stream; 243 const auto apply = [&](PetscScalar *xinout, PetscScalar *yin = nullptr) { 244 PetscFunctionBegin; 245 // clang-format off 246 PetscCallThrust( 247 const auto xptr = thrust::device_pointer_cast(xinout); 248 249 THRUST_CALL( 250 thrust::transform, 251 stream, 252 xptr, xptr + n, 253 (yin && (yin != xinout)) ? thrust::device_pointer_cast(yin) : xptr, 254 std::forward<UnaryFuncT>(unary) 255 ) 256 ); 257 PetscFunctionReturn(PETSC_SUCCESS); 258 }; 259 260 PetscCall(GetHandles_(&dctx, &stream)); 261 if (inplace) { 262 PetscCall(apply(DeviceArrayReadWrite(dctx, xinout).data())); 263 } else { 264 PetscCall(apply(DeviceArrayRead(dctx, xinout).data(), DeviceArrayWrite(dctx, yin).data())); 265 } 266 PetscCall(PetscLogFlops(n)); 267 PetscCall(PetscDeviceContextSynchronize(dctx)); 268 } else { 269 if (inplace) { 270 PetscCall(MaybeIncrementEmptyLocalVec(xinout)); 271 } else { 272 PetscCall(MaybeIncrementEmptyLocalVec(yin)); 273 } 274 } 275 PetscFunctionReturn(PETSC_SUCCESS); 276 } 277 278 // ========================================================================================== 279 // VecSeq_CUPM - Public API - Constructors 280 // ========================================================================================== 281 282 // VecCreateSeqCUPM() 283 template <device::cupm::DeviceType T> 284 inline PetscErrorCode VecSeq_CUPM<T>::createseqcupm(MPI_Comm comm, PetscInt bs, PetscInt n, Vec *v, PetscBool call_set_type) noexcept 285 { 286 PetscFunctionBegin; 287 PetscCall(Create_CUPMBase(comm, bs, n, n, v, call_set_type)); 288 PetscFunctionReturn(PETSC_SUCCESS); 289 } 290 291 // VecCreateSeqCUPMWithArrays() 292 template <device::cupm::DeviceType T> 293 inline PetscErrorCode VecSeq_CUPM<T>::createseqcupmwithbotharrays(MPI_Comm comm, PetscInt bs, PetscInt n, const PetscScalar host_array[], const PetscScalar device_array[], Vec *v) noexcept 294 { 295 PetscDeviceContext dctx; 296 297 PetscFunctionBegin; 298 PetscCall(GetHandles_(&dctx)); 299 // do NOT call VecSetType(), otherwise ops->create() -> create() -> 300 // createseqcupm_() is called! 301 PetscCall(createseqcupm(comm, bs, n, v, PETSC_FALSE)); 302 PetscCall(createseqcupm_(*v, dctx, PetscRemoveConstCast(host_array), PetscRemoveConstCast(device_array))); 303 PetscFunctionReturn(PETSC_SUCCESS); 304 } 305 306 // v->ops->duplicate 307 template <device::cupm::DeviceType T> 308 inline PetscErrorCode VecSeq_CUPM<T>::duplicate(Vec v, Vec *y) noexcept 309 { 310 PetscDeviceContext dctx; 311 312 PetscFunctionBegin; 313 PetscCall(GetHandles_(&dctx)); 314 PetscCall(Duplicate_CUPMBase(v, y, dctx)); 315 PetscFunctionReturn(PETSC_SUCCESS); 316 } 317 318 // ========================================================================================== 319 // VecSeq_CUPM - Public API - Utility 320 // ========================================================================================== 321 322 // v->ops->bindtocpu 323 template <device::cupm::DeviceType T> 324 inline PetscErrorCode VecSeq_CUPM<T>::bindtocpu(Vec v, PetscBool usehost) noexcept 325 { 326 PetscDeviceContext dctx; 327 328 PetscFunctionBegin; 329 PetscCall(GetHandles_(&dctx)); 330 PetscCall(BindToCPU_CUPMBase(v, usehost, dctx)); 331 332 // REVIEW ME: this absolutely should be some sort of bulk mempcy rather than this mess 333 VecSetOp_CUPM(dot, VecDot_Seq, dot); 334 VecSetOp_CUPM(norm, VecNorm_Seq, norm); 335 VecSetOp_CUPM(tdot, VecTDot_Seq, tdot); 336 VecSetOp_CUPM(mdot, VecMDot_Seq, mdot); 337 VecSetOp_CUPM(resetarray, VecResetArray_Seq, base_type::template resetarray<PETSC_MEMTYPE_HOST>); 338 VecSetOp_CUPM(placearray, VecPlaceArray_Seq, base_type::template placearray<PETSC_MEMTYPE_HOST>); 339 v->ops->mtdot = v->ops->mtdot_local = VecMTDot_Seq; 340 VecSetOp_CUPM(conjugate, VecConjugate_Seq, conjugate); 341 VecSetOp_CUPM(max, VecMax_Seq, max); 342 VecSetOp_CUPM(min, VecMin_Seq, min); 343 VecSetOp_CUPM(setpreallocationcoo, VecSetPreallocationCOO_Seq, setpreallocationcoo); 344 VecSetOp_CUPM(setvaluescoo, VecSetValuesCOO_Seq, setvaluescoo); 345 PetscFunctionReturn(PETSC_SUCCESS); 346 } 347 348 // ========================================================================================== 349 // VecSeq_CUPM - Public API - Mutators 350 // ========================================================================================== 351 352 // v->ops->getlocalvector or v->ops->getlocalvectorread 353 template <device::cupm::DeviceType T> 354 template <PetscMemoryAccessMode access> 355 inline PetscErrorCode VecSeq_CUPM<T>::getlocalvector(Vec v, Vec w) noexcept 356 { 357 PetscBool wisseqcupm; 358 359 PetscFunctionBegin; 360 PetscCheckTypeNames(v, VECSEQCUPM(), VECMPICUPM()); 361 PetscCall(PetscObjectTypeCompare(PetscObjectCast(w), VECSEQCUPM(), &wisseqcupm)); 362 if (wisseqcupm) { 363 if (const auto wseq = VecIMPLCast(w)) { 364 if (auto &alloced = wseq->array_allocated) { 365 const auto useit = UseCUPMHostAlloc(util::exchange(w->pinned_memory, PETSC_FALSE)); 366 367 PetscCall(PetscFree(alloced)); 368 } 369 wseq->array = nullptr; 370 wseq->unplacedarray = nullptr; 371 } 372 if (const auto wcu = VecCUPMCast(w)) { 373 if (auto &device_array = wcu->array_d) { 374 cupmStream_t stream; 375 376 PetscCall(GetHandles_(&stream)); 377 PetscCallCUPM(cupmFreeAsync(device_array, stream)); 378 } 379 PetscCall(PetscFree(w->spptr /* wcu */)); 380 } 381 } 382 if (v->petscnative && wisseqcupm) { 383 PetscCall(PetscFree(w->data)); 384 w->data = v->data; 385 w->offloadmask = v->offloadmask; 386 w->pinned_memory = v->pinned_memory; 387 w->spptr = v->spptr; 388 PetscCall(PetscObjectStateIncrease(PetscObjectCast(w))); 389 } else { 390 const auto array = &VecIMPLCast(w)->array; 391 392 if (access == PETSC_MEMORY_ACCESS_READ) { 393 PetscCall(VecGetArrayRead(v, const_cast<const PetscScalar **>(array))); 394 } else { 395 PetscCall(VecGetArray(v, array)); 396 } 397 w->offloadmask = PETSC_OFFLOAD_CPU; 398 if (wisseqcupm) { 399 PetscDeviceContext dctx; 400 401 PetscCall(GetHandles_(&dctx)); 402 PetscCall(DeviceAllocateCheck_(dctx, w)); 403 } 404 } 405 PetscFunctionReturn(PETSC_SUCCESS); 406 } 407 408 // v->ops->restorelocalvector or v->ops->restorelocalvectorread 409 template <device::cupm::DeviceType T> 410 template <PetscMemoryAccessMode access> 411 inline PetscErrorCode VecSeq_CUPM<T>::restorelocalvector(Vec v, Vec w) noexcept 412 { 413 PetscBool wisseqcupm; 414 415 PetscFunctionBegin; 416 PetscCheckTypeNames(v, VECSEQCUPM(), VECMPICUPM()); 417 PetscCall(PetscObjectTypeCompare(PetscObjectCast(w), VECSEQCUPM(), &wisseqcupm)); 418 if (v->petscnative && wisseqcupm) { 419 // the assignments to nullptr are __critical__, as w may persist after this call returns 420 // and shouldn't share data with v! 421 v->pinned_memory = w->pinned_memory; 422 v->offloadmask = util::exchange(w->offloadmask, PETSC_OFFLOAD_UNALLOCATED); 423 v->data = util::exchange(w->data, nullptr); 424 v->spptr = util::exchange(w->spptr, nullptr); 425 } else { 426 const auto array = &VecIMPLCast(w)->array; 427 428 if (access == PETSC_MEMORY_ACCESS_READ) { 429 PetscCall(VecRestoreArrayRead(v, const_cast<const PetscScalar **>(array))); 430 } else { 431 PetscCall(VecRestoreArray(v, array)); 432 } 433 if (w->spptr && wisseqcupm) { 434 cupmStream_t stream; 435 436 PetscCall(GetHandles_(&stream)); 437 PetscCallCUPM(cupmFreeAsync(VecCUPMCast(w)->array_d, stream)); 438 PetscCall(PetscFree(w->spptr)); 439 } 440 } 441 PetscFunctionReturn(PETSC_SUCCESS); 442 } 443 444 // ========================================================================================== 445 // VecSeq_CUPM - Public API - Compute Methods 446 // ========================================================================================== 447 448 // v->ops->aypx 449 template <device::cupm::DeviceType T> 450 inline PetscErrorCode VecSeq_CUPM<T>::aypx(Vec yin, PetscScalar alpha, Vec xin) noexcept 451 { 452 const auto n = static_cast<cupmBlasInt_t>(yin->map->n); 453 const auto sync = n != 0; 454 PetscDeviceContext dctx; 455 456 PetscFunctionBegin; 457 PetscCall(GetHandles_(&dctx)); 458 if (alpha == PetscScalar(0.0)) { 459 cupmStream_t stream; 460 461 PetscCall(GetHandlesFrom_(dctx, &stream)); 462 PetscCall(PetscLogGpuTimeBegin()); 463 PetscCall(PetscCUPMMemcpyAsync(DeviceArrayWrite(dctx, yin).data(), DeviceArrayRead(dctx, xin).data(), n, cupmMemcpyDeviceToDevice, stream)); 464 PetscCall(PetscLogGpuTimeEnd()); 465 } else if (n) { 466 const auto alphaIsOne = alpha == PetscScalar(1.0); 467 const auto calpha = cupmScalarPtrCast(&alpha); 468 cupmBlasHandle_t cupmBlasHandle; 469 470 PetscCall(GetHandlesFrom_(dctx, &cupmBlasHandle)); 471 { 472 const auto yptr = DeviceArrayReadWrite(dctx, yin); 473 const auto xptr = DeviceArrayRead(dctx, xin); 474 475 PetscCall(PetscLogGpuTimeBegin()); 476 if (alphaIsOne) { 477 PetscCallCUPMBLAS(cupmBlasXaxpy(cupmBlasHandle, n, calpha, xptr.cupmdata(), 1, yptr.cupmdata(), 1)); 478 } else { 479 const auto one = cupmScalarCast(1.0); 480 481 PetscCallCUPMBLAS(cupmBlasXscal(cupmBlasHandle, n, calpha, yptr.cupmdata(), 1)); 482 PetscCallCUPMBLAS(cupmBlasXaxpy(cupmBlasHandle, n, &one, xptr.cupmdata(), 1, yptr.cupmdata(), 1)); 483 } 484 PetscCall(PetscLogGpuTimeEnd()); 485 } 486 PetscCall(PetscLogGpuFlops((alphaIsOne ? 1 : 2) * n)); 487 } 488 if (sync) PetscCall(PetscDeviceContextSynchronize(dctx)); 489 PetscFunctionReturn(PETSC_SUCCESS); 490 } 491 492 // v->ops->axpy 493 template <device::cupm::DeviceType T> 494 inline PetscErrorCode VecSeq_CUPM<T>::axpy(Vec yin, PetscScalar alpha, Vec xin) noexcept 495 { 496 PetscBool xiscupm; 497 498 PetscFunctionBegin; 499 PetscCall(PetscObjectTypeCompareAny(PetscObjectCast(xin), &xiscupm, VECSEQCUPM(), VECMPICUPM(), "")); 500 if (xiscupm) { 501 const auto n = static_cast<cupmBlasInt_t>(yin->map->n); 502 PetscDeviceContext dctx; 503 cupmBlasHandle_t cupmBlasHandle; 504 505 PetscCall(GetHandles_(&dctx, &cupmBlasHandle)); 506 PetscCall(PetscLogGpuTimeBegin()); 507 PetscCallCUPMBLAS(cupmBlasXaxpy(cupmBlasHandle, n, cupmScalarPtrCast(&alpha), DeviceArrayRead(dctx, xin), 1, DeviceArrayReadWrite(dctx, yin), 1)); 508 PetscCall(PetscLogGpuTimeEnd()); 509 PetscCall(PetscLogGpuFlops(2 * n)); 510 PetscCall(PetscDeviceContextSynchronize(dctx)); 511 } else { 512 PetscCall(VecAXPY_Seq(yin, alpha, xin)); 513 } 514 PetscFunctionReturn(PETSC_SUCCESS); 515 } 516 517 // v->ops->pointwisedivide 518 template <device::cupm::DeviceType T> 519 inline PetscErrorCode VecSeq_CUPM<T>::pointwisedivide(Vec win, Vec xin, Vec yin) noexcept 520 { 521 PetscFunctionBegin; 522 if (xin->boundtocpu || yin->boundtocpu) { 523 PetscCall(VecPointwiseDivide_Seq(win, xin, yin)); 524 } else { 525 // note order of arguments! xin and yin are read, win is written! 526 PetscCall(pointwisebinary_(thrust::divides<PetscScalar>{}, xin, yin, win)); 527 } 528 PetscFunctionReturn(PETSC_SUCCESS); 529 } 530 531 // v->ops->pointwisemult 532 template <device::cupm::DeviceType T> 533 inline PetscErrorCode VecSeq_CUPM<T>::pointwisemult(Vec win, Vec xin, Vec yin) noexcept 534 { 535 PetscFunctionBegin; 536 if (xin->boundtocpu || yin->boundtocpu) { 537 PetscCall(VecPointwiseMult_Seq(win, xin, yin)); 538 } else { 539 // note order of arguments! xin and yin are read, win is written! 540 PetscCall(pointwisebinary_(thrust::multiplies<PetscScalar>{}, xin, yin, win)); 541 } 542 PetscFunctionReturn(PETSC_SUCCESS); 543 } 544 545 namespace detail 546 { 547 548 struct reciprocal { 549 PETSC_HOSTDEVICE_INLINE_DECL PetscScalar operator()(PetscScalar s) const noexcept 550 { 551 // yes all of this verbosity is needed because sometimes PetscScalar is a thrust::complex 552 // and then it matters whether we do s ? true : false vs s == 0, as well as whether we wrap 553 // everything in PetscScalar... 554 return s == PetscScalar{0.0} ? s : PetscScalar{1.0} / s; 555 } 556 }; 557 558 } // namespace detail 559 560 // v->ops->reciprocal 561 template <device::cupm::DeviceType T> 562 inline PetscErrorCode VecSeq_CUPM<T>::reciprocal(Vec xin) noexcept 563 { 564 PetscFunctionBegin; 565 PetscCall(pointwiseunary_(detail::reciprocal{}, xin)); 566 PetscFunctionReturn(PETSC_SUCCESS); 567 } 568 569 // v->ops->waxpy 570 template <device::cupm::DeviceType T> 571 inline PetscErrorCode VecSeq_CUPM<T>::waxpy(Vec win, PetscScalar alpha, Vec xin, Vec yin) noexcept 572 { 573 PetscFunctionBegin; 574 if (alpha == PetscScalar(0.0)) { 575 PetscCall(copy(yin, win)); 576 } else if (const auto n = static_cast<cupmBlasInt_t>(win->map->n)) { 577 PetscDeviceContext dctx; 578 cupmBlasHandle_t cupmBlasHandle; 579 cupmStream_t stream; 580 581 PetscCall(GetHandles_(&dctx, &cupmBlasHandle, &stream)); 582 { 583 const auto wptr = DeviceArrayWrite(dctx, win); 584 585 PetscCall(PetscLogGpuTimeBegin()); 586 PetscCall(PetscCUPMMemcpyAsync(wptr.data(), DeviceArrayRead(dctx, yin).data(), n, cupmMemcpyDeviceToDevice, stream, true)); 587 PetscCallCUPMBLAS(cupmBlasXaxpy(cupmBlasHandle, n, cupmScalarPtrCast(&alpha), DeviceArrayRead(dctx, xin), 1, wptr.cupmdata(), 1)); 588 PetscCall(PetscLogGpuTimeEnd()); 589 } 590 PetscCall(PetscLogGpuFlops(2 * n)); 591 PetscCall(PetscDeviceContextSynchronize(dctx)); 592 } 593 PetscFunctionReturn(PETSC_SUCCESS); 594 } 595 596 namespace kernels 597 { 598 599 template <typename... Args> 600 PETSC_KERNEL_DECL static void maxpy_kernel(const PetscInt size, PetscScalar *PETSC_RESTRICT xptr, const PetscScalar *PETSC_RESTRICT aptr, Args... yptr) 601 { 602 constexpr int N = sizeof...(Args); 603 const auto tx = threadIdx.x; 604 const PetscScalar *yptr_p[] = {yptr...}; 605 606 PETSC_SHAREDMEM_DECL PetscScalar aptr_shmem[N]; 607 608 // load a to shared memory 609 if (tx < N) aptr_shmem[tx] = aptr[tx]; 610 __syncthreads(); 611 612 ::Petsc::device::cupm::kernels::util::grid_stride_1D(size, [&](PetscInt i) { 613 // these may look the same but give different results! 614 #if 0 615 PetscScalar sum = 0.0; 616 617 #pragma unroll 618 for (auto j = 0; j < N; ++j) sum += aptr_shmem[j]*yptr_p[j][i]; 619 xptr[i] += sum; 620 #else 621 auto sum = xptr[i]; 622 623 #pragma unroll 624 for (auto j = 0; j < N; ++j) sum += aptr_shmem[j]*yptr_p[j][i]; 625 xptr[i] = sum; 626 #endif 627 }); 628 return; 629 } 630 631 } // namespace kernels 632 633 namespace detail 634 { 635 636 // a helper-struct to gobble the size_t input, it is used with template parameter pack 637 // expansion such that 638 // typename repeat_type<MyType, IdxParamPack>... 639 // expands to 640 // MyType, MyType, MyType, ... [repeated sizeof...(IdxParamPack) times] 641 template <typename T, std::size_t> 642 struct repeat_type { 643 using type = T; 644 }; 645 646 } // namespace detail 647 648 template <device::cupm::DeviceType T> 649 template <std::size_t... Idx> 650 inline PetscErrorCode VecSeq_CUPM<T>::maxpy_kernel_dispatch_(PetscDeviceContext dctx, cupmStream_t stream, PetscScalar *xptr, const PetscScalar *aptr, const Vec *yin, PetscInt size, util::index_sequence<Idx...>) noexcept 651 { 652 PetscFunctionBegin; 653 // clang-format off 654 PetscCall( 655 PetscCUPMLaunchKernel1D( 656 size, 0, stream, 657 kernels::maxpy_kernel<typename detail::repeat_type<const PetscScalar *, Idx>::type...>, 658 size, xptr, aptr, DeviceArrayRead(dctx, yin[Idx]).data()... 659 ) 660 ); 661 // clang-format on 662 PetscFunctionReturn(PETSC_SUCCESS); 663 } 664 665 template <device::cupm::DeviceType T> 666 template <int N> 667 inline PetscErrorCode VecSeq_CUPM<T>::maxpy_kernel_dispatch_(PetscDeviceContext dctx, cupmStream_t stream, PetscScalar *xptr, const PetscScalar *aptr, const Vec *yin, PetscInt size, PetscInt &yidx) noexcept 668 { 669 PetscFunctionBegin; 670 PetscCall(maxpy_kernel_dispatch_(dctx, stream, xptr, aptr + yidx, yin + yidx, size, util::make_index_sequence<N>{})); 671 yidx += N; 672 PetscFunctionReturn(PETSC_SUCCESS); 673 } 674 675 // v->ops->maxpy 676 template <device::cupm::DeviceType T> 677 inline PetscErrorCode VecSeq_CUPM<T>::maxpy(Vec xin, PetscInt nv, const PetscScalar *alpha, Vec *yin) noexcept 678 { 679 const auto n = xin->map->n; 680 PetscDeviceContext dctx; 681 cupmStream_t stream; 682 683 PetscFunctionBegin; 684 PetscCall(GetHandles_(&dctx, &stream)); 685 { 686 const auto xptr = DeviceArrayReadWrite(dctx, xin); 687 PetscScalar *d_alpha = nullptr; 688 PetscInt yidx = 0; 689 690 // placement of early-return is deliberate, we would like to capture the 691 // DeviceArrayReadWrite() call (which calls PetscObjectStateIncreate()) before we bail 692 if (!n || !nv) PetscFunctionReturn(PETSC_SUCCESS); 693 PetscCall(PetscDeviceMalloc(dctx, PETSC_MEMTYPE_CUPM(), nv, &d_alpha)); 694 PetscCall(PetscCUPMMemcpyAsync(d_alpha, alpha, nv, cupmMemcpyHostToDevice, stream)); 695 PetscCall(PetscLogGpuTimeBegin()); 696 do { 697 switch (nv - yidx) { 698 case 7: 699 PetscCall(maxpy_kernel_dispatch_<7>(dctx, stream, xptr.data(), d_alpha, yin, n, yidx)); 700 break; 701 case 6: 702 PetscCall(maxpy_kernel_dispatch_<6>(dctx, stream, xptr.data(), d_alpha, yin, n, yidx)); 703 break; 704 case 5: 705 PetscCall(maxpy_kernel_dispatch_<5>(dctx, stream, xptr.data(), d_alpha, yin, n, yidx)); 706 break; 707 case 4: 708 PetscCall(maxpy_kernel_dispatch_<4>(dctx, stream, xptr.data(), d_alpha, yin, n, yidx)); 709 break; 710 case 3: 711 PetscCall(maxpy_kernel_dispatch_<3>(dctx, stream, xptr.data(), d_alpha, yin, n, yidx)); 712 break; 713 case 2: 714 PetscCall(maxpy_kernel_dispatch_<2>(dctx, stream, xptr.data(), d_alpha, yin, n, yidx)); 715 break; 716 case 1: 717 PetscCall(maxpy_kernel_dispatch_<1>(dctx, stream, xptr.data(), d_alpha, yin, n, yidx)); 718 break; 719 default: // 8 or more 720 PetscCall(maxpy_kernel_dispatch_<8>(dctx, stream, xptr.data(), d_alpha, yin, n, yidx)); 721 break; 722 } 723 } while (yidx < nv); 724 PetscCall(PetscLogGpuTimeEnd()); 725 PetscCall(PetscDeviceFree(dctx, d_alpha)); 726 } 727 PetscCall(PetscLogGpuFlops(nv * 2 * n)); 728 PetscCall(PetscDeviceContextSynchronize(dctx)); 729 PetscFunctionReturn(PETSC_SUCCESS); 730 } 731 732 template <device::cupm::DeviceType T> 733 inline PetscErrorCode VecSeq_CUPM<T>::dot(Vec xin, Vec yin, PetscScalar *z) noexcept 734 { 735 PetscFunctionBegin; 736 if (const auto n = static_cast<cupmBlasInt_t>(xin->map->n)) { 737 PetscDeviceContext dctx; 738 cupmBlasHandle_t cupmBlasHandle; 739 740 PetscCall(GetHandles_(&dctx, &cupmBlasHandle)); 741 // arguments y, x are reversed because BLAS complex conjugates the first argument, PETSc the 742 // second 743 PetscCall(PetscLogGpuTimeBegin()); 744 PetscCallCUPMBLAS(cupmBlasXdot(cupmBlasHandle, n, DeviceArrayRead(dctx, yin), 1, DeviceArrayRead(dctx, xin), 1, cupmScalarPtrCast(z))); 745 PetscCall(PetscLogGpuTimeEnd()); 746 PetscCall(PetscLogGpuFlops(2 * n - 1)); 747 } else { 748 *z = 0.0; 749 } 750 PetscFunctionReturn(PETSC_SUCCESS); 751 } 752 753 #define MDOT_WORKGROUP_NUM 128 754 #define MDOT_WORKGROUP_SIZE MDOT_WORKGROUP_NUM 755 756 namespace kernels 757 { 758 759 PETSC_DEVICE_INLINE_DECL static PetscInt EntriesPerGroup(const PetscInt size) noexcept 760 { 761 const auto group_entries = (size - 1) / gridDim.x + 1; 762 // for very small vectors, a group should still do some work 763 return group_entries ? group_entries : 1; 764 } 765 766 template <typename... ConstPetscScalarPointer> 767 PETSC_KERNEL_DECL static void mdot_kernel(const PetscScalar *PETSC_RESTRICT x, const PetscInt size, PetscScalar *PETSC_RESTRICT results, ConstPetscScalarPointer... y) 768 { 769 constexpr int N = sizeof...(ConstPetscScalarPointer); 770 const PetscScalar *ylocal[] = {y...}; 771 PetscScalar sumlocal[N]; 772 773 PETSC_SHAREDMEM_DECL PetscScalar shmem[N * MDOT_WORKGROUP_SIZE]; 774 775 // HIP -- for whatever reason -- has threadIdx, blockIdx, blockDim, and gridDim as separate 776 // types, so each of these go on separate lines... 777 const auto tx = threadIdx.x; 778 const auto bx = blockIdx.x; 779 const auto bdx = blockDim.x; 780 const auto gdx = gridDim.x; 781 const auto worksize = EntriesPerGroup(size); 782 const auto begin = tx + bx * worksize; 783 const auto end = min((bx + 1) * worksize, size); 784 785 #pragma unroll 786 for (auto i = 0; i < N; ++i) sumlocal[i] = 0; 787 788 for (auto i = begin; i < end; i += bdx) { 789 const auto xi = x[i]; // load only once from global memory! 790 791 #pragma unroll 792 for (auto j = 0; j < N; ++j) sumlocal[j] += ylocal[j][i] * xi; 793 } 794 795 #pragma unroll 796 for (auto i = 0; i < N; ++i) shmem[tx + i * MDOT_WORKGROUP_SIZE] = sumlocal[i]; 797 798 // parallel reduction 799 for (auto stride = bdx / 2; stride > 0; stride /= 2) { 800 __syncthreads(); 801 if (tx < stride) { 802 #pragma unroll 803 for (auto i = 0; i < N; ++i) shmem[tx + i * MDOT_WORKGROUP_SIZE] += shmem[tx + stride + i * MDOT_WORKGROUP_SIZE]; 804 } 805 } 806 // bottom N threads per block write to global memory 807 // REVIEW ME: I am ~pretty~ sure we don't need another __syncthreads() here since each thread 808 // writes to the same sections in the above loop that it is about to read from below, but 809 // running this under the racecheck tool of cuda-memcheck reports a write-after-write hazard. 810 __syncthreads(); 811 if (tx < N) results[bx + tx * gdx] = shmem[tx * MDOT_WORKGROUP_SIZE]; 812 return; 813 } 814 815 namespace 816 { 817 818 PETSC_KERNEL_DECL void sum_kernel(const PetscInt size, PetscScalar *PETSC_RESTRICT results) 819 { 820 int local_i = 0; 821 PetscScalar local_results[8]; 822 823 // each thread sums up MDOT_WORKGROUP_NUM entries of the result, storing it in a local buffer 824 // 825 // *-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-* 826 // | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | ... 827 // *-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-* 828 // | ______________________________________________________/ 829 // | / <- MDOT_WORKGROUP_NUM -> 830 // |/ 831 // + 832 // v 833 // *-*-* 834 // | | | ... 835 // *-*-* 836 // 837 ::Petsc::device::cupm::kernels::util::grid_stride_1D(size, [&](PetscInt i) { 838 PetscScalar z_sum = 0; 839 840 for (auto j = i * MDOT_WORKGROUP_SIZE; j < (i + 1) * MDOT_WORKGROUP_SIZE; ++j) z_sum += results[j]; 841 local_results[local_i++] = z_sum; 842 }); 843 // if we needed more than 1 workgroup to handle the vector we should sync since other threads 844 // may currently be reading from results 845 if (size >= MDOT_WORKGROUP_SIZE) __syncthreads(); 846 // Local buffer is now written to global memory 847 ::Petsc::device::cupm::kernels::util::grid_stride_1D(size, [&](PetscInt i) { 848 const auto j = --local_i; 849 850 if (j >= 0) results[i] = local_results[j]; 851 }); 852 return; 853 } 854 855 } // namespace 856 857 } // namespace kernels 858 859 template <device::cupm::DeviceType T> 860 template <std::size_t... Idx> 861 inline PetscErrorCode VecSeq_CUPM<T>::mdot_kernel_dispatch_(PetscDeviceContext dctx, cupmStream_t stream, const PetscScalar *xarr, const Vec yin[], PetscInt size, PetscScalar *results, util::index_sequence<Idx...>) noexcept 862 { 863 PetscFunctionBegin; 864 // REVIEW ME: convert this kernel launch to PetscCUPMLaunchKernel1D(), it currently launches 865 // 128 blocks of 128 threads every time which may be wasteful 866 // clang-format off 867 PetscCallCUPM( 868 cupmLaunchKernel( 869 kernels::mdot_kernel<typename detail::repeat_type<const PetscScalar *, Idx>::type...>, 870 MDOT_WORKGROUP_NUM, MDOT_WORKGROUP_SIZE, 0, stream, 871 xarr, size, results, DeviceArrayRead(dctx, yin[Idx]).data()... 872 ) 873 ); 874 // clang-format on 875 PetscFunctionReturn(PETSC_SUCCESS); 876 } 877 878 template <device::cupm::DeviceType T> 879 template <int N> 880 inline PetscErrorCode VecSeq_CUPM<T>::mdot_kernel_dispatch_(PetscDeviceContext dctx, cupmStream_t stream, const PetscScalar *xarr, const Vec yin[], PetscInt size, PetscScalar *results, PetscInt &yidx) noexcept 881 { 882 PetscFunctionBegin; 883 PetscCall(mdot_kernel_dispatch_(dctx, stream, xarr, yin + yidx, size, results + yidx * MDOT_WORKGROUP_NUM, util::make_index_sequence<N>{})); 884 yidx += N; 885 PetscFunctionReturn(PETSC_SUCCESS); 886 } 887 888 template <device::cupm::DeviceType T> 889 inline PetscErrorCode VecSeq_CUPM<T>::mdot_(std::false_type, Vec xin, PetscInt nv, const Vec yin[], PetscScalar *z, PetscDeviceContext dctx) noexcept 890 { 891 // the largest possible size of a batch 892 constexpr PetscInt batchsize = 8; 893 // how many sub streams to create, if nv <= batchsize we can do this without looping, so we 894 // do not create substreams. Note we don't create more than 8 streams, in practice we could 895 // not get more parallelism with higher numbers. 896 const auto num_sub_streams = nv > batchsize ? std::min((nv + batchsize) / batchsize, batchsize) : 0; 897 const auto n = xin->map->n; 898 // number of vectors that we handle via the batches. note any singletons are handled by 899 // cublas, hence the nv-1. 900 const auto nvbatch = ((nv % batchsize) == 1) ? nv - 1 : nv; 901 const auto nwork = nvbatch * MDOT_WORKGROUP_NUM; 902 PetscScalar *d_results; 903 cupmStream_t stream; 904 905 PetscFunctionBegin; 906 PetscCall(GetHandlesFrom_(dctx, &stream)); 907 // allocate scratchpad memory for the results of individual work groups 908 PetscCall(PetscDeviceMalloc(dctx, PETSC_MEMTYPE_CUPM(), nwork, &d_results)); 909 { 910 const auto xptr = DeviceArrayRead(dctx, xin); 911 PetscInt yidx = 0; 912 auto subidx = 0; 913 auto cur_stream = stream; 914 auto cur_ctx = dctx; 915 PetscDeviceContext *sub = nullptr; 916 PetscStreamType stype; 917 918 // REVIEW ME: maybe PetscDeviceContextFork() should insert dctx into the first entry of 919 // sub. Ideally the parent context should also join in on the fork, but it is extremely 920 // fiddly to do so presently 921 PetscCall(PetscDeviceContextGetStreamType(dctx, &stype)); 922 if (stype == PETSC_STREAM_GLOBAL_BLOCKING) stype = PETSC_STREAM_DEFAULT_BLOCKING; 923 // If we have a globally blocking stream create nonblocking streams instead (as we can 924 // locally exploit the parallelism). Otherwise use the prescribed stream type. 925 PetscCall(PetscDeviceContextForkWithStreamType(dctx, stype, num_sub_streams, &sub)); 926 PetscCall(PetscLogGpuTimeBegin()); 927 do { 928 if (num_sub_streams) { 929 cur_ctx = sub[subidx++ % num_sub_streams]; 930 PetscCall(GetHandlesFrom_(cur_ctx, &cur_stream)); 931 } 932 // REVIEW ME: Should probably try and load-balance these. Consider the case where nv = 9; 933 // it is very likely better to do 4+5 rather than 8+1 934 switch (nv - yidx) { 935 case 7: 936 PetscCall(mdot_kernel_dispatch_<7>(cur_ctx, cur_stream, xptr.data(), yin, n, d_results, yidx)); 937 break; 938 case 6: 939 PetscCall(mdot_kernel_dispatch_<6>(cur_ctx, cur_stream, xptr.data(), yin, n, d_results, yidx)); 940 break; 941 case 5: 942 PetscCall(mdot_kernel_dispatch_<5>(cur_ctx, cur_stream, xptr.data(), yin, n, d_results, yidx)); 943 break; 944 case 4: 945 PetscCall(mdot_kernel_dispatch_<4>(cur_ctx, cur_stream, xptr.data(), yin, n, d_results, yidx)); 946 break; 947 case 3: 948 PetscCall(mdot_kernel_dispatch_<3>(cur_ctx, cur_stream, xptr.data(), yin, n, d_results, yidx)); 949 break; 950 case 2: 951 PetscCall(mdot_kernel_dispatch_<2>(cur_ctx, cur_stream, xptr.data(), yin, n, d_results, yidx)); 952 break; 953 case 1: { 954 cupmBlasHandle_t cupmBlasHandle; 955 956 PetscCall(GetHandlesFrom_(cur_ctx, &cupmBlasHandle)); 957 PetscCallCUPMBLAS(cupmBlasXdot(cupmBlasHandle, static_cast<cupmBlasInt_t>(n), DeviceArrayRead(cur_ctx, yin[yidx]).cupmdata(), 1, xptr.cupmdata(), 1, cupmScalarPtrCast(z + yidx))); 958 ++yidx; 959 } break; 960 default: // 8 or more 961 PetscCall(mdot_kernel_dispatch_<8>(cur_ctx, cur_stream, xptr.data(), yin, n, d_results, yidx)); 962 break; 963 } 964 } while (yidx < nv); 965 PetscCall(PetscLogGpuTimeEnd()); 966 PetscCall(PetscDeviceContextJoin(dctx, num_sub_streams, PETSC_DEVICE_CONTEXT_JOIN_DESTROY, &sub)); 967 } 968 969 PetscCall(PetscCUPMLaunchKernel1D(nvbatch, 0, stream, kernels::sum_kernel, nvbatch, d_results)); 970 // copy result of device reduction to host 971 PetscCall(PetscCUPMMemcpyAsync(z, d_results, nvbatch, cupmMemcpyDeviceToHost, stream)); 972 // do these now while final reduction is in flight 973 PetscCall(PetscLogFlops(nwork)); 974 PetscCall(PetscDeviceFree(dctx, d_results)); 975 PetscFunctionReturn(PETSC_SUCCESS); 976 } 977 978 #undef MDOT_WORKGROUP_NUM 979 #undef MDOT_WORKGROUP_SIZE 980 981 template <device::cupm::DeviceType T> 982 inline PetscErrorCode VecSeq_CUPM<T>::mdot_(std::true_type, Vec xin, PetscInt nv, const Vec yin[], PetscScalar *z, PetscDeviceContext dctx) noexcept 983 { 984 // probably not worth it to run more than 8 of these at a time? 985 const auto n_sub = PetscMin(nv, 8); 986 const auto n = static_cast<cupmBlasInt_t>(xin->map->n); 987 const auto xptr = DeviceArrayRead(dctx, xin); 988 PetscScalar *d_z; 989 PetscDeviceContext *subctx; 990 cupmStream_t stream; 991 992 PetscFunctionBegin; 993 PetscCall(GetHandlesFrom_(dctx, &stream)); 994 PetscCall(PetscDeviceMalloc(dctx, PETSC_MEMTYPE_CUPM(), nv, &d_z)); 995 PetscCall(PetscDeviceContextFork(dctx, n_sub, &subctx)); 996 PetscCall(PetscLogGpuTimeBegin()); 997 for (PetscInt i = 0; i < nv; ++i) { 998 const auto sub = subctx[i % n_sub]; 999 cupmBlasHandle_t handle; 1000 cupmBlasPointerMode_t old_mode; 1001 1002 PetscCall(GetHandlesFrom_(sub, &handle)); 1003 PetscCallCUPMBLAS(cupmBlasGetPointerMode(handle, &old_mode)); 1004 if (old_mode != CUPMBLAS_POINTER_MODE_DEVICE) PetscCallCUPMBLAS(cupmBlasSetPointerMode(handle, CUPMBLAS_POINTER_MODE_DEVICE)); 1005 PetscCallCUPMBLAS(cupmBlasXdot(handle, n, DeviceArrayRead(sub, yin[i]), 1, xptr.cupmdata(), 1, cupmScalarPtrCast(d_z + i))); 1006 if (old_mode != CUPMBLAS_POINTER_MODE_DEVICE) PetscCallCUPMBLAS(cupmBlasSetPointerMode(handle, old_mode)); 1007 } 1008 PetscCall(PetscLogGpuTimeEnd()); 1009 PetscCall(PetscDeviceContextJoin(dctx, n_sub, PETSC_DEVICE_CONTEXT_JOIN_DESTROY, &subctx)); 1010 PetscCall(PetscCUPMMemcpyAsync(z, d_z, nv, cupmMemcpyDeviceToHost, stream)); 1011 PetscCall(PetscDeviceFree(dctx, d_z)); 1012 // REVIEW ME: flops????? 1013 PetscFunctionReturn(PETSC_SUCCESS); 1014 } 1015 1016 // v->ops->mdot 1017 template <device::cupm::DeviceType T> 1018 inline PetscErrorCode VecSeq_CUPM<T>::mdot(Vec xin, PetscInt nv, const Vec yin[], PetscScalar *z) noexcept 1019 { 1020 PetscFunctionBegin; 1021 if (PetscUnlikely(nv == 1)) { 1022 // dot handles nv = 0 correctly 1023 PetscCall(dot(xin, const_cast<Vec>(yin[0]), z)); 1024 } else if (const auto n = xin->map->n) { 1025 PetscDeviceContext dctx; 1026 1027 PetscCheck(nv > 0, PETSC_COMM_SELF, PETSC_ERR_LIB, "Number of vectors provided to %s %" PetscInt_FMT " not positive", PETSC_FUNCTION_NAME, nv); 1028 PetscCall(GetHandles_(&dctx)); 1029 PetscCall(mdot_(std::integral_constant<bool, PetscDefined(USE_COMPLEX)>{}, xin, nv, yin, z, dctx)); 1030 // REVIEW ME: double count of flops?? 1031 PetscCall(PetscLogGpuFlops(nv * (2 * n - 1))); 1032 PetscCall(PetscDeviceContextSynchronize(dctx)); 1033 } else { 1034 PetscCall(PetscArrayzero(z, nv)); 1035 } 1036 PetscFunctionReturn(PETSC_SUCCESS); 1037 } 1038 1039 // v->ops->set 1040 template <device::cupm::DeviceType T> 1041 inline PetscErrorCode VecSeq_CUPM<T>::set(Vec xin, PetscScalar alpha) noexcept 1042 { 1043 const auto n = xin->map->n; 1044 PetscDeviceContext dctx; 1045 cupmStream_t stream; 1046 1047 PetscFunctionBegin; 1048 PetscCall(GetHandles_(&dctx, &stream)); 1049 { 1050 const auto xptr = DeviceArrayWrite(dctx, xin); 1051 1052 if (alpha == PetscScalar(0.0)) { 1053 PetscCall(PetscCUPMMemsetAsync(xptr.data(), 0, n, stream)); 1054 } else { 1055 const auto dptr = thrust::device_pointer_cast(xptr.data()); 1056 1057 PetscCallThrust(THRUST_CALL(thrust::fill, stream, dptr, dptr + n, alpha)); 1058 } 1059 if (n) PetscCall(PetscDeviceContextSynchronize(dctx)); // don't sync if we did nothing 1060 } 1061 PetscFunctionReturn(PETSC_SUCCESS); 1062 } 1063 1064 // v->ops->scale 1065 template <device::cupm::DeviceType T> 1066 inline PetscErrorCode VecSeq_CUPM<T>::scale(Vec xin, PetscScalar alpha) noexcept 1067 { 1068 PetscFunctionBegin; 1069 if (PetscUnlikely(alpha == PetscScalar(1.0))) PetscFunctionReturn(PETSC_SUCCESS); 1070 if (PetscUnlikely(alpha == PetscScalar(0.0))) { 1071 PetscCall(set(xin, alpha)); 1072 } else if (const auto n = static_cast<cupmBlasInt_t>(xin->map->n)) { 1073 PetscDeviceContext dctx; 1074 cupmBlasHandle_t cupmBlasHandle; 1075 1076 PetscCall(GetHandles_(&dctx, &cupmBlasHandle)); 1077 PetscCall(PetscLogGpuTimeBegin()); 1078 PetscCallCUPMBLAS(cupmBlasXscal(cupmBlasHandle, n, cupmScalarPtrCast(&alpha), DeviceArrayReadWrite(dctx, xin), 1)); 1079 PetscCall(PetscLogGpuTimeEnd()); 1080 PetscCall(PetscLogGpuFlops(n)); 1081 PetscCall(PetscDeviceContextSynchronize(dctx)); 1082 } else { 1083 PetscCall(MaybeIncrementEmptyLocalVec(xin)); 1084 } 1085 PetscFunctionReturn(PETSC_SUCCESS); 1086 } 1087 1088 // v->ops->tdot 1089 template <device::cupm::DeviceType T> 1090 inline PetscErrorCode VecSeq_CUPM<T>::tdot(Vec xin, Vec yin, PetscScalar *z) noexcept 1091 { 1092 PetscFunctionBegin; 1093 if (const auto n = static_cast<cupmBlasInt_t>(xin->map->n)) { 1094 PetscDeviceContext dctx; 1095 cupmBlasHandle_t cupmBlasHandle; 1096 1097 PetscCall(GetHandles_(&dctx, &cupmBlasHandle)); 1098 PetscCall(PetscLogGpuTimeBegin()); 1099 PetscCallCUPMBLAS(cupmBlasXdotu(cupmBlasHandle, n, DeviceArrayRead(dctx, xin), 1, DeviceArrayRead(dctx, yin), 1, cupmScalarPtrCast(z))); 1100 PetscCall(PetscLogGpuTimeEnd()); 1101 PetscCall(PetscLogGpuFlops(2 * n - 1)); 1102 } else { 1103 *z = 0.0; 1104 } 1105 PetscFunctionReturn(PETSC_SUCCESS); 1106 } 1107 1108 // v->ops->copy 1109 template <device::cupm::DeviceType T> 1110 inline PetscErrorCode VecSeq_CUPM<T>::copy(Vec xin, Vec yout) noexcept 1111 { 1112 PetscFunctionBegin; 1113 if (xin == yout) PetscFunctionReturn(PETSC_SUCCESS); 1114 if (const auto n = xin->map->n) { 1115 const auto xmask = xin->offloadmask; 1116 // silence buggy gcc warning: mode may be used uninitialized in this function 1117 auto mode = cupmMemcpyDeviceToDevice; 1118 PetscDeviceContext dctx; 1119 cupmStream_t stream; 1120 1121 // translate from PetscOffloadMask to cupmMemcpyKind 1122 switch (const auto ymask = yout->offloadmask) { 1123 case PETSC_OFFLOAD_UNALLOCATED: { 1124 PetscBool yiscupm; 1125 1126 PetscCall(PetscObjectTypeCompareAny(PetscObjectCast(yout), &yiscupm, VECSEQCUPM(), VECMPICUPM(), "")); 1127 if (yiscupm) { 1128 mode = PetscOffloadDevice(xmask) ? cupmMemcpyDeviceToDevice : cupmMemcpyHostToHost; 1129 break; 1130 } 1131 } // fall-through if unallocated and not cupm 1132 #if PETSC_CPP_VERSION >= 17 1133 [[fallthrough]]; 1134 #endif 1135 case PETSC_OFFLOAD_CPU: 1136 mode = PetscOffloadHost(xmask) ? cupmMemcpyHostToHost : cupmMemcpyDeviceToHost; 1137 break; 1138 case PETSC_OFFLOAD_BOTH: 1139 case PETSC_OFFLOAD_GPU: 1140 mode = PetscOffloadDevice(xmask) ? cupmMemcpyDeviceToDevice : cupmMemcpyHostToDevice; 1141 break; 1142 default: 1143 SETERRQ(PETSC_COMM_SELF, PETSC_ERR_ARG_INCOMP, "Incompatible offload mask %s", PetscOffloadMaskToString(ymask)); 1144 } 1145 1146 PetscCall(GetHandles_(&dctx, &stream)); 1147 switch (mode) { 1148 case cupmMemcpyDeviceToDevice: // the best case 1149 case cupmMemcpyHostToDevice: { // not terrible 1150 const auto yptr = DeviceArrayWrite(dctx, yout); 1151 const auto xptr = mode == cupmMemcpyDeviceToDevice ? DeviceArrayRead(dctx, xin).data() : HostArrayRead(dctx, xin).data(); 1152 1153 PetscCall(PetscLogGpuTimeBegin()); 1154 PetscCall(PetscCUPMMemcpyAsync(yptr.data(), xptr, n, mode, stream)); 1155 PetscCall(PetscLogGpuTimeEnd()); 1156 } break; 1157 case cupmMemcpyDeviceToHost: // not great 1158 case cupmMemcpyHostToHost: { // worst case 1159 const auto xptr = mode == cupmMemcpyDeviceToHost ? DeviceArrayRead(dctx, xin).data() : HostArrayRead(dctx, xin).data(); 1160 PetscScalar *yptr; 1161 1162 PetscCall(VecGetArrayWrite(yout, &yptr)); 1163 if (mode == cupmMemcpyDeviceToHost) PetscCall(PetscLogGpuTimeBegin()); 1164 PetscCall(PetscCUPMMemcpyAsync(yptr, xptr, n, mode, stream, /* force async */ true)); 1165 if (mode == cupmMemcpyDeviceToHost) PetscCall(PetscLogGpuTimeEnd()); 1166 PetscCall(VecRestoreArrayWrite(yout, &yptr)); 1167 } break; 1168 default: 1169 SETERRQ(PETSC_COMM_SELF, PETSC_ERR_GPU, "Unknown cupmMemcpyKind %d", static_cast<int>(mode)); 1170 } 1171 PetscCall(PetscDeviceContextSynchronize(dctx)); 1172 } else { 1173 PetscCall(MaybeIncrementEmptyLocalVec(yout)); 1174 } 1175 PetscFunctionReturn(PETSC_SUCCESS); 1176 } 1177 1178 // v->ops->swap 1179 template <device::cupm::DeviceType T> 1180 inline PetscErrorCode VecSeq_CUPM<T>::swap(Vec xin, Vec yin) noexcept 1181 { 1182 PetscFunctionBegin; 1183 if (xin == yin) PetscFunctionReturn(PETSC_SUCCESS); 1184 if (const auto n = static_cast<cupmBlasInt_t>(xin->map->n)) { 1185 PetscDeviceContext dctx; 1186 cupmBlasHandle_t cupmBlasHandle; 1187 1188 PetscCall(GetHandles_(&dctx, &cupmBlasHandle)); 1189 PetscCall(PetscLogGpuTimeBegin()); 1190 PetscCallCUPMBLAS(cupmBlasXswap(cupmBlasHandle, n, DeviceArrayReadWrite(dctx, xin), 1, DeviceArrayReadWrite(dctx, yin), 1)); 1191 PetscCall(PetscLogGpuTimeEnd()); 1192 PetscCall(PetscDeviceContextSynchronize(dctx)); 1193 } else { 1194 PetscCall(MaybeIncrementEmptyLocalVec(xin)); 1195 PetscCall(MaybeIncrementEmptyLocalVec(yin)); 1196 } 1197 PetscFunctionReturn(PETSC_SUCCESS); 1198 } 1199 1200 // v->ops->axpby 1201 template <device::cupm::DeviceType T> 1202 inline PetscErrorCode VecSeq_CUPM<T>::axpby(Vec yin, PetscScalar alpha, PetscScalar beta, Vec xin) noexcept 1203 { 1204 PetscFunctionBegin; 1205 if (alpha == PetscScalar(0.0)) { 1206 PetscCall(scale(yin, beta)); 1207 } else if (beta == PetscScalar(1.0)) { 1208 PetscCall(axpy(yin, alpha, xin)); 1209 } else if (alpha == PetscScalar(1.0)) { 1210 PetscCall(aypx(yin, beta, xin)); 1211 } else if (const auto n = static_cast<cupmBlasInt_t>(yin->map->n)) { 1212 const auto betaIsZero = beta == PetscScalar(0.0); 1213 const auto aptr = cupmScalarPtrCast(&alpha); 1214 PetscDeviceContext dctx; 1215 cupmBlasHandle_t cupmBlasHandle; 1216 1217 PetscCall(GetHandles_(&dctx, &cupmBlasHandle)); 1218 { 1219 const auto xptr = DeviceArrayRead(dctx, xin); 1220 1221 if (betaIsZero /* beta = 0 */) { 1222 // here we can get away with purely write-only as we memcpy into it first 1223 const auto yptr = DeviceArrayWrite(dctx, yin); 1224 cupmStream_t stream; 1225 1226 PetscCall(GetHandlesFrom_(dctx, &stream)); 1227 PetscCall(PetscLogGpuTimeBegin()); 1228 PetscCall(PetscCUPMMemcpyAsync(yptr.data(), xptr.data(), n, cupmMemcpyDeviceToDevice, stream)); 1229 PetscCallCUPMBLAS(cupmBlasXscal(cupmBlasHandle, n, aptr, yptr.cupmdata(), 1)); 1230 } else { 1231 const auto yptr = DeviceArrayReadWrite(dctx, yin); 1232 1233 PetscCall(PetscLogGpuTimeBegin()); 1234 PetscCallCUPMBLAS(cupmBlasXscal(cupmBlasHandle, n, cupmScalarPtrCast(&beta), yptr.cupmdata(), 1)); 1235 PetscCallCUPMBLAS(cupmBlasXaxpy(cupmBlasHandle, n, aptr, xptr.cupmdata(), 1, yptr.cupmdata(), 1)); 1236 } 1237 } 1238 PetscCall(PetscLogGpuTimeEnd()); 1239 PetscCall(PetscLogGpuFlops((betaIsZero ? 1 : 3) * n)); 1240 PetscCall(PetscDeviceContextSynchronize(dctx)); 1241 } else { 1242 PetscCall(MaybeIncrementEmptyLocalVec(yin)); 1243 } 1244 PetscFunctionReturn(PETSC_SUCCESS); 1245 } 1246 1247 // v->ops->axpbypcz 1248 template <device::cupm::DeviceType T> 1249 inline PetscErrorCode VecSeq_CUPM<T>::axpbypcz(Vec zin, PetscScalar alpha, PetscScalar beta, PetscScalar gamma, Vec xin, Vec yin) noexcept 1250 { 1251 PetscFunctionBegin; 1252 if (gamma != PetscScalar(1.0)) PetscCall(scale(zin, gamma)); 1253 PetscCall(axpy(zin, alpha, xin)); 1254 PetscCall(axpy(zin, beta, yin)); 1255 PetscFunctionReturn(PETSC_SUCCESS); 1256 } 1257 1258 // v->ops->norm 1259 template <device::cupm::DeviceType T> 1260 inline PetscErrorCode VecSeq_CUPM<T>::norm(Vec xin, NormType type, PetscReal *z) noexcept 1261 { 1262 PetscDeviceContext dctx; 1263 cupmBlasHandle_t cupmBlasHandle; 1264 1265 PetscFunctionBegin; 1266 PetscCall(GetHandles_(&dctx, &cupmBlasHandle)); 1267 if (const auto n = static_cast<cupmBlasInt_t>(xin->map->n)) { 1268 const auto xptr = DeviceArrayRead(dctx, xin); 1269 PetscInt flopCount = 0; 1270 1271 PetscCall(PetscLogGpuTimeBegin()); 1272 switch (type) { 1273 case NORM_1_AND_2: 1274 case NORM_1: 1275 PetscCallCUPMBLAS(cupmBlasXasum(cupmBlasHandle, n, xptr.cupmdata(), 1, cupmRealPtrCast(z))); 1276 flopCount = std::max(n - 1, 0); 1277 if (type == NORM_1) break; 1278 ++z; // fall-through 1279 #if PETSC_CPP_VERSION >= 17 1280 [[fallthrough]]; 1281 #endif 1282 case NORM_2: 1283 case NORM_FROBENIUS: 1284 PetscCallCUPMBLAS(cupmBlasXnrm2(cupmBlasHandle, n, xptr.cupmdata(), 1, cupmRealPtrCast(z))); 1285 flopCount += std::max(2 * n - 1, 0); // += in case we've fallen through from NORM_1_AND_2 1286 break; 1287 case NORM_INFINITY: { 1288 cupmBlasInt_t max_loc = 0; 1289 PetscScalar xv = 0.; 1290 cupmStream_t stream; 1291 1292 PetscCall(GetHandlesFrom_(dctx, &stream)); 1293 PetscCallCUPMBLAS(cupmBlasXamax(cupmBlasHandle, n, xptr.cupmdata(), 1, &max_loc)); 1294 PetscCall(PetscCUPMMemcpyAsync(&xv, xptr.data() + max_loc - 1, 1, cupmMemcpyDeviceToHost, stream)); 1295 *z = PetscAbsScalar(xv); 1296 // REVIEW ME: flopCount = ??? 1297 } break; 1298 } 1299 PetscCall(PetscLogGpuTimeEnd()); 1300 PetscCall(PetscLogGpuFlops(flopCount)); 1301 } else { 1302 z[0] = 0.0; 1303 z[type == NORM_1_AND_2] = 0.0; 1304 } 1305 PetscFunctionReturn(PETSC_SUCCESS); 1306 } 1307 1308 namespace detail 1309 { 1310 1311 struct dotnorm2_mult { 1312 PETSC_NODISCARD PETSC_HOSTDEVICE_INLINE_DECL thrust::tuple<PetscScalar, PetscScalar> operator()(const PetscScalar &s, const PetscScalar &t) const noexcept 1313 { 1314 const auto conjt = PetscConj(t); 1315 1316 return {s * conjt, t * conjt}; 1317 } 1318 }; 1319 1320 // it is positively __bananas__ that thrust does not define default operator+ for tuples... I 1321 // would do it myself but now I am worried that they do so on purpose... 1322 struct dotnorm2_tuple_plus { 1323 using value_type = thrust::tuple<PetscScalar, PetscScalar>; 1324 1325 PETSC_NODISCARD PETSC_HOSTDEVICE_INLINE_DECL value_type operator()(const value_type &lhs, const value_type &rhs) const noexcept { return {lhs.get<0>() + rhs.get<0>(), lhs.get<1>() + rhs.get<1>()}; } 1326 }; 1327 1328 } // namespace detail 1329 1330 // v->ops->dotnorm2 1331 template <device::cupm::DeviceType T> 1332 inline PetscErrorCode VecSeq_CUPM<T>::dotnorm2(Vec s, Vec t, PetscScalar *dp, PetscScalar *nm) noexcept 1333 { 1334 PetscDeviceContext dctx; 1335 cupmStream_t stream; 1336 1337 PetscFunctionBegin; 1338 PetscCall(GetHandles_(&dctx, &stream)); 1339 { 1340 PetscScalar dpt = 0.0, nmt = 0.0; 1341 const auto sdptr = thrust::device_pointer_cast(DeviceArrayRead(dctx, s).data()); 1342 1343 // clang-format off 1344 PetscCallThrust( 1345 thrust::tie(*dp, *nm) = THRUST_CALL( 1346 thrust::inner_product, 1347 stream, 1348 sdptr, sdptr+s->map->n, thrust::device_pointer_cast(DeviceArrayRead(dctx, t).data()), 1349 thrust::make_tuple(dpt, nmt), 1350 detail::dotnorm2_tuple_plus{}, detail::dotnorm2_mult{} 1351 ); 1352 ); 1353 // clang-format on 1354 } 1355 PetscFunctionReturn(PETSC_SUCCESS); 1356 } 1357 1358 namespace detail 1359 { 1360 1361 struct conjugate { 1362 PETSC_NODISCARD PETSC_HOSTDEVICE_INLINE_DECL PetscScalar operator()(PetscScalar x) const noexcept { return PetscConj(x); } 1363 }; 1364 1365 } // namespace detail 1366 1367 // v->ops->conjugate 1368 template <device::cupm::DeviceType T> 1369 inline PetscErrorCode VecSeq_CUPM<T>::conjugate(Vec xin) noexcept 1370 { 1371 PetscFunctionBegin; 1372 if (PetscDefined(USE_COMPLEX)) PetscCall(pointwiseunary_(detail::conjugate{}, xin)); 1373 PetscFunctionReturn(PETSC_SUCCESS); 1374 } 1375 1376 namespace detail 1377 { 1378 1379 struct real_part { 1380 PETSC_NODISCARD PETSC_HOSTDEVICE_INLINE_DECL thrust::tuple<PetscReal, PetscInt> operator()(const thrust::tuple<PetscScalar, PetscInt> &x) const { return {PetscRealPart(x.get<0>()), x.get<1>()}; } 1381 1382 PETSC_NODISCARD PETSC_HOSTDEVICE_INLINE_DECL PetscReal operator()(PetscScalar x) const { return PetscRealPart(x); } 1383 }; 1384 1385 // deriving from Operator allows us to "store" an instance of the operator in the class but 1386 // also take advantage of empty base class optimization if the operator is stateless 1387 template <typename Operator> 1388 class tuple_compare : Operator { 1389 public: 1390 using tuple_type = thrust::tuple<PetscReal, PetscInt>; 1391 using operator_type = Operator; 1392 1393 PETSC_NODISCARD PETSC_HOSTDEVICE_INLINE_DECL tuple_type operator()(const tuple_type &x, const tuple_type &y) const noexcept 1394 { 1395 if (op_()(y.get<0>(), x.get<0>())) { 1396 // if y is strictly greater/less than x, return y 1397 return y; 1398 } else if (y.get<0>() == x.get<0>()) { 1399 // if equal, prefer lower index 1400 return y.get<1>() < x.get<1>() ? y : x; 1401 } 1402 // otherwise return x 1403 return x; 1404 } 1405 1406 private: 1407 PETSC_NODISCARD PETSC_HOSTDEVICE_INLINE_DECL const operator_type &op_() const noexcept { return *this; } 1408 }; 1409 1410 } // namespace detail 1411 1412 template <device::cupm::DeviceType T> 1413 template <typename TupleFuncT, typename UnaryFuncT> 1414 inline PetscErrorCode VecSeq_CUPM<T>::minmax_(TupleFuncT &&tuple_ftr, UnaryFuncT &&unary_ftr, Vec v, PetscInt *p, PetscReal *m) noexcept 1415 { 1416 PetscFunctionBegin; 1417 PetscCheckTypeNames(v, VECSEQCUPM(), VECMPICUPM()); 1418 if (p) *p = -1; 1419 if (const auto n = v->map->n) { 1420 PetscDeviceContext dctx; 1421 cupmStream_t stream; 1422 1423 PetscCall(GetHandles_(&dctx, &stream)); 1424 // needed to: 1425 // 1. switch between transform_reduce and reduce 1426 // 2. strip the real_part functor from the arguments 1427 #if PetscDefined(USE_COMPLEX) 1428 #define THRUST_MINMAX_REDUCE(...) THRUST_CALL(thrust::transform_reduce, __VA_ARGS__) 1429 #else 1430 #define THRUST_MINMAX_REDUCE(s, b, e, real_part__, ...) THRUST_CALL(thrust::reduce, s, b, e, __VA_ARGS__) 1431 #endif 1432 { 1433 const auto vptr = thrust::device_pointer_cast(DeviceArrayRead(dctx, v).data()); 1434 1435 if (p) { 1436 // clang-format off 1437 const auto zip = thrust::make_zip_iterator( 1438 thrust::make_tuple(std::move(vptr), thrust::make_counting_iterator(PetscInt{0})) 1439 ); 1440 // clang-format on 1441 // need to use preprocessor conditionals since otherwise thrust complains about not being 1442 // able to convert a thrust::device_reference<PetscScalar> to a PetscReal on complex 1443 // builds... 1444 // clang-format off 1445 PetscCallThrust( 1446 thrust::tie(*m, *p) = THRUST_MINMAX_REDUCE( 1447 stream, zip, zip + n, detail::real_part{}, 1448 thrust::make_tuple(*m, *p), std::forward<TupleFuncT>(tuple_ftr) 1449 ); 1450 ); 1451 // clang-format on 1452 } else { 1453 // clang-format off 1454 PetscCallThrust( 1455 *m = THRUST_MINMAX_REDUCE( 1456 stream, vptr, vptr + n, detail::real_part{}, 1457 *m, std::forward<UnaryFuncT>(unary_ftr) 1458 ); 1459 ); 1460 // clang-format on 1461 } 1462 } 1463 #undef THRUST_MINMAX_REDUCE 1464 } 1465 // REVIEW ME: flops? 1466 PetscFunctionReturn(PETSC_SUCCESS); 1467 } 1468 1469 // v->ops->max 1470 template <device::cupm::DeviceType T> 1471 inline PetscErrorCode VecSeq_CUPM<T>::max(Vec v, PetscInt *p, PetscReal *m) noexcept 1472 { 1473 using tuple_functor = detail::tuple_compare<thrust::greater<PetscReal>>; 1474 using unary_functor = thrust::maximum<PetscReal>; 1475 1476 PetscFunctionBegin; 1477 *m = PETSC_MIN_REAL; 1478 // use {} constructor syntax otherwise most vexing parse 1479 PetscCall(minmax_(tuple_functor{}, unary_functor{}, v, p, m)); 1480 PetscFunctionReturn(PETSC_SUCCESS); 1481 } 1482 1483 // v->ops->min 1484 template <device::cupm::DeviceType T> 1485 inline PetscErrorCode VecSeq_CUPM<T>::min(Vec v, PetscInt *p, PetscReal *m) noexcept 1486 { 1487 using tuple_functor = detail::tuple_compare<thrust::less<PetscReal>>; 1488 using unary_functor = thrust::minimum<PetscReal>; 1489 1490 PetscFunctionBegin; 1491 *m = PETSC_MAX_REAL; 1492 // use {} constructor syntax otherwise most vexing parse 1493 PetscCall(minmax_(tuple_functor{}, unary_functor{}, v, p, m)); 1494 PetscFunctionReturn(PETSC_SUCCESS); 1495 } 1496 1497 // v->ops->sum 1498 template <device::cupm::DeviceType T> 1499 inline PetscErrorCode VecSeq_CUPM<T>::sum(Vec v, PetscScalar *sum) noexcept 1500 { 1501 PetscFunctionBegin; 1502 if (const auto n = v->map->n) { 1503 PetscDeviceContext dctx; 1504 cupmStream_t stream; 1505 1506 PetscCall(GetHandles_(&dctx, &stream)); 1507 const auto dptr = thrust::device_pointer_cast(DeviceArrayRead(dctx, v).data()); 1508 // REVIEW ME: why not cupmBlasXasum()? 1509 PetscCallThrust(*sum = THRUST_CALL(thrust::reduce, stream, dptr, dptr + n, PetscScalar{0.0});); 1510 // REVIEW ME: must be at least n additions 1511 PetscCall(PetscLogGpuFlops(n)); 1512 } else { 1513 *sum = 0.0; 1514 } 1515 PetscFunctionReturn(PETSC_SUCCESS); 1516 } 1517 1518 template <device::cupm::DeviceType T> 1519 inline PetscErrorCode VecSeq_CUPM<T>::shift(Vec v, PetscScalar shift) noexcept 1520 { 1521 PetscFunctionBegin; 1522 PetscCall(pointwiseunary_(device::cupm::functors::make_plus_equals(shift), v)); 1523 PetscFunctionReturn(PETSC_SUCCESS); 1524 } 1525 1526 template <device::cupm::DeviceType T> 1527 inline PetscErrorCode VecSeq_CUPM<T>::setrandom(Vec v, PetscRandom rand) noexcept 1528 { 1529 PetscFunctionBegin; 1530 if (const auto n = v->map->n) { 1531 PetscBool iscurand; 1532 PetscDeviceContext dctx; 1533 1534 PetscCall(GetHandles_(&dctx)); 1535 PetscCall(PetscObjectTypeCompare(PetscObjectCast(rand), PETSCCURAND, &iscurand)); 1536 if (iscurand) PetscCall(PetscRandomGetValues(rand, n, DeviceArrayWrite(dctx, v))); 1537 else PetscCall(PetscRandomGetValues(rand, n, HostArrayWrite(dctx, v))); 1538 } else { 1539 PetscCall(MaybeIncrementEmptyLocalVec(v)); 1540 } 1541 // REVIEW ME: flops???? 1542 // REVIEW ME: Timing??? 1543 PetscFunctionReturn(PETSC_SUCCESS); 1544 } 1545 1546 // v->ops->setpreallocation 1547 template <device::cupm::DeviceType T> 1548 inline PetscErrorCode VecSeq_CUPM<T>::setpreallocationcoo(Vec v, PetscCount ncoo, const PetscInt coo_i[]) noexcept 1549 { 1550 PetscDeviceContext dctx; 1551 1552 PetscFunctionBegin; 1553 PetscCall(GetHandles_(&dctx)); 1554 PetscCall(VecSetPreallocationCOO_Seq(v, ncoo, coo_i)); 1555 PetscCall(SetPreallocationCOO_CUPMBase(v, ncoo, coo_i, dctx)); 1556 PetscFunctionReturn(PETSC_SUCCESS); 1557 } 1558 1559 namespace kernels 1560 { 1561 1562 template <typename F> 1563 PETSC_DEVICE_INLINE_DECL void add_coo_values_impl(const PetscScalar *PETSC_RESTRICT vv, PetscCount n, const PetscCount *PETSC_RESTRICT jmap, const PetscCount *PETSC_RESTRICT perm, InsertMode imode, PetscScalar *PETSC_RESTRICT xv, F &&xvindex) 1564 { 1565 ::Petsc::device::cupm::kernels::util::grid_stride_1D(n, [=](PetscCount i) { 1566 const auto end = jmap[i + 1]; 1567 const auto idx = xvindex(i); 1568 PetscScalar sum = 0.0; 1569 1570 for (auto k = jmap[i]; k < end; ++k) sum += vv[perm[k]]; 1571 1572 if (imode == INSERT_VALUES) { 1573 xv[idx] = sum; 1574 } else { 1575 xv[idx] += sum; 1576 } 1577 }); 1578 return; 1579 } 1580 1581 namespace 1582 { 1583 1584 PETSC_KERNEL_DECL void add_coo_values(const PetscScalar *PETSC_RESTRICT v, PetscCount n, const PetscCount *PETSC_RESTRICT jmap1, const PetscCount *PETSC_RESTRICT perm1, InsertMode imode, PetscScalar *PETSC_RESTRICT xv) 1585 { 1586 add_coo_values_impl(v, n, jmap1, perm1, imode, xv, [](PetscCount i) { return i; }); 1587 return; 1588 } 1589 1590 } // namespace 1591 1592 #if PetscDefined(USING_HCC) 1593 namespace do_not_use 1594 { 1595 1596 // Needed to silence clang warning: 1597 // 1598 // warning: function 'FUNCTION NAME' is not needed and will not be emitted 1599 // 1600 // The warning is silly, since the function *is* used, however the host compiler does not 1601 // appear see this. Likely because the function using it is in a template. 1602 // 1603 // This warning appeared in clang-11, and still persists until clang-15 (21/02/2023) 1604 inline void silence_warning_function_sum_kernel_is_not_needed_and_will_not_be_emitted() 1605 { 1606 (void)sum_kernel; 1607 } 1608 1609 inline void silence_warning_function_add_coo_values_is_not_needed_and_will_not_be_emitted() 1610 { 1611 (void)add_coo_values; 1612 } 1613 1614 } // namespace do_not_use 1615 #endif 1616 1617 } // namespace kernels 1618 1619 // v->ops->setvaluescoo 1620 template <device::cupm::DeviceType T> 1621 inline PetscErrorCode VecSeq_CUPM<T>::setvaluescoo(Vec x, const PetscScalar v[], InsertMode imode) noexcept 1622 { 1623 auto vv = const_cast<PetscScalar *>(v); 1624 PetscMemType memtype; 1625 PetscDeviceContext dctx; 1626 cupmStream_t stream; 1627 1628 PetscFunctionBegin; 1629 PetscCall(GetHandles_(&dctx, &stream)); 1630 PetscCall(PetscGetMemType(v, &memtype)); 1631 if (PetscMemTypeHost(memtype)) { 1632 const auto size = VecIMPLCast(x)->coo_n; 1633 1634 // If user gave v[] in host, we might need to copy it to device if any 1635 PetscCall(PetscDeviceMalloc(dctx, PETSC_MEMTYPE_CUPM(), size, &vv)); 1636 PetscCall(PetscCUPMMemcpyAsync(vv, v, size, cupmMemcpyHostToDevice, stream)); 1637 } 1638 1639 if (const auto n = x->map->n) { 1640 const auto vcu = VecCUPMCast(x); 1641 1642 PetscCall(PetscCUPMLaunchKernel1D(n, 0, stream, kernels::add_coo_values, vv, n, vcu->jmap1_d, vcu->perm1_d, imode, imode == INSERT_VALUES ? DeviceArrayWrite(dctx, x).data() : DeviceArrayReadWrite(dctx, x).data())); 1643 } else { 1644 PetscCall(MaybeIncrementEmptyLocalVec(x)); 1645 } 1646 1647 if (PetscMemTypeHost(memtype)) PetscCall(PetscDeviceFree(dctx, vv)); 1648 PetscCall(PetscDeviceContextSynchronize(dctx)); 1649 PetscFunctionReturn(PETSC_SUCCESS); 1650 } 1651 1652 // ========================================================================================== 1653 // VecSeq_CUPM - Implementations 1654 // ========================================================================================== 1655 1656 namespace 1657 { 1658 1659 template <typename T> 1660 inline PetscErrorCode VecCreateSeqCUPMAsync(T &&VecSeq_CUPM_Impls, MPI_Comm comm, PetscInt n, Vec *v) noexcept 1661 { 1662 PetscFunctionBegin; 1663 PetscValidPointer(v, 4); 1664 PetscCall(VecSeq_CUPM_Impls.createseqcupm(comm, 0, n, v, PETSC_TRUE)); 1665 PetscFunctionReturn(PETSC_SUCCESS); 1666 } 1667 1668 template <typename T> 1669 inline PetscErrorCode VecCreateSeqCUPMWithArraysAsync(T &&VecSeq_CUPM_Impls, MPI_Comm comm, PetscInt bs, PetscInt n, const PetscScalar cpuarray[], const PetscScalar gpuarray[], Vec *v) noexcept 1670 { 1671 PetscFunctionBegin; 1672 if (n && cpuarray) PetscValidScalarPointer(cpuarray, 5); 1673 PetscValidPointer(v, 7); 1674 PetscCall(VecSeq_CUPM_Impls.createseqcupmwithbotharrays(comm, bs, n, cpuarray, gpuarray, v)); 1675 PetscFunctionReturn(PETSC_SUCCESS); 1676 } 1677 1678 template <PetscMemoryAccessMode mode, typename T> 1679 inline PetscErrorCode VecCUPMGetArrayAsync_Private(T &&VecSeq_CUPM_Impls, Vec v, PetscScalar **a) noexcept 1680 { 1681 PetscFunctionBegin; 1682 PetscValidHeaderSpecific(v, VEC_CLASSID, 2); 1683 PetscValidPointer(a, 3); 1684 PetscCall(VecSeq_CUPM_Impls.template getarray<PETSC_MEMTYPE_DEVICE, mode>(v, a)); 1685 PetscFunctionReturn(PETSC_SUCCESS); 1686 } 1687 1688 template <PetscMemoryAccessMode mode, typename T> 1689 inline PetscErrorCode VecCUPMRestoreArrayAsync_Private(T &&VecSeq_CUPM_Impls, Vec v, PetscScalar **a) noexcept 1690 { 1691 PetscFunctionBegin; 1692 PetscValidHeaderSpecific(v, VEC_CLASSID, 2); 1693 PetscCall(VecSeq_CUPM_Impls.template restorearray<PETSC_MEMTYPE_DEVICE, mode>(v, a)); 1694 PetscFunctionReturn(PETSC_SUCCESS); 1695 } 1696 1697 template <typename T> 1698 inline PetscErrorCode VecCUPMGetArrayAsync(T &&VecSeq_CUPM_Impls, Vec v, PetscScalar **a) noexcept 1699 { 1700 PetscFunctionBegin; 1701 PetscCall(VecCUPMGetArrayAsync_Private<PETSC_MEMORY_ACCESS_READ_WRITE>(std::forward<T>(VecSeq_CUPM_Impls), v, a)); 1702 PetscFunctionReturn(PETSC_SUCCESS); 1703 } 1704 1705 template <typename T> 1706 inline PetscErrorCode VecCUPMRestoreArrayAsync(T &&VecSeq_CUPM_Impls, Vec v, PetscScalar **a) noexcept 1707 { 1708 PetscFunctionBegin; 1709 PetscCall(VecCUPMRestoreArrayAsync_Private<PETSC_MEMORY_ACCESS_READ_WRITE>(std::forward<T>(VecSeq_CUPM_Impls), v, a)); 1710 PetscFunctionReturn(PETSC_SUCCESS); 1711 } 1712 1713 template <typename T> 1714 inline PetscErrorCode VecCUPMGetArrayReadAsync(T &&VecSeq_CUPM_Impls, Vec v, const PetscScalar **a) noexcept 1715 { 1716 PetscFunctionBegin; 1717 PetscCall(VecCUPMGetArrayAsync_Private<PETSC_MEMORY_ACCESS_READ>(std::forward<T>(VecSeq_CUPM_Impls), v, const_cast<PetscScalar **>(a))); 1718 PetscFunctionReturn(PETSC_SUCCESS); 1719 } 1720 1721 template <typename T> 1722 inline PetscErrorCode VecCUPMRestoreArrayReadAsync(T &&VecSeq_CUPM_Impls, Vec v, const PetscScalar **a) noexcept 1723 { 1724 PetscFunctionBegin; 1725 PetscCall(VecCUPMRestoreArrayAsync_Private<PETSC_MEMORY_ACCESS_READ>(std::forward<T>(VecSeq_CUPM_Impls), v, const_cast<PetscScalar **>(a))); 1726 PetscFunctionReturn(PETSC_SUCCESS); 1727 } 1728 1729 template <typename T> 1730 inline PetscErrorCode VecCUPMGetArrayWriteAsync(T &&VecSeq_CUPM_Impls, Vec v, PetscScalar **a) noexcept 1731 { 1732 PetscFunctionBegin; 1733 PetscCall(VecCUPMGetArrayAsync_Private<PETSC_MEMORY_ACCESS_WRITE>(std::forward<T>(VecSeq_CUPM_Impls), v, a)); 1734 PetscFunctionReturn(PETSC_SUCCESS); 1735 } 1736 1737 template <typename T> 1738 inline PetscErrorCode VecCUPMRestoreArrayWriteAsync(T &&VecSeq_CUPM_Impls, Vec v, PetscScalar **a) noexcept 1739 { 1740 PetscFunctionBegin; 1741 PetscCall(VecCUPMRestoreArrayAsync_Private<PETSC_MEMORY_ACCESS_WRITE>(std::forward<T>(VecSeq_CUPM_Impls), v, a)); 1742 PetscFunctionReturn(PETSC_SUCCESS); 1743 } 1744 1745 template <typename T> 1746 inline PetscErrorCode VecCUPMPlaceArrayAsync(T &&VecSeq_CUPM_Impls, Vec vin, const PetscScalar a[]) noexcept 1747 { 1748 PetscFunctionBegin; 1749 PetscValidHeaderSpecific(vin, VEC_CLASSID, 2); 1750 PetscCall(VecSeq_CUPM_Impls.template placearray<PETSC_MEMTYPE_DEVICE>(vin, a)); 1751 PetscFunctionReturn(PETSC_SUCCESS); 1752 } 1753 1754 template <typename T> 1755 inline PetscErrorCode VecCUPMReplaceArrayAsync(T &&VecSeq_CUPM_Impls, Vec vin, const PetscScalar a[]) noexcept 1756 { 1757 PetscFunctionBegin; 1758 PetscValidHeaderSpecific(vin, VEC_CLASSID, 2); 1759 PetscCall(VecSeq_CUPM_Impls.template replacearray<PETSC_MEMTYPE_DEVICE>(vin, a)); 1760 PetscFunctionReturn(PETSC_SUCCESS); 1761 } 1762 1763 template <typename T> 1764 inline PetscErrorCode VecCUPMResetArrayAsync(T &&VecSeq_CUPM_Impls, Vec vin) noexcept 1765 { 1766 PetscFunctionBegin; 1767 PetscValidHeaderSpecific(vin, VEC_CLASSID, 2); 1768 PetscCall(VecSeq_CUPM_Impls.template resetarray<PETSC_MEMTYPE_DEVICE>(vin)); 1769 PetscFunctionReturn(PETSC_SUCCESS); 1770 } 1771 1772 } // anonymous namespace 1773 1774 } // namespace impl 1775 1776 } // namespace cupm 1777 1778 } // namespace vec 1779 1780 } // namespace Petsc 1781 1782 #endif // __cplusplus 1783 1784 #endif // PETSCVECSEQCUPM_HPP 1785