1 #ifndef PETSCVECSEQCUPM_HPP 2 #define PETSCVECSEQCUPM_HPP 3 4 #include <petsc/private/veccupmimpl.h> 5 6 #if defined(__cplusplus) 7 #include <petsc/private/randomimpl.h> // for _p_PetscRandom 8 9 #include <petsc/private/cpp/utility.hpp> // util::exchange, util::index_sequence 10 11 #include "../src/sys/objects/device/impls/cupm/cupmthrustutility.hpp" 12 #include "../src/sys/objects/device/impls/cupm/kernels.hpp" 13 14 #if PetscDefined(USE_COMPLEX) 15 #include <thrust/transform_reduce.h> 16 #endif 17 #include <thrust/transform.h> 18 #include <thrust/reduce.h> 19 #include <thrust/functional.h> 20 #include <thrust/tuple.h> 21 #include <thrust/device_ptr.h> 22 #include <thrust/iterator/zip_iterator.h> 23 #include <thrust/iterator/counting_iterator.h> 24 #include <thrust/iterator/constant_iterator.h> 25 #include <thrust/inner_product.h> 26 27 namespace Petsc 28 { 29 30 namespace vec 31 { 32 33 namespace cupm 34 { 35 36 namespace impl 37 { 38 39 // ========================================================================================== 40 // VecSeq_CUPM 41 // ========================================================================================== 42 43 template <device::cupm::DeviceType T> 44 class VecSeq_CUPM : Vec_CUPMBase<T, VecSeq_CUPM<T>> { 45 public: 46 PETSC_VEC_CUPM_BASE_CLASS_HEADER(base_type, T, VecSeq_CUPM<T>); 47 48 private: 49 PETSC_NODISCARD static Vec_Seq *VecIMPLCast_(Vec) noexcept; 50 PETSC_NODISCARD static constexpr VecType VECIMPLCUPM_() noexcept; 51 PETSC_NODISCARD static constexpr VecType VECIMPL_() noexcept; 52 53 static PetscErrorCode VecDestroy_IMPL_(Vec) noexcept; 54 static PetscErrorCode VecResetArray_IMPL_(Vec) noexcept; 55 static PetscErrorCode VecPlaceArray_IMPL_(Vec, const PetscScalar *) noexcept; 56 static PetscErrorCode VecCreate_IMPL_Private_(Vec, PetscBool *, PetscInt, PetscScalar *) noexcept; 57 58 static PetscErrorCode MaybeIncrementEmptyLocalVec(Vec) noexcept; 59 60 // common core for min and max 61 template <typename TupleFuncT, typename UnaryFuncT> 62 static PetscErrorCode MinMax_(TupleFuncT &&, UnaryFuncT &&, Vec, PetscInt *, PetscReal *) noexcept; 63 // common core for pointwise binary and pointwise unary thrust functions 64 template <typename BinaryFuncT> 65 static PetscErrorCode PointwiseBinary_(BinaryFuncT &&, Vec, Vec, Vec) noexcept; 66 template <typename UnaryFuncT> 67 static PetscErrorCode PointwiseUnary_(UnaryFuncT &&, Vec, Vec /*out*/ = nullptr) noexcept; 68 // mdot dispatchers 69 static PetscErrorCode MDot_(/* use complex = */ std::true_type, Vec, PetscInt, const Vec[], PetscScalar *, PetscDeviceContext) noexcept; 70 static PetscErrorCode MDot_(/* use complex = */ std::false_type, Vec, PetscInt, const Vec[], PetscScalar *, PetscDeviceContext) noexcept; 71 template <std::size_t... Idx> 72 static PetscErrorCode MDot_kernel_dispatch_(PetscDeviceContext, cupmStream_t, const PetscScalar *, const Vec[], PetscInt, PetscScalar *, util::index_sequence<Idx...>) noexcept; 73 template <int> 74 static PetscErrorCode MDot_kernel_dispatch_(PetscDeviceContext, cupmStream_t, const PetscScalar *, const Vec[], PetscInt, PetscScalar *, PetscInt &) noexcept; 75 template <std::size_t... Idx> 76 static PetscErrorCode MAXPY_kernel_dispatch_(PetscDeviceContext, cupmStream_t, PetscScalar *, const PetscScalar *, const Vec *, PetscInt, util::index_sequence<Idx...>) noexcept; 77 template <int> 78 static PetscErrorCode MAXPY_kernel_dispatch_(PetscDeviceContext, cupmStream_t, PetscScalar *, const PetscScalar *, const Vec *, PetscInt, PetscInt &) noexcept; 79 // common core for the various create routines 80 static PetscErrorCode CreateSeqCUPM_(Vec, PetscDeviceContext, PetscScalar * /*host_ptr*/ = nullptr, PetscScalar * /*device_ptr*/ = nullptr) noexcept; 81 82 public: 83 // callable directly via a bespoke function 84 static PetscErrorCode CreateSeqCUPM(MPI_Comm, PetscInt, PetscInt, Vec *, PetscBool) noexcept; 85 static PetscErrorCode CreateSeqCUPMWithBothArrays(MPI_Comm, PetscInt, PetscInt, const PetscScalar[], const PetscScalar[], Vec *) noexcept; 86 87 // callable indirectly via function pointers 88 static PetscErrorCode Duplicate(Vec, Vec *) noexcept; 89 static PetscErrorCode AYPX(Vec, PetscScalar, Vec) noexcept; 90 static PetscErrorCode AXPY(Vec, PetscScalar, Vec) noexcept; 91 static PetscErrorCode PointwiseDivide(Vec, Vec, Vec) noexcept; 92 static PetscErrorCode PointwiseMult(Vec, Vec, Vec) noexcept; 93 static PetscErrorCode Reciprocal(Vec) noexcept; 94 static PetscErrorCode WAXPY(Vec, PetscScalar, Vec, Vec) noexcept; 95 static PetscErrorCode MAXPY(Vec, PetscInt, const PetscScalar[], Vec *) noexcept; 96 static PetscErrorCode Dot(Vec, Vec, PetscScalar *) noexcept; 97 static PetscErrorCode MDot(Vec, PetscInt, const Vec[], PetscScalar *) noexcept; 98 static PetscErrorCode Set(Vec, PetscScalar) noexcept; 99 static PetscErrorCode Scale(Vec, PetscScalar) noexcept; 100 static PetscErrorCode TDot(Vec, Vec, PetscScalar *) noexcept; 101 static PetscErrorCode Copy(Vec, Vec) noexcept; 102 static PetscErrorCode Swap(Vec, Vec) noexcept; 103 static PetscErrorCode AXPBY(Vec, PetscScalar, PetscScalar, Vec) noexcept; 104 static PetscErrorCode AXPBYPCZ(Vec, PetscScalar, PetscScalar, PetscScalar, Vec, Vec) noexcept; 105 static PetscErrorCode Norm(Vec, NormType, PetscReal *) noexcept; 106 static PetscErrorCode ErrorWnorm(Vec, Vec, Vec, NormType, PetscReal, Vec, PetscReal, Vec, PetscReal, PetscReal *, PetscInt *, PetscReal *, PetscInt *, PetscReal *, PetscInt *) noexcept; 107 static PetscErrorCode DotNorm2(Vec, Vec, PetscScalar *, PetscScalar *) noexcept; 108 static PetscErrorCode Conjugate(Vec) noexcept; 109 template <PetscMemoryAccessMode> 110 static PetscErrorCode GetLocalVector(Vec, Vec) noexcept; 111 template <PetscMemoryAccessMode> 112 static PetscErrorCode RestoreLocalVector(Vec, Vec) noexcept; 113 static PetscErrorCode Max(Vec, PetscInt *, PetscReal *) noexcept; 114 static PetscErrorCode Min(Vec, PetscInt *, PetscReal *) noexcept; 115 static PetscErrorCode Sum(Vec, PetscScalar *) noexcept; 116 static PetscErrorCode Shift(Vec, PetscScalar) noexcept; 117 static PetscErrorCode SetRandom(Vec, PetscRandom) noexcept; 118 static PetscErrorCode BindToCPU(Vec, PetscBool) noexcept; 119 static PetscErrorCode SetPreallocationCOO(Vec, PetscCount, const PetscInt[]) noexcept; 120 static PetscErrorCode SetValuesCOO(Vec, const PetscScalar[], InsertMode) noexcept; 121 }; 122 123 // ========================================================================================== 124 // VecSeq_CUPM - Private API 125 // ========================================================================================== 126 127 template <device::cupm::DeviceType T> 128 inline Vec_Seq *VecSeq_CUPM<T>::VecIMPLCast_(Vec v) noexcept 129 { 130 return static_cast<Vec_Seq *>(v->data); 131 } 132 133 template <device::cupm::DeviceType T> 134 inline constexpr VecType VecSeq_CUPM<T>::VECIMPLCUPM_() noexcept 135 { 136 return VECSEQCUPM(); 137 } 138 139 template <device::cupm::DeviceType T> 140 inline constexpr VecType VecSeq_CUPM<T>::VECIMPL_() noexcept 141 { 142 return VECSEQ; 143 } 144 145 template <device::cupm::DeviceType T> 146 inline PetscErrorCode VecSeq_CUPM<T>::VecDestroy_IMPL_(Vec v) noexcept 147 { 148 return VecDestroy_Seq(v); 149 } 150 151 template <device::cupm::DeviceType T> 152 inline PetscErrorCode VecSeq_CUPM<T>::VecResetArray_IMPL_(Vec v) noexcept 153 { 154 return VecResetArray_Seq(v); 155 } 156 157 template <device::cupm::DeviceType T> 158 inline PetscErrorCode VecSeq_CUPM<T>::VecPlaceArray_IMPL_(Vec v, const PetscScalar *a) noexcept 159 { 160 return VecPlaceArray_Seq(v, a); 161 } 162 163 template <device::cupm::DeviceType T> 164 inline PetscErrorCode VecSeq_CUPM<T>::VecCreate_IMPL_Private_(Vec v, PetscBool *alloc_missing, PetscInt, PetscScalar *host_array) noexcept 165 { 166 PetscMPIInt size; 167 168 PetscFunctionBegin; 169 if (alloc_missing) *alloc_missing = PETSC_FALSE; 170 PetscCallMPI(MPI_Comm_size(PetscObjectComm(PetscObjectCast(v)), &size)); 171 PetscCheck(size <= 1, PETSC_COMM_SELF, PETSC_ERR_ARG_WRONG, "Must create VecSeq on communicator of size 1, have size %d", size); 172 PetscCall(VecCreate_Seq_Private(v, host_array)); 173 PetscFunctionReturn(PETSC_SUCCESS); 174 } 175 176 // for functions with an early return based one vec size we still need to artificially bump the 177 // object state. This is to prevent the following: 178 // 179 // 0. Suppose you have a Vec { 180 // rank 0: [0], 181 // rank 1: [<empty>] 182 // } 183 // 1. both ranks have Vec with PetscObjectState = 0, stashed norm of 0 184 // 2. Vec enters e.g. VecSet(10) 185 // 3. rank 1 has local size 0 and bails immediately 186 // 4. rank 0 has local size 1 and enters function, eventually calls DeviceArrayWrite() 187 // 5. DeviceArrayWrite() calls PetscObjectStateIncrease(), now state = 1 188 // 6. Vec enters VecNorm(), and calls VecNormAvailable() 189 // 7. rank 1 has object state = 0, equal to stash and returns early with norm = 0 190 // 8. rank 0 has object state = 1, not equal to stash, continues to impl function 191 // 9. rank 0 deadlocks on MPI_Allreduce() because rank 1 bailed early 192 template <device::cupm::DeviceType T> 193 inline PetscErrorCode VecSeq_CUPM<T>::MaybeIncrementEmptyLocalVec(Vec v) noexcept 194 { 195 PetscFunctionBegin; 196 if (PetscUnlikely((v->map->n == 0) && (v->map->N != 0))) PetscCall(PetscObjectStateIncrease(PetscObjectCast(v))); 197 PetscFunctionReturn(PETSC_SUCCESS); 198 } 199 200 template <device::cupm::DeviceType T> 201 inline PetscErrorCode VecSeq_CUPM<T>::CreateSeqCUPM_(Vec v, PetscDeviceContext dctx, PetscScalar *host_array, PetscScalar *device_array) noexcept 202 { 203 PetscFunctionBegin; 204 PetscCall(base_type::VecCreate_IMPL_Private(v, nullptr, 0, host_array)); 205 PetscCall(Initialize_CUPMBase(v, PETSC_FALSE, host_array, device_array, dctx)); 206 PetscFunctionReturn(PETSC_SUCCESS); 207 } 208 209 template <device::cupm::DeviceType T> 210 template <typename BinaryFuncT> 211 inline PetscErrorCode VecSeq_CUPM<T>::PointwiseBinary_(BinaryFuncT &&binary, Vec xin, Vec yin, Vec zout) noexcept 212 { 213 PetscFunctionBegin; 214 if (const auto n = zout->map->n) { 215 PetscDeviceContext dctx; 216 cupmStream_t stream; 217 218 PetscCall(GetHandles_(&dctx, &stream)); 219 // clang-format off 220 PetscCallThrust( 221 const auto dxptr = thrust::device_pointer_cast(DeviceArrayRead(dctx, xin).data()); 222 223 THRUST_CALL( 224 thrust::transform, 225 stream, 226 dxptr, dxptr + n, 227 thrust::device_pointer_cast(DeviceArrayRead(dctx, yin).data()), 228 thrust::device_pointer_cast(DeviceArrayWrite(dctx, zout).data()), 229 std::forward<BinaryFuncT>(binary) 230 ) 231 ); 232 // clang-format on 233 PetscCall(PetscLogFlops(n)); 234 PetscCall(PetscDeviceContextSynchronize(dctx)); 235 } else { 236 PetscCall(MaybeIncrementEmptyLocalVec(zout)); 237 } 238 PetscFunctionReturn(PETSC_SUCCESS); 239 } 240 241 template <device::cupm::DeviceType T> 242 template <typename UnaryFuncT> 243 inline PetscErrorCode VecSeq_CUPM<T>::PointwiseUnary_(UnaryFuncT &&unary, Vec xinout, Vec yin) noexcept 244 { 245 const auto inplace = !yin || (xinout == yin); 246 247 PetscFunctionBegin; 248 if (const auto n = xinout->map->n) { 249 PetscDeviceContext dctx; 250 cupmStream_t stream; 251 const auto apply = [&](PetscScalar *xinout, PetscScalar *yin = nullptr) { 252 PetscFunctionBegin; 253 // clang-format off 254 PetscCallThrust( 255 const auto xptr = thrust::device_pointer_cast(xinout); 256 257 THRUST_CALL( 258 thrust::transform, 259 stream, 260 xptr, xptr + n, 261 (yin && (yin != xinout)) ? thrust::device_pointer_cast(yin) : xptr, 262 std::forward<UnaryFuncT>(unary) 263 ) 264 ); 265 // clang-format on 266 PetscFunctionReturn(PETSC_SUCCESS); 267 }; 268 269 PetscCall(GetHandles_(&dctx, &stream)); 270 if (inplace) { 271 PetscCall(apply(DeviceArrayReadWrite(dctx, xinout).data())); 272 } else { 273 PetscCall(apply(DeviceArrayRead(dctx, xinout).data(), DeviceArrayWrite(dctx, yin).data())); 274 } 275 PetscCall(PetscLogFlops(n)); 276 PetscCall(PetscDeviceContextSynchronize(dctx)); 277 } else { 278 if (inplace) { 279 PetscCall(MaybeIncrementEmptyLocalVec(xinout)); 280 } else { 281 PetscCall(MaybeIncrementEmptyLocalVec(yin)); 282 } 283 } 284 PetscFunctionReturn(PETSC_SUCCESS); 285 } 286 287 // ========================================================================================== 288 // VecSeq_CUPM - Public API - Constructors 289 // ========================================================================================== 290 291 // VecCreateSeqCUPM() 292 template <device::cupm::DeviceType T> 293 inline PetscErrorCode VecSeq_CUPM<T>::CreateSeqCUPM(MPI_Comm comm, PetscInt bs, PetscInt n, Vec *v, PetscBool call_set_type) noexcept 294 { 295 PetscFunctionBegin; 296 PetscCall(Create_CUPMBase(comm, bs, n, n, v, call_set_type)); 297 PetscFunctionReturn(PETSC_SUCCESS); 298 } 299 300 // VecCreateSeqCUPMWithArrays() 301 template <device::cupm::DeviceType T> 302 inline PetscErrorCode VecSeq_CUPM<T>::CreateSeqCUPMWithBothArrays(MPI_Comm comm, PetscInt bs, PetscInt n, const PetscScalar host_array[], const PetscScalar device_array[], Vec *v) noexcept 303 { 304 PetscDeviceContext dctx; 305 306 PetscFunctionBegin; 307 PetscCall(GetHandles_(&dctx)); 308 // do NOT call VecSetType(), otherwise ops->create() -> create() -> 309 // CreateSeqCUPM_() is called! 310 PetscCall(CreateSeqCUPM(comm, bs, n, v, PETSC_FALSE)); 311 PetscCall(CreateSeqCUPM_(*v, dctx, PetscRemoveConstCast(host_array), PetscRemoveConstCast(device_array))); 312 PetscFunctionReturn(PETSC_SUCCESS); 313 } 314 315 // v->ops->duplicate 316 template <device::cupm::DeviceType T> 317 inline PetscErrorCode VecSeq_CUPM<T>::Duplicate(Vec v, Vec *y) noexcept 318 { 319 PetscDeviceContext dctx; 320 321 PetscFunctionBegin; 322 PetscCall(GetHandles_(&dctx)); 323 PetscCall(Duplicate_CUPMBase(v, y, dctx)); 324 PetscFunctionReturn(PETSC_SUCCESS); 325 } 326 327 // ========================================================================================== 328 // VecSeq_CUPM - Public API - Utility 329 // ========================================================================================== 330 331 // v->ops->bindtocpu 332 template <device::cupm::DeviceType T> 333 inline PetscErrorCode VecSeq_CUPM<T>::BindToCPU(Vec v, PetscBool usehost) noexcept 334 { 335 PetscDeviceContext dctx; 336 337 PetscFunctionBegin; 338 PetscCall(GetHandles_(&dctx)); 339 PetscCall(BindToCPU_CUPMBase(v, usehost, dctx)); 340 341 // REVIEW ME: this absolutely should be some sort of bulk mempcy rather than this mess 342 VecSetOp_CUPM(dot, VecDot_Seq, Dot); 343 VecSetOp_CUPM(norm, VecNorm_Seq, Norm); 344 VecSetOp_CUPM(tdot, VecTDot_Seq, TDot); 345 VecSetOp_CUPM(mdot, VecMDot_Seq, MDot); 346 VecSetOp_CUPM(resetarray, VecResetArray_Seq, base_type::template ResetArray<PETSC_MEMTYPE_HOST>); 347 VecSetOp_CUPM(placearray, VecPlaceArray_Seq, base_type::template PlaceArray<PETSC_MEMTYPE_HOST>); 348 v->ops->mtdot = v->ops->mtdot_local = VecMTDot_Seq; 349 VecSetOp_CUPM(conjugate, VecConjugate_Seq, Conjugate); 350 VecSetOp_CUPM(max, VecMax_Seq, Max); 351 VecSetOp_CUPM(min, VecMin_Seq, Min); 352 VecSetOp_CUPM(setpreallocationcoo, VecSetPreallocationCOO_Seq, SetPreallocationCOO); 353 VecSetOp_CUPM(setvaluescoo, VecSetValuesCOO_Seq, SetValuesCOO); 354 VecSetOp_CUPM(errorwnorm, nullptr, ErrorWnorm); 355 PetscFunctionReturn(PETSC_SUCCESS); 356 } 357 358 // ========================================================================================== 359 // VecSeq_CUPM - Public API - Mutators 360 // ========================================================================================== 361 362 // v->ops->getlocalvector or v->ops->getlocalvectorread 363 template <device::cupm::DeviceType T> 364 template <PetscMemoryAccessMode access> 365 inline PetscErrorCode VecSeq_CUPM<T>::GetLocalVector(Vec v, Vec w) noexcept 366 { 367 PetscBool wisseqcupm; 368 369 PetscFunctionBegin; 370 PetscCheckTypeNames(v, VECSEQCUPM(), VECMPICUPM()); 371 PetscCall(PetscObjectTypeCompare(PetscObjectCast(w), VECSEQCUPM(), &wisseqcupm)); 372 if (wisseqcupm) { 373 if (const auto wseq = VecIMPLCast(w)) { 374 if (auto &alloced = wseq->array_allocated) { 375 const auto useit = UseCUPMHostAlloc(util::exchange(w->pinned_memory, PETSC_FALSE)); 376 377 PetscCall(PetscFree(alloced)); 378 } 379 wseq->array = nullptr; 380 wseq->unplacedarray = nullptr; 381 } 382 if (const auto wcu = VecCUPMCast(w)) { 383 if (auto &device_array = wcu->array_d) { 384 cupmStream_t stream; 385 386 PetscCall(GetHandles_(&stream)); 387 PetscCallCUPM(cupmFreeAsync(device_array, stream)); 388 } 389 PetscCall(PetscFree(w->spptr /* wcu */)); 390 } 391 } 392 if (v->petscnative && wisseqcupm) { 393 PetscCall(PetscFree(w->data)); 394 w->data = v->data; 395 w->offloadmask = v->offloadmask; 396 w->pinned_memory = v->pinned_memory; 397 w->spptr = v->spptr; 398 PetscCall(PetscObjectStateIncrease(PetscObjectCast(w))); 399 } else { 400 const auto array = &VecIMPLCast(w)->array; 401 402 if (access == PETSC_MEMORY_ACCESS_READ) { 403 PetscCall(VecGetArrayRead(v, const_cast<const PetscScalar **>(array))); 404 } else { 405 PetscCall(VecGetArray(v, array)); 406 } 407 w->offloadmask = PETSC_OFFLOAD_CPU; 408 if (wisseqcupm) { 409 PetscDeviceContext dctx; 410 411 PetscCall(GetHandles_(&dctx)); 412 PetscCall(DeviceAllocateCheck_(dctx, w)); 413 } 414 } 415 PetscFunctionReturn(PETSC_SUCCESS); 416 } 417 418 // v->ops->restorelocalvector or v->ops->restorelocalvectorread 419 template <device::cupm::DeviceType T> 420 template <PetscMemoryAccessMode access> 421 inline PetscErrorCode VecSeq_CUPM<T>::RestoreLocalVector(Vec v, Vec w) noexcept 422 { 423 PetscBool wisseqcupm; 424 425 PetscFunctionBegin; 426 PetscCheckTypeNames(v, VECSEQCUPM(), VECMPICUPM()); 427 PetscCall(PetscObjectTypeCompare(PetscObjectCast(w), VECSEQCUPM(), &wisseqcupm)); 428 if (v->petscnative && wisseqcupm) { 429 // the assignments to nullptr are __critical__, as w may persist after this call returns 430 // and shouldn't share data with v! 431 v->pinned_memory = w->pinned_memory; 432 v->offloadmask = util::exchange(w->offloadmask, PETSC_OFFLOAD_UNALLOCATED); 433 v->data = util::exchange(w->data, nullptr); 434 v->spptr = util::exchange(w->spptr, nullptr); 435 } else { 436 const auto array = &VecIMPLCast(w)->array; 437 438 if (access == PETSC_MEMORY_ACCESS_READ) { 439 PetscCall(VecRestoreArrayRead(v, const_cast<const PetscScalar **>(array))); 440 } else { 441 PetscCall(VecRestoreArray(v, array)); 442 } 443 if (w->spptr && wisseqcupm) { 444 cupmStream_t stream; 445 446 PetscCall(GetHandles_(&stream)); 447 PetscCallCUPM(cupmFreeAsync(VecCUPMCast(w)->array_d, stream)); 448 PetscCall(PetscFree(w->spptr)); 449 } 450 } 451 PetscFunctionReturn(PETSC_SUCCESS); 452 } 453 454 // ========================================================================================== 455 // VecSeq_CUPM - Public API - Compute Methods 456 // ========================================================================================== 457 458 // v->ops->aypx 459 template <device::cupm::DeviceType T> 460 inline PetscErrorCode VecSeq_CUPM<T>::AYPX(Vec yin, PetscScalar alpha, Vec xin) noexcept 461 { 462 const auto n = static_cast<cupmBlasInt_t>(yin->map->n); 463 const auto sync = n != 0; 464 PetscDeviceContext dctx; 465 466 PetscFunctionBegin; 467 PetscCall(GetHandles_(&dctx)); 468 if (alpha == PetscScalar(0.0)) { 469 cupmStream_t stream; 470 471 PetscCall(GetHandlesFrom_(dctx, &stream)); 472 PetscCall(PetscLogGpuTimeBegin()); 473 PetscCall(PetscCUPMMemcpyAsync(DeviceArrayWrite(dctx, yin).data(), DeviceArrayRead(dctx, xin).data(), n, cupmMemcpyDeviceToDevice, stream)); 474 PetscCall(PetscLogGpuTimeEnd()); 475 } else if (n) { 476 const auto alphaIsOne = alpha == PetscScalar(1.0); 477 const auto calpha = cupmScalarPtrCast(&alpha); 478 cupmBlasHandle_t cupmBlasHandle; 479 480 PetscCall(GetHandlesFrom_(dctx, &cupmBlasHandle)); 481 { 482 const auto yptr = DeviceArrayReadWrite(dctx, yin); 483 const auto xptr = DeviceArrayRead(dctx, xin); 484 485 PetscCall(PetscLogGpuTimeBegin()); 486 if (alphaIsOne) { 487 PetscCallCUPMBLAS(cupmBlasXaxpy(cupmBlasHandle, n, calpha, xptr.cupmdata(), 1, yptr.cupmdata(), 1)); 488 } else { 489 const auto one = cupmScalarCast(1.0); 490 491 PetscCallCUPMBLAS(cupmBlasXscal(cupmBlasHandle, n, calpha, yptr.cupmdata(), 1)); 492 PetscCallCUPMBLAS(cupmBlasXaxpy(cupmBlasHandle, n, &one, xptr.cupmdata(), 1, yptr.cupmdata(), 1)); 493 } 494 PetscCall(PetscLogGpuTimeEnd()); 495 } 496 PetscCall(PetscLogGpuFlops((alphaIsOne ? 1 : 2) * n)); 497 } 498 if (sync) PetscCall(PetscDeviceContextSynchronize(dctx)); 499 PetscFunctionReturn(PETSC_SUCCESS); 500 } 501 502 // v->ops->axpy 503 template <device::cupm::DeviceType T> 504 inline PetscErrorCode VecSeq_CUPM<T>::AXPY(Vec yin, PetscScalar alpha, Vec xin) noexcept 505 { 506 PetscBool xiscupm; 507 508 PetscFunctionBegin; 509 PetscCall(PetscObjectTypeCompareAny(PetscObjectCast(xin), &xiscupm, VECSEQCUPM(), VECMPICUPM(), "")); 510 if (xiscupm) { 511 const auto n = static_cast<cupmBlasInt_t>(yin->map->n); 512 PetscDeviceContext dctx; 513 cupmBlasHandle_t cupmBlasHandle; 514 515 PetscCall(GetHandles_(&dctx, &cupmBlasHandle)); 516 PetscCall(PetscLogGpuTimeBegin()); 517 PetscCallCUPMBLAS(cupmBlasXaxpy(cupmBlasHandle, n, cupmScalarPtrCast(&alpha), DeviceArrayRead(dctx, xin), 1, DeviceArrayReadWrite(dctx, yin), 1)); 518 PetscCall(PetscLogGpuTimeEnd()); 519 PetscCall(PetscLogGpuFlops(2 * n)); 520 PetscCall(PetscDeviceContextSynchronize(dctx)); 521 } else { 522 PetscCall(VecAXPY_Seq(yin, alpha, xin)); 523 } 524 PetscFunctionReturn(PETSC_SUCCESS); 525 } 526 527 // v->ops->pointwisedivide 528 template <device::cupm::DeviceType T> 529 inline PetscErrorCode VecSeq_CUPM<T>::PointwiseDivide(Vec win, Vec xin, Vec yin) noexcept 530 { 531 PetscFunctionBegin; 532 if (xin->boundtocpu || yin->boundtocpu) { 533 PetscCall(VecPointwiseDivide_Seq(win, xin, yin)); 534 } else { 535 // note order of arguments! xin and yin are read, win is written! 536 PetscCall(PointwiseBinary_(thrust::divides<PetscScalar>{}, xin, yin, win)); 537 } 538 PetscFunctionReturn(PETSC_SUCCESS); 539 } 540 541 // v->ops->pointwisemult 542 template <device::cupm::DeviceType T> 543 inline PetscErrorCode VecSeq_CUPM<T>::PointwiseMult(Vec win, Vec xin, Vec yin) noexcept 544 { 545 PetscFunctionBegin; 546 if (xin->boundtocpu || yin->boundtocpu) { 547 PetscCall(VecPointwiseMult_Seq(win, xin, yin)); 548 } else { 549 // note order of arguments! xin and yin are read, win is written! 550 PetscCall(PointwiseBinary_(thrust::multiplies<PetscScalar>{}, xin, yin, win)); 551 } 552 PetscFunctionReturn(PETSC_SUCCESS); 553 } 554 555 namespace detail 556 { 557 558 struct reciprocal { 559 PETSC_HOSTDEVICE_INLINE_DECL PetscScalar operator()(PetscScalar s) const noexcept 560 { 561 // yes all of this verbosity is needed because sometimes PetscScalar is a thrust::complex 562 // and then it matters whether we do s ? true : false vs s == 0, as well as whether we wrap 563 // everything in PetscScalar... 564 return s == PetscScalar{0.0} ? s : PetscScalar{1.0} / s; 565 } 566 }; 567 568 } // namespace detail 569 570 // v->ops->reciprocal 571 template <device::cupm::DeviceType T> 572 inline PetscErrorCode VecSeq_CUPM<T>::Reciprocal(Vec xin) noexcept 573 { 574 PetscFunctionBegin; 575 PetscCall(PointwiseUnary_(detail::reciprocal{}, xin)); 576 PetscFunctionReturn(PETSC_SUCCESS); 577 } 578 579 // v->ops->waxpy 580 template <device::cupm::DeviceType T> 581 inline PetscErrorCode VecSeq_CUPM<T>::WAXPY(Vec win, PetscScalar alpha, Vec xin, Vec yin) noexcept 582 { 583 PetscFunctionBegin; 584 if (alpha == PetscScalar(0.0)) { 585 PetscCall(Copy(yin, win)); 586 } else if (const auto n = static_cast<cupmBlasInt_t>(win->map->n)) { 587 PetscDeviceContext dctx; 588 cupmBlasHandle_t cupmBlasHandle; 589 cupmStream_t stream; 590 591 PetscCall(GetHandles_(&dctx, &cupmBlasHandle, &stream)); 592 { 593 const auto wptr = DeviceArrayWrite(dctx, win); 594 595 PetscCall(PetscLogGpuTimeBegin()); 596 PetscCall(PetscCUPMMemcpyAsync(wptr.data(), DeviceArrayRead(dctx, yin).data(), n, cupmMemcpyDeviceToDevice, stream, true)); 597 PetscCallCUPMBLAS(cupmBlasXaxpy(cupmBlasHandle, n, cupmScalarPtrCast(&alpha), DeviceArrayRead(dctx, xin), 1, wptr.cupmdata(), 1)); 598 PetscCall(PetscLogGpuTimeEnd()); 599 } 600 PetscCall(PetscLogGpuFlops(2 * n)); 601 PetscCall(PetscDeviceContextSynchronize(dctx)); 602 } 603 PetscFunctionReturn(PETSC_SUCCESS); 604 } 605 606 namespace kernels 607 { 608 609 template <typename... Args> 610 PETSC_KERNEL_DECL static void MAXPY_kernel(const PetscInt size, PetscScalar *PETSC_RESTRICT xptr, const PetscScalar *PETSC_RESTRICT aptr, Args... yptr) 611 { 612 constexpr int N = sizeof...(Args); 613 const auto tx = threadIdx.x; 614 const PetscScalar *yptr_p[] = {yptr...}; 615 616 PETSC_SHAREDMEM_DECL PetscScalar aptr_shmem[N]; 617 618 // load a to shared memory 619 if (tx < N) aptr_shmem[tx] = aptr[tx]; 620 __syncthreads(); 621 622 ::Petsc::device::cupm::kernels::util::grid_stride_1D(size, [&](PetscInt i) { 623 // these may look the same but give different results! 624 #if 0 625 PetscScalar sum = 0.0; 626 627 #pragma unroll 628 for (auto j = 0; j < N; ++j) sum += aptr_shmem[j]*yptr_p[j][i]; 629 xptr[i] += sum; 630 #else 631 auto sum = xptr[i]; 632 633 #pragma unroll 634 for (auto j = 0; j < N; ++j) sum += aptr_shmem[j] * yptr_p[j][i]; 635 xptr[i] = sum; 636 #endif 637 }); 638 return; 639 } 640 641 } // namespace kernels 642 643 namespace detail 644 { 645 646 // a helper-struct to gobble the size_t input, it is used with template parameter pack 647 // expansion such that 648 // typename repeat_type<MyType, IdxParamPack>... 649 // expands to 650 // MyType, MyType, MyType, ... [repeated sizeof...(IdxParamPack) times] 651 template <typename T, std::size_t> 652 struct repeat_type { 653 using type = T; 654 }; 655 656 } // namespace detail 657 658 template <device::cupm::DeviceType T> 659 template <std::size_t... Idx> 660 inline PetscErrorCode VecSeq_CUPM<T>::MAXPY_kernel_dispatch_(PetscDeviceContext dctx, cupmStream_t stream, PetscScalar *xptr, const PetscScalar *aptr, const Vec *yin, PetscInt size, util::index_sequence<Idx...>) noexcept 661 { 662 PetscFunctionBegin; 663 // clang-format off 664 PetscCall( 665 PetscCUPMLaunchKernel1D( 666 size, 0, stream, 667 kernels::MAXPY_kernel<typename detail::repeat_type<const PetscScalar *, Idx>::type...>, 668 size, xptr, aptr, DeviceArrayRead(dctx, yin[Idx]).data()... 669 ) 670 ); 671 // clang-format on 672 PetscFunctionReturn(PETSC_SUCCESS); 673 } 674 675 template <device::cupm::DeviceType T> 676 template <int N> 677 inline PetscErrorCode VecSeq_CUPM<T>::MAXPY_kernel_dispatch_(PetscDeviceContext dctx, cupmStream_t stream, PetscScalar *xptr, const PetscScalar *aptr, const Vec *yin, PetscInt size, PetscInt &yidx) noexcept 678 { 679 PetscFunctionBegin; 680 PetscCall(MAXPY_kernel_dispatch_(dctx, stream, xptr, aptr + yidx, yin + yidx, size, util::make_index_sequence<N>{})); 681 yidx += N; 682 PetscFunctionReturn(PETSC_SUCCESS); 683 } 684 685 // v->ops->maxpy 686 template <device::cupm::DeviceType T> 687 inline PetscErrorCode VecSeq_CUPM<T>::MAXPY(Vec xin, PetscInt nv, const PetscScalar *alpha, Vec *yin) noexcept 688 { 689 const auto n = xin->map->n; 690 PetscDeviceContext dctx; 691 cupmStream_t stream; 692 693 PetscFunctionBegin; 694 PetscCall(GetHandles_(&dctx, &stream)); 695 { 696 const auto xptr = DeviceArrayReadWrite(dctx, xin); 697 PetscScalar *d_alpha = nullptr; 698 PetscInt yidx = 0; 699 700 // placement of early-return is deliberate, we would like to capture the 701 // DeviceArrayReadWrite() call (which calls PetscObjectStateIncreate()) before we bail 702 if (!n || !nv) PetscFunctionReturn(PETSC_SUCCESS); 703 PetscCall(PetscDeviceMalloc(dctx, PETSC_MEMTYPE_CUPM(), nv, &d_alpha)); 704 PetscCall(PetscCUPMMemcpyAsync(d_alpha, alpha, nv, cupmMemcpyHostToDevice, stream)); 705 PetscCall(PetscLogGpuTimeBegin()); 706 do { 707 switch (nv - yidx) { 708 case 7: 709 PetscCall(MAXPY_kernel_dispatch_<7>(dctx, stream, xptr.data(), d_alpha, yin, n, yidx)); 710 break; 711 case 6: 712 PetscCall(MAXPY_kernel_dispatch_<6>(dctx, stream, xptr.data(), d_alpha, yin, n, yidx)); 713 break; 714 case 5: 715 PetscCall(MAXPY_kernel_dispatch_<5>(dctx, stream, xptr.data(), d_alpha, yin, n, yidx)); 716 break; 717 case 4: 718 PetscCall(MAXPY_kernel_dispatch_<4>(dctx, stream, xptr.data(), d_alpha, yin, n, yidx)); 719 break; 720 case 3: 721 PetscCall(MAXPY_kernel_dispatch_<3>(dctx, stream, xptr.data(), d_alpha, yin, n, yidx)); 722 break; 723 case 2: 724 PetscCall(MAXPY_kernel_dispatch_<2>(dctx, stream, xptr.data(), d_alpha, yin, n, yidx)); 725 break; 726 case 1: 727 PetscCall(MAXPY_kernel_dispatch_<1>(dctx, stream, xptr.data(), d_alpha, yin, n, yidx)); 728 break; 729 default: // 8 or more 730 PetscCall(MAXPY_kernel_dispatch_<8>(dctx, stream, xptr.data(), d_alpha, yin, n, yidx)); 731 break; 732 } 733 } while (yidx < nv); 734 PetscCall(PetscLogGpuTimeEnd()); 735 PetscCall(PetscDeviceFree(dctx, d_alpha)); 736 } 737 PetscCall(PetscLogGpuFlops(nv * 2 * n)); 738 PetscCall(PetscDeviceContextSynchronize(dctx)); 739 PetscFunctionReturn(PETSC_SUCCESS); 740 } 741 742 template <device::cupm::DeviceType T> 743 inline PetscErrorCode VecSeq_CUPM<T>::Dot(Vec xin, Vec yin, PetscScalar *z) noexcept 744 { 745 PetscFunctionBegin; 746 if (const auto n = static_cast<cupmBlasInt_t>(xin->map->n)) { 747 PetscDeviceContext dctx; 748 cupmBlasHandle_t cupmBlasHandle; 749 750 PetscCall(GetHandles_(&dctx, &cupmBlasHandle)); 751 // arguments y, x are reversed because BLAS complex conjugates the first argument, PETSc the 752 // second 753 PetscCall(PetscLogGpuTimeBegin()); 754 PetscCallCUPMBLAS(cupmBlasXdot(cupmBlasHandle, n, DeviceArrayRead(dctx, yin), 1, DeviceArrayRead(dctx, xin), 1, cupmScalarPtrCast(z))); 755 PetscCall(PetscLogGpuTimeEnd()); 756 PetscCall(PetscLogGpuFlops(2 * n - 1)); 757 } else { 758 *z = 0.0; 759 } 760 PetscFunctionReturn(PETSC_SUCCESS); 761 } 762 763 #define MDOT_WORKGROUP_NUM 128 764 #define MDOT_WORKGROUP_SIZE MDOT_WORKGROUP_NUM 765 766 namespace kernels 767 { 768 769 PETSC_DEVICE_INLINE_DECL static PetscInt EntriesPerGroup(const PetscInt size) noexcept 770 { 771 const auto group_entries = (size - 1) / gridDim.x + 1; 772 // for very small vectors, a group should still do some work 773 return group_entries ? group_entries : 1; 774 } 775 776 template <typename... ConstPetscScalarPointer> 777 PETSC_KERNEL_DECL static void MDot_kernel(const PetscScalar *PETSC_RESTRICT x, const PetscInt size, PetscScalar *PETSC_RESTRICT results, ConstPetscScalarPointer... y) 778 { 779 constexpr int N = sizeof...(ConstPetscScalarPointer); 780 const PetscScalar *ylocal[] = {y...}; 781 PetscScalar sumlocal[N]; 782 783 PETSC_SHAREDMEM_DECL PetscScalar shmem[N * MDOT_WORKGROUP_SIZE]; 784 785 // HIP -- for whatever reason -- has threadIdx, blockIdx, blockDim, and gridDim as separate 786 // types, so each of these go on separate lines... 787 const auto tx = threadIdx.x; 788 const auto bx = blockIdx.x; 789 const auto bdx = blockDim.x; 790 const auto gdx = gridDim.x; 791 const auto worksize = EntriesPerGroup(size); 792 const auto begin = tx + bx * worksize; 793 const auto end = min((bx + 1) * worksize, size); 794 795 #pragma unroll 796 for (auto i = 0; i < N; ++i) sumlocal[i] = 0; 797 798 for (auto i = begin; i < end; i += bdx) { 799 const auto xi = x[i]; // load only once from global memory! 800 801 #pragma unroll 802 for (auto j = 0; j < N; ++j) sumlocal[j] += ylocal[j][i] * xi; 803 } 804 805 #pragma unroll 806 for (auto i = 0; i < N; ++i) shmem[tx + i * MDOT_WORKGROUP_SIZE] = sumlocal[i]; 807 808 // parallel reduction 809 for (auto stride = bdx / 2; stride > 0; stride /= 2) { 810 __syncthreads(); 811 if (tx < stride) { 812 #pragma unroll 813 for (auto i = 0; i < N; ++i) shmem[tx + i * MDOT_WORKGROUP_SIZE] += shmem[tx + stride + i * MDOT_WORKGROUP_SIZE]; 814 } 815 } 816 // bottom N threads per block write to global memory 817 // REVIEW ME: I am ~pretty~ sure we don't need another __syncthreads() here since each thread 818 // writes to the same sections in the above loop that it is about to read from below, but 819 // running this under the racecheck tool of cuda-memcheck reports a write-after-write hazard. 820 __syncthreads(); 821 if (tx < N) results[bx + tx * gdx] = shmem[tx * MDOT_WORKGROUP_SIZE]; 822 return; 823 } 824 825 namespace 826 { 827 828 PETSC_KERNEL_DECL void sum_kernel(const PetscInt size, PetscScalar *PETSC_RESTRICT results) 829 { 830 int local_i = 0; 831 PetscScalar local_results[8]; 832 833 // each thread sums up MDOT_WORKGROUP_NUM entries of the result, storing it in a local buffer 834 // 835 // *-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-* 836 // | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | ... 837 // *-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-* 838 // | ______________________________________________________/ 839 // | / <- MDOT_WORKGROUP_NUM -> 840 // |/ 841 // + 842 // v 843 // *-*-* 844 // | | | ... 845 // *-*-* 846 // 847 ::Petsc::device::cupm::kernels::util::grid_stride_1D(size, [&](PetscInt i) { 848 PetscScalar z_sum = 0; 849 850 for (auto j = i * MDOT_WORKGROUP_SIZE; j < (i + 1) * MDOT_WORKGROUP_SIZE; ++j) z_sum += results[j]; 851 local_results[local_i++] = z_sum; 852 }); 853 // if we needed more than 1 workgroup to handle the vector we should sync since other threads 854 // may currently be reading from results 855 if (size >= MDOT_WORKGROUP_SIZE) __syncthreads(); 856 // Local buffer is now written to global memory 857 ::Petsc::device::cupm::kernels::util::grid_stride_1D(size, [&](PetscInt i) { 858 const auto j = --local_i; 859 860 if (j >= 0) results[i] = local_results[j]; 861 }); 862 return; 863 } 864 865 } // namespace 866 867 } // namespace kernels 868 869 template <device::cupm::DeviceType T> 870 template <std::size_t... Idx> 871 inline PetscErrorCode VecSeq_CUPM<T>::MDot_kernel_dispatch_(PetscDeviceContext dctx, cupmStream_t stream, const PetscScalar *xarr, const Vec yin[], PetscInt size, PetscScalar *results, util::index_sequence<Idx...>) noexcept 872 { 873 PetscFunctionBegin; 874 // REVIEW ME: convert this kernel launch to PetscCUPMLaunchKernel1D(), it currently launches 875 // 128 blocks of 128 threads every time which may be wasteful 876 // clang-format off 877 PetscCallCUPM( 878 cupmLaunchKernel( 879 kernels::MDot_kernel<typename detail::repeat_type<const PetscScalar *, Idx>::type...>, 880 MDOT_WORKGROUP_NUM, MDOT_WORKGROUP_SIZE, 0, stream, 881 xarr, size, results, DeviceArrayRead(dctx, yin[Idx]).data()... 882 ) 883 ); 884 // clang-format on 885 PetscFunctionReturn(PETSC_SUCCESS); 886 } 887 888 template <device::cupm::DeviceType T> 889 template <int N> 890 inline PetscErrorCode VecSeq_CUPM<T>::MDot_kernel_dispatch_(PetscDeviceContext dctx, cupmStream_t stream, const PetscScalar *xarr, const Vec yin[], PetscInt size, PetscScalar *results, PetscInt &yidx) noexcept 891 { 892 PetscFunctionBegin; 893 PetscCall(MDot_kernel_dispatch_(dctx, stream, xarr, yin + yidx, size, results + yidx * MDOT_WORKGROUP_NUM, util::make_index_sequence<N>{})); 894 yidx += N; 895 PetscFunctionReturn(PETSC_SUCCESS); 896 } 897 898 template <device::cupm::DeviceType T> 899 inline PetscErrorCode VecSeq_CUPM<T>::MDot_(std::false_type, Vec xin, PetscInt nv, const Vec yin[], PetscScalar *z, PetscDeviceContext dctx) noexcept 900 { 901 // the largest possible size of a batch 902 constexpr PetscInt batchsize = 8; 903 // how many sub streams to create, if nv <= batchsize we can do this without looping, so we 904 // do not create substreams. Note we don't create more than 8 streams, in practice we could 905 // not get more parallelism with higher numbers. 906 const auto num_sub_streams = nv > batchsize ? std::min((nv + batchsize) / batchsize, batchsize) : 0; 907 const auto n = xin->map->n; 908 // number of vectors that we handle via the batches. note any singletons are handled by 909 // cublas, hence the nv-1. 910 const auto nvbatch = ((nv % batchsize) == 1) ? nv - 1 : nv; 911 const auto nwork = nvbatch * MDOT_WORKGROUP_NUM; 912 PetscScalar *d_results; 913 cupmStream_t stream; 914 915 PetscFunctionBegin; 916 PetscCall(GetHandlesFrom_(dctx, &stream)); 917 // allocate scratchpad memory for the results of individual work groups 918 PetscCall(PetscDeviceMalloc(dctx, PETSC_MEMTYPE_CUPM(), nwork, &d_results)); 919 { 920 const auto xptr = DeviceArrayRead(dctx, xin); 921 PetscInt yidx = 0; 922 auto subidx = 0; 923 auto cur_stream = stream; 924 auto cur_ctx = dctx; 925 PetscDeviceContext *sub = nullptr; 926 PetscStreamType stype; 927 928 // REVIEW ME: maybe PetscDeviceContextFork() should insert dctx into the first entry of 929 // sub. Ideally the parent context should also join in on the fork, but it is extremely 930 // fiddly to do so presently 931 PetscCall(PetscDeviceContextGetStreamType(dctx, &stype)); 932 if (stype == PETSC_STREAM_GLOBAL_BLOCKING) stype = PETSC_STREAM_DEFAULT_BLOCKING; 933 // If we have a globally blocking stream create nonblocking streams instead (as we can 934 // locally exploit the parallelism). Otherwise use the prescribed stream type. 935 PetscCall(PetscDeviceContextForkWithStreamType(dctx, stype, num_sub_streams, &sub)); 936 PetscCall(PetscLogGpuTimeBegin()); 937 do { 938 if (num_sub_streams) { 939 cur_ctx = sub[subidx++ % num_sub_streams]; 940 PetscCall(GetHandlesFrom_(cur_ctx, &cur_stream)); 941 } 942 // REVIEW ME: Should probably try and load-balance these. Consider the case where nv = 9; 943 // it is very likely better to do 4+5 rather than 8+1 944 switch (nv - yidx) { 945 case 7: 946 PetscCall(MDot_kernel_dispatch_<7>(cur_ctx, cur_stream, xptr.data(), yin, n, d_results, yidx)); 947 break; 948 case 6: 949 PetscCall(MDot_kernel_dispatch_<6>(cur_ctx, cur_stream, xptr.data(), yin, n, d_results, yidx)); 950 break; 951 case 5: 952 PetscCall(MDot_kernel_dispatch_<5>(cur_ctx, cur_stream, xptr.data(), yin, n, d_results, yidx)); 953 break; 954 case 4: 955 PetscCall(MDot_kernel_dispatch_<4>(cur_ctx, cur_stream, xptr.data(), yin, n, d_results, yidx)); 956 break; 957 case 3: 958 PetscCall(MDot_kernel_dispatch_<3>(cur_ctx, cur_stream, xptr.data(), yin, n, d_results, yidx)); 959 break; 960 case 2: 961 PetscCall(MDot_kernel_dispatch_<2>(cur_ctx, cur_stream, xptr.data(), yin, n, d_results, yidx)); 962 break; 963 case 1: { 964 cupmBlasHandle_t cupmBlasHandle; 965 966 PetscCall(GetHandlesFrom_(cur_ctx, &cupmBlasHandle)); 967 PetscCallCUPMBLAS(cupmBlasXdot(cupmBlasHandle, static_cast<cupmBlasInt_t>(n), DeviceArrayRead(cur_ctx, yin[yidx]).cupmdata(), 1, xptr.cupmdata(), 1, cupmScalarPtrCast(z + yidx))); 968 ++yidx; 969 } break; 970 default: // 8 or more 971 PetscCall(MDot_kernel_dispatch_<8>(cur_ctx, cur_stream, xptr.data(), yin, n, d_results, yidx)); 972 break; 973 } 974 } while (yidx < nv); 975 PetscCall(PetscLogGpuTimeEnd()); 976 PetscCall(PetscDeviceContextJoin(dctx, num_sub_streams, PETSC_DEVICE_CONTEXT_JOIN_DESTROY, &sub)); 977 } 978 979 PetscCall(PetscCUPMLaunchKernel1D(nvbatch, 0, stream, kernels::sum_kernel, nvbatch, d_results)); 980 // copy result of device reduction to host 981 PetscCall(PetscCUPMMemcpyAsync(z, d_results, nvbatch, cupmMemcpyDeviceToHost, stream)); 982 // do these now while final reduction is in flight 983 PetscCall(PetscLogFlops(nwork)); 984 PetscCall(PetscDeviceFree(dctx, d_results)); 985 PetscFunctionReturn(PETSC_SUCCESS); 986 } 987 988 #undef MDOT_WORKGROUP_NUM 989 #undef MDOT_WORKGROUP_SIZE 990 991 template <device::cupm::DeviceType T> 992 inline PetscErrorCode VecSeq_CUPM<T>::MDot_(std::true_type, Vec xin, PetscInt nv, const Vec yin[], PetscScalar *z, PetscDeviceContext dctx) noexcept 993 { 994 // probably not worth it to run more than 8 of these at a time? 995 const auto n_sub = PetscMin(nv, 8); 996 const auto n = static_cast<cupmBlasInt_t>(xin->map->n); 997 const auto xptr = DeviceArrayRead(dctx, xin); 998 PetscScalar *d_z; 999 PetscDeviceContext *subctx; 1000 cupmStream_t stream; 1001 1002 PetscFunctionBegin; 1003 PetscCall(GetHandlesFrom_(dctx, &stream)); 1004 PetscCall(PetscDeviceMalloc(dctx, PETSC_MEMTYPE_CUPM(), nv, &d_z)); 1005 PetscCall(PetscDeviceContextFork(dctx, n_sub, &subctx)); 1006 PetscCall(PetscLogGpuTimeBegin()); 1007 for (PetscInt i = 0; i < nv; ++i) { 1008 const auto sub = subctx[i % n_sub]; 1009 cupmBlasHandle_t handle; 1010 cupmBlasPointerMode_t old_mode; 1011 1012 PetscCall(GetHandlesFrom_(sub, &handle)); 1013 PetscCallCUPMBLAS(cupmBlasGetPointerMode(handle, &old_mode)); 1014 if (old_mode != CUPMBLAS_POINTER_MODE_DEVICE) PetscCallCUPMBLAS(cupmBlasSetPointerMode(handle, CUPMBLAS_POINTER_MODE_DEVICE)); 1015 PetscCallCUPMBLAS(cupmBlasXdot(handle, n, DeviceArrayRead(sub, yin[i]), 1, xptr.cupmdata(), 1, cupmScalarPtrCast(d_z + i))); 1016 if (old_mode != CUPMBLAS_POINTER_MODE_DEVICE) PetscCallCUPMBLAS(cupmBlasSetPointerMode(handle, old_mode)); 1017 } 1018 PetscCall(PetscLogGpuTimeEnd()); 1019 PetscCall(PetscDeviceContextJoin(dctx, n_sub, PETSC_DEVICE_CONTEXT_JOIN_DESTROY, &subctx)); 1020 PetscCall(PetscCUPMMemcpyAsync(z, d_z, nv, cupmMemcpyDeviceToHost, stream)); 1021 PetscCall(PetscDeviceFree(dctx, d_z)); 1022 // REVIEW ME: flops????? 1023 PetscFunctionReturn(PETSC_SUCCESS); 1024 } 1025 1026 // v->ops->mdot 1027 template <device::cupm::DeviceType T> 1028 inline PetscErrorCode VecSeq_CUPM<T>::MDot(Vec xin, PetscInt nv, const Vec yin[], PetscScalar *z) noexcept 1029 { 1030 PetscFunctionBegin; 1031 if (PetscUnlikely(nv == 1)) { 1032 // dot handles nv = 0 correctly 1033 PetscCall(Dot(xin, const_cast<Vec>(yin[0]), z)); 1034 } else if (const auto n = xin->map->n) { 1035 PetscDeviceContext dctx; 1036 1037 PetscCheck(nv > 0, PETSC_COMM_SELF, PETSC_ERR_LIB, "Number of vectors provided to %s %" PetscInt_FMT " not positive", PETSC_FUNCTION_NAME, nv); 1038 PetscCall(GetHandles_(&dctx)); 1039 PetscCall(MDot_(std::integral_constant<bool, PetscDefined(USE_COMPLEX)>{}, xin, nv, yin, z, dctx)); 1040 // REVIEW ME: double count of flops?? 1041 PetscCall(PetscLogGpuFlops(nv * (2 * n - 1))); 1042 PetscCall(PetscDeviceContextSynchronize(dctx)); 1043 } else { 1044 PetscCall(PetscArrayzero(z, nv)); 1045 } 1046 PetscFunctionReturn(PETSC_SUCCESS); 1047 } 1048 1049 // v->ops->set 1050 template <device::cupm::DeviceType T> 1051 inline PetscErrorCode VecSeq_CUPM<T>::Set(Vec xin, PetscScalar alpha) noexcept 1052 { 1053 const auto n = xin->map->n; 1054 PetscDeviceContext dctx; 1055 cupmStream_t stream; 1056 1057 PetscFunctionBegin; 1058 PetscCall(GetHandles_(&dctx, &stream)); 1059 { 1060 const auto xptr = DeviceArrayWrite(dctx, xin); 1061 1062 if (alpha == PetscScalar(0.0)) { 1063 PetscCall(PetscCUPMMemsetAsync(xptr.data(), 0, n, stream)); 1064 } else { 1065 const auto dptr = thrust::device_pointer_cast(xptr.data()); 1066 1067 PetscCallThrust(THRUST_CALL(thrust::fill, stream, dptr, dptr + n, alpha)); 1068 } 1069 if (n) PetscCall(PetscDeviceContextSynchronize(dctx)); // don't sync if we did nothing 1070 } 1071 PetscFunctionReturn(PETSC_SUCCESS); 1072 } 1073 1074 // v->ops->scale 1075 template <device::cupm::DeviceType T> 1076 inline PetscErrorCode VecSeq_CUPM<T>::Scale(Vec xin, PetscScalar alpha) noexcept 1077 { 1078 PetscFunctionBegin; 1079 if (PetscUnlikely(alpha == PetscScalar(1.0))) PetscFunctionReturn(PETSC_SUCCESS); 1080 if (PetscUnlikely(alpha == PetscScalar(0.0))) { 1081 PetscCall(Set(xin, alpha)); 1082 } else if (const auto n = static_cast<cupmBlasInt_t>(xin->map->n)) { 1083 PetscDeviceContext dctx; 1084 cupmBlasHandle_t cupmBlasHandle; 1085 1086 PetscCall(GetHandles_(&dctx, &cupmBlasHandle)); 1087 PetscCall(PetscLogGpuTimeBegin()); 1088 PetscCallCUPMBLAS(cupmBlasXscal(cupmBlasHandle, n, cupmScalarPtrCast(&alpha), DeviceArrayReadWrite(dctx, xin), 1)); 1089 PetscCall(PetscLogGpuTimeEnd()); 1090 PetscCall(PetscLogGpuFlops(n)); 1091 PetscCall(PetscDeviceContextSynchronize(dctx)); 1092 } else { 1093 PetscCall(MaybeIncrementEmptyLocalVec(xin)); 1094 } 1095 PetscFunctionReturn(PETSC_SUCCESS); 1096 } 1097 1098 // v->ops->tdot 1099 template <device::cupm::DeviceType T> 1100 inline PetscErrorCode VecSeq_CUPM<T>::TDot(Vec xin, Vec yin, PetscScalar *z) noexcept 1101 { 1102 PetscFunctionBegin; 1103 if (const auto n = static_cast<cupmBlasInt_t>(xin->map->n)) { 1104 PetscDeviceContext dctx; 1105 cupmBlasHandle_t cupmBlasHandle; 1106 1107 PetscCall(GetHandles_(&dctx, &cupmBlasHandle)); 1108 PetscCall(PetscLogGpuTimeBegin()); 1109 PetscCallCUPMBLAS(cupmBlasXdotu(cupmBlasHandle, n, DeviceArrayRead(dctx, xin), 1, DeviceArrayRead(dctx, yin), 1, cupmScalarPtrCast(z))); 1110 PetscCall(PetscLogGpuTimeEnd()); 1111 PetscCall(PetscLogGpuFlops(2 * n - 1)); 1112 } else { 1113 *z = 0.0; 1114 } 1115 PetscFunctionReturn(PETSC_SUCCESS); 1116 } 1117 1118 // v->ops->copy 1119 template <device::cupm::DeviceType T> 1120 inline PetscErrorCode VecSeq_CUPM<T>::Copy(Vec xin, Vec yout) noexcept 1121 { 1122 PetscFunctionBegin; 1123 if (xin == yout) PetscFunctionReturn(PETSC_SUCCESS); 1124 if (const auto n = xin->map->n) { 1125 const auto xmask = xin->offloadmask; 1126 // silence buggy gcc warning: mode may be used uninitialized in this function 1127 auto mode = cupmMemcpyDeviceToDevice; 1128 PetscDeviceContext dctx; 1129 cupmStream_t stream; 1130 1131 // translate from PetscOffloadMask to cupmMemcpyKind 1132 switch (const auto ymask = yout->offloadmask) { 1133 case PETSC_OFFLOAD_UNALLOCATED: { 1134 PetscBool yiscupm; 1135 1136 PetscCall(PetscObjectTypeCompareAny(PetscObjectCast(yout), &yiscupm, VECSEQCUPM(), VECMPICUPM(), "")); 1137 if (yiscupm) { 1138 mode = PetscOffloadDevice(xmask) ? cupmMemcpyDeviceToDevice : cupmMemcpyHostToHost; 1139 break; 1140 } 1141 } // fall-through if unallocated and not cupm 1142 #if PETSC_CPP_VERSION >= 17 1143 [[fallthrough]]; 1144 #endif 1145 case PETSC_OFFLOAD_CPU: 1146 mode = PetscOffloadHost(xmask) ? cupmMemcpyHostToHost : cupmMemcpyDeviceToHost; 1147 break; 1148 case PETSC_OFFLOAD_BOTH: 1149 case PETSC_OFFLOAD_GPU: 1150 mode = PetscOffloadDevice(xmask) ? cupmMemcpyDeviceToDevice : cupmMemcpyHostToDevice; 1151 break; 1152 default: 1153 SETERRQ(PETSC_COMM_SELF, PETSC_ERR_ARG_INCOMP, "Incompatible offload mask %s", PetscOffloadMaskToString(ymask)); 1154 } 1155 1156 PetscCall(GetHandles_(&dctx, &stream)); 1157 switch (mode) { 1158 case cupmMemcpyDeviceToDevice: // the best case 1159 case cupmMemcpyHostToDevice: { // not terrible 1160 const auto yptr = DeviceArrayWrite(dctx, yout); 1161 const auto xptr = mode == cupmMemcpyDeviceToDevice ? DeviceArrayRead(dctx, xin).data() : HostArrayRead(dctx, xin).data(); 1162 1163 PetscCall(PetscLogGpuTimeBegin()); 1164 PetscCall(PetscCUPMMemcpyAsync(yptr.data(), xptr, n, mode, stream)); 1165 PetscCall(PetscLogGpuTimeEnd()); 1166 } break; 1167 case cupmMemcpyDeviceToHost: // not great 1168 case cupmMemcpyHostToHost: { // worst case 1169 const auto xptr = mode == cupmMemcpyDeviceToHost ? DeviceArrayRead(dctx, xin).data() : HostArrayRead(dctx, xin).data(); 1170 PetscScalar *yptr; 1171 1172 PetscCall(VecGetArrayWrite(yout, &yptr)); 1173 if (mode == cupmMemcpyDeviceToHost) PetscCall(PetscLogGpuTimeBegin()); 1174 PetscCall(PetscCUPMMemcpyAsync(yptr, xptr, n, mode, stream, /* force async */ true)); 1175 if (mode == cupmMemcpyDeviceToHost) PetscCall(PetscLogGpuTimeEnd()); 1176 PetscCall(VecRestoreArrayWrite(yout, &yptr)); 1177 } break; 1178 default: 1179 SETERRQ(PETSC_COMM_SELF, PETSC_ERR_GPU, "Unknown cupmMemcpyKind %d", static_cast<int>(mode)); 1180 } 1181 PetscCall(PetscDeviceContextSynchronize(dctx)); 1182 } else { 1183 PetscCall(MaybeIncrementEmptyLocalVec(yout)); 1184 } 1185 PetscFunctionReturn(PETSC_SUCCESS); 1186 } 1187 1188 // v->ops->swap 1189 template <device::cupm::DeviceType T> 1190 inline PetscErrorCode VecSeq_CUPM<T>::Swap(Vec xin, Vec yin) noexcept 1191 { 1192 PetscFunctionBegin; 1193 if (xin == yin) PetscFunctionReturn(PETSC_SUCCESS); 1194 if (const auto n = static_cast<cupmBlasInt_t>(xin->map->n)) { 1195 PetscDeviceContext dctx; 1196 cupmBlasHandle_t cupmBlasHandle; 1197 1198 PetscCall(GetHandles_(&dctx, &cupmBlasHandle)); 1199 PetscCall(PetscLogGpuTimeBegin()); 1200 PetscCallCUPMBLAS(cupmBlasXswap(cupmBlasHandle, n, DeviceArrayReadWrite(dctx, xin), 1, DeviceArrayReadWrite(dctx, yin), 1)); 1201 PetscCall(PetscLogGpuTimeEnd()); 1202 PetscCall(PetscDeviceContextSynchronize(dctx)); 1203 } else { 1204 PetscCall(MaybeIncrementEmptyLocalVec(xin)); 1205 PetscCall(MaybeIncrementEmptyLocalVec(yin)); 1206 } 1207 PetscFunctionReturn(PETSC_SUCCESS); 1208 } 1209 1210 // v->ops->axpby 1211 template <device::cupm::DeviceType T> 1212 inline PetscErrorCode VecSeq_CUPM<T>::AXPBY(Vec yin, PetscScalar alpha, PetscScalar beta, Vec xin) noexcept 1213 { 1214 PetscFunctionBegin; 1215 if (alpha == PetscScalar(0.0)) { 1216 PetscCall(Scale(yin, beta)); 1217 } else if (beta == PetscScalar(1.0)) { 1218 PetscCall(AXPY(yin, alpha, xin)); 1219 } else if (alpha == PetscScalar(1.0)) { 1220 PetscCall(AYPX(yin, beta, xin)); 1221 } else if (const auto n = static_cast<cupmBlasInt_t>(yin->map->n)) { 1222 const auto betaIsZero = beta == PetscScalar(0.0); 1223 const auto aptr = cupmScalarPtrCast(&alpha); 1224 PetscDeviceContext dctx; 1225 cupmBlasHandle_t cupmBlasHandle; 1226 1227 PetscCall(GetHandles_(&dctx, &cupmBlasHandle)); 1228 { 1229 const auto xptr = DeviceArrayRead(dctx, xin); 1230 1231 if (betaIsZero /* beta = 0 */) { 1232 // here we can get away with purely write-only as we memcpy into it first 1233 const auto yptr = DeviceArrayWrite(dctx, yin); 1234 cupmStream_t stream; 1235 1236 PetscCall(GetHandlesFrom_(dctx, &stream)); 1237 PetscCall(PetscLogGpuTimeBegin()); 1238 PetscCall(PetscCUPMMemcpyAsync(yptr.data(), xptr.data(), n, cupmMemcpyDeviceToDevice, stream)); 1239 PetscCallCUPMBLAS(cupmBlasXscal(cupmBlasHandle, n, aptr, yptr.cupmdata(), 1)); 1240 } else { 1241 const auto yptr = DeviceArrayReadWrite(dctx, yin); 1242 1243 PetscCall(PetscLogGpuTimeBegin()); 1244 PetscCallCUPMBLAS(cupmBlasXscal(cupmBlasHandle, n, cupmScalarPtrCast(&beta), yptr.cupmdata(), 1)); 1245 PetscCallCUPMBLAS(cupmBlasXaxpy(cupmBlasHandle, n, aptr, xptr.cupmdata(), 1, yptr.cupmdata(), 1)); 1246 } 1247 } 1248 PetscCall(PetscLogGpuTimeEnd()); 1249 PetscCall(PetscLogGpuFlops((betaIsZero ? 1 : 3) * n)); 1250 PetscCall(PetscDeviceContextSynchronize(dctx)); 1251 } else { 1252 PetscCall(MaybeIncrementEmptyLocalVec(yin)); 1253 } 1254 PetscFunctionReturn(PETSC_SUCCESS); 1255 } 1256 1257 // v->ops->axpbypcz 1258 template <device::cupm::DeviceType T> 1259 inline PetscErrorCode VecSeq_CUPM<T>::AXPBYPCZ(Vec zin, PetscScalar alpha, PetscScalar beta, PetscScalar gamma, Vec xin, Vec yin) noexcept 1260 { 1261 PetscFunctionBegin; 1262 if (gamma != PetscScalar(1.0)) PetscCall(Scale(zin, gamma)); 1263 PetscCall(AXPY(zin, alpha, xin)); 1264 PetscCall(AXPY(zin, beta, yin)); 1265 PetscFunctionReturn(PETSC_SUCCESS); 1266 } 1267 1268 // v->ops->norm 1269 template <device::cupm::DeviceType T> 1270 inline PetscErrorCode VecSeq_CUPM<T>::Norm(Vec xin, NormType type, PetscReal *z) noexcept 1271 { 1272 PetscDeviceContext dctx; 1273 cupmBlasHandle_t cupmBlasHandle; 1274 1275 PetscFunctionBegin; 1276 PetscCall(GetHandles_(&dctx, &cupmBlasHandle)); 1277 if (const auto n = static_cast<cupmBlasInt_t>(xin->map->n)) { 1278 const auto xptr = DeviceArrayRead(dctx, xin); 1279 PetscInt flopCount = 0; 1280 1281 PetscCall(PetscLogGpuTimeBegin()); 1282 switch (type) { 1283 case NORM_1_AND_2: 1284 case NORM_1: 1285 PetscCallCUPMBLAS(cupmBlasXasum(cupmBlasHandle, n, xptr.cupmdata(), 1, cupmRealPtrCast(z))); 1286 flopCount = std::max(n - 1, 0); 1287 if (type == NORM_1) break; 1288 ++z; // fall-through 1289 #if PETSC_CPP_VERSION >= 17 1290 [[fallthrough]]; 1291 #endif 1292 case NORM_2: 1293 case NORM_FROBENIUS: 1294 PetscCallCUPMBLAS(cupmBlasXnrm2(cupmBlasHandle, n, xptr.cupmdata(), 1, cupmRealPtrCast(z))); 1295 flopCount += std::max(2 * n - 1, 0); // += in case we've fallen through from NORM_1_AND_2 1296 break; 1297 case NORM_INFINITY: { 1298 cupmBlasInt_t max_loc = 0; 1299 PetscScalar xv = 0.; 1300 cupmStream_t stream; 1301 1302 PetscCall(GetHandlesFrom_(dctx, &stream)); 1303 PetscCallCUPMBLAS(cupmBlasXamax(cupmBlasHandle, n, xptr.cupmdata(), 1, &max_loc)); 1304 PetscCall(PetscCUPMMemcpyAsync(&xv, xptr.data() + max_loc - 1, 1, cupmMemcpyDeviceToHost, stream)); 1305 *z = PetscAbsScalar(xv); 1306 // REVIEW ME: flopCount = ??? 1307 } break; 1308 } 1309 PetscCall(PetscLogGpuTimeEnd()); 1310 PetscCall(PetscLogGpuFlops(flopCount)); 1311 } else { 1312 z[0] = 0.0; 1313 z[type == NORM_1_AND_2] = 0.0; 1314 } 1315 PetscFunctionReturn(PETSC_SUCCESS); 1316 } 1317 1318 namespace detail 1319 { 1320 1321 template <NormType wnormtype> 1322 class ErrorWNormTransformBase { 1323 public: 1324 using result_type = thrust::tuple<PetscReal, PetscReal, PetscReal, PetscInt, PetscInt, PetscInt>; 1325 1326 constexpr explicit ErrorWNormTransformBase(PetscReal v) noexcept : ignore_max_{v} { } 1327 1328 protected: 1329 struct NormTuple { 1330 PetscReal norm; 1331 PetscInt loc; 1332 }; 1333 1334 PETSC_NODISCARD PETSC_HOSTDEVICE_INLINE_DECL static NormTuple compute_norm_(PetscReal err, PetscReal tol) noexcept 1335 { 1336 if (tol > 0.) { 1337 const auto val = err / tol; 1338 1339 return {wnormtype == NORM_INFINITY ? val : PetscSqr(val), 1}; 1340 } else { 1341 return {0.0, 0}; 1342 } 1343 } 1344 1345 PetscReal ignore_max_; 1346 }; 1347 1348 template <NormType wnormtype> 1349 struct ErrorWNormTransform : ErrorWNormTransformBase<wnormtype> { 1350 using base_type = ErrorWNormTransformBase<wnormtype>; 1351 using result_type = typename base_type::result_type; 1352 using argument_type = thrust::tuple<PetscScalar, PetscScalar, PetscScalar, PetscScalar>; 1353 1354 using base_type::base_type; 1355 1356 PETSC_NODISCARD PETSC_HOSTDEVICE_INLINE_DECL result_type operator()(const argument_type &x) const noexcept 1357 { 1358 const auto u = x.get<0>(); 1359 const auto y = x.get<1>(); 1360 const auto au = PetscAbsScalar(u); 1361 const auto ay = PetscAbsScalar(y); 1362 const auto skip = au < this->ignore_max_ || ay < this->ignore_max_; 1363 const auto tola = skip ? 0.0 : PetscRealPart(x.get<2>()); 1364 const auto tolr = skip ? 0.0 : PetscRealPart(x.get<3>()) * PetscMax(au, ay); 1365 const auto tol = tola + tolr; 1366 const auto err = PetscAbsScalar(u - y); 1367 const auto tup_a = this->compute_norm_(err, tola); 1368 const auto tup_r = this->compute_norm_(err, tolr); 1369 const auto tup_n = this->compute_norm_(err, tol); 1370 1371 return {tup_n.norm, tup_a.norm, tup_r.norm, tup_n.loc, tup_a.loc, tup_r.loc}; 1372 } 1373 }; 1374 1375 template <NormType wnormtype> 1376 struct ErrorWNormETransform : ErrorWNormTransformBase<wnormtype> { 1377 using base_type = ErrorWNormTransformBase<wnormtype>; 1378 using result_type = typename base_type::result_type; 1379 using argument_type = thrust::tuple<PetscScalar, PetscScalar, PetscScalar, PetscScalar, PetscScalar>; 1380 1381 using base_type::base_type; 1382 1383 PETSC_NODISCARD PETSC_HOSTDEVICE_INLINE_DECL result_type operator()(const argument_type &x) const noexcept 1384 { 1385 const auto au = PetscAbsScalar(x.get<0>()); 1386 const auto ay = PetscAbsScalar(x.get<1>()); 1387 const auto skip = au < this->ignore_max_ || ay < this->ignore_max_; 1388 const auto tola = skip ? 0.0 : PetscRealPart(x.get<3>()); 1389 const auto tolr = skip ? 0.0 : PetscRealPart(x.get<4>()) * PetscMax(au, ay); 1390 const auto tol = tola + tolr; 1391 const auto err = PetscAbsScalar(x.get<2>()); 1392 const auto tup_a = this->compute_norm_(err, tola); 1393 const auto tup_r = this->compute_norm_(err, tolr); 1394 const auto tup_n = this->compute_norm_(err, tol); 1395 1396 return {tup_n.norm, tup_a.norm, tup_r.norm, tup_n.loc, tup_a.loc, tup_r.loc}; 1397 } 1398 }; 1399 1400 template <NormType wnormtype> 1401 struct ErrorWNormReduce { 1402 using value_type = typename ErrorWNormTransformBase<wnormtype>::result_type; 1403 1404 PETSC_NODISCARD PETSC_HOSTDEVICE_INLINE_DECL value_type operator()(const value_type &lhs, const value_type &rhs) const noexcept 1405 { 1406 // cannot use lhs.get<0>() etc since the using decl above ambiguates the fact that 1407 // result_type is a template, so in order to fix this we would need to write: 1408 // 1409 // lhs.template get<0>() 1410 // 1411 // which is unseemly. 1412 if (wnormtype == NORM_INFINITY) { 1413 // clang-format off 1414 return { 1415 PetscMax(thrust::get<0>(lhs), thrust::get<0>(rhs)), 1416 PetscMax(thrust::get<1>(lhs), thrust::get<1>(rhs)), 1417 PetscMax(thrust::get<2>(lhs), thrust::get<2>(rhs)), 1418 thrust::get<3>(lhs) + thrust::get<3>(rhs), 1419 thrust::get<4>(lhs) + thrust::get<4>(rhs), 1420 thrust::get<5>(lhs) + thrust::get<5>(rhs) 1421 }; 1422 // clang-format on 1423 } else { 1424 // clang-format off 1425 return { 1426 thrust::get<0>(lhs) + thrust::get<0>(rhs), 1427 thrust::get<1>(lhs) + thrust::get<1>(rhs), 1428 thrust::get<2>(lhs) + thrust::get<2>(rhs), 1429 thrust::get<3>(lhs) + thrust::get<3>(rhs), 1430 thrust::get<4>(lhs) + thrust::get<4>(rhs), 1431 thrust::get<5>(lhs) + thrust::get<5>(rhs) 1432 }; 1433 // clang-format on 1434 } 1435 } 1436 }; 1437 1438 template <template <NormType> class WNormTransformType, typename Tuple, typename cupmStream_t> 1439 inline PetscErrorCode ExecuteWNorm(Tuple &&first, Tuple &&last, NormType wnormtype, cupmStream_t stream, PetscReal ignore_max, PetscReal *norm, PetscInt *norm_loc, PetscReal *norma, PetscInt *norma_loc, PetscReal *normr, PetscInt *normr_loc) noexcept 1440 { 1441 auto begin = thrust::make_zip_iterator(std::forward<Tuple>(first)); 1442 auto end = thrust::make_zip_iterator(std::forward<Tuple>(last)); 1443 PetscReal n = 0, na = 0, nr = 0; 1444 PetscInt n_loc = 0, na_loc = 0, nr_loc = 0; 1445 1446 PetscFunctionBegin; 1447 // clang-format off 1448 if (wnormtype == NORM_INFINITY) { 1449 PetscCallThrust( 1450 thrust::tie(*norm, *norma, *normr, *norm_loc, *norma_loc, *normr_loc) = THRUST_CALL( 1451 thrust::transform_reduce, 1452 stream, 1453 std::move(begin), 1454 std::move(end), 1455 WNormTransformType<NORM_INFINITY>{ignore_max}, 1456 thrust::make_tuple(n, na, nr, n_loc, na_loc, nr_loc), 1457 ErrorWNormReduce<NORM_INFINITY>{} 1458 ) 1459 ); 1460 } else { 1461 PetscCallThrust( 1462 thrust::tie(*norm, *norma, *normr, *norm_loc, *norma_loc, *normr_loc) = THRUST_CALL( 1463 thrust::transform_reduce, 1464 stream, 1465 std::move(begin), 1466 std::move(end), 1467 WNormTransformType<NORM_2>{ignore_max}, 1468 thrust::make_tuple(n, na, nr, n_loc, na_loc, nr_loc), 1469 ErrorWNormReduce<NORM_2>{} 1470 ) 1471 ); 1472 } 1473 // clang-format on 1474 if (wnormtype == NORM_2) { 1475 *norm = PetscSqrtReal(*norm); 1476 *norma = PetscSqrtReal(*norma); 1477 *normr = PetscSqrtReal(*normr); 1478 } 1479 PetscFunctionReturn(PETSC_SUCCESS); 1480 } 1481 1482 } // namespace detail 1483 1484 // v->ops->errorwnorm 1485 template <device::cupm::DeviceType T> 1486 inline PetscErrorCode VecSeq_CUPM<T>::ErrorWnorm(Vec U, Vec Y, Vec E, NormType wnormtype, PetscReal atol, Vec vatol, PetscReal rtol, Vec vrtol, PetscReal ignore_max, PetscReal *norm, PetscInt *norm_loc, PetscReal *norma, PetscInt *norma_loc, PetscReal *normr, PetscInt *normr_loc) noexcept 1487 { 1488 const auto nl = U->map->n; 1489 auto ait = thrust::make_constant_iterator(static_cast<PetscScalar>(atol)); 1490 auto rit = thrust::make_constant_iterator(static_cast<PetscScalar>(rtol)); 1491 PetscDeviceContext dctx; 1492 cupmStream_t stream; 1493 1494 PetscFunctionBegin; 1495 PetscCall(GetHandles_(&dctx, &stream)); 1496 { 1497 const auto ConditionalDeviceArrayRead = [&](Vec v) { 1498 if (v) { 1499 return thrust::device_pointer_cast(DeviceArrayRead(dctx, v).data()); 1500 } else { 1501 return thrust::device_ptr<PetscScalar>{nullptr}; 1502 } 1503 }; 1504 1505 const auto uarr = DeviceArrayRead(dctx, U); 1506 const auto yarr = DeviceArrayRead(dctx, Y); 1507 const auto uptr = thrust::device_pointer_cast(uarr.data()); 1508 const auto yptr = thrust::device_pointer_cast(yarr.data()); 1509 const auto eptr = ConditionalDeviceArrayRead(E); 1510 const auto rptr = ConditionalDeviceArrayRead(vrtol); 1511 const auto aptr = ConditionalDeviceArrayRead(vatol); 1512 1513 if (!vatol && !vrtol) { 1514 if (E) { 1515 // clang-format off 1516 PetscCall( 1517 detail::ExecuteWNorm<detail::ErrorWNormETransform>( 1518 thrust::make_tuple(uptr, yptr, eptr, ait, rit), 1519 thrust::make_tuple(uptr + nl, yptr + nl, eptr + nl, ait, rit), 1520 wnormtype, stream, ignore_max, norm, norm_loc, norma, norma_loc, normr, normr_loc 1521 ) 1522 ); 1523 // clang-format on 1524 } else { 1525 // clang-format off 1526 PetscCall( 1527 detail::ExecuteWNorm<detail::ErrorWNormTransform>( 1528 thrust::make_tuple(uptr, yptr, ait, rit), 1529 thrust::make_tuple(uptr + nl, yptr + nl, ait, rit), 1530 wnormtype, stream, ignore_max, norm, norm_loc, norma, norma_loc, normr, normr_loc 1531 ) 1532 ); 1533 // clang-format on 1534 } 1535 } else if (!vatol) { 1536 if (E) { 1537 // clang-format off 1538 PetscCall( 1539 detail::ExecuteWNorm<detail::ErrorWNormETransform>( 1540 thrust::make_tuple(uptr, yptr, eptr, ait, rptr), 1541 thrust::make_tuple(uptr + nl, yptr + nl, eptr + nl, ait, rptr + nl), 1542 wnormtype, stream, ignore_max, norm, norm_loc, norma, norma_loc, normr, normr_loc 1543 ) 1544 ); 1545 // clang-format on 1546 } else { 1547 // clang-format off 1548 PetscCall( 1549 detail::ExecuteWNorm<detail::ErrorWNormTransform>( 1550 thrust::make_tuple(uptr, yptr, ait, rptr), 1551 thrust::make_tuple(uptr + nl, yptr + nl, ait, rptr + nl), 1552 wnormtype, stream, ignore_max, norm, norm_loc, norma, norma_loc, normr, normr_loc 1553 ) 1554 ); 1555 // clang-format on 1556 } 1557 } else if (!vrtol) { 1558 if (E) { 1559 // clang-format off 1560 PetscCall( 1561 detail::ExecuteWNorm<detail::ErrorWNormETransform>( 1562 thrust::make_tuple(uptr, yptr, eptr, aptr, rit), 1563 thrust::make_tuple(uptr + nl, yptr + nl, eptr + nl, aptr + nl, rit), 1564 wnormtype, stream, ignore_max, norm, norm_loc, norma, norma_loc, normr, normr_loc 1565 ) 1566 ); 1567 // clang-format on 1568 } else { 1569 // clang-format off 1570 PetscCall( 1571 detail::ExecuteWNorm<detail::ErrorWNormTransform>( 1572 thrust::make_tuple(uptr, yptr, aptr, rit), 1573 thrust::make_tuple(uptr + nl, yptr + nl, aptr + nl, rit), 1574 wnormtype, stream, ignore_max, norm, norm_loc, norma, norma_loc, normr, normr_loc 1575 ) 1576 ); 1577 // clang-format on 1578 } 1579 } else { 1580 if (E) { 1581 // clang-format off 1582 PetscCall( 1583 detail::ExecuteWNorm<detail::ErrorWNormETransform>( 1584 thrust::make_tuple(uptr, yptr, eptr, aptr, rptr), 1585 thrust::make_tuple(uptr + nl, yptr + nl, eptr + nl, aptr + nl, rptr + nl), 1586 wnormtype, stream, ignore_max, norm, norm_loc, norma, norma_loc, normr, normr_loc 1587 ) 1588 ); 1589 // clang-format on 1590 } else { 1591 // clang-format off 1592 PetscCall( 1593 detail::ExecuteWNorm<detail::ErrorWNormTransform>( 1594 thrust::make_tuple(uptr, yptr, aptr, rptr), 1595 thrust::make_tuple(uptr + nl, yptr + nl, aptr + nl, rptr + nl), 1596 wnormtype, stream, ignore_max, norm, norm_loc, norma, norma_loc, normr, normr_loc 1597 ) 1598 ); 1599 // clang-format on 1600 } 1601 } 1602 } 1603 PetscFunctionReturn(PETSC_SUCCESS); 1604 } 1605 1606 namespace detail 1607 { 1608 struct dotnorm2_mult { 1609 PETSC_NODISCARD PETSC_HOSTDEVICE_INLINE_DECL thrust::tuple<PetscScalar, PetscScalar> operator()(const PetscScalar &s, const PetscScalar &t) const noexcept 1610 { 1611 const auto conjt = PetscConj(t); 1612 1613 return {s * conjt, t * conjt}; 1614 } 1615 }; 1616 1617 // it is positively __bananas__ that thrust does not define default operator+ for tuples... I 1618 // would do it myself but now I am worried that they do so on purpose... 1619 struct dotnorm2_tuple_plus { 1620 using value_type = thrust::tuple<PetscScalar, PetscScalar>; 1621 1622 PETSC_NODISCARD PETSC_HOSTDEVICE_INLINE_DECL value_type operator()(const value_type &lhs, const value_type &rhs) const noexcept { return {lhs.get<0>() + rhs.get<0>(), lhs.get<1>() + rhs.get<1>()}; } 1623 }; 1624 1625 } // namespace detail 1626 1627 // v->ops->dotnorm2 1628 template <device::cupm::DeviceType T> 1629 inline PetscErrorCode VecSeq_CUPM<T>::DotNorm2(Vec s, Vec t, PetscScalar *dp, PetscScalar *nm) noexcept 1630 { 1631 PetscDeviceContext dctx; 1632 cupmStream_t stream; 1633 1634 PetscFunctionBegin; 1635 PetscCall(GetHandles_(&dctx, &stream)); 1636 { 1637 PetscScalar dpt = 0.0, nmt = 0.0; 1638 const auto sdptr = thrust::device_pointer_cast(DeviceArrayRead(dctx, s).data()); 1639 1640 // clang-format off 1641 PetscCallThrust( 1642 thrust::tie(*dp, *nm) = THRUST_CALL( 1643 thrust::inner_product, 1644 stream, 1645 sdptr, sdptr+s->map->n, thrust::device_pointer_cast(DeviceArrayRead(dctx, t).data()), 1646 thrust::make_tuple(dpt, nmt), 1647 detail::dotnorm2_tuple_plus{}, detail::dotnorm2_mult{} 1648 ); 1649 ); 1650 // clang-format on 1651 } 1652 PetscFunctionReturn(PETSC_SUCCESS); 1653 } 1654 1655 namespace detail 1656 { 1657 struct conjugate { 1658 PETSC_NODISCARD PETSC_HOSTDEVICE_INLINE_DECL PetscScalar operator()(PetscScalar x) const noexcept { return PetscConj(x); } 1659 }; 1660 1661 } // namespace detail 1662 1663 // v->ops->conjugate 1664 template <device::cupm::DeviceType T> 1665 inline PetscErrorCode VecSeq_CUPM<T>::Conjugate(Vec xin) noexcept 1666 { 1667 PetscFunctionBegin; 1668 if (PetscDefined(USE_COMPLEX)) PetscCall(PointwiseUnary_(detail::conjugate{}, xin)); 1669 PetscFunctionReturn(PETSC_SUCCESS); 1670 } 1671 1672 namespace detail 1673 { 1674 struct real_part { 1675 PETSC_NODISCARD PETSC_HOSTDEVICE_INLINE_DECL thrust::tuple<PetscReal, PetscInt> operator()(const thrust::tuple<PetscScalar, PetscInt> &x) const { return {PetscRealPart(x.get<0>()), x.get<1>()}; } 1676 1677 PETSC_NODISCARD PETSC_HOSTDEVICE_INLINE_DECL PetscReal operator()(PetscScalar x) const { return PetscRealPart(x); } 1678 }; 1679 1680 // deriving from Operator allows us to "store" an instance of the operator in the class but 1681 // also take advantage of empty base class optimization if the operator is stateless 1682 template <typename Operator> 1683 class tuple_compare : Operator { 1684 public: 1685 using tuple_type = thrust::tuple<PetscReal, PetscInt>; 1686 using operator_type = Operator; 1687 1688 PETSC_NODISCARD PETSC_HOSTDEVICE_INLINE_DECL tuple_type operator()(const tuple_type &x, const tuple_type &y) const noexcept 1689 { 1690 if (op_()(y.get<0>(), x.get<0>())) { 1691 // if y is strictly greater/less than x, return y 1692 return y; 1693 } else if (y.get<0>() == x.get<0>()) { 1694 // if equal, prefer lower index 1695 return y.get<1>() < x.get<1>() ? y : x; 1696 } 1697 // otherwise return x 1698 return x; 1699 } 1700 1701 private: 1702 PETSC_NODISCARD PETSC_HOSTDEVICE_INLINE_DECL const operator_type &op_() const noexcept { return *this; } 1703 }; 1704 1705 } // namespace detail 1706 1707 template <device::cupm::DeviceType T> 1708 template <typename TupleFuncT, typename UnaryFuncT> 1709 inline PetscErrorCode VecSeq_CUPM<T>::MinMax_(TupleFuncT &&tuple_ftr, UnaryFuncT &&unary_ftr, Vec v, PetscInt *p, PetscReal *m) noexcept 1710 { 1711 PetscFunctionBegin; 1712 PetscCheckTypeNames(v, VECSEQCUPM(), VECMPICUPM()); 1713 if (p) *p = -1; 1714 if (const auto n = v->map->n) { 1715 PetscDeviceContext dctx; 1716 cupmStream_t stream; 1717 1718 PetscCall(GetHandles_(&dctx, &stream)); 1719 // needed to: 1720 // 1. switch between transform_reduce and reduce 1721 // 2. strip the real_part functor from the arguments 1722 #if PetscDefined(USE_COMPLEX) 1723 #define THRUST_MINMAX_REDUCE(...) THRUST_CALL(thrust::transform_reduce, __VA_ARGS__) 1724 #else 1725 #define THRUST_MINMAX_REDUCE(s, b, e, real_part__, ...) THRUST_CALL(thrust::reduce, s, b, e, __VA_ARGS__) 1726 #endif 1727 { 1728 const auto vptr = thrust::device_pointer_cast(DeviceArrayRead(dctx, v).data()); 1729 1730 if (p) { 1731 // clang-format off 1732 const auto zip = thrust::make_zip_iterator( 1733 thrust::make_tuple(std::move(vptr), thrust::make_counting_iterator(PetscInt{0})) 1734 ); 1735 // clang-format on 1736 // need to use preprocessor conditionals since otherwise thrust complains about not being 1737 // able to convert a thrust::device_reference<PetscScalar> to a PetscReal on complex 1738 // builds... 1739 // clang-format off 1740 PetscCallThrust( 1741 thrust::tie(*m, *p) = THRUST_MINMAX_REDUCE( 1742 stream, zip, zip + n, detail::real_part{}, 1743 thrust::make_tuple(*m, *p), std::forward<TupleFuncT>(tuple_ftr) 1744 ); 1745 ); 1746 // clang-format on 1747 } else { 1748 // clang-format off 1749 PetscCallThrust( 1750 *m = THRUST_MINMAX_REDUCE( 1751 stream, vptr, vptr + n, detail::real_part{}, 1752 *m, std::forward<UnaryFuncT>(unary_ftr) 1753 ); 1754 ); 1755 // clang-format on 1756 } 1757 } 1758 #undef THRUST_MINMAX_REDUCE 1759 } 1760 // REVIEW ME: flops? 1761 PetscFunctionReturn(PETSC_SUCCESS); 1762 } 1763 1764 // v->ops->max 1765 template <device::cupm::DeviceType T> 1766 inline PetscErrorCode VecSeq_CUPM<T>::Max(Vec v, PetscInt *p, PetscReal *m) noexcept 1767 { 1768 using tuple_functor = detail::tuple_compare<thrust::greater<PetscReal>>; 1769 using unary_functor = thrust::maximum<PetscReal>; 1770 1771 PetscFunctionBegin; 1772 *m = PETSC_MIN_REAL; 1773 // use {} constructor syntax otherwise most vexing parse 1774 PetscCall(MinMax_(tuple_functor{}, unary_functor{}, v, p, m)); 1775 PetscFunctionReturn(PETSC_SUCCESS); 1776 } 1777 1778 // v->ops->min 1779 template <device::cupm::DeviceType T> 1780 inline PetscErrorCode VecSeq_CUPM<T>::Min(Vec v, PetscInt *p, PetscReal *m) noexcept 1781 { 1782 using tuple_functor = detail::tuple_compare<thrust::less<PetscReal>>; 1783 using unary_functor = thrust::minimum<PetscReal>; 1784 1785 PetscFunctionBegin; 1786 *m = PETSC_MAX_REAL; 1787 // use {} constructor syntax otherwise most vexing parse 1788 PetscCall(MinMax_(tuple_functor{}, unary_functor{}, v, p, m)); 1789 PetscFunctionReturn(PETSC_SUCCESS); 1790 } 1791 1792 // v->ops->sum 1793 template <device::cupm::DeviceType T> 1794 inline PetscErrorCode VecSeq_CUPM<T>::Sum(Vec v, PetscScalar *sum) noexcept 1795 { 1796 PetscFunctionBegin; 1797 if (const auto n = v->map->n) { 1798 PetscDeviceContext dctx; 1799 cupmStream_t stream; 1800 1801 PetscCall(GetHandles_(&dctx, &stream)); 1802 const auto dptr = thrust::device_pointer_cast(DeviceArrayRead(dctx, v).data()); 1803 // REVIEW ME: why not cupmBlasXasum()? 1804 PetscCallThrust(*sum = THRUST_CALL(thrust::reduce, stream, dptr, dptr + n, PetscScalar{0.0});); 1805 // REVIEW ME: must be at least n additions 1806 PetscCall(PetscLogGpuFlops(n)); 1807 } else { 1808 *sum = 0.0; 1809 } 1810 PetscFunctionReturn(PETSC_SUCCESS); 1811 } 1812 1813 template <device::cupm::DeviceType T> 1814 inline PetscErrorCode VecSeq_CUPM<T>::Shift(Vec v, PetscScalar shift) noexcept 1815 { 1816 PetscFunctionBegin; 1817 PetscCall(PointwiseUnary_(device::cupm::functors::make_plus_equals(shift), v)); 1818 PetscFunctionReturn(PETSC_SUCCESS); 1819 } 1820 1821 template <device::cupm::DeviceType T> 1822 inline PetscErrorCode VecSeq_CUPM<T>::SetRandom(Vec v, PetscRandom rand) noexcept 1823 { 1824 PetscFunctionBegin; 1825 if (const auto n = v->map->n) { 1826 PetscBool iscurand; 1827 PetscDeviceContext dctx; 1828 1829 PetscCall(GetHandles_(&dctx)); 1830 PetscCall(PetscObjectTypeCompare(PetscObjectCast(rand), PETSCCURAND, &iscurand)); 1831 if (iscurand) PetscCall(PetscRandomGetValues(rand, n, DeviceArrayWrite(dctx, v))); 1832 else PetscCall(PetscRandomGetValues(rand, n, HostArrayWrite(dctx, v))); 1833 } else { 1834 PetscCall(MaybeIncrementEmptyLocalVec(v)); 1835 } 1836 // REVIEW ME: flops???? 1837 // REVIEW ME: Timing??? 1838 PetscFunctionReturn(PETSC_SUCCESS); 1839 } 1840 1841 // v->ops->setpreallocation 1842 template <device::cupm::DeviceType T> 1843 inline PetscErrorCode VecSeq_CUPM<T>::SetPreallocationCOO(Vec v, PetscCount ncoo, const PetscInt coo_i[]) noexcept 1844 { 1845 PetscDeviceContext dctx; 1846 1847 PetscFunctionBegin; 1848 PetscCall(GetHandles_(&dctx)); 1849 PetscCall(VecSetPreallocationCOO_Seq(v, ncoo, coo_i)); 1850 PetscCall(SetPreallocationCOO_CUPMBase(v, ncoo, coo_i, dctx)); 1851 PetscFunctionReturn(PETSC_SUCCESS); 1852 } 1853 1854 namespace kernels 1855 { 1856 template <typename F> 1857 PETSC_DEVICE_INLINE_DECL void add_coo_values_impl(const PetscScalar *PETSC_RESTRICT vv, PetscCount n, const PetscCount *PETSC_RESTRICT jmap, const PetscCount *PETSC_RESTRICT perm, InsertMode imode, PetscScalar *PETSC_RESTRICT xv, F &&xvindex) 1858 { 1859 ::Petsc::device::cupm::kernels::util::grid_stride_1D(n, [=](PetscCount i) { 1860 const auto end = jmap[i + 1]; 1861 const auto idx = xvindex(i); 1862 PetscScalar sum = 0.0; 1863 1864 for (auto k = jmap[i]; k < end; ++k) sum += vv[perm[k]]; 1865 1866 if (imode == INSERT_VALUES) { 1867 xv[idx] = sum; 1868 } else { 1869 xv[idx] += sum; 1870 } 1871 }); 1872 return; 1873 } 1874 1875 namespace 1876 { 1877 1878 PETSC_KERNEL_DECL void add_coo_values(const PetscScalar *PETSC_RESTRICT v, PetscCount n, const PetscCount *PETSC_RESTRICT jmap1, const PetscCount *PETSC_RESTRICT perm1, InsertMode imode, PetscScalar *PETSC_RESTRICT xv) 1879 { 1880 add_coo_values_impl(v, n, jmap1, perm1, imode, xv, [](PetscCount i) { return i; }); 1881 return; 1882 } 1883 1884 } // namespace 1885 1886 #if PetscDefined(USING_HCC) 1887 namespace do_not_use 1888 { 1889 1890 // Needed to silence clang warning: 1891 // 1892 // warning: function 'FUNCTION NAME' is not needed and will not be emitted 1893 // 1894 // The warning is silly, since the function *is* used, however the host compiler does not 1895 // appear see this. Likely because the function using it is in a template. 1896 // 1897 // This warning appeared in clang-11, and still persists until clang-15 (21/02/2023) 1898 inline void silence_warning_function_sum_kernel_is_not_needed_and_will_not_be_emitted() 1899 { 1900 (void)sum_kernel; 1901 } 1902 1903 inline void silence_warning_function_add_coo_values_is_not_needed_and_will_not_be_emitted() 1904 { 1905 (void)add_coo_values; 1906 } 1907 1908 } // namespace do_not_use 1909 #endif 1910 1911 } // namespace kernels 1912 1913 // v->ops->setvaluescoo 1914 template <device::cupm::DeviceType T> 1915 inline PetscErrorCode VecSeq_CUPM<T>::SetValuesCOO(Vec x, const PetscScalar v[], InsertMode imode) noexcept 1916 { 1917 auto vv = const_cast<PetscScalar *>(v); 1918 PetscMemType memtype; 1919 PetscDeviceContext dctx; 1920 cupmStream_t stream; 1921 1922 PetscFunctionBegin; 1923 PetscCall(GetHandles_(&dctx, &stream)); 1924 PetscCall(PetscGetMemType(v, &memtype)); 1925 if (PetscMemTypeHost(memtype)) { 1926 const auto size = VecIMPLCast(x)->coo_n; 1927 1928 // If user gave v[] in host, we might need to copy it to device if any 1929 PetscCall(PetscDeviceMalloc(dctx, PETSC_MEMTYPE_CUPM(), size, &vv)); 1930 PetscCall(PetscCUPMMemcpyAsync(vv, v, size, cupmMemcpyHostToDevice, stream)); 1931 } 1932 1933 if (const auto n = x->map->n) { 1934 const auto vcu = VecCUPMCast(x); 1935 1936 PetscCall(PetscCUPMLaunchKernel1D(n, 0, stream, kernels::add_coo_values, vv, n, vcu->jmap1_d, vcu->perm1_d, imode, imode == INSERT_VALUES ? DeviceArrayWrite(dctx, x).data() : DeviceArrayReadWrite(dctx, x).data())); 1937 } else { 1938 PetscCall(MaybeIncrementEmptyLocalVec(x)); 1939 } 1940 1941 if (PetscMemTypeHost(memtype)) PetscCall(PetscDeviceFree(dctx, vv)); 1942 PetscCall(PetscDeviceContextSynchronize(dctx)); 1943 PetscFunctionReturn(PETSC_SUCCESS); 1944 } 1945 1946 } // namespace impl 1947 1948 // ========================================================================================== 1949 // VecSeq_CUPM - Implementations 1950 // ========================================================================================== 1951 1952 namespace 1953 { 1954 1955 template <device::cupm::DeviceType T> 1956 inline PetscErrorCode VecCreateSeqCUPMAsync(MPI_Comm comm, PetscInt n, Vec *v) noexcept 1957 { 1958 PetscFunctionBegin; 1959 PetscValidPointer(v, 4); 1960 PetscCall(impl::VecSeq_CUPM<T>::CreateSeqCUPM(comm, 0, n, v, PETSC_TRUE)); 1961 PetscFunctionReturn(PETSC_SUCCESS); 1962 } 1963 1964 template <device::cupm::DeviceType T> 1965 inline PetscErrorCode VecCreateSeqCUPMWithArraysAsync(MPI_Comm comm, PetscInt bs, PetscInt n, const PetscScalar cpuarray[], const PetscScalar gpuarray[], Vec *v) noexcept 1966 { 1967 PetscFunctionBegin; 1968 if (n && cpuarray) PetscValidScalarPointer(cpuarray, 4); 1969 PetscValidPointer(v, 6); 1970 PetscCall(impl::VecSeq_CUPM<T>::CreateSeqCUPMWithBothArrays(comm, bs, n, cpuarray, gpuarray, v)); 1971 PetscFunctionReturn(PETSC_SUCCESS); 1972 } 1973 1974 template <PetscMemoryAccessMode mode, device::cupm::DeviceType T> 1975 inline PetscErrorCode VecCUPMGetArrayAsync_Private(Vec v, PetscScalar **a, PetscDeviceContext dctx) noexcept 1976 { 1977 PetscFunctionBegin; 1978 PetscValidHeaderSpecific(v, VEC_CLASSID, 1); 1979 PetscValidPointer(a, 2); 1980 PetscCall(PetscDeviceContextGetOptionalNullContext_Internal(&dctx)); 1981 PetscCall(impl::VecSeq_CUPM<T>::template GetArray<PETSC_MEMTYPE_DEVICE, mode>(v, a, dctx)); 1982 PetscFunctionReturn(PETSC_SUCCESS); 1983 } 1984 1985 template <PetscMemoryAccessMode mode, device::cupm::DeviceType T> 1986 inline PetscErrorCode VecCUPMRestoreArrayAsync_Private(Vec v, PetscScalar **a, PetscDeviceContext dctx) noexcept 1987 { 1988 PetscFunctionBegin; 1989 PetscValidHeaderSpecific(v, VEC_CLASSID, 1); 1990 PetscCall(PetscDeviceContextGetOptionalNullContext_Internal(&dctx)); 1991 PetscCall(impl::VecSeq_CUPM<T>::template RestoreArray<PETSC_MEMTYPE_DEVICE, mode>(v, a, dctx)); 1992 PetscFunctionReturn(PETSC_SUCCESS); 1993 } 1994 1995 template <device::cupm::DeviceType T> 1996 inline PetscErrorCode VecCUPMGetArrayAsync(Vec v, PetscScalar **a, PetscDeviceContext dctx = nullptr) noexcept 1997 { 1998 PetscFunctionBegin; 1999 PetscCall(VecCUPMGetArrayAsync_Private<PETSC_MEMORY_ACCESS_READ_WRITE, T>(v, a, dctx)); 2000 PetscFunctionReturn(PETSC_SUCCESS); 2001 } 2002 2003 template <device::cupm::DeviceType T> 2004 inline PetscErrorCode VecCUPMRestoreArrayAsync(Vec v, PetscScalar **a, PetscDeviceContext dctx = nullptr) noexcept 2005 { 2006 PetscFunctionBegin; 2007 PetscCall(VecCUPMRestoreArrayAsync_Private<PETSC_MEMORY_ACCESS_READ_WRITE, T>(v, a, dctx)); 2008 PetscFunctionReturn(PETSC_SUCCESS); 2009 } 2010 2011 template <device::cupm::DeviceType T> 2012 inline PetscErrorCode VecCUPMGetArrayReadAsync(Vec v, const PetscScalar **a, PetscDeviceContext dctx = nullptr) noexcept 2013 { 2014 PetscFunctionBegin; 2015 PetscCall(VecCUPMGetArrayAsync_Private<PETSC_MEMORY_ACCESS_READ, T>(v, const_cast<PetscScalar **>(a), dctx)); 2016 PetscFunctionReturn(PETSC_SUCCESS); 2017 } 2018 2019 template <device::cupm::DeviceType T> 2020 inline PetscErrorCode VecCUPMRestoreArrayReadAsync(Vec v, const PetscScalar **a, PetscDeviceContext dctx = nullptr) noexcept 2021 { 2022 PetscFunctionBegin; 2023 PetscCall(VecCUPMRestoreArrayAsync_Private<PETSC_MEMORY_ACCESS_READ, T>(v, const_cast<PetscScalar **>(a), dctx)); 2024 PetscFunctionReturn(PETSC_SUCCESS); 2025 } 2026 2027 template <device::cupm::DeviceType T> 2028 inline PetscErrorCode VecCUPMGetArrayWriteAsync(Vec v, PetscScalar **a, PetscDeviceContext dctx = nullptr) noexcept 2029 { 2030 PetscFunctionBegin; 2031 PetscCall(VecCUPMGetArrayAsync_Private<PETSC_MEMORY_ACCESS_WRITE, T>(v, a, dctx)); 2032 PetscFunctionReturn(PETSC_SUCCESS); 2033 } 2034 2035 template <device::cupm::DeviceType T> 2036 inline PetscErrorCode VecCUPMRestoreArrayWriteAsync(Vec v, PetscScalar **a, PetscDeviceContext dctx = nullptr) noexcept 2037 { 2038 PetscFunctionBegin; 2039 PetscCall(VecCUPMRestoreArrayAsync_Private<PETSC_MEMORY_ACCESS_WRITE, T>(v, a, dctx)); 2040 PetscFunctionReturn(PETSC_SUCCESS); 2041 } 2042 2043 template <device::cupm::DeviceType T> 2044 inline PetscErrorCode VecCUPMPlaceArrayAsync(Vec vin, const PetscScalar a[]) noexcept 2045 { 2046 PetscFunctionBegin; 2047 PetscValidHeaderSpecific(vin, VEC_CLASSID, 1); 2048 PetscCall(impl::VecSeq_CUPM<T>::template PlaceArray<PETSC_MEMTYPE_DEVICE>(vin, a)); 2049 PetscFunctionReturn(PETSC_SUCCESS); 2050 } 2051 2052 template <device::cupm::DeviceType T> 2053 inline PetscErrorCode VecCUPMReplaceArrayAsync(Vec vin, const PetscScalar a[]) noexcept 2054 { 2055 PetscFunctionBegin; 2056 PetscValidHeaderSpecific(vin, VEC_CLASSID, 1); 2057 PetscCall(impl::VecSeq_CUPM<T>::template ReplaceArray<PETSC_MEMTYPE_DEVICE>(vin, a)); 2058 PetscFunctionReturn(PETSC_SUCCESS); 2059 } 2060 2061 template <device::cupm::DeviceType T> 2062 inline PetscErrorCode VecCUPMResetArrayAsync(Vec vin) noexcept 2063 { 2064 PetscFunctionBegin; 2065 PetscValidHeaderSpecific(vin, VEC_CLASSID, 1); 2066 PetscCall(impl::VecSeq_CUPM<T>::template ResetArray<PETSC_MEMTYPE_DEVICE>(vin)); 2067 PetscFunctionReturn(PETSC_SUCCESS); 2068 } 2069 2070 } // anonymous namespace 2071 2072 } // namespace cupm 2073 2074 } // namespace vec 2075 2076 } // namespace Petsc 2077 2078 #endif // __cplusplus 2079 2080 #endif // PETSCVECSEQCUPM_HPP 2081