xref: /petsc/src/vec/vec/impls/seq/cupm/vecseqcupm.hpp (revision 7a2fbddbbf7441c3c5c59a75eb0adb5a0d08726e)
1 #ifndef PETSCVECSEQCUPM_HPP
2 #define PETSCVECSEQCUPM_HPP
3 
4 #include <petsc/private/veccupmimpl.h>
5 #include <petsc/private/cpp/utility.hpp> // util::index_sequence
6 
7 #include <../src/sys/objects/device/impls/cupm/kernels.hpp> // grid_stride_1D()
8 #include <../src/vec/vec/impls/dvecimpl.h>                  // Vec_Seq
9 
10 namespace Petsc
11 {
12 
13 namespace vec
14 {
15 
16 namespace cupm
17 {
18 
19 namespace impl
20 {
21 
22 // ==========================================================================================
23 // VecSeq_CUPM
24 // ==========================================================================================
25 
26 template <device::cupm::DeviceType T>
27 class VecSeq_CUPM : Vec_CUPMBase<T, VecSeq_CUPM<T>> {
28 public:
29   PETSC_VEC_CUPM_BASE_CLASS_HEADER(base_type, T, VecSeq_CUPM<T>);
30 
31 private:
32   PETSC_NODISCARD static Vec_Seq          *VecIMPLCast_(Vec) noexcept;
33   PETSC_NODISCARD static constexpr VecType VECIMPLCUPM_() noexcept;
34   PETSC_NODISCARD static constexpr VecType VECIMPL_() noexcept;
35 
36   static PetscErrorCode VecDestroy_IMPL_(Vec) noexcept;
37   static PetscErrorCode VecResetArray_IMPL_(Vec) noexcept;
38   static PetscErrorCode VecPlaceArray_IMPL_(Vec, const PetscScalar *) noexcept;
39   static PetscErrorCode VecCreate_IMPL_Private_(Vec, PetscBool *, PetscInt, PetscScalar *) noexcept;
40 
41   static PetscErrorCode MaybeIncrementEmptyLocalVec(Vec) noexcept;
42 
43   // common core for min and max
44   template <typename TupleFuncT, typename UnaryFuncT>
45   static PetscErrorCode MinMax_(TupleFuncT &&, UnaryFuncT &&, Vec, PetscInt *, PetscReal *) noexcept;
46   // common core for pointwise binary and pointwise unary thrust functions
47   template <typename BinaryFuncT>
48   static PetscErrorCode PointwiseBinary_(BinaryFuncT &&, Vec, Vec, Vec, PetscDeviceContext = nullptr) noexcept;
49   template <typename BinaryFuncT>
50   static PetscErrorCode PointwiseBinaryDispatch_(PetscErrorCode (*)(Vec, Vec, Vec), BinaryFuncT &&, Vec, Vec, Vec, PetscDeviceContext = nullptr) noexcept;
51   template <typename UnaryFuncT>
52   static PetscErrorCode PointwiseUnary_(UnaryFuncT &&, Vec, Vec, PetscDeviceContext = nullptr) noexcept;
53   // mdot dispatchers
54   static PetscErrorCode MDot_(/* use complex = */ std::true_type, Vec, PetscInt, const Vec[], PetscScalar *, PetscDeviceContext) noexcept;
55   static PetscErrorCode MDot_(/* use complex = */ std::false_type, Vec, PetscInt, const Vec[], PetscScalar *, PetscDeviceContext) noexcept;
56   template <std::size_t... Idx>
57   static PetscErrorCode MDot_kernel_dispatch_(PetscDeviceContext, cupmStream_t, const PetscScalar *, const Vec[], PetscInt, PetscScalar *, util::index_sequence<Idx...>) noexcept;
58   template <int>
59   static PetscErrorCode MDot_kernel_dispatch_(PetscDeviceContext, cupmStream_t, const PetscScalar *, const Vec[], PetscInt, PetscScalar *, PetscInt &) noexcept;
60   template <std::size_t... Idx>
61   static PetscErrorCode MAXPY_kernel_dispatch_(PetscDeviceContext, cupmStream_t, PetscScalar *, const PetscScalar *, const Vec *, PetscInt, util::index_sequence<Idx...>) noexcept;
62   template <int>
63   static PetscErrorCode MAXPY_kernel_dispatch_(PetscDeviceContext, cupmStream_t, PetscScalar *, const PetscScalar *, const Vec *, PetscInt, PetscInt &) noexcept;
64   // common core for the various create routines
65   static PetscErrorCode CreateSeqCUPM_(Vec, PetscDeviceContext, PetscScalar * /*host_ptr*/ = nullptr, PetscScalar * /*device_ptr*/ = nullptr) noexcept;
66 
67 public:
68   // callable directly via a bespoke function
69   static PetscErrorCode CreateSeqCUPM(MPI_Comm, PetscInt, PetscInt, Vec *, PetscBool) noexcept;
70   static PetscErrorCode CreateSeqCUPMWithBothArrays(MPI_Comm, PetscInt, PetscInt, const PetscScalar[], const PetscScalar[], Vec *) noexcept;
71 
72   static PetscErrorCode InitializeAsyncFunctions(Vec) noexcept;
73   static PetscErrorCode ClearAsyncFunctions(Vec) noexcept;
74 
75   // callable indirectly via function pointers
76   static PetscErrorCode Duplicate(Vec, Vec *) noexcept;
77   static PetscErrorCode AYPX(Vec, PetscScalar, Vec) noexcept;
78   static PetscErrorCode AYPXAsync(Vec, PetscScalar, Vec, PetscDeviceContext) noexcept;
79   static PetscErrorCode AXPY(Vec, PetscScalar, Vec) noexcept;
80   static PetscErrorCode AXPYAsync(Vec, PetscScalar, Vec, PetscDeviceContext) noexcept;
81   static PetscErrorCode PointwiseDivide(Vec, Vec, Vec) noexcept;
82   static PetscErrorCode PointwiseDivideAsync(Vec, Vec, Vec, PetscDeviceContext) noexcept;
83   static PetscErrorCode PointwiseMult(Vec, Vec, Vec) noexcept;
84   static PetscErrorCode PointwiseMultAsync(Vec, Vec, Vec, PetscDeviceContext) noexcept;
85   static PetscErrorCode PointwiseMax(Vec, Vec, Vec) noexcept;
86   static PetscErrorCode PointwiseMaxAsync(Vec, Vec, Vec, PetscDeviceContext) noexcept;
87   static PetscErrorCode PointwiseMaxAbs(Vec, Vec, Vec) noexcept;
88   static PetscErrorCode PointwiseMaxAbsAsync(Vec, Vec, Vec, PetscDeviceContext) noexcept;
89   static PetscErrorCode PointwiseMin(Vec, Vec, Vec) noexcept;
90   static PetscErrorCode PointwiseMinAsync(Vec, Vec, Vec, PetscDeviceContext) noexcept;
91   static PetscErrorCode Reciprocal(Vec) noexcept;
92   static PetscErrorCode ReciprocalAsync(Vec, PetscDeviceContext) noexcept;
93   static PetscErrorCode Abs(Vec) noexcept;
94   static PetscErrorCode AbsAsync(Vec, PetscDeviceContext) noexcept;
95   static PetscErrorCode SqrtAbs(Vec) noexcept;
96   static PetscErrorCode SqrtAbsAsync(Vec, PetscDeviceContext) noexcept;
97   static PetscErrorCode Exp(Vec) noexcept;
98   static PetscErrorCode ExpAsync(Vec, PetscDeviceContext) noexcept;
99   static PetscErrorCode Log(Vec) noexcept;
100   static PetscErrorCode LogAsync(Vec, PetscDeviceContext) noexcept;
101   static PetscErrorCode WAXPY(Vec, PetscScalar, Vec, Vec) noexcept;
102   static PetscErrorCode WAXPYAsync(Vec, PetscScalar, Vec, Vec, PetscDeviceContext) noexcept;
103   static PetscErrorCode MAXPY(Vec, PetscInt, const PetscScalar[], Vec *) noexcept;
104   static PetscErrorCode MAXPYAsync(Vec, PetscInt, const PetscScalar[], Vec *, PetscDeviceContext) noexcept;
105   static PetscErrorCode Dot(Vec, Vec, PetscScalar *) noexcept;
106   static PetscErrorCode MDot(Vec, PetscInt, const Vec[], PetscScalar *) noexcept;
107   static PetscErrorCode Set(Vec, PetscScalar) noexcept;
108   static PetscErrorCode SetAsync(Vec, PetscScalar, PetscDeviceContext) noexcept;
109   static PetscErrorCode Scale(Vec, PetscScalar) noexcept;
110   static PetscErrorCode ScaleAsync(Vec, PetscScalar, PetscDeviceContext) noexcept;
111   static PetscErrorCode TDot(Vec, Vec, PetscScalar *) noexcept;
112   static PetscErrorCode Copy(Vec, Vec) noexcept;
113   static PetscErrorCode CopyAsync(Vec, Vec, PetscDeviceContext) noexcept;
114   static PetscErrorCode Swap(Vec, Vec) noexcept;
115   static PetscErrorCode SwapAsync(Vec, Vec, PetscDeviceContext) noexcept;
116   static PetscErrorCode AXPBY(Vec, PetscScalar, PetscScalar, Vec) noexcept;
117   static PetscErrorCode AXPBYAsync(Vec, PetscScalar, PetscScalar, Vec, PetscDeviceContext) noexcept;
118   static PetscErrorCode AXPBYPCZ(Vec, PetscScalar, PetscScalar, PetscScalar, Vec, Vec) noexcept;
119   static PetscErrorCode AXPBYPCZAsync(Vec, PetscScalar, PetscScalar, PetscScalar, Vec, Vec, PetscDeviceContext) noexcept;
120   static PetscErrorCode Norm(Vec, NormType, PetscReal *) noexcept;
121   static PetscErrorCode ErrorWnorm(Vec, Vec, Vec, NormType, PetscReal, Vec, PetscReal, Vec, PetscReal, PetscReal *, PetscInt *, PetscReal *, PetscInt *, PetscReal *, PetscInt *) noexcept;
122   static PetscErrorCode DotNorm2(Vec, Vec, PetscScalar *, PetscScalar *) noexcept;
123   static PetscErrorCode Conjugate(Vec) noexcept;
124   static PetscErrorCode ConjugateAsync(Vec, PetscDeviceContext) noexcept;
125   template <PetscMemoryAccessMode>
126   static PetscErrorCode GetLocalVector(Vec, Vec) noexcept;
127   template <PetscMemoryAccessMode>
128   static PetscErrorCode RestoreLocalVector(Vec, Vec) noexcept;
129   static PetscErrorCode Max(Vec, PetscInt *, PetscReal *) noexcept;
130   static PetscErrorCode Min(Vec, PetscInt *, PetscReal *) noexcept;
131   static PetscErrorCode Sum(Vec, PetscScalar *) noexcept;
132   static PetscErrorCode Shift(Vec, PetscScalar) noexcept;
133   static PetscErrorCode ShiftAsync(Vec, PetscScalar, PetscDeviceContext) noexcept;
134   static PetscErrorCode SetRandom(Vec, PetscRandom) noexcept;
135   static PetscErrorCode BindToCPU(Vec, PetscBool) noexcept;
136   static PetscErrorCode SetPreallocationCOO(Vec, PetscCount, const PetscInt[]) noexcept;
137   static PetscErrorCode SetValuesCOO(Vec, const PetscScalar[], InsertMode) noexcept;
138 };
139 
140 namespace kernels
141 {
142 
143 template <typename F>
144 PETSC_DEVICE_INLINE_DECL void add_coo_values_impl(const PetscScalar *PETSC_RESTRICT vv, PetscCount n, const PetscCount *PETSC_RESTRICT jmap, const PetscCount *PETSC_RESTRICT perm, InsertMode imode, PetscScalar *PETSC_RESTRICT xv, F &&xvindex)
145 {
146   ::Petsc::device::cupm::kernels::util::grid_stride_1D(n, [=](PetscCount i) {
147     const auto  end = jmap[i + 1];
148     const auto  idx = xvindex(i);
149     PetscScalar sum = 0.0;
150 
151     for (auto k = jmap[i]; k < end; ++k) sum += vv[perm[k]];
152 
153     if (imode == INSERT_VALUES) {
154       xv[idx] = sum;
155     } else {
156       xv[idx] += sum;
157     }
158   });
159   return;
160 }
161 
162 namespace
163 {
164 
165 PETSC_KERNEL_DECL void add_coo_values(const PetscScalar *PETSC_RESTRICT v, PetscCount n, const PetscCount *PETSC_RESTRICT jmap1, const PetscCount *PETSC_RESTRICT perm1, InsertMode imode, PetscScalar *PETSC_RESTRICT xv)
166 {
167   add_coo_values_impl(v, n, jmap1, perm1, imode, xv, [](PetscCount i) { return i; });
168   return;
169 }
170 
171 } // namespace
172 
173 #if PetscDefined(USING_HCC)
174 namespace do_not_use
175 {
176 
177 // Needed to silence clang warning:
178 //
179 // warning: function 'FUNCTION NAME' is not needed and will not be emitted
180 //
181 // The warning is silly, since the function *is* used, however the host compiler does not
182 // appear see this. Likely because the function using it is in a template.
183 //
184 // This warning appeared in clang-11, and still persists until clang-15 (21/02/2023)
185 inline void silence_warning_function_add_coo_values_is_not_needed_and_will_not_be_emitted()
186 {
187   (void)add_coo_values;
188 }
189 
190 } // namespace do_not_use
191 #endif
192 
193 } // namespace kernels
194 
195 } // namespace impl
196 
197 // ==========================================================================================
198 // VecSeq_CUPM - Implementations
199 // ==========================================================================================
200 
201 template <device::cupm::DeviceType T>
202 inline PetscErrorCode VecCreateSeqCUPMAsync(MPI_Comm comm, PetscInt n, Vec *v) noexcept
203 {
204   PetscFunctionBegin;
205   PetscAssertPointer(v, 4);
206   PetscCall(impl::VecSeq_CUPM<T>::CreateSeqCUPM(comm, 0, n, v, PETSC_TRUE));
207   PetscFunctionReturn(PETSC_SUCCESS);
208 }
209 
210 template <device::cupm::DeviceType T>
211 inline PetscErrorCode VecCreateSeqCUPMWithArraysAsync(MPI_Comm comm, PetscInt bs, PetscInt n, const PetscScalar cpuarray[], const PetscScalar gpuarray[], Vec *v) noexcept
212 {
213   PetscFunctionBegin;
214   if (n && cpuarray) PetscAssertPointer(cpuarray, 4);
215   PetscAssertPointer(v, 6);
216   PetscCall(impl::VecSeq_CUPM<T>::CreateSeqCUPMWithBothArrays(comm, bs, n, cpuarray, gpuarray, v));
217   PetscFunctionReturn(PETSC_SUCCESS);
218 }
219 
220 template <PetscMemoryAccessMode mode, device::cupm::DeviceType T>
221 inline PetscErrorCode VecCUPMGetArrayAsync_Private(Vec v, PetscScalar **a, PetscDeviceContext dctx) noexcept
222 {
223   PetscFunctionBegin;
224   PetscValidHeaderSpecific(v, VEC_CLASSID, 1);
225   PetscAssertPointer(a, 2);
226   PetscCall(PetscDeviceContextGetOptionalNullContext_Internal(&dctx));
227   PetscCall(impl::VecSeq_CUPM<T>::template GetArray<PETSC_MEMTYPE_DEVICE, mode>(v, a, dctx));
228   PetscFunctionReturn(PETSC_SUCCESS);
229 }
230 
231 template <PetscMemoryAccessMode mode, device::cupm::DeviceType T>
232 inline PetscErrorCode VecCUPMRestoreArrayAsync_Private(Vec v, PetscScalar **a, PetscDeviceContext dctx) noexcept
233 {
234   PetscFunctionBegin;
235   PetscValidHeaderSpecific(v, VEC_CLASSID, 1);
236   PetscCall(PetscDeviceContextGetOptionalNullContext_Internal(&dctx));
237   PetscCall(impl::VecSeq_CUPM<T>::template RestoreArray<PETSC_MEMTYPE_DEVICE, mode>(v, a, dctx));
238   PetscFunctionReturn(PETSC_SUCCESS);
239 }
240 
241 template <device::cupm::DeviceType T>
242 inline PetscErrorCode VecCUPMGetArrayAsync(Vec v, PetscScalar **a, PetscDeviceContext dctx = nullptr) noexcept
243 {
244   PetscFunctionBegin;
245   PetscCall(VecCUPMGetArrayAsync_Private<PETSC_MEMORY_ACCESS_READ_WRITE, T>(v, a, dctx));
246   PetscFunctionReturn(PETSC_SUCCESS);
247 }
248 
249 template <device::cupm::DeviceType T>
250 inline PetscErrorCode VecCUPMRestoreArrayAsync(Vec v, PetscScalar **a, PetscDeviceContext dctx = nullptr) noexcept
251 {
252   PetscFunctionBegin;
253   PetscCall(VecCUPMRestoreArrayAsync_Private<PETSC_MEMORY_ACCESS_READ_WRITE, T>(v, a, dctx));
254   PetscFunctionReturn(PETSC_SUCCESS);
255 }
256 
257 template <device::cupm::DeviceType T>
258 inline PetscErrorCode VecCUPMGetArrayReadAsync(Vec v, const PetscScalar **a, PetscDeviceContext dctx = nullptr) noexcept
259 {
260   PetscFunctionBegin;
261   PetscCall(VecCUPMGetArrayAsync_Private<PETSC_MEMORY_ACCESS_READ, T>(v, const_cast<PetscScalar **>(a), dctx));
262   PetscFunctionReturn(PETSC_SUCCESS);
263 }
264 
265 template <device::cupm::DeviceType T>
266 inline PetscErrorCode VecCUPMRestoreArrayReadAsync(Vec v, const PetscScalar **a, PetscDeviceContext dctx = nullptr) noexcept
267 {
268   PetscFunctionBegin;
269   PetscCall(VecCUPMRestoreArrayAsync_Private<PETSC_MEMORY_ACCESS_READ, T>(v, const_cast<PetscScalar **>(a), dctx));
270   PetscFunctionReturn(PETSC_SUCCESS);
271 }
272 
273 template <device::cupm::DeviceType T>
274 inline PetscErrorCode VecCUPMGetArrayWriteAsync(Vec v, PetscScalar **a, PetscDeviceContext dctx = nullptr) noexcept
275 {
276   PetscFunctionBegin;
277   PetscCall(VecCUPMGetArrayAsync_Private<PETSC_MEMORY_ACCESS_WRITE, T>(v, a, dctx));
278   PetscFunctionReturn(PETSC_SUCCESS);
279 }
280 
281 template <device::cupm::DeviceType T>
282 inline PetscErrorCode VecCUPMRestoreArrayWriteAsync(Vec v, PetscScalar **a, PetscDeviceContext dctx = nullptr) noexcept
283 {
284   PetscFunctionBegin;
285   PetscCall(VecCUPMRestoreArrayAsync_Private<PETSC_MEMORY_ACCESS_WRITE, T>(v, a, dctx));
286   PetscFunctionReturn(PETSC_SUCCESS);
287 }
288 
289 template <device::cupm::DeviceType T>
290 inline PetscErrorCode VecCUPMPlaceArrayAsync(Vec vin, const PetscScalar a[]) noexcept
291 {
292   PetscFunctionBegin;
293   PetscValidHeaderSpecific(vin, VEC_CLASSID, 1);
294   PetscCall(impl::VecSeq_CUPM<T>::template PlaceArray<PETSC_MEMTYPE_DEVICE>(vin, a));
295   PetscFunctionReturn(PETSC_SUCCESS);
296 }
297 
298 template <device::cupm::DeviceType T>
299 inline PetscErrorCode VecCUPMReplaceArrayAsync(Vec vin, const PetscScalar a[]) noexcept
300 {
301   PetscFunctionBegin;
302   PetscValidHeaderSpecific(vin, VEC_CLASSID, 1);
303   PetscCall(impl::VecSeq_CUPM<T>::template ReplaceArray<PETSC_MEMTYPE_DEVICE>(vin, a));
304   PetscFunctionReturn(PETSC_SUCCESS);
305 }
306 
307 template <device::cupm::DeviceType T>
308 inline PetscErrorCode VecCUPMResetArrayAsync(Vec vin) noexcept
309 {
310   PetscFunctionBegin;
311   PetscValidHeaderSpecific(vin, VEC_CLASSID, 1);
312   PetscCall(impl::VecSeq_CUPM<T>::template ResetArray<PETSC_MEMTYPE_DEVICE>(vin));
313   PetscFunctionReturn(PETSC_SUCCESS);
314 }
315 
316 } // namespace cupm
317 
318 } // namespace vec
319 
320 } // namespace Petsc
321 
322 #if PetscDefined(HAVE_CUDA)
323 extern template class PETSC_SINGLE_LIBRARY_VISIBILITY_INTERNAL ::Petsc::vec::cupm::impl::VecSeq_CUPM<::Petsc::device::cupm::DeviceType::CUDA>;
324 #endif
325 
326 #if PetscDefined(HAVE_HIP)
327 extern template class PETSC_SINGLE_LIBRARY_VISIBILITY_INTERNAL ::Petsc::vec::cupm::impl::VecSeq_CUPM<::Petsc::device::cupm::DeviceType::HIP>;
328 #endif
329 
330 #endif // PETSCVECSEQCUPM_HPP
331