xref: /petsc/src/vec/vec/impls/seq/cupm/vecseqcupm.hpp (revision 4e04ae5a56968941e5bdff79cfe464b412fb7b34)
1 #ifndef PETSCVECSEQCUPM_HPP
2 #define PETSCVECSEQCUPM_HPP
3 
4 #include <petsc/private/veccupmimpl.h>
5 
6 #if defined(__cplusplus)
7   #include <petsc/private/randomimpl.h> // for _p_PetscRandom
8 
9   #include <petsc/private/cpp/utility.hpp> // util::exchange, util::index_sequence
10 
11   #include "../src/sys/objects/device/impls/cupm/cupmthrustutility.hpp"
12   #include "../src/sys/objects/device/impls/cupm/kernels.hpp"
13 
14   #if PetscDefined(USE_COMPLEX)
15     #include <thrust/transform_reduce.h>
16   #endif
17   #include <thrust/transform.h>
18   #include <thrust/reduce.h>
19   #include <thrust/functional.h>
20   #include <thrust/tuple.h>
21   #include <thrust/device_ptr.h>
22   #include <thrust/iterator/zip_iterator.h>
23   #include <thrust/iterator/counting_iterator.h>
24   #include <thrust/iterator/constant_iterator.h>
25   #include <thrust/inner_product.h>
26 
27 namespace Petsc
28 {
29 
30 namespace vec
31 {
32 
33 namespace cupm
34 {
35 
36 namespace impl
37 {
38 
39 // ==========================================================================================
40 // VecSeq_CUPM
41 // ==========================================================================================
42 
43 template <device::cupm::DeviceType T>
44 class VecSeq_CUPM : Vec_CUPMBase<T, VecSeq_CUPM<T>> {
45 public:
46   PETSC_VEC_CUPM_BASE_CLASS_HEADER(base_type, T, VecSeq_CUPM<T>);
47 
48 private:
49   PETSC_NODISCARD static Vec_Seq          *VecIMPLCast_(Vec) noexcept;
50   PETSC_NODISCARD static constexpr VecType VECIMPLCUPM_() noexcept;
51   PETSC_NODISCARD static constexpr VecType VECIMPL_() noexcept;
52 
53   static PetscErrorCode VecDestroy_IMPL_(Vec) noexcept;
54   static PetscErrorCode VecResetArray_IMPL_(Vec) noexcept;
55   static PetscErrorCode VecPlaceArray_IMPL_(Vec, const PetscScalar *) noexcept;
56   static PetscErrorCode VecCreate_IMPL_Private_(Vec, PetscBool *, PetscInt, PetscScalar *) noexcept;
57 
58   static PetscErrorCode MaybeIncrementEmptyLocalVec(Vec) noexcept;
59 
60   // common core for min and max
61   template <typename TupleFuncT, typename UnaryFuncT>
62   static PetscErrorCode MinMax_(TupleFuncT &&, UnaryFuncT &&, Vec, PetscInt *, PetscReal *) noexcept;
63   // common core for pointwise binary and pointwise unary thrust functions
64   template <typename BinaryFuncT>
65   static PetscErrorCode PointwiseBinary_(BinaryFuncT &&, Vec, Vec, Vec) noexcept;
66   template <typename UnaryFuncT>
67   static PetscErrorCode PointwiseUnary_(UnaryFuncT &&, Vec, Vec /*out*/ = nullptr) noexcept;
68   // mdot dispatchers
69   static PetscErrorCode MDot_(/* use complex = */ std::true_type, Vec, PetscInt, const Vec[], PetscScalar *, PetscDeviceContext) noexcept;
70   static PetscErrorCode MDot_(/* use complex = */ std::false_type, Vec, PetscInt, const Vec[], PetscScalar *, PetscDeviceContext) noexcept;
71   template <std::size_t... Idx>
72   static PetscErrorCode MDot_kernel_dispatch_(PetscDeviceContext, cupmStream_t, const PetscScalar *, const Vec[], PetscInt, PetscScalar *, util::index_sequence<Idx...>) noexcept;
73   template <int>
74   static PetscErrorCode MDot_kernel_dispatch_(PetscDeviceContext, cupmStream_t, const PetscScalar *, const Vec[], PetscInt, PetscScalar *, PetscInt &) noexcept;
75   template <std::size_t... Idx>
76   static PetscErrorCode MAXPY_kernel_dispatch_(PetscDeviceContext, cupmStream_t, PetscScalar *, const PetscScalar *, const Vec *, PetscInt, util::index_sequence<Idx...>) noexcept;
77   template <int>
78   static PetscErrorCode MAXPY_kernel_dispatch_(PetscDeviceContext, cupmStream_t, PetscScalar *, const PetscScalar *, const Vec *, PetscInt, PetscInt &) noexcept;
79   // common core for the various create routines
80   static PetscErrorCode CreateSeqCUPM_(Vec, PetscDeviceContext, PetscScalar * /*host_ptr*/ = nullptr, PetscScalar * /*device_ptr*/ = nullptr) noexcept;
81 
82 public:
83   // callable directly via a bespoke function
84   static PetscErrorCode CreateSeqCUPM(MPI_Comm, PetscInt, PetscInt, Vec *, PetscBool) noexcept;
85   static PetscErrorCode CreateSeqCUPMWithBothArrays(MPI_Comm, PetscInt, PetscInt, const PetscScalar[], const PetscScalar[], Vec *) noexcept;
86 
87   // callable indirectly via function pointers
88   static PetscErrorCode Duplicate(Vec, Vec *) noexcept;
89   static PetscErrorCode AYPX(Vec, PetscScalar, Vec) noexcept;
90   static PetscErrorCode AXPY(Vec, PetscScalar, Vec) noexcept;
91   static PetscErrorCode PointwiseDivide(Vec, Vec, Vec) noexcept;
92   static PetscErrorCode PointwiseMult(Vec, Vec, Vec) noexcept;
93   static PetscErrorCode Reciprocal(Vec) noexcept;
94   static PetscErrorCode WAXPY(Vec, PetscScalar, Vec, Vec) noexcept;
95   static PetscErrorCode MAXPY(Vec, PetscInt, const PetscScalar[], Vec *) noexcept;
96   static PetscErrorCode Dot(Vec, Vec, PetscScalar *) noexcept;
97   static PetscErrorCode MDot(Vec, PetscInt, const Vec[], PetscScalar *) noexcept;
98   static PetscErrorCode Set(Vec, PetscScalar) noexcept;
99   static PetscErrorCode Scale(Vec, PetscScalar) noexcept;
100   static PetscErrorCode TDot(Vec, Vec, PetscScalar *) noexcept;
101   static PetscErrorCode Copy(Vec, Vec) noexcept;
102   static PetscErrorCode Swap(Vec, Vec) noexcept;
103   static PetscErrorCode AXPBY(Vec, PetscScalar, PetscScalar, Vec) noexcept;
104   static PetscErrorCode AXPBYPCZ(Vec, PetscScalar, PetscScalar, PetscScalar, Vec, Vec) noexcept;
105   static PetscErrorCode Norm(Vec, NormType, PetscReal *) noexcept;
106   static PetscErrorCode ErrorWnorm(Vec, Vec, Vec, NormType, PetscReal, Vec, PetscReal, Vec, PetscReal, PetscReal *, PetscInt *, PetscReal *, PetscInt *, PetscReal *, PetscInt *) noexcept;
107   static PetscErrorCode DotNorm2(Vec, Vec, PetscScalar *, PetscScalar *) noexcept;
108   static PetscErrorCode Conjugate(Vec) noexcept;
109   template <PetscMemoryAccessMode>
110   static PetscErrorCode GetLocalVector(Vec, Vec) noexcept;
111   template <PetscMemoryAccessMode>
112   static PetscErrorCode RestoreLocalVector(Vec, Vec) noexcept;
113   static PetscErrorCode Max(Vec, PetscInt *, PetscReal *) noexcept;
114   static PetscErrorCode Min(Vec, PetscInt *, PetscReal *) noexcept;
115   static PetscErrorCode Sum(Vec, PetscScalar *) noexcept;
116   static PetscErrorCode Shift(Vec, PetscScalar) noexcept;
117   static PetscErrorCode SetRandom(Vec, PetscRandom) noexcept;
118   static PetscErrorCode BindToCPU(Vec, PetscBool) noexcept;
119   static PetscErrorCode SetPreallocationCOO(Vec, PetscCount, const PetscInt[]) noexcept;
120   static PetscErrorCode SetValuesCOO(Vec, const PetscScalar[], InsertMode) noexcept;
121 };
122 
123 // ==========================================================================================
124 // VecSeq_CUPM - Private API
125 // ==========================================================================================
126 
127 template <device::cupm::DeviceType T>
128 inline Vec_Seq *VecSeq_CUPM<T>::VecIMPLCast_(Vec v) noexcept
129 {
130   return static_cast<Vec_Seq *>(v->data);
131 }
132 
133 template <device::cupm::DeviceType T>
134 inline constexpr VecType VecSeq_CUPM<T>::VECIMPLCUPM_() noexcept
135 {
136   return VECSEQCUPM();
137 }
138 
139 template <device::cupm::DeviceType T>
140 inline constexpr VecType VecSeq_CUPM<T>::VECIMPL_() noexcept
141 {
142   return VECSEQ;
143 }
144 
145 template <device::cupm::DeviceType T>
146 inline PetscErrorCode VecSeq_CUPM<T>::VecDestroy_IMPL_(Vec v) noexcept
147 {
148   return VecDestroy_Seq(v);
149 }
150 
151 template <device::cupm::DeviceType T>
152 inline PetscErrorCode VecSeq_CUPM<T>::VecResetArray_IMPL_(Vec v) noexcept
153 {
154   return VecResetArray_Seq(v);
155 }
156 
157 template <device::cupm::DeviceType T>
158 inline PetscErrorCode VecSeq_CUPM<T>::VecPlaceArray_IMPL_(Vec v, const PetscScalar *a) noexcept
159 {
160   return VecPlaceArray_Seq(v, a);
161 }
162 
163 template <device::cupm::DeviceType T>
164 inline PetscErrorCode VecSeq_CUPM<T>::VecCreate_IMPL_Private_(Vec v, PetscBool *alloc_missing, PetscInt, PetscScalar *host_array) noexcept
165 {
166   PetscMPIInt size;
167 
168   PetscFunctionBegin;
169   if (alloc_missing) *alloc_missing = PETSC_FALSE;
170   PetscCallMPI(MPI_Comm_size(PetscObjectComm(PetscObjectCast(v)), &size));
171   PetscCheck(size <= 1, PETSC_COMM_SELF, PETSC_ERR_ARG_WRONG, "Must create VecSeq on communicator of size 1, have size %d", size);
172   PetscCall(VecCreate_Seq_Private(v, host_array));
173   PetscFunctionReturn(PETSC_SUCCESS);
174 }
175 
176 // for functions with an early return based one vec size we still need to artificially bump the
177 // object state. This is to prevent the following:
178 //
179 // 0. Suppose you have a Vec {
180 //   rank 0: [0],
181 //   rank 1: [<empty>]
182 // }
183 // 1. both ranks have Vec with PetscObjectState = 0, stashed norm of 0
184 // 2. Vec enters e.g. VecSet(10)
185 // 3. rank 1 has local size 0 and bails immediately
186 // 4. rank 0 has local size 1 and enters function, eventually calls DeviceArrayWrite()
187 // 5. DeviceArrayWrite() calls PetscObjectStateIncrease(), now state = 1
188 // 6. Vec enters VecNorm(), and calls VecNormAvailable()
189 // 7. rank 1 has object state = 0, equal to stash and returns early with norm = 0
190 // 8. rank 0 has object state = 1, not equal to stash, continues to impl function
191 // 9. rank 0 deadlocks on MPI_Allreduce() because rank 1 bailed early
192 template <device::cupm::DeviceType T>
193 inline PetscErrorCode VecSeq_CUPM<T>::MaybeIncrementEmptyLocalVec(Vec v) noexcept
194 {
195   PetscFunctionBegin;
196   if (PetscUnlikely((v->map->n == 0) && (v->map->N != 0))) PetscCall(PetscObjectStateIncrease(PetscObjectCast(v)));
197   PetscFunctionReturn(PETSC_SUCCESS);
198 }
199 
200 template <device::cupm::DeviceType T>
201 inline PetscErrorCode VecSeq_CUPM<T>::CreateSeqCUPM_(Vec v, PetscDeviceContext dctx, PetscScalar *host_array, PetscScalar *device_array) noexcept
202 {
203   PetscFunctionBegin;
204   PetscCall(base_type::VecCreate_IMPL_Private(v, nullptr, 0, host_array));
205   PetscCall(Initialize_CUPMBase(v, PETSC_FALSE, host_array, device_array, dctx));
206   PetscFunctionReturn(PETSC_SUCCESS);
207 }
208 
209 template <device::cupm::DeviceType T>
210 template <typename BinaryFuncT>
211 inline PetscErrorCode VecSeq_CUPM<T>::PointwiseBinary_(BinaryFuncT &&binary, Vec xin, Vec yin, Vec zout) noexcept
212 {
213   PetscFunctionBegin;
214   if (const auto n = zout->map->n) {
215     PetscDeviceContext dctx;
216     cupmStream_t       stream;
217 
218     PetscCall(GetHandles_(&dctx, &stream));
219     // clang-format off
220     PetscCallThrust(
221       const auto dxptr = thrust::device_pointer_cast(DeviceArrayRead(dctx, xin).data());
222 
223       THRUST_CALL(
224         thrust::transform,
225         stream,
226         dxptr, dxptr + n,
227         thrust::device_pointer_cast(DeviceArrayRead(dctx, yin).data()),
228         thrust::device_pointer_cast(DeviceArrayWrite(dctx, zout).data()),
229         std::forward<BinaryFuncT>(binary)
230       )
231     );
232     // clang-format on
233     PetscCall(PetscLogFlops(n));
234     PetscCall(PetscDeviceContextSynchronize(dctx));
235   } else {
236     PetscCall(MaybeIncrementEmptyLocalVec(zout));
237   }
238   PetscFunctionReturn(PETSC_SUCCESS);
239 }
240 
241 template <device::cupm::DeviceType T>
242 template <typename UnaryFuncT>
243 inline PetscErrorCode VecSeq_CUPM<T>::PointwiseUnary_(UnaryFuncT &&unary, Vec xinout, Vec yin) noexcept
244 {
245   const auto inplace = !yin || (xinout == yin);
246 
247   PetscFunctionBegin;
248   if (const auto n = xinout->map->n) {
249     PetscDeviceContext dctx;
250     cupmStream_t       stream;
251     const auto         apply = [&](PetscScalar *xinout, PetscScalar *yin = nullptr) {
252       PetscFunctionBegin;
253       // clang-format off
254       PetscCallThrust(
255         const auto xptr = thrust::device_pointer_cast(xinout);
256 
257         THRUST_CALL(
258           thrust::transform,
259           stream,
260           xptr, xptr + n,
261           (yin && (yin != xinout)) ? thrust::device_pointer_cast(yin) : xptr,
262           std::forward<UnaryFuncT>(unary)
263         )
264       );
265       // clang-format on
266       PetscFunctionReturn(PETSC_SUCCESS);
267     };
268 
269     PetscCall(GetHandles_(&dctx, &stream));
270     if (inplace) {
271       PetscCall(apply(DeviceArrayReadWrite(dctx, xinout).data()));
272     } else {
273       PetscCall(apply(DeviceArrayRead(dctx, xinout).data(), DeviceArrayWrite(dctx, yin).data()));
274     }
275     PetscCall(PetscLogFlops(n));
276     PetscCall(PetscDeviceContextSynchronize(dctx));
277   } else {
278     if (inplace) {
279       PetscCall(MaybeIncrementEmptyLocalVec(xinout));
280     } else {
281       PetscCall(MaybeIncrementEmptyLocalVec(yin));
282     }
283   }
284   PetscFunctionReturn(PETSC_SUCCESS);
285 }
286 
287 // ==========================================================================================
288 // VecSeq_CUPM - Public API - Constructors
289 // ==========================================================================================
290 
291 // VecCreateSeqCUPM()
292 template <device::cupm::DeviceType T>
293 inline PetscErrorCode VecSeq_CUPM<T>::CreateSeqCUPM(MPI_Comm comm, PetscInt bs, PetscInt n, Vec *v, PetscBool call_set_type) noexcept
294 {
295   PetscFunctionBegin;
296   PetscCall(Create_CUPMBase(comm, bs, n, n, v, call_set_type));
297   PetscFunctionReturn(PETSC_SUCCESS);
298 }
299 
300 // VecCreateSeqCUPMWithArrays()
301 template <device::cupm::DeviceType T>
302 inline PetscErrorCode VecSeq_CUPM<T>::CreateSeqCUPMWithBothArrays(MPI_Comm comm, PetscInt bs, PetscInt n, const PetscScalar host_array[], const PetscScalar device_array[], Vec *v) noexcept
303 {
304   PetscDeviceContext dctx;
305 
306   PetscFunctionBegin;
307   PetscCall(GetHandles_(&dctx));
308   // do NOT call VecSetType(), otherwise ops->create() -> create() ->
309   // CreateSeqCUPM_() is called!
310   PetscCall(CreateSeqCUPM(comm, bs, n, v, PETSC_FALSE));
311   PetscCall(CreateSeqCUPM_(*v, dctx, PetscRemoveConstCast(host_array), PetscRemoveConstCast(device_array)));
312   PetscFunctionReturn(PETSC_SUCCESS);
313 }
314 
315 // v->ops->duplicate
316 template <device::cupm::DeviceType T>
317 inline PetscErrorCode VecSeq_CUPM<T>::Duplicate(Vec v, Vec *y) noexcept
318 {
319   PetscDeviceContext dctx;
320 
321   PetscFunctionBegin;
322   PetscCall(GetHandles_(&dctx));
323   PetscCall(Duplicate_CUPMBase(v, y, dctx));
324   PetscFunctionReturn(PETSC_SUCCESS);
325 }
326 
327 // ==========================================================================================
328 // VecSeq_CUPM - Public API - Utility
329 // ==========================================================================================
330 
331 // v->ops->bindtocpu
332 template <device::cupm::DeviceType T>
333 inline PetscErrorCode VecSeq_CUPM<T>::BindToCPU(Vec v, PetscBool usehost) noexcept
334 {
335   PetscDeviceContext dctx;
336 
337   PetscFunctionBegin;
338   PetscCall(GetHandles_(&dctx));
339   PetscCall(BindToCPU_CUPMBase(v, usehost, dctx));
340 
341   // REVIEW ME: this absolutely should be some sort of bulk mempcy rather than this mess
342   VecSetOp_CUPM(dot, VecDot_Seq, Dot);
343   VecSetOp_CUPM(norm, VecNorm_Seq, Norm);
344   VecSetOp_CUPM(tdot, VecTDot_Seq, TDot);
345   VecSetOp_CUPM(mdot, VecMDot_Seq, MDot);
346   VecSetOp_CUPM(resetarray, VecResetArray_Seq, base_type::template ResetArray<PETSC_MEMTYPE_HOST>);
347   VecSetOp_CUPM(placearray, VecPlaceArray_Seq, base_type::template PlaceArray<PETSC_MEMTYPE_HOST>);
348   v->ops->mtdot = v->ops->mtdot_local = VecMTDot_Seq;
349   VecSetOp_CUPM(conjugate, VecConjugate_Seq, Conjugate);
350   VecSetOp_CUPM(max, VecMax_Seq, Max);
351   VecSetOp_CUPM(min, VecMin_Seq, Min);
352   VecSetOp_CUPM(setpreallocationcoo, VecSetPreallocationCOO_Seq, SetPreallocationCOO);
353   VecSetOp_CUPM(setvaluescoo, VecSetValuesCOO_Seq, SetValuesCOO);
354   VecSetOp_CUPM(errorwnorm, nullptr, ErrorWnorm);
355   PetscFunctionReturn(PETSC_SUCCESS);
356 }
357 
358 // ==========================================================================================
359 // VecSeq_CUPM - Public API - Mutators
360 // ==========================================================================================
361 
362 // v->ops->getlocalvector or v->ops->getlocalvectorread
363 template <device::cupm::DeviceType T>
364 template <PetscMemoryAccessMode access>
365 inline PetscErrorCode VecSeq_CUPM<T>::GetLocalVector(Vec v, Vec w) noexcept
366 {
367   PetscBool wisseqcupm;
368 
369   PetscFunctionBegin;
370   PetscCheckTypeNames(v, VECSEQCUPM(), VECMPICUPM());
371   PetscCall(PetscObjectTypeCompare(PetscObjectCast(w), VECSEQCUPM(), &wisseqcupm));
372   if (wisseqcupm) {
373     if (const auto wseq = VecIMPLCast(w)) {
374       if (auto &alloced = wseq->array_allocated) {
375         const auto useit = UseCUPMHostAlloc(util::exchange(w->pinned_memory, PETSC_FALSE));
376 
377         PetscCall(PetscFree(alloced));
378       }
379       wseq->array         = nullptr;
380       wseq->unplacedarray = nullptr;
381     }
382     if (const auto wcu = VecCUPMCast(w)) {
383       if (auto &device_array = wcu->array_d) {
384         cupmStream_t stream;
385 
386         PetscCall(GetHandles_(&stream));
387         PetscCallCUPM(cupmFreeAsync(device_array, stream));
388       }
389       PetscCall(PetscFree(w->spptr /* wcu */));
390     }
391   }
392   if (v->petscnative && wisseqcupm) {
393     PetscCall(PetscFree(w->data));
394     w->data          = v->data;
395     w->offloadmask   = v->offloadmask;
396     w->pinned_memory = v->pinned_memory;
397     w->spptr         = v->spptr;
398     PetscCall(PetscObjectStateIncrease(PetscObjectCast(w)));
399   } else {
400     const auto array = &VecIMPLCast(w)->array;
401 
402     if (access == PETSC_MEMORY_ACCESS_READ) {
403       PetscCall(VecGetArrayRead(v, const_cast<const PetscScalar **>(array)));
404     } else {
405       PetscCall(VecGetArray(v, array));
406     }
407     w->offloadmask = PETSC_OFFLOAD_CPU;
408     if (wisseqcupm) {
409       PetscDeviceContext dctx;
410 
411       PetscCall(GetHandles_(&dctx));
412       PetscCall(DeviceAllocateCheck_(dctx, w));
413     }
414   }
415   PetscFunctionReturn(PETSC_SUCCESS);
416 }
417 
418 // v->ops->restorelocalvector or v->ops->restorelocalvectorread
419 template <device::cupm::DeviceType T>
420 template <PetscMemoryAccessMode access>
421 inline PetscErrorCode VecSeq_CUPM<T>::RestoreLocalVector(Vec v, Vec w) noexcept
422 {
423   PetscBool wisseqcupm;
424 
425   PetscFunctionBegin;
426   PetscCheckTypeNames(v, VECSEQCUPM(), VECMPICUPM());
427   PetscCall(PetscObjectTypeCompare(PetscObjectCast(w), VECSEQCUPM(), &wisseqcupm));
428   if (v->petscnative && wisseqcupm) {
429     // the assignments to nullptr are __critical__, as w may persist after this call returns
430     // and shouldn't share data with v!
431     v->pinned_memory = w->pinned_memory;
432     v->offloadmask   = util::exchange(w->offloadmask, PETSC_OFFLOAD_UNALLOCATED);
433     v->data          = util::exchange(w->data, nullptr);
434     v->spptr         = util::exchange(w->spptr, nullptr);
435   } else {
436     const auto array = &VecIMPLCast(w)->array;
437 
438     if (access == PETSC_MEMORY_ACCESS_READ) {
439       PetscCall(VecRestoreArrayRead(v, const_cast<const PetscScalar **>(array)));
440     } else {
441       PetscCall(VecRestoreArray(v, array));
442     }
443     if (w->spptr && wisseqcupm) {
444       cupmStream_t stream;
445 
446       PetscCall(GetHandles_(&stream));
447       PetscCallCUPM(cupmFreeAsync(VecCUPMCast(w)->array_d, stream));
448       PetscCall(PetscFree(w->spptr));
449     }
450   }
451   PetscFunctionReturn(PETSC_SUCCESS);
452 }
453 
454 // ==========================================================================================
455 // VecSeq_CUPM - Public API - Compute Methods
456 // ==========================================================================================
457 
458 // v->ops->aypx
459 template <device::cupm::DeviceType T>
460 inline PetscErrorCode VecSeq_CUPM<T>::AYPX(Vec yin, PetscScalar alpha, Vec xin) noexcept
461 {
462   const auto         n    = static_cast<cupmBlasInt_t>(yin->map->n);
463   const auto         sync = n != 0;
464   PetscDeviceContext dctx;
465 
466   PetscFunctionBegin;
467   PetscCall(GetHandles_(&dctx));
468   if (alpha == PetscScalar(0.0)) {
469     cupmStream_t stream;
470 
471     PetscCall(GetHandlesFrom_(dctx, &stream));
472     PetscCall(PetscLogGpuTimeBegin());
473     PetscCall(PetscCUPMMemcpyAsync(DeviceArrayWrite(dctx, yin).data(), DeviceArrayRead(dctx, xin).data(), n, cupmMemcpyDeviceToDevice, stream));
474     PetscCall(PetscLogGpuTimeEnd());
475   } else if (n) {
476     const auto       alphaIsOne = alpha == PetscScalar(1.0);
477     const auto       calpha     = cupmScalarPtrCast(&alpha);
478     cupmBlasHandle_t cupmBlasHandle;
479 
480     PetscCall(GetHandlesFrom_(dctx, &cupmBlasHandle));
481     {
482       const auto yptr = DeviceArrayReadWrite(dctx, yin);
483       const auto xptr = DeviceArrayRead(dctx, xin);
484 
485       PetscCall(PetscLogGpuTimeBegin());
486       if (alphaIsOne) {
487         PetscCallCUPMBLAS(cupmBlasXaxpy(cupmBlasHandle, n, calpha, xptr.cupmdata(), 1, yptr.cupmdata(), 1));
488       } else {
489         const auto one = cupmScalarCast(1.0);
490 
491         PetscCallCUPMBLAS(cupmBlasXscal(cupmBlasHandle, n, calpha, yptr.cupmdata(), 1));
492         PetscCallCUPMBLAS(cupmBlasXaxpy(cupmBlasHandle, n, &one, xptr.cupmdata(), 1, yptr.cupmdata(), 1));
493       }
494       PetscCall(PetscLogGpuTimeEnd());
495     }
496     PetscCall(PetscLogGpuFlops((alphaIsOne ? 1 : 2) * n));
497   }
498   if (sync) PetscCall(PetscDeviceContextSynchronize(dctx));
499   PetscFunctionReturn(PETSC_SUCCESS);
500 }
501 
502 // v->ops->axpy
503 template <device::cupm::DeviceType T>
504 inline PetscErrorCode VecSeq_CUPM<T>::AXPY(Vec yin, PetscScalar alpha, Vec xin) noexcept
505 {
506   PetscBool xiscupm;
507 
508   PetscFunctionBegin;
509   PetscCall(PetscObjectTypeCompareAny(PetscObjectCast(xin), &xiscupm, VECSEQCUPM(), VECMPICUPM(), ""));
510   if (xiscupm) {
511     const auto         n = static_cast<cupmBlasInt_t>(yin->map->n);
512     PetscDeviceContext dctx;
513     cupmBlasHandle_t   cupmBlasHandle;
514 
515     PetscCall(GetHandles_(&dctx, &cupmBlasHandle));
516     PetscCall(PetscLogGpuTimeBegin());
517     PetscCallCUPMBLAS(cupmBlasXaxpy(cupmBlasHandle, n, cupmScalarPtrCast(&alpha), DeviceArrayRead(dctx, xin), 1, DeviceArrayReadWrite(dctx, yin), 1));
518     PetscCall(PetscLogGpuTimeEnd());
519     PetscCall(PetscLogGpuFlops(2 * n));
520     PetscCall(PetscDeviceContextSynchronize(dctx));
521   } else {
522     PetscCall(VecAXPY_Seq(yin, alpha, xin));
523   }
524   PetscFunctionReturn(PETSC_SUCCESS);
525 }
526 
527 // v->ops->pointwisedivide
528 template <device::cupm::DeviceType T>
529 inline PetscErrorCode VecSeq_CUPM<T>::PointwiseDivide(Vec win, Vec xin, Vec yin) noexcept
530 {
531   PetscFunctionBegin;
532   if (xin->boundtocpu || yin->boundtocpu) {
533     PetscCall(VecPointwiseDivide_Seq(win, xin, yin));
534   } else {
535     // note order of arguments! xin and yin are read, win is written!
536     PetscCall(PointwiseBinary_(thrust::divides<PetscScalar>{}, xin, yin, win));
537   }
538   PetscFunctionReturn(PETSC_SUCCESS);
539 }
540 
541 // v->ops->pointwisemult
542 template <device::cupm::DeviceType T>
543 inline PetscErrorCode VecSeq_CUPM<T>::PointwiseMult(Vec win, Vec xin, Vec yin) noexcept
544 {
545   PetscFunctionBegin;
546   if (xin->boundtocpu || yin->boundtocpu) {
547     PetscCall(VecPointwiseMult_Seq(win, xin, yin));
548   } else {
549     // note order of arguments! xin and yin are read, win is written!
550     PetscCall(PointwiseBinary_(thrust::multiplies<PetscScalar>{}, xin, yin, win));
551   }
552   PetscFunctionReturn(PETSC_SUCCESS);
553 }
554 
555 namespace detail
556 {
557 
558 struct reciprocal {
559   PETSC_HOSTDEVICE_INLINE_DECL PetscScalar operator()(PetscScalar s) const noexcept
560   {
561     // yes all of this verbosity is needed because sometimes PetscScalar is a thrust::complex
562     // and then it matters whether we do s ? true : false vs s == 0, as well as whether we wrap
563     // everything in PetscScalar...
564     return s == PetscScalar{0.0} ? s : PetscScalar{1.0} / s;
565   }
566 };
567 
568 } // namespace detail
569 
570 // v->ops->reciprocal
571 template <device::cupm::DeviceType T>
572 inline PetscErrorCode VecSeq_CUPM<T>::Reciprocal(Vec xin) noexcept
573 {
574   PetscFunctionBegin;
575   PetscCall(PointwiseUnary_(detail::reciprocal{}, xin));
576   PetscFunctionReturn(PETSC_SUCCESS);
577 }
578 
579 // v->ops->waxpy
580 template <device::cupm::DeviceType T>
581 inline PetscErrorCode VecSeq_CUPM<T>::WAXPY(Vec win, PetscScalar alpha, Vec xin, Vec yin) noexcept
582 {
583   PetscFunctionBegin;
584   if (alpha == PetscScalar(0.0)) {
585     PetscCall(Copy(yin, win));
586   } else if (const auto n = static_cast<cupmBlasInt_t>(win->map->n)) {
587     PetscDeviceContext dctx;
588     cupmBlasHandle_t   cupmBlasHandle;
589     cupmStream_t       stream;
590 
591     PetscCall(GetHandles_(&dctx, &cupmBlasHandle, &stream));
592     {
593       const auto wptr = DeviceArrayWrite(dctx, win);
594 
595       PetscCall(PetscLogGpuTimeBegin());
596       PetscCall(PetscCUPMMemcpyAsync(wptr.data(), DeviceArrayRead(dctx, yin).data(), n, cupmMemcpyDeviceToDevice, stream, true));
597       PetscCallCUPMBLAS(cupmBlasXaxpy(cupmBlasHandle, n, cupmScalarPtrCast(&alpha), DeviceArrayRead(dctx, xin), 1, wptr.cupmdata(), 1));
598       PetscCall(PetscLogGpuTimeEnd());
599     }
600     PetscCall(PetscLogGpuFlops(2 * n));
601     PetscCall(PetscDeviceContextSynchronize(dctx));
602   }
603   PetscFunctionReturn(PETSC_SUCCESS);
604 }
605 
606 namespace kernels
607 {
608 
609 template <typename... Args>
610 PETSC_KERNEL_DECL static void MAXPY_kernel(const PetscInt size, PetscScalar *PETSC_RESTRICT xptr, const PetscScalar *PETSC_RESTRICT aptr, Args... yptr)
611 {
612   constexpr int      N        = sizeof...(Args);
613   const auto         tx       = threadIdx.x;
614   const PetscScalar *yptr_p[] = {yptr...};
615 
616   PETSC_SHAREDMEM_DECL PetscScalar aptr_shmem[N];
617 
618   // load a to shared memory
619   if (tx < N) aptr_shmem[tx] = aptr[tx];
620   __syncthreads();
621 
622   ::Petsc::device::cupm::kernels::util::grid_stride_1D(size, [&](PetscInt i) {
623     // these may look the same but give different results!
624   #if 0
625     PetscScalar sum = 0.0;
626 
627     #pragma unroll
628     for (auto j = 0; j < N; ++j) sum += aptr_shmem[j]*yptr_p[j][i];
629     xptr[i] += sum;
630   #else
631     auto sum = xptr[i];
632 
633     #pragma unroll
634     for (auto j = 0; j < N; ++j) sum += aptr_shmem[j] * yptr_p[j][i];
635     xptr[i] = sum;
636   #endif
637   });
638   return;
639 }
640 
641 } // namespace kernels
642 
643 namespace detail
644 {
645 
646 // a helper-struct to gobble the size_t input, it is used with template parameter pack
647 // expansion such that
648 // typename repeat_type<MyType, IdxParamPack>...
649 // expands to
650 // MyType, MyType, MyType, ... [repeated sizeof...(IdxParamPack) times]
651 template <typename T, std::size_t>
652 struct repeat_type {
653   using type = T;
654 };
655 
656 } // namespace detail
657 
658 template <device::cupm::DeviceType T>
659 template <std::size_t... Idx>
660 inline PetscErrorCode VecSeq_CUPM<T>::MAXPY_kernel_dispatch_(PetscDeviceContext dctx, cupmStream_t stream, PetscScalar *xptr, const PetscScalar *aptr, const Vec *yin, PetscInt size, util::index_sequence<Idx...>) noexcept
661 {
662   PetscFunctionBegin;
663   // clang-format off
664   PetscCall(
665     PetscCUPMLaunchKernel1D(
666       size, 0, stream,
667       kernels::MAXPY_kernel<typename detail::repeat_type<const PetscScalar *, Idx>::type...>,
668       size, xptr, aptr, DeviceArrayRead(dctx, yin[Idx]).data()...
669     )
670   );
671   // clang-format on
672   PetscFunctionReturn(PETSC_SUCCESS);
673 }
674 
675 template <device::cupm::DeviceType T>
676 template <int N>
677 inline PetscErrorCode VecSeq_CUPM<T>::MAXPY_kernel_dispatch_(PetscDeviceContext dctx, cupmStream_t stream, PetscScalar *xptr, const PetscScalar *aptr, const Vec *yin, PetscInt size, PetscInt &yidx) noexcept
678 {
679   PetscFunctionBegin;
680   PetscCall(MAXPY_kernel_dispatch_(dctx, stream, xptr, aptr + yidx, yin + yidx, size, util::make_index_sequence<N>{}));
681   yidx += N;
682   PetscFunctionReturn(PETSC_SUCCESS);
683 }
684 
685 // v->ops->maxpy
686 template <device::cupm::DeviceType T>
687 inline PetscErrorCode VecSeq_CUPM<T>::MAXPY(Vec xin, PetscInt nv, const PetscScalar *alpha, Vec *yin) noexcept
688 {
689   const auto         n = xin->map->n;
690   PetscDeviceContext dctx;
691   cupmStream_t       stream;
692 
693   PetscFunctionBegin;
694   PetscCall(GetHandles_(&dctx, &stream));
695   {
696     const auto   xptr    = DeviceArrayReadWrite(dctx, xin);
697     PetscScalar *d_alpha = nullptr;
698     PetscInt     yidx    = 0;
699 
700     // placement of early-return is deliberate, we would like to capture the
701     // DeviceArrayReadWrite() call (which calls PetscObjectStateIncreate()) before we bail
702     if (!n || !nv) PetscFunctionReturn(PETSC_SUCCESS);
703     PetscCall(PetscDeviceMalloc(dctx, PETSC_MEMTYPE_CUPM(), nv, &d_alpha));
704     PetscCall(PetscCUPMMemcpyAsync(d_alpha, alpha, nv, cupmMemcpyHostToDevice, stream));
705     PetscCall(PetscLogGpuTimeBegin());
706     do {
707       switch (nv - yidx) {
708       case 7:
709         PetscCall(MAXPY_kernel_dispatch_<7>(dctx, stream, xptr.data(), d_alpha, yin, n, yidx));
710         break;
711       case 6:
712         PetscCall(MAXPY_kernel_dispatch_<6>(dctx, stream, xptr.data(), d_alpha, yin, n, yidx));
713         break;
714       case 5:
715         PetscCall(MAXPY_kernel_dispatch_<5>(dctx, stream, xptr.data(), d_alpha, yin, n, yidx));
716         break;
717       case 4:
718         PetscCall(MAXPY_kernel_dispatch_<4>(dctx, stream, xptr.data(), d_alpha, yin, n, yidx));
719         break;
720       case 3:
721         PetscCall(MAXPY_kernel_dispatch_<3>(dctx, stream, xptr.data(), d_alpha, yin, n, yidx));
722         break;
723       case 2:
724         PetscCall(MAXPY_kernel_dispatch_<2>(dctx, stream, xptr.data(), d_alpha, yin, n, yidx));
725         break;
726       case 1:
727         PetscCall(MAXPY_kernel_dispatch_<1>(dctx, stream, xptr.data(), d_alpha, yin, n, yidx));
728         break;
729       default: // 8 or more
730         PetscCall(MAXPY_kernel_dispatch_<8>(dctx, stream, xptr.data(), d_alpha, yin, n, yidx));
731         break;
732       }
733     } while (yidx < nv);
734     PetscCall(PetscLogGpuTimeEnd());
735     PetscCall(PetscDeviceFree(dctx, d_alpha));
736   }
737   PetscCall(PetscLogGpuFlops(nv * 2 * n));
738   PetscCall(PetscDeviceContextSynchronize(dctx));
739   PetscFunctionReturn(PETSC_SUCCESS);
740 }
741 
742 template <device::cupm::DeviceType T>
743 inline PetscErrorCode VecSeq_CUPM<T>::Dot(Vec xin, Vec yin, PetscScalar *z) noexcept
744 {
745   PetscFunctionBegin;
746   if (const auto n = static_cast<cupmBlasInt_t>(xin->map->n)) {
747     PetscDeviceContext dctx;
748     cupmBlasHandle_t   cupmBlasHandle;
749 
750     PetscCall(GetHandles_(&dctx, &cupmBlasHandle));
751     // arguments y, x are reversed because BLAS complex conjugates the first argument, PETSc the
752     // second
753     PetscCall(PetscLogGpuTimeBegin());
754     PetscCallCUPMBLAS(cupmBlasXdot(cupmBlasHandle, n, DeviceArrayRead(dctx, yin), 1, DeviceArrayRead(dctx, xin), 1, cupmScalarPtrCast(z)));
755     PetscCall(PetscLogGpuTimeEnd());
756     PetscCall(PetscLogGpuFlops(2 * n - 1));
757   } else {
758     *z = 0.0;
759   }
760   PetscFunctionReturn(PETSC_SUCCESS);
761 }
762 
763   #define MDOT_WORKGROUP_NUM  128
764   #define MDOT_WORKGROUP_SIZE MDOT_WORKGROUP_NUM
765 
766 namespace kernels
767 {
768 
769 PETSC_DEVICE_INLINE_DECL static PetscInt EntriesPerGroup(const PetscInt size) noexcept
770 {
771   const auto group_entries = (size - 1) / gridDim.x + 1;
772   // for very small vectors, a group should still do some work
773   return group_entries ? group_entries : 1;
774 }
775 
776 template <typename... ConstPetscScalarPointer>
777 PETSC_KERNEL_DECL static void MDot_kernel(const PetscScalar *PETSC_RESTRICT x, const PetscInt size, PetscScalar *PETSC_RESTRICT results, ConstPetscScalarPointer... y)
778 {
779   constexpr int      N        = sizeof...(ConstPetscScalarPointer);
780   const PetscScalar *ylocal[] = {y...};
781   PetscScalar        sumlocal[N];
782 
783   PETSC_SHAREDMEM_DECL PetscScalar shmem[N * MDOT_WORKGROUP_SIZE];
784 
785   // HIP -- for whatever reason -- has threadIdx, blockIdx, blockDim, and gridDim as separate
786   // types, so each of these go on separate lines...
787   const auto tx       = threadIdx.x;
788   const auto bx       = blockIdx.x;
789   const auto bdx      = blockDim.x;
790   const auto gdx      = gridDim.x;
791   const auto worksize = EntriesPerGroup(size);
792   const auto begin    = tx + bx * worksize;
793   const auto end      = min((bx + 1) * worksize, size);
794 
795   #pragma unroll
796   for (auto i = 0; i < N; ++i) sumlocal[i] = 0;
797 
798   for (auto i = begin; i < end; i += bdx) {
799     const auto xi = x[i]; // load only once from global memory!
800 
801   #pragma unroll
802     for (auto j = 0; j < N; ++j) sumlocal[j] += ylocal[j][i] * xi;
803   }
804 
805   #pragma unroll
806   for (auto i = 0; i < N; ++i) shmem[tx + i * MDOT_WORKGROUP_SIZE] = sumlocal[i];
807 
808   // parallel reduction
809   for (auto stride = bdx / 2; stride > 0; stride /= 2) {
810     __syncthreads();
811     if (tx < stride) {
812   #pragma unroll
813       for (auto i = 0; i < N; ++i) shmem[tx + i * MDOT_WORKGROUP_SIZE] += shmem[tx + stride + i * MDOT_WORKGROUP_SIZE];
814     }
815   }
816   // bottom N threads per block write to global memory
817   // REVIEW ME: I am ~pretty~ sure we don't need another __syncthreads() here since each thread
818   // writes to the same sections in the above loop that it is about to read from below, but
819   // running this under the racecheck tool of cuda-memcheck reports a write-after-write hazard.
820   __syncthreads();
821   if (tx < N) results[bx + tx * gdx] = shmem[tx * MDOT_WORKGROUP_SIZE];
822   return;
823 }
824 
825 namespace
826 {
827 
828 PETSC_KERNEL_DECL void sum_kernel(const PetscInt size, PetscScalar *PETSC_RESTRICT results)
829 {
830   int         local_i = 0;
831   PetscScalar local_results[8];
832 
833   // each thread sums up MDOT_WORKGROUP_NUM entries of the result, storing it in a local buffer
834   //
835   // *-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*
836   // | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | ...
837   // *-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*
838   //  |  ______________________________________________________/
839   //  | /            <- MDOT_WORKGROUP_NUM ->
840   //  |/
841   //  +
842   //  v
843   // *-*-*
844   // | | | ...
845   // *-*-*
846   //
847   ::Petsc::device::cupm::kernels::util::grid_stride_1D(size, [&](PetscInt i) {
848     PetscScalar z_sum = 0;
849 
850     for (auto j = i * MDOT_WORKGROUP_SIZE; j < (i + 1) * MDOT_WORKGROUP_SIZE; ++j) z_sum += results[j];
851     local_results[local_i++] = z_sum;
852   });
853   // if we needed more than 1 workgroup to handle the vector we should sync since other threads
854   // may currently be reading from results
855   if (size >= MDOT_WORKGROUP_SIZE) __syncthreads();
856   // Local buffer is now written to global memory
857   ::Petsc::device::cupm::kernels::util::grid_stride_1D(size, [&](PetscInt i) {
858     const auto j = --local_i;
859 
860     if (j >= 0) results[i] = local_results[j];
861   });
862   return;
863 }
864 
865 } // namespace
866 
867 } // namespace kernels
868 
869 template <device::cupm::DeviceType T>
870 template <std::size_t... Idx>
871 inline PetscErrorCode VecSeq_CUPM<T>::MDot_kernel_dispatch_(PetscDeviceContext dctx, cupmStream_t stream, const PetscScalar *xarr, const Vec yin[], PetscInt size, PetscScalar *results, util::index_sequence<Idx...>) noexcept
872 {
873   PetscFunctionBegin;
874   // REVIEW ME: convert this kernel launch to PetscCUPMLaunchKernel1D(), it currently launches
875   // 128 blocks of 128 threads every time which may be wasteful
876   // clang-format off
877   PetscCallCUPM(
878     cupmLaunchKernel(
879       kernels::MDot_kernel<typename detail::repeat_type<const PetscScalar *, Idx>::type...>,
880       MDOT_WORKGROUP_NUM, MDOT_WORKGROUP_SIZE, 0, stream,
881       xarr, size, results, DeviceArrayRead(dctx, yin[Idx]).data()...
882     )
883   );
884   // clang-format on
885   PetscFunctionReturn(PETSC_SUCCESS);
886 }
887 
888 template <device::cupm::DeviceType T>
889 template <int N>
890 inline PetscErrorCode VecSeq_CUPM<T>::MDot_kernel_dispatch_(PetscDeviceContext dctx, cupmStream_t stream, const PetscScalar *xarr, const Vec yin[], PetscInt size, PetscScalar *results, PetscInt &yidx) noexcept
891 {
892   PetscFunctionBegin;
893   PetscCall(MDot_kernel_dispatch_(dctx, stream, xarr, yin + yidx, size, results + yidx * MDOT_WORKGROUP_NUM, util::make_index_sequence<N>{}));
894   yidx += N;
895   PetscFunctionReturn(PETSC_SUCCESS);
896 }
897 
898 template <device::cupm::DeviceType T>
899 inline PetscErrorCode VecSeq_CUPM<T>::MDot_(std::false_type, Vec xin, PetscInt nv, const Vec yin[], PetscScalar *z, PetscDeviceContext dctx) noexcept
900 {
901   // the largest possible size of a batch
902   constexpr PetscInt batchsize = 8;
903   // how many sub streams to create, if nv <= batchsize we can do this without looping, so we
904   // do not create substreams. Note we don't create more than 8 streams, in practice we could
905   // not get more parallelism with higher numbers.
906   const auto num_sub_streams = nv > batchsize ? std::min((nv + batchsize) / batchsize, batchsize) : 0;
907   const auto n               = xin->map->n;
908   // number of vectors that we handle via the batches. note any singletons are handled by
909   // cublas, hence the nv-1.
910   const auto   nvbatch = ((nv % batchsize) == 1) ? nv - 1 : nv;
911   const auto   nwork   = nvbatch * MDOT_WORKGROUP_NUM;
912   PetscScalar *d_results;
913   cupmStream_t stream;
914 
915   PetscFunctionBegin;
916   PetscCall(GetHandlesFrom_(dctx, &stream));
917   // allocate scratchpad memory for the results of individual work groups
918   PetscCall(PetscDeviceMalloc(dctx, PETSC_MEMTYPE_CUPM(), nwork, &d_results));
919   {
920     const auto          xptr       = DeviceArrayRead(dctx, xin);
921     PetscInt            yidx       = 0;
922     auto                subidx     = 0;
923     auto                cur_stream = stream;
924     auto                cur_ctx    = dctx;
925     PetscDeviceContext *sub        = nullptr;
926     PetscStreamType     stype;
927 
928     // REVIEW ME: maybe PetscDeviceContextFork() should insert dctx into the first entry of
929     // sub. Ideally the parent context should also join in on the fork, but it is extremely
930     // fiddly to do so presently
931     PetscCall(PetscDeviceContextGetStreamType(dctx, &stype));
932     if (stype == PETSC_STREAM_GLOBAL_BLOCKING) stype = PETSC_STREAM_DEFAULT_BLOCKING;
933     // If we have a globally blocking stream create nonblocking streams instead (as we can
934     // locally exploit the parallelism). Otherwise use the prescribed stream type.
935     PetscCall(PetscDeviceContextForkWithStreamType(dctx, stype, num_sub_streams, &sub));
936     PetscCall(PetscLogGpuTimeBegin());
937     do {
938       if (num_sub_streams) {
939         cur_ctx = sub[subidx++ % num_sub_streams];
940         PetscCall(GetHandlesFrom_(cur_ctx, &cur_stream));
941       }
942       // REVIEW ME: Should probably try and load-balance these. Consider the case where nv = 9;
943       // it is very likely better to do 4+5 rather than 8+1
944       switch (nv - yidx) {
945       case 7:
946         PetscCall(MDot_kernel_dispatch_<7>(cur_ctx, cur_stream, xptr.data(), yin, n, d_results, yidx));
947         break;
948       case 6:
949         PetscCall(MDot_kernel_dispatch_<6>(cur_ctx, cur_stream, xptr.data(), yin, n, d_results, yidx));
950         break;
951       case 5:
952         PetscCall(MDot_kernel_dispatch_<5>(cur_ctx, cur_stream, xptr.data(), yin, n, d_results, yidx));
953         break;
954       case 4:
955         PetscCall(MDot_kernel_dispatch_<4>(cur_ctx, cur_stream, xptr.data(), yin, n, d_results, yidx));
956         break;
957       case 3:
958         PetscCall(MDot_kernel_dispatch_<3>(cur_ctx, cur_stream, xptr.data(), yin, n, d_results, yidx));
959         break;
960       case 2:
961         PetscCall(MDot_kernel_dispatch_<2>(cur_ctx, cur_stream, xptr.data(), yin, n, d_results, yidx));
962         break;
963       case 1: {
964         cupmBlasHandle_t cupmBlasHandle;
965 
966         PetscCall(GetHandlesFrom_(cur_ctx, &cupmBlasHandle));
967         PetscCallCUPMBLAS(cupmBlasXdot(cupmBlasHandle, static_cast<cupmBlasInt_t>(n), DeviceArrayRead(cur_ctx, yin[yidx]).cupmdata(), 1, xptr.cupmdata(), 1, cupmScalarPtrCast(z + yidx)));
968         ++yidx;
969       } break;
970       default: // 8 or more
971         PetscCall(MDot_kernel_dispatch_<8>(cur_ctx, cur_stream, xptr.data(), yin, n, d_results, yidx));
972         break;
973       }
974     } while (yidx < nv);
975     PetscCall(PetscLogGpuTimeEnd());
976     PetscCall(PetscDeviceContextJoin(dctx, num_sub_streams, PETSC_DEVICE_CONTEXT_JOIN_DESTROY, &sub));
977   }
978 
979   PetscCall(PetscCUPMLaunchKernel1D(nvbatch, 0, stream, kernels::sum_kernel, nvbatch, d_results));
980   // copy result of device reduction to host
981   PetscCall(PetscCUPMMemcpyAsync(z, d_results, nvbatch, cupmMemcpyDeviceToHost, stream));
982   // do these now while final reduction is in flight
983   PetscCall(PetscLogFlops(nwork));
984   PetscCall(PetscDeviceFree(dctx, d_results));
985   PetscFunctionReturn(PETSC_SUCCESS);
986 }
987 
988   #undef MDOT_WORKGROUP_NUM
989   #undef MDOT_WORKGROUP_SIZE
990 
991 template <device::cupm::DeviceType T>
992 inline PetscErrorCode VecSeq_CUPM<T>::MDot_(std::true_type, Vec xin, PetscInt nv, const Vec yin[], PetscScalar *z, PetscDeviceContext dctx) noexcept
993 {
994   // probably not worth it to run more than 8 of these at a time?
995   const auto          n_sub = PetscMin(nv, 8);
996   const auto          n     = static_cast<cupmBlasInt_t>(xin->map->n);
997   const auto          xptr  = DeviceArrayRead(dctx, xin);
998   PetscScalar        *d_z;
999   PetscDeviceContext *subctx;
1000   cupmStream_t        stream;
1001 
1002   PetscFunctionBegin;
1003   PetscCall(GetHandlesFrom_(dctx, &stream));
1004   PetscCall(PetscDeviceMalloc(dctx, PETSC_MEMTYPE_CUPM(), nv, &d_z));
1005   PetscCall(PetscDeviceContextFork(dctx, n_sub, &subctx));
1006   PetscCall(PetscLogGpuTimeBegin());
1007   for (PetscInt i = 0; i < nv; ++i) {
1008     const auto            sub = subctx[i % n_sub];
1009     cupmBlasHandle_t      handle;
1010     cupmBlasPointerMode_t old_mode;
1011 
1012     PetscCall(GetHandlesFrom_(sub, &handle));
1013     PetscCallCUPMBLAS(cupmBlasGetPointerMode(handle, &old_mode));
1014     if (old_mode != CUPMBLAS_POINTER_MODE_DEVICE) PetscCallCUPMBLAS(cupmBlasSetPointerMode(handle, CUPMBLAS_POINTER_MODE_DEVICE));
1015     PetscCallCUPMBLAS(cupmBlasXdot(handle, n, DeviceArrayRead(sub, yin[i]), 1, xptr.cupmdata(), 1, cupmScalarPtrCast(d_z + i)));
1016     if (old_mode != CUPMBLAS_POINTER_MODE_DEVICE) PetscCallCUPMBLAS(cupmBlasSetPointerMode(handle, old_mode));
1017   }
1018   PetscCall(PetscLogGpuTimeEnd());
1019   PetscCall(PetscDeviceContextJoin(dctx, n_sub, PETSC_DEVICE_CONTEXT_JOIN_DESTROY, &subctx));
1020   PetscCall(PetscCUPMMemcpyAsync(z, d_z, nv, cupmMemcpyDeviceToHost, stream));
1021   PetscCall(PetscDeviceFree(dctx, d_z));
1022   // REVIEW ME: flops?????
1023   PetscFunctionReturn(PETSC_SUCCESS);
1024 }
1025 
1026 // v->ops->mdot
1027 template <device::cupm::DeviceType T>
1028 inline PetscErrorCode VecSeq_CUPM<T>::MDot(Vec xin, PetscInt nv, const Vec yin[], PetscScalar *z) noexcept
1029 {
1030   PetscFunctionBegin;
1031   if (PetscUnlikely(nv == 1)) {
1032     // dot handles nv = 0 correctly
1033     PetscCall(Dot(xin, const_cast<Vec>(yin[0]), z));
1034   } else if (const auto n = xin->map->n) {
1035     PetscDeviceContext dctx;
1036 
1037     PetscCheck(nv > 0, PETSC_COMM_SELF, PETSC_ERR_LIB, "Number of vectors provided to %s %" PetscInt_FMT " not positive", PETSC_FUNCTION_NAME, nv);
1038     PetscCall(GetHandles_(&dctx));
1039     PetscCall(MDot_(std::integral_constant<bool, PetscDefined(USE_COMPLEX)>{}, xin, nv, yin, z, dctx));
1040     // REVIEW ME: double count of flops??
1041     PetscCall(PetscLogGpuFlops(nv * (2 * n - 1)));
1042     PetscCall(PetscDeviceContextSynchronize(dctx));
1043   } else {
1044     PetscCall(PetscArrayzero(z, nv));
1045   }
1046   PetscFunctionReturn(PETSC_SUCCESS);
1047 }
1048 
1049 // v->ops->set
1050 template <device::cupm::DeviceType T>
1051 inline PetscErrorCode VecSeq_CUPM<T>::Set(Vec xin, PetscScalar alpha) noexcept
1052 {
1053   const auto         n = xin->map->n;
1054   PetscDeviceContext dctx;
1055   cupmStream_t       stream;
1056 
1057   PetscFunctionBegin;
1058   PetscCall(GetHandles_(&dctx, &stream));
1059   {
1060     const auto xptr = DeviceArrayWrite(dctx, xin);
1061 
1062     if (alpha == PetscScalar(0.0)) {
1063       PetscCall(PetscCUPMMemsetAsync(xptr.data(), 0, n, stream));
1064     } else {
1065       const auto dptr = thrust::device_pointer_cast(xptr.data());
1066 
1067       PetscCallThrust(THRUST_CALL(thrust::fill, stream, dptr, dptr + n, alpha));
1068     }
1069     if (n) PetscCall(PetscDeviceContextSynchronize(dctx)); // don't sync if we did nothing
1070   }
1071   PetscFunctionReturn(PETSC_SUCCESS);
1072 }
1073 
1074 // v->ops->scale
1075 template <device::cupm::DeviceType T>
1076 inline PetscErrorCode VecSeq_CUPM<T>::Scale(Vec xin, PetscScalar alpha) noexcept
1077 {
1078   PetscFunctionBegin;
1079   if (PetscUnlikely(alpha == PetscScalar(1.0))) PetscFunctionReturn(PETSC_SUCCESS);
1080   if (PetscUnlikely(alpha == PetscScalar(0.0))) {
1081     PetscCall(Set(xin, alpha));
1082   } else if (const auto n = static_cast<cupmBlasInt_t>(xin->map->n)) {
1083     PetscDeviceContext dctx;
1084     cupmBlasHandle_t   cupmBlasHandle;
1085 
1086     PetscCall(GetHandles_(&dctx, &cupmBlasHandle));
1087     PetscCall(PetscLogGpuTimeBegin());
1088     PetscCallCUPMBLAS(cupmBlasXscal(cupmBlasHandle, n, cupmScalarPtrCast(&alpha), DeviceArrayReadWrite(dctx, xin), 1));
1089     PetscCall(PetscLogGpuTimeEnd());
1090     PetscCall(PetscLogGpuFlops(n));
1091     PetscCall(PetscDeviceContextSynchronize(dctx));
1092   } else {
1093     PetscCall(MaybeIncrementEmptyLocalVec(xin));
1094   }
1095   PetscFunctionReturn(PETSC_SUCCESS);
1096 }
1097 
1098 // v->ops->tdot
1099 template <device::cupm::DeviceType T>
1100 inline PetscErrorCode VecSeq_CUPM<T>::TDot(Vec xin, Vec yin, PetscScalar *z) noexcept
1101 {
1102   PetscFunctionBegin;
1103   if (const auto n = static_cast<cupmBlasInt_t>(xin->map->n)) {
1104     PetscDeviceContext dctx;
1105     cupmBlasHandle_t   cupmBlasHandle;
1106 
1107     PetscCall(GetHandles_(&dctx, &cupmBlasHandle));
1108     PetscCall(PetscLogGpuTimeBegin());
1109     PetscCallCUPMBLAS(cupmBlasXdotu(cupmBlasHandle, n, DeviceArrayRead(dctx, xin), 1, DeviceArrayRead(dctx, yin), 1, cupmScalarPtrCast(z)));
1110     PetscCall(PetscLogGpuTimeEnd());
1111     PetscCall(PetscLogGpuFlops(2 * n - 1));
1112   } else {
1113     *z = 0.0;
1114   }
1115   PetscFunctionReturn(PETSC_SUCCESS);
1116 }
1117 
1118 // v->ops->copy
1119 template <device::cupm::DeviceType T>
1120 inline PetscErrorCode VecSeq_CUPM<T>::Copy(Vec xin, Vec yout) noexcept
1121 {
1122   PetscFunctionBegin;
1123   if (xin == yout) PetscFunctionReturn(PETSC_SUCCESS);
1124   if (const auto n = xin->map->n) {
1125     const auto xmask = xin->offloadmask;
1126     // silence buggy gcc warning: mode may be used uninitialized in this function
1127     auto               mode = cupmMemcpyDeviceToDevice;
1128     PetscDeviceContext dctx;
1129     cupmStream_t       stream;
1130 
1131     // translate from PetscOffloadMask to cupmMemcpyKind
1132     switch (const auto ymask = yout->offloadmask) {
1133     case PETSC_OFFLOAD_UNALLOCATED: {
1134       PetscBool yiscupm;
1135 
1136       PetscCall(PetscObjectTypeCompareAny(PetscObjectCast(yout), &yiscupm, VECSEQCUPM(), VECMPICUPM(), ""));
1137       if (yiscupm) {
1138         mode = PetscOffloadDevice(xmask) ? cupmMemcpyDeviceToDevice : cupmMemcpyHostToHost;
1139         break;
1140       }
1141     } // fall-through if unallocated and not cupm
1142   #if PETSC_CPP_VERSION >= 17
1143       [[fallthrough]];
1144   #endif
1145     case PETSC_OFFLOAD_CPU:
1146       mode = PetscOffloadHost(xmask) ? cupmMemcpyHostToHost : cupmMemcpyDeviceToHost;
1147       break;
1148     case PETSC_OFFLOAD_BOTH:
1149     case PETSC_OFFLOAD_GPU:
1150       mode = PetscOffloadDevice(xmask) ? cupmMemcpyDeviceToDevice : cupmMemcpyHostToDevice;
1151       break;
1152     default:
1153       SETERRQ(PETSC_COMM_SELF, PETSC_ERR_ARG_INCOMP, "Incompatible offload mask %s", PetscOffloadMaskToString(ymask));
1154     }
1155 
1156     PetscCall(GetHandles_(&dctx, &stream));
1157     switch (mode) {
1158     case cupmMemcpyDeviceToDevice: // the best case
1159     case cupmMemcpyHostToDevice: { // not terrible
1160       const auto yptr = DeviceArrayWrite(dctx, yout);
1161       const auto xptr = mode == cupmMemcpyDeviceToDevice ? DeviceArrayRead(dctx, xin).data() : HostArrayRead(dctx, xin).data();
1162 
1163       PetscCall(PetscLogGpuTimeBegin());
1164       PetscCall(PetscCUPMMemcpyAsync(yptr.data(), xptr, n, mode, stream));
1165       PetscCall(PetscLogGpuTimeEnd());
1166     } break;
1167     case cupmMemcpyDeviceToHost: // not great
1168     case cupmMemcpyHostToHost: { // worst case
1169       const auto   xptr = mode == cupmMemcpyDeviceToHost ? DeviceArrayRead(dctx, xin).data() : HostArrayRead(dctx, xin).data();
1170       PetscScalar *yptr;
1171 
1172       PetscCall(VecGetArrayWrite(yout, &yptr));
1173       if (mode == cupmMemcpyDeviceToHost) PetscCall(PetscLogGpuTimeBegin());
1174       PetscCall(PetscCUPMMemcpyAsync(yptr, xptr, n, mode, stream, /* force async */ true));
1175       if (mode == cupmMemcpyDeviceToHost) PetscCall(PetscLogGpuTimeEnd());
1176       PetscCall(VecRestoreArrayWrite(yout, &yptr));
1177     } break;
1178     default:
1179       SETERRQ(PETSC_COMM_SELF, PETSC_ERR_GPU, "Unknown cupmMemcpyKind %d", static_cast<int>(mode));
1180     }
1181     PetscCall(PetscDeviceContextSynchronize(dctx));
1182   } else {
1183     PetscCall(MaybeIncrementEmptyLocalVec(yout));
1184   }
1185   PetscFunctionReturn(PETSC_SUCCESS);
1186 }
1187 
1188 // v->ops->swap
1189 template <device::cupm::DeviceType T>
1190 inline PetscErrorCode VecSeq_CUPM<T>::Swap(Vec xin, Vec yin) noexcept
1191 {
1192   PetscFunctionBegin;
1193   if (xin == yin) PetscFunctionReturn(PETSC_SUCCESS);
1194   if (const auto n = static_cast<cupmBlasInt_t>(xin->map->n)) {
1195     PetscDeviceContext dctx;
1196     cupmBlasHandle_t   cupmBlasHandle;
1197 
1198     PetscCall(GetHandles_(&dctx, &cupmBlasHandle));
1199     PetscCall(PetscLogGpuTimeBegin());
1200     PetscCallCUPMBLAS(cupmBlasXswap(cupmBlasHandle, n, DeviceArrayReadWrite(dctx, xin), 1, DeviceArrayReadWrite(dctx, yin), 1));
1201     PetscCall(PetscLogGpuTimeEnd());
1202     PetscCall(PetscDeviceContextSynchronize(dctx));
1203   } else {
1204     PetscCall(MaybeIncrementEmptyLocalVec(xin));
1205     PetscCall(MaybeIncrementEmptyLocalVec(yin));
1206   }
1207   PetscFunctionReturn(PETSC_SUCCESS);
1208 }
1209 
1210 // v->ops->axpby
1211 template <device::cupm::DeviceType T>
1212 inline PetscErrorCode VecSeq_CUPM<T>::AXPBY(Vec yin, PetscScalar alpha, PetscScalar beta, Vec xin) noexcept
1213 {
1214   PetscFunctionBegin;
1215   if (alpha == PetscScalar(0.0)) {
1216     PetscCall(Scale(yin, beta));
1217   } else if (beta == PetscScalar(1.0)) {
1218     PetscCall(AXPY(yin, alpha, xin));
1219   } else if (alpha == PetscScalar(1.0)) {
1220     PetscCall(AYPX(yin, beta, xin));
1221   } else if (const auto n = static_cast<cupmBlasInt_t>(yin->map->n)) {
1222     const auto         betaIsZero = beta == PetscScalar(0.0);
1223     const auto         aptr       = cupmScalarPtrCast(&alpha);
1224     PetscDeviceContext dctx;
1225     cupmBlasHandle_t   cupmBlasHandle;
1226 
1227     PetscCall(GetHandles_(&dctx, &cupmBlasHandle));
1228     {
1229       const auto xptr = DeviceArrayRead(dctx, xin);
1230 
1231       if (betaIsZero /* beta = 0 */) {
1232         // here we can get away with purely write-only as we memcpy into it first
1233         const auto   yptr = DeviceArrayWrite(dctx, yin);
1234         cupmStream_t stream;
1235 
1236         PetscCall(GetHandlesFrom_(dctx, &stream));
1237         PetscCall(PetscLogGpuTimeBegin());
1238         PetscCall(PetscCUPMMemcpyAsync(yptr.data(), xptr.data(), n, cupmMemcpyDeviceToDevice, stream));
1239         PetscCallCUPMBLAS(cupmBlasXscal(cupmBlasHandle, n, aptr, yptr.cupmdata(), 1));
1240       } else {
1241         const auto yptr = DeviceArrayReadWrite(dctx, yin);
1242 
1243         PetscCall(PetscLogGpuTimeBegin());
1244         PetscCallCUPMBLAS(cupmBlasXscal(cupmBlasHandle, n, cupmScalarPtrCast(&beta), yptr.cupmdata(), 1));
1245         PetscCallCUPMBLAS(cupmBlasXaxpy(cupmBlasHandle, n, aptr, xptr.cupmdata(), 1, yptr.cupmdata(), 1));
1246       }
1247     }
1248     PetscCall(PetscLogGpuTimeEnd());
1249     PetscCall(PetscLogGpuFlops((betaIsZero ? 1 : 3) * n));
1250     PetscCall(PetscDeviceContextSynchronize(dctx));
1251   } else {
1252     PetscCall(MaybeIncrementEmptyLocalVec(yin));
1253   }
1254   PetscFunctionReturn(PETSC_SUCCESS);
1255 }
1256 
1257 // v->ops->axpbypcz
1258 template <device::cupm::DeviceType T>
1259 inline PetscErrorCode VecSeq_CUPM<T>::AXPBYPCZ(Vec zin, PetscScalar alpha, PetscScalar beta, PetscScalar gamma, Vec xin, Vec yin) noexcept
1260 {
1261   PetscFunctionBegin;
1262   if (gamma != PetscScalar(1.0)) PetscCall(Scale(zin, gamma));
1263   PetscCall(AXPY(zin, alpha, xin));
1264   PetscCall(AXPY(zin, beta, yin));
1265   PetscFunctionReturn(PETSC_SUCCESS);
1266 }
1267 
1268 // v->ops->norm
1269 template <device::cupm::DeviceType T>
1270 inline PetscErrorCode VecSeq_CUPM<T>::Norm(Vec xin, NormType type, PetscReal *z) noexcept
1271 {
1272   PetscDeviceContext dctx;
1273   cupmBlasHandle_t   cupmBlasHandle;
1274 
1275   PetscFunctionBegin;
1276   PetscCall(GetHandles_(&dctx, &cupmBlasHandle));
1277   if (const auto n = static_cast<cupmBlasInt_t>(xin->map->n)) {
1278     const auto xptr      = DeviceArrayRead(dctx, xin);
1279     PetscInt   flopCount = 0;
1280 
1281     PetscCall(PetscLogGpuTimeBegin());
1282     switch (type) {
1283     case NORM_1_AND_2:
1284     case NORM_1:
1285       PetscCallCUPMBLAS(cupmBlasXasum(cupmBlasHandle, n, xptr.cupmdata(), 1, cupmRealPtrCast(z)));
1286       flopCount = std::max(n - 1, 0);
1287       if (type == NORM_1) break;
1288       ++z; // fall-through
1289   #if PETSC_CPP_VERSION >= 17
1290       [[fallthrough]];
1291   #endif
1292     case NORM_2:
1293     case NORM_FROBENIUS:
1294       PetscCallCUPMBLAS(cupmBlasXnrm2(cupmBlasHandle, n, xptr.cupmdata(), 1, cupmRealPtrCast(z)));
1295       flopCount += std::max(2 * n - 1, 0); // += in case we've fallen through from NORM_1_AND_2
1296       break;
1297     case NORM_INFINITY: {
1298       cupmBlasInt_t max_loc = 0;
1299       PetscScalar   xv      = 0.;
1300       cupmStream_t  stream;
1301 
1302       PetscCall(GetHandlesFrom_(dctx, &stream));
1303       PetscCallCUPMBLAS(cupmBlasXamax(cupmBlasHandle, n, xptr.cupmdata(), 1, &max_loc));
1304       PetscCall(PetscCUPMMemcpyAsync(&xv, xptr.data() + max_loc - 1, 1, cupmMemcpyDeviceToHost, stream));
1305       *z = PetscAbsScalar(xv);
1306       // REVIEW ME: flopCount = ???
1307     } break;
1308     }
1309     PetscCall(PetscLogGpuTimeEnd());
1310     PetscCall(PetscLogGpuFlops(flopCount));
1311   } else {
1312     z[0]                    = 0.0;
1313     z[type == NORM_1_AND_2] = 0.0;
1314   }
1315   PetscFunctionReturn(PETSC_SUCCESS);
1316 }
1317 
1318 namespace detail
1319 {
1320 
1321 template <NormType wnormtype>
1322 class ErrorWNormTransformBase {
1323 public:
1324   using result_type = thrust::tuple<PetscReal, PetscReal, PetscReal, PetscInt, PetscInt, PetscInt>;
1325 
1326   constexpr explicit ErrorWNormTransformBase(PetscReal v) noexcept : ignore_max_{v} { }
1327 
1328 protected:
1329   struct NormTuple {
1330     PetscReal norm;
1331     PetscInt  loc;
1332   };
1333 
1334   PETSC_NODISCARD PETSC_HOSTDEVICE_INLINE_DECL static NormTuple compute_norm_(PetscReal err, PetscReal tol) noexcept
1335   {
1336     if (tol > 0.) {
1337       const auto val = err / tol;
1338 
1339       return {wnormtype == NORM_INFINITY ? val : PetscSqr(val), 1};
1340     } else {
1341       return {0.0, 0};
1342     }
1343   }
1344 
1345   PetscReal ignore_max_;
1346 };
1347 
1348 template <NormType wnormtype>
1349 struct ErrorWNormTransform : ErrorWNormTransformBase<wnormtype> {
1350   using base_type     = ErrorWNormTransformBase<wnormtype>;
1351   using result_type   = typename base_type::result_type;
1352   using argument_type = thrust::tuple<PetscScalar, PetscScalar, PetscScalar, PetscScalar>;
1353 
1354   using base_type::base_type;
1355 
1356   PETSC_NODISCARD PETSC_HOSTDEVICE_INLINE_DECL result_type operator()(const argument_type &x) const noexcept
1357   {
1358     const auto u     = x.get<0>();
1359     const auto y     = x.get<1>();
1360     const auto au    = PetscAbsScalar(u);
1361     const auto ay    = PetscAbsScalar(y);
1362     const auto skip  = au < this->ignore_max_ || ay < this->ignore_max_;
1363     const auto tola  = skip ? 0.0 : PetscRealPart(x.get<2>());
1364     const auto tolr  = skip ? 0.0 : PetscRealPart(x.get<3>()) * PetscMax(au, ay);
1365     const auto tol   = tola + tolr;
1366     const auto err   = PetscAbsScalar(u - y);
1367     const auto tup_a = this->compute_norm_(err, tola);
1368     const auto tup_r = this->compute_norm_(err, tolr);
1369     const auto tup_n = this->compute_norm_(err, tol);
1370 
1371     return {tup_n.norm, tup_a.norm, tup_r.norm, tup_n.loc, tup_a.loc, tup_r.loc};
1372   }
1373 };
1374 
1375 template <NormType wnormtype>
1376 struct ErrorWNormETransform : ErrorWNormTransformBase<wnormtype> {
1377   using base_type     = ErrorWNormTransformBase<wnormtype>;
1378   using result_type   = typename base_type::result_type;
1379   using argument_type = thrust::tuple<PetscScalar, PetscScalar, PetscScalar, PetscScalar, PetscScalar>;
1380 
1381   using base_type::base_type;
1382 
1383   PETSC_NODISCARD PETSC_HOSTDEVICE_INLINE_DECL result_type operator()(const argument_type &x) const noexcept
1384   {
1385     const auto au    = PetscAbsScalar(x.get<0>());
1386     const auto ay    = PetscAbsScalar(x.get<1>());
1387     const auto skip  = au < this->ignore_max_ || ay < this->ignore_max_;
1388     const auto tola  = skip ? 0.0 : PetscRealPart(x.get<3>());
1389     const auto tolr  = skip ? 0.0 : PetscRealPart(x.get<4>()) * PetscMax(au, ay);
1390     const auto tol   = tola + tolr;
1391     const auto err   = PetscAbsScalar(x.get<2>());
1392     const auto tup_a = this->compute_norm_(err, tola);
1393     const auto tup_r = this->compute_norm_(err, tolr);
1394     const auto tup_n = this->compute_norm_(err, tol);
1395 
1396     return {tup_n.norm, tup_a.norm, tup_r.norm, tup_n.loc, tup_a.loc, tup_r.loc};
1397   }
1398 };
1399 
1400 template <NormType wnormtype>
1401 struct ErrorWNormReduce {
1402   using value_type = typename ErrorWNormTransformBase<wnormtype>::result_type;
1403 
1404   PETSC_NODISCARD PETSC_HOSTDEVICE_INLINE_DECL value_type operator()(const value_type &lhs, const value_type &rhs) const noexcept
1405   {
1406     // cannot use lhs.get<0>() etc since the using decl above ambiguates the fact that
1407     // result_type is a template, so in order to fix this we would need to write:
1408     //
1409     // lhs.template get<0>()
1410     //
1411     // which is unseemly.
1412     if (wnormtype == NORM_INFINITY) {
1413       // clang-format off
1414       return {
1415         PetscMax(thrust::get<0>(lhs), thrust::get<0>(rhs)),
1416         PetscMax(thrust::get<1>(lhs), thrust::get<1>(rhs)),
1417         PetscMax(thrust::get<2>(lhs), thrust::get<2>(rhs)),
1418         thrust::get<3>(lhs) + thrust::get<3>(rhs),
1419         thrust::get<4>(lhs) + thrust::get<4>(rhs),
1420         thrust::get<5>(lhs) + thrust::get<5>(rhs)
1421       };
1422       // clang-format on
1423     } else {
1424       // clang-format off
1425       return {
1426         thrust::get<0>(lhs) + thrust::get<0>(rhs),
1427         thrust::get<1>(lhs) + thrust::get<1>(rhs),
1428         thrust::get<2>(lhs) + thrust::get<2>(rhs),
1429         thrust::get<3>(lhs) + thrust::get<3>(rhs),
1430         thrust::get<4>(lhs) + thrust::get<4>(rhs),
1431         thrust::get<5>(lhs) + thrust::get<5>(rhs)
1432       };
1433       // clang-format on
1434     }
1435   }
1436 };
1437 
1438 template <template <NormType> class WNormTransformType, typename Tuple, typename cupmStream_t>
1439 inline PetscErrorCode ExecuteWNorm(Tuple &&first, Tuple &&last, NormType wnormtype, cupmStream_t stream, PetscReal ignore_max, PetscReal *norm, PetscInt *norm_loc, PetscReal *norma, PetscInt *norma_loc, PetscReal *normr, PetscInt *normr_loc) noexcept
1440 {
1441   auto      begin = thrust::make_zip_iterator(std::forward<Tuple>(first));
1442   auto      end   = thrust::make_zip_iterator(std::forward<Tuple>(last));
1443   PetscReal n = 0, na = 0, nr = 0;
1444   PetscInt  n_loc = 0, na_loc = 0, nr_loc = 0;
1445 
1446   PetscFunctionBegin;
1447   // clang-format off
1448   if (wnormtype == NORM_INFINITY) {
1449     PetscCallThrust(
1450       thrust::tie(*norm, *norma, *normr, *norm_loc, *norma_loc, *normr_loc) = THRUST_CALL(
1451         thrust::transform_reduce,
1452         stream,
1453         std::move(begin),
1454         std::move(end),
1455         WNormTransformType<NORM_INFINITY>{ignore_max},
1456         thrust::make_tuple(n, na, nr, n_loc, na_loc, nr_loc),
1457         ErrorWNormReduce<NORM_INFINITY>{}
1458       )
1459     );
1460   } else {
1461     PetscCallThrust(
1462       thrust::tie(*norm, *norma, *normr, *norm_loc, *norma_loc, *normr_loc) = THRUST_CALL(
1463         thrust::transform_reduce,
1464         stream,
1465         std::move(begin),
1466         std::move(end),
1467         WNormTransformType<NORM_2>{ignore_max},
1468         thrust::make_tuple(n, na, nr, n_loc, na_loc, nr_loc),
1469         ErrorWNormReduce<NORM_2>{}
1470       )
1471     );
1472   }
1473   // clang-format on
1474   if (wnormtype == NORM_2) {
1475     *norm  = PetscSqrtReal(*norm);
1476     *norma = PetscSqrtReal(*norma);
1477     *normr = PetscSqrtReal(*normr);
1478   }
1479   PetscFunctionReturn(PETSC_SUCCESS);
1480 }
1481 
1482 } // namespace detail
1483 
1484 // v->ops->errorwnorm
1485 template <device::cupm::DeviceType T>
1486 inline PetscErrorCode VecSeq_CUPM<T>::ErrorWnorm(Vec U, Vec Y, Vec E, NormType wnormtype, PetscReal atol, Vec vatol, PetscReal rtol, Vec vrtol, PetscReal ignore_max, PetscReal *norm, PetscInt *norm_loc, PetscReal *norma, PetscInt *norma_loc, PetscReal *normr, PetscInt *normr_loc) noexcept
1487 {
1488   const auto         nl  = U->map->n;
1489   auto               ait = thrust::make_constant_iterator(static_cast<PetscScalar>(atol));
1490   auto               rit = thrust::make_constant_iterator(static_cast<PetscScalar>(rtol));
1491   PetscDeviceContext dctx;
1492   cupmStream_t       stream;
1493 
1494   PetscFunctionBegin;
1495   PetscCall(GetHandles_(&dctx, &stream));
1496   {
1497     const auto ConditionalDeviceArrayRead = [&](Vec v) {
1498       if (v) {
1499         return thrust::device_pointer_cast(DeviceArrayRead(dctx, v).data());
1500       } else {
1501         return thrust::device_ptr<PetscScalar>{nullptr};
1502       }
1503     };
1504 
1505     const auto uarr = DeviceArrayRead(dctx, U);
1506     const auto yarr = DeviceArrayRead(dctx, Y);
1507     const auto uptr = thrust::device_pointer_cast(uarr.data());
1508     const auto yptr = thrust::device_pointer_cast(yarr.data());
1509     const auto eptr = ConditionalDeviceArrayRead(E);
1510     const auto rptr = ConditionalDeviceArrayRead(vrtol);
1511     const auto aptr = ConditionalDeviceArrayRead(vatol);
1512 
1513     if (!vatol && !vrtol) {
1514       if (E) {
1515         // clang-format off
1516         PetscCall(
1517           detail::ExecuteWNorm<detail::ErrorWNormETransform>(
1518             thrust::make_tuple(uptr, yptr, eptr, ait, rit),
1519             thrust::make_tuple(uptr + nl, yptr + nl, eptr + nl, ait, rit),
1520             wnormtype, stream, ignore_max, norm, norm_loc, norma, norma_loc, normr, normr_loc
1521           )
1522         );
1523         // clang-format on
1524       } else {
1525         // clang-format off
1526         PetscCall(
1527           detail::ExecuteWNorm<detail::ErrorWNormTransform>(
1528             thrust::make_tuple(uptr, yptr, ait, rit),
1529             thrust::make_tuple(uptr + nl, yptr + nl, ait, rit),
1530             wnormtype, stream, ignore_max, norm, norm_loc, norma, norma_loc, normr, normr_loc
1531           )
1532         );
1533         // clang-format on
1534       }
1535     } else if (!vatol) {
1536       if (E) {
1537         // clang-format off
1538         PetscCall(
1539           detail::ExecuteWNorm<detail::ErrorWNormETransform>(
1540             thrust::make_tuple(uptr, yptr, eptr, ait, rptr),
1541             thrust::make_tuple(uptr + nl, yptr + nl, eptr + nl, ait, rptr + nl),
1542             wnormtype, stream, ignore_max, norm, norm_loc, norma, norma_loc, normr, normr_loc
1543           )
1544         );
1545         // clang-format on
1546       } else {
1547         // clang-format off
1548         PetscCall(
1549           detail::ExecuteWNorm<detail::ErrorWNormTransform>(
1550             thrust::make_tuple(uptr, yptr, ait, rptr),
1551             thrust::make_tuple(uptr + nl, yptr + nl, ait, rptr + nl),
1552             wnormtype, stream, ignore_max, norm, norm_loc, norma, norma_loc, normr, normr_loc
1553           )
1554         );
1555         // clang-format on
1556       }
1557     } else if (!vrtol) {
1558       if (E) {
1559         // clang-format off
1560           PetscCall(
1561             detail::ExecuteWNorm<detail::ErrorWNormETransform>(
1562               thrust::make_tuple(uptr, yptr, eptr, aptr, rit),
1563               thrust::make_tuple(uptr + nl, yptr + nl, eptr + nl, aptr + nl, rit),
1564               wnormtype, stream, ignore_max, norm, norm_loc, norma, norma_loc, normr, normr_loc
1565             )
1566           );
1567         // clang-format on
1568       } else {
1569         // clang-format off
1570           PetscCall(
1571             detail::ExecuteWNorm<detail::ErrorWNormTransform>(
1572               thrust::make_tuple(uptr, yptr, aptr, rit),
1573               thrust::make_tuple(uptr + nl, yptr + nl, aptr + nl, rit),
1574               wnormtype, stream, ignore_max, norm, norm_loc, norma, norma_loc, normr, normr_loc
1575             )
1576           );
1577         // clang-format on
1578       }
1579     } else {
1580       if (E) {
1581         // clang-format off
1582           PetscCall(
1583             detail::ExecuteWNorm<detail::ErrorWNormETransform>(
1584               thrust::make_tuple(uptr, yptr, eptr, aptr, rptr),
1585               thrust::make_tuple(uptr + nl, yptr + nl, eptr + nl, aptr + nl, rptr + nl),
1586               wnormtype, stream, ignore_max, norm, norm_loc, norma, norma_loc, normr, normr_loc
1587             )
1588           );
1589         // clang-format on
1590       } else {
1591         // clang-format off
1592           PetscCall(
1593             detail::ExecuteWNorm<detail::ErrorWNormTransform>(
1594               thrust::make_tuple(uptr, yptr, aptr, rptr),
1595               thrust::make_tuple(uptr + nl, yptr + nl, aptr + nl, rptr + nl),
1596               wnormtype, stream, ignore_max, norm, norm_loc, norma, norma_loc, normr, normr_loc
1597             )
1598           );
1599         // clang-format on
1600       }
1601     }
1602   }
1603   PetscFunctionReturn(PETSC_SUCCESS);
1604 }
1605 
1606 namespace detail
1607 {
1608 struct dotnorm2_mult {
1609   PETSC_NODISCARD PETSC_HOSTDEVICE_INLINE_DECL thrust::tuple<PetscScalar, PetscScalar> operator()(const PetscScalar &s, const PetscScalar &t) const noexcept
1610   {
1611     const auto conjt = PetscConj(t);
1612 
1613     return {s * conjt, t * conjt};
1614   }
1615 };
1616 
1617 // it is positively __bananas__ that thrust does not define default operator+ for tuples... I
1618 // would do it myself but now I am worried that they do so on purpose...
1619 struct dotnorm2_tuple_plus {
1620   using value_type = thrust::tuple<PetscScalar, PetscScalar>;
1621 
1622   PETSC_NODISCARD PETSC_HOSTDEVICE_INLINE_DECL value_type operator()(const value_type &lhs, const value_type &rhs) const noexcept { return {lhs.get<0>() + rhs.get<0>(), lhs.get<1>() + rhs.get<1>()}; }
1623 };
1624 
1625 } // namespace detail
1626 
1627 // v->ops->dotnorm2
1628 template <device::cupm::DeviceType T>
1629 inline PetscErrorCode VecSeq_CUPM<T>::DotNorm2(Vec s, Vec t, PetscScalar *dp, PetscScalar *nm) noexcept
1630 {
1631   PetscDeviceContext dctx;
1632   cupmStream_t       stream;
1633 
1634   PetscFunctionBegin;
1635   PetscCall(GetHandles_(&dctx, &stream));
1636   {
1637     PetscScalar dpt = 0.0, nmt = 0.0;
1638     const auto  sdptr = thrust::device_pointer_cast(DeviceArrayRead(dctx, s).data());
1639 
1640     // clang-format off
1641     PetscCallThrust(
1642       thrust::tie(*dp, *nm) = THRUST_CALL(
1643         thrust::inner_product,
1644         stream,
1645         sdptr, sdptr+s->map->n, thrust::device_pointer_cast(DeviceArrayRead(dctx, t).data()),
1646         thrust::make_tuple(dpt, nmt),
1647         detail::dotnorm2_tuple_plus{}, detail::dotnorm2_mult{}
1648       );
1649     );
1650     // clang-format on
1651   }
1652   PetscFunctionReturn(PETSC_SUCCESS);
1653 }
1654 
1655 namespace detail
1656 {
1657 struct conjugate {
1658   PETSC_NODISCARD PETSC_HOSTDEVICE_INLINE_DECL PetscScalar operator()(PetscScalar x) const noexcept { return PetscConj(x); }
1659 };
1660 
1661 } // namespace detail
1662 
1663 // v->ops->conjugate
1664 template <device::cupm::DeviceType T>
1665 inline PetscErrorCode VecSeq_CUPM<T>::Conjugate(Vec xin) noexcept
1666 {
1667   PetscFunctionBegin;
1668   if (PetscDefined(USE_COMPLEX)) PetscCall(PointwiseUnary_(detail::conjugate{}, xin));
1669   PetscFunctionReturn(PETSC_SUCCESS);
1670 }
1671 
1672 namespace detail
1673 {
1674 struct real_part {
1675   PETSC_NODISCARD PETSC_HOSTDEVICE_INLINE_DECL thrust::tuple<PetscReal, PetscInt> operator()(const thrust::tuple<PetscScalar, PetscInt> &x) const { return {PetscRealPart(x.get<0>()), x.get<1>()}; }
1676 
1677   PETSC_NODISCARD PETSC_HOSTDEVICE_INLINE_DECL PetscReal operator()(PetscScalar x) const { return PetscRealPart(x); }
1678 };
1679 
1680 // deriving from Operator allows us to "store" an instance of the operator in the class but
1681 // also take advantage of empty base class optimization if the operator is stateless
1682 template <typename Operator>
1683 class tuple_compare : Operator {
1684 public:
1685   using tuple_type    = thrust::tuple<PetscReal, PetscInt>;
1686   using operator_type = Operator;
1687 
1688   PETSC_NODISCARD PETSC_HOSTDEVICE_INLINE_DECL tuple_type operator()(const tuple_type &x, const tuple_type &y) const noexcept
1689   {
1690     if (op_()(y.get<0>(), x.get<0>())) {
1691       // if y is strictly greater/less than x, return y
1692       return y;
1693     } else if (y.get<0>() == x.get<0>()) {
1694       // if equal, prefer lower index
1695       return y.get<1>() < x.get<1>() ? y : x;
1696     }
1697     // otherwise return x
1698     return x;
1699   }
1700 
1701 private:
1702   PETSC_NODISCARD PETSC_HOSTDEVICE_INLINE_DECL const operator_type &op_() const noexcept { return *this; }
1703 };
1704 
1705 } // namespace detail
1706 
1707 template <device::cupm::DeviceType T>
1708 template <typename TupleFuncT, typename UnaryFuncT>
1709 inline PetscErrorCode VecSeq_CUPM<T>::MinMax_(TupleFuncT &&tuple_ftr, UnaryFuncT &&unary_ftr, Vec v, PetscInt *p, PetscReal *m) noexcept
1710 {
1711   PetscFunctionBegin;
1712   PetscCheckTypeNames(v, VECSEQCUPM(), VECMPICUPM());
1713   if (p) *p = -1;
1714   if (const auto n = v->map->n) {
1715     PetscDeviceContext dctx;
1716     cupmStream_t       stream;
1717 
1718     PetscCall(GetHandles_(&dctx, &stream));
1719       // needed to:
1720       // 1. switch between transform_reduce and reduce
1721       // 2. strip the real_part functor from the arguments
1722   #if PetscDefined(USE_COMPLEX)
1723     #define THRUST_MINMAX_REDUCE(...) THRUST_CALL(thrust::transform_reduce, __VA_ARGS__)
1724   #else
1725     #define THRUST_MINMAX_REDUCE(s, b, e, real_part__, ...) THRUST_CALL(thrust::reduce, s, b, e, __VA_ARGS__)
1726   #endif
1727     {
1728       const auto vptr = thrust::device_pointer_cast(DeviceArrayRead(dctx, v).data());
1729 
1730       if (p) {
1731         // clang-format off
1732         const auto zip = thrust::make_zip_iterator(
1733           thrust::make_tuple(std::move(vptr), thrust::make_counting_iterator(PetscInt{0}))
1734         );
1735         // clang-format on
1736         // need to use preprocessor conditionals since otherwise thrust complains about not being
1737         // able to convert a thrust::device_reference<PetscScalar> to a PetscReal on complex
1738         // builds...
1739         // clang-format off
1740         PetscCallThrust(
1741           thrust::tie(*m, *p) = THRUST_MINMAX_REDUCE(
1742             stream, zip, zip + n, detail::real_part{},
1743             thrust::make_tuple(*m, *p), std::forward<TupleFuncT>(tuple_ftr)
1744           );
1745         );
1746         // clang-format on
1747       } else {
1748         // clang-format off
1749         PetscCallThrust(
1750           *m = THRUST_MINMAX_REDUCE(
1751             stream, vptr, vptr + n, detail::real_part{},
1752             *m, std::forward<UnaryFuncT>(unary_ftr)
1753           );
1754         );
1755         // clang-format on
1756       }
1757     }
1758   #undef THRUST_MINMAX_REDUCE
1759   }
1760   // REVIEW ME: flops?
1761   PetscFunctionReturn(PETSC_SUCCESS);
1762 }
1763 
1764 // v->ops->max
1765 template <device::cupm::DeviceType T>
1766 inline PetscErrorCode VecSeq_CUPM<T>::Max(Vec v, PetscInt *p, PetscReal *m) noexcept
1767 {
1768   using tuple_functor = detail::tuple_compare<thrust::greater<PetscReal>>;
1769   using unary_functor = thrust::maximum<PetscReal>;
1770 
1771   PetscFunctionBegin;
1772   *m = PETSC_MIN_REAL;
1773   // use {} constructor syntax otherwise most vexing parse
1774   PetscCall(MinMax_(tuple_functor{}, unary_functor{}, v, p, m));
1775   PetscFunctionReturn(PETSC_SUCCESS);
1776 }
1777 
1778 // v->ops->min
1779 template <device::cupm::DeviceType T>
1780 inline PetscErrorCode VecSeq_CUPM<T>::Min(Vec v, PetscInt *p, PetscReal *m) noexcept
1781 {
1782   using tuple_functor = detail::tuple_compare<thrust::less<PetscReal>>;
1783   using unary_functor = thrust::minimum<PetscReal>;
1784 
1785   PetscFunctionBegin;
1786   *m = PETSC_MAX_REAL;
1787   // use {} constructor syntax otherwise most vexing parse
1788   PetscCall(MinMax_(tuple_functor{}, unary_functor{}, v, p, m));
1789   PetscFunctionReturn(PETSC_SUCCESS);
1790 }
1791 
1792 // v->ops->sum
1793 template <device::cupm::DeviceType T>
1794 inline PetscErrorCode VecSeq_CUPM<T>::Sum(Vec v, PetscScalar *sum) noexcept
1795 {
1796   PetscFunctionBegin;
1797   if (const auto n = v->map->n) {
1798     PetscDeviceContext dctx;
1799     cupmStream_t       stream;
1800 
1801     PetscCall(GetHandles_(&dctx, &stream));
1802     const auto dptr = thrust::device_pointer_cast(DeviceArrayRead(dctx, v).data());
1803     // REVIEW ME: why not cupmBlasXasum()?
1804     PetscCallThrust(*sum = THRUST_CALL(thrust::reduce, stream, dptr, dptr + n, PetscScalar{0.0}););
1805     // REVIEW ME: must be at least n additions
1806     PetscCall(PetscLogGpuFlops(n));
1807   } else {
1808     *sum = 0.0;
1809   }
1810   PetscFunctionReturn(PETSC_SUCCESS);
1811 }
1812 
1813 template <device::cupm::DeviceType T>
1814 inline PetscErrorCode VecSeq_CUPM<T>::Shift(Vec v, PetscScalar shift) noexcept
1815 {
1816   PetscFunctionBegin;
1817   PetscCall(PointwiseUnary_(device::cupm::functors::make_plus_equals(shift), v));
1818   PetscFunctionReturn(PETSC_SUCCESS);
1819 }
1820 
1821 template <device::cupm::DeviceType T>
1822 inline PetscErrorCode VecSeq_CUPM<T>::SetRandom(Vec v, PetscRandom rand) noexcept
1823 {
1824   PetscFunctionBegin;
1825   if (const auto n = v->map->n) {
1826     PetscBool          iscurand;
1827     PetscDeviceContext dctx;
1828 
1829     PetscCall(GetHandles_(&dctx));
1830     PetscCall(PetscObjectTypeCompare(PetscObjectCast(rand), PETSCCURAND, &iscurand));
1831     if (iscurand) PetscCall(PetscRandomGetValues(rand, n, DeviceArrayWrite(dctx, v)));
1832     else PetscCall(PetscRandomGetValues(rand, n, HostArrayWrite(dctx, v)));
1833   } else {
1834     PetscCall(MaybeIncrementEmptyLocalVec(v));
1835   }
1836   // REVIEW ME: flops????
1837   // REVIEW ME: Timing???
1838   PetscFunctionReturn(PETSC_SUCCESS);
1839 }
1840 
1841 // v->ops->setpreallocation
1842 template <device::cupm::DeviceType T>
1843 inline PetscErrorCode VecSeq_CUPM<T>::SetPreallocationCOO(Vec v, PetscCount ncoo, const PetscInt coo_i[]) noexcept
1844 {
1845   PetscDeviceContext dctx;
1846 
1847   PetscFunctionBegin;
1848   PetscCall(GetHandles_(&dctx));
1849   PetscCall(VecSetPreallocationCOO_Seq(v, ncoo, coo_i));
1850   PetscCall(SetPreallocationCOO_CUPMBase(v, ncoo, coo_i, dctx));
1851   PetscFunctionReturn(PETSC_SUCCESS);
1852 }
1853 
1854 namespace kernels
1855 {
1856 template <typename F>
1857 PETSC_DEVICE_INLINE_DECL void add_coo_values_impl(const PetscScalar *PETSC_RESTRICT vv, PetscCount n, const PetscCount *PETSC_RESTRICT jmap, const PetscCount *PETSC_RESTRICT perm, InsertMode imode, PetscScalar *PETSC_RESTRICT xv, F &&xvindex)
1858 {
1859   ::Petsc::device::cupm::kernels::util::grid_stride_1D(n, [=](PetscCount i) {
1860     const auto  end = jmap[i + 1];
1861     const auto  idx = xvindex(i);
1862     PetscScalar sum = 0.0;
1863 
1864     for (auto k = jmap[i]; k < end; ++k) sum += vv[perm[k]];
1865 
1866     if (imode == INSERT_VALUES) {
1867       xv[idx] = sum;
1868     } else {
1869       xv[idx] += sum;
1870     }
1871   });
1872   return;
1873 }
1874 
1875 namespace
1876 {
1877 
1878 PETSC_KERNEL_DECL void add_coo_values(const PetscScalar *PETSC_RESTRICT v, PetscCount n, const PetscCount *PETSC_RESTRICT jmap1, const PetscCount *PETSC_RESTRICT perm1, InsertMode imode, PetscScalar *PETSC_RESTRICT xv)
1879 {
1880   add_coo_values_impl(v, n, jmap1, perm1, imode, xv, [](PetscCount i) { return i; });
1881   return;
1882 }
1883 
1884 } // namespace
1885 
1886   #if PetscDefined(USING_HCC)
1887 namespace do_not_use
1888 {
1889 
1890 // Needed to silence clang warning:
1891 //
1892 // warning: function 'FUNCTION NAME' is not needed and will not be emitted
1893 //
1894 // The warning is silly, since the function *is* used, however the host compiler does not
1895 // appear see this. Likely because the function using it is in a template.
1896 //
1897 // This warning appeared in clang-11, and still persists until clang-15 (21/02/2023)
1898 inline void silence_warning_function_sum_kernel_is_not_needed_and_will_not_be_emitted()
1899 {
1900   (void)sum_kernel;
1901 }
1902 
1903 inline void silence_warning_function_add_coo_values_is_not_needed_and_will_not_be_emitted()
1904 {
1905   (void)add_coo_values;
1906 }
1907 
1908 } // namespace do_not_use
1909   #endif
1910 
1911 } // namespace kernels
1912 
1913 // v->ops->setvaluescoo
1914 template <device::cupm::DeviceType T>
1915 inline PetscErrorCode VecSeq_CUPM<T>::SetValuesCOO(Vec x, const PetscScalar v[], InsertMode imode) noexcept
1916 {
1917   auto               vv = const_cast<PetscScalar *>(v);
1918   PetscMemType       memtype;
1919   PetscDeviceContext dctx;
1920   cupmStream_t       stream;
1921 
1922   PetscFunctionBegin;
1923   PetscCall(GetHandles_(&dctx, &stream));
1924   PetscCall(PetscGetMemType(v, &memtype));
1925   if (PetscMemTypeHost(memtype)) {
1926     const auto size = VecIMPLCast(x)->coo_n;
1927 
1928     // If user gave v[] in host, we might need to copy it to device if any
1929     PetscCall(PetscDeviceMalloc(dctx, PETSC_MEMTYPE_CUPM(), size, &vv));
1930     PetscCall(PetscCUPMMemcpyAsync(vv, v, size, cupmMemcpyHostToDevice, stream));
1931   }
1932 
1933   if (const auto n = x->map->n) {
1934     const auto vcu = VecCUPMCast(x);
1935 
1936     PetscCall(PetscCUPMLaunchKernel1D(n, 0, stream, kernels::add_coo_values, vv, n, vcu->jmap1_d, vcu->perm1_d, imode, imode == INSERT_VALUES ? DeviceArrayWrite(dctx, x).data() : DeviceArrayReadWrite(dctx, x).data()));
1937   } else {
1938     PetscCall(MaybeIncrementEmptyLocalVec(x));
1939   }
1940 
1941   if (PetscMemTypeHost(memtype)) PetscCall(PetscDeviceFree(dctx, vv));
1942   PetscCall(PetscDeviceContextSynchronize(dctx));
1943   PetscFunctionReturn(PETSC_SUCCESS);
1944 }
1945 
1946 } // namespace impl
1947 
1948 // ==========================================================================================
1949 // VecSeq_CUPM - Implementations
1950 // ==========================================================================================
1951 
1952 namespace
1953 {
1954 
1955 template <device::cupm::DeviceType T>
1956 inline PetscErrorCode VecCreateSeqCUPMAsync(MPI_Comm comm, PetscInt n, Vec *v) noexcept
1957 {
1958   PetscFunctionBegin;
1959   PetscValidPointer(v, 4);
1960   PetscCall(impl::VecSeq_CUPM<T>::CreateSeqCUPM(comm, 0, n, v, PETSC_TRUE));
1961   PetscFunctionReturn(PETSC_SUCCESS);
1962 }
1963 
1964 template <device::cupm::DeviceType T>
1965 inline PetscErrorCode VecCreateSeqCUPMWithArraysAsync(MPI_Comm comm, PetscInt bs, PetscInt n, const PetscScalar cpuarray[], const PetscScalar gpuarray[], Vec *v) noexcept
1966 {
1967   PetscFunctionBegin;
1968   if (n && cpuarray) PetscValidScalarPointer(cpuarray, 4);
1969   PetscValidPointer(v, 6);
1970   PetscCall(impl::VecSeq_CUPM<T>::CreateSeqCUPMWithBothArrays(comm, bs, n, cpuarray, gpuarray, v));
1971   PetscFunctionReturn(PETSC_SUCCESS);
1972 }
1973 
1974 template <PetscMemoryAccessMode mode, device::cupm::DeviceType T>
1975 inline PetscErrorCode VecCUPMGetArrayAsync_Private(Vec v, PetscScalar **a, PetscDeviceContext dctx) noexcept
1976 {
1977   PetscFunctionBegin;
1978   PetscValidHeaderSpecific(v, VEC_CLASSID, 1);
1979   PetscValidPointer(a, 2);
1980   PetscCall(PetscDeviceContextGetOptionalNullContext_Internal(&dctx));
1981   PetscCall(impl::VecSeq_CUPM<T>::template GetArray<PETSC_MEMTYPE_DEVICE, mode>(v, a, dctx));
1982   PetscFunctionReturn(PETSC_SUCCESS);
1983 }
1984 
1985 template <PetscMemoryAccessMode mode, device::cupm::DeviceType T>
1986 inline PetscErrorCode VecCUPMRestoreArrayAsync_Private(Vec v, PetscScalar **a, PetscDeviceContext dctx) noexcept
1987 {
1988   PetscFunctionBegin;
1989   PetscValidHeaderSpecific(v, VEC_CLASSID, 1);
1990   PetscCall(PetscDeviceContextGetOptionalNullContext_Internal(&dctx));
1991   PetscCall(impl::VecSeq_CUPM<T>::template RestoreArray<PETSC_MEMTYPE_DEVICE, mode>(v, a, dctx));
1992   PetscFunctionReturn(PETSC_SUCCESS);
1993 }
1994 
1995 template <device::cupm::DeviceType T>
1996 inline PetscErrorCode VecCUPMGetArrayAsync(Vec v, PetscScalar **a, PetscDeviceContext dctx = nullptr) noexcept
1997 {
1998   PetscFunctionBegin;
1999   PetscCall(VecCUPMGetArrayAsync_Private<PETSC_MEMORY_ACCESS_READ_WRITE, T>(v, a, dctx));
2000   PetscFunctionReturn(PETSC_SUCCESS);
2001 }
2002 
2003 template <device::cupm::DeviceType T>
2004 inline PetscErrorCode VecCUPMRestoreArrayAsync(Vec v, PetscScalar **a, PetscDeviceContext dctx = nullptr) noexcept
2005 {
2006   PetscFunctionBegin;
2007   PetscCall(VecCUPMRestoreArrayAsync_Private<PETSC_MEMORY_ACCESS_READ_WRITE, T>(v, a, dctx));
2008   PetscFunctionReturn(PETSC_SUCCESS);
2009 }
2010 
2011 template <device::cupm::DeviceType T>
2012 inline PetscErrorCode VecCUPMGetArrayReadAsync(Vec v, const PetscScalar **a, PetscDeviceContext dctx = nullptr) noexcept
2013 {
2014   PetscFunctionBegin;
2015   PetscCall(VecCUPMGetArrayAsync_Private<PETSC_MEMORY_ACCESS_READ, T>(v, const_cast<PetscScalar **>(a), dctx));
2016   PetscFunctionReturn(PETSC_SUCCESS);
2017 }
2018 
2019 template <device::cupm::DeviceType T>
2020 inline PetscErrorCode VecCUPMRestoreArrayReadAsync(Vec v, const PetscScalar **a, PetscDeviceContext dctx = nullptr) noexcept
2021 {
2022   PetscFunctionBegin;
2023   PetscCall(VecCUPMRestoreArrayAsync_Private<PETSC_MEMORY_ACCESS_READ, T>(v, const_cast<PetscScalar **>(a), dctx));
2024   PetscFunctionReturn(PETSC_SUCCESS);
2025 }
2026 
2027 template <device::cupm::DeviceType T>
2028 inline PetscErrorCode VecCUPMGetArrayWriteAsync(Vec v, PetscScalar **a, PetscDeviceContext dctx = nullptr) noexcept
2029 {
2030   PetscFunctionBegin;
2031   PetscCall(VecCUPMGetArrayAsync_Private<PETSC_MEMORY_ACCESS_WRITE, T>(v, a, dctx));
2032   PetscFunctionReturn(PETSC_SUCCESS);
2033 }
2034 
2035 template <device::cupm::DeviceType T>
2036 inline PetscErrorCode VecCUPMRestoreArrayWriteAsync(Vec v, PetscScalar **a, PetscDeviceContext dctx = nullptr) noexcept
2037 {
2038   PetscFunctionBegin;
2039   PetscCall(VecCUPMRestoreArrayAsync_Private<PETSC_MEMORY_ACCESS_WRITE, T>(v, a, dctx));
2040   PetscFunctionReturn(PETSC_SUCCESS);
2041 }
2042 
2043 template <device::cupm::DeviceType T>
2044 inline PetscErrorCode VecCUPMPlaceArrayAsync(Vec vin, const PetscScalar a[]) noexcept
2045 {
2046   PetscFunctionBegin;
2047   PetscValidHeaderSpecific(vin, VEC_CLASSID, 1);
2048   PetscCall(impl::VecSeq_CUPM<T>::template PlaceArray<PETSC_MEMTYPE_DEVICE>(vin, a));
2049   PetscFunctionReturn(PETSC_SUCCESS);
2050 }
2051 
2052 template <device::cupm::DeviceType T>
2053 inline PetscErrorCode VecCUPMReplaceArrayAsync(Vec vin, const PetscScalar a[]) noexcept
2054 {
2055   PetscFunctionBegin;
2056   PetscValidHeaderSpecific(vin, VEC_CLASSID, 1);
2057   PetscCall(impl::VecSeq_CUPM<T>::template ReplaceArray<PETSC_MEMTYPE_DEVICE>(vin, a));
2058   PetscFunctionReturn(PETSC_SUCCESS);
2059 }
2060 
2061 template <device::cupm::DeviceType T>
2062 inline PetscErrorCode VecCUPMResetArrayAsync(Vec vin) noexcept
2063 {
2064   PetscFunctionBegin;
2065   PetscValidHeaderSpecific(vin, VEC_CLASSID, 1);
2066   PetscCall(impl::VecSeq_CUPM<T>::template ResetArray<PETSC_MEMTYPE_DEVICE>(vin));
2067   PetscFunctionReturn(PETSC_SUCCESS);
2068 }
2069 
2070 } // anonymous namespace
2071 
2072 } // namespace cupm
2073 
2074 } // namespace vec
2075 
2076 } // namespace Petsc
2077 
2078 #endif // __cplusplus
2079 
2080 #endif // PETSCVECSEQCUPM_HPP
2081