xref: /petsc/include/petsc/private/matdensecupmimpl.h (revision 9d13fa56c5c6523e02c36edc0e4e22bf2d0334a8)
1 #ifndef PETSCMATDENSECUPMIMPL_H
2 #define PETSCMATDENSECUPMIMPL_H
3 
4 #define PETSC_SKIP_IMMINTRIN_H_CUDAWORKAROUND 1
5 #include <petsc/private/matimpl.h> /*I <petscmat.h> I*/
6 
7 #ifdef __cplusplus
8   #include <petsc/private/deviceimpl.h>
9   #include <petsc/private/cupmsolverinterface.hpp>
10   #include <petsc/private/cupmobject.hpp>
11 
12   #include "../src/sys/objects/device/impls/cupm/cupmthrustutility.hpp"
13   #include "../src/sys/objects/device/impls/cupm/kernels.hpp"
14 
15   #include <thrust/device_vector.h>
16   #include <thrust/device_ptr.h>
17   #include <thrust/iterator/counting_iterator.h>
18   #include <thrust/iterator/transform_iterator.h>
19   #include <thrust/iterator/permutation_iterator.h>
20   #include <thrust/transform.h>
21 
22 namespace Petsc
23 {
24 
25 namespace vec
26 {
27 
28 namespace cupm
29 {
30 
31 namespace impl
32 {
33 
34 template <device::cupm::DeviceType>
35 class VecSeq_CUPM;
36 template <device::cupm::DeviceType>
37 class VecMPI_CUPM;
38 
39 } // namespace impl
40 
41 } // namespace cupm
42 
43 } // namespace vec
44 
45 namespace mat
46 {
47 
48 namespace cupm
49 {
50 
51 namespace impl
52 {
53 
54 // ==========================================================================================
55 // MatDense_CUPM_Base
56 //
57 // A base class to separate out the CRTP code from the common CUPM stuff (like the composed
58 // function names).
59 // ==========================================================================================
60 
61 template <device::cupm::DeviceType T>
62 class MatDense_CUPM_Base : protected device::cupm::impl::CUPMObject<T> {
63 public:
64   PETSC_CUPMOBJECT_HEADER(T);
65 
66   #define MatDenseCUPMComposedOpDecl(OP_NAME) \
67     PETSC_NODISCARD static constexpr const char *PetscConcat(MatDenseCUPM, OP_NAME)() noexcept \
68     { \
69       return T == device::cupm::DeviceType::CUDA ? PetscStringize(PetscConcat(MatDenseCUDA, OP_NAME)) : PetscStringize(PetscConcat(MatDenseHIP, OP_NAME)); \
70     }
71 
72   // clang-format off
73   MatDenseCUPMComposedOpDecl(GetArray_C)
74   MatDenseCUPMComposedOpDecl(GetArrayRead_C)
75   MatDenseCUPMComposedOpDecl(GetArrayWrite_C)
76   MatDenseCUPMComposedOpDecl(RestoreArray_C)
77   MatDenseCUPMComposedOpDecl(RestoreArrayRead_C)
78   MatDenseCUPMComposedOpDecl(RestoreArrayWrite_C)
79   MatDenseCUPMComposedOpDecl(PlaceArray_C)
80   MatDenseCUPMComposedOpDecl(ReplaceArray_C)
81   MatDenseCUPMComposedOpDecl(ResetArray_C)
82   MatDenseCUPMComposedOpDecl(SetPreallocation_C)
83     // clang-format on
84 
85   #undef MatDenseCUPMComposedOpDecl
86 
87     PETSC_NODISCARD static constexpr MatType MATSEQDENSECUPM() noexcept;
88   PETSC_NODISCARD static constexpr MatType       MATMPIDENSECUPM() noexcept;
89   PETSC_NODISCARD static constexpr MatType       MATDENSECUPM() noexcept;
90   PETSC_NODISCARD static constexpr MatSolverType MATSOLVERCUPM() noexcept;
91 };
92 
93 // ==========================================================================================
94 // MatDense_CUPM_Base -- Public API
95 // ==========================================================================================
96 
97 template <device::cupm::DeviceType T>
98 inline constexpr MatType MatDense_CUPM_Base<T>::MATSEQDENSECUPM() noexcept
99 {
100   return T == device::cupm::DeviceType::CUDA ? MATSEQDENSECUDA : MATSEQDENSEHIP;
101 }
102 
103 template <device::cupm::DeviceType T>
104 inline constexpr MatType MatDense_CUPM_Base<T>::MATMPIDENSECUPM() noexcept
105 {
106   return T == device::cupm::DeviceType::CUDA ? MATMPIDENSECUDA : MATMPIDENSEHIP;
107 }
108 
109 template <device::cupm::DeviceType T>
110 inline constexpr MatType MatDense_CUPM_Base<T>::MATDENSECUPM() noexcept
111 {
112   return T == device::cupm::DeviceType::CUDA ? MATDENSECUDA : MATDENSEHIP;
113 }
114 
115 template <device::cupm::DeviceType T>
116 inline constexpr MatSolverType MatDense_CUPM_Base<T>::MATSOLVERCUPM() noexcept
117 {
118   return T == device::cupm::DeviceType::CUDA ? MATSOLVERCUDA : MATSOLVERHIP;
119 }
120 
121   #define MATDENSECUPM_BASE_HEADER(T) \
122     PETSC_CUPMOBJECT_HEADER(T); \
123     using VecSeq_CUPM = ::Petsc::vec::cupm::impl::VecSeq_CUPM<T>; \
124     using VecMPI_CUPM = ::Petsc::vec::cupm::impl::VecMPI_CUPM<T>; \
125     using ::Petsc::mat::cupm::impl::MatDense_CUPM_Base<T>::MATSEQDENSECUPM; \
126     using ::Petsc::mat::cupm::impl::MatDense_CUPM_Base<T>::MATMPIDENSECUPM; \
127     using ::Petsc::mat::cupm::impl::MatDense_CUPM_Base<T>::MATDENSECUPM; \
128     using ::Petsc::mat::cupm::impl::MatDense_CUPM_Base<T>::MATSOLVERCUPM; \
129     using ::Petsc::mat::cupm::impl::MatDense_CUPM_Base<T>::MatDenseCUPMGetArray_C; \
130     using ::Petsc::mat::cupm::impl::MatDense_CUPM_Base<T>::MatDenseCUPMGetArrayRead_C; \
131     using ::Petsc::mat::cupm::impl::MatDense_CUPM_Base<T>::MatDenseCUPMGetArrayWrite_C; \
132     using ::Petsc::mat::cupm::impl::MatDense_CUPM_Base<T>::MatDenseCUPMRestoreArray_C; \
133     using ::Petsc::mat::cupm::impl::MatDense_CUPM_Base<T>::MatDenseCUPMRestoreArrayRead_C; \
134     using ::Petsc::mat::cupm::impl::MatDense_CUPM_Base<T>::MatDenseCUPMRestoreArrayWrite_C; \
135     using ::Petsc::mat::cupm::impl::MatDense_CUPM_Base<T>::MatDenseCUPMPlaceArray_C; \
136     using ::Petsc::mat::cupm::impl::MatDense_CUPM_Base<T>::MatDenseCUPMReplaceArray_C; \
137     using ::Petsc::mat::cupm::impl::MatDense_CUPM_Base<T>::MatDenseCUPMResetArray_C; \
138     using ::Petsc::mat::cupm::impl::MatDense_CUPM_Base<T>::MatDenseCUPMSetPreallocation_C
139 
140 // forward declare
141 template <device::cupm::DeviceType>
142 class MatDense_Seq_CUPM;
143 template <device::cupm::DeviceType>
144 class MatDense_MPI_CUPM;
145 
146 // ==========================================================================================
147 // MatDense_CUPM
148 //
149 // The true "base" class for MatDenseCUPM. The reason MatDense_CUPM and MatDense_CUPM_Base
150 // exist is to separate out the CRTP code from the non-crtp code so that the generic functions
151 // can be called via templates below.
152 // ==========================================================================================
153 
154 template <device::cupm::DeviceType T, typename Derived>
155 class MatDense_CUPM : protected MatDense_CUPM_Base<T> {
156 protected:
157   MATDENSECUPM_BASE_HEADER(T);
158 
159   template <PetscMemType, PetscMemoryAccessMode>
160   class MatrixArray;
161 
162   // Cast the Mat to its host struct, i.e. return the result of (Mat_SeqDense *)m->data
163   template <typename U = Derived>
164   PETSC_NODISCARD static constexpr auto    MatIMPLCast(Mat m) noexcept PETSC_DECLTYPE_AUTO_RETURNS(U::MatIMPLCast_(m))
165   PETSC_NODISCARD static constexpr MatType MATIMPLCUPM() noexcept;
166 
167   static PetscErrorCode CreateIMPLDenseCUPM(MPI_Comm, PetscInt, PetscInt, PetscInt, PetscInt, PetscScalar *, Mat *, PetscDeviceContext, bool) noexcept;
168   static PetscErrorCode SetPreallocation(Mat, PetscDeviceContext, PetscScalar *) noexcept;
169 
170   template <typename F>
171   static PetscErrorCode DiagonalUnaryTransform(Mat, PetscInt, PetscInt, PetscInt, PetscDeviceContext, F &&) noexcept;
172 
173   PETSC_NODISCARD static auto DeviceArrayRead(PetscDeviceContext dctx, Mat m) noexcept PETSC_DECLTYPE_AUTO_RETURNS(MatrixArray<PETSC_MEMTYPE_DEVICE, PETSC_MEMORY_ACCESS_READ>{dctx, m})
174   PETSC_NODISCARD static auto DeviceArrayWrite(PetscDeviceContext dctx, Mat m) noexcept PETSC_DECLTYPE_AUTO_RETURNS(MatrixArray<PETSC_MEMTYPE_DEVICE, PETSC_MEMORY_ACCESS_WRITE>{dctx, m})
175   PETSC_NODISCARD static auto DeviceArrayReadWrite(PetscDeviceContext dctx, Mat m) noexcept PETSC_DECLTYPE_AUTO_RETURNS(MatrixArray<PETSC_MEMTYPE_DEVICE, PETSC_MEMORY_ACCESS_READ_WRITE>{dctx, m})
176   PETSC_NODISCARD static auto HostArrayRead(PetscDeviceContext dctx, Mat m) noexcept PETSC_DECLTYPE_AUTO_RETURNS(MatrixArray<PETSC_MEMTYPE_HOST, PETSC_MEMORY_ACCESS_READ>{dctx, m})
177   PETSC_NODISCARD static auto HostArrayWrite(PetscDeviceContext dctx, Mat m) noexcept PETSC_DECLTYPE_AUTO_RETURNS(MatrixArray<PETSC_MEMTYPE_HOST, PETSC_MEMORY_ACCESS_WRITE>{dctx, m})
178   PETSC_NODISCARD static auto HostArrayReadWrite(PetscDeviceContext dctx, Mat m) noexcept PETSC_DECLTYPE_AUTO_RETURNS(MatrixArray<PETSC_MEMTYPE_HOST, PETSC_MEMORY_ACCESS_READ_WRITE>{dctx, m})
179 };
180 
181 // ==========================================================================================
182 // MatDense_CUPM::MatrixArray
183 // ==========================================================================================
184 
185 template <device::cupm::DeviceType T, typename D>
186 template <PetscMemType MT, PetscMemoryAccessMode MA>
187 class MatDense_CUPM<T, D>::MatrixArray : public device::cupm::impl::RestoreableArray<T, MT, MA> {
188   using base_type = device::cupm::impl::RestoreableArray<T, MT, MA>;
189 
190 public:
191   MatrixArray(PetscDeviceContext, Mat) noexcept;
192   ~MatrixArray() noexcept;
193 
194   // must declare move constructor since we declare a destructor
195   constexpr MatrixArray(MatrixArray &&) noexcept;
196 
197 private:
198   Mat m_ = nullptr;
199 };
200 
201 // ==========================================================================================
202 // MatDense_CUPM::MatrixArray -- Public API
203 // ==========================================================================================
204 
205 template <device::cupm::DeviceType T, typename D>
206 template <PetscMemType MT, PetscMemoryAccessMode MA>
207 inline MatDense_CUPM<T, D>::MatrixArray<MT, MA>::MatrixArray(PetscDeviceContext dctx, Mat m) noexcept : base_type{dctx}, m_{m}
208 {
209   PetscFunctionBegin;
210   PetscCallAbort(PETSC_COMM_SELF, D::template GetArray<MT, MA>(m, &this->ptr_, dctx));
211   PetscFunctionReturnVoid();
212 }
213 
214 template <device::cupm::DeviceType T, typename D>
215 template <PetscMemType MT, PetscMemoryAccessMode MA>
216 inline MatDense_CUPM<T, D>::MatrixArray<MT, MA>::~MatrixArray() noexcept
217 {
218   PetscFunctionBegin;
219   PetscCallAbort(PETSC_COMM_SELF, D::template RestoreArray<MT, MA>(m_, &this->ptr_, this->dctx_));
220   PetscFunctionReturnVoid();
221 }
222 
223 template <device::cupm::DeviceType T, typename D>
224 template <PetscMemType MT, PetscMemoryAccessMode MA>
225 inline constexpr MatDense_CUPM<T, D>::MatrixArray<MT, MA>::MatrixArray(MatrixArray &&other) noexcept : base_type{std::move(other)}, m_{util::exchange(other.m_, nullptr)}
226 {
227 }
228 
229 // ==========================================================================================
230 // MatDense_CUPM -- Protected API
231 // ==========================================================================================
232 
233 template <device::cupm::DeviceType T, typename D>
234 inline constexpr MatType MatDense_CUPM<T, D>::MATIMPLCUPM() noexcept
235 {
236   return D::MATIMPLCUPM_();
237 }
238 
239 // Common core for MatCreateSeqDenseCUPM() and MatCreateMPIDenseCUPM()
240 template <device::cupm::DeviceType T, typename D>
241 inline PetscErrorCode MatDense_CUPM<T, D>::CreateIMPLDenseCUPM(MPI_Comm comm, PetscInt m, PetscInt n, PetscInt M, PetscInt N, PetscScalar *data, Mat *A, PetscDeviceContext dctx, bool preallocate) noexcept
242 {
243   Mat mat;
244 
245   PetscFunctionBegin;
246   PetscValidPointer(A, 7);
247   PetscCall(MatCreate(comm, &mat));
248   PetscCall(MatSetSizes(mat, m, n, M, N));
249   PetscCall(MatSetType(mat, D::MATIMPLCUPM()));
250   if (preallocate) {
251     PetscCall(PetscDeviceContextGetOptionalNullContext_Internal(&dctx));
252     PetscCall(D::SetPreallocation(mat, dctx, data));
253   }
254   *A = mat;
255   PetscFunctionReturn(PETSC_SUCCESS);
256 }
257 
258 template <device::cupm::DeviceType T, typename D>
259 inline PetscErrorCode MatDense_CUPM<T, D>::SetPreallocation(Mat A, PetscDeviceContext dctx, PetscScalar *device_array) noexcept
260 {
261   PetscFunctionBegin;
262   // cannot use PetscValidHeaderSpecificType(..., MATIMPLCUPM()) since the incoming matrix
263   // might be the local (sequential) matrix of a MatMPIDense_CUPM. Since this would be called
264   // from the MPI matrix'es impl MATIMPLCUPM() would return MATMPIDENSECUPM().
265   PetscValidHeaderSpecific(A, MAT_CLASSID, 1);
266   PetscCheckTypeNames(A, D::MATSEQDENSECUPM(), D::MATMPIDENSECUPM());
267   PetscCall(PetscLayoutSetUp(A->rmap));
268   PetscCall(PetscLayoutSetUp(A->cmap));
269   PetscCall(PetscDeviceContextGetOptionalNullContext_Internal(&dctx));
270   PetscCall(D::SetPreallocation_(A, dctx, device_array));
271   A->preallocated = PETSC_TRUE;
272   A->assembled    = PETSC_TRUE;
273   PetscFunctionReturn(PETSC_SUCCESS);
274 }
275 
276 namespace detail
277 {
278 
279 // ==========================================================================================
280 // MatrixIteratorBase
281 //
282 // A base class for creating thrust iterators over the local sub-matrix. This will set up the
283 // proper iterator definitions so thrust knows how to handle things properly. Template
284 // parameters are as follows:
285 //
286 // - Iterator:
287 // The type of the primary array iterator. Usually this is
288 // thrust::device_pointer<PetscScalar>::iterator.
289 //
290 // - IndexFunctor:
291 // This should be a functor which contains an operator() that when called with an index `i`,
292 // returns the i'th permuted index into the array. For example, it could return the i'th
293 // diagonal entry.
294 // ==========================================================================================
295 template <typename Iterator, typename IndexFunctor>
296 class MatrixIteratorBase {
297 public:
298   using array_iterator_type = Iterator;
299   using index_functor_type  = IndexFunctor;
300 
301   using difference_type     = typename thrust::iterator_difference<array_iterator_type>::type;
302   using CountingIterator    = thrust::counting_iterator<difference_type>;
303   using TransformIterator   = thrust::transform_iterator<index_functor_type, CountingIterator>;
304   using PermutationIterator = thrust::permutation_iterator<array_iterator_type, TransformIterator>;
305   using iterator            = PermutationIterator; // type of the begin/end iterator
306 
307   constexpr MatrixIteratorBase(array_iterator_type first, array_iterator_type last, index_functor_type idx_func) noexcept : first{std::move(first)}, last{std::move(last)}, func{std::move(idx_func)} { }
308 
309   PETSC_NODISCARD iterator begin() const noexcept
310   {
311     return PermutationIterator{
312       first, TransformIterator{CountingIterator{0}, func}
313     };
314   }
315 
316 protected:
317   array_iterator_type first;
318   array_iterator_type last;
319   index_functor_type  func;
320 };
321 
322 // ==========================================================================================
323 // StridedIndexFunctor
324 //
325 // Iterator which permutes a linear index range into strided matrix indices. Usually used to
326 // get the diagonal.
327 // ==========================================================================================
328 template <typename T>
329 struct StridedIndexFunctor {
330   PETSC_NODISCARD PETSC_HOSTDEVICE_INLINE_DECL constexpr T operator()(const T &i) const noexcept { return stride * i; }
331 
332   T stride;
333 };
334 
335 template <typename Iterator>
336 class DiagonalIterator : public MatrixIteratorBase<Iterator, StridedIndexFunctor<typename thrust::iterator_difference<Iterator>::type>> {
337 public:
338   using base_type = MatrixIteratorBase<Iterator, StridedIndexFunctor<typename thrust::iterator_difference<Iterator>::type>>;
339 
340   using difference_type = typename base_type::difference_type;
341   using iterator        = typename base_type::iterator;
342 
343   constexpr DiagonalIterator(Iterator first, Iterator last, difference_type stride) noexcept : base_type{std::move(first), std::move(last), {stride}} { }
344 
345   PETSC_NODISCARD iterator end() const noexcept { return this->begin() + (this->last - this->first + this->func.stride - 1) / this->func.stride; }
346 };
347 
348 } // namespace detail
349 
350 template <device::cupm::DeviceType T, typename D>
351 template <typename F>
352 inline PetscErrorCode MatDense_CUPM<T, D>::DiagonalUnaryTransform(Mat A, PetscInt rstart, PetscInt rend, PetscInt cols, PetscDeviceContext dctx, F &&functor) noexcept
353 {
354   const auto rend2 = std::min(rend, cols);
355 
356   PetscFunctionBegin;
357   if (rend2 > rstart) {
358     const auto da = D::DeviceArrayReadWrite(dctx, A);
359     PetscInt   lda;
360 
361     PetscCall(MatDenseGetLDA(A, &lda));
362     {
363       using DiagonalIterator  = detail::DiagonalIterator<thrust::device_vector<PetscScalar>::iterator>;
364       const auto        dptr  = thrust::device_pointer_cast(da.data());
365       const std::size_t begin = rstart * lda;
366       const std::size_t end   = rend2 - rstart + rend2 * lda;
367       DiagonalIterator  diagonal{dptr + begin, dptr + end, lda + 1};
368       cupmStream_t      stream;
369 
370       PetscCall(D::GetHandlesFrom_(dctx, &stream));
371       // clang-format off
372       PetscCallThrust(
373         THRUST_CALL(
374           thrust::transform,
375           stream,
376           diagonal.begin(), diagonal.end(), diagonal.begin(),
377           std::forward<F>(functor)
378         )
379       );
380       // clang-format on
381     }
382     PetscCall(PetscLogGpuFlops(rend2 - rstart));
383   }
384   PetscFunctionReturn(PETSC_SUCCESS);
385 }
386 
387   #define MatComposeOp_CUPM(use_host, pobj, op_str, op_host, ...) \
388     do { \
389       if (use_host) { \
390         PetscCall(PetscObjectComposeFunction(pobj, op_str, op_host)); \
391       } else { \
392         PetscCall(PetscObjectComposeFunction(pobj, op_str, __VA_ARGS__)); \
393       } \
394     } while (0)
395 
396   #define MatSetOp_CUPM(use_host, mat, op_name, op_host, ...) \
397     do { \
398       if (use_host) { \
399         (mat)->ops->op_name = op_host; \
400       } else { \
401         (mat)->ops->op_name = __VA_ARGS__; \
402       } \
403     } while (0)
404 
405   #define MATDENSECUPM_HEADER(T, ...) \
406     MATDENSECUPM_BASE_HEADER(T); \
407     friend class ::Petsc::mat::cupm::impl::MatDense_CUPM<T, __VA_ARGS__>; \
408     using ::Petsc::mat::cupm::impl::MatDense_CUPM<T, __VA_ARGS__>::MatIMPLCast; \
409     using ::Petsc::mat::cupm::impl::MatDense_CUPM<T, __VA_ARGS__>::MATIMPLCUPM; \
410     using ::Petsc::mat::cupm::impl::MatDense_CUPM<T, __VA_ARGS__>::CreateIMPLDenseCUPM; \
411     using ::Petsc::mat::cupm::impl::MatDense_CUPM<T, __VA_ARGS__>::SetPreallocation; \
412     using ::Petsc::mat::cupm::impl::MatDense_CUPM<T, __VA_ARGS__>::DeviceArrayRead; \
413     using ::Petsc::mat::cupm::impl::MatDense_CUPM<T, __VA_ARGS__>::DeviceArrayWrite; \
414     using ::Petsc::mat::cupm::impl::MatDense_CUPM<T, __VA_ARGS__>::DeviceArrayReadWrite; \
415     using ::Petsc::mat::cupm::impl::MatDense_CUPM<T, __VA_ARGS__>::HostArrayRead; \
416     using ::Petsc::mat::cupm::impl::MatDense_CUPM<T, __VA_ARGS__>::HostArrayWrite; \
417     using ::Petsc::mat::cupm::impl::MatDense_CUPM<T, __VA_ARGS__>::HostArrayReadWrite; \
418     using ::Petsc::mat::cupm::impl::MatDense_CUPM<T, __VA_ARGS__>::DiagonalUnaryTransform
419 
420 } // namespace impl
421 
422 namespace
423 {
424 
425 template <device::cupm::DeviceType T, PetscMemoryAccessMode access>
426 inline PetscErrorCode MatDenseCUPMGetArray_Private(Mat A, PetscScalar **array) noexcept
427 {
428   PetscFunctionBegin;
429   PetscValidHeaderSpecific(A, MAT_CLASSID, 1);
430   PetscValidPointer(array, 2);
431   switch (access) {
432   case PETSC_MEMORY_ACCESS_READ:
433     PetscUseMethod(A, impl::MatDense_CUPM_Base<T>::MatDenseCUPMGetArrayRead_C(), (Mat, PetscScalar **), (A, array));
434     break;
435   case PETSC_MEMORY_ACCESS_WRITE:
436     PetscUseMethod(A, impl::MatDense_CUPM_Base<T>::MatDenseCUPMGetArrayWrite_C(), (Mat, PetscScalar **), (A, array));
437     break;
438   case PETSC_MEMORY_ACCESS_READ_WRITE:
439     PetscUseMethod(A, impl::MatDense_CUPM_Base<T>::MatDenseCUPMGetArray_C(), (Mat, PetscScalar **), (A, array));
440     break;
441   }
442   if (PetscMemoryAccessWrite(access)) PetscCall(PetscObjectStateIncrease(PetscObjectCast(A)));
443   PetscFunctionReturn(PETSC_SUCCESS);
444 }
445 
446 template <device::cupm::DeviceType T, PetscMemoryAccessMode access>
447 inline PetscErrorCode MatDenseCUPMRestoreArray_Private(Mat A, PetscScalar **array) noexcept
448 {
449   PetscFunctionBegin;
450   PetscValidHeaderSpecific(A, MAT_CLASSID, 1);
451   if (array) PetscValidPointer(array, 2);
452   switch (access) {
453   case PETSC_MEMORY_ACCESS_READ:
454     PetscUseMethod(A, impl::MatDense_CUPM_Base<T>::MatDenseCUPMRestoreArrayRead_C(), (Mat, PetscScalar **), (A, array));
455     break;
456   case PETSC_MEMORY_ACCESS_WRITE:
457     PetscUseMethod(A, impl::MatDense_CUPM_Base<T>::MatDenseCUPMRestoreArrayWrite_C(), (Mat, PetscScalar **), (A, array));
458     break;
459   case PETSC_MEMORY_ACCESS_READ_WRITE:
460     PetscUseMethod(A, impl::MatDense_CUPM_Base<T>::MatDenseCUPMRestoreArray_C(), (Mat, PetscScalar **), (A, array));
461     break;
462   }
463   if (PetscMemoryAccessWrite(access)) {
464     PetscCall(PetscObjectStateIncrease(PetscObjectCast(A)));
465     A->offloadmask = PETSC_OFFLOAD_GPU;
466   }
467   if (array) *array = nullptr;
468   PetscFunctionReturn(PETSC_SUCCESS);
469 }
470 
471 template <device::cupm::DeviceType T>
472 inline PetscErrorCode MatDenseCUPMGetArray(Mat A, PetscScalar **array) noexcept
473 {
474   PetscFunctionBegin;
475   PetscCall(MatDenseCUPMGetArray_Private<T, PETSC_MEMORY_ACCESS_READ_WRITE>(A, array));
476   PetscFunctionReturn(PETSC_SUCCESS);
477 }
478 
479 template <device::cupm::DeviceType T>
480 inline PetscErrorCode MatDenseCUPMGetArrayRead(Mat A, const PetscScalar **array) noexcept
481 {
482   PetscFunctionBegin;
483   PetscCall(MatDenseCUPMGetArray_Private<T, PETSC_MEMORY_ACCESS_READ>(A, const_cast<PetscScalar **>(array)));
484   PetscFunctionReturn(PETSC_SUCCESS);
485 }
486 
487 template <device::cupm::DeviceType T>
488 inline PetscErrorCode MatDenseCUPMGetArrayWrite(Mat A, PetscScalar **array) noexcept
489 {
490   PetscFunctionBegin;
491   PetscCall(MatDenseCUPMGetArray_Private<T, PETSC_MEMORY_ACCESS_WRITE>(A, array));
492   PetscFunctionReturn(PETSC_SUCCESS);
493 }
494 
495 template <device::cupm::DeviceType T>
496 inline PetscErrorCode MatDenseCUPMRestoreArray(Mat A, PetscScalar **array) noexcept
497 {
498   PetscFunctionBegin;
499   PetscCall(MatDenseCUPMRestoreArray_Private<T, PETSC_MEMORY_ACCESS_READ_WRITE>(A, array));
500   PetscFunctionReturn(PETSC_SUCCESS);
501 }
502 
503 template <device::cupm::DeviceType T>
504 inline PetscErrorCode MatDenseCUPMRestoreArrayRead(Mat A, const PetscScalar **array) noexcept
505 {
506   PetscFunctionBegin;
507   PetscCall(MatDenseCUPMRestoreArray_Private<T, PETSC_MEMORY_ACCESS_READ>(A, const_cast<PetscScalar **>(array)));
508   PetscFunctionReturn(PETSC_SUCCESS);
509 }
510 
511 template <device::cupm::DeviceType T>
512 inline PetscErrorCode MatDenseCUPMRestoreArrayWrite(Mat A, PetscScalar **array) noexcept
513 {
514   PetscFunctionBegin;
515   PetscCall(MatDenseCUPMRestoreArray_Private<T, PETSC_MEMORY_ACCESS_WRITE>(A, array));
516   PetscFunctionReturn(PETSC_SUCCESS);
517 }
518 
519 template <device::cupm::DeviceType T>
520 inline PetscErrorCode MatDenseCUPMPlaceArray(Mat A, const PetscScalar *array) noexcept
521 {
522   PetscFunctionBegin;
523   PetscValidHeaderSpecific(A, MAT_CLASSID, 1);
524   PetscUseMethod(A, impl::MatDense_CUPM_Base<T>::MatDenseCUPMPlaceArray_C(), (Mat, const PetscScalar *), (A, array));
525   PetscCall(PetscObjectStateIncrease(PetscObjectCast(A)));
526   A->offloadmask = PETSC_OFFLOAD_GPU;
527   PetscFunctionReturn(PETSC_SUCCESS);
528 }
529 
530 template <device::cupm::DeviceType T>
531 inline PetscErrorCode MatDenseCUPMReplaceArray(Mat A, const PetscScalar *array) noexcept
532 {
533   PetscFunctionBegin;
534   PetscValidHeaderSpecific(A, MAT_CLASSID, 1);
535   PetscUseMethod(A, impl::MatDense_CUPM_Base<T>::MatDenseCUPMReplaceArray_C(), (Mat, const PetscScalar *), (A, array));
536   PetscCall(PetscObjectStateIncrease(PetscObjectCast(A)));
537   A->offloadmask = PETSC_OFFLOAD_GPU;
538   PetscFunctionReturn(PETSC_SUCCESS);
539 }
540 
541 template <device::cupm::DeviceType T>
542 inline PetscErrorCode MatDenseCUPMResetArray(Mat A) noexcept
543 {
544   PetscFunctionBegin;
545   PetscValidHeaderSpecific(A, MAT_CLASSID, 1);
546   PetscUseMethod(A, impl::MatDense_CUPM_Base<T>::MatDenseCUPMResetArray_C(), (Mat), (A));
547   PetscCall(PetscObjectStateIncrease(PetscObjectCast(A)));
548   PetscFunctionReturn(PETSC_SUCCESS);
549 }
550 
551 template <device::cupm::DeviceType T>
552 inline PetscErrorCode MatDenseCUPMSetPreallocation(Mat A, PetscScalar *device_data, PetscDeviceContext dctx = nullptr) noexcept
553 {
554   PetscFunctionBegin;
555   PetscValidHeaderSpecific(A, MAT_CLASSID, 1);
556   PetscUseMethod(A, impl::MatDense_CUPM_Base<T>::MatDenseCUPMSetPreallocation_C(), (Mat, PetscDeviceContext, PetscScalar *), (A, dctx, device_data));
557   PetscFunctionReturn(PETSC_SUCCESS);
558 }
559 
560 } // anonymous namespace
561 
562 } // namespace cupm
563 
564 } // namespace mat
565 
566 } // namespace Petsc
567 
568 #endif // __cplusplus
569 
570 #endif // PETSCMATDENSECUPMIMPL_H
571