xref: /petsc/src/vec/vec/impls/seq/cupm/vecseqcupm.hpp (revision 9fd2a872fac71d98bdeeea4f800da4f052f7a9fb)
1 #ifndef PETSCVECSEQCUPM_HPP
2 #define PETSCVECSEQCUPM_HPP
3 
4 #include <petsc/private/veccupmimpl.h>
5 
6 #if defined(__cplusplus)
7   #include <petsc/private/randomimpl.h> // for _p_PetscRandom
8 
9   #include <petsc/private/cpp/utility.hpp> // util::exchange, util::index_sequence
10 
11   #include "../src/sys/objects/device/impls/cupm/cupmthrustutility.hpp"
12   #include "../src/sys/objects/device/impls/cupm/kernels.hpp"
13 
14   #if PetscDefined(USE_COMPLEX)
15     #include <thrust/transform_reduce.h>
16   #endif
17   #include <thrust/transform.h>
18   #include <thrust/reduce.h>
19   #include <thrust/functional.h>
20   #include <thrust/tuple.h>
21   #include <thrust/device_ptr.h>
22   #include <thrust/iterator/zip_iterator.h>
23   #include <thrust/iterator/counting_iterator.h>
24   #include <thrust/inner_product.h>
25 
26 namespace Petsc
27 {
28 
29 namespace vec
30 {
31 
32 namespace cupm
33 {
34 
35 namespace impl
36 {
37 
38 // ==========================================================================================
39 // VecSeq_CUPM
40 // ==========================================================================================
41 
42 template <device::cupm::DeviceType T>
43 class VecSeq_CUPM : Vec_CUPMBase<T, VecSeq_CUPM<T>> {
44 public:
45   PETSC_VEC_CUPM_BASE_CLASS_HEADER(base_type, T, VecSeq_CUPM<T>);
46 
47 private:
48   PETSC_NODISCARD static Vec_Seq          *VecIMPLCast_(Vec) noexcept;
49   PETSC_NODISCARD static constexpr VecType VECIMPLCUPM_() noexcept;
50 
51   static PetscErrorCode VecDestroy_IMPL_(Vec) noexcept;
52   static PetscErrorCode VecResetArray_IMPL_(Vec) noexcept;
53   static PetscErrorCode VecPlaceArray_IMPL_(Vec, const PetscScalar *) noexcept;
54   static PetscErrorCode VecCreate_IMPL_Private_(Vec, PetscBool *, PetscInt, PetscScalar *) noexcept;
55 
56   static PetscErrorCode MaybeIncrementEmptyLocalVec(Vec) noexcept;
57 
58   // common core for min and max
59   template <typename TupleFuncT, typename UnaryFuncT>
60   static PetscErrorCode MinMax_(TupleFuncT &&, UnaryFuncT &&, Vec, PetscInt *, PetscReal *) noexcept;
61   // common core for pointwise binary and pointwise unary thrust functions
62   template <typename BinaryFuncT>
63   static PetscErrorCode PointwiseBinary_(BinaryFuncT &&, Vec, Vec, Vec) noexcept;
64   template <typename UnaryFuncT>
65   static PetscErrorCode PointwiseUnary_(UnaryFuncT &&, Vec, Vec /*out*/ = nullptr) noexcept;
66   // mdot dispatchers
67   static PetscErrorCode MDot_(/* use complex = */ std::true_type, Vec, PetscInt, const Vec[], PetscScalar *, PetscDeviceContext) noexcept;
68   static PetscErrorCode MDot_(/* use complex = */ std::false_type, Vec, PetscInt, const Vec[], PetscScalar *, PetscDeviceContext) noexcept;
69   template <std::size_t... Idx>
70   static PetscErrorCode MDot_kernel_dispatch_(PetscDeviceContext, cupmStream_t, const PetscScalar *, const Vec[], PetscInt, PetscScalar *, util::index_sequence<Idx...>) noexcept;
71   template <int>
72   static PetscErrorCode MDot_kernel_dispatch_(PetscDeviceContext, cupmStream_t, const PetscScalar *, const Vec[], PetscInt, PetscScalar *, PetscInt &) noexcept;
73   template <std::size_t... Idx>
74   static PetscErrorCode MAXPY_kernel_dispatch_(PetscDeviceContext, cupmStream_t, PetscScalar *, const PetscScalar *, const Vec *, PetscInt, util::index_sequence<Idx...>) noexcept;
75   template <int>
76   static PetscErrorCode MAXPY_kernel_dispatch_(PetscDeviceContext, cupmStream_t, PetscScalar *, const PetscScalar *, const Vec *, PetscInt, PetscInt &) noexcept;
77   // common core for the various create routines
78   static PetscErrorCode CreateSeqCUPM_(Vec, PetscDeviceContext, PetscScalar * /*host_ptr*/ = nullptr, PetscScalar * /*device_ptr*/ = nullptr) noexcept;
79 
80 public:
81   // callable directly via a bespoke function
82   static PetscErrorCode CreateSeqCUPM(MPI_Comm, PetscInt, PetscInt, Vec *, PetscBool) noexcept;
83   static PetscErrorCode CreateSeqCUPMWithBothArrays(MPI_Comm, PetscInt, PetscInt, const PetscScalar[], const PetscScalar[], Vec *) noexcept;
84 
85   // callable indirectly via function pointers
86   static PetscErrorCode Duplicate(Vec, Vec *) noexcept;
87   static PetscErrorCode AYPX(Vec, PetscScalar, Vec) noexcept;
88   static PetscErrorCode AXPY(Vec, PetscScalar, Vec) noexcept;
89   static PetscErrorCode PointwiseDivide(Vec, Vec, Vec) noexcept;
90   static PetscErrorCode PointwiseMult(Vec, Vec, Vec) noexcept;
91   static PetscErrorCode Reciprocal(Vec) noexcept;
92   static PetscErrorCode WAXPY(Vec, PetscScalar, Vec, Vec) noexcept;
93   static PetscErrorCode MAXPY(Vec, PetscInt, const PetscScalar[], Vec *) noexcept;
94   static PetscErrorCode Dot(Vec, Vec, PetscScalar *) noexcept;
95   static PetscErrorCode MDot(Vec, PetscInt, const Vec[], PetscScalar *) noexcept;
96   static PetscErrorCode Set(Vec, PetscScalar) noexcept;
97   static PetscErrorCode Scale(Vec, PetscScalar) noexcept;
98   static PetscErrorCode TDot(Vec, Vec, PetscScalar *) noexcept;
99   static PetscErrorCode Copy(Vec, Vec) noexcept;
100   static PetscErrorCode Swap(Vec, Vec) noexcept;
101   static PetscErrorCode AXPBY(Vec, PetscScalar, PetscScalar, Vec) noexcept;
102   static PetscErrorCode AXPBYPCZ(Vec, PetscScalar, PetscScalar, PetscScalar, Vec, Vec) noexcept;
103   static PetscErrorCode Norm(Vec, NormType, PetscReal *) noexcept;
104   static PetscErrorCode DotNorm2(Vec, Vec, PetscScalar *, PetscScalar *) noexcept;
105   static PetscErrorCode Conjugate(Vec) noexcept;
106   template <PetscMemoryAccessMode>
107   static PetscErrorCode GetLocalVector(Vec, Vec) noexcept;
108   template <PetscMemoryAccessMode>
109   static PetscErrorCode RestoreLocalVector(Vec, Vec) noexcept;
110   static PetscErrorCode Max(Vec, PetscInt *, PetscReal *) noexcept;
111   static PetscErrorCode Min(Vec, PetscInt *, PetscReal *) noexcept;
112   static PetscErrorCode Sum(Vec, PetscScalar *) noexcept;
113   static PetscErrorCode Shift(Vec, PetscScalar) noexcept;
114   static PetscErrorCode SetRandom(Vec, PetscRandom) noexcept;
115   static PetscErrorCode BindToCPU(Vec, PetscBool) noexcept;
116   static PetscErrorCode SetPreallocationCOO(Vec, PetscCount, const PetscInt[]) noexcept;
117   static PetscErrorCode SetValuesCOO(Vec, const PetscScalar[], InsertMode) noexcept;
118 };
119 
120 // ==========================================================================================
121 // VecSeq_CUPM - Private API
122 // ==========================================================================================
123 
124 template <device::cupm::DeviceType T>
125 inline Vec_Seq *VecSeq_CUPM<T>::VecIMPLCast_(Vec v) noexcept
126 {
127   return static_cast<Vec_Seq *>(v->data);
128 }
129 
130 template <device::cupm::DeviceType T>
131 inline constexpr VecType VecSeq_CUPM<T>::VECIMPLCUPM_() noexcept
132 {
133   return VECSEQCUPM();
134 }
135 
136 template <device::cupm::DeviceType T>
137 inline PetscErrorCode VecSeq_CUPM<T>::VecDestroy_IMPL_(Vec v) noexcept
138 {
139   return VecDestroy_Seq(v);
140 }
141 
142 template <device::cupm::DeviceType T>
143 inline PetscErrorCode VecSeq_CUPM<T>::VecResetArray_IMPL_(Vec v) noexcept
144 {
145   return VecResetArray_Seq(v);
146 }
147 
148 template <device::cupm::DeviceType T>
149 inline PetscErrorCode VecSeq_CUPM<T>::VecPlaceArray_IMPL_(Vec v, const PetscScalar *a) noexcept
150 {
151   return VecPlaceArray_Seq(v, a);
152 }
153 
154 template <device::cupm::DeviceType T>
155 inline PetscErrorCode VecSeq_CUPM<T>::VecCreate_IMPL_Private_(Vec v, PetscBool *alloc_missing, PetscInt, PetscScalar *host_array) noexcept
156 {
157   PetscMPIInt size;
158 
159   PetscFunctionBegin;
160   if (alloc_missing) *alloc_missing = PETSC_FALSE;
161   PetscCallMPI(MPI_Comm_size(PetscObjectComm(PetscObjectCast(v)), &size));
162   PetscCheck(size <= 1, PETSC_COMM_SELF, PETSC_ERR_ARG_WRONG, "Must create VecSeq on communicator of size 1, have size %d", size);
163   PetscCall(VecCreate_Seq_Private(v, host_array));
164   PetscFunctionReturn(PETSC_SUCCESS);
165 }
166 
167 // for functions with an early return based one vec size we still need to artificially bump the
168 // object state. This is to prevent the following:
169 //
170 // 0. Suppose you have a Vec {
171 //   rank 0: [0],
172 //   rank 1: [<empty>]
173 // }
174 // 1. both ranks have Vec with PetscObjectState = 0, stashed norm of 0
175 // 2. Vec enters e.g. VecSet(10)
176 // 3. rank 1 has local size 0 and bails immediately
177 // 4. rank 0 has local size 1 and enters function, eventually calls DeviceArrayWrite()
178 // 5. DeviceArrayWrite() calls PetscObjectStateIncrease(), now state = 1
179 // 6. Vec enters VecNorm(), and calls VecNormAvailable()
180 // 7. rank 1 has object state = 0, equal to stash and returns early with norm = 0
181 // 8. rank 0 has object state = 1, not equal to stash, continues to impl function
182 // 9. rank 0 deadlocks on MPI_Allreduce() because rank 1 bailed early
183 template <device::cupm::DeviceType T>
184 inline PetscErrorCode VecSeq_CUPM<T>::MaybeIncrementEmptyLocalVec(Vec v) noexcept
185 {
186   PetscFunctionBegin;
187   if (PetscUnlikely((v->map->n == 0) && (v->map->N != 0))) PetscCall(PetscObjectStateIncrease(PetscObjectCast(v)));
188   PetscFunctionReturn(PETSC_SUCCESS);
189 }
190 
191 template <device::cupm::DeviceType T>
192 inline PetscErrorCode VecSeq_CUPM<T>::CreateSeqCUPM_(Vec v, PetscDeviceContext dctx, PetscScalar *host_array, PetscScalar *device_array) noexcept
193 {
194   PetscFunctionBegin;
195   PetscCall(base_type::VecCreate_IMPL_Private(v, nullptr, 0, host_array));
196   PetscCall(Initialize_CUPMBase(v, PETSC_FALSE, host_array, device_array, dctx));
197   PetscFunctionReturn(PETSC_SUCCESS);
198 }
199 
200 template <device::cupm::DeviceType T>
201 template <typename BinaryFuncT>
202 inline PetscErrorCode VecSeq_CUPM<T>::PointwiseBinary_(BinaryFuncT &&binary, Vec xin, Vec yin, Vec zout) noexcept
203 {
204   PetscFunctionBegin;
205   if (const auto n = zout->map->n) {
206     PetscDeviceContext dctx;
207     cupmStream_t       stream;
208 
209     PetscCall(GetHandles_(&dctx, &stream));
210     // clang-format off
211     PetscCallThrust(
212       const auto dxptr = thrust::device_pointer_cast(DeviceArrayRead(dctx, xin).data());
213 
214       THRUST_CALL(
215         thrust::transform,
216         stream,
217         dxptr, dxptr + n,
218         thrust::device_pointer_cast(DeviceArrayRead(dctx, yin).data()),
219         thrust::device_pointer_cast(DeviceArrayWrite(dctx, zout).data()),
220         std::forward<BinaryFuncT>(binary)
221       )
222     );
223     // clang-format on
224     PetscCall(PetscLogFlops(n));
225     PetscCall(PetscDeviceContextSynchronize(dctx));
226   } else {
227     PetscCall(MaybeIncrementEmptyLocalVec(zout));
228   }
229   PetscFunctionReturn(PETSC_SUCCESS);
230 }
231 
232 template <device::cupm::DeviceType T>
233 template <typename UnaryFuncT>
234 inline PetscErrorCode VecSeq_CUPM<T>::PointwiseUnary_(UnaryFuncT &&unary, Vec xinout, Vec yin) noexcept
235 {
236   const auto inplace = !yin || (xinout == yin);
237 
238   PetscFunctionBegin;
239   if (const auto n = xinout->map->n) {
240     PetscDeviceContext dctx;
241     cupmStream_t       stream;
242     const auto         apply = [&](PetscScalar *xinout, PetscScalar *yin = nullptr) {
243       PetscFunctionBegin;
244       // clang-format off
245       PetscCallThrust(
246         const auto xptr = thrust::device_pointer_cast(xinout);
247 
248         THRUST_CALL(
249           thrust::transform,
250           stream,
251           xptr, xptr + n,
252           (yin && (yin != xinout)) ? thrust::device_pointer_cast(yin) : xptr,
253           std::forward<UnaryFuncT>(unary)
254         )
255       );
256       PetscFunctionReturn(PETSC_SUCCESS);
257     };
258 
259     PetscCall(GetHandles_(&dctx, &stream));
260     if (inplace) {
261       PetscCall(apply(DeviceArrayReadWrite(dctx, xinout).data()));
262     } else {
263       PetscCall(apply(DeviceArrayRead(dctx, xinout).data(), DeviceArrayWrite(dctx, yin).data()));
264     }
265     PetscCall(PetscLogFlops(n));
266     PetscCall(PetscDeviceContextSynchronize(dctx));
267   } else {
268     if (inplace) {
269       PetscCall(MaybeIncrementEmptyLocalVec(xinout));
270     } else {
271       PetscCall(MaybeIncrementEmptyLocalVec(yin));
272     }
273   }
274   PetscFunctionReturn(PETSC_SUCCESS);
275 }
276 
277 // ==========================================================================================
278 // VecSeq_CUPM - Public API - Constructors
279 // ==========================================================================================
280 
281 // VecCreateSeqCUPM()
282 template <device::cupm::DeviceType T>
283 inline PetscErrorCode VecSeq_CUPM<T>::CreateSeqCUPM(MPI_Comm comm, PetscInt bs, PetscInt n, Vec *v, PetscBool call_set_type) noexcept
284 {
285   PetscFunctionBegin;
286   PetscCall(Create_CUPMBase(comm, bs, n, n, v, call_set_type));
287   PetscFunctionReturn(PETSC_SUCCESS);
288 }
289 
290 // VecCreateSeqCUPMWithArrays()
291 template <device::cupm::DeviceType T>
292 inline PetscErrorCode VecSeq_CUPM<T>::CreateSeqCUPMWithBothArrays(MPI_Comm comm, PetscInt bs, PetscInt n, const PetscScalar host_array[], const PetscScalar device_array[], Vec *v) noexcept
293 {
294   PetscDeviceContext dctx;
295 
296   PetscFunctionBegin;
297   PetscCall(GetHandles_(&dctx));
298   // do NOT call VecSetType(), otherwise ops->create() -> create() ->
299   // CreateSeqCUPM_() is called!
300   PetscCall(CreateSeqCUPM(comm, bs, n, v, PETSC_FALSE));
301   PetscCall(CreateSeqCUPM_(*v, dctx, PetscRemoveConstCast(host_array), PetscRemoveConstCast(device_array)));
302   PetscFunctionReturn(PETSC_SUCCESS);
303 }
304 
305 // v->ops->duplicate
306 template <device::cupm::DeviceType T>
307 inline PetscErrorCode VecSeq_CUPM<T>::Duplicate(Vec v, Vec *y) noexcept
308 {
309   PetscDeviceContext dctx;
310 
311   PetscFunctionBegin;
312   PetscCall(GetHandles_(&dctx));
313   PetscCall(Duplicate_CUPMBase(v, y, dctx));
314   PetscFunctionReturn(PETSC_SUCCESS);
315 }
316 
317 // ==========================================================================================
318 // VecSeq_CUPM - Public API - Utility
319 // ==========================================================================================
320 
321 // v->ops->bindtocpu
322 template <device::cupm::DeviceType T>
323 inline PetscErrorCode VecSeq_CUPM<T>::BindToCPU(Vec v, PetscBool usehost) noexcept
324 {
325   PetscDeviceContext dctx;
326 
327   PetscFunctionBegin;
328   PetscCall(GetHandles_(&dctx));
329   PetscCall(BindToCPU_CUPMBase(v, usehost, dctx));
330 
331   // REVIEW ME: this absolutely should be some sort of bulk mempcy rather than this mess
332   VecSetOp_CUPM(dot, VecDot_Seq, Dot);
333   VecSetOp_CUPM(norm, VecNorm_Seq, Norm);
334   VecSetOp_CUPM(tdot, VecTDot_Seq, TDot);
335   VecSetOp_CUPM(mdot, VecMDot_Seq, MDot);
336   VecSetOp_CUPM(resetarray, VecResetArray_Seq, base_type::template ResetArray<PETSC_MEMTYPE_HOST>);
337   VecSetOp_CUPM(placearray, VecPlaceArray_Seq, base_type::template PlaceArray<PETSC_MEMTYPE_HOST>);
338   v->ops->mtdot = v->ops->mtdot_local = VecMTDot_Seq;
339   VecSetOp_CUPM(conjugate, VecConjugate_Seq, Conjugate);
340   VecSetOp_CUPM(max, VecMax_Seq, Max);
341   VecSetOp_CUPM(min, VecMin_Seq, Min);
342   VecSetOp_CUPM(setpreallocationcoo, VecSetPreallocationCOO_Seq, SetPreallocationCOO);
343   VecSetOp_CUPM(setvaluescoo, VecSetValuesCOO_Seq, SetValuesCOO);
344   PetscFunctionReturn(PETSC_SUCCESS);
345 }
346 
347 // ==========================================================================================
348 // VecSeq_CUPM - Public API - Mutators
349 // ==========================================================================================
350 
351 // v->ops->getlocalvector or v->ops->getlocalvectorread
352 template <device::cupm::DeviceType T>
353 template <PetscMemoryAccessMode access>
354 inline PetscErrorCode VecSeq_CUPM<T>::GetLocalVector(Vec v, Vec w) noexcept
355 {
356   PetscBool wisseqcupm;
357 
358   PetscFunctionBegin;
359   PetscCheckTypeNames(v, VECSEQCUPM(), VECMPICUPM());
360   PetscCall(PetscObjectTypeCompare(PetscObjectCast(w), VECSEQCUPM(), &wisseqcupm));
361   if (wisseqcupm) {
362     if (const auto wseq = VecIMPLCast(w)) {
363       if (auto &alloced = wseq->array_allocated) {
364         const auto useit = UseCUPMHostAlloc(util::exchange(w->pinned_memory, PETSC_FALSE));
365 
366         PetscCall(PetscFree(alloced));
367       }
368       wseq->array         = nullptr;
369       wseq->unplacedarray = nullptr;
370     }
371     if (const auto wcu = VecCUPMCast(w)) {
372       if (auto &device_array = wcu->array_d) {
373         cupmStream_t stream;
374 
375         PetscCall(GetHandles_(&stream));
376         PetscCallCUPM(cupmFreeAsync(device_array, stream));
377       }
378       PetscCall(PetscFree(w->spptr /* wcu */));
379     }
380   }
381   if (v->petscnative && wisseqcupm) {
382     PetscCall(PetscFree(w->data));
383     w->data          = v->data;
384     w->offloadmask   = v->offloadmask;
385     w->pinned_memory = v->pinned_memory;
386     w->spptr         = v->spptr;
387     PetscCall(PetscObjectStateIncrease(PetscObjectCast(w)));
388   } else {
389     const auto array = &VecIMPLCast(w)->array;
390 
391     if (access == PETSC_MEMORY_ACCESS_READ) {
392       PetscCall(VecGetArrayRead(v, const_cast<const PetscScalar **>(array)));
393     } else {
394       PetscCall(VecGetArray(v, array));
395     }
396     w->offloadmask = PETSC_OFFLOAD_CPU;
397     if (wisseqcupm) {
398       PetscDeviceContext dctx;
399 
400       PetscCall(GetHandles_(&dctx));
401       PetscCall(DeviceAllocateCheck_(dctx, w));
402     }
403   }
404   PetscFunctionReturn(PETSC_SUCCESS);
405 }
406 
407 // v->ops->restorelocalvector or v->ops->restorelocalvectorread
408 template <device::cupm::DeviceType T>
409 template <PetscMemoryAccessMode access>
410 inline PetscErrorCode VecSeq_CUPM<T>::RestoreLocalVector(Vec v, Vec w) noexcept
411 {
412   PetscBool wisseqcupm;
413 
414   PetscFunctionBegin;
415   PetscCheckTypeNames(v, VECSEQCUPM(), VECMPICUPM());
416   PetscCall(PetscObjectTypeCompare(PetscObjectCast(w), VECSEQCUPM(), &wisseqcupm));
417   if (v->petscnative && wisseqcupm) {
418     // the assignments to nullptr are __critical__, as w may persist after this call returns
419     // and shouldn't share data with v!
420     v->pinned_memory = w->pinned_memory;
421     v->offloadmask   = util::exchange(w->offloadmask, PETSC_OFFLOAD_UNALLOCATED);
422     v->data          = util::exchange(w->data, nullptr);
423     v->spptr         = util::exchange(w->spptr, nullptr);
424   } else {
425     const auto array = &VecIMPLCast(w)->array;
426 
427     if (access == PETSC_MEMORY_ACCESS_READ) {
428       PetscCall(VecRestoreArrayRead(v, const_cast<const PetscScalar **>(array)));
429     } else {
430       PetscCall(VecRestoreArray(v, array));
431     }
432     if (w->spptr && wisseqcupm) {
433       cupmStream_t stream;
434 
435       PetscCall(GetHandles_(&stream));
436       PetscCallCUPM(cupmFreeAsync(VecCUPMCast(w)->array_d, stream));
437       PetscCall(PetscFree(w->spptr));
438     }
439   }
440   PetscFunctionReturn(PETSC_SUCCESS);
441 }
442 
443 // ==========================================================================================
444 // VecSeq_CUPM - Public API - Compute Methods
445 // ==========================================================================================
446 
447 // v->ops->aypx
448 template <device::cupm::DeviceType T>
449 inline PetscErrorCode VecSeq_CUPM<T>::AYPX(Vec yin, PetscScalar alpha, Vec xin) noexcept
450 {
451   const auto         n    = static_cast<cupmBlasInt_t>(yin->map->n);
452   const auto         sync = n != 0;
453   PetscDeviceContext dctx;
454 
455   PetscFunctionBegin;
456   PetscCall(GetHandles_(&dctx));
457   if (alpha == PetscScalar(0.0)) {
458     cupmStream_t stream;
459 
460     PetscCall(GetHandlesFrom_(dctx, &stream));
461     PetscCall(PetscLogGpuTimeBegin());
462     PetscCall(PetscCUPMMemcpyAsync(DeviceArrayWrite(dctx, yin).data(), DeviceArrayRead(dctx, xin).data(), n, cupmMemcpyDeviceToDevice, stream));
463     PetscCall(PetscLogGpuTimeEnd());
464   } else if (n) {
465     const auto       alphaIsOne = alpha == PetscScalar(1.0);
466     const auto       calpha     = cupmScalarPtrCast(&alpha);
467     cupmBlasHandle_t cupmBlasHandle;
468 
469     PetscCall(GetHandlesFrom_(dctx, &cupmBlasHandle));
470     {
471       const auto yptr = DeviceArrayReadWrite(dctx, yin);
472       const auto xptr = DeviceArrayRead(dctx, xin);
473 
474       PetscCall(PetscLogGpuTimeBegin());
475       if (alphaIsOne) {
476         PetscCallCUPMBLAS(cupmBlasXaxpy(cupmBlasHandle, n, calpha, xptr.cupmdata(), 1, yptr.cupmdata(), 1));
477       } else {
478         const auto one = cupmScalarCast(1.0);
479 
480         PetscCallCUPMBLAS(cupmBlasXscal(cupmBlasHandle, n, calpha, yptr.cupmdata(), 1));
481         PetscCallCUPMBLAS(cupmBlasXaxpy(cupmBlasHandle, n, &one, xptr.cupmdata(), 1, yptr.cupmdata(), 1));
482       }
483       PetscCall(PetscLogGpuTimeEnd());
484     }
485     PetscCall(PetscLogGpuFlops((alphaIsOne ? 1 : 2) * n));
486   }
487   if (sync) PetscCall(PetscDeviceContextSynchronize(dctx));
488   PetscFunctionReturn(PETSC_SUCCESS);
489 }
490 
491 // v->ops->axpy
492 template <device::cupm::DeviceType T>
493 inline PetscErrorCode VecSeq_CUPM<T>::AXPY(Vec yin, PetscScalar alpha, Vec xin) noexcept
494 {
495   PetscBool xiscupm;
496 
497   PetscFunctionBegin;
498   PetscCall(PetscObjectTypeCompareAny(PetscObjectCast(xin), &xiscupm, VECSEQCUPM(), VECMPICUPM(), ""));
499   if (xiscupm) {
500     const auto         n = static_cast<cupmBlasInt_t>(yin->map->n);
501     PetscDeviceContext dctx;
502     cupmBlasHandle_t   cupmBlasHandle;
503 
504     PetscCall(GetHandles_(&dctx, &cupmBlasHandle));
505     PetscCall(PetscLogGpuTimeBegin());
506     PetscCallCUPMBLAS(cupmBlasXaxpy(cupmBlasHandle, n, cupmScalarPtrCast(&alpha), DeviceArrayRead(dctx, xin), 1, DeviceArrayReadWrite(dctx, yin), 1));
507     PetscCall(PetscLogGpuTimeEnd());
508     PetscCall(PetscLogGpuFlops(2 * n));
509     PetscCall(PetscDeviceContextSynchronize(dctx));
510   } else {
511     PetscCall(VecAXPY_Seq(yin, alpha, xin));
512   }
513   PetscFunctionReturn(PETSC_SUCCESS);
514 }
515 
516 // v->ops->pointwisedivide
517 template <device::cupm::DeviceType T>
518 inline PetscErrorCode VecSeq_CUPM<T>::PointwiseDivide(Vec win, Vec xin, Vec yin) noexcept
519 {
520   PetscFunctionBegin;
521   if (xin->boundtocpu || yin->boundtocpu) {
522     PetscCall(VecPointwiseDivide_Seq(win, xin, yin));
523   } else {
524     // note order of arguments! xin and yin are read, win is written!
525     PetscCall(PointwiseBinary_(thrust::divides<PetscScalar>{}, xin, yin, win));
526   }
527   PetscFunctionReturn(PETSC_SUCCESS);
528 }
529 
530 // v->ops->pointwisemult
531 template <device::cupm::DeviceType T>
532 inline PetscErrorCode VecSeq_CUPM<T>::PointwiseMult(Vec win, Vec xin, Vec yin) noexcept
533 {
534   PetscFunctionBegin;
535   if (xin->boundtocpu || yin->boundtocpu) {
536     PetscCall(VecPointwiseMult_Seq(win, xin, yin));
537   } else {
538     // note order of arguments! xin and yin are read, win is written!
539     PetscCall(PointwiseBinary_(thrust::multiplies<PetscScalar>{}, xin, yin, win));
540   }
541   PetscFunctionReturn(PETSC_SUCCESS);
542 }
543 
544 namespace detail
545 {
546 
547 struct reciprocal {
548   PETSC_HOSTDEVICE_INLINE_DECL PetscScalar operator()(PetscScalar s) const noexcept
549   {
550     // yes all of this verbosity is needed because sometimes PetscScalar is a thrust::complex
551     // and then it matters whether we do s ? true : false vs s == 0, as well as whether we wrap
552     // everything in PetscScalar...
553     return s == PetscScalar{0.0} ? s : PetscScalar{1.0} / s;
554   }
555 };
556 
557 } // namespace detail
558 
559 // v->ops->reciprocal
560 template <device::cupm::DeviceType T>
561 inline PetscErrorCode VecSeq_CUPM<T>::Reciprocal(Vec xin) noexcept
562 {
563   PetscFunctionBegin;
564   PetscCall(PointwiseUnary_(detail::reciprocal{}, xin));
565   PetscFunctionReturn(PETSC_SUCCESS);
566 }
567 
568 // v->ops->waxpy
569 template <device::cupm::DeviceType T>
570 inline PetscErrorCode VecSeq_CUPM<T>::WAXPY(Vec win, PetscScalar alpha, Vec xin, Vec yin) noexcept
571 {
572   PetscFunctionBegin;
573   if (alpha == PetscScalar(0.0)) {
574     PetscCall(Copy(yin, win));
575   } else if (const auto n = static_cast<cupmBlasInt_t>(win->map->n)) {
576     PetscDeviceContext dctx;
577     cupmBlasHandle_t   cupmBlasHandle;
578     cupmStream_t       stream;
579 
580     PetscCall(GetHandles_(&dctx, &cupmBlasHandle, &stream));
581     {
582       const auto wptr = DeviceArrayWrite(dctx, win);
583 
584       PetscCall(PetscLogGpuTimeBegin());
585       PetscCall(PetscCUPMMemcpyAsync(wptr.data(), DeviceArrayRead(dctx, yin).data(), n, cupmMemcpyDeviceToDevice, stream, true));
586       PetscCallCUPMBLAS(cupmBlasXaxpy(cupmBlasHandle, n, cupmScalarPtrCast(&alpha), DeviceArrayRead(dctx, xin), 1, wptr.cupmdata(), 1));
587       PetscCall(PetscLogGpuTimeEnd());
588     }
589     PetscCall(PetscLogGpuFlops(2 * n));
590     PetscCall(PetscDeviceContextSynchronize(dctx));
591   }
592   PetscFunctionReturn(PETSC_SUCCESS);
593 }
594 
595 namespace kernels
596 {
597 
598 template <typename... Args>
599 PETSC_KERNEL_DECL static void MAXPY_kernel(const PetscInt size, PetscScalar *PETSC_RESTRICT xptr, const PetscScalar *PETSC_RESTRICT aptr, Args... yptr)
600 {
601   constexpr int      N        = sizeof...(Args);
602   const auto         tx       = threadIdx.x;
603   const PetscScalar *yptr_p[] = {yptr...};
604 
605   PETSC_SHAREDMEM_DECL PetscScalar aptr_shmem[N];
606 
607   // load a to shared memory
608   if (tx < N) aptr_shmem[tx] = aptr[tx];
609   __syncthreads();
610 
611   ::Petsc::device::cupm::kernels::util::grid_stride_1D(size, [&](PetscInt i) {
612     // these may look the same but give different results!
613   #if 0
614     PetscScalar sum = 0.0;
615 
616     #pragma unroll
617     for (auto j = 0; j < N; ++j) sum += aptr_shmem[j]*yptr_p[j][i];
618     xptr[i] += sum;
619   #else
620     auto sum = xptr[i];
621 
622     #pragma unroll
623     for (auto j = 0; j < N; ++j) sum += aptr_shmem[j]*yptr_p[j][i];
624     xptr[i] = sum;
625   #endif
626   });
627   return;
628 }
629 
630 } // namespace kernels
631 
632 namespace detail
633 {
634 
635 // a helper-struct to gobble the size_t input, it is used with template parameter pack
636 // expansion such that
637 // typename repeat_type<MyType, IdxParamPack>...
638 // expands to
639 // MyType, MyType, MyType, ... [repeated sizeof...(IdxParamPack) times]
640 template <typename T, std::size_t>
641 struct repeat_type {
642   using type = T;
643 };
644 
645 } // namespace detail
646 
647 template <device::cupm::DeviceType T>
648 template <std::size_t... Idx>
649 inline PetscErrorCode VecSeq_CUPM<T>::MAXPY_kernel_dispatch_(PetscDeviceContext dctx, cupmStream_t stream, PetscScalar *xptr, const PetscScalar *aptr, const Vec *yin, PetscInt size, util::index_sequence<Idx...>) noexcept
650 {
651   PetscFunctionBegin;
652   // clang-format off
653   PetscCall(
654     PetscCUPMLaunchKernel1D(
655       size, 0, stream,
656       kernels::MAXPY_kernel<typename detail::repeat_type<const PetscScalar *, Idx>::type...>,
657       size, xptr, aptr, DeviceArrayRead(dctx, yin[Idx]).data()...
658     )
659   );
660   // clang-format on
661   PetscFunctionReturn(PETSC_SUCCESS);
662 }
663 
664 template <device::cupm::DeviceType T>
665 template <int N>
666 inline PetscErrorCode VecSeq_CUPM<T>::MAXPY_kernel_dispatch_(PetscDeviceContext dctx, cupmStream_t stream, PetscScalar *xptr, const PetscScalar *aptr, const Vec *yin, PetscInt size, PetscInt &yidx) noexcept
667 {
668   PetscFunctionBegin;
669   PetscCall(MAXPY_kernel_dispatch_(dctx, stream, xptr, aptr + yidx, yin + yidx, size, util::make_index_sequence<N>{}));
670   yidx += N;
671   PetscFunctionReturn(PETSC_SUCCESS);
672 }
673 
674 // v->ops->maxpy
675 template <device::cupm::DeviceType T>
676 inline PetscErrorCode VecSeq_CUPM<T>::MAXPY(Vec xin, PetscInt nv, const PetscScalar *alpha, Vec *yin) noexcept
677 {
678   const auto         n = xin->map->n;
679   PetscDeviceContext dctx;
680   cupmStream_t       stream;
681 
682   PetscFunctionBegin;
683   PetscCall(GetHandles_(&dctx, &stream));
684   {
685     const auto   xptr    = DeviceArrayReadWrite(dctx, xin);
686     PetscScalar *d_alpha = nullptr;
687     PetscInt     yidx    = 0;
688 
689     // placement of early-return is deliberate, we would like to capture the
690     // DeviceArrayReadWrite() call (which calls PetscObjectStateIncreate()) before we bail
691     if (!n || !nv) PetscFunctionReturn(PETSC_SUCCESS);
692     PetscCall(PetscDeviceMalloc(dctx, PETSC_MEMTYPE_CUPM(), nv, &d_alpha));
693     PetscCall(PetscCUPMMemcpyAsync(d_alpha, alpha, nv, cupmMemcpyHostToDevice, stream));
694     PetscCall(PetscLogGpuTimeBegin());
695     do {
696       switch (nv - yidx) {
697       case 7:
698         PetscCall(MAXPY_kernel_dispatch_<7>(dctx, stream, xptr.data(), d_alpha, yin, n, yidx));
699         break;
700       case 6:
701         PetscCall(MAXPY_kernel_dispatch_<6>(dctx, stream, xptr.data(), d_alpha, yin, n, yidx));
702         break;
703       case 5:
704         PetscCall(MAXPY_kernel_dispatch_<5>(dctx, stream, xptr.data(), d_alpha, yin, n, yidx));
705         break;
706       case 4:
707         PetscCall(MAXPY_kernel_dispatch_<4>(dctx, stream, xptr.data(), d_alpha, yin, n, yidx));
708         break;
709       case 3:
710         PetscCall(MAXPY_kernel_dispatch_<3>(dctx, stream, xptr.data(), d_alpha, yin, n, yidx));
711         break;
712       case 2:
713         PetscCall(MAXPY_kernel_dispatch_<2>(dctx, stream, xptr.data(), d_alpha, yin, n, yidx));
714         break;
715       case 1:
716         PetscCall(MAXPY_kernel_dispatch_<1>(dctx, stream, xptr.data(), d_alpha, yin, n, yidx));
717         break;
718       default: // 8 or more
719         PetscCall(MAXPY_kernel_dispatch_<8>(dctx, stream, xptr.data(), d_alpha, yin, n, yidx));
720         break;
721       }
722     } while (yidx < nv);
723     PetscCall(PetscLogGpuTimeEnd());
724     PetscCall(PetscDeviceFree(dctx, d_alpha));
725   }
726   PetscCall(PetscLogGpuFlops(nv * 2 * n));
727   PetscCall(PetscDeviceContextSynchronize(dctx));
728   PetscFunctionReturn(PETSC_SUCCESS);
729 }
730 
731 template <device::cupm::DeviceType T>
732 inline PetscErrorCode VecSeq_CUPM<T>::Dot(Vec xin, Vec yin, PetscScalar *z) noexcept
733 {
734   PetscFunctionBegin;
735   if (const auto n = static_cast<cupmBlasInt_t>(xin->map->n)) {
736     PetscDeviceContext dctx;
737     cupmBlasHandle_t   cupmBlasHandle;
738 
739     PetscCall(GetHandles_(&dctx, &cupmBlasHandle));
740     // arguments y, x are reversed because BLAS complex conjugates the first argument, PETSc the
741     // second
742     PetscCall(PetscLogGpuTimeBegin());
743     PetscCallCUPMBLAS(cupmBlasXdot(cupmBlasHandle, n, DeviceArrayRead(dctx, yin), 1, DeviceArrayRead(dctx, xin), 1, cupmScalarPtrCast(z)));
744     PetscCall(PetscLogGpuTimeEnd());
745     PetscCall(PetscLogGpuFlops(2 * n - 1));
746   } else {
747     *z = 0.0;
748   }
749   PetscFunctionReturn(PETSC_SUCCESS);
750 }
751 
752   #define MDOT_WORKGROUP_NUM  128
753   #define MDOT_WORKGROUP_SIZE MDOT_WORKGROUP_NUM
754 
755 namespace kernels
756 {
757 
758 PETSC_DEVICE_INLINE_DECL static PetscInt EntriesPerGroup(const PetscInt size) noexcept
759 {
760   const auto group_entries = (size - 1) / gridDim.x + 1;
761   // for very small vectors, a group should still do some work
762   return group_entries ? group_entries : 1;
763 }
764 
765 template <typename... ConstPetscScalarPointer>
766 PETSC_KERNEL_DECL static void MDot_kernel(const PetscScalar *PETSC_RESTRICT x, const PetscInt size, PetscScalar *PETSC_RESTRICT results, ConstPetscScalarPointer... y)
767 {
768   constexpr int      N        = sizeof...(ConstPetscScalarPointer);
769   const PetscScalar *ylocal[] = {y...};
770   PetscScalar        sumlocal[N];
771 
772   PETSC_SHAREDMEM_DECL PetscScalar shmem[N * MDOT_WORKGROUP_SIZE];
773 
774   // HIP -- for whatever reason -- has threadIdx, blockIdx, blockDim, and gridDim as separate
775   // types, so each of these go on separate lines...
776   const auto tx       = threadIdx.x;
777   const auto bx       = blockIdx.x;
778   const auto bdx      = blockDim.x;
779   const auto gdx      = gridDim.x;
780   const auto worksize = EntriesPerGroup(size);
781   const auto begin    = tx + bx * worksize;
782   const auto end      = min((bx + 1) * worksize, size);
783 
784   #pragma unroll
785   for (auto i = 0; i < N; ++i) sumlocal[i] = 0;
786 
787   for (auto i = begin; i < end; i += bdx) {
788     const auto xi = x[i]; // load only once from global memory!
789 
790   #pragma unroll
791     for (auto j = 0; j < N; ++j) sumlocal[j] += ylocal[j][i] * xi;
792   }
793 
794   #pragma unroll
795   for (auto i = 0; i < N; ++i) shmem[tx + i * MDOT_WORKGROUP_SIZE] = sumlocal[i];
796 
797   // parallel reduction
798   for (auto stride = bdx / 2; stride > 0; stride /= 2) {
799     __syncthreads();
800     if (tx < stride) {
801   #pragma unroll
802       for (auto i = 0; i < N; ++i) shmem[tx + i * MDOT_WORKGROUP_SIZE] += shmem[tx + stride + i * MDOT_WORKGROUP_SIZE];
803     }
804   }
805   // bottom N threads per block write to global memory
806   // REVIEW ME: I am ~pretty~ sure we don't need another __syncthreads() here since each thread
807   // writes to the same sections in the above loop that it is about to read from below, but
808   // running this under the racecheck tool of cuda-memcheck reports a write-after-write hazard.
809   __syncthreads();
810   if (tx < N) results[bx + tx * gdx] = shmem[tx * MDOT_WORKGROUP_SIZE];
811   return;
812 }
813 
814 namespace
815 {
816 
817 PETSC_KERNEL_DECL void sum_kernel(const PetscInt size, PetscScalar *PETSC_RESTRICT results)
818 {
819   int         local_i = 0;
820   PetscScalar local_results[8];
821 
822   // each thread sums up MDOT_WORKGROUP_NUM entries of the result, storing it in a local buffer
823   //
824   // *-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*
825   // | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | ...
826   // *-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*
827   //  |  ______________________________________________________/
828   //  | /            <- MDOT_WORKGROUP_NUM ->
829   //  |/
830   //  +
831   //  v
832   // *-*-*
833   // | | | ...
834   // *-*-*
835   //
836   ::Petsc::device::cupm::kernels::util::grid_stride_1D(size, [&](PetscInt i) {
837     PetscScalar z_sum = 0;
838 
839     for (auto j = i * MDOT_WORKGROUP_SIZE; j < (i + 1) * MDOT_WORKGROUP_SIZE; ++j) z_sum += results[j];
840     local_results[local_i++] = z_sum;
841   });
842   // if we needed more than 1 workgroup to handle the vector we should sync since other threads
843   // may currently be reading from results
844   if (size >= MDOT_WORKGROUP_SIZE) __syncthreads();
845   // Local buffer is now written to global memory
846   ::Petsc::device::cupm::kernels::util::grid_stride_1D(size, [&](PetscInt i) {
847     const auto j = --local_i;
848 
849     if (j >= 0) results[i] = local_results[j];
850   });
851   return;
852 }
853 
854 } // namespace
855 
856 } // namespace kernels
857 
858 template <device::cupm::DeviceType T>
859 template <std::size_t... Idx>
860 inline PetscErrorCode VecSeq_CUPM<T>::MDot_kernel_dispatch_(PetscDeviceContext dctx, cupmStream_t stream, const PetscScalar *xarr, const Vec yin[], PetscInt size, PetscScalar *results, util::index_sequence<Idx...>) noexcept
861 {
862   PetscFunctionBegin;
863   // REVIEW ME: convert this kernel launch to PetscCUPMLaunchKernel1D(), it currently launches
864   // 128 blocks of 128 threads every time which may be wasteful
865   // clang-format off
866   PetscCallCUPM(
867     cupmLaunchKernel(
868       kernels::MDot_kernel<typename detail::repeat_type<const PetscScalar *, Idx>::type...>,
869       MDOT_WORKGROUP_NUM, MDOT_WORKGROUP_SIZE, 0, stream,
870       xarr, size, results, DeviceArrayRead(dctx, yin[Idx]).data()...
871     )
872   );
873   // clang-format on
874   PetscFunctionReturn(PETSC_SUCCESS);
875 }
876 
877 template <device::cupm::DeviceType T>
878 template <int N>
879 inline PetscErrorCode VecSeq_CUPM<T>::MDot_kernel_dispatch_(PetscDeviceContext dctx, cupmStream_t stream, const PetscScalar *xarr, const Vec yin[], PetscInt size, PetscScalar *results, PetscInt &yidx) noexcept
880 {
881   PetscFunctionBegin;
882   PetscCall(MDot_kernel_dispatch_(dctx, stream, xarr, yin + yidx, size, results + yidx * MDOT_WORKGROUP_NUM, util::make_index_sequence<N>{}));
883   yidx += N;
884   PetscFunctionReturn(PETSC_SUCCESS);
885 }
886 
887 template <device::cupm::DeviceType T>
888 inline PetscErrorCode VecSeq_CUPM<T>::MDot_(std::false_type, Vec xin, PetscInt nv, const Vec yin[], PetscScalar *z, PetscDeviceContext dctx) noexcept
889 {
890   // the largest possible size of a batch
891   constexpr PetscInt batchsize = 8;
892   // how many sub streams to create, if nv <= batchsize we can do this without looping, so we
893   // do not create substreams. Note we don't create more than 8 streams, in practice we could
894   // not get more parallelism with higher numbers.
895   const auto num_sub_streams = nv > batchsize ? std::min((nv + batchsize) / batchsize, batchsize) : 0;
896   const auto n               = xin->map->n;
897   // number of vectors that we handle via the batches. note any singletons are handled by
898   // cublas, hence the nv-1.
899   const auto   nvbatch = ((nv % batchsize) == 1) ? nv - 1 : nv;
900   const auto   nwork   = nvbatch * MDOT_WORKGROUP_NUM;
901   PetscScalar *d_results;
902   cupmStream_t stream;
903 
904   PetscFunctionBegin;
905   PetscCall(GetHandlesFrom_(dctx, &stream));
906   // allocate scratchpad memory for the results of individual work groups
907   PetscCall(PetscDeviceMalloc(dctx, PETSC_MEMTYPE_CUPM(), nwork, &d_results));
908   {
909     const auto          xptr       = DeviceArrayRead(dctx, xin);
910     PetscInt            yidx       = 0;
911     auto                subidx     = 0;
912     auto                cur_stream = stream;
913     auto                cur_ctx    = dctx;
914     PetscDeviceContext *sub        = nullptr;
915     PetscStreamType     stype;
916 
917     // REVIEW ME: maybe PetscDeviceContextFork() should insert dctx into the first entry of
918     // sub. Ideally the parent context should also join in on the fork, but it is extremely
919     // fiddly to do so presently
920     PetscCall(PetscDeviceContextGetStreamType(dctx, &stype));
921     if (stype == PETSC_STREAM_GLOBAL_BLOCKING) stype = PETSC_STREAM_DEFAULT_BLOCKING;
922     // If we have a globally blocking stream create nonblocking streams instead (as we can
923     // locally exploit the parallelism). Otherwise use the prescribed stream type.
924     PetscCall(PetscDeviceContextForkWithStreamType(dctx, stype, num_sub_streams, &sub));
925     PetscCall(PetscLogGpuTimeBegin());
926     do {
927       if (num_sub_streams) {
928         cur_ctx = sub[subidx++ % num_sub_streams];
929         PetscCall(GetHandlesFrom_(cur_ctx, &cur_stream));
930       }
931       // REVIEW ME: Should probably try and load-balance these. Consider the case where nv = 9;
932       // it is very likely better to do 4+5 rather than 8+1
933       switch (nv - yidx) {
934       case 7:
935         PetscCall(MDot_kernel_dispatch_<7>(cur_ctx, cur_stream, xptr.data(), yin, n, d_results, yidx));
936         break;
937       case 6:
938         PetscCall(MDot_kernel_dispatch_<6>(cur_ctx, cur_stream, xptr.data(), yin, n, d_results, yidx));
939         break;
940       case 5:
941         PetscCall(MDot_kernel_dispatch_<5>(cur_ctx, cur_stream, xptr.data(), yin, n, d_results, yidx));
942         break;
943       case 4:
944         PetscCall(MDot_kernel_dispatch_<4>(cur_ctx, cur_stream, xptr.data(), yin, n, d_results, yidx));
945         break;
946       case 3:
947         PetscCall(MDot_kernel_dispatch_<3>(cur_ctx, cur_stream, xptr.data(), yin, n, d_results, yidx));
948         break;
949       case 2:
950         PetscCall(MDot_kernel_dispatch_<2>(cur_ctx, cur_stream, xptr.data(), yin, n, d_results, yidx));
951         break;
952       case 1: {
953         cupmBlasHandle_t cupmBlasHandle;
954 
955         PetscCall(GetHandlesFrom_(cur_ctx, &cupmBlasHandle));
956         PetscCallCUPMBLAS(cupmBlasXdot(cupmBlasHandle, static_cast<cupmBlasInt_t>(n), DeviceArrayRead(cur_ctx, yin[yidx]).cupmdata(), 1, xptr.cupmdata(), 1, cupmScalarPtrCast(z + yidx)));
957         ++yidx;
958       } break;
959       default: // 8 or more
960         PetscCall(MDot_kernel_dispatch_<8>(cur_ctx, cur_stream, xptr.data(), yin, n, d_results, yidx));
961         break;
962       }
963     } while (yidx < nv);
964     PetscCall(PetscLogGpuTimeEnd());
965     PetscCall(PetscDeviceContextJoin(dctx, num_sub_streams, PETSC_DEVICE_CONTEXT_JOIN_DESTROY, &sub));
966   }
967 
968   PetscCall(PetscCUPMLaunchKernel1D(nvbatch, 0, stream, kernels::sum_kernel, nvbatch, d_results));
969   // copy result of device reduction to host
970   PetscCall(PetscCUPMMemcpyAsync(z, d_results, nvbatch, cupmMemcpyDeviceToHost, stream));
971   // do these now while final reduction is in flight
972   PetscCall(PetscLogFlops(nwork));
973   PetscCall(PetscDeviceFree(dctx, d_results));
974   PetscFunctionReturn(PETSC_SUCCESS);
975 }
976 
977   #undef MDOT_WORKGROUP_NUM
978   #undef MDOT_WORKGROUP_SIZE
979 
980 template <device::cupm::DeviceType T>
981 inline PetscErrorCode VecSeq_CUPM<T>::MDot_(std::true_type, Vec xin, PetscInt nv, const Vec yin[], PetscScalar *z, PetscDeviceContext dctx) noexcept
982 {
983   // probably not worth it to run more than 8 of these at a time?
984   const auto          n_sub = PetscMin(nv, 8);
985   const auto          n     = static_cast<cupmBlasInt_t>(xin->map->n);
986   const auto          xptr  = DeviceArrayRead(dctx, xin);
987   PetscScalar        *d_z;
988   PetscDeviceContext *subctx;
989   cupmStream_t        stream;
990 
991   PetscFunctionBegin;
992   PetscCall(GetHandlesFrom_(dctx, &stream));
993   PetscCall(PetscDeviceMalloc(dctx, PETSC_MEMTYPE_CUPM(), nv, &d_z));
994   PetscCall(PetscDeviceContextFork(dctx, n_sub, &subctx));
995   PetscCall(PetscLogGpuTimeBegin());
996   for (PetscInt i = 0; i < nv; ++i) {
997     const auto            sub = subctx[i % n_sub];
998     cupmBlasHandle_t      handle;
999     cupmBlasPointerMode_t old_mode;
1000 
1001     PetscCall(GetHandlesFrom_(sub, &handle));
1002     PetscCallCUPMBLAS(cupmBlasGetPointerMode(handle, &old_mode));
1003     if (old_mode != CUPMBLAS_POINTER_MODE_DEVICE) PetscCallCUPMBLAS(cupmBlasSetPointerMode(handle, CUPMBLAS_POINTER_MODE_DEVICE));
1004     PetscCallCUPMBLAS(cupmBlasXdot(handle, n, DeviceArrayRead(sub, yin[i]), 1, xptr.cupmdata(), 1, cupmScalarPtrCast(d_z + i)));
1005     if (old_mode != CUPMBLAS_POINTER_MODE_DEVICE) PetscCallCUPMBLAS(cupmBlasSetPointerMode(handle, old_mode));
1006   }
1007   PetscCall(PetscLogGpuTimeEnd());
1008   PetscCall(PetscDeviceContextJoin(dctx, n_sub, PETSC_DEVICE_CONTEXT_JOIN_DESTROY, &subctx));
1009   PetscCall(PetscCUPMMemcpyAsync(z, d_z, nv, cupmMemcpyDeviceToHost, stream));
1010   PetscCall(PetscDeviceFree(dctx, d_z));
1011   // REVIEW ME: flops?????
1012   PetscFunctionReturn(PETSC_SUCCESS);
1013 }
1014 
1015 // v->ops->mdot
1016 template <device::cupm::DeviceType T>
1017 inline PetscErrorCode VecSeq_CUPM<T>::MDot(Vec xin, PetscInt nv, const Vec yin[], PetscScalar *z) noexcept
1018 {
1019   PetscFunctionBegin;
1020   if (PetscUnlikely(nv == 1)) {
1021     // dot handles nv = 0 correctly
1022     PetscCall(Dot(xin, const_cast<Vec>(yin[0]), z));
1023   } else if (const auto n = xin->map->n) {
1024     PetscDeviceContext dctx;
1025 
1026     PetscCheck(nv > 0, PETSC_COMM_SELF, PETSC_ERR_LIB, "Number of vectors provided to %s %" PetscInt_FMT " not positive", PETSC_FUNCTION_NAME, nv);
1027     PetscCall(GetHandles_(&dctx));
1028     PetscCall(MDot_(std::integral_constant<bool, PetscDefined(USE_COMPLEX)>{}, xin, nv, yin, z, dctx));
1029     // REVIEW ME: double count of flops??
1030     PetscCall(PetscLogGpuFlops(nv * (2 * n - 1)));
1031     PetscCall(PetscDeviceContextSynchronize(dctx));
1032   } else {
1033     PetscCall(PetscArrayzero(z, nv));
1034   }
1035   PetscFunctionReturn(PETSC_SUCCESS);
1036 }
1037 
1038 // v->ops->set
1039 template <device::cupm::DeviceType T>
1040 inline PetscErrorCode VecSeq_CUPM<T>::Set(Vec xin, PetscScalar alpha) noexcept
1041 {
1042   const auto         n = xin->map->n;
1043   PetscDeviceContext dctx;
1044   cupmStream_t       stream;
1045 
1046   PetscFunctionBegin;
1047   PetscCall(GetHandles_(&dctx, &stream));
1048   {
1049     const auto xptr = DeviceArrayWrite(dctx, xin);
1050 
1051     if (alpha == PetscScalar(0.0)) {
1052       PetscCall(PetscCUPMMemsetAsync(xptr.data(), 0, n, stream));
1053     } else {
1054       const auto dptr = thrust::device_pointer_cast(xptr.data());
1055 
1056       PetscCallThrust(THRUST_CALL(thrust::fill, stream, dptr, dptr + n, alpha));
1057     }
1058     if (n) PetscCall(PetscDeviceContextSynchronize(dctx)); // don't sync if we did nothing
1059   }
1060   PetscFunctionReturn(PETSC_SUCCESS);
1061 }
1062 
1063 // v->ops->scale
1064 template <device::cupm::DeviceType T>
1065 inline PetscErrorCode VecSeq_CUPM<T>::Scale(Vec xin, PetscScalar alpha) noexcept
1066 {
1067   PetscFunctionBegin;
1068   if (PetscUnlikely(alpha == PetscScalar(1.0))) PetscFunctionReturn(PETSC_SUCCESS);
1069   if (PetscUnlikely(alpha == PetscScalar(0.0))) {
1070     PetscCall(Set(xin, alpha));
1071   } else if (const auto n = static_cast<cupmBlasInt_t>(xin->map->n)) {
1072     PetscDeviceContext dctx;
1073     cupmBlasHandle_t   cupmBlasHandle;
1074 
1075     PetscCall(GetHandles_(&dctx, &cupmBlasHandle));
1076     PetscCall(PetscLogGpuTimeBegin());
1077     PetscCallCUPMBLAS(cupmBlasXscal(cupmBlasHandle, n, cupmScalarPtrCast(&alpha), DeviceArrayReadWrite(dctx, xin), 1));
1078     PetscCall(PetscLogGpuTimeEnd());
1079     PetscCall(PetscLogGpuFlops(n));
1080     PetscCall(PetscDeviceContextSynchronize(dctx));
1081   } else {
1082     PetscCall(MaybeIncrementEmptyLocalVec(xin));
1083   }
1084   PetscFunctionReturn(PETSC_SUCCESS);
1085 }
1086 
1087 // v->ops->tdot
1088 template <device::cupm::DeviceType T>
1089 inline PetscErrorCode VecSeq_CUPM<T>::TDot(Vec xin, Vec yin, PetscScalar *z) noexcept
1090 {
1091   PetscFunctionBegin;
1092   if (const auto n = static_cast<cupmBlasInt_t>(xin->map->n)) {
1093     PetscDeviceContext dctx;
1094     cupmBlasHandle_t   cupmBlasHandle;
1095 
1096     PetscCall(GetHandles_(&dctx, &cupmBlasHandle));
1097     PetscCall(PetscLogGpuTimeBegin());
1098     PetscCallCUPMBLAS(cupmBlasXdotu(cupmBlasHandle, n, DeviceArrayRead(dctx, xin), 1, DeviceArrayRead(dctx, yin), 1, cupmScalarPtrCast(z)));
1099     PetscCall(PetscLogGpuTimeEnd());
1100     PetscCall(PetscLogGpuFlops(2 * n - 1));
1101   } else {
1102     *z = 0.0;
1103   }
1104   PetscFunctionReturn(PETSC_SUCCESS);
1105 }
1106 
1107 // v->ops->copy
1108 template <device::cupm::DeviceType T>
1109 inline PetscErrorCode VecSeq_CUPM<T>::Copy(Vec xin, Vec yout) noexcept
1110 {
1111   PetscFunctionBegin;
1112   if (xin == yout) PetscFunctionReturn(PETSC_SUCCESS);
1113   if (const auto n = xin->map->n) {
1114     const auto xmask = xin->offloadmask;
1115     // silence buggy gcc warning: mode may be used uninitialized in this function
1116     auto               mode = cupmMemcpyDeviceToDevice;
1117     PetscDeviceContext dctx;
1118     cupmStream_t       stream;
1119 
1120     // translate from PetscOffloadMask to cupmMemcpyKind
1121     switch (const auto ymask = yout->offloadmask) {
1122     case PETSC_OFFLOAD_UNALLOCATED: {
1123       PetscBool yiscupm;
1124 
1125       PetscCall(PetscObjectTypeCompareAny(PetscObjectCast(yout), &yiscupm, VECSEQCUPM(), VECMPICUPM(), ""));
1126       if (yiscupm) {
1127         mode = PetscOffloadDevice(xmask) ? cupmMemcpyDeviceToDevice : cupmMemcpyHostToHost;
1128         break;
1129       }
1130     } // fall-through if unallocated and not cupm
1131   #if PETSC_CPP_VERSION >= 17
1132       [[fallthrough]];
1133   #endif
1134     case PETSC_OFFLOAD_CPU:
1135       mode = PetscOffloadHost(xmask) ? cupmMemcpyHostToHost : cupmMemcpyDeviceToHost;
1136       break;
1137     case PETSC_OFFLOAD_BOTH:
1138     case PETSC_OFFLOAD_GPU:
1139       mode = PetscOffloadDevice(xmask) ? cupmMemcpyDeviceToDevice : cupmMemcpyHostToDevice;
1140       break;
1141     default:
1142       SETERRQ(PETSC_COMM_SELF, PETSC_ERR_ARG_INCOMP, "Incompatible offload mask %s", PetscOffloadMaskToString(ymask));
1143     }
1144 
1145     PetscCall(GetHandles_(&dctx, &stream));
1146     switch (mode) {
1147     case cupmMemcpyDeviceToDevice: // the best case
1148     case cupmMemcpyHostToDevice: { // not terrible
1149       const auto yptr = DeviceArrayWrite(dctx, yout);
1150       const auto xptr = mode == cupmMemcpyDeviceToDevice ? DeviceArrayRead(dctx, xin).data() : HostArrayRead(dctx, xin).data();
1151 
1152       PetscCall(PetscLogGpuTimeBegin());
1153       PetscCall(PetscCUPMMemcpyAsync(yptr.data(), xptr, n, mode, stream));
1154       PetscCall(PetscLogGpuTimeEnd());
1155     } break;
1156     case cupmMemcpyDeviceToHost: // not great
1157     case cupmMemcpyHostToHost: { // worst case
1158       const auto   xptr = mode == cupmMemcpyDeviceToHost ? DeviceArrayRead(dctx, xin).data() : HostArrayRead(dctx, xin).data();
1159       PetscScalar *yptr;
1160 
1161       PetscCall(VecGetArrayWrite(yout, &yptr));
1162       if (mode == cupmMemcpyDeviceToHost) PetscCall(PetscLogGpuTimeBegin());
1163       PetscCall(PetscCUPMMemcpyAsync(yptr, xptr, n, mode, stream, /* force async */ true));
1164       if (mode == cupmMemcpyDeviceToHost) PetscCall(PetscLogGpuTimeEnd());
1165       PetscCall(VecRestoreArrayWrite(yout, &yptr));
1166     } break;
1167     default:
1168       SETERRQ(PETSC_COMM_SELF, PETSC_ERR_GPU, "Unknown cupmMemcpyKind %d", static_cast<int>(mode));
1169     }
1170     PetscCall(PetscDeviceContextSynchronize(dctx));
1171   } else {
1172     PetscCall(MaybeIncrementEmptyLocalVec(yout));
1173   }
1174   PetscFunctionReturn(PETSC_SUCCESS);
1175 }
1176 
1177 // v->ops->swap
1178 template <device::cupm::DeviceType T>
1179 inline PetscErrorCode VecSeq_CUPM<T>::Swap(Vec xin, Vec yin) noexcept
1180 {
1181   PetscFunctionBegin;
1182   if (xin == yin) PetscFunctionReturn(PETSC_SUCCESS);
1183   if (const auto n = static_cast<cupmBlasInt_t>(xin->map->n)) {
1184     PetscDeviceContext dctx;
1185     cupmBlasHandle_t   cupmBlasHandle;
1186 
1187     PetscCall(GetHandles_(&dctx, &cupmBlasHandle));
1188     PetscCall(PetscLogGpuTimeBegin());
1189     PetscCallCUPMBLAS(cupmBlasXswap(cupmBlasHandle, n, DeviceArrayReadWrite(dctx, xin), 1, DeviceArrayReadWrite(dctx, yin), 1));
1190     PetscCall(PetscLogGpuTimeEnd());
1191     PetscCall(PetscDeviceContextSynchronize(dctx));
1192   } else {
1193     PetscCall(MaybeIncrementEmptyLocalVec(xin));
1194     PetscCall(MaybeIncrementEmptyLocalVec(yin));
1195   }
1196   PetscFunctionReturn(PETSC_SUCCESS);
1197 }
1198 
1199 // v->ops->axpby
1200 template <device::cupm::DeviceType T>
1201 inline PetscErrorCode VecSeq_CUPM<T>::AXPBY(Vec yin, PetscScalar alpha, PetscScalar beta, Vec xin) noexcept
1202 {
1203   PetscFunctionBegin;
1204   if (alpha == PetscScalar(0.0)) {
1205     PetscCall(Scale(yin, beta));
1206   } else if (beta == PetscScalar(1.0)) {
1207     PetscCall(AXPY(yin, alpha, xin));
1208   } else if (alpha == PetscScalar(1.0)) {
1209     PetscCall(AYPX(yin, beta, xin));
1210   } else if (const auto n = static_cast<cupmBlasInt_t>(yin->map->n)) {
1211     const auto         betaIsZero = beta == PetscScalar(0.0);
1212     const auto         aptr       = cupmScalarPtrCast(&alpha);
1213     PetscDeviceContext dctx;
1214     cupmBlasHandle_t   cupmBlasHandle;
1215 
1216     PetscCall(GetHandles_(&dctx, &cupmBlasHandle));
1217     {
1218       const auto xptr = DeviceArrayRead(dctx, xin);
1219 
1220       if (betaIsZero /* beta = 0 */) {
1221         // here we can get away with purely write-only as we memcpy into it first
1222         const auto   yptr = DeviceArrayWrite(dctx, yin);
1223         cupmStream_t stream;
1224 
1225         PetscCall(GetHandlesFrom_(dctx, &stream));
1226         PetscCall(PetscLogGpuTimeBegin());
1227         PetscCall(PetscCUPMMemcpyAsync(yptr.data(), xptr.data(), n, cupmMemcpyDeviceToDevice, stream));
1228         PetscCallCUPMBLAS(cupmBlasXscal(cupmBlasHandle, n, aptr, yptr.cupmdata(), 1));
1229       } else {
1230         const auto yptr = DeviceArrayReadWrite(dctx, yin);
1231 
1232         PetscCall(PetscLogGpuTimeBegin());
1233         PetscCallCUPMBLAS(cupmBlasXscal(cupmBlasHandle, n, cupmScalarPtrCast(&beta), yptr.cupmdata(), 1));
1234         PetscCallCUPMBLAS(cupmBlasXaxpy(cupmBlasHandle, n, aptr, xptr.cupmdata(), 1, yptr.cupmdata(), 1));
1235       }
1236     }
1237     PetscCall(PetscLogGpuTimeEnd());
1238     PetscCall(PetscLogGpuFlops((betaIsZero ? 1 : 3) * n));
1239     PetscCall(PetscDeviceContextSynchronize(dctx));
1240   } else {
1241     PetscCall(MaybeIncrementEmptyLocalVec(yin));
1242   }
1243   PetscFunctionReturn(PETSC_SUCCESS);
1244 }
1245 
1246 // v->ops->axpbypcz
1247 template <device::cupm::DeviceType T>
1248 inline PetscErrorCode VecSeq_CUPM<T>::AXPBYPCZ(Vec zin, PetscScalar alpha, PetscScalar beta, PetscScalar gamma, Vec xin, Vec yin) noexcept
1249 {
1250   PetscFunctionBegin;
1251   if (gamma != PetscScalar(1.0)) PetscCall(Scale(zin, gamma));
1252   PetscCall(AXPY(zin, alpha, xin));
1253   PetscCall(AXPY(zin, beta, yin));
1254   PetscFunctionReturn(PETSC_SUCCESS);
1255 }
1256 
1257 // v->ops->norm
1258 template <device::cupm::DeviceType T>
1259 inline PetscErrorCode VecSeq_CUPM<T>::Norm(Vec xin, NormType type, PetscReal *z) noexcept
1260 {
1261   PetscDeviceContext dctx;
1262   cupmBlasHandle_t   cupmBlasHandle;
1263 
1264   PetscFunctionBegin;
1265   PetscCall(GetHandles_(&dctx, &cupmBlasHandle));
1266   if (const auto n = static_cast<cupmBlasInt_t>(xin->map->n)) {
1267     const auto xptr      = DeviceArrayRead(dctx, xin);
1268     PetscInt   flopCount = 0;
1269 
1270     PetscCall(PetscLogGpuTimeBegin());
1271     switch (type) {
1272     case NORM_1_AND_2:
1273     case NORM_1:
1274       PetscCallCUPMBLAS(cupmBlasXasum(cupmBlasHandle, n, xptr.cupmdata(), 1, cupmRealPtrCast(z)));
1275       flopCount = std::max(n - 1, 0);
1276       if (type == NORM_1) break;
1277       ++z; // fall-through
1278   #if PETSC_CPP_VERSION >= 17
1279       [[fallthrough]];
1280   #endif
1281     case NORM_2:
1282     case NORM_FROBENIUS:
1283       PetscCallCUPMBLAS(cupmBlasXnrm2(cupmBlasHandle, n, xptr.cupmdata(), 1, cupmRealPtrCast(z)));
1284       flopCount += std::max(2 * n - 1, 0); // += in case we've fallen through from NORM_1_AND_2
1285       break;
1286     case NORM_INFINITY: {
1287       cupmBlasInt_t max_loc = 0;
1288       PetscScalar   xv      = 0.;
1289       cupmStream_t  stream;
1290 
1291       PetscCall(GetHandlesFrom_(dctx, &stream));
1292       PetscCallCUPMBLAS(cupmBlasXamax(cupmBlasHandle, n, xptr.cupmdata(), 1, &max_loc));
1293       PetscCall(PetscCUPMMemcpyAsync(&xv, xptr.data() + max_loc - 1, 1, cupmMemcpyDeviceToHost, stream));
1294       *z = PetscAbsScalar(xv);
1295       // REVIEW ME: flopCount = ???
1296     } break;
1297     }
1298     PetscCall(PetscLogGpuTimeEnd());
1299     PetscCall(PetscLogGpuFlops(flopCount));
1300   } else {
1301     z[0]                    = 0.0;
1302     z[type == NORM_1_AND_2] = 0.0;
1303   }
1304   PetscFunctionReturn(PETSC_SUCCESS);
1305 }
1306 
1307 namespace detail
1308 {
1309 
1310 struct dotnorm2_mult {
1311   PETSC_NODISCARD PETSC_HOSTDEVICE_INLINE_DECL thrust::tuple<PetscScalar, PetscScalar> operator()(const PetscScalar &s, const PetscScalar &t) const noexcept
1312   {
1313     const auto conjt = PetscConj(t);
1314 
1315     return {s * conjt, t * conjt};
1316   }
1317 };
1318 
1319 // it is positively __bananas__ that thrust does not define default operator+ for tuples... I
1320 // would do it myself but now I am worried that they do so on purpose...
1321 struct dotnorm2_tuple_plus {
1322   using value_type = thrust::tuple<PetscScalar, PetscScalar>;
1323 
1324   PETSC_NODISCARD PETSC_HOSTDEVICE_INLINE_DECL value_type operator()(const value_type &lhs, const value_type &rhs) const noexcept { return {lhs.get<0>() + rhs.get<0>(), lhs.get<1>() + rhs.get<1>()}; }
1325 };
1326 
1327 } // namespace detail
1328 
1329 // v->ops->dotnorm2
1330 template <device::cupm::DeviceType T>
1331 inline PetscErrorCode VecSeq_CUPM<T>::DotNorm2(Vec s, Vec t, PetscScalar *dp, PetscScalar *nm) noexcept
1332 {
1333   PetscDeviceContext dctx;
1334   cupmStream_t       stream;
1335 
1336   PetscFunctionBegin;
1337   PetscCall(GetHandles_(&dctx, &stream));
1338   {
1339     PetscScalar dpt = 0.0, nmt = 0.0;
1340     const auto  sdptr = thrust::device_pointer_cast(DeviceArrayRead(dctx, s).data());
1341 
1342     // clang-format off
1343     PetscCallThrust(
1344       thrust::tie(*dp, *nm) = THRUST_CALL(
1345         thrust::inner_product,
1346         stream,
1347         sdptr, sdptr+s->map->n, thrust::device_pointer_cast(DeviceArrayRead(dctx, t).data()),
1348         thrust::make_tuple(dpt, nmt),
1349         detail::dotnorm2_tuple_plus{}, detail::dotnorm2_mult{}
1350       );
1351     );
1352     // clang-format on
1353   }
1354   PetscFunctionReturn(PETSC_SUCCESS);
1355 }
1356 
1357 namespace detail
1358 {
1359 
1360 struct conjugate {
1361   PETSC_NODISCARD PETSC_HOSTDEVICE_INLINE_DECL PetscScalar operator()(PetscScalar x) const noexcept { return PetscConj(x); }
1362 };
1363 
1364 } // namespace detail
1365 
1366 // v->ops->conjugate
1367 template <device::cupm::DeviceType T>
1368 inline PetscErrorCode VecSeq_CUPM<T>::Conjugate(Vec xin) noexcept
1369 {
1370   PetscFunctionBegin;
1371   if (PetscDefined(USE_COMPLEX)) PetscCall(PointwiseUnary_(detail::conjugate{}, xin));
1372   PetscFunctionReturn(PETSC_SUCCESS);
1373 }
1374 
1375 namespace detail
1376 {
1377 
1378 struct real_part {
1379   PETSC_NODISCARD PETSC_HOSTDEVICE_INLINE_DECL thrust::tuple<PetscReal, PetscInt> operator()(const thrust::tuple<PetscScalar, PetscInt> &x) const { return {PetscRealPart(x.get<0>()), x.get<1>()}; }
1380 
1381   PETSC_NODISCARD PETSC_HOSTDEVICE_INLINE_DECL PetscReal operator()(PetscScalar x) const { return PetscRealPart(x); }
1382 };
1383 
1384 // deriving from Operator allows us to "store" an instance of the operator in the class but
1385 // also take advantage of empty base class optimization if the operator is stateless
1386 template <typename Operator>
1387 class tuple_compare : Operator {
1388 public:
1389   using tuple_type    = thrust::tuple<PetscReal, PetscInt>;
1390   using operator_type = Operator;
1391 
1392   PETSC_NODISCARD PETSC_HOSTDEVICE_INLINE_DECL tuple_type operator()(const tuple_type &x, const tuple_type &y) const noexcept
1393   {
1394     if (op_()(y.get<0>(), x.get<0>())) {
1395       // if y is strictly greater/less than x, return y
1396       return y;
1397     } else if (y.get<0>() == x.get<0>()) {
1398       // if equal, prefer lower index
1399       return y.get<1>() < x.get<1>() ? y : x;
1400     }
1401     // otherwise return x
1402     return x;
1403   }
1404 
1405 private:
1406   PETSC_NODISCARD PETSC_HOSTDEVICE_INLINE_DECL const operator_type &op_() const noexcept { return *this; }
1407 };
1408 
1409 } // namespace detail
1410 
1411 template <device::cupm::DeviceType T>
1412 template <typename TupleFuncT, typename UnaryFuncT>
1413 inline PetscErrorCode VecSeq_CUPM<T>::MinMax_(TupleFuncT &&tuple_ftr, UnaryFuncT &&unary_ftr, Vec v, PetscInt *p, PetscReal *m) noexcept
1414 {
1415   PetscFunctionBegin;
1416   PetscCheckTypeNames(v, VECSEQCUPM(), VECMPICUPM());
1417   if (p) *p = -1;
1418   if (const auto n = v->map->n) {
1419     PetscDeviceContext dctx;
1420     cupmStream_t       stream;
1421 
1422     PetscCall(GetHandles_(&dctx, &stream));
1423       // needed to:
1424       // 1. switch between transform_reduce and reduce
1425       // 2. strip the real_part functor from the arguments
1426   #if PetscDefined(USE_COMPLEX)
1427     #define THRUST_MINMAX_REDUCE(...) THRUST_CALL(thrust::transform_reduce, __VA_ARGS__)
1428   #else
1429     #define THRUST_MINMAX_REDUCE(s, b, e, real_part__, ...) THRUST_CALL(thrust::reduce, s, b, e, __VA_ARGS__)
1430   #endif
1431     {
1432       const auto vptr = thrust::device_pointer_cast(DeviceArrayRead(dctx, v).data());
1433 
1434       if (p) {
1435         // clang-format off
1436         const auto zip = thrust::make_zip_iterator(
1437           thrust::make_tuple(std::move(vptr), thrust::make_counting_iterator(PetscInt{0}))
1438         );
1439         // clang-format on
1440         // need to use preprocessor conditionals since otherwise thrust complains about not being
1441         // able to convert a thrust::device_reference<PetscScalar> to a PetscReal on complex
1442         // builds...
1443         // clang-format off
1444         PetscCallThrust(
1445           thrust::tie(*m, *p) = THRUST_MINMAX_REDUCE(
1446             stream, zip, zip + n, detail::real_part{},
1447             thrust::make_tuple(*m, *p), std::forward<TupleFuncT>(tuple_ftr)
1448           );
1449         );
1450         // clang-format on
1451       } else {
1452         // clang-format off
1453         PetscCallThrust(
1454           *m = THRUST_MINMAX_REDUCE(
1455             stream, vptr, vptr + n, detail::real_part{},
1456             *m, std::forward<UnaryFuncT>(unary_ftr)
1457           );
1458         );
1459         // clang-format on
1460       }
1461     }
1462   #undef THRUST_MINMAX_REDUCE
1463   }
1464   // REVIEW ME: flops?
1465   PetscFunctionReturn(PETSC_SUCCESS);
1466 }
1467 
1468 // v->ops->max
1469 template <device::cupm::DeviceType T>
1470 inline PetscErrorCode VecSeq_CUPM<T>::Max(Vec v, PetscInt *p, PetscReal *m) noexcept
1471 {
1472   using tuple_functor = detail::tuple_compare<thrust::greater<PetscReal>>;
1473   using unary_functor = thrust::maximum<PetscReal>;
1474 
1475   PetscFunctionBegin;
1476   *m = PETSC_MIN_REAL;
1477   // use {} constructor syntax otherwise most vexing parse
1478   PetscCall(MinMax_(tuple_functor{}, unary_functor{}, v, p, m));
1479   PetscFunctionReturn(PETSC_SUCCESS);
1480 }
1481 
1482 // v->ops->min
1483 template <device::cupm::DeviceType T>
1484 inline PetscErrorCode VecSeq_CUPM<T>::Min(Vec v, PetscInt *p, PetscReal *m) noexcept
1485 {
1486   using tuple_functor = detail::tuple_compare<thrust::less<PetscReal>>;
1487   using unary_functor = thrust::minimum<PetscReal>;
1488 
1489   PetscFunctionBegin;
1490   *m = PETSC_MAX_REAL;
1491   // use {} constructor syntax otherwise most vexing parse
1492   PetscCall(MinMax_(tuple_functor{}, unary_functor{}, v, p, m));
1493   PetscFunctionReturn(PETSC_SUCCESS);
1494 }
1495 
1496 // v->ops->sum
1497 template <device::cupm::DeviceType T>
1498 inline PetscErrorCode VecSeq_CUPM<T>::Sum(Vec v, PetscScalar *sum) noexcept
1499 {
1500   PetscFunctionBegin;
1501   if (const auto n = v->map->n) {
1502     PetscDeviceContext dctx;
1503     cupmStream_t       stream;
1504 
1505     PetscCall(GetHandles_(&dctx, &stream));
1506     const auto dptr = thrust::device_pointer_cast(DeviceArrayRead(dctx, v).data());
1507     // REVIEW ME: why not cupmBlasXasum()?
1508     PetscCallThrust(*sum = THRUST_CALL(thrust::reduce, stream, dptr, dptr + n, PetscScalar{0.0}););
1509     // REVIEW ME: must be at least n additions
1510     PetscCall(PetscLogGpuFlops(n));
1511   } else {
1512     *sum = 0.0;
1513   }
1514   PetscFunctionReturn(PETSC_SUCCESS);
1515 }
1516 
1517 template <device::cupm::DeviceType T>
1518 inline PetscErrorCode VecSeq_CUPM<T>::Shift(Vec v, PetscScalar shift) noexcept
1519 {
1520   PetscFunctionBegin;
1521   PetscCall(PointwiseUnary_(device::cupm::functors::make_plus_equals(shift), v));
1522   PetscFunctionReturn(PETSC_SUCCESS);
1523 }
1524 
1525 template <device::cupm::DeviceType T>
1526 inline PetscErrorCode VecSeq_CUPM<T>::SetRandom(Vec v, PetscRandom rand) noexcept
1527 {
1528   PetscFunctionBegin;
1529   if (const auto n = v->map->n) {
1530     PetscBool          iscurand;
1531     PetscDeviceContext dctx;
1532 
1533     PetscCall(GetHandles_(&dctx));
1534     PetscCall(PetscObjectTypeCompare(PetscObjectCast(rand), PETSCCURAND, &iscurand));
1535     if (iscurand) PetscCall(PetscRandomGetValues(rand, n, DeviceArrayWrite(dctx, v)));
1536     else PetscCall(PetscRandomGetValues(rand, n, HostArrayWrite(dctx, v)));
1537   } else {
1538     PetscCall(MaybeIncrementEmptyLocalVec(v));
1539   }
1540   // REVIEW ME: flops????
1541   // REVIEW ME: Timing???
1542   PetscFunctionReturn(PETSC_SUCCESS);
1543 }
1544 
1545 // v->ops->setpreallocation
1546 template <device::cupm::DeviceType T>
1547 inline PetscErrorCode VecSeq_CUPM<T>::SetPreallocationCOO(Vec v, PetscCount ncoo, const PetscInt coo_i[]) noexcept
1548 {
1549   PetscDeviceContext dctx;
1550 
1551   PetscFunctionBegin;
1552   PetscCall(GetHandles_(&dctx));
1553   PetscCall(VecSetPreallocationCOO_Seq(v, ncoo, coo_i));
1554   PetscCall(SetPreallocationCOO_CUPMBase(v, ncoo, coo_i, dctx));
1555   PetscFunctionReturn(PETSC_SUCCESS);
1556 }
1557 
1558 namespace kernels
1559 {
1560 
1561 template <typename F>
1562 PETSC_DEVICE_INLINE_DECL void add_coo_values_impl(const PetscScalar *PETSC_RESTRICT vv, PetscCount n, const PetscCount *PETSC_RESTRICT jmap, const PetscCount *PETSC_RESTRICT perm, InsertMode imode, PetscScalar *PETSC_RESTRICT xv, F &&xvindex)
1563 {
1564   ::Petsc::device::cupm::kernels::util::grid_stride_1D(n, [=](PetscCount i) {
1565     const auto  end = jmap[i + 1];
1566     const auto  idx = xvindex(i);
1567     PetscScalar sum = 0.0;
1568 
1569     for (auto k = jmap[i]; k < end; ++k) sum += vv[perm[k]];
1570 
1571     if (imode == INSERT_VALUES) {
1572       xv[idx] = sum;
1573     } else {
1574       xv[idx] += sum;
1575     }
1576   });
1577   return;
1578 }
1579 
1580 namespace
1581 {
1582 
1583 PETSC_KERNEL_DECL void add_coo_values(const PetscScalar *PETSC_RESTRICT v, PetscCount n, const PetscCount *PETSC_RESTRICT jmap1, const PetscCount *PETSC_RESTRICT perm1, InsertMode imode, PetscScalar *PETSC_RESTRICT xv)
1584 {
1585   add_coo_values_impl(v, n, jmap1, perm1, imode, xv, [](PetscCount i) { return i; });
1586   return;
1587 }
1588 
1589 } // namespace
1590 
1591   #if PetscDefined(USING_HCC)
1592 namespace do_not_use
1593 {
1594 
1595 // Needed to silence clang warning:
1596 //
1597 // warning: function 'FUNCTION NAME' is not needed and will not be emitted
1598 //
1599 // The warning is silly, since the function *is* used, however the host compiler does not
1600 // appear see this. Likely because the function using it is in a template.
1601 //
1602 // This warning appeared in clang-11, and still persists until clang-15 (21/02/2023)
1603 inline void silence_warning_function_sum_kernel_is_not_needed_and_will_not_be_emitted()
1604 {
1605   (void)sum_kernel;
1606 }
1607 
1608 inline void silence_warning_function_add_coo_values_is_not_needed_and_will_not_be_emitted()
1609 {
1610   (void)add_coo_values;
1611 }
1612 
1613 } // namespace do_not_use
1614   #endif
1615 
1616 } // namespace kernels
1617 
1618 // v->ops->setvaluescoo
1619 template <device::cupm::DeviceType T>
1620 inline PetscErrorCode VecSeq_CUPM<T>::SetValuesCOO(Vec x, const PetscScalar v[], InsertMode imode) noexcept
1621 {
1622   auto               vv = const_cast<PetscScalar *>(v);
1623   PetscMemType       memtype;
1624   PetscDeviceContext dctx;
1625   cupmStream_t       stream;
1626 
1627   PetscFunctionBegin;
1628   PetscCall(GetHandles_(&dctx, &stream));
1629   PetscCall(PetscGetMemType(v, &memtype));
1630   if (PetscMemTypeHost(memtype)) {
1631     const auto size = VecIMPLCast(x)->coo_n;
1632 
1633     // If user gave v[] in host, we might need to copy it to device if any
1634     PetscCall(PetscDeviceMalloc(dctx, PETSC_MEMTYPE_CUPM(), size, &vv));
1635     PetscCall(PetscCUPMMemcpyAsync(vv, v, size, cupmMemcpyHostToDevice, stream));
1636   }
1637 
1638   if (const auto n = x->map->n) {
1639     const auto vcu = VecCUPMCast(x);
1640 
1641     PetscCall(PetscCUPMLaunchKernel1D(n, 0, stream, kernels::add_coo_values, vv, n, vcu->jmap1_d, vcu->perm1_d, imode, imode == INSERT_VALUES ? DeviceArrayWrite(dctx, x).data() : DeviceArrayReadWrite(dctx, x).data()));
1642   } else {
1643     PetscCall(MaybeIncrementEmptyLocalVec(x));
1644   }
1645 
1646   if (PetscMemTypeHost(memtype)) PetscCall(PetscDeviceFree(dctx, vv));
1647   PetscCall(PetscDeviceContextSynchronize(dctx));
1648   PetscFunctionReturn(PETSC_SUCCESS);
1649 }
1650 
1651 } // namespace impl
1652 
1653 // ==========================================================================================
1654 // VecSeq_CUPM - Implementations
1655 // ==========================================================================================
1656 
1657 namespace
1658 {
1659 
1660 template <device::cupm::DeviceType T>
1661 inline PetscErrorCode VecCreateSeqCUPMAsync(MPI_Comm comm, PetscInt n, Vec *v) noexcept
1662 {
1663   PetscFunctionBegin;
1664   PetscValidPointer(v, 4);
1665   PetscCall(impl::VecSeq_CUPM<T>::CreateSeqCUPM(comm, 0, n, v, PETSC_TRUE));
1666   PetscFunctionReturn(PETSC_SUCCESS);
1667 }
1668 
1669 template <device::cupm::DeviceType T>
1670 inline PetscErrorCode VecCreateSeqCUPMWithArraysAsync(MPI_Comm comm, PetscInt bs, PetscInt n, const PetscScalar cpuarray[], const PetscScalar gpuarray[], Vec *v) noexcept
1671 {
1672   PetscFunctionBegin;
1673   if (n && cpuarray) PetscValidScalarPointer(cpuarray, 4);
1674   PetscValidPointer(v, 6);
1675   PetscCall(impl::VecSeq_CUPM<T>::CreateSeqCUPMWithBothArrays(comm, bs, n, cpuarray, gpuarray, v));
1676   PetscFunctionReturn(PETSC_SUCCESS);
1677 }
1678 
1679 template <PetscMemoryAccessMode mode, device::cupm::DeviceType T>
1680 inline PetscErrorCode VecCUPMGetArrayAsync_Private(Vec v, PetscScalar **a, PetscDeviceContext dctx) noexcept
1681 {
1682   PetscFunctionBegin;
1683   PetscValidHeaderSpecific(v, VEC_CLASSID, 1);
1684   PetscValidPointer(a, 2);
1685   PetscCall(PetscDeviceContextGetOptionalNullContext_Internal(&dctx));
1686   PetscCall(impl::VecSeq_CUPM<T>::template GetArray<PETSC_MEMTYPE_DEVICE, mode>(v, a, dctx));
1687   PetscFunctionReturn(PETSC_SUCCESS);
1688 }
1689 
1690 template <PetscMemoryAccessMode mode, device::cupm::DeviceType T>
1691 inline PetscErrorCode VecCUPMRestoreArrayAsync_Private(Vec v, PetscScalar **a, PetscDeviceContext dctx) noexcept
1692 {
1693   PetscFunctionBegin;
1694   PetscValidHeaderSpecific(v, VEC_CLASSID, 1);
1695   PetscCall(PetscDeviceContextGetOptionalNullContext_Internal(&dctx));
1696   PetscCall(impl::VecSeq_CUPM<T>::template RestoreArray<PETSC_MEMTYPE_DEVICE, mode>(v, a, dctx));
1697   PetscFunctionReturn(PETSC_SUCCESS);
1698 }
1699 
1700 template <device::cupm::DeviceType T>
1701 inline PetscErrorCode VecCUPMGetArrayAsync(Vec v, PetscScalar **a, PetscDeviceContext dctx = nullptr) noexcept
1702 {
1703   PetscFunctionBegin;
1704   PetscCall(VecCUPMGetArrayAsync_Private<PETSC_MEMORY_ACCESS_READ_WRITE, T>(v, a, dctx));
1705   PetscFunctionReturn(PETSC_SUCCESS);
1706 }
1707 
1708 template <device::cupm::DeviceType T>
1709 inline PetscErrorCode VecCUPMRestoreArrayAsync(Vec v, PetscScalar **a, PetscDeviceContext dctx = nullptr) noexcept
1710 {
1711   PetscFunctionBegin;
1712   PetscCall(VecCUPMRestoreArrayAsync_Private<PETSC_MEMORY_ACCESS_READ_WRITE, T>(v, a, dctx));
1713   PetscFunctionReturn(PETSC_SUCCESS);
1714 }
1715 
1716 template <device::cupm::DeviceType T>
1717 inline PetscErrorCode VecCUPMGetArrayReadAsync(Vec v, const PetscScalar **a, PetscDeviceContext dctx = nullptr) noexcept
1718 {
1719   PetscFunctionBegin;
1720   PetscCall(VecCUPMGetArrayAsync_Private<PETSC_MEMORY_ACCESS_READ, T>(v, const_cast<PetscScalar **>(a), dctx));
1721   PetscFunctionReturn(PETSC_SUCCESS);
1722 }
1723 
1724 template <device::cupm::DeviceType T>
1725 inline PetscErrorCode VecCUPMRestoreArrayReadAsync(Vec v, const PetscScalar **a, PetscDeviceContext dctx = nullptr) noexcept
1726 {
1727   PetscFunctionBegin;
1728   PetscCall(VecCUPMRestoreArrayAsync_Private<PETSC_MEMORY_ACCESS_READ, T>(v, const_cast<PetscScalar **>(a), dctx));
1729   PetscFunctionReturn(PETSC_SUCCESS);
1730 }
1731 
1732 template <device::cupm::DeviceType T>
1733 inline PetscErrorCode VecCUPMGetArrayWriteAsync(Vec v, PetscScalar **a, PetscDeviceContext dctx = nullptr) noexcept
1734 {
1735   PetscFunctionBegin;
1736   PetscCall(VecCUPMGetArrayAsync_Private<PETSC_MEMORY_ACCESS_WRITE, T>(v, a, dctx));
1737   PetscFunctionReturn(PETSC_SUCCESS);
1738 }
1739 
1740 template <device::cupm::DeviceType T>
1741 inline PetscErrorCode VecCUPMRestoreArrayWriteAsync(Vec v, PetscScalar **a, PetscDeviceContext dctx = nullptr) noexcept
1742 {
1743   PetscFunctionBegin;
1744   PetscCall(VecCUPMRestoreArrayAsync_Private<PETSC_MEMORY_ACCESS_WRITE, T>(v, a, dctx));
1745   PetscFunctionReturn(PETSC_SUCCESS);
1746 }
1747 
1748 template <device::cupm::DeviceType T>
1749 inline PetscErrorCode VecCUPMPlaceArrayAsync(Vec vin, const PetscScalar a[]) noexcept
1750 {
1751   PetscFunctionBegin;
1752   PetscValidHeaderSpecific(vin, VEC_CLASSID, 1);
1753   PetscCall(impl::VecSeq_CUPM<T>::template PlaceArray<PETSC_MEMTYPE_DEVICE>(vin, a));
1754   PetscFunctionReturn(PETSC_SUCCESS);
1755 }
1756 
1757 template <device::cupm::DeviceType T>
1758 inline PetscErrorCode VecCUPMReplaceArrayAsync(Vec vin, const PetscScalar a[]) noexcept
1759 {
1760   PetscFunctionBegin;
1761   PetscValidHeaderSpecific(vin, VEC_CLASSID, 1);
1762   PetscCall(impl::VecSeq_CUPM<T>::template ReplaceArray<PETSC_MEMTYPE_DEVICE>(vin, a));
1763   PetscFunctionReturn(PETSC_SUCCESS);
1764 }
1765 
1766 template <device::cupm::DeviceType T>
1767 inline PetscErrorCode VecCUPMResetArrayAsync(Vec vin) noexcept
1768 {
1769   PetscFunctionBegin;
1770   PetscValidHeaderSpecific(vin, VEC_CLASSID, 1);
1771   PetscCall(impl::VecSeq_CUPM<T>::template ResetArray<PETSC_MEMTYPE_DEVICE>(vin));
1772   PetscFunctionReturn(PETSC_SUCCESS);
1773 }
1774 
1775 } // anonymous namespace
1776 
1777 } // namespace cupm
1778 
1779 } // namespace vec
1780 
1781 } // namespace Petsc
1782 
1783 #endif // __cplusplus
1784 
1785 #endif // PETSCVECSEQCUPM_HPP
1786