xref: /petsc/src/vec/vec/impls/seq/cupm/vecseqcupm.hpp (revision a77d671b1191ed14ce34133f22d77229bbb91e89)
1 #pragma once
2 
3 #include <petsc/private/veccupmimpl.h>
4 #include <petsc/private/cpp/utility.hpp> // util::index_sequence
5 
6 #include <../src/sys/objects/device/impls/cupm/kernels.hpp> // grid_stride_1D()
7 #include <../src/vec/vec/impls/dvecimpl.h>                  // Vec_Seq
8 
9 namespace Petsc
10 {
11 
12 namespace vec
13 {
14 
15 namespace cupm
16 {
17 
18 namespace impl
19 {
20 
21 // ==========================================================================================
22 // VecSeq_CUPM
23 // ==========================================================================================
24 
25 template <device::cupm::DeviceType T>
26 class VecSeq_CUPM : Vec_CUPMBase<T, VecSeq_CUPM<T>> {
27 public:
28   PETSC_VEC_CUPM_BASE_CLASS_HEADER(base_type, T, VecSeq_CUPM<T>);
29 
30 private:
31   PETSC_NODISCARD static Vec_Seq          *VecIMPLCast_(Vec) noexcept;
32   PETSC_NODISCARD static constexpr VecType VECIMPLCUPM_() noexcept;
33   PETSC_NODISCARD static constexpr VecType VECIMPL_() noexcept;
34 
35   static PetscErrorCode VecDestroy_IMPL_(Vec) noexcept;
36   static PetscErrorCode VecResetArray_IMPL_(Vec) noexcept;
37   static PetscErrorCode VecPlaceArray_IMPL_(Vec, const PetscScalar *) noexcept;
38   static PetscErrorCode VecCreate_IMPL_Private_(Vec, PetscBool *, PetscInt, PetscScalar *) noexcept;
39 
40   static PetscErrorCode MaybeIncrementEmptyLocalVec(Vec) noexcept;
41 
42   // common core for min and max
43   template <typename TupleFuncT, typename UnaryFuncT>
44   static PetscErrorCode MinMax_(TupleFuncT &&, UnaryFuncT &&, Vec, PetscInt *, PetscReal *) noexcept;
45   // common core for pointwise binary and pointwise unary thrust functions
46   template <typename BinaryFuncT>
47   static PetscErrorCode PointwiseBinary_(BinaryFuncT &&, Vec, Vec, Vec, PetscDeviceContext = nullptr) noexcept;
48   template <typename BinaryFuncT>
49   static PetscErrorCode PointwiseBinaryDispatch_(PetscErrorCode (*)(Vec, Vec, Vec), BinaryFuncT &&, Vec, Vec, Vec, PetscDeviceContext = nullptr) noexcept;
50   template <typename UnaryFuncT>
51   static PetscErrorCode PointwiseUnary_(UnaryFuncT &&, Vec, Vec, PetscDeviceContext = nullptr) noexcept;
52   // mdot dispatchers
53   static PetscErrorCode MDot_(/* use complex = */ std::true_type, Vec, PetscInt, const Vec[], PetscScalar *, PetscDeviceContext) noexcept;
54   static PetscErrorCode MDot_(/* use complex = */ std::false_type, Vec, PetscInt, const Vec[], PetscScalar *, PetscDeviceContext) noexcept;
55   template <std::size_t... Idx>
56   static PetscErrorCode MDot_kernel_dispatch_(PetscDeviceContext, cupmStream_t, const PetscScalar *, const Vec[], PetscInt, PetscScalar *, util::index_sequence<Idx...>) noexcept;
57   template <int>
58   static PetscErrorCode MDot_kernel_dispatch_(PetscDeviceContext, cupmStream_t, const PetscScalar *, const Vec[], PetscInt, PetscScalar *, PetscInt &) noexcept;
59   template <std::size_t... Idx>
60   static PetscErrorCode MAXPY_kernel_dispatch_(PetscDeviceContext, cupmStream_t, PetscScalar *, const PetscScalar *, const Vec *, PetscInt, util::index_sequence<Idx...>) noexcept;
61   template <int>
62   static PetscErrorCode MAXPY_kernel_dispatch_(PetscDeviceContext, cupmStream_t, PetscScalar *, const PetscScalar *, const Vec *, PetscInt, PetscInt &) noexcept;
63   // common core for the various create routines
64   static PetscErrorCode CreateSeqCUPM_(Vec, PetscDeviceContext, PetscScalar * /*host_ptr*/ = nullptr, PetscScalar * /*device_ptr*/ = nullptr) noexcept;
65 
66 public:
67   // callable directly via a bespoke function
68   static PetscErrorCode CreateSeqCUPM(MPI_Comm, PetscInt, PetscInt, Vec *, PetscBool) noexcept;
69   static PetscErrorCode CreateSeqCUPMWithBothArrays(MPI_Comm, PetscInt, PetscInt, const PetscScalar[], const PetscScalar[], Vec *) noexcept;
70 
71   static PetscErrorCode InitializeAsyncFunctions(Vec) noexcept;
72   static PetscErrorCode ClearAsyncFunctions(Vec) noexcept;
73 
74   // callable indirectly via function pointers
75   static PetscErrorCode Duplicate(Vec, Vec *) noexcept;
76   static PetscErrorCode AYPX(Vec, PetscScalar, Vec) noexcept;
77   static PetscErrorCode AYPXAsync(Vec, PetscScalar, Vec, PetscDeviceContext) noexcept;
78   static PetscErrorCode AXPY(Vec, PetscScalar, Vec) noexcept;
79   static PetscErrorCode AXPYAsync(Vec, PetscScalar, Vec, PetscDeviceContext) noexcept;
80   static PetscErrorCode PointwiseDivide(Vec, Vec, Vec) noexcept;
81   static PetscErrorCode PointwiseDivideAsync(Vec, Vec, Vec, PetscDeviceContext) noexcept;
82   static PetscErrorCode PointwiseMult(Vec, Vec, Vec) noexcept;
83   static PetscErrorCode PointwiseMultAsync(Vec, Vec, Vec, PetscDeviceContext) noexcept;
84   static PetscErrorCode PointwiseMax(Vec, Vec, Vec) noexcept;
85   static PetscErrorCode PointwiseMaxAsync(Vec, Vec, Vec, PetscDeviceContext) noexcept;
86   static PetscErrorCode PointwiseMaxAbs(Vec, Vec, Vec) noexcept;
87   static PetscErrorCode PointwiseMaxAbsAsync(Vec, Vec, Vec, PetscDeviceContext) noexcept;
88   static PetscErrorCode PointwiseMin(Vec, Vec, Vec) noexcept;
89   static PetscErrorCode PointwiseMinAsync(Vec, Vec, Vec, PetscDeviceContext) noexcept;
90   static PetscErrorCode Reciprocal(Vec) noexcept;
91   static PetscErrorCode ReciprocalAsync(Vec, PetscDeviceContext) noexcept;
92   static PetscErrorCode Abs(Vec) noexcept;
93   static PetscErrorCode AbsAsync(Vec, PetscDeviceContext) noexcept;
94   static PetscErrorCode SqrtAbs(Vec) noexcept;
95   static PetscErrorCode SqrtAbsAsync(Vec, PetscDeviceContext) noexcept;
96   static PetscErrorCode Exp(Vec) noexcept;
97   static PetscErrorCode ExpAsync(Vec, PetscDeviceContext) noexcept;
98   static PetscErrorCode Log(Vec) noexcept;
99   static PetscErrorCode LogAsync(Vec, PetscDeviceContext) noexcept;
100   static PetscErrorCode WAXPY(Vec, PetscScalar, Vec, Vec) noexcept;
101   static PetscErrorCode WAXPYAsync(Vec, PetscScalar, Vec, Vec, PetscDeviceContext) noexcept;
102   static PetscErrorCode MAXPY(Vec, PetscInt, const PetscScalar[], Vec *) noexcept;
103   static PetscErrorCode MAXPYAsync(Vec, PetscInt, const PetscScalar[], Vec *, PetscDeviceContext) noexcept;
104   static PetscErrorCode Dot(Vec, Vec, PetscScalar *) noexcept;
105   static PetscErrorCode MDot(Vec, PetscInt, const Vec[], PetscScalar *) noexcept;
106   static PetscErrorCode Set(Vec, PetscScalar) noexcept;
107   static PetscErrorCode SetAsync(Vec, PetscScalar, PetscDeviceContext) noexcept;
108   static PetscErrorCode Scale(Vec, PetscScalar) noexcept;
109   static PetscErrorCode ScaleAsync(Vec, PetscScalar, PetscDeviceContext) noexcept;
110   static PetscErrorCode TDot(Vec, Vec, PetscScalar *) noexcept;
111   static PetscErrorCode Copy(Vec, Vec) noexcept;
112   static PetscErrorCode CopyAsync(Vec, Vec, PetscDeviceContext) noexcept;
113   static PetscErrorCode Swap(Vec, Vec) noexcept;
114   static PetscErrorCode SwapAsync(Vec, Vec, PetscDeviceContext) noexcept;
115   static PetscErrorCode AXPBY(Vec, PetscScalar, PetscScalar, Vec) noexcept;
116   static PetscErrorCode AXPBYAsync(Vec, PetscScalar, PetscScalar, Vec, PetscDeviceContext) noexcept;
117   static PetscErrorCode AXPBYPCZ(Vec, PetscScalar, PetscScalar, PetscScalar, Vec, Vec) noexcept;
118   static PetscErrorCode AXPBYPCZAsync(Vec, PetscScalar, PetscScalar, PetscScalar, Vec, Vec, PetscDeviceContext) noexcept;
119   static PetscErrorCode Norm(Vec, NormType, PetscReal *) noexcept;
120   static PetscErrorCode ErrorWnorm(Vec, Vec, Vec, NormType, PetscReal, Vec, PetscReal, Vec, PetscReal, PetscReal *, PetscInt *, PetscReal *, PetscInt *, PetscReal *, PetscInt *) noexcept;
121   static PetscErrorCode DotNorm2(Vec, Vec, PetscScalar *, PetscScalar *) noexcept;
122   static PetscErrorCode Conjugate(Vec) noexcept;
123   static PetscErrorCode ConjugateAsync(Vec, PetscDeviceContext) noexcept;
124   template <PetscMemoryAccessMode>
125   static PetscErrorCode GetLocalVector(Vec, Vec) noexcept;
126   template <PetscMemoryAccessMode>
127   static PetscErrorCode RestoreLocalVector(Vec, Vec) noexcept;
128   static PetscErrorCode Max(Vec, PetscInt *, PetscReal *) noexcept;
129   static PetscErrorCode Min(Vec, PetscInt *, PetscReal *) noexcept;
130   static PetscErrorCode Sum(Vec, PetscScalar *) noexcept;
131   static PetscErrorCode Shift(Vec, PetscScalar) noexcept;
132   static PetscErrorCode ShiftAsync(Vec, PetscScalar, PetscDeviceContext) noexcept;
133   static PetscErrorCode SetRandom(Vec, PetscRandom) noexcept;
134   static PetscErrorCode BindToCPU(Vec, PetscBool) noexcept;
135   static PetscErrorCode SetPreallocationCOO(Vec, PetscCount, const PetscInt[]) noexcept;
136   static PetscErrorCode SetValuesCOO(Vec, const PetscScalar[], InsertMode) noexcept;
137 };
138 
139 namespace kernels
140 {
141 
142 template <typename F>
143 PETSC_DEVICE_INLINE_DECL void add_coo_values_impl(const PetscScalar *PETSC_RESTRICT vv, PetscCount n, const PetscCount *PETSC_RESTRICT jmap, const PetscCount *PETSC_RESTRICT perm, InsertMode imode, PetscScalar *PETSC_RESTRICT xv, F &&xvindex)
144 {
145   ::Petsc::device::cupm::kernels::util::grid_stride_1D(n, [=](PetscCount i) {
146     const auto  end = jmap[i + 1];
147     const auto  idx = xvindex(i);
148     PetscScalar sum = 0.0;
149 
150     for (auto k = jmap[i]; k < end; ++k) sum += vv[perm[k]];
151 
152     if (imode == INSERT_VALUES) {
153       xv[idx] = sum;
154     } else {
155       xv[idx] += sum;
156     }
157   });
158   return;
159 }
160 
161 namespace
162 {
163 
164 PETSC_KERNEL_DECL void add_coo_values(const PetscScalar *PETSC_RESTRICT v, PetscCount n, const PetscCount *PETSC_RESTRICT jmap1, const PetscCount *PETSC_RESTRICT perm1, InsertMode imode, PetscScalar *PETSC_RESTRICT xv)
165 {
166   add_coo_values_impl(v, n, jmap1, perm1, imode, xv, [](PetscCount i) { return i; });
167   return;
168 }
169 
170 } // namespace
171 
172 #if PetscDefined(USING_HCC)
173 namespace do_not_use
174 {
175 
176 // Needed to silence clang warning:
177 //
178 // warning: function 'FUNCTION NAME' is not needed and will not be emitted
179 //
180 // The warning is silly, since the function *is* used, however the host compiler does not
181 // appear see this. Likely because the function using it is in a template.
182 //
183 // This warning appeared in clang-11, and still persists until clang-15 (21/02/2023)
184 inline void silence_warning_function_add_coo_values_is_not_needed_and_will_not_be_emitted()
185 {
186   (void)add_coo_values;
187 }
188 
189 } // namespace do_not_use
190 #endif
191 
192 } // namespace kernels
193 
194 } // namespace impl
195 
196 // ==========================================================================================
197 // VecSeq_CUPM - Implementations
198 // ==========================================================================================
199 
200 template <device::cupm::DeviceType T>
201 inline PetscErrorCode VecCreateSeqCUPMAsync(MPI_Comm comm, PetscInt n, Vec *v) noexcept
202 {
203   PetscFunctionBegin;
204   PetscAssertPointer(v, 4);
205   PetscCall(impl::VecSeq_CUPM<T>::CreateSeqCUPM(comm, 0, n, v, PETSC_TRUE));
206   PetscFunctionReturn(PETSC_SUCCESS);
207 }
208 
209 template <device::cupm::DeviceType T>
210 inline PetscErrorCode VecCreateSeqCUPMWithArraysAsync(MPI_Comm comm, PetscInt bs, PetscInt n, const PetscScalar cpuarray[], const PetscScalar gpuarray[], Vec *v) noexcept
211 {
212   PetscFunctionBegin;
213   if (n && cpuarray) PetscAssertPointer(cpuarray, 4);
214   PetscAssertPointer(v, 6);
215   PetscCall(impl::VecSeq_CUPM<T>::CreateSeqCUPMWithBothArrays(comm, bs, n, cpuarray, gpuarray, v));
216   PetscFunctionReturn(PETSC_SUCCESS);
217 }
218 
219 template <PetscMemoryAccessMode mode, device::cupm::DeviceType T>
220 inline PetscErrorCode VecCUPMGetArrayAsync_Private(Vec v, PetscScalar **a, PetscDeviceContext dctx) noexcept
221 {
222   PetscFunctionBegin;
223   PetscValidHeaderSpecific(v, VEC_CLASSID, 1);
224   PetscAssertPointer(a, 2);
225   PetscCall(PetscDeviceContextGetOptionalNullContext_Internal(&dctx));
226   PetscCall(impl::VecSeq_CUPM<T>::template GetArray<PETSC_MEMTYPE_DEVICE, mode>(v, a, dctx));
227   PetscFunctionReturn(PETSC_SUCCESS);
228 }
229 
230 template <PetscMemoryAccessMode mode, device::cupm::DeviceType T>
231 inline PetscErrorCode VecCUPMRestoreArrayAsync_Private(Vec v, PetscScalar **a, PetscDeviceContext dctx) noexcept
232 {
233   PetscFunctionBegin;
234   PetscValidHeaderSpecific(v, VEC_CLASSID, 1);
235   PetscCall(PetscDeviceContextGetOptionalNullContext_Internal(&dctx));
236   PetscCall(impl::VecSeq_CUPM<T>::template RestoreArray<PETSC_MEMTYPE_DEVICE, mode>(v, a, dctx));
237   PetscFunctionReturn(PETSC_SUCCESS);
238 }
239 
240 template <device::cupm::DeviceType T>
241 inline PetscErrorCode VecCUPMGetArrayAsync(Vec v, PetscScalar **a, PetscDeviceContext dctx = nullptr) noexcept
242 {
243   PetscFunctionBegin;
244   PetscCall(VecCUPMGetArrayAsync_Private<PETSC_MEMORY_ACCESS_READ_WRITE, T>(v, a, dctx));
245   PetscFunctionReturn(PETSC_SUCCESS);
246 }
247 
248 template <device::cupm::DeviceType T>
249 inline PetscErrorCode VecCUPMRestoreArrayAsync(Vec v, PetscScalar **a, PetscDeviceContext dctx = nullptr) noexcept
250 {
251   PetscFunctionBegin;
252   PetscCall(VecCUPMRestoreArrayAsync_Private<PETSC_MEMORY_ACCESS_READ_WRITE, T>(v, a, dctx));
253   PetscFunctionReturn(PETSC_SUCCESS);
254 }
255 
256 template <device::cupm::DeviceType T>
257 inline PetscErrorCode VecCUPMGetArrayReadAsync(Vec v, const PetscScalar **a, PetscDeviceContext dctx = nullptr) noexcept
258 {
259   PetscFunctionBegin;
260   PetscCall(VecCUPMGetArrayAsync_Private<PETSC_MEMORY_ACCESS_READ, T>(v, const_cast<PetscScalar **>(a), dctx));
261   PetscFunctionReturn(PETSC_SUCCESS);
262 }
263 
264 template <device::cupm::DeviceType T>
265 inline PetscErrorCode VecCUPMRestoreArrayReadAsync(Vec v, const PetscScalar **a, PetscDeviceContext dctx = nullptr) noexcept
266 {
267   PetscFunctionBegin;
268   PetscCall(VecCUPMRestoreArrayAsync_Private<PETSC_MEMORY_ACCESS_READ, T>(v, const_cast<PetscScalar **>(a), dctx));
269   PetscFunctionReturn(PETSC_SUCCESS);
270 }
271 
272 template <device::cupm::DeviceType T>
273 inline PetscErrorCode VecCUPMGetArrayWriteAsync(Vec v, PetscScalar **a, PetscDeviceContext dctx = nullptr) noexcept
274 {
275   PetscFunctionBegin;
276   PetscCall(VecCUPMGetArrayAsync_Private<PETSC_MEMORY_ACCESS_WRITE, T>(v, a, dctx));
277   PetscFunctionReturn(PETSC_SUCCESS);
278 }
279 
280 template <device::cupm::DeviceType T>
281 inline PetscErrorCode VecCUPMRestoreArrayWriteAsync(Vec v, PetscScalar **a, PetscDeviceContext dctx = nullptr) noexcept
282 {
283   PetscFunctionBegin;
284   PetscCall(VecCUPMRestoreArrayAsync_Private<PETSC_MEMORY_ACCESS_WRITE, T>(v, a, dctx));
285   PetscFunctionReturn(PETSC_SUCCESS);
286 }
287 
288 template <device::cupm::DeviceType T>
289 inline PetscErrorCode VecCUPMPlaceArrayAsync(Vec vin, const PetscScalar a[]) noexcept
290 {
291   PetscFunctionBegin;
292   PetscValidHeaderSpecific(vin, VEC_CLASSID, 1);
293   PetscCall(impl::VecSeq_CUPM<T>::template PlaceArray<PETSC_MEMTYPE_DEVICE>(vin, a));
294   PetscFunctionReturn(PETSC_SUCCESS);
295 }
296 
297 template <device::cupm::DeviceType T>
298 inline PetscErrorCode VecCUPMReplaceArrayAsync(Vec vin, const PetscScalar a[]) noexcept
299 {
300   PetscFunctionBegin;
301   PetscValidHeaderSpecific(vin, VEC_CLASSID, 1);
302   PetscCall(impl::VecSeq_CUPM<T>::template ReplaceArray<PETSC_MEMTYPE_DEVICE>(vin, a));
303   PetscFunctionReturn(PETSC_SUCCESS);
304 }
305 
306 template <device::cupm::DeviceType T>
307 inline PetscErrorCode VecCUPMResetArrayAsync(Vec vin) noexcept
308 {
309   PetscFunctionBegin;
310   PetscValidHeaderSpecific(vin, VEC_CLASSID, 1);
311   PetscCall(impl::VecSeq_CUPM<T>::template ResetArray<PETSC_MEMTYPE_DEVICE>(vin));
312   PetscFunctionReturn(PETSC_SUCCESS);
313 }
314 
315 } // namespace cupm
316 
317 } // namespace vec
318 
319 } // namespace Petsc
320 
321 #if PetscDefined(HAVE_CUDA)
322 extern template class PETSC_SINGLE_LIBRARY_VISIBILITY_INTERNAL ::Petsc::vec::cupm::impl::VecSeq_CUPM<::Petsc::device::cupm::DeviceType::CUDA>;
323 #endif
324 
325 #if PetscDefined(HAVE_HIP)
326 extern template class PETSC_SINGLE_LIBRARY_VISIBILITY_INTERNAL ::Petsc::vec::cupm::impl::VecSeq_CUPM<::Petsc::device::cupm::DeviceType::HIP>;
327 #endif
328