xref: /petsc/include/petsc/private/matdensecupmimpl.h (revision a68bbae58a07f2fb515cab24a67de1159d72e8a2)
1 #ifndef PETSCMATDENSECUPMIMPL_H
2 #define PETSCMATDENSECUPMIMPL_H
3 
4 #define PETSC_SKIP_IMMINTRIN_H_CUDAWORKAROUND 1
5 #include <petsc/private/matimpl.h> /*I <petscmat.h> I*/
6 
7 #ifdef __cplusplus
8   #include <petsc/private/deviceimpl.h>
9   #include <petsc/private/cupmsolverinterface.hpp>
10   #include <petsc/private/cupmobject.hpp>
11 
12   #include "../src/sys/objects/device/impls/cupm/cupmthrustutility.hpp"
13   #include "../src/sys/objects/device/impls/cupm/kernels.hpp"
14 
15   #include <thrust/device_vector.h>
16   #include <thrust/device_ptr.h>
17   #include <thrust/iterator/counting_iterator.h>
18   #include <thrust/iterator/transform_iterator.h>
19   #include <thrust/iterator/permutation_iterator.h>
20   #include <thrust/transform.h>
21 
22 namespace Petsc
23 {
24 
25 namespace vec
26 {
27 
28 namespace cupm
29 {
30 
31 namespace impl
32 {
33 
34 template <device::cupm::DeviceType>
35 class VecSeq_CUPM;
36 template <device::cupm::DeviceType>
37 class VecMPI_CUPM;
38 
39 } // namespace impl
40 
41 } // namespace cupm
42 
43 } // namespace vec
44 
45 namespace mat
46 {
47 
48 namespace cupm
49 {
50 
51 namespace impl
52 {
53 
54 // ==========================================================================================
55 // MatDense_CUPM_Base
56 //
57 // A base class to separate out the CRTP code from the common CUPM stuff (like the composed
58 // function names).
59 // ==========================================================================================
60 
61 template <device::cupm::DeviceType T>
62 class MatDense_CUPM_Base : protected device::cupm::impl::CUPMObject<T> {
63 public:
64   PETSC_CUPMOBJECT_HEADER(T);
65 
66   #define MatDenseCUPMComposedOpDecl(OP_NAME) \
67     PETSC_NODISCARD static constexpr const char *PetscConcat(MatDenseCUPM, OP_NAME)() noexcept \
68     { \
69       return T == device::cupm::DeviceType::CUDA ? PetscStringize(PetscConcat(MatDenseCUDA, OP_NAME)) : PetscStringize(PetscConcat(MatDenseHIP, OP_NAME)); \
70     }
71 
72   // clang-format off
73   MatDenseCUPMComposedOpDecl(GetArray_C)
74   MatDenseCUPMComposedOpDecl(GetArrayRead_C)
75   MatDenseCUPMComposedOpDecl(GetArrayWrite_C)
76   MatDenseCUPMComposedOpDecl(RestoreArray_C)
77   MatDenseCUPMComposedOpDecl(RestoreArrayRead_C)
78   MatDenseCUPMComposedOpDecl(RestoreArrayWrite_C)
79   MatDenseCUPMComposedOpDecl(PlaceArray_C)
80   MatDenseCUPMComposedOpDecl(ReplaceArray_C)
81   MatDenseCUPMComposedOpDecl(ResetArray_C)
82     // clang-format on
83 
84   #undef MatDenseCUPMComposedOpDecl
85 
86     PETSC_NODISCARD static constexpr MatType MATSEQDENSECUPM() noexcept;
87   PETSC_NODISCARD static constexpr MatType       MATMPIDENSECUPM() noexcept;
88   PETSC_NODISCARD static constexpr MatType       MATDENSECUPM() noexcept;
89   PETSC_NODISCARD static constexpr MatSolverType MATSOLVERCUPM() noexcept;
90 };
91 
92 // ==========================================================================================
93 // MatDense_CUPM_Base -- Public API
94 // ==========================================================================================
95 
96 template <device::cupm::DeviceType T>
97 inline constexpr MatType MatDense_CUPM_Base<T>::MATSEQDENSECUPM() noexcept
98 {
99   return T == device::cupm::DeviceType::CUDA ? MATSEQDENSECUDA : MATSEQDENSEHIP;
100 }
101 
102 template <device::cupm::DeviceType T>
103 inline constexpr MatType MatDense_CUPM_Base<T>::MATMPIDENSECUPM() noexcept
104 {
105   return T == device::cupm::DeviceType::CUDA ? MATMPIDENSECUDA : MATMPIDENSEHIP;
106 }
107 
108 template <device::cupm::DeviceType T>
109 inline constexpr MatType MatDense_CUPM_Base<T>::MATDENSECUPM() noexcept
110 {
111   return T == device::cupm::DeviceType::CUDA ? MATDENSECUDA : MATDENSEHIP;
112 }
113 
114 template <device::cupm::DeviceType T>
115 inline constexpr MatSolverType MatDense_CUPM_Base<T>::MATSOLVERCUPM() noexcept
116 {
117   return T == device::cupm::DeviceType::CUDA ? MATSOLVERCUDA : MATSOLVERHIP;
118 }
119 
120   #define MATDENSECUPM_BASE_HEADER(T) \
121     PETSC_CUPMOBJECT_HEADER(T); \
122     using VecSeq_CUPM = ::Petsc::vec::cupm::impl::VecSeq_CUPM<T>; \
123     using VecMPI_CUPM = ::Petsc::vec::cupm::impl::VecMPI_CUPM<T>; \
124     using ::Petsc::mat::cupm::impl::MatDense_CUPM_Base<T>::MATSEQDENSECUPM; \
125     using ::Petsc::mat::cupm::impl::MatDense_CUPM_Base<T>::MATMPIDENSECUPM; \
126     using ::Petsc::mat::cupm::impl::MatDense_CUPM_Base<T>::MATDENSECUPM; \
127     using ::Petsc::mat::cupm::impl::MatDense_CUPM_Base<T>::MATSOLVERCUPM; \
128     using ::Petsc::mat::cupm::impl::MatDense_CUPM_Base<T>::MatDenseCUPMGetArray_C; \
129     using ::Petsc::mat::cupm::impl::MatDense_CUPM_Base<T>::MatDenseCUPMGetArrayRead_C; \
130     using ::Petsc::mat::cupm::impl::MatDense_CUPM_Base<T>::MatDenseCUPMGetArrayWrite_C; \
131     using ::Petsc::mat::cupm::impl::MatDense_CUPM_Base<T>::MatDenseCUPMRestoreArray_C; \
132     using ::Petsc::mat::cupm::impl::MatDense_CUPM_Base<T>::MatDenseCUPMRestoreArrayRead_C; \
133     using ::Petsc::mat::cupm::impl::MatDense_CUPM_Base<T>::MatDenseCUPMRestoreArrayWrite_C; \
134     using ::Petsc::mat::cupm::impl::MatDense_CUPM_Base<T>::MatDenseCUPMPlaceArray_C; \
135     using ::Petsc::mat::cupm::impl::MatDense_CUPM_Base<T>::MatDenseCUPMReplaceArray_C; \
136     using ::Petsc::mat::cupm::impl::MatDense_CUPM_Base<T>::MatDenseCUPMResetArray_C
137 
138 // forward declare
139 template <device::cupm::DeviceType>
140 class MatDense_Seq_CUPM;
141 template <device::cupm::DeviceType>
142 class MatDense_MPI_CUPM;
143 
144 // ==========================================================================================
145 // MatDense_CUPM
146 //
147 // The true "base" class for MatDenseCUPM. The reason MatDense_CUPM and MatDense_CUPM_Base
148 // exist is to separate out the CRTP code from the non-crtp code so that the generic functions
149 // can be called via templates below.
150 // ==========================================================================================
151 
152 template <device::cupm::DeviceType T, typename Derived>
153 class MatDense_CUPM : protected MatDense_CUPM_Base<T> {
154 protected:
155   MATDENSECUPM_BASE_HEADER(T);
156 
157   template <PetscMemType, PetscMemoryAccessMode>
158   class MatrixArray;
159 
160   // Cast the Mat to its host struct, i.e. return the result of (Mat_SeqDense *)m->data
161   template <typename U = Derived>
162   PETSC_NODISCARD static constexpr auto    MatIMPLCast(Mat m) noexcept PETSC_DECLTYPE_AUTO_RETURNS(U::MatIMPLCast_(m))
163   PETSC_NODISCARD static constexpr MatType MATIMPLCUPM() noexcept;
164 
165   static PetscErrorCode CreateIMPLDenseCUPM(MPI_Comm, PetscInt, PetscInt, PetscInt, PetscInt, PetscScalar *, Mat *, PetscDeviceContext, bool) noexcept;
166   static PetscErrorCode SetPreallocation(Mat, PetscDeviceContext, PetscScalar * = nullptr) noexcept;
167 
168   template <typename F>
169   static PetscErrorCode DiagonalUnaryTransform(Mat, PetscInt, PetscInt, PetscInt, PetscDeviceContext, F &&) noexcept;
170 
171   PETSC_NODISCARD static auto DeviceArrayRead(PetscDeviceContext dctx, Mat m) noexcept PETSC_DECLTYPE_AUTO_RETURNS(MatrixArray<PETSC_MEMTYPE_DEVICE, PETSC_MEMORY_ACCESS_READ>{dctx, m})
172   PETSC_NODISCARD static auto DeviceArrayWrite(PetscDeviceContext dctx, Mat m) noexcept PETSC_DECLTYPE_AUTO_RETURNS(MatrixArray<PETSC_MEMTYPE_DEVICE, PETSC_MEMORY_ACCESS_WRITE>{dctx, m})
173   PETSC_NODISCARD static auto DeviceArrayReadWrite(PetscDeviceContext dctx, Mat m) noexcept PETSC_DECLTYPE_AUTO_RETURNS(MatrixArray<PETSC_MEMTYPE_DEVICE, PETSC_MEMORY_ACCESS_READ_WRITE>{dctx, m})
174   PETSC_NODISCARD static auto HostArrayRead(PetscDeviceContext dctx, Mat m) noexcept PETSC_DECLTYPE_AUTO_RETURNS(MatrixArray<PETSC_MEMTYPE_HOST, PETSC_MEMORY_ACCESS_READ>{dctx, m})
175   PETSC_NODISCARD static auto HostArrayWrite(PetscDeviceContext dctx, Mat m) noexcept PETSC_DECLTYPE_AUTO_RETURNS(MatrixArray<PETSC_MEMTYPE_HOST, PETSC_MEMORY_ACCESS_WRITE>{dctx, m})
176   PETSC_NODISCARD static auto HostArrayReadWrite(PetscDeviceContext dctx, Mat m) noexcept PETSC_DECLTYPE_AUTO_RETURNS(MatrixArray<PETSC_MEMTYPE_HOST, PETSC_MEMORY_ACCESS_READ_WRITE>{dctx, m})
177 };
178 
179 // ==========================================================================================
180 // MatDense_CUPM::MatrixArray
181 // ==========================================================================================
182 
183 template <device::cupm::DeviceType T, typename D>
184 template <PetscMemType MT, PetscMemoryAccessMode MA>
185 class MatDense_CUPM<T, D>::MatrixArray : public device::cupm::impl::RestoreableArray<T, MT, MA> {
186   using base_type = device::cupm::impl::RestoreableArray<T, MT, MA>;
187 
188 public:
189   MatrixArray(PetscDeviceContext, Mat) noexcept;
190   ~MatrixArray() noexcept;
191 
192   // must declare move constructor since we declare a destructor
193   constexpr MatrixArray(MatrixArray &&) noexcept;
194 
195 private:
196   Mat m_ = nullptr;
197 };
198 
199 // ==========================================================================================
200 // MatDense_CUPM::MatrixArray -- Public API
201 // ==========================================================================================
202 
203 template <device::cupm::DeviceType T, typename D>
204 template <PetscMemType MT, PetscMemoryAccessMode MA>
205 inline MatDense_CUPM<T, D>::MatrixArray<MT, MA>::MatrixArray(PetscDeviceContext dctx, Mat m) noexcept : base_type{dctx}, m_{m}
206 {
207   PetscFunctionBegin;
208   PetscCallAbort(PETSC_COMM_SELF, D::template GetArray<MT, MA>(m, &this->ptr_, dctx));
209   PetscFunctionReturnVoid();
210 }
211 
212 template <device::cupm::DeviceType T, typename D>
213 template <PetscMemType MT, PetscMemoryAccessMode MA>
214 inline MatDense_CUPM<T, D>::MatrixArray<MT, MA>::~MatrixArray() noexcept
215 {
216   PetscFunctionBegin;
217   PetscCallAbort(PETSC_COMM_SELF, D::template RestoreArray<MT, MA>(m_, &this->ptr_, this->dctx_));
218   PetscFunctionReturnVoid();
219 }
220 
221 template <device::cupm::DeviceType T, typename D>
222 template <PetscMemType MT, PetscMemoryAccessMode MA>
223 inline constexpr MatDense_CUPM<T, D>::MatrixArray<MT, MA>::MatrixArray(MatrixArray &&other) noexcept : base_type{std::move(other)}, m_{util::exchange(other.m_, nullptr)}
224 {
225 }
226 
227 // ==========================================================================================
228 // MatDense_CUPM -- Protected API
229 // ==========================================================================================
230 
231 template <device::cupm::DeviceType T, typename D>
232 inline constexpr MatType MatDense_CUPM<T, D>::MATIMPLCUPM() noexcept
233 {
234   return D::MATIMPLCUPM_();
235 }
236 
237 // Common core for MatCreateSeqDenseCUPM() and MatCreateMPIDenseCUPM()
238 template <device::cupm::DeviceType T, typename D>
239 inline PetscErrorCode MatDense_CUPM<T, D>::CreateIMPLDenseCUPM(MPI_Comm comm, PetscInt m, PetscInt n, PetscInt M, PetscInt N, PetscScalar *data, Mat *A, PetscDeviceContext dctx, bool preallocate) noexcept
240 {
241   Mat mat;
242 
243   PetscFunctionBegin;
244   PetscValidPointer(A, 7);
245   PetscCall(MatCreate(comm, &mat));
246   PetscCall(MatSetSizes(mat, m, n, M, N));
247   PetscCall(MatSetType(mat, D::MATIMPLCUPM()));
248   if (preallocate) {
249     PetscCall(PetscDeviceContextGetOptionalNullContext_Internal(&dctx));
250     PetscCall(D::SetPreallocation(mat, dctx, data));
251   }
252   *A = mat;
253   PetscFunctionReturn(PETSC_SUCCESS);
254 }
255 
256 template <device::cupm::DeviceType T, typename D>
257 inline PetscErrorCode MatDense_CUPM<T, D>::SetPreallocation(Mat A, PetscDeviceContext dctx, PetscScalar *device_array) noexcept
258 {
259   PetscFunctionBegin;
260   // cannot use PetscValidHeaderSpecificType(..., MATIMPLCUPM()) since the incoming matrix
261   // might be the local (sequential) matrix of a MatMPIDense_CUPM. Since this would be called
262   // from the MPI matrix'es impl MATIMPLCUPM() would return MATMPIDENSECUPM().
263   PetscValidHeaderSpecific(A, MAT_CLASSID, 1);
264   PetscCheckTypeNames(A, D::MATSEQDENSECUPM(), D::MATMPIDENSECUPM());
265   PetscCall(PetscLayoutSetUp(A->rmap));
266   PetscCall(PetscLayoutSetUp(A->cmap));
267   PetscCall(D::SetPreallocation_(A, dctx, device_array));
268   A->preallocated = PETSC_TRUE;
269   A->assembled    = PETSC_TRUE;
270   PetscFunctionReturn(PETSC_SUCCESS);
271 }
272 
273 namespace detail
274 {
275 
276 // ==========================================================================================
277 // MatrixIteratorBase
278 //
279 // A base class for creating thrust iterators over the local sub-matrix. This will set up the
280 // proper iterator definitions so thrust knows how to handle things properly. Template
281 // parameters are as follows:
282 //
283 // - Iterator:
284 // The type of the primary array iterator. Usually this is
285 // thrust::device_pointer<PetscScalar>::iterator.
286 //
287 // - IndexFunctor:
288 // This should be a functor which contains an operator() that when called with an index `i`,
289 // returns the i'th permuted index into the array. For example, it could return the i'th
290 // diagonal entry.
291 // ==========================================================================================
292 template <typename Iterator, typename IndexFunctor>
293 class MatrixIteratorBase {
294 public:
295   using array_iterator_type = Iterator;
296   using index_functor_type  = IndexFunctor;
297 
298   using difference_type     = typename thrust::iterator_difference<array_iterator_type>::type;
299   using CountingIterator    = thrust::counting_iterator<difference_type>;
300   using TransformIterator   = thrust::transform_iterator<index_functor_type, CountingIterator>;
301   using PermutationIterator = thrust::permutation_iterator<array_iterator_type, TransformIterator>;
302   using iterator            = PermutationIterator; // type of the begin/end iterator
303 
304   constexpr MatrixIteratorBase(array_iterator_type first, array_iterator_type last, index_functor_type idx_func) noexcept : first{std::move(first)}, last{std::move(last)}, func{std::move(idx_func)} { }
305 
306   PETSC_NODISCARD iterator begin() const noexcept
307   {
308     return PermutationIterator{
309       first, TransformIterator{CountingIterator{0}, func}
310     };
311   }
312 
313 protected:
314   array_iterator_type first;
315   array_iterator_type last;
316   index_functor_type  func;
317 };
318 
319 // ==========================================================================================
320 // StridedIndexFunctor
321 //
322 // Iterator which permutes a linear index range into strided matrix indices. Usually used to
323 // get the diagonal.
324 // ==========================================================================================
325 template <typename T>
326 struct StridedIndexFunctor {
327   PETSC_NODISCARD PETSC_HOSTDEVICE_INLINE_DECL constexpr T operator()(const T &i) const noexcept { return stride * i; }
328 
329   T stride;
330 };
331 
332 template <typename Iterator>
333 class DiagonalIterator : public MatrixIteratorBase<Iterator, StridedIndexFunctor<typename thrust::iterator_difference<Iterator>::type>> {
334 public:
335   using base_type = MatrixIteratorBase<Iterator, StridedIndexFunctor<typename thrust::iterator_difference<Iterator>::type>>;
336 
337   using difference_type = typename base_type::difference_type;
338   using iterator        = typename base_type::iterator;
339 
340   constexpr DiagonalIterator(Iterator first, Iterator last, difference_type stride) noexcept : base_type{std::move(first), std::move(last), {stride}} { }
341 
342   PETSC_NODISCARD iterator end() const noexcept { return this->begin() + (this->last - this->first + this->func.stride - 1) / this->func.stride; }
343 };
344 
345 } // namespace detail
346 
347 template <device::cupm::DeviceType T, typename D>
348 template <typename F>
349 inline PetscErrorCode MatDense_CUPM<T, D>::DiagonalUnaryTransform(Mat A, PetscInt rstart, PetscInt rend, PetscInt cols, PetscDeviceContext dctx, F &&functor) noexcept
350 {
351   const auto rend2 = std::min(rend, cols);
352 
353   PetscFunctionBegin;
354   if (rend2 > rstart) {
355     const auto da = D::DeviceArrayReadWrite(dctx, A);
356     PetscInt   lda;
357 
358     PetscCall(MatDenseGetLDA(A, &lda));
359     {
360       using DiagonalIterator  = detail::DiagonalIterator<thrust::device_vector<PetscScalar>::iterator>;
361       const auto        dptr  = thrust::device_pointer_cast(da.data());
362       const std::size_t begin = rstart * lda;
363       const std::size_t end   = rend2 - rstart + rend2 * lda;
364       DiagonalIterator  diagonal{dptr + begin, dptr + end, lda + 1};
365       cupmStream_t      stream;
366 
367       PetscCall(D::GetHandlesFrom_(dctx, &stream));
368       // clang-format off
369       PetscCallThrust(
370         THRUST_CALL(
371           thrust::transform,
372           stream,
373           diagonal.begin(), diagonal.end(), diagonal.begin(),
374           std::forward<F>(functor)
375         )
376       );
377       // clang-format on
378     }
379     PetscCall(PetscLogGpuFlops(rend2 - rstart));
380   }
381   PetscFunctionReturn(PETSC_SUCCESS);
382 }
383 
384   #define MatComposeOp_CUPM(use_host, pobj, op_str, op_host, ...) \
385     do { \
386       if (use_host) { \
387         PetscCall(PetscObjectComposeFunction(pobj, op_str, op_host)); \
388       } else { \
389         PetscCall(PetscObjectComposeFunction(pobj, op_str, __VA_ARGS__)); \
390       } \
391     } while (0)
392 
393   #define MatSetOp_CUPM(use_host, mat, op_name, op_host, ...) \
394     do { \
395       if (use_host) { \
396         (mat)->ops->op_name = op_host; \
397       } else { \
398         (mat)->ops->op_name = __VA_ARGS__; \
399       } \
400     } while (0)
401 
402   #define MATDENSECUPM_HEADER(T, ...) \
403     MATDENSECUPM_BASE_HEADER(T); \
404     friend class ::Petsc::mat::cupm::impl::MatDense_CUPM<T, __VA_ARGS__>; \
405     using ::Petsc::mat::cupm::impl::MatDense_CUPM<T, __VA_ARGS__>::MatIMPLCast; \
406     using ::Petsc::mat::cupm::impl::MatDense_CUPM<T, __VA_ARGS__>::MATIMPLCUPM; \
407     using ::Petsc::mat::cupm::impl::MatDense_CUPM<T, __VA_ARGS__>::CreateIMPLDenseCUPM; \
408     using ::Petsc::mat::cupm::impl::MatDense_CUPM<T, __VA_ARGS__>::SetPreallocation; \
409     using ::Petsc::mat::cupm::impl::MatDense_CUPM<T, __VA_ARGS__>::DeviceArrayRead; \
410     using ::Petsc::mat::cupm::impl::MatDense_CUPM<T, __VA_ARGS__>::DeviceArrayWrite; \
411     using ::Petsc::mat::cupm::impl::MatDense_CUPM<T, __VA_ARGS__>::DeviceArrayReadWrite; \
412     using ::Petsc::mat::cupm::impl::MatDense_CUPM<T, __VA_ARGS__>::HostArrayRead; \
413     using ::Petsc::mat::cupm::impl::MatDense_CUPM<T, __VA_ARGS__>::HostArrayWrite; \
414     using ::Petsc::mat::cupm::impl::MatDense_CUPM<T, __VA_ARGS__>::HostArrayReadWrite; \
415     using ::Petsc::mat::cupm::impl::MatDense_CUPM<T, __VA_ARGS__>::DiagonalUnaryTransform
416 
417 } // namespace impl
418 
419 namespace
420 {
421 
422 template <device::cupm::DeviceType T, PetscMemoryAccessMode access>
423 inline PetscErrorCode MatDenseCUPMGetArray_Private(Mat A, PetscScalar **array) noexcept
424 {
425   PetscFunctionBegin;
426   PetscValidHeaderSpecific(A, MAT_CLASSID, 1);
427   PetscValidPointer(array, 2);
428   switch (access) {
429   case PETSC_MEMORY_ACCESS_READ:
430     PetscUseMethod(A, impl::MatDense_CUPM_Base<T>::MatDenseCUPMGetArrayRead_C(), (Mat, PetscScalar **), (A, array));
431     break;
432   case PETSC_MEMORY_ACCESS_WRITE:
433     PetscUseMethod(A, impl::MatDense_CUPM_Base<T>::MatDenseCUPMGetArrayWrite_C(), (Mat, PetscScalar **), (A, array));
434     break;
435   case PETSC_MEMORY_ACCESS_READ_WRITE:
436     PetscUseMethod(A, impl::MatDense_CUPM_Base<T>::MatDenseCUPMGetArray_C(), (Mat, PetscScalar **), (A, array));
437     break;
438   }
439   if (PetscMemoryAccessWrite(access)) PetscCall(PetscObjectStateIncrease(PetscObjectCast(A)));
440   PetscFunctionReturn(PETSC_SUCCESS);
441 }
442 
443 template <device::cupm::DeviceType T, PetscMemoryAccessMode access>
444 inline PetscErrorCode MatDenseCUPMRestoreArray_Private(Mat A, PetscScalar **array) noexcept
445 {
446   PetscFunctionBegin;
447   PetscValidHeaderSpecific(A, MAT_CLASSID, 1);
448   if (array) PetscValidPointer(array, 2);
449   switch (access) {
450   case PETSC_MEMORY_ACCESS_READ:
451     PetscUseMethod(A, impl::MatDense_CUPM_Base<T>::MatDenseCUPMRestoreArrayRead_C(), (Mat, PetscScalar **), (A, array));
452     break;
453   case PETSC_MEMORY_ACCESS_WRITE:
454     PetscUseMethod(A, impl::MatDense_CUPM_Base<T>::MatDenseCUPMRestoreArrayWrite_C(), (Mat, PetscScalar **), (A, array));
455     break;
456   case PETSC_MEMORY_ACCESS_READ_WRITE:
457     PetscUseMethod(A, impl::MatDense_CUPM_Base<T>::MatDenseCUPMRestoreArray_C(), (Mat, PetscScalar **), (A, array));
458     break;
459   }
460   if (PetscMemoryAccessWrite(access)) {
461     PetscCall(PetscObjectStateIncrease(PetscObjectCast(A)));
462     A->offloadmask = PETSC_OFFLOAD_GPU;
463   }
464   if (array) *array = nullptr;
465   PetscFunctionReturn(PETSC_SUCCESS);
466 }
467 
468 template <device::cupm::DeviceType T>
469 inline PetscErrorCode MatDenseCUPMGetArray(Mat A, PetscScalar **array) noexcept
470 {
471   PetscFunctionBegin;
472   PetscCall(MatDenseCUPMGetArray_Private<T, PETSC_MEMORY_ACCESS_READ_WRITE>(A, array));
473   PetscFunctionReturn(PETSC_SUCCESS);
474 }
475 
476 template <device::cupm::DeviceType T>
477 inline PetscErrorCode MatDenseCUPMGetArrayRead(Mat A, const PetscScalar **array) noexcept
478 {
479   PetscFunctionBegin;
480   PetscCall(MatDenseCUPMGetArray_Private<T, PETSC_MEMORY_ACCESS_READ>(A, const_cast<PetscScalar **>(array)));
481   PetscFunctionReturn(PETSC_SUCCESS);
482 }
483 
484 template <device::cupm::DeviceType T>
485 inline PetscErrorCode MatDenseCUPMGetArrayWrite(Mat A, PetscScalar **array) noexcept
486 {
487   PetscFunctionBegin;
488   PetscCall(MatDenseCUPMGetArray_Private<T, PETSC_MEMORY_ACCESS_WRITE>(A, array));
489   PetscFunctionReturn(PETSC_SUCCESS);
490 }
491 
492 template <device::cupm::DeviceType T>
493 inline PetscErrorCode MatDenseCUPMRestoreArray(Mat A, PetscScalar **array) noexcept
494 {
495   PetscFunctionBegin;
496   PetscCall(MatDenseCUPMRestoreArray_Private<T, PETSC_MEMORY_ACCESS_READ_WRITE>(A, array));
497   PetscFunctionReturn(PETSC_SUCCESS);
498 }
499 
500 template <device::cupm::DeviceType T>
501 inline PetscErrorCode MatDenseCUPMRestoreArrayRead(Mat A, const PetscScalar **array) noexcept
502 {
503   PetscFunctionBegin;
504   PetscCall(MatDenseCUPMRestoreArray_Private<T, PETSC_MEMORY_ACCESS_READ>(A, const_cast<PetscScalar **>(array)));
505   PetscFunctionReturn(PETSC_SUCCESS);
506 }
507 
508 template <device::cupm::DeviceType T>
509 inline PetscErrorCode MatDenseCUPMRestoreArrayWrite(Mat A, PetscScalar **array) noexcept
510 {
511   PetscFunctionBegin;
512   PetscCall(MatDenseCUPMRestoreArray_Private<T, PETSC_MEMORY_ACCESS_WRITE>(A, array));
513   PetscFunctionReturn(PETSC_SUCCESS);
514 }
515 
516 template <device::cupm::DeviceType T>
517 inline PetscErrorCode MatDenseCUPMPlaceArray(Mat A, const PetscScalar *array) noexcept
518 {
519   PetscFunctionBegin;
520   PetscValidHeaderSpecific(A, MAT_CLASSID, 1);
521   PetscUseMethod(A, impl::MatDense_CUPM_Base<T>::MatDenseCUPMPlaceArray_C(), (Mat, const PetscScalar *), (A, array));
522   PetscCall(PetscObjectStateIncrease(PetscObjectCast(A)));
523   A->offloadmask = PETSC_OFFLOAD_GPU;
524   PetscFunctionReturn(PETSC_SUCCESS);
525 }
526 
527 template <device::cupm::DeviceType T>
528 inline PetscErrorCode MatDenseCUPMReplaceArray(Mat A, const PetscScalar *array) noexcept
529 {
530   PetscFunctionBegin;
531   PetscValidHeaderSpecific(A, MAT_CLASSID, 1);
532   PetscUseMethod(A, impl::MatDense_CUPM_Base<T>::MatDenseCUPMReplaceArray_C(), (Mat, const PetscScalar *), (A, array));
533   PetscCall(PetscObjectStateIncrease(PetscObjectCast(A)));
534   A->offloadmask = PETSC_OFFLOAD_GPU;
535   PetscFunctionReturn(PETSC_SUCCESS);
536 }
537 
538 template <device::cupm::DeviceType T>
539 inline PetscErrorCode MatDenseCUPMResetArray(Mat A) noexcept
540 {
541   PetscFunctionBegin;
542   PetscValidHeaderSpecific(A, MAT_CLASSID, 1);
543   PetscUseMethod(A, impl::MatDense_CUPM_Base<T>::MatDenseCUPMResetArray_C(), (Mat), (A));
544   PetscCall(PetscObjectStateIncrease(PetscObjectCast(A)));
545   PetscFunctionReturn(PETSC_SUCCESS);
546 }
547 
548 } // anonymous namespace
549 
550 } // namespace cupm
551 
552 } // namespace mat
553 
554 } // namespace Petsc
555 
556 #endif // __cplusplus
557 
558 #endif // PETSCMATDENSECUPMIMPL_H
559