1 #pragma once 2 3 #include <petsc/private/veccupmimpl.h> 4 #include <petsc/private/cpp/utility.hpp> // util::index_sequence 5 6 #include <../src/sys/objects/device/impls/cupm/kernels.hpp> // grid_stride_1D() 7 #include <../src/vec/vec/impls/dvecimpl.h> // Vec_Seq 8 9 namespace Petsc 10 { 11 12 namespace vec 13 { 14 15 namespace cupm 16 { 17 18 namespace impl 19 { 20 21 // ========================================================================================== 22 // VecSeq_CUPM 23 // ========================================================================================== 24 25 template <device::cupm::DeviceType T> 26 class VecSeq_CUPM : Vec_CUPMBase<T, VecSeq_CUPM<T>> { 27 public: 28 PETSC_VEC_CUPM_BASE_CLASS_HEADER(base_type, T, VecSeq_CUPM<T>); 29 30 private: 31 PETSC_NODISCARD static Vec_Seq *VecIMPLCast_(Vec) noexcept; 32 PETSC_NODISCARD static constexpr VecType VECIMPLCUPM_() noexcept; 33 PETSC_NODISCARD static constexpr VecType VECIMPL_() noexcept; 34 35 static PetscErrorCode VecDestroy_IMPL_(Vec) noexcept; 36 static PetscErrorCode VecResetArray_IMPL_(Vec) noexcept; 37 static PetscErrorCode VecPlaceArray_IMPL_(Vec, const PetscScalar *) noexcept; 38 static PetscErrorCode VecCreate_IMPL_Private_(Vec, PetscBool *, PetscInt, PetscScalar *) noexcept; 39 40 static PetscErrorCode MaybeIncrementEmptyLocalVec(Vec) noexcept; 41 42 // common core for min and max 43 template <typename TupleFuncT, typename UnaryFuncT> 44 static PetscErrorCode MinMax_(TupleFuncT &&, UnaryFuncT &&, Vec, PetscInt *, PetscReal *) noexcept; 45 // common core for pointwise binary and pointwise unary thrust functions 46 template <typename BinaryFuncT> 47 static PetscErrorCode PointwiseBinary_(BinaryFuncT &&, Vec, Vec, Vec, PetscDeviceContext = nullptr) noexcept; 48 template <typename BinaryFuncT> 49 static PetscErrorCode PointwiseBinaryDispatch_(PetscErrorCode (*)(Vec, Vec, Vec), BinaryFuncT &&, Vec, Vec, Vec, PetscDeviceContext = nullptr) noexcept; 50 template <typename UnaryFuncT> 51 static PetscErrorCode PointwiseUnary_(UnaryFuncT &&, Vec, Vec, PetscDeviceContext = nullptr) noexcept; 52 // mdot dispatchers 53 static PetscErrorCode MDot_(/* use complex = */ std::true_type, Vec, PetscInt, const Vec[], PetscScalar *, PetscDeviceContext) noexcept; 54 static PetscErrorCode MDot_(/* use complex = */ std::false_type, Vec, PetscInt, const Vec[], PetscScalar *, PetscDeviceContext) noexcept; 55 template <std::size_t... Idx> 56 static PetscErrorCode MDot_kernel_dispatch_(PetscDeviceContext, cupmStream_t, const PetscScalar *, const Vec[], PetscInt, PetscScalar *, util::index_sequence<Idx...>) noexcept; 57 template <int> 58 static PetscErrorCode MDot_kernel_dispatch_(PetscDeviceContext, cupmStream_t, const PetscScalar *, const Vec[], PetscInt, PetscScalar *, PetscInt &) noexcept; 59 template <std::size_t... Idx> 60 static PetscErrorCode MAXPY_kernel_dispatch_(PetscDeviceContext, cupmStream_t, PetscScalar *, const PetscScalar *, const Vec *, PetscInt, util::index_sequence<Idx...>) noexcept; 61 template <int> 62 static PetscErrorCode MAXPY_kernel_dispatch_(PetscDeviceContext, cupmStream_t, PetscScalar *, const PetscScalar *, const Vec *, PetscInt, PetscInt &) noexcept; 63 // common core for the various create routines 64 static PetscErrorCode CreateSeqCUPM_(Vec, PetscDeviceContext, PetscScalar * /*host_ptr*/ = nullptr, PetscScalar * /*device_ptr*/ = nullptr) noexcept; 65 66 public: 67 // callable directly via a bespoke function 68 static PetscErrorCode CreateSeqCUPM(MPI_Comm, PetscInt, PetscInt, Vec *, PetscBool) noexcept; 69 static PetscErrorCode CreateSeqCUPMWithBothArrays(MPI_Comm, PetscInt, PetscInt, const PetscScalar[], const PetscScalar[], Vec *) noexcept; 70 71 static PetscErrorCode InitializeAsyncFunctions(Vec) noexcept; 72 static PetscErrorCode ClearAsyncFunctions(Vec) noexcept; 73 74 // callable indirectly via function pointers 75 static PetscErrorCode Duplicate(Vec, Vec *) noexcept; 76 static PetscErrorCode AYPX(Vec, PetscScalar, Vec) noexcept; 77 static PetscErrorCode AYPXAsync(Vec, PetscScalar, Vec, PetscDeviceContext) noexcept; 78 static PetscErrorCode AXPY(Vec, PetscScalar, Vec) noexcept; 79 static PetscErrorCode AXPYAsync(Vec, PetscScalar, Vec, PetscDeviceContext) noexcept; 80 static PetscErrorCode PointwiseDivide(Vec, Vec, Vec) noexcept; 81 static PetscErrorCode PointwiseDivideAsync(Vec, Vec, Vec, PetscDeviceContext) noexcept; 82 static PetscErrorCode PointwiseMult(Vec, Vec, Vec) noexcept; 83 static PetscErrorCode PointwiseMultAsync(Vec, Vec, Vec, PetscDeviceContext) noexcept; 84 static PetscErrorCode PointwiseMax(Vec, Vec, Vec) noexcept; 85 static PetscErrorCode PointwiseMaxAsync(Vec, Vec, Vec, PetscDeviceContext) noexcept; 86 static PetscErrorCode PointwiseMaxAbs(Vec, Vec, Vec) noexcept; 87 static PetscErrorCode PointwiseMaxAbsAsync(Vec, Vec, Vec, PetscDeviceContext) noexcept; 88 static PetscErrorCode PointwiseMin(Vec, Vec, Vec) noexcept; 89 static PetscErrorCode PointwiseMinAsync(Vec, Vec, Vec, PetscDeviceContext) noexcept; 90 static PetscErrorCode Reciprocal(Vec) noexcept; 91 static PetscErrorCode ReciprocalAsync(Vec, PetscDeviceContext) noexcept; 92 static PetscErrorCode Abs(Vec) noexcept; 93 static PetscErrorCode AbsAsync(Vec, PetscDeviceContext) noexcept; 94 static PetscErrorCode SqrtAbs(Vec) noexcept; 95 static PetscErrorCode SqrtAbsAsync(Vec, PetscDeviceContext) noexcept; 96 static PetscErrorCode Exp(Vec) noexcept; 97 static PetscErrorCode ExpAsync(Vec, PetscDeviceContext) noexcept; 98 static PetscErrorCode Log(Vec) noexcept; 99 static PetscErrorCode LogAsync(Vec, PetscDeviceContext) noexcept; 100 static PetscErrorCode WAXPY(Vec, PetscScalar, Vec, Vec) noexcept; 101 static PetscErrorCode WAXPYAsync(Vec, PetscScalar, Vec, Vec, PetscDeviceContext) noexcept; 102 static PetscErrorCode MAXPY(Vec, PetscInt, const PetscScalar[], Vec *) noexcept; 103 static PetscErrorCode MAXPYAsync(Vec, PetscInt, const PetscScalar[], Vec *, PetscDeviceContext) noexcept; 104 static PetscErrorCode Dot(Vec, Vec, PetscScalar *) noexcept; 105 static PetscErrorCode MDot(Vec, PetscInt, const Vec[], PetscScalar *) noexcept; 106 static PetscErrorCode Set(Vec, PetscScalar) noexcept; 107 static PetscErrorCode SetAsync(Vec, PetscScalar, PetscDeviceContext) noexcept; 108 static PetscErrorCode Scale(Vec, PetscScalar) noexcept; 109 static PetscErrorCode ScaleAsync(Vec, PetscScalar, PetscDeviceContext) noexcept; 110 static PetscErrorCode TDot(Vec, Vec, PetscScalar *) noexcept; 111 static PetscErrorCode Copy(Vec, Vec) noexcept; 112 static PetscErrorCode CopyAsync(Vec, Vec, PetscDeviceContext) noexcept; 113 static PetscErrorCode Swap(Vec, Vec) noexcept; 114 static PetscErrorCode SwapAsync(Vec, Vec, PetscDeviceContext) noexcept; 115 static PetscErrorCode AXPBY(Vec, PetscScalar, PetscScalar, Vec) noexcept; 116 static PetscErrorCode AXPBYAsync(Vec, PetscScalar, PetscScalar, Vec, PetscDeviceContext) noexcept; 117 static PetscErrorCode AXPBYPCZ(Vec, PetscScalar, PetscScalar, PetscScalar, Vec, Vec) noexcept; 118 static PetscErrorCode AXPBYPCZAsync(Vec, PetscScalar, PetscScalar, PetscScalar, Vec, Vec, PetscDeviceContext) noexcept; 119 static PetscErrorCode Norm(Vec, NormType, PetscReal *) noexcept; 120 static PetscErrorCode ErrorWnorm(Vec, Vec, Vec, NormType, PetscReal, Vec, PetscReal, Vec, PetscReal, PetscReal *, PetscInt *, PetscReal *, PetscInt *, PetscReal *, PetscInt *) noexcept; 121 static PetscErrorCode DotNorm2(Vec, Vec, PetscScalar *, PetscScalar *) noexcept; 122 static PetscErrorCode Conjugate(Vec) noexcept; 123 static PetscErrorCode ConjugateAsync(Vec, PetscDeviceContext) noexcept; 124 template <PetscMemoryAccessMode> 125 static PetscErrorCode GetLocalVector(Vec, Vec) noexcept; 126 template <PetscMemoryAccessMode> 127 static PetscErrorCode RestoreLocalVector(Vec, Vec) noexcept; 128 static PetscErrorCode Max(Vec, PetscInt *, PetscReal *) noexcept; 129 static PetscErrorCode Min(Vec, PetscInt *, PetscReal *) noexcept; 130 static PetscErrorCode Sum(Vec, PetscScalar *) noexcept; 131 static PetscErrorCode Shift(Vec, PetscScalar) noexcept; 132 static PetscErrorCode ShiftAsync(Vec, PetscScalar, PetscDeviceContext) noexcept; 133 static PetscErrorCode SetRandom(Vec, PetscRandom) noexcept; 134 static PetscErrorCode BindToCPU(Vec, PetscBool) noexcept; 135 static PetscErrorCode SetPreallocationCOO(Vec, PetscCount, const PetscInt[]) noexcept; 136 static PetscErrorCode SetValuesCOO(Vec, const PetscScalar[], InsertMode) noexcept; 137 }; 138 139 namespace kernels 140 { 141 142 template <typename F> 143 PETSC_DEVICE_INLINE_DECL void add_coo_values_impl(const PetscScalar *PETSC_RESTRICT vv, PetscCount n, const PetscCount *PETSC_RESTRICT jmap, const PetscCount *PETSC_RESTRICT perm, InsertMode imode, PetscScalar *PETSC_RESTRICT xv, F &&xvindex) 144 { 145 ::Petsc::device::cupm::kernels::util::grid_stride_1D(n, [=](PetscCount i) { 146 const auto end = jmap[i + 1]; 147 const auto idx = xvindex(i); 148 PetscScalar sum = 0.0; 149 150 for (auto k = jmap[i]; k < end; ++k) sum += vv[perm[k]]; 151 152 if (imode == INSERT_VALUES) { 153 xv[idx] = sum; 154 } else { 155 xv[idx] += sum; 156 } 157 }); 158 return; 159 } 160 161 namespace 162 { 163 PETSC_PRAGMA_DIAGNOSTIC_IGNORED_BEGIN("-Wunused-function") 164 PETSC_KERNEL_DECL void add_coo_values(const PetscScalar *PETSC_RESTRICT v, PetscCount n, const PetscCount *PETSC_RESTRICT jmap1, const PetscCount *PETSC_RESTRICT perm1, InsertMode imode, PetscScalar *PETSC_RESTRICT xv) 165 { 166 add_coo_values_impl(v, n, jmap1, perm1, imode, xv, [](PetscCount i) { return i; }); 167 return; 168 } 169 PETSC_PRAGMA_DIAGNOSTIC_IGNORED_END() 170 } // namespace 171 172 } // namespace kernels 173 174 } // namespace impl 175 176 // ========================================================================================== 177 // VecSeq_CUPM - Implementations 178 // ========================================================================================== 179 180 template <device::cupm::DeviceType T> 181 inline PetscErrorCode VecCreateSeqCUPMAsync(MPI_Comm comm, PetscInt n, Vec *v) noexcept 182 { 183 PetscFunctionBegin; 184 PetscAssertPointer(v, 4); 185 PetscCall(impl::VecSeq_CUPM<T>::CreateSeqCUPM(comm, 0, n, v, PETSC_TRUE)); 186 PetscFunctionReturn(PETSC_SUCCESS); 187 } 188 189 template <device::cupm::DeviceType T> 190 inline PetscErrorCode VecCreateSeqCUPMWithArraysAsync(MPI_Comm comm, PetscInt bs, PetscInt n, const PetscScalar cpuarray[], const PetscScalar gpuarray[], Vec *v) noexcept 191 { 192 PetscFunctionBegin; 193 if (n && cpuarray) PetscAssertPointer(cpuarray, 4); 194 PetscAssertPointer(v, 6); 195 PetscCall(impl::VecSeq_CUPM<T>::CreateSeqCUPMWithBothArrays(comm, bs, n, cpuarray, gpuarray, v)); 196 PetscFunctionReturn(PETSC_SUCCESS); 197 } 198 199 template <PetscMemoryAccessMode mode, device::cupm::DeviceType T> 200 inline PetscErrorCode VecCUPMGetArrayAsync_Private(Vec v, PetscScalar **a, PetscDeviceContext dctx) noexcept 201 { 202 PetscFunctionBegin; 203 PetscValidHeaderSpecific(v, VEC_CLASSID, 1); 204 PetscAssertPointer(a, 2); 205 PetscCall(PetscDeviceContextGetOptionalNullContext_Internal(&dctx)); 206 PetscCall(impl::VecSeq_CUPM<T>::template GetArray<PETSC_MEMTYPE_DEVICE, mode>(v, a, dctx)); 207 PetscFunctionReturn(PETSC_SUCCESS); 208 } 209 210 template <PetscMemoryAccessMode mode, device::cupm::DeviceType T> 211 inline PetscErrorCode VecCUPMRestoreArrayAsync_Private(Vec v, PetscScalar **a, PetscDeviceContext dctx) noexcept 212 { 213 PetscFunctionBegin; 214 PetscValidHeaderSpecific(v, VEC_CLASSID, 1); 215 PetscCall(PetscDeviceContextGetOptionalNullContext_Internal(&dctx)); 216 PetscCall(impl::VecSeq_CUPM<T>::template RestoreArray<PETSC_MEMTYPE_DEVICE, mode>(v, a, dctx)); 217 PetscFunctionReturn(PETSC_SUCCESS); 218 } 219 220 template <device::cupm::DeviceType T> 221 inline PetscErrorCode VecCUPMGetArrayAsync(Vec v, PetscScalar **a, PetscDeviceContext dctx = nullptr) noexcept 222 { 223 PetscFunctionBegin; 224 PetscCall(VecCUPMGetArrayAsync_Private<PETSC_MEMORY_ACCESS_READ_WRITE, T>(v, a, dctx)); 225 PetscFunctionReturn(PETSC_SUCCESS); 226 } 227 228 template <device::cupm::DeviceType T> 229 inline PetscErrorCode VecCUPMRestoreArrayAsync(Vec v, PetscScalar **a, PetscDeviceContext dctx = nullptr) noexcept 230 { 231 PetscFunctionBegin; 232 PetscCall(VecCUPMRestoreArrayAsync_Private<PETSC_MEMORY_ACCESS_READ_WRITE, T>(v, a, dctx)); 233 PetscFunctionReturn(PETSC_SUCCESS); 234 } 235 236 template <device::cupm::DeviceType T> 237 inline PetscErrorCode VecCUPMGetArrayReadAsync(Vec v, const PetscScalar **a, PetscDeviceContext dctx = nullptr) noexcept 238 { 239 PetscFunctionBegin; 240 PetscCall(VecCUPMGetArrayAsync_Private<PETSC_MEMORY_ACCESS_READ, T>(v, const_cast<PetscScalar **>(a), dctx)); 241 PetscFunctionReturn(PETSC_SUCCESS); 242 } 243 244 template <device::cupm::DeviceType T> 245 inline PetscErrorCode VecCUPMRestoreArrayReadAsync(Vec v, const PetscScalar **a, PetscDeviceContext dctx = nullptr) noexcept 246 { 247 PetscFunctionBegin; 248 PetscCall(VecCUPMRestoreArrayAsync_Private<PETSC_MEMORY_ACCESS_READ, T>(v, const_cast<PetscScalar **>(a), dctx)); 249 PetscFunctionReturn(PETSC_SUCCESS); 250 } 251 252 template <device::cupm::DeviceType T> 253 inline PetscErrorCode VecCUPMGetArrayWriteAsync(Vec v, PetscScalar **a, PetscDeviceContext dctx = nullptr) noexcept 254 { 255 PetscFunctionBegin; 256 PetscCall(VecCUPMGetArrayAsync_Private<PETSC_MEMORY_ACCESS_WRITE, T>(v, a, dctx)); 257 PetscFunctionReturn(PETSC_SUCCESS); 258 } 259 260 template <device::cupm::DeviceType T> 261 inline PetscErrorCode VecCUPMRestoreArrayWriteAsync(Vec v, PetscScalar **a, PetscDeviceContext dctx = nullptr) noexcept 262 { 263 PetscFunctionBegin; 264 PetscCall(VecCUPMRestoreArrayAsync_Private<PETSC_MEMORY_ACCESS_WRITE, T>(v, a, dctx)); 265 PetscFunctionReturn(PETSC_SUCCESS); 266 } 267 268 template <device::cupm::DeviceType T> 269 inline PetscErrorCode VecCUPMPlaceArrayAsync(Vec vin, const PetscScalar a[]) noexcept 270 { 271 PetscFunctionBegin; 272 PetscValidHeaderSpecific(vin, VEC_CLASSID, 1); 273 PetscCall(impl::VecSeq_CUPM<T>::template PlaceArray<PETSC_MEMTYPE_DEVICE>(vin, a)); 274 PetscFunctionReturn(PETSC_SUCCESS); 275 } 276 277 template <device::cupm::DeviceType T> 278 inline PetscErrorCode VecCUPMReplaceArrayAsync(Vec vin, const PetscScalar a[]) noexcept 279 { 280 PetscFunctionBegin; 281 PetscValidHeaderSpecific(vin, VEC_CLASSID, 1); 282 PetscCall(impl::VecSeq_CUPM<T>::template ReplaceArray<PETSC_MEMTYPE_DEVICE>(vin, a)); 283 PetscFunctionReturn(PETSC_SUCCESS); 284 } 285 286 template <device::cupm::DeviceType T> 287 inline PetscErrorCode VecCUPMResetArrayAsync(Vec vin) noexcept 288 { 289 PetscFunctionBegin; 290 PetscValidHeaderSpecific(vin, VEC_CLASSID, 1); 291 PetscCall(impl::VecSeq_CUPM<T>::template ResetArray<PETSC_MEMTYPE_DEVICE>(vin)); 292 PetscFunctionReturn(PETSC_SUCCESS); 293 } 294 295 } // namespace cupm 296 297 } // namespace vec 298 299 } // namespace Petsc 300 301 #if PetscDefined(HAVE_CUDA) 302 extern template class PETSC_SINGLE_LIBRARY_VISIBILITY_INTERNAL ::Petsc::vec::cupm::impl::VecSeq_CUPM<::Petsc::device::cupm::DeviceType::CUDA>; 303 #endif 304 305 #if PetscDefined(HAVE_HIP) 306 extern template class PETSC_SINGLE_LIBRARY_VISIBILITY_INTERNAL ::Petsc::vec::cupm::impl::VecSeq_CUPM<::Petsc::device::cupm::DeviceType::HIP>; 307 #endif 308