xref: /petsc/include/petsc/private/matdensecupmimpl.h (revision e4094ef18e7e53fda86cf35f3a47fda48a8e77d8)
1 #ifndef PETSCMATDENSECUPMIMPL_H
2 #define PETSCMATDENSECUPMIMPL_H
3 
4 #define PETSC_SKIP_IMMINTRIN_H_CUDAWORKAROUND 1
5 #include <petsc/private/matimpl.h> /*I <petscmat.h> I*/
6 
7 #ifdef __cplusplus
8   #include <petsc/private/deviceimpl.h>
9   #include <petsc/private/cupmsolverinterface.hpp>
10   #include <petsc/private/cupmobject.hpp>
11 
12   #include "../src/sys/objects/device/impls/cupm/cupmthrustutility.hpp"
13   #include "../src/sys/objects/device/impls/cupm/kernels.hpp"
14 
15   #include <thrust/device_vector.h>
16   #include <thrust/device_ptr.h>
17   #include <thrust/iterator/counting_iterator.h>
18   #include <thrust/iterator/transform_iterator.h>
19   #include <thrust/iterator/permutation_iterator.h>
20   #include <thrust/transform.h>
21   #include <thrust/copy.h>
22 
23 namespace Petsc
24 {
25 
26 namespace vec
27 {
28 
29 namespace cupm
30 {
31 
32 namespace impl
33 {
34 
35 template <device::cupm::DeviceType>
36 class VecSeq_CUPM;
37 template <device::cupm::DeviceType>
38 class VecMPI_CUPM;
39 
40 } // namespace impl
41 
42 } // namespace cupm
43 
44 } // namespace vec
45 
46 namespace mat
47 {
48 
49 namespace cupm
50 {
51 
52 namespace impl
53 {
54 
55 // ==========================================================================================
56 // MatDense_CUPM_Base
57 //
58 // A base class to separate out the CRTP code from the common CUPM stuff (like the composed
59 // function names).
60 // ==========================================================================================
61 
62 template <device::cupm::DeviceType T>
63 class MatDense_CUPM_Base : protected device::cupm::impl::CUPMObject<T> {
64 public:
65   PETSC_CUPMOBJECT_HEADER(T);
66 
67   #define MatDenseCUPMComposedOpDecl(OP_NAME) \
68     PETSC_NODISCARD static constexpr const char *PetscConcat(MatDenseCUPM, OP_NAME)() noexcept \
69     { \
70       return T == device::cupm::DeviceType::CUDA ? PetscStringize(PetscConcat(MatDenseCUDA, OP_NAME)) : PetscStringize(PetscConcat(MatDenseHIP, OP_NAME)); \
71     }
72 
73   // clang-format off
74   MatDenseCUPMComposedOpDecl(GetArray_C)
75   MatDenseCUPMComposedOpDecl(GetArrayRead_C)
76   MatDenseCUPMComposedOpDecl(GetArrayWrite_C)
77   MatDenseCUPMComposedOpDecl(RestoreArray_C)
78   MatDenseCUPMComposedOpDecl(RestoreArrayRead_C)
79   MatDenseCUPMComposedOpDecl(RestoreArrayWrite_C)
80   MatDenseCUPMComposedOpDecl(PlaceArray_C)
81   MatDenseCUPMComposedOpDecl(ReplaceArray_C)
82   MatDenseCUPMComposedOpDecl(ResetArray_C)
83   MatDenseCUPMComposedOpDecl(SetPreallocation_C)
84     // clang-format on
85 
86   #undef MatDenseCUPMComposedOpDecl
87 
88     PETSC_NODISCARD static constexpr MatType MATSEQDENSECUPM() noexcept;
89   PETSC_NODISCARD static constexpr MatType       MATMPIDENSECUPM() noexcept;
90   PETSC_NODISCARD static constexpr MatType       MATDENSECUPM() noexcept;
91   PETSC_NODISCARD static constexpr MatSolverType MATSOLVERCUPM() noexcept;
92 };
93 
94 // ==========================================================================================
95 // MatDense_CUPM_Base -- Public API
96 // ==========================================================================================
97 
98 template <device::cupm::DeviceType T>
99 inline constexpr MatType MatDense_CUPM_Base<T>::MATSEQDENSECUPM() noexcept
100 {
101   return T == device::cupm::DeviceType::CUDA ? MATSEQDENSECUDA : MATSEQDENSEHIP;
102 }
103 
104 template <device::cupm::DeviceType T>
105 inline constexpr MatType MatDense_CUPM_Base<T>::MATMPIDENSECUPM() noexcept
106 {
107   return T == device::cupm::DeviceType::CUDA ? MATMPIDENSECUDA : MATMPIDENSEHIP;
108 }
109 
110 template <device::cupm::DeviceType T>
111 inline constexpr MatType MatDense_CUPM_Base<T>::MATDENSECUPM() noexcept
112 {
113   return T == device::cupm::DeviceType::CUDA ? MATDENSECUDA : MATDENSEHIP;
114 }
115 
116 template <device::cupm::DeviceType T>
117 inline constexpr MatSolverType MatDense_CUPM_Base<T>::MATSOLVERCUPM() noexcept
118 {
119   return T == device::cupm::DeviceType::CUDA ? MATSOLVERCUDA : MATSOLVERHIP;
120 }
121 
122   #define MATDENSECUPM_BASE_HEADER(T) \
123     PETSC_CUPMOBJECT_HEADER(T); \
124     using VecSeq_CUPM = ::Petsc::vec::cupm::impl::VecSeq_CUPM<T>; \
125     using VecMPI_CUPM = ::Petsc::vec::cupm::impl::VecMPI_CUPM<T>; \
126     using ::Petsc::mat::cupm::impl::MatDense_CUPM_Base<T>::MATSEQDENSECUPM; \
127     using ::Petsc::mat::cupm::impl::MatDense_CUPM_Base<T>::MATMPIDENSECUPM; \
128     using ::Petsc::mat::cupm::impl::MatDense_CUPM_Base<T>::MATDENSECUPM; \
129     using ::Petsc::mat::cupm::impl::MatDense_CUPM_Base<T>::MATSOLVERCUPM; \
130     using ::Petsc::mat::cupm::impl::MatDense_CUPM_Base<T>::MatDenseCUPMGetArray_C; \
131     using ::Petsc::mat::cupm::impl::MatDense_CUPM_Base<T>::MatDenseCUPMGetArrayRead_C; \
132     using ::Petsc::mat::cupm::impl::MatDense_CUPM_Base<T>::MatDenseCUPMGetArrayWrite_C; \
133     using ::Petsc::mat::cupm::impl::MatDense_CUPM_Base<T>::MatDenseCUPMRestoreArray_C; \
134     using ::Petsc::mat::cupm::impl::MatDense_CUPM_Base<T>::MatDenseCUPMRestoreArrayRead_C; \
135     using ::Petsc::mat::cupm::impl::MatDense_CUPM_Base<T>::MatDenseCUPMRestoreArrayWrite_C; \
136     using ::Petsc::mat::cupm::impl::MatDense_CUPM_Base<T>::MatDenseCUPMPlaceArray_C; \
137     using ::Petsc::mat::cupm::impl::MatDense_CUPM_Base<T>::MatDenseCUPMReplaceArray_C; \
138     using ::Petsc::mat::cupm::impl::MatDense_CUPM_Base<T>::MatDenseCUPMResetArray_C; \
139     using ::Petsc::mat::cupm::impl::MatDense_CUPM_Base<T>::MatDenseCUPMSetPreallocation_C
140 
141 // forward declare
142 template <device::cupm::DeviceType>
143 class MatDense_Seq_CUPM;
144 template <device::cupm::DeviceType>
145 class MatDense_MPI_CUPM;
146 
147 // ==========================================================================================
148 // MatDense_CUPM
149 //
150 // The true "base" class for MatDenseCUPM. The reason MatDense_CUPM and MatDense_CUPM_Base
151 // exist is to separate out the CRTP code from the non-crtp code so that the generic functions
152 // can be called via templates below.
153 // ==========================================================================================
154 
155 template <device::cupm::DeviceType T, typename Derived>
156 class MatDense_CUPM : protected MatDense_CUPM_Base<T> {
157 private:
158   static PetscErrorCode CheckSaneSequentialMatSizes_(Mat) noexcept;
159 
160 protected:
161   MATDENSECUPM_BASE_HEADER(T);
162 
163   template <PetscMemType, PetscMemoryAccessMode>
164   class MatrixArray;
165 
166   // Cast the Mat to its host struct, i.e. return the result of (Mat_SeqDense *)m->data
167   template <typename U = Derived>
168   PETSC_NODISCARD static constexpr auto    MatIMPLCast(Mat m) noexcept PETSC_DECLTYPE_AUTO_RETURNS(U::MatIMPLCast_(m))
169   PETSC_NODISCARD static constexpr MatType MATIMPLCUPM() noexcept;
170 
171   static PetscErrorCode CreateIMPLDenseCUPM(MPI_Comm, PetscInt, PetscInt, PetscInt, PetscInt, PetscScalar *, Mat *, PetscDeviceContext, bool) noexcept;
172   static PetscErrorCode SetPreallocation(Mat, PetscDeviceContext, PetscScalar *) noexcept;
173 
174   template <typename F>
175   static PetscErrorCode DiagonalUnaryTransform(Mat, PetscDeviceContext, F &&) noexcept;
176 
177   static PetscErrorCode Shift(Mat, PetscScalar) noexcept;
178   static PetscErrorCode GetDiagonal(Mat, Vec) noexcept;
179 
180   PETSC_NODISCARD static auto DeviceArrayRead(PetscDeviceContext dctx, Mat m) noexcept PETSC_DECLTYPE_AUTO_RETURNS(MatrixArray<PETSC_MEMTYPE_DEVICE, PETSC_MEMORY_ACCESS_READ>{dctx, m})
181   PETSC_NODISCARD static auto DeviceArrayWrite(PetscDeviceContext dctx, Mat m) noexcept PETSC_DECLTYPE_AUTO_RETURNS(MatrixArray<PETSC_MEMTYPE_DEVICE, PETSC_MEMORY_ACCESS_WRITE>{dctx, m})
182   PETSC_NODISCARD static auto DeviceArrayReadWrite(PetscDeviceContext dctx, Mat m) noexcept PETSC_DECLTYPE_AUTO_RETURNS(MatrixArray<PETSC_MEMTYPE_DEVICE, PETSC_MEMORY_ACCESS_READ_WRITE>{dctx, m})
183   PETSC_NODISCARD static auto HostArrayRead(PetscDeviceContext dctx, Mat m) noexcept PETSC_DECLTYPE_AUTO_RETURNS(MatrixArray<PETSC_MEMTYPE_HOST, PETSC_MEMORY_ACCESS_READ>{dctx, m})
184   PETSC_NODISCARD static auto HostArrayWrite(PetscDeviceContext dctx, Mat m) noexcept PETSC_DECLTYPE_AUTO_RETURNS(MatrixArray<PETSC_MEMTYPE_HOST, PETSC_MEMORY_ACCESS_WRITE>{dctx, m})
185   PETSC_NODISCARD static auto HostArrayReadWrite(PetscDeviceContext dctx, Mat m) noexcept PETSC_DECLTYPE_AUTO_RETURNS(MatrixArray<PETSC_MEMTYPE_HOST, PETSC_MEMORY_ACCESS_READ_WRITE>{dctx, m})
186 };
187 
188 // ==========================================================================================
189 // MatDense_CUPM::MatrixArray
190 // ==========================================================================================
191 
192 template <device::cupm::DeviceType T, typename D>
193 template <PetscMemType MT, PetscMemoryAccessMode MA>
194 class MatDense_CUPM<T, D>::MatrixArray : public device::cupm::impl::RestoreableArray<T, MT, MA> {
195   using base_type = device::cupm::impl::RestoreableArray<T, MT, MA>;
196 
197 public:
198   MatrixArray(PetscDeviceContext, Mat) noexcept;
199   ~MatrixArray() noexcept;
200 
201   // must declare move constructor since we declare a destructor
202   constexpr MatrixArray(MatrixArray &&) noexcept;
203 
204 private:
205   Mat m_ = nullptr;
206 };
207 
208 // ==========================================================================================
209 // MatDense_CUPM::MatrixArray -- Public API
210 // ==========================================================================================
211 
212 template <device::cupm::DeviceType T, typename D>
213 template <PetscMemType MT, PetscMemoryAccessMode MA>
214 inline MatDense_CUPM<T, D>::MatrixArray<MT, MA>::MatrixArray(PetscDeviceContext dctx, Mat m) noexcept : base_type{dctx}, m_{m}
215 {
216   PetscFunctionBegin;
217   PetscCallAbort(PETSC_COMM_SELF, D::template GetArray<MT, MA>(m, &this->ptr_, dctx));
218   PetscFunctionReturnVoid();
219 }
220 
221 template <device::cupm::DeviceType T, typename D>
222 template <PetscMemType MT, PetscMemoryAccessMode MA>
223 inline MatDense_CUPM<T, D>::MatrixArray<MT, MA>::~MatrixArray() noexcept
224 {
225   PetscFunctionBegin;
226   PetscCallAbort(PETSC_COMM_SELF, D::template RestoreArray<MT, MA>(m_, &this->ptr_, this->dctx_));
227   PetscFunctionReturnVoid();
228 }
229 
230 template <device::cupm::DeviceType T, typename D>
231 template <PetscMemType MT, PetscMemoryAccessMode MA>
232 inline constexpr MatDense_CUPM<T, D>::MatrixArray<MT, MA>::MatrixArray(MatrixArray &&other) noexcept : base_type{std::move(other)}, m_{util::exchange(other.m_, nullptr)}
233 {
234 }
235 
236 // ==========================================================================================
237 // MatDense_CUPM -- Private API
238 // ==========================================================================================
239 
240 template <device::cupm::DeviceType T, typename D>
241 inline PetscErrorCode MatDense_CUPM<T, D>::CheckSaneSequentialMatSizes_(Mat A) noexcept
242 {
243   PetscFunctionBegin;
244   if (PetscDefined(USE_DEBUG)) {
245     PetscBool isseq;
246 
247     PetscCall(PetscObjectTypeCompare(PetscObjectCast(A), D::MATSEQDENSECUPM(), &isseq));
248     if (isseq) {
249       // doing this check allows both sequential and parallel implementations to just pass in
250       // A, otherwise they would need to specify rstart, rend, and cols separately!
251       PetscCheck(A->rmap->rstart == 0, PETSC_COMM_SELF, PETSC_ERR_PLIB, "Sequential matrix row start %" PetscInt_FMT " != 0?", A->rmap->rstart);
252       PetscCheck(A->rmap->rend == A->rmap->n, PETSC_COMM_SELF, PETSC_ERR_PLIB, "Sequential matrix row end %" PetscInt_FMT " != number of rows %" PetscInt_FMT, A->rmap->rend, A->rmap->n);
253       PetscCheck(A->cmap->n == A->cmap->N, PETSC_COMM_SELF, PETSC_ERR_PLIB, "Sequential matrix number of local columns %" PetscInt_FMT " != number of global columns %" PetscInt_FMT, A->cmap->n, A->cmap->n);
254     }
255   }
256   PetscFunctionReturn(PETSC_SUCCESS);
257 }
258 
259 // ==========================================================================================
260 // MatDense_CUPM -- Protected API
261 // ==========================================================================================
262 
263 template <device::cupm::DeviceType T, typename D>
264 inline constexpr MatType MatDense_CUPM<T, D>::MATIMPLCUPM() noexcept
265 {
266   return D::MATIMPLCUPM_();
267 }
268 
269 // Common core for MatCreateSeqDenseCUPM() and MatCreateMPIDenseCUPM()
270 template <device::cupm::DeviceType T, typename D>
271 inline PetscErrorCode MatDense_CUPM<T, D>::CreateIMPLDenseCUPM(MPI_Comm comm, PetscInt m, PetscInt n, PetscInt M, PetscInt N, PetscScalar *data, Mat *A, PetscDeviceContext dctx, bool preallocate) noexcept
272 {
273   Mat mat;
274 
275   PetscFunctionBegin;
276   PetscValidPointer(A, 7);
277   PetscCall(MatCreate(comm, &mat));
278   PetscCall(MatSetSizes(mat, m, n, M, N));
279   PetscCall(MatSetType(mat, D::MATIMPLCUPM()));
280   if (preallocate) {
281     PetscCall(PetscDeviceContextGetOptionalNullContext_Internal(&dctx));
282     PetscCall(D::SetPreallocation(mat, dctx, data));
283   }
284   *A = mat;
285   PetscFunctionReturn(PETSC_SUCCESS);
286 }
287 
288 template <device::cupm::DeviceType T, typename D>
289 inline PetscErrorCode MatDense_CUPM<T, D>::SetPreallocation(Mat A, PetscDeviceContext dctx, PetscScalar *device_array) noexcept
290 {
291   PetscFunctionBegin;
292   // cannot use PetscValidHeaderSpecificType(..., MATIMPLCUPM()) since the incoming matrix
293   // might be the local (sequential) matrix of a MatMPIDense_CUPM. Since this would be called
294   // from the MPI matrix'es impl MATIMPLCUPM() would return MATMPIDENSECUPM().
295   PetscValidHeaderSpecific(A, MAT_CLASSID, 1);
296   PetscCheckTypeNames(A, D::MATSEQDENSECUPM(), D::MATMPIDENSECUPM());
297   PetscCall(PetscLayoutSetUp(A->rmap));
298   PetscCall(PetscLayoutSetUp(A->cmap));
299   PetscCall(PetscDeviceContextGetOptionalNullContext_Internal(&dctx));
300   PetscCall(D::SetPreallocation_(A, dctx, device_array));
301   A->preallocated = PETSC_TRUE;
302   A->assembled    = PETSC_TRUE;
303   PetscFunctionReturn(PETSC_SUCCESS);
304 }
305 
306 namespace detail
307 {
308 
309 // ==========================================================================================
310 // MatrixIteratorBase
311 //
312 // A base class for creating thrust iterators over the local sub-matrix. This will set up the
313 // proper iterator definitions so thrust knows how to handle things properly. Template
314 // parameters are as follows:
315 //
316 // - Iterator:
317 // The type of the primary array iterator. Usually this is
318 // thrust::device_pointer<PetscScalar>::iterator.
319 //
320 // - IndexFunctor:
321 // This should be a functor which contains an operator() that when called with an index `i`,
322 // returns the i'th permuted index into the array. For example, it could return the i'th
323 // diagonal entry.
324 // ==========================================================================================
325 template <typename Iterator, typename IndexFunctor>
326 class MatrixIteratorBase {
327 public:
328   using array_iterator_type = Iterator;
329   using index_functor_type  = IndexFunctor;
330 
331   using difference_type     = typename thrust::iterator_difference<array_iterator_type>::type;
332   using CountingIterator    = thrust::counting_iterator<difference_type>;
333   using TransformIterator   = thrust::transform_iterator<index_functor_type, CountingIterator>;
334   using PermutationIterator = thrust::permutation_iterator<array_iterator_type, TransformIterator>;
335   using iterator            = PermutationIterator; // type of the begin/end iterator
336 
337   constexpr MatrixIteratorBase(array_iterator_type first, array_iterator_type last, index_functor_type idx_func) noexcept : first{std::move(first)}, last{std::move(last)}, func{std::move(idx_func)} { }
338 
339   PETSC_NODISCARD iterator begin() const noexcept
340   {
341     return PermutationIterator{
342       first, TransformIterator{CountingIterator{0}, func}
343     };
344   }
345 
346 protected:
347   array_iterator_type first;
348   array_iterator_type last;
349   index_functor_type  func;
350 };
351 
352 // ==========================================================================================
353 // StridedIndexFunctor
354 //
355 // Iterator which permutes a linear index range into strided matrix indices. Usually used to
356 // get the diagonal.
357 // ==========================================================================================
358 template <typename T>
359 struct StridedIndexFunctor {
360   PETSC_NODISCARD PETSC_HOSTDEVICE_INLINE_DECL constexpr T operator()(const T &i) const noexcept { return stride * i; }
361 
362   T stride;
363 };
364 
365 template <typename Iterator>
366 class DiagonalIterator : public MatrixIteratorBase<Iterator, StridedIndexFunctor<typename thrust::iterator_difference<Iterator>::type>> {
367 public:
368   using base_type = MatrixIteratorBase<Iterator, StridedIndexFunctor<typename thrust::iterator_difference<Iterator>::type>>;
369 
370   using difference_type = typename base_type::difference_type;
371   using iterator        = typename base_type::iterator;
372 
373   constexpr DiagonalIterator(Iterator first, Iterator last, difference_type stride) noexcept : base_type{std::move(first), std::move(last), {stride}} { }
374 
375   PETSC_NODISCARD iterator end() const noexcept { return this->begin() + (this->last - this->first + this->func.stride - 1) / this->func.stride; }
376 };
377 
378 template <typename T>
379 inline DiagonalIterator<typename thrust::device_vector<T>::iterator> MakeDiagonalIterator(T *data, PetscInt rstart, PetscInt rend, PetscInt cols, PetscInt lda) noexcept
380 {
381   const auto        rend2 = std::min(rend, cols);
382   const std::size_t begin = rstart * lda;
383   const std::size_t end   = rend2 - rstart + rend2 * lda;
384   const auto        dptr  = thrust::device_pointer_cast(data);
385 
386   return {dptr + begin, dptr + end, lda + 1};
387 }
388 
389 } // namespace detail
390 
391 template <device::cupm::DeviceType T, typename D>
392 template <typename F>
393 inline PetscErrorCode MatDense_CUPM<T, D>::DiagonalUnaryTransform(Mat A, PetscDeviceContext dctx, F &&functor) noexcept
394 {
395   const auto rstart = A->rmap->rstart;
396   const auto rend   = A->rmap->rend;
397   const auto gcols  = A->cmap->N;
398   const auto rend2  = std::min(rend, gcols);
399 
400   PetscFunctionBegin;
401   PetscCall(CheckSaneSequentialMatSizes_(A));
402   if (rend2 > rstart) {
403     const auto   da = D::DeviceArrayReadWrite(dctx, A);
404     cupmStream_t stream;
405     PetscInt     lda;
406 
407     PetscCall(MatDenseGetLDA(A, &lda));
408     PetscCall(D::GetHandlesFrom_(dctx, &stream));
409     {
410       auto diagonal = detail::MakeDiagonalIterator(da.data(), rstart, rend, gcols, lda);
411 
412       // clang-format off
413       PetscCallThrust(
414         THRUST_CALL(
415           thrust::transform,
416           stream,
417           diagonal.begin(), diagonal.end(), diagonal.begin(),
418           std::forward<F>(functor)
419         )
420       );
421       // clang-format on
422     }
423     PetscCall(PetscLogGpuFlops(rend2 - rstart));
424   }
425   PetscFunctionReturn(PETSC_SUCCESS);
426 }
427 
428 template <device::cupm::DeviceType T, typename D>
429 inline PetscErrorCode MatDense_CUPM<T, D>::Shift(Mat A, PetscScalar alpha) noexcept
430 {
431   PetscDeviceContext dctx;
432 
433   PetscFunctionBegin;
434   PetscCall(GetHandles_(&dctx));
435   PetscCall(DiagonalUnaryTransform(A, dctx, device::cupm::functors::make_plus_equals(alpha)));
436   PetscFunctionReturn(PETSC_SUCCESS);
437 }
438 
439 template <device::cupm::DeviceType T, typename D>
440 inline PetscErrorCode MatDense_CUPM<T, D>::GetDiagonal(Mat A, Vec v) noexcept
441 {
442   const auto         rstart = A->rmap->rstart;
443   const auto         rend   = A->rmap->rend;
444   const auto         gcols  = A->cmap->N;
445   PetscInt           lda;
446   PetscDeviceContext dctx;
447 
448   PetscFunctionBegin;
449   PetscCall(CheckSaneSequentialMatSizes_(A));
450   PetscCall(GetHandles_(&dctx));
451   PetscCall(MatDenseGetLDA(A, &lda));
452   {
453     auto dv       = VecSeq_CUPM::DeviceArrayWrite(dctx, v);
454     auto da       = D::DeviceArrayRead(dctx, A);
455     auto diagonal = detail::MakeDiagonalIterator(da.data(), rstart, rend, gcols, lda);
456     // We must have this cast outside of THRUST_CALL(). Without it, GCC 6.4 - 7.5, and 11.3.0
457     // throw spurious warnings:
458     //
459     // warning: 'MatDense_CUPM<...>::GetDiagonal(Mat, Vec)::<lambda()>' declared with greater
460     // visibility than the type of its field 'MatDense_CUPM<...>::GetDiagonal(Mat,
461     // Vec)::<lambda()>::<dv capture>' [-Wattributes]
462     // 460 |     PetscCallThrust(
463     //     |     ^~~~~~~~~~~~~~~~
464     auto         dvp = thrust::device_pointer_cast(dv.data());
465     cupmStream_t stream;
466 
467     PetscCall(GetHandlesFrom_(dctx, &stream));
468     PetscCallThrust(THRUST_CALL(thrust::copy, stream, diagonal.begin(), diagonal.end(), dvp));
469   }
470   PetscFunctionReturn(PETSC_SUCCESS);
471 }
472 
473   #define MatComposeOp_CUPM(use_host, pobj, op_str, op_host, ...) \
474     do { \
475       if (use_host) { \
476         PetscCall(PetscObjectComposeFunction(pobj, op_str, op_host)); \
477       } else { \
478         PetscCall(PetscObjectComposeFunction(pobj, op_str, __VA_ARGS__)); \
479       } \
480     } while (0)
481 
482   #define MatSetOp_CUPM(use_host, mat, op_name, op_host, ...) \
483     do { \
484       if (use_host) { \
485         (mat)->ops->op_name = op_host; \
486       } else { \
487         (mat)->ops->op_name = __VA_ARGS__; \
488       } \
489     } while (0)
490 
491   #define MATDENSECUPM_HEADER(T, ...) \
492     MATDENSECUPM_BASE_HEADER(T); \
493     friend class ::Petsc::mat::cupm::impl::MatDense_CUPM<T, __VA_ARGS__>; \
494     using ::Petsc::mat::cupm::impl::MatDense_CUPM<T, __VA_ARGS__>::MatIMPLCast; \
495     using ::Petsc::mat::cupm::impl::MatDense_CUPM<T, __VA_ARGS__>::MATIMPLCUPM; \
496     using ::Petsc::mat::cupm::impl::MatDense_CUPM<T, __VA_ARGS__>::CreateIMPLDenseCUPM; \
497     using ::Petsc::mat::cupm::impl::MatDense_CUPM<T, __VA_ARGS__>::SetPreallocation; \
498     using ::Petsc::mat::cupm::impl::MatDense_CUPM<T, __VA_ARGS__>::DeviceArrayRead; \
499     using ::Petsc::mat::cupm::impl::MatDense_CUPM<T, __VA_ARGS__>::DeviceArrayWrite; \
500     using ::Petsc::mat::cupm::impl::MatDense_CUPM<T, __VA_ARGS__>::DeviceArrayReadWrite; \
501     using ::Petsc::mat::cupm::impl::MatDense_CUPM<T, __VA_ARGS__>::HostArrayRead; \
502     using ::Petsc::mat::cupm::impl::MatDense_CUPM<T, __VA_ARGS__>::HostArrayWrite; \
503     using ::Petsc::mat::cupm::impl::MatDense_CUPM<T, __VA_ARGS__>::HostArrayReadWrite; \
504     using ::Petsc::mat::cupm::impl::MatDense_CUPM<T, __VA_ARGS__>::DiagonalUnaryTransform; \
505     using ::Petsc::mat::cupm::impl::MatDense_CUPM<T, __VA_ARGS__>::Shift; \
506     using ::Petsc::mat::cupm::impl::MatDense_CUPM<T, __VA_ARGS__>::GetDiagonal
507 
508 } // namespace impl
509 
510 // ==========================================================================================
511 // MatDense_CUPM -- Implementations
512 // ==========================================================================================
513 
514 template <device::cupm::DeviceType T, PetscMemoryAccessMode access>
515 inline PetscErrorCode MatDenseCUPMGetArray_Private(Mat A, PetscScalar **array) noexcept
516 {
517   PetscFunctionBegin;
518   PetscValidHeaderSpecific(A, MAT_CLASSID, 1);
519   PetscValidPointer(array, 2);
520   switch (access) {
521   case PETSC_MEMORY_ACCESS_READ:
522     PetscUseMethod(A, impl::MatDense_CUPM_Base<T>::MatDenseCUPMGetArrayRead_C(), (Mat, PetscScalar **), (A, array));
523     break;
524   case PETSC_MEMORY_ACCESS_WRITE:
525     PetscUseMethod(A, impl::MatDense_CUPM_Base<T>::MatDenseCUPMGetArrayWrite_C(), (Mat, PetscScalar **), (A, array));
526     break;
527   case PETSC_MEMORY_ACCESS_READ_WRITE:
528     PetscUseMethod(A, impl::MatDense_CUPM_Base<T>::MatDenseCUPMGetArray_C(), (Mat, PetscScalar **), (A, array));
529     break;
530   }
531   if (PetscMemoryAccessWrite(access)) PetscCall(PetscObjectStateIncrease(PetscObjectCast(A)));
532   PetscFunctionReturn(PETSC_SUCCESS);
533 }
534 
535 template <device::cupm::DeviceType T, PetscMemoryAccessMode access>
536 inline PetscErrorCode MatDenseCUPMRestoreArray_Private(Mat A, PetscScalar **array) noexcept
537 {
538   PetscFunctionBegin;
539   PetscValidHeaderSpecific(A, MAT_CLASSID, 1);
540   if (array) PetscValidPointer(array, 2);
541   switch (access) {
542   case PETSC_MEMORY_ACCESS_READ:
543     PetscUseMethod(A, impl::MatDense_CUPM_Base<T>::MatDenseCUPMRestoreArrayRead_C(), (Mat, PetscScalar **), (A, array));
544     break;
545   case PETSC_MEMORY_ACCESS_WRITE:
546     PetscUseMethod(A, impl::MatDense_CUPM_Base<T>::MatDenseCUPMRestoreArrayWrite_C(), (Mat, PetscScalar **), (A, array));
547     break;
548   case PETSC_MEMORY_ACCESS_READ_WRITE:
549     PetscUseMethod(A, impl::MatDense_CUPM_Base<T>::MatDenseCUPMRestoreArray_C(), (Mat, PetscScalar **), (A, array));
550     break;
551   }
552   if (PetscMemoryAccessWrite(access)) {
553     PetscCall(PetscObjectStateIncrease(PetscObjectCast(A)));
554     A->offloadmask = PETSC_OFFLOAD_GPU;
555   }
556   if (array) *array = nullptr;
557   PetscFunctionReturn(PETSC_SUCCESS);
558 }
559 
560 template <device::cupm::DeviceType T>
561 inline PetscErrorCode MatDenseCUPMGetArray(Mat A, PetscScalar **array) noexcept
562 {
563   PetscFunctionBegin;
564   PetscCall(MatDenseCUPMGetArray_Private<T, PETSC_MEMORY_ACCESS_READ_WRITE>(A, array));
565   PetscFunctionReturn(PETSC_SUCCESS);
566 }
567 
568 template <device::cupm::DeviceType T>
569 inline PetscErrorCode MatDenseCUPMGetArrayRead(Mat A, const PetscScalar **array) noexcept
570 {
571   PetscFunctionBegin;
572   PetscCall(MatDenseCUPMGetArray_Private<T, PETSC_MEMORY_ACCESS_READ>(A, const_cast<PetscScalar **>(array)));
573   PetscFunctionReturn(PETSC_SUCCESS);
574 }
575 
576 template <device::cupm::DeviceType T>
577 inline PetscErrorCode MatDenseCUPMGetArrayWrite(Mat A, PetscScalar **array) noexcept
578 {
579   PetscFunctionBegin;
580   PetscCall(MatDenseCUPMGetArray_Private<T, PETSC_MEMORY_ACCESS_WRITE>(A, array));
581   PetscFunctionReturn(PETSC_SUCCESS);
582 }
583 
584 template <device::cupm::DeviceType T>
585 inline PetscErrorCode MatDenseCUPMRestoreArray(Mat A, PetscScalar **array) noexcept
586 {
587   PetscFunctionBegin;
588   PetscCall(MatDenseCUPMRestoreArray_Private<T, PETSC_MEMORY_ACCESS_READ_WRITE>(A, array));
589   PetscFunctionReturn(PETSC_SUCCESS);
590 }
591 
592 template <device::cupm::DeviceType T>
593 inline PetscErrorCode MatDenseCUPMRestoreArrayRead(Mat A, const PetscScalar **array) noexcept
594 {
595   PetscFunctionBegin;
596   PetscCall(MatDenseCUPMRestoreArray_Private<T, PETSC_MEMORY_ACCESS_READ>(A, const_cast<PetscScalar **>(array)));
597   PetscFunctionReturn(PETSC_SUCCESS);
598 }
599 
600 template <device::cupm::DeviceType T>
601 inline PetscErrorCode MatDenseCUPMRestoreArrayWrite(Mat A, PetscScalar **array) noexcept
602 {
603   PetscFunctionBegin;
604   PetscCall(MatDenseCUPMRestoreArray_Private<T, PETSC_MEMORY_ACCESS_WRITE>(A, array));
605   PetscFunctionReturn(PETSC_SUCCESS);
606 }
607 
608 template <device::cupm::DeviceType T>
609 inline PetscErrorCode MatDenseCUPMPlaceArray(Mat A, const PetscScalar *array) noexcept
610 {
611   PetscFunctionBegin;
612   PetscValidHeaderSpecific(A, MAT_CLASSID, 1);
613   PetscUseMethod(A, impl::MatDense_CUPM_Base<T>::MatDenseCUPMPlaceArray_C(), (Mat, const PetscScalar *), (A, array));
614   PetscCall(PetscObjectStateIncrease(PetscObjectCast(A)));
615   A->offloadmask = PETSC_OFFLOAD_GPU;
616   PetscFunctionReturn(PETSC_SUCCESS);
617 }
618 
619 template <device::cupm::DeviceType T>
620 inline PetscErrorCode MatDenseCUPMReplaceArray(Mat A, const PetscScalar *array) noexcept
621 {
622   PetscFunctionBegin;
623   PetscValidHeaderSpecific(A, MAT_CLASSID, 1);
624   PetscUseMethod(A, impl::MatDense_CUPM_Base<T>::MatDenseCUPMReplaceArray_C(), (Mat, const PetscScalar *), (A, array));
625   PetscCall(PetscObjectStateIncrease(PetscObjectCast(A)));
626   A->offloadmask = PETSC_OFFLOAD_GPU;
627   PetscFunctionReturn(PETSC_SUCCESS);
628 }
629 
630 template <device::cupm::DeviceType T>
631 inline PetscErrorCode MatDenseCUPMResetArray(Mat A) noexcept
632 {
633   PetscFunctionBegin;
634   PetscValidHeaderSpecific(A, MAT_CLASSID, 1);
635   PetscUseMethod(A, impl::MatDense_CUPM_Base<T>::MatDenseCUPMResetArray_C(), (Mat), (A));
636   PetscCall(PetscObjectStateIncrease(PetscObjectCast(A)));
637   PetscFunctionReturn(PETSC_SUCCESS);
638 }
639 
640 template <device::cupm::DeviceType T>
641 inline PetscErrorCode MatDenseCUPMSetPreallocation(Mat A, PetscScalar *device_data, PetscDeviceContext dctx = nullptr) noexcept
642 {
643   PetscFunctionBegin;
644   PetscValidHeaderSpecific(A, MAT_CLASSID, 1);
645   PetscUseMethod(A, impl::MatDense_CUPM_Base<T>::MatDenseCUPMSetPreallocation_C(), (Mat, PetscDeviceContext, PetscScalar *), (A, dctx, device_data));
646   PetscFunctionReturn(PETSC_SUCCESS);
647 }
648 
649 } // namespace cupm
650 
651 } // namespace mat
652 
653 } // namespace Petsc
654 
655 #endif // __cplusplus
656 
657 #endif // PETSCMATDENSECUPMIMPL_H
658