xref: /petsc/src/vec/vec/impls/seq/cupm/vecseqcupm.hpp (revision a43132046c00290928454b92120f2306b39b67d0)
1 #ifndef PETSCVECSEQCUPM_HPP
2 #define PETSCVECSEQCUPM_HPP
3 
4 #include <petsc/private/veccupmimpl.h>
5 
6 #include <petsc/private/randomimpl.h> // for _p_PetscRandom
7 
8 #include <petsc/private/cpp/utility.hpp> // util::exchange, util::index_sequence
9 
10 #include "../src/sys/objects/device/impls/cupm/cupmthrustutility.hpp"
11 #include "../src/sys/objects/device/impls/cupm/kernels.hpp"
12 
13 #if PetscDefined(USE_COMPLEX)
14   #include <thrust/transform_reduce.h>
15 #endif
16 #include <thrust/transform.h>
17 #include <thrust/reduce.h>
18 #include <thrust/functional.h>
19 #include <thrust/tuple.h>
20 #include <thrust/device_ptr.h>
21 #include <thrust/iterator/zip_iterator.h>
22 #include <thrust/iterator/counting_iterator.h>
23 #include <thrust/iterator/constant_iterator.h>
24 #include <thrust/inner_product.h>
25 
26 namespace Petsc
27 {
28 
29 namespace vec
30 {
31 
32 namespace cupm
33 {
34 
35 namespace impl
36 {
37 
38 // ==========================================================================================
39 // VecSeq_CUPM
40 // ==========================================================================================
41 
42 template <device::cupm::DeviceType T>
43 class VecSeq_CUPM : Vec_CUPMBase<T, VecSeq_CUPM<T>> {
44 public:
45   PETSC_VEC_CUPM_BASE_CLASS_HEADER(base_type, T, VecSeq_CUPM<T>);
46 
47 private:
48   PETSC_NODISCARD static Vec_Seq          *VecIMPLCast_(Vec) noexcept;
49   PETSC_NODISCARD static constexpr VecType VECIMPLCUPM_() noexcept;
50   PETSC_NODISCARD static constexpr VecType VECIMPL_() noexcept;
51 
52   static PetscErrorCode VecDestroy_IMPL_(Vec) noexcept;
53   static PetscErrorCode VecResetArray_IMPL_(Vec) noexcept;
54   static PetscErrorCode VecPlaceArray_IMPL_(Vec, const PetscScalar *) noexcept;
55   static PetscErrorCode VecCreate_IMPL_Private_(Vec, PetscBool *, PetscInt, PetscScalar *) noexcept;
56 
57   static PetscErrorCode MaybeIncrementEmptyLocalVec(Vec) noexcept;
58 
59   // common core for min and max
60   template <typename TupleFuncT, typename UnaryFuncT>
61   static PetscErrorCode MinMax_(TupleFuncT &&, UnaryFuncT &&, Vec, PetscInt *, PetscReal *) noexcept;
62   // common core for pointwise binary and pointwise unary thrust functions
63   template <typename BinaryFuncT>
64   static PetscErrorCode PointwiseBinary_(BinaryFuncT &&, Vec, Vec, Vec) noexcept;
65   template <typename UnaryFuncT>
66   static PetscErrorCode PointwiseUnary_(UnaryFuncT &&, Vec, Vec /*out*/ = nullptr) noexcept;
67   // mdot dispatchers
68   static PetscErrorCode MDot_(/* use complex = */ std::true_type, Vec, PetscInt, const Vec[], PetscScalar *, PetscDeviceContext) noexcept;
69   static PetscErrorCode MDot_(/* use complex = */ std::false_type, Vec, PetscInt, const Vec[], PetscScalar *, PetscDeviceContext) noexcept;
70   template <std::size_t... Idx>
71   static PetscErrorCode MDot_kernel_dispatch_(PetscDeviceContext, cupmStream_t, const PetscScalar *, const Vec[], PetscInt, PetscScalar *, util::index_sequence<Idx...>) noexcept;
72   template <int>
73   static PetscErrorCode MDot_kernel_dispatch_(PetscDeviceContext, cupmStream_t, const PetscScalar *, const Vec[], PetscInt, PetscScalar *, PetscInt &) noexcept;
74   template <std::size_t... Idx>
75   static PetscErrorCode MAXPY_kernel_dispatch_(PetscDeviceContext, cupmStream_t, PetscScalar *, const PetscScalar *, const Vec *, PetscInt, util::index_sequence<Idx...>) noexcept;
76   template <int>
77   static PetscErrorCode MAXPY_kernel_dispatch_(PetscDeviceContext, cupmStream_t, PetscScalar *, const PetscScalar *, const Vec *, PetscInt, PetscInt &) noexcept;
78   // common core for the various create routines
79   static PetscErrorCode CreateSeqCUPM_(Vec, PetscDeviceContext, PetscScalar * /*host_ptr*/ = nullptr, PetscScalar * /*device_ptr*/ = nullptr) noexcept;
80 
81 public:
82   // callable directly via a bespoke function
83   static PetscErrorCode CreateSeqCUPM(MPI_Comm, PetscInt, PetscInt, Vec *, PetscBool) noexcept;
84   static PetscErrorCode CreateSeqCUPMWithBothArrays(MPI_Comm, PetscInt, PetscInt, const PetscScalar[], const PetscScalar[], Vec *) noexcept;
85 
86   // callable indirectly via function pointers
87   static PetscErrorCode Duplicate(Vec, Vec *) noexcept;
88   static PetscErrorCode AYPX(Vec, PetscScalar, Vec) noexcept;
89   static PetscErrorCode AXPY(Vec, PetscScalar, Vec) noexcept;
90   static PetscErrorCode PointwiseDivide(Vec, Vec, Vec) noexcept;
91   static PetscErrorCode PointwiseMult(Vec, Vec, Vec) noexcept;
92   static PetscErrorCode Reciprocal(Vec) noexcept;
93   static PetscErrorCode WAXPY(Vec, PetscScalar, Vec, Vec) noexcept;
94   static PetscErrorCode MAXPY(Vec, PetscInt, const PetscScalar[], Vec *) noexcept;
95   static PetscErrorCode Dot(Vec, Vec, PetscScalar *) noexcept;
96   static PetscErrorCode MDot(Vec, PetscInt, const Vec[], PetscScalar *) noexcept;
97   static PetscErrorCode Set(Vec, PetscScalar) noexcept;
98   static PetscErrorCode Scale(Vec, PetscScalar) noexcept;
99   static PetscErrorCode TDot(Vec, Vec, PetscScalar *) noexcept;
100   static PetscErrorCode Copy(Vec, Vec) noexcept;
101   static PetscErrorCode Swap(Vec, Vec) noexcept;
102   static PetscErrorCode AXPBY(Vec, PetscScalar, PetscScalar, Vec) noexcept;
103   static PetscErrorCode AXPBYPCZ(Vec, PetscScalar, PetscScalar, PetscScalar, Vec, Vec) noexcept;
104   static PetscErrorCode Norm(Vec, NormType, PetscReal *) noexcept;
105   static PetscErrorCode ErrorWnorm(Vec, Vec, Vec, NormType, PetscReal, Vec, PetscReal, Vec, PetscReal, PetscReal *, PetscInt *, PetscReal *, PetscInt *, PetscReal *, PetscInt *) noexcept;
106   static PetscErrorCode DotNorm2(Vec, Vec, PetscScalar *, PetscScalar *) noexcept;
107   static PetscErrorCode Conjugate(Vec) noexcept;
108   template <PetscMemoryAccessMode>
109   static PetscErrorCode GetLocalVector(Vec, Vec) noexcept;
110   template <PetscMemoryAccessMode>
111   static PetscErrorCode RestoreLocalVector(Vec, Vec) noexcept;
112   static PetscErrorCode Max(Vec, PetscInt *, PetscReal *) noexcept;
113   static PetscErrorCode Min(Vec, PetscInt *, PetscReal *) noexcept;
114   static PetscErrorCode Sum(Vec, PetscScalar *) noexcept;
115   static PetscErrorCode Shift(Vec, PetscScalar) noexcept;
116   static PetscErrorCode SetRandom(Vec, PetscRandom) noexcept;
117   static PetscErrorCode BindToCPU(Vec, PetscBool) noexcept;
118   static PetscErrorCode SetPreallocationCOO(Vec, PetscCount, const PetscInt[]) noexcept;
119   static PetscErrorCode SetValuesCOO(Vec, const PetscScalar[], InsertMode) noexcept;
120 };
121 
122 // ==========================================================================================
123 // VecSeq_CUPM - Private API
124 // ==========================================================================================
125 
126 template <device::cupm::DeviceType T>
127 inline Vec_Seq *VecSeq_CUPM<T>::VecIMPLCast_(Vec v) noexcept
128 {
129   return static_cast<Vec_Seq *>(v->data);
130 }
131 
132 template <device::cupm::DeviceType T>
133 inline constexpr VecType VecSeq_CUPM<T>::VECIMPLCUPM_() noexcept
134 {
135   return VECSEQCUPM();
136 }
137 
138 template <device::cupm::DeviceType T>
139 inline constexpr VecType VecSeq_CUPM<T>::VECIMPL_() noexcept
140 {
141   return VECSEQ;
142 }
143 
144 template <device::cupm::DeviceType T>
145 inline PetscErrorCode VecSeq_CUPM<T>::VecDestroy_IMPL_(Vec v) noexcept
146 {
147   return VecDestroy_Seq(v);
148 }
149 
150 template <device::cupm::DeviceType T>
151 inline PetscErrorCode VecSeq_CUPM<T>::VecResetArray_IMPL_(Vec v) noexcept
152 {
153   return VecResetArray_Seq(v);
154 }
155 
156 template <device::cupm::DeviceType T>
157 inline PetscErrorCode VecSeq_CUPM<T>::VecPlaceArray_IMPL_(Vec v, const PetscScalar *a) noexcept
158 {
159   return VecPlaceArray_Seq(v, a);
160 }
161 
162 template <device::cupm::DeviceType T>
163 inline PetscErrorCode VecSeq_CUPM<T>::VecCreate_IMPL_Private_(Vec v, PetscBool *alloc_missing, PetscInt, PetscScalar *host_array) noexcept
164 {
165   PetscMPIInt size;
166 
167   PetscFunctionBegin;
168   if (alloc_missing) *alloc_missing = PETSC_FALSE;
169   PetscCallMPI(MPI_Comm_size(PetscObjectComm(PetscObjectCast(v)), &size));
170   PetscCheck(size <= 1, PETSC_COMM_SELF, PETSC_ERR_ARG_WRONG, "Must create VecSeq on communicator of size 1, have size %d", size);
171   PetscCall(VecCreate_Seq_Private(v, host_array));
172   PetscFunctionReturn(PETSC_SUCCESS);
173 }
174 
175 // for functions with an early return based one vec size we still need to artificially bump the
176 // object state. This is to prevent the following:
177 //
178 // 0. Suppose you have a Vec {
179 //   rank 0: [0],
180 //   rank 1: [<empty>]
181 // }
182 // 1. both ranks have Vec with PetscObjectState = 0, stashed norm of 0
183 // 2. Vec enters e.g. VecSet(10)
184 // 3. rank 1 has local size 0 and bails immediately
185 // 4. rank 0 has local size 1 and enters function, eventually calls DeviceArrayWrite()
186 // 5. DeviceArrayWrite() calls PetscObjectStateIncrease(), now state = 1
187 // 6. Vec enters VecNorm(), and calls VecNormAvailable()
188 // 7. rank 1 has object state = 0, equal to stash and returns early with norm = 0
189 // 8. rank 0 has object state = 1, not equal to stash, continues to impl function
190 // 9. rank 0 deadlocks on MPI_Allreduce() because rank 1 bailed early
191 template <device::cupm::DeviceType T>
192 inline PetscErrorCode VecSeq_CUPM<T>::MaybeIncrementEmptyLocalVec(Vec v) noexcept
193 {
194   PetscFunctionBegin;
195   if (PetscUnlikely((v->map->n == 0) && (v->map->N != 0))) PetscCall(PetscObjectStateIncrease(PetscObjectCast(v)));
196   PetscFunctionReturn(PETSC_SUCCESS);
197 }
198 
199 template <device::cupm::DeviceType T>
200 inline PetscErrorCode VecSeq_CUPM<T>::CreateSeqCUPM_(Vec v, PetscDeviceContext dctx, PetscScalar *host_array, PetscScalar *device_array) noexcept
201 {
202   PetscFunctionBegin;
203   PetscCall(base_type::VecCreate_IMPL_Private(v, nullptr, 0, host_array));
204   PetscCall(Initialize_CUPMBase(v, PETSC_FALSE, host_array, device_array, dctx));
205   PetscFunctionReturn(PETSC_SUCCESS);
206 }
207 
208 template <device::cupm::DeviceType T>
209 template <typename BinaryFuncT>
210 inline PetscErrorCode VecSeq_CUPM<T>::PointwiseBinary_(BinaryFuncT &&binary, Vec xin, Vec yin, Vec zout) noexcept
211 {
212   PetscFunctionBegin;
213   if (const auto n = zout->map->n) {
214     PetscDeviceContext dctx;
215     cupmStream_t       stream;
216 
217     PetscCall(GetHandles_(&dctx, &stream));
218     // clang-format off
219     PetscCallThrust(
220       const auto dxptr = thrust::device_pointer_cast(DeviceArrayRead(dctx, xin).data());
221 
222       THRUST_CALL(
223         thrust::transform,
224         stream,
225         dxptr, dxptr + n,
226         thrust::device_pointer_cast(DeviceArrayRead(dctx, yin).data()),
227         thrust::device_pointer_cast(DeviceArrayWrite(dctx, zout).data()),
228         std::forward<BinaryFuncT>(binary)
229       )
230     );
231     // clang-format on
232     PetscCall(PetscLogFlops(n));
233     PetscCall(PetscDeviceContextSynchronize(dctx));
234   } else {
235     PetscCall(MaybeIncrementEmptyLocalVec(zout));
236   }
237   PetscFunctionReturn(PETSC_SUCCESS);
238 }
239 
240 template <device::cupm::DeviceType T>
241 template <typename UnaryFuncT>
242 inline PetscErrorCode VecSeq_CUPM<T>::PointwiseUnary_(UnaryFuncT &&unary, Vec xinout, Vec yin) noexcept
243 {
244   const auto inplace = !yin || (xinout == yin);
245 
246   PetscFunctionBegin;
247   if (const auto n = xinout->map->n) {
248     PetscDeviceContext dctx;
249     cupmStream_t       stream;
250     const auto         apply = [&](PetscScalar *xinout, PetscScalar *yin = nullptr) {
251       PetscFunctionBegin;
252       // clang-format off
253       PetscCallThrust(
254         const auto xptr = thrust::device_pointer_cast(xinout);
255 
256         THRUST_CALL(
257           thrust::transform,
258           stream,
259           xptr, xptr + n,
260           (yin && (yin != xinout)) ? thrust::device_pointer_cast(yin) : xptr,
261           std::forward<UnaryFuncT>(unary)
262         )
263       );
264       // clang-format on
265       PetscFunctionReturn(PETSC_SUCCESS);
266     };
267 
268     PetscCall(GetHandles_(&dctx, &stream));
269     if (inplace) {
270       PetscCall(apply(DeviceArrayReadWrite(dctx, xinout).data()));
271     } else {
272       PetscCall(apply(DeviceArrayRead(dctx, xinout).data(), DeviceArrayWrite(dctx, yin).data()));
273     }
274     PetscCall(PetscLogFlops(n));
275     PetscCall(PetscDeviceContextSynchronize(dctx));
276   } else {
277     if (inplace) {
278       PetscCall(MaybeIncrementEmptyLocalVec(xinout));
279     } else {
280       PetscCall(MaybeIncrementEmptyLocalVec(yin));
281     }
282   }
283   PetscFunctionReturn(PETSC_SUCCESS);
284 }
285 
286 // ==========================================================================================
287 // VecSeq_CUPM - Public API - Constructors
288 // ==========================================================================================
289 
290 // VecCreateSeqCUPM()
291 template <device::cupm::DeviceType T>
292 inline PetscErrorCode VecSeq_CUPM<T>::CreateSeqCUPM(MPI_Comm comm, PetscInt bs, PetscInt n, Vec *v, PetscBool call_set_type) noexcept
293 {
294   PetscFunctionBegin;
295   PetscCall(Create_CUPMBase(comm, bs, n, n, v, call_set_type));
296   PetscFunctionReturn(PETSC_SUCCESS);
297 }
298 
299 // VecCreateSeqCUPMWithArrays()
300 template <device::cupm::DeviceType T>
301 inline PetscErrorCode VecSeq_CUPM<T>::CreateSeqCUPMWithBothArrays(MPI_Comm comm, PetscInt bs, PetscInt n, const PetscScalar host_array[], const PetscScalar device_array[], Vec *v) noexcept
302 {
303   PetscDeviceContext dctx;
304 
305   PetscFunctionBegin;
306   PetscCall(GetHandles_(&dctx));
307   // do NOT call VecSetType(), otherwise ops->create() -> create() ->
308   // CreateSeqCUPM_() is called!
309   PetscCall(CreateSeqCUPM(comm, bs, n, v, PETSC_FALSE));
310   PetscCall(CreateSeqCUPM_(*v, dctx, PetscRemoveConstCast(host_array), PetscRemoveConstCast(device_array)));
311   PetscFunctionReturn(PETSC_SUCCESS);
312 }
313 
314 // v->ops->duplicate
315 template <device::cupm::DeviceType T>
316 inline PetscErrorCode VecSeq_CUPM<T>::Duplicate(Vec v, Vec *y) noexcept
317 {
318   PetscDeviceContext dctx;
319 
320   PetscFunctionBegin;
321   PetscCall(GetHandles_(&dctx));
322   PetscCall(Duplicate_CUPMBase(v, y, dctx));
323   PetscFunctionReturn(PETSC_SUCCESS);
324 }
325 
326 // ==========================================================================================
327 // VecSeq_CUPM - Public API - Utility
328 // ==========================================================================================
329 
330 // v->ops->bindtocpu
331 template <device::cupm::DeviceType T>
332 inline PetscErrorCode VecSeq_CUPM<T>::BindToCPU(Vec v, PetscBool usehost) noexcept
333 {
334   PetscDeviceContext dctx;
335 
336   PetscFunctionBegin;
337   PetscCall(GetHandles_(&dctx));
338   PetscCall(BindToCPU_CUPMBase(v, usehost, dctx));
339 
340   // REVIEW ME: this absolutely should be some sort of bulk mempcy rather than this mess
341   VecSetOp_CUPM(dot, VecDot_Seq, Dot);
342   VecSetOp_CUPM(norm, VecNorm_Seq, Norm);
343   VecSetOp_CUPM(tdot, VecTDot_Seq, TDot);
344   VecSetOp_CUPM(mdot, VecMDot_Seq, MDot);
345   VecSetOp_CUPM(resetarray, VecResetArray_Seq, base_type::template ResetArray<PETSC_MEMTYPE_HOST>);
346   VecSetOp_CUPM(placearray, VecPlaceArray_Seq, base_type::template PlaceArray<PETSC_MEMTYPE_HOST>);
347   v->ops->mtdot = v->ops->mtdot_local = VecMTDot_Seq;
348   VecSetOp_CUPM(conjugate, VecConjugate_Seq, Conjugate);
349   VecSetOp_CUPM(max, VecMax_Seq, Max);
350   VecSetOp_CUPM(min, VecMin_Seq, Min);
351   VecSetOp_CUPM(setpreallocationcoo, VecSetPreallocationCOO_Seq, SetPreallocationCOO);
352   VecSetOp_CUPM(setvaluescoo, VecSetValuesCOO_Seq, SetValuesCOO);
353   VecSetOp_CUPM(errorwnorm, nullptr, ErrorWnorm);
354   PetscFunctionReturn(PETSC_SUCCESS);
355 }
356 
357 // ==========================================================================================
358 // VecSeq_CUPM - Public API - Mutators
359 // ==========================================================================================
360 
361 // v->ops->getlocalvector or v->ops->getlocalvectorread
362 template <device::cupm::DeviceType T>
363 template <PetscMemoryAccessMode access>
364 inline PetscErrorCode VecSeq_CUPM<T>::GetLocalVector(Vec v, Vec w) noexcept
365 {
366   PetscBool wisseqcupm;
367 
368   PetscFunctionBegin;
369   PetscCheckTypeNames(v, VECSEQCUPM(), VECMPICUPM());
370   PetscCall(PetscObjectTypeCompare(PetscObjectCast(w), VECSEQCUPM(), &wisseqcupm));
371   if (wisseqcupm) {
372     if (const auto wseq = VecIMPLCast(w)) {
373       if (auto &alloced = wseq->array_allocated) {
374         const auto useit = UseCUPMHostAlloc(util::exchange(w->pinned_memory, PETSC_FALSE));
375 
376         PetscCall(PetscFree(alloced));
377       }
378       wseq->array         = nullptr;
379       wseq->unplacedarray = nullptr;
380     }
381     if (const auto wcu = VecCUPMCast(w)) {
382       if (auto &device_array = wcu->array_d) {
383         cupmStream_t stream;
384 
385         PetscCall(GetHandles_(&stream));
386         PetscCallCUPM(cupmFreeAsync(device_array, stream));
387       }
388       PetscCall(PetscFree(w->spptr /* wcu */));
389     }
390   }
391   if (v->petscnative && wisseqcupm) {
392     PetscCall(PetscFree(w->data));
393     w->data          = v->data;
394     w->offloadmask   = v->offloadmask;
395     w->pinned_memory = v->pinned_memory;
396     w->spptr         = v->spptr;
397     PetscCall(PetscObjectStateIncrease(PetscObjectCast(w)));
398   } else {
399     const auto array = &VecIMPLCast(w)->array;
400 
401     if (access == PETSC_MEMORY_ACCESS_READ) {
402       PetscCall(VecGetArrayRead(v, const_cast<const PetscScalar **>(array)));
403     } else {
404       PetscCall(VecGetArray(v, array));
405     }
406     w->offloadmask = PETSC_OFFLOAD_CPU;
407     if (wisseqcupm) {
408       PetscDeviceContext dctx;
409 
410       PetscCall(GetHandles_(&dctx));
411       PetscCall(DeviceAllocateCheck_(dctx, w));
412     }
413   }
414   PetscFunctionReturn(PETSC_SUCCESS);
415 }
416 
417 // v->ops->restorelocalvector or v->ops->restorelocalvectorread
418 template <device::cupm::DeviceType T>
419 template <PetscMemoryAccessMode access>
420 inline PetscErrorCode VecSeq_CUPM<T>::RestoreLocalVector(Vec v, Vec w) noexcept
421 {
422   PetscBool wisseqcupm;
423 
424   PetscFunctionBegin;
425   PetscCheckTypeNames(v, VECSEQCUPM(), VECMPICUPM());
426   PetscCall(PetscObjectTypeCompare(PetscObjectCast(w), VECSEQCUPM(), &wisseqcupm));
427   if (v->petscnative && wisseqcupm) {
428     // the assignments to nullptr are __critical__, as w may persist after this call returns
429     // and shouldn't share data with v!
430     v->pinned_memory = w->pinned_memory;
431     v->offloadmask   = util::exchange(w->offloadmask, PETSC_OFFLOAD_UNALLOCATED);
432     v->data          = util::exchange(w->data, nullptr);
433     v->spptr         = util::exchange(w->spptr, nullptr);
434   } else {
435     const auto array = &VecIMPLCast(w)->array;
436 
437     if (access == PETSC_MEMORY_ACCESS_READ) {
438       PetscCall(VecRestoreArrayRead(v, const_cast<const PetscScalar **>(array)));
439     } else {
440       PetscCall(VecRestoreArray(v, array));
441     }
442     if (w->spptr && wisseqcupm) {
443       cupmStream_t stream;
444 
445       PetscCall(GetHandles_(&stream));
446       PetscCallCUPM(cupmFreeAsync(VecCUPMCast(w)->array_d, stream));
447       PetscCall(PetscFree(w->spptr));
448     }
449   }
450   PetscFunctionReturn(PETSC_SUCCESS);
451 }
452 
453 // ==========================================================================================
454 // VecSeq_CUPM - Public API - Compute Methods
455 // ==========================================================================================
456 
457 // v->ops->aypx
458 template <device::cupm::DeviceType T>
459 inline PetscErrorCode VecSeq_CUPM<T>::AYPX(Vec yin, PetscScalar alpha, Vec xin) noexcept
460 {
461   const auto         n    = static_cast<cupmBlasInt_t>(yin->map->n);
462   const auto         sync = n != 0;
463   PetscDeviceContext dctx;
464 
465   PetscFunctionBegin;
466   PetscCall(GetHandles_(&dctx));
467   if (alpha == PetscScalar(0.0)) {
468     cupmStream_t stream;
469 
470     PetscCall(GetHandlesFrom_(dctx, &stream));
471     PetscCall(PetscLogGpuTimeBegin());
472     PetscCall(PetscCUPMMemcpyAsync(DeviceArrayWrite(dctx, yin).data(), DeviceArrayRead(dctx, xin).data(), n, cupmMemcpyDeviceToDevice, stream));
473     PetscCall(PetscLogGpuTimeEnd());
474   } else if (n) {
475     const auto       alphaIsOne = alpha == PetscScalar(1.0);
476     const auto       calpha     = cupmScalarPtrCast(&alpha);
477     cupmBlasHandle_t cupmBlasHandle;
478 
479     PetscCall(GetHandlesFrom_(dctx, &cupmBlasHandle));
480     {
481       const auto yptr = DeviceArrayReadWrite(dctx, yin);
482       const auto xptr = DeviceArrayRead(dctx, xin);
483 
484       PetscCall(PetscLogGpuTimeBegin());
485       if (alphaIsOne) {
486         PetscCallCUPMBLAS(cupmBlasXaxpy(cupmBlasHandle, n, calpha, xptr.cupmdata(), 1, yptr.cupmdata(), 1));
487       } else {
488         const auto one = cupmScalarCast(1.0);
489 
490         PetscCallCUPMBLAS(cupmBlasXscal(cupmBlasHandle, n, calpha, yptr.cupmdata(), 1));
491         PetscCallCUPMBLAS(cupmBlasXaxpy(cupmBlasHandle, n, &one, xptr.cupmdata(), 1, yptr.cupmdata(), 1));
492       }
493       PetscCall(PetscLogGpuTimeEnd());
494     }
495     PetscCall(PetscLogGpuFlops((alphaIsOne ? 1 : 2) * n));
496   }
497   if (sync) PetscCall(PetscDeviceContextSynchronize(dctx));
498   PetscFunctionReturn(PETSC_SUCCESS);
499 }
500 
501 // v->ops->axpy
502 template <device::cupm::DeviceType T>
503 inline PetscErrorCode VecSeq_CUPM<T>::AXPY(Vec yin, PetscScalar alpha, Vec xin) noexcept
504 {
505   PetscBool xiscupm;
506 
507   PetscFunctionBegin;
508   PetscCall(PetscObjectTypeCompareAny(PetscObjectCast(xin), &xiscupm, VECSEQCUPM(), VECMPICUPM(), ""));
509   if (xiscupm) {
510     const auto         n = static_cast<cupmBlasInt_t>(yin->map->n);
511     PetscDeviceContext dctx;
512     cupmBlasHandle_t   cupmBlasHandle;
513 
514     PetscCall(GetHandles_(&dctx, &cupmBlasHandle));
515     PetscCall(PetscLogGpuTimeBegin());
516     PetscCallCUPMBLAS(cupmBlasXaxpy(cupmBlasHandle, n, cupmScalarPtrCast(&alpha), DeviceArrayRead(dctx, xin), 1, DeviceArrayReadWrite(dctx, yin), 1));
517     PetscCall(PetscLogGpuTimeEnd());
518     PetscCall(PetscLogGpuFlops(2 * n));
519     PetscCall(PetscDeviceContextSynchronize(dctx));
520   } else {
521     PetscCall(VecAXPY_Seq(yin, alpha, xin));
522   }
523   PetscFunctionReturn(PETSC_SUCCESS);
524 }
525 
526 // v->ops->pointwisedivide
527 template <device::cupm::DeviceType T>
528 inline PetscErrorCode VecSeq_CUPM<T>::PointwiseDivide(Vec win, Vec xin, Vec yin) noexcept
529 {
530   PetscFunctionBegin;
531   if (xin->boundtocpu || yin->boundtocpu) {
532     PetscCall(VecPointwiseDivide_Seq(win, xin, yin));
533   } else {
534     // note order of arguments! xin and yin are read, win is written!
535     PetscCall(PointwiseBinary_(thrust::divides<PetscScalar>{}, xin, yin, win));
536   }
537   PetscFunctionReturn(PETSC_SUCCESS);
538 }
539 
540 // v->ops->pointwisemult
541 template <device::cupm::DeviceType T>
542 inline PetscErrorCode VecSeq_CUPM<T>::PointwiseMult(Vec win, Vec xin, Vec yin) noexcept
543 {
544   PetscFunctionBegin;
545   if (xin->boundtocpu || yin->boundtocpu) {
546     PetscCall(VecPointwiseMult_Seq(win, xin, yin));
547   } else {
548     // note order of arguments! xin and yin are read, win is written!
549     PetscCall(PointwiseBinary_(thrust::multiplies<PetscScalar>{}, xin, yin, win));
550   }
551   PetscFunctionReturn(PETSC_SUCCESS);
552 }
553 
554 namespace detail
555 {
556 
557 struct reciprocal {
558   PETSC_HOSTDEVICE_INLINE_DECL PetscScalar operator()(PetscScalar s) const noexcept
559   {
560     // yes all of this verbosity is needed because sometimes PetscScalar is a thrust::complex
561     // and then it matters whether we do s ? true : false vs s == 0, as well as whether we wrap
562     // everything in PetscScalar...
563     return s == PetscScalar{0.0} ? s : PetscScalar{1.0} / s;
564   }
565 };
566 
567 } // namespace detail
568 
569 // v->ops->reciprocal
570 template <device::cupm::DeviceType T>
571 inline PetscErrorCode VecSeq_CUPM<T>::Reciprocal(Vec xin) noexcept
572 {
573   PetscFunctionBegin;
574   PetscCall(PointwiseUnary_(detail::reciprocal{}, xin));
575   PetscFunctionReturn(PETSC_SUCCESS);
576 }
577 
578 // v->ops->waxpy
579 template <device::cupm::DeviceType T>
580 inline PetscErrorCode VecSeq_CUPM<T>::WAXPY(Vec win, PetscScalar alpha, Vec xin, Vec yin) noexcept
581 {
582   PetscFunctionBegin;
583   if (alpha == PetscScalar(0.0)) {
584     PetscCall(Copy(yin, win));
585   } else if (const auto n = static_cast<cupmBlasInt_t>(win->map->n)) {
586     PetscDeviceContext dctx;
587     cupmBlasHandle_t   cupmBlasHandle;
588     cupmStream_t       stream;
589 
590     PetscCall(GetHandles_(&dctx, &cupmBlasHandle, &stream));
591     {
592       const auto wptr = DeviceArrayWrite(dctx, win);
593 
594       PetscCall(PetscLogGpuTimeBegin());
595       PetscCall(PetscCUPMMemcpyAsync(wptr.data(), DeviceArrayRead(dctx, yin).data(), n, cupmMemcpyDeviceToDevice, stream, true));
596       PetscCallCUPMBLAS(cupmBlasXaxpy(cupmBlasHandle, n, cupmScalarPtrCast(&alpha), DeviceArrayRead(dctx, xin), 1, wptr.cupmdata(), 1));
597       PetscCall(PetscLogGpuTimeEnd());
598     }
599     PetscCall(PetscLogGpuFlops(2 * n));
600     PetscCall(PetscDeviceContextSynchronize(dctx));
601   }
602   PetscFunctionReturn(PETSC_SUCCESS);
603 }
604 
605 namespace kernels
606 {
607 
608 template <typename... Args>
609 PETSC_KERNEL_DECL static void MAXPY_kernel(const PetscInt size, PetscScalar *PETSC_RESTRICT xptr, const PetscScalar *PETSC_RESTRICT aptr, Args... yptr)
610 {
611   constexpr int      N        = sizeof...(Args);
612   const auto         tx       = threadIdx.x;
613   const PetscScalar *yptr_p[] = {yptr...};
614 
615   PETSC_SHAREDMEM_DECL PetscScalar aptr_shmem[N];
616 
617   // load a to shared memory
618   if (tx < N) aptr_shmem[tx] = aptr[tx];
619   __syncthreads();
620 
621   ::Petsc::device::cupm::kernels::util::grid_stride_1D(size, [&](PetscInt i) {
622   // these may look the same but give different results!
623 #if 0
624     PetscScalar sum = 0.0;
625 
626   #pragma unroll
627     for (auto j = 0; j < N; ++j) sum += aptr_shmem[j]*yptr_p[j][i];
628     xptr[i] += sum;
629 #else
630     auto sum = xptr[i];
631 
632   #pragma unroll
633     for (auto j = 0; j < N; ++j) sum += aptr_shmem[j] * yptr_p[j][i];
634     xptr[i] = sum;
635 #endif
636   });
637   return;
638 }
639 
640 } // namespace kernels
641 
642 namespace detail
643 {
644 
645 // a helper-struct to gobble the size_t input, it is used with template parameter pack
646 // expansion such that
647 // typename repeat_type<MyType, IdxParamPack>...
648 // expands to
649 // MyType, MyType, MyType, ... [repeated sizeof...(IdxParamPack) times]
650 template <typename T, std::size_t>
651 struct repeat_type {
652   using type = T;
653 };
654 
655 } // namespace detail
656 
657 template <device::cupm::DeviceType T>
658 template <std::size_t... Idx>
659 inline PetscErrorCode VecSeq_CUPM<T>::MAXPY_kernel_dispatch_(PetscDeviceContext dctx, cupmStream_t stream, PetscScalar *xptr, const PetscScalar *aptr, const Vec *yin, PetscInt size, util::index_sequence<Idx...>) noexcept
660 {
661   PetscFunctionBegin;
662   // clang-format off
663   PetscCall(
664     PetscCUPMLaunchKernel1D(
665       size, 0, stream,
666       kernels::MAXPY_kernel<typename detail::repeat_type<const PetscScalar *, Idx>::type...>,
667       size, xptr, aptr, DeviceArrayRead(dctx, yin[Idx]).data()...
668     )
669   );
670   // clang-format on
671   PetscFunctionReturn(PETSC_SUCCESS);
672 }
673 
674 template <device::cupm::DeviceType T>
675 template <int N>
676 inline PetscErrorCode VecSeq_CUPM<T>::MAXPY_kernel_dispatch_(PetscDeviceContext dctx, cupmStream_t stream, PetscScalar *xptr, const PetscScalar *aptr, const Vec *yin, PetscInt size, PetscInt &yidx) noexcept
677 {
678   PetscFunctionBegin;
679   PetscCall(MAXPY_kernel_dispatch_(dctx, stream, xptr, aptr + yidx, yin + yidx, size, util::make_index_sequence<N>{}));
680   yidx += N;
681   PetscFunctionReturn(PETSC_SUCCESS);
682 }
683 
684 // v->ops->maxpy
685 template <device::cupm::DeviceType T>
686 inline PetscErrorCode VecSeq_CUPM<T>::MAXPY(Vec xin, PetscInt nv, const PetscScalar *alpha, Vec *yin) noexcept
687 {
688   const auto         n = xin->map->n;
689   PetscDeviceContext dctx;
690   cupmStream_t       stream;
691 
692   PetscFunctionBegin;
693   PetscCall(GetHandles_(&dctx, &stream));
694   {
695     const auto   xptr    = DeviceArrayReadWrite(dctx, xin);
696     PetscScalar *d_alpha = nullptr;
697     PetscInt     yidx    = 0;
698 
699     // placement of early-return is deliberate, we would like to capture the
700     // DeviceArrayReadWrite() call (which calls PetscObjectStateIncreate()) before we bail
701     if (!n || !nv) PetscFunctionReturn(PETSC_SUCCESS);
702     PetscCall(PetscDeviceMalloc(dctx, PETSC_MEMTYPE_CUPM(), nv, &d_alpha));
703     PetscCall(PetscCUPMMemcpyAsync(d_alpha, alpha, nv, cupmMemcpyHostToDevice, stream));
704     PetscCall(PetscLogGpuTimeBegin());
705     do {
706       switch (nv - yidx) {
707       case 7:
708         PetscCall(MAXPY_kernel_dispatch_<7>(dctx, stream, xptr.data(), d_alpha, yin, n, yidx));
709         break;
710       case 6:
711         PetscCall(MAXPY_kernel_dispatch_<6>(dctx, stream, xptr.data(), d_alpha, yin, n, yidx));
712         break;
713       case 5:
714         PetscCall(MAXPY_kernel_dispatch_<5>(dctx, stream, xptr.data(), d_alpha, yin, n, yidx));
715         break;
716       case 4:
717         PetscCall(MAXPY_kernel_dispatch_<4>(dctx, stream, xptr.data(), d_alpha, yin, n, yidx));
718         break;
719       case 3:
720         PetscCall(MAXPY_kernel_dispatch_<3>(dctx, stream, xptr.data(), d_alpha, yin, n, yidx));
721         break;
722       case 2:
723         PetscCall(MAXPY_kernel_dispatch_<2>(dctx, stream, xptr.data(), d_alpha, yin, n, yidx));
724         break;
725       case 1:
726         PetscCall(MAXPY_kernel_dispatch_<1>(dctx, stream, xptr.data(), d_alpha, yin, n, yidx));
727         break;
728       default: // 8 or more
729         PetscCall(MAXPY_kernel_dispatch_<8>(dctx, stream, xptr.data(), d_alpha, yin, n, yidx));
730         break;
731       }
732     } while (yidx < nv);
733     PetscCall(PetscLogGpuTimeEnd());
734     PetscCall(PetscDeviceFree(dctx, d_alpha));
735   }
736   PetscCall(PetscLogGpuFlops(nv * 2 * n));
737   PetscCall(PetscDeviceContextSynchronize(dctx));
738   PetscFunctionReturn(PETSC_SUCCESS);
739 }
740 
741 template <device::cupm::DeviceType T>
742 inline PetscErrorCode VecSeq_CUPM<T>::Dot(Vec xin, Vec yin, PetscScalar *z) noexcept
743 {
744   PetscFunctionBegin;
745   if (const auto n = static_cast<cupmBlasInt_t>(xin->map->n)) {
746     PetscDeviceContext dctx;
747     cupmBlasHandle_t   cupmBlasHandle;
748 
749     PetscCall(GetHandles_(&dctx, &cupmBlasHandle));
750     // arguments y, x are reversed because BLAS complex conjugates the first argument, PETSc the
751     // second
752     PetscCall(PetscLogGpuTimeBegin());
753     PetscCallCUPMBLAS(cupmBlasXdot(cupmBlasHandle, n, DeviceArrayRead(dctx, yin), 1, DeviceArrayRead(dctx, xin), 1, cupmScalarPtrCast(z)));
754     PetscCall(PetscLogGpuTimeEnd());
755     PetscCall(PetscLogGpuFlops(2 * n - 1));
756   } else {
757     *z = 0.0;
758   }
759   PetscFunctionReturn(PETSC_SUCCESS);
760 }
761 
762 #define MDOT_WORKGROUP_NUM  128
763 #define MDOT_WORKGROUP_SIZE MDOT_WORKGROUP_NUM
764 
765 namespace kernels
766 {
767 
768 PETSC_DEVICE_INLINE_DECL static PetscInt EntriesPerGroup(const PetscInt size) noexcept
769 {
770   const auto group_entries = (size - 1) / gridDim.x + 1;
771   // for very small vectors, a group should still do some work
772   return group_entries ? group_entries : 1;
773 }
774 
775 template <typename... ConstPetscScalarPointer>
776 PETSC_KERNEL_DECL static void MDot_kernel(const PetscScalar *PETSC_RESTRICT x, const PetscInt size, PetscScalar *PETSC_RESTRICT results, ConstPetscScalarPointer... y)
777 {
778   constexpr int      N        = sizeof...(ConstPetscScalarPointer);
779   const PetscScalar *ylocal[] = {y...};
780   PetscScalar        sumlocal[N];
781 
782   PETSC_SHAREDMEM_DECL PetscScalar shmem[N * MDOT_WORKGROUP_SIZE];
783 
784   // HIP -- for whatever reason -- has threadIdx, blockIdx, blockDim, and gridDim as separate
785   // types, so each of these go on separate lines...
786   const auto tx       = threadIdx.x;
787   const auto bx       = blockIdx.x;
788   const auto bdx      = blockDim.x;
789   const auto gdx      = gridDim.x;
790   const auto worksize = EntriesPerGroup(size);
791   const auto begin    = tx + bx * worksize;
792   const auto end      = min((bx + 1) * worksize, size);
793 
794 #pragma unroll
795   for (auto i = 0; i < N; ++i) sumlocal[i] = 0;
796 
797   for (auto i = begin; i < end; i += bdx) {
798     const auto xi = x[i]; // load only once from global memory!
799 
800 #pragma unroll
801     for (auto j = 0; j < N; ++j) sumlocal[j] += ylocal[j][i] * xi;
802   }
803 
804 #pragma unroll
805   for (auto i = 0; i < N; ++i) shmem[tx + i * MDOT_WORKGROUP_SIZE] = sumlocal[i];
806 
807   // parallel reduction
808   for (auto stride = bdx / 2; stride > 0; stride /= 2) {
809     __syncthreads();
810     if (tx < stride) {
811 #pragma unroll
812       for (auto i = 0; i < N; ++i) shmem[tx + i * MDOT_WORKGROUP_SIZE] += shmem[tx + stride + i * MDOT_WORKGROUP_SIZE];
813     }
814   }
815   // bottom N threads per block write to global memory
816   // REVIEW ME: I am ~pretty~ sure we don't need another __syncthreads() here since each thread
817   // writes to the same sections in the above loop that it is about to read from below, but
818   // running this under the racecheck tool of cuda-memcheck reports a write-after-write hazard.
819   __syncthreads();
820   if (tx < N) results[bx + tx * gdx] = shmem[tx * MDOT_WORKGROUP_SIZE];
821   return;
822 }
823 
824 namespace
825 {
826 
827 PETSC_KERNEL_DECL void sum_kernel(const PetscInt size, PetscScalar *PETSC_RESTRICT results)
828 {
829   int         local_i = 0;
830   PetscScalar local_results[8];
831 
832   // each thread sums up MDOT_WORKGROUP_NUM entries of the result, storing it in a local buffer
833   //
834   // *-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*
835   // | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | ...
836   // *-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*
837   //  |  ______________________________________________________/
838   //  | /            <- MDOT_WORKGROUP_NUM ->
839   //  |/
840   //  +
841   //  v
842   // *-*-*
843   // | | | ...
844   // *-*-*
845   //
846   ::Petsc::device::cupm::kernels::util::grid_stride_1D(size, [&](PetscInt i) {
847     PetscScalar z_sum = 0;
848 
849     for (auto j = i * MDOT_WORKGROUP_SIZE; j < (i + 1) * MDOT_WORKGROUP_SIZE; ++j) z_sum += results[j];
850     local_results[local_i++] = z_sum;
851   });
852   // if we needed more than 1 workgroup to handle the vector we should sync since other threads
853   // may currently be reading from results
854   if (size >= MDOT_WORKGROUP_SIZE) __syncthreads();
855   // Local buffer is now written to global memory
856   ::Petsc::device::cupm::kernels::util::grid_stride_1D(size, [&](PetscInt i) {
857     const auto j = --local_i;
858 
859     if (j >= 0) results[i] = local_results[j];
860   });
861   return;
862 }
863 
864 } // namespace
865 
866 } // namespace kernels
867 
868 template <device::cupm::DeviceType T>
869 template <std::size_t... Idx>
870 inline PetscErrorCode VecSeq_CUPM<T>::MDot_kernel_dispatch_(PetscDeviceContext dctx, cupmStream_t stream, const PetscScalar *xarr, const Vec yin[], PetscInt size, PetscScalar *results, util::index_sequence<Idx...>) noexcept
871 {
872   PetscFunctionBegin;
873   // REVIEW ME: convert this kernel launch to PetscCUPMLaunchKernel1D(), it currently launches
874   // 128 blocks of 128 threads every time which may be wasteful
875   // clang-format off
876   PetscCallCUPM(
877     cupmLaunchKernel(
878       kernels::MDot_kernel<typename detail::repeat_type<const PetscScalar *, Idx>::type...>,
879       MDOT_WORKGROUP_NUM, MDOT_WORKGROUP_SIZE, 0, stream,
880       xarr, size, results, DeviceArrayRead(dctx, yin[Idx]).data()...
881     )
882   );
883   // clang-format on
884   PetscFunctionReturn(PETSC_SUCCESS);
885 }
886 
887 template <device::cupm::DeviceType T>
888 template <int N>
889 inline PetscErrorCode VecSeq_CUPM<T>::MDot_kernel_dispatch_(PetscDeviceContext dctx, cupmStream_t stream, const PetscScalar *xarr, const Vec yin[], PetscInt size, PetscScalar *results, PetscInt &yidx) noexcept
890 {
891   PetscFunctionBegin;
892   PetscCall(MDot_kernel_dispatch_(dctx, stream, xarr, yin + yidx, size, results + yidx * MDOT_WORKGROUP_NUM, util::make_index_sequence<N>{}));
893   yidx += N;
894   PetscFunctionReturn(PETSC_SUCCESS);
895 }
896 
897 template <device::cupm::DeviceType T>
898 inline PetscErrorCode VecSeq_CUPM<T>::MDot_(std::false_type, Vec xin, PetscInt nv, const Vec yin[], PetscScalar *z, PetscDeviceContext dctx) noexcept
899 {
900   // the largest possible size of a batch
901   constexpr PetscInt batchsize = 8;
902   // how many sub streams to create, if nv <= batchsize we can do this without looping, so we
903   // do not create substreams. Note we don't create more than 8 streams, in practice we could
904   // not get more parallelism with higher numbers.
905   const auto num_sub_streams = nv > batchsize ? std::min((nv + batchsize) / batchsize, batchsize) : 0;
906   const auto n               = xin->map->n;
907   // number of vectors that we handle via the batches. note any singletons are handled by
908   // cublas, hence the nv-1.
909   const auto   nvbatch = ((nv % batchsize) == 1) ? nv - 1 : nv;
910   const auto   nwork   = nvbatch * MDOT_WORKGROUP_NUM;
911   PetscScalar *d_results;
912   cupmStream_t stream;
913 
914   PetscFunctionBegin;
915   PetscCall(GetHandlesFrom_(dctx, &stream));
916   // allocate scratchpad memory for the results of individual work groups
917   PetscCall(PetscDeviceMalloc(dctx, PETSC_MEMTYPE_CUPM(), nwork, &d_results));
918   {
919     const auto          xptr       = DeviceArrayRead(dctx, xin);
920     PetscInt            yidx       = 0;
921     auto                subidx     = 0;
922     auto                cur_stream = stream;
923     auto                cur_ctx    = dctx;
924     PetscDeviceContext *sub        = nullptr;
925     PetscStreamType     stype;
926 
927     // REVIEW ME: maybe PetscDeviceContextFork() should insert dctx into the first entry of
928     // sub. Ideally the parent context should also join in on the fork, but it is extremely
929     // fiddly to do so presently
930     PetscCall(PetscDeviceContextGetStreamType(dctx, &stype));
931     if (stype == PETSC_STREAM_GLOBAL_BLOCKING) stype = PETSC_STREAM_DEFAULT_BLOCKING;
932     // If we have a globally blocking stream create nonblocking streams instead (as we can
933     // locally exploit the parallelism). Otherwise use the prescribed stream type.
934     PetscCall(PetscDeviceContextForkWithStreamType(dctx, stype, num_sub_streams, &sub));
935     PetscCall(PetscLogGpuTimeBegin());
936     do {
937       if (num_sub_streams) {
938         cur_ctx = sub[subidx++ % num_sub_streams];
939         PetscCall(GetHandlesFrom_(cur_ctx, &cur_stream));
940       }
941       // REVIEW ME: Should probably try and load-balance these. Consider the case where nv = 9;
942       // it is very likely better to do 4+5 rather than 8+1
943       switch (nv - yidx) {
944       case 7:
945         PetscCall(MDot_kernel_dispatch_<7>(cur_ctx, cur_stream, xptr.data(), yin, n, d_results, yidx));
946         break;
947       case 6:
948         PetscCall(MDot_kernel_dispatch_<6>(cur_ctx, cur_stream, xptr.data(), yin, n, d_results, yidx));
949         break;
950       case 5:
951         PetscCall(MDot_kernel_dispatch_<5>(cur_ctx, cur_stream, xptr.data(), yin, n, d_results, yidx));
952         break;
953       case 4:
954         PetscCall(MDot_kernel_dispatch_<4>(cur_ctx, cur_stream, xptr.data(), yin, n, d_results, yidx));
955         break;
956       case 3:
957         PetscCall(MDot_kernel_dispatch_<3>(cur_ctx, cur_stream, xptr.data(), yin, n, d_results, yidx));
958         break;
959       case 2:
960         PetscCall(MDot_kernel_dispatch_<2>(cur_ctx, cur_stream, xptr.data(), yin, n, d_results, yidx));
961         break;
962       case 1: {
963         cupmBlasHandle_t cupmBlasHandle;
964 
965         PetscCall(GetHandlesFrom_(cur_ctx, &cupmBlasHandle));
966         PetscCallCUPMBLAS(cupmBlasXdot(cupmBlasHandle, static_cast<cupmBlasInt_t>(n), DeviceArrayRead(cur_ctx, yin[yidx]).cupmdata(), 1, xptr.cupmdata(), 1, cupmScalarPtrCast(z + yidx)));
967         ++yidx;
968       } break;
969       default: // 8 or more
970         PetscCall(MDot_kernel_dispatch_<8>(cur_ctx, cur_stream, xptr.data(), yin, n, d_results, yidx));
971         break;
972       }
973     } while (yidx < nv);
974     PetscCall(PetscLogGpuTimeEnd());
975     PetscCall(PetscDeviceContextJoin(dctx, num_sub_streams, PETSC_DEVICE_CONTEXT_JOIN_DESTROY, &sub));
976   }
977 
978   PetscCall(PetscCUPMLaunchKernel1D(nvbatch, 0, stream, kernels::sum_kernel, nvbatch, d_results));
979   // copy result of device reduction to host
980   PetscCall(PetscCUPMMemcpyAsync(z, d_results, nvbatch, cupmMemcpyDeviceToHost, stream));
981   // do these now while final reduction is in flight
982   PetscCall(PetscLogFlops(nwork));
983   PetscCall(PetscDeviceFree(dctx, d_results));
984   PetscFunctionReturn(PETSC_SUCCESS);
985 }
986 
987 #undef MDOT_WORKGROUP_NUM
988 #undef MDOT_WORKGROUP_SIZE
989 
990 template <device::cupm::DeviceType T>
991 inline PetscErrorCode VecSeq_CUPM<T>::MDot_(std::true_type, Vec xin, PetscInt nv, const Vec yin[], PetscScalar *z, PetscDeviceContext dctx) noexcept
992 {
993   // probably not worth it to run more than 8 of these at a time?
994   const auto          n_sub = PetscMin(nv, 8);
995   const auto          n     = static_cast<cupmBlasInt_t>(xin->map->n);
996   const auto          xptr  = DeviceArrayRead(dctx, xin);
997   PetscScalar        *d_z;
998   PetscDeviceContext *subctx;
999   cupmStream_t        stream;
1000 
1001   PetscFunctionBegin;
1002   PetscCall(GetHandlesFrom_(dctx, &stream));
1003   PetscCall(PetscDeviceMalloc(dctx, PETSC_MEMTYPE_CUPM(), nv, &d_z));
1004   PetscCall(PetscDeviceContextFork(dctx, n_sub, &subctx));
1005   PetscCall(PetscLogGpuTimeBegin());
1006   for (PetscInt i = 0; i < nv; ++i) {
1007     const auto            sub = subctx[i % n_sub];
1008     cupmBlasHandle_t      handle;
1009     cupmBlasPointerMode_t old_mode;
1010 
1011     PetscCall(GetHandlesFrom_(sub, &handle));
1012     PetscCallCUPMBLAS(cupmBlasGetPointerMode(handle, &old_mode));
1013     if (old_mode != CUPMBLAS_POINTER_MODE_DEVICE) PetscCallCUPMBLAS(cupmBlasSetPointerMode(handle, CUPMBLAS_POINTER_MODE_DEVICE));
1014     PetscCallCUPMBLAS(cupmBlasXdot(handle, n, DeviceArrayRead(sub, yin[i]), 1, xptr.cupmdata(), 1, cupmScalarPtrCast(d_z + i)));
1015     if (old_mode != CUPMBLAS_POINTER_MODE_DEVICE) PetscCallCUPMBLAS(cupmBlasSetPointerMode(handle, old_mode));
1016   }
1017   PetscCall(PetscLogGpuTimeEnd());
1018   PetscCall(PetscDeviceContextJoin(dctx, n_sub, PETSC_DEVICE_CONTEXT_JOIN_DESTROY, &subctx));
1019   PetscCall(PetscCUPMMemcpyAsync(z, d_z, nv, cupmMemcpyDeviceToHost, stream));
1020   PetscCall(PetscDeviceFree(dctx, d_z));
1021   // REVIEW ME: flops?????
1022   PetscFunctionReturn(PETSC_SUCCESS);
1023 }
1024 
1025 // v->ops->mdot
1026 template <device::cupm::DeviceType T>
1027 inline PetscErrorCode VecSeq_CUPM<T>::MDot(Vec xin, PetscInt nv, const Vec yin[], PetscScalar *z) noexcept
1028 {
1029   PetscFunctionBegin;
1030   if (PetscUnlikely(nv == 1)) {
1031     // dot handles nv = 0 correctly
1032     PetscCall(Dot(xin, const_cast<Vec>(yin[0]), z));
1033   } else if (const auto n = xin->map->n) {
1034     PetscDeviceContext dctx;
1035 
1036     PetscCheck(nv > 0, PETSC_COMM_SELF, PETSC_ERR_LIB, "Number of vectors provided to %s %" PetscInt_FMT " not positive", PETSC_FUNCTION_NAME, nv);
1037     PetscCall(GetHandles_(&dctx));
1038     PetscCall(MDot_(std::integral_constant<bool, PetscDefined(USE_COMPLEX)>{}, xin, nv, yin, z, dctx));
1039     // REVIEW ME: double count of flops??
1040     PetscCall(PetscLogGpuFlops(nv * (2 * n - 1)));
1041     PetscCall(PetscDeviceContextSynchronize(dctx));
1042   } else {
1043     PetscCall(PetscArrayzero(z, nv));
1044   }
1045   PetscFunctionReturn(PETSC_SUCCESS);
1046 }
1047 
1048 // v->ops->set
1049 template <device::cupm::DeviceType T>
1050 inline PetscErrorCode VecSeq_CUPM<T>::Set(Vec xin, PetscScalar alpha) noexcept
1051 {
1052   const auto         n = xin->map->n;
1053   PetscDeviceContext dctx;
1054   cupmStream_t       stream;
1055 
1056   PetscFunctionBegin;
1057   PetscCall(GetHandles_(&dctx, &stream));
1058   {
1059     const auto xptr = DeviceArrayWrite(dctx, xin);
1060 
1061     if (alpha == PetscScalar(0.0)) {
1062       PetscCall(PetscCUPMMemsetAsync(xptr.data(), 0, n, stream));
1063     } else {
1064       const auto dptr = thrust::device_pointer_cast(xptr.data());
1065 
1066       PetscCallThrust(THRUST_CALL(thrust::fill, stream, dptr, dptr + n, alpha));
1067     }
1068     if (n) PetscCall(PetscDeviceContextSynchronize(dctx)); // don't sync if we did nothing
1069   }
1070   PetscFunctionReturn(PETSC_SUCCESS);
1071 }
1072 
1073 // v->ops->scale
1074 template <device::cupm::DeviceType T>
1075 inline PetscErrorCode VecSeq_CUPM<T>::Scale(Vec xin, PetscScalar alpha) noexcept
1076 {
1077   PetscFunctionBegin;
1078   if (PetscUnlikely(alpha == PetscScalar(1.0))) PetscFunctionReturn(PETSC_SUCCESS);
1079   if (PetscUnlikely(alpha == PetscScalar(0.0))) {
1080     PetscCall(Set(xin, alpha));
1081   } else if (const auto n = static_cast<cupmBlasInt_t>(xin->map->n)) {
1082     PetscDeviceContext dctx;
1083     cupmBlasHandle_t   cupmBlasHandle;
1084 
1085     PetscCall(GetHandles_(&dctx, &cupmBlasHandle));
1086     PetscCall(PetscLogGpuTimeBegin());
1087     PetscCallCUPMBLAS(cupmBlasXscal(cupmBlasHandle, n, cupmScalarPtrCast(&alpha), DeviceArrayReadWrite(dctx, xin), 1));
1088     PetscCall(PetscLogGpuTimeEnd());
1089     PetscCall(PetscLogGpuFlops(n));
1090     PetscCall(PetscDeviceContextSynchronize(dctx));
1091   } else {
1092     PetscCall(MaybeIncrementEmptyLocalVec(xin));
1093   }
1094   PetscFunctionReturn(PETSC_SUCCESS);
1095 }
1096 
1097 // v->ops->tdot
1098 template <device::cupm::DeviceType T>
1099 inline PetscErrorCode VecSeq_CUPM<T>::TDot(Vec xin, Vec yin, PetscScalar *z) noexcept
1100 {
1101   PetscFunctionBegin;
1102   if (const auto n = static_cast<cupmBlasInt_t>(xin->map->n)) {
1103     PetscDeviceContext dctx;
1104     cupmBlasHandle_t   cupmBlasHandle;
1105 
1106     PetscCall(GetHandles_(&dctx, &cupmBlasHandle));
1107     PetscCall(PetscLogGpuTimeBegin());
1108     PetscCallCUPMBLAS(cupmBlasXdotu(cupmBlasHandle, n, DeviceArrayRead(dctx, xin), 1, DeviceArrayRead(dctx, yin), 1, cupmScalarPtrCast(z)));
1109     PetscCall(PetscLogGpuTimeEnd());
1110     PetscCall(PetscLogGpuFlops(2 * n - 1));
1111   } else {
1112     *z = 0.0;
1113   }
1114   PetscFunctionReturn(PETSC_SUCCESS);
1115 }
1116 
1117 // v->ops->copy
1118 template <device::cupm::DeviceType T>
1119 inline PetscErrorCode VecSeq_CUPM<T>::Copy(Vec xin, Vec yout) noexcept
1120 {
1121   PetscFunctionBegin;
1122   if (xin == yout) PetscFunctionReturn(PETSC_SUCCESS);
1123   if (const auto n = xin->map->n) {
1124     const auto xmask = xin->offloadmask;
1125     // silence buggy gcc warning: mode may be used uninitialized in this function
1126     auto               mode = cupmMemcpyDeviceToDevice;
1127     PetscDeviceContext dctx;
1128     cupmStream_t       stream;
1129 
1130     // translate from PetscOffloadMask to cupmMemcpyKind
1131     switch (const auto ymask = yout->offloadmask) {
1132     case PETSC_OFFLOAD_UNALLOCATED: {
1133       PetscBool yiscupm;
1134 
1135       PetscCall(PetscObjectTypeCompareAny(PetscObjectCast(yout), &yiscupm, VECSEQCUPM(), VECMPICUPM(), ""));
1136       if (yiscupm) {
1137         mode = PetscOffloadDevice(xmask) ? cupmMemcpyDeviceToDevice : cupmMemcpyHostToHost;
1138         break;
1139       }
1140     } // fall-through if unallocated and not cupm
1141 #if PETSC_CPP_VERSION >= 17
1142       [[fallthrough]];
1143 #endif
1144     case PETSC_OFFLOAD_CPU:
1145       mode = PetscOffloadHost(xmask) ? cupmMemcpyHostToHost : cupmMemcpyDeviceToHost;
1146       break;
1147     case PETSC_OFFLOAD_BOTH:
1148     case PETSC_OFFLOAD_GPU:
1149       mode = PetscOffloadDevice(xmask) ? cupmMemcpyDeviceToDevice : cupmMemcpyHostToDevice;
1150       break;
1151     default:
1152       SETERRQ(PETSC_COMM_SELF, PETSC_ERR_ARG_INCOMP, "Incompatible offload mask %s", PetscOffloadMaskToString(ymask));
1153     }
1154 
1155     PetscCall(GetHandles_(&dctx, &stream));
1156     switch (mode) {
1157     case cupmMemcpyDeviceToDevice: // the best case
1158     case cupmMemcpyHostToDevice: { // not terrible
1159       const auto yptr = DeviceArrayWrite(dctx, yout);
1160       const auto xptr = mode == cupmMemcpyDeviceToDevice ? DeviceArrayRead(dctx, xin).data() : HostArrayRead(dctx, xin).data();
1161 
1162       PetscCall(PetscLogGpuTimeBegin());
1163       PetscCall(PetscCUPMMemcpyAsync(yptr.data(), xptr, n, mode, stream));
1164       PetscCall(PetscLogGpuTimeEnd());
1165     } break;
1166     case cupmMemcpyDeviceToHost: // not great
1167     case cupmMemcpyHostToHost: { // worst case
1168       const auto   xptr = mode == cupmMemcpyDeviceToHost ? DeviceArrayRead(dctx, xin).data() : HostArrayRead(dctx, xin).data();
1169       PetscScalar *yptr;
1170 
1171       PetscCall(VecGetArrayWrite(yout, &yptr));
1172       if (mode == cupmMemcpyDeviceToHost) PetscCall(PetscLogGpuTimeBegin());
1173       PetscCall(PetscCUPMMemcpyAsync(yptr, xptr, n, mode, stream, /* force async */ true));
1174       if (mode == cupmMemcpyDeviceToHost) PetscCall(PetscLogGpuTimeEnd());
1175       PetscCall(VecRestoreArrayWrite(yout, &yptr));
1176     } break;
1177     default:
1178       SETERRQ(PETSC_COMM_SELF, PETSC_ERR_GPU, "Unknown cupmMemcpyKind %d", static_cast<int>(mode));
1179     }
1180     PetscCall(PetscDeviceContextSynchronize(dctx));
1181   } else {
1182     PetscCall(MaybeIncrementEmptyLocalVec(yout));
1183   }
1184   PetscFunctionReturn(PETSC_SUCCESS);
1185 }
1186 
1187 // v->ops->swap
1188 template <device::cupm::DeviceType T>
1189 inline PetscErrorCode VecSeq_CUPM<T>::Swap(Vec xin, Vec yin) noexcept
1190 {
1191   PetscFunctionBegin;
1192   if (xin == yin) PetscFunctionReturn(PETSC_SUCCESS);
1193   if (const auto n = static_cast<cupmBlasInt_t>(xin->map->n)) {
1194     PetscDeviceContext dctx;
1195     cupmBlasHandle_t   cupmBlasHandle;
1196 
1197     PetscCall(GetHandles_(&dctx, &cupmBlasHandle));
1198     PetscCall(PetscLogGpuTimeBegin());
1199     PetscCallCUPMBLAS(cupmBlasXswap(cupmBlasHandle, n, DeviceArrayReadWrite(dctx, xin), 1, DeviceArrayReadWrite(dctx, yin), 1));
1200     PetscCall(PetscLogGpuTimeEnd());
1201     PetscCall(PetscDeviceContextSynchronize(dctx));
1202   } else {
1203     PetscCall(MaybeIncrementEmptyLocalVec(xin));
1204     PetscCall(MaybeIncrementEmptyLocalVec(yin));
1205   }
1206   PetscFunctionReturn(PETSC_SUCCESS);
1207 }
1208 
1209 // v->ops->axpby
1210 template <device::cupm::DeviceType T>
1211 inline PetscErrorCode VecSeq_CUPM<T>::AXPBY(Vec yin, PetscScalar alpha, PetscScalar beta, Vec xin) noexcept
1212 {
1213   PetscFunctionBegin;
1214   if (alpha == PetscScalar(0.0)) {
1215     PetscCall(Scale(yin, beta));
1216   } else if (beta == PetscScalar(1.0)) {
1217     PetscCall(AXPY(yin, alpha, xin));
1218   } else if (alpha == PetscScalar(1.0)) {
1219     PetscCall(AYPX(yin, beta, xin));
1220   } else if (const auto n = static_cast<cupmBlasInt_t>(yin->map->n)) {
1221     const auto         betaIsZero = beta == PetscScalar(0.0);
1222     const auto         aptr       = cupmScalarPtrCast(&alpha);
1223     PetscDeviceContext dctx;
1224     cupmBlasHandle_t   cupmBlasHandle;
1225 
1226     PetscCall(GetHandles_(&dctx, &cupmBlasHandle));
1227     {
1228       const auto xptr = DeviceArrayRead(dctx, xin);
1229 
1230       if (betaIsZero /* beta = 0 */) {
1231         // here we can get away with purely write-only as we memcpy into it first
1232         const auto   yptr = DeviceArrayWrite(dctx, yin);
1233         cupmStream_t stream;
1234 
1235         PetscCall(GetHandlesFrom_(dctx, &stream));
1236         PetscCall(PetscLogGpuTimeBegin());
1237         PetscCall(PetscCUPMMemcpyAsync(yptr.data(), xptr.data(), n, cupmMemcpyDeviceToDevice, stream));
1238         PetscCallCUPMBLAS(cupmBlasXscal(cupmBlasHandle, n, aptr, yptr.cupmdata(), 1));
1239       } else {
1240         const auto yptr = DeviceArrayReadWrite(dctx, yin);
1241 
1242         PetscCall(PetscLogGpuTimeBegin());
1243         PetscCallCUPMBLAS(cupmBlasXscal(cupmBlasHandle, n, cupmScalarPtrCast(&beta), yptr.cupmdata(), 1));
1244         PetscCallCUPMBLAS(cupmBlasXaxpy(cupmBlasHandle, n, aptr, xptr.cupmdata(), 1, yptr.cupmdata(), 1));
1245       }
1246     }
1247     PetscCall(PetscLogGpuTimeEnd());
1248     PetscCall(PetscLogGpuFlops((betaIsZero ? 1 : 3) * n));
1249     PetscCall(PetscDeviceContextSynchronize(dctx));
1250   } else {
1251     PetscCall(MaybeIncrementEmptyLocalVec(yin));
1252   }
1253   PetscFunctionReturn(PETSC_SUCCESS);
1254 }
1255 
1256 // v->ops->axpbypcz
1257 template <device::cupm::DeviceType T>
1258 inline PetscErrorCode VecSeq_CUPM<T>::AXPBYPCZ(Vec zin, PetscScalar alpha, PetscScalar beta, PetscScalar gamma, Vec xin, Vec yin) noexcept
1259 {
1260   PetscFunctionBegin;
1261   if (gamma != PetscScalar(1.0)) PetscCall(Scale(zin, gamma));
1262   PetscCall(AXPY(zin, alpha, xin));
1263   PetscCall(AXPY(zin, beta, yin));
1264   PetscFunctionReturn(PETSC_SUCCESS);
1265 }
1266 
1267 // v->ops->norm
1268 template <device::cupm::DeviceType T>
1269 inline PetscErrorCode VecSeq_CUPM<T>::Norm(Vec xin, NormType type, PetscReal *z) noexcept
1270 {
1271   PetscDeviceContext dctx;
1272   cupmBlasHandle_t   cupmBlasHandle;
1273 
1274   PetscFunctionBegin;
1275   PetscCall(GetHandles_(&dctx, &cupmBlasHandle));
1276   if (const auto n = static_cast<cupmBlasInt_t>(xin->map->n)) {
1277     const auto xptr      = DeviceArrayRead(dctx, xin);
1278     PetscInt   flopCount = 0;
1279 
1280     PetscCall(PetscLogGpuTimeBegin());
1281     switch (type) {
1282     case NORM_1_AND_2:
1283     case NORM_1:
1284       PetscCallCUPMBLAS(cupmBlasXasum(cupmBlasHandle, n, xptr.cupmdata(), 1, cupmRealPtrCast(z)));
1285       flopCount = std::max(n - 1, 0);
1286       if (type == NORM_1) break;
1287       ++z; // fall-through
1288 #if PETSC_CPP_VERSION >= 17
1289       [[fallthrough]];
1290 #endif
1291     case NORM_2:
1292     case NORM_FROBENIUS:
1293       PetscCallCUPMBLAS(cupmBlasXnrm2(cupmBlasHandle, n, xptr.cupmdata(), 1, cupmRealPtrCast(z)));
1294       flopCount += std::max(2 * n - 1, 0); // += in case we've fallen through from NORM_1_AND_2
1295       break;
1296     case NORM_INFINITY: {
1297       cupmBlasInt_t max_loc = 0;
1298       PetscScalar   xv      = 0.;
1299       cupmStream_t  stream;
1300 
1301       PetscCall(GetHandlesFrom_(dctx, &stream));
1302       PetscCallCUPMBLAS(cupmBlasXamax(cupmBlasHandle, n, xptr.cupmdata(), 1, &max_loc));
1303       PetscCall(PetscCUPMMemcpyAsync(&xv, xptr.data() + max_loc - 1, 1, cupmMemcpyDeviceToHost, stream));
1304       *z = PetscAbsScalar(xv);
1305       // REVIEW ME: flopCount = ???
1306     } break;
1307     }
1308     PetscCall(PetscLogGpuTimeEnd());
1309     PetscCall(PetscLogGpuFlops(flopCount));
1310   } else {
1311     z[0]                    = 0.0;
1312     z[type == NORM_1_AND_2] = 0.0;
1313   }
1314   PetscFunctionReturn(PETSC_SUCCESS);
1315 }
1316 
1317 namespace detail
1318 {
1319 
1320 template <NormType wnormtype>
1321 class ErrorWNormTransformBase {
1322 public:
1323   using result_type = thrust::tuple<PetscReal, PetscReal, PetscReal, PetscInt, PetscInt, PetscInt>;
1324 
1325   constexpr explicit ErrorWNormTransformBase(PetscReal v) noexcept : ignore_max_{v} { }
1326 
1327 protected:
1328   struct NormTuple {
1329     PetscReal norm;
1330     PetscInt  loc;
1331   };
1332 
1333   PETSC_NODISCARD PETSC_HOSTDEVICE_INLINE_DECL static NormTuple compute_norm_(PetscReal err, PetscReal tol) noexcept
1334   {
1335     if (tol > 0.) {
1336       const auto val = err / tol;
1337 
1338       return {wnormtype == NORM_INFINITY ? val : PetscSqr(val), 1};
1339     } else {
1340       return {0.0, 0};
1341     }
1342   }
1343 
1344   PetscReal ignore_max_;
1345 };
1346 
1347 template <NormType wnormtype>
1348 struct ErrorWNormTransform : ErrorWNormTransformBase<wnormtype> {
1349   using base_type     = ErrorWNormTransformBase<wnormtype>;
1350   using result_type   = typename base_type::result_type;
1351   using argument_type = thrust::tuple<PetscScalar, PetscScalar, PetscScalar, PetscScalar>;
1352 
1353   using base_type::base_type;
1354 
1355   PETSC_NODISCARD PETSC_HOSTDEVICE_INLINE_DECL result_type operator()(const argument_type &x) const noexcept
1356   {
1357     const auto u     = x.get<0>();
1358     const auto y     = x.get<1>();
1359     const auto au    = PetscAbsScalar(u);
1360     const auto ay    = PetscAbsScalar(y);
1361     const auto skip  = au < this->ignore_max_ || ay < this->ignore_max_;
1362     const auto tola  = skip ? 0.0 : PetscRealPart(x.get<2>());
1363     const auto tolr  = skip ? 0.0 : PetscRealPart(x.get<3>()) * PetscMax(au, ay);
1364     const auto tol   = tola + tolr;
1365     const auto err   = PetscAbsScalar(u - y);
1366     const auto tup_a = this->compute_norm_(err, tola);
1367     const auto tup_r = this->compute_norm_(err, tolr);
1368     const auto tup_n = this->compute_norm_(err, tol);
1369 
1370     return {tup_n.norm, tup_a.norm, tup_r.norm, tup_n.loc, tup_a.loc, tup_r.loc};
1371   }
1372 };
1373 
1374 template <NormType wnormtype>
1375 struct ErrorWNormETransform : ErrorWNormTransformBase<wnormtype> {
1376   using base_type     = ErrorWNormTransformBase<wnormtype>;
1377   using result_type   = typename base_type::result_type;
1378   using argument_type = thrust::tuple<PetscScalar, PetscScalar, PetscScalar, PetscScalar, PetscScalar>;
1379 
1380   using base_type::base_type;
1381 
1382   PETSC_NODISCARD PETSC_HOSTDEVICE_INLINE_DECL result_type operator()(const argument_type &x) const noexcept
1383   {
1384     const auto au    = PetscAbsScalar(x.get<0>());
1385     const auto ay    = PetscAbsScalar(x.get<1>());
1386     const auto skip  = au < this->ignore_max_ || ay < this->ignore_max_;
1387     const auto tola  = skip ? 0.0 : PetscRealPart(x.get<3>());
1388     const auto tolr  = skip ? 0.0 : PetscRealPart(x.get<4>()) * PetscMax(au, ay);
1389     const auto tol   = tola + tolr;
1390     const auto err   = PetscAbsScalar(x.get<2>());
1391     const auto tup_a = this->compute_norm_(err, tola);
1392     const auto tup_r = this->compute_norm_(err, tolr);
1393     const auto tup_n = this->compute_norm_(err, tol);
1394 
1395     return {tup_n.norm, tup_a.norm, tup_r.norm, tup_n.loc, tup_a.loc, tup_r.loc};
1396   }
1397 };
1398 
1399 template <NormType wnormtype>
1400 struct ErrorWNormReduce {
1401   using value_type = typename ErrorWNormTransformBase<wnormtype>::result_type;
1402 
1403   PETSC_NODISCARD PETSC_HOSTDEVICE_INLINE_DECL value_type operator()(const value_type &lhs, const value_type &rhs) const noexcept
1404   {
1405     // cannot use lhs.get<0>() etc since the using decl above ambiguates the fact that
1406     // result_type is a template, so in order to fix this we would need to write:
1407     //
1408     // lhs.template get<0>()
1409     //
1410     // which is unseemly.
1411     if (wnormtype == NORM_INFINITY) {
1412       // clang-format off
1413       return {
1414         PetscMax(thrust::get<0>(lhs), thrust::get<0>(rhs)),
1415         PetscMax(thrust::get<1>(lhs), thrust::get<1>(rhs)),
1416         PetscMax(thrust::get<2>(lhs), thrust::get<2>(rhs)),
1417         thrust::get<3>(lhs) + thrust::get<3>(rhs),
1418         thrust::get<4>(lhs) + thrust::get<4>(rhs),
1419         thrust::get<5>(lhs) + thrust::get<5>(rhs)
1420       };
1421       // clang-format on
1422     } else {
1423       // clang-format off
1424       return {
1425         thrust::get<0>(lhs) + thrust::get<0>(rhs),
1426         thrust::get<1>(lhs) + thrust::get<1>(rhs),
1427         thrust::get<2>(lhs) + thrust::get<2>(rhs),
1428         thrust::get<3>(lhs) + thrust::get<3>(rhs),
1429         thrust::get<4>(lhs) + thrust::get<4>(rhs),
1430         thrust::get<5>(lhs) + thrust::get<5>(rhs)
1431       };
1432       // clang-format on
1433     }
1434   }
1435 };
1436 
1437 template <template <NormType> class WNormTransformType, typename Tuple, typename cupmStream_t>
1438 inline PetscErrorCode ExecuteWNorm(Tuple &&first, Tuple &&last, NormType wnormtype, cupmStream_t stream, PetscReal ignore_max, PetscReal *norm, PetscInt *norm_loc, PetscReal *norma, PetscInt *norma_loc, PetscReal *normr, PetscInt *normr_loc) noexcept
1439 {
1440   auto      begin = thrust::make_zip_iterator(std::forward<Tuple>(first));
1441   auto      end   = thrust::make_zip_iterator(std::forward<Tuple>(last));
1442   PetscReal n = 0, na = 0, nr = 0;
1443   PetscInt  n_loc = 0, na_loc = 0, nr_loc = 0;
1444 
1445   PetscFunctionBegin;
1446   // clang-format off
1447   if (wnormtype == NORM_INFINITY) {
1448     PetscCallThrust(
1449       thrust::tie(*norm, *norma, *normr, *norm_loc, *norma_loc, *normr_loc) = THRUST_CALL(
1450         thrust::transform_reduce,
1451         stream,
1452         std::move(begin),
1453         std::move(end),
1454         WNormTransformType<NORM_INFINITY>{ignore_max},
1455         thrust::make_tuple(n, na, nr, n_loc, na_loc, nr_loc),
1456         ErrorWNormReduce<NORM_INFINITY>{}
1457       )
1458     );
1459   } else {
1460     PetscCallThrust(
1461       thrust::tie(*norm, *norma, *normr, *norm_loc, *norma_loc, *normr_loc) = THRUST_CALL(
1462         thrust::transform_reduce,
1463         stream,
1464         std::move(begin),
1465         std::move(end),
1466         WNormTransformType<NORM_2>{ignore_max},
1467         thrust::make_tuple(n, na, nr, n_loc, na_loc, nr_loc),
1468         ErrorWNormReduce<NORM_2>{}
1469       )
1470     );
1471   }
1472   // clang-format on
1473   if (wnormtype == NORM_2) {
1474     *norm  = PetscSqrtReal(*norm);
1475     *norma = PetscSqrtReal(*norma);
1476     *normr = PetscSqrtReal(*normr);
1477   }
1478   PetscFunctionReturn(PETSC_SUCCESS);
1479 }
1480 
1481 } // namespace detail
1482 
1483 // v->ops->errorwnorm
1484 template <device::cupm::DeviceType T>
1485 inline PetscErrorCode VecSeq_CUPM<T>::ErrorWnorm(Vec U, Vec Y, Vec E, NormType wnormtype, PetscReal atol, Vec vatol, PetscReal rtol, Vec vrtol, PetscReal ignore_max, PetscReal *norm, PetscInt *norm_loc, PetscReal *norma, PetscInt *norma_loc, PetscReal *normr, PetscInt *normr_loc) noexcept
1486 {
1487   const auto         nl  = U->map->n;
1488   auto               ait = thrust::make_constant_iterator(static_cast<PetscScalar>(atol));
1489   auto               rit = thrust::make_constant_iterator(static_cast<PetscScalar>(rtol));
1490   PetscDeviceContext dctx;
1491   cupmStream_t       stream;
1492 
1493   PetscFunctionBegin;
1494   PetscCall(GetHandles_(&dctx, &stream));
1495   {
1496     const auto ConditionalDeviceArrayRead = [&](Vec v) {
1497       if (v) {
1498         return thrust::device_pointer_cast(DeviceArrayRead(dctx, v).data());
1499       } else {
1500         return thrust::device_ptr<PetscScalar>{nullptr};
1501       }
1502     };
1503 
1504     const auto uarr = DeviceArrayRead(dctx, U);
1505     const auto yarr = DeviceArrayRead(dctx, Y);
1506     const auto uptr = thrust::device_pointer_cast(uarr.data());
1507     const auto yptr = thrust::device_pointer_cast(yarr.data());
1508     const auto eptr = ConditionalDeviceArrayRead(E);
1509     const auto rptr = ConditionalDeviceArrayRead(vrtol);
1510     const auto aptr = ConditionalDeviceArrayRead(vatol);
1511 
1512     if (!vatol && !vrtol) {
1513       if (E) {
1514         // clang-format off
1515         PetscCall(
1516           detail::ExecuteWNorm<detail::ErrorWNormETransform>(
1517             thrust::make_tuple(uptr, yptr, eptr, ait, rit),
1518             thrust::make_tuple(uptr + nl, yptr + nl, eptr + nl, ait, rit),
1519             wnormtype, stream, ignore_max, norm, norm_loc, norma, norma_loc, normr, normr_loc
1520           )
1521         );
1522         // clang-format on
1523       } else {
1524         // clang-format off
1525         PetscCall(
1526           detail::ExecuteWNorm<detail::ErrorWNormTransform>(
1527             thrust::make_tuple(uptr, yptr, ait, rit),
1528             thrust::make_tuple(uptr + nl, yptr + nl, ait, rit),
1529             wnormtype, stream, ignore_max, norm, norm_loc, norma, norma_loc, normr, normr_loc
1530           )
1531         );
1532         // clang-format on
1533       }
1534     } else if (!vatol) {
1535       if (E) {
1536         // clang-format off
1537         PetscCall(
1538           detail::ExecuteWNorm<detail::ErrorWNormETransform>(
1539             thrust::make_tuple(uptr, yptr, eptr, ait, rptr),
1540             thrust::make_tuple(uptr + nl, yptr + nl, eptr + nl, ait, rptr + nl),
1541             wnormtype, stream, ignore_max, norm, norm_loc, norma, norma_loc, normr, normr_loc
1542           )
1543         );
1544         // clang-format on
1545       } else {
1546         // clang-format off
1547         PetscCall(
1548           detail::ExecuteWNorm<detail::ErrorWNormTransform>(
1549             thrust::make_tuple(uptr, yptr, ait, rptr),
1550             thrust::make_tuple(uptr + nl, yptr + nl, ait, rptr + nl),
1551             wnormtype, stream, ignore_max, norm, norm_loc, norma, norma_loc, normr, normr_loc
1552           )
1553         );
1554         // clang-format on
1555       }
1556     } else if (!vrtol) {
1557       if (E) {
1558         // clang-format off
1559           PetscCall(
1560             detail::ExecuteWNorm<detail::ErrorWNormETransform>(
1561               thrust::make_tuple(uptr, yptr, eptr, aptr, rit),
1562               thrust::make_tuple(uptr + nl, yptr + nl, eptr + nl, aptr + nl, rit),
1563               wnormtype, stream, ignore_max, norm, norm_loc, norma, norma_loc, normr, normr_loc
1564             )
1565           );
1566         // clang-format on
1567       } else {
1568         // clang-format off
1569           PetscCall(
1570             detail::ExecuteWNorm<detail::ErrorWNormTransform>(
1571               thrust::make_tuple(uptr, yptr, aptr, rit),
1572               thrust::make_tuple(uptr + nl, yptr + nl, aptr + nl, rit),
1573               wnormtype, stream, ignore_max, norm, norm_loc, norma, norma_loc, normr, normr_loc
1574             )
1575           );
1576         // clang-format on
1577       }
1578     } else {
1579       if (E) {
1580         // clang-format off
1581           PetscCall(
1582             detail::ExecuteWNorm<detail::ErrorWNormETransform>(
1583               thrust::make_tuple(uptr, yptr, eptr, aptr, rptr),
1584               thrust::make_tuple(uptr + nl, yptr + nl, eptr + nl, aptr + nl, rptr + nl),
1585               wnormtype, stream, ignore_max, norm, norm_loc, norma, norma_loc, normr, normr_loc
1586             )
1587           );
1588         // clang-format on
1589       } else {
1590         // clang-format off
1591           PetscCall(
1592             detail::ExecuteWNorm<detail::ErrorWNormTransform>(
1593               thrust::make_tuple(uptr, yptr, aptr, rptr),
1594               thrust::make_tuple(uptr + nl, yptr + nl, aptr + nl, rptr + nl),
1595               wnormtype, stream, ignore_max, norm, norm_loc, norma, norma_loc, normr, normr_loc
1596             )
1597           );
1598         // clang-format on
1599       }
1600     }
1601   }
1602   PetscFunctionReturn(PETSC_SUCCESS);
1603 }
1604 
1605 namespace detail
1606 {
1607 struct dotnorm2_mult {
1608   PETSC_NODISCARD PETSC_HOSTDEVICE_INLINE_DECL thrust::tuple<PetscScalar, PetscScalar> operator()(const PetscScalar &s, const PetscScalar &t) const noexcept
1609   {
1610     const auto conjt = PetscConj(t);
1611 
1612     return {s * conjt, t * conjt};
1613   }
1614 };
1615 
1616 // it is positively __bananas__ that thrust does not define default operator+ for tuples... I
1617 // would do it myself but now I am worried that they do so on purpose...
1618 struct dotnorm2_tuple_plus {
1619   using value_type = thrust::tuple<PetscScalar, PetscScalar>;
1620 
1621   PETSC_NODISCARD PETSC_HOSTDEVICE_INLINE_DECL value_type operator()(const value_type &lhs, const value_type &rhs) const noexcept { return {lhs.get<0>() + rhs.get<0>(), lhs.get<1>() + rhs.get<1>()}; }
1622 };
1623 
1624 } // namespace detail
1625 
1626 // v->ops->dotnorm2
1627 template <device::cupm::DeviceType T>
1628 inline PetscErrorCode VecSeq_CUPM<T>::DotNorm2(Vec s, Vec t, PetscScalar *dp, PetscScalar *nm) noexcept
1629 {
1630   PetscDeviceContext dctx;
1631   cupmStream_t       stream;
1632 
1633   PetscFunctionBegin;
1634   PetscCall(GetHandles_(&dctx, &stream));
1635   {
1636     PetscScalar dpt = 0.0, nmt = 0.0;
1637     const auto  sdptr = thrust::device_pointer_cast(DeviceArrayRead(dctx, s).data());
1638 
1639     // clang-format off
1640     PetscCallThrust(
1641       thrust::tie(*dp, *nm) = THRUST_CALL(
1642         thrust::inner_product,
1643         stream,
1644         sdptr, sdptr+s->map->n, thrust::device_pointer_cast(DeviceArrayRead(dctx, t).data()),
1645         thrust::make_tuple(dpt, nmt),
1646         detail::dotnorm2_tuple_plus{}, detail::dotnorm2_mult{}
1647       );
1648     );
1649     // clang-format on
1650   }
1651   PetscFunctionReturn(PETSC_SUCCESS);
1652 }
1653 
1654 namespace detail
1655 {
1656 struct conjugate {
1657   PETSC_NODISCARD PETSC_HOSTDEVICE_INLINE_DECL PetscScalar operator()(PetscScalar x) const noexcept { return PetscConj(x); }
1658 };
1659 
1660 } // namespace detail
1661 
1662 // v->ops->conjugate
1663 template <device::cupm::DeviceType T>
1664 inline PetscErrorCode VecSeq_CUPM<T>::Conjugate(Vec xin) noexcept
1665 {
1666   PetscFunctionBegin;
1667   if (PetscDefined(USE_COMPLEX)) PetscCall(PointwiseUnary_(detail::conjugate{}, xin));
1668   PetscFunctionReturn(PETSC_SUCCESS);
1669 }
1670 
1671 namespace detail
1672 {
1673 struct real_part {
1674   PETSC_NODISCARD PETSC_HOSTDEVICE_INLINE_DECL thrust::tuple<PetscReal, PetscInt> operator()(const thrust::tuple<PetscScalar, PetscInt> &x) const { return {PetscRealPart(x.get<0>()), x.get<1>()}; }
1675 
1676   PETSC_NODISCARD PETSC_HOSTDEVICE_INLINE_DECL PetscReal operator()(PetscScalar x) const { return PetscRealPart(x); }
1677 };
1678 
1679 // deriving from Operator allows us to "store" an instance of the operator in the class but
1680 // also take advantage of empty base class optimization if the operator is stateless
1681 template <typename Operator>
1682 class tuple_compare : Operator {
1683 public:
1684   using tuple_type    = thrust::tuple<PetscReal, PetscInt>;
1685   using operator_type = Operator;
1686 
1687   PETSC_NODISCARD PETSC_HOSTDEVICE_INLINE_DECL tuple_type operator()(const tuple_type &x, const tuple_type &y) const noexcept
1688   {
1689     if (op_()(y.get<0>(), x.get<0>())) {
1690       // if y is strictly greater/less than x, return y
1691       return y;
1692     } else if (y.get<0>() == x.get<0>()) {
1693       // if equal, prefer lower index
1694       return y.get<1>() < x.get<1>() ? y : x;
1695     }
1696     // otherwise return x
1697     return x;
1698   }
1699 
1700 private:
1701   PETSC_NODISCARD PETSC_HOSTDEVICE_INLINE_DECL const operator_type &op_() const noexcept { return *this; }
1702 };
1703 
1704 } // namespace detail
1705 
1706 template <device::cupm::DeviceType T>
1707 template <typename TupleFuncT, typename UnaryFuncT>
1708 inline PetscErrorCode VecSeq_CUPM<T>::MinMax_(TupleFuncT &&tuple_ftr, UnaryFuncT &&unary_ftr, Vec v, PetscInt *p, PetscReal *m) noexcept
1709 {
1710   PetscFunctionBegin;
1711   PetscCheckTypeNames(v, VECSEQCUPM(), VECMPICUPM());
1712   if (p) *p = -1;
1713   if (const auto n = v->map->n) {
1714     PetscDeviceContext dctx;
1715     cupmStream_t       stream;
1716 
1717     PetscCall(GetHandles_(&dctx, &stream));
1718     // needed to:
1719     // 1. switch between transform_reduce and reduce
1720     // 2. strip the real_part functor from the arguments
1721 #if PetscDefined(USE_COMPLEX)
1722   #define THRUST_MINMAX_REDUCE(...) THRUST_CALL(thrust::transform_reduce, __VA_ARGS__)
1723 #else
1724   #define THRUST_MINMAX_REDUCE(s, b, e, real_part__, ...) THRUST_CALL(thrust::reduce, s, b, e, __VA_ARGS__)
1725 #endif
1726     {
1727       const auto vptr = thrust::device_pointer_cast(DeviceArrayRead(dctx, v).data());
1728 
1729       if (p) {
1730         // clang-format off
1731         const auto zip = thrust::make_zip_iterator(
1732           thrust::make_tuple(std::move(vptr), thrust::make_counting_iterator(PetscInt{0}))
1733         );
1734         // clang-format on
1735         // need to use preprocessor conditionals since otherwise thrust complains about not being
1736         // able to convert a thrust::device_reference<PetscScalar> to a PetscReal on complex
1737         // builds...
1738         // clang-format off
1739         PetscCallThrust(
1740           thrust::tie(*m, *p) = THRUST_MINMAX_REDUCE(
1741             stream, zip, zip + n, detail::real_part{},
1742             thrust::make_tuple(*m, *p), std::forward<TupleFuncT>(tuple_ftr)
1743           );
1744         );
1745         // clang-format on
1746       } else {
1747         // clang-format off
1748         PetscCallThrust(
1749           *m = THRUST_MINMAX_REDUCE(
1750             stream, vptr, vptr + n, detail::real_part{},
1751             *m, std::forward<UnaryFuncT>(unary_ftr)
1752           );
1753         );
1754         // clang-format on
1755       }
1756     }
1757 #undef THRUST_MINMAX_REDUCE
1758   }
1759   // REVIEW ME: flops?
1760   PetscFunctionReturn(PETSC_SUCCESS);
1761 }
1762 
1763 // v->ops->max
1764 template <device::cupm::DeviceType T>
1765 inline PetscErrorCode VecSeq_CUPM<T>::Max(Vec v, PetscInt *p, PetscReal *m) noexcept
1766 {
1767   using tuple_functor = detail::tuple_compare<thrust::greater<PetscReal>>;
1768   using unary_functor = thrust::maximum<PetscReal>;
1769 
1770   PetscFunctionBegin;
1771   *m = PETSC_MIN_REAL;
1772   // use {} constructor syntax otherwise most vexing parse
1773   PetscCall(MinMax_(tuple_functor{}, unary_functor{}, v, p, m));
1774   PetscFunctionReturn(PETSC_SUCCESS);
1775 }
1776 
1777 // v->ops->min
1778 template <device::cupm::DeviceType T>
1779 inline PetscErrorCode VecSeq_CUPM<T>::Min(Vec v, PetscInt *p, PetscReal *m) noexcept
1780 {
1781   using tuple_functor = detail::tuple_compare<thrust::less<PetscReal>>;
1782   using unary_functor = thrust::minimum<PetscReal>;
1783 
1784   PetscFunctionBegin;
1785   *m = PETSC_MAX_REAL;
1786   // use {} constructor syntax otherwise most vexing parse
1787   PetscCall(MinMax_(tuple_functor{}, unary_functor{}, v, p, m));
1788   PetscFunctionReturn(PETSC_SUCCESS);
1789 }
1790 
1791 // v->ops->sum
1792 template <device::cupm::DeviceType T>
1793 inline PetscErrorCode VecSeq_CUPM<T>::Sum(Vec v, PetscScalar *sum) noexcept
1794 {
1795   PetscFunctionBegin;
1796   if (const auto n = v->map->n) {
1797     PetscDeviceContext dctx;
1798     cupmStream_t       stream;
1799 
1800     PetscCall(GetHandles_(&dctx, &stream));
1801     const auto dptr = thrust::device_pointer_cast(DeviceArrayRead(dctx, v).data());
1802     // REVIEW ME: why not cupmBlasXasum()?
1803     PetscCallThrust(*sum = THRUST_CALL(thrust::reduce, stream, dptr, dptr + n, PetscScalar{0.0}););
1804     // REVIEW ME: must be at least n additions
1805     PetscCall(PetscLogGpuFlops(n));
1806   } else {
1807     *sum = 0.0;
1808   }
1809   PetscFunctionReturn(PETSC_SUCCESS);
1810 }
1811 
1812 template <device::cupm::DeviceType T>
1813 inline PetscErrorCode VecSeq_CUPM<T>::Shift(Vec v, PetscScalar shift) noexcept
1814 {
1815   PetscFunctionBegin;
1816   PetscCall(PointwiseUnary_(device::cupm::functors::make_plus_equals(shift), v));
1817   PetscFunctionReturn(PETSC_SUCCESS);
1818 }
1819 
1820 template <device::cupm::DeviceType T>
1821 inline PetscErrorCode VecSeq_CUPM<T>::SetRandom(Vec v, PetscRandom rand) noexcept
1822 {
1823   PetscFunctionBegin;
1824   if (const auto n = v->map->n) {
1825     PetscBool          iscurand;
1826     PetscDeviceContext dctx;
1827 
1828     PetscCall(GetHandles_(&dctx));
1829     PetscCall(PetscObjectTypeCompare(PetscObjectCast(rand), PETSCCURAND, &iscurand));
1830     if (iscurand) PetscCall(PetscRandomGetValues(rand, n, DeviceArrayWrite(dctx, v)));
1831     else PetscCall(PetscRandomGetValues(rand, n, HostArrayWrite(dctx, v)));
1832   } else {
1833     PetscCall(MaybeIncrementEmptyLocalVec(v));
1834   }
1835   // REVIEW ME: flops????
1836   // REVIEW ME: Timing???
1837   PetscFunctionReturn(PETSC_SUCCESS);
1838 }
1839 
1840 // v->ops->setpreallocation
1841 template <device::cupm::DeviceType T>
1842 inline PetscErrorCode VecSeq_CUPM<T>::SetPreallocationCOO(Vec v, PetscCount ncoo, const PetscInt coo_i[]) noexcept
1843 {
1844   PetscDeviceContext dctx;
1845 
1846   PetscFunctionBegin;
1847   PetscCall(GetHandles_(&dctx));
1848   PetscCall(VecSetPreallocationCOO_Seq(v, ncoo, coo_i));
1849   PetscCall(SetPreallocationCOO_CUPMBase(v, ncoo, coo_i, dctx));
1850   PetscFunctionReturn(PETSC_SUCCESS);
1851 }
1852 
1853 namespace kernels
1854 {
1855 template <typename F>
1856 PETSC_DEVICE_INLINE_DECL void add_coo_values_impl(const PetscScalar *PETSC_RESTRICT vv, PetscCount n, const PetscCount *PETSC_RESTRICT jmap, const PetscCount *PETSC_RESTRICT perm, InsertMode imode, PetscScalar *PETSC_RESTRICT xv, F &&xvindex)
1857 {
1858   ::Petsc::device::cupm::kernels::util::grid_stride_1D(n, [=](PetscCount i) {
1859     const auto  end = jmap[i + 1];
1860     const auto  idx = xvindex(i);
1861     PetscScalar sum = 0.0;
1862 
1863     for (auto k = jmap[i]; k < end; ++k) sum += vv[perm[k]];
1864 
1865     if (imode == INSERT_VALUES) {
1866       xv[idx] = sum;
1867     } else {
1868       xv[idx] += sum;
1869     }
1870   });
1871   return;
1872 }
1873 
1874 namespace
1875 {
1876 
1877 PETSC_KERNEL_DECL void add_coo_values(const PetscScalar *PETSC_RESTRICT v, PetscCount n, const PetscCount *PETSC_RESTRICT jmap1, const PetscCount *PETSC_RESTRICT perm1, InsertMode imode, PetscScalar *PETSC_RESTRICT xv)
1878 {
1879   add_coo_values_impl(v, n, jmap1, perm1, imode, xv, [](PetscCount i) { return i; });
1880   return;
1881 }
1882 
1883 } // namespace
1884 
1885 #if PetscDefined(USING_HCC)
1886 namespace do_not_use
1887 {
1888 
1889 // Needed to silence clang warning:
1890 //
1891 // warning: function 'FUNCTION NAME' is not needed and will not be emitted
1892 //
1893 // The warning is silly, since the function *is* used, however the host compiler does not
1894 // appear see this. Likely because the function using it is in a template.
1895 //
1896 // This warning appeared in clang-11, and still persists until clang-15 (21/02/2023)
1897 inline void silence_warning_function_sum_kernel_is_not_needed_and_will_not_be_emitted()
1898 {
1899   (void)sum_kernel;
1900 }
1901 
1902 inline void silence_warning_function_add_coo_values_is_not_needed_and_will_not_be_emitted()
1903 {
1904   (void)add_coo_values;
1905 }
1906 
1907 } // namespace do_not_use
1908 #endif
1909 
1910 } // namespace kernels
1911 
1912 // v->ops->setvaluescoo
1913 template <device::cupm::DeviceType T>
1914 inline PetscErrorCode VecSeq_CUPM<T>::SetValuesCOO(Vec x, const PetscScalar v[], InsertMode imode) noexcept
1915 {
1916   auto               vv = const_cast<PetscScalar *>(v);
1917   PetscMemType       memtype;
1918   PetscDeviceContext dctx;
1919   cupmStream_t       stream;
1920 
1921   PetscFunctionBegin;
1922   PetscCall(GetHandles_(&dctx, &stream));
1923   PetscCall(PetscGetMemType(v, &memtype));
1924   if (PetscMemTypeHost(memtype)) {
1925     const auto size = VecIMPLCast(x)->coo_n;
1926 
1927     // If user gave v[] in host, we might need to copy it to device if any
1928     PetscCall(PetscDeviceMalloc(dctx, PETSC_MEMTYPE_CUPM(), size, &vv));
1929     PetscCall(PetscCUPMMemcpyAsync(vv, v, size, cupmMemcpyHostToDevice, stream));
1930   }
1931 
1932   if (const auto n = x->map->n) {
1933     const auto vcu = VecCUPMCast(x);
1934 
1935     PetscCall(PetscCUPMLaunchKernel1D(n, 0, stream, kernels::add_coo_values, vv, n, vcu->jmap1_d, vcu->perm1_d, imode, imode == INSERT_VALUES ? DeviceArrayWrite(dctx, x).data() : DeviceArrayReadWrite(dctx, x).data()));
1936   } else {
1937     PetscCall(MaybeIncrementEmptyLocalVec(x));
1938   }
1939 
1940   if (PetscMemTypeHost(memtype)) PetscCall(PetscDeviceFree(dctx, vv));
1941   PetscCall(PetscDeviceContextSynchronize(dctx));
1942   PetscFunctionReturn(PETSC_SUCCESS);
1943 }
1944 
1945 } // namespace impl
1946 
1947 // ==========================================================================================
1948 // VecSeq_CUPM - Implementations
1949 // ==========================================================================================
1950 
1951 namespace
1952 {
1953 
1954 template <device::cupm::DeviceType T>
1955 inline PetscErrorCode VecCreateSeqCUPMAsync(MPI_Comm comm, PetscInt n, Vec *v) noexcept
1956 {
1957   PetscFunctionBegin;
1958   PetscValidPointer(v, 4);
1959   PetscCall(impl::VecSeq_CUPM<T>::CreateSeqCUPM(comm, 0, n, v, PETSC_TRUE));
1960   PetscFunctionReturn(PETSC_SUCCESS);
1961 }
1962 
1963 template <device::cupm::DeviceType T>
1964 inline PetscErrorCode VecCreateSeqCUPMWithArraysAsync(MPI_Comm comm, PetscInt bs, PetscInt n, const PetscScalar cpuarray[], const PetscScalar gpuarray[], Vec *v) noexcept
1965 {
1966   PetscFunctionBegin;
1967   if (n && cpuarray) PetscValidScalarPointer(cpuarray, 4);
1968   PetscValidPointer(v, 6);
1969   PetscCall(impl::VecSeq_CUPM<T>::CreateSeqCUPMWithBothArrays(comm, bs, n, cpuarray, gpuarray, v));
1970   PetscFunctionReturn(PETSC_SUCCESS);
1971 }
1972 
1973 template <PetscMemoryAccessMode mode, device::cupm::DeviceType T>
1974 inline PetscErrorCode VecCUPMGetArrayAsync_Private(Vec v, PetscScalar **a, PetscDeviceContext dctx) noexcept
1975 {
1976   PetscFunctionBegin;
1977   PetscValidHeaderSpecific(v, VEC_CLASSID, 1);
1978   PetscValidPointer(a, 2);
1979   PetscCall(PetscDeviceContextGetOptionalNullContext_Internal(&dctx));
1980   PetscCall(impl::VecSeq_CUPM<T>::template GetArray<PETSC_MEMTYPE_DEVICE, mode>(v, a, dctx));
1981   PetscFunctionReturn(PETSC_SUCCESS);
1982 }
1983 
1984 template <PetscMemoryAccessMode mode, device::cupm::DeviceType T>
1985 inline PetscErrorCode VecCUPMRestoreArrayAsync_Private(Vec v, PetscScalar **a, PetscDeviceContext dctx) noexcept
1986 {
1987   PetscFunctionBegin;
1988   PetscValidHeaderSpecific(v, VEC_CLASSID, 1);
1989   PetscCall(PetscDeviceContextGetOptionalNullContext_Internal(&dctx));
1990   PetscCall(impl::VecSeq_CUPM<T>::template RestoreArray<PETSC_MEMTYPE_DEVICE, mode>(v, a, dctx));
1991   PetscFunctionReturn(PETSC_SUCCESS);
1992 }
1993 
1994 template <device::cupm::DeviceType T>
1995 inline PetscErrorCode VecCUPMGetArrayAsync(Vec v, PetscScalar **a, PetscDeviceContext dctx = nullptr) noexcept
1996 {
1997   PetscFunctionBegin;
1998   PetscCall(VecCUPMGetArrayAsync_Private<PETSC_MEMORY_ACCESS_READ_WRITE, T>(v, a, dctx));
1999   PetscFunctionReturn(PETSC_SUCCESS);
2000 }
2001 
2002 template <device::cupm::DeviceType T>
2003 inline PetscErrorCode VecCUPMRestoreArrayAsync(Vec v, PetscScalar **a, PetscDeviceContext dctx = nullptr) noexcept
2004 {
2005   PetscFunctionBegin;
2006   PetscCall(VecCUPMRestoreArrayAsync_Private<PETSC_MEMORY_ACCESS_READ_WRITE, T>(v, a, dctx));
2007   PetscFunctionReturn(PETSC_SUCCESS);
2008 }
2009 
2010 template <device::cupm::DeviceType T>
2011 inline PetscErrorCode VecCUPMGetArrayReadAsync(Vec v, const PetscScalar **a, PetscDeviceContext dctx = nullptr) noexcept
2012 {
2013   PetscFunctionBegin;
2014   PetscCall(VecCUPMGetArrayAsync_Private<PETSC_MEMORY_ACCESS_READ, T>(v, const_cast<PetscScalar **>(a), dctx));
2015   PetscFunctionReturn(PETSC_SUCCESS);
2016 }
2017 
2018 template <device::cupm::DeviceType T>
2019 inline PetscErrorCode VecCUPMRestoreArrayReadAsync(Vec v, const PetscScalar **a, PetscDeviceContext dctx = nullptr) noexcept
2020 {
2021   PetscFunctionBegin;
2022   PetscCall(VecCUPMRestoreArrayAsync_Private<PETSC_MEMORY_ACCESS_READ, T>(v, const_cast<PetscScalar **>(a), dctx));
2023   PetscFunctionReturn(PETSC_SUCCESS);
2024 }
2025 
2026 template <device::cupm::DeviceType T>
2027 inline PetscErrorCode VecCUPMGetArrayWriteAsync(Vec v, PetscScalar **a, PetscDeviceContext dctx = nullptr) noexcept
2028 {
2029   PetscFunctionBegin;
2030   PetscCall(VecCUPMGetArrayAsync_Private<PETSC_MEMORY_ACCESS_WRITE, T>(v, a, dctx));
2031   PetscFunctionReturn(PETSC_SUCCESS);
2032 }
2033 
2034 template <device::cupm::DeviceType T>
2035 inline PetscErrorCode VecCUPMRestoreArrayWriteAsync(Vec v, PetscScalar **a, PetscDeviceContext dctx = nullptr) noexcept
2036 {
2037   PetscFunctionBegin;
2038   PetscCall(VecCUPMRestoreArrayAsync_Private<PETSC_MEMORY_ACCESS_WRITE, T>(v, a, dctx));
2039   PetscFunctionReturn(PETSC_SUCCESS);
2040 }
2041 
2042 template <device::cupm::DeviceType T>
2043 inline PetscErrorCode VecCUPMPlaceArrayAsync(Vec vin, const PetscScalar a[]) noexcept
2044 {
2045   PetscFunctionBegin;
2046   PetscValidHeaderSpecific(vin, VEC_CLASSID, 1);
2047   PetscCall(impl::VecSeq_CUPM<T>::template PlaceArray<PETSC_MEMTYPE_DEVICE>(vin, a));
2048   PetscFunctionReturn(PETSC_SUCCESS);
2049 }
2050 
2051 template <device::cupm::DeviceType T>
2052 inline PetscErrorCode VecCUPMReplaceArrayAsync(Vec vin, const PetscScalar a[]) noexcept
2053 {
2054   PetscFunctionBegin;
2055   PetscValidHeaderSpecific(vin, VEC_CLASSID, 1);
2056   PetscCall(impl::VecSeq_CUPM<T>::template ReplaceArray<PETSC_MEMTYPE_DEVICE>(vin, a));
2057   PetscFunctionReturn(PETSC_SUCCESS);
2058 }
2059 
2060 template <device::cupm::DeviceType T>
2061 inline PetscErrorCode VecCUPMResetArrayAsync(Vec vin) noexcept
2062 {
2063   PetscFunctionBegin;
2064   PetscValidHeaderSpecific(vin, VEC_CLASSID, 1);
2065   PetscCall(impl::VecSeq_CUPM<T>::template ResetArray<PETSC_MEMTYPE_DEVICE>(vin));
2066   PetscFunctionReturn(PETSC_SUCCESS);
2067 }
2068 
2069 } // anonymous namespace
2070 
2071 } // namespace cupm
2072 
2073 } // namespace vec
2074 
2075 } // namespace Petsc
2076 
2077 #endif // PETSCVECSEQCUPM_HPP
2078