xref: /petsc/src/vec/vec/impls/seq/cupm/vecseqcupm.hpp (revision 21e3ffae2f3b73c0bd738cf6d0a809700fc04bb0)
1 #ifndef PETSCVECSEQCUPM_HPP
2 #define PETSCVECSEQCUPM_HPP
3 
4 #include <petsc/private/veccupmimpl.h>
5 #include <petsc/private/randomimpl.h> // for _p_PetscRandom
6 #include "../src/sys/objects/device/impls/cupm/cupmthrustutility.hpp"
7 #include "../src/sys/objects/device/impls/cupm/cupmallocator.hpp"
8 #include "../src/sys/objects/device/impls/cupm/kernels.hpp"
9 
10 #if defined(__cplusplus)
11   #include <thrust/transform_reduce.h>
12   #include <thrust/reduce.h>
13   #include <thrust/functional.h>
14   #include <thrust/iterator/counting_iterator.h>
15   #include <thrust/inner_product.h>
16 
17 namespace Petsc
18 {
19 
20 namespace vec
21 {
22 
23 namespace cupm
24 {
25 
26 namespace impl
27 {
28 
29 // ==========================================================================================
30 // VecSeq_CUPM
31 // ==========================================================================================
32 
33 template <device::cupm::DeviceType T>
34 class VecSeq_CUPM : Vec_CUPMBase<T, VecSeq_CUPM<T>> {
35 public:
36   PETSC_VEC_CUPM_BASE_CLASS_HEADER(base_type, T, VecSeq_CUPM<T>);
37 
38 private:
39   PETSC_NODISCARD static Vec_Seq          *VecIMPLCast_(Vec) noexcept;
40   PETSC_NODISCARD static constexpr VecType VECIMPLCUPM_() noexcept;
41 
42   static PetscErrorCode VecDestroy_IMPL_(Vec) noexcept;
43   static PetscErrorCode VecResetArray_IMPL_(Vec) noexcept;
44   static PetscErrorCode VecPlaceArray_IMPL_(Vec, const PetscScalar *) noexcept;
45   static PetscErrorCode VecCreate_IMPL_Private_(Vec, PetscBool *, PetscInt, PetscScalar *) noexcept;
46 
47   static PetscErrorCode MaybeIncrementEmptyLocalVec(Vec) noexcept;
48 
49   // common core for min and max
50   template <typename TupleFuncT, typename UnaryFuncT>
51   static PetscErrorCode minmax_(TupleFuncT &&, UnaryFuncT &&, Vec, PetscInt *, PetscReal *) noexcept;
52   // common core for pointwise binary and pointwise unary thrust functions
53   template <typename BinaryFuncT>
54   static PetscErrorCode pointwisebinary_(BinaryFuncT &&, Vec, Vec, Vec) noexcept;
55   template <typename UnaryFuncT>
56   static PetscErrorCode pointwiseunary_(UnaryFuncT &&, Vec, Vec /*out*/ = nullptr) noexcept;
57   // mdot dispatchers
58   static PetscErrorCode mdot_(/* use complex = */ std::true_type, Vec, PetscInt, const Vec[], PetscScalar *, PetscDeviceContext) noexcept;
59   static PetscErrorCode mdot_(/* use complex = */ std::false_type, Vec, PetscInt, const Vec[], PetscScalar *, PetscDeviceContext) noexcept;
60   template <std::size_t... Idx>
61   static PetscErrorCode mdot_kernel_dispatch_(PetscDeviceContext, cupmStream_t, const PetscScalar *, const Vec[], PetscInt, PetscScalar *, util::index_sequence<Idx...>) noexcept;
62   template <int>
63   static PetscErrorCode mdot_kernel_dispatch_(PetscDeviceContext, cupmStream_t, const PetscScalar *, const Vec[], PetscInt, PetscScalar *, PetscInt &) noexcept;
64   template <std::size_t... Idx>
65   static PetscErrorCode maxpy_kernel_dispatch_(PetscDeviceContext, cupmStream_t, PetscScalar *, const PetscScalar *, const Vec *, PetscInt, util::index_sequence<Idx...>) noexcept;
66   template <int>
67   static PetscErrorCode maxpy_kernel_dispatch_(PetscDeviceContext, cupmStream_t, PetscScalar *, const PetscScalar *, const Vec *, PetscInt, PetscInt &) noexcept;
68   // common core for the various create routines
69   static PetscErrorCode createseqcupm_(Vec, PetscDeviceContext, PetscScalar * /*host_ptr*/ = nullptr, PetscScalar * /*device_ptr*/ = nullptr) noexcept;
70 
71 public:
72   // callable directly via a bespoke function
73   static PetscErrorCode createseqcupm(MPI_Comm, PetscInt, PetscInt, Vec *, PetscBool) noexcept;
74   static PetscErrorCode createseqcupmwithbotharrays(MPI_Comm, PetscInt, PetscInt, const PetscScalar[], const PetscScalar[], Vec *) noexcept;
75 
76   // callable indirectly via function pointers
77   static PetscErrorCode duplicate(Vec, Vec *) noexcept;
78   static PetscErrorCode aypx(Vec, PetscScalar, Vec) noexcept;
79   static PetscErrorCode axpy(Vec, PetscScalar, Vec) noexcept;
80   static PetscErrorCode pointwisedivide(Vec, Vec, Vec) noexcept;
81   static PetscErrorCode pointwisemult(Vec, Vec, Vec) noexcept;
82   static PetscErrorCode reciprocal(Vec) noexcept;
83   static PetscErrorCode waxpy(Vec, PetscScalar, Vec, Vec) noexcept;
84   static PetscErrorCode maxpy(Vec, PetscInt, const PetscScalar[], Vec *) noexcept;
85   static PetscErrorCode dot(Vec, Vec, PetscScalar *) noexcept;
86   static PetscErrorCode mdot(Vec, PetscInt, const Vec[], PetscScalar *) noexcept;
87   static PetscErrorCode set(Vec, PetscScalar) noexcept;
88   static PetscErrorCode scale(Vec, PetscScalar) noexcept;
89   static PetscErrorCode tdot(Vec, Vec, PetscScalar *) noexcept;
90   static PetscErrorCode copy(Vec, Vec) noexcept;
91   static PetscErrorCode swap(Vec, Vec) noexcept;
92   static PetscErrorCode axpby(Vec, PetscScalar, PetscScalar, Vec) noexcept;
93   static PetscErrorCode axpbypcz(Vec, PetscScalar, PetscScalar, PetscScalar, Vec, Vec) noexcept;
94   static PetscErrorCode norm(Vec, NormType, PetscReal *) noexcept;
95   static PetscErrorCode dotnorm2(Vec, Vec, PetscScalar *, PetscScalar *) noexcept;
96   static PetscErrorCode destroy(Vec) noexcept;
97   static PetscErrorCode conjugate(Vec) noexcept;
98   template <PetscMemoryAccessMode>
99   static PetscErrorCode getlocalvector(Vec, Vec) noexcept;
100   template <PetscMemoryAccessMode>
101   static PetscErrorCode restorelocalvector(Vec, Vec) noexcept;
102   static PetscErrorCode max(Vec, PetscInt *, PetscReal *) noexcept;
103   static PetscErrorCode min(Vec, PetscInt *, PetscReal *) noexcept;
104   static PetscErrorCode sum(Vec, PetscScalar *) noexcept;
105   static PetscErrorCode shift(Vec, PetscScalar) noexcept;
106   static PetscErrorCode setrandom(Vec, PetscRandom) noexcept;
107   static PetscErrorCode bindtocpu(Vec, PetscBool) noexcept;
108   static PetscErrorCode setpreallocationcoo(Vec, PetscCount, const PetscInt[]) noexcept;
109   static PetscErrorCode setvaluescoo(Vec, const PetscScalar[], InsertMode) noexcept;
110 };
111 
112 // ==========================================================================================
113 // VecSeq_CUPM - Private API
114 // ==========================================================================================
115 
116 template <device::cupm::DeviceType T>
117 inline Vec_Seq *VecSeq_CUPM<T>::VecIMPLCast_(Vec v) noexcept
118 {
119   return static_cast<Vec_Seq *>(v->data);
120 }
121 
122 template <device::cupm::DeviceType T>
123 inline constexpr VecType VecSeq_CUPM<T>::VECIMPLCUPM_() noexcept
124 {
125   return VECSEQCUPM();
126 }
127 
128 template <device::cupm::DeviceType T>
129 inline PetscErrorCode VecSeq_CUPM<T>::VecDestroy_IMPL_(Vec v) noexcept
130 {
131   return VecDestroy_Seq(v);
132 }
133 
134 template <device::cupm::DeviceType T>
135 inline PetscErrorCode VecSeq_CUPM<T>::VecResetArray_IMPL_(Vec v) noexcept
136 {
137   return VecResetArray_Seq(v);
138 }
139 
140 template <device::cupm::DeviceType T>
141 inline PetscErrorCode VecSeq_CUPM<T>::VecPlaceArray_IMPL_(Vec v, const PetscScalar *a) noexcept
142 {
143   return VecPlaceArray_Seq(v, a);
144 }
145 
146 template <device::cupm::DeviceType T>
147 inline PetscErrorCode VecSeq_CUPM<T>::VecCreate_IMPL_Private_(Vec v, PetscBool *alloc_missing, PetscInt, PetscScalar *host_array) noexcept
148 {
149   PetscMPIInt size;
150 
151   PetscFunctionBegin;
152   if (alloc_missing) *alloc_missing = PETSC_FALSE;
153   PetscCallMPI(MPI_Comm_size(PetscObjectComm(PetscObjectCast(v)), &size));
154   PetscCheck(size <= 1, PETSC_COMM_SELF, PETSC_ERR_ARG_WRONG, "Must create VecSeq on communicator of size 1, have size %d", size);
155   PetscCall(VecCreate_Seq_Private(v, host_array));
156   PetscFunctionReturn(PETSC_SUCCESS);
157 }
158 
159 // for functions with an early return based one vec size we still need to artificially bump the
160 // object state. This is to prevent the following:
161 //
162 // 0. Suppose you have a Vec {
163 //   rank 0: [0],
164 //   rank 1: [<empty>]
165 // }
166 // 1. both ranks have Vec with PetscObjectState = 0, stashed norm of 0
167 // 2. Vec enters e.g. VecSet(10)
168 // 3. rank 1 has local size 0 and bails immediately
169 // 4. rank 0 has local size 1 and enters function, eventually calls DeviceArrayWrite()
170 // 5. DeviceArrayWrite() calls PetscObjectStateIncrease(), now state = 1
171 // 6. Vec enters VecNorm(), and calls VecNormAvailable()
172 // 7. rank 1 has object state = 0, equal to stash and returns early with norm = 0
173 // 8. rank 0 has object state = 1, not equal to stash, continues to impl function
174 // 9. rank 0 deadlocks on MPI_Allreduce() because rank 1 bailed early
175 template <device::cupm::DeviceType T>
176 inline PetscErrorCode VecSeq_CUPM<T>::MaybeIncrementEmptyLocalVec(Vec v) noexcept
177 {
178   PetscFunctionBegin;
179   if (PetscUnlikely((v->map->n == 0) && (v->map->N != 0))) PetscCall(PetscObjectStateIncrease(PetscObjectCast(v)));
180   PetscFunctionReturn(PETSC_SUCCESS);
181 }
182 
183 template <device::cupm::DeviceType T>
184 inline PetscErrorCode VecSeq_CUPM<T>::createseqcupm_(Vec v, PetscDeviceContext dctx, PetscScalar *host_array, PetscScalar *device_array) noexcept
185 {
186   PetscFunctionBegin;
187   PetscCall(base_type::VecCreate_IMPL_Private(v, nullptr, 0, host_array));
188   PetscCall(Initialize_CUPMBase(v, PETSC_FALSE, host_array, device_array, dctx));
189   PetscFunctionReturn(PETSC_SUCCESS);
190 }
191 
192 template <device::cupm::DeviceType T>
193 template <typename BinaryFuncT>
194 inline PetscErrorCode VecSeq_CUPM<T>::pointwisebinary_(BinaryFuncT &&binary, Vec xin, Vec yin, Vec zout) noexcept
195 {
196   PetscFunctionBegin;
197   if (const auto n = zout->map->n) {
198     PetscDeviceContext dctx;
199 
200     PetscCall(GetHandles_(&dctx));
201     PetscCall(device::cupm::impl::ThrustApplyPointwise<T>(dctx, std::forward<BinaryFuncT>(binary), n, DeviceArrayRead(dctx, xin).data(), DeviceArrayRead(dctx, yin).data(), DeviceArrayWrite(dctx, zout).data()));
202     PetscCall(PetscDeviceContextSynchronize(dctx));
203   } else {
204     PetscCall(MaybeIncrementEmptyLocalVec(zout));
205   }
206   PetscFunctionReturn(PETSC_SUCCESS);
207 }
208 
209 template <device::cupm::DeviceType T>
210 template <typename UnaryFuncT>
211 inline PetscErrorCode VecSeq_CUPM<T>::pointwiseunary_(UnaryFuncT &&unary, Vec xinout, Vec yin) noexcept
212 {
213   const auto inplace = !yin || (xinout == yin);
214 
215   PetscFunctionBegin;
216   if (const auto n = xinout->map->n) {
217     PetscDeviceContext dctx;
218 
219     PetscCall(GetHandles_(&dctx));
220     if (inplace) {
221       PetscCall(device::cupm::impl::ThrustApplyPointwise<T>(dctx, std::forward<UnaryFuncT>(unary), n, DeviceArrayReadWrite(dctx, xinout).data()));
222     } else {
223       PetscCall(device::cupm::impl::ThrustApplyPointwise<T>(dctx, std::forward<UnaryFuncT>(unary), n, DeviceArrayRead(dctx, xinout).data(), DeviceArrayWrite(dctx, yin).data()));
224     }
225     PetscCall(PetscDeviceContextSynchronize(dctx));
226   } else {
227     if (inplace) {
228       PetscCall(MaybeIncrementEmptyLocalVec(xinout));
229     } else {
230       PetscCall(MaybeIncrementEmptyLocalVec(yin));
231     }
232   }
233   PetscFunctionReturn(PETSC_SUCCESS);
234 }
235 
236 // ==========================================================================================
237 // VecSeq_CUPM - Public API - Constructors
238 // ==========================================================================================
239 
240 // VecCreateSeqCUPM()
241 template <device::cupm::DeviceType T>
242 inline PetscErrorCode VecSeq_CUPM<T>::createseqcupm(MPI_Comm comm, PetscInt bs, PetscInt n, Vec *v, PetscBool call_set_type) noexcept
243 {
244   PetscFunctionBegin;
245   PetscCall(Create_CUPMBase(comm, bs, n, n, v, call_set_type));
246   PetscFunctionReturn(PETSC_SUCCESS);
247 }
248 
249 // VecCreateSeqCUPMWithArrays()
250 template <device::cupm::DeviceType T>
251 inline PetscErrorCode VecSeq_CUPM<T>::createseqcupmwithbotharrays(MPI_Comm comm, PetscInt bs, PetscInt n, const PetscScalar host_array[], const PetscScalar device_array[], Vec *v) noexcept
252 {
253   PetscDeviceContext dctx;
254 
255   PetscFunctionBegin;
256   PetscCall(GetHandles_(&dctx));
257   // do NOT call VecSetType(), otherwise ops->create() -> create() ->
258   // createseqcupm_() is called!
259   PetscCall(createseqcupm(comm, bs, n, v, PETSC_FALSE));
260   PetscCall(createseqcupm_(*v, dctx, PetscRemoveConstCast(host_array), PetscRemoveConstCast(device_array)));
261   PetscFunctionReturn(PETSC_SUCCESS);
262 }
263 
264 // v->ops->duplicate
265 template <device::cupm::DeviceType T>
266 inline PetscErrorCode VecSeq_CUPM<T>::duplicate(Vec v, Vec *y) noexcept
267 {
268   PetscDeviceContext dctx;
269 
270   PetscFunctionBegin;
271   PetscCall(GetHandles_(&dctx));
272   PetscCall(Duplicate_CUPMBase(v, y, dctx));
273   PetscFunctionReturn(PETSC_SUCCESS);
274 }
275 
276 // ==========================================================================================
277 // VecSeq_CUPM - Public API - Utility
278 // ==========================================================================================
279 
280 // v->ops->bindtocpu
281 template <device::cupm::DeviceType T>
282 inline PetscErrorCode VecSeq_CUPM<T>::bindtocpu(Vec v, PetscBool usehost) noexcept
283 {
284   PetscDeviceContext dctx;
285 
286   PetscFunctionBegin;
287   PetscCall(GetHandles_(&dctx));
288   PetscCall(BindToCPU_CUPMBase(v, usehost, dctx));
289 
290   // REVIEW ME: this absolutely should be some sort of bulk mempcy rather than this mess
291   VecSetOp_CUPM(dot, VecDot_Seq, dot);
292   VecSetOp_CUPM(norm, VecNorm_Seq, norm);
293   VecSetOp_CUPM(tdot, VecTDot_Seq, tdot);
294   VecSetOp_CUPM(mdot, VecMDot_Seq, mdot);
295   VecSetOp_CUPM(resetarray, VecResetArray_Seq, base_type::template resetarray<PETSC_MEMTYPE_HOST>);
296   VecSetOp_CUPM(placearray, VecPlaceArray_Seq, base_type::template placearray<PETSC_MEMTYPE_HOST>);
297   v->ops->mtdot = v->ops->mtdot_local = VecMTDot_Seq;
298   VecSetOp_CUPM(conjugate, VecConjugate_Seq, conjugate);
299   VecSetOp_CUPM(max, VecMax_Seq, max);
300   VecSetOp_CUPM(min, VecMin_Seq, min);
301   VecSetOp_CUPM(setpreallocationcoo, VecSetPreallocationCOO_Seq, setpreallocationcoo);
302   VecSetOp_CUPM(setvaluescoo, VecSetValuesCOO_Seq, setvaluescoo);
303   PetscFunctionReturn(PETSC_SUCCESS);
304 }
305 
306 // ==========================================================================================
307 // VecSeq_CUPM - Public API - Mutators
308 // ==========================================================================================
309 
310 // v->ops->getlocalvector or v->ops->getlocalvectorread
311 template <device::cupm::DeviceType T>
312 template <PetscMemoryAccessMode access>
313 inline PetscErrorCode VecSeq_CUPM<T>::getlocalvector(Vec v, Vec w) noexcept
314 {
315   PetscBool wisseqcupm;
316 
317   PetscFunctionBegin;
318   PetscCheckTypeNames(v, VECSEQCUPM(), VECMPICUPM());
319   PetscCall(PetscObjectTypeCompare(PetscObjectCast(w), VECSEQCUPM(), &wisseqcupm));
320   if (wisseqcupm) {
321     if (const auto wseq = VecIMPLCast(w)) {
322       if (auto &alloced = wseq->array_allocated) {
323         const auto useit = UseCUPMHostAlloc(util::exchange(w->pinned_memory, PETSC_FALSE));
324 
325         PetscCall(PetscFree(alloced));
326       }
327       wseq->array         = nullptr;
328       wseq->unplacedarray = nullptr;
329     }
330     if (const auto wcu = VecCUPMCast(w)) {
331       if (auto &device_array = wcu->array_d) {
332         cupmStream_t stream;
333 
334         PetscCall(GetHandles_(&stream));
335         PetscCallCUPM(cupmFreeAsync(device_array, stream));
336       }
337       PetscCall(PetscFree(w->spptr /* wcu */));
338     }
339   }
340   if (v->petscnative && wisseqcupm) {
341     PetscCall(PetscFree(w->data));
342     w->data          = v->data;
343     w->offloadmask   = v->offloadmask;
344     w->pinned_memory = v->pinned_memory;
345     w->spptr         = v->spptr;
346     PetscCall(PetscObjectStateIncrease(PetscObjectCast(w)));
347   } else {
348     const auto array = &VecIMPLCast(w)->array;
349 
350     if (access == PETSC_MEMORY_ACCESS_READ) {
351       PetscCall(VecGetArrayRead(v, const_cast<const PetscScalar **>(array)));
352     } else {
353       PetscCall(VecGetArray(v, array));
354     }
355     w->offloadmask = PETSC_OFFLOAD_CPU;
356     if (wisseqcupm) {
357       PetscDeviceContext dctx;
358 
359       PetscCall(GetHandles_(&dctx));
360       PetscCall(DeviceAllocateCheck_(dctx, w));
361     }
362   }
363   PetscFunctionReturn(PETSC_SUCCESS);
364 }
365 
366 // v->ops->restorelocalvector or v->ops->restorelocalvectorread
367 template <device::cupm::DeviceType T>
368 template <PetscMemoryAccessMode access>
369 inline PetscErrorCode VecSeq_CUPM<T>::restorelocalvector(Vec v, Vec w) noexcept
370 {
371   PetscBool wisseqcupm;
372 
373   PetscFunctionBegin;
374   PetscCheckTypeNames(v, VECSEQCUPM(), VECMPICUPM());
375   PetscCall(PetscObjectTypeCompare(PetscObjectCast(w), VECSEQCUPM(), &wisseqcupm));
376   if (v->petscnative && wisseqcupm) {
377     // the assignments to nullptr are __critical__, as w may persist after this call returns
378     // and shouldn't share data with v!
379     v->pinned_memory = w->pinned_memory;
380     v->offloadmask   = util::exchange(w->offloadmask, PETSC_OFFLOAD_UNALLOCATED);
381     v->data          = util::exchange(w->data, nullptr);
382     v->spptr         = util::exchange(w->spptr, nullptr);
383   } else {
384     const auto array = &VecIMPLCast(w)->array;
385 
386     if (access == PETSC_MEMORY_ACCESS_READ) {
387       PetscCall(VecRestoreArrayRead(v, const_cast<const PetscScalar **>(array)));
388     } else {
389       PetscCall(VecRestoreArray(v, array));
390     }
391     if (w->spptr && wisseqcupm) {
392       cupmStream_t stream;
393 
394       PetscCall(GetHandles_(&stream));
395       PetscCallCUPM(cupmFreeAsync(VecCUPMCast(w)->array_d, stream));
396       PetscCall(PetscFree(w->spptr));
397     }
398   }
399   PetscFunctionReturn(PETSC_SUCCESS);
400 }
401 
402 // ==========================================================================================
403 // VecSeq_CUPM - Public API - Compute Methods
404 // ==========================================================================================
405 
406 // v->ops->aypx
407 template <device::cupm::DeviceType T>
408 inline PetscErrorCode VecSeq_CUPM<T>::aypx(Vec yin, PetscScalar alpha, Vec xin) noexcept
409 {
410   const auto         n    = static_cast<cupmBlasInt_t>(yin->map->n);
411   const auto         sync = n != 0;
412   PetscDeviceContext dctx;
413 
414   PetscFunctionBegin;
415   PetscCall(GetHandles_(&dctx));
416   if (alpha == PetscScalar(0.0)) {
417     cupmStream_t stream;
418 
419     PetscCall(GetHandlesFrom_(dctx, &stream));
420     PetscCall(PetscLogGpuTimeBegin());
421     PetscCall(PetscCUPMMemcpyAsync(DeviceArrayWrite(dctx, yin).data(), DeviceArrayRead(dctx, xin).data(), n, cupmMemcpyDeviceToDevice, stream));
422     PetscCall(PetscLogGpuTimeEnd());
423   } else if (n) {
424     const auto       alphaIsOne = alpha == PetscScalar(1.0);
425     const auto       calpha     = cupmScalarPtrCast(&alpha);
426     cupmBlasHandle_t cupmBlasHandle;
427 
428     PetscCall(GetHandlesFrom_(dctx, &cupmBlasHandle));
429     {
430       const auto yptr = DeviceArrayReadWrite(dctx, yin);
431       const auto xptr = DeviceArrayRead(dctx, xin);
432 
433       PetscCall(PetscLogGpuTimeBegin());
434       if (alphaIsOne) {
435         PetscCallCUPMBLAS(cupmBlasXaxpy(cupmBlasHandle, n, calpha, xptr.cupmdata(), 1, yptr.cupmdata(), 1));
436       } else {
437         const auto one = cupmScalarCast(1.0);
438 
439         PetscCallCUPMBLAS(cupmBlasXscal(cupmBlasHandle, n, calpha, yptr.cupmdata(), 1));
440         PetscCallCUPMBLAS(cupmBlasXaxpy(cupmBlasHandle, n, &one, xptr.cupmdata(), 1, yptr.cupmdata(), 1));
441       }
442       PetscCall(PetscLogGpuTimeEnd());
443     }
444     PetscCall(PetscLogGpuFlops((alphaIsOne ? 1 : 2) * n));
445   }
446   if (sync) PetscCall(PetscDeviceContextSynchronize(dctx));
447   PetscFunctionReturn(PETSC_SUCCESS);
448 }
449 
450 // v->ops->axpy
451 template <device::cupm::DeviceType T>
452 inline PetscErrorCode VecSeq_CUPM<T>::axpy(Vec yin, PetscScalar alpha, Vec xin) noexcept
453 {
454   PetscBool xiscupm;
455 
456   PetscFunctionBegin;
457   PetscCall(PetscObjectTypeCompareAny(PetscObjectCast(xin), &xiscupm, VECSEQCUPM(), VECMPICUPM(), ""));
458   if (xiscupm) {
459     const auto         n = static_cast<cupmBlasInt_t>(yin->map->n);
460     PetscDeviceContext dctx;
461     cupmBlasHandle_t   cupmBlasHandle;
462 
463     PetscCall(GetHandles_(&dctx, &cupmBlasHandle));
464     PetscCall(PetscLogGpuTimeBegin());
465     PetscCallCUPMBLAS(cupmBlasXaxpy(cupmBlasHandle, n, cupmScalarPtrCast(&alpha), DeviceArrayRead(dctx, xin), 1, DeviceArrayReadWrite(dctx, yin), 1));
466     PetscCall(PetscLogGpuTimeEnd());
467     PetscCall(PetscLogGpuFlops(2 * n));
468     PetscCall(PetscDeviceContextSynchronize(dctx));
469   } else {
470     PetscCall(VecAXPY_Seq(yin, alpha, xin));
471   }
472   PetscFunctionReturn(PETSC_SUCCESS);
473 }
474 
475 // v->ops->pointwisedivide
476 template <device::cupm::DeviceType T>
477 inline PetscErrorCode VecSeq_CUPM<T>::pointwisedivide(Vec win, Vec xin, Vec yin) noexcept
478 {
479   PetscFunctionBegin;
480   if (xin->boundtocpu || yin->boundtocpu) {
481     PetscCall(VecPointwiseDivide_Seq(win, xin, yin));
482   } else {
483     // note order of arguments! xin and yin are read, win is written!
484     PetscCall(pointwisebinary_(thrust::divides<PetscScalar>{}, xin, yin, win));
485   }
486   PetscFunctionReturn(PETSC_SUCCESS);
487 }
488 
489 // v->ops->pointwisemult
490 template <device::cupm::DeviceType T>
491 inline PetscErrorCode VecSeq_CUPM<T>::pointwisemult(Vec win, Vec xin, Vec yin) noexcept
492 {
493   PetscFunctionBegin;
494   if (xin->boundtocpu || yin->boundtocpu) {
495     PetscCall(VecPointwiseMult_Seq(win, xin, yin));
496   } else {
497     // note order of arguments! xin and yin are read, win is written!
498     PetscCall(pointwisebinary_(thrust::multiplies<PetscScalar>{}, xin, yin, win));
499   }
500   PetscFunctionReturn(PETSC_SUCCESS);
501 }
502 
503 namespace detail
504 {
505 
506 struct reciprocal {
507   PETSC_HOSTDEVICE_INLINE_DECL PetscScalar operator()(PetscScalar s) const noexcept
508   {
509     // yes all of this verbosity is needed because sometimes PetscScalar is a thrust::complex
510     // and then it matters whether we do s ? true : false vs s == 0, as well as whether we wrap
511     // everything in PetscScalar...
512     return s == PetscScalar{0.0} ? s : PetscScalar{1.0} / s;
513   }
514 };
515 
516 } // namespace detail
517 
518 // v->ops->reciprocal
519 template <device::cupm::DeviceType T>
520 inline PetscErrorCode VecSeq_CUPM<T>::reciprocal(Vec xin) noexcept
521 {
522   PetscFunctionBegin;
523   PetscCall(pointwiseunary_(detail::reciprocal{}, xin));
524   PetscFunctionReturn(PETSC_SUCCESS);
525 }
526 
527 // v->ops->waxpy
528 template <device::cupm::DeviceType T>
529 inline PetscErrorCode VecSeq_CUPM<T>::waxpy(Vec win, PetscScalar alpha, Vec xin, Vec yin) noexcept
530 {
531   PetscFunctionBegin;
532   if (alpha == PetscScalar(0.0)) {
533     PetscCall(copy(yin, win));
534   } else if (const auto n = static_cast<cupmBlasInt_t>(win->map->n)) {
535     PetscDeviceContext dctx;
536     cupmBlasHandle_t   cupmBlasHandle;
537     cupmStream_t       stream;
538 
539     PetscCall(GetHandles_(&dctx, &cupmBlasHandle, &stream));
540     {
541       const auto wptr = DeviceArrayWrite(dctx, win);
542 
543       PetscCall(PetscLogGpuTimeBegin());
544       PetscCall(PetscCUPMMemcpyAsync(wptr.data(), DeviceArrayRead(dctx, yin).data(), n, cupmMemcpyDeviceToDevice, stream, true));
545       PetscCallCUPMBLAS(cupmBlasXaxpy(cupmBlasHandle, n, cupmScalarPtrCast(&alpha), DeviceArrayRead(dctx, xin), 1, wptr.cupmdata(), 1));
546       PetscCall(PetscLogGpuTimeEnd());
547     }
548     PetscCall(PetscLogGpuFlops(2 * n));
549     PetscCall(PetscDeviceContextSynchronize(dctx));
550   }
551   PetscFunctionReturn(PETSC_SUCCESS);
552 }
553 
554 namespace kernels
555 {
556 
557 template <typename... Args>
558 PETSC_KERNEL_DECL static void maxpy_kernel(const PetscInt size, PetscScalar *PETSC_RESTRICT xptr, const PetscScalar *PETSC_RESTRICT aptr, Args... yptr)
559 {
560   constexpr int      N        = sizeof...(Args);
561   const auto         tx       = threadIdx.x;
562   const PetscScalar *yptr_p[] = {yptr...};
563 
564   PETSC_SHAREDMEM_DECL PetscScalar aptr_shmem[N];
565 
566   // load a to shared memory
567   if (tx < N) aptr_shmem[tx] = aptr[tx];
568   __syncthreads();
569 
570   ::Petsc::device::cupm::kernels::util::grid_stride_1D(size, [&](PetscInt i) {
571     // these may look the same but give different results!
572   #if 0
573     PetscScalar sum = 0.0;
574 
575     #pragma unroll
576     for (auto j = 0; j < N; ++j) sum += aptr_shmem[j]*yptr_p[j][i];
577     xptr[i] += sum;
578   #else
579     auto sum = xptr[i];
580 
581     #pragma unroll
582     for (auto j = 0; j < N; ++j) sum += aptr_shmem[j]*yptr_p[j][i];
583     xptr[i] = sum;
584   #endif
585   });
586   return;
587 }
588 
589 } // namespace kernels
590 
591 namespace detail
592 {
593 
594 // a helper-struct to gobble the size_t input, it is used with template parameter pack
595 // expansion such that
596 // typename repeat_type<MyType, IdxParamPack>...
597 // expands to
598 // MyType, MyType, MyType, ... [repeated sizeof...(IdxParamPack) times]
599 template <typename T, std::size_t>
600 struct repeat_type {
601   using type = T;
602 };
603 
604 } // namespace detail
605 
606 template <device::cupm::DeviceType T>
607 template <std::size_t... Idx>
608 inline PetscErrorCode VecSeq_CUPM<T>::maxpy_kernel_dispatch_(PetscDeviceContext dctx, cupmStream_t stream, PetscScalar *xptr, const PetscScalar *aptr, const Vec *yin, PetscInt size, util::index_sequence<Idx...>) noexcept
609 {
610   PetscFunctionBegin;
611   // clang-format off
612   PetscCall(
613     PetscCUPMLaunchKernel1D(
614       size, 0, stream,
615       kernels::maxpy_kernel<typename detail::repeat_type<const PetscScalar *, Idx>::type...>,
616       size, xptr, aptr, DeviceArrayRead(dctx, yin[Idx]).data()...
617     )
618   );
619   // clang-format on
620   PetscFunctionReturn(PETSC_SUCCESS);
621 }
622 
623 template <device::cupm::DeviceType T>
624 template <int N>
625 inline PetscErrorCode VecSeq_CUPM<T>::maxpy_kernel_dispatch_(PetscDeviceContext dctx, cupmStream_t stream, PetscScalar *xptr, const PetscScalar *aptr, const Vec *yin, PetscInt size, PetscInt &yidx) noexcept
626 {
627   PetscFunctionBegin;
628   PetscCall(maxpy_kernel_dispatch_(dctx, stream, xptr, aptr + yidx, yin + yidx, size, util::make_index_sequence<N>{}));
629   yidx += N;
630   PetscFunctionReturn(PETSC_SUCCESS);
631 }
632 
633 // v->ops->maxpy
634 template <device::cupm::DeviceType T>
635 inline PetscErrorCode VecSeq_CUPM<T>::maxpy(Vec xin, PetscInt nv, const PetscScalar *alpha, Vec *yin) noexcept
636 {
637   const auto         n = xin->map->n;
638   PetscDeviceContext dctx;
639   cupmStream_t       stream;
640 
641   PetscFunctionBegin;
642   PetscCall(GetHandles_(&dctx, &stream));
643   {
644     const auto   xptr    = DeviceArrayReadWrite(dctx, xin);
645     PetscScalar *d_alpha = nullptr;
646     PetscInt     yidx    = 0;
647 
648     // placement of early-return is deliberate, we would like to capture the
649     // DeviceArrayReadWrite() call (which calls PetscObjectStateIncreate()) before we bail
650     if (!n || !nv) PetscFunctionReturn(PETSC_SUCCESS);
651     PetscCall(PetscDeviceMalloc(dctx, PETSC_MEMTYPE_CUPM(), nv, &d_alpha));
652     PetscCall(PetscCUPMMemcpyAsync(d_alpha, alpha, nv, cupmMemcpyHostToDevice, stream));
653     PetscCall(PetscLogGpuTimeBegin());
654     do {
655       switch (nv - yidx) {
656       case 7:
657         PetscCall(maxpy_kernel_dispatch_<7>(dctx, stream, xptr.data(), d_alpha, yin, n, yidx));
658         break;
659       case 6:
660         PetscCall(maxpy_kernel_dispatch_<6>(dctx, stream, xptr.data(), d_alpha, yin, n, yidx));
661         break;
662       case 5:
663         PetscCall(maxpy_kernel_dispatch_<5>(dctx, stream, xptr.data(), d_alpha, yin, n, yidx));
664         break;
665       case 4:
666         PetscCall(maxpy_kernel_dispatch_<4>(dctx, stream, xptr.data(), d_alpha, yin, n, yidx));
667         break;
668       case 3:
669         PetscCall(maxpy_kernel_dispatch_<3>(dctx, stream, xptr.data(), d_alpha, yin, n, yidx));
670         break;
671       case 2:
672         PetscCall(maxpy_kernel_dispatch_<2>(dctx, stream, xptr.data(), d_alpha, yin, n, yidx));
673         break;
674       case 1:
675         PetscCall(maxpy_kernel_dispatch_<1>(dctx, stream, xptr.data(), d_alpha, yin, n, yidx));
676         break;
677       default: // 8 or more
678         PetscCall(maxpy_kernel_dispatch_<8>(dctx, stream, xptr.data(), d_alpha, yin, n, yidx));
679         break;
680       }
681     } while (yidx < nv);
682     PetscCall(PetscLogGpuTimeEnd());
683     PetscCall(PetscDeviceFree(dctx, d_alpha));
684   }
685   PetscCall(PetscLogGpuFlops(nv * 2 * n));
686   PetscCall(PetscDeviceContextSynchronize(dctx));
687   PetscFunctionReturn(PETSC_SUCCESS);
688 }
689 
690 template <device::cupm::DeviceType T>
691 inline PetscErrorCode VecSeq_CUPM<T>::dot(Vec xin, Vec yin, PetscScalar *z) noexcept
692 {
693   PetscFunctionBegin;
694   if (const auto n = static_cast<cupmBlasInt_t>(xin->map->n)) {
695     PetscDeviceContext dctx;
696     cupmBlasHandle_t   cupmBlasHandle;
697 
698     PetscCall(GetHandles_(&dctx, &cupmBlasHandle));
699     // arguments y, x are reversed because BLAS complex conjugates the first argument, PETSc the
700     // second
701     PetscCall(PetscLogGpuTimeBegin());
702     PetscCallCUPMBLAS(cupmBlasXdot(cupmBlasHandle, n, DeviceArrayRead(dctx, yin), 1, DeviceArrayRead(dctx, xin), 1, cupmScalarPtrCast(z)));
703     PetscCall(PetscLogGpuTimeEnd());
704     PetscCall(PetscLogGpuFlops(2 * n - 1));
705   } else {
706     *z = 0.0;
707   }
708   PetscFunctionReturn(PETSC_SUCCESS);
709 }
710 
711   #define MDOT_WORKGROUP_NUM  128
712   #define MDOT_WORKGROUP_SIZE MDOT_WORKGROUP_NUM
713 
714 namespace kernels
715 {
716 
717 PETSC_DEVICE_INLINE_DECL static PetscInt EntriesPerGroup(const PetscInt size) noexcept
718 {
719   const auto group_entries = (size - 1) / gridDim.x + 1;
720   // for very small vectors, a group should still do some work
721   return group_entries ? group_entries : 1;
722 }
723 
724 template <typename... ConstPetscScalarPointer>
725 PETSC_KERNEL_DECL static void mdot_kernel(const PetscScalar *PETSC_RESTRICT x, const PetscInt size, PetscScalar *PETSC_RESTRICT results, ConstPetscScalarPointer... y)
726 {
727   constexpr int      N        = sizeof...(ConstPetscScalarPointer);
728   const PetscScalar *ylocal[] = {y...};
729   PetscScalar        sumlocal[N];
730 
731   PETSC_SHAREDMEM_DECL PetscScalar shmem[N * MDOT_WORKGROUP_SIZE];
732 
733   // HIP -- for whatever reason -- has threadIdx, blockIdx, blockDim, and gridDim as separate
734   // types, so each of these go on separate lines...
735   const auto tx       = threadIdx.x;
736   const auto bx       = blockIdx.x;
737   const auto bdx      = blockDim.x;
738   const auto gdx      = gridDim.x;
739   const auto worksize = EntriesPerGroup(size);
740   const auto begin    = tx + bx * worksize;
741   const auto end      = min((bx + 1) * worksize, size);
742 
743   #pragma unroll
744   for (auto i = 0; i < N; ++i) sumlocal[i] = 0;
745 
746   for (auto i = begin; i < end; i += bdx) {
747     const auto xi = x[i]; // load only once from global memory!
748 
749   #pragma unroll
750     for (auto j = 0; j < N; ++j) sumlocal[j] += ylocal[j][i] * xi;
751   }
752 
753   #pragma unroll
754   for (auto i = 0; i < N; ++i) shmem[tx + i * MDOT_WORKGROUP_SIZE] = sumlocal[i];
755 
756   // parallel reduction
757   for (auto stride = bdx / 2; stride > 0; stride /= 2) {
758     __syncthreads();
759     if (tx < stride) {
760   #pragma unroll
761       for (auto i = 0; i < N; ++i) shmem[tx + i * MDOT_WORKGROUP_SIZE] += shmem[tx + stride + i * MDOT_WORKGROUP_SIZE];
762     }
763   }
764   // bottom N threads per block write to global memory
765   // REVIEW ME: I am ~pretty~ sure we don't need another __syncthreads() here since each thread
766   // writes to the same sections in the above loop that it is about to read from below, but
767   // running this under the racecheck tool of cuda-memcheck reports a write-after-write hazard.
768   __syncthreads();
769   if (tx < N) results[bx + tx * gdx] = shmem[tx * MDOT_WORKGROUP_SIZE];
770   return;
771 }
772 
773 namespace
774 {
775 
776 PETSC_KERNEL_DECL void sum_kernel(const PetscInt size, PetscScalar *PETSC_RESTRICT results)
777 {
778   int         local_i = 0;
779   PetscScalar local_results[8];
780 
781   // each thread sums up MDOT_WORKGROUP_NUM entries of the result, storing it in a local buffer
782   //
783   // *-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*
784   // | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | ...
785   // *-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*
786   //  |  ______________________________________________________/
787   //  | /            <- MDOT_WORKGROUP_NUM ->
788   //  |/
789   //  +
790   //  v
791   // *-*-*
792   // | | | ...
793   // *-*-*
794   //
795   ::Petsc::device::cupm::kernels::util::grid_stride_1D(size, [&](PetscInt i) {
796     PetscScalar z_sum = 0;
797 
798     for (auto j = i * MDOT_WORKGROUP_SIZE; j < (i + 1) * MDOT_WORKGROUP_SIZE; ++j) z_sum += results[j];
799     local_results[local_i++] = z_sum;
800   });
801   // if we needed more than 1 workgroup to handle the vector we should sync since other threads
802   // may currently be reading from results
803   if (size >= MDOT_WORKGROUP_SIZE) __syncthreads();
804   // Local buffer is now written to global memory
805   ::Petsc::device::cupm::kernels::util::grid_stride_1D(size, [&](PetscInt i) {
806     const auto j = --local_i;
807 
808     if (j >= 0) results[i] = local_results[j];
809   });
810   return;
811 }
812 
813 } // namespace
814 
815 } // namespace kernels
816 
817 template <device::cupm::DeviceType T>
818 template <std::size_t... Idx>
819 inline PetscErrorCode VecSeq_CUPM<T>::mdot_kernel_dispatch_(PetscDeviceContext dctx, cupmStream_t stream, const PetscScalar *xarr, const Vec yin[], PetscInt size, PetscScalar *results, util::index_sequence<Idx...>) noexcept
820 {
821   PetscFunctionBegin;
822   // REVIEW ME: convert this kernel launch to PetscCUPMLaunchKernel1D(), it currently launches
823   // 128 blocks of 128 threads every time which may be wasteful
824   // clang-format off
825   PetscCallCUPM(
826     cupmLaunchKernel(
827       kernels::mdot_kernel<typename detail::repeat_type<const PetscScalar *, Idx>::type...>,
828       MDOT_WORKGROUP_NUM, MDOT_WORKGROUP_SIZE, 0, stream,
829       xarr, size, results, DeviceArrayRead(dctx, yin[Idx]).data()...
830     )
831   );
832   // clang-format on
833   PetscFunctionReturn(PETSC_SUCCESS);
834 }
835 
836 template <device::cupm::DeviceType T>
837 template <int N>
838 inline PetscErrorCode VecSeq_CUPM<T>::mdot_kernel_dispatch_(PetscDeviceContext dctx, cupmStream_t stream, const PetscScalar *xarr, const Vec yin[], PetscInt size, PetscScalar *results, PetscInt &yidx) noexcept
839 {
840   PetscFunctionBegin;
841   PetscCall(mdot_kernel_dispatch_(dctx, stream, xarr, yin + yidx, size, results + yidx * MDOT_WORKGROUP_NUM, util::make_index_sequence<N>{}));
842   yidx += N;
843   PetscFunctionReturn(PETSC_SUCCESS);
844 }
845 
846 template <device::cupm::DeviceType T>
847 inline PetscErrorCode VecSeq_CUPM<T>::mdot_(std::false_type, Vec xin, PetscInt nv, const Vec yin[], PetscScalar *z, PetscDeviceContext dctx) noexcept
848 {
849   // the largest possible size of a batch
850   constexpr PetscInt batchsize = 8;
851   // how many sub streams to create, if nv <= batchsize we can do this without looping, so we
852   // do not create substreams. Note we don't create more than 8 streams, in practice we could
853   // not get more parallelism with higher numbers.
854   const auto num_sub_streams = nv > batchsize ? std::min((nv + batchsize) / batchsize, batchsize) : 0;
855   const auto n               = xin->map->n;
856   // number of vectors that we handle via the batches. note any singletons are handled by
857   // cublas, hence the nv-1.
858   const auto   nvbatch = ((nv % batchsize) == 1) ? nv - 1 : nv;
859   const auto   nwork   = nvbatch * MDOT_WORKGROUP_NUM;
860   PetscScalar *d_results;
861   cupmStream_t stream;
862 
863   PetscFunctionBegin;
864   PetscCall(GetHandlesFrom_(dctx, &stream));
865   // allocate scratchpad memory for the results of individual work groups
866   PetscCall(PetscDeviceMalloc(dctx, PETSC_MEMTYPE_CUPM(), nwork, &d_results));
867   {
868     const auto          xptr       = DeviceArrayRead(dctx, xin);
869     PetscInt            yidx       = 0;
870     auto                subidx     = 0;
871     auto                cur_stream = stream;
872     auto                cur_ctx    = dctx;
873     PetscDeviceContext *sub        = nullptr;
874     PetscStreamType     stype;
875 
876     // REVIEW ME: maybe PetscDeviceContextFork() should insert dctx into the first entry of
877     // sub. Ideally the parent context should also join in on the fork, but it is extremely
878     // fiddly to do so presently
879     PetscCall(PetscDeviceContextGetStreamType(dctx, &stype));
880     if (stype == PETSC_STREAM_GLOBAL_BLOCKING) stype = PETSC_STREAM_DEFAULT_BLOCKING;
881     // If we have a globally blocking stream create nonblocking streams instead (as we can
882     // locally exploit the parallelism). Otherwise use the prescribed stream type.
883     PetscCall(PetscDeviceContextForkWithStreamType(dctx, stype, num_sub_streams, &sub));
884     PetscCall(PetscLogGpuTimeBegin());
885     do {
886       if (num_sub_streams) {
887         cur_ctx = sub[subidx++ % num_sub_streams];
888         PetscCall(GetHandlesFrom_(cur_ctx, &cur_stream));
889       }
890       // REVIEW ME: Should probably try and load-balance these. Consider the case where nv = 9;
891       // it is very likely better to do 4+5 rather than 8+1
892       switch (nv - yidx) {
893       case 7:
894         PetscCall(mdot_kernel_dispatch_<7>(cur_ctx, cur_stream, xptr.data(), yin, n, d_results, yidx));
895         break;
896       case 6:
897         PetscCall(mdot_kernel_dispatch_<6>(cur_ctx, cur_stream, xptr.data(), yin, n, d_results, yidx));
898         break;
899       case 5:
900         PetscCall(mdot_kernel_dispatch_<5>(cur_ctx, cur_stream, xptr.data(), yin, n, d_results, yidx));
901         break;
902       case 4:
903         PetscCall(mdot_kernel_dispatch_<4>(cur_ctx, cur_stream, xptr.data(), yin, n, d_results, yidx));
904         break;
905       case 3:
906         PetscCall(mdot_kernel_dispatch_<3>(cur_ctx, cur_stream, xptr.data(), yin, n, d_results, yidx));
907         break;
908       case 2:
909         PetscCall(mdot_kernel_dispatch_<2>(cur_ctx, cur_stream, xptr.data(), yin, n, d_results, yidx));
910         break;
911       case 1: {
912         cupmBlasHandle_t cupmBlasHandle;
913 
914         PetscCall(GetHandlesFrom_(cur_ctx, &cupmBlasHandle));
915         PetscCallCUPMBLAS(cupmBlasXdot(cupmBlasHandle, static_cast<cupmBlasInt_t>(n), DeviceArrayRead(cur_ctx, yin[yidx]).cupmdata(), 1, xptr.cupmdata(), 1, cupmScalarPtrCast(z + yidx)));
916         ++yidx;
917       } break;
918       default: // 8 or more
919         PetscCall(mdot_kernel_dispatch_<8>(cur_ctx, cur_stream, xptr.data(), yin, n, d_results, yidx));
920         break;
921       }
922     } while (yidx < nv);
923     PetscCall(PetscLogGpuTimeEnd());
924     PetscCall(PetscDeviceContextJoin(dctx, num_sub_streams, PETSC_DEVICE_CONTEXT_JOIN_DESTROY, &sub));
925   }
926 
927   PetscCall(PetscCUPMLaunchKernel1D(nvbatch, 0, stream, kernels::sum_kernel, nvbatch, d_results));
928   // copy result of device reduction to host
929   PetscCall(PetscCUPMMemcpyAsync(z, d_results, nvbatch, cupmMemcpyDeviceToHost, stream));
930   // do these now while final reduction is in flight
931   PetscCall(PetscLogFlops(nwork));
932   PetscCall(PetscDeviceFree(dctx, d_results));
933   PetscFunctionReturn(PETSC_SUCCESS);
934 }
935 
936   #undef MDOT_WORKGROUP_NUM
937   #undef MDOT_WORKGROUP_SIZE
938 
939 template <device::cupm::DeviceType T>
940 inline PetscErrorCode VecSeq_CUPM<T>::mdot_(std::true_type, Vec xin, PetscInt nv, const Vec yin[], PetscScalar *z, PetscDeviceContext dctx) noexcept
941 {
942   // probably not worth it to run more than 8 of these at a time?
943   const auto          n_sub = PetscMin(nv, 8);
944   const auto          n     = static_cast<cupmBlasInt_t>(xin->map->n);
945   const auto          xptr  = DeviceArrayRead(dctx, xin);
946   PetscScalar        *d_z;
947   PetscDeviceContext *subctx;
948   cupmStream_t        stream;
949 
950   PetscFunctionBegin;
951   PetscCall(GetHandlesFrom_(dctx, &stream));
952   PetscCall(PetscDeviceMalloc(dctx, PETSC_MEMTYPE_CUPM(), nv, &d_z));
953   PetscCall(PetscDeviceContextFork(dctx, n_sub, &subctx));
954   PetscCall(PetscLogGpuTimeBegin());
955   for (PetscInt i = 0; i < nv; ++i) {
956     const auto            sub = subctx[i % n_sub];
957     cupmBlasHandle_t      handle;
958     cupmBlasPointerMode_t old_mode;
959 
960     PetscCall(GetHandlesFrom_(sub, &handle));
961     PetscCallCUPMBLAS(cupmBlasGetPointerMode(handle, &old_mode));
962     if (old_mode != CUPMBLAS_POINTER_MODE_DEVICE) PetscCallCUPMBLAS(cupmBlasSetPointerMode(handle, CUPMBLAS_POINTER_MODE_DEVICE));
963     PetscCallCUPMBLAS(cupmBlasXdot(handle, n, DeviceArrayRead(sub, yin[i]), 1, xptr.cupmdata(), 1, cupmScalarPtrCast(d_z + i)));
964     if (old_mode != CUPMBLAS_POINTER_MODE_DEVICE) PetscCallCUPMBLAS(cupmBlasSetPointerMode(handle, old_mode));
965   }
966   PetscCall(PetscLogGpuTimeEnd());
967   PetscCall(PetscDeviceContextJoin(dctx, n_sub, PETSC_DEVICE_CONTEXT_JOIN_DESTROY, &subctx));
968   PetscCall(PetscCUPMMemcpyAsync(z, d_z, nv, cupmMemcpyDeviceToHost, stream));
969   PetscCall(PetscDeviceFree(dctx, d_z));
970   // REVIEW ME: flops?????
971   PetscFunctionReturn(PETSC_SUCCESS);
972 }
973 
974 // v->ops->mdot
975 template <device::cupm::DeviceType T>
976 inline PetscErrorCode VecSeq_CUPM<T>::mdot(Vec xin, PetscInt nv, const Vec yin[], PetscScalar *z) noexcept
977 {
978   PetscFunctionBegin;
979   if (PetscUnlikely(nv == 1)) {
980     // dot handles nv = 0 correctly
981     PetscCall(dot(xin, const_cast<Vec>(yin[0]), z));
982   } else if (const auto n = xin->map->n) {
983     PetscDeviceContext dctx;
984 
985     PetscCheck(nv > 0, PETSC_COMM_SELF, PETSC_ERR_LIB, "Number of vectors provided to %s %" PetscInt_FMT " not positive", PETSC_FUNCTION_NAME, nv);
986     PetscCall(GetHandles_(&dctx));
987     PetscCall(mdot_(std::integral_constant<bool, PetscDefined(USE_COMPLEX)>{}, xin, nv, yin, z, dctx));
988     // REVIEW ME: double count of flops??
989     PetscCall(PetscLogGpuFlops(nv * (2 * n - 1)));
990     PetscCall(PetscDeviceContextSynchronize(dctx));
991   } else {
992     PetscCall(PetscArrayzero(z, nv));
993   }
994   PetscFunctionReturn(PETSC_SUCCESS);
995 }
996 
997 // v->ops->set
998 template <device::cupm::DeviceType T>
999 inline PetscErrorCode VecSeq_CUPM<T>::set(Vec xin, PetscScalar alpha) noexcept
1000 {
1001   const auto         n = xin->map->n;
1002   PetscDeviceContext dctx;
1003   cupmStream_t       stream;
1004 
1005   PetscFunctionBegin;
1006   PetscCall(GetHandles_(&dctx, &stream));
1007   {
1008     const auto xptr = DeviceArrayWrite(dctx, xin);
1009 
1010     if (alpha == PetscScalar(0.0)) {
1011       PetscCall(PetscCUPMMemsetAsync(xptr.data(), 0, n, stream));
1012     } else {
1013       PetscCall(device::cupm::impl::ThrustSet<T>(stream, n, xptr.data(), &alpha));
1014     }
1015     if (n) PetscCall(PetscDeviceContextSynchronize(dctx)); // don't sync if we did nothing
1016   }
1017   PetscFunctionReturn(PETSC_SUCCESS);
1018 }
1019 
1020 // v->ops->scale
1021 template <device::cupm::DeviceType T>
1022 inline PetscErrorCode VecSeq_CUPM<T>::scale(Vec xin, PetscScalar alpha) noexcept
1023 {
1024   PetscFunctionBegin;
1025   if (PetscUnlikely(alpha == PetscScalar(1.0))) PetscFunctionReturn(PETSC_SUCCESS);
1026   if (PetscUnlikely(alpha == PetscScalar(0.0))) {
1027     PetscCall(set(xin, alpha));
1028   } else if (const auto n = static_cast<cupmBlasInt_t>(xin->map->n)) {
1029     PetscDeviceContext dctx;
1030     cupmBlasHandle_t   cupmBlasHandle;
1031 
1032     PetscCall(GetHandles_(&dctx, &cupmBlasHandle));
1033     PetscCall(PetscLogGpuTimeBegin());
1034     PetscCallCUPMBLAS(cupmBlasXscal(cupmBlasHandle, n, cupmScalarPtrCast(&alpha), DeviceArrayReadWrite(dctx, xin), 1));
1035     PetscCall(PetscLogGpuTimeEnd());
1036     PetscCall(PetscLogGpuFlops(n));
1037     PetscCall(PetscDeviceContextSynchronize(dctx));
1038   } else {
1039     PetscCall(MaybeIncrementEmptyLocalVec(xin));
1040   }
1041   PetscFunctionReturn(PETSC_SUCCESS);
1042 }
1043 
1044 // v->ops->tdot
1045 template <device::cupm::DeviceType T>
1046 inline PetscErrorCode VecSeq_CUPM<T>::tdot(Vec xin, Vec yin, PetscScalar *z) noexcept
1047 {
1048   PetscFunctionBegin;
1049   if (const auto n = static_cast<cupmBlasInt_t>(xin->map->n)) {
1050     PetscDeviceContext dctx;
1051     cupmBlasHandle_t   cupmBlasHandle;
1052 
1053     PetscCall(GetHandles_(&dctx, &cupmBlasHandle));
1054     PetscCall(PetscLogGpuTimeBegin());
1055     PetscCallCUPMBLAS(cupmBlasXdotu(cupmBlasHandle, n, DeviceArrayRead(dctx, xin), 1, DeviceArrayRead(dctx, yin), 1, cupmScalarPtrCast(z)));
1056     PetscCall(PetscLogGpuTimeEnd());
1057     PetscCall(PetscLogGpuFlops(2 * n - 1));
1058   } else {
1059     *z = 0.0;
1060   }
1061   PetscFunctionReturn(PETSC_SUCCESS);
1062 }
1063 
1064 // v->ops->copy
1065 template <device::cupm::DeviceType T>
1066 inline PetscErrorCode VecSeq_CUPM<T>::copy(Vec xin, Vec yout) noexcept
1067 {
1068   PetscFunctionBegin;
1069   if (xin == yout) PetscFunctionReturn(PETSC_SUCCESS);
1070   if (const auto n = xin->map->n) {
1071     const auto xmask = xin->offloadmask;
1072     // silence buggy gcc warning: mode may be used uninitialized in this function
1073     auto               mode = cupmMemcpyDeviceToDevice;
1074     PetscDeviceContext dctx;
1075     cupmStream_t       stream;
1076 
1077     // translate from PetscOffloadMask to cupmMemcpyKind
1078     switch (const auto ymask = yout->offloadmask) {
1079     case PETSC_OFFLOAD_UNALLOCATED: {
1080       PetscBool yiscupm;
1081 
1082       PetscCall(PetscObjectTypeCompareAny(PetscObjectCast(yout), &yiscupm, VECSEQCUPM(), VECMPICUPM(), ""));
1083       if (yiscupm) {
1084         mode = PetscOffloadDevice(xmask) ? cupmMemcpyDeviceToDevice : cupmMemcpyHostToHost;
1085         break;
1086       }
1087     } // fall-through if unallocated and not cupm
1088   #if PETSC_CPP_VERSION >= 17
1089       [[fallthrough]];
1090   #endif
1091     case PETSC_OFFLOAD_CPU:
1092       mode = PetscOffloadHost(xmask) ? cupmMemcpyHostToHost : cupmMemcpyDeviceToHost;
1093       break;
1094     case PETSC_OFFLOAD_BOTH:
1095     case PETSC_OFFLOAD_GPU:
1096       mode = PetscOffloadDevice(xmask) ? cupmMemcpyDeviceToDevice : cupmMemcpyHostToDevice;
1097       break;
1098     default:
1099       SETERRQ(PETSC_COMM_SELF, PETSC_ERR_ARG_INCOMP, "Incompatible offload mask %s", PetscOffloadMaskToString(ymask));
1100     }
1101 
1102     PetscCall(GetHandles_(&dctx, &stream));
1103     switch (mode) {
1104     case cupmMemcpyDeviceToDevice: // the best case
1105     case cupmMemcpyHostToDevice: { // not terrible
1106       const auto yptr = DeviceArrayWrite(dctx, yout);
1107       const auto xptr = mode == cupmMemcpyDeviceToDevice ? DeviceArrayRead(dctx, xin).data() : HostArrayRead(dctx, xin).data();
1108 
1109       PetscCall(PetscLogGpuTimeBegin());
1110       PetscCall(PetscCUPMMemcpyAsync(yptr.data(), xptr, n, mode, stream));
1111       PetscCall(PetscLogGpuTimeEnd());
1112     } break;
1113     case cupmMemcpyDeviceToHost: // not great
1114     case cupmMemcpyHostToHost: { // worst case
1115       const auto   xptr = mode == cupmMemcpyDeviceToHost ? DeviceArrayRead(dctx, xin).data() : HostArrayRead(dctx, xin).data();
1116       PetscScalar *yptr;
1117 
1118       PetscCall(VecGetArrayWrite(yout, &yptr));
1119       if (mode == cupmMemcpyDeviceToHost) PetscCall(PetscLogGpuTimeBegin());
1120       PetscCall(PetscCUPMMemcpyAsync(yptr, xptr, n, mode, stream, /* force async */ true));
1121       if (mode == cupmMemcpyDeviceToHost) PetscCall(PetscLogGpuTimeEnd());
1122       PetscCall(VecRestoreArrayWrite(yout, &yptr));
1123     } break;
1124     default:
1125       SETERRQ(PETSC_COMM_SELF, PETSC_ERR_GPU, "Unknown cupmMemcpyKind %d", static_cast<int>(mode));
1126     }
1127     PetscCall(PetscDeviceContextSynchronize(dctx));
1128   } else {
1129     PetscCall(MaybeIncrementEmptyLocalVec(yout));
1130   }
1131   PetscFunctionReturn(PETSC_SUCCESS);
1132 }
1133 
1134 // v->ops->swap
1135 template <device::cupm::DeviceType T>
1136 inline PetscErrorCode VecSeq_CUPM<T>::swap(Vec xin, Vec yin) noexcept
1137 {
1138   PetscFunctionBegin;
1139   if (xin == yin) PetscFunctionReturn(PETSC_SUCCESS);
1140   if (const auto n = static_cast<cupmBlasInt_t>(xin->map->n)) {
1141     PetscDeviceContext dctx;
1142     cupmBlasHandle_t   cupmBlasHandle;
1143 
1144     PetscCall(GetHandles_(&dctx, &cupmBlasHandle));
1145     PetscCall(PetscLogGpuTimeBegin());
1146     PetscCallCUPMBLAS(cupmBlasXswap(cupmBlasHandle, n, DeviceArrayReadWrite(dctx, xin), 1, DeviceArrayReadWrite(dctx, yin), 1));
1147     PetscCall(PetscLogGpuTimeEnd());
1148     PetscCall(PetscDeviceContextSynchronize(dctx));
1149   } else {
1150     PetscCall(MaybeIncrementEmptyLocalVec(xin));
1151     PetscCall(MaybeIncrementEmptyLocalVec(yin));
1152   }
1153   PetscFunctionReturn(PETSC_SUCCESS);
1154 }
1155 
1156 // v->ops->axpby
1157 template <device::cupm::DeviceType T>
1158 inline PetscErrorCode VecSeq_CUPM<T>::axpby(Vec yin, PetscScalar alpha, PetscScalar beta, Vec xin) noexcept
1159 {
1160   PetscFunctionBegin;
1161   if (alpha == PetscScalar(0.0)) {
1162     PetscCall(scale(yin, beta));
1163   } else if (beta == PetscScalar(1.0)) {
1164     PetscCall(axpy(yin, alpha, xin));
1165   } else if (alpha == PetscScalar(1.0)) {
1166     PetscCall(aypx(yin, beta, xin));
1167   } else if (const auto n = static_cast<cupmBlasInt_t>(yin->map->n)) {
1168     const auto         betaIsZero = beta == PetscScalar(0.0);
1169     const auto         aptr       = cupmScalarPtrCast(&alpha);
1170     PetscDeviceContext dctx;
1171     cupmBlasHandle_t   cupmBlasHandle;
1172 
1173     PetscCall(GetHandles_(&dctx, &cupmBlasHandle));
1174     {
1175       const auto xptr = DeviceArrayRead(dctx, xin);
1176 
1177       if (betaIsZero /* beta = 0 */) {
1178         // here we can get away with purely write-only as we memcpy into it first
1179         const auto   yptr = DeviceArrayWrite(dctx, yin);
1180         cupmStream_t stream;
1181 
1182         PetscCall(GetHandlesFrom_(dctx, &stream));
1183         PetscCall(PetscLogGpuTimeBegin());
1184         PetscCall(PetscCUPMMemcpyAsync(yptr.data(), xptr.data(), n, cupmMemcpyDeviceToDevice, stream));
1185         PetscCallCUPMBLAS(cupmBlasXscal(cupmBlasHandle, n, aptr, yptr.cupmdata(), 1));
1186       } else {
1187         const auto yptr = DeviceArrayReadWrite(dctx, yin);
1188 
1189         PetscCall(PetscLogGpuTimeBegin());
1190         PetscCallCUPMBLAS(cupmBlasXscal(cupmBlasHandle, n, cupmScalarPtrCast(&beta), yptr.cupmdata(), 1));
1191         PetscCallCUPMBLAS(cupmBlasXaxpy(cupmBlasHandle, n, aptr, xptr.cupmdata(), 1, yptr.cupmdata(), 1));
1192       }
1193     }
1194     PetscCall(PetscLogGpuTimeEnd());
1195     PetscCall(PetscLogGpuFlops((betaIsZero ? 1 : 3) * n));
1196     PetscCall(PetscDeviceContextSynchronize(dctx));
1197   } else {
1198     PetscCall(MaybeIncrementEmptyLocalVec(yin));
1199   }
1200   PetscFunctionReturn(PETSC_SUCCESS);
1201 }
1202 
1203 // v->ops->axpbypcz
1204 template <device::cupm::DeviceType T>
1205 inline PetscErrorCode VecSeq_CUPM<T>::axpbypcz(Vec zin, PetscScalar alpha, PetscScalar beta, PetscScalar gamma, Vec xin, Vec yin) noexcept
1206 {
1207   PetscFunctionBegin;
1208   if (gamma != PetscScalar(1.0)) PetscCall(scale(zin, gamma));
1209   PetscCall(axpy(zin, alpha, xin));
1210   PetscCall(axpy(zin, beta, yin));
1211   PetscFunctionReturn(PETSC_SUCCESS);
1212 }
1213 
1214 // v->ops->norm
1215 template <device::cupm::DeviceType T>
1216 inline PetscErrorCode VecSeq_CUPM<T>::norm(Vec xin, NormType type, PetscReal *z) noexcept
1217 {
1218   PetscDeviceContext dctx;
1219   cupmBlasHandle_t   cupmBlasHandle;
1220 
1221   PetscFunctionBegin;
1222   PetscCall(GetHandles_(&dctx, &cupmBlasHandle));
1223   if (const auto n = static_cast<cupmBlasInt_t>(xin->map->n)) {
1224     const auto xptr      = DeviceArrayRead(dctx, xin);
1225     PetscInt   flopCount = 0;
1226 
1227     PetscCall(PetscLogGpuTimeBegin());
1228     switch (type) {
1229     case NORM_1_AND_2:
1230     case NORM_1:
1231       PetscCallCUPMBLAS(cupmBlasXasum(cupmBlasHandle, n, xptr.cupmdata(), 1, cupmRealPtrCast(z)));
1232       flopCount = std::max(n - 1, 0);
1233       if (type == NORM_1) break;
1234       ++z; // fall-through
1235   #if PETSC_CPP_VERSION >= 17
1236       [[fallthrough]];
1237   #endif
1238     case NORM_2:
1239     case NORM_FROBENIUS:
1240       PetscCallCUPMBLAS(cupmBlasXnrm2(cupmBlasHandle, n, xptr.cupmdata(), 1, cupmRealPtrCast(z)));
1241       flopCount += std::max(2 * n - 1, 0); // += in case we've fallen through from NORM_1_AND_2
1242       break;
1243     case NORM_INFINITY: {
1244       cupmBlasInt_t max_loc = 0;
1245       PetscScalar   xv      = 0.;
1246       cupmStream_t  stream;
1247 
1248       PetscCall(GetHandlesFrom_(dctx, &stream));
1249       PetscCallCUPMBLAS(cupmBlasXamax(cupmBlasHandle, n, xptr.cupmdata(), 1, &max_loc));
1250       PetscCall(PetscCUPMMemcpyAsync(&xv, xptr.data() + max_loc - 1, 1, cupmMemcpyDeviceToHost, stream));
1251       *z = PetscAbsScalar(xv);
1252       // REVIEW ME: flopCount = ???
1253     } break;
1254     }
1255     PetscCall(PetscLogGpuTimeEnd());
1256     PetscCall(PetscLogGpuFlops(flopCount));
1257   } else {
1258     z[0]                    = 0.0;
1259     z[type == NORM_1_AND_2] = 0.0;
1260   }
1261   PetscFunctionReturn(PETSC_SUCCESS);
1262 }
1263 
1264 namespace detail
1265 {
1266 
1267 struct dotnorm2_mult {
1268   PETSC_NODISCARD PETSC_HOSTDEVICE_INLINE_DECL thrust::tuple<PetscScalar, PetscScalar> operator()(const PetscScalar &s, const PetscScalar &t) const noexcept
1269   {
1270     const auto conjt = PetscConj(t);
1271 
1272     return {s * conjt, t * conjt};
1273   }
1274 };
1275 
1276 // it is positively __bananas__ that thrust does not define default operator+ for tuples... I
1277 // would do it myself but now I am worried that they do so on purpose...
1278 struct dotnorm2_tuple_plus {
1279   using value_type = thrust::tuple<PetscScalar, PetscScalar>;
1280 
1281   PETSC_NODISCARD PETSC_HOSTDEVICE_INLINE_DECL value_type operator()(const value_type &lhs, const value_type &rhs) const noexcept { return {lhs.get<0>() + rhs.get<0>(), lhs.get<1>() + rhs.get<1>()}; }
1282 };
1283 
1284 } // namespace detail
1285 
1286 // v->ops->dotnorm2
1287 template <device::cupm::DeviceType T>
1288 inline PetscErrorCode VecSeq_CUPM<T>::dotnorm2(Vec s, Vec t, PetscScalar *dp, PetscScalar *nm) noexcept
1289 {
1290   PetscDeviceContext dctx;
1291   cupmStream_t       stream;
1292 
1293   PetscFunctionBegin;
1294   PetscCall(GetHandles_(&dctx, &stream));
1295   {
1296     PetscScalar dpt = 0.0, nmt = 0.0;
1297     const auto  sdptr = thrust::device_pointer_cast(DeviceArrayRead(dctx, s).data());
1298 
1299     // clang-format off
1300     PetscCallThrust(
1301       thrust::tie(*dp, *nm) = THRUST_CALL(
1302         thrust::inner_product,
1303         stream,
1304         sdptr, sdptr+s->map->n, thrust::device_pointer_cast(DeviceArrayRead(dctx, t).data()),
1305         thrust::make_tuple(dpt, nmt),
1306         detail::dotnorm2_tuple_plus{}, detail::dotnorm2_mult{}
1307       );
1308     );
1309     // clang-format on
1310   }
1311   PetscFunctionReturn(PETSC_SUCCESS);
1312 }
1313 
1314 namespace detail
1315 {
1316 
1317 struct conjugate {
1318   PETSC_NODISCARD PETSC_HOSTDEVICE_INLINE_DECL PetscScalar operator()(PetscScalar x) const noexcept { return PetscConj(x); }
1319 };
1320 
1321 } // namespace detail
1322 
1323 // v->ops->conjugate
1324 template <device::cupm::DeviceType T>
1325 inline PetscErrorCode VecSeq_CUPM<T>::conjugate(Vec xin) noexcept
1326 {
1327   PetscFunctionBegin;
1328   if (PetscDefined(USE_COMPLEX)) PetscCall(pointwiseunary_(detail::conjugate{}, xin));
1329   PetscFunctionReturn(PETSC_SUCCESS);
1330 }
1331 
1332 namespace detail
1333 {
1334 
1335 struct real_part {
1336   PETSC_NODISCARD PETSC_HOSTDEVICE_INLINE_DECL thrust::tuple<PetscReal, PetscInt> operator()(const thrust::tuple<PetscScalar, PetscInt> &x) const { return {PetscRealPart(x.get<0>()), x.get<1>()}; }
1337 
1338   PETSC_NODISCARD PETSC_HOSTDEVICE_INLINE_DECL PetscReal operator()(PetscScalar x) const { return PetscRealPart(x); }
1339 };
1340 
1341 // deriving from Operator allows us to "store" an instance of the operator in the class but
1342 // also take advantage of empty base class optimization if the operator is stateless
1343 template <typename Operator>
1344 class tuple_compare : Operator {
1345 public:
1346   using tuple_type    = thrust::tuple<PetscReal, PetscInt>;
1347   using operator_type = Operator;
1348 
1349   PETSC_NODISCARD PETSC_HOSTDEVICE_INLINE_DECL tuple_type operator()(const tuple_type &x, const tuple_type &y) const noexcept
1350   {
1351     if (op_()(y.get<0>(), x.get<0>())) {
1352       // if y is strictly greater/less than x, return y
1353       return y;
1354     } else if (y.get<0>() == x.get<0>()) {
1355       // if equal, prefer lower index
1356       return y.get<1>() < x.get<1>() ? y : x;
1357     }
1358     // otherwise return x
1359     return x;
1360   }
1361 
1362 private:
1363   PETSC_NODISCARD PETSC_HOSTDEVICE_INLINE_DECL const operator_type &op_() const noexcept { return *this; }
1364 };
1365 
1366 } // namespace detail
1367 
1368 template <device::cupm::DeviceType T>
1369 template <typename TupleFuncT, typename UnaryFuncT>
1370 inline PetscErrorCode VecSeq_CUPM<T>::minmax_(TupleFuncT &&tuple_ftr, UnaryFuncT &&unary_ftr, Vec v, PetscInt *p, PetscReal *m) noexcept
1371 {
1372   PetscFunctionBegin;
1373   PetscCheckTypeNames(v, VECSEQCUPM(), VECMPICUPM());
1374   if (p) *p = -1;
1375   if (const auto n = v->map->n) {
1376     PetscDeviceContext dctx;
1377     cupmStream_t       stream;
1378 
1379     PetscCall(GetHandles_(&dctx, &stream));
1380       // needed to:
1381       // 1. switch between transform_reduce and reduce
1382       // 2. strip the real_part functor from the arguments
1383   #if PetscDefined(USE_COMPLEX)
1384     #define THRUST_MINMAX_REDUCE(...) THRUST_CALL(thrust::transform_reduce, __VA_ARGS__)
1385   #else
1386     #define THRUST_MINMAX_REDUCE(s, b, e, real_part__, ...) THRUST_CALL(thrust::reduce, s, b, e, __VA_ARGS__)
1387   #endif
1388     {
1389       const auto vptr = thrust::device_pointer_cast(DeviceArrayRead(dctx, v).data());
1390 
1391       if (p) {
1392         // clang-format off
1393         const auto zip = thrust::make_zip_iterator(
1394           thrust::make_tuple(std::move(vptr), thrust::make_counting_iterator(PetscInt{0}))
1395         );
1396         // clang-format on
1397         // need to use preprocessor conditionals since otherwise thrust complains about not being
1398         // able to convert a thrust::device_reference<PetscScalar> to a PetscReal on complex
1399         // builds...
1400         // clang-format off
1401         PetscCallThrust(
1402           thrust::tie(*m, *p) = THRUST_MINMAX_REDUCE(
1403             stream, zip, zip + n, detail::real_part{},
1404             thrust::make_tuple(*m, *p), std::forward<TupleFuncT>(tuple_ftr)
1405           );
1406         );
1407         // clang-format on
1408       } else {
1409         // clang-format off
1410         PetscCallThrust(
1411           *m = THRUST_MINMAX_REDUCE(
1412             stream, vptr, vptr + n, detail::real_part{},
1413             *m, std::forward<UnaryFuncT>(unary_ftr)
1414           );
1415         );
1416         // clang-format on
1417       }
1418     }
1419   #undef THRUST_MINMAX_REDUCE
1420   }
1421   // REVIEW ME: flops?
1422   PetscFunctionReturn(PETSC_SUCCESS);
1423 }
1424 
1425 // v->ops->max
1426 template <device::cupm::DeviceType T>
1427 inline PetscErrorCode VecSeq_CUPM<T>::max(Vec v, PetscInt *p, PetscReal *m) noexcept
1428 {
1429   using tuple_functor = detail::tuple_compare<thrust::greater<PetscReal>>;
1430   using unary_functor = thrust::maximum<PetscReal>;
1431 
1432   PetscFunctionBegin;
1433   *m = PETSC_MIN_REAL;
1434   // use {} constructor syntax otherwise most vexing parse
1435   PetscCall(minmax_(tuple_functor{}, unary_functor{}, v, p, m));
1436   PetscFunctionReturn(PETSC_SUCCESS);
1437 }
1438 
1439 // v->ops->min
1440 template <device::cupm::DeviceType T>
1441 inline PetscErrorCode VecSeq_CUPM<T>::min(Vec v, PetscInt *p, PetscReal *m) noexcept
1442 {
1443   using tuple_functor = detail::tuple_compare<thrust::less<PetscReal>>;
1444   using unary_functor = thrust::minimum<PetscReal>;
1445 
1446   PetscFunctionBegin;
1447   *m = PETSC_MAX_REAL;
1448   // use {} constructor syntax otherwise most vexing parse
1449   PetscCall(minmax_(tuple_functor{}, unary_functor{}, v, p, m));
1450   PetscFunctionReturn(PETSC_SUCCESS);
1451 }
1452 
1453 // v->ops->sum
1454 template <device::cupm::DeviceType T>
1455 inline PetscErrorCode VecSeq_CUPM<T>::sum(Vec v, PetscScalar *sum) noexcept
1456 {
1457   PetscFunctionBegin;
1458   if (const auto n = v->map->n) {
1459     PetscDeviceContext dctx;
1460     cupmStream_t       stream;
1461 
1462     PetscCall(GetHandles_(&dctx, &stream));
1463     const auto dptr = thrust::device_pointer_cast(DeviceArrayRead(dctx, v).data());
1464     // REVIEW ME: why not cupmBlasXasum()?
1465     PetscCallThrust(*sum = THRUST_CALL(thrust::reduce, stream, dptr, dptr + n, PetscScalar{0.0}););
1466     // REVIEW ME: must be at least n additions
1467     PetscCall(PetscLogGpuFlops(n));
1468   } else {
1469     *sum = 0.0;
1470   }
1471   PetscFunctionReturn(PETSC_SUCCESS);
1472 }
1473 
1474 namespace detail
1475 {
1476 
1477 template <typename T>
1478 class plus_equals {
1479 public:
1480   using value_type = T;
1481 
1482   PETSC_HOSTDEVICE_DECL constexpr explicit plus_equals(value_type v = value_type{}) noexcept : v_(std::move(v)) { }
1483 
1484   PETSC_NODISCARD PETSC_HOSTDEVICE_INLINE_DECL constexpr T operator()(const T &val) const noexcept { return val + v_; }
1485 
1486 private:
1487   value_type v_;
1488 };
1489 
1490 } // namespace detail
1491 
1492 template <device::cupm::DeviceType T>
1493 inline PetscErrorCode VecSeq_CUPM<T>::shift(Vec v, PetscScalar shift) noexcept
1494 {
1495   PetscFunctionBegin;
1496   PetscCall(pointwiseunary_(detail::plus_equals<PetscScalar>{shift}, v));
1497   PetscFunctionReturn(PETSC_SUCCESS);
1498 }
1499 
1500 template <device::cupm::DeviceType T>
1501 inline PetscErrorCode VecSeq_CUPM<T>::setrandom(Vec v, PetscRandom rand) noexcept
1502 {
1503   PetscFunctionBegin;
1504   if (const auto n = v->map->n) {
1505     PetscBool          iscurand;
1506     PetscDeviceContext dctx;
1507 
1508     PetscCall(GetHandles_(&dctx));
1509     PetscCall(PetscObjectTypeCompare(PetscObjectCast(rand), PETSCCURAND, &iscurand));
1510     if (iscurand) PetscCall(PetscRandomGetValues(rand, n, DeviceArrayWrite(dctx, v)));
1511     else PetscCall(PetscRandomGetValues(rand, n, HostArrayWrite(dctx, v)));
1512   } else {
1513     PetscCall(MaybeIncrementEmptyLocalVec(v));
1514   }
1515   // REVIEW ME: flops????
1516   // REVIEW ME: Timing???
1517   PetscFunctionReturn(PETSC_SUCCESS);
1518 }
1519 
1520 // v->ops->setpreallocation
1521 template <device::cupm::DeviceType T>
1522 inline PetscErrorCode VecSeq_CUPM<T>::setpreallocationcoo(Vec v, PetscCount ncoo, const PetscInt coo_i[]) noexcept
1523 {
1524   PetscDeviceContext dctx;
1525 
1526   PetscFunctionBegin;
1527   PetscCall(GetHandles_(&dctx));
1528   PetscCall(VecSetPreallocationCOO_Seq(v, ncoo, coo_i));
1529   PetscCall(SetPreallocationCOO_CUPMBase(v, ncoo, coo_i, dctx));
1530   PetscFunctionReturn(PETSC_SUCCESS);
1531 }
1532 
1533 namespace kernels
1534 {
1535 
1536 template <typename F>
1537 PETSC_DEVICE_INLINE_DECL void add_coo_values_impl(const PetscScalar *PETSC_RESTRICT vv, PetscCount n, const PetscCount *PETSC_RESTRICT jmap, const PetscCount *PETSC_RESTRICT perm, InsertMode imode, PetscScalar *PETSC_RESTRICT xv, F &&xvindex)
1538 {
1539   ::Petsc::device::cupm::kernels::util::grid_stride_1D(n, [=](PetscCount i) {
1540     const auto  end = jmap[i + 1];
1541     const auto  idx = xvindex(i);
1542     PetscScalar sum = 0.0;
1543 
1544     for (auto k = jmap[i]; k < end; ++k) sum += vv[perm[k]];
1545 
1546     if (imode == INSERT_VALUES) {
1547       xv[idx] = sum;
1548     } else {
1549       xv[idx] += sum;
1550     }
1551   });
1552   return;
1553 }
1554 
1555 namespace
1556 {
1557 
1558 PETSC_KERNEL_DECL void add_coo_values(const PetscScalar *PETSC_RESTRICT v, PetscCount n, const PetscCount *PETSC_RESTRICT jmap1, const PetscCount *PETSC_RESTRICT perm1, InsertMode imode, PetscScalar *PETSC_RESTRICT xv)
1559 {
1560   add_coo_values_impl(v, n, jmap1, perm1, imode, xv, [](PetscCount i) { return i; });
1561   return;
1562 }
1563 
1564 } // namespace
1565 
1566   #if PetscDefined(USING_HCC)
1567 namespace do_not_use
1568 {
1569 
1570 // Needed to silence clang warning:
1571 //
1572 // warning: function 'FUNCTION NAME' is not needed and will not be emitted
1573 //
1574 // The warning is silly, since the function *is* used, however the host compiler does not
1575 // appear see this. Likely because the function using it is in a template.
1576 //
1577 // This warning appeared in clang-11, and still persists until clang-15 (21/02/2023)
1578 inline void silence_warning_function_sum_kernel_is_not_needed_and_will_not_be_emitted()
1579 {
1580   (void)sum_kernel;
1581 }
1582 
1583 inline void silence_warning_function_add_coo_values_is_not_needed_and_will_not_be_emitted()
1584 {
1585   (void)add_coo_values;
1586 }
1587 
1588 } // namespace do_not_use
1589   #endif
1590 
1591 } // namespace kernels
1592 
1593 // v->ops->setvaluescoo
1594 template <device::cupm::DeviceType T>
1595 inline PetscErrorCode VecSeq_CUPM<T>::setvaluescoo(Vec x, const PetscScalar v[], InsertMode imode) noexcept
1596 {
1597   auto               vv = const_cast<PetscScalar *>(v);
1598   PetscMemType       memtype;
1599   PetscDeviceContext dctx;
1600   cupmStream_t       stream;
1601 
1602   PetscFunctionBegin;
1603   PetscCall(GetHandles_(&dctx, &stream));
1604   PetscCall(PetscGetMemType(v, &memtype));
1605   if (PetscMemTypeHost(memtype)) {
1606     const auto size = VecIMPLCast(x)->coo_n;
1607 
1608     // If user gave v[] in host, we might need to copy it to device if any
1609     PetscCall(PetscDeviceMalloc(dctx, PETSC_MEMTYPE_CUPM(), size, &vv));
1610     PetscCall(PetscCUPMMemcpyAsync(vv, v, size, cupmMemcpyHostToDevice, stream));
1611   }
1612 
1613   if (const auto n = x->map->n) {
1614     const auto vcu = VecCUPMCast(x);
1615 
1616     PetscCall(PetscCUPMLaunchKernel1D(n, 0, stream, kernels::add_coo_values, vv, n, vcu->jmap1_d, vcu->perm1_d, imode, imode == INSERT_VALUES ? DeviceArrayWrite(dctx, x).data() : DeviceArrayReadWrite(dctx, x).data()));
1617   } else {
1618     PetscCall(MaybeIncrementEmptyLocalVec(x));
1619   }
1620 
1621   if (PetscMemTypeHost(memtype)) PetscCall(PetscDeviceFree(dctx, vv));
1622   PetscCall(PetscDeviceContextSynchronize(dctx));
1623   PetscFunctionReturn(PETSC_SUCCESS);
1624 }
1625 
1626 // ==========================================================================================
1627 // VecSeq_CUPM - Implementations
1628 // ==========================================================================================
1629 
1630 namespace
1631 {
1632 
1633 template <typename T>
1634 inline PetscErrorCode VecCreateSeqCUPMAsync(T &&VecSeq_CUPM_Impls, MPI_Comm comm, PetscInt n, Vec *v) noexcept
1635 {
1636   PetscFunctionBegin;
1637   PetscValidPointer(v, 4);
1638   PetscCall(VecSeq_CUPM_Impls.createseqcupm(comm, 0, n, v, PETSC_TRUE));
1639   PetscFunctionReturn(PETSC_SUCCESS);
1640 }
1641 
1642 template <typename T>
1643 inline PetscErrorCode VecCreateSeqCUPMWithArraysAsync(T &&VecSeq_CUPM_Impls, MPI_Comm comm, PetscInt bs, PetscInt n, const PetscScalar cpuarray[], const PetscScalar gpuarray[], Vec *v) noexcept
1644 {
1645   PetscFunctionBegin;
1646   if (n && cpuarray) PetscValidScalarPointer(cpuarray, 5);
1647   PetscValidPointer(v, 7);
1648   PetscCall(VecSeq_CUPM_Impls.createseqcupmwithbotharrays(comm, bs, n, cpuarray, gpuarray, v));
1649   PetscFunctionReturn(PETSC_SUCCESS);
1650 }
1651 
1652 template <PetscMemoryAccessMode mode, typename T>
1653 inline PetscErrorCode VecCUPMGetArrayAsync_Private(T &&VecSeq_CUPM_Impls, Vec v, PetscScalar **a) noexcept
1654 {
1655   PetscFunctionBegin;
1656   PetscValidHeaderSpecific(v, VEC_CLASSID, 2);
1657   PetscValidPointer(a, 3);
1658   PetscCall(VecSeq_CUPM_Impls.template getarray<PETSC_MEMTYPE_DEVICE, mode>(v, a));
1659   PetscFunctionReturn(PETSC_SUCCESS);
1660 }
1661 
1662 template <PetscMemoryAccessMode mode, typename T>
1663 inline PetscErrorCode VecCUPMRestoreArrayAsync_Private(T &&VecSeq_CUPM_Impls, Vec v, PetscScalar **a) noexcept
1664 {
1665   PetscFunctionBegin;
1666   PetscValidHeaderSpecific(v, VEC_CLASSID, 2);
1667   PetscCall(VecSeq_CUPM_Impls.template restorearray<PETSC_MEMTYPE_DEVICE, mode>(v, a));
1668   PetscFunctionReturn(PETSC_SUCCESS);
1669 }
1670 
1671 template <typename T>
1672 inline PetscErrorCode VecCUPMGetArrayAsync(T &&VecSeq_CUPM_Impls, Vec v, PetscScalar **a) noexcept
1673 {
1674   PetscFunctionBegin;
1675   PetscCall(VecCUPMGetArrayAsync_Private<PETSC_MEMORY_ACCESS_READ_WRITE>(std::forward<T>(VecSeq_CUPM_Impls), v, a));
1676   PetscFunctionReturn(PETSC_SUCCESS);
1677 }
1678 
1679 template <typename T>
1680 inline PetscErrorCode VecCUPMRestoreArrayAsync(T &&VecSeq_CUPM_Impls, Vec v, PetscScalar **a) noexcept
1681 {
1682   PetscFunctionBegin;
1683   PetscCall(VecCUPMRestoreArrayAsync_Private<PETSC_MEMORY_ACCESS_READ_WRITE>(std::forward<T>(VecSeq_CUPM_Impls), v, a));
1684   PetscFunctionReturn(PETSC_SUCCESS);
1685 }
1686 
1687 template <typename T>
1688 inline PetscErrorCode VecCUPMGetArrayReadAsync(T &&VecSeq_CUPM_Impls, Vec v, const PetscScalar **a) noexcept
1689 {
1690   PetscFunctionBegin;
1691   PetscCall(VecCUPMGetArrayAsync_Private<PETSC_MEMORY_ACCESS_READ>(std::forward<T>(VecSeq_CUPM_Impls), v, const_cast<PetscScalar **>(a)));
1692   PetscFunctionReturn(PETSC_SUCCESS);
1693 }
1694 
1695 template <typename T>
1696 inline PetscErrorCode VecCUPMRestoreArrayReadAsync(T &&VecSeq_CUPM_Impls, Vec v, const PetscScalar **a) noexcept
1697 {
1698   PetscFunctionBegin;
1699   PetscCall(VecCUPMRestoreArrayAsync_Private<PETSC_MEMORY_ACCESS_READ>(std::forward<T>(VecSeq_CUPM_Impls), v, const_cast<PetscScalar **>(a)));
1700   PetscFunctionReturn(PETSC_SUCCESS);
1701 }
1702 
1703 template <typename T>
1704 inline PetscErrorCode VecCUPMGetArrayWriteAsync(T &&VecSeq_CUPM_Impls, Vec v, PetscScalar **a) noexcept
1705 {
1706   PetscFunctionBegin;
1707   PetscCall(VecCUPMGetArrayAsync_Private<PETSC_MEMORY_ACCESS_WRITE>(std::forward<T>(VecSeq_CUPM_Impls), v, a));
1708   PetscFunctionReturn(PETSC_SUCCESS);
1709 }
1710 
1711 template <typename T>
1712 inline PetscErrorCode VecCUPMRestoreArrayWriteAsync(T &&VecSeq_CUPM_Impls, Vec v, PetscScalar **a) noexcept
1713 {
1714   PetscFunctionBegin;
1715   PetscCall(VecCUPMRestoreArrayAsync_Private<PETSC_MEMORY_ACCESS_WRITE>(std::forward<T>(VecSeq_CUPM_Impls), v, a));
1716   PetscFunctionReturn(PETSC_SUCCESS);
1717 }
1718 
1719 template <typename T>
1720 inline PetscErrorCode VecCUPMPlaceArrayAsync(T &&VecSeq_CUPM_Impls, Vec vin, const PetscScalar a[]) noexcept
1721 {
1722   PetscFunctionBegin;
1723   PetscValidHeaderSpecific(vin, VEC_CLASSID, 2);
1724   PetscCall(VecSeq_CUPM_Impls.template placearray<PETSC_MEMTYPE_DEVICE>(vin, a));
1725   PetscFunctionReturn(PETSC_SUCCESS);
1726 }
1727 
1728 template <typename T>
1729 inline PetscErrorCode VecCUPMReplaceArrayAsync(T &&VecSeq_CUPM_Impls, Vec vin, const PetscScalar a[]) noexcept
1730 {
1731   PetscFunctionBegin;
1732   PetscValidHeaderSpecific(vin, VEC_CLASSID, 2);
1733   PetscCall(VecSeq_CUPM_Impls.template replacearray<PETSC_MEMTYPE_DEVICE>(vin, a));
1734   PetscFunctionReturn(PETSC_SUCCESS);
1735 }
1736 
1737 template <typename T>
1738 inline PetscErrorCode VecCUPMResetArrayAsync(T &&VecSeq_CUPM_Impls, Vec vin) noexcept
1739 {
1740   PetscFunctionBegin;
1741   PetscValidHeaderSpecific(vin, VEC_CLASSID, 2);
1742   PetscCall(VecSeq_CUPM_Impls.template resetarray<PETSC_MEMTYPE_DEVICE>(vin));
1743   PetscFunctionReturn(PETSC_SUCCESS);
1744 }
1745 
1746 } // anonymous namespace
1747 
1748 } // namespace impl
1749 
1750 } // namespace cupm
1751 
1752 } // namespace vec
1753 
1754 } // namespace Petsc
1755 
1756 #endif // __cplusplus
1757 
1758 #endif // PETSCVECSEQCUPM_HPP
1759