xref: /petsc/src/vec/vec/impls/seq/cupm/vecseqcupm.hpp (revision bcdedc735d4d558206d0dbf4329cd11c3414be55)
1 #ifndef PETSCVECSEQCUPM_HPP
2 #define PETSCVECSEQCUPM_HPP
3 
4 #include <petsc/private/veccupmimpl.h>
5 #include <petsc/private/cpp/utility.hpp> // util::index_sequence
6 
7 #include <../src/sys/objects/device/impls/cupm/kernels.hpp> // grid_stride_1D()
8 #include <../src/vec/vec/impls/dvecimpl.h>                  // Vec_Seq
9 
10 namespace Petsc
11 {
12 
13 namespace vec
14 {
15 
16 namespace cupm
17 {
18 
19 namespace impl
20 {
21 
22 // ==========================================================================================
23 // VecSeq_CUPM
24 // ==========================================================================================
25 
26 template <device::cupm::DeviceType T>
27 class VecSeq_CUPM : Vec_CUPMBase<T, VecSeq_CUPM<T>> {
28 public:
29   PETSC_VEC_CUPM_BASE_CLASS_HEADER(base_type, T, VecSeq_CUPM<T>);
30 
31 private:
32   PETSC_NODISCARD static Vec_Seq          *VecIMPLCast_(Vec) noexcept;
33   PETSC_NODISCARD static constexpr VecType VECIMPLCUPM_() noexcept;
34   PETSC_NODISCARD static constexpr VecType VECIMPL_() noexcept;
35 
36   static PetscErrorCode VecDestroy_IMPL_(Vec) noexcept;
37   static PetscErrorCode VecResetArray_IMPL_(Vec) noexcept;
38   static PetscErrorCode VecPlaceArray_IMPL_(Vec, const PetscScalar *) noexcept;
39   static PetscErrorCode VecCreate_IMPL_Private_(Vec, PetscBool *, PetscInt, PetscScalar *) noexcept;
40 
41   static PetscErrorCode MaybeIncrementEmptyLocalVec(Vec) noexcept;
42 
43   // common core for min and max
44   template <typename TupleFuncT, typename UnaryFuncT>
45   static PetscErrorCode MinMax_(TupleFuncT &&, UnaryFuncT &&, Vec, PetscInt *, PetscReal *) noexcept;
46   // common core for pointwise binary and pointwise unary thrust functions
47   template <typename BinaryFuncT>
48   static PetscErrorCode PointwiseBinary_(BinaryFuncT &&, Vec, Vec, Vec) noexcept;
49   template <typename UnaryFuncT>
50   static PetscErrorCode PointwiseUnary_(UnaryFuncT &&, Vec, Vec /*out*/ = nullptr) noexcept;
51   // mdot dispatchers
52   static PetscErrorCode MDot_(/* use complex = */ std::true_type, Vec, PetscInt, const Vec[], PetscScalar *, PetscDeviceContext) noexcept;
53   static PetscErrorCode MDot_(/* use complex = */ std::false_type, Vec, PetscInt, const Vec[], PetscScalar *, PetscDeviceContext) noexcept;
54   template <std::size_t... Idx>
55   static PetscErrorCode MDot_kernel_dispatch_(PetscDeviceContext, cupmStream_t, const PetscScalar *, const Vec[], PetscInt, PetscScalar *, util::index_sequence<Idx...>) noexcept;
56   template <int>
57   static PetscErrorCode MDot_kernel_dispatch_(PetscDeviceContext, cupmStream_t, const PetscScalar *, const Vec[], PetscInt, PetscScalar *, PetscInt &) noexcept;
58   template <std::size_t... Idx>
59   static PetscErrorCode MAXPY_kernel_dispatch_(PetscDeviceContext, cupmStream_t, PetscScalar *, const PetscScalar *, const Vec *, PetscInt, util::index_sequence<Idx...>) noexcept;
60   template <int>
61   static PetscErrorCode MAXPY_kernel_dispatch_(PetscDeviceContext, cupmStream_t, PetscScalar *, const PetscScalar *, const Vec *, PetscInt, PetscInt &) noexcept;
62   // common core for the various create routines
63   static PetscErrorCode CreateSeqCUPM_(Vec, PetscDeviceContext, PetscScalar * /*host_ptr*/ = nullptr, PetscScalar * /*device_ptr*/ = nullptr) noexcept;
64 
65 public:
66   // callable directly via a bespoke function
67   static PetscErrorCode CreateSeqCUPM(MPI_Comm, PetscInt, PetscInt, Vec *, PetscBool) noexcept;
68   static PetscErrorCode CreateSeqCUPMWithBothArrays(MPI_Comm, PetscInt, PetscInt, const PetscScalar[], const PetscScalar[], Vec *) noexcept;
69 
70   // callable indirectly via function pointers
71   static PetscErrorCode Duplicate(Vec, Vec *) noexcept;
72   static PetscErrorCode AYPX(Vec, PetscScalar, Vec) noexcept;
73   static PetscErrorCode AXPY(Vec, PetscScalar, Vec) noexcept;
74   static PetscErrorCode PointwiseDivide(Vec, Vec, Vec) noexcept;
75   static PetscErrorCode PointwiseMult(Vec, Vec, Vec) noexcept;
76   static PetscErrorCode Reciprocal(Vec) noexcept;
77   static PetscErrorCode WAXPY(Vec, PetscScalar, Vec, Vec) noexcept;
78   static PetscErrorCode MAXPY(Vec, PetscInt, const PetscScalar[], Vec *) noexcept;
79   static PetscErrorCode Dot(Vec, Vec, PetscScalar *) noexcept;
80   static PetscErrorCode MDot(Vec, PetscInt, const Vec[], PetscScalar *) noexcept;
81   static PetscErrorCode Set(Vec, PetscScalar) noexcept;
82   static PetscErrorCode Scale(Vec, PetscScalar) noexcept;
83   static PetscErrorCode TDot(Vec, Vec, PetscScalar *) noexcept;
84   static PetscErrorCode Copy(Vec, Vec) noexcept;
85   static PetscErrorCode Swap(Vec, Vec) noexcept;
86   static PetscErrorCode AXPBY(Vec, PetscScalar, PetscScalar, Vec) noexcept;
87   static PetscErrorCode AXPBYPCZ(Vec, PetscScalar, PetscScalar, PetscScalar, Vec, Vec) noexcept;
88   static PetscErrorCode Norm(Vec, NormType, PetscReal *) noexcept;
89   static PetscErrorCode ErrorWnorm(Vec, Vec, Vec, NormType, PetscReal, Vec, PetscReal, Vec, PetscReal, PetscReal *, PetscInt *, PetscReal *, PetscInt *, PetscReal *, PetscInt *) noexcept;
90   static PetscErrorCode DotNorm2(Vec, Vec, PetscScalar *, PetscScalar *) noexcept;
91   static PetscErrorCode Conjugate(Vec) noexcept;
92   template <PetscMemoryAccessMode>
93   static PetscErrorCode GetLocalVector(Vec, Vec) noexcept;
94   template <PetscMemoryAccessMode>
95   static PetscErrorCode RestoreLocalVector(Vec, Vec) noexcept;
96   static PetscErrorCode Max(Vec, PetscInt *, PetscReal *) noexcept;
97   static PetscErrorCode Min(Vec, PetscInt *, PetscReal *) noexcept;
98   static PetscErrorCode Sum(Vec, PetscScalar *) noexcept;
99   static PetscErrorCode Shift(Vec, PetscScalar) noexcept;
100   static PetscErrorCode SetRandom(Vec, PetscRandom) noexcept;
101   static PetscErrorCode BindToCPU(Vec, PetscBool) noexcept;
102   static PetscErrorCode SetPreallocationCOO(Vec, PetscCount, const PetscInt[]) noexcept;
103   static PetscErrorCode SetValuesCOO(Vec, const PetscScalar[], InsertMode) noexcept;
104 };
105 
106 namespace kernels
107 {
108 
109 template <typename F>
110 PETSC_DEVICE_INLINE_DECL void add_coo_values_impl(const PetscScalar *PETSC_RESTRICT vv, PetscCount n, const PetscCount *PETSC_RESTRICT jmap, const PetscCount *PETSC_RESTRICT perm, InsertMode imode, PetscScalar *PETSC_RESTRICT xv, F &&xvindex)
111 {
112   ::Petsc::device::cupm::kernels::util::grid_stride_1D(n, [=](PetscCount i) {
113     const auto  end = jmap[i + 1];
114     const auto  idx = xvindex(i);
115     PetscScalar sum = 0.0;
116 
117     for (auto k = jmap[i]; k < end; ++k) sum += vv[perm[k]];
118 
119     if (imode == INSERT_VALUES) {
120       xv[idx] = sum;
121     } else {
122       xv[idx] += sum;
123     }
124   });
125   return;
126 }
127 
128 namespace
129 {
130 
131 PETSC_KERNEL_DECL void add_coo_values(const PetscScalar *PETSC_RESTRICT v, PetscCount n, const PetscCount *PETSC_RESTRICT jmap1, const PetscCount *PETSC_RESTRICT perm1, InsertMode imode, PetscScalar *PETSC_RESTRICT xv)
132 {
133   add_coo_values_impl(v, n, jmap1, perm1, imode, xv, [](PetscCount i) { return i; });
134   return;
135 }
136 
137 } // namespace
138 
139 #if PetscDefined(USING_HCC)
140 namespace do_not_use
141 {
142 
143 // Needed to silence clang warning:
144 //
145 // warning: function 'FUNCTION NAME' is not needed and will not be emitted
146 //
147 // The warning is silly, since the function *is* used, however the host compiler does not
148 // appear see this. Likely because the function using it is in a template.
149 //
150 // This warning appeared in clang-11, and still persists until clang-15 (21/02/2023)
151 inline void silence_warning_function_add_coo_values_is_not_needed_and_will_not_be_emitted()
152 {
153   (void)add_coo_values;
154 }
155 
156 } // namespace do_not_use
157 #endif
158 
159 } // namespace kernels
160 
161 } // namespace impl
162 
163 // ==========================================================================================
164 // VecSeq_CUPM - Implementations
165 // ==========================================================================================
166 
167 template <device::cupm::DeviceType T>
168 inline PetscErrorCode VecCreateSeqCUPMAsync(MPI_Comm comm, PetscInt n, Vec *v) noexcept
169 {
170   PetscFunctionBegin;
171   PetscValidPointer(v, 4);
172   PetscCall(impl::VecSeq_CUPM<T>::CreateSeqCUPM(comm, 0, n, v, PETSC_TRUE));
173   PetscFunctionReturn(PETSC_SUCCESS);
174 }
175 
176 template <device::cupm::DeviceType T>
177 inline PetscErrorCode VecCreateSeqCUPMWithArraysAsync(MPI_Comm comm, PetscInt bs, PetscInt n, const PetscScalar cpuarray[], const PetscScalar gpuarray[], Vec *v) noexcept
178 {
179   PetscFunctionBegin;
180   if (n && cpuarray) PetscValidScalarPointer(cpuarray, 4);
181   PetscValidPointer(v, 6);
182   PetscCall(impl::VecSeq_CUPM<T>::CreateSeqCUPMWithBothArrays(comm, bs, n, cpuarray, gpuarray, v));
183   PetscFunctionReturn(PETSC_SUCCESS);
184 }
185 
186 template <PetscMemoryAccessMode mode, device::cupm::DeviceType T>
187 inline PetscErrorCode VecCUPMGetArrayAsync_Private(Vec v, PetscScalar **a, PetscDeviceContext dctx) noexcept
188 {
189   PetscFunctionBegin;
190   PetscValidHeaderSpecific(v, VEC_CLASSID, 1);
191   PetscValidPointer(a, 2);
192   PetscCall(PetscDeviceContextGetOptionalNullContext_Internal(&dctx));
193   PetscCall(impl::VecSeq_CUPM<T>::template GetArray<PETSC_MEMTYPE_DEVICE, mode>(v, a, dctx));
194   PetscFunctionReturn(PETSC_SUCCESS);
195 }
196 
197 template <PetscMemoryAccessMode mode, device::cupm::DeviceType T>
198 inline PetscErrorCode VecCUPMRestoreArrayAsync_Private(Vec v, PetscScalar **a, PetscDeviceContext dctx) noexcept
199 {
200   PetscFunctionBegin;
201   PetscValidHeaderSpecific(v, VEC_CLASSID, 1);
202   PetscCall(PetscDeviceContextGetOptionalNullContext_Internal(&dctx));
203   PetscCall(impl::VecSeq_CUPM<T>::template RestoreArray<PETSC_MEMTYPE_DEVICE, mode>(v, a, dctx));
204   PetscFunctionReturn(PETSC_SUCCESS);
205 }
206 
207 template <device::cupm::DeviceType T>
208 inline PetscErrorCode VecCUPMGetArrayAsync(Vec v, PetscScalar **a, PetscDeviceContext dctx = nullptr) noexcept
209 {
210   PetscFunctionBegin;
211   PetscCall(VecCUPMGetArrayAsync_Private<PETSC_MEMORY_ACCESS_READ_WRITE, T>(v, a, dctx));
212   PetscFunctionReturn(PETSC_SUCCESS);
213 }
214 
215 template <device::cupm::DeviceType T>
216 inline PetscErrorCode VecCUPMRestoreArrayAsync(Vec v, PetscScalar **a, PetscDeviceContext dctx = nullptr) noexcept
217 {
218   PetscFunctionBegin;
219   PetscCall(VecCUPMRestoreArrayAsync_Private<PETSC_MEMORY_ACCESS_READ_WRITE, T>(v, a, dctx));
220   PetscFunctionReturn(PETSC_SUCCESS);
221 }
222 
223 template <device::cupm::DeviceType T>
224 inline PetscErrorCode VecCUPMGetArrayReadAsync(Vec v, const PetscScalar **a, PetscDeviceContext dctx = nullptr) noexcept
225 {
226   PetscFunctionBegin;
227   PetscCall(VecCUPMGetArrayAsync_Private<PETSC_MEMORY_ACCESS_READ, T>(v, const_cast<PetscScalar **>(a), dctx));
228   PetscFunctionReturn(PETSC_SUCCESS);
229 }
230 
231 template <device::cupm::DeviceType T>
232 inline PetscErrorCode VecCUPMRestoreArrayReadAsync(Vec v, const PetscScalar **a, PetscDeviceContext dctx = nullptr) noexcept
233 {
234   PetscFunctionBegin;
235   PetscCall(VecCUPMRestoreArrayAsync_Private<PETSC_MEMORY_ACCESS_READ, T>(v, const_cast<PetscScalar **>(a), dctx));
236   PetscFunctionReturn(PETSC_SUCCESS);
237 }
238 
239 template <device::cupm::DeviceType T>
240 inline PetscErrorCode VecCUPMGetArrayWriteAsync(Vec v, PetscScalar **a, PetscDeviceContext dctx = nullptr) noexcept
241 {
242   PetscFunctionBegin;
243   PetscCall(VecCUPMGetArrayAsync_Private<PETSC_MEMORY_ACCESS_WRITE, T>(v, a, dctx));
244   PetscFunctionReturn(PETSC_SUCCESS);
245 }
246 
247 template <device::cupm::DeviceType T>
248 inline PetscErrorCode VecCUPMRestoreArrayWriteAsync(Vec v, PetscScalar **a, PetscDeviceContext dctx = nullptr) noexcept
249 {
250   PetscFunctionBegin;
251   PetscCall(VecCUPMRestoreArrayAsync_Private<PETSC_MEMORY_ACCESS_WRITE, T>(v, a, dctx));
252   PetscFunctionReturn(PETSC_SUCCESS);
253 }
254 
255 template <device::cupm::DeviceType T>
256 inline PetscErrorCode VecCUPMPlaceArrayAsync(Vec vin, const PetscScalar a[]) noexcept
257 {
258   PetscFunctionBegin;
259   PetscValidHeaderSpecific(vin, VEC_CLASSID, 1);
260   PetscCall(impl::VecSeq_CUPM<T>::template PlaceArray<PETSC_MEMTYPE_DEVICE>(vin, a));
261   PetscFunctionReturn(PETSC_SUCCESS);
262 }
263 
264 template <device::cupm::DeviceType T>
265 inline PetscErrorCode VecCUPMReplaceArrayAsync(Vec vin, const PetscScalar a[]) noexcept
266 {
267   PetscFunctionBegin;
268   PetscValidHeaderSpecific(vin, VEC_CLASSID, 1);
269   PetscCall(impl::VecSeq_CUPM<T>::template ReplaceArray<PETSC_MEMTYPE_DEVICE>(vin, a));
270   PetscFunctionReturn(PETSC_SUCCESS);
271 }
272 
273 template <device::cupm::DeviceType T>
274 inline PetscErrorCode VecCUPMResetArrayAsync(Vec vin) noexcept
275 {
276   PetscFunctionBegin;
277   PetscValidHeaderSpecific(vin, VEC_CLASSID, 1);
278   PetscCall(impl::VecSeq_CUPM<T>::template ResetArray<PETSC_MEMTYPE_DEVICE>(vin));
279   PetscFunctionReturn(PETSC_SUCCESS);
280 }
281 
282 } // namespace cupm
283 
284 } // namespace vec
285 
286 } // namespace Petsc
287 
288 #if PetscDefined(HAVE_CUDA)
289 extern template class PETSC_SINGLE_LIBRARY_VISIBILITY_INTERNAL ::Petsc::vec::cupm::impl::VecSeq_CUPM<::Petsc::device::cupm::DeviceType::CUDA>;
290 #endif
291 
292 #if PetscDefined(HAVE_HIP)
293 extern template class PETSC_SINGLE_LIBRARY_VISIBILITY_INTERNAL ::Petsc::vec::cupm::impl::VecSeq_CUPM<::Petsc::device::cupm::DeviceType::HIP>;
294 #endif
295 
296 #endif // PETSCVECSEQCUPM_HPP
297