xref: /petsc/src/mat/impls/dense/mpi/cupm/matmpidensecupm.hpp (revision 4742e46b56cb5d0762110e30c569ce3737a8e22a)
1*4742e46bSJacob Faibussowitsch #ifndef PETSCMATMPIDENSECUPM_HPP
2*4742e46bSJacob Faibussowitsch #define PETSCMATMPIDENSECUPM_HPP
3*4742e46bSJacob Faibussowitsch 
4*4742e46bSJacob Faibussowitsch #include <petsc/private/matdensecupmimpl.h> /*I <petscmat.h> I*/
5*4742e46bSJacob Faibussowitsch #include <../src/mat/impls/dense/mpi/mpidense.h>
6*4742e46bSJacob Faibussowitsch 
7*4742e46bSJacob Faibussowitsch #ifdef __cplusplus
8*4742e46bSJacob Faibussowitsch   #include <../src/mat/impls/dense/seq/cupm/matseqdensecupm.hpp>
9*4742e46bSJacob Faibussowitsch   #include <../src/vec/vec/impls/mpi/cupm/vecmpicupm.hpp>
10*4742e46bSJacob Faibussowitsch 
11*4742e46bSJacob Faibussowitsch namespace Petsc
12*4742e46bSJacob Faibussowitsch {
13*4742e46bSJacob Faibussowitsch 
14*4742e46bSJacob Faibussowitsch namespace mat
15*4742e46bSJacob Faibussowitsch {
16*4742e46bSJacob Faibussowitsch 
17*4742e46bSJacob Faibussowitsch namespace cupm
18*4742e46bSJacob Faibussowitsch {
19*4742e46bSJacob Faibussowitsch 
20*4742e46bSJacob Faibussowitsch namespace impl
21*4742e46bSJacob Faibussowitsch {
22*4742e46bSJacob Faibussowitsch 
23*4742e46bSJacob Faibussowitsch template <device::cupm::DeviceType T>
24*4742e46bSJacob Faibussowitsch class MatDense_MPI_CUPM : MatDense_CUPM<T, MatDense_MPI_CUPM<T>> {
25*4742e46bSJacob Faibussowitsch public:
26*4742e46bSJacob Faibussowitsch   MATDENSECUPM_HEADER(T, MatDense_MPI_CUPM<T>);
27*4742e46bSJacob Faibussowitsch 
28*4742e46bSJacob Faibussowitsch private:
29*4742e46bSJacob Faibussowitsch   PETSC_NODISCARD static constexpr Mat_MPIDense *MatIMPLCast_(Mat) noexcept;
30*4742e46bSJacob Faibussowitsch   PETSC_NODISCARD static constexpr MatType       MATIMPLCUPM_() noexcept;
31*4742e46bSJacob Faibussowitsch 
32*4742e46bSJacob Faibussowitsch   static PetscErrorCode SetPreallocation_(Mat, PetscDeviceContext, PetscScalar *) noexcept;
33*4742e46bSJacob Faibussowitsch 
34*4742e46bSJacob Faibussowitsch   template <bool to_host>
35*4742e46bSJacob Faibussowitsch   static PetscErrorCode Convert_Dispatch_(Mat, MatType, MatReuse, Mat *) noexcept;
36*4742e46bSJacob Faibussowitsch 
37*4742e46bSJacob Faibussowitsch public:
38*4742e46bSJacob Faibussowitsch   PETSC_NODISCARD static constexpr const char *MatConvert_mpidensecupm_mpidense_C() noexcept;
39*4742e46bSJacob Faibussowitsch 
40*4742e46bSJacob Faibussowitsch   PETSC_NODISCARD static constexpr const char *MatProductSetFromOptions_mpiaij_mpidensecupm_C() noexcept;
41*4742e46bSJacob Faibussowitsch   PETSC_NODISCARD static constexpr const char *MatProductSetFromOptions_mpidensecupm_mpiaij_C() noexcept;
42*4742e46bSJacob Faibussowitsch 
43*4742e46bSJacob Faibussowitsch   PETSC_NODISCARD static constexpr const char *MatProductSetFromOptions_mpiaijcupmsparse_mpidensecupm_C() noexcept;
44*4742e46bSJacob Faibussowitsch   PETSC_NODISCARD static constexpr const char *MatProductSetFromOptions_mpidensecupm_mpiaijcupmsparse_C() noexcept;
45*4742e46bSJacob Faibussowitsch 
46*4742e46bSJacob Faibussowitsch   static PetscErrorCode Create(Mat) noexcept;
47*4742e46bSJacob Faibussowitsch 
48*4742e46bSJacob Faibussowitsch   static PetscErrorCode BindToCPU(Mat, PetscBool) noexcept;
49*4742e46bSJacob Faibussowitsch   static PetscErrorCode Convert_MPIDenseCUPM_MPIDense(Mat, MatType, MatReuse, Mat *) noexcept;
50*4742e46bSJacob Faibussowitsch   static PetscErrorCode Convert_MPIDense_MPIDenseCUPM(Mat, MatType, MatReuse, Mat *) noexcept;
51*4742e46bSJacob Faibussowitsch 
52*4742e46bSJacob Faibussowitsch   template <PetscMemType, PetscMemoryAccessMode>
53*4742e46bSJacob Faibussowitsch   static PetscErrorCode GetArray(Mat, PetscScalar **, PetscDeviceContext = nullptr) noexcept;
54*4742e46bSJacob Faibussowitsch   template <PetscMemType, PetscMemoryAccessMode>
55*4742e46bSJacob Faibussowitsch   static PetscErrorCode RestoreArray(Mat, PetscScalar **, PetscDeviceContext = nullptr) noexcept;
56*4742e46bSJacob Faibussowitsch 
57*4742e46bSJacob Faibussowitsch private:
58*4742e46bSJacob Faibussowitsch   template <PetscMemType mtype, PetscMemoryAccessMode mode>
59*4742e46bSJacob Faibussowitsch   static PetscErrorCode GetArrayC_(Mat m, PetscScalar **p) noexcept
60*4742e46bSJacob Faibussowitsch   {
61*4742e46bSJacob Faibussowitsch     return GetArray<mtype, mode>(m, p);
62*4742e46bSJacob Faibussowitsch   }
63*4742e46bSJacob Faibussowitsch 
64*4742e46bSJacob Faibussowitsch   template <PetscMemType mtype, PetscMemoryAccessMode mode>
65*4742e46bSJacob Faibussowitsch   static PetscErrorCode RestoreArrayC_(Mat m, PetscScalar **p) noexcept
66*4742e46bSJacob Faibussowitsch   {
67*4742e46bSJacob Faibussowitsch     return RestoreArray<mtype, mode>(m, p);
68*4742e46bSJacob Faibussowitsch   }
69*4742e46bSJacob Faibussowitsch 
70*4742e46bSJacob Faibussowitsch public:
71*4742e46bSJacob Faibussowitsch   template <PetscMemoryAccessMode>
72*4742e46bSJacob Faibussowitsch   static PetscErrorCode GetColumnVec(Mat, PetscInt, Vec *) noexcept;
73*4742e46bSJacob Faibussowitsch   template <PetscMemoryAccessMode>
74*4742e46bSJacob Faibussowitsch   static PetscErrorCode RestoreColumnVec(Mat, PetscInt, Vec *) noexcept;
75*4742e46bSJacob Faibussowitsch 
76*4742e46bSJacob Faibussowitsch   static PetscErrorCode PlaceArray(Mat, const PetscScalar *) noexcept;
77*4742e46bSJacob Faibussowitsch   static PetscErrorCode ReplaceArray(Mat, const PetscScalar *) noexcept;
78*4742e46bSJacob Faibussowitsch   static PetscErrorCode ResetArray(Mat) noexcept;
79*4742e46bSJacob Faibussowitsch 
80*4742e46bSJacob Faibussowitsch   static PetscErrorCode Shift(Mat, PetscScalar) noexcept;
81*4742e46bSJacob Faibussowitsch };
82*4742e46bSJacob Faibussowitsch 
83*4742e46bSJacob Faibussowitsch } // namespace impl
84*4742e46bSJacob Faibussowitsch 
85*4742e46bSJacob Faibussowitsch namespace
86*4742e46bSJacob Faibussowitsch {
87*4742e46bSJacob Faibussowitsch 
88*4742e46bSJacob Faibussowitsch // Declare this here so that the functions below can make use of it
89*4742e46bSJacob Faibussowitsch template <device::cupm::DeviceType T>
90*4742e46bSJacob Faibussowitsch inline PetscErrorCode MatCreateMPIDenseCUPM(MPI_Comm comm, PetscInt m, PetscInt n, PetscInt M, PetscInt N, PetscScalar *data, Mat *A, PetscDeviceContext dctx = nullptr, bool preallocate = true) noexcept
91*4742e46bSJacob Faibussowitsch {
92*4742e46bSJacob Faibussowitsch   PetscFunctionBegin;
93*4742e46bSJacob Faibussowitsch   PetscCall(impl::MatDense_MPI_CUPM<T>::CreateIMPLDenseCUPM(comm, m, n, M, N, data, A, dctx, preallocate));
94*4742e46bSJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
95*4742e46bSJacob Faibussowitsch }
96*4742e46bSJacob Faibussowitsch 
97*4742e46bSJacob Faibussowitsch } // anonymous namespace
98*4742e46bSJacob Faibussowitsch 
99*4742e46bSJacob Faibussowitsch namespace impl
100*4742e46bSJacob Faibussowitsch {
101*4742e46bSJacob Faibussowitsch 
102*4742e46bSJacob Faibussowitsch // ==========================================================================================
103*4742e46bSJacob Faibussowitsch // MatDense_MPI_CUPM -- Private API
104*4742e46bSJacob Faibussowitsch // ==========================================================================================
105*4742e46bSJacob Faibussowitsch 
106*4742e46bSJacob Faibussowitsch template <device::cupm::DeviceType T>
107*4742e46bSJacob Faibussowitsch inline constexpr Mat_MPIDense *MatDense_MPI_CUPM<T>::MatIMPLCast_(Mat m) noexcept
108*4742e46bSJacob Faibussowitsch {
109*4742e46bSJacob Faibussowitsch   return static_cast<Mat_MPIDense *>(m->data);
110*4742e46bSJacob Faibussowitsch }
111*4742e46bSJacob Faibussowitsch 
112*4742e46bSJacob Faibussowitsch template <device::cupm::DeviceType T>
113*4742e46bSJacob Faibussowitsch inline constexpr MatType MatDense_MPI_CUPM<T>::MATIMPLCUPM_() noexcept
114*4742e46bSJacob Faibussowitsch {
115*4742e46bSJacob Faibussowitsch   return MATMPIDENSECUPM();
116*4742e46bSJacob Faibussowitsch }
117*4742e46bSJacob Faibussowitsch 
118*4742e46bSJacob Faibussowitsch // ==========================================================================================
119*4742e46bSJacob Faibussowitsch 
120*4742e46bSJacob Faibussowitsch template <device::cupm::DeviceType T>
121*4742e46bSJacob Faibussowitsch inline PetscErrorCode MatDense_MPI_CUPM<T>::SetPreallocation_(Mat A, PetscDeviceContext dctx, PetscScalar *device_array) noexcept
122*4742e46bSJacob Faibussowitsch {
123*4742e46bSJacob Faibussowitsch   PetscFunctionBegin;
124*4742e46bSJacob Faibussowitsch   if (auto &mimplA = MatIMPLCast(A)->A) {
125*4742e46bSJacob Faibussowitsch     PetscCall(MatSetType(mimplA, MATSEQDENSECUPM()));
126*4742e46bSJacob Faibussowitsch     PetscCall(MatDense_Seq_CUPM<T>::SetPreallocation(mimplA, dctx, device_array));
127*4742e46bSJacob Faibussowitsch   } else {
128*4742e46bSJacob Faibussowitsch     PetscCall(MatCreateSeqDenseCUPM<T>(PETSC_COMM_SELF, A->rmap->n, A->cmap->N, device_array, &mimplA, dctx));
129*4742e46bSJacob Faibussowitsch   }
130*4742e46bSJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
131*4742e46bSJacob Faibussowitsch }
132*4742e46bSJacob Faibussowitsch 
133*4742e46bSJacob Faibussowitsch template <device::cupm::DeviceType T>
134*4742e46bSJacob Faibussowitsch template <bool to_host>
135*4742e46bSJacob Faibussowitsch inline PetscErrorCode MatDense_MPI_CUPM<T>::Convert_Dispatch_(Mat M, MatType, MatReuse reuse, Mat *newmat) noexcept
136*4742e46bSJacob Faibussowitsch {
137*4742e46bSJacob Faibussowitsch   PetscFunctionBegin;
138*4742e46bSJacob Faibussowitsch   if (reuse == MAT_INITIAL_MATRIX) {
139*4742e46bSJacob Faibussowitsch     PetscCall(MatDuplicate(M, MAT_COPY_VALUES, newmat));
140*4742e46bSJacob Faibussowitsch   } else if (reuse == MAT_REUSE_MATRIX) {
141*4742e46bSJacob Faibussowitsch     PetscCall(MatCopy(M, *newmat, SAME_NONZERO_PATTERN));
142*4742e46bSJacob Faibussowitsch   }
143*4742e46bSJacob Faibussowitsch   {
144*4742e46bSJacob Faibussowitsch     const auto B    = *newmat;
145*4742e46bSJacob Faibussowitsch     const auto pobj = PetscObjectCast(B);
146*4742e46bSJacob Faibussowitsch 
147*4742e46bSJacob Faibussowitsch     if (to_host) {
148*4742e46bSJacob Faibussowitsch       PetscCall(BindToCPU(B, PETSC_TRUE));
149*4742e46bSJacob Faibussowitsch     } else {
150*4742e46bSJacob Faibussowitsch       PetscCall(PetscDeviceInitialize(PETSC_DEVICE_CUPM()));
151*4742e46bSJacob Faibussowitsch     }
152*4742e46bSJacob Faibussowitsch 
153*4742e46bSJacob Faibussowitsch     PetscCall(PetscStrFreeAllocpy(to_host ? VECSTANDARD : VecMPI_CUPM::VECCUPM(), &B->defaultvectype));
154*4742e46bSJacob Faibussowitsch     PetscCall(PetscObjectChangeTypeName(pobj, to_host ? MATMPIDENSE : MATMPIDENSECUPM()));
155*4742e46bSJacob Faibussowitsch 
156*4742e46bSJacob Faibussowitsch     // ============================================================
157*4742e46bSJacob Faibussowitsch     // Composed Ops
158*4742e46bSJacob Faibussowitsch     // ============================================================
159*4742e46bSJacob Faibussowitsch     MatComposeOp_CUPM(to_host, pobj, MatConvert_mpidensecupm_mpidense_C(), nullptr, Convert_MPIDenseCUPM_MPIDense);
160*4742e46bSJacob Faibussowitsch     MatComposeOp_CUPM(to_host, pobj, MatProductSetFromOptions_mpiaij_mpidensecupm_C(), nullptr, MatProductSetFromOptions_MPIAIJ_MPIDense);
161*4742e46bSJacob Faibussowitsch     MatComposeOp_CUPM(to_host, pobj, MatProductSetFromOptions_mpiaijcupmsparse_mpidensecupm_C(), nullptr, MatProductSetFromOptions_MPIAIJ_MPIDense);
162*4742e46bSJacob Faibussowitsch     MatComposeOp_CUPM(to_host, pobj, MatProductSetFromOptions_mpidensecupm_mpiaij_C(), nullptr, MatProductSetFromOptions_MPIDense_MPIAIJ);
163*4742e46bSJacob Faibussowitsch     MatComposeOp_CUPM(to_host, pobj, MatProductSetFromOptions_mpidensecupm_mpiaijcupmsparse_C(), nullptr, MatProductSetFromOptions_MPIDense_MPIAIJ);
164*4742e46bSJacob Faibussowitsch     MatComposeOp_CUPM(to_host, pobj, MatDenseCUPMGetArray_C(), nullptr, GetArrayC_<PETSC_MEMTYPE_DEVICE, PETSC_MEMORY_ACCESS_READ_WRITE>);
165*4742e46bSJacob Faibussowitsch     MatComposeOp_CUPM(to_host, pobj, MatDenseCUPMGetArrayRead_C(), nullptr, GetArrayC_<PETSC_MEMTYPE_DEVICE, PETSC_MEMORY_ACCESS_READ>);
166*4742e46bSJacob Faibussowitsch     MatComposeOp_CUPM(to_host, pobj, MatDenseCUPMGetArrayWrite_C(), nullptr, GetArrayC_<PETSC_MEMTYPE_DEVICE, PETSC_MEMORY_ACCESS_WRITE>);
167*4742e46bSJacob Faibussowitsch     MatComposeOp_CUPM(to_host, pobj, MatDenseCUPMRestoreArray_C(), nullptr, RestoreArrayC_<PETSC_MEMTYPE_DEVICE, PETSC_MEMORY_ACCESS_READ_WRITE>);
168*4742e46bSJacob Faibussowitsch     MatComposeOp_CUPM(to_host, pobj, MatDenseCUPMRestoreArrayRead_C(), nullptr, RestoreArrayC_<PETSC_MEMTYPE_DEVICE, PETSC_MEMORY_ACCESS_READ>);
169*4742e46bSJacob Faibussowitsch     MatComposeOp_CUPM(to_host, pobj, MatDenseCUPMRestoreArrayWrite_C(), nullptr, RestoreArrayC_<PETSC_MEMTYPE_DEVICE, PETSC_MEMORY_ACCESS_WRITE>);
170*4742e46bSJacob Faibussowitsch     MatComposeOp_CUPM(to_host, pobj, MatDenseCUPMPlaceArray_C(), nullptr, PlaceArray);
171*4742e46bSJacob Faibussowitsch     MatComposeOp_CUPM(to_host, pobj, MatDenseCUPMResetArray_C(), nullptr, ResetArray);
172*4742e46bSJacob Faibussowitsch     MatComposeOp_CUPM(to_host, pobj, MatDenseCUPMReplaceArray_C(), nullptr, ReplaceArray);
173*4742e46bSJacob Faibussowitsch 
174*4742e46bSJacob Faibussowitsch     if (to_host) {
175*4742e46bSJacob Faibussowitsch       if (auto &m_A = MatIMPLCast(B)->A) PetscCall(MatConvert(m_A, MATSEQDENSE, MAT_INPLACE_MATRIX, &m_A));
176*4742e46bSJacob Faibussowitsch       B->offloadmask = PETSC_OFFLOAD_CPU;
177*4742e46bSJacob Faibussowitsch     } else {
178*4742e46bSJacob Faibussowitsch       if (auto &m_A = MatIMPLCast(B)->A) {
179*4742e46bSJacob Faibussowitsch         PetscCall(MatConvert(m_A, MATSEQDENSECUPM(), MAT_INPLACE_MATRIX, &m_A));
180*4742e46bSJacob Faibussowitsch         B->offloadmask = PETSC_OFFLOAD_BOTH;
181*4742e46bSJacob Faibussowitsch       } else {
182*4742e46bSJacob Faibussowitsch         B->offloadmask = PETSC_OFFLOAD_UNALLOCATED;
183*4742e46bSJacob Faibussowitsch       }
184*4742e46bSJacob Faibussowitsch       PetscCall(BindToCPU(B, PETSC_FALSE));
185*4742e46bSJacob Faibussowitsch     }
186*4742e46bSJacob Faibussowitsch 
187*4742e46bSJacob Faibussowitsch     // ============================================================
188*4742e46bSJacob Faibussowitsch     // Function Pointer Ops
189*4742e46bSJacob Faibussowitsch     // ============================================================
190*4742e46bSJacob Faibussowitsch     MatSetOp_CUPM(to_host, B, bindtocpu, nullptr, BindToCPU);
191*4742e46bSJacob Faibussowitsch   }
192*4742e46bSJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
193*4742e46bSJacob Faibussowitsch }
194*4742e46bSJacob Faibussowitsch 
195*4742e46bSJacob Faibussowitsch // ==========================================================================================
196*4742e46bSJacob Faibussowitsch // MatDense_MPI_CUPM -- Public API
197*4742e46bSJacob Faibussowitsch // ==========================================================================================
198*4742e46bSJacob Faibussowitsch 
199*4742e46bSJacob Faibussowitsch template <device::cupm::DeviceType T>
200*4742e46bSJacob Faibussowitsch inline constexpr const char *MatDense_MPI_CUPM<T>::MatConvert_mpidensecupm_mpidense_C() noexcept
201*4742e46bSJacob Faibussowitsch {
202*4742e46bSJacob Faibussowitsch   return T == device::cupm::DeviceType::CUDA ? "MatConvert_mpidensecuda_mpidense_C" : "MatConvert_mpidensehip_mpidense_C";
203*4742e46bSJacob Faibussowitsch }
204*4742e46bSJacob Faibussowitsch 
205*4742e46bSJacob Faibussowitsch template <device::cupm::DeviceType T>
206*4742e46bSJacob Faibussowitsch inline constexpr const char *MatDense_MPI_CUPM<T>::MatProductSetFromOptions_mpiaij_mpidensecupm_C() noexcept
207*4742e46bSJacob Faibussowitsch {
208*4742e46bSJacob Faibussowitsch   return T == device::cupm::DeviceType::CUDA ? "MatProductSetFromOptions_mpiaij_mpidensecuda_C" : "MatProductSetFromOptions_mpiaij_mpidensehip_C";
209*4742e46bSJacob Faibussowitsch }
210*4742e46bSJacob Faibussowitsch 
211*4742e46bSJacob Faibussowitsch template <device::cupm::DeviceType T>
212*4742e46bSJacob Faibussowitsch inline constexpr const char *MatDense_MPI_CUPM<T>::MatProductSetFromOptions_mpidensecupm_mpiaij_C() noexcept
213*4742e46bSJacob Faibussowitsch {
214*4742e46bSJacob Faibussowitsch   return T == device::cupm::DeviceType::CUDA ? "MatProductSetFromOptions_mpidensecuda_mpiaij_C" : "MatProductSetFromOptions_mpidensehip_mpiaij_C";
215*4742e46bSJacob Faibussowitsch }
216*4742e46bSJacob Faibussowitsch 
217*4742e46bSJacob Faibussowitsch template <device::cupm::DeviceType T>
218*4742e46bSJacob Faibussowitsch inline constexpr const char *MatDense_MPI_CUPM<T>::MatProductSetFromOptions_mpiaijcupmsparse_mpidensecupm_C() noexcept
219*4742e46bSJacob Faibussowitsch {
220*4742e46bSJacob Faibussowitsch   return T == device::cupm::DeviceType::CUDA ? "MatProductSetFromOptions_mpiaijcusparse_mpidensecuda_C" : "MatProductSetFromOptions_mpiaijhipsparse_mpidensehip_C";
221*4742e46bSJacob Faibussowitsch }
222*4742e46bSJacob Faibussowitsch 
223*4742e46bSJacob Faibussowitsch template <device::cupm::DeviceType T>
224*4742e46bSJacob Faibussowitsch inline constexpr const char *MatDense_MPI_CUPM<T>::MatProductSetFromOptions_mpidensecupm_mpiaijcupmsparse_C() noexcept
225*4742e46bSJacob Faibussowitsch {
226*4742e46bSJacob Faibussowitsch   return T == device::cupm::DeviceType::CUDA ? "MatProductSetFromOptions_mpidensecuda_mpiaijcusparse_C" : "MatProductSetFromOptions_mpidensehip_mpiaijhipsparse_C";
227*4742e46bSJacob Faibussowitsch }
228*4742e46bSJacob Faibussowitsch 
229*4742e46bSJacob Faibussowitsch // ==========================================================================================
230*4742e46bSJacob Faibussowitsch 
231*4742e46bSJacob Faibussowitsch template <device::cupm::DeviceType T>
232*4742e46bSJacob Faibussowitsch inline PetscErrorCode MatDense_MPI_CUPM<T>::Create(Mat A) noexcept
233*4742e46bSJacob Faibussowitsch {
234*4742e46bSJacob Faibussowitsch   PetscFunctionBegin;
235*4742e46bSJacob Faibussowitsch   PetscCall(MatCreate_MPIDense(A));
236*4742e46bSJacob Faibussowitsch   PetscCall(Convert_MPIDense_MPIDenseCUPM(A, MATMPIDENSECUPM(), MAT_INPLACE_MATRIX, &A));
237*4742e46bSJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
238*4742e46bSJacob Faibussowitsch }
239*4742e46bSJacob Faibussowitsch 
240*4742e46bSJacob Faibussowitsch // ==========================================================================================
241*4742e46bSJacob Faibussowitsch 
242*4742e46bSJacob Faibussowitsch template <device::cupm::DeviceType T>
243*4742e46bSJacob Faibussowitsch inline PetscErrorCode MatDense_MPI_CUPM<T>::BindToCPU(Mat A, PetscBool usehost) noexcept
244*4742e46bSJacob Faibussowitsch {
245*4742e46bSJacob Faibussowitsch   const auto mimpl = MatIMPLCast(A);
246*4742e46bSJacob Faibussowitsch   const auto pobj  = PetscObjectCast(A);
247*4742e46bSJacob Faibussowitsch 
248*4742e46bSJacob Faibussowitsch   PetscFunctionBegin;
249*4742e46bSJacob Faibussowitsch   PetscCheck(!mimpl->vecinuse, PetscObjectComm(pobj), PETSC_ERR_ORDER, "Need to call MatDenseRestoreColumnVec() first");
250*4742e46bSJacob Faibussowitsch   PetscCheck(!mimpl->matinuse, PetscObjectComm(pobj), PETSC_ERR_ORDER, "Need to call MatDenseRestoreSubMatrix() first");
251*4742e46bSJacob Faibussowitsch   if (const auto mimpl_A = mimpl->A) PetscCall(MatBindToCPU(mimpl_A, usehost));
252*4742e46bSJacob Faibussowitsch   A->boundtocpu = usehost;
253*4742e46bSJacob Faibussowitsch   PetscCall(PetscStrFreeAllocpy(usehost ? PETSCRANDER48 : PETSCDEVICERAND(), &A->defaultrandtype));
254*4742e46bSJacob Faibussowitsch   if (!usehost) {
255*4742e46bSJacob Faibussowitsch     PetscBool iscupm;
256*4742e46bSJacob Faibussowitsch 
257*4742e46bSJacob Faibussowitsch     PetscCall(PetscObjectTypeCompare(PetscObjectCast(mimpl->cvec), VecMPI_CUPM::VECMPICUPM(), &iscupm));
258*4742e46bSJacob Faibussowitsch     if (!iscupm) PetscCall(VecDestroy(&mimpl->cvec));
259*4742e46bSJacob Faibussowitsch     PetscCall(PetscObjectTypeCompare(PetscObjectCast(mimpl->cmat), MATMPIDENSECUPM(), &iscupm));
260*4742e46bSJacob Faibussowitsch     if (!iscupm) PetscCall(MatDestroy(&mimpl->cmat));
261*4742e46bSJacob Faibussowitsch   }
262*4742e46bSJacob Faibussowitsch 
263*4742e46bSJacob Faibussowitsch   MatComposeOp_CUPM(usehost, pobj, "MatDenseGetColumnVec_C", MatDenseGetColumnVec_MPIDense, GetColumnVec<PETSC_MEMORY_ACCESS_READ_WRITE>);
264*4742e46bSJacob Faibussowitsch   MatComposeOp_CUPM(usehost, pobj, "MatDenseRestoreColumnVec_C", MatDenseRestoreColumnVec_MPIDense, RestoreColumnVec<PETSC_MEMORY_ACCESS_READ_WRITE>);
265*4742e46bSJacob Faibussowitsch   MatComposeOp_CUPM(usehost, pobj, "MatDenseGetColumnVecRead_C", MatDenseGetColumnVecRead_MPIDense, GetColumnVec<PETSC_MEMORY_ACCESS_READ>);
266*4742e46bSJacob Faibussowitsch   MatComposeOp_CUPM(usehost, pobj, "MatDenseRestoreColumnVecRead_C", MatDenseRestoreColumnVecRead_MPIDense, RestoreColumnVec<PETSC_MEMORY_ACCESS_READ>);
267*4742e46bSJacob Faibussowitsch   MatComposeOp_CUPM(usehost, pobj, "MatDenseGetColumnVecWrite_C", MatDenseGetColumnVecWrite_MPIDense, GetColumnVec<PETSC_MEMORY_ACCESS_WRITE>);
268*4742e46bSJacob Faibussowitsch   MatComposeOp_CUPM(usehost, pobj, "MatDenseRestoreColumnVecWrite_C", MatDenseRestoreColumnVecWrite_MPIDense, RestoreColumnVec<PETSC_MEMORY_ACCESS_WRITE>);
269*4742e46bSJacob Faibussowitsch 
270*4742e46bSJacob Faibussowitsch   MatSetOp_CUPM(usehost, A, shift, MatShift_MPIDense, Shift);
271*4742e46bSJacob Faibussowitsch 
272*4742e46bSJacob Faibussowitsch   if (const auto mimpl_cmat = mimpl->cmat) PetscCall(MatBindToCPU(mimpl_cmat, usehost));
273*4742e46bSJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
274*4742e46bSJacob Faibussowitsch }
275*4742e46bSJacob Faibussowitsch 
276*4742e46bSJacob Faibussowitsch template <device::cupm::DeviceType T>
277*4742e46bSJacob Faibussowitsch inline PetscErrorCode MatDense_MPI_CUPM<T>::Convert_MPIDenseCUPM_MPIDense(Mat M, MatType mtype, MatReuse reuse, Mat *newmat) noexcept
278*4742e46bSJacob Faibussowitsch {
279*4742e46bSJacob Faibussowitsch   PetscFunctionBegin;
280*4742e46bSJacob Faibussowitsch   PetscCall(Convert_Dispatch_</* to host */ true>(M, mtype, reuse, newmat));
281*4742e46bSJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
282*4742e46bSJacob Faibussowitsch }
283*4742e46bSJacob Faibussowitsch 
284*4742e46bSJacob Faibussowitsch template <device::cupm::DeviceType T>
285*4742e46bSJacob Faibussowitsch inline PetscErrorCode MatDense_MPI_CUPM<T>::Convert_MPIDense_MPIDenseCUPM(Mat M, MatType mtype, MatReuse reuse, Mat *newmat) noexcept
286*4742e46bSJacob Faibussowitsch {
287*4742e46bSJacob Faibussowitsch   PetscFunctionBegin;
288*4742e46bSJacob Faibussowitsch   PetscCall(Convert_Dispatch_</* to host */ false>(M, mtype, reuse, newmat));
289*4742e46bSJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
290*4742e46bSJacob Faibussowitsch }
291*4742e46bSJacob Faibussowitsch 
292*4742e46bSJacob Faibussowitsch // ==========================================================================================
293*4742e46bSJacob Faibussowitsch 
294*4742e46bSJacob Faibussowitsch template <device::cupm::DeviceType T>
295*4742e46bSJacob Faibussowitsch template <PetscMemType, PetscMemoryAccessMode access>
296*4742e46bSJacob Faibussowitsch inline PetscErrorCode MatDense_MPI_CUPM<T>::GetArray(Mat A, PetscScalar **array, PetscDeviceContext) noexcept
297*4742e46bSJacob Faibussowitsch {
298*4742e46bSJacob Faibussowitsch   PetscFunctionBegin;
299*4742e46bSJacob Faibussowitsch   PetscCall(MatDenseCUPMGetArray_Private<T, access>(MatIMPLCast(A)->A, array));
300*4742e46bSJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
301*4742e46bSJacob Faibussowitsch }
302*4742e46bSJacob Faibussowitsch 
303*4742e46bSJacob Faibussowitsch template <device::cupm::DeviceType T>
304*4742e46bSJacob Faibussowitsch template <PetscMemType, PetscMemoryAccessMode access>
305*4742e46bSJacob Faibussowitsch inline PetscErrorCode MatDense_MPI_CUPM<T>::RestoreArray(Mat A, PetscScalar **array, PetscDeviceContext) noexcept
306*4742e46bSJacob Faibussowitsch {
307*4742e46bSJacob Faibussowitsch   PetscFunctionBegin;
308*4742e46bSJacob Faibussowitsch   PetscCall(MatDenseCUPMRestoreArray_Private<T, access>(MatIMPLCast(A)->A, array));
309*4742e46bSJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
310*4742e46bSJacob Faibussowitsch }
311*4742e46bSJacob Faibussowitsch 
312*4742e46bSJacob Faibussowitsch // ==========================================================================================
313*4742e46bSJacob Faibussowitsch 
314*4742e46bSJacob Faibussowitsch template <device::cupm::DeviceType T>
315*4742e46bSJacob Faibussowitsch template <PetscMemoryAccessMode access>
316*4742e46bSJacob Faibussowitsch inline PetscErrorCode MatDense_MPI_CUPM<T>::GetColumnVec(Mat A, PetscInt col, Vec *v) noexcept
317*4742e46bSJacob Faibussowitsch {
318*4742e46bSJacob Faibussowitsch   using namespace vec::cupm;
319*4742e46bSJacob Faibussowitsch 
320*4742e46bSJacob Faibussowitsch   const auto mimpl   = MatIMPLCast(A);
321*4742e46bSJacob Faibussowitsch   const auto mimpl_A = mimpl->A;
322*4742e46bSJacob Faibussowitsch   const auto pobj    = PetscObjectCast(A);
323*4742e46bSJacob Faibussowitsch   auto      &cvec    = mimpl->cvec;
324*4742e46bSJacob Faibussowitsch   PetscInt   lda;
325*4742e46bSJacob Faibussowitsch 
326*4742e46bSJacob Faibussowitsch   PetscFunctionBegin;
327*4742e46bSJacob Faibussowitsch   PetscCheck(!mimpl->vecinuse, PetscObjectComm(pobj), PETSC_ERR_ORDER, "Need to call MatDenseRestoreColumnVec() first");
328*4742e46bSJacob Faibussowitsch   PetscCheck(!mimpl->matinuse, PetscObjectComm(pobj), PETSC_ERR_ORDER, "Need to call MatDenseRestoreSubMatrix() first");
329*4742e46bSJacob Faibussowitsch   mimpl->vecinuse = col + 1;
330*4742e46bSJacob Faibussowitsch 
331*4742e46bSJacob Faibussowitsch   if (!cvec) PetscCall(VecCreateMPICUPMWithArray<T>(PetscObjectComm(pobj), A->rmap->bs, A->rmap->n, A->rmap->N, nullptr, &cvec));
332*4742e46bSJacob Faibussowitsch 
333*4742e46bSJacob Faibussowitsch   PetscCall(MatDenseGetLDA(mimpl_A, &lda));
334*4742e46bSJacob Faibussowitsch   PetscCall(MatDenseCUPMGetArray_Private<T, access>(mimpl_A, const_cast<PetscScalar **>(&mimpl->ptrinuse)));
335*4742e46bSJacob Faibussowitsch   PetscCall(VecCUPMPlaceArrayAsync<T>(cvec, mimpl->ptrinuse + static_cast<std::size_t>(col) * static_cast<std::size_t>(lda)));
336*4742e46bSJacob Faibussowitsch 
337*4742e46bSJacob Faibussowitsch   if (access == PETSC_MEMORY_ACCESS_READ) PetscCall(VecLockReadPush(cvec));
338*4742e46bSJacob Faibussowitsch   *v = cvec;
339*4742e46bSJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
340*4742e46bSJacob Faibussowitsch }
341*4742e46bSJacob Faibussowitsch 
342*4742e46bSJacob Faibussowitsch template <device::cupm::DeviceType T>
343*4742e46bSJacob Faibussowitsch template <PetscMemoryAccessMode access>
344*4742e46bSJacob Faibussowitsch inline PetscErrorCode MatDense_MPI_CUPM<T>::RestoreColumnVec(Mat A, PetscInt, Vec *v) noexcept
345*4742e46bSJacob Faibussowitsch {
346*4742e46bSJacob Faibussowitsch   using namespace vec::cupm;
347*4742e46bSJacob Faibussowitsch 
348*4742e46bSJacob Faibussowitsch   const auto mimpl = MatIMPLCast(A);
349*4742e46bSJacob Faibussowitsch   const auto cvec  = mimpl->cvec;
350*4742e46bSJacob Faibussowitsch 
351*4742e46bSJacob Faibussowitsch   PetscFunctionBegin;
352*4742e46bSJacob Faibussowitsch   PetscCheck(mimpl->vecinuse, PETSC_COMM_SELF, PETSC_ERR_ORDER, "Need to call MatDenseGetColumnVec() first");
353*4742e46bSJacob Faibussowitsch   PetscCheck(cvec, PETSC_COMM_SELF, PETSC_ERR_PLIB, "Missing internal column vector");
354*4742e46bSJacob Faibussowitsch   mimpl->vecinuse = 0;
355*4742e46bSJacob Faibussowitsch 
356*4742e46bSJacob Faibussowitsch   PetscCall(MatDenseCUPMRestoreArray_Private<T, access>(mimpl->A, const_cast<PetscScalar **>(&mimpl->ptrinuse)));
357*4742e46bSJacob Faibussowitsch   if (access == PETSC_MEMORY_ACCESS_READ) PetscCall(VecLockReadPop(cvec));
358*4742e46bSJacob Faibussowitsch   PetscCall(VecCUPMResetArrayAsync<T>(cvec));
359*4742e46bSJacob Faibussowitsch 
360*4742e46bSJacob Faibussowitsch   if (v) *v = nullptr;
361*4742e46bSJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
362*4742e46bSJacob Faibussowitsch }
363*4742e46bSJacob Faibussowitsch 
364*4742e46bSJacob Faibussowitsch // ==========================================================================================
365*4742e46bSJacob Faibussowitsch 
366*4742e46bSJacob Faibussowitsch template <device::cupm::DeviceType T>
367*4742e46bSJacob Faibussowitsch inline PetscErrorCode MatDense_MPI_CUPM<T>::PlaceArray(Mat A, const PetscScalar *array) noexcept
368*4742e46bSJacob Faibussowitsch {
369*4742e46bSJacob Faibussowitsch   const auto mimpl = MatIMPLCast(A);
370*4742e46bSJacob Faibussowitsch 
371*4742e46bSJacob Faibussowitsch   PetscFunctionBegin;
372*4742e46bSJacob Faibussowitsch   PetscCheck(!mimpl->vecinuse, PetscObjectComm(PetscObjectCast(A)), PETSC_ERR_ORDER, "Need to call MatDenseRestoreColumnVec() first");
373*4742e46bSJacob Faibussowitsch   PetscCheck(!mimpl->matinuse, PetscObjectComm(PetscObjectCast(A)), PETSC_ERR_ORDER, "Need to call MatDenseRestoreSubMatrix() first");
374*4742e46bSJacob Faibussowitsch   PetscCall(MatDenseCUPMPlaceArray<T>(mimpl->A, array));
375*4742e46bSJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
376*4742e46bSJacob Faibussowitsch }
377*4742e46bSJacob Faibussowitsch 
378*4742e46bSJacob Faibussowitsch template <device::cupm::DeviceType T>
379*4742e46bSJacob Faibussowitsch inline PetscErrorCode MatDense_MPI_CUPM<T>::ReplaceArray(Mat A, const PetscScalar *array) noexcept
380*4742e46bSJacob Faibussowitsch {
381*4742e46bSJacob Faibussowitsch   const auto mimpl = MatIMPLCast(A);
382*4742e46bSJacob Faibussowitsch 
383*4742e46bSJacob Faibussowitsch   PetscFunctionBegin;
384*4742e46bSJacob Faibussowitsch   PetscCheck(!mimpl->vecinuse, PetscObjectComm(PetscObjectCast(A)), PETSC_ERR_ORDER, "Need to call MatDenseRestoreColumnVec() first");
385*4742e46bSJacob Faibussowitsch   PetscCheck(!mimpl->matinuse, PetscObjectComm(PetscObjectCast(A)), PETSC_ERR_ORDER, "Need to call MatDenseRestoreSubMatrix() first");
386*4742e46bSJacob Faibussowitsch   PetscCall(MatDenseCUPMReplaceArray<T>(mimpl->A, array));
387*4742e46bSJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
388*4742e46bSJacob Faibussowitsch }
389*4742e46bSJacob Faibussowitsch 
390*4742e46bSJacob Faibussowitsch template <device::cupm::DeviceType T>
391*4742e46bSJacob Faibussowitsch inline PetscErrorCode MatDense_MPI_CUPM<T>::ResetArray(Mat A) noexcept
392*4742e46bSJacob Faibussowitsch {
393*4742e46bSJacob Faibussowitsch   const auto mimpl = MatIMPLCast(A);
394*4742e46bSJacob Faibussowitsch 
395*4742e46bSJacob Faibussowitsch   PetscFunctionBegin;
396*4742e46bSJacob Faibussowitsch   PetscCheck(!mimpl->vecinuse, PetscObjectComm(PetscObjectCast(A)), PETSC_ERR_ORDER, "Need to call MatDenseRestoreColumnVec() first");
397*4742e46bSJacob Faibussowitsch   PetscCheck(!mimpl->matinuse, PetscObjectComm(PetscObjectCast(A)), PETSC_ERR_ORDER, "Need to call MatDenseRestoreSubMatrix() first");
398*4742e46bSJacob Faibussowitsch   PetscCall(MatDenseCUPMResetArray<T>(mimpl->A));
399*4742e46bSJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
400*4742e46bSJacob Faibussowitsch }
401*4742e46bSJacob Faibussowitsch 
402*4742e46bSJacob Faibussowitsch // ==========================================================================================
403*4742e46bSJacob Faibussowitsch 
404*4742e46bSJacob Faibussowitsch template <device::cupm::DeviceType T>
405*4742e46bSJacob Faibussowitsch inline PetscErrorCode MatDense_MPI_CUPM<T>::Shift(Mat A, PetscScalar alpha) noexcept
406*4742e46bSJacob Faibussowitsch {
407*4742e46bSJacob Faibussowitsch   PetscDeviceContext dctx;
408*4742e46bSJacob Faibussowitsch 
409*4742e46bSJacob Faibussowitsch   PetscFunctionBegin;
410*4742e46bSJacob Faibussowitsch   PetscCall(GetHandles_(&dctx));
411*4742e46bSJacob Faibussowitsch   PetscCall(PetscInfo(A, "Performing Shift on backend\n"));
412*4742e46bSJacob Faibussowitsch   PetscCall(PointwiseUnaryTransform(A, A->rmap->rstart, A->rmap->rend, A->cmap->N, dctx, device::cupm::functors::make_plus_equals(alpha)));
413*4742e46bSJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
414*4742e46bSJacob Faibussowitsch }
415*4742e46bSJacob Faibussowitsch 
416*4742e46bSJacob Faibussowitsch } // namespace impl
417*4742e46bSJacob Faibussowitsch 
418*4742e46bSJacob Faibussowitsch namespace
419*4742e46bSJacob Faibussowitsch {
420*4742e46bSJacob Faibussowitsch 
421*4742e46bSJacob Faibussowitsch template <device::cupm::DeviceType T>
422*4742e46bSJacob Faibussowitsch inline PetscErrorCode MatCreateDenseCUPM(MPI_Comm comm, PetscInt n, PetscInt m, PetscInt N, PetscInt M, PetscScalar *data, Mat *A, PetscDeviceContext dctx = nullptr) noexcept
423*4742e46bSJacob Faibussowitsch {
424*4742e46bSJacob Faibussowitsch   PetscMPIInt size;
425*4742e46bSJacob Faibussowitsch 
426*4742e46bSJacob Faibussowitsch   PetscFunctionBegin;
427*4742e46bSJacob Faibussowitsch   PetscValidPointer(A, 7);
428*4742e46bSJacob Faibussowitsch   PetscCallMPI(MPI_Comm_size(comm, &size));
429*4742e46bSJacob Faibussowitsch   if (size > 1) {
430*4742e46bSJacob Faibussowitsch     PetscCall(MatCreateMPIDenseCUPM<T>(comm, n, m, N, M, data, A, dctx));
431*4742e46bSJacob Faibussowitsch   } else {
432*4742e46bSJacob Faibussowitsch     if (n == PETSC_DECIDE) n = N;
433*4742e46bSJacob Faibussowitsch     if (m == PETSC_DECIDE) m = M;
434*4742e46bSJacob Faibussowitsch     // It's OK here if both are PETSC_DECIDE since PetscSplitOwnership() will catch that down
435*4742e46bSJacob Faibussowitsch     // the line
436*4742e46bSJacob Faibussowitsch     PetscCall(MatCreateSeqDenseCUPM<T>(comm, n, m, data, A, dctx));
437*4742e46bSJacob Faibussowitsch   }
438*4742e46bSJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
439*4742e46bSJacob Faibussowitsch }
440*4742e46bSJacob Faibussowitsch 
441*4742e46bSJacob Faibussowitsch } // anonymous namespace
442*4742e46bSJacob Faibussowitsch 
443*4742e46bSJacob Faibussowitsch } // namespace cupm
444*4742e46bSJacob Faibussowitsch 
445*4742e46bSJacob Faibussowitsch } // namespace mat
446*4742e46bSJacob Faibussowitsch 
447*4742e46bSJacob Faibussowitsch } // namespace Petsc
448*4742e46bSJacob Faibussowitsch 
449*4742e46bSJacob Faibussowitsch #endif // __cplusplus
450*4742e46bSJacob Faibussowitsch 
451*4742e46bSJacob Faibussowitsch #endif // PETSCMATMPIDENSECUPM_HPP
452