1 #ifndef PETSCVECSEQCUPM_HPP 2 #define PETSCVECSEQCUPM_HPP 3 4 #include <petsc/private/veccupmimpl.h> 5 #include <petsc/private/randomimpl.h> // for _p_PetscRandom 6 #include "../src/sys/objects/device/impls/cupm/cupmthrustutility.hpp" 7 #include "../src/sys/objects/device/impls/cupm/cupmallocator.hpp" 8 #include "../src/sys/objects/device/impls/cupm/kernels.hpp" 9 10 #if defined(__cplusplus) 11 #include <thrust/transform_reduce.h> 12 #include <thrust/reduce.h> 13 #include <thrust/functional.h> 14 #include <thrust/iterator/counting_iterator.h> 15 #include <thrust/inner_product.h> 16 17 namespace Petsc 18 { 19 20 namespace vec 21 { 22 23 namespace cupm 24 { 25 26 namespace impl 27 { 28 29 // ========================================================================================== 30 // VecSeq_CUPM 31 // ========================================================================================== 32 33 template <device::cupm::DeviceType T> 34 class VecSeq_CUPM : Vec_CUPMBase<T, VecSeq_CUPM<T>> { 35 public: 36 PETSC_VEC_CUPM_BASE_CLASS_HEADER(base_type, T, VecSeq_CUPM<T>); 37 38 private: 39 PETSC_NODISCARD static Vec_Seq *VecIMPLCast_(Vec) noexcept; 40 PETSC_NODISCARD static constexpr VecType VECIMPLCUPM_() noexcept; 41 42 static PetscErrorCode VecDestroy_IMPL_(Vec) noexcept; 43 static PetscErrorCode VecResetArray_IMPL_(Vec) noexcept; 44 static PetscErrorCode VecPlaceArray_IMPL_(Vec, const PetscScalar *) noexcept; 45 static PetscErrorCode VecCreate_IMPL_Private_(Vec, PetscBool *, PetscInt, PetscScalar *) noexcept; 46 47 static PetscErrorCode MaybeIncrementEmptyLocalVec(Vec) noexcept; 48 49 // common core for min and max 50 template <typename TupleFuncT, typename UnaryFuncT> 51 static PetscErrorCode minmax_(TupleFuncT &&, UnaryFuncT &&, Vec, PetscInt *, PetscReal *) noexcept; 52 // common core for pointwise binary and pointwise unary thrust functions 53 template <typename BinaryFuncT> 54 static PetscErrorCode pointwisebinary_(BinaryFuncT &&, Vec, Vec, Vec) noexcept; 55 template <typename UnaryFuncT> 56 static PetscErrorCode pointwiseunary_(UnaryFuncT &&, Vec, Vec /*out*/ = nullptr) noexcept; 57 // mdot dispatchers 58 static PetscErrorCode mdot_(/* use complex = */ std::true_type, Vec, PetscInt, const Vec[], PetscScalar *, PetscDeviceContext) noexcept; 59 static PetscErrorCode mdot_(/* use complex = */ std::false_type, Vec, PetscInt, const Vec[], PetscScalar *, PetscDeviceContext) noexcept; 60 template <std::size_t... Idx> 61 static PetscErrorCode mdot_kernel_dispatch_(PetscDeviceContext, cupmStream_t, const PetscScalar *, const Vec[], PetscInt, PetscScalar *, util::index_sequence<Idx...>) noexcept; 62 template <int> 63 static PetscErrorCode mdot_kernel_dispatch_(PetscDeviceContext, cupmStream_t, const PetscScalar *, const Vec[], PetscInt, PetscScalar *, PetscInt &) noexcept; 64 template <std::size_t... Idx> 65 static PetscErrorCode maxpy_kernel_dispatch_(PetscDeviceContext, cupmStream_t, PetscScalar *, const PetscScalar *, const Vec *, PetscInt, util::index_sequence<Idx...>) noexcept; 66 template <int> 67 static PetscErrorCode maxpy_kernel_dispatch_(PetscDeviceContext, cupmStream_t, PetscScalar *, const PetscScalar *, const Vec *, PetscInt, PetscInt &) noexcept; 68 // common core for the various create routines 69 static PetscErrorCode createseqcupm_(Vec, PetscDeviceContext, PetscScalar * /*host_ptr*/ = nullptr, PetscScalar * /*device_ptr*/ = nullptr) noexcept; 70 71 public: 72 // callable directly via a bespoke function 73 static PetscErrorCode createseqcupm(MPI_Comm, PetscInt, PetscInt, Vec *, PetscBool) noexcept; 74 static PetscErrorCode createseqcupmwithbotharrays(MPI_Comm, PetscInt, PetscInt, const PetscScalar[], const PetscScalar[], Vec *) noexcept; 75 76 // callable indirectly via function pointers 77 static PetscErrorCode duplicate(Vec, Vec *) noexcept; 78 static PetscErrorCode aypx(Vec, PetscScalar, Vec) noexcept; 79 static PetscErrorCode axpy(Vec, PetscScalar, Vec) noexcept; 80 static PetscErrorCode pointwisedivide(Vec, Vec, Vec) noexcept; 81 static PetscErrorCode pointwisemult(Vec, Vec, Vec) noexcept; 82 static PetscErrorCode reciprocal(Vec) noexcept; 83 static PetscErrorCode waxpy(Vec, PetscScalar, Vec, Vec) noexcept; 84 static PetscErrorCode maxpy(Vec, PetscInt, const PetscScalar[], Vec *) noexcept; 85 static PetscErrorCode dot(Vec, Vec, PetscScalar *) noexcept; 86 static PetscErrorCode mdot(Vec, PetscInt, const Vec[], PetscScalar *) noexcept; 87 static PetscErrorCode set(Vec, PetscScalar) noexcept; 88 static PetscErrorCode scale(Vec, PetscScalar) noexcept; 89 static PetscErrorCode tdot(Vec, Vec, PetscScalar *) noexcept; 90 static PetscErrorCode copy(Vec, Vec) noexcept; 91 static PetscErrorCode swap(Vec, Vec) noexcept; 92 static PetscErrorCode axpby(Vec, PetscScalar, PetscScalar, Vec) noexcept; 93 static PetscErrorCode axpbypcz(Vec, PetscScalar, PetscScalar, PetscScalar, Vec, Vec) noexcept; 94 static PetscErrorCode norm(Vec, NormType, PetscReal *) noexcept; 95 static PetscErrorCode dotnorm2(Vec, Vec, PetscScalar *, PetscScalar *) noexcept; 96 static PetscErrorCode destroy(Vec) noexcept; 97 static PetscErrorCode conjugate(Vec) noexcept; 98 template <PetscMemoryAccessMode> 99 static PetscErrorCode getlocalvector(Vec, Vec) noexcept; 100 template <PetscMemoryAccessMode> 101 static PetscErrorCode restorelocalvector(Vec, Vec) noexcept; 102 static PetscErrorCode max(Vec, PetscInt *, PetscReal *) noexcept; 103 static PetscErrorCode min(Vec, PetscInt *, PetscReal *) noexcept; 104 static PetscErrorCode sum(Vec, PetscScalar *) noexcept; 105 static PetscErrorCode shift(Vec, PetscScalar) noexcept; 106 static PetscErrorCode setrandom(Vec, PetscRandom) noexcept; 107 static PetscErrorCode bindtocpu(Vec, PetscBool) noexcept; 108 static PetscErrorCode setpreallocationcoo(Vec, PetscCount, const PetscInt[]) noexcept; 109 static PetscErrorCode setvaluescoo(Vec, const PetscScalar[], InsertMode) noexcept; 110 }; 111 112 // ========================================================================================== 113 // VecSeq_CUPM - Private API 114 // ========================================================================================== 115 116 template <device::cupm::DeviceType T> 117 inline Vec_Seq *VecSeq_CUPM<T>::VecIMPLCast_(Vec v) noexcept 118 { 119 return static_cast<Vec_Seq *>(v->data); 120 } 121 122 template <device::cupm::DeviceType T> 123 inline constexpr VecType VecSeq_CUPM<T>::VECIMPLCUPM_() noexcept 124 { 125 return VECSEQCUPM(); 126 } 127 128 template <device::cupm::DeviceType T> 129 inline PetscErrorCode VecSeq_CUPM<T>::VecDestroy_IMPL_(Vec v) noexcept 130 { 131 return VecDestroy_Seq(v); 132 } 133 134 template <device::cupm::DeviceType T> 135 inline PetscErrorCode VecSeq_CUPM<T>::VecResetArray_IMPL_(Vec v) noexcept 136 { 137 return VecResetArray_Seq(v); 138 } 139 140 template <device::cupm::DeviceType T> 141 inline PetscErrorCode VecSeq_CUPM<T>::VecPlaceArray_IMPL_(Vec v, const PetscScalar *a) noexcept 142 { 143 return VecPlaceArray_Seq(v, a); 144 } 145 146 template <device::cupm::DeviceType T> 147 inline PetscErrorCode VecSeq_CUPM<T>::VecCreate_IMPL_Private_(Vec v, PetscBool *alloc_missing, PetscInt, PetscScalar *host_array) noexcept 148 { 149 PetscMPIInt size; 150 151 PetscFunctionBegin; 152 if (alloc_missing) *alloc_missing = PETSC_FALSE; 153 PetscCallMPI(MPI_Comm_size(PetscObjectComm(PetscObjectCast(v)), &size)); 154 PetscCheck(size <= 1, PETSC_COMM_SELF, PETSC_ERR_ARG_WRONG, "Must create VecSeq on communicator of size 1, have size %d", size); 155 PetscCall(VecCreate_Seq_Private(v, host_array)); 156 PetscFunctionReturn(PETSC_SUCCESS); 157 } 158 159 // for functions with an early return based one vec size we still need to artificially bump the 160 // object state. This is to prevent the following: 161 // 162 // 0. Suppose you have a Vec { 163 // rank 0: [0], 164 // rank 1: [<empty>] 165 // } 166 // 1. both ranks have Vec with PetscObjectState = 0, stashed norm of 0 167 // 2. Vec enters e.g. VecSet(10) 168 // 3. rank 1 has local size 0 and bails immediately 169 // 4. rank 0 has local size 1 and enters function, eventually calls DeviceArrayWrite() 170 // 5. DeviceArrayWrite() calls PetscObjectStateIncrease(), now state = 1 171 // 6. Vec enters VecNorm(), and calls VecNormAvailable() 172 // 7. rank 1 has object state = 0, equal to stash and returns early with norm = 0 173 // 8. rank 0 has object state = 1, not equal to stash, continues to impl function 174 // 9. rank 0 deadlocks on MPI_Allreduce() because rank 1 bailed early 175 template <device::cupm::DeviceType T> 176 inline PetscErrorCode VecSeq_CUPM<T>::MaybeIncrementEmptyLocalVec(Vec v) noexcept 177 { 178 PetscFunctionBegin; 179 if (PetscUnlikely((v->map->n == 0) && (v->map->N != 0))) PetscCall(PetscObjectStateIncrease(PetscObjectCast(v))); 180 PetscFunctionReturn(PETSC_SUCCESS); 181 } 182 183 template <device::cupm::DeviceType T> 184 inline PetscErrorCode VecSeq_CUPM<T>::createseqcupm_(Vec v, PetscDeviceContext dctx, PetscScalar *host_array, PetscScalar *device_array) noexcept 185 { 186 PetscFunctionBegin; 187 PetscCall(base_type::VecCreate_IMPL_Private(v, nullptr, 0, host_array)); 188 PetscCall(Initialize_CUPMBase(v, PETSC_FALSE, host_array, device_array, dctx)); 189 PetscFunctionReturn(PETSC_SUCCESS); 190 } 191 192 template <device::cupm::DeviceType T> 193 template <typename BinaryFuncT> 194 inline PetscErrorCode VecSeq_CUPM<T>::pointwisebinary_(BinaryFuncT &&binary, Vec xin, Vec yin, Vec zout) noexcept 195 { 196 PetscFunctionBegin; 197 if (const auto n = zout->map->n) { 198 PetscDeviceContext dctx; 199 200 PetscCall(GetHandles_(&dctx)); 201 PetscCall(device::cupm::impl::ThrustApplyPointwise<T>(dctx, std::forward<BinaryFuncT>(binary), n, DeviceArrayRead(dctx, xin).data(), DeviceArrayRead(dctx, yin).data(), DeviceArrayWrite(dctx, zout).data())); 202 PetscCall(PetscDeviceContextSynchronize(dctx)); 203 } else { 204 PetscCall(MaybeIncrementEmptyLocalVec(zout)); 205 } 206 PetscFunctionReturn(PETSC_SUCCESS); 207 } 208 209 template <device::cupm::DeviceType T> 210 template <typename UnaryFuncT> 211 inline PetscErrorCode VecSeq_CUPM<T>::pointwiseunary_(UnaryFuncT &&unary, Vec xinout, Vec yin) noexcept 212 { 213 const auto inplace = !yin || (xinout == yin); 214 215 PetscFunctionBegin; 216 if (const auto n = xinout->map->n) { 217 PetscDeviceContext dctx; 218 219 PetscCall(GetHandles_(&dctx)); 220 if (inplace) { 221 PetscCall(device::cupm::impl::ThrustApplyPointwise<T>(dctx, std::forward<UnaryFuncT>(unary), n, DeviceArrayReadWrite(dctx, xinout).data())); 222 } else { 223 PetscCall(device::cupm::impl::ThrustApplyPointwise<T>(dctx, std::forward<UnaryFuncT>(unary), n, DeviceArrayRead(dctx, xinout).data(), DeviceArrayWrite(dctx, yin).data())); 224 } 225 PetscCall(PetscDeviceContextSynchronize(dctx)); 226 } else { 227 if (inplace) { 228 PetscCall(MaybeIncrementEmptyLocalVec(xinout)); 229 } else { 230 PetscCall(MaybeIncrementEmptyLocalVec(yin)); 231 } 232 } 233 PetscFunctionReturn(PETSC_SUCCESS); 234 } 235 236 // ========================================================================================== 237 // VecSeq_CUPM - Public API - Constructors 238 // ========================================================================================== 239 240 // VecCreateSeqCUPM() 241 template <device::cupm::DeviceType T> 242 inline PetscErrorCode VecSeq_CUPM<T>::createseqcupm(MPI_Comm comm, PetscInt bs, PetscInt n, Vec *v, PetscBool call_set_type) noexcept 243 { 244 PetscFunctionBegin; 245 PetscCall(Create_CUPMBase(comm, bs, n, n, v, call_set_type)); 246 PetscFunctionReturn(PETSC_SUCCESS); 247 } 248 249 // VecCreateSeqCUPMWithArrays() 250 template <device::cupm::DeviceType T> 251 inline PetscErrorCode VecSeq_CUPM<T>::createseqcupmwithbotharrays(MPI_Comm comm, PetscInt bs, PetscInt n, const PetscScalar host_array[], const PetscScalar device_array[], Vec *v) noexcept 252 { 253 PetscDeviceContext dctx; 254 255 PetscFunctionBegin; 256 PetscCall(GetHandles_(&dctx)); 257 // do NOT call VecSetType(), otherwise ops->create() -> create() -> 258 // createseqcupm_() is called! 259 PetscCall(createseqcupm(comm, bs, n, v, PETSC_FALSE)); 260 PetscCall(createseqcupm_(*v, dctx, PetscRemoveConstCast(host_array), PetscRemoveConstCast(device_array))); 261 PetscFunctionReturn(PETSC_SUCCESS); 262 } 263 264 // v->ops->duplicate 265 template <device::cupm::DeviceType T> 266 inline PetscErrorCode VecSeq_CUPM<T>::duplicate(Vec v, Vec *y) noexcept 267 { 268 PetscDeviceContext dctx; 269 270 PetscFunctionBegin; 271 PetscCall(GetHandles_(&dctx)); 272 PetscCall(Duplicate_CUPMBase(v, y, dctx)); 273 PetscFunctionReturn(PETSC_SUCCESS); 274 } 275 276 // ========================================================================================== 277 // VecSeq_CUPM - Public API - Utility 278 // ========================================================================================== 279 280 // v->ops->bindtocpu 281 template <device::cupm::DeviceType T> 282 inline PetscErrorCode VecSeq_CUPM<T>::bindtocpu(Vec v, PetscBool usehost) noexcept 283 { 284 PetscDeviceContext dctx; 285 286 PetscFunctionBegin; 287 PetscCall(GetHandles_(&dctx)); 288 PetscCall(BindToCPU_CUPMBase(v, usehost, dctx)); 289 290 // REVIEW ME: this absolutely should be some sort of bulk mempcy rather than this mess 291 VecSetOp_CUPM(dot, VecDot_Seq, dot); 292 VecSetOp_CUPM(norm, VecNorm_Seq, norm); 293 VecSetOp_CUPM(tdot, VecTDot_Seq, tdot); 294 VecSetOp_CUPM(mdot, VecMDot_Seq, mdot); 295 VecSetOp_CUPM(resetarray, VecResetArray_Seq, base_type::template resetarray<PETSC_MEMTYPE_HOST>); 296 VecSetOp_CUPM(placearray, VecPlaceArray_Seq, base_type::template placearray<PETSC_MEMTYPE_HOST>); 297 v->ops->mtdot = v->ops->mtdot_local = VecMTDot_Seq; 298 VecSetOp_CUPM(conjugate, VecConjugate_Seq, conjugate); 299 VecSetOp_CUPM(max, VecMax_Seq, max); 300 VecSetOp_CUPM(min, VecMin_Seq, min); 301 VecSetOp_CUPM(setpreallocationcoo, VecSetPreallocationCOO_Seq, setpreallocationcoo); 302 VecSetOp_CUPM(setvaluescoo, VecSetValuesCOO_Seq, setvaluescoo); 303 PetscFunctionReturn(PETSC_SUCCESS); 304 } 305 306 // ========================================================================================== 307 // VecSeq_CUPM - Public API - Mutators 308 // ========================================================================================== 309 310 // v->ops->getlocalvector or v->ops->getlocalvectorread 311 template <device::cupm::DeviceType T> 312 template <PetscMemoryAccessMode access> 313 inline PetscErrorCode VecSeq_CUPM<T>::getlocalvector(Vec v, Vec w) noexcept 314 { 315 PetscBool wisseqcupm; 316 317 PetscFunctionBegin; 318 PetscCheckTypeNames(v, VECSEQCUPM(), VECMPICUPM()); 319 PetscCall(PetscObjectTypeCompare(PetscObjectCast(w), VECSEQCUPM(), &wisseqcupm)); 320 if (wisseqcupm) { 321 if (const auto wseq = VecIMPLCast(w)) { 322 if (auto &alloced = wseq->array_allocated) { 323 const auto useit = UseCUPMHostAlloc(util::exchange(w->pinned_memory, PETSC_FALSE)); 324 325 PetscCall(PetscFree(alloced)); 326 } 327 wseq->array = nullptr; 328 wseq->unplacedarray = nullptr; 329 } 330 if (const auto wcu = VecCUPMCast(w)) { 331 if (auto &device_array = wcu->array_d) { 332 cupmStream_t stream; 333 334 PetscCall(GetHandles_(&stream)); 335 PetscCallCUPM(cupmFreeAsync(device_array, stream)); 336 } 337 PetscCall(PetscFree(w->spptr /* wcu */)); 338 } 339 } 340 if (v->petscnative && wisseqcupm) { 341 PetscCall(PetscFree(w->data)); 342 w->data = v->data; 343 w->offloadmask = v->offloadmask; 344 w->pinned_memory = v->pinned_memory; 345 w->spptr = v->spptr; 346 PetscCall(PetscObjectStateIncrease(PetscObjectCast(w))); 347 } else { 348 const auto array = &VecIMPLCast(w)->array; 349 350 if (access == PETSC_MEMORY_ACCESS_READ) { 351 PetscCall(VecGetArrayRead(v, const_cast<const PetscScalar **>(array))); 352 } else { 353 PetscCall(VecGetArray(v, array)); 354 } 355 w->offloadmask = PETSC_OFFLOAD_CPU; 356 if (wisseqcupm) { 357 PetscDeviceContext dctx; 358 359 PetscCall(GetHandles_(&dctx)); 360 PetscCall(DeviceAllocateCheck_(dctx, w)); 361 } 362 } 363 PetscFunctionReturn(PETSC_SUCCESS); 364 } 365 366 // v->ops->restorelocalvector or v->ops->restorelocalvectorread 367 template <device::cupm::DeviceType T> 368 template <PetscMemoryAccessMode access> 369 inline PetscErrorCode VecSeq_CUPM<T>::restorelocalvector(Vec v, Vec w) noexcept 370 { 371 PetscBool wisseqcupm; 372 373 PetscFunctionBegin; 374 PetscCheckTypeNames(v, VECSEQCUPM(), VECMPICUPM()); 375 PetscCall(PetscObjectTypeCompare(PetscObjectCast(w), VECSEQCUPM(), &wisseqcupm)); 376 if (v->petscnative && wisseqcupm) { 377 // the assignments to nullptr are __critical__, as w may persist after this call returns 378 // and shouldn't share data with v! 379 v->pinned_memory = w->pinned_memory; 380 v->offloadmask = util::exchange(w->offloadmask, PETSC_OFFLOAD_UNALLOCATED); 381 v->data = util::exchange(w->data, nullptr); 382 v->spptr = util::exchange(w->spptr, nullptr); 383 } else { 384 const auto array = &VecIMPLCast(w)->array; 385 386 if (access == PETSC_MEMORY_ACCESS_READ) { 387 PetscCall(VecRestoreArrayRead(v, const_cast<const PetscScalar **>(array))); 388 } else { 389 PetscCall(VecRestoreArray(v, array)); 390 } 391 if (w->spptr && wisseqcupm) { 392 cupmStream_t stream; 393 394 PetscCall(GetHandles_(&stream)); 395 PetscCallCUPM(cupmFreeAsync(VecCUPMCast(w)->array_d, stream)); 396 PetscCall(PetscFree(w->spptr)); 397 } 398 } 399 PetscFunctionReturn(PETSC_SUCCESS); 400 } 401 402 // ========================================================================================== 403 // VecSeq_CUPM - Public API - Compute Methods 404 // ========================================================================================== 405 406 // v->ops->aypx 407 template <device::cupm::DeviceType T> 408 inline PetscErrorCode VecSeq_CUPM<T>::aypx(Vec yin, PetscScalar alpha, Vec xin) noexcept 409 { 410 const auto n = static_cast<cupmBlasInt_t>(yin->map->n); 411 const auto sync = n != 0; 412 PetscDeviceContext dctx; 413 414 PetscFunctionBegin; 415 PetscCall(GetHandles_(&dctx)); 416 if (alpha == PetscScalar(0.0)) { 417 cupmStream_t stream; 418 419 PetscCall(GetHandlesFrom_(dctx, &stream)); 420 PetscCall(PetscLogGpuTimeBegin()); 421 PetscCall(PetscCUPMMemcpyAsync(DeviceArrayWrite(dctx, yin).data(), DeviceArrayRead(dctx, xin).data(), n, cupmMemcpyDeviceToDevice, stream)); 422 PetscCall(PetscLogGpuTimeEnd()); 423 } else if (n) { 424 const auto alphaIsOne = alpha == PetscScalar(1.0); 425 const auto calpha = cupmScalarPtrCast(&alpha); 426 cupmBlasHandle_t cupmBlasHandle; 427 428 PetscCall(GetHandlesFrom_(dctx, &cupmBlasHandle)); 429 { 430 const auto yptr = DeviceArrayReadWrite(dctx, yin); 431 const auto xptr = DeviceArrayRead(dctx, xin); 432 433 PetscCall(PetscLogGpuTimeBegin()); 434 if (alphaIsOne) { 435 PetscCallCUPMBLAS(cupmBlasXaxpy(cupmBlasHandle, n, calpha, xptr.cupmdata(), 1, yptr.cupmdata(), 1)); 436 } else { 437 const auto one = cupmScalarCast(1.0); 438 439 PetscCallCUPMBLAS(cupmBlasXscal(cupmBlasHandle, n, calpha, yptr.cupmdata(), 1)); 440 PetscCallCUPMBLAS(cupmBlasXaxpy(cupmBlasHandle, n, &one, xptr.cupmdata(), 1, yptr.cupmdata(), 1)); 441 } 442 PetscCall(PetscLogGpuTimeEnd()); 443 } 444 PetscCall(PetscLogGpuFlops((alphaIsOne ? 1 : 2) * n)); 445 } 446 if (sync) PetscCall(PetscDeviceContextSynchronize(dctx)); 447 PetscFunctionReturn(PETSC_SUCCESS); 448 } 449 450 // v->ops->axpy 451 template <device::cupm::DeviceType T> 452 inline PetscErrorCode VecSeq_CUPM<T>::axpy(Vec yin, PetscScalar alpha, Vec xin) noexcept 453 { 454 PetscBool xiscupm; 455 456 PetscFunctionBegin; 457 PetscCall(PetscObjectTypeCompareAny(PetscObjectCast(xin), &xiscupm, VECSEQCUPM(), VECMPICUPM(), "")); 458 if (xiscupm) { 459 const auto n = static_cast<cupmBlasInt_t>(yin->map->n); 460 PetscDeviceContext dctx; 461 cupmBlasHandle_t cupmBlasHandle; 462 463 PetscCall(GetHandles_(&dctx, &cupmBlasHandle)); 464 PetscCall(PetscLogGpuTimeBegin()); 465 PetscCallCUPMBLAS(cupmBlasXaxpy(cupmBlasHandle, n, cupmScalarPtrCast(&alpha), DeviceArrayRead(dctx, xin), 1, DeviceArrayReadWrite(dctx, yin), 1)); 466 PetscCall(PetscLogGpuTimeEnd()); 467 PetscCall(PetscLogGpuFlops(2 * n)); 468 PetscCall(PetscDeviceContextSynchronize(dctx)); 469 } else { 470 PetscCall(VecAXPY_Seq(yin, alpha, xin)); 471 } 472 PetscFunctionReturn(PETSC_SUCCESS); 473 } 474 475 // v->ops->pointwisedivide 476 template <device::cupm::DeviceType T> 477 inline PetscErrorCode VecSeq_CUPM<T>::pointwisedivide(Vec win, Vec xin, Vec yin) noexcept 478 { 479 PetscFunctionBegin; 480 if (xin->boundtocpu || yin->boundtocpu) { 481 PetscCall(VecPointwiseDivide_Seq(win, xin, yin)); 482 } else { 483 // note order of arguments! xin and yin are read, win is written! 484 PetscCall(pointwisebinary_(thrust::divides<PetscScalar>{}, xin, yin, win)); 485 } 486 PetscFunctionReturn(PETSC_SUCCESS); 487 } 488 489 // v->ops->pointwisemult 490 template <device::cupm::DeviceType T> 491 inline PetscErrorCode VecSeq_CUPM<T>::pointwisemult(Vec win, Vec xin, Vec yin) noexcept 492 { 493 PetscFunctionBegin; 494 if (xin->boundtocpu || yin->boundtocpu) { 495 PetscCall(VecPointwiseMult_Seq(win, xin, yin)); 496 } else { 497 // note order of arguments! xin and yin are read, win is written! 498 PetscCall(pointwisebinary_(thrust::multiplies<PetscScalar>{}, xin, yin, win)); 499 } 500 PetscFunctionReturn(PETSC_SUCCESS); 501 } 502 503 namespace detail 504 { 505 506 struct reciprocal { 507 PETSC_HOSTDEVICE_INLINE_DECL PetscScalar operator()(PetscScalar s) const noexcept 508 { 509 // yes all of this verbosity is needed because sometimes PetscScalar is a thrust::complex 510 // and then it matters whether we do s ? true : false vs s == 0, as well as whether we wrap 511 // everything in PetscScalar... 512 return s == PetscScalar{0.0} ? s : PetscScalar{1.0} / s; 513 } 514 }; 515 516 } // namespace detail 517 518 // v->ops->reciprocal 519 template <device::cupm::DeviceType T> 520 inline PetscErrorCode VecSeq_CUPM<T>::reciprocal(Vec xin) noexcept 521 { 522 PetscFunctionBegin; 523 PetscCall(pointwiseunary_(detail::reciprocal{}, xin)); 524 PetscFunctionReturn(PETSC_SUCCESS); 525 } 526 527 // v->ops->waxpy 528 template <device::cupm::DeviceType T> 529 inline PetscErrorCode VecSeq_CUPM<T>::waxpy(Vec win, PetscScalar alpha, Vec xin, Vec yin) noexcept 530 { 531 PetscFunctionBegin; 532 if (alpha == PetscScalar(0.0)) { 533 PetscCall(copy(yin, win)); 534 } else if (const auto n = static_cast<cupmBlasInt_t>(win->map->n)) { 535 PetscDeviceContext dctx; 536 cupmBlasHandle_t cupmBlasHandle; 537 cupmStream_t stream; 538 539 PetscCall(GetHandles_(&dctx, &cupmBlasHandle, &stream)); 540 { 541 const auto wptr = DeviceArrayWrite(dctx, win); 542 543 PetscCall(PetscLogGpuTimeBegin()); 544 PetscCall(PetscCUPMMemcpyAsync(wptr.data(), DeviceArrayRead(dctx, yin).data(), n, cupmMemcpyDeviceToDevice, stream, true)); 545 PetscCallCUPMBLAS(cupmBlasXaxpy(cupmBlasHandle, n, cupmScalarPtrCast(&alpha), DeviceArrayRead(dctx, xin), 1, wptr.cupmdata(), 1)); 546 PetscCall(PetscLogGpuTimeEnd()); 547 } 548 PetscCall(PetscLogGpuFlops(2 * n)); 549 PetscCall(PetscDeviceContextSynchronize(dctx)); 550 } 551 PetscFunctionReturn(PETSC_SUCCESS); 552 } 553 554 namespace kernels 555 { 556 557 template <typename... Args> 558 PETSC_KERNEL_DECL static void maxpy_kernel(const PetscInt size, PetscScalar *PETSC_RESTRICT xptr, const PetscScalar *PETSC_RESTRICT aptr, Args... yptr) 559 { 560 constexpr int N = sizeof...(Args); 561 const auto tx = threadIdx.x; 562 const PetscScalar *yptr_p[] = {yptr...}; 563 564 PETSC_SHAREDMEM_DECL PetscScalar aptr_shmem[N]; 565 566 // load a to shared memory 567 if (tx < N) aptr_shmem[tx] = aptr[tx]; 568 __syncthreads(); 569 570 ::Petsc::device::cupm::kernels::util::grid_stride_1D(size, [&](PetscInt i) { 571 // these may look the same but give different results! 572 #if 0 573 PetscScalar sum = 0.0; 574 575 #pragma unroll 576 for (auto j = 0; j < N; ++j) sum += aptr_shmem[j]*yptr_p[j][i]; 577 xptr[i] += sum; 578 #else 579 auto sum = xptr[i]; 580 581 #pragma unroll 582 for (auto j = 0; j < N; ++j) sum += aptr_shmem[j]*yptr_p[j][i]; 583 xptr[i] = sum; 584 #endif 585 }); 586 return; 587 } 588 589 } // namespace kernels 590 591 namespace detail 592 { 593 594 // a helper-struct to gobble the size_t input, it is used with template parameter pack 595 // expansion such that 596 // typename repeat_type<MyType, IdxParamPack>... 597 // expands to 598 // MyType, MyType, MyType, ... [repeated sizeof...(IdxParamPack) times] 599 template <typename T, std::size_t> 600 struct repeat_type { 601 using type = T; 602 }; 603 604 } // namespace detail 605 606 template <device::cupm::DeviceType T> 607 template <std::size_t... Idx> 608 inline PetscErrorCode VecSeq_CUPM<T>::maxpy_kernel_dispatch_(PetscDeviceContext dctx, cupmStream_t stream, PetscScalar *xptr, const PetscScalar *aptr, const Vec *yin, PetscInt size, util::index_sequence<Idx...>) noexcept 609 { 610 PetscFunctionBegin; 611 // clang-format off 612 PetscCall( 613 PetscCUPMLaunchKernel1D( 614 size, 0, stream, 615 kernels::maxpy_kernel<typename detail::repeat_type<const PetscScalar *, Idx>::type...>, 616 size, xptr, aptr, DeviceArrayRead(dctx, yin[Idx]).data()... 617 ) 618 ); 619 // clang-format on 620 PetscFunctionReturn(PETSC_SUCCESS); 621 } 622 623 template <device::cupm::DeviceType T> 624 template <int N> 625 inline PetscErrorCode VecSeq_CUPM<T>::maxpy_kernel_dispatch_(PetscDeviceContext dctx, cupmStream_t stream, PetscScalar *xptr, const PetscScalar *aptr, const Vec *yin, PetscInt size, PetscInt &yidx) noexcept 626 { 627 PetscFunctionBegin; 628 PetscCall(maxpy_kernel_dispatch_(dctx, stream, xptr, aptr + yidx, yin + yidx, size, util::make_index_sequence<N>{})); 629 yidx += N; 630 PetscFunctionReturn(PETSC_SUCCESS); 631 } 632 633 // v->ops->maxpy 634 template <device::cupm::DeviceType T> 635 inline PetscErrorCode VecSeq_CUPM<T>::maxpy(Vec xin, PetscInt nv, const PetscScalar *alpha, Vec *yin) noexcept 636 { 637 const auto n = xin->map->n; 638 PetscDeviceContext dctx; 639 cupmStream_t stream; 640 641 PetscFunctionBegin; 642 PetscCall(GetHandles_(&dctx, &stream)); 643 { 644 const auto xptr = DeviceArrayReadWrite(dctx, xin); 645 PetscScalar *d_alpha = nullptr; 646 PetscInt yidx = 0; 647 648 // placement of early-return is deliberate, we would like to capture the 649 // DeviceArrayReadWrite() call (which calls PetscObjectStateIncreate()) before we bail 650 if (!n || !nv) PetscFunctionReturn(PETSC_SUCCESS); 651 PetscCall(PetscDeviceMalloc(dctx, PETSC_MEMTYPE_CUPM(), nv, &d_alpha)); 652 PetscCall(PetscCUPMMemcpyAsync(d_alpha, alpha, nv, cupmMemcpyHostToDevice, stream)); 653 PetscCall(PetscLogGpuTimeBegin()); 654 do { 655 switch (nv - yidx) { 656 case 7: 657 PetscCall(maxpy_kernel_dispatch_<7>(dctx, stream, xptr.data(), d_alpha, yin, n, yidx)); 658 break; 659 case 6: 660 PetscCall(maxpy_kernel_dispatch_<6>(dctx, stream, xptr.data(), d_alpha, yin, n, yidx)); 661 break; 662 case 5: 663 PetscCall(maxpy_kernel_dispatch_<5>(dctx, stream, xptr.data(), d_alpha, yin, n, yidx)); 664 break; 665 case 4: 666 PetscCall(maxpy_kernel_dispatch_<4>(dctx, stream, xptr.data(), d_alpha, yin, n, yidx)); 667 break; 668 case 3: 669 PetscCall(maxpy_kernel_dispatch_<3>(dctx, stream, xptr.data(), d_alpha, yin, n, yidx)); 670 break; 671 case 2: 672 PetscCall(maxpy_kernel_dispatch_<2>(dctx, stream, xptr.data(), d_alpha, yin, n, yidx)); 673 break; 674 case 1: 675 PetscCall(maxpy_kernel_dispatch_<1>(dctx, stream, xptr.data(), d_alpha, yin, n, yidx)); 676 break; 677 default: // 8 or more 678 PetscCall(maxpy_kernel_dispatch_<8>(dctx, stream, xptr.data(), d_alpha, yin, n, yidx)); 679 break; 680 } 681 } while (yidx < nv); 682 PetscCall(PetscLogGpuTimeEnd()); 683 PetscCall(PetscDeviceFree(dctx, d_alpha)); 684 } 685 PetscCall(PetscLogGpuFlops(nv * 2 * n)); 686 PetscCall(PetscDeviceContextSynchronize(dctx)); 687 PetscFunctionReturn(PETSC_SUCCESS); 688 } 689 690 template <device::cupm::DeviceType T> 691 inline PetscErrorCode VecSeq_CUPM<T>::dot(Vec xin, Vec yin, PetscScalar *z) noexcept 692 { 693 PetscFunctionBegin; 694 if (const auto n = static_cast<cupmBlasInt_t>(xin->map->n)) { 695 PetscDeviceContext dctx; 696 cupmBlasHandle_t cupmBlasHandle; 697 698 PetscCall(GetHandles_(&dctx, &cupmBlasHandle)); 699 // arguments y, x are reversed because BLAS complex conjugates the first argument, PETSc the 700 // second 701 PetscCall(PetscLogGpuTimeBegin()); 702 PetscCallCUPMBLAS(cupmBlasXdot(cupmBlasHandle, n, DeviceArrayRead(dctx, yin), 1, DeviceArrayRead(dctx, xin), 1, cupmScalarPtrCast(z))); 703 PetscCall(PetscLogGpuTimeEnd()); 704 PetscCall(PetscLogGpuFlops(2 * n - 1)); 705 } else { 706 *z = 0.0; 707 } 708 PetscFunctionReturn(PETSC_SUCCESS); 709 } 710 711 #define MDOT_WORKGROUP_NUM 128 712 #define MDOT_WORKGROUP_SIZE MDOT_WORKGROUP_NUM 713 714 namespace kernels 715 { 716 717 PETSC_DEVICE_INLINE_DECL static PetscInt EntriesPerGroup(const PetscInt size) noexcept 718 { 719 const auto group_entries = (size - 1) / gridDim.x + 1; 720 // for very small vectors, a group should still do some work 721 return group_entries ? group_entries : 1; 722 } 723 724 template <typename... ConstPetscScalarPointer> 725 PETSC_KERNEL_DECL static void mdot_kernel(const PetscScalar *PETSC_RESTRICT x, const PetscInt size, PetscScalar *PETSC_RESTRICT results, ConstPetscScalarPointer... y) 726 { 727 constexpr int N = sizeof...(ConstPetscScalarPointer); 728 const PetscScalar *ylocal[] = {y...}; 729 PetscScalar sumlocal[N]; 730 731 PETSC_SHAREDMEM_DECL PetscScalar shmem[N * MDOT_WORKGROUP_SIZE]; 732 733 // HIP -- for whatever reason -- has threadIdx, blockIdx, blockDim, and gridDim as separate 734 // types, so each of these go on separate lines... 735 const auto tx = threadIdx.x; 736 const auto bx = blockIdx.x; 737 const auto bdx = blockDim.x; 738 const auto gdx = gridDim.x; 739 const auto worksize = EntriesPerGroup(size); 740 const auto begin = tx + bx * worksize; 741 const auto end = min((bx + 1) * worksize, size); 742 743 #pragma unroll 744 for (auto i = 0; i < N; ++i) sumlocal[i] = 0; 745 746 for (auto i = begin; i < end; i += bdx) { 747 const auto xi = x[i]; // load only once from global memory! 748 749 #pragma unroll 750 for (auto j = 0; j < N; ++j) sumlocal[j] += ylocal[j][i] * xi; 751 } 752 753 #pragma unroll 754 for (auto i = 0; i < N; ++i) shmem[tx + i * MDOT_WORKGROUP_SIZE] = sumlocal[i]; 755 756 // parallel reduction 757 for (auto stride = bdx / 2; stride > 0; stride /= 2) { 758 __syncthreads(); 759 if (tx < stride) { 760 #pragma unroll 761 for (auto i = 0; i < N; ++i) shmem[tx + i * MDOT_WORKGROUP_SIZE] += shmem[tx + stride + i * MDOT_WORKGROUP_SIZE]; 762 } 763 } 764 // bottom N threads per block write to global memory 765 // REVIEW ME: I am ~pretty~ sure we don't need another __syncthreads() here since each thread 766 // writes to the same sections in the above loop that it is about to read from below, but 767 // running this under the racecheck tool of cuda-memcheck reports a write-after-write hazard. 768 __syncthreads(); 769 if (tx < N) results[bx + tx * gdx] = shmem[tx * MDOT_WORKGROUP_SIZE]; 770 return; 771 } 772 773 namespace 774 { 775 776 PETSC_KERNEL_DECL void sum_kernel(const PetscInt size, PetscScalar *PETSC_RESTRICT results) 777 { 778 int local_i = 0; 779 PetscScalar local_results[8]; 780 781 // each thread sums up MDOT_WORKGROUP_NUM entries of the result, storing it in a local buffer 782 // 783 // *-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-* 784 // | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | ... 785 // *-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-* 786 // | ______________________________________________________/ 787 // | / <- MDOT_WORKGROUP_NUM -> 788 // |/ 789 // + 790 // v 791 // *-*-* 792 // | | | ... 793 // *-*-* 794 // 795 ::Petsc::device::cupm::kernels::util::grid_stride_1D(size, [&](PetscInt i) { 796 PetscScalar z_sum = 0; 797 798 for (auto j = i * MDOT_WORKGROUP_SIZE; j < (i + 1) * MDOT_WORKGROUP_SIZE; ++j) z_sum += results[j]; 799 local_results[local_i++] = z_sum; 800 }); 801 // if we needed more than 1 workgroup to handle the vector we should sync since other threads 802 // may currently be reading from results 803 if (size >= MDOT_WORKGROUP_SIZE) __syncthreads(); 804 // Local buffer is now written to global memory 805 ::Petsc::device::cupm::kernels::util::grid_stride_1D(size, [&](PetscInt i) { 806 const auto j = --local_i; 807 808 if (j >= 0) results[i] = local_results[j]; 809 }); 810 return; 811 } 812 813 } // namespace 814 815 } // namespace kernels 816 817 template <device::cupm::DeviceType T> 818 template <std::size_t... Idx> 819 inline PetscErrorCode VecSeq_CUPM<T>::mdot_kernel_dispatch_(PetscDeviceContext dctx, cupmStream_t stream, const PetscScalar *xarr, const Vec yin[], PetscInt size, PetscScalar *results, util::index_sequence<Idx...>) noexcept 820 { 821 PetscFunctionBegin; 822 // REVIEW ME: convert this kernel launch to PetscCUPMLaunchKernel1D(), it currently launches 823 // 128 blocks of 128 threads every time which may be wasteful 824 // clang-format off 825 PetscCallCUPM( 826 cupmLaunchKernel( 827 kernels::mdot_kernel<typename detail::repeat_type<const PetscScalar *, Idx>::type...>, 828 MDOT_WORKGROUP_NUM, MDOT_WORKGROUP_SIZE, 0, stream, 829 xarr, size, results, DeviceArrayRead(dctx, yin[Idx]).data()... 830 ) 831 ); 832 // clang-format on 833 PetscFunctionReturn(PETSC_SUCCESS); 834 } 835 836 template <device::cupm::DeviceType T> 837 template <int N> 838 inline PetscErrorCode VecSeq_CUPM<T>::mdot_kernel_dispatch_(PetscDeviceContext dctx, cupmStream_t stream, const PetscScalar *xarr, const Vec yin[], PetscInt size, PetscScalar *results, PetscInt &yidx) noexcept 839 { 840 PetscFunctionBegin; 841 PetscCall(mdot_kernel_dispatch_(dctx, stream, xarr, yin + yidx, size, results + yidx * MDOT_WORKGROUP_NUM, util::make_index_sequence<N>{})); 842 yidx += N; 843 PetscFunctionReturn(PETSC_SUCCESS); 844 } 845 846 template <device::cupm::DeviceType T> 847 inline PetscErrorCode VecSeq_CUPM<T>::mdot_(std::false_type, Vec xin, PetscInt nv, const Vec yin[], PetscScalar *z, PetscDeviceContext dctx) noexcept 848 { 849 // the largest possible size of a batch 850 constexpr PetscInt batchsize = 8; 851 // how many sub streams to create, if nv <= batchsize we can do this without looping, so we 852 // do not create substreams. Note we don't create more than 8 streams, in practice we could 853 // not get more parallelism with higher numbers. 854 const auto num_sub_streams = nv > batchsize ? std::min((nv + batchsize) / batchsize, batchsize) : 0; 855 const auto n = xin->map->n; 856 // number of vectors that we handle via the batches. note any singletons are handled by 857 // cublas, hence the nv-1. 858 const auto nvbatch = ((nv % batchsize) == 1) ? nv - 1 : nv; 859 const auto nwork = nvbatch * MDOT_WORKGROUP_NUM; 860 PetscScalar *d_results; 861 cupmStream_t stream; 862 863 PetscFunctionBegin; 864 PetscCall(GetHandlesFrom_(dctx, &stream)); 865 // allocate scratchpad memory for the results of individual work groups 866 PetscCall(PetscDeviceMalloc(dctx, PETSC_MEMTYPE_CUPM(), nwork, &d_results)); 867 { 868 const auto xptr = DeviceArrayRead(dctx, xin); 869 PetscInt yidx = 0; 870 auto subidx = 0; 871 auto cur_stream = stream; 872 auto cur_ctx = dctx; 873 PetscDeviceContext *sub = nullptr; 874 PetscStreamType stype; 875 876 // REVIEW ME: maybe PetscDeviceContextFork() should insert dctx into the first entry of 877 // sub. Ideally the parent context should also join in on the fork, but it is extremely 878 // fiddly to do so presently 879 PetscCall(PetscDeviceContextGetStreamType(dctx, &stype)); 880 if (stype == PETSC_STREAM_GLOBAL_BLOCKING) stype = PETSC_STREAM_DEFAULT_BLOCKING; 881 // If we have a globally blocking stream create nonblocking streams instead (as we can 882 // locally exploit the parallelism). Otherwise use the prescribed stream type. 883 PetscCall(PetscDeviceContextForkWithStreamType(dctx, stype, num_sub_streams, &sub)); 884 PetscCall(PetscLogGpuTimeBegin()); 885 do { 886 if (num_sub_streams) { 887 cur_ctx = sub[subidx++ % num_sub_streams]; 888 PetscCall(GetHandlesFrom_(cur_ctx, &cur_stream)); 889 } 890 // REVIEW ME: Should probably try and load-balance these. Consider the case where nv = 9; 891 // it is very likely better to do 4+5 rather than 8+1 892 switch (nv - yidx) { 893 case 7: 894 PetscCall(mdot_kernel_dispatch_<7>(cur_ctx, cur_stream, xptr.data(), yin, n, d_results, yidx)); 895 break; 896 case 6: 897 PetscCall(mdot_kernel_dispatch_<6>(cur_ctx, cur_stream, xptr.data(), yin, n, d_results, yidx)); 898 break; 899 case 5: 900 PetscCall(mdot_kernel_dispatch_<5>(cur_ctx, cur_stream, xptr.data(), yin, n, d_results, yidx)); 901 break; 902 case 4: 903 PetscCall(mdot_kernel_dispatch_<4>(cur_ctx, cur_stream, xptr.data(), yin, n, d_results, yidx)); 904 break; 905 case 3: 906 PetscCall(mdot_kernel_dispatch_<3>(cur_ctx, cur_stream, xptr.data(), yin, n, d_results, yidx)); 907 break; 908 case 2: 909 PetscCall(mdot_kernel_dispatch_<2>(cur_ctx, cur_stream, xptr.data(), yin, n, d_results, yidx)); 910 break; 911 case 1: { 912 cupmBlasHandle_t cupmBlasHandle; 913 914 PetscCall(GetHandlesFrom_(cur_ctx, &cupmBlasHandle)); 915 PetscCallCUPMBLAS(cupmBlasXdot(cupmBlasHandle, static_cast<cupmBlasInt_t>(n), DeviceArrayRead(cur_ctx, yin[yidx]).cupmdata(), 1, xptr.cupmdata(), 1, cupmScalarPtrCast(z + yidx))); 916 ++yidx; 917 } break; 918 default: // 8 or more 919 PetscCall(mdot_kernel_dispatch_<8>(cur_ctx, cur_stream, xptr.data(), yin, n, d_results, yidx)); 920 break; 921 } 922 } while (yidx < nv); 923 PetscCall(PetscLogGpuTimeEnd()); 924 PetscCall(PetscDeviceContextJoin(dctx, num_sub_streams, PETSC_DEVICE_CONTEXT_JOIN_DESTROY, &sub)); 925 } 926 927 PetscCall(PetscCUPMLaunchKernel1D(nvbatch, 0, stream, kernels::sum_kernel, nvbatch, d_results)); 928 // copy result of device reduction to host 929 PetscCall(PetscCUPMMemcpyAsync(z, d_results, nvbatch, cupmMemcpyDeviceToHost, stream)); 930 // do these now while final reduction is in flight 931 PetscCall(PetscLogFlops(nwork)); 932 PetscCall(PetscDeviceFree(dctx, d_results)); 933 PetscFunctionReturn(PETSC_SUCCESS); 934 } 935 936 #undef MDOT_WORKGROUP_NUM 937 #undef MDOT_WORKGROUP_SIZE 938 939 template <device::cupm::DeviceType T> 940 inline PetscErrorCode VecSeq_CUPM<T>::mdot_(std::true_type, Vec xin, PetscInt nv, const Vec yin[], PetscScalar *z, PetscDeviceContext dctx) noexcept 941 { 942 // probably not worth it to run more than 8 of these at a time? 943 const auto n_sub = PetscMin(nv, 8); 944 const auto n = static_cast<cupmBlasInt_t>(xin->map->n); 945 const auto xptr = DeviceArrayRead(dctx, xin); 946 PetscScalar *d_z; 947 PetscDeviceContext *subctx; 948 cupmStream_t stream; 949 950 PetscFunctionBegin; 951 PetscCall(GetHandlesFrom_(dctx, &stream)); 952 PetscCall(PetscDeviceMalloc(dctx, PETSC_MEMTYPE_CUPM(), nv, &d_z)); 953 PetscCall(PetscDeviceContextFork(dctx, n_sub, &subctx)); 954 PetscCall(PetscLogGpuTimeBegin()); 955 for (PetscInt i = 0; i < nv; ++i) { 956 const auto sub = subctx[i % n_sub]; 957 cupmBlasHandle_t handle; 958 cupmBlasPointerMode_t old_mode; 959 960 PetscCall(GetHandlesFrom_(sub, &handle)); 961 PetscCallCUPMBLAS(cupmBlasGetPointerMode(handle, &old_mode)); 962 if (old_mode != CUPMBLAS_POINTER_MODE_DEVICE) PetscCallCUPMBLAS(cupmBlasSetPointerMode(handle, CUPMBLAS_POINTER_MODE_DEVICE)); 963 PetscCallCUPMBLAS(cupmBlasXdot(handle, n, DeviceArrayRead(sub, yin[i]), 1, xptr.cupmdata(), 1, cupmScalarPtrCast(d_z + i))); 964 if (old_mode != CUPMBLAS_POINTER_MODE_DEVICE) PetscCallCUPMBLAS(cupmBlasSetPointerMode(handle, old_mode)); 965 } 966 PetscCall(PetscLogGpuTimeEnd()); 967 PetscCall(PetscDeviceContextJoin(dctx, n_sub, PETSC_DEVICE_CONTEXT_JOIN_DESTROY, &subctx)); 968 PetscCall(PetscCUPMMemcpyAsync(z, d_z, nv, cupmMemcpyDeviceToHost, stream)); 969 PetscCall(PetscDeviceFree(dctx, d_z)); 970 // REVIEW ME: flops????? 971 PetscFunctionReturn(PETSC_SUCCESS); 972 } 973 974 // v->ops->mdot 975 template <device::cupm::DeviceType T> 976 inline PetscErrorCode VecSeq_CUPM<T>::mdot(Vec xin, PetscInt nv, const Vec yin[], PetscScalar *z) noexcept 977 { 978 PetscFunctionBegin; 979 if (PetscUnlikely(nv == 1)) { 980 // dot handles nv = 0 correctly 981 PetscCall(dot(xin, const_cast<Vec>(yin[0]), z)); 982 } else if (const auto n = xin->map->n) { 983 PetscDeviceContext dctx; 984 985 PetscCheck(nv > 0, PETSC_COMM_SELF, PETSC_ERR_LIB, "Number of vectors provided to %s %" PetscInt_FMT " not positive", PETSC_FUNCTION_NAME, nv); 986 PetscCall(GetHandles_(&dctx)); 987 PetscCall(mdot_(std::integral_constant<bool, PetscDefined(USE_COMPLEX)>{}, xin, nv, yin, z, dctx)); 988 // REVIEW ME: double count of flops?? 989 PetscCall(PetscLogGpuFlops(nv * (2 * n - 1))); 990 PetscCall(PetscDeviceContextSynchronize(dctx)); 991 } else { 992 PetscCall(PetscArrayzero(z, nv)); 993 } 994 PetscFunctionReturn(PETSC_SUCCESS); 995 } 996 997 // v->ops->set 998 template <device::cupm::DeviceType T> 999 inline PetscErrorCode VecSeq_CUPM<T>::set(Vec xin, PetscScalar alpha) noexcept 1000 { 1001 const auto n = xin->map->n; 1002 PetscDeviceContext dctx; 1003 cupmStream_t stream; 1004 1005 PetscFunctionBegin; 1006 PetscCall(GetHandles_(&dctx, &stream)); 1007 { 1008 const auto xptr = DeviceArrayWrite(dctx, xin); 1009 1010 if (alpha == PetscScalar(0.0)) { 1011 PetscCall(PetscCUPMMemsetAsync(xptr.data(), 0, n, stream)); 1012 } else { 1013 PetscCall(device::cupm::impl::ThrustSet<T>(stream, n, xptr.data(), &alpha)); 1014 } 1015 if (n) PetscCall(PetscDeviceContextSynchronize(dctx)); // don't sync if we did nothing 1016 } 1017 PetscFunctionReturn(PETSC_SUCCESS); 1018 } 1019 1020 // v->ops->scale 1021 template <device::cupm::DeviceType T> 1022 inline PetscErrorCode VecSeq_CUPM<T>::scale(Vec xin, PetscScalar alpha) noexcept 1023 { 1024 PetscFunctionBegin; 1025 if (PetscUnlikely(alpha == PetscScalar(1.0))) PetscFunctionReturn(PETSC_SUCCESS); 1026 if (PetscUnlikely(alpha == PetscScalar(0.0))) { 1027 PetscCall(set(xin, alpha)); 1028 } else if (const auto n = static_cast<cupmBlasInt_t>(xin->map->n)) { 1029 PetscDeviceContext dctx; 1030 cupmBlasHandle_t cupmBlasHandle; 1031 1032 PetscCall(GetHandles_(&dctx, &cupmBlasHandle)); 1033 PetscCall(PetscLogGpuTimeBegin()); 1034 PetscCallCUPMBLAS(cupmBlasXscal(cupmBlasHandle, n, cupmScalarPtrCast(&alpha), DeviceArrayReadWrite(dctx, xin), 1)); 1035 PetscCall(PetscLogGpuTimeEnd()); 1036 PetscCall(PetscLogGpuFlops(n)); 1037 PetscCall(PetscDeviceContextSynchronize(dctx)); 1038 } else { 1039 PetscCall(MaybeIncrementEmptyLocalVec(xin)); 1040 } 1041 PetscFunctionReturn(PETSC_SUCCESS); 1042 } 1043 1044 // v->ops->tdot 1045 template <device::cupm::DeviceType T> 1046 inline PetscErrorCode VecSeq_CUPM<T>::tdot(Vec xin, Vec yin, PetscScalar *z) noexcept 1047 { 1048 PetscFunctionBegin; 1049 if (const auto n = static_cast<cupmBlasInt_t>(xin->map->n)) { 1050 PetscDeviceContext dctx; 1051 cupmBlasHandle_t cupmBlasHandle; 1052 1053 PetscCall(GetHandles_(&dctx, &cupmBlasHandle)); 1054 PetscCall(PetscLogGpuTimeBegin()); 1055 PetscCallCUPMBLAS(cupmBlasXdotu(cupmBlasHandle, n, DeviceArrayRead(dctx, xin), 1, DeviceArrayRead(dctx, yin), 1, cupmScalarPtrCast(z))); 1056 PetscCall(PetscLogGpuTimeEnd()); 1057 PetscCall(PetscLogGpuFlops(2 * n - 1)); 1058 } else { 1059 *z = 0.0; 1060 } 1061 PetscFunctionReturn(PETSC_SUCCESS); 1062 } 1063 1064 // v->ops->copy 1065 template <device::cupm::DeviceType T> 1066 inline PetscErrorCode VecSeq_CUPM<T>::copy(Vec xin, Vec yout) noexcept 1067 { 1068 PetscFunctionBegin; 1069 if (xin == yout) PetscFunctionReturn(PETSC_SUCCESS); 1070 if (const auto n = xin->map->n) { 1071 const auto xmask = xin->offloadmask; 1072 // silence buggy gcc warning: mode may be used uninitialized in this function 1073 auto mode = cupmMemcpyDeviceToDevice; 1074 PetscDeviceContext dctx; 1075 cupmStream_t stream; 1076 1077 // translate from PetscOffloadMask to cupmMemcpyKind 1078 switch (const auto ymask = yout->offloadmask) { 1079 case PETSC_OFFLOAD_UNALLOCATED: { 1080 PetscBool yiscupm; 1081 1082 PetscCall(PetscObjectTypeCompareAny(PetscObjectCast(yout), &yiscupm, VECSEQCUPM(), VECMPICUPM(), "")); 1083 if (yiscupm) { 1084 mode = PetscOffloadDevice(xmask) ? cupmMemcpyDeviceToDevice : cupmMemcpyHostToHost; 1085 break; 1086 } 1087 } // fall-through if unallocated and not cupm 1088 #if PETSC_CPP_VERSION >= 17 1089 [[fallthrough]]; 1090 #endif 1091 case PETSC_OFFLOAD_CPU: 1092 mode = PetscOffloadHost(xmask) ? cupmMemcpyHostToHost : cupmMemcpyDeviceToHost; 1093 break; 1094 case PETSC_OFFLOAD_BOTH: 1095 case PETSC_OFFLOAD_GPU: 1096 mode = PetscOffloadDevice(xmask) ? cupmMemcpyDeviceToDevice : cupmMemcpyHostToDevice; 1097 break; 1098 default: 1099 SETERRQ(PETSC_COMM_SELF, PETSC_ERR_ARG_INCOMP, "Incompatible offload mask %s", PetscOffloadMaskToString(ymask)); 1100 } 1101 1102 PetscCall(GetHandles_(&dctx, &stream)); 1103 switch (mode) { 1104 case cupmMemcpyDeviceToDevice: // the best case 1105 case cupmMemcpyHostToDevice: { // not terrible 1106 const auto yptr = DeviceArrayWrite(dctx, yout); 1107 const auto xptr = mode == cupmMemcpyDeviceToDevice ? DeviceArrayRead(dctx, xin).data() : HostArrayRead(dctx, xin).data(); 1108 1109 PetscCall(PetscLogGpuTimeBegin()); 1110 PetscCall(PetscCUPMMemcpyAsync(yptr.data(), xptr, n, mode, stream)); 1111 PetscCall(PetscLogGpuTimeEnd()); 1112 } break; 1113 case cupmMemcpyDeviceToHost: // not great 1114 case cupmMemcpyHostToHost: { // worst case 1115 const auto xptr = mode == cupmMemcpyDeviceToHost ? DeviceArrayRead(dctx, xin).data() : HostArrayRead(dctx, xin).data(); 1116 PetscScalar *yptr; 1117 1118 PetscCall(VecGetArrayWrite(yout, &yptr)); 1119 if (mode == cupmMemcpyDeviceToHost) PetscCall(PetscLogGpuTimeBegin()); 1120 PetscCall(PetscCUPMMemcpyAsync(yptr, xptr, n, mode, stream, /* force async */ true)); 1121 if (mode == cupmMemcpyDeviceToHost) PetscCall(PetscLogGpuTimeEnd()); 1122 PetscCall(VecRestoreArrayWrite(yout, &yptr)); 1123 } break; 1124 default: 1125 SETERRQ(PETSC_COMM_SELF, PETSC_ERR_GPU, "Unknown cupmMemcpyKind %d", static_cast<int>(mode)); 1126 } 1127 PetscCall(PetscDeviceContextSynchronize(dctx)); 1128 } else { 1129 PetscCall(MaybeIncrementEmptyLocalVec(yout)); 1130 } 1131 PetscFunctionReturn(PETSC_SUCCESS); 1132 } 1133 1134 // v->ops->swap 1135 template <device::cupm::DeviceType T> 1136 inline PetscErrorCode VecSeq_CUPM<T>::swap(Vec xin, Vec yin) noexcept 1137 { 1138 PetscFunctionBegin; 1139 if (xin == yin) PetscFunctionReturn(PETSC_SUCCESS); 1140 if (const auto n = static_cast<cupmBlasInt_t>(xin->map->n)) { 1141 PetscDeviceContext dctx; 1142 cupmBlasHandle_t cupmBlasHandle; 1143 1144 PetscCall(GetHandles_(&dctx, &cupmBlasHandle)); 1145 PetscCall(PetscLogGpuTimeBegin()); 1146 PetscCallCUPMBLAS(cupmBlasXswap(cupmBlasHandle, n, DeviceArrayReadWrite(dctx, xin), 1, DeviceArrayReadWrite(dctx, yin), 1)); 1147 PetscCall(PetscLogGpuTimeEnd()); 1148 PetscCall(PetscDeviceContextSynchronize(dctx)); 1149 } else { 1150 PetscCall(MaybeIncrementEmptyLocalVec(xin)); 1151 PetscCall(MaybeIncrementEmptyLocalVec(yin)); 1152 } 1153 PetscFunctionReturn(PETSC_SUCCESS); 1154 } 1155 1156 // v->ops->axpby 1157 template <device::cupm::DeviceType T> 1158 inline PetscErrorCode VecSeq_CUPM<T>::axpby(Vec yin, PetscScalar alpha, PetscScalar beta, Vec xin) noexcept 1159 { 1160 PetscFunctionBegin; 1161 if (alpha == PetscScalar(0.0)) { 1162 PetscCall(scale(yin, beta)); 1163 } else if (beta == PetscScalar(1.0)) { 1164 PetscCall(axpy(yin, alpha, xin)); 1165 } else if (alpha == PetscScalar(1.0)) { 1166 PetscCall(aypx(yin, beta, xin)); 1167 } else if (const auto n = static_cast<cupmBlasInt_t>(yin->map->n)) { 1168 const auto betaIsZero = beta == PetscScalar(0.0); 1169 const auto aptr = cupmScalarPtrCast(&alpha); 1170 PetscDeviceContext dctx; 1171 cupmBlasHandle_t cupmBlasHandle; 1172 1173 PetscCall(GetHandles_(&dctx, &cupmBlasHandle)); 1174 { 1175 const auto xptr = DeviceArrayRead(dctx, xin); 1176 1177 if (betaIsZero /* beta = 0 */) { 1178 // here we can get away with purely write-only as we memcpy into it first 1179 const auto yptr = DeviceArrayWrite(dctx, yin); 1180 cupmStream_t stream; 1181 1182 PetscCall(GetHandlesFrom_(dctx, &stream)); 1183 PetscCall(PetscLogGpuTimeBegin()); 1184 PetscCall(PetscCUPMMemcpyAsync(yptr.data(), xptr.data(), n, cupmMemcpyDeviceToDevice, stream)); 1185 PetscCallCUPMBLAS(cupmBlasXscal(cupmBlasHandle, n, aptr, yptr.cupmdata(), 1)); 1186 } else { 1187 const auto yptr = DeviceArrayReadWrite(dctx, yin); 1188 1189 PetscCall(PetscLogGpuTimeBegin()); 1190 PetscCallCUPMBLAS(cupmBlasXscal(cupmBlasHandle, n, cupmScalarPtrCast(&beta), yptr.cupmdata(), 1)); 1191 PetscCallCUPMBLAS(cupmBlasXaxpy(cupmBlasHandle, n, aptr, xptr.cupmdata(), 1, yptr.cupmdata(), 1)); 1192 } 1193 } 1194 PetscCall(PetscLogGpuTimeEnd()); 1195 PetscCall(PetscLogGpuFlops((betaIsZero ? 1 : 3) * n)); 1196 PetscCall(PetscDeviceContextSynchronize(dctx)); 1197 } else { 1198 PetscCall(MaybeIncrementEmptyLocalVec(yin)); 1199 } 1200 PetscFunctionReturn(PETSC_SUCCESS); 1201 } 1202 1203 // v->ops->axpbypcz 1204 template <device::cupm::DeviceType T> 1205 inline PetscErrorCode VecSeq_CUPM<T>::axpbypcz(Vec zin, PetscScalar alpha, PetscScalar beta, PetscScalar gamma, Vec xin, Vec yin) noexcept 1206 { 1207 PetscFunctionBegin; 1208 if (gamma != PetscScalar(1.0)) PetscCall(scale(zin, gamma)); 1209 PetscCall(axpy(zin, alpha, xin)); 1210 PetscCall(axpy(zin, beta, yin)); 1211 PetscFunctionReturn(PETSC_SUCCESS); 1212 } 1213 1214 // v->ops->norm 1215 template <device::cupm::DeviceType T> 1216 inline PetscErrorCode VecSeq_CUPM<T>::norm(Vec xin, NormType type, PetscReal *z) noexcept 1217 { 1218 PetscDeviceContext dctx; 1219 cupmBlasHandle_t cupmBlasHandle; 1220 1221 PetscFunctionBegin; 1222 PetscCall(GetHandles_(&dctx, &cupmBlasHandle)); 1223 if (const auto n = static_cast<cupmBlasInt_t>(xin->map->n)) { 1224 const auto xptr = DeviceArrayRead(dctx, xin); 1225 PetscInt flopCount = 0; 1226 1227 PetscCall(PetscLogGpuTimeBegin()); 1228 switch (type) { 1229 case NORM_1_AND_2: 1230 case NORM_1: 1231 PetscCallCUPMBLAS(cupmBlasXasum(cupmBlasHandle, n, xptr.cupmdata(), 1, cupmRealPtrCast(z))); 1232 flopCount = std::max(n - 1, 0); 1233 if (type == NORM_1) break; 1234 ++z; // fall-through 1235 #if PETSC_CPP_VERSION >= 17 1236 [[fallthrough]]; 1237 #endif 1238 case NORM_2: 1239 case NORM_FROBENIUS: 1240 PetscCallCUPMBLAS(cupmBlasXnrm2(cupmBlasHandle, n, xptr.cupmdata(), 1, cupmRealPtrCast(z))); 1241 flopCount += std::max(2 * n - 1, 0); // += in case we've fallen through from NORM_1_AND_2 1242 break; 1243 case NORM_INFINITY: { 1244 cupmBlasInt_t max_loc = 0; 1245 PetscScalar xv = 0.; 1246 cupmStream_t stream; 1247 1248 PetscCall(GetHandlesFrom_(dctx, &stream)); 1249 PetscCallCUPMBLAS(cupmBlasXamax(cupmBlasHandle, n, xptr.cupmdata(), 1, &max_loc)); 1250 PetscCall(PetscCUPMMemcpyAsync(&xv, xptr.data() + max_loc - 1, 1, cupmMemcpyDeviceToHost, stream)); 1251 *z = PetscAbsScalar(xv); 1252 // REVIEW ME: flopCount = ??? 1253 } break; 1254 } 1255 PetscCall(PetscLogGpuTimeEnd()); 1256 PetscCall(PetscLogGpuFlops(flopCount)); 1257 } else { 1258 z[0] = 0.0; 1259 z[type == NORM_1_AND_2] = 0.0; 1260 } 1261 PetscFunctionReturn(PETSC_SUCCESS); 1262 } 1263 1264 namespace detail 1265 { 1266 1267 struct dotnorm2_mult { 1268 PETSC_NODISCARD PETSC_HOSTDEVICE_INLINE_DECL thrust::tuple<PetscScalar, PetscScalar> operator()(const PetscScalar &s, const PetscScalar &t) const noexcept 1269 { 1270 const auto conjt = PetscConj(t); 1271 1272 return {s * conjt, t * conjt}; 1273 } 1274 }; 1275 1276 // it is positively __bananas__ that thrust does not define default operator+ for tuples... I 1277 // would do it myself but now I am worried that they do so on purpose... 1278 struct dotnorm2_tuple_plus { 1279 using value_type = thrust::tuple<PetscScalar, PetscScalar>; 1280 1281 PETSC_NODISCARD PETSC_HOSTDEVICE_INLINE_DECL value_type operator()(const value_type &lhs, const value_type &rhs) const noexcept { return {lhs.get<0>() + rhs.get<0>(), lhs.get<1>() + rhs.get<1>()}; } 1282 }; 1283 1284 } // namespace detail 1285 1286 // v->ops->dotnorm2 1287 template <device::cupm::DeviceType T> 1288 inline PetscErrorCode VecSeq_CUPM<T>::dotnorm2(Vec s, Vec t, PetscScalar *dp, PetscScalar *nm) noexcept 1289 { 1290 PetscDeviceContext dctx; 1291 cupmStream_t stream; 1292 1293 PetscFunctionBegin; 1294 PetscCall(GetHandles_(&dctx, &stream)); 1295 { 1296 PetscScalar dpt = 0.0, nmt = 0.0; 1297 const auto sdptr = thrust::device_pointer_cast(DeviceArrayRead(dctx, s).data()); 1298 1299 // clang-format off 1300 PetscCallThrust( 1301 thrust::tie(*dp, *nm) = THRUST_CALL( 1302 thrust::inner_product, 1303 stream, 1304 sdptr, sdptr+s->map->n, thrust::device_pointer_cast(DeviceArrayRead(dctx, t).data()), 1305 thrust::make_tuple(dpt, nmt), 1306 detail::dotnorm2_tuple_plus{}, detail::dotnorm2_mult{} 1307 ); 1308 ); 1309 // clang-format on 1310 } 1311 PetscFunctionReturn(PETSC_SUCCESS); 1312 } 1313 1314 namespace detail 1315 { 1316 1317 struct conjugate { 1318 PETSC_NODISCARD PETSC_HOSTDEVICE_INLINE_DECL PetscScalar operator()(PetscScalar x) const noexcept { return PetscConj(x); } 1319 }; 1320 1321 } // namespace detail 1322 1323 // v->ops->conjugate 1324 template <device::cupm::DeviceType T> 1325 inline PetscErrorCode VecSeq_CUPM<T>::conjugate(Vec xin) noexcept 1326 { 1327 PetscFunctionBegin; 1328 if (PetscDefined(USE_COMPLEX)) PetscCall(pointwiseunary_(detail::conjugate{}, xin)); 1329 PetscFunctionReturn(PETSC_SUCCESS); 1330 } 1331 1332 namespace detail 1333 { 1334 1335 struct real_part { 1336 PETSC_NODISCARD PETSC_HOSTDEVICE_INLINE_DECL thrust::tuple<PetscReal, PetscInt> operator()(const thrust::tuple<PetscScalar, PetscInt> &x) const { return {PetscRealPart(x.get<0>()), x.get<1>()}; } 1337 1338 PETSC_NODISCARD PETSC_HOSTDEVICE_INLINE_DECL PetscReal operator()(PetscScalar x) const { return PetscRealPart(x); } 1339 }; 1340 1341 // deriving from Operator allows us to "store" an instance of the operator in the class but 1342 // also take advantage of empty base class optimization if the operator is stateless 1343 template <typename Operator> 1344 class tuple_compare : Operator { 1345 public: 1346 using tuple_type = thrust::tuple<PetscReal, PetscInt>; 1347 using operator_type = Operator; 1348 1349 PETSC_NODISCARD PETSC_HOSTDEVICE_INLINE_DECL tuple_type operator()(const tuple_type &x, const tuple_type &y) const noexcept 1350 { 1351 if (op_()(y.get<0>(), x.get<0>())) { 1352 // if y is strictly greater/less than x, return y 1353 return y; 1354 } else if (y.get<0>() == x.get<0>()) { 1355 // if equal, prefer lower index 1356 return y.get<1>() < x.get<1>() ? y : x; 1357 } 1358 // otherwise return x 1359 return x; 1360 } 1361 1362 private: 1363 PETSC_NODISCARD PETSC_HOSTDEVICE_INLINE_DECL const operator_type &op_() const noexcept { return *this; } 1364 }; 1365 1366 } // namespace detail 1367 1368 template <device::cupm::DeviceType T> 1369 template <typename TupleFuncT, typename UnaryFuncT> 1370 inline PetscErrorCode VecSeq_CUPM<T>::minmax_(TupleFuncT &&tuple_ftr, UnaryFuncT &&unary_ftr, Vec v, PetscInt *p, PetscReal *m) noexcept 1371 { 1372 PetscFunctionBegin; 1373 PetscCheckTypeNames(v, VECSEQCUPM(), VECMPICUPM()); 1374 if (p) *p = -1; 1375 if (const auto n = v->map->n) { 1376 PetscDeviceContext dctx; 1377 cupmStream_t stream; 1378 1379 PetscCall(GetHandles_(&dctx, &stream)); 1380 // needed to: 1381 // 1. switch between transform_reduce and reduce 1382 // 2. strip the real_part functor from the arguments 1383 #if PetscDefined(USE_COMPLEX) 1384 #define THRUST_MINMAX_REDUCE(...) THRUST_CALL(thrust::transform_reduce, __VA_ARGS__) 1385 #else 1386 #define THRUST_MINMAX_REDUCE(s, b, e, real_part__, ...) THRUST_CALL(thrust::reduce, s, b, e, __VA_ARGS__) 1387 #endif 1388 { 1389 const auto vptr = thrust::device_pointer_cast(DeviceArrayRead(dctx, v).data()); 1390 1391 if (p) { 1392 // clang-format off 1393 const auto zip = thrust::make_zip_iterator( 1394 thrust::make_tuple(std::move(vptr), thrust::make_counting_iterator(PetscInt{0})) 1395 ); 1396 // clang-format on 1397 // need to use preprocessor conditionals since otherwise thrust complains about not being 1398 // able to convert a thrust::device_reference<PetscScalar> to a PetscReal on complex 1399 // builds... 1400 // clang-format off 1401 PetscCallThrust( 1402 thrust::tie(*m, *p) = THRUST_MINMAX_REDUCE( 1403 stream, zip, zip + n, detail::real_part{}, 1404 thrust::make_tuple(*m, *p), std::forward<TupleFuncT>(tuple_ftr) 1405 ); 1406 ); 1407 // clang-format on 1408 } else { 1409 // clang-format off 1410 PetscCallThrust( 1411 *m = THRUST_MINMAX_REDUCE( 1412 stream, vptr, vptr + n, detail::real_part{}, 1413 *m, std::forward<UnaryFuncT>(unary_ftr) 1414 ); 1415 ); 1416 // clang-format on 1417 } 1418 } 1419 #undef THRUST_MINMAX_REDUCE 1420 } 1421 // REVIEW ME: flops? 1422 PetscFunctionReturn(PETSC_SUCCESS); 1423 } 1424 1425 // v->ops->max 1426 template <device::cupm::DeviceType T> 1427 inline PetscErrorCode VecSeq_CUPM<T>::max(Vec v, PetscInt *p, PetscReal *m) noexcept 1428 { 1429 using tuple_functor = detail::tuple_compare<thrust::greater<PetscReal>>; 1430 using unary_functor = thrust::maximum<PetscReal>; 1431 1432 PetscFunctionBegin; 1433 *m = PETSC_MIN_REAL; 1434 // use {} constructor syntax otherwise most vexing parse 1435 PetscCall(minmax_(tuple_functor{}, unary_functor{}, v, p, m)); 1436 PetscFunctionReturn(PETSC_SUCCESS); 1437 } 1438 1439 // v->ops->min 1440 template <device::cupm::DeviceType T> 1441 inline PetscErrorCode VecSeq_CUPM<T>::min(Vec v, PetscInt *p, PetscReal *m) noexcept 1442 { 1443 using tuple_functor = detail::tuple_compare<thrust::less<PetscReal>>; 1444 using unary_functor = thrust::minimum<PetscReal>; 1445 1446 PetscFunctionBegin; 1447 *m = PETSC_MAX_REAL; 1448 // use {} constructor syntax otherwise most vexing parse 1449 PetscCall(minmax_(tuple_functor{}, unary_functor{}, v, p, m)); 1450 PetscFunctionReturn(PETSC_SUCCESS); 1451 } 1452 1453 // v->ops->sum 1454 template <device::cupm::DeviceType T> 1455 inline PetscErrorCode VecSeq_CUPM<T>::sum(Vec v, PetscScalar *sum) noexcept 1456 { 1457 PetscFunctionBegin; 1458 if (const auto n = v->map->n) { 1459 PetscDeviceContext dctx; 1460 cupmStream_t stream; 1461 1462 PetscCall(GetHandles_(&dctx, &stream)); 1463 const auto dptr = thrust::device_pointer_cast(DeviceArrayRead(dctx, v).data()); 1464 // REVIEW ME: why not cupmBlasXasum()? 1465 PetscCallThrust(*sum = THRUST_CALL(thrust::reduce, stream, dptr, dptr + n, PetscScalar{0.0});); 1466 // REVIEW ME: must be at least n additions 1467 PetscCall(PetscLogGpuFlops(n)); 1468 } else { 1469 *sum = 0.0; 1470 } 1471 PetscFunctionReturn(PETSC_SUCCESS); 1472 } 1473 1474 namespace detail 1475 { 1476 1477 template <typename T> 1478 class plus_equals { 1479 public: 1480 using value_type = T; 1481 1482 PETSC_HOSTDEVICE_DECL constexpr explicit plus_equals(value_type v = value_type{}) noexcept : v_(std::move(v)) { } 1483 1484 PETSC_NODISCARD PETSC_HOSTDEVICE_INLINE_DECL constexpr T operator()(const T &val) const noexcept { return val + v_; } 1485 1486 private: 1487 value_type v_; 1488 }; 1489 1490 } // namespace detail 1491 1492 template <device::cupm::DeviceType T> 1493 inline PetscErrorCode VecSeq_CUPM<T>::shift(Vec v, PetscScalar shift) noexcept 1494 { 1495 PetscFunctionBegin; 1496 PetscCall(pointwiseunary_(detail::plus_equals<PetscScalar>{shift}, v)); 1497 PetscFunctionReturn(PETSC_SUCCESS); 1498 } 1499 1500 template <device::cupm::DeviceType T> 1501 inline PetscErrorCode VecSeq_CUPM<T>::setrandom(Vec v, PetscRandom rand) noexcept 1502 { 1503 PetscFunctionBegin; 1504 if (const auto n = v->map->n) { 1505 PetscBool iscurand; 1506 PetscDeviceContext dctx; 1507 1508 PetscCall(GetHandles_(&dctx)); 1509 PetscCall(PetscObjectTypeCompare(PetscObjectCast(rand), PETSCCURAND, &iscurand)); 1510 if (iscurand) PetscCall(PetscRandomGetValues(rand, n, DeviceArrayWrite(dctx, v))); 1511 else PetscCall(PetscRandomGetValues(rand, n, HostArrayWrite(dctx, v))); 1512 } else { 1513 PetscCall(MaybeIncrementEmptyLocalVec(v)); 1514 } 1515 // REVIEW ME: flops???? 1516 // REVIEW ME: Timing??? 1517 PetscFunctionReturn(PETSC_SUCCESS); 1518 } 1519 1520 // v->ops->setpreallocation 1521 template <device::cupm::DeviceType T> 1522 inline PetscErrorCode VecSeq_CUPM<T>::setpreallocationcoo(Vec v, PetscCount ncoo, const PetscInt coo_i[]) noexcept 1523 { 1524 PetscDeviceContext dctx; 1525 1526 PetscFunctionBegin; 1527 PetscCall(GetHandles_(&dctx)); 1528 PetscCall(VecSetPreallocationCOO_Seq(v, ncoo, coo_i)); 1529 PetscCall(SetPreallocationCOO_CUPMBase(v, ncoo, coo_i, dctx)); 1530 PetscFunctionReturn(PETSC_SUCCESS); 1531 } 1532 1533 namespace kernels 1534 { 1535 1536 template <typename F> 1537 PETSC_DEVICE_INLINE_DECL void add_coo_values_impl(const PetscScalar *PETSC_RESTRICT vv, PetscCount n, const PetscCount *PETSC_RESTRICT jmap, const PetscCount *PETSC_RESTRICT perm, InsertMode imode, PetscScalar *PETSC_RESTRICT xv, F &&xvindex) 1538 { 1539 ::Petsc::device::cupm::kernels::util::grid_stride_1D(n, [=](PetscCount i) { 1540 const auto end = jmap[i + 1]; 1541 const auto idx = xvindex(i); 1542 PetscScalar sum = 0.0; 1543 1544 for (auto k = jmap[i]; k < end; ++k) sum += vv[perm[k]]; 1545 1546 if (imode == INSERT_VALUES) { 1547 xv[idx] = sum; 1548 } else { 1549 xv[idx] += sum; 1550 } 1551 }); 1552 return; 1553 } 1554 1555 namespace 1556 { 1557 1558 PETSC_KERNEL_DECL void add_coo_values(const PetscScalar *PETSC_RESTRICT v, PetscCount n, const PetscCount *PETSC_RESTRICT jmap1, const PetscCount *PETSC_RESTRICT perm1, InsertMode imode, PetscScalar *PETSC_RESTRICT xv) 1559 { 1560 add_coo_values_impl(v, n, jmap1, perm1, imode, xv, [](PetscCount i) { return i; }); 1561 return; 1562 } 1563 1564 } // namespace 1565 1566 #if PetscDefined(USING_HCC) 1567 namespace do_not_use 1568 { 1569 1570 // Needed to silence clang warning: 1571 // 1572 // warning: function 'FUNCTION NAME' is not needed and will not be emitted 1573 // 1574 // The warning is silly, since the function *is* used, however the host compiler does not 1575 // appear see this. Likely because the function using it is in a template. 1576 // 1577 // This warning appeared in clang-11, and still persists until clang-15 (21/02/2023) 1578 inline void silence_warning_function_sum_kernel_is_not_needed_and_will_not_be_emitted() 1579 { 1580 (void)sum_kernel; 1581 } 1582 1583 inline void silence_warning_function_add_coo_values_is_not_needed_and_will_not_be_emitted() 1584 { 1585 (void)add_coo_values; 1586 } 1587 1588 } // namespace do_not_use 1589 #endif 1590 1591 } // namespace kernels 1592 1593 // v->ops->setvaluescoo 1594 template <device::cupm::DeviceType T> 1595 inline PetscErrorCode VecSeq_CUPM<T>::setvaluescoo(Vec x, const PetscScalar v[], InsertMode imode) noexcept 1596 { 1597 auto vv = const_cast<PetscScalar *>(v); 1598 PetscMemType memtype; 1599 PetscDeviceContext dctx; 1600 cupmStream_t stream; 1601 1602 PetscFunctionBegin; 1603 PetscCall(GetHandles_(&dctx, &stream)); 1604 PetscCall(PetscGetMemType(v, &memtype)); 1605 if (PetscMemTypeHost(memtype)) { 1606 const auto size = VecIMPLCast(x)->coo_n; 1607 1608 // If user gave v[] in host, we might need to copy it to device if any 1609 PetscCall(PetscDeviceMalloc(dctx, PETSC_MEMTYPE_CUPM(), size, &vv)); 1610 PetscCall(PetscCUPMMemcpyAsync(vv, v, size, cupmMemcpyHostToDevice, stream)); 1611 } 1612 1613 if (const auto n = x->map->n) { 1614 const auto vcu = VecCUPMCast(x); 1615 1616 PetscCall(PetscCUPMLaunchKernel1D(n, 0, stream, kernels::add_coo_values, vv, n, vcu->jmap1_d, vcu->perm1_d, imode, imode == INSERT_VALUES ? DeviceArrayWrite(dctx, x).data() : DeviceArrayReadWrite(dctx, x).data())); 1617 } else { 1618 PetscCall(MaybeIncrementEmptyLocalVec(x)); 1619 } 1620 1621 if (PetscMemTypeHost(memtype)) PetscCall(PetscDeviceFree(dctx, vv)); 1622 PetscCall(PetscDeviceContextSynchronize(dctx)); 1623 PetscFunctionReturn(PETSC_SUCCESS); 1624 } 1625 1626 // ========================================================================================== 1627 // VecSeq_CUPM - Implementations 1628 // ========================================================================================== 1629 1630 namespace 1631 { 1632 1633 template <typename T> 1634 inline PetscErrorCode VecCreateSeqCUPMAsync(T &&VecSeq_CUPM_Impls, MPI_Comm comm, PetscInt n, Vec *v) noexcept 1635 { 1636 PetscFunctionBegin; 1637 PetscValidPointer(v, 4); 1638 PetscCall(VecSeq_CUPM_Impls.createseqcupm(comm, 0, n, v, PETSC_TRUE)); 1639 PetscFunctionReturn(PETSC_SUCCESS); 1640 } 1641 1642 template <typename T> 1643 inline PetscErrorCode VecCreateSeqCUPMWithArraysAsync(T &&VecSeq_CUPM_Impls, MPI_Comm comm, PetscInt bs, PetscInt n, const PetscScalar cpuarray[], const PetscScalar gpuarray[], Vec *v) noexcept 1644 { 1645 PetscFunctionBegin; 1646 if (n && cpuarray) PetscValidScalarPointer(cpuarray, 5); 1647 PetscValidPointer(v, 7); 1648 PetscCall(VecSeq_CUPM_Impls.createseqcupmwithbotharrays(comm, bs, n, cpuarray, gpuarray, v)); 1649 PetscFunctionReturn(PETSC_SUCCESS); 1650 } 1651 1652 template <PetscMemoryAccessMode mode, typename T> 1653 inline PetscErrorCode VecCUPMGetArrayAsync_Private(T &&VecSeq_CUPM_Impls, Vec v, PetscScalar **a) noexcept 1654 { 1655 PetscFunctionBegin; 1656 PetscValidHeaderSpecific(v, VEC_CLASSID, 2); 1657 PetscValidPointer(a, 3); 1658 PetscCall(VecSeq_CUPM_Impls.template getarray<PETSC_MEMTYPE_DEVICE, mode>(v, a)); 1659 PetscFunctionReturn(PETSC_SUCCESS); 1660 } 1661 1662 template <PetscMemoryAccessMode mode, typename T> 1663 inline PetscErrorCode VecCUPMRestoreArrayAsync_Private(T &&VecSeq_CUPM_Impls, Vec v, PetscScalar **a) noexcept 1664 { 1665 PetscFunctionBegin; 1666 PetscValidHeaderSpecific(v, VEC_CLASSID, 2); 1667 PetscCall(VecSeq_CUPM_Impls.template restorearray<PETSC_MEMTYPE_DEVICE, mode>(v, a)); 1668 PetscFunctionReturn(PETSC_SUCCESS); 1669 } 1670 1671 template <typename T> 1672 inline PetscErrorCode VecCUPMGetArrayAsync(T &&VecSeq_CUPM_Impls, Vec v, PetscScalar **a) noexcept 1673 { 1674 PetscFunctionBegin; 1675 PetscCall(VecCUPMGetArrayAsync_Private<PETSC_MEMORY_ACCESS_READ_WRITE>(std::forward<T>(VecSeq_CUPM_Impls), v, a)); 1676 PetscFunctionReturn(PETSC_SUCCESS); 1677 } 1678 1679 template <typename T> 1680 inline PetscErrorCode VecCUPMRestoreArrayAsync(T &&VecSeq_CUPM_Impls, Vec v, PetscScalar **a) noexcept 1681 { 1682 PetscFunctionBegin; 1683 PetscCall(VecCUPMRestoreArrayAsync_Private<PETSC_MEMORY_ACCESS_READ_WRITE>(std::forward<T>(VecSeq_CUPM_Impls), v, a)); 1684 PetscFunctionReturn(PETSC_SUCCESS); 1685 } 1686 1687 template <typename T> 1688 inline PetscErrorCode VecCUPMGetArrayReadAsync(T &&VecSeq_CUPM_Impls, Vec v, const PetscScalar **a) noexcept 1689 { 1690 PetscFunctionBegin; 1691 PetscCall(VecCUPMGetArrayAsync_Private<PETSC_MEMORY_ACCESS_READ>(std::forward<T>(VecSeq_CUPM_Impls), v, const_cast<PetscScalar **>(a))); 1692 PetscFunctionReturn(PETSC_SUCCESS); 1693 } 1694 1695 template <typename T> 1696 inline PetscErrorCode VecCUPMRestoreArrayReadAsync(T &&VecSeq_CUPM_Impls, Vec v, const PetscScalar **a) noexcept 1697 { 1698 PetscFunctionBegin; 1699 PetscCall(VecCUPMRestoreArrayAsync_Private<PETSC_MEMORY_ACCESS_READ>(std::forward<T>(VecSeq_CUPM_Impls), v, const_cast<PetscScalar **>(a))); 1700 PetscFunctionReturn(PETSC_SUCCESS); 1701 } 1702 1703 template <typename T> 1704 inline PetscErrorCode VecCUPMGetArrayWriteAsync(T &&VecSeq_CUPM_Impls, Vec v, PetscScalar **a) noexcept 1705 { 1706 PetscFunctionBegin; 1707 PetscCall(VecCUPMGetArrayAsync_Private<PETSC_MEMORY_ACCESS_WRITE>(std::forward<T>(VecSeq_CUPM_Impls), v, a)); 1708 PetscFunctionReturn(PETSC_SUCCESS); 1709 } 1710 1711 template <typename T> 1712 inline PetscErrorCode VecCUPMRestoreArrayWriteAsync(T &&VecSeq_CUPM_Impls, Vec v, PetscScalar **a) noexcept 1713 { 1714 PetscFunctionBegin; 1715 PetscCall(VecCUPMRestoreArrayAsync_Private<PETSC_MEMORY_ACCESS_WRITE>(std::forward<T>(VecSeq_CUPM_Impls), v, a)); 1716 PetscFunctionReturn(PETSC_SUCCESS); 1717 } 1718 1719 template <typename T> 1720 inline PetscErrorCode VecCUPMPlaceArrayAsync(T &&VecSeq_CUPM_Impls, Vec vin, const PetscScalar a[]) noexcept 1721 { 1722 PetscFunctionBegin; 1723 PetscValidHeaderSpecific(vin, VEC_CLASSID, 2); 1724 PetscCall(VecSeq_CUPM_Impls.template placearray<PETSC_MEMTYPE_DEVICE>(vin, a)); 1725 PetscFunctionReturn(PETSC_SUCCESS); 1726 } 1727 1728 template <typename T> 1729 inline PetscErrorCode VecCUPMReplaceArrayAsync(T &&VecSeq_CUPM_Impls, Vec vin, const PetscScalar a[]) noexcept 1730 { 1731 PetscFunctionBegin; 1732 PetscValidHeaderSpecific(vin, VEC_CLASSID, 2); 1733 PetscCall(VecSeq_CUPM_Impls.template replacearray<PETSC_MEMTYPE_DEVICE>(vin, a)); 1734 PetscFunctionReturn(PETSC_SUCCESS); 1735 } 1736 1737 template <typename T> 1738 inline PetscErrorCode VecCUPMResetArrayAsync(T &&VecSeq_CUPM_Impls, Vec vin) noexcept 1739 { 1740 PetscFunctionBegin; 1741 PetscValidHeaderSpecific(vin, VEC_CLASSID, 2); 1742 PetscCall(VecSeq_CUPM_Impls.template resetarray<PETSC_MEMTYPE_DEVICE>(vin)); 1743 PetscFunctionReturn(PETSC_SUCCESS); 1744 } 1745 1746 } // anonymous namespace 1747 1748 } // namespace impl 1749 1750 } // namespace cupm 1751 1752 } // namespace vec 1753 1754 } // namespace Petsc 1755 1756 #endif // __cplusplus 1757 1758 #endif // PETSCVECSEQCUPM_HPP 1759