xref: /petsc/src/vec/vec/impls/seq/cupm/vecseqcupm.hpp (revision 8ba4effdc256b99103d44af35f27014676a8e511)
1 #ifndef PETSCVECSEQCUPM_HPP
2 #define PETSCVECSEQCUPM_HPP
3 
4 #include <petsc/private/veccupmimpl.h>
5 
6 #if defined(__cplusplus)
7   #include <petsc/private/randomimpl.h> // for _p_PetscRandom
8 
9   #include <petsc/private/cpp/utility.hpp> // util::exchange, util::index_sequence
10 
11   #include "../src/sys/objects/device/impls/cupm/cupmthrustutility.hpp"
12   #include "../src/sys/objects/device/impls/cupm/kernels.hpp"
13 
14   #if PetscDefined(USE_COMPLEX)
15     #include <thrust/transform_reduce.h>
16   #endif
17   #include <thrust/transform.h>
18   #include <thrust/reduce.h>
19   #include <thrust/functional.h>
20   #include <thrust/tuple.h>
21   #include <thrust/device_ptr.h>
22   #include <thrust/iterator/zip_iterator.h>
23   #include <thrust/iterator/counting_iterator.h>
24   #include <thrust/inner_product.h>
25 
26 namespace Petsc
27 {
28 
29 namespace vec
30 {
31 
32 namespace cupm
33 {
34 
35 namespace impl
36 {
37 
38 // ==========================================================================================
39 // VecSeq_CUPM
40 // ==========================================================================================
41 
42 template <device::cupm::DeviceType T>
43 class VecSeq_CUPM : Vec_CUPMBase<T, VecSeq_CUPM<T>> {
44 public:
45   PETSC_VEC_CUPM_BASE_CLASS_HEADER(base_type, T, VecSeq_CUPM<T>);
46 
47 private:
48   PETSC_NODISCARD static Vec_Seq          *VecIMPLCast_(Vec) noexcept;
49   PETSC_NODISCARD static constexpr VecType VECIMPLCUPM_() noexcept;
50   PETSC_NODISCARD static constexpr VecType VECIMPL_() noexcept;
51 
52   static PetscErrorCode VecDestroy_IMPL_(Vec) noexcept;
53   static PetscErrorCode VecResetArray_IMPL_(Vec) noexcept;
54   static PetscErrorCode VecPlaceArray_IMPL_(Vec, const PetscScalar *) noexcept;
55   static PetscErrorCode VecCreate_IMPL_Private_(Vec, PetscBool *, PetscInt, PetscScalar *) noexcept;
56 
57   static PetscErrorCode MaybeIncrementEmptyLocalVec(Vec) noexcept;
58 
59   // common core for min and max
60   template <typename TupleFuncT, typename UnaryFuncT>
61   static PetscErrorCode MinMax_(TupleFuncT &&, UnaryFuncT &&, Vec, PetscInt *, PetscReal *) noexcept;
62   // common core for pointwise binary and pointwise unary thrust functions
63   template <typename BinaryFuncT>
64   static PetscErrorCode PointwiseBinary_(BinaryFuncT &&, Vec, Vec, Vec) noexcept;
65   template <typename UnaryFuncT>
66   static PetscErrorCode PointwiseUnary_(UnaryFuncT &&, Vec, Vec /*out*/ = nullptr) noexcept;
67   // mdot dispatchers
68   static PetscErrorCode MDot_(/* use complex = */ std::true_type, Vec, PetscInt, const Vec[], PetscScalar *, PetscDeviceContext) noexcept;
69   static PetscErrorCode MDot_(/* use complex = */ std::false_type, Vec, PetscInt, const Vec[], PetscScalar *, PetscDeviceContext) noexcept;
70   template <std::size_t... Idx>
71   static PetscErrorCode MDot_kernel_dispatch_(PetscDeviceContext, cupmStream_t, const PetscScalar *, const Vec[], PetscInt, PetscScalar *, util::index_sequence<Idx...>) noexcept;
72   template <int>
73   static PetscErrorCode MDot_kernel_dispatch_(PetscDeviceContext, cupmStream_t, const PetscScalar *, const Vec[], PetscInt, PetscScalar *, PetscInt &) noexcept;
74   template <std::size_t... Idx>
75   static PetscErrorCode MAXPY_kernel_dispatch_(PetscDeviceContext, cupmStream_t, PetscScalar *, const PetscScalar *, const Vec *, PetscInt, util::index_sequence<Idx...>) noexcept;
76   template <int>
77   static PetscErrorCode MAXPY_kernel_dispatch_(PetscDeviceContext, cupmStream_t, PetscScalar *, const PetscScalar *, const Vec *, PetscInt, PetscInt &) noexcept;
78   // common core for the various create routines
79   static PetscErrorCode CreateSeqCUPM_(Vec, PetscDeviceContext, PetscScalar * /*host_ptr*/ = nullptr, PetscScalar * /*device_ptr*/ = nullptr) noexcept;
80 
81 public:
82   // callable directly via a bespoke function
83   static PetscErrorCode CreateSeqCUPM(MPI_Comm, PetscInt, PetscInt, Vec *, PetscBool) noexcept;
84   static PetscErrorCode CreateSeqCUPMWithBothArrays(MPI_Comm, PetscInt, PetscInt, const PetscScalar[], const PetscScalar[], Vec *) noexcept;
85 
86   // callable indirectly via function pointers
87   static PetscErrorCode Duplicate(Vec, Vec *) noexcept;
88   static PetscErrorCode AYPX(Vec, PetscScalar, Vec) noexcept;
89   static PetscErrorCode AXPY(Vec, PetscScalar, Vec) noexcept;
90   static PetscErrorCode PointwiseDivide(Vec, Vec, Vec) noexcept;
91   static PetscErrorCode PointwiseMult(Vec, Vec, Vec) noexcept;
92   static PetscErrorCode Reciprocal(Vec) noexcept;
93   static PetscErrorCode WAXPY(Vec, PetscScalar, Vec, Vec) noexcept;
94   static PetscErrorCode MAXPY(Vec, PetscInt, const PetscScalar[], Vec *) noexcept;
95   static PetscErrorCode Dot(Vec, Vec, PetscScalar *) noexcept;
96   static PetscErrorCode MDot(Vec, PetscInt, const Vec[], PetscScalar *) noexcept;
97   static PetscErrorCode Set(Vec, PetscScalar) noexcept;
98   static PetscErrorCode Scale(Vec, PetscScalar) noexcept;
99   static PetscErrorCode TDot(Vec, Vec, PetscScalar *) noexcept;
100   static PetscErrorCode Copy(Vec, Vec) noexcept;
101   static PetscErrorCode Swap(Vec, Vec) noexcept;
102   static PetscErrorCode AXPBY(Vec, PetscScalar, PetscScalar, Vec) noexcept;
103   static PetscErrorCode AXPBYPCZ(Vec, PetscScalar, PetscScalar, PetscScalar, Vec, Vec) noexcept;
104   static PetscErrorCode Norm(Vec, NormType, PetscReal *) noexcept;
105   static PetscErrorCode DotNorm2(Vec, Vec, PetscScalar *, PetscScalar *) noexcept;
106   static PetscErrorCode Conjugate(Vec) noexcept;
107   template <PetscMemoryAccessMode>
108   static PetscErrorCode GetLocalVector(Vec, Vec) noexcept;
109   template <PetscMemoryAccessMode>
110   static PetscErrorCode RestoreLocalVector(Vec, Vec) noexcept;
111   static PetscErrorCode Max(Vec, PetscInt *, PetscReal *) noexcept;
112   static PetscErrorCode Min(Vec, PetscInt *, PetscReal *) noexcept;
113   static PetscErrorCode Sum(Vec, PetscScalar *) noexcept;
114   static PetscErrorCode Shift(Vec, PetscScalar) noexcept;
115   static PetscErrorCode SetRandom(Vec, PetscRandom) noexcept;
116   static PetscErrorCode BindToCPU(Vec, PetscBool) noexcept;
117   static PetscErrorCode SetPreallocationCOO(Vec, PetscCount, const PetscInt[]) noexcept;
118   static PetscErrorCode SetValuesCOO(Vec, const PetscScalar[], InsertMode) noexcept;
119 };
120 
121 // ==========================================================================================
122 // VecSeq_CUPM - Private API
123 // ==========================================================================================
124 
125 template <device::cupm::DeviceType T>
126 inline Vec_Seq *VecSeq_CUPM<T>::VecIMPLCast_(Vec v) noexcept
127 {
128   return static_cast<Vec_Seq *>(v->data);
129 }
130 
131 template <device::cupm::DeviceType T>
132 inline constexpr VecType VecSeq_CUPM<T>::VECIMPLCUPM_() noexcept
133 {
134   return VECSEQCUPM();
135 }
136 
137 template <device::cupm::DeviceType T>
138 inline constexpr VecType VecSeq_CUPM<T>::VECIMPL_() noexcept
139 {
140   return VECSEQ;
141 }
142 
143 template <device::cupm::DeviceType T>
144 inline PetscErrorCode VecSeq_CUPM<T>::VecDestroy_IMPL_(Vec v) noexcept
145 {
146   return VecDestroy_Seq(v);
147 }
148 
149 template <device::cupm::DeviceType T>
150 inline PetscErrorCode VecSeq_CUPM<T>::VecResetArray_IMPL_(Vec v) noexcept
151 {
152   return VecResetArray_Seq(v);
153 }
154 
155 template <device::cupm::DeviceType T>
156 inline PetscErrorCode VecSeq_CUPM<T>::VecPlaceArray_IMPL_(Vec v, const PetscScalar *a) noexcept
157 {
158   return VecPlaceArray_Seq(v, a);
159 }
160 
161 template <device::cupm::DeviceType T>
162 inline PetscErrorCode VecSeq_CUPM<T>::VecCreate_IMPL_Private_(Vec v, PetscBool *alloc_missing, PetscInt, PetscScalar *host_array) noexcept
163 {
164   PetscMPIInt size;
165 
166   PetscFunctionBegin;
167   if (alloc_missing) *alloc_missing = PETSC_FALSE;
168   PetscCallMPI(MPI_Comm_size(PetscObjectComm(PetscObjectCast(v)), &size));
169   PetscCheck(size <= 1, PETSC_COMM_SELF, PETSC_ERR_ARG_WRONG, "Must create VecSeq on communicator of size 1, have size %d", size);
170   PetscCall(VecCreate_Seq_Private(v, host_array));
171   PetscFunctionReturn(PETSC_SUCCESS);
172 }
173 
174 // for functions with an early return based one vec size we still need to artificially bump the
175 // object state. This is to prevent the following:
176 //
177 // 0. Suppose you have a Vec {
178 //   rank 0: [0],
179 //   rank 1: [<empty>]
180 // }
181 // 1. both ranks have Vec with PetscObjectState = 0, stashed norm of 0
182 // 2. Vec enters e.g. VecSet(10)
183 // 3. rank 1 has local size 0 and bails immediately
184 // 4. rank 0 has local size 1 and enters function, eventually calls DeviceArrayWrite()
185 // 5. DeviceArrayWrite() calls PetscObjectStateIncrease(), now state = 1
186 // 6. Vec enters VecNorm(), and calls VecNormAvailable()
187 // 7. rank 1 has object state = 0, equal to stash and returns early with norm = 0
188 // 8. rank 0 has object state = 1, not equal to stash, continues to impl function
189 // 9. rank 0 deadlocks on MPI_Allreduce() because rank 1 bailed early
190 template <device::cupm::DeviceType T>
191 inline PetscErrorCode VecSeq_CUPM<T>::MaybeIncrementEmptyLocalVec(Vec v) noexcept
192 {
193   PetscFunctionBegin;
194   if (PetscUnlikely((v->map->n == 0) && (v->map->N != 0))) PetscCall(PetscObjectStateIncrease(PetscObjectCast(v)));
195   PetscFunctionReturn(PETSC_SUCCESS);
196 }
197 
198 template <device::cupm::DeviceType T>
199 inline PetscErrorCode VecSeq_CUPM<T>::CreateSeqCUPM_(Vec v, PetscDeviceContext dctx, PetscScalar *host_array, PetscScalar *device_array) noexcept
200 {
201   PetscFunctionBegin;
202   PetscCall(base_type::VecCreate_IMPL_Private(v, nullptr, 0, host_array));
203   PetscCall(Initialize_CUPMBase(v, PETSC_FALSE, host_array, device_array, dctx));
204   PetscFunctionReturn(PETSC_SUCCESS);
205 }
206 
207 template <device::cupm::DeviceType T>
208 template <typename BinaryFuncT>
209 inline PetscErrorCode VecSeq_CUPM<T>::PointwiseBinary_(BinaryFuncT &&binary, Vec xin, Vec yin, Vec zout) noexcept
210 {
211   PetscFunctionBegin;
212   if (const auto n = zout->map->n) {
213     PetscDeviceContext dctx;
214     cupmStream_t       stream;
215 
216     PetscCall(GetHandles_(&dctx, &stream));
217     // clang-format off
218     PetscCallThrust(
219       const auto dxptr = thrust::device_pointer_cast(DeviceArrayRead(dctx, xin).data());
220 
221       THRUST_CALL(
222         thrust::transform,
223         stream,
224         dxptr, dxptr + n,
225         thrust::device_pointer_cast(DeviceArrayRead(dctx, yin).data()),
226         thrust::device_pointer_cast(DeviceArrayWrite(dctx, zout).data()),
227         std::forward<BinaryFuncT>(binary)
228       )
229     );
230     // clang-format on
231     PetscCall(PetscLogFlops(n));
232     PetscCall(PetscDeviceContextSynchronize(dctx));
233   } else {
234     PetscCall(MaybeIncrementEmptyLocalVec(zout));
235   }
236   PetscFunctionReturn(PETSC_SUCCESS);
237 }
238 
239 template <device::cupm::DeviceType T>
240 template <typename UnaryFuncT>
241 inline PetscErrorCode VecSeq_CUPM<T>::PointwiseUnary_(UnaryFuncT &&unary, Vec xinout, Vec yin) noexcept
242 {
243   const auto inplace = !yin || (xinout == yin);
244 
245   PetscFunctionBegin;
246   if (const auto n = xinout->map->n) {
247     PetscDeviceContext dctx;
248     cupmStream_t       stream;
249     const auto         apply = [&](PetscScalar *xinout, PetscScalar *yin = nullptr) {
250       PetscFunctionBegin;
251       // clang-format off
252       PetscCallThrust(
253         const auto xptr = thrust::device_pointer_cast(xinout);
254 
255         THRUST_CALL(
256           thrust::transform,
257           stream,
258           xptr, xptr + n,
259           (yin && (yin != xinout)) ? thrust::device_pointer_cast(yin) : xptr,
260           std::forward<UnaryFuncT>(unary)
261         )
262       );
263       PetscFunctionReturn(PETSC_SUCCESS);
264     };
265 
266     PetscCall(GetHandles_(&dctx, &stream));
267     if (inplace) {
268       PetscCall(apply(DeviceArrayReadWrite(dctx, xinout).data()));
269     } else {
270       PetscCall(apply(DeviceArrayRead(dctx, xinout).data(), DeviceArrayWrite(dctx, yin).data()));
271     }
272     PetscCall(PetscLogFlops(n));
273     PetscCall(PetscDeviceContextSynchronize(dctx));
274   } else {
275     if (inplace) {
276       PetscCall(MaybeIncrementEmptyLocalVec(xinout));
277     } else {
278       PetscCall(MaybeIncrementEmptyLocalVec(yin));
279     }
280   }
281   PetscFunctionReturn(PETSC_SUCCESS);
282 }
283 
284 // ==========================================================================================
285 // VecSeq_CUPM - Public API - Constructors
286 // ==========================================================================================
287 
288 // VecCreateSeqCUPM()
289 template <device::cupm::DeviceType T>
290 inline PetscErrorCode VecSeq_CUPM<T>::CreateSeqCUPM(MPI_Comm comm, PetscInt bs, PetscInt n, Vec *v, PetscBool call_set_type) noexcept
291 {
292   PetscFunctionBegin;
293   PetscCall(Create_CUPMBase(comm, bs, n, n, v, call_set_type));
294   PetscFunctionReturn(PETSC_SUCCESS);
295 }
296 
297 // VecCreateSeqCUPMWithArrays()
298 template <device::cupm::DeviceType T>
299 inline PetscErrorCode VecSeq_CUPM<T>::CreateSeqCUPMWithBothArrays(MPI_Comm comm, PetscInt bs, PetscInt n, const PetscScalar host_array[], const PetscScalar device_array[], Vec *v) noexcept
300 {
301   PetscDeviceContext dctx;
302 
303   PetscFunctionBegin;
304   PetscCall(GetHandles_(&dctx));
305   // do NOT call VecSetType(), otherwise ops->create() -> create() ->
306   // CreateSeqCUPM_() is called!
307   PetscCall(CreateSeqCUPM(comm, bs, n, v, PETSC_FALSE));
308   PetscCall(CreateSeqCUPM_(*v, dctx, PetscRemoveConstCast(host_array), PetscRemoveConstCast(device_array)));
309   PetscFunctionReturn(PETSC_SUCCESS);
310 }
311 
312 // v->ops->duplicate
313 template <device::cupm::DeviceType T>
314 inline PetscErrorCode VecSeq_CUPM<T>::Duplicate(Vec v, Vec *y) noexcept
315 {
316   PetscDeviceContext dctx;
317 
318   PetscFunctionBegin;
319   PetscCall(GetHandles_(&dctx));
320   PetscCall(Duplicate_CUPMBase(v, y, dctx));
321   PetscFunctionReturn(PETSC_SUCCESS);
322 }
323 
324 // ==========================================================================================
325 // VecSeq_CUPM - Public API - Utility
326 // ==========================================================================================
327 
328 // v->ops->bindtocpu
329 template <device::cupm::DeviceType T>
330 inline PetscErrorCode VecSeq_CUPM<T>::BindToCPU(Vec v, PetscBool usehost) noexcept
331 {
332   PetscDeviceContext dctx;
333 
334   PetscFunctionBegin;
335   PetscCall(GetHandles_(&dctx));
336   PetscCall(BindToCPU_CUPMBase(v, usehost, dctx));
337 
338   // REVIEW ME: this absolutely should be some sort of bulk mempcy rather than this mess
339   VecSetOp_CUPM(dot, VecDot_Seq, Dot);
340   VecSetOp_CUPM(norm, VecNorm_Seq, Norm);
341   VecSetOp_CUPM(tdot, VecTDot_Seq, TDot);
342   VecSetOp_CUPM(mdot, VecMDot_Seq, MDot);
343   VecSetOp_CUPM(resetarray, VecResetArray_Seq, base_type::template ResetArray<PETSC_MEMTYPE_HOST>);
344   VecSetOp_CUPM(placearray, VecPlaceArray_Seq, base_type::template PlaceArray<PETSC_MEMTYPE_HOST>);
345   v->ops->mtdot = v->ops->mtdot_local = VecMTDot_Seq;
346   VecSetOp_CUPM(conjugate, VecConjugate_Seq, Conjugate);
347   VecSetOp_CUPM(max, VecMax_Seq, Max);
348   VecSetOp_CUPM(min, VecMin_Seq, Min);
349   VecSetOp_CUPM(setpreallocationcoo, VecSetPreallocationCOO_Seq, SetPreallocationCOO);
350   VecSetOp_CUPM(setvaluescoo, VecSetValuesCOO_Seq, SetValuesCOO);
351   PetscFunctionReturn(PETSC_SUCCESS);
352 }
353 
354 // ==========================================================================================
355 // VecSeq_CUPM - Public API - Mutators
356 // ==========================================================================================
357 
358 // v->ops->getlocalvector or v->ops->getlocalvectorread
359 template <device::cupm::DeviceType T>
360 template <PetscMemoryAccessMode access>
361 inline PetscErrorCode VecSeq_CUPM<T>::GetLocalVector(Vec v, Vec w) noexcept
362 {
363   PetscBool wisseqcupm;
364 
365   PetscFunctionBegin;
366   PetscCheckTypeNames(v, VECSEQCUPM(), VECMPICUPM());
367   PetscCall(PetscObjectTypeCompare(PetscObjectCast(w), VECSEQCUPM(), &wisseqcupm));
368   if (wisseqcupm) {
369     if (const auto wseq = VecIMPLCast(w)) {
370       if (auto &alloced = wseq->array_allocated) {
371         const auto useit = UseCUPMHostAlloc(util::exchange(w->pinned_memory, PETSC_FALSE));
372 
373         PetscCall(PetscFree(alloced));
374       }
375       wseq->array         = nullptr;
376       wseq->unplacedarray = nullptr;
377     }
378     if (const auto wcu = VecCUPMCast(w)) {
379       if (auto &device_array = wcu->array_d) {
380         cupmStream_t stream;
381 
382         PetscCall(GetHandles_(&stream));
383         PetscCallCUPM(cupmFreeAsync(device_array, stream));
384       }
385       PetscCall(PetscFree(w->spptr /* wcu */));
386     }
387   }
388   if (v->petscnative && wisseqcupm) {
389     PetscCall(PetscFree(w->data));
390     w->data          = v->data;
391     w->offloadmask   = v->offloadmask;
392     w->pinned_memory = v->pinned_memory;
393     w->spptr         = v->spptr;
394     PetscCall(PetscObjectStateIncrease(PetscObjectCast(w)));
395   } else {
396     const auto array = &VecIMPLCast(w)->array;
397 
398     if (access == PETSC_MEMORY_ACCESS_READ) {
399       PetscCall(VecGetArrayRead(v, const_cast<const PetscScalar **>(array)));
400     } else {
401       PetscCall(VecGetArray(v, array));
402     }
403     w->offloadmask = PETSC_OFFLOAD_CPU;
404     if (wisseqcupm) {
405       PetscDeviceContext dctx;
406 
407       PetscCall(GetHandles_(&dctx));
408       PetscCall(DeviceAllocateCheck_(dctx, w));
409     }
410   }
411   PetscFunctionReturn(PETSC_SUCCESS);
412 }
413 
414 // v->ops->restorelocalvector or v->ops->restorelocalvectorread
415 template <device::cupm::DeviceType T>
416 template <PetscMemoryAccessMode access>
417 inline PetscErrorCode VecSeq_CUPM<T>::RestoreLocalVector(Vec v, Vec w) noexcept
418 {
419   PetscBool wisseqcupm;
420 
421   PetscFunctionBegin;
422   PetscCheckTypeNames(v, VECSEQCUPM(), VECMPICUPM());
423   PetscCall(PetscObjectTypeCompare(PetscObjectCast(w), VECSEQCUPM(), &wisseqcupm));
424   if (v->petscnative && wisseqcupm) {
425     // the assignments to nullptr are __critical__, as w may persist after this call returns
426     // and shouldn't share data with v!
427     v->pinned_memory = w->pinned_memory;
428     v->offloadmask   = util::exchange(w->offloadmask, PETSC_OFFLOAD_UNALLOCATED);
429     v->data          = util::exchange(w->data, nullptr);
430     v->spptr         = util::exchange(w->spptr, nullptr);
431   } else {
432     const auto array = &VecIMPLCast(w)->array;
433 
434     if (access == PETSC_MEMORY_ACCESS_READ) {
435       PetscCall(VecRestoreArrayRead(v, const_cast<const PetscScalar **>(array)));
436     } else {
437       PetscCall(VecRestoreArray(v, array));
438     }
439     if (w->spptr && wisseqcupm) {
440       cupmStream_t stream;
441 
442       PetscCall(GetHandles_(&stream));
443       PetscCallCUPM(cupmFreeAsync(VecCUPMCast(w)->array_d, stream));
444       PetscCall(PetscFree(w->spptr));
445     }
446   }
447   PetscFunctionReturn(PETSC_SUCCESS);
448 }
449 
450 // ==========================================================================================
451 // VecSeq_CUPM - Public API - Compute Methods
452 // ==========================================================================================
453 
454 // v->ops->aypx
455 template <device::cupm::DeviceType T>
456 inline PetscErrorCode VecSeq_CUPM<T>::AYPX(Vec yin, PetscScalar alpha, Vec xin) noexcept
457 {
458   const auto         n    = static_cast<cupmBlasInt_t>(yin->map->n);
459   const auto         sync = n != 0;
460   PetscDeviceContext dctx;
461 
462   PetscFunctionBegin;
463   PetscCall(GetHandles_(&dctx));
464   if (alpha == PetscScalar(0.0)) {
465     cupmStream_t stream;
466 
467     PetscCall(GetHandlesFrom_(dctx, &stream));
468     PetscCall(PetscLogGpuTimeBegin());
469     PetscCall(PetscCUPMMemcpyAsync(DeviceArrayWrite(dctx, yin).data(), DeviceArrayRead(dctx, xin).data(), n, cupmMemcpyDeviceToDevice, stream));
470     PetscCall(PetscLogGpuTimeEnd());
471   } else if (n) {
472     const auto       alphaIsOne = alpha == PetscScalar(1.0);
473     const auto       calpha     = cupmScalarPtrCast(&alpha);
474     cupmBlasHandle_t cupmBlasHandle;
475 
476     PetscCall(GetHandlesFrom_(dctx, &cupmBlasHandle));
477     {
478       const auto yptr = DeviceArrayReadWrite(dctx, yin);
479       const auto xptr = DeviceArrayRead(dctx, xin);
480 
481       PetscCall(PetscLogGpuTimeBegin());
482       if (alphaIsOne) {
483         PetscCallCUPMBLAS(cupmBlasXaxpy(cupmBlasHandle, n, calpha, xptr.cupmdata(), 1, yptr.cupmdata(), 1));
484       } else {
485         const auto one = cupmScalarCast(1.0);
486 
487         PetscCallCUPMBLAS(cupmBlasXscal(cupmBlasHandle, n, calpha, yptr.cupmdata(), 1));
488         PetscCallCUPMBLAS(cupmBlasXaxpy(cupmBlasHandle, n, &one, xptr.cupmdata(), 1, yptr.cupmdata(), 1));
489       }
490       PetscCall(PetscLogGpuTimeEnd());
491     }
492     PetscCall(PetscLogGpuFlops((alphaIsOne ? 1 : 2) * n));
493   }
494   if (sync) PetscCall(PetscDeviceContextSynchronize(dctx));
495   PetscFunctionReturn(PETSC_SUCCESS);
496 }
497 
498 // v->ops->axpy
499 template <device::cupm::DeviceType T>
500 inline PetscErrorCode VecSeq_CUPM<T>::AXPY(Vec yin, PetscScalar alpha, Vec xin) noexcept
501 {
502   PetscBool xiscupm;
503 
504   PetscFunctionBegin;
505   PetscCall(PetscObjectTypeCompareAny(PetscObjectCast(xin), &xiscupm, VECSEQCUPM(), VECMPICUPM(), ""));
506   if (xiscupm) {
507     const auto         n = static_cast<cupmBlasInt_t>(yin->map->n);
508     PetscDeviceContext dctx;
509     cupmBlasHandle_t   cupmBlasHandle;
510 
511     PetscCall(GetHandles_(&dctx, &cupmBlasHandle));
512     PetscCall(PetscLogGpuTimeBegin());
513     PetscCallCUPMBLAS(cupmBlasXaxpy(cupmBlasHandle, n, cupmScalarPtrCast(&alpha), DeviceArrayRead(dctx, xin), 1, DeviceArrayReadWrite(dctx, yin), 1));
514     PetscCall(PetscLogGpuTimeEnd());
515     PetscCall(PetscLogGpuFlops(2 * n));
516     PetscCall(PetscDeviceContextSynchronize(dctx));
517   } else {
518     PetscCall(VecAXPY_Seq(yin, alpha, xin));
519   }
520   PetscFunctionReturn(PETSC_SUCCESS);
521 }
522 
523 // v->ops->pointwisedivide
524 template <device::cupm::DeviceType T>
525 inline PetscErrorCode VecSeq_CUPM<T>::PointwiseDivide(Vec win, Vec xin, Vec yin) noexcept
526 {
527   PetscFunctionBegin;
528   if (xin->boundtocpu || yin->boundtocpu) {
529     PetscCall(VecPointwiseDivide_Seq(win, xin, yin));
530   } else {
531     // note order of arguments! xin and yin are read, win is written!
532     PetscCall(PointwiseBinary_(thrust::divides<PetscScalar>{}, xin, yin, win));
533   }
534   PetscFunctionReturn(PETSC_SUCCESS);
535 }
536 
537 // v->ops->pointwisemult
538 template <device::cupm::DeviceType T>
539 inline PetscErrorCode VecSeq_CUPM<T>::PointwiseMult(Vec win, Vec xin, Vec yin) noexcept
540 {
541   PetscFunctionBegin;
542   if (xin->boundtocpu || yin->boundtocpu) {
543     PetscCall(VecPointwiseMult_Seq(win, xin, yin));
544   } else {
545     // note order of arguments! xin and yin are read, win is written!
546     PetscCall(PointwiseBinary_(thrust::multiplies<PetscScalar>{}, xin, yin, win));
547   }
548   PetscFunctionReturn(PETSC_SUCCESS);
549 }
550 
551 namespace detail
552 {
553 
554 struct reciprocal {
555   PETSC_HOSTDEVICE_INLINE_DECL PetscScalar operator()(PetscScalar s) const noexcept
556   {
557     // yes all of this verbosity is needed because sometimes PetscScalar is a thrust::complex
558     // and then it matters whether we do s ? true : false vs s == 0, as well as whether we wrap
559     // everything in PetscScalar...
560     return s == PetscScalar{0.0} ? s : PetscScalar{1.0} / s;
561   }
562 };
563 
564 } // namespace detail
565 
566 // v->ops->reciprocal
567 template <device::cupm::DeviceType T>
568 inline PetscErrorCode VecSeq_CUPM<T>::Reciprocal(Vec xin) noexcept
569 {
570   PetscFunctionBegin;
571   PetscCall(PointwiseUnary_(detail::reciprocal{}, xin));
572   PetscFunctionReturn(PETSC_SUCCESS);
573 }
574 
575 // v->ops->waxpy
576 template <device::cupm::DeviceType T>
577 inline PetscErrorCode VecSeq_CUPM<T>::WAXPY(Vec win, PetscScalar alpha, Vec xin, Vec yin) noexcept
578 {
579   PetscFunctionBegin;
580   if (alpha == PetscScalar(0.0)) {
581     PetscCall(Copy(yin, win));
582   } else if (const auto n = static_cast<cupmBlasInt_t>(win->map->n)) {
583     PetscDeviceContext dctx;
584     cupmBlasHandle_t   cupmBlasHandle;
585     cupmStream_t       stream;
586 
587     PetscCall(GetHandles_(&dctx, &cupmBlasHandle, &stream));
588     {
589       const auto wptr = DeviceArrayWrite(dctx, win);
590 
591       PetscCall(PetscLogGpuTimeBegin());
592       PetscCall(PetscCUPMMemcpyAsync(wptr.data(), DeviceArrayRead(dctx, yin).data(), n, cupmMemcpyDeviceToDevice, stream, true));
593       PetscCallCUPMBLAS(cupmBlasXaxpy(cupmBlasHandle, n, cupmScalarPtrCast(&alpha), DeviceArrayRead(dctx, xin), 1, wptr.cupmdata(), 1));
594       PetscCall(PetscLogGpuTimeEnd());
595     }
596     PetscCall(PetscLogGpuFlops(2 * n));
597     PetscCall(PetscDeviceContextSynchronize(dctx));
598   }
599   PetscFunctionReturn(PETSC_SUCCESS);
600 }
601 
602 namespace kernels
603 {
604 
605 template <typename... Args>
606 PETSC_KERNEL_DECL static void MAXPY_kernel(const PetscInt size, PetscScalar *PETSC_RESTRICT xptr, const PetscScalar *PETSC_RESTRICT aptr, Args... yptr)
607 {
608   constexpr int      N        = sizeof...(Args);
609   const auto         tx       = threadIdx.x;
610   const PetscScalar *yptr_p[] = {yptr...};
611 
612   PETSC_SHAREDMEM_DECL PetscScalar aptr_shmem[N];
613 
614   // load a to shared memory
615   if (tx < N) aptr_shmem[tx] = aptr[tx];
616   __syncthreads();
617 
618   ::Petsc::device::cupm::kernels::util::grid_stride_1D(size, [&](PetscInt i) {
619     // these may look the same but give different results!
620   #if 0
621     PetscScalar sum = 0.0;
622 
623     #pragma unroll
624     for (auto j = 0; j < N; ++j) sum += aptr_shmem[j]*yptr_p[j][i];
625     xptr[i] += sum;
626   #else
627     auto sum = xptr[i];
628 
629     #pragma unroll
630     for (auto j = 0; j < N; ++j) sum += aptr_shmem[j]*yptr_p[j][i];
631     xptr[i] = sum;
632   #endif
633   });
634   return;
635 }
636 
637 } // namespace kernels
638 
639 namespace detail
640 {
641 
642 // a helper-struct to gobble the size_t input, it is used with template parameter pack
643 // expansion such that
644 // typename repeat_type<MyType, IdxParamPack>...
645 // expands to
646 // MyType, MyType, MyType, ... [repeated sizeof...(IdxParamPack) times]
647 template <typename T, std::size_t>
648 struct repeat_type {
649   using type = T;
650 };
651 
652 } // namespace detail
653 
654 template <device::cupm::DeviceType T>
655 template <std::size_t... Idx>
656 inline PetscErrorCode VecSeq_CUPM<T>::MAXPY_kernel_dispatch_(PetscDeviceContext dctx, cupmStream_t stream, PetscScalar *xptr, const PetscScalar *aptr, const Vec *yin, PetscInt size, util::index_sequence<Idx...>) noexcept
657 {
658   PetscFunctionBegin;
659   // clang-format off
660   PetscCall(
661     PetscCUPMLaunchKernel1D(
662       size, 0, stream,
663       kernels::MAXPY_kernel<typename detail::repeat_type<const PetscScalar *, Idx>::type...>,
664       size, xptr, aptr, DeviceArrayRead(dctx, yin[Idx]).data()...
665     )
666   );
667   // clang-format on
668   PetscFunctionReturn(PETSC_SUCCESS);
669 }
670 
671 template <device::cupm::DeviceType T>
672 template <int N>
673 inline PetscErrorCode VecSeq_CUPM<T>::MAXPY_kernel_dispatch_(PetscDeviceContext dctx, cupmStream_t stream, PetscScalar *xptr, const PetscScalar *aptr, const Vec *yin, PetscInt size, PetscInt &yidx) noexcept
674 {
675   PetscFunctionBegin;
676   PetscCall(MAXPY_kernel_dispatch_(dctx, stream, xptr, aptr + yidx, yin + yidx, size, util::make_index_sequence<N>{}));
677   yidx += N;
678   PetscFunctionReturn(PETSC_SUCCESS);
679 }
680 
681 // v->ops->maxpy
682 template <device::cupm::DeviceType T>
683 inline PetscErrorCode VecSeq_CUPM<T>::MAXPY(Vec xin, PetscInt nv, const PetscScalar *alpha, Vec *yin) noexcept
684 {
685   const auto         n = xin->map->n;
686   PetscDeviceContext dctx;
687   cupmStream_t       stream;
688 
689   PetscFunctionBegin;
690   PetscCall(GetHandles_(&dctx, &stream));
691   {
692     const auto   xptr    = DeviceArrayReadWrite(dctx, xin);
693     PetscScalar *d_alpha = nullptr;
694     PetscInt     yidx    = 0;
695 
696     // placement of early-return is deliberate, we would like to capture the
697     // DeviceArrayReadWrite() call (which calls PetscObjectStateIncreate()) before we bail
698     if (!n || !nv) PetscFunctionReturn(PETSC_SUCCESS);
699     PetscCall(PetscDeviceMalloc(dctx, PETSC_MEMTYPE_CUPM(), nv, &d_alpha));
700     PetscCall(PetscCUPMMemcpyAsync(d_alpha, alpha, nv, cupmMemcpyHostToDevice, stream));
701     PetscCall(PetscLogGpuTimeBegin());
702     do {
703       switch (nv - yidx) {
704       case 7:
705         PetscCall(MAXPY_kernel_dispatch_<7>(dctx, stream, xptr.data(), d_alpha, yin, n, yidx));
706         break;
707       case 6:
708         PetscCall(MAXPY_kernel_dispatch_<6>(dctx, stream, xptr.data(), d_alpha, yin, n, yidx));
709         break;
710       case 5:
711         PetscCall(MAXPY_kernel_dispatch_<5>(dctx, stream, xptr.data(), d_alpha, yin, n, yidx));
712         break;
713       case 4:
714         PetscCall(MAXPY_kernel_dispatch_<4>(dctx, stream, xptr.data(), d_alpha, yin, n, yidx));
715         break;
716       case 3:
717         PetscCall(MAXPY_kernel_dispatch_<3>(dctx, stream, xptr.data(), d_alpha, yin, n, yidx));
718         break;
719       case 2:
720         PetscCall(MAXPY_kernel_dispatch_<2>(dctx, stream, xptr.data(), d_alpha, yin, n, yidx));
721         break;
722       case 1:
723         PetscCall(MAXPY_kernel_dispatch_<1>(dctx, stream, xptr.data(), d_alpha, yin, n, yidx));
724         break;
725       default: // 8 or more
726         PetscCall(MAXPY_kernel_dispatch_<8>(dctx, stream, xptr.data(), d_alpha, yin, n, yidx));
727         break;
728       }
729     } while (yidx < nv);
730     PetscCall(PetscLogGpuTimeEnd());
731     PetscCall(PetscDeviceFree(dctx, d_alpha));
732   }
733   PetscCall(PetscLogGpuFlops(nv * 2 * n));
734   PetscCall(PetscDeviceContextSynchronize(dctx));
735   PetscFunctionReturn(PETSC_SUCCESS);
736 }
737 
738 template <device::cupm::DeviceType T>
739 inline PetscErrorCode VecSeq_CUPM<T>::Dot(Vec xin, Vec yin, PetscScalar *z) noexcept
740 {
741   PetscFunctionBegin;
742   if (const auto n = static_cast<cupmBlasInt_t>(xin->map->n)) {
743     PetscDeviceContext dctx;
744     cupmBlasHandle_t   cupmBlasHandle;
745 
746     PetscCall(GetHandles_(&dctx, &cupmBlasHandle));
747     // arguments y, x are reversed because BLAS complex conjugates the first argument, PETSc the
748     // second
749     PetscCall(PetscLogGpuTimeBegin());
750     PetscCallCUPMBLAS(cupmBlasXdot(cupmBlasHandle, n, DeviceArrayRead(dctx, yin), 1, DeviceArrayRead(dctx, xin), 1, cupmScalarPtrCast(z)));
751     PetscCall(PetscLogGpuTimeEnd());
752     PetscCall(PetscLogGpuFlops(2 * n - 1));
753   } else {
754     *z = 0.0;
755   }
756   PetscFunctionReturn(PETSC_SUCCESS);
757 }
758 
759   #define MDOT_WORKGROUP_NUM  128
760   #define MDOT_WORKGROUP_SIZE MDOT_WORKGROUP_NUM
761 
762 namespace kernels
763 {
764 
765 PETSC_DEVICE_INLINE_DECL static PetscInt EntriesPerGroup(const PetscInt size) noexcept
766 {
767   const auto group_entries = (size - 1) / gridDim.x + 1;
768   // for very small vectors, a group should still do some work
769   return group_entries ? group_entries : 1;
770 }
771 
772 template <typename... ConstPetscScalarPointer>
773 PETSC_KERNEL_DECL static void MDot_kernel(const PetscScalar *PETSC_RESTRICT x, const PetscInt size, PetscScalar *PETSC_RESTRICT results, ConstPetscScalarPointer... y)
774 {
775   constexpr int      N        = sizeof...(ConstPetscScalarPointer);
776   const PetscScalar *ylocal[] = {y...};
777   PetscScalar        sumlocal[N];
778 
779   PETSC_SHAREDMEM_DECL PetscScalar shmem[N * MDOT_WORKGROUP_SIZE];
780 
781   // HIP -- for whatever reason -- has threadIdx, blockIdx, blockDim, and gridDim as separate
782   // types, so each of these go on separate lines...
783   const auto tx       = threadIdx.x;
784   const auto bx       = blockIdx.x;
785   const auto bdx      = blockDim.x;
786   const auto gdx      = gridDim.x;
787   const auto worksize = EntriesPerGroup(size);
788   const auto begin    = tx + bx * worksize;
789   const auto end      = min((bx + 1) * worksize, size);
790 
791   #pragma unroll
792   for (auto i = 0; i < N; ++i) sumlocal[i] = 0;
793 
794   for (auto i = begin; i < end; i += bdx) {
795     const auto xi = x[i]; // load only once from global memory!
796 
797   #pragma unroll
798     for (auto j = 0; j < N; ++j) sumlocal[j] += ylocal[j][i] * xi;
799   }
800 
801   #pragma unroll
802   for (auto i = 0; i < N; ++i) shmem[tx + i * MDOT_WORKGROUP_SIZE] = sumlocal[i];
803 
804   // parallel reduction
805   for (auto stride = bdx / 2; stride > 0; stride /= 2) {
806     __syncthreads();
807     if (tx < stride) {
808   #pragma unroll
809       for (auto i = 0; i < N; ++i) shmem[tx + i * MDOT_WORKGROUP_SIZE] += shmem[tx + stride + i * MDOT_WORKGROUP_SIZE];
810     }
811   }
812   // bottom N threads per block write to global memory
813   // REVIEW ME: I am ~pretty~ sure we don't need another __syncthreads() here since each thread
814   // writes to the same sections in the above loop that it is about to read from below, but
815   // running this under the racecheck tool of cuda-memcheck reports a write-after-write hazard.
816   __syncthreads();
817   if (tx < N) results[bx + tx * gdx] = shmem[tx * MDOT_WORKGROUP_SIZE];
818   return;
819 }
820 
821 namespace
822 {
823 
824 PETSC_KERNEL_DECL void sum_kernel(const PetscInt size, PetscScalar *PETSC_RESTRICT results)
825 {
826   int         local_i = 0;
827   PetscScalar local_results[8];
828 
829   // each thread sums up MDOT_WORKGROUP_NUM entries of the result, storing it in a local buffer
830   //
831   // *-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*
832   // | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | ...
833   // *-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*
834   //  |  ______________________________________________________/
835   //  | /            <- MDOT_WORKGROUP_NUM ->
836   //  |/
837   //  +
838   //  v
839   // *-*-*
840   // | | | ...
841   // *-*-*
842   //
843   ::Petsc::device::cupm::kernels::util::grid_stride_1D(size, [&](PetscInt i) {
844     PetscScalar z_sum = 0;
845 
846     for (auto j = i * MDOT_WORKGROUP_SIZE; j < (i + 1) * MDOT_WORKGROUP_SIZE; ++j) z_sum += results[j];
847     local_results[local_i++] = z_sum;
848   });
849   // if we needed more than 1 workgroup to handle the vector we should sync since other threads
850   // may currently be reading from results
851   if (size >= MDOT_WORKGROUP_SIZE) __syncthreads();
852   // Local buffer is now written to global memory
853   ::Petsc::device::cupm::kernels::util::grid_stride_1D(size, [&](PetscInt i) {
854     const auto j = --local_i;
855 
856     if (j >= 0) results[i] = local_results[j];
857   });
858   return;
859 }
860 
861 } // namespace
862 
863 } // namespace kernels
864 
865 template <device::cupm::DeviceType T>
866 template <std::size_t... Idx>
867 inline PetscErrorCode VecSeq_CUPM<T>::MDot_kernel_dispatch_(PetscDeviceContext dctx, cupmStream_t stream, const PetscScalar *xarr, const Vec yin[], PetscInt size, PetscScalar *results, util::index_sequence<Idx...>) noexcept
868 {
869   PetscFunctionBegin;
870   // REVIEW ME: convert this kernel launch to PetscCUPMLaunchKernel1D(), it currently launches
871   // 128 blocks of 128 threads every time which may be wasteful
872   // clang-format off
873   PetscCallCUPM(
874     cupmLaunchKernel(
875       kernels::MDot_kernel<typename detail::repeat_type<const PetscScalar *, Idx>::type...>,
876       MDOT_WORKGROUP_NUM, MDOT_WORKGROUP_SIZE, 0, stream,
877       xarr, size, results, DeviceArrayRead(dctx, yin[Idx]).data()...
878     )
879   );
880   // clang-format on
881   PetscFunctionReturn(PETSC_SUCCESS);
882 }
883 
884 template <device::cupm::DeviceType T>
885 template <int N>
886 inline PetscErrorCode VecSeq_CUPM<T>::MDot_kernel_dispatch_(PetscDeviceContext dctx, cupmStream_t stream, const PetscScalar *xarr, const Vec yin[], PetscInt size, PetscScalar *results, PetscInt &yidx) noexcept
887 {
888   PetscFunctionBegin;
889   PetscCall(MDot_kernel_dispatch_(dctx, stream, xarr, yin + yidx, size, results + yidx * MDOT_WORKGROUP_NUM, util::make_index_sequence<N>{}));
890   yidx += N;
891   PetscFunctionReturn(PETSC_SUCCESS);
892 }
893 
894 template <device::cupm::DeviceType T>
895 inline PetscErrorCode VecSeq_CUPM<T>::MDot_(std::false_type, Vec xin, PetscInt nv, const Vec yin[], PetscScalar *z, PetscDeviceContext dctx) noexcept
896 {
897   // the largest possible size of a batch
898   constexpr PetscInt batchsize = 8;
899   // how many sub streams to create, if nv <= batchsize we can do this without looping, so we
900   // do not create substreams. Note we don't create more than 8 streams, in practice we could
901   // not get more parallelism with higher numbers.
902   const auto num_sub_streams = nv > batchsize ? std::min((nv + batchsize) / batchsize, batchsize) : 0;
903   const auto n               = xin->map->n;
904   // number of vectors that we handle via the batches. note any singletons are handled by
905   // cublas, hence the nv-1.
906   const auto   nvbatch = ((nv % batchsize) == 1) ? nv - 1 : nv;
907   const auto   nwork   = nvbatch * MDOT_WORKGROUP_NUM;
908   PetscScalar *d_results;
909   cupmStream_t stream;
910 
911   PetscFunctionBegin;
912   PetscCall(GetHandlesFrom_(dctx, &stream));
913   // allocate scratchpad memory for the results of individual work groups
914   PetscCall(PetscDeviceMalloc(dctx, PETSC_MEMTYPE_CUPM(), nwork, &d_results));
915   {
916     const auto          xptr       = DeviceArrayRead(dctx, xin);
917     PetscInt            yidx       = 0;
918     auto                subidx     = 0;
919     auto                cur_stream = stream;
920     auto                cur_ctx    = dctx;
921     PetscDeviceContext *sub        = nullptr;
922     PetscStreamType     stype;
923 
924     // REVIEW ME: maybe PetscDeviceContextFork() should insert dctx into the first entry of
925     // sub. Ideally the parent context should also join in on the fork, but it is extremely
926     // fiddly to do so presently
927     PetscCall(PetscDeviceContextGetStreamType(dctx, &stype));
928     if (stype == PETSC_STREAM_GLOBAL_BLOCKING) stype = PETSC_STREAM_DEFAULT_BLOCKING;
929     // If we have a globally blocking stream create nonblocking streams instead (as we can
930     // locally exploit the parallelism). Otherwise use the prescribed stream type.
931     PetscCall(PetscDeviceContextForkWithStreamType(dctx, stype, num_sub_streams, &sub));
932     PetscCall(PetscLogGpuTimeBegin());
933     do {
934       if (num_sub_streams) {
935         cur_ctx = sub[subidx++ % num_sub_streams];
936         PetscCall(GetHandlesFrom_(cur_ctx, &cur_stream));
937       }
938       // REVIEW ME: Should probably try and load-balance these. Consider the case where nv = 9;
939       // it is very likely better to do 4+5 rather than 8+1
940       switch (nv - yidx) {
941       case 7:
942         PetscCall(MDot_kernel_dispatch_<7>(cur_ctx, cur_stream, xptr.data(), yin, n, d_results, yidx));
943         break;
944       case 6:
945         PetscCall(MDot_kernel_dispatch_<6>(cur_ctx, cur_stream, xptr.data(), yin, n, d_results, yidx));
946         break;
947       case 5:
948         PetscCall(MDot_kernel_dispatch_<5>(cur_ctx, cur_stream, xptr.data(), yin, n, d_results, yidx));
949         break;
950       case 4:
951         PetscCall(MDot_kernel_dispatch_<4>(cur_ctx, cur_stream, xptr.data(), yin, n, d_results, yidx));
952         break;
953       case 3:
954         PetscCall(MDot_kernel_dispatch_<3>(cur_ctx, cur_stream, xptr.data(), yin, n, d_results, yidx));
955         break;
956       case 2:
957         PetscCall(MDot_kernel_dispatch_<2>(cur_ctx, cur_stream, xptr.data(), yin, n, d_results, yidx));
958         break;
959       case 1: {
960         cupmBlasHandle_t cupmBlasHandle;
961 
962         PetscCall(GetHandlesFrom_(cur_ctx, &cupmBlasHandle));
963         PetscCallCUPMBLAS(cupmBlasXdot(cupmBlasHandle, static_cast<cupmBlasInt_t>(n), DeviceArrayRead(cur_ctx, yin[yidx]).cupmdata(), 1, xptr.cupmdata(), 1, cupmScalarPtrCast(z + yidx)));
964         ++yidx;
965       } break;
966       default: // 8 or more
967         PetscCall(MDot_kernel_dispatch_<8>(cur_ctx, cur_stream, xptr.data(), yin, n, d_results, yidx));
968         break;
969       }
970     } while (yidx < nv);
971     PetscCall(PetscLogGpuTimeEnd());
972     PetscCall(PetscDeviceContextJoin(dctx, num_sub_streams, PETSC_DEVICE_CONTEXT_JOIN_DESTROY, &sub));
973   }
974 
975   PetscCall(PetscCUPMLaunchKernel1D(nvbatch, 0, stream, kernels::sum_kernel, nvbatch, d_results));
976   // copy result of device reduction to host
977   PetscCall(PetscCUPMMemcpyAsync(z, d_results, nvbatch, cupmMemcpyDeviceToHost, stream));
978   // do these now while final reduction is in flight
979   PetscCall(PetscLogFlops(nwork));
980   PetscCall(PetscDeviceFree(dctx, d_results));
981   PetscFunctionReturn(PETSC_SUCCESS);
982 }
983 
984   #undef MDOT_WORKGROUP_NUM
985   #undef MDOT_WORKGROUP_SIZE
986 
987 template <device::cupm::DeviceType T>
988 inline PetscErrorCode VecSeq_CUPM<T>::MDot_(std::true_type, Vec xin, PetscInt nv, const Vec yin[], PetscScalar *z, PetscDeviceContext dctx) noexcept
989 {
990   // probably not worth it to run more than 8 of these at a time?
991   const auto          n_sub = PetscMin(nv, 8);
992   const auto          n     = static_cast<cupmBlasInt_t>(xin->map->n);
993   const auto          xptr  = DeviceArrayRead(dctx, xin);
994   PetscScalar        *d_z;
995   PetscDeviceContext *subctx;
996   cupmStream_t        stream;
997 
998   PetscFunctionBegin;
999   PetscCall(GetHandlesFrom_(dctx, &stream));
1000   PetscCall(PetscDeviceMalloc(dctx, PETSC_MEMTYPE_CUPM(), nv, &d_z));
1001   PetscCall(PetscDeviceContextFork(dctx, n_sub, &subctx));
1002   PetscCall(PetscLogGpuTimeBegin());
1003   for (PetscInt i = 0; i < nv; ++i) {
1004     const auto            sub = subctx[i % n_sub];
1005     cupmBlasHandle_t      handle;
1006     cupmBlasPointerMode_t old_mode;
1007 
1008     PetscCall(GetHandlesFrom_(sub, &handle));
1009     PetscCallCUPMBLAS(cupmBlasGetPointerMode(handle, &old_mode));
1010     if (old_mode != CUPMBLAS_POINTER_MODE_DEVICE) PetscCallCUPMBLAS(cupmBlasSetPointerMode(handle, CUPMBLAS_POINTER_MODE_DEVICE));
1011     PetscCallCUPMBLAS(cupmBlasXdot(handle, n, DeviceArrayRead(sub, yin[i]), 1, xptr.cupmdata(), 1, cupmScalarPtrCast(d_z + i)));
1012     if (old_mode != CUPMBLAS_POINTER_MODE_DEVICE) PetscCallCUPMBLAS(cupmBlasSetPointerMode(handle, old_mode));
1013   }
1014   PetscCall(PetscLogGpuTimeEnd());
1015   PetscCall(PetscDeviceContextJoin(dctx, n_sub, PETSC_DEVICE_CONTEXT_JOIN_DESTROY, &subctx));
1016   PetscCall(PetscCUPMMemcpyAsync(z, d_z, nv, cupmMemcpyDeviceToHost, stream));
1017   PetscCall(PetscDeviceFree(dctx, d_z));
1018   // REVIEW ME: flops?????
1019   PetscFunctionReturn(PETSC_SUCCESS);
1020 }
1021 
1022 // v->ops->mdot
1023 template <device::cupm::DeviceType T>
1024 inline PetscErrorCode VecSeq_CUPM<T>::MDot(Vec xin, PetscInt nv, const Vec yin[], PetscScalar *z) noexcept
1025 {
1026   PetscFunctionBegin;
1027   if (PetscUnlikely(nv == 1)) {
1028     // dot handles nv = 0 correctly
1029     PetscCall(Dot(xin, const_cast<Vec>(yin[0]), z));
1030   } else if (const auto n = xin->map->n) {
1031     PetscDeviceContext dctx;
1032 
1033     PetscCheck(nv > 0, PETSC_COMM_SELF, PETSC_ERR_LIB, "Number of vectors provided to %s %" PetscInt_FMT " not positive", PETSC_FUNCTION_NAME, nv);
1034     PetscCall(GetHandles_(&dctx));
1035     PetscCall(MDot_(std::integral_constant<bool, PetscDefined(USE_COMPLEX)>{}, xin, nv, yin, z, dctx));
1036     // REVIEW ME: double count of flops??
1037     PetscCall(PetscLogGpuFlops(nv * (2 * n - 1)));
1038     PetscCall(PetscDeviceContextSynchronize(dctx));
1039   } else {
1040     PetscCall(PetscArrayzero(z, nv));
1041   }
1042   PetscFunctionReturn(PETSC_SUCCESS);
1043 }
1044 
1045 // v->ops->set
1046 template <device::cupm::DeviceType T>
1047 inline PetscErrorCode VecSeq_CUPM<T>::Set(Vec xin, PetscScalar alpha) noexcept
1048 {
1049   const auto         n = xin->map->n;
1050   PetscDeviceContext dctx;
1051   cupmStream_t       stream;
1052 
1053   PetscFunctionBegin;
1054   PetscCall(GetHandles_(&dctx, &stream));
1055   {
1056     const auto xptr = DeviceArrayWrite(dctx, xin);
1057 
1058     if (alpha == PetscScalar(0.0)) {
1059       PetscCall(PetscCUPMMemsetAsync(xptr.data(), 0, n, stream));
1060     } else {
1061       const auto dptr = thrust::device_pointer_cast(xptr.data());
1062 
1063       PetscCallThrust(THRUST_CALL(thrust::fill, stream, dptr, dptr + n, alpha));
1064     }
1065     if (n) PetscCall(PetscDeviceContextSynchronize(dctx)); // don't sync if we did nothing
1066   }
1067   PetscFunctionReturn(PETSC_SUCCESS);
1068 }
1069 
1070 // v->ops->scale
1071 template <device::cupm::DeviceType T>
1072 inline PetscErrorCode VecSeq_CUPM<T>::Scale(Vec xin, PetscScalar alpha) noexcept
1073 {
1074   PetscFunctionBegin;
1075   if (PetscUnlikely(alpha == PetscScalar(1.0))) PetscFunctionReturn(PETSC_SUCCESS);
1076   if (PetscUnlikely(alpha == PetscScalar(0.0))) {
1077     PetscCall(Set(xin, alpha));
1078   } else if (const auto n = static_cast<cupmBlasInt_t>(xin->map->n)) {
1079     PetscDeviceContext dctx;
1080     cupmBlasHandle_t   cupmBlasHandle;
1081 
1082     PetscCall(GetHandles_(&dctx, &cupmBlasHandle));
1083     PetscCall(PetscLogGpuTimeBegin());
1084     PetscCallCUPMBLAS(cupmBlasXscal(cupmBlasHandle, n, cupmScalarPtrCast(&alpha), DeviceArrayReadWrite(dctx, xin), 1));
1085     PetscCall(PetscLogGpuTimeEnd());
1086     PetscCall(PetscLogGpuFlops(n));
1087     PetscCall(PetscDeviceContextSynchronize(dctx));
1088   } else {
1089     PetscCall(MaybeIncrementEmptyLocalVec(xin));
1090   }
1091   PetscFunctionReturn(PETSC_SUCCESS);
1092 }
1093 
1094 // v->ops->tdot
1095 template <device::cupm::DeviceType T>
1096 inline PetscErrorCode VecSeq_CUPM<T>::TDot(Vec xin, Vec yin, PetscScalar *z) noexcept
1097 {
1098   PetscFunctionBegin;
1099   if (const auto n = static_cast<cupmBlasInt_t>(xin->map->n)) {
1100     PetscDeviceContext dctx;
1101     cupmBlasHandle_t   cupmBlasHandle;
1102 
1103     PetscCall(GetHandles_(&dctx, &cupmBlasHandle));
1104     PetscCall(PetscLogGpuTimeBegin());
1105     PetscCallCUPMBLAS(cupmBlasXdotu(cupmBlasHandle, n, DeviceArrayRead(dctx, xin), 1, DeviceArrayRead(dctx, yin), 1, cupmScalarPtrCast(z)));
1106     PetscCall(PetscLogGpuTimeEnd());
1107     PetscCall(PetscLogGpuFlops(2 * n - 1));
1108   } else {
1109     *z = 0.0;
1110   }
1111   PetscFunctionReturn(PETSC_SUCCESS);
1112 }
1113 
1114 // v->ops->copy
1115 template <device::cupm::DeviceType T>
1116 inline PetscErrorCode VecSeq_CUPM<T>::Copy(Vec xin, Vec yout) noexcept
1117 {
1118   PetscFunctionBegin;
1119   if (xin == yout) PetscFunctionReturn(PETSC_SUCCESS);
1120   if (const auto n = xin->map->n) {
1121     const auto xmask = xin->offloadmask;
1122     // silence buggy gcc warning: mode may be used uninitialized in this function
1123     auto               mode = cupmMemcpyDeviceToDevice;
1124     PetscDeviceContext dctx;
1125     cupmStream_t       stream;
1126 
1127     // translate from PetscOffloadMask to cupmMemcpyKind
1128     switch (const auto ymask = yout->offloadmask) {
1129     case PETSC_OFFLOAD_UNALLOCATED: {
1130       PetscBool yiscupm;
1131 
1132       PetscCall(PetscObjectTypeCompareAny(PetscObjectCast(yout), &yiscupm, VECSEQCUPM(), VECMPICUPM(), ""));
1133       if (yiscupm) {
1134         mode = PetscOffloadDevice(xmask) ? cupmMemcpyDeviceToDevice : cupmMemcpyHostToHost;
1135         break;
1136       }
1137     } // fall-through if unallocated and not cupm
1138   #if PETSC_CPP_VERSION >= 17
1139       [[fallthrough]];
1140   #endif
1141     case PETSC_OFFLOAD_CPU:
1142       mode = PetscOffloadHost(xmask) ? cupmMemcpyHostToHost : cupmMemcpyDeviceToHost;
1143       break;
1144     case PETSC_OFFLOAD_BOTH:
1145     case PETSC_OFFLOAD_GPU:
1146       mode = PetscOffloadDevice(xmask) ? cupmMemcpyDeviceToDevice : cupmMemcpyHostToDevice;
1147       break;
1148     default:
1149       SETERRQ(PETSC_COMM_SELF, PETSC_ERR_ARG_INCOMP, "Incompatible offload mask %s", PetscOffloadMaskToString(ymask));
1150     }
1151 
1152     PetscCall(GetHandles_(&dctx, &stream));
1153     switch (mode) {
1154     case cupmMemcpyDeviceToDevice: // the best case
1155     case cupmMemcpyHostToDevice: { // not terrible
1156       const auto yptr = DeviceArrayWrite(dctx, yout);
1157       const auto xptr = mode == cupmMemcpyDeviceToDevice ? DeviceArrayRead(dctx, xin).data() : HostArrayRead(dctx, xin).data();
1158 
1159       PetscCall(PetscLogGpuTimeBegin());
1160       PetscCall(PetscCUPMMemcpyAsync(yptr.data(), xptr, n, mode, stream));
1161       PetscCall(PetscLogGpuTimeEnd());
1162     } break;
1163     case cupmMemcpyDeviceToHost: // not great
1164     case cupmMemcpyHostToHost: { // worst case
1165       const auto   xptr = mode == cupmMemcpyDeviceToHost ? DeviceArrayRead(dctx, xin).data() : HostArrayRead(dctx, xin).data();
1166       PetscScalar *yptr;
1167 
1168       PetscCall(VecGetArrayWrite(yout, &yptr));
1169       if (mode == cupmMemcpyDeviceToHost) PetscCall(PetscLogGpuTimeBegin());
1170       PetscCall(PetscCUPMMemcpyAsync(yptr, xptr, n, mode, stream, /* force async */ true));
1171       if (mode == cupmMemcpyDeviceToHost) PetscCall(PetscLogGpuTimeEnd());
1172       PetscCall(VecRestoreArrayWrite(yout, &yptr));
1173     } break;
1174     default:
1175       SETERRQ(PETSC_COMM_SELF, PETSC_ERR_GPU, "Unknown cupmMemcpyKind %d", static_cast<int>(mode));
1176     }
1177     PetscCall(PetscDeviceContextSynchronize(dctx));
1178   } else {
1179     PetscCall(MaybeIncrementEmptyLocalVec(yout));
1180   }
1181   PetscFunctionReturn(PETSC_SUCCESS);
1182 }
1183 
1184 // v->ops->swap
1185 template <device::cupm::DeviceType T>
1186 inline PetscErrorCode VecSeq_CUPM<T>::Swap(Vec xin, Vec yin) noexcept
1187 {
1188   PetscFunctionBegin;
1189   if (xin == yin) PetscFunctionReturn(PETSC_SUCCESS);
1190   if (const auto n = static_cast<cupmBlasInt_t>(xin->map->n)) {
1191     PetscDeviceContext dctx;
1192     cupmBlasHandle_t   cupmBlasHandle;
1193 
1194     PetscCall(GetHandles_(&dctx, &cupmBlasHandle));
1195     PetscCall(PetscLogGpuTimeBegin());
1196     PetscCallCUPMBLAS(cupmBlasXswap(cupmBlasHandle, n, DeviceArrayReadWrite(dctx, xin), 1, DeviceArrayReadWrite(dctx, yin), 1));
1197     PetscCall(PetscLogGpuTimeEnd());
1198     PetscCall(PetscDeviceContextSynchronize(dctx));
1199   } else {
1200     PetscCall(MaybeIncrementEmptyLocalVec(xin));
1201     PetscCall(MaybeIncrementEmptyLocalVec(yin));
1202   }
1203   PetscFunctionReturn(PETSC_SUCCESS);
1204 }
1205 
1206 // v->ops->axpby
1207 template <device::cupm::DeviceType T>
1208 inline PetscErrorCode VecSeq_CUPM<T>::AXPBY(Vec yin, PetscScalar alpha, PetscScalar beta, Vec xin) noexcept
1209 {
1210   PetscFunctionBegin;
1211   if (alpha == PetscScalar(0.0)) {
1212     PetscCall(Scale(yin, beta));
1213   } else if (beta == PetscScalar(1.0)) {
1214     PetscCall(AXPY(yin, alpha, xin));
1215   } else if (alpha == PetscScalar(1.0)) {
1216     PetscCall(AYPX(yin, beta, xin));
1217   } else if (const auto n = static_cast<cupmBlasInt_t>(yin->map->n)) {
1218     const auto         betaIsZero = beta == PetscScalar(0.0);
1219     const auto         aptr       = cupmScalarPtrCast(&alpha);
1220     PetscDeviceContext dctx;
1221     cupmBlasHandle_t   cupmBlasHandle;
1222 
1223     PetscCall(GetHandles_(&dctx, &cupmBlasHandle));
1224     {
1225       const auto xptr = DeviceArrayRead(dctx, xin);
1226 
1227       if (betaIsZero /* beta = 0 */) {
1228         // here we can get away with purely write-only as we memcpy into it first
1229         const auto   yptr = DeviceArrayWrite(dctx, yin);
1230         cupmStream_t stream;
1231 
1232         PetscCall(GetHandlesFrom_(dctx, &stream));
1233         PetscCall(PetscLogGpuTimeBegin());
1234         PetscCall(PetscCUPMMemcpyAsync(yptr.data(), xptr.data(), n, cupmMemcpyDeviceToDevice, stream));
1235         PetscCallCUPMBLAS(cupmBlasXscal(cupmBlasHandle, n, aptr, yptr.cupmdata(), 1));
1236       } else {
1237         const auto yptr = DeviceArrayReadWrite(dctx, yin);
1238 
1239         PetscCall(PetscLogGpuTimeBegin());
1240         PetscCallCUPMBLAS(cupmBlasXscal(cupmBlasHandle, n, cupmScalarPtrCast(&beta), yptr.cupmdata(), 1));
1241         PetscCallCUPMBLAS(cupmBlasXaxpy(cupmBlasHandle, n, aptr, xptr.cupmdata(), 1, yptr.cupmdata(), 1));
1242       }
1243     }
1244     PetscCall(PetscLogGpuTimeEnd());
1245     PetscCall(PetscLogGpuFlops((betaIsZero ? 1 : 3) * n));
1246     PetscCall(PetscDeviceContextSynchronize(dctx));
1247   } else {
1248     PetscCall(MaybeIncrementEmptyLocalVec(yin));
1249   }
1250   PetscFunctionReturn(PETSC_SUCCESS);
1251 }
1252 
1253 // v->ops->axpbypcz
1254 template <device::cupm::DeviceType T>
1255 inline PetscErrorCode VecSeq_CUPM<T>::AXPBYPCZ(Vec zin, PetscScalar alpha, PetscScalar beta, PetscScalar gamma, Vec xin, Vec yin) noexcept
1256 {
1257   PetscFunctionBegin;
1258   if (gamma != PetscScalar(1.0)) PetscCall(Scale(zin, gamma));
1259   PetscCall(AXPY(zin, alpha, xin));
1260   PetscCall(AXPY(zin, beta, yin));
1261   PetscFunctionReturn(PETSC_SUCCESS);
1262 }
1263 
1264 // v->ops->norm
1265 template <device::cupm::DeviceType T>
1266 inline PetscErrorCode VecSeq_CUPM<T>::Norm(Vec xin, NormType type, PetscReal *z) noexcept
1267 {
1268   PetscDeviceContext dctx;
1269   cupmBlasHandle_t   cupmBlasHandle;
1270 
1271   PetscFunctionBegin;
1272   PetscCall(GetHandles_(&dctx, &cupmBlasHandle));
1273   if (const auto n = static_cast<cupmBlasInt_t>(xin->map->n)) {
1274     const auto xptr      = DeviceArrayRead(dctx, xin);
1275     PetscInt   flopCount = 0;
1276 
1277     PetscCall(PetscLogGpuTimeBegin());
1278     switch (type) {
1279     case NORM_1_AND_2:
1280     case NORM_1:
1281       PetscCallCUPMBLAS(cupmBlasXasum(cupmBlasHandle, n, xptr.cupmdata(), 1, cupmRealPtrCast(z)));
1282       flopCount = std::max(n - 1, 0);
1283       if (type == NORM_1) break;
1284       ++z; // fall-through
1285   #if PETSC_CPP_VERSION >= 17
1286       [[fallthrough]];
1287   #endif
1288     case NORM_2:
1289     case NORM_FROBENIUS:
1290       PetscCallCUPMBLAS(cupmBlasXnrm2(cupmBlasHandle, n, xptr.cupmdata(), 1, cupmRealPtrCast(z)));
1291       flopCount += std::max(2 * n - 1, 0); // += in case we've fallen through from NORM_1_AND_2
1292       break;
1293     case NORM_INFINITY: {
1294       cupmBlasInt_t max_loc = 0;
1295       PetscScalar   xv      = 0.;
1296       cupmStream_t  stream;
1297 
1298       PetscCall(GetHandlesFrom_(dctx, &stream));
1299       PetscCallCUPMBLAS(cupmBlasXamax(cupmBlasHandle, n, xptr.cupmdata(), 1, &max_loc));
1300       PetscCall(PetscCUPMMemcpyAsync(&xv, xptr.data() + max_loc - 1, 1, cupmMemcpyDeviceToHost, stream));
1301       *z = PetscAbsScalar(xv);
1302       // REVIEW ME: flopCount = ???
1303     } break;
1304     }
1305     PetscCall(PetscLogGpuTimeEnd());
1306     PetscCall(PetscLogGpuFlops(flopCount));
1307   } else {
1308     z[0]                    = 0.0;
1309     z[type == NORM_1_AND_2] = 0.0;
1310   }
1311   PetscFunctionReturn(PETSC_SUCCESS);
1312 }
1313 
1314 namespace detail
1315 {
1316 
1317 struct dotnorm2_mult {
1318   PETSC_NODISCARD PETSC_HOSTDEVICE_INLINE_DECL thrust::tuple<PetscScalar, PetscScalar> operator()(const PetscScalar &s, const PetscScalar &t) const noexcept
1319   {
1320     const auto conjt = PetscConj(t);
1321 
1322     return {s * conjt, t * conjt};
1323   }
1324 };
1325 
1326 // it is positively __bananas__ that thrust does not define default operator+ for tuples... I
1327 // would do it myself but now I am worried that they do so on purpose...
1328 struct dotnorm2_tuple_plus {
1329   using value_type = thrust::tuple<PetscScalar, PetscScalar>;
1330 
1331   PETSC_NODISCARD PETSC_HOSTDEVICE_INLINE_DECL value_type operator()(const value_type &lhs, const value_type &rhs) const noexcept { return {lhs.get<0>() + rhs.get<0>(), lhs.get<1>() + rhs.get<1>()}; }
1332 };
1333 
1334 } // namespace detail
1335 
1336 // v->ops->dotnorm2
1337 template <device::cupm::DeviceType T>
1338 inline PetscErrorCode VecSeq_CUPM<T>::DotNorm2(Vec s, Vec t, PetscScalar *dp, PetscScalar *nm) noexcept
1339 {
1340   PetscDeviceContext dctx;
1341   cupmStream_t       stream;
1342 
1343   PetscFunctionBegin;
1344   PetscCall(GetHandles_(&dctx, &stream));
1345   {
1346     PetscScalar dpt = 0.0, nmt = 0.0;
1347     const auto  sdptr = thrust::device_pointer_cast(DeviceArrayRead(dctx, s).data());
1348 
1349     // clang-format off
1350     PetscCallThrust(
1351       thrust::tie(*dp, *nm) = THRUST_CALL(
1352         thrust::inner_product,
1353         stream,
1354         sdptr, sdptr+s->map->n, thrust::device_pointer_cast(DeviceArrayRead(dctx, t).data()),
1355         thrust::make_tuple(dpt, nmt),
1356         detail::dotnorm2_tuple_plus{}, detail::dotnorm2_mult{}
1357       );
1358     );
1359     // clang-format on
1360   }
1361   PetscFunctionReturn(PETSC_SUCCESS);
1362 }
1363 
1364 namespace detail
1365 {
1366 
1367 struct conjugate {
1368   PETSC_NODISCARD PETSC_HOSTDEVICE_INLINE_DECL PetscScalar operator()(PetscScalar x) const noexcept { return PetscConj(x); }
1369 };
1370 
1371 } // namespace detail
1372 
1373 // v->ops->conjugate
1374 template <device::cupm::DeviceType T>
1375 inline PetscErrorCode VecSeq_CUPM<T>::Conjugate(Vec xin) noexcept
1376 {
1377   PetscFunctionBegin;
1378   if (PetscDefined(USE_COMPLEX)) PetscCall(PointwiseUnary_(detail::conjugate{}, xin));
1379   PetscFunctionReturn(PETSC_SUCCESS);
1380 }
1381 
1382 namespace detail
1383 {
1384 
1385 struct real_part {
1386   PETSC_NODISCARD PETSC_HOSTDEVICE_INLINE_DECL thrust::tuple<PetscReal, PetscInt> operator()(const thrust::tuple<PetscScalar, PetscInt> &x) const { return {PetscRealPart(x.get<0>()), x.get<1>()}; }
1387 
1388   PETSC_NODISCARD PETSC_HOSTDEVICE_INLINE_DECL PetscReal operator()(PetscScalar x) const { return PetscRealPart(x); }
1389 };
1390 
1391 // deriving from Operator allows us to "store" an instance of the operator in the class but
1392 // also take advantage of empty base class optimization if the operator is stateless
1393 template <typename Operator>
1394 class tuple_compare : Operator {
1395 public:
1396   using tuple_type    = thrust::tuple<PetscReal, PetscInt>;
1397   using operator_type = Operator;
1398 
1399   PETSC_NODISCARD PETSC_HOSTDEVICE_INLINE_DECL tuple_type operator()(const tuple_type &x, const tuple_type &y) const noexcept
1400   {
1401     if (op_()(y.get<0>(), x.get<0>())) {
1402       // if y is strictly greater/less than x, return y
1403       return y;
1404     } else if (y.get<0>() == x.get<0>()) {
1405       // if equal, prefer lower index
1406       return y.get<1>() < x.get<1>() ? y : x;
1407     }
1408     // otherwise return x
1409     return x;
1410   }
1411 
1412 private:
1413   PETSC_NODISCARD PETSC_HOSTDEVICE_INLINE_DECL const operator_type &op_() const noexcept { return *this; }
1414 };
1415 
1416 } // namespace detail
1417 
1418 template <device::cupm::DeviceType T>
1419 template <typename TupleFuncT, typename UnaryFuncT>
1420 inline PetscErrorCode VecSeq_CUPM<T>::MinMax_(TupleFuncT &&tuple_ftr, UnaryFuncT &&unary_ftr, Vec v, PetscInt *p, PetscReal *m) noexcept
1421 {
1422   PetscFunctionBegin;
1423   PetscCheckTypeNames(v, VECSEQCUPM(), VECMPICUPM());
1424   if (p) *p = -1;
1425   if (const auto n = v->map->n) {
1426     PetscDeviceContext dctx;
1427     cupmStream_t       stream;
1428 
1429     PetscCall(GetHandles_(&dctx, &stream));
1430       // needed to:
1431       // 1. switch between transform_reduce and reduce
1432       // 2. strip the real_part functor from the arguments
1433   #if PetscDefined(USE_COMPLEX)
1434     #define THRUST_MINMAX_REDUCE(...) THRUST_CALL(thrust::transform_reduce, __VA_ARGS__)
1435   #else
1436     #define THRUST_MINMAX_REDUCE(s, b, e, real_part__, ...) THRUST_CALL(thrust::reduce, s, b, e, __VA_ARGS__)
1437   #endif
1438     {
1439       const auto vptr = thrust::device_pointer_cast(DeviceArrayRead(dctx, v).data());
1440 
1441       if (p) {
1442         // clang-format off
1443         const auto zip = thrust::make_zip_iterator(
1444           thrust::make_tuple(std::move(vptr), thrust::make_counting_iterator(PetscInt{0}))
1445         );
1446         // clang-format on
1447         // need to use preprocessor conditionals since otherwise thrust complains about not being
1448         // able to convert a thrust::device_reference<PetscScalar> to a PetscReal on complex
1449         // builds...
1450         // clang-format off
1451         PetscCallThrust(
1452           thrust::tie(*m, *p) = THRUST_MINMAX_REDUCE(
1453             stream, zip, zip + n, detail::real_part{},
1454             thrust::make_tuple(*m, *p), std::forward<TupleFuncT>(tuple_ftr)
1455           );
1456         );
1457         // clang-format on
1458       } else {
1459         // clang-format off
1460         PetscCallThrust(
1461           *m = THRUST_MINMAX_REDUCE(
1462             stream, vptr, vptr + n, detail::real_part{},
1463             *m, std::forward<UnaryFuncT>(unary_ftr)
1464           );
1465         );
1466         // clang-format on
1467       }
1468     }
1469   #undef THRUST_MINMAX_REDUCE
1470   }
1471   // REVIEW ME: flops?
1472   PetscFunctionReturn(PETSC_SUCCESS);
1473 }
1474 
1475 // v->ops->max
1476 template <device::cupm::DeviceType T>
1477 inline PetscErrorCode VecSeq_CUPM<T>::Max(Vec v, PetscInt *p, PetscReal *m) noexcept
1478 {
1479   using tuple_functor = detail::tuple_compare<thrust::greater<PetscReal>>;
1480   using unary_functor = thrust::maximum<PetscReal>;
1481 
1482   PetscFunctionBegin;
1483   *m = PETSC_MIN_REAL;
1484   // use {} constructor syntax otherwise most vexing parse
1485   PetscCall(MinMax_(tuple_functor{}, unary_functor{}, v, p, m));
1486   PetscFunctionReturn(PETSC_SUCCESS);
1487 }
1488 
1489 // v->ops->min
1490 template <device::cupm::DeviceType T>
1491 inline PetscErrorCode VecSeq_CUPM<T>::Min(Vec v, PetscInt *p, PetscReal *m) noexcept
1492 {
1493   using tuple_functor = detail::tuple_compare<thrust::less<PetscReal>>;
1494   using unary_functor = thrust::minimum<PetscReal>;
1495 
1496   PetscFunctionBegin;
1497   *m = PETSC_MAX_REAL;
1498   // use {} constructor syntax otherwise most vexing parse
1499   PetscCall(MinMax_(tuple_functor{}, unary_functor{}, v, p, m));
1500   PetscFunctionReturn(PETSC_SUCCESS);
1501 }
1502 
1503 // v->ops->sum
1504 template <device::cupm::DeviceType T>
1505 inline PetscErrorCode VecSeq_CUPM<T>::Sum(Vec v, PetscScalar *sum) noexcept
1506 {
1507   PetscFunctionBegin;
1508   if (const auto n = v->map->n) {
1509     PetscDeviceContext dctx;
1510     cupmStream_t       stream;
1511 
1512     PetscCall(GetHandles_(&dctx, &stream));
1513     const auto dptr = thrust::device_pointer_cast(DeviceArrayRead(dctx, v).data());
1514     // REVIEW ME: why not cupmBlasXasum()?
1515     PetscCallThrust(*sum = THRUST_CALL(thrust::reduce, stream, dptr, dptr + n, PetscScalar{0.0}););
1516     // REVIEW ME: must be at least n additions
1517     PetscCall(PetscLogGpuFlops(n));
1518   } else {
1519     *sum = 0.0;
1520   }
1521   PetscFunctionReturn(PETSC_SUCCESS);
1522 }
1523 
1524 template <device::cupm::DeviceType T>
1525 inline PetscErrorCode VecSeq_CUPM<T>::Shift(Vec v, PetscScalar shift) noexcept
1526 {
1527   PetscFunctionBegin;
1528   PetscCall(PointwiseUnary_(device::cupm::functors::make_plus_equals(shift), v));
1529   PetscFunctionReturn(PETSC_SUCCESS);
1530 }
1531 
1532 template <device::cupm::DeviceType T>
1533 inline PetscErrorCode VecSeq_CUPM<T>::SetRandom(Vec v, PetscRandom rand) noexcept
1534 {
1535   PetscFunctionBegin;
1536   if (const auto n = v->map->n) {
1537     PetscBool          iscurand;
1538     PetscDeviceContext dctx;
1539 
1540     PetscCall(GetHandles_(&dctx));
1541     PetscCall(PetscObjectTypeCompare(PetscObjectCast(rand), PETSCCURAND, &iscurand));
1542     if (iscurand) PetscCall(PetscRandomGetValues(rand, n, DeviceArrayWrite(dctx, v)));
1543     else PetscCall(PetscRandomGetValues(rand, n, HostArrayWrite(dctx, v)));
1544   } else {
1545     PetscCall(MaybeIncrementEmptyLocalVec(v));
1546   }
1547   // REVIEW ME: flops????
1548   // REVIEW ME: Timing???
1549   PetscFunctionReturn(PETSC_SUCCESS);
1550 }
1551 
1552 // v->ops->setpreallocation
1553 template <device::cupm::DeviceType T>
1554 inline PetscErrorCode VecSeq_CUPM<T>::SetPreallocationCOO(Vec v, PetscCount ncoo, const PetscInt coo_i[]) noexcept
1555 {
1556   PetscDeviceContext dctx;
1557 
1558   PetscFunctionBegin;
1559   PetscCall(GetHandles_(&dctx));
1560   PetscCall(VecSetPreallocationCOO_Seq(v, ncoo, coo_i));
1561   PetscCall(SetPreallocationCOO_CUPMBase(v, ncoo, coo_i, dctx));
1562   PetscFunctionReturn(PETSC_SUCCESS);
1563 }
1564 
1565 namespace kernels
1566 {
1567 
1568 template <typename F>
1569 PETSC_DEVICE_INLINE_DECL void add_coo_values_impl(const PetscScalar *PETSC_RESTRICT vv, PetscCount n, const PetscCount *PETSC_RESTRICT jmap, const PetscCount *PETSC_RESTRICT perm, InsertMode imode, PetscScalar *PETSC_RESTRICT xv, F &&xvindex)
1570 {
1571   ::Petsc::device::cupm::kernels::util::grid_stride_1D(n, [=](PetscCount i) {
1572     const auto  end = jmap[i + 1];
1573     const auto  idx = xvindex(i);
1574     PetscScalar sum = 0.0;
1575 
1576     for (auto k = jmap[i]; k < end; ++k) sum += vv[perm[k]];
1577 
1578     if (imode == INSERT_VALUES) {
1579       xv[idx] = sum;
1580     } else {
1581       xv[idx] += sum;
1582     }
1583   });
1584   return;
1585 }
1586 
1587 namespace
1588 {
1589 
1590 PETSC_KERNEL_DECL void add_coo_values(const PetscScalar *PETSC_RESTRICT v, PetscCount n, const PetscCount *PETSC_RESTRICT jmap1, const PetscCount *PETSC_RESTRICT perm1, InsertMode imode, PetscScalar *PETSC_RESTRICT xv)
1591 {
1592   add_coo_values_impl(v, n, jmap1, perm1, imode, xv, [](PetscCount i) { return i; });
1593   return;
1594 }
1595 
1596 } // namespace
1597 
1598   #if PetscDefined(USING_HCC)
1599 namespace do_not_use
1600 {
1601 
1602 // Needed to silence clang warning:
1603 //
1604 // warning: function 'FUNCTION NAME' is not needed and will not be emitted
1605 //
1606 // The warning is silly, since the function *is* used, however the host compiler does not
1607 // appear see this. Likely because the function using it is in a template.
1608 //
1609 // This warning appeared in clang-11, and still persists until clang-15 (21/02/2023)
1610 inline void silence_warning_function_sum_kernel_is_not_needed_and_will_not_be_emitted()
1611 {
1612   (void)sum_kernel;
1613 }
1614 
1615 inline void silence_warning_function_add_coo_values_is_not_needed_and_will_not_be_emitted()
1616 {
1617   (void)add_coo_values;
1618 }
1619 
1620 } // namespace do_not_use
1621   #endif
1622 
1623 } // namespace kernels
1624 
1625 // v->ops->setvaluescoo
1626 template <device::cupm::DeviceType T>
1627 inline PetscErrorCode VecSeq_CUPM<T>::SetValuesCOO(Vec x, const PetscScalar v[], InsertMode imode) noexcept
1628 {
1629   auto               vv = const_cast<PetscScalar *>(v);
1630   PetscMemType       memtype;
1631   PetscDeviceContext dctx;
1632   cupmStream_t       stream;
1633 
1634   PetscFunctionBegin;
1635   PetscCall(GetHandles_(&dctx, &stream));
1636   PetscCall(PetscGetMemType(v, &memtype));
1637   if (PetscMemTypeHost(memtype)) {
1638     const auto size = VecIMPLCast(x)->coo_n;
1639 
1640     // If user gave v[] in host, we might need to copy it to device if any
1641     PetscCall(PetscDeviceMalloc(dctx, PETSC_MEMTYPE_CUPM(), size, &vv));
1642     PetscCall(PetscCUPMMemcpyAsync(vv, v, size, cupmMemcpyHostToDevice, stream));
1643   }
1644 
1645   if (const auto n = x->map->n) {
1646     const auto vcu = VecCUPMCast(x);
1647 
1648     PetscCall(PetscCUPMLaunchKernel1D(n, 0, stream, kernels::add_coo_values, vv, n, vcu->jmap1_d, vcu->perm1_d, imode, imode == INSERT_VALUES ? DeviceArrayWrite(dctx, x).data() : DeviceArrayReadWrite(dctx, x).data()));
1649   } else {
1650     PetscCall(MaybeIncrementEmptyLocalVec(x));
1651   }
1652 
1653   if (PetscMemTypeHost(memtype)) PetscCall(PetscDeviceFree(dctx, vv));
1654   PetscCall(PetscDeviceContextSynchronize(dctx));
1655   PetscFunctionReturn(PETSC_SUCCESS);
1656 }
1657 
1658 } // namespace impl
1659 
1660 // ==========================================================================================
1661 // VecSeq_CUPM - Implementations
1662 // ==========================================================================================
1663 
1664 namespace
1665 {
1666 
1667 template <device::cupm::DeviceType T>
1668 inline PetscErrorCode VecCreateSeqCUPMAsync(MPI_Comm comm, PetscInt n, Vec *v) noexcept
1669 {
1670   PetscFunctionBegin;
1671   PetscValidPointer(v, 4);
1672   PetscCall(impl::VecSeq_CUPM<T>::CreateSeqCUPM(comm, 0, n, v, PETSC_TRUE));
1673   PetscFunctionReturn(PETSC_SUCCESS);
1674 }
1675 
1676 template <device::cupm::DeviceType T>
1677 inline PetscErrorCode VecCreateSeqCUPMWithArraysAsync(MPI_Comm comm, PetscInt bs, PetscInt n, const PetscScalar cpuarray[], const PetscScalar gpuarray[], Vec *v) noexcept
1678 {
1679   PetscFunctionBegin;
1680   if (n && cpuarray) PetscValidScalarPointer(cpuarray, 4);
1681   PetscValidPointer(v, 6);
1682   PetscCall(impl::VecSeq_CUPM<T>::CreateSeqCUPMWithBothArrays(comm, bs, n, cpuarray, gpuarray, v));
1683   PetscFunctionReturn(PETSC_SUCCESS);
1684 }
1685 
1686 template <PetscMemoryAccessMode mode, device::cupm::DeviceType T>
1687 inline PetscErrorCode VecCUPMGetArrayAsync_Private(Vec v, PetscScalar **a, PetscDeviceContext dctx) noexcept
1688 {
1689   PetscFunctionBegin;
1690   PetscValidHeaderSpecific(v, VEC_CLASSID, 1);
1691   PetscValidPointer(a, 2);
1692   PetscCall(PetscDeviceContextGetOptionalNullContext_Internal(&dctx));
1693   PetscCall(impl::VecSeq_CUPM<T>::template GetArray<PETSC_MEMTYPE_DEVICE, mode>(v, a, dctx));
1694   PetscFunctionReturn(PETSC_SUCCESS);
1695 }
1696 
1697 template <PetscMemoryAccessMode mode, device::cupm::DeviceType T>
1698 inline PetscErrorCode VecCUPMRestoreArrayAsync_Private(Vec v, PetscScalar **a, PetscDeviceContext dctx) noexcept
1699 {
1700   PetscFunctionBegin;
1701   PetscValidHeaderSpecific(v, VEC_CLASSID, 1);
1702   PetscCall(PetscDeviceContextGetOptionalNullContext_Internal(&dctx));
1703   PetscCall(impl::VecSeq_CUPM<T>::template RestoreArray<PETSC_MEMTYPE_DEVICE, mode>(v, a, dctx));
1704   PetscFunctionReturn(PETSC_SUCCESS);
1705 }
1706 
1707 template <device::cupm::DeviceType T>
1708 inline PetscErrorCode VecCUPMGetArrayAsync(Vec v, PetscScalar **a, PetscDeviceContext dctx = nullptr) noexcept
1709 {
1710   PetscFunctionBegin;
1711   PetscCall(VecCUPMGetArrayAsync_Private<PETSC_MEMORY_ACCESS_READ_WRITE, T>(v, a, dctx));
1712   PetscFunctionReturn(PETSC_SUCCESS);
1713 }
1714 
1715 template <device::cupm::DeviceType T>
1716 inline PetscErrorCode VecCUPMRestoreArrayAsync(Vec v, PetscScalar **a, PetscDeviceContext dctx = nullptr) noexcept
1717 {
1718   PetscFunctionBegin;
1719   PetscCall(VecCUPMRestoreArrayAsync_Private<PETSC_MEMORY_ACCESS_READ_WRITE, T>(v, a, dctx));
1720   PetscFunctionReturn(PETSC_SUCCESS);
1721 }
1722 
1723 template <device::cupm::DeviceType T>
1724 inline PetscErrorCode VecCUPMGetArrayReadAsync(Vec v, const PetscScalar **a, PetscDeviceContext dctx = nullptr) noexcept
1725 {
1726   PetscFunctionBegin;
1727   PetscCall(VecCUPMGetArrayAsync_Private<PETSC_MEMORY_ACCESS_READ, T>(v, const_cast<PetscScalar **>(a), dctx));
1728   PetscFunctionReturn(PETSC_SUCCESS);
1729 }
1730 
1731 template <device::cupm::DeviceType T>
1732 inline PetscErrorCode VecCUPMRestoreArrayReadAsync(Vec v, const PetscScalar **a, PetscDeviceContext dctx = nullptr) noexcept
1733 {
1734   PetscFunctionBegin;
1735   PetscCall(VecCUPMRestoreArrayAsync_Private<PETSC_MEMORY_ACCESS_READ, T>(v, const_cast<PetscScalar **>(a), dctx));
1736   PetscFunctionReturn(PETSC_SUCCESS);
1737 }
1738 
1739 template <device::cupm::DeviceType T>
1740 inline PetscErrorCode VecCUPMGetArrayWriteAsync(Vec v, PetscScalar **a, PetscDeviceContext dctx = nullptr) noexcept
1741 {
1742   PetscFunctionBegin;
1743   PetscCall(VecCUPMGetArrayAsync_Private<PETSC_MEMORY_ACCESS_WRITE, T>(v, a, dctx));
1744   PetscFunctionReturn(PETSC_SUCCESS);
1745 }
1746 
1747 template <device::cupm::DeviceType T>
1748 inline PetscErrorCode VecCUPMRestoreArrayWriteAsync(Vec v, PetscScalar **a, PetscDeviceContext dctx = nullptr) noexcept
1749 {
1750   PetscFunctionBegin;
1751   PetscCall(VecCUPMRestoreArrayAsync_Private<PETSC_MEMORY_ACCESS_WRITE, T>(v, a, dctx));
1752   PetscFunctionReturn(PETSC_SUCCESS);
1753 }
1754 
1755 template <device::cupm::DeviceType T>
1756 inline PetscErrorCode VecCUPMPlaceArrayAsync(Vec vin, const PetscScalar a[]) noexcept
1757 {
1758   PetscFunctionBegin;
1759   PetscValidHeaderSpecific(vin, VEC_CLASSID, 1);
1760   PetscCall(impl::VecSeq_CUPM<T>::template PlaceArray<PETSC_MEMTYPE_DEVICE>(vin, a));
1761   PetscFunctionReturn(PETSC_SUCCESS);
1762 }
1763 
1764 template <device::cupm::DeviceType T>
1765 inline PetscErrorCode VecCUPMReplaceArrayAsync(Vec vin, const PetscScalar a[]) noexcept
1766 {
1767   PetscFunctionBegin;
1768   PetscValidHeaderSpecific(vin, VEC_CLASSID, 1);
1769   PetscCall(impl::VecSeq_CUPM<T>::template ReplaceArray<PETSC_MEMTYPE_DEVICE>(vin, a));
1770   PetscFunctionReturn(PETSC_SUCCESS);
1771 }
1772 
1773 template <device::cupm::DeviceType T>
1774 inline PetscErrorCode VecCUPMResetArrayAsync(Vec vin) noexcept
1775 {
1776   PetscFunctionBegin;
1777   PetscValidHeaderSpecific(vin, VEC_CLASSID, 1);
1778   PetscCall(impl::VecSeq_CUPM<T>::template ResetArray<PETSC_MEMTYPE_DEVICE>(vin));
1779   PetscFunctionReturn(PETSC_SUCCESS);
1780 }
1781 
1782 } // anonymous namespace
1783 
1784 } // namespace cupm
1785 
1786 } // namespace vec
1787 
1788 } // namespace Petsc
1789 
1790 #endif // __cplusplus
1791 
1792 #endif // PETSCVECSEQCUPM_HPP
1793