1 #ifndef PETSCVECSEQCUPM_HPP 2 #define PETSCVECSEQCUPM_HPP 3 4 #define PETSC_SKIP_SPINLOCK // REVIEW ME: why 5 6 #include <petsc/private/veccupmimpl.h> 7 #include <petsc/private/randomimpl.h> // for _p_PetscRandom 8 #include "../src/sys/objects/device/impls/cupm/cupmthrustutility.hpp" 9 #include "../src/sys/objects/device/impls/cupm/cupmallocator.hpp" 10 #include "../src/sys/objects/device/impls/cupm/kernels.hpp" 11 12 #if defined(__cplusplus) 13 #include <thrust/transform_reduce.h> 14 #include <thrust/reduce.h> 15 #include <thrust/functional.h> 16 #include <thrust/iterator/counting_iterator.h> 17 #include <thrust/inner_product.h> 18 19 namespace Petsc 20 { 21 22 namespace vec 23 { 24 25 namespace cupm 26 { 27 28 namespace impl 29 { 30 31 // ========================================================================================== 32 // VecSeq_CUPM 33 // ========================================================================================== 34 35 template <device::cupm::DeviceType T> 36 class VecSeq_CUPM : Vec_CUPMBase<T, VecSeq_CUPM<T>> { 37 public: 38 PETSC_VEC_CUPM_BASE_CLASS_HEADER(base_type, T, VecSeq_CUPM<T>); 39 40 private: 41 PETSC_NODISCARD static Vec_Seq *VecIMPLCast_(Vec) noexcept; 42 PETSC_NODISCARD static constexpr VecType VECIMPLCUPM_() noexcept; 43 44 PETSC_NODISCARD static PetscErrorCode VecDestroy_IMPL_(Vec) noexcept; 45 PETSC_NODISCARD static PetscErrorCode VecResetArray_IMPL_(Vec) noexcept; 46 PETSC_NODISCARD static PetscErrorCode VecPlaceArray_IMPL_(Vec, const PetscScalar *) noexcept; 47 PETSC_NODISCARD static PetscErrorCode VecCreate_IMPL_Private_(Vec, PetscBool *, PetscInt, PetscScalar *) noexcept; 48 49 PETSC_NODISCARD static PetscErrorCode MaybeIncrementEmptyLocalVec(Vec) noexcept; 50 51 // common core for min and max 52 template <typename TupleFuncT, typename UnaryFuncT> 53 PETSC_NODISCARD static PetscErrorCode minmax_(TupleFuncT &&, UnaryFuncT &&, Vec, PetscInt *, PetscReal *) noexcept; 54 // common core for pointwise binary and pointwise unary thrust functions 55 template <typename BinaryFuncT> 56 PETSC_NODISCARD static PetscErrorCode pointwisebinary_(BinaryFuncT &&, Vec, Vec, Vec) noexcept; 57 template <typename UnaryFuncT> 58 PETSC_NODISCARD static PetscErrorCode pointwiseunary_(UnaryFuncT &&, Vec, Vec /*out*/ = nullptr) noexcept; 59 // mdot dispatchers 60 PETSC_NODISCARD static PetscErrorCode mdot_(/* use complex = */ std::true_type, Vec, PetscInt, const Vec[], PetscScalar *, PetscDeviceContext) noexcept; 61 PETSC_NODISCARD static PetscErrorCode mdot_(/* use complex = */ std::false_type, Vec, PetscInt, const Vec[], PetscScalar *, PetscDeviceContext) noexcept; 62 template <std::size_t... Idx> 63 PETSC_NODISCARD static PetscErrorCode mdot_kernel_dispatch_(PetscDeviceContext, cupmStream_t, const PetscScalar *, const Vec[], PetscInt, PetscScalar *, util::index_sequence<Idx...>) noexcept; 64 template <int> 65 PETSC_NODISCARD static PetscErrorCode mdot_kernel_dispatch_(PetscDeviceContext, cupmStream_t, const PetscScalar *, const Vec[], PetscInt, PetscScalar *, PetscInt &) noexcept; 66 template <std::size_t... Idx> 67 PETSC_NODISCARD static PetscErrorCode maxpy_kernel_dispatch_(PetscDeviceContext, cupmStream_t, PetscScalar *, const PetscScalar *, const Vec *, PetscInt, util::index_sequence<Idx...>) noexcept; 68 template <int> 69 PETSC_NODISCARD static PetscErrorCode maxpy_kernel_dispatch_(PetscDeviceContext, cupmStream_t, PetscScalar *, const PetscScalar *, const Vec *, PetscInt, PetscInt &) noexcept; 70 // common core for the various create routines 71 PETSC_NODISCARD static PetscErrorCode createseqcupm_(Vec, PetscDeviceContext, PetscScalar * /*host_ptr*/ = nullptr, PetscScalar * /*device_ptr*/ = nullptr) noexcept; 72 73 public: 74 // callable directly via a bespoke function 75 PETSC_NODISCARD static PetscErrorCode createseqcupm(MPI_Comm, PetscInt, PetscInt, Vec *, PetscBool) noexcept; 76 PETSC_NODISCARD static PetscErrorCode createseqcupmwithbotharrays(MPI_Comm, PetscInt, PetscInt, const PetscScalar[], const PetscScalar[], Vec *) noexcept; 77 78 // callable indirectly via function pointers 79 PETSC_NODISCARD static PetscErrorCode duplicate(Vec, Vec *) noexcept; 80 PETSC_NODISCARD static PetscErrorCode aypx(Vec, PetscScalar, Vec) noexcept; 81 PETSC_NODISCARD static PetscErrorCode axpy(Vec, PetscScalar, Vec) noexcept; 82 PETSC_NODISCARD static PetscErrorCode pointwisedivide(Vec, Vec, Vec) noexcept; 83 PETSC_NODISCARD static PetscErrorCode pointwisemult(Vec, Vec, Vec) noexcept; 84 PETSC_NODISCARD static PetscErrorCode reciprocal(Vec) noexcept; 85 PETSC_NODISCARD static PetscErrorCode waxpy(Vec, PetscScalar, Vec, Vec) noexcept; 86 PETSC_NODISCARD static PetscErrorCode maxpy(Vec, PetscInt, const PetscScalar[], Vec *) noexcept; 87 PETSC_NODISCARD static PetscErrorCode dot(Vec, Vec, PetscScalar *) noexcept; 88 PETSC_NODISCARD static PetscErrorCode mdot(Vec, PetscInt, const Vec[], PetscScalar *) noexcept; 89 PETSC_NODISCARD static PetscErrorCode set(Vec, PetscScalar) noexcept; 90 PETSC_NODISCARD static PetscErrorCode scale(Vec, PetscScalar) noexcept; 91 PETSC_NODISCARD static PetscErrorCode tdot(Vec, Vec, PetscScalar *) noexcept; 92 PETSC_NODISCARD static PetscErrorCode copy(Vec, Vec) noexcept; 93 PETSC_NODISCARD static PetscErrorCode swap(Vec, Vec) noexcept; 94 PETSC_NODISCARD static PetscErrorCode axpby(Vec, PetscScalar, PetscScalar, Vec) noexcept; 95 PETSC_NODISCARD static PetscErrorCode axpbypcz(Vec, PetscScalar, PetscScalar, PetscScalar, Vec, Vec) noexcept; 96 PETSC_NODISCARD static PetscErrorCode norm(Vec, NormType, PetscReal *) noexcept; 97 PETSC_NODISCARD static PetscErrorCode dotnorm2(Vec, Vec, PetscScalar *, PetscScalar *) noexcept; 98 PETSC_NODISCARD static PetscErrorCode destroy(Vec) noexcept; 99 PETSC_NODISCARD static PetscErrorCode conjugate(Vec) noexcept; 100 template <PetscMemoryAccessMode> 101 PETSC_NODISCARD static PetscErrorCode getlocalvector(Vec, Vec) noexcept; 102 template <PetscMemoryAccessMode> 103 PETSC_NODISCARD static PetscErrorCode restorelocalvector(Vec, Vec) noexcept; 104 PETSC_NODISCARD static PetscErrorCode max(Vec, PetscInt *, PetscReal *) noexcept; 105 PETSC_NODISCARD static PetscErrorCode min(Vec, PetscInt *, PetscReal *) noexcept; 106 PETSC_NODISCARD static PetscErrorCode sum(Vec, PetscScalar *) noexcept; 107 PETSC_NODISCARD static PetscErrorCode shift(Vec, PetscScalar) noexcept; 108 PETSC_NODISCARD static PetscErrorCode setrandom(Vec, PetscRandom) noexcept; 109 PETSC_NODISCARD static PetscErrorCode bindtocpu(Vec, PetscBool) noexcept; 110 PETSC_NODISCARD static PetscErrorCode setpreallocationcoo(Vec, PetscCount, const PetscInt[]) noexcept; 111 PETSC_NODISCARD static PetscErrorCode setvaluescoo(Vec, const PetscScalar[], InsertMode) noexcept; 112 }; 113 114 // ========================================================================================== 115 // VecSeq_CUPM - Private API 116 // ========================================================================================== 117 118 template <device::cupm::DeviceType T> 119 inline Vec_Seq *VecSeq_CUPM<T>::VecIMPLCast_(Vec v) noexcept 120 { 121 return static_cast<Vec_Seq *>(v->data); 122 } 123 124 template <device::cupm::DeviceType T> 125 inline constexpr VecType VecSeq_CUPM<T>::VECIMPLCUPM_() noexcept 126 { 127 return VECSEQCUPM(); 128 } 129 130 template <device::cupm::DeviceType T> 131 inline PetscErrorCode VecSeq_CUPM<T>::VecDestroy_IMPL_(Vec v) noexcept 132 { 133 return VecDestroy_Seq(v); 134 } 135 136 template <device::cupm::DeviceType T> 137 inline PetscErrorCode VecSeq_CUPM<T>::VecResetArray_IMPL_(Vec v) noexcept 138 { 139 return VecResetArray_Seq(v); 140 } 141 142 template <device::cupm::DeviceType T> 143 inline PetscErrorCode VecSeq_CUPM<T>::VecPlaceArray_IMPL_(Vec v, const PetscScalar *a) noexcept 144 { 145 return VecPlaceArray_Seq(v, a); 146 } 147 148 template <device::cupm::DeviceType T> 149 inline PetscErrorCode VecSeq_CUPM<T>::VecCreate_IMPL_Private_(Vec v, PetscBool *alloc_missing, PetscInt, PetscScalar *host_array) noexcept 150 { 151 PetscMPIInt size; 152 153 PetscFunctionBegin; 154 if (alloc_missing) *alloc_missing = PETSC_FALSE; 155 PetscCallMPI(MPI_Comm_size(PetscObjectComm(PetscObjectCast(v)), &size)); 156 PetscCheck(size <= 1, PETSC_COMM_SELF, PETSC_ERR_ARG_WRONG, "Must create VecSeq on communicator of size 1, have size %d", size); 157 PetscCall(VecCreate_Seq_Private(v, host_array)); 158 PetscFunctionReturn(0); 159 } 160 161 // for functions with an early return based one vec size we still need to artificially bump the 162 // object state. This is to prevent the following: 163 // 164 // 0. Suppose you have a Vec { 165 // rank 0: [0], 166 // rank 1: [<empty>] 167 // } 168 // 1. both ranks have Vec with PetscObjectState = 0, stashed norm of 0 169 // 2. Vec enters e.g. VecSet(10) 170 // 3. rank 1 has local size 0 and bails immediately 171 // 4. rank 0 has local size 1 and enters function, eventually calls DeviceArrayWrite() 172 // 5. DeviceArrayWrite() calls PetscObjectStateIncrease(), now state = 1 173 // 6. Vec enters VecNorm(), and calls VecNormAvailable() 174 // 7. rank 1 has object state = 0, equal to stash and returns early with norm = 0 175 // 8. rank 0 has object state = 1, not equal to stash, continues to impl function 176 // 9. rank 0 deadlocks on MPI_Allreduce() because rank 1 bailed early 177 template <device::cupm::DeviceType T> 178 inline PetscErrorCode VecSeq_CUPM<T>::MaybeIncrementEmptyLocalVec(Vec v) noexcept 179 { 180 PetscFunctionBegin; 181 if (PetscUnlikely((v->map->n == 0) && (v->map->N != 0))) PetscCall(PetscObjectStateIncrease(PetscObjectCast(v))); 182 PetscFunctionReturn(0); 183 } 184 185 template <device::cupm::DeviceType T> 186 inline PetscErrorCode VecSeq_CUPM<T>::createseqcupm_(Vec v, PetscDeviceContext dctx, PetscScalar *host_array, PetscScalar *device_array) noexcept 187 { 188 PetscFunctionBegin; 189 PetscCall(base_type::VecCreate_IMPL_Private(v, nullptr, 0, host_array)); 190 PetscCall(Initialize_CUPMBase(v, PETSC_FALSE, host_array, device_array, dctx)); 191 PetscFunctionReturn(0); 192 } 193 194 template <device::cupm::DeviceType T> 195 template <typename BinaryFuncT> 196 inline PetscErrorCode VecSeq_CUPM<T>::pointwisebinary_(BinaryFuncT &&binary, Vec xin, Vec yin, Vec zout) noexcept 197 { 198 PetscFunctionBegin; 199 if (const auto n = zout->map->n) { 200 PetscDeviceContext dctx; 201 202 PetscCall(GetHandles_(&dctx)); 203 PetscCall(device::cupm::impl::ThrustApplyPointwise<T>(dctx, std::forward<BinaryFuncT>(binary), n, DeviceArrayRead(dctx, xin).data(), DeviceArrayRead(dctx, yin).data(), DeviceArrayWrite(dctx, zout).data())); 204 PetscCall(PetscDeviceContextSynchronize(dctx)); 205 } else { 206 PetscCall(MaybeIncrementEmptyLocalVec(zout)); 207 } 208 PetscFunctionReturn(0); 209 } 210 211 template <device::cupm::DeviceType T> 212 template <typename UnaryFuncT> 213 inline PetscErrorCode VecSeq_CUPM<T>::pointwiseunary_(UnaryFuncT &&unary, Vec xinout, Vec yin) noexcept 214 { 215 const auto inplace = !yin || (xinout == yin); 216 217 PetscFunctionBegin; 218 if (const auto n = xinout->map->n) { 219 PetscDeviceContext dctx; 220 221 PetscCall(GetHandles_(&dctx)); 222 if (inplace) { 223 PetscCall(device::cupm::impl::ThrustApplyPointwise<T>(dctx, std::forward<UnaryFuncT>(unary), n, DeviceArrayReadWrite(dctx, xinout).data())); 224 } else { 225 PetscCall(device::cupm::impl::ThrustApplyPointwise<T>(dctx, std::forward<UnaryFuncT>(unary), n, DeviceArrayRead(dctx, xinout).data(), DeviceArrayWrite(dctx, yin).data())); 226 } 227 PetscCall(PetscDeviceContextSynchronize(dctx)); 228 } else { 229 if (inplace) { 230 PetscCall(MaybeIncrementEmptyLocalVec(xinout)); 231 } else { 232 PetscCall(MaybeIncrementEmptyLocalVec(yin)); 233 } 234 } 235 PetscFunctionReturn(0); 236 } 237 238 // ========================================================================================== 239 // VecSeq_CUPM - Public API - Constructors 240 // ========================================================================================== 241 242 // VecCreateSeqCUPM() 243 template <device::cupm::DeviceType T> 244 inline PetscErrorCode VecSeq_CUPM<T>::createseqcupm(MPI_Comm comm, PetscInt bs, PetscInt n, Vec *v, PetscBool call_set_type) noexcept 245 { 246 PetscFunctionBegin; 247 PetscCall(Create_CUPMBase(comm, bs, n, n, v, call_set_type)); 248 PetscFunctionReturn(0); 249 } 250 251 // VecCreateSeqCUPMWithArrays() 252 template <device::cupm::DeviceType T> 253 inline PetscErrorCode VecSeq_CUPM<T>::createseqcupmwithbotharrays(MPI_Comm comm, PetscInt bs, PetscInt n, const PetscScalar host_array[], const PetscScalar device_array[], Vec *v) noexcept 254 { 255 PetscDeviceContext dctx; 256 257 PetscFunctionBegin; 258 PetscCall(GetHandles_(&dctx)); 259 // do NOT call VecSetType(), otherwise ops->create() -> create() -> 260 // createseqcupm_() is called! 261 PetscCall(createseqcupm(comm, bs, n, v, PETSC_FALSE)); 262 PetscCall(createseqcupm_(*v, dctx, PetscRemoveConstCast(host_array), PetscRemoveConstCast(device_array))); 263 PetscFunctionReturn(0); 264 } 265 266 // v->ops->duplicate 267 template <device::cupm::DeviceType T> 268 inline PetscErrorCode VecSeq_CUPM<T>::duplicate(Vec v, Vec *y) noexcept 269 { 270 PetscDeviceContext dctx; 271 272 PetscFunctionBegin; 273 PetscCall(GetHandles_(&dctx)); 274 PetscCall(Duplicate_CUPMBase(v, y, dctx)); 275 PetscFunctionReturn(0); 276 } 277 278 // ========================================================================================== 279 // VecSeq_CUPM - Public API - Utility 280 // ========================================================================================== 281 282 // v->ops->bindtocpu 283 template <device::cupm::DeviceType T> 284 inline PetscErrorCode VecSeq_CUPM<T>::bindtocpu(Vec v, PetscBool usehost) noexcept 285 { 286 PetscDeviceContext dctx; 287 288 PetscFunctionBegin; 289 PetscCall(GetHandles_(&dctx)); 290 PetscCall(BindToCPU_CUPMBase(v, usehost, dctx)); 291 292 // REVIEW ME: this absolutely should be some sort of bulk mempcy rather than this mess 293 VecSetOp_CUPM(dot, VecDot_Seq, dot); 294 VecSetOp_CUPM(norm, VecNorm_Seq, norm); 295 VecSetOp_CUPM(tdot, VecTDot_Seq, tdot); 296 VecSetOp_CUPM(mdot, VecMDot_Seq, mdot); 297 VecSetOp_CUPM(resetarray, VecResetArray_Seq, base_type::template resetarray<PETSC_MEMTYPE_HOST>); 298 VecSetOp_CUPM(placearray, VecPlaceArray_Seq, base_type::template placearray<PETSC_MEMTYPE_HOST>); 299 v->ops->mtdot = v->ops->mtdot_local = VecMTDot_Seq; 300 VecSetOp_CUPM(conjugate, VecConjugate_Seq, conjugate); 301 VecSetOp_CUPM(max, VecMax_Seq, max); 302 VecSetOp_CUPM(min, VecMin_Seq, min); 303 VecSetOp_CUPM(setpreallocationcoo, VecSetPreallocationCOO_Seq, setpreallocationcoo); 304 VecSetOp_CUPM(setvaluescoo, VecSetValuesCOO_Seq, setvaluescoo); 305 PetscFunctionReturn(0); 306 } 307 308 // ========================================================================================== 309 // VecSeq_CUPM - Public API - Mutatators 310 // ========================================================================================== 311 312 // v->ops->getlocalvector or v->ops->getlocalvectorread 313 template <device::cupm::DeviceType T> 314 template <PetscMemoryAccessMode access> 315 inline PetscErrorCode VecSeq_CUPM<T>::getlocalvector(Vec v, Vec w) noexcept 316 { 317 PetscBool wisseqcupm; 318 319 PetscFunctionBegin; 320 PetscCheckTypeNames(v, VECSEQCUPM(), VECMPICUPM()); 321 PetscCall(PetscObjectTypeCompare(PetscObjectCast(w), VECSEQCUPM(), &wisseqcupm)); 322 if (wisseqcupm) { 323 if (const auto wseq = VecIMPLCast(w)) { 324 if (auto &alloced = wseq->array_allocated) { 325 const auto useit = UseCUPMHostAlloc(util::exchange(w->pinned_memory, PETSC_FALSE)); 326 327 PetscCall(PetscFree(alloced)); 328 } 329 wseq->array = nullptr; 330 wseq->unplacedarray = nullptr; 331 } 332 if (const auto wcu = VecCUPMCast(w)) { 333 if (auto &device_array = wcu->array_d) { 334 cupmStream_t stream; 335 336 PetscCall(GetHandles_(&stream)); 337 PetscCallCUPM(cupmFreeAsync(device_array, stream)); 338 } 339 PetscCall(PetscFree(w->spptr /* wcu */)); 340 } 341 } 342 if (v->petscnative && wisseqcupm) { 343 PetscCall(PetscFree(w->data)); 344 w->data = v->data; 345 w->offloadmask = v->offloadmask; 346 w->pinned_memory = v->pinned_memory; 347 w->spptr = v->spptr; 348 PetscCall(PetscObjectStateIncrease(PetscObjectCast(w))); 349 } else { 350 const auto array = &VecIMPLCast(w)->array; 351 352 if (access == PETSC_MEMORY_ACCESS_READ) { 353 PetscCall(VecGetArrayRead(v, const_cast<const PetscScalar **>(array))); 354 } else { 355 PetscCall(VecGetArray(v, array)); 356 } 357 w->offloadmask = PETSC_OFFLOAD_CPU; 358 if (wisseqcupm) { 359 PetscDeviceContext dctx; 360 361 PetscCall(GetHandles_(&dctx)); 362 PetscCall(DeviceAllocateCheck_(dctx, w)); 363 } 364 } 365 PetscFunctionReturn(0); 366 } 367 368 // v->ops->restorelocalvector or v->ops->restorelocalvectorread 369 template <device::cupm::DeviceType T> 370 template <PetscMemoryAccessMode access> 371 inline PetscErrorCode VecSeq_CUPM<T>::restorelocalvector(Vec v, Vec w) noexcept 372 { 373 PetscBool wisseqcupm; 374 375 PetscFunctionBegin; 376 PetscCheckTypeNames(v, VECSEQCUPM(), VECMPICUPM()); 377 PetscCall(PetscObjectTypeCompare(PetscObjectCast(w), VECSEQCUPM(), &wisseqcupm)); 378 if (v->petscnative && wisseqcupm) { 379 // the assignments to nullptr are __critical__, as w may persist after this call returns 380 // and shouldn't share data with v! 381 v->pinned_memory = w->pinned_memory; 382 v->offloadmask = util::exchange(w->offloadmask, PETSC_OFFLOAD_UNALLOCATED); 383 v->data = util::exchange(w->data, nullptr); 384 v->spptr = util::exchange(w->spptr, nullptr); 385 } else { 386 const auto array = &VecIMPLCast(w)->array; 387 388 if (access == PETSC_MEMORY_ACCESS_READ) { 389 PetscCall(VecRestoreArrayRead(v, const_cast<const PetscScalar **>(array))); 390 } else { 391 PetscCall(VecRestoreArray(v, array)); 392 } 393 if (w->spptr && wisseqcupm) { 394 cupmStream_t stream; 395 396 PetscCall(GetHandles_(&stream)); 397 PetscCallCUPM(cupmFreeAsync(VecCUPMCast(w)->array_d, stream)); 398 PetscCall(PetscFree(w->spptr)); 399 } 400 } 401 PetscFunctionReturn(0); 402 } 403 404 // ========================================================================================== 405 // VecSeq_CUPM - Public API - Compute Methods 406 // ========================================================================================== 407 408 // v->ops->aypx 409 template <device::cupm::DeviceType T> 410 inline PetscErrorCode VecSeq_CUPM<T>::aypx(Vec yin, PetscScalar alpha, Vec xin) noexcept 411 { 412 const auto n = static_cast<cupmBlasInt_t>(yin->map->n); 413 const auto sync = n != 0; 414 PetscDeviceContext dctx; 415 416 PetscFunctionBegin; 417 PetscCall(GetHandles_(&dctx)); 418 if (alpha == PetscScalar(0.0)) { 419 cupmStream_t stream; 420 421 PetscCall(GetHandlesFrom_(dctx, &stream)); 422 PetscCall(PetscLogGpuTimeBegin()); 423 PetscCall(PetscCUPMMemcpyAsync(DeviceArrayWrite(dctx, yin).data(), DeviceArrayRead(dctx, xin).data(), n, cupmMemcpyDeviceToDevice, stream)); 424 PetscCall(PetscLogGpuTimeEnd()); 425 } else if (n) { 426 const auto alphaIsOne = alpha == PetscScalar(1.0); 427 const auto calpha = cupmScalarPtrCast(&alpha); 428 cupmBlasHandle_t cupmBlasHandle; 429 430 PetscCall(GetHandlesFrom_(dctx, &cupmBlasHandle)); 431 { 432 const auto yptr = DeviceArrayReadWrite(dctx, yin); 433 const auto xptr = DeviceArrayRead(dctx, xin); 434 435 PetscCall(PetscLogGpuTimeBegin()); 436 if (alphaIsOne) { 437 PetscCallCUPMBLAS(cupmBlasXaxpy(cupmBlasHandle, n, calpha, xptr.cupmdata(), 1, yptr.cupmdata(), 1)); 438 } else { 439 const auto one = cupmScalarCast(1.0); 440 441 PetscCallCUPMBLAS(cupmBlasXscal(cupmBlasHandle, n, calpha, yptr.cupmdata(), 1)); 442 PetscCallCUPMBLAS(cupmBlasXaxpy(cupmBlasHandle, n, &one, xptr.cupmdata(), 1, yptr.cupmdata(), 1)); 443 } 444 PetscCall(PetscLogGpuTimeEnd()); 445 } 446 PetscCall(PetscLogGpuFlops((alphaIsOne ? 1 : 2) * n)); 447 } 448 if (sync) PetscCall(PetscDeviceContextSynchronize(dctx)); 449 PetscFunctionReturn(0); 450 } 451 452 // v->ops->axpy 453 template <device::cupm::DeviceType T> 454 inline PetscErrorCode VecSeq_CUPM<T>::axpy(Vec yin, PetscScalar alpha, Vec xin) noexcept 455 { 456 PetscBool xiscupm; 457 458 PetscFunctionBegin; 459 PetscCall(PetscObjectTypeCompareAny(PetscObjectCast(xin), &xiscupm, VECSEQCUPM(), VECMPICUPM(), "")); 460 if (xiscupm) { 461 const auto n = static_cast<cupmBlasInt_t>(yin->map->n); 462 PetscDeviceContext dctx; 463 cupmBlasHandle_t cupmBlasHandle; 464 465 PetscCall(GetHandles_(&dctx, &cupmBlasHandle)); 466 PetscCall(PetscLogGpuTimeBegin()); 467 PetscCallCUPMBLAS(cupmBlasXaxpy(cupmBlasHandle, n, cupmScalarPtrCast(&alpha), DeviceArrayRead(dctx, xin), 1, DeviceArrayReadWrite(dctx, yin), 1)); 468 PetscCall(PetscLogGpuTimeEnd()); 469 PetscCall(PetscLogGpuFlops(2 * n)); 470 PetscCall(PetscDeviceContextSynchronize(dctx)); 471 } else { 472 PetscCall(VecAXPY_Seq(yin, alpha, xin)); 473 } 474 PetscFunctionReturn(0); 475 } 476 477 // v->ops->pointwisedivide 478 template <device::cupm::DeviceType T> 479 inline PetscErrorCode VecSeq_CUPM<T>::pointwisedivide(Vec win, Vec xin, Vec yin) noexcept 480 { 481 PetscFunctionBegin; 482 if (xin->boundtocpu || yin->boundtocpu) { 483 PetscCall(VecPointwiseDivide_Seq(win, xin, yin)); 484 } else { 485 // note order of arguments! xin and yin are read, win is written! 486 PetscCall(pointwisebinary_(thrust::divides<PetscScalar>{}, xin, yin, win)); 487 } 488 PetscFunctionReturn(0); 489 } 490 491 // v->ops->pointwisemult 492 template <device::cupm::DeviceType T> 493 inline PetscErrorCode VecSeq_CUPM<T>::pointwisemult(Vec win, Vec xin, Vec yin) noexcept 494 { 495 PetscFunctionBegin; 496 if (xin->boundtocpu || yin->boundtocpu) { 497 PetscCall(VecPointwiseMult_Seq(win, xin, yin)); 498 } else { 499 // note order of arguments! xin and yin are read, win is written! 500 PetscCall(pointwisebinary_(thrust::multiplies<PetscScalar>{}, xin, yin, win)); 501 } 502 PetscFunctionReturn(0); 503 } 504 505 namespace detail 506 { 507 508 struct reciprocal { 509 PETSC_HOSTDEVICE_INLINE_DECL PetscScalar operator()(PetscScalar s) const noexcept 510 { 511 // yes all of this verbosity is needed because sometimes PetscScalar is a thrust::complex 512 // and then it matters whether we do s ? true : false vs s == 0, as well as whether we wrap 513 // everything in PetscScalar... 514 return s == PetscScalar{0.0} ? s : PetscScalar{1.0} / s; 515 } 516 }; 517 518 } // namespace detail 519 520 // v->ops->reciprocal 521 template <device::cupm::DeviceType T> 522 inline PetscErrorCode VecSeq_CUPM<T>::reciprocal(Vec xin) noexcept 523 { 524 PetscFunctionBegin; 525 PetscCall(pointwiseunary_(detail::reciprocal{}, xin)); 526 PetscFunctionReturn(0); 527 } 528 529 // v->ops->waxpy 530 template <device::cupm::DeviceType T> 531 inline PetscErrorCode VecSeq_CUPM<T>::waxpy(Vec win, PetscScalar alpha, Vec xin, Vec yin) noexcept 532 { 533 PetscFunctionBegin; 534 if (alpha == PetscScalar(0.0)) { 535 PetscCall(copy(yin, win)); 536 } else if (const auto n = static_cast<cupmBlasInt_t>(win->map->n)) { 537 PetscDeviceContext dctx; 538 cupmBlasHandle_t cupmBlasHandle; 539 cupmStream_t stream; 540 541 PetscCall(GetHandles_(&dctx, &cupmBlasHandle, &stream)); 542 { 543 const auto wptr = DeviceArrayWrite(dctx, win); 544 545 PetscCall(PetscLogGpuTimeBegin()); 546 PetscCall(PetscCUPMMemcpyAsync(wptr.data(), DeviceArrayRead(dctx, yin).data(), n, cupmMemcpyDeviceToDevice, stream, true)); 547 PetscCallCUPMBLAS(cupmBlasXaxpy(cupmBlasHandle, n, cupmScalarPtrCast(&alpha), DeviceArrayRead(dctx, xin), 1, wptr.cupmdata(), 1)); 548 PetscCall(PetscLogGpuTimeEnd()); 549 } 550 PetscCall(PetscLogGpuFlops(2 * n)); 551 PetscCall(PetscDeviceContextSynchronize(dctx)); 552 } 553 PetscFunctionReturn(0); 554 } 555 556 namespace kernels 557 { 558 559 template <typename... Args> 560 PETSC_KERNEL_DECL static void maxpy_kernel(const PetscInt size, PetscScalar *PETSC_RESTRICT xptr, const PetscScalar *PETSC_RESTRICT aptr, Args... yptr) 561 { 562 constexpr int N = sizeof...(Args); 563 const auto tx = threadIdx.x; 564 const PetscScalar *yptr_p[] = {yptr...}; 565 566 PETSC_SHAREDMEM_DECL PetscScalar aptr_shmem[N]; 567 568 // load a to shared memory 569 if (tx < N) aptr_shmem[tx] = aptr[tx]; 570 __syncthreads(); 571 572 ::Petsc::device::cupm::kernels::util::grid_stride_1D(size, [&](PetscInt i) { 573 // these may look the same but give different results! 574 #if 0 575 PetscScalar sum = 0.0; 576 577 #pragma unroll 578 for (auto j = 0; j < N; ++j) sum += aptr_shmem[j]*yptr_p[j][i]; 579 xptr[i] += sum; 580 #else 581 auto sum = xptr[i]; 582 583 #pragma unroll 584 for (auto j = 0; j < N; ++j) sum += aptr_shmem[j]*yptr_p[j][i]; 585 xptr[i] = sum; 586 #endif 587 }); 588 return; 589 } 590 591 } // namespace kernels 592 593 namespace detail 594 { 595 596 // a helper-struct to gobble the size_t input, it is used with template parameter pack 597 // expansion such that 598 // typename repeat_type<MyType, IdxParamPack>... 599 // expands to 600 // MyType, MyType, MyType, ... [repeated sizeof...(IdxParamPack) times] 601 template <typename T, std::size_t> 602 struct repeat_type { 603 using type = T; 604 }; 605 606 } // namespace detail 607 608 template <device::cupm::DeviceType T> 609 template <std::size_t... Idx> 610 inline PetscErrorCode VecSeq_CUPM<T>::maxpy_kernel_dispatch_(PetscDeviceContext dctx, cupmStream_t stream, PetscScalar *xptr, const PetscScalar *aptr, const Vec *yin, PetscInt size, util::index_sequence<Idx...>) noexcept 611 { 612 PetscFunctionBegin; 613 // clang-format off 614 PetscCall( 615 PetscCUPMLaunchKernel1D( 616 size, 0, stream, 617 kernels::maxpy_kernel<typename detail::repeat_type<const PetscScalar *, Idx>::type...>, 618 size, xptr, aptr, DeviceArrayRead(dctx, yin[Idx]).data()... 619 ) 620 ); 621 // clang-format on 622 PetscFunctionReturn(0); 623 } 624 625 template <device::cupm::DeviceType T> 626 template <int N> 627 inline PetscErrorCode VecSeq_CUPM<T>::maxpy_kernel_dispatch_(PetscDeviceContext dctx, cupmStream_t stream, PetscScalar *xptr, const PetscScalar *aptr, const Vec *yin, PetscInt size, PetscInt &yidx) noexcept 628 { 629 PetscFunctionBegin; 630 PetscCall(maxpy_kernel_dispatch_(dctx, stream, xptr, aptr + yidx, yin + yidx, size, util::make_index_sequence<N>{})); 631 yidx += N; 632 PetscFunctionReturn(0); 633 } 634 635 // v->ops->maxpy 636 template <device::cupm::DeviceType T> 637 inline PetscErrorCode VecSeq_CUPM<T>::maxpy(Vec xin, PetscInt nv, const PetscScalar *alpha, Vec *yin) noexcept 638 { 639 const auto n = xin->map->n; 640 PetscDeviceContext dctx; 641 cupmStream_t stream; 642 643 PetscFunctionBegin; 644 PetscCall(GetHandles_(&dctx, &stream)); 645 { 646 const auto xptr = DeviceArrayReadWrite(dctx, xin); 647 PetscScalar *d_alpha = nullptr; 648 PetscInt yidx = 0; 649 650 // placement of early-return is deliberate, we would like to capture the 651 // DeviceArrayReadWrite() call (which calls PetscObjectStateIncreate()) before we bail 652 if (!n || !nv) PetscFunctionReturn(0); 653 PetscCall(PetscDeviceMalloc(dctx, PETSC_MEMTYPE_CUPM(), nv, &d_alpha)); 654 PetscCall(PetscCUPMMemcpyAsync(d_alpha, alpha, nv, cupmMemcpyHostToDevice, stream)); 655 PetscCall(PetscLogGpuTimeBegin()); 656 do { 657 switch (nv - yidx) { 658 case 7: 659 PetscCall(maxpy_kernel_dispatch_<7>(dctx, stream, xptr.data(), d_alpha, yin, n, yidx)); 660 break; 661 case 6: 662 PetscCall(maxpy_kernel_dispatch_<6>(dctx, stream, xptr.data(), d_alpha, yin, n, yidx)); 663 break; 664 case 5: 665 PetscCall(maxpy_kernel_dispatch_<5>(dctx, stream, xptr.data(), d_alpha, yin, n, yidx)); 666 break; 667 case 4: 668 PetscCall(maxpy_kernel_dispatch_<4>(dctx, stream, xptr.data(), d_alpha, yin, n, yidx)); 669 break; 670 case 3: 671 PetscCall(maxpy_kernel_dispatch_<3>(dctx, stream, xptr.data(), d_alpha, yin, n, yidx)); 672 break; 673 case 2: 674 PetscCall(maxpy_kernel_dispatch_<2>(dctx, stream, xptr.data(), d_alpha, yin, n, yidx)); 675 break; 676 case 1: 677 PetscCall(maxpy_kernel_dispatch_<1>(dctx, stream, xptr.data(), d_alpha, yin, n, yidx)); 678 break; 679 default: // 8 or more 680 PetscCall(maxpy_kernel_dispatch_<8>(dctx, stream, xptr.data(), d_alpha, yin, n, yidx)); 681 break; 682 } 683 } while (yidx < nv); 684 PetscCall(PetscLogGpuTimeEnd()); 685 PetscCall(PetscDeviceFree(dctx, d_alpha)); 686 } 687 PetscCall(PetscLogGpuFlops(nv * 2 * n)); 688 PetscCall(PetscDeviceContextSynchronize(dctx)); 689 PetscFunctionReturn(0); 690 } 691 692 template <device::cupm::DeviceType T> 693 inline PetscErrorCode VecSeq_CUPM<T>::dot(Vec xin, Vec yin, PetscScalar *z) noexcept 694 { 695 PetscFunctionBegin; 696 if (const auto n = static_cast<cupmBlasInt_t>(xin->map->n)) { 697 PetscDeviceContext dctx; 698 cupmBlasHandle_t cupmBlasHandle; 699 700 PetscCall(GetHandles_(&dctx, &cupmBlasHandle)); 701 // arguments y, x are reversed because BLAS complex conjugates the first argument, PETSc the 702 // second 703 PetscCall(PetscLogGpuTimeBegin()); 704 PetscCallCUPMBLAS(cupmBlasXdot(cupmBlasHandle, n, DeviceArrayRead(dctx, yin), 1, DeviceArrayRead(dctx, xin), 1, cupmScalarPtrCast(z))); 705 PetscCall(PetscLogGpuTimeEnd()); 706 PetscCall(PetscLogGpuFlops(2 * n - 1)); 707 } else { 708 *z = 0.0; 709 } 710 PetscFunctionReturn(0); 711 } 712 713 #define MDOT_WORKGROUP_NUM 128 714 #define MDOT_WORKGROUP_SIZE MDOT_WORKGROUP_NUM 715 716 namespace kernels 717 { 718 719 PETSC_DEVICE_INLINE_DECL static PetscInt EntriesPerGroup(const PetscInt size) noexcept 720 { 721 const auto group_entries = (size - 1) / gridDim.x + 1; 722 // for very small vectors, a group should still do some work 723 return group_entries ? group_entries : 1; 724 } 725 726 template <typename... ConstPetscScalarPointer> 727 PETSC_KERNEL_DECL static void mdot_kernel(const PetscScalar *PETSC_RESTRICT x, const PetscInt size, PetscScalar *PETSC_RESTRICT results, ConstPetscScalarPointer... y) 728 { 729 constexpr int N = sizeof...(ConstPetscScalarPointer); 730 const PetscScalar *ylocal[] = {y...}; 731 PetscScalar sumlocal[N]; 732 733 PETSC_SHAREDMEM_DECL PetscScalar shmem[N * MDOT_WORKGROUP_SIZE]; 734 735 // HIP -- for whatever reason -- has threadIdx, blockIdx, blockDim, and gridDim as separate 736 // types, so each of these go on separate lines... 737 const auto tx = threadIdx.x; 738 const auto bx = blockIdx.x; 739 const auto bdx = blockDim.x; 740 const auto gdx = gridDim.x; 741 const auto worksize = EntriesPerGroup(size); 742 const auto begin = tx + bx * worksize; 743 const auto end = min((bx + 1) * worksize, size); 744 745 #pragma unroll 746 for (auto i = 0; i < N; ++i) sumlocal[i] = 0; 747 748 for (auto i = begin; i < end; i += bdx) { 749 const auto xi = x[i]; // load only once from global memory! 750 751 #pragma unroll 752 for (auto j = 0; j < N; ++j) sumlocal[j] += ylocal[j][i] * xi; 753 } 754 755 #pragma unroll 756 for (auto i = 0; i < N; ++i) shmem[tx + i * MDOT_WORKGROUP_SIZE] = sumlocal[i]; 757 758 // parallel reduction 759 for (auto stride = bdx / 2; stride > 0; stride /= 2) { 760 __syncthreads(); 761 if (tx < stride) { 762 #pragma unroll 763 for (auto i = 0; i < N; ++i) shmem[tx + i * MDOT_WORKGROUP_SIZE] += shmem[tx + stride + i * MDOT_WORKGROUP_SIZE]; 764 } 765 } 766 // bottom N threads per block write to global memory 767 // REVIEW ME: I am ~pretty~ sure we don't need another __syncthreads() here since each thread 768 // writes to the same sections in the above loop that it is about to read from below, but 769 // running this under the racecheck tool of cuda-memcheck reports a write-after-write hazard. 770 __syncthreads(); 771 if (tx < N) results[bx + tx * gdx] = shmem[tx * MDOT_WORKGROUP_SIZE]; 772 return; 773 } 774 775 PETSC_KERNEL_DECL static void sum_kernel(const PetscInt size, PetscScalar *PETSC_RESTRICT results) 776 { 777 int local_i = 0; 778 PetscScalar local_results[8]; 779 780 // each thread sums up MDOT_WORKGROUP_NUM entries of the result, storing it in a local buffer 781 // 782 // *-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-* 783 // | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | ... 784 // *-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-* 785 // | ______________________________________________________/ 786 // | / <- MDOT_WORKGROUP_NUM -> 787 // |/ 788 // + 789 // v 790 // *-*-* 791 // | | | ... 792 // *-*-* 793 // 794 ::Petsc::device::cupm::kernels::util::grid_stride_1D(size, [&](PetscInt i) { 795 PetscScalar z_sum = 0; 796 797 for (auto j = i * MDOT_WORKGROUP_SIZE; j < (i + 1) * MDOT_WORKGROUP_SIZE; ++j) z_sum += results[j]; 798 local_results[local_i++] = z_sum; 799 }); 800 // if we needed more than 1 workgroup to handle the vector we should sync since other threads 801 // may currently be reading from results 802 if (size >= MDOT_WORKGROUP_SIZE) __syncthreads(); 803 // Local buffer is now written to global memory 804 ::Petsc::device::cupm::kernels::util::grid_stride_1D(size, [&](PetscInt i) { 805 const auto j = --local_i; 806 807 if (j >= 0) results[i] = local_results[j]; 808 }); 809 return; 810 } 811 812 } // namespace kernels 813 814 template <device::cupm::DeviceType T> 815 template <std::size_t... Idx> 816 inline PetscErrorCode VecSeq_CUPM<T>::mdot_kernel_dispatch_(PetscDeviceContext dctx, cupmStream_t stream, const PetscScalar *xarr, const Vec yin[], PetscInt size, PetscScalar *results, util::index_sequence<Idx...>) noexcept 817 { 818 PetscFunctionBegin; 819 // REVIEW ME: convert this kernel launch to PetscCUPMLaunchKernel1D(), it currently launches 820 // 128 blocks of 128 threads every time which may be wasteful 821 // clang-format off 822 PetscCallCUPM( 823 cupmLaunchKernel( 824 kernels::mdot_kernel<typename detail::repeat_type<const PetscScalar *, Idx>::type...>, 825 MDOT_WORKGROUP_NUM, MDOT_WORKGROUP_SIZE, 0, stream, 826 xarr, size, results, DeviceArrayRead(dctx, yin[Idx]).data()... 827 ) 828 ); 829 // clang-format on 830 PetscFunctionReturn(0); 831 } 832 833 template <device::cupm::DeviceType T> 834 template <int N> 835 inline PetscErrorCode VecSeq_CUPM<T>::mdot_kernel_dispatch_(PetscDeviceContext dctx, cupmStream_t stream, const PetscScalar *xarr, const Vec yin[], PetscInt size, PetscScalar *results, PetscInt &yidx) noexcept 836 { 837 PetscFunctionBegin; 838 PetscCall(mdot_kernel_dispatch_(dctx, stream, xarr, yin + yidx, size, results + yidx * MDOT_WORKGROUP_NUM, util::make_index_sequence<N>{})); 839 yidx += N; 840 PetscFunctionReturn(0); 841 } 842 843 template <device::cupm::DeviceType T> 844 inline PetscErrorCode VecSeq_CUPM<T>::mdot_(std::false_type, Vec xin, PetscInt nv, const Vec yin[], PetscScalar *z, PetscDeviceContext dctx) noexcept 845 { 846 // the largest possible size of a batch 847 constexpr PetscInt batchsize = 8; 848 // how many sub streams to create, if nv <= batchsize we can do this without looping, so we 849 // do not create substreams. Note we don't create more than 8 streams, in practice we could 850 // not get more parallelism with higher numbers. 851 const auto num_sub_streams = nv > batchsize ? std::min((nv + batchsize) / batchsize, batchsize) : 0; 852 const auto n = xin->map->n; 853 // number of vectors that we handle via the batches. note any singletons are handled by 854 // cublas, hence the nv-1. 855 const auto nvbatch = ((nv % batchsize) == 1) ? nv - 1 : nv; 856 const auto nwork = nvbatch * MDOT_WORKGROUP_NUM; 857 PetscScalar *d_results; 858 cupmStream_t stream; 859 860 PetscFunctionBegin; 861 PetscCall(GetHandlesFrom_(dctx, &stream)); 862 // allocate scratchpad memory for the results of individual work groups 863 PetscCall(PetscDeviceMalloc(dctx, PETSC_MEMTYPE_CUPM(), nwork, &d_results)); 864 { 865 const auto xptr = DeviceArrayRead(dctx, xin); 866 PetscInt yidx = 0; 867 auto subidx = 0; 868 auto cur_stream = stream; 869 auto cur_ctx = dctx; 870 PetscDeviceContext *sub = nullptr; 871 PetscStreamType stype; 872 873 // REVIEW ME: maybe PetscDeviceContextFork() should insert dctx into the first entry of 874 // sub. Ideally the parent context should also join in on the fork, but it is extremely 875 // fiddly to do so presently 876 PetscCall(PetscDeviceContextGetStreamType(dctx, &stype)); 877 if (stype == PETSC_STREAM_GLOBAL_BLOCKING) stype = PETSC_STREAM_DEFAULT_BLOCKING; 878 // If we have a globally blocking stream create nonblocking streams instead (as we can 879 // locally exploit the parallelism). Otherwise use the prescribed stream type. 880 PetscCall(PetscDeviceContextForkWithStreamType(dctx, stype, num_sub_streams, &sub)); 881 PetscCall(PetscLogGpuTimeBegin()); 882 do { 883 if (num_sub_streams) { 884 cur_ctx = sub[subidx++ % num_sub_streams]; 885 PetscCall(GetHandlesFrom_(cur_ctx, &cur_stream)); 886 } 887 // REVIEW ME: Should probably try and load-balance these. Consider the case where nv = 9; 888 // it is very likely better to do 4+5 rather than 8+1 889 switch (nv - yidx) { 890 case 7: 891 PetscCall(mdot_kernel_dispatch_<7>(cur_ctx, cur_stream, xptr.data(), yin, n, d_results, yidx)); 892 break; 893 case 6: 894 PetscCall(mdot_kernel_dispatch_<6>(cur_ctx, cur_stream, xptr.data(), yin, n, d_results, yidx)); 895 break; 896 case 5: 897 PetscCall(mdot_kernel_dispatch_<5>(cur_ctx, cur_stream, xptr.data(), yin, n, d_results, yidx)); 898 break; 899 case 4: 900 PetscCall(mdot_kernel_dispatch_<4>(cur_ctx, cur_stream, xptr.data(), yin, n, d_results, yidx)); 901 break; 902 case 3: 903 PetscCall(mdot_kernel_dispatch_<3>(cur_ctx, cur_stream, xptr.data(), yin, n, d_results, yidx)); 904 break; 905 case 2: 906 PetscCall(mdot_kernel_dispatch_<2>(cur_ctx, cur_stream, xptr.data(), yin, n, d_results, yidx)); 907 break; 908 case 1: { 909 cupmBlasHandle_t cupmBlasHandle; 910 911 PetscCall(GetHandlesFrom_(cur_ctx, &cupmBlasHandle)); 912 PetscCallCUPMBLAS(cupmBlasXdot(cupmBlasHandle, static_cast<cupmBlasInt_t>(n), DeviceArrayRead(cur_ctx, yin[yidx]).cupmdata(), 1, xptr.cupmdata(), 1, cupmScalarPtrCast(z + yidx))); 913 ++yidx; 914 } break; 915 default: // 8 or more 916 PetscCall(mdot_kernel_dispatch_<8>(cur_ctx, cur_stream, xptr.data(), yin, n, d_results, yidx)); 917 break; 918 } 919 } while (yidx < nv); 920 PetscCall(PetscLogGpuTimeEnd()); 921 PetscCall(PetscDeviceContextJoin(dctx, num_sub_streams, PETSC_DEVICE_CONTEXT_JOIN_DESTROY, &sub)); 922 } 923 924 PetscCall(PetscCUPMLaunchKernel1D(nvbatch, 0, stream, kernels::sum_kernel, nvbatch, d_results)); 925 // copy result of device reduction to host 926 PetscCall(PetscCUPMMemcpyAsync(z, d_results, nvbatch, cupmMemcpyDeviceToHost, stream)); 927 // do these now while final reduction is in flight 928 PetscCall(PetscLogFlops(nwork)); 929 PetscCall(PetscDeviceFree(dctx, d_results)); 930 PetscFunctionReturn(0); 931 } 932 933 #undef MDOT_WORKGROUP_NUM 934 #undef MDOT_WORKGROUP_SIZE 935 936 template <device::cupm::DeviceType T> 937 inline PetscErrorCode VecSeq_CUPM<T>::mdot_(std::true_type, Vec xin, PetscInt nv, const Vec yin[], PetscScalar *z, PetscDeviceContext dctx) noexcept 938 { 939 // probably not worth it to run more than 8 of these at a time? 940 const auto n_sub = PetscMin(nv, 8); 941 const auto n = static_cast<cupmBlasInt_t>(xin->map->n); 942 const auto xptr = DeviceArrayRead(dctx, xin); 943 PetscScalar *d_z; 944 PetscDeviceContext *subctx; 945 cupmStream_t stream; 946 947 PetscFunctionBegin; 948 PetscCall(GetHandlesFrom_(dctx, &stream)); 949 PetscCall(PetscDeviceMalloc(dctx, PETSC_MEMTYPE_CUPM(), nv, &d_z)); 950 PetscCall(PetscDeviceContextFork(dctx, n_sub, &subctx)); 951 PetscCall(PetscLogGpuTimeBegin()); 952 for (PetscInt i = 0; i < nv; ++i) { 953 const auto sub = subctx[i % n_sub]; 954 cupmBlasHandle_t handle; 955 cupmBlasPointerMode_t old_mode; 956 957 PetscCall(GetHandlesFrom_(sub, &handle)); 958 PetscCallCUPMBLAS(cupmBlasGetPointerMode(handle, &old_mode)); 959 if (old_mode != CUPMBLAS_POINTER_MODE_DEVICE) PetscCallCUPMBLAS(cupmBlasSetPointerMode(handle, CUPMBLAS_POINTER_MODE_DEVICE)); 960 PetscCallCUPMBLAS(cupmBlasXdot(handle, n, DeviceArrayRead(sub, yin[i]), 1, xptr.cupmdata(), 1, cupmScalarPtrCast(d_z + i))); 961 if (old_mode != CUPMBLAS_POINTER_MODE_DEVICE) PetscCallCUPMBLAS(cupmBlasSetPointerMode(handle, old_mode)); 962 } 963 PetscCall(PetscLogGpuTimeEnd()); 964 PetscCall(PetscDeviceContextJoin(dctx, n_sub, PETSC_DEVICE_CONTEXT_JOIN_DESTROY, &subctx)); 965 PetscCall(PetscCUPMMemcpyAsync(z, d_z, nv, cupmMemcpyDeviceToHost, stream)); 966 PetscCall(PetscDeviceFree(dctx, d_z)); 967 // REVIEW ME: flops????? 968 PetscFunctionReturn(0); 969 } 970 971 // v->ops->mdot 972 template <device::cupm::DeviceType T> 973 inline PetscErrorCode VecSeq_CUPM<T>::mdot(Vec xin, PetscInt nv, const Vec yin[], PetscScalar *z) noexcept 974 { 975 PetscFunctionBegin; 976 if (PetscUnlikely(nv == 1)) { 977 // dot handles nv = 0 correctly 978 PetscCall(dot(xin, const_cast<Vec>(yin[0]), z)); 979 } else if (const auto n = xin->map->n) { 980 PetscDeviceContext dctx; 981 982 PetscCheck(nv > 0, PETSC_COMM_SELF, PETSC_ERR_LIB, "Number of vectors provided to %s %" PetscInt_FMT " not positive", PETSC_FUNCTION_NAME, nv); 983 PetscCall(GetHandles_(&dctx)); 984 PetscCall(mdot_(std::integral_constant<bool, PetscDefined(USE_COMPLEX)>{}, xin, nv, yin, z, dctx)); 985 // REVIEW ME: double count of flops?? 986 PetscCall(PetscLogGpuFlops(nv * (2 * n - 1))); 987 PetscCall(PetscDeviceContextSynchronize(dctx)); 988 } else { 989 PetscCall(PetscArrayzero(z, nv)); 990 } 991 PetscFunctionReturn(0); 992 } 993 994 // v->ops->set 995 template <device::cupm::DeviceType T> 996 inline PetscErrorCode VecSeq_CUPM<T>::set(Vec xin, PetscScalar alpha) noexcept 997 { 998 const auto n = xin->map->n; 999 PetscDeviceContext dctx; 1000 cupmStream_t stream; 1001 1002 PetscFunctionBegin; 1003 PetscCall(GetHandles_(&dctx, &stream)); 1004 { 1005 const auto xptr = DeviceArrayWrite(dctx, xin); 1006 1007 if (alpha == PetscScalar(0.0)) { 1008 PetscCall(PetscCUPMMemsetAsync(xptr.data(), 0, n, stream)); 1009 } else { 1010 PetscCall(device::cupm::impl::ThrustSet<T>(stream, n, xptr.data(), &alpha)); 1011 } 1012 if (n) PetscCall(PetscDeviceContextSynchronize(dctx)); // don't sync if we did nothing 1013 } 1014 PetscFunctionReturn(0); 1015 } 1016 1017 // v->ops->scale 1018 template <device::cupm::DeviceType T> 1019 inline PetscErrorCode VecSeq_CUPM<T>::scale(Vec xin, PetscScalar alpha) noexcept 1020 { 1021 PetscFunctionBegin; 1022 if (PetscUnlikely(alpha == PetscScalar(1.0))) PetscFunctionReturn(0); 1023 if (PetscUnlikely(alpha == PetscScalar(0.0))) { 1024 PetscCall(set(xin, alpha)); 1025 } else if (const auto n = static_cast<cupmBlasInt_t>(xin->map->n)) { 1026 PetscDeviceContext dctx; 1027 cupmBlasHandle_t cupmBlasHandle; 1028 1029 PetscCall(GetHandles_(&dctx, &cupmBlasHandle)); 1030 PetscCall(PetscLogGpuTimeBegin()); 1031 PetscCallCUPMBLAS(cupmBlasXscal(cupmBlasHandle, n, cupmScalarPtrCast(&alpha), DeviceArrayReadWrite(dctx, xin), 1)); 1032 PetscCall(PetscLogGpuTimeEnd()); 1033 PetscCall(PetscLogGpuFlops(n)); 1034 PetscCall(PetscDeviceContextSynchronize(dctx)); 1035 } else { 1036 PetscCall(MaybeIncrementEmptyLocalVec(xin)); 1037 } 1038 PetscFunctionReturn(0); 1039 } 1040 1041 // v->ops->tdot 1042 template <device::cupm::DeviceType T> 1043 inline PetscErrorCode VecSeq_CUPM<T>::tdot(Vec xin, Vec yin, PetscScalar *z) noexcept 1044 { 1045 PetscFunctionBegin; 1046 if (const auto n = static_cast<cupmBlasInt_t>(xin->map->n)) { 1047 PetscDeviceContext dctx; 1048 cupmBlasHandle_t cupmBlasHandle; 1049 1050 PetscCall(GetHandles_(&dctx, &cupmBlasHandle)); 1051 PetscCall(PetscLogGpuTimeBegin()); 1052 PetscCallCUPMBLAS(cupmBlasXdotu(cupmBlasHandle, n, DeviceArrayRead(dctx, xin), 1, DeviceArrayRead(dctx, yin), 1, cupmScalarPtrCast(z))); 1053 PetscCall(PetscLogGpuTimeEnd()); 1054 PetscCall(PetscLogGpuFlops(2 * n - 1)); 1055 } else { 1056 *z = 0.0; 1057 } 1058 PetscFunctionReturn(0); 1059 } 1060 1061 // v->ops->copy 1062 template <device::cupm::DeviceType T> 1063 inline PetscErrorCode VecSeq_CUPM<T>::copy(Vec xin, Vec yout) noexcept 1064 { 1065 PetscFunctionBegin; 1066 if (xin == yout) PetscFunctionReturn(0); 1067 if (const auto n = xin->map->n) { 1068 const auto xmask = xin->offloadmask; 1069 // silence buggy gcc warning: mode may be used uninitialized in this function 1070 auto mode = cupmMemcpyDeviceToDevice; 1071 PetscDeviceContext dctx; 1072 cupmStream_t stream; 1073 1074 // translate from PetscOffloadMask to cupmMemcpyKind 1075 switch (const auto ymask = yout->offloadmask) { 1076 case PETSC_OFFLOAD_UNALLOCATED: { 1077 PetscBool yiscupm; 1078 1079 PetscCall(PetscObjectTypeCompareAny(PetscObjectCast(yout), &yiscupm, VECSEQCUPM(), VECMPICUPM(), "")); 1080 if (yiscupm) { 1081 mode = PetscOffloadDevice(xmask) ? cupmMemcpyDeviceToDevice : cupmMemcpyHostToHost; 1082 break; 1083 } 1084 } // fall-through if unallocated and not cupm 1085 #if PETSC_CPP_VERSION >= 17 1086 [[fallthrough]]; 1087 #endif 1088 case PETSC_OFFLOAD_CPU: 1089 mode = PetscOffloadHost(xmask) ? cupmMemcpyHostToHost : cupmMemcpyDeviceToHost; 1090 break; 1091 case PETSC_OFFLOAD_BOTH: 1092 case PETSC_OFFLOAD_GPU: 1093 mode = PetscOffloadDevice(xmask) ? cupmMemcpyDeviceToDevice : cupmMemcpyHostToDevice; 1094 break; 1095 default: 1096 SETERRQ(PETSC_COMM_SELF, PETSC_ERR_ARG_INCOMP, "Incompatible offload mask %s", PetscOffloadMaskToString(ymask)); 1097 } 1098 1099 PetscCall(GetHandles_(&dctx, &stream)); 1100 switch (mode) { 1101 case cupmMemcpyDeviceToDevice: // the best case 1102 case cupmMemcpyHostToDevice: { // not terrible 1103 const auto yptr = DeviceArrayWrite(dctx, yout); 1104 const auto xptr = mode == cupmMemcpyDeviceToDevice ? DeviceArrayRead(dctx, xin).data() : HostArrayRead(dctx, xin).data(); 1105 1106 PetscCall(PetscLogGpuTimeBegin()); 1107 PetscCall(PetscCUPMMemcpyAsync(yptr.data(), xptr, n, mode, stream)); 1108 PetscCall(PetscLogGpuTimeEnd()); 1109 } break; 1110 case cupmMemcpyDeviceToHost: // not great 1111 case cupmMemcpyHostToHost: { // worst case 1112 const auto xptr = mode == cupmMemcpyDeviceToHost ? DeviceArrayRead(dctx, xin).data() : HostArrayRead(dctx, xin).data(); 1113 PetscScalar *yptr; 1114 1115 PetscCall(VecGetArrayWrite(yout, &yptr)); 1116 if (mode == cupmMemcpyDeviceToHost) PetscCall(PetscLogGpuTimeBegin()); 1117 PetscCall(PetscCUPMMemcpyAsync(yptr, xptr, n, mode, stream, /* force async */ true)); 1118 if (mode == cupmMemcpyDeviceToHost) PetscCall(PetscLogGpuTimeEnd()); 1119 PetscCall(VecRestoreArrayWrite(yout, &yptr)); 1120 } break; 1121 default: 1122 SETERRQ(PETSC_COMM_SELF, PETSC_ERR_GPU, "Unknown cupmMemcpyKind %d", static_cast<int>(mode)); 1123 } 1124 PetscCall(PetscDeviceContextSynchronize(dctx)); 1125 } else { 1126 PetscCall(MaybeIncrementEmptyLocalVec(yout)); 1127 } 1128 PetscFunctionReturn(0); 1129 } 1130 1131 // v->ops->swap 1132 template <device::cupm::DeviceType T> 1133 inline PetscErrorCode VecSeq_CUPM<T>::swap(Vec xin, Vec yin) noexcept 1134 { 1135 PetscFunctionBegin; 1136 if (xin == yin) PetscFunctionReturn(0); 1137 if (const auto n = static_cast<cupmBlasInt_t>(xin->map->n)) { 1138 PetscDeviceContext dctx; 1139 cupmBlasHandle_t cupmBlasHandle; 1140 1141 PetscCall(GetHandles_(&dctx, &cupmBlasHandle)); 1142 PetscCall(PetscLogGpuTimeBegin()); 1143 PetscCallCUPMBLAS(cupmBlasXswap(cupmBlasHandle, n, DeviceArrayReadWrite(dctx, xin), 1, DeviceArrayReadWrite(dctx, yin), 1)); 1144 PetscCall(PetscLogGpuTimeEnd()); 1145 PetscCall(PetscDeviceContextSynchronize(dctx)); 1146 } else { 1147 PetscCall(MaybeIncrementEmptyLocalVec(xin)); 1148 PetscCall(MaybeIncrementEmptyLocalVec(yin)); 1149 } 1150 PetscFunctionReturn(0); 1151 } 1152 1153 // v->ops->axpby 1154 template <device::cupm::DeviceType T> 1155 inline PetscErrorCode VecSeq_CUPM<T>::axpby(Vec yin, PetscScalar alpha, PetscScalar beta, Vec xin) noexcept 1156 { 1157 PetscFunctionBegin; 1158 if (alpha == PetscScalar(0.0)) { 1159 PetscCall(scale(yin, beta)); 1160 } else if (beta == PetscScalar(1.0)) { 1161 PetscCall(axpy(yin, alpha, xin)); 1162 } else if (alpha == PetscScalar(1.0)) { 1163 PetscCall(aypx(yin, beta, xin)); 1164 } else if (const auto n = static_cast<cupmBlasInt_t>(yin->map->n)) { 1165 const auto betaIsZero = beta == PetscScalar(0.0); 1166 const auto aptr = cupmScalarPtrCast(&alpha); 1167 PetscDeviceContext dctx; 1168 cupmBlasHandle_t cupmBlasHandle; 1169 1170 PetscCall(GetHandles_(&dctx, &cupmBlasHandle)); 1171 { 1172 const auto xptr = DeviceArrayRead(dctx, xin); 1173 1174 if (betaIsZero /* beta = 0 */) { 1175 // here we can get away with purely write-only as we memcpy into it first 1176 const auto yptr = DeviceArrayWrite(dctx, yin); 1177 cupmStream_t stream; 1178 1179 PetscCall(GetHandlesFrom_(dctx, &stream)); 1180 PetscCall(PetscLogGpuTimeBegin()); 1181 PetscCall(PetscCUPMMemcpyAsync(yptr.data(), xptr.data(), n, cupmMemcpyDeviceToDevice, stream)); 1182 PetscCallCUPMBLAS(cupmBlasXscal(cupmBlasHandle, n, aptr, yptr.cupmdata(), 1)); 1183 } else { 1184 const auto yptr = DeviceArrayReadWrite(dctx, yin); 1185 1186 PetscCall(PetscLogGpuTimeBegin()); 1187 PetscCallCUPMBLAS(cupmBlasXscal(cupmBlasHandle, n, cupmScalarPtrCast(&beta), yptr.cupmdata(), 1)); 1188 PetscCallCUPMBLAS(cupmBlasXaxpy(cupmBlasHandle, n, aptr, xptr.cupmdata(), 1, yptr.cupmdata(), 1)); 1189 } 1190 } 1191 PetscCall(PetscLogGpuTimeEnd()); 1192 PetscCall(PetscLogGpuFlops((betaIsZero ? 1 : 3) * n)); 1193 PetscCall(PetscDeviceContextSynchronize(dctx)); 1194 } else { 1195 PetscCall(MaybeIncrementEmptyLocalVec(yin)); 1196 } 1197 PetscFunctionReturn(0); 1198 } 1199 1200 // v->ops->axpbypcz 1201 template <device::cupm::DeviceType T> 1202 inline PetscErrorCode VecSeq_CUPM<T>::axpbypcz(Vec zin, PetscScalar alpha, PetscScalar beta, PetscScalar gamma, Vec xin, Vec yin) noexcept 1203 { 1204 PetscFunctionBegin; 1205 if (gamma != PetscScalar(1.0)) PetscCall(scale(zin, gamma)); 1206 PetscCall(axpy(zin, alpha, xin)); 1207 PetscCall(axpy(zin, beta, yin)); 1208 PetscFunctionReturn(0); 1209 } 1210 1211 // v->ops->norm 1212 template <device::cupm::DeviceType T> 1213 inline PetscErrorCode VecSeq_CUPM<T>::norm(Vec xin, NormType type, PetscReal *z) noexcept 1214 { 1215 PetscDeviceContext dctx; 1216 cupmBlasHandle_t cupmBlasHandle; 1217 1218 PetscFunctionBegin; 1219 PetscCall(GetHandles_(&dctx, &cupmBlasHandle)); 1220 if (const auto n = static_cast<cupmBlasInt_t>(xin->map->n)) { 1221 const auto xptr = DeviceArrayRead(dctx, xin); 1222 PetscInt flopCount = 0; 1223 1224 PetscCall(PetscLogGpuTimeBegin()); 1225 switch (type) { 1226 case NORM_1_AND_2: 1227 case NORM_1: 1228 PetscCallCUPMBLAS(cupmBlasXasum(cupmBlasHandle, n, xptr.cupmdata(), 1, cupmRealPtrCast(z))); 1229 flopCount = std::max(n - 1, 0); 1230 if (type == NORM_1) break; 1231 ++z; // fall-through 1232 #if PETSC_CPP_VERSION >= 17 1233 [[fallthrough]]; 1234 #endif 1235 case NORM_2: 1236 case NORM_FROBENIUS: 1237 PetscCallCUPMBLAS(cupmBlasXnrm2(cupmBlasHandle, n, xptr.cupmdata(), 1, cupmRealPtrCast(z))); 1238 flopCount += std::max(2 * n - 1, 0); // += in case we've fallen through from NORM_1_AND_2 1239 break; 1240 case NORM_INFINITY: { 1241 cupmBlasInt_t max_loc = 0; 1242 PetscScalar xv = 0.; 1243 cupmStream_t stream; 1244 1245 PetscCall(GetHandlesFrom_(dctx, &stream)); 1246 PetscCallCUPMBLAS(cupmBlasXamax(cupmBlasHandle, n, xptr.cupmdata(), 1, &max_loc)); 1247 PetscCall(PetscCUPMMemcpyAsync(&xv, xptr.data() + max_loc - 1, 1, cupmMemcpyDeviceToHost, stream)); 1248 *z = PetscAbsScalar(xv); 1249 // REVIEW ME: flopCount = ??? 1250 } break; 1251 } 1252 PetscCall(PetscLogGpuTimeEnd()); 1253 PetscCall(PetscLogGpuFlops(flopCount)); 1254 } else { 1255 z[0] = 0.0; 1256 z[type == NORM_1_AND_2] = 0.0; 1257 } 1258 PetscFunctionReturn(0); 1259 } 1260 1261 namespace detail 1262 { 1263 1264 struct dotnorm2_mult { 1265 PETSC_NODISCARD PETSC_HOSTDEVICE_INLINE_DECL thrust::tuple<PetscScalar, PetscScalar> operator()(const PetscScalar &s, const PetscScalar &t) const noexcept 1266 { 1267 const auto conjt = PetscConj(t); 1268 1269 return {s * conjt, t * conjt}; 1270 } 1271 }; 1272 1273 // it is positively __bananas__ that thrust does not define default operator+ for tuples... I 1274 // would do it myself but now I am worried that they do so on purpose... 1275 struct dotnorm2_tuple_plus { 1276 using value_type = thrust::tuple<PetscScalar, PetscScalar>; 1277 1278 PETSC_NODISCARD PETSC_HOSTDEVICE_INLINE_DECL value_type operator()(const value_type &lhs, const value_type &rhs) const noexcept { return {lhs.get<0>() + rhs.get<0>(), lhs.get<1>() + rhs.get<1>()}; } 1279 }; 1280 1281 } // namespace detail 1282 1283 // v->ops->dotnorm2 1284 template <device::cupm::DeviceType T> 1285 inline PetscErrorCode VecSeq_CUPM<T>::dotnorm2(Vec s, Vec t, PetscScalar *dp, PetscScalar *nm) noexcept 1286 { 1287 PetscDeviceContext dctx; 1288 cupmStream_t stream; 1289 1290 PetscFunctionBegin; 1291 PetscCall(GetHandles_(&dctx, &stream)); 1292 { 1293 PetscScalar dpt = 0.0, nmt = 0.0; 1294 const auto sdptr = thrust::device_pointer_cast(DeviceArrayRead(dctx, s).data()); 1295 1296 // clang-format off 1297 PetscCallThrust( 1298 thrust::tie(*dp, *nm) = THRUST_CALL( 1299 thrust::inner_product, 1300 stream, 1301 sdptr, sdptr+s->map->n, thrust::device_pointer_cast(DeviceArrayRead(dctx, t).data()), 1302 thrust::make_tuple(dpt, nmt), 1303 detail::dotnorm2_tuple_plus{}, detail::dotnorm2_mult{} 1304 ); 1305 ); 1306 // clang-format on 1307 } 1308 PetscFunctionReturn(0); 1309 } 1310 1311 namespace detail 1312 { 1313 1314 struct conjugate { 1315 PETSC_NODISCARD PETSC_HOSTDEVICE_INLINE_DECL PetscScalar operator()(PetscScalar x) const noexcept { return PetscConj(x); } 1316 }; 1317 1318 } // namespace detail 1319 1320 // v->ops->conjugate 1321 template <device::cupm::DeviceType T> 1322 inline PetscErrorCode VecSeq_CUPM<T>::conjugate(Vec xin) noexcept 1323 { 1324 PetscFunctionBegin; 1325 if (PetscDefined(USE_COMPLEX)) PetscCall(pointwiseunary_(detail::conjugate{}, xin)); 1326 PetscFunctionReturn(0); 1327 } 1328 1329 namespace detail 1330 { 1331 1332 struct real_part { 1333 PETSC_NODISCARD PETSC_HOSTDEVICE_INLINE_DECL thrust::tuple<PetscReal, PetscInt> operator()(const thrust::tuple<PetscScalar, PetscInt> &x) const { return {PetscRealPart(x.get<0>()), x.get<1>()}; } 1334 1335 PETSC_NODISCARD PETSC_HOSTDEVICE_INLINE_DECL PetscReal operator()(PetscScalar x) const { return PetscRealPart(x); } 1336 }; 1337 1338 // deriving from Operator allows us to "store" an instance of the operator in the class but 1339 // also take advantage of empty base class optimization if the operator is stateless 1340 template <typename Operator> 1341 class tuple_compare : Operator { 1342 public: 1343 using tuple_type = thrust::tuple<PetscReal, PetscInt>; 1344 using operator_type = Operator; 1345 1346 PETSC_NODISCARD PETSC_HOSTDEVICE_INLINE_DECL tuple_type operator()(const tuple_type &x, const tuple_type &y) const noexcept 1347 { 1348 if (op_()(y.get<0>(), x.get<0>())) { 1349 // if y is strictly greater/less than x, return y 1350 return y; 1351 } else if (y.get<0>() == x.get<0>()) { 1352 // if equal, prefer lower index 1353 return y.get<1>() < x.get<1>() ? y : x; 1354 } 1355 // otherwise return x 1356 return x; 1357 } 1358 1359 private: 1360 PETSC_NODISCARD PETSC_HOSTDEVICE_INLINE_DECL const operator_type &op_() const noexcept { return *this; } 1361 }; 1362 1363 } // namespace detail 1364 1365 template <device::cupm::DeviceType T> 1366 template <typename TupleFuncT, typename UnaryFuncT> 1367 inline PetscErrorCode VecSeq_CUPM<T>::minmax_(TupleFuncT &&tuple_ftr, UnaryFuncT &&unary_ftr, Vec v, PetscInt *p, PetscReal *m) noexcept 1368 { 1369 PetscFunctionBegin; 1370 PetscCheckTypeNames(v, VECSEQCUPM(), VECMPICUPM()); 1371 if (p) *p = -1; 1372 if (const auto n = v->map->n) { 1373 PetscDeviceContext dctx; 1374 cupmStream_t stream; 1375 1376 PetscCall(GetHandles_(&dctx, &stream)); 1377 // needed to: 1378 // 1. switch between transform_reduce and reduce 1379 // 2. strip the real_part functor from the arguments 1380 #if PetscDefined(USE_COMPLEX) 1381 #define THRUST_MINMAX_REDUCE(...) THRUST_CALL(thrust::transform_reduce, __VA_ARGS__) 1382 #else 1383 #define THRUST_MINMAX_REDUCE(s, b, e, real_part__, ...) THRUST_CALL(thrust::reduce, s, b, e, __VA_ARGS__) 1384 #endif 1385 { 1386 const auto vptr = thrust::device_pointer_cast(DeviceArrayRead(dctx, v).data()); 1387 1388 if (p) { 1389 // clang-format off 1390 const auto zip = thrust::make_zip_iterator( 1391 thrust::make_tuple(std::move(vptr), thrust::make_counting_iterator(PetscInt{0})) 1392 ); 1393 // clang-format on 1394 // need to use preprocessor conditionals since otherwise thrust complains about not being 1395 // able to convert a thrust::device_reference<PetscScalar> to a PetscReal on complex 1396 // builds... 1397 // clang-format off 1398 PetscCallThrust( 1399 thrust::tie(*m, *p) = THRUST_MINMAX_REDUCE( 1400 stream, zip, zip + n, detail::real_part{}, 1401 thrust::make_tuple(*m, *p), std::forward<TupleFuncT>(tuple_ftr) 1402 ); 1403 ); 1404 // clang-format on 1405 } else { 1406 // clang-format off 1407 PetscCallThrust( 1408 *m = THRUST_MINMAX_REDUCE( 1409 stream, vptr, vptr + n, detail::real_part{}, 1410 *m, std::forward<UnaryFuncT>(unary_ftr) 1411 ); 1412 ); 1413 // clang-format on 1414 } 1415 } 1416 #undef THRUST_MINMAX_REDUCE 1417 } 1418 // REVIEW ME: flops? 1419 PetscFunctionReturn(0); 1420 } 1421 1422 // v->ops->max 1423 template <device::cupm::DeviceType T> 1424 inline PetscErrorCode VecSeq_CUPM<T>::max(Vec v, PetscInt *p, PetscReal *m) noexcept 1425 { 1426 using tuple_functor = detail::tuple_compare<thrust::greater<PetscReal>>; 1427 using unary_functor = thrust::maximum<PetscReal>; 1428 1429 PetscFunctionBegin; 1430 *m = PETSC_MIN_REAL; 1431 // use {} constructor syntax otherwise most vexing parse 1432 PetscCall(minmax_(tuple_functor{}, unary_functor{}, v, p, m)); 1433 PetscFunctionReturn(0); 1434 } 1435 1436 // v->ops->min 1437 template <device::cupm::DeviceType T> 1438 inline PetscErrorCode VecSeq_CUPM<T>::min(Vec v, PetscInt *p, PetscReal *m) noexcept 1439 { 1440 using tuple_functor = detail::tuple_compare<thrust::less<PetscReal>>; 1441 using unary_functor = thrust::minimum<PetscReal>; 1442 1443 PetscFunctionBegin; 1444 *m = PETSC_MAX_REAL; 1445 // use {} constructor syntax otherwise most vexing parse 1446 PetscCall(minmax_(tuple_functor{}, unary_functor{}, v, p, m)); 1447 PetscFunctionReturn(0); 1448 } 1449 1450 // v->ops->sum 1451 template <device::cupm::DeviceType T> 1452 inline PetscErrorCode VecSeq_CUPM<T>::sum(Vec v, PetscScalar *sum) noexcept 1453 { 1454 PetscFunctionBegin; 1455 if (const auto n = v->map->n) { 1456 PetscDeviceContext dctx; 1457 cupmStream_t stream; 1458 1459 PetscCall(GetHandles_(&dctx, &stream)); 1460 const auto dptr = thrust::device_pointer_cast(DeviceArrayRead(dctx, v).data()); 1461 // REVIEW ME: why not cupmBlasXasum()? 1462 PetscCallThrust(*sum = THRUST_CALL(thrust::reduce, stream, dptr, dptr + n, PetscScalar{0.0});); 1463 // REVIEW ME: must be at least n additions 1464 PetscCall(PetscLogGpuFlops(n)); 1465 } else { 1466 *sum = 0.0; 1467 } 1468 PetscFunctionReturn(0); 1469 } 1470 1471 namespace detail 1472 { 1473 1474 template <typename T> 1475 class plus_equals { 1476 public: 1477 using value_type = T; 1478 1479 PETSC_HOSTDEVICE_DECL constexpr explicit plus_equals(value_type v = value_type{}) noexcept : v_(std::move(v)) { } 1480 1481 PETSC_NODISCARD PETSC_HOSTDEVICE_INLINE_DECL constexpr T operator()(const T &val) const noexcept { return val + v_; } 1482 1483 private: 1484 value_type v_; 1485 }; 1486 1487 } // namespace detail 1488 1489 template <device::cupm::DeviceType T> 1490 inline PetscErrorCode VecSeq_CUPM<T>::shift(Vec v, PetscScalar shift) noexcept 1491 { 1492 PetscFunctionBegin; 1493 PetscCall(pointwiseunary_(detail::plus_equals<PetscScalar>{shift}, v)); 1494 PetscFunctionReturn(0); 1495 } 1496 1497 template <device::cupm::DeviceType T> 1498 inline PetscErrorCode VecSeq_CUPM<T>::setrandom(Vec v, PetscRandom rand) noexcept 1499 { 1500 PetscFunctionBegin; 1501 if (const auto n = v->map->n) { 1502 PetscBool iscurand; 1503 PetscDeviceContext dctx; 1504 1505 PetscCall(GetHandles_(&dctx)); 1506 PetscCall(PetscObjectTypeCompare(PetscObjectCast(rand), PETSCCURAND, &iscurand)); 1507 if (iscurand) PetscCall(PetscRandomGetValues(rand, n, DeviceArrayWrite(dctx, v))); 1508 else PetscCall(PetscRandomGetValues(rand, n, HostArrayWrite(dctx, v))); 1509 } else { 1510 PetscCall(MaybeIncrementEmptyLocalVec(v)); 1511 } 1512 // REVIEW ME: flops???? 1513 // REVIEW ME: Timing??? 1514 PetscFunctionReturn(0); 1515 } 1516 1517 // v->ops->setpreallocation 1518 template <device::cupm::DeviceType T> 1519 inline PetscErrorCode VecSeq_CUPM<T>::setpreallocationcoo(Vec v, PetscCount ncoo, const PetscInt coo_i[]) noexcept 1520 { 1521 PetscDeviceContext dctx; 1522 1523 PetscFunctionBegin; 1524 PetscCall(GetHandles_(&dctx)); 1525 PetscCall(VecSetPreallocationCOO_Seq(v, ncoo, coo_i)); 1526 PetscCall(SetPreallocationCOO_CUPMBase(v, ncoo, coo_i, dctx)); 1527 PetscFunctionReturn(0); 1528 } 1529 1530 namespace kernels 1531 { 1532 1533 template <typename F> 1534 PETSC_DEVICE_INLINE_DECL void add_coo_values_impl(const PetscScalar *PETSC_RESTRICT vv, PetscCount n, const PetscCount *PETSC_RESTRICT jmap, const PetscCount *PETSC_RESTRICT perm, InsertMode imode, PetscScalar *PETSC_RESTRICT xv, F &&xvindex) 1535 { 1536 ::Petsc::device::cupm::kernels::util::grid_stride_1D(n, [=](PetscCount i) { 1537 const auto end = jmap[i + 1]; 1538 const auto idx = xvindex(i); 1539 PetscScalar sum = 0.0; 1540 1541 for (auto k = jmap[i]; k < end; ++k) sum += vv[perm[k]]; 1542 1543 if (imode == INSERT_VALUES) { 1544 xv[idx] = sum; 1545 } else { 1546 xv[idx] += sum; 1547 } 1548 }); 1549 return; 1550 } 1551 1552 PETSC_KERNEL_DECL static void add_coo_values(const PetscScalar *PETSC_RESTRICT v, PetscCount n, const PetscCount *PETSC_RESTRICT jmap1, const PetscCount *PETSC_RESTRICT perm1, InsertMode imode, PetscScalar *PETSC_RESTRICT xv) 1553 { 1554 add_coo_values_impl(v, n, jmap1, perm1, imode, xv, [](PetscCount i) { return i; }); 1555 return; 1556 } 1557 1558 } // namespace kernels 1559 1560 // v->ops->setvaluescoo 1561 template <device::cupm::DeviceType T> 1562 inline PetscErrorCode VecSeq_CUPM<T>::setvaluescoo(Vec x, const PetscScalar v[], InsertMode imode) noexcept 1563 { 1564 auto vv = const_cast<PetscScalar *>(v); 1565 PetscMemType memtype; 1566 PetscDeviceContext dctx; 1567 cupmStream_t stream; 1568 1569 PetscFunctionBegin; 1570 PetscCall(GetHandles_(&dctx, &stream)); 1571 PetscCall(PetscGetMemType(v, &memtype)); 1572 if (PetscMemTypeHost(memtype)) { 1573 const auto size = VecIMPLCast(x)->coo_n; 1574 1575 // If user gave v[] in host, we might need to copy it to device if any 1576 PetscCall(PetscDeviceMalloc(dctx, PETSC_MEMTYPE_CUPM(), size, &vv)); 1577 PetscCall(PetscCUPMMemcpyAsync(vv, v, size, cupmMemcpyHostToDevice, stream)); 1578 } 1579 1580 if (const auto n = x->map->n) { 1581 const auto vcu = VecCUPMCast(x); 1582 1583 PetscCall(PetscCUPMLaunchKernel1D(n, 0, stream, kernels::add_coo_values, vv, n, vcu->jmap1_d, vcu->perm1_d, imode, imode == INSERT_VALUES ? DeviceArrayWrite(dctx, x).data() : DeviceArrayReadWrite(dctx, x).data())); 1584 } else { 1585 PetscCall(MaybeIncrementEmptyLocalVec(x)); 1586 } 1587 1588 if (PetscMemTypeHost(memtype)) PetscCall(PetscDeviceFree(dctx, vv)); 1589 PetscCall(PetscDeviceContextSynchronize(dctx)); 1590 PetscFunctionReturn(0); 1591 } 1592 1593 // ========================================================================================== 1594 // VecSeq_CUPM - Implementations 1595 // ========================================================================================== 1596 1597 namespace 1598 { 1599 1600 template <typename T> 1601 PETSC_NODISCARD inline PetscErrorCode VecCreateSeqCUPMAsync(T &&VecSeq_CUPM_Impls, MPI_Comm comm, PetscInt n, Vec *v) noexcept 1602 { 1603 PetscFunctionBegin; 1604 PetscValidPointer(v, 4); 1605 PetscCall(VecSeq_CUPM_Impls.createseqcupm(comm, 0, n, v, PETSC_TRUE)); 1606 PetscFunctionReturn(0); 1607 } 1608 1609 template <typename T> 1610 PETSC_NODISCARD inline PetscErrorCode VecCreateSeqCUPMWithArraysAsync(T &&VecSeq_CUPM_Impls, MPI_Comm comm, PetscInt bs, PetscInt n, const PetscScalar cpuarray[], const PetscScalar gpuarray[], Vec *v) noexcept 1611 { 1612 PetscFunctionBegin; 1613 if (n && cpuarray) PetscValidScalarPointer(cpuarray, 5); 1614 PetscValidPointer(v, 7); 1615 PetscCall(VecSeq_CUPM_Impls.createseqcupmwithbotharrays(comm, bs, n, cpuarray, gpuarray, v)); 1616 PetscFunctionReturn(0); 1617 } 1618 1619 template <PetscMemoryAccessMode mode, typename T> 1620 PETSC_NODISCARD inline PetscErrorCode VecCUPMGetArrayAsync_Private(T &&VecSeq_CUPM_Impls, Vec v, PetscScalar **a) noexcept 1621 { 1622 PetscFunctionBegin; 1623 PetscValidHeaderSpecific(v, VEC_CLASSID, 2); 1624 PetscValidPointer(a, 3); 1625 PetscCall(VecSeq_CUPM_Impls.template getarray<PETSC_MEMTYPE_DEVICE, mode>(v, a)); 1626 PetscFunctionReturn(0); 1627 } 1628 1629 template <PetscMemoryAccessMode mode, typename T> 1630 PETSC_NODISCARD inline PetscErrorCode VecCUPMRestoreArrayAsync_Private(T &&VecSeq_CUPM_Impls, Vec v, PetscScalar **a) noexcept 1631 { 1632 PetscFunctionBegin; 1633 PetscValidHeaderSpecific(v, VEC_CLASSID, 2); 1634 PetscCall(VecSeq_CUPM_Impls.template restorearray<PETSC_MEMTYPE_DEVICE, mode>(v, a)); 1635 PetscFunctionReturn(0); 1636 } 1637 1638 template <typename T> 1639 PETSC_NODISCARD inline PetscErrorCode VecCUPMGetArrayAsync(T &&VecSeq_CUPM_Impls, Vec v, PetscScalar **a) noexcept 1640 { 1641 PetscFunctionBegin; 1642 PetscCall(VecCUPMGetArrayAsync_Private<PETSC_MEMORY_ACCESS_READ_WRITE>(std::forward<T>(VecSeq_CUPM_Impls), v, a)); 1643 PetscFunctionReturn(0); 1644 } 1645 1646 template <typename T> 1647 PETSC_NODISCARD inline PetscErrorCode VecCUPMRestoreArrayAsync(T &&VecSeq_CUPM_Impls, Vec v, PetscScalar **a) noexcept 1648 { 1649 PetscFunctionBegin; 1650 PetscCall(VecCUPMRestoreArrayAsync_Private<PETSC_MEMORY_ACCESS_READ_WRITE>(std::forward<T>(VecSeq_CUPM_Impls), v, a)); 1651 PetscFunctionReturn(0); 1652 } 1653 1654 template <typename T> 1655 PETSC_NODISCARD inline PetscErrorCode VecCUPMGetArrayReadAsync(T &&VecSeq_CUPM_Impls, Vec v, const PetscScalar **a) noexcept 1656 { 1657 PetscFunctionBegin; 1658 PetscCall(VecCUPMGetArrayAsync_Private<PETSC_MEMORY_ACCESS_READ>(std::forward<T>(VecSeq_CUPM_Impls), v, const_cast<PetscScalar **>(a))); 1659 PetscFunctionReturn(0); 1660 } 1661 1662 template <typename T> 1663 PETSC_NODISCARD inline PetscErrorCode VecCUPMRestoreArrayReadAsync(T &&VecSeq_CUPM_Impls, Vec v, const PetscScalar **a) noexcept 1664 { 1665 PetscFunctionBegin; 1666 PetscCall(VecCUPMRestoreArrayAsync_Private<PETSC_MEMORY_ACCESS_READ>(std::forward<T>(VecSeq_CUPM_Impls), v, const_cast<PetscScalar **>(a))); 1667 PetscFunctionReturn(0); 1668 } 1669 1670 template <typename T> 1671 PETSC_NODISCARD inline PetscErrorCode VecCUPMGetArrayWriteAsync(T &&VecSeq_CUPM_Impls, Vec v, PetscScalar **a) noexcept 1672 { 1673 PetscFunctionBegin; 1674 PetscCall(VecCUPMGetArrayAsync_Private<PETSC_MEMORY_ACCESS_WRITE>(std::forward<T>(VecSeq_CUPM_Impls), v, a)); 1675 PetscFunctionReturn(0); 1676 } 1677 1678 template <typename T> 1679 PETSC_NODISCARD inline PetscErrorCode VecCUPMRestoreArrayWriteAsync(T &&VecSeq_CUPM_Impls, Vec v, PetscScalar **a) noexcept 1680 { 1681 PetscFunctionBegin; 1682 PetscCall(VecCUPMRestoreArrayAsync_Private<PETSC_MEMORY_ACCESS_WRITE>(std::forward<T>(VecSeq_CUPM_Impls), v, a)); 1683 PetscFunctionReturn(0); 1684 } 1685 1686 template <typename T> 1687 PETSC_NODISCARD inline PetscErrorCode VecCUPMPlaceArrayAsync(T &&VecSeq_CUPM_Impls, Vec vin, const PetscScalar a[]) noexcept 1688 { 1689 PetscFunctionBegin; 1690 PetscValidHeaderSpecific(vin, VEC_CLASSID, 2); 1691 PetscCall(VecSeq_CUPM_Impls.template placearray<PETSC_MEMTYPE_DEVICE>(vin, a)); 1692 PetscFunctionReturn(0); 1693 } 1694 1695 template <typename T> 1696 PETSC_NODISCARD inline PetscErrorCode VecCUPMReplaceArrayAsync(T &&VecSeq_CUPM_Impls, Vec vin, const PetscScalar a[]) noexcept 1697 { 1698 PetscFunctionBegin; 1699 PetscValidHeaderSpecific(vin, VEC_CLASSID, 2); 1700 PetscCall(VecSeq_CUPM_Impls.template replacearray<PETSC_MEMTYPE_DEVICE>(vin, a)); 1701 PetscFunctionReturn(0); 1702 } 1703 1704 template <typename T> 1705 PETSC_NODISCARD inline PetscErrorCode VecCUPMResetArrayAsync(T &&VecSeq_CUPM_Impls, Vec vin) noexcept 1706 { 1707 PetscFunctionBegin; 1708 PetscValidHeaderSpecific(vin, VEC_CLASSID, 2); 1709 PetscCall(VecSeq_CUPM_Impls.template resetarray<PETSC_MEMTYPE_DEVICE>(vin)); 1710 PetscFunctionReturn(0); 1711 } 1712 1713 } // anonymous namespace 1714 1715 } // namespace impl 1716 1717 } // namespace cupm 1718 1719 } // namespace vec 1720 1721 } // namespace Petsc 1722 1723 #endif // __cplusplus 1724 1725 #endif // PETSCVECSEQCUPM_HPP 1726