1 #ifndef PETSCVECSEQCUPM_HPP 2 #define PETSCVECSEQCUPM_HPP 3 4 #include <petsc/private/veccupmimpl.h> 5 #include <petsc/private/cpp/utility.hpp> // util::index_sequence 6 7 #include <../src/sys/objects/device/impls/cupm/kernels.hpp> // grid_stride_1D() 8 #include <../src/vec/vec/impls/dvecimpl.h> // Vec_Seq 9 10 namespace Petsc 11 { 12 13 namespace vec 14 { 15 16 namespace cupm 17 { 18 19 namespace impl 20 { 21 22 // ========================================================================================== 23 // VecSeq_CUPM 24 // ========================================================================================== 25 26 template <device::cupm::DeviceType T> 27 class VecSeq_CUPM : Vec_CUPMBase<T, VecSeq_CUPM<T>> { 28 public: 29 PETSC_VEC_CUPM_BASE_CLASS_HEADER(base_type, T, VecSeq_CUPM<T>); 30 31 private: 32 PETSC_NODISCARD static Vec_Seq *VecIMPLCast_(Vec) noexcept; 33 PETSC_NODISCARD static constexpr VecType VECIMPLCUPM_() noexcept; 34 PETSC_NODISCARD static constexpr VecType VECIMPL_() noexcept; 35 36 static PetscErrorCode VecDestroy_IMPL_(Vec) noexcept; 37 static PetscErrorCode VecResetArray_IMPL_(Vec) noexcept; 38 static PetscErrorCode VecPlaceArray_IMPL_(Vec, const PetscScalar *) noexcept; 39 static PetscErrorCode VecCreate_IMPL_Private_(Vec, PetscBool *, PetscInt, PetscScalar *) noexcept; 40 41 static PetscErrorCode MaybeIncrementEmptyLocalVec(Vec) noexcept; 42 43 // common core for min and max 44 template <typename TupleFuncT, typename UnaryFuncT> 45 static PetscErrorCode MinMax_(TupleFuncT &&, UnaryFuncT &&, Vec, PetscInt *, PetscReal *) noexcept; 46 // common core for pointwise binary and pointwise unary thrust functions 47 template <typename BinaryFuncT> 48 static PetscErrorCode PointwiseBinary_(BinaryFuncT &&, Vec, Vec, Vec, PetscDeviceContext = nullptr) noexcept; 49 template <typename BinaryFuncT> 50 static PetscErrorCode PointwiseBinaryDispatch_(PetscErrorCode (*)(Vec, Vec, Vec), BinaryFuncT &&, Vec, Vec, Vec, PetscDeviceContext = nullptr) noexcept; 51 template <typename UnaryFuncT> 52 static PetscErrorCode PointwiseUnary_(UnaryFuncT &&, Vec, Vec, PetscDeviceContext = nullptr) noexcept; 53 // mdot dispatchers 54 static PetscErrorCode MDot_(/* use complex = */ std::true_type, Vec, PetscInt, const Vec[], PetscScalar *, PetscDeviceContext) noexcept; 55 static PetscErrorCode MDot_(/* use complex = */ std::false_type, Vec, PetscInt, const Vec[], PetscScalar *, PetscDeviceContext) noexcept; 56 template <std::size_t... Idx> 57 static PetscErrorCode MDot_kernel_dispatch_(PetscDeviceContext, cupmStream_t, const PetscScalar *, const Vec[], PetscInt, PetscScalar *, util::index_sequence<Idx...>) noexcept; 58 template <int> 59 static PetscErrorCode MDot_kernel_dispatch_(PetscDeviceContext, cupmStream_t, const PetscScalar *, const Vec[], PetscInt, PetscScalar *, PetscInt &) noexcept; 60 template <std::size_t... Idx> 61 static PetscErrorCode MAXPY_kernel_dispatch_(PetscDeviceContext, cupmStream_t, PetscScalar *, const PetscScalar *, const Vec *, PetscInt, util::index_sequence<Idx...>) noexcept; 62 template <int> 63 static PetscErrorCode MAXPY_kernel_dispatch_(PetscDeviceContext, cupmStream_t, PetscScalar *, const PetscScalar *, const Vec *, PetscInt, PetscInt &) noexcept; 64 // common core for the various create routines 65 static PetscErrorCode CreateSeqCUPM_(Vec, PetscDeviceContext, PetscScalar * /*host_ptr*/ = nullptr, PetscScalar * /*device_ptr*/ = nullptr) noexcept; 66 67 public: 68 // callable directly via a bespoke function 69 static PetscErrorCode CreateSeqCUPM(MPI_Comm, PetscInt, PetscInt, Vec *, PetscBool) noexcept; 70 static PetscErrorCode CreateSeqCUPMWithBothArrays(MPI_Comm, PetscInt, PetscInt, const PetscScalar[], const PetscScalar[], Vec *) noexcept; 71 72 static PetscErrorCode InitializeAsyncFunctions(Vec) noexcept; 73 static PetscErrorCode ClearAsyncFunctions(Vec) noexcept; 74 75 // callable indirectly via function pointers 76 static PetscErrorCode Duplicate(Vec, Vec *) noexcept; 77 static PetscErrorCode AYPX(Vec, PetscScalar, Vec) noexcept; 78 static PetscErrorCode AYPXAsync(Vec, PetscScalar, Vec, PetscDeviceContext) noexcept; 79 static PetscErrorCode AXPY(Vec, PetscScalar, Vec) noexcept; 80 static PetscErrorCode AXPYAsync(Vec, PetscScalar, Vec, PetscDeviceContext) noexcept; 81 static PetscErrorCode PointwiseDivide(Vec, Vec, Vec) noexcept; 82 static PetscErrorCode PointwiseDivideAsync(Vec, Vec, Vec, PetscDeviceContext) noexcept; 83 static PetscErrorCode PointwiseMult(Vec, Vec, Vec) noexcept; 84 static PetscErrorCode PointwiseMultAsync(Vec, Vec, Vec, PetscDeviceContext) noexcept; 85 static PetscErrorCode PointwiseMax(Vec, Vec, Vec) noexcept; 86 static PetscErrorCode PointwiseMaxAsync(Vec, Vec, Vec, PetscDeviceContext) noexcept; 87 static PetscErrorCode PointwiseMaxAbs(Vec, Vec, Vec) noexcept; 88 static PetscErrorCode PointwiseMaxAbsAsync(Vec, Vec, Vec, PetscDeviceContext) noexcept; 89 static PetscErrorCode PointwiseMin(Vec, Vec, Vec) noexcept; 90 static PetscErrorCode PointwiseMinAsync(Vec, Vec, Vec, PetscDeviceContext) noexcept; 91 static PetscErrorCode Reciprocal(Vec) noexcept; 92 static PetscErrorCode ReciprocalAsync(Vec, PetscDeviceContext) noexcept; 93 static PetscErrorCode Abs(Vec) noexcept; 94 static PetscErrorCode AbsAsync(Vec, PetscDeviceContext) noexcept; 95 static PetscErrorCode SqrtAbs(Vec) noexcept; 96 static PetscErrorCode SqrtAbsAsync(Vec, PetscDeviceContext) noexcept; 97 static PetscErrorCode Exp(Vec) noexcept; 98 static PetscErrorCode ExpAsync(Vec, PetscDeviceContext) noexcept; 99 static PetscErrorCode Log(Vec) noexcept; 100 static PetscErrorCode LogAsync(Vec, PetscDeviceContext) noexcept; 101 static PetscErrorCode WAXPY(Vec, PetscScalar, Vec, Vec) noexcept; 102 static PetscErrorCode WAXPYAsync(Vec, PetscScalar, Vec, Vec, PetscDeviceContext) noexcept; 103 static PetscErrorCode MAXPY(Vec, PetscInt, const PetscScalar[], Vec *) noexcept; 104 static PetscErrorCode MAXPYAsync(Vec, PetscInt, const PetscScalar[], Vec *, PetscDeviceContext) noexcept; 105 static PetscErrorCode Dot(Vec, Vec, PetscScalar *) noexcept; 106 static PetscErrorCode MDot(Vec, PetscInt, const Vec[], PetscScalar *) noexcept; 107 static PetscErrorCode Set(Vec, PetscScalar) noexcept; 108 static PetscErrorCode SetAsync(Vec, PetscScalar, PetscDeviceContext) noexcept; 109 static PetscErrorCode Scale(Vec, PetscScalar) noexcept; 110 static PetscErrorCode ScaleAsync(Vec, PetscScalar, PetscDeviceContext) noexcept; 111 static PetscErrorCode TDot(Vec, Vec, PetscScalar *) noexcept; 112 static PetscErrorCode Copy(Vec, Vec) noexcept; 113 static PetscErrorCode CopyAsync(Vec, Vec, PetscDeviceContext) noexcept; 114 static PetscErrorCode Swap(Vec, Vec) noexcept; 115 static PetscErrorCode SwapAsync(Vec, Vec, PetscDeviceContext) noexcept; 116 static PetscErrorCode AXPBY(Vec, PetscScalar, PetscScalar, Vec) noexcept; 117 static PetscErrorCode AXPBYAsync(Vec, PetscScalar, PetscScalar, Vec, PetscDeviceContext) noexcept; 118 static PetscErrorCode AXPBYPCZ(Vec, PetscScalar, PetscScalar, PetscScalar, Vec, Vec) noexcept; 119 static PetscErrorCode AXPBYPCZAsync(Vec, PetscScalar, PetscScalar, PetscScalar, Vec, Vec, PetscDeviceContext) noexcept; 120 static PetscErrorCode Norm(Vec, NormType, PetscReal *) noexcept; 121 static PetscErrorCode ErrorWnorm(Vec, Vec, Vec, NormType, PetscReal, Vec, PetscReal, Vec, PetscReal, PetscReal *, PetscInt *, PetscReal *, PetscInt *, PetscReal *, PetscInt *) noexcept; 122 static PetscErrorCode DotNorm2(Vec, Vec, PetscScalar *, PetscScalar *) noexcept; 123 static PetscErrorCode Conjugate(Vec) noexcept; 124 static PetscErrorCode ConjugateAsync(Vec, PetscDeviceContext) noexcept; 125 template <PetscMemoryAccessMode> 126 static PetscErrorCode GetLocalVector(Vec, Vec) noexcept; 127 template <PetscMemoryAccessMode> 128 static PetscErrorCode RestoreLocalVector(Vec, Vec) noexcept; 129 static PetscErrorCode Max(Vec, PetscInt *, PetscReal *) noexcept; 130 static PetscErrorCode Min(Vec, PetscInt *, PetscReal *) noexcept; 131 static PetscErrorCode Sum(Vec, PetscScalar *) noexcept; 132 static PetscErrorCode Shift(Vec, PetscScalar) noexcept; 133 static PetscErrorCode ShiftAsync(Vec, PetscScalar, PetscDeviceContext) noexcept; 134 static PetscErrorCode SetRandom(Vec, PetscRandom) noexcept; 135 static PetscErrorCode BindToCPU(Vec, PetscBool) noexcept; 136 static PetscErrorCode SetPreallocationCOO(Vec, PetscCount, const PetscInt[]) noexcept; 137 static PetscErrorCode SetValuesCOO(Vec, const PetscScalar[], InsertMode) noexcept; 138 }; 139 140 namespace kernels 141 { 142 143 template <typename F> 144 PETSC_DEVICE_INLINE_DECL void add_coo_values_impl(const PetscScalar *PETSC_RESTRICT vv, PetscCount n, const PetscCount *PETSC_RESTRICT jmap, const PetscCount *PETSC_RESTRICT perm, InsertMode imode, PetscScalar *PETSC_RESTRICT xv, F &&xvindex) 145 { 146 ::Petsc::device::cupm::kernels::util::grid_stride_1D(n, [=](PetscCount i) { 147 const auto end = jmap[i + 1]; 148 const auto idx = xvindex(i); 149 PetscScalar sum = 0.0; 150 151 for (auto k = jmap[i]; k < end; ++k) sum += vv[perm[k]]; 152 153 if (imode == INSERT_VALUES) { 154 xv[idx] = sum; 155 } else { 156 xv[idx] += sum; 157 } 158 }); 159 return; 160 } 161 162 namespace 163 { 164 165 PETSC_KERNEL_DECL void add_coo_values(const PetscScalar *PETSC_RESTRICT v, PetscCount n, const PetscCount *PETSC_RESTRICT jmap1, const PetscCount *PETSC_RESTRICT perm1, InsertMode imode, PetscScalar *PETSC_RESTRICT xv) 166 { 167 add_coo_values_impl(v, n, jmap1, perm1, imode, xv, [](PetscCount i) { return i; }); 168 return; 169 } 170 171 } // namespace 172 173 #if PetscDefined(USING_HCC) 174 namespace do_not_use 175 { 176 177 // Needed to silence clang warning: 178 // 179 // warning: function 'FUNCTION NAME' is not needed and will not be emitted 180 // 181 // The warning is silly, since the function *is* used, however the host compiler does not 182 // appear see this. Likely because the function using it is in a template. 183 // 184 // This warning appeared in clang-11, and still persists until clang-15 (21/02/2023) 185 inline void silence_warning_function_add_coo_values_is_not_needed_and_will_not_be_emitted() 186 { 187 (void)add_coo_values; 188 } 189 190 } // namespace do_not_use 191 #endif 192 193 } // namespace kernels 194 195 } // namespace impl 196 197 // ========================================================================================== 198 // VecSeq_CUPM - Implementations 199 // ========================================================================================== 200 201 template <device::cupm::DeviceType T> 202 inline PetscErrorCode VecCreateSeqCUPMAsync(MPI_Comm comm, PetscInt n, Vec *v) noexcept 203 { 204 PetscFunctionBegin; 205 PetscAssertPointer(v, 4); 206 PetscCall(impl::VecSeq_CUPM<T>::CreateSeqCUPM(comm, 0, n, v, PETSC_TRUE)); 207 PetscFunctionReturn(PETSC_SUCCESS); 208 } 209 210 template <device::cupm::DeviceType T> 211 inline PetscErrorCode VecCreateSeqCUPMWithArraysAsync(MPI_Comm comm, PetscInt bs, PetscInt n, const PetscScalar cpuarray[], const PetscScalar gpuarray[], Vec *v) noexcept 212 { 213 PetscFunctionBegin; 214 if (n && cpuarray) PetscAssertPointer(cpuarray, 4); 215 PetscAssertPointer(v, 6); 216 PetscCall(impl::VecSeq_CUPM<T>::CreateSeqCUPMWithBothArrays(comm, bs, n, cpuarray, gpuarray, v)); 217 PetscFunctionReturn(PETSC_SUCCESS); 218 } 219 220 template <PetscMemoryAccessMode mode, device::cupm::DeviceType T> 221 inline PetscErrorCode VecCUPMGetArrayAsync_Private(Vec v, PetscScalar **a, PetscDeviceContext dctx) noexcept 222 { 223 PetscFunctionBegin; 224 PetscValidHeaderSpecific(v, VEC_CLASSID, 1); 225 PetscAssertPointer(a, 2); 226 PetscCall(PetscDeviceContextGetOptionalNullContext_Internal(&dctx)); 227 PetscCall(impl::VecSeq_CUPM<T>::template GetArray<PETSC_MEMTYPE_DEVICE, mode>(v, a, dctx)); 228 PetscFunctionReturn(PETSC_SUCCESS); 229 } 230 231 template <PetscMemoryAccessMode mode, device::cupm::DeviceType T> 232 inline PetscErrorCode VecCUPMRestoreArrayAsync_Private(Vec v, PetscScalar **a, PetscDeviceContext dctx) noexcept 233 { 234 PetscFunctionBegin; 235 PetscValidHeaderSpecific(v, VEC_CLASSID, 1); 236 PetscCall(PetscDeviceContextGetOptionalNullContext_Internal(&dctx)); 237 PetscCall(impl::VecSeq_CUPM<T>::template RestoreArray<PETSC_MEMTYPE_DEVICE, mode>(v, a, dctx)); 238 PetscFunctionReturn(PETSC_SUCCESS); 239 } 240 241 template <device::cupm::DeviceType T> 242 inline PetscErrorCode VecCUPMGetArrayAsync(Vec v, PetscScalar **a, PetscDeviceContext dctx = nullptr) noexcept 243 { 244 PetscFunctionBegin; 245 PetscCall(VecCUPMGetArrayAsync_Private<PETSC_MEMORY_ACCESS_READ_WRITE, T>(v, a, dctx)); 246 PetscFunctionReturn(PETSC_SUCCESS); 247 } 248 249 template <device::cupm::DeviceType T> 250 inline PetscErrorCode VecCUPMRestoreArrayAsync(Vec v, PetscScalar **a, PetscDeviceContext dctx = nullptr) noexcept 251 { 252 PetscFunctionBegin; 253 PetscCall(VecCUPMRestoreArrayAsync_Private<PETSC_MEMORY_ACCESS_READ_WRITE, T>(v, a, dctx)); 254 PetscFunctionReturn(PETSC_SUCCESS); 255 } 256 257 template <device::cupm::DeviceType T> 258 inline PetscErrorCode VecCUPMGetArrayReadAsync(Vec v, const PetscScalar **a, PetscDeviceContext dctx = nullptr) noexcept 259 { 260 PetscFunctionBegin; 261 PetscCall(VecCUPMGetArrayAsync_Private<PETSC_MEMORY_ACCESS_READ, T>(v, const_cast<PetscScalar **>(a), dctx)); 262 PetscFunctionReturn(PETSC_SUCCESS); 263 } 264 265 template <device::cupm::DeviceType T> 266 inline PetscErrorCode VecCUPMRestoreArrayReadAsync(Vec v, const PetscScalar **a, PetscDeviceContext dctx = nullptr) noexcept 267 { 268 PetscFunctionBegin; 269 PetscCall(VecCUPMRestoreArrayAsync_Private<PETSC_MEMORY_ACCESS_READ, T>(v, const_cast<PetscScalar **>(a), dctx)); 270 PetscFunctionReturn(PETSC_SUCCESS); 271 } 272 273 template <device::cupm::DeviceType T> 274 inline PetscErrorCode VecCUPMGetArrayWriteAsync(Vec v, PetscScalar **a, PetscDeviceContext dctx = nullptr) noexcept 275 { 276 PetscFunctionBegin; 277 PetscCall(VecCUPMGetArrayAsync_Private<PETSC_MEMORY_ACCESS_WRITE, T>(v, a, dctx)); 278 PetscFunctionReturn(PETSC_SUCCESS); 279 } 280 281 template <device::cupm::DeviceType T> 282 inline PetscErrorCode VecCUPMRestoreArrayWriteAsync(Vec v, PetscScalar **a, PetscDeviceContext dctx = nullptr) noexcept 283 { 284 PetscFunctionBegin; 285 PetscCall(VecCUPMRestoreArrayAsync_Private<PETSC_MEMORY_ACCESS_WRITE, T>(v, a, dctx)); 286 PetscFunctionReturn(PETSC_SUCCESS); 287 } 288 289 template <device::cupm::DeviceType T> 290 inline PetscErrorCode VecCUPMPlaceArrayAsync(Vec vin, const PetscScalar a[]) noexcept 291 { 292 PetscFunctionBegin; 293 PetscValidHeaderSpecific(vin, VEC_CLASSID, 1); 294 PetscCall(impl::VecSeq_CUPM<T>::template PlaceArray<PETSC_MEMTYPE_DEVICE>(vin, a)); 295 PetscFunctionReturn(PETSC_SUCCESS); 296 } 297 298 template <device::cupm::DeviceType T> 299 inline PetscErrorCode VecCUPMReplaceArrayAsync(Vec vin, const PetscScalar a[]) noexcept 300 { 301 PetscFunctionBegin; 302 PetscValidHeaderSpecific(vin, VEC_CLASSID, 1); 303 PetscCall(impl::VecSeq_CUPM<T>::template ReplaceArray<PETSC_MEMTYPE_DEVICE>(vin, a)); 304 PetscFunctionReturn(PETSC_SUCCESS); 305 } 306 307 template <device::cupm::DeviceType T> 308 inline PetscErrorCode VecCUPMResetArrayAsync(Vec vin) noexcept 309 { 310 PetscFunctionBegin; 311 PetscValidHeaderSpecific(vin, VEC_CLASSID, 1); 312 PetscCall(impl::VecSeq_CUPM<T>::template ResetArray<PETSC_MEMTYPE_DEVICE>(vin)); 313 PetscFunctionReturn(PETSC_SUCCESS); 314 } 315 316 } // namespace cupm 317 318 } // namespace vec 319 320 } // namespace Petsc 321 322 #if PetscDefined(HAVE_CUDA) 323 extern template class PETSC_SINGLE_LIBRARY_VISIBILITY_INTERNAL ::Petsc::vec::cupm::impl::VecSeq_CUPM<::Petsc::device::cupm::DeviceType::CUDA>; 324 #endif 325 326 #if PetscDefined(HAVE_HIP) 327 extern template class PETSC_SINGLE_LIBRARY_VISIBILITY_INTERNAL ::Petsc::vec::cupm::impl::VecSeq_CUPM<::Petsc::device::cupm::DeviceType::HIP>; 328 #endif 329 330 #endif // PETSCVECSEQCUPM_HPP 331