xref: /petsc/src/vec/vec/impls/seq/cupm/vecseqcupm.hpp (revision 90585354b1e2059ecf1e5cb64ef3b58c6db60aab)
1 #ifndef PETSCVECSEQCUPM_HPP
2 #define PETSCVECSEQCUPM_HPP
3 
4 #include <petsc/private/veccupmimpl.h>
5 
6 #if defined(__cplusplus)
7   #include <petsc/private/randomimpl.h> // for _p_PetscRandom
8 
9   #include <petsc/private/cpp/utility.hpp> // util::exchange, util::index_sequence
10 
11   #include "../src/sys/objects/device/impls/cupm/cupmthrustutility.hpp"
12   #include "../src/sys/objects/device/impls/cupm/kernels.hpp"
13 
14   #if PetscDefined(USE_COMPLEX)
15     #include <thrust/transform_reduce.h>
16   #endif
17   #include <thrust/transform.h>
18   #include <thrust/reduce.h>
19   #include <thrust/functional.h>
20   #include <thrust/tuple.h>
21   #include <thrust/device_ptr.h>
22   #include <thrust/iterator/zip_iterator.h>
23   #include <thrust/iterator/counting_iterator.h>
24   #include <thrust/inner_product.h>
25 
26 namespace Petsc
27 {
28 
29 namespace vec
30 {
31 
32 namespace cupm
33 {
34 
35 namespace impl
36 {
37 
38 // ==========================================================================================
39 // VecSeq_CUPM
40 // ==========================================================================================
41 
42 template <device::cupm::DeviceType T>
43 class VecSeq_CUPM : Vec_CUPMBase<T, VecSeq_CUPM<T>> {
44 public:
45   PETSC_VEC_CUPM_BASE_CLASS_HEADER(base_type, T, VecSeq_CUPM<T>);
46 
47 private:
48   PETSC_NODISCARD static Vec_Seq          *VecIMPLCast_(Vec) noexcept;
49   PETSC_NODISCARD static constexpr VecType VECIMPLCUPM_() noexcept;
50 
51   static PetscErrorCode VecDestroy_IMPL_(Vec) noexcept;
52   static PetscErrorCode VecResetArray_IMPL_(Vec) noexcept;
53   static PetscErrorCode VecPlaceArray_IMPL_(Vec, const PetscScalar *) noexcept;
54   static PetscErrorCode VecCreate_IMPL_Private_(Vec, PetscBool *, PetscInt, PetscScalar *) noexcept;
55 
56   static PetscErrorCode MaybeIncrementEmptyLocalVec(Vec) noexcept;
57 
58   // common core for min and max
59   template <typename TupleFuncT, typename UnaryFuncT>
60   static PetscErrorCode minmax_(TupleFuncT &&, UnaryFuncT &&, Vec, PetscInt *, PetscReal *) noexcept;
61   // common core for pointwise binary and pointwise unary thrust functions
62   template <typename BinaryFuncT>
63   static PetscErrorCode pointwisebinary_(BinaryFuncT &&, Vec, Vec, Vec) noexcept;
64   template <typename UnaryFuncT>
65   static PetscErrorCode pointwiseunary_(UnaryFuncT &&, Vec, Vec /*out*/ = nullptr) noexcept;
66   // mdot dispatchers
67   static PetscErrorCode mdot_(/* use complex = */ std::true_type, Vec, PetscInt, const Vec[], PetscScalar *, PetscDeviceContext) noexcept;
68   static PetscErrorCode mdot_(/* use complex = */ std::false_type, Vec, PetscInt, const Vec[], PetscScalar *, PetscDeviceContext) noexcept;
69   template <std::size_t... Idx>
70   static PetscErrorCode mdot_kernel_dispatch_(PetscDeviceContext, cupmStream_t, const PetscScalar *, const Vec[], PetscInt, PetscScalar *, util::index_sequence<Idx...>) noexcept;
71   template <int>
72   static PetscErrorCode mdot_kernel_dispatch_(PetscDeviceContext, cupmStream_t, const PetscScalar *, const Vec[], PetscInt, PetscScalar *, PetscInt &) noexcept;
73   template <std::size_t... Idx>
74   static PetscErrorCode maxpy_kernel_dispatch_(PetscDeviceContext, cupmStream_t, PetscScalar *, const PetscScalar *, const Vec *, PetscInt, util::index_sequence<Idx...>) noexcept;
75   template <int>
76   static PetscErrorCode maxpy_kernel_dispatch_(PetscDeviceContext, cupmStream_t, PetscScalar *, const PetscScalar *, const Vec *, PetscInt, PetscInt &) noexcept;
77   // common core for the various create routines
78   static PetscErrorCode createseqcupm_(Vec, PetscDeviceContext, PetscScalar * /*host_ptr*/ = nullptr, PetscScalar * /*device_ptr*/ = nullptr) noexcept;
79 
80 public:
81   // callable directly via a bespoke function
82   static PetscErrorCode createseqcupm(MPI_Comm, PetscInt, PetscInt, Vec *, PetscBool) noexcept;
83   static PetscErrorCode createseqcupmwithbotharrays(MPI_Comm, PetscInt, PetscInt, const PetscScalar[], const PetscScalar[], Vec *) noexcept;
84 
85   // callable indirectly via function pointers
86   static PetscErrorCode duplicate(Vec, Vec *) noexcept;
87   static PetscErrorCode aypx(Vec, PetscScalar, Vec) noexcept;
88   static PetscErrorCode axpy(Vec, PetscScalar, Vec) noexcept;
89   static PetscErrorCode pointwisedivide(Vec, Vec, Vec) noexcept;
90   static PetscErrorCode pointwisemult(Vec, Vec, Vec) noexcept;
91   static PetscErrorCode reciprocal(Vec) noexcept;
92   static PetscErrorCode waxpy(Vec, PetscScalar, Vec, Vec) noexcept;
93   static PetscErrorCode maxpy(Vec, PetscInt, const PetscScalar[], Vec *) noexcept;
94   static PetscErrorCode dot(Vec, Vec, PetscScalar *) noexcept;
95   static PetscErrorCode mdot(Vec, PetscInt, const Vec[], PetscScalar *) noexcept;
96   static PetscErrorCode set(Vec, PetscScalar) noexcept;
97   static PetscErrorCode scale(Vec, PetscScalar) noexcept;
98   static PetscErrorCode tdot(Vec, Vec, PetscScalar *) noexcept;
99   static PetscErrorCode copy(Vec, Vec) noexcept;
100   static PetscErrorCode swap(Vec, Vec) noexcept;
101   static PetscErrorCode axpby(Vec, PetscScalar, PetscScalar, Vec) noexcept;
102   static PetscErrorCode axpbypcz(Vec, PetscScalar, PetscScalar, PetscScalar, Vec, Vec) noexcept;
103   static PetscErrorCode norm(Vec, NormType, PetscReal *) noexcept;
104   static PetscErrorCode dotnorm2(Vec, Vec, PetscScalar *, PetscScalar *) noexcept;
105   static PetscErrorCode destroy(Vec) noexcept;
106   static PetscErrorCode conjugate(Vec) noexcept;
107   template <PetscMemoryAccessMode>
108   static PetscErrorCode getlocalvector(Vec, Vec) noexcept;
109   template <PetscMemoryAccessMode>
110   static PetscErrorCode restorelocalvector(Vec, Vec) noexcept;
111   static PetscErrorCode max(Vec, PetscInt *, PetscReal *) noexcept;
112   static PetscErrorCode min(Vec, PetscInt *, PetscReal *) noexcept;
113   static PetscErrorCode sum(Vec, PetscScalar *) noexcept;
114   static PetscErrorCode shift(Vec, PetscScalar) noexcept;
115   static PetscErrorCode setrandom(Vec, PetscRandom) noexcept;
116   static PetscErrorCode bindtocpu(Vec, PetscBool) noexcept;
117   static PetscErrorCode setpreallocationcoo(Vec, PetscCount, const PetscInt[]) noexcept;
118   static PetscErrorCode setvaluescoo(Vec, const PetscScalar[], InsertMode) noexcept;
119 };
120 
121 // ==========================================================================================
122 // VecSeq_CUPM - Private API
123 // ==========================================================================================
124 
125 template <device::cupm::DeviceType T>
126 inline Vec_Seq *VecSeq_CUPM<T>::VecIMPLCast_(Vec v) noexcept
127 {
128   return static_cast<Vec_Seq *>(v->data);
129 }
130 
131 template <device::cupm::DeviceType T>
132 inline constexpr VecType VecSeq_CUPM<T>::VECIMPLCUPM_() noexcept
133 {
134   return VECSEQCUPM();
135 }
136 
137 template <device::cupm::DeviceType T>
138 inline PetscErrorCode VecSeq_CUPM<T>::VecDestroy_IMPL_(Vec v) noexcept
139 {
140   return VecDestroy_Seq(v);
141 }
142 
143 template <device::cupm::DeviceType T>
144 inline PetscErrorCode VecSeq_CUPM<T>::VecResetArray_IMPL_(Vec v) noexcept
145 {
146   return VecResetArray_Seq(v);
147 }
148 
149 template <device::cupm::DeviceType T>
150 inline PetscErrorCode VecSeq_CUPM<T>::VecPlaceArray_IMPL_(Vec v, const PetscScalar *a) noexcept
151 {
152   return VecPlaceArray_Seq(v, a);
153 }
154 
155 template <device::cupm::DeviceType T>
156 inline PetscErrorCode VecSeq_CUPM<T>::VecCreate_IMPL_Private_(Vec v, PetscBool *alloc_missing, PetscInt, PetscScalar *host_array) noexcept
157 {
158   PetscMPIInt size;
159 
160   PetscFunctionBegin;
161   if (alloc_missing) *alloc_missing = PETSC_FALSE;
162   PetscCallMPI(MPI_Comm_size(PetscObjectComm(PetscObjectCast(v)), &size));
163   PetscCheck(size <= 1, PETSC_COMM_SELF, PETSC_ERR_ARG_WRONG, "Must create VecSeq on communicator of size 1, have size %d", size);
164   PetscCall(VecCreate_Seq_Private(v, host_array));
165   PetscFunctionReturn(PETSC_SUCCESS);
166 }
167 
168 // for functions with an early return based one vec size we still need to artificially bump the
169 // object state. This is to prevent the following:
170 //
171 // 0. Suppose you have a Vec {
172 //   rank 0: [0],
173 //   rank 1: [<empty>]
174 // }
175 // 1. both ranks have Vec with PetscObjectState = 0, stashed norm of 0
176 // 2. Vec enters e.g. VecSet(10)
177 // 3. rank 1 has local size 0 and bails immediately
178 // 4. rank 0 has local size 1 and enters function, eventually calls DeviceArrayWrite()
179 // 5. DeviceArrayWrite() calls PetscObjectStateIncrease(), now state = 1
180 // 6. Vec enters VecNorm(), and calls VecNormAvailable()
181 // 7. rank 1 has object state = 0, equal to stash and returns early with norm = 0
182 // 8. rank 0 has object state = 1, not equal to stash, continues to impl function
183 // 9. rank 0 deadlocks on MPI_Allreduce() because rank 1 bailed early
184 template <device::cupm::DeviceType T>
185 inline PetscErrorCode VecSeq_CUPM<T>::MaybeIncrementEmptyLocalVec(Vec v) noexcept
186 {
187   PetscFunctionBegin;
188   if (PetscUnlikely((v->map->n == 0) && (v->map->N != 0))) PetscCall(PetscObjectStateIncrease(PetscObjectCast(v)));
189   PetscFunctionReturn(PETSC_SUCCESS);
190 }
191 
192 template <device::cupm::DeviceType T>
193 inline PetscErrorCode VecSeq_CUPM<T>::createseqcupm_(Vec v, PetscDeviceContext dctx, PetscScalar *host_array, PetscScalar *device_array) noexcept
194 {
195   PetscFunctionBegin;
196   PetscCall(base_type::VecCreate_IMPL_Private(v, nullptr, 0, host_array));
197   PetscCall(Initialize_CUPMBase(v, PETSC_FALSE, host_array, device_array, dctx));
198   PetscFunctionReturn(PETSC_SUCCESS);
199 }
200 
201 template <device::cupm::DeviceType T>
202 template <typename BinaryFuncT>
203 inline PetscErrorCode VecSeq_CUPM<T>::pointwisebinary_(BinaryFuncT &&binary, Vec xin, Vec yin, Vec zout) noexcept
204 {
205   PetscFunctionBegin;
206   if (const auto n = zout->map->n) {
207     PetscDeviceContext dctx;
208     cupmStream_t       stream;
209 
210     PetscCall(GetHandles_(&dctx, &stream));
211     // clang-format off
212     PetscCallThrust(
213       const auto dxptr = thrust::device_pointer_cast(DeviceArrayRead(dctx, xin).data());
214 
215       THRUST_CALL(
216         thrust::transform,
217         stream,
218         dxptr, dxptr + n,
219         thrust::device_pointer_cast(DeviceArrayRead(dctx, yin).data()),
220         thrust::device_pointer_cast(DeviceArrayWrite(dctx, zout).data()),
221         std::forward<BinaryFuncT>(binary)
222       )
223     );
224     // clang-format on
225     PetscCall(PetscLogFlops(n));
226     PetscCall(PetscDeviceContextSynchronize(dctx));
227   } else {
228     PetscCall(MaybeIncrementEmptyLocalVec(zout));
229   }
230   PetscFunctionReturn(PETSC_SUCCESS);
231 }
232 
233 template <device::cupm::DeviceType T>
234 template <typename UnaryFuncT>
235 inline PetscErrorCode VecSeq_CUPM<T>::pointwiseunary_(UnaryFuncT &&unary, Vec xinout, Vec yin) noexcept
236 {
237   const auto inplace = !yin || (xinout == yin);
238 
239   PetscFunctionBegin;
240   if (const auto n = xinout->map->n) {
241     PetscDeviceContext dctx;
242     cupmStream_t       stream;
243     const auto         apply = [&](PetscScalar *xinout, PetscScalar *yin = nullptr) {
244       PetscFunctionBegin;
245       // clang-format off
246       PetscCallThrust(
247         const auto xptr = thrust::device_pointer_cast(xinout);
248 
249         THRUST_CALL(
250           thrust::transform,
251           stream,
252           xptr, xptr + n,
253           (yin && (yin != xinout)) ? thrust::device_pointer_cast(yin) : xptr,
254           std::forward<UnaryFuncT>(unary)
255         )
256       );
257       PetscFunctionReturn(PETSC_SUCCESS);
258     };
259 
260     PetscCall(GetHandles_(&dctx, &stream));
261     if (inplace) {
262       PetscCall(apply(DeviceArrayReadWrite(dctx, xinout).data()));
263     } else {
264       PetscCall(apply(DeviceArrayRead(dctx, xinout).data(), DeviceArrayWrite(dctx, yin).data()));
265     }
266     PetscCall(PetscLogFlops(n));
267     PetscCall(PetscDeviceContextSynchronize(dctx));
268   } else {
269     if (inplace) {
270       PetscCall(MaybeIncrementEmptyLocalVec(xinout));
271     } else {
272       PetscCall(MaybeIncrementEmptyLocalVec(yin));
273     }
274   }
275   PetscFunctionReturn(PETSC_SUCCESS);
276 }
277 
278 // ==========================================================================================
279 // VecSeq_CUPM - Public API - Constructors
280 // ==========================================================================================
281 
282 // VecCreateSeqCUPM()
283 template <device::cupm::DeviceType T>
284 inline PetscErrorCode VecSeq_CUPM<T>::createseqcupm(MPI_Comm comm, PetscInt bs, PetscInt n, Vec *v, PetscBool call_set_type) noexcept
285 {
286   PetscFunctionBegin;
287   PetscCall(Create_CUPMBase(comm, bs, n, n, v, call_set_type));
288   PetscFunctionReturn(PETSC_SUCCESS);
289 }
290 
291 // VecCreateSeqCUPMWithArrays()
292 template <device::cupm::DeviceType T>
293 inline PetscErrorCode VecSeq_CUPM<T>::createseqcupmwithbotharrays(MPI_Comm comm, PetscInt bs, PetscInt n, const PetscScalar host_array[], const PetscScalar device_array[], Vec *v) noexcept
294 {
295   PetscDeviceContext dctx;
296 
297   PetscFunctionBegin;
298   PetscCall(GetHandles_(&dctx));
299   // do NOT call VecSetType(), otherwise ops->create() -> create() ->
300   // createseqcupm_() is called!
301   PetscCall(createseqcupm(comm, bs, n, v, PETSC_FALSE));
302   PetscCall(createseqcupm_(*v, dctx, PetscRemoveConstCast(host_array), PetscRemoveConstCast(device_array)));
303   PetscFunctionReturn(PETSC_SUCCESS);
304 }
305 
306 // v->ops->duplicate
307 template <device::cupm::DeviceType T>
308 inline PetscErrorCode VecSeq_CUPM<T>::duplicate(Vec v, Vec *y) noexcept
309 {
310   PetscDeviceContext dctx;
311 
312   PetscFunctionBegin;
313   PetscCall(GetHandles_(&dctx));
314   PetscCall(Duplicate_CUPMBase(v, y, dctx));
315   PetscFunctionReturn(PETSC_SUCCESS);
316 }
317 
318 // ==========================================================================================
319 // VecSeq_CUPM - Public API - Utility
320 // ==========================================================================================
321 
322 // v->ops->bindtocpu
323 template <device::cupm::DeviceType T>
324 inline PetscErrorCode VecSeq_CUPM<T>::bindtocpu(Vec v, PetscBool usehost) noexcept
325 {
326   PetscDeviceContext dctx;
327 
328   PetscFunctionBegin;
329   PetscCall(GetHandles_(&dctx));
330   PetscCall(BindToCPU_CUPMBase(v, usehost, dctx));
331 
332   // REVIEW ME: this absolutely should be some sort of bulk mempcy rather than this mess
333   VecSetOp_CUPM(dot, VecDot_Seq, dot);
334   VecSetOp_CUPM(norm, VecNorm_Seq, norm);
335   VecSetOp_CUPM(tdot, VecTDot_Seq, tdot);
336   VecSetOp_CUPM(mdot, VecMDot_Seq, mdot);
337   VecSetOp_CUPM(resetarray, VecResetArray_Seq, base_type::template resetarray<PETSC_MEMTYPE_HOST>);
338   VecSetOp_CUPM(placearray, VecPlaceArray_Seq, base_type::template placearray<PETSC_MEMTYPE_HOST>);
339   v->ops->mtdot = v->ops->mtdot_local = VecMTDot_Seq;
340   VecSetOp_CUPM(conjugate, VecConjugate_Seq, conjugate);
341   VecSetOp_CUPM(max, VecMax_Seq, max);
342   VecSetOp_CUPM(min, VecMin_Seq, min);
343   VecSetOp_CUPM(setpreallocationcoo, VecSetPreallocationCOO_Seq, setpreallocationcoo);
344   VecSetOp_CUPM(setvaluescoo, VecSetValuesCOO_Seq, setvaluescoo);
345   PetscFunctionReturn(PETSC_SUCCESS);
346 }
347 
348 // ==========================================================================================
349 // VecSeq_CUPM - Public API - Mutators
350 // ==========================================================================================
351 
352 // v->ops->getlocalvector or v->ops->getlocalvectorread
353 template <device::cupm::DeviceType T>
354 template <PetscMemoryAccessMode access>
355 inline PetscErrorCode VecSeq_CUPM<T>::getlocalvector(Vec v, Vec w) noexcept
356 {
357   PetscBool wisseqcupm;
358 
359   PetscFunctionBegin;
360   PetscCheckTypeNames(v, VECSEQCUPM(), VECMPICUPM());
361   PetscCall(PetscObjectTypeCompare(PetscObjectCast(w), VECSEQCUPM(), &wisseqcupm));
362   if (wisseqcupm) {
363     if (const auto wseq = VecIMPLCast(w)) {
364       if (auto &alloced = wseq->array_allocated) {
365         const auto useit = UseCUPMHostAlloc(util::exchange(w->pinned_memory, PETSC_FALSE));
366 
367         PetscCall(PetscFree(alloced));
368       }
369       wseq->array         = nullptr;
370       wseq->unplacedarray = nullptr;
371     }
372     if (const auto wcu = VecCUPMCast(w)) {
373       if (auto &device_array = wcu->array_d) {
374         cupmStream_t stream;
375 
376         PetscCall(GetHandles_(&stream));
377         PetscCallCUPM(cupmFreeAsync(device_array, stream));
378       }
379       PetscCall(PetscFree(w->spptr /* wcu */));
380     }
381   }
382   if (v->petscnative && wisseqcupm) {
383     PetscCall(PetscFree(w->data));
384     w->data          = v->data;
385     w->offloadmask   = v->offloadmask;
386     w->pinned_memory = v->pinned_memory;
387     w->spptr         = v->spptr;
388     PetscCall(PetscObjectStateIncrease(PetscObjectCast(w)));
389   } else {
390     const auto array = &VecIMPLCast(w)->array;
391 
392     if (access == PETSC_MEMORY_ACCESS_READ) {
393       PetscCall(VecGetArrayRead(v, const_cast<const PetscScalar **>(array)));
394     } else {
395       PetscCall(VecGetArray(v, array));
396     }
397     w->offloadmask = PETSC_OFFLOAD_CPU;
398     if (wisseqcupm) {
399       PetscDeviceContext dctx;
400 
401       PetscCall(GetHandles_(&dctx));
402       PetscCall(DeviceAllocateCheck_(dctx, w));
403     }
404   }
405   PetscFunctionReturn(PETSC_SUCCESS);
406 }
407 
408 // v->ops->restorelocalvector or v->ops->restorelocalvectorread
409 template <device::cupm::DeviceType T>
410 template <PetscMemoryAccessMode access>
411 inline PetscErrorCode VecSeq_CUPM<T>::restorelocalvector(Vec v, Vec w) noexcept
412 {
413   PetscBool wisseqcupm;
414 
415   PetscFunctionBegin;
416   PetscCheckTypeNames(v, VECSEQCUPM(), VECMPICUPM());
417   PetscCall(PetscObjectTypeCompare(PetscObjectCast(w), VECSEQCUPM(), &wisseqcupm));
418   if (v->petscnative && wisseqcupm) {
419     // the assignments to nullptr are __critical__, as w may persist after this call returns
420     // and shouldn't share data with v!
421     v->pinned_memory = w->pinned_memory;
422     v->offloadmask   = util::exchange(w->offloadmask, PETSC_OFFLOAD_UNALLOCATED);
423     v->data          = util::exchange(w->data, nullptr);
424     v->spptr         = util::exchange(w->spptr, nullptr);
425   } else {
426     const auto array = &VecIMPLCast(w)->array;
427 
428     if (access == PETSC_MEMORY_ACCESS_READ) {
429       PetscCall(VecRestoreArrayRead(v, const_cast<const PetscScalar **>(array)));
430     } else {
431       PetscCall(VecRestoreArray(v, array));
432     }
433     if (w->spptr && wisseqcupm) {
434       cupmStream_t stream;
435 
436       PetscCall(GetHandles_(&stream));
437       PetscCallCUPM(cupmFreeAsync(VecCUPMCast(w)->array_d, stream));
438       PetscCall(PetscFree(w->spptr));
439     }
440   }
441   PetscFunctionReturn(PETSC_SUCCESS);
442 }
443 
444 // ==========================================================================================
445 // VecSeq_CUPM - Public API - Compute Methods
446 // ==========================================================================================
447 
448 // v->ops->aypx
449 template <device::cupm::DeviceType T>
450 inline PetscErrorCode VecSeq_CUPM<T>::aypx(Vec yin, PetscScalar alpha, Vec xin) noexcept
451 {
452   const auto         n    = static_cast<cupmBlasInt_t>(yin->map->n);
453   const auto         sync = n != 0;
454   PetscDeviceContext dctx;
455 
456   PetscFunctionBegin;
457   PetscCall(GetHandles_(&dctx));
458   if (alpha == PetscScalar(0.0)) {
459     cupmStream_t stream;
460 
461     PetscCall(GetHandlesFrom_(dctx, &stream));
462     PetscCall(PetscLogGpuTimeBegin());
463     PetscCall(PetscCUPMMemcpyAsync(DeviceArrayWrite(dctx, yin).data(), DeviceArrayRead(dctx, xin).data(), n, cupmMemcpyDeviceToDevice, stream));
464     PetscCall(PetscLogGpuTimeEnd());
465   } else if (n) {
466     const auto       alphaIsOne = alpha == PetscScalar(1.0);
467     const auto       calpha     = cupmScalarPtrCast(&alpha);
468     cupmBlasHandle_t cupmBlasHandle;
469 
470     PetscCall(GetHandlesFrom_(dctx, &cupmBlasHandle));
471     {
472       const auto yptr = DeviceArrayReadWrite(dctx, yin);
473       const auto xptr = DeviceArrayRead(dctx, xin);
474 
475       PetscCall(PetscLogGpuTimeBegin());
476       if (alphaIsOne) {
477         PetscCallCUPMBLAS(cupmBlasXaxpy(cupmBlasHandle, n, calpha, xptr.cupmdata(), 1, yptr.cupmdata(), 1));
478       } else {
479         const auto one = cupmScalarCast(1.0);
480 
481         PetscCallCUPMBLAS(cupmBlasXscal(cupmBlasHandle, n, calpha, yptr.cupmdata(), 1));
482         PetscCallCUPMBLAS(cupmBlasXaxpy(cupmBlasHandle, n, &one, xptr.cupmdata(), 1, yptr.cupmdata(), 1));
483       }
484       PetscCall(PetscLogGpuTimeEnd());
485     }
486     PetscCall(PetscLogGpuFlops((alphaIsOne ? 1 : 2) * n));
487   }
488   if (sync) PetscCall(PetscDeviceContextSynchronize(dctx));
489   PetscFunctionReturn(PETSC_SUCCESS);
490 }
491 
492 // v->ops->axpy
493 template <device::cupm::DeviceType T>
494 inline PetscErrorCode VecSeq_CUPM<T>::axpy(Vec yin, PetscScalar alpha, Vec xin) noexcept
495 {
496   PetscBool xiscupm;
497 
498   PetscFunctionBegin;
499   PetscCall(PetscObjectTypeCompareAny(PetscObjectCast(xin), &xiscupm, VECSEQCUPM(), VECMPICUPM(), ""));
500   if (xiscupm) {
501     const auto         n = static_cast<cupmBlasInt_t>(yin->map->n);
502     PetscDeviceContext dctx;
503     cupmBlasHandle_t   cupmBlasHandle;
504 
505     PetscCall(GetHandles_(&dctx, &cupmBlasHandle));
506     PetscCall(PetscLogGpuTimeBegin());
507     PetscCallCUPMBLAS(cupmBlasXaxpy(cupmBlasHandle, n, cupmScalarPtrCast(&alpha), DeviceArrayRead(dctx, xin), 1, DeviceArrayReadWrite(dctx, yin), 1));
508     PetscCall(PetscLogGpuTimeEnd());
509     PetscCall(PetscLogGpuFlops(2 * n));
510     PetscCall(PetscDeviceContextSynchronize(dctx));
511   } else {
512     PetscCall(VecAXPY_Seq(yin, alpha, xin));
513   }
514   PetscFunctionReturn(PETSC_SUCCESS);
515 }
516 
517 // v->ops->pointwisedivide
518 template <device::cupm::DeviceType T>
519 inline PetscErrorCode VecSeq_CUPM<T>::pointwisedivide(Vec win, Vec xin, Vec yin) noexcept
520 {
521   PetscFunctionBegin;
522   if (xin->boundtocpu || yin->boundtocpu) {
523     PetscCall(VecPointwiseDivide_Seq(win, xin, yin));
524   } else {
525     // note order of arguments! xin and yin are read, win is written!
526     PetscCall(pointwisebinary_(thrust::divides<PetscScalar>{}, xin, yin, win));
527   }
528   PetscFunctionReturn(PETSC_SUCCESS);
529 }
530 
531 // v->ops->pointwisemult
532 template <device::cupm::DeviceType T>
533 inline PetscErrorCode VecSeq_CUPM<T>::pointwisemult(Vec win, Vec xin, Vec yin) noexcept
534 {
535   PetscFunctionBegin;
536   if (xin->boundtocpu || yin->boundtocpu) {
537     PetscCall(VecPointwiseMult_Seq(win, xin, yin));
538   } else {
539     // note order of arguments! xin and yin are read, win is written!
540     PetscCall(pointwisebinary_(thrust::multiplies<PetscScalar>{}, xin, yin, win));
541   }
542   PetscFunctionReturn(PETSC_SUCCESS);
543 }
544 
545 namespace detail
546 {
547 
548 struct reciprocal {
549   PETSC_HOSTDEVICE_INLINE_DECL PetscScalar operator()(PetscScalar s) const noexcept
550   {
551     // yes all of this verbosity is needed because sometimes PetscScalar is a thrust::complex
552     // and then it matters whether we do s ? true : false vs s == 0, as well as whether we wrap
553     // everything in PetscScalar...
554     return s == PetscScalar{0.0} ? s : PetscScalar{1.0} / s;
555   }
556 };
557 
558 } // namespace detail
559 
560 // v->ops->reciprocal
561 template <device::cupm::DeviceType T>
562 inline PetscErrorCode VecSeq_CUPM<T>::reciprocal(Vec xin) noexcept
563 {
564   PetscFunctionBegin;
565   PetscCall(pointwiseunary_(detail::reciprocal{}, xin));
566   PetscFunctionReturn(PETSC_SUCCESS);
567 }
568 
569 // v->ops->waxpy
570 template <device::cupm::DeviceType T>
571 inline PetscErrorCode VecSeq_CUPM<T>::waxpy(Vec win, PetscScalar alpha, Vec xin, Vec yin) noexcept
572 {
573   PetscFunctionBegin;
574   if (alpha == PetscScalar(0.0)) {
575     PetscCall(copy(yin, win));
576   } else if (const auto n = static_cast<cupmBlasInt_t>(win->map->n)) {
577     PetscDeviceContext dctx;
578     cupmBlasHandle_t   cupmBlasHandle;
579     cupmStream_t       stream;
580 
581     PetscCall(GetHandles_(&dctx, &cupmBlasHandle, &stream));
582     {
583       const auto wptr = DeviceArrayWrite(dctx, win);
584 
585       PetscCall(PetscLogGpuTimeBegin());
586       PetscCall(PetscCUPMMemcpyAsync(wptr.data(), DeviceArrayRead(dctx, yin).data(), n, cupmMemcpyDeviceToDevice, stream, true));
587       PetscCallCUPMBLAS(cupmBlasXaxpy(cupmBlasHandle, n, cupmScalarPtrCast(&alpha), DeviceArrayRead(dctx, xin), 1, wptr.cupmdata(), 1));
588       PetscCall(PetscLogGpuTimeEnd());
589     }
590     PetscCall(PetscLogGpuFlops(2 * n));
591     PetscCall(PetscDeviceContextSynchronize(dctx));
592   }
593   PetscFunctionReturn(PETSC_SUCCESS);
594 }
595 
596 namespace kernels
597 {
598 
599 template <typename... Args>
600 PETSC_KERNEL_DECL static void maxpy_kernel(const PetscInt size, PetscScalar *PETSC_RESTRICT xptr, const PetscScalar *PETSC_RESTRICT aptr, Args... yptr)
601 {
602   constexpr int      N        = sizeof...(Args);
603   const auto         tx       = threadIdx.x;
604   const PetscScalar *yptr_p[] = {yptr...};
605 
606   PETSC_SHAREDMEM_DECL PetscScalar aptr_shmem[N];
607 
608   // load a to shared memory
609   if (tx < N) aptr_shmem[tx] = aptr[tx];
610   __syncthreads();
611 
612   ::Petsc::device::cupm::kernels::util::grid_stride_1D(size, [&](PetscInt i) {
613     // these may look the same but give different results!
614   #if 0
615     PetscScalar sum = 0.0;
616 
617     #pragma unroll
618     for (auto j = 0; j < N; ++j) sum += aptr_shmem[j]*yptr_p[j][i];
619     xptr[i] += sum;
620   #else
621     auto sum = xptr[i];
622 
623     #pragma unroll
624     for (auto j = 0; j < N; ++j) sum += aptr_shmem[j]*yptr_p[j][i];
625     xptr[i] = sum;
626   #endif
627   });
628   return;
629 }
630 
631 } // namespace kernels
632 
633 namespace detail
634 {
635 
636 // a helper-struct to gobble the size_t input, it is used with template parameter pack
637 // expansion such that
638 // typename repeat_type<MyType, IdxParamPack>...
639 // expands to
640 // MyType, MyType, MyType, ... [repeated sizeof...(IdxParamPack) times]
641 template <typename T, std::size_t>
642 struct repeat_type {
643   using type = T;
644 };
645 
646 } // namespace detail
647 
648 template <device::cupm::DeviceType T>
649 template <std::size_t... Idx>
650 inline PetscErrorCode VecSeq_CUPM<T>::maxpy_kernel_dispatch_(PetscDeviceContext dctx, cupmStream_t stream, PetscScalar *xptr, const PetscScalar *aptr, const Vec *yin, PetscInt size, util::index_sequence<Idx...>) noexcept
651 {
652   PetscFunctionBegin;
653   // clang-format off
654   PetscCall(
655     PetscCUPMLaunchKernel1D(
656       size, 0, stream,
657       kernels::maxpy_kernel<typename detail::repeat_type<const PetscScalar *, Idx>::type...>,
658       size, xptr, aptr, DeviceArrayRead(dctx, yin[Idx]).data()...
659     )
660   );
661   // clang-format on
662   PetscFunctionReturn(PETSC_SUCCESS);
663 }
664 
665 template <device::cupm::DeviceType T>
666 template <int N>
667 inline PetscErrorCode VecSeq_CUPM<T>::maxpy_kernel_dispatch_(PetscDeviceContext dctx, cupmStream_t stream, PetscScalar *xptr, const PetscScalar *aptr, const Vec *yin, PetscInt size, PetscInt &yidx) noexcept
668 {
669   PetscFunctionBegin;
670   PetscCall(maxpy_kernel_dispatch_(dctx, stream, xptr, aptr + yidx, yin + yidx, size, util::make_index_sequence<N>{}));
671   yidx += N;
672   PetscFunctionReturn(PETSC_SUCCESS);
673 }
674 
675 // v->ops->maxpy
676 template <device::cupm::DeviceType T>
677 inline PetscErrorCode VecSeq_CUPM<T>::maxpy(Vec xin, PetscInt nv, const PetscScalar *alpha, Vec *yin) noexcept
678 {
679   const auto         n = xin->map->n;
680   PetscDeviceContext dctx;
681   cupmStream_t       stream;
682 
683   PetscFunctionBegin;
684   PetscCall(GetHandles_(&dctx, &stream));
685   {
686     const auto   xptr    = DeviceArrayReadWrite(dctx, xin);
687     PetscScalar *d_alpha = nullptr;
688     PetscInt     yidx    = 0;
689 
690     // placement of early-return is deliberate, we would like to capture the
691     // DeviceArrayReadWrite() call (which calls PetscObjectStateIncreate()) before we bail
692     if (!n || !nv) PetscFunctionReturn(PETSC_SUCCESS);
693     PetscCall(PetscDeviceMalloc(dctx, PETSC_MEMTYPE_CUPM(), nv, &d_alpha));
694     PetscCall(PetscCUPMMemcpyAsync(d_alpha, alpha, nv, cupmMemcpyHostToDevice, stream));
695     PetscCall(PetscLogGpuTimeBegin());
696     do {
697       switch (nv - yidx) {
698       case 7:
699         PetscCall(maxpy_kernel_dispatch_<7>(dctx, stream, xptr.data(), d_alpha, yin, n, yidx));
700         break;
701       case 6:
702         PetscCall(maxpy_kernel_dispatch_<6>(dctx, stream, xptr.data(), d_alpha, yin, n, yidx));
703         break;
704       case 5:
705         PetscCall(maxpy_kernel_dispatch_<5>(dctx, stream, xptr.data(), d_alpha, yin, n, yidx));
706         break;
707       case 4:
708         PetscCall(maxpy_kernel_dispatch_<4>(dctx, stream, xptr.data(), d_alpha, yin, n, yidx));
709         break;
710       case 3:
711         PetscCall(maxpy_kernel_dispatch_<3>(dctx, stream, xptr.data(), d_alpha, yin, n, yidx));
712         break;
713       case 2:
714         PetscCall(maxpy_kernel_dispatch_<2>(dctx, stream, xptr.data(), d_alpha, yin, n, yidx));
715         break;
716       case 1:
717         PetscCall(maxpy_kernel_dispatch_<1>(dctx, stream, xptr.data(), d_alpha, yin, n, yidx));
718         break;
719       default: // 8 or more
720         PetscCall(maxpy_kernel_dispatch_<8>(dctx, stream, xptr.data(), d_alpha, yin, n, yidx));
721         break;
722       }
723     } while (yidx < nv);
724     PetscCall(PetscLogGpuTimeEnd());
725     PetscCall(PetscDeviceFree(dctx, d_alpha));
726   }
727   PetscCall(PetscLogGpuFlops(nv * 2 * n));
728   PetscCall(PetscDeviceContextSynchronize(dctx));
729   PetscFunctionReturn(PETSC_SUCCESS);
730 }
731 
732 template <device::cupm::DeviceType T>
733 inline PetscErrorCode VecSeq_CUPM<T>::dot(Vec xin, Vec yin, PetscScalar *z) noexcept
734 {
735   PetscFunctionBegin;
736   if (const auto n = static_cast<cupmBlasInt_t>(xin->map->n)) {
737     PetscDeviceContext dctx;
738     cupmBlasHandle_t   cupmBlasHandle;
739 
740     PetscCall(GetHandles_(&dctx, &cupmBlasHandle));
741     // arguments y, x are reversed because BLAS complex conjugates the first argument, PETSc the
742     // second
743     PetscCall(PetscLogGpuTimeBegin());
744     PetscCallCUPMBLAS(cupmBlasXdot(cupmBlasHandle, n, DeviceArrayRead(dctx, yin), 1, DeviceArrayRead(dctx, xin), 1, cupmScalarPtrCast(z)));
745     PetscCall(PetscLogGpuTimeEnd());
746     PetscCall(PetscLogGpuFlops(2 * n - 1));
747   } else {
748     *z = 0.0;
749   }
750   PetscFunctionReturn(PETSC_SUCCESS);
751 }
752 
753   #define MDOT_WORKGROUP_NUM  128
754   #define MDOT_WORKGROUP_SIZE MDOT_WORKGROUP_NUM
755 
756 namespace kernels
757 {
758 
759 PETSC_DEVICE_INLINE_DECL static PetscInt EntriesPerGroup(const PetscInt size) noexcept
760 {
761   const auto group_entries = (size - 1) / gridDim.x + 1;
762   // for very small vectors, a group should still do some work
763   return group_entries ? group_entries : 1;
764 }
765 
766 template <typename... ConstPetscScalarPointer>
767 PETSC_KERNEL_DECL static void mdot_kernel(const PetscScalar *PETSC_RESTRICT x, const PetscInt size, PetscScalar *PETSC_RESTRICT results, ConstPetscScalarPointer... y)
768 {
769   constexpr int      N        = sizeof...(ConstPetscScalarPointer);
770   const PetscScalar *ylocal[] = {y...};
771   PetscScalar        sumlocal[N];
772 
773   PETSC_SHAREDMEM_DECL PetscScalar shmem[N * MDOT_WORKGROUP_SIZE];
774 
775   // HIP -- for whatever reason -- has threadIdx, blockIdx, blockDim, and gridDim as separate
776   // types, so each of these go on separate lines...
777   const auto tx       = threadIdx.x;
778   const auto bx       = blockIdx.x;
779   const auto bdx      = blockDim.x;
780   const auto gdx      = gridDim.x;
781   const auto worksize = EntriesPerGroup(size);
782   const auto begin    = tx + bx * worksize;
783   const auto end      = min((bx + 1) * worksize, size);
784 
785   #pragma unroll
786   for (auto i = 0; i < N; ++i) sumlocal[i] = 0;
787 
788   for (auto i = begin; i < end; i += bdx) {
789     const auto xi = x[i]; // load only once from global memory!
790 
791   #pragma unroll
792     for (auto j = 0; j < N; ++j) sumlocal[j] += ylocal[j][i] * xi;
793   }
794 
795   #pragma unroll
796   for (auto i = 0; i < N; ++i) shmem[tx + i * MDOT_WORKGROUP_SIZE] = sumlocal[i];
797 
798   // parallel reduction
799   for (auto stride = bdx / 2; stride > 0; stride /= 2) {
800     __syncthreads();
801     if (tx < stride) {
802   #pragma unroll
803       for (auto i = 0; i < N; ++i) shmem[tx + i * MDOT_WORKGROUP_SIZE] += shmem[tx + stride + i * MDOT_WORKGROUP_SIZE];
804     }
805   }
806   // bottom N threads per block write to global memory
807   // REVIEW ME: I am ~pretty~ sure we don't need another __syncthreads() here since each thread
808   // writes to the same sections in the above loop that it is about to read from below, but
809   // running this under the racecheck tool of cuda-memcheck reports a write-after-write hazard.
810   __syncthreads();
811   if (tx < N) results[bx + tx * gdx] = shmem[tx * MDOT_WORKGROUP_SIZE];
812   return;
813 }
814 
815 namespace
816 {
817 
818 PETSC_KERNEL_DECL void sum_kernel(const PetscInt size, PetscScalar *PETSC_RESTRICT results)
819 {
820   int         local_i = 0;
821   PetscScalar local_results[8];
822 
823   // each thread sums up MDOT_WORKGROUP_NUM entries of the result, storing it in a local buffer
824   //
825   // *-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*
826   // | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | ...
827   // *-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*
828   //  |  ______________________________________________________/
829   //  | /            <- MDOT_WORKGROUP_NUM ->
830   //  |/
831   //  +
832   //  v
833   // *-*-*
834   // | | | ...
835   // *-*-*
836   //
837   ::Petsc::device::cupm::kernels::util::grid_stride_1D(size, [&](PetscInt i) {
838     PetscScalar z_sum = 0;
839 
840     for (auto j = i * MDOT_WORKGROUP_SIZE; j < (i + 1) * MDOT_WORKGROUP_SIZE; ++j) z_sum += results[j];
841     local_results[local_i++] = z_sum;
842   });
843   // if we needed more than 1 workgroup to handle the vector we should sync since other threads
844   // may currently be reading from results
845   if (size >= MDOT_WORKGROUP_SIZE) __syncthreads();
846   // Local buffer is now written to global memory
847   ::Petsc::device::cupm::kernels::util::grid_stride_1D(size, [&](PetscInt i) {
848     const auto j = --local_i;
849 
850     if (j >= 0) results[i] = local_results[j];
851   });
852   return;
853 }
854 
855 } // namespace
856 
857 } // namespace kernels
858 
859 template <device::cupm::DeviceType T>
860 template <std::size_t... Idx>
861 inline PetscErrorCode VecSeq_CUPM<T>::mdot_kernel_dispatch_(PetscDeviceContext dctx, cupmStream_t stream, const PetscScalar *xarr, const Vec yin[], PetscInt size, PetscScalar *results, util::index_sequence<Idx...>) noexcept
862 {
863   PetscFunctionBegin;
864   // REVIEW ME: convert this kernel launch to PetscCUPMLaunchKernel1D(), it currently launches
865   // 128 blocks of 128 threads every time which may be wasteful
866   // clang-format off
867   PetscCallCUPM(
868     cupmLaunchKernel(
869       kernels::mdot_kernel<typename detail::repeat_type<const PetscScalar *, Idx>::type...>,
870       MDOT_WORKGROUP_NUM, MDOT_WORKGROUP_SIZE, 0, stream,
871       xarr, size, results, DeviceArrayRead(dctx, yin[Idx]).data()...
872     )
873   );
874   // clang-format on
875   PetscFunctionReturn(PETSC_SUCCESS);
876 }
877 
878 template <device::cupm::DeviceType T>
879 template <int N>
880 inline PetscErrorCode VecSeq_CUPM<T>::mdot_kernel_dispatch_(PetscDeviceContext dctx, cupmStream_t stream, const PetscScalar *xarr, const Vec yin[], PetscInt size, PetscScalar *results, PetscInt &yidx) noexcept
881 {
882   PetscFunctionBegin;
883   PetscCall(mdot_kernel_dispatch_(dctx, stream, xarr, yin + yidx, size, results + yidx * MDOT_WORKGROUP_NUM, util::make_index_sequence<N>{}));
884   yidx += N;
885   PetscFunctionReturn(PETSC_SUCCESS);
886 }
887 
888 template <device::cupm::DeviceType T>
889 inline PetscErrorCode VecSeq_CUPM<T>::mdot_(std::false_type, Vec xin, PetscInt nv, const Vec yin[], PetscScalar *z, PetscDeviceContext dctx) noexcept
890 {
891   // the largest possible size of a batch
892   constexpr PetscInt batchsize = 8;
893   // how many sub streams to create, if nv <= batchsize we can do this without looping, so we
894   // do not create substreams. Note we don't create more than 8 streams, in practice we could
895   // not get more parallelism with higher numbers.
896   const auto num_sub_streams = nv > batchsize ? std::min((nv + batchsize) / batchsize, batchsize) : 0;
897   const auto n               = xin->map->n;
898   // number of vectors that we handle via the batches. note any singletons are handled by
899   // cublas, hence the nv-1.
900   const auto   nvbatch = ((nv % batchsize) == 1) ? nv - 1 : nv;
901   const auto   nwork   = nvbatch * MDOT_WORKGROUP_NUM;
902   PetscScalar *d_results;
903   cupmStream_t stream;
904 
905   PetscFunctionBegin;
906   PetscCall(GetHandlesFrom_(dctx, &stream));
907   // allocate scratchpad memory for the results of individual work groups
908   PetscCall(PetscDeviceMalloc(dctx, PETSC_MEMTYPE_CUPM(), nwork, &d_results));
909   {
910     const auto          xptr       = DeviceArrayRead(dctx, xin);
911     PetscInt            yidx       = 0;
912     auto                subidx     = 0;
913     auto                cur_stream = stream;
914     auto                cur_ctx    = dctx;
915     PetscDeviceContext *sub        = nullptr;
916     PetscStreamType     stype;
917 
918     // REVIEW ME: maybe PetscDeviceContextFork() should insert dctx into the first entry of
919     // sub. Ideally the parent context should also join in on the fork, but it is extremely
920     // fiddly to do so presently
921     PetscCall(PetscDeviceContextGetStreamType(dctx, &stype));
922     if (stype == PETSC_STREAM_GLOBAL_BLOCKING) stype = PETSC_STREAM_DEFAULT_BLOCKING;
923     // If we have a globally blocking stream create nonblocking streams instead (as we can
924     // locally exploit the parallelism). Otherwise use the prescribed stream type.
925     PetscCall(PetscDeviceContextForkWithStreamType(dctx, stype, num_sub_streams, &sub));
926     PetscCall(PetscLogGpuTimeBegin());
927     do {
928       if (num_sub_streams) {
929         cur_ctx = sub[subidx++ % num_sub_streams];
930         PetscCall(GetHandlesFrom_(cur_ctx, &cur_stream));
931       }
932       // REVIEW ME: Should probably try and load-balance these. Consider the case where nv = 9;
933       // it is very likely better to do 4+5 rather than 8+1
934       switch (nv - yidx) {
935       case 7:
936         PetscCall(mdot_kernel_dispatch_<7>(cur_ctx, cur_stream, xptr.data(), yin, n, d_results, yidx));
937         break;
938       case 6:
939         PetscCall(mdot_kernel_dispatch_<6>(cur_ctx, cur_stream, xptr.data(), yin, n, d_results, yidx));
940         break;
941       case 5:
942         PetscCall(mdot_kernel_dispatch_<5>(cur_ctx, cur_stream, xptr.data(), yin, n, d_results, yidx));
943         break;
944       case 4:
945         PetscCall(mdot_kernel_dispatch_<4>(cur_ctx, cur_stream, xptr.data(), yin, n, d_results, yidx));
946         break;
947       case 3:
948         PetscCall(mdot_kernel_dispatch_<3>(cur_ctx, cur_stream, xptr.data(), yin, n, d_results, yidx));
949         break;
950       case 2:
951         PetscCall(mdot_kernel_dispatch_<2>(cur_ctx, cur_stream, xptr.data(), yin, n, d_results, yidx));
952         break;
953       case 1: {
954         cupmBlasHandle_t cupmBlasHandle;
955 
956         PetscCall(GetHandlesFrom_(cur_ctx, &cupmBlasHandle));
957         PetscCallCUPMBLAS(cupmBlasXdot(cupmBlasHandle, static_cast<cupmBlasInt_t>(n), DeviceArrayRead(cur_ctx, yin[yidx]).cupmdata(), 1, xptr.cupmdata(), 1, cupmScalarPtrCast(z + yidx)));
958         ++yidx;
959       } break;
960       default: // 8 or more
961         PetscCall(mdot_kernel_dispatch_<8>(cur_ctx, cur_stream, xptr.data(), yin, n, d_results, yidx));
962         break;
963       }
964     } while (yidx < nv);
965     PetscCall(PetscLogGpuTimeEnd());
966     PetscCall(PetscDeviceContextJoin(dctx, num_sub_streams, PETSC_DEVICE_CONTEXT_JOIN_DESTROY, &sub));
967   }
968 
969   PetscCall(PetscCUPMLaunchKernel1D(nvbatch, 0, stream, kernels::sum_kernel, nvbatch, d_results));
970   // copy result of device reduction to host
971   PetscCall(PetscCUPMMemcpyAsync(z, d_results, nvbatch, cupmMemcpyDeviceToHost, stream));
972   // do these now while final reduction is in flight
973   PetscCall(PetscLogFlops(nwork));
974   PetscCall(PetscDeviceFree(dctx, d_results));
975   PetscFunctionReturn(PETSC_SUCCESS);
976 }
977 
978   #undef MDOT_WORKGROUP_NUM
979   #undef MDOT_WORKGROUP_SIZE
980 
981 template <device::cupm::DeviceType T>
982 inline PetscErrorCode VecSeq_CUPM<T>::mdot_(std::true_type, Vec xin, PetscInt nv, const Vec yin[], PetscScalar *z, PetscDeviceContext dctx) noexcept
983 {
984   // probably not worth it to run more than 8 of these at a time?
985   const auto          n_sub = PetscMin(nv, 8);
986   const auto          n     = static_cast<cupmBlasInt_t>(xin->map->n);
987   const auto          xptr  = DeviceArrayRead(dctx, xin);
988   PetscScalar        *d_z;
989   PetscDeviceContext *subctx;
990   cupmStream_t        stream;
991 
992   PetscFunctionBegin;
993   PetscCall(GetHandlesFrom_(dctx, &stream));
994   PetscCall(PetscDeviceMalloc(dctx, PETSC_MEMTYPE_CUPM(), nv, &d_z));
995   PetscCall(PetscDeviceContextFork(dctx, n_sub, &subctx));
996   PetscCall(PetscLogGpuTimeBegin());
997   for (PetscInt i = 0; i < nv; ++i) {
998     const auto            sub = subctx[i % n_sub];
999     cupmBlasHandle_t      handle;
1000     cupmBlasPointerMode_t old_mode;
1001 
1002     PetscCall(GetHandlesFrom_(sub, &handle));
1003     PetscCallCUPMBLAS(cupmBlasGetPointerMode(handle, &old_mode));
1004     if (old_mode != CUPMBLAS_POINTER_MODE_DEVICE) PetscCallCUPMBLAS(cupmBlasSetPointerMode(handle, CUPMBLAS_POINTER_MODE_DEVICE));
1005     PetscCallCUPMBLAS(cupmBlasXdot(handle, n, DeviceArrayRead(sub, yin[i]), 1, xptr.cupmdata(), 1, cupmScalarPtrCast(d_z + i)));
1006     if (old_mode != CUPMBLAS_POINTER_MODE_DEVICE) PetscCallCUPMBLAS(cupmBlasSetPointerMode(handle, old_mode));
1007   }
1008   PetscCall(PetscLogGpuTimeEnd());
1009   PetscCall(PetscDeviceContextJoin(dctx, n_sub, PETSC_DEVICE_CONTEXT_JOIN_DESTROY, &subctx));
1010   PetscCall(PetscCUPMMemcpyAsync(z, d_z, nv, cupmMemcpyDeviceToHost, stream));
1011   PetscCall(PetscDeviceFree(dctx, d_z));
1012   // REVIEW ME: flops?????
1013   PetscFunctionReturn(PETSC_SUCCESS);
1014 }
1015 
1016 // v->ops->mdot
1017 template <device::cupm::DeviceType T>
1018 inline PetscErrorCode VecSeq_CUPM<T>::mdot(Vec xin, PetscInt nv, const Vec yin[], PetscScalar *z) noexcept
1019 {
1020   PetscFunctionBegin;
1021   if (PetscUnlikely(nv == 1)) {
1022     // dot handles nv = 0 correctly
1023     PetscCall(dot(xin, const_cast<Vec>(yin[0]), z));
1024   } else if (const auto n = xin->map->n) {
1025     PetscDeviceContext dctx;
1026 
1027     PetscCheck(nv > 0, PETSC_COMM_SELF, PETSC_ERR_LIB, "Number of vectors provided to %s %" PetscInt_FMT " not positive", PETSC_FUNCTION_NAME, nv);
1028     PetscCall(GetHandles_(&dctx));
1029     PetscCall(mdot_(std::integral_constant<bool, PetscDefined(USE_COMPLEX)>{}, xin, nv, yin, z, dctx));
1030     // REVIEW ME: double count of flops??
1031     PetscCall(PetscLogGpuFlops(nv * (2 * n - 1)));
1032     PetscCall(PetscDeviceContextSynchronize(dctx));
1033   } else {
1034     PetscCall(PetscArrayzero(z, nv));
1035   }
1036   PetscFunctionReturn(PETSC_SUCCESS);
1037 }
1038 
1039 // v->ops->set
1040 template <device::cupm::DeviceType T>
1041 inline PetscErrorCode VecSeq_CUPM<T>::set(Vec xin, PetscScalar alpha) noexcept
1042 {
1043   const auto         n = xin->map->n;
1044   PetscDeviceContext dctx;
1045   cupmStream_t       stream;
1046 
1047   PetscFunctionBegin;
1048   PetscCall(GetHandles_(&dctx, &stream));
1049   {
1050     const auto xptr = DeviceArrayWrite(dctx, xin);
1051 
1052     if (alpha == PetscScalar(0.0)) {
1053       PetscCall(PetscCUPMMemsetAsync(xptr.data(), 0, n, stream));
1054     } else {
1055       const auto dptr = thrust::device_pointer_cast(xptr.data());
1056 
1057       PetscCallThrust(THRUST_CALL(thrust::fill, stream, dptr, dptr + n, alpha));
1058     }
1059     if (n) PetscCall(PetscDeviceContextSynchronize(dctx)); // don't sync if we did nothing
1060   }
1061   PetscFunctionReturn(PETSC_SUCCESS);
1062 }
1063 
1064 // v->ops->scale
1065 template <device::cupm::DeviceType T>
1066 inline PetscErrorCode VecSeq_CUPM<T>::scale(Vec xin, PetscScalar alpha) noexcept
1067 {
1068   PetscFunctionBegin;
1069   if (PetscUnlikely(alpha == PetscScalar(1.0))) PetscFunctionReturn(PETSC_SUCCESS);
1070   if (PetscUnlikely(alpha == PetscScalar(0.0))) {
1071     PetscCall(set(xin, alpha));
1072   } else if (const auto n = static_cast<cupmBlasInt_t>(xin->map->n)) {
1073     PetscDeviceContext dctx;
1074     cupmBlasHandle_t   cupmBlasHandle;
1075 
1076     PetscCall(GetHandles_(&dctx, &cupmBlasHandle));
1077     PetscCall(PetscLogGpuTimeBegin());
1078     PetscCallCUPMBLAS(cupmBlasXscal(cupmBlasHandle, n, cupmScalarPtrCast(&alpha), DeviceArrayReadWrite(dctx, xin), 1));
1079     PetscCall(PetscLogGpuTimeEnd());
1080     PetscCall(PetscLogGpuFlops(n));
1081     PetscCall(PetscDeviceContextSynchronize(dctx));
1082   } else {
1083     PetscCall(MaybeIncrementEmptyLocalVec(xin));
1084   }
1085   PetscFunctionReturn(PETSC_SUCCESS);
1086 }
1087 
1088 // v->ops->tdot
1089 template <device::cupm::DeviceType T>
1090 inline PetscErrorCode VecSeq_CUPM<T>::tdot(Vec xin, Vec yin, PetscScalar *z) noexcept
1091 {
1092   PetscFunctionBegin;
1093   if (const auto n = static_cast<cupmBlasInt_t>(xin->map->n)) {
1094     PetscDeviceContext dctx;
1095     cupmBlasHandle_t   cupmBlasHandle;
1096 
1097     PetscCall(GetHandles_(&dctx, &cupmBlasHandle));
1098     PetscCall(PetscLogGpuTimeBegin());
1099     PetscCallCUPMBLAS(cupmBlasXdotu(cupmBlasHandle, n, DeviceArrayRead(dctx, xin), 1, DeviceArrayRead(dctx, yin), 1, cupmScalarPtrCast(z)));
1100     PetscCall(PetscLogGpuTimeEnd());
1101     PetscCall(PetscLogGpuFlops(2 * n - 1));
1102   } else {
1103     *z = 0.0;
1104   }
1105   PetscFunctionReturn(PETSC_SUCCESS);
1106 }
1107 
1108 // v->ops->copy
1109 template <device::cupm::DeviceType T>
1110 inline PetscErrorCode VecSeq_CUPM<T>::copy(Vec xin, Vec yout) noexcept
1111 {
1112   PetscFunctionBegin;
1113   if (xin == yout) PetscFunctionReturn(PETSC_SUCCESS);
1114   if (const auto n = xin->map->n) {
1115     const auto xmask = xin->offloadmask;
1116     // silence buggy gcc warning: mode may be used uninitialized in this function
1117     auto               mode = cupmMemcpyDeviceToDevice;
1118     PetscDeviceContext dctx;
1119     cupmStream_t       stream;
1120 
1121     // translate from PetscOffloadMask to cupmMemcpyKind
1122     switch (const auto ymask = yout->offloadmask) {
1123     case PETSC_OFFLOAD_UNALLOCATED: {
1124       PetscBool yiscupm;
1125 
1126       PetscCall(PetscObjectTypeCompareAny(PetscObjectCast(yout), &yiscupm, VECSEQCUPM(), VECMPICUPM(), ""));
1127       if (yiscupm) {
1128         mode = PetscOffloadDevice(xmask) ? cupmMemcpyDeviceToDevice : cupmMemcpyHostToHost;
1129         break;
1130       }
1131     } // fall-through if unallocated and not cupm
1132   #if PETSC_CPP_VERSION >= 17
1133       [[fallthrough]];
1134   #endif
1135     case PETSC_OFFLOAD_CPU:
1136       mode = PetscOffloadHost(xmask) ? cupmMemcpyHostToHost : cupmMemcpyDeviceToHost;
1137       break;
1138     case PETSC_OFFLOAD_BOTH:
1139     case PETSC_OFFLOAD_GPU:
1140       mode = PetscOffloadDevice(xmask) ? cupmMemcpyDeviceToDevice : cupmMemcpyHostToDevice;
1141       break;
1142     default:
1143       SETERRQ(PETSC_COMM_SELF, PETSC_ERR_ARG_INCOMP, "Incompatible offload mask %s", PetscOffloadMaskToString(ymask));
1144     }
1145 
1146     PetscCall(GetHandles_(&dctx, &stream));
1147     switch (mode) {
1148     case cupmMemcpyDeviceToDevice: // the best case
1149     case cupmMemcpyHostToDevice: { // not terrible
1150       const auto yptr = DeviceArrayWrite(dctx, yout);
1151       const auto xptr = mode == cupmMemcpyDeviceToDevice ? DeviceArrayRead(dctx, xin).data() : HostArrayRead(dctx, xin).data();
1152 
1153       PetscCall(PetscLogGpuTimeBegin());
1154       PetscCall(PetscCUPMMemcpyAsync(yptr.data(), xptr, n, mode, stream));
1155       PetscCall(PetscLogGpuTimeEnd());
1156     } break;
1157     case cupmMemcpyDeviceToHost: // not great
1158     case cupmMemcpyHostToHost: { // worst case
1159       const auto   xptr = mode == cupmMemcpyDeviceToHost ? DeviceArrayRead(dctx, xin).data() : HostArrayRead(dctx, xin).data();
1160       PetscScalar *yptr;
1161 
1162       PetscCall(VecGetArrayWrite(yout, &yptr));
1163       if (mode == cupmMemcpyDeviceToHost) PetscCall(PetscLogGpuTimeBegin());
1164       PetscCall(PetscCUPMMemcpyAsync(yptr, xptr, n, mode, stream, /* force async */ true));
1165       if (mode == cupmMemcpyDeviceToHost) PetscCall(PetscLogGpuTimeEnd());
1166       PetscCall(VecRestoreArrayWrite(yout, &yptr));
1167     } break;
1168     default:
1169       SETERRQ(PETSC_COMM_SELF, PETSC_ERR_GPU, "Unknown cupmMemcpyKind %d", static_cast<int>(mode));
1170     }
1171     PetscCall(PetscDeviceContextSynchronize(dctx));
1172   } else {
1173     PetscCall(MaybeIncrementEmptyLocalVec(yout));
1174   }
1175   PetscFunctionReturn(PETSC_SUCCESS);
1176 }
1177 
1178 // v->ops->swap
1179 template <device::cupm::DeviceType T>
1180 inline PetscErrorCode VecSeq_CUPM<T>::swap(Vec xin, Vec yin) noexcept
1181 {
1182   PetscFunctionBegin;
1183   if (xin == yin) PetscFunctionReturn(PETSC_SUCCESS);
1184   if (const auto n = static_cast<cupmBlasInt_t>(xin->map->n)) {
1185     PetscDeviceContext dctx;
1186     cupmBlasHandle_t   cupmBlasHandle;
1187 
1188     PetscCall(GetHandles_(&dctx, &cupmBlasHandle));
1189     PetscCall(PetscLogGpuTimeBegin());
1190     PetscCallCUPMBLAS(cupmBlasXswap(cupmBlasHandle, n, DeviceArrayReadWrite(dctx, xin), 1, DeviceArrayReadWrite(dctx, yin), 1));
1191     PetscCall(PetscLogGpuTimeEnd());
1192     PetscCall(PetscDeviceContextSynchronize(dctx));
1193   } else {
1194     PetscCall(MaybeIncrementEmptyLocalVec(xin));
1195     PetscCall(MaybeIncrementEmptyLocalVec(yin));
1196   }
1197   PetscFunctionReturn(PETSC_SUCCESS);
1198 }
1199 
1200 // v->ops->axpby
1201 template <device::cupm::DeviceType T>
1202 inline PetscErrorCode VecSeq_CUPM<T>::axpby(Vec yin, PetscScalar alpha, PetscScalar beta, Vec xin) noexcept
1203 {
1204   PetscFunctionBegin;
1205   if (alpha == PetscScalar(0.0)) {
1206     PetscCall(scale(yin, beta));
1207   } else if (beta == PetscScalar(1.0)) {
1208     PetscCall(axpy(yin, alpha, xin));
1209   } else if (alpha == PetscScalar(1.0)) {
1210     PetscCall(aypx(yin, beta, xin));
1211   } else if (const auto n = static_cast<cupmBlasInt_t>(yin->map->n)) {
1212     const auto         betaIsZero = beta == PetscScalar(0.0);
1213     const auto         aptr       = cupmScalarPtrCast(&alpha);
1214     PetscDeviceContext dctx;
1215     cupmBlasHandle_t   cupmBlasHandle;
1216 
1217     PetscCall(GetHandles_(&dctx, &cupmBlasHandle));
1218     {
1219       const auto xptr = DeviceArrayRead(dctx, xin);
1220 
1221       if (betaIsZero /* beta = 0 */) {
1222         // here we can get away with purely write-only as we memcpy into it first
1223         const auto   yptr = DeviceArrayWrite(dctx, yin);
1224         cupmStream_t stream;
1225 
1226         PetscCall(GetHandlesFrom_(dctx, &stream));
1227         PetscCall(PetscLogGpuTimeBegin());
1228         PetscCall(PetscCUPMMemcpyAsync(yptr.data(), xptr.data(), n, cupmMemcpyDeviceToDevice, stream));
1229         PetscCallCUPMBLAS(cupmBlasXscal(cupmBlasHandle, n, aptr, yptr.cupmdata(), 1));
1230       } else {
1231         const auto yptr = DeviceArrayReadWrite(dctx, yin);
1232 
1233         PetscCall(PetscLogGpuTimeBegin());
1234         PetscCallCUPMBLAS(cupmBlasXscal(cupmBlasHandle, n, cupmScalarPtrCast(&beta), yptr.cupmdata(), 1));
1235         PetscCallCUPMBLAS(cupmBlasXaxpy(cupmBlasHandle, n, aptr, xptr.cupmdata(), 1, yptr.cupmdata(), 1));
1236       }
1237     }
1238     PetscCall(PetscLogGpuTimeEnd());
1239     PetscCall(PetscLogGpuFlops((betaIsZero ? 1 : 3) * n));
1240     PetscCall(PetscDeviceContextSynchronize(dctx));
1241   } else {
1242     PetscCall(MaybeIncrementEmptyLocalVec(yin));
1243   }
1244   PetscFunctionReturn(PETSC_SUCCESS);
1245 }
1246 
1247 // v->ops->axpbypcz
1248 template <device::cupm::DeviceType T>
1249 inline PetscErrorCode VecSeq_CUPM<T>::axpbypcz(Vec zin, PetscScalar alpha, PetscScalar beta, PetscScalar gamma, Vec xin, Vec yin) noexcept
1250 {
1251   PetscFunctionBegin;
1252   if (gamma != PetscScalar(1.0)) PetscCall(scale(zin, gamma));
1253   PetscCall(axpy(zin, alpha, xin));
1254   PetscCall(axpy(zin, beta, yin));
1255   PetscFunctionReturn(PETSC_SUCCESS);
1256 }
1257 
1258 // v->ops->norm
1259 template <device::cupm::DeviceType T>
1260 inline PetscErrorCode VecSeq_CUPM<T>::norm(Vec xin, NormType type, PetscReal *z) noexcept
1261 {
1262   PetscDeviceContext dctx;
1263   cupmBlasHandle_t   cupmBlasHandle;
1264 
1265   PetscFunctionBegin;
1266   PetscCall(GetHandles_(&dctx, &cupmBlasHandle));
1267   if (const auto n = static_cast<cupmBlasInt_t>(xin->map->n)) {
1268     const auto xptr      = DeviceArrayRead(dctx, xin);
1269     PetscInt   flopCount = 0;
1270 
1271     PetscCall(PetscLogGpuTimeBegin());
1272     switch (type) {
1273     case NORM_1_AND_2:
1274     case NORM_1:
1275       PetscCallCUPMBLAS(cupmBlasXasum(cupmBlasHandle, n, xptr.cupmdata(), 1, cupmRealPtrCast(z)));
1276       flopCount = std::max(n - 1, 0);
1277       if (type == NORM_1) break;
1278       ++z; // fall-through
1279   #if PETSC_CPP_VERSION >= 17
1280       [[fallthrough]];
1281   #endif
1282     case NORM_2:
1283     case NORM_FROBENIUS:
1284       PetscCallCUPMBLAS(cupmBlasXnrm2(cupmBlasHandle, n, xptr.cupmdata(), 1, cupmRealPtrCast(z)));
1285       flopCount += std::max(2 * n - 1, 0); // += in case we've fallen through from NORM_1_AND_2
1286       break;
1287     case NORM_INFINITY: {
1288       cupmBlasInt_t max_loc = 0;
1289       PetscScalar   xv      = 0.;
1290       cupmStream_t  stream;
1291 
1292       PetscCall(GetHandlesFrom_(dctx, &stream));
1293       PetscCallCUPMBLAS(cupmBlasXamax(cupmBlasHandle, n, xptr.cupmdata(), 1, &max_loc));
1294       PetscCall(PetscCUPMMemcpyAsync(&xv, xptr.data() + max_loc - 1, 1, cupmMemcpyDeviceToHost, stream));
1295       *z = PetscAbsScalar(xv);
1296       // REVIEW ME: flopCount = ???
1297     } break;
1298     }
1299     PetscCall(PetscLogGpuTimeEnd());
1300     PetscCall(PetscLogGpuFlops(flopCount));
1301   } else {
1302     z[0]                    = 0.0;
1303     z[type == NORM_1_AND_2] = 0.0;
1304   }
1305   PetscFunctionReturn(PETSC_SUCCESS);
1306 }
1307 
1308 namespace detail
1309 {
1310 
1311 struct dotnorm2_mult {
1312   PETSC_NODISCARD PETSC_HOSTDEVICE_INLINE_DECL thrust::tuple<PetscScalar, PetscScalar> operator()(const PetscScalar &s, const PetscScalar &t) const noexcept
1313   {
1314     const auto conjt = PetscConj(t);
1315 
1316     return {s * conjt, t * conjt};
1317   }
1318 };
1319 
1320 // it is positively __bananas__ that thrust does not define default operator+ for tuples... I
1321 // would do it myself but now I am worried that they do so on purpose...
1322 struct dotnorm2_tuple_plus {
1323   using value_type = thrust::tuple<PetscScalar, PetscScalar>;
1324 
1325   PETSC_NODISCARD PETSC_HOSTDEVICE_INLINE_DECL value_type operator()(const value_type &lhs, const value_type &rhs) const noexcept { return {lhs.get<0>() + rhs.get<0>(), lhs.get<1>() + rhs.get<1>()}; }
1326 };
1327 
1328 } // namespace detail
1329 
1330 // v->ops->dotnorm2
1331 template <device::cupm::DeviceType T>
1332 inline PetscErrorCode VecSeq_CUPM<T>::dotnorm2(Vec s, Vec t, PetscScalar *dp, PetscScalar *nm) noexcept
1333 {
1334   PetscDeviceContext dctx;
1335   cupmStream_t       stream;
1336 
1337   PetscFunctionBegin;
1338   PetscCall(GetHandles_(&dctx, &stream));
1339   {
1340     PetscScalar dpt = 0.0, nmt = 0.0;
1341     const auto  sdptr = thrust::device_pointer_cast(DeviceArrayRead(dctx, s).data());
1342 
1343     // clang-format off
1344     PetscCallThrust(
1345       thrust::tie(*dp, *nm) = THRUST_CALL(
1346         thrust::inner_product,
1347         stream,
1348         sdptr, sdptr+s->map->n, thrust::device_pointer_cast(DeviceArrayRead(dctx, t).data()),
1349         thrust::make_tuple(dpt, nmt),
1350         detail::dotnorm2_tuple_plus{}, detail::dotnorm2_mult{}
1351       );
1352     );
1353     // clang-format on
1354   }
1355   PetscFunctionReturn(PETSC_SUCCESS);
1356 }
1357 
1358 namespace detail
1359 {
1360 
1361 struct conjugate {
1362   PETSC_NODISCARD PETSC_HOSTDEVICE_INLINE_DECL PetscScalar operator()(PetscScalar x) const noexcept { return PetscConj(x); }
1363 };
1364 
1365 } // namespace detail
1366 
1367 // v->ops->conjugate
1368 template <device::cupm::DeviceType T>
1369 inline PetscErrorCode VecSeq_CUPM<T>::conjugate(Vec xin) noexcept
1370 {
1371   PetscFunctionBegin;
1372   if (PetscDefined(USE_COMPLEX)) PetscCall(pointwiseunary_(detail::conjugate{}, xin));
1373   PetscFunctionReturn(PETSC_SUCCESS);
1374 }
1375 
1376 namespace detail
1377 {
1378 
1379 struct real_part {
1380   PETSC_NODISCARD PETSC_HOSTDEVICE_INLINE_DECL thrust::tuple<PetscReal, PetscInt> operator()(const thrust::tuple<PetscScalar, PetscInt> &x) const { return {PetscRealPart(x.get<0>()), x.get<1>()}; }
1381 
1382   PETSC_NODISCARD PETSC_HOSTDEVICE_INLINE_DECL PetscReal operator()(PetscScalar x) const { return PetscRealPart(x); }
1383 };
1384 
1385 // deriving from Operator allows us to "store" an instance of the operator in the class but
1386 // also take advantage of empty base class optimization if the operator is stateless
1387 template <typename Operator>
1388 class tuple_compare : Operator {
1389 public:
1390   using tuple_type    = thrust::tuple<PetscReal, PetscInt>;
1391   using operator_type = Operator;
1392 
1393   PETSC_NODISCARD PETSC_HOSTDEVICE_INLINE_DECL tuple_type operator()(const tuple_type &x, const tuple_type &y) const noexcept
1394   {
1395     if (op_()(y.get<0>(), x.get<0>())) {
1396       // if y is strictly greater/less than x, return y
1397       return y;
1398     } else if (y.get<0>() == x.get<0>()) {
1399       // if equal, prefer lower index
1400       return y.get<1>() < x.get<1>() ? y : x;
1401     }
1402     // otherwise return x
1403     return x;
1404   }
1405 
1406 private:
1407   PETSC_NODISCARD PETSC_HOSTDEVICE_INLINE_DECL const operator_type &op_() const noexcept { return *this; }
1408 };
1409 
1410 } // namespace detail
1411 
1412 template <device::cupm::DeviceType T>
1413 template <typename TupleFuncT, typename UnaryFuncT>
1414 inline PetscErrorCode VecSeq_CUPM<T>::minmax_(TupleFuncT &&tuple_ftr, UnaryFuncT &&unary_ftr, Vec v, PetscInt *p, PetscReal *m) noexcept
1415 {
1416   PetscFunctionBegin;
1417   PetscCheckTypeNames(v, VECSEQCUPM(), VECMPICUPM());
1418   if (p) *p = -1;
1419   if (const auto n = v->map->n) {
1420     PetscDeviceContext dctx;
1421     cupmStream_t       stream;
1422 
1423     PetscCall(GetHandles_(&dctx, &stream));
1424       // needed to:
1425       // 1. switch between transform_reduce and reduce
1426       // 2. strip the real_part functor from the arguments
1427   #if PetscDefined(USE_COMPLEX)
1428     #define THRUST_MINMAX_REDUCE(...) THRUST_CALL(thrust::transform_reduce, __VA_ARGS__)
1429   #else
1430     #define THRUST_MINMAX_REDUCE(s, b, e, real_part__, ...) THRUST_CALL(thrust::reduce, s, b, e, __VA_ARGS__)
1431   #endif
1432     {
1433       const auto vptr = thrust::device_pointer_cast(DeviceArrayRead(dctx, v).data());
1434 
1435       if (p) {
1436         // clang-format off
1437         const auto zip = thrust::make_zip_iterator(
1438           thrust::make_tuple(std::move(vptr), thrust::make_counting_iterator(PetscInt{0}))
1439         );
1440         // clang-format on
1441         // need to use preprocessor conditionals since otherwise thrust complains about not being
1442         // able to convert a thrust::device_reference<PetscScalar> to a PetscReal on complex
1443         // builds...
1444         // clang-format off
1445         PetscCallThrust(
1446           thrust::tie(*m, *p) = THRUST_MINMAX_REDUCE(
1447             stream, zip, zip + n, detail::real_part{},
1448             thrust::make_tuple(*m, *p), std::forward<TupleFuncT>(tuple_ftr)
1449           );
1450         );
1451         // clang-format on
1452       } else {
1453         // clang-format off
1454         PetscCallThrust(
1455           *m = THRUST_MINMAX_REDUCE(
1456             stream, vptr, vptr + n, detail::real_part{},
1457             *m, std::forward<UnaryFuncT>(unary_ftr)
1458           );
1459         );
1460         // clang-format on
1461       }
1462     }
1463   #undef THRUST_MINMAX_REDUCE
1464   }
1465   // REVIEW ME: flops?
1466   PetscFunctionReturn(PETSC_SUCCESS);
1467 }
1468 
1469 // v->ops->max
1470 template <device::cupm::DeviceType T>
1471 inline PetscErrorCode VecSeq_CUPM<T>::max(Vec v, PetscInt *p, PetscReal *m) noexcept
1472 {
1473   using tuple_functor = detail::tuple_compare<thrust::greater<PetscReal>>;
1474   using unary_functor = thrust::maximum<PetscReal>;
1475 
1476   PetscFunctionBegin;
1477   *m = PETSC_MIN_REAL;
1478   // use {} constructor syntax otherwise most vexing parse
1479   PetscCall(minmax_(tuple_functor{}, unary_functor{}, v, p, m));
1480   PetscFunctionReturn(PETSC_SUCCESS);
1481 }
1482 
1483 // v->ops->min
1484 template <device::cupm::DeviceType T>
1485 inline PetscErrorCode VecSeq_CUPM<T>::min(Vec v, PetscInt *p, PetscReal *m) noexcept
1486 {
1487   using tuple_functor = detail::tuple_compare<thrust::less<PetscReal>>;
1488   using unary_functor = thrust::minimum<PetscReal>;
1489 
1490   PetscFunctionBegin;
1491   *m = PETSC_MAX_REAL;
1492   // use {} constructor syntax otherwise most vexing parse
1493   PetscCall(minmax_(tuple_functor{}, unary_functor{}, v, p, m));
1494   PetscFunctionReturn(PETSC_SUCCESS);
1495 }
1496 
1497 // v->ops->sum
1498 template <device::cupm::DeviceType T>
1499 inline PetscErrorCode VecSeq_CUPM<T>::sum(Vec v, PetscScalar *sum) noexcept
1500 {
1501   PetscFunctionBegin;
1502   if (const auto n = v->map->n) {
1503     PetscDeviceContext dctx;
1504     cupmStream_t       stream;
1505 
1506     PetscCall(GetHandles_(&dctx, &stream));
1507     const auto dptr = thrust::device_pointer_cast(DeviceArrayRead(dctx, v).data());
1508     // REVIEW ME: why not cupmBlasXasum()?
1509     PetscCallThrust(*sum = THRUST_CALL(thrust::reduce, stream, dptr, dptr + n, PetscScalar{0.0}););
1510     // REVIEW ME: must be at least n additions
1511     PetscCall(PetscLogGpuFlops(n));
1512   } else {
1513     *sum = 0.0;
1514   }
1515   PetscFunctionReturn(PETSC_SUCCESS);
1516 }
1517 
1518 template <device::cupm::DeviceType T>
1519 inline PetscErrorCode VecSeq_CUPM<T>::shift(Vec v, PetscScalar shift) noexcept
1520 {
1521   PetscFunctionBegin;
1522   PetscCall(pointwiseunary_(device::cupm::functors::make_plus_equals(shift), v));
1523   PetscFunctionReturn(PETSC_SUCCESS);
1524 }
1525 
1526 template <device::cupm::DeviceType T>
1527 inline PetscErrorCode VecSeq_CUPM<T>::setrandom(Vec v, PetscRandom rand) noexcept
1528 {
1529   PetscFunctionBegin;
1530   if (const auto n = v->map->n) {
1531     PetscBool          iscurand;
1532     PetscDeviceContext dctx;
1533 
1534     PetscCall(GetHandles_(&dctx));
1535     PetscCall(PetscObjectTypeCompare(PetscObjectCast(rand), PETSCCURAND, &iscurand));
1536     if (iscurand) PetscCall(PetscRandomGetValues(rand, n, DeviceArrayWrite(dctx, v)));
1537     else PetscCall(PetscRandomGetValues(rand, n, HostArrayWrite(dctx, v)));
1538   } else {
1539     PetscCall(MaybeIncrementEmptyLocalVec(v));
1540   }
1541   // REVIEW ME: flops????
1542   // REVIEW ME: Timing???
1543   PetscFunctionReturn(PETSC_SUCCESS);
1544 }
1545 
1546 // v->ops->setpreallocation
1547 template <device::cupm::DeviceType T>
1548 inline PetscErrorCode VecSeq_CUPM<T>::setpreallocationcoo(Vec v, PetscCount ncoo, const PetscInt coo_i[]) noexcept
1549 {
1550   PetscDeviceContext dctx;
1551 
1552   PetscFunctionBegin;
1553   PetscCall(GetHandles_(&dctx));
1554   PetscCall(VecSetPreallocationCOO_Seq(v, ncoo, coo_i));
1555   PetscCall(SetPreallocationCOO_CUPMBase(v, ncoo, coo_i, dctx));
1556   PetscFunctionReturn(PETSC_SUCCESS);
1557 }
1558 
1559 namespace kernels
1560 {
1561 
1562 template <typename F>
1563 PETSC_DEVICE_INLINE_DECL void add_coo_values_impl(const PetscScalar *PETSC_RESTRICT vv, PetscCount n, const PetscCount *PETSC_RESTRICT jmap, const PetscCount *PETSC_RESTRICT perm, InsertMode imode, PetscScalar *PETSC_RESTRICT xv, F &&xvindex)
1564 {
1565   ::Petsc::device::cupm::kernels::util::grid_stride_1D(n, [=](PetscCount i) {
1566     const auto  end = jmap[i + 1];
1567     const auto  idx = xvindex(i);
1568     PetscScalar sum = 0.0;
1569 
1570     for (auto k = jmap[i]; k < end; ++k) sum += vv[perm[k]];
1571 
1572     if (imode == INSERT_VALUES) {
1573       xv[idx] = sum;
1574     } else {
1575       xv[idx] += sum;
1576     }
1577   });
1578   return;
1579 }
1580 
1581 namespace
1582 {
1583 
1584 PETSC_KERNEL_DECL void add_coo_values(const PetscScalar *PETSC_RESTRICT v, PetscCount n, const PetscCount *PETSC_RESTRICT jmap1, const PetscCount *PETSC_RESTRICT perm1, InsertMode imode, PetscScalar *PETSC_RESTRICT xv)
1585 {
1586   add_coo_values_impl(v, n, jmap1, perm1, imode, xv, [](PetscCount i) { return i; });
1587   return;
1588 }
1589 
1590 } // namespace
1591 
1592   #if PetscDefined(USING_HCC)
1593 namespace do_not_use
1594 {
1595 
1596 // Needed to silence clang warning:
1597 //
1598 // warning: function 'FUNCTION NAME' is not needed and will not be emitted
1599 //
1600 // The warning is silly, since the function *is* used, however the host compiler does not
1601 // appear see this. Likely because the function using it is in a template.
1602 //
1603 // This warning appeared in clang-11, and still persists until clang-15 (21/02/2023)
1604 inline void silence_warning_function_sum_kernel_is_not_needed_and_will_not_be_emitted()
1605 {
1606   (void)sum_kernel;
1607 }
1608 
1609 inline void silence_warning_function_add_coo_values_is_not_needed_and_will_not_be_emitted()
1610 {
1611   (void)add_coo_values;
1612 }
1613 
1614 } // namespace do_not_use
1615   #endif
1616 
1617 } // namespace kernels
1618 
1619 // v->ops->setvaluescoo
1620 template <device::cupm::DeviceType T>
1621 inline PetscErrorCode VecSeq_CUPM<T>::setvaluescoo(Vec x, const PetscScalar v[], InsertMode imode) noexcept
1622 {
1623   auto               vv = const_cast<PetscScalar *>(v);
1624   PetscMemType       memtype;
1625   PetscDeviceContext dctx;
1626   cupmStream_t       stream;
1627 
1628   PetscFunctionBegin;
1629   PetscCall(GetHandles_(&dctx, &stream));
1630   PetscCall(PetscGetMemType(v, &memtype));
1631   if (PetscMemTypeHost(memtype)) {
1632     const auto size = VecIMPLCast(x)->coo_n;
1633 
1634     // If user gave v[] in host, we might need to copy it to device if any
1635     PetscCall(PetscDeviceMalloc(dctx, PETSC_MEMTYPE_CUPM(), size, &vv));
1636     PetscCall(PetscCUPMMemcpyAsync(vv, v, size, cupmMemcpyHostToDevice, stream));
1637   }
1638 
1639   if (const auto n = x->map->n) {
1640     const auto vcu = VecCUPMCast(x);
1641 
1642     PetscCall(PetscCUPMLaunchKernel1D(n, 0, stream, kernels::add_coo_values, vv, n, vcu->jmap1_d, vcu->perm1_d, imode, imode == INSERT_VALUES ? DeviceArrayWrite(dctx, x).data() : DeviceArrayReadWrite(dctx, x).data()));
1643   } else {
1644     PetscCall(MaybeIncrementEmptyLocalVec(x));
1645   }
1646 
1647   if (PetscMemTypeHost(memtype)) PetscCall(PetscDeviceFree(dctx, vv));
1648   PetscCall(PetscDeviceContextSynchronize(dctx));
1649   PetscFunctionReturn(PETSC_SUCCESS);
1650 }
1651 
1652 // ==========================================================================================
1653 // VecSeq_CUPM - Implementations
1654 // ==========================================================================================
1655 
1656 namespace
1657 {
1658 
1659 template <typename T>
1660 inline PetscErrorCode VecCreateSeqCUPMAsync(T &&VecSeq_CUPM_Impls, MPI_Comm comm, PetscInt n, Vec *v) noexcept
1661 {
1662   PetscFunctionBegin;
1663   PetscValidPointer(v, 4);
1664   PetscCall(VecSeq_CUPM_Impls.createseqcupm(comm, 0, n, v, PETSC_TRUE));
1665   PetscFunctionReturn(PETSC_SUCCESS);
1666 }
1667 
1668 template <typename T>
1669 inline PetscErrorCode VecCreateSeqCUPMWithArraysAsync(T &&VecSeq_CUPM_Impls, MPI_Comm comm, PetscInt bs, PetscInt n, const PetscScalar cpuarray[], const PetscScalar gpuarray[], Vec *v) noexcept
1670 {
1671   PetscFunctionBegin;
1672   if (n && cpuarray) PetscValidScalarPointer(cpuarray, 5);
1673   PetscValidPointer(v, 7);
1674   PetscCall(VecSeq_CUPM_Impls.createseqcupmwithbotharrays(comm, bs, n, cpuarray, gpuarray, v));
1675   PetscFunctionReturn(PETSC_SUCCESS);
1676 }
1677 
1678 template <PetscMemoryAccessMode mode, typename T>
1679 inline PetscErrorCode VecCUPMGetArrayAsync_Private(T &&VecSeq_CUPM_Impls, Vec v, PetscScalar **a) noexcept
1680 {
1681   PetscFunctionBegin;
1682   PetscValidHeaderSpecific(v, VEC_CLASSID, 2);
1683   PetscValidPointer(a, 3);
1684   PetscCall(VecSeq_CUPM_Impls.template getarray<PETSC_MEMTYPE_DEVICE, mode>(v, a));
1685   PetscFunctionReturn(PETSC_SUCCESS);
1686 }
1687 
1688 template <PetscMemoryAccessMode mode, typename T>
1689 inline PetscErrorCode VecCUPMRestoreArrayAsync_Private(T &&VecSeq_CUPM_Impls, Vec v, PetscScalar **a) noexcept
1690 {
1691   PetscFunctionBegin;
1692   PetscValidHeaderSpecific(v, VEC_CLASSID, 2);
1693   PetscCall(VecSeq_CUPM_Impls.template restorearray<PETSC_MEMTYPE_DEVICE, mode>(v, a));
1694   PetscFunctionReturn(PETSC_SUCCESS);
1695 }
1696 
1697 template <typename T>
1698 inline PetscErrorCode VecCUPMGetArrayAsync(T &&VecSeq_CUPM_Impls, Vec v, PetscScalar **a) noexcept
1699 {
1700   PetscFunctionBegin;
1701   PetscCall(VecCUPMGetArrayAsync_Private<PETSC_MEMORY_ACCESS_READ_WRITE>(std::forward<T>(VecSeq_CUPM_Impls), v, a));
1702   PetscFunctionReturn(PETSC_SUCCESS);
1703 }
1704 
1705 template <typename T>
1706 inline PetscErrorCode VecCUPMRestoreArrayAsync(T &&VecSeq_CUPM_Impls, Vec v, PetscScalar **a) noexcept
1707 {
1708   PetscFunctionBegin;
1709   PetscCall(VecCUPMRestoreArrayAsync_Private<PETSC_MEMORY_ACCESS_READ_WRITE>(std::forward<T>(VecSeq_CUPM_Impls), v, a));
1710   PetscFunctionReturn(PETSC_SUCCESS);
1711 }
1712 
1713 template <typename T>
1714 inline PetscErrorCode VecCUPMGetArrayReadAsync(T &&VecSeq_CUPM_Impls, Vec v, const PetscScalar **a) noexcept
1715 {
1716   PetscFunctionBegin;
1717   PetscCall(VecCUPMGetArrayAsync_Private<PETSC_MEMORY_ACCESS_READ>(std::forward<T>(VecSeq_CUPM_Impls), v, const_cast<PetscScalar **>(a)));
1718   PetscFunctionReturn(PETSC_SUCCESS);
1719 }
1720 
1721 template <typename T>
1722 inline PetscErrorCode VecCUPMRestoreArrayReadAsync(T &&VecSeq_CUPM_Impls, Vec v, const PetscScalar **a) noexcept
1723 {
1724   PetscFunctionBegin;
1725   PetscCall(VecCUPMRestoreArrayAsync_Private<PETSC_MEMORY_ACCESS_READ>(std::forward<T>(VecSeq_CUPM_Impls), v, const_cast<PetscScalar **>(a)));
1726   PetscFunctionReturn(PETSC_SUCCESS);
1727 }
1728 
1729 template <typename T>
1730 inline PetscErrorCode VecCUPMGetArrayWriteAsync(T &&VecSeq_CUPM_Impls, Vec v, PetscScalar **a) noexcept
1731 {
1732   PetscFunctionBegin;
1733   PetscCall(VecCUPMGetArrayAsync_Private<PETSC_MEMORY_ACCESS_WRITE>(std::forward<T>(VecSeq_CUPM_Impls), v, a));
1734   PetscFunctionReturn(PETSC_SUCCESS);
1735 }
1736 
1737 template <typename T>
1738 inline PetscErrorCode VecCUPMRestoreArrayWriteAsync(T &&VecSeq_CUPM_Impls, Vec v, PetscScalar **a) noexcept
1739 {
1740   PetscFunctionBegin;
1741   PetscCall(VecCUPMRestoreArrayAsync_Private<PETSC_MEMORY_ACCESS_WRITE>(std::forward<T>(VecSeq_CUPM_Impls), v, a));
1742   PetscFunctionReturn(PETSC_SUCCESS);
1743 }
1744 
1745 template <typename T>
1746 inline PetscErrorCode VecCUPMPlaceArrayAsync(T &&VecSeq_CUPM_Impls, Vec vin, const PetscScalar a[]) noexcept
1747 {
1748   PetscFunctionBegin;
1749   PetscValidHeaderSpecific(vin, VEC_CLASSID, 2);
1750   PetscCall(VecSeq_CUPM_Impls.template placearray<PETSC_MEMTYPE_DEVICE>(vin, a));
1751   PetscFunctionReturn(PETSC_SUCCESS);
1752 }
1753 
1754 template <typename T>
1755 inline PetscErrorCode VecCUPMReplaceArrayAsync(T &&VecSeq_CUPM_Impls, Vec vin, const PetscScalar a[]) noexcept
1756 {
1757   PetscFunctionBegin;
1758   PetscValidHeaderSpecific(vin, VEC_CLASSID, 2);
1759   PetscCall(VecSeq_CUPM_Impls.template replacearray<PETSC_MEMTYPE_DEVICE>(vin, a));
1760   PetscFunctionReturn(PETSC_SUCCESS);
1761 }
1762 
1763 template <typename T>
1764 inline PetscErrorCode VecCUPMResetArrayAsync(T &&VecSeq_CUPM_Impls, Vec vin) noexcept
1765 {
1766   PetscFunctionBegin;
1767   PetscValidHeaderSpecific(vin, VEC_CLASSID, 2);
1768   PetscCall(VecSeq_CUPM_Impls.template resetarray<PETSC_MEMTYPE_DEVICE>(vin));
1769   PetscFunctionReturn(PETSC_SUCCESS);
1770 }
1771 
1772 } // anonymous namespace
1773 
1774 } // namespace impl
1775 
1776 } // namespace cupm
1777 
1778 } // namespace vec
1779 
1780 } // namespace Petsc
1781 
1782 #endif // __cplusplus
1783 
1784 #endif // PETSCVECSEQCUPM_HPP
1785