1 #ifndef PETSCVECSEQCUPM_HPP 2 #define PETSCVECSEQCUPM_HPP 3 4 #include <petsc/private/veccupmimpl.h> 5 6 #include <petsc/private/randomimpl.h> // for _p_PetscRandom 7 8 #include <petsc/private/cpp/utility.hpp> // util::exchange, util::index_sequence 9 10 #include "../src/sys/objects/device/impls/cupm/cupmthrustutility.hpp" 11 #include "../src/sys/objects/device/impls/cupm/kernels.hpp" 12 13 #if PetscDefined(USE_COMPLEX) 14 #include <thrust/transform_reduce.h> 15 #endif 16 #include <thrust/transform.h> 17 #include <thrust/reduce.h> 18 #include <thrust/functional.h> 19 #include <thrust/tuple.h> 20 #include <thrust/device_ptr.h> 21 #include <thrust/iterator/zip_iterator.h> 22 #include <thrust/iterator/counting_iterator.h> 23 #include <thrust/iterator/constant_iterator.h> 24 #include <thrust/inner_product.h> 25 26 namespace Petsc 27 { 28 29 namespace vec 30 { 31 32 namespace cupm 33 { 34 35 namespace impl 36 { 37 38 // ========================================================================================== 39 // VecSeq_CUPM 40 // ========================================================================================== 41 42 template <device::cupm::DeviceType T> 43 class VecSeq_CUPM : Vec_CUPMBase<T, VecSeq_CUPM<T>> { 44 public: 45 PETSC_VEC_CUPM_BASE_CLASS_HEADER(base_type, T, VecSeq_CUPM<T>); 46 47 private: 48 PETSC_NODISCARD static Vec_Seq *VecIMPLCast_(Vec) noexcept; 49 PETSC_NODISCARD static constexpr VecType VECIMPLCUPM_() noexcept; 50 PETSC_NODISCARD static constexpr VecType VECIMPL_() noexcept; 51 52 static PetscErrorCode VecDestroy_IMPL_(Vec) noexcept; 53 static PetscErrorCode VecResetArray_IMPL_(Vec) noexcept; 54 static PetscErrorCode VecPlaceArray_IMPL_(Vec, const PetscScalar *) noexcept; 55 static PetscErrorCode VecCreate_IMPL_Private_(Vec, PetscBool *, PetscInt, PetscScalar *) noexcept; 56 57 static PetscErrorCode MaybeIncrementEmptyLocalVec(Vec) noexcept; 58 59 // common core for min and max 60 template <typename TupleFuncT, typename UnaryFuncT> 61 static PetscErrorCode MinMax_(TupleFuncT &&, UnaryFuncT &&, Vec, PetscInt *, PetscReal *) noexcept; 62 // common core for pointwise binary and pointwise unary thrust functions 63 template <typename BinaryFuncT> 64 static PetscErrorCode PointwiseBinary_(BinaryFuncT &&, Vec, Vec, Vec) noexcept; 65 template <typename UnaryFuncT> 66 static PetscErrorCode PointwiseUnary_(UnaryFuncT &&, Vec, Vec /*out*/ = nullptr) noexcept; 67 // mdot dispatchers 68 static PetscErrorCode MDot_(/* use complex = */ std::true_type, Vec, PetscInt, const Vec[], PetscScalar *, PetscDeviceContext) noexcept; 69 static PetscErrorCode MDot_(/* use complex = */ std::false_type, Vec, PetscInt, const Vec[], PetscScalar *, PetscDeviceContext) noexcept; 70 template <std::size_t... Idx> 71 static PetscErrorCode MDot_kernel_dispatch_(PetscDeviceContext, cupmStream_t, const PetscScalar *, const Vec[], PetscInt, PetscScalar *, util::index_sequence<Idx...>) noexcept; 72 template <int> 73 static PetscErrorCode MDot_kernel_dispatch_(PetscDeviceContext, cupmStream_t, const PetscScalar *, const Vec[], PetscInt, PetscScalar *, PetscInt &) noexcept; 74 template <std::size_t... Idx> 75 static PetscErrorCode MAXPY_kernel_dispatch_(PetscDeviceContext, cupmStream_t, PetscScalar *, const PetscScalar *, const Vec *, PetscInt, util::index_sequence<Idx...>) noexcept; 76 template <int> 77 static PetscErrorCode MAXPY_kernel_dispatch_(PetscDeviceContext, cupmStream_t, PetscScalar *, const PetscScalar *, const Vec *, PetscInt, PetscInt &) noexcept; 78 // common core for the various create routines 79 static PetscErrorCode CreateSeqCUPM_(Vec, PetscDeviceContext, PetscScalar * /*host_ptr*/ = nullptr, PetscScalar * /*device_ptr*/ = nullptr) noexcept; 80 81 public: 82 // callable directly via a bespoke function 83 static PetscErrorCode CreateSeqCUPM(MPI_Comm, PetscInt, PetscInt, Vec *, PetscBool) noexcept; 84 static PetscErrorCode CreateSeqCUPMWithBothArrays(MPI_Comm, PetscInt, PetscInt, const PetscScalar[], const PetscScalar[], Vec *) noexcept; 85 86 // callable indirectly via function pointers 87 static PetscErrorCode Duplicate(Vec, Vec *) noexcept; 88 static PetscErrorCode AYPX(Vec, PetscScalar, Vec) noexcept; 89 static PetscErrorCode AXPY(Vec, PetscScalar, Vec) noexcept; 90 static PetscErrorCode PointwiseDivide(Vec, Vec, Vec) noexcept; 91 static PetscErrorCode PointwiseMult(Vec, Vec, Vec) noexcept; 92 static PetscErrorCode Reciprocal(Vec) noexcept; 93 static PetscErrorCode WAXPY(Vec, PetscScalar, Vec, Vec) noexcept; 94 static PetscErrorCode MAXPY(Vec, PetscInt, const PetscScalar[], Vec *) noexcept; 95 static PetscErrorCode Dot(Vec, Vec, PetscScalar *) noexcept; 96 static PetscErrorCode MDot(Vec, PetscInt, const Vec[], PetscScalar *) noexcept; 97 static PetscErrorCode Set(Vec, PetscScalar) noexcept; 98 static PetscErrorCode Scale(Vec, PetscScalar) noexcept; 99 static PetscErrorCode TDot(Vec, Vec, PetscScalar *) noexcept; 100 static PetscErrorCode Copy(Vec, Vec) noexcept; 101 static PetscErrorCode Swap(Vec, Vec) noexcept; 102 static PetscErrorCode AXPBY(Vec, PetscScalar, PetscScalar, Vec) noexcept; 103 static PetscErrorCode AXPBYPCZ(Vec, PetscScalar, PetscScalar, PetscScalar, Vec, Vec) noexcept; 104 static PetscErrorCode Norm(Vec, NormType, PetscReal *) noexcept; 105 static PetscErrorCode ErrorWnorm(Vec, Vec, Vec, NormType, PetscReal, Vec, PetscReal, Vec, PetscReal, PetscReal *, PetscInt *, PetscReal *, PetscInt *, PetscReal *, PetscInt *) noexcept; 106 static PetscErrorCode DotNorm2(Vec, Vec, PetscScalar *, PetscScalar *) noexcept; 107 static PetscErrorCode Conjugate(Vec) noexcept; 108 template <PetscMemoryAccessMode> 109 static PetscErrorCode GetLocalVector(Vec, Vec) noexcept; 110 template <PetscMemoryAccessMode> 111 static PetscErrorCode RestoreLocalVector(Vec, Vec) noexcept; 112 static PetscErrorCode Max(Vec, PetscInt *, PetscReal *) noexcept; 113 static PetscErrorCode Min(Vec, PetscInt *, PetscReal *) noexcept; 114 static PetscErrorCode Sum(Vec, PetscScalar *) noexcept; 115 static PetscErrorCode Shift(Vec, PetscScalar) noexcept; 116 static PetscErrorCode SetRandom(Vec, PetscRandom) noexcept; 117 static PetscErrorCode BindToCPU(Vec, PetscBool) noexcept; 118 static PetscErrorCode SetPreallocationCOO(Vec, PetscCount, const PetscInt[]) noexcept; 119 static PetscErrorCode SetValuesCOO(Vec, const PetscScalar[], InsertMode) noexcept; 120 }; 121 122 // ========================================================================================== 123 // VecSeq_CUPM - Private API 124 // ========================================================================================== 125 126 template <device::cupm::DeviceType T> 127 inline Vec_Seq *VecSeq_CUPM<T>::VecIMPLCast_(Vec v) noexcept 128 { 129 return static_cast<Vec_Seq *>(v->data); 130 } 131 132 template <device::cupm::DeviceType T> 133 inline constexpr VecType VecSeq_CUPM<T>::VECIMPLCUPM_() noexcept 134 { 135 return VECSEQCUPM(); 136 } 137 138 template <device::cupm::DeviceType T> 139 inline constexpr VecType VecSeq_CUPM<T>::VECIMPL_() noexcept 140 { 141 return VECSEQ; 142 } 143 144 template <device::cupm::DeviceType T> 145 inline PetscErrorCode VecSeq_CUPM<T>::VecDestroy_IMPL_(Vec v) noexcept 146 { 147 return VecDestroy_Seq(v); 148 } 149 150 template <device::cupm::DeviceType T> 151 inline PetscErrorCode VecSeq_CUPM<T>::VecResetArray_IMPL_(Vec v) noexcept 152 { 153 return VecResetArray_Seq(v); 154 } 155 156 template <device::cupm::DeviceType T> 157 inline PetscErrorCode VecSeq_CUPM<T>::VecPlaceArray_IMPL_(Vec v, const PetscScalar *a) noexcept 158 { 159 return VecPlaceArray_Seq(v, a); 160 } 161 162 template <device::cupm::DeviceType T> 163 inline PetscErrorCode VecSeq_CUPM<T>::VecCreate_IMPL_Private_(Vec v, PetscBool *alloc_missing, PetscInt, PetscScalar *host_array) noexcept 164 { 165 PetscMPIInt size; 166 167 PetscFunctionBegin; 168 if (alloc_missing) *alloc_missing = PETSC_FALSE; 169 PetscCallMPI(MPI_Comm_size(PetscObjectComm(PetscObjectCast(v)), &size)); 170 PetscCheck(size <= 1, PETSC_COMM_SELF, PETSC_ERR_ARG_WRONG, "Must create VecSeq on communicator of size 1, have size %d", size); 171 PetscCall(VecCreate_Seq_Private(v, host_array)); 172 PetscFunctionReturn(PETSC_SUCCESS); 173 } 174 175 // for functions with an early return based one vec size we still need to artificially bump the 176 // object state. This is to prevent the following: 177 // 178 // 0. Suppose you have a Vec { 179 // rank 0: [0], 180 // rank 1: [<empty>] 181 // } 182 // 1. both ranks have Vec with PetscObjectState = 0, stashed norm of 0 183 // 2. Vec enters e.g. VecSet(10) 184 // 3. rank 1 has local size 0 and bails immediately 185 // 4. rank 0 has local size 1 and enters function, eventually calls DeviceArrayWrite() 186 // 5. DeviceArrayWrite() calls PetscObjectStateIncrease(), now state = 1 187 // 6. Vec enters VecNorm(), and calls VecNormAvailable() 188 // 7. rank 1 has object state = 0, equal to stash and returns early with norm = 0 189 // 8. rank 0 has object state = 1, not equal to stash, continues to impl function 190 // 9. rank 0 deadlocks on MPI_Allreduce() because rank 1 bailed early 191 template <device::cupm::DeviceType T> 192 inline PetscErrorCode VecSeq_CUPM<T>::MaybeIncrementEmptyLocalVec(Vec v) noexcept 193 { 194 PetscFunctionBegin; 195 if (PetscUnlikely((v->map->n == 0) && (v->map->N != 0))) PetscCall(PetscObjectStateIncrease(PetscObjectCast(v))); 196 PetscFunctionReturn(PETSC_SUCCESS); 197 } 198 199 template <device::cupm::DeviceType T> 200 inline PetscErrorCode VecSeq_CUPM<T>::CreateSeqCUPM_(Vec v, PetscDeviceContext dctx, PetscScalar *host_array, PetscScalar *device_array) noexcept 201 { 202 PetscFunctionBegin; 203 PetscCall(base_type::VecCreate_IMPL_Private(v, nullptr, 0, host_array)); 204 PetscCall(Initialize_CUPMBase(v, PETSC_FALSE, host_array, device_array, dctx)); 205 PetscFunctionReturn(PETSC_SUCCESS); 206 } 207 208 template <device::cupm::DeviceType T> 209 template <typename BinaryFuncT> 210 inline PetscErrorCode VecSeq_CUPM<T>::PointwiseBinary_(BinaryFuncT &&binary, Vec xin, Vec yin, Vec zout) noexcept 211 { 212 PetscFunctionBegin; 213 if (const auto n = zout->map->n) { 214 PetscDeviceContext dctx; 215 cupmStream_t stream; 216 217 PetscCall(GetHandles_(&dctx, &stream)); 218 // clang-format off 219 PetscCallThrust( 220 const auto dxptr = thrust::device_pointer_cast(DeviceArrayRead(dctx, xin).data()); 221 222 THRUST_CALL( 223 thrust::transform, 224 stream, 225 dxptr, dxptr + n, 226 thrust::device_pointer_cast(DeviceArrayRead(dctx, yin).data()), 227 thrust::device_pointer_cast(DeviceArrayWrite(dctx, zout).data()), 228 std::forward<BinaryFuncT>(binary) 229 ) 230 ); 231 // clang-format on 232 PetscCall(PetscLogFlops(n)); 233 PetscCall(PetscDeviceContextSynchronize(dctx)); 234 } else { 235 PetscCall(MaybeIncrementEmptyLocalVec(zout)); 236 } 237 PetscFunctionReturn(PETSC_SUCCESS); 238 } 239 240 template <device::cupm::DeviceType T> 241 template <typename UnaryFuncT> 242 inline PetscErrorCode VecSeq_CUPM<T>::PointwiseUnary_(UnaryFuncT &&unary, Vec xinout, Vec yin) noexcept 243 { 244 const auto inplace = !yin || (xinout == yin); 245 246 PetscFunctionBegin; 247 if (const auto n = xinout->map->n) { 248 PetscDeviceContext dctx; 249 cupmStream_t stream; 250 const auto apply = [&](PetscScalar *xinout, PetscScalar *yin = nullptr) { 251 PetscFunctionBegin; 252 // clang-format off 253 PetscCallThrust( 254 const auto xptr = thrust::device_pointer_cast(xinout); 255 256 THRUST_CALL( 257 thrust::transform, 258 stream, 259 xptr, xptr + n, 260 (yin && (yin != xinout)) ? thrust::device_pointer_cast(yin) : xptr, 261 std::forward<UnaryFuncT>(unary) 262 ) 263 ); 264 // clang-format on 265 PetscFunctionReturn(PETSC_SUCCESS); 266 }; 267 268 PetscCall(GetHandles_(&dctx, &stream)); 269 if (inplace) { 270 PetscCall(apply(DeviceArrayReadWrite(dctx, xinout).data())); 271 } else { 272 PetscCall(apply(DeviceArrayRead(dctx, xinout).data(), DeviceArrayWrite(dctx, yin).data())); 273 } 274 PetscCall(PetscLogFlops(n)); 275 PetscCall(PetscDeviceContextSynchronize(dctx)); 276 } else { 277 if (inplace) { 278 PetscCall(MaybeIncrementEmptyLocalVec(xinout)); 279 } else { 280 PetscCall(MaybeIncrementEmptyLocalVec(yin)); 281 } 282 } 283 PetscFunctionReturn(PETSC_SUCCESS); 284 } 285 286 // ========================================================================================== 287 // VecSeq_CUPM - Public API - Constructors 288 // ========================================================================================== 289 290 // VecCreateSeqCUPM() 291 template <device::cupm::DeviceType T> 292 inline PetscErrorCode VecSeq_CUPM<T>::CreateSeqCUPM(MPI_Comm comm, PetscInt bs, PetscInt n, Vec *v, PetscBool call_set_type) noexcept 293 { 294 PetscFunctionBegin; 295 PetscCall(Create_CUPMBase(comm, bs, n, n, v, call_set_type)); 296 PetscFunctionReturn(PETSC_SUCCESS); 297 } 298 299 // VecCreateSeqCUPMWithArrays() 300 template <device::cupm::DeviceType T> 301 inline PetscErrorCode VecSeq_CUPM<T>::CreateSeqCUPMWithBothArrays(MPI_Comm comm, PetscInt bs, PetscInt n, const PetscScalar host_array[], const PetscScalar device_array[], Vec *v) noexcept 302 { 303 PetscDeviceContext dctx; 304 305 PetscFunctionBegin; 306 PetscCall(GetHandles_(&dctx)); 307 // do NOT call VecSetType(), otherwise ops->create() -> create() -> 308 // CreateSeqCUPM_() is called! 309 PetscCall(CreateSeqCUPM(comm, bs, n, v, PETSC_FALSE)); 310 PetscCall(CreateSeqCUPM_(*v, dctx, PetscRemoveConstCast(host_array), PetscRemoveConstCast(device_array))); 311 PetscFunctionReturn(PETSC_SUCCESS); 312 } 313 314 // v->ops->duplicate 315 template <device::cupm::DeviceType T> 316 inline PetscErrorCode VecSeq_CUPM<T>::Duplicate(Vec v, Vec *y) noexcept 317 { 318 PetscDeviceContext dctx; 319 320 PetscFunctionBegin; 321 PetscCall(GetHandles_(&dctx)); 322 PetscCall(Duplicate_CUPMBase(v, y, dctx)); 323 PetscFunctionReturn(PETSC_SUCCESS); 324 } 325 326 // ========================================================================================== 327 // VecSeq_CUPM - Public API - Utility 328 // ========================================================================================== 329 330 // v->ops->bindtocpu 331 template <device::cupm::DeviceType T> 332 inline PetscErrorCode VecSeq_CUPM<T>::BindToCPU(Vec v, PetscBool usehost) noexcept 333 { 334 PetscDeviceContext dctx; 335 336 PetscFunctionBegin; 337 PetscCall(GetHandles_(&dctx)); 338 PetscCall(BindToCPU_CUPMBase(v, usehost, dctx)); 339 340 // REVIEW ME: this absolutely should be some sort of bulk mempcy rather than this mess 341 VecSetOp_CUPM(dot, VecDot_Seq, Dot); 342 VecSetOp_CUPM(norm, VecNorm_Seq, Norm); 343 VecSetOp_CUPM(tdot, VecTDot_Seq, TDot); 344 VecSetOp_CUPM(mdot, VecMDot_Seq, MDot); 345 VecSetOp_CUPM(resetarray, VecResetArray_Seq, base_type::template ResetArray<PETSC_MEMTYPE_HOST>); 346 VecSetOp_CUPM(placearray, VecPlaceArray_Seq, base_type::template PlaceArray<PETSC_MEMTYPE_HOST>); 347 v->ops->mtdot = v->ops->mtdot_local = VecMTDot_Seq; 348 VecSetOp_CUPM(conjugate, VecConjugate_Seq, Conjugate); 349 VecSetOp_CUPM(max, VecMax_Seq, Max); 350 VecSetOp_CUPM(min, VecMin_Seq, Min); 351 VecSetOp_CUPM(setpreallocationcoo, VecSetPreallocationCOO_Seq, SetPreallocationCOO); 352 VecSetOp_CUPM(setvaluescoo, VecSetValuesCOO_Seq, SetValuesCOO); 353 VecSetOp_CUPM(errorwnorm, nullptr, ErrorWnorm); 354 PetscFunctionReturn(PETSC_SUCCESS); 355 } 356 357 // ========================================================================================== 358 // VecSeq_CUPM - Public API - Mutators 359 // ========================================================================================== 360 361 // v->ops->getlocalvector or v->ops->getlocalvectorread 362 template <device::cupm::DeviceType T> 363 template <PetscMemoryAccessMode access> 364 inline PetscErrorCode VecSeq_CUPM<T>::GetLocalVector(Vec v, Vec w) noexcept 365 { 366 PetscBool wisseqcupm; 367 368 PetscFunctionBegin; 369 PetscCheckTypeNames(v, VECSEQCUPM(), VECMPICUPM()); 370 PetscCall(PetscObjectTypeCompare(PetscObjectCast(w), VECSEQCUPM(), &wisseqcupm)); 371 if (wisseqcupm) { 372 if (const auto wseq = VecIMPLCast(w)) { 373 if (auto &alloced = wseq->array_allocated) { 374 const auto useit = UseCUPMHostAlloc(util::exchange(w->pinned_memory, PETSC_FALSE)); 375 376 PetscCall(PetscFree(alloced)); 377 } 378 wseq->array = nullptr; 379 wseq->unplacedarray = nullptr; 380 } 381 if (const auto wcu = VecCUPMCast(w)) { 382 if (auto &device_array = wcu->array_d) { 383 cupmStream_t stream; 384 385 PetscCall(GetHandles_(&stream)); 386 PetscCallCUPM(cupmFreeAsync(device_array, stream)); 387 } 388 PetscCall(PetscFree(w->spptr /* wcu */)); 389 } 390 } 391 if (v->petscnative && wisseqcupm) { 392 PetscCall(PetscFree(w->data)); 393 w->data = v->data; 394 w->offloadmask = v->offloadmask; 395 w->pinned_memory = v->pinned_memory; 396 w->spptr = v->spptr; 397 PetscCall(PetscObjectStateIncrease(PetscObjectCast(w))); 398 } else { 399 const auto array = &VecIMPLCast(w)->array; 400 401 if (access == PETSC_MEMORY_ACCESS_READ) { 402 PetscCall(VecGetArrayRead(v, const_cast<const PetscScalar **>(array))); 403 } else { 404 PetscCall(VecGetArray(v, array)); 405 } 406 w->offloadmask = PETSC_OFFLOAD_CPU; 407 if (wisseqcupm) { 408 PetscDeviceContext dctx; 409 410 PetscCall(GetHandles_(&dctx)); 411 PetscCall(DeviceAllocateCheck_(dctx, w)); 412 } 413 } 414 PetscFunctionReturn(PETSC_SUCCESS); 415 } 416 417 // v->ops->restorelocalvector or v->ops->restorelocalvectorread 418 template <device::cupm::DeviceType T> 419 template <PetscMemoryAccessMode access> 420 inline PetscErrorCode VecSeq_CUPM<T>::RestoreLocalVector(Vec v, Vec w) noexcept 421 { 422 PetscBool wisseqcupm; 423 424 PetscFunctionBegin; 425 PetscCheckTypeNames(v, VECSEQCUPM(), VECMPICUPM()); 426 PetscCall(PetscObjectTypeCompare(PetscObjectCast(w), VECSEQCUPM(), &wisseqcupm)); 427 if (v->petscnative && wisseqcupm) { 428 // the assignments to nullptr are __critical__, as w may persist after this call returns 429 // and shouldn't share data with v! 430 v->pinned_memory = w->pinned_memory; 431 v->offloadmask = util::exchange(w->offloadmask, PETSC_OFFLOAD_UNALLOCATED); 432 v->data = util::exchange(w->data, nullptr); 433 v->spptr = util::exchange(w->spptr, nullptr); 434 } else { 435 const auto array = &VecIMPLCast(w)->array; 436 437 if (access == PETSC_MEMORY_ACCESS_READ) { 438 PetscCall(VecRestoreArrayRead(v, const_cast<const PetscScalar **>(array))); 439 } else { 440 PetscCall(VecRestoreArray(v, array)); 441 } 442 if (w->spptr && wisseqcupm) { 443 cupmStream_t stream; 444 445 PetscCall(GetHandles_(&stream)); 446 PetscCallCUPM(cupmFreeAsync(VecCUPMCast(w)->array_d, stream)); 447 PetscCall(PetscFree(w->spptr)); 448 } 449 } 450 PetscFunctionReturn(PETSC_SUCCESS); 451 } 452 453 // ========================================================================================== 454 // VecSeq_CUPM - Public API - Compute Methods 455 // ========================================================================================== 456 457 // v->ops->aypx 458 template <device::cupm::DeviceType T> 459 inline PetscErrorCode VecSeq_CUPM<T>::AYPX(Vec yin, PetscScalar alpha, Vec xin) noexcept 460 { 461 const auto n = static_cast<cupmBlasInt_t>(yin->map->n); 462 const auto sync = n != 0; 463 PetscDeviceContext dctx; 464 465 PetscFunctionBegin; 466 PetscCall(GetHandles_(&dctx)); 467 if (alpha == PetscScalar(0.0)) { 468 cupmStream_t stream; 469 470 PetscCall(GetHandlesFrom_(dctx, &stream)); 471 PetscCall(PetscLogGpuTimeBegin()); 472 PetscCall(PetscCUPMMemcpyAsync(DeviceArrayWrite(dctx, yin).data(), DeviceArrayRead(dctx, xin).data(), n, cupmMemcpyDeviceToDevice, stream)); 473 PetscCall(PetscLogGpuTimeEnd()); 474 } else if (n) { 475 const auto alphaIsOne = alpha == PetscScalar(1.0); 476 const auto calpha = cupmScalarPtrCast(&alpha); 477 cupmBlasHandle_t cupmBlasHandle; 478 479 PetscCall(GetHandlesFrom_(dctx, &cupmBlasHandle)); 480 { 481 const auto yptr = DeviceArrayReadWrite(dctx, yin); 482 const auto xptr = DeviceArrayRead(dctx, xin); 483 484 PetscCall(PetscLogGpuTimeBegin()); 485 if (alphaIsOne) { 486 PetscCallCUPMBLAS(cupmBlasXaxpy(cupmBlasHandle, n, calpha, xptr.cupmdata(), 1, yptr.cupmdata(), 1)); 487 } else { 488 const auto one = cupmScalarCast(1.0); 489 490 PetscCallCUPMBLAS(cupmBlasXscal(cupmBlasHandle, n, calpha, yptr.cupmdata(), 1)); 491 PetscCallCUPMBLAS(cupmBlasXaxpy(cupmBlasHandle, n, &one, xptr.cupmdata(), 1, yptr.cupmdata(), 1)); 492 } 493 PetscCall(PetscLogGpuTimeEnd()); 494 } 495 PetscCall(PetscLogGpuFlops((alphaIsOne ? 1 : 2) * n)); 496 } 497 if (sync) PetscCall(PetscDeviceContextSynchronize(dctx)); 498 PetscFunctionReturn(PETSC_SUCCESS); 499 } 500 501 // v->ops->axpy 502 template <device::cupm::DeviceType T> 503 inline PetscErrorCode VecSeq_CUPM<T>::AXPY(Vec yin, PetscScalar alpha, Vec xin) noexcept 504 { 505 PetscBool xiscupm; 506 507 PetscFunctionBegin; 508 PetscCall(PetscObjectTypeCompareAny(PetscObjectCast(xin), &xiscupm, VECSEQCUPM(), VECMPICUPM(), "")); 509 if (xiscupm) { 510 const auto n = static_cast<cupmBlasInt_t>(yin->map->n); 511 PetscDeviceContext dctx; 512 cupmBlasHandle_t cupmBlasHandle; 513 514 PetscCall(GetHandles_(&dctx, &cupmBlasHandle)); 515 PetscCall(PetscLogGpuTimeBegin()); 516 PetscCallCUPMBLAS(cupmBlasXaxpy(cupmBlasHandle, n, cupmScalarPtrCast(&alpha), DeviceArrayRead(dctx, xin), 1, DeviceArrayReadWrite(dctx, yin), 1)); 517 PetscCall(PetscLogGpuTimeEnd()); 518 PetscCall(PetscLogGpuFlops(2 * n)); 519 PetscCall(PetscDeviceContextSynchronize(dctx)); 520 } else { 521 PetscCall(VecAXPY_Seq(yin, alpha, xin)); 522 } 523 PetscFunctionReturn(PETSC_SUCCESS); 524 } 525 526 // v->ops->pointwisedivide 527 template <device::cupm::DeviceType T> 528 inline PetscErrorCode VecSeq_CUPM<T>::PointwiseDivide(Vec win, Vec xin, Vec yin) noexcept 529 { 530 PetscFunctionBegin; 531 if (xin->boundtocpu || yin->boundtocpu) { 532 PetscCall(VecPointwiseDivide_Seq(win, xin, yin)); 533 } else { 534 // note order of arguments! xin and yin are read, win is written! 535 PetscCall(PointwiseBinary_(thrust::divides<PetscScalar>{}, xin, yin, win)); 536 } 537 PetscFunctionReturn(PETSC_SUCCESS); 538 } 539 540 // v->ops->pointwisemult 541 template <device::cupm::DeviceType T> 542 inline PetscErrorCode VecSeq_CUPM<T>::PointwiseMult(Vec win, Vec xin, Vec yin) noexcept 543 { 544 PetscFunctionBegin; 545 if (xin->boundtocpu || yin->boundtocpu) { 546 PetscCall(VecPointwiseMult_Seq(win, xin, yin)); 547 } else { 548 // note order of arguments! xin and yin are read, win is written! 549 PetscCall(PointwiseBinary_(thrust::multiplies<PetscScalar>{}, xin, yin, win)); 550 } 551 PetscFunctionReturn(PETSC_SUCCESS); 552 } 553 554 namespace detail 555 { 556 557 struct reciprocal { 558 PETSC_HOSTDEVICE_INLINE_DECL PetscScalar operator()(PetscScalar s) const noexcept 559 { 560 // yes all of this verbosity is needed because sometimes PetscScalar is a thrust::complex 561 // and then it matters whether we do s ? true : false vs s == 0, as well as whether we wrap 562 // everything in PetscScalar... 563 return s == PetscScalar{0.0} ? s : PetscScalar{1.0} / s; 564 } 565 }; 566 567 } // namespace detail 568 569 // v->ops->reciprocal 570 template <device::cupm::DeviceType T> 571 inline PetscErrorCode VecSeq_CUPM<T>::Reciprocal(Vec xin) noexcept 572 { 573 PetscFunctionBegin; 574 PetscCall(PointwiseUnary_(detail::reciprocal{}, xin)); 575 PetscFunctionReturn(PETSC_SUCCESS); 576 } 577 578 // v->ops->waxpy 579 template <device::cupm::DeviceType T> 580 inline PetscErrorCode VecSeq_CUPM<T>::WAXPY(Vec win, PetscScalar alpha, Vec xin, Vec yin) noexcept 581 { 582 PetscFunctionBegin; 583 if (alpha == PetscScalar(0.0)) { 584 PetscCall(Copy(yin, win)); 585 } else if (const auto n = static_cast<cupmBlasInt_t>(win->map->n)) { 586 PetscDeviceContext dctx; 587 cupmBlasHandle_t cupmBlasHandle; 588 cupmStream_t stream; 589 590 PetscCall(GetHandles_(&dctx, &cupmBlasHandle, &stream)); 591 { 592 const auto wptr = DeviceArrayWrite(dctx, win); 593 594 PetscCall(PetscLogGpuTimeBegin()); 595 PetscCall(PetscCUPMMemcpyAsync(wptr.data(), DeviceArrayRead(dctx, yin).data(), n, cupmMemcpyDeviceToDevice, stream, true)); 596 PetscCallCUPMBLAS(cupmBlasXaxpy(cupmBlasHandle, n, cupmScalarPtrCast(&alpha), DeviceArrayRead(dctx, xin), 1, wptr.cupmdata(), 1)); 597 PetscCall(PetscLogGpuTimeEnd()); 598 } 599 PetscCall(PetscLogGpuFlops(2 * n)); 600 PetscCall(PetscDeviceContextSynchronize(dctx)); 601 } 602 PetscFunctionReturn(PETSC_SUCCESS); 603 } 604 605 namespace kernels 606 { 607 608 template <typename... Args> 609 PETSC_KERNEL_DECL static void MAXPY_kernel(const PetscInt size, PetscScalar *PETSC_RESTRICT xptr, const PetscScalar *PETSC_RESTRICT aptr, Args... yptr) 610 { 611 constexpr int N = sizeof...(Args); 612 const auto tx = threadIdx.x; 613 const PetscScalar *yptr_p[] = {yptr...}; 614 615 PETSC_SHAREDMEM_DECL PetscScalar aptr_shmem[N]; 616 617 // load a to shared memory 618 if (tx < N) aptr_shmem[tx] = aptr[tx]; 619 __syncthreads(); 620 621 ::Petsc::device::cupm::kernels::util::grid_stride_1D(size, [&](PetscInt i) { 622 // these may look the same but give different results! 623 #if 0 624 PetscScalar sum = 0.0; 625 626 #pragma unroll 627 for (auto j = 0; j < N; ++j) sum += aptr_shmem[j]*yptr_p[j][i]; 628 xptr[i] += sum; 629 #else 630 auto sum = xptr[i]; 631 632 #pragma unroll 633 for (auto j = 0; j < N; ++j) sum += aptr_shmem[j] * yptr_p[j][i]; 634 xptr[i] = sum; 635 #endif 636 }); 637 return; 638 } 639 640 } // namespace kernels 641 642 namespace detail 643 { 644 645 // a helper-struct to gobble the size_t input, it is used with template parameter pack 646 // expansion such that 647 // typename repeat_type<MyType, IdxParamPack>... 648 // expands to 649 // MyType, MyType, MyType, ... [repeated sizeof...(IdxParamPack) times] 650 template <typename T, std::size_t> 651 struct repeat_type { 652 using type = T; 653 }; 654 655 } // namespace detail 656 657 template <device::cupm::DeviceType T> 658 template <std::size_t... Idx> 659 inline PetscErrorCode VecSeq_CUPM<T>::MAXPY_kernel_dispatch_(PetscDeviceContext dctx, cupmStream_t stream, PetscScalar *xptr, const PetscScalar *aptr, const Vec *yin, PetscInt size, util::index_sequence<Idx...>) noexcept 660 { 661 PetscFunctionBegin; 662 // clang-format off 663 PetscCall( 664 PetscCUPMLaunchKernel1D( 665 size, 0, stream, 666 kernels::MAXPY_kernel<typename detail::repeat_type<const PetscScalar *, Idx>::type...>, 667 size, xptr, aptr, DeviceArrayRead(dctx, yin[Idx]).data()... 668 ) 669 ); 670 // clang-format on 671 PetscFunctionReturn(PETSC_SUCCESS); 672 } 673 674 template <device::cupm::DeviceType T> 675 template <int N> 676 inline PetscErrorCode VecSeq_CUPM<T>::MAXPY_kernel_dispatch_(PetscDeviceContext dctx, cupmStream_t stream, PetscScalar *xptr, const PetscScalar *aptr, const Vec *yin, PetscInt size, PetscInt &yidx) noexcept 677 { 678 PetscFunctionBegin; 679 PetscCall(MAXPY_kernel_dispatch_(dctx, stream, xptr, aptr + yidx, yin + yidx, size, util::make_index_sequence<N>{})); 680 yidx += N; 681 PetscFunctionReturn(PETSC_SUCCESS); 682 } 683 684 // v->ops->maxpy 685 template <device::cupm::DeviceType T> 686 inline PetscErrorCode VecSeq_CUPM<T>::MAXPY(Vec xin, PetscInt nv, const PetscScalar *alpha, Vec *yin) noexcept 687 { 688 const auto n = xin->map->n; 689 PetscDeviceContext dctx; 690 cupmStream_t stream; 691 692 PetscFunctionBegin; 693 PetscCall(GetHandles_(&dctx, &stream)); 694 { 695 const auto xptr = DeviceArrayReadWrite(dctx, xin); 696 PetscScalar *d_alpha = nullptr; 697 PetscInt yidx = 0; 698 699 // placement of early-return is deliberate, we would like to capture the 700 // DeviceArrayReadWrite() call (which calls PetscObjectStateIncreate()) before we bail 701 if (!n || !nv) PetscFunctionReturn(PETSC_SUCCESS); 702 PetscCall(PetscDeviceMalloc(dctx, PETSC_MEMTYPE_CUPM(), nv, &d_alpha)); 703 PetscCall(PetscCUPMMemcpyAsync(d_alpha, alpha, nv, cupmMemcpyHostToDevice, stream)); 704 PetscCall(PetscLogGpuTimeBegin()); 705 do { 706 switch (nv - yidx) { 707 case 7: 708 PetscCall(MAXPY_kernel_dispatch_<7>(dctx, stream, xptr.data(), d_alpha, yin, n, yidx)); 709 break; 710 case 6: 711 PetscCall(MAXPY_kernel_dispatch_<6>(dctx, stream, xptr.data(), d_alpha, yin, n, yidx)); 712 break; 713 case 5: 714 PetscCall(MAXPY_kernel_dispatch_<5>(dctx, stream, xptr.data(), d_alpha, yin, n, yidx)); 715 break; 716 case 4: 717 PetscCall(MAXPY_kernel_dispatch_<4>(dctx, stream, xptr.data(), d_alpha, yin, n, yidx)); 718 break; 719 case 3: 720 PetscCall(MAXPY_kernel_dispatch_<3>(dctx, stream, xptr.data(), d_alpha, yin, n, yidx)); 721 break; 722 case 2: 723 PetscCall(MAXPY_kernel_dispatch_<2>(dctx, stream, xptr.data(), d_alpha, yin, n, yidx)); 724 break; 725 case 1: 726 PetscCall(MAXPY_kernel_dispatch_<1>(dctx, stream, xptr.data(), d_alpha, yin, n, yidx)); 727 break; 728 default: // 8 or more 729 PetscCall(MAXPY_kernel_dispatch_<8>(dctx, stream, xptr.data(), d_alpha, yin, n, yidx)); 730 break; 731 } 732 } while (yidx < nv); 733 PetscCall(PetscLogGpuTimeEnd()); 734 PetscCall(PetscDeviceFree(dctx, d_alpha)); 735 } 736 PetscCall(PetscLogGpuFlops(nv * 2 * n)); 737 PetscCall(PetscDeviceContextSynchronize(dctx)); 738 PetscFunctionReturn(PETSC_SUCCESS); 739 } 740 741 template <device::cupm::DeviceType T> 742 inline PetscErrorCode VecSeq_CUPM<T>::Dot(Vec xin, Vec yin, PetscScalar *z) noexcept 743 { 744 PetscFunctionBegin; 745 if (const auto n = static_cast<cupmBlasInt_t>(xin->map->n)) { 746 PetscDeviceContext dctx; 747 cupmBlasHandle_t cupmBlasHandle; 748 749 PetscCall(GetHandles_(&dctx, &cupmBlasHandle)); 750 // arguments y, x are reversed because BLAS complex conjugates the first argument, PETSc the 751 // second 752 PetscCall(PetscLogGpuTimeBegin()); 753 PetscCallCUPMBLAS(cupmBlasXdot(cupmBlasHandle, n, DeviceArrayRead(dctx, yin), 1, DeviceArrayRead(dctx, xin), 1, cupmScalarPtrCast(z))); 754 PetscCall(PetscLogGpuTimeEnd()); 755 PetscCall(PetscLogGpuFlops(2 * n - 1)); 756 } else { 757 *z = 0.0; 758 } 759 PetscFunctionReturn(PETSC_SUCCESS); 760 } 761 762 #define MDOT_WORKGROUP_NUM 128 763 #define MDOT_WORKGROUP_SIZE MDOT_WORKGROUP_NUM 764 765 namespace kernels 766 { 767 768 PETSC_DEVICE_INLINE_DECL static PetscInt EntriesPerGroup(const PetscInt size) noexcept 769 { 770 const auto group_entries = (size - 1) / gridDim.x + 1; 771 // for very small vectors, a group should still do some work 772 return group_entries ? group_entries : 1; 773 } 774 775 template <typename... ConstPetscScalarPointer> 776 PETSC_KERNEL_DECL static void MDot_kernel(const PetscScalar *PETSC_RESTRICT x, const PetscInt size, PetscScalar *PETSC_RESTRICT results, ConstPetscScalarPointer... y) 777 { 778 constexpr int N = sizeof...(ConstPetscScalarPointer); 779 const PetscScalar *ylocal[] = {y...}; 780 PetscScalar sumlocal[N]; 781 782 PETSC_SHAREDMEM_DECL PetscScalar shmem[N * MDOT_WORKGROUP_SIZE]; 783 784 // HIP -- for whatever reason -- has threadIdx, blockIdx, blockDim, and gridDim as separate 785 // types, so each of these go on separate lines... 786 const auto tx = threadIdx.x; 787 const auto bx = blockIdx.x; 788 const auto bdx = blockDim.x; 789 const auto gdx = gridDim.x; 790 const auto worksize = EntriesPerGroup(size); 791 const auto begin = tx + bx * worksize; 792 const auto end = min((bx + 1) * worksize, size); 793 794 #pragma unroll 795 for (auto i = 0; i < N; ++i) sumlocal[i] = 0; 796 797 for (auto i = begin; i < end; i += bdx) { 798 const auto xi = x[i]; // load only once from global memory! 799 800 #pragma unroll 801 for (auto j = 0; j < N; ++j) sumlocal[j] += ylocal[j][i] * xi; 802 } 803 804 #pragma unroll 805 for (auto i = 0; i < N; ++i) shmem[tx + i * MDOT_WORKGROUP_SIZE] = sumlocal[i]; 806 807 // parallel reduction 808 for (auto stride = bdx / 2; stride > 0; stride /= 2) { 809 __syncthreads(); 810 if (tx < stride) { 811 #pragma unroll 812 for (auto i = 0; i < N; ++i) shmem[tx + i * MDOT_WORKGROUP_SIZE] += shmem[tx + stride + i * MDOT_WORKGROUP_SIZE]; 813 } 814 } 815 // bottom N threads per block write to global memory 816 // REVIEW ME: I am ~pretty~ sure we don't need another __syncthreads() here since each thread 817 // writes to the same sections in the above loop that it is about to read from below, but 818 // running this under the racecheck tool of cuda-memcheck reports a write-after-write hazard. 819 __syncthreads(); 820 if (tx < N) results[bx + tx * gdx] = shmem[tx * MDOT_WORKGROUP_SIZE]; 821 return; 822 } 823 824 namespace 825 { 826 827 PETSC_KERNEL_DECL void sum_kernel(const PetscInt size, PetscScalar *PETSC_RESTRICT results) 828 { 829 int local_i = 0; 830 PetscScalar local_results[8]; 831 832 // each thread sums up MDOT_WORKGROUP_NUM entries of the result, storing it in a local buffer 833 // 834 // *-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-* 835 // | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | ... 836 // *-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-* 837 // | ______________________________________________________/ 838 // | / <- MDOT_WORKGROUP_NUM -> 839 // |/ 840 // + 841 // v 842 // *-*-* 843 // | | | ... 844 // *-*-* 845 // 846 ::Petsc::device::cupm::kernels::util::grid_stride_1D(size, [&](PetscInt i) { 847 PetscScalar z_sum = 0; 848 849 for (auto j = i * MDOT_WORKGROUP_SIZE; j < (i + 1) * MDOT_WORKGROUP_SIZE; ++j) z_sum += results[j]; 850 local_results[local_i++] = z_sum; 851 }); 852 // if we needed more than 1 workgroup to handle the vector we should sync since other threads 853 // may currently be reading from results 854 if (size >= MDOT_WORKGROUP_SIZE) __syncthreads(); 855 // Local buffer is now written to global memory 856 ::Petsc::device::cupm::kernels::util::grid_stride_1D(size, [&](PetscInt i) { 857 const auto j = --local_i; 858 859 if (j >= 0) results[i] = local_results[j]; 860 }); 861 return; 862 } 863 864 } // namespace 865 866 } // namespace kernels 867 868 template <device::cupm::DeviceType T> 869 template <std::size_t... Idx> 870 inline PetscErrorCode VecSeq_CUPM<T>::MDot_kernel_dispatch_(PetscDeviceContext dctx, cupmStream_t stream, const PetscScalar *xarr, const Vec yin[], PetscInt size, PetscScalar *results, util::index_sequence<Idx...>) noexcept 871 { 872 PetscFunctionBegin; 873 // REVIEW ME: convert this kernel launch to PetscCUPMLaunchKernel1D(), it currently launches 874 // 128 blocks of 128 threads every time which may be wasteful 875 // clang-format off 876 PetscCallCUPM( 877 cupmLaunchKernel( 878 kernels::MDot_kernel<typename detail::repeat_type<const PetscScalar *, Idx>::type...>, 879 MDOT_WORKGROUP_NUM, MDOT_WORKGROUP_SIZE, 0, stream, 880 xarr, size, results, DeviceArrayRead(dctx, yin[Idx]).data()... 881 ) 882 ); 883 // clang-format on 884 PetscFunctionReturn(PETSC_SUCCESS); 885 } 886 887 template <device::cupm::DeviceType T> 888 template <int N> 889 inline PetscErrorCode VecSeq_CUPM<T>::MDot_kernel_dispatch_(PetscDeviceContext dctx, cupmStream_t stream, const PetscScalar *xarr, const Vec yin[], PetscInt size, PetscScalar *results, PetscInt &yidx) noexcept 890 { 891 PetscFunctionBegin; 892 PetscCall(MDot_kernel_dispatch_(dctx, stream, xarr, yin + yidx, size, results + yidx * MDOT_WORKGROUP_NUM, util::make_index_sequence<N>{})); 893 yidx += N; 894 PetscFunctionReturn(PETSC_SUCCESS); 895 } 896 897 template <device::cupm::DeviceType T> 898 inline PetscErrorCode VecSeq_CUPM<T>::MDot_(std::false_type, Vec xin, PetscInt nv, const Vec yin[], PetscScalar *z, PetscDeviceContext dctx) noexcept 899 { 900 // the largest possible size of a batch 901 constexpr PetscInt batchsize = 8; 902 // how many sub streams to create, if nv <= batchsize we can do this without looping, so we 903 // do not create substreams. Note we don't create more than 8 streams, in practice we could 904 // not get more parallelism with higher numbers. 905 const auto num_sub_streams = nv > batchsize ? std::min((nv + batchsize) / batchsize, batchsize) : 0; 906 const auto n = xin->map->n; 907 // number of vectors that we handle via the batches. note any singletons are handled by 908 // cublas, hence the nv-1. 909 const auto nvbatch = ((nv % batchsize) == 1) ? nv - 1 : nv; 910 const auto nwork = nvbatch * MDOT_WORKGROUP_NUM; 911 PetscScalar *d_results; 912 cupmStream_t stream; 913 914 PetscFunctionBegin; 915 PetscCall(GetHandlesFrom_(dctx, &stream)); 916 // allocate scratchpad memory for the results of individual work groups 917 PetscCall(PetscDeviceMalloc(dctx, PETSC_MEMTYPE_CUPM(), nwork, &d_results)); 918 { 919 const auto xptr = DeviceArrayRead(dctx, xin); 920 PetscInt yidx = 0; 921 auto subidx = 0; 922 auto cur_stream = stream; 923 auto cur_ctx = dctx; 924 PetscDeviceContext *sub = nullptr; 925 PetscStreamType stype; 926 927 // REVIEW ME: maybe PetscDeviceContextFork() should insert dctx into the first entry of 928 // sub. Ideally the parent context should also join in on the fork, but it is extremely 929 // fiddly to do so presently 930 PetscCall(PetscDeviceContextGetStreamType(dctx, &stype)); 931 if (stype == PETSC_STREAM_GLOBAL_BLOCKING) stype = PETSC_STREAM_DEFAULT_BLOCKING; 932 // If we have a globally blocking stream create nonblocking streams instead (as we can 933 // locally exploit the parallelism). Otherwise use the prescribed stream type. 934 PetscCall(PetscDeviceContextForkWithStreamType(dctx, stype, num_sub_streams, &sub)); 935 PetscCall(PetscLogGpuTimeBegin()); 936 do { 937 if (num_sub_streams) { 938 cur_ctx = sub[subidx++ % num_sub_streams]; 939 PetscCall(GetHandlesFrom_(cur_ctx, &cur_stream)); 940 } 941 // REVIEW ME: Should probably try and load-balance these. Consider the case where nv = 9; 942 // it is very likely better to do 4+5 rather than 8+1 943 switch (nv - yidx) { 944 case 7: 945 PetscCall(MDot_kernel_dispatch_<7>(cur_ctx, cur_stream, xptr.data(), yin, n, d_results, yidx)); 946 break; 947 case 6: 948 PetscCall(MDot_kernel_dispatch_<6>(cur_ctx, cur_stream, xptr.data(), yin, n, d_results, yidx)); 949 break; 950 case 5: 951 PetscCall(MDot_kernel_dispatch_<5>(cur_ctx, cur_stream, xptr.data(), yin, n, d_results, yidx)); 952 break; 953 case 4: 954 PetscCall(MDot_kernel_dispatch_<4>(cur_ctx, cur_stream, xptr.data(), yin, n, d_results, yidx)); 955 break; 956 case 3: 957 PetscCall(MDot_kernel_dispatch_<3>(cur_ctx, cur_stream, xptr.data(), yin, n, d_results, yidx)); 958 break; 959 case 2: 960 PetscCall(MDot_kernel_dispatch_<2>(cur_ctx, cur_stream, xptr.data(), yin, n, d_results, yidx)); 961 break; 962 case 1: { 963 cupmBlasHandle_t cupmBlasHandle; 964 965 PetscCall(GetHandlesFrom_(cur_ctx, &cupmBlasHandle)); 966 PetscCallCUPMBLAS(cupmBlasXdot(cupmBlasHandle, static_cast<cupmBlasInt_t>(n), DeviceArrayRead(cur_ctx, yin[yidx]).cupmdata(), 1, xptr.cupmdata(), 1, cupmScalarPtrCast(z + yidx))); 967 ++yidx; 968 } break; 969 default: // 8 or more 970 PetscCall(MDot_kernel_dispatch_<8>(cur_ctx, cur_stream, xptr.data(), yin, n, d_results, yidx)); 971 break; 972 } 973 } while (yidx < nv); 974 PetscCall(PetscLogGpuTimeEnd()); 975 PetscCall(PetscDeviceContextJoin(dctx, num_sub_streams, PETSC_DEVICE_CONTEXT_JOIN_DESTROY, &sub)); 976 } 977 978 PetscCall(PetscCUPMLaunchKernel1D(nvbatch, 0, stream, kernels::sum_kernel, nvbatch, d_results)); 979 // copy result of device reduction to host 980 PetscCall(PetscCUPMMemcpyAsync(z, d_results, nvbatch, cupmMemcpyDeviceToHost, stream)); 981 // do these now while final reduction is in flight 982 PetscCall(PetscLogFlops(nwork)); 983 PetscCall(PetscDeviceFree(dctx, d_results)); 984 PetscFunctionReturn(PETSC_SUCCESS); 985 } 986 987 #undef MDOT_WORKGROUP_NUM 988 #undef MDOT_WORKGROUP_SIZE 989 990 template <device::cupm::DeviceType T> 991 inline PetscErrorCode VecSeq_CUPM<T>::MDot_(std::true_type, Vec xin, PetscInt nv, const Vec yin[], PetscScalar *z, PetscDeviceContext dctx) noexcept 992 { 993 // probably not worth it to run more than 8 of these at a time? 994 const auto n_sub = PetscMin(nv, 8); 995 const auto n = static_cast<cupmBlasInt_t>(xin->map->n); 996 const auto xptr = DeviceArrayRead(dctx, xin); 997 PetscScalar *d_z; 998 PetscDeviceContext *subctx; 999 cupmStream_t stream; 1000 1001 PetscFunctionBegin; 1002 PetscCall(GetHandlesFrom_(dctx, &stream)); 1003 PetscCall(PetscDeviceMalloc(dctx, PETSC_MEMTYPE_CUPM(), nv, &d_z)); 1004 PetscCall(PetscDeviceContextFork(dctx, n_sub, &subctx)); 1005 PetscCall(PetscLogGpuTimeBegin()); 1006 for (PetscInt i = 0; i < nv; ++i) { 1007 const auto sub = subctx[i % n_sub]; 1008 cupmBlasHandle_t handle; 1009 cupmBlasPointerMode_t old_mode; 1010 1011 PetscCall(GetHandlesFrom_(sub, &handle)); 1012 PetscCallCUPMBLAS(cupmBlasGetPointerMode(handle, &old_mode)); 1013 if (old_mode != CUPMBLAS_POINTER_MODE_DEVICE) PetscCallCUPMBLAS(cupmBlasSetPointerMode(handle, CUPMBLAS_POINTER_MODE_DEVICE)); 1014 PetscCallCUPMBLAS(cupmBlasXdot(handle, n, DeviceArrayRead(sub, yin[i]), 1, xptr.cupmdata(), 1, cupmScalarPtrCast(d_z + i))); 1015 if (old_mode != CUPMBLAS_POINTER_MODE_DEVICE) PetscCallCUPMBLAS(cupmBlasSetPointerMode(handle, old_mode)); 1016 } 1017 PetscCall(PetscLogGpuTimeEnd()); 1018 PetscCall(PetscDeviceContextJoin(dctx, n_sub, PETSC_DEVICE_CONTEXT_JOIN_DESTROY, &subctx)); 1019 PetscCall(PetscCUPMMemcpyAsync(z, d_z, nv, cupmMemcpyDeviceToHost, stream)); 1020 PetscCall(PetscDeviceFree(dctx, d_z)); 1021 // REVIEW ME: flops????? 1022 PetscFunctionReturn(PETSC_SUCCESS); 1023 } 1024 1025 // v->ops->mdot 1026 template <device::cupm::DeviceType T> 1027 inline PetscErrorCode VecSeq_CUPM<T>::MDot(Vec xin, PetscInt nv, const Vec yin[], PetscScalar *z) noexcept 1028 { 1029 PetscFunctionBegin; 1030 if (PetscUnlikely(nv == 1)) { 1031 // dot handles nv = 0 correctly 1032 PetscCall(Dot(xin, const_cast<Vec>(yin[0]), z)); 1033 } else if (const auto n = xin->map->n) { 1034 PetscDeviceContext dctx; 1035 1036 PetscCheck(nv > 0, PETSC_COMM_SELF, PETSC_ERR_LIB, "Number of vectors provided to %s %" PetscInt_FMT " not positive", PETSC_FUNCTION_NAME, nv); 1037 PetscCall(GetHandles_(&dctx)); 1038 PetscCall(MDot_(std::integral_constant<bool, PetscDefined(USE_COMPLEX)>{}, xin, nv, yin, z, dctx)); 1039 // REVIEW ME: double count of flops?? 1040 PetscCall(PetscLogGpuFlops(nv * (2 * n - 1))); 1041 PetscCall(PetscDeviceContextSynchronize(dctx)); 1042 } else { 1043 PetscCall(PetscArrayzero(z, nv)); 1044 } 1045 PetscFunctionReturn(PETSC_SUCCESS); 1046 } 1047 1048 // v->ops->set 1049 template <device::cupm::DeviceType T> 1050 inline PetscErrorCode VecSeq_CUPM<T>::Set(Vec xin, PetscScalar alpha) noexcept 1051 { 1052 const auto n = xin->map->n; 1053 PetscDeviceContext dctx; 1054 cupmStream_t stream; 1055 1056 PetscFunctionBegin; 1057 PetscCall(GetHandles_(&dctx, &stream)); 1058 { 1059 const auto xptr = DeviceArrayWrite(dctx, xin); 1060 1061 if (alpha == PetscScalar(0.0)) { 1062 PetscCall(PetscCUPMMemsetAsync(xptr.data(), 0, n, stream)); 1063 } else { 1064 const auto dptr = thrust::device_pointer_cast(xptr.data()); 1065 1066 PetscCallThrust(THRUST_CALL(thrust::fill, stream, dptr, dptr + n, alpha)); 1067 } 1068 if (n) PetscCall(PetscDeviceContextSynchronize(dctx)); // don't sync if we did nothing 1069 } 1070 PetscFunctionReturn(PETSC_SUCCESS); 1071 } 1072 1073 // v->ops->scale 1074 template <device::cupm::DeviceType T> 1075 inline PetscErrorCode VecSeq_CUPM<T>::Scale(Vec xin, PetscScalar alpha) noexcept 1076 { 1077 PetscFunctionBegin; 1078 if (PetscUnlikely(alpha == PetscScalar(1.0))) PetscFunctionReturn(PETSC_SUCCESS); 1079 if (PetscUnlikely(alpha == PetscScalar(0.0))) { 1080 PetscCall(Set(xin, alpha)); 1081 } else if (const auto n = static_cast<cupmBlasInt_t>(xin->map->n)) { 1082 PetscDeviceContext dctx; 1083 cupmBlasHandle_t cupmBlasHandle; 1084 1085 PetscCall(GetHandles_(&dctx, &cupmBlasHandle)); 1086 PetscCall(PetscLogGpuTimeBegin()); 1087 PetscCallCUPMBLAS(cupmBlasXscal(cupmBlasHandle, n, cupmScalarPtrCast(&alpha), DeviceArrayReadWrite(dctx, xin), 1)); 1088 PetscCall(PetscLogGpuTimeEnd()); 1089 PetscCall(PetscLogGpuFlops(n)); 1090 PetscCall(PetscDeviceContextSynchronize(dctx)); 1091 } else { 1092 PetscCall(MaybeIncrementEmptyLocalVec(xin)); 1093 } 1094 PetscFunctionReturn(PETSC_SUCCESS); 1095 } 1096 1097 // v->ops->tdot 1098 template <device::cupm::DeviceType T> 1099 inline PetscErrorCode VecSeq_CUPM<T>::TDot(Vec xin, Vec yin, PetscScalar *z) noexcept 1100 { 1101 PetscFunctionBegin; 1102 if (const auto n = static_cast<cupmBlasInt_t>(xin->map->n)) { 1103 PetscDeviceContext dctx; 1104 cupmBlasHandle_t cupmBlasHandle; 1105 1106 PetscCall(GetHandles_(&dctx, &cupmBlasHandle)); 1107 PetscCall(PetscLogGpuTimeBegin()); 1108 PetscCallCUPMBLAS(cupmBlasXdotu(cupmBlasHandle, n, DeviceArrayRead(dctx, xin), 1, DeviceArrayRead(dctx, yin), 1, cupmScalarPtrCast(z))); 1109 PetscCall(PetscLogGpuTimeEnd()); 1110 PetscCall(PetscLogGpuFlops(2 * n - 1)); 1111 } else { 1112 *z = 0.0; 1113 } 1114 PetscFunctionReturn(PETSC_SUCCESS); 1115 } 1116 1117 // v->ops->copy 1118 template <device::cupm::DeviceType T> 1119 inline PetscErrorCode VecSeq_CUPM<T>::Copy(Vec xin, Vec yout) noexcept 1120 { 1121 PetscFunctionBegin; 1122 if (xin == yout) PetscFunctionReturn(PETSC_SUCCESS); 1123 if (const auto n = xin->map->n) { 1124 const auto xmask = xin->offloadmask; 1125 // silence buggy gcc warning: mode may be used uninitialized in this function 1126 auto mode = cupmMemcpyDeviceToDevice; 1127 PetscDeviceContext dctx; 1128 cupmStream_t stream; 1129 1130 // translate from PetscOffloadMask to cupmMemcpyKind 1131 switch (const auto ymask = yout->offloadmask) { 1132 case PETSC_OFFLOAD_UNALLOCATED: { 1133 PetscBool yiscupm; 1134 1135 PetscCall(PetscObjectTypeCompareAny(PetscObjectCast(yout), &yiscupm, VECSEQCUPM(), VECMPICUPM(), "")); 1136 if (yiscupm) { 1137 mode = PetscOffloadDevice(xmask) ? cupmMemcpyDeviceToDevice : cupmMemcpyHostToHost; 1138 break; 1139 } 1140 } // fall-through if unallocated and not cupm 1141 #if PETSC_CPP_VERSION >= 17 1142 [[fallthrough]]; 1143 #endif 1144 case PETSC_OFFLOAD_CPU: 1145 mode = PetscOffloadHost(xmask) ? cupmMemcpyHostToHost : cupmMemcpyDeviceToHost; 1146 break; 1147 case PETSC_OFFLOAD_BOTH: 1148 case PETSC_OFFLOAD_GPU: 1149 mode = PetscOffloadDevice(xmask) ? cupmMemcpyDeviceToDevice : cupmMemcpyHostToDevice; 1150 break; 1151 default: 1152 SETERRQ(PETSC_COMM_SELF, PETSC_ERR_ARG_INCOMP, "Incompatible offload mask %s", PetscOffloadMaskToString(ymask)); 1153 } 1154 1155 PetscCall(GetHandles_(&dctx, &stream)); 1156 switch (mode) { 1157 case cupmMemcpyDeviceToDevice: // the best case 1158 case cupmMemcpyHostToDevice: { // not terrible 1159 const auto yptr = DeviceArrayWrite(dctx, yout); 1160 const auto xptr = mode == cupmMemcpyDeviceToDevice ? DeviceArrayRead(dctx, xin).data() : HostArrayRead(dctx, xin).data(); 1161 1162 PetscCall(PetscLogGpuTimeBegin()); 1163 PetscCall(PetscCUPMMemcpyAsync(yptr.data(), xptr, n, mode, stream)); 1164 PetscCall(PetscLogGpuTimeEnd()); 1165 } break; 1166 case cupmMemcpyDeviceToHost: // not great 1167 case cupmMemcpyHostToHost: { // worst case 1168 const auto xptr = mode == cupmMemcpyDeviceToHost ? DeviceArrayRead(dctx, xin).data() : HostArrayRead(dctx, xin).data(); 1169 PetscScalar *yptr; 1170 1171 PetscCall(VecGetArrayWrite(yout, &yptr)); 1172 if (mode == cupmMemcpyDeviceToHost) PetscCall(PetscLogGpuTimeBegin()); 1173 PetscCall(PetscCUPMMemcpyAsync(yptr, xptr, n, mode, stream, /* force async */ true)); 1174 if (mode == cupmMemcpyDeviceToHost) PetscCall(PetscLogGpuTimeEnd()); 1175 PetscCall(VecRestoreArrayWrite(yout, &yptr)); 1176 } break; 1177 default: 1178 SETERRQ(PETSC_COMM_SELF, PETSC_ERR_GPU, "Unknown cupmMemcpyKind %d", static_cast<int>(mode)); 1179 } 1180 PetscCall(PetscDeviceContextSynchronize(dctx)); 1181 } else { 1182 PetscCall(MaybeIncrementEmptyLocalVec(yout)); 1183 } 1184 PetscFunctionReturn(PETSC_SUCCESS); 1185 } 1186 1187 // v->ops->swap 1188 template <device::cupm::DeviceType T> 1189 inline PetscErrorCode VecSeq_CUPM<T>::Swap(Vec xin, Vec yin) noexcept 1190 { 1191 PetscFunctionBegin; 1192 if (xin == yin) PetscFunctionReturn(PETSC_SUCCESS); 1193 if (const auto n = static_cast<cupmBlasInt_t>(xin->map->n)) { 1194 PetscDeviceContext dctx; 1195 cupmBlasHandle_t cupmBlasHandle; 1196 1197 PetscCall(GetHandles_(&dctx, &cupmBlasHandle)); 1198 PetscCall(PetscLogGpuTimeBegin()); 1199 PetscCallCUPMBLAS(cupmBlasXswap(cupmBlasHandle, n, DeviceArrayReadWrite(dctx, xin), 1, DeviceArrayReadWrite(dctx, yin), 1)); 1200 PetscCall(PetscLogGpuTimeEnd()); 1201 PetscCall(PetscDeviceContextSynchronize(dctx)); 1202 } else { 1203 PetscCall(MaybeIncrementEmptyLocalVec(xin)); 1204 PetscCall(MaybeIncrementEmptyLocalVec(yin)); 1205 } 1206 PetscFunctionReturn(PETSC_SUCCESS); 1207 } 1208 1209 // v->ops->axpby 1210 template <device::cupm::DeviceType T> 1211 inline PetscErrorCode VecSeq_CUPM<T>::AXPBY(Vec yin, PetscScalar alpha, PetscScalar beta, Vec xin) noexcept 1212 { 1213 PetscFunctionBegin; 1214 if (alpha == PetscScalar(0.0)) { 1215 PetscCall(Scale(yin, beta)); 1216 } else if (beta == PetscScalar(1.0)) { 1217 PetscCall(AXPY(yin, alpha, xin)); 1218 } else if (alpha == PetscScalar(1.0)) { 1219 PetscCall(AYPX(yin, beta, xin)); 1220 } else if (const auto n = static_cast<cupmBlasInt_t>(yin->map->n)) { 1221 const auto betaIsZero = beta == PetscScalar(0.0); 1222 const auto aptr = cupmScalarPtrCast(&alpha); 1223 PetscDeviceContext dctx; 1224 cupmBlasHandle_t cupmBlasHandle; 1225 1226 PetscCall(GetHandles_(&dctx, &cupmBlasHandle)); 1227 { 1228 const auto xptr = DeviceArrayRead(dctx, xin); 1229 1230 if (betaIsZero /* beta = 0 */) { 1231 // here we can get away with purely write-only as we memcpy into it first 1232 const auto yptr = DeviceArrayWrite(dctx, yin); 1233 cupmStream_t stream; 1234 1235 PetscCall(GetHandlesFrom_(dctx, &stream)); 1236 PetscCall(PetscLogGpuTimeBegin()); 1237 PetscCall(PetscCUPMMemcpyAsync(yptr.data(), xptr.data(), n, cupmMemcpyDeviceToDevice, stream)); 1238 PetscCallCUPMBLAS(cupmBlasXscal(cupmBlasHandle, n, aptr, yptr.cupmdata(), 1)); 1239 } else { 1240 const auto yptr = DeviceArrayReadWrite(dctx, yin); 1241 1242 PetscCall(PetscLogGpuTimeBegin()); 1243 PetscCallCUPMBLAS(cupmBlasXscal(cupmBlasHandle, n, cupmScalarPtrCast(&beta), yptr.cupmdata(), 1)); 1244 PetscCallCUPMBLAS(cupmBlasXaxpy(cupmBlasHandle, n, aptr, xptr.cupmdata(), 1, yptr.cupmdata(), 1)); 1245 } 1246 } 1247 PetscCall(PetscLogGpuTimeEnd()); 1248 PetscCall(PetscLogGpuFlops((betaIsZero ? 1 : 3) * n)); 1249 PetscCall(PetscDeviceContextSynchronize(dctx)); 1250 } else { 1251 PetscCall(MaybeIncrementEmptyLocalVec(yin)); 1252 } 1253 PetscFunctionReturn(PETSC_SUCCESS); 1254 } 1255 1256 // v->ops->axpbypcz 1257 template <device::cupm::DeviceType T> 1258 inline PetscErrorCode VecSeq_CUPM<T>::AXPBYPCZ(Vec zin, PetscScalar alpha, PetscScalar beta, PetscScalar gamma, Vec xin, Vec yin) noexcept 1259 { 1260 PetscFunctionBegin; 1261 if (gamma != PetscScalar(1.0)) PetscCall(Scale(zin, gamma)); 1262 PetscCall(AXPY(zin, alpha, xin)); 1263 PetscCall(AXPY(zin, beta, yin)); 1264 PetscFunctionReturn(PETSC_SUCCESS); 1265 } 1266 1267 // v->ops->norm 1268 template <device::cupm::DeviceType T> 1269 inline PetscErrorCode VecSeq_CUPM<T>::Norm(Vec xin, NormType type, PetscReal *z) noexcept 1270 { 1271 PetscDeviceContext dctx; 1272 cupmBlasHandle_t cupmBlasHandle; 1273 1274 PetscFunctionBegin; 1275 PetscCall(GetHandles_(&dctx, &cupmBlasHandle)); 1276 if (const auto n = static_cast<cupmBlasInt_t>(xin->map->n)) { 1277 const auto xptr = DeviceArrayRead(dctx, xin); 1278 PetscInt flopCount = 0; 1279 1280 PetscCall(PetscLogGpuTimeBegin()); 1281 switch (type) { 1282 case NORM_1_AND_2: 1283 case NORM_1: 1284 PetscCallCUPMBLAS(cupmBlasXasum(cupmBlasHandle, n, xptr.cupmdata(), 1, cupmRealPtrCast(z))); 1285 flopCount = std::max(n - 1, 0); 1286 if (type == NORM_1) break; 1287 ++z; // fall-through 1288 #if PETSC_CPP_VERSION >= 17 1289 [[fallthrough]]; 1290 #endif 1291 case NORM_2: 1292 case NORM_FROBENIUS: 1293 PetscCallCUPMBLAS(cupmBlasXnrm2(cupmBlasHandle, n, xptr.cupmdata(), 1, cupmRealPtrCast(z))); 1294 flopCount += std::max(2 * n - 1, 0); // += in case we've fallen through from NORM_1_AND_2 1295 break; 1296 case NORM_INFINITY: { 1297 cupmBlasInt_t max_loc = 0; 1298 PetscScalar xv = 0.; 1299 cupmStream_t stream; 1300 1301 PetscCall(GetHandlesFrom_(dctx, &stream)); 1302 PetscCallCUPMBLAS(cupmBlasXamax(cupmBlasHandle, n, xptr.cupmdata(), 1, &max_loc)); 1303 PetscCall(PetscCUPMMemcpyAsync(&xv, xptr.data() + max_loc - 1, 1, cupmMemcpyDeviceToHost, stream)); 1304 *z = PetscAbsScalar(xv); 1305 // REVIEW ME: flopCount = ??? 1306 } break; 1307 } 1308 PetscCall(PetscLogGpuTimeEnd()); 1309 PetscCall(PetscLogGpuFlops(flopCount)); 1310 } else { 1311 z[0] = 0.0; 1312 z[type == NORM_1_AND_2] = 0.0; 1313 } 1314 PetscFunctionReturn(PETSC_SUCCESS); 1315 } 1316 1317 namespace detail 1318 { 1319 1320 template <NormType wnormtype> 1321 class ErrorWNormTransformBase { 1322 public: 1323 using result_type = thrust::tuple<PetscReal, PetscReal, PetscReal, PetscInt, PetscInt, PetscInt>; 1324 1325 constexpr explicit ErrorWNormTransformBase(PetscReal v) noexcept : ignore_max_{v} { } 1326 1327 protected: 1328 struct NormTuple { 1329 PetscReal norm; 1330 PetscInt loc; 1331 }; 1332 1333 PETSC_NODISCARD PETSC_HOSTDEVICE_INLINE_DECL static NormTuple compute_norm_(PetscReal err, PetscReal tol) noexcept 1334 { 1335 if (tol > 0.) { 1336 const auto val = err / tol; 1337 1338 return {wnormtype == NORM_INFINITY ? val : PetscSqr(val), 1}; 1339 } else { 1340 return {0.0, 0}; 1341 } 1342 } 1343 1344 PetscReal ignore_max_; 1345 }; 1346 1347 template <NormType wnormtype> 1348 struct ErrorWNormTransform : ErrorWNormTransformBase<wnormtype> { 1349 using base_type = ErrorWNormTransformBase<wnormtype>; 1350 using result_type = typename base_type::result_type; 1351 using argument_type = thrust::tuple<PetscScalar, PetscScalar, PetscScalar, PetscScalar>; 1352 1353 using base_type::base_type; 1354 1355 PETSC_NODISCARD PETSC_HOSTDEVICE_INLINE_DECL result_type operator()(const argument_type &x) const noexcept 1356 { 1357 const auto u = x.get<0>(); 1358 const auto y = x.get<1>(); 1359 const auto au = PetscAbsScalar(u); 1360 const auto ay = PetscAbsScalar(y); 1361 const auto skip = au < this->ignore_max_ || ay < this->ignore_max_; 1362 const auto tola = skip ? 0.0 : PetscRealPart(x.get<2>()); 1363 const auto tolr = skip ? 0.0 : PetscRealPart(x.get<3>()) * PetscMax(au, ay); 1364 const auto tol = tola + tolr; 1365 const auto err = PetscAbsScalar(u - y); 1366 const auto tup_a = this->compute_norm_(err, tola); 1367 const auto tup_r = this->compute_norm_(err, tolr); 1368 const auto tup_n = this->compute_norm_(err, tol); 1369 1370 return {tup_n.norm, tup_a.norm, tup_r.norm, tup_n.loc, tup_a.loc, tup_r.loc}; 1371 } 1372 }; 1373 1374 template <NormType wnormtype> 1375 struct ErrorWNormETransform : ErrorWNormTransformBase<wnormtype> { 1376 using base_type = ErrorWNormTransformBase<wnormtype>; 1377 using result_type = typename base_type::result_type; 1378 using argument_type = thrust::tuple<PetscScalar, PetscScalar, PetscScalar, PetscScalar, PetscScalar>; 1379 1380 using base_type::base_type; 1381 1382 PETSC_NODISCARD PETSC_HOSTDEVICE_INLINE_DECL result_type operator()(const argument_type &x) const noexcept 1383 { 1384 const auto au = PetscAbsScalar(x.get<0>()); 1385 const auto ay = PetscAbsScalar(x.get<1>()); 1386 const auto skip = au < this->ignore_max_ || ay < this->ignore_max_; 1387 const auto tola = skip ? 0.0 : PetscRealPart(x.get<3>()); 1388 const auto tolr = skip ? 0.0 : PetscRealPart(x.get<4>()) * PetscMax(au, ay); 1389 const auto tol = tola + tolr; 1390 const auto err = PetscAbsScalar(x.get<2>()); 1391 const auto tup_a = this->compute_norm_(err, tola); 1392 const auto tup_r = this->compute_norm_(err, tolr); 1393 const auto tup_n = this->compute_norm_(err, tol); 1394 1395 return {tup_n.norm, tup_a.norm, tup_r.norm, tup_n.loc, tup_a.loc, tup_r.loc}; 1396 } 1397 }; 1398 1399 template <NormType wnormtype> 1400 struct ErrorWNormReduce { 1401 using value_type = typename ErrorWNormTransformBase<wnormtype>::result_type; 1402 1403 PETSC_NODISCARD PETSC_HOSTDEVICE_INLINE_DECL value_type operator()(const value_type &lhs, const value_type &rhs) const noexcept 1404 { 1405 // cannot use lhs.get<0>() etc since the using decl above ambiguates the fact that 1406 // result_type is a template, so in order to fix this we would need to write: 1407 // 1408 // lhs.template get<0>() 1409 // 1410 // which is unseemly. 1411 if (wnormtype == NORM_INFINITY) { 1412 // clang-format off 1413 return { 1414 PetscMax(thrust::get<0>(lhs), thrust::get<0>(rhs)), 1415 PetscMax(thrust::get<1>(lhs), thrust::get<1>(rhs)), 1416 PetscMax(thrust::get<2>(lhs), thrust::get<2>(rhs)), 1417 thrust::get<3>(lhs) + thrust::get<3>(rhs), 1418 thrust::get<4>(lhs) + thrust::get<4>(rhs), 1419 thrust::get<5>(lhs) + thrust::get<5>(rhs) 1420 }; 1421 // clang-format on 1422 } else { 1423 // clang-format off 1424 return { 1425 thrust::get<0>(lhs) + thrust::get<0>(rhs), 1426 thrust::get<1>(lhs) + thrust::get<1>(rhs), 1427 thrust::get<2>(lhs) + thrust::get<2>(rhs), 1428 thrust::get<3>(lhs) + thrust::get<3>(rhs), 1429 thrust::get<4>(lhs) + thrust::get<4>(rhs), 1430 thrust::get<5>(lhs) + thrust::get<5>(rhs) 1431 }; 1432 // clang-format on 1433 } 1434 } 1435 }; 1436 1437 template <template <NormType> class WNormTransformType, typename Tuple, typename cupmStream_t> 1438 inline PetscErrorCode ExecuteWNorm(Tuple &&first, Tuple &&last, NormType wnormtype, cupmStream_t stream, PetscReal ignore_max, PetscReal *norm, PetscInt *norm_loc, PetscReal *norma, PetscInt *norma_loc, PetscReal *normr, PetscInt *normr_loc) noexcept 1439 { 1440 auto begin = thrust::make_zip_iterator(std::forward<Tuple>(first)); 1441 auto end = thrust::make_zip_iterator(std::forward<Tuple>(last)); 1442 PetscReal n = 0, na = 0, nr = 0; 1443 PetscInt n_loc = 0, na_loc = 0, nr_loc = 0; 1444 1445 PetscFunctionBegin; 1446 // clang-format off 1447 if (wnormtype == NORM_INFINITY) { 1448 PetscCallThrust( 1449 thrust::tie(*norm, *norma, *normr, *norm_loc, *norma_loc, *normr_loc) = THRUST_CALL( 1450 thrust::transform_reduce, 1451 stream, 1452 std::move(begin), 1453 std::move(end), 1454 WNormTransformType<NORM_INFINITY>{ignore_max}, 1455 thrust::make_tuple(n, na, nr, n_loc, na_loc, nr_loc), 1456 ErrorWNormReduce<NORM_INFINITY>{} 1457 ) 1458 ); 1459 } else { 1460 PetscCallThrust( 1461 thrust::tie(*norm, *norma, *normr, *norm_loc, *norma_loc, *normr_loc) = THRUST_CALL( 1462 thrust::transform_reduce, 1463 stream, 1464 std::move(begin), 1465 std::move(end), 1466 WNormTransformType<NORM_2>{ignore_max}, 1467 thrust::make_tuple(n, na, nr, n_loc, na_loc, nr_loc), 1468 ErrorWNormReduce<NORM_2>{} 1469 ) 1470 ); 1471 } 1472 // clang-format on 1473 if (wnormtype == NORM_2) { 1474 *norm = PetscSqrtReal(*norm); 1475 *norma = PetscSqrtReal(*norma); 1476 *normr = PetscSqrtReal(*normr); 1477 } 1478 PetscFunctionReturn(PETSC_SUCCESS); 1479 } 1480 1481 } // namespace detail 1482 1483 // v->ops->errorwnorm 1484 template <device::cupm::DeviceType T> 1485 inline PetscErrorCode VecSeq_CUPM<T>::ErrorWnorm(Vec U, Vec Y, Vec E, NormType wnormtype, PetscReal atol, Vec vatol, PetscReal rtol, Vec vrtol, PetscReal ignore_max, PetscReal *norm, PetscInt *norm_loc, PetscReal *norma, PetscInt *norma_loc, PetscReal *normr, PetscInt *normr_loc) noexcept 1486 { 1487 const auto nl = U->map->n; 1488 auto ait = thrust::make_constant_iterator(static_cast<PetscScalar>(atol)); 1489 auto rit = thrust::make_constant_iterator(static_cast<PetscScalar>(rtol)); 1490 PetscDeviceContext dctx; 1491 cupmStream_t stream; 1492 1493 PetscFunctionBegin; 1494 PetscCall(GetHandles_(&dctx, &stream)); 1495 { 1496 const auto ConditionalDeviceArrayRead = [&](Vec v) { 1497 if (v) { 1498 return thrust::device_pointer_cast(DeviceArrayRead(dctx, v).data()); 1499 } else { 1500 return thrust::device_ptr<PetscScalar>{nullptr}; 1501 } 1502 }; 1503 1504 const auto uarr = DeviceArrayRead(dctx, U); 1505 const auto yarr = DeviceArrayRead(dctx, Y); 1506 const auto uptr = thrust::device_pointer_cast(uarr.data()); 1507 const auto yptr = thrust::device_pointer_cast(yarr.data()); 1508 const auto eptr = ConditionalDeviceArrayRead(E); 1509 const auto rptr = ConditionalDeviceArrayRead(vrtol); 1510 const auto aptr = ConditionalDeviceArrayRead(vatol); 1511 1512 if (!vatol && !vrtol) { 1513 if (E) { 1514 // clang-format off 1515 PetscCall( 1516 detail::ExecuteWNorm<detail::ErrorWNormETransform>( 1517 thrust::make_tuple(uptr, yptr, eptr, ait, rit), 1518 thrust::make_tuple(uptr + nl, yptr + nl, eptr + nl, ait, rit), 1519 wnormtype, stream, ignore_max, norm, norm_loc, norma, norma_loc, normr, normr_loc 1520 ) 1521 ); 1522 // clang-format on 1523 } else { 1524 // clang-format off 1525 PetscCall( 1526 detail::ExecuteWNorm<detail::ErrorWNormTransform>( 1527 thrust::make_tuple(uptr, yptr, ait, rit), 1528 thrust::make_tuple(uptr + nl, yptr + nl, ait, rit), 1529 wnormtype, stream, ignore_max, norm, norm_loc, norma, norma_loc, normr, normr_loc 1530 ) 1531 ); 1532 // clang-format on 1533 } 1534 } else if (!vatol) { 1535 if (E) { 1536 // clang-format off 1537 PetscCall( 1538 detail::ExecuteWNorm<detail::ErrorWNormETransform>( 1539 thrust::make_tuple(uptr, yptr, eptr, ait, rptr), 1540 thrust::make_tuple(uptr + nl, yptr + nl, eptr + nl, ait, rptr + nl), 1541 wnormtype, stream, ignore_max, norm, norm_loc, norma, norma_loc, normr, normr_loc 1542 ) 1543 ); 1544 // clang-format on 1545 } else { 1546 // clang-format off 1547 PetscCall( 1548 detail::ExecuteWNorm<detail::ErrorWNormTransform>( 1549 thrust::make_tuple(uptr, yptr, ait, rptr), 1550 thrust::make_tuple(uptr + nl, yptr + nl, ait, rptr + nl), 1551 wnormtype, stream, ignore_max, norm, norm_loc, norma, norma_loc, normr, normr_loc 1552 ) 1553 ); 1554 // clang-format on 1555 } 1556 } else if (!vrtol) { 1557 if (E) { 1558 // clang-format off 1559 PetscCall( 1560 detail::ExecuteWNorm<detail::ErrorWNormETransform>( 1561 thrust::make_tuple(uptr, yptr, eptr, aptr, rit), 1562 thrust::make_tuple(uptr + nl, yptr + nl, eptr + nl, aptr + nl, rit), 1563 wnormtype, stream, ignore_max, norm, norm_loc, norma, norma_loc, normr, normr_loc 1564 ) 1565 ); 1566 // clang-format on 1567 } else { 1568 // clang-format off 1569 PetscCall( 1570 detail::ExecuteWNorm<detail::ErrorWNormTransform>( 1571 thrust::make_tuple(uptr, yptr, aptr, rit), 1572 thrust::make_tuple(uptr + nl, yptr + nl, aptr + nl, rit), 1573 wnormtype, stream, ignore_max, norm, norm_loc, norma, norma_loc, normr, normr_loc 1574 ) 1575 ); 1576 // clang-format on 1577 } 1578 } else { 1579 if (E) { 1580 // clang-format off 1581 PetscCall( 1582 detail::ExecuteWNorm<detail::ErrorWNormETransform>( 1583 thrust::make_tuple(uptr, yptr, eptr, aptr, rptr), 1584 thrust::make_tuple(uptr + nl, yptr + nl, eptr + nl, aptr + nl, rptr + nl), 1585 wnormtype, stream, ignore_max, norm, norm_loc, norma, norma_loc, normr, normr_loc 1586 ) 1587 ); 1588 // clang-format on 1589 } else { 1590 // clang-format off 1591 PetscCall( 1592 detail::ExecuteWNorm<detail::ErrorWNormTransform>( 1593 thrust::make_tuple(uptr, yptr, aptr, rptr), 1594 thrust::make_tuple(uptr + nl, yptr + nl, aptr + nl, rptr + nl), 1595 wnormtype, stream, ignore_max, norm, norm_loc, norma, norma_loc, normr, normr_loc 1596 ) 1597 ); 1598 // clang-format on 1599 } 1600 } 1601 } 1602 PetscFunctionReturn(PETSC_SUCCESS); 1603 } 1604 1605 namespace detail 1606 { 1607 struct dotnorm2_mult { 1608 PETSC_NODISCARD PETSC_HOSTDEVICE_INLINE_DECL thrust::tuple<PetscScalar, PetscScalar> operator()(const PetscScalar &s, const PetscScalar &t) const noexcept 1609 { 1610 const auto conjt = PetscConj(t); 1611 1612 return {s * conjt, t * conjt}; 1613 } 1614 }; 1615 1616 // it is positively __bananas__ that thrust does not define default operator+ for tuples... I 1617 // would do it myself but now I am worried that they do so on purpose... 1618 struct dotnorm2_tuple_plus { 1619 using value_type = thrust::tuple<PetscScalar, PetscScalar>; 1620 1621 PETSC_NODISCARD PETSC_HOSTDEVICE_INLINE_DECL value_type operator()(const value_type &lhs, const value_type &rhs) const noexcept { return {lhs.get<0>() + rhs.get<0>(), lhs.get<1>() + rhs.get<1>()}; } 1622 }; 1623 1624 } // namespace detail 1625 1626 // v->ops->dotnorm2 1627 template <device::cupm::DeviceType T> 1628 inline PetscErrorCode VecSeq_CUPM<T>::DotNorm2(Vec s, Vec t, PetscScalar *dp, PetscScalar *nm) noexcept 1629 { 1630 PetscDeviceContext dctx; 1631 cupmStream_t stream; 1632 1633 PetscFunctionBegin; 1634 PetscCall(GetHandles_(&dctx, &stream)); 1635 { 1636 PetscScalar dpt = 0.0, nmt = 0.0; 1637 const auto sdptr = thrust::device_pointer_cast(DeviceArrayRead(dctx, s).data()); 1638 1639 // clang-format off 1640 PetscCallThrust( 1641 thrust::tie(*dp, *nm) = THRUST_CALL( 1642 thrust::inner_product, 1643 stream, 1644 sdptr, sdptr+s->map->n, thrust::device_pointer_cast(DeviceArrayRead(dctx, t).data()), 1645 thrust::make_tuple(dpt, nmt), 1646 detail::dotnorm2_tuple_plus{}, detail::dotnorm2_mult{} 1647 ); 1648 ); 1649 // clang-format on 1650 } 1651 PetscFunctionReturn(PETSC_SUCCESS); 1652 } 1653 1654 namespace detail 1655 { 1656 struct conjugate { 1657 PETSC_NODISCARD PETSC_HOSTDEVICE_INLINE_DECL PetscScalar operator()(PetscScalar x) const noexcept { return PetscConj(x); } 1658 }; 1659 1660 } // namespace detail 1661 1662 // v->ops->conjugate 1663 template <device::cupm::DeviceType T> 1664 inline PetscErrorCode VecSeq_CUPM<T>::Conjugate(Vec xin) noexcept 1665 { 1666 PetscFunctionBegin; 1667 if (PetscDefined(USE_COMPLEX)) PetscCall(PointwiseUnary_(detail::conjugate{}, xin)); 1668 PetscFunctionReturn(PETSC_SUCCESS); 1669 } 1670 1671 namespace detail 1672 { 1673 struct real_part { 1674 PETSC_NODISCARD PETSC_HOSTDEVICE_INLINE_DECL thrust::tuple<PetscReal, PetscInt> operator()(const thrust::tuple<PetscScalar, PetscInt> &x) const { return {PetscRealPart(x.get<0>()), x.get<1>()}; } 1675 1676 PETSC_NODISCARD PETSC_HOSTDEVICE_INLINE_DECL PetscReal operator()(PetscScalar x) const { return PetscRealPart(x); } 1677 }; 1678 1679 // deriving from Operator allows us to "store" an instance of the operator in the class but 1680 // also take advantage of empty base class optimization if the operator is stateless 1681 template <typename Operator> 1682 class tuple_compare : Operator { 1683 public: 1684 using tuple_type = thrust::tuple<PetscReal, PetscInt>; 1685 using operator_type = Operator; 1686 1687 PETSC_NODISCARD PETSC_HOSTDEVICE_INLINE_DECL tuple_type operator()(const tuple_type &x, const tuple_type &y) const noexcept 1688 { 1689 if (op_()(y.get<0>(), x.get<0>())) { 1690 // if y is strictly greater/less than x, return y 1691 return y; 1692 } else if (y.get<0>() == x.get<0>()) { 1693 // if equal, prefer lower index 1694 return y.get<1>() < x.get<1>() ? y : x; 1695 } 1696 // otherwise return x 1697 return x; 1698 } 1699 1700 private: 1701 PETSC_NODISCARD PETSC_HOSTDEVICE_INLINE_DECL const operator_type &op_() const noexcept { return *this; } 1702 }; 1703 1704 } // namespace detail 1705 1706 template <device::cupm::DeviceType T> 1707 template <typename TupleFuncT, typename UnaryFuncT> 1708 inline PetscErrorCode VecSeq_CUPM<T>::MinMax_(TupleFuncT &&tuple_ftr, UnaryFuncT &&unary_ftr, Vec v, PetscInt *p, PetscReal *m) noexcept 1709 { 1710 PetscFunctionBegin; 1711 PetscCheckTypeNames(v, VECSEQCUPM(), VECMPICUPM()); 1712 if (p) *p = -1; 1713 if (const auto n = v->map->n) { 1714 PetscDeviceContext dctx; 1715 cupmStream_t stream; 1716 1717 PetscCall(GetHandles_(&dctx, &stream)); 1718 // needed to: 1719 // 1. switch between transform_reduce and reduce 1720 // 2. strip the real_part functor from the arguments 1721 #if PetscDefined(USE_COMPLEX) 1722 #define THRUST_MINMAX_REDUCE(...) THRUST_CALL(thrust::transform_reduce, __VA_ARGS__) 1723 #else 1724 #define THRUST_MINMAX_REDUCE(s, b, e, real_part__, ...) THRUST_CALL(thrust::reduce, s, b, e, __VA_ARGS__) 1725 #endif 1726 { 1727 const auto vptr = thrust::device_pointer_cast(DeviceArrayRead(dctx, v).data()); 1728 1729 if (p) { 1730 // clang-format off 1731 const auto zip = thrust::make_zip_iterator( 1732 thrust::make_tuple(std::move(vptr), thrust::make_counting_iterator(PetscInt{0})) 1733 ); 1734 // clang-format on 1735 // need to use preprocessor conditionals since otherwise thrust complains about not being 1736 // able to convert a thrust::device_reference<PetscScalar> to a PetscReal on complex 1737 // builds... 1738 // clang-format off 1739 PetscCallThrust( 1740 thrust::tie(*m, *p) = THRUST_MINMAX_REDUCE( 1741 stream, zip, zip + n, detail::real_part{}, 1742 thrust::make_tuple(*m, *p), std::forward<TupleFuncT>(tuple_ftr) 1743 ); 1744 ); 1745 // clang-format on 1746 } else { 1747 // clang-format off 1748 PetscCallThrust( 1749 *m = THRUST_MINMAX_REDUCE( 1750 stream, vptr, vptr + n, detail::real_part{}, 1751 *m, std::forward<UnaryFuncT>(unary_ftr) 1752 ); 1753 ); 1754 // clang-format on 1755 } 1756 } 1757 #undef THRUST_MINMAX_REDUCE 1758 } 1759 // REVIEW ME: flops? 1760 PetscFunctionReturn(PETSC_SUCCESS); 1761 } 1762 1763 // v->ops->max 1764 template <device::cupm::DeviceType T> 1765 inline PetscErrorCode VecSeq_CUPM<T>::Max(Vec v, PetscInt *p, PetscReal *m) noexcept 1766 { 1767 using tuple_functor = detail::tuple_compare<thrust::greater<PetscReal>>; 1768 using unary_functor = thrust::maximum<PetscReal>; 1769 1770 PetscFunctionBegin; 1771 *m = PETSC_MIN_REAL; 1772 // use {} constructor syntax otherwise most vexing parse 1773 PetscCall(MinMax_(tuple_functor{}, unary_functor{}, v, p, m)); 1774 PetscFunctionReturn(PETSC_SUCCESS); 1775 } 1776 1777 // v->ops->min 1778 template <device::cupm::DeviceType T> 1779 inline PetscErrorCode VecSeq_CUPM<T>::Min(Vec v, PetscInt *p, PetscReal *m) noexcept 1780 { 1781 using tuple_functor = detail::tuple_compare<thrust::less<PetscReal>>; 1782 using unary_functor = thrust::minimum<PetscReal>; 1783 1784 PetscFunctionBegin; 1785 *m = PETSC_MAX_REAL; 1786 // use {} constructor syntax otherwise most vexing parse 1787 PetscCall(MinMax_(tuple_functor{}, unary_functor{}, v, p, m)); 1788 PetscFunctionReturn(PETSC_SUCCESS); 1789 } 1790 1791 // v->ops->sum 1792 template <device::cupm::DeviceType T> 1793 inline PetscErrorCode VecSeq_CUPM<T>::Sum(Vec v, PetscScalar *sum) noexcept 1794 { 1795 PetscFunctionBegin; 1796 if (const auto n = v->map->n) { 1797 PetscDeviceContext dctx; 1798 cupmStream_t stream; 1799 1800 PetscCall(GetHandles_(&dctx, &stream)); 1801 const auto dptr = thrust::device_pointer_cast(DeviceArrayRead(dctx, v).data()); 1802 // REVIEW ME: why not cupmBlasXasum()? 1803 PetscCallThrust(*sum = THRUST_CALL(thrust::reduce, stream, dptr, dptr + n, PetscScalar{0.0});); 1804 // REVIEW ME: must be at least n additions 1805 PetscCall(PetscLogGpuFlops(n)); 1806 } else { 1807 *sum = 0.0; 1808 } 1809 PetscFunctionReturn(PETSC_SUCCESS); 1810 } 1811 1812 template <device::cupm::DeviceType T> 1813 inline PetscErrorCode VecSeq_CUPM<T>::Shift(Vec v, PetscScalar shift) noexcept 1814 { 1815 PetscFunctionBegin; 1816 PetscCall(PointwiseUnary_(device::cupm::functors::make_plus_equals(shift), v)); 1817 PetscFunctionReturn(PETSC_SUCCESS); 1818 } 1819 1820 template <device::cupm::DeviceType T> 1821 inline PetscErrorCode VecSeq_CUPM<T>::SetRandom(Vec v, PetscRandom rand) noexcept 1822 { 1823 PetscFunctionBegin; 1824 if (const auto n = v->map->n) { 1825 PetscBool iscurand; 1826 PetscDeviceContext dctx; 1827 1828 PetscCall(GetHandles_(&dctx)); 1829 PetscCall(PetscObjectTypeCompare(PetscObjectCast(rand), PETSCCURAND, &iscurand)); 1830 if (iscurand) PetscCall(PetscRandomGetValues(rand, n, DeviceArrayWrite(dctx, v))); 1831 else PetscCall(PetscRandomGetValues(rand, n, HostArrayWrite(dctx, v))); 1832 } else { 1833 PetscCall(MaybeIncrementEmptyLocalVec(v)); 1834 } 1835 // REVIEW ME: flops???? 1836 // REVIEW ME: Timing??? 1837 PetscFunctionReturn(PETSC_SUCCESS); 1838 } 1839 1840 // v->ops->setpreallocation 1841 template <device::cupm::DeviceType T> 1842 inline PetscErrorCode VecSeq_CUPM<T>::SetPreallocationCOO(Vec v, PetscCount ncoo, const PetscInt coo_i[]) noexcept 1843 { 1844 PetscDeviceContext dctx; 1845 1846 PetscFunctionBegin; 1847 PetscCall(GetHandles_(&dctx)); 1848 PetscCall(VecSetPreallocationCOO_Seq(v, ncoo, coo_i)); 1849 PetscCall(SetPreallocationCOO_CUPMBase(v, ncoo, coo_i, dctx)); 1850 PetscFunctionReturn(PETSC_SUCCESS); 1851 } 1852 1853 namespace kernels 1854 { 1855 template <typename F> 1856 PETSC_DEVICE_INLINE_DECL void add_coo_values_impl(const PetscScalar *PETSC_RESTRICT vv, PetscCount n, const PetscCount *PETSC_RESTRICT jmap, const PetscCount *PETSC_RESTRICT perm, InsertMode imode, PetscScalar *PETSC_RESTRICT xv, F &&xvindex) 1857 { 1858 ::Petsc::device::cupm::kernels::util::grid_stride_1D(n, [=](PetscCount i) { 1859 const auto end = jmap[i + 1]; 1860 const auto idx = xvindex(i); 1861 PetscScalar sum = 0.0; 1862 1863 for (auto k = jmap[i]; k < end; ++k) sum += vv[perm[k]]; 1864 1865 if (imode == INSERT_VALUES) { 1866 xv[idx] = sum; 1867 } else { 1868 xv[idx] += sum; 1869 } 1870 }); 1871 return; 1872 } 1873 1874 namespace 1875 { 1876 1877 PETSC_KERNEL_DECL void add_coo_values(const PetscScalar *PETSC_RESTRICT v, PetscCount n, const PetscCount *PETSC_RESTRICT jmap1, const PetscCount *PETSC_RESTRICT perm1, InsertMode imode, PetscScalar *PETSC_RESTRICT xv) 1878 { 1879 add_coo_values_impl(v, n, jmap1, perm1, imode, xv, [](PetscCount i) { return i; }); 1880 return; 1881 } 1882 1883 } // namespace 1884 1885 #if PetscDefined(USING_HCC) 1886 namespace do_not_use 1887 { 1888 1889 // Needed to silence clang warning: 1890 // 1891 // warning: function 'FUNCTION NAME' is not needed and will not be emitted 1892 // 1893 // The warning is silly, since the function *is* used, however the host compiler does not 1894 // appear see this. Likely because the function using it is in a template. 1895 // 1896 // This warning appeared in clang-11, and still persists until clang-15 (21/02/2023) 1897 inline void silence_warning_function_sum_kernel_is_not_needed_and_will_not_be_emitted() 1898 { 1899 (void)sum_kernel; 1900 } 1901 1902 inline void silence_warning_function_add_coo_values_is_not_needed_and_will_not_be_emitted() 1903 { 1904 (void)add_coo_values; 1905 } 1906 1907 } // namespace do_not_use 1908 #endif 1909 1910 } // namespace kernels 1911 1912 // v->ops->setvaluescoo 1913 template <device::cupm::DeviceType T> 1914 inline PetscErrorCode VecSeq_CUPM<T>::SetValuesCOO(Vec x, const PetscScalar v[], InsertMode imode) noexcept 1915 { 1916 auto vv = const_cast<PetscScalar *>(v); 1917 PetscMemType memtype; 1918 PetscDeviceContext dctx; 1919 cupmStream_t stream; 1920 1921 PetscFunctionBegin; 1922 PetscCall(GetHandles_(&dctx, &stream)); 1923 PetscCall(PetscGetMemType(v, &memtype)); 1924 if (PetscMemTypeHost(memtype)) { 1925 const auto size = VecIMPLCast(x)->coo_n; 1926 1927 // If user gave v[] in host, we might need to copy it to device if any 1928 PetscCall(PetscDeviceMalloc(dctx, PETSC_MEMTYPE_CUPM(), size, &vv)); 1929 PetscCall(PetscCUPMMemcpyAsync(vv, v, size, cupmMemcpyHostToDevice, stream)); 1930 } 1931 1932 if (const auto n = x->map->n) { 1933 const auto vcu = VecCUPMCast(x); 1934 1935 PetscCall(PetscCUPMLaunchKernel1D(n, 0, stream, kernels::add_coo_values, vv, n, vcu->jmap1_d, vcu->perm1_d, imode, imode == INSERT_VALUES ? DeviceArrayWrite(dctx, x).data() : DeviceArrayReadWrite(dctx, x).data())); 1936 } else { 1937 PetscCall(MaybeIncrementEmptyLocalVec(x)); 1938 } 1939 1940 if (PetscMemTypeHost(memtype)) PetscCall(PetscDeviceFree(dctx, vv)); 1941 PetscCall(PetscDeviceContextSynchronize(dctx)); 1942 PetscFunctionReturn(PETSC_SUCCESS); 1943 } 1944 1945 } // namespace impl 1946 1947 // ========================================================================================== 1948 // VecSeq_CUPM - Implementations 1949 // ========================================================================================== 1950 1951 namespace 1952 { 1953 1954 template <device::cupm::DeviceType T> 1955 inline PetscErrorCode VecCreateSeqCUPMAsync(MPI_Comm comm, PetscInt n, Vec *v) noexcept 1956 { 1957 PetscFunctionBegin; 1958 PetscValidPointer(v, 4); 1959 PetscCall(impl::VecSeq_CUPM<T>::CreateSeqCUPM(comm, 0, n, v, PETSC_TRUE)); 1960 PetscFunctionReturn(PETSC_SUCCESS); 1961 } 1962 1963 template <device::cupm::DeviceType T> 1964 inline PetscErrorCode VecCreateSeqCUPMWithArraysAsync(MPI_Comm comm, PetscInt bs, PetscInt n, const PetscScalar cpuarray[], const PetscScalar gpuarray[], Vec *v) noexcept 1965 { 1966 PetscFunctionBegin; 1967 if (n && cpuarray) PetscValidScalarPointer(cpuarray, 4); 1968 PetscValidPointer(v, 6); 1969 PetscCall(impl::VecSeq_CUPM<T>::CreateSeqCUPMWithBothArrays(comm, bs, n, cpuarray, gpuarray, v)); 1970 PetscFunctionReturn(PETSC_SUCCESS); 1971 } 1972 1973 template <PetscMemoryAccessMode mode, device::cupm::DeviceType T> 1974 inline PetscErrorCode VecCUPMGetArrayAsync_Private(Vec v, PetscScalar **a, PetscDeviceContext dctx) noexcept 1975 { 1976 PetscFunctionBegin; 1977 PetscValidHeaderSpecific(v, VEC_CLASSID, 1); 1978 PetscValidPointer(a, 2); 1979 PetscCall(PetscDeviceContextGetOptionalNullContext_Internal(&dctx)); 1980 PetscCall(impl::VecSeq_CUPM<T>::template GetArray<PETSC_MEMTYPE_DEVICE, mode>(v, a, dctx)); 1981 PetscFunctionReturn(PETSC_SUCCESS); 1982 } 1983 1984 template <PetscMemoryAccessMode mode, device::cupm::DeviceType T> 1985 inline PetscErrorCode VecCUPMRestoreArrayAsync_Private(Vec v, PetscScalar **a, PetscDeviceContext dctx) noexcept 1986 { 1987 PetscFunctionBegin; 1988 PetscValidHeaderSpecific(v, VEC_CLASSID, 1); 1989 PetscCall(PetscDeviceContextGetOptionalNullContext_Internal(&dctx)); 1990 PetscCall(impl::VecSeq_CUPM<T>::template RestoreArray<PETSC_MEMTYPE_DEVICE, mode>(v, a, dctx)); 1991 PetscFunctionReturn(PETSC_SUCCESS); 1992 } 1993 1994 template <device::cupm::DeviceType T> 1995 inline PetscErrorCode VecCUPMGetArrayAsync(Vec v, PetscScalar **a, PetscDeviceContext dctx = nullptr) noexcept 1996 { 1997 PetscFunctionBegin; 1998 PetscCall(VecCUPMGetArrayAsync_Private<PETSC_MEMORY_ACCESS_READ_WRITE, T>(v, a, dctx)); 1999 PetscFunctionReturn(PETSC_SUCCESS); 2000 } 2001 2002 template <device::cupm::DeviceType T> 2003 inline PetscErrorCode VecCUPMRestoreArrayAsync(Vec v, PetscScalar **a, PetscDeviceContext dctx = nullptr) noexcept 2004 { 2005 PetscFunctionBegin; 2006 PetscCall(VecCUPMRestoreArrayAsync_Private<PETSC_MEMORY_ACCESS_READ_WRITE, T>(v, a, dctx)); 2007 PetscFunctionReturn(PETSC_SUCCESS); 2008 } 2009 2010 template <device::cupm::DeviceType T> 2011 inline PetscErrorCode VecCUPMGetArrayReadAsync(Vec v, const PetscScalar **a, PetscDeviceContext dctx = nullptr) noexcept 2012 { 2013 PetscFunctionBegin; 2014 PetscCall(VecCUPMGetArrayAsync_Private<PETSC_MEMORY_ACCESS_READ, T>(v, const_cast<PetscScalar **>(a), dctx)); 2015 PetscFunctionReturn(PETSC_SUCCESS); 2016 } 2017 2018 template <device::cupm::DeviceType T> 2019 inline PetscErrorCode VecCUPMRestoreArrayReadAsync(Vec v, const PetscScalar **a, PetscDeviceContext dctx = nullptr) noexcept 2020 { 2021 PetscFunctionBegin; 2022 PetscCall(VecCUPMRestoreArrayAsync_Private<PETSC_MEMORY_ACCESS_READ, T>(v, const_cast<PetscScalar **>(a), dctx)); 2023 PetscFunctionReturn(PETSC_SUCCESS); 2024 } 2025 2026 template <device::cupm::DeviceType T> 2027 inline PetscErrorCode VecCUPMGetArrayWriteAsync(Vec v, PetscScalar **a, PetscDeviceContext dctx = nullptr) noexcept 2028 { 2029 PetscFunctionBegin; 2030 PetscCall(VecCUPMGetArrayAsync_Private<PETSC_MEMORY_ACCESS_WRITE, T>(v, a, dctx)); 2031 PetscFunctionReturn(PETSC_SUCCESS); 2032 } 2033 2034 template <device::cupm::DeviceType T> 2035 inline PetscErrorCode VecCUPMRestoreArrayWriteAsync(Vec v, PetscScalar **a, PetscDeviceContext dctx = nullptr) noexcept 2036 { 2037 PetscFunctionBegin; 2038 PetscCall(VecCUPMRestoreArrayAsync_Private<PETSC_MEMORY_ACCESS_WRITE, T>(v, a, dctx)); 2039 PetscFunctionReturn(PETSC_SUCCESS); 2040 } 2041 2042 template <device::cupm::DeviceType T> 2043 inline PetscErrorCode VecCUPMPlaceArrayAsync(Vec vin, const PetscScalar a[]) noexcept 2044 { 2045 PetscFunctionBegin; 2046 PetscValidHeaderSpecific(vin, VEC_CLASSID, 1); 2047 PetscCall(impl::VecSeq_CUPM<T>::template PlaceArray<PETSC_MEMTYPE_DEVICE>(vin, a)); 2048 PetscFunctionReturn(PETSC_SUCCESS); 2049 } 2050 2051 template <device::cupm::DeviceType T> 2052 inline PetscErrorCode VecCUPMReplaceArrayAsync(Vec vin, const PetscScalar a[]) noexcept 2053 { 2054 PetscFunctionBegin; 2055 PetscValidHeaderSpecific(vin, VEC_CLASSID, 1); 2056 PetscCall(impl::VecSeq_CUPM<T>::template ReplaceArray<PETSC_MEMTYPE_DEVICE>(vin, a)); 2057 PetscFunctionReturn(PETSC_SUCCESS); 2058 } 2059 2060 template <device::cupm::DeviceType T> 2061 inline PetscErrorCode VecCUPMResetArrayAsync(Vec vin) noexcept 2062 { 2063 PetscFunctionBegin; 2064 PetscValidHeaderSpecific(vin, VEC_CLASSID, 1); 2065 PetscCall(impl::VecSeq_CUPM<T>::template ResetArray<PETSC_MEMTYPE_DEVICE>(vin)); 2066 PetscFunctionReturn(PETSC_SUCCESS); 2067 } 2068 2069 } // anonymous namespace 2070 2071 } // namespace cupm 2072 2073 } // namespace vec 2074 2075 } // namespace Petsc 2076 2077 #endif // PETSCVECSEQCUPM_HPP 2078