xref: /petsc/src/vec/vec/impls/seq/cupm/vecseqcupm.hpp (revision 35916fb55af248d44baf214f8a1ededac6a8329b)
1 #ifndef PETSCVECSEQCUPM_HPP
2 #define PETSCVECSEQCUPM_HPP
3 
4 #define PETSC_SKIP_SPINLOCK // REVIEW ME: why
5 
6 #include <petsc/private/veccupmimpl.h>
7 #include <petsc/private/randomimpl.h> // for _p_PetscRandom
8 #include "../src/sys/objects/device/impls/cupm/cupmthrustutility.hpp"
9 #include "../src/sys/objects/device/impls/cupm/cupmallocator.hpp"
10 #include "../src/sys/objects/device/impls/cupm/kernels.hpp"
11 
12 #if defined(__cplusplus)
13   #include <thrust/transform_reduce.h>
14   #include <thrust/reduce.h>
15   #include <thrust/functional.h>
16   #include <thrust/iterator/counting_iterator.h>
17   #include <thrust/inner_product.h>
18 
19 namespace Petsc
20 {
21 
22 namespace vec
23 {
24 
25 namespace cupm
26 {
27 
28 namespace impl
29 {
30 
31 // ==========================================================================================
32 // VecSeq_CUPM
33 // ==========================================================================================
34 
35 template <device::cupm::DeviceType T>
36 class VecSeq_CUPM : Vec_CUPMBase<T, VecSeq_CUPM<T>> {
37 public:
38   PETSC_VEC_CUPM_BASE_CLASS_HEADER(base_type, T, VecSeq_CUPM<T>);
39 
40 private:
41   PETSC_NODISCARD static Vec_Seq          *VecIMPLCast_(Vec) noexcept;
42   PETSC_NODISCARD static constexpr VecType VECIMPLCUPM_() noexcept;
43 
44   PETSC_NODISCARD static PetscErrorCode VecDestroy_IMPL_(Vec) noexcept;
45   PETSC_NODISCARD static PetscErrorCode VecResetArray_IMPL_(Vec) noexcept;
46   PETSC_NODISCARD static PetscErrorCode VecPlaceArray_IMPL_(Vec, const PetscScalar *) noexcept;
47   PETSC_NODISCARD static PetscErrorCode VecCreate_IMPL_Private_(Vec, PetscBool *, PetscInt, PetscScalar *) noexcept;
48 
49   PETSC_NODISCARD static PetscErrorCode MaybeIncrementEmptyLocalVec(Vec) noexcept;
50 
51   // common core for min and max
52   template <typename TupleFuncT, typename UnaryFuncT>
53   PETSC_NODISCARD static PetscErrorCode minmax_(TupleFuncT &&, UnaryFuncT &&, Vec, PetscInt *, PetscReal *) noexcept;
54   // common core for pointwise binary and pointwise unary thrust functions
55   template <typename BinaryFuncT>
56   PETSC_NODISCARD static PetscErrorCode pointwisebinary_(BinaryFuncT &&, Vec, Vec, Vec) noexcept;
57   template <typename UnaryFuncT>
58   PETSC_NODISCARD static PetscErrorCode pointwiseunary_(UnaryFuncT &&, Vec, Vec /*out*/ = nullptr) noexcept;
59   // mdot dispatchers
60   PETSC_NODISCARD static PetscErrorCode mdot_(/* use complex = */ std::true_type, Vec, PetscInt, const Vec[], PetscScalar *, PetscDeviceContext) noexcept;
61   PETSC_NODISCARD static PetscErrorCode mdot_(/* use complex = */ std::false_type, Vec, PetscInt, const Vec[], PetscScalar *, PetscDeviceContext) noexcept;
62   template <std::size_t... Idx>
63   PETSC_NODISCARD static PetscErrorCode mdot_kernel_dispatch_(PetscDeviceContext, cupmStream_t, const PetscScalar *, const Vec[], PetscInt, PetscScalar *, util::index_sequence<Idx...>) noexcept;
64   template <int>
65   PETSC_NODISCARD static PetscErrorCode mdot_kernel_dispatch_(PetscDeviceContext, cupmStream_t, const PetscScalar *, const Vec[], PetscInt, PetscScalar *, PetscInt &) noexcept;
66   template <std::size_t... Idx>
67   PETSC_NODISCARD static PetscErrorCode maxpy_kernel_dispatch_(PetscDeviceContext, cupmStream_t, PetscScalar *, const PetscScalar *, const Vec *, PetscInt, util::index_sequence<Idx...>) noexcept;
68   template <int>
69   PETSC_NODISCARD static PetscErrorCode maxpy_kernel_dispatch_(PetscDeviceContext, cupmStream_t, PetscScalar *, const PetscScalar *, const Vec *, PetscInt, PetscInt &) noexcept;
70   // common core for the various create routines
71   PETSC_NODISCARD static PetscErrorCode createseqcupm_(Vec, PetscDeviceContext, PetscScalar * /*host_ptr*/ = nullptr, PetscScalar * /*device_ptr*/ = nullptr) noexcept;
72 
73 public:
74   // callable directly via a bespoke function
75   PETSC_NODISCARD static PetscErrorCode createseqcupm(MPI_Comm, PetscInt, PetscInt, Vec *, PetscBool) noexcept;
76   PETSC_NODISCARD static PetscErrorCode createseqcupmwithbotharrays(MPI_Comm, PetscInt, PetscInt, const PetscScalar[], const PetscScalar[], Vec *) noexcept;
77 
78   // callable indirectly via function pointers
79   PETSC_NODISCARD static PetscErrorCode duplicate(Vec, Vec *) noexcept;
80   PETSC_NODISCARD static PetscErrorCode aypx(Vec, PetscScalar, Vec) noexcept;
81   PETSC_NODISCARD static PetscErrorCode axpy(Vec, PetscScalar, Vec) noexcept;
82   PETSC_NODISCARD static PetscErrorCode pointwisedivide(Vec, Vec, Vec) noexcept;
83   PETSC_NODISCARD static PetscErrorCode pointwisemult(Vec, Vec, Vec) noexcept;
84   PETSC_NODISCARD static PetscErrorCode reciprocal(Vec) noexcept;
85   PETSC_NODISCARD static PetscErrorCode waxpy(Vec, PetscScalar, Vec, Vec) noexcept;
86   PETSC_NODISCARD static PetscErrorCode maxpy(Vec, PetscInt, const PetscScalar[], Vec *) noexcept;
87   PETSC_NODISCARD static PetscErrorCode dot(Vec, Vec, PetscScalar *) noexcept;
88   PETSC_NODISCARD static PetscErrorCode mdot(Vec, PetscInt, const Vec[], PetscScalar *) noexcept;
89   PETSC_NODISCARD static PetscErrorCode set(Vec, PetscScalar) noexcept;
90   PETSC_NODISCARD static PetscErrorCode scale(Vec, PetscScalar) noexcept;
91   PETSC_NODISCARD static PetscErrorCode tdot(Vec, Vec, PetscScalar *) noexcept;
92   PETSC_NODISCARD static PetscErrorCode copy(Vec, Vec) noexcept;
93   PETSC_NODISCARD static PetscErrorCode swap(Vec, Vec) noexcept;
94   PETSC_NODISCARD static PetscErrorCode axpby(Vec, PetscScalar, PetscScalar, Vec) noexcept;
95   PETSC_NODISCARD static PetscErrorCode axpbypcz(Vec, PetscScalar, PetscScalar, PetscScalar, Vec, Vec) noexcept;
96   PETSC_NODISCARD static PetscErrorCode norm(Vec, NormType, PetscReal *) noexcept;
97   PETSC_NODISCARD static PetscErrorCode dotnorm2(Vec, Vec, PetscScalar *, PetscScalar *) noexcept;
98   PETSC_NODISCARD static PetscErrorCode destroy(Vec) noexcept;
99   PETSC_NODISCARD static PetscErrorCode conjugate(Vec) noexcept;
100   template <PetscMemoryAccessMode>
101   PETSC_NODISCARD static PetscErrorCode getlocalvector(Vec, Vec) noexcept;
102   template <PetscMemoryAccessMode>
103   PETSC_NODISCARD static PetscErrorCode restorelocalvector(Vec, Vec) noexcept;
104   PETSC_NODISCARD static PetscErrorCode max(Vec, PetscInt *, PetscReal *) noexcept;
105   PETSC_NODISCARD static PetscErrorCode min(Vec, PetscInt *, PetscReal *) noexcept;
106   PETSC_NODISCARD static PetscErrorCode sum(Vec, PetscScalar *) noexcept;
107   PETSC_NODISCARD static PetscErrorCode shift(Vec, PetscScalar) noexcept;
108   PETSC_NODISCARD static PetscErrorCode setrandom(Vec, PetscRandom) noexcept;
109   PETSC_NODISCARD static PetscErrorCode bindtocpu(Vec, PetscBool) noexcept;
110   PETSC_NODISCARD static PetscErrorCode setpreallocationcoo(Vec, PetscCount, const PetscInt[]) noexcept;
111   PETSC_NODISCARD static PetscErrorCode setvaluescoo(Vec, const PetscScalar[], InsertMode) noexcept;
112 };
113 
114 // ==========================================================================================
115 // VecSeq_CUPM - Private API
116 // ==========================================================================================
117 
118 template <device::cupm::DeviceType T>
119 inline Vec_Seq *VecSeq_CUPM<T>::VecIMPLCast_(Vec v) noexcept
120 {
121   return static_cast<Vec_Seq *>(v->data);
122 }
123 
124 template <device::cupm::DeviceType T>
125 inline constexpr VecType VecSeq_CUPM<T>::VECIMPLCUPM_() noexcept
126 {
127   return VECSEQCUPM();
128 }
129 
130 template <device::cupm::DeviceType T>
131 inline PetscErrorCode VecSeq_CUPM<T>::VecDestroy_IMPL_(Vec v) noexcept
132 {
133   return VecDestroy_Seq(v);
134 }
135 
136 template <device::cupm::DeviceType T>
137 inline PetscErrorCode VecSeq_CUPM<T>::VecResetArray_IMPL_(Vec v) noexcept
138 {
139   return VecResetArray_Seq(v);
140 }
141 
142 template <device::cupm::DeviceType T>
143 inline PetscErrorCode VecSeq_CUPM<T>::VecPlaceArray_IMPL_(Vec v, const PetscScalar *a) noexcept
144 {
145   return VecPlaceArray_Seq(v, a);
146 }
147 
148 template <device::cupm::DeviceType T>
149 inline PetscErrorCode VecSeq_CUPM<T>::VecCreate_IMPL_Private_(Vec v, PetscBool *alloc_missing, PetscInt, PetscScalar *host_array) noexcept
150 {
151   PetscMPIInt size;
152 
153   PetscFunctionBegin;
154   if (alloc_missing) *alloc_missing = PETSC_FALSE;
155   PetscCallMPI(MPI_Comm_size(PetscObjectComm(PetscObjectCast(v)), &size));
156   PetscCheck(size <= 1, PETSC_COMM_SELF, PETSC_ERR_ARG_WRONG, "Must create VecSeq on communicator of size 1, have size %d", size);
157   PetscCall(VecCreate_Seq_Private(v, host_array));
158   PetscFunctionReturn(0);
159 }
160 
161 // for functions with an early return based one vec size we still need to artificially bump the
162 // object state. This is to prevent the following:
163 //
164 // 0. Suppose you have a Vec {
165 //   rank 0: [0],
166 //   rank 1: [<empty>]
167 // }
168 // 1. both ranks have Vec with PetscObjectState = 0, stashed norm of 0
169 // 2. Vec enters e.g. VecSet(10)
170 // 3. rank 1 has local size 0 and bails immediately
171 // 4. rank 0 has local size 1 and enters function, eventually calls DeviceArrayWrite()
172 // 5. DeviceArrayWrite() calls PetscObjectStateIncrease(), now state = 1
173 // 6. Vec enters VecNorm(), and calls VecNormAvailable()
174 // 7. rank 1 has object state = 0, equal to stash and returns early with norm = 0
175 // 8. rank 0 has object state = 1, not equal to stash, continues to impl function
176 // 9. rank 0 deadlocks on MPI_Allreduce() because rank 1 bailed early
177 template <device::cupm::DeviceType T>
178 inline PetscErrorCode VecSeq_CUPM<T>::MaybeIncrementEmptyLocalVec(Vec v) noexcept
179 {
180   PetscFunctionBegin;
181   if (PetscUnlikely((v->map->n == 0) && (v->map->N != 0))) PetscCall(PetscObjectStateIncrease(PetscObjectCast(v)));
182   PetscFunctionReturn(0);
183 }
184 
185 template <device::cupm::DeviceType T>
186 inline PetscErrorCode VecSeq_CUPM<T>::createseqcupm_(Vec v, PetscDeviceContext dctx, PetscScalar *host_array, PetscScalar *device_array) noexcept
187 {
188   PetscFunctionBegin;
189   PetscCall(base_type::VecCreate_IMPL_Private(v, nullptr, 0, host_array));
190   PetscCall(Initialize_CUPMBase(v, PETSC_FALSE, host_array, device_array, dctx));
191   PetscFunctionReturn(0);
192 }
193 
194 template <device::cupm::DeviceType T>
195 template <typename BinaryFuncT>
196 inline PetscErrorCode VecSeq_CUPM<T>::pointwisebinary_(BinaryFuncT &&binary, Vec xin, Vec yin, Vec zout) noexcept
197 {
198   PetscFunctionBegin;
199   if (const auto n = zout->map->n) {
200     PetscDeviceContext dctx;
201 
202     PetscCall(GetHandles_(&dctx));
203     PetscCall(device::cupm::impl::ThrustApplyPointwise<T>(dctx, std::forward<BinaryFuncT>(binary), n, DeviceArrayRead(dctx, xin).data(), DeviceArrayRead(dctx, yin).data(), DeviceArrayWrite(dctx, zout).data()));
204     PetscCall(PetscDeviceContextSynchronize(dctx));
205   } else {
206     PetscCall(MaybeIncrementEmptyLocalVec(zout));
207   }
208   PetscFunctionReturn(0);
209 }
210 
211 template <device::cupm::DeviceType T>
212 template <typename UnaryFuncT>
213 inline PetscErrorCode VecSeq_CUPM<T>::pointwiseunary_(UnaryFuncT &&unary, Vec xinout, Vec yin) noexcept
214 {
215   const auto inplace = !yin || (xinout == yin);
216 
217   PetscFunctionBegin;
218   if (const auto n = xinout->map->n) {
219     PetscDeviceContext dctx;
220 
221     PetscCall(GetHandles_(&dctx));
222     if (inplace) {
223       PetscCall(device::cupm::impl::ThrustApplyPointwise<T>(dctx, std::forward<UnaryFuncT>(unary), n, DeviceArrayReadWrite(dctx, xinout).data()));
224     } else {
225       PetscCall(device::cupm::impl::ThrustApplyPointwise<T>(dctx, std::forward<UnaryFuncT>(unary), n, DeviceArrayRead(dctx, xinout).data(), DeviceArrayWrite(dctx, yin).data()));
226     }
227     PetscCall(PetscDeviceContextSynchronize(dctx));
228   } else {
229     if (inplace) {
230       PetscCall(MaybeIncrementEmptyLocalVec(xinout));
231     } else {
232       PetscCall(MaybeIncrementEmptyLocalVec(yin));
233     }
234   }
235   PetscFunctionReturn(0);
236 }
237 
238 // ==========================================================================================
239 // VecSeq_CUPM - Public API - Constructors
240 // ==========================================================================================
241 
242 // VecCreateSeqCUPM()
243 template <device::cupm::DeviceType T>
244 inline PetscErrorCode VecSeq_CUPM<T>::createseqcupm(MPI_Comm comm, PetscInt bs, PetscInt n, Vec *v, PetscBool call_set_type) noexcept
245 {
246   PetscFunctionBegin;
247   PetscCall(Create_CUPMBase(comm, bs, n, n, v, call_set_type));
248   PetscFunctionReturn(0);
249 }
250 
251 // VecCreateSeqCUPMWithArrays()
252 template <device::cupm::DeviceType T>
253 inline PetscErrorCode VecSeq_CUPM<T>::createseqcupmwithbotharrays(MPI_Comm comm, PetscInt bs, PetscInt n, const PetscScalar host_array[], const PetscScalar device_array[], Vec *v) noexcept
254 {
255   PetscDeviceContext dctx;
256 
257   PetscFunctionBegin;
258   PetscCall(GetHandles_(&dctx));
259   // do NOT call VecSetType(), otherwise ops->create() -> create() ->
260   // createseqcupm_() is called!
261   PetscCall(createseqcupm(comm, bs, n, v, PETSC_FALSE));
262   PetscCall(createseqcupm_(*v, dctx, PetscRemoveConstCast(host_array), PetscRemoveConstCast(device_array)));
263   PetscFunctionReturn(0);
264 }
265 
266 // v->ops->duplicate
267 template <device::cupm::DeviceType T>
268 inline PetscErrorCode VecSeq_CUPM<T>::duplicate(Vec v, Vec *y) noexcept
269 {
270   PetscDeviceContext dctx;
271 
272   PetscFunctionBegin;
273   PetscCall(GetHandles_(&dctx));
274   PetscCall(Duplicate_CUPMBase(v, y, dctx));
275   PetscFunctionReturn(0);
276 }
277 
278 // ==========================================================================================
279 // VecSeq_CUPM - Public API - Utility
280 // ==========================================================================================
281 
282 // v->ops->bindtocpu
283 template <device::cupm::DeviceType T>
284 inline PetscErrorCode VecSeq_CUPM<T>::bindtocpu(Vec v, PetscBool usehost) noexcept
285 {
286   PetscDeviceContext dctx;
287 
288   PetscFunctionBegin;
289   PetscCall(GetHandles_(&dctx));
290   PetscCall(BindToCPU_CUPMBase(v, usehost, dctx));
291 
292   // REVIEW ME: this absolutely should be some sort of bulk mempcy rather than this mess
293   VecSetOp_CUPM(dot, VecDot_Seq, dot);
294   VecSetOp_CUPM(norm, VecNorm_Seq, norm);
295   VecSetOp_CUPM(tdot, VecTDot_Seq, tdot);
296   VecSetOp_CUPM(mdot, VecMDot_Seq, mdot);
297   VecSetOp_CUPM(resetarray, VecResetArray_Seq, base_type::template resetarray<PETSC_MEMTYPE_HOST>);
298   VecSetOp_CUPM(placearray, VecPlaceArray_Seq, base_type::template placearray<PETSC_MEMTYPE_HOST>);
299   v->ops->mtdot = v->ops->mtdot_local = VecMTDot_Seq;
300   VecSetOp_CUPM(conjugate, VecConjugate_Seq, conjugate);
301   VecSetOp_CUPM(max, VecMax_Seq, max);
302   VecSetOp_CUPM(min, VecMin_Seq, min);
303   VecSetOp_CUPM(setpreallocationcoo, VecSetPreallocationCOO_Seq, setpreallocationcoo);
304   VecSetOp_CUPM(setvaluescoo, VecSetValuesCOO_Seq, setvaluescoo);
305   PetscFunctionReturn(0);
306 }
307 
308 // ==========================================================================================
309 // VecSeq_CUPM - Public API - Mutatators
310 // ==========================================================================================
311 
312 // v->ops->getlocalvector or v->ops->getlocalvectorread
313 template <device::cupm::DeviceType T>
314 template <PetscMemoryAccessMode access>
315 inline PetscErrorCode VecSeq_CUPM<T>::getlocalvector(Vec v, Vec w) noexcept
316 {
317   PetscBool wisseqcupm;
318 
319   PetscFunctionBegin;
320   PetscCheckTypeNames(v, VECSEQCUPM(), VECMPICUPM());
321   PetscCall(PetscObjectTypeCompare(PetscObjectCast(w), VECSEQCUPM(), &wisseqcupm));
322   if (wisseqcupm) {
323     if (const auto wseq = VecIMPLCast(w)) {
324       if (auto &alloced = wseq->array_allocated) {
325         const auto useit = UseCUPMHostAlloc(util::exchange(w->pinned_memory, PETSC_FALSE));
326 
327         PetscCall(PetscFree(alloced));
328       }
329       wseq->array         = nullptr;
330       wseq->unplacedarray = nullptr;
331     }
332     if (const auto wcu = VecCUPMCast(w)) {
333       if (auto &device_array = wcu->array_d) {
334         cupmStream_t stream;
335 
336         PetscCall(GetHandles_(&stream));
337         PetscCallCUPM(cupmFreeAsync(device_array, stream));
338       }
339       PetscCall(PetscFree(w->spptr /* wcu */));
340     }
341   }
342   if (v->petscnative && wisseqcupm) {
343     PetscCall(PetscFree(w->data));
344     w->data          = v->data;
345     w->offloadmask   = v->offloadmask;
346     w->pinned_memory = v->pinned_memory;
347     w->spptr         = v->spptr;
348     PetscCall(PetscObjectStateIncrease(PetscObjectCast(w)));
349   } else {
350     const auto array = &VecIMPLCast(w)->array;
351 
352     if (access == PETSC_MEMORY_ACCESS_READ) {
353       PetscCall(VecGetArrayRead(v, const_cast<const PetscScalar **>(array)));
354     } else {
355       PetscCall(VecGetArray(v, array));
356     }
357     w->offloadmask = PETSC_OFFLOAD_CPU;
358     if (wisseqcupm) {
359       PetscDeviceContext dctx;
360 
361       PetscCall(GetHandles_(&dctx));
362       PetscCall(DeviceAllocateCheck_(dctx, w));
363     }
364   }
365   PetscFunctionReturn(0);
366 }
367 
368 // v->ops->restorelocalvector or v->ops->restorelocalvectorread
369 template <device::cupm::DeviceType T>
370 template <PetscMemoryAccessMode access>
371 inline PetscErrorCode VecSeq_CUPM<T>::restorelocalvector(Vec v, Vec w) noexcept
372 {
373   PetscBool wisseqcupm;
374 
375   PetscFunctionBegin;
376   PetscCheckTypeNames(v, VECSEQCUPM(), VECMPICUPM());
377   PetscCall(PetscObjectTypeCompare(PetscObjectCast(w), VECSEQCUPM(), &wisseqcupm));
378   if (v->petscnative && wisseqcupm) {
379     // the assignments to nullptr are __critical__, as w may persist after this call returns
380     // and shouldn't share data with v!
381     v->pinned_memory = w->pinned_memory;
382     v->offloadmask   = util::exchange(w->offloadmask, PETSC_OFFLOAD_UNALLOCATED);
383     v->data          = util::exchange(w->data, nullptr);
384     v->spptr         = util::exchange(w->spptr, nullptr);
385   } else {
386     const auto array = &VecIMPLCast(w)->array;
387 
388     if (access == PETSC_MEMORY_ACCESS_READ) {
389       PetscCall(VecRestoreArrayRead(v, const_cast<const PetscScalar **>(array)));
390     } else {
391       PetscCall(VecRestoreArray(v, array));
392     }
393     if (w->spptr && wisseqcupm) {
394       cupmStream_t stream;
395 
396       PetscCall(GetHandles_(&stream));
397       PetscCallCUPM(cupmFreeAsync(VecCUPMCast(w)->array_d, stream));
398       PetscCall(PetscFree(w->spptr));
399     }
400   }
401   PetscFunctionReturn(0);
402 }
403 
404 // ==========================================================================================
405 // VecSeq_CUPM - Public API - Compute Methods
406 // ==========================================================================================
407 
408 // v->ops->aypx
409 template <device::cupm::DeviceType T>
410 inline PetscErrorCode VecSeq_CUPM<T>::aypx(Vec yin, PetscScalar alpha, Vec xin) noexcept
411 {
412   const auto         n    = static_cast<cupmBlasInt_t>(yin->map->n);
413   const auto         sync = n != 0;
414   PetscDeviceContext dctx;
415 
416   PetscFunctionBegin;
417   PetscCall(GetHandles_(&dctx));
418   if (alpha == PetscScalar(0.0)) {
419     cupmStream_t stream;
420 
421     PetscCall(GetHandlesFrom_(dctx, &stream));
422     PetscCall(PetscLogGpuTimeBegin());
423     PetscCall(PetscCUPMMemcpyAsync(DeviceArrayWrite(dctx, yin).data(), DeviceArrayRead(dctx, xin).data(), n, cupmMemcpyDeviceToDevice, stream));
424     PetscCall(PetscLogGpuTimeEnd());
425   } else if (n) {
426     const auto       alphaIsOne = alpha == PetscScalar(1.0);
427     const auto       calpha     = cupmScalarPtrCast(&alpha);
428     cupmBlasHandle_t cupmBlasHandle;
429 
430     PetscCall(GetHandlesFrom_(dctx, &cupmBlasHandle));
431     {
432       const auto yptr = DeviceArrayReadWrite(dctx, yin);
433       const auto xptr = DeviceArrayRead(dctx, xin);
434 
435       PetscCall(PetscLogGpuTimeBegin());
436       if (alphaIsOne) {
437         PetscCallCUPMBLAS(cupmBlasXaxpy(cupmBlasHandle, n, calpha, xptr.cupmdata(), 1, yptr.cupmdata(), 1));
438       } else {
439         const auto one = cupmScalarCast(1.0);
440 
441         PetscCallCUPMBLAS(cupmBlasXscal(cupmBlasHandle, n, calpha, yptr.cupmdata(), 1));
442         PetscCallCUPMBLAS(cupmBlasXaxpy(cupmBlasHandle, n, &one, xptr.cupmdata(), 1, yptr.cupmdata(), 1));
443       }
444       PetscCall(PetscLogGpuTimeEnd());
445     }
446     PetscCall(PetscLogGpuFlops((alphaIsOne ? 1 : 2) * n));
447   }
448   if (sync) PetscCall(PetscDeviceContextSynchronize(dctx));
449   PetscFunctionReturn(0);
450 }
451 
452 // v->ops->axpy
453 template <device::cupm::DeviceType T>
454 inline PetscErrorCode VecSeq_CUPM<T>::axpy(Vec yin, PetscScalar alpha, Vec xin) noexcept
455 {
456   PetscBool xiscupm;
457 
458   PetscFunctionBegin;
459   PetscCall(PetscObjectTypeCompareAny(PetscObjectCast(xin), &xiscupm, VECSEQCUPM(), VECMPICUPM(), ""));
460   if (xiscupm) {
461     const auto         n = static_cast<cupmBlasInt_t>(yin->map->n);
462     PetscDeviceContext dctx;
463     cupmBlasHandle_t   cupmBlasHandle;
464 
465     PetscCall(GetHandles_(&dctx, &cupmBlasHandle));
466     PetscCall(PetscLogGpuTimeBegin());
467     PetscCallCUPMBLAS(cupmBlasXaxpy(cupmBlasHandle, n, cupmScalarPtrCast(&alpha), DeviceArrayRead(dctx, xin), 1, DeviceArrayReadWrite(dctx, yin), 1));
468     PetscCall(PetscLogGpuTimeEnd());
469     PetscCall(PetscLogGpuFlops(2 * n));
470     PetscCall(PetscDeviceContextSynchronize(dctx));
471   } else {
472     PetscCall(VecAXPY_Seq(yin, alpha, xin));
473   }
474   PetscFunctionReturn(0);
475 }
476 
477 // v->ops->pointwisedivide
478 template <device::cupm::DeviceType T>
479 inline PetscErrorCode VecSeq_CUPM<T>::pointwisedivide(Vec win, Vec xin, Vec yin) noexcept
480 {
481   PetscFunctionBegin;
482   if (xin->boundtocpu || yin->boundtocpu) {
483     PetscCall(VecPointwiseDivide_Seq(win, xin, yin));
484   } else {
485     // note order of arguments! xin and yin are read, win is written!
486     PetscCall(pointwisebinary_(thrust::divides<PetscScalar>{}, xin, yin, win));
487   }
488   PetscFunctionReturn(0);
489 }
490 
491 // v->ops->pointwisemult
492 template <device::cupm::DeviceType T>
493 inline PetscErrorCode VecSeq_CUPM<T>::pointwisemult(Vec win, Vec xin, Vec yin) noexcept
494 {
495   PetscFunctionBegin;
496   if (xin->boundtocpu || yin->boundtocpu) {
497     PetscCall(VecPointwiseMult_Seq(win, xin, yin));
498   } else {
499     // note order of arguments! xin and yin are read, win is written!
500     PetscCall(pointwisebinary_(thrust::multiplies<PetscScalar>{}, xin, yin, win));
501   }
502   PetscFunctionReturn(0);
503 }
504 
505 namespace detail
506 {
507 
508 struct reciprocal {
509   PETSC_HOSTDEVICE_INLINE_DECL PetscScalar operator()(PetscScalar s) const noexcept
510   {
511     // yes all of this verbosity is needed because sometimes PetscScalar is a thrust::complex
512     // and then it matters whether we do s ? true : false vs s == 0, as well as whether we wrap
513     // everything in PetscScalar...
514     return s == PetscScalar{0.0} ? s : PetscScalar{1.0} / s;
515   }
516 };
517 
518 } // namespace detail
519 
520 // v->ops->reciprocal
521 template <device::cupm::DeviceType T>
522 inline PetscErrorCode VecSeq_CUPM<T>::reciprocal(Vec xin) noexcept
523 {
524   PetscFunctionBegin;
525   PetscCall(pointwiseunary_(detail::reciprocal{}, xin));
526   PetscFunctionReturn(0);
527 }
528 
529 // v->ops->waxpy
530 template <device::cupm::DeviceType T>
531 inline PetscErrorCode VecSeq_CUPM<T>::waxpy(Vec win, PetscScalar alpha, Vec xin, Vec yin) noexcept
532 {
533   PetscFunctionBegin;
534   if (alpha == PetscScalar(0.0)) {
535     PetscCall(copy(yin, win));
536   } else if (const auto n = static_cast<cupmBlasInt_t>(win->map->n)) {
537     PetscDeviceContext dctx;
538     cupmBlasHandle_t   cupmBlasHandle;
539     cupmStream_t       stream;
540 
541     PetscCall(GetHandles_(&dctx, &cupmBlasHandle, &stream));
542     {
543       const auto wptr = DeviceArrayWrite(dctx, win);
544 
545       PetscCall(PetscLogGpuTimeBegin());
546       PetscCall(PetscCUPMMemcpyAsync(wptr.data(), DeviceArrayRead(dctx, yin).data(), n, cupmMemcpyDeviceToDevice, stream, true));
547       PetscCallCUPMBLAS(cupmBlasXaxpy(cupmBlasHandle, n, cupmScalarPtrCast(&alpha), DeviceArrayRead(dctx, xin), 1, wptr.cupmdata(), 1));
548       PetscCall(PetscLogGpuTimeEnd());
549     }
550     PetscCall(PetscLogGpuFlops(2 * n));
551     PetscCall(PetscDeviceContextSynchronize(dctx));
552   }
553   PetscFunctionReturn(0);
554 }
555 
556 namespace kernels
557 {
558 
559 template <typename... Args>
560 PETSC_KERNEL_DECL static void maxpy_kernel(const PetscInt size, PetscScalar *PETSC_RESTRICT xptr, const PetscScalar *PETSC_RESTRICT aptr, Args... yptr)
561 {
562   constexpr int      N        = sizeof...(Args);
563   const auto         tx       = threadIdx.x;
564   const PetscScalar *yptr_p[] = {yptr...};
565 
566   PETSC_SHAREDMEM_DECL PetscScalar aptr_shmem[N];
567 
568   // load a to shared memory
569   if (tx < N) aptr_shmem[tx] = aptr[tx];
570   __syncthreads();
571 
572   ::Petsc::device::cupm::kernels::util::grid_stride_1D(size, [&](PetscInt i) {
573     // these may look the same but give different results!
574   #if 0
575     PetscScalar sum = 0.0;
576 
577     #pragma unroll
578     for (auto j = 0; j < N; ++j) sum += aptr_shmem[j]*yptr_p[j][i];
579     xptr[i] += sum;
580   #else
581     auto sum = xptr[i];
582 
583     #pragma unroll
584     for (auto j = 0; j < N; ++j) sum += aptr_shmem[j]*yptr_p[j][i];
585     xptr[i] = sum;
586   #endif
587   });
588   return;
589 }
590 
591 } // namespace kernels
592 
593 namespace detail
594 {
595 
596 // a helper-struct to gobble the size_t input, it is used with template parameter pack
597 // expansion such that
598 // typename repeat_type<MyType, IdxParamPack>...
599 // expands to
600 // MyType, MyType, MyType, ... [repeated sizeof...(IdxParamPack) times]
601 template <typename T, std::size_t>
602 struct repeat_type {
603   using type = T;
604 };
605 
606 } // namespace detail
607 
608 template <device::cupm::DeviceType T>
609 template <std::size_t... Idx>
610 inline PetscErrorCode VecSeq_CUPM<T>::maxpy_kernel_dispatch_(PetscDeviceContext dctx, cupmStream_t stream, PetscScalar *xptr, const PetscScalar *aptr, const Vec *yin, PetscInt size, util::index_sequence<Idx...>) noexcept
611 {
612   PetscFunctionBegin;
613   // clang-format off
614   PetscCall(
615     PetscCUPMLaunchKernel1D(
616       size, 0, stream,
617       kernels::maxpy_kernel<typename detail::repeat_type<const PetscScalar *, Idx>::type...>,
618       size, xptr, aptr, DeviceArrayRead(dctx, yin[Idx]).data()...
619     )
620   );
621   // clang-format on
622   PetscFunctionReturn(0);
623 }
624 
625 template <device::cupm::DeviceType T>
626 template <int N>
627 inline PetscErrorCode VecSeq_CUPM<T>::maxpy_kernel_dispatch_(PetscDeviceContext dctx, cupmStream_t stream, PetscScalar *xptr, const PetscScalar *aptr, const Vec *yin, PetscInt size, PetscInt &yidx) noexcept
628 {
629   PetscFunctionBegin;
630   PetscCall(maxpy_kernel_dispatch_(dctx, stream, xptr, aptr + yidx, yin + yidx, size, util::make_index_sequence<N>{}));
631   yidx += N;
632   PetscFunctionReturn(0);
633 }
634 
635 // v->ops->maxpy
636 template <device::cupm::DeviceType T>
637 inline PetscErrorCode VecSeq_CUPM<T>::maxpy(Vec xin, PetscInt nv, const PetscScalar *alpha, Vec *yin) noexcept
638 {
639   const auto         n = xin->map->n;
640   PetscDeviceContext dctx;
641   cupmStream_t       stream;
642 
643   PetscFunctionBegin;
644   PetscCall(GetHandles_(&dctx, &stream));
645   {
646     const auto   xptr    = DeviceArrayReadWrite(dctx, xin);
647     PetscScalar *d_alpha = nullptr;
648     PetscInt     yidx    = 0;
649 
650     // placement of early-return is deliberate, we would like to capture the
651     // DeviceArrayReadWrite() call (which calls PetscObjectStateIncreate()) before we bail
652     if (!n || !nv) PetscFunctionReturn(0);
653     PetscCall(PetscDeviceMalloc(dctx, PETSC_MEMTYPE_CUPM(), nv, &d_alpha));
654     PetscCall(PetscCUPMMemcpyAsync(d_alpha, alpha, nv, cupmMemcpyHostToDevice, stream));
655     PetscCall(PetscLogGpuTimeBegin());
656     do {
657       switch (nv - yidx) {
658       case 7:
659         PetscCall(maxpy_kernel_dispatch_<7>(dctx, stream, xptr.data(), d_alpha, yin, n, yidx));
660         break;
661       case 6:
662         PetscCall(maxpy_kernel_dispatch_<6>(dctx, stream, xptr.data(), d_alpha, yin, n, yidx));
663         break;
664       case 5:
665         PetscCall(maxpy_kernel_dispatch_<5>(dctx, stream, xptr.data(), d_alpha, yin, n, yidx));
666         break;
667       case 4:
668         PetscCall(maxpy_kernel_dispatch_<4>(dctx, stream, xptr.data(), d_alpha, yin, n, yidx));
669         break;
670       case 3:
671         PetscCall(maxpy_kernel_dispatch_<3>(dctx, stream, xptr.data(), d_alpha, yin, n, yidx));
672         break;
673       case 2:
674         PetscCall(maxpy_kernel_dispatch_<2>(dctx, stream, xptr.data(), d_alpha, yin, n, yidx));
675         break;
676       case 1:
677         PetscCall(maxpy_kernel_dispatch_<1>(dctx, stream, xptr.data(), d_alpha, yin, n, yidx));
678         break;
679       default: // 8 or more
680         PetscCall(maxpy_kernel_dispatch_<8>(dctx, stream, xptr.data(), d_alpha, yin, n, yidx));
681         break;
682       }
683     } while (yidx < nv);
684     PetscCall(PetscLogGpuTimeEnd());
685     PetscCall(PetscDeviceFree(dctx, d_alpha));
686   }
687   PetscCall(PetscLogGpuFlops(nv * 2 * n));
688   PetscCall(PetscDeviceContextSynchronize(dctx));
689   PetscFunctionReturn(0);
690 }
691 
692 template <device::cupm::DeviceType T>
693 inline PetscErrorCode VecSeq_CUPM<T>::dot(Vec xin, Vec yin, PetscScalar *z) noexcept
694 {
695   PetscFunctionBegin;
696   if (const auto n = static_cast<cupmBlasInt_t>(xin->map->n)) {
697     PetscDeviceContext dctx;
698     cupmBlasHandle_t   cupmBlasHandle;
699 
700     PetscCall(GetHandles_(&dctx, &cupmBlasHandle));
701     // arguments y, x are reversed because BLAS complex conjugates the first argument, PETSc the
702     // second
703     PetscCall(PetscLogGpuTimeBegin());
704     PetscCallCUPMBLAS(cupmBlasXdot(cupmBlasHandle, n, DeviceArrayRead(dctx, yin), 1, DeviceArrayRead(dctx, xin), 1, cupmScalarPtrCast(z)));
705     PetscCall(PetscLogGpuTimeEnd());
706     PetscCall(PetscLogGpuFlops(2 * n - 1));
707   } else {
708     *z = 0.0;
709   }
710   PetscFunctionReturn(0);
711 }
712 
713   #define MDOT_WORKGROUP_NUM  128
714   #define MDOT_WORKGROUP_SIZE MDOT_WORKGROUP_NUM
715 
716 namespace kernels
717 {
718 
719 PETSC_DEVICE_INLINE_DECL static PetscInt EntriesPerGroup(const PetscInt size) noexcept
720 {
721   const auto group_entries = (size - 1) / gridDim.x + 1;
722   // for very small vectors, a group should still do some work
723   return group_entries ? group_entries : 1;
724 }
725 
726 template <typename... ConstPetscScalarPointer>
727 PETSC_KERNEL_DECL static void mdot_kernel(const PetscScalar *PETSC_RESTRICT x, const PetscInt size, PetscScalar *PETSC_RESTRICT results, ConstPetscScalarPointer... y)
728 {
729   constexpr int      N        = sizeof...(ConstPetscScalarPointer);
730   const PetscScalar *ylocal[] = {y...};
731   PetscScalar        sumlocal[N];
732 
733   PETSC_SHAREDMEM_DECL PetscScalar shmem[N * MDOT_WORKGROUP_SIZE];
734 
735   // HIP -- for whatever reason -- has threadIdx, blockIdx, blockDim, and gridDim as separate
736   // types, so each of these go on separate lines...
737   const auto tx       = threadIdx.x;
738   const auto bx       = blockIdx.x;
739   const auto bdx      = blockDim.x;
740   const auto gdx      = gridDim.x;
741   const auto worksize = EntriesPerGroup(size);
742   const auto begin    = tx + bx * worksize;
743   const auto end      = min((bx + 1) * worksize, size);
744 
745   #pragma unroll
746   for (auto i = 0; i < N; ++i) sumlocal[i] = 0;
747 
748   for (auto i = begin; i < end; i += bdx) {
749     const auto xi = x[i]; // load only once from global memory!
750 
751   #pragma unroll
752     for (auto j = 0; j < N; ++j) sumlocal[j] += ylocal[j][i] * xi;
753   }
754 
755   #pragma unroll
756   for (auto i = 0; i < N; ++i) shmem[tx + i * MDOT_WORKGROUP_SIZE] = sumlocal[i];
757 
758   // parallel reduction
759   for (auto stride = bdx / 2; stride > 0; stride /= 2) {
760     __syncthreads();
761     if (tx < stride) {
762   #pragma unroll
763       for (auto i = 0; i < N; ++i) shmem[tx + i * MDOT_WORKGROUP_SIZE] += shmem[tx + stride + i * MDOT_WORKGROUP_SIZE];
764     }
765   }
766   // bottom N threads per block write to global memory
767   // REVIEW ME: I am ~pretty~ sure we don't need another __syncthreads() here since each thread
768   // writes to the same sections in the above loop that it is about to read from below, but
769   // running this under the racecheck tool of cuda-memcheck reports a write-after-write hazard.
770   __syncthreads();
771   if (tx < N) results[bx + tx * gdx] = shmem[tx * MDOT_WORKGROUP_SIZE];
772   return;
773 }
774 
775 PETSC_KERNEL_DECL static void sum_kernel(const PetscInt size, PetscScalar *PETSC_RESTRICT results)
776 {
777   int         local_i = 0;
778   PetscScalar local_results[8];
779 
780   // each thread sums up MDOT_WORKGROUP_NUM entries of the result, storing it in a local buffer
781   //
782   // *-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*
783   // | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | ...
784   // *-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*
785   //  |  ______________________________________________________/
786   //  | /            <- MDOT_WORKGROUP_NUM ->
787   //  |/
788   //  +
789   //  v
790   // *-*-*
791   // | | | ...
792   // *-*-*
793   //
794   ::Petsc::device::cupm::kernels::util::grid_stride_1D(size, [&](PetscInt i) {
795     PetscScalar z_sum = 0;
796 
797     for (auto j = i * MDOT_WORKGROUP_SIZE; j < (i + 1) * MDOT_WORKGROUP_SIZE; ++j) z_sum += results[j];
798     local_results[local_i++] = z_sum;
799   });
800   // if we needed more than 1 workgroup to handle the vector we should sync since other threads
801   // may currently be reading from results
802   if (size >= MDOT_WORKGROUP_SIZE) __syncthreads();
803   // Local buffer is now written to global memory
804   ::Petsc::device::cupm::kernels::util::grid_stride_1D(size, [&](PetscInt i) {
805     const auto j = --local_i;
806 
807     if (j >= 0) results[i] = local_results[j];
808   });
809   return;
810 }
811 
812 } // namespace kernels
813 
814 template <device::cupm::DeviceType T>
815 template <std::size_t... Idx>
816 inline PetscErrorCode VecSeq_CUPM<T>::mdot_kernel_dispatch_(PetscDeviceContext dctx, cupmStream_t stream, const PetscScalar *xarr, const Vec yin[], PetscInt size, PetscScalar *results, util::index_sequence<Idx...>) noexcept
817 {
818   PetscFunctionBegin;
819   // REVIEW ME: convert this kernel launch to PetscCUPMLaunchKernel1D(), it currently launches
820   // 128 blocks of 128 threads every time which may be wasteful
821   // clang-format off
822   PetscCallCUPM(
823     cupmLaunchKernel(
824       kernels::mdot_kernel<typename detail::repeat_type<const PetscScalar *, Idx>::type...>,
825       MDOT_WORKGROUP_NUM, MDOT_WORKGROUP_SIZE, 0, stream,
826       xarr, size, results, DeviceArrayRead(dctx, yin[Idx]).data()...
827     )
828   );
829   // clang-format on
830   PetscFunctionReturn(0);
831 }
832 
833 template <device::cupm::DeviceType T>
834 template <int N>
835 inline PetscErrorCode VecSeq_CUPM<T>::mdot_kernel_dispatch_(PetscDeviceContext dctx, cupmStream_t stream, const PetscScalar *xarr, const Vec yin[], PetscInt size, PetscScalar *results, PetscInt &yidx) noexcept
836 {
837   PetscFunctionBegin;
838   PetscCall(mdot_kernel_dispatch_(dctx, stream, xarr, yin + yidx, size, results + yidx * MDOT_WORKGROUP_NUM, util::make_index_sequence<N>{}));
839   yidx += N;
840   PetscFunctionReturn(0);
841 }
842 
843 template <device::cupm::DeviceType T>
844 inline PetscErrorCode VecSeq_CUPM<T>::mdot_(std::false_type, Vec xin, PetscInt nv, const Vec yin[], PetscScalar *z, PetscDeviceContext dctx) noexcept
845 {
846   // the largest possible size of a batch
847   constexpr PetscInt batchsize = 8;
848   // how many sub streams to create, if nv <= batchsize we can do this without looping, so we
849   // do not create substreams. Note we don't create more than 8 streams, in practice we could
850   // not get more parallelism with higher numbers.
851   const auto num_sub_streams = nv > batchsize ? std::min((nv + batchsize) / batchsize, batchsize) : 0;
852   const auto n               = xin->map->n;
853   // number of vectors that we handle via the batches. note any singletons are handled by
854   // cublas, hence the nv-1.
855   const auto   nvbatch = ((nv % batchsize) == 1) ? nv - 1 : nv;
856   const auto   nwork   = nvbatch * MDOT_WORKGROUP_NUM;
857   PetscScalar *d_results;
858   cupmStream_t stream;
859 
860   PetscFunctionBegin;
861   PetscCall(GetHandlesFrom_(dctx, &stream));
862   // allocate scratchpad memory for the results of individual work groups
863   PetscCall(PetscDeviceMalloc(dctx, PETSC_MEMTYPE_CUPM(), nwork, &d_results));
864   {
865     const auto          xptr       = DeviceArrayRead(dctx, xin);
866     PetscInt            yidx       = 0;
867     auto                subidx     = 0;
868     auto                cur_stream = stream;
869     auto                cur_ctx    = dctx;
870     PetscDeviceContext *sub        = nullptr;
871     PetscStreamType     stype;
872 
873     // REVIEW ME: maybe PetscDeviceContextFork() should insert dctx into the first entry of
874     // sub. Ideally the parent context should also join in on the fork, but it is extremely
875     // fiddly to do so presently
876     PetscCall(PetscDeviceContextGetStreamType(dctx, &stype));
877     if (stype == PETSC_STREAM_GLOBAL_BLOCKING) stype = PETSC_STREAM_DEFAULT_BLOCKING;
878     // If we have a globally blocking stream create nonblocking streams instead (as we can
879     // locally exploit the parallelism). Otherwise use the prescribed stream type.
880     PetscCall(PetscDeviceContextForkWithStreamType(dctx, stype, num_sub_streams, &sub));
881     PetscCall(PetscLogGpuTimeBegin());
882     do {
883       if (num_sub_streams) {
884         cur_ctx = sub[subidx++ % num_sub_streams];
885         PetscCall(GetHandlesFrom_(cur_ctx, &cur_stream));
886       }
887       // REVIEW ME: Should probably try and load-balance these. Consider the case where nv = 9;
888       // it is very likely better to do 4+5 rather than 8+1
889       switch (nv - yidx) {
890       case 7:
891         PetscCall(mdot_kernel_dispatch_<7>(cur_ctx, cur_stream, xptr.data(), yin, n, d_results, yidx));
892         break;
893       case 6:
894         PetscCall(mdot_kernel_dispatch_<6>(cur_ctx, cur_stream, xptr.data(), yin, n, d_results, yidx));
895         break;
896       case 5:
897         PetscCall(mdot_kernel_dispatch_<5>(cur_ctx, cur_stream, xptr.data(), yin, n, d_results, yidx));
898         break;
899       case 4:
900         PetscCall(mdot_kernel_dispatch_<4>(cur_ctx, cur_stream, xptr.data(), yin, n, d_results, yidx));
901         break;
902       case 3:
903         PetscCall(mdot_kernel_dispatch_<3>(cur_ctx, cur_stream, xptr.data(), yin, n, d_results, yidx));
904         break;
905       case 2:
906         PetscCall(mdot_kernel_dispatch_<2>(cur_ctx, cur_stream, xptr.data(), yin, n, d_results, yidx));
907         break;
908       case 1: {
909         cupmBlasHandle_t cupmBlasHandle;
910 
911         PetscCall(GetHandlesFrom_(cur_ctx, &cupmBlasHandle));
912         PetscCallCUPMBLAS(cupmBlasXdot(cupmBlasHandle, static_cast<cupmBlasInt_t>(n), DeviceArrayRead(cur_ctx, yin[yidx]).cupmdata(), 1, xptr.cupmdata(), 1, cupmScalarPtrCast(z + yidx)));
913         ++yidx;
914       } break;
915       default: // 8 or more
916         PetscCall(mdot_kernel_dispatch_<8>(cur_ctx, cur_stream, xptr.data(), yin, n, d_results, yidx));
917         break;
918       }
919     } while (yidx < nv);
920     PetscCall(PetscLogGpuTimeEnd());
921     PetscCall(PetscDeviceContextJoin(dctx, num_sub_streams, PETSC_DEVICE_CONTEXT_JOIN_DESTROY, &sub));
922   }
923 
924   PetscCall(PetscCUPMLaunchKernel1D(nvbatch, 0, stream, kernels::sum_kernel, nvbatch, d_results));
925   // copy result of device reduction to host
926   PetscCall(PetscCUPMMemcpyAsync(z, d_results, nvbatch, cupmMemcpyDeviceToHost, stream));
927   // do these now while final reduction is in flight
928   PetscCall(PetscLogFlops(nwork));
929   PetscCall(PetscDeviceFree(dctx, d_results));
930   PetscFunctionReturn(0);
931 }
932 
933   #undef MDOT_WORKGROUP_NUM
934   #undef MDOT_WORKGROUP_SIZE
935 
936 template <device::cupm::DeviceType T>
937 inline PetscErrorCode VecSeq_CUPM<T>::mdot_(std::true_type, Vec xin, PetscInt nv, const Vec yin[], PetscScalar *z, PetscDeviceContext dctx) noexcept
938 {
939   // probably not worth it to run more than 8 of these at a time?
940   const auto          n_sub = PetscMin(nv, 8);
941   const auto          n     = static_cast<cupmBlasInt_t>(xin->map->n);
942   const auto          xptr  = DeviceArrayRead(dctx, xin);
943   PetscScalar        *d_z;
944   PetscDeviceContext *subctx;
945   cupmStream_t        stream;
946 
947   PetscFunctionBegin;
948   PetscCall(GetHandlesFrom_(dctx, &stream));
949   PetscCall(PetscDeviceMalloc(dctx, PETSC_MEMTYPE_CUPM(), nv, &d_z));
950   PetscCall(PetscDeviceContextFork(dctx, n_sub, &subctx));
951   PetscCall(PetscLogGpuTimeBegin());
952   for (PetscInt i = 0; i < nv; ++i) {
953     const auto            sub = subctx[i % n_sub];
954     cupmBlasHandle_t      handle;
955     cupmBlasPointerMode_t old_mode;
956 
957     PetscCall(GetHandlesFrom_(sub, &handle));
958     PetscCallCUPMBLAS(cupmBlasGetPointerMode(handle, &old_mode));
959     if (old_mode != CUPMBLAS_POINTER_MODE_DEVICE) PetscCallCUPMBLAS(cupmBlasSetPointerMode(handle, CUPMBLAS_POINTER_MODE_DEVICE));
960     PetscCallCUPMBLAS(cupmBlasXdot(handle, n, DeviceArrayRead(sub, yin[i]), 1, xptr.cupmdata(), 1, cupmScalarPtrCast(d_z + i)));
961     if (old_mode != CUPMBLAS_POINTER_MODE_DEVICE) PetscCallCUPMBLAS(cupmBlasSetPointerMode(handle, old_mode));
962   }
963   PetscCall(PetscLogGpuTimeEnd());
964   PetscCall(PetscDeviceContextJoin(dctx, n_sub, PETSC_DEVICE_CONTEXT_JOIN_DESTROY, &subctx));
965   PetscCall(PetscCUPMMemcpyAsync(z, d_z, nv, cupmMemcpyDeviceToHost, stream));
966   PetscCall(PetscDeviceFree(dctx, d_z));
967   // REVIEW ME: flops?????
968   PetscFunctionReturn(0);
969 }
970 
971 // v->ops->mdot
972 template <device::cupm::DeviceType T>
973 inline PetscErrorCode VecSeq_CUPM<T>::mdot(Vec xin, PetscInt nv, const Vec yin[], PetscScalar *z) noexcept
974 {
975   PetscFunctionBegin;
976   if (PetscUnlikely(nv == 1)) {
977     // dot handles nv = 0 correctly
978     PetscCall(dot(xin, const_cast<Vec>(yin[0]), z));
979   } else if (const auto n = xin->map->n) {
980     PetscDeviceContext dctx;
981 
982     PetscCheck(nv > 0, PETSC_COMM_SELF, PETSC_ERR_LIB, "Number of vectors provided to %s %" PetscInt_FMT " not positive", PETSC_FUNCTION_NAME, nv);
983     PetscCall(GetHandles_(&dctx));
984     PetscCall(mdot_(std::integral_constant<bool, PetscDefined(USE_COMPLEX)>{}, xin, nv, yin, z, dctx));
985     // REVIEW ME: double count of flops??
986     PetscCall(PetscLogGpuFlops(nv * (2 * n - 1)));
987     PetscCall(PetscDeviceContextSynchronize(dctx));
988   } else {
989     PetscCall(PetscArrayzero(z, nv));
990   }
991   PetscFunctionReturn(0);
992 }
993 
994 // v->ops->set
995 template <device::cupm::DeviceType T>
996 inline PetscErrorCode VecSeq_CUPM<T>::set(Vec xin, PetscScalar alpha) noexcept
997 {
998   const auto         n = xin->map->n;
999   PetscDeviceContext dctx;
1000   cupmStream_t       stream;
1001 
1002   PetscFunctionBegin;
1003   PetscCall(GetHandles_(&dctx, &stream));
1004   {
1005     const auto xptr = DeviceArrayWrite(dctx, xin);
1006 
1007     if (alpha == PetscScalar(0.0)) {
1008       PetscCall(PetscCUPMMemsetAsync(xptr.data(), 0, n, stream));
1009     } else {
1010       PetscCall(device::cupm::impl::ThrustSet<T>(stream, n, xptr.data(), &alpha));
1011     }
1012     if (n) PetscCall(PetscDeviceContextSynchronize(dctx)); // don't sync if we did nothing
1013   }
1014   PetscFunctionReturn(0);
1015 }
1016 
1017 // v->ops->scale
1018 template <device::cupm::DeviceType T>
1019 inline PetscErrorCode VecSeq_CUPM<T>::scale(Vec xin, PetscScalar alpha) noexcept
1020 {
1021   PetscFunctionBegin;
1022   if (PetscUnlikely(alpha == PetscScalar(1.0))) PetscFunctionReturn(0);
1023   if (PetscUnlikely(alpha == PetscScalar(0.0))) {
1024     PetscCall(set(xin, alpha));
1025   } else if (const auto n = static_cast<cupmBlasInt_t>(xin->map->n)) {
1026     PetscDeviceContext dctx;
1027     cupmBlasHandle_t   cupmBlasHandle;
1028 
1029     PetscCall(GetHandles_(&dctx, &cupmBlasHandle));
1030     PetscCall(PetscLogGpuTimeBegin());
1031     PetscCallCUPMBLAS(cupmBlasXscal(cupmBlasHandle, n, cupmScalarPtrCast(&alpha), DeviceArrayReadWrite(dctx, xin), 1));
1032     PetscCall(PetscLogGpuTimeEnd());
1033     PetscCall(PetscLogGpuFlops(n));
1034     PetscCall(PetscDeviceContextSynchronize(dctx));
1035   } else {
1036     PetscCall(MaybeIncrementEmptyLocalVec(xin));
1037   }
1038   PetscFunctionReturn(0);
1039 }
1040 
1041 // v->ops->tdot
1042 template <device::cupm::DeviceType T>
1043 inline PetscErrorCode VecSeq_CUPM<T>::tdot(Vec xin, Vec yin, PetscScalar *z) noexcept
1044 {
1045   PetscFunctionBegin;
1046   if (const auto n = static_cast<cupmBlasInt_t>(xin->map->n)) {
1047     PetscDeviceContext dctx;
1048     cupmBlasHandle_t   cupmBlasHandle;
1049 
1050     PetscCall(GetHandles_(&dctx, &cupmBlasHandle));
1051     PetscCall(PetscLogGpuTimeBegin());
1052     PetscCallCUPMBLAS(cupmBlasXdotu(cupmBlasHandle, n, DeviceArrayRead(dctx, xin), 1, DeviceArrayRead(dctx, yin), 1, cupmScalarPtrCast(z)));
1053     PetscCall(PetscLogGpuTimeEnd());
1054     PetscCall(PetscLogGpuFlops(2 * n - 1));
1055   } else {
1056     *z = 0.0;
1057   }
1058   PetscFunctionReturn(0);
1059 }
1060 
1061 // v->ops->copy
1062 template <device::cupm::DeviceType T>
1063 inline PetscErrorCode VecSeq_CUPM<T>::copy(Vec xin, Vec yout) noexcept
1064 {
1065   PetscFunctionBegin;
1066   if (xin == yout) PetscFunctionReturn(0);
1067   if (const auto n = xin->map->n) {
1068     const auto xmask = xin->offloadmask;
1069     // silence buggy gcc warning: mode may be used uninitialized in this function
1070     auto               mode = cupmMemcpyDeviceToDevice;
1071     PetscDeviceContext dctx;
1072     cupmStream_t       stream;
1073 
1074     // translate from PetscOffloadMask to cupmMemcpyKind
1075     switch (const auto ymask = yout->offloadmask) {
1076     case PETSC_OFFLOAD_UNALLOCATED: {
1077       PetscBool yiscupm;
1078 
1079       PetscCall(PetscObjectTypeCompareAny(PetscObjectCast(yout), &yiscupm, VECSEQCUPM(), VECMPICUPM(), ""));
1080       if (yiscupm) {
1081         mode = PetscOffloadDevice(xmask) ? cupmMemcpyDeviceToDevice : cupmMemcpyHostToHost;
1082         break;
1083       }
1084     } // fall-through if unallocated and not cupm
1085     case PETSC_OFFLOAD_CPU:
1086       mode = PetscOffloadHost(xmask) ? cupmMemcpyHostToHost : cupmMemcpyDeviceToHost;
1087       break;
1088     case PETSC_OFFLOAD_BOTH:
1089     case PETSC_OFFLOAD_GPU:
1090       mode = PetscOffloadDevice(xmask) ? cupmMemcpyDeviceToDevice : cupmMemcpyHostToDevice;
1091       break;
1092     default:
1093       SETERRQ(PETSC_COMM_SELF, PETSC_ERR_ARG_INCOMP, "Incompatible offload mask %s", PetscOffloadMaskToString(ymask));
1094     }
1095 
1096     PetscCall(GetHandles_(&dctx, &stream));
1097     switch (mode) {
1098     case cupmMemcpyDeviceToDevice: // the best case
1099     case cupmMemcpyHostToDevice: { // not terrible
1100       const auto yptr = DeviceArrayWrite(dctx, yout);
1101       const auto xptr = mode == cupmMemcpyDeviceToDevice ? DeviceArrayRead(dctx, xin).data() : HostArrayRead(dctx, xin).data();
1102 
1103       PetscCall(PetscLogGpuTimeBegin());
1104       PetscCall(PetscCUPMMemcpyAsync(yptr.data(), xptr, n, mode, stream));
1105       PetscCall(PetscLogGpuTimeEnd());
1106     } break;
1107     case cupmMemcpyDeviceToHost: // not great
1108     case cupmMemcpyHostToHost: { // worst case
1109       const auto   xptr = mode == cupmMemcpyDeviceToHost ? DeviceArrayRead(dctx, xin).data() : HostArrayRead(dctx, xin).data();
1110       PetscScalar *yptr;
1111 
1112       PetscCall(VecGetArrayWrite(yout, &yptr));
1113       if (mode == cupmMemcpyDeviceToHost) PetscCall(PetscLogGpuTimeBegin());
1114       PetscCall(PetscCUPMMemcpyAsync(yptr, xptr, n, mode, stream, /* force async */ true));
1115       if (mode == cupmMemcpyDeviceToHost) PetscCall(PetscLogGpuTimeEnd());
1116       PetscCall(VecRestoreArrayWrite(yout, &yptr));
1117     } break;
1118     default:
1119       SETERRQ(PETSC_COMM_SELF, PETSC_ERR_GPU, "Unknown cupmMemcpyKind %d", static_cast<int>(mode));
1120     }
1121     PetscCall(PetscDeviceContextSynchronize(dctx));
1122   } else {
1123     PetscCall(MaybeIncrementEmptyLocalVec(yout));
1124   }
1125   PetscFunctionReturn(0);
1126 }
1127 
1128 // v->ops->swap
1129 template <device::cupm::DeviceType T>
1130 inline PetscErrorCode VecSeq_CUPM<T>::swap(Vec xin, Vec yin) noexcept
1131 {
1132   PetscFunctionBegin;
1133   if (xin == yin) PetscFunctionReturn(0);
1134   if (const auto n = static_cast<cupmBlasInt_t>(xin->map->n)) {
1135     PetscDeviceContext dctx;
1136     cupmBlasHandle_t   cupmBlasHandle;
1137 
1138     PetscCall(GetHandles_(&dctx, &cupmBlasHandle));
1139     PetscCall(PetscLogGpuTimeBegin());
1140     PetscCallCUPMBLAS(cupmBlasXswap(cupmBlasHandle, n, DeviceArrayReadWrite(dctx, xin), 1, DeviceArrayReadWrite(dctx, yin), 1));
1141     PetscCall(PetscLogGpuTimeEnd());
1142     PetscCall(PetscDeviceContextSynchronize(dctx));
1143   } else {
1144     PetscCall(MaybeIncrementEmptyLocalVec(xin));
1145     PetscCall(MaybeIncrementEmptyLocalVec(yin));
1146   }
1147   PetscFunctionReturn(0);
1148 }
1149 
1150 // v->ops->axpby
1151 template <device::cupm::DeviceType T>
1152 inline PetscErrorCode VecSeq_CUPM<T>::axpby(Vec yin, PetscScalar alpha, PetscScalar beta, Vec xin) noexcept
1153 {
1154   PetscFunctionBegin;
1155   if (alpha == PetscScalar(0.0)) {
1156     PetscCall(scale(yin, beta));
1157   } else if (beta == PetscScalar(1.0)) {
1158     PetscCall(axpy(yin, alpha, xin));
1159   } else if (alpha == PetscScalar(1.0)) {
1160     PetscCall(aypx(yin, beta, xin));
1161   } else if (const auto n = static_cast<cupmBlasInt_t>(yin->map->n)) {
1162     const auto         betaIsZero = beta == PetscScalar(0.0);
1163     const auto         aptr       = cupmScalarPtrCast(&alpha);
1164     PetscDeviceContext dctx;
1165     cupmBlasHandle_t   cupmBlasHandle;
1166 
1167     PetscCall(GetHandles_(&dctx, &cupmBlasHandle));
1168     {
1169       const auto xptr = DeviceArrayRead(dctx, xin);
1170 
1171       if (betaIsZero /* beta = 0 */) {
1172         // here we can get away with purely write-only as we memcpy into it first
1173         const auto   yptr = DeviceArrayWrite(dctx, yin);
1174         cupmStream_t stream;
1175 
1176         PetscCall(GetHandlesFrom_(dctx, &stream));
1177         PetscCall(PetscLogGpuTimeBegin());
1178         PetscCall(PetscCUPMMemcpyAsync(yptr.data(), xptr.data(), n, cupmMemcpyDeviceToDevice, stream));
1179         PetscCallCUPMBLAS(cupmBlasXscal(cupmBlasHandle, n, aptr, yptr.cupmdata(), 1));
1180       } else {
1181         const auto yptr = DeviceArrayReadWrite(dctx, yin);
1182 
1183         PetscCall(PetscLogGpuTimeBegin());
1184         PetscCallCUPMBLAS(cupmBlasXscal(cupmBlasHandle, n, cupmScalarPtrCast(&beta), yptr.cupmdata(), 1));
1185         PetscCallCUPMBLAS(cupmBlasXaxpy(cupmBlasHandle, n, aptr, xptr.cupmdata(), 1, yptr.cupmdata(), 1));
1186       }
1187     }
1188     PetscCall(PetscLogGpuTimeEnd());
1189     PetscCall(PetscLogGpuFlops((betaIsZero ? 1 : 3) * n));
1190     PetscCall(PetscDeviceContextSynchronize(dctx));
1191   } else {
1192     PetscCall(MaybeIncrementEmptyLocalVec(yin));
1193   }
1194   PetscFunctionReturn(0);
1195 }
1196 
1197 // v->ops->axpbypcz
1198 template <device::cupm::DeviceType T>
1199 inline PetscErrorCode VecSeq_CUPM<T>::axpbypcz(Vec zin, PetscScalar alpha, PetscScalar beta, PetscScalar gamma, Vec xin, Vec yin) noexcept
1200 {
1201   PetscFunctionBegin;
1202   if (gamma != PetscScalar(1.0)) PetscCall(scale(zin, gamma));
1203   PetscCall(axpy(zin, alpha, xin));
1204   PetscCall(axpy(zin, beta, yin));
1205   PetscFunctionReturn(0);
1206 }
1207 
1208 // v->ops->norm
1209 template <device::cupm::DeviceType T>
1210 inline PetscErrorCode VecSeq_CUPM<T>::norm(Vec xin, NormType type, PetscReal *z) noexcept
1211 {
1212   PetscDeviceContext dctx;
1213   cupmBlasHandle_t   cupmBlasHandle;
1214 
1215   PetscFunctionBegin;
1216   PetscCall(GetHandles_(&dctx, &cupmBlasHandle));
1217   if (const auto n = static_cast<cupmBlasInt_t>(xin->map->n)) {
1218     const auto xptr      = DeviceArrayRead(dctx, xin);
1219     PetscInt   flopCount = 0;
1220 
1221     PetscCall(PetscLogGpuTimeBegin());
1222     switch (type) {
1223     case NORM_1_AND_2:
1224     case NORM_1:
1225       PetscCallCUPMBLAS(cupmBlasXasum(cupmBlasHandle, n, xptr.cupmdata(), 1, cupmRealPtrCast(z)));
1226       flopCount = std::max(n - 1, 0);
1227       if (type == NORM_1) break;
1228       ++z; // fall-through
1229     case NORM_2:
1230     case NORM_FROBENIUS:
1231       PetscCallCUPMBLAS(cupmBlasXnrm2(cupmBlasHandle, n, xptr.cupmdata(), 1, cupmRealPtrCast(z)));
1232       flopCount += std::max(2 * n - 1, 0); // += in case we've fallen through from NORM_1_AND_2
1233       break;
1234     case NORM_INFINITY: {
1235       cupmBlasInt_t max_loc = 0;
1236       PetscScalar   xv      = 0.;
1237       cupmStream_t  stream;
1238 
1239       PetscCall(GetHandlesFrom_(dctx, &stream));
1240       PetscCallCUPMBLAS(cupmBlasXamax(cupmBlasHandle, n, xptr.cupmdata(), 1, &max_loc));
1241       PetscCall(PetscCUPMMemcpyAsync(&xv, xptr.data() + max_loc - 1, 1, cupmMemcpyDeviceToHost, stream));
1242       *z = PetscAbsScalar(xv);
1243       // REVIEW ME: flopCount = ???
1244     } break;
1245     }
1246     PetscCall(PetscLogGpuTimeEnd());
1247     PetscCall(PetscLogGpuFlops(flopCount));
1248   } else {
1249     z[0]                    = 0.0;
1250     z[type == NORM_1_AND_2] = 0.0;
1251   }
1252   PetscFunctionReturn(0);
1253 }
1254 
1255 namespace detail
1256 {
1257 
1258 struct dotnorm2_mult {
1259   PETSC_NODISCARD PETSC_HOSTDEVICE_INLINE_DECL thrust::tuple<PetscScalar, PetscScalar> operator()(const PetscScalar &s, const PetscScalar &t) const noexcept
1260   {
1261     const auto conjt = PetscConj(t);
1262 
1263     return {s * conjt, t * conjt};
1264   }
1265 };
1266 
1267 // it is positively __bananas__ that thrust does not define default operator+ for tuples... I
1268 // would do it myself but now I am worried that they do so on purpose...
1269 struct dotnorm2_tuple_plus {
1270   using value_type = thrust::tuple<PetscScalar, PetscScalar>;
1271 
1272   PETSC_NODISCARD PETSC_HOSTDEVICE_INLINE_DECL value_type operator()(const value_type &lhs, const value_type &rhs) const noexcept { return {lhs.get<0>() + rhs.get<0>(), lhs.get<1>() + rhs.get<1>()}; }
1273 };
1274 
1275 } // namespace detail
1276 
1277 // v->ops->dotnorm2
1278 template <device::cupm::DeviceType T>
1279 inline PetscErrorCode VecSeq_CUPM<T>::dotnorm2(Vec s, Vec t, PetscScalar *dp, PetscScalar *nm) noexcept
1280 {
1281   PetscDeviceContext dctx;
1282   cupmStream_t       stream;
1283 
1284   PetscFunctionBegin;
1285   PetscCall(GetHandles_(&dctx, &stream));
1286   {
1287     PetscScalar dpt = 0.0, nmt = 0.0;
1288     const auto  sdptr = thrust::device_pointer_cast(DeviceArrayRead(dctx, s).data());
1289 
1290     // clang-format off
1291     PetscCallThrust(
1292       thrust::tie(*dp, *nm) = THRUST_CALL(
1293         thrust::inner_product,
1294         stream,
1295         sdptr, sdptr+s->map->n, thrust::device_pointer_cast(DeviceArrayRead(dctx, t).data()),
1296         thrust::make_tuple(dpt, nmt),
1297         detail::dotnorm2_tuple_plus{}, detail::dotnorm2_mult{}
1298       );
1299     );
1300     // clang-format on
1301   }
1302   PetscFunctionReturn(0);
1303 }
1304 
1305 namespace detail
1306 {
1307 
1308 struct conjugate {
1309   PETSC_NODISCARD PETSC_HOSTDEVICE_INLINE_DECL PetscScalar operator()(PetscScalar x) const noexcept { return PetscConj(x); }
1310 };
1311 
1312 } // namespace detail
1313 
1314 // v->ops->conjugate
1315 template <device::cupm::DeviceType T>
1316 inline PetscErrorCode VecSeq_CUPM<T>::conjugate(Vec xin) noexcept
1317 {
1318   PetscFunctionBegin;
1319   if (PetscDefined(USE_COMPLEX)) PetscCall(pointwiseunary_(detail::conjugate{}, xin));
1320   PetscFunctionReturn(0);
1321 }
1322 
1323 namespace detail
1324 {
1325 
1326 struct real_part {
1327   PETSC_NODISCARD PETSC_HOSTDEVICE_INLINE_DECL thrust::tuple<PetscReal, PetscInt> operator()(const thrust::tuple<PetscScalar, PetscInt> &x) const { return {PetscRealPart(x.get<0>()), x.get<1>()}; }
1328 
1329   PETSC_NODISCARD PETSC_HOSTDEVICE_INLINE_DECL PetscReal operator()(PetscScalar x) const { return PetscRealPart(x); }
1330 };
1331 
1332 // deriving from Operator allows us to "store" an instance of the operator in the class but
1333 // also take advantage of empty base class optimization if the operator is stateless
1334 template <typename Operator>
1335 class tuple_compare : Operator {
1336 public:
1337   using tuple_type    = thrust::tuple<PetscReal, PetscInt>;
1338   using operator_type = Operator;
1339 
1340   PETSC_NODISCARD PETSC_HOSTDEVICE_INLINE_DECL tuple_type operator()(const tuple_type &x, const tuple_type &y) const noexcept
1341   {
1342     if (op_()(y.get<0>(), x.get<0>())) {
1343       // if y is strictly greater/less than x, return y
1344       return y;
1345     } else if (y.get<0>() == x.get<0>()) {
1346       // if equal, prefer lower index
1347       return y.get<1>() < x.get<1>() ? y : x;
1348     }
1349     // otherwise return x
1350     return x;
1351   }
1352 
1353 private:
1354   PETSC_NODISCARD PETSC_HOSTDEVICE_INLINE_DECL const operator_type &op_() const noexcept { return *this; }
1355 };
1356 
1357 } // namespace detail
1358 
1359 template <device::cupm::DeviceType T>
1360 template <typename TupleFuncT, typename UnaryFuncT>
1361 inline PetscErrorCode VecSeq_CUPM<T>::minmax_(TupleFuncT &&tuple_ftr, UnaryFuncT &&unary_ftr, Vec v, PetscInt *p, PetscReal *m) noexcept
1362 {
1363   PetscFunctionBegin;
1364   PetscCheckTypeNames(v, VECSEQCUPM(), VECMPICUPM());
1365   if (p) *p = -1;
1366   if (const auto n = v->map->n) {
1367     PetscDeviceContext dctx;
1368     cupmStream_t       stream;
1369 
1370     PetscCall(GetHandles_(&dctx, &stream));
1371       // needed to:
1372       // 1. switch between transform_reduce and reduce
1373       // 2. strip the real_part functor from the arguments
1374   #if PetscDefined(USE_COMPLEX)
1375     #define THRUST_MINMAX_REDUCE(...) THRUST_CALL(thrust::transform_reduce, __VA_ARGS__)
1376   #else
1377     #define THRUST_MINMAX_REDUCE(s, b, e, real_part__, ...) THRUST_CALL(thrust::reduce, s, b, e, __VA_ARGS__)
1378   #endif
1379     {
1380       const auto vptr = thrust::device_pointer_cast(DeviceArrayRead(dctx, v).data());
1381 
1382       if (p) {
1383         // clang-format off
1384         const auto zip = thrust::make_zip_iterator(
1385           thrust::make_tuple(std::move(vptr), thrust::make_counting_iterator(PetscInt{0}))
1386         );
1387         // clang-format on
1388         // need to use preprocessor conditionals since otherwise thrust complains about not being
1389         // able to convert a thrust::device_reference<PetscScalar> to a PetscReal on complex
1390         // builds...
1391         // clang-format off
1392         PetscCallThrust(
1393           thrust::tie(*m, *p) = THRUST_MINMAX_REDUCE(
1394             stream, zip, zip + n, detail::real_part{},
1395             thrust::make_tuple(*m, *p), std::forward<TupleFuncT>(tuple_ftr)
1396           );
1397         );
1398         // clang-format on
1399       } else {
1400         // clang-format off
1401         PetscCallThrust(
1402           *m = THRUST_MINMAX_REDUCE(
1403             stream, vptr, vptr + n, detail::real_part{},
1404             *m, std::forward<UnaryFuncT>(unary_ftr)
1405           );
1406         );
1407         // clang-format on
1408       }
1409     }
1410   #undef THRUST_MINMAX_REDUCE
1411   }
1412   // REVIEW ME: flops?
1413   PetscFunctionReturn(0);
1414 }
1415 
1416 // v->ops->max
1417 template <device::cupm::DeviceType T>
1418 inline PetscErrorCode VecSeq_CUPM<T>::max(Vec v, PetscInt *p, PetscReal *m) noexcept
1419 {
1420   using tuple_functor = detail::tuple_compare<thrust::greater<PetscReal>>;
1421   using unary_functor = thrust::maximum<PetscReal>;
1422 
1423   PetscFunctionBegin;
1424   *m = PETSC_MIN_REAL;
1425   // use {} constructor syntax otherwise most vexing parse
1426   PetscCall(minmax_(tuple_functor{}, unary_functor{}, v, p, m));
1427   PetscFunctionReturn(0);
1428 }
1429 
1430 // v->ops->min
1431 template <device::cupm::DeviceType T>
1432 inline PetscErrorCode VecSeq_CUPM<T>::min(Vec v, PetscInt *p, PetscReal *m) noexcept
1433 {
1434   using tuple_functor = detail::tuple_compare<thrust::less<PetscReal>>;
1435   using unary_functor = thrust::minimum<PetscReal>;
1436 
1437   PetscFunctionBegin;
1438   *m = PETSC_MAX_REAL;
1439   // use {} constructor syntax otherwise most vexing parse
1440   PetscCall(minmax_(tuple_functor{}, unary_functor{}, v, p, m));
1441   PetscFunctionReturn(0);
1442 }
1443 
1444 // v->ops->sum
1445 template <device::cupm::DeviceType T>
1446 inline PetscErrorCode VecSeq_CUPM<T>::sum(Vec v, PetscScalar *sum) noexcept
1447 {
1448   PetscFunctionBegin;
1449   if (const auto n = v->map->n) {
1450     PetscDeviceContext dctx;
1451     cupmStream_t       stream;
1452 
1453     PetscCall(GetHandles_(&dctx, &stream));
1454     const auto dptr = thrust::device_pointer_cast(DeviceArrayRead(dctx, v).data());
1455     // REVIEW ME: why not cupmBlasXasum()?
1456     PetscCallThrust(*sum = THRUST_CALL(thrust::reduce, stream, dptr, dptr + n, PetscScalar{0.0}););
1457     // REVIEW ME: must be at least n additions
1458     PetscCall(PetscLogGpuFlops(n));
1459   } else {
1460     *sum = 0.0;
1461   }
1462   PetscFunctionReturn(0);
1463 }
1464 
1465 namespace detail
1466 {
1467 
1468 template <typename T>
1469 class plus_equals {
1470 public:
1471   using value_type = T;
1472 
1473   PETSC_HOSTDEVICE_DECL constexpr explicit plus_equals(value_type v = value_type{}) noexcept : v_(std::move(v)) { }
1474 
1475   PETSC_NODISCARD PETSC_HOSTDEVICE_INLINE_DECL constexpr T operator()(const T &val) const noexcept { return val + v_; }
1476 
1477 private:
1478   value_type v_;
1479 };
1480 
1481 } // namespace detail
1482 
1483 template <device::cupm::DeviceType T>
1484 inline PetscErrorCode VecSeq_CUPM<T>::shift(Vec v, PetscScalar shift) noexcept
1485 {
1486   PetscFunctionBegin;
1487   PetscCall(pointwiseunary_(detail::plus_equals<PetscScalar>{shift}, v));
1488   PetscFunctionReturn(0);
1489 }
1490 
1491 template <device::cupm::DeviceType T>
1492 inline PetscErrorCode VecSeq_CUPM<T>::setrandom(Vec v, PetscRandom rand) noexcept
1493 {
1494   PetscFunctionBegin;
1495   if (const auto n = v->map->n) {
1496     PetscBool          iscurand;
1497     PetscDeviceContext dctx;
1498 
1499     PetscCall(GetHandles_(&dctx));
1500     PetscCall(PetscObjectTypeCompare(PetscObjectCast(rand), PETSCCURAND, &iscurand));
1501     if (iscurand) PetscCall(PetscRandomGetValues(rand, n, DeviceArrayWrite(dctx, v)));
1502     else PetscCall(PetscRandomGetValues(rand, n, HostArrayWrite(dctx, v)));
1503   } else {
1504     PetscCall(MaybeIncrementEmptyLocalVec(v));
1505   }
1506   // REVIEW ME: flops????
1507   // REVIEW ME: Timing???
1508   PetscFunctionReturn(0);
1509 }
1510 
1511 // v->ops->setpreallocation
1512 template <device::cupm::DeviceType T>
1513 inline PetscErrorCode VecSeq_CUPM<T>::setpreallocationcoo(Vec v, PetscCount ncoo, const PetscInt coo_i[]) noexcept
1514 {
1515   PetscDeviceContext dctx;
1516 
1517   PetscFunctionBegin;
1518   PetscCall(GetHandles_(&dctx));
1519   PetscCall(VecSetPreallocationCOO_Seq(v, ncoo, coo_i));
1520   PetscCall(SetPreallocationCOO_CUPMBase(v, ncoo, coo_i, dctx));
1521   PetscFunctionReturn(0);
1522 }
1523 
1524 namespace kernels
1525 {
1526 
1527 template <typename F>
1528 PETSC_DEVICE_INLINE_DECL void add_coo_values_impl(const PetscScalar *PETSC_RESTRICT vv, PetscCount n, const PetscCount *PETSC_RESTRICT jmap, const PetscCount *PETSC_RESTRICT perm, InsertMode imode, PetscScalar *PETSC_RESTRICT xv, F &&xvindex)
1529 {
1530   ::Petsc::device::cupm::kernels::util::grid_stride_1D(n, [=](PetscCount i) {
1531     const auto  end = jmap[i + 1];
1532     const auto  idx = xvindex(i);
1533     PetscScalar sum = 0.0;
1534 
1535     for (auto k = jmap[i]; k < end; ++k) sum += vv[perm[k]];
1536 
1537     if (imode == INSERT_VALUES) {
1538       xv[idx] = sum;
1539     } else {
1540       xv[idx] += sum;
1541     }
1542   });
1543   return;
1544 }
1545 
1546 PETSC_KERNEL_DECL static void add_coo_values(const PetscScalar *PETSC_RESTRICT v, PetscCount n, const PetscCount *PETSC_RESTRICT jmap1, const PetscCount *PETSC_RESTRICT perm1, InsertMode imode, PetscScalar *PETSC_RESTRICT xv)
1547 {
1548   add_coo_values_impl(v, n, jmap1, perm1, imode, xv, [](PetscCount i) { return i; });
1549   return;
1550 }
1551 
1552 } // namespace kernels
1553 
1554 // v->ops->setvaluescoo
1555 template <device::cupm::DeviceType T>
1556 inline PetscErrorCode VecSeq_CUPM<T>::setvaluescoo(Vec x, const PetscScalar v[], InsertMode imode) noexcept
1557 {
1558   auto               vv = const_cast<PetscScalar *>(v);
1559   PetscMemType       memtype;
1560   PetscDeviceContext dctx;
1561   cupmStream_t       stream;
1562 
1563   PetscFunctionBegin;
1564   PetscCall(GetHandles_(&dctx, &stream));
1565   PetscCall(PetscGetMemType(v, &memtype));
1566   if (PetscMemTypeHost(memtype)) {
1567     const auto size = VecIMPLCast(x)->coo_n;
1568 
1569     // If user gave v[] in host, we might need to copy it to device if any
1570     PetscCall(PetscDeviceMalloc(dctx, PETSC_MEMTYPE_CUPM(), size, &vv));
1571     PetscCall(PetscCUPMMemcpyAsync(vv, v, size, cupmMemcpyHostToDevice, stream));
1572   }
1573 
1574   if (const auto n = x->map->n) {
1575     const auto vcu = VecCUPMCast(x);
1576 
1577     PetscCall(PetscCUPMLaunchKernel1D(n, 0, stream, kernels::add_coo_values, vv, n, vcu->jmap1_d, vcu->perm1_d, imode, imode == INSERT_VALUES ? DeviceArrayWrite(dctx, x).data() : DeviceArrayReadWrite(dctx, x).data()));
1578   } else {
1579     PetscCall(MaybeIncrementEmptyLocalVec(x));
1580   }
1581 
1582   if (PetscMemTypeHost(memtype)) PetscCall(PetscDeviceFree(dctx, vv));
1583   PetscCall(PetscDeviceContextSynchronize(dctx));
1584   PetscFunctionReturn(0);
1585 }
1586 
1587 // ==========================================================================================
1588 // VecSeq_CUPM - Implementations
1589 // ==========================================================================================
1590 
1591 namespace
1592 {
1593 
1594 template <typename T>
1595 PETSC_NODISCARD inline PetscErrorCode VecCreateSeqCUPMAsync(T &&VecSeq_CUPM_Impls, MPI_Comm comm, PetscInt n, Vec *v) noexcept
1596 {
1597   PetscFunctionBegin;
1598   PetscValidPointer(v, 4);
1599   PetscCall(VecSeq_CUPM_Impls.createseqcupm(comm, 0, n, v, PETSC_TRUE));
1600   PetscFunctionReturn(0);
1601 }
1602 
1603 template <typename T>
1604 PETSC_NODISCARD inline PetscErrorCode VecCreateSeqCUPMWithArraysAsync(T &&VecSeq_CUPM_Impls, MPI_Comm comm, PetscInt bs, PetscInt n, const PetscScalar cpuarray[], const PetscScalar gpuarray[], Vec *v) noexcept
1605 {
1606   PetscFunctionBegin;
1607   if (n && cpuarray) PetscValidScalarPointer(cpuarray, 5);
1608   PetscValidPointer(v, 7);
1609   PetscCall(VecSeq_CUPM_Impls.createseqcupmwithbotharrays(comm, bs, n, cpuarray, gpuarray, v));
1610   PetscFunctionReturn(0);
1611 }
1612 
1613 template <PetscMemoryAccessMode mode, typename T>
1614 PETSC_NODISCARD inline PetscErrorCode VecCUPMGetArrayAsync_Private(T &&VecSeq_CUPM_Impls, Vec v, PetscScalar **a) noexcept
1615 {
1616   PetscFunctionBegin;
1617   PetscValidHeaderSpecific(v, VEC_CLASSID, 2);
1618   PetscValidPointer(a, 3);
1619   PetscCall(VecSeq_CUPM_Impls.template getarray<PETSC_MEMTYPE_DEVICE, mode>(v, a));
1620   PetscFunctionReturn(0);
1621 }
1622 
1623 template <PetscMemoryAccessMode mode, typename T>
1624 PETSC_NODISCARD inline PetscErrorCode VecCUPMRestoreArrayAsync_Private(T &&VecSeq_CUPM_Impls, Vec v, PetscScalar **a) noexcept
1625 {
1626   PetscFunctionBegin;
1627   PetscValidHeaderSpecific(v, VEC_CLASSID, 2);
1628   PetscCall(VecSeq_CUPM_Impls.template restorearray<PETSC_MEMTYPE_DEVICE, mode>(v, a));
1629   PetscFunctionReturn(0);
1630 }
1631 
1632 template <typename T>
1633 PETSC_NODISCARD inline PetscErrorCode VecCUPMGetArrayAsync(T &&VecSeq_CUPM_Impls, Vec v, PetscScalar **a) noexcept
1634 {
1635   PetscFunctionBegin;
1636   PetscCall(VecCUPMGetArrayAsync_Private<PETSC_MEMORY_ACCESS_READ_WRITE>(std::forward<T>(VecSeq_CUPM_Impls), v, a));
1637   PetscFunctionReturn(0);
1638 }
1639 
1640 template <typename T>
1641 PETSC_NODISCARD inline PetscErrorCode VecCUPMRestoreArrayAsync(T &&VecSeq_CUPM_Impls, Vec v, PetscScalar **a) noexcept
1642 {
1643   PetscFunctionBegin;
1644   PetscCall(VecCUPMRestoreArrayAsync_Private<PETSC_MEMORY_ACCESS_READ_WRITE>(std::forward<T>(VecSeq_CUPM_Impls), v, a));
1645   PetscFunctionReturn(0);
1646 }
1647 
1648 template <typename T>
1649 PETSC_NODISCARD inline PetscErrorCode VecCUPMGetArrayReadAsync(T &&VecSeq_CUPM_Impls, Vec v, const PetscScalar **a) noexcept
1650 {
1651   PetscFunctionBegin;
1652   PetscCall(VecCUPMGetArrayAsync_Private<PETSC_MEMORY_ACCESS_READ>(std::forward<T>(VecSeq_CUPM_Impls), v, const_cast<PetscScalar **>(a)));
1653   PetscFunctionReturn(0);
1654 }
1655 
1656 template <typename T>
1657 PETSC_NODISCARD inline PetscErrorCode VecCUPMRestoreArrayReadAsync(T &&VecSeq_CUPM_Impls, Vec v, const PetscScalar **a) noexcept
1658 {
1659   PetscFunctionBegin;
1660   PetscCall(VecCUPMRestoreArrayAsync_Private<PETSC_MEMORY_ACCESS_READ>(std::forward<T>(VecSeq_CUPM_Impls), v, const_cast<PetscScalar **>(a)));
1661   PetscFunctionReturn(0);
1662 }
1663 
1664 template <typename T>
1665 PETSC_NODISCARD inline PetscErrorCode VecCUPMGetArrayWriteAsync(T &&VecSeq_CUPM_Impls, Vec v, PetscScalar **a) noexcept
1666 {
1667   PetscFunctionBegin;
1668   PetscCall(VecCUPMGetArrayAsync_Private<PETSC_MEMORY_ACCESS_WRITE>(std::forward<T>(VecSeq_CUPM_Impls), v, a));
1669   PetscFunctionReturn(0);
1670 }
1671 
1672 template <typename T>
1673 PETSC_NODISCARD inline PetscErrorCode VecCUPMRestoreArrayWriteAsync(T &&VecSeq_CUPM_Impls, Vec v, PetscScalar **a) noexcept
1674 {
1675   PetscFunctionBegin;
1676   PetscCall(VecCUPMRestoreArrayAsync_Private<PETSC_MEMORY_ACCESS_WRITE>(std::forward<T>(VecSeq_CUPM_Impls), v, a));
1677   PetscFunctionReturn(0);
1678 }
1679 
1680 template <typename T>
1681 PETSC_NODISCARD inline PetscErrorCode VecCUPMPlaceArrayAsync(T &&VecSeq_CUPM_Impls, Vec vin, const PetscScalar a[]) noexcept
1682 {
1683   PetscFunctionBegin;
1684   PetscValidHeaderSpecific(vin, VEC_CLASSID, 2);
1685   PetscCall(VecSeq_CUPM_Impls.template placearray<PETSC_MEMTYPE_DEVICE>(vin, a));
1686   PetscFunctionReturn(0);
1687 }
1688 
1689 template <typename T>
1690 PETSC_NODISCARD inline PetscErrorCode VecCUPMReplaceArrayAsync(T &&VecSeq_CUPM_Impls, Vec vin, const PetscScalar a[]) noexcept
1691 {
1692   PetscFunctionBegin;
1693   PetscValidHeaderSpecific(vin, VEC_CLASSID, 2);
1694   PetscCall(VecSeq_CUPM_Impls.template replacearray<PETSC_MEMTYPE_DEVICE>(vin, a));
1695   PetscFunctionReturn(0);
1696 }
1697 
1698 template <typename T>
1699 PETSC_NODISCARD inline PetscErrorCode VecCUPMResetArrayAsync(T &&VecSeq_CUPM_Impls, Vec vin) noexcept
1700 {
1701   PetscFunctionBegin;
1702   PetscValidHeaderSpecific(vin, VEC_CLASSID, 2);
1703   PetscCall(VecSeq_CUPM_Impls.template resetarray<PETSC_MEMTYPE_DEVICE>(vin));
1704   PetscFunctionReturn(0);
1705 }
1706 
1707 } // anonymous namespace
1708 
1709 } // namespace impl
1710 
1711 } // namespace cupm
1712 
1713 } // namespace vec
1714 
1715 } // namespace Petsc
1716 
1717 #endif // __cplusplus
1718 
1719 #endif // PETSCVECSEQCUPM_HPP
1720