xref: /petsc/include/petsc/private/matdensecupmimpl.h (revision d9acb416d05abeed0a33bde3a81aeb2ea0364f6a)
1 #pragma once
2 
3 #define PETSC_SKIP_IMMINTRIN_H_CUDAWORKAROUND 1
4 #include <petsc/private/matimpl.h> /*I <petscmat.h> I*/
5 
6 #ifdef __cplusplus
7   #include <petsc/private/deviceimpl.h>
8   #include <petsc/private/cupmsolverinterface.hpp>
9   #include <petsc/private/cupmobject.hpp>
10 
11   #include "../src/sys/objects/device/impls/cupm/cupmthrustutility.hpp"
12   #include "../src/sys/objects/device/impls/cupm/kernels.hpp"
13 
14   #include <thrust/device_vector.h>
15   #include <thrust/device_ptr.h>
16   #include <thrust/iterator/counting_iterator.h>
17   #include <thrust/iterator/transform_iterator.h>
18   #include <thrust/iterator/permutation_iterator.h>
19   #include <thrust/transform.h>
20   #include <thrust/copy.h>
21 
22 namespace Petsc
23 {
24 
25 namespace vec
26 {
27 
28 namespace cupm
29 {
30 
31 namespace impl
32 {
33 
34 template <device::cupm::DeviceType>
35 class VecSeq_CUPM;
36 template <device::cupm::DeviceType>
37 class VecMPI_CUPM;
38 
39 } // namespace impl
40 
41 } // namespace cupm
42 
43 } // namespace vec
44 
45 namespace mat
46 {
47 
48 namespace cupm
49 {
50 
51 namespace impl
52 {
53 
54 // ==========================================================================================
55 // MatDense_CUPM_Base
56 //
57 // A base class to separate out the CRTP code from the common CUPM stuff (like the composed
58 // function names).
59 // ==========================================================================================
60 
61 template <device::cupm::DeviceType T>
62 class MatDense_CUPM_Base : protected device::cupm::impl::CUPMObject<T> {
63 public:
64   PETSC_CUPMOBJECT_HEADER(T);
65 
66   #define MatDenseCUPMComposedOpDecl(OP_NAME) \
67     PETSC_NODISCARD static constexpr const char *PetscConcat(MatDenseCUPM, OP_NAME)() noexcept \
68     { \
69       return T == device::cupm::DeviceType::CUDA ? PetscStringize(PetscConcat(MatDenseCUDA, OP_NAME)) : PetscStringize(PetscConcat(MatDenseHIP, OP_NAME)); \
70     }
71 
72   // clang-format off
73   MatDenseCUPMComposedOpDecl(GetArray_C)
74   MatDenseCUPMComposedOpDecl(GetArrayRead_C)
75   MatDenseCUPMComposedOpDecl(GetArrayWrite_C)
76   MatDenseCUPMComposedOpDecl(RestoreArray_C)
77   MatDenseCUPMComposedOpDecl(RestoreArrayRead_C)
78   MatDenseCUPMComposedOpDecl(RestoreArrayWrite_C)
79   MatDenseCUPMComposedOpDecl(PlaceArray_C)
80   MatDenseCUPMComposedOpDecl(ReplaceArray_C)
81   MatDenseCUPMComposedOpDecl(ResetArray_C)
82   MatDenseCUPMComposedOpDecl(SetPreallocation_C)
83   // clang-format on
84 
85   #undef MatDenseCUPMComposedOpDecl
86 
87     PETSC_NODISCARD static constexpr MatType MATSEQDENSECUPM() noexcept;
88   PETSC_NODISCARD static constexpr MatType       MATMPIDENSECUPM() noexcept;
89   PETSC_NODISCARD static constexpr MatType       MATDENSECUPM() noexcept;
90   PETSC_NODISCARD static constexpr MatSolverType MATSOLVERCUPM() noexcept;
91 };
92 
93 // ==========================================================================================
94 // MatDense_CUPM_Base -- Public API
95 // ==========================================================================================
96 
97 template <device::cupm::DeviceType T>
98 inline constexpr MatType MatDense_CUPM_Base<T>::MATSEQDENSECUPM() noexcept
99 {
100   return T == device::cupm::DeviceType::CUDA ? MATSEQDENSECUDA : MATSEQDENSEHIP;
101 }
102 
103 template <device::cupm::DeviceType T>
104 inline constexpr MatType MatDense_CUPM_Base<T>::MATMPIDENSECUPM() noexcept
105 {
106   return T == device::cupm::DeviceType::CUDA ? MATMPIDENSECUDA : MATMPIDENSEHIP;
107 }
108 
109 template <device::cupm::DeviceType T>
110 inline constexpr MatType MatDense_CUPM_Base<T>::MATDENSECUPM() noexcept
111 {
112   return T == device::cupm::DeviceType::CUDA ? MATDENSECUDA : MATDENSEHIP;
113 }
114 
115 template <device::cupm::DeviceType T>
116 inline constexpr MatSolverType MatDense_CUPM_Base<T>::MATSOLVERCUPM() noexcept
117 {
118   return T == device::cupm::DeviceType::CUDA ? MATSOLVERCUDA : MATSOLVERHIP;
119 }
120 
121   #define MATDENSECUPM_BASE_HEADER(T) \
122     PETSC_CUPMOBJECT_HEADER(T); \
123     using VecSeq_CUPM = ::Petsc::vec::cupm::impl::VecSeq_CUPM<T>; \
124     using VecMPI_CUPM = ::Petsc::vec::cupm::impl::VecMPI_CUPM<T>; \
125     using ::Petsc::mat::cupm::impl::MatDense_CUPM_Base<T>::MATSEQDENSECUPM; \
126     using ::Petsc::mat::cupm::impl::MatDense_CUPM_Base<T>::MATMPIDENSECUPM; \
127     using ::Petsc::mat::cupm::impl::MatDense_CUPM_Base<T>::MATDENSECUPM; \
128     using ::Petsc::mat::cupm::impl::MatDense_CUPM_Base<T>::MATSOLVERCUPM; \
129     using ::Petsc::mat::cupm::impl::MatDense_CUPM_Base<T>::MatDenseCUPMGetArray_C; \
130     using ::Petsc::mat::cupm::impl::MatDense_CUPM_Base<T>::MatDenseCUPMGetArrayRead_C; \
131     using ::Petsc::mat::cupm::impl::MatDense_CUPM_Base<T>::MatDenseCUPMGetArrayWrite_C; \
132     using ::Petsc::mat::cupm::impl::MatDense_CUPM_Base<T>::MatDenseCUPMRestoreArray_C; \
133     using ::Petsc::mat::cupm::impl::MatDense_CUPM_Base<T>::MatDenseCUPMRestoreArrayRead_C; \
134     using ::Petsc::mat::cupm::impl::MatDense_CUPM_Base<T>::MatDenseCUPMRestoreArrayWrite_C; \
135     using ::Petsc::mat::cupm::impl::MatDense_CUPM_Base<T>::MatDenseCUPMPlaceArray_C; \
136     using ::Petsc::mat::cupm::impl::MatDense_CUPM_Base<T>::MatDenseCUPMReplaceArray_C; \
137     using ::Petsc::mat::cupm::impl::MatDense_CUPM_Base<T>::MatDenseCUPMResetArray_C; \
138     using ::Petsc::mat::cupm::impl::MatDense_CUPM_Base<T>::MatDenseCUPMSetPreallocation_C
139 
140 // forward declare
141 template <device::cupm::DeviceType>
142 class MatDense_Seq_CUPM;
143 template <device::cupm::DeviceType>
144 class MatDense_MPI_CUPM;
145 
146 // ==========================================================================================
147 // MatDense_CUPM
148 //
149 // The true "base" class for MatDenseCUPM. The reason MatDense_CUPM and MatDense_CUPM_Base
150 // exist is to separate out the CRTP code from the non-crtp code so that the generic functions
151 // can be called via templates below.
152 // ==========================================================================================
153 
154 template <device::cupm::DeviceType T, typename Derived>
155 class MatDense_CUPM : protected MatDense_CUPM_Base<T> {
156 private:
157   static PetscErrorCode CheckSaneSequentialMatSizes_(Mat) noexcept;
158 
159 protected:
160   MATDENSECUPM_BASE_HEADER(T);
161 
162   template <PetscMemType, PetscMemoryAccessMode>
163   class MatrixArray;
164 
165   // Cast the Mat to its host struct, i.e. return the result of (Mat_SeqDense *)m->data
166   template <typename U = Derived>
167   PETSC_NODISCARD static constexpr auto    MatIMPLCast(Mat m) noexcept PETSC_DECLTYPE_AUTO_RETURNS(U::MatIMPLCast_(m))
168   PETSC_NODISCARD static constexpr MatType MATIMPLCUPM() noexcept;
169 
170   static PetscErrorCode CreateIMPLDenseCUPM(MPI_Comm, PetscInt, PetscInt, PetscInt, PetscInt, PetscScalar *, Mat *, PetscDeviceContext, bool) noexcept;
171   static PetscErrorCode SetPreallocation(Mat, PetscDeviceContext, PetscScalar *) noexcept;
172 
173   template <typename F>
174   static PetscErrorCode DiagonalUnaryTransform(Mat, PetscDeviceContext, F &&) noexcept;
175 
176   static PetscErrorCode Shift(Mat, PetscScalar) noexcept;
177   static PetscErrorCode GetDiagonal(Mat, Vec) noexcept;
178 
179   PETSC_NODISCARD static auto DeviceArrayRead(PetscDeviceContext dctx, Mat m) noexcept PETSC_DECLTYPE_AUTO_RETURNS(MatrixArray<PETSC_MEMTYPE_DEVICE, PETSC_MEMORY_ACCESS_READ>{dctx, m})
180   PETSC_NODISCARD static auto DeviceArrayWrite(PetscDeviceContext dctx, Mat m) noexcept PETSC_DECLTYPE_AUTO_RETURNS(MatrixArray<PETSC_MEMTYPE_DEVICE, PETSC_MEMORY_ACCESS_WRITE>{dctx, m})
181   PETSC_NODISCARD static auto DeviceArrayReadWrite(PetscDeviceContext dctx, Mat m) noexcept PETSC_DECLTYPE_AUTO_RETURNS(MatrixArray<PETSC_MEMTYPE_DEVICE, PETSC_MEMORY_ACCESS_READ_WRITE>{dctx, m})
182   PETSC_NODISCARD static auto HostArrayRead(PetscDeviceContext dctx, Mat m) noexcept PETSC_DECLTYPE_AUTO_RETURNS(MatrixArray<PETSC_MEMTYPE_HOST, PETSC_MEMORY_ACCESS_READ>{dctx, m})
183   PETSC_NODISCARD static auto HostArrayWrite(PetscDeviceContext dctx, Mat m) noexcept PETSC_DECLTYPE_AUTO_RETURNS(MatrixArray<PETSC_MEMTYPE_HOST, PETSC_MEMORY_ACCESS_WRITE>{dctx, m})
184   PETSC_NODISCARD static auto HostArrayReadWrite(PetscDeviceContext dctx, Mat m) noexcept PETSC_DECLTYPE_AUTO_RETURNS(MatrixArray<PETSC_MEMTYPE_HOST, PETSC_MEMORY_ACCESS_READ_WRITE>{dctx, m})
185 };
186 
187 // ==========================================================================================
188 // MatDense_CUPM::MatrixArray
189 // ==========================================================================================
190 
191 template <device::cupm::DeviceType T, typename D>
192 template <PetscMemType MT, PetscMemoryAccessMode MA>
193 class MatDense_CUPM<T, D>::MatrixArray : public device::cupm::impl::RestoreableArray<T, MT, MA> {
194   using base_type = device::cupm::impl::RestoreableArray<T, MT, MA>;
195 
196 public:
197   MatrixArray(PetscDeviceContext, Mat) noexcept;
198   ~MatrixArray() noexcept;
199 
200   // must declare move constructor since we declare a destructor
201   constexpr MatrixArray(MatrixArray &&) noexcept;
202 
203 private:
204   Mat m_ = nullptr;
205 };
206 
207 // ==========================================================================================
208 // MatDense_CUPM::MatrixArray -- Public API
209 // ==========================================================================================
210 
211 template <device::cupm::DeviceType T, typename D>
212 template <PetscMemType MT, PetscMemoryAccessMode MA>
213 inline MatDense_CUPM<T, D>::MatrixArray<MT, MA>::MatrixArray(PetscDeviceContext dctx, Mat m) noexcept : base_type{dctx}, m_{m}
214 {
215   PetscFunctionBegin;
216   PetscCallAbort(PETSC_COMM_SELF, D::template GetArray<MT, MA>(m, &this->ptr_, dctx));
217   PetscFunctionReturnVoid();
218 }
219 
220 template <device::cupm::DeviceType T, typename D>
221 template <PetscMemType MT, PetscMemoryAccessMode MA>
222 inline MatDense_CUPM<T, D>::MatrixArray<MT, MA>::~MatrixArray() noexcept
223 {
224   PetscFunctionBegin;
225   PetscCallAbort(PETSC_COMM_SELF, D::template RestoreArray<MT, MA>(m_, &this->ptr_, this->dctx_));
226   PetscFunctionReturnVoid();
227 }
228 
229 template <device::cupm::DeviceType T, typename D>
230 template <PetscMemType MT, PetscMemoryAccessMode MA>
231 inline constexpr MatDense_CUPM<T, D>::MatrixArray<MT, MA>::MatrixArray(MatrixArray &&other) noexcept : base_type{std::move(other)}, m_{util::exchange(other.m_, nullptr)}
232 {
233 }
234 
235 // ==========================================================================================
236 // MatDense_CUPM -- Private API
237 // ==========================================================================================
238 
239 template <device::cupm::DeviceType T, typename D>
240 inline PetscErrorCode MatDense_CUPM<T, D>::CheckSaneSequentialMatSizes_(Mat A) noexcept
241 {
242   PetscFunctionBegin;
243   if (PetscDefined(USE_DEBUG)) {
244     PetscBool isseq;
245 
246     PetscCall(PetscObjectTypeCompare(PetscObjectCast(A), D::MATSEQDENSECUPM(), &isseq));
247     if (isseq) {
248       // doing this check allows both sequential and parallel implementations to just pass in
249       // A, otherwise they would need to specify rstart, rend, and cols separately!
250       PetscCheck(A->rmap->rstart == 0, PETSC_COMM_SELF, PETSC_ERR_PLIB, "Sequential matrix row start %" PetscInt_FMT " != 0?", A->rmap->rstart);
251       PetscCheck(A->rmap->rend == A->rmap->n, PETSC_COMM_SELF, PETSC_ERR_PLIB, "Sequential matrix row end %" PetscInt_FMT " != number of rows %" PetscInt_FMT, A->rmap->rend, A->rmap->n);
252       PetscCheck(A->cmap->n == A->cmap->N, PETSC_COMM_SELF, PETSC_ERR_PLIB, "Sequential matrix number of local columns %" PetscInt_FMT " != number of global columns %" PetscInt_FMT, A->cmap->n, A->cmap->n);
253     }
254   }
255   PetscFunctionReturn(PETSC_SUCCESS);
256 }
257 
258 // ==========================================================================================
259 // MatDense_CUPM -- Protected API
260 // ==========================================================================================
261 
262 template <device::cupm::DeviceType T, typename D>
263 inline constexpr MatType MatDense_CUPM<T, D>::MATIMPLCUPM() noexcept
264 {
265   return D::MATIMPLCUPM_();
266 }
267 
268 // Common core for MatCreateSeqDenseCUPM() and MatCreateMPIDenseCUPM()
269 template <device::cupm::DeviceType T, typename D>
270 inline PetscErrorCode MatDense_CUPM<T, D>::CreateIMPLDenseCUPM(MPI_Comm comm, PetscInt m, PetscInt n, PetscInt M, PetscInt N, PetscScalar *data, Mat *A, PetscDeviceContext dctx, bool preallocate) noexcept
271 {
272   Mat mat;
273 
274   PetscFunctionBegin;
275   PetscAssertPointer(A, 7);
276   PetscCall(MatCreate(comm, &mat));
277   PetscCall(MatSetSizes(mat, m, n, M, N));
278   PetscCall(MatSetType(mat, D::MATIMPLCUPM()));
279   if (preallocate) {
280     PetscCall(PetscDeviceContextGetOptionalNullContext_Internal(&dctx));
281     PetscCall(D::SetPreallocation(mat, dctx, data));
282   }
283   *A = mat;
284   PetscFunctionReturn(PETSC_SUCCESS);
285 }
286 
287 template <device::cupm::DeviceType T, typename D>
288 inline PetscErrorCode MatDense_CUPM<T, D>::SetPreallocation(Mat A, PetscDeviceContext dctx, PetscScalar *device_array) noexcept
289 {
290   PetscFunctionBegin;
291   // cannot use PetscValidHeaderSpecificType(..., MATIMPLCUPM()) since the incoming matrix
292   // might be the local (sequential) matrix of a MatMPIDense_CUPM. Since this would be called
293   // from the MPI matrix'es impl MATIMPLCUPM() would return MATMPIDENSECUPM().
294   PetscValidHeaderSpecific(A, MAT_CLASSID, 1);
295   PetscCheckTypeNames(A, D::MATSEQDENSECUPM(), D::MATMPIDENSECUPM());
296   PetscCall(PetscLayoutSetUp(A->rmap));
297   PetscCall(PetscLayoutSetUp(A->cmap));
298   PetscCall(PetscDeviceContextGetOptionalNullContext_Internal(&dctx));
299   PetscCall(D::SetPreallocation_(A, dctx, device_array));
300   A->preallocated = PETSC_TRUE;
301   A->assembled    = PETSC_TRUE;
302   PetscFunctionReturn(PETSC_SUCCESS);
303 }
304 
305 namespace detail
306 {
307 
308 // ==========================================================================================
309 // MatrixIteratorBase
310 //
311 // A base class for creating thrust iterators over the local sub-matrix. This will set up the
312 // proper iterator definitions so thrust knows how to handle things properly. Template
313 // parameters are as follows:
314 //
315 // - Iterator:
316 // The type of the primary array iterator. Usually this is
317 // thrust::device_pointer<PetscScalar>::iterator.
318 //
319 // - IndexFunctor:
320 // This should be a functor which contains an operator() that when called with an index `i`,
321 // returns the i'th permuted index into the array. For example, it could return the i'th
322 // diagonal entry.
323 // ==========================================================================================
324 template <typename Iterator, typename IndexFunctor>
325 class MatrixIteratorBase {
326 public:
327   using array_iterator_type = Iterator;
328   using index_functor_type  = IndexFunctor;
329 
330   using difference_type     = typename thrust::iterator_difference<array_iterator_type>::type;
331   using CountingIterator    = thrust::counting_iterator<difference_type>;
332   using TransformIterator   = thrust::transform_iterator<index_functor_type, CountingIterator>;
333   using PermutationIterator = thrust::permutation_iterator<array_iterator_type, TransformIterator>;
334   using iterator            = PermutationIterator; // type of the begin/end iterator
335 
336   constexpr MatrixIteratorBase(array_iterator_type first, array_iterator_type last, index_functor_type idx_func) noexcept : first{std::move(first)}, last{std::move(last)}, func{std::move(idx_func)} { }
337 
338   PETSC_NODISCARD iterator begin() const noexcept
339   {
340     return PermutationIterator{
341       first, TransformIterator{CountingIterator{0}, func}
342     };
343   }
344 
345 protected:
346   array_iterator_type first;
347   array_iterator_type last;
348   index_functor_type  func;
349 };
350 
351 // ==========================================================================================
352 // StridedIndexFunctor
353 //
354 // Iterator which permutes a linear index range into strided matrix indices. Usually used to
355 // get the diagonal.
356 // ==========================================================================================
357 template <typename T>
358 struct StridedIndexFunctor {
359   PETSC_NODISCARD PETSC_HOSTDEVICE_INLINE_DECL constexpr T operator()(const T &i) const noexcept { return stride * i; }
360 
361   T stride;
362 };
363 
364 template <typename Iterator>
365 class DiagonalIterator : public MatrixIteratorBase<Iterator, StridedIndexFunctor<typename thrust::iterator_difference<Iterator>::type>> {
366 public:
367   using base_type = MatrixIteratorBase<Iterator, StridedIndexFunctor<typename thrust::iterator_difference<Iterator>::type>>;
368 
369   using difference_type = typename base_type::difference_type;
370   using iterator        = typename base_type::iterator;
371 
372   constexpr DiagonalIterator(Iterator first, Iterator last, difference_type stride) noexcept : base_type{std::move(first), std::move(last), {stride}} { }
373 
374   PETSC_NODISCARD iterator end() const noexcept { return this->begin() + (this->last - this->first + this->func.stride - 1) / this->func.stride; }
375 };
376 
377 template <typename T>
378 inline DiagonalIterator<typename thrust::device_vector<T>::iterator> MakeDiagonalIterator(T *data, PetscInt rstart, PetscInt rend, PetscInt cols, PetscInt lda) noexcept
379 {
380   const auto        rend2 = std::min(rend, cols);
381   const std::size_t begin = rstart * lda;
382   const std::size_t end   = rend2 - rstart + rend2 * lda;
383   const auto        dptr  = thrust::device_pointer_cast(data);
384 
385   return {dptr + begin, dptr + end, lda + 1};
386 }
387 
388 } // namespace detail
389 
390 template <device::cupm::DeviceType T, typename D>
391 template <typename F>
392 inline PetscErrorCode MatDense_CUPM<T, D>::DiagonalUnaryTransform(Mat A, PetscDeviceContext dctx, F &&functor) noexcept
393 {
394   const auto rstart = A->rmap->rstart;
395   const auto rend   = A->rmap->rend;
396   const auto gcols  = A->cmap->N;
397   const auto rend2  = std::min(rend, gcols);
398 
399   PetscFunctionBegin;
400   PetscCall(CheckSaneSequentialMatSizes_(A));
401   if (rend2 > rstart) {
402     const auto   da = D::DeviceArrayReadWrite(dctx, A);
403     cupmStream_t stream;
404     PetscInt     lda;
405 
406     PetscCall(MatDenseGetLDA(A, &lda));
407     PetscCall(D::GetHandlesFrom_(dctx, &stream));
408     {
409       auto diagonal = detail::MakeDiagonalIterator(da.data(), rstart, rend, gcols, lda);
410 
411       // clang-format off
412       PetscCallThrust(
413         THRUST_CALL(
414           thrust::transform,
415           stream,
416           diagonal.begin(), diagonal.end(), diagonal.begin(),
417           std::forward<F>(functor)
418         )
419       );
420       // clang-format on
421     }
422     PetscCall(PetscLogGpuFlops(rend2 - rstart));
423   }
424   PetscFunctionReturn(PETSC_SUCCESS);
425 }
426 
427 template <device::cupm::DeviceType T, typename D>
428 inline PetscErrorCode MatDense_CUPM<T, D>::Shift(Mat A, PetscScalar alpha) noexcept
429 {
430   PetscDeviceContext dctx;
431 
432   PetscFunctionBegin;
433   PetscCall(GetHandles_(&dctx));
434   PetscCall(DiagonalUnaryTransform(A, dctx, device::cupm::functors::make_plus_equals(alpha)));
435   PetscFunctionReturn(PETSC_SUCCESS);
436 }
437 
438 template <device::cupm::DeviceType T, typename D>
439 inline PetscErrorCode MatDense_CUPM<T, D>::GetDiagonal(Mat A, Vec v) noexcept
440 {
441   const auto         rstart = A->rmap->rstart;
442   const auto         rend   = A->rmap->rend;
443   const auto         gcols  = A->cmap->N;
444   PetscInt           lda;
445   PetscDeviceContext dctx;
446 
447   PetscFunctionBegin;
448   PetscCall(CheckSaneSequentialMatSizes_(A));
449   PetscCall(GetHandles_(&dctx));
450   PetscCall(MatDenseGetLDA(A, &lda));
451   {
452     auto dv       = VecSeq_CUPM::DeviceArrayWrite(dctx, v);
453     auto da       = D::DeviceArrayRead(dctx, A);
454     auto diagonal = detail::MakeDiagonalIterator(da.data(), rstart, rend, gcols, lda);
455     // We must have this cast outside of THRUST_CALL(). Without it, GCC 6.4 - 7.5, and 11.3.0
456     // throw spurious warnings:
457     //
458     // warning: 'MatDense_CUPM<...>::GetDiagonal(Mat, Vec)::<lambda()>' declared with greater
459     // visibility than the type of its field 'MatDense_CUPM<...>::GetDiagonal(Mat,
460     // Vec)::<lambda()>::<dv capture>' [-Wattributes]
461     // 460 |     PetscCallThrust(
462     //     |     ^~~~~~~~~~~~~~~~
463     auto         dvp = thrust::device_pointer_cast(dv.data());
464     cupmStream_t stream;
465 
466     PetscCall(GetHandlesFrom_(dctx, &stream));
467     PetscCallThrust(THRUST_CALL(thrust::copy, stream, diagonal.begin(), diagonal.end(), dvp));
468   }
469   PetscFunctionReturn(PETSC_SUCCESS);
470 }
471 
472   #define MatComposeOp_CUPM(use_host, pobj, op_str, op_host, ...) \
473     do { \
474       if (use_host) { \
475         PetscCall(PetscObjectComposeFunction(pobj, op_str, op_host)); \
476       } else { \
477         PetscCall(PetscObjectComposeFunction(pobj, op_str, __VA_ARGS__)); \
478       } \
479     } while (0)
480 
481   #define MatSetOp_CUPM(use_host, mat, op_name, op_host, ...) \
482     do { \
483       if (use_host) { \
484         (mat)->ops->op_name = op_host; \
485       } else { \
486         (mat)->ops->op_name = __VA_ARGS__; \
487       } \
488     } while (0)
489 
490   #define MATDENSECUPM_HEADER(T, ...) \
491     MATDENSECUPM_BASE_HEADER(T); \
492     friend class ::Petsc::mat::cupm::impl::MatDense_CUPM<T, __VA_ARGS__>; \
493     using ::Petsc::mat::cupm::impl::MatDense_CUPM<T, __VA_ARGS__>::MatIMPLCast; \
494     using ::Petsc::mat::cupm::impl::MatDense_CUPM<T, __VA_ARGS__>::MATIMPLCUPM; \
495     using ::Petsc::mat::cupm::impl::MatDense_CUPM<T, __VA_ARGS__>::CreateIMPLDenseCUPM; \
496     using ::Petsc::mat::cupm::impl::MatDense_CUPM<T, __VA_ARGS__>::SetPreallocation; \
497     using ::Petsc::mat::cupm::impl::MatDense_CUPM<T, __VA_ARGS__>::DeviceArrayRead; \
498     using ::Petsc::mat::cupm::impl::MatDense_CUPM<T, __VA_ARGS__>::DeviceArrayWrite; \
499     using ::Petsc::mat::cupm::impl::MatDense_CUPM<T, __VA_ARGS__>::DeviceArrayReadWrite; \
500     using ::Petsc::mat::cupm::impl::MatDense_CUPM<T, __VA_ARGS__>::HostArrayRead; \
501     using ::Petsc::mat::cupm::impl::MatDense_CUPM<T, __VA_ARGS__>::HostArrayWrite; \
502     using ::Petsc::mat::cupm::impl::MatDense_CUPM<T, __VA_ARGS__>::HostArrayReadWrite; \
503     using ::Petsc::mat::cupm::impl::MatDense_CUPM<T, __VA_ARGS__>::DiagonalUnaryTransform; \
504     using ::Petsc::mat::cupm::impl::MatDense_CUPM<T, __VA_ARGS__>::Shift; \
505     using ::Petsc::mat::cupm::impl::MatDense_CUPM<T, __VA_ARGS__>::GetDiagonal
506 
507 } // namespace impl
508 
509 // ==========================================================================================
510 // MatDense_CUPM -- Implementations
511 // ==========================================================================================
512 
513 template <device::cupm::DeviceType T, PetscMemoryAccessMode access>
514 inline PetscErrorCode MatDenseCUPMGetArray_Private(Mat A, PetscScalar **array) noexcept
515 {
516   PetscFunctionBegin;
517   PetscValidHeaderSpecific(A, MAT_CLASSID, 1);
518   PetscAssertPointer(array, 2);
519   switch (access) {
520   case PETSC_MEMORY_ACCESS_READ:
521     PetscUseMethod(A, impl::MatDense_CUPM_Base<T>::MatDenseCUPMGetArrayRead_C(), (Mat, PetscScalar **), (A, array));
522     break;
523   case PETSC_MEMORY_ACCESS_WRITE:
524     PetscUseMethod(A, impl::MatDense_CUPM_Base<T>::MatDenseCUPMGetArrayWrite_C(), (Mat, PetscScalar **), (A, array));
525     break;
526   case PETSC_MEMORY_ACCESS_READ_WRITE:
527     PetscUseMethod(A, impl::MatDense_CUPM_Base<T>::MatDenseCUPMGetArray_C(), (Mat, PetscScalar **), (A, array));
528     break;
529   }
530   if (PetscMemoryAccessWrite(access)) PetscCall(PetscObjectStateIncrease(PetscObjectCast(A)));
531   PetscFunctionReturn(PETSC_SUCCESS);
532 }
533 
534 template <device::cupm::DeviceType T, PetscMemoryAccessMode access>
535 inline PetscErrorCode MatDenseCUPMRestoreArray_Private(Mat A, PetscScalar **array) noexcept
536 {
537   PetscFunctionBegin;
538   PetscValidHeaderSpecific(A, MAT_CLASSID, 1);
539   if (array) PetscAssertPointer(array, 2);
540   switch (access) {
541   case PETSC_MEMORY_ACCESS_READ:
542     PetscUseMethod(A, impl::MatDense_CUPM_Base<T>::MatDenseCUPMRestoreArrayRead_C(), (Mat, PetscScalar **), (A, array));
543     break;
544   case PETSC_MEMORY_ACCESS_WRITE:
545     PetscUseMethod(A, impl::MatDense_CUPM_Base<T>::MatDenseCUPMRestoreArrayWrite_C(), (Mat, PetscScalar **), (A, array));
546     break;
547   case PETSC_MEMORY_ACCESS_READ_WRITE:
548     PetscUseMethod(A, impl::MatDense_CUPM_Base<T>::MatDenseCUPMRestoreArray_C(), (Mat, PetscScalar **), (A, array));
549     break;
550   }
551   if (PetscMemoryAccessWrite(access)) {
552     PetscCall(PetscObjectStateIncrease(PetscObjectCast(A)));
553     A->offloadmask = PETSC_OFFLOAD_GPU;
554   }
555   if (array) *array = nullptr;
556   PetscFunctionReturn(PETSC_SUCCESS);
557 }
558 
559 template <device::cupm::DeviceType T>
560 inline PetscErrorCode MatDenseCUPMGetArray(Mat A, PetscScalar **array) noexcept
561 {
562   PetscFunctionBegin;
563   PetscCall(MatDenseCUPMGetArray_Private<T, PETSC_MEMORY_ACCESS_READ_WRITE>(A, array));
564   PetscFunctionReturn(PETSC_SUCCESS);
565 }
566 
567 template <device::cupm::DeviceType T>
568 inline PetscErrorCode MatDenseCUPMGetArrayRead(Mat A, const PetscScalar **array) noexcept
569 {
570   PetscFunctionBegin;
571   PetscCall(MatDenseCUPMGetArray_Private<T, PETSC_MEMORY_ACCESS_READ>(A, const_cast<PetscScalar **>(array)));
572   PetscFunctionReturn(PETSC_SUCCESS);
573 }
574 
575 template <device::cupm::DeviceType T>
576 inline PetscErrorCode MatDenseCUPMGetArrayWrite(Mat A, PetscScalar **array) noexcept
577 {
578   PetscFunctionBegin;
579   PetscCall(MatDenseCUPMGetArray_Private<T, PETSC_MEMORY_ACCESS_WRITE>(A, array));
580   PetscFunctionReturn(PETSC_SUCCESS);
581 }
582 
583 template <device::cupm::DeviceType T>
584 inline PetscErrorCode MatDenseCUPMRestoreArray(Mat A, PetscScalar **array) noexcept
585 {
586   PetscFunctionBegin;
587   PetscCall(MatDenseCUPMRestoreArray_Private<T, PETSC_MEMORY_ACCESS_READ_WRITE>(A, array));
588   PetscFunctionReturn(PETSC_SUCCESS);
589 }
590 
591 template <device::cupm::DeviceType T>
592 inline PetscErrorCode MatDenseCUPMRestoreArrayRead(Mat A, const PetscScalar **array) noexcept
593 {
594   PetscFunctionBegin;
595   PetscCall(MatDenseCUPMRestoreArray_Private<T, PETSC_MEMORY_ACCESS_READ>(A, const_cast<PetscScalar **>(array)));
596   PetscFunctionReturn(PETSC_SUCCESS);
597 }
598 
599 template <device::cupm::DeviceType T>
600 inline PetscErrorCode MatDenseCUPMRestoreArrayWrite(Mat A, PetscScalar **array) noexcept
601 {
602   PetscFunctionBegin;
603   PetscCall(MatDenseCUPMRestoreArray_Private<T, PETSC_MEMORY_ACCESS_WRITE>(A, array));
604   PetscFunctionReturn(PETSC_SUCCESS);
605 }
606 
607 template <device::cupm::DeviceType T>
608 inline PetscErrorCode MatDenseCUPMPlaceArray(Mat A, const PetscScalar *array) noexcept
609 {
610   PetscFunctionBegin;
611   PetscValidHeaderSpecific(A, MAT_CLASSID, 1);
612   PetscUseMethod(A, impl::MatDense_CUPM_Base<T>::MatDenseCUPMPlaceArray_C(), (Mat, const PetscScalar *), (A, array));
613   PetscCall(PetscObjectStateIncrease(PetscObjectCast(A)));
614   A->offloadmask = PETSC_OFFLOAD_GPU;
615   PetscFunctionReturn(PETSC_SUCCESS);
616 }
617 
618 template <device::cupm::DeviceType T>
619 inline PetscErrorCode MatDenseCUPMReplaceArray(Mat A, const PetscScalar *array) noexcept
620 {
621   PetscFunctionBegin;
622   PetscValidHeaderSpecific(A, MAT_CLASSID, 1);
623   PetscUseMethod(A, impl::MatDense_CUPM_Base<T>::MatDenseCUPMReplaceArray_C(), (Mat, const PetscScalar *), (A, array));
624   PetscCall(PetscObjectStateIncrease(PetscObjectCast(A)));
625   A->offloadmask = PETSC_OFFLOAD_GPU;
626   PetscFunctionReturn(PETSC_SUCCESS);
627 }
628 
629 template <device::cupm::DeviceType T>
630 inline PetscErrorCode MatDenseCUPMResetArray(Mat A) noexcept
631 {
632   PetscFunctionBegin;
633   PetscValidHeaderSpecific(A, MAT_CLASSID, 1);
634   PetscUseMethod(A, impl::MatDense_CUPM_Base<T>::MatDenseCUPMResetArray_C(), (Mat), (A));
635   PetscCall(PetscObjectStateIncrease(PetscObjectCast(A)));
636   PetscFunctionReturn(PETSC_SUCCESS);
637 }
638 
639 template <device::cupm::DeviceType T>
640 inline PetscErrorCode MatDenseCUPMSetPreallocation(Mat A, PetscScalar *device_data, PetscDeviceContext dctx = nullptr) noexcept
641 {
642   PetscFunctionBegin;
643   PetscValidHeaderSpecific(A, MAT_CLASSID, 1);
644   PetscUseMethod(A, impl::MatDense_CUPM_Base<T>::MatDenseCUPMSetPreallocation_C(), (Mat, PetscDeviceContext, PetscScalar *), (A, dctx, device_data));
645   PetscFunctionReturn(PETSC_SUCCESS);
646 }
647 
648 } // namespace cupm
649 
650 } // namespace mat
651 
652 } // namespace Petsc
653 
654 #endif // __cplusplus
655