xref: /petsc/src/vec/vec/impls/seq/cupm/vecseqcupm.hpp (revision e056e8ce1cfdaf3750afb193bfd9d5df747ade21)
1 #ifndef PETSCVECSEQCUPM_HPP
2 #define PETSCVECSEQCUPM_HPP
3 
4 #include <petsc/private/veccupmimpl.h>
5 #include <petsc/private/cpp/utility.hpp> // util::index_sequence
6 
7 #include <../src/sys/objects/device/impls/cupm/kernels.hpp> // grid_stride_1D()
8 #include <../src/vec/vec/impls/dvecimpl.h>                  // Vec_Seq
9 
10 namespace Petsc
11 {
12 
13 namespace vec
14 {
15 
16 namespace cupm
17 {
18 
19 namespace impl
20 {
21 
22 // ==========================================================================================
23 // VecSeq_CUPM
24 // ==========================================================================================
25 
26 template <device::cupm::DeviceType T>
27 class VecSeq_CUPM : Vec_CUPMBase<T, VecSeq_CUPM<T>> {
28 public:
29   PETSC_VEC_CUPM_BASE_CLASS_HEADER(base_type, T, VecSeq_CUPM<T>);
30 
31 private:
32   PETSC_NODISCARD static Vec_Seq          *VecIMPLCast_(Vec) noexcept;
33   PETSC_NODISCARD static constexpr VecType VECIMPLCUPM_() noexcept;
34   PETSC_NODISCARD static constexpr VecType VECIMPL_() noexcept;
35 
36   static PetscErrorCode VecDestroy_IMPL_(Vec) noexcept;
37   static PetscErrorCode VecResetArray_IMPL_(Vec) noexcept;
38   static PetscErrorCode VecPlaceArray_IMPL_(Vec, const PetscScalar *) noexcept;
39   static PetscErrorCode VecCreate_IMPL_Private_(Vec, PetscBool *, PetscInt, PetscScalar *) noexcept;
40 
41   static PetscErrorCode MaybeIncrementEmptyLocalVec(Vec) noexcept;
42 
43   // common core for min and max
44   template <typename TupleFuncT, typename UnaryFuncT>
45   static PetscErrorCode MinMax_(TupleFuncT &&, UnaryFuncT &&, Vec, PetscInt *, PetscReal *) noexcept;
46   // common core for pointwise binary and pointwise unary thrust functions
47   template <typename BinaryFuncT>
48   static PetscErrorCode PointwiseBinary_(BinaryFuncT &&, Vec, Vec, Vec) noexcept;
49   template <typename BinaryFuncT>
50   static PetscErrorCode PointwiseBinaryDispatch_(PetscErrorCode (*)(Vec, Vec, Vec), BinaryFuncT &&, Vec, Vec, Vec) noexcept;
51   template <typename UnaryFuncT>
52   static PetscErrorCode PointwiseUnary_(UnaryFuncT &&, Vec, Vec /*out*/ = nullptr) noexcept;
53   // mdot dispatchers
54   static PetscErrorCode MDot_(/* use complex = */ std::true_type, Vec, PetscInt, const Vec[], PetscScalar *, PetscDeviceContext) noexcept;
55   static PetscErrorCode MDot_(/* use complex = */ std::false_type, Vec, PetscInt, const Vec[], PetscScalar *, PetscDeviceContext) noexcept;
56   template <std::size_t... Idx>
57   static PetscErrorCode MDot_kernel_dispatch_(PetscDeviceContext, cupmStream_t, const PetscScalar *, const Vec[], PetscInt, PetscScalar *, util::index_sequence<Idx...>) noexcept;
58   template <int>
59   static PetscErrorCode MDot_kernel_dispatch_(PetscDeviceContext, cupmStream_t, const PetscScalar *, const Vec[], PetscInt, PetscScalar *, PetscInt &) noexcept;
60   template <std::size_t... Idx>
61   static PetscErrorCode MAXPY_kernel_dispatch_(PetscDeviceContext, cupmStream_t, PetscScalar *, const PetscScalar *, const Vec *, PetscInt, util::index_sequence<Idx...>) noexcept;
62   template <int>
63   static PetscErrorCode MAXPY_kernel_dispatch_(PetscDeviceContext, cupmStream_t, PetscScalar *, const PetscScalar *, const Vec *, PetscInt, PetscInt &) noexcept;
64   // common core for the various create routines
65   static PetscErrorCode CreateSeqCUPM_(Vec, PetscDeviceContext, PetscScalar * /*host_ptr*/ = nullptr, PetscScalar * /*device_ptr*/ = nullptr) noexcept;
66 
67 public:
68   // callable directly via a bespoke function
69   static PetscErrorCode CreateSeqCUPM(MPI_Comm, PetscInt, PetscInt, Vec *, PetscBool) noexcept;
70   static PetscErrorCode CreateSeqCUPMWithBothArrays(MPI_Comm, PetscInt, PetscInt, const PetscScalar[], const PetscScalar[], Vec *) noexcept;
71 
72   // callable indirectly via function pointers
73   static PetscErrorCode Duplicate(Vec, Vec *) noexcept;
74   static PetscErrorCode AYPX(Vec, PetscScalar, Vec) noexcept;
75   static PetscErrorCode AXPY(Vec, PetscScalar, Vec) noexcept;
76   static PetscErrorCode PointwiseDivide(Vec, Vec, Vec) noexcept;
77   static PetscErrorCode PointwiseMult(Vec, Vec, Vec) noexcept;
78   static PetscErrorCode PointwiseMax(Vec, Vec, Vec) noexcept;
79   static PetscErrorCode PointwiseMaxAbs(Vec, Vec, Vec) noexcept;
80   static PetscErrorCode PointwiseMin(Vec, Vec, Vec) noexcept;
81   static PetscErrorCode Reciprocal(Vec) noexcept;
82   static PetscErrorCode Abs(Vec) noexcept;
83   static PetscErrorCode Sqrt(Vec) noexcept;
84   static PetscErrorCode Exp(Vec) noexcept;
85   static PetscErrorCode Log(Vec) noexcept;
86   static PetscErrorCode WAXPY(Vec, PetscScalar, Vec, Vec) noexcept;
87   static PetscErrorCode MAXPY(Vec, PetscInt, const PetscScalar[], Vec *) noexcept;
88   static PetscErrorCode Dot(Vec, Vec, PetscScalar *) noexcept;
89   static PetscErrorCode MDot(Vec, PetscInt, const Vec[], PetscScalar *) noexcept;
90   static PetscErrorCode Set(Vec, PetscScalar) noexcept;
91   static PetscErrorCode Scale(Vec, PetscScalar) noexcept;
92   static PetscErrorCode TDot(Vec, Vec, PetscScalar *) noexcept;
93   static PetscErrorCode Copy(Vec, Vec) noexcept;
94   static PetscErrorCode Swap(Vec, Vec) noexcept;
95   static PetscErrorCode AXPBY(Vec, PetscScalar, PetscScalar, Vec) noexcept;
96   static PetscErrorCode AXPBYPCZ(Vec, PetscScalar, PetscScalar, PetscScalar, Vec, Vec) noexcept;
97   static PetscErrorCode Norm(Vec, NormType, PetscReal *) noexcept;
98   static PetscErrorCode ErrorWnorm(Vec, Vec, Vec, NormType, PetscReal, Vec, PetscReal, Vec, PetscReal, PetscReal *, PetscInt *, PetscReal *, PetscInt *, PetscReal *, PetscInt *) noexcept;
99   static PetscErrorCode DotNorm2(Vec, Vec, PetscScalar *, PetscScalar *) noexcept;
100   static PetscErrorCode Conjugate(Vec) noexcept;
101   template <PetscMemoryAccessMode>
102   static PetscErrorCode GetLocalVector(Vec, Vec) noexcept;
103   template <PetscMemoryAccessMode>
104   static PetscErrorCode RestoreLocalVector(Vec, Vec) noexcept;
105   static PetscErrorCode Max(Vec, PetscInt *, PetscReal *) noexcept;
106   static PetscErrorCode Min(Vec, PetscInt *, PetscReal *) noexcept;
107   static PetscErrorCode Sum(Vec, PetscScalar *) noexcept;
108   static PetscErrorCode Shift(Vec, PetscScalar) noexcept;
109   static PetscErrorCode SetRandom(Vec, PetscRandom) noexcept;
110   static PetscErrorCode BindToCPU(Vec, PetscBool) noexcept;
111   static PetscErrorCode SetPreallocationCOO(Vec, PetscCount, const PetscInt[]) noexcept;
112   static PetscErrorCode SetValuesCOO(Vec, const PetscScalar[], InsertMode) noexcept;
113 };
114 
115 namespace kernels
116 {
117 
118 template <typename F>
119 PETSC_DEVICE_INLINE_DECL void add_coo_values_impl(const PetscScalar *PETSC_RESTRICT vv, PetscCount n, const PetscCount *PETSC_RESTRICT jmap, const PetscCount *PETSC_RESTRICT perm, InsertMode imode, PetscScalar *PETSC_RESTRICT xv, F &&xvindex)
120 {
121   ::Petsc::device::cupm::kernels::util::grid_stride_1D(n, [=](PetscCount i) {
122     const auto  end = jmap[i + 1];
123     const auto  idx = xvindex(i);
124     PetscScalar sum = 0.0;
125 
126     for (auto k = jmap[i]; k < end; ++k) sum += vv[perm[k]];
127 
128     if (imode == INSERT_VALUES) {
129       xv[idx] = sum;
130     } else {
131       xv[idx] += sum;
132     }
133   });
134   return;
135 }
136 
137 namespace
138 {
139 
140 PETSC_KERNEL_DECL void add_coo_values(const PetscScalar *PETSC_RESTRICT v, PetscCount n, const PetscCount *PETSC_RESTRICT jmap1, const PetscCount *PETSC_RESTRICT perm1, InsertMode imode, PetscScalar *PETSC_RESTRICT xv)
141 {
142   add_coo_values_impl(v, n, jmap1, perm1, imode, xv, [](PetscCount i) { return i; });
143   return;
144 }
145 
146 } // namespace
147 
148 #if PetscDefined(USING_HCC)
149 namespace do_not_use
150 {
151 
152 // Needed to silence clang warning:
153 //
154 // warning: function 'FUNCTION NAME' is not needed and will not be emitted
155 //
156 // The warning is silly, since the function *is* used, however the host compiler does not
157 // appear see this. Likely because the function using it is in a template.
158 //
159 // This warning appeared in clang-11, and still persists until clang-15 (21/02/2023)
160 inline void silence_warning_function_add_coo_values_is_not_needed_and_will_not_be_emitted()
161 {
162   (void)add_coo_values;
163 }
164 
165 } // namespace do_not_use
166 #endif
167 
168 } // namespace kernels
169 
170 } // namespace impl
171 
172 // ==========================================================================================
173 // VecSeq_CUPM - Implementations
174 // ==========================================================================================
175 
176 template <device::cupm::DeviceType T>
177 inline PetscErrorCode VecCreateSeqCUPMAsync(MPI_Comm comm, PetscInt n, Vec *v) noexcept
178 {
179   PetscFunctionBegin;
180   PetscValidPointer(v, 4);
181   PetscCall(impl::VecSeq_CUPM<T>::CreateSeqCUPM(comm, 0, n, v, PETSC_TRUE));
182   PetscFunctionReturn(PETSC_SUCCESS);
183 }
184 
185 template <device::cupm::DeviceType T>
186 inline PetscErrorCode VecCreateSeqCUPMWithArraysAsync(MPI_Comm comm, PetscInt bs, PetscInt n, const PetscScalar cpuarray[], const PetscScalar gpuarray[], Vec *v) noexcept
187 {
188   PetscFunctionBegin;
189   if (n && cpuarray) PetscValidScalarPointer(cpuarray, 4);
190   PetscValidPointer(v, 6);
191   PetscCall(impl::VecSeq_CUPM<T>::CreateSeqCUPMWithBothArrays(comm, bs, n, cpuarray, gpuarray, v));
192   PetscFunctionReturn(PETSC_SUCCESS);
193 }
194 
195 template <PetscMemoryAccessMode mode, device::cupm::DeviceType T>
196 inline PetscErrorCode VecCUPMGetArrayAsync_Private(Vec v, PetscScalar **a, PetscDeviceContext dctx) noexcept
197 {
198   PetscFunctionBegin;
199   PetscValidHeaderSpecific(v, VEC_CLASSID, 1);
200   PetscValidPointer(a, 2);
201   PetscCall(PetscDeviceContextGetOptionalNullContext_Internal(&dctx));
202   PetscCall(impl::VecSeq_CUPM<T>::template GetArray<PETSC_MEMTYPE_DEVICE, mode>(v, a, dctx));
203   PetscFunctionReturn(PETSC_SUCCESS);
204 }
205 
206 template <PetscMemoryAccessMode mode, device::cupm::DeviceType T>
207 inline PetscErrorCode VecCUPMRestoreArrayAsync_Private(Vec v, PetscScalar **a, PetscDeviceContext dctx) noexcept
208 {
209   PetscFunctionBegin;
210   PetscValidHeaderSpecific(v, VEC_CLASSID, 1);
211   PetscCall(PetscDeviceContextGetOptionalNullContext_Internal(&dctx));
212   PetscCall(impl::VecSeq_CUPM<T>::template RestoreArray<PETSC_MEMTYPE_DEVICE, mode>(v, a, dctx));
213   PetscFunctionReturn(PETSC_SUCCESS);
214 }
215 
216 template <device::cupm::DeviceType T>
217 inline PetscErrorCode VecCUPMGetArrayAsync(Vec v, PetscScalar **a, PetscDeviceContext dctx = nullptr) noexcept
218 {
219   PetscFunctionBegin;
220   PetscCall(VecCUPMGetArrayAsync_Private<PETSC_MEMORY_ACCESS_READ_WRITE, T>(v, a, dctx));
221   PetscFunctionReturn(PETSC_SUCCESS);
222 }
223 
224 template <device::cupm::DeviceType T>
225 inline PetscErrorCode VecCUPMRestoreArrayAsync(Vec v, PetscScalar **a, PetscDeviceContext dctx = nullptr) noexcept
226 {
227   PetscFunctionBegin;
228   PetscCall(VecCUPMRestoreArrayAsync_Private<PETSC_MEMORY_ACCESS_READ_WRITE, T>(v, a, dctx));
229   PetscFunctionReturn(PETSC_SUCCESS);
230 }
231 
232 template <device::cupm::DeviceType T>
233 inline PetscErrorCode VecCUPMGetArrayReadAsync(Vec v, const PetscScalar **a, PetscDeviceContext dctx = nullptr) noexcept
234 {
235   PetscFunctionBegin;
236   PetscCall(VecCUPMGetArrayAsync_Private<PETSC_MEMORY_ACCESS_READ, T>(v, const_cast<PetscScalar **>(a), dctx));
237   PetscFunctionReturn(PETSC_SUCCESS);
238 }
239 
240 template <device::cupm::DeviceType T>
241 inline PetscErrorCode VecCUPMRestoreArrayReadAsync(Vec v, const PetscScalar **a, PetscDeviceContext dctx = nullptr) noexcept
242 {
243   PetscFunctionBegin;
244   PetscCall(VecCUPMRestoreArrayAsync_Private<PETSC_MEMORY_ACCESS_READ, T>(v, const_cast<PetscScalar **>(a), dctx));
245   PetscFunctionReturn(PETSC_SUCCESS);
246 }
247 
248 template <device::cupm::DeviceType T>
249 inline PetscErrorCode VecCUPMGetArrayWriteAsync(Vec v, PetscScalar **a, PetscDeviceContext dctx = nullptr) noexcept
250 {
251   PetscFunctionBegin;
252   PetscCall(VecCUPMGetArrayAsync_Private<PETSC_MEMORY_ACCESS_WRITE, T>(v, a, dctx));
253   PetscFunctionReturn(PETSC_SUCCESS);
254 }
255 
256 template <device::cupm::DeviceType T>
257 inline PetscErrorCode VecCUPMRestoreArrayWriteAsync(Vec v, PetscScalar **a, PetscDeviceContext dctx = nullptr) noexcept
258 {
259   PetscFunctionBegin;
260   PetscCall(VecCUPMRestoreArrayAsync_Private<PETSC_MEMORY_ACCESS_WRITE, T>(v, a, dctx));
261   PetscFunctionReturn(PETSC_SUCCESS);
262 }
263 
264 template <device::cupm::DeviceType T>
265 inline PetscErrorCode VecCUPMPlaceArrayAsync(Vec vin, const PetscScalar a[]) noexcept
266 {
267   PetscFunctionBegin;
268   PetscValidHeaderSpecific(vin, VEC_CLASSID, 1);
269   PetscCall(impl::VecSeq_CUPM<T>::template PlaceArray<PETSC_MEMTYPE_DEVICE>(vin, a));
270   PetscFunctionReturn(PETSC_SUCCESS);
271 }
272 
273 template <device::cupm::DeviceType T>
274 inline PetscErrorCode VecCUPMReplaceArrayAsync(Vec vin, const PetscScalar a[]) noexcept
275 {
276   PetscFunctionBegin;
277   PetscValidHeaderSpecific(vin, VEC_CLASSID, 1);
278   PetscCall(impl::VecSeq_CUPM<T>::template ReplaceArray<PETSC_MEMTYPE_DEVICE>(vin, a));
279   PetscFunctionReturn(PETSC_SUCCESS);
280 }
281 
282 template <device::cupm::DeviceType T>
283 inline PetscErrorCode VecCUPMResetArrayAsync(Vec vin) noexcept
284 {
285   PetscFunctionBegin;
286   PetscValidHeaderSpecific(vin, VEC_CLASSID, 1);
287   PetscCall(impl::VecSeq_CUPM<T>::template ResetArray<PETSC_MEMTYPE_DEVICE>(vin));
288   PetscFunctionReturn(PETSC_SUCCESS);
289 }
290 
291 } // namespace cupm
292 
293 } // namespace vec
294 
295 } // namespace Petsc
296 
297 #if PetscDefined(HAVE_CUDA)
298 extern template class PETSC_SINGLE_LIBRARY_VISIBILITY_INTERNAL ::Petsc::vec::cupm::impl::VecSeq_CUPM<::Petsc::device::cupm::DeviceType::CUDA>;
299 #endif
300 
301 #if PetscDefined(HAVE_HIP)
302 extern template class PETSC_SINGLE_LIBRARY_VISIBILITY_INTERNAL ::Petsc::vec::cupm::impl::VecSeq_CUPM<::Petsc::device::cupm::DeviceType::HIP>;
303 #endif
304 
305 #endif // PETSCVECSEQCUPM_HPP
306