xref: /petsc/src/vec/vec/impls/seq/cupm/vecseqcupm.hpp (revision eec179cf895b6fbcd6a0b58694319392b06e361b)
1 #ifndef PETSCVECSEQCUPM_HPP
2 #define PETSCVECSEQCUPM_HPP
3 
4 #define PETSC_SKIP_SPINLOCK // REVIEW ME: why
5 
6 #include <petsc/private/veccupmimpl.h>
7 #include <petsc/private/randomimpl.h> // for _p_PetscRandom
8 #include "../src/sys/objects/device/impls/cupm/cupmthrustutility.hpp"
9 #include "../src/sys/objects/device/impls/cupm/cupmallocator.hpp"
10 #include "../src/sys/objects/device/impls/cupm/kernels.hpp"
11 
12 #if defined(__cplusplus)
13   #include <thrust/transform_reduce.h>
14   #include <thrust/reduce.h>
15   #include <thrust/functional.h>
16   #include <thrust/iterator/counting_iterator.h>
17   #include <thrust/inner_product.h>
18 
19 namespace Petsc
20 {
21 
22 namespace vec
23 {
24 
25 namespace cupm
26 {
27 
28 namespace impl
29 {
30 
31 // ==========================================================================================
32 // VecSeq_CUPM
33 // ==========================================================================================
34 
35 template <device::cupm::DeviceType T>
36 class VecSeq_CUPM : Vec_CUPMBase<T, VecSeq_CUPM<T>> {
37 public:
38   PETSC_VEC_CUPM_BASE_CLASS_HEADER(base_type, T, VecSeq_CUPM<T>);
39 
40 private:
41   PETSC_NODISCARD static Vec_Seq          *VecIMPLCast_(Vec) noexcept;
42   PETSC_NODISCARD static constexpr VecType VECIMPLCUPM_() noexcept;
43 
44   PETSC_NODISCARD static PetscErrorCode VecDestroy_IMPL_(Vec) noexcept;
45   PETSC_NODISCARD static PetscErrorCode VecResetArray_IMPL_(Vec) noexcept;
46   PETSC_NODISCARD static PetscErrorCode VecPlaceArray_IMPL_(Vec, const PetscScalar *) noexcept;
47   PETSC_NODISCARD static PetscErrorCode VecCreate_IMPL_Private_(Vec, PetscBool *, PetscInt, PetscScalar *) noexcept;
48 
49   PETSC_NODISCARD static PetscErrorCode MaybeIncrementEmptyLocalVec(Vec) noexcept;
50 
51   // common core for min and max
52   template <typename TupleFuncT, typename UnaryFuncT>
53   PETSC_NODISCARD static PetscErrorCode minmax_(TupleFuncT &&, UnaryFuncT &&, Vec, PetscInt *, PetscReal *) noexcept;
54   // common core for pointwise binary and pointwise unary thrust functions
55   template <typename BinaryFuncT>
56   PETSC_NODISCARD static PetscErrorCode pointwisebinary_(BinaryFuncT &&, Vec, Vec, Vec) noexcept;
57   template <typename UnaryFuncT>
58   PETSC_NODISCARD static PetscErrorCode pointwiseunary_(UnaryFuncT &&, Vec, Vec /*out*/ = nullptr) noexcept;
59   // mdot dispatchers
60   PETSC_NODISCARD static PetscErrorCode mdot_(/* use complex = */ std::true_type, Vec, PetscInt, const Vec[], PetscScalar *, PetscDeviceContext) noexcept;
61   PETSC_NODISCARD static PetscErrorCode mdot_(/* use complex = */ std::false_type, Vec, PetscInt, const Vec[], PetscScalar *, PetscDeviceContext) noexcept;
62   template <std::size_t... Idx>
63   PETSC_NODISCARD static PetscErrorCode mdot_kernel_dispatch_(PetscDeviceContext, cupmStream_t, const PetscScalar *, const Vec[], PetscInt, PetscScalar *, util::index_sequence<Idx...>) noexcept;
64   template <int>
65   PETSC_NODISCARD static PetscErrorCode mdot_kernel_dispatch_(PetscDeviceContext, cupmStream_t, const PetscScalar *, const Vec[], PetscInt, PetscScalar *, PetscInt &) noexcept;
66   template <std::size_t... Idx>
67   PETSC_NODISCARD static PetscErrorCode maxpy_kernel_dispatch_(PetscDeviceContext, cupmStream_t, PetscScalar *, const PetscScalar *, const Vec *, PetscInt, util::index_sequence<Idx...>) noexcept;
68   template <int>
69   PETSC_NODISCARD static PetscErrorCode maxpy_kernel_dispatch_(PetscDeviceContext, cupmStream_t, PetscScalar *, const PetscScalar *, const Vec *, PetscInt, PetscInt &) noexcept;
70   // common core for the various create routines
71   PETSC_NODISCARD static PetscErrorCode createseqcupm_(Vec, PetscDeviceContext, PetscScalar * /*host_ptr*/ = nullptr, PetscScalar * /*device_ptr*/ = nullptr) noexcept;
72 
73 public:
74   // callable directly via a bespoke function
75   PETSC_NODISCARD static PetscErrorCode createseqcupm(MPI_Comm, PetscInt, PetscInt, Vec *, PetscBool) noexcept;
76   PETSC_NODISCARD static PetscErrorCode createseqcupmwithbotharrays(MPI_Comm, PetscInt, PetscInt, const PetscScalar[], const PetscScalar[], Vec *) noexcept;
77 
78   // callable indirectly via function pointers
79   PETSC_NODISCARD static PetscErrorCode duplicate(Vec, Vec *) noexcept;
80   PETSC_NODISCARD static PetscErrorCode aypx(Vec, PetscScalar, Vec) noexcept;
81   PETSC_NODISCARD static PetscErrorCode axpy(Vec, PetscScalar, Vec) noexcept;
82   PETSC_NODISCARD static PetscErrorCode pointwisedivide(Vec, Vec, Vec) noexcept;
83   PETSC_NODISCARD static PetscErrorCode pointwisemult(Vec, Vec, Vec) noexcept;
84   PETSC_NODISCARD static PetscErrorCode reciprocal(Vec) noexcept;
85   PETSC_NODISCARD static PetscErrorCode waxpy(Vec, PetscScalar, Vec, Vec) noexcept;
86   PETSC_NODISCARD static PetscErrorCode maxpy(Vec, PetscInt, const PetscScalar[], Vec *) noexcept;
87   PETSC_NODISCARD static PetscErrorCode dot(Vec, Vec, PetscScalar *) noexcept;
88   PETSC_NODISCARD static PetscErrorCode mdot(Vec, PetscInt, const Vec[], PetscScalar *) noexcept;
89   PETSC_NODISCARD static PetscErrorCode set(Vec, PetscScalar) noexcept;
90   PETSC_NODISCARD static PetscErrorCode scale(Vec, PetscScalar) noexcept;
91   PETSC_NODISCARD static PetscErrorCode tdot(Vec, Vec, PetscScalar *) noexcept;
92   PETSC_NODISCARD static PetscErrorCode copy(Vec, Vec) noexcept;
93   PETSC_NODISCARD static PetscErrorCode swap(Vec, Vec) noexcept;
94   PETSC_NODISCARD static PetscErrorCode axpby(Vec, PetscScalar, PetscScalar, Vec) noexcept;
95   PETSC_NODISCARD static PetscErrorCode axpbypcz(Vec, PetscScalar, PetscScalar, PetscScalar, Vec, Vec) noexcept;
96   PETSC_NODISCARD static PetscErrorCode norm(Vec, NormType, PetscReal *) noexcept;
97   PETSC_NODISCARD static PetscErrorCode dotnorm2(Vec, Vec, PetscScalar *, PetscScalar *) noexcept;
98   PETSC_NODISCARD static PetscErrorCode destroy(Vec) noexcept;
99   PETSC_NODISCARD static PetscErrorCode conjugate(Vec) noexcept;
100   template <PetscMemoryAccessMode>
101   PETSC_NODISCARD static PetscErrorCode getlocalvector(Vec, Vec) noexcept;
102   template <PetscMemoryAccessMode>
103   PETSC_NODISCARD static PetscErrorCode restorelocalvector(Vec, Vec) noexcept;
104   PETSC_NODISCARD static PetscErrorCode max(Vec, PetscInt *, PetscReal *) noexcept;
105   PETSC_NODISCARD static PetscErrorCode min(Vec, PetscInt *, PetscReal *) noexcept;
106   PETSC_NODISCARD static PetscErrorCode sum(Vec, PetscScalar *) noexcept;
107   PETSC_NODISCARD static PetscErrorCode shift(Vec, PetscScalar) noexcept;
108   PETSC_NODISCARD static PetscErrorCode setrandom(Vec, PetscRandom) noexcept;
109   PETSC_NODISCARD static PetscErrorCode bindtocpu(Vec, PetscBool) noexcept;
110   PETSC_NODISCARD static PetscErrorCode setpreallocationcoo(Vec, PetscCount, const PetscInt[]) noexcept;
111   PETSC_NODISCARD static PetscErrorCode setvaluescoo(Vec, const PetscScalar[], InsertMode) noexcept;
112 };
113 
114 // ==========================================================================================
115 // VecSeq_CUPM - Private API
116 // ==========================================================================================
117 
118 template <device::cupm::DeviceType T>
119 inline Vec_Seq *VecSeq_CUPM<T>::VecIMPLCast_(Vec v) noexcept
120 {
121   return static_cast<Vec_Seq *>(v->data);
122 }
123 
124 template <device::cupm::DeviceType T>
125 inline constexpr VecType VecSeq_CUPM<T>::VECIMPLCUPM_() noexcept
126 {
127   return VECSEQCUPM();
128 }
129 
130 template <device::cupm::DeviceType T>
131 inline PetscErrorCode VecSeq_CUPM<T>::VecDestroy_IMPL_(Vec v) noexcept
132 {
133   return VecDestroy_Seq(v);
134 }
135 
136 template <device::cupm::DeviceType T>
137 inline PetscErrorCode VecSeq_CUPM<T>::VecResetArray_IMPL_(Vec v) noexcept
138 {
139   return VecResetArray_Seq(v);
140 }
141 
142 template <device::cupm::DeviceType T>
143 inline PetscErrorCode VecSeq_CUPM<T>::VecPlaceArray_IMPL_(Vec v, const PetscScalar *a) noexcept
144 {
145   return VecPlaceArray_Seq(v, a);
146 }
147 
148 template <device::cupm::DeviceType T>
149 inline PetscErrorCode VecSeq_CUPM<T>::VecCreate_IMPL_Private_(Vec v, PetscBool *alloc_missing, PetscInt, PetscScalar *host_array) noexcept
150 {
151   PetscMPIInt size;
152 
153   PetscFunctionBegin;
154   if (alloc_missing) *alloc_missing = PETSC_FALSE;
155   PetscCallMPI(MPI_Comm_size(PetscObjectComm(PetscObjectCast(v)), &size));
156   PetscCheck(size <= 1, PETSC_COMM_SELF, PETSC_ERR_ARG_WRONG, "Must create VecSeq on communicator of size 1, have size %d", size);
157   PetscCall(VecCreate_Seq_Private(v, host_array));
158   PetscFunctionReturn(0);
159 }
160 
161 // for functions with an early return based one vec size we still need to artificially bump the
162 // object state. This is to prevent the following:
163 //
164 // 0. Suppose you have a Vec {
165 //   rank 0: [0],
166 //   rank 1: [<empty>]
167 // }
168 // 1. both ranks have Vec with PetscObjectState = 0, stashed norm of 0
169 // 2. Vec enters e.g. VecSet(10)
170 // 3. rank 1 has local size 0 and bails immediately
171 // 4. rank 0 has local size 1 and enters function, eventually calls DeviceArrayWrite()
172 // 5. DeviceArrayWrite() calls PetscObjectStateIncrease(), now state = 1
173 // 6. Vec enters VecNorm(), and calls VecNormAvailable()
174 // 7. rank 1 has object state = 0, equal to stash and returns early with norm = 0
175 // 8. rank 0 has object state = 1, not equal to stash, continues to impl function
176 // 9. rank 0 deadlocks on MPI_Allreduce() because rank 1 bailed early
177 template <device::cupm::DeviceType T>
178 inline PetscErrorCode VecSeq_CUPM<T>::MaybeIncrementEmptyLocalVec(Vec v) noexcept
179 {
180   PetscFunctionBegin;
181   if (PetscUnlikely((v->map->n == 0) && (v->map->N != 0))) PetscCall(PetscObjectStateIncrease(PetscObjectCast(v)));
182   PetscFunctionReturn(0);
183 }
184 
185 template <device::cupm::DeviceType T>
186 inline PetscErrorCode VecSeq_CUPM<T>::createseqcupm_(Vec v, PetscDeviceContext dctx, PetscScalar *host_array, PetscScalar *device_array) noexcept
187 {
188   PetscFunctionBegin;
189   PetscCall(base_type::VecCreate_IMPL_Private(v, nullptr, 0, host_array));
190   PetscCall(Initialize_CUPMBase(v, PETSC_FALSE, host_array, device_array, dctx));
191   PetscFunctionReturn(0);
192 }
193 
194 template <device::cupm::DeviceType T>
195 template <typename BinaryFuncT>
196 inline PetscErrorCode VecSeq_CUPM<T>::pointwisebinary_(BinaryFuncT &&binary, Vec xin, Vec yin, Vec zout) noexcept
197 {
198   PetscFunctionBegin;
199   if (const auto n = zout->map->n) {
200     PetscDeviceContext dctx;
201 
202     PetscCall(GetHandles_(&dctx));
203     PetscCall(device::cupm::impl::ThrustApplyPointwise<T>(dctx, std::forward<BinaryFuncT>(binary), n, DeviceArrayRead(dctx, xin).data(), DeviceArrayRead(dctx, yin).data(), DeviceArrayWrite(dctx, zout).data()));
204     PetscCall(PetscDeviceContextSynchronize(dctx));
205   } else {
206     PetscCall(MaybeIncrementEmptyLocalVec(zout));
207   }
208   PetscFunctionReturn(0);
209 }
210 
211 template <device::cupm::DeviceType T>
212 template <typename UnaryFuncT>
213 inline PetscErrorCode VecSeq_CUPM<T>::pointwiseunary_(UnaryFuncT &&unary, Vec xinout, Vec yin) noexcept
214 {
215   const auto inplace = !yin || (xinout == yin);
216 
217   PetscFunctionBegin;
218   if (const auto n = xinout->map->n) {
219     PetscDeviceContext dctx;
220 
221     PetscCall(GetHandles_(&dctx));
222     if (inplace) {
223       PetscCall(device::cupm::impl::ThrustApplyPointwise<T>(dctx, std::forward<UnaryFuncT>(unary), n, DeviceArrayReadWrite(dctx, xinout).data()));
224     } else {
225       PetscCall(device::cupm::impl::ThrustApplyPointwise<T>(dctx, std::forward<UnaryFuncT>(unary), n, DeviceArrayRead(dctx, xinout).data(), DeviceArrayWrite(dctx, yin).data()));
226     }
227     PetscCall(PetscDeviceContextSynchronize(dctx));
228   } else {
229     if (inplace) {
230       PetscCall(MaybeIncrementEmptyLocalVec(xinout));
231     } else {
232       PetscCall(MaybeIncrementEmptyLocalVec(yin));
233     }
234   }
235   PetscFunctionReturn(0);
236 }
237 
238 // ==========================================================================================
239 // VecSeq_CUPM - Public API - Constructors
240 // ==========================================================================================
241 
242 // VecCreateSeqCUPM()
243 template <device::cupm::DeviceType T>
244 inline PetscErrorCode VecSeq_CUPM<T>::createseqcupm(MPI_Comm comm, PetscInt bs, PetscInt n, Vec *v, PetscBool call_set_type) noexcept
245 {
246   PetscFunctionBegin;
247   PetscCall(Create_CUPMBase(comm, bs, n, n, v, call_set_type));
248   PetscFunctionReturn(0);
249 }
250 
251 // VecCreateSeqCUPMWithArrays()
252 template <device::cupm::DeviceType T>
253 inline PetscErrorCode VecSeq_CUPM<T>::createseqcupmwithbotharrays(MPI_Comm comm, PetscInt bs, PetscInt n, const PetscScalar host_array[], const PetscScalar device_array[], Vec *v) noexcept
254 {
255   PetscDeviceContext dctx;
256 
257   PetscFunctionBegin;
258   PetscCall(GetHandles_(&dctx));
259   // do NOT call VecSetType(), otherwise ops->create() -> create() ->
260   // createseqcupm_() is called!
261   PetscCall(createseqcupm(comm, bs, n, v, PETSC_FALSE));
262   PetscCall(createseqcupm_(*v, dctx, PetscRemoveConstCast(host_array), PetscRemoveConstCast(device_array)));
263   PetscFunctionReturn(0);
264 }
265 
266 // v->ops->duplicate
267 template <device::cupm::DeviceType T>
268 inline PetscErrorCode VecSeq_CUPM<T>::duplicate(Vec v, Vec *y) noexcept
269 {
270   PetscDeviceContext dctx;
271 
272   PetscFunctionBegin;
273   PetscCall(GetHandles_(&dctx));
274   PetscCall(Duplicate_CUPMBase(v, y, dctx));
275   PetscFunctionReturn(0);
276 }
277 
278 // ==========================================================================================
279 // VecSeq_CUPM - Public API - Utility
280 // ==========================================================================================
281 
282 // v->ops->bindtocpu
283 template <device::cupm::DeviceType T>
284 inline PetscErrorCode VecSeq_CUPM<T>::bindtocpu(Vec v, PetscBool usehost) noexcept
285 {
286   PetscDeviceContext dctx;
287 
288   PetscFunctionBegin;
289   PetscCall(GetHandles_(&dctx));
290   PetscCall(BindToCPU_CUPMBase(v, usehost, dctx));
291 
292   // REVIEW ME: this absolutely should be some sort of bulk mempcy rather than this mess
293   VecSetOp_CUPM(dot, VecDot_Seq, dot);
294   VecSetOp_CUPM(norm, VecNorm_Seq, norm);
295   VecSetOp_CUPM(tdot, VecTDot_Seq, tdot);
296   VecSetOp_CUPM(mdot, VecMDot_Seq, mdot);
297   VecSetOp_CUPM(resetarray, VecResetArray_Seq, base_type::template resetarray<PETSC_MEMTYPE_HOST>);
298   VecSetOp_CUPM(placearray, VecPlaceArray_Seq, base_type::template placearray<PETSC_MEMTYPE_HOST>);
299   v->ops->mtdot = v->ops->mtdot_local = VecMTDot_Seq;
300   VecSetOp_CUPM(conjugate, VecConjugate_Seq, conjugate);
301   VecSetOp_CUPM(max, VecMax_Seq, max);
302   VecSetOp_CUPM(min, VecMin_Seq, min);
303   VecSetOp_CUPM(setpreallocationcoo, VecSetPreallocationCOO_Seq, setpreallocationcoo);
304   VecSetOp_CUPM(setvaluescoo, VecSetValuesCOO_Seq, setvaluescoo);
305   PetscFunctionReturn(0);
306 }
307 
308 // ==========================================================================================
309 // VecSeq_CUPM - Public API - Mutatators
310 // ==========================================================================================
311 
312 // v->ops->getlocalvector or v->ops->getlocalvectorread
313 template <device::cupm::DeviceType T>
314 template <PetscMemoryAccessMode access>
315 inline PetscErrorCode VecSeq_CUPM<T>::getlocalvector(Vec v, Vec w) noexcept
316 {
317   PetscBool wisseqcupm;
318 
319   PetscFunctionBegin;
320   PetscCheckTypeNames(v, VECSEQCUPM(), VECMPICUPM());
321   PetscCall(PetscObjectTypeCompare(PetscObjectCast(w), VECSEQCUPM(), &wisseqcupm));
322   if (wisseqcupm) {
323     if (const auto wseq = VecIMPLCast(w)) {
324       if (auto &alloced = wseq->array_allocated) {
325         const auto useit = UseCUPMHostAlloc(util::exchange(w->pinned_memory, PETSC_FALSE));
326 
327         PetscCall(PetscFree(alloced));
328       }
329       wseq->array         = nullptr;
330       wseq->unplacedarray = nullptr;
331     }
332     if (const auto wcu = VecCUPMCast(w)) {
333       if (auto &device_array = wcu->array_d) {
334         cupmStream_t stream;
335 
336         PetscCall(GetHandles_(&stream));
337         PetscCallCUPM(cupmFreeAsync(device_array, stream));
338       }
339       PetscCall(PetscFree(w->spptr /* wcu */));
340     }
341   }
342   if (v->petscnative && wisseqcupm) {
343     PetscCall(PetscFree(w->data));
344     w->data          = v->data;
345     w->offloadmask   = v->offloadmask;
346     w->pinned_memory = v->pinned_memory;
347     w->spptr         = v->spptr;
348     PetscCall(PetscObjectStateIncrease(PetscObjectCast(w)));
349   } else {
350     const auto array = &VecIMPLCast(w)->array;
351 
352     if (access == PETSC_MEMORY_ACCESS_READ) {
353       PetscCall(VecGetArrayRead(v, const_cast<const PetscScalar **>(array)));
354     } else {
355       PetscCall(VecGetArray(v, array));
356     }
357     w->offloadmask = PETSC_OFFLOAD_CPU;
358     if (wisseqcupm) {
359       PetscDeviceContext dctx;
360 
361       PetscCall(GetHandles_(&dctx));
362       PetscCall(DeviceAllocateCheck_(dctx, w));
363     }
364   }
365   PetscFunctionReturn(0);
366 }
367 
368 // v->ops->restorelocalvector or v->ops->restorelocalvectorread
369 template <device::cupm::DeviceType T>
370 template <PetscMemoryAccessMode access>
371 inline PetscErrorCode VecSeq_CUPM<T>::restorelocalvector(Vec v, Vec w) noexcept
372 {
373   PetscBool wisseqcupm;
374 
375   PetscFunctionBegin;
376   PetscCheckTypeNames(v, VECSEQCUPM(), VECMPICUPM());
377   PetscCall(PetscObjectTypeCompare(PetscObjectCast(w), VECSEQCUPM(), &wisseqcupm));
378   if (v->petscnative && wisseqcupm) {
379     // the assignments to nullptr are __critical__, as w may persist after this call returns
380     // and shouldn't share data with v!
381     v->pinned_memory = w->pinned_memory;
382     v->offloadmask   = util::exchange(w->offloadmask, PETSC_OFFLOAD_UNALLOCATED);
383     v->data          = util::exchange(w->data, nullptr);
384     v->spptr         = util::exchange(w->spptr, nullptr);
385   } else {
386     const auto array = &VecIMPLCast(w)->array;
387 
388     if (access == PETSC_MEMORY_ACCESS_READ) {
389       PetscCall(VecRestoreArrayRead(v, const_cast<const PetscScalar **>(array)));
390     } else {
391       PetscCall(VecRestoreArray(v, array));
392     }
393     if (w->spptr && wisseqcupm) {
394       cupmStream_t stream;
395 
396       PetscCall(GetHandles_(&stream));
397       PetscCallCUPM(cupmFreeAsync(VecCUPMCast(w)->array_d, stream));
398       PetscCall(PetscFree(w->spptr));
399     }
400   }
401   PetscFunctionReturn(0);
402 }
403 
404 // ==========================================================================================
405 // VecSeq_CUPM - Public API - Compute Methods
406 // ==========================================================================================
407 
408 // v->ops->aypx
409 template <device::cupm::DeviceType T>
410 inline PetscErrorCode VecSeq_CUPM<T>::aypx(Vec yin, PetscScalar alpha, Vec xin) noexcept
411 {
412   const auto         n    = static_cast<cupmBlasInt_t>(yin->map->n);
413   const auto         sync = n != 0;
414   PetscDeviceContext dctx;
415 
416   PetscFunctionBegin;
417   PetscCall(GetHandles_(&dctx));
418   if (alpha == PetscScalar(0.0)) {
419     cupmStream_t stream;
420 
421     PetscCall(GetHandlesFrom_(dctx, &stream));
422     PetscCall(PetscLogGpuTimeBegin());
423     PetscCall(PetscCUPMMemcpyAsync(DeviceArrayWrite(dctx, yin).data(), DeviceArrayRead(dctx, xin).data(), n, cupmMemcpyDeviceToDevice, stream));
424     PetscCall(PetscLogGpuTimeEnd());
425   } else if (n) {
426     const auto       alphaIsOne = alpha == PetscScalar(1.0);
427     const auto       calpha     = cupmScalarPtrCast(&alpha);
428     cupmBlasHandle_t cupmBlasHandle;
429 
430     PetscCall(GetHandlesFrom_(dctx, &cupmBlasHandle));
431     {
432       const auto yptr = DeviceArrayReadWrite(dctx, yin);
433       const auto xptr = DeviceArrayRead(dctx, xin);
434 
435       PetscCall(PetscLogGpuTimeBegin());
436       if (alphaIsOne) {
437         PetscCallCUPMBLAS(cupmBlasXaxpy(cupmBlasHandle, n, calpha, xptr.cupmdata(), 1, yptr.cupmdata(), 1));
438       } else {
439         const auto one = cupmScalarCast(1.0);
440 
441         PetscCallCUPMBLAS(cupmBlasXscal(cupmBlasHandle, n, calpha, yptr.cupmdata(), 1));
442         PetscCallCUPMBLAS(cupmBlasXaxpy(cupmBlasHandle, n, &one, xptr.cupmdata(), 1, yptr.cupmdata(), 1));
443       }
444       PetscCall(PetscLogGpuTimeEnd());
445     }
446     PetscCall(PetscLogGpuFlops((alphaIsOne ? 1 : 2) * n));
447   }
448   if (sync) PetscCall(PetscDeviceContextSynchronize(dctx));
449   PetscFunctionReturn(0);
450 }
451 
452 // v->ops->axpy
453 template <device::cupm::DeviceType T>
454 inline PetscErrorCode VecSeq_CUPM<T>::axpy(Vec yin, PetscScalar alpha, Vec xin) noexcept
455 {
456   PetscBool xiscupm;
457 
458   PetscFunctionBegin;
459   PetscCall(PetscObjectTypeCompareAny(PetscObjectCast(xin), &xiscupm, VECSEQCUPM(), VECMPICUPM(), ""));
460   if (xiscupm) {
461     const auto         n = static_cast<cupmBlasInt_t>(yin->map->n);
462     PetscDeviceContext dctx;
463     cupmBlasHandle_t   cupmBlasHandle;
464 
465     PetscCall(GetHandles_(&dctx, &cupmBlasHandle));
466     PetscCall(PetscLogGpuTimeBegin());
467     PetscCallCUPMBLAS(cupmBlasXaxpy(cupmBlasHandle, n, cupmScalarPtrCast(&alpha), DeviceArrayRead(dctx, xin), 1, DeviceArrayReadWrite(dctx, yin), 1));
468     PetscCall(PetscLogGpuTimeEnd());
469     PetscCall(PetscLogGpuFlops(2 * n));
470     PetscCall(PetscDeviceContextSynchronize(dctx));
471   } else {
472     PetscCall(VecAXPY_Seq(yin, alpha, xin));
473   }
474   PetscFunctionReturn(0);
475 }
476 
477 // v->ops->pointwisedivide
478 template <device::cupm::DeviceType T>
479 inline PetscErrorCode VecSeq_CUPM<T>::pointwisedivide(Vec win, Vec xin, Vec yin) noexcept
480 {
481   PetscFunctionBegin;
482   if (xin->boundtocpu || yin->boundtocpu) {
483     PetscCall(VecPointwiseDivide_Seq(win, xin, yin));
484   } else {
485     // note order of arguments! xin and yin are read, win is written!
486     PetscCall(pointwisebinary_(thrust::divides<PetscScalar>{}, xin, yin, win));
487   }
488   PetscFunctionReturn(0);
489 }
490 
491 // v->ops->pointwisemult
492 template <device::cupm::DeviceType T>
493 inline PetscErrorCode VecSeq_CUPM<T>::pointwisemult(Vec win, Vec xin, Vec yin) noexcept
494 {
495   PetscFunctionBegin;
496   if (xin->boundtocpu || yin->boundtocpu) {
497     PetscCall(VecPointwiseMult_Seq(win, xin, yin));
498   } else {
499     // note order of arguments! xin and yin are read, win is written!
500     PetscCall(pointwisebinary_(thrust::multiplies<PetscScalar>{}, xin, yin, win));
501   }
502   PetscFunctionReturn(0);
503 }
504 
505 namespace detail
506 {
507 
508 struct reciprocal {
509   PETSC_HOSTDEVICE_INLINE_DECL PetscScalar operator()(PetscScalar s) const noexcept
510   {
511     // yes all of this verbosity is needed because sometimes PetscScalar is a thrust::complex
512     // and then it matters whether we do s ? true : false vs s == 0, as well as whether we wrap
513     // everything in PetscScalar...
514     return s == PetscScalar{0.0} ? s : PetscScalar{1.0} / s;
515   }
516 };
517 
518 } // namespace detail
519 
520 // v->ops->reciprocal
521 template <device::cupm::DeviceType T>
522 inline PetscErrorCode VecSeq_CUPM<T>::reciprocal(Vec xin) noexcept
523 {
524   PetscFunctionBegin;
525   PetscCall(pointwiseunary_(detail::reciprocal{}, xin));
526   PetscFunctionReturn(0);
527 }
528 
529 // v->ops->waxpy
530 template <device::cupm::DeviceType T>
531 inline PetscErrorCode VecSeq_CUPM<T>::waxpy(Vec win, PetscScalar alpha, Vec xin, Vec yin) noexcept
532 {
533   PetscFunctionBegin;
534   if (alpha == PetscScalar(0.0)) {
535     PetscCall(copy(yin, win));
536   } else if (const auto n = static_cast<cupmBlasInt_t>(win->map->n)) {
537     PetscDeviceContext dctx;
538     cupmBlasHandle_t   cupmBlasHandle;
539     cupmStream_t       stream;
540 
541     PetscCall(GetHandles_(&dctx, &cupmBlasHandle, &stream));
542     {
543       const auto wptr = DeviceArrayWrite(dctx, win);
544 
545       PetscCall(PetscLogGpuTimeBegin());
546       PetscCall(PetscCUPMMemcpyAsync(wptr.data(), DeviceArrayRead(dctx, yin).data(), n, cupmMemcpyDeviceToDevice, stream, true));
547       PetscCallCUPMBLAS(cupmBlasXaxpy(cupmBlasHandle, n, cupmScalarPtrCast(&alpha), DeviceArrayRead(dctx, xin), 1, wptr.cupmdata(), 1));
548       PetscCall(PetscLogGpuTimeEnd());
549     }
550     PetscCall(PetscLogGpuFlops(2 * n));
551     PetscCall(PetscDeviceContextSynchronize(dctx));
552   }
553   PetscFunctionReturn(0);
554 }
555 
556 namespace kernels
557 {
558 
559 template <typename... Args>
560 PETSC_KERNEL_DECL static void maxpy_kernel(const PetscInt size, PetscScalar *PETSC_RESTRICT xptr, const PetscScalar *PETSC_RESTRICT aptr, Args... yptr)
561 {
562   constexpr int      N        = sizeof...(Args);
563   const auto         tx       = threadIdx.x;
564   const PetscScalar *yptr_p[] = {yptr...};
565 
566   PETSC_SHAREDMEM_DECL PetscScalar aptr_shmem[N];
567 
568   // load a to shared memory
569   if (tx < N) aptr_shmem[tx] = aptr[tx];
570   __syncthreads();
571 
572   ::Petsc::device::cupm::kernels::util::grid_stride_1D(size, [&](PetscInt i) {
573     // these may look the same but give different results!
574   #if 0
575     PetscScalar sum = 0.0;
576 
577     #pragma unroll
578     for (auto j = 0; j < N; ++j) sum += aptr_shmem[j]*yptr_p[j][i];
579     xptr[i] += sum;
580   #else
581     auto sum = xptr[i];
582 
583     #pragma unroll
584     for (auto j = 0; j < N; ++j) sum += aptr_shmem[j]*yptr_p[j][i];
585     xptr[i] = sum;
586   #endif
587   });
588   return;
589 }
590 
591 } // namespace kernels
592 
593 namespace detail
594 {
595 
596 // a helper-struct to gobble the size_t input, it is used with template parameter pack
597 // expansion such that
598 // typename repeat_type<MyType, IdxParamPack>...
599 // expands to
600 // MyType, MyType, MyType, ... [repeated sizeof...(IdxParamPack) times]
601 template <typename T, std::size_t>
602 struct repeat_type {
603   using type = T;
604 };
605 
606 } // namespace detail
607 
608 template <device::cupm::DeviceType T>
609 template <std::size_t... Idx>
610 inline PetscErrorCode VecSeq_CUPM<T>::maxpy_kernel_dispatch_(PetscDeviceContext dctx, cupmStream_t stream, PetscScalar *xptr, const PetscScalar *aptr, const Vec *yin, PetscInt size, util::index_sequence<Idx...>) noexcept
611 {
612   PetscFunctionBegin;
613   // clang-format off
614   PetscCall(
615     PetscCUPMLaunchKernel1D(
616       size, 0, stream,
617       kernels::maxpy_kernel<typename detail::repeat_type<const PetscScalar *, Idx>::type...>,
618       size, xptr, aptr, DeviceArrayRead(dctx, yin[Idx]).data()...
619     )
620   );
621   // clang-format on
622   PetscFunctionReturn(0);
623 }
624 
625 template <device::cupm::DeviceType T>
626 template <int N>
627 inline PetscErrorCode VecSeq_CUPM<T>::maxpy_kernel_dispatch_(PetscDeviceContext dctx, cupmStream_t stream, PetscScalar *xptr, const PetscScalar *aptr, const Vec *yin, PetscInt size, PetscInt &yidx) noexcept
628 {
629   PetscFunctionBegin;
630   PetscCall(maxpy_kernel_dispatch_(dctx, stream, xptr, aptr + yidx, yin + yidx, size, util::make_index_sequence<N>{}));
631   yidx += N;
632   PetscFunctionReturn(0);
633 }
634 
635 // v->ops->maxpy
636 template <device::cupm::DeviceType T>
637 inline PetscErrorCode VecSeq_CUPM<T>::maxpy(Vec xin, PetscInt nv, const PetscScalar *alpha, Vec *yin) noexcept
638 {
639   const auto         n = xin->map->n;
640   PetscDeviceContext dctx;
641   cupmStream_t       stream;
642 
643   PetscFunctionBegin;
644   PetscCall(GetHandles_(&dctx, &stream));
645   {
646     const auto   xptr    = DeviceArrayReadWrite(dctx, xin);
647     PetscScalar *d_alpha = nullptr;
648     PetscInt     yidx    = 0;
649 
650     // placement of early-return is deliberate, we would like to capture the
651     // DeviceArrayReadWrite() call (which calls PetscObjectStateIncreate()) before we bail
652     if (!n || !nv) PetscFunctionReturn(0);
653     PetscCall(PetscDeviceMalloc(dctx, PETSC_MEMTYPE_CUPM(), nv, &d_alpha));
654     PetscCall(PetscCUPMMemcpyAsync(d_alpha, alpha, nv, cupmMemcpyHostToDevice, stream));
655     PetscCall(PetscLogGpuTimeBegin());
656     do {
657       switch (nv - yidx) {
658       case 7:
659         PetscCall(maxpy_kernel_dispatch_<7>(dctx, stream, xptr.data(), d_alpha, yin, n, yidx));
660         break;
661       case 6:
662         PetscCall(maxpy_kernel_dispatch_<6>(dctx, stream, xptr.data(), d_alpha, yin, n, yidx));
663         break;
664       case 5:
665         PetscCall(maxpy_kernel_dispatch_<5>(dctx, stream, xptr.data(), d_alpha, yin, n, yidx));
666         break;
667       case 4:
668         PetscCall(maxpy_kernel_dispatch_<4>(dctx, stream, xptr.data(), d_alpha, yin, n, yidx));
669         break;
670       case 3:
671         PetscCall(maxpy_kernel_dispatch_<3>(dctx, stream, xptr.data(), d_alpha, yin, n, yidx));
672         break;
673       case 2:
674         PetscCall(maxpy_kernel_dispatch_<2>(dctx, stream, xptr.data(), d_alpha, yin, n, yidx));
675         break;
676       case 1:
677         PetscCall(maxpy_kernel_dispatch_<1>(dctx, stream, xptr.data(), d_alpha, yin, n, yidx));
678         break;
679       default: // 8 or more
680         PetscCall(maxpy_kernel_dispatch_<8>(dctx, stream, xptr.data(), d_alpha, yin, n, yidx));
681         break;
682       }
683     } while (yidx < nv);
684     PetscCall(PetscLogGpuTimeEnd());
685     PetscCall(PetscDeviceFree(dctx, d_alpha));
686   }
687   PetscCall(PetscLogGpuFlops(nv * 2 * n));
688   PetscCall(PetscDeviceContextSynchronize(dctx));
689   PetscFunctionReturn(0);
690 }
691 
692 template <device::cupm::DeviceType T>
693 inline PetscErrorCode VecSeq_CUPM<T>::dot(Vec xin, Vec yin, PetscScalar *z) noexcept
694 {
695   PetscFunctionBegin;
696   if (const auto n = static_cast<cupmBlasInt_t>(xin->map->n)) {
697     PetscDeviceContext dctx;
698     cupmBlasHandle_t   cupmBlasHandle;
699 
700     PetscCall(GetHandles_(&dctx, &cupmBlasHandle));
701     // arguments y, x are reversed because BLAS complex conjugates the first argument, PETSc the
702     // second
703     PetscCall(PetscLogGpuTimeBegin());
704     PetscCallCUPMBLAS(cupmBlasXdot(cupmBlasHandle, n, DeviceArrayRead(dctx, yin), 1, DeviceArrayRead(dctx, xin), 1, cupmScalarPtrCast(z)));
705     PetscCall(PetscLogGpuTimeEnd());
706     PetscCall(PetscLogGpuFlops(2 * n - 1));
707   } else {
708     *z = 0.0;
709   }
710   PetscFunctionReturn(0);
711 }
712 
713   #define MDOT_WORKGROUP_NUM  128
714   #define MDOT_WORKGROUP_SIZE MDOT_WORKGROUP_NUM
715 
716 namespace kernels
717 {
718 
719 PETSC_DEVICE_INLINE_DECL static PetscInt EntriesPerGroup(const PetscInt size) noexcept
720 {
721   const auto group_entries = (size - 1) / gridDim.x + 1;
722   // for very small vectors, a group should still do some work
723   return group_entries ? group_entries : 1;
724 }
725 
726 template <typename... ConstPetscScalarPointer>
727 PETSC_KERNEL_DECL static void mdot_kernel(const PetscScalar *PETSC_RESTRICT x, const PetscInt size, PetscScalar *PETSC_RESTRICT results, ConstPetscScalarPointer... y)
728 {
729   constexpr int      N        = sizeof...(ConstPetscScalarPointer);
730   const PetscScalar *ylocal[] = {y...};
731   PetscScalar        sumlocal[N];
732 
733   PETSC_SHAREDMEM_DECL PetscScalar shmem[N * MDOT_WORKGROUP_SIZE];
734 
735   // HIP -- for whatever reason -- has threadIdx, blockIdx, blockDim, and gridDim as separate
736   // types, so each of these go on separate lines...
737   const auto tx       = threadIdx.x;
738   const auto bx       = blockIdx.x;
739   const auto bdx      = blockDim.x;
740   const auto gdx      = gridDim.x;
741   const auto worksize = EntriesPerGroup(size);
742   const auto begin    = tx + bx * worksize;
743   const auto end      = min((bx + 1) * worksize, size);
744 
745   #pragma unroll
746   for (auto i = 0; i < N; ++i) sumlocal[i] = 0;
747 
748   for (auto i = begin; i < end; i += bdx) {
749     const auto xi = x[i]; // load only once from global memory!
750 
751   #pragma unroll
752     for (auto j = 0; j < N; ++j) sumlocal[j] += ylocal[j][i] * xi;
753   }
754 
755   #pragma unroll
756   for (auto i = 0; i < N; ++i) shmem[tx + i * MDOT_WORKGROUP_SIZE] = sumlocal[i];
757 
758   // parallel reduction
759   for (auto stride = bdx / 2; stride > 0; stride /= 2) {
760     __syncthreads();
761     if (tx < stride) {
762   #pragma unroll
763       for (auto i = 0; i < N; ++i) shmem[tx + i * MDOT_WORKGROUP_SIZE] += shmem[tx + stride + i * MDOT_WORKGROUP_SIZE];
764     }
765   }
766   // bottom N threads per block write to global memory
767   // REVIEW ME: I am ~pretty~ sure we don't need another __syncthreads() here since each thread
768   // writes to the same sections in the above loop that it is about to read from below, but
769   // running this under the racecheck tool of cuda-memcheck reports a write-after-write hazard.
770   __syncthreads();
771   if (tx < N) results[bx + tx * gdx] = shmem[tx * MDOT_WORKGROUP_SIZE];
772   return;
773 }
774 
775 PETSC_KERNEL_DECL static void sum_kernel(const PetscInt size, PetscScalar *PETSC_RESTRICT results)
776 {
777   int         local_i = 0;
778   PetscScalar local_results[8];
779 
780   // each thread sums up MDOT_WORKGROUP_NUM entries of the result, storing it in a local buffer
781   //
782   // *-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*
783   // | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | ...
784   // *-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*
785   //  |  ______________________________________________________/
786   //  | /            <- MDOT_WORKGROUP_NUM ->
787   //  |/
788   //  +
789   //  v
790   // *-*-*
791   // | | | ...
792   // *-*-*
793   //
794   ::Petsc::device::cupm::kernels::util::grid_stride_1D(size, [&](PetscInt i) {
795     PetscScalar z_sum = 0;
796 
797     for (auto j = i * MDOT_WORKGROUP_SIZE; j < (i + 1) * MDOT_WORKGROUP_SIZE; ++j) z_sum += results[j];
798     local_results[local_i++] = z_sum;
799   });
800   // if we needed more than 1 workgroup to handle the vector we should sync since other threads
801   // may currently be reading from results
802   if (size >= MDOT_WORKGROUP_SIZE) __syncthreads();
803   // Local buffer is now written to global memory
804   ::Petsc::device::cupm::kernels::util::grid_stride_1D(size, [&](PetscInt i) {
805     const auto j = --local_i;
806 
807     if (j >= 0) results[i] = local_results[j];
808   });
809   return;
810 }
811 
812 } // namespace kernels
813 
814 template <device::cupm::DeviceType T>
815 template <std::size_t... Idx>
816 inline PetscErrorCode VecSeq_CUPM<T>::mdot_kernel_dispatch_(PetscDeviceContext dctx, cupmStream_t stream, const PetscScalar *xarr, const Vec yin[], PetscInt size, PetscScalar *results, util::index_sequence<Idx...>) noexcept
817 {
818   PetscFunctionBegin;
819   // REVIEW ME: convert this kernel launch to PetscCUPMLaunchKernel1D(), it currently launches
820   // 128 blocks of 128 threads every time which may be wasteful
821   // clang-format off
822   PetscCallCUPM(
823     cupmLaunchKernel(
824       kernels::mdot_kernel<typename detail::repeat_type<const PetscScalar *, Idx>::type...>,
825       MDOT_WORKGROUP_NUM, MDOT_WORKGROUP_SIZE, 0, stream,
826       xarr, size, results, DeviceArrayRead(dctx, yin[Idx]).data()...
827     )
828   );
829   // clang-format on
830   PetscFunctionReturn(0);
831 }
832 
833 template <device::cupm::DeviceType T>
834 template <int N>
835 inline PetscErrorCode VecSeq_CUPM<T>::mdot_kernel_dispatch_(PetscDeviceContext dctx, cupmStream_t stream, const PetscScalar *xarr, const Vec yin[], PetscInt size, PetscScalar *results, PetscInt &yidx) noexcept
836 {
837   PetscFunctionBegin;
838   PetscCall(mdot_kernel_dispatch_(dctx, stream, xarr, yin + yidx, size, results + yidx * MDOT_WORKGROUP_NUM, util::make_index_sequence<N>{}));
839   yidx += N;
840   PetscFunctionReturn(0);
841 }
842 
843 template <device::cupm::DeviceType T>
844 inline PetscErrorCode VecSeq_CUPM<T>::mdot_(std::false_type, Vec xin, PetscInt nv, const Vec yin[], PetscScalar *z, PetscDeviceContext dctx) noexcept
845 {
846   // the largest possible size of a batch
847   constexpr PetscInt batchsize = 8;
848   // how many sub streams to create, if nv <= batchsize we can do this without looping, so we
849   // do not create substreams. Note we don't create more than 8 streams, in practice we could
850   // not get more parallelism with higher numbers.
851   const auto num_sub_streams = nv > batchsize ? std::min((nv + batchsize) / batchsize, batchsize) : 0;
852   const auto n               = xin->map->n;
853   // number of vectors that we handle via the batches. note any singletons are handled by
854   // cublas, hence the nv-1.
855   const auto   nvbatch = ((nv % batchsize) == 1) ? nv - 1 : nv;
856   const auto   nwork   = nvbatch * MDOT_WORKGROUP_NUM;
857   PetscScalar *d_results;
858   cupmStream_t stream;
859 
860   PetscFunctionBegin;
861   PetscCall(GetHandlesFrom_(dctx, &stream));
862   // allocate scratchpad memory for the results of individual work groups
863   PetscCall(PetscDeviceMalloc(dctx, PETSC_MEMTYPE_CUPM(), nwork, &d_results));
864   {
865     const auto          xptr       = DeviceArrayRead(dctx, xin);
866     PetscInt            yidx       = 0;
867     auto                subidx     = 0;
868     auto                cur_stream = stream;
869     auto                cur_ctx    = dctx;
870     PetscDeviceContext *sub        = nullptr;
871     PetscStreamType     stype;
872 
873     // REVIEW ME: maybe PetscDeviceContextFork() should insert dctx into the first entry of
874     // sub. Ideally the parent context should also join in on the fork, but it is extremely
875     // fiddly to do so presently
876     PetscCall(PetscDeviceContextGetStreamType(dctx, &stype));
877     if (stype == PETSC_STREAM_GLOBAL_BLOCKING) stype = PETSC_STREAM_DEFAULT_BLOCKING;
878     // If we have a globally blocking stream create nonblocking streams instead (as we can
879     // locally exploit the parallelism). Otherwise use the prescribed stream type.
880     PetscCall(PetscDeviceContextForkWithStreamType(dctx, stype, num_sub_streams, &sub));
881     PetscCall(PetscLogGpuTimeBegin());
882     do {
883       if (num_sub_streams) {
884         cur_ctx = sub[subidx++ % num_sub_streams];
885         PetscCall(GetHandlesFrom_(cur_ctx, &cur_stream));
886       }
887       // REVIEW ME: Should probably try and load-balance these. Consider the case where nv = 9;
888       // it is very likely better to do 4+5 rather than 8+1
889       switch (nv - yidx) {
890       case 7:
891         PetscCall(mdot_kernel_dispatch_<7>(cur_ctx, cur_stream, xptr.data(), yin, n, d_results, yidx));
892         break;
893       case 6:
894         PetscCall(mdot_kernel_dispatch_<6>(cur_ctx, cur_stream, xptr.data(), yin, n, d_results, yidx));
895         break;
896       case 5:
897         PetscCall(mdot_kernel_dispatch_<5>(cur_ctx, cur_stream, xptr.data(), yin, n, d_results, yidx));
898         break;
899       case 4:
900         PetscCall(mdot_kernel_dispatch_<4>(cur_ctx, cur_stream, xptr.data(), yin, n, d_results, yidx));
901         break;
902       case 3:
903         PetscCall(mdot_kernel_dispatch_<3>(cur_ctx, cur_stream, xptr.data(), yin, n, d_results, yidx));
904         break;
905       case 2:
906         PetscCall(mdot_kernel_dispatch_<2>(cur_ctx, cur_stream, xptr.data(), yin, n, d_results, yidx));
907         break;
908       case 1: {
909         cupmBlasHandle_t cupmBlasHandle;
910 
911         PetscCall(GetHandlesFrom_(cur_ctx, &cupmBlasHandle));
912         PetscCallCUPMBLAS(cupmBlasXdot(cupmBlasHandle, static_cast<cupmBlasInt_t>(n), DeviceArrayRead(cur_ctx, yin[yidx]).cupmdata(), 1, xptr.cupmdata(), 1, cupmScalarPtrCast(z + yidx)));
913         ++yidx;
914       } break;
915       default: // 8 or more
916         PetscCall(mdot_kernel_dispatch_<8>(cur_ctx, cur_stream, xptr.data(), yin, n, d_results, yidx));
917         break;
918       }
919     } while (yidx < nv);
920     PetscCall(PetscLogGpuTimeEnd());
921     PetscCall(PetscDeviceContextJoin(dctx, num_sub_streams, PETSC_DEVICE_CONTEXT_JOIN_DESTROY, &sub));
922   }
923 
924   PetscCall(PetscCUPMLaunchKernel1D(nvbatch, 0, stream, kernels::sum_kernel, nvbatch, d_results));
925   // copy result of device reduction to host
926   PetscCall(PetscCUPMMemcpyAsync(z, d_results, nvbatch, cupmMemcpyDeviceToHost, stream));
927   // do these now while final reduction is in flight
928   PetscCall(PetscLogFlops(nwork));
929   PetscCall(PetscDeviceFree(dctx, d_results));
930   PetscFunctionReturn(0);
931 }
932 
933   #undef MDOT_WORKGROUP_NUM
934   #undef MDOT_WORKGROUP_SIZE
935 
936 template <device::cupm::DeviceType T>
937 inline PetscErrorCode VecSeq_CUPM<T>::mdot_(std::true_type, Vec xin, PetscInt nv, const Vec yin[], PetscScalar *z, PetscDeviceContext dctx) noexcept
938 {
939   // probably not worth it to run more than 8 of these at a time?
940   const auto          n_sub = PetscMin(nv, 8);
941   const auto          n     = static_cast<cupmBlasInt_t>(xin->map->n);
942   const auto          xptr  = DeviceArrayRead(dctx, xin);
943   PetscScalar        *d_z;
944   PetscDeviceContext *subctx;
945   cupmStream_t        stream;
946 
947   PetscFunctionBegin;
948   PetscCall(GetHandlesFrom_(dctx, &stream));
949   PetscCall(PetscDeviceMalloc(dctx, PETSC_MEMTYPE_CUPM(), nv, &d_z));
950   PetscCall(PetscDeviceContextFork(dctx, n_sub, &subctx));
951   PetscCall(PetscLogGpuTimeBegin());
952   for (PetscInt i = 0; i < nv; ++i) {
953     const auto            sub = subctx[i % n_sub];
954     cupmBlasHandle_t      handle;
955     cupmBlasPointerMode_t old_mode;
956 
957     PetscCall(GetHandlesFrom_(sub, &handle));
958     PetscCallCUPMBLAS(cupmBlasGetPointerMode(handle, &old_mode));
959     if (old_mode != CUPMBLAS_POINTER_MODE_DEVICE) PetscCallCUPMBLAS(cupmBlasSetPointerMode(handle, CUPMBLAS_POINTER_MODE_DEVICE));
960     PetscCallCUPMBLAS(cupmBlasXdot(handle, n, DeviceArrayRead(sub, yin[i]), 1, xptr.cupmdata(), 1, cupmScalarPtrCast(d_z + i)));
961     if (old_mode != CUPMBLAS_POINTER_MODE_DEVICE) PetscCallCUPMBLAS(cupmBlasSetPointerMode(handle, old_mode));
962   }
963   PetscCall(PetscLogGpuTimeEnd());
964   PetscCall(PetscDeviceContextJoin(dctx, n_sub, PETSC_DEVICE_CONTEXT_JOIN_DESTROY, &subctx));
965   PetscCall(PetscCUPMMemcpyAsync(z, d_z, nv, cupmMemcpyDeviceToHost, stream));
966   PetscCall(PetscDeviceFree(dctx, d_z));
967   // REVIEW ME: flops?????
968   PetscFunctionReturn(0);
969 }
970 
971 // v->ops->mdot
972 template <device::cupm::DeviceType T>
973 inline PetscErrorCode VecSeq_CUPM<T>::mdot(Vec xin, PetscInt nv, const Vec yin[], PetscScalar *z) noexcept
974 {
975   PetscFunctionBegin;
976   if (PetscUnlikely(nv == 1)) {
977     // dot handles nv = 0 correctly
978     PetscCall(dot(xin, const_cast<Vec>(yin[0]), z));
979   } else if (const auto n = xin->map->n) {
980     PetscDeviceContext dctx;
981 
982     PetscCheck(nv > 0, PETSC_COMM_SELF, PETSC_ERR_LIB, "Number of vectors provided to %s %" PetscInt_FMT " not positive", PETSC_FUNCTION_NAME, nv);
983     PetscCall(GetHandles_(&dctx));
984     PetscCall(mdot_(std::integral_constant<bool, PetscDefined(USE_COMPLEX)>{}, xin, nv, yin, z, dctx));
985     // REVIEW ME: double count of flops??
986     PetscCall(PetscLogGpuFlops(nv * (2 * n - 1)));
987     PetscCall(PetscDeviceContextSynchronize(dctx));
988   } else {
989     PetscCall(PetscArrayzero(z, nv));
990   }
991   PetscFunctionReturn(0);
992 }
993 
994 // v->ops->set
995 template <device::cupm::DeviceType T>
996 inline PetscErrorCode VecSeq_CUPM<T>::set(Vec xin, PetscScalar alpha) noexcept
997 {
998   const auto         n = xin->map->n;
999   PetscDeviceContext dctx;
1000   cupmStream_t       stream;
1001 
1002   PetscFunctionBegin;
1003   PetscCall(GetHandles_(&dctx, &stream));
1004   {
1005     const auto xptr = DeviceArrayWrite(dctx, xin);
1006 
1007     if (alpha == PetscScalar(0.0)) {
1008       PetscCall(PetscCUPMMemsetAsync(xptr.data(), 0, n, stream));
1009     } else {
1010       PetscCall(device::cupm::impl::ThrustSet<T>(stream, n, xptr.data(), &alpha));
1011     }
1012     if (n) PetscCall(PetscDeviceContextSynchronize(dctx)); // don't sync if we did nothing
1013   }
1014   PetscFunctionReturn(0);
1015 }
1016 
1017 // v->ops->scale
1018 template <device::cupm::DeviceType T>
1019 inline PetscErrorCode VecSeq_CUPM<T>::scale(Vec xin, PetscScalar alpha) noexcept
1020 {
1021   PetscFunctionBegin;
1022   if (PetscUnlikely(alpha == PetscScalar(1.0))) PetscFunctionReturn(0);
1023   if (PetscUnlikely(alpha == PetscScalar(0.0))) {
1024     PetscCall(set(xin, alpha));
1025   } else if (const auto n = static_cast<cupmBlasInt_t>(xin->map->n)) {
1026     PetscDeviceContext dctx;
1027     cupmBlasHandle_t   cupmBlasHandle;
1028 
1029     PetscCall(GetHandles_(&dctx, &cupmBlasHandle));
1030     PetscCall(PetscLogGpuTimeBegin());
1031     PetscCallCUPMBLAS(cupmBlasXscal(cupmBlasHandle, n, cupmScalarPtrCast(&alpha), DeviceArrayReadWrite(dctx, xin), 1));
1032     PetscCall(PetscLogGpuTimeEnd());
1033     PetscCall(PetscLogGpuFlops(n));
1034     PetscCall(PetscDeviceContextSynchronize(dctx));
1035   } else {
1036     PetscCall(MaybeIncrementEmptyLocalVec(xin));
1037   }
1038   PetscFunctionReturn(0);
1039 }
1040 
1041 // v->ops->tdot
1042 template <device::cupm::DeviceType T>
1043 inline PetscErrorCode VecSeq_CUPM<T>::tdot(Vec xin, Vec yin, PetscScalar *z) noexcept
1044 {
1045   PetscFunctionBegin;
1046   if (const auto n = static_cast<cupmBlasInt_t>(xin->map->n)) {
1047     PetscDeviceContext dctx;
1048     cupmBlasHandle_t   cupmBlasHandle;
1049 
1050     PetscCall(GetHandles_(&dctx, &cupmBlasHandle));
1051     PetscCall(PetscLogGpuTimeBegin());
1052     PetscCallCUPMBLAS(cupmBlasXdotu(cupmBlasHandle, n, DeviceArrayRead(dctx, xin), 1, DeviceArrayRead(dctx, yin), 1, cupmScalarPtrCast(z)));
1053     PetscCall(PetscLogGpuTimeEnd());
1054     PetscCall(PetscLogGpuFlops(2 * n - 1));
1055   } else {
1056     *z = 0.0;
1057   }
1058   PetscFunctionReturn(0);
1059 }
1060 
1061 // v->ops->copy
1062 template <device::cupm::DeviceType T>
1063 inline PetscErrorCode VecSeq_CUPM<T>::copy(Vec xin, Vec yout) noexcept
1064 {
1065   PetscFunctionBegin;
1066   if (xin == yout) PetscFunctionReturn(0);
1067   if (const auto n = xin->map->n) {
1068     const auto xmask = xin->offloadmask;
1069     // silence buggy gcc warning: mode may be used uninitialized in this function
1070     auto               mode = cupmMemcpyDeviceToDevice;
1071     PetscDeviceContext dctx;
1072     cupmStream_t       stream;
1073 
1074     // translate from PetscOffloadMask to cupmMemcpyKind
1075     switch (const auto ymask = yout->offloadmask) {
1076     case PETSC_OFFLOAD_UNALLOCATED: {
1077       PetscBool yiscupm;
1078 
1079       PetscCall(PetscObjectTypeCompareAny(PetscObjectCast(yout), &yiscupm, VECSEQCUPM(), VECMPICUPM(), ""));
1080       if (yiscupm) {
1081         mode = PetscOffloadDevice(xmask) ? cupmMemcpyDeviceToDevice : cupmMemcpyHostToHost;
1082         break;
1083       }
1084     } // fall-through if unallocated and not cupm
1085   #if PETSC_CPP_VERSION >= 17
1086       [[fallthrough]];
1087   #endif
1088     case PETSC_OFFLOAD_CPU:
1089       mode = PetscOffloadHost(xmask) ? cupmMemcpyHostToHost : cupmMemcpyDeviceToHost;
1090       break;
1091     case PETSC_OFFLOAD_BOTH:
1092     case PETSC_OFFLOAD_GPU:
1093       mode = PetscOffloadDevice(xmask) ? cupmMemcpyDeviceToDevice : cupmMemcpyHostToDevice;
1094       break;
1095     default:
1096       SETERRQ(PETSC_COMM_SELF, PETSC_ERR_ARG_INCOMP, "Incompatible offload mask %s", PetscOffloadMaskToString(ymask));
1097     }
1098 
1099     PetscCall(GetHandles_(&dctx, &stream));
1100     switch (mode) {
1101     case cupmMemcpyDeviceToDevice: // the best case
1102     case cupmMemcpyHostToDevice: { // not terrible
1103       const auto yptr = DeviceArrayWrite(dctx, yout);
1104       const auto xptr = mode == cupmMemcpyDeviceToDevice ? DeviceArrayRead(dctx, xin).data() : HostArrayRead(dctx, xin).data();
1105 
1106       PetscCall(PetscLogGpuTimeBegin());
1107       PetscCall(PetscCUPMMemcpyAsync(yptr.data(), xptr, n, mode, stream));
1108       PetscCall(PetscLogGpuTimeEnd());
1109     } break;
1110     case cupmMemcpyDeviceToHost: // not great
1111     case cupmMemcpyHostToHost: { // worst case
1112       const auto   xptr = mode == cupmMemcpyDeviceToHost ? DeviceArrayRead(dctx, xin).data() : HostArrayRead(dctx, xin).data();
1113       PetscScalar *yptr;
1114 
1115       PetscCall(VecGetArrayWrite(yout, &yptr));
1116       if (mode == cupmMemcpyDeviceToHost) PetscCall(PetscLogGpuTimeBegin());
1117       PetscCall(PetscCUPMMemcpyAsync(yptr, xptr, n, mode, stream, /* force async */ true));
1118       if (mode == cupmMemcpyDeviceToHost) PetscCall(PetscLogGpuTimeEnd());
1119       PetscCall(VecRestoreArrayWrite(yout, &yptr));
1120     } break;
1121     default:
1122       SETERRQ(PETSC_COMM_SELF, PETSC_ERR_GPU, "Unknown cupmMemcpyKind %d", static_cast<int>(mode));
1123     }
1124     PetscCall(PetscDeviceContextSynchronize(dctx));
1125   } else {
1126     PetscCall(MaybeIncrementEmptyLocalVec(yout));
1127   }
1128   PetscFunctionReturn(0);
1129 }
1130 
1131 // v->ops->swap
1132 template <device::cupm::DeviceType T>
1133 inline PetscErrorCode VecSeq_CUPM<T>::swap(Vec xin, Vec yin) noexcept
1134 {
1135   PetscFunctionBegin;
1136   if (xin == yin) PetscFunctionReturn(0);
1137   if (const auto n = static_cast<cupmBlasInt_t>(xin->map->n)) {
1138     PetscDeviceContext dctx;
1139     cupmBlasHandle_t   cupmBlasHandle;
1140 
1141     PetscCall(GetHandles_(&dctx, &cupmBlasHandle));
1142     PetscCall(PetscLogGpuTimeBegin());
1143     PetscCallCUPMBLAS(cupmBlasXswap(cupmBlasHandle, n, DeviceArrayReadWrite(dctx, xin), 1, DeviceArrayReadWrite(dctx, yin), 1));
1144     PetscCall(PetscLogGpuTimeEnd());
1145     PetscCall(PetscDeviceContextSynchronize(dctx));
1146   } else {
1147     PetscCall(MaybeIncrementEmptyLocalVec(xin));
1148     PetscCall(MaybeIncrementEmptyLocalVec(yin));
1149   }
1150   PetscFunctionReturn(0);
1151 }
1152 
1153 // v->ops->axpby
1154 template <device::cupm::DeviceType T>
1155 inline PetscErrorCode VecSeq_CUPM<T>::axpby(Vec yin, PetscScalar alpha, PetscScalar beta, Vec xin) noexcept
1156 {
1157   PetscFunctionBegin;
1158   if (alpha == PetscScalar(0.0)) {
1159     PetscCall(scale(yin, beta));
1160   } else if (beta == PetscScalar(1.0)) {
1161     PetscCall(axpy(yin, alpha, xin));
1162   } else if (alpha == PetscScalar(1.0)) {
1163     PetscCall(aypx(yin, beta, xin));
1164   } else if (const auto n = static_cast<cupmBlasInt_t>(yin->map->n)) {
1165     const auto         betaIsZero = beta == PetscScalar(0.0);
1166     const auto         aptr       = cupmScalarPtrCast(&alpha);
1167     PetscDeviceContext dctx;
1168     cupmBlasHandle_t   cupmBlasHandle;
1169 
1170     PetscCall(GetHandles_(&dctx, &cupmBlasHandle));
1171     {
1172       const auto xptr = DeviceArrayRead(dctx, xin);
1173 
1174       if (betaIsZero /* beta = 0 */) {
1175         // here we can get away with purely write-only as we memcpy into it first
1176         const auto   yptr = DeviceArrayWrite(dctx, yin);
1177         cupmStream_t stream;
1178 
1179         PetscCall(GetHandlesFrom_(dctx, &stream));
1180         PetscCall(PetscLogGpuTimeBegin());
1181         PetscCall(PetscCUPMMemcpyAsync(yptr.data(), xptr.data(), n, cupmMemcpyDeviceToDevice, stream));
1182         PetscCallCUPMBLAS(cupmBlasXscal(cupmBlasHandle, n, aptr, yptr.cupmdata(), 1));
1183       } else {
1184         const auto yptr = DeviceArrayReadWrite(dctx, yin);
1185 
1186         PetscCall(PetscLogGpuTimeBegin());
1187         PetscCallCUPMBLAS(cupmBlasXscal(cupmBlasHandle, n, cupmScalarPtrCast(&beta), yptr.cupmdata(), 1));
1188         PetscCallCUPMBLAS(cupmBlasXaxpy(cupmBlasHandle, n, aptr, xptr.cupmdata(), 1, yptr.cupmdata(), 1));
1189       }
1190     }
1191     PetscCall(PetscLogGpuTimeEnd());
1192     PetscCall(PetscLogGpuFlops((betaIsZero ? 1 : 3) * n));
1193     PetscCall(PetscDeviceContextSynchronize(dctx));
1194   } else {
1195     PetscCall(MaybeIncrementEmptyLocalVec(yin));
1196   }
1197   PetscFunctionReturn(0);
1198 }
1199 
1200 // v->ops->axpbypcz
1201 template <device::cupm::DeviceType T>
1202 inline PetscErrorCode VecSeq_CUPM<T>::axpbypcz(Vec zin, PetscScalar alpha, PetscScalar beta, PetscScalar gamma, Vec xin, Vec yin) noexcept
1203 {
1204   PetscFunctionBegin;
1205   if (gamma != PetscScalar(1.0)) PetscCall(scale(zin, gamma));
1206   PetscCall(axpy(zin, alpha, xin));
1207   PetscCall(axpy(zin, beta, yin));
1208   PetscFunctionReturn(0);
1209 }
1210 
1211 // v->ops->norm
1212 template <device::cupm::DeviceType T>
1213 inline PetscErrorCode VecSeq_CUPM<T>::norm(Vec xin, NormType type, PetscReal *z) noexcept
1214 {
1215   PetscDeviceContext dctx;
1216   cupmBlasHandle_t   cupmBlasHandle;
1217 
1218   PetscFunctionBegin;
1219   PetscCall(GetHandles_(&dctx, &cupmBlasHandle));
1220   if (const auto n = static_cast<cupmBlasInt_t>(xin->map->n)) {
1221     const auto xptr      = DeviceArrayRead(dctx, xin);
1222     PetscInt   flopCount = 0;
1223 
1224     PetscCall(PetscLogGpuTimeBegin());
1225     switch (type) {
1226     case NORM_1_AND_2:
1227     case NORM_1:
1228       PetscCallCUPMBLAS(cupmBlasXasum(cupmBlasHandle, n, xptr.cupmdata(), 1, cupmRealPtrCast(z)));
1229       flopCount = std::max(n - 1, 0);
1230       if (type == NORM_1) break;
1231       ++z; // fall-through
1232   #if PETSC_CPP_VERSION >= 17
1233       [[fallthrough]];
1234   #endif
1235     case NORM_2:
1236     case NORM_FROBENIUS:
1237       PetscCallCUPMBLAS(cupmBlasXnrm2(cupmBlasHandle, n, xptr.cupmdata(), 1, cupmRealPtrCast(z)));
1238       flopCount += std::max(2 * n - 1, 0); // += in case we've fallen through from NORM_1_AND_2
1239       break;
1240     case NORM_INFINITY: {
1241       cupmBlasInt_t max_loc = 0;
1242       PetscScalar   xv      = 0.;
1243       cupmStream_t  stream;
1244 
1245       PetscCall(GetHandlesFrom_(dctx, &stream));
1246       PetscCallCUPMBLAS(cupmBlasXamax(cupmBlasHandle, n, xptr.cupmdata(), 1, &max_loc));
1247       PetscCall(PetscCUPMMemcpyAsync(&xv, xptr.data() + max_loc - 1, 1, cupmMemcpyDeviceToHost, stream));
1248       *z = PetscAbsScalar(xv);
1249       // REVIEW ME: flopCount = ???
1250     } break;
1251     }
1252     PetscCall(PetscLogGpuTimeEnd());
1253     PetscCall(PetscLogGpuFlops(flopCount));
1254   } else {
1255     z[0]                    = 0.0;
1256     z[type == NORM_1_AND_2] = 0.0;
1257   }
1258   PetscFunctionReturn(0);
1259 }
1260 
1261 namespace detail
1262 {
1263 
1264 struct dotnorm2_mult {
1265   PETSC_NODISCARD PETSC_HOSTDEVICE_INLINE_DECL thrust::tuple<PetscScalar, PetscScalar> operator()(const PetscScalar &s, const PetscScalar &t) const noexcept
1266   {
1267     const auto conjt = PetscConj(t);
1268 
1269     return {s * conjt, t * conjt};
1270   }
1271 };
1272 
1273 // it is positively __bananas__ that thrust does not define default operator+ for tuples... I
1274 // would do it myself but now I am worried that they do so on purpose...
1275 struct dotnorm2_tuple_plus {
1276   using value_type = thrust::tuple<PetscScalar, PetscScalar>;
1277 
1278   PETSC_NODISCARD PETSC_HOSTDEVICE_INLINE_DECL value_type operator()(const value_type &lhs, const value_type &rhs) const noexcept { return {lhs.get<0>() + rhs.get<0>(), lhs.get<1>() + rhs.get<1>()}; }
1279 };
1280 
1281 } // namespace detail
1282 
1283 // v->ops->dotnorm2
1284 template <device::cupm::DeviceType T>
1285 inline PetscErrorCode VecSeq_CUPM<T>::dotnorm2(Vec s, Vec t, PetscScalar *dp, PetscScalar *nm) noexcept
1286 {
1287   PetscDeviceContext dctx;
1288   cupmStream_t       stream;
1289 
1290   PetscFunctionBegin;
1291   PetscCall(GetHandles_(&dctx, &stream));
1292   {
1293     PetscScalar dpt = 0.0, nmt = 0.0;
1294     const auto  sdptr = thrust::device_pointer_cast(DeviceArrayRead(dctx, s).data());
1295 
1296     // clang-format off
1297     PetscCallThrust(
1298       thrust::tie(*dp, *nm) = THRUST_CALL(
1299         thrust::inner_product,
1300         stream,
1301         sdptr, sdptr+s->map->n, thrust::device_pointer_cast(DeviceArrayRead(dctx, t).data()),
1302         thrust::make_tuple(dpt, nmt),
1303         detail::dotnorm2_tuple_plus{}, detail::dotnorm2_mult{}
1304       );
1305     );
1306     // clang-format on
1307   }
1308   PetscFunctionReturn(0);
1309 }
1310 
1311 namespace detail
1312 {
1313 
1314 struct conjugate {
1315   PETSC_NODISCARD PETSC_HOSTDEVICE_INLINE_DECL PetscScalar operator()(PetscScalar x) const noexcept { return PetscConj(x); }
1316 };
1317 
1318 } // namespace detail
1319 
1320 // v->ops->conjugate
1321 template <device::cupm::DeviceType T>
1322 inline PetscErrorCode VecSeq_CUPM<T>::conjugate(Vec xin) noexcept
1323 {
1324   PetscFunctionBegin;
1325   if (PetscDefined(USE_COMPLEX)) PetscCall(pointwiseunary_(detail::conjugate{}, xin));
1326   PetscFunctionReturn(0);
1327 }
1328 
1329 namespace detail
1330 {
1331 
1332 struct real_part {
1333   PETSC_NODISCARD PETSC_HOSTDEVICE_INLINE_DECL thrust::tuple<PetscReal, PetscInt> operator()(const thrust::tuple<PetscScalar, PetscInt> &x) const { return {PetscRealPart(x.get<0>()), x.get<1>()}; }
1334 
1335   PETSC_NODISCARD PETSC_HOSTDEVICE_INLINE_DECL PetscReal operator()(PetscScalar x) const { return PetscRealPart(x); }
1336 };
1337 
1338 // deriving from Operator allows us to "store" an instance of the operator in the class but
1339 // also take advantage of empty base class optimization if the operator is stateless
1340 template <typename Operator>
1341 class tuple_compare : Operator {
1342 public:
1343   using tuple_type    = thrust::tuple<PetscReal, PetscInt>;
1344   using operator_type = Operator;
1345 
1346   PETSC_NODISCARD PETSC_HOSTDEVICE_INLINE_DECL tuple_type operator()(const tuple_type &x, const tuple_type &y) const noexcept
1347   {
1348     if (op_()(y.get<0>(), x.get<0>())) {
1349       // if y is strictly greater/less than x, return y
1350       return y;
1351     } else if (y.get<0>() == x.get<0>()) {
1352       // if equal, prefer lower index
1353       return y.get<1>() < x.get<1>() ? y : x;
1354     }
1355     // otherwise return x
1356     return x;
1357   }
1358 
1359 private:
1360   PETSC_NODISCARD PETSC_HOSTDEVICE_INLINE_DECL const operator_type &op_() const noexcept { return *this; }
1361 };
1362 
1363 } // namespace detail
1364 
1365 template <device::cupm::DeviceType T>
1366 template <typename TupleFuncT, typename UnaryFuncT>
1367 inline PetscErrorCode VecSeq_CUPM<T>::minmax_(TupleFuncT &&tuple_ftr, UnaryFuncT &&unary_ftr, Vec v, PetscInt *p, PetscReal *m) noexcept
1368 {
1369   PetscFunctionBegin;
1370   PetscCheckTypeNames(v, VECSEQCUPM(), VECMPICUPM());
1371   if (p) *p = -1;
1372   if (const auto n = v->map->n) {
1373     PetscDeviceContext dctx;
1374     cupmStream_t       stream;
1375 
1376     PetscCall(GetHandles_(&dctx, &stream));
1377       // needed to:
1378       // 1. switch between transform_reduce and reduce
1379       // 2. strip the real_part functor from the arguments
1380   #if PetscDefined(USE_COMPLEX)
1381     #define THRUST_MINMAX_REDUCE(...) THRUST_CALL(thrust::transform_reduce, __VA_ARGS__)
1382   #else
1383     #define THRUST_MINMAX_REDUCE(s, b, e, real_part__, ...) THRUST_CALL(thrust::reduce, s, b, e, __VA_ARGS__)
1384   #endif
1385     {
1386       const auto vptr = thrust::device_pointer_cast(DeviceArrayRead(dctx, v).data());
1387 
1388       if (p) {
1389         // clang-format off
1390         const auto zip = thrust::make_zip_iterator(
1391           thrust::make_tuple(std::move(vptr), thrust::make_counting_iterator(PetscInt{0}))
1392         );
1393         // clang-format on
1394         // need to use preprocessor conditionals since otherwise thrust complains about not being
1395         // able to convert a thrust::device_reference<PetscScalar> to a PetscReal on complex
1396         // builds...
1397         // clang-format off
1398         PetscCallThrust(
1399           thrust::tie(*m, *p) = THRUST_MINMAX_REDUCE(
1400             stream, zip, zip + n, detail::real_part{},
1401             thrust::make_tuple(*m, *p), std::forward<TupleFuncT>(tuple_ftr)
1402           );
1403         );
1404         // clang-format on
1405       } else {
1406         // clang-format off
1407         PetscCallThrust(
1408           *m = THRUST_MINMAX_REDUCE(
1409             stream, vptr, vptr + n, detail::real_part{},
1410             *m, std::forward<UnaryFuncT>(unary_ftr)
1411           );
1412         );
1413         // clang-format on
1414       }
1415     }
1416   #undef THRUST_MINMAX_REDUCE
1417   }
1418   // REVIEW ME: flops?
1419   PetscFunctionReturn(0);
1420 }
1421 
1422 // v->ops->max
1423 template <device::cupm::DeviceType T>
1424 inline PetscErrorCode VecSeq_CUPM<T>::max(Vec v, PetscInt *p, PetscReal *m) noexcept
1425 {
1426   using tuple_functor = detail::tuple_compare<thrust::greater<PetscReal>>;
1427   using unary_functor = thrust::maximum<PetscReal>;
1428 
1429   PetscFunctionBegin;
1430   *m = PETSC_MIN_REAL;
1431   // use {} constructor syntax otherwise most vexing parse
1432   PetscCall(minmax_(tuple_functor{}, unary_functor{}, v, p, m));
1433   PetscFunctionReturn(0);
1434 }
1435 
1436 // v->ops->min
1437 template <device::cupm::DeviceType T>
1438 inline PetscErrorCode VecSeq_CUPM<T>::min(Vec v, PetscInt *p, PetscReal *m) noexcept
1439 {
1440   using tuple_functor = detail::tuple_compare<thrust::less<PetscReal>>;
1441   using unary_functor = thrust::minimum<PetscReal>;
1442 
1443   PetscFunctionBegin;
1444   *m = PETSC_MAX_REAL;
1445   // use {} constructor syntax otherwise most vexing parse
1446   PetscCall(minmax_(tuple_functor{}, unary_functor{}, v, p, m));
1447   PetscFunctionReturn(0);
1448 }
1449 
1450 // v->ops->sum
1451 template <device::cupm::DeviceType T>
1452 inline PetscErrorCode VecSeq_CUPM<T>::sum(Vec v, PetscScalar *sum) noexcept
1453 {
1454   PetscFunctionBegin;
1455   if (const auto n = v->map->n) {
1456     PetscDeviceContext dctx;
1457     cupmStream_t       stream;
1458 
1459     PetscCall(GetHandles_(&dctx, &stream));
1460     const auto dptr = thrust::device_pointer_cast(DeviceArrayRead(dctx, v).data());
1461     // REVIEW ME: why not cupmBlasXasum()?
1462     PetscCallThrust(*sum = THRUST_CALL(thrust::reduce, stream, dptr, dptr + n, PetscScalar{0.0}););
1463     // REVIEW ME: must be at least n additions
1464     PetscCall(PetscLogGpuFlops(n));
1465   } else {
1466     *sum = 0.0;
1467   }
1468   PetscFunctionReturn(0);
1469 }
1470 
1471 namespace detail
1472 {
1473 
1474 template <typename T>
1475 class plus_equals {
1476 public:
1477   using value_type = T;
1478 
1479   PETSC_HOSTDEVICE_DECL constexpr explicit plus_equals(value_type v = value_type{}) noexcept : v_(std::move(v)) { }
1480 
1481   PETSC_NODISCARD PETSC_HOSTDEVICE_INLINE_DECL constexpr T operator()(const T &val) const noexcept { return val + v_; }
1482 
1483 private:
1484   value_type v_;
1485 };
1486 
1487 } // namespace detail
1488 
1489 template <device::cupm::DeviceType T>
1490 inline PetscErrorCode VecSeq_CUPM<T>::shift(Vec v, PetscScalar shift) noexcept
1491 {
1492   PetscFunctionBegin;
1493   PetscCall(pointwiseunary_(detail::plus_equals<PetscScalar>{shift}, v));
1494   PetscFunctionReturn(0);
1495 }
1496 
1497 template <device::cupm::DeviceType T>
1498 inline PetscErrorCode VecSeq_CUPM<T>::setrandom(Vec v, PetscRandom rand) noexcept
1499 {
1500   PetscFunctionBegin;
1501   if (const auto n = v->map->n) {
1502     PetscBool          iscurand;
1503     PetscDeviceContext dctx;
1504 
1505     PetscCall(GetHandles_(&dctx));
1506     PetscCall(PetscObjectTypeCompare(PetscObjectCast(rand), PETSCCURAND, &iscurand));
1507     if (iscurand) PetscCall(PetscRandomGetValues(rand, n, DeviceArrayWrite(dctx, v)));
1508     else PetscCall(PetscRandomGetValues(rand, n, HostArrayWrite(dctx, v)));
1509   } else {
1510     PetscCall(MaybeIncrementEmptyLocalVec(v));
1511   }
1512   // REVIEW ME: flops????
1513   // REVIEW ME: Timing???
1514   PetscFunctionReturn(0);
1515 }
1516 
1517 // v->ops->setpreallocation
1518 template <device::cupm::DeviceType T>
1519 inline PetscErrorCode VecSeq_CUPM<T>::setpreallocationcoo(Vec v, PetscCount ncoo, const PetscInt coo_i[]) noexcept
1520 {
1521   PetscDeviceContext dctx;
1522 
1523   PetscFunctionBegin;
1524   PetscCall(GetHandles_(&dctx));
1525   PetscCall(VecSetPreallocationCOO_Seq(v, ncoo, coo_i));
1526   PetscCall(SetPreallocationCOO_CUPMBase(v, ncoo, coo_i, dctx));
1527   PetscFunctionReturn(0);
1528 }
1529 
1530 namespace kernels
1531 {
1532 
1533 template <typename F>
1534 PETSC_DEVICE_INLINE_DECL void add_coo_values_impl(const PetscScalar *PETSC_RESTRICT vv, PetscCount n, const PetscCount *PETSC_RESTRICT jmap, const PetscCount *PETSC_RESTRICT perm, InsertMode imode, PetscScalar *PETSC_RESTRICT xv, F &&xvindex)
1535 {
1536   ::Petsc::device::cupm::kernels::util::grid_stride_1D(n, [=](PetscCount i) {
1537     const auto  end = jmap[i + 1];
1538     const auto  idx = xvindex(i);
1539     PetscScalar sum = 0.0;
1540 
1541     for (auto k = jmap[i]; k < end; ++k) sum += vv[perm[k]];
1542 
1543     if (imode == INSERT_VALUES) {
1544       xv[idx] = sum;
1545     } else {
1546       xv[idx] += sum;
1547     }
1548   });
1549   return;
1550 }
1551 
1552 PETSC_KERNEL_DECL static void add_coo_values(const PetscScalar *PETSC_RESTRICT v, PetscCount n, const PetscCount *PETSC_RESTRICT jmap1, const PetscCount *PETSC_RESTRICT perm1, InsertMode imode, PetscScalar *PETSC_RESTRICT xv)
1553 {
1554   add_coo_values_impl(v, n, jmap1, perm1, imode, xv, [](PetscCount i) { return i; });
1555   return;
1556 }
1557 
1558 } // namespace kernels
1559 
1560 // v->ops->setvaluescoo
1561 template <device::cupm::DeviceType T>
1562 inline PetscErrorCode VecSeq_CUPM<T>::setvaluescoo(Vec x, const PetscScalar v[], InsertMode imode) noexcept
1563 {
1564   auto               vv = const_cast<PetscScalar *>(v);
1565   PetscMemType       memtype;
1566   PetscDeviceContext dctx;
1567   cupmStream_t       stream;
1568 
1569   PetscFunctionBegin;
1570   PetscCall(GetHandles_(&dctx, &stream));
1571   PetscCall(PetscGetMemType(v, &memtype));
1572   if (PetscMemTypeHost(memtype)) {
1573     const auto size = VecIMPLCast(x)->coo_n;
1574 
1575     // If user gave v[] in host, we might need to copy it to device if any
1576     PetscCall(PetscDeviceMalloc(dctx, PETSC_MEMTYPE_CUPM(), size, &vv));
1577     PetscCall(PetscCUPMMemcpyAsync(vv, v, size, cupmMemcpyHostToDevice, stream));
1578   }
1579 
1580   if (const auto n = x->map->n) {
1581     const auto vcu = VecCUPMCast(x);
1582 
1583     PetscCall(PetscCUPMLaunchKernel1D(n, 0, stream, kernels::add_coo_values, vv, n, vcu->jmap1_d, vcu->perm1_d, imode, imode == INSERT_VALUES ? DeviceArrayWrite(dctx, x).data() : DeviceArrayReadWrite(dctx, x).data()));
1584   } else {
1585     PetscCall(MaybeIncrementEmptyLocalVec(x));
1586   }
1587 
1588   if (PetscMemTypeHost(memtype)) PetscCall(PetscDeviceFree(dctx, vv));
1589   PetscCall(PetscDeviceContextSynchronize(dctx));
1590   PetscFunctionReturn(0);
1591 }
1592 
1593 // ==========================================================================================
1594 // VecSeq_CUPM - Implementations
1595 // ==========================================================================================
1596 
1597 namespace
1598 {
1599 
1600 template <typename T>
1601 PETSC_NODISCARD inline PetscErrorCode VecCreateSeqCUPMAsync(T &&VecSeq_CUPM_Impls, MPI_Comm comm, PetscInt n, Vec *v) noexcept
1602 {
1603   PetscFunctionBegin;
1604   PetscValidPointer(v, 4);
1605   PetscCall(VecSeq_CUPM_Impls.createseqcupm(comm, 0, n, v, PETSC_TRUE));
1606   PetscFunctionReturn(0);
1607 }
1608 
1609 template <typename T>
1610 PETSC_NODISCARD inline PetscErrorCode VecCreateSeqCUPMWithArraysAsync(T &&VecSeq_CUPM_Impls, MPI_Comm comm, PetscInt bs, PetscInt n, const PetscScalar cpuarray[], const PetscScalar gpuarray[], Vec *v) noexcept
1611 {
1612   PetscFunctionBegin;
1613   if (n && cpuarray) PetscValidScalarPointer(cpuarray, 5);
1614   PetscValidPointer(v, 7);
1615   PetscCall(VecSeq_CUPM_Impls.createseqcupmwithbotharrays(comm, bs, n, cpuarray, gpuarray, v));
1616   PetscFunctionReturn(0);
1617 }
1618 
1619 template <PetscMemoryAccessMode mode, typename T>
1620 PETSC_NODISCARD inline PetscErrorCode VecCUPMGetArrayAsync_Private(T &&VecSeq_CUPM_Impls, Vec v, PetscScalar **a) noexcept
1621 {
1622   PetscFunctionBegin;
1623   PetscValidHeaderSpecific(v, VEC_CLASSID, 2);
1624   PetscValidPointer(a, 3);
1625   PetscCall(VecSeq_CUPM_Impls.template getarray<PETSC_MEMTYPE_DEVICE, mode>(v, a));
1626   PetscFunctionReturn(0);
1627 }
1628 
1629 template <PetscMemoryAccessMode mode, typename T>
1630 PETSC_NODISCARD inline PetscErrorCode VecCUPMRestoreArrayAsync_Private(T &&VecSeq_CUPM_Impls, Vec v, PetscScalar **a) noexcept
1631 {
1632   PetscFunctionBegin;
1633   PetscValidHeaderSpecific(v, VEC_CLASSID, 2);
1634   PetscCall(VecSeq_CUPM_Impls.template restorearray<PETSC_MEMTYPE_DEVICE, mode>(v, a));
1635   PetscFunctionReturn(0);
1636 }
1637 
1638 template <typename T>
1639 PETSC_NODISCARD inline PetscErrorCode VecCUPMGetArrayAsync(T &&VecSeq_CUPM_Impls, Vec v, PetscScalar **a) noexcept
1640 {
1641   PetscFunctionBegin;
1642   PetscCall(VecCUPMGetArrayAsync_Private<PETSC_MEMORY_ACCESS_READ_WRITE>(std::forward<T>(VecSeq_CUPM_Impls), v, a));
1643   PetscFunctionReturn(0);
1644 }
1645 
1646 template <typename T>
1647 PETSC_NODISCARD inline PetscErrorCode VecCUPMRestoreArrayAsync(T &&VecSeq_CUPM_Impls, Vec v, PetscScalar **a) noexcept
1648 {
1649   PetscFunctionBegin;
1650   PetscCall(VecCUPMRestoreArrayAsync_Private<PETSC_MEMORY_ACCESS_READ_WRITE>(std::forward<T>(VecSeq_CUPM_Impls), v, a));
1651   PetscFunctionReturn(0);
1652 }
1653 
1654 template <typename T>
1655 PETSC_NODISCARD inline PetscErrorCode VecCUPMGetArrayReadAsync(T &&VecSeq_CUPM_Impls, Vec v, const PetscScalar **a) noexcept
1656 {
1657   PetscFunctionBegin;
1658   PetscCall(VecCUPMGetArrayAsync_Private<PETSC_MEMORY_ACCESS_READ>(std::forward<T>(VecSeq_CUPM_Impls), v, const_cast<PetscScalar **>(a)));
1659   PetscFunctionReturn(0);
1660 }
1661 
1662 template <typename T>
1663 PETSC_NODISCARD inline PetscErrorCode VecCUPMRestoreArrayReadAsync(T &&VecSeq_CUPM_Impls, Vec v, const PetscScalar **a) noexcept
1664 {
1665   PetscFunctionBegin;
1666   PetscCall(VecCUPMRestoreArrayAsync_Private<PETSC_MEMORY_ACCESS_READ>(std::forward<T>(VecSeq_CUPM_Impls), v, const_cast<PetscScalar **>(a)));
1667   PetscFunctionReturn(0);
1668 }
1669 
1670 template <typename T>
1671 PETSC_NODISCARD inline PetscErrorCode VecCUPMGetArrayWriteAsync(T &&VecSeq_CUPM_Impls, Vec v, PetscScalar **a) noexcept
1672 {
1673   PetscFunctionBegin;
1674   PetscCall(VecCUPMGetArrayAsync_Private<PETSC_MEMORY_ACCESS_WRITE>(std::forward<T>(VecSeq_CUPM_Impls), v, a));
1675   PetscFunctionReturn(0);
1676 }
1677 
1678 template <typename T>
1679 PETSC_NODISCARD inline PetscErrorCode VecCUPMRestoreArrayWriteAsync(T &&VecSeq_CUPM_Impls, Vec v, PetscScalar **a) noexcept
1680 {
1681   PetscFunctionBegin;
1682   PetscCall(VecCUPMRestoreArrayAsync_Private<PETSC_MEMORY_ACCESS_WRITE>(std::forward<T>(VecSeq_CUPM_Impls), v, a));
1683   PetscFunctionReturn(0);
1684 }
1685 
1686 template <typename T>
1687 PETSC_NODISCARD inline PetscErrorCode VecCUPMPlaceArrayAsync(T &&VecSeq_CUPM_Impls, Vec vin, const PetscScalar a[]) noexcept
1688 {
1689   PetscFunctionBegin;
1690   PetscValidHeaderSpecific(vin, VEC_CLASSID, 2);
1691   PetscCall(VecSeq_CUPM_Impls.template placearray<PETSC_MEMTYPE_DEVICE>(vin, a));
1692   PetscFunctionReturn(0);
1693 }
1694 
1695 template <typename T>
1696 PETSC_NODISCARD inline PetscErrorCode VecCUPMReplaceArrayAsync(T &&VecSeq_CUPM_Impls, Vec vin, const PetscScalar a[]) noexcept
1697 {
1698   PetscFunctionBegin;
1699   PetscValidHeaderSpecific(vin, VEC_CLASSID, 2);
1700   PetscCall(VecSeq_CUPM_Impls.template replacearray<PETSC_MEMTYPE_DEVICE>(vin, a));
1701   PetscFunctionReturn(0);
1702 }
1703 
1704 template <typename T>
1705 PETSC_NODISCARD inline PetscErrorCode VecCUPMResetArrayAsync(T &&VecSeq_CUPM_Impls, Vec vin) noexcept
1706 {
1707   PetscFunctionBegin;
1708   PetscValidHeaderSpecific(vin, VEC_CLASSID, 2);
1709   PetscCall(VecSeq_CUPM_Impls.template resetarray<PETSC_MEMTYPE_DEVICE>(vin));
1710   PetscFunctionReturn(0);
1711 }
1712 
1713 } // anonymous namespace
1714 
1715 } // namespace impl
1716 
1717 } // namespace cupm
1718 
1719 } // namespace vec
1720 
1721 } // namespace Petsc
1722 
1723 #endif // __cplusplus
1724 
1725 #endif // PETSCVECSEQCUPM_HPP
1726