1 #pragma once 2 3 #include <petsc/private/veccupmimpl.h> 4 #include <petsc/private/cpp/utility.hpp> // util::index_sequence 5 6 #include <../src/sys/objects/device/impls/cupm/kernels.hpp> // grid_stride_1D() 7 #include <../src/vec/vec/impls/dvecimpl.h> // Vec_Seq 8 9 namespace Petsc 10 { 11 12 namespace vec 13 { 14 15 namespace cupm 16 { 17 18 namespace impl 19 { 20 21 // ========================================================================================== 22 // VecSeq_CUPM 23 // ========================================================================================== 24 25 template <device::cupm::DeviceType T> 26 class VecSeq_CUPM : Vec_CUPMBase<T, VecSeq_CUPM<T>> { 27 public: 28 PETSC_VEC_CUPM_BASE_CLASS_HEADER(base_type, T, VecSeq_CUPM<T>); 29 30 private: 31 PETSC_NODISCARD static Vec_Seq *VecIMPLCast_(Vec) noexcept; 32 PETSC_NODISCARD static constexpr VecType VECIMPLCUPM_() noexcept; 33 PETSC_NODISCARD static constexpr VecType VECIMPL_() noexcept; 34 35 static PetscErrorCode VecDestroy_IMPL_(Vec) noexcept; 36 static PetscErrorCode VecResetArray_IMPL_(Vec) noexcept; 37 static PetscErrorCode VecPlaceArray_IMPL_(Vec, const PetscScalar *) noexcept; 38 static PetscErrorCode VecCreate_IMPL_Private_(Vec, PetscBool *, PetscInt, PetscScalar *) noexcept; 39 40 static PetscErrorCode MaybeIncrementEmptyLocalVec(Vec) noexcept; 41 42 // common core for min and max 43 template <typename TupleFuncT, typename UnaryFuncT> 44 static PetscErrorCode MinMax_(TupleFuncT &&, UnaryFuncT &&, Vec, PetscInt *, PetscReal *) noexcept; 45 // common core for pointwise binary and pointwise unary thrust functions 46 template <typename BinaryFuncT> 47 static PetscErrorCode PointwiseBinary_(BinaryFuncT &&, Vec, Vec, Vec, PetscDeviceContext = nullptr) noexcept; 48 template <typename BinaryFuncT> 49 static PetscErrorCode PointwiseBinaryDispatch_(PetscErrorCode (*)(Vec, Vec, Vec), BinaryFuncT &&, Vec, Vec, Vec, PetscDeviceContext = nullptr) noexcept; 50 template <typename UnaryFuncT> 51 static PetscErrorCode PointwiseUnary_(UnaryFuncT &&, Vec, Vec, PetscDeviceContext = nullptr) noexcept; 52 // mdot dispatchers 53 static PetscErrorCode MDot_(/* use complex = */ std::true_type, Vec, PetscInt, const Vec[], PetscScalar *, PetscDeviceContext) noexcept; 54 static PetscErrorCode MDot_(/* use complex = */ std::false_type, Vec, PetscInt, const Vec[], PetscScalar *, PetscDeviceContext) noexcept; 55 template <std::size_t... Idx> 56 static PetscErrorCode MDot_kernel_dispatch_(PetscDeviceContext, cupmStream_t, const PetscScalar *, const Vec[], PetscInt, PetscScalar *, util::index_sequence<Idx...>) noexcept; 57 template <int> 58 static PetscErrorCode MDot_kernel_dispatch_(PetscDeviceContext, cupmStream_t, const PetscScalar *, const Vec[], PetscInt, PetscScalar *, PetscInt &) noexcept; 59 template <std::size_t... Idx> 60 static PetscErrorCode MAXPY_kernel_dispatch_(PetscDeviceContext, cupmStream_t, PetscScalar *, const PetscScalar *, const Vec *, PetscInt, util::index_sequence<Idx...>) noexcept; 61 template <int> 62 static PetscErrorCode MAXPY_kernel_dispatch_(PetscDeviceContext, cupmStream_t, PetscScalar *, const PetscScalar *, const Vec *, PetscInt, PetscInt &) noexcept; 63 // common core for the various create routines 64 static PetscErrorCode CreateSeqCUPM_(Vec, PetscDeviceContext, PetscScalar * /*host_ptr*/ = nullptr, PetscScalar * /*device_ptr*/ = nullptr) noexcept; 65 66 public: 67 // callable directly via a bespoke function 68 static PetscErrorCode CreateSeqCUPM(MPI_Comm, PetscInt, PetscInt, Vec *, PetscBool) noexcept; 69 static PetscErrorCode CreateSeqCUPMWithBothArrays(MPI_Comm, PetscInt, PetscInt, const PetscScalar[], const PetscScalar[], Vec *) noexcept; 70 71 static PetscErrorCode InitializeAsyncFunctions(Vec) noexcept; 72 static PetscErrorCode ClearAsyncFunctions(Vec) noexcept; 73 74 // callable indirectly via function pointers 75 static PetscErrorCode Duplicate(Vec, Vec *) noexcept; 76 static PetscErrorCode AYPX(Vec, PetscScalar, Vec) noexcept; 77 static PetscErrorCode AYPXAsync(Vec, PetscScalar, Vec, PetscDeviceContext) noexcept; 78 static PetscErrorCode AXPY(Vec, PetscScalar, Vec) noexcept; 79 static PetscErrorCode AXPYAsync(Vec, PetscScalar, Vec, PetscDeviceContext) noexcept; 80 static PetscErrorCode PointwiseDivide(Vec, Vec, Vec) noexcept; 81 static PetscErrorCode PointwiseDivideAsync(Vec, Vec, Vec, PetscDeviceContext) noexcept; 82 static PetscErrorCode PointwiseMult(Vec, Vec, Vec) noexcept; 83 static PetscErrorCode PointwiseMultAsync(Vec, Vec, Vec, PetscDeviceContext) noexcept; 84 static PetscErrorCode PointwiseMax(Vec, Vec, Vec) noexcept; 85 static PetscErrorCode PointwiseMaxAsync(Vec, Vec, Vec, PetscDeviceContext) noexcept; 86 static PetscErrorCode PointwiseMaxAbs(Vec, Vec, Vec) noexcept; 87 static PetscErrorCode PointwiseMaxAbsAsync(Vec, Vec, Vec, PetscDeviceContext) noexcept; 88 static PetscErrorCode PointwiseMin(Vec, Vec, Vec) noexcept; 89 static PetscErrorCode PointwiseMinAsync(Vec, Vec, Vec, PetscDeviceContext) noexcept; 90 static PetscErrorCode Reciprocal(Vec) noexcept; 91 static PetscErrorCode ReciprocalAsync(Vec, PetscDeviceContext) noexcept; 92 static PetscErrorCode Abs(Vec) noexcept; 93 static PetscErrorCode AbsAsync(Vec, PetscDeviceContext) noexcept; 94 static PetscErrorCode SqrtAbs(Vec) noexcept; 95 static PetscErrorCode SqrtAbsAsync(Vec, PetscDeviceContext) noexcept; 96 static PetscErrorCode Exp(Vec) noexcept; 97 static PetscErrorCode ExpAsync(Vec, PetscDeviceContext) noexcept; 98 static PetscErrorCode Log(Vec) noexcept; 99 static PetscErrorCode LogAsync(Vec, PetscDeviceContext) noexcept; 100 static PetscErrorCode WAXPY(Vec, PetscScalar, Vec, Vec) noexcept; 101 static PetscErrorCode WAXPYAsync(Vec, PetscScalar, Vec, Vec, PetscDeviceContext) noexcept; 102 static PetscErrorCode MAXPY(Vec, PetscInt, const PetscScalar[], Vec *) noexcept; 103 static PetscErrorCode MAXPYAsync(Vec, PetscInt, const PetscScalar[], Vec *, PetscDeviceContext) noexcept; 104 static PetscErrorCode Dot(Vec, Vec, PetscScalar *) noexcept; 105 static PetscErrorCode MDot(Vec, PetscInt, const Vec[], PetscScalar *) noexcept; 106 static PetscErrorCode Set(Vec, PetscScalar) noexcept; 107 static PetscErrorCode SetAsync(Vec, PetscScalar, PetscDeviceContext) noexcept; 108 static PetscErrorCode Scale(Vec, PetscScalar) noexcept; 109 static PetscErrorCode ScaleAsync(Vec, PetscScalar, PetscDeviceContext) noexcept; 110 static PetscErrorCode TDot(Vec, Vec, PetscScalar *) noexcept; 111 static PetscErrorCode Copy(Vec, Vec) noexcept; 112 static PetscErrorCode CopyAsync(Vec, Vec, PetscDeviceContext) noexcept; 113 static PetscErrorCode Swap(Vec, Vec) noexcept; 114 static PetscErrorCode SwapAsync(Vec, Vec, PetscDeviceContext) noexcept; 115 static PetscErrorCode AXPBY(Vec, PetscScalar, PetscScalar, Vec) noexcept; 116 static PetscErrorCode AXPBYAsync(Vec, PetscScalar, PetscScalar, Vec, PetscDeviceContext) noexcept; 117 static PetscErrorCode AXPBYPCZ(Vec, PetscScalar, PetscScalar, PetscScalar, Vec, Vec) noexcept; 118 static PetscErrorCode AXPBYPCZAsync(Vec, PetscScalar, PetscScalar, PetscScalar, Vec, Vec, PetscDeviceContext) noexcept; 119 static PetscErrorCode Norm(Vec, NormType, PetscReal *) noexcept; 120 static PetscErrorCode ErrorWnorm(Vec, Vec, Vec, NormType, PetscReal, Vec, PetscReal, Vec, PetscReal, PetscReal *, PetscInt *, PetscReal *, PetscInt *, PetscReal *, PetscInt *) noexcept; 121 static PetscErrorCode DotNorm2(Vec, Vec, PetscScalar *, PetscScalar *) noexcept; 122 static PetscErrorCode Conjugate(Vec) noexcept; 123 static PetscErrorCode ConjugateAsync(Vec, PetscDeviceContext) noexcept; 124 template <PetscMemoryAccessMode> 125 static PetscErrorCode GetLocalVector(Vec, Vec) noexcept; 126 template <PetscMemoryAccessMode> 127 static PetscErrorCode RestoreLocalVector(Vec, Vec) noexcept; 128 static PetscErrorCode Max(Vec, PetscInt *, PetscReal *) noexcept; 129 static PetscErrorCode Min(Vec, PetscInt *, PetscReal *) noexcept; 130 static PetscErrorCode Sum(Vec, PetscScalar *) noexcept; 131 static PetscErrorCode Shift(Vec, PetscScalar) noexcept; 132 static PetscErrorCode ShiftAsync(Vec, PetscScalar, PetscDeviceContext) noexcept; 133 static PetscErrorCode SetRandom(Vec, PetscRandom) noexcept; 134 static PetscErrorCode BindToCPU(Vec, PetscBool) noexcept; 135 static PetscErrorCode SetPreallocationCOO(Vec, PetscCount, const PetscInt[]) noexcept; 136 static PetscErrorCode SetValuesCOO(Vec, const PetscScalar[], InsertMode) noexcept; 137 }; 138 139 namespace kernels 140 { 141 142 template <typename F> 143 PETSC_DEVICE_INLINE_DECL void add_coo_values_impl(const PetscScalar *PETSC_RESTRICT vv, PetscCount n, const PetscCount *PETSC_RESTRICT jmap, const PetscCount *PETSC_RESTRICT perm, InsertMode imode, PetscScalar *PETSC_RESTRICT xv, F &&xvindex) 144 { 145 ::Petsc::device::cupm::kernels::util::grid_stride_1D(n, [=](PetscCount i) { 146 const auto end = jmap[i + 1]; 147 const auto idx = xvindex(i); 148 PetscScalar sum = 0.0; 149 150 for (auto k = jmap[i]; k < end; ++k) sum += vv[perm[k]]; 151 152 if (imode == INSERT_VALUES) { 153 xv[idx] = sum; 154 } else { 155 xv[idx] += sum; 156 } 157 }); 158 return; 159 } 160 161 namespace 162 { 163 164 PETSC_KERNEL_DECL void add_coo_values(const PetscScalar *PETSC_RESTRICT v, PetscCount n, const PetscCount *PETSC_RESTRICT jmap1, const PetscCount *PETSC_RESTRICT perm1, InsertMode imode, PetscScalar *PETSC_RESTRICT xv) 165 { 166 add_coo_values_impl(v, n, jmap1, perm1, imode, xv, [](PetscCount i) { return i; }); 167 return; 168 } 169 170 } // namespace 171 172 #if PetscDefined(USING_HCC) 173 namespace do_not_use 174 { 175 176 // Needed to silence clang warning: 177 // 178 // warning: function 'FUNCTION NAME' is not needed and will not be emitted 179 // 180 // The warning is silly, since the function *is* used, however the host compiler does not 181 // appear see this. Likely because the function using it is in a template. 182 // 183 // This warning appeared in clang-11, and still persists until clang-15 (21/02/2023) 184 inline void silence_warning_function_add_coo_values_is_not_needed_and_will_not_be_emitted() 185 { 186 (void)add_coo_values; 187 } 188 189 } // namespace do_not_use 190 #endif 191 192 } // namespace kernels 193 194 } // namespace impl 195 196 // ========================================================================================== 197 // VecSeq_CUPM - Implementations 198 // ========================================================================================== 199 200 template <device::cupm::DeviceType T> 201 inline PetscErrorCode VecCreateSeqCUPMAsync(MPI_Comm comm, PetscInt n, Vec *v) noexcept 202 { 203 PetscFunctionBegin; 204 PetscAssertPointer(v, 4); 205 PetscCall(impl::VecSeq_CUPM<T>::CreateSeqCUPM(comm, 0, n, v, PETSC_TRUE)); 206 PetscFunctionReturn(PETSC_SUCCESS); 207 } 208 209 template <device::cupm::DeviceType T> 210 inline PetscErrorCode VecCreateSeqCUPMWithArraysAsync(MPI_Comm comm, PetscInt bs, PetscInt n, const PetscScalar cpuarray[], const PetscScalar gpuarray[], Vec *v) noexcept 211 { 212 PetscFunctionBegin; 213 if (n && cpuarray) PetscAssertPointer(cpuarray, 4); 214 PetscAssertPointer(v, 6); 215 PetscCall(impl::VecSeq_CUPM<T>::CreateSeqCUPMWithBothArrays(comm, bs, n, cpuarray, gpuarray, v)); 216 PetscFunctionReturn(PETSC_SUCCESS); 217 } 218 219 template <PetscMemoryAccessMode mode, device::cupm::DeviceType T> 220 inline PetscErrorCode VecCUPMGetArrayAsync_Private(Vec v, PetscScalar **a, PetscDeviceContext dctx) noexcept 221 { 222 PetscFunctionBegin; 223 PetscValidHeaderSpecific(v, VEC_CLASSID, 1); 224 PetscAssertPointer(a, 2); 225 PetscCall(PetscDeviceContextGetOptionalNullContext_Internal(&dctx)); 226 PetscCall(impl::VecSeq_CUPM<T>::template GetArray<PETSC_MEMTYPE_DEVICE, mode>(v, a, dctx)); 227 PetscFunctionReturn(PETSC_SUCCESS); 228 } 229 230 template <PetscMemoryAccessMode mode, device::cupm::DeviceType T> 231 inline PetscErrorCode VecCUPMRestoreArrayAsync_Private(Vec v, PetscScalar **a, PetscDeviceContext dctx) noexcept 232 { 233 PetscFunctionBegin; 234 PetscValidHeaderSpecific(v, VEC_CLASSID, 1); 235 PetscCall(PetscDeviceContextGetOptionalNullContext_Internal(&dctx)); 236 PetscCall(impl::VecSeq_CUPM<T>::template RestoreArray<PETSC_MEMTYPE_DEVICE, mode>(v, a, dctx)); 237 PetscFunctionReturn(PETSC_SUCCESS); 238 } 239 240 template <device::cupm::DeviceType T> 241 inline PetscErrorCode VecCUPMGetArrayAsync(Vec v, PetscScalar **a, PetscDeviceContext dctx = nullptr) noexcept 242 { 243 PetscFunctionBegin; 244 PetscCall(VecCUPMGetArrayAsync_Private<PETSC_MEMORY_ACCESS_READ_WRITE, T>(v, a, dctx)); 245 PetscFunctionReturn(PETSC_SUCCESS); 246 } 247 248 template <device::cupm::DeviceType T> 249 inline PetscErrorCode VecCUPMRestoreArrayAsync(Vec v, PetscScalar **a, PetscDeviceContext dctx = nullptr) noexcept 250 { 251 PetscFunctionBegin; 252 PetscCall(VecCUPMRestoreArrayAsync_Private<PETSC_MEMORY_ACCESS_READ_WRITE, T>(v, a, dctx)); 253 PetscFunctionReturn(PETSC_SUCCESS); 254 } 255 256 template <device::cupm::DeviceType T> 257 inline PetscErrorCode VecCUPMGetArrayReadAsync(Vec v, const PetscScalar **a, PetscDeviceContext dctx = nullptr) noexcept 258 { 259 PetscFunctionBegin; 260 PetscCall(VecCUPMGetArrayAsync_Private<PETSC_MEMORY_ACCESS_READ, T>(v, const_cast<PetscScalar **>(a), dctx)); 261 PetscFunctionReturn(PETSC_SUCCESS); 262 } 263 264 template <device::cupm::DeviceType T> 265 inline PetscErrorCode VecCUPMRestoreArrayReadAsync(Vec v, const PetscScalar **a, PetscDeviceContext dctx = nullptr) noexcept 266 { 267 PetscFunctionBegin; 268 PetscCall(VecCUPMRestoreArrayAsync_Private<PETSC_MEMORY_ACCESS_READ, T>(v, const_cast<PetscScalar **>(a), dctx)); 269 PetscFunctionReturn(PETSC_SUCCESS); 270 } 271 272 template <device::cupm::DeviceType T> 273 inline PetscErrorCode VecCUPMGetArrayWriteAsync(Vec v, PetscScalar **a, PetscDeviceContext dctx = nullptr) noexcept 274 { 275 PetscFunctionBegin; 276 PetscCall(VecCUPMGetArrayAsync_Private<PETSC_MEMORY_ACCESS_WRITE, T>(v, a, dctx)); 277 PetscFunctionReturn(PETSC_SUCCESS); 278 } 279 280 template <device::cupm::DeviceType T> 281 inline PetscErrorCode VecCUPMRestoreArrayWriteAsync(Vec v, PetscScalar **a, PetscDeviceContext dctx = nullptr) noexcept 282 { 283 PetscFunctionBegin; 284 PetscCall(VecCUPMRestoreArrayAsync_Private<PETSC_MEMORY_ACCESS_WRITE, T>(v, a, dctx)); 285 PetscFunctionReturn(PETSC_SUCCESS); 286 } 287 288 template <device::cupm::DeviceType T> 289 inline PetscErrorCode VecCUPMPlaceArrayAsync(Vec vin, const PetscScalar a[]) noexcept 290 { 291 PetscFunctionBegin; 292 PetscValidHeaderSpecific(vin, VEC_CLASSID, 1); 293 PetscCall(impl::VecSeq_CUPM<T>::template PlaceArray<PETSC_MEMTYPE_DEVICE>(vin, a)); 294 PetscFunctionReturn(PETSC_SUCCESS); 295 } 296 297 template <device::cupm::DeviceType T> 298 inline PetscErrorCode VecCUPMReplaceArrayAsync(Vec vin, const PetscScalar a[]) noexcept 299 { 300 PetscFunctionBegin; 301 PetscValidHeaderSpecific(vin, VEC_CLASSID, 1); 302 PetscCall(impl::VecSeq_CUPM<T>::template ReplaceArray<PETSC_MEMTYPE_DEVICE>(vin, a)); 303 PetscFunctionReturn(PETSC_SUCCESS); 304 } 305 306 template <device::cupm::DeviceType T> 307 inline PetscErrorCode VecCUPMResetArrayAsync(Vec vin) noexcept 308 { 309 PetscFunctionBegin; 310 PetscValidHeaderSpecific(vin, VEC_CLASSID, 1); 311 PetscCall(impl::VecSeq_CUPM<T>::template ResetArray<PETSC_MEMTYPE_DEVICE>(vin)); 312 PetscFunctionReturn(PETSC_SUCCESS); 313 } 314 315 } // namespace cupm 316 317 } // namespace vec 318 319 } // namespace Petsc 320 321 #if PetscDefined(HAVE_CUDA) 322 extern template class PETSC_SINGLE_LIBRARY_VISIBILITY_INTERNAL ::Petsc::vec::cupm::impl::VecSeq_CUPM<::Petsc::device::cupm::DeviceType::CUDA>; 323 #endif 324 325 #if PetscDefined(HAVE_HIP) 326 extern template class PETSC_SINGLE_LIBRARY_VISIBILITY_INTERNAL ::Petsc::vec::cupm::impl::VecSeq_CUPM<::Petsc::device::cupm::DeviceType::HIP>; 327 #endif 328