1 #ifndef PETSCVECSEQCUPM_HPP 2 #define PETSCVECSEQCUPM_HPP 3 4 #include <petsc/private/veccupmimpl.h> 5 #include <petsc/private/cpp/utility.hpp> // util::index_sequence 6 7 #include <../src/sys/objects/device/impls/cupm/kernels.hpp> // grid_stride_1D() 8 #include <../src/vec/vec/impls/dvecimpl.h> // Vec_Seq 9 10 namespace Petsc 11 { 12 13 namespace vec 14 { 15 16 namespace cupm 17 { 18 19 namespace impl 20 { 21 22 // ========================================================================================== 23 // VecSeq_CUPM 24 // ========================================================================================== 25 26 template <device::cupm::DeviceType T> 27 class VecSeq_CUPM : Vec_CUPMBase<T, VecSeq_CUPM<T>> { 28 public: 29 PETSC_VEC_CUPM_BASE_CLASS_HEADER(base_type, T, VecSeq_CUPM<T>); 30 31 private: 32 PETSC_NODISCARD static Vec_Seq *VecIMPLCast_(Vec) noexcept; 33 PETSC_NODISCARD static constexpr VecType VECIMPLCUPM_() noexcept; 34 PETSC_NODISCARD static constexpr VecType VECIMPL_() noexcept; 35 36 static PetscErrorCode VecDestroy_IMPL_(Vec) noexcept; 37 static PetscErrorCode VecResetArray_IMPL_(Vec) noexcept; 38 static PetscErrorCode VecPlaceArray_IMPL_(Vec, const PetscScalar *) noexcept; 39 static PetscErrorCode VecCreate_IMPL_Private_(Vec, PetscBool *, PetscInt, PetscScalar *) noexcept; 40 41 static PetscErrorCode MaybeIncrementEmptyLocalVec(Vec) noexcept; 42 43 // common core for min and max 44 template <typename TupleFuncT, typename UnaryFuncT> 45 static PetscErrorCode MinMax_(TupleFuncT &&, UnaryFuncT &&, Vec, PetscInt *, PetscReal *) noexcept; 46 // common core for pointwise binary and pointwise unary thrust functions 47 template <typename BinaryFuncT> 48 static PetscErrorCode PointwiseBinary_(BinaryFuncT &&, Vec, Vec, Vec) noexcept; 49 template <typename UnaryFuncT> 50 static PetscErrorCode PointwiseUnary_(UnaryFuncT &&, Vec, Vec /*out*/ = nullptr) noexcept; 51 // mdot dispatchers 52 static PetscErrorCode MDot_(/* use complex = */ std::true_type, Vec, PetscInt, const Vec[], PetscScalar *, PetscDeviceContext) noexcept; 53 static PetscErrorCode MDot_(/* use complex = */ std::false_type, Vec, PetscInt, const Vec[], PetscScalar *, PetscDeviceContext) noexcept; 54 template <std::size_t... Idx> 55 static PetscErrorCode MDot_kernel_dispatch_(PetscDeviceContext, cupmStream_t, const PetscScalar *, const Vec[], PetscInt, PetscScalar *, util::index_sequence<Idx...>) noexcept; 56 template <int> 57 static PetscErrorCode MDot_kernel_dispatch_(PetscDeviceContext, cupmStream_t, const PetscScalar *, const Vec[], PetscInt, PetscScalar *, PetscInt &) noexcept; 58 template <std::size_t... Idx> 59 static PetscErrorCode MAXPY_kernel_dispatch_(PetscDeviceContext, cupmStream_t, PetscScalar *, const PetscScalar *, const Vec *, PetscInt, util::index_sequence<Idx...>) noexcept; 60 template <int> 61 static PetscErrorCode MAXPY_kernel_dispatch_(PetscDeviceContext, cupmStream_t, PetscScalar *, const PetscScalar *, const Vec *, PetscInt, PetscInt &) noexcept; 62 // common core for the various create routines 63 static PetscErrorCode CreateSeqCUPM_(Vec, PetscDeviceContext, PetscScalar * /*host_ptr*/ = nullptr, PetscScalar * /*device_ptr*/ = nullptr) noexcept; 64 65 public: 66 // callable directly via a bespoke function 67 static PetscErrorCode CreateSeqCUPM(MPI_Comm, PetscInt, PetscInt, Vec *, PetscBool) noexcept; 68 static PetscErrorCode CreateSeqCUPMWithBothArrays(MPI_Comm, PetscInt, PetscInt, const PetscScalar[], const PetscScalar[], Vec *) noexcept; 69 70 // callable indirectly via function pointers 71 static PetscErrorCode Duplicate(Vec, Vec *) noexcept; 72 static PetscErrorCode AYPX(Vec, PetscScalar, Vec) noexcept; 73 static PetscErrorCode AXPY(Vec, PetscScalar, Vec) noexcept; 74 static PetscErrorCode PointwiseDivide(Vec, Vec, Vec) noexcept; 75 static PetscErrorCode PointwiseMult(Vec, Vec, Vec) noexcept; 76 static PetscErrorCode Reciprocal(Vec) noexcept; 77 static PetscErrorCode WAXPY(Vec, PetscScalar, Vec, Vec) noexcept; 78 static PetscErrorCode MAXPY(Vec, PetscInt, const PetscScalar[], Vec *) noexcept; 79 static PetscErrorCode Dot(Vec, Vec, PetscScalar *) noexcept; 80 static PetscErrorCode MDot(Vec, PetscInt, const Vec[], PetscScalar *) noexcept; 81 static PetscErrorCode Set(Vec, PetscScalar) noexcept; 82 static PetscErrorCode Scale(Vec, PetscScalar) noexcept; 83 static PetscErrorCode TDot(Vec, Vec, PetscScalar *) noexcept; 84 static PetscErrorCode Copy(Vec, Vec) noexcept; 85 static PetscErrorCode Swap(Vec, Vec) noexcept; 86 static PetscErrorCode AXPBY(Vec, PetscScalar, PetscScalar, Vec) noexcept; 87 static PetscErrorCode AXPBYPCZ(Vec, PetscScalar, PetscScalar, PetscScalar, Vec, Vec) noexcept; 88 static PetscErrorCode Norm(Vec, NormType, PetscReal *) noexcept; 89 static PetscErrorCode ErrorWnorm(Vec, Vec, Vec, NormType, PetscReal, Vec, PetscReal, Vec, PetscReal, PetscReal *, PetscInt *, PetscReal *, PetscInt *, PetscReal *, PetscInt *) noexcept; 90 static PetscErrorCode DotNorm2(Vec, Vec, PetscScalar *, PetscScalar *) noexcept; 91 static PetscErrorCode Conjugate(Vec) noexcept; 92 template <PetscMemoryAccessMode> 93 static PetscErrorCode GetLocalVector(Vec, Vec) noexcept; 94 template <PetscMemoryAccessMode> 95 static PetscErrorCode RestoreLocalVector(Vec, Vec) noexcept; 96 static PetscErrorCode Max(Vec, PetscInt *, PetscReal *) noexcept; 97 static PetscErrorCode Min(Vec, PetscInt *, PetscReal *) noexcept; 98 static PetscErrorCode Sum(Vec, PetscScalar *) noexcept; 99 static PetscErrorCode Shift(Vec, PetscScalar) noexcept; 100 static PetscErrorCode SetRandom(Vec, PetscRandom) noexcept; 101 static PetscErrorCode BindToCPU(Vec, PetscBool) noexcept; 102 static PetscErrorCode SetPreallocationCOO(Vec, PetscCount, const PetscInt[]) noexcept; 103 static PetscErrorCode SetValuesCOO(Vec, const PetscScalar[], InsertMode) noexcept; 104 }; 105 106 namespace kernels 107 { 108 109 template <typename F> 110 PETSC_DEVICE_INLINE_DECL void add_coo_values_impl(const PetscScalar *PETSC_RESTRICT vv, PetscCount n, const PetscCount *PETSC_RESTRICT jmap, const PetscCount *PETSC_RESTRICT perm, InsertMode imode, PetscScalar *PETSC_RESTRICT xv, F &&xvindex) 111 { 112 ::Petsc::device::cupm::kernels::util::grid_stride_1D(n, [=](PetscCount i) { 113 const auto end = jmap[i + 1]; 114 const auto idx = xvindex(i); 115 PetscScalar sum = 0.0; 116 117 for (auto k = jmap[i]; k < end; ++k) sum += vv[perm[k]]; 118 119 if (imode == INSERT_VALUES) { 120 xv[idx] = sum; 121 } else { 122 xv[idx] += sum; 123 } 124 }); 125 return; 126 } 127 128 namespace 129 { 130 131 PETSC_KERNEL_DECL void add_coo_values(const PetscScalar *PETSC_RESTRICT v, PetscCount n, const PetscCount *PETSC_RESTRICT jmap1, const PetscCount *PETSC_RESTRICT perm1, InsertMode imode, PetscScalar *PETSC_RESTRICT xv) 132 { 133 add_coo_values_impl(v, n, jmap1, perm1, imode, xv, [](PetscCount i) { return i; }); 134 return; 135 } 136 137 } // namespace 138 139 #if PetscDefined(USING_HCC) 140 namespace do_not_use 141 { 142 143 // Needed to silence clang warning: 144 // 145 // warning: function 'FUNCTION NAME' is not needed and will not be emitted 146 // 147 // The warning is silly, since the function *is* used, however the host compiler does not 148 // appear see this. Likely because the function using it is in a template. 149 // 150 // This warning appeared in clang-11, and still persists until clang-15 (21/02/2023) 151 inline void silence_warning_function_add_coo_values_is_not_needed_and_will_not_be_emitted() 152 { 153 (void)add_coo_values; 154 } 155 156 } // namespace do_not_use 157 #endif 158 159 } // namespace kernels 160 161 } // namespace impl 162 163 // ========================================================================================== 164 // VecSeq_CUPM - Implementations 165 // ========================================================================================== 166 167 template <device::cupm::DeviceType T> 168 inline PetscErrorCode VecCreateSeqCUPMAsync(MPI_Comm comm, PetscInt n, Vec *v) noexcept 169 { 170 PetscFunctionBegin; 171 PetscValidPointer(v, 4); 172 PetscCall(impl::VecSeq_CUPM<T>::CreateSeqCUPM(comm, 0, n, v, PETSC_TRUE)); 173 PetscFunctionReturn(PETSC_SUCCESS); 174 } 175 176 template <device::cupm::DeviceType T> 177 inline PetscErrorCode VecCreateSeqCUPMWithArraysAsync(MPI_Comm comm, PetscInt bs, PetscInt n, const PetscScalar cpuarray[], const PetscScalar gpuarray[], Vec *v) noexcept 178 { 179 PetscFunctionBegin; 180 if (n && cpuarray) PetscValidScalarPointer(cpuarray, 4); 181 PetscValidPointer(v, 6); 182 PetscCall(impl::VecSeq_CUPM<T>::CreateSeqCUPMWithBothArrays(comm, bs, n, cpuarray, gpuarray, v)); 183 PetscFunctionReturn(PETSC_SUCCESS); 184 } 185 186 template <PetscMemoryAccessMode mode, device::cupm::DeviceType T> 187 inline PetscErrorCode VecCUPMGetArrayAsync_Private(Vec v, PetscScalar **a, PetscDeviceContext dctx) noexcept 188 { 189 PetscFunctionBegin; 190 PetscValidHeaderSpecific(v, VEC_CLASSID, 1); 191 PetscValidPointer(a, 2); 192 PetscCall(PetscDeviceContextGetOptionalNullContext_Internal(&dctx)); 193 PetscCall(impl::VecSeq_CUPM<T>::template GetArray<PETSC_MEMTYPE_DEVICE, mode>(v, a, dctx)); 194 PetscFunctionReturn(PETSC_SUCCESS); 195 } 196 197 template <PetscMemoryAccessMode mode, device::cupm::DeviceType T> 198 inline PetscErrorCode VecCUPMRestoreArrayAsync_Private(Vec v, PetscScalar **a, PetscDeviceContext dctx) noexcept 199 { 200 PetscFunctionBegin; 201 PetscValidHeaderSpecific(v, VEC_CLASSID, 1); 202 PetscCall(PetscDeviceContextGetOptionalNullContext_Internal(&dctx)); 203 PetscCall(impl::VecSeq_CUPM<T>::template RestoreArray<PETSC_MEMTYPE_DEVICE, mode>(v, a, dctx)); 204 PetscFunctionReturn(PETSC_SUCCESS); 205 } 206 207 template <device::cupm::DeviceType T> 208 inline PetscErrorCode VecCUPMGetArrayAsync(Vec v, PetscScalar **a, PetscDeviceContext dctx = nullptr) noexcept 209 { 210 PetscFunctionBegin; 211 PetscCall(VecCUPMGetArrayAsync_Private<PETSC_MEMORY_ACCESS_READ_WRITE, T>(v, a, dctx)); 212 PetscFunctionReturn(PETSC_SUCCESS); 213 } 214 215 template <device::cupm::DeviceType T> 216 inline PetscErrorCode VecCUPMRestoreArrayAsync(Vec v, PetscScalar **a, PetscDeviceContext dctx = nullptr) noexcept 217 { 218 PetscFunctionBegin; 219 PetscCall(VecCUPMRestoreArrayAsync_Private<PETSC_MEMORY_ACCESS_READ_WRITE, T>(v, a, dctx)); 220 PetscFunctionReturn(PETSC_SUCCESS); 221 } 222 223 template <device::cupm::DeviceType T> 224 inline PetscErrorCode VecCUPMGetArrayReadAsync(Vec v, const PetscScalar **a, PetscDeviceContext dctx = nullptr) noexcept 225 { 226 PetscFunctionBegin; 227 PetscCall(VecCUPMGetArrayAsync_Private<PETSC_MEMORY_ACCESS_READ, T>(v, const_cast<PetscScalar **>(a), dctx)); 228 PetscFunctionReturn(PETSC_SUCCESS); 229 } 230 231 template <device::cupm::DeviceType T> 232 inline PetscErrorCode VecCUPMRestoreArrayReadAsync(Vec v, const PetscScalar **a, PetscDeviceContext dctx = nullptr) noexcept 233 { 234 PetscFunctionBegin; 235 PetscCall(VecCUPMRestoreArrayAsync_Private<PETSC_MEMORY_ACCESS_READ, T>(v, const_cast<PetscScalar **>(a), dctx)); 236 PetscFunctionReturn(PETSC_SUCCESS); 237 } 238 239 template <device::cupm::DeviceType T> 240 inline PetscErrorCode VecCUPMGetArrayWriteAsync(Vec v, PetscScalar **a, PetscDeviceContext dctx = nullptr) noexcept 241 { 242 PetscFunctionBegin; 243 PetscCall(VecCUPMGetArrayAsync_Private<PETSC_MEMORY_ACCESS_WRITE, T>(v, a, dctx)); 244 PetscFunctionReturn(PETSC_SUCCESS); 245 } 246 247 template <device::cupm::DeviceType T> 248 inline PetscErrorCode VecCUPMRestoreArrayWriteAsync(Vec v, PetscScalar **a, PetscDeviceContext dctx = nullptr) noexcept 249 { 250 PetscFunctionBegin; 251 PetscCall(VecCUPMRestoreArrayAsync_Private<PETSC_MEMORY_ACCESS_WRITE, T>(v, a, dctx)); 252 PetscFunctionReturn(PETSC_SUCCESS); 253 } 254 255 template <device::cupm::DeviceType T> 256 inline PetscErrorCode VecCUPMPlaceArrayAsync(Vec vin, const PetscScalar a[]) noexcept 257 { 258 PetscFunctionBegin; 259 PetscValidHeaderSpecific(vin, VEC_CLASSID, 1); 260 PetscCall(impl::VecSeq_CUPM<T>::template PlaceArray<PETSC_MEMTYPE_DEVICE>(vin, a)); 261 PetscFunctionReturn(PETSC_SUCCESS); 262 } 263 264 template <device::cupm::DeviceType T> 265 inline PetscErrorCode VecCUPMReplaceArrayAsync(Vec vin, const PetscScalar a[]) noexcept 266 { 267 PetscFunctionBegin; 268 PetscValidHeaderSpecific(vin, VEC_CLASSID, 1); 269 PetscCall(impl::VecSeq_CUPM<T>::template ReplaceArray<PETSC_MEMTYPE_DEVICE>(vin, a)); 270 PetscFunctionReturn(PETSC_SUCCESS); 271 } 272 273 template <device::cupm::DeviceType T> 274 inline PetscErrorCode VecCUPMResetArrayAsync(Vec vin) noexcept 275 { 276 PetscFunctionBegin; 277 PetscValidHeaderSpecific(vin, VEC_CLASSID, 1); 278 PetscCall(impl::VecSeq_CUPM<T>::template ResetArray<PETSC_MEMTYPE_DEVICE>(vin)); 279 PetscFunctionReturn(PETSC_SUCCESS); 280 } 281 282 } // namespace cupm 283 284 } // namespace vec 285 286 } // namespace Petsc 287 288 #if PetscDefined(HAVE_CUDA) 289 extern template class PETSC_SINGLE_LIBRARY_VISIBILITY_INTERNAL ::Petsc::vec::cupm::impl::VecSeq_CUPM<::Petsc::device::cupm::DeviceType::CUDA>; 290 #endif 291 292 #if PetscDefined(HAVE_HIP) 293 extern template class PETSC_SINGLE_LIBRARY_VISIBILITY_INTERNAL ::Petsc::vec::cupm::impl::VecSeq_CUPM<::Petsc::device::cupm::DeviceType::HIP>; 294 #endif 295 296 #endif // PETSCVECSEQCUPM_HPP 297