xref: /petsc/src/mat/impls/dense/mpi/cupm/matmpidensecupm.hpp (revision 80aecfd0aa0b9da7f91c87289ecefccb7fa72419)
1 #ifndef PETSCMATMPIDENSECUPM_HPP
2 #define PETSCMATMPIDENSECUPM_HPP
3 
4 #include <petsc/private/matdensecupmimpl.h> /*I <petscmat.h> I*/
5 #include <../src/mat/impls/dense/mpi/mpidense.h>
6 
7 #ifdef __cplusplus
8   #include <../src/mat/impls/dense/seq/cupm/matseqdensecupm.hpp>
9   #include <../src/vec/vec/impls/mpi/cupm/vecmpicupm.hpp>
10 
11 namespace Petsc
12 {
13 
14 namespace mat
15 {
16 
17 namespace cupm
18 {
19 
20 namespace impl
21 {
22 
23 template <device::cupm::DeviceType T>
24 class MatDense_MPI_CUPM : MatDense_CUPM<T, MatDense_MPI_CUPM<T>> {
25 public:
26   MATDENSECUPM_HEADER(T, MatDense_MPI_CUPM<T>);
27 
28 private:
29   PETSC_NODISCARD static constexpr Mat_MPIDense *MatIMPLCast_(Mat) noexcept;
30   PETSC_NODISCARD static constexpr MatType       MATIMPLCUPM_() noexcept;
31 
32   static PetscErrorCode SetPreallocation_(Mat, PetscDeviceContext, PetscScalar *) noexcept;
33 
34   template <bool to_host>
35   static PetscErrorCode Convert_Dispatch_(Mat, MatType, MatReuse, Mat *) noexcept;
36 
37 public:
38   PETSC_NODISCARD static constexpr const char *MatConvert_mpidensecupm_mpidense_C() noexcept;
39 
40   PETSC_NODISCARD static constexpr const char *MatProductSetFromOptions_mpiaij_mpidensecupm_C() noexcept;
41   PETSC_NODISCARD static constexpr const char *MatProductSetFromOptions_mpidensecupm_mpiaij_C() noexcept;
42 
43   PETSC_NODISCARD static constexpr const char *MatProductSetFromOptions_mpiaijcupmsparse_mpidensecupm_C() noexcept;
44   PETSC_NODISCARD static constexpr const char *MatProductSetFromOptions_mpidensecupm_mpiaijcupmsparse_C() noexcept;
45 
46   static PetscErrorCode Create(Mat) noexcept;
47 
48   static PetscErrorCode BindToCPU(Mat, PetscBool) noexcept;
49   static PetscErrorCode Convert_MPIDenseCUPM_MPIDense(Mat, MatType, MatReuse, Mat *) noexcept;
50   static PetscErrorCode Convert_MPIDense_MPIDenseCUPM(Mat, MatType, MatReuse, Mat *) noexcept;
51 
52   template <PetscMemType, PetscMemoryAccessMode>
53   static PetscErrorCode GetArray(Mat, PetscScalar **, PetscDeviceContext = nullptr) noexcept;
54   template <PetscMemType, PetscMemoryAccessMode>
55   static PetscErrorCode RestoreArray(Mat, PetscScalar **, PetscDeviceContext = nullptr) noexcept;
56 
57 private:
58   template <PetscMemType mtype, PetscMemoryAccessMode mode>
59   static PetscErrorCode GetArrayC_(Mat m, PetscScalar **p) noexcept
60   {
61     return GetArray<mtype, mode>(m, p);
62   }
63 
64   template <PetscMemType mtype, PetscMemoryAccessMode mode>
65   static PetscErrorCode RestoreArrayC_(Mat m, PetscScalar **p) noexcept
66   {
67     return RestoreArray<mtype, mode>(m, p);
68   }
69 
70 public:
71   template <PetscMemoryAccessMode>
72   static PetscErrorCode GetColumnVec(Mat, PetscInt, Vec *) noexcept;
73   template <PetscMemoryAccessMode>
74   static PetscErrorCode RestoreColumnVec(Mat, PetscInt, Vec *) noexcept;
75 
76   static PetscErrorCode PlaceArray(Mat, const PetscScalar *) noexcept;
77   static PetscErrorCode ReplaceArray(Mat, const PetscScalar *) noexcept;
78   static PetscErrorCode ResetArray(Mat) noexcept;
79 
80   static PetscErrorCode Shift(Mat, PetscScalar) noexcept;
81 };
82 
83 } // namespace impl
84 
85 namespace
86 {
87 
88 // Declare this here so that the functions below can make use of it
89 template <device::cupm::DeviceType T>
90 inline PetscErrorCode MatCreateMPIDenseCUPM(MPI_Comm comm, PetscInt m, PetscInt n, PetscInt M, PetscInt N, PetscScalar *data, Mat *A, PetscDeviceContext dctx = nullptr, bool preallocate = true) noexcept
91 {
92   PetscFunctionBegin;
93   PetscCall(impl::MatDense_MPI_CUPM<T>::CreateIMPLDenseCUPM(comm, m, n, M, N, data, A, dctx, preallocate));
94   PetscFunctionReturn(PETSC_SUCCESS);
95 }
96 
97 } // anonymous namespace
98 
99 namespace impl
100 {
101 
102 // ==========================================================================================
103 // MatDense_MPI_CUPM -- Private API
104 // ==========================================================================================
105 
106 template <device::cupm::DeviceType T>
107 inline constexpr Mat_MPIDense *MatDense_MPI_CUPM<T>::MatIMPLCast_(Mat m) noexcept
108 {
109   return static_cast<Mat_MPIDense *>(m->data);
110 }
111 
112 template <device::cupm::DeviceType T>
113 inline constexpr MatType MatDense_MPI_CUPM<T>::MATIMPLCUPM_() noexcept
114 {
115   return MATMPIDENSECUPM();
116 }
117 
118 // ==========================================================================================
119 
120 template <device::cupm::DeviceType T>
121 inline PetscErrorCode MatDense_MPI_CUPM<T>::SetPreallocation_(Mat A, PetscDeviceContext dctx, PetscScalar *device_array) noexcept
122 {
123   PetscFunctionBegin;
124   if (auto &mimplA = MatIMPLCast(A)->A) {
125     PetscCall(MatSetType(mimplA, MATSEQDENSECUPM()));
126     PetscCall(MatDense_Seq_CUPM<T>::SetPreallocation(mimplA, dctx, device_array));
127   } else {
128     PetscCall(MatCreateSeqDenseCUPM<T>(PETSC_COMM_SELF, A->rmap->n, A->cmap->N, device_array, &mimplA, dctx));
129   }
130   PetscFunctionReturn(PETSC_SUCCESS);
131 }
132 
133 template <device::cupm::DeviceType T>
134 template <bool to_host>
135 inline PetscErrorCode MatDense_MPI_CUPM<T>::Convert_Dispatch_(Mat M, MatType, MatReuse reuse, Mat *newmat) noexcept
136 {
137   PetscFunctionBegin;
138   if (reuse == MAT_INITIAL_MATRIX) {
139     PetscCall(MatDuplicate(M, MAT_COPY_VALUES, newmat));
140   } else if (reuse == MAT_REUSE_MATRIX) {
141     PetscCall(MatCopy(M, *newmat, SAME_NONZERO_PATTERN));
142   }
143   {
144     const auto B    = *newmat;
145     const auto pobj = PetscObjectCast(B);
146 
147     if (to_host) {
148       PetscCall(BindToCPU(B, PETSC_TRUE));
149     } else {
150       PetscCall(PetscDeviceInitialize(PETSC_DEVICE_CUPM()));
151     }
152 
153     PetscCall(PetscStrFreeAllocpy(to_host ? VECSTANDARD : VecMPI_CUPM::VECCUPM(), &B->defaultvectype));
154     PetscCall(PetscObjectChangeTypeName(pobj, to_host ? MATMPIDENSE : MATMPIDENSECUPM()));
155 
156     // ============================================================
157     // Composed Ops
158     // ============================================================
159     MatComposeOp_CUPM(to_host, pobj, MatConvert_mpidensecupm_mpidense_C(), nullptr, Convert_MPIDenseCUPM_MPIDense);
160     MatComposeOp_CUPM(to_host, pobj, MatProductSetFromOptions_mpiaij_mpidensecupm_C(), nullptr, MatProductSetFromOptions_MPIAIJ_MPIDense);
161     MatComposeOp_CUPM(to_host, pobj, MatProductSetFromOptions_mpiaijcupmsparse_mpidensecupm_C(), nullptr, MatProductSetFromOptions_MPIAIJ_MPIDense);
162     MatComposeOp_CUPM(to_host, pobj, MatProductSetFromOptions_mpidensecupm_mpiaij_C(), nullptr, MatProductSetFromOptions_MPIDense_MPIAIJ);
163     MatComposeOp_CUPM(to_host, pobj, MatProductSetFromOptions_mpidensecupm_mpiaijcupmsparse_C(), nullptr, MatProductSetFromOptions_MPIDense_MPIAIJ);
164     MatComposeOp_CUPM(to_host, pobj, MatDenseCUPMGetArray_C(), nullptr, GetArrayC_<PETSC_MEMTYPE_DEVICE, PETSC_MEMORY_ACCESS_READ_WRITE>);
165     MatComposeOp_CUPM(to_host, pobj, MatDenseCUPMGetArrayRead_C(), nullptr, GetArrayC_<PETSC_MEMTYPE_DEVICE, PETSC_MEMORY_ACCESS_READ>);
166     MatComposeOp_CUPM(to_host, pobj, MatDenseCUPMGetArrayWrite_C(), nullptr, GetArrayC_<PETSC_MEMTYPE_DEVICE, PETSC_MEMORY_ACCESS_WRITE>);
167     MatComposeOp_CUPM(to_host, pobj, MatDenseCUPMRestoreArray_C(), nullptr, RestoreArrayC_<PETSC_MEMTYPE_DEVICE, PETSC_MEMORY_ACCESS_READ_WRITE>);
168     MatComposeOp_CUPM(to_host, pobj, MatDenseCUPMRestoreArrayRead_C(), nullptr, RestoreArrayC_<PETSC_MEMTYPE_DEVICE, PETSC_MEMORY_ACCESS_READ>);
169     MatComposeOp_CUPM(to_host, pobj, MatDenseCUPMRestoreArrayWrite_C(), nullptr, RestoreArrayC_<PETSC_MEMTYPE_DEVICE, PETSC_MEMORY_ACCESS_WRITE>);
170     MatComposeOp_CUPM(to_host, pobj, MatDenseCUPMPlaceArray_C(), nullptr, PlaceArray);
171     MatComposeOp_CUPM(to_host, pobj, MatDenseCUPMResetArray_C(), nullptr, ResetArray);
172     MatComposeOp_CUPM(to_host, pobj, MatDenseCUPMReplaceArray_C(), nullptr, ReplaceArray);
173     MatComposeOp_CUPM(to_host, pobj, MatDenseCUPMSetPreallocation_C(), nullptr, SetPreallocation);
174 
175     if (to_host) {
176       if (auto &m_A = MatIMPLCast(B)->A) PetscCall(MatConvert(m_A, MATSEQDENSE, MAT_INPLACE_MATRIX, &m_A));
177       B->offloadmask = PETSC_OFFLOAD_CPU;
178     } else {
179       if (auto &m_A = MatIMPLCast(B)->A) {
180         PetscCall(MatConvert(m_A, MATSEQDENSECUPM(), MAT_INPLACE_MATRIX, &m_A));
181         B->offloadmask = PETSC_OFFLOAD_BOTH;
182       } else {
183         B->offloadmask = PETSC_OFFLOAD_UNALLOCATED;
184       }
185       PetscCall(BindToCPU(B, PETSC_FALSE));
186     }
187 
188     // ============================================================
189     // Function Pointer Ops
190     // ============================================================
191     MatSetOp_CUPM(to_host, B, bindtocpu, nullptr, BindToCPU);
192   }
193   PetscFunctionReturn(PETSC_SUCCESS);
194 }
195 
196 // ==========================================================================================
197 // MatDense_MPI_CUPM -- Public API
198 // ==========================================================================================
199 
200 template <device::cupm::DeviceType T>
201 inline constexpr const char *MatDense_MPI_CUPM<T>::MatConvert_mpidensecupm_mpidense_C() noexcept
202 {
203   return T == device::cupm::DeviceType::CUDA ? "MatConvert_mpidensecuda_mpidense_C" : "MatConvert_mpidensehip_mpidense_C";
204 }
205 
206 template <device::cupm::DeviceType T>
207 inline constexpr const char *MatDense_MPI_CUPM<T>::MatProductSetFromOptions_mpiaij_mpidensecupm_C() noexcept
208 {
209   return T == device::cupm::DeviceType::CUDA ? "MatProductSetFromOptions_mpiaij_mpidensecuda_C" : "MatProductSetFromOptions_mpiaij_mpidensehip_C";
210 }
211 
212 template <device::cupm::DeviceType T>
213 inline constexpr const char *MatDense_MPI_CUPM<T>::MatProductSetFromOptions_mpidensecupm_mpiaij_C() noexcept
214 {
215   return T == device::cupm::DeviceType::CUDA ? "MatProductSetFromOptions_mpidensecuda_mpiaij_C" : "MatProductSetFromOptions_mpidensehip_mpiaij_C";
216 }
217 
218 template <device::cupm::DeviceType T>
219 inline constexpr const char *MatDense_MPI_CUPM<T>::MatProductSetFromOptions_mpiaijcupmsparse_mpidensecupm_C() noexcept
220 {
221   return T == device::cupm::DeviceType::CUDA ? "MatProductSetFromOptions_mpiaijcusparse_mpidensecuda_C" : "MatProductSetFromOptions_mpiaijhipsparse_mpidensehip_C";
222 }
223 
224 template <device::cupm::DeviceType T>
225 inline constexpr const char *MatDense_MPI_CUPM<T>::MatProductSetFromOptions_mpidensecupm_mpiaijcupmsparse_C() noexcept
226 {
227   return T == device::cupm::DeviceType::CUDA ? "MatProductSetFromOptions_mpidensecuda_mpiaijcusparse_C" : "MatProductSetFromOptions_mpidensehip_mpiaijhipsparse_C";
228 }
229 
230 // ==========================================================================================
231 
232 template <device::cupm::DeviceType T>
233 inline PetscErrorCode MatDense_MPI_CUPM<T>::Create(Mat A) noexcept
234 {
235   PetscFunctionBegin;
236   PetscCall(MatCreate_MPIDense(A));
237   PetscCall(Convert_MPIDense_MPIDenseCUPM(A, MATMPIDENSECUPM(), MAT_INPLACE_MATRIX, &A));
238   PetscFunctionReturn(PETSC_SUCCESS);
239 }
240 
241 // ==========================================================================================
242 
243 template <device::cupm::DeviceType T>
244 inline PetscErrorCode MatDense_MPI_CUPM<T>::BindToCPU(Mat A, PetscBool usehost) noexcept
245 {
246   const auto mimpl = MatIMPLCast(A);
247   const auto pobj  = PetscObjectCast(A);
248 
249   PetscFunctionBegin;
250   PetscCheck(!mimpl->vecinuse, PetscObjectComm(pobj), PETSC_ERR_ORDER, "Need to call MatDenseRestoreColumnVec() first");
251   PetscCheck(!mimpl->matinuse, PetscObjectComm(pobj), PETSC_ERR_ORDER, "Need to call MatDenseRestoreSubMatrix() first");
252   if (const auto mimpl_A = mimpl->A) PetscCall(MatBindToCPU(mimpl_A, usehost));
253   A->boundtocpu = usehost;
254   PetscCall(PetscStrFreeAllocpy(usehost ? PETSCRANDER48 : PETSCDEVICERAND(), &A->defaultrandtype));
255   if (!usehost) {
256     PetscBool iscupm;
257 
258     PetscCall(PetscObjectTypeCompare(PetscObjectCast(mimpl->cvec), VecMPI_CUPM::VECMPICUPM(), &iscupm));
259     if (!iscupm) PetscCall(VecDestroy(&mimpl->cvec));
260     PetscCall(PetscObjectTypeCompare(PetscObjectCast(mimpl->cmat), MATMPIDENSECUPM(), &iscupm));
261     if (!iscupm) PetscCall(MatDestroy(&mimpl->cmat));
262   }
263 
264   MatComposeOp_CUPM(usehost, pobj, "MatDenseGetColumnVec_C", MatDenseGetColumnVec_MPIDense, GetColumnVec<PETSC_MEMORY_ACCESS_READ_WRITE>);
265   MatComposeOp_CUPM(usehost, pobj, "MatDenseRestoreColumnVec_C", MatDenseRestoreColumnVec_MPIDense, RestoreColumnVec<PETSC_MEMORY_ACCESS_READ_WRITE>);
266   MatComposeOp_CUPM(usehost, pobj, "MatDenseGetColumnVecRead_C", MatDenseGetColumnVecRead_MPIDense, GetColumnVec<PETSC_MEMORY_ACCESS_READ>);
267   MatComposeOp_CUPM(usehost, pobj, "MatDenseRestoreColumnVecRead_C", MatDenseRestoreColumnVecRead_MPIDense, RestoreColumnVec<PETSC_MEMORY_ACCESS_READ>);
268   MatComposeOp_CUPM(usehost, pobj, "MatDenseGetColumnVecWrite_C", MatDenseGetColumnVecWrite_MPIDense, GetColumnVec<PETSC_MEMORY_ACCESS_WRITE>);
269   MatComposeOp_CUPM(usehost, pobj, "MatDenseRestoreColumnVecWrite_C", MatDenseRestoreColumnVecWrite_MPIDense, RestoreColumnVec<PETSC_MEMORY_ACCESS_WRITE>);
270 
271   MatSetOp_CUPM(usehost, A, shift, MatShift_MPIDense, Shift);
272 
273   if (const auto mimpl_cmat = mimpl->cmat) PetscCall(MatBindToCPU(mimpl_cmat, usehost));
274   PetscFunctionReturn(PETSC_SUCCESS);
275 }
276 
277 template <device::cupm::DeviceType T>
278 inline PetscErrorCode MatDense_MPI_CUPM<T>::Convert_MPIDenseCUPM_MPIDense(Mat M, MatType mtype, MatReuse reuse, Mat *newmat) noexcept
279 {
280   PetscFunctionBegin;
281   PetscCall(Convert_Dispatch_</* to host */ true>(M, mtype, reuse, newmat));
282   PetscFunctionReturn(PETSC_SUCCESS);
283 }
284 
285 template <device::cupm::DeviceType T>
286 inline PetscErrorCode MatDense_MPI_CUPM<T>::Convert_MPIDense_MPIDenseCUPM(Mat M, MatType mtype, MatReuse reuse, Mat *newmat) noexcept
287 {
288   PetscFunctionBegin;
289   PetscCall(Convert_Dispatch_</* to host */ false>(M, mtype, reuse, newmat));
290   PetscFunctionReturn(PETSC_SUCCESS);
291 }
292 
293 // ==========================================================================================
294 
295 template <device::cupm::DeviceType T>
296 template <PetscMemType, PetscMemoryAccessMode access>
297 inline PetscErrorCode MatDense_MPI_CUPM<T>::GetArray(Mat A, PetscScalar **array, PetscDeviceContext) noexcept
298 {
299   PetscFunctionBegin;
300   PetscCall(MatDenseCUPMGetArray_Private<T, access>(MatIMPLCast(A)->A, array));
301   PetscFunctionReturn(PETSC_SUCCESS);
302 }
303 
304 template <device::cupm::DeviceType T>
305 template <PetscMemType, PetscMemoryAccessMode access>
306 inline PetscErrorCode MatDense_MPI_CUPM<T>::RestoreArray(Mat A, PetscScalar **array, PetscDeviceContext) noexcept
307 {
308   PetscFunctionBegin;
309   PetscCall(MatDenseCUPMRestoreArray_Private<T, access>(MatIMPLCast(A)->A, array));
310   PetscFunctionReturn(PETSC_SUCCESS);
311 }
312 
313 // ==========================================================================================
314 
315 template <device::cupm::DeviceType T>
316 template <PetscMemoryAccessMode access>
317 inline PetscErrorCode MatDense_MPI_CUPM<T>::GetColumnVec(Mat A, PetscInt col, Vec *v) noexcept
318 {
319   using namespace vec::cupm;
320 
321   const auto mimpl   = MatIMPLCast(A);
322   const auto mimpl_A = mimpl->A;
323   const auto pobj    = PetscObjectCast(A);
324   PetscInt   lda;
325 
326   PetscFunctionBegin;
327   PetscCheck(!mimpl->vecinuse, PetscObjectComm(pobj), PETSC_ERR_ORDER, "Need to call MatDenseRestoreColumnVec() first");
328   PetscCheck(!mimpl->matinuse, PetscObjectComm(pobj), PETSC_ERR_ORDER, "Need to call MatDenseRestoreSubMatrix() first");
329   mimpl->vecinuse = col + 1;
330 
331   if (!mimpl->cvec) PetscCall(MatDenseCreateColumnVec_Private(A, &mimpl->cvec));
332 
333   PetscCall(MatDenseGetLDA(mimpl_A, &lda));
334   PetscCall(MatDenseCUPMGetArray_Private<T, access>(mimpl_A, const_cast<PetscScalar **>(&mimpl->ptrinuse)));
335   PetscCall(VecCUPMPlaceArrayAsync<T>(mimpl->cvec, mimpl->ptrinuse + static_cast<std::size_t>(col) * static_cast<std::size_t>(lda)));
336 
337   if (access == PETSC_MEMORY_ACCESS_READ) PetscCall(VecLockReadPush(mimpl->cvec));
338   *v = mimpl->cvec;
339   PetscFunctionReturn(PETSC_SUCCESS);
340 }
341 
342 template <device::cupm::DeviceType T>
343 template <PetscMemoryAccessMode access>
344 inline PetscErrorCode MatDense_MPI_CUPM<T>::RestoreColumnVec(Mat A, PetscInt, Vec *v) noexcept
345 {
346   using namespace vec::cupm;
347 
348   const auto mimpl = MatIMPLCast(A);
349   const auto cvec  = mimpl->cvec;
350 
351   PetscFunctionBegin;
352   PetscCheck(mimpl->vecinuse, PETSC_COMM_SELF, PETSC_ERR_ORDER, "Need to call MatDenseGetColumnVec() first");
353   PetscCheck(cvec, PETSC_COMM_SELF, PETSC_ERR_PLIB, "Missing internal column vector");
354   mimpl->vecinuse = 0;
355 
356   PetscCall(MatDenseCUPMRestoreArray_Private<T, access>(mimpl->A, const_cast<PetscScalar **>(&mimpl->ptrinuse)));
357   if (access == PETSC_MEMORY_ACCESS_READ) PetscCall(VecLockReadPop(cvec));
358   PetscCall(VecCUPMResetArrayAsync<T>(cvec));
359 
360   if (v) *v = nullptr;
361   PetscFunctionReturn(PETSC_SUCCESS);
362 }
363 
364 // ==========================================================================================
365 
366 template <device::cupm::DeviceType T>
367 inline PetscErrorCode MatDense_MPI_CUPM<T>::PlaceArray(Mat A, const PetscScalar *array) noexcept
368 {
369   const auto mimpl = MatIMPLCast(A);
370 
371   PetscFunctionBegin;
372   PetscCheck(!mimpl->vecinuse, PetscObjectComm(PetscObjectCast(A)), PETSC_ERR_ORDER, "Need to call MatDenseRestoreColumnVec() first");
373   PetscCheck(!mimpl->matinuse, PetscObjectComm(PetscObjectCast(A)), PETSC_ERR_ORDER, "Need to call MatDenseRestoreSubMatrix() first");
374   PetscCall(MatDenseCUPMPlaceArray<T>(mimpl->A, array));
375   PetscFunctionReturn(PETSC_SUCCESS);
376 }
377 
378 template <device::cupm::DeviceType T>
379 inline PetscErrorCode MatDense_MPI_CUPM<T>::ReplaceArray(Mat A, const PetscScalar *array) noexcept
380 {
381   const auto mimpl = MatIMPLCast(A);
382 
383   PetscFunctionBegin;
384   PetscCheck(!mimpl->vecinuse, PetscObjectComm(PetscObjectCast(A)), PETSC_ERR_ORDER, "Need to call MatDenseRestoreColumnVec() first");
385   PetscCheck(!mimpl->matinuse, PetscObjectComm(PetscObjectCast(A)), PETSC_ERR_ORDER, "Need to call MatDenseRestoreSubMatrix() first");
386   PetscCall(MatDenseCUPMReplaceArray<T>(mimpl->A, array));
387   PetscFunctionReturn(PETSC_SUCCESS);
388 }
389 
390 template <device::cupm::DeviceType T>
391 inline PetscErrorCode MatDense_MPI_CUPM<T>::ResetArray(Mat A) noexcept
392 {
393   const auto mimpl = MatIMPLCast(A);
394 
395   PetscFunctionBegin;
396   PetscCheck(!mimpl->vecinuse, PetscObjectComm(PetscObjectCast(A)), PETSC_ERR_ORDER, "Need to call MatDenseRestoreColumnVec() first");
397   PetscCheck(!mimpl->matinuse, PetscObjectComm(PetscObjectCast(A)), PETSC_ERR_ORDER, "Need to call MatDenseRestoreSubMatrix() first");
398   PetscCall(MatDenseCUPMResetArray<T>(mimpl->A));
399   PetscFunctionReturn(PETSC_SUCCESS);
400 }
401 
402 // ==========================================================================================
403 
404 template <device::cupm::DeviceType T>
405 inline PetscErrorCode MatDense_MPI_CUPM<T>::Shift(Mat A, PetscScalar alpha) noexcept
406 {
407   PetscDeviceContext dctx;
408 
409   PetscFunctionBegin;
410   PetscCall(GetHandles_(&dctx));
411   PetscCall(PetscInfo(A, "Performing Shift on backend\n"));
412   PetscCall(DiagonalUnaryTransform(A, A->rmap->rstart, A->rmap->rend, A->cmap->N, dctx, device::cupm::functors::make_plus_equals(alpha)));
413   PetscFunctionReturn(PETSC_SUCCESS);
414 }
415 
416 } // namespace impl
417 
418 namespace
419 {
420 
421 template <device::cupm::DeviceType T>
422 inline PetscErrorCode MatCreateDenseCUPM(MPI_Comm comm, PetscInt n, PetscInt m, PetscInt N, PetscInt M, PetscScalar *data, Mat *A, PetscDeviceContext dctx = nullptr) noexcept
423 {
424   PetscMPIInt size;
425 
426   PetscFunctionBegin;
427   PetscValidPointer(A, 7);
428   PetscCallMPI(MPI_Comm_size(comm, &size));
429   if (size > 1) {
430     PetscCall(MatCreateMPIDenseCUPM<T>(comm, n, m, N, M, data, A, dctx));
431   } else {
432     if (n == PETSC_DECIDE) n = N;
433     if (m == PETSC_DECIDE) m = M;
434     // It's OK here if both are PETSC_DECIDE since PetscSplitOwnership() will catch that down
435     // the line
436     PetscCall(MatCreateSeqDenseCUPM<T>(comm, n, m, data, A, dctx));
437   }
438   PetscFunctionReturn(PETSC_SUCCESS);
439 }
440 
441 } // anonymous namespace
442 
443 } // namespace cupm
444 
445 } // namespace mat
446 
447 } // namespace Petsc
448 
449 #endif // __cplusplus
450 
451 #endif // PETSCMATMPIDENSECUPM_HPP
452