1 #ifndef PETSCVECSEQCUPM_HPP 2 #define PETSCVECSEQCUPM_HPP 3 4 #include <petsc/private/veccupmimpl.h> 5 #include <petsc/private/cpp/utility.hpp> // util::index_sequence 6 7 #include <../src/sys/objects/device/impls/cupm/kernels.hpp> // grid_stride_1D() 8 #include <../src/vec/vec/impls/dvecimpl.h> // Vec_Seq 9 10 namespace Petsc 11 { 12 13 namespace vec 14 { 15 16 namespace cupm 17 { 18 19 namespace impl 20 { 21 22 // ========================================================================================== 23 // VecSeq_CUPM 24 // ========================================================================================== 25 26 template <device::cupm::DeviceType T> 27 class VecSeq_CUPM : Vec_CUPMBase<T, VecSeq_CUPM<T>> { 28 public: 29 PETSC_VEC_CUPM_BASE_CLASS_HEADER(base_type, T, VecSeq_CUPM<T>); 30 31 private: 32 PETSC_NODISCARD static Vec_Seq *VecIMPLCast_(Vec) noexcept; 33 PETSC_NODISCARD static constexpr VecType VECIMPLCUPM_() noexcept; 34 PETSC_NODISCARD static constexpr VecType VECIMPL_() noexcept; 35 36 static PetscErrorCode VecDestroy_IMPL_(Vec) noexcept; 37 static PetscErrorCode VecResetArray_IMPL_(Vec) noexcept; 38 static PetscErrorCode VecPlaceArray_IMPL_(Vec, const PetscScalar *) noexcept; 39 static PetscErrorCode VecCreate_IMPL_Private_(Vec, PetscBool *, PetscInt, PetscScalar *) noexcept; 40 41 static PetscErrorCode MaybeIncrementEmptyLocalVec(Vec) noexcept; 42 43 // common core for min and max 44 template <typename TupleFuncT, typename UnaryFuncT> 45 static PetscErrorCode MinMax_(TupleFuncT &&, UnaryFuncT &&, Vec, PetscInt *, PetscReal *) noexcept; 46 // common core for pointwise binary and pointwise unary thrust functions 47 template <typename BinaryFuncT> 48 static PetscErrorCode PointwiseBinary_(BinaryFuncT &&, Vec, Vec, Vec) noexcept; 49 template <typename BinaryFuncT> 50 static PetscErrorCode PointwiseBinaryDispatch_(PetscErrorCode (*)(Vec, Vec, Vec), BinaryFuncT &&, Vec, Vec, Vec) noexcept; 51 template <typename UnaryFuncT> 52 static PetscErrorCode PointwiseUnary_(UnaryFuncT &&, Vec, Vec /*out*/ = nullptr) noexcept; 53 // mdot dispatchers 54 static PetscErrorCode MDot_(/* use complex = */ std::true_type, Vec, PetscInt, const Vec[], PetscScalar *, PetscDeviceContext) noexcept; 55 static PetscErrorCode MDot_(/* use complex = */ std::false_type, Vec, PetscInt, const Vec[], PetscScalar *, PetscDeviceContext) noexcept; 56 template <std::size_t... Idx> 57 static PetscErrorCode MDot_kernel_dispatch_(PetscDeviceContext, cupmStream_t, const PetscScalar *, const Vec[], PetscInt, PetscScalar *, util::index_sequence<Idx...>) noexcept; 58 template <int> 59 static PetscErrorCode MDot_kernel_dispatch_(PetscDeviceContext, cupmStream_t, const PetscScalar *, const Vec[], PetscInt, PetscScalar *, PetscInt &) noexcept; 60 template <std::size_t... Idx> 61 static PetscErrorCode MAXPY_kernel_dispatch_(PetscDeviceContext, cupmStream_t, PetscScalar *, const PetscScalar *, const Vec *, PetscInt, util::index_sequence<Idx...>) noexcept; 62 template <int> 63 static PetscErrorCode MAXPY_kernel_dispatch_(PetscDeviceContext, cupmStream_t, PetscScalar *, const PetscScalar *, const Vec *, PetscInt, PetscInt &) noexcept; 64 // common core for the various create routines 65 static PetscErrorCode CreateSeqCUPM_(Vec, PetscDeviceContext, PetscScalar * /*host_ptr*/ = nullptr, PetscScalar * /*device_ptr*/ = nullptr) noexcept; 66 67 public: 68 // callable directly via a bespoke function 69 static PetscErrorCode CreateSeqCUPM(MPI_Comm, PetscInt, PetscInt, Vec *, PetscBool) noexcept; 70 static PetscErrorCode CreateSeqCUPMWithBothArrays(MPI_Comm, PetscInt, PetscInt, const PetscScalar[], const PetscScalar[], Vec *) noexcept; 71 72 // callable indirectly via function pointers 73 static PetscErrorCode Duplicate(Vec, Vec *) noexcept; 74 static PetscErrorCode AYPX(Vec, PetscScalar, Vec) noexcept; 75 static PetscErrorCode AXPY(Vec, PetscScalar, Vec) noexcept; 76 static PetscErrorCode PointwiseDivide(Vec, Vec, Vec) noexcept; 77 static PetscErrorCode PointwiseMult(Vec, Vec, Vec) noexcept; 78 static PetscErrorCode PointwiseMax(Vec, Vec, Vec) noexcept; 79 static PetscErrorCode PointwiseMaxAbs(Vec, Vec, Vec) noexcept; 80 static PetscErrorCode PointwiseMin(Vec, Vec, Vec) noexcept; 81 static PetscErrorCode Reciprocal(Vec) noexcept; 82 static PetscErrorCode Abs(Vec) noexcept; 83 static PetscErrorCode Sqrt(Vec) noexcept; 84 static PetscErrorCode Exp(Vec) noexcept; 85 static PetscErrorCode Log(Vec) noexcept; 86 static PetscErrorCode WAXPY(Vec, PetscScalar, Vec, Vec) noexcept; 87 static PetscErrorCode MAXPY(Vec, PetscInt, const PetscScalar[], Vec *) noexcept; 88 static PetscErrorCode Dot(Vec, Vec, PetscScalar *) noexcept; 89 static PetscErrorCode MDot(Vec, PetscInt, const Vec[], PetscScalar *) noexcept; 90 static PetscErrorCode Set(Vec, PetscScalar) noexcept; 91 static PetscErrorCode Scale(Vec, PetscScalar) noexcept; 92 static PetscErrorCode TDot(Vec, Vec, PetscScalar *) noexcept; 93 static PetscErrorCode Copy(Vec, Vec) noexcept; 94 static PetscErrorCode Swap(Vec, Vec) noexcept; 95 static PetscErrorCode AXPBY(Vec, PetscScalar, PetscScalar, Vec) noexcept; 96 static PetscErrorCode AXPBYPCZ(Vec, PetscScalar, PetscScalar, PetscScalar, Vec, Vec) noexcept; 97 static PetscErrorCode Norm(Vec, NormType, PetscReal *) noexcept; 98 static PetscErrorCode ErrorWnorm(Vec, Vec, Vec, NormType, PetscReal, Vec, PetscReal, Vec, PetscReal, PetscReal *, PetscInt *, PetscReal *, PetscInt *, PetscReal *, PetscInt *) noexcept; 99 static PetscErrorCode DotNorm2(Vec, Vec, PetscScalar *, PetscScalar *) noexcept; 100 static PetscErrorCode Conjugate(Vec) noexcept; 101 template <PetscMemoryAccessMode> 102 static PetscErrorCode GetLocalVector(Vec, Vec) noexcept; 103 template <PetscMemoryAccessMode> 104 static PetscErrorCode RestoreLocalVector(Vec, Vec) noexcept; 105 static PetscErrorCode Max(Vec, PetscInt *, PetscReal *) noexcept; 106 static PetscErrorCode Min(Vec, PetscInt *, PetscReal *) noexcept; 107 static PetscErrorCode Sum(Vec, PetscScalar *) noexcept; 108 static PetscErrorCode Shift(Vec, PetscScalar) noexcept; 109 static PetscErrorCode SetRandom(Vec, PetscRandom) noexcept; 110 static PetscErrorCode BindToCPU(Vec, PetscBool) noexcept; 111 static PetscErrorCode SetPreallocationCOO(Vec, PetscCount, const PetscInt[]) noexcept; 112 static PetscErrorCode SetValuesCOO(Vec, const PetscScalar[], InsertMode) noexcept; 113 }; 114 115 namespace kernels 116 { 117 118 template <typename F> 119 PETSC_DEVICE_INLINE_DECL void add_coo_values_impl(const PetscScalar *PETSC_RESTRICT vv, PetscCount n, const PetscCount *PETSC_RESTRICT jmap, const PetscCount *PETSC_RESTRICT perm, InsertMode imode, PetscScalar *PETSC_RESTRICT xv, F &&xvindex) 120 { 121 ::Petsc::device::cupm::kernels::util::grid_stride_1D(n, [=](PetscCount i) { 122 const auto end = jmap[i + 1]; 123 const auto idx = xvindex(i); 124 PetscScalar sum = 0.0; 125 126 for (auto k = jmap[i]; k < end; ++k) sum += vv[perm[k]]; 127 128 if (imode == INSERT_VALUES) { 129 xv[idx] = sum; 130 } else { 131 xv[idx] += sum; 132 } 133 }); 134 return; 135 } 136 137 namespace 138 { 139 140 PETSC_KERNEL_DECL void add_coo_values(const PetscScalar *PETSC_RESTRICT v, PetscCount n, const PetscCount *PETSC_RESTRICT jmap1, const PetscCount *PETSC_RESTRICT perm1, InsertMode imode, PetscScalar *PETSC_RESTRICT xv) 141 { 142 add_coo_values_impl(v, n, jmap1, perm1, imode, xv, [](PetscCount i) { return i; }); 143 return; 144 } 145 146 } // namespace 147 148 #if PetscDefined(USING_HCC) 149 namespace do_not_use 150 { 151 152 // Needed to silence clang warning: 153 // 154 // warning: function 'FUNCTION NAME' is not needed and will not be emitted 155 // 156 // The warning is silly, since the function *is* used, however the host compiler does not 157 // appear see this. Likely because the function using it is in a template. 158 // 159 // This warning appeared in clang-11, and still persists until clang-15 (21/02/2023) 160 inline void silence_warning_function_add_coo_values_is_not_needed_and_will_not_be_emitted() 161 { 162 (void)add_coo_values; 163 } 164 165 } // namespace do_not_use 166 #endif 167 168 } // namespace kernels 169 170 } // namespace impl 171 172 // ========================================================================================== 173 // VecSeq_CUPM - Implementations 174 // ========================================================================================== 175 176 template <device::cupm::DeviceType T> 177 inline PetscErrorCode VecCreateSeqCUPMAsync(MPI_Comm comm, PetscInt n, Vec *v) noexcept 178 { 179 PetscFunctionBegin; 180 PetscValidPointer(v, 4); 181 PetscCall(impl::VecSeq_CUPM<T>::CreateSeqCUPM(comm, 0, n, v, PETSC_TRUE)); 182 PetscFunctionReturn(PETSC_SUCCESS); 183 } 184 185 template <device::cupm::DeviceType T> 186 inline PetscErrorCode VecCreateSeqCUPMWithArraysAsync(MPI_Comm comm, PetscInt bs, PetscInt n, const PetscScalar cpuarray[], const PetscScalar gpuarray[], Vec *v) noexcept 187 { 188 PetscFunctionBegin; 189 if (n && cpuarray) PetscValidScalarPointer(cpuarray, 4); 190 PetscValidPointer(v, 6); 191 PetscCall(impl::VecSeq_CUPM<T>::CreateSeqCUPMWithBothArrays(comm, bs, n, cpuarray, gpuarray, v)); 192 PetscFunctionReturn(PETSC_SUCCESS); 193 } 194 195 template <PetscMemoryAccessMode mode, device::cupm::DeviceType T> 196 inline PetscErrorCode VecCUPMGetArrayAsync_Private(Vec v, PetscScalar **a, PetscDeviceContext dctx) noexcept 197 { 198 PetscFunctionBegin; 199 PetscValidHeaderSpecific(v, VEC_CLASSID, 1); 200 PetscValidPointer(a, 2); 201 PetscCall(PetscDeviceContextGetOptionalNullContext_Internal(&dctx)); 202 PetscCall(impl::VecSeq_CUPM<T>::template GetArray<PETSC_MEMTYPE_DEVICE, mode>(v, a, dctx)); 203 PetscFunctionReturn(PETSC_SUCCESS); 204 } 205 206 template <PetscMemoryAccessMode mode, device::cupm::DeviceType T> 207 inline PetscErrorCode VecCUPMRestoreArrayAsync_Private(Vec v, PetscScalar **a, PetscDeviceContext dctx) noexcept 208 { 209 PetscFunctionBegin; 210 PetscValidHeaderSpecific(v, VEC_CLASSID, 1); 211 PetscCall(PetscDeviceContextGetOptionalNullContext_Internal(&dctx)); 212 PetscCall(impl::VecSeq_CUPM<T>::template RestoreArray<PETSC_MEMTYPE_DEVICE, mode>(v, a, dctx)); 213 PetscFunctionReturn(PETSC_SUCCESS); 214 } 215 216 template <device::cupm::DeviceType T> 217 inline PetscErrorCode VecCUPMGetArrayAsync(Vec v, PetscScalar **a, PetscDeviceContext dctx = nullptr) noexcept 218 { 219 PetscFunctionBegin; 220 PetscCall(VecCUPMGetArrayAsync_Private<PETSC_MEMORY_ACCESS_READ_WRITE, T>(v, a, dctx)); 221 PetscFunctionReturn(PETSC_SUCCESS); 222 } 223 224 template <device::cupm::DeviceType T> 225 inline PetscErrorCode VecCUPMRestoreArrayAsync(Vec v, PetscScalar **a, PetscDeviceContext dctx = nullptr) noexcept 226 { 227 PetscFunctionBegin; 228 PetscCall(VecCUPMRestoreArrayAsync_Private<PETSC_MEMORY_ACCESS_READ_WRITE, T>(v, a, dctx)); 229 PetscFunctionReturn(PETSC_SUCCESS); 230 } 231 232 template <device::cupm::DeviceType T> 233 inline PetscErrorCode VecCUPMGetArrayReadAsync(Vec v, const PetscScalar **a, PetscDeviceContext dctx = nullptr) noexcept 234 { 235 PetscFunctionBegin; 236 PetscCall(VecCUPMGetArrayAsync_Private<PETSC_MEMORY_ACCESS_READ, T>(v, const_cast<PetscScalar **>(a), dctx)); 237 PetscFunctionReturn(PETSC_SUCCESS); 238 } 239 240 template <device::cupm::DeviceType T> 241 inline PetscErrorCode VecCUPMRestoreArrayReadAsync(Vec v, const PetscScalar **a, PetscDeviceContext dctx = nullptr) noexcept 242 { 243 PetscFunctionBegin; 244 PetscCall(VecCUPMRestoreArrayAsync_Private<PETSC_MEMORY_ACCESS_READ, T>(v, const_cast<PetscScalar **>(a), dctx)); 245 PetscFunctionReturn(PETSC_SUCCESS); 246 } 247 248 template <device::cupm::DeviceType T> 249 inline PetscErrorCode VecCUPMGetArrayWriteAsync(Vec v, PetscScalar **a, PetscDeviceContext dctx = nullptr) noexcept 250 { 251 PetscFunctionBegin; 252 PetscCall(VecCUPMGetArrayAsync_Private<PETSC_MEMORY_ACCESS_WRITE, T>(v, a, dctx)); 253 PetscFunctionReturn(PETSC_SUCCESS); 254 } 255 256 template <device::cupm::DeviceType T> 257 inline PetscErrorCode VecCUPMRestoreArrayWriteAsync(Vec v, PetscScalar **a, PetscDeviceContext dctx = nullptr) noexcept 258 { 259 PetscFunctionBegin; 260 PetscCall(VecCUPMRestoreArrayAsync_Private<PETSC_MEMORY_ACCESS_WRITE, T>(v, a, dctx)); 261 PetscFunctionReturn(PETSC_SUCCESS); 262 } 263 264 template <device::cupm::DeviceType T> 265 inline PetscErrorCode VecCUPMPlaceArrayAsync(Vec vin, const PetscScalar a[]) noexcept 266 { 267 PetscFunctionBegin; 268 PetscValidHeaderSpecific(vin, VEC_CLASSID, 1); 269 PetscCall(impl::VecSeq_CUPM<T>::template PlaceArray<PETSC_MEMTYPE_DEVICE>(vin, a)); 270 PetscFunctionReturn(PETSC_SUCCESS); 271 } 272 273 template <device::cupm::DeviceType T> 274 inline PetscErrorCode VecCUPMReplaceArrayAsync(Vec vin, const PetscScalar a[]) noexcept 275 { 276 PetscFunctionBegin; 277 PetscValidHeaderSpecific(vin, VEC_CLASSID, 1); 278 PetscCall(impl::VecSeq_CUPM<T>::template ReplaceArray<PETSC_MEMTYPE_DEVICE>(vin, a)); 279 PetscFunctionReturn(PETSC_SUCCESS); 280 } 281 282 template <device::cupm::DeviceType T> 283 inline PetscErrorCode VecCUPMResetArrayAsync(Vec vin) noexcept 284 { 285 PetscFunctionBegin; 286 PetscValidHeaderSpecific(vin, VEC_CLASSID, 1); 287 PetscCall(impl::VecSeq_CUPM<T>::template ResetArray<PETSC_MEMTYPE_DEVICE>(vin)); 288 PetscFunctionReturn(PETSC_SUCCESS); 289 } 290 291 } // namespace cupm 292 293 } // namespace vec 294 295 } // namespace Petsc 296 297 #if PetscDefined(HAVE_CUDA) 298 extern template class PETSC_SINGLE_LIBRARY_VISIBILITY_INTERNAL ::Petsc::vec::cupm::impl::VecSeq_CUPM<::Petsc::device::cupm::DeviceType::CUDA>; 299 #endif 300 301 #if PetscDefined(HAVE_HIP) 302 extern template class PETSC_SINGLE_LIBRARY_VISIBILITY_INTERNAL ::Petsc::vec::cupm::impl::VecSeq_CUPM<::Petsc::device::cupm::DeviceType::HIP>; 303 #endif 304 305 #endif // PETSCVECSEQCUPM_HPP 306