1 #ifndef PETSCVECSEQCUPM_HPP 2 #define PETSCVECSEQCUPM_HPP 3 4 #include <petsc/private/veccupmimpl.h> 5 6 #if defined(__cplusplus) 7 #include <petsc/private/randomimpl.h> // for _p_PetscRandom 8 9 #include <petsc/private/cpp/utility.hpp> // util::exchange, util::index_sequence 10 11 #include "../src/sys/objects/device/impls/cupm/cupmthrustutility.hpp" 12 #include "../src/sys/objects/device/impls/cupm/kernels.hpp" 13 14 #if PetscDefined(USE_COMPLEX) 15 #include <thrust/transform_reduce.h> 16 #endif 17 #include <thrust/transform.h> 18 #include <thrust/reduce.h> 19 #include <thrust/functional.h> 20 #include <thrust/tuple.h> 21 #include <thrust/device_ptr.h> 22 #include <thrust/iterator/zip_iterator.h> 23 #include <thrust/iterator/counting_iterator.h> 24 #include <thrust/inner_product.h> 25 26 namespace Petsc 27 { 28 29 namespace vec 30 { 31 32 namespace cupm 33 { 34 35 namespace impl 36 { 37 38 // ========================================================================================== 39 // VecSeq_CUPM 40 // ========================================================================================== 41 42 template <device::cupm::DeviceType T> 43 class VecSeq_CUPM : Vec_CUPMBase<T, VecSeq_CUPM<T>> { 44 public: 45 PETSC_VEC_CUPM_BASE_CLASS_HEADER(base_type, T, VecSeq_CUPM<T>); 46 47 private: 48 PETSC_NODISCARD static Vec_Seq *VecIMPLCast_(Vec) noexcept; 49 PETSC_NODISCARD static constexpr VecType VECIMPLCUPM_() noexcept; 50 51 static PetscErrorCode VecDestroy_IMPL_(Vec) noexcept; 52 static PetscErrorCode VecResetArray_IMPL_(Vec) noexcept; 53 static PetscErrorCode VecPlaceArray_IMPL_(Vec, const PetscScalar *) noexcept; 54 static PetscErrorCode VecCreate_IMPL_Private_(Vec, PetscBool *, PetscInt, PetscScalar *) noexcept; 55 56 static PetscErrorCode MaybeIncrementEmptyLocalVec(Vec) noexcept; 57 58 // common core for min and max 59 template <typename TupleFuncT, typename UnaryFuncT> 60 static PetscErrorCode MinMax_(TupleFuncT &&, UnaryFuncT &&, Vec, PetscInt *, PetscReal *) noexcept; 61 // common core for pointwise binary and pointwise unary thrust functions 62 template <typename BinaryFuncT> 63 static PetscErrorCode PointwiseBinary_(BinaryFuncT &&, Vec, Vec, Vec) noexcept; 64 template <typename UnaryFuncT> 65 static PetscErrorCode PointwiseUnary_(UnaryFuncT &&, Vec, Vec /*out*/ = nullptr) noexcept; 66 // mdot dispatchers 67 static PetscErrorCode MDot_(/* use complex = */ std::true_type, Vec, PetscInt, const Vec[], PetscScalar *, PetscDeviceContext) noexcept; 68 static PetscErrorCode MDot_(/* use complex = */ std::false_type, Vec, PetscInt, const Vec[], PetscScalar *, PetscDeviceContext) noexcept; 69 template <std::size_t... Idx> 70 static PetscErrorCode MDot_kernel_dispatch_(PetscDeviceContext, cupmStream_t, const PetscScalar *, const Vec[], PetscInt, PetscScalar *, util::index_sequence<Idx...>) noexcept; 71 template <int> 72 static PetscErrorCode MDot_kernel_dispatch_(PetscDeviceContext, cupmStream_t, const PetscScalar *, const Vec[], PetscInt, PetscScalar *, PetscInt &) noexcept; 73 template <std::size_t... Idx> 74 static PetscErrorCode MAXPY_kernel_dispatch_(PetscDeviceContext, cupmStream_t, PetscScalar *, const PetscScalar *, const Vec *, PetscInt, util::index_sequence<Idx...>) noexcept; 75 template <int> 76 static PetscErrorCode MAXPY_kernel_dispatch_(PetscDeviceContext, cupmStream_t, PetscScalar *, const PetscScalar *, const Vec *, PetscInt, PetscInt &) noexcept; 77 // common core for the various create routines 78 static PetscErrorCode CreateSeqCUPM_(Vec, PetscDeviceContext, PetscScalar * /*host_ptr*/ = nullptr, PetscScalar * /*device_ptr*/ = nullptr) noexcept; 79 80 public: 81 // callable directly via a bespoke function 82 static PetscErrorCode CreateSeqCUPM(MPI_Comm, PetscInt, PetscInt, Vec *, PetscBool) noexcept; 83 static PetscErrorCode CreateSeqCUPMWithBothArrays(MPI_Comm, PetscInt, PetscInt, const PetscScalar[], const PetscScalar[], Vec *) noexcept; 84 85 // callable indirectly via function pointers 86 static PetscErrorCode Duplicate(Vec, Vec *) noexcept; 87 static PetscErrorCode AYPX(Vec, PetscScalar, Vec) noexcept; 88 static PetscErrorCode AXPY(Vec, PetscScalar, Vec) noexcept; 89 static PetscErrorCode PointwiseDivide(Vec, Vec, Vec) noexcept; 90 static PetscErrorCode PointwiseMult(Vec, Vec, Vec) noexcept; 91 static PetscErrorCode Reciprocal(Vec) noexcept; 92 static PetscErrorCode WAXPY(Vec, PetscScalar, Vec, Vec) noexcept; 93 static PetscErrorCode MAXPY(Vec, PetscInt, const PetscScalar[], Vec *) noexcept; 94 static PetscErrorCode Dot(Vec, Vec, PetscScalar *) noexcept; 95 static PetscErrorCode MDot(Vec, PetscInt, const Vec[], PetscScalar *) noexcept; 96 static PetscErrorCode Set(Vec, PetscScalar) noexcept; 97 static PetscErrorCode Scale(Vec, PetscScalar) noexcept; 98 static PetscErrorCode TDot(Vec, Vec, PetscScalar *) noexcept; 99 static PetscErrorCode Copy(Vec, Vec) noexcept; 100 static PetscErrorCode Swap(Vec, Vec) noexcept; 101 static PetscErrorCode AXPBY(Vec, PetscScalar, PetscScalar, Vec) noexcept; 102 static PetscErrorCode AXPBYPCZ(Vec, PetscScalar, PetscScalar, PetscScalar, Vec, Vec) noexcept; 103 static PetscErrorCode Norm(Vec, NormType, PetscReal *) noexcept; 104 static PetscErrorCode DotNorm2(Vec, Vec, PetscScalar *, PetscScalar *) noexcept; 105 static PetscErrorCode Conjugate(Vec) noexcept; 106 template <PetscMemoryAccessMode> 107 static PetscErrorCode GetLocalVector(Vec, Vec) noexcept; 108 template <PetscMemoryAccessMode> 109 static PetscErrorCode RestoreLocalVector(Vec, Vec) noexcept; 110 static PetscErrorCode Max(Vec, PetscInt *, PetscReal *) noexcept; 111 static PetscErrorCode Min(Vec, PetscInt *, PetscReal *) noexcept; 112 static PetscErrorCode Sum(Vec, PetscScalar *) noexcept; 113 static PetscErrorCode Shift(Vec, PetscScalar) noexcept; 114 static PetscErrorCode SetRandom(Vec, PetscRandom) noexcept; 115 static PetscErrorCode BindToCPU(Vec, PetscBool) noexcept; 116 static PetscErrorCode SetPreallocationCOO(Vec, PetscCount, const PetscInt[]) noexcept; 117 static PetscErrorCode SetValuesCOO(Vec, const PetscScalar[], InsertMode) noexcept; 118 }; 119 120 // ========================================================================================== 121 // VecSeq_CUPM - Private API 122 // ========================================================================================== 123 124 template <device::cupm::DeviceType T> 125 inline Vec_Seq *VecSeq_CUPM<T>::VecIMPLCast_(Vec v) noexcept 126 { 127 return static_cast<Vec_Seq *>(v->data); 128 } 129 130 template <device::cupm::DeviceType T> 131 inline constexpr VecType VecSeq_CUPM<T>::VECIMPLCUPM_() noexcept 132 { 133 return VECSEQCUPM(); 134 } 135 136 template <device::cupm::DeviceType T> 137 inline PetscErrorCode VecSeq_CUPM<T>::VecDestroy_IMPL_(Vec v) noexcept 138 { 139 return VecDestroy_Seq(v); 140 } 141 142 template <device::cupm::DeviceType T> 143 inline PetscErrorCode VecSeq_CUPM<T>::VecResetArray_IMPL_(Vec v) noexcept 144 { 145 return VecResetArray_Seq(v); 146 } 147 148 template <device::cupm::DeviceType T> 149 inline PetscErrorCode VecSeq_CUPM<T>::VecPlaceArray_IMPL_(Vec v, const PetscScalar *a) noexcept 150 { 151 return VecPlaceArray_Seq(v, a); 152 } 153 154 template <device::cupm::DeviceType T> 155 inline PetscErrorCode VecSeq_CUPM<T>::VecCreate_IMPL_Private_(Vec v, PetscBool *alloc_missing, PetscInt, PetscScalar *host_array) noexcept 156 { 157 PetscMPIInt size; 158 159 PetscFunctionBegin; 160 if (alloc_missing) *alloc_missing = PETSC_FALSE; 161 PetscCallMPI(MPI_Comm_size(PetscObjectComm(PetscObjectCast(v)), &size)); 162 PetscCheck(size <= 1, PETSC_COMM_SELF, PETSC_ERR_ARG_WRONG, "Must create VecSeq on communicator of size 1, have size %d", size); 163 PetscCall(VecCreate_Seq_Private(v, host_array)); 164 PetscFunctionReturn(PETSC_SUCCESS); 165 } 166 167 // for functions with an early return based one vec size we still need to artificially bump the 168 // object state. This is to prevent the following: 169 // 170 // 0. Suppose you have a Vec { 171 // rank 0: [0], 172 // rank 1: [<empty>] 173 // } 174 // 1. both ranks have Vec with PetscObjectState = 0, stashed norm of 0 175 // 2. Vec enters e.g. VecSet(10) 176 // 3. rank 1 has local size 0 and bails immediately 177 // 4. rank 0 has local size 1 and enters function, eventually calls DeviceArrayWrite() 178 // 5. DeviceArrayWrite() calls PetscObjectStateIncrease(), now state = 1 179 // 6. Vec enters VecNorm(), and calls VecNormAvailable() 180 // 7. rank 1 has object state = 0, equal to stash and returns early with norm = 0 181 // 8. rank 0 has object state = 1, not equal to stash, continues to impl function 182 // 9. rank 0 deadlocks on MPI_Allreduce() because rank 1 bailed early 183 template <device::cupm::DeviceType T> 184 inline PetscErrorCode VecSeq_CUPM<T>::MaybeIncrementEmptyLocalVec(Vec v) noexcept 185 { 186 PetscFunctionBegin; 187 if (PetscUnlikely((v->map->n == 0) && (v->map->N != 0))) PetscCall(PetscObjectStateIncrease(PetscObjectCast(v))); 188 PetscFunctionReturn(PETSC_SUCCESS); 189 } 190 191 template <device::cupm::DeviceType T> 192 inline PetscErrorCode VecSeq_CUPM<T>::CreateSeqCUPM_(Vec v, PetscDeviceContext dctx, PetscScalar *host_array, PetscScalar *device_array) noexcept 193 { 194 PetscFunctionBegin; 195 PetscCall(base_type::VecCreate_IMPL_Private(v, nullptr, 0, host_array)); 196 PetscCall(Initialize_CUPMBase(v, PETSC_FALSE, host_array, device_array, dctx)); 197 PetscFunctionReturn(PETSC_SUCCESS); 198 } 199 200 template <device::cupm::DeviceType T> 201 template <typename BinaryFuncT> 202 inline PetscErrorCode VecSeq_CUPM<T>::PointwiseBinary_(BinaryFuncT &&binary, Vec xin, Vec yin, Vec zout) noexcept 203 { 204 PetscFunctionBegin; 205 if (const auto n = zout->map->n) { 206 PetscDeviceContext dctx; 207 cupmStream_t stream; 208 209 PetscCall(GetHandles_(&dctx, &stream)); 210 // clang-format off 211 PetscCallThrust( 212 const auto dxptr = thrust::device_pointer_cast(DeviceArrayRead(dctx, xin).data()); 213 214 THRUST_CALL( 215 thrust::transform, 216 stream, 217 dxptr, dxptr + n, 218 thrust::device_pointer_cast(DeviceArrayRead(dctx, yin).data()), 219 thrust::device_pointer_cast(DeviceArrayWrite(dctx, zout).data()), 220 std::forward<BinaryFuncT>(binary) 221 ) 222 ); 223 // clang-format on 224 PetscCall(PetscLogFlops(n)); 225 PetscCall(PetscDeviceContextSynchronize(dctx)); 226 } else { 227 PetscCall(MaybeIncrementEmptyLocalVec(zout)); 228 } 229 PetscFunctionReturn(PETSC_SUCCESS); 230 } 231 232 template <device::cupm::DeviceType T> 233 template <typename UnaryFuncT> 234 inline PetscErrorCode VecSeq_CUPM<T>::PointwiseUnary_(UnaryFuncT &&unary, Vec xinout, Vec yin) noexcept 235 { 236 const auto inplace = !yin || (xinout == yin); 237 238 PetscFunctionBegin; 239 if (const auto n = xinout->map->n) { 240 PetscDeviceContext dctx; 241 cupmStream_t stream; 242 const auto apply = [&](PetscScalar *xinout, PetscScalar *yin = nullptr) { 243 PetscFunctionBegin; 244 // clang-format off 245 PetscCallThrust( 246 const auto xptr = thrust::device_pointer_cast(xinout); 247 248 THRUST_CALL( 249 thrust::transform, 250 stream, 251 xptr, xptr + n, 252 (yin && (yin != xinout)) ? thrust::device_pointer_cast(yin) : xptr, 253 std::forward<UnaryFuncT>(unary) 254 ) 255 ); 256 PetscFunctionReturn(PETSC_SUCCESS); 257 }; 258 259 PetscCall(GetHandles_(&dctx, &stream)); 260 if (inplace) { 261 PetscCall(apply(DeviceArrayReadWrite(dctx, xinout).data())); 262 } else { 263 PetscCall(apply(DeviceArrayRead(dctx, xinout).data(), DeviceArrayWrite(dctx, yin).data())); 264 } 265 PetscCall(PetscLogFlops(n)); 266 PetscCall(PetscDeviceContextSynchronize(dctx)); 267 } else { 268 if (inplace) { 269 PetscCall(MaybeIncrementEmptyLocalVec(xinout)); 270 } else { 271 PetscCall(MaybeIncrementEmptyLocalVec(yin)); 272 } 273 } 274 PetscFunctionReturn(PETSC_SUCCESS); 275 } 276 277 // ========================================================================================== 278 // VecSeq_CUPM - Public API - Constructors 279 // ========================================================================================== 280 281 // VecCreateSeqCUPM() 282 template <device::cupm::DeviceType T> 283 inline PetscErrorCode VecSeq_CUPM<T>::CreateSeqCUPM(MPI_Comm comm, PetscInt bs, PetscInt n, Vec *v, PetscBool call_set_type) noexcept 284 { 285 PetscFunctionBegin; 286 PetscCall(Create_CUPMBase(comm, bs, n, n, v, call_set_type)); 287 PetscFunctionReturn(PETSC_SUCCESS); 288 } 289 290 // VecCreateSeqCUPMWithArrays() 291 template <device::cupm::DeviceType T> 292 inline PetscErrorCode VecSeq_CUPM<T>::CreateSeqCUPMWithBothArrays(MPI_Comm comm, PetscInt bs, PetscInt n, const PetscScalar host_array[], const PetscScalar device_array[], Vec *v) noexcept 293 { 294 PetscDeviceContext dctx; 295 296 PetscFunctionBegin; 297 PetscCall(GetHandles_(&dctx)); 298 // do NOT call VecSetType(), otherwise ops->create() -> create() -> 299 // CreateSeqCUPM_() is called! 300 PetscCall(CreateSeqCUPM(comm, bs, n, v, PETSC_FALSE)); 301 PetscCall(CreateSeqCUPM_(*v, dctx, PetscRemoveConstCast(host_array), PetscRemoveConstCast(device_array))); 302 PetscFunctionReturn(PETSC_SUCCESS); 303 } 304 305 // v->ops->duplicate 306 template <device::cupm::DeviceType T> 307 inline PetscErrorCode VecSeq_CUPM<T>::Duplicate(Vec v, Vec *y) noexcept 308 { 309 PetscDeviceContext dctx; 310 311 PetscFunctionBegin; 312 PetscCall(GetHandles_(&dctx)); 313 PetscCall(Duplicate_CUPMBase(v, y, dctx)); 314 PetscFunctionReturn(PETSC_SUCCESS); 315 } 316 317 // ========================================================================================== 318 // VecSeq_CUPM - Public API - Utility 319 // ========================================================================================== 320 321 // v->ops->bindtocpu 322 template <device::cupm::DeviceType T> 323 inline PetscErrorCode VecSeq_CUPM<T>::BindToCPU(Vec v, PetscBool usehost) noexcept 324 { 325 PetscDeviceContext dctx; 326 327 PetscFunctionBegin; 328 PetscCall(GetHandles_(&dctx)); 329 PetscCall(BindToCPU_CUPMBase(v, usehost, dctx)); 330 331 // REVIEW ME: this absolutely should be some sort of bulk mempcy rather than this mess 332 VecSetOp_CUPM(dot, VecDot_Seq, Dot); 333 VecSetOp_CUPM(norm, VecNorm_Seq, Norm); 334 VecSetOp_CUPM(tdot, VecTDot_Seq, TDot); 335 VecSetOp_CUPM(mdot, VecMDot_Seq, MDot); 336 VecSetOp_CUPM(resetarray, VecResetArray_Seq, base_type::template ResetArray<PETSC_MEMTYPE_HOST>); 337 VecSetOp_CUPM(placearray, VecPlaceArray_Seq, base_type::template PlaceArray<PETSC_MEMTYPE_HOST>); 338 v->ops->mtdot = v->ops->mtdot_local = VecMTDot_Seq; 339 VecSetOp_CUPM(conjugate, VecConjugate_Seq, Conjugate); 340 VecSetOp_CUPM(max, VecMax_Seq, Max); 341 VecSetOp_CUPM(min, VecMin_Seq, Min); 342 VecSetOp_CUPM(setpreallocationcoo, VecSetPreallocationCOO_Seq, SetPreallocationCOO); 343 VecSetOp_CUPM(setvaluescoo, VecSetValuesCOO_Seq, SetValuesCOO); 344 PetscFunctionReturn(PETSC_SUCCESS); 345 } 346 347 // ========================================================================================== 348 // VecSeq_CUPM - Public API - Mutators 349 // ========================================================================================== 350 351 // v->ops->getlocalvector or v->ops->getlocalvectorread 352 template <device::cupm::DeviceType T> 353 template <PetscMemoryAccessMode access> 354 inline PetscErrorCode VecSeq_CUPM<T>::GetLocalVector(Vec v, Vec w) noexcept 355 { 356 PetscBool wisseqcupm; 357 358 PetscFunctionBegin; 359 PetscCheckTypeNames(v, VECSEQCUPM(), VECMPICUPM()); 360 PetscCall(PetscObjectTypeCompare(PetscObjectCast(w), VECSEQCUPM(), &wisseqcupm)); 361 if (wisseqcupm) { 362 if (const auto wseq = VecIMPLCast(w)) { 363 if (auto &alloced = wseq->array_allocated) { 364 const auto useit = UseCUPMHostAlloc(util::exchange(w->pinned_memory, PETSC_FALSE)); 365 366 PetscCall(PetscFree(alloced)); 367 } 368 wseq->array = nullptr; 369 wseq->unplacedarray = nullptr; 370 } 371 if (const auto wcu = VecCUPMCast(w)) { 372 if (auto &device_array = wcu->array_d) { 373 cupmStream_t stream; 374 375 PetscCall(GetHandles_(&stream)); 376 PetscCallCUPM(cupmFreeAsync(device_array, stream)); 377 } 378 PetscCall(PetscFree(w->spptr /* wcu */)); 379 } 380 } 381 if (v->petscnative && wisseqcupm) { 382 PetscCall(PetscFree(w->data)); 383 w->data = v->data; 384 w->offloadmask = v->offloadmask; 385 w->pinned_memory = v->pinned_memory; 386 w->spptr = v->spptr; 387 PetscCall(PetscObjectStateIncrease(PetscObjectCast(w))); 388 } else { 389 const auto array = &VecIMPLCast(w)->array; 390 391 if (access == PETSC_MEMORY_ACCESS_READ) { 392 PetscCall(VecGetArrayRead(v, const_cast<const PetscScalar **>(array))); 393 } else { 394 PetscCall(VecGetArray(v, array)); 395 } 396 w->offloadmask = PETSC_OFFLOAD_CPU; 397 if (wisseqcupm) { 398 PetscDeviceContext dctx; 399 400 PetscCall(GetHandles_(&dctx)); 401 PetscCall(DeviceAllocateCheck_(dctx, w)); 402 } 403 } 404 PetscFunctionReturn(PETSC_SUCCESS); 405 } 406 407 // v->ops->restorelocalvector or v->ops->restorelocalvectorread 408 template <device::cupm::DeviceType T> 409 template <PetscMemoryAccessMode access> 410 inline PetscErrorCode VecSeq_CUPM<T>::RestoreLocalVector(Vec v, Vec w) noexcept 411 { 412 PetscBool wisseqcupm; 413 414 PetscFunctionBegin; 415 PetscCheckTypeNames(v, VECSEQCUPM(), VECMPICUPM()); 416 PetscCall(PetscObjectTypeCompare(PetscObjectCast(w), VECSEQCUPM(), &wisseqcupm)); 417 if (v->petscnative && wisseqcupm) { 418 // the assignments to nullptr are __critical__, as w may persist after this call returns 419 // and shouldn't share data with v! 420 v->pinned_memory = w->pinned_memory; 421 v->offloadmask = util::exchange(w->offloadmask, PETSC_OFFLOAD_UNALLOCATED); 422 v->data = util::exchange(w->data, nullptr); 423 v->spptr = util::exchange(w->spptr, nullptr); 424 } else { 425 const auto array = &VecIMPLCast(w)->array; 426 427 if (access == PETSC_MEMORY_ACCESS_READ) { 428 PetscCall(VecRestoreArrayRead(v, const_cast<const PetscScalar **>(array))); 429 } else { 430 PetscCall(VecRestoreArray(v, array)); 431 } 432 if (w->spptr && wisseqcupm) { 433 cupmStream_t stream; 434 435 PetscCall(GetHandles_(&stream)); 436 PetscCallCUPM(cupmFreeAsync(VecCUPMCast(w)->array_d, stream)); 437 PetscCall(PetscFree(w->spptr)); 438 } 439 } 440 PetscFunctionReturn(PETSC_SUCCESS); 441 } 442 443 // ========================================================================================== 444 // VecSeq_CUPM - Public API - Compute Methods 445 // ========================================================================================== 446 447 // v->ops->aypx 448 template <device::cupm::DeviceType T> 449 inline PetscErrorCode VecSeq_CUPM<T>::AYPX(Vec yin, PetscScalar alpha, Vec xin) noexcept 450 { 451 const auto n = static_cast<cupmBlasInt_t>(yin->map->n); 452 const auto sync = n != 0; 453 PetscDeviceContext dctx; 454 455 PetscFunctionBegin; 456 PetscCall(GetHandles_(&dctx)); 457 if (alpha == PetscScalar(0.0)) { 458 cupmStream_t stream; 459 460 PetscCall(GetHandlesFrom_(dctx, &stream)); 461 PetscCall(PetscLogGpuTimeBegin()); 462 PetscCall(PetscCUPMMemcpyAsync(DeviceArrayWrite(dctx, yin).data(), DeviceArrayRead(dctx, xin).data(), n, cupmMemcpyDeviceToDevice, stream)); 463 PetscCall(PetscLogGpuTimeEnd()); 464 } else if (n) { 465 const auto alphaIsOne = alpha == PetscScalar(1.0); 466 const auto calpha = cupmScalarPtrCast(&alpha); 467 cupmBlasHandle_t cupmBlasHandle; 468 469 PetscCall(GetHandlesFrom_(dctx, &cupmBlasHandle)); 470 { 471 const auto yptr = DeviceArrayReadWrite(dctx, yin); 472 const auto xptr = DeviceArrayRead(dctx, xin); 473 474 PetscCall(PetscLogGpuTimeBegin()); 475 if (alphaIsOne) { 476 PetscCallCUPMBLAS(cupmBlasXaxpy(cupmBlasHandle, n, calpha, xptr.cupmdata(), 1, yptr.cupmdata(), 1)); 477 } else { 478 const auto one = cupmScalarCast(1.0); 479 480 PetscCallCUPMBLAS(cupmBlasXscal(cupmBlasHandle, n, calpha, yptr.cupmdata(), 1)); 481 PetscCallCUPMBLAS(cupmBlasXaxpy(cupmBlasHandle, n, &one, xptr.cupmdata(), 1, yptr.cupmdata(), 1)); 482 } 483 PetscCall(PetscLogGpuTimeEnd()); 484 } 485 PetscCall(PetscLogGpuFlops((alphaIsOne ? 1 : 2) * n)); 486 } 487 if (sync) PetscCall(PetscDeviceContextSynchronize(dctx)); 488 PetscFunctionReturn(PETSC_SUCCESS); 489 } 490 491 // v->ops->axpy 492 template <device::cupm::DeviceType T> 493 inline PetscErrorCode VecSeq_CUPM<T>::AXPY(Vec yin, PetscScalar alpha, Vec xin) noexcept 494 { 495 PetscBool xiscupm; 496 497 PetscFunctionBegin; 498 PetscCall(PetscObjectTypeCompareAny(PetscObjectCast(xin), &xiscupm, VECSEQCUPM(), VECMPICUPM(), "")); 499 if (xiscupm) { 500 const auto n = static_cast<cupmBlasInt_t>(yin->map->n); 501 PetscDeviceContext dctx; 502 cupmBlasHandle_t cupmBlasHandle; 503 504 PetscCall(GetHandles_(&dctx, &cupmBlasHandle)); 505 PetscCall(PetscLogGpuTimeBegin()); 506 PetscCallCUPMBLAS(cupmBlasXaxpy(cupmBlasHandle, n, cupmScalarPtrCast(&alpha), DeviceArrayRead(dctx, xin), 1, DeviceArrayReadWrite(dctx, yin), 1)); 507 PetscCall(PetscLogGpuTimeEnd()); 508 PetscCall(PetscLogGpuFlops(2 * n)); 509 PetscCall(PetscDeviceContextSynchronize(dctx)); 510 } else { 511 PetscCall(VecAXPY_Seq(yin, alpha, xin)); 512 } 513 PetscFunctionReturn(PETSC_SUCCESS); 514 } 515 516 // v->ops->pointwisedivide 517 template <device::cupm::DeviceType T> 518 inline PetscErrorCode VecSeq_CUPM<T>::PointwiseDivide(Vec win, Vec xin, Vec yin) noexcept 519 { 520 PetscFunctionBegin; 521 if (xin->boundtocpu || yin->boundtocpu) { 522 PetscCall(VecPointwiseDivide_Seq(win, xin, yin)); 523 } else { 524 // note order of arguments! xin and yin are read, win is written! 525 PetscCall(PointwiseBinary_(thrust::divides<PetscScalar>{}, xin, yin, win)); 526 } 527 PetscFunctionReturn(PETSC_SUCCESS); 528 } 529 530 // v->ops->pointwisemult 531 template <device::cupm::DeviceType T> 532 inline PetscErrorCode VecSeq_CUPM<T>::PointwiseMult(Vec win, Vec xin, Vec yin) noexcept 533 { 534 PetscFunctionBegin; 535 if (xin->boundtocpu || yin->boundtocpu) { 536 PetscCall(VecPointwiseMult_Seq(win, xin, yin)); 537 } else { 538 // note order of arguments! xin and yin are read, win is written! 539 PetscCall(PointwiseBinary_(thrust::multiplies<PetscScalar>{}, xin, yin, win)); 540 } 541 PetscFunctionReturn(PETSC_SUCCESS); 542 } 543 544 namespace detail 545 { 546 547 struct reciprocal { 548 PETSC_HOSTDEVICE_INLINE_DECL PetscScalar operator()(PetscScalar s) const noexcept 549 { 550 // yes all of this verbosity is needed because sometimes PetscScalar is a thrust::complex 551 // and then it matters whether we do s ? true : false vs s == 0, as well as whether we wrap 552 // everything in PetscScalar... 553 return s == PetscScalar{0.0} ? s : PetscScalar{1.0} / s; 554 } 555 }; 556 557 } // namespace detail 558 559 // v->ops->reciprocal 560 template <device::cupm::DeviceType T> 561 inline PetscErrorCode VecSeq_CUPM<T>::Reciprocal(Vec xin) noexcept 562 { 563 PetscFunctionBegin; 564 PetscCall(PointwiseUnary_(detail::reciprocal{}, xin)); 565 PetscFunctionReturn(PETSC_SUCCESS); 566 } 567 568 // v->ops->waxpy 569 template <device::cupm::DeviceType T> 570 inline PetscErrorCode VecSeq_CUPM<T>::WAXPY(Vec win, PetscScalar alpha, Vec xin, Vec yin) noexcept 571 { 572 PetscFunctionBegin; 573 if (alpha == PetscScalar(0.0)) { 574 PetscCall(Copy(yin, win)); 575 } else if (const auto n = static_cast<cupmBlasInt_t>(win->map->n)) { 576 PetscDeviceContext dctx; 577 cupmBlasHandle_t cupmBlasHandle; 578 cupmStream_t stream; 579 580 PetscCall(GetHandles_(&dctx, &cupmBlasHandle, &stream)); 581 { 582 const auto wptr = DeviceArrayWrite(dctx, win); 583 584 PetscCall(PetscLogGpuTimeBegin()); 585 PetscCall(PetscCUPMMemcpyAsync(wptr.data(), DeviceArrayRead(dctx, yin).data(), n, cupmMemcpyDeviceToDevice, stream, true)); 586 PetscCallCUPMBLAS(cupmBlasXaxpy(cupmBlasHandle, n, cupmScalarPtrCast(&alpha), DeviceArrayRead(dctx, xin), 1, wptr.cupmdata(), 1)); 587 PetscCall(PetscLogGpuTimeEnd()); 588 } 589 PetscCall(PetscLogGpuFlops(2 * n)); 590 PetscCall(PetscDeviceContextSynchronize(dctx)); 591 } 592 PetscFunctionReturn(PETSC_SUCCESS); 593 } 594 595 namespace kernels 596 { 597 598 template <typename... Args> 599 PETSC_KERNEL_DECL static void MAXPY_kernel(const PetscInt size, PetscScalar *PETSC_RESTRICT xptr, const PetscScalar *PETSC_RESTRICT aptr, Args... yptr) 600 { 601 constexpr int N = sizeof...(Args); 602 const auto tx = threadIdx.x; 603 const PetscScalar *yptr_p[] = {yptr...}; 604 605 PETSC_SHAREDMEM_DECL PetscScalar aptr_shmem[N]; 606 607 // load a to shared memory 608 if (tx < N) aptr_shmem[tx] = aptr[tx]; 609 __syncthreads(); 610 611 ::Petsc::device::cupm::kernels::util::grid_stride_1D(size, [&](PetscInt i) { 612 // these may look the same but give different results! 613 #if 0 614 PetscScalar sum = 0.0; 615 616 #pragma unroll 617 for (auto j = 0; j < N; ++j) sum += aptr_shmem[j]*yptr_p[j][i]; 618 xptr[i] += sum; 619 #else 620 auto sum = xptr[i]; 621 622 #pragma unroll 623 for (auto j = 0; j < N; ++j) sum += aptr_shmem[j]*yptr_p[j][i]; 624 xptr[i] = sum; 625 #endif 626 }); 627 return; 628 } 629 630 } // namespace kernels 631 632 namespace detail 633 { 634 635 // a helper-struct to gobble the size_t input, it is used with template parameter pack 636 // expansion such that 637 // typename repeat_type<MyType, IdxParamPack>... 638 // expands to 639 // MyType, MyType, MyType, ... [repeated sizeof...(IdxParamPack) times] 640 template <typename T, std::size_t> 641 struct repeat_type { 642 using type = T; 643 }; 644 645 } // namespace detail 646 647 template <device::cupm::DeviceType T> 648 template <std::size_t... Idx> 649 inline PetscErrorCode VecSeq_CUPM<T>::MAXPY_kernel_dispatch_(PetscDeviceContext dctx, cupmStream_t stream, PetscScalar *xptr, const PetscScalar *aptr, const Vec *yin, PetscInt size, util::index_sequence<Idx...>) noexcept 650 { 651 PetscFunctionBegin; 652 // clang-format off 653 PetscCall( 654 PetscCUPMLaunchKernel1D( 655 size, 0, stream, 656 kernels::MAXPY_kernel<typename detail::repeat_type<const PetscScalar *, Idx>::type...>, 657 size, xptr, aptr, DeviceArrayRead(dctx, yin[Idx]).data()... 658 ) 659 ); 660 // clang-format on 661 PetscFunctionReturn(PETSC_SUCCESS); 662 } 663 664 template <device::cupm::DeviceType T> 665 template <int N> 666 inline PetscErrorCode VecSeq_CUPM<T>::MAXPY_kernel_dispatch_(PetscDeviceContext dctx, cupmStream_t stream, PetscScalar *xptr, const PetscScalar *aptr, const Vec *yin, PetscInt size, PetscInt &yidx) noexcept 667 { 668 PetscFunctionBegin; 669 PetscCall(MAXPY_kernel_dispatch_(dctx, stream, xptr, aptr + yidx, yin + yidx, size, util::make_index_sequence<N>{})); 670 yidx += N; 671 PetscFunctionReturn(PETSC_SUCCESS); 672 } 673 674 // v->ops->maxpy 675 template <device::cupm::DeviceType T> 676 inline PetscErrorCode VecSeq_CUPM<T>::MAXPY(Vec xin, PetscInt nv, const PetscScalar *alpha, Vec *yin) noexcept 677 { 678 const auto n = xin->map->n; 679 PetscDeviceContext dctx; 680 cupmStream_t stream; 681 682 PetscFunctionBegin; 683 PetscCall(GetHandles_(&dctx, &stream)); 684 { 685 const auto xptr = DeviceArrayReadWrite(dctx, xin); 686 PetscScalar *d_alpha = nullptr; 687 PetscInt yidx = 0; 688 689 // placement of early-return is deliberate, we would like to capture the 690 // DeviceArrayReadWrite() call (which calls PetscObjectStateIncreate()) before we bail 691 if (!n || !nv) PetscFunctionReturn(PETSC_SUCCESS); 692 PetscCall(PetscDeviceMalloc(dctx, PETSC_MEMTYPE_CUPM(), nv, &d_alpha)); 693 PetscCall(PetscCUPMMemcpyAsync(d_alpha, alpha, nv, cupmMemcpyHostToDevice, stream)); 694 PetscCall(PetscLogGpuTimeBegin()); 695 do { 696 switch (nv - yidx) { 697 case 7: 698 PetscCall(MAXPY_kernel_dispatch_<7>(dctx, stream, xptr.data(), d_alpha, yin, n, yidx)); 699 break; 700 case 6: 701 PetscCall(MAXPY_kernel_dispatch_<6>(dctx, stream, xptr.data(), d_alpha, yin, n, yidx)); 702 break; 703 case 5: 704 PetscCall(MAXPY_kernel_dispatch_<5>(dctx, stream, xptr.data(), d_alpha, yin, n, yidx)); 705 break; 706 case 4: 707 PetscCall(MAXPY_kernel_dispatch_<4>(dctx, stream, xptr.data(), d_alpha, yin, n, yidx)); 708 break; 709 case 3: 710 PetscCall(MAXPY_kernel_dispatch_<3>(dctx, stream, xptr.data(), d_alpha, yin, n, yidx)); 711 break; 712 case 2: 713 PetscCall(MAXPY_kernel_dispatch_<2>(dctx, stream, xptr.data(), d_alpha, yin, n, yidx)); 714 break; 715 case 1: 716 PetscCall(MAXPY_kernel_dispatch_<1>(dctx, stream, xptr.data(), d_alpha, yin, n, yidx)); 717 break; 718 default: // 8 or more 719 PetscCall(MAXPY_kernel_dispatch_<8>(dctx, stream, xptr.data(), d_alpha, yin, n, yidx)); 720 break; 721 } 722 } while (yidx < nv); 723 PetscCall(PetscLogGpuTimeEnd()); 724 PetscCall(PetscDeviceFree(dctx, d_alpha)); 725 } 726 PetscCall(PetscLogGpuFlops(nv * 2 * n)); 727 PetscCall(PetscDeviceContextSynchronize(dctx)); 728 PetscFunctionReturn(PETSC_SUCCESS); 729 } 730 731 template <device::cupm::DeviceType T> 732 inline PetscErrorCode VecSeq_CUPM<T>::Dot(Vec xin, Vec yin, PetscScalar *z) noexcept 733 { 734 PetscFunctionBegin; 735 if (const auto n = static_cast<cupmBlasInt_t>(xin->map->n)) { 736 PetscDeviceContext dctx; 737 cupmBlasHandle_t cupmBlasHandle; 738 739 PetscCall(GetHandles_(&dctx, &cupmBlasHandle)); 740 // arguments y, x are reversed because BLAS complex conjugates the first argument, PETSc the 741 // second 742 PetscCall(PetscLogGpuTimeBegin()); 743 PetscCallCUPMBLAS(cupmBlasXdot(cupmBlasHandle, n, DeviceArrayRead(dctx, yin), 1, DeviceArrayRead(dctx, xin), 1, cupmScalarPtrCast(z))); 744 PetscCall(PetscLogGpuTimeEnd()); 745 PetscCall(PetscLogGpuFlops(2 * n - 1)); 746 } else { 747 *z = 0.0; 748 } 749 PetscFunctionReturn(PETSC_SUCCESS); 750 } 751 752 #define MDOT_WORKGROUP_NUM 128 753 #define MDOT_WORKGROUP_SIZE MDOT_WORKGROUP_NUM 754 755 namespace kernels 756 { 757 758 PETSC_DEVICE_INLINE_DECL static PetscInt EntriesPerGroup(const PetscInt size) noexcept 759 { 760 const auto group_entries = (size - 1) / gridDim.x + 1; 761 // for very small vectors, a group should still do some work 762 return group_entries ? group_entries : 1; 763 } 764 765 template <typename... ConstPetscScalarPointer> 766 PETSC_KERNEL_DECL static void MDot_kernel(const PetscScalar *PETSC_RESTRICT x, const PetscInt size, PetscScalar *PETSC_RESTRICT results, ConstPetscScalarPointer... y) 767 { 768 constexpr int N = sizeof...(ConstPetscScalarPointer); 769 const PetscScalar *ylocal[] = {y...}; 770 PetscScalar sumlocal[N]; 771 772 PETSC_SHAREDMEM_DECL PetscScalar shmem[N * MDOT_WORKGROUP_SIZE]; 773 774 // HIP -- for whatever reason -- has threadIdx, blockIdx, blockDim, and gridDim as separate 775 // types, so each of these go on separate lines... 776 const auto tx = threadIdx.x; 777 const auto bx = blockIdx.x; 778 const auto bdx = blockDim.x; 779 const auto gdx = gridDim.x; 780 const auto worksize = EntriesPerGroup(size); 781 const auto begin = tx + bx * worksize; 782 const auto end = min((bx + 1) * worksize, size); 783 784 #pragma unroll 785 for (auto i = 0; i < N; ++i) sumlocal[i] = 0; 786 787 for (auto i = begin; i < end; i += bdx) { 788 const auto xi = x[i]; // load only once from global memory! 789 790 #pragma unroll 791 for (auto j = 0; j < N; ++j) sumlocal[j] += ylocal[j][i] * xi; 792 } 793 794 #pragma unroll 795 for (auto i = 0; i < N; ++i) shmem[tx + i * MDOT_WORKGROUP_SIZE] = sumlocal[i]; 796 797 // parallel reduction 798 for (auto stride = bdx / 2; stride > 0; stride /= 2) { 799 __syncthreads(); 800 if (tx < stride) { 801 #pragma unroll 802 for (auto i = 0; i < N; ++i) shmem[tx + i * MDOT_WORKGROUP_SIZE] += shmem[tx + stride + i * MDOT_WORKGROUP_SIZE]; 803 } 804 } 805 // bottom N threads per block write to global memory 806 // REVIEW ME: I am ~pretty~ sure we don't need another __syncthreads() here since each thread 807 // writes to the same sections in the above loop that it is about to read from below, but 808 // running this under the racecheck tool of cuda-memcheck reports a write-after-write hazard. 809 __syncthreads(); 810 if (tx < N) results[bx + tx * gdx] = shmem[tx * MDOT_WORKGROUP_SIZE]; 811 return; 812 } 813 814 namespace 815 { 816 817 PETSC_KERNEL_DECL void sum_kernel(const PetscInt size, PetscScalar *PETSC_RESTRICT results) 818 { 819 int local_i = 0; 820 PetscScalar local_results[8]; 821 822 // each thread sums up MDOT_WORKGROUP_NUM entries of the result, storing it in a local buffer 823 // 824 // *-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-* 825 // | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | ... 826 // *-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-* 827 // | ______________________________________________________/ 828 // | / <- MDOT_WORKGROUP_NUM -> 829 // |/ 830 // + 831 // v 832 // *-*-* 833 // | | | ... 834 // *-*-* 835 // 836 ::Petsc::device::cupm::kernels::util::grid_stride_1D(size, [&](PetscInt i) { 837 PetscScalar z_sum = 0; 838 839 for (auto j = i * MDOT_WORKGROUP_SIZE; j < (i + 1) * MDOT_WORKGROUP_SIZE; ++j) z_sum += results[j]; 840 local_results[local_i++] = z_sum; 841 }); 842 // if we needed more than 1 workgroup to handle the vector we should sync since other threads 843 // may currently be reading from results 844 if (size >= MDOT_WORKGROUP_SIZE) __syncthreads(); 845 // Local buffer is now written to global memory 846 ::Petsc::device::cupm::kernels::util::grid_stride_1D(size, [&](PetscInt i) { 847 const auto j = --local_i; 848 849 if (j >= 0) results[i] = local_results[j]; 850 }); 851 return; 852 } 853 854 } // namespace 855 856 } // namespace kernels 857 858 template <device::cupm::DeviceType T> 859 template <std::size_t... Idx> 860 inline PetscErrorCode VecSeq_CUPM<T>::MDot_kernel_dispatch_(PetscDeviceContext dctx, cupmStream_t stream, const PetscScalar *xarr, const Vec yin[], PetscInt size, PetscScalar *results, util::index_sequence<Idx...>) noexcept 861 { 862 PetscFunctionBegin; 863 // REVIEW ME: convert this kernel launch to PetscCUPMLaunchKernel1D(), it currently launches 864 // 128 blocks of 128 threads every time which may be wasteful 865 // clang-format off 866 PetscCallCUPM( 867 cupmLaunchKernel( 868 kernels::MDot_kernel<typename detail::repeat_type<const PetscScalar *, Idx>::type...>, 869 MDOT_WORKGROUP_NUM, MDOT_WORKGROUP_SIZE, 0, stream, 870 xarr, size, results, DeviceArrayRead(dctx, yin[Idx]).data()... 871 ) 872 ); 873 // clang-format on 874 PetscFunctionReturn(PETSC_SUCCESS); 875 } 876 877 template <device::cupm::DeviceType T> 878 template <int N> 879 inline PetscErrorCode VecSeq_CUPM<T>::MDot_kernel_dispatch_(PetscDeviceContext dctx, cupmStream_t stream, const PetscScalar *xarr, const Vec yin[], PetscInt size, PetscScalar *results, PetscInt &yidx) noexcept 880 { 881 PetscFunctionBegin; 882 PetscCall(MDot_kernel_dispatch_(dctx, stream, xarr, yin + yidx, size, results + yidx * MDOT_WORKGROUP_NUM, util::make_index_sequence<N>{})); 883 yidx += N; 884 PetscFunctionReturn(PETSC_SUCCESS); 885 } 886 887 template <device::cupm::DeviceType T> 888 inline PetscErrorCode VecSeq_CUPM<T>::MDot_(std::false_type, Vec xin, PetscInt nv, const Vec yin[], PetscScalar *z, PetscDeviceContext dctx) noexcept 889 { 890 // the largest possible size of a batch 891 constexpr PetscInt batchsize = 8; 892 // how many sub streams to create, if nv <= batchsize we can do this without looping, so we 893 // do not create substreams. Note we don't create more than 8 streams, in practice we could 894 // not get more parallelism with higher numbers. 895 const auto num_sub_streams = nv > batchsize ? std::min((nv + batchsize) / batchsize, batchsize) : 0; 896 const auto n = xin->map->n; 897 // number of vectors that we handle via the batches. note any singletons are handled by 898 // cublas, hence the nv-1. 899 const auto nvbatch = ((nv % batchsize) == 1) ? nv - 1 : nv; 900 const auto nwork = nvbatch * MDOT_WORKGROUP_NUM; 901 PetscScalar *d_results; 902 cupmStream_t stream; 903 904 PetscFunctionBegin; 905 PetscCall(GetHandlesFrom_(dctx, &stream)); 906 // allocate scratchpad memory for the results of individual work groups 907 PetscCall(PetscDeviceMalloc(dctx, PETSC_MEMTYPE_CUPM(), nwork, &d_results)); 908 { 909 const auto xptr = DeviceArrayRead(dctx, xin); 910 PetscInt yidx = 0; 911 auto subidx = 0; 912 auto cur_stream = stream; 913 auto cur_ctx = dctx; 914 PetscDeviceContext *sub = nullptr; 915 PetscStreamType stype; 916 917 // REVIEW ME: maybe PetscDeviceContextFork() should insert dctx into the first entry of 918 // sub. Ideally the parent context should also join in on the fork, but it is extremely 919 // fiddly to do so presently 920 PetscCall(PetscDeviceContextGetStreamType(dctx, &stype)); 921 if (stype == PETSC_STREAM_GLOBAL_BLOCKING) stype = PETSC_STREAM_DEFAULT_BLOCKING; 922 // If we have a globally blocking stream create nonblocking streams instead (as we can 923 // locally exploit the parallelism). Otherwise use the prescribed stream type. 924 PetscCall(PetscDeviceContextForkWithStreamType(dctx, stype, num_sub_streams, &sub)); 925 PetscCall(PetscLogGpuTimeBegin()); 926 do { 927 if (num_sub_streams) { 928 cur_ctx = sub[subidx++ % num_sub_streams]; 929 PetscCall(GetHandlesFrom_(cur_ctx, &cur_stream)); 930 } 931 // REVIEW ME: Should probably try and load-balance these. Consider the case where nv = 9; 932 // it is very likely better to do 4+5 rather than 8+1 933 switch (nv - yidx) { 934 case 7: 935 PetscCall(MDot_kernel_dispatch_<7>(cur_ctx, cur_stream, xptr.data(), yin, n, d_results, yidx)); 936 break; 937 case 6: 938 PetscCall(MDot_kernel_dispatch_<6>(cur_ctx, cur_stream, xptr.data(), yin, n, d_results, yidx)); 939 break; 940 case 5: 941 PetscCall(MDot_kernel_dispatch_<5>(cur_ctx, cur_stream, xptr.data(), yin, n, d_results, yidx)); 942 break; 943 case 4: 944 PetscCall(MDot_kernel_dispatch_<4>(cur_ctx, cur_stream, xptr.data(), yin, n, d_results, yidx)); 945 break; 946 case 3: 947 PetscCall(MDot_kernel_dispatch_<3>(cur_ctx, cur_stream, xptr.data(), yin, n, d_results, yidx)); 948 break; 949 case 2: 950 PetscCall(MDot_kernel_dispatch_<2>(cur_ctx, cur_stream, xptr.data(), yin, n, d_results, yidx)); 951 break; 952 case 1: { 953 cupmBlasHandle_t cupmBlasHandle; 954 955 PetscCall(GetHandlesFrom_(cur_ctx, &cupmBlasHandle)); 956 PetscCallCUPMBLAS(cupmBlasXdot(cupmBlasHandle, static_cast<cupmBlasInt_t>(n), DeviceArrayRead(cur_ctx, yin[yidx]).cupmdata(), 1, xptr.cupmdata(), 1, cupmScalarPtrCast(z + yidx))); 957 ++yidx; 958 } break; 959 default: // 8 or more 960 PetscCall(MDot_kernel_dispatch_<8>(cur_ctx, cur_stream, xptr.data(), yin, n, d_results, yidx)); 961 break; 962 } 963 } while (yidx < nv); 964 PetscCall(PetscLogGpuTimeEnd()); 965 PetscCall(PetscDeviceContextJoin(dctx, num_sub_streams, PETSC_DEVICE_CONTEXT_JOIN_DESTROY, &sub)); 966 } 967 968 PetscCall(PetscCUPMLaunchKernel1D(nvbatch, 0, stream, kernels::sum_kernel, nvbatch, d_results)); 969 // copy result of device reduction to host 970 PetscCall(PetscCUPMMemcpyAsync(z, d_results, nvbatch, cupmMemcpyDeviceToHost, stream)); 971 // do these now while final reduction is in flight 972 PetscCall(PetscLogFlops(nwork)); 973 PetscCall(PetscDeviceFree(dctx, d_results)); 974 PetscFunctionReturn(PETSC_SUCCESS); 975 } 976 977 #undef MDOT_WORKGROUP_NUM 978 #undef MDOT_WORKGROUP_SIZE 979 980 template <device::cupm::DeviceType T> 981 inline PetscErrorCode VecSeq_CUPM<T>::MDot_(std::true_type, Vec xin, PetscInt nv, const Vec yin[], PetscScalar *z, PetscDeviceContext dctx) noexcept 982 { 983 // probably not worth it to run more than 8 of these at a time? 984 const auto n_sub = PetscMin(nv, 8); 985 const auto n = static_cast<cupmBlasInt_t>(xin->map->n); 986 const auto xptr = DeviceArrayRead(dctx, xin); 987 PetscScalar *d_z; 988 PetscDeviceContext *subctx; 989 cupmStream_t stream; 990 991 PetscFunctionBegin; 992 PetscCall(GetHandlesFrom_(dctx, &stream)); 993 PetscCall(PetscDeviceMalloc(dctx, PETSC_MEMTYPE_CUPM(), nv, &d_z)); 994 PetscCall(PetscDeviceContextFork(dctx, n_sub, &subctx)); 995 PetscCall(PetscLogGpuTimeBegin()); 996 for (PetscInt i = 0; i < nv; ++i) { 997 const auto sub = subctx[i % n_sub]; 998 cupmBlasHandle_t handle; 999 cupmBlasPointerMode_t old_mode; 1000 1001 PetscCall(GetHandlesFrom_(sub, &handle)); 1002 PetscCallCUPMBLAS(cupmBlasGetPointerMode(handle, &old_mode)); 1003 if (old_mode != CUPMBLAS_POINTER_MODE_DEVICE) PetscCallCUPMBLAS(cupmBlasSetPointerMode(handle, CUPMBLAS_POINTER_MODE_DEVICE)); 1004 PetscCallCUPMBLAS(cupmBlasXdot(handle, n, DeviceArrayRead(sub, yin[i]), 1, xptr.cupmdata(), 1, cupmScalarPtrCast(d_z + i))); 1005 if (old_mode != CUPMBLAS_POINTER_MODE_DEVICE) PetscCallCUPMBLAS(cupmBlasSetPointerMode(handle, old_mode)); 1006 } 1007 PetscCall(PetscLogGpuTimeEnd()); 1008 PetscCall(PetscDeviceContextJoin(dctx, n_sub, PETSC_DEVICE_CONTEXT_JOIN_DESTROY, &subctx)); 1009 PetscCall(PetscCUPMMemcpyAsync(z, d_z, nv, cupmMemcpyDeviceToHost, stream)); 1010 PetscCall(PetscDeviceFree(dctx, d_z)); 1011 // REVIEW ME: flops????? 1012 PetscFunctionReturn(PETSC_SUCCESS); 1013 } 1014 1015 // v->ops->mdot 1016 template <device::cupm::DeviceType T> 1017 inline PetscErrorCode VecSeq_CUPM<T>::MDot(Vec xin, PetscInt nv, const Vec yin[], PetscScalar *z) noexcept 1018 { 1019 PetscFunctionBegin; 1020 if (PetscUnlikely(nv == 1)) { 1021 // dot handles nv = 0 correctly 1022 PetscCall(Dot(xin, const_cast<Vec>(yin[0]), z)); 1023 } else if (const auto n = xin->map->n) { 1024 PetscDeviceContext dctx; 1025 1026 PetscCheck(nv > 0, PETSC_COMM_SELF, PETSC_ERR_LIB, "Number of vectors provided to %s %" PetscInt_FMT " not positive", PETSC_FUNCTION_NAME, nv); 1027 PetscCall(GetHandles_(&dctx)); 1028 PetscCall(MDot_(std::integral_constant<bool, PetscDefined(USE_COMPLEX)>{}, xin, nv, yin, z, dctx)); 1029 // REVIEW ME: double count of flops?? 1030 PetscCall(PetscLogGpuFlops(nv * (2 * n - 1))); 1031 PetscCall(PetscDeviceContextSynchronize(dctx)); 1032 } else { 1033 PetscCall(PetscArrayzero(z, nv)); 1034 } 1035 PetscFunctionReturn(PETSC_SUCCESS); 1036 } 1037 1038 // v->ops->set 1039 template <device::cupm::DeviceType T> 1040 inline PetscErrorCode VecSeq_CUPM<T>::Set(Vec xin, PetscScalar alpha) noexcept 1041 { 1042 const auto n = xin->map->n; 1043 PetscDeviceContext dctx; 1044 cupmStream_t stream; 1045 1046 PetscFunctionBegin; 1047 PetscCall(GetHandles_(&dctx, &stream)); 1048 { 1049 const auto xptr = DeviceArrayWrite(dctx, xin); 1050 1051 if (alpha == PetscScalar(0.0)) { 1052 PetscCall(PetscCUPMMemsetAsync(xptr.data(), 0, n, stream)); 1053 } else { 1054 const auto dptr = thrust::device_pointer_cast(xptr.data()); 1055 1056 PetscCallThrust(THRUST_CALL(thrust::fill, stream, dptr, dptr + n, alpha)); 1057 } 1058 if (n) PetscCall(PetscDeviceContextSynchronize(dctx)); // don't sync if we did nothing 1059 } 1060 PetscFunctionReturn(PETSC_SUCCESS); 1061 } 1062 1063 // v->ops->scale 1064 template <device::cupm::DeviceType T> 1065 inline PetscErrorCode VecSeq_CUPM<T>::Scale(Vec xin, PetscScalar alpha) noexcept 1066 { 1067 PetscFunctionBegin; 1068 if (PetscUnlikely(alpha == PetscScalar(1.0))) PetscFunctionReturn(PETSC_SUCCESS); 1069 if (PetscUnlikely(alpha == PetscScalar(0.0))) { 1070 PetscCall(Set(xin, alpha)); 1071 } else if (const auto n = static_cast<cupmBlasInt_t>(xin->map->n)) { 1072 PetscDeviceContext dctx; 1073 cupmBlasHandle_t cupmBlasHandle; 1074 1075 PetscCall(GetHandles_(&dctx, &cupmBlasHandle)); 1076 PetscCall(PetscLogGpuTimeBegin()); 1077 PetscCallCUPMBLAS(cupmBlasXscal(cupmBlasHandle, n, cupmScalarPtrCast(&alpha), DeviceArrayReadWrite(dctx, xin), 1)); 1078 PetscCall(PetscLogGpuTimeEnd()); 1079 PetscCall(PetscLogGpuFlops(n)); 1080 PetscCall(PetscDeviceContextSynchronize(dctx)); 1081 } else { 1082 PetscCall(MaybeIncrementEmptyLocalVec(xin)); 1083 } 1084 PetscFunctionReturn(PETSC_SUCCESS); 1085 } 1086 1087 // v->ops->tdot 1088 template <device::cupm::DeviceType T> 1089 inline PetscErrorCode VecSeq_CUPM<T>::TDot(Vec xin, Vec yin, PetscScalar *z) noexcept 1090 { 1091 PetscFunctionBegin; 1092 if (const auto n = static_cast<cupmBlasInt_t>(xin->map->n)) { 1093 PetscDeviceContext dctx; 1094 cupmBlasHandle_t cupmBlasHandle; 1095 1096 PetscCall(GetHandles_(&dctx, &cupmBlasHandle)); 1097 PetscCall(PetscLogGpuTimeBegin()); 1098 PetscCallCUPMBLAS(cupmBlasXdotu(cupmBlasHandle, n, DeviceArrayRead(dctx, xin), 1, DeviceArrayRead(dctx, yin), 1, cupmScalarPtrCast(z))); 1099 PetscCall(PetscLogGpuTimeEnd()); 1100 PetscCall(PetscLogGpuFlops(2 * n - 1)); 1101 } else { 1102 *z = 0.0; 1103 } 1104 PetscFunctionReturn(PETSC_SUCCESS); 1105 } 1106 1107 // v->ops->copy 1108 template <device::cupm::DeviceType T> 1109 inline PetscErrorCode VecSeq_CUPM<T>::Copy(Vec xin, Vec yout) noexcept 1110 { 1111 PetscFunctionBegin; 1112 if (xin == yout) PetscFunctionReturn(PETSC_SUCCESS); 1113 if (const auto n = xin->map->n) { 1114 const auto xmask = xin->offloadmask; 1115 // silence buggy gcc warning: mode may be used uninitialized in this function 1116 auto mode = cupmMemcpyDeviceToDevice; 1117 PetscDeviceContext dctx; 1118 cupmStream_t stream; 1119 1120 // translate from PetscOffloadMask to cupmMemcpyKind 1121 switch (const auto ymask = yout->offloadmask) { 1122 case PETSC_OFFLOAD_UNALLOCATED: { 1123 PetscBool yiscupm; 1124 1125 PetscCall(PetscObjectTypeCompareAny(PetscObjectCast(yout), &yiscupm, VECSEQCUPM(), VECMPICUPM(), "")); 1126 if (yiscupm) { 1127 mode = PetscOffloadDevice(xmask) ? cupmMemcpyDeviceToDevice : cupmMemcpyHostToHost; 1128 break; 1129 } 1130 } // fall-through if unallocated and not cupm 1131 #if PETSC_CPP_VERSION >= 17 1132 [[fallthrough]]; 1133 #endif 1134 case PETSC_OFFLOAD_CPU: 1135 mode = PetscOffloadHost(xmask) ? cupmMemcpyHostToHost : cupmMemcpyDeviceToHost; 1136 break; 1137 case PETSC_OFFLOAD_BOTH: 1138 case PETSC_OFFLOAD_GPU: 1139 mode = PetscOffloadDevice(xmask) ? cupmMemcpyDeviceToDevice : cupmMemcpyHostToDevice; 1140 break; 1141 default: 1142 SETERRQ(PETSC_COMM_SELF, PETSC_ERR_ARG_INCOMP, "Incompatible offload mask %s", PetscOffloadMaskToString(ymask)); 1143 } 1144 1145 PetscCall(GetHandles_(&dctx, &stream)); 1146 switch (mode) { 1147 case cupmMemcpyDeviceToDevice: // the best case 1148 case cupmMemcpyHostToDevice: { // not terrible 1149 const auto yptr = DeviceArrayWrite(dctx, yout); 1150 const auto xptr = mode == cupmMemcpyDeviceToDevice ? DeviceArrayRead(dctx, xin).data() : HostArrayRead(dctx, xin).data(); 1151 1152 PetscCall(PetscLogGpuTimeBegin()); 1153 PetscCall(PetscCUPMMemcpyAsync(yptr.data(), xptr, n, mode, stream)); 1154 PetscCall(PetscLogGpuTimeEnd()); 1155 } break; 1156 case cupmMemcpyDeviceToHost: // not great 1157 case cupmMemcpyHostToHost: { // worst case 1158 const auto xptr = mode == cupmMemcpyDeviceToHost ? DeviceArrayRead(dctx, xin).data() : HostArrayRead(dctx, xin).data(); 1159 PetscScalar *yptr; 1160 1161 PetscCall(VecGetArrayWrite(yout, &yptr)); 1162 if (mode == cupmMemcpyDeviceToHost) PetscCall(PetscLogGpuTimeBegin()); 1163 PetscCall(PetscCUPMMemcpyAsync(yptr, xptr, n, mode, stream, /* force async */ true)); 1164 if (mode == cupmMemcpyDeviceToHost) PetscCall(PetscLogGpuTimeEnd()); 1165 PetscCall(VecRestoreArrayWrite(yout, &yptr)); 1166 } break; 1167 default: 1168 SETERRQ(PETSC_COMM_SELF, PETSC_ERR_GPU, "Unknown cupmMemcpyKind %d", static_cast<int>(mode)); 1169 } 1170 PetscCall(PetscDeviceContextSynchronize(dctx)); 1171 } else { 1172 PetscCall(MaybeIncrementEmptyLocalVec(yout)); 1173 } 1174 PetscFunctionReturn(PETSC_SUCCESS); 1175 } 1176 1177 // v->ops->swap 1178 template <device::cupm::DeviceType T> 1179 inline PetscErrorCode VecSeq_CUPM<T>::Swap(Vec xin, Vec yin) noexcept 1180 { 1181 PetscFunctionBegin; 1182 if (xin == yin) PetscFunctionReturn(PETSC_SUCCESS); 1183 if (const auto n = static_cast<cupmBlasInt_t>(xin->map->n)) { 1184 PetscDeviceContext dctx; 1185 cupmBlasHandle_t cupmBlasHandle; 1186 1187 PetscCall(GetHandles_(&dctx, &cupmBlasHandle)); 1188 PetscCall(PetscLogGpuTimeBegin()); 1189 PetscCallCUPMBLAS(cupmBlasXswap(cupmBlasHandle, n, DeviceArrayReadWrite(dctx, xin), 1, DeviceArrayReadWrite(dctx, yin), 1)); 1190 PetscCall(PetscLogGpuTimeEnd()); 1191 PetscCall(PetscDeviceContextSynchronize(dctx)); 1192 } else { 1193 PetscCall(MaybeIncrementEmptyLocalVec(xin)); 1194 PetscCall(MaybeIncrementEmptyLocalVec(yin)); 1195 } 1196 PetscFunctionReturn(PETSC_SUCCESS); 1197 } 1198 1199 // v->ops->axpby 1200 template <device::cupm::DeviceType T> 1201 inline PetscErrorCode VecSeq_CUPM<T>::AXPBY(Vec yin, PetscScalar alpha, PetscScalar beta, Vec xin) noexcept 1202 { 1203 PetscFunctionBegin; 1204 if (alpha == PetscScalar(0.0)) { 1205 PetscCall(Scale(yin, beta)); 1206 } else if (beta == PetscScalar(1.0)) { 1207 PetscCall(AXPY(yin, alpha, xin)); 1208 } else if (alpha == PetscScalar(1.0)) { 1209 PetscCall(AYPX(yin, beta, xin)); 1210 } else if (const auto n = static_cast<cupmBlasInt_t>(yin->map->n)) { 1211 const auto betaIsZero = beta == PetscScalar(0.0); 1212 const auto aptr = cupmScalarPtrCast(&alpha); 1213 PetscDeviceContext dctx; 1214 cupmBlasHandle_t cupmBlasHandle; 1215 1216 PetscCall(GetHandles_(&dctx, &cupmBlasHandle)); 1217 { 1218 const auto xptr = DeviceArrayRead(dctx, xin); 1219 1220 if (betaIsZero /* beta = 0 */) { 1221 // here we can get away with purely write-only as we memcpy into it first 1222 const auto yptr = DeviceArrayWrite(dctx, yin); 1223 cupmStream_t stream; 1224 1225 PetscCall(GetHandlesFrom_(dctx, &stream)); 1226 PetscCall(PetscLogGpuTimeBegin()); 1227 PetscCall(PetscCUPMMemcpyAsync(yptr.data(), xptr.data(), n, cupmMemcpyDeviceToDevice, stream)); 1228 PetscCallCUPMBLAS(cupmBlasXscal(cupmBlasHandle, n, aptr, yptr.cupmdata(), 1)); 1229 } else { 1230 const auto yptr = DeviceArrayReadWrite(dctx, yin); 1231 1232 PetscCall(PetscLogGpuTimeBegin()); 1233 PetscCallCUPMBLAS(cupmBlasXscal(cupmBlasHandle, n, cupmScalarPtrCast(&beta), yptr.cupmdata(), 1)); 1234 PetscCallCUPMBLAS(cupmBlasXaxpy(cupmBlasHandle, n, aptr, xptr.cupmdata(), 1, yptr.cupmdata(), 1)); 1235 } 1236 } 1237 PetscCall(PetscLogGpuTimeEnd()); 1238 PetscCall(PetscLogGpuFlops((betaIsZero ? 1 : 3) * n)); 1239 PetscCall(PetscDeviceContextSynchronize(dctx)); 1240 } else { 1241 PetscCall(MaybeIncrementEmptyLocalVec(yin)); 1242 } 1243 PetscFunctionReturn(PETSC_SUCCESS); 1244 } 1245 1246 // v->ops->axpbypcz 1247 template <device::cupm::DeviceType T> 1248 inline PetscErrorCode VecSeq_CUPM<T>::AXPBYPCZ(Vec zin, PetscScalar alpha, PetscScalar beta, PetscScalar gamma, Vec xin, Vec yin) noexcept 1249 { 1250 PetscFunctionBegin; 1251 if (gamma != PetscScalar(1.0)) PetscCall(Scale(zin, gamma)); 1252 PetscCall(AXPY(zin, alpha, xin)); 1253 PetscCall(AXPY(zin, beta, yin)); 1254 PetscFunctionReturn(PETSC_SUCCESS); 1255 } 1256 1257 // v->ops->norm 1258 template <device::cupm::DeviceType T> 1259 inline PetscErrorCode VecSeq_CUPM<T>::Norm(Vec xin, NormType type, PetscReal *z) noexcept 1260 { 1261 PetscDeviceContext dctx; 1262 cupmBlasHandle_t cupmBlasHandle; 1263 1264 PetscFunctionBegin; 1265 PetscCall(GetHandles_(&dctx, &cupmBlasHandle)); 1266 if (const auto n = static_cast<cupmBlasInt_t>(xin->map->n)) { 1267 const auto xptr = DeviceArrayRead(dctx, xin); 1268 PetscInt flopCount = 0; 1269 1270 PetscCall(PetscLogGpuTimeBegin()); 1271 switch (type) { 1272 case NORM_1_AND_2: 1273 case NORM_1: 1274 PetscCallCUPMBLAS(cupmBlasXasum(cupmBlasHandle, n, xptr.cupmdata(), 1, cupmRealPtrCast(z))); 1275 flopCount = std::max(n - 1, 0); 1276 if (type == NORM_1) break; 1277 ++z; // fall-through 1278 #if PETSC_CPP_VERSION >= 17 1279 [[fallthrough]]; 1280 #endif 1281 case NORM_2: 1282 case NORM_FROBENIUS: 1283 PetscCallCUPMBLAS(cupmBlasXnrm2(cupmBlasHandle, n, xptr.cupmdata(), 1, cupmRealPtrCast(z))); 1284 flopCount += std::max(2 * n - 1, 0); // += in case we've fallen through from NORM_1_AND_2 1285 break; 1286 case NORM_INFINITY: { 1287 cupmBlasInt_t max_loc = 0; 1288 PetscScalar xv = 0.; 1289 cupmStream_t stream; 1290 1291 PetscCall(GetHandlesFrom_(dctx, &stream)); 1292 PetscCallCUPMBLAS(cupmBlasXamax(cupmBlasHandle, n, xptr.cupmdata(), 1, &max_loc)); 1293 PetscCall(PetscCUPMMemcpyAsync(&xv, xptr.data() + max_loc - 1, 1, cupmMemcpyDeviceToHost, stream)); 1294 *z = PetscAbsScalar(xv); 1295 // REVIEW ME: flopCount = ??? 1296 } break; 1297 } 1298 PetscCall(PetscLogGpuTimeEnd()); 1299 PetscCall(PetscLogGpuFlops(flopCount)); 1300 } else { 1301 z[0] = 0.0; 1302 z[type == NORM_1_AND_2] = 0.0; 1303 } 1304 PetscFunctionReturn(PETSC_SUCCESS); 1305 } 1306 1307 namespace detail 1308 { 1309 1310 struct dotnorm2_mult { 1311 PETSC_NODISCARD PETSC_HOSTDEVICE_INLINE_DECL thrust::tuple<PetscScalar, PetscScalar> operator()(const PetscScalar &s, const PetscScalar &t) const noexcept 1312 { 1313 const auto conjt = PetscConj(t); 1314 1315 return {s * conjt, t * conjt}; 1316 } 1317 }; 1318 1319 // it is positively __bananas__ that thrust does not define default operator+ for tuples... I 1320 // would do it myself but now I am worried that they do so on purpose... 1321 struct dotnorm2_tuple_plus { 1322 using value_type = thrust::tuple<PetscScalar, PetscScalar>; 1323 1324 PETSC_NODISCARD PETSC_HOSTDEVICE_INLINE_DECL value_type operator()(const value_type &lhs, const value_type &rhs) const noexcept { return {lhs.get<0>() + rhs.get<0>(), lhs.get<1>() + rhs.get<1>()}; } 1325 }; 1326 1327 } // namespace detail 1328 1329 // v->ops->dotnorm2 1330 template <device::cupm::DeviceType T> 1331 inline PetscErrorCode VecSeq_CUPM<T>::DotNorm2(Vec s, Vec t, PetscScalar *dp, PetscScalar *nm) noexcept 1332 { 1333 PetscDeviceContext dctx; 1334 cupmStream_t stream; 1335 1336 PetscFunctionBegin; 1337 PetscCall(GetHandles_(&dctx, &stream)); 1338 { 1339 PetscScalar dpt = 0.0, nmt = 0.0; 1340 const auto sdptr = thrust::device_pointer_cast(DeviceArrayRead(dctx, s).data()); 1341 1342 // clang-format off 1343 PetscCallThrust( 1344 thrust::tie(*dp, *nm) = THRUST_CALL( 1345 thrust::inner_product, 1346 stream, 1347 sdptr, sdptr+s->map->n, thrust::device_pointer_cast(DeviceArrayRead(dctx, t).data()), 1348 thrust::make_tuple(dpt, nmt), 1349 detail::dotnorm2_tuple_plus{}, detail::dotnorm2_mult{} 1350 ); 1351 ); 1352 // clang-format on 1353 } 1354 PetscFunctionReturn(PETSC_SUCCESS); 1355 } 1356 1357 namespace detail 1358 { 1359 1360 struct conjugate { 1361 PETSC_NODISCARD PETSC_HOSTDEVICE_INLINE_DECL PetscScalar operator()(PetscScalar x) const noexcept { return PetscConj(x); } 1362 }; 1363 1364 } // namespace detail 1365 1366 // v->ops->conjugate 1367 template <device::cupm::DeviceType T> 1368 inline PetscErrorCode VecSeq_CUPM<T>::Conjugate(Vec xin) noexcept 1369 { 1370 PetscFunctionBegin; 1371 if (PetscDefined(USE_COMPLEX)) PetscCall(PointwiseUnary_(detail::conjugate{}, xin)); 1372 PetscFunctionReturn(PETSC_SUCCESS); 1373 } 1374 1375 namespace detail 1376 { 1377 1378 struct real_part { 1379 PETSC_NODISCARD PETSC_HOSTDEVICE_INLINE_DECL thrust::tuple<PetscReal, PetscInt> operator()(const thrust::tuple<PetscScalar, PetscInt> &x) const { return {PetscRealPart(x.get<0>()), x.get<1>()}; } 1380 1381 PETSC_NODISCARD PETSC_HOSTDEVICE_INLINE_DECL PetscReal operator()(PetscScalar x) const { return PetscRealPart(x); } 1382 }; 1383 1384 // deriving from Operator allows us to "store" an instance of the operator in the class but 1385 // also take advantage of empty base class optimization if the operator is stateless 1386 template <typename Operator> 1387 class tuple_compare : Operator { 1388 public: 1389 using tuple_type = thrust::tuple<PetscReal, PetscInt>; 1390 using operator_type = Operator; 1391 1392 PETSC_NODISCARD PETSC_HOSTDEVICE_INLINE_DECL tuple_type operator()(const tuple_type &x, const tuple_type &y) const noexcept 1393 { 1394 if (op_()(y.get<0>(), x.get<0>())) { 1395 // if y is strictly greater/less than x, return y 1396 return y; 1397 } else if (y.get<0>() == x.get<0>()) { 1398 // if equal, prefer lower index 1399 return y.get<1>() < x.get<1>() ? y : x; 1400 } 1401 // otherwise return x 1402 return x; 1403 } 1404 1405 private: 1406 PETSC_NODISCARD PETSC_HOSTDEVICE_INLINE_DECL const operator_type &op_() const noexcept { return *this; } 1407 }; 1408 1409 } // namespace detail 1410 1411 template <device::cupm::DeviceType T> 1412 template <typename TupleFuncT, typename UnaryFuncT> 1413 inline PetscErrorCode VecSeq_CUPM<T>::MinMax_(TupleFuncT &&tuple_ftr, UnaryFuncT &&unary_ftr, Vec v, PetscInt *p, PetscReal *m) noexcept 1414 { 1415 PetscFunctionBegin; 1416 PetscCheckTypeNames(v, VECSEQCUPM(), VECMPICUPM()); 1417 if (p) *p = -1; 1418 if (const auto n = v->map->n) { 1419 PetscDeviceContext dctx; 1420 cupmStream_t stream; 1421 1422 PetscCall(GetHandles_(&dctx, &stream)); 1423 // needed to: 1424 // 1. switch between transform_reduce and reduce 1425 // 2. strip the real_part functor from the arguments 1426 #if PetscDefined(USE_COMPLEX) 1427 #define THRUST_MINMAX_REDUCE(...) THRUST_CALL(thrust::transform_reduce, __VA_ARGS__) 1428 #else 1429 #define THRUST_MINMAX_REDUCE(s, b, e, real_part__, ...) THRUST_CALL(thrust::reduce, s, b, e, __VA_ARGS__) 1430 #endif 1431 { 1432 const auto vptr = thrust::device_pointer_cast(DeviceArrayRead(dctx, v).data()); 1433 1434 if (p) { 1435 // clang-format off 1436 const auto zip = thrust::make_zip_iterator( 1437 thrust::make_tuple(std::move(vptr), thrust::make_counting_iterator(PetscInt{0})) 1438 ); 1439 // clang-format on 1440 // need to use preprocessor conditionals since otherwise thrust complains about not being 1441 // able to convert a thrust::device_reference<PetscScalar> to a PetscReal on complex 1442 // builds... 1443 // clang-format off 1444 PetscCallThrust( 1445 thrust::tie(*m, *p) = THRUST_MINMAX_REDUCE( 1446 stream, zip, zip + n, detail::real_part{}, 1447 thrust::make_tuple(*m, *p), std::forward<TupleFuncT>(tuple_ftr) 1448 ); 1449 ); 1450 // clang-format on 1451 } else { 1452 // clang-format off 1453 PetscCallThrust( 1454 *m = THRUST_MINMAX_REDUCE( 1455 stream, vptr, vptr + n, detail::real_part{}, 1456 *m, std::forward<UnaryFuncT>(unary_ftr) 1457 ); 1458 ); 1459 // clang-format on 1460 } 1461 } 1462 #undef THRUST_MINMAX_REDUCE 1463 } 1464 // REVIEW ME: flops? 1465 PetscFunctionReturn(PETSC_SUCCESS); 1466 } 1467 1468 // v->ops->max 1469 template <device::cupm::DeviceType T> 1470 inline PetscErrorCode VecSeq_CUPM<T>::Max(Vec v, PetscInt *p, PetscReal *m) noexcept 1471 { 1472 using tuple_functor = detail::tuple_compare<thrust::greater<PetscReal>>; 1473 using unary_functor = thrust::maximum<PetscReal>; 1474 1475 PetscFunctionBegin; 1476 *m = PETSC_MIN_REAL; 1477 // use {} constructor syntax otherwise most vexing parse 1478 PetscCall(MinMax_(tuple_functor{}, unary_functor{}, v, p, m)); 1479 PetscFunctionReturn(PETSC_SUCCESS); 1480 } 1481 1482 // v->ops->min 1483 template <device::cupm::DeviceType T> 1484 inline PetscErrorCode VecSeq_CUPM<T>::Min(Vec v, PetscInt *p, PetscReal *m) noexcept 1485 { 1486 using tuple_functor = detail::tuple_compare<thrust::less<PetscReal>>; 1487 using unary_functor = thrust::minimum<PetscReal>; 1488 1489 PetscFunctionBegin; 1490 *m = PETSC_MAX_REAL; 1491 // use {} constructor syntax otherwise most vexing parse 1492 PetscCall(MinMax_(tuple_functor{}, unary_functor{}, v, p, m)); 1493 PetscFunctionReturn(PETSC_SUCCESS); 1494 } 1495 1496 // v->ops->sum 1497 template <device::cupm::DeviceType T> 1498 inline PetscErrorCode VecSeq_CUPM<T>::Sum(Vec v, PetscScalar *sum) noexcept 1499 { 1500 PetscFunctionBegin; 1501 if (const auto n = v->map->n) { 1502 PetscDeviceContext dctx; 1503 cupmStream_t stream; 1504 1505 PetscCall(GetHandles_(&dctx, &stream)); 1506 const auto dptr = thrust::device_pointer_cast(DeviceArrayRead(dctx, v).data()); 1507 // REVIEW ME: why not cupmBlasXasum()? 1508 PetscCallThrust(*sum = THRUST_CALL(thrust::reduce, stream, dptr, dptr + n, PetscScalar{0.0});); 1509 // REVIEW ME: must be at least n additions 1510 PetscCall(PetscLogGpuFlops(n)); 1511 } else { 1512 *sum = 0.0; 1513 } 1514 PetscFunctionReturn(PETSC_SUCCESS); 1515 } 1516 1517 template <device::cupm::DeviceType T> 1518 inline PetscErrorCode VecSeq_CUPM<T>::Shift(Vec v, PetscScalar shift) noexcept 1519 { 1520 PetscFunctionBegin; 1521 PetscCall(PointwiseUnary_(device::cupm::functors::make_plus_equals(shift), v)); 1522 PetscFunctionReturn(PETSC_SUCCESS); 1523 } 1524 1525 template <device::cupm::DeviceType T> 1526 inline PetscErrorCode VecSeq_CUPM<T>::SetRandom(Vec v, PetscRandom rand) noexcept 1527 { 1528 PetscFunctionBegin; 1529 if (const auto n = v->map->n) { 1530 PetscBool iscurand; 1531 PetscDeviceContext dctx; 1532 1533 PetscCall(GetHandles_(&dctx)); 1534 PetscCall(PetscObjectTypeCompare(PetscObjectCast(rand), PETSCCURAND, &iscurand)); 1535 if (iscurand) PetscCall(PetscRandomGetValues(rand, n, DeviceArrayWrite(dctx, v))); 1536 else PetscCall(PetscRandomGetValues(rand, n, HostArrayWrite(dctx, v))); 1537 } else { 1538 PetscCall(MaybeIncrementEmptyLocalVec(v)); 1539 } 1540 // REVIEW ME: flops???? 1541 // REVIEW ME: Timing??? 1542 PetscFunctionReturn(PETSC_SUCCESS); 1543 } 1544 1545 // v->ops->setpreallocation 1546 template <device::cupm::DeviceType T> 1547 inline PetscErrorCode VecSeq_CUPM<T>::SetPreallocationCOO(Vec v, PetscCount ncoo, const PetscInt coo_i[]) noexcept 1548 { 1549 PetscDeviceContext dctx; 1550 1551 PetscFunctionBegin; 1552 PetscCall(GetHandles_(&dctx)); 1553 PetscCall(VecSetPreallocationCOO_Seq(v, ncoo, coo_i)); 1554 PetscCall(SetPreallocationCOO_CUPMBase(v, ncoo, coo_i, dctx)); 1555 PetscFunctionReturn(PETSC_SUCCESS); 1556 } 1557 1558 namespace kernels 1559 { 1560 1561 template <typename F> 1562 PETSC_DEVICE_INLINE_DECL void add_coo_values_impl(const PetscScalar *PETSC_RESTRICT vv, PetscCount n, const PetscCount *PETSC_RESTRICT jmap, const PetscCount *PETSC_RESTRICT perm, InsertMode imode, PetscScalar *PETSC_RESTRICT xv, F &&xvindex) 1563 { 1564 ::Petsc::device::cupm::kernels::util::grid_stride_1D(n, [=](PetscCount i) { 1565 const auto end = jmap[i + 1]; 1566 const auto idx = xvindex(i); 1567 PetscScalar sum = 0.0; 1568 1569 for (auto k = jmap[i]; k < end; ++k) sum += vv[perm[k]]; 1570 1571 if (imode == INSERT_VALUES) { 1572 xv[idx] = sum; 1573 } else { 1574 xv[idx] += sum; 1575 } 1576 }); 1577 return; 1578 } 1579 1580 namespace 1581 { 1582 1583 PETSC_KERNEL_DECL void add_coo_values(const PetscScalar *PETSC_RESTRICT v, PetscCount n, const PetscCount *PETSC_RESTRICT jmap1, const PetscCount *PETSC_RESTRICT perm1, InsertMode imode, PetscScalar *PETSC_RESTRICT xv) 1584 { 1585 add_coo_values_impl(v, n, jmap1, perm1, imode, xv, [](PetscCount i) { return i; }); 1586 return; 1587 } 1588 1589 } // namespace 1590 1591 #if PetscDefined(USING_HCC) 1592 namespace do_not_use 1593 { 1594 1595 // Needed to silence clang warning: 1596 // 1597 // warning: function 'FUNCTION NAME' is not needed and will not be emitted 1598 // 1599 // The warning is silly, since the function *is* used, however the host compiler does not 1600 // appear see this. Likely because the function using it is in a template. 1601 // 1602 // This warning appeared in clang-11, and still persists until clang-15 (21/02/2023) 1603 inline void silence_warning_function_sum_kernel_is_not_needed_and_will_not_be_emitted() 1604 { 1605 (void)sum_kernel; 1606 } 1607 1608 inline void silence_warning_function_add_coo_values_is_not_needed_and_will_not_be_emitted() 1609 { 1610 (void)add_coo_values; 1611 } 1612 1613 } // namespace do_not_use 1614 #endif 1615 1616 } // namespace kernels 1617 1618 // v->ops->setvaluescoo 1619 template <device::cupm::DeviceType T> 1620 inline PetscErrorCode VecSeq_CUPM<T>::SetValuesCOO(Vec x, const PetscScalar v[], InsertMode imode) noexcept 1621 { 1622 auto vv = const_cast<PetscScalar *>(v); 1623 PetscMemType memtype; 1624 PetscDeviceContext dctx; 1625 cupmStream_t stream; 1626 1627 PetscFunctionBegin; 1628 PetscCall(GetHandles_(&dctx, &stream)); 1629 PetscCall(PetscGetMemType(v, &memtype)); 1630 if (PetscMemTypeHost(memtype)) { 1631 const auto size = VecIMPLCast(x)->coo_n; 1632 1633 // If user gave v[] in host, we might need to copy it to device if any 1634 PetscCall(PetscDeviceMalloc(dctx, PETSC_MEMTYPE_CUPM(), size, &vv)); 1635 PetscCall(PetscCUPMMemcpyAsync(vv, v, size, cupmMemcpyHostToDevice, stream)); 1636 } 1637 1638 if (const auto n = x->map->n) { 1639 const auto vcu = VecCUPMCast(x); 1640 1641 PetscCall(PetscCUPMLaunchKernel1D(n, 0, stream, kernels::add_coo_values, vv, n, vcu->jmap1_d, vcu->perm1_d, imode, imode == INSERT_VALUES ? DeviceArrayWrite(dctx, x).data() : DeviceArrayReadWrite(dctx, x).data())); 1642 } else { 1643 PetscCall(MaybeIncrementEmptyLocalVec(x)); 1644 } 1645 1646 if (PetscMemTypeHost(memtype)) PetscCall(PetscDeviceFree(dctx, vv)); 1647 PetscCall(PetscDeviceContextSynchronize(dctx)); 1648 PetscFunctionReturn(PETSC_SUCCESS); 1649 } 1650 1651 } // namespace impl 1652 1653 // ========================================================================================== 1654 // VecSeq_CUPM - Implementations 1655 // ========================================================================================== 1656 1657 namespace 1658 { 1659 1660 template <device::cupm::DeviceType T> 1661 inline PetscErrorCode VecCreateSeqCUPMAsync(MPI_Comm comm, PetscInt n, Vec *v) noexcept 1662 { 1663 PetscFunctionBegin; 1664 PetscValidPointer(v, 4); 1665 PetscCall(impl::VecSeq_CUPM<T>::CreateSeqCUPM(comm, 0, n, v, PETSC_TRUE)); 1666 PetscFunctionReturn(PETSC_SUCCESS); 1667 } 1668 1669 template <device::cupm::DeviceType T> 1670 inline PetscErrorCode VecCreateSeqCUPMWithArraysAsync(MPI_Comm comm, PetscInt bs, PetscInt n, const PetscScalar cpuarray[], const PetscScalar gpuarray[], Vec *v) noexcept 1671 { 1672 PetscFunctionBegin; 1673 if (n && cpuarray) PetscValidScalarPointer(cpuarray, 4); 1674 PetscValidPointer(v, 6); 1675 PetscCall(impl::VecSeq_CUPM<T>::CreateSeqCUPMWithBothArrays(comm, bs, n, cpuarray, gpuarray, v)); 1676 PetscFunctionReturn(PETSC_SUCCESS); 1677 } 1678 1679 template <PetscMemoryAccessMode mode, device::cupm::DeviceType T> 1680 inline PetscErrorCode VecCUPMGetArrayAsync_Private(Vec v, PetscScalar **a, PetscDeviceContext dctx) noexcept 1681 { 1682 PetscFunctionBegin; 1683 PetscValidHeaderSpecific(v, VEC_CLASSID, 1); 1684 PetscValidPointer(a, 2); 1685 PetscCall(PetscDeviceContextGetOptionalNullContext_Internal(&dctx)); 1686 PetscCall(impl::VecSeq_CUPM<T>::template GetArray<PETSC_MEMTYPE_DEVICE, mode>(v, a, dctx)); 1687 PetscFunctionReturn(PETSC_SUCCESS); 1688 } 1689 1690 template <PetscMemoryAccessMode mode, device::cupm::DeviceType T> 1691 inline PetscErrorCode VecCUPMRestoreArrayAsync_Private(Vec v, PetscScalar **a, PetscDeviceContext dctx) noexcept 1692 { 1693 PetscFunctionBegin; 1694 PetscValidHeaderSpecific(v, VEC_CLASSID, 1); 1695 PetscCall(PetscDeviceContextGetOptionalNullContext_Internal(&dctx)); 1696 PetscCall(impl::VecSeq_CUPM<T>::template RestoreArray<PETSC_MEMTYPE_DEVICE, mode>(v, a, dctx)); 1697 PetscFunctionReturn(PETSC_SUCCESS); 1698 } 1699 1700 template <device::cupm::DeviceType T> 1701 inline PetscErrorCode VecCUPMGetArrayAsync(Vec v, PetscScalar **a, PetscDeviceContext dctx = nullptr) noexcept 1702 { 1703 PetscFunctionBegin; 1704 PetscCall(VecCUPMGetArrayAsync_Private<PETSC_MEMORY_ACCESS_READ_WRITE, T>(v, a, dctx)); 1705 PetscFunctionReturn(PETSC_SUCCESS); 1706 } 1707 1708 template <device::cupm::DeviceType T> 1709 inline PetscErrorCode VecCUPMRestoreArrayAsync(Vec v, PetscScalar **a, PetscDeviceContext dctx = nullptr) noexcept 1710 { 1711 PetscFunctionBegin; 1712 PetscCall(VecCUPMRestoreArrayAsync_Private<PETSC_MEMORY_ACCESS_READ_WRITE, T>(v, a, dctx)); 1713 PetscFunctionReturn(PETSC_SUCCESS); 1714 } 1715 1716 template <device::cupm::DeviceType T> 1717 inline PetscErrorCode VecCUPMGetArrayReadAsync(Vec v, const PetscScalar **a, PetscDeviceContext dctx = nullptr) noexcept 1718 { 1719 PetscFunctionBegin; 1720 PetscCall(VecCUPMGetArrayAsync_Private<PETSC_MEMORY_ACCESS_READ, T>(v, const_cast<PetscScalar **>(a), dctx)); 1721 PetscFunctionReturn(PETSC_SUCCESS); 1722 } 1723 1724 template <device::cupm::DeviceType T> 1725 inline PetscErrorCode VecCUPMRestoreArrayReadAsync(Vec v, const PetscScalar **a, PetscDeviceContext dctx = nullptr) noexcept 1726 { 1727 PetscFunctionBegin; 1728 PetscCall(VecCUPMRestoreArrayAsync_Private<PETSC_MEMORY_ACCESS_READ, T>(v, const_cast<PetscScalar **>(a), dctx)); 1729 PetscFunctionReturn(PETSC_SUCCESS); 1730 } 1731 1732 template <device::cupm::DeviceType T> 1733 inline PetscErrorCode VecCUPMGetArrayWriteAsync(Vec v, PetscScalar **a, PetscDeviceContext dctx = nullptr) noexcept 1734 { 1735 PetscFunctionBegin; 1736 PetscCall(VecCUPMGetArrayAsync_Private<PETSC_MEMORY_ACCESS_WRITE, T>(v, a, dctx)); 1737 PetscFunctionReturn(PETSC_SUCCESS); 1738 } 1739 1740 template <device::cupm::DeviceType T> 1741 inline PetscErrorCode VecCUPMRestoreArrayWriteAsync(Vec v, PetscScalar **a, PetscDeviceContext dctx = nullptr) noexcept 1742 { 1743 PetscFunctionBegin; 1744 PetscCall(VecCUPMRestoreArrayAsync_Private<PETSC_MEMORY_ACCESS_WRITE, T>(v, a, dctx)); 1745 PetscFunctionReturn(PETSC_SUCCESS); 1746 } 1747 1748 template <device::cupm::DeviceType T> 1749 inline PetscErrorCode VecCUPMPlaceArrayAsync(Vec vin, const PetscScalar a[]) noexcept 1750 { 1751 PetscFunctionBegin; 1752 PetscValidHeaderSpecific(vin, VEC_CLASSID, 1); 1753 PetscCall(impl::VecSeq_CUPM<T>::template PlaceArray<PETSC_MEMTYPE_DEVICE>(vin, a)); 1754 PetscFunctionReturn(PETSC_SUCCESS); 1755 } 1756 1757 template <device::cupm::DeviceType T> 1758 inline PetscErrorCode VecCUPMReplaceArrayAsync(Vec vin, const PetscScalar a[]) noexcept 1759 { 1760 PetscFunctionBegin; 1761 PetscValidHeaderSpecific(vin, VEC_CLASSID, 1); 1762 PetscCall(impl::VecSeq_CUPM<T>::template ReplaceArray<PETSC_MEMTYPE_DEVICE>(vin, a)); 1763 PetscFunctionReturn(PETSC_SUCCESS); 1764 } 1765 1766 template <device::cupm::DeviceType T> 1767 inline PetscErrorCode VecCUPMResetArrayAsync(Vec vin) noexcept 1768 { 1769 PetscFunctionBegin; 1770 PetscValidHeaderSpecific(vin, VEC_CLASSID, 1); 1771 PetscCall(impl::VecSeq_CUPM<T>::template ResetArray<PETSC_MEMTYPE_DEVICE>(vin)); 1772 PetscFunctionReturn(PETSC_SUCCESS); 1773 } 1774 1775 } // anonymous namespace 1776 1777 } // namespace cupm 1778 1779 } // namespace vec 1780 1781 } // namespace Petsc 1782 1783 #endif // __cplusplus 1784 1785 #endif // PETSCVECSEQCUPM_HPP 1786