xref: /petsc/src/mat/impls/dense/mpi/cupm/matmpidensecupm.hpp (revision 14277c9297e98638593654668ce43885242b9940)
1 #ifndef PETSCMATMPIDENSECUPM_HPP
2 #define PETSCMATMPIDENSECUPM_HPP
3 
4 #include <petsc/private/matdensecupmimpl.h> /*I <petscmat.h> I*/
5 #include <../src/mat/impls/dense/mpi/mpidense.h>
6 
7 #include <../src/mat/impls/dense/seq/cupm/matseqdensecupm.hpp>
8 #include <../src/vec/vec/impls/mpi/cupm/vecmpicupm.hpp>
9 
10 namespace Petsc
11 {
12 
13 namespace mat
14 {
15 
16 namespace cupm
17 {
18 
19 namespace impl
20 {
21 
22 template <device::cupm::DeviceType T>
23 class MatDense_MPI_CUPM : MatDense_CUPM<T, MatDense_MPI_CUPM<T>> {
24 public:
25   MATDENSECUPM_HEADER(T, MatDense_MPI_CUPM<T>);
26 
27 private:
28   PETSC_NODISCARD static constexpr Mat_MPIDense *MatIMPLCast_(Mat) noexcept;
29   PETSC_NODISCARD static constexpr MatType       MATIMPLCUPM_() noexcept;
30 
31   static PetscErrorCode SetPreallocation_(Mat, PetscDeviceContext, PetscScalar *) noexcept;
32 
33   template <bool to_host>
34   static PetscErrorCode Convert_Dispatch_(Mat, MatType, MatReuse, Mat *) noexcept;
35 
36 public:
37   PETSC_NODISCARD static constexpr const char *MatConvert_mpidensecupm_mpidense_C() noexcept;
38 
39   PETSC_NODISCARD static constexpr const char *MatProductSetFromOptions_mpiaij_mpidensecupm_C() noexcept;
40   PETSC_NODISCARD static constexpr const char *MatProductSetFromOptions_mpidensecupm_mpiaij_C() noexcept;
41 
42   PETSC_NODISCARD static constexpr const char *MatProductSetFromOptions_mpiaijcupmsparse_mpidensecupm_C() noexcept;
43   PETSC_NODISCARD static constexpr const char *MatProductSetFromOptions_mpidensecupm_mpiaijcupmsparse_C() noexcept;
44 
45   static PetscErrorCode Create(Mat) noexcept;
46 
47   static PetscErrorCode BindToCPU(Mat, PetscBool) noexcept;
48   static PetscErrorCode Convert_MPIDenseCUPM_MPIDense(Mat, MatType, MatReuse, Mat *) noexcept;
49   static PetscErrorCode Convert_MPIDense_MPIDenseCUPM(Mat, MatType, MatReuse, Mat *) noexcept;
50 
51   template <PetscMemType, PetscMemoryAccessMode>
52   static PetscErrorCode GetArray(Mat, PetscScalar **, PetscDeviceContext = nullptr) noexcept;
53   template <PetscMemType, PetscMemoryAccessMode>
54   static PetscErrorCode RestoreArray(Mat, PetscScalar **, PetscDeviceContext = nullptr) noexcept;
55 
56 private:
57   template <PetscMemType mtype, PetscMemoryAccessMode mode>
58   static PetscErrorCode GetArrayC_(Mat m, PetscScalar **p) noexcept
59   {
60     return GetArray<mtype, mode>(m, p);
61   }
62 
63   template <PetscMemType mtype, PetscMemoryAccessMode mode>
64   static PetscErrorCode RestoreArrayC_(Mat m, PetscScalar **p) noexcept
65   {
66     return RestoreArray<mtype, mode>(m, p);
67   }
68 
69 public:
70   template <PetscMemoryAccessMode>
71   static PetscErrorCode GetColumnVec(Mat, PetscInt, Vec *) noexcept;
72   template <PetscMemoryAccessMode>
73   static PetscErrorCode RestoreColumnVec(Mat, PetscInt, Vec *) noexcept;
74 
75   static PetscErrorCode PlaceArray(Mat, const PetscScalar *) noexcept;
76   static PetscErrorCode ReplaceArray(Mat, const PetscScalar *) noexcept;
77   static PetscErrorCode ResetArray(Mat) noexcept;
78 
79   static PetscErrorCode Shift(Mat, PetscScalar) noexcept;
80 
81   static PetscErrorCode GetDiagonal(Mat, Vec) noexcept;
82 };
83 
84 } // namespace impl
85 
86 namespace
87 {
88 
89 // Declare this here so that the functions below can make use of it
90 template <device::cupm::DeviceType T>
91 inline PetscErrorCode MatCreateMPIDenseCUPM(MPI_Comm comm, PetscInt m, PetscInt n, PetscInt M, PetscInt N, PetscScalar *data, Mat *A, PetscDeviceContext dctx = nullptr, bool preallocate = true) noexcept
92 {
93   PetscFunctionBegin;
94   PetscCall(impl::MatDense_MPI_CUPM<T>::CreateIMPLDenseCUPM(comm, m, n, M, N, data, A, dctx, preallocate));
95   PetscFunctionReturn(PETSC_SUCCESS);
96 }
97 
98 } // anonymous namespace
99 
100 namespace impl
101 {
102 
103 // ==========================================================================================
104 // MatDense_MPI_CUPM -- Private API
105 // ==========================================================================================
106 
107 template <device::cupm::DeviceType T>
108 inline constexpr Mat_MPIDense *MatDense_MPI_CUPM<T>::MatIMPLCast_(Mat m) noexcept
109 {
110   return static_cast<Mat_MPIDense *>(m->data);
111 }
112 
113 template <device::cupm::DeviceType T>
114 inline constexpr MatType MatDense_MPI_CUPM<T>::MATIMPLCUPM_() noexcept
115 {
116   return MATMPIDENSECUPM();
117 }
118 
119 // ==========================================================================================
120 
121 template <device::cupm::DeviceType T>
122 inline PetscErrorCode MatDense_MPI_CUPM<T>::SetPreallocation_(Mat A, PetscDeviceContext dctx, PetscScalar *device_array) noexcept
123 {
124   PetscFunctionBegin;
125   if (auto &mimplA = MatIMPLCast(A)->A) {
126     PetscCall(MatSetType(mimplA, MATSEQDENSECUPM()));
127     PetscCall(MatDense_Seq_CUPM<T>::SetPreallocation(mimplA, dctx, device_array));
128   } else {
129     PetscCall(MatCreateSeqDenseCUPM<T>(PETSC_COMM_SELF, A->rmap->n, A->cmap->N, device_array, &mimplA, dctx));
130   }
131   PetscFunctionReturn(PETSC_SUCCESS);
132 }
133 
134 template <device::cupm::DeviceType T>
135 template <bool to_host>
136 inline PetscErrorCode MatDense_MPI_CUPM<T>::Convert_Dispatch_(Mat M, MatType, MatReuse reuse, Mat *newmat) noexcept
137 {
138   PetscFunctionBegin;
139   if (reuse == MAT_INITIAL_MATRIX) {
140     PetscCall(MatDuplicate(M, MAT_COPY_VALUES, newmat));
141   } else if (reuse == MAT_REUSE_MATRIX) {
142     PetscCall(MatCopy(M, *newmat, SAME_NONZERO_PATTERN));
143   }
144   {
145     const auto B    = *newmat;
146     const auto pobj = PetscObjectCast(B);
147 
148     if (to_host) {
149       PetscCall(BindToCPU(B, PETSC_TRUE));
150     } else {
151       PetscCall(PetscDeviceInitialize(PETSC_DEVICE_CUPM()));
152     }
153 
154     PetscCall(PetscStrFreeAllocpy(to_host ? VECSTANDARD : VecMPI_CUPM::VECCUPM(), &B->defaultvectype));
155     PetscCall(PetscObjectChangeTypeName(pobj, to_host ? MATMPIDENSE : MATMPIDENSECUPM()));
156 
157     // ============================================================
158     // Composed Ops
159     // ============================================================
160     MatComposeOp_CUPM(to_host, pobj, MatConvert_mpidensecupm_mpidense_C(), nullptr, Convert_MPIDenseCUPM_MPIDense);
161     MatComposeOp_CUPM(to_host, pobj, MatProductSetFromOptions_mpiaij_mpidensecupm_C(), nullptr, MatProductSetFromOptions_MPIAIJ_MPIDense);
162     MatComposeOp_CUPM(to_host, pobj, MatProductSetFromOptions_mpiaijcupmsparse_mpidensecupm_C(), nullptr, MatProductSetFromOptions_MPIAIJ_MPIDense);
163     MatComposeOp_CUPM(to_host, pobj, MatProductSetFromOptions_mpidensecupm_mpiaij_C(), nullptr, MatProductSetFromOptions_MPIDense_MPIAIJ);
164     MatComposeOp_CUPM(to_host, pobj, MatProductSetFromOptions_mpidensecupm_mpiaijcupmsparse_C(), nullptr, MatProductSetFromOptions_MPIDense_MPIAIJ);
165     MatComposeOp_CUPM(to_host, pobj, MatDenseCUPMGetArray_C(), nullptr, GetArrayC_<PETSC_MEMTYPE_DEVICE, PETSC_MEMORY_ACCESS_READ_WRITE>);
166     MatComposeOp_CUPM(to_host, pobj, MatDenseCUPMGetArrayRead_C(), nullptr, GetArrayC_<PETSC_MEMTYPE_DEVICE, PETSC_MEMORY_ACCESS_READ>);
167     MatComposeOp_CUPM(to_host, pobj, MatDenseCUPMGetArrayWrite_C(), nullptr, GetArrayC_<PETSC_MEMTYPE_DEVICE, PETSC_MEMORY_ACCESS_WRITE>);
168     MatComposeOp_CUPM(to_host, pobj, MatDenseCUPMRestoreArray_C(), nullptr, RestoreArrayC_<PETSC_MEMTYPE_DEVICE, PETSC_MEMORY_ACCESS_READ_WRITE>);
169     MatComposeOp_CUPM(to_host, pobj, MatDenseCUPMRestoreArrayRead_C(), nullptr, RestoreArrayC_<PETSC_MEMTYPE_DEVICE, PETSC_MEMORY_ACCESS_READ>);
170     MatComposeOp_CUPM(to_host, pobj, MatDenseCUPMRestoreArrayWrite_C(), nullptr, RestoreArrayC_<PETSC_MEMTYPE_DEVICE, PETSC_MEMORY_ACCESS_WRITE>);
171     MatComposeOp_CUPM(to_host, pobj, MatDenseCUPMPlaceArray_C(), nullptr, PlaceArray);
172     MatComposeOp_CUPM(to_host, pobj, MatDenseCUPMResetArray_C(), nullptr, ResetArray);
173     MatComposeOp_CUPM(to_host, pobj, MatDenseCUPMReplaceArray_C(), nullptr, ReplaceArray);
174     MatComposeOp_CUPM(to_host, pobj, MatDenseCUPMSetPreallocation_C(), nullptr, SetPreallocation);
175 
176     if (to_host) {
177       if (auto &m_A = MatIMPLCast(B)->A) PetscCall(MatConvert(m_A, MATSEQDENSE, MAT_INPLACE_MATRIX, &m_A));
178       B->offloadmask = PETSC_OFFLOAD_CPU;
179     } else {
180       if (auto &m_A = MatIMPLCast(B)->A) {
181         PetscCall(MatConvert(m_A, MATSEQDENSECUPM(), MAT_INPLACE_MATRIX, &m_A));
182         B->offloadmask = PETSC_OFFLOAD_BOTH;
183       } else {
184         B->offloadmask = PETSC_OFFLOAD_UNALLOCATED;
185       }
186       PetscCall(BindToCPU(B, PETSC_FALSE));
187     }
188 
189     // ============================================================
190     // Function Pointer Ops
191     // ============================================================
192     MatSetOp_CUPM(to_host, B, getdiagonal, MatGetDiagonal_MPIDense, GetDiagonal);
193     MatSetOp_CUPM(to_host, B, bindtocpu, nullptr, BindToCPU);
194   }
195   PetscFunctionReturn(PETSC_SUCCESS);
196 }
197 
198 // ==========================================================================================
199 // MatDense_MPI_CUPM -- Public API
200 // ==========================================================================================
201 
202 template <device::cupm::DeviceType T>
203 inline constexpr const char *MatDense_MPI_CUPM<T>::MatConvert_mpidensecupm_mpidense_C() noexcept
204 {
205   return T == device::cupm::DeviceType::CUDA ? "MatConvert_mpidensecuda_mpidense_C" : "MatConvert_mpidensehip_mpidense_C";
206 }
207 
208 template <device::cupm::DeviceType T>
209 inline constexpr const char *MatDense_MPI_CUPM<T>::MatProductSetFromOptions_mpiaij_mpidensecupm_C() noexcept
210 {
211   return T == device::cupm::DeviceType::CUDA ? "MatProductSetFromOptions_mpiaij_mpidensecuda_C" : "MatProductSetFromOptions_mpiaij_mpidensehip_C";
212 }
213 
214 template <device::cupm::DeviceType T>
215 inline constexpr const char *MatDense_MPI_CUPM<T>::MatProductSetFromOptions_mpidensecupm_mpiaij_C() noexcept
216 {
217   return T == device::cupm::DeviceType::CUDA ? "MatProductSetFromOptions_mpidensecuda_mpiaij_C" : "MatProductSetFromOptions_mpidensehip_mpiaij_C";
218 }
219 
220 template <device::cupm::DeviceType T>
221 inline constexpr const char *MatDense_MPI_CUPM<T>::MatProductSetFromOptions_mpiaijcupmsparse_mpidensecupm_C() noexcept
222 {
223   return T == device::cupm::DeviceType::CUDA ? "MatProductSetFromOptions_mpiaijcusparse_mpidensecuda_C" : "MatProductSetFromOptions_mpiaijhipsparse_mpidensehip_C";
224 }
225 
226 template <device::cupm::DeviceType T>
227 inline constexpr const char *MatDense_MPI_CUPM<T>::MatProductSetFromOptions_mpidensecupm_mpiaijcupmsparse_C() noexcept
228 {
229   return T == device::cupm::DeviceType::CUDA ? "MatProductSetFromOptions_mpidensecuda_mpiaijcusparse_C" : "MatProductSetFromOptions_mpidensehip_mpiaijhipsparse_C";
230 }
231 
232 // ==========================================================================================
233 
234 template <device::cupm::DeviceType T>
235 inline PetscErrorCode MatDense_MPI_CUPM<T>::Create(Mat A) noexcept
236 {
237   PetscFunctionBegin;
238   PetscCall(MatCreate_MPIDense(A));
239   PetscCall(Convert_MPIDense_MPIDenseCUPM(A, MATMPIDENSECUPM(), MAT_INPLACE_MATRIX, &A));
240   PetscFunctionReturn(PETSC_SUCCESS);
241 }
242 
243 // ==========================================================================================
244 
245 template <device::cupm::DeviceType T>
246 inline PetscErrorCode MatDense_MPI_CUPM<T>::BindToCPU(Mat A, PetscBool usehost) noexcept
247 {
248   const auto mimpl = MatIMPLCast(A);
249   const auto pobj  = PetscObjectCast(A);
250 
251   PetscFunctionBegin;
252   PetscCheck(!mimpl->vecinuse, PetscObjectComm(pobj), PETSC_ERR_ORDER, "Need to call MatDenseRestoreColumnVec() first");
253   PetscCheck(!mimpl->matinuse, PetscObjectComm(pobj), PETSC_ERR_ORDER, "Need to call MatDenseRestoreSubMatrix() first");
254   if (const auto mimpl_A = mimpl->A) PetscCall(MatBindToCPU(mimpl_A, usehost));
255   A->boundtocpu = usehost;
256   PetscCall(PetscStrFreeAllocpy(usehost ? PETSCRANDER48 : PETSCDEVICERAND(), &A->defaultrandtype));
257   if (!usehost) {
258     PetscBool iscupm;
259 
260     PetscCall(PetscObjectTypeCompare(PetscObjectCast(mimpl->cvec), VecMPI_CUPM::VECMPICUPM(), &iscupm));
261     if (!iscupm) PetscCall(VecDestroy(&mimpl->cvec));
262     PetscCall(PetscObjectTypeCompare(PetscObjectCast(mimpl->cmat), MATMPIDENSECUPM(), &iscupm));
263     if (!iscupm) PetscCall(MatDestroy(&mimpl->cmat));
264   }
265 
266   MatComposeOp_CUPM(usehost, pobj, "MatDenseGetColumnVec_C", MatDenseGetColumnVec_MPIDense, GetColumnVec<PETSC_MEMORY_ACCESS_READ_WRITE>);
267   MatComposeOp_CUPM(usehost, pobj, "MatDenseRestoreColumnVec_C", MatDenseRestoreColumnVec_MPIDense, RestoreColumnVec<PETSC_MEMORY_ACCESS_READ_WRITE>);
268   MatComposeOp_CUPM(usehost, pobj, "MatDenseGetColumnVecRead_C", MatDenseGetColumnVecRead_MPIDense, GetColumnVec<PETSC_MEMORY_ACCESS_READ>);
269   MatComposeOp_CUPM(usehost, pobj, "MatDenseRestoreColumnVecRead_C", MatDenseRestoreColumnVecRead_MPIDense, RestoreColumnVec<PETSC_MEMORY_ACCESS_READ>);
270   MatComposeOp_CUPM(usehost, pobj, "MatDenseGetColumnVecWrite_C", MatDenseGetColumnVecWrite_MPIDense, GetColumnVec<PETSC_MEMORY_ACCESS_WRITE>);
271   MatComposeOp_CUPM(usehost, pobj, "MatDenseRestoreColumnVecWrite_C", MatDenseRestoreColumnVecWrite_MPIDense, RestoreColumnVec<PETSC_MEMORY_ACCESS_WRITE>);
272 
273   MatSetOp_CUPM(usehost, A, shift, MatShift_MPIDense, Shift);
274 
275   if (const auto mimpl_cmat = mimpl->cmat) PetscCall(MatBindToCPU(mimpl_cmat, usehost));
276   PetscFunctionReturn(PETSC_SUCCESS);
277 }
278 
279 template <device::cupm::DeviceType T>
280 inline PetscErrorCode MatDense_MPI_CUPM<T>::Convert_MPIDenseCUPM_MPIDense(Mat M, MatType mtype, MatReuse reuse, Mat *newmat) noexcept
281 {
282   PetscFunctionBegin;
283   PetscCall(Convert_Dispatch_</* to host */ true>(M, mtype, reuse, newmat));
284   PetscFunctionReturn(PETSC_SUCCESS);
285 }
286 
287 template <device::cupm::DeviceType T>
288 inline PetscErrorCode MatDense_MPI_CUPM<T>::Convert_MPIDense_MPIDenseCUPM(Mat M, MatType mtype, MatReuse reuse, Mat *newmat) noexcept
289 {
290   PetscFunctionBegin;
291   PetscCall(Convert_Dispatch_</* to host */ false>(M, mtype, reuse, newmat));
292   PetscFunctionReturn(PETSC_SUCCESS);
293 }
294 
295 // ==========================================================================================
296 
297 template <device::cupm::DeviceType T>
298 template <PetscMemType, PetscMemoryAccessMode access>
299 inline PetscErrorCode MatDense_MPI_CUPM<T>::GetArray(Mat A, PetscScalar **array, PetscDeviceContext dctx) noexcept
300 {
301   auto &mimplA = MatIMPLCast(A)->A;
302 
303   PetscFunctionBegin;
304   if (!mimplA) PetscCall(MatCreateSeqDenseCUPM<T>(PETSC_COMM_SELF, A->rmap->n, A->cmap->N, nullptr, &mimplA, dctx));
305   PetscCall(MatDenseCUPMGetArray_Private<T, access>(mimplA, array));
306   PetscFunctionReturn(PETSC_SUCCESS);
307 }
308 
309 template <device::cupm::DeviceType T>
310 template <PetscMemType, PetscMemoryAccessMode access>
311 inline PetscErrorCode MatDense_MPI_CUPM<T>::RestoreArray(Mat A, PetscScalar **array, PetscDeviceContext) noexcept
312 {
313   PetscFunctionBegin;
314   PetscCall(MatDenseCUPMRestoreArray_Private<T, access>(MatIMPLCast(A)->A, array));
315   PetscFunctionReturn(PETSC_SUCCESS);
316 }
317 
318 // ==========================================================================================
319 
320 template <device::cupm::DeviceType T>
321 template <PetscMemoryAccessMode access>
322 inline PetscErrorCode MatDense_MPI_CUPM<T>::GetColumnVec(Mat A, PetscInt col, Vec *v) noexcept
323 {
324   using namespace vec::cupm;
325 
326   const auto mimpl   = MatIMPLCast(A);
327   const auto mimpl_A = mimpl->A;
328   const auto pobj    = PetscObjectCast(A);
329   PetscInt   lda;
330 
331   PetscFunctionBegin;
332   PetscCheck(!mimpl->vecinuse, PetscObjectComm(pobj), PETSC_ERR_ORDER, "Need to call MatDenseRestoreColumnVec() first");
333   PetscCheck(!mimpl->matinuse, PetscObjectComm(pobj), PETSC_ERR_ORDER, "Need to call MatDenseRestoreSubMatrix() first");
334   mimpl->vecinuse = col + 1;
335 
336   if (!mimpl->cvec) PetscCall(MatDenseCreateColumnVec_Private(A, &mimpl->cvec));
337 
338   PetscCall(MatDenseGetLDA(mimpl_A, &lda));
339   PetscCall(MatDenseCUPMGetArray_Private<T, access>(mimpl_A, const_cast<PetscScalar **>(&mimpl->ptrinuse)));
340   PetscCall(VecCUPMPlaceArrayAsync<T>(mimpl->cvec, mimpl->ptrinuse + static_cast<std::size_t>(col) * static_cast<std::size_t>(lda)));
341 
342   if (access == PETSC_MEMORY_ACCESS_READ) PetscCall(VecLockReadPush(mimpl->cvec));
343   *v = mimpl->cvec;
344   PetscFunctionReturn(PETSC_SUCCESS);
345 }
346 
347 template <device::cupm::DeviceType T>
348 template <PetscMemoryAccessMode access>
349 inline PetscErrorCode MatDense_MPI_CUPM<T>::RestoreColumnVec(Mat A, PetscInt, Vec *v) noexcept
350 {
351   using namespace vec::cupm;
352 
353   const auto mimpl = MatIMPLCast(A);
354   const auto cvec  = mimpl->cvec;
355 
356   PetscFunctionBegin;
357   PetscCheck(mimpl->vecinuse, PETSC_COMM_SELF, PETSC_ERR_ORDER, "Need to call MatDenseGetColumnVec() first");
358   PetscCheck(cvec, PETSC_COMM_SELF, PETSC_ERR_PLIB, "Missing internal column vector");
359   mimpl->vecinuse = 0;
360 
361   PetscCall(MatDenseCUPMRestoreArray_Private<T, access>(mimpl->A, const_cast<PetscScalar **>(&mimpl->ptrinuse)));
362   if (access == PETSC_MEMORY_ACCESS_READ) PetscCall(VecLockReadPop(cvec));
363   PetscCall(VecCUPMResetArrayAsync<T>(cvec));
364 
365   if (v) *v = nullptr;
366   PetscFunctionReturn(PETSC_SUCCESS);
367 }
368 
369 // ==========================================================================================
370 
371 template <device::cupm::DeviceType T>
372 inline PetscErrorCode MatDense_MPI_CUPM<T>::PlaceArray(Mat A, const PetscScalar *array) noexcept
373 {
374   const auto mimpl = MatIMPLCast(A);
375 
376   PetscFunctionBegin;
377   PetscCheck(!mimpl->vecinuse, PetscObjectComm(PetscObjectCast(A)), PETSC_ERR_ORDER, "Need to call MatDenseRestoreColumnVec() first");
378   PetscCheck(!mimpl->matinuse, PetscObjectComm(PetscObjectCast(A)), PETSC_ERR_ORDER, "Need to call MatDenseRestoreSubMatrix() first");
379   PetscCall(MatDenseCUPMPlaceArray<T>(mimpl->A, array));
380   PetscFunctionReturn(PETSC_SUCCESS);
381 }
382 
383 template <device::cupm::DeviceType T>
384 inline PetscErrorCode MatDense_MPI_CUPM<T>::ReplaceArray(Mat A, const PetscScalar *array) noexcept
385 {
386   const auto mimpl = MatIMPLCast(A);
387 
388   PetscFunctionBegin;
389   PetscCheck(!mimpl->vecinuse, PetscObjectComm(PetscObjectCast(A)), PETSC_ERR_ORDER, "Need to call MatDenseRestoreColumnVec() first");
390   PetscCheck(!mimpl->matinuse, PetscObjectComm(PetscObjectCast(A)), PETSC_ERR_ORDER, "Need to call MatDenseRestoreSubMatrix() first");
391   PetscCall(MatDenseCUPMReplaceArray<T>(mimpl->A, array));
392   PetscFunctionReturn(PETSC_SUCCESS);
393 }
394 
395 template <device::cupm::DeviceType T>
396 inline PetscErrorCode MatDense_MPI_CUPM<T>::ResetArray(Mat A) noexcept
397 {
398   const auto mimpl = MatIMPLCast(A);
399 
400   PetscFunctionBegin;
401   PetscCheck(!mimpl->vecinuse, PetscObjectComm(PetscObjectCast(A)), PETSC_ERR_ORDER, "Need to call MatDenseRestoreColumnVec() first");
402   PetscCheck(!mimpl->matinuse, PetscObjectComm(PetscObjectCast(A)), PETSC_ERR_ORDER, "Need to call MatDenseRestoreSubMatrix() first");
403   PetscCall(MatDenseCUPMResetArray<T>(mimpl->A));
404   PetscFunctionReturn(PETSC_SUCCESS);
405 }
406 
407 // ==========================================================================================
408 
409 template <device::cupm::DeviceType T>
410 inline PetscErrorCode MatDense_MPI_CUPM<T>::Shift(Mat A, PetscScalar alpha) noexcept
411 {
412   PetscDeviceContext dctx;
413 
414   PetscFunctionBegin;
415   PetscCall(GetHandles_(&dctx));
416   PetscCall(PetscInfo(A, "Performing Shift on backend\n"));
417   PetscCall(DiagonalUnaryTransform(A, dctx, device::cupm::functors::make_plus_equals(alpha)));
418   PetscFunctionReturn(PETSC_SUCCESS);
419 }
420 
421 template <device::cupm::DeviceType T>
422 inline PetscErrorCode MatDense_MPI_CUPM<T>::GetDiagonal(Mat A, Vec v) noexcept
423 {
424   PetscFunctionBegin;
425   PetscCall(GetDiagonal_CUPMBase(A, v));
426   PetscFunctionReturn(PETSC_SUCCESS);
427 }
428 
429 } // namespace impl
430 
431 namespace
432 {
433 
434 template <device::cupm::DeviceType T>
435 inline PetscErrorCode MatCreateDenseCUPM(MPI_Comm comm, PetscInt n, PetscInt m, PetscInt N, PetscInt M, PetscScalar *data, Mat *A, PetscDeviceContext dctx = nullptr) noexcept
436 {
437   PetscMPIInt size;
438 
439   PetscFunctionBegin;
440   PetscValidPointer(A, 7);
441   PetscCallMPI(MPI_Comm_size(comm, &size));
442   if (size > 1) {
443     PetscCall(MatCreateMPIDenseCUPM<T>(comm, n, m, N, M, data, A, dctx));
444   } else {
445     if (n == PETSC_DECIDE) n = N;
446     if (m == PETSC_DECIDE) m = M;
447     // It's OK here if both are PETSC_DECIDE since PetscSplitOwnership() will catch that down
448     // the line
449     PetscCall(MatCreateSeqDenseCUPM<T>(comm, n, m, data, A, dctx));
450   }
451   PetscFunctionReturn(PETSC_SUCCESS);
452 }
453 
454 } // anonymous namespace
455 
456 } // namespace cupm
457 
458 } // namespace mat
459 
460 } // namespace Petsc
461 
462 #endif // PETSCMATMPIDENSECUPM_HPP
463